Progress on new checkPattern

This commit is contained in:
Martin Fredin 2023-03-27 23:55:04 +02:00
parent f20b80cab3
commit 528369c95c

View file

@ -10,11 +10,12 @@ import Auxiliary (maybeToRightM, snoc)
import Control.Applicative (Alternative, Applicative (liftA2),
(<|>))
import Control.Monad.Except (ExceptT, MonadError (throwError),
runExceptT, unless, zipWithM,
zipWithM_)
liftEither, runExceptT, unless,
zipWithM, zipWithM_)
import Control.Monad.State (MonadState (get, put), State,
evalState, gets, modify)
import Data.Coerce (coerce)
import Data.Either.Combinators (maybeToRight)
import Data.Function (on)
import Data.List (intercalate)
import Data.Map (Map)
@ -45,11 +46,12 @@ type Env = Seq EnvElem
-- | Ordered context
-- Γ ::= ・| Γ, α | Γ, ά | Γ, ▶ ά | Γ, x:A
data Cxt = Cxt
{ env :: Env -- ^ Local scope context Γ
, sig :: Map LIdent Type -- ^ Top-level signatures x : A
, binds :: Map LIdent Exp -- ^ Top-level binds x : e
, next_tevar :: Int -- ^ Counter to distinguish ά
, data_injs :: Map UIdent Type -- ^ Data injections (constructors) K
{ env :: Env -- ^ Local scope context Γ
, sig :: Map LIdent Type -- ^ Top-level signatures x : A
, binds :: Map LIdent Exp -- ^ Top-level binds x : e
, next_tevar :: Int -- ^ Counter to distinguish ά
, data_injs :: Map UIdent Type -- ^ Data injections (constructors) K/inj : A
, data_types :: Map UIdent [(UIdent, Type)] -- ^ Data types D : (K₁:A₁ + ‥ + Kₙ:Aₙ)
} deriving (Show, Eq)
newtype Tc a = Tc { runTc :: ExceptT String (State Cxt) a }
@ -69,16 +71,24 @@ typecheck (Program defs) = do
| DBind' name vars rhs <- defs
]
, next_tevar = 0
, data_injs = Map.fromList [ (name, typ)
| Data _ injs <- datatypes
, Inj name typ <- injs
]
, data_injs = Map.fromList [ (name, foldr ($) typ $ getForallsData typ)
| Data _ injs <- datatypes
, Inj name typ <- injs
]
, data_types = Map.fromList [ let
TData name _ = getTData typ
kts = [(k,t) | Inj k t <- injs ]
in
(name, kts)
| Data typ injs <- datatypes
]
}
binds' <- evalState (runExceptT (runTc $ mapM typecheckBind binds)) initCxt;
pure . T.Program $ map T.DData (coerceData datatypes) ++ map T.DBind binds'
where
binds = [ b | DBind b <- defs ]
-- TODO this should happen in typecheckDataType
coerceData = map (\(Data t injs) -> T.Data t $ map
(\(Inj name typ) -> T.Inj (coerce name) typ) injs)
@ -138,8 +148,8 @@ typecheckInj (Inj inj_name inj_typ) name tvars
| TData name' typs <- getReturn inj_typ
, name' == name
, Right tvars' <- mapM toTVar typs
, tvars' == tvars
= pure (Inj inj_name $ foldr TAll inj_typ tvars)
, all (`elem` tvars) tvars'
= pure (Inj inj_name inj_typ)
| otherwise
= throwError $ unwords
["Bad type constructor: ", show name
@ -216,6 +226,7 @@ subtype t1 t2 = case (t1, t2) of
-- Γ[ά] ⊢ A <: ά ⊣ Δ
(typ, TEVar tevar) | notElem tevar $ frees typ -> instantiateR typ tevar
(TData name1 typs1, TData name2 typs2)
-- D₁ = D₂
@ -542,33 +553,82 @@ checkBranch (Branch patt exp) t_patt t_exp = do
pure (T.Branch patt' (exp, t_exp))
checkPattern :: Pattern -> Type -> Tc (T.Pattern' Type, Type)
checkPattern patt t_patt = case patt of
checkPattern patt t_patt = (, t_patt) <$> case patt of
PVar x -> do
insertEnv $ EnvVar x t_patt
pure (T.PVar (coerce x, dummy), dummy) -- TODO
PCatch -> pure (T.PCatch, dummy) -- TODO
PLit lit | inferLit lit == t_patt -> let
t = inferLit lit
in
pure (T.PLit (lit, t), t)
pure $ T.PVar (coerce x, t_patt)
PCatch -> pure T.PCatch
PLit lit | inferLit lit == t_patt -> pure $ T.PLit (lit, t_patt)
| otherwise -> throwError "Literal in pattern have wrong type"
PEnum name -> do
t <- maybeToRightM ("Unknown constructor " ++ show name)
=<< lookupInj name
subtype t t_patt
pure (T.PEnum (coerce name), dummy) -- TODO
pure $ T.PEnum (coerce name)
-- Θ₁ ⊢ p₁ ↑ [Θ₁]B₁ ⊣ Θ₂
-- Γ ⊢ (xₖ : B₁ → ‥ → Bₘ₊₁) ∈ Γ ...
-- Γ ⊢ B₁ → ‥ → Bₘ₊₁ <: A₁ + ‥ + Aₙ ⊣ Θ₁ Θₘ ⊢ pₘ ↑ [Θₘ₋₁]Bₘ ⊣ Δ
-- --------------------------------------------------------------
-- Γ ⊢ injₖ xₖ. p₁ ‥ pₘ ↑ A₁ + ‥ + Aₙ ⊣ Δ
PInj name ps -> do
t <- maybeToRightM ("Unknown constructor " ++ show name)
=<< lookupInj name
subtype t t_patt
let (t_ps, t_return) = partitionTypeWithForall t
unless (length ps == length t_ps) $
throwError "Wrong number of variables"
subtype t_return t_patt
ps' <- zipWithM (\p t -> checkPattern p =<< applyEnv t) ps t_ps
let ps'' = map fst ps' -- TODO
pure (T.PInj (coerce name) ps'', dummy)
-- §ps' <- zipWithM (\p t -> checkPattern p =<< applyEnv t) ps t_ps
-- let ps'' = map fst ps' -- TODO
pure $ T.PInj (coerce name) []
subtypeData :: UIdent -> Type -> Tc ()
subtypeData name_inj typ = do
injs <- maybeToRightM err1 =<< lookupDataType name_d
t_inj <- liftEither . maybeToRight err2 $ lookup name_k injs
(t_inj', typs')<- substituteTVars foralls t_inj data_t
subtype ()
undefined
where
substituteTVars fas t1 t2 = case fas of
[] -> pure (t1, t2)
fa:fas' -> do
(t1', t2') <- go fa (t1, t2)
substituteTVars fas' t1' t2'
where
go fa (t1, t2) = let TAll tvar _ = fa dummy in do
tevar <- fresh
insertEnv (EnvTEVar tevar)
pure $ on (,) (substitute tvar tevar) t1 t2
(foralls, data_t@(TData name_d typs)) = partitionData typ
err1 = unwords ["Unknown data type", show name_d]
err2 = unwords ["No", show name_k, "constructor for data type", show name_d]
-- TAll tvar t -> do
-- tevar <- fresh
-- let -- env_marker = EnvMark tevar
-- env_tevar = EnvTEVar tevar
-- -- insertEnv env_marker
-- insertEnv env_tevar
-- let a' = substitute tvar tevar a
-- subtype a' b
-- -- dropTrailing env_marker
-- TData name_d typs -> do
--
-- subtype t_k typ
-- undefined
-- where
---------------------------------------------------------------------------
-- * Auxiliary
@ -725,6 +785,7 @@ getReturn = snd . partitionType
-- ([a, ∀c. c → c], b)
--
-- Unsure if foralls should be added to the return type or not.
-- FIXME
partitionType :: Type -> ([Type], Type)
partitionType = go [] . skipForalls'
where
@ -743,6 +804,22 @@ skipForalls = go []
TAll tvar t -> go (snoc (TAll tvar) acc) t
_ -> (acc, typ)
getForallsData :: Type -> [Type -> Type]
getForallsData = fst . partitionData
getTData :: Type -> Type
getTData = snd . partitionData
partitionData :: Type -> ([Type -> Type], Type)
partitionData = go . ([],)
where
go (acc, typ) = case typ of
TAll tvar t -> go (snoc (TAll tvar) acc, t)
TData {} -> (acc, typ)
_ -> error "Bad data type"
partitionTypeWithForall :: Type -> ([Type], Type)
partitionTypeWithForall typ = (t_vars', t_return')
where
@ -798,6 +875,9 @@ insertEnv x = modifyEnv (:|> x)
lookupBind :: LIdent -> Tc (Maybe Exp)
lookupBind x = gets (Map.lookup x . binds)
lookupDataType :: UIdent -> Tc (Maybe [(UIdent, Type)])
lookupDataType x = gets (Map.lookup x . data_types)
lookupSig :: LIdent -> Tc (Maybe Type)
lookupSig x = gets (Map.lookup x . sig)