diff --git a/src/TypeChecker/TypeCheckerBidir.hs b/src/TypeChecker/TypeCheckerBidir.hs index 1f16e11..60667c5 100644 --- a/src/TypeChecker/TypeCheckerBidir.hs +++ b/src/TypeChecker/TypeCheckerBidir.hs @@ -39,7 +39,6 @@ import qualified TypeChecker.TypeCheckerIr as T -- • Fix problems with types in Pattern/Branch in TypeCheckerIr -- • Use applyEnvExp consistently -- • Fix the different type getters functions (e.g. partitionType) functions --- • Handle recursive functions. Maybe use a isRec : Bool variable. data EnvElem = EnvVar LIdent Type -- ^ Term variable typing. x : A | EnvTVar TVar -- ^ Universal type variable. α @@ -53,44 +52,43 @@ type Env = Seq EnvElem -- | Ordered context -- Γ ::= ・| Γ, α | Γ, ά | Γ, ▶ ά | Γ, x:A data Cxt = Cxt - { env :: Env -- ^ Local scope context Γ - , sig :: Map LIdent Type -- ^ Top-level signatures x : A - , binds :: Map LIdent Exp -- ^ Top-level binds x : e - , next_tevar :: Int -- ^ Counter to distinguish ά - , data_injs :: Map UIdent Type -- ^ Data injections (constructors) K/inj : A + { env :: Env -- ^ Local scope context Γ + , sig :: Map LIdent Type -- ^ Top-level signatures x : A + , binds :: Map LIdent Exp -- ^ Top-level binds x : e + , next_tevar :: Int -- ^ Counter to distinguish ά + , data_injs :: Map UIdent Type -- ^ Data injections (constructors) K/inj : A } deriving (Show, Eq) newtype Tc a = Tc { runTc :: ExceptT String (State Cxt) a } deriving (Functor, Applicative, Monad, Alternative, MonadState Cxt, MonadError String) -typecheck :: Program -> Err (T.Program' Type) -typecheck (Program defs) = do - datatypes <- mapM typecheckDataType [ d | DData d <- defs ] - - - let initCxt = Cxt - { env = mempty - , sig = Map.fromList [ (name, t) - | DSig' name t <- defs - ] - , binds = Map.fromList [ (name, foldr EAbs rhs vars) - | DBind' name vars rhs <- defs - ] - , next_tevar = 0 - , data_injs = Map.fromList [ (name, t) - | Data _ injs <- datatypes - , Inj name t <- injs - ] +initCxt :: [Def] -> Cxt +initCxt defs = Cxt + { env = mempty + , sig = Map.fromList [ (name, t) + | DSig' name t <- defs + ] + , binds = Map.fromList [ (name, foldr EAbs rhs vars) + | DBind' name vars rhs <- defs + ] + , next_tevar = 0 + , data_injs = Map.fromList [ (name, t) + | DData (Data _ injs) <- defs + , Inj name t <- injs + ] } - binds' <- evalState (runExceptT (runTc $ mapM typecheckBind binds)) initCxt; - pure . T.Program $ map T.DData (coerceData datatypes) ++ map T.DBind binds' - where - binds = [ b | DBind b <- defs ] - -- TODO this should happen in typecheckDataType - coerceData = map (\(Data t injs) -> T.Data t $ map - (\(Inj name typ) -> T.Inj (coerce name) typ) injs) +typecheck :: Program -> Err (T.Program' Type) +typecheck (Program defs) = do + dataTypes' <- mapM typecheckDataType [ d | DData d <- defs ] + binds' <- typecheckBinds (initCxt defs) [b | DBind b <- defs] + pure . T.Program $ map T.DData dataTypes' ++ map T.DBind binds' +typecheckBinds :: Cxt -> [Bind] -> Err [T.Bind' Type] +typecheckBinds cxt = flip evalState cxt + . runExceptT + . runTc + . mapM typecheckBind typecheckBind :: Bind -> Tc (T.Bind' Type) typecheckBind (Bind name vars rhs) = do @@ -105,7 +103,7 @@ typecheckBind (Bind name vars rhs) = do pure (T.Bind (coerce name, t') [] (e', t')) env <- gets env unless (isComplete env) err - insertSig (coerce name) typ -- HERE + insertSig (coerce name) typ putEnv Empty pure bind' where @@ -114,11 +112,11 @@ typecheckBind (Bind name vars rhs) = do , "Did you forget to add type annotation to a polymorphic function?" ] -typecheckDataType :: Data -> Err Data +typecheckDataType :: Data -> Err (T.Data' Type) typecheckDataType (Data typ injs) = do (name, tvars) <- go [] typ injs' <- mapM (\i -> typecheckInj i name tvars) injs - pure (Data typ injs') + pure (T.Data typ injs') where go tvars = \case TAll tvar t -> go (tvar:tvars) t @@ -128,7 +126,7 @@ typecheckDataType (Data typ injs) = do -> pure (name, tvars') _ -> throwError $ unwords ["Bad data type definition: ", ppT typ] -typecheckInj :: Inj -> UIdent -> [TVar] -> Err Inj +typecheckInj :: Inj -> UIdent -> [TVar] -> Err (T.Inj' Type) typecheckInj (Inj inj_name inj_typ) name tvars | not $ boundTVars tvars inj_typ = throwError "Unbound type variables" @@ -136,7 +134,7 @@ typecheckInj (Inj inj_name inj_typ) name tvars , name' == name , Right tvars' <- mapM toTVar typs , all (`elem` tvars) tvars' - = pure (Inj inj_name $ foldr TAll inj_typ tvars') + = pure $ T.Inj (coerce inj_name) (foldr TAll inj_typ tvars') | otherwise = throwError $ unwords ["Bad type constructor: ", show name @@ -470,7 +468,6 @@ infer = \case e2'' <- applyEnvExpT e2' pure (T.EAdd e1'' e2'', int) - -- Θ ⊢ Π ∷ A ↓ C ⊣ Δ -- Γ ⊢ e ↓ A ⊣ Θ Δ ⊢ Π covers [Δ]A TODO -- --------------------------------------- @@ -612,7 +609,7 @@ checkPattern patt t_patt = case patt of -- insertEnv (EnvTEVar tevar) pure $ substitute tvar tevar t where - TAll tvar _ = fa dummy + TAll tvar _ = fa int getParams = \case TAll _ t -> getParams t @@ -856,8 +853,6 @@ modifyEnv f = pattern DBind' name vars exp = DBind (Bind name vars exp) pattern DSig' name typ = DSig (Sig name typ) -dummy = TLit "Int" - --------------------------------------------------------------------------- -- * Debug ---------------------------------------------------------------------------