Remade <<=, better err msg, removed writer monad
This commit is contained in:
parent
5eaf7ae00d
commit
de1ca23db7
2 changed files with 114 additions and 71 deletions
|
|
@ -4,7 +4,6 @@ import Control.Monad ((<=<))
|
||||||
import qualified Grammar.Abs as G
|
import qualified Grammar.Abs as G
|
||||||
import Grammar.ErrM (Err)
|
import Grammar.ErrM (Err)
|
||||||
import TypeChecker.RemoveForall (removeForall)
|
import TypeChecker.RemoveForall (removeForall)
|
||||||
import qualified TypeChecker.ReportTEVar as R
|
|
||||||
import TypeChecker.ReportTEVar (reportTEVar)
|
import TypeChecker.ReportTEVar (reportTEVar)
|
||||||
import qualified TypeChecker.TypeCheckerBidir as Bi
|
import qualified TypeChecker.TypeCheckerBidir as Bi
|
||||||
import qualified TypeChecker.TypeCheckerHm as Hm
|
import qualified TypeChecker.TypeCheckerHm as Hm
|
||||||
|
|
@ -17,4 +16,4 @@ typecheck tc = fmap removeForall . (reportTEVar <=< f)
|
||||||
where
|
where
|
||||||
f = case tc of
|
f = case tc of
|
||||||
Bi -> Bi.typecheck
|
Bi -> Bi.typecheck
|
||||||
Hm -> fmap fst . Hm.typecheck
|
Hm -> Hm.typecheck
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,6 @@ import Control.Monad.Except
|
||||||
import Control.Monad.Identity (Identity, runIdentity)
|
import Control.Monad.Identity (Identity, runIdentity)
|
||||||
import Control.Monad.Reader
|
import Control.Monad.Reader
|
||||||
import Control.Monad.State
|
import Control.Monad.State
|
||||||
import Control.Monad.Writer
|
|
||||||
import Data.Coerce (coerce)
|
import Data.Coerce (coerce)
|
||||||
import Data.Function (on)
|
import Data.Function (on)
|
||||||
import Data.List (foldl')
|
import Data.List (foldl')
|
||||||
|
|
@ -27,10 +26,9 @@ import Grammar.Abs
|
||||||
import Grammar.Print (printTree)
|
import Grammar.Print (printTree)
|
||||||
import TypeChecker.TypeCheckerIr (T, T')
|
import TypeChecker.TypeCheckerIr (T, T')
|
||||||
import TypeChecker.TypeCheckerIr qualified as T
|
import TypeChecker.TypeCheckerIr qualified as T
|
||||||
import Debug.Trace (trace)
|
|
||||||
|
|
||||||
-- | Type check a program
|
-- | Type check a program
|
||||||
typecheck :: Program -> Either String (T.Program' Type, [Warning])
|
typecheck :: Program -> Either String (T.Program' Type)
|
||||||
typecheck = onLeft msg . run . checkPrg
|
typecheck = onLeft msg . run . checkPrg
|
||||||
where
|
where
|
||||||
onLeft :: (Error -> String) -> Either Error a -> Either String a
|
onLeft :: (Error -> String) -> Either Error a -> Either String a
|
||||||
|
|
@ -108,9 +106,9 @@ checkBind (Bind name args e) = do
|
||||||
Error
|
Error
|
||||||
( Aux.do
|
( Aux.do
|
||||||
"Inferred type"
|
"Inferred type"
|
||||||
quote $ printTree genInfSig
|
pretty genInfSig
|
||||||
"doesn't match given type"
|
"doesn't match given type"
|
||||||
quote $ printTree typSig
|
pretty typSig
|
||||||
)
|
)
|
||||||
False
|
False
|
||||||
)
|
)
|
||||||
|
|
@ -146,7 +144,7 @@ checkInj (Inj c inj_typ) name tvars
|
||||||
"Constructor"
|
"Constructor"
|
||||||
quote $ coerce name
|
quote $ coerce name
|
||||||
"with type"
|
"with type"
|
||||||
quote $ printTree t
|
pretty t
|
||||||
"already exist"
|
"already exist"
|
||||||
Nothing -> insertInj (coerce c) inj_typ
|
Nothing -> insertInj (coerce c) inj_typ
|
||||||
| otherwise =
|
| otherwise =
|
||||||
|
|
@ -155,9 +153,9 @@ checkInj (Inj c inj_typ) name tvars
|
||||||
[ "Bad type constructor: "
|
[ "Bad type constructor: "
|
||||||
, show name
|
, show name
|
||||||
, "\nExpected: "
|
, "\nExpected: "
|
||||||
, printTree . TData name $ map TVar tvars
|
, printTree . TData name $ map (clean . TVar) tvars
|
||||||
, "\nActual: "
|
, "\nActual: "
|
||||||
, printTree $ returnType inj_typ
|
, pretty $ returnType inj_typ
|
||||||
]
|
]
|
||||||
|
|
||||||
toTVar :: Type -> Either Error TVar
|
toTVar :: Type -> Either Error TVar
|
||||||
|
|
@ -203,9 +201,9 @@ algoW = \case
|
||||||
b
|
b
|
||||||
( uncatchableErr $ Aux.do
|
( uncatchableErr $ Aux.do
|
||||||
"Annotated type"
|
"Annotated type"
|
||||||
quote $ printTree t
|
pretty t
|
||||||
"does not match inferred type"
|
"is more polymorphic than the inferred type"
|
||||||
quote $ printTree t'
|
pretty t'
|
||||||
)
|
)
|
||||||
let comp = sub1 <> sub0
|
let comp = sub1 <> sub0
|
||||||
return (comp, (apply comp e', t))
|
return (comp, (apply comp e', t))
|
||||||
|
|
@ -327,7 +325,6 @@ checkCase expT brnchs = do
|
||||||
(subs, branchTs, injs, returns) <- unzip4 <$> mapM inferBranch brnchs
|
(subs, branchTs, injs, returns) <- unzip4 <$> mapM inferBranch brnchs
|
||||||
-- compose all probably wrong
|
-- compose all probably wrong
|
||||||
let sub0 = composeAll subs
|
let sub0 = composeAll subs
|
||||||
trace ("Substitutions: " ++ show subs) pure ()
|
|
||||||
(sub1, _) <-
|
(sub1, _) <-
|
||||||
foldM
|
foldM
|
||||||
( \(sub, acc) x ->
|
( \(sub, acc) x ->
|
||||||
|
|
@ -453,17 +450,17 @@ unify t0 t1 = case (t0, t1) of
|
||||||
Aux.do
|
Aux.do
|
||||||
"Type constructor:"
|
"Type constructor:"
|
||||||
printTree name
|
printTree name
|
||||||
quote $ printTree t
|
quote $ printTree (map clean t)
|
||||||
"does not match with:"
|
"does not match with:"
|
||||||
printTree name'
|
printTree name'
|
||||||
quote $ printTree t'
|
quote $ printTree (map clean t')
|
||||||
(a, b) -> do
|
(a, b) -> do
|
||||||
catchableErr $
|
catchableErr $
|
||||||
Aux.do
|
Aux.do
|
||||||
"Can not unify"
|
"Can not unify"
|
||||||
quote $ printTree a
|
pretty a
|
||||||
"with"
|
"with"
|
||||||
quote $ printTree b
|
pretty b
|
||||||
|
|
||||||
{- | Check if a type is contained in another type.
|
{- | Check if a type is contained in another type.
|
||||||
I.E. { a = a -> b } is an unsolvable constraint since there is no substitution
|
I.E. { a = a -> b } is an unsolvable constraint since there is no substitution
|
||||||
|
|
@ -476,9 +473,9 @@ occurs i t
|
||||||
catchableErr
|
catchableErr
|
||||||
( Aux.do
|
( Aux.do
|
||||||
"Occurs check failed, can't unify"
|
"Occurs check failed, can't unify"
|
||||||
quote $ printTree (TVar i)
|
pretty $ TVar i
|
||||||
"with"
|
"with"
|
||||||
quote $ printTree t
|
pretty t
|
||||||
)
|
)
|
||||||
| otherwise = return $ singleton i t
|
| otherwise = return $ singleton i t
|
||||||
|
|
||||||
|
|
@ -529,56 +526,88 @@ fresh :: Infer Type
|
||||||
fresh = do
|
fresh = do
|
||||||
n <- gets count
|
n <- gets count
|
||||||
modify (\st -> st{count = succ (count st)})
|
modify (\st -> st{count = succ (count st)})
|
||||||
return . TVar . MkTVar . LIdent $ letters !! n
|
return . TVar . MkTVar . LIdent $ '1' : (letters !! n)
|
||||||
|
|
||||||
-- Is the left more general than the right
|
-- Is the left more general than the right
|
||||||
-- TODO: A bug might exist
|
-- TODO: A bug might exist
|
||||||
|
-- (<<=) :: Type -> Type -> Infer Bool
|
||||||
|
-- (<<=) a b = case (a, b) of
|
||||||
|
-- (TVar _, _) -> return True
|
||||||
|
-- (TFun a b, TFun c d) -> do
|
||||||
|
-- bfirst <- a <<= c
|
||||||
|
-- bsecond <- b <<= d
|
||||||
|
-- return (bfirst && bsecond)
|
||||||
|
-- (TData n1 ts1, TData n2 ts2) -> do
|
||||||
|
-- b <- and <$> zipWithM (<<=) ts1 ts2
|
||||||
|
-- return (b && n1 == n2 && length ts1 == length ts2)
|
||||||
|
-- (t1@(TAll _ _), t2) ->
|
||||||
|
-- let (tvars1, t1') = gatherTVars [] t1
|
||||||
|
-- (tvars2, t2') = gatherTVars [] t2
|
||||||
|
-- in go (tvars1 ++ tvars2) t1' t2'
|
||||||
|
-- (t1, t2@(TAll _ _)) ->
|
||||||
|
-- let (tvars1, t1') = gatherTVars [] t1
|
||||||
|
-- (tvars2, t2') = gatherTVars [] t2
|
||||||
|
-- in go (tvars1 ++ tvars2) t1' t2'
|
||||||
|
-- (t1, t2) -> return $ t1 == t2
|
||||||
|
-- where
|
||||||
|
-- go :: [TVar] -> Type -> Type -> Infer Bool
|
||||||
|
-- go tvars t1 t2 = do
|
||||||
|
-- freshies <- mapM (const fresh) tvars
|
||||||
|
-- let sub = Subst . M.fromList $ zip tvars freshies
|
||||||
|
-- let t1' = apply sub t1
|
||||||
|
-- let t2' = apply sub t2
|
||||||
|
-- let alph = Subst $ execState (alpha t1' t2') mempty
|
||||||
|
-- return $ apply alph t1' == t2'
|
||||||
|
|
||||||
|
-- -- Alpha rename the first type's type variable to match second.
|
||||||
|
-- -- Pre-condition: No TAll are checked
|
||||||
|
-- alpha :: Type -> Type -> State (Map TVar Type) ()
|
||||||
|
-- alpha t1 t2 = case (t1, t2) of
|
||||||
|
-- (TVar i, t2) -> do
|
||||||
|
-- m <- get
|
||||||
|
-- put (M.insert i t2 m)
|
||||||
|
-- (TFun t1 t2, TFun t3 t4) -> do
|
||||||
|
-- alpha t1 t3
|
||||||
|
-- alpha t2 t4
|
||||||
|
-- (TData _ ts1, TData _ ts2) -> zipWithM_ alpha ts1 ts2
|
||||||
|
-- _ -> return ()
|
||||||
|
-- Pre-condition: All TAlls are outermost
|
||||||
|
gatherTVars :: [TVar] -> Type -> ([TVar], Type)
|
||||||
|
gatherTVars tvars (TAll tvar t) = gatherTVars (tvar : tvars) t
|
||||||
|
gatherTVars tvars t = (tvars, t)
|
||||||
|
|
||||||
|
t1 = TAll (MkTVar "a") a
|
||||||
|
a = TVar $ MkTVar "a"
|
||||||
|
b = TVar $ MkTVar "b"
|
||||||
|
t2 = TFun a b
|
||||||
|
|
||||||
(<<=) :: Type -> Type -> Infer Bool
|
(<<=) :: Type -> Type -> Infer Bool
|
||||||
(<<=) a b = case (a, b) of
|
(<<=) t1@(TAll _ _ ) t2 =
|
||||||
(TVar _, _) -> return True
|
let (tvars1, t1') = gatherTVars [] t1
|
||||||
(TFun a b, TFun c d) -> do
|
(_, t2') = gatherTVars [] t2
|
||||||
bfirst <- a <<= c
|
(b1, vars) = runState (match t1' t2') mempty
|
||||||
bsecond <- b <<= d
|
b2 = all (`elem` tvars1) (M.keys vars)
|
||||||
return (bfirst && bsecond)
|
in return $ b1 && b2
|
||||||
(TData n1 ts1, TData n2 ts2) -> do
|
(<<=) t1 t2 = return $ t1 == t2
|
||||||
b <- and <$> zipWithM (<<=) ts1 ts2
|
|
||||||
return (b && n1 == n2 && length ts1 == length ts2)
|
-- Left represents the one that should be more general
|
||||||
(t1@(TAll _ _), t2) ->
|
match :: Type -> Type -> State (Map TVar Type) Bool
|
||||||
let (tvars1, t1') = gatherTVars [] t1
|
match t1 t2 = case (t1, t2) of
|
||||||
(tvars2, t2') = gatherTVars [] t2
|
(TVar a, t2) -> insertMatch a t2
|
||||||
in go (tvars1 ++ tvars2) t1' t2'
|
(TFun t1 t2, TFun t3 t4) -> (&&) <$> match t1 t3 <*> match t2 t4
|
||||||
(t1, t2@(TAll _ _)) ->
|
(TData ident1 ts1, TData ident2 ts2) ->
|
||||||
let (tvars1, t1') = gatherTVars [] t1
|
(ident1 == ident2 && ) . and <$> zipWithM match ts1 ts2
|
||||||
(tvars2, t2') = gatherTVars [] t2
|
|
||||||
in go (tvars1 ++ tvars2) t1' t2'
|
|
||||||
(t1, t2) -> return $ t1 == t2
|
(t1, t2) -> return $ t1 == t2
|
||||||
where
|
|
||||||
go :: [TVar] -> Type -> Type -> Infer Bool
|
|
||||||
go tvars t1 t2 = do
|
|
||||||
freshies <- mapM (const fresh) tvars
|
|
||||||
let sub = Subst . M.fromList $ zip tvars freshies
|
|
||||||
let t1' = apply sub t1
|
|
||||||
let t2' = apply sub t2
|
|
||||||
let alph = Subst $ execState (alpha t1' t2') mempty
|
|
||||||
return $ apply alph t1' == t2'
|
|
||||||
|
|
||||||
-- Pre-condition: All TAlls are outermost
|
|
||||||
gatherTVars :: [TVar] -> Type -> ([TVar], Type)
|
|
||||||
gatherTVars tvars (TAll tvar t) = gatherTVars (tvar : tvars) t
|
|
||||||
gatherTVars tvars t = (tvars, t)
|
|
||||||
|
|
||||||
-- Alpha rename the first type's type variable to match second.
|
where
|
||||||
-- Pre-condition: No TAll are checked
|
insertMatch :: TVar -> Type -> State (Map TVar Type) Bool
|
||||||
alpha :: Type -> Type -> State (Map TVar Type) ()
|
insertMatch tvar t = do
|
||||||
alpha t1 t2 = case (t1, t2) of
|
m <- gets $ M.lookup tvar
|
||||||
(TVar i, t2) -> do
|
case m of
|
||||||
m <- get
|
Nothing -> modify (M.insert tvar t) >> return True
|
||||||
put (M.insert i t2 m)
|
Just t' -> return $ t == t'
|
||||||
(TFun t1 t2, TFun t3 t4) -> do
|
|
||||||
alpha t1 t3
|
|
||||||
alpha t2 t4
|
|
||||||
(TData _ ts1, TData _ ts2) -> zipWithM_ alpha ts1 ts2
|
|
||||||
_ -> return ()
|
|
||||||
|
|
||||||
-- | A class for substitutions
|
-- | A class for substitutions
|
||||||
class SubstType t where
|
class SubstType t where
|
||||||
|
|
@ -808,14 +837,13 @@ dataErr ma d =
|
||||||
initCtx = Ctx mempty
|
initCtx = Ctx mempty
|
||||||
initEnv = Env 0 'a' mempty mempty mempty mempty
|
initEnv = Env 0 'a' mempty mempty mempty mempty
|
||||||
|
|
||||||
run :: Infer a -> Either Error (a, [Warning])
|
run :: Infer a -> Either Error a
|
||||||
run = run' initEnv initCtx
|
run = run' initEnv initCtx
|
||||||
|
|
||||||
run' :: Env -> Ctx -> Infer a -> Either Error (a, [Warning])
|
run' :: Env -> Ctx -> Infer a -> Either Error a
|
||||||
run' e c =
|
run' e c =
|
||||||
runIdentity
|
runIdentity
|
||||||
. runExceptT
|
. runExceptT
|
||||||
. runWriterT
|
|
||||||
. flip runReaderT c
|
. flip runReaderT c
|
||||||
. flip evalStateT e
|
. flip evalStateT e
|
||||||
. runInfer
|
. runInfer
|
||||||
|
|
@ -851,11 +879,16 @@ instance Semigroup Subst where
|
||||||
instance Monoid Subst where
|
instance Monoid Subst where
|
||||||
mempty = Subst mempty
|
mempty = Subst mempty
|
||||||
|
|
||||||
newtype Warning = NonExhaustive String
|
newtype Infer a = Infer {runInfer :: StateT Env
|
||||||
deriving (Show)
|
(ReaderT Ctx
|
||||||
|
(ExceptT Error Identity)) a}
|
||||||
newtype Infer a = Infer {runInfer :: StateT Env (ReaderT Ctx (WriterT [Warning] (ExceptT Error Identity))) a}
|
deriving ( Functor
|
||||||
deriving (Functor, Applicative, Monad, MonadReader Ctx, MonadError Error, MonadState Env)
|
, Applicative
|
||||||
|
, Monad
|
||||||
|
, MonadReader Ctx
|
||||||
|
, MonadError Error
|
||||||
|
, MonadState Env
|
||||||
|
)
|
||||||
|
|
||||||
catchableErr :: MonadError Error m => String -> m a
|
catchableErr :: MonadError Error m => String -> m a
|
||||||
catchableErr msg = throwError $ Error msg True
|
catchableErr msg = throwError $ Error msg True
|
||||||
|
|
@ -868,3 +901,14 @@ quote s = "'" ++ s ++ "'"
|
||||||
|
|
||||||
letters :: [String]
|
letters :: [String]
|
||||||
letters = [1 ..] >>= flip replicateM ['a' .. 'z']
|
letters = [1 ..] >>= flip replicateM ['a' .. 'z']
|
||||||
|
|
||||||
|
clean :: Type -> Type
|
||||||
|
clean (TVar (MkTVar (LIdent a))) =
|
||||||
|
TVar . MkTVar . LIdent $ dropWhile (`notElem` ['a' .. 'z']) a
|
||||||
|
clean (TFun t1 t2) = TFun (clean t1) (clean t2)
|
||||||
|
clean (TData i ts) = TData i (map clean ts)
|
||||||
|
clean (TAll tvar t) = let (TVar tvar') = clean (TVar tvar)
|
||||||
|
in TAll tvar' (clean t)
|
||||||
|
clean t = t
|
||||||
|
|
||||||
|
pretty = quote . printTree . clean
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue