Implemented case matching on ints in the code generator
This commit is contained in:
parent
7cedc2e28c
commit
287f84377c
2 changed files with 130 additions and 49 deletions
|
|
@ -4,18 +4,19 @@
|
|||
module Compiler (compile) where
|
||||
|
||||
import Control.Monad.State (StateT, execStateT, gets, modify)
|
||||
import Data.List.Extra (trim)
|
||||
import Data.Map (Map)
|
||||
import qualified Data.Map as Map
|
||||
import Data.Tuple.Extra (second)
|
||||
import Grammar.ErrM (Err)
|
||||
import Grammar.Print (printTree)
|
||||
import LlvmIr (
|
||||
LLVMIr (..),
|
||||
LLVMType (..),
|
||||
LLVMValue (..),
|
||||
Visibility (..),
|
||||
llvmIrToString,
|
||||
)
|
||||
import LlvmIr (LLVMComp (..), LLVMIr (..),
|
||||
LLVMType (..), LLVMValue (..),
|
||||
Visibility (..), llvmIrToString)
|
||||
import System.IO (stdin)
|
||||
import System.Process.Extra (CreateProcess (std_in),
|
||||
StdStream (CreatePipe), createProcess,
|
||||
readCreateProcess, shell)
|
||||
import TypeChecker (partitionType)
|
||||
import TypeCheckerIr
|
||||
|
||||
|
|
@ -24,6 +25,7 @@ data CodeGenerator = CodeGenerator
|
|||
{ instructions :: [LLVMIr]
|
||||
, functions :: Map Id FunctionInfo
|
||||
, variableCount :: Integer
|
||||
, labelCount :: Integer
|
||||
}
|
||||
|
||||
-- | A state type synonym
|
||||
|
|
@ -50,6 +52,12 @@ getVarCount = gets variableCount
|
|||
getNewVar :: CompilerState Integer
|
||||
getNewVar = increaseVarCount >> getVarCount
|
||||
|
||||
-- | Increses the label count and returns a label from the CodeGenerator state
|
||||
getNewLabel :: CompilerState Integer
|
||||
getNewLabel = do
|
||||
modify (\t -> t{labelCount = labelCount t + 1})
|
||||
gets labelCount
|
||||
|
||||
{- | Produces a map of functions infos from a list of binds,
|
||||
which contains useful data for code generation.
|
||||
-}
|
||||
|
|
@ -67,6 +75,36 @@ getFunctions xs =
|
|||
)
|
||||
xs
|
||||
|
||||
run :: Err String -> IO ()
|
||||
run s = do
|
||||
let s' = case s of
|
||||
Right s -> s
|
||||
Left _ -> error "yo"
|
||||
writeFile "llvm.ll" s'
|
||||
putStrLn . trim =<< readCreateProcess (shell "lli") s'
|
||||
test :: Integer -> Program
|
||||
test v = Program [
|
||||
Bind (Ident "fibonacci", TInt) [(Ident "x", TInt)] (
|
||||
ECased (EId ("x", TInt)) [
|
||||
Case (CInt 0) (EInt 0),
|
||||
Case (CInt 1) (EInt 1),
|
||||
Case CatchAll (EAdd TInt
|
||||
(EApp TInt (EId (Ident "fibonacci", TInt)) (
|
||||
EAdd TInt (EId (Ident "x", TInt))
|
||||
(EInt (fromIntegral ((maxBound :: Int) * 2)))
|
||||
))
|
||||
(EApp TInt (EId (Ident "fibonacci", TInt)) (
|
||||
EAdd TInt (EId (Ident "x", TInt))
|
||||
(EInt (fromIntegral ((maxBound :: Int) * 2 + 1)))
|
||||
))
|
||||
)
|
||||
]
|
||||
),
|
||||
Bind (Ident "main",TInt) [] (
|
||||
EApp TInt (EId (Ident "fibonacci", TInt)) (EInt v) -- (EInt 92)
|
||||
)
|
||||
]
|
||||
|
||||
{- | Compiles an AST and produces a LLVM Ir string.
|
||||
An easy way to actually "compile" this output is to
|
||||
Simply pipe it to LLI
|
||||
|
|
@ -78,6 +116,7 @@ compile (Program prg) = do
|
|||
{ instructions = defaultStart
|
||||
, functions = getFunctions prg
|
||||
, variableCount = 0
|
||||
, labelCount = 0
|
||||
}
|
||||
ins <- instructions <$> execStateT (goDef prg) s
|
||||
pure $ llvmIrToString ins
|
||||
|
|
@ -112,7 +151,7 @@ compile (Program prg) = do
|
|||
goDef (Bind (name, t) args exp : xs) = do
|
||||
emit $ UnsafeRaw "\n"
|
||||
emit $ Comment $ show name <> ": " <> show exp
|
||||
emit $ Define (type2LlvmType t_return) name (map (second type2LlvmType) args)
|
||||
emit $ Define (I64{-type2LlvmType t_return-}) name (map (second type2LlvmType) args)
|
||||
functionBody <- exprToValue exp
|
||||
if name == "main"
|
||||
then mapM_ emit (mainContent functionBody)
|
||||
|
|
@ -131,14 +170,47 @@ compile (Program prg) = do
|
|||
go (EAbs t ti e) = emitAbs t ti e
|
||||
go (ELet binds e) = emitLet binds e
|
||||
go (EAnn _ _) = emitEAnn
|
||||
go (ECased e c) = emitECased e c
|
||||
-- go (ESub e1 e2) = emitSub e1 e2
|
||||
-- go (EMul e1 e2) = emitMul e1 e2
|
||||
-- go (EDiv e1 e2) = emitDiv e1 e2
|
||||
-- go (EMod e1 e2) = emitMod e1 e2
|
||||
|
||||
--- aux functions ---
|
||||
emitECased :: Exp -> [Case] -> CompilerState ()
|
||||
emitECased e cs = do
|
||||
vs <- exprToValue e
|
||||
lbl <- getNewLabel
|
||||
let label = Ident $ "escape_" <> show lbl
|
||||
stackPtr <- getNewVar
|
||||
emit $ SetVariable (Ident $ show stackPtr) (Alloca I64)
|
||||
mapM_ (emitCases label stackPtr vs) cs
|
||||
emit $ Label label
|
||||
res <- getNewVar
|
||||
emit $ SetVariable (Ident $ show res) (Load I64 Ptr (Ident $ show stackPtr))
|
||||
where
|
||||
emitCases :: Ident -> Integer -> LLVMValue -> Case -> CompilerState ()
|
||||
emitCases label stackPtr vs (Case (CInt i) exp) = do
|
||||
ns <- getNewVar
|
||||
lbl_fail <- getNewLabel
|
||||
lbl_succ <- getNewLabel
|
||||
let failed = Ident $ "failed_" <> show lbl_fail
|
||||
let success = Ident $ "success_" <> show lbl_succ
|
||||
emit $ SetVariable (Ident $ show ns) (Icmp LLEq I64 vs (VInteger i))
|
||||
emit $ BrCond (VIdent (Ident $ show ns) I64) success failed
|
||||
emit $ Label success
|
||||
val <- exprToValue exp
|
||||
emit $ Store I64 val Ptr (Ident . show $ stackPtr)
|
||||
emit $ Br label
|
||||
emit $ Label failed
|
||||
emitCases label stackPtr _ (Case CatchAll exp) = do
|
||||
val <- exprToValue exp
|
||||
emit $ Store I64 val Ptr (Ident . show $ stackPtr)
|
||||
emit $ Br label
|
||||
|
||||
|
||||
emitEAnn :: CompilerState ()
|
||||
emitEAnn = emit . UnsafeRaw $ "why?"
|
||||
emitEAnn = emit . UnsafeRaw $ "Annotated escaped previous stages"
|
||||
|
||||
emitAbs :: Type -> Id -> Exp -> CompilerState ()
|
||||
emitAbs _t tid e = do
|
||||
|
|
|
|||
|
|
@ -106,7 +106,8 @@ data LLVMIr
|
|||
| Label Ident
|
||||
| Call LLVMType Visibility Ident Args
|
||||
| Alloca LLVMType
|
||||
| Store LLVMType Ident LLVMType Ident
|
||||
| Store LLVMType LLVMValue LLVMType Ident
|
||||
| Load LLVMType LLVMType Ident
|
||||
| Bitcast LLVMType Ident LLVMType
|
||||
| Ret LLVMType LLVMValue
|
||||
| Comment String
|
||||
|
|
@ -175,11 +176,16 @@ llvmIrToString = go 0
|
|||
, ")\n"
|
||||
]
|
||||
(Alloca t) -> unwords ["alloca", show t, "\n"]
|
||||
(Store t1 (Ident id1) t2 (Ident id2)) ->
|
||||
(Store t1 val t2 (Ident id2)) ->
|
||||
concat
|
||||
[ "store ", show t1, " %", id1
|
||||
[ "store ", show t1, " ", show val
|
||||
, ", ", show t2 , " %", id2, "\n"
|
||||
]
|
||||
(Load t1 t2 (Ident addr)) ->
|
||||
concat
|
||||
[ "load ", show t1, ", "
|
||||
, show t2, " %", addr, "\n"
|
||||
]
|
||||
(Bitcast t1 (Ident i) t2) ->
|
||||
concat
|
||||
[ "bitcast ", show t1, " %"
|
||||
|
|
@ -196,13 +202,16 @@ llvmIrToString = go 0
|
|||
, show v, "\n"
|
||||
]
|
||||
(UnsafeRaw s) -> s
|
||||
(Label (Ident s)) -> "\nlabel_" <> s <> ":\n"
|
||||
(Br (Ident s)) -> "br label %label_" <> s <> "\n"
|
||||
(Label (Ident s)) -> "\n" <> lblPfx <> s <> ":\n"
|
||||
(Br (Ident s)) -> "br label %" <> lblPfx <> s <> "\n"
|
||||
(BrCond val (Ident s1) (Ident s2)) ->
|
||||
concat
|
||||
[ "br i1 ", show val, ", ", "label %"
|
||||
, "label_", s1, ", ", "label %", "label_", s2, "\n"
|
||||
, lblPfx, s1, ", ", "label %", lblPfx, s2, "\n"
|
||||
]
|
||||
(Comment s) -> "; " <> s <> "\n"
|
||||
(Variable (Ident id)) -> "%" <> id
|
||||
{- FOURMOLU_ENABLE -}
|
||||
|
||||
lblPfx :: String
|
||||
lblPfx = "lbl_"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue