Fix bad inference on case expression, and make pretty for report

This commit is contained in:
Martin Fredin 2023-04-08 21:52:57 +02:00
parent 29de6c49e4
commit a109b3010d
6 changed files with 406 additions and 391 deletions

View file

@ -2,7 +2,14 @@ data Bool () where
True : Bool () True : Bool ()
False : Bool () False : Bool ()
main : Bool () -> a -> Int -- Both valid
main b = case b of -- f : Bool () -> a -> Int
False => (\x. 1) f : Bool () -> (forall a. a -> Int)
True => (\x. 0) f b = case b of
False => (\x. 0 : forall a. a -> Int)
True => (\x. 1 : forall a. a -> Int)
main : Int
main = (f True) 'h'

View file

@ -2,7 +2,7 @@ data Bool () where
True : Bool () True : Bool ()
False : Bool () False : Bool ()
ifThenElse : forall a. Bool () -> a -> a -> a ifThenElse : Bool () -> a -> a -> a
ifThenElse b if else = case b of ifThenElse b if else = case b of
True => if True => if
False => else False => else

View file

@ -4,9 +4,8 @@
module Auxiliary (module Auxiliary) where module Auxiliary (module Auxiliary) where
import Control.Applicative (Applicative (liftA2))
import Control.Monad.Error.Class (liftEither) import Control.Monad.Error.Class (liftEither)
import Control.Monad.Except (MonadError) import Control.Monad.Except (MonadError, liftM2)
import Data.Either.Combinators (maybeToRight) import Data.Either.Combinators (maybeToRight)
import Data.List (foldl') import Data.List (foldl')
import Grammar.Abs import Grammar.Abs
@ -31,8 +30,11 @@ mapAccumM f = go
(acc'', xs') <- go acc' xs (acc'', xs') <- go acc' xs
pure (acc'', x' : xs') pure (acc'', x' : xs')
onMM :: Monad m => (b -> b -> m c) -> (a -> m b) -> a -> a -> m c
onMM f g x y = liftMM2 f (g x) (g y)
onM :: Monad m => (b -> b -> c) -> (a -> m b) -> a -> a -> m c onM :: Monad m => (b -> b -> c) -> (a -> m b) -> a -> a -> m c
onM f g x y = liftA2 f (g x) (g y) onM f g x y = liftM2 f (g x) (g y)
unzip4 :: [(a, b, c, d)] -> ([a], [b], [c], [d]) unzip4 :: [(a, b, c, d)] -> ([a], [b], [c], [d])
unzip4 = unzip4 =
@ -42,6 +44,12 @@ unzip4 =
) )
([], [], [], []) ([], [], [], [])
liftMM2 :: Monad m => (a -> b -> m c) -> m a -> m b -> m c
liftMM2 f m1 m2 = do
x1 <- m1
x2 <- m2
f x1 x2
litType :: Lit -> Type litType :: Lit -> Type
litType (LInt _) = int litType (LInt _) = int
litType (LChar _) = char litType (LChar _) = char

View file

@ -9,6 +9,7 @@ import Data.Coerce (coerce)
import Data.Tuple.Extra (secondM) import Data.Tuple.Extra (secondM)
import qualified Grammar.Abs as G import qualified Grammar.Abs as G
import Grammar.ErrM (Err) import Grammar.ErrM (Err)
import Grammar.Print (printTree)
import TypeChecker.TypeCheckerIr hiding (Type (..)) import TypeChecker.TypeCheckerIr hiding (Type (..))
@ -78,4 +79,4 @@ instance ReportTEVar G.Type Type where
G.TData name typs -> TData (coerce name) <$> reportTEVar typs G.TData name typs -> TData (coerce name) <$> reportTEVar typs
G.TFun t1 t2 -> liftA2 TFun (reportTEVar t1) (reportTEVar t2) G.TFun t1 t2 -> liftA2 TFun (reportTEVar t1) (reportTEVar t2)
G.TAll (G.MkTVar i) t -> TAll (MkTVar $ coerce i) <$> reportTEVar t G.TAll (G.MkTVar i) t -> TAll (MkTVar $ coerce i) <$> reportTEVar t
G.TEVar _ -> throwError "NewType TEVar!" G.TEVar tevar -> throwError ("Found TEVar: " ++ printTree tevar)

View file

