diff --git a/src/TypeChecker/TypeCheckerHm.hs b/src/TypeChecker/TypeCheckerHm.hs index 01a7e16..518b3e8 100644 --- a/src/TypeChecker/TypeCheckerHm.hs +++ b/src/TypeChecker/TypeCheckerHm.hs @@ -39,9 +39,8 @@ typecheck = onLeft msg . run . checkPrg checkPrg :: Program -> Infer (T.Program' Type) checkPrg (Program bs) = do preRun bs - (subs, bs) <- checkDef bs - ctrace "SUBS" $ unionSubsts subs - return $ T.Program bs + (sub, bs) <- checkDef bs + return $ T.Program $ apply sub bs preRun :: [Def] -> Infer () preRun [] = return () @@ -74,13 +73,14 @@ preRun (x : xs) = case x of duplicateDecl :: (Monad m, MonadError Error m) => LIdent -> [T.Ident] -> String -> m () duplicateDecl n env msg = when (coerce n `elem` env) (uncatchableErr msg) -checkDef :: [Def] -> Infer ([Subst], [T.Def' Type]) -checkDef [] = return ([], []) +checkDef :: [Def] -> Infer (Subst, [T.Def' Type]) +checkDef [] = return (nullSubst, []) checkDef (x : xs) = case x of (DBind b) -> do (sub0, b') <- checkBind b (sub1, xs') <- checkDef xs - return (sub1 ++ sub0, T.DBind b' : xs') + comp <- sub0 `composey` sub1 + return (comp, T.DBind b' : xs') (DData d) -> do (sub, xs') <- checkDef xs return (sub, T.DData (coerceData d) : xs') @@ -89,17 +89,16 @@ checkDef (x : xs) = case x of coerceData (Data t injs) = T.Data t $ map (\(Inj name typ) -> T.Inj (coerce name) typ) injs -checkBind :: Bind -> Infer ([Subst], T.Bind' Type) +checkBind :: Bind -> Infer (Subst, T.Bind' Type) checkBind bind@(Bind name args e) = do let lambda = makeLambda e (reverse (coerce args)) (sub0, (e, lambda_t)) <- inferExp lambda s <- gets sigs case M.lookup (coerce name) s of - Just t' -> do - sub1 <- bindErr (unify t' lambda_t) bind - ctrace "SUB0" sub0 - ctrace "SUB1" sub1 - return ([sub1, sub0], T.Bind (coerce name, t') [] (e, lambda_t)) + Just t -> do + sub1 <- bindErr (unify t lambda_t) bind + comp <- sub1 `composey` sub0 + return (comp, T.Bind (coerce name, apply comp t) [] (e, lambda_t)) _ -> error "First pass through failed to add function to env" checkData :: (MonadState Env m, Monad m, MonadError Error m) => Data -> m () @@ -168,7 +167,7 @@ returnType a = a inferExp :: Exp -> Infer (Subst, T.ExpT' Type) inferExp e = do (s, (e', t)) <- algoW e - subbed <- apply s t + let subbed = apply s t return (s, (e', subbed)) class CollectTVars a where @@ -195,9 +194,17 @@ algoW = \case (sub0, (e', t')) <- exprErr (algoW e) err sub1 <- unify t t' sub2 <- unify t' t + unless + (apply sub1 t == t' && apply sub2 t' == t) + ( uncatchableErr $ Aux.do + "Annotated type" + quote $ printTree t + "does not match inferred type" + quote $ printTree t' + ) let comp = sub2 `compose` sub1 `compose` sub0 - et <- apply comp (e', t) - return (comp, et) + -- return (comp, apply comp (e', t)) + return (comp, (e', t)) -- \| ------------------ -- \| Γ ⊢ i : Int, ∅ @@ -238,10 +245,10 @@ algoW = \case fr <- fresh withBinding (coerce name) fr $ do (s1, (e', t')) <- exprErr (algoW e) err - varType <- apply s1 fr + let varType = apply s1 fr let newArr = TFun varType t' - eabs <- apply s1 (T.EAbs (coerce name) (e', t'), newArr) - return (s1, eabs) + -- return (s1, apply s1 (T.EAbs (coerce name) (e', t'), newArr)) + return (s1, (T.EAbs (coerce name) (e', t'), newArr)) -- \| Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S₁ -- \| s₂ = mgu(s₁τ₀, Int) s₃ = mgu(s₂τ₁, Int) @@ -255,10 +262,8 @@ algoW = \case s3 <- exprErr (unify (apply s2 t0) int) err s4 <- exprErr (unify (apply s3 t1) int) err let comp = s4 `compose` s3 `compose` s2 `compose` s1 - return - ( comp - , apply comp (T.EAdd (e0', t0) (e1', t1), int) - ) + -- return (comp, apply comp (T.EAdd (e0', t0) (e1', t1), int)) + return (comp, (T.EAdd (e0', t0) (e1', t1), int)) -- \| Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S1 -- \| τ' = newvar S₂ = mgu(S₁τ₀, τ₁ → τ') @@ -273,8 +278,10 @@ algoW = \case (s1, (e1', t1)) <- algoW e1 s2 <- exprErr (unify (apply s1 t0) (TFun t1 fr)) err let t = apply s2 fr - let comp = s2 `compose` s1 `compose` s0 - return (comp, apply comp (T.EApp (e0', t0) (e1', t1), t)) + comp <- foldM composey nullSubst [s2, s1, s0] + -- let comp = s2 `compose` s1 `compose` s0 + -- return (comp, apply comp (T.EApp (e0', t0) (e1', t1), t)) + return (comp, (T.EApp (e0', t0) (e1', t1), t)) -- \| Γ ⊢ e₀ : τ, S₀ S₀Γ, x : S̅₀Γ̅(τ) ⊢ e₁ : τ', S₁ -- \| ---------------------------------------------- @@ -290,12 +297,14 @@ algoW = \case withBinding (coerce name) t' $ do (s2, (e1', t2)) <- algoW e1 let comp = s2 `compose` s1 - return (comp, apply comp (T.ELet bind' (e1', t2), t2)) + -- return (comp, apply comp (T.ELet bind' (e1', t2), t2)) + return (comp, (T.ELet bind' (e1', t2), t2)) ECase caseExpr injs -> do (sub, (e', t)) <- algoW caseExpr (subst, injs, ret_t) <- checkCase t injs let comp = subst `compose` sub - return (comp, apply comp (T.ECase (e', t) injs, ret_t)) + -- return (comp, apply comp (T.ECase (e', t) injs, ret_t)) + return (comp, (T.ECase (e', t) injs, ret_t)) EAppInf{} -> error "desugar phase failed" checkCase :: Type -> [Branch] -> Infer (Subst, [T.Branch' Type], Type) @@ -528,7 +537,7 @@ skolemize t = t class SubstType t where -- | Apply a substitution to t -- apply :: MonadError e m => Subst -> t -> m t - apply :: Subst -> t -> Infer t + apply :: Subst -> t -> t class FreeVars t where -- | Get all free variables from t @@ -550,43 +559,28 @@ instance FreeVars a => FreeVars [a] where instance SubstType Type where apply sub@(Subst s) t = do case t of - TLit a -> return $ TLit a + TLit a -> TLit a TVar (MkTVar a) -> case M.lookup (coerce a) s of - Nothing -> return $ TVar (MkTVar $ coerce a) - Just t -> return $ t + Nothing -> TVar (MkTVar $ coerce a) + Just t -> t TAll (MkTVar i) t -> case M.lookup (coerce i) s of - Nothing -> TAll (MkTVar i) <$> apply sub t + Nothing -> TAll (MkTVar i) (apply sub t) Just _ -> apply sub t - TFun a b -> TFun <$> apply sub a <*> apply sub b - TData name a -> TData name <$> apply sub a + TFun a b -> TFun (apply sub a) (apply sub b) + TData name a -> TData name (apply sub a) TEVar (MkTEVar a) -> case M.lookup (coerce a) s of - Nothing -> return $ TEVar (MkTEVar a) - Just t -> return $ t + Nothing -> TEVar (MkTEVar a) + Just t -> t instance FreeVars (Map T.Ident Type) where free :: Map T.Ident Type -> Set T.Ident free = free . M.elems instance SubstType (Map T.Ident Type) where - apply s = undefined -- M.map (apply s) + apply s = M.map (apply s) instance SubstType Subst where - apply s@(Subst m1) (Subst m2) = do - let both = M.keys $ M.intersection m1 m2 - case both of - [] -> Subst <$> apply s m2 - xs -> do - sub0 <- apply s m2 - sub1 <- loop xs m1 m2 - apply sub1 (Subst sub0) - where - loop [] _ _ = return nullSubst - loop (x : xs) m1 m2 = do - let k1 = m1 M.! x - let k2 = m2 M.! x - sub <- unify k1 k2 - subs <- loop xs m1 m2 - return $ sub `compose` subs + apply s (Subst m2) = Subst $ apply s m2 -- Subst $ M.map (apply s) m2 @@ -640,6 +634,30 @@ nullSubst = Subst mempty compose :: Subst -> Subst -> Subst compose m1 m2 = Subst $ M.map (apply $ coerce m1) (coerce m2) `M.union` coerce m1 +-- Order matters. +{- +sub0 = Subst $ (M.singleton "a" (arr d e)) `M.union` (M.singleton "b" (arr d f)) `M.union` (M.singleton "c" (arr f e)) +sub1 = Subst $ (M.singleton "a" (arr g bool)) `M.union` (M.singleton "b" (arr g bool)) `M.union` (M.singleton "c" (arr bool bool)) `M.union` (M.singleton "h" bool) `M.union` (M.singleton "i" bool) +sub0 `composey` sub1 != sub1 `composey` sub0 + -} +composey :: Subst -> Subst -> Infer Subst +composey s0@(Subst m1) s1@(Subst m2) = do + let both = M.keys $ M.intersection m1 m2 + case both of + [] -> return $ s0 `compose` s1 + xs -> do + let m2' = apply s0 m2 + sub <- loop xs m1 m2' + return $ sub `compose` Subst m2 + where + loop [] _ _ = return nullSubst + loop (x : xs) m1 m2 = do + let k1 = m1 M.! x + let k2 = m2 M.! x + sub <- unify k1 k2 + subs <- loop xs m1 m2 + return $ sub `compose` subs + -- | Compose a list of substitution sets into one composeAll :: [Subst] -> Subst composeAll = foldl' compose nullSubst @@ -800,7 +818,7 @@ data Error = Error {msg :: String, catchable :: Bool} newtype Subst = Subst (Map T.Ident Type) instance Show Subst where - show (Subst s) = "[" ++ let xs = (map (\(a, b) -> printTree a ++ " = " ++ printTree b) $ M.toList s) in intercalate " | " xs ++ "]" + show (Subst s) = "[ " ++ let xs = (map (\(a, b) -> printTree a ++ " = " ++ printTree b) $ M.toList s) in intercalate " | " xs ++ " ]" newtype Infer a = Infer {runInfer :: StateT Env (ReaderT Ctx (ExceptT Error Identity)) a} deriving (Functor, Applicative, Monad, MonadReader Ctx, MonadError Error, MonadState Env) @@ -816,3 +834,16 @@ quote s = "'" ++ s ++ "'" ctrace :: (Monad m, Show a) => String -> a -> m () ctrace str a = trace (str ++ ": " ++ show a) pure () + +{- +Save each subst mapped to their respective function +Apply composition of all used functions to the function + +a = id 0 ; +b = id 'a' ; +id x = x ; + +apply_on_a = id_sub `compose` a_sub +apply_on_b = id_sub `compose` b_sub +apply_on_id = id_sub +-}