diff --git a/src/TypeChecker/TypeCheckerBidir.hs b/src/TypeChecker/TypeCheckerBidir.hs index 7cb0081..5ad5021 100644 --- a/src/TypeChecker/TypeCheckerBidir.hs +++ b/src/TypeChecker/TypeCheckerBidir.hs @@ -10,11 +10,12 @@ import Auxiliary (maybeToRightM, snoc) import Control.Applicative (Alternative, Applicative (liftA2), (<|>)) import Control.Monad.Except (ExceptT, MonadError (throwError), - runExceptT, unless, zipWithM, - zipWithM_) + liftEither, runExceptT, unless, + zipWithM, zipWithM_) import Control.Monad.State (MonadState (get, put), State, evalState, gets, modify) import Data.Coerce (coerce) +import Data.Either.Combinators (maybeToRight) import Data.Function (on) import Data.List (intercalate) import Data.Map (Map) @@ -45,11 +46,12 @@ type Env = Seq EnvElem -- | Ordered context -- Γ ::= ・| Γ, α | Γ, ά | Γ, ▶ ά | Γ, x:A data Cxt = Cxt - { env :: Env -- ^ Local scope context Γ - , sig :: Map LIdent Type -- ^ Top-level signatures x : A - , binds :: Map LIdent Exp -- ^ Top-level binds x : e - , next_tevar :: Int -- ^ Counter to distinguish ά - , data_injs :: Map UIdent Type -- ^ Data injections (constructors) K + { env :: Env -- ^ Local scope context Γ + , sig :: Map LIdent Type -- ^ Top-level signatures x : A + , binds :: Map LIdent Exp -- ^ Top-level binds x : e + , next_tevar :: Int -- ^ Counter to distinguish ά + , data_injs :: Map UIdent Type -- ^ Data injections (constructors) K/inj : A + , data_types :: Map UIdent [(UIdent, Type)] -- ^ Data types D : (K₁:A₁ + ‥ + Kₙ:Aₙ) } deriving (Show, Eq) newtype Tc a = Tc { runTc :: ExceptT String (State Cxt) a } @@ -69,16 +71,24 @@ typecheck (Program defs) = do | DBind' name vars rhs <- defs ] , next_tevar = 0 - , data_injs = Map.fromList [ (name, typ) - | Data _ injs <- datatypes - , Inj name typ <- injs - ] + , data_injs = Map.fromList [ (name, foldr ($) typ $ getForallsData typ) + | Data _ injs <- datatypes + , Inj name typ <- injs + ] + , data_types = Map.fromList [ let + TData name _ = getTData typ + kts = [(k,t) | Inj k t <- injs ] + in + (name, kts) + | Data typ injs <- datatypes + ] } binds' <- evalState (runExceptT (runTc $ mapM typecheckBind binds)) initCxt; pure . T.Program $ map T.DData (coerceData datatypes) ++ map T.DBind binds' where binds = [ b | DBind b <- defs ] + -- TODO this should happen in typecheckDataType coerceData = map (\(Data t injs) -> T.Data t $ map (\(Inj name typ) -> T.Inj (coerce name) typ) injs) @@ -138,8 +148,8 @@ typecheckInj (Inj inj_name inj_typ) name tvars | TData name' typs <- getReturn inj_typ , name' == name , Right tvars' <- mapM toTVar typs - , tvars' == tvars - = pure (Inj inj_name $ foldr TAll inj_typ tvars) + , all (`elem` tvars) tvars' + = pure (Inj inj_name inj_typ) | otherwise = throwError $ unwords ["Bad type constructor: ", show name @@ -216,6 +226,7 @@ subtype t1 t2 = case (t1, t2) of -- Γ[ά] ⊢ A <: ά ⊣ Δ (typ, TEVar tevar) | notElem tevar $ frees typ -> instantiateR typ tevar + (TData name1 typs1, TData name2 typs2) -- D₁ = D₂ @@ -542,33 +553,82 @@ checkBranch (Branch patt exp) t_patt t_exp = do pure (T.Branch patt' (exp, t_exp)) checkPattern :: Pattern -> Type -> Tc (T.Pattern' Type, Type) -checkPattern patt t_patt = case patt of +checkPattern patt t_patt = (, t_patt) <$> case patt of PVar x -> do insertEnv $ EnvVar x t_patt - pure (T.PVar (coerce x, dummy), dummy) -- TODO - PCatch -> pure (T.PCatch, dummy) -- TODO - PLit lit | inferLit lit == t_patt -> let - t = inferLit lit - in - pure (T.PLit (lit, t), t) + pure $ T.PVar (coerce x, t_patt) + PCatch -> pure T.PCatch + PLit lit | inferLit lit == t_patt -> pure $ T.PLit (lit, t_patt) | otherwise -> throwError "Literal in pattern have wrong type" PEnum name -> do t <- maybeToRightM ("Unknown constructor " ++ show name) =<< lookupInj name subtype t t_patt - pure (T.PEnum (coerce name), dummy) -- TODO + pure $ T.PEnum (coerce name) + + + -- Θ₁ ⊢ p₁ ↑ [Θ₁]B₁ ⊣ Θ₂ + -- Γ ⊢ (xₖ : B₁ → ‥ → Bₘ₊₁) ∈ Γ ... + -- Γ ⊢ B₁ → ‥ → Bₘ₊₁ <: A₁ + ‥ + Aₙ ⊣ Θ₁ Θₘ ⊢ pₘ ↑ [Θₘ₋₁]Bₘ ⊣ Δ + -- -------------------------------------------------------------- + -- Γ ⊢ injₖ xₖ. p₁ ‥ pₘ ↑ A₁ + ‥ + Aₙ ⊣ Δ PInj name ps -> do t <- maybeToRightM ("Unknown constructor " ++ show name) =<< lookupInj name + subtype t t_patt + + let (t_ps, t_return) = partitionTypeWithForall t unless (length ps == length t_ps) $ throwError "Wrong number of variables" - subtype t_return t_patt - ps' <- zipWithM (\p t -> checkPattern p =<< applyEnv t) ps t_ps - let ps'' = map fst ps' -- TODO - pure (T.PInj (coerce name) ps'', dummy) + + -- §ps' <- zipWithM (\p t -> checkPattern p =<< applyEnv t) ps t_ps + -- let ps'' = map fst ps' -- TODO + pure $ T.PInj (coerce name) [] + +subtypeData :: UIdent -> Type -> Tc () +subtypeData name_inj typ = do + injs <- maybeToRightM err1 =<< lookupDataType name_d + t_inj <- liftEither . maybeToRight err2 $ lookup name_k injs + (t_inj', typs')<- substituteTVars foralls t_inj data_t + subtype () + + undefined + where + substituteTVars fas t1 t2 = case fas of + [] -> pure (t1, t2) + fa:fas' -> do + (t1', t2') <- go fa (t1, t2) + substituteTVars fas' t1' t2' + where + go fa (t1, t2) = let TAll tvar _ = fa dummy in do + tevar <- fresh + insertEnv (EnvTEVar tevar) + pure $ on (,) (substitute tvar tevar) t1 t2 + + + + (foralls, data_t@(TData name_d typs)) = partitionData typ + err1 = unwords ["Unknown data type", show name_d] + err2 = unwords ["No", show name_k, "constructor for data type", show name_d] + + -- TAll tvar t -> do + -- tevar <- fresh + -- let -- env_marker = EnvMark tevar + -- env_tevar = EnvTEVar tevar + -- -- insertEnv env_marker + -- insertEnv env_tevar + -- let a' = substitute tvar tevar a + -- subtype a' b + -- -- dropTrailing env_marker + + -- TData name_d typs -> do + -- + -- subtype t_k typ + -- undefined + -- where --------------------------------------------------------------------------- -- * Auxiliary @@ -725,6 +785,7 @@ getReturn = snd . partitionType -- ([a, ∀c. c → c], b) -- -- Unsure if foralls should be added to the return type or not. +-- FIXME partitionType :: Type -> ([Type], Type) partitionType = go [] . skipForalls' where @@ -743,6 +804,22 @@ skipForalls = go [] TAll tvar t -> go (snoc (TAll tvar) acc) t _ -> (acc, typ) + +getForallsData :: Type -> [Type -> Type] +getForallsData = fst . partitionData + +getTData :: Type -> Type +getTData = snd . partitionData + +partitionData :: Type -> ([Type -> Type], Type) +partitionData = go . ([],) + where + go (acc, typ) = case typ of + TAll tvar t -> go (snoc (TAll tvar) acc, t) + TData {} -> (acc, typ) + _ -> error "Bad data type" + + partitionTypeWithForall :: Type -> ([Type], Type) partitionTypeWithForall typ = (t_vars', t_return') where @@ -798,6 +875,9 @@ insertEnv x = modifyEnv (:|> x) lookupBind :: LIdent -> Tc (Maybe Exp) lookupBind x = gets (Map.lookup x . binds) +lookupDataType :: UIdent -> Tc (Maybe [(UIdent, Type)]) +lookupDataType x = gets (Map.lookup x . data_types) + lookupSig :: LIdent -> Tc (Maybe Type) lookupSig x = gets (Map.lookup x . sig)