diff --git a/src/Compiler.hs b/src/Compiler.hs index 0820523..92f4a23 100644 --- a/src/Compiler.hs +++ b/src/Compiler.hs @@ -1,36 +1,38 @@ -{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedStrings #-} module Compiler (compile) where -import Control.Monad.State (StateT, execStateT, gets, modify) -import Data.Map (Map) -import qualified Data.Map as Map -import Data.Tuple.Extra (second) -import Grammar.ErrM (Err) -import Grammar.Print (printTree) -import LlvmIr ( - LLVMIr (..), - LLVMType (..), - LLVMValue (..), - Visibility (..), - llvmIrToString, - ) -import TypeChecker (partitionType) -import TypeCheckerIr +import Control.Monad.State (StateT, execStateT, gets, modify) +import Data.List.Extra (trim) +import Data.Map (Map) +import qualified Data.Map as Map +import Data.Tuple.Extra (second) +import Grammar.ErrM (Err) +import Grammar.Print (printTree) +import LlvmIr (LLVMComp (..), LLVMIr (..), + LLVMType (..), LLVMValue (..), + Visibility (..), llvmIrToString) +import System.IO (stdin) +import System.Process.Extra (CreateProcess (std_in), + StdStream (CreatePipe), createProcess, + readCreateProcess, shell) +import TypeChecker (partitionType) +import TypeCheckerIr -- | The record used as the code generator state data CodeGenerator = CodeGenerator - { instructions :: [LLVMIr] - , functions :: Map Id FunctionInfo + { instructions :: [LLVMIr] + , functions :: Map Id FunctionInfo , variableCount :: Integer + , labelCount :: Integer } -- | A state type synonym type CompilerState a = StateT CodeGenerator Err a data FunctionInfo = FunctionInfo - { numArgs :: Int + { numArgs :: Int , arguments :: [Id] } @@ -50,6 +52,12 @@ getVarCount = gets variableCount getNewVar :: CompilerState Integer getNewVar = increaseVarCount >> getVarCount +-- | Increses the label count and returns a label from the CodeGenerator state +getNewLabel :: CompilerState Integer +getNewLabel = do + modify (\t -> t{labelCount = labelCount t + 1}) + gets labelCount + {- | Produces a map of functions infos from a list of binds, which contains useful data for code generation. -} @@ -67,6 +75,36 @@ getFunctions xs = ) xs +run :: Err String -> IO () +run s = do + let s' = case s of + Right s -> s + Left _ -> error "yo" + writeFile "llvm.ll" s' + putStrLn . trim =<< readCreateProcess (shell "lli") s' +test :: Integer -> Program +test v = Program [ + Bind (Ident "fibonacci", TInt) [(Ident "x", TInt)] ( + ECased (EId ("x", TInt)) [ + Case (CInt 0) (EInt 0), + Case (CInt 1) (EInt 1), + Case CatchAll (EAdd TInt + (EApp TInt (EId (Ident "fibonacci", TInt)) ( + EAdd TInt (EId (Ident "x", TInt)) + (EInt (fromIntegral ((maxBound :: Int) * 2))) + )) + (EApp TInt (EId (Ident "fibonacci", TInt)) ( + EAdd TInt (EId (Ident "x", TInt)) + (EInt (fromIntegral ((maxBound :: Int) * 2 + 1))) + )) + ) + ] + ), + Bind (Ident "main",TInt) [] ( + EApp TInt (EId (Ident "fibonacci", TInt)) (EInt v) -- (EInt 92) + ) + ] + {- | Compiles an AST and produces a LLVM Ir string. An easy way to actually "compile" this output is to Simply pipe it to LLI @@ -78,6 +116,7 @@ compile (Program prg) = do { instructions = defaultStart , functions = getFunctions prg , variableCount = 0 + , labelCount = 0 } ins <- instructions <$> execStateT (goDef prg) s pure $ llvmIrToString ins @@ -112,7 +151,7 @@ compile (Program prg) = do goDef (Bind (name, t) args exp : xs) = do emit $ UnsafeRaw "\n" emit $ Comment $ show name <> ": " <> show exp - emit $ Define (type2LlvmType t_return) name (map (second type2LlvmType) args) + emit $ Define (I64{-type2LlvmType t_return-}) name (map (second type2LlvmType) args) functionBody <- exprToValue exp if name == "main" then mapM_ emit (mainContent functionBody) @@ -124,21 +163,54 @@ compile (Program prg) = do t_return = snd $ partitionType (length args) t go :: Exp -> CompilerState () - go (EInt int) = emitInt int - go (EAdd t e1 e2) = emitAdd t e1 e2 + go (EInt int) = emitInt int + go (EAdd t e1 e2) = emitAdd t e1 e2 go (EId (name, _)) = emitIdent name - go (EApp t e1 e2) = emitApp t e1 e2 - go (EAbs t ti e) = emitAbs t ti e - go (ELet binds e) = emitLet binds e - go (EAnn _ _) = emitEAnn + go (EApp t e1 e2) = emitApp t e1 e2 + go (EAbs t ti e) = emitAbs t ti e + go (ELet binds e) = emitLet binds e + go (EAnn _ _) = emitEAnn + go (ECased e c) = emitECased e c -- go (ESub e1 e2) = emitSub e1 e2 -- go (EMul e1 e2) = emitMul e1 e2 -- go (EDiv e1 e2) = emitDiv e1 e2 -- go (EMod e1 e2) = emitMod e1 e2 --- aux functions --- + emitECased :: Exp -> [Case] -> CompilerState () + emitECased e cs = do + vs <- exprToValue e + lbl <- getNewLabel + let label = Ident $ "escape_" <> show lbl + stackPtr <- getNewVar + emit $ SetVariable (Ident $ show stackPtr) (Alloca I64) + mapM_ (emitCases label stackPtr vs) cs + emit $ Label label + res <- getNewVar + emit $ SetVariable (Ident $ show res) (Load I64 Ptr (Ident $ show stackPtr)) + where + emitCases :: Ident -> Integer -> LLVMValue -> Case -> CompilerState () + emitCases label stackPtr vs (Case (CInt i) exp) = do + ns <- getNewVar + lbl_fail <- getNewLabel + lbl_succ <- getNewLabel + let failed = Ident $ "failed_" <> show lbl_fail + let success = Ident $ "success_" <> show lbl_succ + emit $ SetVariable (Ident $ show ns) (Icmp LLEq I64 vs (VInteger i)) + emit $ BrCond (VIdent (Ident $ show ns) I64) success failed + emit $ Label success + val <- exprToValue exp + emit $ Store I64 val Ptr (Ident . show $ stackPtr) + emit $ Br label + emit $ Label failed + emitCases label stackPtr _ (Case CatchAll exp) = do + val <- exprToValue exp + emit $ Store I64 val Ptr (Ident . show $ stackPtr) + emit $ Br label + + emitEAnn :: CompilerState () - emitEAnn = emit . UnsafeRaw $ "why?" + emitEAnn = emit . UnsafeRaw $ "Annotated escaped previous stages" emitAbs :: Type -> Id -> Exp -> CompilerState () emitAbs _t tid e = do @@ -170,7 +242,7 @@ compile (Program prg) = do funcs <- gets functions let vis = case Map.lookup id funcs of Nothing -> Local - Just _ -> Global + Just _ -> Global let call = Call (type2LlvmType t) vis name ((\x -> (valueGetType x, x)) <$> args) emit $ SetVariable (Ident $ show vs) call x -> do @@ -271,19 +343,19 @@ type2LlvmType = \case where function2LLVMType :: Type -> [LLVMType] -> (LLVMType, [LLVMType]) function2LLVMType (TFun t xs) s = function2LLVMType xs (type2LlvmType t : s) - function2LLVMType x s = (type2LlvmType x, s) + function2LLVMType x s = (type2LlvmType x, s) getType :: Exp -> LLVMType -getType (EInt _) = I64 +getType (EInt _) = I64 getType (EAdd t _ _) = type2LlvmType t getType (EId (_, t)) = type2LlvmType t getType (EApp t _ _) = type2LlvmType t getType (EAbs t _ _) = type2LlvmType t -getType (ELet _ e) = getType e -getType (EAnn _ t) = type2LlvmType t +getType (ELet _ e) = getType e +getType (EAnn _ t) = type2LlvmType t valueGetType :: LLVMValue -> LLVMType -valueGetType (VInteger _) = I64 -valueGetType (VIdent _ t) = t -valueGetType (VConstant s) = Array (length s) I8 +valueGetType (VInteger _) = I64 +valueGetType (VIdent _ t) = t +valueGetType (VConstant s) = Array (length s) I8 valueGetType (VFunction _ _ t) = t diff --git a/src/LlvmIr.hs b/src/LlvmIr.hs index b29f296..281fc34 100644 --- a/src/LlvmIr.hs +++ b/src/LlvmIr.hs @@ -9,8 +9,8 @@ module LlvmIr ( Visibility (..), ) where -import Data.List (intercalate) -import TypeCheckerIr +import Data.List (intercalate) +import TypeCheckerIr -- | A datatype which represents some basic LLVM types data LLVMType @@ -65,7 +65,7 @@ instance Show LLVMComp where data Visibility = Local | Global instance Show Visibility where show :: Visibility -> String - show Local = "%" + show Local = "%" show Global = "@" {- | Represents a LLVM "value", as in an integer, a register variable, @@ -80,10 +80,10 @@ data LLVMValue instance Show LLVMValue where show :: LLVMValue -> String show v = case v of - VInteger i -> show i - VIdent (Ident n) _ -> "%" <> n + VInteger i -> show i + VIdent (Ident n) _ -> "%" <> n VFunction (Ident n) vis _ -> show vis <> n - VConstant s -> "c" <> show s + VConstant s -> "c" <> show s type Params = [(Ident, LLVMType)] type Args = [(LLVMType, LLVMValue)] @@ -106,7 +106,8 @@ data LLVMIr | Label Ident | Call LLVMType Visibility Ident Args | Alloca LLVMType - | Store LLVMType Ident LLVMType Ident + | Store LLVMType LLVMValue LLVMType Ident + | Load LLVMType LLVMType Ident | Bitcast LLVMType Ident LLVMType | Ret LLVMType LLVMValue | Comment String @@ -122,9 +123,9 @@ llvmIrToString = go 0 go _ [] = mempty go i (x : xs) = do let (i', n) = case x of - Define{} -> (i + 1, 0) + Define{} -> (i + 1, 0) DefineEnd -> (i - 1, 0) - _ -> (i, i) + _ -> (i, i) insToString n x <> go i' xs {- | Converts a LLVM inststruction to a String, allowing for printing etc. @@ -175,11 +176,16 @@ llvmIrToString = go 0 , ")\n" ] (Alloca t) -> unwords ["alloca", show t, "\n"] - (Store t1 (Ident id1) t2 (Ident id2)) -> + (Store t1 val t2 (Ident id2)) -> concat - [ "store ", show t1, " %", id1 + [ "store ", show t1, " ", show val , ", ", show t2 , " %", id2, "\n" ] + (Load t1 t2 (Ident addr)) -> + concat + [ "load ", show t1, ", " + , show t2, " %", addr, "\n" + ] (Bitcast t1 (Ident i) t2) -> concat [ "bitcast ", show t1, " %" @@ -196,13 +202,16 @@ llvmIrToString = go 0 , show v, "\n" ] (UnsafeRaw s) -> s - (Label (Ident s)) -> "\nlabel_" <> s <> ":\n" - (Br (Ident s)) -> "br label %label_" <> s <> "\n" + (Label (Ident s)) -> "\n" <> lblPfx <> s <> ":\n" + (Br (Ident s)) -> "br label %" <> lblPfx <> s <> "\n" (BrCond val (Ident s1) (Ident s2)) -> concat [ "br i1 ", show val, ", ", "label %" - , "label_", s1, ", ", "label %", "label_", s2, "\n" + , lblPfx, s1, ", ", "label %", lblPfx, s2, "\n" ] (Comment s) -> "; " <> s <> "\n" (Variable (Ident id)) -> "%" <> id {- FOURMOLU_ENABLE -} + +lblPfx :: String +lblPfx = "lbl_"