diff --git a/language.cabal b/language.cabal index 922f873..ddf0fa0 100644 --- a/language.cabal +++ b/language.cabal @@ -42,6 +42,9 @@ executable language Monomorphizer.MonomorphizerIr Codegen.Codegen Codegen.LlvmIr + Codegen.Auxillary + Codegen.CompilerState + Codegen.Emits Compiler Renamer.Renamer TreeConverter diff --git a/src/Codegen/Auxillary.hs b/src/Codegen/Auxillary.hs new file mode 100644 index 0000000..c95f4cb --- /dev/null +++ b/src/Codegen/Auxillary.hs @@ -0,0 +1,50 @@ +module Codegen.Auxillary where + +import Codegen.LlvmIr (LLVMType (..), LLVMValue (..)) +import Control.Monad (foldM_) +import Monomorphizer.MonomorphizerIr as MIR (ExpT, Type (..)) +import TypeChecker.TypeCheckerIr qualified as TIR + +type2LlvmType :: MIR.Type -> LLVMType +type2LlvmType (MIR.TLit id@(TIR.Ident name)) = case name of + "Int" -> I64 + "Char" -> I8 + _ -> CustomType id +type2LlvmType (MIR.TFun t xs) = do + let (t', xs') = function2LLVMType xs [type2LlvmType t] + Function t' xs' + where + function2LLVMType :: Type -> [LLVMType] -> (LLVMType, [LLVMType]) + function2LLVMType (TFun t xs) s = function2LLVMType xs (type2LlvmType t : s) + function2LLVMType x s = (type2LlvmType x, s) + +getType :: ExpT -> LLVMType +getType (_, t) = type2LlvmType t + +extractTypeName :: MIR.Type -> TIR.Ident +extractTypeName (MIR.TLit id) = id +extractTypeName (MIR.TFun t xs) = + let (TIR.Ident i) = extractTypeName t + (TIR.Ident is) = extractTypeName xs + in TIR.Ident $ i <> "_$_" <> is + +valueGetType :: LLVMValue -> LLVMType +valueGetType (VInteger _) = I64 +valueGetType (VChar _) = I8 +valueGetType (VIdent _ t) = t +valueGetType (VConstant s) = Array (fromIntegral $ length s) I8 +valueGetType (VFunction _ _ t) = t + +typeByteSize :: LLVMType -> Integer +typeByteSize I1 = 1 +typeByteSize I8 = 1 +typeByteSize I32 = 4 +typeByteSize I64 = 8 +typeByteSize Ptr = 8 +typeByteSize (Ref _) = 8 +typeByteSize (Function _ _) = 8 +typeByteSize (Array n t) = n * typeByteSize t +typeByteSize (CustomType _) = 8 + +enumerateOneM_ :: Monad m => (Integer -> a -> m b) -> [a] -> m () +enumerateOneM_ f = foldM_ (\i a -> f i a >> pure (i + 1)) 1 diff --git a/src/Codegen/Codegen.hs b/src/Codegen/Codegen.hs index eaf8e25..bf35f4f 100644 --- a/src/Codegen/Codegen.hs +++ b/src/Codegen/Codegen.hs @@ -1,184 +1,16 @@ -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE OverloadedStrings #-} - module Codegen.Codegen (generateCode) where -import Auxiliary (snoc) -import Codegen.LlvmIr as LIR -import Control.Applicative ((<|>)) -import Control.Monad (when) -import Control.Monad.State ( - StateT, - execStateT, - foldM_, - gets, - modify, +import Codegen.CompilerState ( + CodeGenerator (instructions), + initCodeGenerator, + ) +import Codegen.Emits (compileScs) +import Codegen.LlvmIr as LIR (llvmIrToString) +import Control.Monad.State ( + execStateT, ) -import Data.Bifunctor qualified as BI -import Data.Char (ord) -import Data.Coerce (coerce) -import Data.Map (Map) -import Data.Map qualified as Map -import Data.Maybe (fromJust, fromMaybe) -import Data.Tuple.Extra (dupe, first, second) import Grammar.ErrM (Err) -import Monomorphizer.MonomorphizerIr as MIR -import TypeChecker.TypeCheckerIr qualified as TIR - --- | The record used as the code generator state -data CodeGenerator = CodeGenerator - { instructions :: [LLVMIr] - , functions :: Map MIR.Id FunctionInfo - , customTypes :: Map LLVMType Integer - , constructors :: Map TIR.Ident ConstructorInfo - , variableCount :: Integer - , labelCount :: Integer - } - --- | A state type synonym -type CompilerState a = StateT CodeGenerator Err a - -data FunctionInfo = FunctionInfo - { numArgs :: Int - , arguments :: [Id] - } - deriving (Show) -data ConstructorInfo = ConstructorInfo - { numArgsCI :: Int - , argumentsCI :: [Id] - , numCI :: Integer - , returnTypeCI :: MIR.Type - } - deriving (Show) - --- | Adds a instruction to the CodeGenerator state -emit :: LLVMIr -> CompilerState () -emit l = modify $ \t -> t{instructions = Auxiliary.snoc l $ instructions t} - --- | Increases the variable counter in the CodeGenerator state -increaseVarCount :: CompilerState () -increaseVarCount = modify $ \t -> t{variableCount = variableCount t + 1} - --- | Returns the variable count from the CodeGenerator state -getVarCount :: CompilerState Integer -getVarCount = gets variableCount - --- | Increases the variable count and returns it from the CodeGenerator state -getNewVar :: CompilerState TIR.Ident -getNewVar = TIR.Ident . show <$> (increaseVarCount >> getVarCount) - --- | Increses the label count and returns a label from the CodeGenerator state -getNewLabel :: CompilerState Integer -getNewLabel = do - modify (\t -> t{labelCount = labelCount t + 1}) - gets labelCount - -{- | Produces a map of functions infos from a list of binds, - which contains useful data for code generation. --} -getFunctions :: [MIR.Def] -> Map Id FunctionInfo -getFunctions bs = Map.fromList $ go bs - where - go [] = [] - go (MIR.DBind (MIR.Bind id args _) : xs) = - (id, FunctionInfo{numArgs = length args, arguments = args}) - : go xs - go (_ : xs) = go xs - -createArgs :: [MIR.Type] -> [Id] -createArgs xs = fst $ foldl (\(acc, l) t -> (acc ++ [(TIR.Ident ("arg_" <> show l), t)], l + 1)) ([], 0) xs - -{- | Produces a map of functions infos from a list of binds, - which contains useful data for code generation. --} -getConstructors :: [MIR.Def] -> Map TIR.Ident ConstructorInfo -getConstructors bs = Map.fromList $ go bs - where - go [] = [] - go (MIR.DData (MIR.Data t cons) : xs) = - fst - ( foldl - ( \(acc, i) (Inj id xs) -> - ( ( id - , ConstructorInfo - { numArgsCI = length (init . flattenType $ xs) - , argumentsCI = createArgs (init . flattenType $ xs) - , numCI = i - , returnTypeCI = t -- last . flattenType $ xs - } - ) - : acc - , i + 1 - ) - ) - ([], 0) - cons - ) - <> go xs - go (_ : xs) = go xs - -getTypes :: [MIR.Def] -> Map LLVMType Integer -getTypes bs = Map.fromList $ go bs - where - go [] = [] - go (MIR.DData (MIR.Data t ts) : xs) = (type2LlvmType t, biggestVariant ts) : go xs - go (_ : xs) = go xs - variantTypes fi = init $ map type2LlvmType (flattenType fi) - biggestVariant ts = 8 + maximum (sum . (\(Inj _ fi) -> typeByteSize <$> variantTypes fi) <$> ts) - -initCodeGenerator :: [MIR.Def] -> CodeGenerator -initCodeGenerator scs = - CodeGenerator - { instructions = defaultStart - , functions = getFunctions scs - , constructors = getConstructors scs - , customTypes = getTypes scs - , variableCount = 0 - , labelCount = 0 - } - -{- -run :: Err String -> IO () -run s = do - let s' = case s of - Right s -> s - Left _ -> error "yo" - writeFile "output/llvm.ll" s' - putStrLn . trim =<< readCreateProcess (shell "lli") s' - -test :: Integer -> Program -test v = - Program - [ DataType - (TIR.Ident "Craig") - [ Constructor (TIR.Ident "Bob") [MIR.Type (TIR.Ident "_Int")] - , Constructor (TIR.Ident "Betty") [MIR.Type (TIR.Ident "_Int")] - ] - , DataType - (TIR.Ident "Alice") - [ Constructor (TIR.Ident "Eve") [MIR.Type (TIR.Ident "_Int")] -- , - -- (TIR.Ident "Alice", [TInt, TInt]) - ] - , Bind (TIR.Ident "fibonacci", MIR.Type (TIR.Ident "_Int")) [(TIR.Ident "x", MIR.Type (TIR.Ident "_Int"))] (EId ("x", MIR.Type (TIR.Ident "Craig")), MIR.Type (TIR.Ident "Craig")) - , Bind (TIR.Ident "main", MIR.Type (TIR.Ident "_Int")) [] - -- (EApp (MIR.Type (TIR.Ident "Craig")) (EId (TIR.Ident "Craig_Bob", MIR.Type (TIR.Ident "Craig")), MIR.Type (TIR.Ident "Craig")) (ELit (LInt v), MIR.Type (TIR.Ident "_Int")), MIR.Type (TIR.Ident "Craig"))-- (EInt 92) - $ - eCaseInt - (EApp (MIR.TLit (TIR.Ident "Craig")) (EId (TIR.Ident "Craig_Bob", MIR.TLit (TIR.Ident "Craig")), MIR.TLit (TIR.Ident "Craig")) (ELit (LInt v), MIR.Type (TIR.Ident "_Int")), MIR.Type (TIR.Ident "Craig")) - [ injectionCons "Craig_Bob" "Craig" [CIdent (TIR.Ident "x")] (EId (TIR.Ident "x", MIR.Type (TIR.Ident "_Int")), MIR.Type (TIR.Ident "_Int")) - , injectionCons "Craig_Betty" "Craig" [CLit (LInt 5)] (int 2) - , Injection (CIdent (TIR.Ident "z")) (int 3) - , -- , injectionInt 5 (int 6) - injectionCatchAll (int 10) - ] - ] - where - injectionCons x y xs = Injection (CCons (TIR.Ident x, MIR.Type (TIR.Ident y)) xs) - injectionInt x = Injection (CLit (LInt x)) - injectionCatchAll = Injection CatchAll - eCaseInt x xs = (ECase (MIR.TLit (MIR.Ident "_Int")) x xs, MIR.TLit (MIR.Ident "_Int")) - int x = (ELit (LInt x), MIR.TLit (MIR.Ident "_Int")) --} +import Monomorphizer.MonomorphizerIr as MIR (Program (..)) {- | Compiles an AST and produces a LLVM Ir string. An easy way to actually "compile" this output is to @@ -188,397 +20,3 @@ generateCode :: MIR.Program -> Err String generateCode (MIR.Program scs) = do let codegen = initCodeGenerator scs llvmIrToString . instructions <$> execStateT (compileScs scs) codegen - -compileScs :: [MIR.Def] -> CompilerState () -compileScs [] = do - emit $ UnsafeRaw "\n" - -- as a last step create all the constructors - -- //TODO maybe merge this with the data type match? - c <- gets (Map.toList . constructors) - mapM_ - ( \(id, ci) -> do - let t = returnTypeCI ci - let t' = type2LlvmType t - let x = BI.second type2LlvmType <$> argumentsCI ci - emit $ Define FastCC t' id x - top <- getNewVar - ptr <- getNewVar - -- allocated the primary type - emit $ SetVariable top (Alloca t') - - -- set the first byte to the index of the constructor - emit $ - SetVariable ptr $ - GetElementPtr - t' - (Ref t') - (VIdent top I8) - I64 - (VInteger 0) - I32 - (VInteger 0) - emit $ Store I8 (VInteger $ numCI ci) (Ref I8) ptr - - -- get a pointer of the correct type - ptr' <- getNewVar - emit $ SetVariable ptr' (Bitcast (Ref t') (VIdent top Ptr) (Ref $ CustomType id)) - cTypes <- gets customTypes - - enumerateOneM_ - ( \i (TIR.Ident arg_n, arg_t) -> do - let arg_t' = type2LlvmType arg_t - emit $ Comment (toIr arg_t' <> " " <> arg_n <> " " <> show i) - elemPtr <- getNewVar - emit $ - SetVariable - elemPtr - ( GetElementPtr - (CustomType id) - (Ref (CustomType id)) - (VIdent ptr' Ptr) - I64 - (VInteger 0) - I32 - (VInteger i) - ) - case Map.lookup arg_t' cTypes of - Just s -> do - emit $ Comment "Malloc and store" - heapPtr <- getNewVar - emit $ SetVariable heapPtr (Malloca s) - emit $ Store arg_t' (VIdent (TIR.Ident arg_n) arg_t') Ptr heapPtr - emit $ Store (Ref arg_t') (VIdent heapPtr arg_t') Ptr elemPtr - Nothing -> do - emit $ Comment "Just store" - emit $ Store arg_t' (VIdent (TIR.Ident arg_n) arg_t') Ptr elemPtr - ) - (argumentsCI ci) - - -- load and return the constructed value - emit $ Comment "Return the newly constructed value" - load <- getNewVar - emit $ SetVariable load (Load t' Ptr top) - emit $ Ret t' (VIdent load t') - emit DefineEnd - emit $ UnsafeRaw "\n" - - modify $ \s -> s{variableCount = 0} - ) - c -compileScs (MIR.DBind (MIR.Bind (name, t) args exp) : xs) = do - let t_return = type2LlvmType . last . flattenType $ t - emit $ UnsafeRaw "\n" - emit . Comment $ show name <> ": " <> show exp - let args' = map (second type2LlvmType) args - emit $ Define FastCC t_return name args' - when (name == "main") (mapM_ emit firstMainContent) - functionBody <- exprToValue exp - if name == "main" - then mapM_ emit $ lastMainContent functionBody - else emit $ Ret t_return functionBody - emit DefineEnd - modify $ \s -> s{variableCount = 0} - compileScs xs -compileScs (MIR.DData (MIR.Data typ ts) : xs) = do - let (TIR.Ident outer_id) = extractTypeName typ - -- //TODO this could be extracted from the customTypes map - let variantTypes fi = init $ map type2LlvmType (flattenType fi) - let biggestVariant = 7 + maximum (sum . (\(Inj _ fi) -> typeByteSize <$> variantTypes fi) <$> ts) - emit $ LIR.Type (TIR.Ident outer_id) [I8, Array biggestVariant I8] - typeSets <- gets customTypes - mapM_ - ( \(Inj inner_id fi) -> do - let types = (\s -> if Map.member s typeSets then Ref s else s) <$> variantTypes fi - emit $ LIR.Type inner_id (I8 : types) - ) - ts - compileScs xs - -firstMainContent :: [LLVMIr] -firstMainContent = - [] - --- UnsafeRaw "call void @_ZN2GC4Heap4initEv()\n" - -lastMainContent :: LLVMValue -> [LLVMIr] -lastMainContent var = - [ UnsafeRaw $ - "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef " <> toIr var <> ")\n" - , Ret I64 (VInteger 0) - ] - -defaultStart :: [LLVMIr] -defaultStart = - [ UnsafeRaw "target triple = \"x86_64-pc-linux-gnu\"\n" - , UnsafeRaw "target datalayout = \"e-m:o-i64:64-f80:128-n8:16:32:64-S128\"\n" - , UnsafeRaw "@.str = private unnamed_addr constant [3 x i8] c\"%i\n\", align 1\n" - , UnsafeRaw "@.non_exhaustive_patterns = private unnamed_addr constant [41 x i8] c\"Non-exhaustive patterns in case at %i:%i\n\"\n" - , UnsafeRaw "declare i32 @printf(ptr noalias nocapture, ...)\n" - , UnsafeRaw "declare i32 @exit(i32 noundef)\n" - , UnsafeRaw "declare ptr @malloc(i32 noundef)\n" - , UnsafeRaw "declare void @_ZN2GC4Heap4initEv()\n" - , UnsafeRaw "declare void @_ZN2GC4Heap5allocEm()\n" - , UnsafeRaw "declare void @_ZN2GC4Heap7disposeEv()\n" - ] - -compileExp :: ExpT -> CompilerState () -compileExp (MIR.ELit lit, _t) = emitLit lit -compileExp (MIR.EAdd e1 e2, t) = emitAdd t e1 e2 --- compileExp (ESub t e1 e2) = emitSub t e1 e2 -compileExp (MIR.EVar name, _t) = emitIdent name -compileExp (MIR.EApp e1 e2, t) = emitApp t e1 e2 --- compileExp (EAbs t ti e) = emitAbs t ti e -compileExp (MIR.ELet bind e, _) = emitLet bind e -compileExp (MIR.ECase e cs, t) = emitECased t e (map (t,) cs) - --- go (EMul e1 e2) = emitMul e1 e2 --- go (EDiv e1 e2) = emitDiv e1 e2 --- go (EMod e1 e2) = emitMod e1 e2 - ---- aux functions --- -emitLet :: MIR.Bind -> ExpT -> CompilerState () -emitLet (MIR.Bind id [] innerExp) e = do - evaled <- exprToValue innerExp - tempVar <- getNewVar - let t = type2LlvmType . snd $ innerExp - emit $ SetVariable tempVar (Alloca t) - emit $ Store (type2LlvmType . snd $ innerExp) evaled Ptr tempVar - emit $ SetVariable (fst id) (Load t Ptr tempVar) - compileExp e -emitLet b _ = error $ "Non empty argument list in let-bind " <> show b - -emitECased :: MIR.Type -> ExpT -> [(MIR.Type, Branch)] -> CompilerState () -emitECased t e cases = do - let cs = snd <$> cases - let ty = type2LlvmType t - let rt = type2LlvmType (snd e) - vs <- exprToValue e - lbl <- getNewLabel - let label = TIR.Ident $ "escape_" <> show lbl - stackPtr <- getNewVar - emit $ SetVariable stackPtr (Alloca ty) - mapM_ (emitCases rt ty label stackPtr vs) cs - -- crashLbl <- TIR.Ident . ("crash_" <>) . show <$> getNewLabel - -- emit $ Label crashLbl - emit . UnsafeRaw $ "call i32 (ptr, ...) @printf(ptr noundef @.non_exhaustive_patterns, i64 noundef 6, i64 noundef 6)\n" - emit . UnsafeRaw $ "call i32 @exit(i32 noundef 1)\n" - mapM_ (const increaseVarCount) [0 .. 1] - emit $ Br label - emit $ Label label - res <- getNewVar - emit $ SetVariable res (Load ty Ptr stackPtr) - where - emitCases :: LLVMType -> LLVMType -> TIR.Ident -> TIR.Ident -> LLVMValue -> Branch -> CompilerState () - emitCases rt ty label stackPtr vs (Branch (MIR.PInj consId cs, _t) exp) = do - emit $ Comment "Inj" - cons <- gets constructors - let r = fromJust $ Map.lookup consId cons - - lbl_failPos <- (\x -> TIR.Ident $ "failed_" <> show x) <$> getNewLabel - lbl_succPos <- (\x -> TIR.Ident $ "success_" <> show x) <$> getNewLabel - - consVal <- getNewVar - emit $ SetVariable consVal (ExtractValue rt vs 0) - - consCheck <- getNewVar - emit $ SetVariable consCheck (Icmp LLEq I8 (VIdent consVal I8) (VInteger $ numCI r)) - emit $ BrCond (VIdent consCheck ty) lbl_succPos lbl_failPos - emit $ Label lbl_succPos - - castPtr <- getNewVar - casted <- getNewVar - emit $ SetVariable castPtr (Alloca rt) - emit $ Store rt vs Ptr castPtr - emit $ SetVariable casted (Load (CustomType (coerce consId)) Ptr castPtr) - enumerateOneM_ - ( \i c -> do - case c of - PVar (x, topT) -> do - let topT' = type2LlvmType topT - let botT' = CustomType (coerce consId) - emit . Comment $ "ident " <> toIr topT' - cTypes <- gets customTypes - if Map.member topT' cTypes - then do - deref <- getNewVar - emit $ SetVariable deref (ExtractValue botT' (VIdent casted Ptr) i) - emit $ SetVariable x (Load topT' Ptr deref) - else emit $ SetVariable x (ExtractValue botT' (VIdent casted Ptr) i) - PLit (_l, _t) -> undefined - PInj _id _ps -> undefined - PCatch -> pure () - PEnum _id -> undefined - ) - cs - val <- exprToValue exp - emit $ Store ty val Ptr stackPtr - emit $ Br label - emit $ Label lbl_failPos - emitCases _rt ty label stackPtr vs (Branch (MIR.PLit i, t) exp) = do - emit $ Comment "Plit" - let i' = case i of - (MIR.LInt i, _) -> VInteger i - (MIR.LChar i, _) -> VChar (ord i) - ns <- getNewVar - lbl_failPos <- (\x -> TIR.Ident $ "failed_" <> show x) <$> getNewLabel - lbl_succPos <- (\x -> TIR.Ident $ "success_" <> show x) <$> getNewLabel - emit $ SetVariable ns (Icmp LLEq (type2LlvmType t) vs i') - emit $ BrCond (VIdent ns ty) lbl_succPos lbl_failPos - emit $ Label lbl_succPos - val <- exprToValue exp - emit $ Store ty val Ptr stackPtr - emit $ Br label - emit $ Label lbl_failPos - emitCases rt ty label stackPtr vs (Branch (MIR.PVar (id, _), _) exp) = do - emit $ Comment "Pvar" - -- //TODO this is pretty disgusting and would heavily benefit from a rewrite - valPtr <- getNewVar - emit $ SetVariable valPtr (Alloca rt) - emit $ Store rt vs Ptr valPtr - emit $ SetVariable id (Load rt Ptr valPtr) - val <- exprToValue exp - emit $ Store ty val Ptr stackPtr - emit $ Br label - lbl_failPos <- (\x -> TIR.Ident $ "failed_" <> show x) <$> getNewLabel - emit $ Label lbl_failPos - emitCases _rt ty label stackPtr _vs (Branch (MIR.PEnum _id, _) exp) = do - -- //TODO Penum wrong, acts as a catch all - emit $ Comment "Penum" - val <- exprToValue exp - emit $ Store ty val Ptr stackPtr - emit $ Br label - lbl_failPos <- (\x -> TIR.Ident $ "failed_" <> show x) <$> getNewLabel - emit $ Label lbl_failPos - emitCases _ ty label stackPtr _ (Branch (MIR.PCatch, _) exp) = do - emit $ Comment "Pcatch" - val <- exprToValue exp - emit $ Store ty val Ptr stackPtr - emit $ Br label - lbl_failPos <- (\x -> TIR.Ident $ "failed_" <> show x) <$> getNewLabel - emit $ Label lbl_failPos - -emitApp :: MIR.Type -> ExpT -> ExpT -> CompilerState () -emitApp rt e1 e2 = appEmitter e1 e2 [] - where - appEmitter :: ExpT -> ExpT -> [ExpT] -> CompilerState () - appEmitter e1 e2 stack = do - let newStack = e2 : stack - case e1 of - (MIR.EApp e1' e2', _) -> appEmitter e1' e2' newStack - (MIR.EVar name, t) -> do - args <- traverse exprToValue newStack - vs <- getNewVar - funcs <- gets functions - consts <- gets constructors - let visibility = - fromMaybe Local $ - Global <$ Map.lookup name consts - <|> Global <$ Map.lookup (name, t) funcs - -- this piece of code could probably be improved, i.e remove the double `const Global` - args' = map (first valueGetType . dupe) args - call = Call FastCC (type2LlvmType rt) visibility name args' - emit $ Comment $ show rt - emit $ SetVariable vs call - x -> error $ "The unspeakable happened: " <> show x - -emitIdent :: TIR.Ident -> CompilerState () -emitIdent id = do - -- !!this should never happen!! - emit $ Comment "This should not have happened!" - emit $ Variable id - emit $ UnsafeRaw "\n" - -emitLit :: MIR.Lit -> CompilerState () -emitLit i = do - -- !!this should never happen!! - let (i', t) = case i of - (MIR.LInt i'') -> (VInteger i'', I64) - (MIR.LChar i'') -> (VChar $ ord i'', I8) - varCount <- getNewVar - emit $ Comment "This should not have happened!" - emit $ SetVariable varCount (Add t i' (VInteger 0)) - -emitAdd :: MIR.Type -> ExpT -> ExpT -> CompilerState () -emitAdd t e1 e2 = do - v1 <- exprToValue e1 - v2 <- exprToValue e2 - v <- getNewVar - emit $ SetVariable v (Add (type2LlvmType t) v1 v2) - -exprToValue :: ExpT -> CompilerState LLVMValue -exprToValue = \case - (MIR.ELit i, _t) -> pure $ case i of - (MIR.LInt i) -> VInteger i - (MIR.LChar i) -> VChar $ ord i - (MIR.EVar name, t) -> do - funcs <- gets functions - cons <- gets constructors - let res = - Map.lookup (name, t) funcs - <|> ( \c -> - FunctionInfo - { numArgs = numArgsCI c - , arguments = argumentsCI c - } - ) - <$> Map.lookup name cons - case res of - Just fi -> do - if numArgs fi == 0 - then do - vc <- getNewVar - emit $ - SetVariable - vc - (Call FastCC (type2LlvmType t) Global name []) - pure $ VIdent vc (type2LlvmType t) - else pure $ VFunction name Global (type2LlvmType t) - Nothing -> pure $ VIdent name (type2LlvmType t) - e -> do - compileExp e - v <- getVarCount - pure $ VIdent (TIR.Ident $ show v) (getType e) - -type2LlvmType :: MIR.Type -> LLVMType -type2LlvmType (MIR.TLit id@(TIR.Ident name)) = case name of - "Int" -> I64 - "Char" -> I8 - _ -> CustomType id -type2LlvmType (MIR.TFun t xs) = do - let (t', xs') = function2LLVMType xs [type2LlvmType t] - Function t' xs' - where - function2LLVMType :: Type -> [LLVMType] -> (LLVMType, [LLVMType]) - function2LLVMType (TFun t xs) s = function2LLVMType xs (type2LlvmType t : s) - function2LLVMType x s = (type2LlvmType x, s) - -getType :: ExpT -> LLVMType -getType (_, t) = type2LlvmType t - -extractTypeName :: MIR.Type -> TIR.Ident -extractTypeName (MIR.TLit id) = id -extractTypeName (MIR.TFun t xs) = - let (TIR.Ident i) = extractTypeName t - (TIR.Ident is) = extractTypeName xs - in TIR.Ident $ i <> "_$_" <> is - -valueGetType :: LLVMValue -> LLVMType -valueGetType (VInteger _) = I64 -valueGetType (VChar _) = I8 -valueGetType (VIdent _ t) = t -valueGetType (VConstant s) = Array (fromIntegral $ length s) I8 -valueGetType (VFunction _ _ t) = t - -typeByteSize :: LLVMType -> Integer -typeByteSize I1 = 1 -typeByteSize I8 = 1 -typeByteSize I32 = 4 -typeByteSize I64 = 8 -typeByteSize Ptr = 8 -typeByteSize (Ref _) = 8 -typeByteSize (Function _ _) = 8 -typeByteSize (Array n t) = n * typeByteSize t -typeByteSize (CustomType _) = 8 - -enumerateOneM_ :: Monad m => (Integer -> a -> m b) -> [a] -> m () -enumerateOneM_ f = foldM_ (\i a -> f i a >> pure (i + 1)) 1 diff --git a/src/Codegen/CompilerState.hs b/src/Codegen/CompilerState.hs new file mode 100644 index 0000000..a6c100a --- /dev/null +++ b/src/Codegen/CompilerState.hs @@ -0,0 +1,141 @@ +module Codegen.CompilerState where + +import Auxiliary (snoc) +import Codegen.Auxillary (type2LlvmType, typeByteSize) +import Codegen.LlvmIr as LIR (LLVMIr (UnsafeRaw), LLVMType) +import Control.Monad.State ( + StateT, + gets, + modify, + ) +import Data.Map (Map) +import Data.Map qualified as Map +import Grammar.ErrM (Err) +import Monomorphizer.MonomorphizerIr as MIR +import TypeChecker.TypeCheckerIr qualified as TIR + +-- | The record used as the code generator state +data CodeGenerator = CodeGenerator + { instructions :: [LLVMIr] + , functions :: Map MIR.Id FunctionInfo + , customTypes :: Map LLVMType Integer + , constructors :: Map TIR.Ident ConstructorInfo + , variableCount :: Integer + , labelCount :: Integer + } + +-- | A state type synonym +type CompilerState a = StateT CodeGenerator Err a + +data FunctionInfo = FunctionInfo + { numArgs :: Int + , arguments :: [Id] + } + deriving (Show) +data ConstructorInfo = ConstructorInfo + { numArgsCI :: Int + , argumentsCI :: [Id] + , numCI :: Integer + , returnTypeCI :: MIR.Type + } + deriving (Show) + +-- | Adds a instruction to the CodeGenerator state +emit :: LLVMIr -> CompilerState () +emit l = modify $ \t -> t{instructions = Auxiliary.snoc l $ instructions t} + +-- | Increases the variable counter in the CodeGenerator state +increaseVarCount :: CompilerState () +increaseVarCount = modify $ \t -> t{variableCount = variableCount t + 1} + +-- | Returns the variable count from the CodeGenerator state +getVarCount :: CompilerState Integer +getVarCount = gets variableCount + +-- | Increases the variable count and returns it from the CodeGenerator state +getNewVar :: CompilerState TIR.Ident +getNewVar = TIR.Ident . show <$> (increaseVarCount >> getVarCount) + +-- | Increses the label count and returns a label from the CodeGenerator state +getNewLabel :: CompilerState Integer +getNewLabel = do + modify (\t -> t{labelCount = labelCount t + 1}) + gets labelCount + +{- | Produces a map of functions infos from a list of binds, + which contains useful data for code generation. +-} +getFunctions :: [MIR.Def] -> Map Id FunctionInfo +getFunctions bs = Map.fromList $ go bs + where + go [] = [] + go (MIR.DBind (MIR.Bind id args _) : xs) = + (id, FunctionInfo{numArgs = length args, arguments = args}) + : go xs + go (_ : xs) = go xs + +createArgs :: [MIR.Type] -> [Id] +createArgs xs = fst $ foldl (\(acc, l) t -> (acc ++ [(TIR.Ident ("arg_" <> show l), t)], l + 1)) ([], 0) xs + +{- | Produces a map of functions infos from a list of binds, + which contains useful data for code generation. +-} +getConstructors :: [MIR.Def] -> Map TIR.Ident ConstructorInfo +getConstructors bs = Map.fromList $ go bs + where + go [] = [] + go (MIR.DData (MIR.Data t cons) : xs) = + fst + ( foldl + ( \(acc, i) (Inj id xs) -> + ( ( id + , ConstructorInfo + { numArgsCI = length (init . flattenType $ xs) + , argumentsCI = createArgs (init . flattenType $ xs) + , numCI = i + , returnTypeCI = t -- last . flattenType $ xs + } + ) + : acc + , i + 1 + ) + ) + ([], 0) + cons + ) + <> go xs + go (_ : xs) = go xs + +getTypes :: [MIR.Def] -> Map LLVMType Integer +getTypes bs = Map.fromList $ go bs + where + go [] = [] + go (MIR.DData (MIR.Data t ts) : xs) = (type2LlvmType t, biggestVariant ts) : go xs + go (_ : xs) = go xs + variantTypes fi = init $ map type2LlvmType (flattenType fi) + biggestVariant ts = 8 + maximum (sum . (\(Inj _ fi) -> typeByteSize <$> variantTypes fi) <$> ts) + +initCodeGenerator :: [MIR.Def] -> CodeGenerator +initCodeGenerator scs = + CodeGenerator + { instructions = defaultStart + , functions = getFunctions scs + , constructors = getConstructors scs + , customTypes = getTypes scs + , variableCount = 0 + , labelCount = 0 + } + +defaultStart :: [LLVMIr] +defaultStart = + [ UnsafeRaw "target triple = \"x86_64-pc-linux-gnu\"\n" + , UnsafeRaw "target datalayout = \"e-m:o-i64:64-f80:128-n8:16:32:64-S128\"\n" + , UnsafeRaw "@.str = private unnamed_addr constant [3 x i8] c\"%i\n\", align 1\n" + , UnsafeRaw "@.non_exhaustive_patterns = private unnamed_addr constant [41 x i8] c\"Non-exhaustive patterns in case at %i:%i\n\"\n" + , UnsafeRaw "declare i32 @printf(ptr noalias nocapture, ...)\n" + , UnsafeRaw "declare i32 @exit(i32 noundef)\n" + , UnsafeRaw "declare ptr @malloc(i32 noundef)\n" + , UnsafeRaw "declare void @_ZN2GC4Heap4initEv()\n" + , UnsafeRaw "declare void @_ZN2GC4Heap5allocEm()\n" + , UnsafeRaw "declare void @_ZN2GC4Heap7disposeEv()\n" + ] \ No newline at end of file diff --git a/src/Codegen/Emits.hs b/src/Codegen/Emits.hs new file mode 100644 index 0000000..c41e340 --- /dev/null +++ b/src/Codegen/Emits.hs @@ -0,0 +1,348 @@ +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedStrings #-} + +module Codegen.Emits where + +import Codegen.Auxillary +import Codegen.CompilerState +import Codegen.LlvmIr as LIR +import Control.Applicative ((<|>)) +import Control.Monad (when) +import Control.Monad.State ( + gets, + modify, + ) +import Data.Bifunctor qualified as BI +import Data.Char (ord) +import Data.Coerce (coerce) +import Data.Map qualified as Map +import Data.Maybe (fromJust, fromMaybe) +import Data.Tuple.Extra (dupe, first, second) +import Monomorphizer.MonomorphizerIr as MIR +import TypeChecker.TypeCheckerIr qualified as TIR + +compileScs :: [MIR.Def] -> CompilerState () +compileScs [] = do + emit $ UnsafeRaw "\n" + -- as a last step create all the constructors + -- //TODO maybe merge this with the data type match? + c <- gets (Map.toList . constructors) + mapM_ + ( \(id, ci) -> do + let t = returnTypeCI ci + let t' = type2LlvmType t + let x = BI.second type2LlvmType <$> argumentsCI ci + emit $ Define FastCC t' id x + top <- getNewVar + ptr <- getNewVar + -- allocated the primary type + emit $ SetVariable top (Alloca t') + + -- set the first byte to the index of the constructor + emit $ + SetVariable ptr $ + GetElementPtr + t' + (Ref t') + (VIdent top I8) + I64 + (VInteger 0) + I32 + (VInteger 0) + emit $ Store I8 (VInteger $ numCI ci) (Ref I8) ptr + + -- get a pointer of the correct type + ptr' <- getNewVar + emit $ SetVariable ptr' (Bitcast (Ref t') (VIdent top Ptr) (Ref $ CustomType id)) + cTypes <- gets customTypes + + enumerateOneM_ + ( \i (TIR.Ident arg_n, arg_t) -> do + let arg_t' = type2LlvmType arg_t + emit $ Comment (toIr arg_t' <> " " <> arg_n <> " " <> show i) + elemPtr <- getNewVar + emit $ + SetVariable + elemPtr + ( GetElementPtr + (CustomType id) + (Ref (CustomType id)) + (VIdent ptr' Ptr) + I64 + (VInteger 0) + I32 + (VInteger i) + ) + case Map.lookup arg_t' cTypes of + Just s -> do + emit $ Comment "Malloc and store" + heapPtr <- getNewVar + emit $ SetVariable heapPtr (Malloca s) + emit $ Store arg_t' (VIdent (TIR.Ident arg_n) arg_t') Ptr heapPtr + emit $ Store (Ref arg_t') (VIdent heapPtr arg_t') Ptr elemPtr + Nothing -> do + emit $ Comment "Just store" + emit $ Store arg_t' (VIdent (TIR.Ident arg_n) arg_t') Ptr elemPtr + ) + (argumentsCI ci) + + -- load and return the constructed value + emit $ Comment "Return the newly constructed value" + load <- getNewVar + emit $ SetVariable load (Load t' Ptr top) + emit $ Ret t' (VIdent load t') + emit DefineEnd + emit $ UnsafeRaw "\n" + + modify $ \s -> s{variableCount = 0} + ) + c +compileScs (MIR.DBind (MIR.Bind (name, t) args exp) : xs) = do + let t_return = type2LlvmType . last . flattenType $ t + emit $ UnsafeRaw "\n" + emit . Comment $ show name <> ": " <> show exp + let args' = map (second type2LlvmType) args + emit $ Define FastCC t_return name args' + when (name == "main") (mapM_ emit firstMainContent) + functionBody <- exprToValue exp + if name == "main" + then mapM_ emit $ lastMainContent functionBody + else emit $ Ret t_return functionBody + emit DefineEnd + modify $ \s -> s{variableCount = 0} + compileScs xs +compileScs (MIR.DData (MIR.Data typ ts) : xs) = do + let (TIR.Ident outer_id) = extractTypeName typ + -- //TODO this could be extracted from the customTypes map + let variantTypes fi = init $ map type2LlvmType (flattenType fi) + let biggestVariant = 7 + maximum (sum . (\(Inj _ fi) -> typeByteSize <$> variantTypes fi) <$> ts) + emit $ LIR.Type (TIR.Ident outer_id) [I8, Array biggestVariant I8] + typeSets <- gets customTypes + mapM_ + ( \(Inj inner_id fi) -> do + let types = (\s -> if Map.member s typeSets then Ref s else s) <$> variantTypes fi + emit $ LIR.Type inner_id (I8 : types) + ) + ts + compileScs xs + +firstMainContent :: [LLVMIr] +firstMainContent = [] + +lastMainContent :: LLVMValue -> [LLVMIr] +lastMainContent var = + [ UnsafeRaw $ + "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef " <> toIr var <> ")\n" + , Ret I64 (VInteger 0) + ] + +compileExp :: ExpT -> CompilerState () +compileExp (MIR.ELit lit, _t) = emitLit lit +compileExp (MIR.EAdd e1 e2, t) = emitAdd t e1 e2 +compileExp (MIR.EVar name, _t) = emitIdent name +compileExp (MIR.EApp e1 e2, t) = emitApp t e1 e2 +compileExp (MIR.ELet bind e, _) = emitLet bind e +compileExp (MIR.ECase e cs, t) = emitECased t e (map (t,) cs) + +emitLet :: MIR.Bind -> ExpT -> CompilerState () +emitLet (MIR.Bind id [] innerExp) e = do + evaled <- exprToValue innerExp + tempVar <- getNewVar + let t = type2LlvmType . snd $ innerExp + emit $ SetVariable tempVar (Alloca t) + emit $ Store (type2LlvmType . snd $ innerExp) evaled Ptr tempVar + emit $ SetVariable (fst id) (Load t Ptr tempVar) + compileExp e +emitLet b _ = error $ "Non empty argument list in let-bind " <> show b + +emitECased :: MIR.Type -> ExpT -> [(MIR.Type, Branch)] -> CompilerState () +emitECased t e cases = do + let cs = snd <$> cases + let ty = type2LlvmType t + let rt = type2LlvmType (snd e) + vs <- exprToValue e + lbl <- getNewLabel + let label = TIR.Ident $ "escape_" <> show lbl + stackPtr <- getNewVar + emit $ SetVariable stackPtr (Alloca ty) + mapM_ (emitCases rt ty label stackPtr vs) cs + -- crashLbl <- TIR.Ident . ("crash_" <>) . show <$> getNewLabel + -- emit $ Label crashLbl + emit . UnsafeRaw $ "call i32 (ptr, ...) @printf(ptr noundef @.non_exhaustive_patterns, i64 noundef 6, i64 noundef 6)\n" + emit . UnsafeRaw $ "call i32 @exit(i32 noundef 1)\n" + mapM_ (const increaseVarCount) [0 .. 1] + emit $ Br label + emit $ Label label + res <- getNewVar + emit $ SetVariable res (Load ty Ptr stackPtr) + where + emitCases :: LLVMType -> LLVMType -> TIR.Ident -> TIR.Ident -> LLVMValue -> Branch -> CompilerState () + emitCases rt ty label stackPtr vs (Branch (MIR.PInj consId cs, _t) exp) = do + emit $ Comment "Inj" + cons <- gets constructors + let r = fromJust $ Map.lookup consId cons + + lbl_failPos <- (\x -> TIR.Ident $ "failed_" <> show x) <$> getNewLabel + lbl_succPos <- (\x -> TIR.Ident $ "success_" <> show x) <$> getNewLabel + + consVal <- getNewVar + emit $ SetVariable consVal (ExtractValue rt vs 0) + + consCheck <- getNewVar + emit $ SetVariable consCheck (Icmp LLEq I8 (VIdent consVal I8) (VInteger $ numCI r)) + emit $ BrCond (VIdent consCheck ty) lbl_succPos lbl_failPos + emit $ Label lbl_succPos + + castPtr <- getNewVar + casted <- getNewVar + emit $ SetVariable castPtr (Alloca rt) + emit $ Store rt vs Ptr castPtr + emit $ SetVariable casted (Load (CustomType (coerce consId)) Ptr castPtr) + enumerateOneM_ + ( \i c -> do + case c of + PVar (x, topT) -> do + let topT' = type2LlvmType topT + let botT' = CustomType (coerce consId) + emit . Comment $ "ident " <> toIr topT' + cTypes <- gets customTypes + if Map.member topT' cTypes + then do + deref <- getNewVar + emit $ SetVariable deref (ExtractValue botT' (VIdent casted Ptr) i) + emit $ SetVariable x (Load topT' Ptr deref) + else emit $ SetVariable x (ExtractValue botT' (VIdent casted Ptr) i) + PLit (_l, _t) -> undefined + PInj _id _ps -> undefined + PCatch -> pure () + PEnum _id -> undefined + ) + cs + val <- exprToValue exp + emit $ Store ty val Ptr stackPtr + emit $ Br label + emit $ Label lbl_failPos + emitCases _rt ty label stackPtr vs (Branch (MIR.PLit i, t) exp) = do + emit $ Comment "Plit" + let i' = case i of + (MIR.LInt i, _) -> VInteger i + (MIR.LChar i, _) -> VChar (ord i) + ns <- getNewVar + lbl_failPos <- (\x -> TIR.Ident $ "failed_" <> show x) <$> getNewLabel + lbl_succPos <- (\x -> TIR.Ident $ "success_" <> show x) <$> getNewLabel + emit $ SetVariable ns (Icmp LLEq (type2LlvmType t) vs i') + emit $ BrCond (VIdent ns ty) lbl_succPos lbl_failPos + emit $ Label lbl_succPos + val <- exprToValue exp + emit $ Store ty val Ptr stackPtr + emit $ Br label + emit $ Label lbl_failPos + emitCases rt ty label stackPtr vs (Branch (MIR.PVar (id, _), _) exp) = do + emit $ Comment "Pvar" + -- //TODO this is pretty disgusting and would heavily benefit from a rewrite + valPtr <- getNewVar + emit $ SetVariable valPtr (Alloca rt) + emit $ Store rt vs Ptr valPtr + emit $ SetVariable id (Load rt Ptr valPtr) + val <- exprToValue exp + emit $ Store ty val Ptr stackPtr + emit $ Br label + lbl_failPos <- (\x -> TIR.Ident $ "failed_" <> show x) <$> getNewLabel + emit $ Label lbl_failPos + emitCases _rt ty label stackPtr _vs (Branch (MIR.PEnum _id, _) exp) = do + -- //TODO Penum wrong, acts as a catch all + emit $ Comment "Penum" + val <- exprToValue exp + emit $ Store ty val Ptr stackPtr + emit $ Br label + lbl_failPos <- (\x -> TIR.Ident $ "failed_" <> show x) <$> getNewLabel + emit $ Label lbl_failPos + emitCases _ ty label stackPtr _ (Branch (MIR.PCatch, _) exp) = do + emit $ Comment "Pcatch" + val <- exprToValue exp + emit $ Store ty val Ptr stackPtr + emit $ Br label + lbl_failPos <- (\x -> TIR.Ident $ "failed_" <> show x) <$> getNewLabel + emit $ Label lbl_failPos + +emitApp :: MIR.Type -> ExpT -> ExpT -> CompilerState () +emitApp rt e1 e2 = appEmitter e1 e2 [] + where + appEmitter :: ExpT -> ExpT -> [ExpT] -> CompilerState () + appEmitter e1 e2 stack = do + let newStack = e2 : stack + case e1 of + (MIR.EApp e1' e2', _) -> appEmitter e1' e2' newStack + (MIR.EVar name, t) -> do + args <- traverse exprToValue newStack + vs <- getNewVar + funcs <- gets functions + consts <- gets constructors + let visibility = + fromMaybe Local $ + Global <$ Map.lookup name consts + <|> Global <$ Map.lookup (name, t) funcs + -- this piece of code could probably be improved, i.e remove the double `const Global` + args' = map (first valueGetType . dupe) args + call = Call FastCC (type2LlvmType rt) visibility name args' + emit $ Comment $ show rt + emit $ SetVariable vs call + x -> error $ "The unspeakable happened: " <> show x + +emitIdent :: TIR.Ident -> CompilerState () +emitIdent id = do + -- !!this should never happen!! + emit $ Comment "This should not have happened!" + emit $ Variable id + emit $ UnsafeRaw "\n" + +emitLit :: MIR.Lit -> CompilerState () +emitLit i = do + -- !!this should never happen!! + let (i', t) = case i of + (MIR.LInt i'') -> (VInteger i'', I64) + (MIR.LChar i'') -> (VChar $ ord i'', I8) + varCount <- getNewVar + emit $ Comment "This should not have happened!" + emit $ SetVariable varCount (Add t i' (VInteger 0)) + +emitAdd :: MIR.Type -> ExpT -> ExpT -> CompilerState () +emitAdd t e1 e2 = do + v1 <- exprToValue e1 + v2 <- exprToValue e2 + v <- getNewVar + emit $ SetVariable v (Add (type2LlvmType t) v1 v2) + +exprToValue :: ExpT -> CompilerState LLVMValue +exprToValue = \case + (MIR.ELit i, _t) -> pure $ case i of + (MIR.LInt i) -> VInteger i + (MIR.LChar i) -> VChar $ ord i + (MIR.EVar name, t) -> do + funcs <- gets functions + cons <- gets constructors + let res = + Map.lookup (name, t) funcs + <|> ( \c -> + FunctionInfo + { numArgs = numArgsCI c + , arguments = argumentsCI c + } + ) + <$> Map.lookup name cons + case res of + Just fi -> do + if numArgs fi == 0 + then do + vc <- getNewVar + emit $ + SetVariable + vc + (Call FastCC (type2LlvmType t) Global name []) + pure $ VIdent vc (type2LlvmType t) + else pure $ VFunction name Global (type2LlvmType t) + Nothing -> pure $ VIdent name (type2LlvmType t) + e -> do + compileExp e + v <- getVarCount + pure $ VIdent (TIR.Ident $ show v) (getType e) diff --git a/test_program.crf b/test_program.crf index c5b3f9d..14cd86c 100644 --- a/test_program.crf +++ b/test_program.crf @@ -31,7 +31,7 @@ bind x f = case x of { -- represents minus one :) minusOne : Int ; minusOne = 9223372036854775807 + 9223372036854775807 + 1; -{- + ---- LIST STUFF ---- -- a simple list data type containing ints data List () where { @@ -69,4 +69,3 @@ repeat x n = case n of { 0 => Nil ; n => Cons x (repeat x (n + minusOne)) ; }; --} \ No newline at end of file