Fixed errors in tc hm

This commit is contained in:
sebastianselander 2023-03-27 16:48:23 +02:00
parent 847ec37117
commit 6e54378327
4 changed files with 346 additions and 345 deletions

View file

@ -8,19 +8,18 @@ module Codegen.LlvmIr (
LLVMComp (..), LLVMComp (..),
Visibility (..), Visibility (..),
CallingConvention (..), CallingConvention (..),
ToIr(..) ToIr (..),
) where ) where
import Data.List (intercalate) import Data.List (intercalate)
import Grammar.Abs (Character) import TypeChecker.TypeCheckerIr (Ident (..))
import TypeChecker.TypeCheckerIr (Ident (..))
data CallingConvention = TailCC | FastCC | CCC | ColdCC deriving Show data CallingConvention = TailCC | FastCC | CCC | ColdCC deriving (Show)
instance ToIr CallingConvention where instance ToIr CallingConvention where
toIr :: CallingConvention -> String toIr :: CallingConvention -> String
toIr TailCC = "tailcc" toIr TailCC = "tailcc"
toIr FastCC = "fastcc" toIr FastCC = "fastcc"
toIr CCC = "ccc" toIr CCC = "ccc"
toIr ColdCC = "coldcc" toIr ColdCC = "coldcc"
-- | A datatype which represents some basic LLVM types -- | A datatype which represents some basic LLVM types
@ -34,7 +33,7 @@ data LLVMType
| Function LLVMType [LLVMType] | Function LLVMType [LLVMType]
| Array Integer LLVMType | Array Integer LLVMType
| CustomType Ident | CustomType Ident
deriving Show deriving (Show)
class ToIr a where class ToIr a where
toIr :: a -> String toIr :: a -> String
@ -63,12 +62,12 @@ data LLVMComp
| LLSge | LLSge
| LLSlt | LLSlt
| LLSle | LLSle
deriving Show deriving (Show)
instance ToIr LLVMComp where instance ToIr LLVMComp where
toIr :: LLVMComp -> String toIr :: LLVMComp -> String
toIr = \case toIr = \case
LLEq -> "eq" LLEq -> "eq"
LLNe -> "ne" LLNe -> "ne"
LLUgt -> "ugt" LLUgt -> "ugt"
LLUge -> "uge" LLUge -> "uge"
LLUlt -> "ult" LLUlt -> "ult"
@ -78,30 +77,31 @@ instance ToIr LLVMComp where
LLSlt -> "slt" LLSlt -> "slt"
LLSle -> "sle" LLSle -> "sle"
data Visibility = Local | Global deriving Show data Visibility = Local | Global deriving (Show)
instance ToIr Visibility where instance ToIr Visibility where
toIr :: Visibility -> String toIr :: Visibility -> String
toIr Local = "%" toIr Local = "%"
toIr Global = "@" toIr Global = "@"
-- | Represents a LLVM "value", as in an integer, a register variable, {- | Represents a LLVM "value", as in an integer, a register variable,
-- or a string contstant or a string contstant
-}
data LLVMValue data LLVMValue
= VInteger Integer = VInteger Integer
| VChar Character | VChar Char
| VIdent Ident LLVMType | VIdent Ident LLVMType
| VConstant String | VConstant String
| VFunction Ident Visibility LLVMType | VFunction Ident Visibility LLVMType
deriving Show deriving (Show)
instance ToIr LLVMValue where instance ToIr LLVMValue where
toIr :: LLVMValue -> String toIr :: LLVMValue -> String
toIr v = case v of toIr v = case v of
VInteger i -> show i VInteger i -> show i
VChar i -> show i VChar i -> show i
VIdent (Ident n) _ -> "%" <> n VIdent (Ident n) _ -> "%" <> n
VFunction (Ident n) vis _ -> toIr vis <> n VFunction (Ident n) vis _ -> toIr vis <> n
VConstant s -> "c" <> show s VConstant s -> "c" <> show s
type Params = [(Ident, LLVMType)] type Params = [(Ident, LLVMType)]
type Args = [(LLVMType, LLVMValue)] type Args = [(LLVMType, LLVMValue)]
@ -114,8 +114,8 @@ data LLVMIr
| Declare LLVMType Ident Params | Declare LLVMType Ident Params
| SetVariable Ident LLVMIr | SetVariable Ident LLVMIr
| Variable Ident | Variable Ident
-- extractvalue <aggregate type> <val>, <idx>{, <idx>}* | -- extractvalue <aggregate type> <val>, <idx>{, <idx>}*
| ExtractValue LLVMType LLVMValue Integer ExtractValue LLVMType LLVMValue Integer
| GetElementPtr LLVMType LLVMType LLVMValue LLVMType LLVMValue LLVMType LLVMValue | GetElementPtr LLVMType LLVMType LLVMValue LLVMType LLVMValue LLVMType LLVMValue
| GetElementPtrInbounds LLVMType LLVMType LLVMValue LLVMType LLVMValue LLVMType LLVMValue | GetElementPtrInbounds LLVMType LLVMType LLVMValue LLVMType LLVMValue LLVMType LLVMValue
| Add LLVMType LLVMValue LLVMValue | Add LLVMType LLVMValue LLVMValue
@ -136,7 +136,7 @@ data LLVMIr
| Comment String | Comment String
| UnsafeRaw String -- This should generally be avoided, and proper | UnsafeRaw String -- This should generally be avoided, and proper
-- instructions should be used in its place -- instructions should be used in its place
deriving Show deriving (Show)
-- | Converts a list of LLVMIr instructions to a string -- | Converts a list of LLVMIr instructions to a string
llvmIrToString :: [LLVMIr] -> String llvmIrToString :: [LLVMIr] -> String
@ -146,14 +146,15 @@ llvmIrToString = go 0
go _ [] = mempty go _ [] = mempty
go i (x : xs) = do go i (x : xs) = do
let (i', n) = case x of let (i', n) = case x of
Define{} -> (i + 1, 0) Define{} -> (i + 1, 0)
DefineEnd -> (i - 1, 0) DefineEnd -> (i - 1, 0)
_ -> (i, i) _ -> (i, i)
insToString n x <> go i' xs insToString n x <> go i' xs
{- | Converts a LLVM inststruction to a String, allowing for printing etc.
The integer represents the indentation -- \| Converts a LLVM inststruction to a String, allowing for printing etc.
-} -- The integer represents the indentation
{- FOURMOLU_DISABLE -} --
{- FOURMOLU_DISABLE -}
insToString :: Int -> LLVMIr -> String insToString :: Int -> LLVMIr -> String
insToString i l = insToString i l =
replicate i '\t' <> case l of replicate i '\t' <> case l of
@ -261,4 +262,3 @@ llvmIrToString = go 0
lblPfx :: String lblPfx :: String
lblPfx = "lbl_" lblPfx = "lbl_"

View file

