From 0749ca062d7a73b2444af8f1b19163945d8135b3 Mon Sep 17 00:00:00 2001 From: sebastianselander Date: Fri, 31 Mar 2023 18:28:04 +0200 Subject: [PATCH] Merge in mutual recursion handling --- src/TypeChecker/TypeCheckerHm.hs | 131 ++++++++++++++++++------------- src/TypeChecker/TypeCheckerIr.hs | 5 ++ test_program.crf | 43 ---------- 3 files changed, 81 insertions(+), 98 deletions(-) diff --git a/src/TypeChecker/TypeCheckerHm.hs b/src/TypeChecker/TypeCheckerHm.hs index 518b3e8..33765e0 100644 --- a/src/TypeChecker/TypeCheckerHm.hs +++ b/src/TypeChecker/TypeCheckerHm.hs @@ -8,6 +8,7 @@ module TypeChecker.TypeCheckerHm where import Auxiliary (int, litType, maybeToRightM, unzip4) import Auxiliary qualified as Aux +import Control.Arrow ((&&&)) import Control.Monad.Except import Control.Monad.Identity (Identity, runIdentity) import Control.Monad.Reader @@ -18,7 +19,7 @@ import Data.List (foldl', intercalate) import Data.List.Extra (unsnoc) import Data.Map (Map) import Data.Map qualified as M -import Data.Maybe (fromJust) +import Data.Maybe (fromJust, fromMaybe, mapMaybe) import Data.Set (Set) import Data.Set qualified as S import Debug.Trace (trace) @@ -26,8 +27,6 @@ import Grammar.Abs import Grammar.Print (printTree) import TypeChecker.TypeCheckerIr qualified as T --- TODO: Disallow mutual recursion - -- | Type check a program typecheck :: Program -> Either String (T.Program' Type) typecheck = onLeft msg . run . checkPrg @@ -37,10 +36,16 @@ typecheck = onLeft msg . run . checkPrg onLeft _ (Right x) = Right x checkPrg :: Program -> Infer (T.Program' Type) -checkPrg (Program bs) = do - preRun bs - (sub, bs) <- checkDef bs - return $ T.Program $ apply sub bs +checkPrg (Program bs) = T.Program <$> (preRun bs >> checkDef bs >>= mapM substPrg) + +substPrg :: T.Def' Type -> Infer (T.Def' Type) +substPrg (T.DBind (T.Bind (name, t) args e)) = do + (bu, sub) <- gets (bindUsages &&& bindSubs) + let uses = fromMaybe [] $ M.lookup name bu + let subs = mapMaybe (`M.lookup` sub) (name : uses) + sub <- foldM composey nullSubst (reverse subs) + return . T.DBind $ T.Bind (name, apply sub t) (apply sub args) (apply sub e) +substPrg d = return d preRun :: [Def] -> Infer () preRun [] = return () @@ -51,7 +56,7 @@ preRun (x : xs) = case x of duplicateDecl n s $ Aux.do "Multiple signatures of function" quote $ printTree n - insertSig (coerce n) t + insertSig (coerce n) (Instantiated t) preRun xs DBind (Bind n _ e) -> do s <- gets (S.toList . declaredBinds) @@ -64,7 +69,7 @@ preRun (x : xs) = case x of case M.lookup (coerce n) sigs of Nothing -> do fr <- fresh - insertSig (coerce n) fr + insertSig (coerce n) (Generalized fr) preRun xs Just _ -> preRun xs DData d@(Data t _) -> collect (collectTVars t) >> checkData d >> preRun xs @@ -73,33 +78,38 @@ preRun (x : xs) = case x of duplicateDecl :: (Monad m, MonadError Error m) => LIdent -> [T.Ident] -> String -> m () duplicateDecl n env msg = when (coerce n `elem` env) (uncatchableErr msg) -checkDef :: [Def] -> Infer (Subst, [T.Def' Type]) -checkDef [] = return (nullSubst, []) +checkDef :: [Def] -> Infer [T.Def' Type] +checkDef [] = return [] checkDef (x : xs) = case x of (DBind b) -> do - (sub0, b') <- checkBind b - (sub1, xs') <- checkDef xs - comp <- sub0 `composey` sub1 - return (comp, T.DBind b' : xs') + b' <- checkBind b + xs' <- checkDef xs + return $ T.DBind b' : xs' (DData d) -> do - (sub, xs') <- checkDef xs - return (sub, T.DData (coerceData d) : xs') + xs' <- checkDef xs + return $ T.DData (coerceData d) : xs' (DSig _) -> checkDef xs where coerceData (Data t injs) = T.Data t $ map (\(Inj name typ) -> T.Inj (coerce name) typ) injs -checkBind :: Bind -> Infer (Subst, T.Bind' Type) +checkBind :: Bind -> Infer (T.Bind' Type) checkBind bind@(Bind name args e) = do + setCurrentBind $ coerce name let lambda = makeLambda e (reverse (coerce args)) (sub0, (e, lambda_t)) <- inferExp lambda s <- gets sigs case M.lookup (coerce name) s of Just t -> do - sub1 <- bindErr (unify t lambda_t) bind + let t' = case t of + Instantiated a -> skolemize a + Generalized a -> a + sub1 <- bindErr (unify t' lambda_t) bind comp <- sub1 `composey` sub0 - return (comp, T.Bind (coerce name, apply comp t) [] (e, lambda_t)) - _ -> error "First pass through failed to add function to env" + insertBindSubst (coerce name) comp + return (T.Bind (coerce name, apply comp t') [] (e, lambda_t)) + _ -> do + uncatchableErr $ "Undeclared function: " ++ printTree name checkData :: (MonadState Env m, Monad m, MonadError Error m) => Data -> m () checkData err@(Data typ injs) = do @@ -203,7 +213,6 @@ algoW = \case quote $ printTree t' ) let comp = sub2 `compose` sub1 `compose` sub0 - -- return (comp, apply comp (e', t)) return (comp, (e', t)) -- \| ------------------ @@ -221,8 +230,11 @@ algoW = \case return (nullSubst, (T.EVar $ coerce i, x)) Nothing -> do sig <- gets sigs + cb <- gets currentBind case M.lookup (coerce i) sig of - Just t -> return (nullSubst, (T.EVar $ coerce i, t)) + Just t -> do + insertBindUsage cb (coerce i) + return (nullSubst, (T.EVar $ coerce i, unlevel t)) Nothing -> uncatchableErr $ "Unbound variable: " @@ -247,7 +259,6 @@ algoW = \case (s1, (e', t')) <- exprErr (algoW e) err let varType = apply s1 fr let newArr = TFun varType t' - -- return (s1, apply s1 (T.EAbs (coerce name) (e', t'), newArr)) return (s1, (T.EAbs (coerce name) (e', t'), newArr)) -- \| Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S₁ @@ -262,7 +273,6 @@ algoW = \case s3 <- exprErr (unify (apply s2 t0) int) err s4 <- exprErr (unify (apply s3 t1) int) err let comp = s4 `compose` s3 `compose` s2 `compose` s1 - -- return (comp, apply comp (T.EAdd (e0', t0) (e1', t1), int)) return (comp, (T.EAdd (e0', t0) (e1', t1), int)) -- \| Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S1 @@ -274,13 +284,10 @@ algoW = \case fr <- fresh (s0, (e0', t0)) <- algoW e0 applySt s0 $ do - modify (\st -> st{sigs = apply s0 st.sigs}) (s1, (e1', t1)) <- algoW e1 s2 <- exprErr (unify (apply s1 t0) (TFun t1 fr)) err let t = apply s2 fr comp <- foldM composey nullSubst [s2, s1, s0] - -- let comp = s2 `compose` s1 `compose` s0 - -- return (comp, apply comp (T.EApp (e0', t0) (e1', t1), t)) return (comp, (T.EApp (e0', t0) (e1', t1), t)) -- \| Γ ⊢ e₀ : τ, S₀ S₀Γ, x : S̅₀Γ̅(τ) ⊢ e₁ : τ', S₁ @@ -289,16 +296,17 @@ algoW = \case -- The bar over S₀ and Γ means "generalize" - err@(ELet b@(Bind name args e) e1) -> do - (s1, (_, t0)) <- algoW (makeLambda e (coerce args)) - (_, bind') <- exprErr (checkBind b) err + (ELet (Bind name args e) e1) -> do + (s1, (e, t0)) <- algoW (makeLambda e (coerce args)) env <- asks vars let t' = generalize (apply s1 env) t0 withBinding (coerce name) t' $ do (s2, (e1', t2)) <- algoW e1 let comp = s2 `compose` s1 - -- return (comp, apply comp (T.ELet bind' (e1', t2), t2)) - return (comp, (T.ELet bind' (e1', t2), t2)) + return + ( comp + , (T.ELet (T.Bind (coerce name, t0) [] (e, t0)) (e1', t2), t2) + ) ECase caseExpr injs -> do (sub, (e', t)) <- algoW caseExpr (subst, injs, ret_t) <- checkCase t injs @@ -413,8 +421,10 @@ unify t0 t1 = s1 <- unify a c s2 <- unify (apply s1 b) (apply s1 d) return $ s2 `compose` s1 - (TVar (MkTVar a), t@(TData _ _)) -> return $ coerce $ M.singleton (coerce a) t - (t@(TData _ _), TVar (MkTVar b)) -> return $ coerce $ M.singleton (coerce b) t + (TVar (MkTVar a), t@(TData _ _)) -> + return $ coerce $ M.singleton (coerce a) t + (t@(TData _ _), TVar (MkTVar b)) -> + return $ coerce $ M.singleton (coerce b) t (TVar (MkTVar a), t) -> occurs (coerce a) t (t, TVar (MkTVar b)) -> occurs (coerce b) t (TAll _ t, b) -> unify t b @@ -603,7 +613,8 @@ instance SubstType (T.Exp' Type) where instance SubstType (T.Def' Type) where apply s = \case - T.DBind (T.Bind name args e) -> T.DBind $ T.Bind (apply s name) (apply s args) (apply s e) + T.DBind (T.Bind name args e) -> + T.DBind $ T.Bind (apply s name) (apply s args) (apply s e) d -> d instance SubstType (T.Branch' Type) where @@ -636,8 +647,14 @@ compose m1 m2 = Subst $ M.map (apply $ coerce m1) (coerce m2) `M.union` coerce m -- Order matters. {- -sub0 = Subst $ (M.singleton "a" (arr d e)) `M.union` (M.singleton "b" (arr d f)) `M.union` (M.singleton "c" (arr f e)) -sub1 = Subst $ (M.singleton "a" (arr g bool)) `M.union` (M.singleton "b" (arr g bool)) `M.union` (M.singleton "c" (arr bool bool)) `M.union` (M.singleton "h" bool) `M.union` (M.singleton "i" bool) +sub0 = Subst $ (M.singleton "a" (arr d e)) + `M.union` (M.singleton "b" (arr d f)) + `M.union` (M.singleton "c" (arr f e)) +sub1 = Subst $ (M.singleton "a" (arr g bool)) + `M.union` (M.singleton "b" (arr g bool)) + `M.union` (M.singleton "c" (arr bool bool)) + `M.union` (M.singleton "h" bool) + `M.union` (M.singleton "i" bool) sub0 `composey` sub1 != sub1 `composey` sub0 -} composey :: Subst -> Subst -> Infer Subst @@ -690,12 +707,21 @@ withPattern p ma = case p of T.PEnum _ -> ma -- | Insert a function signature into the environment -insertSig :: T.Ident -> Type -> Infer () +insertSig :: T.Ident -> Level Type -> Infer () insertSig i t = modify (\st -> st{sigs = M.insert i t (sigs st)}) insertBind :: T.Ident -> Infer () insertBind i = modify (\st -> st{declaredBinds = S.insert i st.declaredBinds}) +insertBindSubst :: T.Ident -> Subst -> Infer () +insertBindSubst name sub = modify (\st -> st{bindSubs = M.insert name sub st.bindSubs}) + +setCurrentBind :: T.Ident -> Infer () +setCurrentBind n = modify (\st -> st{currentBind = n, bindUsages = M.insertWith (++) n [] st.bindUsages}) + +insertBindUsage :: T.Ident -> T.Ident -> Infer () +insertBindUsage cur use = modify (\st -> st{bindUsages = M.insertWith (++) cur [use] st.bindUsages}) + -- | Insert a constructor into the start with its type insertInj :: (Monad m, MonadState Env m) => T.Ident -> Type -> m () insertInj i t = @@ -786,7 +812,7 @@ dataErr ma d = ) initCtx = Ctx mempty -initEnv = Env 0 'a' mempty mempty mempty mempty +initEnv = Env 0 'a' mempty mempty mempty mempty "" mempty mempty run :: Infer a -> Either Error a run = run' initEnv initCtx @@ -805,20 +831,28 @@ newtype Ctx = Ctx {vars :: Map T.Ident Type} data Env = Env { count :: Int , nextChar :: Char - , sigs :: Map T.Ident Type + , sigs :: Map T.Ident (Level Type) , takenTypeVars :: Set T.Ident , injections :: Map T.Ident Type , declaredBinds :: Set T.Ident + , currentBind :: T.Ident + , bindSubs :: Map T.Ident Subst + , bindUsages :: Map T.Ident [T.Ident] } deriving (Show) +data Level a = Instantiated {unlevel :: a} | Generalized {unlevel :: a} + deriving (Show) + data Error = Error {msg :: String, catchable :: Bool} deriving (Show) newtype Subst = Subst (Map T.Ident Type) instance Show Subst where - show (Subst s) = "[ " ++ let xs = (map (\(a, b) -> printTree a ++ " = " ++ printTree b) $ M.toList s) in intercalate " | " xs ++ " ]" + show (Subst s) = "[ " ++ intercalate " | " xs ++ " ]" + where + xs = map (\(a, b) -> printTree a ++ " = " ++ printTree b) $ M.toList s newtype Infer a = Infer {runInfer :: StateT Env (ReaderT Ctx (ExceptT Error Identity)) a} deriving (Functor, Applicative, Monad, MonadReader Ctx, MonadError Error, MonadState Env) @@ -834,16 +868,3 @@ quote s = "'" ++ s ++ "'" ctrace :: (Monad m, Show a) => String -> a -> m () ctrace str a = trace (str ++ ": " ++ show a) pure () - -{- -Save each subst mapped to their respective function -Apply composition of all used functions to the function - -a = id 0 ; -b = id 'a' ; -id x = x ; - -apply_on_a = id_sub `compose` a_sub -apply_on_b = id_sub `compose` b_sub -apply_on_id = id_sub --} diff --git a/src/TypeChecker/TypeCheckerIr.hs b/src/TypeChecker/TypeCheckerIr.hs index b3f51d7..d59e429 100644 --- a/src/TypeChecker/TypeCheckerIr.hs +++ b/src/TypeChecker/TypeCheckerIr.hs @@ -188,6 +188,11 @@ instance Print t => Print (Inj' t) where prt i = \case Inj uident type_ -> prPrec i 0 (concatD [prt 0 uident, doc (showString ":"), prt 0 type_]) +instance Print t => Print [Inj' t] where + prt _ [] = concatD [] + prt i [x] = prt i x + prt i (x : xs) = prPrec i 0 $ concatD [prt i x, doc $ showString "\n ", prt i xs] + instance Print t => Print (Pattern' t) where prt i = \case PVar name -> prPrec i 1 (concatD [prt 0 name]) diff --git a/test_program.crf b/test_program.crf index b584ff8..432d33f 100644 --- a/test_program.crf +++ b/test_program.crf @@ -26,46 +26,3 @@ bind : Maybe () -> (Int -> Maybe ()) -> Maybe () ; bind x f = case x of { Just x => f x ; Nothing => Nothing ; -}; - --- represents minus one :) -minusOne : Int ; -minusOne = 9223372036854775807 + 9223372036854775807 + 1; - ----- LIST STUFF ---- --- a simple list data type containing ints -data List () where { - Cons : Int -> List () -> List () - Nil : List () -}; - --- take the length of a list -length : List () -> Int ; -length x = case x of { - Cons _ xs => 1 + length xs ; - Nil => 0 ; -}; - --- sum a list -sum : List () -> Int ; -sum x = case x of { - Cons a xs => a + sum xs ; - Nil => 0 ; -}; - --- sum + length of a list -sumlength: List () -> Int ; -sumlength x = sum x + length x ; - --- take the head of a list -head : List () -> Int ; -head x = case x of { - Cons h _ => h ; -}; - --- repeat an element n times -repeat : Int -> Int -> List () ; -repeat x n = case n of { - 0 => Nil ; - n => Cons x (repeat x (n + minusOne)) ; -};