From 6d9c42a03ee93433e36a68493a640e12d8d77f73 Mon Sep 17 00:00:00 2001 From: Samuel Hammersberg Date: Thu, 16 Feb 2023 13:36:45 +0100 Subject: [PATCH] Got higher order functions working. --- src/Compiler.hs | 47 +++++++++++++++++++++++++++++++++-------------- src/LlvmIr.hs | 14 ++++++++------ 2 files changed, 41 insertions(+), 20 deletions(-) diff --git a/src/Compiler.hs b/src/Compiler.hs index 1425a1a..a31c3d8 100644 --- a/src/Compiler.hs +++ b/src/Compiler.hs @@ -126,15 +126,15 @@ compile (Program prg) = do go :: Exp -> CompilerState () go (EInt int) = emitInt int go (EAdd t e1 e2) = emitAdd t e1 e2 - -- go (ESub e1 e2) = emitSub e1 e2 - -- go (EMul e1 e2) = emitMul e1 e2 - -- go (EDiv e1 e2) = emitDiv e1 e2 - -- go (EMod e1 e2) = emitMod e1 e2 go (EId (name, _)) = emitIdent name go (EApp t e1 e2) = emitApp t e1 e2 go (EAbs t ti e) = emitAbs t ti e go (ELet binds e) = emitLet binds e go (EAnn _ _) = emitEAnn + -- go (ESub e1 e2) = emitSub e1 e2 + -- go (EMul e1 e2) = emitMul e1 e2 + -- go (EDiv e1 e2) = emitDiv e1 e2 + -- go (EMod e1 e2) = emitMod e1 e2 --- aux functions --- emitEAnn :: CompilerState () @@ -160,18 +160,19 @@ compile (Program prg) = do emitApp t e1 e2 = appEmitter t e1 e2 [] where appEmitter :: Type -> Exp -> Exp -> [Exp] -> CompilerState () - appEmitter _t e1 e2 stack = do + appEmitter t e1 e2 stack = do let newStack = e2 : stack case e1 of - EApp t' e1' e2' -> appEmitter t' e1' e2' newStack - EId id@(name, t') -> do + EApp _ e1' e2' -> appEmitter t e1' e2' newStack + EId id@(name, _) -> do args <- traverse exprToValue newStack vs <- getNewVar funcs <- gets functions let vis = case Map.lookup id funcs of Nothing -> Local Just _ -> Global - emit $ SetVariable (Ident $ show vs) (Call (type2LlvmType t') vis name (map (I64,) args)) + let call = Call (type2LlvmType t) vis name ((\x -> (valueGetType x, x)) <$> args) + emit $ SetVariable (Ident $ show vs) call x -> do emit . Comment $ "The unspeakable happened: " emit . Comment $ show x @@ -247,18 +248,36 @@ compile (Program prg) = do exprToValue (EId id@(name, t)) = do funcs <- gets functions case Map.lookup id funcs of - Just _ -> do - vc <- getNewVar - emit $ SetVariable (Ident $ show vc) (Call (type2LlvmType t) Global name []) - return $ VIdent (Ident $ show vc, t) - Nothing -> return $ VIdent id + Just fi -> do + if numArgs fi == 0 + then do + vc <- getNewVar + emit $ SetVariable (Ident $ show vc) (Call (type2LlvmType t) Global name []) + return $ VIdent (Ident $ show vc) (type2LlvmType t) + else return $ VFunction name Global (type2LlvmType t) + Nothing -> return $ VIdent name (type2LlvmType t) exprToValue e = do go e v <- getVarCount - return $ VIdent (Ident $ show v, TInt) + return $ VIdent (Ident $ show v) (getType e) type2LlvmType :: Type -> LLVMType type2LlvmType = \case TInt -> I64 TFun t xs -> Function (type2LlvmType t) [type2LlvmType xs] t -> CustomType $ Ident ("\"" ++ show t ++ "\"") + +getType :: Exp -> LLVMType +getType (EInt _) = I64 +getType (EAdd t _ _) = type2LlvmType t +getType (EId (_, t)) = type2LlvmType t +getType (EApp t _ _) = type2LlvmType t +getType (EAbs t _ _) = type2LlvmType t +getType (ELet _ e) = getType e +getType (EAnn _ t) = type2LlvmType t + +valueGetType :: LLVMValue -> LLVMType +valueGetType (VInteger _) = I64 +valueGetType (VIdent _ t) = t +valueGetType (VConstant s) = Array (length s) I8 +valueGetType (VFunction _ _ t) = t diff --git a/src/LlvmIr.hs b/src/LlvmIr.hs index f8a70fe..b29f296 100644 --- a/src/LlvmIr.hs +++ b/src/LlvmIr.hs @@ -21,7 +21,7 @@ data LLVMType | Ptr | Ref LLVMType | Function LLVMType [LLVMType] - | Array Integer LLVMType + | Array Int LLVMType | CustomType Ident instance Show LLVMType where @@ -71,13 +71,18 @@ instance Show Visibility where {- | Represents a LLVM "value", as in an integer, a register variable, or a string contstant -} -data LLVMValue = VInteger Integer | VIdent Id | VConstant String +data LLVMValue + = VInteger Integer + | VIdent Ident LLVMType + | VConstant String + | VFunction Ident Visibility LLVMType instance Show LLVMValue where show :: LLVMValue -> String show v = case v of VInteger i -> show i - VIdent (n, _) -> "%" <> fromIdent n + VIdent (Ident n) _ -> "%" <> n + VFunction (Ident n) vis _ -> show vis <> n VConstant s -> "c" <> show s type Params = [(Ident, LLVMType)] @@ -201,6 +206,3 @@ llvmIrToString = go 0 (Comment s) -> "; " <> s <> "\n" (Variable (Ident id)) -> "%" <> id {- FOURMOLU_ENABLE -} - -fromIdent :: Ident -> String -fromIdent (Ident s) = s