From 4ab6681f681df9d5671106e8c5a43e971bb9f4f2 Mon Sep 17 00:00:00 2001 From: Martin Fredin Date: Sat, 18 Feb 2023 14:36:59 +0100 Subject: [PATCH] Rearrange code --- src/Compiler.hs | 362 +++++++++++++++++++++++------------------------- src/LlvmIr.hs | 180 ++++++++++++------------ 2 files changed, 262 insertions(+), 280 deletions(-) diff --git a/src/Compiler.hs b/src/Compiler.hs index 8cbeb58..fd6b6bc 100644 --- a/src/Compiler.hs +++ b/src/Compiler.hs @@ -3,12 +3,12 @@ module Compiler (compile) where +import Auxiliary (snoc) import Control.Monad.State (StateT, execStateT, gets, modify) import Data.Map (Map) import qualified Data.Map as Map -import Data.Tuple.Extra (second) +import Data.Tuple.Extra (dupe, first, second) import Grammar.ErrM (Err) -import Grammar.Print (printTree) import LlvmIr (LLVMIr (..), LLVMType (..), LLVMValue (..), Visibility (..), llvmIrToString) @@ -32,11 +32,11 @@ data FunctionInfo = FunctionInfo -- | Adds a instruction to the CodeGenerator state emit :: LLVMIr -> CompilerState () -emit l = modify (\t -> t{instructions = instructions t ++ [l]}) +emit l = modify $ \t -> t { instructions = snoc l $ instructions t } -- | Increases the variable counter in the CodeGenerator state increaseVarCount :: CompilerState () -increaseVarCount = modify (\t -> t{variableCount = variableCount t + 1}) +increaseVarCount = modify $ \t -> t { variableCount = variableCount t + 1 } -- | Returns the variable count from the CodeGenerator state getVarCount :: CompilerState Integer @@ -46,212 +46,198 @@ getVarCount = gets variableCount getNewVar :: CompilerState Integer getNewVar = increaseVarCount >> getVarCount -{- | Produces a map of functions infos from a list of binds, - which contains useful data for code generation. --} +-- | Produces a map of functions infos from a list of binds, +-- which contains useful data for code generation. getFunctions :: [Bind] -> Map Id FunctionInfo -getFunctions xs = - Map.fromList $ - map - ( \(Bind id args _) -> - ( id - , FunctionInfo - { numArgs = length args - , arguments = args - } - ) - ) - xs - -{- | Compiles an AST and produces a LLVM Ir string. - An easy way to actually "compile" this output is to - Simply pipe it to LLI --} -compile :: Program -> Err String -compile (Program prg) = do - let s = - CodeGenerator - { instructions = defaultStart - , functions = getFunctions prg - , variableCount = 0 - } - ins <- instructions <$> execStateT (goDef prg) s - pure $ llvmIrToString ins +getFunctions bs = Map.fromList $ map go bs where - mainContent :: LLVMValue -> [LLVMIr] - mainContent var = - [ UnsafeRaw $ - "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef " <> show var <> ")\n" - , -- , SetVariable (Ident "p") (Icmp LLEq I64 (VInteger 2) (VInteger 2)) - -- , BrCond (VIdent (Ident "p")) (Ident "b_1") (Ident "b_2") - -- , Label (Ident "b_1") - -- , UnsafeRaw - -- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 1)\n" - -- , Br (Ident "end") - -- , Label (Ident "b_2") - -- , UnsafeRaw - -- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 2)\n" - -- , Br (Ident "end") - -- , Label (Ident "end") - Ret I64 (VInteger 0) - ] + go (Bind id args _) = + (id, FunctionInfo { numArgs=length args, arguments=args }) - defaultStart :: [LLVMIr] - defaultStart = - [ Comment (show $ printTree (Program prg)) - , UnsafeRaw "@.str = private unnamed_addr constant [3 x i8] c\"%i\n\", align 1\n" - , UnsafeRaw "declare i32 @printf(ptr noalias nocapture, ...)\n" - ] - goDef :: [Bind] -> CompilerState () - goDef [] = return () - goDef (Bind (name, t) args exp : xs) = do - emit $ UnsafeRaw "\n" - emit $ Comment $ show name <> ": " <> show exp - emit $ Define (type2LlvmType t_return) name (map (second type2LlvmType) args) - functionBody <- exprToValue exp - if name == "main" - then mapM_ emit (mainContent functionBody) - else emit $ Ret I64 functionBody - emit DefineEnd - modify (\s -> s{variableCount = 0}) - goDef xs - where - t_return = snd $ partitionType (length args) t - go :: Exp -> CompilerState () - go (EInt int) = emitInt int - go (EAdd t e1 e2) = emitAdd t e1 e2 - go (EId (name, _)) = emitIdent name - go (EApp t e1 e2) = emitApp t e1 e2 - go (EAbs t ti e) = emitAbs t ti e - go (ELet bind e) = emitLet bind e - -- go (ESub e1 e2) = emitSub e1 e2 - -- go (EMul e1 e2) = emitMul e1 e2 - -- go (EDiv e1 e2) = emitDiv e1 e2 - -- go (EMod e1 e2) = emitMod e1 e2 +initCodeGenerator :: [Bind] -> CodeGenerator +initCodeGenerator scs = CodeGenerator { instructions = defaultStart + , functions = getFunctions scs + , variableCount = 0 + } - --- aux functions --- - emitAbs :: Type -> Id -> Exp -> CompilerState () - emitAbs _t tid e = do - emit . Comment $ - "Lambda escaped previous stages: \\" <> show tid <> " . " <> show e - emitLet :: Bind -> Exp -> CompilerState () - emitLet b e = do - emit $ - Comment $ - concat - [ "ELet (" - , show b - , " = " - , show e - , ") is not implemented!" - ] +-- | Compiles an AST and produces a LLVM Ir string. +-- An easy way to actually "compile" this output is to +-- Simply pipe it to lli +compile :: Program -> Err String +compile (Program scs) = do + let codegen = initCodeGenerator scs + llvmIrToString . instructions <$> execStateT (compileScs scs) codegen - emitApp :: Type -> Exp -> Exp -> CompilerState () - emitApp t e1 e2 = appEmitter t e1 e2 [] - where - appEmitter :: Type -> Exp -> Exp -> [Exp] -> CompilerState () - appEmitter t e1 e2 stack = do - let newStack = e2 : stack - case e1 of - EApp _ e1' e2' -> appEmitter t e1' e2' newStack - EId id@(name, _) -> do - args <- traverse exprToValue newStack - vs <- getNewVar - funcs <- gets functions - let vis = case Map.lookup id funcs of - Nothing -> Local - Just _ -> Global - let call = Call (type2LlvmType t) vis name ((\x -> (valueGetType x, x)) <$> args) - emit $ SetVariable (Ident $ show vs) call - x -> do - emit . Comment $ "The unspeakable happened: " - emit . Comment $ show x +compileScs :: [Bind] -> CompilerState () +compileScs [] = pure () +compileScs (Bind (name, t) args exp : xs) = do + emit $ UnsafeRaw "\n" + emit . Comment $ show name <> ": " <> show exp + let args' = map (second type2LlvmType) args + emit $ Define (type2LlvmType t_return) name args' + functionBody <- exprToValue exp + if name == "main" + then mapM_ emit $ mainContent functionBody + else emit $ Ret I64 functionBody + emit DefineEnd + modify $ \s -> s { variableCount = 0 } + compileScs xs + where + t_return = snd $ partitionType (length args) t - emitIdent :: Ident -> CompilerState () - emitIdent id = do - -- !!this should never happen!! - emit $ Comment "This should not have happened!" - emit $ Variable id - emit $ UnsafeRaw "\n" +mainContent :: LLVMValue -> [LLVMIr] +mainContent var = + [ UnsafeRaw $ + "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef " <> show var <> ")\n" + , -- , SetVariable (Ident "p") (Icmp LLEq I64 (VInteger 2) (VInteger 2)) + -- , BrCond (VIdent (Ident "p")) (Ident "b_1") (Ident "b_2") + -- , Label (Ident "b_1") + -- , UnsafeRaw + -- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 1)\n" + -- , Br (Ident "end") + -- , Label (Ident "b_2") + -- , UnsafeRaw + -- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 2)\n" + -- , Br (Ident "end") + -- , Label (Ident "end") + Ret I64 (VInteger 0) + ] - emitInt :: Integer -> CompilerState () - emitInt i = do - -- !!this should never happen!! - varCount <- getNewVar - emit $ Comment "This should not have happened!" - emit $ SetVariable (Ident (show varCount)) (Add I64 (VInteger i) (VInteger 0)) +defaultStart :: [LLVMIr] +defaultStart = [ UnsafeRaw "@.str = private unnamed_addr constant [3 x i8] c\"%i\n\", align 1\n" + , UnsafeRaw "declare i32 @printf(ptr noalias nocapture, ...)\n" + ] - emitAdd :: Type -> Exp -> Exp -> CompilerState () - emitAdd t e1 e2 = do - v1 <- exprToValue e1 - v2 <- exprToValue e2 - v <- getNewVar - emit $ SetVariable (Ident $ show v) (Add (type2LlvmType t) v1 v2) +compileExp :: Exp -> CompilerState () +compileExp = \case + EInt i -> emitInt i + EAdd t e1 e2 -> emitAdd t e1 e2 + EId (name, _) -> emitIdent name + EApp t e1 e2 -> emitApp t e1 e2 + EAbs t ti e -> emitAbs t ti e + ELet bind e -> emitLet bind e - -- emitMul :: Exp -> Exp -> CompilerState () - -- emitMul e1 e2 = do - -- (v1,v2) <- binExprToValues e1 e2 - -- increaseVarCount - -- v <- gets variableCount - -- emit $ SetVariable $ Ident $ show v - -- emit $ Mul I64 v1 v2 +--- aux functions --- +emitAbs :: Type -> Id -> Exp -> CompilerState () +emitAbs _t tid e = emit . Comment $ "Lambda escaped previous stages: \\" <> show tid <> " . " <> show e - -- emitMod :: Exp -> Exp -> CompilerState () - -- emitMod e1 e2 = do - -- -- `let m a b = rem (abs $ b + a) b` - -- (v1,v2) <- binExprToValues e1 e2 - -- increaseVarCount - -- vadd <- gets variableCount - -- emit $ SetVariable $ Ident $ show vadd - -- emit $ Add I64 v1 v2 - -- - -- increaseVarCount - -- vabs <- gets variableCount - -- emit $ SetVariable $ Ident $ show vabs - -- emit $ Call I64 (Ident "llvm.abs.i64") - -- [ (I64, VIdent (Ident $ show vadd)) - -- , (I1, VInteger 1) - -- ] - -- increaseVarCount - -- v <- gets variableCount - -- emit $ SetVariable $ Ident $ show v - -- emit $ Srem I64 (VIdent (Ident $ show vabs)) v2 +emitLet :: Bind -> Exp -> CompilerState () +emitLet b e = emit . Comment $ concat [ "ELet (" + , show b + , " = " + , show e + , ") is not implemented!" + ] - -- emitDiv :: Exp -> Exp -> CompilerState () - -- emitDiv e1 e2 = do - -- (v1,v2) <- binExprToValues e1 e2 - -- increaseVarCount - -- v <- gets variableCount - -- emit $ SetVariable $ Ident $ show v - -- emit $ Div I64 v1 v2 +emitApp :: Type -> Exp -> Exp -> CompilerState () +emitApp t e1 e2 = appEmitter t e1 e2 [] + where + appEmitter :: Type -> Exp -> Exp -> [Exp] -> CompilerState () + appEmitter t e1 e2 stack = do + let newStack = e2 : stack + case e1 of + EApp _ e1' e2' -> appEmitter t e1' e2' newStack + EId id@(name, _) -> do + args <- traverse exprToValue newStack + vs <- getNewVar + funcs <- gets functions + let visibility = maybe Local (const Global) $ Map.lookup id funcs + args' = map (first valueGetType . dupe) args + call = Call (type2LlvmType t) visibility name args' + emit $ SetVariable (Ident $ show vs) call + x -> do + emit . Comment $ "The unspeakable happened: " + emit . Comment $ show x - -- emitSub :: Exp -> Exp -> CompilerState () - -- emitSub e1 e2 = do - -- (v1,v2) <- binExprToValues e1 e2 - -- increaseVarCount - -- v <- gets variableCount - -- emit $ SetVariable $ Ident $ show v - -- emit $ Sub I64 v1 v2 +emitIdent :: Ident -> CompilerState () +emitIdent id = do + -- !!this should never happen!! + emit $ Comment "This should not have happened!" + emit $ Variable id + emit $ UnsafeRaw "\n" - exprToValue :: Exp -> CompilerState LLVMValue - exprToValue (EInt i) = return $ VInteger i - exprToValue (EId id@(name, t)) = do +emitInt :: Integer -> CompilerState () +emitInt i = do + -- !!this should never happen!! + varCount <- getNewVar + emit $ Comment "This should not have happened!" + emit $ SetVariable (Ident (show varCount)) (Add I64 (VInteger i) (VInteger 0)) + +emitAdd :: Type -> Exp -> Exp -> CompilerState () +emitAdd t e1 e2 = do + v1 <- exprToValue e1 + v2 <- exprToValue e2 + v <- getNewVar + emit $ SetVariable (Ident $ show v) (Add (type2LlvmType t) v1 v2) + +-- emitMul :: Exp -> Exp -> CompilerState () +-- emitMul e1 e2 = do +-- (v1,v2) <- binExprToValues e1 e2 +-- increaseVarCount +-- v <- gets variableCount +-- emit $ SetVariable $ Ident $ show v +-- emit $ Mul I64 v1 v2 + +-- emitMod :: Exp -> Exp -> CompilerState () +-- emitMod e1 e2 = do +-- -- `let m a b = rem (abs $ b + a) b` +-- (v1,v2) <- binExprToValues e1 e2 +-- increaseVarCount +-- vadd <- gets variableCount +-- emit $ SetVariable $ Ident $ show vadd +-- emit $ Add I64 v1 v2 +-- +-- increaseVarCount +-- vabs <- gets variableCount +-- emit $ SetVariable $ Ident $ show vabs +-- emit $ Call I64 (Ident "llvm.abs.i64") +-- [ (I64, VIdent (Ident $ show vadd)) +-- , (I1, VInteger 1) +-- ] +-- increaseVarCount +-- v <- gets variableCount +-- emit $ SetVariable $ Ident $ show v +-- emit $ Srem I64 (VIdent (Ident $ show vabs)) v2 + +-- emitDiv :: Exp -> Exp -> CompilerState () +-- emitDiv e1 e2 = do +-- (v1,v2) <- binExprToValues e1 e2 +-- increaseVarCount +-- v <- gets variableCount +-- emit $ SetVariable $ Ident $ show v +-- emit $ Div I64 v1 v2 + +-- emitSub :: Exp -> Exp -> CompilerState () +-- emitSub e1 e2 = do +-- (v1,v2) <- binExprToValues e1 e2 +-- increaseVarCount +-- v <- gets variableCount +-- emit $ SetVariable $ Ident $ show v +-- emit $ Sub I64 v1 v2 + +exprToValue :: Exp -> CompilerState LLVMValue +exprToValue = \case + EInt i -> pure $ VInteger i + + EId id@(name, t) -> do funcs <- gets functions case Map.lookup id funcs of Just fi -> do if numArgs fi == 0 then do vc <- getNewVar - emit $ SetVariable (Ident $ show vc) (Call (type2LlvmType t) Global name []) - return $ VIdent (Ident $ show vc) (type2LlvmType t) - else return $ VFunction name Global (type2LlvmType t) - Nothing -> return $ VIdent name (type2LlvmType t) - exprToValue e = do - go e + emit $ SetVariable (Ident $ show vc) + (Call (type2LlvmType t) Global name []) + pure $ VIdent (Ident $ show vc) (type2LlvmType t) + else pure $ VFunction name Global (type2LlvmType t) + Nothing -> pure $ VIdent name (type2LlvmType t) + + e -> do + compileExp e v <- getVarCount - return $ VIdent (Ident $ show v) (getType e) + pure $ VIdent (Ident $ show v) (getType e) type2LlvmType :: Type -> LLVMType type2LlvmType = \case diff --git a/src/LlvmIr.hs b/src/LlvmIr.hs index b29f296..d340ddc 100644 --- a/src/LlvmIr.hs +++ b/src/LlvmIr.hs @@ -9,8 +9,8 @@ module LlvmIr ( Visibility (..), ) where -import Data.List (intercalate) -import TypeCheckerIr +import Data.List (intercalate) +import TypeCheckerIr -- | A datatype which represents some basic LLVM types data LLVMType @@ -51,8 +51,8 @@ data LLVMComp instance Show LLVMComp where show :: LLVMComp -> String show = \case - LLEq -> "eq" - LLNe -> "ne" + LLEq -> "eq" + LLNe -> "ne" LLUgt -> "ugt" LLUge -> "uge" LLUlt -> "ult" @@ -65,12 +65,11 @@ instance Show LLVMComp where data Visibility = Local | Global instance Show Visibility where show :: Visibility -> String - show Local = "%" + show Local = "%" show Global = "@" -{- | Represents a LLVM "value", as in an integer, a register variable, - or a string contstant --} +-- | Represents a LLVM "value", as in an integer, a register variable, +-- or a string contstant data LLVMValue = VInteger Integer | VIdent Ident LLVMType @@ -80,10 +79,10 @@ data LLVMValue instance Show LLVMValue where show :: LLVMValue -> String show v = case v of - VInteger i -> show i - VIdent (Ident n) _ -> "%" <> n + VInteger i -> show i + VIdent (Ident n) _ -> "%" <> n VFunction (Ident n) vis _ -> show vis <> n - VConstant s -> "c" <> show s + VConstant s -> "c" <> show s type Params = [(Ident, LLVMType)] type Args = [(LLVMType, LLVMValue)] @@ -122,87 +121,84 @@ llvmIrToString = go 0 go _ [] = mempty go i (x : xs) = do let (i', n) = case x of - Define{} -> (i + 1, 0) + Define{} -> (i + 1, 0) DefineEnd -> (i - 1, 0) - _ -> (i, i) + _ -> (i, i) insToString n x <> go i' xs -{- | Converts a LLVM inststruction to a String, allowing for printing etc. - The integer represents the indentation --} -{- FOURMOLU_DISABLE -} - insToString :: Int -> LLVMIr -> String - insToString i l = - replicate i '\t' <> case l of - (Define t (Ident i) params) -> - concat - [ "define ", show t, " @", i - , "(", intercalate ", " (map (\(Ident y, x) -> unwords [show x, "%" <> y]) params) - , ") {\n" - ] - DefineEnd -> "}\n" - (Declare _t (Ident _i) _params) -> undefined - (SetVariable (Ident i) ir) -> concat ["%", i, " = ", insToString 0 ir] - (Add t v1 v2) -> - concat - [ "add ", show t, " ", show v1 - , ", ", show v2, "\n" - ] - (Sub t v1 v2) -> - concat - [ "sub ", show t, " ", show v1, ", " - , show v2, "\n" - ] - (Div t v1 v2) -> - concat - [ "sdiv ", show t, " ", show v1, ", " - , show v2, "\n" - ] - (Mul t v1 v2) -> - concat - [ "mul ", show t, " ", show v1 - , ", ", show v2, "\n" - ] - (Srem t v1 v2) -> - concat - [ "srem ", show t, " ", show v1, ", " - , show v2, "\n" - ] - (Call t vis (Ident i) arg) -> - concat - [ "call ", show t, " ", show vis, i, "(" - , intercalate ", " $ Prelude.map (\(x, y) -> show x <> " " <> show y) arg - , ")\n" - ] - (Alloca t) -> unwords ["alloca", show t, "\n"] - (Store t1 (Ident id1) t2 (Ident id2)) -> - concat - [ "store ", show t1, " %", id1 - , ", ", show t2 , " %", id2, "\n" - ] - (Bitcast t1 (Ident i) t2) -> - concat - [ "bitcast ", show t1, " %" - , i, " to ", show t2, "\n" - ] - (Icmp comp t v1 v2) -> - concat - [ "icmp ", show comp, " ", show t - , " ", show v1, ", ", show v2, "\n" - ] - (Ret t v) -> - concat - [ "ret ", show t, " " - , show v, "\n" - ] - (UnsafeRaw s) -> s - (Label (Ident s)) -> "\nlabel_" <> s <> ":\n" - (Br (Ident s)) -> "br label %label_" <> s <> "\n" - (BrCond val (Ident s1) (Ident s2)) -> - concat - [ "br i1 ", show val, ", ", "label %" - , "label_", s1, ", ", "label %", "label_", s2, "\n" - ] - (Comment s) -> "; " <> s <> "\n" - (Variable (Ident id)) -> "%" <> id -{- FOURMOLU_ENABLE -} +-- | Converts a LLVM inststruction to a String, allowing for printing etc. +-- The integer represents the indentation +insToString :: Int -> LLVMIr -> String +insToString i l = + replicate i '\t' <> case l of + (Define t (Ident i) params) -> + concat + [ "define ", show t, " @", i + , "(", intercalate ", " (map (\(Ident y, x) -> unwords [show x, "%" <> y]) params) + , ") {\n" + ] + DefineEnd -> "}\n" + (Declare _t (Ident _i) _params) -> undefined + (SetVariable (Ident i) ir) -> concat ["%", i, " = ", insToString 0 ir] + (Add t v1 v2) -> + concat + [ "add ", show t, " ", show v1 + , ", ", show v2, "\n" + ] + (Sub t v1 v2) -> + concat + [ "sub ", show t, " ", show v1, ", " + , show v2, "\n" + ] + (Div t v1 v2) -> + concat + [ "sdiv ", show t, " ", show v1, ", " + , show v2, "\n" + ] + (Mul t v1 v2) -> + concat + [ "mul ", show t, " ", show v1 + , ", ", show v2, "\n" + ] + (Srem t v1 v2) -> + concat + [ "srem ", show t, " ", show v1, ", " + , show v2, "\n" + ] + (Call t vis (Ident i) arg) -> + concat + [ "call ", show t, " ", show vis, i, "(" + , intercalate ", " $ Prelude.map (\(x, y) -> show x <> " " <> show y) arg + , ")\n" + ] + (Alloca t) -> unwords ["alloca", show t, "\n"] + (Store t1 (Ident id1) t2 (Ident id2)) -> + concat + [ "store ", show t1, " %", id1 + , ", ", show t2 , " %", id2, "\n" + ] + (Bitcast t1 (Ident i) t2) -> + concat + [ "bitcast ", show t1, " %" + , i, " to ", show t2, "\n" + ] + (Icmp comp t v1 v2) -> + concat + [ "icmp ", show comp, " ", show t + , " ", show v1, ", ", show v2, "\n" + ] + (Ret t v) -> + concat + [ "ret ", show t, " " + , show v, "\n" + ] + (UnsafeRaw s) -> s + (Label (Ident s)) -> "\nlabel_" <> s <> ":\n" + (Br (Ident s)) -> "br label %label_" <> s <> "\n" + (BrCond val (Ident s1) (Ident s2)) -> + concat + [ "br i1 ", show val, ", ", "label %" + , "label_", s1, ", ", "label %", "label_", s2, "\n" + ] + (Comment s) -> "; " <> s <> "\n" + (Variable (Ident id)) -> "%" <> id