Simplify pattern matching
This commit is contained in:
parent
9ffcbf66b9
commit
e138cb27ec
1 changed files with 110 additions and 177 deletions
|
|
@ -4,19 +4,18 @@
|
|||
{-# LANGUAGE PatternSynonyms #-}
|
||||
{-# OPTIONS_GHC -Wno-incomplete-patterns #-}
|
||||
|
||||
module TypeChecker.TypeCheckerBidir (typecheck, getVars) where
|
||||
module TypeChecker.TypeCheckerBidir (typecheck) where
|
||||
|
||||
import Auxiliary (int, liftMM2, litType,
|
||||
maybeToRightM, onM, onMM, snoc)
|
||||
import Control.Applicative (Alternative, Applicative (liftA2),
|
||||
(<|>))
|
||||
import Auxiliary (int, litType, maybeToRightM, snoc)
|
||||
import Control.Applicative (Applicative (liftA2), (<|>))
|
||||
import Control.Monad.Except (ExceptT, MonadError (throwError),
|
||||
runExceptT, unless, zipWithM,
|
||||
zipWithM_)
|
||||
import Control.Monad.Extra (fromMaybeM, maybeM)
|
||||
import Control.Monad.Extra (fromMaybeM)
|
||||
import Control.Monad.State (MonadState, State, evalState, gets,
|
||||
modify)
|
||||
import Data.Coerce (coerce)
|
||||
import Data.Foldable (foldlM)
|
||||
import Data.Function (on)
|
||||
import Data.List (intercalate)
|
||||
import Data.Map (Map)
|
||||
|
|
@ -38,7 +37,8 @@ import qualified TypeChecker.TypeCheckerIr as T
|
|||
--
|
||||
-- TODO
|
||||
-- • Fix problems with types in Pattern/Branch in TypeCheckerIr
|
||||
-- • Fix the different type getters functions (e.g. partitionType) functions
|
||||
-- • Remove EAdd
|
||||
-- • Add kinds!!
|
||||
|
||||
data EnvElem = EnvVar LIdent Type -- ^ Term variable typing. x : A
|
||||
| EnvTVar TVar -- ^ Universal type variable. α
|
||||
|
|
@ -140,7 +140,7 @@ typecheckInj :: Inj -> UIdent -> [TVar] -> Err (T.Inj' Type)
|
|||
typecheckInj (Inj inj_name inj_typ) name tvars
|
||||
| not $ boundTVars tvars inj_typ
|
||||
= throwError "Unbound type variables"
|
||||
| TData name' typs <- getReturn inj_typ
|
||||
| TData name' typs <- getDataId inj_typ
|
||||
, name' == name
|
||||
, Right tvars' <- mapM toTVar typs
|
||||
, all (`elem` tvars) tvars'
|
||||
|
|
@ -149,7 +149,7 @@ typecheckInj (Inj inj_name inj_typ) name tvars
|
|||
= throwError $ unwords
|
||||
["Bad type constructor: ", show name
|
||||
, "\nExpected: ", ppT . TData name $ map TVar tvars
|
||||
, "\nActual: ", ppT $ getReturn inj_typ
|
||||
, "\nActual: ", ppT $ getDataId inj_typ
|
||||
]
|
||||
where
|
||||
boundTVars :: [TVar] -> Type -> Bool
|
||||
|
|
@ -161,6 +161,8 @@ typecheckInj (Inj inj_name inj_typ) name tvars
|
|||
TLit _ -> True
|
||||
TEVar _ -> error "TEVar in data type declaration"
|
||||
|
||||
|
||||
|
||||
---------------------------------------------------------------------------
|
||||
-- * Typing rules
|
||||
---------------------------------------------------------------------------
|
||||
|
|
@ -200,10 +202,72 @@ check e b = do
|
|||
subtype a b'
|
||||
apply (e', b)
|
||||
|
||||
checkPattern :: Pattern -> Type -> Tc (T.Pattern' Type, Type)
|
||||
checkPattern patt t_patt = case patt of
|
||||
|
||||
-- -------------------
|
||||
-- Γ ⊢ x ↑ A ⊣ Γ,(x:A)
|
||||
PVar x -> do
|
||||
insertEnv $ EnvVar x t_patt
|
||||
apply (T.PVar (coerce x, t_patt), t_patt)
|
||||
|
||||
-- -------------
|
||||
-- Γ ⊢ _ ↑ A ⊣ Γ
|
||||
PCatch -> apply (T.PCatch, t_patt)
|
||||
|
||||
-- Γ ⊢ τ ↓ A ⊣ Γ Γ ⊢ A <: B ⊣ Δ
|
||||
-- ------------------------------
|
||||
-- Γ ⊢ τ ↑ B ⊣ Δ
|
||||
PLit lit -> do
|
||||
subtype (litType lit) t_patt
|
||||
apply (T.PLit (lit, t_patt), t_patt)
|
||||
|
||||
-- Γ ∋ (K : A) Γ ⊢ A <: B ⊣ Δ
|
||||
-- ---------------------------
|
||||
-- Γ ⊢ K ↑ B ⊣ Δ
|
||||
PEnum name -> do
|
||||
t <- maybeToRightM ("Unknown constructor " ++ show name)
|
||||
=<< lookupInj name
|
||||
subtype t t_patt
|
||||
apply (T.PEnum (coerce name), t_patt)
|
||||
|
||||
-- Example
|
||||
-- Γ ∋ (K : A) let A = ∀α. A₁ -> A₂ -> Tτs
|
||||
-- Γ ⊢ [ά/α]Tτs <: B ⊣ Θ₁
|
||||
-- Θ ⊢ p₁ ↑ [Θ][ά/α]A₁ ⊣ Θ₂
|
||||
-- Θ₂ ⊢ p₂ ↑ [Θ₂][ά/α]A₂ ⊣ Δ
|
||||
-- ---------------------------
|
||||
-- Γ ⊢ K p₁ p₂ ↑ B ⊣ Δ
|
||||
PInj name ps -> do
|
||||
t_inj <- maybeToRightM "unknown constructor" =<< lookupInj name
|
||||
let ts = getArgs t_inj
|
||||
unless (length ts == length ps)
|
||||
$ throwError "Wrong number of arguments!"
|
||||
|
||||
-- [ά/α]
|
||||
sub <- substituteTVarsOf t_inj
|
||||
subtype (sub $ getDataId t_inj) t_patt
|
||||
let check p t = checkPattern p =<< apply (sub t)
|
||||
ps' <- zipWithM check ps ts
|
||||
apply (T.PInj (coerce name) (map fst ps'), t_patt)
|
||||
where
|
||||
substituteTVarsOf = \case
|
||||
TAll tvar t -> do
|
||||
tevar <- fresh
|
||||
(substitute tvar tevar .) <$> substituteTVarsOf t
|
||||
_ -> pure id
|
||||
|
||||
getArgs = \case
|
||||
TAll _ t -> getArgs t
|
||||
t -> go [] t
|
||||
where
|
||||
go acc = \case
|
||||
TFun t1 t2 -> go (snoc t1 acc) t2
|
||||
_ -> acc
|
||||
|
||||
-- | Γ ⊢ e ↓ A ⊣ Δ
|
||||
-- Under input context Γ, e infers output type A, with output context ∆
|
||||
infer :: Exp -> Tc (T.ExpT' Type)
|
||||
|
||||
infer (ELit lit) = apply (T.ELit lit, litType lit)
|
||||
|
||||
-- Γ ∋ (x : A) Γ ∌ (x : A)
|
||||
|
|
@ -273,14 +337,23 @@ infer (EAdd e1 e2) = do
|
|||
e2' <- check e2 int
|
||||
apply (T.EAdd e1' e2', int)
|
||||
|
||||
-- Θ ⊢ Π ∷ A ↓ C ⊣ Δ
|
||||
-- Γ ⊢ e ↓ A ⊣ Θ Δ ⊢ Π covers [Δ]A TODO
|
||||
-- ---------------------------------------
|
||||
-- Γ ⊢ e ↑ A ⊣ Θ Θ ⊢ Π ∷ [Θ]A ↑ C ⊣ Δ
|
||||
-- ------------------------------------ Case
|
||||
-- Γ ⊢ case e of Π ↓ C ⊣ Δ
|
||||
infer (ECase scrut branches) = do
|
||||
(scrut', t_scrut) <- infer scrut
|
||||
(branches', t_return) <- inferBranches branches t_scrut
|
||||
apply (T.ECase (scrut', t_scrut) branches', t_return)
|
||||
infer (ECase scrut pi) = do
|
||||
(scrut', a) <- infer scrut
|
||||
case pi of
|
||||
[] -> apply (T.ECase (scrut', a) [], a)
|
||||
(Branch _ e):_ -> do
|
||||
(_, c)<- infer e
|
||||
(pi', c') <- foldlM go ([], c) pi
|
||||
apply (T.ECase (scrut', a) pi', c')
|
||||
where
|
||||
go (bs, c) (Branch p e) = do
|
||||
p' <- checkPattern p =<< apply a
|
||||
e'@(_, c') <- infer e
|
||||
subtype c' c
|
||||
apply (T.Branch p' e' : bs, c')
|
||||
|
||||
-- | Γ ⊢ A • e ⇓ C ⊣ Δ
|
||||
-- Under input context Γ , applying a function of type A to e infers type C, with output context ∆
|
||||
|
|
@ -319,112 +392,6 @@ applyInfer (TFun a c) e = do
|
|||
|
||||
applyInfer a e = throwError ("Cannot apply type " ++ show a ++ " with expression " ++ show e)
|
||||
|
||||
---------------------------------------------------------------------------
|
||||
-- * Pattern matching
|
||||
---------------------------------------------------------------------------
|
||||
|
||||
-- Γ ⊢ p ⇒ e ∷ A ↓ B ⊣ Θ
|
||||
-- Θ ⊢ Π ∷ [Θ]A ↓ C ⊣ Δ
|
||||
-- [Δ]B <: C
|
||||
-- ---------------------------
|
||||
-- Γ ⊢ (p ⇒ e),Π ∷ A ↓ C ⊣ Δ
|
||||
inferBranches :: [Branch] -> Type -> Tc ([T.Branch' Type], Type)
|
||||
inferBranches branches t_patt = do
|
||||
(branches', ts_exp) <- inferBranches' t_patt branches
|
||||
t_exp <- case ts_exp of
|
||||
[] -> pure t_patt
|
||||
t:_ -> do
|
||||
zipWithM_ (onMM subtype apply) (init ts_exp) (tail ts_exp)
|
||||
apply t
|
||||
apply (branches', t_exp)
|
||||
where
|
||||
|
||||
inferBranches' = go [] []
|
||||
where
|
||||
go branches ts_exp t = \case
|
||||
[] -> pure (branches, ts_exp)
|
||||
b:bs -> do
|
||||
(b', t_e) <- inferBranch b t
|
||||
t' <- apply t
|
||||
go (snoc b' branches) (snoc t_e ts_exp) t' bs
|
||||
|
||||
-- Γ ⊢ p ↑ A ⊣ Θ Θ ⊢ e ↓ C ⊣ Δ
|
||||
-- -------------------------------
|
||||
-- Γ ⊢ p ⇒ e ∷ A ↓ C ⊣ Δ
|
||||
inferBranch :: Branch -> Type -> Tc (T.Branch' Type, Type)
|
||||
inferBranch (Branch patt exp) t_patt = do
|
||||
patt' <- checkPattern patt t_patt
|
||||
(exp', t_exp) <- infer exp
|
||||
apply (T.Branch patt' (exp', t_exp), t_exp)
|
||||
|
||||
checkPattern :: Pattern -> Type -> Tc (T.Pattern' Type, Type)
|
||||
checkPattern patt t_patt = case patt of
|
||||
|
||||
-- -------------------
|
||||
-- Γ ⊢ x ↑ A ⊣ Γ,(x:A)
|
||||
PVar x -> do
|
||||
insertEnv $ EnvVar x t_patt
|
||||
apply (T.PVar (coerce x, t_patt), t_patt)
|
||||
|
||||
-- -------------
|
||||
-- Γ ⊢ _ ↑ A ⊣ Γ
|
||||
PCatch -> apply (T.PCatch, t_patt)
|
||||
|
||||
-- Γ ⊢ τ ↓ A ⊣ Γ Γ ⊢ A <: B ⊣ Δ
|
||||
-- ------------------------------
|
||||
-- Γ ⊢ τ ↑ B ⊣ Δ
|
||||
PLit lit -> do
|
||||
subtype (litType lit) t_patt
|
||||
apply (T.PLit (lit, t_patt), t_patt)
|
||||
|
||||
-- Γ ∋ (K : A) Γ ⊢ A <: B ⊣ Δ
|
||||
-- ---------------------------
|
||||
-- Γ ⊢ K ↑ B ⊣ Δ
|
||||
PEnum name -> do
|
||||
t <- maybeToRightM ("Unknown constructor " ++ show name)
|
||||
=<< lookupInj name
|
||||
subtype t t_patt
|
||||
apply (T.PEnum (coerce name), t_patt)
|
||||
|
||||
|
||||
-- Example
|
||||
-- Γ ∋ (K : A) let A = ∀α. A₁ -> A₂ -> Tτs
|
||||
-- Γ ⊢ [ά/α]Tτs <: B ⊣ Θ₁
|
||||
-- Θ ⊢ p₁ ↑ [Θ][ά/α]A₁ ⊣ Θ₂
|
||||
-- Θ ⊢ p₂ ↑ [Θ][ά/α]A₂ ⊣ Δ
|
||||
-- ---------------------------
|
||||
-- Γ ⊢ K p₁ p₂ ↑ B ⊣ Δ
|
||||
PInj name ps -> do
|
||||
t_inj <- maybeToRightM "unknown constructor" =<< lookupInj name
|
||||
let ts = getParams t_inj
|
||||
unless (length ts == length ps) $
|
||||
throwError "Wrong number of arguments!"
|
||||
sub <- substituteTVarsOf t_inj
|
||||
subtype (sub $ getDataId t_inj) t_patt
|
||||
let checkP p t = checkPattern p =<< apply (sub t)
|
||||
ps' <- zipWithM checkP ps ts
|
||||
apply (T.PInj (coerce name) (map fst ps'), t_patt)
|
||||
where
|
||||
substituteTVarsOf = \case
|
||||
TAll tvar t -> do
|
||||
tevar <- fresh
|
||||
(substitute tvar tevar .) <$> substituteTVarsOf t
|
||||
_ -> pure id
|
||||
|
||||
getParams = \case
|
||||
TAll _ t -> getParams t
|
||||
t -> go [] t
|
||||
where
|
||||
go acc = \case
|
||||
TFun t1 t2 -> go (snoc t1 acc) t2
|
||||
_ -> acc
|
||||
|
||||
getDataId typ = case typ of
|
||||
TAll _ t -> getDataId t
|
||||
TFun _ t -> getDataId t
|
||||
TData {} -> typ
|
||||
|
||||
|
||||
---------------------------------------------------------------------------
|
||||
-- * Subtyping rules
|
||||
---------------------------------------------------------------------------
|
||||
|
|
@ -482,7 +449,6 @@ subtype (TEVar alpha) a | notElem alpha $ frees a = instantiateL alpha a
|
|||
-- Γ[ά] ⊢ A <: ά ⊣ Δ
|
||||
subtype a (TEVar alpha) | notElem alpha $ frees a = instantiateR a alpha
|
||||
|
||||
|
||||
subtype t1 t2 = case (t1, t2) of
|
||||
(TData name1 typs1, TData name2 typs2)
|
||||
|
||||
|
|
@ -564,14 +530,13 @@ instantiateL alpha a = gets env >>= \env -> go env alpha a
|
|||
putEnv env_l
|
||||
|
||||
go _ alpha a = error $ "Trying to instantiateL: " ++ ppT (TEVar alpha)
|
||||
++ " <: " ++ ppT a
|
||||
++ " <: " ++ ppT a
|
||||
|
||||
-- | Γ ⊢ A =:< ά ⊣ Δ
|
||||
-- Under input context Γ, instantiate ά such that A <: ά, with output context ∆
|
||||
instantiateR :: Type -> TEVar -> Tc ()
|
||||
instantiateR a alpha = gets env >>= \env -> go env a alpha
|
||||
where
|
||||
|
||||
-- Γ ⊢ τ
|
||||
-- ----------------------------- InstRSolve
|
||||
-- Γ,ά,Γ' ⊢ τ =:< ά ⊣ Γ,(ά=τ),Γ'
|
||||
|
|
@ -588,11 +553,9 @@ instantiateR a alpha = gets env >>= \env -> go env a alpha
|
|||
let (env_l, env_r) = splitOn (EnvTEVar epsilon) env
|
||||
putEnv $ (env_l :|> EnvTEVarSolved epsilon (TEVar alpha)) <> env_r
|
||||
|
||||
|
||||
|
||||
-- Γ[ά₂ά₁,(ά=ά₁→ά₂)] ⊢ A₁ :=< ά₁ ⊣ Θ Θ ⊢ ά₂ =:< [Θ]A₂ ⊣ Δ
|
||||
-- ------------------------------------------------------- InstRArr
|
||||
-- Γ[ά] ⊢ A₁ → A₂ =:< ά ⊣ Δ
|
||||
-- Γ[ά₂ά₁,(ά=ά₁→ά₂)] ⊢ A₁ :=< ά₁ ⊣ Θ Θ ⊢ ά₂ =:< [Θ]A₂ ⊣ Δ
|
||||
-- ------------------------------------------------------- InstRArr
|
||||
-- Γ[ά] ⊢ A₁ → A₂ =:< ά ⊣ Δ
|
||||
go _ (TFun a1 a2) alpha = do
|
||||
alpha1 <- fresh
|
||||
alpha2 <- fresh
|
||||
|
|
@ -603,24 +566,19 @@ instantiateR a alpha = gets env >>= \env -> go env a alpha
|
|||
a2' <- apply a2
|
||||
instantiateR a2' alpha2
|
||||
|
||||
|
||||
|
||||
-- Γ[ά],▶έ,ε ⊢ [έ/ε]E =:< ά ⊣ Δ,▶έ,Δ'
|
||||
-- ---------------------------------- InstRAIIL
|
||||
-- Γ[ά] ⊢ ∀ε.Ε =:< ά ⊣ Δ
|
||||
-- Γ[ά],▶έ,ε ⊢ [έ/ε]E =:< ά ⊣ Δ,▶έ,Δ'
|
||||
-- ---------------------------------- InstRAIIL
|
||||
-- Γ[ά] ⊢ ∀ε.Ε =:< ά ⊣ Δ
|
||||
go env (TAll epsilon e) alpha = do
|
||||
epsilon' <- fresh
|
||||
insertEnv $ EnvMark epsilon'
|
||||
insertEnv $ EnvTVar epsilon
|
||||
instantiateR (substitute epsilon epsilon' e) alpha
|
||||
let (env_l, _) = splitOn (EnvMark epsilon') env
|
||||
putEnv env_l
|
||||
epsilon' <- fresh
|
||||
insertEnv $ EnvMark epsilon'
|
||||
insertEnv $ EnvTVar epsilon
|
||||
instantiateR (substitute epsilon epsilon' e) alpha
|
||||
let (env_l, _) = splitOn (EnvMark epsilon') env
|
||||
putEnv env_l
|
||||
|
||||
go _ a alpha = error $ "Trying to instantiateR: " ++ ppT a ++ " <: "
|
||||
++ ppT (TEVar alpha)
|
||||
|
||||
|
||||
|
||||
++ ppT (TEVar alpha)
|
||||
|
||||
---------------------------------------------------------------------------
|
||||
-- * Auxiliary
|
||||
|
|
@ -713,35 +671,6 @@ fresh = do
|
|||
modify $ \cxt -> cxt { next_tevar = succ cxt.next_tevar }
|
||||
pure tevar
|
||||
|
||||
getVars :: Type -> [Type]
|
||||
getVars = fst . partitionType
|
||||
|
||||
getReturn :: Type -> Type
|
||||
getReturn = snd . partitionType
|
||||
|
||||
-- | Partion type into variable types and return type.
|
||||
--
|
||||
-- ∀a.∀b. a → (∀c. c → c) → b
|
||||
-- ([a, ∀c. c → c], b)
|
||||
--
|
||||
-- Unsure if foralls should be added to the return type or not.
|
||||
-- FIXME
|
||||
partitionType :: Type -> ([Type], Type)
|
||||
partitionType = go [] . skipForalls'
|
||||
where
|
||||
go acc t = case t of
|
||||
TFun t1 t2 -> go (snoc t1 acc) t2
|
||||
_ -> (acc, t)
|
||||
|
||||
skipForalls' :: Type -> Type
|
||||
skipForalls' = snd . skipForalls
|
||||
|
||||
skipForalls :: Type -> ([Type -> Type], Type)
|
||||
skipForalls = go []
|
||||
where
|
||||
go acc typ = case typ of
|
||||
TAll tvar t -> go (snoc (TAll tvar) acc) t
|
||||
_ -> (acc, typ)
|
||||
|
||||
isComplete :: Env -> Bool
|
||||
isComplete = isNothing . S.findIndexL unSolvedTEVar
|
||||
|
|
@ -750,6 +679,12 @@ isComplete = isNothing . S.findIndexL unSolvedTEVar
|
|||
EnvTEVar _ -> True
|
||||
_ -> False
|
||||
|
||||
getDataId :: Type -> Type
|
||||
getDataId typ = case typ of
|
||||
TAll _ t -> getDataId t
|
||||
TFun _ t -> getDataId t
|
||||
TData {} -> typ
|
||||
|
||||
toTVar :: Type -> Err TVar
|
||||
toTVar = \case
|
||||
TVar tvar -> pure tvar
|
||||
|
|
@ -764,7 +699,6 @@ lookupSig x = gets (Map.lookup x . sig)
|
|||
insertSig :: LIdent -> Type -> Tc ()
|
||||
insertSig name t = modify $ \cxt -> cxt { sig = Map.insert name t cxt.sig }
|
||||
|
||||
|
||||
lookupEnv :: LIdent -> Tc (Maybe Type)
|
||||
lookupEnv x = gets (findId . env)
|
||||
where
|
||||
|
|
@ -786,7 +720,6 @@ modifyEnv f =
|
|||
pattern DBind' name vars exp = DBind (Bind name vars exp)
|
||||
pattern DSig' name typ = DSig (Sig name typ)
|
||||
|
||||
|
||||
---------------------------------------------------------------------------
|
||||
-- * Apply
|
||||
---------------------------------------------------------------------------
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue