Add closures and fix lets in monomorphizer

This commit is contained in:
Martin Fredin 2023-05-06 22:49:08 +02:00
parent 677a200a15
commit 72e599d5de
26 changed files with 1440 additions and 692 deletions

View file

@ -1,25 +1,25 @@
module Codegen.Auxillary where
import Codegen.LlvmIr (LLVMType (..), LLVMValue (..))
import Control.Monad (foldM_)
import Monomorphizer.MonomorphizerIr as MIR (ExpT, Type (..))
import TypeChecker.TypeCheckerIr qualified as TIR
import Codegen.LlvmIr (LLVMType (..), LLVMValue (..))
import Control.Monad (foldM_)
import Monomorphizer.MonomorphizerIr as MIR (Exp, T, Type (..))
import qualified TypeChecker.TypeCheckerIr as TIR
type2LlvmType :: MIR.Type -> LLVMType
type2LlvmType (MIR.TLit id@(TIR.Ident name)) = case name of
"Int" -> I64
"Int" -> I64
"Char" -> I8
"Bool" -> I1
_ -> CustomType id
_ -> CustomType id
type2LlvmType (MIR.TFun t xs) = do
let (t', xs') = function2LLVMType xs [type2LlvmType t]
Function t' xs'
where
function2LLVMType :: Type -> [LLVMType] -> (LLVMType, [LLVMType])
function2LLVMType (TFun t xs) s = function2LLVMType xs (type2LlvmType t : s)
function2LLVMType x s = (type2LlvmType x, s)
function2LLVMType x s = (type2LlvmType x, s)
getType :: ExpT -> LLVMType
getType :: T Exp -> LLVMType
getType (_, t) = type2LlvmType t
extractTypeName :: MIR.Type -> TIR.Ident
@ -30,21 +30,21 @@ extractTypeName (MIR.TFun t xs) =
in TIR.Ident $ i <> "_$_" <> is
valueGetType :: LLVMValue -> LLVMType
valueGetType (VInteger _) = I64
valueGetType (VChar _) = I8
valueGetType (VIdent _ t) = t
valueGetType (VConstant s) = Array (fromIntegral $ length s) I8
valueGetType (VInteger _) = I64
valueGetType (VChar _) = I8
valueGetType (VIdent _ t) = t
valueGetType (VConstant s) = Array (fromIntegral $ length s) I8
valueGetType (VFunction _ _ t) = t
typeByteSize :: LLVMType -> Integer
typeByteSize I1 = 1
typeByteSize I8 = 1
typeByteSize I32 = 4
typeByteSize I64 = 8
typeByteSize Ptr = 8
typeByteSize (Ref _) = 8
typeByteSize I1 = 1
typeByteSize I8 = 1
typeByteSize I32 = 4
typeByteSize I64 = 8
typeByteSize Ptr = 8
typeByteSize (Ref _) = 8
typeByteSize (Function _ _) = 8
typeByteSize (Array n t) = n * typeByteSize t
typeByteSize (Array n t) = n * typeByteSize t
typeByteSize (CustomType _) = 8
enumerateOneM_ :: Monad m => (Integer -> a -> m b) -> [a] -> m ()

View file

