Splat up the codegenerator a bit.

This commit is contained in:
Samuel Hammersberg 2023-03-29 17:34:47 +02:00
parent 36b6a8f781
commit 61f364cd75
6 changed files with 552 additions and 573 deletions

View file

@ -42,6 +42,9 @@ executable language
Monomorphizer.MonomorphizerIr
Codegen.Codegen
Codegen.LlvmIr
Codegen.Auxillary
Codegen.CompilerState
Codegen.Emits
Compiler
Renamer.Renamer
TreeConverter

50
src/Codegen/Auxillary.hs Normal file
View file

@ -0,0 +1,50 @@
module Codegen.Auxillary where
import Codegen.LlvmIr (LLVMType (..), LLVMValue (..))
import Control.Monad (foldM_)
import Monomorphizer.MonomorphizerIr as MIR (ExpT, Type (..))
import TypeChecker.TypeCheckerIr qualified as TIR
type2LlvmType :: MIR.Type -> LLVMType
type2LlvmType (MIR.TLit id@(TIR.Ident name)) = case name of
"Int" -> I64
"Char" -> I8
_ -> CustomType id
type2LlvmType (MIR.TFun t xs) = do
let (t', xs') = function2LLVMType xs [type2LlvmType t]
Function t' xs'
where
function2LLVMType :: Type -> [LLVMType] -> (LLVMType, [LLVMType])
function2LLVMType (TFun t xs) s = function2LLVMType xs (type2LlvmType t : s)
function2LLVMType x s = (type2LlvmType x, s)
getType :: ExpT -> LLVMType
getType (_, t) = type2LlvmType t
extractTypeName :: MIR.Type -> TIR.Ident
extractTypeName (MIR.TLit id) = id
extractTypeName (MIR.TFun t xs) =
let (TIR.Ident i) = extractTypeName t
(TIR.Ident is) = extractTypeName xs
in TIR.Ident $ i <> "_$_" <> is
valueGetType :: LLVMValue -> LLVMType
valueGetType (VInteger _) = I64
valueGetType (VChar _) = I8
valueGetType (VIdent _ t) = t
valueGetType (VConstant s) = Array (fromIntegral $ length s) I8
valueGetType (VFunction _ _ t) = t
typeByteSize :: LLVMType -> Integer
typeByteSize I1 = 1
typeByteSize I8 = 1
typeByteSize I32 = 4
typeByteSize I64 = 8
typeByteSize Ptr = 8
typeByteSize (Ref _) = 8
typeByteSize (Function _ _) = 8
typeByteSize (Array n t) = n * typeByteSize t
typeByteSize (CustomType _) = 8
enumerateOneM_ :: Monad m => (Integer -> a -> m b) -> [a] -> m ()
enumerateOneM_ f = foldM_ (\i a -> f i a >> pure (i + 1)) 1

View file

