typechecker working, still unsure of quality

This commit is contained in:
sebastianselander 2023-03-23 14:18:23 +01:00
parent 8d1330ad42
commit 7fa677e3d3
2 changed files with 56 additions and 58 deletions

View file

@ -15,7 +15,6 @@ import Data.List (foldl')
import Data.List.Extra (unsnoc) import Data.List.Extra (unsnoc)
import Data.Map (Map) import Data.Map (Map)
import Data.Map qualified as M import Data.Map qualified as M
import Data.Maybe (fromMaybe)
import Data.Set (Set) import Data.Set (Set)
import Data.Set qualified as S import Data.Set qualified as S
import Debug.Trace (trace) import Debug.Trace (trace)
@ -26,7 +25,6 @@ import TypeChecker.TypeCheckerIr (
Env (..), Env (..),
Error, Error,
Infer, Infer,
Poly (..),
Subst, Subst,
) )
import TypeChecker.TypeCheckerIr qualified as T import TypeChecker.TypeCheckerIr qualified as T
@ -78,15 +76,21 @@ retType a = a
checkPrg :: Program -> Infer T.Program checkPrg :: Program -> Infer T.Program
checkPrg (Program bs) = do checkPrg (Program bs) = do
preRun bs preRun bs
bs' <- checkDef bs -- Type check the program twice to produce all top-level types in the first pass through
return $ T.Program bs' _ <- checkDef bs
bs'' <- checkDef bs
return $ T.Program bs''
where where
preRun :: [Def] -> Infer () preRun :: [Def] -> Infer ()
preRun [] = return () preRun [] = return ()
preRun (x : xs) = case x of preRun (x : xs) = case x of
-- TODO: Check for no overlapping signature definitions -- TODO: Check for no overlapping signature definitions
DSig (Sig n t) -> insertSig (coerce n) (toNew t) >> preRun xs DSig (Sig n t) -> insertSig (coerce n) (Just $ toNew t) >> preRun xs
DBind (Bind{}) -> preRun xs DBind (Bind n _ _) -> do
s <- gets sigs
case M.lookup (coerce n) s of
Nothing -> insertSig (coerce n) Nothing >> preRun xs
Just _ -> preRun xs
DData d@(Data _ _) -> checkData d >> preRun xs DData d@(Data _ _) -> checkData d >> preRun xs
checkDef :: [Def] -> Infer [T.Def] checkDef :: [Def] -> Infer [T.Def]
@ -102,25 +106,33 @@ checkBind :: Bind -> Infer T.Bind
checkBind (Bind name args e) = do checkBind (Bind name args e) = do
let lambda = makeLambda e (reverse $ coerce args) let lambda = makeLambda e (reverse $ coerce args)
e@(_, t') <- inferExp lambda e@(_, t') <- inferExp lambda
-- TODO: Check for match against existing signatures s <- gets sigs
return $ T.Bind (coerce name, t') [] e -- (apply s e) -- let fs = map (second Just) $ getFunctionTypes s e
-- mapM_ (uncurry insertSig) fs
case M.lookup (coerce name) s of
Just (Just t) -> do
sub <- unify t t'
let newT = apply sub t
insertSig (coerce name) (Just newT)
return $ T.Bind (coerce name, newT) [] e
_ -> do
insertSig (coerce name) (Just t')
return (T.Bind (coerce name, t') [] e) -- (apply s e)
where where
makeLambda :: Exp -> [Ident] -> Exp makeLambda :: Exp -> [Ident] -> Exp
makeLambda = foldl (flip (EAbs . coerce)) makeLambda = foldl (flip (EAbs . coerce))
{- | Check if two types are considered equal -- getFunctionTypes :: Map Ident (Maybe T.Type) -> T.ExpT -> [(Ident, T.Type)]
For the purpose of the algorithm two polymorphic types are always considered -- getFunctionTypes s = \case
equal -- (T.EId b, t) -> case M.lookup b s of
-} -- Just Nothing -> return (b, t)
typeEq :: Type -> Type -> Bool -- _ -> []
typeEq (TFun l r) (TFun l' r') = typeEq l l' && typeEq r r' -- (T.ELit _, _) -> []
typeEq (TLit a) (TLit b) = a == b -- (T.ELet (T.Bind _ _ e1) e2, _) -> getFunctionTypes s e1 ++ getFunctionTypes s e2
typeEq (TIndexed (Indexed name a)) (TIndexed (Indexed name' b)) = -- (T.EApp e1 e2, _) -> getFunctionTypes s e1 ++ getFunctionTypes s e2
length a == length b -- (T.EAdd e1 e2, _) -> getFunctionTypes s e1 ++ getFunctionTypes s e2
&& name == name' -- (T.EAbs _ e, _) -> getFunctionTypes s e
&& and (zipWith typeEq a b) -- (T.ECase e injs, _) -> getFunctionTypes s e ++ concatMap (getFunctionTypes s . \(T.Inj _ e) -> e) injs
typeEq (TAll n1 t1) (TAll n2 t2) = t1 `typeEq` t2
typeEq _ _ = False
isMoreSpecificOrEq :: T.Type -> T.Type -> Bool isMoreSpecificOrEq :: T.Type -> T.Type -> Bool
isMoreSpecificOrEq _ (T.TAll _ _) = True isMoreSpecificOrEq _ (T.TAll _ _) = True
@ -193,20 +205,20 @@ algoW = \case
-- \| x : σ ∈ Γ τ = inst(σ) -- \| x : σ ∈ Γ τ = inst(σ)
-- \| ---------------------- -- \| ----------------------
-- \| Γ ⊢ x : τ, ∅ -- \| Γ ⊢ x : τ, ∅
EVar i -> do EVar i -> do
var <- asks vars var <- asks vars
case M.lookup (coerce i) var of case M.lookup (coerce i) var of
Just t -> inst t >>= \x -> return (nullSubst, (T.EId (coerce i, x), x)) Just t -> inst t >>= \x -> return (nullSubst, (T.EId $ coerce i, x))
Nothing -> do Nothing -> do
sig <- gets sigs sig <- gets sigs
case M.lookup (coerce i) sig of case M.lookup (coerce i) sig of
Just t -> return (nullSubst, (T.EId (coerce i, t), t)) Just (Just t) -> return (nullSubst, (T.EId $ coerce i, t))
Nothing -> throwError $ "Unbound variable: " ++ show i Just Nothing -> (\x -> (nullSubst, (T.EId $ coerce i, x))) <$> fresh
Nothing -> throwError $ "Unbound variable: " ++ printTree i
ECons i -> do ECons i -> do
constr <- gets constructors constr <- gets constructors
case M.lookup (coerce i) constr of case M.lookup (coerce i) constr of
Just t -> return (nullSubst, (T.EId (coerce i, t), t)) Just t -> return (nullSubst, (T.EId $ coerce i, t))
Nothing -> throwError $ "Constructor: '" ++ printTree i ++ "' is not defined" Nothing -> throwError $ "Constructor: '" ++ printTree i ++ "' is not defined"
-- \| τ = newvar Γ, x : τ ⊢ e : τ', S -- \| τ = newvar Γ, x : τ ⊢ e : τ', S
@ -219,7 +231,7 @@ algoW = \case
(s1, (e', t')) <- algoW e (s1, (e', t')) <- algoW e
let varType = apply s1 fr let varType = apply s1 fr
let newArr = T.TFun varType t' let newArr = T.TFun varType t'
return (s1, apply s1 (T.EAbs (coerce name, varType) (e', newArr), newArr)) return (s1, apply s1 (T.EAbs (coerce name, varType) (e', t'), newArr))
-- \| Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S₁ -- \| Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S₁
-- \| s₂ = mgu(s₁τ₀, Int) s₃ = mgu(s₂τ₁, Int) -- \| s₂ = mgu(s₁τ₀, Int) s₃ = mgu(s₂τ₁, Int)
@ -250,7 +262,6 @@ algoW = \case
(s0, (e0', t0)) <- algoW e0 (s0, (e0', t0)) <- algoW e0
applySt s0 $ do applySt s0 $ do
(s1, (e1', t1)) <- algoW e1 (s1, (e1', t1)) <- algoW e1
-- applySt s1 $ do
s2 <- unify (apply s1 t0) (T.TFun t1 fr) s2 <- unify (apply s1 t0) (T.TFun t1 fr)
let t = apply s2 fr let t = apply s2 fr
let comp = s2 `compose` s1 `compose` s0 let comp = s2 `compose` s1 `compose` s0
@ -309,17 +320,10 @@ unify t0 t1 = do
, "(" ++ printTree t' ++ ")" , "(" ++ printTree t' ++ ")"
] ]
(a, b) -> do (a, b) -> do
ctx <- ask
env <- get
throwError . unwords $ throwError . unwords $
[ "T.Type:" [ "'" ++ printTree a ++ "'"
, printTree a , "can't be unified with"
, "can't be unified with:" , "'" ++ printTree b ++ "'"
, printTree b
, "\nCtx:"
, show ctx
, "\nEnv:"
, show env
] ]
{- | Check if a type is contained in another type. {- | Check if a type is contained in another type.
@ -415,7 +419,7 @@ instance FreeVars T.ExpT where
free = error "free not implemented for T.Exp" free = error "free not implemented for T.Exp"
apply :: Subst -> T.ExpT -> T.ExpT apply :: Subst -> T.ExpT -> T.ExpT
apply s = \case apply s = \case
(T.EId (i, innerT), outerT) -> (T.EId (i, apply s innerT), apply s outerT) (T.EId i, outerT) -> (T.EId i, apply s outerT)
(T.ELit lit, t) -> (T.ELit lit, apply s t) (T.ELit lit, t) -> (T.ELit lit, apply s t)
(T.ELet (T.Bind (ident, t1) args e1) e2, t2) -> (T.ELet (T.Bind (ident, apply s t1) args (apply s e1)) (apply s e2), apply s t2) (T.ELet (T.Bind (ident, t1) args e1) e2, t2) -> (T.ELet (T.Bind (ident, apply s t1) args (apply s e1)) (apply s e2), apply s t2)
(T.EApp e1 e2, t) -> (T.EApp (apply s e1) (apply s e2), apply s t) (T.EApp e1 e2, t) -> (T.EApp (apply s e1) (apply s e2), apply s t)
@ -459,7 +463,7 @@ withBindings xs =
local (\st -> st{vars = foldl' (flip (uncurry M.insert)) (vars st) xs}) local (\st -> st{vars = foldl' (flip (uncurry M.insert)) (vars st) xs})
-- | Insert a function signature into the environment -- | Insert a function signature into the environment
insertSig :: Ident -> T.Type -> Infer () insertSig :: Ident -> Maybe T.Type -> Infer ()
insertSig i t = modify (\st -> st{sigs = M.insert i t (sigs st)}) insertSig i t = modify (\st -> st{sigs = M.insert i t (sigs st)})
-- | Insert a constructor with its data type -- | Insert a constructor with its data type

View file

@ -17,16 +17,12 @@ import Grammar.Print
import Prelude import Prelude
import Prelude qualified as C (Eq, Ord, Read, Show) import Prelude qualified as C (Eq, Ord, Read, Show)
-- | A data type representing type variables
data Poly = Forall [Ident] Type
deriving (Show)
newtype Ctx = Ctx {vars :: Map Ident Type} newtype Ctx = Ctx {vars :: Map Ident Type}
deriving (Show) deriving (Show)
data Env = Env data Env = Env
{ count :: Int { count :: Int
, sigs :: Map Ident Type , sigs :: Map Ident (Maybe Type)
, constructors :: Map Ident Type , constructors :: Map Ident Type
} }
deriving (Show) deriving (Show)
@ -39,7 +35,7 @@ type Infer = StateT Env (ReaderT Ctx (ExceptT Error Identity))
newtype Program = Program [Def] newtype Program = Program [Def]
deriving (C.Eq, C.Ord, C.Show, C.Read) deriving (C.Eq, C.Ord, C.Show, C.Read)
data TVar = MkTVar Ident newtype TVar = MkTVar Ident
deriving (Show, Eq, Ord, Read) deriving (Show, Eq, Ord, Read)
data Type data Type
@ -51,7 +47,7 @@ data Type
deriving (Show, Eq, Ord, Read) deriving (Show, Eq, Ord, Read)
data Exp data Exp
= EId Id = EId Ident
| ELit Lit | ELit Lit
| ELet Bind ExpT | ELet Bind ExpT
| EApp ExpT ExpT | EApp ExpT ExpT
@ -78,7 +74,7 @@ data Bind = Bind Id [Id] ExpT
instance Print [Def] where instance Print [Def] where
prt _ [] = concatD [] prt _ [] = concatD []
prt _ (x : xs) = concatD [prt 0 x, doc (showString "\n"), prt 0 xs] prt _ (x : xs) = concatD [prt 0 x, doc (showString "\n\n"), prt 0 xs]
instance Print Def where instance Print Def where
prt i (DBind bind) = prt i bind prt i (DBind bind) = prt i bind
@ -88,7 +84,7 @@ instance Print Program where
prt i (Program sc) = prPrec i 0 $ prt 0 sc prt i (Program sc) = prPrec i 0 $ prt 0 sc
instance Print Bind where instance Print Bind where
prt i (Bind (t, name) args rhs) = prt i (Bind (name, t) _ rhs) =
prPrec i 0 $ prPrec i 0 $
concatD concatD
[ prt 0 name [ prt 0 name
@ -112,9 +108,11 @@ prtId :: Int -> Id -> Doc
prtId i (name, t) = prtId i (name, t) =
prPrec i 0 $ prPrec i 0 $
concatD concatD
[ prt 0 name [ doc $ showString "("
, prt 0 name
, doc $ showString ":" , doc $ showString ":"
, prt 0 t , prt 0 t
, doc $ showString ")"
] ]
prtIdP :: Int -> Id -> Doc prtIdP :: Int -> Id -> Doc
@ -130,8 +128,8 @@ prtIdP i (name, t) =
instance Print Exp where instance Print Exp where
prt i = \case prt i = \case
EId n -> prPrec i 3 $ concatD [prtId 0 n, doc $ showString "\n"] EId n -> prPrec i 3 $ concatD [prt 0 n]
ELit lit -> prPrec i 3 $ concatD [prt 0 lit, doc $ showString "\n"] ELit lit -> prPrec i 3 $ concatD [prt 0 lit]
ELet bs e -> ELet bs e ->
prPrec i 3 $ prPrec i 3 $
concatD concatD
@ -139,7 +137,6 @@ instance Print Exp where
, prt 0 bs , prt 0 bs
, doc $ showString "in" , doc $ showString "in"
, prt 0 e , prt 0 e
, doc $ showString "\n"
] ]
EApp e1 e2 -> EApp e1 e2 ->
prPrec i 2 $ prPrec i 2 $
@ -154,16 +151,14 @@ instance Print Exp where
, prt 1 e1 , prt 1 e1
, doc $ showString "+" , doc $ showString "+"
, prt 2 e2 , prt 2 e2
, doc $ showString "\n"
] ]
EAbs n e -> EAbs n e ->
prPrec i 0 $ prPrec i 0 $
concatD concatD
[ doc $ showString "@" [ doc $ showString "λ"
, prtId 0 n , prtId 0 n
, doc $ showString "." , doc $ showString "."
, prt 0 e , prt 0 e
, doc $ showString "\n"
] ]
ECase exp injs -> ECase exp injs ->
prPrec prPrec
@ -177,12 +172,11 @@ instance Print Exp where
, prt 0 injs , prt 0 injs
, doc (showString "}") , doc (showString "}")
, doc (showString ":") , doc (showString ":")
, doc $ showString "\n"
] ]
) )
instance Print ExpT where instance Print ExpT where
prt i (e, t) = concatD [prt i e, doc (showString ":"), prt i t] prt i (e, t) = concatD [doc $ showString "(", prt i e, doc (showString ":"), prt i t, doc $ showString ")"]
instance Print Inj where instance Print Inj where
prt i = \case prt i = \case