Got higher order functions working.

This commit is contained in:
Samuel Hammersberg 2023-02-16 13:36:45 +01:00
parent 46c6f5b7ab
commit 6d9c42a03e
2 changed files with 41 additions and 20 deletions

View file

@ -126,15 +126,15 @@ compile (Program prg) = do
go :: Exp -> CompilerState () go :: Exp -> CompilerState ()
go (EInt int) = emitInt int go (EInt int) = emitInt int
go (EAdd t e1 e2) = emitAdd t e1 e2 go (EAdd t e1 e2) = emitAdd t e1 e2
-- go (ESub e1 e2) = emitSub e1 e2
-- go (EMul e1 e2) = emitMul e1 e2
-- go (EDiv e1 e2) = emitDiv e1 e2
-- go (EMod e1 e2) = emitMod e1 e2
go (EId (name, _)) = emitIdent name go (EId (name, _)) = emitIdent name
go (EApp t e1 e2) = emitApp t e1 e2 go (EApp t e1 e2) = emitApp t e1 e2
go (EAbs t ti e) = emitAbs t ti e go (EAbs t ti e) = emitAbs t ti e
go (ELet binds e) = emitLet binds e go (ELet binds e) = emitLet binds e
go (EAnn _ _) = emitEAnn go (EAnn _ _) = emitEAnn
-- go (ESub e1 e2) = emitSub e1 e2
-- go (EMul e1 e2) = emitMul e1 e2
-- go (EDiv e1 e2) = emitDiv e1 e2
-- go (EMod e1 e2) = emitMod e1 e2
--- aux functions --- --- aux functions ---
emitEAnn :: CompilerState () emitEAnn :: CompilerState ()
@ -160,18 +160,19 @@ compile (Program prg) = do
emitApp t e1 e2 = appEmitter t e1 e2 [] emitApp t e1 e2 = appEmitter t e1 e2 []
where where
appEmitter :: Type -> Exp -> Exp -> [Exp] -> CompilerState () appEmitter :: Type -> Exp -> Exp -> [Exp] -> CompilerState ()
appEmitter _t e1 e2 stack = do appEmitter t e1 e2 stack = do
let newStack = e2 : stack let newStack = e2 : stack
case e1 of case e1 of
EApp t' e1' e2' -> appEmitter t' e1' e2' newStack EApp _ e1' e2' -> appEmitter t e1' e2' newStack
EId id@(name, t') -> do EId id@(name, _) -> do
args <- traverse exprToValue newStack args <- traverse exprToValue newStack
vs <- getNewVar vs <- getNewVar
funcs <- gets functions funcs <- gets functions
let vis = case Map.lookup id funcs of let vis = case Map.lookup id funcs of
Nothing -> Local Nothing -> Local
Just _ -> Global Just _ -> Global
emit $ SetVariable (Ident $ show vs) (Call (type2LlvmType t') vis name (map (I64,) args)) let call = Call (type2LlvmType t) vis name ((\x -> (valueGetType x, x)) <$> args)
emit $ SetVariable (Ident $ show vs) call
x -> do x -> do
emit . Comment $ "The unspeakable happened: " emit . Comment $ "The unspeakable happened: "
emit . Comment $ show x emit . Comment $ show x
@ -247,18 +248,36 @@ compile (Program prg) = do
exprToValue (EId id@(name, t)) = do exprToValue (EId id@(name, t)) = do
funcs <- gets functions funcs <- gets functions
case Map.lookup id funcs of case Map.lookup id funcs of
Just _ -> do Just fi -> do
vc <- getNewVar if numArgs fi == 0
emit $ SetVariable (Ident $ show vc) (Call (type2LlvmType t) Global name []) then do
return $ VIdent (Ident $ show vc, t) vc <- getNewVar
Nothing -> return $ VIdent id emit $ SetVariable (Ident $ show vc) (Call (type2LlvmType t) Global name [])
return $ VIdent (Ident $ show vc) (type2LlvmType t)
else return $ VFunction name Global (type2LlvmType t)
Nothing -> return $ VIdent name (type2LlvmType t)
exprToValue e = do exprToValue e = do
go e go e
v <- getVarCount v <- getVarCount
return $ VIdent (Ident $ show v, TInt) return $ VIdent (Ident $ show v) (getType e)
type2LlvmType :: Type -> LLVMType type2LlvmType :: Type -> LLVMType
type2LlvmType = \case type2LlvmType = \case
TInt -> I64 TInt -> I64
TFun t xs -> Function (type2LlvmType t) [type2LlvmType xs] TFun t xs -> Function (type2LlvmType t) [type2LlvmType xs]
t -> CustomType $ Ident ("\"" ++ show t ++ "\"") t -> CustomType $ Ident ("\"" ++ show t ++ "\"")
getType :: Exp -> LLVMType
getType (EInt _) = I64
getType (EAdd t _ _) = type2LlvmType t
getType (EId (_, t)) = type2LlvmType t
getType (EApp t _ _) = type2LlvmType t
getType (EAbs t _ _) = type2LlvmType t
getType (ELet _ e) = getType e
getType (EAnn _ t) = type2LlvmType t
valueGetType :: LLVMValue -> LLVMType
valueGetType (VInteger _) = I64
valueGetType (VIdent _ t) = t
valueGetType (VConstant s) = Array (length s) I8
valueGetType (VFunction _ _ t) = t

View file

@ -21,7 +21,7 @@ data LLVMType
| Ptr | Ptr
| Ref LLVMType | Ref LLVMType
| Function LLVMType [LLVMType] | Function LLVMType [LLVMType]
| Array Integer LLVMType | Array Int LLVMType
| CustomType Ident | CustomType Ident
instance Show LLVMType where instance Show LLVMType where
@ -71,13 +71,18 @@ instance Show Visibility where
{- | Represents a LLVM "value", as in an integer, a register variable, {- | Represents a LLVM "value", as in an integer, a register variable,
or a string contstant or a string contstant
-} -}
data LLVMValue = VInteger Integer | VIdent Id | VConstant String data LLVMValue
= VInteger Integer
| VIdent Ident LLVMType
| VConstant String
| VFunction Ident Visibility LLVMType
instance Show LLVMValue where instance Show LLVMValue where
show :: LLVMValue -> String show :: LLVMValue -> String
show v = case v of show v = case v of
VInteger i -> show i VInteger i -> show i
VIdent (n, _) -> "%" <> fromIdent n VIdent (Ident n) _ -> "%" <> n
VFunction (Ident n) vis _ -> show vis <> n
VConstant s -> "c" <> show s VConstant s -> "c" <> show s
type Params = [(Ident, LLVMType)] type Params = [(Ident, LLVMType)]
@ -201,6 +206,3 @@ llvmIrToString = go 0
(Comment s) -> "; " <> s <> "\n" (Comment s) -> "; " <> s <> "\n"
(Variable (Ident id)) -> "%" <> id (Variable (Ident id)) -> "%" <> id
{- FOURMOLU_ENABLE -} {- FOURMOLU_ENABLE -}
fromIdent :: Ident -> String
fromIdent (Ident s) = s