Implemented case matching on ints in the code generator
This commit is contained in:
parent
7cedc2e28c
commit
287f84377c
2 changed files with 130 additions and 49 deletions
|
|
@ -4,18 +4,19 @@
|
||||||
module Compiler (compile) where
|
module Compiler (compile) where
|
||||||
|
|
||||||
import Control.Monad.State (StateT, execStateT, gets, modify)
|
import Control.Monad.State (StateT, execStateT, gets, modify)
|
||||||
|
import Data.List.Extra (trim)
|
||||||
import Data.Map (Map)
|
import Data.Map (Map)
|
||||||
import qualified Data.Map as Map
|
import qualified Data.Map as Map
|
||||||
import Data.Tuple.Extra (second)
|
import Data.Tuple.Extra (second)
|
||||||
import Grammar.ErrM (Err)
|
import Grammar.ErrM (Err)
|
||||||
import Grammar.Print (printTree)
|
import Grammar.Print (printTree)
|
||||||
import LlvmIr (
|
import LlvmIr (LLVMComp (..), LLVMIr (..),
|
||||||
LLVMIr (..),
|
LLVMType (..), LLVMValue (..),
|
||||||
LLVMType (..),
|
Visibility (..), llvmIrToString)
|
||||||
LLVMValue (..),
|
import System.IO (stdin)
|
||||||
Visibility (..),
|
import System.Process.Extra (CreateProcess (std_in),
|
||||||
llvmIrToString,
|
StdStream (CreatePipe), createProcess,
|
||||||
)
|
readCreateProcess, shell)
|
||||||
import TypeChecker (partitionType)
|
import TypeChecker (partitionType)
|
||||||
import TypeCheckerIr
|
import TypeCheckerIr
|
||||||
|
|
||||||
|
|
@ -24,6 +25,7 @@ data CodeGenerator = CodeGenerator
|
||||||
{ instructions :: [LLVMIr]
|
{ instructions :: [LLVMIr]
|
||||||
, functions :: Map Id FunctionInfo
|
, functions :: Map Id FunctionInfo
|
||||||
, variableCount :: Integer
|
, variableCount :: Integer
|
||||||
|
, labelCount :: Integer
|
||||||
}
|
}
|
||||||
|
|
||||||
-- | A state type synonym
|
-- | A state type synonym
|
||||||
|
|
@ -50,6 +52,12 @@ getVarCount = gets variableCount
|
||||||
getNewVar :: CompilerState Integer
|
getNewVar :: CompilerState Integer
|
||||||
getNewVar = increaseVarCount >> getVarCount
|
getNewVar = increaseVarCount >> getVarCount
|
||||||
|
|
||||||
|
-- | Increses the label count and returns a label from the CodeGenerator state
|
||||||
|
getNewLabel :: CompilerState Integer
|
||||||
|
getNewLabel = do
|
||||||
|
modify (\t -> t{labelCount = labelCount t + 1})
|
||||||
|
gets labelCount
|
||||||
|
|
||||||
{- | Produces a map of functions infos from a list of binds,
|
{- | Produces a map of functions infos from a list of binds,
|
||||||
which contains useful data for code generation.
|
which contains useful data for code generation.
|
||||||
-}
|
-}
|
||||||
|
|
@ -67,6 +75,36 @@ getFunctions xs =
|
||||||
)
|
)
|
||||||
xs
|
xs
|
||||||
|
|
||||||
|
run :: Err String -> IO ()
|
||||||
|
run s = do
|
||||||
|
let s' = case s of
|
||||||
|
Right s -> s
|
||||||
|
Left _ -> error "yo"
|
||||||
|
writeFile "llvm.ll" s'
|
||||||
|
putStrLn . trim =<< readCreateProcess (shell "lli") s'
|
||||||
|
test :: Integer -> Program
|
||||||
|
test v = Program [
|
||||||
|
Bind (Ident "fibonacci", TInt) [(Ident "x", TInt)] (
|
||||||
|
ECased (EId ("x", TInt)) [
|
||||||
|
Case (CInt 0) (EInt 0),
|
||||||
|
Case (CInt 1) (EInt 1),
|
||||||
|
Case CatchAll (EAdd TInt
|
||||||
|
(EApp TInt (EId (Ident "fibonacci", TInt)) (
|
||||||
|
EAdd TInt (EId (Ident "x", TInt))
|
||||||
|
(EInt (fromIntegral ((maxBound :: Int) * 2)))
|
||||||
|
))
|
||||||
|
(EApp TInt (EId (Ident "fibonacci", TInt)) (
|
||||||
|
EAdd TInt (EId (Ident "x", TInt))
|
||||||
|
(EInt (fromIntegral ((maxBound :: Int) * 2 + 1)))
|
||||||
|
))
|
||||||
|
)
|
||||||
|
]
|
||||||
|
),
|
||||||
|
Bind (Ident "main",TInt) [] (
|
||||||
|
EApp TInt (EId (Ident "fibonacci", TInt)) (EInt v) -- (EInt 92)
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
{- | Compiles an AST and produces a LLVM Ir string.
|
{- | Compiles an AST and produces a LLVM Ir string.
|
||||||
An easy way to actually "compile" this output is to
|
An easy way to actually "compile" this output is to
|
||||||
Simply pipe it to LLI
|
Simply pipe it to LLI
|
||||||
|
|
@ -78,6 +116,7 @@ compile (Program prg) = do
|
||||||
{ instructions = defaultStart
|
{ instructions = defaultStart
|
||||||
, functions = getFunctions prg
|
, functions = getFunctions prg
|
||||||
, variableCount = 0
|
, variableCount = 0
|
||||||
|
, labelCount = 0
|
||||||
}
|
}
|
||||||
ins <- instructions <$> execStateT (goDef prg) s
|
ins <- instructions <$> execStateT (goDef prg) s
|
||||||
pure $ llvmIrToString ins
|
pure $ llvmIrToString ins
|
||||||
|
|
@ -112,7 +151,7 @@ compile (Program prg) = do
|
||||||
goDef (Bind (name, t) args exp : xs) = do
|
goDef (Bind (name, t) args exp : xs) = do
|
||||||
emit $ UnsafeRaw "\n"
|
emit $ UnsafeRaw "\n"
|
||||||
emit $ Comment $ show name <> ": " <> show exp
|
emit $ Comment $ show name <> ": " <> show exp
|
||||||
emit $ Define (type2LlvmType t_return) name (map (second type2LlvmType) args)
|
emit $ Define (I64{-type2LlvmType t_return-}) name (map (second type2LlvmType) args)
|
||||||
functionBody <- exprToValue exp
|
functionBody <- exprToValue exp
|
||||||
if name == "main"
|
if name == "main"
|
||||||
then mapM_ emit (mainContent functionBody)
|
then mapM_ emit (mainContent functionBody)
|
||||||
|
|
@ -131,14 +170,47 @@ compile (Program prg) = do
|
||||||
go (EAbs t ti e) = emitAbs t ti e
|
go (EAbs t ti e) = emitAbs t ti e
|
||||||
go (ELet binds e) = emitLet binds e
|
go (ELet binds e) = emitLet binds e
|
||||||
go (EAnn _ _) = emitEAnn
|
go (EAnn _ _) = emitEAnn
|
||||||
|
go (ECased e c) = emitECased e c
|
||||||
-- go (ESub e1 e2) = emitSub e1 e2
|
-- go (ESub e1 e2) = emitSub e1 e2
|
||||||
-- go (EMul e1 e2) = emitMul e1 e2
|
-- go (EMul e1 e2) = emitMul e1 e2
|
||||||
-- go (EDiv e1 e2) = emitDiv e1 e2
|
-- go (EDiv e1 e2) = emitDiv e1 e2
|
||||||
-- go (EMod e1 e2) = emitMod e1 e2
|
-- go (EMod e1 e2) = emitMod e1 e2
|
||||||
|
|
||||||
--- aux functions ---
|
--- aux functions ---
|
||||||
|
emitECased :: Exp -> [Case] -> CompilerState ()
|
||||||
|
emitECased e cs = do
|
||||||
|
vs <- exprToValue e
|
||||||
|
lbl <- getNewLabel
|
||||||
|
let label = Ident $ "escape_" <> show lbl
|
||||||
|
stackPtr <- getNewVar
|
||||||
|
emit $ SetVariable (Ident $ show stackPtr) (Alloca I64)
|
||||||
|
mapM_ (emitCases label stackPtr vs) cs
|
||||||
|
emit $ Label label
|
||||||
|
res <- getNewVar
|
||||||
|
emit $ SetVariable (Ident $ show res) (Load I64 Ptr (Ident $ show stackPtr))
|
||||||
|
where
|
||||||
|
emitCases :: Ident -> Integer -> LLVMValue -> Case -> CompilerState ()
|
||||||
|
emitCases label stackPtr vs (Case (CInt i) exp) = do
|
||||||
|
ns <- getNewVar
|
||||||
|
lbl_fail <- getNewLabel
|
||||||
|
lbl_succ <- getNewLabel
|
||||||
|
let failed = Ident $ "failed_" <> show lbl_fail
|
||||||
|
let success = Ident $ "success_" <> show lbl_succ
|
||||||
|
emit $ SetVariable (Ident $ show ns) (Icmp LLEq I64 vs (VInteger i))
|
||||||
|
emit $ BrCond (VIdent (Ident $ show ns) I64) success failed
|
||||||
|
emit $ Label success
|
||||||
|
val <- exprToValue exp
|
||||||
|
emit $ Store I64 val Ptr (Ident . show $ stackPtr)
|
||||||
|
emit $ Br label
|
||||||
|
emit $ Label failed
|
||||||
|
emitCases label stackPtr _ (Case CatchAll exp) = do
|
||||||
|
val <- exprToValue exp
|
||||||
|
emit $ Store I64 val Ptr (Ident . show $ stackPtr)
|
||||||
|
emit $ Br label
|
||||||
|
|
||||||
|
|
||||||
emitEAnn :: CompilerState ()
|
emitEAnn :: CompilerState ()
|
||||||
emitEAnn = emit . UnsafeRaw $ "why?"
|
emitEAnn = emit . UnsafeRaw $ "Annotated escaped previous stages"
|
||||||
|
|
||||||
emitAbs :: Type -> Id -> Exp -> CompilerState ()
|
emitAbs :: Type -> Id -> Exp -> CompilerState ()
|
||||||
emitAbs _t tid e = do
|
emitAbs _t tid e = do
|
||||||
|
|
|
||||||
|
|
@ -106,7 +106,8 @@ data LLVMIr
|
||||||
| Label Ident
|
| Label Ident
|
||||||
| Call LLVMType Visibility Ident Args
|
| Call LLVMType Visibility Ident Args
|
||||||
| Alloca LLVMType
|
| Alloca LLVMType
|
||||||
| Store LLVMType Ident LLVMType Ident
|
| Store LLVMType LLVMValue LLVMType Ident
|
||||||
|
| Load LLVMType LLVMType Ident
|
||||||
| Bitcast LLVMType Ident LLVMType
|
| Bitcast LLVMType Ident LLVMType
|
||||||
| Ret LLVMType LLVMValue
|
| Ret LLVMType LLVMValue
|
||||||
| Comment String
|
| Comment String
|
||||||
|
|
@ -175,11 +176,16 @@ llvmIrToString = go 0
|
||||||
, ")\n"
|
, ")\n"
|
||||||
]
|
]
|
||||||
(Alloca t) -> unwords ["alloca", show t, "\n"]
|
(Alloca t) -> unwords ["alloca", show t, "\n"]
|
||||||
(Store t1 (Ident id1) t2 (Ident id2)) ->
|
(Store t1 val t2 (Ident id2)) ->
|
||||||
concat
|
concat
|
||||||
[ "store ", show t1, " %", id1
|
[ "store ", show t1, " ", show val
|
||||||
, ", ", show t2 , " %", id2, "\n"
|
, ", ", show t2 , " %", id2, "\n"
|
||||||
]
|
]
|
||||||
|
(Load t1 t2 (Ident addr)) ->
|
||||||
|
concat
|
||||||
|
[ "load ", show t1, ", "
|
||||||
|
, show t2, " %", addr, "\n"
|
||||||
|
]
|
||||||
(Bitcast t1 (Ident i) t2) ->
|
(Bitcast t1 (Ident i) t2) ->
|
||||||
concat
|
concat
|
||||||
[ "bitcast ", show t1, " %"
|
[ "bitcast ", show t1, " %"
|
||||||
|
|
@ -196,13 +202,16 @@ llvmIrToString = go 0
|
||||||
, show v, "\n"
|
, show v, "\n"
|
||||||
]
|
]
|
||||||
(UnsafeRaw s) -> s
|
(UnsafeRaw s) -> s
|
||||||
(Label (Ident s)) -> "\nlabel_" <> s <> ":\n"
|
(Label (Ident s)) -> "\n" <> lblPfx <> s <> ":\n"
|
||||||
(Br (Ident s)) -> "br label %label_" <> s <> "\n"
|
(Br (Ident s)) -> "br label %" <> lblPfx <> s <> "\n"
|
||||||
(BrCond val (Ident s1) (Ident s2)) ->
|
(BrCond val (Ident s1) (Ident s2)) ->
|
||||||
concat
|
concat
|
||||||
[ "br i1 ", show val, ", ", "label %"
|
[ "br i1 ", show val, ", ", "label %"
|
||||||
, "label_", s1, ", ", "label %", "label_", s2, "\n"
|
, lblPfx, s1, ", ", "label %", lblPfx, s2, "\n"
|
||||||
]
|
]
|
||||||
(Comment s) -> "; " <> s <> "\n"
|
(Comment s) -> "; " <> s <> "\n"
|
||||||
(Variable (Ident id)) -> "%" <> id
|
(Variable (Ident id)) -> "%" <> id
|
||||||
{- FOURMOLU_ENABLE -}
|
{- FOURMOLU_ENABLE -}
|
||||||
|
|
||||||
|
lblPfx :: String
|
||||||
|
lblPfx = "lbl_"
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue