created dummy monomorphizer

This commit is contained in:
sebastianselander 2023-03-23 17:20:19 +01:00
parent 42c8ebc7b6
commit e3df4192bb
6 changed files with 279 additions and 393 deletions

View file

@ -34,7 +34,6 @@ executable language
TypeChecker.TypeChecker TypeChecker.TypeChecker
TypeChecker.TypeCheckerIr TypeChecker.TypeCheckerIr
Renamer.Renamer Renamer.Renamer
LambdaLifter.LambdaLifter
Codegen.Codegen Codegen.Codegen
Codegen.LlvmIr Codegen.LlvmIr

View file

@ -1,56 +1,69 @@
{-# LANGUAGE LambdaCase #-} {-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE OverloadedStrings #-}
module Codegen.Codegen (generateCode) where module Codegen.Codegen (generateCode) where
import Auxiliary (snoc)
import Codegen.LlvmIr (CallingConvention (..), import Auxiliary (snoc)
LLVMComp (..), LLVMIr (..), import Codegen.LlvmIr (
LLVMType (..), LLVMValue (..), CallingConvention (..),
Visibility (..), llvmIrToString) LLVMComp (..),
import Codegen.LlvmIr as LIR LLVMIr (..),
import Control.Applicative ((<|>)) LLVMType (..),
import Control.Monad.State (StateT, execStateT, foldM_, LLVMValue (..),
gets, modify) Visibility (..),
import qualified Data.Bifunctor as BI llvmIrToString,
import Data.List.Extra (trim) )
import Data.Map (Map) import Codegen.LlvmIr as LIR
import qualified Data.Map as Map import Control.Applicative ((<|>))
import Data.Maybe (fromJust, fromMaybe) import Control.Monad.State (
import Data.Tuple.Extra (dupe, first, second) StateT,
import qualified Grammar.Abs as GA execStateT,
import Grammar.ErrM (Err) foldM_,
import Monomorphizer.MonomorphizerIr as MIR gets,
import System.Process.Extra (readCreateProcess, shell) modify,
)
import Data.Bifunctor qualified as BI
import Data.List.Extra (trim)
import Data.Map (Map)
import Data.Map qualified as Map
import Data.Maybe (fromJust, fromMaybe)
import Data.Tuple.Extra (dupe, first, second)
import Grammar.Abs qualified as GA
import Grammar.ErrM (Err)
import Monomorphizer.MonomorphizerIr as MIR
import System.Process.Extra (readCreateProcess, shell)
-- | The record used as the code generator state -- | The record used as the code generator state
data CodeGenerator = CodeGenerator data CodeGenerator = CodeGenerator
{ instructions :: [LLVMIr] { instructions :: [LLVMIr]
, functions :: Map Id FunctionInfo , functions :: Map Id FunctionInfo
, constructors :: Map Id ConstructorInfo , constructors :: Map Id ConstructorInfo
, variableCount :: Integer , variableCount :: Integer
, labelCount :: Integer , labelCount :: Integer
} }
-- | A state type synonym -- | A state type synonym
type CompilerState a = StateT CodeGenerator Err a type CompilerState a = StateT CodeGenerator Err a
data FunctionInfo = FunctionInfo data FunctionInfo = FunctionInfo
{ numArgs :: Int { numArgs :: Int
, arguments :: [Id] , arguments :: [Id]
} deriving Show }
deriving (Show)
data ConstructorInfo = ConstructorInfo data ConstructorInfo = ConstructorInfo
{ numArgsCI :: Int { numArgsCI :: Int
, argumentsCI :: [Id] , argumentsCI :: [Id]
, numCI :: Integer , numCI :: Integer
} deriving Show }
deriving (Show)
-- | Adds a instruction to the CodeGenerator state -- | Adds a instruction to the CodeGenerator state
emit :: LLVMIr -> CompilerState () emit :: LLVMIr -> CompilerState ()
emit l = modify $ \t -> t { instructions = Auxiliary.snoc l $ instructions t } emit l = modify $ \t -> t{instructions = Auxiliary.snoc l $ instructions t}
-- | Increases the variable counter in the CodeGenerator state -- | Increases the variable counter in the CodeGenerator state
increaseVarCount :: CompilerState () increaseVarCount :: CompilerState ()
increaseVarCount = modify $ \t -> t { variableCount = variableCount t + 1 } increaseVarCount = modify $ \t -> t{variableCount = variableCount t + 1}
-- | Returns the variable count from the CodeGenerator state -- | Returns the variable count from the CodeGenerator state
getVarCount :: CompilerState Integer getVarCount :: CompilerState Integer
@ -66,76 +79,106 @@ getNewLabel = do
modify (\t -> t{labelCount = labelCount t + 1}) modify (\t -> t{labelCount = labelCount t + 1})
gets labelCount gets labelCount
-- | Produces a map of functions infos from a list of binds, {- | Produces a map of functions infos from a list of binds,
-- which contains useful data for code generation. which contains useful data for code generation.
-}
getFunctions :: [Bind] -> Map Id FunctionInfo getFunctions :: [Bind] -> Map Id FunctionInfo
getFunctions bs = Map.fromList $ go bs getFunctions bs = Map.fromList $ go bs
where where
go [] = [] go [] = []
go (Bind id args _ : xs) = go (Bind id args _ : xs) =
(id, FunctionInfo { numArgs=length args, arguments=args }) (id, FunctionInfo{numArgs = length args, arguments = args})
: go xs : go xs
go (DataType n cons : xs) = do go (DataType n cons : xs) =
map (\(Constructor id xs) -> ((id, MIR.Type n), FunctionInfo { do
numArgs=length xs, arguments=createArgs xs map
})) cons ( \(Constructor id xs) ->
<> go xs ( (id, MIR.Type n)
, FunctionInfo
{ numArgs = length xs
, arguments = createArgs xs
}
)
)
cons
<> go xs
createArgs :: [Type] -> [Id] createArgs :: [Type] -> [Id]
createArgs xs = fst $ foldl (\(acc, l) t -> (acc ++ [(GA.Ident ("arg_" <> show l) , t)],l+1)) ([], 0) xs createArgs xs = fst $ foldl (\(acc, l) t -> (acc ++ [(GA.Ident ("arg_" <> show l), t)], l + 1)) ([], 0) xs
-- | Produces a map of functions infos from a list of binds, {- | Produces a map of functions infos from a list of binds,
-- which contains useful data for code generation. which contains useful data for code generation.
-}
getConstructors :: [Bind] -> Map Id ConstructorInfo getConstructors :: [Bind] -> Map Id ConstructorInfo
getConstructors bs = Map.fromList $ go bs getConstructors bs = Map.fromList $ go bs
where where
go [] = [] go [] = []
go (DataType (GA.Ident n) cons : xs) = do go (DataType (GA.Ident n) cons : xs) =
fst (foldl (\(acc,i) (Constructor (GA.Ident id) xs) -> (((GA.Ident (n <> "_" <> id), MIR.Type (GA.Ident n)), ConstructorInfo { do
numArgsCI=length xs, fst
argumentsCI=createArgs xs, ( foldl
numCI=i ( \(acc, i) (Constructor (GA.Ident id) xs) ->
}) : acc, i+1)) ([],0) cons) ( ( (GA.Ident (n <> "_" <> id), MIR.Type (GA.Ident n))
<> go xs , ConstructorInfo
go (_: xs) = go xs { numArgsCI = length xs
, argumentsCI = createArgs xs
, numCI = i
}
)
: acc
, i + 1
)
)
([], 0)
cons
)
<> go xs
go (_ : xs) = go xs
initCodeGenerator :: [Bind] -> CodeGenerator initCodeGenerator :: [Bind] -> CodeGenerator
initCodeGenerator scs = CodeGenerator { instructions = defaultStart initCodeGenerator scs =
, functions = getFunctions scs CodeGenerator
, constructors = getConstructors scs { instructions = defaultStart
, variableCount = 0 , functions = getFunctions scs
, labelCount = 0 , constructors = getConstructors scs
} , variableCount = 0
, labelCount = 0
}
run :: Err String -> IO () run :: Err String -> IO ()
run s = do run s = do
let s' = case s of let s' = case s of
Right s -> s Right s -> s
Left _ -> error "yo" Left _ -> error "yo"
writeFile "output/llvm.ll" s' writeFile "output/llvm.ll" s'
putStrLn . trim =<< readCreateProcess (shell "lli") s' putStrLn . trim =<< readCreateProcess (shell "lli") s'
test :: Integer -> Program test :: Integer -> Program
test v = Program test v =
[ DataType (GA.Ident "Craig") [ Program
Constructor (GA.Ident "Bob") [MIR.Type (GA.Ident "_Int")], [ DataType
Constructor (GA.Ident "Betty") [MIR.Type (GA.Ident "_Int")] (GA.Ident "Craig")
] [ Constructor (GA.Ident "Bob") [MIR.Type (GA.Ident "_Int")]
, DataType (GA.Ident "Alice") [ , Constructor (GA.Ident "Betty") [MIR.Type (GA.Ident "_Int")]
Constructor (GA.Ident "Eve") [MIR.Type (GA.Ident "_Int")]--,
--(GA.Ident "Alice", [TInt, TInt])
]
, Bind (GA.Ident "fibonacci", MIR.Type (GA.Ident "_Int")) [(GA.Ident "x", MIR.Type (GA.Ident "_Int"))] (EId ("x", MIR.Type (GA.Ident "Craig")), MIR.Type (GA.Ident "Craig"))
, Bind (GA.Ident "main", MIR.Type (GA.Ident "_Int")) []
--(EApp (MIR.Type (GA.Ident "Craig")) (EId (GA.Ident "Craig_Bob", MIR.Type (GA.Ident "Craig")), MIR.Type (GA.Ident "Craig")) (ELit (LInt v), MIR.Type (GA.Ident "_Int")), MIR.Type (GA.Ident "Craig"))-- (EInt 92)
$ eCaseInt (EApp (MIR.Type (GA.Ident "Craig")) (EId (GA.Ident "Craig_Bob", MIR.Type (GA.Ident "Craig")), MIR.Type (GA.Ident "Craig")) (ELit (LInt v), MIR.Type (GA.Ident "_Int")), MIR.Type (GA.Ident "Craig"))
[ injectionCons "Craig_Bob" "Craig" [CIdent (GA.Ident "x")] (EId (GA.Ident "x", MIR.Type (GA.Ident "_Int")), MIR.Type (GA.Ident "_Int"))
, injectionCons "Craig_Betty" "Craig" [CLit (LInt 5)] (int 2)
, Injection (CIdent (GA.Ident "z")) (int 3)
--, injectionInt 5 (int 6)
, injectionCatchAll (int 10)
] ]
] , DataType
(GA.Ident "Alice")
[ Constructor (GA.Ident "Eve") [MIR.Type (GA.Ident "_Int")] -- ,
-- (GA.Ident "Alice", [TInt, TInt])
]
, Bind (GA.Ident "fibonacci", MIR.Type (GA.Ident "_Int")) [(GA.Ident "x", MIR.Type (GA.Ident "_Int"))] (EId ("x", MIR.Type (GA.Ident "Craig")), MIR.Type (GA.Ident "Craig"))
, Bind (GA.Ident "main", MIR.Type (GA.Ident "_Int")) []
-- (EApp (MIR.Type (GA.Ident "Craig")) (EId (GA.Ident "Craig_Bob", MIR.Type (GA.Ident "Craig")), MIR.Type (GA.Ident "Craig")) (ELit (LInt v), MIR.Type (GA.Ident "_Int")), MIR.Type (GA.Ident "Craig"))-- (EInt 92)
$
eCaseInt
(EApp (MIR.Type (GA.Ident "Craig")) (EId (GA.Ident "Craig_Bob", MIR.Type (GA.Ident "Craig")), MIR.Type (GA.Ident "Craig")) (ELit (LInt v), MIR.Type (GA.Ident "_Int")), MIR.Type (GA.Ident "Craig"))
[ injectionCons "Craig_Bob" "Craig" [CIdent (GA.Ident "x")] (EId (GA.Ident "x", MIR.Type (GA.Ident "_Int")), MIR.Type (GA.Ident "_Int"))
, injectionCons "Craig_Betty" "Craig" [CLit (LInt 5)] (int 2)
, Injection (CIdent (GA.Ident "z")) (int 3)
, -- , injectionInt 5 (int 6)
injectionCatchAll (int 10)
]
]
where where
injectionCons x y xs = Injection (CCons (GA.Ident x, MIR.Type (GA.Ident y)) xs) injectionCons x y xs = Injection (CCons (GA.Ident x, MIR.Type (GA.Ident y)) xs)
injectionInt x = Injection (CLit (LInt x)) injectionInt x = Injection (CLit (LInt x))
@ -153,11 +196,12 @@ generateCode (Program scs) = do
llvmIrToString . instructions <$> execStateT (compileScs scs) codegen llvmIrToString . instructions <$> execStateT (compileScs scs) codegen
compileScs :: [Bind] -> CompilerState () compileScs :: [Bind] -> CompilerState ()
compileScs [] = do compileScs [] = do
-- as a last step create all the constructors -- as a last step create all the constructors
-- //TODO maybe merge this with the data type match? -- //TODO maybe merge this with the data type match?
c <- gets (Map.toList . constructors) c <- gets (Map.toList . constructors)
mapM_ (\((id, t), ci) -> do mapM_
( \((id, t), ci) -> do
let t' = type2LlvmType t let t' = type2LlvmType t
let x = BI.second type2LlvmType <$> argumentsCI ci let x = BI.second type2LlvmType <$> argumentsCI ci
emit $ Define FastCC t' id x emit $ Define FastCC t' id x
@ -166,32 +210,47 @@ compileScs [] = do
-- allocated the primary type -- allocated the primary type
emit $ SetVariable top (Alloca t') emit $ SetVariable top (Alloca t')
-- set the first byte to the index of the constructor -- set the first byte to the index of the constructor
emit $ SetVariable ptr $ emit $
GetElementPtr t' (Ref t') (VIdent top I8) SetVariable ptr $
I64 (VInteger 0) GetElementPtr
I32 (VInteger 0) t'
emit $ Store I8 (VInteger $ numCI ci ) (Ref I8) ptr (Ref t')
(VIdent top I8)
I64
(VInteger 0)
I32
(VInteger 0)
emit $ Store I8 (VInteger $ numCI ci) (Ref I8) ptr
-- get a pointer of the correct type -- get a pointer of the correct type
ptr' <- getNewVar ptr' <- getNewVar
emit $ SetVariable ptr' (Bitcast (Ref t') (VIdent top Ptr) (Ref $ CustomType id)) emit $ SetVariable ptr' (Bitcast (Ref t') (VIdent top Ptr) (Ref $ CustomType id))
--emit $ UnsafeRaw "\n" -- emit $ UnsafeRaw "\n"
enumerateOneM_ (\i (GA.Ident arg_n, arg_t) -> do enumerateOneM_
let arg_t' = type2LlvmType arg_t ( \i (GA.Ident arg_n, arg_t) -> do
emit $ Comment (toIr arg_t' <>" "<> arg_n <> " " <> show i ) let arg_t' = type2LlvmType arg_t
elemPtr <- getNewVar emit $ Comment (toIr arg_t' <> " " <> arg_n <> " " <> show i)
emit $ SetVariable elemPtr ( elemPtr <- getNewVar
GetElementPtr (CustomType id) (Ref (CustomType id)) emit $
(VIdent ptr' Ptr) SetVariable
I64 (VInteger 0) elemPtr
I32 (VInteger i)) ( GetElementPtr
emit $ Store arg_t' (VIdent (GA.Ident arg_n) arg_t') Ptr elemPtr (CustomType id)
) (argumentsCI ci) (Ref (CustomType id))
(VIdent ptr' Ptr)
I64
(VInteger 0)
I32
(VInteger i)
)
emit $ Store arg_t' (VIdent (GA.Ident arg_n) arg_t') Ptr elemPtr
)
(argumentsCI ci)
--emit $ UnsafeRaw "\n" -- emit $ UnsafeRaw "\n"
-- load and return the constructed value -- load and return the constructed value
emit $ Comment "Return the newly constructed value" emit $ Comment "Return the newly constructed value"
@ -200,8 +259,9 @@ compileScs [] = do
emit $ Ret t' (VIdent load t') emit $ Ret t' (VIdent load t')
emit DefineEnd emit DefineEnd
modify $ \s -> s { variableCount = 0 } modify $ \s -> s{variableCount = 0}
) c )
c
compileScs (Bind (name, _t) args exp : xs) = do compileScs (Bind (name, _t) args exp : xs) = do
emit $ UnsafeRaw "\n" emit $ UnsafeRaw "\n"
emit . Comment $ show name <> ": " <> show exp emit . Comment $ show name <> ": " <> show exp
@ -212,18 +272,20 @@ compileScs (Bind (name, _t) args exp : xs) = do
then mapM_ emit $ mainContent functionBody then mapM_ emit $ mainContent functionBody
else emit $ Ret I64 functionBody else emit $ Ret I64 functionBody
emit DefineEnd emit DefineEnd
modify $ \s -> s { variableCount = 0 } modify $ \s -> s{variableCount = 0}
compileScs xs compileScs xs
compileScs (DataType id@(GA.Ident outer_id) ts : xs) = do compileScs (DataType id@(GA.Ident outer_id) ts : xs) = do
let biggestVariant = maximum ((\(Constructor _ t) -> sum $ typeByteSize . type2LlvmType <$> t) <$> ts) let biggestVariant = maximum ((\(Constructor _ t) -> sum $ typeByteSize . type2LlvmType <$> t) <$> ts)
emit $ LIR.Type id [I8, Array biggestVariant I8] emit $ LIR.Type id [I8, Array biggestVariant I8]
mapM_ (\(Constructor (GA.Ident inner_id) fi) -> do mapM_
( \(Constructor (GA.Ident inner_id) fi) -> do
emit $ LIR.Type (GA.Ident $ outer_id <> "_" <> inner_id) (I8 : map type2LlvmType fi) emit $ LIR.Type (GA.Ident $ outer_id <> "_" <> inner_id) (I8 : map type2LlvmType fi)
) ts )
ts
compileScs xs compileScs xs
-- where -- where
-- _t_return = snd $ partitionType (length args) t -- _t_return = snd $ partitionType (length args) t
mainContent :: LLVMValue -> [LLVMIr] mainContent :: LLVMValue -> [LLVMIr]
mainContent var = mainContent var =
@ -233,7 +295,7 @@ mainContent var =
-- " %3 = bitcast %Craig* %2 to i72*\n" <> -- " %3 = bitcast %Craig* %2 to i72*\n" <>
-- " %4 = load i72, ptr %3\n" <> -- " %4 = load i72, ptr %3\n" <>
-- " call i32 (ptr, ...) @printf(ptr noundef @.str, i72 noundef %4)\n" -- " call i32 (ptr, ...) @printf(ptr noundef @.str, i72 noundef %4)\n"
"call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef " <> toIr var <> ")\n" "call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef " <> toIr var <> ")\n"
, -- , SetVariable (GA.Ident "p") (Icmp LLEq I64 (VInteger 2) (VInteger 2)) , -- , SetVariable (GA.Ident "p") (Icmp LLEq I64 (VInteger 2) (VInteger 2))
-- , BrCond (VIdent (GA.Ident "p")) (GA.Ident "b_1") (GA.Ident "b_2") -- , BrCond (VIdent (GA.Ident "p")) (GA.Ident "b_1") (GA.Ident "b_2")
-- , Label (GA.Ident "b_1") -- , Label (GA.Ident "b_1")
@ -249,24 +311,26 @@ mainContent var =
] ]
defaultStart :: [LLVMIr] defaultStart :: [LLVMIr]
defaultStart = [ UnsafeRaw "target triple = \"x86_64-pc-linux-gnu\"\n" defaultStart =
, UnsafeRaw "target datalayout = \"e-m:o-i64:64-f80:128-n8:16:32:64-S128\"\n" [ UnsafeRaw "target triple = \"x86_64-pc-linux-gnu\"\n"
, UnsafeRaw "@.str = private unnamed_addr constant [3 x i8] c\"%x\n\", align 1\n" , UnsafeRaw "target datalayout = \"e-m:o-i64:64-f80:128-n8:16:32:64-S128\"\n"
, UnsafeRaw "declare i32 @printf(ptr noalias nocapture, ...)\n" , UnsafeRaw "@.str = private unnamed_addr constant [3 x i8] c\"%x\n\", align 1\n"
] , UnsafeRaw "declare i32 @printf(ptr noalias nocapture, ...)\n"
]
compileExp :: Exp -> CompilerState () compileExp :: Exp -> CompilerState ()
compileExp (ELit lit) = emitLit lit compileExp (ELit lit) = emitLit lit
compileExp (EAdd t e1 e2) = emitAdd t (fst e1) (fst e2) compileExp (EAdd t e1 e2) = emitAdd t (fst e1) (fst e2)
--compileExp (ESub t e1 e2) = emitSub t e1 e2 -- compileExp (ESub t e1 e2) = emitSub t e1 e2
compileExp (EId (name, _)) = emitIdent name compileExp (EId (name, _)) = emitIdent name
compileExp (EApp t e1 e2) = emitApp t (fst e1) (fst e2) compileExp (EApp t e1 e2) = emitApp t (fst e1) (fst e2)
--compileExp (EAbs t ti e) = emitAbs t ti e -- compileExp (EAbs t ti e) = emitAbs t ti e
compileExp (ELet _ binds e) = undefined emitLet binds (fst e) compileExp (ELet _ binds e) = undefined emitLet binds (fst e)
compileExp (ECase t e cs) = emitECased t e (map (t,) cs) compileExp (ECase t e cs) = emitECased t e (map (t,) cs)
-- go (EMul e1 e2) = emitMul e1 e2
-- go (EDiv e1 e2) = emitDiv e1 e2 -- go (EMul e1 e2) = emitMul e1 e2
-- go (EMod e1 e2) = emitMod e1 e2 -- go (EDiv e1 e2) = emitDiv e1 e2
-- go (EMod e1 e2) = emitMod e1 e2
--- aux functions --- --- aux functions ---
emitECased :: Type -> ExpT -> [(Type, Injection)] -> CompilerState () emitECased :: Type -> ExpT -> [(Type, Injection)] -> CompilerState ()
@ -309,31 +373,33 @@ emitECased t e cases = do
emit $ SetVariable casted (Load (CustomType (fst consId)) Ptr castedPtr) emit $ SetVariable casted (Load (CustomType (fst consId)) Ptr castedPtr)
val <- exprToValue (fst exp) val <- exprToValue (fst exp)
enumerateOneM_ (\i c -> do enumerateOneM_
( \i c -> do
case c of case c of
CIdent x -> do CIdent x -> do
emit . Comment $ "ident " <> show x emit . Comment $ "ident " <> show x
emit $ SetVariable x (ExtractValue (CustomType (fst consId)) (VIdent casted Ptr) i) emit $ SetVariable x (ExtractValue (CustomType (fst consId)) (VIdent casted Ptr) i)
emit $ Store ty val Ptr stackPtr emit $ Store ty val Ptr stackPtr
CCons x cs -> error "nested constructor" CCons x cs -> error "nested constructor"
CLit l -> do CLit l -> do
testVar <- getNewVar testVar <- getNewVar
emit $ SetVariable testVar (ExtractValue (CustomType (fst consId)) (VIdent casted Ptr) i) emit $ SetVariable testVar (ExtractValue (CustomType (fst consId)) (VIdent casted Ptr) i)
case l of case l of
LInt l -> emit $ Icmp LLEq I64 (VIdent testVar Ptr) (VInteger l) LInt l -> emit $ Icmp LLEq I64 (VIdent testVar Ptr) (VInteger l)
LChar c -> emit $ Icmp LLEq I8 (VIdent testVar Ptr) (VChar c) LChar c -> emit $ Icmp LLEq I8 (VIdent testVar Ptr) (VChar c)
CatchAll -> emit . Comment $ "Catch all" CatchAll -> emit . Comment $ "Catch all"
emit . Comment $ "return this " <> toIr val emit . Comment $ "return this " <> toIr val
emit . Comment . show $ c emit . Comment . show $ c
emit . Comment . show $ i emit . Comment . show $ i
) cs )
cs
-- emit $ Store ty val Ptr stackPtr -- emit $ Store ty val Ptr stackPtr
emit $ Br label emit $ Br label
emit $ Label lbl_failPos emit $ Label lbl_failPos
emitCases rt ty label stackPtr vs (Injection (MIR.CLit i) exp) = do emitCases rt ty label stackPtr vs (Injection (MIR.CLit i) exp) = do
let i' = case i of let i' = case i of
LInt i -> VInteger i LInt i -> VInteger i
LChar i -> VChar i LChar i -> VChar i
ns <- getNewVar ns <- getNewVar
lbl_failPos <- (\x -> GA.Ident $ "failed_" <> show x) <$> getNewLabel lbl_failPos <- (\x -> GA.Ident $ "failed_" <> show x) <$> getNewLabel
lbl_succPos <- (\x -> GA.Ident $ "success_" <> show x) <$> getNewLabel lbl_succPos <- (\x -> GA.Ident $ "success_" <> show x) <$> getNewLabel
@ -359,7 +425,6 @@ emitECased t e cases = do
emit $ Store ty val Ptr stackPtr emit $ Store ty val Ptr stackPtr
emit $ Br label emit $ Br label
emitLet :: Bind -> Exp -> CompilerState () emitLet :: Bind -> Exp -> CompilerState ()
emitLet xs e = do emitLet xs e = do
emit $ emit $
@ -380,18 +445,18 @@ emitApp t e1 e2 = appEmitter t e1 e2 []
let newStack = e2 : stack let newStack = e2 : stack
case e1 of case e1 of
EApp _ (e1', _) (e2', _) -> appEmitter t e1' e2' newStack EApp _ (e1', _) (e2', _) -> appEmitter t e1' e2' newStack
EId id@(GA.Ident name,_ ) -> do EId id@(GA.Ident name, _) -> do
args <- traverse exprToValue newStack args <- traverse exprToValue newStack
vs <- getNewVar vs <- getNewVar
funcs <- gets functions funcs <- gets functions
consts <- gets constructors consts <- gets constructors
let visibility = fromMaybe Local $ let visibility =
Global <$ Map.lookup id consts fromMaybe Local $
<|> Global <$ Map.lookup id consts
Global <$ Map.lookup id funcs <|> Global <$ Map.lookup id funcs
-- this piece of code could probably be improved, i.e remove the double `const Global` -- this piece of code could probably be improved, i.e remove the double `const Global`
args' = map (first valueGetType . dupe) args args' = map (first valueGetType . dupe) args
call = Call FastCC (type2LlvmType t) visibility (GA.Ident name) args' call = Call FastCC (type2LlvmType t) visibility (GA.Ident name) args'
emit $ SetVariable vs call emit $ SetVariable vs call
x -> error $ "The unspeakable happened: " <> show x x -> error $ "The unspeakable happened: " <> show x
@ -405,14 +470,13 @@ emitIdent id = do
emitLit :: Lit -> CompilerState () emitLit :: Lit -> CompilerState ()
emitLit i = do emitLit i = do
-- !!this should never happen!! -- !!this should never happen!!
let (i',t) = case i of let (i', t) = case i of
(LInt i'') -> (VInteger i'',I64) (LInt i'') -> (VInteger i'', I64)
(LChar i'') -> (VChar i'', I8) (LChar i'') -> (VChar i'', I8)
varCount <- getNewVar varCount <- getNewVar
emit $ Comment "This should not have happened!" emit $ Comment "This should not have happened!"
emit $ SetVariable (GA.Ident (show varCount)) (Add t i' (VInteger 0)) emit $ SetVariable (GA.Ident (show varCount)) (Add t i' (VInteger 0))
emitAdd :: Type -> Exp -> Exp -> CompilerState () emitAdd :: Type -> Exp -> Exp -> CompilerState ()
emitAdd t e1 e2 = do emitAdd t e1 e2 = do
v1 <- exprToValue e1 v1 <- exprToValue e1
@ -430,8 +494,8 @@ emitSub t e1 e2 = do
exprToValue :: Exp -> CompilerState LLVMValue exprToValue :: Exp -> CompilerState LLVMValue
exprToValue = \case exprToValue = \case
ELit i -> pure $ case i of ELit i -> pure $ case i of
(LInt i) -> VInteger i (LInt i) -> VInteger i
(LChar i) -> VChar i (LChar i) -> VChar i
EId id@(name, t) -> do EId id@(name, t) -> do
funcs <- gets functions funcs <- gets functions
case Map.lookup id funcs of case Map.lookup id funcs of
@ -439,8 +503,10 @@ exprToValue = \case
if numArgs fi == 0 if numArgs fi == 0
then do then do
vc <- getNewVar vc <- getNewVar
emit $ SetVariable vc emit $
(Call FastCC (type2LlvmType t) Global name []) SetVariable
vc
(Call FastCC (type2LlvmType t) Global name [])
pure $ VIdent vc (type2LlvmType t) pure $ VIdent vc (type2LlvmType t)
else pure $ VFunction name Global (type2LlvmType t) else pure $ VFunction name Global (type2LlvmType t)
Nothing -> pure $ VIdent name (type2LlvmType t) Nothing -> pure $ VIdent name (type2LlvmType t)
@ -452,45 +518,45 @@ exprToValue = \case
type2LlvmType :: Type -> LLVMType type2LlvmType :: Type -> LLVMType
type2LlvmType (MIR.Type (GA.Ident t)) = case t of type2LlvmType (MIR.Type (GA.Ident t)) = case t of
"_Int" -> I64 "_Int" -> I64
t -> CustomType (GA.Ident t) t -> CustomType (GA.Ident t)
-- TInt -> I64
-- TFun t xs -> do -- TInt -> I64
-- let (t', xs') = function2LLVMType xs [type2LlvmType t] -- TFun t xs -> do
-- Function t' xs' -- let (t', xs') = function2LLVMType xs [type2LlvmType t]
-- TPol t -> CustomType t -- Function t' xs'
--where -- TPol t -> CustomType t
-- function2LLVMType :: Type -> [LLVMType] -> (LLVMType, [LLVMType]) -- where
-- function2LLVMType (TFun t xs) s = function2LLVMType xs (type2LlvmType t : s) -- function2LLVMType :: Type -> [LLVMType] -> (LLVMType, [LLVMType])
-- function2LLVMType x s = (type2LlvmType x, s) -- function2LLVMType (TFun t xs) s = function2LLVMType xs (type2LlvmType t : s)
-- function2LLVMType x s = (type2LlvmType x, s)
getType :: Exp -> LLVMType getType :: Exp -> LLVMType
getType (ELit l) = I64 getType (ELit l) = I64
getType (EAdd t _ _) = type2LlvmType t getType (EAdd t _ _) = type2LlvmType t
--getType (ESub t _ _) = type2LlvmType t -- getType (ESub t _ _) = type2LlvmType t
getType (EId (_, t)) = type2LlvmType t getType (EId (_, t)) = type2LlvmType t
getType (EApp t _ _) = type2LlvmType t getType (EApp t _ _) = type2LlvmType t
--getType (EAbs t _ _) = type2LlvmType t -- getType (EAbs t _ _) = type2LlvmType t
getType (ELet (_, t) _ e) = type2LlvmType t getType (ELet (_, t) _ e) = type2LlvmType t
getType (ECase t _ _) = type2LlvmType t getType (ECase t _ _) = type2LlvmType t
valueGetType :: LLVMValue -> LLVMType valueGetType :: LLVMValue -> LLVMType
valueGetType (VInteger _) = I64 valueGetType (VInteger _) = I64
valueGetType (VChar _) = I8 valueGetType (VChar _) = I8
valueGetType (VIdent _ t) = t valueGetType (VIdent _ t) = t
valueGetType (VConstant s) = Array (fromIntegral $ length s) I8 valueGetType (VConstant s) = Array (fromIntegral $ length s) I8
valueGetType (VFunction _ _ t) = t valueGetType (VFunction _ _ t) = t
typeByteSize :: LLVMType -> Integer typeByteSize :: LLVMType -> Integer
typeByteSize I1 = 1 typeByteSize I1 = 1
typeByteSize I8 = 1 typeByteSize I8 = 1
typeByteSize I32 = 4 typeByteSize I32 = 4
typeByteSize I64 = 8 typeByteSize I64 = 8
typeByteSize Ptr = 8 typeByteSize Ptr = 8
typeByteSize (Ref _) = 8 typeByteSize (Ref _) = 8
typeByteSize (Function _ _) = 8 typeByteSize (Function _ _) = 8
typeByteSize (Array n t) = n * typeByteSize t typeByteSize (Array n t) = n * typeByteSize t
typeByteSize (CustomType _) = 8 typeByteSize (CustomType _) = 8
enumerateOneM_ :: Monad m => (Integer -> a -> m b) -> [a] -> m () enumerateOneM_ :: Monad m => (Integer -> a -> m b) -> [a] -> m ()
enumerateOneM_ f = foldM_ (\i a -> f i a >> pure (i + 1)) 1 enumerateOneM_ f = foldM_ (\i a -> f i a >> pure (i + 1)) 1

