Use use tevars for bind without type signatures, fix recursive functions

This commit is contained in:
Martin Fredin 2023-03-30 18:46:37 +02:00
parent 4831205e67
commit 72352d9619
2 changed files with 107 additions and 137 deletions

View file

@ -6,18 +6,18 @@
module TypeChecker.TypeCheckerBidir (typecheck, getVars) where module TypeChecker.TypeCheckerBidir (typecheck, getVars) where
import Auxiliary (char, int, maybeToRightM, snoc) import Auxiliary (int, litType, maybeToRightM, snoc)
import Control.Applicative (Alternative, Applicative (liftA2), import Control.Applicative (Alternative, Applicative (liftA2),
(<|>)) (<|>))
import Control.Monad.Except (ExceptT, MonadError (throwError), import Control.Monad.Except (ExceptT, MonadError (throwError),
mapAndUnzipM, runExceptT, unless, liftEither, runExceptT, unless,
zipWithM, zipWithM_) zipWithM, zipWithM_)
import Control.Monad.State (MonadState (get, put), State, import Control.Monad.State (MonadState (get, put), State,
evalState, gets, modify) evalState, gets, modify)
import Data.Coerce (coerce) import Data.Coerce (coerce)
import Data.Foldable (foldrM) import Data.Foldable (foldrM)
import Data.Function (on) import Data.Function (on)
import Data.List (intercalate) import Data.List (intercalate, partition)
import Data.List.Extra (allSame) import Data.List.Extra (allSame)
import Data.Map (Map) import Data.Map (Map)
import qualified Data.Map as Map import qualified Data.Map as Map
@ -39,6 +39,7 @@ import qualified TypeChecker.TypeCheckerIr as T
-- • Fix problems with types in Pattern/Branch in TypeCheckerIr -- • Fix problems with types in Pattern/Branch in TypeCheckerIr
-- • Use applyEnvExp consistently -- • Use applyEnvExp consistently
-- • Fix the different type getters functions (e.g. partitionType) functions -- • Fix the different type getters functions (e.g. partitionType) functions
-- • Handle recursive functions. Maybe use a isRec : Bool variable.
data EnvElem = EnvVar LIdent Type -- ^ Term variable typing. x : A data EnvElem = EnvVar LIdent Type -- ^ Term variable typing. x : A
| EnvTVar TVar -- ^ Universal type variable. α | EnvTVar TVar -- ^ Universal type variable. α
@ -94,18 +95,9 @@ typecheck (Program defs) = do
typecheckBind :: Bind -> Tc (T.Bind' Type) typecheckBind :: Bind -> Tc (T.Bind' Type)
typecheckBind (Bind name vars rhs) = do typecheckBind (Bind name vars rhs) = do
bind'@(T.Bind (name, typ) _ _) <- lookupSig name >>= \case bind'@(T.Bind (name, typ) _ _) <- lookupSig name >>= \case
-- TODO These Judgment aren't accurate
-- (f:A → B) ∈ Γ
-- Γ,(xs:A) ⊢ e ↑ Β ⊣ Δ
---------------------------
-- Γ ⊢ f xs = e ↓ Α → B ⊣ Δ
Just t -> do Just t -> do
(rhs', _) <- check (foldr EAbs rhs vars) t (rhs', _) <- check (foldr EAbs rhs vars) t
pure (T.Bind (coerce name, t) [] (rhs', t)) pure (T.Bind (coerce name, t) [] (rhs', t))
-- Γ ⊢ (λxs. e) ↓ A → B ⊣ Δ
-- ------------------------------
-- Γ ⊢ f xs = e ↓ [Γ]A → [Γ]B ⊣ Δ
Nothing -> do Nothing -> do
(e, t) <- infer $ foldr EAbs rhs vars (e, t) <- infer $ foldr EAbs rhs vars
t' <- applyEnv t t' <- applyEnv t
@ -113,7 +105,7 @@ typecheckBind (Bind name vars rhs) = do
pure (T.Bind (coerce name, t') [] (e', t')) pure (T.Bind (coerce name, t') [] (e', t'))
env <- gets env env <- gets env
unless (isComplete env) err unless (isComplete env) err
insertSig (coerce name) typ insertSig (coerce name) typ -- HERE
putEnv Empty putEnv Empty
pure bind' pure bind'
where where
@ -265,7 +257,7 @@ instantiateL tevar typ = gets env >>= go
-- Γ ⊢ τ -- Γ ⊢ τ
-- ----------------------------- InstLSolve -- ----------------------------- InstLSolve
-- Γ,ά,Γ' ⊢ ά :=< τ ⊣ Γ,(ά=τ),Γ' -- Γ,ά,Γ' ⊢ ά :=< τ ⊣ Γ,(ά=τ),Γ'
| isMono typ | noForall typ
, (env_l, env_r) <- splitOn (EnvTEVar tevar) env , (env_l, env_r) <- splitOn (EnvTEVar tevar) env
, Right _ <- wellFormed env_l typ , Right _ <- wellFormed env_l typ
= putEnv $ (env_l :|> EnvTEVarSolved tevar typ) <> env_r = putEnv $ (env_l :|> EnvTEVarSolved tevar typ) <> env_r
@ -305,7 +297,7 @@ instantiateR typ tevar = gets env >>= go
-- Γ ⊢ τ -- Γ ⊢ τ
-- ----------------------------- InstRSolve -- ----------------------------- InstRSolve
-- Γ,ά,Γ' ⊢ τ =:< ά ⊣ Γ,(ά=τ),Γ' -- Γ,ά,Γ' ⊢ τ =:< ά ⊣ Γ,(ά=τ),Γ'
| isMono typ | noForall typ
, (env_l, env_r) <- splitOn (EnvTEVar tevar) env , (env_l, env_r) <- splitOn (EnvTEVar tevar) env
, Right _ <- wellFormed env_l typ , Right _ <- wellFormed env_l typ
= putEnv $ (env_l :|> EnvTEVarSolved tevar typ) <> env_r = putEnv $ (env_l :|> EnvTEVarSolved tevar typ) <> env_r
@ -337,7 +329,6 @@ instantiateR typ tevar = gets env >>= go
let (env_l, _) = splitOn (EnvTVar tvar) env let (env_l, _) = splitOn (EnvTVar tvar) env
putEnv env_l putEnv env_l
| otherwise = error $ "Trying to instantiateR: " ++ ppT typ ++ " <: " | otherwise = error $ "Trying to instantiateR: " ++ ppT typ ++ " <: "
++ ppT (TEVar tevar) ++ ppT (TEVar tevar)
@ -385,18 +376,6 @@ check exp typ
putEnv env_l putEnv env_l
pure (T.EAbs (coerce name) e', typ) pure (T.EAbs (coerce name) e', typ)
-- Θ ⊢ Π ∷ [Θ]A ↑ [Θ]C ⊣ Δ
-- Γ ⊢ e ↓ A ⊣ Θ Δ ⊢ Π covers [Δ]A TODO
-- ---------------------------------------
-- Γ ⊢ case e of Π ↑ C ⊣ Δ
-- TODO maybe remove only use infer rule
| ECase scrut branches <- exp = do
(scrut', t_scrut) <- infer scrut
t_scrut' <- applyEnv t_scrut
typ' <- applyEnv typ
branches' <- mapM (\b -> checkBranch b t_scrut' typ') branches
pure (T.ECase (scrut', t_scrut') branches', typ')
| otherwise = subsumption | otherwise = subsumption
where where
-- Γ,α ⊢ e ↓ A ⊣ Θ Θ ⊢ [Θ]A <: [Θ]B ⊣ Δ -- Γ,α ⊢ e ↓ A ⊣ Θ Θ ⊢ [Θ]A <: [Θ]B ⊣ Δ
@ -415,19 +394,20 @@ check exp typ
infer :: Exp -> Tc (T.ExpT' Type) infer :: Exp -> Tc (T.ExpT' Type)
infer = \case infer = \case
ELit lit -> pure (T.ELit lit, inferLit lit) ELit lit -> pure (T.ELit lit, litType lit)
-- (x : A) ∈ Γ -- (x : A) ∈ Γ (x : A) ∉ Γ
-- ------------- Var -- ------------- Var --------------- Var'
-- Γ ⊢ x ↓ A ⊣ Γ -- Γ ⊢ x ↓ A ⊣ Γ Γ ⊢ x ↓ ά ⊣ Γ,ά
EVar name -> do EVar name -> do
t <- liftA2 (<|>) (lookupEnv name) (lookupSig name) >>= \case t <- liftA2 (<|>) (lookupEnv name) (lookupSig name) >>= \case
Just t -> pure t Just t -> pure t
Nothing -> do Nothing -> do
e <- maybeToRightM tevar <- fresh
("Unbound variable " ++ show name) insertEnv (EnvTEVar tevar)
=<< lookupBind name let t = TEVar tevar
snd <$> infer e insertEnv (EnvVar name t)
pure t
pure (T.EVar (coerce name), t) pure (T.EVar (coerce name), t)
EInj name -> do EInj name -> do
@ -480,28 +460,25 @@ infer = \case
putEnv env_l putEnv env_l
pure (T.ELet (T.Bind (coerce name, t_rhs) [] (rhs', t_rhs)) (e',t), t) pure (T.ELet (T.Bind (coerce name, t_rhs) [] (rhs', t_rhs)) (e',t), t)
-- Γ ⊢ e₁ ↑ Int Γ ⊢ e₁ ↑ Int -- Γ ⊢ e₁ ↑ Int ⊣ Θ Θ ⊢ e₂ ↑ Int
-- --------------------------- +I -- --------------------------- +I
-- Γ ⊢ e₁ + e₂ ↓ Int ⊣ Δ -- Γ ⊢ e₁ + e₂ ↓ Int ⊣ Δ
EAdd e1 e2 -> do EAdd e1 e2 -> do
cxt <- get e1' <- check e1 int
let t = int e2' <- check e2 int
e1' <- check e1 t e1'' <- applyEnvExpT e1'
put cxt e2'' <- applyEnvExpT e2'
e2' <- check e2 t pure (T.EAdd e1'' e2'', int)
pure (T.EAdd e1' e2', t)
-- Θ ⊢ Π ∷ [Θ]A ↑ [Θ]C ⊣ Δ -- Θ ⊢ Π ∷ A ↓ C ⊣ Δ
-- Γ ⊢ e ↓ A ⊣ Θ Δ ⊢ Π covers [Δ]A TODO -- Γ ⊢ e ↓ A ⊣ Θ Δ ⊢ Π covers [Δ]A TODO
-- --------------------------------------- -- ---------------------------------------
-- Γ ⊢ case e of Π ↓ C ⊣ Δ -- Γ ⊢ case e of Π ↓ C ⊣ Δ
ECase scrut branches -> do ECase scrut branches -> do
(scrut', t_scrut) <- infer scrut (scrut', t_scrut) <- infer scrut
t_scrut' <- applyEnv t_scrut (branches', t_return) <- inferBranches branches t_scrut
(branches', ts) <- mapAndUnzipM (`inferBranch` t_scrut') branches pure (T.ECase (scrut', t_scrut) branches', t_return)
unless (allSame ts) $ throwError "Branches have different return types"
pure (T.ECase (scrut', t_scrut') branches', head ts)
-- | Γ ⊢ A • e ⇓ C ⊣ Δ -- | Γ ⊢ A • e ⇓ C ⊣ Δ
-- Under input context Γ , applying a function of type A to e infers type C, with output context ∆ -- Under input context Γ , applying a function of type A to e infers type C, with output context ∆
@ -547,45 +524,71 @@ apply typ exp = case typ of
-- * Pattern matching -- * Pattern matching
--------------------------------------------------------------------------- ---------------------------------------------------------------------------
-- | Γ ⊢ p ⇒ e ∷ A ↓ C -- Γ ⊢ p ⇒ e ∷ A ↓ B ⊣ Θ
-- Under context Γ, check pattern in branch p ⇒ e of type A and infer bodies of type C -- Θ ⊢ Π ∷ [Θ]A ↓ C ⊣ Δ
-- [Δ]B <: C
-- ---------------------------
-- Γ ⊢ (p ⇒ e),Π ∷ A ↓ C ⊣ Δ
inferBranches :: [Branch] -> Type -> Tc ([T.Branch' Type], Type)
inferBranches branches t_patt = do
(branches', ts_exp) <- inferBranches' t_patt branches
ts_exp' <- mapM applyEnv ts_exp
let (monos, pols) = partition isMono ts_exp'
t_exp <- liftEither $ bodyType t_patt monos
mapM_ (subtype t_exp) pols
pure (branches', t_exp)
where
bodyType :: Type -> [Type] -> Err Type
bodyType t_patt = \case
[] -> pure t_patt
[m] -> pure m
m:n:ms | m == n -> bodyType t_patt (n:ms)
| otherwise -> throwError $ unwords [ "Wrong return types: "
, ppT m, "", ppT n ]
inferBranches' = go [] []
where
go branches ts_exp t = \case
[] -> pure (branches, ts_exp)
b:bs -> do
(b', t_e) <- inferBranch b t
t' <- applyEnv t
go (snoc b' branches) (snoc t_e ts_exp) t' bs
-- Γ ⊢ p ↑ A ⊣ Θ Θ ⊢ e ↓ C ⊣ Δ
-- -------------------------------
-- Γ ⊢ p ⇒ e ∷ A ↓ C ⊣ Δ
inferBranch :: Branch -> Type -> Tc (T.Branch' Type, Type) inferBranch :: Branch -> Type -> Tc (T.Branch' Type, Type)
inferBranch (Branch patt exp) t_patt = do inferBranch (Branch patt exp) t_patt = do
env_marker <- EnvMark <$> fresh
insertEnv env_marker
patt' <- checkPattern patt t_patt patt' <- checkPattern patt t_patt
(exp', t_exp) <- infer exp (exp', t_exp) <- infer exp
(env_l, _) <- gets (splitOn env_marker . env)
putEnv env_l
pure (T.Branch patt' (exp', t_exp), t_exp) pure (T.Branch patt' (exp', t_exp), t_exp)
-- | Γ ⊢ p ⇒ e ∷ A ↑ C
-- Under context Γ, check branch p ⇒ e of type A and bodies of type C
checkBranch :: Branch -> Type -> Type -> Tc (T.Branch' Type)
checkBranch (Branch patt exp) t_patt t_exp = do
env_marker <- EnvMark <$> fresh
insertEnv env_marker
patt' <- checkPattern patt t_patt
t_exp' <- applyEnv t_exp
(exp, t_exp) <- check exp t_exp'
(env_l, _) <- gets (splitOn env_marker . env)
putEnv env_l
pure (T.Branch patt' (exp, t_exp))
checkPattern :: Pattern -> Type -> Tc (T.Pattern' Type, Type) checkPattern :: Pattern -> Type -> Tc (T.Pattern' Type, Type)
checkPattern patt t_patt = case patt of checkPattern patt t_patt = case patt of
-- -------------------
-- Γ ⊢ x ↑ A ⊣ Γ,(x:A)
PVar x -> do PVar x -> do
insertEnv $ EnvVar x t_patt insertEnv $ EnvVar x t_patt
pure (T.PVar (coerce x, t_patt), t_patt) pure (T.PVar (coerce x, t_patt), t_patt)
-- -------------
-- Γ ⊢ _ ↑ A ⊣ Γ
PCatch -> pure (T.PCatch, t_patt) PCatch -> pure (T.PCatch, t_patt)
-- Γ ⊢ τ ↓ A ⊣ Γ Γ ⊢ A <: B ⊣ Δ
-- ------------------------------
-- Γ ⊢ τ ↑ B ⊣ Δ
PLit lit -> do PLit lit -> do
subtype (inferLit lit) t_patt subtype (litType lit) t_patt
t_patt' <- applyEnv t_patt t_patt' <- applyEnv t_patt
pure (T.PLit (lit, t_patt), t_patt') pure (T.PLit (lit, t_patt), t_patt')
-- (x : A) ∈ Γ Γ ⊢ A <: B ⊣ Δ
-- ---------------------------
-- Γ ⊢ inj₀ x ↑ B ⊣ Δ
PEnum name -> do PEnum name -> do
t <- maybeToRightM ("Unknown constructor " ++ show name) t <- maybeToRightM ("Unknown constructor " ++ show name)
=<< lookupInj name =<< lookupInj name
@ -599,13 +602,14 @@ checkPattern patt t_patt = case patt of
t_inj' <- foldrM substitute' t_inj $ getInitForalls t_inj t_inj' <- foldrM substitute' t_inj $ getInitForalls t_inj
subtype (getDataId t_inj') t_patt subtype (getDataId t_inj') t_patt
t_inj'' <- applyEnv t_inj' t_inj'' <- applyEnv t_inj'
t_patt' <- applyEnv t_patt
let ts_inj = getParams t_inj'' let ts_inj = getParams t_inj''
ps' <- zipWithM (\p t -> checkPattern p =<< applyEnv t) ps ts_inj ps' <- zipWithM (\p t -> checkPattern p =<< applyEnv t) ps ts_inj
t_patt' <- applyEnv t_patt
pure (T.PInj (coerce name) (map fst ps'), t_patt') pure (T.PInj (coerce name) (map fst ps'), t_patt')
where where
substitute' fa t = do substitute' fa t = do
tevar <- fresh tevar <- fresh
-- insertEnv (EnvTEVar tevar)
pure $ substitute tvar tevar t pure $ substitute tvar tevar t
where where
TAll tvar _ = fa dummy TAll tvar _ = fa dummy
@ -666,6 +670,9 @@ splitOn x env = second (S.drop 1) $ S.breakl (==x) env
dropTrailing :: EnvElem -> Tc () dropTrailing :: EnvElem -> Tc ()
dropTrailing x = modifyEnv $ S.takeWhileL (/= x) dropTrailing x = modifyEnv $ S.takeWhileL (/= x)
applyEnvExpT :: (T.Exp' Type, Type) -> Tc (T.Exp' Type, Type)
applyEnvExpT (e, t) = liftA2 (,) (applyEnvExp e) (applyEnv t)
applyEnvExp :: T.Exp' Type -> Tc (T.Exp' Type) applyEnvExp :: T.Exp' Type -> Tc (T.Exp' Type)
applyEnvExp exp = case exp of applyEnvExp exp = case exp of
T.ELet (T.Bind id vars rhs) exp -> do T.ELet (T.Bind id vars rhs) exp -> do
@ -681,7 +688,6 @@ applyEnvExp exp = case exp of
(mapM applyEnvBranch branches) (mapM applyEnvBranch branches)
_ -> pure exp _ -> pure exp
where where
applyEnvExpT (e, t) = liftA2 (,) (applyEnvExp e) (applyEnv t)
applyEnvId = secondM applyEnv applyEnvId = secondM applyEnv
applyEnvBranch (T.Branch (p, t) e) = do applyEnvBranch (T.Branch (p, t) e) = do
pt <- liftA2 (,) (applyEnvPattern p) (applyEnv t) pt <- liftA2 (,) (applyEnvPattern p) (applyEnv t)
@ -752,20 +758,24 @@ wellFormed env = \case
TData _ typs -> mapM_ (wellFormed env) typs TData _ typs -> mapM_ (wellFormed env) typs
noForall :: Type -> Bool
noForall = \case
TAll{} -> False
TFun t1 t2 -> on (&&) noForall t1 t2
TData _ typs -> all noForall typs
TVar _ -> True
TEVar _ -> True
TLit _ -> True
isMono :: Type -> Bool isMono :: Type -> Bool
isMono = \case isMono = \case
TAll{} -> False TAll{} -> False
TFun t1 t2 -> on (&&) isMono t1 t2 TFun t1 t2 -> on (&&) isMono t1 t2
TData _ typs -> all isMono typs TData _ typs -> all isMono typs
TVar _ -> True TVar _ -> False
TEVar _ -> True TEVar _ -> False
TLit _ -> True TLit _ -> True
inferLit :: Lit -> Type
inferLit = \case
LInt _ -> TLit "Int"
LChar _ -> TLit "Char"
fresh :: Tc TEVar fresh :: Tc TEVar
fresh = do fresh = do
tevar <- gets (MkTEVar . LIdent . ("a#" ++) . show . next_tevar) tevar <- gets (MkTEVar . LIdent . ("a#" ++) . show . next_tevar)
@ -803,60 +813,6 @@ skipForalls = go []
TAll tvar t -> go (snoc (TAll tvar) acc) t TAll tvar t -> go (snoc (TAll tvar) acc) t
_ -> (acc, typ) _ -> (acc, typ)
getForallsData :: Type -> [Type -> Type]
getForallsData = fst . partitionData
getTData :: Type -> Type
getTData = snd . partitionData
partitionData :: Type -> ([Type -> Type], Type)
partitionData = go . ([],)
where
go (acc, typ) = case typ of
TAll tvar t -> go (snoc (TAll tvar) acc, t)
TData {} -> (acc, typ)
TFun _ t -> go (acc, t)
_ -> error "Bad data type"
partitionTypeWithForall :: Type -> ([Type], Type)
partitionTypeWithForall typ = (t_vars', t_return')
where
t_vars' = map (\t -> foldr applyForall t foralls) t_vars
t_return' = foldr applyForall t_return foralls
applyForall fa t | usesTVar tvar t = fa t
| otherwise = t
where TAll tvar _ = fa t
(t_vars, t_return) = go [] typ'
(foralls, typ') = skipForalls typ
go acc t = case t of
TFun t1 t2 -> go (snoc t1 acc) t2
_ -> (acc, t)
usesTVar :: TVar -> Type -> Bool
usesTVar tvar = \case
TLit _ -> False
TVar tvar' | tvar' == tvar -> True
| otherwise -> False
TFun t1 t2 -> on (||) usesTVar' t1 t2
TAll tvar' t | tvar' == tvar -> error "Redeclaration of TVar"
| otherwise -> usesTVar' t
TData _ typs -> any usesTVar' typs
_ -> error "Impossible"
where
usesTVar' = usesTVar tvar
skipLambdas :: Int -> T.Exp' Type -> T.Exp' Type
skipLambdas i exp
| i == 0 = exp
| T.EAbs _ (e, _) <- exp = skipLambdas (i-1) e
| otherwise = error "Number of expected lambdas doesn't match expression"
isComplete :: Env -> Bool isComplete :: Env -> Bool
isComplete = isNothing . S.findIndexL unSolvedTEVar isComplete = isNothing . S.findIndexL unSolvedTEVar
where where
@ -872,9 +828,6 @@ toTVar = \case
insertEnv :: EnvElem -> Tc () insertEnv :: EnvElem -> Tc ()
insertEnv x = modifyEnv (:|> x) insertEnv x = modifyEnv (:|> x)
lookupBind :: LIdent -> Tc (Maybe Exp)
lookupBind x = gets (Map.lookup x . binds)
lookupSig :: LIdent -> Tc (Maybe Type) lookupSig :: LIdent -> Tc (Maybe Type)
lookupSig x = gets (Map.lookup x . sig) lookupSig x = gets (Map.lookup x . sig)

View file

@ -32,6 +32,8 @@ testTypeCheckerBidir = describe "Bidirectional type checker test" $ do
tc_mono_case tc_mono_case
tc_pol_case tc_pol_case
tc_infer_case tc_infer_case
tc_rec1
tc_rec2
tc_id = tc_id =
specify "Basic identity function polymorphism" $ specify "Basic identity function polymorphism" $
@ -295,6 +297,21 @@ tc_infer_case = describe "Infer case expression" $ do
, "};" , "};"
] ]
tc_rec1 = specify "Infer simple recursive definition" $
run ["test x = 1 + test (x + 1);"] `shouldSatisfy` ok
tc_rec2 = specify "Infer recursive definition with pattern matching" $ run
[ "data Bool () where {"
, " False : Bool ()"
, " True : Bool ()"
, "};"
, "test = \\x. case x of {"
, " 10 => True;"
, " _ => test (x+1);"
, "};"
] `shouldSatisfy` ok
run :: [String] -> Err T.Program run :: [String] -> Err T.Program
run = rmTEVar <=< typecheck <=< pProgram . myLexer . unlines run = rmTEVar <=< typecheck <=< pProgram . myLexer . unlines