Progress on new checkPattern

This commit is contained in:
Martin Fredin 2023-03-27 23:55:04 +02:00
parent f20b80cab3
commit 528369c95c

View file

@ -10,11 +10,12 @@ import Auxiliary (maybeToRightM, snoc)
import Control.Applicative (Alternative, Applicative (liftA2), import Control.Applicative (Alternative, Applicative (liftA2),
(<|>)) (<|>))
import Control.Monad.Except (ExceptT, MonadError (throwError), import Control.Monad.Except (ExceptT, MonadError (throwError),
runExceptT, unless, zipWithM, liftEither, runExceptT, unless,
zipWithM_) zipWithM, zipWithM_)
import Control.Monad.State (MonadState (get, put), State, import Control.Monad.State (MonadState (get, put), State,
evalState, gets, modify) evalState, gets, modify)
import Data.Coerce (coerce) import Data.Coerce (coerce)
import Data.Either.Combinators (maybeToRight)
import Data.Function (on) import Data.Function (on)
import Data.List (intercalate) import Data.List (intercalate)
import Data.Map (Map) import Data.Map (Map)
@ -49,7 +50,8 @@ data Cxt = Cxt
, sig :: Map LIdent Type -- ^ Top-level signatures x : A , sig :: Map LIdent Type -- ^ Top-level signatures x : A
, binds :: Map LIdent Exp -- ^ Top-level binds x : e , binds :: Map LIdent Exp -- ^ Top-level binds x : e
, next_tevar :: Int -- ^ Counter to distinguish ά , next_tevar :: Int -- ^ Counter to distinguish ά
, data_injs :: Map UIdent Type -- ^ Data injections (constructors) K , data_injs :: Map UIdent Type -- ^ Data injections (constructors) K/inj : A
, data_types :: Map UIdent [(UIdent, Type)] -- ^ Data types D : (K₁:A₁ + ‥ + Kₙ:Aₙ)
} deriving (Show, Eq) } deriving (Show, Eq)
newtype Tc a = Tc { runTc :: ExceptT String (State Cxt) a } newtype Tc a = Tc { runTc :: ExceptT String (State Cxt) a }
@ -69,16 +71,24 @@ typecheck (Program defs) = do
| DBind' name vars rhs <- defs | DBind' name vars rhs <- defs
] ]
, next_tevar = 0 , next_tevar = 0
, data_injs = Map.fromList [ (name, typ) , data_injs = Map.fromList [ (name, foldr ($) typ $ getForallsData typ)
| Data _ injs <- datatypes | Data _ injs <- datatypes
, Inj name typ <- injs , Inj name typ <- injs
] ]
, data_types = Map.fromList [ let
TData name _ = getTData typ
kts = [(k,t) | Inj k t <- injs ]
in
(name, kts)
| Data typ injs <- datatypes
]
} }
binds' <- evalState (runExceptT (runTc $ mapM typecheckBind binds)) initCxt; binds' <- evalState (runExceptT (runTc $ mapM typecheckBind binds)) initCxt;
pure . T.Program $ map T.DData (coerceData datatypes) ++ map T.DBind binds' pure . T.Program $ map T.DData (coerceData datatypes) ++ map T.DBind binds'
where where
binds = [ b | DBind b <- defs ] binds = [ b | DBind b <- defs ]
-- TODO this should happen in typecheckDataType
coerceData = map (\(Data t injs) -> T.Data t $ map coerceData = map (\(Data t injs) -> T.Data t $ map
(\(Inj name typ) -> T.Inj (coerce name) typ) injs) (\(Inj name typ) -> T.Inj (coerce name) typ) injs)
@ -138,8 +148,8 @@ typecheckInj (Inj inj_name inj_typ) name tvars
| TData name' typs <- getReturn inj_typ | TData name' typs <- getReturn inj_typ
, name' == name , name' == name
, Right tvars' <- mapM toTVar typs , Right tvars' <- mapM toTVar typs
, tvars' == tvars , all (`elem` tvars) tvars'
= pure (Inj inj_name $ foldr TAll inj_typ tvars) = pure (Inj inj_name inj_typ)
| otherwise | otherwise
= throwError $ unwords = throwError $ unwords
["Bad type constructor: ", show name ["Bad type constructor: ", show name
@ -216,6 +226,7 @@ subtype t1 t2 = case (t1, t2) of
-- Γ[ά] ⊢ A <: ά ⊣ Δ -- Γ[ά] ⊢ A <: ά ⊣ Δ
(typ, TEVar tevar) | notElem tevar $ frees typ -> instantiateR typ tevar (typ, TEVar tevar) | notElem tevar $ frees typ -> instantiateR typ tevar
(TData name1 typs1, TData name2 typs2) (TData name1 typs1, TData name2 typs2)
-- D₁ = D₂ -- D₁ = D₂
@ -542,33 +553,82 @@ checkBranch (Branch patt exp) t_patt t_exp = do
pure (T.Branch patt' (exp, t_exp)) pure (T.Branch patt' (exp, t_exp))
checkPattern :: Pattern -> Type -> Tc (T.Pattern' Type, Type) checkPattern :: Pattern -> Type -> Tc (T.Pattern' Type, Type)
checkPattern patt t_patt = case patt of checkPattern patt t_patt = (, t_patt) <$> case patt of
PVar x -> do PVar x -> do
insertEnv $ EnvVar x t_patt insertEnv $ EnvVar x t_patt
pure (T.PVar (coerce x, dummy), dummy) -- TODO pure $ T.PVar (coerce x, t_patt)
PCatch -> pure (T.PCatch, dummy) -- TODO PCatch -> pure T.PCatch
PLit lit | inferLit lit == t_patt -> let PLit lit | inferLit lit == t_patt -> pure $ T.PLit (lit, t_patt)
t = inferLit lit
in
pure (T.PLit (lit, t), t)
| otherwise -> throwError "Literal in pattern have wrong type" | otherwise -> throwError "Literal in pattern have wrong type"
PEnum name -> do PEnum name -> do
t <- maybeToRightM ("Unknown constructor " ++ show name) t <- maybeToRightM ("Unknown constructor " ++ show name)
=<< lookupInj name =<< lookupInj name
subtype t t_patt subtype t t_patt
pure (T.PEnum (coerce name), dummy) -- TODO pure $ T.PEnum (coerce name)
-- Θ₁ ⊢ p₁ ↑ [Θ₁]B₁ ⊣ Θ₂
-- Γ ⊢ (xₖ : B₁ → ‥ → Bₘ₊₁) ∈ Γ ...
-- Γ ⊢ B₁ → ‥ → Bₘ₊₁ <: A₁ + ‥ + Aₙ ⊣ Θ₁ Θₘ ⊢ pₘ ↑ [Θₘ₋₁]Bₘ ⊣ Δ
-- --------------------------------------------------------------
-- Γ ⊢ injₖ xₖ. p₁ ‥ pₘ ↑ A₁ + ‥ + Aₙ ⊣ Δ
PInj name ps -> do PInj name ps -> do
t <- maybeToRightM ("Unknown constructor " ++ show name) t <- maybeToRightM ("Unknown constructor " ++ show name)
=<< lookupInj name =<< lookupInj name
subtype t t_patt
let (t_ps, t_return) = partitionTypeWithForall t let (t_ps, t_return) = partitionTypeWithForall t
unless (length ps == length t_ps) $ unless (length ps == length t_ps) $
throwError "Wrong number of variables" throwError "Wrong number of variables"
subtype t_return t_patt
ps' <- zipWithM (\p t -> checkPattern p =<< applyEnv t) ps t_ps -- §ps' <- zipWithM (\p t -> checkPattern p =<< applyEnv t) ps t_ps
let ps'' = map fst ps' -- TODO -- let ps'' = map fst ps' -- TODO
pure (T.PInj (coerce name) ps'', dummy) pure $ T.PInj (coerce name) []
subtypeData :: UIdent -> Type -> Tc ()
subtypeData name_inj typ = do
injs <- maybeToRightM err1 =<< lookupDataType name_d
t_inj <- liftEither . maybeToRight err2 $ lookup name_k injs
(t_inj', typs')<- substituteTVars foralls t_inj data_t
subtype ()
undefined
where
substituteTVars fas t1 t2 = case fas of
[] -> pure (t1, t2)
fa:fas' -> do
(t1', t2') <- go fa (t1, t2)
substituteTVars fas' t1' t2'
where
go fa (t1, t2) = let TAll tvar _ = fa dummy in do
tevar <- fresh
insertEnv (EnvTEVar tevar)
pure $ on (,) (substitute tvar tevar) t1 t2
(foralls, data_t@(TData name_d typs)) = partitionData typ
err1 = unwords ["Unknown data type", show name_d]
err2 = unwords ["No", show name_k, "constructor for data type", show name_d]
-- TAll tvar t -> do
-- tevar <- fresh
-- let -- env_marker = EnvMark tevar
-- env_tevar = EnvTEVar tevar
-- -- insertEnv env_marker
-- insertEnv env_tevar
-- let a' = substitute tvar tevar a
-- subtype a' b
-- -- dropTrailing env_marker
-- TData name_d typs -> do
--
-- subtype t_k typ
-- undefined
-- where
--------------------------------------------------------------------------- ---------------------------------------------------------------------------
-- * Auxiliary -- * Auxiliary
@ -725,6 +785,7 @@ getReturn = snd . partitionType
-- ([a, ∀c. c → c], b) -- ([a, ∀c. c → c], b)
-- --
-- Unsure if foralls should be added to the return type or not. -- Unsure if foralls should be added to the return type or not.
-- FIXME
partitionType :: Type -> ([Type], Type) partitionType :: Type -> ([Type], Type)
partitionType = go [] . skipForalls' partitionType = go [] . skipForalls'
where where
@ -743,6 +804,22 @@ skipForalls = go []
TAll tvar t -> go (snoc (TAll tvar) acc) t TAll tvar t -> go (snoc (TAll tvar) acc) t
_ -> (acc, typ) _ -> (acc, typ)
getForallsData :: Type -> [Type -> Type]
getForallsData = fst . partitionData
getTData :: Type -> Type
getTData = snd . partitionData
partitionData :: Type -> ([Type -> Type], Type)
partitionData = go . ([],)
where
go (acc, typ) = case typ of
TAll tvar t -> go (snoc (TAll tvar) acc, t)
TData {} -> (acc, typ)
_ -> error "Bad data type"
partitionTypeWithForall :: Type -> ([Type], Type) partitionTypeWithForall :: Type -> ([Type], Type)
partitionTypeWithForall typ = (t_vars', t_return') partitionTypeWithForall typ = (t_vars', t_return')
where where
@ -798,6 +875,9 @@ insertEnv x = modifyEnv (:|> x)
lookupBind :: LIdent -> Tc (Maybe Exp) lookupBind :: LIdent -> Tc (Maybe Exp)
lookupBind x = gets (Map.lookup x . binds) lookupBind x = gets (Map.lookup x . binds)
lookupDataType :: UIdent -> Tc (Maybe [(UIdent, Type)])
lookupDataType x = gets (Map.lookup x . data_types)
lookupSig :: LIdent -> Tc (Maybe Type) lookupSig :: LIdent -> Tc (Maybe Type)
lookupSig x = gets (Map.lookup x . sig) lookupSig x = gets (Map.lookup x . sig)