Use use tevars for bind without type signatures, fix recursive functions

This commit is contained in:
Martin Fredin 2023-03-30 18:46:37 +02:00
parent 4831205e67
commit 72352d9619
2 changed files with 107 additions and 137 deletions

View file

@ -6,18 +6,18 @@
module TypeChecker.TypeCheckerBidir (typecheck, getVars) where
import Auxiliary (char, int, maybeToRightM, snoc)
import Auxiliary (int, litType, maybeToRightM, snoc)
import Control.Applicative (Alternative, Applicative (liftA2),
(<|>))
import Control.Monad.Except (ExceptT, MonadError (throwError),
mapAndUnzipM, runExceptT, unless,
liftEither, runExceptT, unless,
zipWithM, zipWithM_)
import Control.Monad.State (MonadState (get, put), State,
evalState, gets, modify)
import Data.Coerce (coerce)
import Data.Foldable (foldrM)
import Data.Function (on)
import Data.List (intercalate)
import Data.List (intercalate, partition)
import Data.List.Extra (allSame)
import Data.Map (Map)
import qualified Data.Map as Map
@ -39,6 +39,7 @@ import qualified TypeChecker.TypeCheckerIr as T
-- • Fix problems with types in Pattern/Branch in TypeCheckerIr
-- • Use applyEnvExp consistently
-- • Fix the different type getters functions (e.g. partitionType) functions
-- • Handle recursive functions. Maybe use a isRec : Bool variable.
data EnvElem = EnvVar LIdent Type -- ^ Term variable typing. x : A
| EnvTVar TVar -- ^ Universal type variable. α
@ -94,18 +95,9 @@ typecheck (Program defs) = do
typecheckBind :: Bind -> Tc (T.Bind' Type)
typecheckBind (Bind name vars rhs) = do
bind'@(T.Bind (name, typ) _ _) <- lookupSig name >>= \case
-- TODO These Judgment aren't accurate
-- (f:A → B) ∈ Γ
-- Γ,(xs:A) ⊢ e ↑ Β ⊣ Δ
---------------------------
-- Γ ⊢ f xs = e ↓ Α → B ⊣ Δ
Just t -> do
(rhs', _) <- check (foldr EAbs rhs vars) t
pure (T.Bind (coerce name, t) [] (rhs', t))
-- Γ ⊢ (λxs. e) ↓ A → B ⊣ Δ
-- ------------------------------
-- Γ ⊢ f xs = e ↓ [Γ]A → [Γ]B ⊣ Δ
Nothing -> do
(e, t) <- infer $ foldr EAbs rhs vars
t' <- applyEnv t
@ -113,7 +105,7 @@ typecheckBind (Bind name vars rhs) = do
pure (T.Bind (coerce name, t') [] (e', t'))
env <- gets env
unless (isComplete env) err
insertSig (coerce name) typ
insertSig (coerce name) typ -- HERE
putEnv Empty
pure bind'
where
@ -265,9 +257,9 @@ instantiateL tevar typ = gets env >>= go
-- Γ ⊢ τ
-- ----------------------------- InstLSolve
-- Γ,ά,Γ' ⊢ ά :=< τ ⊣ Γ,(ά=τ),Γ'
| isMono typ
| noForall typ
, (env_l, env_r) <- splitOn (EnvTEVar tevar) env
, Right _ <- wellFormed env_l typ
, Right _ <- wellFormed env_l typ
= putEnv $ (env_l :|> EnvTEVarSolved tevar typ) <> env_r
| TEVar tevar' <- typ = instReach tevar tevar'
@ -305,7 +297,7 @@ instantiateR typ tevar = gets env >>= go
-- Γ ⊢ τ
-- ----------------------------- InstRSolve
-- Γ,ά,Γ' ⊢ τ =:< ά ⊣ Γ,(ά=τ),Γ'
| isMono typ
| noForall typ
, (env_l, env_r) <- splitOn (EnvTEVar tevar) env
, Right _ <- wellFormed env_l typ
= putEnv $ (env_l :|> EnvTEVarSolved tevar typ) <> env_r
@ -337,7 +329,6 @@ instantiateR typ tevar = gets env >>= go
let (env_l, _) = splitOn (EnvTVar tvar) env
putEnv env_l
| otherwise = error $ "Trying to instantiateR: " ++ ppT typ ++ " <: "
++ ppT (TEVar tevar)
@ -385,18 +376,6 @@ check exp typ
putEnv env_l
pure (T.EAbs (coerce name) e', typ)
-- Θ ⊢ Π ∷ [Θ]A ↑ [Θ]C ⊣ Δ
-- Γ ⊢ e ↓ A ⊣ Θ Δ ⊢ Π covers [Δ]A TODO
-- ---------------------------------------
-- Γ ⊢ case e of Π ↑ C ⊣ Δ
-- TODO maybe remove only use infer rule
| ECase scrut branches <- exp = do
(scrut', t_scrut) <- infer scrut
t_scrut' <- applyEnv t_scrut
typ' <- applyEnv typ
branches' <- mapM (\b -> checkBranch b t_scrut' typ') branches
pure (T.ECase (scrut', t_scrut') branches', typ')
| otherwise = subsumption
where
-- Γ,α ⊢ e ↓ A ⊣ Θ Θ ⊢ [Θ]A <: [Θ]B ⊣ Δ
@ -405,7 +384,7 @@ check exp typ
subsumption = do
(exp', t) <- infer exp
exp'' <- applyEnvExp exp'
t' <- applyEnv t
t' <- applyEnv t
typ' <- applyEnv typ
subtype t' typ'
pure (exp'', t')
@ -415,19 +394,20 @@ check exp typ
infer :: Exp -> Tc (T.ExpT' Type)
infer = \case
ELit lit -> pure (T.ELit lit, inferLit lit)
ELit lit -> pure (T.ELit lit, litType lit)
-- (x : A) ∈ Γ
-- ------------- Var
-- Γ ⊢ x ↓ A ⊣ Γ
-- (x : A) ∈ Γ (x : A) ∉ Γ
-- ------------- Var --------------- Var'
-- Γ ⊢ x ↓ A ⊣ Γ Γ ⊢ x ↓ ά ⊣ Γ,ά
EVar name -> do
t <- liftA2 (<|>) (lookupEnv name) (lookupSig name) >>= \case
Just t -> pure t
Nothing -> do
e <- maybeToRightM
("Unbound variable " ++ show name)
=<< lookupBind name
snd <$> infer e
tevar <- fresh
insertEnv (EnvTEVar tevar)
let t = TEVar tevar
insertEnv (EnvVar name t)
pure t
pure (T.EVar (coerce name), t)
EInj name -> do
@ -480,28 +460,25 @@ infer = \case
putEnv env_l
pure (T.ELet (T.Bind (coerce name, t_rhs) [] (rhs', t_rhs)) (e',t), t)
-- Γ ⊢ e₁ ↑ Int Γ ⊢ e₁ ↑ Int
-- Γ ⊢ e₁ ↑ Int ⊣ Θ Θ ⊢ e₂ ↑ Int
-- --------------------------- +I
-- Γ ⊢ e₁ + e₂ ↓ Int ⊣ Δ
EAdd e1 e2 -> do
cxt <- get
let t = int
e1' <- check e1 t
put cxt
e2' <- check e2 t
pure (T.EAdd e1' e2', t)
e1' <- check e1 int
e2' <- check e2 int
e1'' <- applyEnvExpT e1'
e2'' <- applyEnvExpT e2'
pure (T.EAdd e1'' e2'', int)
-- Θ ⊢ Π ∷ [Θ]A ↑ [Θ]C ⊣ Δ
-- Θ ⊢ Π ∷ A ↓ C ⊣ Δ
-- Γ ⊢ e ↓ A ⊣ Θ Δ ⊢ Π covers [Δ]A TODO
-- ---------------------------------------
-- Γ ⊢ case e of Π ↓ C ⊣ Δ
ECase scrut branches -> do
(scrut', t_scrut) <- infer scrut
t_scrut' <- applyEnv t_scrut
(branches', ts) <- mapAndUnzipM (`inferBranch` t_scrut') branches
unless (allSame ts) $ throwError "Branches have different return types"
pure (T.ECase (scrut', t_scrut') branches', head ts)
(branches', t_return) <- inferBranches branches t_scrut
pure (T.ECase (scrut', t_scrut) branches', t_return)
-- | Γ ⊢ A • e ⇓ C ⊣ Δ
-- Under input context Γ , applying a function of type A to e infers type C, with output context ∆
@ -547,45 +524,71 @@ apply typ exp = case typ of
-- * Pattern matching
---------------------------------------------------------------------------
-- | Γ ⊢ p ⇒ e ∷ A ↓ C
-- Under context Γ, check pattern in branch p ⇒ e of type A and infer bodies of type C
-- Γ ⊢ p ⇒ e ∷ A ↓ B ⊣ Θ
-- Θ ⊢ Π ∷ [Θ]A ↓ C ⊣ Δ
-- [Δ]B <: C
-- ---------------------------
-- Γ ⊢ (p ⇒ e),Π ∷ A ↓ C ⊣ Δ
inferBranches :: [Branch] -> Type -> Tc ([T.Branch' Type], Type)
inferBranches branches t_patt = do
(branches', ts_exp) <- inferBranches' t_patt branches
ts_exp' <- mapM applyEnv ts_exp
let (monos, pols) = partition isMono ts_exp'
t_exp <- liftEither $ bodyType t_patt monos
mapM_ (subtype t_exp) pols
pure (branches', t_exp)
where
bodyType :: Type -> [Type] -> Err Type
bodyType t_patt = \case
[] -> pure t_patt
[m] -> pure m
m:n:ms | m == n -> bodyType t_patt (n:ms)
| otherwise -> throwError $ unwords [ "Wrong return types: "
, ppT m, "", ppT n ]
inferBranches' = go [] []
where
go branches ts_exp t = \case
[] -> pure (branches, ts_exp)
b:bs -> do
(b', t_e) <- inferBranch b t
t' <- applyEnv t
go (snoc b' branches) (snoc t_e ts_exp) t' bs
-- Γ ⊢ p ↑ A ⊣ Θ Θ ⊢ e ↓ C ⊣ Δ
-- -------------------------------
-- Γ ⊢ p ⇒ e ∷ A ↓ C ⊣ Δ
inferBranch :: Branch -> Type -> Tc (T.Branch' Type, Type)
inferBranch (Branch patt exp) t_patt = do
env_marker <- EnvMark <$> fresh
insertEnv env_marker
patt' <- checkPattern patt t_patt
(exp', t_exp) <- infer exp
(env_l, _) <- gets (splitOn env_marker . env)
putEnv env_l
pure (T.Branch patt' (exp', t_exp), t_exp)
-- | Γ ⊢ p ⇒ e ∷ A ↑ C
-- Under context Γ, check branch p ⇒ e of type A and bodies of type C
checkBranch :: Branch -> Type -> Type -> Tc (T.Branch' Type)
checkBranch (Branch patt exp) t_patt t_exp = do
env_marker <- EnvMark <$> fresh
insertEnv env_marker
patt' <- checkPattern patt t_patt
t_exp' <- applyEnv t_exp
(exp, t_exp) <- check exp t_exp'
(env_l, _) <- gets (splitOn env_marker . env)
putEnv env_l
pure (T.Branch patt' (exp, t_exp))
checkPattern :: Pattern -> Type -> Tc (T.Pattern' Type, Type)
checkPattern patt t_patt = case patt of
-- -------------------
-- Γ ⊢ x ↑ A ⊣ Γ,(x:A)
PVar x -> do
insertEnv $ EnvVar x t_patt
pure (T.PVar (coerce x, t_patt), t_patt)
-- -------------
-- Γ ⊢ _ ↑ A ⊣ Γ
PCatch -> pure (T.PCatch, t_patt)
-- Γ ⊢ τ ↓ A ⊣ Γ Γ ⊢ A <: B ⊣ Δ
-- ------------------------------
-- Γ ⊢ τ ↑ B ⊣ Δ
PLit lit -> do
subtype (inferLit lit) t_patt
subtype (litType lit) t_patt
t_patt' <- applyEnv t_patt
pure (T.PLit (lit, t_patt), t_patt')
-- (x : A) ∈ Γ Γ ⊢ A <: B ⊣ Δ
-- ---------------------------
-- Γ ⊢ inj₀ x ↑ B ⊣ Δ
PEnum name -> do
t <- maybeToRightM ("Unknown constructor " ++ show name)
=<< lookupInj name
@ -599,13 +602,14 @@ checkPattern patt t_patt = case patt of
t_inj' <- foldrM substitute' t_inj $ getInitForalls t_inj
subtype (getDataId t_inj') t_patt
t_inj'' <- applyEnv t_inj'
t_patt' <- applyEnv t_patt
let ts_inj = getParams t_inj''
ps' <- zipWithM (\p t -> checkPattern p =<< applyEnv t) ps ts_inj
t_patt' <- applyEnv t_patt
pure (T.PInj (coerce name) (map fst ps'), t_patt')
where
substitute' fa t = do
tevar <- fresh
-- insertEnv (EnvTEVar tevar)
pure $ substitute tvar tevar t
where
TAll tvar _ = fa dummy
@ -666,6 +670,9 @@ splitOn x env = second (S.drop 1) $ S.breakl (==x) env
dropTrailing :: EnvElem -> Tc ()
dropTrailing x = modifyEnv $ S.takeWhileL (/= x)
applyEnvExpT :: (T.Exp' Type, Type) -> Tc (T.Exp' Type, Type)
applyEnvExpT (e, t) = liftA2 (,) (applyEnvExp e) (applyEnv t)
applyEnvExp :: T.Exp' Type -> Tc (T.Exp' Type)
applyEnvExp exp = case exp of
T.ELet (T.Bind id vars rhs) exp -> do
@ -681,7 +688,6 @@ applyEnvExp exp = case exp of
(mapM applyEnvBranch branches)
_ -> pure exp
where
applyEnvExpT (e, t) = liftA2 (,) (applyEnvExp e) (applyEnv t)
applyEnvId = secondM applyEnv
applyEnvBranch (T.Branch (p, t) e) = do
pt <- liftA2 (,) (applyEnvPattern p) (applyEnv t)
@ -752,20 +758,24 @@ wellFormed env = \case
TData _ typs -> mapM_ (wellFormed env) typs
noForall :: Type -> Bool
noForall = \case
TAll{} -> False
TFun t1 t2 -> on (&&) noForall t1 t2
TData _ typs -> all noForall typs
TVar _ -> True
TEVar _ -> True
TLit _ -> True
isMono :: Type -> Bool
isMono = \case
TAll{} -> False
TFun t1 t2 -> on (&&) isMono t1 t2
TData _ typs -> all isMono typs
TVar _ -> True
TEVar _ -> True
TVar _ -> False
TEVar _ -> False
TLit _ -> True
inferLit :: Lit -> Type
inferLit = \case
LInt _ -> TLit "Int"
LChar _ -> TLit "Char"
fresh :: Tc TEVar
fresh = do
tevar <- gets (MkTEVar . LIdent . ("a#" ++) . show . next_tevar)
@ -803,60 +813,6 @@ skipForalls = go []
TAll tvar t -> go (snoc (TAll tvar) acc) t
_ -> (acc, typ)
getForallsData :: Type -> [Type -> Type]
getForallsData = fst . partitionData
getTData :: Type -> Type
getTData = snd . partitionData
partitionData :: Type -> ([Type -> Type], Type)
partitionData = go . ([],)
where
go (acc, typ) = case typ of
TAll tvar t -> go (snoc (TAll tvar) acc, t)
TData {} -> (acc, typ)
TFun _ t -> go (acc, t)
_ -> error "Bad data type"
partitionTypeWithForall :: Type -> ([Type], Type)
partitionTypeWithForall typ = (t_vars', t_return')
where
t_vars' = map (\t -> foldr applyForall t foralls) t_vars
t_return' = foldr applyForall t_return foralls
applyForall fa t | usesTVar tvar t = fa t
| otherwise = t
where TAll tvar _ = fa t
(t_vars, t_return) = go [] typ'
(foralls, typ') = skipForalls typ
go acc t = case t of
TFun t1 t2 -> go (snoc t1 acc) t2
_ -> (acc, t)
usesTVar :: TVar -> Type -> Bool
usesTVar tvar = \case
TLit _ -> False
TVar tvar' | tvar' == tvar -> True
| otherwise -> False
TFun t1 t2 -> on (||) usesTVar' t1 t2
TAll tvar' t | tvar' == tvar -> error "Redeclaration of TVar"
| otherwise -> usesTVar' t
TData _ typs -> any usesTVar' typs
_ -> error "Impossible"
where
usesTVar' = usesTVar tvar
skipLambdas :: Int -> T.Exp' Type -> T.Exp' Type
skipLambdas i exp
| i == 0 = exp
| T.EAbs _ (e, _) <- exp = skipLambdas (i-1) e
| otherwise = error "Number of expected lambdas doesn't match expression"
isComplete :: Env -> Bool
isComplete = isNothing . S.findIndexL unSolvedTEVar
where
@ -872,9 +828,6 @@ toTVar = \case
insertEnv :: EnvElem -> Tc ()
insertEnv x = modifyEnv (:|> x)
lookupBind :: LIdent -> Tc (Maybe Exp)
lookupBind x = gets (Map.lookup x . binds)
lookupSig :: LIdent -> Tc (Maybe Type)
lookupSig x = gets (Map.lookup x . sig)

View file

@ -32,6 +32,8 @@ testTypeCheckerBidir = describe "Bidirectional type checker test" $ do
tc_mono_case
tc_pol_case
tc_infer_case
tc_rec1
tc_rec2
tc_id =
specify "Basic identity function polymorphism" $
@ -295,6 +297,21 @@ tc_infer_case = describe "Infer case expression" $ do
, "};"
]
tc_rec1 = specify "Infer simple recursive definition" $
run ["test x = 1 + test (x + 1);"] `shouldSatisfy` ok
tc_rec2 = specify "Infer recursive definition with pattern matching" $ run
[ "data Bool () where {"
, " False : Bool ()"
, " True : Bool ()"
, "};"
, "test = \\x. case x of {"
, " 10 => True;"
, " _ => test (x+1);"
, "};"
] `shouldSatisfy` ok
run :: [String] -> Err T.Program
run = rmTEVar <=< typecheck <=< pProgram . myLexer . unlines