Rearrange code

This commit is contained in:
Martin Fredin 2023-02-18 14:36:59 +01:00
parent 3efb27ac0c
commit 4ab6681f68
2 changed files with 262 additions and 280 deletions

View file

@ -3,12 +3,12 @@
module Compiler (compile) where module Compiler (compile) where
import Auxiliary (snoc)
import Control.Monad.State (StateT, execStateT, gets, modify) import Control.Monad.State (StateT, execStateT, gets, modify)
import Data.Map (Map) import Data.Map (Map)
import qualified Data.Map as Map import qualified Data.Map as Map
import Data.Tuple.Extra (second) import Data.Tuple.Extra (dupe, first, second)
import Grammar.ErrM (Err) import Grammar.ErrM (Err)
import Grammar.Print (printTree)
import LlvmIr (LLVMIr (..), LLVMType (..), import LlvmIr (LLVMIr (..), LLVMType (..),
LLVMValue (..), Visibility (..), LLVMValue (..), Visibility (..),
llvmIrToString) llvmIrToString)
@ -32,11 +32,11 @@ data FunctionInfo = FunctionInfo
-- | Adds a instruction to the CodeGenerator state -- | Adds a instruction to the CodeGenerator state
emit :: LLVMIr -> CompilerState () emit :: LLVMIr -> CompilerState ()
emit l = modify (\t -> t{instructions = instructions t ++ [l]}) emit l = modify $ \t -> t { instructions = snoc l $ instructions t }
-- | Increases the variable counter in the CodeGenerator state -- | Increases the variable counter in the CodeGenerator state
increaseVarCount :: CompilerState () increaseVarCount :: CompilerState ()
increaseVarCount = modify (\t -> t{variableCount = variableCount t + 1}) increaseVarCount = modify $ \t -> t { variableCount = variableCount t + 1 }
-- | Returns the variable count from the CodeGenerator state -- | Returns the variable count from the CodeGenerator state
getVarCount :: CompilerState Integer getVarCount :: CompilerState Integer
@ -46,40 +46,49 @@ getVarCount = gets variableCount
getNewVar :: CompilerState Integer getNewVar :: CompilerState Integer
getNewVar = increaseVarCount >> getVarCount getNewVar = increaseVarCount >> getVarCount
{- | Produces a map of functions infos from a list of binds, -- | Produces a map of functions infos from a list of binds,
which contains useful data for code generation. -- which contains useful data for code generation.
-}
getFunctions :: [Bind] -> Map Id FunctionInfo getFunctions :: [Bind] -> Map Id FunctionInfo
getFunctions xs = getFunctions bs = Map.fromList $ map go bs
Map.fromList $ where
map go (Bind id args _) =
( \(Bind id args _) -> (id, FunctionInfo { numArgs=length args, arguments=args })
( id
, FunctionInfo
{ numArgs = length args
, arguments = args
}
)
)
xs
{- | Compiles an AST and produces a LLVM Ir string.
An easy way to actually "compile" this output is to
Simply pipe it to LLI initCodeGenerator :: [Bind] -> CodeGenerator
-} initCodeGenerator scs = CodeGenerator { instructions = defaultStart
compile :: Program -> Err String , functions = getFunctions scs
compile (Program prg) = do
let s =
CodeGenerator
{ instructions = defaultStart
, functions = getFunctions prg
, variableCount = 0 , variableCount = 0
} }
ins <- instructions <$> execStateT (goDef prg) s
pure $ llvmIrToString ins -- | Compiles an AST and produces a LLVM Ir string.
-- An easy way to actually "compile" this output is to
-- Simply pipe it to lli
compile :: Program -> Err String
compile (Program scs) = do
let codegen = initCodeGenerator scs
llvmIrToString . instructions <$> execStateT (compileScs scs) codegen
compileScs :: [Bind] -> CompilerState ()
compileScs [] = pure ()
compileScs (Bind (name, t) args exp : xs) = do
emit $ UnsafeRaw "\n"
emit . Comment $ show name <> ": " <> show exp
let args' = map (second type2LlvmType) args
emit $ Define (type2LlvmType t_return) name args'
functionBody <- exprToValue exp
if name == "main"
then mapM_ emit $ mainContent functionBody
else emit $ Ret I64 functionBody
emit DefineEnd
modify $ \s -> s { variableCount = 0 }
compileScs xs
where where
mainContent :: LLVMValue -> [LLVMIr] t_return = snd $ partitionType (length args) t
mainContent var =
mainContent :: LLVMValue -> [LLVMIr]
mainContent var =
[ UnsafeRaw $ [ UnsafeRaw $
"call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef " <> show var <> ")\n" "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef " <> show var <> ")\n"
, -- , SetVariable (Ident "p") (Icmp LLEq I64 (VInteger 2) (VInteger 2)) , -- , SetVariable (Ident "p") (Icmp LLEq I64 (VInteger 2) (VInteger 2))
@ -96,60 +105,34 @@ compile (Program prg) = do
Ret I64 (VInteger 0) Ret I64 (VInteger 0)
] ]
defaultStart :: [LLVMIr] defaultStart :: [LLVMIr]
defaultStart = defaultStart = [ UnsafeRaw "@.str = private unnamed_addr constant [3 x i8] c\"%i\n\", align 1\n"
[ Comment (show $ printTree (Program prg))
, UnsafeRaw "@.str = private unnamed_addr constant [3 x i8] c\"%i\n\", align 1\n"
, UnsafeRaw "declare i32 @printf(ptr noalias nocapture, ...)\n" , UnsafeRaw "declare i32 @printf(ptr noalias nocapture, ...)\n"
] ]
goDef :: [Bind] -> CompilerState () compileExp :: Exp -> CompilerState ()
goDef [] = return () compileExp = \case
goDef (Bind (name, t) args exp : xs) = do EInt i -> emitInt i
emit $ UnsafeRaw "\n" EAdd t e1 e2 -> emitAdd t e1 e2
emit $ Comment $ show name <> ": " <> show exp EId (name, _) -> emitIdent name
emit $ Define (type2LlvmType t_return) name (map (second type2LlvmType) args) EApp t e1 e2 -> emitApp t e1 e2
functionBody <- exprToValue exp EAbs t ti e -> emitAbs t ti e
if name == "main" ELet bind e -> emitLet bind e
then mapM_ emit (mainContent functionBody)
else emit $ Ret I64 functionBody
emit DefineEnd
modify (\s -> s{variableCount = 0})
goDef xs
where
t_return = snd $ partitionType (length args) t
go :: Exp -> CompilerState () --- aux functions ---
go (EInt int) = emitInt int emitAbs :: Type -> Id -> Exp -> CompilerState ()
go (EAdd t e1 e2) = emitAdd t e1 e2 emitAbs _t tid e = emit . Comment $ "Lambda escaped previous stages: \\" <> show tid <> " . " <> show e
go (EId (name, _)) = emitIdent name
go (EApp t e1 e2) = emitApp t e1 e2
go (EAbs t ti e) = emitAbs t ti e
go (ELet bind e) = emitLet bind e
-- go (ESub e1 e2) = emitSub e1 e2
-- go (EMul e1 e2) = emitMul e1 e2
-- go (EDiv e1 e2) = emitDiv e1 e2
-- go (EMod e1 e2) = emitMod e1 e2
--- aux functions --- emitLet :: Bind -> Exp -> CompilerState ()
emitAbs :: Type -> Id -> Exp -> CompilerState () emitLet b e = emit . Comment $ concat [ "ELet ("
emitAbs _t tid e = do
emit . Comment $
"Lambda escaped previous stages: \\" <> show tid <> " . " <> show e
emitLet :: Bind -> Exp -> CompilerState ()
emitLet b e = do
emit $
Comment $
concat
[ "ELet ("
, show b , show b
, " = " , " = "
, show e , show e
, ") is not implemented!" , ") is not implemented!"
] ]
emitApp :: Type -> Exp -> Exp -> CompilerState () emitApp :: Type -> Exp -> Exp -> CompilerState ()
emitApp t e1 e2 = appEmitter t e1 e2 [] emitApp t e1 e2 = appEmitter t e1 e2 []
where where
appEmitter :: Type -> Exp -> Exp -> [Exp] -> CompilerState () appEmitter :: Type -> Exp -> Exp -> [Exp] -> CompilerState ()
appEmitter t e1 e2 stack = do appEmitter t e1 e2 stack = do
@ -160,98 +143,101 @@ compile (Program prg) = do
args <- traverse exprToValue newStack args <- traverse exprToValue newStack
vs <- getNewVar vs <- getNewVar
funcs <- gets functions funcs <- gets functions
let vis = case Map.lookup id funcs of let visibility = maybe Local (const Global) $ Map.lookup id funcs
Nothing -> Local args' = map (first valueGetType . dupe) args
Just _ -> Global call = Call (type2LlvmType t) visibility name args'
let call = Call (type2LlvmType t) vis name ((\x -> (valueGetType x, x)) <$> args)
emit $ SetVariable (Ident $ show vs) call emit $ SetVariable (Ident $ show vs) call
x -> do x -> do
emit . Comment $ "The unspeakable happened: " emit . Comment $ "The unspeakable happened: "
emit . Comment $ show x emit . Comment $ show x
emitIdent :: Ident -> CompilerState () emitIdent :: Ident -> CompilerState ()
emitIdent id = do emitIdent id = do
-- !!this should never happen!! -- !!this should never happen!!
emit $ Comment "This should not have happened!" emit $ Comment "This should not have happened!"
emit $ Variable id emit $ Variable id
emit $ UnsafeRaw "\n" emit $ UnsafeRaw "\n"
emitInt :: Integer -> CompilerState () emitInt :: Integer -> CompilerState ()
emitInt i = do emitInt i = do
-- !!this should never happen!! -- !!this should never happen!!
varCount <- getNewVar varCount <- getNewVar
emit $ Comment "This should not have happened!" emit $ Comment "This should not have happened!"
emit $ SetVariable (Ident (show varCount)) (Add I64 (VInteger i) (VInteger 0)) emit $ SetVariable (Ident (show varCount)) (Add I64 (VInteger i) (VInteger 0))
emitAdd :: Type -> Exp -> Exp -> CompilerState () emitAdd :: Type -> Exp -> Exp -> CompilerState ()
emitAdd t e1 e2 = do emitAdd t e1 e2 = do
v1 <- exprToValue e1 v1 <- exprToValue e1
v2 <- exprToValue e2 v2 <- exprToValue e2
v <- getNewVar v <- getNewVar
emit $ SetVariable (Ident $ show v) (Add (type2LlvmType t) v1 v2) emit $ SetVariable (Ident $ show v) (Add (type2LlvmType t) v1 v2)
-- emitMul :: Exp -> Exp -> CompilerState () -- emitMul :: Exp -> Exp -> CompilerState ()
-- emitMul e1 e2 = do -- emitMul e1 e2 = do
-- (v1,v2) <- binExprToValues e1 e2 -- (v1,v2) <- binExprToValues e1 e2
-- increaseVarCount -- increaseVarCount
-- v <- gets variableCount -- v <- gets variableCount
-- emit $ SetVariable $ Ident $ show v -- emit $ SetVariable $ Ident $ show v
-- emit $ Mul I64 v1 v2 -- emit $ Mul I64 v1 v2
-- emitMod :: Exp -> Exp -> CompilerState () -- emitMod :: Exp -> Exp -> CompilerState ()
-- emitMod e1 e2 = do -- emitMod e1 e2 = do
-- -- `let m a b = rem (abs $ b + a) b` -- -- `let m a b = rem (abs $ b + a) b`
-- (v1,v2) <- binExprToValues e1 e2 -- (v1,v2) <- binExprToValues e1 e2
-- increaseVarCount -- increaseVarCount
-- vadd <- gets variableCount -- vadd <- gets variableCount
-- emit $ SetVariable $ Ident $ show vadd -- emit $ SetVariable $ Ident $ show vadd
-- emit $ Add I64 v1 v2 -- emit $ Add I64 v1 v2
-- --
-- increaseVarCount -- increaseVarCount
-- vabs <- gets variableCount -- vabs <- gets variableCount
-- emit $ SetVariable $ Ident $ show vabs -- emit $ SetVariable $ Ident $ show vabs
-- emit $ Call I64 (Ident "llvm.abs.i64") -- emit $ Call I64 (Ident "llvm.abs.i64")
-- [ (I64, VIdent (Ident $ show vadd)) -- [ (I64, VIdent (Ident $ show vadd))
-- , (I1, VInteger 1) -- , (I1, VInteger 1)
-- ] -- ]
-- increaseVarCount -- increaseVarCount
-- v <- gets variableCount -- v <- gets variableCount
-- emit $ SetVariable $ Ident $ show v -- emit $ SetVariable $ Ident $ show v
-- emit $ Srem I64 (VIdent (Ident $ show vabs)) v2 -- emit $ Srem I64 (VIdent (Ident $ show vabs)) v2
-- emitDiv :: Exp -> Exp -> CompilerState () -- emitDiv :: Exp -> Exp -> CompilerState ()
-- emitDiv e1 e2 = do -- emitDiv e1 e2 = do
-- (v1,v2) <- binExprToValues e1 e2 -- (v1,v2) <- binExprToValues e1 e2
-- increaseVarCount -- increaseVarCount
-- v <- gets variableCount -- v <- gets variableCount
-- emit $ SetVariable $ Ident $ show v -- emit $ SetVariable $ Ident $ show v
-- emit $ Div I64 v1 v2 -- emit $ Div I64 v1 v2
-- emitSub :: Exp -> Exp -> CompilerState () -- emitSub :: Exp -> Exp -> CompilerState ()
-- emitSub e1 e2 = do -- emitSub e1 e2 = do
-- (v1,v2) <- binExprToValues e1 e2 -- (v1,v2) <- binExprToValues e1 e2
-- increaseVarCount -- increaseVarCount
-- v <- gets variableCount -- v <- gets variableCount
-- emit $ SetVariable $ Ident $ show v -- emit $ SetVariable $ Ident $ show v
-- emit $ Sub I64 v1 v2 -- emit $ Sub I64 v1 v2
exprToValue :: Exp -> CompilerState LLVMValue exprToValue :: Exp -> CompilerState LLVMValue
exprToValue (EInt i) = return $ VInteger i exprToValue = \case
exprToValue (EId id@(name, t)) = do EInt i -> pure $ VInteger i
EId id@(name, t) -> do
funcs <- gets functions funcs <- gets functions
case Map.lookup id funcs of case Map.lookup id funcs of
Just fi -> do Just fi -> do
if numArgs fi == 0 if numArgs fi == 0
then do then do
vc <- getNewVar vc <- getNewVar
emit $ SetVariable (Ident $ show vc) (Call (type2LlvmType t) Global name []) emit $ SetVariable (Ident $ show vc)
return $ VIdent (Ident $ show vc) (type2LlvmType t) (Call (type2LlvmType t) Global name [])
else return $ VFunction name Global (type2LlvmType t) pure $ VIdent (Ident $ show vc) (type2LlvmType t)
Nothing -> return $ VIdent name (type2LlvmType t) else pure $ VFunction name Global (type2LlvmType t)
exprToValue e = do Nothing -> pure $ VIdent name (type2LlvmType t)
go e
e -> do
compileExp e
v <- getVarCount v <- getVarCount
return $ VIdent (Ident $ show v) (getType e) pure $ VIdent (Ident $ show v) (getType e)
type2LlvmType :: Type -> LLVMType type2LlvmType :: Type -> LLVMType
type2LlvmType = \case type2LlvmType = \case

