From a109b3010df5782edd475e5673c2f11c29348127 Mon Sep 17 00:00:00 2001 From: Martin Fredin Date: Sat, 8 Apr 2023 21:52:57 +0200 Subject: [PATCH] Fix bad inference on case expression, and make pretty for report --- sample-programs/basic-6.crf | 15 +- sample-programs/basic-7.crf | 2 +- src/Auxiliary.hs | 14 +- src/TypeChecker/ReportTEVar.hs | 3 +- src/TypeChecker/TypeCheckerBidir.hs | 752 ++++++++++++++-------------- tests/TestTypeCheckerBidir.hs | 11 + 6 files changed, 406 insertions(+), 391 deletions(-) diff --git a/sample-programs/basic-6.crf b/sample-programs/basic-6.crf index bc8bebe..ed51a1c 100644 --- a/sample-programs/basic-6.crf +++ b/sample-programs/basic-6.crf @@ -2,7 +2,14 @@ data Bool () where True : Bool () False : Bool () -main : Bool () -> a -> Int -main b = case b of - False => (\x. 1) - True => (\x. 0) +-- Both valid +-- f : Bool () -> a -> Int +f : Bool () -> (forall a. a -> Int) +f b = case b of + False => (\x. 0 : forall a. a -> Int) + True => (\x. 1 : forall a. a -> Int) + + +main : Int +main = (f True) 'h' + diff --git a/sample-programs/basic-7.crf b/sample-programs/basic-7.crf index 6fed9b7..f0fc916 100644 --- a/sample-programs/basic-7.crf +++ b/sample-programs/basic-7.crf @@ -2,7 +2,7 @@ data Bool () where True : Bool () False : Bool () -ifThenElse : forall a. Bool () -> a -> a -> a +ifThenElse : Bool () -> a -> a -> a ifThenElse b if else = case b of True => if False => else diff --git a/src/Auxiliary.hs b/src/Auxiliary.hs index cfdd828..22095aa 100644 --- a/src/Auxiliary.hs +++ b/src/Auxiliary.hs @@ -4,9 +4,8 @@ module Auxiliary (module Auxiliary) where -import Control.Applicative (Applicative (liftA2)) import Control.Monad.Error.Class (liftEither) -import Control.Monad.Except (MonadError) +import Control.Monad.Except (MonadError, liftM2) import Data.Either.Combinators (maybeToRight) import Data.List (foldl') import Grammar.Abs @@ -31,8 +30,11 @@ mapAccumM f = go (acc'', xs') <- go acc' xs pure (acc'', x' : xs') +onMM :: Monad m => (b -> b -> m c) -> (a -> m b) -> a -> a -> m c +onMM f g x y = liftMM2 f (g x) (g y) + onM :: Monad m => (b -> b -> c) -> (a -> m b) -> a -> a -> m c -onM f g x y = liftA2 f (g x) (g y) +onM f g x y = liftM2 f (g x) (g y) unzip4 :: [(a, b, c, d)] -> ([a], [b], [c], [d]) unzip4 = @@ -42,6 +44,12 @@ unzip4 = ) ([], [], [], []) +liftMM2 :: Monad m => (a -> b -> m c) -> m a -> m b -> m c +liftMM2 f m1 m2 = do + x1 <- m1 + x2 <- m2 + f x1 x2 + litType :: Lit -> Type litType (LInt _) = int litType (LChar _) = char diff --git a/src/TypeChecker/ReportTEVar.hs b/src/TypeChecker/ReportTEVar.hs index e69c8b6..61ed688 100644 --- a/src/TypeChecker/ReportTEVar.hs +++ b/src/TypeChecker/ReportTEVar.hs @@ -9,6 +9,7 @@ import Data.Coerce (coerce) import Data.Tuple.Extra (secondM) import qualified Grammar.Abs as G import Grammar.ErrM (Err) +import Grammar.Print (printTree) import TypeChecker.TypeCheckerIr hiding (Type (..)) @@ -78,4 +79,4 @@ instance ReportTEVar G.Type Type where G.TData name typs -> TData (coerce name) <$> reportTEVar typs G.TFun t1 t2 -> liftA2 TFun (reportTEVar t1) (reportTEVar t2) G.TAll (G.MkTVar i) t -> TAll (MkTVar $ coerce i) <$> reportTEVar t - G.TEVar _ -> throwError "NewType TEVar!" + G.TEVar tevar -> throwError ("Found TEVar: " ++ printTree tevar) diff --git a/src/TypeChecker/TypeCheckerBidir.hs b/src/TypeChecker/TypeCheckerBidir.hs index d6ec572..1ad5bea 100644 --- a/src/TypeChecker/TypeCheckerBidir.hs +++ b/src/TypeChecker/TypeCheckerBidir.hs @@ -6,30 +6,28 @@ module TypeChecker.TypeCheckerBidir (typecheck, getVars) where -import Auxiliary (int, litType, maybeToRightM, snoc) -import Control.Applicative (Alternative, Applicative (liftA2), - (<|>)) +import Auxiliary (int, liftMM2, litType, + maybeToRightM, onM, onMM, snoc) +import Control.Applicative (Applicative (liftA2), (<|>)) import Control.Monad.Except (ExceptT, MonadError (throwError), - liftEither, runExceptT, unless, - zipWithM, zipWithM_) -import Control.Monad.State (MonadState, State, evalState, gets, - modify) + runExceptT, unless, zipWithM, + zipWithM_) +import Control.Monad.State (State, evalState, gets, modify) import Data.Coerce (coerce) -import Data.Foldable (foldrM) import Data.Function (on) -import Data.List (intercalate, partition) +import Data.List (intercalate) import Data.Map (Map) import qualified Data.Map as Map import Data.Maybe (fromMaybe, isNothing) import Data.Sequence (Seq (..)) import qualified Data.Sequence as S import qualified Data.Set as Set -import Data.Tuple.Extra (second, secondM) +import Data.Tuple.Extra (second) import Debug.Trace (trace) import Grammar.Abs import Grammar.ErrM import Grammar.Print (printTree) -import Prelude hiding (exp, id) +import Prelude hiding (exp) import qualified TypeChecker.TypeCheckerIr as T -- Implementation is derived from the paper (Dunfield and Krishnaswami 2013) @@ -59,8 +57,9 @@ data Cxt = Cxt , data_injs :: Map UIdent Type -- ^ Data injections (constructors) K/inj : A } deriving (Show, Eq) -newtype Tc a = Tc { runTc :: ExceptT String (State Cxt) a } - deriving (Functor, Applicative, Monad, Alternative, MonadState Cxt, MonadError String) +type Tc a = ExceptT String (State Cxt) a + -- deriving (Functor, Applicative, Monad, Alternative, MonadState Cxt, MonadError String) + initCxt :: [Def] -> Cxt initCxt defs = Cxt @@ -96,7 +95,7 @@ typecheck (Program defs) = do typecheckBinds :: Cxt -> [Bind] -> Err [T.Bind' Type] typecheckBinds cxt = flip evalState cxt . runExceptT - . runTc + -- . runTc . mapM typecheckBind typecheckBind :: Bind -> Tc (T.Bind' Type) @@ -106,10 +105,8 @@ typecheckBind (Bind name vars rhs) = do (rhs', _) <- check (foldr EAbs rhs vars) t pure (T.Bind (coerce name, t) [] (rhs', t)) Nothing -> do - (e, t) <- infer $ foldr EAbs rhs vars - t' <- applyEnv t - e' <- applyEnvExp e - pure (T.Bind (coerce name, t') [] (e', t')) + (e, t) <- apply =<< infer (foldr EAbs rhs vars) + pure (T.Bind (coerce name, t) [] (e, t)) env <- gets env unless (isComplete env) err insertSig (coerce name) typ @@ -162,6 +159,275 @@ typecheckInj (Inj inj_name inj_typ) name tvars TLit _ -> True TEVar _ -> error "TEVar in data type declaration" +--------------------------------------------------------------------------- +-- * Typing rules +--------------------------------------------------------------------------- + +-- | Γ ⊢ e ↑ A ⊣ Δ +-- Under input context Γ, e checks against input type A, with output context ∆ +check :: Exp -> Type -> Tc (T.ExpT' Type) +check exp typ + + -- Γ,α ⊢ e ↑ A ⊣ Δ,α,Θ + -- ------------------- ∀I + -- Γ ⊢ e ↑ ∀α.A ⊣ Δ + | TAll tvar t <- typ = do + let env_tvar = EnvTVar tvar + insertEnv env_tvar + exp' <- check exp t + (env_l, _) <- gets (splitOn env_tvar . env) + putEnv env_l + pure exp' + + -- Γ,(x:A) ⊢ e ↑ B ⊢ Δ,(x:A),Θ + -- --------------------------- →I + -- Γ ⊢ λx.e ↑ A → B ⊣ Δ + | EAbs name e <- exp + , TFun t1 t2 <- typ = do + let env_var = EnvVar name t1 + insertEnv env_var + e' <- check e t2 + (env_l, _) <- gets (splitOn env_var . env) + putEnv env_l + pure (T.EAbs (coerce name) e', typ) + + | otherwise = subsumption + where + -- Γ,α ⊢ e ↓ A ⊣ Θ Θ ⊢ [Θ]A <: [Θ]B ⊣ Δ + -- -------------------------------------- Sub + -- Γ ⊢ e ↑ B ⊣ Δ + subsumption = do + (exp', t) <- infer exp + typ' <- apply typ + subtype t typ' + apply (exp', t) + +-- | Γ ⊢ e ↓ A ⊣ Δ +-- Under input context Γ, e infers output type A, with output context ∆ +infer :: Exp -> Tc (T.ExpT' Type) +infer = \case + + ELit lit -> pure (T.ELit lit, litType lit) + + -- Γ ∋ (x : A) Γ ∌ (x : A) + -- ------------- Var --------------- Var' + -- Γ ⊢ x ↓ A ⊣ Γ Γ ⊢ x ↓ ά ⊣ Γ,ά + EVar name -> do + t <- liftA2 (<|>) (lookupEnv name) (lookupSig name) >>= \case + Just t -> pure t + Nothing -> do + tevar <- fresh + insertEnv (EnvTEVar tevar) + let t = TEVar tevar + insertEnv (EnvVar name t) + pure t + apply (T.EVar (coerce name), t) + + EInj name -> do + t <- maybeToRightM ("Unknown constructor: " ++ show name) + =<< lookupInj name + apply (T.EInj $ coerce name, t) + + -- Γ ⊢ A Γ ⊢ e ↑ A ⊣ Δ + -- --------------------- Anno + -- Γ ⊢ (e : A) ↓ A ⊣ Δ + EAnn e t -> do + _ <- gets $ (`wellFormed` t) . env + (e', _) <- check e t + apply (e', t) + + -- Γ ⊢ e₁ ↓ A ⊣ Θ Γ ⊢ [Θ]A • ⇓ C ⊣ Δ + -- ----------------------------------- →E + -- Γ ⊢ e₁ e₂ ↓ C ⊣ Δ + EApp e1 e2 -> do + (e1', t) <- infer e1 + (e2', t'') <- applyInfer t e2 + apply (T.EApp (e1', t) e2', t'') + + -- Γ,ά,έ,(x:ά) ⊢ e ↑ έ ⊣ Δ,(x:ά),Θ + -- ------------------------------- →I + -- Γ ⊢ λx.e ↓ ά → έ ⊣ Δ + EAbs name e -> do + tevar1 <- fresh + tevar2 <- fresh + insertEnv $ EnvTEVar tevar1 + insertEnv $ EnvTEVar tevar2 + let env_var = EnvVar name (TEVar tevar1) + insertEnv env_var + e' <- check e $ TEVar tevar2 + dropTrailing env_var + let t_exp = on TFun TEVar tevar1 tevar2 + apply (T.EAbs (coerce name) e', t_exp) + + + -- Γ ⊢ e ↓ A ⊣ Θ Θ,(x:A) ⊢ e' ↑ C ⊣ Δ,(x:A),Θ + -- -------------------------------------------- LetI + -- Γ ⊢ let x=e in e' ↑ C ⊣ Δ + ELet (Bind name [] rhs) e -> do -- TODO vars + (rhs', t_rhs) <- infer rhs + let env_var = EnvVar name t_rhs + insertEnv env_var + (e', t) <- infer e + (env_l, _) <- gets (splitOn env_var . env) + putEnv env_l + apply (T.ELet (T.Bind (coerce name, t_rhs) [] (rhs', t_rhs)) (e',t), t) + + -- Γ ⊢ e₁ ↑ Int ⊣ Θ Θ ⊢ e₂ ↑ Int + -- --------------------------- +I + -- Γ ⊢ e₁ + e₂ ↓ Int ⊣ Δ + EAdd e1 e2 -> (, int) <$> onM T.EAdd (`check` int) e1 e2 + + -- Θ ⊢ Π ∷ A ↓ C ⊣ Δ + -- Γ ⊢ e ↓ A ⊣ Θ Δ ⊢ Π covers [Δ]A TODO + -- --------------------------------------- + -- Γ ⊢ case e of Π ↓ C ⊣ Δ + ECase scrut branches -> do + (scrut', t_scrut) <- infer scrut + (branches', t_return) <- inferBranches branches t_scrut + apply (T.ECase (scrut', t_scrut) branches', t_return) + +-- | Γ ⊢ A • e ⇓ C ⊣ Δ +-- Under input context Γ , applying a function of type A to e infers type C, with output context ∆ +-- Instantiate existential type variables until there is an arrow type. +applyInfer :: Type -> Exp -> Tc (T.ExpT' Type, Type) +applyInfer typ exp = case typ of + + -- Γ,ά ⊢ [ά/α]A • e ⇓ C ⊣ Δ + -- ------------------------ ∀App + -- Γ ⊢ ∀α.A • e ⇓ C ⊣ Δ + TAll tvar t -> do + tevar <- fresh + insertEnv $ EnvTEVar tevar + let t' = substitute tvar tevar t + applyInfer t' exp + + -- Γ[ά₂,ά₁,(ά=ά₁→ά₂)] ⊢ e ↑ ά₁ ⊣ Δ + -- ------------------------------- άApp + -- Γ[ά] ⊢ ά • e ⇓ ά₂ ⊣ Δ + TEVar tevar -> do + tevar1 <- fresh + tevar2 <- fresh + let env_tevar1 = EnvTEVar tevar1 + env_tevar2 = EnvTEVar tevar2 + t_fun = on TFun TEVar tevar1 tevar2 + env_tevar_solved = EnvTEVarSolved tevar t_fun + (env_l, env_r) <- gets (splitOn (EnvTEVar tevar) . env) + putEnv $ + (env_l :|> env_tevar2 :|> env_tevar1 :|> env_tevar_solved) <> env_r + expT' <- check exp $ TEVar tevar1 + apply (expT', TEVar tevar2) + + -- Γ ⊢ e ↑ A ⊣ Δ + -- --------------------- →App + -- Γ ⊢ A → C • e ⇓ C ⊣ Δ + TFun t1 t2 -> (, t2) <$> check exp t1 + + _ -> throwError ("Cannot apply type " ++ show typ ++ " with expression " ++ show exp) + +--------------------------------------------------------------------------- +-- * Pattern matching +--------------------------------------------------------------------------- + +-- Γ ⊢ p ⇒ e ∷ A ↓ B ⊣ Θ +-- Θ ⊢ Π ∷ [Θ]A ↓ C ⊣ Δ +-- [Δ]B <: C +-- --------------------------- +-- Γ ⊢ (p ⇒ e),Π ∷ A ↓ C ⊣ Δ +inferBranches :: [Branch] -> Type -> Tc ([T.Branch' Type], Type) +inferBranches branches t_patt = do + (branches', ts_exp) <- inferBranches' t_patt branches + traceTs "TYPES " ts_exp + t_exp <- case ts_exp of + [] -> pure t_patt + t:_ -> do + zipWithM_ (onMM subtype apply) (init ts_exp) (tail ts_exp) + apply t + apply (branches', t_exp) + where + + inferBranches' = go [] [] + where + go branches ts_exp t = \case + [] -> pure (branches, ts_exp) + b:bs -> do + (b', t_e) <- inferBranch b t + t' <- apply t + go (snoc b' branches) (snoc t_e ts_exp) t' bs + +-- Γ ⊢ p ↑ A ⊣ Θ Θ ⊢ e ↓ C ⊣ Δ +-- ------------------------------- +-- Γ ⊢ p ⇒ e ∷ A ↓ C ⊣ Δ +inferBranch :: Branch -> Type -> Tc (T.Branch' Type, Type) +inferBranch (Branch patt exp) t_patt = do + patt' <- checkPattern patt t_patt + (exp', t_exp) <- infer exp + apply (T.Branch patt' (exp', t_exp), t_exp) + +checkPattern :: Pattern -> Type -> Tc (T.Pattern' Type, Type) +checkPattern patt t_patt = case patt of + + -- ------------------- + -- Γ ⊢ x ↑ A ⊣ Γ,(x:A) + PVar x -> do + insertEnv $ EnvVar x t_patt + apply (T.PVar (coerce x, t_patt), t_patt) + + -- ------------- + -- Γ ⊢ _ ↑ A ⊣ Γ + PCatch -> apply (T.PCatch, t_patt) + + -- Γ ⊢ τ ↓ A ⊣ Γ Γ ⊢ A <: B ⊣ Δ + -- ------------------------------ + -- Γ ⊢ τ ↑ B ⊣ Δ + PLit lit -> do + subtype (litType lit) t_patt + apply (T.PLit (lit, t_patt), t_patt) + + -- Γ ∋ (K : A) Γ ⊢ A <: B ⊣ Δ + -- --------------------------- + -- Γ ⊢ K ↑ B ⊣ Δ + PEnum name -> do + t <- maybeToRightM ("Unknown constructor " ++ show name) + =<< lookupInj name + subtype t t_patt + apply (T.PEnum (coerce name), t_patt) + + + -- Example + -- Γ ∋ (K : A) let A = ∀α. A₁ -> A₂ -> Tτs + -- Γ ⊢ [ά/α]Tτs <: B ⊣ Θ₁ + -- Θ ⊢ p₁ ↑ [Θ][ά/α]A₁ ⊣ Θ₂ + -- Θ ⊢ p₂ ↑ [Θ][ά/α]A₂ ⊣ Δ + -- --------------------------- + -- Γ ⊢ K p₁ p₂ ↑ B ⊣ Δ + PInj name ps -> do + t_inj <- maybeToRightM "unknown constructor" =<< lookupInj name + sub <- substituteTVarsOf t_inj + subtype (sub $ getDataId t_inj) t_patt + let checkP p t = checkPattern p =<< apply (sub t) + ps' <- zipWithM checkP ps $ getParams t_inj + apply (T.PInj (coerce name) (map fst ps'), t_patt) + where + substituteTVarsOf = \case + TAll tvar t -> do + tevar <- fresh + (substitute tvar tevar .) <$> substituteTVarsOf t + _ -> pure id + + getParams = \case + TAll _ t -> getParams t + t -> go [] t + where + go acc = \case + TFun t1 t2 -> go (snoc t1 acc) t2 + _ -> acc + + getDataId typ = case typ of + TAll _ t -> getDataId t + TFun _ t -> getDataId t + TData {} -> typ + + --------------------------------------------------------------------------- -- * Subtyping rules --------------------------------------------------------------------------- @@ -186,8 +452,8 @@ subtype t1 t2 = case (t1, t2) of -- Γ ⊢ A₁ → A₂ <: B₁ → B₂ ⊣ Δ (TFun a1 a2, TFun b1 b2) -> do subtype b1 a1 - a2' <- applyEnv a2 - b2' <- applyEnv b2 + a2' <- apply a2 + b2' <- apply b2 subtype a2' b2' -- Γ, α ⊢ A <: B ⊣ Δ,α,Θ @@ -245,8 +511,8 @@ subtype t1 t2 = case (t1, t2) of zipWithM_ go t1s t2s where go t1' t2' = do - t1'' <- applyEnv t1' - t2'' <- applyEnv t2' + t1'' <- apply t1' + t2'' <- apply t2' subtype t1'' t2'' _ -> throwError $ unwords ["Types", ppT t1, "and", ppT t2, "doesn't match!"] @@ -265,7 +531,7 @@ instantiateL tevar typ = gets env >>= go -- Γ ⊢ τ -- ----------------------------- InstLSolve -- Γ,ά,Γ' ⊢ ά :=< τ ⊣ Γ,(ά=τ),Γ' - | noForall typ + | isMono typ , (env_l, env_r) <- splitOn (EnvTEVar tevar) env , Right _ <- wellFormed env_l typ = putEnv $ (env_l :|> EnvTEVarSolved tevar typ) <> env_r @@ -282,7 +548,7 @@ instantiateL tevar typ = gets env >>= go insertEnv $ EnvTEVar tevar1 insertEnv $ EnvTEVarSolved tevar (on TFun TEVar tevar1 tevar2) instantiateR t1 tevar1 - instantiateL tevar2 =<< applyEnv t2 + instantiateL tevar2 =<< apply t2 -- Γ[ά],ε ⊢ ά :=< E ⊣ Δ,ε,Δ' -- ------------------------- InstLAIIR @@ -305,7 +571,7 @@ instantiateR typ tevar = gets env >>= go -- Γ ⊢ τ -- ----------------------------- InstRSolve -- Γ,ά,Γ' ⊢ τ =:< ά ⊣ Γ,(ά=τ),Γ' - | noForall typ + | isMono typ , (env_l, env_r) <- splitOn (EnvTEVar tevar) env , Right _ <- wellFormed env_l typ = putEnv $ (env_l :|> EnvTEVarSolved tevar typ) <> env_r @@ -322,7 +588,7 @@ instantiateR typ tevar = gets env >>= go insertEnv $ EnvTEVar tevar1 insertEnv $ EnvTEVarSolved tevar (on TFun TEVar tevar1 tevar2) instantiateL tevar1 t1 - t2' <- applyEnv t2 + t2' <- apply t2 instantiateR t2' tevar2 -- Γ[ά],▶έ,ε ⊢ [έ/ε]E =:< ά ⊣ Δ,▶έ,Δ' @@ -352,293 +618,6 @@ instReach tevar tevar' = do let env_solved = EnvTEVarSolved tevar' $ TEVar tevar putEnv $ (env_l :|> env_solved) <> env_r ---------------------------------------------------------------------------- --- * Typing rules ---------------------------------------------------------------------------- - --- | Γ ⊢ e ↑ A ⊣ Δ --- Under input context Γ, e checks against input type A, with output context ∆ -check :: Exp -> Type -> Tc (T.ExpT' Type) -check exp typ - - -- Γ,α ⊢ e ↑ A ⊣ Δ,α,Θ - -- ------------------- ∀I - -- Γ ⊢ e ↑ ∀α.A ⊣ Δ - | TAll tvar t <- typ = do - let env_tvar = EnvTVar tvar - insertEnv env_tvar - exp' <- check exp t - (env_l, _) <- gets (splitOn env_tvar . env) - putEnv env_l - pure exp' - - -- Γ,(x:A) ⊢ e ↑ B ⊢ Δ,(x:A),Θ - -- --------------------------- →I - -- Γ ⊢ λx.e ↑ A → B ⊣ Δ - | EAbs name e <- exp - , TFun t1 t2 <- typ = do - let env_var = EnvVar name t1 - insertEnv env_var - e' <- check e t2 - (env_l, _) <- gets (splitOn env_var . env) - putEnv env_l - pure (T.EAbs (coerce name) e', typ) - - | otherwise = subsumption - where - -- Γ,α ⊢ e ↓ A ⊣ Θ Θ ⊢ [Θ]A <: [Θ]B ⊣ Δ - -- -------------------------------------- Sub - -- Γ ⊢ e ↑ B ⊣ Δ - subsumption = do - (exp', t) <- infer exp - exp'' <- applyEnvExp exp' - t' <- applyEnv t - typ' <- applyEnv typ - subtype t' typ' - pure (exp'', t') - --- | Γ ⊢ e ↓ A ⊣ Δ --- Under input context Γ, e infers output type A, with output context ∆ -infer :: Exp -> Tc (T.ExpT' Type) -infer = \case - - ELit lit -> pure (T.ELit lit, litType lit) - - -- (x : A) ∈ Γ (x : A) ∉ Γ - -- ------------- Var --------------- Var' - -- Γ ⊢ x ↓ A ⊣ Γ Γ ⊢ x ↓ ά ⊣ Γ,ά - EVar name -> do - t <- liftA2 (<|>) (lookupEnv name) (lookupSig name) >>= \case - Just t -> pure t - Nothing -> do - tevar <- fresh - insertEnv (EnvTEVar tevar) - let t = TEVar tevar - insertEnv (EnvVar name t) - pure t - pure (T.EVar (coerce name), t) - - EInj name -> do - t <- maybeToRightM ("Unknown constructor: " ++ show name) =<< lookupInj name - pure (T.EInj $ coerce name, t) - - -- Γ ⊢ A Γ ⊢ e ↑ A ⊣ Δ - -- --------------------- Anno - -- Γ ⊢ (e : A) ↓ A ⊣ Δ - EAnn e t -> do - _ <- gets $ (`wellFormed` t) . env - (e', _) <- check e t - pure (e', t) - - -- Γ ⊢ e₁ ↓ A ⊣ Θ Γ ⊢ [Θ]A • ⇓ C ⊣ Δ - -- ----------------------------------- →E - -- Γ ⊢ e₁ e₂ ↓ C ⊣ Δ - EApp e1 e2 -> do - (e1', t) <- infer e1 - t' <- applyEnv t - e1'' <- applyEnvExp e1' - (e2', t'') <- apply t' e2 - pure (T.EApp (e1'', t) e2', t'') - - -- Γ,ά,έ,(x:ά) ⊢ e ↑ έ ⊣ Δ,(x:ά),Θ - -- ------------------------------- →I - -- Γ ⊢ λx.e ↓ ά → έ ⊣ Δ - EAbs name e -> do - tevar1 <- fresh - tevar2 <- fresh - insertEnv $ EnvTEVar tevar1 - insertEnv $ EnvTEVar tevar2 - let env_var = EnvVar name (TEVar tevar1) - insertEnv env_var - e' <- check e $ TEVar tevar2 - dropTrailing env_var - let t_exp = on TFun TEVar tevar1 tevar2 - pure (T.EAbs (coerce name) e', t_exp) - - - -- Γ ⊢ e ↓ A ⊣ Θ Θ,(x:A) ⊢ e' ↑ C ⊣ Δ,(x:A),Θ - -- -------------------------------------------- LetI - -- Γ ⊢ let x=e in e' ↑ C ⊣ Δ - ELet (Bind name [] rhs) e -> do -- TODO vars - (rhs', t_rhs) <- infer rhs - let env_var = EnvVar name t_rhs - insertEnv env_var - (e', t) <- infer e - (env_l, _) <- gets (splitOn env_var . env) - putEnv env_l - pure (T.ELet (T.Bind (coerce name, t_rhs) [] (rhs', t_rhs)) (e',t), t) - - -- Γ ⊢ e₁ ↑ Int ⊣ Θ Θ ⊢ e₂ ↑ Int - -- --------------------------- +I - -- Γ ⊢ e₁ + e₂ ↓ Int ⊣ Δ - EAdd e1 e2 -> do - e1' <- check e1 int - e2' <- check e2 int - e1'' <- applyEnvExpT e1' - e2'' <- applyEnvExpT e2' - pure (T.EAdd e1'' e2'', int) - - -- Θ ⊢ Π ∷ A ↓ C ⊣ Δ - -- Γ ⊢ e ↓ A ⊣ Θ Δ ⊢ Π covers [Δ]A TODO - -- --------------------------------------- - -- Γ ⊢ case e of Π ↓ C ⊣ Δ - ECase scrut branches -> do - (scrut', t_scrut) <- infer scrut - (branches', t_return) <- inferBranches branches t_scrut - pure (T.ECase (scrut', t_scrut) branches', t_return) - --- | Γ ⊢ A • e ⇓ C ⊣ Δ --- Under input context Γ , applying a function of type A to e infers type C, with output context ∆ --- Instantiate existential type variables until there is an arrow type. -apply :: Type -> Exp -> Tc (T.ExpT' Type, Type) -apply typ exp = case typ of - - -- Γ,ά ⊢ [ά/α]A • e ⇓ C ⊣ Δ - -- ------------------------ ∀App - -- Γ ⊢ ∀α.A • e ⇓ C ⊣ Δ - TAll tvar t -> do - tevar <- fresh - insertEnv $ EnvTEVar tevar - let t' = substitute tvar tevar t - apply t' exp - - -- Γ[ά₂,ά₁,(ά=ά₁→ά₂)] ⊢ e ↑ ά₁ ⊣ Δ - -- ------------------------------- άApp - -- Γ[ά] ⊢ ά • e ⇓ ά₂ ⊣ Δ - TEVar tevar -> do - tevar1 <- fresh - tevar2 <- fresh - let env_tevar1 = EnvTEVar tevar1 - env_tevar2 = EnvTEVar tevar2 - t_fun = on TFun TEVar tevar1 tevar2 - env_tevar_solved = EnvTEVarSolved tevar t_fun - (env_l, env_r) <- gets (splitOn (EnvTEVar tevar) . env) - putEnv $ - (env_l :|> env_tevar2 :|> env_tevar1 :|> env_tevar_solved) <> env_r - expT' <- check exp $ TEVar tevar1 - pure (expT', TEVar tevar2) - - -- Γ ⊢ e ↑ A ⊣ Δ - -- --------------------- →App - -- Γ ⊢ A → C • e ⇓ C ⊣ Δ - TFun t1 t2 -> do - expt' <- check exp t1 - pure (expt', t2) - - _ -> throwError ("Cannot apply type " ++ show typ ++ " with expression " ++ show exp) - ---------------------------------------------------------------------------- --- * Pattern matching ---------------------------------------------------------------------------- - --- Γ ⊢ p ⇒ e ∷ A ↓ B ⊣ Θ --- Θ ⊢ Π ∷ [Θ]A ↓ C ⊣ Δ --- [Δ]B <: C --- --------------------------- --- Γ ⊢ (p ⇒ e),Π ∷ A ↓ C ⊣ Δ -inferBranches :: [Branch] -> Type -> Tc ([T.Branch' Type], Type) -inferBranches branches t_patt = do - (branches', ts_exp) <- inferBranches' t_patt branches - ts_exp' <- mapM applyEnv ts_exp - let (monos, pols) = partition isMono ts_exp' - t_exp <- liftEither $ bodyType t_patt monos - mapM_ (subtype t_exp) pols - pure (branches', t_exp) - where - - bodyType :: Type -> [Type] -> Err Type - bodyType t_patt = \case - [] -> pure t_patt - [m] -> pure m - m:n:ms | m == n -> bodyType t_patt (n:ms) - | otherwise -> throwError $ unwords [ "Wrong return types: " - , ppT m, "≠", ppT n ] - - inferBranches' = go [] [] - where - go branches ts_exp t = \case - [] -> pure (branches, ts_exp) - b:bs -> do - (b', t_e) <- inferBranch b t - t' <- applyEnv t - go (snoc b' branches) (snoc t_e ts_exp) t' bs - --- Γ ⊢ p ↑ A ⊣ Θ Θ ⊢ e ↓ C ⊣ Δ --- ------------------------------- --- Γ ⊢ p ⇒ e ∷ A ↓ C ⊣ Δ -inferBranch :: Branch -> Type -> Tc (T.Branch' Type, Type) -inferBranch (Branch patt exp) t_patt = do - patt' <- checkPattern patt t_patt - (exp', t_exp) <- infer exp - pure (T.Branch patt' (exp', t_exp), t_exp) - -checkPattern :: Pattern -> Type -> Tc (T.Pattern' Type, Type) -checkPattern patt t_patt = case patt of - - -- ------------------- - -- Γ ⊢ x ↑ A ⊣ Γ,(x:A) - PVar x -> do - insertEnv $ EnvVar x t_patt - pure (T.PVar (coerce x, t_patt), t_patt) - - -- ------------- - -- Γ ⊢ _ ↑ A ⊣ Γ - PCatch -> pure (T.PCatch, t_patt) - - -- Γ ⊢ τ ↓ A ⊣ Γ Γ ⊢ A <: B ⊣ Δ - -- ------------------------------ - -- Γ ⊢ τ ↑ B ⊣ Δ - PLit lit -> do - subtype (litType lit) t_patt - t_patt' <- applyEnv t_patt - pure (T.PLit (lit, t_patt), t_patt') - - -- (x : A) ∈ Γ Γ ⊢ A <: B ⊣ Δ - -- --------------------------- - -- Γ ⊢ inj₀ x ↑ B ⊣ Δ - PEnum name -> do - t <- maybeToRightM ("Unknown constructor " ++ show name) - =<< lookupInj name - subtype t t_patt - t_patt' <- applyEnv t_patt - pure (T.PEnum (coerce name), t_patt') - - - PInj name ps -> do - t_inj <- maybeToRightM "unknown constructor" =<< lookupInj name - t_inj' <- foldrM substitute' t_inj $ getInitForalls t_inj - subtype (getDataId t_inj') t_patt - t_inj'' <- applyEnv t_inj' - let ts_inj = getParams t_inj'' - ps' <- zipWithM (\p t -> checkPattern p =<< applyEnv t) ps ts_inj - t_patt' <- applyEnv t_patt - pure (T.PInj (coerce name) (map fst ps'), t_patt') - where - substitute' fa t = do - tevar <- fresh - -- insertEnv (EnvTEVar tevar) - pure $ substitute tvar tevar t - where - TAll tvar _ = fa int - - getParams = \case - TAll _ t -> getParams t - t -> go [] t - where - go acc = \case - TFun t1 t2 -> go (snoc t1 acc) t2 - _ -> acc - - getDataId typ = case typ of - TAll _ t -> getDataId t - TFun _ t -> getDataId t - TData {} -> typ - - getInitForalls = go [] - where - go acc = \case - TAll tvar t -> go (snoc (TAll tvar) acc) t - _ -> acc --------------------------------------------------------------------------- -- * Auxiliary @@ -677,55 +656,6 @@ splitOn x env = second (S.drop 1) $ S.breakl (==x) env dropTrailing :: EnvElem -> Tc () dropTrailing x = modifyEnv $ S.takeWhileL (/= x) -applyEnvExpT :: (T.Exp' Type, Type) -> Tc (T.Exp' Type, Type) -applyEnvExpT (e, t) = liftA2 (,) (applyEnvExp e) (applyEnv t) - -applyEnvExp :: T.Exp' Type -> Tc (T.Exp' Type) -applyEnvExp exp = case exp of - T.ELet (T.Bind id vars rhs) exp -> do - id <- applyEnvId id - vars' <- mapM applyEnvId vars - rhs' <- applyEnvExpT rhs - exp' <- applyEnvExpT exp - pure $ T.ELet (T.Bind id vars' rhs') exp' - T.EApp e1 e2 -> liftA2 T.EApp (applyEnvExpT e1) (applyEnvExpT e2) - T.EAdd e1 e2 -> liftA2 T.EAdd (applyEnvExpT e1) (applyEnvExpT e2) - T.EAbs name e -> T.EAbs name <$> applyEnvExpT e - T.ECase e branches -> liftA2 T.ECase (applyEnvExpT e) - (mapM applyEnvBranch branches) - _ -> pure exp - where - applyEnvId = secondM applyEnv - applyEnvBranch (T.Branch (p, t) e) = do - pt <- liftA2 (,) (applyEnvPattern p) (applyEnv t) - e' <- applyEnvExpT e - pure $ T.Branch pt e' - applyEnvPattern = \case - T.PVar id -> T.PVar <$> applyEnvId id - T.PLit (lit, t) -> T.PLit . (lit, ) <$> applyEnv t - T.PInj name ps -> T.PInj name <$> mapM applyEnvPattern ps - p -> pure p - -applyEnv :: Type -> Tc Type -applyEnv t = gets $ (`applyEnv'` t) . env - --- | [Γ]A. Applies context to type until fully applied. -applyEnv' :: Env -> Type -> Type -applyEnv' cxt typ | typ == typ' = typ' - | otherwise = applyEnv' cxt typ' - where - typ' = case typ of - TLit _ -> typ - TData name typs -> TData name $ map (applyEnv' cxt) typs - -- [Γ]α = α - TVar _ -> typ - -- [Γ[ά=τ]]ά = [Γ[ά=τ]]τ - -- [Γ[ά]]ά = [Γ[ά]]ά - TEVar tevar -> fromMaybe typ $ findSolved tevar cxt - -- [Γ](A → B) = [Γ]A → [Γ]B - TFun t1 t2 -> on TFun (applyEnv' cxt) t1 t2 - -- [Γ](∀α. A) = (∀α. [Γ]A) - TAll tvar t -> TAll tvar $ applyEnv' cxt t findSolved :: TEVar -> Env -> Maybe Type findSolved _ Empty = Nothing @@ -765,27 +695,18 @@ wellFormed env = \case TData _ typs -> mapM_ (wellFormed env) typs -noForall :: Type -> Bool -noForall = \case - TAll{} -> False - TFun t1 t2 -> on (&&) noForall t1 t2 - TData _ typs -> all noForall typs - TVar _ -> True - TEVar _ -> True - TLit _ -> True - isMono :: Type -> Bool isMono = \case TAll{} -> False TFun t1 t2 -> on (&&) isMono t1 t2 TData _ typs -> all isMono typs - TVar _ -> False - TEVar _ -> False + TVar _ -> True + TEVar _ -> True TLit _ -> True fresh :: Tc TEVar fresh = do - tevar <- gets (MkTEVar . LIdent . ("a#" ++) . show . next_tevar) + tevar <- gets (MkTEVar . LIdent . show . next_tevar) modify $ \cxt -> cxt { next_tevar = succ cxt.next_tevar } pure tevar @@ -805,7 +726,6 @@ getReturn = snd . partitionType partitionType :: Type -> ([Type], Type) partitionType = go [] . skipForalls' where - go acc t = case t of TFun t1 t2 -> go (snoc t1 acc) t2 _ -> (acc, t) @@ -863,6 +783,74 @@ modifyEnv f = pattern DBind' name vars exp = DBind (Bind name vars exp) pattern DSig' name typ = DSig (Sig name typ) + +--------------------------------------------------------------------------- +-- * Apply +--------------------------------------------------------------------------- + +class Apply a where + apply :: a -> Tc a + +instance Apply Type where apply = applyType +instance Apply (T.Exp' Type) where apply = applyExp +instance Apply (T.Branch' Type) where apply = applyBranch +instance Apply (T.Pattern' Type) where apply = applyPattern +instance Apply a => Apply [a] where apply = mapM apply +instance (Apply a, Apply b) => Apply (a, b) where apply = applyPair +instance Apply T.Ident where apply = pure + +applyType :: Type -> Tc Type +applyType t = gets $ (`applyType'` t) . env + +-- | [Γ]A. Applies context to type until fully applied. +applyType' :: Env -> Type -> Type +applyType' cxt typ | typ == typ' = typ' + | otherwise = applyType' cxt typ' + where + typ' = case typ of + TLit _ -> typ + TData name typs -> TData name $ map (applyType' cxt) typs + -- [Γ]α = α + TVar _ -> typ + -- [Γ[ά=τ]]ά = [Γ[ά=τ]]τ + -- [Γ[ά]]ά = [Γ[ά]]ά + TEVar tevar -> fromMaybe typ $ findSolved tevar cxt + -- [Γ](A → B) = [Γ]A → [Γ]B + TFun t1 t2 -> on TFun (applyType' cxt) t1 t2 + -- [Γ](∀α. A) = (∀α. [Γ]A) + TAll tvar t -> TAll tvar $ applyType' cxt t + +applyExp :: T.Exp' Type -> Tc (T.Exp' Type) +applyExp exp = case exp of + T.ELet (T.Bind id vars rhs) exp -> do + id <- apply id + vars' <- mapM apply vars + rhs' <- apply rhs + exp' <- apply exp + pure $ T.ELet (T.Bind id vars' rhs') exp' + T.EApp e1 e2 -> liftA2 T.EApp (apply e1) (apply e2) + T.EAdd e1 e2 -> liftA2 T.EAdd (apply e1) (apply e2) + T.EAbs name e -> T.EAbs name <$> apply e + T.ECase e branches -> liftA2 T.ECase (apply e) + (mapM apply branches) + _ -> pure exp + +applyBranch :: T.Branch' Type -> Tc (T.Branch' Type) +applyBranch (T.Branch (p, t) e) = do + pt <- liftA2 (,) (apply p) (apply t) + e' <- apply e + pure $ T.Branch pt e' + +applyPattern :: T.Pattern' Type -> Tc (T.Pattern' Type) +applyPattern = \case + T.PVar id -> T.PVar <$> apply id + T.PLit (lit, t) -> T.PLit . (lit, ) <$> apply t + T.PInj name ps -> T.PInj name <$> apply ps + p -> pure p + +applyPair :: (Apply a, Apply b) => (a, b) -> Tc (a, b) +applyPair (x, y) = liftA2 (,) (apply x) (apply y) + --------------------------------------------------------------------------- -- * Debug --------------------------------------------------------------------------- @@ -873,24 +861,24 @@ traceEnv s = do traceD s x = trace (s ++ " " ++ show x) pure () -traceT s x = trace (s ++ " " ++ ppT x) pure () +traceT s x = trace (s ++ " : " ++ ppT x) pure () traceTs s xs = trace (s ++ " [ " ++ intercalate ", " (map ppT xs) ++ " ]") pure () ppT = \case TLit (UIdent s) -> s - TVar (MkTVar (LIdent s)) -> "a_" ++ s + TVar (MkTVar (LIdent s)) -> "tvar_" ++ s TFun t1 t2 -> ppT t1 ++ "->" ++ ppT t2 TAll (MkTVar (LIdent s)) t -> "forall " ++ s ++ ". " ++ ppT t - TEVar (MkTEVar (LIdent s)) -> "a^_" ++ s + TEVar (MkTEVar (LIdent s)) -> "tevar_" ++ s TData (UIdent name) typs -> name ++ " (" ++ unwords (map ppT typs) ++ " )" ppEnvElem = \case EnvVar (LIdent s) t -> s ++ ":" ++ ppT t - EnvTVar (MkTVar (LIdent s)) -> "a_" ++ s - EnvTEVar (MkTEVar (LIdent s)) -> "a^_" ++ s - EnvTEVarSolved (MkTEVar (LIdent s)) t -> "_" ++ s ++ "=" ++ ppT t - EnvMark (MkTEVar (LIdent s)) -> "▶" ++ "a^_" ++ s + EnvTVar (MkTVar (LIdent s)) -> "tvar_" ++ s + EnvTEVar (MkTEVar (LIdent s)) -> "tevar_" ++ s + EnvTEVarSolved (MkTEVar (LIdent s)) t -> "tevar_" ++ s ++ "=" ++ ppT t + EnvMark (MkTEVar (LIdent s)) -> "▶" ++ "tevar_" ++ s ppEnv = \case Empty -> "·" diff --git a/tests/TestTypeCheckerBidir.hs b/tests/TestTypeCheckerBidir.hs index 4cf98f2..916b688 100644 --- a/tests/TestTypeCheckerBidir.hs +++ b/tests/TestTypeCheckerBidir.hs @@ -260,6 +260,17 @@ tc_pol_case = describe "Polymophic and recursive pattern matching" $ do , " Cons (Cons _ ys) xs => 1 + elems (Cons ys xs)" ] +tc_if = specify "Test if else case expression" $ do + run [ "data Bool () where" + , " True : Bool ()" + , " False : Bool ()" + + , "ifThenElse : Bool () -> a -> a -> a" + , "ifThenElse b if else = case b of" + , " True => if" + , " False => else" + ] `shouldSatisfy` ok + tc_infer_case = describe "Infer case expression" $ do specify "Wrong case expression rejected" $