From 72352d9619e862484f16ee12140e0b3a5d23f32e Mon Sep 17 00:00:00 2001 From: Martin Fredin Date: Thu, 30 Mar 2023 18:46:37 +0200 Subject: [PATCH] Use use tevars for bind without type signatures, fix recursive functions --- src/TypeChecker/TypeCheckerBidir.hs | 227 +++++++++++----------------- tests/TestTypeCheckerBidir.hs | 17 +++ 2 files changed, 107 insertions(+), 137 deletions(-) diff --git a/src/TypeChecker/TypeCheckerBidir.hs b/src/TypeChecker/TypeCheckerBidir.hs index 3930a0e..1f16e11 100644 --- a/src/TypeChecker/TypeCheckerBidir.hs +++ b/src/TypeChecker/TypeCheckerBidir.hs @@ -6,18 +6,18 @@ module TypeChecker.TypeCheckerBidir (typecheck, getVars) where -import Auxiliary (char, int, maybeToRightM, snoc) +import Auxiliary (int, litType, maybeToRightM, snoc) import Control.Applicative (Alternative, Applicative (liftA2), (<|>)) import Control.Monad.Except (ExceptT, MonadError (throwError), - mapAndUnzipM, runExceptT, unless, + liftEither, runExceptT, unless, zipWithM, zipWithM_) import Control.Monad.State (MonadState (get, put), State, evalState, gets, modify) import Data.Coerce (coerce) import Data.Foldable (foldrM) import Data.Function (on) -import Data.List (intercalate) +import Data.List (intercalate, partition) import Data.List.Extra (allSame) import Data.Map (Map) import qualified Data.Map as Map @@ -39,6 +39,7 @@ import qualified TypeChecker.TypeCheckerIr as T -- • Fix problems with types in Pattern/Branch in TypeCheckerIr -- • Use applyEnvExp consistently -- • Fix the different type getters functions (e.g. partitionType) functions +-- • Handle recursive functions. Maybe use a isRec : Bool variable. data EnvElem = EnvVar LIdent Type -- ^ Term variable typing. x : A | EnvTVar TVar -- ^ Universal type variable. α @@ -94,18 +95,9 @@ typecheck (Program defs) = do typecheckBind :: Bind -> Tc (T.Bind' Type) typecheckBind (Bind name vars rhs) = do bind'@(T.Bind (name, typ) _ _) <- lookupSig name >>= \case - -- TODO These Judgment aren't accurate - -- (f:A → B) ∈ Γ - -- Γ,(xs:A) ⊢ e ↑ Β ⊣ Δ - --------------------------- - -- Γ ⊢ f xs = e ↓ Α → B ⊣ Δ Just t -> do (rhs', _) <- check (foldr EAbs rhs vars) t pure (T.Bind (coerce name, t) [] (rhs', t)) - - -- Γ ⊢ (λxs. e) ↓ A → B ⊣ Δ - -- ------------------------------ - -- Γ ⊢ f xs = e ↓ [Γ]A → [Γ]B ⊣ Δ Nothing -> do (e, t) <- infer $ foldr EAbs rhs vars t' <- applyEnv t @@ -113,7 +105,7 @@ typecheckBind (Bind name vars rhs) = do pure (T.Bind (coerce name, t') [] (e', t')) env <- gets env unless (isComplete env) err - insertSig (coerce name) typ + insertSig (coerce name) typ -- HERE putEnv Empty pure bind' where @@ -265,9 +257,9 @@ instantiateL tevar typ = gets env >>= go -- Γ ⊢ τ -- ----------------------------- InstLSolve -- Γ,ά,Γ' ⊢ ά :=< τ ⊣ Γ,(ά=τ),Γ' - | isMono typ + | noForall typ , (env_l, env_r) <- splitOn (EnvTEVar tevar) env - , Right _ <- wellFormed env_l typ + , Right _ <- wellFormed env_l typ = putEnv $ (env_l :|> EnvTEVarSolved tevar typ) <> env_r | TEVar tevar' <- typ = instReach tevar tevar' @@ -305,7 +297,7 @@ instantiateR typ tevar = gets env >>= go -- Γ ⊢ τ -- ----------------------------- InstRSolve -- Γ,ά,Γ' ⊢ τ =:< ά ⊣ Γ,(ά=τ),Γ' - | isMono typ + | noForall typ , (env_l, env_r) <- splitOn (EnvTEVar tevar) env , Right _ <- wellFormed env_l typ = putEnv $ (env_l :|> EnvTEVarSolved tevar typ) <> env_r @@ -337,7 +329,6 @@ instantiateR typ tevar = gets env >>= go let (env_l, _) = splitOn (EnvTVar tvar) env putEnv env_l - | otherwise = error $ "Trying to instantiateR: " ++ ppT typ ++ " <: " ++ ppT (TEVar tevar) @@ -385,18 +376,6 @@ check exp typ putEnv env_l pure (T.EAbs (coerce name) e', typ) - -- Θ ⊢ Π ∷ [Θ]A ↑ [Θ]C ⊣ Δ - -- Γ ⊢ e ↓ A ⊣ Θ Δ ⊢ Π covers [Δ]A TODO - -- --------------------------------------- - -- Γ ⊢ case e of Π ↑ C ⊣ Δ - -- TODO maybe remove only use infer rule - | ECase scrut branches <- exp = do - (scrut', t_scrut) <- infer scrut - t_scrut' <- applyEnv t_scrut - typ' <- applyEnv typ - branches' <- mapM (\b -> checkBranch b t_scrut' typ') branches - pure (T.ECase (scrut', t_scrut') branches', typ') - | otherwise = subsumption where -- Γ,α ⊢ e ↓ A ⊣ Θ Θ ⊢ [Θ]A <: [Θ]B ⊣ Δ @@ -405,7 +384,7 @@ check exp typ subsumption = do (exp', t) <- infer exp exp'' <- applyEnvExp exp' - t' <- applyEnv t + t' <- applyEnv t typ' <- applyEnv typ subtype t' typ' pure (exp'', t') @@ -415,19 +394,20 @@ check exp typ infer :: Exp -> Tc (T.ExpT' Type) infer = \case - ELit lit -> pure (T.ELit lit, inferLit lit) + ELit lit -> pure (T.ELit lit, litType lit) - -- (x : A) ∈ Γ - -- ------------- Var - -- Γ ⊢ x ↓ A ⊣ Γ + -- (x : A) ∈ Γ (x : A) ∉ Γ + -- ------------- Var --------------- Var' + -- Γ ⊢ x ↓ A ⊣ Γ Γ ⊢ x ↓ ά ⊣ Γ,ά EVar name -> do t <- liftA2 (<|>) (lookupEnv name) (lookupSig name) >>= \case Just t -> pure t Nothing -> do - e <- maybeToRightM - ("Unbound variable " ++ show name) - =<< lookupBind name - snd <$> infer e + tevar <- fresh + insertEnv (EnvTEVar tevar) + let t = TEVar tevar + insertEnv (EnvVar name t) + pure t pure (T.EVar (coerce name), t) EInj name -> do @@ -480,28 +460,25 @@ infer = \case putEnv env_l pure (T.ELet (T.Bind (coerce name, t_rhs) [] (rhs', t_rhs)) (e',t), t) - -- Γ ⊢ e₁ ↑ Int Γ ⊢ e₁ ↑ Int + -- Γ ⊢ e₁ ↑ Int ⊣ Θ Θ ⊢ e₂ ↑ Int -- --------------------------- +I -- Γ ⊢ e₁ + e₂ ↓ Int ⊣ Δ EAdd e1 e2 -> do - cxt <- get - let t = int - e1' <- check e1 t - put cxt - e2' <- check e2 t - pure (T.EAdd e1' e2', t) + e1' <- check e1 int + e2' <- check e2 int + e1'' <- applyEnvExpT e1' + e2'' <- applyEnvExpT e2' + pure (T.EAdd e1'' e2'', int) - -- Θ ⊢ Π ∷ [Θ]A ↑ [Θ]C ⊣ Δ + -- Θ ⊢ Π ∷ A ↓ C ⊣ Δ -- Γ ⊢ e ↓ A ⊣ Θ Δ ⊢ Π covers [Δ]A TODO -- --------------------------------------- -- Γ ⊢ case e of Π ↓ C ⊣ Δ ECase scrut branches -> do (scrut', t_scrut) <- infer scrut - t_scrut' <- applyEnv t_scrut - (branches', ts) <- mapAndUnzipM (`inferBranch` t_scrut') branches - unless (allSame ts) $ throwError "Branches have different return types" - pure (T.ECase (scrut', t_scrut') branches', head ts) + (branches', t_return) <- inferBranches branches t_scrut + pure (T.ECase (scrut', t_scrut) branches', t_return) -- | Γ ⊢ A • e ⇓ C ⊣ Δ -- Under input context Γ , applying a function of type A to e infers type C, with output context ∆ @@ -547,45 +524,71 @@ apply typ exp = case typ of -- * Pattern matching --------------------------------------------------------------------------- --- | Γ ⊢ p ⇒ e ∷ A ↓ C --- Under context Γ, check pattern in branch p ⇒ e of type A and infer bodies of type C +-- Γ ⊢ p ⇒ e ∷ A ↓ B ⊣ Θ +-- Θ ⊢ Π ∷ [Θ]A ↓ C ⊣ Δ +-- [Δ]B <: C +-- --------------------------- +-- Γ ⊢ (p ⇒ e),Π ∷ A ↓ C ⊣ Δ +inferBranches :: [Branch] -> Type -> Tc ([T.Branch' Type], Type) +inferBranches branches t_patt = do + (branches', ts_exp) <- inferBranches' t_patt branches + ts_exp' <- mapM applyEnv ts_exp + let (monos, pols) = partition isMono ts_exp' + t_exp <- liftEither $ bodyType t_patt monos + mapM_ (subtype t_exp) pols + pure (branches', t_exp) + where + + bodyType :: Type -> [Type] -> Err Type + bodyType t_patt = \case + [] -> pure t_patt + [m] -> pure m + m:n:ms | m == n -> bodyType t_patt (n:ms) + | otherwise -> throwError $ unwords [ "Wrong return types: " + , ppT m, "≠", ppT n ] + + inferBranches' = go [] [] + where + go branches ts_exp t = \case + [] -> pure (branches, ts_exp) + b:bs -> do + (b', t_e) <- inferBranch b t + t' <- applyEnv t + go (snoc b' branches) (snoc t_e ts_exp) t' bs + +-- Γ ⊢ p ↑ A ⊣ Θ Θ ⊢ e ↓ C ⊣ Δ +-- ------------------------------- +-- Γ ⊢ p ⇒ e ∷ A ↓ C ⊣ Δ inferBranch :: Branch -> Type -> Tc (T.Branch' Type, Type) inferBranch (Branch patt exp) t_patt = do - env_marker <- EnvMark <$> fresh - insertEnv env_marker patt' <- checkPattern patt t_patt (exp', t_exp) <- infer exp - (env_l, _) <- gets (splitOn env_marker . env) - putEnv env_l pure (T.Branch patt' (exp', t_exp), t_exp) - --- | Γ ⊢ p ⇒ e ∷ A ↑ C --- Under context Γ, check branch p ⇒ e of type A and bodies of type C -checkBranch :: Branch -> Type -> Type -> Tc (T.Branch' Type) -checkBranch (Branch patt exp) t_patt t_exp = do - env_marker <- EnvMark <$> fresh - insertEnv env_marker - patt' <- checkPattern patt t_patt - t_exp' <- applyEnv t_exp - (exp, t_exp) <- check exp t_exp' - (env_l, _) <- gets (splitOn env_marker . env) - putEnv env_l - pure (T.Branch patt' (exp, t_exp)) - checkPattern :: Pattern -> Type -> Tc (T.Pattern' Type, Type) checkPattern patt t_patt = case patt of + + -- ------------------- + -- Γ ⊢ x ↑ A ⊣ Γ,(x:A) PVar x -> do insertEnv $ EnvVar x t_patt pure (T.PVar (coerce x, t_patt), t_patt) + -- ------------- + -- Γ ⊢ _ ↑ A ⊣ Γ PCatch -> pure (T.PCatch, t_patt) + -- Γ ⊢ τ ↓ A ⊣ Γ Γ ⊢ A <: B ⊣ Δ + -- ------------------------------ + -- Γ ⊢ τ ↑ B ⊣ Δ PLit lit -> do - subtype (inferLit lit) t_patt + subtype (litType lit) t_patt t_patt' <- applyEnv t_patt pure (T.PLit (lit, t_patt), t_patt') + -- (x : A) ∈ Γ Γ ⊢ A <: B ⊣ Δ + -- --------------------------- + -- Γ ⊢ inj₀ x ↑ B ⊣ Δ PEnum name -> do t <- maybeToRightM ("Unknown constructor " ++ show name) =<< lookupInj name @@ -599,13 +602,14 @@ checkPattern patt t_patt = case patt of t_inj' <- foldrM substitute' t_inj $ getInitForalls t_inj subtype (getDataId t_inj') t_patt t_inj'' <- applyEnv t_inj' - t_patt' <- applyEnv t_patt let ts_inj = getParams t_inj'' ps' <- zipWithM (\p t -> checkPattern p =<< applyEnv t) ps ts_inj + t_patt' <- applyEnv t_patt pure (T.PInj (coerce name) (map fst ps'), t_patt') where substitute' fa t = do tevar <- fresh + -- insertEnv (EnvTEVar tevar) pure $ substitute tvar tevar t where TAll tvar _ = fa dummy @@ -666,6 +670,9 @@ splitOn x env = second (S.drop 1) $ S.breakl (==x) env dropTrailing :: EnvElem -> Tc () dropTrailing x = modifyEnv $ S.takeWhileL (/= x) +applyEnvExpT :: (T.Exp' Type, Type) -> Tc (T.Exp' Type, Type) +applyEnvExpT (e, t) = liftA2 (,) (applyEnvExp e) (applyEnv t) + applyEnvExp :: T.Exp' Type -> Tc (T.Exp' Type) applyEnvExp exp = case exp of T.ELet (T.Bind id vars rhs) exp -> do @@ -681,7 +688,6 @@ applyEnvExp exp = case exp of (mapM applyEnvBranch branches) _ -> pure exp where - applyEnvExpT (e, t) = liftA2 (,) (applyEnvExp e) (applyEnv t) applyEnvId = secondM applyEnv applyEnvBranch (T.Branch (p, t) e) = do pt <- liftA2 (,) (applyEnvPattern p) (applyEnv t) @@ -752,20 +758,24 @@ wellFormed env = \case TData _ typs -> mapM_ (wellFormed env) typs +noForall :: Type -> Bool +noForall = \case + TAll{} -> False + TFun t1 t2 -> on (&&) noForall t1 t2 + TData _ typs -> all noForall typs + TVar _ -> True + TEVar _ -> True + TLit _ -> True + isMono :: Type -> Bool isMono = \case TAll{} -> False TFun t1 t2 -> on (&&) isMono t1 t2 TData _ typs -> all isMono typs - TVar _ -> True - TEVar _ -> True + TVar _ -> False + TEVar _ -> False TLit _ -> True -inferLit :: Lit -> Type -inferLit = \case - LInt _ -> TLit "Int" - LChar _ -> TLit "Char" - fresh :: Tc TEVar fresh = do tevar <- gets (MkTEVar . LIdent . ("a#" ++) . show . next_tevar) @@ -803,60 +813,6 @@ skipForalls = go [] TAll tvar t -> go (snoc (TAll tvar) acc) t _ -> (acc, typ) - -getForallsData :: Type -> [Type -> Type] -getForallsData = fst . partitionData - -getTData :: Type -> Type -getTData = snd . partitionData - -partitionData :: Type -> ([Type -> Type], Type) -partitionData = go . ([],) - where - go (acc, typ) = case typ of - TAll tvar t -> go (snoc (TAll tvar) acc, t) - TData {} -> (acc, typ) - TFun _ t -> go (acc, t) - _ -> error "Bad data type" - - -partitionTypeWithForall :: Type -> ([Type], Type) -partitionTypeWithForall typ = (t_vars', t_return') - where - t_vars' = map (\t -> foldr applyForall t foralls) t_vars - t_return' = foldr applyForall t_return foralls - - applyForall fa t | usesTVar tvar t = fa t - | otherwise = t - where TAll tvar _ = fa t - - (t_vars, t_return) = go [] typ' - (foralls, typ') = skipForalls typ - - - go acc t = case t of - TFun t1 t2 -> go (snoc t1 acc) t2 - _ -> (acc, t) - -usesTVar :: TVar -> Type -> Bool -usesTVar tvar = \case - TLit _ -> False - TVar tvar' | tvar' == tvar -> True - | otherwise -> False - TFun t1 t2 -> on (||) usesTVar' t1 t2 - TAll tvar' t | tvar' == tvar -> error "Redeclaration of TVar" - | otherwise -> usesTVar' t - TData _ typs -> any usesTVar' typs - _ -> error "Impossible" - where - usesTVar' = usesTVar tvar - -skipLambdas :: Int -> T.Exp' Type -> T.Exp' Type -skipLambdas i exp - | i == 0 = exp - | T.EAbs _ (e, _) <- exp = skipLambdas (i-1) e - | otherwise = error "Number of expected lambdas doesn't match expression" - isComplete :: Env -> Bool isComplete = isNothing . S.findIndexL unSolvedTEVar where @@ -872,9 +828,6 @@ toTVar = \case insertEnv :: EnvElem -> Tc () insertEnv x = modifyEnv (:|> x) -lookupBind :: LIdent -> Tc (Maybe Exp) -lookupBind x = gets (Map.lookup x . binds) - lookupSig :: LIdent -> Tc (Maybe Type) lookupSig x = gets (Map.lookup x . sig) diff --git a/tests/TestTypeCheckerBidir.hs b/tests/TestTypeCheckerBidir.hs index f423720..5e1d5b1 100644 --- a/tests/TestTypeCheckerBidir.hs +++ b/tests/TestTypeCheckerBidir.hs @@ -32,6 +32,8 @@ testTypeCheckerBidir = describe "Bidirectional type checker test" $ do tc_mono_case tc_pol_case tc_infer_case + tc_rec1 + tc_rec2 tc_id = specify "Basic identity function polymorphism" $ @@ -295,6 +297,21 @@ tc_infer_case = describe "Infer case expression" $ do , "};" ] +tc_rec1 = specify "Infer simple recursive definition" $ + run ["test x = 1 + test (x + 1);"] `shouldSatisfy` ok + +tc_rec2 = specify "Infer recursive definition with pattern matching" $ run + [ "data Bool () where {" + , " False : Bool ()" + , " True : Bool ()" + , "};" + + , "test = \\x. case x of {" + , " 10 => True;" + , " _ => test (x+1);" + , "};" + ] `shouldSatisfy` ok + run :: [String] -> Err T.Program run = rmTEVar <=< typecheck <=< pProgram . myLexer . unlines