Implemented case matching on ints in the code generator
This commit is contained in:
parent
7cedc2e28c
commit
287f84377c
2 changed files with 130 additions and 49 deletions
142
src/Compiler.hs
142
src/Compiler.hs
|
|
@ -1,36 +1,38 @@
|
|||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
|
||||
module Compiler (compile) where
|
||||
|
||||
import Control.Monad.State (StateT, execStateT, gets, modify)
|
||||
import Data.Map (Map)
|
||||
import qualified Data.Map as Map
|
||||
import Data.Tuple.Extra (second)
|
||||
import Grammar.ErrM (Err)
|
||||
import Grammar.Print (printTree)
|
||||
import LlvmIr (
|
||||
LLVMIr (..),
|
||||
LLVMType (..),
|
||||
LLVMValue (..),
|
||||
Visibility (..),
|
||||
llvmIrToString,
|
||||
)
|
||||
import TypeChecker (partitionType)
|
||||
import TypeCheckerIr
|
||||
import Control.Monad.State (StateT, execStateT, gets, modify)
|
||||
import Data.List.Extra (trim)
|
||||
import Data.Map (Map)
|
||||
import qualified Data.Map as Map
|
||||
import Data.Tuple.Extra (second)
|
||||
import Grammar.ErrM (Err)
|
||||
import Grammar.Print (printTree)
|
||||
import LlvmIr (LLVMComp (..), LLVMIr (..),
|
||||
LLVMType (..), LLVMValue (..),
|
||||
Visibility (..), llvmIrToString)
|
||||
import System.IO (stdin)
|
||||
import System.Process.Extra (CreateProcess (std_in),
|
||||
StdStream (CreatePipe), createProcess,
|
||||
readCreateProcess, shell)
|
||||
import TypeChecker (partitionType)
|
||||
import TypeCheckerIr
|
||||
|
||||
-- | The record used as the code generator state
|
||||
data CodeGenerator = CodeGenerator
|
||||
{ instructions :: [LLVMIr]
|
||||
, functions :: Map Id FunctionInfo
|
||||
{ instructions :: [LLVMIr]
|
||||
, functions :: Map Id FunctionInfo
|
||||
, variableCount :: Integer
|
||||
, labelCount :: Integer
|
||||
}
|
||||
|
||||
-- | A state type synonym
|
||||
type CompilerState a = StateT CodeGenerator Err a
|
||||
|
||||
data FunctionInfo = FunctionInfo
|
||||
{ numArgs :: Int
|
||||
{ numArgs :: Int
|
||||
, arguments :: [Id]
|
||||
}
|
||||
|
||||
|
|
@ -50,6 +52,12 @@ getVarCount = gets variableCount
|
|||
getNewVar :: CompilerState Integer
|
||||
getNewVar = increaseVarCount >> getVarCount
|
||||
|
||||
-- | Increses the label count and returns a label from the CodeGenerator state
|
||||
getNewLabel :: CompilerState Integer
|
||||
getNewLabel = do
|
||||
modify (\t -> t{labelCount = labelCount t + 1})
|
||||
gets labelCount
|
||||
|
||||
{- | Produces a map of functions infos from a list of binds,
|
||||
which contains useful data for code generation.
|
||||
-}
|
||||
|
|
@ -67,6 +75,36 @@ getFunctions xs =
|
|||
)
|
||||
xs
|
||||
|
||||
run :: Err String -> IO ()
|
||||
run s = do
|
||||
let s' = case s of
|
||||
Right s -> s
|
||||
Left _ -> error "yo"
|
||||
writeFile "llvm.ll" s'
|
||||
putStrLn . trim =<< readCreateProcess (shell "lli") s'
|
||||
test :: Integer -> Program
|
||||
test v = Program [
|
||||
Bind (Ident "fibonacci", TInt) [(Ident "x", TInt)] (
|
||||
ECased (EId ("x", TInt)) [
|
||||
Case (CInt 0) (EInt 0),
|
||||
Case (CInt 1) (EInt 1),
|
||||
Case CatchAll (EAdd TInt
|
||||
(EApp TInt (EId (Ident "fibonacci", TInt)) (
|
||||
EAdd TInt (EId (Ident "x", TInt))
|
||||
(EInt (fromIntegral ((maxBound :: Int) * 2)))
|
||||
))
|
||||
(EApp TInt (EId (Ident "fibonacci", TInt)) (
|
||||
EAdd TInt (EId (Ident "x", TInt))
|
||||
(EInt (fromIntegral ((maxBound :: Int) * 2 + 1)))
|
||||
))
|
||||
)
|
||||
]
|
||||
),
|
||||
Bind (Ident "main",TInt) [] (
|
||||
EApp TInt (EId (Ident "fibonacci", TInt)) (EInt v) -- (EInt 92)
|
||||
)
|
||||
]
|
||||
|
||||
{- | Compiles an AST and produces a LLVM Ir string.
|
||||
An easy way to actually "compile" this output is to
|
||||
Simply pipe it to LLI
|
||||
|
|
@ -78,6 +116,7 @@ compile (Program prg) = do
|
|||
{ instructions = defaultStart
|
||||
, functions = getFunctions prg
|
||||
, variableCount = 0
|
||||
, labelCount = 0
|
||||
}
|
||||
ins <- instructions <$> execStateT (goDef prg) s
|
||||
pure $ llvmIrToString ins
|
||||
|
|
@ -112,7 +151,7 @@ compile (Program prg) = do
|
|||
goDef (Bind (name, t) args exp : xs) = do
|
||||
emit $ UnsafeRaw "\n"
|
||||
emit $ Comment $ show name <> ": " <> show exp
|
||||
emit $ Define (type2LlvmType t_return) name (map (second type2LlvmType) args)
|
||||
emit $ Define (I64{-type2LlvmType t_return-}) name (map (second type2LlvmType) args)
|
||||
functionBody <- exprToValue exp
|
||||
if name == "main"
|
||||
then mapM_ emit (mainContent functionBody)
|
||||
|
|
@ -124,21 +163,54 @@ compile (Program prg) = do
|
|||
t_return = snd $ partitionType (length args) t
|
||||
|
||||
go :: Exp -> CompilerState ()
|
||||
go (EInt int) = emitInt int
|
||||
go (EAdd t e1 e2) = emitAdd t e1 e2
|
||||
go (EInt int) = emitInt int
|
||||
go (EAdd t e1 e2) = emitAdd t e1 e2
|
||||
go (EId (name, _)) = emitIdent name
|
||||
go (EApp t e1 e2) = emitApp t e1 e2
|
||||
go (EAbs t ti e) = emitAbs t ti e
|
||||
go (ELet binds e) = emitLet binds e
|
||||
go (EAnn _ _) = emitEAnn
|
||||
go (EApp t e1 e2) = emitApp t e1 e2
|
||||
go (EAbs t ti e) = emitAbs t ti e
|
||||
go (ELet binds e) = emitLet binds e
|
||||
go (EAnn _ _) = emitEAnn
|
||||
go (ECased e c) = emitECased e c
|
||||
-- go (ESub e1 e2) = emitSub e1 e2
|
||||
-- go (EMul e1 e2) = emitMul e1 e2
|
||||
-- go (EDiv e1 e2) = emitDiv e1 e2
|
||||
-- go (EMod e1 e2) = emitMod e1 e2
|
||||
|
||||
--- aux functions ---
|
||||
emitECased :: Exp -> [Case] -> CompilerState ()
|
||||
emitECased e cs = do
|
||||
vs <- exprToValue e
|
||||
lbl <- getNewLabel
|
||||
let label = Ident $ "escape_" <> show lbl
|
||||
stackPtr <- getNewVar
|
||||
emit $ SetVariable (Ident $ show stackPtr) (Alloca I64)
|
||||
mapM_ (emitCases label stackPtr vs) cs
|
||||
emit $ Label label
|
||||
res <- getNewVar
|
||||
emit $ SetVariable (Ident $ show res) (Load I64 Ptr (Ident $ show stackPtr))
|
||||
where
|
||||
emitCases :: Ident -> Integer -> LLVMValue -> Case -> CompilerState ()
|
||||
emitCases label stackPtr vs (Case (CInt i) exp) = do
|
||||
ns <- getNewVar
|
||||
lbl_fail <- getNewLabel
|
||||
lbl_succ <- getNewLabel
|
||||
let failed = Ident $ "failed_" <> show lbl_fail
|
||||
let success = Ident $ "success_" <> show lbl_succ
|
||||
emit $ SetVariable (Ident $ show ns) (Icmp LLEq I64 vs (VInteger i))
|
||||
emit $ BrCond (VIdent (Ident $ show ns) I64) success failed
|
||||
emit $ Label success
|
||||
val <- exprToValue exp
|
||||
emit $ Store I64 val Ptr (Ident . show $ stackPtr)
|
||||
emit $ Br label
|
||||
emit $ Label failed
|
||||
emitCases label stackPtr _ (Case CatchAll exp) = do
|
||||
val <- exprToValue exp
|
||||
emit $ Store I64 val Ptr (Ident . show $ stackPtr)
|
||||
emit $ Br label
|
||||
|
||||
|
||||
emitEAnn :: CompilerState ()
|
||||
emitEAnn = emit . UnsafeRaw $ "why?"
|
||||
emitEAnn = emit . UnsafeRaw $ "Annotated escaped previous stages"
|
||||
|
||||
emitAbs :: Type -> Id -> Exp -> CompilerState ()
|
||||
emitAbs _t tid e = do
|
||||
|
|
@ -170,7 +242,7 @@ compile (Program prg) = do
|
|||
funcs <- gets functions
|
||||
let vis = case Map.lookup id funcs of
|
||||
Nothing -> Local
|
||||
Just _ -> Global
|
||||
Just _ -> Global
|
||||
let call = Call (type2LlvmType t) vis name ((\x -> (valueGetType x, x)) <$> args)
|
||||
emit $ SetVariable (Ident $ show vs) call
|
||||
x -> do
|
||||
|
|
@ -271,19 +343,19 @@ type2LlvmType = \case
|
|||
where
|
||||
function2LLVMType :: Type -> [LLVMType] -> (LLVMType, [LLVMType])
|
||||
function2LLVMType (TFun t xs) s = function2LLVMType xs (type2LlvmType t : s)
|
||||
function2LLVMType x s = (type2LlvmType x, s)
|
||||
function2LLVMType x s = (type2LlvmType x, s)
|
||||
|
||||
getType :: Exp -> LLVMType
|
||||
getType (EInt _) = I64
|
||||
getType (EInt _) = I64
|
||||
getType (EAdd t _ _) = type2LlvmType t
|
||||
getType (EId (_, t)) = type2LlvmType t
|
||||
getType (EApp t _ _) = type2LlvmType t
|
||||
getType (EAbs t _ _) = type2LlvmType t
|
||||
getType (ELet _ e) = getType e
|
||||
getType (EAnn _ t) = type2LlvmType t
|
||||
getType (ELet _ e) = getType e
|
||||
getType (EAnn _ t) = type2LlvmType t
|
||||
|
||||
valueGetType :: LLVMValue -> LLVMType
|
||||
valueGetType (VInteger _) = I64
|
||||
valueGetType (VIdent _ t) = t
|
||||
valueGetType (VConstant s) = Array (length s) I8
|
||||
valueGetType (VInteger _) = I64
|
||||
valueGetType (VIdent _ t) = t
|
||||
valueGetType (VConstant s) = Array (length s) I8
|
||||
valueGetType (VFunction _ _ t) = t
|
||||
|
|
|
|||
|
|
@ -9,8 +9,8 @@ module LlvmIr (
|
|||
Visibility (..),
|
||||
) where
|
||||
|
||||
import Data.List (intercalate)
|
||||
import TypeCheckerIr
|
||||
import Data.List (intercalate)
|
||||
import TypeCheckerIr
|
||||
|
||||
-- | A datatype which represents some basic LLVM types
|
||||
data LLVMType
|
||||
|
|
@ -65,7 +65,7 @@ instance Show LLVMComp where
|
|||
data Visibility = Local | Global
|
||||
instance Show Visibility where
|
||||
show :: Visibility -> String
|
||||
show Local = "%"
|
||||
show Local = "%"
|
||||
show Global = "@"
|
||||
|
||||
{- | Represents a LLVM "value", as in an integer, a register variable,
|
||||
|
|
@ -80,10 +80,10 @@ data LLVMValue
|
|||
instance Show LLVMValue where
|
||||
show :: LLVMValue -> String
|
||||
show v = case v of
|
||||
VInteger i -> show i
|
||||
VIdent (Ident n) _ -> "%" <> n
|
||||
VInteger i -> show i
|
||||
VIdent (Ident n) _ -> "%" <> n
|
||||
VFunction (Ident n) vis _ -> show vis <> n
|
||||
VConstant s -> "c" <> show s
|
||||
VConstant s -> "c" <> show s
|
||||
|
||||
type Params = [(Ident, LLVMType)]
|
||||
type Args = [(LLVMType, LLVMValue)]
|
||||
|
|
@ -106,7 +106,8 @@ data LLVMIr
|
|||
| Label Ident
|
||||
| Call LLVMType Visibility Ident Args
|
||||
| Alloca LLVMType
|
||||
| Store LLVMType Ident LLVMType Ident
|
||||
| Store LLVMType LLVMValue LLVMType Ident
|
||||
| Load LLVMType LLVMType Ident
|
||||
| Bitcast LLVMType Ident LLVMType
|
||||
| Ret LLVMType LLVMValue
|
||||
| Comment String
|
||||
|
|
@ -122,9 +123,9 @@ llvmIrToString = go 0
|
|||
go _ [] = mempty
|
||||
go i (x : xs) = do
|
||||
let (i', n) = case x of
|
||||
Define{} -> (i + 1, 0)
|
||||
Define{} -> (i + 1, 0)
|
||||
DefineEnd -> (i - 1, 0)
|
||||
_ -> (i, i)
|
||||
_ -> (i, i)
|
||||
insToString n x <> go i' xs
|
||||
|
||||
{- | Converts a LLVM inststruction to a String, allowing for printing etc.
|
||||
|
|
@ -175,11 +176,16 @@ llvmIrToString = go 0
|
|||
, ")\n"
|
||||
]
|
||||
(Alloca t) -> unwords ["alloca", show t, "\n"]
|
||||
(Store t1 (Ident id1) t2 (Ident id2)) ->
|
||||
(Store t1 val t2 (Ident id2)) ->
|
||||
concat
|
||||
[ "store ", show t1, " %", id1
|
||||
[ "store ", show t1, " ", show val
|
||||
, ", ", show t2 , " %", id2, "\n"
|
||||
]
|
||||
(Load t1 t2 (Ident addr)) ->
|
||||
concat
|
||||
[ "load ", show t1, ", "
|
||||
, show t2, " %", addr, "\n"
|
||||
]
|
||||
(Bitcast t1 (Ident i) t2) ->
|
||||
concat
|
||||
[ "bitcast ", show t1, " %"
|
||||
|
|
@ -196,13 +202,16 @@ llvmIrToString = go 0
|
|||
, show v, "\n"
|
||||
]
|
||||
(UnsafeRaw s) -> s
|
||||
(Label (Ident s)) -> "\nlabel_" <> s <> ":\n"
|
||||
(Br (Ident s)) -> "br label %label_" <> s <> "\n"
|
||||
(Label (Ident s)) -> "\n" <> lblPfx <> s <> ":\n"
|
||||
(Br (Ident s)) -> "br label %" <> lblPfx <> s <> "\n"
|
||||
(BrCond val (Ident s1) (Ident s2)) ->
|
||||
concat
|
||||
[ "br i1 ", show val, ", ", "label %"
|
||||
, "label_", s1, ", ", "label %", "label_", s2, "\n"
|
||||
, lblPfx, s1, ", ", "label %", lblPfx, s2, "\n"
|
||||
]
|
||||
(Comment s) -> "; " <> s <> "\n"
|
||||
(Variable (Ident id)) -> "%" <> id
|
||||
{- FOURMOLU_ENABLE -}
|
||||
|
||||
lblPfx :: String
|
||||
lblPfx = "lbl_"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue