diff --git a/src/TypeChecker/TypeCheckerHm.hs b/src/TypeChecker/TypeCheckerHm.hs index 33765e0..710343f 100644 --- a/src/TypeChecker/TypeCheckerHm.hs +++ b/src/TypeChecker/TypeCheckerHm.hs @@ -6,20 +6,20 @@ -- | A module for type checking and inference using algorithm W, Hindley-Milner module TypeChecker.TypeCheckerHm where -import Auxiliary (int, litType, maybeToRightM, unzip4) +import Auxiliary (int, litType, maybeToRightM, tupSequence, unzip4) import Auxiliary qualified as Aux -import Control.Arrow ((&&&)) import Control.Monad.Except import Control.Monad.Identity (Identity, runIdentity) import Control.Monad.Reader import Control.Monad.State +import Data.Bifunctor (first) import Data.Coerce (coerce) import Data.Function (on) -import Data.List (foldl', intercalate) +import Data.List (foldl') import Data.List.Extra (unsnoc) import Data.Map (Map) import Data.Map qualified as M -import Data.Maybe (fromJust, fromMaybe, mapMaybe) +import Data.Maybe (fromJust) import Data.Set (Set) import Data.Set qualified as S import Debug.Trace (trace) @@ -27,6 +27,8 @@ import Grammar.Abs import Grammar.Print (printTree) import TypeChecker.TypeCheckerIr qualified as T +-- TODO: Disallow mutual recursion + -- | Type check a program typecheck :: Program -> Either String (T.Program' Type) typecheck = onLeft msg . run . checkPrg @@ -36,16 +38,20 @@ typecheck = onLeft msg . run . checkPrg onLeft _ (Right x) = Right x checkPrg :: Program -> Infer (T.Program' Type) -checkPrg (Program bs) = T.Program <$> (preRun bs >> checkDef bs >>= mapM substPrg) +checkPrg (Program bs) = do + preRun bs + bs <- checkDef bs + sub0 <- solveUndecidable + bs <- mapM (mono sub0) bs + return $ T.Program bs -substPrg :: T.Def' Type -> Infer (T.Def' Type) -substPrg (T.DBind (T.Bind (name, t) args e)) = do - (bu, sub) <- gets (bindUsages &&& bindSubs) - let uses = fromMaybe [] $ M.lookup name bu - let subs = mapMaybe (`M.lookup` sub) (name : uses) - sub <- foldM composey nullSubst (reverse subs) - return . T.DBind $ T.Bind (name, apply sub t) (apply sub args) (apply sub e) -substPrg d = return d +mono :: Subst -> T.Def' Type -> Infer (T.Def' Type) +mono s bind@(T.DBind (T.Bind (name, t) args e)) = do + b <- gets (S.member name . toDecide) + if b + then return $ T.DBind $ T.Bind (name, apply s t) (apply s args) (apply s e) + else return bind +mono _ (T.DData d) = return $ T.DData d preRun :: [Def] -> Infer () preRun [] = return () @@ -56,8 +62,7 @@ preRun (x : xs) = case x of duplicateDecl n s $ Aux.do "Multiple signatures of function" quote $ printTree n - insertSig (coerce n) (Instantiated t) - preRun xs + insertSig (coerce n) (Just t) >> preRun xs DBind (Bind n _ e) -> do s <- gets (S.toList . declaredBinds) duplicateDecl n s $ Aux.do @@ -65,17 +70,13 @@ preRun (x : xs) = case x of quote $ printTree n collect (collectTVars e) insertBind $ coerce n - sigs <- gets sigs - case M.lookup (coerce n) sigs of - Nothing -> do - fr <- fresh - insertSig (coerce n) (Generalized fr) - preRun xs + s <- gets sigs + case M.lookup (coerce n) s of + Nothing -> insertSig (coerce n) Nothing >> preRun xs Just _ -> preRun xs DData d@(Data t _) -> collect (collectTVars t) >> checkData d >> preRun xs where -- Check if function body / signature has been declared already - duplicateDecl :: (Monad m, MonadError Error m) => LIdent -> [T.Ident] -> String -> m () duplicateDecl n env msg = when (coerce n `elem` env) (uncatchableErr msg) checkDef :: [Def] -> Infer [T.Def' Type] @@ -100,16 +101,12 @@ checkBind bind@(Bind name args e) = do (sub0, (e, lambda_t)) <- inferExp lambda s <- gets sigs case M.lookup (coerce name) s of - Just t -> do - let t' = case t of - Instantiated a -> skolemize a - Generalized a -> a - sub1 <- bindErr (unify t' lambda_t) bind - comp <- sub1 `composey` sub0 - insertBindSubst (coerce name) comp - return (T.Bind (coerce name, apply comp t') [] (e, lambda_t)) + Just (Just t') -> do + sub1 <- bindErr (unify lambda_t (skolemize t')) bind + return $ T.Bind (coerce name, apply (sub1 `compose` sub0) t') [] (e, lambda_t) _ -> do - uncatchableErr $ "Undeclared function: " ++ printTree name + insertSig (coerce name) (Just lambda_t) + return (T.Bind (coerce name, lambda_t) [] (e, lambda_t)) checkData :: (MonadState Env m, Monad m, MonadError Error m) => Data -> m () checkData err@(Data typ injs) = do @@ -178,6 +175,7 @@ inferExp :: Exp -> Infer (Subst, T.ExpT' Type) inferExp e = do (s, (e', t)) <- algoW e let subbed = apply s t + modify (\st -> st{undecidedSigs = apply s st.undecidedSigs}) return (s, (e', subbed)) class CollectTVars a where @@ -213,7 +211,7 @@ algoW = \case quote $ printTree t' ) let comp = sub2 `compose` sub1 `compose` sub0 - return (comp, (e', t)) + return (comp, apply comp (e', t)) -- \| ------------------ -- \| Γ ⊢ i : Int, ∅ @@ -232,9 +230,11 @@ algoW = \case sig <- gets sigs cb <- gets currentBind case M.lookup (coerce i) sig of - Just t -> do - insertBindUsage cb (coerce i) - return (nullSubst, (T.EVar $ coerce i, unlevel t)) + Just (Just t) -> return (nullSubst, (T.EVar $ coerce i, t)) + Just Nothing -> do + fr <- fresh + modify (\st -> st{toDecide = S.insert cb st.toDecide, undecidedSigs = M.insert (coerce $ concat [[prefix], i, [delim], coerce cb]) fr st.undecidedSigs}) + return (nullSubst, (T.EVar $ coerce i, fr)) Nothing -> uncatchableErr $ "Unbound variable: " @@ -259,7 +259,7 @@ algoW = \case (s1, (e', t')) <- exprErr (algoW e) err let varType = apply s1 fr let newArr = TFun varType t' - return (s1, (T.EAbs (coerce name) (e', t'), newArr)) + return (s1, apply s1 (T.EAbs (coerce name) (e', t'), newArr)) -- \| Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S₁ -- \| s₂ = mgu(s₁τ₀, Int) s₃ = mgu(s₂τ₁, Int) @@ -273,7 +273,10 @@ algoW = \case s3 <- exprErr (unify (apply s2 t0) int) err s4 <- exprErr (unify (apply s3 t1) int) err let comp = s4 `compose` s3 `compose` s2 `compose` s1 - return (comp, (T.EAdd (e0', t0) (e1', t1), int)) + return + ( comp + , apply comp (T.EAdd (e0', t0) (e1', t1), int) + ) -- \| Γ ⊢ e₀ : τ₀, S₀ S₀Γ ⊢ e₁ : τ₁, S1 -- \| τ' = newvar S₂ = mgu(S₁τ₀, τ₁ → τ') @@ -284,11 +287,12 @@ algoW = \case fr <- fresh (s0, (e0', t0)) <- algoW e0 applySt s0 $ do + modify (\st -> st{sigs = apply s0 st.sigs}) (s1, (e1', t1)) <- algoW e1 s2 <- exprErr (unify (apply s1 t0) (TFun t1 fr)) err let t = apply s2 fr - comp <- foldM composey nullSubst [s2, s1, s0] - return (comp, (T.EApp (e0', t0) (e1', t1), t)) + let comp = s2 `compose` s1 `compose` s0 + return (comp, apply comp (T.EApp (e0', t0) (e1', t1), t)) -- \| Γ ⊢ e₀ : τ, S₀ S₀Γ, x : S̅₀Γ̅(τ) ⊢ e₁ : τ', S₁ -- \| ---------------------------------------------- @@ -296,23 +300,20 @@ algoW = \case -- The bar over S₀ and Γ means "generalize" - (ELet (Bind name args e) e1) -> do - (s1, (e, t0)) <- algoW (makeLambda e (coerce args)) + err@(ELet b@(Bind name args e) e1) -> do + (s1, (_, t0)) <- algoW (makeLambda e (coerce args)) + bind' <- exprErr (checkBind b) err env <- asks vars let t' = generalize (apply s1 env) t0 withBinding (coerce name) t' $ do (s2, (e1', t2)) <- algoW e1 let comp = s2 `compose` s1 - return - ( comp - , (T.ELet (T.Bind (coerce name, t0) [] (e, t0)) (e1', t2), t2) - ) + return (comp, apply comp (T.ELet bind' (e1', t2), t2)) ECase caseExpr injs -> do (sub, (e', t)) <- algoW caseExpr (subst, injs, ret_t) <- checkCase t injs let comp = subst `compose` sub - -- return (comp, apply comp (T.ECase (e', t) injs, ret_t)) - return (comp, (T.ECase (e', t) injs, ret_t)) + return (comp, apply comp (T.ECase (e', t) injs, ret_t)) EAppInf{} -> error "desugar phase failed" checkCase :: Type -> [Branch] -> Infer (Subst, [T.Branch' Type], Type) @@ -421,17 +422,15 @@ unify t0 t1 = s1 <- unify a c s2 <- unify (apply s1 b) (apply s1 d) return $ s2 `compose` s1 - (TVar (MkTVar a), t@(TData _ _)) -> - return $ coerce $ M.singleton (coerce a) t - (t@(TData _ _), TVar (MkTVar b)) -> - return $ coerce $ M.singleton (coerce b) t + (TVar (MkTVar a), t@(TData _ _)) -> return $ M.singleton (coerce a) t + (t@(TData _ _), TVar (MkTVar b)) -> return $ M.singleton (coerce b) t (TVar (MkTVar a), t) -> occurs (coerce a) t (t, TVar (MkTVar b)) -> occurs (coerce b) t (TAll _ t, b) -> unify t b (a, TAll _ t) -> unify a t (TLit a, TLit b) -> if a == b - then return nullSubst + then return M.empty else catchableErr $ Aux.do "Can not unify" @@ -453,7 +452,7 @@ unify t0 t1 = quote $ printTree t' (TEVar a, TEVar b) -> if a == b - then return nullSubst + then return M.empty else catchableErr $ Aux.do "Can not unify" @@ -473,7 +472,7 @@ I.E. { a = a -> b } is an unsolvable constraint since there is no substitution where these are equal -} occurs :: T.Ident -> Type -> Infer Subst -occurs i t@(TVar _) = return (coerce $ M.singleton i t) +occurs i t@(TVar _) = return (M.singleton i t) occurs i t = if S.member i (free t) then @@ -484,7 +483,7 @@ occurs i t = "with" quote $ printTree t ) - else return $ coerce $ M.singleton i t + else return $ M.singleton i t {- | Generalize a type over all free variables in the substitution set Used for let bindings to allow expression that do not type check in @@ -510,7 +509,7 @@ inst :: Type -> Infer Type inst = \case TAll (MkTVar bound) t -> do fr <- fresh - let s = coerce $ M.singleton (coerce bound) fr + let s = M.singleton (coerce bound) fr apply s <$> inst t TFun t1 t2 -> TFun <$> inst t1 <*> inst t2 rest -> return rest @@ -546,7 +545,6 @@ skolemize t = t -- | A class for substitutions class SubstType t where -- | Apply a substitution to t - -- apply :: MonadError e m => Subst -> t -> m t apply :: Subst -> t -> t class FreeVars t where @@ -567,18 +565,19 @@ instance FreeVars a => FreeVars [a] where free = let f acc x = acc `S.union` free x in foldl' f S.empty instance SubstType Type where - apply sub@(Subst s) t = do + apply :: Subst -> Type -> Type + apply sub t = do case t of TLit a -> TLit a - TVar (MkTVar a) -> case M.lookup (coerce a) s of + TVar (MkTVar a) -> case M.lookup (coerce a) sub of Nothing -> TVar (MkTVar $ coerce a) Just t -> t - TAll (MkTVar i) t -> case M.lookup (coerce i) s of + TAll (MkTVar i) t -> case M.lookup (coerce i) sub of Nothing -> TAll (MkTVar i) (apply sub t) Just _ -> apply sub t TFun a b -> TFun (apply sub a) (apply sub b) TData name a -> TData name (apply sub a) - TEVar (MkTEVar a) -> case M.lookup (coerce a) s of + TEVar (MkTEVar a) -> case M.lookup (coerce a) sub of Nothing -> TEVar (MkTEVar a) Just t -> t @@ -587,12 +586,11 @@ instance FreeVars (Map T.Ident Type) where free = free . M.elems instance SubstType (Map T.Ident Type) where - apply s = M.map (apply s) + apply :: Subst -> Map T.Ident Type -> Map T.Ident Type + apply = M.map . apply -instance SubstType Subst where - apply s (Subst m2) = Subst $ apply s m2 - --- Subst $ M.map (apply s) m2 +instance SubstType (Map T.Ident (Maybe Type)) where + apply s = M.map (fmap $ apply s) instance SubstType (T.ExpT' Type) where apply s (e, t) = (apply s e, apply s t) @@ -613,8 +611,7 @@ instance SubstType (T.Exp' Type) where instance SubstType (T.Def' Type) where apply s = \case - T.DBind (T.Bind name args e) -> - T.DBind $ T.Bind (apply s name) (apply s args) (apply s e) + T.DBind (T.Bind name args e) -> T.DBind $ T.Bind (apply s name) (apply s args) (apply s e) d -> d instance SubstType (T.Branch' Type) where @@ -639,49 +636,16 @@ instance SubstType (T.Id' Type) where -- | Represents the empty substition set nullSubst :: Subst -nullSubst = Subst mempty +nullSubst = mempty -- | Compose two substitution sets compose :: Subst -> Subst -> Subst -compose m1 m2 = Subst $ M.map (apply $ coerce m1) (coerce m2) `M.union` coerce m1 - --- Order matters. -{- -sub0 = Subst $ (M.singleton "a" (arr d e)) - `M.union` (M.singleton "b" (arr d f)) - `M.union` (M.singleton "c" (arr f e)) -sub1 = Subst $ (M.singleton "a" (arr g bool)) - `M.union` (M.singleton "b" (arr g bool)) - `M.union` (M.singleton "c" (arr bool bool)) - `M.union` (M.singleton "h" bool) - `M.union` (M.singleton "i" bool) -sub0 `composey` sub1 != sub1 `composey` sub0 - -} -composey :: Subst -> Subst -> Infer Subst -composey s0@(Subst m1) s1@(Subst m2) = do - let both = M.keys $ M.intersection m1 m2 - case both of - [] -> return $ s0 `compose` s1 - xs -> do - let m2' = apply s0 m2 - sub <- loop xs m1 m2' - return $ sub `compose` Subst m2 - where - loop [] _ _ = return nullSubst - loop (x : xs) m1 m2 = do - let k1 = m1 M.! x - let k2 = m2 M.! x - sub <- unify k1 k2 - subs <- loop xs m1 m2 - return $ sub `compose` subs +compose m1 m2 = M.map (apply m1) m2 `M.union` m1 -- | Compose a list of substitution sets into one composeAll :: [Subst] -> Subst composeAll = foldl' compose nullSubst -unionSubsts :: [Subst] -> Subst -unionSubsts = Subst . foldl' M.union M.empty . map coerce - {- | Convert a function with arguments to its pointfree version > makeLambda (add x y = x + y) = add = \x. \y. x + y -} @@ -707,21 +671,12 @@ withPattern p ma = case p of T.PEnum _ -> ma -- | Insert a function signature into the environment -insertSig :: T.Ident -> Level Type -> Infer () +insertSig :: T.Ident -> Maybe Type -> Infer () insertSig i t = modify (\st -> st{sigs = M.insert i t (sigs st)}) insertBind :: T.Ident -> Infer () insertBind i = modify (\st -> st{declaredBinds = S.insert i st.declaredBinds}) -insertBindSubst :: T.Ident -> Subst -> Infer () -insertBindSubst name sub = modify (\st -> st{bindSubs = M.insert name sub st.bindSubs}) - -setCurrentBind :: T.Ident -> Infer () -setCurrentBind n = modify (\st -> st{currentBind = n, bindUsages = M.insertWith (++) n [] st.bindUsages}) - -insertBindUsage :: T.Ident -> T.Ident -> Infer () -insertBindUsage cur use = modify (\st -> st{bindUsages = M.insertWith (++) cur [use] st.bindUsages}) - -- | Insert a constructor into the start with its type insertInj :: (Monad m, MonadState Env m) => T.Ident -> Type -> m () insertInj i t = @@ -736,6 +691,24 @@ with an equivalent name has been declared already existInj :: (Monad m, MonadState Env m) => T.Ident -> m (Maybe Type) existInj n = gets (M.lookup n . injections) +setCurrentBind :: T.Ident -> Infer () +setCurrentBind i = modify (\st -> st{currentBind = i}) + +solveUndecidable :: Infer Subst +solveUndecidable = do + sigs <- gets sigs + undecided <- gets undecidedSigs + ys <- + maybeToRightM + (Error "SIGNATURE MISSING" False) + ( mapM (tupSequence . first (join . flip M.lookup sigs . getOriginal)) $ + M.toList undecided + ) + composeAll <$> mapM (uncurry unify) ys + +getOriginal :: T.Ident -> T.Ident +getOriginal (T.Ident i) = coerce $ takeWhile (/= delim) $ drop 1 i + delim :: Char delim = '_' prefix :: Char @@ -812,7 +785,7 @@ dataErr ma d = ) initCtx = Ctx mempty -initEnv = Env 0 'a' mempty mempty mempty mempty "" mempty mempty +initEnv = Env 0 'a' mempty mempty mempty "" mempty mempty mempty run :: Infer a -> Either Error a run = run' initEnv initCtx @@ -831,28 +804,19 @@ newtype Ctx = Ctx {vars :: Map T.Ident Type} data Env = Env { count :: Int , nextChar :: Char - , sigs :: Map T.Ident (Level Type) + , sigs :: Map T.Ident (Maybe Type) , takenTypeVars :: Set T.Ident , injections :: Map T.Ident Type - , declaredBinds :: Set T.Ident , currentBind :: T.Ident - , bindSubs :: Map T.Ident Subst - , bindUsages :: Map T.Ident [T.Ident] + , undecidedSigs :: Map T.Ident Type + , toDecide :: Set T.Ident + , declaredBinds :: Set T.Ident } deriving (Show) -data Level a = Instantiated {unlevel :: a} | Generalized {unlevel :: a} - deriving (Show) - data Error = Error {msg :: String, catchable :: Bool} deriving (Show) - -newtype Subst = Subst (Map T.Ident Type) - -instance Show Subst where - show (Subst s) = "[ " ++ intercalate " | " xs ++ " ]" - where - xs = map (\(a, b) -> printTree a ++ " = " ++ printTree b) $ M.toList s +type Subst = Map T.Ident Type newtype Infer a = Infer {runInfer :: StateT Env (ReaderT Ctx (ExceptT Error Identity)) a} deriving (Functor, Applicative, Monad, MonadReader Ctx, MonadError Error, MonadState Env) @@ -868,3 +832,4 @@ quote s = "'" ++ s ++ "'" ctrace :: (Monad m, Show a) => String -> a -> m () ctrace str a = trace (str ++ ": " ++ show a) pure () +