We got pattern matching on data types!

This commit is contained in:
Samuel Hammersberg 2023-03-29 14:31:24 +02:00
parent 2860d47f11
commit 100b7b113a
3 changed files with 58 additions and 40 deletions

View file

@ -20,8 +20,6 @@ import Data.Coerce (coerce)
import Data.Map (Map)
import Data.Map qualified as Map
import Data.Maybe (fromJust, fromMaybe)
import Data.Set (Set)
import Data.Set qualified as Set
import Data.Tuple.Extra (dupe, first, second)
import Debug.Trace (trace)
import Grammar.ErrM (Err)
@ -32,7 +30,7 @@ import TypeChecker.TypeCheckerIr qualified as TIR
data CodeGenerator = CodeGenerator
{ instructions :: [LLVMIr]
, functions :: Map MIR.Id FunctionInfo
, customTypes :: Set LLVMType
, customTypes :: Map LLVMType Integer
, constructors :: Map TIR.Ident ConstructorInfo
, variableCount :: Integer
, labelCount :: Integer
@ -60,9 +58,7 @@ emit l = modify $ \t -> t{instructions = Auxiliary.snoc l $ instructions t}
-- | Increases the variable counter in the CodeGenerator state
increaseVarCount :: CompilerState ()
increaseVarCount = do
gets variableCount >>= \s -> emit . Comment $ "increase: " <> show (s + 1)
modify $ \t -> t{variableCount = variableCount t + 1}
increaseVarCount = modify $ \t -> t{variableCount = variableCount t + 1}
-- | Returns the variable count from the CodeGenerator state
getVarCount :: CompilerState Integer
@ -122,12 +118,14 @@ getConstructors bs = Map.fromList $ go bs
<> go xs
go (_ : xs) = go xs
getTypes :: [MIR.Def] -> Set LLVMType
getTypes bs = Set.fromList $ go bs
getTypes :: [MIR.Def] -> Map LLVMType Integer
getTypes bs = Map.fromList $ go bs
where
go [] = []
go (MIR.DData (MIR.Data t _) : xs) = type2LlvmType t : go xs
go (MIR.DData (MIR.Data t ts) : xs) = (type2LlvmType t, biggestVariant ts) : go xs
go (_ : xs) = go xs
variantTypes fi = init $ map type2LlvmType (flattenType fi)
biggestVariant ts = 8 + maximum (sum . (\(Inj _ fi) -> typeByteSize <$> variantTypes fi) <$> ts)
initCodeGenerator :: [MIR.Def] -> CodeGenerator
initCodeGenerator scs =
@ -225,6 +223,7 @@ compileScs [] = do
-- get a pointer of the correct type
ptr' <- getNewVar
emit $ SetVariable ptr' (Bitcast (Ref t') (VIdent top Ptr) (Ref $ CustomType id))
cTypes <- gets customTypes
enumerateOneM_
( \i (TIR.Ident arg_n, arg_t) -> do
@ -243,7 +242,16 @@ compileScs [] = do
I32
(VInteger i)
)
emit $ Store arg_t' (VIdent (TIR.Ident arg_n) arg_t') Ptr elemPtr
case Map.lookup arg_t' cTypes of
Just s -> do
emit $ Comment "Malloc and store"
heapPtr <- getNewVar
emit $ SetVariable heapPtr (Malloca s)
emit $ Store arg_t' (VIdent (TIR.Ident arg_n) arg_t') Ptr heapPtr
emit $ Store (Ref arg_t') (VIdent heapPtr arg_t') Ptr elemPtr
Nothing -> do
emit $ Comment "Just store"
emit $ Store arg_t' (VIdent (TIR.Ident arg_n) arg_t') Ptr elemPtr
)
(argumentsCI ci)
@ -274,12 +282,15 @@ compileScs (MIR.DBind (MIR.Bind (name, t) args exp) : xs) = do
compileScs xs
compileScs (MIR.DData (MIR.Data typ ts) : xs) = do
let (TIR.Ident outer_id) = extractTypeName typ
-- //TODO this could be extracted from the customTypes map
let variantTypes fi = init $ map type2LlvmType (flattenType fi)
let biggestVariant = 7 + maximum (sum . (\(Inj _ fi) -> typeByteSize <$> variantTypes fi) <$> ts)
emit $ LIR.Type (TIR.Ident outer_id) [I8, Array biggestVariant I8]
typeSets <- gets customTypes
mapM_
( \(Inj inner_id fi) -> do
emit $ LIR.Type inner_id (I8 : variantTypes fi)
let types = (\s -> if Map.member s typeSets then Ref s else s) <$> variantTypes fi
emit $ LIR.Type inner_id (I8 : types)
)
ts
compileScs xs
@ -369,32 +380,28 @@ emitECased t e cases = do
emit $ SetVariable castPtr (Alloca rt)
emit $ Store rt vs Ptr castPtr
emit $ SetVariable casted (Load (CustomType (coerce consId)) Ptr castPtr)
val <- exprToValue exp
enumerateOneM_
( \i c -> do
case c of
PVar x -> do
emit . Comment $ "ident " <> show x
emit $ SetVariable (fst x) (ExtractValue (CustomType (coerce consId)) (VIdent casted Ptr) i)
PVar (x, topT) -> do
let topT' = type2LlvmType topT
let botT' = CustomType (coerce consId)
emit . Comment $ "ident " <> toIr topT'
cTypes <- gets customTypes
if Map.member topT' cTypes
then do
emit . Comment $ "tjabatjena"
deref <- getNewVar
emit $ SetVariable deref (ExtractValue botT' (VIdent casted Ptr) i)
emit $ SetVariable x (Load topT' Ptr deref)
else emit $ SetVariable x (ExtractValue botT' (VIdent casted Ptr) i)
PLit (_l, _t) -> undefined
PInj _id _ps -> undefined
PCatch -> pure ()
PEnum _id -> undefined
-- case c of
-- CIdent x -> do
-- emit . Comment $ "ident " <> show x
-- emit $ SetVariable x (ExtractValue (CustomType (fst consId)) (VIdent casted Ptr) i)
-- emit $ Store ty val Ptr stackPtr
-- CCons x cs -> error "nested constructor"
-- CLit l -> do
-- testVar <- getNewVar
-- emit $ SetVariable testVar (ExtractValue (CustomType (fst consId)) (VIdent casted Ptr) i)
-- case l of
-- LInt l -> emit $ Icmp LLEq I64 (VIdent testVar Ptr) (VInteger l)
-- LChar c -> emit $ Icmp LLEq I8 (VIdent testVar Ptr) (VChar c)
-- CCatch -> emit . Comment $ "Catch all"
)
cs
val <- exprToValue exp
emit $ Store ty val Ptr stackPtr
emit $ Br label
emit $ Label lbl_failPos

View file

@ -225,7 +225,7 @@ llvmIrToString = go 0
(Alloca t) -> unwords ["alloca", toIr t, "\n"]
(Malloca t) ->
concat
[ "call ptr @malloc(i32 ", show t, ")"]
[ "call ptr @malloc(i32 ", show t, ")\n"]
(Store t1 val t2 (Ident id2)) ->
concat
[ "store ", toIr t1, " ", toIr val

View file

@ -1,13 +1,24 @@
id x = x;
const x y = x ;
data Maybe () where {
Just : Int -> Maybe ()
Nothing : Maybe ()
-- a simple list data type containing ints
data List () where {
Cons : Int -> List () -> List ()
Nil : List ()
};
main = case (Just 5) of {
Just a => 10 ;
Nothing => 0 ;
}; --const (id 0) (id 'a') ;
main = sumlength (Cons 1 (Cons 2 (Cons 3 (Cons 4 (Cons 5 Nil)))));
-- take the length of a list
length : List () -> Int ;
length x = case x of {
Cons _ xs => 1 + length xs ;
Nil => 0 ;
};
-- sum a list
sum : List () -> Int ;
sum x = case x of {
Cons a xs => a + sum xs ;
Nil => 0 ;
};
-- sum + length of a list
sumlength: List () -> Int ;
sumlength x = sum x + length x ;