From 5e15983f4c36d2087b446ddc1f29c3bfe846f278 Mon Sep 17 00:00:00 2001 From: Martin Fredin Date: Mon, 15 May 2023 00:31:30 +0200 Subject: [PATCH] Revork type checking of data types to make in reliable --- src/TypeChecker/TypeCheckerBidir.hs | 357 ++++++++++++---------------- src/TypeChecker/TypeCheckerIr.hs | 2 +- 2 files changed, 151 insertions(+), 208 deletions(-) diff --git a/src/TypeChecker/TypeCheckerBidir.hs b/src/TypeChecker/TypeCheckerBidir.hs index 3745b0d..9b4765f 100644 --- a/src/TypeChecker/TypeCheckerBidir.hs +++ b/src/TypeChecker/TypeCheckerBidir.hs @@ -7,17 +7,18 @@ module TypeChecker.TypeCheckerBidir (typecheck) where -import Auxiliary (int, maybeToRightM, onM, snoc, - typeof) +import Auxiliary (int, mapAccumM, maybeToRightM, onM, + snoc, typeof) import Control.Applicative (Applicative (liftA2), liftA3, (<|>)) import Control.Monad.Except (ExceptT, MonadError (throwError), - MonadTrans (lift), forM, runExceptT, - unless, zipWithM, zipWithM_) -import Control.Monad.Extra (fromMaybeM, ifM) + MonadTrans (lift), foldM, forM, + forM_, runExceptT, unless, zipWithM, + zipWithM_) +import Control.Monad.Extra (fromMaybeM, ifM, maybeM, unlessM) import Control.Monad.State (MonadState (get), State, StateT, evalState, evalStateT, gets, modify) import Data.Coerce (coerce) -import Data.Foldable (foldlM) +import Data.Foldable (foldlM, foldrM) import Data.Function (on) import Data.List (intercalate) import Data.Map (Map) @@ -25,8 +26,9 @@ import qualified Data.Map as Map import Data.Maybe (fromMaybe, isNothing) import Data.Sequence (Seq (..)) import qualified Data.Sequence as S +import Data.Set (Set) import qualified Data.Set as Set -import Data.Tuple.Extra (second) +import Data.Tuple.Extra (first, second) import Debug.Trace (trace) import Grammar.Abs import Grammar.ErrM @@ -59,14 +61,13 @@ data Cxt = Cxt , sig :: Map LIdent Type -- ^ Top-level signatures x : A , binds :: Map LIdent Exp -- ^ Top-level binds x : e , next_tevar :: Int -- ^ Counter to distinguish ά - , data_injs :: Map UIdent Type -- ^ Data injections (constructors) K/inj : A + , data_injs :: Map UIdent (Type, [TVar]) -- ^ Data injections (constructors) K : A , currentBind :: LIdent -- ^ Used for recursive functions } deriving (Show, Eq) newtype Tc a = Tc { runTc :: ExceptT String (State Cxt) a } deriving (Functor, Applicative, Monad, MonadState Cxt, MonadError String) - initCxt :: [Def] -> Cxt initCxt defs = Cxt { env = mempty @@ -77,14 +78,14 @@ initCxt defs = Cxt | DBind' name vars rhs <- defs ] , next_tevar = 0 - , data_injs = Map.fromList [ (name, foldr TAll t $ unboundedTVars t) + , data_injs = Map.fromList [ (name, (t, unboundedTVars t)) | DData (Data _ injs) <- defs , Inj name t <- injs ] , currentBind = "" } where - unboundedTVars = uncurry (Set.\\) . go (mempty, mempty) + unboundedTVars = Set.toList . uncurry (Set.\\) . go (mempty, mempty) where go (unbounded, bounded) = \case TAll tvar t -> go (unbounded, Set.insert tvar bounded) t @@ -108,16 +109,17 @@ typecheckBinds cxt = flip evalState cxt typecheckBind :: Bind -> Tc (T.Bind' Type) typecheckBind (Bind name vars rhs) = do modify $ \cxt -> cxt { currentBind = name } - bind'@(T.Bind (name, typ) _ _) <- lookupSig name >>= \case + bind' <- lookupSig name >>= \case Just t -> do + insertSig (coerce name) t (rhs', _) <- check (foldr EAbs rhs vars) t pure (T.Bind (coerce name, t) [] (rhs', t)) Nothing -> do (e, t) <- apply =<< infer (foldr EAbs rhs vars) + insertSig (coerce name) t pure (T.Bind (coerce name, t) [] (e, t)) env <- gets env unless (isComplete env) err - insertSig (coerce name) typ putEnv Empty pure bind' where @@ -166,8 +168,6 @@ typecheckInj (Inj inj_name inj_typ) name tvars TLit _ -> True TEVar _ -> error "TEVar in data type declaration" - - --------------------------------------------------------------------------- -- * Typing rules --------------------------------------------------------------------------- @@ -199,26 +199,30 @@ check (EAbs x e) (TFun a b) = do putEnv env_l apply (T.EAbs (coerce x) e', TFun a b) - -- Γ ⊢ e ↑ A ⊣ Θ Θ ⊢ [Θ]A <: [Θ]C ⊣ Δ -- ----------------------------------- CaseEmpty -- Γ ⊢ case e of {} ↓ C ⊣ Δ --- Θ₁ ⊢ p₁⇒e₁ ↓ [Θ₁]C ⊣ Θ₂ +-- Θ₁ ⊢ p₁⇒e₁ ↑ [Θ₁]C ⊣ Θ₂ -- ... --- Γ ⊢ e ↑ A ⊣ Θ₁ Θₙ ⊢ pₙ⇒eₙ ↓ [Θₙ]C ⊣ Δ +-- Γ ⊢ e ↑ A ⊣ Θ₁ Θₙ ⊢ pₙ⇒eₙ ↑ [Θₙ]C ⊣ Δ -- --------------------------------------- Case --- Γ ⊢ case e of {p₁⇒e₁ ‥ pₙ⇒eₙ} ↓ C ⊣ Δ -check (ECase scrut branches) c = do - (scrut', a) <- infer scrut +-- Γ ⊢ case e of {p₁⇒e₁ ‥ pₙ⇒eₙ} ↑ C ⊣ Δ +check (ECase e branches) c = do + (e', a) <- infer e case branches of [] -> do - subtype a c - apply (T.ECase (scrut', a) [], a) + subtype a =<< apply c + apply (T.ECase (e', a) [], a) _ -> do - branches' <- checkBranches branches a c - apply (T.ECase (scrut', a) branches', c) - + branches' <- evalStateT (checkBranches a) mempty + apply (T.ECase (e', a) branches', c) + where + checkBranches a = forM branches $ \(Branch p e) -> do + p' <- match p =<< lift (apply a) + lift $ do + e' <- check e c + apply (T.Branch p' e') -- Γ,α ⊢ e ↓ A ⊣ Θ Θ ⊢ [Θ]A <: [Θ]B ⊣ Δ -- -------------------------------------- Sub @@ -229,134 +233,6 @@ check e b = do subtype a b' apply (e', b) - --- Γ ⊢ p₁⇒e₁ ‥ pₙ⇒eₙ :: A ↑ C ⊣ Δ -checkBranches :: [Branch] -> Type -> Type -> Tc [T.Branch' Type] -checkBranches branches a c = evalStateT go mempty - where - go = forM branches $ \(Branch p e) -> do - p' <- patternMatch p =<< lift (apply a) - lift $ do - e' <- check e c - apply (T.Branch p' e') - - - -substituteAll :: Type -> StateT (Map TVar TEVar) Tc Type -substituteAll t = get >>= \subs -> case t of - TAll tvar t - | Just tevar <- Map.lookup tvar subs -> - lift . pure $ substitute tvar tevar t - - | otherwise -> do - tevar <- lift fresh - modify (Map.insert tvar tevar) - substituteAll (substitute tvar tevar t) - TFun t1 t2 -> onM TFun substituteAll t1 t2 - t -> pure t - - - -patternMatch :: Pattern -> Type -> StateT (Map TVar TEVar) Tc (T' T.Pattern' Type) - --- ------------------- PVar --- Γ ⊢ x :: A ⊣ Γ,(x:A) -patternMatch (PVar x) a = lift $ do - insertEnv $ EnvVar x a - apply (T.PVar (coerce x), a) - - - --- ------------- PCatch --- Γ ⊢ _ :: A ⊣ Γ -patternMatch PCatch a = lift $ apply (T.PCatch, a) - --- Γ ⊢ typeof(lit) <: A ⊣ Δ --- ------------------------- PLit --- Γ ⊢ lit :: A ⊣ Δ -patternMatch (PLit lit) a = lift $ do - subtype (typeof lit) a - apply (T.PLit lit, typeof lit) - --- Γ ∋ (K : A) Γ ⊢ A <: C ⊣ Δ --- --------------------------- --- Γ ⊢ K :: C ⊣ Δ - -patternMatch (PEnum k) b = do - a <- lift (maybeToRightM ("Unknown constructor " ++ show k) =<< lookupInj k) - a <- substituteAll a - lift $ do - subtype a b - apply (T.PEnum (coerce k), a) - - --- β α Γ Γ Δ --- --- Γ ∋ (K : A) Θ₂ ⊢ p₁ :: [Θ₁]A₁ ⊣ Θ₂ --- Γ ⊢ ∀ά₁‥άₘ A₁ → ‥ → Aₙ₊₁ = substituteAll(A) ⊣ Θ₁ ... --- Θ₁ ⊢ Aₙ₊₁ <: B ⊣ Θ₂ Β Θₙ₊₁ ⊢ pₙ :: [Θₙ₊₁]Aₙ ⊣ Δ --- ----------------------------------------------------------------------------- PInj --- Γ ⊢ K p₁‥pₙ ↑ B ⊣ Δ -patternMatch (PInj k ps) b = do - a <- maybeToRightM ("Unknown constructor " ++ show k) =<< lift (lookupInj k) - a <- substituteAll a - lift $ subtype (getDataId a) b - - let as = getArgs a - unless (length as == length ps) $ throwError "Wrong number of arguments!" - - ps' <- zipWithM (\p a -> patternMatch p =<< lift (apply a)) ps as - lift $ apply (T.PInj (coerce k) ps', a) - where - - getArgs = \case - TAll _ t -> getArgs t - t -> go [] t - where - go acc = \case - TFun t1 t2 -> go (snoc t1 acc) t2 - _ -> acc - - - -- Example - -- Γ ∋ (K : A) let A = ∀α. A₁ -> A₂ -> Tτs - -- Γ ⊢ [ά/α]Tτs <: B ⊣ Θ₁ - -- Θ ⊢ p₁ ↑ [Θ][ά/α]A₁ ⊣ Θ₂ - -- Θ₂ ⊢ p₂ ↑ [Θ₂][ά/α]A₂ ⊣ Δ - -- --------------------------- - -- Γ ⊢ K p₁ p₂ ↑ B ⊣ Δ --- patternMatch (PInj name ps) t_patt = do --- t_inj <- maybeToRightM "unknown constructor" =<< lookupInj name --- let ts = getArgs t_inj --- unless (length ts == length ps) --- $ throwError "Wrong number of arguments!" --- --- -- [ά/α] --- sub <- substituteTVarsOf t_inj --- subtype (sub $ getDataId t_inj) t_patt --- let check p t = checkPattern p =<< apply (sub t) --- ps' <- zipWithM check ps ts --- apply (T.PInj (coerce name) ps', t_patt) --- where --- substituteTVarsOf = \case --- TAll tvar t -> do --- tevar <- fresh --- (substitute tvar tevar .) <$> substituteTVarsOf t --- _ -> pure id --- --- getArgs = \case --- TAll _ t -> getArgs t --- t -> go [] t --- where --- go acc = \case --- TFun t1 t2 -> go (snoc t1 acc) t2 --- _ -> acc - - - - - - -- | Γ ⊢ e ↓ A ⊣ Δ -- Under input context Γ, e infers output type A, with output context ∆ infer :: Exp -> Tc (T' T.Exp' Type) @@ -366,23 +242,27 @@ infer (ELit lit) = apply (T.ELit lit, typeof lit) -- ------------- Var --------------------- VarRec -- Γ ⊢ x ↓ A ⊣ Γ Γ ⊢ x ↓ ά ⊣ Γ,(x : ά) infer (EVar x) = do - a <- ifM (gets $ (x==) . currentBind) varRec var + a <- fromMaybeM varRec var apply (T.EVar (coerce x), a) where - var = maybeToRightM "Can't infer" =<< - liftA2 (<|>) (lookupEnv x) (lookupSig x) + var = liftA2 (<|>) (lookupEnv x) (lookupSig x) varRec = do + unlessM (gets $ (x==) . currentBind) $ + throwError ("Can't infer " ++ show x) alpha <- TEVar <$> fresh insertEnv (EnvVar x alpha) pure alpha --- Γ ∋ (k : A) --- ------------- Inj --- Γ ⊢ k ↓ A ⊣ Γ +-- TODO infer (EInj k) = do - t <- maybeToRightM ("Unknown constructor: " ++ show k) - =<< lookupInj k - apply (T.EInj $ coerce k, t) + (t, as) <- maybeToRightM ("Unknown constructor here2: " ++ show k) + =<< lookupInj k + t' <- foldM go t as + apply (T.EInj $ coerce k, t') + where + go t a = do + a' <- fresh + pure $ substitute a a' t -- Γ ⊢ A Γ ⊢ e ↑ A ⊣ Δ -- --------------------- Anno @@ -414,17 +294,22 @@ infer (EAbs name e) = do dropTrailing env_var apply (T.EAbs (coerce name) e', on TFun TEVar alpha epsilon) --- Γ ⊢ rhs ↓ A ⊣ Θ Θ,(x:A) ⊢ e ↑ C ⊣ Δ,(x:A),Θ +-- Γ ⊢ λys.rhs ↓ A ⊣ Θ Θ,(x:A) ⊢ e ↑ C ⊣ Δ,(x:A),Θ -- -------------------------------------------- LetI --- Γ ⊢ let x = rhs in e ↑ C ⊣ Δ +-- Γ ⊢ let x ys = rhs in e' ↑ C ⊣ Δ infer (ELet (Bind x vars rhs) e) = do (rhs', a) <- infer $ foldr EAbs rhs vars - let env_var = EnvVar x a + let (a_ret, a_vars) = go vars a + env_var = EnvVar x a insertEnv env_var e'@(_, c) <- infer e (env_l, _) <- gets (splitOn env_var . env) putEnv env_l - apply (T.ELet (T.Bind (coerce x, a) [] (rhs', a)) e', c) + apply (T.ELet (T.Bind (coerce x, a) a_vars (rhs', a_ret)) e', c) + where + go [] t = (t, []) + go (x:xs) (TFun t1 t2) = second (snoc (coerce x, t1)) $ go xs t2 + go _ _ = error "IMPOSSIBLE" -- Γ ⊢ e₁ ↑ Int ⊣ Θ Θ ⊢ e₂ ↑ Int -- --------------------------- +I @@ -434,30 +319,21 @@ infer (EAdd e1 e2) = do e2' <- check e2 int apply (T.EAdd e1' e2', int) +-- Experimental inferece for case +infer (ECase e branches) = do + (e', a) <- infer e + case branches of + [] -> apply (T.ECase (e', a) [], a) + _ -> do + let inferBranch bs (Branch p e) = do + p' <- match p =<< lift (apply a) + lift $ do + e'@(_, b) <- infer e + mapM_ (subtype b) bs + apply (b:bs, T.Branch p' e') --- Γ ⊢ e ↑ A ⊣ Δ --- ------------------------ CaseEmpty↓ --- Γ ⊢ case e of {} ↑ A ⊣ Δ - --- Θ₁ ⊢ p₁⇒e₁ ↓ [Θ₁]C ⊣ Θ₂ --- ... --- Γ ⊢ e ↑ A ⊣ Θ₁ Θₙ ⊢ pₙ⇒eₙ ↓ [Θₙ]C ⊣ Δ --- --------------------------------------- Case↓ --- Γ ⊢ case e of {p₁⇒e₁ ‥ pₙ⇒eₙ} ↓ C ⊣ Δ --- infer (ECase scrut branches) = do --- (scrut', a) <- infer scrut --- case branches of --- [] -> apply (T.ECase (scrut', a) [], a) --- (Branch _ e):_ -> do --- (_, b)<- infer e --- (branches', b') <- foldlM go ([], b) branches --- apply (T.ECase (scrut', a) branches', b') --- where --- go (pi, b) (Branch p e) = do --- p' <- checkPattern p =<< apply a --- e'@(_, b') <- infer e --- subtype b' b --- apply (T.Branch p' e' : pi, b') + (bs, branches') <- evalStateT (mapAccumM inferBranch [] branches) mempty + apply (T.ECase (e', a) branches', head bs) -- | Γ ⊢ A • e ⇓ C ⊣ Δ -- Under input context Γ , applying a function of type A to e infers type C, with output context ∆ @@ -643,16 +519,15 @@ instantiateL alpha a = gets env >>= \env -> go env alpha a instantiateR :: Type -> TEVar -> Tc () instantiateR a alpha = gets env >>= \env -> go env a alpha where - -- Γ ⊢ τ - -- ----------------------------- InstRSolve - -- Γ,ά,Γ' ⊢ τ =:< ά ⊣ Γ,(ά=τ),Γ' + -- Γ ⊢ τ + -- ----------------------------- InstRSolve + -- Γ,ά,Γ' ⊢ τ =:< ά ⊣ Γ,(ά=τ),Γ' go env tau alpha | isMono tau , (env_l, env_r) <- splitOn (EnvTEVar alpha) env , Right _ <- wellFormed env_l tau = putEnv $ (env_l :|> EnvTEVarSolved alpha tau) <> env_r - -- -- ----------------------------- InstRReach -- Γ[ά][έ] ⊢ έ =:< ά ⊣ Γ[ά][έ=ά] go env (TEVar epsilon) alpha = do @@ -686,6 +561,75 @@ instantiateR a alpha = gets env >>= \env -> go env a alpha go _ a alpha = throwError $ "Trying to instantiateR: " ++ ppT a ++ " <: " ++ ppT (TEVar alpha) +--------------------------------------------------------------------------- +-- * Pattern Matching +--------------------------------------------------------------------------- + +substituteP :: TVar -> Type -> StateT (Map TVar TEVar) Tc Type +substituteP alpha a = go =<< get + where + go subs + | Just alpha' <- Map.lookup alpha subs + = lift $ pure $ substitute alpha alpha' a + | otherwise + = do + alpha' <- lift fresh + modify (Map.insert alpha alpha') + lift $ pure $ substitute alpha alpha' a + + +match :: Pattern -> Type -> StateT (Map TVar TEVar) Tc (T' T.Pattern' Type) + +-- ------------------- PVar +-- Γ ⊢ x :: A ⊣ Γ,(x:A) +match (PVar x) a = lift $ do + insertEnv $ EnvVar x a + apply (T.PVar (coerce x), a) + +-- ------------- PCatch +-- Γ ⊢ _ :: A ⊣ Γ +match PCatch a = lift $ apply (T.PCatch, a) + +-- Γ ⊢ typeof(lit) <: A ⊣ Δ +-- ------------------------- PLit +-- Γ ⊢ lit :: A ⊣ Δ +match (PLit lit) a = lift $ do + subtype (typeof lit) a + apply (T.PLit lit, typeof lit) + +match (PEnum k) b = do + (a, tvars) <- lift (maybeToRightM ("Unknown constructor: " ++ show k) =<< lookupInj k) + a' <- foldrM substituteP a tvars + lift $ do + subtype a' b + apply (T.PEnum (coerce k), a') + +match (PInj k ps) b = do + (a, tvars) <- lift (maybeToRightM ("Unknown constructor: " ++ show k) =<< lookupInj k) + a' <- foldrM substituteP a tvars + + let t_return = getDataId a' + + lift $ subtype t_return b + a'' <- lift $ apply a' + + let as = getArgs a'' + unless (length as == length ps) $ throwError "Wrong number of arguments!" + + ps' <- zipWithM matchArgs ps as + lift $ apply (T.PInj (coerce k) ps', t_return) + where + matchArgs p@(PVar _) a = match p =<< lift (apply a) + matchArgs _ _ = throwError "Nested pattern matching not supported!" + + getArgs = \case + TAll _ t -> getArgs t + t -> go [] t + where + go acc = \case + TFun t1 t2 -> go (snoc t1 acc) t2 + _ -> acc + --------------------------------------------------------------------------- -- * Auxiliary --------------------------------------------------------------------------- @@ -704,16 +648,16 @@ substitute :: TVar -- α -> TEVar -- ά -> Type -- A -> Type -- [ά/α]A -substitute tvar tevar typ = case typ of - TLit _ -> typ - TVar tvar' | tvar' == tvar -> TEVar tevar - | otherwise -> typ - TEVar _ -> typ - TFun t1 t2 -> on TFun substitute' t1 t2 - TAll tvar' t -> TAll tvar' (substitute' t) - TData name typs -> TData name $ map substitute' typs +substitute alpha alpha' typ = case typ of + TLit _ -> typ + TVar tvar | tvar == alpha -> TEVar alpha' + | otherwise -> typ + TEVar _ -> typ + TFun t1 t2 -> on TFun subs t1 t2 + TAll tvar' t -> TAll tvar' (subs t) + TData name typs -> TData name $ map subs typs where - substitute' = substitute tvar tevar + subs = substitute alpha alpha' -- | Γ,x,Γ' → (Γ, Γ') splitOn :: EnvElem -> Env -> (Env, Env) @@ -777,7 +721,6 @@ fresh = do modify $ \cxt -> cxt { next_tevar = succ cxt.next_tevar } pure tevar - isComplete :: Env -> Bool isComplete = isNothing . S.findIndexL unSolvedTEVar where @@ -813,15 +756,15 @@ lookupEnv x = gets (findId . env) EnvVar x' t | x==x' -> Just t _ -> findId ys -lookupInj :: UIdent -> Tc (Maybe Type) -lookupInj x = gets (Map.lookup x . data_injs) +lookupInj :: UIdent -> Tc (Maybe (Type, [TVar])) +lookupInj x = gets $ Map.lookup x . data_injs putEnv :: Env -> Tc () putEnv = modifyEnv . const modifyEnv :: (Env -> Env) -> Tc () modifyEnv f = - modify $ \cxt -> {- trace (ppEnv (f cxt.env)) -} cxt { env = f cxt.env } + modify $ \cxt -> {- trace (ppEnv (f cxt.env)) -}cxt { env = f cxt.env } pattern DBind' name vars exp = DBind (Bind name vars exp) pattern DSig' name typ = DSig (Sig name typ) diff --git a/src/TypeChecker/TypeCheckerIr.hs b/src/TypeChecker/TypeCheckerIr.hs index b9207c1..b7b1ef0 100644 --- a/src/TypeChecker/TypeCheckerIr.hs +++ b/src/TypeChecker/TypeCheckerIr.hs @@ -90,7 +90,7 @@ prtSig (x, t) = ] instance (Print a, Print t) => Print (T a t) where - prt i (x, t) = withT + prt i (x, t) = noT where noT = prt i x withT = concatD