diff --git a/src/TypeChecker/TypeCheckerBidir.hs b/src/TypeChecker/TypeCheckerBidir.hs index fb9e93d..9222755 100644 --- a/src/TypeChecker/TypeCheckerBidir.hs +++ b/src/TypeChecker/TypeCheckerBidir.hs @@ -4,19 +4,18 @@ {-# LANGUAGE PatternSynonyms #-} {-# OPTIONS_GHC -Wno-incomplete-patterns #-} -module TypeChecker.TypeCheckerBidir (typecheck, getVars) where +module TypeChecker.TypeCheckerBidir (typecheck) where -import Auxiliary (int, liftMM2, litType, - maybeToRightM, onM, onMM, snoc) -import Control.Applicative (Alternative, Applicative (liftA2), - (<|>)) +import Auxiliary (int, litType, maybeToRightM, snoc) +import Control.Applicative (Applicative (liftA2), (<|>)) import Control.Monad.Except (ExceptT, MonadError (throwError), runExceptT, unless, zipWithM, zipWithM_) -import Control.Monad.Extra (fromMaybeM, maybeM) +import Control.Monad.Extra (fromMaybeM) import Control.Monad.State (MonadState, State, evalState, gets, modify) import Data.Coerce (coerce) +import Data.Foldable (foldlM) import Data.Function (on) import Data.List (intercalate) import Data.Map (Map) @@ -38,7 +37,8 @@ import qualified TypeChecker.TypeCheckerIr as T -- -- TODO -- • Fix problems with types in Pattern/Branch in TypeCheckerIr --- • Fix the different type getters functions (e.g. partitionType) functions +-- • Remove EAdd +-- • Add kinds!! data EnvElem = EnvVar LIdent Type -- ^ Term variable typing. x : A | EnvTVar TVar -- ^ Universal type variable. α @@ -140,7 +140,7 @@ typecheckInj :: Inj -> UIdent -> [TVar] -> Err (T.Inj' Type) typecheckInj (Inj inj_name inj_typ) name tvars | not $ boundTVars tvars inj_typ = throwError "Unbound type variables" - | TData name' typs <- getReturn inj_typ + | TData name' typs <- getDataId inj_typ , name' == name , Right tvars' <- mapM toTVar typs , all (`elem` tvars) tvars' @@ -149,7 +149,7 @@ typecheckInj (Inj inj_name inj_typ) name tvars = throwError $ unwords ["Bad type constructor: ", show name , "\nExpected: ", ppT . TData name $ map TVar tvars - , "\nActual: ", ppT $ getReturn inj_typ + , "\nActual: ", ppT $ getDataId inj_typ ] where boundTVars :: [TVar] -> Type -> Bool @@ -161,6 +161,8 @@ typecheckInj (Inj inj_name inj_typ) name tvars TLit _ -> True TEVar _ -> error "TEVar in data type declaration" + + --------------------------------------------------------------------------- -- * Typing rules --------------------------------------------------------------------------- @@ -200,10 +202,72 @@ check e b = do subtype a b' apply (e', b) +checkPattern :: Pattern -> Type -> Tc (T.Pattern' Type, Type) +checkPattern patt t_patt = case patt of + + -- ------------------- + -- Γ ⊢ x ↑ A ⊣ Γ,(x:A) + PVar x -> do + insertEnv $ EnvVar x t_patt + apply (T.PVar (coerce x, t_patt), t_patt) + + -- ------------- + -- Γ ⊢ _ ↑ A ⊣ Γ + PCatch -> apply (T.PCatch, t_patt) + + -- Γ ⊢ τ ↓ A ⊣ Γ Γ ⊢ A <: B ⊣ Δ + -- ------------------------------ + -- Γ ⊢ τ ↑ B ⊣ Δ + PLit lit -> do + subtype (litType lit) t_patt + apply (T.PLit (lit, t_patt), t_patt) + + -- Γ ∋ (K : A) Γ ⊢ A <: B ⊣ Δ + -- --------------------------- + -- Γ ⊢ K ↑ B ⊣ Δ + PEnum name -> do + t <- maybeToRightM ("Unknown constructor " ++ show name) + =<< lookupInj name + subtype t t_patt + apply (T.PEnum (coerce name), t_patt) + + -- Example + -- Γ ∋ (K : A) let A = ∀α. A₁ -> A₂ -> Tτs + -- Γ ⊢ [ά/α]Tτs <: B ⊣ Θ₁ + -- Θ ⊢ p₁ ↑ [Θ][ά/α]A₁ ⊣ Θ₂ + -- Θ₂ ⊢ p₂ ↑ [Θ₂][ά/α]A₂ ⊣ Δ + -- --------------------------- + -- Γ ⊢ K p₁ p₂ ↑ B ⊣ Δ + PInj name ps -> do + t_inj <- maybeToRightM "unknown constructor" =<< lookupInj name + let ts = getArgs t_inj + unless (length ts == length ps) + $ throwError "Wrong number of arguments!" + + -- [ά/α] + sub <- substituteTVarsOf t_inj + subtype (sub $ getDataId t_inj) t_patt + let check p t = checkPattern p =<< apply (sub t) + ps' <- zipWithM check ps ts + apply (T.PInj (coerce name) (map fst ps'), t_patt) + where + substituteTVarsOf = \case + TAll tvar t -> do + tevar <- fresh + (substitute tvar tevar .) <$> substituteTVarsOf t + _ -> pure id + + getArgs = \case + TAll _ t -> getArgs t + t -> go [] t + where + go acc = \case + TFun t1 t2 -> go (snoc t1 acc) t2 + _ -> acc + -- | Γ ⊢ e ↓ A ⊣ Δ -- Under input context Γ, e infers output type A, with output context ∆ infer :: Exp -> Tc (T.ExpT' Type) - infer (ELit lit) = apply (T.ELit lit, litType lit) -- Γ ∋ (x : A) Γ ∌ (x : A) @@ -273,14 +337,23 @@ infer (EAdd e1 e2) = do e2' <- check e2 int apply (T.EAdd e1' e2', int) --- Θ ⊢ Π ∷ A ↓ C ⊣ Δ --- Γ ⊢ e ↓ A ⊣ Θ Δ ⊢ Π covers [Δ]A TODO --- --------------------------------------- +-- Γ ⊢ e ↑ A ⊣ Θ Θ ⊢ Π ∷ [Θ]A ↑ C ⊣ Δ +-- ------------------------------------ Case -- Γ ⊢ case e of Π ↓ C ⊣ Δ -infer (ECase scrut branches) = do - (scrut', t_scrut) <- infer scrut - (branches', t_return) <- inferBranches branches t_scrut - apply (T.ECase (scrut', t_scrut) branches', t_return) +infer (ECase scrut pi) = do + (scrut', a) <- infer scrut + case pi of + [] -> apply (T.ECase (scrut', a) [], a) + (Branch _ e):_ -> do + (_, c)<- infer e + (pi', c') <- foldlM go ([], c) pi + apply (T.ECase (scrut', a) pi', c') + where + go (bs, c) (Branch p e) = do + p' <- checkPattern p =<< apply a + e'@(_, c') <- infer e + subtype c' c + apply (T.Branch p' e' : bs, c') -- | Γ ⊢ A • e ⇓ C ⊣ Δ -- Under input context Γ , applying a function of type A to e infers type C, with output context ∆ @@ -319,112 +392,6 @@ applyInfer (TFun a c) e = do applyInfer a e = throwError ("Cannot apply type " ++ show a ++ " with expression " ++ show e) ---------------------------------------------------------------------------- --- * Pattern matching ---------------------------------------------------------------------------- - --- Γ ⊢ p ⇒ e ∷ A ↓ B ⊣ Θ --- Θ ⊢ Π ∷ [Θ]A ↓ C ⊣ Δ --- [Δ]B <: C --- --------------------------- --- Γ ⊢ (p ⇒ e),Π ∷ A ↓ C ⊣ Δ -inferBranches :: [Branch] -> Type -> Tc ([T.Branch' Type], Type) -inferBranches branches t_patt = do - (branches', ts_exp) <- inferBranches' t_patt branches - t_exp <- case ts_exp of - [] -> pure t_patt - t:_ -> do - zipWithM_ (onMM subtype apply) (init ts_exp) (tail ts_exp) - apply t - apply (branches', t_exp) - where - - inferBranches' = go [] [] - where - go branches ts_exp t = \case - [] -> pure (branches, ts_exp) - b:bs -> do - (b', t_e) <- inferBranch b t - t' <- apply t - go (snoc b' branches) (snoc t_e ts_exp) t' bs - --- Γ ⊢ p ↑ A ⊣ Θ Θ ⊢ e ↓ C ⊣ Δ --- ------------------------------- --- Γ ⊢ p ⇒ e ∷ A ↓ C ⊣ Δ -inferBranch :: Branch -> Type -> Tc (T.Branch' Type, Type) -inferBranch (Branch patt exp) t_patt = do - patt' <- checkPattern patt t_patt - (exp', t_exp) <- infer exp - apply (T.Branch patt' (exp', t_exp), t_exp) - -checkPattern :: Pattern -> Type -> Tc (T.Pattern' Type, Type) -checkPattern patt t_patt = case patt of - - -- ------------------- - -- Γ ⊢ x ↑ A ⊣ Γ,(x:A) - PVar x -> do - insertEnv $ EnvVar x t_patt - apply (T.PVar (coerce x, t_patt), t_patt) - - -- ------------- - -- Γ ⊢ _ ↑ A ⊣ Γ - PCatch -> apply (T.PCatch, t_patt) - - -- Γ ⊢ τ ↓ A ⊣ Γ Γ ⊢ A <: B ⊣ Δ - -- ------------------------------ - -- Γ ⊢ τ ↑ B ⊣ Δ - PLit lit -> do - subtype (litType lit) t_patt - apply (T.PLit (lit, t_patt), t_patt) - - -- Γ ∋ (K : A) Γ ⊢ A <: B ⊣ Δ - -- --------------------------- - -- Γ ⊢ K ↑ B ⊣ Δ - PEnum name -> do - t <- maybeToRightM ("Unknown constructor " ++ show name) - =<< lookupInj name - subtype t t_patt - apply (T.PEnum (coerce name), t_patt) - - - -- Example - -- Γ ∋ (K : A) let A = ∀α. A₁ -> A₂ -> Tτs - -- Γ ⊢ [ά/α]Tτs <: B ⊣ Θ₁ - -- Θ ⊢ p₁ ↑ [Θ][ά/α]A₁ ⊣ Θ₂ - -- Θ ⊢ p₂ ↑ [Θ][ά/α]A₂ ⊣ Δ - -- --------------------------- - -- Γ ⊢ K p₁ p₂ ↑ B ⊣ Δ - PInj name ps -> do - t_inj <- maybeToRightM "unknown constructor" =<< lookupInj name - let ts = getParams t_inj - unless (length ts == length ps) $ - throwError "Wrong number of arguments!" - sub <- substituteTVarsOf t_inj - subtype (sub $ getDataId t_inj) t_patt - let checkP p t = checkPattern p =<< apply (sub t) - ps' <- zipWithM checkP ps ts - apply (T.PInj (coerce name) (map fst ps'), t_patt) - where - substituteTVarsOf = \case - TAll tvar t -> do - tevar <- fresh - (substitute tvar tevar .) <$> substituteTVarsOf t - _ -> pure id - - getParams = \case - TAll _ t -> getParams t - t -> go [] t - where - go acc = \case - TFun t1 t2 -> go (snoc t1 acc) t2 - _ -> acc - - getDataId typ = case typ of - TAll _ t -> getDataId t - TFun _ t -> getDataId t - TData {} -> typ - - --------------------------------------------------------------------------- -- * Subtyping rules --------------------------------------------------------------------------- @@ -482,7 +449,6 @@ subtype (TEVar alpha) a | notElem alpha $ frees a = instantiateL alpha a -- Γ[ά] ⊢ A <: ά ⊣ Δ subtype a (TEVar alpha) | notElem alpha $ frees a = instantiateR a alpha - subtype t1 t2 = case (t1, t2) of (TData name1 typs1, TData name2 typs2) @@ -564,14 +530,13 @@ instantiateL alpha a = gets env >>= \env -> go env alpha a putEnv env_l go _ alpha a = error $ "Trying to instantiateL: " ++ ppT (TEVar alpha) - ++ " <: " ++ ppT a + ++ " <: " ++ ppT a -- | Γ ⊢ A =:< ά ⊣ Δ -- Under input context Γ, instantiate ά such that A <: ά, with output context ∆ instantiateR :: Type -> TEVar -> Tc () instantiateR a alpha = gets env >>= \env -> go env a alpha where - -- Γ ⊢ τ -- ----------------------------- InstRSolve -- Γ,ά,Γ' ⊢ τ =:< ά ⊣ Γ,(ά=τ),Γ' @@ -588,11 +553,9 @@ instantiateR a alpha = gets env >>= \env -> go env a alpha let (env_l, env_r) = splitOn (EnvTEVar epsilon) env putEnv $ (env_l :|> EnvTEVarSolved epsilon (TEVar alpha)) <> env_r - - - -- Γ[ά₂ά₁,(ά=ά₁→ά₂)] ⊢ A₁ :=< ά₁ ⊣ Θ Θ ⊢ ά₂ =:< [Θ]A₂ ⊣ Δ - -- ------------------------------------------------------- InstRArr - -- Γ[ά] ⊢ A₁ → A₂ =:< ά ⊣ Δ + -- Γ[ά₂ά₁,(ά=ά₁→ά₂)] ⊢ A₁ :=< ά₁ ⊣ Θ Θ ⊢ ά₂ =:< [Θ]A₂ ⊣ Δ + -- ------------------------------------------------------- InstRArr + -- Γ[ά] ⊢ A₁ → A₂ =:< ά ⊣ Δ go _ (TFun a1 a2) alpha = do alpha1 <- fresh alpha2 <- fresh @@ -603,24 +566,19 @@ instantiateR a alpha = gets env >>= \env -> go env a alpha a2' <- apply a2 instantiateR a2' alpha2 - - - -- Γ[ά],▶έ,ε ⊢ [έ/ε]E =:< ά ⊣ Δ,▶έ,Δ' - -- ---------------------------------- InstRAIIL - -- Γ[ά] ⊢ ∀ε.Ε =:< ά ⊣ Δ + -- Γ[ά],▶έ,ε ⊢ [έ/ε]E =:< ά ⊣ Δ,▶έ,Δ' + -- ---------------------------------- InstRAIIL + -- Γ[ά] ⊢ ∀ε.Ε =:< ά ⊣ Δ go env (TAll epsilon e) alpha = do - epsilon' <- fresh - insertEnv $ EnvMark epsilon' - insertEnv $ EnvTVar epsilon - instantiateR (substitute epsilon epsilon' e) alpha - let (env_l, _) = splitOn (EnvMark epsilon') env - putEnv env_l + epsilon' <- fresh + insertEnv $ EnvMark epsilon' + insertEnv $ EnvTVar epsilon + instantiateR (substitute epsilon epsilon' e) alpha + let (env_l, _) = splitOn (EnvMark epsilon') env + putEnv env_l go _ a alpha = error $ "Trying to instantiateR: " ++ ppT a ++ " <: " - ++ ppT (TEVar alpha) - - - + ++ ppT (TEVar alpha) --------------------------------------------------------------------------- -- * Auxiliary @@ -713,35 +671,6 @@ fresh = do modify $ \cxt -> cxt { next_tevar = succ cxt.next_tevar } pure tevar -getVars :: Type -> [Type] -getVars = fst . partitionType - -getReturn :: Type -> Type -getReturn = snd . partitionType - --- | Partion type into variable types and return type. --- --- ∀a.∀b. a → (∀c. c → c) → b --- ([a, ∀c. c → c], b) --- --- Unsure if foralls should be added to the return type or not. --- FIXME -partitionType :: Type -> ([Type], Type) -partitionType = go [] . skipForalls' - where - go acc t = case t of - TFun t1 t2 -> go (snoc t1 acc) t2 - _ -> (acc, t) - -skipForalls' :: Type -> Type -skipForalls' = snd . skipForalls - -skipForalls :: Type -> ([Type -> Type], Type) -skipForalls = go [] - where - go acc typ = case typ of - TAll tvar t -> go (snoc (TAll tvar) acc) t - _ -> (acc, typ) isComplete :: Env -> Bool isComplete = isNothing . S.findIndexL unSolvedTEVar @@ -750,6 +679,12 @@ isComplete = isNothing . S.findIndexL unSolvedTEVar EnvTEVar _ -> True _ -> False +getDataId :: Type -> Type +getDataId typ = case typ of + TAll _ t -> getDataId t + TFun _ t -> getDataId t + TData {} -> typ + toTVar :: Type -> Err TVar toTVar = \case TVar tvar -> pure tvar @@ -764,7 +699,6 @@ lookupSig x = gets (Map.lookup x . sig) insertSig :: LIdent -> Type -> Tc () insertSig name t = modify $ \cxt -> cxt { sig = Map.insert name t cxt.sig } - lookupEnv :: LIdent -> Tc (Maybe Type) lookupEnv x = gets (findId . env) where @@ -786,7 +720,6 @@ modifyEnv f = pattern DBind' name vars exp = DBind (Bind name vars exp) pattern DSig' name typ = DSig (Sig name typ) - --------------------------------------------------------------------------- -- * Apply ---------------------------------------------------------------------------