From 230a205965854266c2c38d5421866d2d6934846d Mon Sep 17 00:00:00 2001 From: Samuel Hammersberg Date: Tue, 28 Mar 2023 17:37:29 +0200 Subject: [PATCH] Fixed wrongly typed functions in the code generator. --- src/Codegen/Codegen.hs | 235 +++++++++++++++++++++++------------------ src/Compiler.hs | 5 +- 2 files changed, 135 insertions(+), 105 deletions(-) diff --git a/src/Codegen/Codegen.hs b/src/Codegen/Codegen.hs index 0cb08a8..ffe1f91 100644 --- a/src/Codegen/Codegen.hs +++ b/src/Codegen/Codegen.hs @@ -1,49 +1,55 @@ -{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedStrings #-} module Codegen.Codegen (generateCode) where -import Auxiliary (snoc) -import Codegen.LlvmIr as LIR -import Control.Applicative ((<|>)) -import Control.Monad (when) -import Control.Monad.State (StateT, execStateT, foldM_, - gets, modify) -import qualified Data.Bifunctor as BI -import Data.Char (ord) -import Data.Coerce (coerce) -import Data.Map (Map) -import qualified Data.Map as Map -import Data.Maybe (fromJust, fromMaybe) -import Data.Set (Set) -import qualified Data.Set as Set -import Data.Tuple.Extra (dupe, first, second) -import Grammar.ErrM (Err) -import Monomorphizer.MonomorphizerIr as MIR -import qualified TypeChecker.TypeCheckerIr as TIR +import Auxiliary (snoc) +import Codegen.LlvmIr as LIR +import Control.Applicative ((<|>)) +import Control.Monad (when) +import Control.Monad.State ( + StateT, + execStateT, + foldM_, + gets, + modify, + ) +import Data.Bifunctor qualified as BI +import Data.Char (ord) +import Data.Coerce (coerce) +import Data.Map (Map) +import Data.Map qualified as Map +import Data.Maybe (fromJust, fromMaybe) +import Data.Set (Set) +import Data.Set qualified as Set +import Data.Tuple.Extra (dupe, first, second) +import Debug.Trace (trace) +import Grammar.ErrM (Err) +import Monomorphizer.MonomorphizerIr as MIR +import TypeChecker.TypeCheckerIr qualified as TIR -- | The record used as the code generator state data CodeGenerator = CodeGenerator - { instructions :: [LLVMIr] - , functions :: Map MIR.Id FunctionInfo - , customTypes :: Set LLVMType - , constructors :: Map TIR.Ident ConstructorInfo + { instructions :: [LLVMIr] + , functions :: Map MIR.Id FunctionInfo + , customTypes :: Set LLVMType + , constructors :: Map TIR.Ident ConstructorInfo , variableCount :: Integer - , labelCount :: Integer + , labelCount :: Integer } -- | A state type synonym type CompilerState a = StateT CodeGenerator Err a data FunctionInfo = FunctionInfo - { numArgs :: Int + { numArgs :: Int , arguments :: [Id] } deriving (Show) data ConstructorInfo = ConstructorInfo - { numArgsCI :: Int - , argumentsCI :: [Id] - , numCI :: Integer + { numArgsCI :: Int + , argumentsCI :: [Id] + , numCI :: Integer , returnTypeCI :: MIR.Type } deriving (Show) @@ -55,7 +61,7 @@ emit l = modify $ \t -> t{instructions = Auxiliary.snoc l $ instructions t} -- | Increases the variable counter in the CodeGenerator state increaseVarCount :: CompilerState () increaseVarCount = do - gets variableCount >>= \s -> emit.Comment $ "increase: " <> show (s + 1) + gets variableCount >>= \s -> emit . Comment $ "increase: " <> show (s + 1) modify $ \t -> t{variableCount = variableCount t + 1} -- | Returns the variable count from the CodeGenerator state @@ -94,23 +100,34 @@ getConstructors :: [MIR.Def] -> Map TIR.Ident ConstructorInfo getConstructors bs = Map.fromList $ go bs where go [] = [] - go (MIR.DData (MIR.Data t cons) : xs) = fst - (foldl (\(acc, i) (Inj id xs) -> - (( id, ConstructorInfo - { numArgsCI = length (init . flattenType $ xs) - , argumentsCI = createArgs (init . flattenType $ xs) - , numCI = i - , returnTypeCI = t --last . flattenType $ xs - } - ) : acc, i + 1)) ([], 0) cons) <> go xs + go (MIR.DData (MIR.Data t cons) : xs) = + fst + ( foldl + ( \(acc, i) (Inj id xs) -> + ( ( id + , ConstructorInfo + { numArgsCI = length (init . flattenType $ xs) + , argumentsCI = createArgs (init . flattenType $ xs) + , numCI = i + , returnTypeCI = t -- last . flattenType $ xs + } + ) + : acc + , i + 1 + ) + ) + ([], 0) + cons + ) + <> go xs go (_ : xs) = go xs getTypes :: [MIR.Def] -> Set LLVMType getTypes bs = Set.fromList $ go bs where - go [] = [] + go [] = [] go (MIR.DData (MIR.Data t _) : xs) = type2LlvmType t : go xs - go (_:xs) = go xs + go (_ : xs) = go xs initCodeGenerator :: [MIR.Def] -> CodeGenerator initCodeGenerator scs = @@ -165,6 +182,7 @@ test v = eCaseInt x xs = (ECase (MIR.TLit (MIR.Ident "_Int")) x xs, MIR.TLit (MIR.Ident "_Int")) int x = (ELit (LInt x), MIR.TLit (MIR.Ident "_Int")) -} + {- | Compiles an AST and produces a LLVM Ir string. An easy way to actually "compile" this output is to Simply pipe it to LLI @@ -172,7 +190,7 @@ test v = generateCode :: MIR.Program -> Err String generateCode (MIR.Program scs) = do let codegen = initCodeGenerator scs - llvmIrToString . instructions <$> execStateT (compileScs scs) codegen + llvmIrToString . instructions <$> execStateT (compileScs (trace (show scs) scs)) codegen compileScs :: [MIR.Def] -> CompilerState () compileScs [] = do @@ -240,16 +258,17 @@ compileScs [] = do modify $ \s -> s{variableCount = 0} ) c -compileScs (MIR.DBind (MIR.Bind (name, _t) args exp) : xs) = do +compileScs (MIR.DBind (MIR.Bind (name, t) args exp) : xs) = do + let t_return = type2LlvmType . last . flattenType $ t emit $ UnsafeRaw "\n" emit . Comment $ show name <> ": " <> show exp let args' = map (second type2LlvmType) args - emit $ Define FastCC I64 {-(type2LlvmType t_return)-} name args' + emit $ Define FastCC t_return name args' when (name == "main") (mapM_ emit firstMainContent) functionBody <- exprToValue exp if name == "main" then mapM_ emit $ lastMainContent functionBody - else emit $ Ret I64 functionBody + else emit $ Ret t_return functionBody emit DefineEnd modify $ \s -> s{variableCount = 0} compileScs xs @@ -267,8 +286,10 @@ compileScs (MIR.DData (MIR.Data typ ts) : xs) = do firstMainContent :: [LLVMIr] firstMainContent = - [ UnsafeRaw "call void @_ZN2GC4Heap4initEv()\n" - ] + [] + +-- UnsafeRaw "call void @_ZN2GC4Heap4initEv()\n" + lastMainContent :: LLVMValue -> [LLVMIr] lastMainContent var = [ UnsafeRaw $ @@ -284,20 +305,21 @@ defaultStart = , UnsafeRaw "@.non_exhaustive_patterns = private unnamed_addr constant [41 x i8] c\"Non-exhaustive patterns in case at %i:%i\n\"\n" , UnsafeRaw "declare i32 @printf(ptr noalias nocapture, ...)\n" , UnsafeRaw "declare i32 @exit(i32 noundef)\n" - , UnsafeRaw "declare i32 @_ZN2GC4Heap4initEv()\n" - , UnsafeRaw "declare i32 @_ZN2GC4Heap5allocEm()\n" - , UnsafeRaw "declare i32 @_ZN2GC4Heap7disposeEv()\n" + , UnsafeRaw "declare ptr @malloc(i32 noundef)\n" + , UnsafeRaw "declare void @_ZN2GC4Heap4initEv()\n" + , UnsafeRaw "declare void @_ZN2GC4Heap5allocEm()\n" + , UnsafeRaw "declare void @_ZN2GC4Heap7disposeEv()\n" ] compileExp :: ExpT -> CompilerState () -compileExp (MIR.ELit lit,t) = emitLit lit -compileExp (MIR.EAdd e1 e2,t) = emitAdd t e1 e2 +compileExp (MIR.ELit lit, t) = emitLit lit +compileExp (MIR.EAdd e1 e2, t) = emitAdd t e1 e2 -- compileExp (ESub t e1 e2) = emitSub t e1 e2 -compileExp (MIR.EVar name, t) = emitIdent name -compileExp (MIR.EApp e1 e2,t) = emitApp t e1 e2 +compileExp (MIR.EVar name, t) = emitIdent name +compileExp (MIR.EApp e1 e2, t) = emitApp t e1 e2 -- compileExp (EAbs t ti e) = emitAbs t ti e -compileExp (MIR.ELet binds e,t) = undefined -- emitLet binds (fst e) -compileExp (MIR.ECase e cs,t) = emitECased t e (map (t,) cs) +compileExp (MIR.ELet binds e, t) = undefined -- emitLet binds (fst e) +compileExp (MIR.ECase e cs, t) = emitECased t e (map (t,) cs) -- go (EMul e1 e2) = emitMul e1 e2 -- go (EDiv e1 e2) = emitDiv e1 e2 @@ -319,7 +341,7 @@ emitECased t e cases = do -- emit $ Label crashLbl emit . UnsafeRaw $ "call i32 (ptr, ...) @printf(ptr noundef @.non_exhaustive_patterns, i64 noundef 6, i64 noundef 6)\n" emit . UnsafeRaw $ "call i32 @exit(i32 noundef 1)\n" - mapM_ (const increaseVarCount) [0..1] + mapM_ (const increaseVarCount) [0 .. 1] emit $ Br label emit $ Label label res <- getNewVar @@ -349,28 +371,28 @@ emitECased t e cases = do emit $ SetVariable casted (Load (CustomType (coerce consId)) Ptr castPtr) val <- exprToValue exp enumerateOneM_ - (\i c -> do + ( \i c -> do case c of - PVar x -> do + PVar x -> do emit . Comment $ "ident " <> show x emit $ SetVariable (fst x) (ExtractValue (CustomType (coerce consId)) (VIdent casted Ptr) i) PLit (l, t) -> undefined - PInj id ps -> undefined - PCatch -> pure() - PEnum id -> undefined - --case c of - -- CIdent x -> do - -- emit . Comment $ "ident " <> show x - -- emit $ SetVariable x (ExtractValue (CustomType (fst consId)) (VIdent casted Ptr) i) - -- emit $ Store ty val Ptr stackPtr - -- CCons x cs -> error "nested constructor" - -- CLit l -> do - -- testVar <- getNewVar - -- emit $ SetVariable testVar (ExtractValue (CustomType (fst consId)) (VIdent casted Ptr) i) - -- case l of - -- LInt l -> emit $ Icmp LLEq I64 (VIdent testVar Ptr) (VInteger l) - -- LChar c -> emit $ Icmp LLEq I8 (VIdent testVar Ptr) (VChar c) - -- CCatch -> emit . Comment $ "Catch all" + PInj id ps -> undefined + PCatch -> pure () + PEnum id -> undefined + -- case c of + -- CIdent x -> do + -- emit . Comment $ "ident " <> show x + -- emit $ SetVariable x (ExtractValue (CustomType (fst consId)) (VIdent casted Ptr) i) + -- emit $ Store ty val Ptr stackPtr + -- CCons x cs -> error "nested constructor" + -- CLit l -> do + -- testVar <- getNewVar + -- emit $ SetVariable testVar (ExtractValue (CustomType (fst consId)) (VIdent casted Ptr) i) + -- case l of + -- LInt l -> emit $ Icmp LLEq I64 (VIdent testVar Ptr) (VInteger l) + -- LChar c -> emit $ Icmp LLEq I8 (VIdent testVar Ptr) (VChar c) + -- CCatch -> emit . Comment $ "Catch all" ) cs emit $ Store ty val Ptr stackPtr @@ -379,7 +401,7 @@ emitECased t e cases = do emitCases rt ty label stackPtr vs (Branch (MIR.PLit i, _) exp) = do emit $ Comment "Plit" let i' = case i of - (MIR.LInt i, _) -> VInteger i + (MIR.LInt i, _) -> VInteger i (MIR.LChar i, _) -> VChar (ord i) ns <- getNewVar lbl_failPos <- (\x -> TIR.Ident $ "failed_" <> show x) <$> getNewLabel @@ -391,7 +413,7 @@ emitECased t e cases = do emit $ Store ty val Ptr stackPtr emit $ Br label emit $ Label lbl_failPos - emitCases rt ty label stackPtr vs (Branch (MIR.PVar (id,_), _) exp) = do + emitCases rt ty label stackPtr vs (Branch (MIR.PVar (id, _), _) exp) = do emit $ Comment "Pvar" -- //TODO this is pretty disgusting and would heavily benefit from a rewrite valPtr <- getNewVar @@ -418,7 +440,7 @@ emitECased t e cases = do lbl_failPos <- (\x -> TIR.Ident $ "failed_" <> show x) <$> getNewLabel emit $ Label lbl_failPos ---emitLet :: Bind -> Exp -> CompilerState () +-- emitLet :: Bind -> Exp -> CompilerState () emitLet xs e = do emit $ Comment $ @@ -446,8 +468,7 @@ emitApp rt e1 e2 = appEmitter e1 e2 [] let visibility = fromMaybe Local $ Global <$ Map.lookup name consts - <|> - Global <$ Map.lookup (name, t) funcs + <|> Global <$ Map.lookup (name, t) funcs -- this piece of code could probably be improved, i.e remove the double `const Global` args' = map (first valueGetType . dupe) args call = Call FastCC (type2LlvmType rt) visibility name args' @@ -466,7 +487,7 @@ emitLit :: MIR.Lit -> CompilerState () emitLit i = do -- !!this should never happen!! let (i', t) = case i of - (MIR.LInt i'') -> (VInteger i'', I64) + (MIR.LInt i'') -> (VInteger i'', I64) (MIR.LChar i'') -> (VChar $ ord i'', I8) varCount <- getNewVar emit $ Comment "This should not have happened!" @@ -489,16 +510,20 @@ emitSub t e1 e2 = do exprToValue :: ExpT -> CompilerState LLVMValue exprToValue = \case (MIR.ELit i, t) -> pure $ case i of - (MIR.LInt i) -> VInteger i + (MIR.LInt i) -> VInteger i (MIR.LChar i) -> VChar $ ord i (MIR.EVar name, t) -> do funcs <- gets functions cons <- gets constructors - let res = Map.lookup (name, t) funcs - <|> - (\c -> FunctionInfo { numArgs = numArgsCI c - , arguments = argumentsCI c} ) - <$> Map.lookup name cons + let res = + Map.lookup (name, t) funcs + <|> ( \c -> + FunctionInfo + { numArgs = numArgsCI c + , arguments = argumentsCI c + } + ) + <$> Map.lookup name cons case res of Just fi -> do if numArgs fi == 0 @@ -519,40 +544,42 @@ exprToValue = \case type2LlvmType :: MIR.Type -> LLVMType type2LlvmType (MIR.TLit id@(TIR.Ident name)) = case name of "Int" -> I64 - _ -> CustomType id -type2LlvmType (MIR.TFun t xs) = do + "Char" -> I8 + _ -> CustomType id +type2LlvmType (MIR.TFun t xs) = do let (t', xs') = function2LLVMType xs [type2LlvmType t] Function t' xs' where function2LLVMType :: Type -> [LLVMType] -> (LLVMType, [LLVMType]) function2LLVMType (TFun t xs) s = function2LLVMType xs (type2LlvmType t : s) - function2LLVMType x s = (type2LlvmType x, s) + function2LLVMType x s = (type2LlvmType x, s) getType :: ExpT -> LLVMType -getType (_, t) = type2LlvmType t +getType (_, t) = type2LlvmType t extractTypeName :: MIR.Type -> TIR.Ident extractTypeName (MIR.TLit id) = id -extractTypeName (MIR.TFun t xs) = let (TIR.Ident i) = extractTypeName t - (TIR.Ident is) = extractTypeName xs - in TIR.Ident $ i <> "_$_" <> is +extractTypeName (MIR.TFun t xs) = + let (TIR.Ident i) = extractTypeName t + (TIR.Ident is) = extractTypeName xs + in TIR.Ident $ i <> "_$_" <> is valueGetType :: LLVMValue -> LLVMType -valueGetType (VInteger _) = I64 -valueGetType (VChar _) = I8 -valueGetType (VIdent _ t) = t -valueGetType (VConstant s) = Array (fromIntegral $ length s) I8 +valueGetType (VInteger _) = I64 +valueGetType (VChar _) = I8 +valueGetType (VIdent _ t) = t +valueGetType (VConstant s) = Array (fromIntegral $ length s) I8 valueGetType (VFunction _ _ t) = t typeByteSize :: LLVMType -> Integer -typeByteSize I1 = 1 -typeByteSize I8 = 1 -typeByteSize I32 = 4 -typeByteSize I64 = 8 -typeByteSize Ptr = 8 -typeByteSize (Ref _) = 8 +typeByteSize I1 = 1 +typeByteSize I8 = 1 +typeByteSize I32 = 4 +typeByteSize I64 = 8 +typeByteSize Ptr = 8 +typeByteSize (Ref _) = 8 typeByteSize (Function _ _) = 8 -typeByteSize (Array n t) = n * typeByteSize t +typeByteSize (Array n t) = n * typeByteSize t typeByteSize (CustomType _) = 8 enumerateOneM_ :: Monad m => (Integer -> a -> m b) -> [a] -> m () diff --git a/src/Compiler.hs b/src/Compiler.hs index 180914f..a10a642 100644 --- a/src/Compiler.hs +++ b/src/Compiler.hs @@ -16,7 +16,10 @@ optimize :: String -> IO String optimize = readCreateProcess (shell "opt --O3 -S") compileClang :: String -> IO String -compileClang = readCreateProcess (shell "clang -x ir -o output/hello_world -") +compileClang = readCreateProcess . shell + $ unwords ["clang++"--, "-Lsrc/GC/lib/", "-l:libgcoll.a" + , "-fno-exceptions -x", "ir" ,"-o" ,"output/hello_world" + , "-"] compile :: String -> IO String compile s = optimize s >>= compileClang