Implemented case matching on ints in the code generator

This commit is contained in:
Samuel Hammersberg 2023-02-18 14:36:46 +01:00
parent 7cedc2e28c
commit 287f84377c
2 changed files with 130 additions and 49 deletions

View file

@ -1,36 +1,38 @@
{-# LANGUAGE LambdaCase #-} {-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE OverloadedStrings #-}
module Compiler (compile) where module Compiler (compile) where
import Control.Monad.State (StateT, execStateT, gets, modify) import Control.Monad.State (StateT, execStateT, gets, modify)
import Data.Map (Map) import Data.List.Extra (trim)
import qualified Data.Map as Map import Data.Map (Map)
import Data.Tuple.Extra (second) import qualified Data.Map as Map
import Grammar.ErrM (Err) import Data.Tuple.Extra (second)
import Grammar.Print (printTree) import Grammar.ErrM (Err)
import LlvmIr ( import Grammar.Print (printTree)
LLVMIr (..), import LlvmIr (LLVMComp (..), LLVMIr (..),
LLVMType (..), LLVMType (..), LLVMValue (..),
LLVMValue (..), Visibility (..), llvmIrToString)
Visibility (..), import System.IO (stdin)
llvmIrToString, import System.Process.Extra (CreateProcess (std_in),
) StdStream (CreatePipe), createProcess,
import TypeChecker (partitionType) readCreateProcess, shell)
import TypeCheckerIr import TypeChecker (partitionType)
import TypeCheckerIr
-- | The record used as the code generator state -- | The record used as the code generator state
data CodeGenerator = CodeGenerator data CodeGenerator = CodeGenerator
{ instructions :: [LLVMIr] { instructions :: [LLVMIr]
, functions :: Map Id FunctionInfo , functions :: Map Id FunctionInfo
, variableCount :: Integer , variableCount :: Integer
, labelCount :: Integer
} }
-- | A state type synonym -- | A state type synonym
type CompilerState a = StateT CodeGenerator Err a type CompilerState a = StateT CodeGenerator Err a
data FunctionInfo = FunctionInfo data FunctionInfo = FunctionInfo
{ numArgs :: Int { numArgs :: Int
, arguments :: [Id] , arguments :: [Id]
} }
@ -50,6 +52,12 @@ getVarCount = gets variableCount
getNewVar :: CompilerState Integer getNewVar :: CompilerState Integer
getNewVar = increaseVarCount >> getVarCount getNewVar = increaseVarCount >> getVarCount
-- | Increses the label count and returns a label from the CodeGenerator state
getNewLabel :: CompilerState Integer
getNewLabel = do
modify (\t -> t{labelCount = labelCount t + 1})
gets labelCount
{- | Produces a map of functions infos from a list of binds, {- | Produces a map of functions infos from a list of binds,
which contains useful data for code generation. which contains useful data for code generation.
-} -}
@ -67,6 +75,36 @@ getFunctions xs =
) )
xs xs
run :: Err String -> IO ()
run s = do
let s' = case s of
Right s -> s
Left _ -> error "yo"
writeFile "llvm.ll" s'
putStrLn . trim =<< readCreateProcess (shell "lli") s'
test :: Integer -> Program
test v = Program [
Bind (Ident "fibonacci", TInt) [(Ident "x", TInt)] (
ECased (EId ("x", TInt)) [
Case (CInt 0) (EInt 0),
Case (CInt 1) (EInt 1),
Case CatchAll (EAdd TInt
(EApp TInt (EId (Ident "fibonacci", TInt)) (
EAdd TInt (EId (Ident "x", TInt))
(EInt (fromIntegral ((maxBound :: Int) * 2)))
))
(EApp TInt (EId (Ident "fibonacci", TInt)) (
EAdd TInt (EId (Ident "x", TInt))
(EInt (fromIntegral ((maxBound :: Int) * 2 + 1)))
))
)
]
),
Bind (Ident "main",TInt) [] (
EApp TInt (EId (Ident "fibonacci", TInt)) (EInt v) -- (EInt 92)
)
]
{- | Compiles an AST and produces a LLVM Ir string. {- | Compiles an AST and produces a LLVM Ir string.
An easy way to actually "compile" this output is to An easy way to actually "compile" this output is to
Simply pipe it to LLI Simply pipe it to LLI
@ -78,6 +116,7 @@ compile (Program prg) = do
{ instructions = defaultStart { instructions = defaultStart
, functions = getFunctions prg , functions = getFunctions prg
, variableCount = 0 , variableCount = 0
, labelCount = 0
} }
ins <- instructions <$> execStateT (goDef prg) s ins <- instructions <$> execStateT (goDef prg) s
pure $ llvmIrToString ins pure $ llvmIrToString ins
@ -112,7 +151,7 @@ compile (Program prg) = do
goDef (Bind (name, t) args exp : xs) = do goDef (Bind (name, t) args exp : xs) = do
emit $ UnsafeRaw "\n" emit $ UnsafeRaw "\n"
emit $ Comment $ show name <> ": " <> show exp emit $ Comment $ show name <> ": " <> show exp
emit $ Define (type2LlvmType t_return) name (map (second type2LlvmType) args) emit $ Define (I64{-type2LlvmType t_return-}) name (map (second type2LlvmType) args)
functionBody <- exprToValue exp functionBody <- exprToValue exp
if name == "main" if name == "main"
then mapM_ emit (mainContent functionBody) then mapM_ emit (mainContent functionBody)
@ -124,21 +163,54 @@ compile (Program prg) = do
t_return = snd $ partitionType (length args) t t_return = snd $ partitionType (length args) t
go :: Exp -> CompilerState () go :: Exp -> CompilerState ()
go (EInt int) = emitInt int go (EInt int) = emitInt int
go (EAdd t e1 e2) = emitAdd t e1 e2 go (EAdd t e1 e2) = emitAdd t e1 e2
go (EId (name, _)) = emitIdent name go (EId (name, _)) = emitIdent name
go (EApp t e1 e2) = emitApp t e1 e2 go (EApp t e1 e2) = emitApp t e1 e2
go (EAbs t ti e) = emitAbs t ti e go (EAbs t ti e) = emitAbs t ti e
go (ELet binds e) = emitLet binds e go (ELet binds e) = emitLet binds e
go (EAnn _ _) = emitEAnn go (EAnn _ _) = emitEAnn
go (ECased e c) = emitECased e c
-- go (ESub e1 e2) = emitSub e1 e2 -- go (ESub e1 e2) = emitSub e1 e2
-- go (EMul e1 e2) = emitMul e1 e2 -- go (EMul e1 e2) = emitMul e1 e2
-- go (EDiv e1 e2) = emitDiv e1 e2 -- go (EDiv e1 e2) = emitDiv e1 e2
-- go (EMod e1 e2) = emitMod e1 e2 -- go (EMod e1 e2) = emitMod e1 e2
--- aux functions --- --- aux functions ---
emitECased :: Exp -> [Case] -> CompilerState ()
emitECased e cs = do
vs <- exprToValue e
lbl <- getNewLabel
let label = Ident $ "escape_" <> show lbl
stackPtr <- getNewVar
emit $ SetVariable (Ident $ show stackPtr) (Alloca I64)
mapM_ (emitCases label stackPtr vs) cs
emit $ Label label
res <- getNewVar
emit $ SetVariable (Ident $ show res) (Load I64 Ptr (Ident $ show stackPtr))
where
emitCases :: Ident -> Integer -> LLVMValue -> Case -> CompilerState ()
emitCases label stackPtr vs (Case (CInt i) exp) = do
ns <- getNewVar
lbl_fail <- getNewLabel
lbl_succ <- getNewLabel
let failed = Ident $ "failed_" <> show lbl_fail
let success = Ident $ "success_" <> show lbl_succ
emit $ SetVariable (Ident $ show ns) (Icmp LLEq I64 vs (VInteger i))
emit $ BrCond (VIdent (Ident $ show ns) I64) success failed
emit $ Label success
val <- exprToValue exp
emit $ Store I64 val Ptr (Ident . show $ stackPtr)
emit $ Br label
emit $ Label failed
emitCases label stackPtr _ (Case CatchAll exp) = do
val <- exprToValue exp
emit $ Store I64 val Ptr (Ident . show $ stackPtr)
emit $ Br label
emitEAnn :: CompilerState () emitEAnn :: CompilerState ()
emitEAnn = emit . UnsafeRaw $ "why?" emitEAnn = emit . UnsafeRaw $ "Annotated escaped previous stages"
emitAbs :: Type -> Id -> Exp -> CompilerState () emitAbs :: Type -> Id -> Exp -> CompilerState ()
emitAbs _t tid e = do emitAbs _t tid e = do
@ -170,7 +242,7 @@ compile (Program prg) = do
funcs <- gets functions funcs <- gets functions
let vis = case Map.lookup id funcs of let vis = case Map.lookup id funcs of
Nothing -> Local Nothing -> Local
Just _ -> Global Just _ -> Global
let call = Call (type2LlvmType t) vis name ((\x -> (valueGetType x, x)) <$> args) let call = Call (type2LlvmType t) vis name ((\x -> (valueGetType x, x)) <$> args)
emit $ SetVariable (Ident $ show vs) call emit $ SetVariable (Ident $ show vs) call
x -> do x -> do
@ -271,19 +343,19 @@ type2LlvmType = \case
where where
function2LLVMType :: Type -> [LLVMType] -> (LLVMType, [LLVMType]) function2LLVMType :: Type -> [LLVMType] -> (LLVMType, [LLVMType])
function2LLVMType (TFun t xs) s = function2LLVMType xs (type2LlvmType t : s) function2LLVMType (TFun t xs) s = function2LLVMType xs (type2LlvmType t : s)
function2LLVMType x s = (type2LlvmType x, s) function2LLVMType x s = (type2LlvmType x, s)
getType :: Exp -> LLVMType getType :: Exp -> LLVMType
getType (EInt _) = I64 getType (EInt _) = I64
getType (EAdd t _ _) = type2LlvmType t getType (EAdd t _ _) = type2LlvmType t
getType (EId (_, t)) = type2LlvmType t getType (EId (_, t)) = type2LlvmType t
getType (EApp t _ _) = type2LlvmType t getType (EApp t _ _) = type2LlvmType t
getType (EAbs t _ _) = type2LlvmType t getType (EAbs t _ _) = type2LlvmType t
getType (ELet _ e) = getType e getType (ELet _ e) = getType e
getType (EAnn _ t) = type2LlvmType t getType (EAnn _ t) = type2LlvmType t
valueGetType :: LLVMValue -> LLVMType valueGetType :: LLVMValue -> LLVMType
valueGetType (VInteger _) = I64 valueGetType (VInteger _) = I64
valueGetType (VIdent _ t) = t valueGetType (VIdent _ t) = t
valueGetType (VConstant s) = Array (length s) I8 valueGetType (VConstant s) = Array (length s) I8
valueGetType (VFunction _ _ t) = t valueGetType (VFunction _ _ t) = t

View file

@ -9,8 +9,8 @@ module LlvmIr (
Visibility (..), Visibility (..),
) where ) where
import Data.List (intercalate) import Data.List (intercalate)
import TypeCheckerIr import TypeCheckerIr
-- | A datatype which represents some basic LLVM types -- | A datatype which represents some basic LLVM types
data LLVMType data LLVMType
@ -65,7 +65,7 @@ instance Show LLVMComp where
data Visibility = Local | Global data Visibility = Local | Global
instance Show Visibility where instance Show Visibility where
show :: Visibility -> String show :: Visibility -> String
show Local = "%" show Local = "%"
show Global = "@" show Global = "@"
{- | Represents a LLVM "value", as in an integer, a register variable, {- | Represents a LLVM "value", as in an integer, a register variable,
@ -80,10 +80,10 @@ data LLVMValue
instance Show LLVMValue where instance Show LLVMValue where
show :: LLVMValue -> String show :: LLVMValue -> String
show v = case v of show v = case v of
VInteger i -> show i VInteger i -> show i
VIdent (Ident n) _ -> "%" <> n VIdent (Ident n) _ -> "%" <> n
VFunction (Ident n) vis _ -> show vis <> n VFunction (Ident n) vis _ -> show vis <> n
VConstant s -> "c" <> show s VConstant s -> "c" <> show s
type Params = [(Ident, LLVMType)] type Params = [(Ident, LLVMType)]
type Args = [(LLVMType, LLVMValue)] type Args = [(LLVMType, LLVMValue)]
@ -106,7 +106,8 @@ data LLVMIr
| Label Ident | Label Ident
| Call LLVMType Visibility Ident Args | Call LLVMType Visibility Ident Args
| Alloca LLVMType | Alloca LLVMType
| Store LLVMType Ident LLVMType Ident | Store LLVMType LLVMValue LLVMType Ident
| Load LLVMType LLVMType Ident
| Bitcast LLVMType Ident LLVMType | Bitcast LLVMType Ident LLVMType
| Ret LLVMType LLVMValue | Ret LLVMType LLVMValue
| Comment String | Comment String
@ -122,9 +123,9 @@ llvmIrToString = go 0
go _ [] = mempty go _ [] = mempty
go i (x : xs) = do go i (x : xs) = do
let (i', n) = case x of let (i', n) = case x of
Define{} -> (i + 1, 0) Define{} -> (i + 1, 0)
DefineEnd -> (i - 1, 0) DefineEnd -> (i - 1, 0)
_ -> (i, i) _ -> (i, i)
insToString n x <> go i' xs insToString n x <> go i' xs
{- | Converts a LLVM inststruction to a String, allowing for printing etc. {- | Converts a LLVM inststruction to a String, allowing for printing etc.
@ -175,11 +176,16 @@ llvmIrToString = go 0
, ")\n" , ")\n"
] ]
(Alloca t) -> unwords ["alloca", show t, "\n"] (Alloca t) -> unwords ["alloca", show t, "\n"]
(Store t1 (Ident id1) t2 (Ident id2)) -> (Store t1 val t2 (Ident id2)) ->
concat concat
[ "store ", show t1, " %", id1 [ "store ", show t1, " ", show val
, ", ", show t2 , " %", id2, "\n" , ", ", show t2 , " %", id2, "\n"
] ]
(Load t1 t2 (Ident addr)) ->
concat
[ "load ", show t1, ", "
, show t2, " %", addr, "\n"
]
(Bitcast t1 (Ident i) t2) -> (Bitcast t1 (Ident i) t2) ->
concat concat
[ "bitcast ", show t1, " %" [ "bitcast ", show t1, " %"
@ -196,13 +202,16 @@ llvmIrToString = go 0
, show v, "\n" , show v, "\n"
] ]
(UnsafeRaw s) -> s (UnsafeRaw s) -> s
(Label (Ident s)) -> "\nlabel_" <> s <> ":\n" (Label (Ident s)) -> "\n" <> lblPfx <> s <> ":\n"
(Br (Ident s)) -> "br label %label_" <> s <> "\n" (Br (Ident s)) -> "br label %" <> lblPfx <> s <> "\n"
(BrCond val (Ident s1) (Ident s2)) -> (BrCond val (Ident s1) (Ident s2)) ->
concat concat
[ "br i1 ", show val, ", ", "label %" [ "br i1 ", show val, ", ", "label %"
, "label_", s1, ", ", "label %", "label_", s2, "\n" , lblPfx, s1, ", ", "label %", lblPfx, s2, "\n"
] ]
(Comment s) -> "; " <> s <> "\n" (Comment s) -> "; " <> s <> "\n"
(Variable (Ident id)) -> "%" <> id (Variable (Ident id)) -> "%" <> id
{- FOURMOLU_ENABLE -} {- FOURMOLU_ENABLE -}
lblPfx :: String
lblPfx = "lbl_"