{-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedStrings #-} -- | A module for type checking and inference using algorithm W, Hindley-Milner module TypeChecker.TypeChecker where import Auxiliary import Control.Monad.Except import Control.Monad.Reader import Control.Monad.State import Data.Bifunctor (second) import Data.Coerce (coerce) import Data.Foldable (traverse_) import Data.Functor.Identity (runIdentity) import Data.List (foldl') import Data.List.Extra (unsnoc) import Data.Map (Map) import Data.Map qualified as M import Data.Set (Set) import Data.Set qualified as S import Debug.Trace (trace) import Grammar.Abs import Grammar.Print (printTree) import TypeChecker.TypeCheckerIr ( Ctx (..), Env (..), Error, Infer, Subst, ) import TypeChecker.TypeCheckerIr qualified as T initCtx = Ctx mempty initEnv = Env 0 mempty mempty runPretty :: Exp -> Either Error String runPretty = fmap (printTree . fst) . run . inferExp run :: Infer a -> Either Error a run = runC initEnv initCtx runC :: Env -> Ctx -> Infer a -> Either Error a runC e c = runIdentity . runExceptT . flip runReaderT c . flip evalStateT e typecheck :: Program -> Either Error T.Program typecheck = run . checkPrg checkData :: Data -> Infer () checkData d = do case d of (Data typ@(TData name ts) constrs) -> do unless (all isPoly ts) (throwError $ unwords ["Data type incorrectly declared"]) traverse_ ( \(Constructor name' t') -> if typ == retType t' then insertConstr (coerce name') (toNew t') else throwError $ unwords [ "return type of constructor:" , printTree name , "with type:" , printTree (retType t') , "does not match data: " , printTree typ ] ) constrs _ -> throwError $ "incorrectly declared data type '" <> printTree d <> "'" retType :: Type -> Type retType (TFun _ t2) = retType t2 retType a = a checkPrg :: Program -> Infer T.Program checkPrg (Program bs) = do preRun bs -- Type check the program twice to produce all top-level types in the first pass through _ <- checkDef bs bs'' <- checkDef bs return $ T.Program bs'' where preRun :: [Def] -> Infer () preRun [] = return () preRun (x : xs) = case x of DSig (Sig n t) -> do gets (M.member (coerce n) . sigs) >>= flip when ( throwError $ "Duplicate signatures for function '" <> printTree n <> "'" ) insertSig (coerce n) (Just $ toNew t) >> preRun xs DBind (Bind n _ _) -> do s <- gets sigs case M.lookup (coerce n) s of Nothing -> insertSig (coerce n) Nothing >> preRun xs Just _ -> preRun xs DData d@(Data _ _) -> checkData d >> preRun xs checkDef :: [Def] -> Infer [T.Def] checkDef [] = return [] checkDef (x : xs) = case x of (DBind b) -> do b' <- checkBind b fmap (T.DBind b' :) (checkDef xs) (DData d) -> fmap (T.DData (toNew d) :) (checkDef xs) (DSig _) -> checkDef xs checkBind :: Bind -> Infer T.Bind checkBind err@(Bind name args e) = do let lambda = makeLambda e (reverse (coerce args)) e@(_, args_t) <- inferExp lambda -- args <- zip args <$> mapM (const fresh) args -- withBindings (coerce args) $ do -- e@(_, t) <- inferExp e -- let args_t = foldl' T.TFun t (reverse (map snd args)) s <- gets sigs case M.lookup (coerce name) s of Just (Just t') -> do -- sub <- bindErr (unify args_t t') err -- let newT = apply sub args_t -- insertSig (coerce name) (Just newT) -- return $ T.Bind (apply sub (coerce name, newT)) [] e unless (args_t `typeEq` t') ( throwError $ "Inferred type '" ++ printTree args_t ++ " does not match specified type '" ++ printTree t' ++ "'" ) return $ T.Bind (coerce name, t') [] e _ -> do insertSig (coerce name) (Just args_t) return (T.Bind (coerce name, args_t) [] e) typeEq :: T.Type -> T.Type -> Bool typeEq (T.TFun l r) (T.TFun l' r') = typeEq l l' && typeEq r r' typeEq (T.TLit a) (T.TLit b) = a == b typeEq (T.TData name a) (T.TData name' b) = length a == length b && name == name' && and (zipWith typeEq a b) typeEq (T.TAll _ t1) (T.TAll _ t2) = t1 `typeEq` t2 typeEq (T.TVar _) (T.TVar _) = True typeEq _ _ = False isMoreSpecificOrEq :: T.Type -> T.Type -> Bool isMoreSpecificOrEq t1 (T.TAll _ t2) = isMoreSpecificOrEq t1 t2 isMoreSpecificOrEq (T.TFun a b) (T.TFun c d) = isMoreSpecificOrEq a c && isMoreSpecificOrEq b d isMoreSpecificOrEq (T.TData n1 ts1) (T.TData n2 ts2) = n1 == n2 && length ts1 == length ts2 && and (zipWith isMoreSpecificOrEq ts1 ts2) isMoreSpecificOrEq a b = a == b isPoly :: Type -> Bool isPoly (TAll _ _) = True isPoly (TVar _) = True isPoly _ = False inferExp :: Exp -> Infer T.ExpT inferExp e = do (s, (e', t)) <- algoW e let subbed = apply s t return $ replace subbed (e', t) replace :: T.Type -> T.ExpT -> T.ExpT replace t = second (const t) class NewType a b where toNew :: a -> b instance NewType Type T.Type where toNew = \case TLit i -> T.TLit $ coerce i TVar v -> T.TVar $ toNew v TFun t1 t2 -> T.TFun (toNew t1) (toNew t2) TAll b t -> T.TAll (toNew b) (toNew t) TData i ts -> T.TData (coerce i) (map toNew ts) TEVar _ -> error "Should not exist after typechecker" instance NewType Lit T.Lit where toNew (LInt i) = T.LInt i toNew (LChar i) = T.LChar i instance NewType Data T.Data where toNew (Data t xs) = T.Data (name $ retType t) (toNew xs) where name (TData n _) = coerce n name _ = error "Bug in toNew Data -> T.Data" instance NewType Constructor T.Constructor where toNew (Constructor name xs) = T.Constructor (coerce name) (toNew xs) instance NewType TVar T.TVar where toNew (MkTVar i) = T.MkTVar $ coerce i instance NewType a b => NewType [a] [b] where toNew = map toNew algoW :: Exp -> Infer (Subst, T.ExpT) algoW = \case -- \| TODO: More testing need to be done. Unsure of the correctness of this err@(EAnn e t) -> do (s1, (e', t')) <- exprErr (algoW e) err unless (toNew t `isMoreSpecificOrEq` t') ( throwError $ unwords [ "Annotated type:" , printTree t , "does not match inferred type:" , printTree t' ] ) applySt s1 $ do s2 <- exprErr (unify (toNew t) t') err let comp = s2 `compose` s1 return (comp, apply comp (e', toNew t)) -- \| ------------------ -- \| Γ ⊢ i : Int, ∅ ELit lit -> return (nullSubst, (T.ELit $ toNew lit, litType lit)) -- \| x : σ ∈ Γ   τ = inst(σ) -- \| ---------------------- -- \| Γ ⊢ x : τ, ∅ EVar i -> do var <- asks vars case M.lookup (coerce i) var of Just t -> inst t >>= \x -> return (nullSubst, (T.EId $ coerce i, x)) Nothing -> do sig <- gets sigs case M.lookup (coerce i) sig of Just (Just t) -> return (nullSubst, (T.EId $ coerce i, t)) Just Nothing -> do fr <- fresh insertSig (coerce i) (Just fr) return (nullSubst, (T.EId $ coerce i, fr)) Nothing -> throwError $ "Unbound variable: " <> printTree i EInj i -> do constr <- gets constructors case M.lookup (coerce i) constr of Just t -> return (nullSubst, (T.EId $ coerce i, t)) Nothing -> throwError $ "Constructor: '" <> printTree i <> "' is not defined" -- \| τ = newvar Γ, x : τ ⊢ e : τ', S -- \| --------------------------------- -- \| Γ ⊢ w λx. e : Sτ → τ', S err@(EAbs name e) -> do fr <- fresh exprErr ( withBinding (coerce name) fr $ do (s1, (e', t')) <- exprErr (algoW e) err let varType = apply s1 fr let newArr = T.TFun varType t' return (s1, apply s1 (T.EAbs (coerce name) (e', t'), newArr)) ) err -- \| Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S₁ -- \| s₂ = mgu(s₁τ₀, Int) s₃ = mgu(s₂τ₁, Int) -- \| ------------------------------------------ -- \| Γ ⊢ e₀ + e₁ : Int, S₃S₂S₁S₀ -- This might be wrong err@(EAdd e0 e1) -> do (s1, (e0', t0)) <- algoW e0 applySt s1 $ do (s2, (e1', t1)) <- algoW e1 -- applySt s2 $ do s3 <- exprErr (unify (apply s2 t0) int) err s4 <- exprErr (unify (apply s3 t1) int) err let comp = s4 `compose` s3 `compose` s2 `compose` s1 return ( comp , apply comp (T.EAdd (e0', t0) (e1', t1), int) ) -- \| Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S1 -- \| τ' = newvar S₂ = mgu(S₁τ₀, τ₁ → τ') -- \| -------------------------------------- -- \| Γ ⊢ e₀ e₁ : S₂τ', S₂S₁S₀ err@(EApp e0 e1) -> do fr <- fresh (s0, (e0', t0)) <- algoW e0 applySt s0 $ do (s1, (e1', t1)) <- algoW e1 s2 <- exprErr (unify (apply s1 t0) (T.TFun t1 fr)) err let t = apply s2 fr let comp = s2 `compose` s1 `compose` s0 return (comp, apply comp (T.EApp (e0', t0) (e1', t1), t)) -- \| Γ ⊢ e₀ : τ, S₀ S₀Γ, x : S̅₀Γ̅(τ) ⊢ e₁ : τ', S₁ -- \| ---------------------------------------------- -- \| Γ ⊢ let x = e₀ in e₁ : τ', S₁S₀ -- The bar over S₀ and Γ means "generalize" err@(ELet b@(Bind name args e) e1) -> do (s1, (_, t0)) <- algoW (makeLambda e (coerce args)) bind' <- exprErr (checkBind b) err env <- asks vars let t' = generalize (apply s1 env) t0 withBinding (coerce name) t' $ do (s2, (e1', t2)) <- algoW e1 let comp = s2 `compose` s1 return (comp, apply comp (T.ELet bind' (e1', t2), t2)) -- \| TODO: Add judgement ECase caseExpr injs -> do (sub, (e', t)) <- algoW caseExpr (subst, injs, ret_t) <- checkCase t injs let comp = subst `compose` sub let t' = apply comp ret_t return (comp, apply comp (T.ECase (e', t) injs, t')) makeLambda :: Exp -> [T.Ident] -> Exp makeLambda = foldl (flip (EAbs . coerce)) -- | Unify two types producing a new substitution unify :: T.Type -> T.Type -> Infer Subst unify t0 t1 = do case (t0, t1) of (T.TFun a b, T.TFun c d) -> do s1 <- unify a c s2 <- unify (apply s1 b) (apply s1 d) return $ s1 `compose` s2 ----------- TODO: CAREFUL!!!! THIS IS PROBABLY WRONG!!! ----------- (T.TVar (T.MkTVar a), t@(T.TData _ _)) -> return $ M.singleton a t (t@(T.TData _ _), T.TVar (T.MkTVar b)) -> return $ M.singleton b t ------------------------------------------------------------------- (T.TVar (T.MkTVar a), t) -> occurs a t (t, T.TVar (T.MkTVar b)) -> occurs b t (T.TAll _ t, b) -> unify t b (a, T.TAll _ t) -> unify a t (T.TLit a, T.TLit b) -> if a == b then return M.empty else throwError . unwords $ [ "Can not unify" , "'" <> printTree (T.TLit a) <> "'" , "with" , "'" <> printTree (T.TLit b) <> "'" ] (T.TData name t, T.TData name' t') -> if name == name' && length t == length t' then do xs <- zipWithM unify t t' return $ foldr compose nullSubst xs else throwError $ unwords [ "T.Type constructor:" , printTree name , "(" <> printTree t <> ")" , "does not match with:" , printTree name' , "(" <> printTree t' <> ")" ] (a, b) -> do throwError . unwords $ [ "'" <> printTree a <> "'" , "can't be unified with" , "'" <> printTree b <> "'" ] {- | Check if a type is contained in another type. I.E. { a = a -> b } is an unsolvable constraint since there is no substitution where these are equal -} occurs :: T.Ident -> T.Type -> Infer Subst occurs i t@(T.TVar _) = return (M.singleton i t) occurs i t = if S.member i (free t) then throwError $ unwords [ "Occurs check failed, can't unify" , printTree (T.TVar $ T.MkTVar i) , "with" , printTree t ] else return $ M.singleton i t -- | Generalize a type over all free variables in the substitution set generalize :: Map T.Ident T.Type -> T.Type -> T.Type generalize env t = go freeVars $ removeForalls t where freeVars :: [T.Ident] freeVars = S.toList $ free t S.\\ free env go :: [T.Ident] -> T.Type -> T.Type go [] t = t go (x : xs) t = T.TAll (T.MkTVar x) (go xs t) removeForalls :: T.Type -> T.Type removeForalls (T.TAll _ t) = removeForalls t removeForalls (T.TFun t1 t2) = T.TFun (removeForalls t1) (removeForalls t2) removeForalls t = t {- | Instantiate a polymorphic type. The free type variables are substituted with fresh ones. -} inst :: T.Type -> Infer T.Type inst = \case T.TAll (T.MkTVar bound) t -> do fr <- fresh let s = M.singleton bound fr apply s <$> inst t T.TFun t1 t2 -> T.TFun <$> inst t1 <*> inst t2 rest -> return rest -- | Compose two substitution sets compose :: Subst -> Subst -> Subst compose m1 m2 = M.map (apply m1) m2 `M.union` m1 -- TODO: Split this class into two separate classes, one for free variables -- and one for applying substitutions -- | A class representing free variables functions class SubstType t where -- | Apply a substitution to t apply :: Subst -> t -> t class FreeVars t where -- | Get all free variables from t free :: t -> Set T.Ident instance FreeVars T.Type where free :: T.Type -> Set T.Ident free (T.TVar (T.MkTVar a)) = S.singleton a free (T.TAll (T.MkTVar bound) t) = S.singleton bound `S.intersection` free t free (T.TLit _) = mempty free (T.TFun a b) = free a `S.union` free b -- \| Not guaranteed to be correct free (T.TData _ a) = foldl' (\acc x -> free x `S.union` acc) S.empty a instance SubstType T.Type where apply :: Subst -> T.Type -> T.Type apply sub t = do case t of T.TLit a -> T.TLit a T.TVar (T.MkTVar a) -> case M.lookup a sub of Nothing -> T.TVar (T.MkTVar $ coerce a) Just t -> t T.TAll (T.MkTVar i) t -> case M.lookup i sub of Nothing -> T.TAll (T.MkTVar i) (apply sub t) Just _ -> apply sub t T.TFun a b -> T.TFun (apply sub a) (apply sub b) T.TData name a -> T.TData name (map (apply sub) a) instance FreeVars (Map T.Ident T.Type) where free :: Map T.Ident T.Type -> Set T.Ident free m = foldl' S.union S.empty (map free $ M.elems m) instance SubstType (Map T.Ident T.Type) where apply :: Subst -> Map T.Ident T.Type -> Map T.Ident T.Type apply s = M.map (apply s) instance SubstType T.ExpT where apply :: Subst -> T.ExpT -> T.ExpT apply s = \case (T.EId i, outerT) -> (T.EId i, apply s outerT) (T.ELit lit, t) -> (T.ELit lit, apply s t) (T.ELet (T.Bind (ident, t1) args e1) e2, t2) -> ( T.ELet (T.Bind (ident, apply s t1) args (apply s e1)) (apply s e2) , apply s t2 ) (T.EApp e1 e2, t) -> (T.EApp (apply s e1) (apply s e2), apply s t) (T.EAdd e1 e2, t) -> (T.EAdd (apply s e1) (apply s e2), apply s t) (T.EAbs ident e, t1) -> (T.EAbs ident (apply s e), apply s t1) (T.ECase e brnch, t) -> (T.ECase (apply s e) (apply s brnch), apply s t) instance SubstType T.Branch where apply :: Subst -> T.Branch -> T.Branch apply s (T.Branch (i, t) e) = T.Branch (apply s i, apply s t) (apply s e) instance SubstType T.Pattern where apply :: Subst -> T.Pattern -> T.Pattern apply s = \case T.PVar (iden, t) -> T.PVar (iden, apply s t) T.PLit (lit, t) -> T.PLit (lit, apply s t) T.PInj i ps -> T.PInj i $ apply s ps T.PCatch -> T.PCatch T.PEnum i -> T.PEnum i instance SubstType a => SubstType [a] where apply s = map (apply s) instance SubstType T.Id where apply s (name, t) = (name, apply s t) -- | Apply substitutions to the environment. applySt :: Subst -> Infer a -> Infer a applySt s = local (\st -> st{vars = apply s (vars st)}) -- | Represents the empty substition set nullSubst :: Subst nullSubst = M.empty -- | Generate a new fresh variable and increment the state counter fresh :: Infer T.Type fresh = do n <- gets count modify (\st -> st{count = n + 1}) return . T.TVar . T.MkTVar . T.Ident $ show n -- | Run the monadic action with an additional binding withBinding :: (Monad m, MonadReader Ctx m) => T.Ident -> T.Type -> m a -> m a withBinding i p = local (\st -> st{vars = M.insert i p (vars st)}) -- | Run the monadic action with several additional bindings withBindings :: (Monad m, MonadReader Ctx m) => [(T.Ident, T.Type)] -> m a -> m a withBindings xs = local (\st -> st{vars = foldl' (flip (uncurry M.insert)) (vars st) xs}) -- | Insert a function signature into the environment insertSig :: T.Ident -> Maybe T.Type -> Infer () insertSig i t = modify (\st -> st{sigs = M.insert i t (sigs st)}) -- | Insert a constructor with its data type insertConstr :: T.Ident -> T.Type -> Infer () insertConstr i t = modify (\st -> st{constructors = M.insert i t (constructors st)}) -------- PATTERN MATCHING --------- checkCase :: T.Type -> [Branch] -> Infer (Subst, [T.Branch], T.Type) checkCase expT injs = do (injTs, injs, returns) <- unzip3 <$> mapM inferBranch injs (sub1, _) <- foldM ( \(sub, acc) x -> (\a -> (a `compose` sub, a `apply` acc)) <$> unify x acc ) (nullSubst, expT) injTs (sub2, returns_type) <- foldM ( \(sub, acc) x -> (\a -> (a `compose` sub, a `apply` acc)) <$> unify x acc ) (nullSubst, head returns) (tail returns) return (sub2 `compose` sub1, injs, returns_type) {- | fst = type of init | snd = type of expr -} inferBranch :: Branch -> Infer (T.Type, T.Branch, T.Type) inferBranch (Branch pat expr) = do newPat@(pat, branchT) <- inferPattern pat trace ("BRANCH TYPE: " ++ show branchT) pure () newExp@(_, exprT) <- withPattern pat (inferExp expr) return (branchT, T.Branch newPat newExp, exprT) withPattern :: T.Pattern -> Infer a -> Infer a withPattern p ma = case p of T.PVar (x, t) -> withBinding x t ma T.PInj _ ps -> foldl' (flip withPattern) ma ps T.PLit _ -> ma T.PCatch -> ma T.PEnum _ -> ma inferPattern :: Pattern -> Infer (T.Pattern, T.Type) inferPattern = \case PLit lit -> let lt = litType lit in return (T.PLit (toNew lit, lt), lt) PInj constr patterns -> do t <- gets (M.lookup (coerce constr) . constructors) t <- maybeToRightM ("Constructor: " <> printTree constr <> " does not exist") t (vs, ret) <- maybeToRightM "Partial pattern match not allowed" (unsnoc $ flattenType t) patterns <- mapM inferPattern patterns sub <- foldl' compose nullSubst <$> zipWithM unify vs (map snd patterns) return (T.PInj (coerce constr) (map fst patterns), apply sub ret) PCatch -> (T.PCatch,) <$> fresh PEnum p -> do t <- gets (M.lookup (coerce p) . constructors) t <- maybeToRightM ("Constructor: " <> printTree p <> " does not exist") t return (T.PEnum $ coerce p, t) PVar x -> do fr <- fresh let pvar = T.PVar (coerce x, fr) return (pvar, fr) flattenType :: T.Type -> [T.Type] flattenType (T.TFun a b) = flattenType a <> flattenType b flattenType a = [a] litType :: Lit -> T.Type litType (LInt _) = int litType (LChar _) = char int = T.TLit "Int" char = T.TLit "Char" partitionType :: Int -> -- Number of parameters to apply Type -> ([Type], Type) partitionType = go [] where go acc 0 t = (acc, t) go acc i t = case t of TAll tvar t' -> second (TAll tvar) $ go acc i t' TFun t1 t2 -> go (acc <> [t1]) (i - 1) t2 _ -> error "Number of parameters and type doesn't match" exprErr :: Infer a -> Exp -> Infer a exprErr ma exp = catchError ma (\x -> throwError $ x <> " on expression: " <> printTree exp) bindErr :: Infer a -> Bind -> Infer a bindErr ma exp = catchError ma (\x -> throwError $ x <> " on expression: " <> printTree exp)