From 62724964d7144256c1c456a71f50a7af7539b3bf Mon Sep 17 00:00:00 2001 From: sebastian Date: Wed, 8 Mar 2023 15:22:42 +0100 Subject: [PATCH] fixed Maybe ('a -> 'a) bug. Pattern matching still wonky, will have to redo --- "\\" | 511 +++++++++++++++++++++++++++++++++ src/TypeChecker/TypeChecker.hs | 117 ++++---- test_program | 48 +--- 3 files changed, 579 insertions(+), 97 deletions(-) create mode 100644 "\\" diff --git "a/\\" "b/\\" new file mode 100644 index 0000000..90c24ff --- /dev/null +++ "b/\\" @@ -0,0 +1,511 @@ +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedStrings #-} + +-- | A module for type checking and inference using algorithm W, Hindley-Milner +module TypeChecker.TypeChecker where + +import Control.Monad.Except +import Control.Monad.Reader +import Control.Monad.State +import Data.Foldable (traverse_) +import Data.Functor.Identity (runIdentity) +import Debug.Trace (trace) +import Data.List (foldl') +import Data.Map (Map) +import Data.Map qualified as M +import Data.Set (Set) +import Data.Set qualified as S +import Data.Maybe (fromMaybe) +import Grammar.Abs +import Grammar.Print (printTree) +import TypeChecker.TypeCheckerIr ( + Ctx (..), + Env (..), + Error, + Infer, + Poly (..), + Subst, + ) +import TypeChecker.TypeCheckerIr qualified as T + +initCtx = Ctx mempty + +initEnv = Env 0 mempty mempty + +runPretty :: Exp -> Either Error String +runPretty = fmap (printTree . fst) . run . inferExp + +run :: Infer a -> Either Error a +run = runC initEnv initCtx + +runC :: Env -> Ctx -> Infer a -> Either Error a +runC e c = runIdentity . runExceptT . flip runReaderT c . flip evalStateT e + +typecheck :: Program -> Either Error T.Program +typecheck = run . checkPrg + +{- | Start by freshening the type variable of data types to avoid clash with +other user defined polymorphic types +This might be wrong for type constructors that work over several variables +-} +-- freshenData :: Data -> Infer Data +-- freshenData (Data (Constr name ts) constrs) = do +-- new_ts <- traverse freshenType ts +-- new_constrs <- traverse freshenConstr constrs +-- return $ Data (Constr name new_ts) new_constrs +--TODO: Fix incorrect behavior here + +{- | Freshen all polymorphic variables, regardless of name +| freshenType "d" (a -> b -> c) becomes (d -> d -> d) +-} +-- freshenType :: Type -> Infer Type +-- freshenType t = do +-- let freeVars = (S.toList $ free t) +-- frs <- sequenceA $ map (const fresh) freeVars +-- let remaps = M.fromList $ zip freeVars frs +-- return $ go remaps t +-- where +-- go :: Map Ident Type -> Type -> Type +-- go m t = case t of +-- TPol a -> fromMaybe (error "bug in \'free\'") (M.lookup a m ) +-- TMono a -> TMono a +-- TArr t1 t2 -> TArr (go m t1) (go m t2) +-- TConstr (Constr ident ts) -> TConstr (Constr ident (map (go m) ts)) + +-- freshenConstr :: Constructor -> Infer Constructor +-- freshenConstr (Constructor name t) = do +-- t' <- freshenType t +-- return $ Constructor name t' + +checkData :: Data -> Infer () +checkData d = do + case d of + (Data typ@(Constr name ts) constrs) -> do + unless + (all isPoly ts) + (throwError $ unwords ["Data type incorrectly declared"]) + traverse_ + ( \(Constructor name' t') -> + if TConstr typ == retType t' + then insertConstr name' t' + else + throwError $ + unwords + [ "return type of constructor:" + , printTree name + , "with type:" + , printTree (retType t') + , "does not match data: " + , printTree typ + ] + ) + constrs + +retType :: Type -> Type +retType (TArr _ t2) = retType t2 +retType a = a + +checkPrg :: Program -> Infer T.Program +checkPrg (Program bs) = do + preRun bs + bs' <- checkDef bs + return $ T.Program bs' + where + preRun :: [Def] -> Infer () + preRun [] = return () + preRun (x : xs) = case x of + DBind (Bind n t _ _ _) -> insertSig n t >> preRun xs + DData d@(Data _ _) -> checkData d >> preRun xs + + checkDef :: [Def] -> Infer [T.Def] + checkDef [] = return [] + checkDef (x : xs) = case x of + (DBind b) -> do + b' <- checkBind b + fmap (T.DBind b' :) (checkDef xs) + (DData d) -> fmap (T.DData d :) (checkDef xs) + +checkBind :: Bind -> Infer T.Bind +checkBind (Bind n t _ args e) = do + (t', e') <- inferExp $ makeLambda e (reverse args) + s <- unify t t' + let t'' = apply s t + unless + (t `typeEq` t'') + ( throwError $ + unwords + [ "Top level signature" + , printTree t + , "does not match body with inferred type:" + , printTree t'' + ] + ) + return $ T.Bind (n, t) e' + where + makeLambda :: Exp -> [Ident] -> Exp + makeLambda = foldl (flip EAbs) + +{- | Check if two types are considered equal + For the purpose of the algorithm two polymorphic types are always considered + equal +-} +typeEq :: Type -> Type -> Bool +typeEq (TArr l r) (TArr l' r') = typeEq l l' && typeEq r r' +typeEq (TMono a) (TMono b) = a == b +typeEq (TConstr (Constr name a)) (TConstr (Constr name' b)) = + length a == length b + && name == name' + && and (zipWith typeEq a b) +typeEq (TPol _) (TPol _) = True +typeEq _ _ = False + +isMoreSpecificOrEq :: Type -> Type -> Bool +isMoreSpecificOrEq _ (TPol _) = True +isMoreSpecificOrEq (TArr a b) (TArr c d) = + isMoreSpecificOrEq a c && isMoreSpecificOrEq b d +isMoreSpecificOrEq (TConstr (Constr n1 ts1)) (TConstr (Constr n2 ts2)) = + n1 == n2 + && length ts1 == length ts2 + && and (zipWith isMoreSpecificOrEq ts1 ts2) +isMoreSpecificOrEq a b = a == b + +isPoly :: Type -> Bool +isPoly (TPol _) = True +isPoly _ = False + +inferExp :: Exp -> Infer (Type, T.Exp) +inferExp e = do + (s, t, e') <- algoW e + let subbed = apply s t + return (subbed, replace subbed e') + +replace :: Type -> T.Exp -> T.Exp +replace t = \case + T.ELit _ e -> T.ELit t e + T.EId (n, _) -> T.EId (n, t) + T.EAbs _ name e -> T.EAbs t name e + T.EApp _ e1 e2 -> T.EApp t e1 e2 + T.EAdd _ e1 e2 -> T.EAdd t e1 e2 + T.ELet (T.Bind (n, _) e1) e2 -> T.ELet (T.Bind (n, t) e1) e2 + T.ECase _ expr injs -> T.ECase t expr injs + +algoW :: Exp -> Infer (Subst, Type, T.Exp) +algoW = \case + -- \| TODO: More testing need to be done. Unsure of the correctness of this + EAnn e t -> do + (s1, t', e') <- algoW e + unless + (t `isMoreSpecificOrEq` t') + ( throwError $ + unwords + [ "Annotated type:" + , printTree t + , "does not match inferred type:" + , printTree t' + ] + ) + applySt s1 $ do + s2 <- unify t t' + return (s2 `compose` s1, t, e') + + -- \| ------------------ + -- \| Γ ⊢ i : Int, ∅ + + ELit (LInt n) -> + return (nullSubst, TMono "Int", T.ELit (TMono "Int") (LInt n)) + ELit a -> error $ "NOT IMPLEMENTED YET: ELit " ++ show a + -- \| x : σ ∈ Γ   τ = inst(σ) + -- \| ---------------------- + -- \| Γ ⊢ x : τ, ∅ + + EId i -> do + var <- asks vars + case M.lookup i var of + Just t -> inst t >>= \x -> return (nullSubst, x, T.EId (i, x)) + Nothing -> do + sig <- gets sigs + case M.lookup i sig of + Just t -> return (nullSubst, t, T.EId (i, t)) + Nothing -> do + constr <- gets constructors + case M.lookup i constr of + Just t -> return (nullSubst, t, T.EId (i, t)) + Nothing -> + throwError $ + "Unbound variable: " ++ show i + + -- \| τ = newvar Γ, x : τ ⊢ e : τ', S + -- \| --------------------------------- + -- \| Γ ⊢ w λx. e : Sτ → τ', S + + EAbs name e -> do + fr <- fresh + withBinding name (Forall [] fr) $ do + (s1, t', e') <- algoW e + let varType = apply s1 fr + let newArr = TArr varType t' + return (s1, newArr, T.EAbs newArr (name, varType) e') + + -- \| Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S₁ + -- \| s₂ = mgu(s₁τ₀, Int) s₃ = mgu(s₂τ₁, Int) + -- \| ------------------------------------------ + -- \| Γ ⊢ e₀ + e₁ : Int, S₃S₂S₁S₀ + -- This might be wrong + + EAdd e0 e1 -> do + (s1, t0, e0') <- algoW e0 + applySt s1 $ do + (s2, t1, e1') <- algoW e1 + -- applySt s2 $ do + s3 <- unify (apply s2 t0) (TMono "Int") + s4 <- unify (apply s3 t1) (TMono "Int") + return + ( s4 `compose` s3 `compose` s2 `compose` s1 + , TMono "Int" + , T.EAdd (TMono "Int") e0' e1' + ) + + -- \| Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S1 + -- \| τ' = newvar S₂ = mgu(S₁τ₀, τ₁ → τ') + -- \| -------------------------------------- + -- \| Γ ⊢ e₀ e₁ : S₂τ', S₂S₁S₀ + + EApp e0 e1 -> do + fr <- fresh + (s0, t0, e0') <- algoW e0 + applySt s0 $ do + (s1, t1, e1') <- algoW e1 + -- applySt s1 $ do + s2 <- unify (apply s1 t0) (TArr t1 fr) + let t = apply s2 fr + return (s2 `compose` s1 `compose` s0, t, T.EApp t e0' e1') + + -- \| Γ ⊢ e₀ : τ, S₀ S₀Γ, x : S̅₀Γ̅(τ) ⊢ e₁ : τ', S₁ + -- \| ---------------------------------------------- + -- \| Γ ⊢ let x = e₀ in e₁ : τ', S₁S₀ + + -- The bar over S₀ and Γ means "generalize" + + ELet name e0 e1 -> do + (s1, t1, e0') <- algoW e0 + env <- asks vars + let t' = generalize (apply s1 env) t1 + withBinding name t' $ do + (s2, t2, e1') <- algoW e1 + return (s2 `compose` s1, t2, T.ELet (T.Bind (name, t2) e0') e1') + ECase caseExpr injs -> do + (_, t0, e0') <- algoW caseExpr + (injs', ts) <- mapAndUnzipM (checkInj t0) injs + case ts of + [] -> throwError "Case expression missing any matches" + ts -> do + unified <- zipWithM unify ts (tail ts) + let unified' = foldl' compose mempty unified + let typ = apply unified' (head ts) + return (unified', typ, T.ECase typ e0' injs') + +-- | Unify two types producing a new substitution +unify :: Type -> Type -> Infer Subst +unify t0 t1 = do + case (t0, t1) of + (TArr a b, TArr c d) -> do + s1 <- unify a c + s2 <- unify (apply s1 b) (apply s1 d) + return $ s1 `compose` s2 + (TPol a, b) -> occurs a b + (a, TPol b) -> occurs b a + (TMono a, TMono b) -> + if a == b then return M.empty else throwError "Types do not unify" + -- \| TODO: Figure out a cleaner way to express the same thing + (TConstr (Constr name t), TConstr (Constr name' t')) -> + if name == name' && length t == length t' + then do + xs <- zipWithM unify t t' + return $ foldr compose nullSubst xs + else + throwError $ + unwords + [ "Type constructor:" + , printTree name + , "(" ++ printTree t ++ ")" + , "does not match with:" + , printTree name' + , "(" ++ printTree t' ++ ")" + ] + (a, b) -> + throwError . unwords $ + [ "Type:" + , printTree a + , "can't be unified with:" + , printTree b + ] + +{- | Check if a type is contained in another type. +I.E. { a = a -> b } is an unsolvable constraint since there is no substitution +such that these are equal +-} +occurs :: Ident -> Type -> Infer Subst +occurs _ (TPol _) = return nullSubst +occurs i t = + if S.member i (free t) + then + throwError $ + unwords + [ "Occurs check failed, can't unify" + , printTree (TPol i) + , "with" + , printTree t + ] + else return $ M.singleton i t + +-- | Generalize a type over all free variables in the substitution set +generalize :: Map Ident Poly -> Type -> Poly +generalize env t = Forall (S.toList $ free t S.\\ free env) t + +{- | Instantiate a polymorphic type. The free type variables are substituted +with fresh ones. +-} +inst :: Poly -> Infer Type +inst (Forall xs t) = do + xs' <- mapM (const fresh) xs + let s = M.fromList $ zip xs xs' + return $ apply s t + +-- | Compose two substitution sets +compose :: Subst -> Subst -> Subst +compose m1 m2 = M.map (apply m1) m2 `M.union` m1 + +-- | A class representing free variables functions +class FreeVars t where + -- | Get all free variables from t + free :: t -> Set Ident + + -- | Apply a substitution to t + apply :: Subst -> t -> t + +instance FreeVars Type where + free :: Type -> Set Ident + free (TPol a) = S.singleton a + free (TMono _) = mempty + free (TArr a b) = free a `S.union` free b + -- \| Not guaranteed to be correct + free (TConstr (Constr _ a)) = + foldl' (\acc x -> free x `S.union` acc) S.empty a + + apply :: Subst -> Type -> Type + apply sub t = do + case t of + TMono a -> TMono a + TPol a -> case M.lookup a sub of + Nothing -> TPol a + Just t -> t + TArr a b -> TArr (apply sub a) (apply sub b) + TConstr (Constr name a) -> TConstr (Constr name (map (apply sub) a)) + +instance FreeVars Poly where + free :: Poly -> Set Ident + free (Forall xs t) = free t S.\\ S.fromList xs + apply :: Subst -> Poly -> Poly + apply s (Forall xs t) = Forall xs (apply (foldr M.delete s xs) t) + +instance FreeVars (Map Ident Poly) where + free :: Map Ident Poly -> Set Ident + free m = foldl' S.union S.empty (map free $ M.elems m) + apply :: Subst -> Map Ident Poly -> Map Ident Poly + apply s = M.map (apply s) + +-- | Apply substitutions to the environment. +applySt :: Subst -> Infer a -> Infer a +applySt s = local (\st -> st{vars = apply s (vars st)}) + +-- | Represents the empty substition set +nullSubst :: Subst +nullSubst = M.empty + +-- | Generate a new fresh variable and increment the state counter +fresh :: Infer Type +fresh = do + n <- gets count + modify (\st -> st{count = n + 1}) + return . TPol . Ident $ show n + +-- | Run the monadic action with an additional binding +withBinding :: (Monad m, MonadReader Ctx m) => Ident -> Poly -> m a -> m a +withBinding i p = local (\st -> st{vars = M.insert i p (vars st)}) + +-- | Insert a function signature into the environment +insertSig :: Ident -> Type -> Infer () +insertSig i t = modify (\st -> st{sigs = M.insert i t (sigs st)}) + +-- | Insert a constructor with its data type +insertConstr :: Ident -> Type -> Infer () +insertConstr i t = + modify (\st -> st{constructors = M.insert i t (constructors st)}) + +-------- PATTERN MATCHING --------- + +-- "case expr of", the type of 'expr' is caseType +checkInj :: Type -> Inj -> Infer (T.Inj, Type) +checkInj caseType (Inj it expr) = do + (args, t') <- initType caseType it + subst <- unify caseType t' + trace ("SUBST: " ++ show subst) return () + applySt subst $ do + (_, t, e') <- local (\st -> st { vars = args `M.union` vars st }) (algoW expr) + return (T.Inj (it, t') e', t) + +initType :: Type -> Init -> Infer (Map Ident Poly, Type) +initType expected = \case + InitLit lit -> do + trace (show "EXPECTED: " ++ show expected ++ "\nreturnType: " ++ show (litType lit)) return () + if litType lit `isMoreSpecificOrEq` expected + then return (mempty, litType lit) + else + throwError $ + unwords + [ "Inferred type" + , printTree $ litType lit + , "does not match expected type:" + , printTree expected + ] + InitConstr c args -> do + st <- gets constructors + case M.lookup c st of + Nothing -> + throwError $ + unwords + [ "Constructor:" + , printTree c + , "does not exist" + ] + Just t -> do + let flat = flattenType t + let returnType = last flat + case ( length (init flat) == length args + , returnType `isMoreSpecificOrEq` expected + ) of + (True, True) -> + return + ( M.fromList $ zip args (map (Forall []) flat) + , expected + ) + (False, _) -> + throwError $ + "Can't partially match on the constructor: " + ++ printTree c + (_, False) -> + throwError $ + unwords + [ "Inferred type" + , printTree returnType + , "does not match expected type:" + , printTree expected + ] + InitCatch -> return (mempty, expected) + +flattenType :: Type -> [Type] +flattenType (TArr a b) = flattenType a ++ flattenType b +flattenType a = [a] + +litType :: Literal -> Type +litType (LInt _) = TMono "Int" diff --git a/src/TypeChecker/TypeChecker.hs b/src/TypeChecker/TypeChecker.hs index af4734d..0c3df12 100644 --- a/src/TypeChecker/TypeChecker.hs +++ b/src/TypeChecker/TypeChecker.hs @@ -9,12 +9,13 @@ import Control.Monad.Reader import Control.Monad.State import Data.Foldable (traverse_) import Data.Functor.Identity (runIdentity) +import Debug.Trace (trace) import Data.List (foldl') import Data.Map (Map) import Data.Map qualified as M import Data.Set (Set) import Data.Set qualified as S -import Debug.Trace (trace) +import Data.Maybe (fromMaybe) import Grammar.Abs import Grammar.Print (printTree) import TypeChecker.TypeCheckerIr ( @@ -45,36 +46,23 @@ typecheck = run . checkPrg {- | Start by freshening the type variable of data types to avoid clash with other user defined polymorphic types -This might be wrong for type constructors that work over several variables -} freshenData :: Data -> Infer Data freshenData (Data (Constr name ts) constrs) = do - fr <- fresh - let fr' = case fr of - TPol a -> a - -- Meh, this part assumes fresh generates a polymorphic type - _ -> - error - "Bug: implementation of \ - \ fresh and freshenData are not compatible" - let new_ts = map (freshenType fr') ts - let new_constrs = map (freshenConstr fr') constrs - return $ Data (Constr name new_ts) new_constrs - + let xs = (S.toList . free) =<< ts + frs <- traverse (const fresh) xs + let m = M.fromList $ zip xs frs + return $ Data (Constr name (map (freshenType m) ts)) (map (\(Constructor ident t) -> Constructor ident (freshenType m t)) constrs) + {- | Freshen all polymorphic variables, regardless of name | freshenType "d" (a -> b -> c) becomes (d -> d -> d) -} -freshenType :: Ident -> Type -> Type -freshenType iden = \case - (TPol _) -> TPol iden - (TArr a b) -> TArr (freshenType iden a) (freshenType iden b) - (TConstr (Constr a ts)) -> - TConstr (Constr a (map (freshenType iden) ts)) - rest -> rest - -freshenConstr :: Ident -> Constructor -> Constructor -freshenConstr iden (Constructor name t) = - Constructor name (freshenType iden t) +freshenType :: Map Ident Type -> Type -> Type +freshenType m t = case t of + TPol poly -> fromMaybe (error "bug in \'free\'") (M.lookup poly m) + TMono mono -> TMono mono + TArr t1 t2 -> TArr (freshenType m t1) (freshenType m t2) + TConstr (Constr ident ts) -> TConstr (Constr ident (map (freshenType m) ts)) checkData :: Data -> Infer () checkData d = do @@ -108,7 +96,8 @@ retType a = a checkPrg :: Program -> Infer T.Program checkPrg (Program bs) = do preRun bs - T.Program <$> checkDef bs + bs' <- checkDef bs + return $ T.Program bs' where preRun :: [Def] -> Infer () preRun [] = return () @@ -122,7 +111,9 @@ checkPrg (Program bs) = do (DBind b) -> do b' <- checkBind b fmap (T.DBind b' :) (checkDef xs) - (DData d) -> fmap (T.DData d :) (checkDef xs) + (DData d) -> do + d' <- freshenData d + fmap (T.DData d' :) (checkDef xs) checkBind :: Bind -> Infer T.Bind checkBind (Bind n t _ args e) = do @@ -205,7 +196,8 @@ algoW = \case ) applySt s1 $ do s2 <- unify t t' - return (s2 `compose` s1, t, e') + let composition = s2 `compose` s1 + return (composition, t, apply composition e') -- \| ------------------ -- \| Γ ⊢ i : Int, ∅ @@ -243,7 +235,7 @@ algoW = \case (s1, t', e') <- algoW e let varType = apply s1 fr let newArr = TArr varType t' - return (s1, newArr, T.EAbs newArr (name, varType) e') + return (s1, newArr, apply s1 $ T.EAbs newArr (name, varType) e') -- \| Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S₁ -- \| s₂ = mgu(s₁τ₀, Int) s₃ = mgu(s₂τ₁, Int) @@ -258,10 +250,11 @@ algoW = \case -- applySt s2 $ do s3 <- unify (apply s2 t0) (TMono "Int") s4 <- unify (apply s3 t1) (TMono "Int") + let composition = s4 `compose` s3 `compose` s2 `compose` s1 return - ( s4 `compose` s3 `compose` s2 `compose` s1 + ( composition , TMono "Int" - , T.EAdd (TMono "Int") e0' e1' + , apply composition $ T.EAdd (TMono "Int") e0' e1' ) -- \| Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S1 @@ -277,7 +270,8 @@ algoW = \case -- applySt s1 $ do s2 <- unify (apply s1 t0) (TArr t1 fr) let t = apply s2 fr - return (s2 `compose` s1 `compose` s0, t, T.EApp t e0' e1') + let composition = s2 `compose` s1 `compose` s0 + return (composition, t, apply composition $ T.EApp t e0' e1') -- \| Γ ⊢ e₀ : τ, S₀ S₀Γ, x : S̅₀Γ̅(τ) ⊢ e₁ : τ', S₁ -- \| ---------------------------------------------- @@ -291,7 +285,9 @@ algoW = \case let t' = generalize (apply s1 env) t1 withBinding name t' $ do (s2, t2, e1') <- algoW e1 - return (s2 `compose` s1, t2, T.ELet (T.Bind (name, t2) e0') e1') + let composition = s2 `compose` s1 + return (composition, t2, apply composition $ T.ELet (T.Bind (name, t2) e0') e1') + ECase caseExpr injs -> do (_, t0, e0') <- algoW caseExpr (injs', ts) <- mapAndUnzipM (checkInj t0) injs @@ -299,15 +295,13 @@ algoW = \case [] -> throwError "Case expression missing any matches" ts -> do unified <- zipWithM unify ts (tail ts) - let unified' = foldl' compose mempty unified - let typ = apply unified' (head ts) - return (unified', typ, T.ECase typ e0' injs') + let composition = foldl' compose mempty unified + let typ = apply composition (head ts) + return (composition, typ, apply composition $ T.ECase typ e0' injs') -- | Unify two types producing a new substitution unify :: Type -> Type -> Infer Subst unify t0 t1 = do - trace ("t0: " ++ show t0) return () - trace ("t1: " ++ show t1) return () case (t0, t1) of (TArr a b, TArr c d) -> do s1 <- unify a c @@ -343,7 +337,7 @@ unify t0 t1 = do {- | Check if a type is contained in another type. I.E. { a = a -> b } is an unsolvable constraint since there is no substitution -such that these are equal +where these are equal -} occurs :: Ident -> Type -> Infer Subst occurs _ (TPol _) = return nullSubst @@ -415,6 +409,30 @@ instance FreeVars (Map Ident Poly) where apply :: Subst -> Map Ident Poly -> Map Ident Poly apply s = M.map (apply s) +instance FreeVars T.Exp where + free :: T.Exp -> Set Ident + free = error "free not implemented for T.Exp" + apply :: Subst -> T.Exp -> T.Exp + apply s = \case + T.EId (ident, t) -> T.EId (ident, apply s t) + T.ELit t lit -> T.ELit (apply s t) lit + T.ELet (T.Bind (ident, t) e1) e2 -> T.ELet (T.Bind (ident, apply s t) (apply s e1)) (apply s e2) + T.EApp t e1 e2 -> T.EApp (apply s t) (apply s e1) (apply s e2) + T.EAdd t e1 e2 -> T.EAdd (apply s t) (apply s e1) (apply s e2) + T.EAbs t1 (ident, t2) e -> T.EAbs (apply s t1) (ident, apply s t2) (apply s e) + T.ECase t e injs -> T.ECase (apply s t) (apply s e) (apply s injs) + +instance FreeVars T.Inj where + free :: T.Inj -> Set Ident + free = undefined + apply :: Subst -> T.Inj -> T.Inj + apply s (T.Inj (i, t) e) = T.Inj (i, apply s t) (apply s e) + +instance FreeVars [T.Inj] where + free :: [T.Inj] -> Set Ident + free = foldl' (\acc x -> free x `S.union` acc) mempty + apply s = map (apply s) + -- | Apply substitutions to the environment. applySt :: Subst -> Infer a -> Infer a applySt s = local (\st -> st{vars = apply s (vars st)}) @@ -449,23 +467,16 @@ insertConstr i t = checkInj :: Type -> Inj -> Infer (T.Inj, Type) checkInj caseType (Inj it expr) = do (args, t') <- initType caseType it - (_, t, e') <- local (\st -> st{vars = args `M.union` vars st}) (algoW expr) - return (T.Inj (it, t') e', t) + subst <- unify caseType t' + applySt subst $ do + (_, t, e') <- local (\st -> st { vars = args `M.union` vars st }) (algoW expr) + return (T.Inj (it, t') e', t) initType :: Type -> Init -> Infer (Map Ident Poly, Type) initType expected = \case - InitLit lit -> - let returnType = litType lit - in if expected == returnType - then return (mempty, expected) - else - throwError $ - unwords - [ "Inferred type" - , printTree returnType - , "does not match expected type:" - , printTree expected - ] + + InitLit lit -> error "Pattern match on literals not implemented yet" + InitConstr c args -> do st <- gets constructors case M.lookup c st of diff --git a/test_program b/test_program index efa8eea..0d74a4e 100644 --- a/test_program +++ b/test_program @@ -1,50 +1,10 @@ --- data Bool () where { --- True : Bool () --- False : Bool () --- }; --- --- data List ('a) where { --- Nil : List ('a) --- Cons : ('a) -> List ('a) -> List ('a) --- }; - data Maybe ('a) where { Nothing : Maybe ('a) Just : 'a -> Maybe ('a) }; --- id : 'a -> 'a ; --- id x = x ; +id : 'a -> 'a ; +id x = x ; --- main : Maybe ('a -> 'a) ; --- main = Just id; - --- data Either ('a 'b) where { --- Left : 'a -> Either ('a 'b) --- Right : 'b -> Either ('a 'b) --- }; - --- safeHead : List ('a) -> Maybe ('a) ; --- safeHead xs = --- case xs of { --- Nil => Nothing ; --- Cons x xs => Just x --- }; - --- main : Maybe (_Int) ; --- main = safeHead (Cons 0 (Cons 1 Nil)) ; --- --- maybeToEither : Either ('a 'b) -> Maybe ('a) ; --- maybeToEither e = --- case e of { --- Left y => Nothing ; --- Right x => Just x --- }; - --- Bug. f not included in the case-expression context -fmap : ('a -> 'b) -> Maybe ('a) -> Maybe ('b) ; -fmap f x = - case x of { - Just x => Just (f x) ; - Nothing => Nothing - } +main : Maybe ('a -> 'a) ; +main = Just id ;