@ -1,7 +1,6 @@
module Monomorphizer.MonomorphizerIr (module Monomorphizer.MonomorphizerIr, module GA) where module Monomorphizer.MonomorphizerIr (module Monomorphizer.MonomorphizerIr) where
import qualified Grammar.Abs as GA (Ident (..)) import TypeChecker.TypeCheckerIr qualified as TIR (Ident (..))
import qualified TypeChecker.TypeCheckerIr as TIR (Ident (..))
type Id = (TIR.Ident, Type) type Id = (TIR.Ident, Type)
@ -26,8 +25,12 @@ data Exp
| ECase ExpT [Branch] | ECase ExpT [Branch]
deriving (Show, Ord, Eq) deriving (Show, Ord, Eq)
data Pattern = PVar Id | PLit (Lit, Type) | PInj TIR.Ident [Pattern] data Pattern
| PCatch | PEnum TIR.Ident = PVar Id
| PLit (Lit, Type)
| PInj TIR.Ident [Pattern]
| PCatch
| PEnum TIR.Ident
deriving (Eq, Ord, Show) deriving (Eq, Ord, Show)
data Branch = Branch (Pattern, Type) ExpT data Branch = Branch (Pattern, Type) ExpT
@ -48,4 +51,4 @@ data Type = TLit TIR.Ident | TFun Type Type
flattenType :: Type -> [Type] flattenType :: Type -> [Type]
flattenType (TFun t1 t2) = t1 : flattenType t2 flattenType (TFun t1 t2) = t1 : flattenType t2
flattenType x = [x] flattenType x = [x]

View file

