Fixed wrongly typed functions in the code generator.

This commit is contained in:
Samuel Hammersberg 2023-03-28 17:37:29 +02:00
parent e87e2d3870
commit 230a205965
2 changed files with 135 additions and 105 deletions

View file

@ -1,49 +1,55 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
module Codegen.Codegen (generateCode) where
import Auxiliary (snoc)
import Codegen.LlvmIr as LIR
import Control.Applicative ((<|>))
import Control.Monad (when)
import Control.Monad.State (StateT, execStateT, foldM_,
gets, modify)
import qualified Data.Bifunctor as BI
import Data.Char (ord)
import Data.Coerce (coerce)
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Maybe (fromJust, fromMaybe)
import Data.Set (Set)
import qualified Data.Set as Set
import Data.Tuple.Extra (dupe, first, second)
import Grammar.ErrM (Err)
import Monomorphizer.MonomorphizerIr as MIR
import qualified TypeChecker.TypeCheckerIr as TIR
import Auxiliary (snoc)
import Codegen.LlvmIr as LIR
import Control.Applicative ((<|>))
import Control.Monad (when)
import Control.Monad.State (
StateT,
execStateT,
foldM_,
gets,
modify,
)
import Data.Bifunctor qualified as BI
import Data.Char (ord)
import Data.Coerce (coerce)
import Data.Map (Map)
import Data.Map qualified as Map
import Data.Maybe (fromJust, fromMaybe)
import Data.Set (Set)
import Data.Set qualified as Set
import Data.Tuple.Extra (dupe, first, second)
import Debug.Trace (trace)
import Grammar.ErrM (Err)
import Monomorphizer.MonomorphizerIr as MIR
import TypeChecker.TypeCheckerIr qualified as TIR
-- | The record used as the code generator state
data CodeGenerator = CodeGenerator
{ instructions :: [LLVMIr]
, functions :: Map MIR.Id FunctionInfo
, customTypes :: Set LLVMType
, constructors :: Map TIR.Ident ConstructorInfo
{ instructions :: [LLVMIr]
, functions :: Map MIR.Id FunctionInfo
, customTypes :: Set LLVMType
, constructors :: Map TIR.Ident ConstructorInfo
, variableCount :: Integer
, labelCount :: Integer
, labelCount :: Integer
}
-- | A state type synonym
type CompilerState a = StateT CodeGenerator Err a
data FunctionInfo = FunctionInfo
{ numArgs :: Int
{ numArgs :: Int
, arguments :: [Id]
}
deriving (Show)
data ConstructorInfo = ConstructorInfo
{ numArgsCI :: Int
, argumentsCI :: [Id]
, numCI :: Integer
{ numArgsCI :: Int
, argumentsCI :: [Id]
, numCI :: Integer
, returnTypeCI :: MIR.Type
}
deriving (Show)
@ -55,7 +61,7 @@ emit l = modify $ \t -> t{instructions = Auxiliary.snoc l $ instructions t}
-- | Increases the variable counter in the CodeGenerator state
increaseVarCount :: CompilerState ()
increaseVarCount = do
gets variableCount >>= \s -> emit.Comment $ "increase: " <> show (s + 1)
gets variableCount >>= \s -> emit . Comment $ "increase: " <> show (s + 1)
modify $ \t -> t{variableCount = variableCount t + 1}
-- | Returns the variable count from the CodeGenerator state
@ -94,23 +100,34 @@ getConstructors :: [MIR.Def] -> Map TIR.Ident ConstructorInfo
getConstructors bs = Map.fromList $ go bs
where
go [] = []
go (MIR.DData (MIR.Data t cons) : xs) = fst
(foldl (\(acc, i) (Inj id xs) ->
(( id, ConstructorInfo
{ numArgsCI = length (init . flattenType $ xs)
, argumentsCI = createArgs (init . flattenType $ xs)
, numCI = i
, returnTypeCI = t --last . flattenType $ xs
}
) : acc, i + 1)) ([], 0) cons) <> go xs
go (MIR.DData (MIR.Data t cons) : xs) =
fst
( foldl
( \(acc, i) (Inj id xs) ->
( ( id
, ConstructorInfo
{ numArgsCI = length (init . flattenType $ xs)
, argumentsCI = createArgs (init . flattenType $ xs)
, numCI = i
, returnTypeCI = t -- last . flattenType $ xs
}
)
: acc
, i + 1
)
)
([], 0)
cons
)
<> go xs
go (_ : xs) = go xs
getTypes :: [MIR.Def] -> Set LLVMType
getTypes bs = Set.fromList $ go bs
where
go [] = []
go [] = []
go (MIR.DData (MIR.Data t _) : xs) = type2LlvmType t : go xs
go (_:xs) = go xs
go (_ : xs) = go xs
initCodeGenerator :: [MIR.Def] -> CodeGenerator
initCodeGenerator scs =
@ -165,6 +182,7 @@ test v =
eCaseInt x xs = (ECase (MIR.TLit (MIR.Ident "_Int")) x xs, MIR.TLit (MIR.Ident "_Int"))
int x = (ELit (LInt x), MIR.TLit (MIR.Ident "_Int"))
-}
{- | Compiles an AST and produces a LLVM Ir string.
An easy way to actually "compile" this output is to
Simply pipe it to LLI
@ -172,7 +190,7 @@ test v =
generateCode :: MIR.Program -> Err String
generateCode (MIR.Program scs) = do
let codegen = initCodeGenerator scs
llvmIrToString . instructions <$> execStateT (compileScs scs) codegen
llvmIrToString . instructions <$> execStateT (compileScs (trace (show scs) scs)) codegen
compileScs :: [MIR.Def] -> CompilerState ()
compileScs [] = do
@ -240,16 +258,17 @@ compileScs [] = do
modify $ \s -> s{variableCount = 0}
)
c
compileScs (MIR.DBind (MIR.Bind (name, _t) args exp) : xs) = do
compileScs (MIR.DBind (MIR.Bind (name, t) args exp) : xs) = do
let t_return = type2LlvmType . last . flattenType $ t
emit $ UnsafeRaw "\n"
emit . Comment $ show name <> ": " <> show exp
let args' = map (second type2LlvmType) args
emit $ Define FastCC I64 {-(type2LlvmType t_return)-} name args'
emit $ Define FastCC t_return name args'
when (name == "main") (mapM_ emit firstMainContent)
functionBody <- exprToValue exp
if name == "main"
then mapM_ emit $ lastMainContent functionBody
else emit $ Ret I64 functionBody
else emit $ Ret t_return functionBody
emit DefineEnd
modify $ \s -> s{variableCount = 0}
compileScs xs
@ -267,8 +286,10 @@ compileScs (MIR.DData (MIR.Data typ ts) : xs) = do
firstMainContent :: [LLVMIr]
firstMainContent =
[ UnsafeRaw "call void @_ZN2GC4Heap4initEv()\n"
]
[]
-- UnsafeRaw "call void @_ZN2GC4Heap4initEv()\n"
lastMainContent :: LLVMValue -> [LLVMIr]
lastMainContent var =
[ UnsafeRaw $
@ -284,20 +305,21 @@ defaultStart =
, UnsafeRaw "@.non_exhaustive_patterns = private unnamed_addr constant [41 x i8] c\"Non-exhaustive patterns in case at %i:%i\n\"\n"
, UnsafeRaw "declare i32 @printf(ptr noalias nocapture, ...)\n"
, UnsafeRaw "declare i32 @exit(i32 noundef)\n"
, UnsafeRaw "declare i32 @_ZN2GC4Heap4initEv()\n"
, UnsafeRaw "declare i32 @_ZN2GC4Heap5allocEm()\n"
, UnsafeRaw "declare i32 @_ZN2GC4Heap7disposeEv()\n"
, UnsafeRaw "declare ptr @malloc(i32 noundef)\n"
, UnsafeRaw "declare void @_ZN2GC4Heap4initEv()\n"
, UnsafeRaw "declare void @_ZN2GC4Heap5allocEm()\n"
, UnsafeRaw "declare void @_ZN2GC4Heap7disposeEv()\n"
]
compileExp :: ExpT -> CompilerState ()
compileExp (MIR.ELit lit,t) = emitLit lit
compileExp (MIR.EAdd e1 e2,t) = emitAdd t e1 e2
compileExp (MIR.ELit lit, t) = emitLit lit
compileExp (MIR.EAdd e1 e2, t) = emitAdd t e1 e2
-- compileExp (ESub t e1 e2) = emitSub t e1 e2
compileExp (MIR.EVar name, t) = emitIdent name
compileExp (MIR.EApp e1 e2,t) = emitApp t e1 e2
compileExp (MIR.EVar name, t) = emitIdent name
compileExp (MIR.EApp e1 e2, t) = emitApp t e1 e2
-- compileExp (EAbs t ti e) = emitAbs t ti e
compileExp (MIR.ELet binds e,t) = undefined -- emitLet binds (fst e)
compileExp (MIR.ECase e cs,t) = emitECased t e (map (t,) cs)
compileExp (MIR.ELet binds e, t) = undefined -- emitLet binds (fst e)
compileExp (MIR.ECase e cs, t) = emitECased t e (map (t,) cs)
-- go (EMul e1 e2) = emitMul e1 e2
-- go (EDiv e1 e2) = emitDiv e1 e2
@ -319,7 +341,7 @@ emitECased t e cases = do
-- emit $ Label crashLbl
emit . UnsafeRaw $ "call i32 (ptr, ...) @printf(ptr noundef @.non_exhaustive_patterns, i64 noundef 6, i64 noundef 6)\n"
emit . UnsafeRaw $ "call i32 @exit(i32 noundef 1)\n"
mapM_ (const increaseVarCount) [0..1]
mapM_ (const increaseVarCount) [0 .. 1]
emit $ Br label
emit $ Label label
res <- getNewVar
@ -349,28 +371,28 @@ emitECased t e cases = do
emit $ SetVariable casted (Load (CustomType (coerce consId)) Ptr castPtr)
val <- exprToValue exp
enumerateOneM_
(\i c -> do
( \i c -> do
case c of
PVar x -> do
PVar x -> do
emit . Comment $ "ident " <> show x
emit $ SetVariable (fst x) (ExtractValue (CustomType (coerce consId)) (VIdent casted Ptr) i)
PLit (l, t) -> undefined
PInj id ps -> undefined
PCatch -> pure()
PEnum id -> undefined
--case c of
-- CIdent x -> do
-- emit . Comment $ "ident " <> show x
-- emit $ SetVariable x (ExtractValue (CustomType (fst consId)) (VIdent casted Ptr) i)
-- emit $ Store ty val Ptr stackPtr
-- CCons x cs -> error "nested constructor"
-- CLit l -> do
-- testVar <- getNewVar
-- emit $ SetVariable testVar (ExtractValue (CustomType (fst consId)) (VIdent casted Ptr) i)
-- case l of
-- LInt l -> emit $ Icmp LLEq I64 (VIdent testVar Ptr) (VInteger l)
-- LChar c -> emit $ Icmp LLEq I8 (VIdent testVar Ptr) (VChar c)
-- CCatch -> emit . Comment $ "Catch all"
PInj id ps -> undefined
PCatch -> pure ()
PEnum id -> undefined
-- case c of
-- CIdent x -> do
-- emit . Comment $ "ident " <> show x
-- emit $ SetVariable x (ExtractValue (CustomType (fst consId)) (VIdent casted Ptr) i)
-- emit $ Store ty val Ptr stackPtr
-- CCons x cs -> error "nested constructor"
-- CLit l -> do
-- testVar <- getNewVar
-- emit $ SetVariable testVar (ExtractValue (CustomType (fst consId)) (VIdent casted Ptr) i)
-- case l of
-- LInt l -> emit $ Icmp LLEq I64 (VIdent testVar Ptr) (VInteger l)
-- LChar c -> emit $ Icmp LLEq I8 (VIdent testVar Ptr) (VChar c)
-- CCatch -> emit . Comment $ "Catch all"
)
cs
emit $ Store ty val Ptr stackPtr
@ -379,7 +401,7 @@ emitECased t e cases = do
emitCases rt ty label stackPtr vs (Branch (MIR.PLit i, _) exp) = do
emit $ Comment "Plit"
let i' = case i of
(MIR.LInt i, _) -> VInteger i
(MIR.LInt i, _) -> VInteger i
(MIR.LChar i, _) -> VChar (ord i)
ns <- getNewVar
lbl_failPos <- (\x -> TIR.Ident $ "failed_" <> show x) <$> getNewLabel
@ -391,7 +413,7 @@ emitECased t e cases = do
emit $ Store ty val Ptr stackPtr
emit $ Br label
emit $ Label lbl_failPos
emitCases rt ty label stackPtr vs (Branch (MIR.PVar (id,_), _) exp) = do
emitCases rt ty label stackPtr vs (Branch (MIR.PVar (id, _), _) exp) = do
emit $ Comment "Pvar"
-- //TODO this is pretty disgusting and would heavily benefit from a rewrite
valPtr <- getNewVar
@ -418,7 +440,7 @@ emitECased t e cases = do
lbl_failPos <- (\x -> TIR.Ident $ "failed_" <> show x) <$> getNewLabel
emit $ Label lbl_failPos
--emitLet :: Bind -> Exp -> CompilerState ()
-- emitLet :: Bind -> Exp -> CompilerState ()
emitLet xs e = do
emit $
Comment $
@ -446,8 +468,7 @@ emitApp rt e1 e2 = appEmitter e1 e2 []
let visibility =
fromMaybe Local $
Global <$ Map.lookup name consts
<|>
Global <$ Map.lookup (name, t) funcs
<|> Global <$ Map.lookup (name, t) funcs
-- this piece of code could probably be improved, i.e remove the double `const Global`
args' = map (first valueGetType . dupe) args
call = Call FastCC (type2LlvmType rt) visibility name args'
@ -466,7 +487,7 @@ emitLit :: MIR.Lit -> CompilerState ()
emitLit i = do
-- !!this should never happen!!
let (i', t) = case i of
(MIR.LInt i'') -> (VInteger i'', I64)
(MIR.LInt i'') -> (VInteger i'', I64)
(MIR.LChar i'') -> (VChar $ ord i'', I8)
varCount <- getNewVar
emit $ Comment "This should not have happened!"
@ -489,16 +510,20 @@ emitSub t e1 e2 = do
exprToValue :: ExpT -> CompilerState LLVMValue
exprToValue = \case
(MIR.ELit i, t) -> pure $ case i of
(MIR.LInt i) -> VInteger i
(MIR.LInt i) -> VInteger i
(MIR.LChar i) -> VChar $ ord i
(MIR.EVar name, t) -> do
funcs <- gets functions
cons <- gets constructors
let res = Map.lookup (name, t) funcs
<|>
(\c -> FunctionInfo { numArgs = numArgsCI c
, arguments = argumentsCI c} )
<$> Map.lookup name cons
let res =
Map.lookup (name, t) funcs
<|> ( \c ->
FunctionInfo
{ numArgs = numArgsCI c
, arguments = argumentsCI c
}
)
<$> Map.lookup name cons
case res of
Just fi -> do
if numArgs fi == 0
@ -519,40 +544,42 @@ exprToValue = \case
type2LlvmType :: MIR.Type -> LLVMType
type2LlvmType (MIR.TLit id@(TIR.Ident name)) = case name of
"Int" -> I64
_ -> CustomType id
type2LlvmType (MIR.TFun t xs) = do
"Char" -> I8
_ -> CustomType id
type2LlvmType (MIR.TFun t xs) = do
let (t', xs') = function2LLVMType xs [type2LlvmType t]
Function t' xs'
where
function2LLVMType :: Type -> [LLVMType] -> (LLVMType, [LLVMType])
function2LLVMType (TFun t xs) s = function2LLVMType xs (type2LlvmType t : s)
function2LLVMType x s = (type2LlvmType x, s)
function2LLVMType x s = (type2LlvmType x, s)
getType :: ExpT -> LLVMType
getType (_, t) = type2LlvmType t
getType (_, t) = type2LlvmType t
extractTypeName :: MIR.Type -> TIR.Ident
extractTypeName (MIR.TLit id) = id
extractTypeName (MIR.TFun t xs) = let (TIR.Ident i) = extractTypeName t
(TIR.Ident is) = extractTypeName xs
in TIR.Ident $ i <> "_$_" <> is
extractTypeName (MIR.TFun t xs) =
let (TIR.Ident i) = extractTypeName t
(TIR.Ident is) = extractTypeName xs
in TIR.Ident $ i <> "_$_" <> is
valueGetType :: LLVMValue -> LLVMType
valueGetType (VInteger _) = I64
valueGetType (VChar _) = I8
valueGetType (VIdent _ t) = t
valueGetType (VConstant s) = Array (fromIntegral $ length s) I8
valueGetType (VInteger _) = I64
valueGetType (VChar _) = I8
valueGetType (VIdent _ t) = t
valueGetType (VConstant s) = Array (fromIntegral $ length s) I8
valueGetType (VFunction _ _ t) = t
typeByteSize :: LLVMType -> Integer
typeByteSize I1 = 1
typeByteSize I8 = 1
typeByteSize I32 = 4
typeByteSize I64 = 8
typeByteSize Ptr = 8
typeByteSize (Ref _) = 8
typeByteSize I1 = 1
typeByteSize I8 = 1
typeByteSize I32 = 4
typeByteSize I64 = 8
typeByteSize Ptr = 8
typeByteSize (Ref _) = 8
typeByteSize (Function _ _) = 8
typeByteSize (Array n t) = n * typeByteSize t
typeByteSize (Array n t) = n * typeByteSize t
typeByteSize (CustomType _) = 8
enumerateOneM_ :: Monad m => (Integer -> a -> m b) -> [a] -> m ()

View file

@ -16,7 +16,10 @@ optimize :: String -> IO String
optimize = readCreateProcess (shell "opt --O3 -S")
compileClang :: String -> IO String
compileClang = readCreateProcess (shell "clang -x ir -o output/hello_world -")
compileClang = readCreateProcess . shell
$ unwords ["clang++"--, "-Lsrc/GC/lib/", "-l:libgcoll.a"
, "-fno-exceptions -x", "ir" ,"-o" ,"output/hello_world"
, "-"]
compile :: String -> IO String
compile s = optimize s >>= compileClang