Add closures and fix lets in monomorphizer

This commit is contained in:
Martin Fredin 2023-05-06 22:49:08 +02:00
parent 677a200a15
commit 72e599d5de
26 changed files with 1440 additions and 692 deletions

View file

@ -43,6 +43,7 @@ executable language
TypeChecker.ReportTEVar TypeChecker.ReportTEVar
TypeChecker.RemoveForall TypeChecker.RemoveForall
LambdaLifter LambdaLifter
LambdaLifterIr
Monomorphizer.Monomorphizer Monomorphizer.Monomorphizer
Monomorphizer.MonomorphizerIr Monomorphizer.MonomorphizerIr
Monomorphizer.MorbIr Monomorphizer.MorbIr
@ -101,6 +102,8 @@ Test-suite language-testsuite
TypeChecker.TypeChecker TypeChecker.TypeChecker
AnnForall AnnForall
ReportForall ReportForall
LambdaLifterIr
LambdaLifter
TypeChecker.TypeCheckerHm TypeChecker.TypeCheckerHm
TypeChecker.TypeCheckerBidir TypeChecker.TypeCheckerBidir
TypeChecker.ReportTEVar TypeChecker.ReportTEVar

View file

@ -0,0 +1,6 @@
add : Int -> Int -> Int -> Int
add x y z = x + y + z
main = add 8 6 2

View file

@ -0,0 +1,7 @@
apply : (Int -> Int) -> Int -> Int
apply f y = f y
main = apply (\y. y + y) 5

View file

@ -0,0 +1,10 @@
apply : (Int -> Int) -> Int -> Int
apply f z = f z
main =
let x = 10 in
apply (\y. y + x) 6

View file

@ -0,0 +1,15 @@
data List a where
Nil : List a
Cons : a -> List a -> List a
foldr : (a -> b -> b) -> b -> List a -> b
foldr f y xs = case xs of
Nil => y
Cons x xs => f x (foldr f y xs)
main = let z = 2 in foldr (\x.\y. x + y + z) 0 (Cons 1000 (Cons 100 Nil))

View file

@ -0,0 +1,25 @@
data List (a) where
Nil : List (a)
Cons : a -> List (a) -> List (a)
map : (a -> b) -> List (a) -> List (b)
map f xs = case xs of
Nil => Nil
Cons x xs => Cons (f x) (map f xs)
add : Int -> Int -> Int
add x y = x + y
foldr : (a -> b -> b) -> b -> List (a) -> b
foldr f y xs = case xs of
Nil => y
Cons x xs => f x (foldr f y xs)
f : List (Int)
f = ((\x.\ys. map (\y. add y x) ys) 4 (Cons 1 (Cons 2 Nil)))
-- [5, 6]
main : Int
main = foldr add 0 f

View file

@ -0,0 +1,21 @@
data List a where
Nil : List a
Cons : a -> List a -> List a
map : (a -> b) -> List a -> List b
map f xs = case xs of
Nil => Nil
Cons x xs => Cons (f x) (map f xs)
f : List Int
f = (\x.\ys. map (\y. y + x) ys) 4 (Cons 1 (Cons 2 Nil))
-- [5, 6]
sum : List Int -> Int
sum xs = case xs of
Nil => 0
Cons x xs => x + sum xs
main = sum f

View file

@ -0,0 +1,3 @@
main = let x = 10 in 6 + x

View file

@ -0,0 +1,16 @@
data List a where
Nil : List a
Cons : a -> List a -> List a
map : (a -> b) -> List a -> List b
map f xs = case xs of
Nil => Nil
Cons x xs => Cons (f x) (map f xs)
sum : List Int -> Int
sum xs = case xs of
Nil => 0
Cons x xs => x + (sum xs)
main = let y = 10 in sum (map (\x. x + y) (Cons 2 (Cons 4 Nil)))

View file

@ -0,0 +1,7 @@
f = 10
main = f + 6

View file

@ -1,25 +1,25 @@
module Codegen.Auxillary where module Codegen.Auxillary where
import Codegen.LlvmIr (LLVMType (..), LLVMValue (..)) import Codegen.LlvmIr (LLVMType (..), LLVMValue (..))
import Control.Monad (foldM_) import Control.Monad (foldM_)
import Monomorphizer.MonomorphizerIr as MIR (ExpT, Type (..)) import Monomorphizer.MonomorphizerIr as MIR (Exp, T, Type (..))
import TypeChecker.TypeCheckerIr qualified as TIR import qualified TypeChecker.TypeCheckerIr as TIR
type2LlvmType :: MIR.Type -> LLVMType type2LlvmType :: MIR.Type -> LLVMType
type2LlvmType (MIR.TLit id@(TIR.Ident name)) = case name of type2LlvmType (MIR.TLit id@(TIR.Ident name)) = case name of
"Int" -> I64 "Int" -> I64
"Char" -> I8 "Char" -> I8
"Bool" -> I1 "Bool" -> I1
_ -> CustomType id _ -> CustomType id
type2LlvmType (MIR.TFun t xs) = do type2LlvmType (MIR.TFun t xs) = do
let (t', xs') = function2LLVMType xs [type2LlvmType t] let (t', xs') = function2LLVMType xs [type2LlvmType t]
Function t' xs' Function t' xs'
where where
function2LLVMType :: Type -> [LLVMType] -> (LLVMType, [LLVMType]) function2LLVMType :: Type -> [LLVMType] -> (LLVMType, [LLVMType])
function2LLVMType (TFun t xs) s = function2LLVMType xs (type2LlvmType t : s) function2LLVMType (TFun t xs) s = function2LLVMType xs (type2LlvmType t : s)
function2LLVMType x s = (type2LlvmType x, s) function2LLVMType x s = (type2LlvmType x, s)
getType :: ExpT -> LLVMType getType :: T Exp -> LLVMType
getType (_, t) = type2LlvmType t getType (_, t) = type2LlvmType t
extractTypeName :: MIR.Type -> TIR.Ident extractTypeName :: MIR.Type -> TIR.Ident
@ -30,21 +30,21 @@ extractTypeName (MIR.TFun t xs) =
in TIR.Ident $ i <> "_$_" <> is in TIR.Ident $ i <> "_$_" <> is
valueGetType :: LLVMValue -> LLVMType valueGetType :: LLVMValue -> LLVMType
valueGetType (VInteger _) = I64 valueGetType (VInteger _) = I64
valueGetType (VChar _) = I8 valueGetType (VChar _) = I8
valueGetType (VIdent _ t) = t valueGetType (VIdent _ t) = t
valueGetType (VConstant s) = Array (fromIntegral $ length s) I8 valueGetType (VConstant s) = Array (fromIntegral $ length s) I8
valueGetType (VFunction _ _ t) = t valueGetType (VFunction _ _ t) = t
typeByteSize :: LLVMType -> Integer typeByteSize :: LLVMType -> Integer
typeByteSize I1 = 1 typeByteSize I1 = 1
typeByteSize I8 = 1 typeByteSize I8 = 1
typeByteSize I32 = 4 typeByteSize I32 = 4
typeByteSize I64 = 8 typeByteSize I64 = 8
typeByteSize Ptr = 8 typeByteSize Ptr = 8
typeByteSize (Ref _) = 8 typeByteSize (Ref _) = 8
typeByteSize (Function _ _) = 8 typeByteSize (Function _ _) = 8
typeByteSize (Array n t) = n * typeByteSize t typeByteSize (Array n t) = n * typeByteSize t
typeByteSize (CustomType _) = 8 typeByteSize (CustomType _) = 8
enumerateOneM_ :: Monad m => (Integer -> a -> m b) -> [a] -> m () enumerateOneM_ :: Monad m => (Integer -> a -> m b) -> [a] -> m ()

View file