@ -1,184 +1,16 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
module Codegen.Codegen (generateCode) where
import Auxiliary (snoc)
import Codegen.LlvmIr as LIR
import Control.Applicative ((<|>))
import Control.Monad (when)
import Control.Monad.State (
StateT,
execStateT,
foldM_,
gets,
modify,
import Codegen.CompilerState (
CodeGenerator (instructions),
initCodeGenerator,
)
import Codegen.Emits (compileScs)
import Codegen.LlvmIr as LIR (llvmIrToString)
import Control.Monad.State (
execStateT,
)
import Data.Bifunctor qualified as BI
import Data.Char (ord)
import Data.Coerce (coerce)
import Data.Map (Map)
import Data.Map qualified as Map
import Data.Maybe (fromJust, fromMaybe)
import Data.Tuple.Extra (dupe, first, second)
import Grammar.ErrM (Err)
import Monomorphizer.MonomorphizerIr as MIR
import TypeChecker.TypeCheckerIr qualified as TIR
-- | The record used as the code generator state
data CodeGenerator = CodeGenerator
{ instructions :: [LLVMIr]
, functions :: Map MIR.Id FunctionInfo
, customTypes :: Map LLVMType Integer
, constructors :: Map TIR.Ident ConstructorInfo
, variableCount :: Integer
, labelCount :: Integer
}
-- | A state type synonym
type CompilerState a = StateT CodeGenerator Err a
data FunctionInfo = FunctionInfo
{ numArgs :: Int
, arguments :: [Id]
}
deriving (Show)
data ConstructorInfo = ConstructorInfo
{ numArgsCI :: Int
, argumentsCI :: [Id]
, numCI :: Integer
, returnTypeCI :: MIR.Type
}
deriving (Show)
-- | Adds a instruction to the CodeGenerator state
emit :: LLVMIr -> CompilerState ()
emit l = modify $ \t -> t{instructions = Auxiliary.snoc l $ instructions t}
-- | Increases the variable counter in the CodeGenerator state
increaseVarCount :: CompilerState ()
increaseVarCount = modify $ \t -> t{variableCount = variableCount t + 1}
-- | Returns the variable count from the CodeGenerator state
getVarCount :: CompilerState Integer
getVarCount = gets variableCount
-- | Increases the variable count and returns it from the CodeGenerator state
getNewVar :: CompilerState TIR.Ident
getNewVar = TIR.Ident . show <$> (increaseVarCount >> getVarCount)
-- | Increses the label count and returns a label from the CodeGenerator state
getNewLabel :: CompilerState Integer
getNewLabel = do
modify (\t -> t{labelCount = labelCount t + 1})
gets labelCount
{- | Produces a map of functions infos from a list of binds,
which contains useful data for code generation.
-}
getFunctions :: [MIR.Def] -> Map Id FunctionInfo
getFunctions bs = Map.fromList $ go bs
where
go [] = []
go (MIR.DBind (MIR.Bind id args _) : xs) =
(id, FunctionInfo{numArgs = length args, arguments = args})
: go xs
go (_ : xs) = go xs
createArgs :: [MIR.Type] -> [Id]
createArgs xs = fst $ foldl (\(acc, l) t -> (acc ++ [(TIR.Ident ("arg_" <> show l), t)], l + 1)) ([], 0) xs
{- | Produces a map of functions infos from a list of binds,
which contains useful data for code generation.
-}
getConstructors :: [MIR.Def] -> Map TIR.Ident ConstructorInfo
getConstructors bs = Map.fromList $ go bs
where
go [] = []
go (MIR.DData (MIR.Data t cons) : xs) =
fst
( foldl
( \(acc, i) (Inj id xs) ->
( ( id
, ConstructorInfo
{ numArgsCI = length (init . flattenType $ xs)
, argumentsCI = createArgs (init . flattenType $ xs)
, numCI = i
, returnTypeCI = t -- last . flattenType $ xs
}
)
: acc
, i + 1
)
)
([], 0)
cons
)
<> go xs
go (_ : xs) = go xs
getTypes :: [MIR.Def] -> Map LLVMType Integer
getTypes bs = Map.fromList $ go bs
where
go [] = []
go (MIR.DData (MIR.Data t ts) : xs) = (type2LlvmType t, biggestVariant ts) : go xs
go (_ : xs) = go xs
variantTypes fi = init $ map type2LlvmType (flattenType fi)
biggestVariant ts = 8 + maximum (sum . (\(Inj _ fi) -> typeByteSize <$> variantTypes fi) <$> ts)
initCodeGenerator :: [MIR.Def] -> CodeGenerator
initCodeGenerator scs =
CodeGenerator
{ instructions = defaultStart
, functions = getFunctions scs
, constructors = getConstructors scs
, customTypes = getTypes scs
, variableCount = 0
, labelCount = 0
}
{-
run :: Err String -> IO ()
run s = do
let s' = case s of
Right s -> s
Left _ -> error "yo"
writeFile "output/llvm.ll" s'
putStrLn . trim =<< readCreateProcess (shell "lli") s'
test :: Integer -> Program
test v =
Program
[ DataType
(TIR.Ident "Craig")
[ Constructor (TIR.Ident "Bob") [MIR.Type (TIR.Ident "_Int")]
, Constructor (TIR.Ident "Betty") [MIR.Type (TIR.Ident "_Int")]
]
, DataType
(TIR.Ident "Alice")
[ Constructor (TIR.Ident "Eve") [MIR.Type (TIR.Ident "_Int")] -- ,
-- (TIR.Ident "Alice", [TInt, TInt])
]
, Bind (TIR.Ident "fibonacci", MIR.Type (TIR.Ident "_Int")) [(TIR.Ident "x", MIR.Type (TIR.Ident "_Int"))] (EId ("x", MIR.Type (TIR.Ident "Craig")), MIR.Type (TIR.Ident "Craig"))
, Bind (TIR.Ident "main", MIR.Type (TIR.Ident "_Int")) []
-- (EApp (MIR.Type (TIR.Ident "Craig")) (EId (TIR.Ident "Craig_Bob", MIR.Type (TIR.Ident "Craig")), MIR.Type (TIR.Ident "Craig")) (ELit (LInt v), MIR.Type (TIR.Ident "_Int")), MIR.Type (TIR.Ident "Craig"))-- (EInt 92)
$
eCaseInt
(EApp (MIR.TLit (TIR.Ident "Craig")) (EId (TIR.Ident "Craig_Bob", MIR.TLit (TIR.Ident "Craig")), MIR.TLit (TIR.Ident "Craig")) (ELit (LInt v), MIR.Type (TIR.Ident "_Int")), MIR.Type (TIR.Ident "Craig"))
[ injectionCons "Craig_Bob" "Craig" [CIdent (TIR.Ident "x")] (EId (TIR.Ident "x", MIR.Type (TIR.Ident "_Int")), MIR.Type (TIR.Ident "_Int"))
, injectionCons "Craig_Betty" "Craig" [CLit (LInt 5)] (int 2)
, Injection (CIdent (TIR.Ident "z")) (int 3)
, -- , injectionInt 5 (int 6)
injectionCatchAll (int 10)
]
]
where
injectionCons x y xs = Injection (CCons (TIR.Ident x, MIR.Type (TIR.Ident y)) xs)
injectionInt x = Injection (CLit (LInt x))
injectionCatchAll = Injection CatchAll
eCaseInt x xs = (ECase (MIR.TLit (MIR.Ident "_Int")) x xs, MIR.TLit (MIR.Ident "_Int"))
int x = (ELit (LInt x), MIR.TLit (MIR.Ident "_Int"))
-}
import Monomorphizer.MonomorphizerIr as MIR (Program (..))
{- | Compiles an AST and produces a LLVM Ir string.
An easy way to actually "compile" this output is to
@ -188,397 +20,3 @@ generateCode :: MIR.Program -> Err String
generateCode (MIR.Program scs) = do
let codegen = initCodeGenerator scs
llvmIrToString . instructions <$> execStateT (compileScs scs) codegen
compileScs :: [MIR.Def] -> CompilerState ()
compileScs [] = do
emit $ UnsafeRaw "\n"
-- as a last step create all the constructors
-- //TODO maybe merge this with the data type match?
c <- gets (Map.toList . constructors)
mapM_
( \(id, ci) -> do
let t = returnTypeCI ci
let t' = type2LlvmType t
let x = BI.second type2LlvmType <$> argumentsCI ci
emit $ Define FastCC t' id x
top <- getNewVar
ptr <- getNewVar
-- allocated the primary type
emit $ SetVariable top (Alloca t')
-- set the first byte to the index of the constructor
emit $
SetVariable ptr $
GetElementPtr
t'
(Ref t')
(VIdent top I8)
I64
(VInteger 0)
I32
(VInteger 0)
emit $ Store I8 (VInteger $ numCI ci) (Ref I8) ptr
-- get a pointer of the correct type
ptr' <- getNewVar
emit $ SetVariable ptr' (Bitcast (Ref t') (VIdent top Ptr) (Ref $ CustomType id))
cTypes <- gets customTypes
enumerateOneM_
( \i (TIR.Ident arg_n, arg_t) -> do
let arg_t' = type2LlvmType arg_t
emit $ Comment (toIr arg_t' <> " " <> arg_n <> " " <> show i)
elemPtr <- getNewVar
emit $
SetVariable
elemPtr
( GetElementPtr
(CustomType id)
(Ref (CustomType id))
(VIdent ptr' Ptr)
I64
(VInteger 0)
I32
(VInteger i)
)
case Map.lookup arg_t' cTypes of
Just s -> do
emit $ Comment "Malloc and store"
heapPtr <- getNewVar
emit $ SetVariable heapPtr (Malloca s)
emit $ Store arg_t' (VIdent (TIR.Ident arg_n) arg_t') Ptr heapPtr
emit $ Store (Ref arg_t') (VIdent heapPtr arg_t') Ptr elemPtr
Nothing -> do
emit $ Comment "Just store"
emit $ Store arg_t' (VIdent (TIR.Ident arg_n) arg_t') Ptr elemPtr
)
(argumentsCI ci)
-- load and return the constructed value
emit $ Comment "Return the newly constructed value"
load <- getNewVar
emit $ SetVariable load (Load t' Ptr top)
emit $ Ret t' (VIdent load t')
emit DefineEnd
emit $ UnsafeRaw "\n"
modify $ \s -> s{variableCount = 0}
)
c
compileScs (MIR.DBind (MIR.Bind (name, t) args exp) : xs) = do
let t_return = type2LlvmType . last . flattenType $ t
emit $ UnsafeRaw "\n"
emit . Comment $ show name <> ": " <> show exp
let args' = map (second type2LlvmType) args
emit $ Define FastCC t_return name args'
when (name == "main") (mapM_ emit firstMainContent)
functionBody <- exprToValue exp
if name == "main"
then mapM_ emit $ lastMainContent functionBody
else emit $ Ret t_return functionBody
emit DefineEnd
modify $ \s -> s{variableCount = 0}
compileScs xs
compileScs (MIR.DData (MIR.Data typ ts) : xs) = do
let (TIR.Ident outer_id) = extractTypeName typ
-- //TODO this could be extracted from the customTypes map
let variantTypes fi = init $ map type2LlvmType (flattenType fi)
let biggestVariant = 7 + maximum (sum . (\(Inj _ fi) -> typeByteSize <$> variantTypes fi) <$> ts)
emit $ LIR.Type (TIR.Ident outer_id) [I8, Array biggestVariant I8]
typeSets <- gets customTypes
mapM_
( \(Inj inner_id fi) -> do
let types = (\s -> if Map.member s typeSets then Ref s else s) <$> variantTypes fi
emit $ LIR.Type inner_id (I8 : types)
)
ts
compileScs xs
firstMainContent :: [LLVMIr]
firstMainContent =
[]
-- UnsafeRaw "call void @_ZN2GC4Heap4initEv()\n"
lastMainContent :: LLVMValue -> [LLVMIr]
lastMainContent var =
[ UnsafeRaw $
"call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef " <> toIr var <> ")\n"
, Ret I64 (VInteger 0)
]
defaultStart :: [LLVMIr]
defaultStart =
[ UnsafeRaw "target triple = \"x86_64-pc-linux-gnu\"\n"
, UnsafeRaw "target datalayout = \"e-m:o-i64:64-f80:128-n8:16:32:64-S128\"\n"
, UnsafeRaw "@.str = private unnamed_addr constant [3 x i8] c\"%i\n\", align 1\n"
, UnsafeRaw "@.non_exhaustive_patterns = private unnamed_addr constant [41 x i8] c\"Non-exhaustive patterns in case at %i:%i\n\"\n"
, UnsafeRaw "declare i32 @printf(ptr noalias nocapture, ...)\n"
, UnsafeRaw "declare i32 @exit(i32 noundef)\n"
, UnsafeRaw "declare ptr @malloc(i32 noundef)\n"
, UnsafeRaw "declare void @_ZN2GC4Heap4initEv()\n"
, UnsafeRaw "declare void @_ZN2GC4Heap5allocEm()\n"
, UnsafeRaw "declare void @_ZN2GC4Heap7disposeEv()\n"
]
compileExp :: ExpT -> CompilerState ()
compileExp (MIR.ELit lit, _t) = emitLit lit
compileExp (MIR.EAdd e1 e2, t) = emitAdd t e1 e2
-- compileExp (ESub t e1 e2) = emitSub t e1 e2
compileExp (MIR.EVar name, _t) = emitIdent name
compileExp (MIR.EApp e1 e2, t) = emitApp t e1 e2
-- compileExp (EAbs t ti e) = emitAbs t ti e
compileExp (MIR.ELet bind e, _) = emitLet bind e
compileExp (MIR.ECase e cs, t) = emitECased t e (map (t,) cs)
-- go (EMul e1 e2) = emitMul e1 e2
-- go (EDiv e1 e2) = emitDiv e1 e2
-- go (EMod e1 e2) = emitMod e1 e2
--- aux functions ---
emitLet :: MIR.Bind -> ExpT -> CompilerState ()
emitLet (MIR.Bind id [] innerExp) e = do
evaled <- exprToValue innerExp
tempVar <- getNewVar
let t = type2LlvmType . snd $ innerExp
emit $ SetVariable tempVar (Alloca t)
emit $ Store (type2LlvmType . snd $ innerExp) evaled Ptr tempVar
emit $ SetVariable (fst id) (Load t Ptr tempVar)
compileExp e
emitLet b _ = error $ "Non empty argument list in let-bind " <> show b
emitECased :: MIR.Type -> ExpT -> [(MIR.Type, Branch)] -> CompilerState ()
emitECased t e cases = do
let cs = snd <$> cases
let ty = type2LlvmType t
let rt = type2LlvmType (snd e)
vs <- exprToValue e
lbl <- getNewLabel
let label = TIR.Ident $ "escape_" <> show lbl
stackPtr <- getNewVar
emit $ SetVariable stackPtr (Alloca ty)
mapM_ (emitCases rt ty label stackPtr vs) cs
-- crashLbl <- TIR.Ident . ("crash_" <>) . show <$> getNewLabel
-- emit $ Label crashLbl
emit . UnsafeRaw $ "call i32 (ptr, ...) @printf(ptr noundef @.non_exhaustive_patterns, i64 noundef 6, i64 noundef 6)\n"
emit . UnsafeRaw $ "call i32 @exit(i32 noundef 1)\n"
mapM_ (const increaseVarCount) [0 .. 1]
emit $ Br label
emit $ Label label
res <- getNewVar
emit $ SetVariable res (Load ty Ptr stackPtr)
where
emitCases :: LLVMType -> LLVMType -> TIR.Ident -> TIR.Ident -> LLVMValue -> Branch -> CompilerState ()
emitCases rt ty label stackPtr vs (Branch (MIR.PInj consId cs, _t) exp) = do
emit $ Comment "Inj"
cons <- gets constructors
let r = fromJust $ Map.lookup consId cons
lbl_failPos <- (\x -> TIR.Ident $ "failed_" <> show x) <$> getNewLabel
lbl_succPos <- (\x -> TIR.Ident $ "success_" <> show x) <$> getNewLabel
consVal <- getNewVar
emit $ SetVariable consVal (ExtractValue rt vs 0)
consCheck <- getNewVar
emit $ SetVariable consCheck (Icmp LLEq I8 (VIdent consVal I8) (VInteger $ numCI r))
emit $ BrCond (VIdent consCheck ty) lbl_succPos lbl_failPos
emit $ Label lbl_succPos
castPtr <- getNewVar
casted <- getNewVar
emit $ SetVariable castPtr (Alloca rt)
emit $ Store rt vs Ptr castPtr
emit $ SetVariable casted (Load (CustomType (coerce consId)) Ptr castPtr)
enumerateOneM_
( \i c -> do
case c of
PVar (x, topT) -> do
let topT' = type2LlvmType topT
let botT' = CustomType (coerce consId)
emit . Comment $ "ident " <> toIr topT'
cTypes <- gets customTypes
if Map.member topT' cTypes
then do
deref <- getNewVar
emit $ SetVariable deref (ExtractValue botT' (VIdent casted Ptr) i)
emit $ SetVariable x (Load topT' Ptr deref)
else emit $ SetVariable x (ExtractValue botT' (VIdent casted Ptr) i)
PLit (_l, _t) -> undefined
PInj _id _ps -> undefined
PCatch -> pure ()
PEnum _id -> undefined
)
cs
val <- exprToValue exp
emit $ Store ty val Ptr stackPtr
emit $ Br label
emit $ Label lbl_failPos
emitCases _rt ty label stackPtr vs (Branch (MIR.PLit i, t) exp) = do
emit $ Comment "Plit"
let i' = case i of
(MIR.LInt i, _) -> VInteger i
(MIR.LChar i, _) -> VChar (ord i)
ns <- getNewVar
lbl_failPos <- (\x -> TIR.Ident $ "failed_" <> show x) <$> getNewLabel
lbl_succPos <- (\x -> TIR.Ident $ "success_" <> show x) <$> getNewLabel
emit $ SetVariable ns (Icmp LLEq (type2LlvmType t) vs i')
emit $ BrCond (VIdent ns ty) lbl_succPos lbl_failPos
emit $ Label lbl_succPos
val <- exprToValue exp
emit $ Store ty val Ptr stackPtr
emit $ Br label
emit $ Label lbl_failPos
emitCases rt ty label stackPtr vs (Branch (MIR.PVar (id, _), _) exp) = do
emit $ Comment "Pvar"
-- //TODO this is pretty disgusting and would heavily benefit from a rewrite
valPtr <- getNewVar
emit $ SetVariable valPtr (Alloca rt)
emit $ Store rt vs Ptr valPtr
emit $ SetVariable id (Load rt Ptr valPtr)
val <- exprToValue exp
emit $ Store ty val Ptr stackPtr
emit $ Br label
lbl_failPos <- (\x -> TIR.Ident $ "failed_" <> show x) <$> getNewLabel
emit $ Label lbl_failPos
emitCases _rt ty label stackPtr _vs (Branch (MIR.PEnum _id, _) exp) = do
-- //TODO Penum wrong, acts as a catch all
emit $ Comment "Penum"
val <- exprToValue exp
emit $ Store ty val Ptr stackPtr
emit $ Br label
lbl_failPos <- (\x -> TIR.Ident $ "failed_" <> show x) <$> getNewLabel
emit $ Label lbl_failPos
emitCases _ ty label stackPtr _ (Branch (MIR.PCatch, _) exp) = do
emit $ Comment "Pcatch"
val <- exprToValue exp
emit $ Store ty val Ptr stackPtr
emit $ Br label
lbl_failPos <- (\x -> TIR.Ident $ "failed_" <> show x) <$> getNewLabel
emit $ Label lbl_failPos
emitApp :: MIR.Type -> ExpT -> ExpT -> CompilerState ()
emitApp rt e1 e2 = appEmitter e1 e2 []
where
appEmitter :: ExpT -> ExpT -> [ExpT] -> CompilerState ()
appEmitter e1 e2 stack = do
let newStack = e2 : stack
case e1 of
(MIR.EApp e1' e2', _) -> appEmitter e1' e2' newStack
(MIR.EVar name, t) -> do
args <- traverse exprToValue newStack
vs <- getNewVar
funcs <- gets functions
consts <- gets constructors
let visibility =
fromMaybe Local $
Global <$ Map.lookup name consts
<|> Global <$ Map.lookup (name, t) funcs
-- this piece of code could probably be improved, i.e remove the double `const Global`
args' = map (first valueGetType . dupe) args
call = Call FastCC (type2LlvmType rt) visibility name args'
emit $ Comment $ show rt
emit $ SetVariable vs call
x -> error $ "The unspeakable happened: " <> show x
emitIdent :: TIR.Ident -> CompilerState ()
emitIdent id = do
-- !!this should never happen!!
emit $ Comment "This should not have happened!"
emit $ Variable id
emit $ UnsafeRaw "\n"
emitLit :: MIR.Lit -> CompilerState ()
emitLit i = do
-- !!this should never happen!!
let (i', t) = case i of
(MIR.LInt i'') -> (VInteger i'', I64)
(MIR.LChar i'') -> (VChar $ ord i'', I8)
varCount <- getNewVar
emit $ Comment "This should not have happened!"
emit $ SetVariable varCount (Add t i' (VInteger 0))
emitAdd :: MIR.Type -> ExpT -> ExpT -> CompilerState ()
emitAdd t e1 e2 = do
v1 <- exprToValue e1
v2 <- exprToValue e2
v <- getNewVar
emit $ SetVariable v (Add (type2LlvmType t) v1 v2)
exprToValue :: ExpT -> CompilerState LLVMValue
exprToValue = \case
(MIR.ELit i, _t) -> pure $ case i of
(MIR.LInt i) -> VInteger i
(MIR.LChar i) -> VChar $ ord i
(MIR.EVar name, t) -> do
funcs <- gets functions
cons <- gets constructors
let res =
Map.lookup (name, t) funcs
<|> ( \c ->
FunctionInfo
{ numArgs = numArgsCI c
, arguments = argumentsCI c
}
)
<$> Map.lookup name cons
case res of
Just fi -> do
if numArgs fi == 0
then do
vc <- getNewVar
emit $
SetVariable
vc
(Call FastCC (type2LlvmType t) Global name [])
pure $ VIdent vc (type2LlvmType t)
else pure $ VFunction name Global (type2LlvmType t)
Nothing -> pure $ VIdent name (type2LlvmType t)
e -> do
compileExp e
v <- getVarCount
pure $ VIdent (TIR.Ident $ show v) (getType e)
type2LlvmType :: MIR.Type -> LLVMType
type2LlvmType (MIR.TLit id@(TIR.Ident name)) = case name of
"Int" -> I64
"Char" -> I8
_ -> CustomType id
type2LlvmType (MIR.TFun t xs) = do
let (t', xs') = function2LLVMType xs [type2LlvmType t]
Function t' xs'
where
function2LLVMType :: Type -> [LLVMType] -> (LLVMType, [LLVMType])
function2LLVMType (TFun t xs) s = function2LLVMType xs (type2LlvmType t : s)
function2LLVMType x s = (type2LlvmType x, s)
getType :: ExpT -> LLVMType
getType (_, t) = type2LlvmType t
extractTypeName :: MIR.Type -> TIR.Ident
extractTypeName (MIR.TLit id) = id
extractTypeName (MIR.TFun t xs) =
let (TIR.Ident i) = extractTypeName t
(TIR.Ident is) = extractTypeName xs
in TIR.Ident $ i <> "_$_" <> is
valueGetType :: LLVMValue -> LLVMType
valueGetType (VInteger _) = I64
valueGetType (VChar _) = I8
valueGetType (VIdent _ t) = t
valueGetType (VConstant s) = Array (fromIntegral $ length s) I8
valueGetType (VFunction _ _ t) = t
typeByteSize :: LLVMType -> Integer
typeByteSize I1 = 1
typeByteSize I8 = 1
typeByteSize I32 = 4
typeByteSize I64 = 8
typeByteSize Ptr = 8
typeByteSize (Ref _) = 8
typeByteSize (Function _ _) = 8
typeByteSize (Array n t) = n * typeByteSize t
typeByteSize (CustomType _) = 8
enumerateOneM_ :: Monad m => (Integer -> a -> m b) -> [a] -> m ()
enumerateOneM_ f = foldM_ (\i a -> f i a >> pure (i + 1)) 1

View file

@ -0,0 +1,141 @@
module Codegen.CompilerState where
import Auxiliary (snoc)
import Codegen.Auxillary (type2LlvmType, typeByteSize)
import Codegen.LlvmIr as LIR (LLVMIr (UnsafeRaw), LLVMType)
import Control.Monad.State (
StateT,
gets,
modify,
)
import Data.Map (Map)
import Data.Map qualified as Map
import Grammar.ErrM (Err)
import Monomorphizer.MonomorphizerIr as MIR
import TypeChecker.TypeCheckerIr qualified as TIR
-- | The record used as the code generator state
data CodeGenerator = CodeGenerator
{ instructions :: [LLVMIr]
, functions :: Map MIR.Id FunctionInfo
, customTypes :: Map LLVMType Integer
, constructors :: Map TIR.Ident ConstructorInfo
, variableCount :: Integer
, labelCount :: Integer
}
-- | A state type synonym
type CompilerState a = StateT CodeGenerator Err a
data FunctionInfo = FunctionInfo
{ numArgs :: Int
, arguments :: [Id]
}
deriving (Show)
data ConstructorInfo = ConstructorInfo
{ numArgsCI :: Int
, argumentsCI :: [Id]
, numCI :: Integer
, returnTypeCI :: MIR.Type
}
deriving (Show)
-- | Adds a instruction to the CodeGenerator state
emit :: LLVMIr -> CompilerState ()
emit l = modify $ \t -> t{instructions = Auxiliary.snoc l $ instructions t}
-- | Increases the variable counter in the CodeGenerator state
increaseVarCount :: CompilerState ()
increaseVarCount = modify $ \t -> t{variableCount = variableCount t + 1}
-- | Returns the variable count from the CodeGenerator state
getVarCount :: CompilerState Integer
getVarCount = gets variableCount
-- | Increases the variable count and returns it from the CodeGenerator state
getNewVar :: CompilerState TIR.Ident
getNewVar = TIR.Ident . show <$> (increaseVarCount >> getVarCount)
-- | Increses the label count and returns a label from the CodeGenerator state
getNewLabel :: CompilerState Integer
getNewLabel = do
modify (\t -> t{labelCount = labelCount t + 1})
gets labelCount
{- | Produces a map of functions infos from a list of binds,
which contains useful data for code generation.
-}
getFunctions :: [MIR.Def] -> Map Id FunctionInfo
getFunctions bs = Map.fromList $ go bs
where
go [] = []
go (MIR.DBind (MIR.Bind id args _) : xs) =
(id, FunctionInfo{numArgs = length args, arguments = args})
: go xs
go (_ : xs) = go xs
createArgs :: [MIR.Type] -> [Id]
createArgs xs = fst $ foldl (\(acc, l) t -> (acc ++ [(TIR.Ident ("arg_" <> show l), t)], l + 1)) ([], 0) xs
{- | Produces a map of functions infos from a list of binds,
which contains useful data for code generation.
-}
getConstructors :: [MIR.Def] -> Map TIR.Ident ConstructorInfo
getConstructors bs = Map.fromList $ go bs
where
go [] = []
go (MIR.DData (MIR.Data t cons) : xs) =
fst
( foldl
( \(acc, i) (Inj id xs) ->
( ( id
, ConstructorInfo
{ numArgsCI = length (init . flattenType $ xs)
, argumentsCI = createArgs (init . flattenType $ xs)
, numCI = i
, returnTypeCI = t -- last . flattenType $ xs
}
)
: acc
, i + 1
)
)
([], 0)
cons
)
<> go xs
go (_ : xs) = go xs
getTypes :: [MIR.Def] -> Map LLVMType Integer
getTypes bs = Map.fromList $ go bs
where
go [] = []
go (MIR.DData (MIR.Data t ts) : xs) = (type2LlvmType t, biggestVariant ts) : go xs
go (_ : xs) = go xs
variantTypes fi = init $ map type2LlvmType (flattenType fi)
biggestVariant ts = 8 + maximum (sum . (\(Inj _ fi) -> typeByteSize <$> variantTypes fi) <$> ts)
initCodeGenerator :: [MIR.Def] -> CodeGenerator
initCodeGenerator scs =
CodeGenerator
{ instructions = defaultStart
, functions = getFunctions scs
, constructors = getConstructors scs
, customTypes = getTypes scs
, variableCount = 0
, labelCount = 0
}
defaultStart :: [LLVMIr]
defaultStart =
[ UnsafeRaw "target triple = \"x86_64-pc-linux-gnu\"\n"
, UnsafeRaw "target datalayout = \"e-m:o-i64:64-f80:128-n8:16:32:64-S128\"\n"
, UnsafeRaw "@.str = private unnamed_addr constant [3 x i8] c\"%i\n\", align 1\n"
, UnsafeRaw "@.non_exhaustive_patterns = private unnamed_addr constant [41 x i8] c\"Non-exhaustive patterns in case at %i:%i\n\"\n"
, UnsafeRaw "declare i32 @printf(ptr noalias nocapture, ...)\n"
, UnsafeRaw "declare i32 @exit(i32 noundef)\n"
, UnsafeRaw "declare ptr @malloc(i32 noundef)\n"
, UnsafeRaw "declare void @_ZN2GC4Heap4initEv()\n"
, UnsafeRaw "declare void @_ZN2GC4Heap5allocEm()\n"
, UnsafeRaw "declare void @_ZN2GC4Heap7disposeEv()\n"
]

348
src/Codegen/Emits.hs Normal file
View file

@ -0,0 +1,348 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
module Codegen.Emits where
import Codegen.Auxillary
import Codegen.CompilerState
import Codegen.LlvmIr as LIR
import Control.Applicative ((<|>))
import Control.Monad (when)
import Control.Monad.State (
gets,
modify,
)
import Data.Bifunctor qualified as BI
import Data.Char (ord)
import Data.Coerce (coerce)
import Data.Map qualified as Map
import Data.Maybe (fromJust, fromMaybe)
import Data.Tuple.Extra (dupe, first, second)
import Monomorphizer.MonomorphizerIr as MIR
import TypeChecker.TypeCheckerIr qualified as TIR
compileScs :: [MIR.Def] -> CompilerState ()
compileScs [] = do
emit $ UnsafeRaw "\n"
-- as a last step create all the constructors
-- //TODO maybe merge this with the data type match?
c <- gets (Map.toList . constructors)
mapM_
( \(id, ci) -> do
let t = returnTypeCI ci
let t' = type2LlvmType t
let x = BI.second type2LlvmType <$> argumentsCI ci
emit $ Define FastCC t' id x
top <- getNewVar
ptr <- getNewVar
-- allocated the primary type
emit $ SetVariable top (Alloca t')
-- set the first byte to the index of the constructor
emit $
SetVariable ptr $
GetElementPtr
t'
(Ref t')
(VIdent top I8)
I64
(VInteger 0)
I32
(VInteger 0)
emit $ Store I8 (VInteger $ numCI ci) (Ref I8) ptr
-- get a pointer of the correct type
ptr' <- getNewVar
emit $ SetVariable ptr' (Bitcast (Ref t') (VIdent top Ptr) (Ref $ CustomType id))
cTypes <- gets customTypes
enumerateOneM_
( \i (TIR.Ident arg_n, arg_t) -> do
let arg_t' = type2LlvmType arg_t
emit $ Comment (toIr arg_t' <> " " <> arg_n <> " " <> show i)
elemPtr <- getNewVar
emit $
SetVariable
elemPtr
( GetElementPtr
(CustomType id)
(Ref (CustomType id))
(VIdent ptr' Ptr)
I64
(VInteger 0)
I32
(VInteger i)
)
case Map.lookup arg_t' cTypes of
Just s -> do
emit $ Comment "Malloc and store"
heapPtr <- getNewVar
emit $ SetVariable heapPtr (Malloca s)
emit $ Store arg_t' (VIdent (TIR.Ident arg_n) arg_t') Ptr heapPtr
emit $ Store (Ref arg_t') (VIdent heapPtr arg_t') Ptr elemPtr
Nothing -> do
emit $ Comment "Just store"
emit $ Store arg_t' (VIdent (TIR.Ident arg_n) arg_t') Ptr elemPtr
)
(argumentsCI ci)
-- load and return the constructed value
emit $ Comment "Return the newly constructed value"
load <- getNewVar
emit $ SetVariable load (Load t' Ptr top)
emit $ Ret t' (VIdent load t')
emit DefineEnd
emit $ UnsafeRaw "\n"
modify $ \s -> s{variableCount = 0}
)
c
compileScs (MIR.DBind (MIR.Bind (name, t) args exp) : xs) = do
let t_return = type2LlvmType . last . flattenType $ t
emit $ UnsafeRaw "\n"
emit . Comment $ show name <> ": " <> show exp
let args' = map (second type2LlvmType) args
emit $ Define FastCC t_return name args'
when (name == "main") (mapM_ emit firstMainContent)
functionBody <- exprToValue exp
if name == "main"
then mapM_ emit $ lastMainContent functionBody
else emit $ Ret t_return functionBody
emit DefineEnd
modify $ \s -> s{variableCount = 0}
compileScs xs
compileScs (MIR.DData (MIR.Data typ ts) : xs) = do
let (TIR.Ident outer_id) = extractTypeName typ
-- //TODO this could be extracted from the customTypes map
let variantTypes fi = init $ map type2LlvmType (flattenType fi)
let biggestVariant = 7 + maximum (sum . (\(Inj _ fi) -> typeByteSize <$> variantTypes fi) <$> ts)
emit $ LIR.Type (TIR.Ident outer_id) [I8, Array biggestVariant I8]
typeSets <- gets customTypes
mapM_
( \(Inj inner_id fi) -> do
let types = (\s -> if Map.member s typeSets then Ref s else s) <$> variantTypes fi
emit $ LIR.Type inner_id (I8 : types)
)
ts
compileScs xs
firstMainContent :: [LLVMIr]
firstMainContent = []
lastMainContent :: LLVMValue -> [LLVMIr]
lastMainContent var =
[ UnsafeRaw $
"call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef " <> toIr var <> ")\n"
, Ret I64 (VInteger 0)
]
compileExp :: ExpT -> CompilerState ()
compileExp (MIR.ELit lit, _t) = emitLit lit
compileExp (MIR.EAdd e1 e2, t) = emitAdd t e1 e2
compileExp (MIR.EVar name, _t) = emitIdent name
compileExp (MIR.EApp e1 e2, t) = emitApp t e1 e2
compileExp (MIR.ELet bind e, _) = emitLet bind e
compileExp (MIR.ECase e cs, t) = emitECased t e (map (t,) cs)
emitLet :: MIR.Bind -> ExpT -> CompilerState ()
emitLet (MIR.Bind id [] innerExp) e = do
evaled <- exprToValue innerExp
tempVar <- getNewVar
let t = type2LlvmType . snd $ innerExp
emit $ SetVariable tempVar (Alloca t)
emit $ Store (type2LlvmType . snd $ innerExp) evaled Ptr tempVar
emit $ SetVariable (fst id) (Load t Ptr tempVar)
compileExp e
emitLet b _ = error $ "Non empty argument list in let-bind " <> show b
emitECased :: MIR.Type -> ExpT -> [(MIR.Type, Branch)] -> CompilerState ()
emitECased t e cases = do
let cs = snd <$> cases
let ty = type2LlvmType t
let rt = type2LlvmType (snd e)
vs <- exprToValue e
lbl <- getNewLabel
let label = TIR.Ident $ "escape_" <> show lbl
stackPtr <- getNewVar
emit $ SetVariable stackPtr (Alloca ty)
mapM_ (emitCases rt ty label stackPtr vs) cs
-- crashLbl <- TIR.Ident . ("crash_" <>) . show <$> getNewLabel
-- emit $ Label crashLbl
emit . UnsafeRaw $ "call i32 (ptr, ...) @printf(ptr noundef @.non_exhaustive_patterns, i64 noundef 6, i64 noundef 6)\n"
emit . UnsafeRaw $ "call i32 @exit(i32 noundef 1)\n"
mapM_ (const increaseVarCount) [0 .. 1]
emit $ Br label
emit $ Label label
res <- getNewVar
emit $ SetVariable res (Load ty Ptr stackPtr)
where
emitCases :: LLVMType -> LLVMType -> TIR.Ident -> TIR.Ident -> LLVMValue -> Branch -> CompilerState ()
emitCases rt ty label stackPtr vs (Branch (MIR.PInj consId cs, _t) exp) = do
emit $ Comment "Inj"
cons <- gets constructors
let r = fromJust $ Map.lookup consId cons
lbl_failPos <- (\x -> TIR.Ident $ "failed_" <> show x) <$> getNewLabel
lbl_succPos <- (\x -> TIR.Ident $ "success_" <> show x) <$> getNewLabel
consVal <- getNewVar
emit $ SetVariable consVal (ExtractValue rt vs 0)
consCheck <- getNewVar
emit $ SetVariable consCheck (Icmp LLEq I8 (VIdent consVal I8) (VInteger $ numCI r))
emit $ BrCond (VIdent consCheck ty) lbl_succPos lbl_failPos
emit $ Label lbl_succPos
castPtr <- getNewVar
casted <- getNewVar
emit $ SetVariable castPtr (Alloca rt)
emit $ Store rt vs Ptr castPtr
emit $ SetVariable casted (Load (CustomType (coerce consId)) Ptr castPtr)
enumerateOneM_
( \i c -> do
case c of
PVar (x, topT) -> do
let topT' = type2LlvmType topT
let botT' = CustomType (coerce consId)
emit . Comment $ "ident " <> toIr topT'
cTypes <- gets customTypes
if Map.member topT' cTypes
then do
deref <- getNewVar
emit $ SetVariable deref (ExtractValue botT' (VIdent casted Ptr) i)
emit $ SetVariable x (Load topT' Ptr deref)
else emit $ SetVariable x (ExtractValue botT' (VIdent casted Ptr) i)
PLit (_l, _t) -> undefined
PInj _id _ps -> undefined
PCatch -> pure ()
PEnum _id -> undefined
)
cs
val <- exprToValue exp
emit $ Store ty val Ptr stackPtr
emit $ Br label
emit $ Label lbl_failPos
emitCases _rt ty label stackPtr vs (Branch (MIR.PLit i, t) exp) = do
emit $ Comment "Plit"
let i' = case i of
(MIR.LInt i, _) -> VInteger i
(MIR.LChar i, _) -> VChar (ord i)
ns <- getNewVar
lbl_failPos <- (\x -> TIR.Ident $ "failed_" <> show x) <$> getNewLabel
lbl_succPos <- (\x -> TIR.Ident $ "success_" <> show x) <$> getNewLabel
emit $ SetVariable ns (Icmp LLEq (type2LlvmType t) vs i')
emit $ BrCond (VIdent ns ty) lbl_succPos lbl_failPos
emit $ Label lbl_succPos
val <- exprToValue exp
emit $ Store ty val Ptr stackPtr
emit $ Br label
emit $ Label lbl_failPos
emitCases rt ty label stackPtr vs (Branch (MIR.PVar (id, _), _) exp) = do
emit $ Comment "Pvar"
-- //TODO this is pretty disgusting and would heavily benefit from a rewrite
valPtr <- getNewVar
emit $ SetVariable valPtr (Alloca rt)
emit $ Store rt vs Ptr valPtr
emit $ SetVariable id (Load rt Ptr valPtr)
val <- exprToValue exp
emit $ Store ty val Ptr stackPtr
emit $ Br label
lbl_failPos <- (\x -> TIR.Ident $ "failed_" <> show x) <$> getNewLabel
emit $ Label lbl_failPos
emitCases _rt ty label stackPtr _vs (Branch (MIR.PEnum _id, _) exp) = do
-- //TODO Penum wrong, acts as a catch all
emit $ Comment "Penum"
val <- exprToValue exp
emit $ Store ty val Ptr stackPtr
emit $ Br label
lbl_failPos <- (\x -> TIR.Ident $ "failed_" <> show x) <$> getNewLabel
emit $ Label lbl_failPos
emitCases _ ty label stackPtr _ (Branch (MIR.PCatch, _) exp) = do
emit $ Comment "Pcatch"
val <- exprToValue exp
emit $ Store ty val Ptr stackPtr
emit $ Br label
lbl_failPos <- (\x -> TIR.Ident $ "failed_" <> show x) <$> getNewLabel
emit $ Label lbl_failPos
emitApp :: MIR.Type -> ExpT -> ExpT -> CompilerState ()
emitApp rt e1 e2 = appEmitter e1 e2 []
where
appEmitter :: ExpT -> ExpT -> [ExpT] -> CompilerState ()
appEmitter e1 e2 stack = do
let newStack = e2 : stack
case e1 of
(MIR.EApp e1' e2', _) -> appEmitter e1' e2' newStack
(MIR.EVar name, t) -> do
args <- traverse exprToValue newStack
vs <- getNewVar
funcs <- gets functions
consts <- gets constructors
let visibility =
fromMaybe Local $
Global <$ Map.lookup name consts
<|> Global <$ Map.lookup (name, t) funcs
-- this piece of code could probably be improved, i.e remove the double `const Global`
args' = map (first valueGetType . dupe) args
call = Call FastCC (type2LlvmType rt) visibility name args'
emit $ Comment $ show rt
emit $ SetVariable vs call
x -> error $ "The unspeakable happened: " <> show x
emitIdent :: TIR.Ident -> CompilerState ()
emitIdent id = do
-- !!this should never happen!!
emit $ Comment "This should not have happened!"
emit $ Variable id
emit $ UnsafeRaw "\n"
emitLit :: MIR.Lit -> CompilerState ()
emitLit i = do
-- !!this should never happen!!
let (i', t) = case i of
(MIR.LInt i'') -> (VInteger i'', I64)
(MIR.LChar i'') -> (VChar $ ord i'', I8)
varCount <- getNewVar
emit $ Comment "This should not have happened!"
emit $ SetVariable varCount (Add t i' (VInteger 0))
emitAdd :: MIR.Type -> ExpT -> ExpT -> CompilerState ()
emitAdd t e1 e2 = do
v1 <- exprToValue e1
v2 <- exprToValue e2
v <- getNewVar
emit $ SetVariable v (Add (type2LlvmType t) v1 v2)
exprToValue :: ExpT -> CompilerState LLVMValue
exprToValue = \case
(MIR.ELit i, _t) -> pure $ case i of
(MIR.LInt i) -> VInteger i
(MIR.LChar i) -> VChar $ ord i
(MIR.EVar name, t) -> do
funcs <- gets functions
cons <- gets constructors
let res =
Map.lookup (name, t) funcs
<|> ( \c ->
FunctionInfo
{ numArgs = numArgsCI c
, arguments = argumentsCI c
}
)
<$> Map.lookup name cons
case res of
Just fi -> do
if numArgs fi == 0
then do
vc <- getNewVar
emit $
SetVariable
vc
(Call FastCC (type2LlvmType t) Global name [])
pure $ VIdent vc (type2LlvmType t)
else pure $ VFunction name Global (type2LlvmType t)
Nothing -> pure $ VIdent name (type2LlvmType t)
e -> do
compileExp e
v <- getVarCount
pure $ VIdent (TIR.Ident $ show v) (getType e)

View file

@ -31,7 +31,7 @@ bind x f = case x of {
-- represents minus one :)
minusOne : Int ;
minusOne = 9223372036854775807 + 9223372036854775807 + 1;
{-
---- LIST STUFF ----
-- a simple list data type containing ints
data List () where {
@ -69,4 +69,3 @@ repeat x n = case n of {
0 => Nil ;
n => Cons x (repeat x (n + minusOne)) ;
};
-}