Bind now does correct subtype check.
This commit is contained in:
parent
f8a70b4cf4
commit
5a28f9d909
1 changed files with 49 additions and 29 deletions
|
|
@ -22,7 +22,7 @@ import Data.Map qualified as M
|
|||
import Data.Maybe (fromJust)
|
||||
import Data.Set (Set)
|
||||
import Data.Set qualified as S
|
||||
import Debug.Trace (trace)
|
||||
import Debug.Trace (trace, traceShow)
|
||||
import Grammar.Abs
|
||||
import Grammar.Print (printTree)
|
||||
import TypeChecker.TypeCheckerIr qualified as T
|
||||
|
|
@ -158,36 +158,31 @@ freeOrdered _ = mempty
|
|||
checkBind :: Bind -> Infer (T.Bind' Type)
|
||||
checkBind (Bind name args e) = do
|
||||
let lambda = makeLambda e (reverse (coerce args))
|
||||
(e, lambda_t) <- inferExp lambda
|
||||
(e, infSig) <- inferExp lambda
|
||||
s <- gets sigs
|
||||
case M.lookup (coerce name) s of
|
||||
Just (Just t') -> do
|
||||
let fvs0 = nub $ freeOrdered t'
|
||||
let m0 = M.fromList $ zip fvs0 letters
|
||||
let fvs1 = nub $ freeOrdered lambda_t
|
||||
let m1 = M.fromList $ zip fvs1 letters
|
||||
let t0 = replace m0 t'
|
||||
let t1 = replace m1 lambda_t
|
||||
-- Not sure if this is actually correct
|
||||
sub <- unify t' lambda_t
|
||||
Just (Just typSig) -> do
|
||||
let genInfSig = generalize mempty infSig
|
||||
(trace ("Inferred: " ++ printTree infSig ++ "\nGeneralized inferred: " ++ printTree genInfSig ++ "\nGiven: " ++ printTree typSig ++ "\n") pure ())
|
||||
sub <- genInfSig `unify` typSig
|
||||
unless
|
||||
(t1 <<= t0)
|
||||
(genInfSig <<= typSig)
|
||||
( throwError $
|
||||
Error
|
||||
( Aux.do
|
||||
"Inferred type"
|
||||
quote $ printTree t1
|
||||
quote $ printTree infSig
|
||||
"doesn't match given type"
|
||||
quote $ printTree $ mkForall t0
|
||||
quote $ printTree typSig
|
||||
)
|
||||
False
|
||||
)
|
||||
-- Applying sub to t' will worsen error messages.
|
||||
-- Applying sub to typSig will worsen error messages.
|
||||
-- Unfortunately I do not know a better solution at the moment.
|
||||
return $ T.Bind (coerce name, apply sub t') [] (apply sub e, lambda_t)
|
||||
return $ T.Bind (coerce name, apply sub typSig) [] (apply sub e, typSig)
|
||||
_ -> do
|
||||
insertSig (coerce name) (Just lambda_t)
|
||||
return (T.Bind (coerce name, lambda_t) [] (e, lambda_t))
|
||||
insertSig (coerce name) (Just infSig)
|
||||
return (T.Bind (coerce name, infSig) [] (e, infSig))
|
||||
|
||||
checkData :: (MonadState Env m, Monad m, MonadError Error m) => Data -> m ()
|
||||
checkData err@(Data typ injs) = do
|
||||
|
|
@ -276,7 +271,7 @@ algoW = \case
|
|||
quote $ printTree t'
|
||||
)
|
||||
let comp = sub2 `compose` sub1 `compose` sub0
|
||||
return (comp, (apply comp e', skolemize t))
|
||||
return (comp, (apply comp e', t))
|
||||
|
||||
-- \| ------------------
|
||||
-- \| Γ ⊢ i : Int, ∅
|
||||
|
|
@ -384,7 +379,9 @@ algoW = \case
|
|||
return (comp, apply comp (T.ECase (e', t) injs, ret_t))
|
||||
|
||||
checkCase :: Type -> [Branch] -> Infer (Subst, [T.Branch' Type], Type)
|
||||
checkCase _ [] = catchableErr "Atleast one case required"
|
||||
checkCase _ [] = do
|
||||
fr <- fresh
|
||||
return (nullSubst, [], fr)
|
||||
checkCase expT brnchs = do
|
||||
(subs, branchTs, injs, returns) <- unzip4 <$> mapM inferBranch brnchs
|
||||
let sub0 = composeAll subs
|
||||
|
|
@ -608,15 +605,37 @@ currently this is not the case, the TAll pattern match is incorrectly implemente
|
|||
-}
|
||||
-- Is the left a subtype of the right
|
||||
(<<=) :: Type -> Type -> Bool
|
||||
(<<=) (TVar _) _ = True
|
||||
(<<=) t1@TAll{} t2 = skipForalls t1 <<= t2
|
||||
(<<=) t1 t2@TAll{} = t1 <<= skipForalls t2
|
||||
(<<=) (TFun a b) (TFun c d) = a <<= c && b <<= d
|
||||
(<<=) (TData n1 ts1) (TData n2 ts2) =
|
||||
n1 == n2
|
||||
&& length ts1 == length ts2
|
||||
&& and (zipWith (<<=) ts1 ts2)
|
||||
(<<=) a b = a == b
|
||||
(<<=) a b = case (a,b) of
|
||||
(TVar _, _) -> True
|
||||
(TFun a b,TFun c d) -> a <<= c && b <<= d
|
||||
(TAll tvar1 t1, TAll tvar2 t2) -> ungo [tvar1, tvar2] t1 t2
|
||||
(TAll tvar t1, t2) -> ungo [tvar] t1 t2
|
||||
(t1, TAll tvar t2) -> ungo [tvar] t1 t2
|
||||
(TData n1 ts1, TData n2 ts2) -> n1 == n2
|
||||
&& length ts1 == length ts2
|
||||
&& and (zipWith (<<=) ts1 ts2)
|
||||
(t1,t2) -> t1 == t2
|
||||
where
|
||||
ungo :: [TVar] -> Type -> Type -> Bool
|
||||
ungo tvars t1 t2 = case run (go tvars t1 t2) of
|
||||
Right (b,_) -> b
|
||||
_ -> False
|
||||
go :: [TVar] -> Type -> Type -> Infer Bool
|
||||
go tvars t1 t2 = do
|
||||
fr <- fresh
|
||||
let sub = M.fromList [(coerce x, fr) | (MkTVar x) <- tvars]
|
||||
return (apply sub t1 <<= apply sub t2)
|
||||
|
||||
{-
|
||||
|
||||
typSig = TAll (MkTVar "a") (TAll (MkTVar "b") ((TVar (MkTVar "a") `TFun` (TVar (MkTVar "b")))))
|
||||
|
||||
infSig = generalize mempty $ TFun (TVar $ MkTVar "x") (TVar $ MkTVar "x")
|
||||
|
||||
a = TAll (MkTVar "a") (TFun (TVar $ MkTVar "a") (TVar $ MkTVar "a"))
|
||||
b = TAll (MkTVar "b") (TFun (TVar $ MkTVar "b") (TVar $ MkTVar "b"))
|
||||
int = TFun (TLit "Int") (TLit "Int")
|
||||
-}
|
||||
|
||||
skipForalls :: Type -> Type
|
||||
skipForalls = \case
|
||||
|
|
@ -897,6 +916,7 @@ data Error = Error {msg :: String, catchable :: Bool}
|
|||
type Subst = Map T.Ident Type
|
||||
|
||||
newtype Warning = NonExhaustive String
|
||||
deriving (Show)
|
||||
|
||||
newtype Infer a = Infer {runInfer :: StateT Env (ReaderT Ctx (WriterT [Warning] (ExceptT Error Identity))) a}
|
||||
deriving (Functor, Applicative, Monad, MonadReader Ctx, MonadError Error, MonadState Env)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue