From de1ca23db77ac6eb975b710cd083dc687c6f19dd Mon Sep 17 00:00:00 2001 From: sebastian Date: Wed, 17 May 2023 17:31:08 +0200 Subject: [PATCH] Remade <<=, better err msg, removed writer monad --- src/TypeChecker/TypeChecker.hs | 3 +- src/TypeChecker/TypeCheckerHm.hs | 182 +++++++++++++++++++------------ 2 files changed, 114 insertions(+), 71 deletions(-) diff --git a/src/TypeChecker/TypeChecker.hs b/src/TypeChecker/TypeChecker.hs index 7f3d67a..008f086 100644 --- a/src/TypeChecker/TypeChecker.hs +++ b/src/TypeChecker/TypeChecker.hs @@ -4,7 +4,6 @@ import Control.Monad ((<=<)) import qualified Grammar.Abs as G import Grammar.ErrM (Err) import TypeChecker.RemoveForall (removeForall) -import qualified TypeChecker.ReportTEVar as R import TypeChecker.ReportTEVar (reportTEVar) import qualified TypeChecker.TypeCheckerBidir as Bi import qualified TypeChecker.TypeCheckerHm as Hm @@ -17,4 +16,4 @@ typecheck tc = fmap removeForall . (reportTEVar <=< f) where f = case tc of Bi -> Bi.typecheck - Hm -> fmap fst . Hm.typecheck + Hm -> Hm.typecheck diff --git a/src/TypeChecker/TypeCheckerHm.hs b/src/TypeChecker/TypeCheckerHm.hs index 73a9bc8..5e76846 100644 --- a/src/TypeChecker/TypeCheckerHm.hs +++ b/src/TypeChecker/TypeCheckerHm.hs @@ -15,7 +15,6 @@ import Control.Monad.Except import Control.Monad.Identity (Identity, runIdentity) import Control.Monad.Reader import Control.Monad.State -import Control.Monad.Writer import Data.Coerce (coerce) import Data.Function (on) import Data.List (foldl') @@ -27,10 +26,9 @@ import Grammar.Abs import Grammar.Print (printTree) import TypeChecker.TypeCheckerIr (T, T') import TypeChecker.TypeCheckerIr qualified as T -import Debug.Trace (trace) -- | Type check a program -typecheck :: Program -> Either String (T.Program' Type, [Warning]) +typecheck :: Program -> Either String (T.Program' Type) typecheck = onLeft msg . run . checkPrg where onLeft :: (Error -> String) -> Either Error a -> Either String a @@ -108,9 +106,9 @@ checkBind (Bind name args e) = do Error ( Aux.do "Inferred type" - quote $ printTree genInfSig + pretty genInfSig "doesn't match given type" - quote $ printTree typSig + pretty typSig ) False ) @@ -146,7 +144,7 @@ checkInj (Inj c inj_typ) name tvars "Constructor" quote $ coerce name "with type" - quote $ printTree t + pretty t "already exist" Nothing -> insertInj (coerce c) inj_typ | otherwise = @@ -155,9 +153,9 @@ checkInj (Inj c inj_typ) name tvars [ "Bad type constructor: " , show name , "\nExpected: " - , printTree . TData name $ map TVar tvars + , printTree . TData name $ map (clean . TVar) tvars , "\nActual: " - , printTree $ returnType inj_typ + , pretty $ returnType inj_typ ] toTVar :: Type -> Either Error TVar @@ -203,9 +201,9 @@ algoW = \case b ( uncatchableErr $ Aux.do "Annotated type" - quote $ printTree t - "does not match inferred type" - quote $ printTree t' + pretty t + "is more polymorphic than the inferred type" + pretty t' ) let comp = sub1 <> sub0 return (comp, (apply comp e', t)) @@ -327,7 +325,6 @@ checkCase expT brnchs = do (subs, branchTs, injs, returns) <- unzip4 <$> mapM inferBranch brnchs -- compose all probably wrong let sub0 = composeAll subs - trace ("Substitutions: " ++ show subs) pure () (sub1, _) <- foldM ( \(sub, acc) x -> @@ -453,17 +450,17 @@ unify t0 t1 = case (t0, t1) of Aux.do "Type constructor:" printTree name - quote $ printTree t + quote $ printTree (map clean t) "does not match with:" printTree name' - quote $ printTree t' + quote $ printTree (map clean t') (a, b) -> do catchableErr $ Aux.do "Can not unify" - quote $ printTree a + pretty a "with" - quote $ printTree b + pretty b {- | Check if a type is contained in another type. I.E. { a = a -> b } is an unsolvable constraint since there is no substitution @@ -476,9 +473,9 @@ occurs i t catchableErr ( Aux.do "Occurs check failed, can't unify" - quote $ printTree (TVar i) + pretty $ TVar i "with" - quote $ printTree t + pretty t ) | otherwise = return $ singleton i t @@ -529,56 +526,88 @@ fresh :: Infer Type fresh = do n <- gets count modify (\st -> st{count = succ (count st)}) - return . TVar . MkTVar . LIdent $ letters !! n + return . TVar . MkTVar . LIdent $ '1' : (letters !! n) -- Is the left more general than the right -- TODO: A bug might exist +-- (<<=) :: Type -> Type -> Infer Bool +-- (<<=) a b = case (a, b) of +-- (TVar _, _) -> return True +-- (TFun a b, TFun c d) -> do +-- bfirst <- a <<= c +-- bsecond <- b <<= d +-- return (bfirst && bsecond) +-- (TData n1 ts1, TData n2 ts2) -> do +-- b <- and <$> zipWithM (<<=) ts1 ts2 +-- return (b && n1 == n2 && length ts1 == length ts2) +-- (t1@(TAll _ _), t2) -> +-- let (tvars1, t1') = gatherTVars [] t1 +-- (tvars2, t2') = gatherTVars [] t2 +-- in go (tvars1 ++ tvars2) t1' t2' +-- (t1, t2@(TAll _ _)) -> +-- let (tvars1, t1') = gatherTVars [] t1 +-- (tvars2, t2') = gatherTVars [] t2 +-- in go (tvars1 ++ tvars2) t1' t2' +-- (t1, t2) -> return $ t1 == t2 +-- where +-- go :: [TVar] -> Type -> Type -> Infer Bool +-- go tvars t1 t2 = do +-- freshies <- mapM (const fresh) tvars +-- let sub = Subst . M.fromList $ zip tvars freshies +-- let t1' = apply sub t1 +-- let t2' = apply sub t2 +-- let alph = Subst $ execState (alpha t1' t2') mempty +-- return $ apply alph t1' == t2' + +-- -- Alpha rename the first type's type variable to match second. +-- -- Pre-condition: No TAll are checked +-- alpha :: Type -> Type -> State (Map TVar Type) () +-- alpha t1 t2 = case (t1, t2) of +-- (TVar i, t2) -> do +-- m <- get +-- put (M.insert i t2 m) +-- (TFun t1 t2, TFun t3 t4) -> do +-- alpha t1 t3 +-- alpha t2 t4 +-- (TData _ ts1, TData _ ts2) -> zipWithM_ alpha ts1 ts2 +-- _ -> return () +-- Pre-condition: All TAlls are outermost +gatherTVars :: [TVar] -> Type -> ([TVar], Type) +gatherTVars tvars (TAll tvar t) = gatherTVars (tvar : tvars) t +gatherTVars tvars t = (tvars, t) + +t1 = TAll (MkTVar "a") a +a = TVar $ MkTVar "a" +b = TVar $ MkTVar "b" +t2 = TFun a b + (<<=) :: Type -> Type -> Infer Bool -(<<=) a b = case (a, b) of - (TVar _, _) -> return True - (TFun a b, TFun c d) -> do - bfirst <- a <<= c - bsecond <- b <<= d - return (bfirst && bsecond) - (TData n1 ts1, TData n2 ts2) -> do - b <- and <$> zipWithM (<<=) ts1 ts2 - return (b && n1 == n2 && length ts1 == length ts2) - (t1@(TAll _ _), t2) -> - let (tvars1, t1') = gatherTVars [] t1 - (tvars2, t2') = gatherTVars [] t2 - in go (tvars1 ++ tvars2) t1' t2' - (t1, t2@(TAll _ _)) -> - let (tvars1, t1') = gatherTVars [] t1 - (tvars2, t2') = gatherTVars [] t2 - in go (tvars1 ++ tvars2) t1' t2' +(<<=) t1@(TAll _ _ ) t2 = + let (tvars1, t1') = gatherTVars [] t1 + (_, t2') = gatherTVars [] t2 + (b1, vars) = runState (match t1' t2') mempty + b2 = all (`elem` tvars1) (M.keys vars) + in return $ b1 && b2 +(<<=) t1 t2 = return $ t1 == t2 + +-- Left represents the one that should be more general +match :: Type -> Type -> State (Map TVar Type) Bool +match t1 t2 = case (t1, t2) of + (TVar a, t2) -> insertMatch a t2 + (TFun t1 t2, TFun t3 t4) -> (&&) <$> match t1 t3 <*> match t2 t4 + (TData ident1 ts1, TData ident2 ts2) -> + (ident1 == ident2 && ) . and <$> zipWithM match ts1 ts2 (t1, t2) -> return $ t1 == t2 - where - go :: [TVar] -> Type -> Type -> Infer Bool - go tvars t1 t2 = do - freshies <- mapM (const fresh) tvars - let sub = Subst . M.fromList $ zip tvars freshies - let t1' = apply sub t1 - let t2' = apply sub t2 - let alph = Subst $ execState (alpha t1' t2') mempty - return $ apply alph t1' == t2' - -- Pre-condition: All TAlls are outermost - gatherTVars :: [TVar] -> Type -> ([TVar], Type) - gatherTVars tvars (TAll tvar t) = gatherTVars (tvar : tvars) t - gatherTVars tvars t = (tvars, t) - -- Alpha rename the first type's type variable to match second. - -- Pre-condition: No TAll are checked - alpha :: Type -> Type -> State (Map TVar Type) () - alpha t1 t2 = case (t1, t2) of - (TVar i, t2) -> do - m <- get - put (M.insert i t2 m) - (TFun t1 t2, TFun t3 t4) -> do - alpha t1 t3 - alpha t2 t4 - (TData _ ts1, TData _ ts2) -> zipWithM_ alpha ts1 ts2 - _ -> return () + where + insertMatch :: TVar -> Type -> State (Map TVar Type) Bool + insertMatch tvar t = do + m <- gets $ M.lookup tvar + case m of + Nothing -> modify (M.insert tvar t) >> return True + Just t' -> return $ t == t' + -- | A class for substitutions class SubstType t where @@ -808,14 +837,13 @@ dataErr ma d = initCtx = Ctx mempty initEnv = Env 0 'a' mempty mempty mempty mempty -run :: Infer a -> Either Error (a, [Warning]) +run :: Infer a -> Either Error a run = run' initEnv initCtx -run' :: Env -> Ctx -> Infer a -> Either Error (a, [Warning]) +run' :: Env -> Ctx -> Infer a -> Either Error a run' e c = runIdentity . runExceptT - . runWriterT . flip runReaderT c . flip evalStateT e . runInfer @@ -851,11 +879,16 @@ instance Semigroup Subst where instance Monoid Subst where mempty = Subst mempty -newtype Warning = NonExhaustive String - deriving (Show) - -newtype Infer a = Infer {runInfer :: StateT Env (ReaderT Ctx (WriterT [Warning] (ExceptT Error Identity))) a} - deriving (Functor, Applicative, Monad, MonadReader Ctx, MonadError Error, MonadState Env) +newtype Infer a = Infer {runInfer :: StateT Env + (ReaderT Ctx + (ExceptT Error Identity)) a} + deriving ( Functor + , Applicative + , Monad + , MonadReader Ctx + , MonadError Error + , MonadState Env + ) catchableErr :: MonadError Error m => String -> m a catchableErr msg = throwError $ Error msg True @@ -868,3 +901,14 @@ quote s = "'" ++ s ++ "'" letters :: [String] letters = [1 ..] >>= flip replicateM ['a' .. 'z'] + +clean :: Type -> Type +clean (TVar (MkTVar (LIdent a))) = + TVar . MkTVar . LIdent $ dropWhile (`notElem` ['a' .. 'z']) a +clean (TFun t1 t2) = TFun (clean t1) (clean t2) +clean (TData i ts) = TData i (map clean ts) +clean (TAll tvar t) = let (TVar tvar') = clean (TVar tvar) + in TAll tvar' (clean t) +clean t = t + +pretty = quote . printTree . clean