@ -1,18 +1,24 @@
{-# LANGUAGE OverloadedRecordDot #-}
{-# LANGUAGE OverloadedStrings #-}
module Codegen.Codegen (generateCode) where
import Codegen.CompilerState (
CodeGenerator (instructions),
initCodeGenerator,
)
import Codegen.Emits (compileScs)
import Codegen.LlvmIr as LIR (llvmIrToString)
import Control.Monad.State (
execStateT,
)
import Data.List (sortBy)
import Grammar.ErrM (Err)
import Monomorphizer.MonomorphizerIr as MIR (Bind (..), Data (..), Def (DBind, DData), Program (..), Type (TLit))
import TypeChecker.TypeCheckerIr (Ident (..))
import Codegen.CompilerState (CodeGenerator (..),
StructType (inst),
initCodeGenerator)
import Codegen.Emits (compileScs)
import Codegen.LlvmIr as LIR (LLVMIr (UnsafeRaw),
llvmIrToString)
import Control.Monad.State (execStateT)
import Data.Functor ((<&>))
import Data.List (sortBy)
import qualified Data.Map as Map
import Grammar.ErrM (Err)
import Monomorphizer.MonomorphizerIr as MIR (Bind (..), Data (..),
Def (DBind, DData),
Program (..),
Type (TLit))
import TypeChecker.TypeCheckerIr (Ident (..))
{- | Compiles an AST and produces a LLVM Ir string.
An easy way to actually "compile" this output is to
@ -20,16 +26,43 @@ import TypeChecker.TypeCheckerIr (Ident (..))
-}
generateCode :: MIR.Program -> Bool -> Err String
generateCode (MIR.Program scs) addGc = do
let tree = filter (not . detectPrelude) (sortBy lowData scs)
let codegen = initCodeGenerator addGc tree
llvmIrToString . instructions <$> execStateT (compileScs tree) codegen
let tree = filter (not . detectPrelude) (sortBy lowData scs)
codegen = initCodeGenerator addGc tree
-- Append instructions
execStateT (compileScs tree) codegen <&> \state ->
llvmIrToString $ defaultStart
++ (if addGc then gcStart else [])
++ map inst (Map.elems state.structTypes)
++ state.instructions
detectPrelude :: Def -> Bool
detectPrelude (DData (Data (TLit (Ident "Bool")) _)) = True
detectPrelude (DData (Data (TLit (Ident "Bool")) _)) = True
detectPrelude (DBind (Bind (Ident ('l' : 't' : '$' : _), _) _ _)) = True
detectPrelude _ = False
detectPrelude _ = False
lowData :: Def -> Def -> Ordering
lowData (DData _) (DBind _) = LT
lowData (DBind _) (DData _) = GT
lowData _ _ = EQ
lowData _ _ = EQ
defaultStart :: [LLVMIr]
defaultStart =
[ UnsafeRaw "target triple = \"x86_64-pc-linux-gnu\"\n"
, UnsafeRaw "target datalayout = \"e-m:o-i64:64-f80:128-n8:16:32:64-S128\"\n"
, UnsafeRaw "@.str = private unnamed_addr constant [3 x i8] c\"%i\n\", align 1\n"
, UnsafeRaw "@.non_exhaustive_patterns = private unnamed_addr constant [41 x i8] c\"Non-exhaustive patterns in case at %i:%i\n\"\n"
, UnsafeRaw "declare i32 @printf(ptr noalias nocapture, ...)\n"
, UnsafeRaw "declare i32 @exit(i32 noundef)\n"
, UnsafeRaw "declare ptr @malloc(i32 noundef)\n"
]
gcStart :: [LLVMIr]
gcStart =
[ UnsafeRaw "declare external void @cheap_init()\n"
, UnsafeRaw "declare external ptr @cheap_alloc(i64)\n"
, UnsafeRaw "declare external void @cheap_dispose()\n"
, UnsafeRaw "declare external ptr @cheap_the()\n"
, UnsafeRaw "declare external void @cheap_set_profiler(ptr, i1)\n"
, UnsafeRaw "declare external void @cheap_profiler_log_options(ptr, i64)\n"
]

View file

@ -1,46 +1,101 @@
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedRecordDot #-}
{-# LANGUAGE OverloadedStrings #-}
module Codegen.CompilerState where
import Auxiliary (snoc)
import Codegen.Auxillary (type2LlvmType, typeByteSize)
import Codegen.LlvmIr as LIR (LLVMIr (UnsafeRaw),
LLVMType)
import Control.Monad.State (StateT, gets, modify)
import Codegen.LlvmIr as LIR (LLVMIr (SetVariable, Type),
LLVMType (CustomType, Function, I64, Ptr),
LLVMValue (VFunction, VIdent),
Visibility (Global),
typeOf)
import Control.Monad.State (StateT, gets, modify, void)
import Data.Map (Map)
import qualified Data.Map as Map
import Grammar.ErrM (Err)
import Monomorphizer.MonomorphizerIr as MIR
import Monomorphizer.MonomorphizerIr (Ident (..), Inj (..), T,
flattenType)
import qualified Monomorphizer.MonomorphizerIr as MIR
import qualified TypeChecker.TypeCheckerIr as TIR
-- | The record used as the code generator state
data CodeGenerator = CodeGenerator
{ instructions :: [LLVMIr]
, functions :: Map MIR.Id FunctionInfo
, functions :: Map (T Ident) FunctionInfo
, customTypes :: Map LLVMType Integer
, constructors :: Map TIR.Ident ConstructorInfo
, constructors :: Map Ident ConstructorInfo
, variableCount :: Integer
, labelCount :: Integer
, gcEnabled :: Bool
, structTypes :: Map Ident StructType
-- ^ Custom stucture types
, locals :: [(Ident, LocalElem)]
-- ^ Arguments and variables in local environment
, globals :: Map Ident (LLVMType, LLVMValue)
}
data StructType = StructType
{ ptr :: LLVMType
, typs :: [LLVMType]
, inst :: LLVMIr
}
data LocalElem = LocalElem
{ typ :: LLVMType
, val :: LLVMValue
}
-- | A state type synonym
type CompilerState a = StateT CodeGenerator Err a
data FunctionInfo = FunctionInfo
{ numArgs :: Int
, arguments :: [Id]
, arguments :: [T Ident]
}
deriving (Show)
data ConstructorInfo = ConstructorInfo
{ numArgsCI :: Int
, argumentsCI :: [Id]
, argumentsCI :: [T Ident]
, numCI :: Integer
, returnTypeCI :: MIR.Type
}
deriving (Show)
addStructType_ :: Ident -> [LLVMType] -> CompilerState ()
addStructType_ = fmap void . addStructType
addStructType :: Ident -> [LLVMType] -> CompilerState LLVMType
addStructType x ts = do
modify $ \s -> s { structTypes = Map.insert x struct s.structTypes }
pure t
where
struct = StructType
{ ptr = t
, typs = ts
, inst = Type x ts
}
t = CustomType x
-- | Adds a instruction to the CodeGenerator state
emit :: LLVMIr -> CompilerState ()
emit l = modify $ \t -> t{instructions = Auxiliary.snoc l $ instructions t}
-- Add variable to environment
emit l@(SetVariable x _) = modify $ \t ->
t { instructions = Auxiliary.snoc l t.instructions
, locals = snoc (x, local)
t.locals
}
where
local = LocalElem { typ = typeOf l
, val = VIdent x (typeOf l)
}
emit l = modify $ \t -> t { instructions = Auxiliary.snoc l t.instructions }
-- | Increases the variable counter in the CodeGenerator state
increaseVarCount :: CompilerState ()
@ -63,16 +118,19 @@ getNewLabel = do
{- | Produces a map of functions infos from a list of binds,
which contains useful data for code generation.
-}
getFunctions :: [MIR.Def] -> Map Id FunctionInfo
getFunctions :: [MIR.Def] -> Map (T Ident) FunctionInfo
getFunctions bs = Map.fromList $ go bs
where
go [] = []
go (MIR.DBind (MIR.Bind id args _) : xs) =
(id, FunctionInfo{numArgs = length args, arguments = args})
: go xs
(id, FunctionInfo { numArgs = length args
, arguments = args
}
)
: go xs
go (_ : xs) = go xs
createArgs :: [MIR.Type] -> [Id]
createArgs :: [MIR.Type] -> [T Ident]
createArgs xs = fst $ foldl (\(acc, l) t -> (acc ++ [(TIR.Ident ("arg_" <> show l), t)], l + 1)) ([], 0) xs
{- | Produces a map of functions infos from a list of binds,
@ -113,35 +171,43 @@ getTypes bs = Map.fromList $ go bs
variantTypes fi = init $ map type2LlvmType (flattenType fi)
biggestVariant ts = 8 + maximum (sum . (\(Inj _ fi) -> typeByteSize <$> variantTypes fi) <$> ts)
getGlobals :: [MIR.Def] -> Map Ident (LLVMType, LLVMValue)
getGlobals scs = Map.fromList [ go b | MIR.DBind b <- scs ]
where
go bind | x == "main" = let typ = Function I64 []
in (x, (typ, VFunction x Global typ))
| otherwise = (x, (typ, VFunction x Global typ))
where
typ = Function tr $ Ptr : ts
Function tr ts = type2LlvmType' t
(x, t) = case bind of
MIR.Bind xt _ _ -> xt
MIR.BindC _ xt _ _ -> xt
-- Higher order function arguments are replaced with ptr
type2LlvmType' = go []
where
go acc = \case
MIR.TFun (MIR.TFun _ _) t2 -> go (snoc Ptr acc) t2
MIR.TFun t1 t2 -> go (snoc (type2LlvmType t1) acc) t2
t -> Function (type2LlvmType t) acc
initCodeGenerator :: Bool -> [MIR.Def] -> CodeGenerator
initCodeGenerator addGc scs =
CodeGenerator
{ instructions = defaultStart <> if addGc then gcStart else []
{ instructions = []
, functions = getFunctions scs
, constructors = getConstructors scs
, customTypes = getTypes scs
, structTypes = mempty
, variableCount = 0
, labelCount = 0
, gcEnabled = addGc
, locals = mempty
, globals = getGlobals scs
}
defaultStart :: [LLVMIr]
defaultStart =
[ UnsafeRaw "target triple = \"x86_64-pc-linux-gnu\"\n"
, UnsafeRaw "target datalayout = \"e-m:o-i64:64-f80:128-n8:16:32:64-S128\"\n"
, UnsafeRaw "@.str = private unnamed_addr constant [3 x i8] c\"%i\n\", align 1\n"
, UnsafeRaw "@.non_exhaustive_patterns = private unnamed_addr constant [41 x i8] c\"Non-exhaustive patterns in case at %i:%i\n\"\n"
, UnsafeRaw "declare i32 @printf(ptr noalias nocapture, ...)\n"
, UnsafeRaw "declare i32 @exit(i32 noundef)\n"
, UnsafeRaw "declare ptr @malloc(i32 noundef)\n"
]
gcStart :: [LLVMIr]
gcStart =
[ UnsafeRaw "declare external void @cheap_init()\n"
, UnsafeRaw "declare external ptr @cheap_alloc(i64)\n"
, UnsafeRaw "declare external void @cheap_dispose()\n"
, UnsafeRaw "declare external ptr @cheap_the()\n"
, UnsafeRaw "declare external void @cheap_set_profiler(ptr, i1)\n"
, UnsafeRaw "declare external void @cheap_profiler_log_options(ptr, i64)\n"
]

View file

@ -1,36 +1,40 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedRecordDot #-}
{-# LANGUAGE OverloadedStrings #-}
module Codegen.Emits where
import Codegen.Auxillary
import Codegen.CompilerState
import Codegen.LlvmIr as LIR
import Control.Applicative ((<|>))
import Control.Monad (when)
import Control.Monad.State (gets, modify)
import Data.Bifunctor qualified as BI
import Data.Char (ord)
import Data.Coerce (coerce)
import Data.Map qualified as Map
import Data.Maybe (fromJust, fromMaybe, isNothing)
import Data.Tuple.Extra (dupe, first, second)
import Debug.Trace (trace, traceShow)
import Grammar.Print
import Monomorphizer.MonomorphizerIr as MIR
import TypeChecker.TypeCheckerIr qualified as TIR
import Auxiliary (snoc)
import Codegen.Auxillary
import Codegen.CompilerState
import Codegen.LlvmIr as LIR
import Control.Applicative (Applicative (liftA2), (<|>))
import Control.Monad (forM_, when, zipWithM_)
import Control.Monad.Extra (whenJust)
import Control.Monad.State (gets, modify)
import Data.Char (ord)
import Data.Coerce (coerce)
import Data.Foldable.Extra (notNull)
import qualified Data.Map as Map
import Data.Maybe (fromJust, fromMaybe, isNothing)
import Data.Tuple.Extra (second)
import Grammar.Print (printTree)
import Monomorphizer.MonomorphizerIr
compileScs :: [MIR.Def] -> CompilerState ()
compileScs :: [Def] -> CompilerState ()
compileScs [] = do
emit $ UnsafeRaw "\n"
mapM_ createConstructor =<< gets (Map.toList . constructors)
-- as a last step create all the constructors
-- //TODO maybe merge this with the data type match?
c <- gets (Map.toList . constructors)
mapM_
( \(id, ci) -> do
let t = returnTypeCI ci
let t' = type2LlvmType t
let x = BI.second type2LlvmType <$> argumentsCI ci
where
createConstructor (id, ci) = do
let t = returnTypeCI ci
t' = type2LlvmType t
x = (mkCxtName, Ptr) : map (second type2LlvmType) ci.argumentsCI
emit $ Define FastCC t' id x
top <- getNewVar
ptr <- getNewVar
@ -56,7 +60,7 @@ compileScs [] = do
cTypes <- gets customTypes
enumerateOneM_
( \i (TIR.Ident arg_n, arg_t) -> do
( \i (Ident arg_n, arg_t) -> do
let arg_t' = type2LlvmType arg_t
emit $ Comment (toIr arg_t' <> " " <> arg_n <> " " <> show i)
elemPtr <- getNewVar
@ -78,11 +82,11 @@ compileScs [] = do
heapPtr <- getNewVar
useGc <- gets gcEnabled
emit $ SetVariable heapPtr (if useGc then GcMalloc s else Malloc s)
emit $ Store arg_t' (VIdent (TIR.Ident arg_n) arg_t') Ptr heapPtr
emit $ Store arg_t' (VIdent (Ident arg_n) arg_t') Ptr heapPtr
emit $ Store (Ref arg_t') (VIdent heapPtr arg_t') Ptr elemPtr
Nothing -> do
emit $ Comment "Just store"
emit $ Store arg_t' (VIdent (TIR.Ident arg_n) arg_t') Ptr elemPtr
emit $ Store arg_t' (VIdent (Ident arg_n) arg_t') Ptr elemPtr
)
(argumentsCI ci)
@ -95,34 +99,83 @@ compileScs [] = do
emit $ UnsafeRaw "\n"
modify $ \s -> s{variableCount = 0}
)
c
compileScs (MIR.DBind (MIR.Bind (name, t) args exp) : xs) = do
let t_return = type2LlvmType . last . flattenType $ t
compileScs (DBind bind : xs) = do
emit $ UnsafeRaw "\n"
emit . Comment $ show name <> ": " <> show exp
let args' = map (second type2LlvmType) args
emit . Comment $ show name <> ": " <> show (fst exp)
Function t_return t_args <- gets $ fst
. fromJust
. Map.lookup name
. globals
let args' = zip (mkCxtName : map fst args) t_args
emit $ Define FastCC t_return name args'
useGc <- gets gcEnabled
when (name == "main") (mapM_ emit (firstMainContent useGc))
functionBody <- exprToValue exp
if name == "main"
then mapM_ emit $ lastMainContent useGc functionBody
else emit $ Ret t_return functionBody
modify $ \s -> s { locals = foldr insertArg s.locals args' }
-- Dereference ptr arguments
when (notNull args') $
forM_ (tail args') $ \(x, t) -> when (t == Ptr) $ do
let t_deref =
let
Function t ts = type2LlvmType . fromJust $ lookup x args
in
Function t (Ptr : ts)
emit . SetVariable (mkDerefName x)
$ Load t_deref Ptr x
whenJust mcxt loadFreeVars
gcEnabled <- gets gcEnabled
when isMain $ mapM_ emit (firstMainContent gcEnabled)
result <- exprToValue exp
if isMain
then mapM_ emit $ lastMainContent gcEnabled result
else emit $ Ret t_return result
emit DefineEnd
modify $ \s -> s{variableCount = 0}
-- Reset variable count and empty locals
modify $ \s -> s { variableCount = 0, locals = mempty }
compileScs xs
compileScs (MIR.DData (MIR.Data typ ts) : xs) = do
let (TIR.Ident outer_id) = extractTypeName typ
where
loadFreeVars cxt = do
emit $ Comment "Load free variables"
zipWithM_ go cxt' [1 ..]
where
go (x, t) i = do
vc <- getNewVar
emit . SetVariable vc
$ GetElementPtrInbounds (CustomType $ mkClosureName name) Ptr (VIdent mkCxtName Ptr)
I32 (VInteger 0) I32 (VInteger i) -- TODO fix indices
emit . SetVariable x $ Load t Ptr vc
cxt' = map (second type2LlvmType) cxt
isMain = name == "main"
(name, args, exp, mcxt) = case bind of
Bind (name, _) args exp -> (name, args, exp, Nothing)
BindC cxt (name, _) args exp -> (name, args, exp, Just cxt)
insertArg (x, t) = snoc (x, LocalElem { val = VIdent x t, typ = t })
compileScs (DData (Data typ ts) : xs) = do
let (Ident outer_id) = extractTypeName typ
-- //TODO this could be extracted from the customTypes map
let variantTypes fi = init $ map type2LlvmType (flattenType fi)
let biggestVariant = 7 + maximum (sum . (\(Inj _ fi) -> typeByteSize <$> variantTypes fi) <$> ts)
emit $ LIR.Type (TIR.Ident outer_id) [I8, Array biggestVariant I8]
-- Add data type (e.g. %List) to top of the file
addStructType_ (Ident outer_id) [I8, Array biggestVariant I8]
typeSets <- gets customTypes
mapM_
( \(Inj inner_id fi) -> do
let types = (\s -> if Map.member s typeSets then Ref s else s) <$> variantTypes fi
emit $ LIR.Type inner_id (I8 : types)
-- Add constructor type (e.g. %Cons) to top of the file
addStructType_ inner_id (I8 : types)
)
ts
compileScs xs
@ -149,16 +202,16 @@ lastMainContent False var =
, Ret I64 (VInteger 0)
]
compileExp :: ExpT -> CompilerState ()
compileExp (MIR.ELit lit, _t) = emitLit lit
compileExp (MIR.EAdd e1 e2, t) = emitAdd t e1 e2
compileExp (MIR.EVar name, _t) = emitIdent name
compileExp (MIR.EApp e1 e2, t) = emitApp t e1 e2
compileExp (MIR.ELet bind e, _) = emitLet bind e
compileExp (MIR.ECase e cs, t) = emitECased t e (map (t,) cs)
compileExp :: T Exp -> CompilerState ()
compileExp (ELit lit, _t) = emitLit lit
compileExp (EAdd e1 e2, t) = emitAdd t e1 e2
compileExp (EVar name, _t) = emitIdent name
compileExp (EApp e1 e2, t) = emitApp t e1 e2
compileExp (ELet bind e, _) = emitLet bind e
compileExp (ECase e cs, t) = emitECased t e (map (t,) cs)
emitLet :: MIR.Bind -> ExpT -> CompilerState ()
emitLet (MIR.Bind id [] innerExp) e = do
emitLet :: Bind -> T Exp -> CompilerState ()
emitLet (Bind id [] innerExp) e = do
evaled <- exprToValue innerExp
tempVar <- getNewVar
let t = type2LlvmType . snd $ innerExp
@ -168,14 +221,14 @@ emitLet (MIR.Bind id [] innerExp) e = do
compileExp e
emitLet b _ = error $ "Non empty argument list in let-bind " <> show b
emitECased :: MIR.Type -> ExpT -> [(MIR.Type, Branch)] -> CompilerState ()
emitECased :: Type -> T Exp -> [(Type, Branch)] -> CompilerState ()
emitECased t e cases = do
let cs = snd <$> cases
let ty = type2LlvmType t
let rt = type2LlvmType (snd e)
vs <- exprToValue e
lbl <- getNewLabel
let label = TIR.Ident $ "escape_" <> show lbl
let label = Ident $ "escape_" <> show lbl
stackPtr <- getNewVar
emit $ SetVariable stackPtr (Alloca ty)
mapM_ (emitCases rt ty label stackPtr vs) cs
@ -192,14 +245,14 @@ emitECased t e cases = do
res <- getNewVar
emit $ SetVariable res (Load ty Ptr stackPtr)
where
emitCases :: LLVMType -> LLVMType -> TIR.Ident -> TIR.Ident -> LLVMValue -> Branch -> CompilerState ()
emitCases rt ty label stackPtr vs (Branch (MIR.PInj consId cs, _t) exp) = do
emitCases :: LLVMType -> LLVMType -> Ident -> Ident -> LLVMValue -> Branch -> CompilerState ()
emitCases rt ty label stackPtr vs (Branch (PInj consId cs, _t) exp) = do
emit $ Comment "Inj"
cons <- gets constructors
let r = fromJust $ Map.lookup consId cons
lbl_failPos <- (\x -> TIR.Ident $ "failed_" <> show x) <$> getNewLabel
lbl_succPos <- (\x -> TIR.Ident $ "success_" <> show x) <$> getNewLabel
lbl_failPos <- (\x -> Ident $ "failed_" <> show x) <$> getNewLabel
lbl_succPos <- (\x -> Ident $ "success_" <> show x) <$> getNewLabel
consVal <- getNewVar
emit $ SetVariable consVal (ExtractValue rt vs 0)
@ -215,10 +268,10 @@ emitECased t e cases = do
emit $ Store rt vs Ptr castPtr
emit $ SetVariable casted (Load (CustomType (coerce consId)) Ptr castPtr)
enumerateOneM_
( \i c -> do
( \i (c, t) -> do
case c of
PVar (x, topT) -> do
let topT' = type2LlvmType topT
PVar x -> do
let topT' = type2LlvmType t
let botT' = CustomType (coerce consId)
emit . Comment $ "ident " <> toIr topT'
cTypes <- gets customTypes
@ -228,7 +281,7 @@ emitECased t e cases = do
emit $ SetVariable deref (ExtractValue botT' (VIdent casted Ptr) i)
emit $ SetVariable x (Load topT' Ptr deref)
else emit $ SetVariable x (ExtractValue botT' (VIdent casted Ptr) i)
PLit (_l, _t) -> error "Nested pattern matching to be implemented"
PLit _l -> error "Nested pattern matching to be implemented"
PInj _id _ps -> error "Nested pattern matching to be implemented"
PCatch -> pure ()
PEnum _id -> error "Nested pattern matching to be implemented"
@ -238,22 +291,22 @@ emitECased t e cases = do
emit $ Store ty val Ptr stackPtr
emit $ Br label
emit $ Label lbl_failPos
emitCases _rt ty label stackPtr vs (Branch (MIR.PLit (i, ct), t) exp) = do
emitCases _rt ty label stackPtr vs (Branch (PLit i, t) exp) = do
emit $ Comment "Plit"
let i' = case i of
MIR.LInt i -> VInteger i
MIR.LChar i -> VChar (ord i)
LInt i -> VInteger i
LChar i -> VChar (ord i)
ns <- getNewVar
lbl_failPos <- (\x -> TIR.Ident $ "failed_" <> show x) <$> getNewLabel
lbl_succPos <- (\x -> TIR.Ident $ "success_" <> show x) <$> getNewLabel
emit $ SetVariable ns (Icmp LLEq (type2LlvmType ct) vs i')
lbl_failPos <- (\x -> Ident $ "failed_" <> show x) <$> getNewLabel
lbl_succPos <- (\x -> Ident $ "success_" <> show x) <$> getNewLabel
emit $ SetVariable ns (Icmp LLEq (type2LlvmType t) vs i')
emit $ BrCond (VIdent ns ty) lbl_succPos lbl_failPos
emit $ Label lbl_succPos
val <- exprToValue exp
emit $ Store ty val Ptr stackPtr
emit $ Br label
emit $ Label lbl_failPos
emitCases rt ty label stackPtr vs (Branch (MIR.PVar (id, _), _) exp) = do
emitCases rt ty label stackPtr vs (Branch (PVar id, _) exp) = do
emit $ Comment "Pvar"
-- //TODO this is pretty disgusting and would heavily benefit from a rewrite
valPtr <- getNewVar
@ -263,20 +316,20 @@ emitECased t e cases = do
val <- exprToValue exp
emit $ Store ty val Ptr stackPtr
emit $ Br label
lbl_failPos <- (\x -> TIR.Ident $ "failed_" <> show x) <$> getNewLabel
lbl_failPos <- (\x -> Ident $ "failed_" <> show x) <$> getNewLabel
emit $ Label lbl_failPos
emitCases rt ty label stackPtr vs (Branch (MIR.PEnum (TIR.Ident "True$Bool"), t) exp) = do
emitCases rt ty label stackPtr vs (Branch (MIR.PLit (MIR.LInt 1, TLit "Bool"), t) exp)
emitCases rt ty label stackPtr vs (Branch (MIR.PEnum (TIR.Ident "False$Bool"), _) exp) = do
emitCases rt ty label stackPtr vs (Branch (MIR.PLit (MIR.LInt 0, TLit "Bool"), t) exp)
emitCases rt ty label stackPtr vs br@(Branch (MIR.PEnum consId, _) exp) = do
emitCases rt ty label stackPtr vs (Branch (PEnum (Ident "True$Bool"), t) exp) = do
emitCases rt ty label stackPtr vs (Branch (PLit $ LInt 1, t) exp)
emitCases rt ty label stackPtr vs (Branch (PEnum (Ident "False$Bool"), _) exp) = do
emitCases rt ty label stackPtr vs (Branch (PLit (LInt 0), t) exp)
emitCases rt ty label stackPtr vs br@(Branch (PEnum consId, _) exp) = do
emit $ Comment "Penum"
cons <- gets constructors
let r = Map.lookup consId cons
when (isNothing r) (error $ "Constructor: '" ++ printTree consId ++ "' does not exist in cons state:\n" ++ show cons ++ "\nin pattern\n'" ++ printTree br ++ "'\n")
lbl_failPos <- (\x -> TIR.Ident $ "failed_" <> show x) <$> getNewLabel
lbl_succPos <- (\x -> TIR.Ident $ "success_" <> show x) <$> getNewLabel
lbl_failPos <- (\x -> Ident $ "failed_" <> show x) <$> getNewLabel
lbl_succPos <- (\x -> Ident $ "success_" <> show x) <$> getNewLabel
consVal <- getNewVar
emit $ SetVariable consVal (ExtractValue rt vs 0)
@ -295,98 +348,167 @@ emitECased t e cases = do
emit $ Store ty val Ptr stackPtr
emit $ Br label
emit $ Label lbl_failPos
emitCases _ ty label stackPtr _ (Branch (MIR.PCatch, _) exp) = do
emitCases _ ty label stackPtr _ (Branch (PCatch, _) exp) = do
emit $ Comment "Pcatch"
val <- exprToValue exp
emit $ Store ty val Ptr stackPtr
emit $ Br label
lbl_failPos <- (\x -> TIR.Ident $ "failed_" <> show x) <$> getNewLabel
lbl_failPos <- (\x -> Ident $ "failed_" <> show x) <$> getNewLabel
emit $ Label lbl_failPos
emitApp :: MIR.Type -> ExpT -> ExpT -> CompilerState ()
emitApp rt e1 e2 = appEmitter e1 e2 []
where
appEmitter :: ExpT -> ExpT -> [ExpT] -> CompilerState ()
appEmitter e1 e2 stack = do
let newStack = e2 : stack
case e1 of
(MIR.EApp e1' e2', _) -> appEmitter e1' e2' newStack
(MIR.EVar name, t) -> do
args <- traverse exprToValue newStack
vs <- getNewVar
funcs <- gets functions
consts <- gets constructors
let visibility =
fromMaybe Local $
Global <$ Map.lookup name consts
<|> Global <$ Map.lookup (name, t) funcs
-- this piece of code could probably be improved, i.e remove the double `const Global`
args' = map (first valueGetType . dupe) args
let call =
case name of
TIR.Ident ('l' : 't' : '$' : _) -> Icmp LLSlt I64 (snd (head args')) (snd (args' !! 1))
TIR.Ident ('$' : 'm' : 'i' : 'n' : 'u' : 's' : '$' : '$' : _) -> Sub I64 (snd (head args')) (snd (args' !! 1))
_ -> Call FastCC (type2LlvmType rt) visibility name args'
emit $ Comment $ show rt
emit $ SetVariable vs call
x -> error $ "The unspeakable happened: " <> show x
emitApp :: Type -> T Exp -> T Exp -> CompilerState ()
emitApp rt e1 e2 = do
((EVar name, t), args) <- go (EApp e1 e2, rt)
vs <- getNewVar
funcs <- gets functions
consts <- gets constructors
let visibility =
fromMaybe Local $
Global <$ Map.lookup name consts
<|> Global <$ Map.lookup (name, t) funcs
-- this piece of code could probably be improved, i.e remove the double `const Global`
emitIdent :: TIR.Ident -> CompilerState ()
call <- case name of
Ident ('l' : 't' : '$' : _) ->
pure $ Icmp LLSlt I64 (snd (head args)) (snd (args !! 1))
Ident ('$' : 'm' : 'i' : 'n' : 'u' : 's' : '$' : '$' : _) ->
pure $ Sub I64 (snd (head args)) (snd (args !! 1))
-- FIXME
_ -> do
let closure_call LocalElem { typ = Ptr, val } = (mkDerefName name, (Ptr, val) : args)
(name, args) <- gets $ maybe (name, (Ptr, VNull) : args) closure_call
. lookup name
. locals
pure $ Call FastCC (type2LlvmType rt) visibility name args
emit $ Comment $ show (type2LlvmType rt)
emit $ SetVariable vs call
where
go :: T Exp -> CompilerState (T Exp, [(LLVMType, LLVMValue)])
go et@(e, _) = case e of
EApp e1 e2@(_, t) -> do
(x, as) <- go e1
a <- exprToValue e2
let t' = type2LlvmType' t
pure (x, snoc (t', a) as)
_ -> pure (et, [])
type2LlvmType' = \case
TFun _ _ -> Ptr
t -> type2LlvmType t
emitIdent :: Ident -> CompilerState ()
emitIdent id = do
-- !!this should never happen!!
emit $ Comment "This should not have happened!"
emit $ Variable id
emit $ UnsafeRaw "\n"
emitLit :: MIR.Lit -> CompilerState ()
emitLit :: Lit -> CompilerState ()
emitLit i = do
-- !!this should never happen!!
let (i', t) = case i of
(MIR.LInt i'') -> (VInteger i'', I64)
(MIR.LChar i'') -> (VChar $ ord i'', I8)
(LInt i'') -> (VInteger i'', I64)
(LChar i'') -> (VChar $ ord i'', I8)
varCount <- getNewVar
emit $ Comment "This should not have happened!"
emit $ SetVariable varCount (Add t i' (VInteger 0))
emitAdd :: MIR.Type -> ExpT -> ExpT -> CompilerState ()
emitAdd :: Type -> T Exp -> T Exp -> CompilerState ()
emitAdd t e1 e2 = do
v1 <- exprToValue e1
v2 <- exprToValue e2
v <- getNewVar
emit $ SetVariable v (Add (type2LlvmType t) v1 v2)
exprToValue :: ExpT -> CompilerState LLVMValue
exprToValue = \case
(MIR.ELit i, _t) -> pure $ case i of
(MIR.LInt i) -> VInteger i
(MIR.LChar i) -> VChar $ ord i
(MIR.EVar (TIR.Ident "True$Bool"), _t) -> pure $ VInteger 1
(MIR.EVar (TIR.Ident "False$Bool"), _t) -> pure $ VInteger 0
(MIR.EVar name, t) -> do
funcs <- gets functions
cons <- gets constructors
let res =
Map.lookup (name, t) funcs
<|> ( \c ->
FunctionInfo
{ numArgs = numArgsCI c
, arguments = argumentsCI c
}
)
<$> Map.lookup name cons
case res of
Just fi -> do
if numArgs fi == 0
then do
vc <- getNewVar
emit $
SetVariable
vc
(Call FastCC (type2LlvmType t) Global name [])
pure $ VIdent vc (type2LlvmType t)
else pure $ VFunction name Global (type2LlvmType t)
Nothing -> pure $ VIdent name (type2LlvmType t)
e -> do
compileExp e
exprToValue :: T Exp -> CompilerState LLVMValue
exprToValue et@(e, t) = case e of
ELit (LInt i) -> pure $ VInteger i
ELit (LChar c) -> pure . VChar $ ord c
EVar "True$Bool" -> pure $ VInteger 1
EVar "False$Bool" -> pure $ VInteger 0
EVar name -> gets (Map.lookup name . globals) >>= \case
Just (typ@(Function _ ts), val) | length ts > 1 -> do
type_struct <- addStructType (mkClosureName name) [typ]
emit $ Comment "Allocating structure"
emit . SetVariable name $ Alloca type_struct
emit $ Store typ val Ptr name
pure $ VIdent name Ptr
Just _ | name == "main" -> do
vc <- getNewVar
emit $ SetVariable vc (Call FastCC I64 Global name [])
pure $ VIdent vc I64
Just (Function t_return [_], _) -> do
vc <- getNewVar
emit $ SetVariable vc (Call FastCC t_return Global name [(Ptr, VNull)])
pure $ VIdent vc t_return
Just _ -> error "Bad"
Nothing -> gets (Map.lookup name . constructors) >>= \case
Just ConstructorInfo {numArgsCI}
| numArgsCI == 0 -> do
vc <- getNewVar
emit $ SetVariable vc call
pure $ VIdent vc (type2LlvmType t)
| otherwise -> pure $ VFunction name Global (type2LlvmType t)
where
call = Call FastCC (type2LlvmType t) Global name []
Nothing -> gets $ val
. fromJust
. lookup name
. locals
EVarC cxt name -> do
let cxt' = flip map cxt $ \(x, t) -> let t' = type2LlvmType t
in (t', VIdent x t')
cxt'' <- gets $ (:cxt')
. fromJust
. Map.lookup name
. globals
-- Create a new type for function pointer and arguments
type_struct <- addStructType (mkClosureName name) $ map fst cxt''
emit $ Comment "Allocating structure"
emit . SetVariable name $ Alloca type_struct
let ptr_struct = VIdent name Ptr
storeArg (t, v) i = do
vc <- getNewVar
emit . SetVariable vc
$ GetElementPtrInbounds type_struct Ptr ptr_struct
I32 (VInteger 0) I32 (VInteger i) -- TODO fix indices
emit $ Store t v Ptr vc
-- Store arguments in structure
zipWithM_ storeArg cxt'' [0 ..]
pure ptr_struct
_ -> do
compileExp et
v <- getVarCount
pure $ VIdent (TIR.Ident $ show v) (getType e)
pure $ VIdent (Ident $ show v) (getType et)
mkClosureName :: Ident -> Ident
mkClosureName (Ident s) = Ident $ "Closure_" ++ s
mkDerefName :: Ident -> Ident
mkDerefName (Ident s) = Ident $ s ++ "_deref"
mkCxtName :: Ident
mkCxtName = Ident "cxt"

View file

@ -9,17 +9,18 @@ module Codegen.LlvmIr (
Visibility (..),
CallingConvention (..),
ToIr (..),
typeOf
) where
import Data.List (intercalate)
import TypeChecker.TypeCheckerIr (Ident (..))
import Data.List (intercalate)
import TypeChecker.TypeCheckerIr (Ident (..))
data CallingConvention = TailCC | FastCC | CCC | ColdCC deriving (Show, Eq, Ord)
instance ToIr CallingConvention where
toIr :: CallingConvention -> String
toIr TailCC = "tailcc"
toIr FastCC = "fastcc"
toIr CCC = "ccc"
toIr CCC = "ccc"
toIr ColdCC = "coldcc"
-- | A datatype which represents some basic LLVM types
@ -38,6 +39,9 @@ data LLVMType
class ToIr a where
toIr :: a -> String
instance ToIr a => ToIr [a] where
toIr = concatMap toIr
instance ToIr LLVMType where
toIr :: LLVMType -> String
toIr = \case
@ -66,8 +70,8 @@ data LLVMComp
instance ToIr LLVMComp where
toIr :: LLVMComp -> String
toIr = \case
LLEq -> "eq"
LLNe -> "ne"
LLEq -> "eq"
LLNe -> "ne"
LLUgt -> "ugt"
LLUge -> "uge"
LLUlt -> "ult"
@ -80,7 +84,7 @@ instance ToIr LLVMComp where
data Visibility = Local | Global deriving (Show, Eq, Ord)
instance ToIr Visibility where
toIr :: Visibility -> String
toIr Local = "%"
toIr Local = "%"
toIr Global = "@"
{- | Represents a LLVM "value", as in an integer, a register variable,
@ -92,16 +96,18 @@ data LLVMValue
| VIdent Ident LLVMType
| VConstant String
| VFunction Ident Visibility LLVMType
| VNull
deriving (Show, Eq, Ord)
instance ToIr LLVMValue where
toIr :: LLVMValue -> String
toIr v = case v of
VInteger i -> show i
VChar i -> show i
VIdent (Ident n) _ -> "%" <> n
VInteger i -> show i
VChar i -> show i
VIdent (Ident n) _ -> "%" <> n
VFunction (Ident n) vis _ -> toIr vis <> n
VConstant s -> "c" <> show s
VConstant s -> "c" <> show s
VNull -> "null"
type Params = [(Ident, LLVMType)]
type Args = [(LLVMType, LLVMValue)]
@ -139,6 +145,21 @@ data LLVMIr
-- instructions should be used in its place
deriving (Show, Eq, Ord)
-- TODO add missing clauses
typeOf :: LLVMIr -> LLVMType
typeOf = \case
Add t _ _ -> t
Sub t _ _ -> t
Mul t _ _ -> t
Div t _ _ -> t
Load t _ _ -> t
Store t _ _ _ -> t
Type x _ -> CustomType x
SetVariable _ ir -> typeOf ir
-- | Converts a list of LLVMIr instructions to a string
llvmIrToString :: [LLVMIr] -> String
llvmIrToString = go 0
@ -147,9 +168,9 @@ llvmIrToString = go 0
go _ [] = mempty
go i (x : xs) = do
let (i', n) = case x of
Define{} -> (i + 1, 0)
Define{} -> (i + 1, 0)
DefineEnd -> (i - 1, 0)
_ -> (i, i)
_ -> (i, i)
insToString n x <> go i' xs
-- \| Converts a LLVM inststruction to a String, allowing for printing etc.
@ -224,10 +245,10 @@ llvmIrToString = go 0
, ")\n"
]
(Alloca t) -> unwords ["alloca", toIr t, "\n"]
(Malloc t) ->
(Malloc t) ->
concat
[ "call ptr @malloc(i64 ", show t, ")\n"]
(GcMalloc t) ->
(GcMalloc t) ->
concat
[ "call ptr @cheap_alloc(i64 ", show t, ")\n"]
(Store t1 val t2 (Ident id2)) ->