From bbe0d77a19e88055b597ed597c8da27ca6f490ca Mon Sep 17 00:00:00 2001 From: Martin Fredin Date: Thu, 30 Mar 2023 12:35:47 +0200 Subject: [PATCH] Add signature of inferred bind to allow some mutually defined definitions --- sample-programs/basic-0 | 21 ++++++----- sample-programs/basic-1.crf | 16 ++++----- src/TypeChecker/TypeCheckerBidir.hs | 54 ++++++++++++++++++++++------- tests/TestTypeCheckerBidir.hs | 48 +++++++++++++++++++++++++ 4 files changed, 111 insertions(+), 28 deletions(-) diff --git a/sample-programs/basic-0 b/sample-programs/basic-0 index 88e4071..35b9c04 100644 --- a/sample-programs/basic-0 +++ b/sample-programs/basic-0 @@ -1,15 +1,20 @@ -data forall a. List (a) where { - Nil : List (a) - Cons : a -> List (a) -> List (a) +data Bool () where { + True : Bool () + False : Bool () }; -length : forall c. List (List (c)) -> Int; -length = \list. case list of { - Cons x xs => 1 + length xs; --- Nil => 0; --- Cons x (Cons y Nil) => 2; +even : Int -> Bool (); +even x = not (odd x) ; + +odd x = not (even x) ; + +not x = case x of { + True => False; + False => True; }; +f = g; +g = f; diff --git a/sample-programs/basic-1.crf b/sample-programs/basic-1.crf index 91317cd..a5e2ae4 100644 --- a/sample-programs/basic-1.crf +++ b/sample-programs/basic-1.crf @@ -1,9 +1,9 @@ -data True() where { - True: True() +data Bool () where { + True : Bool () + False : Bool () +}; + +toBool = case 0 of { + 0 => False; + _ => True; }; -main: Int; -main = - case True of { - True => 1; - _ => 0; - }; \ No newline at end of file diff --git a/src/TypeChecker/TypeCheckerBidir.hs b/src/TypeChecker/TypeCheckerBidir.hs index ffadf07..3930a0e 100644 --- a/src/TypeChecker/TypeCheckerBidir.hs +++ b/src/TypeChecker/TypeCheckerBidir.hs @@ -6,18 +6,19 @@ module TypeChecker.TypeCheckerBidir (typecheck, getVars) where -import Auxiliary (maybeToRightM, snoc, int, char) +import Auxiliary (char, int, maybeToRightM, snoc) import Control.Applicative (Alternative, Applicative (liftA2), (<|>)) import Control.Monad.Except (ExceptT, MonadError (throwError), - runExceptT, unless, zipWithM, - zipWithM_) + mapAndUnzipM, runExceptT, unless, + zipWithM, zipWithM_) import Control.Monad.State (MonadState (get, put), State, evalState, gets, modify) import Data.Coerce (coerce) import Data.Foldable (foldrM) import Data.Function (on) import Data.List (intercalate) +import Data.List.Extra (allSame) import Data.Map (Map) import qualified Data.Map as Map import Data.Maybe (fromMaybe, isNothing) @@ -92,7 +93,7 @@ typecheck (Program defs) = do typecheckBind :: Bind -> Tc (T.Bind' Type) typecheckBind (Bind name vars rhs) = do - bind' <- lookupSig name >>= \case + bind'@(T.Bind (name, typ) _ _) <- lookupSig name >>= \case -- TODO These Judgment aren't accurate -- (f:A → B) ∈ Γ -- Γ,(xs:A) ⊢ e ↑ Β ⊣ Δ @@ -101,8 +102,6 @@ typecheckBind (Bind name vars rhs) = do Just t -> do (rhs', _) <- check (foldr EAbs rhs vars) t pure (T.Bind (coerce name, t) [] (rhs', t)) - where - vars' = zip vars $ getVars t -- Γ ⊢ (λxs. e) ↓ A → B ⊣ Δ -- ------------------------------ @@ -114,6 +113,7 @@ typecheckBind (Bind name vars rhs) = do pure (T.Bind (coerce name, t') [] (e', t')) env <- gets env unless (isComplete env) err + insertSig (coerce name) typ putEnv Empty pure bind' where @@ -389,12 +389,13 @@ check exp typ -- Γ ⊢ e ↓ A ⊣ Θ Δ ⊢ Π covers [Δ]A TODO -- --------------------------------------- -- Γ ⊢ case e of Π ↑ C ⊣ Δ - | ECase scrut branches <- exp = do - (scrut', t_scrut) <- infer scrut - t_scrut' <- applyEnv t_scrut - typ' <- applyEnv typ - branches' <- mapM (\b -> checkBranch b t_scrut' typ') branches - pure (T.ECase (scrut', t_scrut') branches', typ') + -- TODO maybe remove only use infer rule + | ECase scrut branches <- exp = do + (scrut', t_scrut) <- infer scrut + t_scrut' <- applyEnv t_scrut + typ' <- applyEnv typ + branches' <- mapM (\b -> checkBranch b t_scrut' typ') branches + pure (T.ECase (scrut', t_scrut') branches', typ') | otherwise = subsumption where @@ -490,6 +491,18 @@ infer = \case e2' <- check e2 t pure (T.EAdd e1' e2', t) + + -- Θ ⊢ Π ∷ [Θ]A ↑ [Θ]C ⊣ Δ + -- Γ ⊢ e ↓ A ⊣ Θ Δ ⊢ Π covers [Δ]A TODO + -- --------------------------------------- + -- Γ ⊢ case e of Π ↓ C ⊣ Δ + ECase scrut branches -> do + (scrut', t_scrut) <- infer scrut + t_scrut' <- applyEnv t_scrut + (branches', ts) <- mapAndUnzipM (`inferBranch` t_scrut') branches + unless (allSame ts) $ throwError "Branches have different return types" + pure (T.ECase (scrut', t_scrut') branches', head ts) + -- | Γ ⊢ A • e ⇓ C ⊣ Δ -- Under input context Γ , applying a function of type A to e infers type C, with output context ∆ -- Instantiate existential type variables until there is an arrow type. @@ -534,6 +547,19 @@ apply typ exp = case typ of -- * Pattern matching --------------------------------------------------------------------------- +-- | Γ ⊢ p ⇒ e ∷ A ↓ C +-- Under context Γ, check pattern in branch p ⇒ e of type A and infer bodies of type C +inferBranch :: Branch -> Type -> Tc (T.Branch' Type, Type) +inferBranch (Branch patt exp) t_patt = do + env_marker <- EnvMark <$> fresh + insertEnv env_marker + patt' <- checkPattern patt t_patt + (exp', t_exp) <- infer exp + (env_l, _) <- gets (splitOn env_marker . env) + putEnv env_l + pure (T.Branch patt' (exp', t_exp), t_exp) + + -- | Γ ⊢ p ⇒ e ∷ A ↑ C -- Under context Γ, check branch p ⇒ e of type A and bodies of type C checkBranch :: Branch -> Type -> Type -> Tc (T.Branch' Type) @@ -852,6 +878,10 @@ lookupBind x = gets (Map.lookup x . binds) lookupSig :: LIdent -> Tc (Maybe Type) lookupSig x = gets (Map.lookup x . sig) +insertSig :: LIdent -> Type -> Tc () +insertSig name t = modify $ \cxt -> cxt { sig = Map.insert name t cxt.sig } + + lookupEnv :: LIdent -> Tc (Maybe Type) lookupEnv x = gets (findId . env) where diff --git a/tests/TestTypeCheckerBidir.hs b/tests/TestTypeCheckerBidir.hs index 48bf230..c75457e 100644 --- a/tests/TestTypeCheckerBidir.hs +++ b/tests/TestTypeCheckerBidir.hs @@ -31,6 +31,8 @@ testTypeCheckerBidir = describe "Bidirectional type checker test" $ do tc_tree tc_mono_case tc_pol_case + tc_mut_rec + tc_infer_case tc_id = specify "Basic identity function polymorphism" $ @@ -266,6 +268,52 @@ tc_pol_case = describe "Polymophic and recursive pattern matching" $ do , "};" ] + +tc_mut_rec = specify "Feasible mutuable recursive definitions" $ run + [ "data Bool () where {" + , " True : Bool ()" + , " False : Bool ()" + , "};" + + , "even : Int -> Bool ();" + , "even x = not (odd x);" + + , "odd x = not (even x);" + + , "not x = case x of {" + , " True => False;" + , " False => True;" + , "};" + ] `shouldSatisfy` ok + +tc_infer_case = describe "Infer case expression" $ do + specify "Wrong case expression rejected" $ + run (fs ++ wrong) `shouldNotSatisfy` ok + specify "Correct case expression accepted" $ + run (fs ++ correct) `shouldSatisfy` ok + where + fs = + [ "data Bool () where {" + , " True : Bool ()" + , " False : Bool ()" + , "};" + ] + + correct = + [ "toBool = case 0 of {" + , " 0 => False;" + , " _ => True;" + , "};" + ] + + wrong = + [ "toBool = case 0 of {" + , " 0 => False;" + , " _ => 1;" + , "};" + ] + + run :: [String] -> Err T.Program run = rmTEVar <=< typecheck <=< pProgram . myLexer . unlines