We got pattern matching on data types!
This commit is contained in:
parent
2860d47f11
commit
100b7b113a
3 changed files with 58 additions and 40 deletions
|
|
@ -20,8 +20,6 @@ import Data.Coerce (coerce)
|
||||||
import Data.Map (Map)
|
import Data.Map (Map)
|
||||||
import Data.Map qualified as Map
|
import Data.Map qualified as Map
|
||||||
import Data.Maybe (fromJust, fromMaybe)
|
import Data.Maybe (fromJust, fromMaybe)
|
||||||
import Data.Set (Set)
|
|
||||||
import Data.Set qualified as Set
|
|
||||||
import Data.Tuple.Extra (dupe, first, second)
|
import Data.Tuple.Extra (dupe, first, second)
|
||||||
import Debug.Trace (trace)
|
import Debug.Trace (trace)
|
||||||
import Grammar.ErrM (Err)
|
import Grammar.ErrM (Err)
|
||||||
|
|
@ -32,7 +30,7 @@ import TypeChecker.TypeCheckerIr qualified as TIR
|
||||||
data CodeGenerator = CodeGenerator
|
data CodeGenerator = CodeGenerator
|
||||||
{ instructions :: [LLVMIr]
|
{ instructions :: [LLVMIr]
|
||||||
, functions :: Map MIR.Id FunctionInfo
|
, functions :: Map MIR.Id FunctionInfo
|
||||||
, customTypes :: Set LLVMType
|
, customTypes :: Map LLVMType Integer
|
||||||
, constructors :: Map TIR.Ident ConstructorInfo
|
, constructors :: Map TIR.Ident ConstructorInfo
|
||||||
, variableCount :: Integer
|
, variableCount :: Integer
|
||||||
, labelCount :: Integer
|
, labelCount :: Integer
|
||||||
|
|
@ -60,9 +58,7 @@ emit l = modify $ \t -> t{instructions = Auxiliary.snoc l $ instructions t}
|
||||||
|
|
||||||
-- | Increases the variable counter in the CodeGenerator state
|
-- | Increases the variable counter in the CodeGenerator state
|
||||||
increaseVarCount :: CompilerState ()
|
increaseVarCount :: CompilerState ()
|
||||||
increaseVarCount = do
|
increaseVarCount = modify $ \t -> t{variableCount = variableCount t + 1}
|
||||||
gets variableCount >>= \s -> emit . Comment $ "increase: " <> show (s + 1)
|
|
||||||
modify $ \t -> t{variableCount = variableCount t + 1}
|
|
||||||
|
|
||||||
-- | Returns the variable count from the CodeGenerator state
|
-- | Returns the variable count from the CodeGenerator state
|
||||||
getVarCount :: CompilerState Integer
|
getVarCount :: CompilerState Integer
|
||||||
|
|
@ -122,12 +118,14 @@ getConstructors bs = Map.fromList $ go bs
|
||||||
<> go xs
|
<> go xs
|
||||||
go (_ : xs) = go xs
|
go (_ : xs) = go xs
|
||||||
|
|
||||||
getTypes :: [MIR.Def] -> Set LLVMType
|
getTypes :: [MIR.Def] -> Map LLVMType Integer
|
||||||
getTypes bs = Set.fromList $ go bs
|
getTypes bs = Map.fromList $ go bs
|
||||||
where
|
where
|
||||||
go [] = []
|
go [] = []
|
||||||
go (MIR.DData (MIR.Data t _) : xs) = type2LlvmType t : go xs
|
go (MIR.DData (MIR.Data t ts) : xs) = (type2LlvmType t, biggestVariant ts) : go xs
|
||||||
go (_ : xs) = go xs
|
go (_ : xs) = go xs
|
||||||
|
variantTypes fi = init $ map type2LlvmType (flattenType fi)
|
||||||
|
biggestVariant ts = 8 + maximum (sum . (\(Inj _ fi) -> typeByteSize <$> variantTypes fi) <$> ts)
|
||||||
|
|
||||||
initCodeGenerator :: [MIR.Def] -> CodeGenerator
|
initCodeGenerator :: [MIR.Def] -> CodeGenerator
|
||||||
initCodeGenerator scs =
|
initCodeGenerator scs =
|
||||||
|
|
@ -225,6 +223,7 @@ compileScs [] = do
|
||||||
-- get a pointer of the correct type
|
-- get a pointer of the correct type
|
||||||
ptr' <- getNewVar
|
ptr' <- getNewVar
|
||||||
emit $ SetVariable ptr' (Bitcast (Ref t') (VIdent top Ptr) (Ref $ CustomType id))
|
emit $ SetVariable ptr' (Bitcast (Ref t') (VIdent top Ptr) (Ref $ CustomType id))
|
||||||
|
cTypes <- gets customTypes
|
||||||
|
|
||||||
enumerateOneM_
|
enumerateOneM_
|
||||||
( \i (TIR.Ident arg_n, arg_t) -> do
|
( \i (TIR.Ident arg_n, arg_t) -> do
|
||||||
|
|
@ -243,7 +242,16 @@ compileScs [] = do
|
||||||
I32
|
I32
|
||||||
(VInteger i)
|
(VInteger i)
|
||||||
)
|
)
|
||||||
emit $ Store arg_t' (VIdent (TIR.Ident arg_n) arg_t') Ptr elemPtr
|
case Map.lookup arg_t' cTypes of
|
||||||
|
Just s -> do
|
||||||
|
emit $ Comment "Malloc and store"
|
||||||
|
heapPtr <- getNewVar
|
||||||
|
emit $ SetVariable heapPtr (Malloca s)
|
||||||
|
emit $ Store arg_t' (VIdent (TIR.Ident arg_n) arg_t') Ptr heapPtr
|
||||||
|
emit $ Store (Ref arg_t') (VIdent heapPtr arg_t') Ptr elemPtr
|
||||||
|
Nothing -> do
|
||||||
|
emit $ Comment "Just store"
|
||||||
|
emit $ Store arg_t' (VIdent (TIR.Ident arg_n) arg_t') Ptr elemPtr
|
||||||
)
|
)
|
||||||
(argumentsCI ci)
|
(argumentsCI ci)
|
||||||
|
|
||||||
|
|
@ -274,12 +282,15 @@ compileScs (MIR.DBind (MIR.Bind (name, t) args exp) : xs) = do
|
||||||
compileScs xs
|
compileScs xs
|
||||||
compileScs (MIR.DData (MIR.Data typ ts) : xs) = do
|
compileScs (MIR.DData (MIR.Data typ ts) : xs) = do
|
||||||
let (TIR.Ident outer_id) = extractTypeName typ
|
let (TIR.Ident outer_id) = extractTypeName typ
|
||||||
|
-- //TODO this could be extracted from the customTypes map
|
||||||
let variantTypes fi = init $ map type2LlvmType (flattenType fi)
|
let variantTypes fi = init $ map type2LlvmType (flattenType fi)
|
||||||
let biggestVariant = 7 + maximum (sum . (\(Inj _ fi) -> typeByteSize <$> variantTypes fi) <$> ts)
|
let biggestVariant = 7 + maximum (sum . (\(Inj _ fi) -> typeByteSize <$> variantTypes fi) <$> ts)
|
||||||
emit $ LIR.Type (TIR.Ident outer_id) [I8, Array biggestVariant I8]
|
emit $ LIR.Type (TIR.Ident outer_id) [I8, Array biggestVariant I8]
|
||||||
|
typeSets <- gets customTypes
|
||||||
mapM_
|
mapM_
|
||||||
( \(Inj inner_id fi) -> do
|
( \(Inj inner_id fi) -> do
|
||||||
emit $ LIR.Type inner_id (I8 : variantTypes fi)
|
let types = (\s -> if Map.member s typeSets then Ref s else s) <$> variantTypes fi
|
||||||
|
emit $ LIR.Type inner_id (I8 : types)
|
||||||
)
|
)
|
||||||
ts
|
ts
|
||||||
compileScs xs
|
compileScs xs
|
||||||
|
|
@ -369,32 +380,28 @@ emitECased t e cases = do
|
||||||
emit $ SetVariable castPtr (Alloca rt)
|
emit $ SetVariable castPtr (Alloca rt)
|
||||||
emit $ Store rt vs Ptr castPtr
|
emit $ Store rt vs Ptr castPtr
|
||||||
emit $ SetVariable casted (Load (CustomType (coerce consId)) Ptr castPtr)
|
emit $ SetVariable casted (Load (CustomType (coerce consId)) Ptr castPtr)
|
||||||
val <- exprToValue exp
|
|
||||||
enumerateOneM_
|
enumerateOneM_
|
||||||
( \i c -> do
|
( \i c -> do
|
||||||
case c of
|
case c of
|
||||||
PVar x -> do
|
PVar (x, topT) -> do
|
||||||
emit . Comment $ "ident " <> show x
|
let topT' = type2LlvmType topT
|
||||||
emit $ SetVariable (fst x) (ExtractValue (CustomType (coerce consId)) (VIdent casted Ptr) i)
|
let botT' = CustomType (coerce consId)
|
||||||
|
emit . Comment $ "ident " <> toIr topT'
|
||||||
|
cTypes <- gets customTypes
|
||||||
|
if Map.member topT' cTypes
|
||||||
|
then do
|
||||||
|
emit . Comment $ "tjabatjena"
|
||||||
|
deref <- getNewVar
|
||||||
|
emit $ SetVariable deref (ExtractValue botT' (VIdent casted Ptr) i)
|
||||||
|
emit $ SetVariable x (Load topT' Ptr deref)
|
||||||
|
else emit $ SetVariable x (ExtractValue botT' (VIdent casted Ptr) i)
|
||||||
PLit (_l, _t) -> undefined
|
PLit (_l, _t) -> undefined
|
||||||
PInj _id _ps -> undefined
|
PInj _id _ps -> undefined
|
||||||
PCatch -> pure ()
|
PCatch -> pure ()
|
||||||
PEnum _id -> undefined
|
PEnum _id -> undefined
|
||||||
-- case c of
|
|
||||||
-- CIdent x -> do
|
|
||||||
-- emit . Comment $ "ident " <> show x
|
|
||||||
-- emit $ SetVariable x (ExtractValue (CustomType (fst consId)) (VIdent casted Ptr) i)
|
|
||||||
-- emit $ Store ty val Ptr stackPtr
|
|
||||||
-- CCons x cs -> error "nested constructor"
|
|
||||||
-- CLit l -> do
|
|
||||||
-- testVar <- getNewVar
|
|
||||||
-- emit $ SetVariable testVar (ExtractValue (CustomType (fst consId)) (VIdent casted Ptr) i)
|
|
||||||
-- case l of
|
|
||||||
-- LInt l -> emit $ Icmp LLEq I64 (VIdent testVar Ptr) (VInteger l)
|
|
||||||
-- LChar c -> emit $ Icmp LLEq I8 (VIdent testVar Ptr) (VChar c)
|
|
||||||
-- CCatch -> emit . Comment $ "Catch all"
|
|
||||||
)
|
)
|
||||||
cs
|
cs
|
||||||
|
val <- exprToValue exp
|
||||||
emit $ Store ty val Ptr stackPtr
|
emit $ Store ty val Ptr stackPtr
|
||||||
emit $ Br label
|
emit $ Br label
|
||||||
emit $ Label lbl_failPos
|
emit $ Label lbl_failPos
|
||||||
|
|
|
||||||
|
|
@ -225,7 +225,7 @@ llvmIrToString = go 0
|
||||||
(Alloca t) -> unwords ["alloca", toIr t, "\n"]
|
(Alloca t) -> unwords ["alloca", toIr t, "\n"]
|
||||||
(Malloca t) ->
|
(Malloca t) ->
|
||||||
concat
|
concat
|
||||||
[ "call ptr @malloc(i32 ", show t, ")"]
|
[ "call ptr @malloc(i32 ", show t, ")\n"]
|
||||||
(Store t1 val t2 (Ident id2)) ->
|
(Store t1 val t2 (Ident id2)) ->
|
||||||
concat
|
concat
|
||||||
[ "store ", toIr t1, " ", toIr val
|
[ "store ", toIr t1, " ", toIr val
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,24 @@
|
||||||
id x = x;
|
-- a simple list data type containing ints
|
||||||
|
data List () where {
|
||||||
const x y = x ;
|
Cons : Int -> List () -> List ()
|
||||||
|
Nil : List ()
|
||||||
data Maybe () where {
|
|
||||||
Just : Int -> Maybe ()
|
|
||||||
Nothing : Maybe ()
|
|
||||||
};
|
};
|
||||||
|
|
||||||
main = case (Just 5) of {
|
main = sumlength (Cons 1 (Cons 2 (Cons 3 (Cons 4 (Cons 5 Nil)))));
|
||||||
Just a => 10 ;
|
|
||||||
Nothing => 0 ;
|
-- take the length of a list
|
||||||
}; --const (id 0) (id 'a') ;
|
length : List () -> Int ;
|
||||||
|
length x = case x of {
|
||||||
|
Cons _ xs => 1 + length xs ;
|
||||||
|
Nil => 0 ;
|
||||||
|
};
|
||||||
|
-- sum a list
|
||||||
|
sum : List () -> Int ;
|
||||||
|
sum x = case x of {
|
||||||
|
Cons a xs => a + sum xs ;
|
||||||
|
Nil => 0 ;
|
||||||
|
};
|
||||||
|
|
||||||
|
-- sum + length of a list
|
||||||
|
sumlength: List () -> Int ;
|
||||||
|
sumlength x = sum x + length x ;
|
||||||
Loading…
Add table
Add a link
Reference in a new issue