Simplified quite a bit. Made a unify function. Still bugs left

This commit is contained in:
sebastianselander 2023-02-17 11:09:48 +01:00
parent eafe0fea0b
commit a9f54dbca1
4 changed files with 192 additions and 209 deletions

View file

@ -3,12 +3,12 @@ function-arrows: trailing
comma-style: leading comma-style: leading
import-export-style: diff-friendly import-export-style: diff-friendly
indent-wheres: false indent-wheres: false
record-brace-space: false record-brace-space: true
newlines-between-decls: 1 newlines-between-decls: 1
haddock-style: multi-line haddock-style: multi-line
haddock-style-module: haddock-style-module:
let-style: auto let-style: auto
in-style: right-align in-style: right-align
respectful: true respectful: false
fixities: [] fixities: []
unicode: never unicode: never

View file

@ -17,7 +17,7 @@ extra-source-files:
common warnings common warnings
ghc-options: -Wdefault ghc-options: -W
executable language executable language
import: warnings import: warnings

View file

@ -1,26 +1,30 @@
{-# LANGUAGE LambdaCase, OverloadedStrings, OverloadedRecordDot #-} {-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedRecordDot #-}
{-# LANGUAGE OverloadedStrings #-}
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
{-# HLINT ignore "Use camelCase" #-}
module TypeChecker.TypeChecker where module TypeChecker.TypeChecker where
import Control.Monad (when, void) import Control.Monad (void)
import Control.Monad.Except (ExceptT, throwError, runExceptT) import Control.Monad.Except (ExceptT, runExceptT, throwError)
import Control.Monad.State (StateT) import Control.Monad.State (StateT)
import qualified Control.Monad.State as St import Control.Monad.State qualified as St
import Data.Functor.Identity (Identity, runIdentity) import Data.Functor.Identity (Identity, runIdentity)
import Data.Map (Map) import Data.Map (Map)
import qualified Data.Map as M import Data.Map qualified as M
import Grammar.ErrM (Err)
import Grammar.Print
import Data.List (findIndex)
import Debug.Trace (trace)
import TypeChecker.TypeCheckerIr import TypeChecker.TypeCheckerIr
data Ctx = Ctx { vars :: Map Integer Type data Ctx = Ctx
{ vars :: Map Integer Type
, sigs :: Map Ident Type , sigs :: Map Ident Type
, nextFresh :: Ident , nextFresh :: Int
} }
deriving Show deriving (Show)
-- Perhaps swap over to reader monad instead for vars and sigs.
type Infer = StateT Ctx (ExceptT Error Identity)
{- {-
@ -28,15 +32,17 @@ The type checker will assume we first rename all variables to unique name, as to
have to care about scoping. It significantly improves the quality of life of the have to care about scoping. It significantly improves the quality of life of the
programmer. programmer.
TODOs:
Add skolemization variables. i.e
{ \x. 3 : forall a. a -> a }
should not type check
Generalize. Not really sure what that means though
-} -}
type Infer = StateT Ctx (ExceptT Error Identity)
initEnv :: Ctx
initEnv = Ctx mempty mempty "a"
run :: Infer a -> Either Error a run :: Infer a -> Either Error a
run = runIdentity . runExceptT . flip St.evalStateT initEnv run = runIdentity . runExceptT . flip St.evalStateT (Ctx mempty mempty 0)
typecheck :: RProgram -> Either Error TProgram typecheck :: RProgram -> Either Error TProgram
typecheck = run . inferPrg typecheck = run . inferPrg
@ -54,92 +60,60 @@ inferBind (RBind name e) = do
inferExp :: RExp -> Infer (Type, TExp) inferExp :: RExp -> Infer (Type, TExp)
inferExp = \case inferExp = \case
RAnn expr typ -> do RAnn expr typ -> do
(t,expr') <- inferExp expr (t, expr') <- inferExp expr
when (not (t == typ || isPoly t)) (throwError $ AnnotatedMismatch "inferExp, RAnn") void $ t =:= typ
return (typ,expr') return (typ, expr')
RBound num name -> do
-- Name is only here for proper error messages t <- lookupVars num
RBound num name -> return (t, TBound num name t)
M.lookup num <$> St.gets vars >>= \case
Nothing -> throwError $ UnboundVar "RBound"
Just t -> return (t, TBound num name t)
RFree name -> do RFree name -> do
M.lookup name <$> St.gets sigs >>= \case t <- lookupSigs name
Nothing -> throwError $ UnboundVar "RFree" return (t, TFree name t)
Just t -> return (t, TFree name t) RConst (CInt i) -> return (TMono "Int", TConst (CInt i) (TMono "Int"))
RConst (CStr str) -> return (TMono "Str", TConst (CStr str) (TMono "Str"))
RConst (CInt i) -> return $ (TMono "Int", TConst (CInt i) (TMono "Int"))
RConst (CStr str) -> return $ (TMono "Str", TConst (CStr str) (TMono "Str"))
-- Should do proper unification using union-find. Some nice libs exist
RApp expr1 expr2 -> do
(typ1, expr1') <- inferExp expr1
(typ2, expr2') <- inferExp expr2
fvar <- fresh
case typ1 of
(TPoly (Ident x)) -> do
let newType = (TArrow (TPoly (Ident x)) (TPoly fvar))
specifyType expr1 newType
typ1' <- apply newType typ1
return $ (typ1', TApp expr1' expr2' typ1')
_ -> (\t -> (t, TApp expr1' expr2' t)) <$> apply typ1 typ2
RAdd expr1 expr2 -> do RAdd expr1 expr2 -> do
(typ1, expr1') <- inferExp expr1 (typ1, expr1') <- check expr1 (TMono "Int")
(typ2, expr2') <- inferExp expr2 (_, expr2') <- check expr2 (TMono "Int")
when (not $ (isInt typ1 || isPoly typ1) && (isInt typ2 || isPoly typ2)) (throwError $ TypeMismatch "inferExp, RAdd") return (typ1, TAdd expr1' expr2' typ1)
specifyType expr1 (TMono "Int") RApp expr1 expr2 -> do
specifyType expr2 (TMono "Int") (fn_t, expr1') <- inferExp expr1
return (TMono "Int", TAdd expr1' expr2' (TMono "Int")) (arg_t, expr2') <- inferExp expr2
res <- fresh
-- TODO: Double check if this is correct behavior.
-- It might be the case that we should return res, rather than new_t
new_t <- fn_t =:= TArrow arg_t res
return (new_t, TApp expr1' expr2' new_t)
RAbs num name expr -> do RAbs num name expr -> do
insertVars num (TPoly "a") arg <- fresh
insertVars num arg
(typ, expr') <- inferExp expr (typ, expr') <- inferExp expr
newTyp <- lookupVars num return (TArrow arg typ, TAbs num name expr' typ)
return $ (TArrow newTyp typ, TAbs num name expr' typ)
-- Aux check :: RExp -> Type -> Infer (Type, TExp)
isInt :: Type -> Bool check e t = do
isInt (TMono "Int") = True (t', e') <- inferExp e
isInt _ = False t'' <- t' =:= t
return (t'', e')
isArrow :: Type -> Bool fresh :: Infer Type
isArrow (TArrow _ _) = True
isArrow _ = False
isPoly :: Type -> Bool
isPoly (TPoly _) = True
isPoly _ = False
fresh :: Infer Ident
fresh = do fresh = do
(Ident var) <- St.gets nextFresh var <- St.gets nextFresh
when (length var == 0) (throwError $ Default "fresh") St.modify (\st -> st {nextFresh = succ var})
index <- case findIndex (== (head var)) alphabet of return (TPoly $ Ident (show var))
Nothing -> throwError $ Default "fresh"
Just i -> return i
let nextIndex = (index + 1) `mod` 26
let newVar = Ident $ [alphabet !! nextIndex]
St.modify (\st -> st { nextFresh = newVar })
return newVar
where
alphabet = "abcdefghijklmnopqrstuvwxyz" :: [Char]
unify :: Type -> Type -> Infer Type -- | Unify two types.
unify = todo (=:=) :: Type -> Type -> Infer Type
(=:=) (TPoly _) b = return b
-- | Specify the type of a bound variable (=:=) a (TPoly _) = return a
-- Because in lambdas we have to assume a general type and update it (=:=) (TMono a) (TMono b) | a == b = return (TMono a)
specifyType :: RExp -> Type -> Infer () (=:=) (TArrow a b) (TArrow c d) = do
specifyType (RBound num name) typ = do t1 <- a =:= c
insertVars num typ t2 <- b =:= d
return () return $ TArrow t1 t2
specifyType _ _ = return () (=:=) a b = throwError (TypeMismatch $ unwords ["Can not unify type", show a, "with", show b])
-- Unused currently
lookupVars :: Integer -> Infer Type lookupVars :: Integer -> Infer Type
lookupVars i = do lookupVars i = do
st <- St.gets vars st <- St.gets vars
@ -150,7 +124,7 @@ lookupVars i = do
insertVars :: Integer -> Type -> Infer () insertVars :: Integer -> Type -> Infer ()
insertVars i t = do insertVars i t = do
st <- St.get st <- St.get
St.put ( st { vars = M.insert i t st.vars } ) St.put (st {vars = M.insert i t st.vars})
lookupSigs :: Ident -> Infer Type lookupSigs :: Ident -> Infer Type
lookupSigs i = do lookupSigs i = do
@ -162,14 +136,7 @@ lookupSigs i = do
insertSigs :: Ident -> Type -> Infer () insertSigs :: Ident -> Type -> Infer ()
insertSigs i t = do insertSigs i t = do
st <- St.get st <- St.get
St.put ( st { sigs = M.insert i t st.sigs } ) St.put (st {sigs = M.insert i t st.sigs})
-- Have to figure out the equivalence classes for types.
-- Currently this does not support more than exact matches.
apply :: Type -> Type -> Infer Type
apply (TArrow t1 t2) t3
| t1 == t3 = return t2
apply t1 t2 = throwError $ TypeMismatch "apply"
{-# WARNING todo "TODO IN CODE" #-} {-# WARNING todo "TODO IN CODE" #-}
todo :: a todo :: a
@ -183,21 +150,30 @@ data Error
| UnboundVar String | UnboundVar String
| AnnotatedMismatch String | AnnotatedMismatch String
| Default String | Default String
deriving Show deriving (Show)
-- Tests -- Tests
lambda = RAbs 0 "x" (RAdd (RBound 0 "x") (RBound 0 "x")) -- (\x. x + 1) 1
lambda2 = RAbs 0 "x" (RAnn (RBound 0 "x") (TArrow (TMono "Int") (TMono "String"))) app_lambda :: RExp
app_lambda = app lambda one
fn_on_var = RAbs 0 (Ident "f") (RAbs 1 (Ident "x") (RApp (RBound 0 (Ident "f")) (RBound 1 (Ident "x")))) lambda :: RExp
lambda = RAbs 0 "x" $ add bound one
add :: RExp -> RExp -> RExp
add = RAdd
--add x = \y. x+y; bound = RBound 0 "x"
add = RAbs 0 "x" (RAbs 1 "y" (RAdd (RBound 0 "x") (RBound 1 "y")))
-- main = (\z. z+z) ((add 4) 6); app :: RExp -> RExp -> RExp
main = RApp (RAbs 0 "z" (RAdd (RBound 0 "z") (RBound 0 "z"))) applyAdd app = RApp
four = RConst (CInt 4)
six = RConst (CInt 6) one :: RExp
applyAdd = (RApp (RApp add four) six) one = RConst (CInt 1)
partialAdd = RApp add four
fn_t = TArrow (TPoly (Ident "0")) (TMono (Ident "Int"))
arr_t = TArrow (TMono "Int") (TPoly "1")
f_x = RAbs 0 "f" (RAbs 1 "x" (RApp (RBound 0 "f") (RBound 1 "x")))

View file

@ -1,28 +1,31 @@
{-# LANGUAGE DeriveAnyClass, PatternSynonyms, GADTs, LambdaCase, OverloadedStrings #-} {-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-}
module TypeChecker.Unification where module TypeChecker.Unification where
import Renamer.Renamer import Control.Arrow ((>>>))
import Renamer.RenamerIr (Const(..), RExp(..), RBind(..), RProgram(..), Ident(..)) import Control.Monad.Except
import qualified Renamer.RenamerIr as R
import Control.Monad.Reader import Control.Monad.Reader
import Control.Monad.State import Control.Monad.State
import Control.Monad.Except import Control.Unification hiding (applyBindings, (=:=))
import Data.Functor.Identity import Control.Unification qualified as U
import Control.Arrow ((>>>))
import Control.Unification hiding ((=:=), applyBindings)
import qualified Control.Unification as U
import Control.Unification.IntVar import Control.Unification.IntVar
import Data.Functor.Fixedpoint
import GHC.Generics (Generic1)
import Data.Foldable (fold) import Data.Foldable (fold)
import Data.Functor.Fixedpoint
import Data.Functor.Identity
import Data.Map (Map) import Data.Map (Map)
import qualified Data.Map as M import Data.Map qualified as M
import Data.Maybe (fromMaybe, fromJust) import Data.Maybe (fromJust, fromMaybe)
import Data.Set (Set, (\\)) import Data.Set (Set, (\\))
import qualified Data.Set as S import Data.Set qualified as S
import Debug.Trace (trace) import Debug.Trace (trace)
import GHC.Generics (Generic1)
import Renamer.Renamer
import Renamer.RenamerIr (Const (..), Ident (..), RBind (..), RExp (..), RProgram (..))
import Renamer.RenamerIr qualified as R
type Ctx = Map Ident UPolytype type Ctx = Map Ident UPolytype
@ -39,6 +42,7 @@ instance Show a => Show (TypeT a) where
type Infer = StateT (Map Ident UPolytype) (ReaderT Ctx (ExceptT TypeError (IntBindingT TypeT Identity))) type Infer = StateT (Map Ident UPolytype) (ReaderT Ctx (ExceptT TypeError (IntBindingT TypeT Identity)))
type Type = Fix TypeT type Type = Fix TypeT
type UType = UTerm TypeT IntVar type UType = UTerm TypeT IntVar
data Poly t = Forall [Ident] t data Poly t = Forall [Ident] t
@ -67,13 +71,13 @@ pattern UTPoly :: Ident -> UType
pattern UTPoly v = UTerm (TPolyT v) pattern UTPoly v = UTerm (TPolyT v)
data TType = TTPoly Ident | TTMono Ident | TTArrow TType TType data TType = TTPoly Ident | TTMono Ident | TTArrow TType TType
deriving Show deriving (Show)
data Program = Program [Bind] newtype Program = Program [Bind]
deriving Show deriving (Show)
data Bind = Bind Ident Exp Polytype data Bind = Bind Ident Exp Polytype
deriving Show deriving (Show)
data Exp data Exp
= EAnn Exp Polytype = EAnn Exp Polytype
@ -83,7 +87,7 @@ data Exp
| EApp Exp Exp Polytype | EApp Exp Exp Polytype
| EAdd Exp Exp Polytype | EAdd Exp Exp Polytype
| EAbs Ident Exp Polytype | EAbs Ident Exp Polytype
deriving Show deriving (Show)
data TExp data TExp
= TAnn TExp UType = TAnn TExp UType
@ -93,7 +97,7 @@ data TExp
| TApp TExp TExp UType | TApp TExp TExp UType
| TAdd TExp TExp UType | TAdd TExp TExp UType
| TAbs Ident TExp UType | TAbs Ident TExp UType
deriving Show deriving (Show)
---------------------------------------------------------- ----------------------------------------------------------
typecheck :: RProgram -> Either TypeError Program typecheck :: RProgram -> Either TypeError Program
@ -106,7 +110,7 @@ inferProgram (RProgram binds) = do
inferBind :: RBind -> Infer Bind inferBind :: RBind -> Infer Bind
inferBind (RBind i e) = do inferBind (RBind i e) = do
(t,e') <- infer e (t, e') <- infer e
e'' <- convert fromUType e' e'' <- convert fromUType e'
t' <- fromUType t t' <- fromUType t
insertSigs i (Forall [] t) insertSigs i (Forall [] t)
@ -119,8 +123,7 @@ convert :: (UType -> Infer Polytype) -> TExp -> Infer Exp
convert f = \case convert f = \case
(TAnn e t) -> do (TAnn e t) -> do
e' <- convert f e e' <- convert f e
t' <- (f t) EAnn e' <$> f t
return $ EAnn e' t'
(TFree i t) -> do (TFree i t) -> do
t' <- f t t' <- f t
return $ EFree i t' return $ EFree i t'
@ -146,7 +149,8 @@ convert f = \case
return $ EAbs i e' t' return $ EAbs i e' t'
run :: Infer a -> Either TypeError a run :: Infer a -> Either TypeError a
run = flip evalStateT mempty run =
flip evalStateT mempty
>>> flip runReaderT mempty >>> flip runReaderT mempty
>>> runExceptT >>> runExceptT
>>> evalIntBindingT >>> evalIntBindingT
@ -154,23 +158,23 @@ run = flip evalStateT mempty
infer :: RExp -> Infer (UType, TExp) infer :: RExp -> Infer (UType, TExp)
infer = \case infer = \case
(RConst (CInt i)) -> return $ (UTMono "Int", TConst (CInt i) (UTMono "Int")) (RConst (CInt i)) -> return (UTMono "Int", TConst (CInt i) (UTMono "Int"))
(RConst (CStr str)) -> return $ (UTMono "String", TConst (CStr str) (UTMono "String")) (RConst (CStr str)) -> return (UTMono "String", TConst (CStr str) (UTMono "String"))
(RAdd e1 e2) -> do (RAdd e1 e2) -> do
(t1,e1') <- infer e2 (t1, e1') <- infer e2
(t2,e2') <- infer e1 (t2, e2') <- infer e1
t1 =:= (UTMono "Int") t1 =:= UTMono "Int"
t2 =:= (UTMono "Int") t2 =:= UTMono "Int"
return $ (UTMono "Int", TAdd e1' e2' (UTMono "Int")) return (UTMono "Int", TAdd e1' e2' (UTMono "Int"))
(RAnn e t) -> do (RAnn e t) -> do
(t',e') <- infer e (t', e') <- infer e
check e t' check e t'
return (t', TAnn e' t') return (t', TAnn e' t')
(RApp e1 e2) -> do (RApp e1 e2) -> do
(f,e1') <- infer e1 (f, e1') <- infer e1
(arg,e2') <- infer e2 (arg, e2') <- infer e2
res <- fresh res <- fresh
f =:= UTArrow f arg f =:= UTArrow arg res
return (res, TApp e1' e2' res) return (res, TApp e1' e2' res)
(RAbs _ i e) -> do (RAbs _ i e) -> do
arg <- fresh arg <- fresh
@ -223,8 +227,10 @@ class FreeVars a where
instance FreeVars UType where instance FreeVars UType where
freeVars ut = do freeVars ut = do
fuvs <- fmap (S.fromList . map Right) . lift . lift . lift $ getFreeVars ut fuvs <- fmap (S.fromList . map Right) . lift . lift . lift $ getFreeVars ut
let ftvs = ucata (const S.empty) let ftvs =
(\case {TMonoT x -> S.singleton (Left x); f -> fold f}) ucata
(const S.empty)
(\case TMonoT x -> S.singleton (Left x); f -> fold f)
ut ut
return $ fuvs `S.union` ftvs return $ fuvs `S.union` ftvs
@ -253,9 +259,10 @@ instantiate (Forall xs uty) = do
return $ substU (M.fromList (zip (map Left xs) xs')) uty return $ substU (M.fromList (zip (map Left xs) xs')) uty
substU :: Map (Either Ident IntVar) UType -> UType -> UType substU :: Map (Either Ident IntVar) UType -> UType -> UType
substU m = ucata substU m =
ucata
(\v -> fromMaybe (UVar v) (M.lookup (Right v) m)) (\v -> fromMaybe (UVar v) (M.lookup (Right v) m))
(\case ( \case
TPolyT v -> fromMaybe (UTPoly v) (M.lookup (Left v) m) TPolyT v -> fromMaybe (UTPoly v) (M.lookup (Left v) m)
f -> UTerm f f -> UTerm f
) )