@ -6,30 +6,28 @@
module TypeChecker.TypeCheckerBidir (typecheck, getVars) where module TypeChecker.TypeCheckerBidir (typecheck, getVars) where
import Auxiliary (int, litType, maybeToRightM, snoc) import Auxiliary (int, liftMM2, litType,
import Control.Applicative (Alternative, Applicative (liftA2), maybeToRightM, onM, onMM, snoc)
(<|>)) import Control.Applicative (Applicative (liftA2), (<|>))
import Control.Monad.Except (ExceptT, MonadError (throwError), import Control.Monad.Except (ExceptT, MonadError (throwError),
liftEither, runExceptT, unless, runExceptT, unless, zipWithM,
zipWithM, zipWithM_) zipWithM_)
import Control.Monad.State (MonadState, State, evalState, gets, import Control.Monad.State (State, evalState, gets, modify)
modify)
import Data.Coerce (coerce) import Data.Coerce (coerce)
import Data.Foldable (foldrM)
import Data.Function (on) import Data.Function (on)
import Data.List (intercalate, partition) import Data.List (intercalate)
import Data.Map (Map) import Data.Map (Map)
import qualified Data.Map as Map import qualified Data.Map as Map
import Data.Maybe (fromMaybe, isNothing) import Data.Maybe (fromMaybe, isNothing)
import Data.Sequence (Seq (..)) import Data.Sequence (Seq (..))
import qualified Data.Sequence as S import qualified Data.Sequence as S
import qualified Data.Set as Set import qualified Data.Set as Set
import Data.Tuple.Extra (second, secondM) import Data.Tuple.Extra (second)
import Debug.Trace (trace) import Debug.Trace (trace)
import Grammar.Abs import Grammar.Abs
import Grammar.ErrM import Grammar.ErrM
import Grammar.Print (printTree) import Grammar.Print (printTree)
import Prelude hiding (exp, id) import Prelude hiding (exp)
import qualified TypeChecker.TypeCheckerIr as T import qualified TypeChecker.TypeCheckerIr as T
-- Implementation is derived from the paper (Dunfield and Krishnaswami 2013) -- Implementation is derived from the paper (Dunfield and Krishnaswami 2013)
@ -59,8 +57,9 @@ data Cxt = Cxt
, data_injs :: Map UIdent Type -- ^ Data injections (constructors) K/inj : A , data_injs :: Map UIdent Type -- ^ Data injections (constructors) K/inj : A
} deriving (Show, Eq) } deriving (Show, Eq)
newtype Tc a = Tc { runTc :: ExceptT String (State Cxt) a } type Tc a = ExceptT String (State Cxt) a
deriving (Functor, Applicative, Monad, Alternative, MonadState Cxt, MonadError String) -- deriving (Functor, Applicative, Monad, Alternative, MonadState Cxt, MonadError String)
initCxt :: [Def] -> Cxt initCxt :: [Def] -> Cxt
initCxt defs = Cxt initCxt defs = Cxt
@ -96,7 +95,7 @@ typecheck (Program defs) = do
typecheckBinds :: Cxt -> [Bind] -> Err [T.Bind' Type] typecheckBinds :: Cxt -> [Bind] -> Err [T.Bind' Type]
typecheckBinds cxt = flip evalState cxt typecheckBinds cxt = flip evalState cxt
. runExceptT . runExceptT
. runTc -- . runTc
. mapM typecheckBind . mapM typecheckBind
typecheckBind :: Bind -> Tc (T.Bind' Type) typecheckBind :: Bind -> Tc (T.Bind' Type)
@ -106,10 +105,8 @@ typecheckBind (Bind name vars rhs) = do
(rhs', _) <- check (foldr EAbs rhs vars) t (rhs', _) <- check (foldr EAbs rhs vars) t
pure (T.Bind (coerce name, t) [] (rhs', t)) pure (T.Bind (coerce name, t) [] (rhs', t))
Nothing -> do Nothing -> do
(e, t) <- infer $ foldr EAbs rhs vars (e, t) <- apply =<< infer (foldr EAbs rhs vars)
t' <- applyEnv t pure (T.Bind (coerce name, t) [] (e, t))
e' <- applyEnvExp e
pure (T.Bind (coerce name, t') [] (e', t'))
env <- gets env env <- gets env
unless (isComplete env) err unless (isComplete env) err
insertSig (coerce name) typ insertSig (coerce name) typ
@ -162,6 +159,275 @@ typecheckInj (Inj inj_name inj_typ) name tvars
TLit _ -> True TLit _ -> True
TEVar _ -> error "TEVar in data type declaration" TEVar _ -> error "TEVar in data type declaration"
---------------------------------------------------------------------------
-- * Typing rules
---------------------------------------------------------------------------
-- | Γ ⊢ e ↑ A ⊣ Δ
-- Under input context Γ, e checks against input type A, with output context ∆
check :: Exp -> Type -> Tc (T.ExpT' Type)
check exp typ
-- Γ,α ⊢ e ↑ A ⊣ Δ,α
-- ------------------- ∀I
-- Γ ⊢ e ↑ ∀α.A ⊣ Δ
| TAll tvar t <- typ = do
let env_tvar = EnvTVar tvar
insertEnv env_tvar
exp' <- check exp t
(env_l, _) <- gets (splitOn env_tvar . env)
putEnv env_l
pure exp'
-- Γ,(x:A) ⊢ e ↑ B ⊢ Δ,(x:A),Θ
-- --------------------------- →I
-- Γ ⊢ λx.e ↑ A → B ⊣ Δ
| EAbs name e <- exp
, TFun t1 t2 <- typ = do
let env_var = EnvVar name t1
insertEnv env_var
e' <- check e t2
(env_l, _) <- gets (splitOn env_var . env)
putEnv env_l
pure (T.EAbs (coerce name) e', typ)
| otherwise = subsumption
where
-- Γ,α ⊢ e ↓ A ⊣ Θ Θ ⊢ [Θ]A <: [Θ]B ⊣ Δ
-- -------------------------------------- Sub
-- Γ ⊢ e ↑ B ⊣ Δ
subsumption = do
(exp', t) <- infer exp
typ' <- apply typ
subtype t typ'
apply (exp', t)
-- | Γ ⊢ e ↓ A ⊣ Δ
-- Under input context Γ, e infers output type A, with output context ∆
infer :: Exp -> Tc (T.ExpT' Type)
infer = \case
ELit lit -> pure (T.ELit lit, litType lit)
-- Γ ∋ (x : A) Γ ∌ (x : A)
-- ------------- Var --------------- Var'
-- Γ ⊢ x ↓ A ⊣ Γ Γ ⊢ x ↓ ά ⊣ Γ,ά
EVar name -> do
t <- liftA2 (<|>) (lookupEnv name) (lookupSig name) >>= \case
Just t -> pure t
Nothing -> do
tevar <- fresh
insertEnv (EnvTEVar tevar)
let t = TEVar tevar
insertEnv (EnvVar name t)
pure t
apply (T.EVar (coerce name), t)
EInj name -> do
t <- maybeToRightM ("Unknown constructor: " ++ show name)
=<< lookupInj name
apply (T.EInj $ coerce name, t)
-- Γ ⊢ A Γ ⊢ e ↑ A ⊣ Δ
-- --------------------- Anno
-- Γ ⊢ (e : A) ↓ A ⊣ Δ
EAnn e t -> do
_ <- gets $ (`wellFormed` t) . env
(e', _) <- check e t
apply (e', t)
-- Γ ⊢ e₁ ↓ A ⊣ Θ Γ ⊢ [Θ]A • ⇓ C ⊣ Δ
-- ----------------------------------- →E
-- Γ ⊢ e₁ e₂ ↓ C ⊣ Δ
EApp e1 e2 -> do
(e1', t) <- infer e1
(e2', t'') <- applyInfer t e2
apply (T.EApp (e1', t) e2', t'')
-- Γ,ά,έ,(x:ά) ⊢ e ↑ έ ⊣ Δ,(x:ά),Θ
-- ------------------------------- →I
-- Γ ⊢ λx.e ↓ ά → έ ⊣ Δ
EAbs name e -> do
tevar1 <- fresh
tevar2 <- fresh
insertEnv $ EnvTEVar tevar1
insertEnv $ EnvTEVar tevar2
let env_var = EnvVar name (TEVar tevar1)
insertEnv env_var
e' <- check e $ TEVar tevar2
dropTrailing env_var
let t_exp = on TFun TEVar tevar1 tevar2
apply (T.EAbs (coerce name) e', t_exp)
-- Γ ⊢ e ↓ A ⊣ Θ Θ,(x:A) ⊢ e' ↑ C ⊣ Δ,(x:A),Θ
-- -------------------------------------------- LetI
-- Γ ⊢ let x=e in e' ↑ C ⊣ Δ
ELet (Bind name [] rhs) e -> do -- TODO vars
(rhs', t_rhs) <- infer rhs
let env_var = EnvVar name t_rhs
insertEnv env_var
(e', t) <- infer e
(env_l, _) <- gets (splitOn env_var . env)
putEnv env_l
apply (T.ELet (T.Bind (coerce name, t_rhs) [] (rhs', t_rhs)) (e',t), t)
-- Γ ⊢ e₁ ↑ Int ⊣ Θ Θ ⊢ e₂ ↑ Int
-- --------------------------- +I
-- Γ ⊢ e₁ + e₂ ↓ Int ⊣ Δ
EAdd e1 e2 -> (, int) <$> onM T.EAdd (`check` int) e1 e2
-- Θ ⊢ Π ∷ A ↓ C ⊣ Δ
-- Γ ⊢ e ↓ A ⊣ Θ Δ ⊢ Π covers [Δ]A TODO
-- ---------------------------------------
-- Γ ⊢ case e of Π ↓ C ⊣ Δ
ECase scrut branches -> do
(scrut', t_scrut) <- infer scrut
(branches', t_return) <- inferBranches branches t_scrut
apply (T.ECase (scrut', t_scrut) branches', t_return)
-- | Γ ⊢ A • e ⇓ C ⊣ Δ
-- Under input context Γ , applying a function of type A to e infers type C, with output context ∆
-- Instantiate existential type variables until there is an arrow type.
applyInfer :: Type -> Exp -> Tc (T.ExpT' Type, Type)
applyInfer typ exp = case typ of
-- Γ,ά ⊢ [ά/α]A • e ⇓ C ⊣ Δ
-- ------------------------ ∀App
-- Γ ⊢ ∀α.A • e ⇓ C ⊣ Δ
TAll tvar t -> do
tevar <- fresh
insertEnv $ EnvTEVar tevar
let t' = substitute tvar tevar t
applyInfer t' exp
-- Γ[ά₂,ά₁,(ά=ά₁→ά₂)] ⊢ e ↑ ά₁ ⊣ Δ
-- ------------------------------- άApp
-- Γ[ά] ⊢ ά • e ⇓ ά₂ ⊣ Δ
TEVar tevar -> do
tevar1 <- fresh
tevar2 <- fresh
let env_tevar1 = EnvTEVar tevar1
env_tevar2 = EnvTEVar tevar2
t_fun = on TFun TEVar tevar1 tevar2
env_tevar_solved = EnvTEVarSolved tevar t_fun
(env_l, env_r) <- gets (splitOn (EnvTEVar tevar) . env)
putEnv $
(env_l :|> env_tevar2 :|> env_tevar1 :|> env_tevar_solved) <> env_r
expT' <- check exp $ TEVar tevar1
apply (expT', TEVar tevar2)
-- Γ ⊢ e ↑ A ⊣ Δ
-- --------------------- →App
-- Γ ⊢ A → C • e ⇓ C ⊣ Δ
TFun t1 t2 -> (, t2) <$> check exp t1
_ -> throwError ("Cannot apply type " ++ show typ ++ " with expression " ++ show exp)
---------------------------------------------------------------------------
-- * Pattern matching
---------------------------------------------------------------------------
-- Γ ⊢ p ⇒ e ∷ A ↓ B ⊣ Θ
-- Θ ⊢ Π ∷ [Θ]A ↓ C ⊣ Δ
-- [Δ]B <: C
-- ---------------------------
-- Γ ⊢ (p ⇒ e),Π ∷ A ↓ C ⊣ Δ
inferBranches :: [Branch] -> Type -> Tc ([T.Branch' Type], Type)
inferBranches branches t_patt = do
(branches', ts_exp) <- inferBranches' t_patt branches
traceTs "TYPES " ts_exp
t_exp <- case ts_exp of
[] -> pure t_patt
t:_ -> do
zipWithM_ (onMM subtype apply) (init ts_exp) (tail ts_exp)
apply t
apply (branches', t_exp)
where
inferBranches' = go [] []
where
go branches ts_exp t = \case
[] -> pure (branches, ts_exp)
b:bs -> do
(b', t_e) <- inferBranch b t
t' <- apply t
go (snoc b' branches) (snoc t_e ts_exp) t' bs
-- Γ ⊢ p ↑ A ⊣ Θ Θ ⊢ e ↓ C ⊣ Δ
-- -------------------------------
-- Γ ⊢ p ⇒ e ∷ A ↓ C ⊣ Δ
inferBranch :: Branch -> Type -> Tc (T.Branch' Type, Type)
inferBranch (Branch patt exp) t_patt = do
patt' <- checkPattern patt t_patt
(exp', t_exp) <- infer exp
apply (T.Branch patt' (exp', t_exp), t_exp)
checkPattern :: Pattern -> Type -> Tc (T.Pattern' Type, Type)
checkPattern patt t_patt = case patt of
-- -------------------
-- Γ ⊢ x ↑ A ⊣ Γ,(x:A)
PVar x -> do
insertEnv $ EnvVar x t_patt
apply (T.PVar (coerce x, t_patt), t_patt)
-- -------------
-- Γ ⊢ _ ↑ A ⊣ Γ
PCatch -> apply (T.PCatch, t_patt)
-- Γ ⊢ τ ↓ A ⊣ Γ Γ ⊢ A <: B ⊣ Δ
-- ------------------------------
-- Γ ⊢ τ ↑ B ⊣ Δ
PLit lit -> do
subtype (litType lit) t_patt
apply (T.PLit (lit, t_patt), t_patt)
-- Γ ∋ (K : A) Γ ⊢ A <: B ⊣ Δ
-- ---------------------------
-- Γ ⊢ K ↑ B ⊣ Δ
PEnum name -> do
t <- maybeToRightM ("Unknown constructor " ++ show name)
=<< lookupInj name
subtype t t_patt
apply (T.PEnum (coerce name), t_patt)
-- Example
-- Γ ∋ (K : A) let A = ∀α. A₁ -> A₂ -> Tτs
-- Γ ⊢ [ά/α]Tτs <: B ⊣ Θ₁
-- Θ ⊢ p₁ ↑ [Θ][ά/α]A₁ ⊣ Θ₂
-- Θ ⊢ p₂ ↑ [Θ][ά/α]A₂ ⊣ Δ
-- ---------------------------
-- Γ ⊢ K p₁ p₂ ↑ B ⊣ Δ
PInj name ps -> do
t_inj <- maybeToRightM "unknown constructor" =<< lookupInj name
sub <- substituteTVarsOf t_inj
subtype (sub $ getDataId t_inj) t_patt
let checkP p t = checkPattern p =<< apply (sub t)
ps' <- zipWithM checkP ps $ getParams t_inj
apply (T.PInj (coerce name) (map fst ps'), t_patt)
where
substituteTVarsOf = \case
TAll tvar t -> do
tevar <- fresh
(substitute tvar tevar .) <$> substituteTVarsOf t
_ -> pure id
getParams = \case
TAll _ t -> getParams t
t -> go [] t
where
go acc = \case
TFun t1 t2 -> go (snoc t1 acc) t2
_ -> acc
getDataId typ = case typ of
TAll _ t -> getDataId t
TFun _ t -> getDataId t
TData {} -> typ
--------------------------------------------------------------------------- ---------------------------------------------------------------------------
-- * Subtyping rules -- * Subtyping rules
--------------------------------------------------------------------------- ---------------------------------------------------------------------------
@ -186,8 +452,8 @@ subtype t1 t2 = case (t1, t2) of
-- Γ ⊢ A₁ → A₂ <: B₁ → B₂ ⊣ Δ -- Γ ⊢ A₁ → A₂ <: B₁ → B₂ ⊣ Δ
(TFun a1 a2, TFun b1 b2) -> do (TFun a1 a2, TFun b1 b2) -> do
subtype b1 a1 subtype b1 a1
a2' <- applyEnv a2 a2' <- apply a2
b2' <- applyEnv b2 b2' <- apply b2
subtype a2' b2' subtype a2' b2'
-- Γ, α ⊢ A <: B ⊣ Δ,α -- Γ, α ⊢ A <: B ⊣ Δ,α
@ -245,8 +511,8 @@ subtype t1 t2 = case (t1, t2) of
zipWithM_ go t1s t2s zipWithM_ go t1s t2s
where where
go t1' t2' = do go t1' t2' = do
t1'' <- applyEnv t1' t1'' <- apply t1'
t2'' <- applyEnv t2' t2'' <- apply t2'
subtype t1'' t2'' subtype t1'' t2''
_ -> throwError $ unwords ["Types", ppT t1, "and", ppT t2, "doesn't match!"] _ -> throwError $ unwords ["Types", ppT t1, "and", ppT t2, "doesn't match!"]
@ -265,7 +531,7 @@ instantiateL tevar typ = gets env >>= go
-- Γ ⊢ τ -- Γ ⊢ τ
-- ----------------------------- InstLSolve -- ----------------------------- InstLSolve
-- Γ,ά,Γ' ⊢ ά :=< τ ⊣ Γ,(ά=τ),Γ' -- Γ,ά,Γ' ⊢ ά :=< τ ⊣ Γ,(ά=τ),Γ'
| noForall typ | isMono typ
, (env_l, env_r) <- splitOn (EnvTEVar tevar) env , (env_l, env_r) <- splitOn (EnvTEVar tevar) env
, Right _ <- wellFormed env_l typ , Right _ <- wellFormed env_l typ
= putEnv $ (env_l :|> EnvTEVarSolved tevar typ) <> env_r = putEnv $ (env_l :|> EnvTEVarSolved tevar typ) <> env_r
@ -282,7 +548,7 @@ instantiateL tevar typ = gets env >>= go
insertEnv $ EnvTEVar tevar1 insertEnv $ EnvTEVar tevar1
insertEnv $ EnvTEVarSolved tevar (on TFun TEVar tevar1 tevar2) insertEnv $ EnvTEVarSolved tevar (on TFun TEVar tevar1 tevar2)
instantiateR t1 tevar1 instantiateR t1 tevar1
instantiateL tevar2 =<< applyEnv t2 instantiateL tevar2 =<< apply t2
-- Γ[ά],ε ⊢ ά :=< E ⊣ Δ,ε,Δ' -- Γ[ά],ε ⊢ ά :=< E ⊣ Δ,ε,Δ'
-- ------------------------- InstLAIIR -- ------------------------- InstLAIIR
@ -305,7 +571,7 @@ instantiateR typ tevar = gets env >>= go
-- Γ ⊢ τ -- Γ ⊢ τ
-- ----------------------------- InstRSolve -- ----------------------------- InstRSolve
-- Γ,ά,Γ' ⊢ τ =:< ά ⊣ Γ,(ά=τ),Γ' -- Γ,ά,Γ' ⊢ τ =:< ά ⊣ Γ,(ά=τ),Γ'
| noForall typ | isMono typ
, (env_l, env_r) <- splitOn (EnvTEVar tevar) env , (env_l, env_r) <- splitOn (EnvTEVar tevar) env
, Right _ <- wellFormed env_l typ , Right _ <- wellFormed env_l typ
= putEnv $ (env_l :|> EnvTEVarSolved tevar typ) <> env_r = putEnv $ (env_l :|> EnvTEVarSolved tevar typ) <> env_r
@ -322,7 +588,7 @@ instantiateR typ tevar = gets env >>= go
insertEnv $ EnvTEVar tevar1 insertEnv $ EnvTEVar tevar1
insertEnv $ EnvTEVarSolved tevar (on TFun TEVar tevar1 tevar2) insertEnv $ EnvTEVarSolved tevar (on TFun TEVar tevar1 tevar2)
instantiateL tevar1 t1 instantiateL tevar1 t1
t2' <- applyEnv t2 t2' <- apply t2
instantiateR t2' tevar2 instantiateR t2' tevar2
-- Γ[ά],▶έ,ε ⊢ [έ/ε]E =:< ά ⊣ Δ,▶έ,Δ' -- Γ[ά],▶έ,ε ⊢ [έ/ε]E =:< ά ⊣ Δ,▶έ,Δ'
@ -352,293 +618,6 @@ instReach tevar tevar' = do
let env_solved = EnvTEVarSolved tevar' $ TEVar tevar let env_solved = EnvTEVarSolved tevar' $ TEVar tevar
putEnv $ (env_l :|> env_solved) <> env_r putEnv $ (env_l :|> env_solved) <> env_r
---------------------------------------------------------------------------
-- * Typing rules
---------------------------------------------------------------------------
-- | Γ ⊢ e ↑ A ⊣ Δ
-- Under input context Γ, e checks against input type A, with output context ∆
check :: Exp -> Type -> Tc (T.ExpT' Type)
check exp typ
-- Γ,α ⊢ e ↑ A ⊣ Δ,α
-- ------------------- ∀I
-- Γ ⊢ e ↑ ∀α.A ⊣ Δ
| TAll tvar t <- typ = do
let env_tvar = EnvTVar tvar
insertEnv env_tvar
exp' <- check exp t
(env_l, _) <- gets (splitOn env_tvar . env)
putEnv env_l
pure exp'
-- Γ,(x:A) ⊢ e ↑ B ⊢ Δ,(x:A),Θ
-- --------------------------- →I
-- Γ ⊢ λx.e ↑ A → B ⊣ Δ
| EAbs name e <- exp
, TFun t1 t2 <- typ = do
let env_var = EnvVar name t1
insertEnv env_var
e' <- check e t2
(env_l, _) <- gets (splitOn env_var . env)
putEnv env_l
pure (T.EAbs (coerce name) e', typ)
| otherwise = subsumption
where
-- Γ,α ⊢ e ↓ A ⊣ Θ Θ ⊢ [Θ]A <: [Θ]B ⊣ Δ
-- -------------------------------------- Sub
-- Γ ⊢ e ↑ B ⊣ Δ
subsumption = do
(exp', t) <- infer exp
exp'' <- applyEnvExp exp'
t' <- applyEnv t
typ' <- applyEnv typ
subtype t' typ'
pure (exp'', t')
-- | Γ ⊢ e ↓ A ⊣ Δ
-- Under input context Γ, e infers output type A, with output context ∆
infer :: Exp -> Tc (T.ExpT' Type)
infer = \case
ELit lit -> pure (T.ELit lit, litType lit)
-- (x : A) ∈ Γ (x : A) ∉ Γ
-- ------------- Var --------------- Var'
-- Γ ⊢ x ↓ A ⊣ Γ Γ ⊢ x ↓ ά ⊣ Γ,ά
EVar name -> do
t <- liftA2 (<|>) (lookupEnv name) (lookupSig name) >>= \case
Just t -> pure t
Nothing -> do
tevar <- fresh
insertEnv (EnvTEVar tevar)
let t = TEVar tevar
insertEnv (EnvVar name t)
pure t
pure (T.EVar (coerce name), t)
EInj name -> do
t <- maybeToRightM ("Unknown constructor: " ++ show name) =<< lookupInj name
pure (T.EInj $ coerce name, t)
-- Γ ⊢ A Γ ⊢ e ↑ A ⊣ Δ
-- --------------------- Anno
-- Γ ⊢ (e : A) ↓ A ⊣ Δ
EAnn e t -> do
_ <- gets $ (`wellFormed` t) . env
(e', _) <- check e t
pure (e', t)
-- Γ ⊢ e₁ ↓ A ⊣ Θ Γ ⊢ [Θ]A • ⇓ C ⊣ Δ
-- ----------------------------------- →E
-- Γ ⊢ e₁ e₂ ↓ C ⊣ Δ
EApp e1 e2 -> do
(e1', t) <- infer e1
t' <- applyEnv t
e1'' <- applyEnvExp e1'
(e2', t'') <- apply t' e2
pure (T.EApp (e1'', t) e2', t'')
-- Γ,ά,έ,(x:ά) ⊢ e ↑ έ ⊣ Δ,(x:ά),Θ
-- ------------------------------- →I
-- Γ ⊢ λx.e ↓ ά → έ ⊣ Δ
EAbs name e -> do
tevar1 <- fresh
tevar2 <- fresh
insertEnv $ EnvTEVar tevar1
insertEnv $ EnvTEVar tevar2
let env_var = EnvVar name (TEVar tevar1)
insertEnv env_var
e' <- check e $ TEVar tevar2
dropTrailing env_var
let t_exp = on TFun TEVar tevar1 tevar2
pure (T.EAbs (coerce name) e', t_exp)
-- Γ ⊢ e ↓ A ⊣ Θ Θ,(x:A) ⊢ e' ↑ C ⊣ Δ,(x:A),Θ
-- -------------------------------------------- LetI
-- Γ ⊢ let x=e in e' ↑ C ⊣ Δ
ELet (Bind name [] rhs) e -> do -- TODO vars
(rhs', t_rhs) <- infer rhs
let env_var = EnvVar name t_rhs
insertEnv env_var
(e', t) <- infer e
(env_l, _) <- gets (splitOn env_var . env)
putEnv env_l
pure (T.ELet (T.Bind (coerce name, t_rhs) [] (rhs', t_rhs)) (e',t), t)
-- Γ ⊢ e₁ ↑ Int ⊣ Θ Θ ⊢ e₂ ↑ Int
-- --------------------------- +I
-- Γ ⊢ e₁ + e₂ ↓ Int ⊣ Δ
EAdd e1 e2 -> do
e1' <- check e1 int
e2' <- check e2 int
e1'' <- applyEnvExpT e1'
e2'' <- applyEnvExpT e2'
pure (T.EAdd e1'' e2'', int)
-- Θ ⊢ Π ∷ A ↓ C ⊣ Δ
-- Γ ⊢ e ↓ A ⊣ Θ Δ ⊢ Π covers [Δ]A TODO
-- ---------------------------------------
-- Γ ⊢ case e of Π ↓ C ⊣ Δ
ECase scrut branches -> do
(scrut', t_scrut) <- infer scrut
(branches', t_return) <- inferBranches branches t_scrut
pure (T.ECase (scrut', t_scrut) branches', t_return)
-- | Γ ⊢ A • e ⇓ C ⊣ Δ
-- Under input context Γ , applying a function of type A to e infers type C, with output context ∆
-- Instantiate existential type variables until there is an arrow type.
apply :: Type -> Exp -> Tc (T.ExpT' Type, Type)
apply typ exp = case typ of
-- Γ,ά ⊢ [ά/α]A • e ⇓ C ⊣ Δ
-- ------------------------ ∀App
-- Γ ⊢ ∀α.A • e ⇓ C ⊣ Δ
TAll tvar t -> do
tevar <- fresh
insertEnv $ EnvTEVar tevar
let t' = substitute tvar tevar t
apply t' exp
-- Γ[ά₂,ά₁,(ά=ά₁→ά₂)] ⊢ e ↑ ά₁ ⊣ Δ
-- ------------------------------- άApp
-- Γ[ά] ⊢ ά • e ⇓ ά₂ ⊣ Δ
TEVar tevar -> do
tevar1 <- fresh
tevar2 <- fresh
let env_tevar1 = EnvTEVar tevar1
env_tevar2 = EnvTEVar tevar2
t_fun = on TFun TEVar tevar1 tevar2
env_tevar_solved = EnvTEVarSolved tevar t_fun
(env_l, env_r) <- gets (splitOn (EnvTEVar tevar) . env)
putEnv $
(env_l :|> env_tevar2 :|> env_tevar1 :|> env_tevar_solved) <> env_r
expT' <- check exp $ TEVar tevar1
pure (expT', TEVar tevar2)
-- Γ ⊢ e ↑ A ⊣ Δ
-- --------------------- →App
-- Γ ⊢ A → C • e ⇓ C ⊣ Δ
TFun t1 t2 -> do
expt' <- check exp t1
pure (expt', t2)
_ -> throwError ("Cannot apply type " ++ show typ ++ " with expression " ++ show exp)
---------------------------------------------------------------------------
-- * Pattern matching
---------------------------------------------------------------------------
-- Γ ⊢ p ⇒ e ∷ A ↓ B ⊣ Θ
-- Θ ⊢ Π ∷ [Θ]A ↓ C ⊣ Δ
-- [Δ]B <: C
-- ---------------------------
-- Γ ⊢ (p ⇒ e),Π ∷ A ↓ C ⊣ Δ
inferBranches :: [Branch] -> Type -> Tc ([T.Branch' Type], Type)
inferBranches branches t_patt = do
(branches', ts_exp) <- inferBranches' t_patt branches
ts_exp' <- mapM applyEnv ts_exp
let (monos, pols) = partition isMono ts_exp'
t_exp <- liftEither $ bodyType t_patt monos
mapM_ (subtype t_exp) pols
pure (branches', t_exp)
where
bodyType :: Type -> [Type] -> Err Type
bodyType t_patt = \case
[] -> pure t_patt
[m] -> pure m
m:n:ms | m == n -> bodyType t_patt (n:ms)
| otherwise -> throwError $ unwords [ "Wrong return types: "
, ppT m, "", ppT n ]
inferBranches' = go [] []
where
go branches ts_exp t = \case
[] -> pure (branches, ts_exp)
b:bs -> do
(b', t_e) <- inferBranch b t
t' <- applyEnv t
go (snoc b' branches) (snoc t_e ts_exp) t' bs
-- Γ ⊢ p ↑ A ⊣ Θ Θ ⊢ e ↓ C ⊣ Δ
-- -------------------------------
-- Γ ⊢ p ⇒ e ∷ A ↓ C ⊣ Δ
inferBranch :: Branch -> Type -> Tc (T.Branch' Type, Type)
inferBranch (Branch patt exp) t_patt = do
patt' <- checkPattern patt t_patt
(exp', t_exp) <- infer exp
pure (T.Branch patt' (exp', t_exp), t_exp)
checkPattern :: Pattern -> Type -> Tc (T.Pattern' Type, Type)
checkPattern patt t_patt = case patt of
-- -------------------
-- Γ ⊢ x ↑ A ⊣ Γ,(x:A)
PVar x -> do
insertEnv $ EnvVar x t_patt
pure (T.PVar (coerce x, t_patt), t_patt)
-- -------------
-- Γ ⊢ _ ↑ A ⊣ Γ
PCatch -> pure (T.PCatch, t_patt)
-- Γ ⊢ τ ↓ A ⊣ Γ Γ ⊢ A <: B ⊣ Δ
-- ------------------------------
-- Γ ⊢ τ ↑ B ⊣ Δ
PLit lit -> do
subtype (litType lit) t_patt
t_patt' <- applyEnv t_patt
pure (T.PLit (lit, t_patt), t_patt')
-- (x : A) ∈ Γ Γ ⊢ A <: B ⊣ Δ
-- ---------------------------
-- Γ ⊢ inj₀ x ↑ B ⊣ Δ
PEnum name -> do
t <- maybeToRightM ("Unknown constructor " ++ show name)
=<< lookupInj name
subtype t t_patt
t_patt' <- applyEnv t_patt
pure (T.PEnum (coerce name), t_patt')
PInj name ps -> do
t_inj <- maybeToRightM "unknown constructor" =<< lookupInj name
t_inj' <- foldrM substitute' t_inj $ getInitForalls t_inj
subtype (getDataId t_inj') t_patt
t_inj'' <- applyEnv t_inj'
let ts_inj = getParams t_inj''
ps' <- zipWithM (\p t -> checkPattern p =<< applyEnv t) ps ts_inj
t_patt' <- applyEnv t_patt
pure (T.PInj (coerce name) (map fst ps'), t_patt')
where
substitute' fa t = do
tevar <- fresh
-- insertEnv (EnvTEVar tevar)
pure $ substitute tvar tevar t
where
TAll tvar _ = fa int
getParams = \case
TAll _ t -> getParams t
t -> go [] t
where
go acc = \case
TFun t1 t2 -> go (snoc t1 acc) t2
_ -> acc
getDataId typ = case typ of
TAll _ t -> getDataId t
TFun _ t -> getDataId t
TData {} -> typ
getInitForalls = go []
where
go acc = \case
TAll tvar t -> go (snoc (TAll tvar) acc) t
_ -> acc
--------------------------------------------------------------------------- ---------------------------------------------------------------------------
-- * Auxiliary -- * Auxiliary
@ -677,55 +656,6 @@ splitOn x env = second (S.drop 1) $ S.breakl (==x) env
dropTrailing :: EnvElem -> Tc () dropTrailing :: EnvElem -> Tc ()
dropTrailing x = modifyEnv $ S.takeWhileL (/= x) dropTrailing x = modifyEnv $ S.takeWhileL (/= x)
applyEnvExpT :: (T.Exp' Type, Type) -> Tc (T.Exp' Type, Type)
applyEnvExpT (e, t) = liftA2 (,) (applyEnvExp e) (applyEnv t)
applyEnvExp :: T.Exp' Type -> Tc (T.Exp' Type)
applyEnvExp exp = case exp of
T.ELet (T.Bind id vars rhs) exp -> do
id <- applyEnvId id
vars' <- mapM applyEnvId vars
rhs' <- applyEnvExpT rhs
exp' <- applyEnvExpT exp
pure $ T.ELet (T.Bind id vars' rhs') exp'
T.EApp e1 e2 -> liftA2 T.EApp (applyEnvExpT e1) (applyEnvExpT e2)
T.EAdd e1 e2 -> liftA2 T.EAdd (applyEnvExpT e1) (applyEnvExpT e2)
T.EAbs name e -> T.EAbs name <$> applyEnvExpT e
T.ECase e branches -> liftA2 T.ECase (applyEnvExpT e)
(mapM applyEnvBranch branches)
_ -> pure exp
where
applyEnvId = secondM applyEnv
applyEnvBranch (T.Branch (p, t) e) = do
pt <- liftA2 (,) (applyEnvPattern p) (applyEnv t)
e' <- applyEnvExpT e
pure $ T.Branch pt e'
applyEnvPattern = \case
T.PVar id -> T.PVar <$> applyEnvId id
T.PLit (lit, t) -> T.PLit . (lit, ) <$> applyEnv t
T.PInj name ps -> T.PInj name <$> mapM applyEnvPattern ps
p -> pure p
applyEnv :: Type -> Tc Type
applyEnv t = gets $ (`applyEnv'` t) . env
-- | [Γ]A. Applies context to type until fully applied.
applyEnv' :: Env -> Type -> Type
applyEnv' cxt typ | typ == typ' = typ'
| otherwise = applyEnv' cxt typ'
where
typ' = case typ of
TLit _ -> typ
TData name typs -> TData name $ map (applyEnv' cxt) typs
-- [Γ]α = α
TVar _ -> typ
-- [Γ[ά=τ]]ά = [Γ[ά=τ]]τ
-- [Γ[ά]]ά = [Γ[ά]]ά
TEVar tevar -> fromMaybe typ $ findSolved tevar cxt
-- [Γ](A → B) = [Γ]A → [Γ]B
TFun t1 t2 -> on TFun (applyEnv' cxt) t1 t2
-- [Γ](∀α. A) = (∀α. [Γ]A)
TAll tvar t -> TAll tvar $ applyEnv' cxt t
findSolved :: TEVar -> Env -> Maybe Type findSolved :: TEVar -> Env -> Maybe Type
findSolved _ Empty = Nothing findSolved _ Empty = Nothing
@ -765,27 +695,18 @@ wellFormed env = \case
TData _ typs -> mapM_ (wellFormed env) typs TData _ typs -> mapM_ (wellFormed env) typs
noForall :: Type -> Bool
noForall = \case
TAll{} -> False
TFun t1 t2 -> on (&&) noForall t1 t2
TData _ typs -> all noForall typs
TVar _ -> True
TEVar _ -> True
TLit _ -> True
isMono :: Type -> Bool isMono :: Type -> Bool
isMono = \case isMono = \case
TAll{} -> False TAll{} -> False
TFun t1 t2 -> on (&&) isMono t1 t2 TFun t1 t2 -> on (&&) isMono t1 t2
TData _ typs -> all isMono typs TData _ typs -> all isMono typs
TVar _ -> False TVar _ -> True
TEVar _ -> False TEVar _ -> True
TLit _ -> True TLit _ -> True
fresh :: Tc TEVar fresh :: Tc TEVar
fresh = do fresh = do
tevar <- gets (MkTEVar . LIdent . ("a#" ++) . show . next_tevar) tevar <- gets (MkTEVar . LIdent . show . next_tevar)
modify $ \cxt -> cxt { next_tevar = succ cxt.next_tevar } modify $ \cxt -> cxt { next_tevar = succ cxt.next_tevar }
pure tevar pure tevar
@ -805,7 +726,6 @@ getReturn = snd . partitionType
partitionType :: Type -> ([Type], Type) partitionType :: Type -> ([Type], Type)
partitionType = go [] . skipForalls' partitionType = go [] . skipForalls'
where where
go acc t = case t of go acc t = case t of
TFun t1 t2 -> go (snoc t1 acc) t2 TFun t1 t2 -> go (snoc t1 acc) t2
_ -> (acc, t) _ -> (acc, t)
@ -863,6 +783,74 @@ modifyEnv f =
pattern DBind' name vars exp = DBind (Bind name vars exp) pattern DBind' name vars exp = DBind (Bind name vars exp)
pattern DSig' name typ = DSig (Sig name typ) pattern DSig' name typ = DSig (Sig name typ)
---------------------------------------------------------------------------
-- * Apply
---------------------------------------------------------------------------
class Apply a where
apply :: a -> Tc a
instance Apply Type where apply = applyType
instance Apply (T.Exp' Type) where apply = applyExp
instance Apply (T.Branch' Type) where apply = applyBranch
instance Apply (T.Pattern' Type) where apply = applyPattern
instance Apply a => Apply [a] where apply = mapM apply
instance (Apply a, Apply b) => Apply (a, b) where apply = applyPair
instance Apply T.Ident where apply = pure
applyType :: Type -> Tc Type
applyType t = gets $ (`applyType'` t) . env
-- | [Γ]A. Applies context to type until fully applied.
applyType' :: Env -> Type -> Type
applyType' cxt typ | typ == typ' = typ'
| otherwise = applyType' cxt typ'
where
typ' = case typ of
TLit _ -> typ
TData name typs -> TData name $ map (applyType' cxt) typs
-- [Γ]α = α
TVar _ -> typ
-- [Γ[ά=τ]]ά = [Γ[ά=τ]]τ
-- [Γ[ά]]ά = [Γ[ά]]ά
TEVar tevar -> fromMaybe typ $ findSolved tevar cxt
-- [Γ](A → B) = [Γ]A → [Γ]B
TFun t1 t2 -> on TFun (applyType' cxt) t1 t2
-- [Γ](∀α. A) = (∀α. [Γ]A)
TAll tvar t -> TAll tvar $ applyType' cxt t
applyExp :: T.Exp' Type -> Tc (T.Exp' Type)
applyExp exp = case exp of
T.ELet (T.Bind id vars rhs) exp -> do
id <- apply id
vars' <- mapM apply vars
rhs' <- apply rhs
exp' <- apply exp
pure $ T.ELet (T.Bind id vars' rhs') exp'
T.EApp e1 e2 -> liftA2 T.EApp (apply e1) (apply e2)
T.EAdd e1 e2 -> liftA2 T.EAdd (apply e1) (apply e2)
T.EAbs name e -> T.EAbs name <$> apply e
T.ECase e branches -> liftA2 T.ECase (apply e)
(mapM apply branches)
_ -> pure exp
applyBranch :: T.Branch' Type -> Tc (T.Branch' Type)
applyBranch (T.Branch (p, t) e) = do
pt <- liftA2 (,) (apply p) (apply t)
e' <- apply e
pure $ T.Branch pt e'
applyPattern :: T.Pattern' Type -> Tc (T.Pattern' Type)
applyPattern = \case
T.PVar id -> T.PVar <$> apply id
T.PLit (lit, t) -> T.PLit . (lit, ) <$> apply t
T.PInj name ps -> T.PInj name <$> apply ps
p -> pure p
applyPair :: (Apply a, Apply b) => (a, b) -> Tc (a, b)
applyPair (x, y) = liftA2 (,) (apply x) (apply y)
--------------------------------------------------------------------------- ---------------------------------------------------------------------------
-- * Debug -- * Debug
--------------------------------------------------------------------------- ---------------------------------------------------------------------------
@ -873,24 +861,24 @@ traceEnv s = do
traceD s x = trace (s ++ " " ++ show x) pure () traceD s x = trace (s ++ " " ++ show x) pure ()
traceT s x = trace (s ++ " " ++ ppT x) pure () traceT s x = trace (s ++ " : " ++ ppT x) pure ()
traceTs s xs = trace (s ++ " [ " ++ intercalate ", " (map ppT xs) ++ " ]") pure () traceTs s xs = trace (s ++ " [ " ++ intercalate ", " (map ppT xs) ++ " ]") pure ()
ppT = \case ppT = \case
TLit (UIdent s) -> s TLit (UIdent s) -> s
TVar (MkTVar (LIdent s)) -> "a_" ++ s TVar (MkTVar (LIdent s)) -> "tvar_" ++ s
TFun t1 t2 -> ppT t1 ++ "->" ++ ppT t2 TFun t1 t2 -> ppT t1 ++ "->" ++ ppT t2
TAll (MkTVar (LIdent s)) t -> "forall " ++ s ++ ". " ++ ppT t TAll (MkTVar (LIdent s)) t -> "forall " ++ s ++ ". " ++ ppT t
TEVar (MkTEVar (LIdent s)) -> "a^_" ++ s TEVar (MkTEVar (LIdent s)) -> "tevar_" ++ s
TData (UIdent name) typs -> name ++ " (" ++ unwords (map ppT typs) TData (UIdent name) typs -> name ++ " (" ++ unwords (map ppT typs)
++ " )" ++ " )"
ppEnvElem = \case ppEnvElem = \case
EnvVar (LIdent s) t -> s ++ ":" ++ ppT t EnvVar (LIdent s) t -> s ++ ":" ++ ppT t
EnvTVar (MkTVar (LIdent s)) -> "a_" ++ s EnvTVar (MkTVar (LIdent s)) -> "tvar_" ++ s
EnvTEVar (MkTEVar (LIdent s)) -> "a^_" ++ s EnvTEVar (MkTEVar (LIdent s)) -> "tevar_" ++ s
EnvTEVarSolved (MkTEVar (LIdent s)) t -> "_" ++ s ++ "=" ++ ppT t EnvTEVarSolved (MkTEVar (LIdent s)) t -> "tevar_" ++ s ++ "=" ++ ppT t
EnvMark (MkTEVar (LIdent s)) -> "" ++ "a^_" ++ s EnvMark (MkTEVar (LIdent s)) -> "" ++ "tevar_" ++ s
ppEnv = \case ppEnv = \case
Empty -> "·" Empty -> "·"

View file

@ -260,6 +260,17 @@ tc_pol_case = describe "Polymophic and recursive pattern matching" $ do
, " Cons (Cons _ ys) xs => 1 + elems (Cons ys xs)" , " Cons (Cons _ ys) xs => 1 + elems (Cons ys xs)"
] ]
tc_if = specify "Test if else case expression" $ do
run [ "data Bool () where"
, " True : Bool ()"
, " False : Bool ()"
, "ifThenElse : Bool () -> a -> a -> a"
, "ifThenElse b if else = case b of"
, " True => if"
, " False => else"
] `shouldSatisfy` ok
tc_infer_case = describe "Infer case expression" $ do tc_infer_case = describe "Infer case expression" $ do
specify "Wrong case expression rejected" $ specify "Wrong case expression rejected" $