Fixed wrongly typed functions in the code generator.

This commit is contained in:
Samuel Hammersberg 2023-03-28 17:37:29 +02:00
parent e87e2d3870
commit 230a205965
2 changed files with 135 additions and 105 deletions

View file

@ -7,20 +7,26 @@ import Auxiliary (snoc)
import Codegen.LlvmIr as LIR import Codegen.LlvmIr as LIR
import Control.Applicative ((<|>)) import Control.Applicative ((<|>))
import Control.Monad (when) import Control.Monad (when)
import Control.Monad.State (StateT, execStateT, foldM_, import Control.Monad.State (
gets, modify) StateT,
import qualified Data.Bifunctor as BI execStateT,
foldM_,
gets,
modify,
)
import Data.Bifunctor qualified as BI
import Data.Char (ord) import Data.Char (ord)
import Data.Coerce (coerce) import Data.Coerce (coerce)
import Data.Map (Map) import Data.Map (Map)
import qualified Data.Map as Map import Data.Map qualified as Map
import Data.Maybe (fromJust, fromMaybe) import Data.Maybe (fromJust, fromMaybe)
import Data.Set (Set) import Data.Set (Set)
import qualified Data.Set as Set import Data.Set qualified as Set
import Data.Tuple.Extra (dupe, first, second) import Data.Tuple.Extra (dupe, first, second)
import Debug.Trace (trace)
import Grammar.ErrM (Err) import Grammar.ErrM (Err)
import Monomorphizer.MonomorphizerIr as MIR import Monomorphizer.MonomorphizerIr as MIR
import qualified TypeChecker.TypeCheckerIr as TIR import TypeChecker.TypeCheckerIr qualified as TIR
-- | The record used as the code generator state -- | The record used as the code generator state
data CodeGenerator = CodeGenerator data CodeGenerator = CodeGenerator
@ -94,15 +100,26 @@ getConstructors :: [MIR.Def] -> Map TIR.Ident ConstructorInfo
getConstructors bs = Map.fromList $ go bs getConstructors bs = Map.fromList $ go bs
where where
go [] = [] go [] = []
go (MIR.DData (MIR.Data t cons) : xs) = fst go (MIR.DData (MIR.Data t cons) : xs) =
(foldl (\(acc, i) (Inj id xs) -> fst
(( id, ConstructorInfo ( foldl
( \(acc, i) (Inj id xs) ->
( ( id
, ConstructorInfo
{ numArgsCI = length (init . flattenType $ xs) { numArgsCI = length (init . flattenType $ xs)
, argumentsCI = createArgs (init . flattenType $ xs) , argumentsCI = createArgs (init . flattenType $ xs)
, numCI = i , numCI = i
, returnTypeCI = t -- last . flattenType $ xs , returnTypeCI = t -- last . flattenType $ xs
} }
) : acc, i + 1)) ([], 0) cons) <> go xs )
: acc
, i + 1
)
)
([], 0)
cons
)
<> go xs
go (_ : xs) = go xs go (_ : xs) = go xs
getTypes :: [MIR.Def] -> Set LLVMType getTypes :: [MIR.Def] -> Set LLVMType
@ -165,6 +182,7 @@ test v =
eCaseInt x xs = (ECase (MIR.TLit (MIR.Ident "_Int")) x xs, MIR.TLit (MIR.Ident "_Int")) eCaseInt x xs = (ECase (MIR.TLit (MIR.Ident "_Int")) x xs, MIR.TLit (MIR.Ident "_Int"))
int x = (ELit (LInt x), MIR.TLit (MIR.Ident "_Int")) int x = (ELit (LInt x), MIR.TLit (MIR.Ident "_Int"))
-} -}
{- | Compiles an AST and produces a LLVM Ir string. {- | Compiles an AST and produces a LLVM Ir string.
An easy way to actually "compile" this output is to An easy way to actually "compile" this output is to
Simply pipe it to LLI Simply pipe it to LLI
@ -172,7 +190,7 @@ test v =
generateCode :: MIR.Program -> Err String generateCode :: MIR.Program -> Err String
generateCode (MIR.Program scs) = do generateCode (MIR.Program scs) = do
let codegen = initCodeGenerator scs let codegen = initCodeGenerator scs
llvmIrToString . instructions <$> execStateT (compileScs scs) codegen llvmIrToString . instructions <$> execStateT (compileScs (trace (show scs) scs)) codegen
compileScs :: [MIR.Def] -> CompilerState () compileScs :: [MIR.Def] -> CompilerState ()
compileScs [] = do compileScs [] = do
@ -240,16 +258,17 @@ compileScs [] = do
modify $ \s -> s{variableCount = 0} modify $ \s -> s{variableCount = 0}
) )
c c
compileScs (MIR.DBind (MIR.Bind (name, _t) args exp) : xs) = do compileScs (MIR.DBind (MIR.Bind (name, t) args exp) : xs) = do
let t_return = type2LlvmType . last . flattenType $ t
emit $ UnsafeRaw "\n" emit $ UnsafeRaw "\n"
emit . Comment $ show name <> ": " <> show exp emit . Comment $ show name <> ": " <> show exp
let args' = map (second type2LlvmType) args let args' = map (second type2LlvmType) args
emit $ Define FastCC I64 {-(type2LlvmType t_return)-} name args' emit $ Define FastCC t_return name args'
when (name == "main") (mapM_ emit firstMainContent) when (name == "main") (mapM_ emit firstMainContent)
functionBody <- exprToValue exp functionBody <- exprToValue exp
if name == "main" if name == "main"
then mapM_ emit $ lastMainContent functionBody then mapM_ emit $ lastMainContent functionBody
else emit $ Ret I64 functionBody else emit $ Ret t_return functionBody
emit DefineEnd emit DefineEnd
modify $ \s -> s{variableCount = 0} modify $ \s -> s{variableCount = 0}
compileScs xs compileScs xs
@ -267,8 +286,10 @@ compileScs (MIR.DData (MIR.Data typ ts) : xs) = do
firstMainContent :: [LLVMIr] firstMainContent :: [LLVMIr]
firstMainContent = firstMainContent =
[ UnsafeRaw "call void @_ZN2GC4Heap4initEv()\n" []
]
-- UnsafeRaw "call void @_ZN2GC4Heap4initEv()\n"
lastMainContent :: LLVMValue -> [LLVMIr] lastMainContent :: LLVMValue -> [LLVMIr]
lastMainContent var = lastMainContent var =
[ UnsafeRaw $ [ UnsafeRaw $
@ -284,9 +305,10 @@ defaultStart =
, UnsafeRaw "@.non_exhaustive_patterns = private unnamed_addr constant [41 x i8] c\"Non-exhaustive patterns in case at %i:%i\n\"\n" , UnsafeRaw "@.non_exhaustive_patterns = private unnamed_addr constant [41 x i8] c\"Non-exhaustive patterns in case at %i:%i\n\"\n"
, UnsafeRaw "declare i32 @printf(ptr noalias nocapture, ...)\n" , UnsafeRaw "declare i32 @printf(ptr noalias nocapture, ...)\n"
, UnsafeRaw "declare i32 @exit(i32 noundef)\n" , UnsafeRaw "declare i32 @exit(i32 noundef)\n"
, UnsafeRaw "declare i32 @_ZN2GC4Heap4initEv()\n" , UnsafeRaw "declare ptr @malloc(i32 noundef)\n"
, UnsafeRaw "declare i32 @_ZN2GC4Heap5allocEm()\n" , UnsafeRaw "declare void @_ZN2GC4Heap4initEv()\n"
, UnsafeRaw "declare i32 @_ZN2GC4Heap7disposeEv()\n" , UnsafeRaw "declare void @_ZN2GC4Heap5allocEm()\n"
, UnsafeRaw "declare void @_ZN2GC4Heap7disposeEv()\n"
] ]
compileExp :: ExpT -> CompilerState () compileExp :: ExpT -> CompilerState ()
@ -446,8 +468,7 @@ emitApp rt e1 e2 = appEmitter e1 e2 []
let visibility = let visibility =
fromMaybe Local $ fromMaybe Local $
Global <$ Map.lookup name consts Global <$ Map.lookup name consts
<|> <|> Global <$ Map.lookup (name, t) funcs
Global <$ Map.lookup (name, t) funcs
-- this piece of code could probably be improved, i.e remove the double `const Global` -- this piece of code could probably be improved, i.e remove the double `const Global`
args' = map (first valueGetType . dupe) args args' = map (first valueGetType . dupe) args
call = Call FastCC (type2LlvmType rt) visibility name args' call = Call FastCC (type2LlvmType rt) visibility name args'
@ -494,10 +515,14 @@ exprToValue = \case
(MIR.EVar name, t) -> do (MIR.EVar name, t) -> do
funcs <- gets functions funcs <- gets functions
cons <- gets constructors cons <- gets constructors
let res = Map.lookup (name, t) funcs let res =
<|> Map.lookup (name, t) funcs
(\c -> FunctionInfo { numArgs = numArgsCI c <|> ( \c ->
, arguments = argumentsCI c} ) FunctionInfo
{ numArgs = numArgsCI c
, arguments = argumentsCI c
}
)
<$> Map.lookup name cons <$> Map.lookup name cons
case res of case res of
Just fi -> do Just fi -> do
@ -519,6 +544,7 @@ exprToValue = \case
type2LlvmType :: MIR.Type -> LLVMType type2LlvmType :: MIR.Type -> LLVMType
type2LlvmType (MIR.TLit id@(TIR.Ident name)) = case name of type2LlvmType (MIR.TLit id@(TIR.Ident name)) = case name of
"Int" -> I64 "Int" -> I64
"Char" -> I8
_ -> CustomType id _ -> CustomType id
type2LlvmType (MIR.TFun t xs) = do type2LlvmType (MIR.TFun t xs) = do
let (t', xs') = function2LLVMType xs [type2LlvmType t] let (t', xs') = function2LLVMType xs [type2LlvmType t]
@ -533,7 +559,8 @@ getType (_, t) = type2LlvmType t
extractTypeName :: MIR.Type -> TIR.Ident extractTypeName :: MIR.Type -> TIR.Ident
extractTypeName (MIR.TLit id) = id extractTypeName (MIR.TLit id) = id
extractTypeName (MIR.TFun t xs) = let (TIR.Ident i) = extractTypeName t extractTypeName (MIR.TFun t xs) =
let (TIR.Ident i) = extractTypeName t
(TIR.Ident is) = extractTypeName xs (TIR.Ident is) = extractTypeName xs
in TIR.Ident $ i <> "_$_" <> is in TIR.Ident $ i <> "_$_" <> is

View file

@ -16,7 +16,10 @@ optimize :: String -> IO String
optimize = readCreateProcess (shell "opt --O3 -S") optimize = readCreateProcess (shell "opt --O3 -S")
compileClang :: String -> IO String compileClang :: String -> IO String
compileClang = readCreateProcess (shell "clang -x ir -o output/hello_world -") compileClang = readCreateProcess . shell
$ unwords ["clang++"--, "-Lsrc/GC/lib/", "-l:libgcoll.a"
, "-fno-exceptions -x", "ir" ,"-o" ,"output/hello_world"
, "-"]
compile :: String -> IO String compile :: String -> IO String
compile s = optimize s >>= compileClang compile s = optimize s >>= compileClang