Remade <<=, better err msg, removed writer monad

This commit is contained in:
sebastian 2023-05-17 17:31:08 +02:00
parent 5eaf7ae00d
commit de1ca23db7
2 changed files with 114 additions and 71 deletions

View file

@ -4,7 +4,6 @@ import Control.Monad ((<=<))
import qualified Grammar.Abs as G import qualified Grammar.Abs as G
import Grammar.ErrM (Err) import Grammar.ErrM (Err)
import TypeChecker.RemoveForall (removeForall) import TypeChecker.RemoveForall (removeForall)
import qualified TypeChecker.ReportTEVar as R
import TypeChecker.ReportTEVar (reportTEVar) import TypeChecker.ReportTEVar (reportTEVar)
import qualified TypeChecker.TypeCheckerBidir as Bi import qualified TypeChecker.TypeCheckerBidir as Bi
import qualified TypeChecker.TypeCheckerHm as Hm import qualified TypeChecker.TypeCheckerHm as Hm
@ -17,4 +16,4 @@ typecheck tc = fmap removeForall . (reportTEVar <=< f)
where where
f = case tc of f = case tc of
Bi -> Bi.typecheck Bi -> Bi.typecheck
Hm -> fmap fst . Hm.typecheck Hm -> Hm.typecheck

View file

