Type inference/checking on ADTs mostly complete(?). Still have to test
This commit is contained in:
parent
2f45f39435
commit
bbf6e159c7
8 changed files with 563 additions and 467 deletions
|
|
@ -3,6 +3,7 @@
|
|||
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
|
||||
{-# HLINT ignore "Use traverse_" #-}
|
||||
{-# OPTIONS_GHC -Wno-overlapping-patterns #-}
|
||||
{-# HLINT ignore "Use zipWithM" #-}
|
||||
|
||||
module TypeChecker.TypeChecker where
|
||||
|
||||
|
|
@ -16,6 +17,7 @@ import qualified Data.Map as M
|
|||
import Data.Set (Set)
|
||||
import qualified Data.Set as S
|
||||
|
||||
import Data.Foldable (traverse_)
|
||||
import Grammar.Abs
|
||||
import Grammar.Print (printTree)
|
||||
import qualified TypeChecker.TypeCheckerIr as T
|
||||
|
|
@ -24,10 +26,12 @@ import qualified TypeChecker.TypeCheckerIr as T
|
|||
data Poly = Forall [Ident] Type
|
||||
deriving Show
|
||||
|
||||
newtype Ctx = Ctx { vars :: Map Ident Poly }
|
||||
newtype Ctx = Ctx { vars :: Map Ident Poly
|
||||
}
|
||||
|
||||
data Env = Env { count :: Int
|
||||
, sigs :: Map Ident Type
|
||||
data Env = Env { count :: Int
|
||||
, sigs :: Map Ident Type
|
||||
, dtypes :: Map Ident Type
|
||||
}
|
||||
|
||||
type Error = String
|
||||
|
|
@ -36,7 +40,7 @@ type Subst = Map Ident Type
|
|||
type Infer = StateT Env (ReaderT Ctx (ExceptT Error Identity))
|
||||
|
||||
initCtx = Ctx mempty
|
||||
initEnv = Env 0 mempty
|
||||
initEnv = Env 0 mempty mempty
|
||||
|
||||
runPretty :: Exp -> Either Error String
|
||||
runPretty = fmap (printTree . fst). run . inferExp
|
||||
|
|
@ -50,21 +54,44 @@ runC e c = runIdentity . runExceptT . flip runReaderT c . flip evalStateT e
|
|||
typecheck :: Program -> Either Error T.Program
|
||||
typecheck = run . checkPrg
|
||||
|
||||
checkData :: Data -> Infer ()
|
||||
checkData d = case d of
|
||||
(Data typ@(TConstr name _) constrs) -> do
|
||||
traverse_ (\(Constructor name' t')
|
||||
-> if typ == retType t'
|
||||
then insertConstr name' t' else
|
||||
throwError $
|
||||
unwords
|
||||
[ "return type of constructor:"
|
||||
, printTree name
|
||||
, "with type:"
|
||||
, printTree (retType t')
|
||||
, "does not match data: "
|
||||
, printTree typ]) constrs
|
||||
_ -> throwError "Data type incorrectly declared"
|
||||
where
|
||||
retType :: Type -> Type
|
||||
retType (TArr _ t2) = retType t2
|
||||
retType a = a
|
||||
|
||||
checkPrg :: Program -> Infer T.Program
|
||||
checkPrg (Program bs) = do
|
||||
let bs' = getBinds bs
|
||||
traverse (\(Bind n t _ _ _) -> insertSig n t) bs'
|
||||
bs' <- mapM checkBind bs'
|
||||
return $ T.Program bs'
|
||||
preRun bs
|
||||
T.Program <$> checkDef bs
|
||||
where
|
||||
getBinds :: [Def] -> [Bind]
|
||||
getBinds = map toBind . filter isBind
|
||||
isBind :: Def -> Bool
|
||||
isBind (DBind _) = True
|
||||
isBind _ = True
|
||||
toBind :: Def -> Bind
|
||||
toBind (DBind bind) = bind
|
||||
toBind _ = error "Can't convert DData to Bind"
|
||||
preRun :: [Def] -> Infer ()
|
||||
preRun [] = return ()
|
||||
preRun (x:xs) = case x of
|
||||
DBind (Bind n t _ _ _ ) -> insertSig n t >> preRun xs
|
||||
DData d@(Data _ _) -> checkData d >> preRun xs
|
||||
|
||||
checkDef :: [Def] -> Infer [T.Def]
|
||||
checkDef [] = return []
|
||||
checkDef (x:xs) = case x of
|
||||
(DBind b) -> do
|
||||
b' <- checkBind b
|
||||
fmap (T.DBind b' :) (checkDef xs)
|
||||
(DData d) -> fmap (T.DData d :) (checkDef xs)
|
||||
|
||||
checkBind :: Bind -> Infer T.Bind
|
||||
checkBind (Bind n t _ args e) = do
|
||||
|
|
@ -77,15 +104,18 @@ checkBind (Bind n t _ args e) = do
|
|||
makeLambda :: Exp -> [Ident] -> Exp
|
||||
makeLambda = foldl (flip EAbs)
|
||||
|
||||
-- | Check if two types are considered equal
|
||||
-- For the purpose of the algorithm two polymorphic types are always considered equal
|
||||
typeEq :: Type -> Type -> Bool
|
||||
typeEq (TArr l r) (TArr l' r') = typeEq l l' && typeEq r r'
|
||||
typeEq (TMono a) (TMono b) = a == b
|
||||
typeEq (TPol _) (TPol _) = True
|
||||
typeEq _ _ = False
|
||||
typeEq (TArr l r) (TArr l' r') = typeEq l l' && typeEq r r'
|
||||
typeEq (TMono a) (TMono b) = a == b
|
||||
typeEq (TConstr name a) (TConstr name' b) = name == name' && and (zipWith typeEq a b)
|
||||
typeEq (TPol _) (TPol _) = True
|
||||
typeEq _ _ = False
|
||||
|
||||
inferExp :: Exp -> Infer (Type, T.Exp)
|
||||
inferExp e = do
|
||||
(s, t, e') <- w e
|
||||
(s, t, e') <- algoW e
|
||||
let subbed = apply s t
|
||||
return (subbed, replace subbed e')
|
||||
|
||||
|
|
@ -98,19 +128,26 @@ replace t = \case
|
|||
T.EAdd _ e1 e2 -> T.EAdd t e1 e2
|
||||
T.ELet (T.Bind (n, _) args e1) e2 -> T.ELet (T.Bind (n, t) args e1) e2
|
||||
|
||||
w :: Exp -> Infer (Subst, Type, T.Exp)
|
||||
w = \case
|
||||
algoW :: Exp -> Infer (Subst, Type, T.Exp)
|
||||
algoW = \case
|
||||
|
||||
EAnn e t -> do
|
||||
(s1, t', e') <- w e
|
||||
(s1, t', e') <- algoW e
|
||||
applySt s1 $ do
|
||||
s2 <- unify (apply s1 t) t'
|
||||
return (s2 `compose` s1, t, e')
|
||||
|
||||
-- | ------------------
|
||||
-- | Γ ⊢ e₀ : Int, ∅
|
||||
|
||||
ELit (LInt n) -> return (nullSubst, TMono "Int", T.ELit (TMono "Int") (LInt n))
|
||||
|
||||
ELit a -> error $ "NOT IMPLEMENTED YET: ELit " ++ show a
|
||||
|
||||
-- | x : σ ∈ Γ τ = inst(σ)
|
||||
-- | ----------------------
|
||||
-- | Γ ⊢ x : τ, ∅
|
||||
|
||||
EId i -> do
|
||||
var <- asks vars
|
||||
case M.lookup i var of
|
||||
|
|
@ -118,42 +155,67 @@ w = \case
|
|||
Nothing -> do
|
||||
sig <- gets sigs
|
||||
case M.lookup i sig of
|
||||
Nothing -> throwError $ "Unbound variable: " ++ show i
|
||||
Just t -> return (nullSubst, t, T.EId (i, t))
|
||||
Nothing -> do
|
||||
constr <- gets dtypes
|
||||
case M.lookup i constr of
|
||||
Just t -> return (nullSubst, t, T.EId (i, t))
|
||||
Nothing -> throwError $ "Unbound variable: " ++ show i
|
||||
|
||||
-- | τ = newvar Γ, x : τ ⊢ e : τ', S
|
||||
-- | ---------------------------------
|
||||
-- | Γ ⊢ w λx. e : Sτ → τ', S
|
||||
|
||||
EAbs name e -> do
|
||||
fr <- fresh
|
||||
withBinding name (Forall [] fr) $ do
|
||||
(s1, t', e') <- w e
|
||||
(s1, t', e') <- algoW e
|
||||
let varType = apply s1 fr
|
||||
let newArr = TArr varType t'
|
||||
return (s1, newArr, T.EAbs newArr (name, varType) e')
|
||||
|
||||
-- | Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S₁
|
||||
-- | s₂ = mgu(s₁τ₀, Int) s₃ = mgu(s₂τ₁, Int)
|
||||
-- | ------------------------------------------
|
||||
-- | Γ ⊢ e₀ + e₁ : Int, S₃S₂S₁S₀
|
||||
-- This might be wrong
|
||||
|
||||
EAdd e0 e1 -> do
|
||||
(s1, t0, e0') <- w e0
|
||||
(s1, t0, e0') <- algoW e0
|
||||
applySt s1 $ do
|
||||
(s2, t1, e1') <- w e1
|
||||
applySt s2 $ do
|
||||
s3 <- unify (apply s2 t0) (TMono "Int")
|
||||
s4 <- unify (apply s3 t1) (TMono "Int")
|
||||
return (s4 `compose` s3 `compose` s2 `compose` s1, TMono "Int", T.EAdd (TMono "Int") e0' e1')
|
||||
(s2, t1, e1') <- algoW e1
|
||||
-- applySt s2 $ do
|
||||
s3 <- unify (apply s2 t0) (TMono "Int")
|
||||
s4 <- unify (apply s3 t1) (TMono "Int")
|
||||
return (s4 `compose` s3 `compose` s2 `compose` s1, TMono "Int", T.EAdd (TMono "Int") e0' e1')
|
||||
|
||||
-- | Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S1
|
||||
-- | τ' = newvar S₂ = mgu(S₁τ₀, τ₁ → τ')
|
||||
-- | --------------------------------------
|
||||
-- | Γ ⊢ e₀ e₁ : S₂τ', S₂S₁S₀
|
||||
|
||||
EApp e0 e1 -> do
|
||||
fr <- fresh
|
||||
(s0, t0, e0') <- w e0
|
||||
(s0, t0, e0') <- algoW e0
|
||||
applySt s0 $ do
|
||||
(s1, t1, e1') <- w e1
|
||||
(s1, t1, e1') <- algoW e1
|
||||
-- applySt s1 $ do
|
||||
s2 <- unify (apply s1 t0) (TArr t1 fr)
|
||||
let t = apply s2 fr
|
||||
return (s2 `compose` s1 `compose` s0, t, T.EApp t e0' e1')
|
||||
|
||||
-- | Γ ⊢ e₀ : τ, S₀ S₀Γ, x : S̅₀Γ̅(τ) ⊢ e₁ : τ', S₁
|
||||
-- | ----------------------------------------------
|
||||
-- | Γ ⊢ let x = e₀ in e₁ : τ', S₁S₀
|
||||
|
||||
-- The bar over S₀ and Γ means "generalize"
|
||||
|
||||
ELet name e0 e1 -> do
|
||||
(s1, t1, e0') <- w e0
|
||||
(s1, t1, e0') <- algoW e0
|
||||
env <- asks vars
|
||||
let t' = generalize (apply s1 env) t1
|
||||
withBinding name t' $ do
|
||||
(s2, t2, e1') <- w e1
|
||||
(s2, t2, e1') <- algoW e1
|
||||
return (s2 `compose` s1, t2, T.ELet (T.Bind (name,t2) [] e0') e1' )
|
||||
|
||||
ECase a b -> error $ "NOT IMPLEMENTED YET: ECase" ++ show a ++ " " ++ show b
|
||||
|
|
@ -168,6 +230,12 @@ unify t0 t1 = case (t0, t1) of
|
|||
(TPol a, b) -> occurs a b
|
||||
(a, TPol b) -> occurs b a
|
||||
(TMono a, TMono b) -> if a == b then return M.empty else throwError "Types do not unify"
|
||||
-- | TODO: Figure out a cleaner way to express the same thing
|
||||
(TConstr name t, TConstr name' t') -> if name == name' && length t == length t'
|
||||
then do
|
||||
xs <- sequence $ zipWith unify t t'
|
||||
return $ foldr compose nullSubst xs
|
||||
else throwError $ unwords ["Type constructor:", printTree name, "(" ++ printTree t ++ ")", "does not match with:", printTree name', "(" ++ printTree t' ++ ")"]
|
||||
(a, b) -> throwError . unwords $ ["Type:", printTree a, "can't be unified with:", printTree b]
|
||||
|
||||
-- | Check if a type is contained in another type.
|
||||
|
|
@ -202,9 +270,11 @@ class FreeVars t where
|
|||
|
||||
instance FreeVars Type where
|
||||
free :: Type -> Set Ident
|
||||
free (TPol a) = S.singleton a
|
||||
free (TMono _) = mempty
|
||||
free (TArr a b) = free a `S.union` free b
|
||||
free (TPol a) = S.singleton a
|
||||
free (TMono _) = mempty
|
||||
free (TArr a b) = free a `S.union` free b
|
||||
-- | Not guaranteed to be correct
|
||||
free (TConstr _ a) = foldl' (\acc x -> free x `S.union` acc) S.empty a
|
||||
apply :: Subst -> Type -> Type
|
||||
apply sub t = do
|
||||
case t of
|
||||
|
|
@ -213,6 +283,7 @@ instance FreeVars Type where
|
|||
Nothing -> TPol a
|
||||
Just t -> t
|
||||
TArr a b -> TArr (apply sub a) (apply sub b)
|
||||
TConstr name a -> TConstr name (map (apply sub) a)
|
||||
|
||||
instance FreeVars Poly where
|
||||
free :: Poly -> Set Ident
|
||||
|
|
@ -248,3 +319,7 @@ withBinding i p = local (\st -> st { vars = M.insert i p (vars st) })
|
|||
-- | Insert a function signature into the environment
|
||||
insertSig :: Ident -> Type -> Infer ()
|
||||
insertSig i t = modify (\st -> st { sigs = M.insert i t (sigs st) })
|
||||
|
||||
-- | Insert a constructor with its data type
|
||||
insertConstr :: Ident -> Type -> Infer ()
|
||||
insertConstr i t = modify (\st -> st { dtypes = M.insert i t (dtypes st) })
|
||||
|
|
|
|||
|
|
@ -5,12 +5,12 @@ module TypeChecker.TypeCheckerIr
|
|||
, module TypeChecker.TypeCheckerIr
|
||||
) where
|
||||
|
||||
import Grammar.Abs (Ident (..), Literal (..), Type (..))
|
||||
import Grammar.Abs (Data (..), Ident (..), Literal (..), Type (..))
|
||||
import Grammar.Print
|
||||
import Prelude
|
||||
import qualified Prelude as C (Eq, Ord, Read, Show)
|
||||
|
||||
newtype Program = Program [Bind]
|
||||
newtype Program = Program [Def]
|
||||
deriving (C.Eq, C.Ord, C.Show, C.Read)
|
||||
|
||||
data Exp
|
||||
|
|
@ -22,11 +22,18 @@ data Exp
|
|||
| EAbs Type Id Exp
|
||||
deriving (C.Eq, C.Ord, C.Read, C.Show)
|
||||
|
||||
data Def = DBind Bind | DData Data
|
||||
deriving (C.Eq, C.Ord, C.Read, C.Show)
|
||||
|
||||
type Id = (Ident, Type)
|
||||
|
||||
data Bind = Bind Id [Id] Exp
|
||||
deriving (C.Eq, C.Ord, C.Show, C.Read)
|
||||
|
||||
instance Print Def where
|
||||
prt i (DBind bind) = prt i bind
|
||||
prt i (DData d) = prt i d
|
||||
|
||||
instance Print Program where
|
||||
prt i (Program sc) = prPrec i 0 $ prt 0 sc
|
||||
|
||||
|
|
@ -75,7 +82,7 @@ instance Print Exp where
|
|||
, doc $ showString "in"
|
||||
, prt 0 e
|
||||
]
|
||||
EApp t e1 e2 -> prPrec i 2 $ concatD
|
||||
EApp _ e1 e2 -> prPrec i 2 $ concatD
|
||||
[ prt 2 e1
|
||||
, prt 3 e2
|
||||
]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue