Rearrange code

This commit is contained in:
Martin Fredin 2023-02-18 14:36:59 +01:00
parent 3efb27ac0c
commit 4ab6681f68
2 changed files with 262 additions and 280 deletions

View file

@ -3,12 +3,12 @@
module Compiler (compile) where module Compiler (compile) where
import Auxiliary (snoc)
import Control.Monad.State (StateT, execStateT, gets, modify) import Control.Monad.State (StateT, execStateT, gets, modify)
import Data.Map (Map) import Data.Map (Map)
import qualified Data.Map as Map import qualified Data.Map as Map
import Data.Tuple.Extra (second) import Data.Tuple.Extra (dupe, first, second)
import Grammar.ErrM (Err) import Grammar.ErrM (Err)
import Grammar.Print (printTree)
import LlvmIr (LLVMIr (..), LLVMType (..), import LlvmIr (LLVMIr (..), LLVMType (..),
LLVMValue (..), Visibility (..), LLVMValue (..), Visibility (..),
llvmIrToString) llvmIrToString)
@ -32,11 +32,11 @@ data FunctionInfo = FunctionInfo
-- | Adds a instruction to the CodeGenerator state -- | Adds a instruction to the CodeGenerator state
emit :: LLVMIr -> CompilerState () emit :: LLVMIr -> CompilerState ()
emit l = modify (\t -> t{instructions = instructions t ++ [l]}) emit l = modify $ \t -> t { instructions = snoc l $ instructions t }
-- | Increases the variable counter in the CodeGenerator state -- | Increases the variable counter in the CodeGenerator state
increaseVarCount :: CompilerState () increaseVarCount :: CompilerState ()
increaseVarCount = modify (\t -> t{variableCount = variableCount t + 1}) increaseVarCount = modify $ \t -> t { variableCount = variableCount t + 1 }
-- | Returns the variable count from the CodeGenerator state -- | Returns the variable count from the CodeGenerator state
getVarCount :: CompilerState Integer getVarCount :: CompilerState Integer
@ -46,212 +46,198 @@ getVarCount = gets variableCount
getNewVar :: CompilerState Integer getNewVar :: CompilerState Integer
getNewVar = increaseVarCount >> getVarCount getNewVar = increaseVarCount >> getVarCount
{- | Produces a map of functions infos from a list of binds, -- | Produces a map of functions infos from a list of binds,
which contains useful data for code generation. -- which contains useful data for code generation.
-}
getFunctions :: [Bind] -> Map Id FunctionInfo getFunctions :: [Bind] -> Map Id FunctionInfo
getFunctions xs = getFunctions bs = Map.fromList $ map go bs
Map.fromList $
map
( \(Bind id args _) ->
( id
, FunctionInfo
{ numArgs = length args
, arguments = args
}
)
)
xs
{- | Compiles an AST and produces a LLVM Ir string.
An easy way to actually "compile" this output is to
Simply pipe it to LLI
-}
compile :: Program -> Err String
compile (Program prg) = do
let s =
CodeGenerator
{ instructions = defaultStart
, functions = getFunctions prg
, variableCount = 0
}
ins <- instructions <$> execStateT (goDef prg) s
pure $ llvmIrToString ins
where where
mainContent :: LLVMValue -> [LLVMIr] go (Bind id args _) =
mainContent var = (id, FunctionInfo { numArgs=length args, arguments=args })
[ UnsafeRaw $
"call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef " <> show var <> ")\n"
, -- , SetVariable (Ident "p") (Icmp LLEq I64 (VInteger 2) (VInteger 2))
-- , BrCond (VIdent (Ident "p")) (Ident "b_1") (Ident "b_2")
-- , Label (Ident "b_1")
-- , UnsafeRaw
-- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 1)\n"
-- , Br (Ident "end")
-- , Label (Ident "b_2")
-- , UnsafeRaw
-- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 2)\n"
-- , Br (Ident "end")
-- , Label (Ident "end")
Ret I64 (VInteger 0)
]
defaultStart :: [LLVMIr]
defaultStart =
[ Comment (show $ printTree (Program prg))
, UnsafeRaw "@.str = private unnamed_addr constant [3 x i8] c\"%i\n\", align 1\n"
, UnsafeRaw "declare i32 @printf(ptr noalias nocapture, ...)\n"
]
goDef :: [Bind] -> CompilerState ()
goDef [] = return ()
goDef (Bind (name, t) args exp : xs) = do
emit $ UnsafeRaw "\n"
emit $ Comment $ show name <> ": " <> show exp
emit $ Define (type2LlvmType t_return) name (map (second type2LlvmType) args)
functionBody <- exprToValue exp
if name == "main"
then mapM_ emit (mainContent functionBody)
else emit $ Ret I64 functionBody
emit DefineEnd
modify (\s -> s{variableCount = 0})
goDef xs
where
t_return = snd $ partitionType (length args) t
go :: Exp -> CompilerState () initCodeGenerator :: [Bind] -> CodeGenerator
go (EInt int) = emitInt int initCodeGenerator scs = CodeGenerator { instructions = defaultStart
go (EAdd t e1 e2) = emitAdd t e1 e2 , functions = getFunctions scs
go (EId (name, _)) = emitIdent name , variableCount = 0
go (EApp t e1 e2) = emitApp t e1 e2 }
go (EAbs t ti e) = emitAbs t ti e
go (ELet bind e) = emitLet bind e
-- go (ESub e1 e2) = emitSub e1 e2
-- go (EMul e1 e2) = emitMul e1 e2
-- go (EDiv e1 e2) = emitDiv e1 e2
-- go (EMod e1 e2) = emitMod e1 e2
--- aux functions --- -- | Compiles an AST and produces a LLVM Ir string.
emitAbs :: Type -> Id -> Exp -> CompilerState () -- An easy way to actually "compile" this output is to
emitAbs _t tid e = do -- Simply pipe it to lli
emit . Comment $ compile :: Program -> Err String
"Lambda escaped previous stages: \\" <> show tid <> " . " <> show e compile (Program scs) = do
emitLet :: Bind -> Exp -> CompilerState () let codegen = initCodeGenerator scs
emitLet b e = do llvmIrToString . instructions <$> execStateT (compileScs scs) codegen
emit $
Comment $
concat
[ "ELet ("
, show b
, " = "
, show e
, ") is not implemented!"
]
emitApp :: Type -> Exp -> Exp -> CompilerState () compileScs :: [Bind] -> CompilerState ()
emitApp t e1 e2 = appEmitter t e1 e2 [] compileScs [] = pure ()
where compileScs (Bind (name, t) args exp : xs) = do
appEmitter :: Type -> Exp -> Exp -> [Exp] -> CompilerState () emit $ UnsafeRaw "\n"
appEmitter t e1 e2 stack = do emit . Comment $ show name <> ": " <> show exp
let newStack = e2 : stack let args' = map (second type2LlvmType) args
case e1 of emit $ Define (type2LlvmType t_return) name args'
EApp _ e1' e2' -> appEmitter t e1' e2' newStack functionBody <- exprToValue exp
EId id@(name, _) -> do if name == "main"
args <- traverse exprToValue newStack then mapM_ emit $ mainContent functionBody
vs <- getNewVar else emit $ Ret I64 functionBody
funcs <- gets functions emit DefineEnd
let vis = case Map.lookup id funcs of modify $ \s -> s { variableCount = 0 }
Nothing -> Local compileScs xs
Just _ -> Global where
let call = Call (type2LlvmType t) vis name ((\x -> (valueGetType x, x)) <$> args) t_return = snd $ partitionType (length args) t
emit $ SetVariable (Ident $ show vs) call
x -> do
emit . Comment $ "The unspeakable happened: "
emit . Comment $ show x
emitIdent :: Ident -> CompilerState () mainContent :: LLVMValue -> [LLVMIr]
emitIdent id = do mainContent var =
-- !!this should never happen!! [ UnsafeRaw $
emit $ Comment "This should not have happened!" "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef " <> show var <> ")\n"
emit $ Variable id , -- , SetVariable (Ident "p") (Icmp LLEq I64 (VInteger 2) (VInteger 2))
emit $ UnsafeRaw "\n" -- , BrCond (VIdent (Ident "p")) (Ident "b_1") (Ident "b_2")
-- , Label (Ident "b_1")
-- , UnsafeRaw
-- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 1)\n"
-- , Br (Ident "end")
-- , Label (Ident "b_2")
-- , UnsafeRaw
-- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 2)\n"
-- , Br (Ident "end")
-- , Label (Ident "end")
Ret I64 (VInteger 0)
]
emitInt :: Integer -> CompilerState () defaultStart :: [LLVMIr]
emitInt i = do defaultStart = [ UnsafeRaw "@.str = private unnamed_addr constant [3 x i8] c\"%i\n\", align 1\n"
-- !!this should never happen!! , UnsafeRaw "declare i32 @printf(ptr noalias nocapture, ...)\n"
varCount <- getNewVar ]
emit $ Comment "This should not have happened!"
emit $ SetVariable (Ident (show varCount)) (Add I64 (VInteger i) (VInteger 0))
emitAdd :: Type -> Exp -> Exp -> CompilerState () compileExp :: Exp -> CompilerState ()
emitAdd t e1 e2 = do compileExp = \case
v1 <- exprToValue e1 EInt i -> emitInt i
v2 <- exprToValue e2 EAdd t e1 e2 -> emitAdd t e1 e2
v <- getNewVar EId (name, _) -> emitIdent name
emit $ SetVariable (Ident $ show v) (Add (type2LlvmType t) v1 v2) EApp t e1 e2 -> emitApp t e1 e2
EAbs t ti e -> emitAbs t ti e
ELet bind e -> emitLet bind e
-- emitMul :: Exp -> Exp -> CompilerState () --- aux functions ---
-- emitMul e1 e2 = do emitAbs :: Type -> Id -> Exp -> CompilerState ()
-- (v1,v2) <- binExprToValues e1 e2 emitAbs _t tid e = emit . Comment $ "Lambda escaped previous stages: \\" <> show tid <> " . " <> show e
-- increaseVarCount
-- v <- gets variableCount
-- emit $ SetVariable $ Ident $ show v
-- emit $ Mul I64 v1 v2
-- emitMod :: Exp -> Exp -> CompilerState () emitLet :: Bind -> Exp -> CompilerState ()
-- emitMod e1 e2 = do emitLet b e = emit . Comment $ concat [ "ELet ("
-- -- `let m a b = rem (abs $ b + a) b` , show b
-- (v1,v2) <- binExprToValues e1 e2 , " = "
-- increaseVarCount , show e
-- vadd <- gets variableCount , ") is not implemented!"
-- emit $ SetVariable $ Ident $ show vadd ]
-- emit $ Add I64 v1 v2
--
-- increaseVarCount
-- vabs <- gets variableCount
-- emit $ SetVariable $ Ident $ show vabs
-- emit $ Call I64 (Ident "llvm.abs.i64")
-- [ (I64, VIdent (Ident $ show vadd))
-- , (I1, VInteger 1)
-- ]
-- increaseVarCount
-- v <- gets variableCount
-- emit $ SetVariable $ Ident $ show v
-- emit $ Srem I64 (VIdent (Ident $ show vabs)) v2
-- emitDiv :: Exp -> Exp -> CompilerState () emitApp :: Type -> Exp -> Exp -> CompilerState ()
-- emitDiv e1 e2 = do emitApp t e1 e2 = appEmitter t e1 e2 []
-- (v1,v2) <- binExprToValues e1 e2 where
-- increaseVarCount appEmitter :: Type -> Exp -> Exp -> [Exp] -> CompilerState ()
-- v <- gets variableCount appEmitter t e1 e2 stack = do
-- emit $ SetVariable $ Ident $ show v let newStack = e2 : stack
-- emit $ Div I64 v1 v2 case e1 of
EApp _ e1' e2' -> appEmitter t e1' e2' newStack
EId id@(name, _) -> do
args <- traverse exprToValue newStack
vs <- getNewVar
funcs <- gets functions
let visibility = maybe Local (const Global) $ Map.lookup id funcs
args' = map (first valueGetType . dupe) args
call = Call (type2LlvmType t) visibility name args'
emit $ SetVariable (Ident $ show vs) call
x -> do
emit . Comment $ "The unspeakable happened: "
emit . Comment $ show x
-- emitSub :: Exp -> Exp -> CompilerState () emitIdent :: Ident -> CompilerState ()
-- emitSub e1 e2 = do emitIdent id = do
-- (v1,v2) <- binExprToValues e1 e2 -- !!this should never happen!!
-- increaseVarCount emit $ Comment "This should not have happened!"
-- v <- gets variableCount emit $ Variable id
-- emit $ SetVariable $ Ident $ show v emit $ UnsafeRaw "\n"
-- emit $ Sub I64 v1 v2
exprToValue :: Exp -> CompilerState LLVMValue emitInt :: Integer -> CompilerState ()
exprToValue (EInt i) = return $ VInteger i emitInt i = do
exprToValue (EId id@(name, t)) = do -- !!this should never happen!!
varCount <- getNewVar
emit $ Comment "This should not have happened!"
emit $ SetVariable (Ident (show varCount)) (Add I64 (VInteger i) (VInteger 0))
emitAdd :: Type -> Exp -> Exp -> CompilerState ()
emitAdd t e1 e2 = do
v1 <- exprToValue e1
v2 <- exprToValue e2
v <- getNewVar
emit $ SetVariable (Ident $ show v) (Add (type2LlvmType t) v1 v2)
-- emitMul :: Exp -> Exp -> CompilerState ()
-- emitMul e1 e2 = do
-- (v1,v2) <- binExprToValues e1 e2
-- increaseVarCount
-- v <- gets variableCount
-- emit $ SetVariable $ Ident $ show v
-- emit $ Mul I64 v1 v2
-- emitMod :: Exp -> Exp -> CompilerState ()
-- emitMod e1 e2 = do
-- -- `let m a b = rem (abs $ b + a) b`
-- (v1,v2) <- binExprToValues e1 e2
-- increaseVarCount
-- vadd <- gets variableCount
-- emit $ SetVariable $ Ident $ show vadd
-- emit $ Add I64 v1 v2
--
-- increaseVarCount
-- vabs <- gets variableCount
-- emit $ SetVariable $ Ident $ show vabs
-- emit $ Call I64 (Ident "llvm.abs.i64")
-- [ (I64, VIdent (Ident $ show vadd))
-- , (I1, VInteger 1)
-- ]
-- increaseVarCount
-- v <- gets variableCount
-- emit $ SetVariable $ Ident $ show v
-- emit $ Srem I64 (VIdent (Ident $ show vabs)) v2
-- emitDiv :: Exp -> Exp -> CompilerState ()
-- emitDiv e1 e2 = do
-- (v1,v2) <- binExprToValues e1 e2
-- increaseVarCount
-- v <- gets variableCount
-- emit $ SetVariable $ Ident $ show v
-- emit $ Div I64 v1 v2
-- emitSub :: Exp -> Exp -> CompilerState ()
-- emitSub e1 e2 = do
-- (v1,v2) <- binExprToValues e1 e2
-- increaseVarCount
-- v <- gets variableCount
-- emit $ SetVariable $ Ident $ show v
-- emit $ Sub I64 v1 v2
exprToValue :: Exp -> CompilerState LLVMValue
exprToValue = \case
EInt i -> pure $ VInteger i
EId id@(name, t) -> do
funcs <- gets functions funcs <- gets functions
case Map.lookup id funcs of case Map.lookup id funcs of
Just fi -> do Just fi -> do
if numArgs fi == 0 if numArgs fi == 0
then do then do
vc <- getNewVar vc <- getNewVar
emit $ SetVariable (Ident $ show vc) (Call (type2LlvmType t) Global name []) emit $ SetVariable (Ident $ show vc)
return $ VIdent (Ident $ show vc) (type2LlvmType t) (Call (type2LlvmType t) Global name [])
else return $ VFunction name Global (type2LlvmType t) pure $ VIdent (Ident $ show vc) (type2LlvmType t)
Nothing -> return $ VIdent name (type2LlvmType t) else pure $ VFunction name Global (type2LlvmType t)
exprToValue e = do Nothing -> pure $ VIdent name (type2LlvmType t)
go e
e -> do
compileExp e
v <- getVarCount v <- getVarCount
return $ VIdent (Ident $ show v) (getType e) pure $ VIdent (Ident $ show v) (getType e)
type2LlvmType :: Type -> LLVMType type2LlvmType :: Type -> LLVMType
type2LlvmType = \case type2LlvmType = \case