@ -1,31 +1,29 @@
{-# LANGUAGE LambdaCase #-} {-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE OverloadedStrings #-}
-- | A module for type checking and inference using algorithm W, Hindley-Milner -- | A module for type checking and inference using algorithm W, Hindley-Milner
module TypeChecker.TypeCheckerHm where module TypeChecker.TypeCheckerHm where
import Auxiliary import Auxiliary
import Control.Monad.Except import Control.Monad.Except
import Control.Monad.Identity (runIdentity) import Control.Monad.Identity (Identity, runIdentity)
import Control.Monad.Reader import Control.Monad.Reader
import Control.Monad.State import Control.Monad.State
import Data.Bifunctor (second) import Data.Bifunctor (second)
import Data.Coerce (coerce) import Data.Coerce (coerce)
import Data.Foldable (traverse_) import Data.Foldable (traverse_)
import Data.Function (on) import Data.Function (on)
import Data.List (foldl') import Data.List (foldl')
import Data.List.Extra (unsnoc) import Data.List.Extra (unsnoc)
import Data.Map (Map) import Data.Map (Map)
import qualified Data.Map as M import Data.Map qualified as M
import Data.Maybe (fromJust) import Data.Maybe (fromJust)
import Data.Set (Set) import Data.Set (Set)
import qualified Data.Set as S import Data.Set qualified as S
import Debug.Trace (trace) import Data.String
import Grammar.Abs import Grammar.Abs
import Grammar.Print (printTree) import Grammar.Print (printTree)
import qualified TypeChecker.TypeCheckerIr as T import TypeChecker.TypeCheckerIr qualified as T
import TypeChecker.TypeCheckerIr (Ctx (..), Env (..), Error, Infer,
Subst)
initCtx = Ctx mempty initCtx = Ctx mempty
initEnv = Env 0 'a' mempty mempty mempty initEnv = Env 0 'a' mempty mempty mempty
@ -39,7 +37,7 @@ run = runC initEnv initCtx
runC :: Env -> Ctx -> Infer a -> Either Error a runC :: Env -> Ctx -> Infer a -> Either Error a
runC e c = runIdentity . runExceptT . flip runReaderT c . flip evalStateT e runC e c = runIdentity . runExceptT . flip runReaderT c . flip evalStateT e
typecheck :: Program -> Either Error T.Program typecheck :: Program -> Either Error (T.Program' Type)
typecheck = run . checkPrg typecheck = run . checkPrg
checkData :: Data -> Infer () checkData :: Data -> Infer ()
@ -50,9 +48,9 @@ checkData d = do
(all isPoly ts) (all isPoly ts)
(throwError $ unwords ["Data type incorrectly declared"]) (throwError $ unwords ["Data type incorrectly declared"])
traverse_ traverse_
( \(Constructor name' t') -> ( \(Inj name' t') ->
if typ == retType t' if typ == retType t'
then insertConstr (coerce name') (toNew t') then insertConstr (coerce name') (t')
else else
throwError $ throwError $
unwords unwords
@ -73,9 +71,9 @@ checkData d = do
retType :: Type -> Type retType :: Type -> Type
retType (TFun _ t2) = retType t2 retType (TFun _ t2) = retType t2
retType a = a retType a = a
checkPrg :: Program -> Infer T.Program checkPrg :: Program -> Infer (T.Program' Type)
checkPrg (Program bs) = do checkPrg (Program bs) = do
preRun bs preRun bs
bs' <- checkDef bs bs' <- checkDef bs
@ -94,25 +92,27 @@ preRun (x : xs) = case x of
<> printTree n <> printTree n
<> "'" <> "'"
) )
insertSig (coerce n) (Just $ toNew t) >> preRun xs insertSig (coerce n) (Just $ t) >> preRun xs
DBind (Bind n _ e) -> do DBind (Bind n _ e) -> do
collect (collectTypeVars e) collect (collectTypeVars e)
s <- gets sigs s <- gets sigs
case M.lookup (coerce n) s of case M.lookup (coerce n) s of
Nothing -> insertSig (coerce n) Nothing >> preRun xs Nothing -> insertSig (coerce n) Nothing >> preRun xs
Just _ -> preRun xs Just _ -> preRun xs
DData d@(Data t _) -> collect (collectTypeVars t) >> checkData d >> preRun xs DData d@(Data t _) -> collect (collectTypeVars t) >> checkData d >> preRun xs
checkDef :: [Def] -> Infer [T.Def] checkDef :: [Def] -> Infer [T.Def' Type]
checkDef [] = return [] checkDef [] = return []
checkDef (x : xs) = case x of checkDef (x : xs) = case x of
(DBind b) -> do (DBind b) -> do
b' <- checkBind b b' <- checkBind b
fmap (T.DBind b' :) (checkDef xs) fmap (T.DBind b' :) (checkDef xs)
(DData d) -> fmap (T.DData (toNew d) :) (checkDef xs) (DData d) -> fmap ((T.DData (coerceData d)) :) (checkDef xs)
(DSig _) -> checkDef xs (DSig _) -> checkDef xs
where
coerceData (Data t injs) = T.Data t $ map (\(Inj name typ) -> T.Inj (coerce name) typ) injs
checkBind :: Bind -> Infer T.Bind checkBind :: Bind -> Infer (T.Bind' Type)
checkBind (Bind name args e) = do checkBind (Bind name args e) = do
let lambda = makeLambda e (reverse (coerce args)) let lambda = makeLambda e (reverse (coerce args))
e@(_, args_t) <- inferExp lambda e@(_, args_t) <- inferExp lambda
@ -133,41 +133,41 @@ checkBind (Bind name args e) = do
insertSig (coerce name) (Just args_t) insertSig (coerce name) (Just args_t)
return (T.Bind (coerce name, args_t) [] e) return (T.Bind (coerce name, args_t) [] e)
typeEq :: T.Type -> T.Type -> Bool typeEq :: Type -> Type -> Bool
typeEq (T.TFun l r) (T.TFun l' r') = typeEq l l' && typeEq r r' typeEq (TFun l r) (TFun l' r') = typeEq l l' && typeEq r r'
typeEq (T.TLit a) (T.TLit b) = a == b typeEq (TLit a) (TLit b) = a == b
typeEq (T.TData name a) (T.TData name' b) = typeEq (TData name a) (TData name' b) =
length a == length b length a == length b
&& name == name' && name == name'
&& and (zipWith typeEq a b) && and (zipWith typeEq a b)
typeEq (T.TAll _ t1) t2 = t1 `typeEq` t2 typeEq (TAll _ t1) t2 = t1 `typeEq` t2
typeEq t1 (T.TAll _ t2) = t1 `typeEq` t2 typeEq t1 (TAll _ t2) = t1 `typeEq` t2
typeEq (T.TVar _) (T.TVar _) = True typeEq (TVar _) (TVar _) = True
typeEq _ _ = False typeEq _ _ = False
skolem :: T.Type -> T.Type skolem :: Type -> Type
skolem (T.TVar (T.MkTVar a)) = T.TLit a skolem (TVar (T.MkTVar a)) = TLit (coerce a)
skolem (T.TAll x t) = T.TAll x (skolem t) skolem (TAll x t) = TAll x (skolem t)
skolem (T.TFun t1 t2) = (T.TFun `on` skolem) t1 t2 skolem (TFun t1 t2) = (TFun `on` skolem) t1 t2
skolem t = t skolem t = t
isMoreSpecificOrEq :: T.Type -> T.Type -> Bool isMoreSpecificOrEq :: Type -> Type -> Bool
isMoreSpecificOrEq t1 (T.TAll _ t2) = isMoreSpecificOrEq t1 t2 isMoreSpecificOrEq t1 (TAll _ t2) = isMoreSpecificOrEq t1 t2
isMoreSpecificOrEq (T.TFun a b) (T.TFun c d) = isMoreSpecificOrEq (TFun a b) (TFun c d) =
isMoreSpecificOrEq a c && isMoreSpecificOrEq b d isMoreSpecificOrEq a c && isMoreSpecificOrEq b d
isMoreSpecificOrEq (T.TData n1 ts1) (T.TData n2 ts2) = isMoreSpecificOrEq (TData n1 ts1) (TData n2 ts2) =
n1 == n2 n1 == n2
&& length ts1 == length ts2 && length ts1 == length ts2
&& and (zipWith isMoreSpecificOrEq ts1 ts2) && and (zipWith isMoreSpecificOrEq ts1 ts2)
isMoreSpecificOrEq _ (T.TVar _) = True isMoreSpecificOrEq _ (TVar _) = True
isMoreSpecificOrEq a b = a == b isMoreSpecificOrEq a b = a == b
isPoly :: Type -> Bool isPoly :: Type -> Bool
isPoly (TAll _ _) = True isPoly (TAll _ _) = True
isPoly (TVar _) = True isPoly (TVar _) = True
isPoly _ = False isPoly _ = False
inferExp :: Exp -> Infer T.ExpT inferExp :: Exp -> Infer (T.ExpT' Type)
inferExp e = do inferExp e = do
(s, (e', t)) <- algoW e (s, (e', t)) <- algoW e
let subbed = apply s t let subbed = apply s t
@ -178,7 +178,7 @@ class CollectTVars a where
instance CollectTVars Exp where instance CollectTVars Exp where
collectTypeVars (EAnn e t) = collectTypeVars t `S.union` collectTypeVars e collectTypeVars (EAnn e t) = collectTypeVars t `S.union` collectTypeVars e
collectTypeVars _ = S.empty collectTypeVars _ = S.empty
instance CollectTVars Type where instance CollectTVars Type where
collectTypeVars (TVar (MkTVar i)) = S.singleton (coerce i) collectTypeVars (TVar (MkTVar i)) = S.singleton (coerce i)
@ -190,43 +190,12 @@ instance CollectTVars Type where
collect :: Set T.Ident -> Infer () collect :: Set T.Ident -> Infer ()
collect s = modify (\st -> st{takenTypeVars = s `S.union` takenTypeVars st}) collect s = modify (\st -> st{takenTypeVars = s `S.union` takenTypeVars st})
class NewType a b where algoW :: Exp -> Infer (Subst, (T.ExpT' Type))
toNew :: a -> b
instance NewType Type T.Type where
toNew = \case
TLit i -> T.TLit $ coerce i
TVar v -> T.TVar $ toNew v
TFun t1 t2 -> (T.TFun `on` toNew) t1 t2
TAll b t -> T.TAll (toNew b) (toNew t)
TData i ts -> T.TData (coerce i) (map toNew ts)
TEVar _ -> error "Should not exist after typechecker"
instance NewType Lit T.Lit where
toNew (LInt i) = T.LInt i
toNew (LChar i) = T.LChar i
instance NewType Data T.Data where
toNew (Data t xs) = T.Data (name $ retType t) (toNew xs)
where
name (TData n _) = coerce n
name _ = error "Bug: Data types should not be able to be typed over non type variables"
instance NewType Constructor T.Constructor where
toNew (Constructor name xs) = T.Constructor (coerce name) (toNew xs)
instance NewType TVar T.TVar where
toNew (MkTVar i) = T.MkTVar $ coerce i
instance NewType a b => NewType [a] [b] where
toNew = map toNew
algoW :: Exp -> Infer (Subst, T.ExpT)
algoW = \case algoW = \case
err@(EAnn e t) -> do err@(EAnn e t) -> do
(s1, (e', t')) <- exprErr (algoW e) err (s1, (e', t')) <- exprErr (algoW e) err
unless unless
(toNew t `isMoreSpecificOrEq` t') (t `isMoreSpecificOrEq` t')
( throwError $ ( throwError $
unwords unwords
[ "Annotated type:" [ "Annotated type:"
@ -236,34 +205,34 @@ algoW = \case
] ]
) )
applySt s1 $ do applySt s1 $ do
s2 <- exprErr (unify (toNew t) t') err s2 <- exprErr (unify (t) t') err
let comp = s2 `compose` s1 let comp = s2 `compose` s1
return (comp, apply comp (e', toNew t)) return (comp, apply comp (e', t))
-- \| ------------------ -- \| ------------------
-- \| Γ ⊢ i : Int, ∅ -- \| Γ ⊢ i : Int, ∅
ELit lit -> return (nullSubst, (T.ELit $ toNew lit, litType lit)) ELit lit -> return (nullSubst, (T.ELit $ lit, litType lit))
-- \| x : σ ∈ Γ τ = inst(σ) -- \| x : σ ∈ Γ τ = inst(σ)
-- \| ---------------------- -- \| ----------------------
-- \| Γ ⊢ x : τ, ∅ -- \| Γ ⊢ x : τ, ∅
EVar i -> do EVar i -> do
var <- asks vars var <- asks vars
case M.lookup (coerce i) var of case M.lookup (coerce i) var of
Just t -> inst t >>= \x -> return (nullSubst, (T.EId $ coerce i, x)) Just t -> inst t >>= \x -> return (nullSubst, (T.EVar $ coerce i, x))
Nothing -> do Nothing -> do
sig <- gets sigs sig <- gets sigs
case M.lookup (coerce i) sig of case M.lookup (coerce i) sig of
Just (Just t) -> return (nullSubst, (T.EId $ coerce i, t)) Just (Just t) -> return (nullSubst, (T.EVar $ coerce i, t))
Just Nothing -> do Just Nothing -> do
fr <- fresh fr <- fresh
insertSig (coerce i) (Just fr) insertSig (coerce i) (Just fr)
return (nullSubst, (T.EId $ coerce i, fr)) return (nullSubst, (T.EVar $ coerce i, fr))
Nothing -> throwError $ "Unbound variable: " <> printTree i Nothing -> throwError $ "Unbound variable: " <> printTree i
EInj i -> do EInj i -> do
constr <- gets constructors constr <- gets constructors
case M.lookup (coerce i) constr of case M.lookup (coerce i) constr of
Just t -> return (nullSubst, (T.EId $ coerce i, t)) Just t -> return (nullSubst, (T.EVar $ coerce i, t))
Nothing -> Nothing ->
throwError $ throwError $
"Constructor: '" "Constructor: '"
@ -280,7 +249,7 @@ algoW = \case
( withBinding (coerce name) fr $ do ( withBinding (coerce name) fr $ do
(s1, (e', t')) <- exprErr (algoW e) err (s1, (e', t')) <- exprErr (algoW e) err
let varType = apply s1 fr let varType = apply s1 fr
let newArr = T.TFun varType t' let newArr = TFun varType t'
return (s1, apply s1 (T.EAbs (coerce name) (e', t'), newArr)) return (s1, apply s1 (T.EAbs (coerce name) (e', t'), newArr))
) )
err err
@ -314,7 +283,7 @@ algoW = \case
(s0, (e0', t0)) <- algoW e0 (s0, (e0', t0)) <- algoW e0
applySt s0 $ do applySt s0 $ do
(s1, (e1', t1)) <- algoW e1 (s1, (e1', t1)) <- algoW e1
s2 <- exprErr (unify (apply s1 t0) (T.TFun t1 fr)) err s2 <- exprErr (unify (apply s1 t0) (TFun t1 fr)) err
let t = apply s2 fr let t = apply s2 fr
let comp = s2 `compose` s1 `compose` s0 let comp = s2 `compose` s1 `compose` s0
return (comp, apply comp (T.EApp (e0', t0) (e1', t1), t)) return (comp, apply comp (T.EApp (e0', t0) (e1', t1), t))
@ -346,33 +315,33 @@ makeLambda :: Exp -> [T.Ident] -> Exp
makeLambda = foldl (flip (EAbs . coerce)) makeLambda = foldl (flip (EAbs . coerce))
-- | Unify two types producing a new substitution -- | Unify two types producing a new substitution
unify :: T.Type -> T.Type -> Infer Subst unify :: Type -> Type -> Infer Subst
unify t0 t1 = do unify t0 t1 = do
case (t0, t1) of case (t0, t1) of
(T.TFun a b, T.TFun c d) -> do (TFun a b, TFun c d) -> do
s1 <- unify a c s1 <- unify a c
s2 <- unify (apply s1 b) (apply s1 d) s2 <- unify (apply s1 b) (apply s1 d)
return $ s1 `compose` s2 return $ s1 `compose` s2
----------- TODO: BE CAREFUL!!!! THIS IS PROBABLY WRONG!!! ----------- ----------- TODO: BE CAREFUL!!!! THIS IS PROBABLY WRONG!!! -----------
(T.TVar (T.MkTVar a), t@(T.TData _ _)) -> return $ M.singleton a t (TVar (T.MkTVar a), t@(TData _ _)) -> return $ M.singleton (coerce a) t
(t@(T.TData _ _), T.TVar (T.MkTVar b)) -> return $ M.singleton b t (t@(TData _ _), TVar (T.MkTVar b)) -> return $ M.singleton (coerce b) t
------------------------------------------------------------------- -------------------------------------------------------------------
(T.TVar (T.MkTVar a), t) -> occurs a t (TVar (T.MkTVar a), t) -> occurs (coerce a) t
(t, T.TVar (T.MkTVar b)) -> occurs b t (t, TVar (T.MkTVar b)) -> occurs (coerce b) t
(T.TAll _ t, b) -> unify t b (TAll _ t, b) -> unify t b
(a, T.TAll _ t) -> unify a t (a, TAll _ t) -> unify a t
(T.TLit a, T.TLit b) -> (TLit a, TLit b) ->
if a == b if a == b
then return M.empty then return M.empty
else else
throwError throwError
. unwords . unwords
$ [ "Can not unify" $ [ "Can not unify"
, "'" <> printTree (T.TLit a) <> "'" , "'" <> printTree (TLit a) <> "'"
, "with" , "with"
, "'" <> printTree (T.TLit b) <> "'" , "'" <> printTree (TLit b) <> "'"
] ]
(T.TData name t, T.TData name' t') -> (TData name t, TData name' t') ->
if name == name' && length t == length t' if name == name' && length t == length t'
then do then do
xs <- zipWithM unify t t' xs <- zipWithM unify t t'
@ -380,7 +349,7 @@ unify t0 t1 = do
else else
throwError $ throwError $
unwords unwords
[ "T.Type constructor:" [ "Type constructor:"
, printTree name , printTree name
, "(" <> printTree t <> ")" , "(" <> printTree t <> ")"
, "does not match with:" , "does not match with:"
@ -398,42 +367,42 @@ unify t0 t1 = do
I.E. { a = a -> b } is an unsolvable constraint since there is no substitution I.E. { a = a -> b } is an unsolvable constraint since there is no substitution
where these are equal where these are equal
-} -}
occurs :: T.Ident -> T.Type -> Infer Subst occurs :: T.Ident -> Type -> Infer Subst
occurs i t@(T.TVar _) = return (M.singleton i t) occurs i t@(TVar _) = return (M.singleton i t)
occurs i t = occurs i t =
if S.member i (free t) if S.member i (free t)
then then
throwError $ throwError $
unwords unwords
[ "Occurs check failed, can't unify" [ "Occurs check failed, can't unify"
, printTree (T.TVar $ T.MkTVar i) , printTree (TVar $ T.MkTVar (coerce i))
, "with" , "with"
, printTree t , printTree t
] ]
else return $ M.singleton i t else return $ M.singleton i t
-- | Generalize a type over all free variables in the substitution set -- | Generalize a type over all free variables in the substitution set
generalize :: Map T.Ident T.Type -> T.Type -> T.Type generalize :: Map T.Ident Type -> Type -> Type
generalize env t = go (S.toList $ free t S.\\ free env) (removeForalls t) generalize env t = go (S.toList $ free t S.\\ free env) (removeForalls t)
where where
go :: [T.Ident] -> T.Type -> T.Type go :: [T.Ident] -> Type -> Type
go [] t = t go [] t = t
go (x : xs) t = T.TAll (T.MkTVar x) (go xs t) go (x : xs) t = TAll (T.MkTVar (coerce x)) (go xs t)
removeForalls :: T.Type -> T.Type removeForalls :: Type -> Type
removeForalls (T.TAll _ t) = removeForalls t removeForalls (TAll _ t) = removeForalls t
removeForalls (T.TFun t1 t2) = T.TFun (removeForalls t1) (removeForalls t2) removeForalls (TFun t1 t2) = TFun (removeForalls t1) (removeForalls t2)
removeForalls t = t removeForalls t = t
{- | Instantiate a polymorphic type. The free type variables are substituted {- | Instantiate a polymorphic type. The free type variables are substituted
with fresh ones. with fresh ones.
-} -}
inst :: T.Type -> Infer T.Type inst :: Type -> Infer Type
inst = \case inst = \case
T.TAll (T.MkTVar bound) t -> do TAll (T.MkTVar bound) t -> do
fr <- fresh fr <- fresh
let s = M.singleton bound fr let s = M.singleton (coerce bound) fr
apply s <$> inst t apply s <$> inst t
T.TFun t1 t2 -> T.TFun <$> inst t1 <*> inst t2 TFun t1 t2 -> TFun <$> inst t1 <*> inst t2
rest -> return rest rest -> return rest
-- | Compose two substitution sets -- | Compose two substitution sets
@ -455,41 +424,40 @@ class FreeVars t where
-- | Get all free variables from t -- | Get all free variables from t
free :: t -> Set T.Ident free :: t -> Set T.Ident
instance FreeVars T.Type where instance FreeVars Type where
free :: T.Type -> Set T.Ident free :: Type -> Set T.Ident
free (T.TVar (T.MkTVar a)) = S.singleton a free (TVar (T.MkTVar a)) = S.singleton (coerce a)
free (T.TAll (T.MkTVar bound) t) = S.singleton bound `S.intersection` free t free (TAll (T.MkTVar bound) t) = S.singleton (coerce bound) `S.intersection` free t
free (T.TLit _) = mempty free (TLit _) = mempty
free (T.TFun a b) = free a `S.union` free b free (TFun a b) = free a `S.union` free b
-- \| Not guaranteed to be correct -- \| Not guaranteed to be correct
free (T.TData _ a) = free (TData _ a) =
foldl' (\acc x -> free x `S.union` acc) S.empty a foldl' (\acc x -> free x `S.union` acc) S.empty a
instance SubstType T.Type where instance SubstType Type where
apply :: Subst -> T.Type -> T.Type apply :: Subst -> Type -> Type
apply sub t = do apply sub t = do
case t of case t of
T.TLit a -> T.TLit a TLit a -> TLit a
T.TVar (T.MkTVar a) -> case M.lookup a sub of TVar (T.MkTVar a) -> case M.lookup (coerce a) sub of
Nothing -> T.TVar (T.MkTVar $ coerce a) Nothing -> TVar (T.MkTVar $ coerce a)
Just t -> t Just t -> t
T.TAll (T.MkTVar i) t -> case M.lookup i sub of TAll (T.MkTVar i) t -> case M.lookup (coerce i) sub of
Nothing -> T.TAll (T.MkTVar i) (apply sub t) Nothing -> TAll (T.MkTVar i) (apply sub t)
Just _ -> apply sub t Just _ -> apply sub t
T.TFun a b -> T.TFun (apply sub a) (apply sub b) TFun a b -> TFun (apply sub a) (apply sub b)
T.TData name a -> T.TData name (map (apply sub) a) TData name a -> TData name (map (apply sub) a)
instance FreeVars (Map T.Ident T.Type) where instance FreeVars (Map T.Ident Type) where
free :: Map T.Ident T.Type -> Set T.Ident free :: Map T.Ident Type -> Set T.Ident
free m = foldl' S.union S.empty (map free $ M.elems m) free m = foldl' S.union S.empty (map free $ M.elems m)
instance SubstType (Map T.Ident T.Type) where instance SubstType (Map T.Ident Type) where
apply :: Subst -> Map T.Ident T.Type -> Map T.Ident T.Type apply :: Subst -> Map T.Ident Type -> Map T.Ident Type
apply s = M.map (apply s) apply s = M.map (apply s)
instance SubstType T.Exp where instance SubstType (T.Exp' Type) where
apply :: Subst -> T.Exp -> T.Exp
apply s = \case apply s = \case
T.EId i -> T.EId i T.EVar i -> T.EVar i
T.ELit lit -> T.ELit lit T.ELit lit -> T.ELit lit
T.ELet (T.Bind (ident, t1) args e1) e2 -> T.ELet (T.Bind (ident, t1) args e1) e2 ->
T.ELet T.ELet
@ -499,19 +467,18 @@ instance SubstType T.Exp where
T.EAdd e1 e2 -> T.EAdd (apply s e1) (apply s e2) T.EAdd e1 e2 -> T.EAdd (apply s e1) (apply s e2)
T.EAbs ident e -> T.EAbs ident (apply s e) T.EAbs ident e -> T.EAbs ident (apply s e)
T.ECase e brnch -> T.ECase (apply s e) (apply s brnch) T.ECase e brnch -> T.ECase (apply s e) (apply s brnch)
T.EInj{} -> error "implement"
instance SubstType T.Branch where instance SubstType (T.Branch' Type) where
apply :: Subst -> T.Branch -> T.Branch
apply s (T.Branch (i, t) e) = T.Branch (apply s i, apply s t) (apply s e) apply s (T.Branch (i, t) e) = T.Branch (apply s i, apply s t) (apply s e)
instance SubstType T.Pattern where instance SubstType (T.Pattern' Type) where
apply :: Subst -> T.Pattern -> T.Pattern
apply s = \case apply s = \case
T.PVar (iden, t) -> T.PVar (iden, apply s t) T.PVar (iden, t) -> T.PVar (iden, apply s t)
T.PLit (lit, t) -> T.PLit (lit, apply s t) T.PLit (lit, t) -> T.PLit (lit, apply s t)
T.PInj i ps -> T.PInj i $ apply s ps T.PInj i ps -> T.PInj i $ apply s ps
T.PCatch -> T.PCatch T.PCatch -> T.PCatch
T.PEnum i -> T.PEnum i T.PEnum i -> T.PEnum i
instance SubstType a => SubstType [a] where instance SubstType a => SubstType [a] where
apply s = map (apply s) apply s = map (apply s)
@ -519,7 +486,7 @@ instance SubstType a => SubstType [a] where
instance (SubstType a, SubstType b) => SubstType (a, b) where instance (SubstType a, SubstType b) => SubstType (a, b) where
apply s (a, b) = (apply s a, apply s b) apply s (a, b) = (apply s a, apply s b)
instance SubstType T.Id where instance SubstType (T.Id' Type) where
apply s (name, t) = (name, apply s t) apply s (name, t) = (name, apply s t)
-- | Apply substitutions to the environment. -- | Apply substitutions to the environment.
@ -531,7 +498,7 @@ nullSubst :: Subst
nullSubst = M.empty nullSubst = M.empty
-- | Generate a new fresh variable and increment the state counter -- | Generate a new fresh variable and increment the state counter
fresh :: Infer T.Type fresh :: Infer Type
fresh = do fresh = do
c <- gets nextChar c <- gets nextChar
n <- gets count n <- gets count
@ -545,34 +512,34 @@ fresh = do
fresh fresh
else else
if n == 0 if n == 0
then return . T.TVar . T.MkTVar . T.Ident $ [c] then return . TVar . T.MkTVar $ LIdent [c]
else return . T.TVar . T.MkTVar . T.Ident $ [c] ++ show n else return . TVar . T.MkTVar . LIdent $ [c] ++ show n
next :: Char -> Char next :: Char -> Char
next 'z' = 'a' next 'z' = 'a'
next a = succ a next a = succ a
-- | Run the monadic action with an additional binding -- | Run the monadic action with an additional binding
withBinding :: (Monad m, MonadReader Ctx m) => T.Ident -> T.Type -> m a -> m a withBinding :: (Monad m, MonadReader Ctx m) => T.Ident -> Type -> m a -> m a
withBinding i p = local (\st -> st{vars = M.insert i p (vars st)}) withBinding i p = local (\st -> st{vars = M.insert i p (vars st)})
-- | Run the monadic action with several additional bindings -- | Run the monadic action with several additional bindings
withBindings :: (Monad m, MonadReader Ctx m) => [(T.Ident, T.Type)] -> m a -> m a withBindings :: (Monad m, MonadReader Ctx m) => [(T.Ident, Type)] -> m a -> m a
withBindings xs = withBindings xs =
local (\st -> st{vars = foldl' (flip (uncurry M.insert)) (vars st) xs}) local (\st -> st{vars = foldl' (flip (uncurry M.insert)) (vars st) xs})
-- | Insert a function signature into the environment -- | Insert a function signature into the environment
insertSig :: T.Ident -> Maybe T.Type -> Infer () insertSig :: T.Ident -> Maybe Type -> Infer ()
insertSig i t = modify (\st -> st{sigs = M.insert i t (sigs st)}) insertSig i t = modify (\st -> st{sigs = M.insert i t (sigs st)})
-- | Insert a constructor with its data type -- | Insert a constructor with its data type
insertConstr :: T.Ident -> T.Type -> Infer () insertConstr :: T.Ident -> Type -> Infer ()
insertConstr i t = insertConstr i t =
modify (\st -> st{constructors = M.insert i t (constructors st)}) modify (\st -> st{constructors = M.insert i t (constructors st)})
-------- PATTERN MATCHING --------- -------- PATTERN MATCHING ---------
checkCase :: T.Type -> [Branch] -> Infer (Subst, [T.Branch], T.Type) checkCase :: Type -> [Branch] -> Infer (Subst, [T.Branch' Type], Type)
checkCase _ [] = throwError "Atleast one case required" checkCase _ [] = throwError "Atleast one case required"
checkCase expT brnchs = do checkCase expT brnchs = do
(subs, injTs, injs, returns) <- unzip4 <$> mapM inferBranch brnchs (subs, injTs, injs, returns) <- unzip4 <$> mapM inferBranch brnchs
@ -594,23 +561,23 @@ checkCase expT brnchs = do
let comp = sub2 `compose` sub1 `compose` sub0 let comp = sub2 `compose` sub1 `compose` sub0
return (comp, apply comp injs, apply comp returns_type) return (comp, apply comp injs, apply comp returns_type)
inferBranch :: Branch -> Infer (Subst, T.Type, T.Branch, T.Type) inferBranch :: Branch -> Infer (Subst, Type, T.Branch' Type, Type)
inferBranch (Branch pat expr) = do inferBranch (Branch pat expr) = do
newPat@(pat, branchT) <- inferPattern pat newPat@(pat, branchT) <- inferPattern pat
(sub, newExp@(_, exprT)) <- withPattern pat (algoW expr) (sub, newExp@(_, exprT)) <- withPattern pat (algoW expr)
return (sub, apply sub branchT, T.Branch (apply sub newPat) (apply sub newExp), apply sub exprT) return (sub, apply sub branchT, T.Branch (apply sub newPat) (apply sub newExp), apply sub exprT)
withPattern :: T.Pattern -> Infer a -> Infer a withPattern :: T.Pattern' Type -> Infer a -> Infer a
withPattern p ma = case p of withPattern p ma = case p of
T.PVar (x, t) -> withBinding x t ma T.PVar (x, t) -> withBinding x t ma
T.PInj _ ps -> foldl' (flip withPattern) ma ps T.PInj _ ps -> foldl' (flip withPattern) ma ps
T.PLit _ -> ma T.PLit _ -> ma
T.PCatch -> ma T.PCatch -> ma
T.PEnum _ -> ma T.PEnum _ -> ma
inferPattern :: Pattern -> Infer (T.Pattern, T.Type) inferPattern :: Pattern -> Infer (T.Pattern' Type, Type)
inferPattern = \case inferPattern = \case
PLit lit -> let lt = litType lit in return (T.PLit (toNew lit, lt), lt) PLit lit -> let lt = litType lit in return (T.PLit (lit, lt), lt)
PInj constr patterns -> do PInj constr patterns -> do
t <- gets (M.lookup (coerce constr) . constructors) t <- gets (M.lookup (coerce constr) . constructors)
t <- maybeToRightM ("Constructor: " <> printTree constr <> " does not exist") t t <- maybeToRightM ("Constructor: " <> printTree constr <> " does not exist") t
@ -644,28 +611,28 @@ inferPattern = \case
++ show (typeLength t - 1) ++ show (typeLength t - 1)
++ " arguments but has been given 0" ++ " arguments but has been given 0"
) )
let (T.TData _data _ts) = t -- nasty nasty let (TData _data _ts) = t -- nasty nasty
frs <- mapM (const fresh) _ts frs <- mapM (const fresh) _ts
return (T.PEnum $ coerce p, T.TData _data frs) return (T.PEnum $ coerce p, TData _data frs)
PVar x -> do PVar x -> do
fr <- fresh fr <- fresh
let pvar = T.PVar (coerce x, fr) let pvar = T.PVar (coerce x, fr)
return (pvar, fr) return (pvar, fr)
flattenType :: T.Type -> [T.Type] flattenType :: Type -> [Type]
flattenType (T.TFun a b) = flattenType a <> flattenType b flattenType (TFun a b) = flattenType a <> flattenType b
flattenType a = [a] flattenType a = [a]
typeLength :: T.Type -> Int typeLength :: Type -> Int
typeLength (T.TFun a b) = typeLength a + typeLength b typeLength (TFun a b) = typeLength a + typeLength b
typeLength _ = 1 typeLength _ = 1
litType :: Lit -> T.Type litType :: Lit -> Type
litType (LInt _) = int litType (LInt _) = int
litType (LChar _) = char litType (LChar _) = char
int = T.TLit "Int" int = TLit "Int"
char = T.TLit "Char" char = TLit "Char"
partitionType :: partitionType ::
Int -> -- Number of parameters to apply Int -> -- Number of parameters to apply
@ -676,8 +643,8 @@ partitionType = go []
go acc 0 t = (acc, t) go acc 0 t = (acc, t)
go acc i t = case t of go acc i t = case t of
TAll tvar t' -> second (TAll tvar) $ go acc i t' TAll tvar t' -> second (TAll tvar) $ go acc i t'
TFun t1 t2 -> go (acc <> [t1]) (i - 1) t2 TFun t1 t2 -> go (acc <> [t1]) (i - 1) t2
_ -> error "Number of parameters and type doesn't match" _ -> error "Number of parameters and type doesn't match"
exprErr :: Infer a -> Exp -> Infer a exprErr :: Infer a -> Exp -> Infer a
exprErr ma exp = exprErr ma exp =
@ -691,3 +658,19 @@ unzip4 =
) )
([], [], [], []) ([], [], [], [])
newtype Ctx = Ctx {vars :: Map T.Ident Type}
deriving (Show)
data Env = Env
{ count :: Int
, nextChar :: Char
, sigs :: Map T.Ident (Maybe Type)
, constructors :: Map T.Ident Type
, takenTypeVars :: Set T.Ident
}
deriving (Show)
type Error = String
type Subst = Map T.Ident Type
type Infer = StateT Env (ReaderT Ctx (ExceptT Error Identity))

View file

@ -1,24 +1,24 @@
{-# LANGUAGE LambdaCase #-} {-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE PatternSynonyms #-}
module TypeChecker.TypeCheckerIr (
module Grammar.Abs,
module TypeChecker.TypeCheckerIr,
) where
module TypeChecker.TypeCheckerIr import Data.String (IsString)
( module Grammar.Abs import Grammar.Abs (Lit (..), TVar (..))
, module TypeChecker.TypeCheckerIr import Grammar.Print
) where import Prelude
import Prelude qualified as C (Eq, Ord, Read, Show)
import Data.String (IsString)
import Grammar.Abs (Lit (..), TVar (..))
import Grammar.Print
import Prelude
import qualified Prelude as C (Eq, Ord, Read, Show)
newtype Program' t = Program [Def' t] newtype Program' t = Program [Def' t]
deriving (C.Eq, C.Ord, C.Show, C.Read) deriving (C.Eq, C.Ord, C.Show, C.Read)
data Def' t = DBind (Bind' t) data Def' t
| DData (Data' t) = DBind (Bind' t)
deriving (C.Eq, C.Ord, C.Show, C.Read) | DData (Data' t)
deriving (C.Eq, C.Ord, C.Show, C.Read)
data Type data Type
= TLit Ident = TLit Ident
@ -26,24 +26,24 @@ data Type
| TData Ident [Type] | TData Ident [Type]
| TFun Type Type | TFun Type Type
| TAll TVar Type | TAll TVar Type
deriving (C.Eq, C.Ord, C.Show, C.Read) deriving (C.Eq, C.Ord, C.Show, C.Read)
data Data' t = Data t [Inj' t] data Data' t = Data t [Inj' t]
deriving (C.Eq, C.Ord, C.Show, C.Read) deriving (C.Eq, C.Ord, C.Show, C.Read)
data Inj' t = Inj Ident t data Inj' t = Inj Ident t
deriving (C.Eq, C.Ord, C.Show, C.Read) deriving (C.Eq, C.Ord, C.Show, C.Read)
newtype Ident = Ident String newtype Ident = Ident String
deriving (C.Eq, C.Ord, C.Show, C.Read, IsString) deriving (C.Eq, C.Ord, C.Show, C.Read, IsString)
data Pattern' t data Pattern' t
= PVar (Id' t) -- TODO should be Ident = PVar (Id' t) -- TODO should be Ident
| PLit (Lit, t) -- TODO should be Lit | PLit (Lit, t) -- TODO should be Lit
| PCatch | PCatch
| PEnum Ident | PEnum Ident
| PInj Ident [Pattern' t] -- TODO should be (Pattern' t, t) | PInj Ident [Pattern' t] -- TODO should be (Pattern' t, t)
deriving (C.Eq, C.Ord, C.Show, C.Read) deriving (C.Eq, C.Ord, C.Show, C.Read)
data Exp' t data Exp' t
= EVar Ident = EVar Ident
@ -52,18 +52,18 @@ data Exp' t
| ELet (Bind' t) (ExpT' t) | ELet (Bind' t) (ExpT' t)
| EApp (ExpT' t) (ExpT' t) | EApp (ExpT' t) (ExpT' t)
| EAdd (ExpT' t) (ExpT' t) | EAdd (ExpT' t) (ExpT' t)
| EAbs Ident (ExpT' t) | EAbs Ident (ExpT' t)
| ECase (ExpT' t) [Branch' t] | ECase (ExpT' t) [Branch' t]
deriving (C.Eq, C.Ord, C.Show, C.Read) deriving (C.Eq, C.Ord, C.Show, C.Read)
type Id' t = (Ident, t) type Id' t = (Ident, t)
type ExpT' t = (Exp' t, t) type ExpT' t = (Exp' t, t)
data Bind' t = Bind (Id' t) [Id' t] (ExpT' t) data Bind' t = Bind (Id' t) [Id' t] (ExpT' t)
deriving (C.Eq, C.Ord, C.Show, C.Read) deriving (C.Eq, C.Ord, C.Show, C.Read)
data Branch' t = Branch (Pattern' t, t) (ExpT' t) data Branch' t = Branch (Pattern' t, t) (ExpT' t)
deriving (C.Eq, C.Ord, C.Show, C.Read) deriving (C.Eq, C.Ord, C.Show, C.Read)
instance Print Ident where instance Print Ident where
prt i (Ident s) = prt i s prt i (Ident s) = prt i s
@ -72,127 +72,143 @@ instance Print t => Print (Program' t) where
prt i (Program sc) = prPrec i 0 $ prt 0 sc prt i (Program sc) = prPrec i 0 $ prt 0 sc
instance Print t => Print (Bind' t) where instance Print t => Print (Bind' t) where
prt i (Bind sig@(name, _) parms rhs) = prPrec i 0 $ concatD prt i (Bind sig@(name, _) parms rhs) =
[ prtSig sig prPrec i 0 $
, prt 0 name concatD
, prtIdPs 0 parms [ prtSig sig
, doc $ showString "=" , prt 0 name
, prt 0 rhs , prtIdPs 0 parms
] , doc $ showString "="
, prt 0 rhs
]
prtSig :: Print t => Id' t -> Doc prtSig :: Print t => Id' t -> Doc
prtSig (name, t) = concatD [ prt 0 name prtSig (name, t) =
, doc $ showString ":" concatD
, prt 0 t [ prt 0 name
, doc $ showString ";" , doc $ showString ":"
] , prt 0 t
, doc $ showString ";"
]
instance Print t => Print (ExpT' t) where instance Print t => Print (ExpT' t) where
prt i (e, t) = concatD [ doc $ showString "(" prt i (e, t) =
, prt i e concatD
, doc $ showString "," [ doc $ showString "("
, prt i t , prt i e
, doc $ showString ")" , doc $ showString ","
] , prt i t
, doc $ showString ")"
]
instance Print t => Print [Bind' t] where instance Print t => Print [Bind' t] where
prt _ [] = concatD [] prt _ [] = concatD []
prt _ [x] = concatD [prt 0 x] prt _ [x] = concatD [prt 0 x]
prt _ (x:xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs] prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs]
prtIdPs :: Print t => Int -> [Id' t] -> Doc prtIdPs :: Print t => Int -> [Id' t] -> Doc
prtIdPs i = prPrec i 0 . concatD . map (prt i) prtIdPs i = prPrec i 0 . concatD . map (prt i)
instance Print t => Print (Id' t) where instance Print t => Print (Id' t) where
prt i (name, t) = concatD [ doc $ showString "(" prt i (name, t) =
, prt i name concatD
, doc $ showString "," [ doc $ showString "("
, prt i t , prt i name
, doc $ showString ")" , doc $ showString ","
] , prt i t
, doc $ showString ")"
]
instance Print t => Print (Exp' t) where instance Print t => Print (Exp' t) where
prt i = \case prt i = \case
EVar name -> prPrec i 3 $ prt 0 name EVar name -> prPrec i 3 $ prt 0 name
EInj name -> prPrec i 3 $ prt 0 name EInj name -> prPrec i 3 $ prt 0 name
ELit lit -> prPrec i 3 $ prt 0 lit ELit lit -> prPrec i 3 $ prt 0 lit
ELet b e -> prPrec i 3 $ concatD ELet b e ->
[ doc $ showString "let" prPrec i 3 $
, prt 0 b concatD
, doc $ showString "in" [ doc $ showString "let"
, prt 0 e , prt 0 b
] , doc $ showString "in"
EApp e1 e2 -> prPrec i 2 $ concatD , prt 0 e
[ prt 2 e1 ]
, prt 3 e2 EApp e1 e2 ->
] prPrec i 2 $
EAdd e1 e2 -> prPrec i 1 $ concatD concatD
[ prt 1 e1 [ prt 2 e1
, doc $ showString "+" , prt 3 e2
, prt 2 e2 ]
] EAdd e1 e2 ->
EAbs v e -> prPrec i 0 $ concatD prPrec i 1 $
[ doc $ showString "\\" concatD
, prt 0 v [ prt 1 e1
, doc $ showString "." , doc $ showString "+"
, prt 0 e , prt 2 e2
] ]
EAbs v e ->
ECase e branches -> prPrec i 0 $ concatD prPrec i 0 $
[ doc $ showString "case" concatD
, prt 0 e [ doc $ showString "\\"
, doc $ showString "of" , prt 0 v
, doc $ showString "{" , doc $ showString "."
, prt 0 branches , prt 0 e
, doc $ showString "}" ]
] ECase e branches ->
prPrec i 0 $
concatD
[ doc $ showString "case"
, prt 0 e
, doc $ showString "of"
, doc $ showString "{"
, prt 0 branches
, doc $ showString "}"
]
instance Print t => Print (Branch' t) where instance Print t => Print (Branch' t) where
prt i (Branch (pattern_, t) exp) = prPrec i 0 (concatD [doc (showString "("), prt 0 pattern_, doc (showString " : "), prt 0 t, doc (showString ")"), doc (showString "=>"), prt 0 exp]) prt i (Branch (pattern_, t) exp) = prPrec i 0 (concatD [doc (showString "("), prt 0 pattern_, doc (showString " : "), prt 0 t, doc (showString ")"), doc (showString "=>"), prt 0 exp])
instance Print t => Print [Branch' t] where instance Print t => Print [Branch' t] where
prt _ [] = concatD [] prt _ [] = concatD []
prt _ [x] = concatD [prt 0 x] prt _ [x] = concatD [prt 0 x]
prt _ (x:xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs] prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs]
instance Print t => Print (Def' t) where instance Print t => Print (Def' t) where
prt i = \case prt i = \case
DBind bind -> prPrec i 0 (concatD [prt 0 bind]) DBind bind -> prPrec i 0 (concatD [prt 0 bind])
DData data_ -> prPrec i 0 (concatD [prt 0 data_]) DData data_ -> prPrec i 0 (concatD [prt 0 data_])
instance Print t => Print (Data' t) where instance Print t => Print (Data' t) where
prt i = \case prt i = \case
Data type_ injs -> prPrec i 0 (concatD [doc (showString "data"), prt 0 type_, doc (showString "where"), doc (showString "{"), prt 0 injs, doc (showString "}")]) Data type_ injs -> prPrec i 0 (concatD [doc (showString "data"), prt 0 type_, doc (showString "where"), doc (showString "{"), prt 0 injs, doc (showString "}")])
instance Print t => Print (Inj' t) where instance Print t => Print (Inj' t) where
prt i = \case prt i = \case
Inj uident type_ -> prPrec i 0 (concatD [prt 0 uident, doc (showString ":"), prt 0 type_]) Inj uident type_ -> prPrec i 0 (concatD [prt 0 uident, doc (showString ":"), prt 0 type_])
instance Print t => Print (Pattern' t) where instance Print t => Print (Pattern' t) where
prt i = \case prt i = \case
PVar name -> prPrec i 1 (concatD [prt 0 name]) PVar name -> prPrec i 1 (concatD [prt 0 name])
PLit (lit, _) -> prPrec i 1 (concatD [prt 0 lit]) PLit (lit, _) -> prPrec i 1 (concatD [prt 0 lit])
PCatch -> prPrec i 1 (concatD [doc (showString "_")]) PCatch -> prPrec i 1 (concatD [doc (showString "_")])
PEnum name -> prPrec i 1 (concatD [prt 0 name]) PEnum name -> prPrec i 1 (concatD [prt 0 name])
PInj uident patterns -> prPrec i 0 (concatD [prt 0 uident, prt 1 patterns]) PInj uident patterns -> prPrec i 0 (concatD [prt 0 uident, prt 1 patterns])
instance Print t => Print [Def' t] where instance Print t => Print [Def' t] where
prt _ [] = concatD [] prt _ [] = concatD []
prt _ [x] = concatD [prt 0 x] prt _ [x] = concatD [prt 0 x]
prt _ (x:xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs] prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs]
instance Print [Type] where instance Print [Type] where
prt _ [] = concatD [] prt _ [] = concatD []
prt _ (x:xs) = concatD [prt 0 x, doc (showString " "), prt 0 xs] prt _ (x : xs) = concatD [prt 0 x, doc (showString " "), prt 0 xs]
instance Print Type where instance Print Type where
prt i = \case prt i = \case
TLit uident -> prPrec i 1 (concatD [prt 0 uident]) TLit uident -> prPrec i 1 (concatD [prt 0 uident])
TVar tvar -> prPrec i 1 (concatD [prt 0 tvar]) TVar tvar -> prPrec i 1 (concatD [prt 0 tvar])
TData uident types -> prPrec i 1 (concatD [prt 0 uident, doc (showString "("), prt 0 types, doc (showString ")")]) TData uident types -> prPrec i 1 (concatD [prt 0 uident, doc (showString "("), prt 0 types, doc (showString ")")])
TFun type_1 type_2 -> prPrec i 0 (concatD [prt 1 type_1, doc (showString "->"), prt 0 type_2]) TFun type_1 type_2 -> prPrec i 0 (concatD [prt 1 type_1, doc (showString "->"), prt 0 type_2])
TAll tvar type_ -> prPrec i 0 (concatD [doc (showString "forall"), prt 0 tvar, doc (showString "."), prt 0 type_]) TAll tvar type_ -> prPrec i 0 (concatD [doc (showString "forall"), prt 0 tvar, doc (showString "."), prt 0 type_])
type Program = Program' Type type Program = Program' Type
type Def = Def' Type type Def = Def' Type
@ -201,9 +217,8 @@ type Bind = Bind' Type
type Branch = Branch' Type type Branch = Branch' Type
type Pattern = Pattern' Type type Pattern = Pattern' Type
type Inj = Inj' Type type Inj = Inj' Type
type Exp = Exp' Type type Exp = Exp' Type
type ExpT = ExpT' Type type ExpT = ExpT' Type
type Id = Id' Type type Id = Id' Type
pattern DBind' id vars expt = DBind (Bind id vars expt) pattern DBind' id vars expt = DBind (Bind id vars expt)
pattern DData' typ injs = DData (Data typ injs) pattern DData' typ injs = DData (Data typ injs)