Add signature of inferred bind to allow some mutually defined definitions

This commit is contained in:
Martin Fredin 2023-03-30 12:35:47 +02:00
parent a37a52d9f8
commit bbe0d77a19
4 changed files with 111 additions and 28 deletions

View file

@ -1,15 +1,20 @@
data forall a. List (a) where {
Nil : List (a)
Cons : a -> List (a) -> List (a)
data Bool () where {
True : Bool ()
False : Bool ()
};
length : forall c. List (List (c)) -> Int;
length = \list. case list of {
Cons x xs => 1 + length xs;
-- Nil => 0;
-- Cons x (Cons y Nil) => 2;
even : Int -> Bool ();
even x = not (odd x) ;
odd x = not (even x) ;
not x = case x of {
True => False;
False => True;
};
f = g;
g = f;

View file

@ -1,9 +1,9 @@
data True() where {
True: True()
data Bool () where {
True : Bool ()
False : Bool ()
};
toBool = case 0 of {
0 => False;
_ => True;
};
main: Int;
main =
case True of {
True => 1;
_ => 0;
};

View file

@ -6,18 +6,19 @@
module TypeChecker.TypeCheckerBidir (typecheck, getVars) where
import Auxiliary (maybeToRightM, snoc, int, char)
import Auxiliary (char, int, maybeToRightM, snoc)
import Control.Applicative (Alternative, Applicative (liftA2),
(<|>))
import Control.Monad.Except (ExceptT, MonadError (throwError),
runExceptT, unless, zipWithM,
zipWithM_)
mapAndUnzipM, runExceptT, unless,
zipWithM, zipWithM_)
import Control.Monad.State (MonadState (get, put), State,
evalState, gets, modify)
import Data.Coerce (coerce)
import Data.Foldable (foldrM)
import Data.Function (on)
import Data.List (intercalate)
import Data.List.Extra (allSame)
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Maybe (fromMaybe, isNothing)
@ -92,7 +93,7 @@ typecheck (Program defs) = do
typecheckBind :: Bind -> Tc (T.Bind' Type)
typecheckBind (Bind name vars rhs) = do
bind' <- lookupSig name >>= \case
bind'@(T.Bind (name, typ) _ _) <- lookupSig name >>= \case
-- TODO These Judgment aren't accurate
-- (f:A → B) ∈ Γ
-- Γ,(xs:A) ⊢ e ↑ Β ⊣ Δ
@ -101,8 +102,6 @@ typecheckBind (Bind name vars rhs) = do
Just t -> do
(rhs', _) <- check (foldr EAbs rhs vars) t
pure (T.Bind (coerce name, t) [] (rhs', t))
where
vars' = zip vars $ getVars t
-- Γ ⊢ (λxs. e) ↓ A → B ⊣ Δ
-- ------------------------------
@ -114,6 +113,7 @@ typecheckBind (Bind name vars rhs) = do
pure (T.Bind (coerce name, t') [] (e', t'))
env <- gets env
unless (isComplete env) err
insertSig (coerce name) typ
putEnv Empty
pure bind'
where
@ -389,12 +389,13 @@ check exp typ
-- Γ ⊢ e ↓ A ⊣ Θ Δ ⊢ Π covers [Δ]A TODO
-- ---------------------------------------
-- Γ ⊢ case e of Π ↑ C ⊣ Δ
| ECase scrut branches <- exp = do
(scrut', t_scrut) <- infer scrut
t_scrut' <- applyEnv t_scrut
typ' <- applyEnv typ
branches' <- mapM (\b -> checkBranch b t_scrut' typ') branches
pure (T.ECase (scrut', t_scrut') branches', typ')
-- TODO maybe remove only use infer rule
| ECase scrut branches <- exp = do
(scrut', t_scrut) <- infer scrut
t_scrut' <- applyEnv t_scrut
typ' <- applyEnv typ
branches' <- mapM (\b -> checkBranch b t_scrut' typ') branches
pure (T.ECase (scrut', t_scrut') branches', typ')
| otherwise = subsumption
where
@ -490,6 +491,18 @@ infer = \case
e2' <- check e2 t
pure (T.EAdd e1' e2', t)
-- Θ ⊢ Π ∷ [Θ]A ↑ [Θ]C ⊣ Δ
-- Γ ⊢ e ↓ A ⊣ Θ Δ ⊢ Π covers [Δ]A TODO
-- ---------------------------------------
-- Γ ⊢ case e of Π ↓ C ⊣ Δ
ECase scrut branches -> do
(scrut', t_scrut) <- infer scrut
t_scrut' <- applyEnv t_scrut
(branches', ts) <- mapAndUnzipM (`inferBranch` t_scrut') branches
unless (allSame ts) $ throwError "Branches have different return types"
pure (T.ECase (scrut', t_scrut') branches', head ts)
-- | Γ ⊢ A • e ⇓ C ⊣ Δ
-- Under input context Γ , applying a function of type A to e infers type C, with output context ∆
-- Instantiate existential type variables until there is an arrow type.
@ -534,6 +547,19 @@ apply typ exp = case typ of
-- * Pattern matching
---------------------------------------------------------------------------
-- | Γ ⊢ p ⇒ e ∷ A ↓ C
-- Under context Γ, check pattern in branch p ⇒ e of type A and infer bodies of type C
inferBranch :: Branch -> Type -> Tc (T.Branch' Type, Type)
inferBranch (Branch patt exp) t_patt = do
env_marker <- EnvMark <$> fresh
insertEnv env_marker
patt' <- checkPattern patt t_patt
(exp', t_exp) <- infer exp
(env_l, _) <- gets (splitOn env_marker . env)
putEnv env_l
pure (T.Branch patt' (exp', t_exp), t_exp)
-- | Γ ⊢ p ⇒ e ∷ A ↑ C
-- Under context Γ, check branch p ⇒ e of type A and bodies of type C
checkBranch :: Branch -> Type -> Type -> Tc (T.Branch' Type)
@ -852,6 +878,10 @@ lookupBind x = gets (Map.lookup x . binds)
lookupSig :: LIdent -> Tc (Maybe Type)
lookupSig x = gets (Map.lookup x . sig)
insertSig :: LIdent -> Type -> Tc ()
insertSig name t = modify $ \cxt -> cxt { sig = Map.insert name t cxt.sig }
lookupEnv :: LIdent -> Tc (Maybe Type)
lookupEnv x = gets (findId . env)
where

View file

@ -31,6 +31,8 @@ testTypeCheckerBidir = describe "Bidirectional type checker test" $ do
tc_tree
tc_mono_case
tc_pol_case
tc_mut_rec
tc_infer_case
tc_id =
specify "Basic identity function polymorphism" $
@ -266,6 +268,52 @@ tc_pol_case = describe "Polymophic and recursive pattern matching" $ do
, "};"
]
tc_mut_rec = specify "Feasible mutuable recursive definitions" $ run
[ "data Bool () where {"
, " True : Bool ()"
, " False : Bool ()"
, "};"
, "even : Int -> Bool ();"
, "even x = not (odd x);"
, "odd x = not (even x);"
, "not x = case x of {"
, " True => False;"
, " False => True;"
, "};"
] `shouldSatisfy` ok
tc_infer_case = describe "Infer case expression" $ do
specify "Wrong case expression rejected" $
run (fs ++ wrong) `shouldNotSatisfy` ok
specify "Correct case expression accepted" $
run (fs ++ correct) `shouldSatisfy` ok
where
fs =
[ "data Bool () where {"
, " True : Bool ()"
, " False : Bool ()"
, "};"
]
correct =
[ "toBool = case 0 of {"
, " 0 => False;"
, " _ => True;"
, "};"
]
wrong =
[ "toBool = case 0 of {"
, " 0 => False;"
, " _ => 1;"
, "};"
]
run :: [String] -> Err T.Program
run = rmTEVar <=< typecheck <=< pProgram . myLexer . unlines