We got pattern matching on data types!
This commit is contained in:
parent
2860d47f11
commit
100b7b113a
3 changed files with 58 additions and 40 deletions
|
|
@ -20,8 +20,6 @@ import Data.Coerce (coerce)
|
|||
import Data.Map (Map)
|
||||
import Data.Map qualified as Map
|
||||
import Data.Maybe (fromJust, fromMaybe)
|
||||
import Data.Set (Set)
|
||||
import Data.Set qualified as Set
|
||||
import Data.Tuple.Extra (dupe, first, second)
|
||||
import Debug.Trace (trace)
|
||||
import Grammar.ErrM (Err)
|
||||
|
|
@ -32,7 +30,7 @@ import TypeChecker.TypeCheckerIr qualified as TIR
|
|||
data CodeGenerator = CodeGenerator
|
||||
{ instructions :: [LLVMIr]
|
||||
, functions :: Map MIR.Id FunctionInfo
|
||||
, customTypes :: Set LLVMType
|
||||
, customTypes :: Map LLVMType Integer
|
||||
, constructors :: Map TIR.Ident ConstructorInfo
|
||||
, variableCount :: Integer
|
||||
, labelCount :: Integer
|
||||
|
|
@ -60,9 +58,7 @@ emit l = modify $ \t -> t{instructions = Auxiliary.snoc l $ instructions t}
|
|||
|
||||
-- | Increases the variable counter in the CodeGenerator state
|
||||
increaseVarCount :: CompilerState ()
|
||||
increaseVarCount = do
|
||||
gets variableCount >>= \s -> emit . Comment $ "increase: " <> show (s + 1)
|
||||
modify $ \t -> t{variableCount = variableCount t + 1}
|
||||
increaseVarCount = modify $ \t -> t{variableCount = variableCount t + 1}
|
||||
|
||||
-- | Returns the variable count from the CodeGenerator state
|
||||
getVarCount :: CompilerState Integer
|
||||
|
|
@ -122,12 +118,14 @@ getConstructors bs = Map.fromList $ go bs
|
|||
<> go xs
|
||||
go (_ : xs) = go xs
|
||||
|
||||
getTypes :: [MIR.Def] -> Set LLVMType
|
||||
getTypes bs = Set.fromList $ go bs
|
||||
getTypes :: [MIR.Def] -> Map LLVMType Integer
|
||||
getTypes bs = Map.fromList $ go bs
|
||||
where
|
||||
go [] = []
|
||||
go (MIR.DData (MIR.Data t _) : xs) = type2LlvmType t : go xs
|
||||
go (MIR.DData (MIR.Data t ts) : xs) = (type2LlvmType t, biggestVariant ts) : go xs
|
||||
go (_ : xs) = go xs
|
||||
variantTypes fi = init $ map type2LlvmType (flattenType fi)
|
||||
biggestVariant ts = 8 + maximum (sum . (\(Inj _ fi) -> typeByteSize <$> variantTypes fi) <$> ts)
|
||||
|
||||
initCodeGenerator :: [MIR.Def] -> CodeGenerator
|
||||
initCodeGenerator scs =
|
||||
|
|
@ -225,6 +223,7 @@ compileScs [] = do
|
|||
-- get a pointer of the correct type
|
||||
ptr' <- getNewVar
|
||||
emit $ SetVariable ptr' (Bitcast (Ref t') (VIdent top Ptr) (Ref $ CustomType id))
|
||||
cTypes <- gets customTypes
|
||||
|
||||
enumerateOneM_
|
||||
( \i (TIR.Ident arg_n, arg_t) -> do
|
||||
|
|
@ -243,7 +242,16 @@ compileScs [] = do
|
|||
I32
|
||||
(VInteger i)
|
||||
)
|
||||
emit $ Store arg_t' (VIdent (TIR.Ident arg_n) arg_t') Ptr elemPtr
|
||||
case Map.lookup arg_t' cTypes of
|
||||
Just s -> do
|
||||
emit $ Comment "Malloc and store"
|
||||
heapPtr <- getNewVar
|
||||
emit $ SetVariable heapPtr (Malloca s)
|
||||
emit $ Store arg_t' (VIdent (TIR.Ident arg_n) arg_t') Ptr heapPtr
|
||||
emit $ Store (Ref arg_t') (VIdent heapPtr arg_t') Ptr elemPtr
|
||||
Nothing -> do
|
||||
emit $ Comment "Just store"
|
||||
emit $ Store arg_t' (VIdent (TIR.Ident arg_n) arg_t') Ptr elemPtr
|
||||
)
|
||||
(argumentsCI ci)
|
||||
|
||||
|
|
@ -274,12 +282,15 @@ compileScs (MIR.DBind (MIR.Bind (name, t) args exp) : xs) = do
|
|||
compileScs xs
|
||||
compileScs (MIR.DData (MIR.Data typ ts) : xs) = do
|
||||
let (TIR.Ident outer_id) = extractTypeName typ
|
||||
-- //TODO this could be extracted from the customTypes map
|
||||
let variantTypes fi = init $ map type2LlvmType (flattenType fi)
|
||||
let biggestVariant = 7 + maximum (sum . (\(Inj _ fi) -> typeByteSize <$> variantTypes fi) <$> ts)
|
||||
emit $ LIR.Type (TIR.Ident outer_id) [I8, Array biggestVariant I8]
|
||||
typeSets <- gets customTypes
|
||||
mapM_
|
||||
( \(Inj inner_id fi) -> do
|
||||
emit $ LIR.Type inner_id (I8 : variantTypes fi)
|
||||
let types = (\s -> if Map.member s typeSets then Ref s else s) <$> variantTypes fi
|
||||
emit $ LIR.Type inner_id (I8 : types)
|
||||
)
|
||||
ts
|
||||
compileScs xs
|
||||
|
|
@ -369,32 +380,28 @@ emitECased t e cases = do
|
|||
emit $ SetVariable castPtr (Alloca rt)
|
||||
emit $ Store rt vs Ptr castPtr
|
||||
emit $ SetVariable casted (Load (CustomType (coerce consId)) Ptr castPtr)
|
||||
val <- exprToValue exp
|
||||
enumerateOneM_
|
||||
( \i c -> do
|
||||
case c of
|
||||
PVar x -> do
|
||||
emit . Comment $ "ident " <> show x
|
||||
emit $ SetVariable (fst x) (ExtractValue (CustomType (coerce consId)) (VIdent casted Ptr) i)
|
||||
PVar (x, topT) -> do
|
||||
let topT' = type2LlvmType topT
|
||||
let botT' = CustomType (coerce consId)
|
||||
emit . Comment $ "ident " <> toIr topT'
|
||||
cTypes <- gets customTypes
|
||||
if Map.member topT' cTypes
|
||||
then do
|
||||
emit . Comment $ "tjabatjena"
|
||||
deref <- getNewVar
|
||||
emit $ SetVariable deref (ExtractValue botT' (VIdent casted Ptr) i)
|
||||
emit $ SetVariable x (Load topT' Ptr deref)
|
||||
else emit $ SetVariable x (ExtractValue botT' (VIdent casted Ptr) i)
|
||||
PLit (_l, _t) -> undefined
|
||||
PInj _id _ps -> undefined
|
||||
PCatch -> pure ()
|
||||
PEnum _id -> undefined
|
||||
-- case c of
|
||||
-- CIdent x -> do
|
||||
-- emit . Comment $ "ident " <> show x
|
||||
-- emit $ SetVariable x (ExtractValue (CustomType (fst consId)) (VIdent casted Ptr) i)
|
||||
-- emit $ Store ty val Ptr stackPtr
|
||||
-- CCons x cs -> error "nested constructor"
|
||||
-- CLit l -> do
|
||||
-- testVar <- getNewVar
|
||||
-- emit $ SetVariable testVar (ExtractValue (CustomType (fst consId)) (VIdent casted Ptr) i)
|
||||
-- case l of
|
||||
-- LInt l -> emit $ Icmp LLEq I64 (VIdent testVar Ptr) (VInteger l)
|
||||
-- LChar c -> emit $ Icmp LLEq I8 (VIdent testVar Ptr) (VChar c)
|
||||
-- CCatch -> emit . Comment $ "Catch all"
|
||||
)
|
||||
cs
|
||||
val <- exprToValue exp
|
||||
emit $ Store ty val Ptr stackPtr
|
||||
emit $ Br label
|
||||
emit $ Label lbl_failPos
|
||||
|
|
|
|||
|
|
@ -225,7 +225,7 @@ llvmIrToString = go 0
|
|||
(Alloca t) -> unwords ["alloca", toIr t, "\n"]
|
||||
(Malloca t) ->
|
||||
concat
|
||||
[ "call ptr @malloc(i32 ", show t, ")"]
|
||||
[ "call ptr @malloc(i32 ", show t, ")\n"]
|
||||
(Store t1 val t2 (Ident id2)) ->
|
||||
concat
|
||||
[ "store ", toIr t1, " ", toIr val
|
||||
|
|
|
|||
|
|
@ -1,13 +1,24 @@
|
|||
id x = x;
|
||||
|
||||
const x y = x ;
|
||||
|
||||
data Maybe () where {
|
||||
Just : Int -> Maybe ()
|
||||
Nothing : Maybe ()
|
||||
-- a simple list data type containing ints
|
||||
data List () where {
|
||||
Cons : Int -> List () -> List ()
|
||||
Nil : List ()
|
||||
};
|
||||
|
||||
main = case (Just 5) of {
|
||||
Just a => 10 ;
|
||||
Nothing => 0 ;
|
||||
}; --const (id 0) (id 'a') ;
|
||||
main = sumlength (Cons 1 (Cons 2 (Cons 3 (Cons 4 (Cons 5 Nil)))));
|
||||
|
||||
-- take the length of a list
|
||||
length : List () -> Int ;
|
||||
length x = case x of {
|
||||
Cons _ xs => 1 + length xs ;
|
||||
Nil => 0 ;
|
||||
};
|
||||
-- sum a list
|
||||
sum : List () -> Int ;
|
||||
sum x = case x of {
|
||||
Cons a xs => a + sum xs ;
|
||||
Nil => 0 ;
|
||||
};
|
||||
|
||||
-- sum + length of a list
|
||||
sumlength: List () -> Int ;
|
||||
sumlength x = sum x + length x ;
|
||||
Loading…
Add table
Add a link
Reference in a new issue