diff --git a/src/TypeChecker/TypeCheckerHm.hs b/src/TypeChecker/TypeCheckerHm.hs index 166e680..49cef01 100644 --- a/src/TypeChecker/TypeCheckerHm.hs +++ b/src/TypeChecker/TypeCheckerHm.hs @@ -27,22 +27,7 @@ import Grammar.Abs import Grammar.Print (printTree) import TypeChecker.TypeCheckerIr qualified as T --- TODO: Save all substition sets encountered in the program and apply --- to all top level functions in the end. - -initCtx = Ctx mempty -initEnv = Env 0 'a' mempty mempty mempty "" mempty mempty mempty - -run :: Infer a -> Either Error a -run = run' initEnv initCtx - -run' :: Env -> Ctx -> Infer a -> Either Error a -run' e c = - runIdentity - . runExceptT - . flip runReaderT c - . flip evalStateT e - . runInfer +-- TODO: Disallow mutual recursion -- | Type check a program typecheck :: Program -> Either String (T.Program' Type) @@ -73,29 +58,23 @@ preRun [] = return () preRun (x : xs) = case x of DSig (Sig n t) -> do collect (collectTVars t) - gets (M.member (coerce n) . sigs) - >>= flip - when - ( uncatchableErr $ Aux.do - "Duplicate signatures of function" - quote $ printTree n - ) + duplicateDecl n $ Aux.do + "Multiple signatures of function" + quote $ printTree n insertSig (coerce n) (Just t) >> preRun xs DBind (Bind n _ e) -> do - binds <- gets declaredBinds - when - (coerce n `S.member` binds) - ( uncatchableErr $ Aux.do - "Duplicate declarations of function" - quote $ printTree n - ) - modify (\st -> st{declaredBinds = S.insert (coerce n) st.declaredBinds}) + duplicateDecl n $ Aux.do + "Multiple declarations of function" + quote $ printTree n collect (collectTVars e) s <- gets sigs case M.lookup (coerce n) s of Nothing -> insertSig (coerce n) Nothing >> preRun xs Just _ -> preRun xs - DData d@(Data t _) -> collect (collectTVars t) >> checkData d >> preRun xs + DData d@(Data t _) -> let collected = collect (collectTVars t) in checkData d collected >> preRun xs + where + -- Check if function body / signature has been declared already + duplicateDecl n msg = gets (M.member (coerce n) . sigs) >>= flip when (uncatchableErr msg) checkDef :: [Def] -> Infer [T.Def' Type] checkDef [] = return [] @@ -126,10 +105,10 @@ checkBind bind@(Bind name args e) = do insertSig (coerce name) (Just lambda_t) return (T.Bind (coerce name, lambda_t) [] (e, lambda_t)) -checkData :: Data -> Infer () -checkData err@(Data typ injs) = do +checkData :: (MonadReader Ctx m, Monad m, MonadError Error m) => Data -> m () -> m () +checkData err@(Data typ injs) ma = do (name, tvars) <- go typ - dataErr (mapM_ (\i -> checkInj i name tvars) injs) err + dataErr (mapM_ (\i -> checkInj i name tvars ma) injs) err where go = \case TData name typs @@ -140,8 +119,8 @@ checkData err@(Data typ injs) = do uncatchableErr $ unwords ["Bad data type definition: ", printTree typ] -checkInj :: Inj -> UIdent -> [TVar] -> Infer () -checkInj (Inj c inj_typ) name tvars +checkInj :: (MonadError Error m, MonadReader Ctx m, Monad m) => Inj -> UIdent -> [TVar] -> m a -> m a +checkInj (Inj c inj_typ) name tvars ma | Right False <- boundTVars tvars inj_typ = catchableErr "Unbound type variables" | TData name' typs <- returnType inj_typ @@ -156,7 +135,7 @@ checkInj (Inj c inj_typ) name tvars "with type" quote $ printTree t "already exist" - Nothing -> insertInj (coerce c) inj_typ + Nothing -> insertInj (coerce c) inj_typ ma | otherwise = uncatchableErr $ unwords @@ -246,11 +225,11 @@ algoW = \case return (nullSubst, (T.EVar $ coerce i, x)) Nothing -> do sig <- gets sigs + cb <- gets currentBind case M.lookup (coerce i) sig of Just (Just t) -> return (nullSubst, (T.EVar $ coerce i, t)) Just Nothing -> do fr <- fresh - cb <- gets currentBind modify (\st -> st{toDecide = S.insert cb st.toDecide, undecidedSigs = M.insert (coerce $ concat [[prefix], i, [delim], coerce cb]) fr st.undecidedSigs}) return (nullSubst, (T.EVar $ coerce i, fr)) Nothing -> @@ -258,7 +237,7 @@ algoW = \case "Unbound variable: " <> printTree i EInj i -> do - constr <- gets injections + constr <- asks injections case M.lookup (coerce i) constr of Just t -> return (nullSubst, (T.EVar $ coerce i, t)) Nothing -> @@ -304,11 +283,13 @@ algoW = \case err@(EApp e0 e1) -> do fr <- fresh (s0, (e0', t0)) <- algoW e0 - (s1, (e1', t1)) <- algoW e1 - s2 <- exprErr (unify (apply s1 t0) (TFun t1 fr)) err - let t = apply s2 fr - let comp = s2 `compose` s1 `compose` s0 - return (comp, apply comp (T.EApp (e0', t0) (e1', t1), t)) + applySt s0 $ do + modify (\st -> st{sigs = apply s0 st.sigs}) + (s1, (e1', t1)) <- algoW e1 + s2 <- exprErr (unify (apply s1 t0) (TFun t1 fr)) err + let t = apply s2 fr + let comp = s2 `compose` s1 `compose` s0 + return (comp, apply comp (T.EApp (e0', t0) (e1', t1), t)) -- \| Γ ⊢ e₀ : τ, S₀ S₀Γ, x : S̅₀Γ̅(τ) ⊢ e₁ : τ', S₁ -- \| ---------------------------------------------- @@ -368,8 +349,38 @@ inferBranch (Branch pat expr) = do inferPattern :: Pattern -> Infer (T.Pattern' Type, Type) inferPattern = \case PLit lit -> let lt = litType lit in return (T.PLit (lit, lt), lt) + PCatch -> (T.PCatch,) <$> fresh + PVar x -> do + fr <- fresh + let pvar = T.PVar (coerce x, fr) + return (pvar, fr) + PEnum p -> do + t <- asks (M.lookup (coerce p) . injections) + t <- + maybeToRightM + ( Error + ( Aux.do + "Constructor:" + quote $ printTree p + "does not exist" + ) + True + ) + t + unless + (typeLength t == 1) + ( catchableErr $ Aux.do + "The constructor" + quote $ printTree p + " should have " + show (typeLength t - 1) + " arguments but has been given 0" + ) + let (TData _data _ts) = t -- nasty nasty + frs <- mapM (const fresh) _ts + return (T.PEnum $ coerce p, TData _data frs) PInj constr patterns -> do - t <- gets (M.lookup (coerce constr) . injections) + t <- asks (M.lookup (coerce constr) . injections) t <- maybeToRightM ( Error @@ -399,36 +410,6 @@ inferPattern = \case ( T.PInj (coerce constr) (apply sub (map fst patterns)) , apply sub ret ) - PCatch -> (T.PCatch,) <$> fresh - PEnum p -> do - t <- gets (M.lookup (coerce p) . injections) - t <- - maybeToRightM - ( Error - ( Aux.do - "Constructor:" - quote $ printTree p - "does not exist" - ) - True - ) - t - unless - (typeLength t == 1) - ( catchableErr $ Aux.do - "The constructor" - quote $ printTree p - " should have " - show (typeLength t - 1) - " arguments but has been given 0" - ) - let (TData _data _ts) = t -- nasty nasty - frs <- mapM (const fresh) _ts - return (T.PEnum $ coerce p, TData _data frs) - PVar x -> do - fr <- fresh - let pvar = T.PVar (coerce x, fr) - return (pvar, fr) -- | Unify two types producing a new substitution unify :: Type -> Type -> Infer Subst @@ -437,7 +418,7 @@ unify t0 t1 = (TFun a b, TFun c d) -> do s1 <- unify a c s2 <- unify (apply s1 b) (apply s1 d) - return $ s1 `compose` s2 + return $ s2 `compose` s1 (TVar (MkTVar a), t@(TData _ _)) -> return $ M.singleton (coerce a) t (t@(TData _ _), TVar (MkTVar b)) -> return $ M.singleton (coerce b) t (TVar (MkTVar a), t) -> occurs (coerce a) t @@ -605,6 +586,9 @@ instance SubstType (Map T.Ident Type) where apply :: Subst -> Map T.Ident Type -> Map T.Ident Type apply = M.map . apply +instance SubstType (Map T.Ident (Maybe Type)) where + apply s = M.map (fmap $ apply s) + instance SubstType (T.ExpT' Type) where apply s (e, t) = (apply s e, apply s t) @@ -688,15 +672,18 @@ insertSig :: T.Ident -> Maybe Type -> Infer () insertSig i t = modify (\st -> st{sigs = M.insert i t (sigs st)}) -- | Insert a constructor into the start with its type -insertInj :: T.Ident -> Type -> Infer () +insertInj :: (Monad m, MonadReader Ctx m) => T.Ident -> Type -> m a -> m a insertInj i t = - modify (\st -> st{injections = M.insert i t (injections st)}) + local (\st -> st{injections = M.insert i t (injections st)}) + +applySt :: Subst -> Infer a -> Infer a +applySt s = local (\st -> st{vars = apply s st.vars}) {- | Check if an injection (constructor of data type) with an equivalent name has been declared already -} -existInj :: T.Ident -> Infer (Maybe Type) -existInj n = gets (M.lookup n . injections) +existInj :: (Monad m, MonadReader Ctx m) => T.Ident -> m (Maybe Type) +existInj n = asks (M.lookup n . injections) setCurrentBind :: T.Ident -> Infer () setCurrentBind i = modify (\st -> st{currentBind = i}) @@ -705,11 +692,12 @@ solveUndecidable :: Infer Subst solveUndecidable = do sigs <- gets sigs undecided <- gets undecidedSigs - let xs = M.toList undecided ys <- maybeToRightM (Error "SIGNATURE MISSING" False) - (mapM (tupSequence . first (join . flip M.lookup sigs . getOriginal)) xs) + ( mapM (tupSequence . first (join . flip M.lookup sigs . getOriginal)) $ + M.toList undecided + ) composeAll <$> mapM (uncurry unify) ys tupSequence :: Monad m => (m a, b) -> m (a, b) @@ -738,48 +726,6 @@ litType (LChar _) = char int = TLit "Int" char = TLit "Char" -typeEq :: Type -> Type -> StateT Subst (ExceptT Error Identity) () -typeEq (TVar (MkTVar a)) t@(TVar _) = do - st <- get - case M.lookup (coerce a) st of - Nothing -> put $ M.insert (coerce a) t st - Just t' -> - unless - (t == t') - ( catchableErr $ Aux.do - quote $ printTree t - "does not match with" - quote $ printTree t' - ) -typeEq (TFun l r) (TFun l' r') = typeEq l l' *> typeEq r r' -typeEq (TAll _ l) (TAll _ r) = typeEq l r -typeEq t@(TLit a) t'@(TLit b) = - unless - (a == b) - ( catchableErr $ Aux.do - quote $ printTree t - "does not match with" - quote $ printTree t' - ) -typeEq t@(TData nameL tL) t'@(TData nameR tR) = do - unless - (nameL == nameR) - ( catchableErr $ Aux.do - quote $ printTree t - "does not match with" - quote $ printTree t' - ) - zipWithM_ typeEq tL tR -typeEq t@(TEVar _) t'@(TEVar _) = - catchableErr $ Aux.do - quote $ printTree t - "does not match with" - quote $ printTree t' -typeEq t t' = catchableErr $ Aux.do - quote $ printTree t - "does not match with" - quote $ printTree t' - {- | Catch an error if possible and add the given expression as addition to the error message -} @@ -824,7 +770,7 @@ bindErr ma bind = {- | Catch an error if possible and add the given data as addition to the error message -} -dataErr :: Infer a -> Data -> Infer a +dataErr :: (MonadError Error m, Monad m) => m a -> Data -> m a dataErr ma d = catchError ma @@ -850,19 +796,31 @@ unzip4 = ) ([], [], [], []) -newtype Ctx = Ctx {vars :: Map T.Ident Type} +initCtx = Ctx mempty mempty +initEnv = Env 0 'a' mempty mempty "" mempty mempty + +run :: Infer a -> Either Error a +run = run' initEnv initCtx + +run' :: Env -> Ctx -> Infer a -> Either Error a +run' e c = + runIdentity + . runExceptT + . flip runReaderT c + . flip evalStateT e + . runInfer + +data Ctx = Ctx {vars :: Map T.Ident Type, injections :: Map T.Ident Type} deriving (Show) data Env = Env { count :: Int , nextChar :: Char , sigs :: Map T.Ident (Maybe Type) - , injections :: Map T.Ident Type , takenTypeVars :: Set T.Ident , currentBind :: T.Ident , undecidedSigs :: Map T.Ident Type , toDecide :: Set T.Ident - , declaredBinds :: Set T.Ident } deriving (Show)