diff --git a/Grammar.cf b/Grammar.cf index a55e8c4..7d52004 100644 --- a/Grammar.cf +++ b/Grammar.cf @@ -13,7 +13,7 @@ Data. Data ::= "data" Constr "where" "{" [Constructor] "}" ; Constructor. Constructor ::= Ident ":" Type ; separator nonempty Constructor "" ; -TMono. Type1 ::= "_" Ident ; +TMono. Type1 ::= Ident ; TPol. Type1 ::= "'" Ident ; TConstr. Type1 ::= Constr ; TArr. Type ::= Type1 "->" Type ; diff --git a/language.cabal b/language.cabal index e190a7e..f74cb18 100644 --- a/language.cabal +++ b/language.cabal @@ -37,6 +37,8 @@ executable language Renamer.Renamer TypeChecker.TypeChecker TypeChecker.TypeCheckerIr + Monomorphizer.Monomorphizer + Monomorphizer.MonomorphizerIr -- Interpreter Codegen.Codegen Codegen.LlvmIr diff --git a/sample-programs/basic-1 b/sample-programs/basic-1 index 57ce1d9..4177ccf 100644 --- a/sample-programs/basic-1 +++ b/sample-programs/basic-1 @@ -1,26 +1,29 @@ posMul : _Int -> _Int -> _Int; -posMul a b = case b of { +posMul a b = a + b; {-case b of { 0 => 0; _ => a + posMul a (b - 1) -}; - -facc : _Int -> _Int; -facc a = case a of { - 1 => 1; - _ => posMul a (facc (a - 1)) -}; - -minimization : (_Int -> _Int) -> _Int -> _Int; -minimization p x = case p x of { - 1 => x; - _ => minimization p (x + 1) -}; - -checkFac : _Int -> _Int; -checkFac x = case facc x of { - 0 => 1; - _ => 0 -}; +};-} main : _Int; -main = minimization checkFac 1 \ No newline at end of file +main = posMul 5 10; +-- +-- facc : _Int -> _Int; +-- facc a = case a of { +-- 1 => 1; +-- _ => posMul a (facc (a - 1)) +-- }; +-- +-- minimization : (_Int -> _Int) -> _Int -> _Int; +-- minimization p x = case p x of { +-- 1 => x; +-- _ => minimization p (x + 1) +-- }; +-- +-- checkFac : _Int -> _Int; +-- checkFac x = case facc x of { +-- 0 => 1; +-- _ => 0 +-- }; +-- +-- main : _Int; +-- main = minimization checkFac 1 \ No newline at end of file diff --git a/src/Codegen/Codegen.hs b/src/Codegen/Codegen.hs index 9d3b034..b67f0c5 100644 --- a/src/Codegen/Codegen.hs +++ b/src/Codegen/Codegen.hs @@ -1,443 +1,448 @@ -module Codegen.Codegen where --- {-# LANGUAGE LambdaCase #-} --- {-# LANGUAGE OverloadedStrings #-} --- --- module Codegen.Codegen (generateCode) where --- --- import Auxiliary (snoc) --- import Codegen.LlvmIr (CallingConvention (..), --- LLVMComp (..), LLVMIr (..), --- LLVMType (..), LLVMValue (..), --- Visibility (..), llvmIrToString) --- import Control.Monad.State (StateT, execStateT, foldM_, gets, --- modify) --- import qualified Data.Bifunctor as BI --- import Data.List.Extra (trim) --- import Data.Map (Map) --- import qualified Data.Map as Map --- import Data.Tuple.Extra (dupe, first, second) --- import qualified Grammar.Abs as GA --- import Grammar.ErrM (Err) --- import System.Process.Extra (readCreateProcess, shell) --- import TypeChecker.TypeCheckerIr (Bind (..), Case (..), Exp (..), Id, --- Ident (..), Program (..), Type (..)) --- -- | The record used as the code generator state --- data CodeGenerator = CodeGenerator --- { instructions :: [LLVMIr] --- , functions :: Map Id FunctionInfo --- , constructors :: Map Id ConstructorInfo --- , variableCount :: Integer --- , labelCount :: Integer --- } --- --- -- | A state type synonym --- type CompilerState a = StateT CodeGenerator Err a --- --- data FunctionInfo = FunctionInfo --- { numArgs :: Int --- , arguments :: [Id] --- } --- data ConstructorInfo = ConstructorInfo --- { numArgsCI :: Int --- , argumentsCI :: [Id] --- , numCI :: Integer --- } --- --- --- -- | Adds a instruction to the CodeGenerator state --- emit :: LLVMIr -> CompilerState () --- emit l = modify $ \t -> t { instructions = Auxiliary.snoc l $ instructions t } --- --- -- | Increases the variable counter in the CodeGenerator state --- increaseVarCount :: CompilerState () --- increaseVarCount = modify $ \t -> t { variableCount = variableCount t + 1 } --- --- -- | Returns the variable count from the CodeGenerator state --- getVarCount :: CompilerState Integer --- getVarCount = gets variableCount --- --- -- | Increases the variable count and returns it from the CodeGenerator state --- getNewVar :: CompilerState Integer --- getNewVar = increaseVarCount >> getVarCount --- --- -- | Increses the label count and returns a label from the CodeGenerator state --- getNewLabel :: CompilerState Integer --- getNewLabel = do --- modify (\t -> t{labelCount = labelCount t + 1}) --- gets labelCount --- --- -- | Produces a map of functions infos from a list of binds, --- -- which contains useful data for code generation. --- getFunctions :: [Bind] -> Map Id FunctionInfo --- getFunctions bs = Map.fromList $ go bs --- where --- go [] = [] --- go (Bind id args _ : xs) = --- (id, FunctionInfo { numArgs=length args, arguments=args }) --- : go xs --- go (DataStructure n cons : xs) = do --- map (\(id, xs) -> ((id, TPol n), FunctionInfo { --- numArgs=length xs, arguments=createArgs xs --- })) cons --- <> go xs --- --- createArgs :: [Type] -> [Id] --- createArgs xs = fst $ foldl (\(acc, l) t -> (acc ++ [(Ident ("arg_" <> show l) , t)],l+1)) ([], 0) xs --- --- -- | Produces a map of functions infos from a list of binds, --- -- which contains useful data for code generation. --- getConstructors :: [Bind] -> Map Id ConstructorInfo --- getConstructors bs = Map.fromList $ go bs --- where --- go [] = [] --- go (DataStructure (Ident n) cons : xs) = do --- fst (foldl (\(acc,i) (Ident id, xs) -> (((Ident (n <> "_" <> id), TPol (Ident n)), ConstructorInfo { --- numArgsCI=length xs, --- argumentsCI=createArgs xs, --- numCI=i --- }) : acc, i+1)) ([],0) cons) --- <> go xs --- go (_: xs) = go xs --- --- initCodeGenerator :: [Bind] -> CodeGenerator --- initCodeGenerator scs = CodeGenerator { instructions = defaultStart --- , functions = getFunctions scs --- , constructors = getConstructors scs --- , variableCount = 0 --- , labelCount = 0 --- } --- --- run :: Err String -> IO () --- run s = do --- let s' = case s of --- Right s -> s --- Left _ -> error "yo" --- writeFile "output/llvm.ll" s' --- putStrLn . trim =<< readCreateProcess (shell "lli") s' --- --- test :: Integer -> Program --- test v = Program [ --- DataStructure (Ident "Craig") [ --- (Ident "Bob", [TInt])--, --- --(Ident "Alice", [TInt, TInt]) --- ], --- Bind (Ident "fibonacci", TInt) [(Ident "x", TInt)] (EId ("x",TInt)), --- Bind (Ident "main", TInt) [] ( --- EApp (TPol "Craig") (EId (Ident "Craig_Bob", TPol "Craig")) (EInt v) -- (EInt 92) --- ) --- ] --- --- {- | Compiles an AST and produces a LLVM Ir string. --- An easy way to actually "compile" this output is to --- Simply pipe it to LLI --- -} --- generateCode :: Program -> Err String --- generateCode (Program scs) = do --- let codegen = initCodeGenerator scs --- llvmIrToString . instructions <$> execStateT (compileScs scs) codegen --- --- compileScs :: [Bind] -> CompilerState () --- compileScs [] = do --- -- as a last step create all the constructors --- c <- gets (Map.toList . constructors) --- mapM_ (\((id, t), ci) -> do --- let t' = type2LlvmType t --- let x = BI.second type2LlvmType <$> argumentsCI ci --- emit $ Define FastCC t' id x --- top <- Ident . show <$> getNewVar --- ptr <- Ident . show <$> getNewVar --- -- allocated the primary type --- emit $ SetVariable top (Alloca t') --- --- -- set the first byte to the index of the constructor --- emit $ SetVariable ptr $ --- GetElementPtrInbounds t' (Ref t') --- (VIdent top I8) I32 (VInteger 0) I32 (VInteger 0) --- emit $ Store I8 (VInteger $ numCI ci ) (Ref I8) ptr --- --- -- get a pointer of the correct type --- ptr' <- Ident . show <$> getNewVar --- emit $ SetVariable ptr' (Bitcast (Ref t') ptr (Ref $ CustomType id)) --- --- --emit $ UnsafeRaw "\n" --- --- foldM_ (\i (Ident arg_n, arg_t)-> do --- let arg_t' = type2LlvmType arg_t --- emit $ Comment (show arg_t' <>" "<> arg_n <> " " <> show i ) --- elemPtr <- Ident . show <$> getNewVar --- emit $ SetVariable elemPtr ( --- GetElementPtrInbounds (CustomType id) (Ref (CustomType id)) --- (VIdent ptr' Ptr) I32 --- (VInteger 0) I32 (VInteger i)) --- emit $ Store arg_t' (VIdent (Ident arg_n) arg_t') Ptr elemPtr --- -- %2 = getelementptr inbounds %Foo_AInteger, %Foo_AInteger* %1, i32 0, i32 1 --- -- store i32 42, i32* %2 --- pure $ i + 1-- + typeByteSize arg_t' --- ) 1 (argumentsCI ci) --- --- --emit $ UnsafeRaw "\n" --- --- -- load and return the constructed value --- load <- Ident . show <$> getNewVar --- emit $ SetVariable load (Load t' Ptr top) --- emit $ Ret t' (VIdent load t') --- emit DefineEnd --- --- modify $ \s -> s { variableCount = 0 } --- ) c --- compileScs (Bind (name, _t) args exp : xs) = do --- emit $ UnsafeRaw "\n" --- emit . Comment $ show name <> ": " <> show exp --- let args' = map (second type2LlvmType) args --- emit $ Define FastCC I64 {-(type2LlvmType t_return)-} name args' --- functionBody <- exprToValue exp --- if name == "main" --- then mapM_ emit $ mainContent functionBody --- else emit $ Ret I64 functionBody --- emit DefineEnd --- modify $ \s -> s { variableCount = 0 } --- compileScs xs --- compileScs (DataStructure id@(Ident outer_id) ts : xs) = do --- let biggest_variant = maximum ((\(_, t) -> sum $ typeByteSize . type2LlvmType <$> t) <$> ts) --- emit $ Type id [I8, Array biggest_variant I8] --- mapM_ (\(Ident inner_id, fi) -> do --- emit $ Type (Ident $ outer_id <> "_" <> inner_id) (I8 : map type2LlvmType fi) --- ) ts --- compileScs xs --- --- -- where --- -- _t_return = snd $ partitionType (length args) t --- --- mainContent :: LLVMValue -> [LLVMIr] --- mainContent var = --- [ UnsafeRaw $ --- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef " <> show var <> ")\n" --- , -- , SetVariable (Ident "p") (Icmp LLEq I64 (VInteger 2) (VInteger 2)) --- -- , BrCond (VIdent (Ident "p")) (Ident "b_1") (Ident "b_2") --- -- , Label (Ident "b_1") --- -- , UnsafeRaw --- -- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 1)\n" --- -- , Br (Ident "end") --- -- , Label (Ident "b_2") --- -- , UnsafeRaw --- -- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 2)\n" --- -- , Br (Ident "end") --- -- , Label (Ident "end") --- Ret I64 (VInteger 0) --- ] --- --- defaultStart :: [LLVMIr] --- defaultStart = [ UnsafeRaw "target triple = \"x86_64-pc-linux-gnu\"\n" --- , UnsafeRaw "target datalayout = \"e-m:o-i64:64-f80:128-n8:16:32:64-S128\"\n" --- , UnsafeRaw "@.str = private unnamed_addr constant [3 x i8] c\"%i\n\", align 1\n" --- , UnsafeRaw "declare i32 @printf(ptr noalias nocapture, ...)\n" --- ] --- --- compileExp :: Exp -> CompilerState () --- compileExp (EInt int) = emitInt int --- compileExp (EAdd t e1 e2) = emitAdd t e1 e2 --- compileExp (ESub t e1 e2) = emitSub t e1 e2 --- compileExp (EId (name, _)) = emitIdent name --- compileExp (EApp t e1 e2) = emitApp t e1 e2 --- compileExp (EAbs t ti e) = emitAbs t ti e --- compileExp (ELet binds e) = emitLet binds e --- compileExp (ECase t e cs) = emitECased t e cs --- -- go (EMul e1 e2) = emitMul e1 e2 --- -- go (EDiv e1 e2) = emitDiv e1 e2 --- -- go (EMod e1 e2) = emitMod e1 e2 --- --- --- aux functions --- --- emitECased :: Type -> Exp -> [(Type, Case)] -> CompilerState () --- emitECased t e cases = do --- let cs = snd <$> cases --- let ty = type2LlvmType t --- vs <- exprToValue e --- lbl <- getNewLabel --- let label = Ident $ "escape_" <> show lbl --- stackPtr <- getNewVar --- emit $ SetVariable (Ident $ show stackPtr) (Alloca ty) --- mapM_ (emitCases ty label stackPtr vs) cs --- emit $ Label label --- res <- getNewVar --- emit $ SetVariable (Ident $ show res) (Load ty Ptr (Ident $ show stackPtr)) --- where --- emitCases :: LLVMType -> Ident -> Integer -> LLVMValue -> Case -> CompilerState () --- emitCases ty label stackPtr vs (Case (GA.CInt i) exp) = do --- ns <- getNewVar --- lbl_failPos <- (\x -> Ident $ "failed_" <> show x) <$> getNewLabel --- lbl_succPos <- (\x -> Ident $ "success_" <> show x) <$> getNewLabel --- emit $ SetVariable (Ident $ show ns) (Icmp LLEq ty vs (VInteger i)) --- emit $ BrCond (VIdent (Ident $ show ns) ty) lbl_succPos lbl_failPos --- emit $ Label lbl_succPos --- val <- exprToValue exp --- emit $ Store ty val Ptr (Ident . show $ stackPtr) --- emit $ Br label --- emit $ Label lbl_failPos --- emitCases ty label stackPtr _ (Case GA.CatchAll exp) = do --- val <- exprToValue exp --- emit $ Store ty val Ptr (Ident . show $ stackPtr) --- emit $ Br label --- --- --- emitAbs :: Type -> Id -> Exp -> CompilerState () --- emitAbs _t tid e = do --- emit . Comment $ --- "Lambda escaped previous stages: \\" <> show tid <> " . " <> show e --- emitLet :: Bind -> Exp -> CompilerState () --- emitLet xs e = do --- emit $ --- Comment $ --- concat --- [ "ELet (" --- , show xs --- , " = " --- , show e --- , ") is not implemented!" --- ] --- --- emitApp :: Type -> Exp -> Exp -> CompilerState () --- emitApp t e1 e2 = appEmitter t e1 e2 [] --- where --- appEmitter :: Type -> Exp -> Exp -> [Exp] -> CompilerState () --- appEmitter t e1 e2 stack = do --- let newStack = e2 : stack --- case e1 of --- EApp _ e1' e2' -> appEmitter t e1' e2' newStack --- EId id@(name, _) -> do --- args <- traverse exprToValue newStack --- vs <- getNewVar --- funcs <- gets functions --- let visibility = maybe Local (const Global) $ Map.lookup id funcs --- args' = map (first valueGetType . dupe) args --- call = Call FastCC (type2LlvmType t) visibility name args' --- emit $ SetVariable (Ident $ show vs) call --- x -> do --- emit . Comment $ "The unspeakable happened: " --- emit . Comment $ show x --- --- emitIdent :: Ident -> CompilerState () --- emitIdent id = do --- -- !!this should never happen!! --- emit $ Comment "This should not have happened!" --- emit $ Variable id --- emit $ UnsafeRaw "\n" --- --- emitInt :: Integer -> CompilerState () --- emitInt i = do --- -- !!this should never happen!! --- varCount <- getNewVar --- emit $ Comment "This should not have happened!" --- emit $ SetVariable (Ident (show varCount)) (Add I64 (VInteger i) (VInteger 0)) --- --- emitAdd :: Type -> Exp -> Exp -> CompilerState () --- emitAdd t e1 e2 = do --- v1 <- exprToValue e1 --- v2 <- exprToValue e2 --- v <- getNewVar --- emit $ SetVariable (Ident $ show v) (Add (type2LlvmType t) v1 v2) --- --- emitSub :: Type -> Exp -> Exp -> CompilerState () --- emitSub t e1 e2 = do --- v1 <- exprToValue e1 --- v2 <- exprToValue e2 --- v <- getNewVar --- emit $ SetVariable (Ident $ show v) (Sub (type2LlvmType t) v1 v2) --- --- -- emitMul :: Exp -> Exp -> CompilerState () --- -- emitMul e1 e2 = do --- -- (v1,v2) <- binExprToValues e1 e2 --- -- increaseVarCount --- -- v <- gets variableCount --- -- emit $ SetVariable $ Ident $ show v --- -- emit $ Mul I64 v1 v2 --- --- -- emitMod :: Exp -> Exp -> CompilerState () --- -- emitMod e1 e2 = do --- -- -- `let m a b = rem (abs $ b + a) b` --- -- (v1,v2) <- binExprToValues e1 e2 --- -- increaseVarCount --- -- vadd <- gets variableCount --- -- emit $ SetVariable $ Ident $ show vadd --- -- emit $ Add I64 v1 v2 --- -- --- -- increaseVarCount --- -- vabs <- gets variableCount --- -- emit $ SetVariable $ Ident $ show vabs --- -- emit $ Call I64 (Ident "llvm.abs.i64") --- -- [ (I64, VIdent (Ident $ show vadd)) --- -- , (I1, VInteger 1) --- -- ] --- -- increaseVarCount --- -- v <- gets variableCount --- -- emit $ SetVariable $ Ident $ show v --- -- emit $ Srem I64 (VIdent (Ident $ show vabs)) v2 --- --- -- emitDiv :: Exp -> Exp -> CompilerState () --- -- emitDiv e1 e2 = do --- -- (v1,v2) <- binExprToValues e1 e2 --- -- increaseVarCount --- -- v <- gets variableCount --- -- emit $ SetVariable $ Ident $ show v --- -- emit $ Div I64 v1 v2 --- --- exprToValue :: Exp -> CompilerState LLVMValue --- exprToValue = \case --- EInt i -> pure $ VInteger i --- --- EId id@(name, t) -> do --- funcs <- gets functions --- case Map.lookup id funcs of --- Just fi -> do --- if numArgs fi == 0 --- then do --- vc <- getNewVar --- emit $ SetVariable (Ident $ show vc) --- (Call FastCC (type2LlvmType t) Global name []) --- pure $ VIdent (Ident $ show vc) (type2LlvmType t) --- else pure $ VFunction name Global (type2LlvmType t) --- Nothing -> pure $ VIdent name (type2LlvmType t) --- --- e -> do --- compileExp e --- v <- getVarCount --- pure $ VIdent (Ident $ show v) (getType e) --- --- type2LlvmType :: Type -> LLVMType --- type2LlvmType = \case --- TInt -> I64 --- TFun t xs -> do --- let (t', xs') = function2LLVMType xs [type2LlvmType t] --- Function t' xs' --- TPol t -> CustomType t --- where --- function2LLVMType :: Type -> [LLVMType] -> (LLVMType, [LLVMType]) --- function2LLVMType (TFun t xs) s = function2LLVMType xs (type2LlvmType t : s) --- function2LLVMType x s = (type2LlvmType x, s) --- --- getType :: Exp -> LLVMType --- getType (EInt _) = I64 --- getType (EAdd t _ _) = type2LlvmType t --- getType (ESub t _ _) = type2LlvmType t --- getType (EId (_, t)) = type2LlvmType t --- getType (EApp t _ _) = type2LlvmType t --- getType (EAbs t _ _) = type2LlvmType t --- getType (ELet _ e) = getType e --- getType (ECase t _ _) = type2LlvmType t --- --- valueGetType :: LLVMValue -> LLVMType --- valueGetType (VInteger _) = I64 --- valueGetType (VIdent _ t) = t --- valueGetType (VConstant s) = Array (fromIntegral $ length s) I8 --- valueGetType (VFunction _ _ t) = t --- --- typeByteSize :: LLVMType -> Integer --- typeByteSize I1 = 1 --- typeByteSize I8 = 1 --- typeByteSize I32 = 4 --- typeByteSize I64 = 8 --- typeByteSize Ptr = 8 --- typeByteSize (Ref _) = 8 --- typeByteSize (Function _ _) = 8 --- typeByteSize (Array n t) = n * typeByteSize t --- typeByteSize (CustomType _) = 8 --- +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedStrings #-} + +module Codegen.Codegen (generateCode) where +import Auxiliary (snoc) +import Codegen.LlvmIr (CallingConvention (..), + LLVMComp (..), LLVMIr (..), + LLVMType (..), LLVMValue (..), + Visibility (..), llvmIrToString) +import Codegen.LlvmIr as LIR +import Control.Monad.State (StateT, execStateT, foldM_, + gets, modify) +import qualified Data.Bifunctor as BI +import Data.List.Extra (trim) +import Data.Map (Map) +import qualified Data.Map as Map +import Data.Tuple.Extra (dupe, first, second) +import qualified Grammar.Abs as GA +import Grammar.ErrM (Err) +import Monomorphizer.MonomorphizerIr as MIR +import System.Process.Extra (readCreateProcess, shell) +-- | The record used as the code generator state +data CodeGenerator = CodeGenerator + { instructions :: [LLVMIr] + , functions :: Map Id FunctionInfo + , constructors :: Map Id ConstructorInfo + , variableCount :: Integer + , labelCount :: Integer + } + +-- | A state type synonym +type CompilerState a = StateT CodeGenerator Err a + +data FunctionInfo = FunctionInfo + { numArgs :: Int + , arguments :: [Id] + } +data ConstructorInfo = ConstructorInfo + { numArgsCI :: Int + , argumentsCI :: [Id] + , numCI :: Integer + } + + +-- | Adds a instruction to the CodeGenerator state +emit :: LLVMIr -> CompilerState () +emit l = modify $ \t -> t { instructions = Auxiliary.snoc l $ instructions t } + +-- | Increases the variable counter in the CodeGenerator state +increaseVarCount :: CompilerState () +increaseVarCount = modify $ \t -> t { variableCount = variableCount t + 1 } + +-- | Returns the variable count from the CodeGenerator state +getVarCount :: CompilerState Integer +getVarCount = gets variableCount + +-- | Increases the variable count and returns it from the CodeGenerator state +getNewVar :: CompilerState Integer +getNewVar = increaseVarCount >> getVarCount + +-- | Increses the label count and returns a label from the CodeGenerator state +getNewLabel :: CompilerState Integer +getNewLabel = do + modify (\t -> t{labelCount = labelCount t + 1}) + gets labelCount + +-- | Produces a map of functions infos from a list of binds, +-- which contains useful data for code generation. +getFunctions :: [Bind] -> Map Id FunctionInfo +getFunctions bs = Map.fromList $ go bs + where + go [] = [] + go (Bind id args _ : xs) = + (id, FunctionInfo { numArgs=length args, arguments=args }) + : go xs + go (DataType n cons : xs) = do + map (\(Constructor id xs) -> ((id, MIR.Type n), FunctionInfo { + numArgs=length xs, arguments=createArgs xs + })) cons + <> go xs + +createArgs :: [Type] -> [Id] +createArgs xs = fst $ foldl (\(acc, l) t -> (acc ++ [(GA.Ident ("arg_" <> show l) , t)],l+1)) ([], 0) xs + +-- | Produces a map of functions infos from a list of binds, +-- which contains useful data for code generation. +getConstructors :: [Bind] -> Map Id ConstructorInfo +getConstructors bs = Map.fromList $ go bs + where + go [] = [] + go (DataType (GA.Ident n) cons : xs) = do + fst (foldl (\(acc,i) (Constructor (GA.Ident id) xs) -> (((GA.Ident (n <> "_" <> id), MIR.Type (GA.Ident n)), ConstructorInfo { + numArgsCI=length xs, + argumentsCI=createArgs xs, + numCI=i + }) : acc, i+1)) ([],0) cons) + <> go xs + go (_: xs) = go xs + +initCodeGenerator :: [Bind] -> CodeGenerator +initCodeGenerator scs = CodeGenerator { instructions = defaultStart + , functions = getFunctions scs + , constructors = getConstructors scs + , variableCount = 0 + , labelCount = 0 + } + +run :: Err String -> IO () +run s = do + let s' = case s of + Right s -> s + Left _ -> error "yo" + writeFile "output/llvm.ll" s' + putStrLn . trim =<< readCreateProcess (shell "lli") s' + +test :: Integer -> Program +test v = Program [ + DataType (GA.Ident "Craig") [ + Constructor (GA.Ident "Bob") [MIR.Type (GA.Ident "_Int")]--, + --(GA.Ident "Alice", [TInt, TInt]) + ], + Bind (GA.Ident "fibonacci", MIR.Type (GA.Ident "_Int")) [(GA.Ident "x", MIR.Type (GA.Ident "_Int"))] (EId ("x", MIR.Type (GA.Ident "Craig")), MIR.Type (GA.Ident "Craig")), + Bind (GA.Ident "main", MIR.Type (GA.Ident "_Int")) [] + (EApp (MIR.Type (GA.Ident "Craig")) (EId (GA.Ident "Craig_Bob", MIR.Type (GA.Ident "Craig")), MIR.Type (GA.Ident "Craig")) (ELit (LInt v), MIR.Type (GA.Ident "_Int")), MIR.Type (GA.Ident "Craig"))-- (EInt 92) + ] + +{- | Compiles an AST and produces a LLVM Ir string. + An easy way to actually "compile" this output is to + Simply pipe it to LLI +-} +generateCode :: Program -> Err String +generateCode (Program scs) = do + let codegen = initCodeGenerator scs + llvmIrToString . instructions <$> execStateT (compileScs scs) codegen + +compileScs :: [Bind] -> CompilerState () +compileScs [] = do + -- as a last step create all the constructors + c <- gets (Map.toList . constructors) + mapM_ (\((id, t), ci) -> do + let t' = type2LlvmType t + let x = BI.second type2LlvmType <$> argumentsCI ci + emit $ Define FastCC t' id x + top <- GA.Ident . show <$> getNewVar + ptr <- GA.Ident . show <$> getNewVar + -- allocated the primary type + emit $ SetVariable top (Alloca t') + + -- set the first byte to the index of the constructor + emit $ SetVariable ptr $ + GetElementPtrInbounds t' (Ref t') + (VIdent top I8) I32 (VInteger 0) I32 (VInteger 0) + emit $ Store I8 (VInteger $ numCI ci ) (Ref I8) ptr + + -- get a pointer of the correct type + ptr' <- GA.Ident . show <$> getNewVar + emit $ SetVariable ptr' (Bitcast (Ref t') ptr (Ref $ CustomType id)) + + --emit $ UnsafeRaw "\n" + + foldM_ (\i (GA.Ident arg_n, arg_t)-> do + let arg_t' = type2LlvmType arg_t + emit $ Comment (show arg_t' <>" "<> arg_n <> " " <> show i ) + elemPtr <- GA.Ident . show <$> getNewVar + emit $ SetVariable elemPtr ( + GetElementPtrInbounds (CustomType id) (Ref (CustomType id)) + (VIdent ptr' Ptr) I32 + (VInteger 0) I32 (VInteger i)) + emit $ Store arg_t' (VIdent (GA.Ident arg_n) arg_t') Ptr elemPtr + -- %2 = getelementptr inbounds %Foo_AInteger, %Foo_AInteger* %1, i32 0, i32 1 + -- store i32 42, i32* %2 + pure $ i + 1-- + typeByteSize arg_t' + ) 1 (argumentsCI ci) + + --emit $ UnsafeRaw "\n" + + -- load and return the constructed value + load <- GA.Ident . show <$> getNewVar + emit $ SetVariable load (Load t' Ptr top) + emit $ Ret t' (VIdent load t') + emit DefineEnd + + modify $ \s -> s { variableCount = 0 } + ) c +compileScs (Bind (name, _t) args exp : xs) = do + emit $ UnsafeRaw "\n" + emit . Comment $ show name <> ": " <> show exp + let args' = map (second type2LlvmType) args + emit $ Define FastCC I64 {-(type2LlvmType t_return)-} name args' + functionBody <- exprToValue (fst exp) + if name == "main" + then mapM_ emit $ mainContent functionBody + else emit $ Ret I64 functionBody + emit DefineEnd + modify $ \s -> s { variableCount = 0 } + compileScs xs +compileScs (DataType id@(GA.Ident outer_id) ts : xs) = do + let biggest_variant = maximum ((\(Constructor _ t) -> sum $ typeByteSize . type2LlvmType <$> t) <$> ts) + emit $ LIR.Type id [I8, Array biggest_variant I8] + mapM_ (\(Constructor (GA.Ident inner_id) fi) -> do + emit $ LIR.Type (GA.Ident $ outer_id <> "_" <> inner_id) (I8 : map type2LlvmType fi) + ) ts + compileScs xs + + -- where + -- _t_return = snd $ partitionType (length args) t + +mainContent :: LLVMValue -> [LLVMIr] +mainContent var = + [ UnsafeRaw $ + "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef " <> show var <> ")\n" + , -- , SetVariable (GA.Ident "p") (Icmp LLEq I64 (VInteger 2) (VInteger 2)) + -- , BrCond (VIdent (GA.Ident "p")) (GA.Ident "b_1") (GA.Ident "b_2") + -- , Label (GA.Ident "b_1") + -- , UnsafeRaw + -- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 1)\n" + -- , Br (GA.Ident "end") + -- , Label (GA.Ident "b_2") + -- , UnsafeRaw + -- "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef 2)\n" + -- , Br (GA.Ident "end") + -- , Label (GA.Ident "end") + Ret I64 (VInteger 0) + ] + +defaultStart :: [LLVMIr] +defaultStart = [ UnsafeRaw "target triple = \"x86_64-pc-linux-gnu\"\n" + , UnsafeRaw "target datalayout = \"e-m:o-i64:64-f80:128-n8:16:32:64-S128\"\n" + , UnsafeRaw "@.str = private unnamed_addr constant [3 x i8] c\"%i\n\", align 1\n" + , UnsafeRaw "declare i32 @printf(ptr noalias nocapture, ...)\n" + ] + +compileExp :: Exp -> CompilerState () +compileExp (ELit lit) = emitLit lit +compileExp (EAdd t e1 e2) = emitAdd t (fst e1) (fst e2) +--compileExp (ESub t e1 e2) = emitSub t e1 e2 +compileExp (EId (name, _)) = emitIdent name +compileExp (EApp t e1 e2) = emitApp t (fst e1) (fst e2) +--compileExp (EAbs t ti e) = emitAbs t ti e +compileExp (ELet _ binds e) = undefined emitLet binds (fst e) +compileExp (ECase t e cs) = emitECased t (fst e) (map (t,) cs) + -- go (EMul e1 e2) = emitMul e1 e2 + -- go (EDiv e1 e2) = emitDiv e1 e2 + -- go (EMod e1 e2) = emitMod e1 e2 + +--- aux functions --- +emitECased :: Type -> Exp -> [(Type, Injection)] -> CompilerState () +emitECased t e cases = do + let cs = snd <$> cases + let ty = type2LlvmType t + vs <- exprToValue e + lbl <- getNewLabel + let label = GA.Ident $ "escape_" <> show lbl + stackPtr <- getNewVar + emit $ SetVariable (GA.Ident $ show stackPtr) (Alloca ty) + mapM_ (emitCases ty label stackPtr vs) cs + emit $ Label label + res <- getNewVar + emit $ SetVariable (GA.Ident $ show res) (Load ty Ptr (GA.Ident $ show stackPtr)) + where + emitCases :: LLVMType -> GA.Ident -> Integer -> LLVMValue -> Injection -> CompilerState () + emitCases ty label stackPtr vs (Injection (MIR.CLit i) exp) = do + let i' = case i of + LInt i -> VInteger i + LChar i -> VChar i + ns <- getNewVar + lbl_failPos <- (\x -> GA.Ident $ "failed_" <> show x) <$> getNewLabel + lbl_succPos <- (\x -> GA.Ident $ "success_" <> show x) <$> getNewLabel + emit $ SetVariable (GA.Ident $ show ns) (Icmp LLEq ty vs i') + emit $ BrCond (VIdent (GA.Ident $ show ns) ty) lbl_succPos lbl_failPos + emit $ Label lbl_succPos + val <- exprToValue (fst exp) + emit $ Store ty val Ptr (GA.Ident . show $ stackPtr) + emit $ Br label + emit $ Label lbl_failPos + emitCases ty label stackPtr _ (Injection MIR.CatchAll exp) = do + val <- exprToValue (fst exp) + emit $ Store ty val Ptr (GA.Ident . show $ stackPtr) + emit $ Br label + + +emitAbs :: Type -> Id -> Exp -> CompilerState () +emitAbs _t tid e = do + emit . Comment $ + "Lambda escaped previous stages: \\" <> show tid <> " . " <> show e +emitLet :: Bind -> Exp -> CompilerState () +emitLet xs e = do + emit $ + Comment $ + concat + [ "ELet (" + , show xs + , " = " + , show e + , ") is not implemented!" + ] + +emitApp :: Type -> Exp -> Exp -> CompilerState () +emitApp t e1 e2 = appEmitter t e1 e2 [] + where + appEmitter :: Type -> Exp -> Exp -> [Exp] -> CompilerState () + appEmitter t e1 e2 stack = do + let newStack = e2 : stack + case e1 of + EApp _ (e1', _) (e2', _) -> appEmitter t e1' e2' newStack + EId id@(GA.Ident name,_ ) -> do + args <- traverse exprToValue newStack + vs <- getNewVar + funcs <- gets functions + let visibility = maybe Local (const Global) $ Map.lookup id funcs + args' = map (first valueGetType . dupe) args + call = Call FastCC (type2LlvmType t) visibility (GA.Ident name) args' + emit $ SetVariable (GA.Ident $ show vs) call + x -> do + emit . Comment $ "The unspeakable happened: " + emit . Comment $ show x + +emitIdent :: GA.Ident -> CompilerState () +emitIdent id = do + -- !!this should never happen!! + emit $ Comment "This should not have happened!" + emit $ Variable id + emit $ UnsafeRaw "\n" + +emitLit :: Lit -> CompilerState () +emitLit i = do + -- !!this should never happen!! + let (i',t) = case i of + (LInt i'') -> (VInteger i'',I64) + (LChar i'') -> (VChar i'', I8) + varCount <- getNewVar + emit $ Comment "This should not have happened!" + emit $ SetVariable (GA.Ident (show varCount)) (Add t i' (VInteger 0)) + + +emitAdd :: Type -> Exp -> Exp -> CompilerState () +emitAdd t e1 e2 = do + v1 <- exprToValue e1 + v2 <- exprToValue e2 + v <- getNewVar + emit $ SetVariable (GA.Ident $ show v) (Add (type2LlvmType t) v1 v2) + +emitSub :: Type -> Exp -> Exp -> CompilerState () +emitSub t e1 e2 = do + v1 <- exprToValue e1 + v2 <- exprToValue e2 + v <- getNewVar + emit $ SetVariable (GA.Ident $ show v) (Sub (type2LlvmType t) v1 v2) + + -- emitMul :: Exp -> Exp -> CompilerState () + -- emitMul e1 e2 = do + -- (v1,v2) <- binExprToValues e1 e2 + -- increaseVarCount + -- v <- gets variableCount + -- emit $ SetVariable $ GA.Ident $ show v + -- emit $ Mul I64 v1 v2 + + -- emitMod :: Exp -> Exp -> CompilerState () + -- emitMod e1 e2 = do + -- -- `let m a b = rem (abs $ b + a) b` + -- (v1,v2) <- binExprToValues e1 e2 + -- increaseVarCount + -- vadd <- gets variableCount + -- emit $ SetVariable $ GA.Ident $ show vadd + -- emit $ Add I64 v1 v2 + -- + -- increaseVarCount + -- vabs <- gets variableCount + -- emit $ SetVariable $ GA.Ident $ show vabs + -- emit $ Call I64 (GA.Ident "llvm.abs.i64") + -- [ (I64, VIdent (GA.Ident $ show vadd)) + -- , (I1, VInteger 1) + -- ] + -- increaseVarCount + -- v <- gets variableCount + -- emit $ SetVariable $ GA.Ident $ show v + -- emit $ Srem I64 (VIdent (GA.Ident $ show vabs)) v2 + + -- emitDiv :: Exp -> Exp -> CompilerState () + -- emitDiv e1 e2 = do + -- (v1,v2) <- binExprToValues e1 e2 + -- increaseVarCount + -- v <- gets variableCount + -- emit $ SetVariable $ GA.Ident $ show v + -- emit $ Div I64 v1 v2 + +exprToValue :: Exp -> CompilerState LLVMValue +exprToValue = \case + ELit i -> pure $ case i of + (LInt i) -> VInteger i + (LChar i) -> VChar i + EId id@(name, t) -> do + funcs <- gets functions + case Map.lookup id funcs of + Just fi -> do + if numArgs fi == 0 + then do + vc <- getNewVar + emit $ SetVariable (GA.Ident $ show vc) + (Call FastCC (type2LlvmType t) Global name []) + pure $ VIdent (GA.Ident $ show vc) (type2LlvmType t) + else pure $ VFunction name Global (type2LlvmType t) + Nothing -> pure $ VIdent name (type2LlvmType t) + e -> do + compileExp e + v <- getVarCount + pure $ VIdent (GA.Ident $ show v) (getType e) + +type2LlvmType :: Type -> LLVMType +type2LlvmType (MIR.Type (GA.Ident t)) = case t of + "_Int" -> I64 + t -> CustomType (GA.Ident t) + -- TInt -> I64 + -- TFun t xs -> do + -- let (t', xs') = function2LLVMType xs [type2LlvmType t] + -- Function t' xs' + -- TPol t -> CustomType t + --where + -- function2LLVMType :: Type -> [LLVMType] -> (LLVMType, [LLVMType]) + -- function2LLVMType (TFun t xs) s = function2LLVMType xs (type2LlvmType t : s) + -- function2LLVMType x s = (type2LlvmType x, s) + +getType :: Exp -> LLVMType +getType (ELit l) = I64 +getType (EAdd t _ _) = type2LlvmType t +--getType (ESub t _ _) = type2LlvmType t +getType (EId (_, t)) = type2LlvmType t +getType (EApp t _ _) = type2LlvmType t +--getType (EAbs t _ _) = type2LlvmType t +getType (ELet (_, t) _ e) = type2LlvmType t +getType (ECase t _ _) = type2LlvmType t + +valueGetType :: LLVMValue -> LLVMType +valueGetType (VInteger _) = I64 +valueGetType (VIdent _ t) = t +valueGetType (VConstant s) = Array (fromIntegral $ length s) I8 +valueGetType (VFunction _ _ t) = t + +typeByteSize :: LLVMType -> Integer +typeByteSize I1 = 1 +typeByteSize I8 = 1 +typeByteSize I32 = 4 +typeByteSize I64 = 8 +typeByteSize Ptr = 8 +typeByteSize (Ref _) = 8 +typeByteSize (Function _ _) = 8 +typeByteSize (Array n t) = n * typeByteSize t +typeByteSize (CustomType _) = 8 diff --git a/src/Codegen/LlvmIr.hs b/src/Codegen/LlvmIr.hs index 4a649c3..ab2ed90 100644 --- a/src/Codegen/LlvmIr.hs +++ b/src/Codegen/LlvmIr.hs @@ -1,241 +1,241 @@ -module Codegen.LlvmIr where --- {-# LANGUAGE LambdaCase #-} --- --- module Codegen.LlvmIr ( --- LLVMType (..), --- LLVMIr (..), --- llvmIrToString, --- LLVMValue (..), --- LLVMComp (..), --- Visibility (..), --- CallingConvention (..) --- ) where --- --- import Data.List (intercalate) --- import TypeChecker.TypeCheckerIr --- --- data CallingConvention = TailCC | FastCC | CCC | ColdCC --- instance Show CallingConvention where --- show :: CallingConvention -> String --- show TailCC = "tailcc" --- show FastCC = "fastcc" --- show CCC = "ccc" --- show ColdCC = "coldcc" --- --- -- | A datatype which represents some basic LLVM types --- data LLVMType --- = I1 --- | I8 --- | I32 --- | I64 --- | Ptr --- | Ref LLVMType --- | Function LLVMType [LLVMType] --- | Array Integer LLVMType --- | CustomType Ident --- --- instance Show LLVMType where --- show :: LLVMType -> String --- show = \case --- I1 -> "i1" --- I8 -> "i8" --- I32 -> "i32" --- I64 -> "i64" --- Ptr -> "ptr" --- Ref ty -> show ty <> "*" --- Function t xs -> show t <> " (" <> intercalate ", " (map show xs) <> ")*" --- Array n ty -> concat ["[", show n, " x ", show ty, "]"] --- CustomType (Ident ty) -> "%" <> ty --- --- data LLVMComp --- = LLEq --- | LLNe --- | LLUgt --- | LLUge --- | LLUlt --- | LLUle --- | LLSgt --- | LLSge --- | LLSlt --- | LLSle --- instance Show LLVMComp where --- show :: LLVMComp -> String --- show = \case --- LLEq -> "eq" --- LLNe -> "ne" --- LLUgt -> "ugt" --- LLUge -> "uge" --- LLUlt -> "ult" --- LLUle -> "ule" --- LLSgt -> "sgt" --- LLSge -> "sge" --- LLSlt -> "slt" --- LLSle -> "sle" --- --- data Visibility = Local | Global --- instance Show Visibility where --- show :: Visibility -> String --- show Local = "%" --- show Global = "@" --- --- -- | Represents a LLVM "value", as in an integer, a register variable, --- -- or a string contstant --- data LLVMValue --- = VInteger Integer --- | VIdent Ident LLVMType --- | VConstant String --- | VFunction Ident Visibility LLVMType --- --- instance Show LLVMValue where --- show :: LLVMValue -> String --- show v = case v of --- VInteger i -> show i --- VIdent (Ident n) _ -> "%" <> n --- VFunction (Ident n) vis _ -> show vis <> n --- VConstant s -> "c" <> show s --- --- type Params = [(Ident, LLVMType)] --- type Args = [(LLVMType, LLVMValue)] --- --- -- | A datatype which represents different instructions in LLVM --- data LLVMIr --- = Type Ident [LLVMType] --- | Define CallingConvention LLVMType Ident Params --- | DefineEnd --- | Declare LLVMType Ident Params --- | SetVariable Ident LLVMIr --- | Variable Ident --- | GetElementPtrInbounds LLVMType LLVMType LLVMValue LLVMType LLVMValue LLVMType LLVMValue --- | Add LLVMType LLVMValue LLVMValue --- | Sub LLVMType LLVMValue LLVMValue --- | Div LLVMType LLVMValue LLVMValue --- | Mul LLVMType LLVMValue LLVMValue --- | Srem LLVMType LLVMValue LLVMValue --- | Icmp LLVMComp LLVMType LLVMValue LLVMValue --- | Br Ident --- | BrCond LLVMValue Ident Ident --- | Label Ident --- | Call CallingConvention LLVMType Visibility Ident Args --- | Alloca LLVMType --- | Store LLVMType LLVMValue LLVMType Ident --- | Load LLVMType LLVMType Ident --- | Bitcast LLVMType Ident LLVMType --- | Ret LLVMType LLVMValue --- | Comment String --- | UnsafeRaw String -- This should generally be avoided, and proper --- -- instructions should be used in its place --- deriving (Show) --- --- -- | Converts a list of LLVMIr instructions to a string --- llvmIrToString :: [LLVMIr] -> String --- llvmIrToString = go 0 --- where --- go :: Int -> [LLVMIr] -> String --- go _ [] = mempty --- go i (x : xs) = do --- let (i', n) = case x of --- Define{} -> (i + 1, 0) --- DefineEnd -> (i - 1, 0) --- _ -> (i, i) --- insToString n x <> go i' xs --- --- {- | Converts a LLVM inststruction to a String, allowing for printing etc. --- The integer represents the indentation --- -} --- {- FOURMOLU_DISABLE -} --- insToString :: Int -> LLVMIr -> String --- insToString i l = --- replicate i '\t' <> case l of --- (GetElementPtrInbounds t1 t2 p t3 v1 t4 v2) -> do --- -- getelementptr inbounds %Foo, %Foo* %x, i32 0, i32 0 --- concat --- [ "getelementptr inbounds ", show t1, ", " , show t2 --- , " ", show p, ", ", show t3, " ", show v1, --- ", ", show t4, " ", show v2, "\n" ] --- (Type (Ident n) types) -> --- concat --- [ "%", n, " = type { " --- , intercalate ", " (map show types) --- , " }\n" --- ] --- (Define c t (Ident i) params) -> --- concat --- [ "define ", show c, " ", show t, " @", i --- , "(", intercalate ", " (map (\(Ident y, x) -> unwords [show x, "%" <> y]) params) --- , ") {\n" --- ] --- DefineEnd -> "}\n" --- (Declare _t (Ident _i) _params) -> undefined --- (SetVariable (Ident i) ir) -> concat ["%", i, " = ", insToString 0 ir] --- (Add t v1 v2) -> --- concat --- [ "add ", show t, " ", show v1 --- , ", ", show v2, "\n" --- ] --- (Sub t v1 v2) -> --- concat --- [ "sub ", show t, " ", show v1, ", " --- , show v2, "\n" --- ] --- (Div t v1 v2) -> --- concat --- [ "sdiv ", show t, " ", show v1, ", " --- , show v2, "\n" --- ] --- (Mul t v1 v2) -> --- concat --- [ "mul ", show t, " ", show v1 --- , ", ", show v2, "\n" --- ] --- (Srem t v1 v2) -> --- concat --- [ "srem ", show t, " ", show v1, ", " --- , show v2, "\n" --- ] --- (Call c t vis (Ident i) arg) -> --- concat --- [ "call ", show c, " ", show t, " ", show vis, i, "(" --- , intercalate ", " $ Prelude.map (\(x, y) -> show x <> " " <> show y) arg --- , ")\n" --- ] --- (Alloca t) -> unwords ["alloca", show t, "\n"] --- (Store t1 val t2 (Ident id2)) -> --- concat --- [ "store ", show t1, " ", show val --- , ", ", show t2 , " %", id2, "\n" --- ] --- (Load t1 t2 (Ident addr)) -> --- concat --- [ "load ", show t1, ", " --- , show t2, " %", addr, "\n" --- ] --- (Bitcast t1 (Ident i) t2) -> --- concat --- [ "bitcast ", show t1, " %" --- , i, " to ", show t2, "\n" --- ] --- (Icmp comp t v1 v2) -> --- concat --- [ "icmp ", show comp, " ", show t --- , " ", show v1, ", ", show v2, "\n" --- ] --- (Ret t v) -> --- concat --- [ "ret ", show t, " " --- , show v, "\n" --- ] --- (UnsafeRaw s) -> s --- (Label (Ident s)) -> "\n" <> lblPfx <> s <> ":\n" --- (Br (Ident s)) -> "br label %" <> lblPfx <> s <> "\n" --- (BrCond val (Ident s1) (Ident s2)) -> --- concat --- [ "br i1 ", show val, ", ", "label %" --- , lblPfx, s1, ", ", "label %", lblPfx, s2, "\n" --- ] --- (Comment s) -> "; " <> s <> "\n" --- (Variable (Ident id)) -> "%" <> id --- {- FOURMOLU_ENABLE -} --- --- lblPfx :: String --- lblPfx = "lbl_" --- +{-# LANGUAGE LambdaCase #-} + +module Codegen.LlvmIr ( + LLVMType (..), + LLVMIr (..), + llvmIrToString, + LLVMValue (..), + LLVMComp (..), + Visibility (..), + CallingConvention (..) +) where + +import Data.List (intercalate) +import Grammar.Abs (Ident (..)) + +data CallingConvention = TailCC | FastCC | CCC | ColdCC +instance Show CallingConvention where + show :: CallingConvention -> String + show TailCC = "tailcc" + show FastCC = "fastcc" + show CCC = "ccc" + show ColdCC = "coldcc" + +-- | A datatype which represents some basic LLVM types +data LLVMType + = I1 + | I8 + | I32 + | I64 + | Ptr + | Ref LLVMType + | Function LLVMType [LLVMType] + | Array Integer LLVMType + | CustomType Ident + +instance Show LLVMType where + show :: LLVMType -> String + show = \case + I1 -> "i1" + I8 -> "i8" + I32 -> "i32" + I64 -> "i64" + Ptr -> "ptr" + Ref ty -> show ty <> "*" + Function t xs -> show t <> " (" <> intercalate ", " (map show xs) <> ")*" + Array n ty -> concat ["[", show n, " x ", show ty, "]"] + CustomType (Ident ty) -> "%" <> ty + +data LLVMComp + = LLEq + | LLNe + | LLUgt + | LLUge + | LLUlt + | LLUle + | LLSgt + | LLSge + | LLSlt + | LLSle +instance Show LLVMComp where + show :: LLVMComp -> String + show = \case + LLEq -> "eq" + LLNe -> "ne" + LLUgt -> "ugt" + LLUge -> "uge" + LLUlt -> "ult" + LLUle -> "ule" + LLSgt -> "sgt" + LLSge -> "sge" + LLSlt -> "slt" + LLSle -> "sle" + +data Visibility = Local | Global +instance Show Visibility where + show :: Visibility -> String + show Local = "%" + show Global = "@" + +-- | Represents a LLVM "value", as in an integer, a register variable, +-- or a string contstant +data LLVMValue + = VInteger Integer + | VChar Char + | VIdent Ident LLVMType + | VConstant String + | VFunction Ident Visibility LLVMType + +instance Show LLVMValue where + show :: LLVMValue -> String + show v = case v of + VInteger i -> show i + VChar i -> show i + VIdent (Ident n) _ -> "%" <> n + VFunction (Ident n) vis _ -> show vis <> n + VConstant s -> "c" <> show s + +type Params = [(Ident, LLVMType)] +type Args = [(LLVMType, LLVMValue)] + +-- | A datatype which represents different instructions in LLVM +data LLVMIr + = Type Ident [LLVMType] + | Define CallingConvention LLVMType Ident Params + | DefineEnd + | Declare LLVMType Ident Params + | SetVariable Ident LLVMIr + | Variable Ident + | GetElementPtrInbounds LLVMType LLVMType LLVMValue LLVMType LLVMValue LLVMType LLVMValue + | Add LLVMType LLVMValue LLVMValue + | Sub LLVMType LLVMValue LLVMValue + | Div LLVMType LLVMValue LLVMValue + | Mul LLVMType LLVMValue LLVMValue + | Srem LLVMType LLVMValue LLVMValue + | Icmp LLVMComp LLVMType LLVMValue LLVMValue + | Br Ident + | BrCond LLVMValue Ident Ident + | Label Ident + | Call CallingConvention LLVMType Visibility Ident Args + | Alloca LLVMType + | Store LLVMType LLVMValue LLVMType Ident + | Load LLVMType LLVMType Ident + | Bitcast LLVMType Ident LLVMType + | Ret LLVMType LLVMValue + | Comment String + | UnsafeRaw String -- This should generally be avoided, and proper + -- instructions should be used in its place + deriving (Show) + +-- | Converts a list of LLVMIr instructions to a string +llvmIrToString :: [LLVMIr] -> String +llvmIrToString = go 0 + where + go :: Int -> [LLVMIr] -> String + go _ [] = mempty + go i (x : xs) = do + let (i', n) = case x of + Define{} -> (i + 1, 0) + DefineEnd -> (i - 1, 0) + _ -> (i, i) + insToString n x <> go i' xs + {- | Converts a LLVM inststruction to a String, allowing for printing etc. + The integer represents the indentation + -} + {- FOURMOLU_DISABLE -} + insToString :: Int -> LLVMIr -> String + insToString i l = + replicate i '\t' <> case l of + (GetElementPtrInbounds t1 t2 p t3 v1 t4 v2) -> do + -- getelementptr inbounds %Foo, %Foo* %x, i32 0, i32 0 + concat + [ "getelementptr inbounds ", show t1, ", " , show t2 + , " ", show p, ", ", show t3, " ", show v1, + ", ", show t4, " ", show v2, "\n" ] + (Type (Ident n) types) -> + concat + [ "%", n, " = type { " + , intercalate ", " (map show types) + , " }\n" + ] + (Define c t (Ident i) params) -> + concat + [ "define ", show c, " ", show t, " @", i + , "(", intercalate ", " (map (\(Ident y, x) -> unwords [show x, "%" <> y]) params) + , ") {\n" + ] + DefineEnd -> "}\n" + (Declare _t (Ident _i) _params) -> undefined + (SetVariable (Ident i) ir) -> concat ["%", i, " = ", insToString 0 ir] + (Add t v1 v2) -> + concat + [ "add ", show t, " ", show v1 + , ", ", show v2, "\n" + ] + (Sub t v1 v2) -> + concat + [ "sub ", show t, " ", show v1, ", " + , show v2, "\n" + ] + (Div t v1 v2) -> + concat + [ "sdiv ", show t, " ", show v1, ", " + , show v2, "\n" + ] + (Mul t v1 v2) -> + concat + [ "mul ", show t, " ", show v1 + , ", ", show v2, "\n" + ] + (Srem t v1 v2) -> + concat + [ "srem ", show t, " ", show v1, ", " + , show v2, "\n" + ] + (Call c t vis (Ident i) arg) -> + concat + [ "call ", show c, " ", show t, " ", show vis, i, "(" + , intercalate ", " $ Prelude.map (\(x, y) -> show x <> " " <> show y) arg + , ")\n" + ] + (Alloca t) -> unwords ["alloca", show t, "\n"] + (Store t1 val t2 (Ident id2)) -> + concat + [ "store ", show t1, " ", show val + , ", ", show t2 , " %", id2, "\n" + ] + (Load t1 t2 (Ident addr)) -> + concat + [ "load ", show t1, ", " + , show t2, " %", addr, "\n" + ] + (Bitcast t1 (Ident i) t2) -> + concat + [ "bitcast ", show t1, " %" + , i, " to ", show t2, "\n" + ] + (Icmp comp t v1 v2) -> + concat + [ "icmp ", show comp, " ", show t + , " ", show v1, ", ", show v2, "\n" + ] + (Ret t v) -> + concat + [ "ret ", show t, " " + , show v, "\n" + ] + (UnsafeRaw s) -> s + (Label (Ident s)) -> "\n" <> lblPfx <> s <> ":\n" + (Br (Ident s)) -> "br label %" <> lblPfx <> s <> "\n" + (BrCond val (Ident s1) (Ident s2)) -> + concat + [ "br i1 ", show val, ", ", "label %" + , lblPfx, s1, ", ", "label %", lblPfx, s2, "\n" + ] + (Comment s) -> "; " <> s <> "\n" + (Variable (Ident id)) -> "%" <> id +{- FOURMOLU_ENABLE -} + +lblPfx :: String +lblPfx = "lbl_" + diff --git a/src/Monomorphizer/Monomorphizer.hs b/src/Monomorphizer/Monomorphizer.hs new file mode 100644 index 0000000..58a0abc --- /dev/null +++ b/src/Monomorphizer/Monomorphizer.hs @@ -0,0 +1 @@ +module Monomorphizer.Monomorphizer where diff --git a/src/Monomorphizer/MonomorphizerIr.hs b/src/Monomorphizer/MonomorphizerIr.hs new file mode 100644 index 0000000..5bcd5f0 --- /dev/null +++ b/src/Monomorphizer/MonomorphizerIr.hs @@ -0,0 +1,36 @@ +module Monomorphizer.MonomorphizerIr where +import Grammar.Abs (Ident) + +newtype Program = Program [Bind] + deriving (Show, Ord, Eq) + +data Bind = Bind Id [Id] ExpT | DataType Ident [Constructor] + deriving (Show, Ord, Eq) + +data Exp + = EId Id + | ELit Lit + | ELet Id ExpT ExpT + | EApp Type ExpT ExpT + | EAdd Type ExpT ExpT + | ECase Type ExpT [Injection] + deriving (Show, Ord, Eq) + +data Injection = Injection Case ExpT + deriving (Show, Ord, Eq) + +data Case = CLit Lit | CatchAll + deriving (Show, Ord, Eq) + +data Constructor = Constructor Ident [Type] + deriving (Show, Ord, Eq) + +type Id = (Ident, Type) +type ExpT = (Exp, Type) + +data Lit = LInt Integer + | LChar Char + deriving (Show, Ord, Eq) + +newtype Type = Type Ident + deriving (Show, Ord, Eq)