Simplify pattern matching

This commit is contained in:
Martin Fredin 2023-04-25 22:59:33 +02:00
parent 9ffcbf66b9
commit e138cb27ec

View file

@ -4,19 +4,18 @@
{-# LANGUAGE PatternSynonyms #-}
{-# OPTIONS_GHC -Wno-incomplete-patterns #-}
module TypeChecker.TypeCheckerBidir (typecheck, getVars) where
module TypeChecker.TypeCheckerBidir (typecheck) where
import Auxiliary (int, liftMM2, litType,
maybeToRightM, onM, onMM, snoc)
import Control.Applicative (Alternative, Applicative (liftA2),
(<|>))
import Auxiliary (int, litType, maybeToRightM, snoc)
import Control.Applicative (Applicative (liftA2), (<|>))
import Control.Monad.Except (ExceptT, MonadError (throwError),
runExceptT, unless, zipWithM,
zipWithM_)
import Control.Monad.Extra (fromMaybeM, maybeM)
import Control.Monad.Extra (fromMaybeM)
import Control.Monad.State (MonadState, State, evalState, gets,
modify)
import Data.Coerce (coerce)
import Data.Foldable (foldlM)
import Data.Function (on)
import Data.List (intercalate)
import Data.Map (Map)
@ -38,7 +37,8 @@ import qualified TypeChecker.TypeCheckerIr as T
--
-- TODO
-- • Fix problems with types in Pattern/Branch in TypeCheckerIr
-- • Fix the different type getters functions (e.g. partitionType) functions
-- • Remove EAdd
-- • Add kinds!!
data EnvElem = EnvVar LIdent Type -- ^ Term variable typing. x : A
| EnvTVar TVar -- ^ Universal type variable. α
@ -140,7 +140,7 @@ typecheckInj :: Inj -> UIdent -> [TVar] -> Err (T.Inj' Type)
typecheckInj (Inj inj_name inj_typ) name tvars
| not $ boundTVars tvars inj_typ
= throwError "Unbound type variables"
| TData name' typs <- getReturn inj_typ
| TData name' typs <- getDataId inj_typ
, name' == name
, Right tvars' <- mapM toTVar typs
, all (`elem` tvars) tvars'
@ -149,7 +149,7 @@ typecheckInj (Inj inj_name inj_typ) name tvars
= throwError $ unwords
["Bad type constructor: ", show name
, "\nExpected: ", ppT . TData name $ map TVar tvars
, "\nActual: ", ppT $ getReturn inj_typ
, "\nActual: ", ppT $ getDataId inj_typ
]
where
boundTVars :: [TVar] -> Type -> Bool
@ -161,6 +161,8 @@ typecheckInj (Inj inj_name inj_typ) name tvars
TLit _ -> True
TEVar _ -> error "TEVar in data type declaration"
---------------------------------------------------------------------------
-- * Typing rules
---------------------------------------------------------------------------
@ -200,10 +202,72 @@ check e b = do
subtype a b'
apply (e', b)
checkPattern :: Pattern -> Type -> Tc (T.Pattern' Type, Type)
checkPattern patt t_patt = case patt of
-- -------------------
-- Γ ⊢ x ↑ A ⊣ Γ,(x:A)
PVar x -> do
insertEnv $ EnvVar x t_patt
apply (T.PVar (coerce x, t_patt), t_patt)
-- -------------
-- Γ ⊢ _ ↑ A ⊣ Γ
PCatch -> apply (T.PCatch, t_patt)
-- Γ ⊢ τ ↓ A ⊣ Γ Γ ⊢ A <: B ⊣ Δ
-- ------------------------------
-- Γ ⊢ τ ↑ B ⊣ Δ
PLit lit -> do
subtype (litType lit) t_patt
apply (T.PLit (lit, t_patt), t_patt)
-- Γ ∋ (K : A) Γ ⊢ A <: B ⊣ Δ
-- ---------------------------
-- Γ ⊢ K ↑ B ⊣ Δ
PEnum name -> do
t <- maybeToRightM ("Unknown constructor " ++ show name)
=<< lookupInj name
subtype t t_patt
apply (T.PEnum (coerce name), t_patt)
-- Example
-- Γ ∋ (K : A) let A = ∀α. A₁ -> A₂ -> Tτs
-- Γ ⊢ [ά/α]Tτs <: B ⊣ Θ₁
-- Θ ⊢ p₁ ↑ [Θ][ά/α]A₁ ⊣ Θ₂
-- Θ₂ ⊢ p₂ ↑ [Θ₂][ά/α]A₂ ⊣ Δ
-- ---------------------------
-- Γ ⊢ K p₁ p₂ ↑ B ⊣ Δ
PInj name ps -> do
t_inj <- maybeToRightM "unknown constructor" =<< lookupInj name
let ts = getArgs t_inj
unless (length ts == length ps)
$ throwError "Wrong number of arguments!"
-- [ά/α]
sub <- substituteTVarsOf t_inj
subtype (sub $ getDataId t_inj) t_patt
let check p t = checkPattern p =<< apply (sub t)
ps' <- zipWithM check ps ts
apply (T.PInj (coerce name) (map fst ps'), t_patt)
where
substituteTVarsOf = \case
TAll tvar t -> do
tevar <- fresh
(substitute tvar tevar .) <$> substituteTVarsOf t
_ -> pure id
getArgs = \case
TAll _ t -> getArgs t
t -> go [] t
where
go acc = \case
TFun t1 t2 -> go (snoc t1 acc) t2
_ -> acc
-- | Γ ⊢ e ↓ A ⊣ Δ
-- Under input context Γ, e infers output type A, with output context ∆
infer :: Exp -> Tc (T.ExpT' Type)
infer (ELit lit) = apply (T.ELit lit, litType lit)
-- Γ ∋ (x : A) Γ ∌ (x : A)
@ -273,14 +337,23 @@ infer (EAdd e1 e2) = do
e2' <- check e2 int
apply (T.EAdd e1' e2', int)
-- Θ ⊢ Π ∷ A ↓ C ⊣ Δ
-- Γ ⊢ e ↓ A ⊣ Θ Δ ⊢ Π covers [Δ]A TODO
-- ---------------------------------------
-- Γ ⊢ e ↑ A ⊣ Θ Θ ⊢ Π ∷ [Θ]A ↑ C ⊣ Δ
-- ------------------------------------ Case
-- Γ ⊢ case e of Π ↓ C ⊣ Δ
infer (ECase scrut branches) = do
(scrut', t_scrut) <- infer scrut
(branches', t_return) <- inferBranches branches t_scrut
apply (T.ECase (scrut', t_scrut) branches', t_return)
infer (ECase scrut pi) = do
(scrut', a) <- infer scrut
case pi of
[] -> apply (T.ECase (scrut', a) [], a)
(Branch _ e):_ -> do
(_, c)<- infer e
(pi', c') <- foldlM go ([], c) pi
apply (T.ECase (scrut', a) pi', c')
where
go (bs, c) (Branch p e) = do
p' <- checkPattern p =<< apply a
e'@(_, c') <- infer e
subtype c' c
apply (T.Branch p' e' : bs, c')
-- | Γ ⊢ A • e ⇓ C ⊣ Δ
-- Under input context Γ , applying a function of type A to e infers type C, with output context ∆
@ -319,112 +392,6 @@ applyInfer (TFun a c) e = do
applyInfer a e = throwError ("Cannot apply type " ++ show a ++ " with expression " ++ show e)
---------------------------------------------------------------------------
-- * Pattern matching
---------------------------------------------------------------------------
-- Γ ⊢ p ⇒ e ∷ A ↓ B ⊣ Θ
-- Θ ⊢ Π ∷ [Θ]A ↓ C ⊣ Δ
-- [Δ]B <: C
-- ---------------------------
-- Γ ⊢ (p ⇒ e),Π ∷ A ↓ C ⊣ Δ
inferBranches :: [Branch] -> Type -> Tc ([T.Branch' Type], Type)
inferBranches branches t_patt = do
(branches', ts_exp) <- inferBranches' t_patt branches
t_exp <- case ts_exp of
[] -> pure t_patt
t:_ -> do
zipWithM_ (onMM subtype apply) (init ts_exp) (tail ts_exp)
apply t
apply (branches', t_exp)
where
inferBranches' = go [] []
where
go branches ts_exp t = \case
[] -> pure (branches, ts_exp)
b:bs -> do
(b', t_e) <- inferBranch b t
t' <- apply t
go (snoc b' branches) (snoc t_e ts_exp) t' bs
-- Γ ⊢ p ↑ A ⊣ Θ Θ ⊢ e ↓ C ⊣ Δ
-- -------------------------------
-- Γ ⊢ p ⇒ e ∷ A ↓ C ⊣ Δ
inferBranch :: Branch -> Type -> Tc (T.Branch' Type, Type)
inferBranch (Branch patt exp) t_patt = do
patt' <- checkPattern patt t_patt
(exp', t_exp) <- infer exp
apply (T.Branch patt' (exp', t_exp), t_exp)
checkPattern :: Pattern -> Type -> Tc (T.Pattern' Type, Type)
checkPattern patt t_patt = case patt of
-- -------------------
-- Γ ⊢ x ↑ A ⊣ Γ,(x:A)
PVar x -> do
insertEnv $ EnvVar x t_patt
apply (T.PVar (coerce x, t_patt), t_patt)
-- -------------
-- Γ ⊢ _ ↑ A ⊣ Γ
PCatch -> apply (T.PCatch, t_patt)
-- Γ ⊢ τ ↓ A ⊣ Γ Γ ⊢ A <: B ⊣ Δ
-- ------------------------------
-- Γ ⊢ τ ↑ B ⊣ Δ
PLit lit -> do
subtype (litType lit) t_patt
apply (T.PLit (lit, t_patt), t_patt)
-- Γ ∋ (K : A) Γ ⊢ A <: B ⊣ Δ
-- ---------------------------
-- Γ ⊢ K ↑ B ⊣ Δ
PEnum name -> do
t <- maybeToRightM ("Unknown constructor " ++ show name)
=<< lookupInj name
subtype t t_patt
apply (T.PEnum (coerce name), t_patt)
-- Example
-- Γ ∋ (K : A) let A = ∀α. A₁ -> A₂ -> Tτs
-- Γ ⊢ [ά/α]Tτs <: B ⊣ Θ₁
-- Θ ⊢ p₁ ↑ [Θ][ά/α]A₁ ⊣ Θ₂
-- Θ ⊢ p₂ ↑ [Θ][ά/α]A₂ ⊣ Δ
-- ---------------------------
-- Γ ⊢ K p₁ p₂ ↑ B ⊣ Δ
PInj name ps -> do
t_inj <- maybeToRightM "unknown constructor" =<< lookupInj name
let ts = getParams t_inj
unless (length ts == length ps) $
throwError "Wrong number of arguments!"
sub <- substituteTVarsOf t_inj
subtype (sub $ getDataId t_inj) t_patt
let checkP p t = checkPattern p =<< apply (sub t)
ps' <- zipWithM checkP ps ts
apply (T.PInj (coerce name) (map fst ps'), t_patt)
where
substituteTVarsOf = \case
TAll tvar t -> do
tevar <- fresh
(substitute tvar tevar .) <$> substituteTVarsOf t
_ -> pure id
getParams = \case
TAll _ t -> getParams t
t -> go [] t
where
go acc = \case
TFun t1 t2 -> go (snoc t1 acc) t2
_ -> acc
getDataId typ = case typ of
TAll _ t -> getDataId t
TFun _ t -> getDataId t
TData {} -> typ
---------------------------------------------------------------------------
-- * Subtyping rules
---------------------------------------------------------------------------
@ -482,7 +449,6 @@ subtype (TEVar alpha) a | notElem alpha $ frees a = instantiateL alpha a
-- Γ[ά] ⊢ A <: ά ⊣ Δ
subtype a (TEVar alpha) | notElem alpha $ frees a = instantiateR a alpha
subtype t1 t2 = case (t1, t2) of
(TData name1 typs1, TData name2 typs2)
@ -564,14 +530,13 @@ instantiateL alpha a = gets env >>= \env -> go env alpha a
putEnv env_l
go _ alpha a = error $ "Trying to instantiateL: " ++ ppT (TEVar alpha)
++ " <: " ++ ppT a
++ " <: " ++ ppT a
-- | Γ ⊢ A =:< ά ⊣ Δ
-- Under input context Γ, instantiate ά such that A <: ά, with output context ∆
instantiateR :: Type -> TEVar -> Tc ()
instantiateR a alpha = gets env >>= \env -> go env a alpha
where
-- Γ ⊢ τ
-- ----------------------------- InstRSolve
-- Γ,ά,Γ' ⊢ τ =:< ά ⊣ Γ,(ά=τ),Γ'
@ -588,11 +553,9 @@ instantiateR a alpha = gets env >>= \env -> go env a alpha
let (env_l, env_r) = splitOn (EnvTEVar epsilon) env
putEnv $ (env_l :|> EnvTEVarSolved epsilon (TEVar alpha)) <> env_r
-- Γ[ά₂ά₁,(ά=ά₁→ά₂)] ⊢ A₁ :=< ά₁ ⊣ Θ Θ ⊢ ά₂ =:< [Θ]A₂ ⊣ Δ
-- ------------------------------------------------------- InstRArr
-- Γ[ά] ⊢ A₁ → A₂ =:< ά ⊣ Δ
-- Γ[ά₂ά₁,(ά=ά₁→ά₂)] ⊢ A₁ :=< ά₁ ⊣ Θ Θ ⊢ ά₂ =:< [Θ]A₂ ⊣ Δ
-- ------------------------------------------------------- InstRArr
-- Γ[ά] ⊢ A₁ → A₂ =:< ά ⊣ Δ
go _ (TFun a1 a2) alpha = do
alpha1 <- fresh
alpha2 <- fresh
@ -603,24 +566,19 @@ instantiateR a alpha = gets env >>= \env -> go env a alpha
a2' <- apply a2
instantiateR a2' alpha2
-- Γ[ά],▶έ,ε ⊢ [έ/ε]E =:< ά ⊣ Δ,▶έ,Δ'
-- ---------------------------------- InstRAIIL
-- Γ[ά] ⊢ ∀ε.Ε =:< ά ⊣ Δ
-- Γ[ά],▶έ,ε ⊢ [έ/ε]E =:< ά ⊣ Δ,▶έ,Δ'
-- ---------------------------------- InstRAIIL
-- Γ[ά] ⊢ ∀ε.Ε =:< ά ⊣ Δ
go env (TAll epsilon e) alpha = do
epsilon' <- fresh
insertEnv $ EnvMark epsilon'
insertEnv $ EnvTVar epsilon
instantiateR (substitute epsilon epsilon' e) alpha
let (env_l, _) = splitOn (EnvMark epsilon') env
putEnv env_l
epsilon' <- fresh
insertEnv $ EnvMark epsilon'
insertEnv $ EnvTVar epsilon
instantiateR (substitute epsilon epsilon' e) alpha
let (env_l, _) = splitOn (EnvMark epsilon') env
putEnv env_l
go _ a alpha = error $ "Trying to instantiateR: " ++ ppT a ++ " <: "
++ ppT (TEVar alpha)
++ ppT (TEVar alpha)
---------------------------------------------------------------------------
-- * Auxiliary
@ -713,35 +671,6 @@ fresh = do
modify $ \cxt -> cxt { next_tevar = succ cxt.next_tevar }
pure tevar
getVars :: Type -> [Type]
getVars = fst . partitionType
getReturn :: Type -> Type
getReturn = snd . partitionType
-- | Partion type into variable types and return type.
--
-- ∀a.∀b. a → (∀c. c → c) → b
-- ([a, ∀c. c → c], b)
--
-- Unsure if foralls should be added to the return type or not.
-- FIXME
partitionType :: Type -> ([Type], Type)
partitionType = go [] . skipForalls'
where
go acc t = case t of
TFun t1 t2 -> go (snoc t1 acc) t2
_ -> (acc, t)
skipForalls' :: Type -> Type
skipForalls' = snd . skipForalls
skipForalls :: Type -> ([Type -> Type], Type)
skipForalls = go []
where
go acc typ = case typ of
TAll tvar t -> go (snoc (TAll tvar) acc) t
_ -> (acc, typ)
isComplete :: Env -> Bool
isComplete = isNothing . S.findIndexL unSolvedTEVar
@ -750,6 +679,12 @@ isComplete = isNothing . S.findIndexL unSolvedTEVar
EnvTEVar _ -> True
_ -> False
getDataId :: Type -> Type
getDataId typ = case typ of
TAll _ t -> getDataId t
TFun _ t -> getDataId t
TData {} -> typ
toTVar :: Type -> Err TVar
toTVar = \case
TVar tvar -> pure tvar
@ -764,7 +699,6 @@ lookupSig x = gets (Map.lookup x . sig)
insertSig :: LIdent -> Type -> Tc ()
insertSig name t = modify $ \cxt -> cxt { sig = Map.insert name t cxt.sig }
lookupEnv :: LIdent -> Tc (Maybe Type)
lookupEnv x = gets (findId . env)
where
@ -786,7 +720,6 @@ modifyEnv f =
pattern DBind' name vars exp = DBind (Bind name vars exp)
pattern DSig' name typ = DSig (Sig name typ)
---------------------------------------------------------------------------
-- * Apply
---------------------------------------------------------------------------