diff --git a/src/Codegen/Codegen.hs b/src/Codegen/Codegen.hs index 9827571..6cb510d 100644 --- a/src/Codegen/Codegen.hs +++ b/src/Codegen/Codegen.hs @@ -20,8 +20,6 @@ import Data.Coerce (coerce) import Data.Map (Map) import Data.Map qualified as Map import Data.Maybe (fromJust, fromMaybe) -import Data.Set (Set) -import Data.Set qualified as Set import Data.Tuple.Extra (dupe, first, second) import Debug.Trace (trace) import Grammar.ErrM (Err) @@ -32,7 +30,7 @@ import TypeChecker.TypeCheckerIr qualified as TIR data CodeGenerator = CodeGenerator { instructions :: [LLVMIr] , functions :: Map MIR.Id FunctionInfo - , customTypes :: Set LLVMType + , customTypes :: Map LLVMType Integer , constructors :: Map TIR.Ident ConstructorInfo , variableCount :: Integer , labelCount :: Integer @@ -60,9 +58,7 @@ emit l = modify $ \t -> t{instructions = Auxiliary.snoc l $ instructions t} -- | Increases the variable counter in the CodeGenerator state increaseVarCount :: CompilerState () -increaseVarCount = do - gets variableCount >>= \s -> emit . Comment $ "increase: " <> show (s + 1) - modify $ \t -> t{variableCount = variableCount t + 1} +increaseVarCount = modify $ \t -> t{variableCount = variableCount t + 1} -- | Returns the variable count from the CodeGenerator state getVarCount :: CompilerState Integer @@ -122,12 +118,14 @@ getConstructors bs = Map.fromList $ go bs <> go xs go (_ : xs) = go xs -getTypes :: [MIR.Def] -> Set LLVMType -getTypes bs = Set.fromList $ go bs +getTypes :: [MIR.Def] -> Map LLVMType Integer +getTypes bs = Map.fromList $ go bs where go [] = [] - go (MIR.DData (MIR.Data t _) : xs) = type2LlvmType t : go xs + go (MIR.DData (MIR.Data t ts) : xs) = (type2LlvmType t, biggestVariant ts) : go xs go (_ : xs) = go xs + variantTypes fi = init $ map type2LlvmType (flattenType fi) + biggestVariant ts = 8 + maximum (sum . (\(Inj _ fi) -> typeByteSize <$> variantTypes fi) <$> ts) initCodeGenerator :: [MIR.Def] -> CodeGenerator initCodeGenerator scs = @@ -225,6 +223,7 @@ compileScs [] = do -- get a pointer of the correct type ptr' <- getNewVar emit $ SetVariable ptr' (Bitcast (Ref t') (VIdent top Ptr) (Ref $ CustomType id)) + cTypes <- gets customTypes enumerateOneM_ ( \i (TIR.Ident arg_n, arg_t) -> do @@ -243,7 +242,16 @@ compileScs [] = do I32 (VInteger i) ) - emit $ Store arg_t' (VIdent (TIR.Ident arg_n) arg_t') Ptr elemPtr + case Map.lookup arg_t' cTypes of + Just s -> do + emit $ Comment "Malloc and store" + heapPtr <- getNewVar + emit $ SetVariable heapPtr (Malloca s) + emit $ Store arg_t' (VIdent (TIR.Ident arg_n) arg_t') Ptr heapPtr + emit $ Store (Ref arg_t') (VIdent heapPtr arg_t') Ptr elemPtr + Nothing -> do + emit $ Comment "Just store" + emit $ Store arg_t' (VIdent (TIR.Ident arg_n) arg_t') Ptr elemPtr ) (argumentsCI ci) @@ -274,12 +282,15 @@ compileScs (MIR.DBind (MIR.Bind (name, t) args exp) : xs) = do compileScs xs compileScs (MIR.DData (MIR.Data typ ts) : xs) = do let (TIR.Ident outer_id) = extractTypeName typ + -- //TODO this could be extracted from the customTypes map let variantTypes fi = init $ map type2LlvmType (flattenType fi) let biggestVariant = 7 + maximum (sum . (\(Inj _ fi) -> typeByteSize <$> variantTypes fi) <$> ts) emit $ LIR.Type (TIR.Ident outer_id) [I8, Array biggestVariant I8] + typeSets <- gets customTypes mapM_ ( \(Inj inner_id fi) -> do - emit $ LIR.Type inner_id (I8 : variantTypes fi) + let types = (\s -> if Map.member s typeSets then Ref s else s) <$> variantTypes fi + emit $ LIR.Type inner_id (I8 : types) ) ts compileScs xs @@ -369,32 +380,28 @@ emitECased t e cases = do emit $ SetVariable castPtr (Alloca rt) emit $ Store rt vs Ptr castPtr emit $ SetVariable casted (Load (CustomType (coerce consId)) Ptr castPtr) - val <- exprToValue exp enumerateOneM_ ( \i c -> do case c of - PVar x -> do - emit . Comment $ "ident " <> show x - emit $ SetVariable (fst x) (ExtractValue (CustomType (coerce consId)) (VIdent casted Ptr) i) + PVar (x, topT) -> do + let topT' = type2LlvmType topT + let botT' = CustomType (coerce consId) + emit . Comment $ "ident " <> toIr topT' + cTypes <- gets customTypes + if Map.member topT' cTypes + then do + emit . Comment $ "tjabatjena" + deref <- getNewVar + emit $ SetVariable deref (ExtractValue botT' (VIdent casted Ptr) i) + emit $ SetVariable x (Load topT' Ptr deref) + else emit $ SetVariable x (ExtractValue botT' (VIdent casted Ptr) i) PLit (_l, _t) -> undefined PInj _id _ps -> undefined PCatch -> pure () PEnum _id -> undefined - -- case c of - -- CIdent x -> do - -- emit . Comment $ "ident " <> show x - -- emit $ SetVariable x (ExtractValue (CustomType (fst consId)) (VIdent casted Ptr) i) - -- emit $ Store ty val Ptr stackPtr - -- CCons x cs -> error "nested constructor" - -- CLit l -> do - -- testVar <- getNewVar - -- emit $ SetVariable testVar (ExtractValue (CustomType (fst consId)) (VIdent casted Ptr) i) - -- case l of - -- LInt l -> emit $ Icmp LLEq I64 (VIdent testVar Ptr) (VInteger l) - -- LChar c -> emit $ Icmp LLEq I8 (VIdent testVar Ptr) (VChar c) - -- CCatch -> emit . Comment $ "Catch all" ) cs + val <- exprToValue exp emit $ Store ty val Ptr stackPtr emit $ Br label emit $ Label lbl_failPos diff --git a/src/Codegen/LlvmIr.hs b/src/Codegen/LlvmIr.hs index 4a309c7..0ef6ac0 100644 --- a/src/Codegen/LlvmIr.hs +++ b/src/Codegen/LlvmIr.hs @@ -225,7 +225,7 @@ llvmIrToString = go 0 (Alloca t) -> unwords ["alloca", toIr t, "\n"] (Malloca t) -> concat - [ "call ptr @malloc(i32 ", show t, ")"] + [ "call ptr @malloc(i32 ", show t, ")\n"] (Store t1 val t2 (Ident id2)) -> concat [ "store ", toIr t1, " ", toIr val diff --git a/test_program.crf b/test_program.crf index 64aa2e7..cf754ca 100644 --- a/test_program.crf +++ b/test_program.crf @@ -1,13 +1,24 @@ -id x = x; - -const x y = x ; - -data Maybe () where { - Just : Int -> Maybe () - Nothing : Maybe () +-- a simple list data type containing ints +data List () where { + Cons : Int -> List () -> List () + Nil : List () }; -main = case (Just 5) of { - Just a => 10 ; - Nothing => 0 ; -}; --const (id 0) (id 'a') ; +main = sumlength (Cons 1 (Cons 2 (Cons 3 (Cons 4 (Cons 5 Nil))))); + +-- take the length of a list +length : List () -> Int ; +length x = case x of { + Cons _ xs => 1 + length xs ; + Nil => 0 ; +}; +-- sum a list +sum : List () -> Int ; +sum x = case x of { + Cons a xs => a + sum xs ; + Nil => 0 ; +}; + +-- sum + length of a list +sumlength: List () -> Int ; +sumlength x = sum x + length x ; \ No newline at end of file