Implemented case matching on ints in the code generator

This commit is contained in:
Samuel Hammersberg 2023-02-18 14:36:46 +01:00
parent 7cedc2e28c
commit 287f84377c
2 changed files with 130 additions and 49 deletions

View file

@ -4,18 +4,19 @@
module Compiler (compile) where module Compiler (compile) where
import Control.Monad.State (StateT, execStateT, gets, modify) import Control.Monad.State (StateT, execStateT, gets, modify)
import Data.List.Extra (trim)
import Data.Map (Map) import Data.Map (Map)
import qualified Data.Map as Map import qualified Data.Map as Map
import Data.Tuple.Extra (second) import Data.Tuple.Extra (second)
import Grammar.ErrM (Err) import Grammar.ErrM (Err)
import Grammar.Print (printTree) import Grammar.Print (printTree)
import LlvmIr ( import LlvmIr (LLVMComp (..), LLVMIr (..),
LLVMIr (..), LLVMType (..), LLVMValue (..),
LLVMType (..), Visibility (..), llvmIrToString)
LLVMValue (..), import System.IO (stdin)
Visibility (..), import System.Process.Extra (CreateProcess (std_in),
llvmIrToString, StdStream (CreatePipe), createProcess,
) readCreateProcess, shell)
import TypeChecker (partitionType) import TypeChecker (partitionType)
import TypeCheckerIr import TypeCheckerIr
@ -24,6 +25,7 @@ data CodeGenerator = CodeGenerator
{ instructions :: [LLVMIr] { instructions :: [LLVMIr]
, functions :: Map Id FunctionInfo , functions :: Map Id FunctionInfo
, variableCount :: Integer , variableCount :: Integer
, labelCount :: Integer
} }
-- | A state type synonym -- | A state type synonym
@ -50,6 +52,12 @@ getVarCount = gets variableCount
getNewVar :: CompilerState Integer getNewVar :: CompilerState Integer
getNewVar = increaseVarCount >> getVarCount getNewVar = increaseVarCount >> getVarCount
-- | Increses the label count and returns a label from the CodeGenerator state
getNewLabel :: CompilerState Integer
getNewLabel = do
modify (\t -> t{labelCount = labelCount t + 1})
gets labelCount
{- | Produces a map of functions infos from a list of binds, {- | Produces a map of functions infos from a list of binds,
which contains useful data for code generation. which contains useful data for code generation.
-} -}
@ -67,6 +75,36 @@ getFunctions xs =
) )
xs xs
run :: Err String -> IO ()
run s = do
let s' = case s of
Right s -> s
Left _ -> error "yo"
writeFile "llvm.ll" s'
putStrLn . trim =<< readCreateProcess (shell "lli") s'
test :: Integer -> Program
test v = Program [
Bind (Ident "fibonacci", TInt) [(Ident "x", TInt)] (
ECased (EId ("x", TInt)) [
Case (CInt 0) (EInt 0),
Case (CInt 1) (EInt 1),
Case CatchAll (EAdd TInt
(EApp TInt (EId (Ident "fibonacci", TInt)) (
EAdd TInt (EId (Ident "x", TInt))
(EInt (fromIntegral ((maxBound :: Int) * 2)))
))
(EApp TInt (EId (Ident "fibonacci", TInt)) (
EAdd TInt (EId (Ident "x", TInt))
(EInt (fromIntegral ((maxBound :: Int) * 2 + 1)))
))
)
]
),
Bind (Ident "main",TInt) [] (
EApp TInt (EId (Ident "fibonacci", TInt)) (EInt v) -- (EInt 92)
)
]
{- | Compiles an AST and produces a LLVM Ir string. {- | Compiles an AST and produces a LLVM Ir string.
An easy way to actually "compile" this output is to An easy way to actually "compile" this output is to
Simply pipe it to LLI Simply pipe it to LLI
@ -78,6 +116,7 @@ compile (Program prg) = do
{ instructions = defaultStart { instructions = defaultStart
, functions = getFunctions prg , functions = getFunctions prg
, variableCount = 0 , variableCount = 0
, labelCount = 0
} }
ins <- instructions <$> execStateT (goDef prg) s ins <- instructions <$> execStateT (goDef prg) s
pure $ llvmIrToString ins pure $ llvmIrToString ins
@ -112,7 +151,7 @@ compile (Program prg) = do
goDef (Bind (name, t) args exp : xs) = do goDef (Bind (name, t) args exp : xs) = do
emit $ UnsafeRaw "\n" emit $ UnsafeRaw "\n"
emit $ Comment $ show name <> ": " <> show exp emit $ Comment $ show name <> ": " <> show exp
emit $ Define (type2LlvmType t_return) name (map (second type2LlvmType) args) emit $ Define (I64{-type2LlvmType t_return-}) name (map (second type2LlvmType) args)
functionBody <- exprToValue exp functionBody <- exprToValue exp
if name == "main" if name == "main"
then mapM_ emit (mainContent functionBody) then mapM_ emit (mainContent functionBody)
@ -131,14 +170,47 @@ compile (Program prg) = do
go (EAbs t ti e) = emitAbs t ti e go (EAbs t ti e) = emitAbs t ti e
go (ELet binds e) = emitLet binds e go (ELet binds e) = emitLet binds e
go (EAnn _ _) = emitEAnn go (EAnn _ _) = emitEAnn
go (ECased e c) = emitECased e c
-- go (ESub e1 e2) = emitSub e1 e2 -- go (ESub e1 e2) = emitSub e1 e2
-- go (EMul e1 e2) = emitMul e1 e2 -- go (EMul e1 e2) = emitMul e1 e2
-- go (EDiv e1 e2) = emitDiv e1 e2 -- go (EDiv e1 e2) = emitDiv e1 e2
-- go (EMod e1 e2) = emitMod e1 e2 -- go (EMod e1 e2) = emitMod e1 e2
--- aux functions --- --- aux functions ---
emitECased :: Exp -> [Case] -> CompilerState ()
emitECased e cs = do
vs <- exprToValue e
lbl <- getNewLabel
let label = Ident $ "escape_" <> show lbl
stackPtr <- getNewVar
emit $ SetVariable (Ident $ show stackPtr) (Alloca I64)
mapM_ (emitCases label stackPtr vs) cs
emit $ Label label
res <- getNewVar
emit $ SetVariable (Ident $ show res) (Load I64 Ptr (Ident $ show stackPtr))
where
emitCases :: Ident -> Integer -> LLVMValue -> Case -> CompilerState ()
emitCases label stackPtr vs (Case (CInt i) exp) = do
ns <- getNewVar
lbl_fail <- getNewLabel
lbl_succ <- getNewLabel
let failed = Ident $ "failed_" <> show lbl_fail
let success = Ident $ "success_" <> show lbl_succ
emit $ SetVariable (Ident $ show ns) (Icmp LLEq I64 vs (VInteger i))
emit $ BrCond (VIdent (Ident $ show ns) I64) success failed
emit $ Label success
val <- exprToValue exp
emit $ Store I64 val Ptr (Ident . show $ stackPtr)
emit $ Br label
emit $ Label failed
emitCases label stackPtr _ (Case CatchAll exp) = do
val <- exprToValue exp
emit $ Store I64 val Ptr (Ident . show $ stackPtr)
emit $ Br label
emitEAnn :: CompilerState () emitEAnn :: CompilerState ()
emitEAnn = emit . UnsafeRaw $ "why?" emitEAnn = emit . UnsafeRaw $ "Annotated escaped previous stages"
emitAbs :: Type -> Id -> Exp -> CompilerState () emitAbs :: Type -> Id -> Exp -> CompilerState ()
emitAbs _t tid e = do emitAbs _t tid e = do

View file

@ -106,7 +106,8 @@ data LLVMIr
| Label Ident | Label Ident
| Call LLVMType Visibility Ident Args | Call LLVMType Visibility Ident Args
| Alloca LLVMType | Alloca LLVMType
| Store LLVMType Ident LLVMType Ident | Store LLVMType LLVMValue LLVMType Ident
| Load LLVMType LLVMType Ident
| Bitcast LLVMType Ident LLVMType | Bitcast LLVMType Ident LLVMType
| Ret LLVMType LLVMValue | Ret LLVMType LLVMValue
| Comment String | Comment String
@ -175,11 +176,16 @@ llvmIrToString = go 0
, ")\n" , ")\n"
] ]
(Alloca t) -> unwords ["alloca", show t, "\n"] (Alloca t) -> unwords ["alloca", show t, "\n"]
(Store t1 (Ident id1) t2 (Ident id2)) -> (Store t1 val t2 (Ident id2)) ->
concat concat
[ "store ", show t1, " %", id1 [ "store ", show t1, " ", show val
, ", ", show t2 , " %", id2, "\n" , ", ", show t2 , " %", id2, "\n"
] ]
(Load t1 t2 (Ident addr)) ->
concat
[ "load ", show t1, ", "
, show t2, " %", addr, "\n"
]
(Bitcast t1 (Ident i) t2) -> (Bitcast t1 (Ident i) t2) ->
concat concat
[ "bitcast ", show t1, " %" [ "bitcast ", show t1, " %"
@ -196,13 +202,16 @@ llvmIrToString = go 0
, show v, "\n" , show v, "\n"
] ]
(UnsafeRaw s) -> s (UnsafeRaw s) -> s
(Label (Ident s)) -> "\nlabel_" <> s <> ":\n" (Label (Ident s)) -> "\n" <> lblPfx <> s <> ":\n"
(Br (Ident s)) -> "br label %label_" <> s <> "\n" (Br (Ident s)) -> "br label %" <> lblPfx <> s <> "\n"
(BrCond val (Ident s1) (Ident s2)) -> (BrCond val (Ident s1) (Ident s2)) ->
concat concat
[ "br i1 ", show val, ", ", "label %" [ "br i1 ", show val, ", ", "label %"
, "label_", s1, ", ", "label %", "label_", s2, "\n" , lblPfx, s1, ", ", "label %", lblPfx, s2, "\n"
] ]
(Comment s) -> "; " <> s <> "\n" (Comment s) -> "; " <> s <> "\n"
(Variable (Ident id)) -> "%" <> id (Variable (Ident id)) -> "%" <> id
{- FOURMOLU_ENABLE -} {- FOURMOLU_ENABLE -}
lblPfx :: String
lblPfx = "lbl_"