We got pattern matching on data types!

This commit is contained in:
Samuel Hammersberg 2023-03-29 14:31:24 +02:00
parent 2860d47f11
commit 100b7b113a
3 changed files with 58 additions and 40 deletions

View file

@ -20,8 +20,6 @@ import Data.Coerce (coerce)
import Data.Map (Map) import Data.Map (Map)
import Data.Map qualified as Map import Data.Map qualified as Map
import Data.Maybe (fromJust, fromMaybe) import Data.Maybe (fromJust, fromMaybe)
import Data.Set (Set)
import Data.Set qualified as Set
import Data.Tuple.Extra (dupe, first, second) import Data.Tuple.Extra (dupe, first, second)
import Debug.Trace (trace) import Debug.Trace (trace)
import Grammar.ErrM (Err) import Grammar.ErrM (Err)
@ -32,7 +30,7 @@ import TypeChecker.TypeCheckerIr qualified as TIR
data CodeGenerator = CodeGenerator data CodeGenerator = CodeGenerator
{ instructions :: [LLVMIr] { instructions :: [LLVMIr]
, functions :: Map MIR.Id FunctionInfo , functions :: Map MIR.Id FunctionInfo
, customTypes :: Set LLVMType , customTypes :: Map LLVMType Integer
, constructors :: Map TIR.Ident ConstructorInfo , constructors :: Map TIR.Ident ConstructorInfo
, variableCount :: Integer , variableCount :: Integer
, labelCount :: Integer , labelCount :: Integer
@ -60,9 +58,7 @@ emit l = modify $ \t -> t{instructions = Auxiliary.snoc l $ instructions t}
-- | Increases the variable counter in the CodeGenerator state -- | Increases the variable counter in the CodeGenerator state
increaseVarCount :: CompilerState () increaseVarCount :: CompilerState ()
increaseVarCount = do increaseVarCount = modify $ \t -> t{variableCount = variableCount t + 1}
gets variableCount >>= \s -> emit . Comment $ "increase: " <> show (s + 1)
modify $ \t -> t{variableCount = variableCount t + 1}
-- | Returns the variable count from the CodeGenerator state -- | Returns the variable count from the CodeGenerator state
getVarCount :: CompilerState Integer getVarCount :: CompilerState Integer
@ -122,12 +118,14 @@ getConstructors bs = Map.fromList $ go bs
<> go xs <> go xs
go (_ : xs) = go xs go (_ : xs) = go xs
getTypes :: [MIR.Def] -> Set LLVMType getTypes :: [MIR.Def] -> Map LLVMType Integer
getTypes bs = Set.fromList $ go bs getTypes bs = Map.fromList $ go bs
where where
go [] = [] go [] = []
go (MIR.DData (MIR.Data t _) : xs) = type2LlvmType t : go xs go (MIR.DData (MIR.Data t ts) : xs) = (type2LlvmType t, biggestVariant ts) : go xs
go (_ : xs) = go xs go (_ : xs) = go xs
variantTypes fi = init $ map type2LlvmType (flattenType fi)
biggestVariant ts = 8 + maximum (sum . (\(Inj _ fi) -> typeByteSize <$> variantTypes fi) <$> ts)
initCodeGenerator :: [MIR.Def] -> CodeGenerator initCodeGenerator :: [MIR.Def] -> CodeGenerator
initCodeGenerator scs = initCodeGenerator scs =
@ -225,6 +223,7 @@ compileScs [] = do
-- get a pointer of the correct type -- get a pointer of the correct type
ptr' <- getNewVar ptr' <- getNewVar
emit $ SetVariable ptr' (Bitcast (Ref t') (VIdent top Ptr) (Ref $ CustomType id)) emit $ SetVariable ptr' (Bitcast (Ref t') (VIdent top Ptr) (Ref $ CustomType id))
cTypes <- gets customTypes
enumerateOneM_ enumerateOneM_
( \i (TIR.Ident arg_n, arg_t) -> do ( \i (TIR.Ident arg_n, arg_t) -> do
@ -243,6 +242,15 @@ compileScs [] = do
I32 I32
(VInteger i) (VInteger i)
) )
case Map.lookup arg_t' cTypes of
Just s -> do
emit $ Comment "Malloc and store"
heapPtr <- getNewVar
emit $ SetVariable heapPtr (Malloca s)
emit $ Store arg_t' (VIdent (TIR.Ident arg_n) arg_t') Ptr heapPtr
emit $ Store (Ref arg_t') (VIdent heapPtr arg_t') Ptr elemPtr
Nothing -> do
emit $ Comment "Just store"
emit $ Store arg_t' (VIdent (TIR.Ident arg_n) arg_t') Ptr elemPtr emit $ Store arg_t' (VIdent (TIR.Ident arg_n) arg_t') Ptr elemPtr
) )
(argumentsCI ci) (argumentsCI ci)
@ -274,12 +282,15 @@ compileScs (MIR.DBind (MIR.Bind (name, t) args exp) : xs) = do
compileScs xs compileScs xs
compileScs (MIR.DData (MIR.Data typ ts) : xs) = do compileScs (MIR.DData (MIR.Data typ ts) : xs) = do
let (TIR.Ident outer_id) = extractTypeName typ let (TIR.Ident outer_id) = extractTypeName typ
-- //TODO this could be extracted from the customTypes map
let variantTypes fi = init $ map type2LlvmType (flattenType fi) let variantTypes fi = init $ map type2LlvmType (flattenType fi)
let biggestVariant = 7 + maximum (sum . (\(Inj _ fi) -> typeByteSize <$> variantTypes fi) <$> ts) let biggestVariant = 7 + maximum (sum . (\(Inj _ fi) -> typeByteSize <$> variantTypes fi) <$> ts)
emit $ LIR.Type (TIR.Ident outer_id) [I8, Array biggestVariant I8] emit $ LIR.Type (TIR.Ident outer_id) [I8, Array biggestVariant I8]
typeSets <- gets customTypes
mapM_ mapM_
( \(Inj inner_id fi) -> do ( \(Inj inner_id fi) -> do
emit $ LIR.Type inner_id (I8 : variantTypes fi) let types = (\s -> if Map.member s typeSets then Ref s else s) <$> variantTypes fi
emit $ LIR.Type inner_id (I8 : types)
) )
ts ts
compileScs xs compileScs xs
@ -369,32 +380,28 @@ emitECased t e cases = do
emit $ SetVariable castPtr (Alloca rt) emit $ SetVariable castPtr (Alloca rt)
emit $ Store rt vs Ptr castPtr emit $ Store rt vs Ptr castPtr
emit $ SetVariable casted (Load (CustomType (coerce consId)) Ptr castPtr) emit $ SetVariable casted (Load (CustomType (coerce consId)) Ptr castPtr)
val <- exprToValue exp
enumerateOneM_ enumerateOneM_
( \i c -> do ( \i c -> do
case c of case c of
PVar x -> do PVar (x, topT) -> do
emit . Comment $ "ident " <> show x let topT' = type2LlvmType topT
emit $ SetVariable (fst x) (ExtractValue (CustomType (coerce consId)) (VIdent casted Ptr) i) let botT' = CustomType (coerce consId)
emit . Comment $ "ident " <> toIr topT'
cTypes <- gets customTypes
if Map.member topT' cTypes
then do
emit . Comment $ "tjabatjena"
deref <- getNewVar
emit $ SetVariable deref (ExtractValue botT' (VIdent casted Ptr) i)
emit $ SetVariable x (Load topT' Ptr deref)
else emit $ SetVariable x (ExtractValue botT' (VIdent casted Ptr) i)
PLit (_l, _t) -> undefined PLit (_l, _t) -> undefined
PInj _id _ps -> undefined PInj _id _ps -> undefined
PCatch -> pure () PCatch -> pure ()
PEnum _id -> undefined PEnum _id -> undefined
-- case c of
-- CIdent x -> do
-- emit . Comment $ "ident " <> show x
-- emit $ SetVariable x (ExtractValue (CustomType (fst consId)) (VIdent casted Ptr) i)
-- emit $ Store ty val Ptr stackPtr
-- CCons x cs -> error "nested constructor"
-- CLit l -> do
-- testVar <- getNewVar
-- emit $ SetVariable testVar (ExtractValue (CustomType (fst consId)) (VIdent casted Ptr) i)
-- case l of
-- LInt l -> emit $ Icmp LLEq I64 (VIdent testVar Ptr) (VInteger l)
-- LChar c -> emit $ Icmp LLEq I8 (VIdent testVar Ptr) (VChar c)
-- CCatch -> emit . Comment $ "Catch all"
) )
cs cs
val <- exprToValue exp
emit $ Store ty val Ptr stackPtr emit $ Store ty val Ptr stackPtr
emit $ Br label emit $ Br label
emit $ Label lbl_failPos emit $ Label lbl_failPos

View file

@ -225,7 +225,7 @@ llvmIrToString = go 0
(Alloca t) -> unwords ["alloca", toIr t, "\n"] (Alloca t) -> unwords ["alloca", toIr t, "\n"]
(Malloca t) -> (Malloca t) ->
concat concat
[ "call ptr @malloc(i32 ", show t, ")"] [ "call ptr @malloc(i32 ", show t, ")\n"]
(Store t1 val t2 (Ident id2)) -> (Store t1 val t2 (Ident id2)) ->
concat concat
[ "store ", toIr t1, " ", toIr val [ "store ", toIr t1, " ", toIr val

View file

@ -1,13 +1,24 @@
id x = x; -- a simple list data type containing ints
data List () where {
const x y = x ; Cons : Int -> List () -> List ()
Nil : List ()
data Maybe () where {
Just : Int -> Maybe ()
Nothing : Maybe ()
}; };
main = case (Just 5) of { main = sumlength (Cons 1 (Cons 2 (Cons 3 (Cons 4 (Cons 5 Nil)))));
Just a => 10 ;
Nothing => 0 ; -- take the length of a list
}; --const (id 0) (id 'a') ; length : List () -> Int ;
length x = case x of {
Cons _ xs => 1 + length xs ;
Nil => 0 ;
};
-- sum a list
sum : List () -> Int ;
sum x = case x of {
Cons a xs => a + sum xs ;
Nil => 0 ;
};
-- sum + length of a list
sumlength: List () -> Int ;
sumlength x = sum x + length x ;