This commit is contained in:
Martin Fredin 2023-03-30 19:07:12 +02:00
parent 72352d9619
commit 7d2a0e60d8

View file

@ -39,7 +39,6 @@ import qualified TypeChecker.TypeCheckerIr as T
-- • Fix problems with types in Pattern/Branch in TypeCheckerIr -- • Fix problems with types in Pattern/Branch in TypeCheckerIr
-- • Use applyEnvExp consistently -- • Use applyEnvExp consistently
-- • Fix the different type getters functions (e.g. partitionType) functions -- • Fix the different type getters functions (e.g. partitionType) functions
-- • Handle recursive functions. Maybe use a isRec : Bool variable.
data EnvElem = EnvVar LIdent Type -- ^ Term variable typing. x : A data EnvElem = EnvVar LIdent Type -- ^ Term variable typing. x : A
| EnvTVar TVar -- ^ Universal type variable. α | EnvTVar TVar -- ^ Universal type variable. α
@ -63,12 +62,8 @@ data Cxt = Cxt
newtype Tc a = Tc { runTc :: ExceptT String (State Cxt) a } newtype Tc a = Tc { runTc :: ExceptT String (State Cxt) a }
deriving (Functor, Applicative, Monad, Alternative, MonadState Cxt, MonadError String) deriving (Functor, Applicative, Monad, Alternative, MonadState Cxt, MonadError String)
typecheck :: Program -> Err (T.Program' Type) initCxt :: [Def] -> Cxt
typecheck (Program defs) = do initCxt defs = Cxt
datatypes <- mapM typecheckDataType [ d | DData d <- defs ]
let initCxt = Cxt
{ env = mempty { env = mempty
, sig = Map.fromList [ (name, t) , sig = Map.fromList [ (name, t)
| DSig' name t <- defs | DSig' name t <- defs
@ -78,19 +73,22 @@ typecheck (Program defs) = do
] ]
, next_tevar = 0 , next_tevar = 0
, data_injs = Map.fromList [ (name, t) , data_injs = Map.fromList [ (name, t)
| Data _ injs <- datatypes | DData (Data _ injs) <- defs
, Inj name t <- injs , Inj name t <- injs
] ]
} }
binds' <- evalState (runExceptT (runTc $ mapM typecheckBind binds)) initCxt; typecheck :: Program -> Err (T.Program' Type)
pure . T.Program $ map T.DData (coerceData datatypes) ++ map T.DBind binds' typecheck (Program defs) = do
where dataTypes' <- mapM typecheckDataType [ d | DData d <- defs ]
binds = [ b | DBind b <- defs ] binds' <- typecheckBinds (initCxt defs) [b | DBind b <- defs]
-- TODO this should happen in typecheckDataType pure . T.Program $ map T.DData dataTypes' ++ map T.DBind binds'
coerceData = map (\(Data t injs) -> T.Data t $ map
(\(Inj name typ) -> T.Inj (coerce name) typ) injs)
typecheckBinds :: Cxt -> [Bind] -> Err [T.Bind' Type]
typecheckBinds cxt = flip evalState cxt
. runExceptT
. runTc
. mapM typecheckBind
typecheckBind :: Bind -> Tc (T.Bind' Type) typecheckBind :: Bind -> Tc (T.Bind' Type)
typecheckBind (Bind name vars rhs) = do typecheckBind (Bind name vars rhs) = do
@ -105,7 +103,7 @@ typecheckBind (Bind name vars rhs) = do
pure (T.Bind (coerce name, t') [] (e', t')) pure (T.Bind (coerce name, t') [] (e', t'))
env <- gets env env <- gets env
unless (isComplete env) err unless (isComplete env) err
insertSig (coerce name) typ -- HERE insertSig (coerce name) typ
putEnv Empty putEnv Empty
pure bind' pure bind'
where where
@ -114,11 +112,11 @@ typecheckBind (Bind name vars rhs) = do
, "Did you forget to add type annotation to a polymorphic function?" , "Did you forget to add type annotation to a polymorphic function?"
] ]
typecheckDataType :: Data -> Err Data typecheckDataType :: Data -> Err (T.Data' Type)
typecheckDataType (Data typ injs) = do typecheckDataType (Data typ injs) = do
(name, tvars) <- go [] typ (name, tvars) <- go [] typ
injs' <- mapM (\i -> typecheckInj i name tvars) injs injs' <- mapM (\i -> typecheckInj i name tvars) injs
pure (Data typ injs') pure (T.Data typ injs')
where where
go tvars = \case go tvars = \case
TAll tvar t -> go (tvar:tvars) t TAll tvar t -> go (tvar:tvars) t
@ -128,7 +126,7 @@ typecheckDataType (Data typ injs) = do
-> pure (name, tvars') -> pure (name, tvars')
_ -> throwError $ unwords ["Bad data type definition: ", ppT typ] _ -> throwError $ unwords ["Bad data type definition: ", ppT typ]
typecheckInj :: Inj -> UIdent -> [TVar] -> Err Inj typecheckInj :: Inj -> UIdent -> [TVar] -> Err (T.Inj' Type)
typecheckInj (Inj inj_name inj_typ) name tvars typecheckInj (Inj inj_name inj_typ) name tvars
| not $ boundTVars tvars inj_typ | not $ boundTVars tvars inj_typ
= throwError "Unbound type variables" = throwError "Unbound type variables"
@ -136,7 +134,7 @@ typecheckInj (Inj inj_name inj_typ) name tvars
, name' == name , name' == name
, Right tvars' <- mapM toTVar typs , Right tvars' <- mapM toTVar typs
, all (`elem` tvars) tvars' , all (`elem` tvars) tvars'
= pure (Inj inj_name $ foldr TAll inj_typ tvars') = pure $ T.Inj (coerce inj_name) (foldr TAll inj_typ tvars')
| otherwise | otherwise
= throwError $ unwords = throwError $ unwords
["Bad type constructor: ", show name ["Bad type constructor: ", show name
@ -470,7 +468,6 @@ infer = \case
e2'' <- applyEnvExpT e2' e2'' <- applyEnvExpT e2'
pure (T.EAdd e1'' e2'', int) pure (T.EAdd e1'' e2'', int)
-- Θ ⊢ Π ∷ A ↓ C ⊣ Δ -- Θ ⊢ Π ∷ A ↓ C ⊣ Δ
-- Γ ⊢ e ↓ A ⊣ Θ Δ ⊢ Π covers [Δ]A TODO -- Γ ⊢ e ↓ A ⊣ Θ Δ ⊢ Π covers [Δ]A TODO
-- --------------------------------------- -- ---------------------------------------
@ -612,7 +609,7 @@ checkPattern patt t_patt = case patt of
-- insertEnv (EnvTEVar tevar) -- insertEnv (EnvTEVar tevar)
pure $ substitute tvar tevar t pure $ substitute tvar tevar t
where where
TAll tvar _ = fa dummy TAll tvar _ = fa int
getParams = \case getParams = \case
TAll _ t -> getParams t TAll _ t -> getParams t
@ -856,8 +853,6 @@ modifyEnv f =
pattern DBind' name vars exp = DBind (Bind name vars exp) pattern DBind' name vars exp = DBind (Bind name vars exp)
pattern DSig' name typ = DSig (Sig name typ) pattern DSig' name typ = DSig (Sig name typ)
dummy = TLit "Int"
--------------------------------------------------------------------------- ---------------------------------------------------------------------------
-- * Debug -- * Debug
--------------------------------------------------------------------------- ---------------------------------------------------------------------------