Bind now does correct subtype check.
This commit is contained in:
parent
f8a70b4cf4
commit
5a28f9d909
1 changed files with 49 additions and 29 deletions
|
|
@ -22,7 +22,7 @@ import Data.Map qualified as M
|
||||||
import Data.Maybe (fromJust)
|
import Data.Maybe (fromJust)
|
||||||
import Data.Set (Set)
|
import Data.Set (Set)
|
||||||
import Data.Set qualified as S
|
import Data.Set qualified as S
|
||||||
import Debug.Trace (trace)
|
import Debug.Trace (trace, traceShow)
|
||||||
import Grammar.Abs
|
import Grammar.Abs
|
||||||
import Grammar.Print (printTree)
|
import Grammar.Print (printTree)
|
||||||
import TypeChecker.TypeCheckerIr qualified as T
|
import TypeChecker.TypeCheckerIr qualified as T
|
||||||
|
|
@ -158,36 +158,31 @@ freeOrdered _ = mempty
|
||||||
checkBind :: Bind -> Infer (T.Bind' Type)
|
checkBind :: Bind -> Infer (T.Bind' Type)
|
||||||
checkBind (Bind name args e) = do
|
checkBind (Bind name args e) = do
|
||||||
let lambda = makeLambda e (reverse (coerce args))
|
let lambda = makeLambda e (reverse (coerce args))
|
||||||
(e, lambda_t) <- inferExp lambda
|
(e, infSig) <- inferExp lambda
|
||||||
s <- gets sigs
|
s <- gets sigs
|
||||||
case M.lookup (coerce name) s of
|
case M.lookup (coerce name) s of
|
||||||
Just (Just t') -> do
|
Just (Just typSig) -> do
|
||||||
let fvs0 = nub $ freeOrdered t'
|
let genInfSig = generalize mempty infSig
|
||||||
let m0 = M.fromList $ zip fvs0 letters
|
(trace ("Inferred: " ++ printTree infSig ++ "\nGeneralized inferred: " ++ printTree genInfSig ++ "\nGiven: " ++ printTree typSig ++ "\n") pure ())
|
||||||
let fvs1 = nub $ freeOrdered lambda_t
|
sub <- genInfSig `unify` typSig
|
||||||
let m1 = M.fromList $ zip fvs1 letters
|
|
||||||
let t0 = replace m0 t'
|
|
||||||
let t1 = replace m1 lambda_t
|
|
||||||
-- Not sure if this is actually correct
|
|
||||||
sub <- unify t' lambda_t
|
|
||||||
unless
|
unless
|
||||||
(t1 <<= t0)
|
(genInfSig <<= typSig)
|
||||||
( throwError $
|
( throwError $
|
||||||
Error
|
Error
|
||||||
( Aux.do
|
( Aux.do
|
||||||
"Inferred type"
|
"Inferred type"
|
||||||
quote $ printTree t1
|
quote $ printTree infSig
|
||||||
"doesn't match given type"
|
"doesn't match given type"
|
||||||
quote $ printTree $ mkForall t0
|
quote $ printTree typSig
|
||||||
)
|
)
|
||||||
False
|
False
|
||||||
)
|
)
|
||||||
-- Applying sub to t' will worsen error messages.
|
-- Applying sub to typSig will worsen error messages.
|
||||||
-- Unfortunately I do not know a better solution at the moment.
|
-- Unfortunately I do not know a better solution at the moment.
|
||||||
return $ T.Bind (coerce name, apply sub t') [] (apply sub e, lambda_t)
|
return $ T.Bind (coerce name, apply sub typSig) [] (apply sub e, typSig)
|
||||||
_ -> do
|
_ -> do
|
||||||
insertSig (coerce name) (Just lambda_t)
|
insertSig (coerce name) (Just infSig)
|
||||||
return (T.Bind (coerce name, lambda_t) [] (e, lambda_t))
|
return (T.Bind (coerce name, infSig) [] (e, infSig))
|
||||||
|
|
||||||
checkData :: (MonadState Env m, Monad m, MonadError Error m) => Data -> m ()
|
checkData :: (MonadState Env m, Monad m, MonadError Error m) => Data -> m ()
|
||||||
checkData err@(Data typ injs) = do
|
checkData err@(Data typ injs) = do
|
||||||
|
|
@ -276,7 +271,7 @@ algoW = \case
|
||||||
quote $ printTree t'
|
quote $ printTree t'
|
||||||
)
|
)
|
||||||
let comp = sub2 `compose` sub1 `compose` sub0
|
let comp = sub2 `compose` sub1 `compose` sub0
|
||||||
return (comp, (apply comp e', skolemize t))
|
return (comp, (apply comp e', t))
|
||||||
|
|
||||||
-- \| ------------------
|
-- \| ------------------
|
||||||
-- \| Γ ⊢ i : Int, ∅
|
-- \| Γ ⊢ i : Int, ∅
|
||||||
|
|
@ -384,7 +379,9 @@ algoW = \case
|
||||||
return (comp, apply comp (T.ECase (e', t) injs, ret_t))
|
return (comp, apply comp (T.ECase (e', t) injs, ret_t))
|
||||||
|
|
||||||
checkCase :: Type -> [Branch] -> Infer (Subst, [T.Branch' Type], Type)
|
checkCase :: Type -> [Branch] -> Infer (Subst, [T.Branch' Type], Type)
|
||||||
checkCase _ [] = catchableErr "Atleast one case required"
|
checkCase _ [] = do
|
||||||
|
fr <- fresh
|
||||||
|
return (nullSubst, [], fr)
|
||||||
checkCase expT brnchs = do
|
checkCase expT brnchs = do
|
||||||
(subs, branchTs, injs, returns) <- unzip4 <$> mapM inferBranch brnchs
|
(subs, branchTs, injs, returns) <- unzip4 <$> mapM inferBranch brnchs
|
||||||
let sub0 = composeAll subs
|
let sub0 = composeAll subs
|
||||||
|
|
@ -608,15 +605,37 @@ currently this is not the case, the TAll pattern match is incorrectly implemente
|
||||||
-}
|
-}
|
||||||
-- Is the left a subtype of the right
|
-- Is the left a subtype of the right
|
||||||
(<<=) :: Type -> Type -> Bool
|
(<<=) :: Type -> Type -> Bool
|
||||||
(<<=) (TVar _) _ = True
|
(<<=) a b = case (a,b) of
|
||||||
(<<=) t1@TAll{} t2 = skipForalls t1 <<= t2
|
(TVar _, _) -> True
|
||||||
(<<=) t1 t2@TAll{} = t1 <<= skipForalls t2
|
(TFun a b,TFun c d) -> a <<= c && b <<= d
|
||||||
(<<=) (TFun a b) (TFun c d) = a <<= c && b <<= d
|
(TAll tvar1 t1, TAll tvar2 t2) -> ungo [tvar1, tvar2] t1 t2
|
||||||
(<<=) (TData n1 ts1) (TData n2 ts2) =
|
(TAll tvar t1, t2) -> ungo [tvar] t1 t2
|
||||||
n1 == n2
|
(t1, TAll tvar t2) -> ungo [tvar] t1 t2
|
||||||
|
(TData n1 ts1, TData n2 ts2) -> n1 == n2
|
||||||
&& length ts1 == length ts2
|
&& length ts1 == length ts2
|
||||||
&& and (zipWith (<<=) ts1 ts2)
|
&& and (zipWith (<<=) ts1 ts2)
|
||||||
(<<=) a b = a == b
|
(t1,t2) -> t1 == t2
|
||||||
|
where
|
||||||
|
ungo :: [TVar] -> Type -> Type -> Bool
|
||||||
|
ungo tvars t1 t2 = case run (go tvars t1 t2) of
|
||||||
|
Right (b,_) -> b
|
||||||
|
_ -> False
|
||||||
|
go :: [TVar] -> Type -> Type -> Infer Bool
|
||||||
|
go tvars t1 t2 = do
|
||||||
|
fr <- fresh
|
||||||
|
let sub = M.fromList [(coerce x, fr) | (MkTVar x) <- tvars]
|
||||||
|
return (apply sub t1 <<= apply sub t2)
|
||||||
|
|
||||||
|
{-
|
||||||
|
|
||||||
|
typSig = TAll (MkTVar "a") (TAll (MkTVar "b") ((TVar (MkTVar "a") `TFun` (TVar (MkTVar "b")))))
|
||||||
|
|
||||||
|
infSig = generalize mempty $ TFun (TVar $ MkTVar "x") (TVar $ MkTVar "x")
|
||||||
|
|
||||||
|
a = TAll (MkTVar "a") (TFun (TVar $ MkTVar "a") (TVar $ MkTVar "a"))
|
||||||
|
b = TAll (MkTVar "b") (TFun (TVar $ MkTVar "b") (TVar $ MkTVar "b"))
|
||||||
|
int = TFun (TLit "Int") (TLit "Int")
|
||||||
|
-}
|
||||||
|
|
||||||
skipForalls :: Type -> Type
|
skipForalls :: Type -> Type
|
||||||
skipForalls = \case
|
skipForalls = \case
|
||||||
|
|
@ -897,6 +916,7 @@ data Error = Error {msg :: String, catchable :: Bool}
|
||||||
type Subst = Map T.Ident Type
|
type Subst = Map T.Ident Type
|
||||||
|
|
||||||
newtype Warning = NonExhaustive String
|
newtype Warning = NonExhaustive String
|
||||||
|
deriving (Show)
|
||||||
|
|
||||||
newtype Infer a = Infer {runInfer :: StateT Env (ReaderT Ctx (WriterT [Warning] (ExceptT Error Identity))) a}
|
newtype Infer a = Infer {runInfer :: StateT Env (ReaderT Ctx (WriterT [Warning] (ExceptT Error Identity))) a}
|
||||||
deriving (Functor, Applicative, Monad, MonadReader Ctx, MonadError Error, MonadState Env)
|
deriving (Functor, Applicative, Monad, MonadReader Ctx, MonadError Error, MonadState Env)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue