Implemented case matching on ints in the code generator

This commit is contained in:
Samuel Hammersberg 2023-02-18 14:36:46 +01:00
parent 7cedc2e28c
commit 287f84377c
2 changed files with 130 additions and 49 deletions

View file

@ -1,36 +1,38 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
module Compiler (compile) where
import Control.Monad.State (StateT, execStateT, gets, modify)
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Tuple.Extra (second)
import Grammar.ErrM (Err)
import Grammar.Print (printTree)
import LlvmIr (
LLVMIr (..),
LLVMType (..),
LLVMValue (..),
Visibility (..),
llvmIrToString,
)
import TypeChecker (partitionType)
import TypeCheckerIr
import Control.Monad.State (StateT, execStateT, gets, modify)
import Data.List.Extra (trim)
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Tuple.Extra (second)
import Grammar.ErrM (Err)
import Grammar.Print (printTree)
import LlvmIr (LLVMComp (..), LLVMIr (..),
LLVMType (..), LLVMValue (..),
Visibility (..), llvmIrToString)
import System.IO (stdin)
import System.Process.Extra (CreateProcess (std_in),
StdStream (CreatePipe), createProcess,
readCreateProcess, shell)
import TypeChecker (partitionType)
import TypeCheckerIr
-- | The record used as the code generator state
data CodeGenerator = CodeGenerator
{ instructions :: [LLVMIr]
, functions :: Map Id FunctionInfo
{ instructions :: [LLVMIr]
, functions :: Map Id FunctionInfo
, variableCount :: Integer
, labelCount :: Integer
}
-- | A state type synonym
type CompilerState a = StateT CodeGenerator Err a
data FunctionInfo = FunctionInfo
{ numArgs :: Int
{ numArgs :: Int
, arguments :: [Id]
}
@ -50,6 +52,12 @@ getVarCount = gets variableCount
getNewVar :: CompilerState Integer
getNewVar = increaseVarCount >> getVarCount
-- | Increses the label count and returns a label from the CodeGenerator state
getNewLabel :: CompilerState Integer
getNewLabel = do
modify (\t -> t{labelCount = labelCount t + 1})
gets labelCount
{- | Produces a map of functions infos from a list of binds,
which contains useful data for code generation.
-}
@ -67,6 +75,36 @@ getFunctions xs =
)
xs
run :: Err String -> IO ()
run s = do
let s' = case s of
Right s -> s
Left _ -> error "yo"
writeFile "llvm.ll" s'
putStrLn . trim =<< readCreateProcess (shell "lli") s'
test :: Integer -> Program
test v = Program [
Bind (Ident "fibonacci", TInt) [(Ident "x", TInt)] (
ECased (EId ("x", TInt)) [
Case (CInt 0) (EInt 0),
Case (CInt 1) (EInt 1),
Case CatchAll (EAdd TInt
(EApp TInt (EId (Ident "fibonacci", TInt)) (
EAdd TInt (EId (Ident "x", TInt))
(EInt (fromIntegral ((maxBound :: Int) * 2)))
))
(EApp TInt (EId (Ident "fibonacci", TInt)) (
EAdd TInt (EId (Ident "x", TInt))
(EInt (fromIntegral ((maxBound :: Int) * 2 + 1)))
))
)
]
),
Bind (Ident "main",TInt) [] (
EApp TInt (EId (Ident "fibonacci", TInt)) (EInt v) -- (EInt 92)
)
]
{- | Compiles an AST and produces a LLVM Ir string.
An easy way to actually "compile" this output is to
Simply pipe it to LLI
@ -78,6 +116,7 @@ compile (Program prg) = do
{ instructions = defaultStart
, functions = getFunctions prg
, variableCount = 0
, labelCount = 0
}
ins <- instructions <$> execStateT (goDef prg) s
pure $ llvmIrToString ins
@ -112,7 +151,7 @@ compile (Program prg) = do
goDef (Bind (name, t) args exp : xs) = do
emit $ UnsafeRaw "\n"
emit $ Comment $ show name <> ": " <> show exp
emit $ Define (type2LlvmType t_return) name (map (second type2LlvmType) args)
emit $ Define (I64{-type2LlvmType t_return-}) name (map (second type2LlvmType) args)
functionBody <- exprToValue exp
if name == "main"
then mapM_ emit (mainContent functionBody)
@ -124,21 +163,54 @@ compile (Program prg) = do
t_return = snd $ partitionType (length args) t
go :: Exp -> CompilerState ()
go (EInt int) = emitInt int
go (EAdd t e1 e2) = emitAdd t e1 e2
go (EInt int) = emitInt int
go (EAdd t e1 e2) = emitAdd t e1 e2
go (EId (name, _)) = emitIdent name
go (EApp t e1 e2) = emitApp t e1 e2
go (EAbs t ti e) = emitAbs t ti e
go (ELet binds e) = emitLet binds e
go (EAnn _ _) = emitEAnn
go (EApp t e1 e2) = emitApp t e1 e2
go (EAbs t ti e) = emitAbs t ti e
go (ELet binds e) = emitLet binds e
go (EAnn _ _) = emitEAnn
go (ECased e c) = emitECased e c
-- go (ESub e1 e2) = emitSub e1 e2
-- go (EMul e1 e2) = emitMul e1 e2
-- go (EDiv e1 e2) = emitDiv e1 e2
-- go (EMod e1 e2) = emitMod e1 e2
--- aux functions ---
emitECased :: Exp -> [Case] -> CompilerState ()
emitECased e cs = do
vs <- exprToValue e
lbl <- getNewLabel
let label = Ident $ "escape_" <> show lbl
stackPtr <- getNewVar
emit $ SetVariable (Ident $ show stackPtr) (Alloca I64)
mapM_ (emitCases label stackPtr vs) cs
emit $ Label label
res <- getNewVar
emit $ SetVariable (Ident $ show res) (Load I64 Ptr (Ident $ show stackPtr))
where
emitCases :: Ident -> Integer -> LLVMValue -> Case -> CompilerState ()
emitCases label stackPtr vs (Case (CInt i) exp) = do
ns <- getNewVar
lbl_fail <- getNewLabel
lbl_succ <- getNewLabel
let failed = Ident $ "failed_" <> show lbl_fail
let success = Ident $ "success_" <> show lbl_succ
emit $ SetVariable (Ident $ show ns) (Icmp LLEq I64 vs (VInteger i))
emit $ BrCond (VIdent (Ident $ show ns) I64) success failed
emit $ Label success
val <- exprToValue exp
emit $ Store I64 val Ptr (Ident . show $ stackPtr)
emit $ Br label
emit $ Label failed
emitCases label stackPtr _ (Case CatchAll exp) = do
val <- exprToValue exp
emit $ Store I64 val Ptr (Ident . show $ stackPtr)
emit $ Br label
emitEAnn :: CompilerState ()
emitEAnn = emit . UnsafeRaw $ "why?"
emitEAnn = emit . UnsafeRaw $ "Annotated escaped previous stages"
emitAbs :: Type -> Id -> Exp -> CompilerState ()
emitAbs _t tid e = do
@ -170,7 +242,7 @@ compile (Program prg) = do
funcs <- gets functions
let vis = case Map.lookup id funcs of
Nothing -> Local
Just _ -> Global
Just _ -> Global
let call = Call (type2LlvmType t) vis name ((\x -> (valueGetType x, x)) <$> args)
emit $ SetVariable (Ident $ show vs) call
x -> do
@ -271,19 +343,19 @@ type2LlvmType = \case
where
function2LLVMType :: Type -> [LLVMType] -> (LLVMType, [LLVMType])
function2LLVMType (TFun t xs) s = function2LLVMType xs (type2LlvmType t : s)
function2LLVMType x s = (type2LlvmType x, s)
function2LLVMType x s = (type2LlvmType x, s)
getType :: Exp -> LLVMType
getType (EInt _) = I64
getType (EInt _) = I64
getType (EAdd t _ _) = type2LlvmType t
getType (EId (_, t)) = type2LlvmType t
getType (EApp t _ _) = type2LlvmType t
getType (EAbs t _ _) = type2LlvmType t
getType (ELet _ e) = getType e
getType (EAnn _ t) = type2LlvmType t
getType (ELet _ e) = getType e
getType (EAnn _ t) = type2LlvmType t
valueGetType :: LLVMValue -> LLVMType
valueGetType (VInteger _) = I64
valueGetType (VIdent _ t) = t
valueGetType (VConstant s) = Array (length s) I8
valueGetType (VInteger _) = I64
valueGetType (VIdent _ t) = t
valueGetType (VConstant s) = Array (length s) I8
valueGetType (VFunction _ _ t) = t

View file

@ -9,8 +9,8 @@ module LlvmIr (
Visibility (..),
) where
import Data.List (intercalate)
import TypeCheckerIr
import Data.List (intercalate)
import TypeCheckerIr
-- | A datatype which represents some basic LLVM types
data LLVMType
@ -65,7 +65,7 @@ instance Show LLVMComp where
data Visibility = Local | Global
instance Show Visibility where
show :: Visibility -> String
show Local = "%"
show Local = "%"
show Global = "@"
{- | Represents a LLVM "value", as in an integer, a register variable,
@ -80,10 +80,10 @@ data LLVMValue
instance Show LLVMValue where
show :: LLVMValue -> String
show v = case v of
VInteger i -> show i
VIdent (Ident n) _ -> "%" <> n
VInteger i -> show i
VIdent (Ident n) _ -> "%" <> n
VFunction (Ident n) vis _ -> show vis <> n
VConstant s -> "c" <> show s
VConstant s -> "c" <> show s
type Params = [(Ident, LLVMType)]
type Args = [(LLVMType, LLVMValue)]
@ -106,7 +106,8 @@ data LLVMIr
| Label Ident
| Call LLVMType Visibility Ident Args
| Alloca LLVMType
| Store LLVMType Ident LLVMType Ident
| Store LLVMType LLVMValue LLVMType Ident
| Load LLVMType LLVMType Ident
| Bitcast LLVMType Ident LLVMType
| Ret LLVMType LLVMValue
| Comment String
@ -122,9 +123,9 @@ llvmIrToString = go 0
go _ [] = mempty
go i (x : xs) = do
let (i', n) = case x of
Define{} -> (i + 1, 0)
Define{} -> (i + 1, 0)
DefineEnd -> (i - 1, 0)
_ -> (i, i)
_ -> (i, i)
insToString n x <> go i' xs
{- | Converts a LLVM inststruction to a String, allowing for printing etc.
@ -175,11 +176,16 @@ llvmIrToString = go 0
, ")\n"
]
(Alloca t) -> unwords ["alloca", show t, "\n"]
(Store t1 (Ident id1) t2 (Ident id2)) ->
(Store t1 val t2 (Ident id2)) ->
concat
[ "store ", show t1, " %", id1
[ "store ", show t1, " ", show val
, ", ", show t2 , " %", id2, "\n"
]
(Load t1 t2 (Ident addr)) ->
concat
[ "load ", show t1, ", "
, show t2, " %", addr, "\n"
]
(Bitcast t1 (Ident i) t2) ->
concat
[ "bitcast ", show t1, " %"
@ -196,13 +202,16 @@ llvmIrToString = go 0
, show v, "\n"
]
(UnsafeRaw s) -> s
(Label (Ident s)) -> "\nlabel_" <> s <> ":\n"
(Br (Ident s)) -> "br label %label_" <> s <> "\n"
(Label (Ident s)) -> "\n" <> lblPfx <> s <> ":\n"
(Br (Ident s)) -> "br label %" <> lblPfx <> s <> "\n"
(BrCond val (Ident s1) (Ident s2)) ->
concat
[ "br i1 ", show val, ", ", "label %"
, "label_", s1, ", ", "label %", "label_", s2, "\n"
, lblPfx, s1, ", ", "label %", lblPfx, s2, "\n"
]
(Comment s) -> "; " <> s <> "\n"
(Variable (Ident id)) -> "%" <> id
{- FOURMOLU_ENABLE -}
lblPfx :: String
lblPfx = "lbl_"