Add signature of inferred bind to allow some mutually defined definitions

This commit is contained in:
Martin Fredin 2023-03-30 12:35:47 +02:00
parent a37a52d9f8
commit bbe0d77a19
4 changed files with 111 additions and 28 deletions

View file

@ -1,15 +1,20 @@
data forall a. List (a) where { data Bool () where {
Nil : List (a) True : Bool ()
Cons : a -> List (a) -> List (a) False : Bool ()
}; };
length : forall c. List (List (c)) -> Int; even : Int -> Bool ();
length = \list. case list of { even x = not (odd x) ;
Cons x xs => 1 + length xs;
-- Nil => 0; odd x = not (even x) ;
-- Cons x (Cons y Nil) => 2;
not x = case x of {
True => False;
False => True;
}; };
f = g;
g = f;

View file

@ -1,9 +1,9 @@
data True() where { data Bool () where {
True: True() True : Bool ()
False : Bool ()
};
toBool = case 0 of {
0 => False;
_ => True;
}; };
main: Int;
main =
case True of {
True => 1;
_ => 0;
};

View file

@ -6,18 +6,19 @@
module TypeChecker.TypeCheckerBidir (typecheck, getVars) where module TypeChecker.TypeCheckerBidir (typecheck, getVars) where
import Auxiliary (maybeToRightM, snoc, int, char) import Auxiliary (char, int, maybeToRightM, snoc)
import Control.Applicative (Alternative, Applicative (liftA2), import Control.Applicative (Alternative, Applicative (liftA2),
(<|>)) (<|>))
import Control.Monad.Except (ExceptT, MonadError (throwError), import Control.Monad.Except (ExceptT, MonadError (throwError),
runExceptT, unless, zipWithM, mapAndUnzipM, runExceptT, unless,
zipWithM_) zipWithM, zipWithM_)
import Control.Monad.State (MonadState (get, put), State, import Control.Monad.State (MonadState (get, put), State,
evalState, gets, modify) evalState, gets, modify)
import Data.Coerce (coerce) import Data.Coerce (coerce)
import Data.Foldable (foldrM) import Data.Foldable (foldrM)
import Data.Function (on) import Data.Function (on)
import Data.List (intercalate) import Data.List (intercalate)
import Data.List.Extra (allSame)
import Data.Map (Map) import Data.Map (Map)
import qualified Data.Map as Map import qualified Data.Map as Map
import Data.Maybe (fromMaybe, isNothing) import Data.Maybe (fromMaybe, isNothing)
@ -92,7 +93,7 @@ typecheck (Program defs) = do
typecheckBind :: Bind -> Tc (T.Bind' Type) typecheckBind :: Bind -> Tc (T.Bind' Type)
typecheckBind (Bind name vars rhs) = do typecheckBind (Bind name vars rhs) = do
bind' <- lookupSig name >>= \case bind'@(T.Bind (name, typ) _ _) <- lookupSig name >>= \case
-- TODO These Judgment aren't accurate -- TODO These Judgment aren't accurate
-- (f:A → B) ∈ Γ -- (f:A → B) ∈ Γ
-- Γ,(xs:A) ⊢ e ↑ Β ⊣ Δ -- Γ,(xs:A) ⊢ e ↑ Β ⊣ Δ
@ -101,8 +102,6 @@ typecheckBind (Bind name vars rhs) = do
Just t -> do Just t -> do
(rhs', _) <- check (foldr EAbs rhs vars) t (rhs', _) <- check (foldr EAbs rhs vars) t
pure (T.Bind (coerce name, t) [] (rhs', t)) pure (T.Bind (coerce name, t) [] (rhs', t))
where
vars' = zip vars $ getVars t
-- Γ ⊢ (λxs. e) ↓ A → B ⊣ Δ -- Γ ⊢ (λxs. e) ↓ A → B ⊣ Δ
-- ------------------------------ -- ------------------------------
@ -114,6 +113,7 @@ typecheckBind (Bind name vars rhs) = do
pure (T.Bind (coerce name, t') [] (e', t')) pure (T.Bind (coerce name, t') [] (e', t'))
env <- gets env env <- gets env
unless (isComplete env) err unless (isComplete env) err
insertSig (coerce name) typ
putEnv Empty putEnv Empty
pure bind' pure bind'
where where
@ -389,12 +389,13 @@ check exp typ
-- Γ ⊢ e ↓ A ⊣ Θ Δ ⊢ Π covers [Δ]A TODO -- Γ ⊢ e ↓ A ⊣ Θ Δ ⊢ Π covers [Δ]A TODO
-- --------------------------------------- -- ---------------------------------------
-- Γ ⊢ case e of Π ↑ C ⊣ Δ -- Γ ⊢ case e of Π ↑ C ⊣ Δ
| ECase scrut branches <- exp = do -- TODO maybe remove only use infer rule
(scrut', t_scrut) <- infer scrut | ECase scrut branches <- exp = do
t_scrut' <- applyEnv t_scrut (scrut', t_scrut) <- infer scrut
typ' <- applyEnv typ t_scrut' <- applyEnv t_scrut
branches' <- mapM (\b -> checkBranch b t_scrut' typ') branches typ' <- applyEnv typ
pure (T.ECase (scrut', t_scrut') branches', typ') branches' <- mapM (\b -> checkBranch b t_scrut' typ') branches
pure (T.ECase (scrut', t_scrut') branches', typ')
| otherwise = subsumption | otherwise = subsumption
where where
@ -490,6 +491,18 @@ infer = \case
e2' <- check e2 t e2' <- check e2 t
pure (T.EAdd e1' e2', t) pure (T.EAdd e1' e2', t)
-- Θ ⊢ Π ∷ [Θ]A ↑ [Θ]C ⊣ Δ
-- Γ ⊢ e ↓ A ⊣ Θ Δ ⊢ Π covers [Δ]A TODO
-- ---------------------------------------
-- Γ ⊢ case e of Π ↓ C ⊣ Δ
ECase scrut branches -> do
(scrut', t_scrut) <- infer scrut
t_scrut' <- applyEnv t_scrut
(branches', ts) <- mapAndUnzipM (`inferBranch` t_scrut') branches
unless (allSame ts) $ throwError "Branches have different return types"
pure (T.ECase (scrut', t_scrut') branches', head ts)
-- | Γ ⊢ A • e ⇓ C ⊣ Δ -- | Γ ⊢ A • e ⇓ C ⊣ Δ
-- Under input context Γ , applying a function of type A to e infers type C, with output context ∆ -- Under input context Γ , applying a function of type A to e infers type C, with output context ∆
-- Instantiate existential type variables until there is an arrow type. -- Instantiate existential type variables until there is an arrow type.
@ -534,6 +547,19 @@ apply typ exp = case typ of
-- * Pattern matching -- * Pattern matching
--------------------------------------------------------------------------- ---------------------------------------------------------------------------
-- | Γ ⊢ p ⇒ e ∷ A ↓ C
-- Under context Γ, check pattern in branch p ⇒ e of type A and infer bodies of type C
inferBranch :: Branch -> Type -> Tc (T.Branch' Type, Type)
inferBranch (Branch patt exp) t_patt = do
env_marker <- EnvMark <$> fresh
insertEnv env_marker
patt' <- checkPattern patt t_patt
(exp', t_exp) <- infer exp
(env_l, _) <- gets (splitOn env_marker . env)
putEnv env_l
pure (T.Branch patt' (exp', t_exp), t_exp)
-- | Γ ⊢ p ⇒ e ∷ A ↑ C -- | Γ ⊢ p ⇒ e ∷ A ↑ C
-- Under context Γ, check branch p ⇒ e of type A and bodies of type C -- Under context Γ, check branch p ⇒ e of type A and bodies of type C
checkBranch :: Branch -> Type -> Type -> Tc (T.Branch' Type) checkBranch :: Branch -> Type -> Type -> Tc (T.Branch' Type)
@ -852,6 +878,10 @@ lookupBind x = gets (Map.lookup x . binds)
lookupSig :: LIdent -> Tc (Maybe Type) lookupSig :: LIdent -> Tc (Maybe Type)
lookupSig x = gets (Map.lookup x . sig) lookupSig x = gets (Map.lookup x . sig)
insertSig :: LIdent -> Type -> Tc ()
insertSig name t = modify $ \cxt -> cxt { sig = Map.insert name t cxt.sig }
lookupEnv :: LIdent -> Tc (Maybe Type) lookupEnv :: LIdent -> Tc (Maybe Type)
lookupEnv x = gets (findId . env) lookupEnv x = gets (findId . env)
where where

View file

@ -31,6 +31,8 @@ testTypeCheckerBidir = describe "Bidirectional type checker test" $ do
tc_tree tc_tree
tc_mono_case tc_mono_case
tc_pol_case tc_pol_case
tc_mut_rec
tc_infer_case
tc_id = tc_id =
specify "Basic identity function polymorphism" $ specify "Basic identity function polymorphism" $
@ -266,6 +268,52 @@ tc_pol_case = describe "Polymophic and recursive pattern matching" $ do
, "};" , "};"
] ]
tc_mut_rec = specify "Feasible mutuable recursive definitions" $ run
[ "data Bool () where {"
, " True : Bool ()"
, " False : Bool ()"
, "};"
, "even : Int -> Bool ();"
, "even x = not (odd x);"
, "odd x = not (even x);"
, "not x = case x of {"
, " True => False;"
, " False => True;"
, "};"
] `shouldSatisfy` ok
tc_infer_case = describe "Infer case expression" $ do
specify "Wrong case expression rejected" $
run (fs ++ wrong) `shouldNotSatisfy` ok
specify "Correct case expression accepted" $
run (fs ++ correct) `shouldSatisfy` ok
where
fs =
[ "data Bool () where {"
, " True : Bool ()"
, " False : Bool ()"
, "};"
]
correct =
[ "toBool = case 0 of {"
, " 0 => False;"
, " _ => True;"
, "};"
]
wrong =
[ "toBool = case 0 of {"
, " 0 => False;"
, " _ => 1;"
, "};"
]
run :: [String] -> Err T.Program run :: [String] -> Err T.Program
run = rmTEVar <=< typecheck <=< pProgram . myLexer . unlines run = rmTEVar <=< typecheck <=< pProgram . myLexer . unlines