View file

@ -68,9 +68,8 @@ instance Show Visibility where
show Local = "%" show Local = "%"
show Global = "@" show Global = "@"
{- | Represents a LLVM "value", as in an integer, a register variable, -- | Represents a LLVM "value", as in an integer, a register variable,
or a string contstant -- or a string contstant
-}
data LLVMValue data LLVMValue
= VInteger Integer = VInteger Integer
| VIdent Ident LLVMType | VIdent Ident LLVMType
@ -127,12 +126,10 @@ llvmIrToString = go 0
_ -> (i, i) _ -> (i, i)
insToString n x <> go i' xs insToString n x <> go i' xs
{- | Converts a LLVM inststruction to a String, allowing for printing etc. -- | Converts a LLVM inststruction to a String, allowing for printing etc.
The integer represents the indentation -- The integer represents the indentation
-} insToString :: Int -> LLVMIr -> String
{- FOURMOLU_DISABLE -} insToString i l =
insToString :: Int -> LLVMIr -> String
insToString i l =
replicate i '\t' <> case l of replicate i '\t' <> case l of
(Define t (Ident i) params) -> (Define t (Ident i) params) ->
concat concat
@ -205,4 +202,3 @@ llvmIrToString = go 0
] ]
(Comment s) -> "; " <> s <> "\n" (Comment s) -> "; " <> s <> "\n"
(Variable (Ident id)) -> "%" <> id (Variable (Ident id)) -> "%" <> id
{- FOURMOLU_ENABLE -}