@ -1,18 +1,24 @@
{-# LANGUAGE OverloadedRecordDot #-}
{-# LANGUAGE OverloadedStrings #-}
module Codegen.Codegen (generateCode) where module Codegen.Codegen (generateCode) where
import Codegen.CompilerState ( import Codegen.CompilerState (CodeGenerator (..),
CodeGenerator (instructions), StructType (inst),
initCodeGenerator, initCodeGenerator)
) import Codegen.Emits (compileScs)
import Codegen.Emits (compileScs) import Codegen.LlvmIr as LIR (LLVMIr (UnsafeRaw),
import Codegen.LlvmIr as LIR (llvmIrToString) llvmIrToString)
import Control.Monad.State ( import Control.Monad.State (execStateT)
execStateT, import Data.Functor ((<&>))
) import Data.List (sortBy)
import Data.List (sortBy) import qualified Data.Map as Map
import Grammar.ErrM (Err) import Grammar.ErrM (Err)
import Monomorphizer.MonomorphizerIr as MIR (Bind (..), Data (..), Def (DBind, DData), Program (..), Type (TLit)) import Monomorphizer.MonomorphizerIr as MIR (Bind (..), Data (..),
import TypeChecker.TypeCheckerIr (Ident (..)) Def (DBind, DData),
Program (..),
Type (TLit))
import TypeChecker.TypeCheckerIr (Ident (..))
{- | Compiles an AST and produces a LLVM Ir string. {- | Compiles an AST and produces a LLVM Ir string.
An easy way to actually "compile" this output is to An easy way to actually "compile" this output is to
@ -20,16 +26,43 @@ import TypeChecker.TypeCheckerIr (Ident (..))
-} -}
generateCode :: MIR.Program -> Bool -> Err String generateCode :: MIR.Program -> Bool -> Err String
generateCode (MIR.Program scs) addGc = do generateCode (MIR.Program scs) addGc = do
let tree = filter (not . detectPrelude) (sortBy lowData scs) let tree = filter (not . detectPrelude) (sortBy lowData scs)
let codegen = initCodeGenerator addGc tree codegen = initCodeGenerator addGc tree
llvmIrToString . instructions <$> execStateT (compileScs tree) codegen
-- Append instructions
execStateT (compileScs tree) codegen <&> \state ->
llvmIrToString $ defaultStart
++ (if addGc then gcStart else [])
++ map inst (Map.elems state.structTypes)
++ state.instructions
detectPrelude :: Def -> Bool detectPrelude :: Def -> Bool
detectPrelude (DData (Data (TLit (Ident "Bool")) _)) = True detectPrelude (DData (Data (TLit (Ident "Bool")) _)) = True
detectPrelude (DBind (Bind (Ident ('l' : 't' : '$' : _), _) _ _)) = True detectPrelude (DBind (Bind (Ident ('l' : 't' : '$' : _), _) _ _)) = True
detectPrelude _ = False detectPrelude _ = False
lowData :: Def -> Def -> Ordering lowData :: Def -> Def -> Ordering
lowData (DData _) (DBind _) = LT lowData (DData _) (DBind _) = LT
lowData (DBind _) (DData _) = GT lowData (DBind _) (DData _) = GT
lowData _ _ = EQ lowData _ _ = EQ
defaultStart :: [LLVMIr]
defaultStart =
[ UnsafeRaw "target triple = \"x86_64-pc-linux-gnu\"\n"
, UnsafeRaw "target datalayout = \"e-m:o-i64:64-f80:128-n8:16:32:64-S128\"\n"
, UnsafeRaw "@.str = private unnamed_addr constant [3 x i8] c\"%i\n\", align 1\n"
, UnsafeRaw "@.non_exhaustive_patterns = private unnamed_addr constant [41 x i8] c\"Non-exhaustive patterns in case at %i:%i\n\"\n"
, UnsafeRaw "declare i32 @printf(ptr noalias nocapture, ...)\n"
, UnsafeRaw "declare i32 @exit(i32 noundef)\n"
, UnsafeRaw "declare ptr @malloc(i32 noundef)\n"
]
gcStart :: [LLVMIr]
gcStart =
[ UnsafeRaw "declare external void @cheap_init()\n"
, UnsafeRaw "declare external ptr @cheap_alloc(i64)\n"
, UnsafeRaw "declare external void @cheap_dispose()\n"
, UnsafeRaw "declare external ptr @cheap_the()\n"
, UnsafeRaw "declare external void @cheap_set_profiler(ptr, i1)\n"
, UnsafeRaw "declare external void @cheap_profiler_log_options(ptr, i64)\n"
]

View file

@ -1,46 +1,101 @@
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedRecordDot #-}
{-# LANGUAGE OverloadedStrings #-}
module Codegen.CompilerState where module Codegen.CompilerState where
import Auxiliary (snoc) import Auxiliary (snoc)
import Codegen.Auxillary (type2LlvmType, typeByteSize) import Codegen.Auxillary (type2LlvmType, typeByteSize)
import Codegen.LlvmIr as LIR (LLVMIr (UnsafeRaw), import Codegen.LlvmIr as LIR (LLVMIr (SetVariable, Type),
LLVMType) LLVMType (CustomType, Function, I64, Ptr),
import Control.Monad.State (StateT, gets, modify) LLVMValue (VFunction, VIdent),
Visibility (Global),
typeOf)
import Control.Monad.State (StateT, gets, modify, void)
import Data.Map (Map) import Data.Map (Map)
import qualified Data.Map as Map import qualified Data.Map as Map
import Grammar.ErrM (Err) import Grammar.ErrM (Err)
import Monomorphizer.MonomorphizerIr as MIR import Monomorphizer.MonomorphizerIr (Ident (..), Inj (..), T,
flattenType)
import qualified Monomorphizer.MonomorphizerIr as MIR
import qualified TypeChecker.TypeCheckerIr as TIR import qualified TypeChecker.TypeCheckerIr as TIR
-- | The record used as the code generator state -- | The record used as the code generator state
data CodeGenerator = CodeGenerator data CodeGenerator = CodeGenerator
{ instructions :: [LLVMIr] { instructions :: [LLVMIr]
, functions :: Map MIR.Id FunctionInfo , functions :: Map (T Ident) FunctionInfo
, customTypes :: Map LLVMType Integer , customTypes :: Map LLVMType Integer
, constructors :: Map TIR.Ident ConstructorInfo , constructors :: Map Ident ConstructorInfo
, variableCount :: Integer , variableCount :: Integer
, labelCount :: Integer , labelCount :: Integer
, gcEnabled :: Bool , gcEnabled :: Bool
, structTypes :: Map Ident StructType
-- ^ Custom stucture types
, locals :: [(Ident, LocalElem)]
-- ^ Arguments and variables in local environment
, globals :: Map Ident (LLVMType, LLVMValue)
} }
data StructType = StructType
{ ptr :: LLVMType
, typs :: [LLVMType]
, inst :: LLVMIr
}
data LocalElem = LocalElem
{ typ :: LLVMType
, val :: LLVMValue
}
-- | A state type synonym -- | A state type synonym
type CompilerState a = StateT CodeGenerator Err a type CompilerState a = StateT CodeGenerator Err a
data FunctionInfo = FunctionInfo data FunctionInfo = FunctionInfo
{ numArgs :: Int { numArgs :: Int
, arguments :: [Id] , arguments :: [T Ident]
} }
deriving (Show) deriving (Show)
data ConstructorInfo = ConstructorInfo data ConstructorInfo = ConstructorInfo
{ numArgsCI :: Int { numArgsCI :: Int
, argumentsCI :: [Id] , argumentsCI :: [T Ident]
, numCI :: Integer , numCI :: Integer
, returnTypeCI :: MIR.Type , returnTypeCI :: MIR.Type
} }
deriving (Show) deriving (Show)
addStructType_ :: Ident -> [LLVMType] -> CompilerState ()
addStructType_ = fmap void . addStructType
addStructType :: Ident -> [LLVMType] -> CompilerState LLVMType
addStructType x ts = do
modify $ \s -> s { structTypes = Map.insert x struct s.structTypes }
pure t
where
struct = StructType
{ ptr = t
, typs = ts
, inst = Type x ts
}
t = CustomType x
-- | Adds a instruction to the CodeGenerator state -- | Adds a instruction to the CodeGenerator state
emit :: LLVMIr -> CompilerState () emit :: LLVMIr -> CompilerState ()
emit l = modify $ \t -> t{instructions = Auxiliary.snoc l $ instructions t}
-- Add variable to environment
emit l@(SetVariable x _) = modify $ \t ->
t { instructions = Auxiliary.snoc l t.instructions
, locals = snoc (x, local)
t.locals
}
where
local = LocalElem { typ = typeOf l
, val = VIdent x (typeOf l)
}
emit l = modify $ \t -> t { instructions = Auxiliary.snoc l t.instructions }
-- | Increases the variable counter in the CodeGenerator state -- | Increases the variable counter in the CodeGenerator state
increaseVarCount :: CompilerState () increaseVarCount :: CompilerState ()
@ -63,16 +118,19 @@ getNewLabel = do
{- | Produces a map of functions infos from a list of binds, {- | Produces a map of functions infos from a list of binds,
which contains useful data for code generation. which contains useful data for code generation.
-} -}
getFunctions :: [MIR.Def] -> Map Id FunctionInfo getFunctions :: [MIR.Def] -> Map (T Ident) FunctionInfo
getFunctions bs = Map.fromList $ go bs getFunctions bs = Map.fromList $ go bs
where where
go [] = [] go [] = []
go (MIR.DBind (MIR.Bind id args _) : xs) = go (MIR.DBind (MIR.Bind id args _) : xs) =
(id, FunctionInfo{numArgs = length args, arguments = args}) (id, FunctionInfo { numArgs = length args
: go xs , arguments = args
}
)
: go xs
go (_ : xs) = go xs go (_ : xs) = go xs
createArgs :: [MIR.Type] -> [Id] createArgs :: [MIR.Type] -> [T Ident]
createArgs xs = fst $ foldl (\(acc, l) t -> (acc ++ [(TIR.Ident ("arg_" <> show l), t)], l + 1)) ([], 0) xs createArgs xs = fst $ foldl (\(acc, l) t -> (acc ++ [(TIR.Ident ("arg_" <> show l), t)], l + 1)) ([], 0) xs
{- | Produces a map of functions infos from a list of binds, {- | Produces a map of functions infos from a list of binds,
@ -113,35 +171,43 @@ getTypes bs = Map.fromList $ go bs
variantTypes fi = init $ map type2LlvmType (flattenType fi) variantTypes fi = init $ map type2LlvmType (flattenType fi)
biggestVariant ts = 8 + maximum (sum . (\(Inj _ fi) -> typeByteSize <$> variantTypes fi) <$> ts) biggestVariant ts = 8 + maximum (sum . (\(Inj _ fi) -> typeByteSize <$> variantTypes fi) <$> ts)
getGlobals :: [MIR.Def] -> Map Ident (LLVMType, LLVMValue)
getGlobals scs = Map.fromList [ go b | MIR.DBind b <- scs ]
where
go bind | x == "main" = let typ = Function I64 []
in (x, (typ, VFunction x Global typ))
| otherwise = (x, (typ, VFunction x Global typ))
where
typ = Function tr $ Ptr : ts
Function tr ts = type2LlvmType' t
(x, t) = case bind of
MIR.Bind xt _ _ -> xt
MIR.BindC _ xt _ _ -> xt
-- Higher order function arguments are replaced with ptr
type2LlvmType' = go []
where
go acc = \case
MIR.TFun (MIR.TFun _ _) t2 -> go (snoc Ptr acc) t2
MIR.TFun t1 t2 -> go (snoc (type2LlvmType t1) acc) t2
t -> Function (type2LlvmType t) acc
initCodeGenerator :: Bool -> [MIR.Def] -> CodeGenerator initCodeGenerator :: Bool -> [MIR.Def] -> CodeGenerator
initCodeGenerator addGc scs = initCodeGenerator addGc scs =
CodeGenerator CodeGenerator
{ instructions = defaultStart <> if addGc then gcStart else [] { instructions = []
, functions = getFunctions scs , functions = getFunctions scs
, constructors = getConstructors scs , constructors = getConstructors scs
, customTypes = getTypes scs , customTypes = getTypes scs
, structTypes = mempty
, variableCount = 0 , variableCount = 0
, labelCount = 0 , labelCount = 0
, gcEnabled = addGc , gcEnabled = addGc
, locals = mempty
, globals = getGlobals scs
} }
defaultStart :: [LLVMIr]
defaultStart =
[ UnsafeRaw "target triple = \"x86_64-pc-linux-gnu\"\n"
, UnsafeRaw "target datalayout = \"e-m:o-i64:64-f80:128-n8:16:32:64-S128\"\n"
, UnsafeRaw "@.str = private unnamed_addr constant [3 x i8] c\"%i\n\", align 1\n"
, UnsafeRaw "@.non_exhaustive_patterns = private unnamed_addr constant [41 x i8] c\"Non-exhaustive patterns in case at %i:%i\n\"\n"
, UnsafeRaw "declare i32 @printf(ptr noalias nocapture, ...)\n"
, UnsafeRaw "declare i32 @exit(i32 noundef)\n"
, UnsafeRaw "declare ptr @malloc(i32 noundef)\n"
]
gcStart :: [LLVMIr]
gcStart =
[ UnsafeRaw "declare external void @cheap_init()\n"
, UnsafeRaw "declare external ptr @cheap_alloc(i64)\n"
, UnsafeRaw "declare external void @cheap_dispose()\n"
, UnsafeRaw "declare external ptr @cheap_the()\n"
, UnsafeRaw "declare external void @cheap_set_profiler(ptr, i1)\n"
, UnsafeRaw "declare external void @cheap_profiler_log_options(ptr, i64)\n"
]

View file

@ -1,36 +1,40 @@
{-# LANGUAGE LambdaCase #-} {-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedRecordDot #-}
{-# LANGUAGE OverloadedStrings #-}
module Codegen.Emits where module Codegen.Emits where
import Codegen.Auxillary import Auxiliary (snoc)
import Codegen.CompilerState import Codegen.Auxillary
import Codegen.LlvmIr as LIR import Codegen.CompilerState
import Control.Applicative ((<|>)) import Codegen.LlvmIr as LIR
import Control.Monad (when) import Control.Applicative (Applicative (liftA2), (<|>))
import Control.Monad.State (gets, modify) import Control.Monad (forM_, when, zipWithM_)
import Data.Bifunctor qualified as BI import Control.Monad.Extra (whenJust)
import Data.Char (ord) import Control.Monad.State (gets, modify)
import Data.Coerce (coerce) import Data.Char (ord)
import Data.Map qualified as Map import Data.Coerce (coerce)
import Data.Maybe (fromJust, fromMaybe, isNothing) import Data.Foldable.Extra (notNull)
import Data.Tuple.Extra (dupe, first, second) import qualified Data.Map as Map
import Debug.Trace (trace, traceShow) import Data.Maybe (fromJust, fromMaybe, isNothing)
import Grammar.Print import Data.Tuple.Extra (second)
import Monomorphizer.MonomorphizerIr as MIR import Grammar.Print (printTree)
import TypeChecker.TypeCheckerIr qualified as TIR import Monomorphizer.MonomorphizerIr
compileScs :: [MIR.Def] -> CompilerState ()
compileScs :: [Def] -> CompilerState ()
compileScs [] = do compileScs [] = do
emit $ UnsafeRaw "\n" emit $ UnsafeRaw "\n"
mapM_ createConstructor =<< gets (Map.toList . constructors)
-- as a last step create all the constructors -- as a last step create all the constructors
-- //TODO maybe merge this with the data type match? -- //TODO maybe merge this with the data type match?
c <- gets (Map.toList . constructors) where
mapM_ createConstructor (id, ci) = do
( \(id, ci) -> do let t = returnTypeCI ci
let t = returnTypeCI ci t' = type2LlvmType t
let t' = type2LlvmType t x = (mkCxtName, Ptr) : map (second type2LlvmType) ci.argumentsCI
let x = BI.second type2LlvmType <$> argumentsCI ci
emit $ Define FastCC t' id x emit $ Define FastCC t' id x
top <- getNewVar top <- getNewVar
ptr <- getNewVar ptr <- getNewVar
@ -56,7 +60,7 @@ compileScs [] = do
cTypes <- gets customTypes cTypes <- gets customTypes
enumerateOneM_ enumerateOneM_
( \i (TIR.Ident arg_n, arg_t) -> do ( \i (Ident arg_n, arg_t) -> do
let arg_t' = type2LlvmType arg_t let arg_t' = type2LlvmType arg_t
emit $ Comment (toIr arg_t' <> " " <> arg_n <> " " <> show i) emit $ Comment (toIr arg_t' <> " " <> arg_n <> " " <> show i)
elemPtr <- getNewVar elemPtr <- getNewVar
@ -78,11 +82,11 @@ compileScs [] = do
heapPtr <- getNewVar heapPtr <- getNewVar
useGc <- gets gcEnabled useGc <- gets gcEnabled
emit $ SetVariable heapPtr (if useGc then GcMalloc s else Malloc s) emit $ SetVariable heapPtr (if useGc then GcMalloc s else Malloc s)
emit $ Store arg_t' (VIdent (TIR.Ident arg_n) arg_t') Ptr heapPtr emit $ Store arg_t' (VIdent (Ident arg_n) arg_t') Ptr heapPtr
emit $ Store (Ref arg_t') (VIdent heapPtr arg_t') Ptr elemPtr emit $ Store (Ref arg_t') (VIdent heapPtr arg_t') Ptr elemPtr
Nothing -> do Nothing -> do
emit $ Comment "Just store" emit $ Comment "Just store"
emit $ Store arg_t' (VIdent (TIR.Ident arg_n) arg_t') Ptr elemPtr emit $ Store arg_t' (VIdent (Ident arg_n) arg_t') Ptr elemPtr
) )
(argumentsCI ci) (argumentsCI ci)
@ -95,34 +99,83 @@ compileScs [] = do
emit $ UnsafeRaw "\n" emit $ UnsafeRaw "\n"
modify $ \s -> s{variableCount = 0} modify $ \s -> s{variableCount = 0}
)
c compileScs (DBind bind : xs) = do
compileScs (MIR.DBind (MIR.Bind (name, t) args exp) : xs) = do
let t_return = type2LlvmType . last . flattenType $ t
emit $ UnsafeRaw "\n" emit $ UnsafeRaw "\n"
emit . Comment $ show name <> ": " <> show exp emit . Comment $ show name <> ": " <> show (fst exp)
let args' = map (second type2LlvmType) args
Function t_return t_args <- gets $ fst
. fromJust
. Map.lookup name
. globals
let args' = zip (mkCxtName : map fst args) t_args
emit $ Define FastCC t_return name args' emit $ Define FastCC t_return name args'
useGc <- gets gcEnabled modify $ \s -> s { locals = foldr insertArg s.locals args' }
when (name == "main") (mapM_ emit (firstMainContent useGc))
functionBody <- exprToValue exp -- Dereference ptr arguments
if name == "main" when (notNull args') $
then mapM_ emit $ lastMainContent useGc functionBody forM_ (tail args') $ \(x, t) -> when (t == Ptr) $ do
else emit $ Ret t_return functionBody let t_deref =
let
Function t ts = type2LlvmType . fromJust $ lookup x args
in
Function t (Ptr : ts)
emit . SetVariable (mkDerefName x)
$ Load t_deref Ptr x
whenJust mcxt loadFreeVars
gcEnabled <- gets gcEnabled
when isMain $ mapM_ emit (firstMainContent gcEnabled)
result <- exprToValue exp
if isMain
then mapM_ emit $ lastMainContent gcEnabled result
else emit $ Ret t_return result
emit DefineEnd emit DefineEnd
modify $ \s -> s{variableCount = 0} -- Reset variable count and empty locals
modify $ \s -> s { variableCount = 0, locals = mempty }
compileScs xs compileScs xs
compileScs (MIR.DData (MIR.Data typ ts) : xs) = do where
let (TIR.Ident outer_id) = extractTypeName typ loadFreeVars cxt = do
emit $ Comment "Load free variables"
zipWithM_ go cxt' [1 ..]
where
go (x, t) i = do
vc <- getNewVar
emit . SetVariable vc
$ GetElementPtrInbounds (CustomType $ mkClosureName name) Ptr (VIdent mkCxtName Ptr)
I32 (VInteger 0) I32 (VInteger i) -- TODO fix indices
emit . SetVariable x $ Load t Ptr vc
cxt' = map (second type2LlvmType) cxt
isMain = name == "main"
(name, args, exp, mcxt) = case bind of
Bind (name, _) args exp -> (name, args, exp, Nothing)
BindC cxt (name, _) args exp -> (name, args, exp, Just cxt)
insertArg (x, t) = snoc (x, LocalElem { val = VIdent x t, typ = t })
compileScs (DData (Data typ ts) : xs) = do
let (Ident outer_id) = extractTypeName typ
-- //TODO this could be extracted from the customTypes map -- //TODO this could be extracted from the customTypes map
let variantTypes fi = init $ map type2LlvmType (flattenType fi) let variantTypes fi = init $ map type2LlvmType (flattenType fi)
let biggestVariant = 7 + maximum (sum . (\(Inj _ fi) -> typeByteSize <$> variantTypes fi) <$> ts) let biggestVariant = 7 + maximum (sum . (\(Inj _ fi) -> typeByteSize <$> variantTypes fi) <$> ts)
emit $ LIR.Type (TIR.Ident outer_id) [I8, Array biggestVariant I8] -- Add data type (e.g. %List) to top of the file
addStructType_ (Ident outer_id) [I8, Array biggestVariant I8]
typeSets <- gets customTypes typeSets <- gets customTypes
mapM_ mapM_
( \(Inj inner_id fi) -> do ( \(Inj inner_id fi) -> do
let types = (\s -> if Map.member s typeSets then Ref s else s) <$> variantTypes fi let types = (\s -> if Map.member s typeSets then Ref s else s) <$> variantTypes fi
emit $ LIR.Type inner_id (I8 : types) -- Add constructor type (e.g. %Cons) to top of the file
addStructType_ inner_id (I8 : types)
) )
ts ts
compileScs xs compileScs xs
@ -149,16 +202,16 @@ lastMainContent False var =
, Ret I64 (VInteger 0) , Ret I64 (VInteger 0)
] ]
compileExp :: ExpT -> CompilerState () compileExp :: T Exp -> CompilerState ()
compileExp (MIR.ELit lit, _t) = emitLit lit compileExp (ELit lit, _t) = emitLit lit
compileExp (MIR.EAdd e1 e2, t) = emitAdd t e1 e2 compileExp (EAdd e1 e2, t) = emitAdd t e1 e2
compileExp (MIR.EVar name, _t) = emitIdent name compileExp (EVar name, _t) = emitIdent name
compileExp (MIR.EApp e1 e2, t) = emitApp t e1 e2 compileExp (EApp e1 e2, t) = emitApp t e1 e2
compileExp (MIR.ELet bind e, _) = emitLet bind e compileExp (ELet bind e, _) = emitLet bind e
compileExp (MIR.ECase e cs, t) = emitECased t e (map (t,) cs) compileExp (ECase e cs, t) = emitECased t e (map (t,) cs)
emitLet :: MIR.Bind -> ExpT -> CompilerState () emitLet :: Bind -> T Exp -> CompilerState ()
emitLet (MIR.Bind id [] innerExp) e = do emitLet (Bind id [] innerExp) e = do
evaled <- exprToValue innerExp evaled <- exprToValue innerExp
tempVar <- getNewVar tempVar <- getNewVar
let t = type2LlvmType . snd $ innerExp let t = type2LlvmType . snd $ innerExp
@ -168,14 +221,14 @@ emitLet (MIR.Bind id [] innerExp) e = do
compileExp e compileExp e
emitLet b _ = error $ "Non empty argument list in let-bind " <> show b emitLet b _ = error $ "Non empty argument list in let-bind " <> show b
emitECased :: MIR.Type -> ExpT -> [(MIR.Type, Branch)] -> CompilerState () emitECased :: Type -> T Exp -> [(Type, Branch)] -> CompilerState ()
emitECased t e cases = do emitECased t e cases = do
let cs = snd <$> cases let cs = snd <$> cases
let ty = type2LlvmType t let ty = type2LlvmType t
let rt = type2LlvmType (snd e) let rt = type2LlvmType (snd e)
vs <- exprToValue e vs <- exprToValue e
lbl <- getNewLabel lbl <- getNewLabel
let label = TIR.Ident $ "escape_" <> show lbl let label = Ident $ "escape_" <> show lbl
stackPtr <- getNewVar stackPtr <- getNewVar
emit $ SetVariable stackPtr (Alloca ty) emit $ SetVariable stackPtr (Alloca ty)
mapM_ (emitCases rt ty label stackPtr vs) cs mapM_ (emitCases rt ty label stackPtr vs) cs
@ -192,14 +245,14 @@ emitECased t e cases = do
res <- getNewVar res <- getNewVar
emit $ SetVariable res (Load ty Ptr stackPtr) emit $ SetVariable res (Load ty Ptr stackPtr)
where where
emitCases :: LLVMType -> LLVMType -> TIR.Ident -> TIR.Ident -> LLVMValue -> Branch -> CompilerState () emitCases :: LLVMType -> LLVMType -> Ident -> Ident -> LLVMValue -> Branch -> CompilerState ()
emitCases rt ty label stackPtr vs (Branch (MIR.PInj consId cs, _t) exp) = do emitCases rt ty label stackPtr vs (Branch (PInj consId cs, _t) exp) = do
emit $ Comment "Inj" emit $ Comment "Inj"
cons <- gets constructors cons <- gets constructors
let r = fromJust $ Map.lookup consId cons let r = fromJust $ Map.lookup consId cons
lbl_failPos <- (\x -> TIR.Ident $ "failed_" <> show x) <$> getNewLabel lbl_failPos <- (\x -> Ident $ "failed_" <> show x) <$> getNewLabel
lbl_succPos <- (\x -> TIR.Ident $ "success_" <> show x) <$> getNewLabel lbl_succPos <- (\x -> Ident $ "success_" <> show x) <$> getNewLabel
consVal <- getNewVar consVal <- getNewVar
emit $ SetVariable consVal (ExtractValue rt vs 0) emit $ SetVariable consVal (ExtractValue rt vs 0)
@ -215,10 +268,10 @@ emitECased t e cases = do
emit $ Store rt vs Ptr castPtr emit $ Store rt vs Ptr castPtr
emit $ SetVariable casted (Load (CustomType (coerce consId)) Ptr castPtr) emit $ SetVariable casted (Load (CustomType (coerce consId)) Ptr castPtr)
enumerateOneM_ enumerateOneM_
( \i c -> do ( \i (c, t) -> do
case c of case c of
PVar (x, topT) -> do PVar x -> do
let topT' = type2LlvmType topT let topT' = type2LlvmType t
let botT' = CustomType (coerce consId) let botT' = CustomType (coerce consId)
emit . Comment $ "ident " <> toIr topT' emit . Comment $ "ident " <> toIr topT'
cTypes <- gets customTypes cTypes <- gets customTypes
@ -228,7 +281,7 @@ emitECased t e cases = do
emit $ SetVariable deref (ExtractValue botT' (VIdent casted Ptr) i) emit $ SetVariable deref (ExtractValue botT' (VIdent casted Ptr) i)
emit $ SetVariable x (Load topT' Ptr deref) emit $ SetVariable x (Load topT' Ptr deref)
else emit $ SetVariable x (ExtractValue botT' (VIdent casted Ptr) i) else emit $ SetVariable x (ExtractValue botT' (VIdent casted Ptr) i)
PLit (_l, _t) -> error "Nested pattern matching to be implemented" PLit _l -> error "Nested pattern matching to be implemented"
PInj _id _ps -> error "Nested pattern matching to be implemented" PInj _id _ps -> error "Nested pattern matching to be implemented"
PCatch -> pure () PCatch -> pure ()
PEnum _id -> error "Nested pattern matching to be implemented" PEnum _id -> error "Nested pattern matching to be implemented"
@ -238,22 +291,22 @@ emitECased t e cases = do
emit $ Store ty val Ptr stackPtr emit $ Store ty val Ptr stackPtr
emit $ Br label emit $ Br label
emit $ Label lbl_failPos emit $ Label lbl_failPos
emitCases _rt ty label stackPtr vs (Branch (MIR.PLit (i, ct), t) exp) = do emitCases _rt ty label stackPtr vs (Branch (PLit i, t) exp) = do
emit $ Comment "Plit" emit $ Comment "Plit"
let i' = case i of let i' = case i of
MIR.LInt i -> VInteger i LInt i -> VInteger i
MIR.LChar i -> VChar (ord i) LChar i -> VChar (ord i)
ns <- getNewVar ns <- getNewVar
lbl_failPos <- (\x -> TIR.Ident $ "failed_" <> show x) <$> getNewLabel lbl_failPos <- (\x -> Ident $ "failed_" <> show x) <$> getNewLabel
lbl_succPos <- (\x -> TIR.Ident $ "success_" <> show x) <$> getNewLabel lbl_succPos <- (\x -> Ident $ "success_" <> show x) <$> getNewLabel
emit $ SetVariable ns (Icmp LLEq (type2LlvmType ct) vs i') emit $ SetVariable ns (Icmp LLEq (type2LlvmType t) vs i')
emit $ BrCond (VIdent ns ty) lbl_succPos lbl_failPos emit $ BrCond (VIdent ns ty) lbl_succPos lbl_failPos
emit $ Label lbl_succPos emit $ Label lbl_succPos
val <- exprToValue exp val <- exprToValue exp
emit $ Store ty val Ptr stackPtr emit $ Store ty val Ptr stackPtr
emit $ Br label emit $ Br label
emit $ Label lbl_failPos emit $ Label lbl_failPos
emitCases rt ty label stackPtr vs (Branch (MIR.PVar (id, _), _) exp) = do emitCases rt ty label stackPtr vs (Branch (PVar id, _) exp) = do
emit $ Comment "Pvar" emit $ Comment "Pvar"
-- //TODO this is pretty disgusting and would heavily benefit from a rewrite -- //TODO this is pretty disgusting and would heavily benefit from a rewrite
valPtr <- getNewVar valPtr <- getNewVar
@ -263,20 +316,20 @@ emitECased t e cases = do
val <- exprToValue exp val <- exprToValue exp
emit $ Store ty val Ptr stackPtr emit $ Store ty val Ptr stackPtr
emit $ Br label emit $ Br label
lbl_failPos <- (\x -> TIR.Ident $ "failed_" <> show x) <$> getNewLabel lbl_failPos <- (\x -> Ident $ "failed_" <> show x) <$> getNewLabel
emit $ Label lbl_failPos emit $ Label lbl_failPos
emitCases rt ty label stackPtr vs (Branch (MIR.PEnum (TIR.Ident "True$Bool"), t) exp) = do emitCases rt ty label stackPtr vs (Branch (PEnum (Ident "True$Bool"), t) exp) = do
emitCases rt ty label stackPtr vs (Branch (MIR.PLit (MIR.LInt 1, TLit "Bool"), t) exp) emitCases rt ty label stackPtr vs (Branch (PLit $ LInt 1, t) exp)
emitCases rt ty label stackPtr vs (Branch (MIR.PEnum (TIR.Ident "False$Bool"), _) exp) = do emitCases rt ty label stackPtr vs (Branch (PEnum (Ident "False$Bool"), _) exp) = do
emitCases rt ty label stackPtr vs (Branch (MIR.PLit (MIR.LInt 0, TLit "Bool"), t) exp) emitCases rt ty label stackPtr vs (Branch (PLit (LInt 0), t) exp)
emitCases rt ty label stackPtr vs br@(Branch (MIR.PEnum consId, _) exp) = do emitCases rt ty label stackPtr vs br@(Branch (PEnum consId, _) exp) = do
emit $ Comment "Penum" emit $ Comment "Penum"
cons <- gets constructors cons <- gets constructors
let r = Map.lookup consId cons let r = Map.lookup consId cons
when (isNothing r) (error $ "Constructor: '" ++ printTree consId ++ "' does not exist in cons state:\n" ++ show cons ++ "\nin pattern\n'" ++ printTree br ++ "'\n") when (isNothing r) (error $ "Constructor: '" ++ printTree consId ++ "' does not exist in cons state:\n" ++ show cons ++ "\nin pattern\n'" ++ printTree br ++ "'\n")
lbl_failPos <- (\x -> TIR.Ident $ "failed_" <> show x) <$> getNewLabel lbl_failPos <- (\x -> Ident $ "failed_" <> show x) <$> getNewLabel
lbl_succPos <- (\x -> TIR.Ident $ "success_" <> show x) <$> getNewLabel lbl_succPos <- (\x -> Ident $ "success_" <> show x) <$> getNewLabel
consVal <- getNewVar consVal <- getNewVar
emit $ SetVariable consVal (ExtractValue rt vs 0) emit $ SetVariable consVal (ExtractValue rt vs 0)
@ -295,98 +348,167 @@ emitECased t e cases = do
emit $ Store ty val Ptr stackPtr emit $ Store ty val Ptr stackPtr
emit $ Br label emit $ Br label
emit $ Label lbl_failPos emit $ Label lbl_failPos
emitCases _ ty label stackPtr _ (Branch (MIR.PCatch, _) exp) = do emitCases _ ty label stackPtr _ (Branch (PCatch, _) exp) = do
emit $ Comment "Pcatch" emit $ Comment "Pcatch"
val <- exprToValue exp val <- exprToValue exp
emit $ Store ty val Ptr stackPtr emit $ Store ty val Ptr stackPtr
emit $ Br label emit $ Br label
lbl_failPos <- (\x -> TIR.Ident $ "failed_" <> show x) <$> getNewLabel lbl_failPos <- (\x -> Ident $ "failed_" <> show x) <$> getNewLabel
emit $ Label lbl_failPos emit $ Label lbl_failPos
emitApp :: MIR.Type -> ExpT -> ExpT -> CompilerState () emitApp :: Type -> T Exp -> T Exp -> CompilerState ()
emitApp rt e1 e2 = appEmitter e1 e2 [] emitApp rt e1 e2 = do
where ((EVar name, t), args) <- go (EApp e1 e2, rt)
appEmitter :: ExpT -> ExpT -> [ExpT] -> CompilerState () vs <- getNewVar
appEmitter e1 e2 stack = do funcs <- gets functions
let newStack = e2 : stack consts <- gets constructors
case e1 of let visibility =
(MIR.EApp e1' e2', _) -> appEmitter e1' e2' newStack fromMaybe Local $
(MIR.EVar name, t) -> do Global <$ Map.lookup name consts
args <- traverse exprToValue newStack <|> Global <$ Map.lookup (name, t) funcs
vs <- getNewVar -- this piece of code could probably be improved, i.e remove the double `const Global`
funcs <- gets functions
consts <- gets constructors
let visibility =
fromMaybe Local $
Global <$ Map.lookup name consts
<|> Global <$ Map.lookup (name, t) funcs
-- this piece of code could probably be improved, i.e remove the double `const Global`
args' = map (first valueGetType . dupe) args
let call =
case name of
TIR.Ident ('l' : 't' : '$' : _) -> Icmp LLSlt I64 (snd (head args')) (snd (args' !! 1))
TIR.Ident ('$' : 'm' : 'i' : 'n' : 'u' : 's' : '$' : '$' : _) -> Sub I64 (snd (head args')) (snd (args' !! 1))
_ -> Call FastCC (type2LlvmType rt) visibility name args'
emit $ Comment $ show rt
emit $ SetVariable vs call
x -> error $ "The unspeakable happened: " <> show x
emitIdent :: TIR.Ident -> CompilerState () call <- case name of
Ident ('l' : 't' : '$' : _) ->
pure $ Icmp LLSlt I64 (snd (head args)) (snd (args !! 1))
Ident ('$' : 'm' : 'i' : 'n' : 'u' : 's' : '$' : '$' : _) ->
pure $ Sub I64 (snd (head args)) (snd (args !! 1))
-- FIXME
_ -> do
let closure_call LocalElem { typ = Ptr, val } = (mkDerefName name, (Ptr, val) : args)
(name, args) <- gets $ maybe (name, (Ptr, VNull) : args) closure_call
. lookup name
. locals
pure $ Call FastCC (type2LlvmType rt) visibility name args
emit $ Comment $ show (type2LlvmType rt)
emit $ SetVariable vs call
where
go :: T Exp -> CompilerState (T Exp, [(LLVMType, LLVMValue)])
go et@(e, _) = case e of
EApp e1 e2@(_, t) -> do
(x, as) <- go e1
a <- exprToValue e2
let t' = type2LlvmType' t
pure (x, snoc (t', a) as)
_ -> pure (et, [])
type2LlvmType' = \case
TFun _ _ -> Ptr
t -> type2LlvmType t
emitIdent :: Ident -> CompilerState ()
emitIdent id = do emitIdent id = do
-- !!this should never happen!! -- !!this should never happen!!
emit $ Comment "This should not have happened!" emit $ Comment "This should not have happened!"
emit $ Variable id emit $ Variable id
emit $ UnsafeRaw "\n" emit $ UnsafeRaw "\n"
emitLit :: MIR.Lit -> CompilerState () emitLit :: Lit -> CompilerState ()
emitLit i = do emitLit i = do
-- !!this should never happen!! -- !!this should never happen!!
let (i', t) = case i of let (i', t) = case i of
(MIR.LInt i'') -> (VInteger i'', I64) (LInt i'') -> (VInteger i'', I64)
(MIR.LChar i'') -> (VChar $ ord i'', I8) (LChar i'') -> (VChar $ ord i'', I8)
varCount <- getNewVar varCount <- getNewVar
emit $ Comment "This should not have happened!" emit $ Comment "This should not have happened!"
emit $ SetVariable varCount (Add t i' (VInteger 0)) emit $ SetVariable varCount (Add t i' (VInteger 0))
emitAdd :: MIR.Type -> ExpT -> ExpT -> CompilerState () emitAdd :: Type -> T Exp -> T Exp -> CompilerState ()
emitAdd t e1 e2 = do emitAdd t e1 e2 = do
v1 <- exprToValue e1 v1 <- exprToValue e1
v2 <- exprToValue e2 v2 <- exprToValue e2
v <- getNewVar v <- getNewVar
emit $ SetVariable v (Add (type2LlvmType t) v1 v2) emit $ SetVariable v (Add (type2LlvmType t) v1 v2)
exprToValue :: ExpT -> CompilerState LLVMValue
exprToValue = \case exprToValue :: T Exp -> CompilerState LLVMValue
(MIR.ELit i, _t) -> pure $ case i of exprToValue et@(e, t) = case e of
(MIR.LInt i) -> VInteger i ELit (LInt i) -> pure $ VInteger i
(MIR.LChar i) -> VChar $ ord i ELit (LChar c) -> pure . VChar $ ord c
(MIR.EVar (TIR.Ident "True$Bool"), _t) -> pure $ VInteger 1
(MIR.EVar (TIR.Ident "False$Bool"), _t) -> pure $ VInteger 0 EVar "True$Bool" -> pure $ VInteger 1
(MIR.EVar name, t) -> do EVar "False$Bool" -> pure $ VInteger 0
funcs <- gets functions
cons <- gets constructors EVar name -> gets (Map.lookup name . globals) >>= \case
let res = Just (typ@(Function _ ts), val) | length ts > 1 -> do
Map.lookup (name, t) funcs type_struct <- addStructType (mkClosureName name) [typ]
<|> ( \c -> emit $ Comment "Allocating structure"
FunctionInfo emit . SetVariable name $ Alloca type_struct
{ numArgs = numArgsCI c emit $ Store typ val Ptr name
, arguments = argumentsCI c pure $ VIdent name Ptr
}
) Just _ | name == "main" -> do
<$> Map.lookup name cons vc <- getNewVar
case res of emit $ SetVariable vc (Call FastCC I64 Global name [])
Just fi -> do pure $ VIdent vc I64
if numArgs fi == 0
then do
vc <- getNewVar Just (Function t_return [_], _) -> do
emit $ vc <- getNewVar
SetVariable emit $ SetVariable vc (Call FastCC t_return Global name [(Ptr, VNull)])
vc pure $ VIdent vc t_return
(Call FastCC (type2LlvmType t) Global name [])
pure $ VIdent vc (type2LlvmType t) Just _ -> error "Bad"
else pure $ VFunction name Global (type2LlvmType t)
Nothing -> pure $ VIdent name (type2LlvmType t) Nothing -> gets (Map.lookup name . constructors) >>= \case
e -> do
compileExp e Just ConstructorInfo {numArgsCI}
| numArgsCI == 0 -> do
vc <- getNewVar
emit $ SetVariable vc call
pure $ VIdent vc (type2LlvmType t)
| otherwise -> pure $ VFunction name Global (type2LlvmType t)
where
call = Call FastCC (type2LlvmType t) Global name []
Nothing -> gets $ val
. fromJust
. lookup name
. locals
EVarC cxt name -> do
let cxt' = flip map cxt $ \(x, t) -> let t' = type2LlvmType t
in (t', VIdent x t')
cxt'' <- gets $ (:cxt')
. fromJust
. Map.lookup name
. globals
-- Create a new type for function pointer and arguments
type_struct <- addStructType (mkClosureName name) $ map fst cxt''
emit $ Comment "Allocating structure"
emit . SetVariable name $ Alloca type_struct
let ptr_struct = VIdent name Ptr
storeArg (t, v) i = do
vc <- getNewVar
emit . SetVariable vc
$ GetElementPtrInbounds type_struct Ptr ptr_struct
I32 (VInteger 0) I32 (VInteger i) -- TODO fix indices
emit $ Store t v Ptr vc
-- Store arguments in structure
zipWithM_ storeArg cxt'' [0 ..]
pure ptr_struct
_ -> do
compileExp et
v <- getVarCount v <- getVarCount
pure $ VIdent (TIR.Ident $ show v) (getType e) pure $ VIdent (Ident $ show v) (getType et)
mkClosureName :: Ident -> Ident
mkClosureName (Ident s) = Ident $ "Closure_" ++ s
mkDerefName :: Ident -> Ident
mkDerefName (Ident s) = Ident $ s ++ "_deref"
mkCxtName :: Ident
mkCxtName = Ident "cxt"

View file

@ -9,17 +9,18 @@ module Codegen.LlvmIr (
Visibility (..), Visibility (..),
CallingConvention (..), CallingConvention (..),
ToIr (..), ToIr (..),
typeOf
) where ) where
import Data.List (intercalate) import Data.List (intercalate)
import TypeChecker.TypeCheckerIr (Ident (..)) import TypeChecker.TypeCheckerIr (Ident (..))
data CallingConvention = TailCC | FastCC | CCC | ColdCC deriving (Show, Eq, Ord) data CallingConvention = TailCC | FastCC | CCC | ColdCC deriving (Show, Eq, Ord)
instance ToIr CallingConvention where instance ToIr CallingConvention where
toIr :: CallingConvention -> String toIr :: CallingConvention -> String
toIr TailCC = "tailcc" toIr TailCC = "tailcc"
toIr FastCC = "fastcc" toIr FastCC = "fastcc"
toIr CCC = "ccc" toIr CCC = "ccc"
toIr ColdCC = "coldcc" toIr ColdCC = "coldcc"
-- | A datatype which represents some basic LLVM types -- | A datatype which represents some basic LLVM types
@ -38,6 +39,9 @@ data LLVMType
class ToIr a where class ToIr a where
toIr :: a -> String toIr :: a -> String
instance ToIr a => ToIr [a] where
toIr = concatMap toIr
instance ToIr LLVMType where instance ToIr LLVMType where
toIr :: LLVMType -> String toIr :: LLVMType -> String
toIr = \case toIr = \case
@ -66,8 +70,8 @@ data LLVMComp
instance ToIr LLVMComp where instance ToIr LLVMComp where
toIr :: LLVMComp -> String toIr :: LLVMComp -> String
toIr = \case toIr = \case
LLEq -> "eq" LLEq -> "eq"
LLNe -> "ne" LLNe -> "ne"
LLUgt -> "ugt" LLUgt -> "ugt"
LLUge -> "uge" LLUge -> "uge"
LLUlt -> "ult" LLUlt -> "ult"
@ -80,7 +84,7 @@ instance ToIr LLVMComp where
data Visibility = Local | Global deriving (Show, Eq, Ord) data Visibility = Local | Global deriving (Show, Eq, Ord)
instance ToIr Visibility where instance ToIr Visibility where
toIr :: Visibility -> String toIr :: Visibility -> String
toIr Local = "%" toIr Local = "%"
toIr Global = "@" toIr Global = "@"
{- | Represents a LLVM "value", as in an integer, a register variable, {- | Represents a LLVM "value", as in an integer, a register variable,
@ -92,16 +96,18 @@ data LLVMValue
| VIdent Ident LLVMType | VIdent Ident LLVMType
| VConstant String | VConstant String
| VFunction Ident Visibility LLVMType | VFunction Ident Visibility LLVMType
| VNull
deriving (Show, Eq, Ord) deriving (Show, Eq, Ord)
instance ToIr LLVMValue where instance ToIr LLVMValue where
toIr :: LLVMValue -> String toIr :: LLVMValue -> String
toIr v = case v of toIr v = case v of
VInteger i -> show i VInteger i -> show i
VChar i -> show i VChar i -> show i
VIdent (Ident n) _ -> "%" <> n VIdent (Ident n) _ -> "%" <> n
VFunction (Ident n) vis _ -> toIr vis <> n VFunction (Ident n) vis _ -> toIr vis <> n
VConstant s -> "c" <> show s VConstant s -> "c" <> show s
VNull -> "null"
type Params = [(Ident, LLVMType)] type Params = [(Ident, LLVMType)]
type Args = [(LLVMType, LLVMValue)] type Args = [(LLVMType, LLVMValue)]
@ -139,6 +145,21 @@ data LLVMIr
-- instructions should be used in its place -- instructions should be used in its place
deriving (Show, Eq, Ord) deriving (Show, Eq, Ord)
-- TODO add missing clauses
typeOf :: LLVMIr -> LLVMType
typeOf = \case
Add t _ _ -> t
Sub t _ _ -> t
Mul t _ _ -> t
Div t _ _ -> t
Load t _ _ -> t
Store t _ _ _ -> t
Type x _ -> CustomType x
SetVariable _ ir -> typeOf ir
-- | Converts a list of LLVMIr instructions to a string -- | Converts a list of LLVMIr instructions to a string
llvmIrToString :: [LLVMIr] -> String llvmIrToString :: [LLVMIr] -> String
llvmIrToString = go 0 llvmIrToString = go 0
@ -147,9 +168,9 @@ llvmIrToString = go 0
go _ [] = mempty go _ [] = mempty
go i (x : xs) = do go i (x : xs) = do
let (i', n) = case x of let (i', n) = case x of
Define{} -> (i + 1, 0) Define{} -> (i + 1, 0)
DefineEnd -> (i - 1, 0) DefineEnd -> (i - 1, 0)
_ -> (i, i) _ -> (i, i)
insToString n x <> go i' xs insToString n x <> go i' xs
-- \| Converts a LLVM inststruction to a String, allowing for printing etc. -- \| Converts a LLVM inststruction to a String, allowing for printing etc.
@ -224,10 +245,10 @@ llvmIrToString = go 0
, ")\n" , ")\n"
] ]
(Alloca t) -> unwords ["alloca", toIr t, "\n"] (Alloca t) -> unwords ["alloca", toIr t, "\n"]
(Malloc t) -> (Malloc t) ->
concat concat
[ "call ptr @malloc(i64 ", show t, ")\n"] [ "call ptr @malloc(i64 ", show t, ")\n"]
(GcMalloc t) -> (GcMalloc t) ->
concat concat
[ "call ptr @cheap_alloc(i64 ", show t, ")\n"] [ "call ptr @cheap_alloc(i64 ", show t, ")\n"]
(Store t1 val t2 (Ident id2)) -> (Store t1 val t2 (Ident id2)) ->

View file

@ -11,9 +11,11 @@ import Control.Monad.State (MonadState (get, put), State,
evalState) evalState)
import Data.Function (on) import Data.Function (on)
import Data.List (delete, mapAccumL, (\\)) import Data.List (delete, mapAccumL, (\\))
import Data.Tuple.Extra (first, second)
import LambdaLifterIr (T)
import qualified LambdaLifterIr as L
import Prelude hiding (exp) import Prelude hiding (exp)
import TypeChecker.TypeCheckerIr import TypeChecker.TypeCheckerIr hiding (T)
-- | Lift lambdas and let expression into supercombinators. -- | Lift lambdas and let expression into supercombinators.
-- Three phases: -- Three phases:
@ -21,12 +23,13 @@ import TypeChecker.TypeCheckerIr
-- @abstract@ converts lambdas into let expressions. -- @abstract@ converts lambdas into let expressions.
-- @collectScs@ moves every non-constant let expression to a top-level function. -- @collectScs@ moves every non-constant let expression to a top-level function.
-- --
lambdaLift :: Program -> Program lambdaLift :: Program -> L.Program
lambdaLift (Program ds) = Program (datatypes ++ binds) lambdaLift (Program ds) = L.Program (datatypes ++ binds)
where where
datatypes = flip filter ds $ \case DData _ -> True datatypes = [L.DData (toLirData d) | DData d <- ds]
_ -> False
binds = map DBind $ (collectScs . abstract . freeVars) [b | DBind b <- ds] binds = map L.DBind $ (collectScs . abstract . freeVars) [b | DBind b <- ds]
-- | Annotate free variables -- | Annotate free variables
freeVars :: [Bind] -> [ABind] freeVars :: [Bind] -> [ABind]
@ -36,7 +39,7 @@ freeVars binds = [ let ae = freeVarsExp [] e
| Bind n xs e <- binds | Bind n xs e <- binds
] ]
freeVarsExp :: Frees -> ExpT -> Ann AExpT freeVarsExp :: Frees -> T Exp -> Ann (T AExp)
freeVarsExp localVars (ae, t) = case ae of freeVarsExp localVars (ae, t) = case ae of
EVar n | elem (n,t) localVars -> Ann { frees = [(n, t)] EVar n | elem (n,t) localVars -> Ann { frees = [(n, t)]
, term = (AVar n, t) , term = (AVar n, t)
@ -121,27 +124,47 @@ data Ann a = Ann
, term :: a , term :: a
} deriving (Show, Eq) } deriving (Show, Eq)
data ABind = ABind Id [Id] (Ann AExpT) deriving (Show, Eq) data ABind = ABind (T Ident) [T Ident] (Ann (T AExp)) deriving (Show, Eq)
data ABranch = ABranch (Pattern, Type) (Ann AExpT) deriving (Show, Eq) data ABranch = ABranch (Pattern, Type) (Ann (T AExp)) deriving (Show, Eq)
type AExpT = (AExp, Type)
data AExp = AVar Ident data AExp = AVar Ident
| AInj Ident | AInj Ident
| ALit Lit | ALit Lit
| ALet (Ann ABind) (Ann AExpT) | ALet (Ann ABind) (Ann (T AExp))
| AApp (Ann AExpT) (Ann AExpT) | AApp (Ann (T AExp)) (Ann (T AExp))
| AAdd (Ann AExpT) (Ann AExpT) | AAdd (Ann (T AExp)) (Ann (T AExp))
| AAbs Ident (Ann AExpT) | AAbs Ident (Ann (T AExp))
| ACase (Ann AExpT) [Ann ABranch] | ACase (Ann (T AExp)) [Ann ABranch]
deriving (Show, Eq) deriving (Show, Eq)
abstract :: [ABind] -> [Bind]
data BBind = BBind (T Ident) [T Ident] (T BExp)
| BBindCxt [T Ident] (T Ident) [T Ident] (T BExp)
deriving (Eq, Ord, Show)
data BBranch = BBranch (T Pattern) (T BExp)
deriving (Eq, Ord, Show)
data BExp
= BVar Ident
| BVarC [T Ident] Ident
| BInj Ident
| BLit Lit
| BLet BBind (T BExp)
| BApp (T BExp)(T BExp)
| BAdd (T BExp)(T BExp)
| BCase (T BExp) [BBranch]
deriving (Eq, Ord, Show)
abstract :: [ABind] -> [BBind]
abstract bs = evalState (mapM (abstractAnnBind . Ann []) bs) 0 abstract bs = evalState (mapM (abstractAnnBind . Ann []) bs) 0
abstractAnnBind :: Ann ABind -> State Int Bind abstractAnnBind :: Ann ABind -> State Int BBind
abstractAnnBind Ann { term = ABind name vars annae } = abstractAnnBind Ann { term = ABind name vars annae } =
Bind name (vars' <|| vars) <$> abstractAnnExp annae' BBind name (vars' <|| vars) <$> abstractAnnExp annae'
where where
(annae', vars') = go [] annae (annae', vars') = go [] annae
where where
@ -149,24 +172,27 @@ abstractAnnBind Ann { term = ABind name vars annae } =
Ann { term = (AAbs x ae, TFun t _) } -> go (snoc (x, t) acc) ae Ann { term = (AAbs x ae, TFun t _) } -> go (snoc (x, t) acc) ae
ae -> (ae, acc) ae -> (ae, acc)
abstractAnnExp :: Ann AExpT -> State Int ExpT abstractAnnExp :: Ann (T AExp) -> State Int (T BExp)
abstractAnnExp Ann {frees, term = (annae, typ) } = case annae of abstractAnnExp Ann {frees, term = (annae, typ) } = case annae of
AVar n -> pure (EVar n, typ) AVar n -> pure (BVar n, typ)
AInj n -> pure (EInj n, typ) AInj n -> pure (BInj n, typ)
ALit lit -> pure (ELit lit, typ) ALit lit -> pure (BLit lit, typ)
AApp annae1 annae2 -> (, typ) <$> onM EApp abstractAnnExp annae1 annae2 AApp annae1 annae2 -> (, typ) <$> onM BApp abstractAnnExp annae1 annae2
AAdd annae1 annae2 -> (, typ) <$> onM EAdd abstractAnnExp annae1 annae2 AAdd annae1 annae2 -> (, typ) <$> onM BAdd abstractAnnExp annae1 annae2
-- \x. \y. x + y + z ⇒ let sc x y z = x + y + z in sc
AAbs x annae' -> do AAbs x annae' -> do
i <- nextNumber i <- nextNumber
rhs <- abstractAnnExp annae'' rhs <- abstractAnnExp annae''
let sc_name = Ident ("sc_" ++ show i) let sc_name = Ident ("sc_" ++ show i)
e@(_, t) = foldl applyFree (EVar sc_name, typ) frees sc | null frees = (BVar sc_name, typ)
pure (ELet (Bind (sc_name, typ) vars rhs) e ,t) | otherwise = (BVarC frees sc_name, typ)
bind | null frees = BBind (sc_name, typ) vars rhs
| otherwise = BBindCxt frees (sc_name, typ) vars rhs
pure (BLet bind sc ,typ)
where where
vars = frees <| (x, t_x) <|| ys vars = [(x, t_x)] <|| ys
t_x = case typ of TFun t _ -> t t_x = case typ of TFun t _ -> t
_ -> error "Impossible" _ -> error "Impossible"
@ -176,54 +202,48 @@ abstractAnnExp Ann {frees, term = (annae, typ) } = case annae of
Ann { term = (AAbs x ae, TFun t _) } -> go (snoc (x, t) acc) ae Ann { term = (AAbs x ae, TFun t _) } -> go (snoc (x, t) acc) ae
ae -> (ae, acc) ae -> (ae, acc)
applyFree :: (Exp' Type, Type) -> (Ident, Type) -> (Exp' Type, Type)
applyFree (e, t_e) (x, t_x) = (EApp (e, t_e) (EVar x, t_x), t_e')
where
t_e' = case t_e of TFun _ t -> t
_ -> error "Impossible"
ACase annae' bs -> do ACase annae' bs -> do
bs <- mapM go bs bs <- mapM go bs
e <- abstractAnnExp annae' e <- abstractAnnExp annae'
pure (ECase e bs, typ) pure (BCase e bs, typ)
where where
go Ann { term = ABranch p annae } = Branch p <$> abstractAnnExp annae go Ann { term = ABranch p annae } = BBranch p <$> abstractAnnExp annae
ALet b annae' -> ALet b annae' ->
(, typ) <$> liftA2 ELet (abstractAnnBind b) (abstractAnnExp annae') (, typ) <$> liftA2 BLet (abstractAnnBind b) (abstractAnnExp annae')
-- | Collects supercombinators by lifting non-constant let expressions -- | Collects supercombinators by lifting non-constant let expressions
collectScs :: [Bind] -> [Bind] collectScs :: [BBind] -> [L.Bind]
collectScs = concatMap collectFromRhs collectScs = concatMap collectFromRhs
where where
collectFromRhs (Bind name parms rhs) = collectFromRhs (BBind name parms rhs) =
let (rhs_scs, rhs') = collectScsExp rhs let (rhs_scs, rhs') = collectScsExp rhs
in Bind name parms rhs' : rhs_scs in L.Bind name parms rhs' : rhs_scs
collectFromRhs (BBindCxt cxt name parms rhs) =
let (rhs_scs, rhs') = collectScsExp rhs
in L.BindC cxt name parms rhs' : rhs_scs
collectScsExp :: ExpT -> ([Bind], ExpT) collectScsExp :: T BExp -> ([L.Bind], T L.Exp)
collectScsExp expT@(exp, typ) = case exp of collectScsExp (exp, typ) = case exp of
EVar _ -> ([], expT) BVar x -> ([], (L.EVar x, typ))
EInj _ -> ([], expT) BVarC as x -> ([], (L.EVarC as x, typ))
ELit _ -> ([], expT) BInj k -> ([], (L.EInj k, typ))
BLit lit -> ([], (L.ELit lit, typ))
EApp e1 e2 -> (scs1 ++ scs2, (EApp e1' e2', typ)) BApp e1 e2 -> (scs1 ++ scs2, (L.EApp e1' e2', typ))
where where
(scs1, e1') = collectScsExp e1 (scs1, e1') = collectScsExp e1
(scs2, e2') = collectScsExp e2 (scs2, e2') = collectScsExp e2
EAdd e1 e2 -> (scs1 ++ scs2, (EAdd e1' e2', typ)) BAdd e1 e2 -> (scs1 ++ scs2, (L.EAdd e1' e2', typ))
where where
(scs1, e1') = collectScsExp e1 (scs1, e1') = collectScsExp e1
(scs2, e2') = collectScsExp e2 (scs2, e2') = collectScsExp e2
EAbs par e -> (scs, (EAbs par e', typ))
where
(scs, e') = collectScsExp e
ECase e branches -> (scs ++ scs_e, (ECase e' branches', typ)) BCase e branches -> (scs ++ scs_e, (L.ECase e' branches', typ))
where where
(scs, branches') = mapAccumL f [] branches (scs, branches') = mapAccumL f [] branches
(scs_e, e') = collectScsExp e (scs_e, e') = collectScsExp e
@ -234,15 +254,24 @@ collectScsExp expT@(exp, typ) = case exp of
-- --
-- > f = let sc x y = rhs in e -- > f = let sc x y = rhs in e
-- --
ELet (Bind name parms rhs) e BLet (BBind name parms rhs) e
| null parms -> (rhs_scs ++ et_scs, (ELet bind et', snd et')) | null parms -> (rhs_scs ++ et_scs, (L.ELet name rhs' et', snd et'))
| otherwise -> (bind : rhs_scs ++ et_scs, et') | otherwise -> (bind : rhs_scs ++ et_scs, et')
where where
bind = Bind name parms rhs' bind = L.Bind name parms rhs'
(rhs_scs, rhs') = collectScsExp rhs (rhs_scs, rhs') = collectScsExp rhs
(et_scs, et') = collectScsExp e (et_scs, et') = collectScsExp e
collectScsBranch (Branch patt exp) = (scs, Branch patt exp')
BLet (BBindCxt cxt name parms rhs) e
| null parms -> (rhs_scs ++ et_scs, (L.ELet name rhs' et', snd et'))
| otherwise -> (bind : rhs_scs ++ et_scs, et')
where
bind = L.BindC cxt name parms rhs'
(rhs_scs, rhs') = collectScsExp rhs
(et_scs, et') = collectScsExp e
collectScsBranch (BBranch patt exp) = (scs, L.Branch (first toLirPattern patt) exp')
where (scs, exp') = collectScsExp exp where (scs, exp') = collectScsExp exp
nextNumber :: State Int Int nextNumber :: State Int Int
@ -259,3 +288,19 @@ xs <| x | elem x xs = xs
(<||) :: Eq a => [a] -> [a] -> [a] (<||) :: Eq a => [a] -> [a] -> [a]
xs <|| ys = foldl (<|) xs ys xs <|| ys = foldl (<|) xs ys
toLirData (Data t injs) = L.Data t (map toLirInj injs)
toLirInj (Inj n t) = L.Inj n t
toLirPattern :: Pattern -> L.Pattern
toLirPattern = \case
PVar x -> L.PVar x
PLit lit -> L.PLit lit
PCatch -> L.PCatch
PEnum k -> L.PEnum k
PInj k ps -> L.PInj k (map (first toLirPattern) ps)

140
src/LambdaLifterIr.hs Normal file
View file

@ -0,0 +1,140 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PatternSynonyms #-}
module LambdaLifterIr (
module Grammar.Abs,
module LambdaLifterIr,
module TypeChecker.TypeCheckerIr
) where
import Data.List (intercalate)
import Grammar.Abs (Lit (..))
import Grammar.Print
import Prelude hiding (exp)
import qualified Prelude as C (Eq, Ord, Show)
import TypeChecker.TypeCheckerIr (Ident (..), TVar (..), Type (..))
newtype Program = Program [Def]
deriving (C.Eq, C.Ord, C.Show)
data Def
= DBind Bind
| DData Data
deriving (C.Eq, C.Ord, C.Show)
data Data = Data Type [Inj]
deriving (C.Eq, C.Ord, C.Show)
data Inj = Inj Ident Type
deriving (C.Eq, C.Ord, C.Show)
data Pattern
= PVar Ident
| PLit Lit
| PCatch
| PEnum Ident
| PInj Ident [(Pattern, Type)]
deriving (C.Eq, C.Ord, C.Show)
data Exp
= EVar Ident
| EVarC [T Ident] Ident
| EInj Ident
| ELit Lit
| ELet (T Ident) (T Exp) (T Exp)
| EApp (T Exp)(T Exp)
| EAdd (T Exp)(T Exp)
| ECase (T Exp) [Branch]
deriving (C.Eq, C.Ord, C.Show)
type T a = (a, Type)
data Bind = Bind (T Ident) [T Ident] (T Exp)
| BindC [T Ident] (T Ident) [T Ident] (T Exp)
deriving (C.Eq, C.Ord, C.Show)
data Branch = Branch (T Pattern) (T Exp)
deriving (C.Eq, C.Ord, C.Show)
instance Print Program where
prt i (Program sc) = prt i sc
instance Print Bind where
prt i (Bind sig parms rhs) = concatD
[ prt i sig
, prt i parms
, doc $ showString "="
, prt i rhs
]
prt i (BindC cxt sig parms rhs) =
prPrec i 0 $
concatD
[ doc . showString $ "{" ++ intercalate ", " (map (\(x, _) -> printTree x ++ "^") cxt) ++ "}" ++ printTree sig
, prt i parms
, doc $ showString "="
, prt i rhs
]
instance Print [Bind] where
prt _ [] = concatD []
prt i [x] = concatD [prt i x]
prt i (x : xs) = concatD [prt i x, doc (showString ";"), prt i xs]
instance Print Exp where
prt i = \case
EVar lident -> prPrec i 3 (concatD [prt 0 lident])
EVarC as lident -> doc . showString
$ "{" ++ intercalate ", " (map go as) ++ "}" ++ printTree lident
where
go (x, _) = printTree x ++ "^=" ++ printTree (EVar x)
EInj uident -> prPrec i 3 (concatD [prt 0 uident])
ELit lit -> prPrec i 3 (concatD [prt 0 lit])
EApp exp1 exp2 -> prPrec i 2 (concatD [prt 2 exp1, prt 3 exp2])
EAdd exp1 exp2 -> prPrec i 1 (concatD [prt 1 exp1, doc (showString "+"), prt 2 exp2])
ELet lident exp1 exp2 -> prPrec i 0 (concatD [doc (showString "let"), prt 0 lident, doc (showString "="), prt 0 exp1 , doc (showString "in"), prt 0 exp2])
ECase exp branchs -> prPrec i 0 (concatD [doc (showString "case"), prt 0 exp, doc (showString "of"), doc (showString "{"), prt 0 branchs, doc (showString "}")])
instance Print Branch where
prt i (Branch (pattern_, t) exp) = prPrec i 0 (concatD [doc (showString "("), prt 0 pattern_, doc (showString " : "), prt 0 t, doc (showString ")"), doc (showString "=>"), prt 0 exp])
instance Print [Branch] where
prt _ [] = concatD []
prt _ [x] = concatD [prt 0 x]
prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs]
instance Print Def where
prt i = \case
DBind bind -> prPrec i 0 (concatD [prt 0 bind])
DData data_ -> prPrec i 0 (concatD [prt 0 data_])
instance Print Data where
prt i = \case
Data type_ injs -> prPrec i 0 (concatD [doc (showString "data"), prt 0 type_, doc (showString "where"), doc (showString "{"), prt 0 injs, doc (showString "}")])
instance Print Inj where
prt i = \case
Inj uident type_ -> prt i (uident, type_)
instance Print [Inj] where
prt _ [] = concatD []
prt i [x] = prt i x
prt i (x : xs) = prPrec i 0 $ concatD [prt i x, doc $ showString "\n ", prt i xs]
instance Print Pattern where
prt i = \case
PVar name -> prPrec i 1 (concatD [prt 0 name])
PLit lit -> prPrec i 1 (concatD [prt 0 lit])
PCatch -> prPrec i 1 (concatD [doc (showString "_")])
PEnum name -> prPrec i 1 (concatD [prt 0 name])
PInj uident patterns -> prPrec i 0 (concatD [prt 0 uident, prt 1 patterns])
instance Print [Def] where
prt _ [] = concatD []
prt _ [x] = concatD [prt 0 x]
prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs]
pattern DBind' id vars expt = DBind (Bind id vars expt)
pattern DData' typ injs = DData (Data typ injs)

View file

@ -1,8 +1,11 @@
module Monomorphizer.DataTypeRemover (removeDataTypes) where module Monomorphizer.DataTypeRemover (removeDataTypes) where
import Monomorphizer.MonomorphizerIr qualified as M2 import Data.Bifunctor (Bifunctor (bimap))
import Monomorphizer.MorbIr qualified as M1 import Monomorphizer.MonomorphizerIr (Ident (..))
import TypeChecker.TypeCheckerIr (Ident (Ident)) import qualified Monomorphizer.MonomorphizerIr as M2
import qualified Monomorphizer.MorbIr as M1
import Prelude hiding (exp)
removeDataTypes :: M1.Program -> M2.Program removeDataTypes :: M1.Program -> M2.Program
removeDataTypes (M1.Program defs) = M2.Program (map pDef defs) removeDataTypes (M1.Program defs) = M2.Program (map pDef defs)
@ -18,43 +21,43 @@ pCons :: M1.Inj -> M2.Inj
pCons (M1.Inj ident t) = M2.Inj ident (pType t) pCons (M1.Inj ident t) = M2.Inj ident (pType t)
pType :: M1.Type -> M2.Type pType :: M1.Type -> M2.Type
pType (M1.TLit ident) = M2.TLit ident pType (M1.TLit ident) = M2.TLit ident
pType (M1.TFun t1 t2) = M2.TFun (pType t1) (pType t2) pType (M1.TFun t1 t2) = M2.TFun (pType t1) (pType t2)
pType (M1.TData (Ident "Bool") _) = M2.TLit (Ident "Bool") pType (M1.TData (Ident "Bool") _) = M2.TLit (Ident "Bool")
pType d = M2.TLit (Ident (newName d)) -- This is the step pType d = M2.TLit (Ident (newName d)) -- This is the step
newName :: M1.Type -> String newName :: M1.Type -> String
newName (M1.TLit (Ident str)) = str newName (M1.TLit (Ident str)) = str
newName (M1.TFun t1 t2) = newName t1 ++ newName t2 newName (M1.TFun t1 t2) = newName t1 ++ newName t2
newName (M1.TData (Ident str) args) = str ++ concatMap newName args newName (M1.TData (Ident str) args) = str ++ concatMap newName args
pBind :: M1.Bind -> M2.Bind pBind :: M1.Bind -> M2.Bind
pBind (M1.Bind id argIds expt) = M2.Bind (pId id) (map pId argIds) (pExpT expt) pBind (M1.Bind id argIds expt) = M2.Bind (pId id) (map pId argIds) (pExpT expt)
pBind (M1.BindC cxt id argIds expt) =
M2.BindC (map pId cxt) (pId id) (map pId argIds) (pExpT expt)
pId :: (Ident, M1.Type) -> (Ident, M2.Type) pId :: (Ident, M1.Type) -> (Ident, M2.Type)
pId (ident, t) = (ident, pType t) pId (ident, t) = (ident, pType t)
pExpT :: M1.ExpT -> M2.ExpT pExpT :: M1.T M1.Exp -> M2.T M2.Exp
pExpT (exp, t) = (pExp exp, pType t) pExpT (exp, t) = (pExp exp, pType t)
pExp :: M1.Exp -> M2.Exp pExp :: M1.Exp -> M2.Exp
pExp (M1.EVar ident) = M2.EVar ident pExp (M1.EVar ident) = M2.EVar ident
pExp (M1.ELit lit) = M2.ELit (pLit lit) pExp (M1.EVarC as ident) = M2.EVarC (map pId as) ident
pExp (M1.ELet bind expt) = M2.ELet (pBind bind) (pExpT expt) pExp (M1.ELit lit) = M2.ELit lit
pExp (M1.EApp e1 e2) = M2.EApp (pExpT e1) (pExpT e2) pExp (M1.ELet bind expt) = M2.ELet (pBind bind) (pExpT expt)
pExp (M1.EAdd e1 e2) = M2.EAdd (pExpT e1) (pExpT e2) pExp (M1.EApp e1 e2) = M2.EApp (pExpT e1) (pExpT e2)
pExp (M1.EAdd e1 e2) = M2.EAdd (pExpT e1) (pExpT e2)
pExp (M1.ECase expT branches) = M2.ECase (pExpT expT) (map pBranch branches) pExp (M1.ECase expT branches) = M2.ECase (pExpT expT) (map pBranch branches)
pBranch :: M1.Branch -> M2.Branch pBranch :: M1.Branch -> M2.Branch
pBranch (M1.Branch (patt, t) expt) = M2.Branch (pPattern patt, pType t) (pExpT expt) pBranch (M1.Branch (patt, t) expt) = M2.Branch (pPattern patt, pType t) (pExpT expt)
pPattern :: M1.Pattern -> M2.Pattern pPattern :: M1.Pattern -> M2.Pattern
pPattern (M1.PVar id) = M2.PVar (pId id) pPattern (M1.PVar ident) = M2.PVar ident
pPattern (M1.PLit (lit, t)) = M2.PLit (pLit lit, pType t) pPattern (M1.PLit lit) = M2.PLit lit
pPattern (M1.PInj ident patts) = M2.PInj ident (map pPattern patts) pPattern (M1.PInj ident patts) = M2.PInj ident (map (bimap pPattern pType) patts)
pPattern M1.PCatch = M2.PCatch pPattern M1.PCatch = M2.PCatch
pPattern (M1.PEnum ident) = M2.PEnum ident pPattern (M1.PEnum ident) = M2.PEnum ident
pLit :: M1.Lit -> M2.Lit
pLit (M1.LInt v) = M2.LInt v
pLit (M1.LChar c) = M2.LChar c

View file

@ -1,4 +1,5 @@
{-# LANGUAGE LambdaCase #-} {-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedRecordDot #-}
{- | For now, converts polymorphic functions to concrete ones based on usage. {- | For now, converts polymorphic functions to concrete ones based on usage.
Assumes lambdas are lifted. Assumes lambdas are lifted.
@ -25,30 +26,35 @@ bind) is added to the resulting set of binds.
-} -}
module Monomorphizer.Monomorphizer (monomorphize, morphExp, morphBind) where module Monomorphizer.Monomorphizer (monomorphize, morphExp, morphBind) where
import Monomorphizer.DataTypeRemover (removeDataTypes)
import Monomorphizer.MonomorphizerIr qualified as O
import Monomorphizer.MorbIr qualified as M
import TypeChecker.TypeCheckerIr (Ident (Ident))
import TypeChecker.TypeCheckerIr qualified as T
import Control.Monad.Reader ( import Control.Monad.Reader (MonadReader (ask, local),
MonadReader (ask, local), Reader, asks, runReader)
Reader, import Control.Monad.State (MonadState (get),
asks, StateT (runStateT), gets,
runReader, modify)
) import Data.Coerce (coerce)
import Control.Monad.State ( import qualified Data.Map as Map
MonadState (get), import Data.Maybe (catMaybes)
StateT (runStateT), import qualified Data.Set as Set
gets, import Debug.Trace (trace)
modify, import Grammar.Print (printTree)
) import Monomorphizer.DataTypeRemover (removeDataTypes)
import Data.Coerce (coerce) import qualified Monomorphizer.MonomorphizerIr as O
import Data.Map qualified as Map import qualified Monomorphizer.MorbIr as M
import Data.Maybe (catMaybes) -- import TypeChecker.TypeCheckerIr (Ident (Ident))
import Data.Set qualified as Set import LambdaLifterIr (Ident (..))
import Grammar.Print (printTree) -- import TypeChecker.TypeCheckerIr qualified as T
import Debug.Trace (trace) import qualified LambdaLifterIr as L
import Control.Monad.Reader (MonadReader (ask, local),
Reader, asks, runReader)
import Control.Monad.State (MonadState, StateT (runStateT),
gets, modify)
import qualified Data.Map as Map
import Data.Maybe (catMaybes, fromJust)
import qualified Data.Set as Set
import Data.Tuple.Extra (secondM)
import Grammar.Print (printTree)
{- | EnvM is the monad containing the read-only state as well as the {- | EnvM is the monad containing the read-only state as well as the
output state containing monomorphized functions and to-be monomorphized output state containing monomorphized functions and to-be monomorphized
@ -64,18 +70,18 @@ Binds, Polymorphic Data types (monomorphized in a later step) and
Marked bind, which means that it is in the process of monomorphization Marked bind, which means that it is in the process of monomorphization
and should not be monomorphized again. and should not be monomorphized again.
-} -}
data Outputted = Marked | Complete M.Bind | Data M.Type T.Data deriving (Show) data Outputted = Marked | Complete M.Bind | Data M.Type L.Data deriving (Show)
-- | Static environment. -- | Static environment.
data Env = Env data Env = Env
{ input :: Map.Map Ident T.Bind { input :: Map.Map Ident L.Bind
-- ^ All binds in the program. -- ^ All binds in the program.
, dataDefs :: Map.Map Ident T.Data , dataDefs :: Map.Map Ident L.Data
-- ^ All constructors mapped to their respective polymorphic data def -- ^ All constructors mapped to their respective polymorphic data def
-- which includes all other constructors. -- which includes all other constructors.
, polys :: Map.Map Ident M.Type , polys :: Map.Map Ident M.Type
-- ^ Maps polymorphic identifiers with concrete types. -- ^ Maps polymorphic identifiers with concrete types.
, locals :: Set.Set Ident , locals :: Set.Set Ident
-- ^ Local variables. -- ^ Local variables.
} }
@ -84,12 +90,13 @@ localExists :: Ident -> EnvM Bool
localExists ident = asks (Set.member ident . locals) localExists ident = asks (Set.member ident . locals)
-- | Gets a polymorphic bind from an id. -- | Gets a polymorphic bind from an id.
getInputBind :: Ident -> EnvM (Maybe T.Bind) getInputBind :: Ident -> EnvM (Maybe L.Bind)
getInputBind ident = asks (Map.lookup ident . input) getInputBind ident = asks (Map.lookup ident . input)
-- | Add monomorphic function derived from a polymorphic one, to env. -- | Add monomorphic function derived from a polymorphic one, to env.
addOutputBind :: M.Bind -> EnvM () addOutputBind :: M.Bind -> EnvM ()
addOutputBind b@(M.Bind (ident, _) _ _) = modify (Map.insert ident (Complete b)) addOutputBind b@(M.Bind (ident, _) _ _) = modify (Map.insert ident (Complete b))
addOutputBind b@(M.BindC _ (ident, _) _ _) = modify (Map.insert ident (Complete b))
{- | Marks a global bind as being processed, meaning that when encountered again, {- | Marks a global bind as being processed, meaning that when encountered again,
it should not be recursively processed. it should not be recursively processed.
@ -106,8 +113,8 @@ isConsMarked :: Ident -> EnvM Bool
isConsMarked ident = gets (Map.member ident) isConsMarked ident = gets (Map.member ident)
-- | Finds main bind. -- | Finds main bind.
getMain :: EnvM T.Bind getMain :: EnvM L.Bind
getMain = asks (\env -> case Map.lookup (T.Ident "main") (input env) of getMain = asks (\env -> case Map.lookup (Ident "main") (input env) of
Just mainBind -> mainBind Just mainBind -> mainBind
Nothing -> error "main not found in monomorphizer!" Nothing -> error "main not found in monomorphizer!"
) )
@ -116,13 +123,13 @@ getMain = asks (\env -> case Map.lookup (T.Ident "main") (input env) of
error when encountering different structures between the two arguments. Debug: error when encountering different structures between the two arguments. Debug:
First argument is the name of the bind. First argument is the name of the bind.
-} -}
mapTypes :: Ident -> T.Type -> M.Type -> [(Ident, M.Type)] mapTypes :: Ident -> L.Type -> M.Type -> [(Ident, M.Type)]
mapTypes _ident (T.TLit _) (M.TLit _) = [] mapTypes _ident (L.TLit _) (M.TLit _) = []
mapTypes _ident (T.TVar (T.MkTVar i1)) tm = [(i1, tm)] mapTypes _ident (L.TVar (L.MkTVar i1)) tm = [(i1, tm)]
mapTypes ident (T.TFun pt1 pt2) (M.TFun mt1 mt2) = mapTypes ident (L.TFun pt1 pt2) (M.TFun mt1 mt2) =
mapTypes ident pt1 mt1 mapTypes ident pt1 mt1
++ mapTypes ident pt2 mt2 ++ mapTypes ident pt2 mt2
mapTypes ident (T.TData tIdent pTs) (M.TData mIdent mTs) = mapTypes ident (L.TData tIdent pTs) (M.TData mIdent mTs) =
if tIdent /= mIdent if tIdent /= mIdent
then error "the data type names of monomorphic and polymorphic data types does not match" then error "the data type names of monomorphic and polymorphic data types does not match"
else foldl (\xs (p, m) -> mapTypes ident p m ++ xs) [] (zip pTs mTs) else foldl (\xs (p, m) -> mapTypes ident p m ++ xs) [] (zip pTs mTs)
@ -130,30 +137,30 @@ mapTypes ident t1 t2 = error $ "in bind: '" ++ printTree ident ++ "', " ++
"structure of types not the same: '" ++ printTree t1 ++ "', '" ++ printTree t2 ++ "'" "structure of types not the same: '" ++ printTree t1 ++ "', '" ++ printTree t2 ++ "'"
-- | Gets the mapped monomorphic type of a polymorphic type in the current context. -- | Gets the mapped monomorphic type of a polymorphic type in the current context.
getMonoFromPoly :: T.Type -> EnvM M.Type getMonoFromPoly :: L.Type -> EnvM M.Type
getMonoFromPoly t = do getMonoFromPoly t = do
env <- ask env <- ask
return $ getMono (polys env) t return $ getMono (polys env) t
where where
getMono :: Map.Map Ident M.Type -> T.Type -> M.Type getMono :: Map.Map Ident M.Type -> L.Type -> M.Type
getMono polys t = case t of getMono polys t = case t of
(T.TLit ident) -> M.TLit (coerce ident) (L.TLit ident) -> M.TLit ident
(T.TFun t1 t2) -> M.TFun (getMono polys t1) (getMono polys t2) (L.TFun t1 t2) -> M.TFun (getMono polys t1) (getMono polys t2)
(T.TVar (T.MkTVar ident)) -> case Map.lookup ident polys of (L.TVar (L.MkTVar ident)) -> case Map.lookup ident polys of
Just concrete -> concrete Just concrete -> concrete
Nothing -> M.TLit (Ident "void") Nothing -> M.TLit (Ident "void")
-- error $ "type not found! type: " ++ show ident ++ ", error in previous compilation steps" -- error $ "type not found! type: " ++ show ident ++ ", error in previous compilation steps"
(T.TData ident args) -> M.TData ident (map (getMono polys) args) (L.TData ident args) -> M.TData ident (map (getMono polys) args)
{- | If ident not already in env's output, morphed bind to output {- | If ident not already in env's output, morphed bind to output
(and all referenced binds within this bind). (and all referenced binds within this bind).
Returns the annotated bind name. Returns the annotated bind name.
-} -}
morphBind :: M.Type -> T.Bind -> EnvM Ident morphBind :: M.Type -> L.Bind -> EnvM Ident
morphBind expectedType b@(T.Bind (ident, btype) args (exp, expt)) = do morphBind expectedType b@(L.Bind (ident, btype) args (exp, expt)) = do
-- The "new name" is used to find out if it is already marked or not. -- The "new name" is used to find out if it is already marked or not.
let name' = newFuncName expectedType b let name' = newFuncName expectedType b
bindMarked <- isBindMarked (coerce name') bindMarked <- isBindMarked name'
local local
( \env -> ( \env ->
env env
@ -168,26 +175,59 @@ morphBind expectedType b@(T.Bind (ident, btype) args (exp, expt)) = do
else do else do
-- Mark so that this bind will not be processed in recursive or cyclic -- Mark so that this bind will not be processed in recursive or cyclic
-- function calls -- function calls
markBind (coerce name') markBind name'
expt' <- getMonoFromPoly expt expt' <- getMonoFromPoly expt
exp' <- morphExp expt' exp exp' <- morphExp expt' exp
-- Get monomorphic type sof args -- Get monomorphic type sof args
args' <- mapM morphArg args args' <- mapM morphArg args
addOutputBind $ addOutputBind $
M.Bind M.Bind
(coerce name', expectedType) (name', expectedType)
args' args'
(exp', expt') (exp', expt')
return name' return name'
morphBind expectedType b@(L.BindC cxt (ident, btype) args (exp, expt)) = do
-- The "new name" is used to find out if it is already marked or not.
let name' = newFuncName expectedType b
bindMarked <- isBindMarked name'
local
( \env ->
env
{ locals = Set.fromList (map fst args)
, polys = Map.fromList (mapTypes ident btype expectedType)
}
)
$ do
-- Return with right name if already marked
if bindMarked
then return name'
else do
-- Mark so that this bind will not be processed in recursive or cyclic
-- function calls
markBind name'
-- Get monomorphic type sof args
args' <- mapM morphArg args
cxt' <- mapM (secondM getMonoFromPoly) cxt
expt' <- getMonoFromPoly expt
exp' <- local (\env -> foldr (addLocal . fst) env cxt)
(morphExp expt' exp)
addOutputBind $
M.BindC cxt'
(name', expectedType)
args'
(exp', expt')
return name'
-- | Monomorphizes arguments of a bind. -- | Monomorphizes arguments of a bind.
morphArg :: (Ident, T.Type) -> EnvM (Ident, M.Type) morphArg :: (Ident, L.Type) -> EnvM (Ident, M.Type)
morphArg (ident, t) = do morphArg (ident, t) = do
t' <- getMonoFromPoly t t' <- getMonoFromPoly t
return (ident, t') return (ident, t')
-- | Gets the data bind from the name of a constructor. -- | Gets the data bind from the name of a constructor.
getInputData :: Ident -> EnvM (Maybe T.Data) getInputData :: Ident -> EnvM (Maybe L.Data)
getInputData ident = do getInputData ident = do
env <- ask env <- ask
return $ Map.lookup ident (dataDefs env) return $ Map.lookup ident (dataDefs env)
@ -201,50 +241,50 @@ morphCons expectedType ident newIdent = do
--trace ("Tjofras:" ++ show (newName expectedType ident)) $ return () --trace ("Tjofras:" ++ show (newName expectedType ident)) $ return ()
maybeD <- getInputData ident maybeD <- getInputData ident
case maybeD of case maybeD of
Nothing -> error $ "identifier '" ++ show ident ++ "' not found" -- closures can have unbound variables
Nothing -> pure ()
Just d -> do Just d -> do
modify (\output -> Map.insert newIdent (Data expectedType d) output) modify (\output -> Map.insert newIdent (Data expectedType d) output)
-- | Converts literals from input to output tree. -- | Converts literals from input to output tree.
convertLit :: T.Lit -> M.Lit convertLit :: L.Lit -> M.Lit
convertLit (T.LInt v) = M.LInt v convertLit (L.LInt v) = M.LInt v
convertLit (T.LChar v) = M.LChar v convertLit (L.LChar v) = M.LChar v
-- | Monomorphizes an expression, given an expected type. -- | Monomorphizes an expression, given an expected type.
morphExp :: M.Type -> T.Exp -> EnvM M.Exp morphExp :: M.Type -> L.Exp -> EnvM M.Exp
morphExp expectedType exp = case exp of morphExp expectedType exp = case exp of
T.ELit lit -> return $ M.ELit (convertLit lit) L.ELit lit -> return $ M.ELit lit
-- Constructor -- Constructor
T.EInj ident -> do L.EInj ident -> do
let ident' = newName (getDataType expectedType) ident let ident' = newName (getDataType expectedType) ident
morphCons expectedType ident ident' morphCons expectedType ident ident'
return $ M.EVar ident' return $ M.EVar ident'
T.EApp (e1, _t1) (e2, t2) -> do L.EApp (e1, _t1) (e2, t2) -> do
t2' <- getMonoFromPoly t2 t2' <- getMonoFromPoly t2
e2' <- morphExp t2' e2 e2' <- morphExp t2' e2
e1' <- morphExp (M.TFun t2' expectedType) e1 e1' <- morphExp (M.TFun t2' expectedType) e1
return $ M.EApp (e1', M.TFun t2' expectedType) (e2', t2') return $ M.EApp (e1', M.TFun t2' expectedType) (e2', t2')
T.EAdd (e1, t1) (e2, t2) -> do L.EAdd (e1, t1) (e2, t2) -> do
t1' <- getMonoFromPoly t1 t1' <- getMonoFromPoly t1
t2' <- getMonoFromPoly t2 t2' <- getMonoFromPoly t2
e1' <- morphExp t1' e1 e1' <- morphExp t1' e1
e2' <- morphExp t2' e2 e2' <- morphExp t2' e2
return $ M.EAdd (e1', expectedType) (e2', expectedType) return $ M.EAdd (e1', expectedType) (e2', expectedType)
T.EAbs ident (exp, t) -> local (\env -> env{locals = Set.insert ident (locals env)}) $ do L.ECase (exp, t) bs -> do
t' <- getMonoFromPoly t
morphExp t' exp
T.ECase (exp, t) bs -> do
t' <- getMonoFromPoly t t' <- getMonoFromPoly t
exp' <- morphExp t' exp exp' <- morphExp t' exp
bs' <- mapM morphBranch bs bs' <- mapM morphBranch bs
return $ M.ECase (exp', t') (catMaybes bs') return $ M.ECase (exp', t') (catMaybes bs')
-- Ideally constructors should be EInj, though this code handles them -- Ideally constructors should be EInj, though this code handles them
-- as well. -- as well.
T.EVar ident -> do -- FIXME MAKE EVAR AND EINJ SEPARATE!!!
L.EVar ident -> do
isLocal <- localExists ident isLocal <- localExists ident
if isLocal if isLocal
then do then do
return $ M.EVar (coerce ident) return $ M.EVar ident
else do else do
bind <- getInputBind ident bind <- getInputBind ident
case bind of case bind of
@ -252,38 +292,51 @@ morphExp expectedType exp = case exp of
Just bind' -> do Just bind' -> do
-- New bind to process -- New bind to process
newBindName <- morphBind expectedType bind' newBindName <- morphBind expectedType bind'
return $ M.EVar (coerce newBindName) return $ M.EVar newBindName
T.ELet (T.Bind (identB, tB) args (expB, tExpB)) (exp, tExp) -> L.EVarC as ident -> do
if length args > 0 then error "only constants in lets allowed" isLocal <- localExists ident
else do if isLocal
then do
return $ M.EVar ident
else do
bind <- fromJust <$> getInputBind ident
as' <- mapM (secondM getMonoFromPoly) as
-- New bind to process
newBindName <- morphBind expectedType bind
return $ M.EVarC as' newBindName
-- Ideally constructors should be EInj, though this code handles them
-- as well.
L.ELet (identB, tB) (expB, tExpB) (exp, tExp) -> do
tB' <- getMonoFromPoly tB tB' <- getMonoFromPoly tB
tExpB' <- getMonoFromPoly tExpB tExpB' <- getMonoFromPoly tExpB
tExp' <- getMonoFromPoly tExp tExp' <- getMonoFromPoly tExp
expB' <- morphExp tExpB' expB expB' <- morphExp tExpB' expB
exp' <- morphExp tExp' exp exp' <- local (addLocal identB) (morphExp tExp' exp)
return $ M.ELet (M.Bind (identB, tB') [] (expB', tExpB')) (exp', tExp') return $ M.ELet (M.Bind (identB, tB') [] (expB', tExpB')) (exp', tExp')
-- | Monomorphizes case-of branches. -- | Monomorphizes case-of branches.
morphBranch :: T.Branch -> EnvM (Maybe M.Branch) morphBranch :: L.Branch -> EnvM (Maybe M.Branch)
morphBranch (T.Branch (p, pt) (e, et)) = do morphBranch (L.Branch (p, pt) (e, et)) = do
pt' <- getMonoFromPoly pt pt' <- getMonoFromPoly pt
et' <- getMonoFromPoly et et' <- getMonoFromPoly et
env <- ask env <- ask
maybeMorphedPattern <- morphPattern p pt' maybeMorphedPattern <- morphPattern p pt'
case maybeMorphedPattern of case maybeMorphedPattern of
Nothing -> return Nothing Nothing -> return Nothing
Just (p', newLocals) -> Just (p', newLocals) ->
local (const env { locals = Set.union (locals env) newLocals }) $ do local (const env { locals = Set.union (locals env) newLocals }) $ do
e' <- morphExp et' e e' <- morphExp et' e
return $ Just (M.Branch (p', pt') (e', et')) return $ Just (M.Branch p' (e', et'))
morphPattern :: T.Pattern -> M.Type -> EnvM (Maybe (M.Pattern, Set.Set Ident)) morphPattern :: L.Pattern -> M.Type -> EnvM (Maybe (M.T M.Pattern, Set.Set Ident))
morphPattern p expectedType = case p of morphPattern p expectedType = case p of
T.PVar ident -> return $ Just (M.PVar (ident, expectedType), Set.singleton ident) L.PVar ident -> return $ Just ((M.PVar ident, expectedType), Set.singleton ident)
T.PLit lit -> return $ Just (M.PLit (convertLit lit, expectedType), Set.empty) L.PLit lit -> return $ Just ((M.PLit (convertLit lit), expectedType), Set.empty)
T.PCatch -> return $ Just (M.PCatch, Set.empty) L.PCatch -> return $ Just ((M.PCatch, expectedType), Set.empty)
T.PEnum ident -> return $ Just (M.PEnum (newName expectedType ident), Set.empty) L.PEnum ident -> return $ Just ((M.PEnum (newName expectedType ident), expectedType), Set.empty)
T.PInj ident pts -> do let newIdent = newName expectedType ident L.PInj ident pts -> do let newIdent = newName expectedType ident
outEnv <- get outEnv <- get
trace ("WOW: " ++ show (newName expectedType ident)) $ return () trace ("WOW: " ++ show (newName expectedType ident)) $ return ()
trace ("WOW2: " ++ show (outEnv)) $ return () trace ("WOW2: " ++ show (outEnv)) $ return ()
@ -297,13 +350,18 @@ morphPattern p expectedType = case p of
let maybePsSets = sequence psSets let maybePsSets = sequence psSets
case maybePsSets of case maybePsSets of
Nothing -> return Nothing Nothing -> return Nothing
Just psSets' -> return $ Just Just psSets' -> return $ Just
(M.PInj newIdent (map fst psSets'), Set.unions $ map snd psSets') ((M.PInj newIdent (map fst psSets'), expectedType), Set.unions $ map snd psSets')
else return Nothing else return Nothing
-- | Creates a new identifier for a function with an assigned type. -- | Creates a new identifier for a function with an assigned type.
newFuncName :: M.Type -> T.Bind -> Ident newFuncName :: M.Type -> L.Bind -> Ident
newFuncName t (T.Bind (ident@(Ident bindName), _) _ _) = newFuncName t (L.Bind (ident@(Ident bindName), _) _ _) =
if bindName == "main"
then Ident bindName
else newName t ident
newFuncName t (L.BindC _ (ident@(Ident bindName), _) _ _) =
if bindName == "main" if bindName == "main"
then Ident bindName then Ident bindName
else newName t ident else newName t ident
@ -317,8 +375,8 @@ newName t (Ident str) = Ident $ str ++ "$" ++ newName' t
newName' (M.TData (Ident str) ts) = str ++ foldl (\s t -> s ++ "." ++ newName' t) "" ts newName' (M.TData (Ident str) ts) = str ++ foldl (\s t -> s ++ "." ++ newName' t) "" ts
-- | Monomorphization step. -- | Monomorphization step.
monomorphize :: T.Program -> O.Program monomorphize :: L.Program -> O.Program
monomorphize (T.Program defs) = monomorphize (L.Program defs) =
removeDataTypes $ removeDataTypes $
M.Program M.Program
( getDefsFromOutput ( getDefsFromOutput
@ -336,7 +394,7 @@ runEnvM :: Output -> Env -> EnvM () -> Output
runEnvM o env (EnvM stateM) = snd $ runReader (runStateT stateM o) env runEnvM o env (EnvM stateM) = snd $ runReader (runStateT stateM o) env
-- | Creates the environment based on the input binds. -- | Creates the environment based on the input binds.
createEnv :: [T.Def] -> Env createEnv :: [L.Def] -> Env
createEnv defs = createEnv defs =
Env Env
{ input = Map.fromList bindPairs { input = Map.fromList bindPairs
@ -346,33 +404,34 @@ createEnv defs =
} }
where where
bindPairs = (map (\b -> (getBindName b, b)) . getBindsFromDefs) defs bindPairs = (map (\b -> (getBindName b, b)) . getBindsFromDefs) defs
dataPairs :: [(Ident, T.Data)] dataPairs :: [(Ident, L.Data)]
dataPairs = (foldl (\acc d@(T.Data _ cs) -> map ((,d) . getConsName) cs ++ acc) [] . getDataFromDefs) defs dataPairs = (foldl (\acc d@(L.Data _ cs) -> map ((,d) . getConsName) cs ++ acc) [] . getDataFromDefs) defs
-- | Gets a top-lefel function name. -- | Gets a top-lefel function name.
getBindName :: T.Bind -> Ident getBindName :: L.Bind -> Ident
getBindName (T.Bind (ident, _) _ _) = ident getBindName (L.Bind (ident, _) _ _) = ident
getBindName (L.BindC _ (ident, _) _ _) = ident
-- Helper functions -- Helper functions
-- Gets custom data declarations form defs. -- Gets custom data declarations form defs.
getDataFromDefs :: [T.Def] -> [T.Data] getDataFromDefs :: [L.Def] -> [L.Data]
getDataFromDefs = getDataFromDefs =
foldl foldl
( \bs -> \case ( \bs -> \case
T.DBind _ -> bs L.DBind _ -> bs
T.DData d -> d : bs L.DData d -> d : bs
) )
[] []
getConsName :: T.Inj -> Ident getConsName :: L.Inj -> Ident
getConsName (T.Inj ident _) = ident getConsName (L.Inj ident _) = ident
getBindsFromDefs :: [T.Def] -> [T.Bind] getBindsFromDefs :: [L.Def] -> [L.Bind]
getBindsFromDefs = getBindsFromDefs =
foldl foldl
( \bs -> \case ( \bs -> \case
T.DBind b -> b : bs L.DBind b -> b : bs
T.DData _ -> bs L.DData _ -> bs
) )
[] []
@ -384,19 +443,19 @@ getDefsFromOutput o =
(binds, dataInput) = splitBindsAndData o (binds, dataInput) = splitBindsAndData o
-- | Splits the output into binds and data declaration components (used in createNewData) -- | Splits the output into binds and data declaration components (used in createNewData)
splitBindsAndData :: Output -> ([M.Bind], [(Ident, M.Type, T.Data)]) splitBindsAndData :: Output -> ([M.Bind], [(Ident, M.Type, L.Data)])
splitBindsAndData output = splitBindsAndData output =
foldl foldl
( \(oBinds, oData) (ident, o) -> case o of ( \(oBinds, oData) (ident, o) -> case o of
Marked -> error "internal bug in monomorphizer" Marked -> error "internal bug in monomorphizer"
Complete b -> (b : oBinds, oData) Complete b -> (b : oBinds, oData)
Data t d -> (oBinds, (ident, t, d) : oData) Data t d -> (oBinds, (ident, t, d) : oData)
) )
([], []) ([], [])
(Map.toList output) (Map.toList output)
-- | Converts all found constructors to monomorphic data declarations. -- | Converts all found constructors to monomorphic data declarations.
createNewData :: [(Ident, M.Type, T.Data)] -> Map.Map Ident M.Data -> Map.Map Ident M.Data createNewData :: [(Ident, M.Type, L.Data)] -> Map.Map Ident M.Data -> Map.Map Ident M.Data
createNewData [] o = o createNewData [] o = o
createNewData ((consIdent, consType, polyData) : input) o = createNewData ((consIdent, consType, polyData) : input) o =
createNewData input $ createNewData input $
@ -406,14 +465,17 @@ createNewData ((consIdent, consType, polyData) : input) o =
(M.Data newDataType [newCons]) (M.Data newDataType [newCons])
o o
where where
T.Data (T.TData polyDataIdent _) _ = polyData L.Data (L.TData polyDataIdent _) _ = polyData
newDataType = getDataType consType newDataType = getDataType consType
newDataName = newName newDataType polyDataIdent newDataName = newName newDataType polyDataIdent
newCons = M.Inj consIdent consType newCons = M.Inj consIdent consType
-- | Gets the Data Type of a constructor type (a -> Just a becomes Just a). -- | Gets the Data Type of a constructor type (a -> Just a becomes Just a).
getDataType :: M.Type -> M.Type getDataType :: M.Type -> M.Type
getDataType (M.TFun _t1 t2) = getDataType t2 getDataType (M.TFun _t1 t2) = getDataType t2
getDataType tData@(M.TData _ _) = tData getDataType tData@(M.TData _ _) = tData
getDataType _ = error "???" getDataType _ = error "???"
addLocal :: Ident -> Env -> Env
addLocal x env = env { locals = Set.insert x env.locals }

View file

@ -1,11 +1,14 @@
{-# LANGUAGE LambdaCase #-} {-# LANGUAGE LambdaCase #-}
module Monomorphizer.MonomorphizerIr (module Monomorphizer.MonomorphizerIr) where module Monomorphizer.MonomorphizerIr (
module Monomorphizer.MonomorphizerIr,
module LambdaLifterIr
) where
import Grammar.Print import Data.List (intercalate)
import TypeChecker.TypeCheckerIr qualified as TIR (Ident (..)) import Grammar.Print
import LambdaLifterIr (Ident (..), Lit (..))
type Id = (TIR.Ident, Type) import Prelude hiding (exp)
newtype Program = Program [Def] newtype Program = Program [Def]
deriving (Show, Ord, Eq) deriving (Show, Ord, Eq)
@ -16,90 +19,80 @@ data Def = DBind Bind | DData Data
data Data = Data Type [Inj] data Data = Data Type [Inj]
deriving (Show, Ord, Eq) deriving (Show, Ord, Eq)
data Bind = Bind Id [Id] ExpT data Bind = Bind (T Ident) [T Ident] (T Exp)
| BindC [T Ident] (T Ident) [T Ident] (T Exp)
deriving (Show, Ord, Eq) deriving (Show, Ord, Eq)
type T a = (a, Type)
data Exp data Exp
= EVar TIR.Ident = EVar Ident
| EVarC [T Ident] Ident
| ELit Lit | ELit Lit
| ELet Bind ExpT | ELet Bind (T Exp)
| EApp ExpT ExpT | EApp (T Exp) (T Exp)
| EAdd ExpT ExpT | EAdd (T Exp) (T Exp)
| ECase ExpT [Branch] | ECase (T Exp) [Branch]
deriving (Show, Ord, Eq) deriving (Show, Ord, Eq)
data Pattern data Pattern
= PVar Id = PVar Ident
| PLit (Lit, Type) | PLit Lit
| PInj TIR.Ident [Pattern] | PInj Ident [T Pattern]
| PCatch | PCatch
| PEnum TIR.Ident | PEnum Ident
deriving (Eq, Ord, Show) deriving (Eq, Ord, Show)
data Branch = Branch (Pattern, Type) ExpT data Branch = Branch (T Pattern) (T Exp)
deriving (Eq, Ord, Show) deriving (Eq, Ord, Show)
type ExpT = (Exp, Type) data Inj = Inj Ident Type
data Inj = Inj TIR.Ident Type
deriving (Show, Ord, Eq) deriving (Show, Ord, Eq)
data Lit data Type = TLit Ident | TFun Type Type
= LInt Integer
| LChar Char
deriving (Show, Ord, Eq)
data Type = TLit TIR.Ident | TFun Type Type
deriving (Show, Ord, Eq) deriving (Show, Ord, Eq)
flattenType :: Type -> [Type] flattenType :: Type -> [Type]
flattenType (TFun t1 t2) = t1 : flattenType t2 flattenType (TFun t1 t2) = t1 : flattenType t2
flattenType x = [x] flattenType x = [x]
instance Print Program where instance Print Program where
prt i (Program sc) = prPrec i 0 $ prt 0 sc prt i (Program sc) = prPrec i 0 $ prt 0 sc
instance Print (Bind) where instance Print Bind where
prt i (Bind sig@(name, _) parms rhs) = prt i (Bind sig@(name, _) parms rhs) =
prPrec i 0 $ prPrec i 0 $
concatD concatD
[ prtSig sig [ prt 0 sig
, prt 0 name , prt 0 name
, prtIdPs 0 parms , prt 0 parms
, doc $ showString "=" , doc $ showString "="
, prt 0 rhs , prt 0 rhs
] ]
prtSig :: Id -> Doc prt i (BindC cxt sig parms rhs) =
prtSig (name, t) = prPrec i 0 $
concatD concatD
[ prt 0 name [ doc . showString $ "{" ++ intercalate ", " (map (\(x, _) -> printTree x ++ "^") cxt) ++ "}" ++ printTree sig
, doc $ showString ":" , prt i parms
, prt 0 t , doc $ showString "="
, doc $ showString ";" , prt i rhs
] ]
instance Print (ExpT) where
prt i (e, t) =
concatD
[ doc $ showString "("
, prt i e
, doc $ showString ","
, prt i t
, doc $ showString ")"
]
instance Print [Bind] where instance Print [Bind] where
prt _ [] = concatD [] prt _ [] = concatD []
prt _ [x] = concatD [prt 0 x] prt _ [x] = concatD [prt 0 x]
prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs] prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs]
prtIdPs :: Int -> [Id] -> Doc
prtIdPs i = prPrec i 0 . concatD . map (prt i)
instance Print Exp where instance Print Exp where
prt i = \case prt i = \case
EVar name -> prPrec i 3 $ prt 0 name EVar name -> prPrec i 3 $ prt 0 name
EVarC as lident -> doc . showString
$ "{" ++ intercalate ", " (map go as) ++ "}" ++ printTree lident
where
go (x, _) = printTree x ++ "^=" ++ printTree (EVar x)
ELit lit -> prPrec i 3 $ prt 0 lit ELit lit -> prPrec i 3 $ prt 0 lit
ELet b e -> ELet b e ->
prPrec i 3 $ prPrec i 3 $
@ -134,16 +127,16 @@ instance Print Exp where
] ]
instance Print Branch where instance Print Branch where
prt i (Branch (pattern_, t) exp) = prPrec i 0 (concatD [doc (showString "("), prt 0 pattern_, doc (showString " : "), prt 0 t, doc (showString ")"), doc (showString "=>"), prt 0 exp]) prt i (Branch patt exp) = prPrec i 0 (concatD [prt i patt, doc (showString "=>"), prt 0 exp])
instance Print [Branch] where instance Print [Branch] where
prt _ [] = concatD [] prt _ [] = concatD []
prt _ [x] = concatD [prt 0 x] prt _ [x] = concatD [prt 0 x]
prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs] prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs]
instance Print Def where instance Print Def where
prt i = \case prt i = \case
DBind bind -> prPrec i 0 (concatD [prt 0 bind]) DBind bind -> prPrec i 0 (concatD [prt 0 bind])
DData data_ -> prPrec i 0 (concatD [prt 0 data_]) DData data_ -> prPrec i 0 (concatD [prt 0 data_])
instance Print Data where instance Print Data where
@ -152,23 +145,23 @@ instance Print Data where
instance Print Inj where instance Print Inj where
prt i = \case prt i = \case
Inj uident type_ -> prPrec i 0 (concatD [prt 0 uident, doc (showString ":"), prt 0 type_]) Inj uident type_ -> prt i (uident, type_)
instance Print Pattern where instance Print Pattern where
prt i = \case prt i = \case
PVar name -> prPrec i 1 (concatD [prt 0 name]) PVar name -> prPrec i 1 (concatD [prt 0 name])
PLit (lit, _) -> prPrec i 1 (concatD [prt 0 lit]) PLit lit -> prPrec i 1 (concatD [prt 0 lit])
PCatch -> prPrec i 1 (concatD [doc (showString "_")]) PCatch -> prPrec i 1 (concatD [doc (showString "_")])
PEnum name -> prPrec i 1 (concatD [prt 0 name]) PEnum name -> prPrec i 1 (concatD [prt 0 name])
PInj uident patterns -> prPrec i 0 (concatD [prt 0 uident, prt 1 patterns]) PInj uident patterns -> prPrec i 0 (concatD [prt 0 uident, prt 1 patterns])
instance Print [Def] where instance Print [Def] where
prt _ [] = concatD [] prt _ [] = concatD []
prt _ [x] = concatD [prt 0 x] prt _ [x] = concatD [prt 0 x]
prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs] prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs]
instance Print [Type] where instance Print [Type] where
prt _ [] = concatD [] prt _ [] = concatD []
prt _ (x : xs) = concatD [prt 0 x, doc (showString " "), prt 0 xs] prt _ (x : xs) = concatD [prt 0 x, doc (showString " "), prt 0 xs]
instance Print Type where instance Print Type where
@ -176,7 +169,3 @@ instance Print Type where
TLit uident -> prPrec i 1 (concatD [prt 0 uident]) TLit uident -> prPrec i 1 (concatD [prt 0 uident])
TFun type_1 type_2 -> prPrec i 0 (concatD [prt 1 type_1, doc (showString "->"), prt 0 type_2]) TFun type_1 type_2 -> prPrec i 0 (concatD [prt 1 type_1, doc (showString "->"), prt 0 type_2])
instance Print Lit where
prt i = \case
LInt int -> prt i int
LChar char -> prt i char

View file

@ -1,10 +1,14 @@
{-# LANGUAGE LambdaCase #-} {-# LANGUAGE LambdaCase #-}
module Monomorphizer.MorbIr where
import Grammar.Print module Monomorphizer.MorbIr (
import TypeChecker.TypeCheckerIr qualified as TIR (Ident (..)) module Monomorphizer.MorbIr,
module LambdaLifterIr
) where
type Id = (TIR.Ident, Type) import Data.List (intercalate)
import Grammar.Print
import LambdaLifterIr (Ident (..), Lit (..))
import Prelude hiding (exp)
newtype Program = Program [Def] newtype Program = Program [Def]
deriving (Show, Ord, Eq) deriving (Show, Ord, Eq)
@ -15,91 +19,81 @@ data Def = DBind Bind | DData Data
data Data = Data Type [Inj] data Data = Data Type [Inj]
deriving (Show, Ord, Eq) deriving (Show, Ord, Eq)
data Bind = Bind Id [Id] ExpT data Bind = Bind (T Ident) [T Ident] (T Exp)
| BindC [T Ident] (T Ident) [T Ident] (T Exp)
deriving (Show, Ord, Eq) deriving (Show, Ord, Eq)
type T a = (a, Type)
data Exp data Exp
= EVar TIR.Ident = EVar Ident
| EVarC [T Ident] Ident
| ELit Lit | ELit Lit
| ELet Bind ExpT | ELet Bind (T Exp)
| EApp ExpT ExpT | EApp (T Exp) (T Exp)
| EAdd ExpT ExpT | EAdd (T Exp) (T Exp)
| ECase ExpT [Branch] | ECase (T Exp) [Branch]
deriving (Show, Ord, Eq) deriving (Show, Ord, Eq)
data Pattern data Pattern
= PVar Id = PVar Ident
| PLit (Lit, Type) | PLit Lit
| PInj TIR.Ident [Pattern] | PInj Ident [T Pattern]
| PCatch | PCatch
| PEnum TIR.Ident | PEnum Ident
deriving (Eq, Ord, Show) deriving (Eq, Ord, Show)
data Branch = Branch (Pattern, Type) ExpT
data Branch = Branch (T Pattern) (T Exp)
deriving (Eq, Ord, Show) deriving (Eq, Ord, Show)
type ExpT = (Exp, Type) data Inj = Inj Ident Type
data Inj = Inj TIR.Ident Type
deriving (Show, Ord, Eq) deriving (Show, Ord, Eq)
data Lit data Type = TLit Ident | TFun Type Type | TData Ident [Type]
= LInt Integer
| LChar Char
deriving (Show, Ord, Eq)
data Type = TLit TIR.Ident | TFun Type Type | TData TIR.Ident [Type]
deriving (Show, Ord, Eq) deriving (Show, Ord, Eq)
flattenType :: Type -> [Type] flattenType :: Type -> [Type]
flattenType (TFun t1 t2) = t1 : flattenType t2 flattenType (TFun t1 t2) = t1 : flattenType t2
flattenType x = [x] flattenType x = [x]
instance Print Program where instance Print Program where
prt i (Program sc) = prPrec i 0 $ prt 0 sc prt i (Program sc) = prPrec i 0 $ prt 0 sc
instance Print (Bind) where instance Print Bind where
prt i (Bind sig@(name, _) parms rhs) = prt i (Bind sig@(name, _) parms rhs) =
prPrec i 0 $ prPrec i 0 $
concatD concatD
[ prtSig sig [ prt 0 sig
, prt 0 name , prt 0 name
, prtIdPs 0 parms , prt 0 parms
, doc $ showString "=" , doc $ showString "="
, prt 0 rhs , prt 0 rhs
] ]
prtSig :: Id -> Doc prt i (BindC cxt sig parms rhs) =
prtSig (name, t) = prPrec i 0 $
concatD concatD
[ prt 0 name [ doc . showString $ "{" ++ intercalate ", " (map (\(x, _) -> printTree x ++ "^") cxt) ++ "}" ++ printTree sig
, doc $ showString ":" , prt i parms
, prt 0 t , doc $ showString "="
, doc $ showString ";" , prt i rhs
] ]
instance Print (ExpT) where
prt i (e, t) =
concatD
[ doc $ showString "("
, prt i e
, doc $ showString ","
, prt i t
, doc $ showString ")"
]
instance Print [Bind] where instance Print [Bind] where
prt _ [] = concatD [] prt _ [] = concatD []
prt _ [x] = concatD [prt 0 x] prt _ [x] = concatD [prt 0 x]
prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs] prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs]
prtIdPs :: Int -> [Id] -> Doc
prtIdPs i = prPrec i 0 . concatD . map (prt i)
instance Print Exp where instance Print Exp where
prt i = \case prt i = \case
EVar name -> prPrec i 3 $ prt 0 name EVar name -> prPrec i 3 $ prt 0 name
EVarC as lident -> doc . showString
$ "{" ++ intercalate ", " (map go as) ++ "}" ++ printTree lident
where
go (x, _) = printTree x ++ "^=" ++ printTree (EVar x)
ELit lit -> prPrec i 3 $ prt 0 lit ELit lit -> prPrec i 3 $ prt 0 lit
ELet b e -> ELet b e ->
prPrec i 3 $ prPrec i 3 $
@ -134,16 +128,16 @@ instance Print Exp where
] ]
instance Print Branch where instance Print Branch where
prt i (Branch (pattern_, t) exp) = prPrec i 0 (concatD [doc (showString "("), prt 0 pattern_, doc (showString " : "), prt 0 t, doc (showString ")"), doc (showString "=>"), prt 0 exp]) prt i (Branch patt exp) = prPrec i 0 (concatD [prt i patt, doc (showString "=>"), prt 0 exp])
instance Print [Branch] where instance Print [Branch] where
prt _ [] = concatD [] prt _ [] = concatD []
prt _ [x] = concatD [prt 0 x] prt _ [x] = concatD [prt 0 x]
prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs] prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs]
instance Print Def where instance Print Def where
prt i = \case prt i = \case
DBind bind -> prPrec i 0 (concatD [prt 0 bind]) DBind bind -> prPrec i 0 (concatD [prt 0 bind])
DData data_ -> prPrec i 0 (concatD [prt 0 data_]) DData data_ -> prPrec i 0 (concatD [prt 0 data_])
instance Print Data where instance Print Data where
@ -152,23 +146,23 @@ instance Print Data where
instance Print Inj where instance Print Inj where
prt i = \case prt i = \case
Inj uident type_ -> prPrec i 0 (concatD [prt 0 uident, doc (showString ":"), prt 0 type_]) Inj uident type_ -> prt i (uident, type_)
instance Print Pattern where instance Print Pattern where
prt i = \case prt i = \case
PVar name -> prPrec i 1 (concatD [prt 0 name]) PVar name -> prPrec i 1 (concatD [prt 0 name])
PLit (lit, _) -> prPrec i 1 (concatD [prt 0 lit]) PLit lit -> prPrec i 1 (concatD [prt 0 lit])
PCatch -> prPrec i 1 (concatD [doc (showString "_")]) PCatch -> prPrec i 1 (concatD [doc (showString "_")])
PEnum name -> prPrec i 1 (concatD [prt 0 name]) PEnum name -> prPrec i 1 (concatD [prt 0 name])
PInj uident patterns -> prPrec i 0 (concatD [prt 0 uident, prt 1 patterns]) PInj uident patterns -> prPrec i 0 (concatD [prt 0 uident, prt 1 patterns])
instance Print [Def] where instance Print [Def] where
prt _ [] = concatD [] prt _ [] = concatD []
prt _ [x] = concatD [prt 0 x] prt _ [x] = concatD [prt 0 x]
prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs] prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs]
instance Print [Type] where instance Print [Type] where
prt _ [] = concatD [] prt _ [] = concatD []
prt _ (x : xs) = concatD [prt 0 x, doc (showString " "), prt 0 xs] prt _ (x : xs) = concatD [prt 0 x, doc (showString " "), prt 0 xs]
instance Print Type where instance Print Type where
@ -177,8 +171,4 @@ instance Print Type where
TFun type_1 type_2 -> prPrec i 0 (concatD [prt 1 type_1, doc (showString "->"), prt 0 type_2]) TFun type_1 type_2 -> prPrec i 0 (concatD [prt 1 type_1, doc (showString "->"), prt 0 type_2])
TData uident types -> prPrec i 1 (concatD [prt 0 uident, doc (showString "("), prt 0 types, doc (showString ")")]) TData uident types -> prPrec i 1 (concatD [prt 0 uident, doc (showString "("), prt 0 types, doc (showString ")")])
instance Print Lit where
prt i = \case
LInt int -> prt i int
LChar char -> prt i char

View file

@ -2,15 +2,15 @@
module TypeChecker.ReportTEVar where module TypeChecker.ReportTEVar where
import Auxiliary (onM) import Auxiliary (onM)
import Control.Applicative (Applicative (liftA2), liftA3) import Control.Applicative (Applicative (liftA2), liftA3)
import Control.Monad.Except (MonadError (throwError)) import Control.Monad.Except (MonadError (throwError))
import Data.Coerce (coerce) import Data.Coerce (coerce)
import Data.Tuple.Extra (secondM) import Data.Tuple.Extra (secondM)
import Grammar.Abs qualified as G import qualified Grammar.Abs as G
import Grammar.ErrM (Err) import Grammar.ErrM (Err)
import Grammar.Print (printTree) import Grammar.Print (printTree)
import TypeChecker.TypeCheckerIr hiding (Type (..)) import TypeChecker.TypeCheckerIr hiding (Type (..))
data Type data Type
= TLit Ident = TLit Ident
@ -18,7 +18,7 @@ data Type
| TData Ident [Type] | TData Ident [Type]
| TFun Type Type | TFun Type Type
| TAll TVar Type | TAll TVar Type
deriving (Eq, Ord, Show, Read) deriving (Eq, Ord, Show)
class ReportTEVar a b where class ReportTEVar a b where
reportTEVar :: a -> Err b reportTEVar :: a -> Err b
@ -29,20 +29,20 @@ instance ReportTEVar (Program' G.Type) (Program' Type) where
instance ReportTEVar (Def' G.Type) (Def' Type) where instance ReportTEVar (Def' G.Type) (Def' Type) where
reportTEVar = \case reportTEVar = \case
DBind bind -> DBind <$> reportTEVar bind DBind bind -> DBind <$> reportTEVar bind
DData dat -> DData <$> reportTEVar dat DData dat -> DData <$> reportTEVar dat
instance ReportTEVar (Bind' G.Type) (Bind' Type) where instance ReportTEVar (Bind' G.Type) (Bind' Type) where
reportTEVar (Bind id vars rhs) = liftA3 Bind (reportTEVar id) (reportTEVar vars) (reportTEVar rhs) reportTEVar (Bind id vars rhs) = liftA3 Bind (reportTEVar id) (reportTEVar vars) (reportTEVar rhs)
instance ReportTEVar (Exp' G.Type) (Exp' Type) where instance ReportTEVar (Exp' G.Type) (Exp' Type) where
reportTEVar exp = case exp of reportTEVar exp = case exp of
EVar name -> pure $ EVar name EVar name -> pure $ EVar name
EInj name -> pure $ EInj name EInj name -> pure $ EInj name
ELit lit -> pure $ ELit lit ELit lit -> pure $ ELit lit
ELet bind e -> liftA2 ELet (reportTEVar bind) (reportTEVar e) ELet bind e -> liftA2 ELet (reportTEVar bind) (reportTEVar e)
EApp e1 e2 -> onM EApp reportTEVar e1 e2 EApp e1 e2 -> onM EApp reportTEVar e1 e2
EAdd e1 e2 -> onM EAdd reportTEVar e1 e2 EAdd e1 e2 -> onM EAdd reportTEVar e1 e2
EAbs name e -> EAbs name <$> reportTEVar e EAbs name e -> EAbs name <$> reportTEVar e
ECase e branches -> liftA2 ECase (reportTEVar e) (reportTEVar branches) ECase e branches -> liftA2 ECase (reportTEVar e) (reportTEVar branches)
instance ReportTEVar (Branch' G.Type) (Branch' Type) where instance ReportTEVar (Branch' G.Type) (Branch' Type) where
@ -53,10 +53,10 @@ instance ReportTEVar (Pattern' G.Type, G.Type) (Pattern' Type, Type) where
instance ReportTEVar (Pattern' G.Type) (Pattern' Type) where instance ReportTEVar (Pattern' G.Type) (Pattern' Type) where
reportTEVar = \case reportTEVar = \case
PVar name -> pure $ PVar name PVar name -> pure $ PVar name
PLit lit -> pure $ PLit lit PLit lit -> pure $ PLit lit
PCatch -> pure PCatch PCatch -> pure PCatch
PEnum name -> pure $ PEnum name PEnum name -> pure $ PEnum name
PInj name ps -> PInj name <$> reportTEVar ps PInj name ps -> PInj name <$> reportTEVar ps
instance ReportTEVar (Data' G.Type) (Data' Type) where instance ReportTEVar (Data' G.Type) (Data' Type) where
@ -65,10 +65,10 @@ instance ReportTEVar (Data' G.Type) (Data' Type) where
instance ReportTEVar (Inj' G.Type) (Inj' Type) where instance ReportTEVar (Inj' G.Type) (Inj' Type) where
reportTEVar (Inj name typ) = Inj name <$> reportTEVar typ reportTEVar (Inj name typ) = Inj name <$> reportTEVar typ
instance ReportTEVar (Id' G.Type) (Id' Type) where instance ReportTEVar (a, G.Type) (a, Type) where
reportTEVar = secondM reportTEVar reportTEVar = secondM reportTEVar
instance ReportTEVar (ExpT' G.Type) (ExpT' Type) where instance ReportTEVar (T' Exp' G.Type) (T' Exp' Type) where
reportTEVar (exp, typ) = liftA2 (,) (reportTEVar exp) (reportTEVar typ) reportTEVar (exp, typ) = liftA2 (,) (reportTEVar exp) (reportTEVar typ)
instance ReportTEVar a b => ReportTEVar [a] [b] where instance ReportTEVar a b => ReportTEVar [a] [b] where
@ -76,9 +76,9 @@ instance ReportTEVar a b => ReportTEVar [a] [b] where
instance ReportTEVar G.Type Type where instance ReportTEVar G.Type Type where
reportTEVar = \case reportTEVar = \case
G.TLit lit -> pure $ TLit (coerce lit) G.TLit lit -> pure $ TLit (coerce lit)
G.TVar (G.MkTVar i) -> pure $ TVar (MkTVar $ coerce i) G.TVar (G.MkTVar i) -> pure $ TVar (MkTVar $ coerce i)
G.TData name typs -> TData (coerce name) <$> reportTEVar typs G.TData name typs -> TData (coerce name) <$> reportTEVar typs
G.TFun t1 t2 -> liftA2 TFun (reportTEVar t1) (reportTEVar t2) G.TFun t1 t2 -> liftA2 TFun (reportTEVar t1) (reportTEVar t2)
G.TAll (G.MkTVar i) t -> TAll (MkTVar $ coerce i) <$> reportTEVar t G.TAll (G.MkTVar i) t -> TAll (MkTVar $ coerce i) <$> reportTEVar t
G.TEVar tevar -> throwError ("Found TEVar: " ++ printTree tevar) G.TEVar tevar -> throwError ("Found TEVar: " ++ printTree tevar)

View file

@ -31,6 +31,7 @@ import Grammar.ErrM
import Grammar.Print (printTree) import Grammar.Print (printTree)
import Prelude hiding (exp) import Prelude hiding (exp)
import qualified TypeChecker.TypeCheckerIr as T import qualified TypeChecker.TypeCheckerIr as T
import TypeChecker.TypeCheckerIr (T, T')
-- Implementation is derived from the paper (Dunfield and Krishnaswami 2013) -- Implementation is derived from the paper (Dunfield and Krishnaswami 2013)
-- https://doi.org/10.1145/2500365.2500582 -- https://doi.org/10.1145/2500365.2500582
@ -172,7 +173,7 @@ typecheckInj (Inj inj_name inj_typ) name tvars
-- | Γ ⊢ e ↑ A ⊣ Δ -- | Γ ⊢ e ↑ A ⊣ Δ
-- Under input context Γ, e checks against input type A, with output context ∆ -- Under input context Γ, e checks against input type A, with output context ∆
check :: Exp -> Type -> Tc (T.ExpT' Type) check :: Exp -> Type -> Tc (T' T.Exp' Type)
-- Γ,α ⊢ e ↑ A ⊣ Δ,α -- Γ,α ⊢ e ↑ A ⊣ Δ,α
-- ------------------- ∀I -- ------------------- ∀I
@ -212,12 +213,6 @@ check (ECase scrut pi) c = do
e' <- check e c e' <- check e c
pure (T.Branch p' e') pure (T.Branch p' e')
apply (T.ECase (scrut', a) pi', c) apply (T.ECase (scrut', a) pi', c)
where
go (pi, b) (Branch p e) = do
p' <- checkPattern p =<< apply a
e'@(_, b') <- infer e
subtype b' b
apply (T.Branch p' e' : pi, b')
-- Γ,α ⊢ e ↓ A ⊣ Θ Θ ⊢ [Θ]A <: [Θ]B ⊣ Δ -- Γ,α ⊢ e ↓ A ⊣ Θ Θ ⊢ [Θ]A <: [Θ]B ⊣ Δ
@ -229,9 +224,6 @@ check e b = do
subtype a b' subtype a b'
apply (e', b) apply (e', b)
checkPattern :: Pattern -> Type -> Tc (T.Pattern' Type, Type) checkPattern :: Pattern -> Type -> Tc (T.Pattern' Type, Type)
checkPattern patt t_patt = case patt of checkPattern patt t_patt = case patt of
@ -297,7 +289,7 @@ checkPattern patt t_patt = case patt of
-- | Γ ⊢ e ↓ A ⊣ Δ -- | Γ ⊢ e ↓ A ⊣ Δ
-- Under input context Γ, e infers output type A, with output context ∆ -- Under input context Γ, e infers output type A, with output context ∆
infer :: Exp -> Tc (T.ExpT' Type) infer :: Exp -> Tc (T' T.Exp' Type)
infer (ELit lit) = apply (T.ELit lit, litType lit) infer (ELit lit) = apply (T.ELit lit, litType lit)
-- Γ ∋ (x : A) Γ ⊢ rec(x) -- Γ ∋ (x : A) Γ ⊢ rec(x)
@ -391,7 +383,7 @@ infer (ECase scrut pi) = do
-- | Γ ⊢ A • e ⇓ C ⊣ Δ -- | Γ ⊢ A • e ⇓ C ⊣ Δ
-- Under input context Γ , applying a function of type A to e infers type C, with output context ∆ -- Under input context Γ , applying a function of type A to e infers type C, with output context ∆
-- Instantiate existential type variables until there is an arrow type. -- Instantiate existential type variables until there is an arrow type.
applyInfer :: Type -> Exp -> Tc (T.ExpT' Type, Type) applyInfer :: Type -> Exp -> Tc (T' T.Exp' Type, Type)
-- Γ,ά ⊢ [ά/α]A • e ⇓ C ⊣ Δ -- Γ,ά ⊢ [ά/α]A • e ⇓ C ⊣ Δ
-- ------------------------ ∀App -- ------------------------ ∀App

View file

@ -1,32 +1,32 @@
{-# LANGUAGE LambdaCase #-} {-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedRecordDot #-} {-# LANGUAGE OverloadedRecordDot #-}
{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE QualifiedDo #-} {-# LANGUAGE QualifiedDo #-}
{-# OPTIONS_GHC -Wno-incomplete-patterns #-}
-- | A module for type checking and inference using algorithm W, Hindley-Milner -- | A module for type checking and inference using algorithm W, Hindley-Milner
module TypeChecker.TypeCheckerHm where module TypeChecker.TypeCheckerHm where
import Auxiliary (int, litType, maybeToRightM, unzip4) import Auxiliary (int, litType, maybeToRightM, unzip4)
import Auxiliary qualified as Aux import qualified Auxiliary as Aux
import Control.Monad.Except import Control.Monad.Except
import Control.Monad.Identity (Identity, runIdentity) import Control.Monad.Identity (Identity, runIdentity)
import Control.Monad.Reader import Control.Monad.Reader
import Control.Monad.State import Control.Monad.State
import Control.Monad.Writer import Control.Monad.Writer
import Data.Coerce (coerce) import Data.Coerce (coerce)
import Data.Function (on) import Data.Function (on)
import Data.List (foldl', nub, sortOn) import Data.List (foldl', nub, sortOn)
import Data.List.Extra (unsnoc) import Data.List.Extra (unsnoc)
import Data.Map (Map) import Data.Map (Map)
import Data.Map qualified as M import qualified Data.Map as M
import Data.Maybe (fromJust) import Data.Maybe (fromJust)
import Data.Set (Set) import Data.Set (Set)
import Data.Set qualified as S import qualified Data.Set as S
import Debug.Trace (trace, traceShow) import Debug.Trace (trace, traceShow)
import Grammar.Abs import Grammar.Abs
import Grammar.Print (printTree) import Grammar.Print (printTree)
import TypeChecker.TypeCheckerIr qualified as T import qualified TypeChecker.TypeCheckerIr as T
import TypeChecker.TypeCheckerIr (T, T')
{- {-
TODO TODO
@ -41,7 +41,7 @@ typecheck :: Program -> Either String (T.Program' Type, [Warning])
typecheck = onLeft msg . run . checkPrg typecheck = onLeft msg . run . checkPrg
where where
onLeft :: (Error -> String) -> Either Error a -> Either String a onLeft :: (Error -> String) -> Either Error a -> Either String a
onLeft f (Left x) = Left $ f x onLeft f (Left x) = Left $ f x
onLeft _ (Right x) = Right x onLeft _ (Right x) = Right x
checkPrg :: Program -> Infer (T.Program' Type) checkPrg :: Program -> Infer (T.Program' Type)
@ -68,13 +68,13 @@ prettify s (T.Program defs) = T.Program $ map (go s) defs
replace :: Map T.Ident T.Ident -> Type -> Type replace :: Map T.Ident T.Ident -> Type -> Type
replace m def@(TVar (MkTVar (LIdent a))) = case M.lookup (coerce a) m of replace m def@(TVar (MkTVar (LIdent a))) = case M.lookup (coerce a) m of
Just t -> TVar . MkTVar . LIdent $ coerce t Just t -> TVar . MkTVar . LIdent $ coerce t
Nothing -> def Nothing -> def
replace m (TFun t1 t2) = (TFun `on` replace m) t1 t2 replace m (TFun t1 t2) = (TFun `on` replace m) t1 t2
replace m (TData name ts) = TData name (map (replace m) ts) replace m (TData name ts) = TData name (map (replace m) ts)
replace m def@(TAll (MkTVar forall_) t) = case M.lookup (coerce forall_) m of replace m def@(TAll (MkTVar forall_) t) = case M.lookup (coerce forall_) m of
Just found -> TAll (MkTVar $ coerce found) (replace m t) Just found -> TAll (MkTVar $ coerce found) (replace m t)
Nothing -> def Nothing -> def
replace _ t = t replace _ t = t
bindCount :: [Def] -> Infer [(Int, Def)] bindCount :: [Def] -> Infer [(Int, Def)]
@ -128,7 +128,7 @@ preRun (x : xs) = case x of
s <- gets sigs s <- gets sigs
case M.lookup (coerce n) s of case M.lookup (coerce n) s of
Nothing -> insertSig (coerce n) Nothing >> preRun xs Nothing -> insertSig (coerce n) Nothing >> preRun xs
Just _ -> preRun xs Just _ -> preRun xs
DData d@(Data t _) -> collect (collectTVars t) >> checkData d >> preRun xs DData d@(Data t _) -> collect (collectTVars t) >> checkData d >> preRun xs
where where
-- Check if function body / signature has been declared already -- Check if function body / signature has been declared already
@ -150,11 +150,11 @@ checkDef (x : xs) = case x of
T.Data t $ map (\(Inj name typ) -> T.Inj (coerce name) typ) injs T.Data t $ map (\(Inj name typ) -> T.Inj (coerce name) typ) injs
freeOrdered :: Type -> [T.Ident] freeOrdered :: Type -> [T.Ident]
freeOrdered (TVar (MkTVar a)) = return (coerce a) freeOrdered (TVar (MkTVar a)) = return (coerce a)
freeOrdered (TAll (MkTVar bound) t) = return (coerce bound) ++ freeOrdered t freeOrdered (TAll (MkTVar bound) t) = return (coerce bound) ++ freeOrdered t
freeOrdered (TFun a b) = freeOrdered a ++ freeOrdered b freeOrdered (TFun a b) = freeOrdered a ++ freeOrdered b
freeOrdered (TData _ a) = concatMap freeOrdered a freeOrdered (TData _ a) = concatMap freeOrdered a
freeOrdered _ = mempty freeOrdered _ = mempty
-- Much cleaner implementation, unfortunately one minor bug -- Much cleaner implementation, unfortunately one minor bug
-- checkBind :: Bind -> Infer (T.Bind' Type) -- checkBind :: Bind -> Infer (T.Bind' Type)
@ -257,13 +257,13 @@ checkInj (Inj c inj_typ) name tvars
toTVar :: Type -> Either Error TVar toTVar :: Type -> Either Error TVar
toTVar = \case toTVar = \case
TVar tvar -> pure tvar TVar tvar -> pure tvar
_ -> uncatchableErr "Not a type variable" _ -> uncatchableErr "Not a type variable"
returnType :: Type -> Type returnType :: Type -> Type
returnType (TFun _ t2) = returnType t2 returnType (TFun _ t2) = returnType t2
returnType a = a returnType a = a
inferExp :: Exp -> Infer (T.ExpT' Type) inferExp :: Exp -> Infer (T' T.Exp' Type)
inferExp e = do inferExp e = do
(s, (e', t)) <- algoW e (s, (e', t)) <- algoW e
let subbed = apply s t let subbed = apply s t
@ -274,7 +274,7 @@ class CollectTVars a where
instance CollectTVars Exp where instance CollectTVars Exp where
collectTVars (EAnn e t) = collectTVars t `S.union` collectTVars e collectTVars (EAnn e t) = collectTVars t `S.union` collectTVars e
collectTVars _ = S.empty collectTVars _ = S.empty
instance CollectTVars Type where instance CollectTVars Type where
collectTVars (TVar (MkTVar i)) = S.singleton (coerce i) collectTVars (TVar (MkTVar i)) = S.singleton (coerce i)
@ -287,7 +287,7 @@ instance CollectTVars Type where
collect :: Set T.Ident -> Infer () collect :: Set T.Ident -> Infer ()
collect s = modify (\st -> st{takenTypeVars = s `S.union` takenTypeVars st}) collect s = modify (\st -> st{takenTypeVars = s `S.union` takenTypeVars st})
algoW :: Exp -> Infer (Subst, T.ExpT' Type) algoW :: Exp -> Infer (Subst, T' T.Exp' Type)
algoW = \case algoW = \case
err@(EAnn e t) -> do err@(EAnn e t) -> do
(sub0, (e', t')) <- exprErr (algoW e) err (sub0, (e', t')) <- exprErr (algoW e) err
@ -600,12 +600,12 @@ generalize :: Map T.Ident Type -> Type -> Type
generalize env t = go (S.toList $ free t S.\\ free env) (removeForalls t) generalize env t = go (S.toList $ free t S.\\ free env) (removeForalls t)
where where
go :: [T.Ident] -> Type -> Type go :: [T.Ident] -> Type -> Type
go [] t = t go [] t = t
go (x : xs) t = TAll (MkTVar (coerce x)) (go xs t) go (x : xs) t = TAll (MkTVar (coerce x)) (go xs t)
removeForalls :: Type -> Type removeForalls :: Type -> Type
removeForalls (TAll _ t) = removeForalls t removeForalls (TAll _ t) = removeForalls t
removeForalls (TFun t1 t2) = TFun (removeForalls t1) (removeForalls t2) removeForalls (TFun t1 t2) = TFun (removeForalls t1) (removeForalls t2)
removeForalls t = t removeForalls t = t
{- | Instantiate a polymorphic type. The free type variables are substituted {- | Instantiate a polymorphic type. The free type variables are substituted
with fresh ones. with fresh ones.
@ -643,7 +643,7 @@ fresh = do
ungo :: [TVar] -> Type -> Type -> Bool ungo :: [TVar] -> Type -> Type -> Bool
ungo tvars t1 t2 = case run (go tvars t1 t2) of ungo tvars t1 t2 = case run (go tvars t1 t2) of
Right (b, _) -> b Right (b, _) -> b
_ -> False _ -> False
-- TODO: Fix the following -- TODO: Fix the following
-- Maybe locally using the Infer monad can cause trouble. -- Maybe locally using the Infer monad can cause trouble.
-- Since the fresh count starts from zero -- Since the fresh count starts from zero
@ -656,7 +656,7 @@ fresh = do
skipForalls :: Type -> Type skipForalls :: Type -> Type
skipForalls = \case skipForalls = \case
TAll _ t -> skipForalls t TAll _ t -> skipForalls t
t -> t t -> t
freshen :: Type -> Infer Type freshen :: Type -> Infer Type
freshen t = do freshen t = do
@ -705,10 +705,10 @@ instance SubstType Type where
TLit _ -> t TLit _ -> t
TVar (MkTVar a) -> case M.lookup (coerce a) sub of TVar (MkTVar a) -> case M.lookup (coerce a) sub of
Nothing -> TVar (MkTVar $ coerce a) Nothing -> TVar (MkTVar $ coerce a)
Just t -> t Just t -> t
TAll (MkTVar i) t -> case M.lookup (coerce i) sub of TAll (MkTVar i) t -> case M.lookup (coerce i) sub of
Nothing -> TAll (MkTVar i) (apply sub t) Nothing -> TAll (MkTVar i) (apply sub t)
Just _ -> apply sub t Just _ -> apply sub t
TFun a b -> TFun (apply sub a) (apply sub b) TFun a b -> TFun (apply sub a) (apply sub b)
TData name a -> TData name (apply sub a) TData name a -> TData name (apply sub a)
TEVar (MkTEVar _) -> t TEVar (MkTEVar _) -> t
@ -724,7 +724,7 @@ instance SubstType (Map T.Ident Type) where
instance SubstType (Map T.Ident (Maybe Type)) where instance SubstType (Map T.Ident (Maybe Type)) where
apply s = M.map (fmap $ apply s) apply s = M.map (fmap $ apply s)
instance SubstType (T.ExpT' Type) where instance SubstType (T' T.Exp' Type) where
apply s (e, t) = (apply s e, apply s t) apply s (e, t) = (apply s e, apply s t)
instance SubstType (T.Exp' Type) where instance SubstType (T.Exp' Type) where
@ -753,10 +753,10 @@ instance SubstType (T.Branch' Type) where
instance SubstType (T.Pattern' Type) where instance SubstType (T.Pattern' Type) where
apply s = \case apply s = \case
T.PVar iden -> T.PVar iden T.PVar iden -> T.PVar iden
T.PLit lit -> T.PLit lit T.PLit lit -> T.PLit lit
T.PInj i ps -> T.PInj i $ apply s ps T.PInj i ps -> T.PInj i $ apply s ps
T.PCatch -> T.PCatch T.PCatch -> T.PCatch
T.PEnum i -> T.PEnum i T.PEnum i -> T.PEnum i
instance SubstType (T.Pattern' Type, Type) where instance SubstType (T.Pattern' Type, Type) where
apply s (p, t) = (apply s p, apply s t) apply s (p, t) = (apply s p, apply s t)
@ -764,7 +764,7 @@ instance SubstType (T.Pattern' Type, Type) where
instance SubstType a => SubstType [a] where instance SubstType a => SubstType [a] where
apply s = map (apply s) apply s = map (apply s)
instance SubstType (T.Id' Type) where instance SubstType (T T.Ident Type) where
apply s (name, t) = (name, apply s t) apply s (name, t) = (name, apply s t)
-- | Represents the empty substition set -- | Represents the empty substition set
@ -797,11 +797,11 @@ withBindings xs =
-- | Run the monadic action with a pattern -- | Run the monadic action with a pattern
withPattern :: (Monad m, MonadReader Ctx m) => (T.Pattern' Type, Type) -> m a -> m a withPattern :: (Monad m, MonadReader Ctx m) => (T.Pattern' Type, Type) -> m a -> m a
withPattern (p, t) ma = case p of withPattern (p, t) ma = case p of
T.PVar x -> withBinding x t ma T.PVar x -> withBinding x t ma
T.PInj _ ps -> foldl' (flip withPattern) ma ps T.PInj _ ps -> foldl' (flip withPattern) ma ps
T.PLit _ -> ma T.PLit _ -> ma
T.PCatch -> ma T.PCatch -> ma
T.PEnum _ -> ma T.PEnum _ -> ma
-- | Insert a function signature into the environment -- | Insert a function signature into the environment
insertSig :: T.Ident -> Maybe Type -> Infer () insertSig :: T.Ident -> Maybe Type -> Infer ()
@ -826,11 +826,11 @@ existInj n = gets (M.lookup n . injections)
flattenType :: Type -> [Type] flattenType :: Type -> [Type]
flattenType (TFun a b) = flattenType a <> flattenType b flattenType (TFun a b) = flattenType a <> flattenType b
flattenType a = [a] flattenType a = [a]
typeLength :: Type -> Int typeLength :: Type -> Int
typeLength (TFun _ b) = 1 + typeLength b typeLength (TFun _ b) = 1 + typeLength b
typeLength _ = 1 typeLength _ = 1
{- | Catch an error if possible and add the given {- | Catch an error if possible and add the given
expression as addition to the error message expression as addition to the error message
@ -913,11 +913,11 @@ newtype Ctx = Ctx {vars :: Map T.Ident Type}
deriving (Show) deriving (Show)
data Env = Env data Env = Env
{ count :: Int { count :: Int
, nextChar :: Char , nextChar :: Char
, sigs :: Map T.Ident (Maybe Type) , sigs :: Map T.Ident (Maybe Type)
, takenTypeVars :: Set T.Ident , takenTypeVars :: Set T.Ident
, injections :: Map T.Ident Type , injections :: Map T.Ident Type
, declaredBinds :: Set T.Ident , declaredBinds :: Set T.Ident
} }
deriving (Show) deriving (Show)

View file

@ -1,6 +1,7 @@
{-# LANGUAGE LambdaCase #-} {-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE PatternSynonyms #-}
module TypeChecker.TypeCheckerIr ( module TypeChecker.TypeCheckerIr (
module Grammar.Abs, module Grammar.Abs,
module TypeChecker.TypeCheckerIr, module TypeChecker.TypeCheckerIr,
@ -10,31 +11,30 @@ import Data.String (IsString)
import Grammar.Abs (Lit (..)) import Grammar.Abs (Lit (..))
import Grammar.Print import Grammar.Print
import Prelude import Prelude
import qualified Prelude as C (Eq, Ord, Read, Show)
newtype Program' t = Program [Def' t] newtype Program' t = Program [Def' t]
deriving (C.Eq, C.Ord, C.Show, C.Read, Functor) deriving (Eq, Ord, Show, Functor)
data Def' t data Def' t
= DBind (Bind' t) = DBind (Bind' t)
| DData (Data' t) | DData (Data' t)
deriving (C.Eq, C.Ord, C.Show, C.Read, Functor) deriving (Eq, Ord, Show, Functor)
data Type data Type
= TLit Ident = TLit Ident
| TVar TVar | TVar TVar
| TData Ident [Type] | TData Ident [Type]
| TFun Type Type | TFun Type Type
deriving (Eq, Ord, Show, Read) deriving (Eq, Ord, Show)
data Data' t = Data t [Inj' t] data Data' t = Data t [Inj' t]
deriving (C.Eq, C.Ord, C.Show, C.Read, Functor) deriving (Eq, Ord, Show, Functor)
data Inj' t = Inj Ident t data Inj' t = Inj Ident t
deriving (C.Eq, C.Ord, C.Show, C.Read, Functor) deriving (Eq, Ord, Show, Functor)
newtype Ident = Ident String newtype Ident = Ident String
deriving (C.Eq, C.Ord, C.Show, C.Read, IsString) deriving (Eq, Ord, Show, IsString)
data Pattern' t data Pattern' t
= PVar Ident = PVar Ident
@ -42,30 +42,31 @@ data Pattern' t
| PCatch | PCatch
| PEnum Ident | PEnum Ident
| PInj Ident [(Pattern' t, t)] | PInj Ident [(Pattern' t, t)]
deriving (C.Eq, C.Ord, C.Show, C.Read, Functor) deriving (Eq, Ord, Show, Functor)
data Exp' t data Exp' t
= EVar Ident = EVar Ident
| EInj Ident | EInj Ident
| ELit Lit | ELit Lit
| ELet (Bind' t) (ExpT' t) | ELet (Bind' t) (T' Exp' t)
| EApp (ExpT' t) (ExpT' t) | EApp (T' Exp' t) (T' Exp' t)
| EAdd (ExpT' t) (ExpT' t) | EAdd (T' Exp' t) (T' Exp' t)
| EAbs Ident (ExpT' t) | EAbs Ident (T' Exp' t)
| ECase (ExpT' t) [Branch' t] | ECase (T' Exp' t) [Branch' t]
deriving (C.Eq, C.Ord, C.Show, C.Read, Functor) deriving (Eq, Ord, Show, Functor)
newtype TVar = MkTVar Ident newtype TVar = MkTVar Ident
deriving (C.Eq, C.Ord, C.Show, C.Read) deriving (Eq, Ord, Show)
type Id' t = (Ident, t) type T' a t = (a t, t)
type ExpT' t = (Exp' t, t) type T a t = (a, t)
data Bind' t = Bind (Id' t) [Id' t] (ExpT' t)
deriving (C.Eq, C.Ord, C.Show, C.Read, Functor)
data Branch' t = Branch (Pattern' t, t) (ExpT' t) data Bind' t = Bind (T Ident t) [T Ident t] (T' Exp' t)
deriving (C.Eq, C.Ord, C.Show, C.Read, Functor) deriving (Eq, Ord, Show, Functor)
data Branch' t = Branch (T' Pattern' t) (T' Exp' t)
deriving (Eq, Ord, Show, Functor)
instance Print Ident where instance Print Ident where
prt _ (Ident s) = doc $ showString s prt _ (Ident s) = doc $ showString s
@ -81,22 +82,22 @@ instance Print t => Print (Bind' t) where
, prt i rhs , prt i rhs
] ]
prtSig :: Print t => Id' t -> Doc prtSig :: Print t => T Ident t -> Doc
prtSig (name, t) = prtSig (x, t) =
concatD concatD
[ prt 0 name [ prt 0 x
, doc $ showString ":" , doc $ showString ":"
, prt 0 t , prt 0 t
] ]
instance Print t => Print (ExpT' t) where instance (Print a, Print t) => Print (T a t) where
prt i (e, t) = prt i (x, t) =
concatD concatD
[ doc $ showString "(" [ -- doc $ showString "("
, prt i e {- , -} prt i x
, doc $ showString ":" -- , doc $ showString ":"
, prt 0 t -- , prt 0 t
, doc $ showString ")" -- , doc $ showString ")"
] ]
instance Print t => Print [Bind' t] where instance Print t => Print [Bind' t] where
@ -104,15 +105,6 @@ instance Print t => Print [Bind' t] where
prt i [x] = concatD [prt i x] prt i [x] = concatD [prt i x]
prt i (x : xs) = concatD [prt i x, doc (showString ";"), prt i xs] prt i (x : xs) = concatD [prt i x, doc (showString ";"), prt i xs]
instance Print t => Print (Id' t) where
prt i (name, t) =
concatD
[ doc $ showString "("
, prt i name
, doc $ showString ","
, prt i t
, doc $ showString ")"
]
instance Print t => Print (Exp' t) where instance Print t => Print (Exp' t) where
prt i = \case prt i = \case
@ -151,9 +143,6 @@ instance Print t => Print [Inj' t] where
prt i [x] = prt i x prt i [x] = prt i x
prt i (x : xs) = prPrec i 0 $ concatD [prt i x, doc $ showString "\n ", prt i xs] prt i (x : xs) = prPrec i 0 $ concatD [prt i x, doc $ showString "\n ", prt i xs]
instance Print t => Print (Pattern' t, t) where
prt i (p, t) = prPrec i 1 (concatD [prt i p, prt i t])
instance Print t => Print (Pattern' t) where instance Print t => Print (Pattern' t) where
prt i = \case prt i = \case
PVar name -> prPrec i 1 (concatD [prt 0 name]) PVar name -> prPrec i 1 (concatD [prt 0 name])
@ -189,8 +178,6 @@ type Branch = Branch' Type
type Pattern = Pattern' Type type Pattern = Pattern' Type
type Inj = Inj' Type type Inj = Inj' Type
type Exp = Exp' Type type Exp = Exp' Type
type ExpT = ExpT' Type
type Id = Id' Type
pattern TVar' s = TVar (MkTVar s) pattern TVar' s = TVar (MkTVar s)
pattern DBind' id vars expt = DBind (Bind id vars expt) pattern DBind' id vars expt = DBind (Bind id vars expt)
pattern DData' typ injs = DData (Data typ injs) pattern DData' typ injs = DData (Data typ injs)

185
test_map2.ll Normal file
View file

@ -0,0 +1,185 @@
target triple = "x86_64-pc-linux-gnu"
target datalayout = "e-m:o-i64:64-f80:128-n8:16:32:64-S128"
@.str = private unnamed_addr constant [3 x i8] c"%i
", align 1
@.non_exhaustive_patterns = private unnamed_addr constant [41 x i8] c"Non-exhaustive patterns in case at %i:%i
"
declare i32 @printf(ptr noalias nocapture, ...)
declare i32 @exit(i32 noundef)
declare ptr @malloc(i32 noundef)
%List = type { i8, [23 x i8] }
%Cons = type { i8, i64, %List* }
%Nil = type { i8 }
; NYTT: kontexttyp
%Closure_sc_0 = type { i64 (i64)*, i64 }
; Ident "sum$List_Int": (ECase (EVar (Ident "$4xs"),TLit (Ident "List")) [Branch (PEnum (Ident "Nil"),TLit (Ident "List")) (ELit (LInt 0),TLit (Ident "Int")),Branch (PInj (Ident "Cons") [PVar (Ident "$5x",TLit (Ident "Int")),PVar (Ident "$6xs",TLit (Ident "List"))],TLit (Ident "List")) (EAdd (EVar (Ident "$5x"),TLit (Ident "Int")) (EApp (EVar (Ident "sum$List_Int"),TFun (TLit (Ident "List")) (TLit (Ident "Int"))) (EVar (Ident "$6xs"),TLit (Ident "List")),TLit (Ident "Int")),TLit (Ident "Int"))],TLit (Ident "Int"))
define fastcc i64 @sum$List_Int(%List %$4xs) {
%1 = alloca i64
; Penum
%2 = extractvalue %List %$4xs, 0
%3 = icmp eq i8 %2, 1
br i1 %3, label %lbl_success_3, label %lbl_failed_2
lbl_success_3:
%4 = alloca %List
store %List %$4xs, ptr %4
%5 = load %Nil, ptr %4
store i64 0, ptr %1
br label %lbl_escape_1
lbl_failed_2:
; Inj
%6 = extractvalue %List %$4xs, 0
%7 = icmp eq i8 %6, 0
br i1 %7, label %lbl_success_5, label %lbl_failed_4
lbl_success_5:
%8 = alloca %List
store %List %$4xs, ptr %8
%9 = load %Cons, ptr %8
; ident i64
%$5x = extractvalue %Cons %9, 1
; ident %List
%10 = extractvalue %Cons %9, 2
%$6xs = load %List, ptr %10
; TLit (Ident "Int")
%11 = call fastcc i64 @sum$List_Int(%List %$6xs)
%12 = add i64 %$5x, %11
store i64 %12, ptr %1
br label %lbl_escape_1
lbl_failed_4:
call i32 (ptr, ...) @printf(ptr noundef @.non_exhaustive_patterns, i64 noundef 12, i64 noundef 6)
call i32 @exit(i32 noundef 1)
br label %lbl_escape_1
lbl_escape_1:
%15 = load i64, ptr %1
ret i64 %15
}
; Ident "sc_0$Int_Int": (EAdd (EVar (Ident "$7x"),TLit (Ident "Int")) (ELit (LInt 10),TLit (Ident "Int")),TLit (Ident "Int"))
; ÄNDRAT: lägg till kontextpekare
define fastcc i64 @sc_0$Int_Int(ptr %closure_sc_0, i64 %$7x) {
; NYTT: Ladda alla fria variabler
%fri_variabel_ptr = getelementptr inbounds %Closure_sc_0, ptr %closure_sc_0, i32 0, i32 1
%fri_variabel = load i64, ptr %fri_variabel_ptr
; ÄNDRAT: %fri_variabel istället för 2
%1 = add i64 %$7x, %fri_variabel
ret i64 %1
}
; Ident "map$Int_Int_List_List": (ECase (EVar (Ident "$1xs"),TLit (Ident "List")) [Branch (PEnum (Ident "Nil"),TLit (Ident "List")) (EVar (Ident "Nil"),TLit (Ident "List")),Branch (PInj (Ident "Cons") [PVar (Ident "$2x",TLit (Ident "Int")),PVar (Ident "$3xs",TLit (Ident "List"))],TLit (Ident "List")) (EApp (EApp (EVar (Ident "Cons"),TFun (TLit (Ident "Int")) (TFun (TLit (Ident "List")) (TLit (Ident "List")))) (EApp (EVar (Ident "$0f"),TFun (TLit (Ident "Int")) (TLit (Ident "Int"))) (EVar (Ident "$2x"),TLit (Ident "Int")),TLit (Ident "Int")),TFun (TLit (Ident "List")) (TLit (Ident "List"))) (EApp (EApp (EVar (Ident "map$Int_Int_List_List"),TFun (TFun (TLit (Ident "Int")) (TLit (Ident "Int"))) (TFun (TLit (Ident "List")) (TLit (Ident "List")))) (EVar (Ident "$0f"),TFun (TLit (Ident "Int")) (TLit (Ident "Int"))),TFun (TLit (Ident "List")) (TLit (Ident "List"))) (EVar (Ident "$3xs"),TLit (Ident "List")),TLit (Ident "List")),TLit (Ident "List"))],TLit (Ident "List"))
; ÄNDRAT: ptr istället för i64 (i64)*
define fastcc %List @map$Int_Int_List_List(ptr %$0f, %List %$1xs) {
; NYTT: ta fram funktionspekaren
%$0f_deref = load i64(i64)*, ptr %$0f
%1 = alloca %List
; Penum
%2 = extractvalue %List %$1xs, 0
%3 = icmp eq i8 %2, 1
br i1 %3, label %lbl_success_8, label %lbl_failed_7
lbl_success_8:
%4 = alloca %List
store %List %$1xs, ptr %4
%5 = load %Nil, ptr %4
%6 = call fastcc %List @Nil()
store %List %6, ptr %1
br label %lbl_escape_6
lbl_failed_7:
; Inj
%7 = extractvalue %List %$1xs, 0
%8 = icmp eq i8 %7, 0
br i1 %8, label %lbl_success_10, label %lbl_failed_9
lbl_success_10:
%9 = alloca %List
store %List %$1xs, ptr %9
%10 = load %Cons, ptr %9
; ident i64
%$2x = extractvalue %Cons %10, 1
; ident %List
%11 = extractvalue %Cons %10, 2
%$3xs = load %List, ptr %11
; TLit (Ident "Int")
; ÄNDRAT använd deref
%12 = call fastcc i64 %$0f_deref(ptr %$0f, i64 %$2x)
; TLit (Ident "List")
; ÄNDRAT ptr istället för 64 (64)* och skicka med ptr
%13 = call fastcc %List @map$Int_Int_List_List(ptr %$0f, %List %$3xs)
; TLit (Ident "List")
%14 = call fastcc %List @Cons(i64 %12, %List %13)
store %List %14, ptr %1
br label %lbl_escape_6
lbl_failed_9:
call i32 (ptr, ...) @printf(ptr noundef @.non_exhaustive_patterns, i64 noundef 14, i64 noundef 6)
call i32 @exit(i32 noundef 1)
br label %lbl_escape_6
lbl_escape_6:
%17 = load %List, ptr %1
ret %List %17
}
; Ident "main": (EApp (EVar (Ident "sum$List_Int"),TFun (TLit (Ident "List")) (TLit (Ident "Int"))) (EApp (EApp (EVar (Ident "map$Int_Int_List_List"),TFun (TFun (TLit (Ident "Int")) (TLit (Ident "Int"))) (TFun (TLit (Ident "List")) (TLit (Ident "List")))) (EVar (Ident "sc_0$Int_Int"),TFun (TLit (Ident "Int")) (TLit (Ident "Int"))),TFun (TLit (Ident "List")) (TLit (Ident "List"))) (EApp (EApp (EVar (Ident "Cons"),TFun (TLit (Ident "Int")) (TFun (TLit (Ident "List")) (TLit (Ident "List")))) (ELit (LInt 1),TLit (Ident "Int")),TFun (TLit (Ident "List")) (TLit (Ident "List"))) (EApp (EApp (EVar (Ident "Cons"),TFun (TLit (Ident "Int")) (TFun (TLit (Ident "List")) (TLit (Ident "List")))) (ELit (LInt 2),TLit (Ident "Int")),TFun (TLit (Ident "List")) (TLit (Ident "List"))) (EVar (Ident "Nil"),TLit (Ident "List")),TLit (Ident "List")),TLit (Ident "List")),TLit (Ident "List")),TLit (Ident "Int"))
define fastcc i64 @main() {
%1 = call fastcc %List @Nil()
; TLit (Ident "List")
%2 = call fastcc %List @Cons(i64 2, %List %1)
; TLit (Ident "List")
%3 = call fastcc %List @Cons(i64 1, %List %2)
; TLit (Ident "List")
; NYTT: spara funktionspekaren och 100 i kontexten
%closure_sc_0 = alloca %Closure_sc_0
store i64(i64)* @sc_0$Int_Int, ptr %closure_sc_0
%fri_variabel_ptr = getelementptr inbounds %Closure_sc_0, ptr %closure_sc_0, i32 0, i32 1
store i64 100, ptr %fri_variabel_ptr
; store %Closure_sc_0 {i64 (i64)* @sc_0$Int_Int, 100}, ptr %closure_sc_0
; ÄNDRAT ptr %closure_sc_0 istället för i64 (i64)* @sc_0$Int_Int
%4 = call fastcc %List @map$Int_Int_List_List(ptr %closure_sc_0, %List %3)
; TLit (Ident "Int")
%5 = call fastcc i64 @sum$List_Int(%List %4)
call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef %5)
ret i64 0
}
define fastcc %List @Cons(i64 %arg_0, %List %arg_1) {
%1 = alloca %List
%2 = getelementptr %List, %List* %1, i64 0, i32 0
store i8 0, i8* %2
%3 = bitcast %List* %1 to %Cons*
; i64 arg_0 1
%4 = getelementptr %Cons, %Cons* %3, i64 0, i32 1
; Just store
store i64 %arg_0, ptr %4
; %List arg_1 2
%5 = getelementptr %Cons, %Cons* %3, i64 0, i32 2
; Malloc and store
%6 = call ptr @malloc(i64 24)
store %List %arg_1, ptr %6
store %List* %6, ptr %5
; Return the newly constructed value
%7 = load %List, ptr %1
ret %List %7
}
define fastcc %List @Nil() {
%1 = alloca %List
%2 = getelementptr %List, %List* %1, i64 0, i32 0
store i8 1, i8* %2
%3 = bitcast %List* %1 to %Nil*
; Return the newly constructed value
%4 = load %List, ptr %1
ret %List %4
}