From c4f78ca37d713975043f43c415052a1990a10333 Mon Sep 17 00:00:00 2001 From: sebastianselander Date: Fri, 31 Mar 2023 18:26:58 +0200 Subject: [PATCH] Merge in mutual recursion handling --- src/TypeChecker/TypeCheckerHm.hs | 194 ++++++++++++++----------------- 1 file changed, 89 insertions(+), 105 deletions(-) diff --git a/src/TypeChecker/TypeCheckerHm.hs b/src/TypeChecker/TypeCheckerHm.hs index 11cb94e..01a7e16 100644 --- a/src/TypeChecker/TypeCheckerHm.hs +++ b/src/TypeChecker/TypeCheckerHm.hs @@ -6,16 +6,15 @@ -- | A module for type checking and inference using algorithm W, Hindley-Milner module TypeChecker.TypeCheckerHm where -import Auxiliary (int, litType, maybeToRightM, tupSequence, unzip4) +import Auxiliary (int, litType, maybeToRightM, unzip4) import Auxiliary qualified as Aux import Control.Monad.Except import Control.Monad.Identity (Identity, runIdentity) import Control.Monad.Reader import Control.Monad.State -import Data.Bifunctor (first) import Data.Coerce (coerce) import Data.Function (on) -import Data.List (foldl') +import Data.List (foldl', intercalate) import Data.List.Extra (unsnoc) import Data.Map (Map) import Data.Map qualified as M @@ -40,19 +39,10 @@ typecheck = onLeft msg . run . checkPrg checkPrg :: Program -> Infer (T.Program' Type) checkPrg (Program bs) = do preRun bs - bs <- checkDef bs - sub0 <- solveUndecidable - bs <- mapM (mono sub0) bs + (subs, bs) <- checkDef bs + ctrace "SUBS" $ unionSubsts subs return $ T.Program bs -mono :: Subst -> T.Def' Type -> Infer (T.Def' Type) -mono s bind@(T.DBind (T.Bind (name, t) args e)) = do - b <- gets (S.member name . toDecide) - if b - then return $ T.DBind $ T.Bind (name, apply s t) (apply s args) (apply s e) - else return bind -mono _ (T.DData d) = return $ T.DData d - preRun :: [Def] -> Infer () preRun [] = return () preRun (x : xs) = case x of @@ -62,7 +52,8 @@ preRun (x : xs) = case x of duplicateDecl n s $ Aux.do "Multiple signatures of function" quote $ printTree n - insertSig (coerce n) (Just t) >> preRun xs + insertSig (coerce n) t + preRun xs DBind (Bind n _ e) -> do s <- gets (S.toList . declaredBinds) duplicateDecl n s $ Aux.do @@ -70,43 +61,46 @@ preRun (x : xs) = case x of quote $ printTree n collect (collectTVars e) insertBind $ coerce n - s <- gets sigs - case M.lookup (coerce n) s of - Nothing -> insertSig (coerce n) Nothing >> preRun xs + sigs <- gets sigs + case M.lookup (coerce n) sigs of + Nothing -> do + fr <- fresh + insertSig (coerce n) fr + preRun xs Just _ -> preRun xs DData d@(Data t _) -> collect (collectTVars t) >> checkData d >> preRun xs where -- Check if function body / signature has been declared already + duplicateDecl :: (Monad m, MonadError Error m) => LIdent -> [T.Ident] -> String -> m () duplicateDecl n env msg = when (coerce n `elem` env) (uncatchableErr msg) -checkDef :: [Def] -> Infer [T.Def' Type] -checkDef [] = return [] +checkDef :: [Def] -> Infer ([Subst], [T.Def' Type]) +checkDef [] = return ([], []) checkDef (x : xs) = case x of (DBind b) -> do - b' <- checkBind b - xs' <- checkDef xs - return $ T.DBind b' : xs' + (sub0, b') <- checkBind b + (sub1, xs') <- checkDef xs + return (sub1 ++ sub0, T.DBind b' : xs') (DData d) -> do - xs' <- checkDef xs - return $ T.DData (coerceData d) : xs' + (sub, xs') <- checkDef xs + return (sub, T.DData (coerceData d) : xs') (DSig _) -> checkDef xs where coerceData (Data t injs) = T.Data t $ map (\(Inj name typ) -> T.Inj (coerce name) typ) injs -checkBind :: Bind -> Infer (T.Bind' Type) +checkBind :: Bind -> Infer ([Subst], T.Bind' Type) checkBind bind@(Bind name args e) = do - setCurrentBind $ coerce name let lambda = makeLambda e (reverse (coerce args)) (sub0, (e, lambda_t)) <- inferExp lambda s <- gets sigs case M.lookup (coerce name) s of - Just (Just t') -> do - sub1 <- bindErr (unify lambda_t (skolemize t')) bind - return $ T.Bind (coerce name, apply (sub1 `compose` sub0) t') [] (e, lambda_t) - _ -> do - insertSig (coerce name) (Just lambda_t) - return (T.Bind (coerce name, lambda_t) [] (e, lambda_t)) + Just t' -> do + sub1 <- bindErr (unify t' lambda_t) bind + ctrace "SUB0" sub0 + ctrace "SUB1" sub1 + return ([sub1, sub0], T.Bind (coerce name, t') [] (e, lambda_t)) + _ -> error "First pass through failed to add function to env" checkData :: (MonadState Env m, Monad m, MonadError Error m) => Data -> m () checkData err@(Data typ injs) = do @@ -174,8 +168,7 @@ returnType a = a inferExp :: Exp -> Infer (Subst, T.ExpT' Type) inferExp e = do (s, (e', t)) <- algoW e - let subbed = apply s t - modify (\st -> st{undecidedSigs = apply s st.undecidedSigs}) + subbed <- apply s t return (s, (e', subbed)) class CollectTVars a where @@ -202,16 +195,9 @@ algoW = \case (sub0, (e', t')) <- exprErr (algoW e) err sub1 <- unify t t' sub2 <- unify t' t - unless - (apply sub1 t == t' && apply sub2 t' == t) - ( uncatchableErr $ Aux.do - "Annotated type" - quote $ printTree t - "does not match inferred type" - quote $ printTree t' - ) let comp = sub2 `compose` sub1 `compose` sub0 - return (comp, apply comp (e', t)) + et <- apply comp (e', t) + return (comp, et) -- \| ------------------ -- \| Γ ⊢ i : Int, ∅ @@ -228,13 +214,8 @@ algoW = \case return (nullSubst, (T.EVar $ coerce i, x)) Nothing -> do sig <- gets sigs - cb <- gets currentBind case M.lookup (coerce i) sig of - Just (Just t) -> return (nullSubst, (T.EVar $ coerce i, t)) - Just Nothing -> do - fr <- fresh - modify (\st -> st{toDecide = S.insert cb st.toDecide, undecidedSigs = M.insert (coerce $ concat [[prefix], i, [delim], coerce cb]) fr st.undecidedSigs}) - return (nullSubst, (T.EVar $ coerce i, fr)) + Just t -> return (nullSubst, (T.EVar $ coerce i, t)) Nothing -> uncatchableErr $ "Unbound variable: " @@ -257,9 +238,10 @@ algoW = \case fr <- fresh withBinding (coerce name) fr $ do (s1, (e', t')) <- exprErr (algoW e) err - let varType = apply s1 fr + varType <- apply s1 fr let newArr = TFun varType t' - return (s1, apply s1 (T.EAbs (coerce name) (e', t'), newArr)) + eabs <- apply s1 (T.EAbs (coerce name) (e', t'), newArr) + return (s1, eabs) -- \| Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S₁ -- \| s₂ = mgu(s₁τ₀, Int) s₃ = mgu(s₂τ₁, Int) @@ -302,7 +284,7 @@ algoW = \case err@(ELet b@(Bind name args e) e1) -> do (s1, (_, t0)) <- algoW (makeLambda e (coerce args)) - bind' <- exprErr (checkBind b) err + (_, bind') <- exprErr (checkBind b) err env <- asks vars let t' = generalize (apply s1 env) t0 withBinding (coerce name) t' $ do @@ -422,15 +404,15 @@ unify t0 t1 = s1 <- unify a c s2 <- unify (apply s1 b) (apply s1 d) return $ s2 `compose` s1 - (TVar (MkTVar a), t@(TData _ _)) -> return $ M.singleton (coerce a) t - (t@(TData _ _), TVar (MkTVar b)) -> return $ M.singleton (coerce b) t + (TVar (MkTVar a), t@(TData _ _)) -> return $ coerce $ M.singleton (coerce a) t + (t@(TData _ _), TVar (MkTVar b)) -> return $ coerce $ M.singleton (coerce b) t (TVar (MkTVar a), t) -> occurs (coerce a) t (t, TVar (MkTVar b)) -> occurs (coerce b) t (TAll _ t, b) -> unify t b (a, TAll _ t) -> unify a t (TLit a, TLit b) -> if a == b - then return M.empty + then return nullSubst else catchableErr $ Aux.do "Can not unify" @@ -452,7 +434,7 @@ unify t0 t1 = quote $ printTree t' (TEVar a, TEVar b) -> if a == b - then return M.empty + then return nullSubst else catchableErr $ Aux.do "Can not unify" @@ -472,7 +454,7 @@ I.E. { a = a -> b } is an unsolvable constraint since there is no substitution where these are equal -} occurs :: T.Ident -> Type -> Infer Subst -occurs i t@(TVar _) = return (M.singleton i t) +occurs i t@(TVar _) = return (coerce $ M.singleton i t) occurs i t = if S.member i (free t) then @@ -483,7 +465,7 @@ occurs i t = "with" quote $ printTree t ) - else return $ M.singleton i t + else return $ coerce $ M.singleton i t {- | Generalize a type over all free variables in the substitution set Used for let bindings to allow expression that do not type check in @@ -509,7 +491,7 @@ inst :: Type -> Infer Type inst = \case TAll (MkTVar bound) t -> do fr <- fresh - let s = M.singleton (coerce bound) fr + let s = coerce $ M.singleton (coerce bound) fr apply s <$> inst t TFun t1 t2 -> TFun <$> inst t1 <*> inst t2 rest -> return rest @@ -545,7 +527,8 @@ skolemize t = t -- | A class for substitutions class SubstType t where -- | Apply a substitution to t - apply :: Subst -> t -> t + -- apply :: MonadError e m => Subst -> t -> m t + apply :: Subst -> t -> Infer t class FreeVars t where -- | Get all free variables from t @@ -565,32 +548,47 @@ instance FreeVars a => FreeVars [a] where free = let f acc x = acc `S.union` free x in foldl' f S.empty instance SubstType Type where - apply :: Subst -> Type -> Type - apply sub t = do + apply sub@(Subst s) t = do case t of - TLit a -> TLit a - TVar (MkTVar a) -> case M.lookup (coerce a) sub of - Nothing -> TVar (MkTVar $ coerce a) - Just t -> t - TAll (MkTVar i) t -> case M.lookup (coerce i) sub of - Nothing -> TAll (MkTVar i) (apply sub t) + TLit a -> return $ TLit a + TVar (MkTVar a) -> case M.lookup (coerce a) s of + Nothing -> return $ TVar (MkTVar $ coerce a) + Just t -> return $ t + TAll (MkTVar i) t -> case M.lookup (coerce i) s of + Nothing -> TAll (MkTVar i) <$> apply sub t Just _ -> apply sub t - TFun a b -> TFun (apply sub a) (apply sub b) - TData name a -> TData name (apply sub a) - TEVar (MkTEVar a) -> case M.lookup (coerce a) sub of - Nothing -> TEVar (MkTEVar a) - Just t -> t + TFun a b -> TFun <$> apply sub a <*> apply sub b + TData name a -> TData name <$> apply sub a + TEVar (MkTEVar a) -> case M.lookup (coerce a) s of + Nothing -> return $ TEVar (MkTEVar a) + Just t -> return $ t instance FreeVars (Map T.Ident Type) where free :: Map T.Ident Type -> Set T.Ident free = free . M.elems instance SubstType (Map T.Ident Type) where - apply :: Subst -> Map T.Ident Type -> Map T.Ident Type - apply = M.map . apply + apply s = undefined -- M.map (apply s) -instance SubstType (Map T.Ident (Maybe Type)) where - apply s = M.map (fmap $ apply s) +instance SubstType Subst where + apply s@(Subst m1) (Subst m2) = do + let both = M.keys $ M.intersection m1 m2 + case both of + [] -> Subst <$> apply s m2 + xs -> do + sub0 <- apply s m2 + sub1 <- loop xs m1 m2 + apply sub1 (Subst sub0) + where + loop [] _ _ = return nullSubst + loop (x : xs) m1 m2 = do + let k1 = m1 M.! x + let k2 = m2 M.! x + sub <- unify k1 k2 + subs <- loop xs m1 m2 + return $ sub `compose` subs + +-- Subst $ M.map (apply s) m2 instance SubstType (T.ExpT' Type) where apply s (e, t) = (apply s e, apply s t) @@ -636,16 +634,19 @@ instance SubstType (T.Id' Type) where -- | Represents the empty substition set nullSubst :: Subst -nullSubst = mempty +nullSubst = Subst mempty -- | Compose two substitution sets compose :: Subst -> Subst -> Subst -compose m1 m2 = M.map (apply m1) m2 `M.union` m1 +compose m1 m2 = Subst $ M.map (apply $ coerce m1) (coerce m2) `M.union` coerce m1 -- | Compose a list of substitution sets into one composeAll :: [Subst] -> Subst composeAll = foldl' compose nullSubst +unionSubsts :: [Subst] -> Subst +unionSubsts = Subst . foldl' M.union M.empty . map coerce + {- | Convert a function with arguments to its pointfree version > makeLambda (add x y = x + y) = add = \x. \y. x + y -} @@ -671,7 +672,7 @@ withPattern p ma = case p of T.PEnum _ -> ma -- | Insert a function signature into the environment -insertSig :: T.Ident -> Maybe Type -> Infer () +insertSig :: T.Ident -> Type -> Infer () insertSig i t = modify (\st -> st{sigs = M.insert i t (sigs st)}) insertBind :: T.Ident -> Infer () @@ -691,24 +692,6 @@ with an equivalent name has been declared already existInj :: (Monad m, MonadState Env m) => T.Ident -> m (Maybe Type) existInj n = gets (M.lookup n . injections) -setCurrentBind :: T.Ident -> Infer () -setCurrentBind i = modify (\st -> st{currentBind = i}) - -solveUndecidable :: Infer Subst -solveUndecidable = do - sigs <- gets sigs - undecided <- gets undecidedSigs - ys <- - maybeToRightM - (Error "SIGNATURE MISSING" False) - ( mapM (tupSequence . first (join . flip M.lookup sigs . getOriginal)) $ - M.toList undecided - ) - composeAll <$> mapM (uncurry unify) ys - -getOriginal :: T.Ident -> T.Ident -getOriginal (T.Ident i) = coerce $ takeWhile (/= delim) $ drop 1 i - delim :: Char delim = '_' prefix :: Char @@ -785,7 +768,7 @@ dataErr ma d = ) initCtx = Ctx mempty -initEnv = Env 0 'a' mempty mempty mempty "" mempty mempty mempty +initEnv = Env 0 'a' mempty mempty mempty mempty run :: Infer a -> Either Error a run = run' initEnv initCtx @@ -804,19 +787,20 @@ newtype Ctx = Ctx {vars :: Map T.Ident Type} data Env = Env { count :: Int , nextChar :: Char - , sigs :: Map T.Ident (Maybe Type) + , sigs :: Map T.Ident Type , takenTypeVars :: Set T.Ident , injections :: Map T.Ident Type - , currentBind :: T.Ident - , undecidedSigs :: Map T.Ident Type - , toDecide :: Set T.Ident , declaredBinds :: Set T.Ident } deriving (Show) data Error = Error {msg :: String, catchable :: Bool} deriving (Show) -type Subst = Map T.Ident Type + +newtype Subst = Subst (Map T.Ident Type) + +instance Show Subst where + show (Subst s) = "[" ++ let xs = (map (\(a, b) -> printTree a ++ " = " ++ printTree b) $ M.toList s) in intercalate " | " xs ++ "]" newtype Infer a = Infer {runInfer :: StateT Env (ReaderT Ctx (ExceptT Error Identity)) a} deriving (Functor, Applicative, Monad, MonadReader Ctx, MonadError Error, MonadState Env)