@ -15,7 +15,6 @@ import Control.Monad.Except
import Control.Monad.Identity (Identity, runIdentity) import Control.Monad.Identity (Identity, runIdentity)
import Control.Monad.Reader import Control.Monad.Reader
import Control.Monad.State import Control.Monad.State
import Control.Monad.Writer
import Data.Coerce (coerce) import Data.Coerce (coerce)
import Data.Function (on) import Data.Function (on)
import Data.List (foldl') import Data.List (foldl')
@ -27,10 +26,9 @@ import Grammar.Abs
import Grammar.Print (printTree) import Grammar.Print (printTree)
import TypeChecker.TypeCheckerIr (T, T') import TypeChecker.TypeCheckerIr (T, T')
import TypeChecker.TypeCheckerIr qualified as T import TypeChecker.TypeCheckerIr qualified as T
import Debug.Trace (trace)
-- | Type check a program -- | Type check a program
typecheck :: Program -> Either String (T.Program' Type, [Warning]) typecheck :: Program -> Either String (T.Program' Type)
typecheck = onLeft msg . run . checkPrg typecheck = onLeft msg . run . checkPrg
where where
onLeft :: (Error -> String) -> Either Error a -> Either String a onLeft :: (Error -> String) -> Either Error a -> Either String a
@ -108,9 +106,9 @@ checkBind (Bind name args e) = do
Error Error
( Aux.do ( Aux.do
"Inferred type" "Inferred type"
quote $ printTree genInfSig pretty genInfSig
"doesn't match given type" "doesn't match given type"
quote $ printTree typSig pretty typSig
) )
False False
) )
@ -146,7 +144,7 @@ checkInj (Inj c inj_typ) name tvars
"Constructor" "Constructor"
quote $ coerce name quote $ coerce name
"with type" "with type"
quote $ printTree t pretty t
"already exist" "already exist"
Nothing -> insertInj (coerce c) inj_typ Nothing -> insertInj (coerce c) inj_typ
| otherwise = | otherwise =
@ -155,9 +153,9 @@ checkInj (Inj c inj_typ) name tvars
[ "Bad type constructor: " [ "Bad type constructor: "
, show name , show name
, "\nExpected: " , "\nExpected: "
, printTree . TData name $ map TVar tvars , printTree . TData name $ map (clean . TVar) tvars
, "\nActual: " , "\nActual: "
, printTree $ returnType inj_typ , pretty $ returnType inj_typ
] ]
toTVar :: Type -> Either Error TVar toTVar :: Type -> Either Error TVar
@ -203,9 +201,9 @@ algoW = \case
b b
( uncatchableErr $ Aux.do ( uncatchableErr $ Aux.do
"Annotated type" "Annotated type"
quote $ printTree t pretty t
"does not match inferred type" "is more polymorphic than the inferred type"
quote $ printTree t' pretty t'
) )
let comp = sub1 <> sub0 let comp = sub1 <> sub0
return (comp, (apply comp e', t)) return (comp, (apply comp e', t))
@ -327,7 +325,6 @@ checkCase expT brnchs = do
(subs, branchTs, injs, returns) <- unzip4 <$> mapM inferBranch brnchs (subs, branchTs, injs, returns) <- unzip4 <$> mapM inferBranch brnchs
-- compose all probably wrong -- compose all probably wrong
let sub0 = composeAll subs let sub0 = composeAll subs
trace ("Substitutions: " ++ show subs) pure ()
(sub1, _) <- (sub1, _) <-
foldM foldM
( \(sub, acc) x -> ( \(sub, acc) x ->
@ -453,17 +450,17 @@ unify t0 t1 = case (t0, t1) of
Aux.do Aux.do
"Type constructor:" "Type constructor:"
printTree name printTree name
quote $ printTree t quote $ printTree (map clean t)
"does not match with:" "does not match with:"
printTree name' printTree name'
quote $ printTree t' quote $ printTree (map clean t')
(a, b) -> do (a, b) -> do
catchableErr $ catchableErr $
Aux.do Aux.do
"Can not unify" "Can not unify"
quote $ printTree a pretty a
"with" "with"
quote $ printTree b pretty b
{- | Check if a type is contained in another type. {- | Check if a type is contained in another type.
I.E. { a = a -> b } is an unsolvable constraint since there is no substitution I.E. { a = a -> b } is an unsolvable constraint since there is no substitution
@ -476,9 +473,9 @@ occurs i t
catchableErr catchableErr
( Aux.do ( Aux.do
"Occurs check failed, can't unify" "Occurs check failed, can't unify"
quote $ printTree (TVar i) pretty $ TVar i
"with" "with"
quote $ printTree t pretty t
) )
| otherwise = return $ singleton i t | otherwise = return $ singleton i t
@ -529,56 +526,88 @@ fresh :: Infer Type
fresh = do fresh = do
n <- gets count n <- gets count
modify (\st -> st{count = succ (count st)}) modify (\st -> st{count = succ (count st)})
return . TVar . MkTVar . LIdent $ letters !! n return . TVar . MkTVar . LIdent $ '1' : (letters !! n)
-- Is the left more general than the right -- Is the left more general than the right
-- TODO: A bug might exist -- TODO: A bug might exist
-- (<<=) :: Type -> Type -> Infer Bool
-- (<<=) a b = case (a, b) of
-- (TVar _, _) -> return True
-- (TFun a b, TFun c d) -> do
-- bfirst <- a <<= c
-- bsecond <- b <<= d
-- return (bfirst && bsecond)
-- (TData n1 ts1, TData n2 ts2) -> do
-- b <- and <$> zipWithM (<<=) ts1 ts2
-- return (b && n1 == n2 && length ts1 == length ts2)
-- (t1@(TAll _ _), t2) ->
-- let (tvars1, t1') = gatherTVars [] t1
-- (tvars2, t2') = gatherTVars [] t2
-- in go (tvars1 ++ tvars2) t1' t2'
-- (t1, t2@(TAll _ _)) ->
-- let (tvars1, t1') = gatherTVars [] t1
-- (tvars2, t2') = gatherTVars [] t2
-- in go (tvars1 ++ tvars2) t1' t2'
-- (t1, t2) -> return $ t1 == t2
-- where
-- go :: [TVar] -> Type -> Type -> Infer Bool
-- go tvars t1 t2 = do
-- freshies <- mapM (const fresh) tvars
-- let sub = Subst . M.fromList $ zip tvars freshies
-- let t1' = apply sub t1
-- let t2' = apply sub t2
-- let alph = Subst $ execState (alpha t1' t2') mempty
-- return $ apply alph t1' == t2'
-- -- Alpha rename the first type's type variable to match second.
-- -- Pre-condition: No TAll are checked
-- alpha :: Type -> Type -> State (Map TVar Type) ()
-- alpha t1 t2 = case (t1, t2) of
-- (TVar i, t2) -> do
-- m <- get
-- put (M.insert i t2 m)
-- (TFun t1 t2, TFun t3 t4) -> do
-- alpha t1 t3
-- alpha t2 t4
-- (TData _ ts1, TData _ ts2) -> zipWithM_ alpha ts1 ts2
-- _ -> return ()
-- Pre-condition: All TAlls are outermost
gatherTVars :: [TVar] -> Type -> ([TVar], Type)
gatherTVars tvars (TAll tvar t) = gatherTVars (tvar : tvars) t
gatherTVars tvars t = (tvars, t)
t1 = TAll (MkTVar "a") a
a = TVar $ MkTVar "a"
b = TVar $ MkTVar "b"
t2 = TFun a b
(<<=) :: Type -> Type -> Infer Bool (<<=) :: Type -> Type -> Infer Bool
(<<=) a b = case (a, b) of (<<=) t1@(TAll _ _ ) t2 =
(TVar _, _) -> return True let (tvars1, t1') = gatherTVars [] t1
(TFun a b, TFun c d) -> do (_, t2') = gatherTVars [] t2
bfirst <- a <<= c (b1, vars) = runState (match t1' t2') mempty
bsecond <- b <<= d b2 = all (`elem` tvars1) (M.keys vars)
return (bfirst && bsecond) in return $ b1 && b2
(TData n1 ts1, TData n2 ts2) -> do (<<=) t1 t2 = return $ t1 == t2
b <- and <$> zipWithM (<<=) ts1 ts2
return (b && n1 == n2 && length ts1 == length ts2) -- Left represents the one that should be more general
(t1@(TAll _ _), t2) -> match :: Type -> Type -> State (Map TVar Type) Bool
let (tvars1, t1') = gatherTVars [] t1 match t1 t2 = case (t1, t2) of
(tvars2, t2') = gatherTVars [] t2 (TVar a, t2) -> insertMatch a t2
in go (tvars1 ++ tvars2) t1' t2' (TFun t1 t2, TFun t3 t4) -> (&&) <$> match t1 t3 <*> match t2 t4
(t1, t2@(TAll _ _)) -> (TData ident1 ts1, TData ident2 ts2) ->
let (tvars1, t1') = gatherTVars [] t1 (ident1 == ident2 && ) . and <$> zipWithM match ts1 ts2
(tvars2, t2') = gatherTVars [] t2
in go (tvars1 ++ tvars2) t1' t2'
(t1, t2) -> return $ t1 == t2 (t1, t2) -> return $ t1 == t2
where
go :: [TVar] -> Type -> Type -> Infer Bool
go tvars t1 t2 = do
freshies <- mapM (const fresh) tvars
let sub = Subst . M.fromList $ zip tvars freshies
let t1' = apply sub t1
let t2' = apply sub t2
let alph = Subst $ execState (alpha t1' t2') mempty
return $ apply alph t1' == t2'
-- Pre-condition: All TAlls are outermost
gatherTVars :: [TVar] -> Type -> ([TVar], Type)
gatherTVars tvars (TAll tvar t) = gatherTVars (tvar : tvars) t
gatherTVars tvars t = (tvars, t)
-- Alpha rename the first type's type variable to match second. where
-- Pre-condition: No TAll are checked insertMatch :: TVar -> Type -> State (Map TVar Type) Bool
alpha :: Type -> Type -> State (Map TVar Type) () insertMatch tvar t = do
alpha t1 t2 = case (t1, t2) of m <- gets $ M.lookup tvar
(TVar i, t2) -> do case m of
m <- get Nothing -> modify (M.insert tvar t) >> return True
put (M.insert i t2 m) Just t' -> return $ t == t'
(TFun t1 t2, TFun t3 t4) -> do
alpha t1 t3
alpha t2 t4
(TData _ ts1, TData _ ts2) -> zipWithM_ alpha ts1 ts2
_ -> return ()
-- | A class for substitutions -- | A class for substitutions
class SubstType t where class SubstType t where
@ -808,14 +837,13 @@ dataErr ma d =
initCtx = Ctx mempty initCtx = Ctx mempty
initEnv = Env 0 'a' mempty mempty mempty mempty initEnv = Env 0 'a' mempty mempty mempty mempty
run :: Infer a -> Either Error (a, [Warning]) run :: Infer a -> Either Error a
run = run' initEnv initCtx run = run' initEnv initCtx
run' :: Env -> Ctx -> Infer a -> Either Error (a, [Warning]) run' :: Env -> Ctx -> Infer a -> Either Error a
run' e c = run' e c =
runIdentity runIdentity
. runExceptT . runExceptT
. runWriterT
. flip runReaderT c . flip runReaderT c
. flip evalStateT e . flip evalStateT e
. runInfer . runInfer
@ -851,11 +879,16 @@ instance Semigroup Subst where
instance Monoid Subst where instance Monoid Subst where
mempty = Subst mempty mempty = Subst mempty
newtype Warning = NonExhaustive String newtype Infer a = Infer {runInfer :: StateT Env
deriving (Show) (ReaderT Ctx
(ExceptT Error Identity)) a}
newtype Infer a = Infer {runInfer :: StateT Env (ReaderT Ctx (WriterT [Warning] (ExceptT Error Identity))) a} deriving ( Functor
deriving (Functor, Applicative, Monad, MonadReader Ctx, MonadError Error, MonadState Env) , Applicative
, Monad
, MonadReader Ctx
, MonadError Error
, MonadState Env
)
catchableErr :: MonadError Error m => String -> m a catchableErr :: MonadError Error m => String -> m a
catchableErr msg = throwError $ Error msg True catchableErr msg = throwError $ Error msg True
@ -868,3 +901,14 @@ quote s = "'" ++ s ++ "'"
letters :: [String] letters :: [String]
letters = [1 ..] >>= flip replicateM ['a' .. 'z'] letters = [1 ..] >>= flip replicateM ['a' .. 'z']
clean :: Type -> Type
clean (TVar (MkTVar (LIdent a))) =
TVar . MkTVar . LIdent $ dropWhile (`notElem` ['a' .. 'z']) a
clean (TFun t1 t2) = TFun (clean t1) (clean t2)
clean (TData i ts) = TData i (map clean ts)
clean (TAll tvar t) = let (TVar tvar') = clean (TVar tvar)
in TAll tvar' (clean t)
clean t = t
pretty = quote . printTree . clean