Bind now does correct subtype check.

This commit is contained in:
sebastianselander 2023-05-03 17:59:09 +02:00
parent f8a70b4cf4
commit 5a28f9d909

View file

@ -22,7 +22,7 @@ import Data.Map qualified as M
import Data.Maybe (fromJust)
import Data.Set (Set)
import Data.Set qualified as S
import Debug.Trace (trace)
import Debug.Trace (trace, traceShow)
import Grammar.Abs
import Grammar.Print (printTree)
import TypeChecker.TypeCheckerIr qualified as T
@ -158,36 +158,31 @@ freeOrdered _ = mempty
checkBind :: Bind -> Infer (T.Bind' Type)
checkBind (Bind name args e) = do
let lambda = makeLambda e (reverse (coerce args))
(e, lambda_t) <- inferExp lambda
(e, infSig) <- inferExp lambda
s <- gets sigs
case M.lookup (coerce name) s of
Just (Just t') -> do
let fvs0 = nub $ freeOrdered t'
let m0 = M.fromList $ zip fvs0 letters
let fvs1 = nub $ freeOrdered lambda_t
let m1 = M.fromList $ zip fvs1 letters
let t0 = replace m0 t'
let t1 = replace m1 lambda_t
-- Not sure if this is actually correct
sub <- unify t' lambda_t
Just (Just typSig) -> do
let genInfSig = generalize mempty infSig
(trace ("Inferred: " ++ printTree infSig ++ "\nGeneralized inferred: " ++ printTree genInfSig ++ "\nGiven: " ++ printTree typSig ++ "\n") pure ())
sub <- genInfSig `unify` typSig
unless
(t1 <<= t0)
(genInfSig <<= typSig)
( throwError $
Error
( Aux.do
"Inferred type"
quote $ printTree t1
quote $ printTree infSig
"doesn't match given type"
quote $ printTree $ mkForall t0
quote $ printTree typSig
)
False
)
-- Applying sub to t' will worsen error messages.
-- Applying sub to typSig will worsen error messages.
-- Unfortunately I do not know a better solution at the moment.
return $ T.Bind (coerce name, apply sub t') [] (apply sub e, lambda_t)
return $ T.Bind (coerce name, apply sub typSig) [] (apply sub e, typSig)
_ -> do
insertSig (coerce name) (Just lambda_t)
return (T.Bind (coerce name, lambda_t) [] (e, lambda_t))
insertSig (coerce name) (Just infSig)
return (T.Bind (coerce name, infSig) [] (e, infSig))
checkData :: (MonadState Env m, Monad m, MonadError Error m) => Data -> m ()
checkData err@(Data typ injs) = do
@ -276,7 +271,7 @@ algoW = \case
quote $ printTree t'
)
let comp = sub2 `compose` sub1 `compose` sub0
return (comp, (apply comp e', skolemize t))
return (comp, (apply comp e', t))
-- \| ------------------
-- \| Γ ⊢ i : Int, ∅
@ -384,7 +379,9 @@ algoW = \case
return (comp, apply comp (T.ECase (e', t) injs, ret_t))
checkCase :: Type -> [Branch] -> Infer (Subst, [T.Branch' Type], Type)
checkCase _ [] = catchableErr "Atleast one case required"
checkCase _ [] = do
fr <- fresh
return (nullSubst, [], fr)
checkCase expT brnchs = do
(subs, branchTs, injs, returns) <- unzip4 <$> mapM inferBranch brnchs
let sub0 = composeAll subs
@ -608,15 +605,37 @@ currently this is not the case, the TAll pattern match is incorrectly implemente
-}
-- Is the left a subtype of the right
(<<=) :: Type -> Type -> Bool
(<<=) (TVar _) _ = True
(<<=) t1@TAll{} t2 = skipForalls t1 <<= t2
(<<=) t1 t2@TAll{} = t1 <<= skipForalls t2
(<<=) (TFun a b) (TFun c d) = a <<= c && b <<= d
(<<=) (TData n1 ts1) (TData n2 ts2) =
n1 == n2
&& length ts1 == length ts2
&& and (zipWith (<<=) ts1 ts2)
(<<=) a b = a == b
(<<=) a b = case (a,b) of
(TVar _, _) -> True
(TFun a b,TFun c d) -> a <<= c && b <<= d
(TAll tvar1 t1, TAll tvar2 t2) -> ungo [tvar1, tvar2] t1 t2
(TAll tvar t1, t2) -> ungo [tvar] t1 t2
(t1, TAll tvar t2) -> ungo [tvar] t1 t2
(TData n1 ts1, TData n2 ts2) -> n1 == n2
&& length ts1 == length ts2
&& and (zipWith (<<=) ts1 ts2)
(t1,t2) -> t1 == t2
where
ungo :: [TVar] -> Type -> Type -> Bool
ungo tvars t1 t2 = case run (go tvars t1 t2) of
Right (b,_) -> b
_ -> False
go :: [TVar] -> Type -> Type -> Infer Bool
go tvars t1 t2 = do
fr <- fresh
let sub = M.fromList [(coerce x, fr) | (MkTVar x) <- tvars]
return (apply sub t1 <<= apply sub t2)
{-
typSig = TAll (MkTVar "a") (TAll (MkTVar "b") ((TVar (MkTVar "a") `TFun` (TVar (MkTVar "b")))))
infSig = generalize mempty $ TFun (TVar $ MkTVar "x") (TVar $ MkTVar "x")
a = TAll (MkTVar "a") (TFun (TVar $ MkTVar "a") (TVar $ MkTVar "a"))
b = TAll (MkTVar "b") (TFun (TVar $ MkTVar "b") (TVar $ MkTVar "b"))
int = TFun (TLit "Int") (TLit "Int")
-}
skipForalls :: Type -> Type
skipForalls = \case
@ -897,6 +916,7 @@ data Error = Error {msg :: String, catchable :: Bool}
type Subst = Map T.Ident Type
newtype Warning = NonExhaustive String
deriving (Show)
newtype Infer a = Infer {runInfer :: StateT Env (ReaderT Ctx (WriterT [Warning] (ExceptT Error Identity))) a}
deriving (Functor, Applicative, Monad, MonadReader Ctx, MonadError Error, MonadState Env)