Simplify pattern matching

This commit is contained in:
Martin Fredin 2023-04-25 22:59:33 +02:00
parent 9ffcbf66b9
commit e138cb27ec

View file

@ -4,19 +4,18 @@
{-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE PatternSynonyms #-}
{-# OPTIONS_GHC -Wno-incomplete-patterns #-} {-# OPTIONS_GHC -Wno-incomplete-patterns #-}
module TypeChecker.TypeCheckerBidir (typecheck, getVars) where module TypeChecker.TypeCheckerBidir (typecheck) where
import Auxiliary (int, liftMM2, litType, import Auxiliary (int, litType, maybeToRightM, snoc)
maybeToRightM, onM, onMM, snoc) import Control.Applicative (Applicative (liftA2), (<|>))
import Control.Applicative (Alternative, Applicative (liftA2),
(<|>))
import Control.Monad.Except (ExceptT, MonadError (throwError), import Control.Monad.Except (ExceptT, MonadError (throwError),
runExceptT, unless, zipWithM, runExceptT, unless, zipWithM,
zipWithM_) zipWithM_)
import Control.Monad.Extra (fromMaybeM, maybeM) import Control.Monad.Extra (fromMaybeM)
import Control.Monad.State (MonadState, State, evalState, gets, import Control.Monad.State (MonadState, State, evalState, gets,
modify) modify)
import Data.Coerce (coerce) import Data.Coerce (coerce)
import Data.Foldable (foldlM)
import Data.Function (on) import Data.Function (on)
import Data.List (intercalate) import Data.List (intercalate)
import Data.Map (Map) import Data.Map (Map)
@ -38,7 +37,8 @@ import qualified TypeChecker.TypeCheckerIr as T
-- --
-- TODO -- TODO
-- • Fix problems with types in Pattern/Branch in TypeCheckerIr -- • Fix problems with types in Pattern/Branch in TypeCheckerIr
-- • Fix the different type getters functions (e.g. partitionType) functions -- • Remove EAdd
-- • Add kinds!!
data EnvElem = EnvVar LIdent Type -- ^ Term variable typing. x : A data EnvElem = EnvVar LIdent Type -- ^ Term variable typing. x : A
| EnvTVar TVar -- ^ Universal type variable. α | EnvTVar TVar -- ^ Universal type variable. α
@ -140,7 +140,7 @@ typecheckInj :: Inj -> UIdent -> [TVar] -> Err (T.Inj' Type)
typecheckInj (Inj inj_name inj_typ) name tvars typecheckInj (Inj inj_name inj_typ) name tvars
| not $ boundTVars tvars inj_typ | not $ boundTVars tvars inj_typ
= throwError "Unbound type variables" = throwError "Unbound type variables"
| TData name' typs <- getReturn inj_typ | TData name' typs <- getDataId inj_typ
, name' == name , name' == name
, Right tvars' <- mapM toTVar typs , Right tvars' <- mapM toTVar typs
, all (`elem` tvars) tvars' , all (`elem` tvars) tvars'
@ -149,7 +149,7 @@ typecheckInj (Inj inj_name inj_typ) name tvars
= throwError $ unwords = throwError $ unwords
["Bad type constructor: ", show name ["Bad type constructor: ", show name
, "\nExpected: ", ppT . TData name $ map TVar tvars , "\nExpected: ", ppT . TData name $ map TVar tvars
, "\nActual: ", ppT $ getReturn inj_typ , "\nActual: ", ppT $ getDataId inj_typ
] ]
where where
boundTVars :: [TVar] -> Type -> Bool boundTVars :: [TVar] -> Type -> Bool
@ -161,6 +161,8 @@ typecheckInj (Inj inj_name inj_typ) name tvars
TLit _ -> True TLit _ -> True
TEVar _ -> error "TEVar in data type declaration" TEVar _ -> error "TEVar in data type declaration"
--------------------------------------------------------------------------- ---------------------------------------------------------------------------
-- * Typing rules -- * Typing rules
--------------------------------------------------------------------------- ---------------------------------------------------------------------------
@ -200,10 +202,72 @@ check e b = do
subtype a b' subtype a b'
apply (e', b) apply (e', b)
checkPattern :: Pattern -> Type -> Tc (T.Pattern' Type, Type)
checkPattern patt t_patt = case patt of
-- -------------------
-- Γ ⊢ x ↑ A ⊣ Γ,(x:A)
PVar x -> do
insertEnv $ EnvVar x t_patt
apply (T.PVar (coerce x, t_patt), t_patt)
-- -------------
-- Γ ⊢ _ ↑ A ⊣ Γ
PCatch -> apply (T.PCatch, t_patt)
-- Γ ⊢ τ ↓ A ⊣ Γ Γ ⊢ A <: B ⊣ Δ
-- ------------------------------
-- Γ ⊢ τ ↑ B ⊣ Δ
PLit lit -> do
subtype (litType lit) t_patt
apply (T.PLit (lit, t_patt), t_patt)
-- Γ ∋ (K : A) Γ ⊢ A <: B ⊣ Δ
-- ---------------------------
-- Γ ⊢ K ↑ B ⊣ Δ
PEnum name -> do
t <- maybeToRightM ("Unknown constructor " ++ show name)
=<< lookupInj name
subtype t t_patt
apply (T.PEnum (coerce name), t_patt)
-- Example
-- Γ ∋ (K : A) let A = ∀α. A₁ -> A₂ -> Tτs
-- Γ ⊢ [ά/α]Tτs <: B ⊣ Θ₁
-- Θ ⊢ p₁ ↑ [Θ][ά/α]A₁ ⊣ Θ₂
-- Θ₂ ⊢ p₂ ↑ [Θ₂][ά/α]A₂ ⊣ Δ
-- ---------------------------
-- Γ ⊢ K p₁ p₂ ↑ B ⊣ Δ
PInj name ps -> do
t_inj <- maybeToRightM "unknown constructor" =<< lookupInj name
let ts = getArgs t_inj
unless (length ts == length ps)
$ throwError "Wrong number of arguments!"
-- [ά/α]
sub <- substituteTVarsOf t_inj
subtype (sub $ getDataId t_inj) t_patt
let check p t = checkPattern p =<< apply (sub t)
ps' <- zipWithM check ps ts
apply (T.PInj (coerce name) (map fst ps'), t_patt)
where
substituteTVarsOf = \case
TAll tvar t -> do
tevar <- fresh
(substitute tvar tevar .) <$> substituteTVarsOf t
_ -> pure id
getArgs = \case
TAll _ t -> getArgs t
t -> go [] t
where
go acc = \case
TFun t1 t2 -> go (snoc t1 acc) t2
_ -> acc
-- | Γ ⊢ e ↓ A ⊣ Δ -- | Γ ⊢ e ↓ A ⊣ Δ
-- Under input context Γ, e infers output type A, with output context ∆ -- Under input context Γ, e infers output type A, with output context ∆
infer :: Exp -> Tc (T.ExpT' Type) infer :: Exp -> Tc (T.ExpT' Type)
infer (ELit lit) = apply (T.ELit lit, litType lit) infer (ELit lit) = apply (T.ELit lit, litType lit)
-- Γ ∋ (x : A) Γ ∌ (x : A) -- Γ ∋ (x : A) Γ ∌ (x : A)
@ -273,14 +337,23 @@ infer (EAdd e1 e2) = do
e2' <- check e2 int e2' <- check e2 int
apply (T.EAdd e1' e2', int) apply (T.EAdd e1' e2', int)
-- Θ ⊢ Π ∷ A ↓ C ⊣ Δ -- Γ ⊢ e ↑ A ⊣ Θ Θ ⊢ Π ∷ [Θ]A ↑ C ⊣ Δ
-- Γ ⊢ e ↓ A ⊣ Θ Δ ⊢ Π covers [Δ]A TODO -- ------------------------------------ Case
-- ---------------------------------------
-- Γ ⊢ case e of Π ↓ C ⊣ Δ -- Γ ⊢ case e of Π ↓ C ⊣ Δ
infer (ECase scrut branches) = do infer (ECase scrut pi) = do
(scrut', t_scrut) <- infer scrut (scrut', a) <- infer scrut
(branches', t_return) <- inferBranches branches t_scrut case pi of
apply (T.ECase (scrut', t_scrut) branches', t_return) [] -> apply (T.ECase (scrut', a) [], a)
(Branch _ e):_ -> do
(_, c)<- infer e
(pi', c') <- foldlM go ([], c) pi
apply (T.ECase (scrut', a) pi', c')
where
go (bs, c) (Branch p e) = do
p' <- checkPattern p =<< apply a
e'@(_, c') <- infer e
subtype c' c
apply (T.Branch p' e' : bs, c')
-- | Γ ⊢ A • e ⇓ C ⊣ Δ -- | Γ ⊢ A • e ⇓ C ⊣ Δ
-- Under input context Γ , applying a function of type A to e infers type C, with output context ∆ -- Under input context Γ , applying a function of type A to e infers type C, with output context ∆
@ -319,112 +392,6 @@ applyInfer (TFun a c) e = do
applyInfer a e = throwError ("Cannot apply type " ++ show a ++ " with expression " ++ show e) applyInfer a e = throwError ("Cannot apply type " ++ show a ++ " with expression " ++ show e)
---------------------------------------------------------------------------
-- * Pattern matching
---------------------------------------------------------------------------
-- Γ ⊢ p ⇒ e ∷ A ↓ B ⊣ Θ
-- Θ ⊢ Π ∷ [Θ]A ↓ C ⊣ Δ
-- [Δ]B <: C
-- ---------------------------
-- Γ ⊢ (p ⇒ e),Π ∷ A ↓ C ⊣ Δ
inferBranches :: [Branch] -> Type -> Tc ([T.Branch' Type], Type)
inferBranches branches t_patt = do
(branches', ts_exp) <- inferBranches' t_patt branches
t_exp <- case ts_exp of
[] -> pure t_patt
t:_ -> do
zipWithM_ (onMM subtype apply) (init ts_exp) (tail ts_exp)
apply t
apply (branches', t_exp)
where
inferBranches' = go [] []
where
go branches ts_exp t = \case
[] -> pure (branches, ts_exp)
b:bs -> do
(b', t_e) <- inferBranch b t
t' <- apply t
go (snoc b' branches) (snoc t_e ts_exp) t' bs
-- Γ ⊢ p ↑ A ⊣ Θ Θ ⊢ e ↓ C ⊣ Δ
-- -------------------------------
-- Γ ⊢ p ⇒ e ∷ A ↓ C ⊣ Δ
inferBranch :: Branch -> Type -> Tc (T.Branch' Type, Type)
inferBranch (Branch patt exp) t_patt = do
patt' <- checkPattern patt t_patt
(exp', t_exp) <- infer exp
apply (T.Branch patt' (exp', t_exp), t_exp)
checkPattern :: Pattern -> Type -> Tc (T.Pattern' Type, Type)
checkPattern patt t_patt = case patt of
-- -------------------
-- Γ ⊢ x ↑ A ⊣ Γ,(x:A)
PVar x -> do
insertEnv $ EnvVar x t_patt
apply (T.PVar (coerce x, t_patt), t_patt)
-- -------------
-- Γ ⊢ _ ↑ A ⊣ Γ
PCatch -> apply (T.PCatch, t_patt)
-- Γ ⊢ τ ↓ A ⊣ Γ Γ ⊢ A <: B ⊣ Δ
-- ------------------------------
-- Γ ⊢ τ ↑ B ⊣ Δ
PLit lit -> do
subtype (litType lit) t_patt
apply (T.PLit (lit, t_patt), t_patt)
-- Γ ∋ (K : A) Γ ⊢ A <: B ⊣ Δ
-- ---------------------------
-- Γ ⊢ K ↑ B ⊣ Δ
PEnum name -> do
t <- maybeToRightM ("Unknown constructor " ++ show name)
=<< lookupInj name
subtype t t_patt
apply (T.PEnum (coerce name), t_patt)
-- Example
-- Γ ∋ (K : A) let A = ∀α. A₁ -> A₂ -> Tτs
-- Γ ⊢ [ά/α]Tτs <: B ⊣ Θ₁
-- Θ ⊢ p₁ ↑ [Θ][ά/α]A₁ ⊣ Θ₂
-- Θ ⊢ p₂ ↑ [Θ][ά/α]A₂ ⊣ Δ
-- ---------------------------
-- Γ ⊢ K p₁ p₂ ↑ B ⊣ Δ
PInj name ps -> do
t_inj <- maybeToRightM "unknown constructor" =<< lookupInj name
let ts = getParams t_inj
unless (length ts == length ps) $
throwError "Wrong number of arguments!"
sub <- substituteTVarsOf t_inj
subtype (sub $ getDataId t_inj) t_patt
let checkP p t = checkPattern p =<< apply (sub t)
ps' <- zipWithM checkP ps ts
apply (T.PInj (coerce name) (map fst ps'), t_patt)
where
substituteTVarsOf = \case
TAll tvar t -> do
tevar <- fresh
(substitute tvar tevar .) <$> substituteTVarsOf t
_ -> pure id
getParams = \case
TAll _ t -> getParams t
t -> go [] t
where
go acc = \case
TFun t1 t2 -> go (snoc t1 acc) t2
_ -> acc
getDataId typ = case typ of
TAll _ t -> getDataId t
TFun _ t -> getDataId t
TData {} -> typ
--------------------------------------------------------------------------- ---------------------------------------------------------------------------
-- * Subtyping rules -- * Subtyping rules
--------------------------------------------------------------------------- ---------------------------------------------------------------------------
@ -482,7 +449,6 @@ subtype (TEVar alpha) a | notElem alpha $ frees a = instantiateL alpha a
-- Γ[ά] ⊢ A <: ά ⊣ Δ -- Γ[ά] ⊢ A <: ά ⊣ Δ
subtype a (TEVar alpha) | notElem alpha $ frees a = instantiateR a alpha subtype a (TEVar alpha) | notElem alpha $ frees a = instantiateR a alpha
subtype t1 t2 = case (t1, t2) of subtype t1 t2 = case (t1, t2) of
(TData name1 typs1, TData name2 typs2) (TData name1 typs1, TData name2 typs2)
@ -571,7 +537,6 @@ instantiateL alpha a = gets env >>= \env -> go env alpha a
instantiateR :: Type -> TEVar -> Tc () instantiateR :: Type -> TEVar -> Tc ()
instantiateR a alpha = gets env >>= \env -> go env a alpha instantiateR a alpha = gets env >>= \env -> go env a alpha
where where
-- Γ ⊢ τ -- Γ ⊢ τ
-- ----------------------------- InstRSolve -- ----------------------------- InstRSolve
-- Γ,ά,Γ' ⊢ τ =:< ά ⊣ Γ,(ά=τ),Γ' -- Γ,ά,Γ' ⊢ τ =:< ά ⊣ Γ,(ά=τ),Γ'
@ -588,8 +553,6 @@ instantiateR a alpha = gets env >>= \env -> go env a alpha
let (env_l, env_r) = splitOn (EnvTEVar epsilon) env let (env_l, env_r) = splitOn (EnvTEVar epsilon) env
putEnv $ (env_l :|> EnvTEVarSolved epsilon (TEVar alpha)) <> env_r putEnv $ (env_l :|> EnvTEVarSolved epsilon (TEVar alpha)) <> env_r
-- Γ[ά₂ά₁,(ά=ά₁→ά₂)] ⊢ A₁ :=< ά₁ ⊣ Θ Θ ⊢ ά₂ =:< [Θ]A₂ ⊣ Δ -- Γ[ά₂ά₁,(ά=ά₁→ά₂)] ⊢ A₁ :=< ά₁ ⊣ Θ Θ ⊢ ά₂ =:< [Θ]A₂ ⊣ Δ
-- ------------------------------------------------------- InstRArr -- ------------------------------------------------------- InstRArr
-- Γ[ά] ⊢ A₁ → A₂ =:< ά ⊣ Δ -- Γ[ά] ⊢ A₁ → A₂ =:< ά ⊣ Δ
@ -603,8 +566,6 @@ instantiateR a alpha = gets env >>= \env -> go env a alpha
a2' <- apply a2 a2' <- apply a2
instantiateR a2' alpha2 instantiateR a2' alpha2
-- Γ[ά],▶έ,ε ⊢ [έ/ε]E =:< ά ⊣ Δ,▶έ,Δ' -- Γ[ά],▶έ,ε ⊢ [έ/ε]E =:< ά ⊣ Δ,▶έ,Δ'
-- ---------------------------------- InstRAIIL -- ---------------------------------- InstRAIIL
-- Γ[ά] ⊢ ∀ε.Ε =:< ά ⊣ Δ -- Γ[ά] ⊢ ∀ε.Ε =:< ά ⊣ Δ
@ -619,9 +580,6 @@ instantiateR a alpha = gets env >>= \env -> go env a alpha
go _ a alpha = error $ "Trying to instantiateR: " ++ ppT a ++ " <: " go _ a alpha = error $ "Trying to instantiateR: " ++ ppT a ++ " <: "
++ ppT (TEVar alpha) ++ ppT (TEVar alpha)
--------------------------------------------------------------------------- ---------------------------------------------------------------------------
-- * Auxiliary -- * Auxiliary
--------------------------------------------------------------------------- ---------------------------------------------------------------------------
@ -713,35 +671,6 @@ fresh = do
modify $ \cxt -> cxt { next_tevar = succ cxt.next_tevar } modify $ \cxt -> cxt { next_tevar = succ cxt.next_tevar }
pure tevar pure tevar
getVars :: Type -> [Type]
getVars = fst . partitionType
getReturn :: Type -> Type
getReturn = snd . partitionType
-- | Partion type into variable types and return type.
--
-- ∀a.∀b. a → (∀c. c → c) → b
-- ([a, ∀c. c → c], b)
--
-- Unsure if foralls should be added to the return type or not.
-- FIXME
partitionType :: Type -> ([Type], Type)
partitionType = go [] . skipForalls'
where
go acc t = case t of
TFun t1 t2 -> go (snoc t1 acc) t2
_ -> (acc, t)
skipForalls' :: Type -> Type
skipForalls' = snd . skipForalls
skipForalls :: Type -> ([Type -> Type], Type)
skipForalls = go []
where
go acc typ = case typ of
TAll tvar t -> go (snoc (TAll tvar) acc) t
_ -> (acc, typ)
isComplete :: Env -> Bool isComplete :: Env -> Bool
isComplete = isNothing . S.findIndexL unSolvedTEVar isComplete = isNothing . S.findIndexL unSolvedTEVar
@ -750,6 +679,12 @@ isComplete = isNothing . S.findIndexL unSolvedTEVar
EnvTEVar _ -> True EnvTEVar _ -> True
_ -> False _ -> False
getDataId :: Type -> Type
getDataId typ = case typ of
TAll _ t -> getDataId t
TFun _ t -> getDataId t
TData {} -> typ
toTVar :: Type -> Err TVar toTVar :: Type -> Err TVar
toTVar = \case toTVar = \case
TVar tvar -> pure tvar TVar tvar -> pure tvar
@ -764,7 +699,6 @@ lookupSig x = gets (Map.lookup x . sig)
insertSig :: LIdent -> Type -> Tc () insertSig :: LIdent -> Type -> Tc ()
insertSig name t = modify $ \cxt -> cxt { sig = Map.insert name t cxt.sig } insertSig name t = modify $ \cxt -> cxt { sig = Map.insert name t cxt.sig }
lookupEnv :: LIdent -> Tc (Maybe Type) lookupEnv :: LIdent -> Tc (Maybe Type)
lookupEnv x = gets (findId . env) lookupEnv x = gets (findId . env)
where where
@ -786,7 +720,6 @@ modifyEnv f =
pattern DBind' name vars exp = DBind (Bind name vars exp) pattern DBind' name vars exp = DBind (Bind name vars exp)
pattern DSig' name typ = DSig (Sig name typ) pattern DSig' name typ = DSig (Sig name typ)
--------------------------------------------------------------------------- ---------------------------------------------------------------------------
-- * Apply -- * Apply
--------------------------------------------------------------------------- ---------------------------------------------------------------------------