From a98135827c5fb1df1afeb5387df4199abe2dc50d Mon Sep 17 00:00:00 2001 From: sebastianselander Date: Mon, 20 Feb 2023 16:51:44 +0100 Subject: [PATCH] EAdd is bugged. Mostly complete though. --- src/TypeChecker/AlgoW.hs | 183 ++++++++++++++++++++++++--------------- 1 file changed, 113 insertions(+), 70 deletions(-) diff --git a/src/TypeChecker/AlgoW.hs b/src/TypeChecker/AlgoW.hs index e630da2..3667761 100644 --- a/src/TypeChecker/AlgoW.hs +++ b/src/TypeChecker/AlgoW.hs @@ -8,108 +8,151 @@ import Control.Monad.Reader import Control.Monad.State import Data.Bifunctor (bimap, second) import Data.Functor.Identity (Identity, runIdentity) -import Data.List (intersect) +import Data.List (foldl', intersect) import Data.Map (Map) import qualified Data.Map as M import Data.Maybe (fromMaybe) +import Data.Set (Set) +import qualified Data.Set as S import Grammar.Abs +import Grammar.Print (printTree) import qualified TypeChecker.HMIr as T data Poly = Forall [Ident] Type deriving Show -a = TPol "a" -b = TPol "b" -int = TMono "int" -arr = TArr - data Ctx = Ctx { vars :: Map Ident Poly - , sigs :: Map Ident Poly } + , sigs :: Map Ident Type } -data Env = Env { counter :: Int - , substitutions :: Map Type Type - } - -type Subst = Map Type Type type Error = String +type Subst = Map Ident Type -newtype Infer a = Infer { runInfer :: StateT Env (ReaderT Ctx (ExceptT Error Identity)) a } - deriving (Functor, Applicative, Monad, MonadState Env, MonadReader Ctx, MonadError Error) +type Infer = StateT Int (ReaderT Ctx (ExceptT Error Identity)) -initCtx :: Ctx initCtx = Ctx mempty mempty -initEnv :: Env -initEnv = Env 0 mempty +run :: Infer a -> Either Error a +run = runC initCtx 0 -run :: Ctx -> Env -> Infer a -> Either Error a -run c e = runIdentity . runExceptT . flip runReaderT c . flip evalStateT e . runInfer +runC :: Ctx -> Int -> Infer a -> Either Error a +runC c e = runIdentity . runExceptT . flip runReaderT c . flip evalStateT e -w :: Exp -> Infer Type -w = \case - EInt n -> return int - EId i -> (\(Forall _ t) -> t) <$> (lookupVar i >>= inst) +inferExp :: Exp -> Infer Type +inferExp e = snd <$> w nullSubst e + +w :: Subst -> Exp -> Infer (Subst, Type) +w s = \case + EAnn e t -> do + (s1, t') <- w nullSubst e + let t'' = apply s1 t + return (s1, t'') + EInt n -> return (nullSubst, TMono "Int") + EId i -> do + var <- asks vars + case M.lookup i var of + Nothing -> throwError $ "Unbound variable: " ++ show i + Just t -> (nullSubst,) <$> inst t EAbs var e -> do fr <- fresh - withBinding var (Forall [] (TPol fr)) $ do - t' <- w e - subst (Forall [] $ TArr (TPol fr) t') + withBinding var (Forall [] fr) $ do + (s1, t') <- w s e + return (s, TArr (apply s1 fr) t') + EAdd e0 e1 -> do + (s1, t1) <- w s e0 + (s2, t2) <- w s1 e1 + return (s2, TMono "Int") EApp e0 e1 -> do - t0 <- substCtx (w e0) - t1 <- w e1 - undefined + fr <- fresh + (s1, t0) <- w s e0 + (s2, t1) <- w s1 e1 + s3 <- unify (subst s2 t0) (TArr t1 fr) + return (s3 `compose` s2 `compose` s1, apply s3 fr) + ELet name e0 e1 -> do + (s1, t1) <- w s e0 + env <- asks vars + let t' = generalize (apply s1 env) t1 + withBinding name t' $ do + (s2, t2) <- w s1 e1 + return (s1 `compose` s2, t2) -substCtx :: Infer Type -> Infer Type -substCtx m = do - vs <- asks (M.toList . vars) - ks <- traverse (subst . snd) vs - let x = map fst vs - local (\st -> st { vars = M.fromList $ zip x ks }) m +unify :: Type -> Type -> Infer Subst +unify t0 t1 = case (t0, t1) of + (TArr a b, TArr c d) -> do + s1 <- unify a c + s2 <- unify (subst s1 b) (subst s1 c) + return $ s1 `compose` s2 + (TPol a, b) -> occurs a b + (a, TPol b) -> occurs b a + (TMono a, TMono b) -> if a == b then return M.empty else throwError "Types do not unify" -subst :: Poly -> Infer Poly -subst (Forall xs t) = do - subs <- gets substitutions - case t of - TPol a -> case M.lookup (TPol a) subs of - Nothing -> return $ Forall xs t - Just t' -> return $ Forall (remove a xs) t' - TMono a -> case M.lookup (TMono a) subs of - Nothing -> return $ Forall xs t - Just t' -> return $ Forall (remove a xs) t' - TArr a b -> do - (Forall xs' a') <- subst (Forall xs a) - (Forall xs'' b') <- subst (Forall xs b) - return $ Forall (xs' `intersect` xs'') (TArr a' b') +occurs :: Ident -> Type -> Infer Subst +occurs i (TPol a) = return nullSubst +occurs i t = if S.member i (free t) + then throwError "Occurs check failed" + else return $ M.singleton i t +generalize :: Map Ident Poly -> Type -> Poly +generalize env t = Forall (S.toList $ free t S.\\ free env) t -remove :: Ord a => a -> [a] -> [a] -remove a = foldr (\x acc -> if x == a then acc else x : acc) [] - -inst :: Poly -> Infer Poly +inst :: Poly -> Infer Type inst (Forall xs t) = do xs' <- mapM (const fresh) xs - let sub = zip xs xs' - let subst' t = case t of - TMono a -> return $ TMono a - TPol a -> case lookup a sub of - Nothing -> return $ TPol a - Just t -> return $ TPol t - TArr a b -> TArr <$> subst' a <*> subst' b - Forall [] <$> subst' t + let s = M.fromList $ zip xs xs' + return $ apply s t + +compose :: Subst -> Subst -> Subst +compose m1 m2 = M.map (subst m1) m2 `M.union` m1 + +class FreeVars t where + free :: t -> Set Ident + apply :: Subst -> t -> t + +instance FreeVars Type where + free :: Type -> Set Ident + free (TPol a) = S.singleton a + free (TMono _) = mempty + free (TArr a b) = free a `S.union` free b + apply :: Subst -> Type -> Type + apply sub t = do + case t of + TMono a -> TMono a + TPol a -> case M.lookup a sub of + Nothing -> TPol a + Just t -> t + TArr a b -> TArr (apply sub a) (apply sub b) + +instance FreeVars Poly where + free :: Poly -> Set Ident + free (Forall xs t) = free t S.\\ S.fromList xs + apply :: Subst -> Poly -> Poly + apply s (Forall xs t) = Forall xs (apply (foldr M.delete s xs) t) + +instance FreeVars (Map Ident Poly) where + free :: Map Ident Poly -> Set Ident + free m = foldl' S.union S.empty (map free $ M.elems m) + apply :: Subst -> Map Ident Poly -> Map Ident Poly + apply s = M.map (apply s) + +nullSubst :: Subst +nullSubst = M.empty + +subst :: Subst -> Type -> Type +subst m t = do + case t of + TPol a -> fromMaybe t (M.lookup a m) + TMono a -> TMono a + TArr a b -> TArr (subst m a) (subst m b) -- | Generate a new fresh variable and increment the state -fresh :: Infer Ident +fresh :: Infer Type fresh = do - n <- gets counter - modify (\st -> st { counter = n + 1 }) - return . Ident $ "t" ++ show n + n <- get + put (n + 1) + return . TPol . Ident $ "t" ++ show n -insertSub :: Type -> Type -> Infer () -insertSub t1 t2 = modify (\st -> st { substitutions = M.insert t1 t2 (substitutions st) }) - -withBinding :: Ident -> Poly -> Infer Poly -> Infer Type -withBinding i t m = (\(Forall _ t) -> t) <$> local (\re -> re { vars = M.insert i t (vars re) }) m +withBinding :: Ident -> Poly -> Infer (Subst, Type) -> Infer (Subst, Type) +withBinding i t = local (\re -> re { vars = M.insert i t (vars re) }) lookupVar :: Ident -> Infer Poly lookupVar i = do