Remade <<=, better err msg, removed writer monad
This commit is contained in:
parent
5eaf7ae00d
commit
de1ca23db7
2 changed files with 114 additions and 71 deletions
|
|
@ -4,7 +4,6 @@ import Control.Monad ((<=<))
|
|||
import qualified Grammar.Abs as G
|
||||
import Grammar.ErrM (Err)
|
||||
import TypeChecker.RemoveForall (removeForall)
|
||||
import qualified TypeChecker.ReportTEVar as R
|
||||
import TypeChecker.ReportTEVar (reportTEVar)
|
||||
import qualified TypeChecker.TypeCheckerBidir as Bi
|
||||
import qualified TypeChecker.TypeCheckerHm as Hm
|
||||
|
|
@ -17,4 +16,4 @@ typecheck tc = fmap removeForall . (reportTEVar <=< f)
|
|||
where
|
||||
f = case tc of
|
||||
Bi -> Bi.typecheck
|
||||
Hm -> fmap fst . Hm.typecheck
|
||||
Hm -> Hm.typecheck
|
||||
|
|
|
|||
|
|
@ -15,7 +15,6 @@ import Control.Monad.Except
|
|||
import Control.Monad.Identity (Identity, runIdentity)
|
||||
import Control.Monad.Reader
|
||||
import Control.Monad.State
|
||||
import Control.Monad.Writer
|
||||
import Data.Coerce (coerce)
|
||||
import Data.Function (on)
|
||||
import Data.List (foldl')
|
||||
|
|
@ -27,10 +26,9 @@ import Grammar.Abs
|
|||
import Grammar.Print (printTree)
|
||||
import TypeChecker.TypeCheckerIr (T, T')
|
||||
import TypeChecker.TypeCheckerIr qualified as T
|
||||
import Debug.Trace (trace)
|
||||
|
||||
-- | Type check a program
|
||||
typecheck :: Program -> Either String (T.Program' Type, [Warning])
|
||||
typecheck :: Program -> Either String (T.Program' Type)
|
||||
typecheck = onLeft msg . run . checkPrg
|
||||
where
|
||||
onLeft :: (Error -> String) -> Either Error a -> Either String a
|
||||
|
|
@ -108,9 +106,9 @@ checkBind (Bind name args e) = do
|
|||
Error
|
||||
( Aux.do
|
||||
"Inferred type"
|
||||
quote $ printTree genInfSig
|
||||
pretty genInfSig
|
||||
"doesn't match given type"
|
||||
quote $ printTree typSig
|
||||
pretty typSig
|
||||
)
|
||||
False
|
||||
)
|
||||
|
|
@ -146,7 +144,7 @@ checkInj (Inj c inj_typ) name tvars
|
|||
"Constructor"
|
||||
quote $ coerce name
|
||||
"with type"
|
||||
quote $ printTree t
|
||||
pretty t
|
||||
"already exist"
|
||||
Nothing -> insertInj (coerce c) inj_typ
|
||||
| otherwise =
|
||||
|
|
@ -155,9 +153,9 @@ checkInj (Inj c inj_typ) name tvars
|
|||
[ "Bad type constructor: "
|
||||
, show name
|
||||
, "\nExpected: "
|
||||
, printTree . TData name $ map TVar tvars
|
||||
, printTree . TData name $ map (clean . TVar) tvars
|
||||
, "\nActual: "
|
||||
, printTree $ returnType inj_typ
|
||||
, pretty $ returnType inj_typ
|
||||
]
|
||||
|
||||
toTVar :: Type -> Either Error TVar
|
||||
|
|
@ -203,9 +201,9 @@ algoW = \case
|
|||
b
|
||||
( uncatchableErr $ Aux.do
|
||||
"Annotated type"
|
||||
quote $ printTree t
|
||||
"does not match inferred type"
|
||||
quote $ printTree t'
|
||||
pretty t
|
||||
"is more polymorphic than the inferred type"
|
||||
pretty t'
|
||||
)
|
||||
let comp = sub1 <> sub0
|
||||
return (comp, (apply comp e', t))
|
||||
|
|
@ -327,7 +325,6 @@ checkCase expT brnchs = do
|
|||
(subs, branchTs, injs, returns) <- unzip4 <$> mapM inferBranch brnchs
|
||||
-- compose all probably wrong
|
||||
let sub0 = composeAll subs
|
||||
trace ("Substitutions: " ++ show subs) pure ()
|
||||
(sub1, _) <-
|
||||
foldM
|
||||
( \(sub, acc) x ->
|
||||
|
|
@ -453,17 +450,17 @@ unify t0 t1 = case (t0, t1) of
|
|||
Aux.do
|
||||
"Type constructor:"
|
||||
printTree name
|
||||
quote $ printTree t
|
||||
quote $ printTree (map clean t)
|
||||
"does not match with:"
|
||||
printTree name'
|
||||
quote $ printTree t'
|
||||
quote $ printTree (map clean t')
|
||||
(a, b) -> do
|
||||
catchableErr $
|
||||
Aux.do
|
||||
"Can not unify"
|
||||
quote $ printTree a
|
||||
pretty a
|
||||
"with"
|
||||
quote $ printTree b
|
||||
pretty b
|
||||
|
||||
{- | Check if a type is contained in another type.
|
||||
I.E. { a = a -> b } is an unsolvable constraint since there is no substitution
|
||||
|
|
@ -476,9 +473,9 @@ occurs i t
|
|||
catchableErr
|
||||
( Aux.do
|
||||
"Occurs check failed, can't unify"
|
||||
quote $ printTree (TVar i)
|
||||
pretty $ TVar i
|
||||
"with"
|
||||
quote $ printTree t
|
||||
pretty t
|
||||
)
|
||||
| otherwise = return $ singleton i t
|
||||
|
||||
|
|
@ -529,56 +526,88 @@ fresh :: Infer Type
|
|||
fresh = do
|
||||
n <- gets count
|
||||
modify (\st -> st{count = succ (count st)})
|
||||
return . TVar . MkTVar . LIdent $ letters !! n
|
||||
return . TVar . MkTVar . LIdent $ '1' : (letters !! n)
|
||||
|
||||
-- Is the left more general than the right
|
||||
-- TODO: A bug might exist
|
||||
-- (<<=) :: Type -> Type -> Infer Bool
|
||||
-- (<<=) a b = case (a, b) of
|
||||
-- (TVar _, _) -> return True
|
||||
-- (TFun a b, TFun c d) -> do
|
||||
-- bfirst <- a <<= c
|
||||
-- bsecond <- b <<= d
|
||||
-- return (bfirst && bsecond)
|
||||
-- (TData n1 ts1, TData n2 ts2) -> do
|
||||
-- b <- and <$> zipWithM (<<=) ts1 ts2
|
||||
-- return (b && n1 == n2 && length ts1 == length ts2)
|
||||
-- (t1@(TAll _ _), t2) ->
|
||||
-- let (tvars1, t1') = gatherTVars [] t1
|
||||
-- (tvars2, t2') = gatherTVars [] t2
|
||||
-- in go (tvars1 ++ tvars2) t1' t2'
|
||||
-- (t1, t2@(TAll _ _)) ->
|
||||
-- let (tvars1, t1') = gatherTVars [] t1
|
||||
-- (tvars2, t2') = gatherTVars [] t2
|
||||
-- in go (tvars1 ++ tvars2) t1' t2'
|
||||
-- (t1, t2) -> return $ t1 == t2
|
||||
-- where
|
||||
-- go :: [TVar] -> Type -> Type -> Infer Bool
|
||||
-- go tvars t1 t2 = do
|
||||
-- freshies <- mapM (const fresh) tvars
|
||||
-- let sub = Subst . M.fromList $ zip tvars freshies
|
||||
-- let t1' = apply sub t1
|
||||
-- let t2' = apply sub t2
|
||||
-- let alph = Subst $ execState (alpha t1' t2') mempty
|
||||
-- return $ apply alph t1' == t2'
|
||||
|
||||
-- -- Alpha rename the first type's type variable to match second.
|
||||
-- -- Pre-condition: No TAll are checked
|
||||
-- alpha :: Type -> Type -> State (Map TVar Type) ()
|
||||
-- alpha t1 t2 = case (t1, t2) of
|
||||
-- (TVar i, t2) -> do
|
||||
-- m <- get
|
||||
-- put (M.insert i t2 m)
|
||||
-- (TFun t1 t2, TFun t3 t4) -> do
|
||||
-- alpha t1 t3
|
||||
-- alpha t2 t4
|
||||
-- (TData _ ts1, TData _ ts2) -> zipWithM_ alpha ts1 ts2
|
||||
-- _ -> return ()
|
||||
-- Pre-condition: All TAlls are outermost
|
||||
gatherTVars :: [TVar] -> Type -> ([TVar], Type)
|
||||
gatherTVars tvars (TAll tvar t) = gatherTVars (tvar : tvars) t
|
||||
gatherTVars tvars t = (tvars, t)
|
||||
|
||||
t1 = TAll (MkTVar "a") a
|
||||
a = TVar $ MkTVar "a"
|
||||
b = TVar $ MkTVar "b"
|
||||
t2 = TFun a b
|
||||
|
||||
(<<=) :: Type -> Type -> Infer Bool
|
||||
(<<=) a b = case (a, b) of
|
||||
(TVar _, _) -> return True
|
||||
(TFun a b, TFun c d) -> do
|
||||
bfirst <- a <<= c
|
||||
bsecond <- b <<= d
|
||||
return (bfirst && bsecond)
|
||||
(TData n1 ts1, TData n2 ts2) -> do
|
||||
b <- and <$> zipWithM (<<=) ts1 ts2
|
||||
return (b && n1 == n2 && length ts1 == length ts2)
|
||||
(t1@(TAll _ _), t2) ->
|
||||
(<<=) t1@(TAll _ _ ) t2 =
|
||||
let (tvars1, t1') = gatherTVars [] t1
|
||||
(tvars2, t2') = gatherTVars [] t2
|
||||
in go (tvars1 ++ tvars2) t1' t2'
|
||||
(t1, t2@(TAll _ _)) ->
|
||||
let (tvars1, t1') = gatherTVars [] t1
|
||||
(tvars2, t2') = gatherTVars [] t2
|
||||
in go (tvars1 ++ tvars2) t1' t2'
|
||||
(_, t2') = gatherTVars [] t2
|
||||
(b1, vars) = runState (match t1' t2') mempty
|
||||
b2 = all (`elem` tvars1) (M.keys vars)
|
||||
in return $ b1 && b2
|
||||
(<<=) t1 t2 = return $ t1 == t2
|
||||
|
||||
-- Left represents the one that should be more general
|
||||
match :: Type -> Type -> State (Map TVar Type) Bool
|
||||
match t1 t2 = case (t1, t2) of
|
||||
(TVar a, t2) -> insertMatch a t2
|
||||
(TFun t1 t2, TFun t3 t4) -> (&&) <$> match t1 t3 <*> match t2 t4
|
||||
(TData ident1 ts1, TData ident2 ts2) ->
|
||||
(ident1 == ident2 && ) . and <$> zipWithM match ts1 ts2
|
||||
(t1, t2) -> return $ t1 == t2
|
||||
|
||||
|
||||
where
|
||||
go :: [TVar] -> Type -> Type -> Infer Bool
|
||||
go tvars t1 t2 = do
|
||||
freshies <- mapM (const fresh) tvars
|
||||
let sub = Subst . M.fromList $ zip tvars freshies
|
||||
let t1' = apply sub t1
|
||||
let t2' = apply sub t2
|
||||
let alph = Subst $ execState (alpha t1' t2') mempty
|
||||
return $ apply alph t1' == t2'
|
||||
insertMatch :: TVar -> Type -> State (Map TVar Type) Bool
|
||||
insertMatch tvar t = do
|
||||
m <- gets $ M.lookup tvar
|
||||
case m of
|
||||
Nothing -> modify (M.insert tvar t) >> return True
|
||||
Just t' -> return $ t == t'
|
||||
|
||||
-- Pre-condition: All TAlls are outermost
|
||||
gatherTVars :: [TVar] -> Type -> ([TVar], Type)
|
||||
gatherTVars tvars (TAll tvar t) = gatherTVars (tvar : tvars) t
|
||||
gatherTVars tvars t = (tvars, t)
|
||||
|
||||
-- Alpha rename the first type's type variable to match second.
|
||||
-- Pre-condition: No TAll are checked
|
||||
alpha :: Type -> Type -> State (Map TVar Type) ()
|
||||
alpha t1 t2 = case (t1, t2) of
|
||||
(TVar i, t2) -> do
|
||||
m <- get
|
||||
put (M.insert i t2 m)
|
||||
(TFun t1 t2, TFun t3 t4) -> do
|
||||
alpha t1 t3
|
||||
alpha t2 t4
|
||||
(TData _ ts1, TData _ ts2) -> zipWithM_ alpha ts1 ts2
|
||||
_ -> return ()
|
||||
|
||||
-- | A class for substitutions
|
||||
class SubstType t where
|
||||
|
|
@ -808,14 +837,13 @@ dataErr ma d =
|
|||
initCtx = Ctx mempty
|
||||
initEnv = Env 0 'a' mempty mempty mempty mempty
|
||||
|
||||
run :: Infer a -> Either Error (a, [Warning])
|
||||
run :: Infer a -> Either Error a
|
||||
run = run' initEnv initCtx
|
||||
|
||||
run' :: Env -> Ctx -> Infer a -> Either Error (a, [Warning])
|
||||
run' :: Env -> Ctx -> Infer a -> Either Error a
|
||||
run' e c =
|
||||
runIdentity
|
||||
. runExceptT
|
||||
. runWriterT
|
||||
. flip runReaderT c
|
||||
. flip evalStateT e
|
||||
. runInfer
|
||||
|
|
@ -851,11 +879,16 @@ instance Semigroup Subst where
|
|||
instance Monoid Subst where
|
||||
mempty = Subst mempty
|
||||
|
||||
newtype Warning = NonExhaustive String
|
||||
deriving (Show)
|
||||
|
||||
newtype Infer a = Infer {runInfer :: StateT Env (ReaderT Ctx (WriterT [Warning] (ExceptT Error Identity))) a}
|
||||
deriving (Functor, Applicative, Monad, MonadReader Ctx, MonadError Error, MonadState Env)
|
||||
newtype Infer a = Infer {runInfer :: StateT Env
|
||||
(ReaderT Ctx
|
||||
(ExceptT Error Identity)) a}
|
||||
deriving ( Functor
|
||||
, Applicative
|
||||
, Monad
|
||||
, MonadReader Ctx
|
||||
, MonadError Error
|
||||
, MonadState Env
|
||||
)
|
||||
|
||||
catchableErr :: MonadError Error m => String -> m a
|
||||
catchableErr msg = throwError $ Error msg True
|
||||
|
|
@ -868,3 +901,14 @@ quote s = "'" ++ s ++ "'"
|
|||
|
||||
letters :: [String]
|
||||
letters = [1 ..] >>= flip replicateM ['a' .. 'z']
|
||||
|
||||
clean :: Type -> Type
|
||||
clean (TVar (MkTVar (LIdent a))) =
|
||||
TVar . MkTVar . LIdent $ dropWhile (`notElem` ['a' .. 'z']) a
|
||||
clean (TFun t1 t2) = TFun (clean t1) (clean t2)
|
||||
clean (TData i ts) = TData i (map clean ts)
|
||||
clean (TAll tvar t) = let (TVar tvar') = clean (TVar tvar)
|
||||
in TAll tvar' (clean t)
|
||||
clean t = t
|
||||
|
||||
pretty = quote . printTree . clean
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue