Fixed scoping of function pointers.

This commit is contained in:
Samuel Hammersberg 2023-02-16 11:17:45 +01:00
parent 5680334fde
commit 46c6f5b7ab
2 changed files with 45 additions and 26 deletions

View file

@ -13,6 +13,7 @@ import LlvmIr (
LLVMIr (..), LLVMIr (..),
LLVMType (..), LLVMType (..),
LLVMValue (..), LLVMValue (..),
Visibility (..),
llvmIrToString, llvmIrToString,
) )
import TypeChecker (partitionType) import TypeChecker (partitionType)
@ -108,7 +109,7 @@ compile (Program prg) = do
goDef :: [Bind] -> CompilerState () goDef :: [Bind] -> CompilerState ()
goDef [] = return () goDef [] = return ()
goDef (Bind id@(name, t) args exp : xs) = do goDef (Bind (name, t) args exp : xs) = do
emit $ UnsafeRaw "\n" emit $ UnsafeRaw "\n"
emit $ Comment $ show name <> ": " <> show exp emit $ Comment $ show name <> ": " <> show exp
emit $ Define (type2LlvmType t_return) name (map (second type2LlvmType) args) emit $ Define (type2LlvmType t_return) name (map (second type2LlvmType) args)
@ -159,14 +160,18 @@ compile (Program prg) = do
emitApp t e1 e2 = appEmitter t e1 e2 [] emitApp t e1 e2 = appEmitter t e1 e2 []
where where
appEmitter :: Type -> Exp -> Exp -> [Exp] -> CompilerState () appEmitter :: Type -> Exp -> Exp -> [Exp] -> CompilerState ()
appEmitter t e1 e2 stack = do appEmitter _t e1 e2 stack = do
let newStack = e2 : stack let newStack = e2 : stack
case e1 of case e1 of
EApp t' e1' e2' -> appEmitter t' e1' e2' newStack EApp t' e1' e2' -> appEmitter t' e1' e2' newStack
EId (name, _) -> do EId id@(name, t') -> do
args <- traverse exprToValue newStack args <- traverse exprToValue newStack
vs <- getNewVar vs <- getNewVar
emit $ SetVariable (Ident $ show vs) (Call (type2LlvmType t) name (map (I64,) args)) funcs <- gets functions
let vis = case Map.lookup id funcs of
Nothing -> Local
Just _ -> Global
emit $ SetVariable (Ident $ show vs) (Call (type2LlvmType t') vis name (map (I64,) args))
x -> do x -> do
emit . Comment $ "The unspeakable happened: " emit . Comment $ "The unspeakable happened: "
emit . Comment $ show x emit . Comment $ show x
@ -244,7 +249,7 @@ compile (Program prg) = do
case Map.lookup id funcs of case Map.lookup id funcs of
Just _ -> do Just _ -> do
vc <- getNewVar vc <- getNewVar
emit $ SetVariable (Ident $ show vc) (Call (type2LlvmType t) name []) emit $ SetVariable (Ident $ show vc) (Call (type2LlvmType t) Global name [])
return $ VIdent (Ident $ show vc, t) return $ VIdent (Ident $ show vc, t)
Nothing -> return $ VIdent id Nothing -> return $ VIdent id
exprToValue e = do exprToValue e = do
@ -255,5 +260,5 @@ compile (Program prg) = do
type2LlvmType :: Type -> LLVMType type2LlvmType :: Type -> LLVMType
type2LlvmType = \case type2LlvmType = \case
TInt -> I64 TInt -> I64
TFun t _ -> type2LlvmType t TFun t xs -> Function (type2LlvmType t) [type2LlvmType xs]
t -> CustomType $ Ident ("\"" ++ show t ++ "\"") t -> CustomType $ Ident ("\"" ++ show t ++ "\"")

View file

@ -1,11 +1,17 @@
{-# LANGUAGE LambdaCase #-} {-# LANGUAGE LambdaCase #-}
module LlvmIr (LLVMType (..), LLVMIr (..), llvmIrToString, LLVMValue (..), LLVMComp (..)) where module LlvmIr (
LLVMType (..),
LLVMIr (..),
llvmIrToString,
LLVMValue (..),
LLVMComp (..),
Visibility (..),
) where
import Data.List (intercalate) import Data.List (intercalate)
import TypeCheckerIr import TypeCheckerIr
-- | A datatype which represents some basic LLVM types -- | A datatype which represents some basic LLVM types
data LLVMType data LLVMType
= I1 = I1
@ -14,6 +20,7 @@ data LLVMType
| I64 | I64
| Ptr | Ptr
| Ref LLVMType | Ref LLVMType
| Function LLVMType [LLVMType]
| Array Integer LLVMType | Array Integer LLVMType
| CustomType Ident | CustomType Ident
@ -26,6 +33,7 @@ instance Show LLVMType where
I64 -> "i64" I64 -> "i64"
Ptr -> "ptr" Ptr -> "ptr"
Ref ty -> show ty <> "*" Ref ty -> show ty <> "*"
Function t xs -> show t <> " (" <> intercalate ", " (map show xs) <> ")*"
Array n ty -> concat ["[", show n, " x ", show ty, "]"] Array n ty -> concat ["[", show n, " x ", show ty, "]"]
CustomType (Ident ty) -> ty CustomType (Ident ty) -> ty
@ -54,6 +62,12 @@ instance Show LLVMComp where
LLSlt -> "slt" LLSlt -> "slt"
LLSle -> "sle" LLSle -> "sle"
data Visibility = Local | Global
instance Show Visibility where
show :: Visibility -> String
show Local = "%"
show Global = "@"
{- | Represents a LLVM "value", as in an integer, a register variable, {- | Represents a LLVM "value", as in an integer, a register variable,
or a string contstant or a string contstant
-} -}
@ -85,7 +99,7 @@ data LLVMIr
| Br Ident | Br Ident
| BrCond LLVMValue Ident Ident | BrCond LLVMValue Ident Ident
| Label Ident | Label Ident
| Call LLVMType Ident Args | Call LLVMType Visibility Ident Args
| Alloca LLVMType | Alloca LLVMType
| Store LLVMType Ident LLVMType Ident | Store LLVMType Ident LLVMType Ident
| Bitcast LLVMType Ident LLVMType | Bitcast LLVMType Ident LLVMType
@ -149,9 +163,9 @@ llvmIrToString = go 0
[ "srem ", show t, " ", show v1, ", " [ "srem ", show t, " ", show v1, ", "
, show v2, "\n" , show v2, "\n"
] ]
(Call t (Ident i) arg) -> (Call t vis (Ident i) arg) ->
concat concat
[ "call ", show t, " @", i, "(" [ "call ", show t, " ", show vis, i, "("
, intercalate ", " $ Prelude.map (\(x, y) -> show x <> " " <> show y) arg , intercalate ", " $ Prelude.map (\(x, y) -> show x <> " " <> show y) arg
, ")\n" , ")\n"
] ]