From 7fa677e3d3729b2755d35ba380905d6ea1deb43c Mon Sep 17 00:00:00 2001 From: sebastianselander Date: Thu, 23 Mar 2023 14:18:23 +0100 Subject: [PATCH] typechecker working, still unsure of quality --- src/TypeChecker/TypeChecker.hs | 84 +++++++++++++++++--------------- src/TypeChecker/TypeCheckerIr.hs | 30 +++++------- 2 files changed, 56 insertions(+), 58 deletions(-) diff --git a/src/TypeChecker/TypeChecker.hs b/src/TypeChecker/TypeChecker.hs index 112bf7d..7da23a6 100644 --- a/src/TypeChecker/TypeChecker.hs +++ b/src/TypeChecker/TypeChecker.hs @@ -15,7 +15,6 @@ import Data.List (foldl') import Data.List.Extra (unsnoc) import Data.Map (Map) import Data.Map qualified as M -import Data.Maybe (fromMaybe) import Data.Set (Set) import Data.Set qualified as S import Debug.Trace (trace) @@ -26,7 +25,6 @@ import TypeChecker.TypeCheckerIr ( Env (..), Error, Infer, - Poly (..), Subst, ) import TypeChecker.TypeCheckerIr qualified as T @@ -78,15 +76,21 @@ retType a = a checkPrg :: Program -> Infer T.Program checkPrg (Program bs) = do preRun bs - bs' <- checkDef bs - return $ T.Program bs' + -- Type check the program twice to produce all top-level types in the first pass through + _ <- checkDef bs + bs'' <- checkDef bs + return $ T.Program bs'' where preRun :: [Def] -> Infer () preRun [] = return () preRun (x : xs) = case x of -- TODO: Check for no overlapping signature definitions - DSig (Sig n t) -> insertSig (coerce n) (toNew t) >> preRun xs - DBind (Bind{}) -> preRun xs + DSig (Sig n t) -> insertSig (coerce n) (Just $ toNew t) >> preRun xs + DBind (Bind n _ _) -> do + s <- gets sigs + case M.lookup (coerce n) s of + Nothing -> insertSig (coerce n) Nothing >> preRun xs + Just _ -> preRun xs DData d@(Data _ _) -> checkData d >> preRun xs checkDef :: [Def] -> Infer [T.Def] @@ -102,25 +106,33 @@ checkBind :: Bind -> Infer T.Bind checkBind (Bind name args e) = do let lambda = makeLambda e (reverse $ coerce args) e@(_, t') <- inferExp lambda - -- TODO: Check for match against existing signatures - return $ T.Bind (coerce name, t') [] e -- (apply s e) + s <- gets sigs + -- let fs = map (second Just) $ getFunctionTypes s e + -- mapM_ (uncurry insertSig) fs + case M.lookup (coerce name) s of + Just (Just t) -> do + sub <- unify t t' + let newT = apply sub t + insertSig (coerce name) (Just newT) + return $ T.Bind (coerce name, newT) [] e + _ -> do + insertSig (coerce name) (Just t') + return (T.Bind (coerce name, t') [] e) -- (apply s e) where makeLambda :: Exp -> [Ident] -> Exp makeLambda = foldl (flip (EAbs . coerce)) -{- | Check if two types are considered equal - For the purpose of the algorithm two polymorphic types are always considered - equal --} -typeEq :: Type -> Type -> Bool -typeEq (TFun l r) (TFun l' r') = typeEq l l' && typeEq r r' -typeEq (TLit a) (TLit b) = a == b -typeEq (TIndexed (Indexed name a)) (TIndexed (Indexed name' b)) = - length a == length b - && name == name' - && and (zipWith typeEq a b) -typeEq (TAll n1 t1) (TAll n2 t2) = t1 `typeEq` t2 -typeEq _ _ = False + -- getFunctionTypes :: Map Ident (Maybe T.Type) -> T.ExpT -> [(Ident, T.Type)] + -- getFunctionTypes s = \case + -- (T.EId b, t) -> case M.lookup b s of + -- Just Nothing -> return (b, t) + -- _ -> [] + -- (T.ELit _, _) -> [] + -- (T.ELet (T.Bind _ _ e1) e2, _) -> getFunctionTypes s e1 ++ getFunctionTypes s e2 + -- (T.EApp e1 e2, _) -> getFunctionTypes s e1 ++ getFunctionTypes s e2 + -- (T.EAdd e1 e2, _) -> getFunctionTypes s e1 ++ getFunctionTypes s e2 + -- (T.EAbs _ e, _) -> getFunctionTypes s e + -- (T.ECase e injs, _) -> getFunctionTypes s e ++ concatMap (getFunctionTypes s . \(T.Inj _ e) -> e) injs isMoreSpecificOrEq :: T.Type -> T.Type -> Bool isMoreSpecificOrEq _ (T.TAll _ _) = True @@ -193,20 +205,20 @@ algoW = \case -- \| x : σ ∈ Γ   τ = inst(σ) -- \| ---------------------- -- \| Γ ⊢ x : τ, ∅ - EVar i -> do var <- asks vars case M.lookup (coerce i) var of - Just t -> inst t >>= \x -> return (nullSubst, (T.EId (coerce i, x), x)) + Just t -> inst t >>= \x -> return (nullSubst, (T.EId $ coerce i, x)) Nothing -> do sig <- gets sigs case M.lookup (coerce i) sig of - Just t -> return (nullSubst, (T.EId (coerce i, t), t)) - Nothing -> throwError $ "Unbound variable: " ++ show i + Just (Just t) -> return (nullSubst, (T.EId $ coerce i, t)) + Just Nothing -> (\x -> (nullSubst, (T.EId $ coerce i, x))) <$> fresh + Nothing -> throwError $ "Unbound variable: " ++ printTree i ECons i -> do constr <- gets constructors case M.lookup (coerce i) constr of - Just t -> return (nullSubst, (T.EId (coerce i, t), t)) + Just t -> return (nullSubst, (T.EId $ coerce i, t)) Nothing -> throwError $ "Constructor: '" ++ printTree i ++ "' is not defined" -- \| τ = newvar Γ, x : τ ⊢ e : τ', S @@ -219,7 +231,7 @@ algoW = \case (s1, (e', t')) <- algoW e let varType = apply s1 fr let newArr = T.TFun varType t' - return (s1, apply s1 (T.EAbs (coerce name, varType) (e', newArr), newArr)) + return (s1, apply s1 (T.EAbs (coerce name, varType) (e', t'), newArr)) -- \| Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S₁ -- \| s₂ = mgu(s₁τ₀, Int) s₃ = mgu(s₂τ₁, Int) @@ -250,7 +262,6 @@ algoW = \case (s0, (e0', t0)) <- algoW e0 applySt s0 $ do (s1, (e1', t1)) <- algoW e1 - -- applySt s1 $ do s2 <- unify (apply s1 t0) (T.TFun t1 fr) let t = apply s2 fr let comp = s2 `compose` s1 `compose` s0 @@ -309,17 +320,10 @@ unify t0 t1 = do , "(" ++ printTree t' ++ ")" ] (a, b) -> do - ctx <- ask - env <- get throwError . unwords $ - [ "T.Type:" - , printTree a - , "can't be unified with:" - , printTree b - , "\nCtx:" - , show ctx - , "\nEnv:" - , show env + [ "'" ++ printTree a ++ "'" + , "can't be unified with" + , "'" ++ printTree b ++ "'" ] {- | Check if a type is contained in another type. @@ -415,7 +419,7 @@ instance FreeVars T.ExpT where free = error "free not implemented for T.Exp" apply :: Subst -> T.ExpT -> T.ExpT apply s = \case - (T.EId (i, innerT), outerT) -> (T.EId (i, apply s innerT), apply s outerT) + (T.EId i, outerT) -> (T.EId i, apply s outerT) (T.ELit lit, t) -> (T.ELit lit, apply s t) (T.ELet (T.Bind (ident, t1) args e1) e2, t2) -> (T.ELet (T.Bind (ident, apply s t1) args (apply s e1)) (apply s e2), apply s t2) (T.EApp e1 e2, t) -> (T.EApp (apply s e1) (apply s e2), apply s t) @@ -459,7 +463,7 @@ withBindings xs = local (\st -> st{vars = foldl' (flip (uncurry M.insert)) (vars st) xs}) -- | Insert a function signature into the environment -insertSig :: Ident -> T.Type -> Infer () +insertSig :: Ident -> Maybe T.Type -> Infer () insertSig i t = modify (\st -> st{sigs = M.insert i t (sigs st)}) -- | Insert a constructor with its data type diff --git a/src/TypeChecker/TypeCheckerIr.hs b/src/TypeChecker/TypeCheckerIr.hs index 9cf2059..7c24ab3 100644 --- a/src/TypeChecker/TypeCheckerIr.hs +++ b/src/TypeChecker/TypeCheckerIr.hs @@ -17,16 +17,12 @@ import Grammar.Print import Prelude import Prelude qualified as C (Eq, Ord, Read, Show) --- | A data type representing type variables -data Poly = Forall [Ident] Type - deriving (Show) - newtype Ctx = Ctx {vars :: Map Ident Type} deriving (Show) data Env = Env { count :: Int - , sigs :: Map Ident Type + , sigs :: Map Ident (Maybe Type) , constructors :: Map Ident Type } deriving (Show) @@ -39,7 +35,7 @@ type Infer = StateT Env (ReaderT Ctx (ExceptT Error Identity)) newtype Program = Program [Def] deriving (C.Eq, C.Ord, C.Show, C.Read) -data TVar = MkTVar Ident +newtype TVar = MkTVar Ident deriving (Show, Eq, Ord, Read) data Type @@ -51,7 +47,7 @@ data Type deriving (Show, Eq, Ord, Read) data Exp - = EId Id + = EId Ident | ELit Lit | ELet Bind ExpT | EApp ExpT ExpT @@ -78,7 +74,7 @@ data Bind = Bind Id [Id] ExpT instance Print [Def] where prt _ [] = concatD [] - prt _ (x : xs) = concatD [prt 0 x, doc (showString "\n"), prt 0 xs] + prt _ (x : xs) = concatD [prt 0 x, doc (showString "\n\n"), prt 0 xs] instance Print Def where prt i (DBind bind) = prt i bind @@ -88,7 +84,7 @@ instance Print Program where prt i (Program sc) = prPrec i 0 $ prt 0 sc instance Print Bind where - prt i (Bind (t, name) args rhs) = + prt i (Bind (name, t) _ rhs) = prPrec i 0 $ concatD [ prt 0 name @@ -112,9 +108,11 @@ prtId :: Int -> Id -> Doc prtId i (name, t) = prPrec i 0 $ concatD - [ prt 0 name + [ doc $ showString "(" + , prt 0 name , doc $ showString ":" , prt 0 t + , doc $ showString ")" ] prtIdP :: Int -> Id -> Doc @@ -130,8 +128,8 @@ prtIdP i (name, t) = instance Print Exp where prt i = \case - EId n -> prPrec i 3 $ concatD [prtId 0 n, doc $ showString "\n"] - ELit lit -> prPrec i 3 $ concatD [prt 0 lit, doc $ showString "\n"] + EId n -> prPrec i 3 $ concatD [prt 0 n] + ELit lit -> prPrec i 3 $ concatD [prt 0 lit] ELet bs e -> prPrec i 3 $ concatD @@ -139,7 +137,6 @@ instance Print Exp where , prt 0 bs , doc $ showString "in" , prt 0 e - , doc $ showString "\n" ] EApp e1 e2 -> prPrec i 2 $ @@ -154,16 +151,14 @@ instance Print Exp where , prt 1 e1 , doc $ showString "+" , prt 2 e2 - , doc $ showString "\n" ] EAbs n e -> prPrec i 0 $ concatD - [ doc $ showString "@" + [ doc $ showString "λ" , prtId 0 n , doc $ showString "." , prt 0 e - , doc $ showString "\n" ] ECase exp injs -> prPrec @@ -177,12 +172,11 @@ instance Print Exp where , prt 0 injs , doc (showString "}") , doc (showString ":") - , doc $ showString "\n" ] ) instance Print ExpT where - prt i (e, t) = concatD [prt i e, doc (showString ":"), prt i t] + prt i (e, t) = concatD [doc $ showString "(", prt i e, doc (showString ":"), prt i t, doc $ showString ")"] instance Print Inj where prt i = \case