208 lines
5.6 KiB
Haskell
208 lines
5.6 KiB
Haskell
{-# LANGUAGE LambdaCase #-}
|
|
|
|
module TypeChecker.TypeCheckerIr where
|
|
|
|
import Control.Monad.Except
|
|
import Control.Monad.Reader
|
|
import Control.Monad.State
|
|
import Data.Functor.Identity (Identity)
|
|
import Data.Map (Map)
|
|
import Grammar.Abs (
|
|
Data (..),
|
|
Ident (..),
|
|
Init (..),
|
|
Lit (..),
|
|
TVar (..),
|
|
)
|
|
import Grammar.Abs qualified as GA (Type (..))
|
|
import Grammar.Print
|
|
import Prelude
|
|
import Prelude qualified as C (Eq, Ord, Read, Show)
|
|
|
|
-- | A data type representing type variables
|
|
data Poly = Forall [Ident] Type
|
|
deriving (Show)
|
|
|
|
newtype Ctx = Ctx {vars :: Map Ident Poly}
|
|
deriving (Show)
|
|
|
|
data Env = Env
|
|
{ count :: Int
|
|
, sigs :: Map Ident GA.Type
|
|
, constructors :: Map Ident GA.Type
|
|
}
|
|
deriving (Show)
|
|
|
|
type Error = String
|
|
type Subst = Map Ident Type
|
|
|
|
type Infer = StateT Env (ReaderT Ctx (ExceptT Error Identity))
|
|
|
|
newtype Program = Program [Def]
|
|
deriving (C.Eq, C.Ord, C.Show, C.Read)
|
|
|
|
data Type
|
|
= TLit Ident
|
|
| TVar TVar
|
|
| TFun Type Type
|
|
| TAll TVar Type
|
|
| TIndexed Indexed
|
|
deriving (Show, Eq, Ord, Read)
|
|
|
|
data Exp
|
|
= EId Id
|
|
| ELit Lit
|
|
| ELet Bind ExpT
|
|
| EApp ExpT ExpT
|
|
| EAdd ExpT ExpT
|
|
| EAbs Id ExpT
|
|
| ECase ExpT [Inj]
|
|
deriving (C.Eq, C.Ord, C.Read, C.Show)
|
|
|
|
type ExpT = (Exp, Type)
|
|
|
|
data Indexed = Indexed Ident [Type]
|
|
deriving (Show, Read, Ord, Eq)
|
|
|
|
data Inj = Inj (Init, Type) ExpT
|
|
deriving (C.Eq, C.Ord, C.Read, C.Show)
|
|
|
|
data Def = DBind Bind | DData Data
|
|
deriving (C.Eq, C.Ord, C.Read, C.Show)
|
|
|
|
type Id = (Ident, Type)
|
|
|
|
data Bind = Bind Id [Id] ExpT
|
|
deriving (C.Eq, C.Ord, C.Show, C.Read)
|
|
|
|
instance Print [Def] where
|
|
prt _ [] = concatD []
|
|
prt _ (x : xs) = concatD [prt 0 x, doc (showString "\n"), prt 0 xs]
|
|
|
|
instance Print Def where
|
|
prt i (DBind bind) = prt i bind
|
|
prt i (DData d) = prt i d
|
|
|
|
instance Print Program where
|
|
prt i (Program sc) = prPrec i 0 $ prt 0 sc
|
|
|
|
instance Print Bind where
|
|
prt i (Bind (t, name) args rhs) =
|
|
prPrec i 0 $
|
|
concatD
|
|
[ prt 0 name
|
|
, doc $ showString ":"
|
|
, prt 0 t
|
|
, doc $ showString "\n"
|
|
, prt 0 name
|
|
, doc $ showString "="
|
|
, prt 0 rhs
|
|
]
|
|
|
|
instance Print [Bind] where
|
|
prt _ [] = concatD []
|
|
prt _ [x] = concatD [prt 0 x]
|
|
prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), doc (showString "\n"), prt 0 xs]
|
|
|
|
prtIdPs :: Int -> [Id] -> Doc
|
|
prtIdPs i = prPrec i 0 . concatD . map (prtIdP i)
|
|
|
|
prtId :: Int -> Id -> Doc
|
|
prtId i (name, t) =
|
|
prPrec i 0 $
|
|
concatD
|
|
[ prt 0 name
|
|
, doc $ showString ":"
|
|
, prt 0 t
|
|
]
|
|
|
|
prtIdP :: Int -> Id -> Doc
|
|
prtIdP i (name, t) =
|
|
prPrec i 0 $
|
|
concatD
|
|
[ doc $ showString "("
|
|
, prt 0 name
|
|
, doc $ showString ":"
|
|
, prt 0 t
|
|
, doc $ showString ")"
|
|
]
|
|
|
|
instance Print Exp where
|
|
prt i = \case
|
|
EId n -> prPrec i 3 $ concatD [prtId 0 n, doc $ showString "\n"]
|
|
ELit _ lit -> prPrec i 3 $ concatD [prt 0 lit, doc $ showString "\n"]
|
|
ELet bs e ->
|
|
prPrec i 3 $
|
|
concatD
|
|
[ doc $ showString "let"
|
|
, prt 0 bs
|
|
, doc $ showString "in"
|
|
, prt 0 e
|
|
, doc $ showString "\n"
|
|
]
|
|
EApp _ e1 e2 ->
|
|
prPrec i 2 $
|
|
concatD
|
|
[ prt 2 e1
|
|
, prt 3 e2
|
|
]
|
|
EAdd t e1 e2 ->
|
|
prPrec i 1 $
|
|
concatD
|
|
[ doc $ showString "@"
|
|
, prt 0 t
|
|
, prt 1 e1
|
|
, doc $ showString "+"
|
|
, prt 2 e2
|
|
, doc $ showString "\n"
|
|
]
|
|
EAbs t n e ->
|
|
prPrec i 0 $
|
|
concatD
|
|
[ doc $ showString "@"
|
|
, prt 0 t
|
|
, doc $ showString "\\"
|
|
, prtId 0 n
|
|
, doc $ showString "."
|
|
, prt 0 e
|
|
, doc $ showString "\n"
|
|
]
|
|
ECase t exp injs ->
|
|
prPrec
|
|
i
|
|
0
|
|
( concatD
|
|
[ doc (showString "case")
|
|
, prt 0 exp
|
|
, doc (showString "of")
|
|
, doc (showString "{")
|
|
, prt 0 injs
|
|
, doc (showString "}")
|
|
, doc (showString ":")
|
|
, prt 0 t
|
|
, doc $ showString "\n"
|
|
]
|
|
)
|
|
|
|
instance Print ExpT where
|
|
prt i (e, t) = concatD [prt i e, doc (showString ":"), prt i t]
|
|
|
|
instance Print Inj where
|
|
prt i = \case
|
|
Inj (init, t) exp -> prPrec i 0 (concatD [prt 0 init, doc (showString ":"), prt 0 t, doc (showString "=>"), prt 0 exp])
|
|
|
|
instance Print [Inj] where
|
|
prt _ [] = concatD []
|
|
prt _ [x] = concatD [prt 0 x]
|
|
prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs]
|
|
|
|
instance Print Type where
|
|
prt i = \case
|
|
TLit uident -> prPrec i 2 (concatD [prt 0 uident])
|
|
TVar tvar -> prPrec i 2 (concatD [prt 0 tvar])
|
|
TAll tvar type_ -> prPrec i 1 (concatD [doc (showString "forall"), prt 0 tvar, doc (showString "."), prt 0 type_])
|
|
TIndexed indexed -> prPrec i 1 (concatD [prt 0 indexed])
|
|
TFun type_1 type_2 -> prPrec i 0 (concatD [prt 1 type_1, doc (showString "->"), prt 0 type_2])
|
|
|
|
instance Print Indexed where
|
|
prt i (Indexed u ts) = concatD [prt i u, prt i ts]
|