Remade <<=, better err msg, removed writer monad

This commit is contained in:
sebastian 2023-05-17 17:31:08 +02:00
parent 5eaf7ae00d
commit de1ca23db7
2 changed files with 114 additions and 71 deletions

View file

@ -4,7 +4,6 @@ import Control.Monad ((<=<))
import qualified Grammar.Abs as G
import Grammar.ErrM (Err)
import TypeChecker.RemoveForall (removeForall)
import qualified TypeChecker.ReportTEVar as R
import TypeChecker.ReportTEVar (reportTEVar)
import qualified TypeChecker.TypeCheckerBidir as Bi
import qualified TypeChecker.TypeCheckerHm as Hm
@ -17,4 +16,4 @@ typecheck tc = fmap removeForall . (reportTEVar <=< f)
where
f = case tc of
Bi -> Bi.typecheck
Hm -> fmap fst . Hm.typecheck
Hm -> Hm.typecheck

View file

@ -15,7 +15,6 @@ import Control.Monad.Except
import Control.Monad.Identity (Identity, runIdentity)
import Control.Monad.Reader
import Control.Monad.State
import Control.Monad.Writer
import Data.Coerce (coerce)
import Data.Function (on)
import Data.List (foldl')
@ -27,10 +26,9 @@ import Grammar.Abs
import Grammar.Print (printTree)
import TypeChecker.TypeCheckerIr (T, T')
import TypeChecker.TypeCheckerIr qualified as T
import Debug.Trace (trace)
-- | Type check a program
typecheck :: Program -> Either String (T.Program' Type, [Warning])
typecheck :: Program -> Either String (T.Program' Type)
typecheck = onLeft msg . run . checkPrg
where
onLeft :: (Error -> String) -> Either Error a -> Either String a
@ -108,9 +106,9 @@ checkBind (Bind name args e) = do
Error
( Aux.do
"Inferred type"
quote $ printTree genInfSig
pretty genInfSig
"doesn't match given type"
quote $ printTree typSig
pretty typSig
)
False
)
@ -146,7 +144,7 @@ checkInj (Inj c inj_typ) name tvars
"Constructor"
quote $ coerce name
"with type"
quote $ printTree t
pretty t
"already exist"
Nothing -> insertInj (coerce c) inj_typ
| otherwise =
@ -155,9 +153,9 @@ checkInj (Inj c inj_typ) name tvars
[ "Bad type constructor: "
, show name
, "\nExpected: "
, printTree . TData name $ map TVar tvars
, printTree . TData name $ map (clean . TVar) tvars
, "\nActual: "
, printTree $ returnType inj_typ
, pretty $ returnType inj_typ
]
toTVar :: Type -> Either Error TVar
@ -203,9 +201,9 @@ algoW = \case
b
( uncatchableErr $ Aux.do
"Annotated type"
quote $ printTree t
"does not match inferred type"
quote $ printTree t'
pretty t
"is more polymorphic than the inferred type"
pretty t'
)
let comp = sub1 <> sub0
return (comp, (apply comp e', t))
@ -327,7 +325,6 @@ checkCase expT brnchs = do
(subs, branchTs, injs, returns) <- unzip4 <$> mapM inferBranch brnchs
-- compose all probably wrong
let sub0 = composeAll subs
trace ("Substitutions: " ++ show subs) pure ()
(sub1, _) <-
foldM
( \(sub, acc) x ->
@ -453,17 +450,17 @@ unify t0 t1 = case (t0, t1) of
Aux.do
"Type constructor:"
printTree name
quote $ printTree t
quote $ printTree (map clean t)
"does not match with:"
printTree name'
quote $ printTree t'
quote $ printTree (map clean t')
(a, b) -> do
catchableErr $
Aux.do
"Can not unify"
quote $ printTree a
pretty a
"with"
quote $ printTree b
pretty b
{- | Check if a type is contained in another type.
I.E. { a = a -> b } is an unsolvable constraint since there is no substitution
@ -476,9 +473,9 @@ occurs i t
catchableErr
( Aux.do
"Occurs check failed, can't unify"
quote $ printTree (TVar i)
pretty $ TVar i
"with"
quote $ printTree t
pretty t
)
| otherwise = return $ singleton i t
@ -529,56 +526,88 @@ fresh :: Infer Type
fresh = do
n <- gets count
modify (\st -> st{count = succ (count st)})
return . TVar . MkTVar . LIdent $ letters !! n
return . TVar . MkTVar . LIdent $ '1' : (letters !! n)
-- Is the left more general than the right
-- TODO: A bug might exist
(<<=) :: Type -> Type -> Infer Bool
(<<=) a b = case (a, b) of
(TVar _, _) -> return True
(TFun a b, TFun c d) -> do
bfirst <- a <<= c
bsecond <- b <<= d
return (bfirst && bsecond)
(TData n1 ts1, TData n2 ts2) -> do
b <- and <$> zipWithM (<<=) ts1 ts2
return (b && n1 == n2 && length ts1 == length ts2)
(t1@(TAll _ _), t2) ->
let (tvars1, t1') = gatherTVars [] t1
(tvars2, t2') = gatherTVars [] t2
in go (tvars1 ++ tvars2) t1' t2'
(t1, t2@(TAll _ _)) ->
let (tvars1, t1') = gatherTVars [] t1
(tvars2, t2') = gatherTVars [] t2
in go (tvars1 ++ tvars2) t1' t2'
(t1, t2) -> return $ t1 == t2
where
go :: [TVar] -> Type -> Type -> Infer Bool
go tvars t1 t2 = do
freshies <- mapM (const fresh) tvars
let sub = Subst . M.fromList $ zip tvars freshies
let t1' = apply sub t1
let t2' = apply sub t2
let alph = Subst $ execState (alpha t1' t2') mempty
return $ apply alph t1' == t2'
-- (<<=) :: Type -> Type -> Infer Bool
-- (<<=) a b = case (a, b) of
-- (TVar _, _) -> return True
-- (TFun a b, TFun c d) -> do
-- bfirst <- a <<= c
-- bsecond <- b <<= d
-- return (bfirst && bsecond)
-- (TData n1 ts1, TData n2 ts2) -> do
-- b <- and <$> zipWithM (<<=) ts1 ts2
-- return (b && n1 == n2 && length ts1 == length ts2)
-- (t1@(TAll _ _), t2) ->
-- let (tvars1, t1') = gatherTVars [] t1
-- (tvars2, t2') = gatherTVars [] t2
-- in go (tvars1 ++ tvars2) t1' t2'
-- (t1, t2@(TAll _ _)) ->
-- let (tvars1, t1') = gatherTVars [] t1
-- (tvars2, t2') = gatherTVars [] t2
-- in go (tvars1 ++ tvars2) t1' t2'
-- (t1, t2) -> return $ t1 == t2
-- where
-- go :: [TVar] -> Type -> Type -> Infer Bool
-- go tvars t1 t2 = do
-- freshies <- mapM (const fresh) tvars
-- let sub = Subst . M.fromList $ zip tvars freshies
-- let t1' = apply sub t1
-- let t2' = apply sub t2
-- let alph = Subst $ execState (alpha t1' t2') mempty
-- return $ apply alph t1' == t2'
-- -- Alpha rename the first type's type variable to match second.
-- -- Pre-condition: No TAll are checked
-- alpha :: Type -> Type -> State (Map TVar Type) ()
-- alpha t1 t2 = case (t1, t2) of
-- (TVar i, t2) -> do
-- m <- get
-- put (M.insert i t2 m)
-- (TFun t1 t2, TFun t3 t4) -> do
-- alpha t1 t3
-- alpha t2 t4
-- (TData _ ts1, TData _ ts2) -> zipWithM_ alpha ts1 ts2
-- _ -> return ()
-- Pre-condition: All TAlls are outermost
gatherTVars :: [TVar] -> Type -> ([TVar], Type)
gatherTVars tvars (TAll tvar t) = gatherTVars (tvar : tvars) t
gatherTVars tvars t = (tvars, t)
-- Alpha rename the first type's type variable to match second.
-- Pre-condition: No TAll are checked
alpha :: Type -> Type -> State (Map TVar Type) ()
alpha t1 t2 = case (t1, t2) of
(TVar i, t2) -> do
m <- get
put (M.insert i t2 m)
(TFun t1 t2, TFun t3 t4) -> do
alpha t1 t3
alpha t2 t4
(TData _ ts1, TData _ ts2) -> zipWithM_ alpha ts1 ts2
_ -> return ()
t1 = TAll (MkTVar "a") a
a = TVar $ MkTVar "a"
b = TVar $ MkTVar "b"
t2 = TFun a b
(<<=) :: Type -> Type -> Infer Bool
(<<=) t1@(TAll _ _ ) t2 =
let (tvars1, t1') = gatherTVars [] t1
(_, t2') = gatherTVars [] t2
(b1, vars) = runState (match t1' t2') mempty
b2 = all (`elem` tvars1) (M.keys vars)
in return $ b1 && b2
(<<=) t1 t2 = return $ t1 == t2
-- Left represents the one that should be more general
match :: Type -> Type -> State (Map TVar Type) Bool
match t1 t2 = case (t1, t2) of
(TVar a, t2) -> insertMatch a t2
(TFun t1 t2, TFun t3 t4) -> (&&) <$> match t1 t3 <*> match t2 t4
(TData ident1 ts1, TData ident2 ts2) ->
(ident1 == ident2 && ) . and <$> zipWithM match ts1 ts2
(t1, t2) -> return $ t1 == t2
where
insertMatch :: TVar -> Type -> State (Map TVar Type) Bool
insertMatch tvar t = do
m <- gets $ M.lookup tvar
case m of
Nothing -> modify (M.insert tvar t) >> return True
Just t' -> return $ t == t'
-- | A class for substitutions
class SubstType t where
@ -808,14 +837,13 @@ dataErr ma d =
initCtx = Ctx mempty
initEnv = Env 0 'a' mempty mempty mempty mempty
run :: Infer a -> Either Error (a, [Warning])
run :: Infer a -> Either Error a
run = run' initEnv initCtx
run' :: Env -> Ctx -> Infer a -> Either Error (a, [Warning])
run' :: Env -> Ctx -> Infer a -> Either Error a
run' e c =
runIdentity
. runExceptT
. runWriterT
. flip runReaderT c
. flip evalStateT e
. runInfer
@ -851,11 +879,16 @@ instance Semigroup Subst where
instance Monoid Subst where
mempty = Subst mempty
newtype Warning = NonExhaustive String
deriving (Show)
newtype Infer a = Infer {runInfer :: StateT Env (ReaderT Ctx (WriterT [Warning] (ExceptT Error Identity))) a}
deriving (Functor, Applicative, Monad, MonadReader Ctx, MonadError Error, MonadState Env)
newtype Infer a = Infer {runInfer :: StateT Env
(ReaderT Ctx
(ExceptT Error Identity)) a}
deriving ( Functor
, Applicative
, Monad
, MonadReader Ctx
, MonadError Error
, MonadState Env
)
catchableErr :: MonadError Error m => String -> m a
catchableErr msg = throwError $ Error msg True
@ -868,3 +901,14 @@ quote s = "'" ++ s ++ "'"
letters :: [String]
letters = [1 ..] >>= flip replicateM ['a' .. 'z']
clean :: Type -> Type
clean (TVar (MkTVar (LIdent a))) =
TVar . MkTVar . LIdent $ dropWhile (`notElem` ['a' .. 'z']) a
clean (TFun t1 t2) = TFun (clean t1) (clean t2)
clean (TData i ts) = TData i (map clean ts)
clean (TAll tvar t) = let (TVar tvar') = clean (TVar tvar)
in TAll tvar' (clean t)
clean t = t
pretty = quote . printTree . clean