View file

@ -9,8 +9,8 @@ module LlvmIr (
Visibility (..), Visibility (..),
) where ) where
import Data.List (intercalate) import Data.List (intercalate)
import TypeCheckerIr import TypeCheckerIr
-- | A datatype which represents some basic LLVM types -- | A datatype which represents some basic LLVM types
data LLVMType data LLVMType
@ -51,8 +51,8 @@ data LLVMComp
instance Show LLVMComp where instance Show LLVMComp where
show :: LLVMComp -> String show :: LLVMComp -> String
show = \case show = \case
LLEq -> "eq" LLEq -> "eq"
LLNe -> "ne" LLNe -> "ne"
LLUgt -> "ugt" LLUgt -> "ugt"
LLUge -> "uge" LLUge -> "uge"
LLUlt -> "ult" LLUlt -> "ult"
@ -65,12 +65,11 @@ instance Show LLVMComp where
data Visibility = Local | Global data Visibility = Local | Global
instance Show Visibility where instance Show Visibility where
show :: Visibility -> String show :: Visibility -> String
show Local = "%" show Local = "%"
show Global = "@" show Global = "@"
{- | Represents a LLVM "value", as in an integer, a register variable, -- | Represents a LLVM "value", as in an integer, a register variable,
or a string contstant -- or a string contstant
-}
data LLVMValue data LLVMValue
= VInteger Integer = VInteger Integer
| VIdent Ident LLVMType | VIdent Ident LLVMType
@ -80,10 +79,10 @@ data LLVMValue
instance Show LLVMValue where instance Show LLVMValue where
show :: LLVMValue -> String show :: LLVMValue -> String
show v = case v of show v = case v of
VInteger i -> show i VInteger i -> show i
VIdent (Ident n) _ -> "%" <> n VIdent (Ident n) _ -> "%" <> n
VFunction (Ident n) vis _ -> show vis <> n VFunction (Ident n) vis _ -> show vis <> n
VConstant s -> "c" <> show s VConstant s -> "c" <> show s
type Params = [(Ident, LLVMType)] type Params = [(Ident, LLVMType)]
type Args = [(LLVMType, LLVMValue)] type Args = [(LLVMType, LLVMValue)]
@ -122,87 +121,84 @@ llvmIrToString = go 0
go _ [] = mempty go _ [] = mempty
go i (x : xs) = do go i (x : xs) = do
let (i', n) = case x of let (i', n) = case x of
Define{} -> (i + 1, 0) Define{} -> (i + 1, 0)
DefineEnd -> (i - 1, 0) DefineEnd -> (i - 1, 0)
_ -> (i, i) _ -> (i, i)
insToString n x <> go i' xs insToString n x <> go i' xs
{- | Converts a LLVM inststruction to a String, allowing for printing etc. -- | Converts a LLVM inststruction to a String, allowing for printing etc.
The integer represents the indentation -- The integer represents the indentation
-} insToString :: Int -> LLVMIr -> String
{- FOURMOLU_DISABLE -} insToString i l =
insToString :: Int -> LLVMIr -> String replicate i '\t' <> case l of
insToString i l = (Define t (Ident i) params) ->
replicate i '\t' <> case l of concat
(Define t (Ident i) params) -> [ "define ", show t, " @", i
concat , "(", intercalate ", " (map (\(Ident y, x) -> unwords [show x, "%" <> y]) params)
[ "define ", show t, " @", i , ") {\n"
, "(", intercalate ", " (map (\(Ident y, x) -> unwords [show x, "%" <> y]) params) ]
, ") {\n" DefineEnd -> "}\n"
] (Declare _t (Ident _i) _params) -> undefined
DefineEnd -> "}\n" (SetVariable (Ident i) ir) -> concat ["%", i, " = ", insToString 0 ir]
(Declare _t (Ident _i) _params) -> undefined (Add t v1 v2) ->
(SetVariable (Ident i) ir) -> concat ["%", i, " = ", insToString 0 ir] concat
(Add t v1 v2) -> [ "add ", show t, " ", show v1
concat , ", ", show v2, "\n"
[ "add ", show t, " ", show v1 ]
, ", ", show v2, "\n" (Sub t v1 v2) ->
] concat
(Sub t v1 v2) -> [ "sub ", show t, " ", show v1, ", "
concat , show v2, "\n"
[ "sub ", show t, " ", show v1, ", " ]
, show v2, "\n" (Div t v1 v2) ->
] concat
(Div t v1 v2) -> [ "sdiv ", show t, " ", show v1, ", "
concat , show v2, "\n"
[ "sdiv ", show t, " ", show v1, ", " ]
, show v2, "\n" (Mul t v1 v2) ->
] concat
(Mul t v1 v2) -> [ "mul ", show t, " ", show v1
concat , ", ", show v2, "\n"
[ "mul ", show t, " ", show v1 ]
, ", ", show v2, "\n" (Srem t v1 v2) ->
] concat
(Srem t v1 v2) -> [ "srem ", show t, " ", show v1, ", "
concat , show v2, "\n"
[ "srem ", show t, " ", show v1, ", " ]
, show v2, "\n" (Call t vis (Ident i) arg) ->
] concat
(Call t vis (Ident i) arg) -> [ "call ", show t, " ", show vis, i, "("
concat , intercalate ", " $ Prelude.map (\(x, y) -> show x <> " " <> show y) arg
[ "call ", show t, " ", show vis, i, "(" , ")\n"
, intercalate ", " $ Prelude.map (\(x, y) -> show x <> " " <> show y) arg ]
, ")\n" (Alloca t) -> unwords ["alloca", show t, "\n"]
] (Store t1 (Ident id1) t2 (Ident id2)) ->
(Alloca t) -> unwords ["alloca", show t, "\n"] concat
(Store t1 (Ident id1) t2 (Ident id2)) -> [ "store ", show t1, " %", id1
concat , ", ", show t2 , " %", id2, "\n"
[ "store ", show t1, " %", id1 ]
, ", ", show t2 , " %", id2, "\n" (Bitcast t1 (Ident i) t2) ->
] concat
(Bitcast t1 (Ident i) t2) -> [ "bitcast ", show t1, " %"
concat , i, " to ", show t2, "\n"
[ "bitcast ", show t1, " %" ]
, i, " to ", show t2, "\n" (Icmp comp t v1 v2) ->
] concat
(Icmp comp t v1 v2) -> [ "icmp ", show comp, " ", show t
concat , " ", show v1, ", ", show v2, "\n"
[ "icmp ", show comp, " ", show t ]
, " ", show v1, ", ", show v2, "\n" (Ret t v) ->
] concat
(Ret t v) -> [ "ret ", show t, " "
concat , show v, "\n"
[ "ret ", show t, " " ]
, show v, "\n" (UnsafeRaw s) -> s
] (Label (Ident s)) -> "\nlabel_" <> s <> ":\n"
(UnsafeRaw s) -> s (Br (Ident s)) -> "br label %label_" <> s <> "\n"
(Label (Ident s)) -> "\nlabel_" <> s <> ":\n" (BrCond val (Ident s1) (Ident s2)) ->
(Br (Ident s)) -> "br label %label_" <> s <> "\n" concat
(BrCond val (Ident s1) (Ident s2)) -> [ "br i1 ", show val, ", ", "label %"
concat , "label_", s1, ", ", "label %", "label_", s2, "\n"
[ "br i1 ", show val, ", ", "label %" ]
, "label_", s1, ", ", "label %", "label_", s2, "\n" (Comment s) -> "; " <> s <> "\n"
] (Variable (Ident id)) -> "%" <> id
(Comment s) -> "; " <> s <> "\n"
(Variable (Ident id)) -> "%" <> id
{- FOURMOLU_ENABLE -}