Revork type checking of data types to make in reliable

This commit is contained in:
Martin Fredin 2023-05-15 00:31:30 +02:00
parent 46d4ef3923
commit 5e15983f4c
2 changed files with 151 additions and 208 deletions

View file

@ -7,17 +7,18 @@
module TypeChecker.TypeCheckerBidir (typecheck) where
import Auxiliary (int, maybeToRightM, onM, snoc,
typeof)
import Auxiliary (int, mapAccumM, maybeToRightM, onM,
snoc, typeof)
import Control.Applicative (Applicative (liftA2), liftA3, (<|>))
import Control.Monad.Except (ExceptT, MonadError (throwError),
MonadTrans (lift), forM, runExceptT,
unless, zipWithM, zipWithM_)
import Control.Monad.Extra (fromMaybeM, ifM)
MonadTrans (lift), foldM, forM,
forM_, runExceptT, unless, zipWithM,
zipWithM_)
import Control.Monad.Extra (fromMaybeM, ifM, maybeM, unlessM)
import Control.Monad.State (MonadState (get), State, StateT,
evalState, evalStateT, gets, modify)
import Data.Coerce (coerce)
import Data.Foldable (foldlM)
import Data.Foldable (foldlM, foldrM)
import Data.Function (on)
import Data.List (intercalate)
import Data.Map (Map)
@ -25,8 +26,9 @@ import qualified Data.Map as Map
import Data.Maybe (fromMaybe, isNothing)
import Data.Sequence (Seq (..))
import qualified Data.Sequence as S
import Data.Set (Set)
import qualified Data.Set as Set
import Data.Tuple.Extra (second)
import Data.Tuple.Extra (first, second)
import Debug.Trace (trace)
import Grammar.Abs
import Grammar.ErrM
@ -59,14 +61,13 @@ data Cxt = Cxt
, sig :: Map LIdent Type -- ^ Top-level signatures x : A
, binds :: Map LIdent Exp -- ^ Top-level binds x : e
, next_tevar :: Int -- ^ Counter to distinguish ά
, data_injs :: Map UIdent Type -- ^ Data injections (constructors) K/inj : A
, data_injs :: Map UIdent (Type, [TVar]) -- ^ Data injections (constructors) K : A
, currentBind :: LIdent -- ^ Used for recursive functions
} deriving (Show, Eq)
newtype Tc a = Tc { runTc :: ExceptT String (State Cxt) a }
deriving (Functor, Applicative, Monad, MonadState Cxt, MonadError String)
initCxt :: [Def] -> Cxt
initCxt defs = Cxt
{ env = mempty
@ -77,14 +78,14 @@ initCxt defs = Cxt
| DBind' name vars rhs <- defs
]
, next_tevar = 0
, data_injs = Map.fromList [ (name, foldr TAll t $ unboundedTVars t)
, data_injs = Map.fromList [ (name, (t, unboundedTVars t))
| DData (Data _ injs) <- defs
, Inj name t <- injs
]
, currentBind = ""
}
where
unboundedTVars = uncurry (Set.\\) . go (mempty, mempty)
unboundedTVars = Set.toList . uncurry (Set.\\) . go (mempty, mempty)
where
go (unbounded, bounded) = \case
TAll tvar t -> go (unbounded, Set.insert tvar bounded) t
@ -108,16 +109,17 @@ typecheckBinds cxt = flip evalState cxt
typecheckBind :: Bind -> Tc (T.Bind' Type)
typecheckBind (Bind name vars rhs) = do
modify $ \cxt -> cxt { currentBind = name }
bind'@(T.Bind (name, typ) _ _) <- lookupSig name >>= \case
bind' <- lookupSig name >>= \case
Just t -> do
insertSig (coerce name) t
(rhs', _) <- check (foldr EAbs rhs vars) t
pure (T.Bind (coerce name, t) [] (rhs', t))
Nothing -> do
(e, t) <- apply =<< infer (foldr EAbs rhs vars)
insertSig (coerce name) t
pure (T.Bind (coerce name, t) [] (e, t))
env <- gets env
unless (isComplete env) err
insertSig (coerce name) typ
putEnv Empty
pure bind'
where
@ -166,8 +168,6 @@ typecheckInj (Inj inj_name inj_typ) name tvars
TLit _ -> True
TEVar _ -> error "TEVar in data type declaration"
---------------------------------------------------------------------------
-- * Typing rules
---------------------------------------------------------------------------
@ -199,26 +199,30 @@ check (EAbs x e) (TFun a b) = do
putEnv env_l
apply (T.EAbs (coerce x) e', TFun a b)
-- Γ ⊢ e ↑ A ⊣ Θ Θ ⊢ [Θ]A <: [Θ]C ⊣ Δ
-- ----------------------------------- CaseEmpty
-- Γ ⊢ case e of {} ↓ C ⊣ Δ
-- Θ₁ ⊢ p₁⇒e₁ [Θ₁]C ⊣ Θ₂
-- Θ₁ ⊢ p₁⇒e₁ [Θ₁]C ⊣ Θ₂
-- ...
-- Γ ⊢ e ↑ A ⊣ Θ₁ Θₙ ⊢ pₙ⇒eₙ [Θₙ]C ⊣ Δ
-- Γ ⊢ e ↑ A ⊣ Θ₁ Θₙ ⊢ pₙ⇒eₙ [Θₙ]C ⊣ Δ
-- --------------------------------------- Case
-- Γ ⊢ case e of {p₁⇒e₁ ‥ pₙ⇒eₙ} C ⊣ Δ
check (ECase scrut branches) c = do
(scrut', a) <- infer scrut
-- Γ ⊢ case e of {p₁⇒e₁ ‥ pₙ⇒eₙ} C ⊣ Δ
check (ECase e branches) c = do
(e', a) <- infer e
case branches of
[] -> do
subtype a c
apply (T.ECase (scrut', a) [], a)
subtype a =<< apply c
apply (T.ECase (e', a) [], a)
_ -> do
branches' <- checkBranches branches a c
apply (T.ECase (scrut', a) branches', c)
branches' <- evalStateT (checkBranches a) mempty
apply (T.ECase (e', a) branches', c)
where
checkBranches a = forM branches $ \(Branch p e) -> do
p' <- match p =<< lift (apply a)
lift $ do
e' <- check e c
apply (T.Branch p' e')
-- Γ,α ⊢ e ↓ A ⊣ Θ Θ ⊢ [Θ]A <: [Θ]B ⊣ Δ
-- -------------------------------------- Sub
@ -229,134 +233,6 @@ check e b = do
subtype a b'
apply (e', b)
-- Γ ⊢ p₁⇒e₁ ‥ pₙ⇒eₙ :: A ↑ C ⊣ Δ
checkBranches :: [Branch] -> Type -> Type -> Tc [T.Branch' Type]
checkBranches branches a c = evalStateT go mempty
where
go = forM branches $ \(Branch p e) -> do
p' <- patternMatch p =<< lift (apply a)
lift $ do
e' <- check e c
apply (T.Branch p' e')
substituteAll :: Type -> StateT (Map TVar TEVar) Tc Type
substituteAll t = get >>= \subs -> case t of
TAll tvar t
| Just tevar <- Map.lookup tvar subs ->
lift . pure $ substitute tvar tevar t
| otherwise -> do
tevar <- lift fresh
modify (Map.insert tvar tevar)
substituteAll (substitute tvar tevar t)
TFun t1 t2 -> onM TFun substituteAll t1 t2
t -> pure t
patternMatch :: Pattern -> Type -> StateT (Map TVar TEVar) Tc (T' T.Pattern' Type)
-- ------------------- PVar
-- Γ ⊢ x :: A ⊣ Γ,(x:A)
patternMatch (PVar x) a = lift $ do
insertEnv $ EnvVar x a
apply (T.PVar (coerce x), a)
-- ------------- PCatch
-- Γ ⊢ _ :: A ⊣ Γ
patternMatch PCatch a = lift $ apply (T.PCatch, a)
-- Γ ⊢ typeof(lit) <: A ⊣ Δ
-- ------------------------- PLit
-- Γ ⊢ lit :: A ⊣ Δ
patternMatch (PLit lit) a = lift $ do
subtype (typeof lit) a
apply (T.PLit lit, typeof lit)
-- Γ ∋ (K : A) Γ ⊢ A <: C ⊣ Δ
-- ---------------------------
-- Γ ⊢ K :: C ⊣ Δ
patternMatch (PEnum k) b = do
a <- lift (maybeToRightM ("Unknown constructor " ++ show k) =<< lookupInj k)
a <- substituteAll a
lift $ do
subtype a b
apply (T.PEnum (coerce k), a)
-- β α Γ Γ Δ
--
-- Γ ∋ (K : A) Θ₂ ⊢ p₁ :: [Θ₁]A₁ ⊣ Θ₂
-- Γ ⊢ ∀ά₁‥άₘ A₁ → ‥ → Aₙ₊₁ = substituteAll(A) ⊣ Θ₁ ...
-- Θ₁ ⊢ Aₙ₊₁ <: B ⊣ Θ₂ Β Θₙ₊₁ ⊢ pₙ :: [Θₙ₊₁]Aₙ ⊣ Δ
-- ----------------------------------------------------------------------------- PInj
-- Γ ⊢ K p₁‥pₙ ↑ B ⊣ Δ
patternMatch (PInj k ps) b = do
a <- maybeToRightM ("Unknown constructor " ++ show k) =<< lift (lookupInj k)
a <- substituteAll a
lift $ subtype (getDataId a) b
let as = getArgs a
unless (length as == length ps) $ throwError "Wrong number of arguments!"
ps' <- zipWithM (\p a -> patternMatch p =<< lift (apply a)) ps as
lift $ apply (T.PInj (coerce k) ps', a)
where
getArgs = \case
TAll _ t -> getArgs t
t -> go [] t
where
go acc = \case
TFun t1 t2 -> go (snoc t1 acc) t2
_ -> acc
-- Example
-- Γ ∋ (K : A) let A = ∀α. A₁ -> A₂ -> Tτs
-- Γ ⊢ [ά/α]Tτs <: B ⊣ Θ₁
-- Θ ⊢ p₁ ↑ [Θ][ά/α]A₁ ⊣ Θ₂
-- Θ₂ ⊢ p₂ ↑ [Θ₂][ά/α]A₂ ⊣ Δ
-- ---------------------------
-- Γ ⊢ K p₁ p₂ ↑ B ⊣ Δ
-- patternMatch (PInj name ps) t_patt = do
-- t_inj <- maybeToRightM "unknown constructor" =<< lookupInj name
-- let ts = getArgs t_inj
-- unless (length ts == length ps)
-- $ throwError "Wrong number of arguments!"
--
-- -- [ά/α]
-- sub <- substituteTVarsOf t_inj
-- subtype (sub $ getDataId t_inj) t_patt
-- let check p t = checkPattern p =<< apply (sub t)
-- ps' <- zipWithM check ps ts
-- apply (T.PInj (coerce name) ps', t_patt)
-- where
-- substituteTVarsOf = \case
-- TAll tvar t -> do
-- tevar <- fresh
-- (substitute tvar tevar .) <$> substituteTVarsOf t
-- _ -> pure id
--
-- getArgs = \case
-- TAll _ t -> getArgs t
-- t -> go [] t
-- where
-- go acc = \case
-- TFun t1 t2 -> go (snoc t1 acc) t2
-- _ -> acc
-- | Γ ⊢ e ↓ A ⊣ Δ
-- Under input context Γ, e infers output type A, with output context ∆
infer :: Exp -> Tc (T' T.Exp' Type)
@ -366,23 +242,27 @@ infer (ELit lit) = apply (T.ELit lit, typeof lit)
-- ------------- Var --------------------- VarRec
-- Γ ⊢ x ↓ A ⊣ Γ Γ ⊢ x ↓ ά ⊣ Γ,(x : ά)
infer (EVar x) = do
a <- ifM (gets $ (x==) . currentBind) varRec var
a <- fromMaybeM varRec var
apply (T.EVar (coerce x), a)
where
var = maybeToRightM "Can't infer" =<<
liftA2 (<|>) (lookupEnv x) (lookupSig x)
var = liftA2 (<|>) (lookupEnv x) (lookupSig x)
varRec = do
unlessM (gets $ (x==) . currentBind) $
throwError ("Can't infer " ++ show x)
alpha <- TEVar <$> fresh
insertEnv (EnvVar x alpha)
pure alpha
-- Γ ∋ (k : A)
-- ------------- Inj
-- Γ ⊢ k ↓ A ⊣ Γ
-- TODO
infer (EInj k) = do
t <- maybeToRightM ("Unknown constructor: " ++ show k)
=<< lookupInj k
apply (T.EInj $ coerce k, t)
(t, as) <- maybeToRightM ("Unknown constructor here2: " ++ show k)
=<< lookupInj k
t' <- foldM go t as
apply (T.EInj $ coerce k, t')
where
go t a = do
a' <- fresh
pure $ substitute a a' t
-- Γ ⊢ A Γ ⊢ e ↑ A ⊣ Δ
-- --------------------- Anno
@ -414,17 +294,22 @@ infer (EAbs name e) = do
dropTrailing env_var
apply (T.EAbs (coerce name) e', on TFun TEVar alpha epsilon)
-- Γ ⊢ rhs ↓ A ⊣ Θ Θ,(x:A) ⊢ e ↑ C ⊣ Δ,(x:A),Θ
-- Γ ⊢ λys.rhs ↓ A ⊣ Θ Θ,(x:A) ⊢ e ↑ C ⊣ Δ,(x:A),Θ
-- -------------------------------------------- LetI
-- Γ ⊢ let x = rhs in e ↑ C ⊣ Δ
-- Γ ⊢ let x ys = rhs in e' ↑ C ⊣ Δ
infer (ELet (Bind x vars rhs) e) = do
(rhs', a) <- infer $ foldr EAbs rhs vars
let env_var = EnvVar x a
let (a_ret, a_vars) = go vars a
env_var = EnvVar x a
insertEnv env_var
e'@(_, c) <- infer e
(env_l, _) <- gets (splitOn env_var . env)
putEnv env_l
apply (T.ELet (T.Bind (coerce x, a) [] (rhs', a)) e', c)
apply (T.ELet (T.Bind (coerce x, a) a_vars (rhs', a_ret)) e', c)
where
go [] t = (t, [])
go (x:xs) (TFun t1 t2) = second (snoc (coerce x, t1)) $ go xs t2
go _ _ = error "IMPOSSIBLE"
-- Γ ⊢ e₁ ↑ Int ⊣ Θ Θ ⊢ e₂ ↑ Int
-- --------------------------- +I
@ -434,30 +319,21 @@ infer (EAdd e1 e2) = do
e2' <- check e2 int
apply (T.EAdd e1' e2', int)
-- Experimental inferece for case
infer (ECase e branches) = do
(e', a) <- infer e
case branches of
[] -> apply (T.ECase (e', a) [], a)
_ -> do
let inferBranch bs (Branch p e) = do
p' <- match p =<< lift (apply a)
lift $ do
e'@(_, b) <- infer e
mapM_ (subtype b) bs
apply (b:bs, T.Branch p' e')
-- Γ ⊢ e ↑ A ⊣ Δ
-- ------------------------ CaseEmpty↓
-- Γ ⊢ case e of {} ↑ A ⊣ Δ
-- Θ₁ ⊢ p₁⇒e₁ ↓ [Θ₁]C ⊣ Θ₂
-- ...
-- Γ ⊢ e ↑ A ⊣ Θ₁ Θₙ ⊢ pₙ⇒eₙ ↓ [Θₙ]C ⊣ Δ
-- --------------------------------------- Case↓
-- Γ ⊢ case e of {p₁⇒e₁ ‥ pₙ⇒eₙ} ↓ C ⊣ Δ
-- infer (ECase scrut branches) = do
-- (scrut', a) <- infer scrut
-- case branches of
-- [] -> apply (T.ECase (scrut', a) [], a)
-- (Branch _ e):_ -> do
-- (_, b)<- infer e
-- (branches', b') <- foldlM go ([], b) branches
-- apply (T.ECase (scrut', a) branches', b')
-- where
-- go (pi, b) (Branch p e) = do
-- p' <- checkPattern p =<< apply a
-- e'@(_, b') <- infer e
-- subtype b' b
-- apply (T.Branch p' e' : pi, b')
(bs, branches') <- evalStateT (mapAccumM inferBranch [] branches) mempty
apply (T.ECase (e', a) branches', head bs)
-- | Γ ⊢ A • e ⇓ C ⊣ Δ
-- Under input context Γ , applying a function of type A to e infers type C, with output context ∆
@ -643,16 +519,15 @@ instantiateL alpha a = gets env >>= \env -> go env alpha a
instantiateR :: Type -> TEVar -> Tc ()
instantiateR a alpha = gets env >>= \env -> go env a alpha
where
-- Γ ⊢ τ
-- ----------------------------- InstRSolve
-- Γ,ά,Γ' ⊢ τ =:< ά ⊣ Γ,(ά=τ),Γ'
-- Γ ⊢ τ
-- ----------------------------- InstRSolve
-- Γ,ά,Γ' ⊢ τ =:< ά ⊣ Γ,(ά=τ),Γ'
go env tau alpha
| isMono tau
, (env_l, env_r) <- splitOn (EnvTEVar alpha) env
, Right _ <- wellFormed env_l tau
= putEnv $ (env_l :|> EnvTEVarSolved alpha tau) <> env_r
--
-- ----------------------------- InstRReach
-- Γ[ά][έ] ⊢ έ =:< ά ⊣ Γ[ά][έ=ά]
go env (TEVar epsilon) alpha = do
@ -686,6 +561,75 @@ instantiateR a alpha = gets env >>= \env -> go env a alpha
go _ a alpha = throwError $ "Trying to instantiateR: " ++ ppT a ++ " <: "
++ ppT (TEVar alpha)
---------------------------------------------------------------------------
-- * Pattern Matching
---------------------------------------------------------------------------
substituteP :: TVar -> Type -> StateT (Map TVar TEVar) Tc Type
substituteP alpha a = go =<< get
where
go subs
| Just alpha' <- Map.lookup alpha subs
= lift $ pure $ substitute alpha alpha' a
| otherwise
= do
alpha' <- lift fresh
modify (Map.insert alpha alpha')
lift $ pure $ substitute alpha alpha' a
match :: Pattern -> Type -> StateT (Map TVar TEVar) Tc (T' T.Pattern' Type)
-- ------------------- PVar
-- Γ ⊢ x :: A ⊣ Γ,(x:A)
match (PVar x) a = lift $ do
insertEnv $ EnvVar x a
apply (T.PVar (coerce x), a)
-- ------------- PCatch
-- Γ ⊢ _ :: A ⊣ Γ
match PCatch a = lift $ apply (T.PCatch, a)
-- Γ ⊢ typeof(lit) <: A ⊣ Δ
-- ------------------------- PLit
-- Γ ⊢ lit :: A ⊣ Δ
match (PLit lit) a = lift $ do
subtype (typeof lit) a
apply (T.PLit lit, typeof lit)
match (PEnum k) b = do
(a, tvars) <- lift (maybeToRightM ("Unknown constructor: " ++ show k) =<< lookupInj k)
a' <- foldrM substituteP a tvars
lift $ do
subtype a' b
apply (T.PEnum (coerce k), a')
match (PInj k ps) b = do
(a, tvars) <- lift (maybeToRightM ("Unknown constructor: " ++ show k) =<< lookupInj k)
a' <- foldrM substituteP a tvars
let t_return = getDataId a'
lift $ subtype t_return b
a'' <- lift $ apply a'
let as = getArgs a''
unless (length as == length ps) $ throwError "Wrong number of arguments!"
ps' <- zipWithM matchArgs ps as
lift $ apply (T.PInj (coerce k) ps', t_return)
where
matchArgs p@(PVar _) a = match p =<< lift (apply a)
matchArgs _ _ = throwError "Nested pattern matching not supported!"
getArgs = \case
TAll _ t -> getArgs t
t -> go [] t
where
go acc = \case
TFun t1 t2 -> go (snoc t1 acc) t2
_ -> acc
---------------------------------------------------------------------------
-- * Auxiliary
---------------------------------------------------------------------------
@ -704,16 +648,16 @@ substitute :: TVar -- α
-> TEVar -- ά
-> Type -- A
-> Type -- [ά/α]A
substitute tvar tevar typ = case typ of
TLit _ -> typ
TVar tvar' | tvar' == tvar -> TEVar tevar
| otherwise -> typ
TEVar _ -> typ
TFun t1 t2 -> on TFun substitute' t1 t2
TAll tvar' t -> TAll tvar' (substitute' t)
TData name typs -> TData name $ map substitute' typs
substitute alpha alpha' typ = case typ of
TLit _ -> typ
TVar tvar | tvar == alpha -> TEVar alpha'
| otherwise -> typ
TEVar _ -> typ
TFun t1 t2 -> on TFun subs t1 t2
TAll tvar' t -> TAll tvar' (subs t)
TData name typs -> TData name $ map subs typs
where
substitute' = substitute tvar tevar
subs = substitute alpha alpha'
-- | Γ,x,Γ' → (Γ, Γ')
splitOn :: EnvElem -> Env -> (Env, Env)
@ -777,7 +721,6 @@ fresh = do
modify $ \cxt -> cxt { next_tevar = succ cxt.next_tevar }
pure tevar
isComplete :: Env -> Bool
isComplete = isNothing . S.findIndexL unSolvedTEVar
where
@ -813,15 +756,15 @@ lookupEnv x = gets (findId . env)
EnvVar x' t | x==x' -> Just t
_ -> findId ys
lookupInj :: UIdent -> Tc (Maybe Type)
lookupInj x = gets (Map.lookup x . data_injs)
lookupInj :: UIdent -> Tc (Maybe (Type, [TVar]))
lookupInj x = gets $ Map.lookup x . data_injs
putEnv :: Env -> Tc ()
putEnv = modifyEnv . const
modifyEnv :: (Env -> Env) -> Tc ()
modifyEnv f =
modify $ \cxt -> {- trace (ppEnv (f cxt.env)) -} cxt { env = f cxt.env }
modify $ \cxt -> {- trace (ppEnv (f cxt.env)) -}cxt { env = f cxt.env }
pattern DBind' name vars exp = DBind (Bind name vars exp)
pattern DSig' name typ = DSig (Sig name typ)

View file

@ -90,7 +90,7 @@ prtSig (x, t) =
]
instance (Print a, Print t) => Print (T a t) where
prt i (x, t) = withT
prt i (x, t) = noT
where
noT = prt i x
withT = concatD