View file

@ -1,194 +0,0 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
module LambdaLifter.LambdaLifter where
import Auxiliary (snoc)
import Control.Applicative (Applicative (liftA2))
import Control.Monad.State (MonadState (get, put), State, evalState)
import Data.Set (Set)
import Data.Set qualified as Set
import Renamer.Renamer
import TypeChecker.TypeChecker (partitionType)
import TypeChecker.TypeCheckerIr
import Prelude hiding (exp)
{- | Lift lambdas and let expression into supercombinators.
Three phases:
@freeVars@ annotates all the free variables.
@abstract@ converts lambdas into let expressions.
@collectScs@ moves every non-constant let expression to a top-level function.
-}
lambdaLift :: Program -> Program
lambdaLift = collectScs . abstract . freeVars
-- | Annotate free variables
freeVars :: Program -> AnnProgram
freeVars (Program ds) =
[ (n, xs, freeVarsExp (Set.fromList $ map fst xs) e)
| Bind n xs e <- ds
]
freeVarsExp :: Set Ident -> ExpT -> AnnExpT
freeVarsExp localVars (exp, t) = case exp of
EId n
| Set.member n localVars -> (Set.singleton n, (AId n, t))
| otherwise -> (mempty, (AId n, t))
-- EInt i -> (mempty, AInt i)
ELit lit -> (mempty, (ALit lit, t))
EApp e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), (AApp e1' e2', t))
where
e1' = freeVarsExp localVars e1
e2' = freeVarsExp localVars e2
EAdd e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), (AAdd e1' e2', t))
where
e1' = freeVarsExp localVars e1
e2' = freeVarsExp localVars e2
EAbs par e -> (Set.delete par $ freeVarsOf e', (AAbs par e', t))
where
e' = freeVarsExp (Set.insert par localVars) e
-- Sum free variables present in bind and the expression
ELet (Bind (name, t_bind) parms rhs) e -> (Set.union binders_frees e_free, (ALet new_bind e', t))
where
binders_frees = Set.delete name $ freeVarsOf rhs'
e_free = Set.delete name $ freeVarsOf e'
rhs' = freeVarsExp e_localVars rhs
new_bind = ABind (name, t_bind) parms rhs'
e' = freeVarsExp e_localVars e
e_localVars = Set.insert name localVars
freeVarsOf :: AnnExpT -> Set Ident
freeVarsOf = fst
-- AST annotated with free variables
type AnnProgram = [(Id, [Id], AnnExpT)]
type AnnExpT = (Set Ident, AnnExpT')
data ABind = ABind Id [Id] AnnExpT deriving (Show)
type AnnExpT' = (AnnExp, Type)
data AnnExp
= AId Ident
| ALit Lit
| ALet ABind AnnExpT
| AApp AnnExpT AnnExpT
| AAdd AnnExpT AnnExpT
| AAbs Ident AnnExpT
deriving (Show)
{- | Lift lambdas to let expression of the form @let sc = \v₁ x₁ -> e₁@.
Free variables are @v v .. vₙ@ are bound.
-}
abstract :: AnnProgram -> Program
abstract prog = Program $ evalState (mapM go prog) 0
where
go :: (Id, [Id], AnnExpT) -> State Int Bind
go (name, parms, rhs) = Bind name (parms ++ parms1) <$> abstractExp rhs'
where
(rhs', parms1) = flattenLambdasAnn rhs
{- | Flatten nested lambdas and collect the parameters
@\x.\y.\z. ae (ae, [x,y,z])@
-}
flattenLambdasAnn :: AnnExpT -> (AnnExpT, [Id])
flattenLambdasAnn ae = go (ae, [])
where
go :: (AnnExpT, [Id]) -> (AnnExpT, [Id])
go ((free, (e, t)), acc)
| AAbs par (free1, e1) <- e
, TFun t_par _ <- t =
go ((Set.delete par free1, e1), snoc (par, t_par) acc)
| otherwise = ((free, (e, t)), acc)
abstractExp :: AnnExpT -> State Int ExpT
abstractExp (free, (exp, t)) = case exp of
AId n -> pure (EId n, t)
ALit lit -> pure (ELit lit, t)
AApp e1 e2 -> (,t) <$> liftA2 EApp (abstractExp e1) (abstractExp e2)
AAdd e1 e2 -> (,t) <$> liftA2 EAdd (abstractExp e1) (abstractExp e2)
ALet b e -> (,t) <$> liftA2 ELet (go b) (abstractExp e)
where
go (ABind name parms rhs) = do
(rhs', parms1) <- flattenLambdas <$> skipLambdas abstractExp rhs
pure $ Bind name (parms ++ parms1) rhs'
skipLambdas :: (AnnExpT -> State Int ExpT) -> AnnExpT -> State Int ExpT
skipLambdas f (free, (ae, t)) = case ae of
AAbs par ae1 -> do
ae1' <- skipLambdas f ae1
pure (EAbs par ae1', t)
_ -> f (free, (ae, t))
-- Lift lambda into let and bind free variables
AAbs parm e -> do
i <- nextNumber
rhs <- abstractExp e
let sc_name = Ident ("sc_" ++ show i)
sc = (ELet (Bind (sc_name, t) vars rhs) (EId sc_name, t), t)
pure $ foldl applyVars sc freeList
where
freeList = Set.toList free
vars = zip names . fst $ partitionType (length names) t
names = snoc parm freeList
applyVars (e, t) name = (EApp (e, t) (EId name, t_var), t_return)
where
(t_var : _, t_return) = partitionType 1 t
nextNumber :: State Int Int
nextNumber = do
i <- get
put $ succ i
pure i
-- | Collects supercombinators by lifting non-constant let expressions
collectScs :: Program -> Program
collectScs (Program scs) = Program $ concatMap collectFromRhs scs
where
collectFromRhs (Bind name parms rhs) =
let (rhs_scs, rhs') = collectScsExp rhs
in Bind name parms rhs' : rhs_scs
collectScsExp :: ExpT -> ([Bind], ExpT)
collectScsExp expT@(exp, typ) = case exp of
EId _ -> ([], expT)
ELit _ -> ([], expT)
EApp e1 e2 -> (scs1 ++ scs2, (EApp e1' e2', typ))
where
(scs1, e1') = collectScsExp e1
(scs2, e2') = collectScsExp e2
EAdd e1 e2 -> (scs1 ++ scs2, (EAdd e1' e2', typ))
where
(scs1, e1') = collectScsExp e1
(scs2, e2') = collectScsExp e2
EAbs par e -> (scs, (EAbs par e', typ))
where
(scs, e') = collectScsExp e
-- Collect supercombinators from bind, the rhss, and the expression.
--
-- > f = let sc x y = rhs in e
--
ELet (Bind name parms rhs) e ->
if null parms
then (rhs_scs ++ et_scs, (ELet bind et', snd et'))
else (bind : rhs_scs ++ et_scs, et')
where
bind = Bind name parms rhs'
(rhs_scs, rhs') = collectScsExp rhs
(et_scs, et') = collectScsExp e
-- @\x.\y.\z. e → (e, [x,y,z])@
flattenLambdas :: ExpT -> (ExpT, [Id])
flattenLambdas = go . (,[])
where
go ((e, t), acc) = case e of
EAbs name e1 -> go (e1, snoc (name, t_var) acc)
where
t_var : _ = fst $ partitionType 1 t
_ -> ((e, t), acc)

View file

@ -2,17 +2,16 @@
module Main where module Main where
-- import Codegen.Codegen (generateCode) import Codegen.Codegen (generateCode)
import GHC.IO.Handle.Text (hPutStrLn) import GHC.IO.Handle.Text (hPutStrLn)
import Grammar.ErrM (Err) import Grammar.ErrM (Err)
import Grammar.Par (myLexer, pProgram) import Grammar.Par (myLexer, pProgram)
import Grammar.Print (printTree) import Grammar.Print (printTree)
import Monomorphizer.Monomorphizer (monomorphize)
-- import Interpreter (interpret)
import Control.Monad (when) import Control.Monad (when)
import Data.List.Extra (isSuffixOf) import Data.List.Extra (isSuffixOf)
-- import LambdaLifter.LambdaLifter (lambdaLift)
import Renamer.Renamer (rename) import Renamer.Renamer (rename)
import System.Directory ( import System.Directory (
createDirectory, createDirectory,
@ -54,9 +53,9 @@ main' debug s = do
-- let lifted = lambdaLift typechecked -- let lifted = lambdaLift typechecked
-- printToErr $ printTree lifted -- printToErr $ printTree lifted
-- --
-- printToErr "\n -- Printing compiler output to stdout --" printToErr "\n -- Printing compiler output to stdout --"
-- compiled <- fromCompilerErr $ generateCode lifted compiled <- fromCompilerErr $ generateCode (monomorphize typechecked)
-- putStrLn compiled putStrLn compiled
-- check <- doesPathExist "output" -- check <- doesPathExist "output"
-- when check (removeDirectoryRecursive "output") -- when check (removeDirectoryRecursive "output")

View file

@ -1 +1,17 @@
module Monomorphizer.Monomorphizer where module Monomorphizer.Monomorphizer (monomorphize) where
import Monomorphizer.MonomorphizerIr
import TypeChecker.TypeCheckerIr qualified as T
monomorphize :: T.Program -> Program
monomorphize (T.Program ds) = Program $ monoDefs ds
monoDefs :: [T.Def] -> [Def]
monoDefs = map monoDef
monoDef :: T.Def -> Def
monoDef (T.DBind bind) = DBind $ monoBind bind
monoDef (T.DData d) = DData d
monoBind :: T.Bind -> Bind
monoBind (T.Bind name args e) = Bind name args e

View file

@ -1,14 +1,19 @@
module Monomorphizer.MonomorphizerIr where module Monomorphizer.MonomorphizerIr where
import Grammar.Abs (Ident)
newtype Program = Program [Bind] import Grammar.Abs (Data, Ident, Init)
import TypeChecker.TypeCheckerIr (ExpT, Id, Indexed)
newtype Program = Program [Def]
deriving (Show, Ord, Eq) deriving (Show, Ord, Eq)
data Bind = Bind Id [Id] ExpT | DataType Ident [Constructor] data Def = DBind Bind | DData Data
deriving (Show, Ord, Eq)
data Bind = Bind Id [Id] ExpT
deriving (Show, Ord, Eq) deriving (Show, Ord, Eq)
data Exp data Exp
= EId Id = EId Id
| ELit Lit | ELit Lit
| ELet Id ExpT ExpT | ELet Id ExpT ExpT
| EApp Type ExpT ExpT | EApp Type ExpT ExpT
@ -16,20 +21,15 @@ data Exp
| ECase Type ExpT [Injection] | ECase Type ExpT [Injection]
deriving (Show, Ord, Eq) deriving (Show, Ord, Eq)
data Injection = Injection Case ExpT data Injection = Injection (Init, Type) ExpT
deriving (Show, Ord, Eq) deriving (Eq, Ord, Show)
data Case = CLit Lit | CCons Id [Case] | CIdent Ident | CatchAll
deriving (Show, Ord, Eq)
data Constructor = Constructor Ident [Type] data Constructor = Constructor Ident [Type]
deriving (Show, Ord, Eq) deriving (Show, Ord, Eq)
type Id = (Ident, Type) data Lit
type ExpT = (Exp, Type) = LInt Integer
| LChar Char
data Lit = LInt Integer
| LChar Char
deriving (Show, Ord, Eq) deriving (Show, Ord, Eq)
newtype Type = Type Ident newtype Type = Type Ident