diff --git a/language.cabal b/language.cabal index af7178c..e299f24 100644 --- a/language.cabal +++ b/language.cabal @@ -43,6 +43,7 @@ executable language TypeChecker.ReportTEVar TypeChecker.RemoveForall LambdaLifter + LambdaLifterIr Monomorphizer.Monomorphizer Monomorphizer.MonomorphizerIr Monomorphizer.MorbIr @@ -101,6 +102,8 @@ Test-suite language-testsuite TypeChecker.TypeChecker AnnForall ReportForall + LambdaLifterIr + LambdaLifter TypeChecker.TypeCheckerHm TypeChecker.TypeCheckerBidir TypeChecker.ReportTEVar diff --git a/sample-programs/working/addition.chrf b/sample-programs/working/addition.chrf new file mode 100644 index 0000000..7bddab7 --- /dev/null +++ b/sample-programs/working/addition.chrf @@ -0,0 +1,6 @@ + + +add : Int -> Int -> Int -> Int +add x y z = x + y + z + +main = add 8 6 2 diff --git a/sample-programs/working/apply.crf b/sample-programs/working/apply.crf new file mode 100644 index 0000000..61c76ad --- /dev/null +++ b/sample-programs/working/apply.crf @@ -0,0 +1,7 @@ + + + +apply : (Int -> Int) -> Int -> Int +apply f y = f y + +main = apply (\y. y + y) 5 diff --git a/sample-programs/working/closure.crf b/sample-programs/working/closure.crf new file mode 100644 index 0000000..b85ab32 --- /dev/null +++ b/sample-programs/working/closure.crf @@ -0,0 +1,10 @@ + + + + +apply : (Int -> Int) -> Int -> Int +apply f z = f z + +main = + let x = 10 in + apply (\y. y + x) 6 diff --git a/sample-programs/working/foldr.crf b/sample-programs/working/foldr.crf new file mode 100644 index 0000000..da798ac --- /dev/null +++ b/sample-programs/working/foldr.crf @@ -0,0 +1,15 @@ + +data List a where + Nil : List a + Cons : a -> List a -> List a + +foldr : (a -> b -> b) -> b -> List a -> b +foldr f y xs = case xs of + Nil => y + Cons x xs => f x (foldr f y xs) + + +main = let z = 2 in foldr (\x.\y. x + y + z) 0 (Cons 1000 (Cons 100 Nil)) + + + diff --git a/sample-programs/working/lambda-2.crf b/sample-programs/working/lambda-2.crf new file mode 100644 index 0000000..f081d92 --- /dev/null +++ b/sample-programs/working/lambda-2.crf @@ -0,0 +1,25 @@ +data List (a) where + Nil : List (a) + Cons : a -> List (a) -> List (a) + +map : (a -> b) -> List (a) -> List (b) +map f xs = case xs of + Nil => Nil + Cons x xs => Cons (f x) (map f xs) + +add : Int -> Int -> Int +add x y = x + y + +foldr : (a -> b -> b) -> b -> List (a) -> b +foldr f y xs = case xs of + Nil => y + Cons x xs => f x (foldr f y xs) + +f : List (Int) +f = ((\x.\ys. map (\y. add y x) ys) 4 (Cons 1 (Cons 2 Nil))) +-- [5, 6] + +main : Int +main = foldr add 0 f + + diff --git a/sample-programs/working/lambda.crf b/sample-programs/working/lambda.crf new file mode 100644 index 0000000..3dcb947 --- /dev/null +++ b/sample-programs/working/lambda.crf @@ -0,0 +1,21 @@ +data List a where + Nil : List a + Cons : a -> List a -> List a + +map : (a -> b) -> List a -> List b +map f xs = case xs of + Nil => Nil + Cons x xs => Cons (f x) (map f xs) + + +f : List Int +f = (\x.\ys. map (\y. y + x) ys) 4 (Cons 1 (Cons 2 Nil)) +-- [5, 6] + +sum : List Int -> Int +sum xs = case xs of + Nil => 0 + Cons x xs => x + sum xs + +main = sum f + diff --git a/sample-programs/working/let.crf b/sample-programs/working/let.crf new file mode 100644 index 0000000..9ed4abe --- /dev/null +++ b/sample-programs/working/let.crf @@ -0,0 +1,3 @@ + + +main = let x = 10 in 6 + x diff --git a/sample-programs/working/map.crf b/sample-programs/working/map.crf new file mode 100644 index 0000000..4e77ad8 --- /dev/null +++ b/sample-programs/working/map.crf @@ -0,0 +1,16 @@ + +data List a where + Nil : List a + Cons : a -> List a -> List a + +map : (a -> b) -> List a -> List b +map f xs = case xs of + Nil => Nil + Cons x xs => Cons (f x) (map f xs) + +sum : List Int -> Int +sum xs = case xs of + Nil => 0 + Cons x xs => x + (sum xs) + +main = let y = 10 in sum (map (\x. x + y) (Cons 2 (Cons 4 Nil))) diff --git a/sample-programs/working/simple.crf b/sample-programs/working/simple.crf new file mode 100644 index 0000000..04d3ef8 --- /dev/null +++ b/sample-programs/working/simple.crf @@ -0,0 +1,7 @@ + + + +f = 10 + + +main = f + 6 diff --git a/src/Codegen/Auxillary.hs b/src/Codegen/Auxillary.hs index c95be39..af31504 100644 --- a/src/Codegen/Auxillary.hs +++ b/src/Codegen/Auxillary.hs @@ -1,25 +1,25 @@ module Codegen.Auxillary where -import Codegen.LlvmIr (LLVMType (..), LLVMValue (..)) -import Control.Monad (foldM_) -import Monomorphizer.MonomorphizerIr as MIR (ExpT, Type (..)) -import TypeChecker.TypeCheckerIr qualified as TIR +import Codegen.LlvmIr (LLVMType (..), LLVMValue (..)) +import Control.Monad (foldM_) +import Monomorphizer.MonomorphizerIr as MIR (Exp, T, Type (..)) +import qualified TypeChecker.TypeCheckerIr as TIR type2LlvmType :: MIR.Type -> LLVMType type2LlvmType (MIR.TLit id@(TIR.Ident name)) = case name of - "Int" -> I64 + "Int" -> I64 "Char" -> I8 "Bool" -> I1 - _ -> CustomType id + _ -> CustomType id type2LlvmType (MIR.TFun t xs) = do let (t', xs') = function2LLVMType xs [type2LlvmType t] Function t' xs' where function2LLVMType :: Type -> [LLVMType] -> (LLVMType, [LLVMType]) function2LLVMType (TFun t xs) s = function2LLVMType xs (type2LlvmType t : s) - function2LLVMType x s = (type2LlvmType x, s) + function2LLVMType x s = (type2LlvmType x, s) -getType :: ExpT -> LLVMType +getType :: T Exp -> LLVMType getType (_, t) = type2LlvmType t extractTypeName :: MIR.Type -> TIR.Ident @@ -30,21 +30,21 @@ extractTypeName (MIR.TFun t xs) = in TIR.Ident $ i <> "_$_" <> is valueGetType :: LLVMValue -> LLVMType -valueGetType (VInteger _) = I64 -valueGetType (VChar _) = I8 -valueGetType (VIdent _ t) = t -valueGetType (VConstant s) = Array (fromIntegral $ length s) I8 +valueGetType (VInteger _) = I64 +valueGetType (VChar _) = I8 +valueGetType (VIdent _ t) = t +valueGetType (VConstant s) = Array (fromIntegral $ length s) I8 valueGetType (VFunction _ _ t) = t typeByteSize :: LLVMType -> Integer -typeByteSize I1 = 1 -typeByteSize I8 = 1 -typeByteSize I32 = 4 -typeByteSize I64 = 8 -typeByteSize Ptr = 8 -typeByteSize (Ref _) = 8 +typeByteSize I1 = 1 +typeByteSize I8 = 1 +typeByteSize I32 = 4 +typeByteSize I64 = 8 +typeByteSize Ptr = 8 +typeByteSize (Ref _) = 8 typeByteSize (Function _ _) = 8 -typeByteSize (Array n t) = n * typeByteSize t +typeByteSize (Array n t) = n * typeByteSize t typeByteSize (CustomType _) = 8 enumerateOneM_ :: Monad m => (Integer -> a -> m b) -> [a] -> m () diff --git a/src/Codegen/Codegen.hs b/src/Codegen/Codegen.hs index be92a35..6f66c36 100644 --- a/src/Codegen/Codegen.hs +++ b/src/Codegen/Codegen.hs @@ -1,18 +1,24 @@ +{-# LANGUAGE OverloadedRecordDot #-} +{-# LANGUAGE OverloadedStrings #-} + module Codegen.Codegen (generateCode) where -import Codegen.CompilerState ( - CodeGenerator (instructions), - initCodeGenerator, - ) -import Codegen.Emits (compileScs) -import Codegen.LlvmIr as LIR (llvmIrToString) -import Control.Monad.State ( - execStateT, - ) -import Data.List (sortBy) -import Grammar.ErrM (Err) -import Monomorphizer.MonomorphizerIr as MIR (Bind (..), Data (..), Def (DBind, DData), Program (..), Type (TLit)) -import TypeChecker.TypeCheckerIr (Ident (..)) +import Codegen.CompilerState (CodeGenerator (..), + StructType (inst), + initCodeGenerator) +import Codegen.Emits (compileScs) +import Codegen.LlvmIr as LIR (LLVMIr (UnsafeRaw), + llvmIrToString) +import Control.Monad.State (execStateT) +import Data.Functor ((<&>)) +import Data.List (sortBy) +import qualified Data.Map as Map +import Grammar.ErrM (Err) +import Monomorphizer.MonomorphizerIr as MIR (Bind (..), Data (..), + Def (DBind, DData), + Program (..), + Type (TLit)) +import TypeChecker.TypeCheckerIr (Ident (..)) {- | Compiles an AST and produces a LLVM Ir string. An easy way to actually "compile" this output is to @@ -20,16 +26,43 @@ import TypeChecker.TypeCheckerIr (Ident (..)) -} generateCode :: MIR.Program -> Bool -> Err String generateCode (MIR.Program scs) addGc = do - let tree = filter (not . detectPrelude) (sortBy lowData scs) - let codegen = initCodeGenerator addGc tree - llvmIrToString . instructions <$> execStateT (compileScs tree) codegen + let tree = filter (not . detectPrelude) (sortBy lowData scs) + codegen = initCodeGenerator addGc tree + + -- Append instructions + execStateT (compileScs tree) codegen <&> \state -> + llvmIrToString $ defaultStart + ++ (if addGc then gcStart else []) + ++ map inst (Map.elems state.structTypes) + ++ state.instructions detectPrelude :: Def -> Bool -detectPrelude (DData (Data (TLit (Ident "Bool")) _)) = True +detectPrelude (DData (Data (TLit (Ident "Bool")) _)) = True detectPrelude (DBind (Bind (Ident ('l' : 't' : '$' : _), _) _ _)) = True -detectPrelude _ = False +detectPrelude _ = False lowData :: Def -> Def -> Ordering lowData (DData _) (DBind _) = LT lowData (DBind _) (DData _) = GT -lowData _ _ = EQ \ No newline at end of file +lowData _ _ = EQ + +defaultStart :: [LLVMIr] +defaultStart = + [ UnsafeRaw "target triple = \"x86_64-pc-linux-gnu\"\n" + , UnsafeRaw "target datalayout = \"e-m:o-i64:64-f80:128-n8:16:32:64-S128\"\n" + , UnsafeRaw "@.str = private unnamed_addr constant [3 x i8] c\"%i\n\", align 1\n" + , UnsafeRaw "@.non_exhaustive_patterns = private unnamed_addr constant [41 x i8] c\"Non-exhaustive patterns in case at %i:%i\n\"\n" + , UnsafeRaw "declare i32 @printf(ptr noalias nocapture, ...)\n" + , UnsafeRaw "declare i32 @exit(i32 noundef)\n" + , UnsafeRaw "declare ptr @malloc(i32 noundef)\n" + ] + +gcStart :: [LLVMIr] +gcStart = + [ UnsafeRaw "declare external void @cheap_init()\n" + , UnsafeRaw "declare external ptr @cheap_alloc(i64)\n" + , UnsafeRaw "declare external void @cheap_dispose()\n" + , UnsafeRaw "declare external ptr @cheap_the()\n" + , UnsafeRaw "declare external void @cheap_set_profiler(ptr, i1)\n" + , UnsafeRaw "declare external void @cheap_profiler_log_options(ptr, i64)\n" + ] diff --git a/src/Codegen/CompilerState.hs b/src/Codegen/CompilerState.hs index 523cc54..b455712 100644 --- a/src/Codegen/CompilerState.hs +++ b/src/Codegen/CompilerState.hs @@ -1,46 +1,101 @@ +{-# LANGUAGE DuplicateRecordFields #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedRecordDot #-} +{-# LANGUAGE OverloadedStrings #-} + module Codegen.CompilerState where import Auxiliary (snoc) import Codegen.Auxillary (type2LlvmType, typeByteSize) -import Codegen.LlvmIr as LIR (LLVMIr (UnsafeRaw), - LLVMType) -import Control.Monad.State (StateT, gets, modify) +import Codegen.LlvmIr as LIR (LLVMIr (SetVariable, Type), + LLVMType (CustomType, Function, I64, Ptr), + LLVMValue (VFunction, VIdent), + Visibility (Global), + typeOf) +import Control.Monad.State (StateT, gets, modify, void) import Data.Map (Map) import qualified Data.Map as Map import Grammar.ErrM (Err) -import Monomorphizer.MonomorphizerIr as MIR +import Monomorphizer.MonomorphizerIr (Ident (..), Inj (..), T, + flattenType) +import qualified Monomorphizer.MonomorphizerIr as MIR import qualified TypeChecker.TypeCheckerIr as TIR -- | The record used as the code generator state data CodeGenerator = CodeGenerator { instructions :: [LLVMIr] - , functions :: Map MIR.Id FunctionInfo + , functions :: Map (T Ident) FunctionInfo , customTypes :: Map LLVMType Integer - , constructors :: Map TIR.Ident ConstructorInfo + , constructors :: Map Ident ConstructorInfo , variableCount :: Integer , labelCount :: Integer , gcEnabled :: Bool + , structTypes :: Map Ident StructType + -- ^ Custom stucture types + , locals :: [(Ident, LocalElem)] + -- ^ Arguments and variables in local environment + , globals :: Map Ident (LLVMType, LLVMValue) } +data StructType = StructType + { ptr :: LLVMType + , typs :: [LLVMType] + , inst :: LLVMIr + } + +data LocalElem = LocalElem + { typ :: LLVMType + , val :: LLVMValue + } + + -- | A state type synonym type CompilerState a = StateT CodeGenerator Err a data FunctionInfo = FunctionInfo { numArgs :: Int - , arguments :: [Id] + , arguments :: [T Ident] } deriving (Show) data ConstructorInfo = ConstructorInfo { numArgsCI :: Int - , argumentsCI :: [Id] + , argumentsCI :: [T Ident] , numCI :: Integer , returnTypeCI :: MIR.Type } deriving (Show) + +addStructType_ :: Ident -> [LLVMType] -> CompilerState () +addStructType_ = fmap void . addStructType + +addStructType :: Ident -> [LLVMType] -> CompilerState LLVMType +addStructType x ts = do + modify $ \s -> s { structTypes = Map.insert x struct s.structTypes } + pure t + where + struct = StructType + { ptr = t + , typs = ts + , inst = Type x ts + } + t = CustomType x + -- | Adds a instruction to the CodeGenerator state emit :: LLVMIr -> CompilerState () -emit l = modify $ \t -> t{instructions = Auxiliary.snoc l $ instructions t} + +-- Add variable to environment +emit l@(SetVariable x _) = modify $ \t -> + t { instructions = Auxiliary.snoc l t.instructions + , locals = snoc (x, local) + t.locals + } + where + local = LocalElem { typ = typeOf l + , val = VIdent x (typeOf l) + } + +emit l = modify $ \t -> t { instructions = Auxiliary.snoc l t.instructions } -- | Increases the variable counter in the CodeGenerator state increaseVarCount :: CompilerState () @@ -63,16 +118,19 @@ getNewLabel = do {- | Produces a map of functions infos from a list of binds, which contains useful data for code generation. -} -getFunctions :: [MIR.Def] -> Map Id FunctionInfo +getFunctions :: [MIR.Def] -> Map (T Ident) FunctionInfo getFunctions bs = Map.fromList $ go bs where go [] = [] go (MIR.DBind (MIR.Bind id args _) : xs) = - (id, FunctionInfo{numArgs = length args, arguments = args}) - : go xs + (id, FunctionInfo { numArgs = length args + , arguments = args + } + ) + : go xs go (_ : xs) = go xs -createArgs :: [MIR.Type] -> [Id] +createArgs :: [MIR.Type] -> [T Ident] createArgs xs = fst $ foldl (\(acc, l) t -> (acc ++ [(TIR.Ident ("arg_" <> show l), t)], l + 1)) ([], 0) xs {- | Produces a map of functions infos from a list of binds, @@ -113,35 +171,43 @@ getTypes bs = Map.fromList $ go bs variantTypes fi = init $ map type2LlvmType (flattenType fi) biggestVariant ts = 8 + maximum (sum . (\(Inj _ fi) -> typeByteSize <$> variantTypes fi) <$> ts) +getGlobals :: [MIR.Def] -> Map Ident (LLVMType, LLVMValue) +getGlobals scs = Map.fromList [ go b | MIR.DBind b <- scs ] + where + go bind | x == "main" = let typ = Function I64 [] + in (x, (typ, VFunction x Global typ)) + | otherwise = (x, (typ, VFunction x Global typ)) + where + typ = Function tr $ Ptr : ts + Function tr ts = type2LlvmType' t + + (x, t) = case bind of + MIR.Bind xt _ _ -> xt + MIR.BindC _ xt _ _ -> xt + + -- Higher order function arguments are replaced with ptr + type2LlvmType' = go [] + where + go acc = \case + MIR.TFun (MIR.TFun _ _) t2 -> go (snoc Ptr acc) t2 + MIR.TFun t1 t2 -> go (snoc (type2LlvmType t1) acc) t2 + t -> Function (type2LlvmType t) acc + + + + initCodeGenerator :: Bool -> [MIR.Def] -> CodeGenerator initCodeGenerator addGc scs = CodeGenerator - { instructions = defaultStart <> if addGc then gcStart else [] + { instructions = [] , functions = getFunctions scs , constructors = getConstructors scs , customTypes = getTypes scs + , structTypes = mempty , variableCount = 0 , labelCount = 0 , gcEnabled = addGc + , locals = mempty + , globals = getGlobals scs } -defaultStart :: [LLVMIr] -defaultStart = - [ UnsafeRaw "target triple = \"x86_64-pc-linux-gnu\"\n" - , UnsafeRaw "target datalayout = \"e-m:o-i64:64-f80:128-n8:16:32:64-S128\"\n" - , UnsafeRaw "@.str = private unnamed_addr constant [3 x i8] c\"%i\n\", align 1\n" - , UnsafeRaw "@.non_exhaustive_patterns = private unnamed_addr constant [41 x i8] c\"Non-exhaustive patterns in case at %i:%i\n\"\n" - , UnsafeRaw "declare i32 @printf(ptr noalias nocapture, ...)\n" - , UnsafeRaw "declare i32 @exit(i32 noundef)\n" - , UnsafeRaw "declare ptr @malloc(i32 noundef)\n" - ] - -gcStart :: [LLVMIr] -gcStart = - [ UnsafeRaw "declare external void @cheap_init()\n" - , UnsafeRaw "declare external ptr @cheap_alloc(i64)\n" - , UnsafeRaw "declare external void @cheap_dispose()\n" - , UnsafeRaw "declare external ptr @cheap_the()\n" - , UnsafeRaw "declare external void @cheap_set_profiler(ptr, i1)\n" - , UnsafeRaw "declare external void @cheap_profiler_log_options(ptr, i64)\n" - ] diff --git a/src/Codegen/Emits.hs b/src/Codegen/Emits.hs index bc19f87..9c6f59f 100644 --- a/src/Codegen/Emits.hs +++ b/src/Codegen/Emits.hs @@ -1,36 +1,40 @@ -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE DuplicateRecordFields #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE OverloadedRecordDot #-} +{-# LANGUAGE OverloadedStrings #-} module Codegen.Emits where -import Codegen.Auxillary -import Codegen.CompilerState -import Codegen.LlvmIr as LIR -import Control.Applicative ((<|>)) -import Control.Monad (when) -import Control.Monad.State (gets, modify) -import Data.Bifunctor qualified as BI -import Data.Char (ord) -import Data.Coerce (coerce) -import Data.Map qualified as Map -import Data.Maybe (fromJust, fromMaybe, isNothing) -import Data.Tuple.Extra (dupe, first, second) -import Debug.Trace (trace, traceShow) -import Grammar.Print -import Monomorphizer.MonomorphizerIr as MIR -import TypeChecker.TypeCheckerIr qualified as TIR +import Auxiliary (snoc) +import Codegen.Auxillary +import Codegen.CompilerState +import Codegen.LlvmIr as LIR +import Control.Applicative (Applicative (liftA2), (<|>)) +import Control.Monad (forM_, when, zipWithM_) +import Control.Monad.Extra (whenJust) +import Control.Monad.State (gets, modify) +import Data.Char (ord) +import Data.Coerce (coerce) +import Data.Foldable.Extra (notNull) +import qualified Data.Map as Map +import Data.Maybe (fromJust, fromMaybe, isNothing) +import Data.Tuple.Extra (second) +import Grammar.Print (printTree) +import Monomorphizer.MonomorphizerIr -compileScs :: [MIR.Def] -> CompilerState () + +compileScs :: [Def] -> CompilerState () compileScs [] = do emit $ UnsafeRaw "\n" + mapM_ createConstructor =<< gets (Map.toList . constructors) -- as a last step create all the constructors -- //TODO maybe merge this with the data type match? - c <- gets (Map.toList . constructors) - mapM_ - ( \(id, ci) -> do - let t = returnTypeCI ci - let t' = type2LlvmType t - let x = BI.second type2LlvmType <$> argumentsCI ci + where + createConstructor (id, ci) = do + let t = returnTypeCI ci + t' = type2LlvmType t + x = (mkCxtName, Ptr) : map (second type2LlvmType) ci.argumentsCI emit $ Define FastCC t' id x top <- getNewVar ptr <- getNewVar @@ -56,7 +60,7 @@ compileScs [] = do cTypes <- gets customTypes enumerateOneM_ - ( \i (TIR.Ident arg_n, arg_t) -> do + ( \i (Ident arg_n, arg_t) -> do let arg_t' = type2LlvmType arg_t emit $ Comment (toIr arg_t' <> " " <> arg_n <> " " <> show i) elemPtr <- getNewVar @@ -78,11 +82,11 @@ compileScs [] = do heapPtr <- getNewVar useGc <- gets gcEnabled emit $ SetVariable heapPtr (if useGc then GcMalloc s else Malloc s) - emit $ Store arg_t' (VIdent (TIR.Ident arg_n) arg_t') Ptr heapPtr + emit $ Store arg_t' (VIdent (Ident arg_n) arg_t') Ptr heapPtr emit $ Store (Ref arg_t') (VIdent heapPtr arg_t') Ptr elemPtr Nothing -> do emit $ Comment "Just store" - emit $ Store arg_t' (VIdent (TIR.Ident arg_n) arg_t') Ptr elemPtr + emit $ Store arg_t' (VIdent (Ident arg_n) arg_t') Ptr elemPtr ) (argumentsCI ci) @@ -95,34 +99,83 @@ compileScs [] = do emit $ UnsafeRaw "\n" modify $ \s -> s{variableCount = 0} - ) - c -compileScs (MIR.DBind (MIR.Bind (name, t) args exp) : xs) = do - let t_return = type2LlvmType . last . flattenType $ t + +compileScs (DBind bind : xs) = do emit $ UnsafeRaw "\n" - emit . Comment $ show name <> ": " <> show exp - let args' = map (second type2LlvmType) args + emit . Comment $ show name <> ": " <> show (fst exp) + + Function t_return t_args <- gets $ fst + . fromJust + . Map.lookup name + . globals + + let args' = zip (mkCxtName : map fst args) t_args + emit $ Define FastCC t_return name args' - useGc <- gets gcEnabled - when (name == "main") (mapM_ emit (firstMainContent useGc)) - functionBody <- exprToValue exp - if name == "main" - then mapM_ emit $ lastMainContent useGc functionBody - else emit $ Ret t_return functionBody + modify $ \s -> s { locals = foldr insertArg s.locals args' } + + -- Dereference ptr arguments + when (notNull args') $ + forM_ (tail args') $ \(x, t) -> when (t == Ptr) $ do + let t_deref = + let + Function t ts = type2LlvmType . fromJust $ lookup x args + in + Function t (Ptr : ts) + + emit . SetVariable (mkDerefName x) + $ Load t_deref Ptr x + + whenJust mcxt loadFreeVars + + gcEnabled <- gets gcEnabled + when isMain $ mapM_ emit (firstMainContent gcEnabled) + + result <- exprToValue exp + + if isMain + then mapM_ emit $ lastMainContent gcEnabled result + else emit $ Ret t_return result + emit DefineEnd - modify $ \s -> s{variableCount = 0} + -- Reset variable count and empty locals + modify $ \s -> s { variableCount = 0, locals = mempty } compileScs xs -compileScs (MIR.DData (MIR.Data typ ts) : xs) = do - let (TIR.Ident outer_id) = extractTypeName typ + where + loadFreeVars cxt = do + emit $ Comment "Load free variables" + zipWithM_ go cxt' [1 ..] + where + go (x, t) i = do + vc <- getNewVar + emit . SetVariable vc + $ GetElementPtrInbounds (CustomType $ mkClosureName name) Ptr (VIdent mkCxtName Ptr) + I32 (VInteger 0) I32 (VInteger i) -- TODO fix indices + emit . SetVariable x $ Load t Ptr vc + cxt' = map (second type2LlvmType) cxt + + isMain = name == "main" + + (name, args, exp, mcxt) = case bind of + Bind (name, _) args exp -> (name, args, exp, Nothing) + BindC cxt (name, _) args exp -> (name, args, exp, Just cxt) + + + insertArg (x, t) = snoc (x, LocalElem { val = VIdent x t, typ = t }) + +compileScs (DData (Data typ ts) : xs) = do + let (Ident outer_id) = extractTypeName typ -- //TODO this could be extracted from the customTypes map let variantTypes fi = init $ map type2LlvmType (flattenType fi) let biggestVariant = 7 + maximum (sum . (\(Inj _ fi) -> typeByteSize <$> variantTypes fi) <$> ts) - emit $ LIR.Type (TIR.Ident outer_id) [I8, Array biggestVariant I8] + -- Add data type (e.g. %List) to top of the file + addStructType_ (Ident outer_id) [I8, Array biggestVariant I8] typeSets <- gets customTypes mapM_ ( \(Inj inner_id fi) -> do let types = (\s -> if Map.member s typeSets then Ref s else s) <$> variantTypes fi - emit $ LIR.Type inner_id (I8 : types) + -- Add constructor type (e.g. %Cons) to top of the file + addStructType_ inner_id (I8 : types) ) ts compileScs xs @@ -149,16 +202,16 @@ lastMainContent False var = , Ret I64 (VInteger 0) ] -compileExp :: ExpT -> CompilerState () -compileExp (MIR.ELit lit, _t) = emitLit lit -compileExp (MIR.EAdd e1 e2, t) = emitAdd t e1 e2 -compileExp (MIR.EVar name, _t) = emitIdent name -compileExp (MIR.EApp e1 e2, t) = emitApp t e1 e2 -compileExp (MIR.ELet bind e, _) = emitLet bind e -compileExp (MIR.ECase e cs, t) = emitECased t e (map (t,) cs) +compileExp :: T Exp -> CompilerState () +compileExp (ELit lit, _t) = emitLit lit +compileExp (EAdd e1 e2, t) = emitAdd t e1 e2 +compileExp (EVar name, _t) = emitIdent name +compileExp (EApp e1 e2, t) = emitApp t e1 e2 +compileExp (ELet bind e, _) = emitLet bind e +compileExp (ECase e cs, t) = emitECased t e (map (t,) cs) -emitLet :: MIR.Bind -> ExpT -> CompilerState () -emitLet (MIR.Bind id [] innerExp) e = do +emitLet :: Bind -> T Exp -> CompilerState () +emitLet (Bind id [] innerExp) e = do evaled <- exprToValue innerExp tempVar <- getNewVar let t = type2LlvmType . snd $ innerExp @@ -168,14 +221,14 @@ emitLet (MIR.Bind id [] innerExp) e = do compileExp e emitLet b _ = error $ "Non empty argument list in let-bind " <> show b -emitECased :: MIR.Type -> ExpT -> [(MIR.Type, Branch)] -> CompilerState () +emitECased :: Type -> T Exp -> [(Type, Branch)] -> CompilerState () emitECased t e cases = do let cs = snd <$> cases let ty = type2LlvmType t let rt = type2LlvmType (snd e) vs <- exprToValue e lbl <- getNewLabel - let label = TIR.Ident $ "escape_" <> show lbl + let label = Ident $ "escape_" <> show lbl stackPtr <- getNewVar emit $ SetVariable stackPtr (Alloca ty) mapM_ (emitCases rt ty label stackPtr vs) cs @@ -192,14 +245,14 @@ emitECased t e cases = do res <- getNewVar emit $ SetVariable res (Load ty Ptr stackPtr) where - emitCases :: LLVMType -> LLVMType -> TIR.Ident -> TIR.Ident -> LLVMValue -> Branch -> CompilerState () - emitCases rt ty label stackPtr vs (Branch (MIR.PInj consId cs, _t) exp) = do + emitCases :: LLVMType -> LLVMType -> Ident -> Ident -> LLVMValue -> Branch -> CompilerState () + emitCases rt ty label stackPtr vs (Branch (PInj consId cs, _t) exp) = do emit $ Comment "Inj" cons <- gets constructors let r = fromJust $ Map.lookup consId cons - lbl_failPos <- (\x -> TIR.Ident $ "failed_" <> show x) <$> getNewLabel - lbl_succPos <- (\x -> TIR.Ident $ "success_" <> show x) <$> getNewLabel + lbl_failPos <- (\x -> Ident $ "failed_" <> show x) <$> getNewLabel + lbl_succPos <- (\x -> Ident $ "success_" <> show x) <$> getNewLabel consVal <- getNewVar emit $ SetVariable consVal (ExtractValue rt vs 0) @@ -215,10 +268,10 @@ emitECased t e cases = do emit $ Store rt vs Ptr castPtr emit $ SetVariable casted (Load (CustomType (coerce consId)) Ptr castPtr) enumerateOneM_ - ( \i c -> do + ( \i (c, t) -> do case c of - PVar (x, topT) -> do - let topT' = type2LlvmType topT + PVar x -> do + let topT' = type2LlvmType t let botT' = CustomType (coerce consId) emit . Comment $ "ident " <> toIr topT' cTypes <- gets customTypes @@ -228,7 +281,7 @@ emitECased t e cases = do emit $ SetVariable deref (ExtractValue botT' (VIdent casted Ptr) i) emit $ SetVariable x (Load topT' Ptr deref) else emit $ SetVariable x (ExtractValue botT' (VIdent casted Ptr) i) - PLit (_l, _t) -> error "Nested pattern matching to be implemented" + PLit _l -> error "Nested pattern matching to be implemented" PInj _id _ps -> error "Nested pattern matching to be implemented" PCatch -> pure () PEnum _id -> error "Nested pattern matching to be implemented" @@ -238,22 +291,22 @@ emitECased t e cases = do emit $ Store ty val Ptr stackPtr emit $ Br label emit $ Label lbl_failPos - emitCases _rt ty label stackPtr vs (Branch (MIR.PLit (i, ct), t) exp) = do + emitCases _rt ty label stackPtr vs (Branch (PLit i, t) exp) = do emit $ Comment "Plit" let i' = case i of - MIR.LInt i -> VInteger i - MIR.LChar i -> VChar (ord i) + LInt i -> VInteger i + LChar i -> VChar (ord i) ns <- getNewVar - lbl_failPos <- (\x -> TIR.Ident $ "failed_" <> show x) <$> getNewLabel - lbl_succPos <- (\x -> TIR.Ident $ "success_" <> show x) <$> getNewLabel - emit $ SetVariable ns (Icmp LLEq (type2LlvmType ct) vs i') + lbl_failPos <- (\x -> Ident $ "failed_" <> show x) <$> getNewLabel + lbl_succPos <- (\x -> Ident $ "success_" <> show x) <$> getNewLabel + emit $ SetVariable ns (Icmp LLEq (type2LlvmType t) vs i') emit $ BrCond (VIdent ns ty) lbl_succPos lbl_failPos emit $ Label lbl_succPos val <- exprToValue exp emit $ Store ty val Ptr stackPtr emit $ Br label emit $ Label lbl_failPos - emitCases rt ty label stackPtr vs (Branch (MIR.PVar (id, _), _) exp) = do + emitCases rt ty label stackPtr vs (Branch (PVar id, _) exp) = do emit $ Comment "Pvar" -- //TODO this is pretty disgusting and would heavily benefit from a rewrite valPtr <- getNewVar @@ -263,20 +316,20 @@ emitECased t e cases = do val <- exprToValue exp emit $ Store ty val Ptr stackPtr emit $ Br label - lbl_failPos <- (\x -> TIR.Ident $ "failed_" <> show x) <$> getNewLabel + lbl_failPos <- (\x -> Ident $ "failed_" <> show x) <$> getNewLabel emit $ Label lbl_failPos - emitCases rt ty label stackPtr vs (Branch (MIR.PEnum (TIR.Ident "True$Bool"), t) exp) = do - emitCases rt ty label stackPtr vs (Branch (MIR.PLit (MIR.LInt 1, TLit "Bool"), t) exp) - emitCases rt ty label stackPtr vs (Branch (MIR.PEnum (TIR.Ident "False$Bool"), _) exp) = do - emitCases rt ty label stackPtr vs (Branch (MIR.PLit (MIR.LInt 0, TLit "Bool"), t) exp) - emitCases rt ty label stackPtr vs br@(Branch (MIR.PEnum consId, _) exp) = do + emitCases rt ty label stackPtr vs (Branch (PEnum (Ident "True$Bool"), t) exp) = do + emitCases rt ty label stackPtr vs (Branch (PLit $ LInt 1, t) exp) + emitCases rt ty label stackPtr vs (Branch (PEnum (Ident "False$Bool"), _) exp) = do + emitCases rt ty label stackPtr vs (Branch (PLit (LInt 0), t) exp) + emitCases rt ty label stackPtr vs br@(Branch (PEnum consId, _) exp) = do emit $ Comment "Penum" cons <- gets constructors let r = Map.lookup consId cons when (isNothing r) (error $ "Constructor: '" ++ printTree consId ++ "' does not exist in cons state:\n" ++ show cons ++ "\nin pattern\n'" ++ printTree br ++ "'\n") - lbl_failPos <- (\x -> TIR.Ident $ "failed_" <> show x) <$> getNewLabel - lbl_succPos <- (\x -> TIR.Ident $ "success_" <> show x) <$> getNewLabel + lbl_failPos <- (\x -> Ident $ "failed_" <> show x) <$> getNewLabel + lbl_succPos <- (\x -> Ident $ "success_" <> show x) <$> getNewLabel consVal <- getNewVar emit $ SetVariable consVal (ExtractValue rt vs 0) @@ -295,98 +348,167 @@ emitECased t e cases = do emit $ Store ty val Ptr stackPtr emit $ Br label emit $ Label lbl_failPos - emitCases _ ty label stackPtr _ (Branch (MIR.PCatch, _) exp) = do + emitCases _ ty label stackPtr _ (Branch (PCatch, _) exp) = do emit $ Comment "Pcatch" val <- exprToValue exp emit $ Store ty val Ptr stackPtr emit $ Br label - lbl_failPos <- (\x -> TIR.Ident $ "failed_" <> show x) <$> getNewLabel + lbl_failPos <- (\x -> Ident $ "failed_" <> show x) <$> getNewLabel emit $ Label lbl_failPos -emitApp :: MIR.Type -> ExpT -> ExpT -> CompilerState () -emitApp rt e1 e2 = appEmitter e1 e2 [] - where - appEmitter :: ExpT -> ExpT -> [ExpT] -> CompilerState () - appEmitter e1 e2 stack = do - let newStack = e2 : stack - case e1 of - (MIR.EApp e1' e2', _) -> appEmitter e1' e2' newStack - (MIR.EVar name, t) -> do - args <- traverse exprToValue newStack - vs <- getNewVar - funcs <- gets functions - consts <- gets constructors - let visibility = - fromMaybe Local $ - Global <$ Map.lookup name consts - <|> Global <$ Map.lookup (name, t) funcs - -- this piece of code could probably be improved, i.e remove the double `const Global` - args' = map (first valueGetType . dupe) args - let call = - case name of - TIR.Ident ('l' : 't' : '$' : _) -> Icmp LLSlt I64 (snd (head args')) (snd (args' !! 1)) - TIR.Ident ('$' : 'm' : 'i' : 'n' : 'u' : 's' : '$' : '$' : _) -> Sub I64 (snd (head args')) (snd (args' !! 1)) - _ -> Call FastCC (type2LlvmType rt) visibility name args' - emit $ Comment $ show rt - emit $ SetVariable vs call - x -> error $ "The unspeakable happened: " <> show x +emitApp :: Type -> T Exp -> T Exp -> CompilerState () +emitApp rt e1 e2 = do + ((EVar name, t), args) <- go (EApp e1 e2, rt) + vs <- getNewVar + funcs <- gets functions + consts <- gets constructors + let visibility = + fromMaybe Local $ + Global <$ Map.lookup name consts + <|> Global <$ Map.lookup (name, t) funcs + -- this piece of code could probably be improved, i.e remove the double `const Global` -emitIdent :: TIR.Ident -> CompilerState () + call <- case name of + Ident ('l' : 't' : '$' : _) -> + pure $ Icmp LLSlt I64 (snd (head args)) (snd (args !! 1)) + Ident ('$' : 'm' : 'i' : 'n' : 'u' : 's' : '$' : '$' : _) -> + pure $ Sub I64 (snd (head args)) (snd (args !! 1)) + + -- FIXME + _ -> do + let closure_call LocalElem { typ = Ptr, val } = (mkDerefName name, (Ptr, val) : args) + + (name, args) <- gets $ maybe (name, (Ptr, VNull) : args) closure_call + . lookup name + . locals + + pure $ Call FastCC (type2LlvmType rt) visibility name args + + emit $ Comment $ show (type2LlvmType rt) + emit $ SetVariable vs call + + where + + go :: T Exp -> CompilerState (T Exp, [(LLVMType, LLVMValue)]) + go et@(e, _) = case e of + EApp e1 e2@(_, t) -> do + (x, as) <- go e1 + a <- exprToValue e2 + let t' = type2LlvmType' t + pure (x, snoc (t', a) as) + _ -> pure (et, []) + + type2LlvmType' = \case + TFun _ _ -> Ptr + t -> type2LlvmType t + +emitIdent :: Ident -> CompilerState () emitIdent id = do -- !!this should never happen!! emit $ Comment "This should not have happened!" emit $ Variable id emit $ UnsafeRaw "\n" -emitLit :: MIR.Lit -> CompilerState () +emitLit :: Lit -> CompilerState () emitLit i = do -- !!this should never happen!! let (i', t) = case i of - (MIR.LInt i'') -> (VInteger i'', I64) - (MIR.LChar i'') -> (VChar $ ord i'', I8) + (LInt i'') -> (VInteger i'', I64) + (LChar i'') -> (VChar $ ord i'', I8) varCount <- getNewVar emit $ Comment "This should not have happened!" emit $ SetVariable varCount (Add t i' (VInteger 0)) -emitAdd :: MIR.Type -> ExpT -> ExpT -> CompilerState () +emitAdd :: Type -> T Exp -> T Exp -> CompilerState () emitAdd t e1 e2 = do v1 <- exprToValue e1 v2 <- exprToValue e2 v <- getNewVar emit $ SetVariable v (Add (type2LlvmType t) v1 v2) -exprToValue :: ExpT -> CompilerState LLVMValue -exprToValue = \case - (MIR.ELit i, _t) -> pure $ case i of - (MIR.LInt i) -> VInteger i - (MIR.LChar i) -> VChar $ ord i - (MIR.EVar (TIR.Ident "True$Bool"), _t) -> pure $ VInteger 1 - (MIR.EVar (TIR.Ident "False$Bool"), _t) -> pure $ VInteger 0 - (MIR.EVar name, t) -> do - funcs <- gets functions - cons <- gets constructors - let res = - Map.lookup (name, t) funcs - <|> ( \c -> - FunctionInfo - { numArgs = numArgsCI c - , arguments = argumentsCI c - } - ) - <$> Map.lookup name cons - case res of - Just fi -> do - if numArgs fi == 0 - then do - vc <- getNewVar - emit $ - SetVariable - vc - (Call FastCC (type2LlvmType t) Global name []) - pure $ VIdent vc (type2LlvmType t) - else pure $ VFunction name Global (type2LlvmType t) - Nothing -> pure $ VIdent name (type2LlvmType t) - e -> do - compileExp e + +exprToValue :: T Exp -> CompilerState LLVMValue +exprToValue et@(e, t) = case e of + ELit (LInt i) -> pure $ VInteger i + ELit (LChar c) -> pure . VChar $ ord c + + EVar "True$Bool" -> pure $ VInteger 1 + EVar "False$Bool" -> pure $ VInteger 0 + + EVar name -> gets (Map.lookup name . globals) >>= \case + Just (typ@(Function _ ts), val) | length ts > 1 -> do + type_struct <- addStructType (mkClosureName name) [typ] + emit $ Comment "Allocating structure" + emit . SetVariable name $ Alloca type_struct + emit $ Store typ val Ptr name + pure $ VIdent name Ptr + + Just _ | name == "main" -> do + vc <- getNewVar + emit $ SetVariable vc (Call FastCC I64 Global name []) + pure $ VIdent vc I64 + + + Just (Function t_return [_], _) -> do + vc <- getNewVar + emit $ SetVariable vc (Call FastCC t_return Global name [(Ptr, VNull)]) + pure $ VIdent vc t_return + + Just _ -> error "Bad" + + Nothing -> gets (Map.lookup name . constructors) >>= \case + + Just ConstructorInfo {numArgsCI} + | numArgsCI == 0 -> do + vc <- getNewVar + emit $ SetVariable vc call + pure $ VIdent vc (type2LlvmType t) + | otherwise -> pure $ VFunction name Global (type2LlvmType t) + where + call = Call FastCC (type2LlvmType t) Global name [] + + Nothing -> gets $ val + . fromJust + . lookup name + . locals + + EVarC cxt name -> do + let cxt' = flip map cxt $ \(x, t) -> let t' = type2LlvmType t + in (t', VIdent x t') + cxt'' <- gets $ (:cxt') + . fromJust + . Map.lookup name + . globals + + -- Create a new type for function pointer and arguments + type_struct <- addStructType (mkClosureName name) $ map fst cxt'' + emit $ Comment "Allocating structure" + emit . SetVariable name $ Alloca type_struct + + let ptr_struct = VIdent name Ptr + storeArg (t, v) i = do + vc <- getNewVar + emit . SetVariable vc + $ GetElementPtrInbounds type_struct Ptr ptr_struct + I32 (VInteger 0) I32 (VInteger i) -- TODO fix indices + emit $ Store t v Ptr vc + + -- Store arguments in structure + zipWithM_ storeArg cxt'' [0 ..] + pure ptr_struct + + _ -> do + compileExp et v <- getVarCount - pure $ VIdent (TIR.Ident $ show v) (getType e) + pure $ VIdent (Ident $ show v) (getType et) + + +mkClosureName :: Ident -> Ident +mkClosureName (Ident s) = Ident $ "Closure_" ++ s + +mkDerefName :: Ident -> Ident +mkDerefName (Ident s) = Ident $ s ++ "_deref" + +mkCxtName :: Ident +mkCxtName = Ident "cxt" + diff --git a/src/Codegen/LlvmIr.hs b/src/Codegen/LlvmIr.hs index cc77cf9..0e0a6ce 100644 --- a/src/Codegen/LlvmIr.hs +++ b/src/Codegen/LlvmIr.hs @@ -9,17 +9,18 @@ module Codegen.LlvmIr ( Visibility (..), CallingConvention (..), ToIr (..), + typeOf ) where -import Data.List (intercalate) -import TypeChecker.TypeCheckerIr (Ident (..)) +import Data.List (intercalate) +import TypeChecker.TypeCheckerIr (Ident (..)) data CallingConvention = TailCC | FastCC | CCC | ColdCC deriving (Show, Eq, Ord) instance ToIr CallingConvention where toIr :: CallingConvention -> String toIr TailCC = "tailcc" toIr FastCC = "fastcc" - toIr CCC = "ccc" + toIr CCC = "ccc" toIr ColdCC = "coldcc" -- | A datatype which represents some basic LLVM types @@ -38,6 +39,9 @@ data LLVMType class ToIr a where toIr :: a -> String +instance ToIr a => ToIr [a] where + toIr = concatMap toIr + instance ToIr LLVMType where toIr :: LLVMType -> String toIr = \case @@ -66,8 +70,8 @@ data LLVMComp instance ToIr LLVMComp where toIr :: LLVMComp -> String toIr = \case - LLEq -> "eq" - LLNe -> "ne" + LLEq -> "eq" + LLNe -> "ne" LLUgt -> "ugt" LLUge -> "uge" LLUlt -> "ult" @@ -80,7 +84,7 @@ instance ToIr LLVMComp where data Visibility = Local | Global deriving (Show, Eq, Ord) instance ToIr Visibility where toIr :: Visibility -> String - toIr Local = "%" + toIr Local = "%" toIr Global = "@" {- | Represents a LLVM "value", as in an integer, a register variable, @@ -92,16 +96,18 @@ data LLVMValue | VIdent Ident LLVMType | VConstant String | VFunction Ident Visibility LLVMType + | VNull deriving (Show, Eq, Ord) instance ToIr LLVMValue where toIr :: LLVMValue -> String toIr v = case v of - VInteger i -> show i - VChar i -> show i - VIdent (Ident n) _ -> "%" <> n + VInteger i -> show i + VChar i -> show i + VIdent (Ident n) _ -> "%" <> n VFunction (Ident n) vis _ -> toIr vis <> n - VConstant s -> "c" <> show s + VConstant s -> "c" <> show s + VNull -> "null" type Params = [(Ident, LLVMType)] type Args = [(LLVMType, LLVMValue)] @@ -139,6 +145,21 @@ data LLVMIr -- instructions should be used in its place deriving (Show, Eq, Ord) + +-- TODO add missing clauses +typeOf :: LLVMIr -> LLVMType +typeOf = \case + Add t _ _ -> t + Sub t _ _ -> t + Mul t _ _ -> t + Div t _ _ -> t + Load t _ _ -> t + Store t _ _ _ -> t + Type x _ -> CustomType x + SetVariable _ ir -> typeOf ir + + + -- | Converts a list of LLVMIr instructions to a string llvmIrToString :: [LLVMIr] -> String llvmIrToString = go 0 @@ -147,9 +168,9 @@ llvmIrToString = go 0 go _ [] = mempty go i (x : xs) = do let (i', n) = case x of - Define{} -> (i + 1, 0) + Define{} -> (i + 1, 0) DefineEnd -> (i - 1, 0) - _ -> (i, i) + _ -> (i, i) insToString n x <> go i' xs -- \| Converts a LLVM inststruction to a String, allowing for printing etc. @@ -224,10 +245,10 @@ llvmIrToString = go 0 , ")\n" ] (Alloca t) -> unwords ["alloca", toIr t, "\n"] - (Malloc t) -> + (Malloc t) -> concat [ "call ptr @malloc(i64 ", show t, ")\n"] - (GcMalloc t) -> + (GcMalloc t) -> concat [ "call ptr @cheap_alloc(i64 ", show t, ")\n"] (Store t1 val t2 (Ident id2)) -> diff --git a/src/LambdaLifter.hs b/src/LambdaLifter.hs index 5581814..9369442 100644 --- a/src/LambdaLifter.hs +++ b/src/LambdaLifter.hs @@ -11,9 +11,11 @@ import Control.Monad.State (MonadState (get, put), State, evalState) import Data.Function (on) import Data.List (delete, mapAccumL, (\\)) +import Data.Tuple.Extra (first, second) +import LambdaLifterIr (T) +import qualified LambdaLifterIr as L import Prelude hiding (exp) -import TypeChecker.TypeCheckerIr - +import TypeChecker.TypeCheckerIr hiding (T) -- | Lift lambdas and let expression into supercombinators. -- Three phases: @@ -21,12 +23,13 @@ import TypeChecker.TypeCheckerIr -- @abstract@ converts lambdas into let expressions. -- @collectScs@ moves every non-constant let expression to a top-level function. -- -lambdaLift :: Program -> Program -lambdaLift (Program ds) = Program (datatypes ++ binds) +lambdaLift :: Program -> L.Program +lambdaLift (Program ds) = L.Program (datatypes ++ binds) where - datatypes = flip filter ds $ \case DData _ -> True - _ -> False - binds = map DBind $ (collectScs . abstract . freeVars) [b | DBind b <- ds] + datatypes = [L.DData (toLirData d) | DData d <- ds] + + binds = map L.DBind $ (collectScs . abstract . freeVars) [b | DBind b <- ds] + -- | Annotate free variables freeVars :: [Bind] -> [ABind] @@ -36,7 +39,7 @@ freeVars binds = [ let ae = freeVarsExp [] e | Bind n xs e <- binds ] -freeVarsExp :: Frees -> ExpT -> Ann AExpT +freeVarsExp :: Frees -> T Exp -> Ann (T AExp) freeVarsExp localVars (ae, t) = case ae of EVar n | elem (n,t) localVars -> Ann { frees = [(n, t)] , term = (AVar n, t) @@ -121,27 +124,47 @@ data Ann a = Ann , term :: a } deriving (Show, Eq) -data ABind = ABind Id [Id] (Ann AExpT) deriving (Show, Eq) -data ABranch = ABranch (Pattern, Type) (Ann AExpT) deriving (Show, Eq) - -type AExpT = (AExp, Type) +data ABind = ABind (T Ident) [T Ident] (Ann (T AExp)) deriving (Show, Eq) +data ABranch = ABranch (Pattern, Type) (Ann (T AExp)) deriving (Show, Eq) data AExp = AVar Ident | AInj Ident | ALit Lit - | ALet (Ann ABind) (Ann AExpT) - | AApp (Ann AExpT) (Ann AExpT) - | AAdd (Ann AExpT) (Ann AExpT) - | AAbs Ident (Ann AExpT) - | ACase (Ann AExpT) [Ann ABranch] + | ALet (Ann ABind) (Ann (T AExp)) + | AApp (Ann (T AExp)) (Ann (T AExp)) + | AAdd (Ann (T AExp)) (Ann (T AExp)) + | AAbs Ident (Ann (T AExp)) + | ACase (Ann (T AExp)) [Ann ABranch] deriving (Show, Eq) -abstract :: [ABind] -> [Bind] + + +data BBind = BBind (T Ident) [T Ident] (T BExp) + | BBindCxt [T Ident] (T Ident) [T Ident] (T BExp) + deriving (Eq, Ord, Show) + + +data BBranch = BBranch (T Pattern) (T BExp) + deriving (Eq, Ord, Show) + +data BExp + = BVar Ident + | BVarC [T Ident] Ident + | BInj Ident + | BLit Lit + | BLet BBind (T BExp) + | BApp (T BExp)(T BExp) + | BAdd (T BExp)(T BExp) + | BCase (T BExp) [BBranch] + deriving (Eq, Ord, Show) + + +abstract :: [ABind] -> [BBind] abstract bs = evalState (mapM (abstractAnnBind . Ann []) bs) 0 -abstractAnnBind :: Ann ABind -> State Int Bind +abstractAnnBind :: Ann ABind -> State Int BBind abstractAnnBind Ann { term = ABind name vars annae } = - Bind name (vars' <|| vars) <$> abstractAnnExp annae' + BBind name (vars' <|| vars) <$> abstractAnnExp annae' where (annae', vars') = go [] annae where @@ -149,24 +172,27 @@ abstractAnnBind Ann { term = ABind name vars annae } = Ann { term = (AAbs x ae, TFun t _) } -> go (snoc (x, t) acc) ae ae -> (ae, acc) -abstractAnnExp :: Ann AExpT -> State Int ExpT +abstractAnnExp :: Ann (T AExp) -> State Int (T BExp) abstractAnnExp Ann {frees, term = (annae, typ) } = case annae of - AVar n -> pure (EVar n, typ) - AInj n -> pure (EInj n, typ) - ALit lit -> pure (ELit lit, typ) - AApp annae1 annae2 -> (, typ) <$> onM EApp abstractAnnExp annae1 annae2 - AAdd annae1 annae2 -> (, typ) <$> onM EAdd abstractAnnExp annae1 annae2 + AVar n -> pure (BVar n, typ) + AInj n -> pure (BInj n, typ) + ALit lit -> pure (BLit lit, typ) + AApp annae1 annae2 -> (, typ) <$> onM BApp abstractAnnExp annae1 annae2 + AAdd annae1 annae2 -> (, typ) <$> onM BAdd abstractAnnExp annae1 annae2 - -- \x. \y. x + y + z ⇒ let sc x y z = x + y + z in sc AAbs x annae' -> do i <- nextNumber rhs <- abstractAnnExp annae'' let sc_name = Ident ("sc_" ++ show i) - e@(_, t) = foldl applyFree (EVar sc_name, typ) frees - pure (ELet (Bind (sc_name, typ) vars rhs) e ,t) + sc | null frees = (BVar sc_name, typ) + | otherwise = (BVarC frees sc_name, typ) + bind | null frees = BBind (sc_name, typ) vars rhs + | otherwise = BBindCxt frees (sc_name, typ) vars rhs + + pure (BLet bind sc ,typ) where - vars = frees <| (x, t_x) <|| ys + vars = [(x, t_x)] <|| ys t_x = case typ of TFun t _ -> t _ -> error "Impossible" @@ -176,54 +202,48 @@ abstractAnnExp Ann {frees, term = (annae, typ) } = case annae of Ann { term = (AAbs x ae, TFun t _) } -> go (snoc (x, t) acc) ae ae -> (ae, acc) - - applyFree :: (Exp' Type, Type) -> (Ident, Type) -> (Exp' Type, Type) - applyFree (e, t_e) (x, t_x) = (EApp (e, t_e) (EVar x, t_x), t_e') - where - t_e' = case t_e of TFun _ t -> t - _ -> error "Impossible" - ACase annae' bs -> do bs <- mapM go bs e <- abstractAnnExp annae' - pure (ECase e bs, typ) + pure (BCase e bs, typ) where - go Ann { term = ABranch p annae } = Branch p <$> abstractAnnExp annae + go Ann { term = ABranch p annae } = BBranch p <$> abstractAnnExp annae ALet b annae' -> - (, typ) <$> liftA2 ELet (abstractAnnBind b) (abstractAnnExp annae') + (, typ) <$> liftA2 BLet (abstractAnnBind b) (abstractAnnExp annae') -- | Collects supercombinators by lifting non-constant let expressions -collectScs :: [Bind] -> [Bind] +collectScs :: [BBind] -> [L.Bind] collectScs = concatMap collectFromRhs where - collectFromRhs (Bind name parms rhs) = + collectFromRhs (BBind name parms rhs) = let (rhs_scs, rhs') = collectScsExp rhs - in Bind name parms rhs' : rhs_scs + in L.Bind name parms rhs' : rhs_scs + collectFromRhs (BBindCxt cxt name parms rhs) = + let (rhs_scs, rhs') = collectScsExp rhs + in L.BindC cxt name parms rhs' : rhs_scs -collectScsExp :: ExpT -> ([Bind], ExpT) -collectScsExp expT@(exp, typ) = case exp of - EVar _ -> ([], expT) - EInj _ -> ([], expT) - ELit _ -> ([], expT) +collectScsExp :: T BExp -> ([L.Bind], T L.Exp) +collectScsExp (exp, typ) = case exp of + BVar x -> ([], (L.EVar x, typ)) + BVarC as x -> ([], (L.EVarC as x, typ)) + BInj k -> ([], (L.EInj k, typ)) + BLit lit -> ([], (L.ELit lit, typ)) - EApp e1 e2 -> (scs1 ++ scs2, (EApp e1' e2', typ)) + BApp e1 e2 -> (scs1 ++ scs2, (L.EApp e1' e2', typ)) where (scs1, e1') = collectScsExp e1 (scs2, e2') = collectScsExp e2 - EAdd e1 e2 -> (scs1 ++ scs2, (EAdd e1' e2', typ)) + BAdd e1 e2 -> (scs1 ++ scs2, (L.EAdd e1' e2', typ)) where (scs1, e1') = collectScsExp e1 (scs2, e2') = collectScsExp e2 - EAbs par e -> (scs, (EAbs par e', typ)) - where - (scs, e') = collectScsExp e - ECase e branches -> (scs ++ scs_e, (ECase e' branches', typ)) + BCase e branches -> (scs ++ scs_e, (L.ECase e' branches', typ)) where (scs, branches') = mapAccumL f [] branches (scs_e, e') = collectScsExp e @@ -234,15 +254,24 @@ collectScsExp expT@(exp, typ) = case exp of -- -- > f = let sc x y = rhs in e -- - ELet (Bind name parms rhs) e - | null parms -> (rhs_scs ++ et_scs, (ELet bind et', snd et')) + BLet (BBind name parms rhs) e + | null parms -> (rhs_scs ++ et_scs, (L.ELet name rhs' et', snd et')) | otherwise -> (bind : rhs_scs ++ et_scs, et') where - bind = Bind name parms rhs' + bind = L.Bind name parms rhs' (rhs_scs, rhs') = collectScsExp rhs (et_scs, et') = collectScsExp e -collectScsBranch (Branch patt exp) = (scs, Branch patt exp') + + BLet (BBindCxt cxt name parms rhs) e + | null parms -> (rhs_scs ++ et_scs, (L.ELet name rhs' et', snd et')) + | otherwise -> (bind : rhs_scs ++ et_scs, et') + where + bind = L.BindC cxt name parms rhs' + (rhs_scs, rhs') = collectScsExp rhs + (et_scs, et') = collectScsExp e + +collectScsBranch (BBranch patt exp) = (scs, L.Branch (first toLirPattern patt) exp') where (scs, exp') = collectScsExp exp nextNumber :: State Int Int @@ -259,3 +288,19 @@ xs <| x | elem x xs = xs (<||) :: Eq a => [a] -> [a] -> [a] xs <|| ys = foldl (<|) xs ys + + +toLirData (Data t injs) = L.Data t (map toLirInj injs) +toLirInj (Inj n t) = L.Inj n t + +toLirPattern :: Pattern -> L.Pattern +toLirPattern = \case + PVar x -> L.PVar x + PLit lit -> L.PLit lit + PCatch -> L.PCatch + PEnum k -> L.PEnum k + PInj k ps -> L.PInj k (map (first toLirPattern) ps) + + + + diff --git a/src/LambdaLifterIr.hs b/src/LambdaLifterIr.hs new file mode 100644 index 0000000..9ba57f7 --- /dev/null +++ b/src/LambdaLifterIr.hs @@ -0,0 +1,140 @@ +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE PatternSynonyms #-} + +module LambdaLifterIr ( + module Grammar.Abs, + module LambdaLifterIr, + module TypeChecker.TypeCheckerIr +) where + +import Data.List (intercalate) +import Grammar.Abs (Lit (..)) +import Grammar.Print +import Prelude hiding (exp) +import qualified Prelude as C (Eq, Ord, Show) +import TypeChecker.TypeCheckerIr (Ident (..), TVar (..), Type (..)) + +newtype Program = Program [Def] + deriving (C.Eq, C.Ord, C.Show) + +data Def + = DBind Bind + | DData Data + deriving (C.Eq, C.Ord, C.Show) + +data Data = Data Type [Inj] + deriving (C.Eq, C.Ord, C.Show) + +data Inj = Inj Ident Type + deriving (C.Eq, C.Ord, C.Show) + +data Pattern + = PVar Ident + | PLit Lit + | PCatch + | PEnum Ident + | PInj Ident [(Pattern, Type)] + deriving (C.Eq, C.Ord, C.Show) + +data Exp + = EVar Ident + | EVarC [T Ident] Ident + | EInj Ident + | ELit Lit + | ELet (T Ident) (T Exp) (T Exp) + | EApp (T Exp)(T Exp) + | EAdd (T Exp)(T Exp) + | ECase (T Exp) [Branch] + deriving (C.Eq, C.Ord, C.Show) + + +type T a = (a, Type) + +data Bind = Bind (T Ident) [T Ident] (T Exp) + | BindC [T Ident] (T Ident) [T Ident] (T Exp) + deriving (C.Eq, C.Ord, C.Show) + +data Branch = Branch (T Pattern) (T Exp) + deriving (C.Eq, C.Ord, C.Show) + +instance Print Program where + prt i (Program sc) = prt i sc + +instance Print Bind where + prt i (Bind sig parms rhs) = concatD + [ prt i sig + , prt i parms + , doc $ showString "=" + , prt i rhs + ] + prt i (BindC cxt sig parms rhs) = + prPrec i 0 $ + concatD + [ doc . showString $ "{" ++ intercalate ", " (map (\(x, _) -> printTree x ++ "^") cxt) ++ "}" ++ printTree sig + , prt i parms + , doc $ showString "=" + , prt i rhs + ] + +instance Print [Bind] where + prt _ [] = concatD [] + prt i [x] = concatD [prt i x] + prt i (x : xs) = concatD [prt i x, doc (showString ";"), prt i xs] + +instance Print Exp where + prt i = \case + EVar lident -> prPrec i 3 (concatD [prt 0 lident]) + EVarC as lident -> doc . showString + $ "{" ++ intercalate ", " (map go as) ++ "}" ++ printTree lident + where + go (x, _) = printTree x ++ "^=" ++ printTree (EVar x) + EInj uident -> prPrec i 3 (concatD [prt 0 uident]) + ELit lit -> prPrec i 3 (concatD [prt 0 lit]) + EApp exp1 exp2 -> prPrec i 2 (concatD [prt 2 exp1, prt 3 exp2]) + EAdd exp1 exp2 -> prPrec i 1 (concatD [prt 1 exp1, doc (showString "+"), prt 2 exp2]) + ELet lident exp1 exp2 -> prPrec i 0 (concatD [doc (showString "let"), prt 0 lident, doc (showString "="), prt 0 exp1 , doc (showString "in"), prt 0 exp2]) + ECase exp branchs -> prPrec i 0 (concatD [doc (showString "case"), prt 0 exp, doc (showString "of"), doc (showString "{"), prt 0 branchs, doc (showString "}")]) + + +instance Print Branch where + prt i (Branch (pattern_, t) exp) = prPrec i 0 (concatD [doc (showString "("), prt 0 pattern_, doc (showString " : "), prt 0 t, doc (showString ")"), doc (showString "=>"), prt 0 exp]) + +instance Print [Branch] where + prt _ [] = concatD [] + prt _ [x] = concatD [prt 0 x] + prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs] + +instance Print Def where + prt i = \case + DBind bind -> prPrec i 0 (concatD [prt 0 bind]) + DData data_ -> prPrec i 0 (concatD [prt 0 data_]) + +instance Print Data where + prt i = \case + Data type_ injs -> prPrec i 0 (concatD [doc (showString "data"), prt 0 type_, doc (showString "where"), doc (showString "{"), prt 0 injs, doc (showString "}")]) + +instance Print Inj where + prt i = \case + Inj uident type_ -> prt i (uident, type_) + +instance Print [Inj] where + prt _ [] = concatD [] + prt i [x] = prt i x + prt i (x : xs) = prPrec i 0 $ concatD [prt i x, doc $ showString "\n ", prt i xs] + +instance Print Pattern where + prt i = \case + PVar name -> prPrec i 1 (concatD [prt 0 name]) + PLit lit -> prPrec i 1 (concatD [prt 0 lit]) + PCatch -> prPrec i 1 (concatD [doc (showString "_")]) + PEnum name -> prPrec i 1 (concatD [prt 0 name]) + PInj uident patterns -> prPrec i 0 (concatD [prt 0 uident, prt 1 patterns]) + +instance Print [Def] where + prt _ [] = concatD [] + prt _ [x] = concatD [prt 0 x] + prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs] + +pattern DBind' id vars expt = DBind (Bind id vars expt) +pattern DData' typ injs = DData (Data typ injs) + diff --git a/src/Monomorphizer/DataTypeRemover.hs b/src/Monomorphizer/DataTypeRemover.hs index e4caef0..c3e2eb5 100644 --- a/src/Monomorphizer/DataTypeRemover.hs +++ b/src/Monomorphizer/DataTypeRemover.hs @@ -1,8 +1,11 @@ + module Monomorphizer.DataTypeRemover (removeDataTypes) where -import Monomorphizer.MonomorphizerIr qualified as M2 -import Monomorphizer.MorbIr qualified as M1 -import TypeChecker.TypeCheckerIr (Ident (Ident)) +import Data.Bifunctor (Bifunctor (bimap)) +import Monomorphizer.MonomorphizerIr (Ident (..)) +import qualified Monomorphizer.MonomorphizerIr as M2 +import qualified Monomorphizer.MorbIr as M1 +import Prelude hiding (exp) removeDataTypes :: M1.Program -> M2.Program removeDataTypes (M1.Program defs) = M2.Program (map pDef defs) @@ -18,43 +21,43 @@ pCons :: M1.Inj -> M2.Inj pCons (M1.Inj ident t) = M2.Inj ident (pType t) pType :: M1.Type -> M2.Type -pType (M1.TLit ident) = M2.TLit ident -pType (M1.TFun t1 t2) = M2.TFun (pType t1) (pType t2) +pType (M1.TLit ident) = M2.TLit ident +pType (M1.TFun t1 t2) = M2.TFun (pType t1) (pType t2) pType (M1.TData (Ident "Bool") _) = M2.TLit (Ident "Bool") -pType d = M2.TLit (Ident (newName d)) -- This is the step +pType d = M2.TLit (Ident (newName d)) -- This is the step newName :: M1.Type -> String -newName (M1.TLit (Ident str)) = str -newName (M1.TFun t1 t2) = newName t1 ++ newName t2 +newName (M1.TLit (Ident str)) = str +newName (M1.TFun t1 t2) = newName t1 ++ newName t2 newName (M1.TData (Ident str) args) = str ++ concatMap newName args pBind :: M1.Bind -> M2.Bind pBind (M1.Bind id argIds expt) = M2.Bind (pId id) (map pId argIds) (pExpT expt) +pBind (M1.BindC cxt id argIds expt) = + M2.BindC (map pId cxt) (pId id) (map pId argIds) (pExpT expt) pId :: (Ident, M1.Type) -> (Ident, M2.Type) pId (ident, t) = (ident, pType t) -pExpT :: M1.ExpT -> M2.ExpT +pExpT :: M1.T M1.Exp -> M2.T M2.Exp pExpT (exp, t) = (pExp exp, pType t) pExp :: M1.Exp -> M2.Exp -pExp (M1.EVar ident) = M2.EVar ident -pExp (M1.ELit lit) = M2.ELit (pLit lit) -pExp (M1.ELet bind expt) = M2.ELet (pBind bind) (pExpT expt) -pExp (M1.EApp e1 e2) = M2.EApp (pExpT e1) (pExpT e2) -pExp (M1.EAdd e1 e2) = M2.EAdd (pExpT e1) (pExpT e2) +pExp (M1.EVar ident) = M2.EVar ident +pExp (M1.EVarC as ident) = M2.EVarC (map pId as) ident +pExp (M1.ELit lit) = M2.ELit lit +pExp (M1.ELet bind expt) = M2.ELet (pBind bind) (pExpT expt) +pExp (M1.EApp e1 e2) = M2.EApp (pExpT e1) (pExpT e2) +pExp (M1.EAdd e1 e2) = M2.EAdd (pExpT e1) (pExpT e2) pExp (M1.ECase expT branches) = M2.ECase (pExpT expT) (map pBranch branches) pBranch :: M1.Branch -> M2.Branch pBranch (M1.Branch (patt, t) expt) = M2.Branch (pPattern patt, pType t) (pExpT expt) pPattern :: M1.Pattern -> M2.Pattern -pPattern (M1.PVar id) = M2.PVar (pId id) -pPattern (M1.PLit (lit, t)) = M2.PLit (pLit lit, pType t) -pPattern (M1.PInj ident patts) = M2.PInj ident (map pPattern patts) -pPattern M1.PCatch = M2.PCatch -pPattern (M1.PEnum ident) = M2.PEnum ident +pPattern (M1.PVar ident) = M2.PVar ident +pPattern (M1.PLit lit) = M2.PLit lit +pPattern (M1.PInj ident patts) = M2.PInj ident (map (bimap pPattern pType) patts) +pPattern M1.PCatch = M2.PCatch +pPattern (M1.PEnum ident) = M2.PEnum ident -pLit :: M1.Lit -> M2.Lit -pLit (M1.LInt v) = M2.LInt v -pLit (M1.LChar c) = M2.LChar c diff --git a/src/Monomorphizer/Monomorphizer.hs b/src/Monomorphizer/Monomorphizer.hs index 3a8bd9e..4b25aaa 100644 --- a/src/Monomorphizer/Monomorphizer.hs +++ b/src/Monomorphizer/Monomorphizer.hs @@ -1,4 +1,5 @@ -{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedRecordDot #-} {- | For now, converts polymorphic functions to concrete ones based on usage. Assumes lambdas are lifted. @@ -25,30 +26,35 @@ bind) is added to the resulting set of binds. -} module Monomorphizer.Monomorphizer (monomorphize, morphExp, morphBind) where -import Monomorphizer.DataTypeRemover (removeDataTypes) -import Monomorphizer.MonomorphizerIr qualified as O -import Monomorphizer.MorbIr qualified as M -import TypeChecker.TypeCheckerIr (Ident (Ident)) -import TypeChecker.TypeCheckerIr qualified as T -import Control.Monad.Reader ( - MonadReader (ask, local), - Reader, - asks, - runReader, - ) -import Control.Monad.State ( - MonadState (get), - StateT (runStateT), - gets, - modify, - ) -import Data.Coerce (coerce) -import Data.Map qualified as Map -import Data.Maybe (catMaybes) -import Data.Set qualified as Set -import Grammar.Print (printTree) -import Debug.Trace (trace) +import Control.Monad.Reader (MonadReader (ask, local), + Reader, asks, runReader) +import Control.Monad.State (MonadState (get), + StateT (runStateT), gets, + modify) +import Data.Coerce (coerce) +import qualified Data.Map as Map +import Data.Maybe (catMaybes) +import qualified Data.Set as Set +import Debug.Trace (trace) +import Grammar.Print (printTree) +import Monomorphizer.DataTypeRemover (removeDataTypes) +import qualified Monomorphizer.MonomorphizerIr as O +import qualified Monomorphizer.MorbIr as M +-- import TypeChecker.TypeCheckerIr (Ident (Ident)) +import LambdaLifterIr (Ident (..)) +-- import TypeChecker.TypeCheckerIr qualified as T +import qualified LambdaLifterIr as L + +import Control.Monad.Reader (MonadReader (ask, local), + Reader, asks, runReader) +import Control.Monad.State (MonadState, StateT (runStateT), + gets, modify) +import qualified Data.Map as Map +import Data.Maybe (catMaybes, fromJust) +import qualified Data.Set as Set +import Data.Tuple.Extra (secondM) +import Grammar.Print (printTree) {- | EnvM is the monad containing the read-only state as well as the output state containing monomorphized functions and to-be monomorphized @@ -64,18 +70,18 @@ Binds, Polymorphic Data types (monomorphized in a later step) and Marked bind, which means that it is in the process of monomorphization and should not be monomorphized again. -} -data Outputted = Marked | Complete M.Bind | Data M.Type T.Data deriving (Show) +data Outputted = Marked | Complete M.Bind | Data M.Type L.Data deriving (Show) -- | Static environment. data Env = Env - { input :: Map.Map Ident T.Bind + { input :: Map.Map Ident L.Bind -- ^ All binds in the program. - , dataDefs :: Map.Map Ident T.Data + , dataDefs :: Map.Map Ident L.Data -- ^ All constructors mapped to their respective polymorphic data def -- which includes all other constructors. - , polys :: Map.Map Ident M.Type + , polys :: Map.Map Ident M.Type -- ^ Maps polymorphic identifiers with concrete types. - , locals :: Set.Set Ident + , locals :: Set.Set Ident -- ^ Local variables. } @@ -84,12 +90,13 @@ localExists :: Ident -> EnvM Bool localExists ident = asks (Set.member ident . locals) -- | Gets a polymorphic bind from an id. -getInputBind :: Ident -> EnvM (Maybe T.Bind) +getInputBind :: Ident -> EnvM (Maybe L.Bind) getInputBind ident = asks (Map.lookup ident . input) -- | Add monomorphic function derived from a polymorphic one, to env. addOutputBind :: M.Bind -> EnvM () addOutputBind b@(M.Bind (ident, _) _ _) = modify (Map.insert ident (Complete b)) +addOutputBind b@(M.BindC _ (ident, _) _ _) = modify (Map.insert ident (Complete b)) {- | Marks a global bind as being processed, meaning that when encountered again, it should not be recursively processed. @@ -106,8 +113,8 @@ isConsMarked :: Ident -> EnvM Bool isConsMarked ident = gets (Map.member ident) -- | Finds main bind. -getMain :: EnvM T.Bind -getMain = asks (\env -> case Map.lookup (T.Ident "main") (input env) of +getMain :: EnvM L.Bind +getMain = asks (\env -> case Map.lookup (Ident "main") (input env) of Just mainBind -> mainBind Nothing -> error "main not found in monomorphizer!" ) @@ -116,13 +123,13 @@ getMain = asks (\env -> case Map.lookup (T.Ident "main") (input env) of error when encountering different structures between the two arguments. Debug: First argument is the name of the bind. -} -mapTypes :: Ident -> T.Type -> M.Type -> [(Ident, M.Type)] -mapTypes _ident (T.TLit _) (M.TLit _) = [] -mapTypes _ident (T.TVar (T.MkTVar i1)) tm = [(i1, tm)] -mapTypes ident (T.TFun pt1 pt2) (M.TFun mt1 mt2) = +mapTypes :: Ident -> L.Type -> M.Type -> [(Ident, M.Type)] +mapTypes _ident (L.TLit _) (M.TLit _) = [] +mapTypes _ident (L.TVar (L.MkTVar i1)) tm = [(i1, tm)] +mapTypes ident (L.TFun pt1 pt2) (M.TFun mt1 mt2) = mapTypes ident pt1 mt1 ++ mapTypes ident pt2 mt2 -mapTypes ident (T.TData tIdent pTs) (M.TData mIdent mTs) = +mapTypes ident (L.TData tIdent pTs) (M.TData mIdent mTs) = if tIdent /= mIdent then error "the data type names of monomorphic and polymorphic data types does not match" else foldl (\xs (p, m) -> mapTypes ident p m ++ xs) [] (zip pTs mTs) @@ -130,30 +137,30 @@ mapTypes ident t1 t2 = error $ "in bind: '" ++ printTree ident ++ "', " ++ "structure of types not the same: '" ++ printTree t1 ++ "', '" ++ printTree t2 ++ "'" -- | Gets the mapped monomorphic type of a polymorphic type in the current context. -getMonoFromPoly :: T.Type -> EnvM M.Type +getMonoFromPoly :: L.Type -> EnvM M.Type getMonoFromPoly t = do env <- ask return $ getMono (polys env) t where - getMono :: Map.Map Ident M.Type -> T.Type -> M.Type + getMono :: Map.Map Ident M.Type -> L.Type -> M.Type getMono polys t = case t of - (T.TLit ident) -> M.TLit (coerce ident) - (T.TFun t1 t2) -> M.TFun (getMono polys t1) (getMono polys t2) - (T.TVar (T.MkTVar ident)) -> case Map.lookup ident polys of + (L.TLit ident) -> M.TLit ident + (L.TFun t1 t2) -> M.TFun (getMono polys t1) (getMono polys t2) + (L.TVar (L.MkTVar ident)) -> case Map.lookup ident polys of Just concrete -> concrete - Nothing -> M.TLit (Ident "void") + Nothing -> M.TLit (Ident "void") -- error $ "type not found! type: " ++ show ident ++ ", error in previous compilation steps" - (T.TData ident args) -> M.TData ident (map (getMono polys) args) + (L.TData ident args) -> M.TData ident (map (getMono polys) args) {- | If ident not already in env's output, morphed bind to output (and all referenced binds within this bind). Returns the annotated bind name. -} -morphBind :: M.Type -> T.Bind -> EnvM Ident -morphBind expectedType b@(T.Bind (ident, btype) args (exp, expt)) = do +morphBind :: M.Type -> L.Bind -> EnvM Ident +morphBind expectedType b@(L.Bind (ident, btype) args (exp, expt)) = do -- The "new name" is used to find out if it is already marked or not. let name' = newFuncName expectedType b - bindMarked <- isBindMarked (coerce name') + bindMarked <- isBindMarked name' local ( \env -> env @@ -168,26 +175,59 @@ morphBind expectedType b@(T.Bind (ident, btype) args (exp, expt)) = do else do -- Mark so that this bind will not be processed in recursive or cyclic -- function calls - markBind (coerce name') + markBind name' expt' <- getMonoFromPoly expt exp' <- morphExp expt' exp -- Get monomorphic type sof args args' <- mapM morphArg args addOutputBind $ M.Bind - (coerce name', expectedType) + (name', expectedType) args' (exp', expt') return name' +morphBind expectedType b@(L.BindC cxt (ident, btype) args (exp, expt)) = do + -- The "new name" is used to find out if it is already marked or not. + let name' = newFuncName expectedType b + bindMarked <- isBindMarked name' + local + ( \env -> + env + { locals = Set.fromList (map fst args) + , polys = Map.fromList (mapTypes ident btype expectedType) + } + ) + $ do + -- Return with right name if already marked + if bindMarked + then return name' + else do + -- Mark so that this bind will not be processed in recursive or cyclic + -- function calls + markBind name' + -- Get monomorphic type sof args + args' <- mapM morphArg args + cxt' <- mapM (secondM getMonoFromPoly) cxt + expt' <- getMonoFromPoly expt + exp' <- local (\env -> foldr (addLocal . fst) env cxt) + (morphExp expt' exp) + addOutputBind $ + M.BindC cxt' + (name', expectedType) + args' + (exp', expt') + return name' + + -- | Monomorphizes arguments of a bind. -morphArg :: (Ident, T.Type) -> EnvM (Ident, M.Type) +morphArg :: (Ident, L.Type) -> EnvM (Ident, M.Type) morphArg (ident, t) = do t' <- getMonoFromPoly t return (ident, t') -- | Gets the data bind from the name of a constructor. -getInputData :: Ident -> EnvM (Maybe T.Data) +getInputData :: Ident -> EnvM (Maybe L.Data) getInputData ident = do env <- ask return $ Map.lookup ident (dataDefs env) @@ -201,50 +241,50 @@ morphCons expectedType ident newIdent = do --trace ("Tjofras:" ++ show (newName expectedType ident)) $ return () maybeD <- getInputData ident case maybeD of - Nothing -> error $ "identifier '" ++ show ident ++ "' not found" + -- closures can have unbound variables + Nothing -> pure () Just d -> do modify (\output -> Map.insert newIdent (Data expectedType d) output) -- | Converts literals from input to output tree. -convertLit :: T.Lit -> M.Lit -convertLit (T.LInt v) = M.LInt v -convertLit (T.LChar v) = M.LChar v +convertLit :: L.Lit -> M.Lit +convertLit (L.LInt v) = M.LInt v +convertLit (L.LChar v) = M.LChar v + -- | Monomorphizes an expression, given an expected type. -morphExp :: M.Type -> T.Exp -> EnvM M.Exp +morphExp :: M.Type -> L.Exp -> EnvM M.Exp morphExp expectedType exp = case exp of - T.ELit lit -> return $ M.ELit (convertLit lit) + L.ELit lit -> return $ M.ELit lit -- Constructor - T.EInj ident -> do + L.EInj ident -> do let ident' = newName (getDataType expectedType) ident morphCons expectedType ident ident' return $ M.EVar ident' - T.EApp (e1, _t1) (e2, t2) -> do + L.EApp (e1, _t1) (e2, t2) -> do t2' <- getMonoFromPoly t2 e2' <- morphExp t2' e2 e1' <- morphExp (M.TFun t2' expectedType) e1 return $ M.EApp (e1', M.TFun t2' expectedType) (e2', t2') - T.EAdd (e1, t1) (e2, t2) -> do + L.EAdd (e1, t1) (e2, t2) -> do t1' <- getMonoFromPoly t1 t2' <- getMonoFromPoly t2 e1' <- morphExp t1' e1 e2' <- morphExp t2' e2 return $ M.EAdd (e1', expectedType) (e2', expectedType) - T.EAbs ident (exp, t) -> local (\env -> env{locals = Set.insert ident (locals env)}) $ do - t' <- getMonoFromPoly t - morphExp t' exp - T.ECase (exp, t) bs -> do + L.ECase (exp, t) bs -> do t' <- getMonoFromPoly t exp' <- morphExp t' exp bs' <- mapM morphBranch bs return $ M.ECase (exp', t') (catMaybes bs') -- Ideally constructors should be EInj, though this code handles them -- as well. - T.EVar ident -> do + -- FIXME MAKE EVAR AND EINJ SEPARATE!!! + L.EVar ident -> do isLocal <- localExists ident if isLocal then do - return $ M.EVar (coerce ident) + return $ M.EVar ident else do bind <- getInputBind ident case bind of @@ -252,38 +292,51 @@ morphExp expectedType exp = case exp of Just bind' -> do -- New bind to process newBindName <- morphBind expectedType bind' - return $ M.EVar (coerce newBindName) - T.ELet (T.Bind (identB, tB) args (expB, tExpB)) (exp, tExp) -> - if length args > 0 then error "only constants in lets allowed" - else do + return $ M.EVar newBindName + L.EVarC as ident -> do + isLocal <- localExists ident + if isLocal + then do + return $ M.EVar ident + else do + bind <- fromJust <$> getInputBind ident + as' <- mapM (secondM getMonoFromPoly) as + -- New bind to process + newBindName <- morphBind expectedType bind + return $ M.EVarC as' newBindName + -- Ideally constructors should be EInj, though this code handles them + -- as well. + + + L.ELet (identB, tB) (expB, tExpB) (exp, tExp) -> do tB' <- getMonoFromPoly tB tExpB' <- getMonoFromPoly tExpB tExp' <- getMonoFromPoly tExp expB' <- morphExp tExpB' expB - exp' <- morphExp tExp' exp + exp' <- local (addLocal identB) (morphExp tExp' exp) return $ M.ELet (M.Bind (identB, tB') [] (expB', tExpB')) (exp', tExp') -- | Monomorphizes case-of branches. -morphBranch :: T.Branch -> EnvM (Maybe M.Branch) -morphBranch (T.Branch (p, pt) (e, et)) = do +morphBranch :: L.Branch -> EnvM (Maybe M.Branch) +morphBranch (L.Branch (p, pt) (e, et)) = do pt' <- getMonoFromPoly pt et' <- getMonoFromPoly et env <- ask maybeMorphedPattern <- morphPattern p pt' case maybeMorphedPattern of Nothing -> return Nothing - Just (p', newLocals) -> + Just (p', newLocals) -> local (const env { locals = Set.union (locals env) newLocals }) $ do e' <- morphExp et' e - return $ Just (M.Branch (p', pt') (e', et')) + return $ Just (M.Branch p' (e', et')) -morphPattern :: T.Pattern -> M.Type -> EnvM (Maybe (M.Pattern, Set.Set Ident)) +morphPattern :: L.Pattern -> M.Type -> EnvM (Maybe (M.T M.Pattern, Set.Set Ident)) morphPattern p expectedType = case p of - T.PVar ident -> return $ Just (M.PVar (ident, expectedType), Set.singleton ident) - T.PLit lit -> return $ Just (M.PLit (convertLit lit, expectedType), Set.empty) - T.PCatch -> return $ Just (M.PCatch, Set.empty) - T.PEnum ident -> return $ Just (M.PEnum (newName expectedType ident), Set.empty) - T.PInj ident pts -> do let newIdent = newName expectedType ident + L.PVar ident -> return $ Just ((M.PVar ident, expectedType), Set.singleton ident) + L.PLit lit -> return $ Just ((M.PLit (convertLit lit), expectedType), Set.empty) + L.PCatch -> return $ Just ((M.PCatch, expectedType), Set.empty) + L.PEnum ident -> return $ Just ((M.PEnum (newName expectedType ident), expectedType), Set.empty) + L.PInj ident pts -> do let newIdent = newName expectedType ident outEnv <- get trace ("WOW: " ++ show (newName expectedType ident)) $ return () trace ("WOW2: " ++ show (outEnv)) $ return () @@ -297,13 +350,18 @@ morphPattern p expectedType = case p of let maybePsSets = sequence psSets case maybePsSets of Nothing -> return Nothing - Just psSets' -> return $ Just - (M.PInj newIdent (map fst psSets'), Set.unions $ map snd psSets') + Just psSets' -> return $ Just + ((M.PInj newIdent (map fst psSets'), expectedType), Set.unions $ map snd psSets') else return Nothing -- | Creates a new identifier for a function with an assigned type. -newFuncName :: M.Type -> T.Bind -> Ident -newFuncName t (T.Bind (ident@(Ident bindName), _) _ _) = +newFuncName :: M.Type -> L.Bind -> Ident +newFuncName t (L.Bind (ident@(Ident bindName), _) _ _) = + if bindName == "main" + then Ident bindName + else newName t ident + +newFuncName t (L.BindC _ (ident@(Ident bindName), _) _ _) = if bindName == "main" then Ident bindName else newName t ident @@ -317,8 +375,8 @@ newName t (Ident str) = Ident $ str ++ "$" ++ newName' t newName' (M.TData (Ident str) ts) = str ++ foldl (\s t -> s ++ "." ++ newName' t) "" ts -- | Monomorphization step. -monomorphize :: T.Program -> O.Program -monomorphize (T.Program defs) = +monomorphize :: L.Program -> O.Program +monomorphize (L.Program defs) = removeDataTypes $ M.Program ( getDefsFromOutput @@ -336,7 +394,7 @@ runEnvM :: Output -> Env -> EnvM () -> Output runEnvM o env (EnvM stateM) = snd $ runReader (runStateT stateM o) env -- | Creates the environment based on the input binds. -createEnv :: [T.Def] -> Env +createEnv :: [L.Def] -> Env createEnv defs = Env { input = Map.fromList bindPairs @@ -346,33 +404,34 @@ createEnv defs = } where bindPairs = (map (\b -> (getBindName b, b)) . getBindsFromDefs) defs - dataPairs :: [(Ident, T.Data)] - dataPairs = (foldl (\acc d@(T.Data _ cs) -> map ((,d) . getConsName) cs ++ acc) [] . getDataFromDefs) defs + dataPairs :: [(Ident, L.Data)] + dataPairs = (foldl (\acc d@(L.Data _ cs) -> map ((,d) . getConsName) cs ++ acc) [] . getDataFromDefs) defs -- | Gets a top-lefel function name. -getBindName :: T.Bind -> Ident -getBindName (T.Bind (ident, _) _ _) = ident +getBindName :: L.Bind -> Ident +getBindName (L.Bind (ident, _) _ _) = ident +getBindName (L.BindC _ (ident, _) _ _) = ident -- Helper functions -- Gets custom data declarations form defs. -getDataFromDefs :: [T.Def] -> [T.Data] +getDataFromDefs :: [L.Def] -> [L.Data] getDataFromDefs = foldl ( \bs -> \case - T.DBind _ -> bs - T.DData d -> d : bs + L.DBind _ -> bs + L.DData d -> d : bs ) [] -getConsName :: T.Inj -> Ident -getConsName (T.Inj ident _) = ident +getConsName :: L.Inj -> Ident +getConsName (L.Inj ident _) = ident -getBindsFromDefs :: [T.Def] -> [T.Bind] +getBindsFromDefs :: [L.Def] -> [L.Bind] getBindsFromDefs = foldl ( \bs -> \case - T.DBind b -> b : bs - T.DData _ -> bs + L.DBind b -> b : bs + L.DData _ -> bs ) [] @@ -384,19 +443,19 @@ getDefsFromOutput o = (binds, dataInput) = splitBindsAndData o -- | Splits the output into binds and data declaration components (used in createNewData) -splitBindsAndData :: Output -> ([M.Bind], [(Ident, M.Type, T.Data)]) +splitBindsAndData :: Output -> ([M.Bind], [(Ident, M.Type, L.Data)]) splitBindsAndData output = foldl ( \(oBinds, oData) (ident, o) -> case o of - Marked -> error "internal bug in monomorphizer" + Marked -> error "internal bug in monomorphizer" Complete b -> (b : oBinds, oData) - Data t d -> (oBinds, (ident, t, d) : oData) + Data t d -> (oBinds, (ident, t, d) : oData) ) ([], []) (Map.toList output) -- | Converts all found constructors to monomorphic data declarations. -createNewData :: [(Ident, M.Type, T.Data)] -> Map.Map Ident M.Data -> Map.Map Ident M.Data +createNewData :: [(Ident, M.Type, L.Data)] -> Map.Map Ident M.Data -> Map.Map Ident M.Data createNewData [] o = o createNewData ((consIdent, consType, polyData) : input) o = createNewData input $ @@ -406,14 +465,17 @@ createNewData ((consIdent, consType, polyData) : input) o = (M.Data newDataType [newCons]) o where - T.Data (T.TData polyDataIdent _) _ = polyData + L.Data (L.TData polyDataIdent _) _ = polyData newDataType = getDataType consType newDataName = newName newDataType polyDataIdent newCons = M.Inj consIdent consType -- | Gets the Data Type of a constructor type (a -> Just a becomes Just a). getDataType :: M.Type -> M.Type -getDataType (M.TFun _t1 t2) = getDataType t2 +getDataType (M.TFun _t1 t2) = getDataType t2 getDataType tData@(M.TData _ _) = tData -getDataType _ = error "???" +getDataType _ = error "???" + +addLocal :: Ident -> Env -> Env +addLocal x env = env { locals = Set.insert x env.locals } diff --git a/src/Monomorphizer/MonomorphizerIr.hs b/src/Monomorphizer/MonomorphizerIr.hs index 052cdc1..59ad067 100644 --- a/src/Monomorphizer/MonomorphizerIr.hs +++ b/src/Monomorphizer/MonomorphizerIr.hs @@ -1,11 +1,14 @@ {-# LANGUAGE LambdaCase #-} -module Monomorphizer.MonomorphizerIr (module Monomorphizer.MonomorphizerIr) where +module Monomorphizer.MonomorphizerIr ( + module Monomorphizer.MonomorphizerIr, + module LambdaLifterIr +) where -import Grammar.Print -import TypeChecker.TypeCheckerIr qualified as TIR (Ident (..)) - -type Id = (TIR.Ident, Type) +import Data.List (intercalate) +import Grammar.Print +import LambdaLifterIr (Ident (..), Lit (..)) +import Prelude hiding (exp) newtype Program = Program [Def] deriving (Show, Ord, Eq) @@ -16,90 +19,80 @@ data Def = DBind Bind | DData Data data Data = Data Type [Inj] deriving (Show, Ord, Eq) -data Bind = Bind Id [Id] ExpT +data Bind = Bind (T Ident) [T Ident] (T Exp) + | BindC [T Ident] (T Ident) [T Ident] (T Exp) deriving (Show, Ord, Eq) +type T a = (a, Type) + data Exp - = EVar TIR.Ident + = EVar Ident + | EVarC [T Ident] Ident | ELit Lit - | ELet Bind ExpT - | EApp ExpT ExpT - | EAdd ExpT ExpT - | ECase ExpT [Branch] + | ELet Bind (T Exp) + | EApp (T Exp) (T Exp) + | EAdd (T Exp) (T Exp) + | ECase (T Exp) [Branch] deriving (Show, Ord, Eq) data Pattern - = PVar Id - | PLit (Lit, Type) - | PInj TIR.Ident [Pattern] + = PVar Ident + | PLit Lit + | PInj Ident [T Pattern] | PCatch - | PEnum TIR.Ident + | PEnum Ident deriving (Eq, Ord, Show) -data Branch = Branch (Pattern, Type) ExpT +data Branch = Branch (T Pattern) (T Exp) deriving (Eq, Ord, Show) -type ExpT = (Exp, Type) - -data Inj = Inj TIR.Ident Type +data Inj = Inj Ident Type deriving (Show, Ord, Eq) -data Lit - = LInt Integer - | LChar Char - deriving (Show, Ord, Eq) - -data Type = TLit TIR.Ident | TFun Type Type +data Type = TLit Ident | TFun Type Type deriving (Show, Ord, Eq) flattenType :: Type -> [Type] flattenType (TFun t1 t2) = t1 : flattenType t2 -flattenType x = [x] +flattenType x = [x] instance Print Program where prt i (Program sc) = prPrec i 0 $ prt 0 sc -instance Print (Bind) where +instance Print Bind where prt i (Bind sig@(name, _) parms rhs) = prPrec i 0 $ concatD - [ prtSig sig + [ prt 0 sig , prt 0 name - , prtIdPs 0 parms + , prt 0 parms , doc $ showString "=" , prt 0 rhs ] -prtSig :: Id -> Doc -prtSig (name, t) = - concatD - [ prt 0 name - , doc $ showString ":" - , prt 0 t - , doc $ showString ";" - ] + prt i (BindC cxt sig parms rhs) = + prPrec i 0 $ + concatD + [ doc . showString $ "{" ++ intercalate ", " (map (\(x, _) -> printTree x ++ "^") cxt) ++ "}" ++ printTree sig + , prt i parms + , doc $ showString "=" + , prt i rhs + ] -instance Print (ExpT) where - prt i (e, t) = - concatD - [ doc $ showString "(" - , prt i e - , doc $ showString "," - , prt i t - , doc $ showString ")" - ] instance Print [Bind] where - prt _ [] = concatD [] - prt _ [x] = concatD [prt 0 x] + prt _ [] = concatD [] + prt _ [x] = concatD [prt 0 x] prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs] -prtIdPs :: Int -> [Id] -> Doc -prtIdPs i = prPrec i 0 . concatD . map (prt i) instance Print Exp where prt i = \case EVar name -> prPrec i 3 $ prt 0 name + EVarC as lident -> doc . showString + $ "{" ++ intercalate ", " (map go as) ++ "}" ++ printTree lident + where + go (x, _) = printTree x ++ "^=" ++ printTree (EVar x) ELit lit -> prPrec i 3 $ prt 0 lit ELet b e -> prPrec i 3 $ @@ -134,16 +127,16 @@ instance Print Exp where ] instance Print Branch where - prt i (Branch (pattern_, t) exp) = prPrec i 0 (concatD [doc (showString "("), prt 0 pattern_, doc (showString " : "), prt 0 t, doc (showString ")"), doc (showString "=>"), prt 0 exp]) + prt i (Branch patt exp) = prPrec i 0 (concatD [prt i patt, doc (showString "=>"), prt 0 exp]) instance Print [Branch] where - prt _ [] = concatD [] - prt _ [x] = concatD [prt 0 x] + prt _ [] = concatD [] + prt _ [x] = concatD [prt 0 x] prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs] instance Print Def where prt i = \case - DBind bind -> prPrec i 0 (concatD [prt 0 bind]) + DBind bind -> prPrec i 0 (concatD [prt 0 bind]) DData data_ -> prPrec i 0 (concatD [prt 0 data_]) instance Print Data where @@ -152,23 +145,23 @@ instance Print Data where instance Print Inj where prt i = \case - Inj uident type_ -> prPrec i 0 (concatD [prt 0 uident, doc (showString ":"), prt 0 type_]) + Inj uident type_ -> prt i (uident, type_) instance Print Pattern where prt i = \case PVar name -> prPrec i 1 (concatD [prt 0 name]) - PLit (lit, _) -> prPrec i 1 (concatD [prt 0 lit]) + PLit lit -> prPrec i 1 (concatD [prt 0 lit]) PCatch -> prPrec i 1 (concatD [doc (showString "_")]) PEnum name -> prPrec i 1 (concatD [prt 0 name]) PInj uident patterns -> prPrec i 0 (concatD [prt 0 uident, prt 1 patterns]) instance Print [Def] where - prt _ [] = concatD [] - prt _ [x] = concatD [prt 0 x] + prt _ [] = concatD [] + prt _ [x] = concatD [prt 0 x] prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs] instance Print [Type] where - prt _ [] = concatD [] + prt _ [] = concatD [] prt _ (x : xs) = concatD [prt 0 x, doc (showString " "), prt 0 xs] instance Print Type where @@ -176,7 +169,3 @@ instance Print Type where TLit uident -> prPrec i 1 (concatD [prt 0 uident]) TFun type_1 type_2 -> prPrec i 0 (concatD [prt 1 type_1, doc (showString "->"), prt 0 type_2]) -instance Print Lit where - prt i = \case - LInt int -> prt i int - LChar char -> prt i char diff --git a/src/Monomorphizer/MorbIr.hs b/src/Monomorphizer/MorbIr.hs index 3e5db6b..35af47c 100644 --- a/src/Monomorphizer/MorbIr.hs +++ b/src/Monomorphizer/MorbIr.hs @@ -1,10 +1,14 @@ {-# LANGUAGE LambdaCase #-} -module Monomorphizer.MorbIr where -import Grammar.Print -import TypeChecker.TypeCheckerIr qualified as TIR (Ident (..)) +module Monomorphizer.MorbIr ( + module Monomorphizer.MorbIr, + module LambdaLifterIr +) where -type Id = (TIR.Ident, Type) +import Data.List (intercalate) +import Grammar.Print +import LambdaLifterIr (Ident (..), Lit (..)) +import Prelude hiding (exp) newtype Program = Program [Def] deriving (Show, Ord, Eq) @@ -15,91 +19,81 @@ data Def = DBind Bind | DData Data data Data = Data Type [Inj] deriving (Show, Ord, Eq) -data Bind = Bind Id [Id] ExpT +data Bind = Bind (T Ident) [T Ident] (T Exp) + | BindC [T Ident] (T Ident) [T Ident] (T Exp) deriving (Show, Ord, Eq) + +type T a = (a, Type) + data Exp - = EVar TIR.Ident + = EVar Ident + | EVarC [T Ident] Ident | ELit Lit - | ELet Bind ExpT - | EApp ExpT ExpT - | EAdd ExpT ExpT - | ECase ExpT [Branch] + | ELet Bind (T Exp) + | EApp (T Exp) (T Exp) + | EAdd (T Exp) (T Exp) + | ECase (T Exp) [Branch] deriving (Show, Ord, Eq) data Pattern - = PVar Id - | PLit (Lit, Type) - | PInj TIR.Ident [Pattern] + = PVar Ident + | PLit Lit + | PInj Ident [T Pattern] | PCatch - | PEnum TIR.Ident + | PEnum Ident deriving (Eq, Ord, Show) -data Branch = Branch (Pattern, Type) ExpT + +data Branch = Branch (T Pattern) (T Exp) deriving (Eq, Ord, Show) -type ExpT = (Exp, Type) - -data Inj = Inj TIR.Ident Type +data Inj = Inj Ident Type deriving (Show, Ord, Eq) -data Lit - = LInt Integer - | LChar Char - deriving (Show, Ord, Eq) - -data Type = TLit TIR.Ident | TFun Type Type | TData TIR.Ident [Type] +data Type = TLit Ident | TFun Type Type | TData Ident [Type] deriving (Show, Ord, Eq) flattenType :: Type -> [Type] flattenType (TFun t1 t2) = t1 : flattenType t2 -flattenType x = [x] +flattenType x = [x] instance Print Program where prt i (Program sc) = prPrec i 0 $ prt 0 sc -instance Print (Bind) where +instance Print Bind where prt i (Bind sig@(name, _) parms rhs) = prPrec i 0 $ concatD - [ prtSig sig + [ prt 0 sig , prt 0 name - , prtIdPs 0 parms + , prt 0 parms , doc $ showString "=" , prt 0 rhs ] -prtSig :: Id -> Doc -prtSig (name, t) = - concatD - [ prt 0 name - , doc $ showString ":" - , prt 0 t - , doc $ showString ";" - ] - -instance Print (ExpT) where - prt i (e, t) = - concatD - [ doc $ showString "(" - , prt i e - , doc $ showString "," - , prt i t - , doc $ showString ")" - ] + prt i (BindC cxt sig parms rhs) = + prPrec i 0 $ + concatD + [ doc . showString $ "{" ++ intercalate ", " (map (\(x, _) -> printTree x ++ "^") cxt) ++ "}" ++ printTree sig + , prt i parms + , doc $ showString "=" + , prt i rhs + ] instance Print [Bind] where - prt _ [] = concatD [] - prt _ [x] = concatD [prt 0 x] + prt _ [] = concatD [] + prt _ [x] = concatD [prt 0 x] prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs] -prtIdPs :: Int -> [Id] -> Doc -prtIdPs i = prPrec i 0 . concatD . map (prt i) - instance Print Exp where prt i = \case EVar name -> prPrec i 3 $ prt 0 name + EVarC as lident -> doc . showString + $ "{" ++ intercalate ", " (map go as) ++ "}" ++ printTree lident + where + go (x, _) = printTree x ++ "^=" ++ printTree (EVar x) ELit lit -> prPrec i 3 $ prt 0 lit ELet b e -> prPrec i 3 $ @@ -134,16 +128,16 @@ instance Print Exp where ] instance Print Branch where - prt i (Branch (pattern_, t) exp) = prPrec i 0 (concatD [doc (showString "("), prt 0 pattern_, doc (showString " : "), prt 0 t, doc (showString ")"), doc (showString "=>"), prt 0 exp]) + prt i (Branch patt exp) = prPrec i 0 (concatD [prt i patt, doc (showString "=>"), prt 0 exp]) instance Print [Branch] where - prt _ [] = concatD [] - prt _ [x] = concatD [prt 0 x] + prt _ [] = concatD [] + prt _ [x] = concatD [prt 0 x] prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs] instance Print Def where prt i = \case - DBind bind -> prPrec i 0 (concatD [prt 0 bind]) + DBind bind -> prPrec i 0 (concatD [prt 0 bind]) DData data_ -> prPrec i 0 (concatD [prt 0 data_]) instance Print Data where @@ -152,23 +146,23 @@ instance Print Data where instance Print Inj where prt i = \case - Inj uident type_ -> prPrec i 0 (concatD [prt 0 uident, doc (showString ":"), prt 0 type_]) + Inj uident type_ -> prt i (uident, type_) instance Print Pattern where prt i = \case PVar name -> prPrec i 1 (concatD [prt 0 name]) - PLit (lit, _) -> prPrec i 1 (concatD [prt 0 lit]) + PLit lit -> prPrec i 1 (concatD [prt 0 lit]) PCatch -> prPrec i 1 (concatD [doc (showString "_")]) PEnum name -> prPrec i 1 (concatD [prt 0 name]) PInj uident patterns -> prPrec i 0 (concatD [prt 0 uident, prt 1 patterns]) instance Print [Def] where - prt _ [] = concatD [] - prt _ [x] = concatD [prt 0 x] + prt _ [] = concatD [] + prt _ [x] = concatD [prt 0 x] prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs] instance Print [Type] where - prt _ [] = concatD [] + prt _ [] = concatD [] prt _ (x : xs) = concatD [prt 0 x, doc (showString " "), prt 0 xs] instance Print Type where @@ -177,8 +171,4 @@ instance Print Type where TFun type_1 type_2 -> prPrec i 0 (concatD [prt 1 type_1, doc (showString "->"), prt 0 type_2]) TData uident types -> prPrec i 1 (concatD [prt 0 uident, doc (showString "("), prt 0 types, doc (showString ")")]) -instance Print Lit where - prt i = \case - LInt int -> prt i int - LChar char -> prt i char diff --git a/src/TypeChecker/ReportTEVar.hs b/src/TypeChecker/ReportTEVar.hs index 62cd301..c15967a 100644 --- a/src/TypeChecker/ReportTEVar.hs +++ b/src/TypeChecker/ReportTEVar.hs @@ -2,15 +2,15 @@ module TypeChecker.ReportTEVar where -import Auxiliary (onM) -import Control.Applicative (Applicative (liftA2), liftA3) -import Control.Monad.Except (MonadError (throwError)) -import Data.Coerce (coerce) -import Data.Tuple.Extra (secondM) -import Grammar.Abs qualified as G -import Grammar.ErrM (Err) -import Grammar.Print (printTree) -import TypeChecker.TypeCheckerIr hiding (Type (..)) +import Auxiliary (onM) +import Control.Applicative (Applicative (liftA2), liftA3) +import Control.Monad.Except (MonadError (throwError)) +import Data.Coerce (coerce) +import Data.Tuple.Extra (secondM) +import qualified Grammar.Abs as G +import Grammar.ErrM (Err) +import Grammar.Print (printTree) +import TypeChecker.TypeCheckerIr hiding (Type (..)) data Type = TLit Ident @@ -18,7 +18,7 @@ data Type | TData Ident [Type] | TFun Type Type | TAll TVar Type - deriving (Eq, Ord, Show, Read) + deriving (Eq, Ord, Show) class ReportTEVar a b where reportTEVar :: a -> Err b @@ -29,20 +29,20 @@ instance ReportTEVar (Program' G.Type) (Program' Type) where instance ReportTEVar (Def' G.Type) (Def' Type) where reportTEVar = \case DBind bind -> DBind <$> reportTEVar bind - DData dat -> DData <$> reportTEVar dat + DData dat -> DData <$> reportTEVar dat instance ReportTEVar (Bind' G.Type) (Bind' Type) where reportTEVar (Bind id vars rhs) = liftA3 Bind (reportTEVar id) (reportTEVar vars) (reportTEVar rhs) instance ReportTEVar (Exp' G.Type) (Exp' Type) where reportTEVar exp = case exp of - EVar name -> pure $ EVar name - EInj name -> pure $ EInj name - ELit lit -> pure $ ELit lit - ELet bind e -> liftA2 ELet (reportTEVar bind) (reportTEVar e) - EApp e1 e2 -> onM EApp reportTEVar e1 e2 - EAdd e1 e2 -> onM EAdd reportTEVar e1 e2 - EAbs name e -> EAbs name <$> reportTEVar e + EVar name -> pure $ EVar name + EInj name -> pure $ EInj name + ELit lit -> pure $ ELit lit + ELet bind e -> liftA2 ELet (reportTEVar bind) (reportTEVar e) + EApp e1 e2 -> onM EApp reportTEVar e1 e2 + EAdd e1 e2 -> onM EAdd reportTEVar e1 e2 + EAbs name e -> EAbs name <$> reportTEVar e ECase e branches -> liftA2 ECase (reportTEVar e) (reportTEVar branches) instance ReportTEVar (Branch' G.Type) (Branch' Type) where @@ -53,10 +53,10 @@ instance ReportTEVar (Pattern' G.Type, G.Type) (Pattern' Type, Type) where instance ReportTEVar (Pattern' G.Type) (Pattern' Type) where reportTEVar = \case - PVar name -> pure $ PVar name - PLit lit -> pure $ PLit lit - PCatch -> pure PCatch - PEnum name -> pure $ PEnum name + PVar name -> pure $ PVar name + PLit lit -> pure $ PLit lit + PCatch -> pure PCatch + PEnum name -> pure $ PEnum name PInj name ps -> PInj name <$> reportTEVar ps instance ReportTEVar (Data' G.Type) (Data' Type) where @@ -65,10 +65,10 @@ instance ReportTEVar (Data' G.Type) (Data' Type) where instance ReportTEVar (Inj' G.Type) (Inj' Type) where reportTEVar (Inj name typ) = Inj name <$> reportTEVar typ -instance ReportTEVar (Id' G.Type) (Id' Type) where +instance ReportTEVar (a, G.Type) (a, Type) where reportTEVar = secondM reportTEVar -instance ReportTEVar (ExpT' G.Type) (ExpT' Type) where +instance ReportTEVar (T' Exp' G.Type) (T' Exp' Type) where reportTEVar (exp, typ) = liftA2 (,) (reportTEVar exp) (reportTEVar typ) instance ReportTEVar a b => ReportTEVar [a] [b] where @@ -76,9 +76,9 @@ instance ReportTEVar a b => ReportTEVar [a] [b] where instance ReportTEVar G.Type Type where reportTEVar = \case - G.TLit lit -> pure $ TLit (coerce lit) - G.TVar (G.MkTVar i) -> pure $ TVar (MkTVar $ coerce i) - G.TData name typs -> TData (coerce name) <$> reportTEVar typs - G.TFun t1 t2 -> liftA2 TFun (reportTEVar t1) (reportTEVar t2) + G.TLit lit -> pure $ TLit (coerce lit) + G.TVar (G.MkTVar i) -> pure $ TVar (MkTVar $ coerce i) + G.TData name typs -> TData (coerce name) <$> reportTEVar typs + G.TFun t1 t2 -> liftA2 TFun (reportTEVar t1) (reportTEVar t2) G.TAll (G.MkTVar i) t -> TAll (MkTVar $ coerce i) <$> reportTEVar t - G.TEVar tevar -> throwError ("Found TEVar: " ++ printTree tevar) + G.TEVar tevar -> throwError ("Found TEVar: " ++ printTree tevar) diff --git a/src/TypeChecker/TypeCheckerBidir.hs b/src/TypeChecker/TypeCheckerBidir.hs index 04a8d91..184243f 100644 --- a/src/TypeChecker/TypeCheckerBidir.hs +++ b/src/TypeChecker/TypeCheckerBidir.hs @@ -31,6 +31,7 @@ import Grammar.ErrM import Grammar.Print (printTree) import Prelude hiding (exp) import qualified TypeChecker.TypeCheckerIr as T +import TypeChecker.TypeCheckerIr (T, T') -- Implementation is derived from the paper (Dunfield and Krishnaswami 2013) -- https://doi.org/10.1145/2500365.2500582 @@ -172,7 +173,7 @@ typecheckInj (Inj inj_name inj_typ) name tvars -- | Γ ⊢ e ↑ A ⊣ Δ -- Under input context Γ, e checks against input type A, with output context ∆ -check :: Exp -> Type -> Tc (T.ExpT' Type) +check :: Exp -> Type -> Tc (T' T.Exp' Type) -- Γ,α ⊢ e ↑ A ⊣ Δ,α,Θ -- ------------------- ∀I @@ -212,12 +213,6 @@ check (ECase scrut pi) c = do e' <- check e c pure (T.Branch p' e') apply (T.ECase (scrut', a) pi', c) - where - go (pi, b) (Branch p e) = do - p' <- checkPattern p =<< apply a - e'@(_, b') <- infer e - subtype b' b - apply (T.Branch p' e' : pi, b') -- Γ,α ⊢ e ↓ A ⊣ Θ Θ ⊢ [Θ]A <: [Θ]B ⊣ Δ @@ -229,9 +224,6 @@ check e b = do subtype a b' apply (e', b) - - - checkPattern :: Pattern -> Type -> Tc (T.Pattern' Type, Type) checkPattern patt t_patt = case patt of @@ -297,7 +289,7 @@ checkPattern patt t_patt = case patt of -- | Γ ⊢ e ↓ A ⊣ Δ -- Under input context Γ, e infers output type A, with output context ∆ -infer :: Exp -> Tc (T.ExpT' Type) +infer :: Exp -> Tc (T' T.Exp' Type) infer (ELit lit) = apply (T.ELit lit, litType lit) -- Γ ∋ (x : A) Γ ⊢ rec(x) @@ -391,7 +383,7 @@ infer (ECase scrut pi) = do -- | Γ ⊢ A • e ⇓ C ⊣ Δ -- Under input context Γ , applying a function of type A to e infers type C, with output context ∆ -- Instantiate existential type variables until there is an arrow type. -applyInfer :: Type -> Exp -> Tc (T.ExpT' Type, Type) +applyInfer :: Type -> Exp -> Tc (T' T.Exp' Type, Type) -- Γ,ά ⊢ [ά/α]A • e ⇓ C ⊣ Δ -- ------------------------ ∀App diff --git a/src/TypeChecker/TypeCheckerHm.hs b/src/TypeChecker/TypeCheckerHm.hs index f4ec70a..7834ecd 100644 --- a/src/TypeChecker/TypeCheckerHm.hs +++ b/src/TypeChecker/TypeCheckerHm.hs @@ -1,32 +1,32 @@ -{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedRecordDot #-} -{-# LANGUAGE OverloadedStrings #-} -{-# LANGUAGE QualifiedDo #-} -{-# OPTIONS_GHC -Wno-incomplete-patterns #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE QualifiedDo #-} -- | A module for type checking and inference using algorithm W, Hindley-Milner module TypeChecker.TypeCheckerHm where -import Auxiliary (int, litType, maybeToRightM, unzip4) -import Auxiliary qualified as Aux -import Control.Monad.Except -import Control.Monad.Identity (Identity, runIdentity) -import Control.Monad.Reader -import Control.Monad.State -import Control.Monad.Writer -import Data.Coerce (coerce) -import Data.Function (on) -import Data.List (foldl', nub, sortOn) -import Data.List.Extra (unsnoc) -import Data.Map (Map) -import Data.Map qualified as M -import Data.Maybe (fromJust) -import Data.Set (Set) -import Data.Set qualified as S -import Debug.Trace (trace, traceShow) -import Grammar.Abs -import Grammar.Print (printTree) -import TypeChecker.TypeCheckerIr qualified as T +import Auxiliary (int, litType, maybeToRightM, unzip4) +import qualified Auxiliary as Aux +import Control.Monad.Except +import Control.Monad.Identity (Identity, runIdentity) +import Control.Monad.Reader +import Control.Monad.State +import Control.Monad.Writer +import Data.Coerce (coerce) +import Data.Function (on) +import Data.List (foldl', nub, sortOn) +import Data.List.Extra (unsnoc) +import Data.Map (Map) +import qualified Data.Map as M +import Data.Maybe (fromJust) +import Data.Set (Set) +import qualified Data.Set as S +import Debug.Trace (trace, traceShow) +import Grammar.Abs +import Grammar.Print (printTree) +import qualified TypeChecker.TypeCheckerIr as T +import TypeChecker.TypeCheckerIr (T, T') {- TODO @@ -41,7 +41,7 @@ typecheck :: Program -> Either String (T.Program' Type, [Warning]) typecheck = onLeft msg . run . checkPrg where onLeft :: (Error -> String) -> Either Error a -> Either String a - onLeft f (Left x) = Left $ f x + onLeft f (Left x) = Left $ f x onLeft _ (Right x) = Right x checkPrg :: Program -> Infer (T.Program' Type) @@ -68,13 +68,13 @@ prettify s (T.Program defs) = T.Program $ map (go s) defs replace :: Map T.Ident T.Ident -> Type -> Type replace m def@(TVar (MkTVar (LIdent a))) = case M.lookup (coerce a) m of - Just t -> TVar . MkTVar . LIdent $ coerce t + Just t -> TVar . MkTVar . LIdent $ coerce t Nothing -> def replace m (TFun t1 t2) = (TFun `on` replace m) t1 t2 replace m (TData name ts) = TData name (map (replace m) ts) replace m def@(TAll (MkTVar forall_) t) = case M.lookup (coerce forall_) m of Just found -> TAll (MkTVar $ coerce found) (replace m t) - Nothing -> def + Nothing -> def replace _ t = t bindCount :: [Def] -> Infer [(Int, Def)] @@ -128,7 +128,7 @@ preRun (x : xs) = case x of s <- gets sigs case M.lookup (coerce n) s of Nothing -> insertSig (coerce n) Nothing >> preRun xs - Just _ -> preRun xs + Just _ -> preRun xs DData d@(Data t _) -> collect (collectTVars t) >> checkData d >> preRun xs where -- Check if function body / signature has been declared already @@ -150,11 +150,11 @@ checkDef (x : xs) = case x of T.Data t $ map (\(Inj name typ) -> T.Inj (coerce name) typ) injs freeOrdered :: Type -> [T.Ident] -freeOrdered (TVar (MkTVar a)) = return (coerce a) +freeOrdered (TVar (MkTVar a)) = return (coerce a) freeOrdered (TAll (MkTVar bound) t) = return (coerce bound) ++ freeOrdered t -freeOrdered (TFun a b) = freeOrdered a ++ freeOrdered b -freeOrdered (TData _ a) = concatMap freeOrdered a -freeOrdered _ = mempty +freeOrdered (TFun a b) = freeOrdered a ++ freeOrdered b +freeOrdered (TData _ a) = concatMap freeOrdered a +freeOrdered _ = mempty -- Much cleaner implementation, unfortunately one minor bug -- checkBind :: Bind -> Infer (T.Bind' Type) @@ -257,13 +257,13 @@ checkInj (Inj c inj_typ) name tvars toTVar :: Type -> Either Error TVar toTVar = \case TVar tvar -> pure tvar - _ -> uncatchableErr "Not a type variable" + _ -> uncatchableErr "Not a type variable" returnType :: Type -> Type returnType (TFun _ t2) = returnType t2 -returnType a = a +returnType a = a -inferExp :: Exp -> Infer (T.ExpT' Type) +inferExp :: Exp -> Infer (T' T.Exp' Type) inferExp e = do (s, (e', t)) <- algoW e let subbed = apply s t @@ -274,7 +274,7 @@ class CollectTVars a where instance CollectTVars Exp where collectTVars (EAnn e t) = collectTVars t `S.union` collectTVars e - collectTVars _ = S.empty + collectTVars _ = S.empty instance CollectTVars Type where collectTVars (TVar (MkTVar i)) = S.singleton (coerce i) @@ -287,7 +287,7 @@ instance CollectTVars Type where collect :: Set T.Ident -> Infer () collect s = modify (\st -> st{takenTypeVars = s `S.union` takenTypeVars st}) -algoW :: Exp -> Infer (Subst, T.ExpT' Type) +algoW :: Exp -> Infer (Subst, T' T.Exp' Type) algoW = \case err@(EAnn e t) -> do (sub0, (e', t')) <- exprErr (algoW e) err @@ -600,12 +600,12 @@ generalize :: Map T.Ident Type -> Type -> Type generalize env t = go (S.toList $ free t S.\\ free env) (removeForalls t) where go :: [T.Ident] -> Type -> Type - go [] t = t + go [] t = t go (x : xs) t = TAll (MkTVar (coerce x)) (go xs t) removeForalls :: Type -> Type - removeForalls (TAll _ t) = removeForalls t + removeForalls (TAll _ t) = removeForalls t removeForalls (TFun t1 t2) = TFun (removeForalls t1) (removeForalls t2) - removeForalls t = t + removeForalls t = t {- | Instantiate a polymorphic type. The free type variables are substituted with fresh ones. @@ -643,7 +643,7 @@ fresh = do ungo :: [TVar] -> Type -> Type -> Bool ungo tvars t1 t2 = case run (go tvars t1 t2) of Right (b, _) -> b - _ -> False + _ -> False -- TODO: Fix the following -- Maybe locally using the Infer monad can cause trouble. -- Since the fresh count starts from zero @@ -656,7 +656,7 @@ fresh = do skipForalls :: Type -> Type skipForalls = \case TAll _ t -> skipForalls t - t -> t + t -> t freshen :: Type -> Infer Type freshen t = do @@ -705,10 +705,10 @@ instance SubstType Type where TLit _ -> t TVar (MkTVar a) -> case M.lookup (coerce a) sub of Nothing -> TVar (MkTVar $ coerce a) - Just t -> t + Just t -> t TAll (MkTVar i) t -> case M.lookup (coerce i) sub of Nothing -> TAll (MkTVar i) (apply sub t) - Just _ -> apply sub t + Just _ -> apply sub t TFun a b -> TFun (apply sub a) (apply sub b) TData name a -> TData name (apply sub a) TEVar (MkTEVar _) -> t @@ -724,7 +724,7 @@ instance SubstType (Map T.Ident Type) where instance SubstType (Map T.Ident (Maybe Type)) where apply s = M.map (fmap $ apply s) -instance SubstType (T.ExpT' Type) where +instance SubstType (T' T.Exp' Type) where apply s (e, t) = (apply s e, apply s t) instance SubstType (T.Exp' Type) where @@ -753,10 +753,10 @@ instance SubstType (T.Branch' Type) where instance SubstType (T.Pattern' Type) where apply s = \case T.PVar iden -> T.PVar iden - T.PLit lit -> T.PLit lit + T.PLit lit -> T.PLit lit T.PInj i ps -> T.PInj i $ apply s ps - T.PCatch -> T.PCatch - T.PEnum i -> T.PEnum i + T.PCatch -> T.PCatch + T.PEnum i -> T.PEnum i instance SubstType (T.Pattern' Type, Type) where apply s (p, t) = (apply s p, apply s t) @@ -764,7 +764,7 @@ instance SubstType (T.Pattern' Type, Type) where instance SubstType a => SubstType [a] where apply s = map (apply s) -instance SubstType (T.Id' Type) where +instance SubstType (T T.Ident Type) where apply s (name, t) = (name, apply s t) -- | Represents the empty substition set @@ -797,11 +797,11 @@ withBindings xs = -- | Run the monadic action with a pattern withPattern :: (Monad m, MonadReader Ctx m) => (T.Pattern' Type, Type) -> m a -> m a withPattern (p, t) ma = case p of - T.PVar x -> withBinding x t ma + T.PVar x -> withBinding x t ma T.PInj _ ps -> foldl' (flip withPattern) ma ps - T.PLit _ -> ma - T.PCatch -> ma - T.PEnum _ -> ma + T.PLit _ -> ma + T.PCatch -> ma + T.PEnum _ -> ma -- | Insert a function signature into the environment insertSig :: T.Ident -> Maybe Type -> Infer () @@ -826,11 +826,11 @@ existInj n = gets (M.lookup n . injections) flattenType :: Type -> [Type] flattenType (TFun a b) = flattenType a <> flattenType b -flattenType a = [a] +flattenType a = [a] typeLength :: Type -> Int typeLength (TFun _ b) = 1 + typeLength b -typeLength _ = 1 +typeLength _ = 1 {- | Catch an error if possible and add the given expression as addition to the error message @@ -913,11 +913,11 @@ newtype Ctx = Ctx {vars :: Map T.Ident Type} deriving (Show) data Env = Env - { count :: Int - , nextChar :: Char - , sigs :: Map T.Ident (Maybe Type) + { count :: Int + , nextChar :: Char + , sigs :: Map T.Ident (Maybe Type) , takenTypeVars :: Set T.Ident - , injections :: Map T.Ident Type + , injections :: Map T.Ident Type , declaredBinds :: Set T.Ident } deriving (Show) diff --git a/src/TypeChecker/TypeCheckerIr.hs b/src/TypeChecker/TypeCheckerIr.hs index a956ff3..9dea744 100644 --- a/src/TypeChecker/TypeCheckerIr.hs +++ b/src/TypeChecker/TypeCheckerIr.hs @@ -1,6 +1,7 @@ {-# LANGUAGE LambdaCase #-} {-# LANGUAGE PatternSynonyms #-} + module TypeChecker.TypeCheckerIr ( module Grammar.Abs, module TypeChecker.TypeCheckerIr, @@ -10,31 +11,30 @@ import Data.String (IsString) import Grammar.Abs (Lit (..)) import Grammar.Print import Prelude -import qualified Prelude as C (Eq, Ord, Read, Show) newtype Program' t = Program [Def' t] - deriving (C.Eq, C.Ord, C.Show, C.Read, Functor) + deriving (Eq, Ord, Show, Functor) data Def' t = DBind (Bind' t) | DData (Data' t) - deriving (C.Eq, C.Ord, C.Show, C.Read, Functor) + deriving (Eq, Ord, Show, Functor) data Type = TLit Ident | TVar TVar | TData Ident [Type] | TFun Type Type - deriving (Eq, Ord, Show, Read) + deriving (Eq, Ord, Show) data Data' t = Data t [Inj' t] - deriving (C.Eq, C.Ord, C.Show, C.Read, Functor) + deriving (Eq, Ord, Show, Functor) data Inj' t = Inj Ident t - deriving (C.Eq, C.Ord, C.Show, C.Read, Functor) + deriving (Eq, Ord, Show, Functor) newtype Ident = Ident String - deriving (C.Eq, C.Ord, C.Show, C.Read, IsString) + deriving (Eq, Ord, Show, IsString) data Pattern' t = PVar Ident @@ -42,30 +42,31 @@ data Pattern' t | PCatch | PEnum Ident | PInj Ident [(Pattern' t, t)] - deriving (C.Eq, C.Ord, C.Show, C.Read, Functor) + deriving (Eq, Ord, Show, Functor) data Exp' t = EVar Ident | EInj Ident | ELit Lit - | ELet (Bind' t) (ExpT' t) - | EApp (ExpT' t) (ExpT' t) - | EAdd (ExpT' t) (ExpT' t) - | EAbs Ident (ExpT' t) - | ECase (ExpT' t) [Branch' t] - deriving (C.Eq, C.Ord, C.Show, C.Read, Functor) + | ELet (Bind' t) (T' Exp' t) + | EApp (T' Exp' t) (T' Exp' t) + | EAdd (T' Exp' t) (T' Exp' t) + | EAbs Ident (T' Exp' t) + | ECase (T' Exp' t) [Branch' t] + deriving (Eq, Ord, Show, Functor) newtype TVar = MkTVar Ident - deriving (C.Eq, C.Ord, C.Show, C.Read) + deriving (Eq, Ord, Show) -type Id' t = (Ident, t) -type ExpT' t = (Exp' t, t) +type T' a t = (a t, t) +type T a t = (a, t) -data Bind' t = Bind (Id' t) [Id' t] (ExpT' t) - deriving (C.Eq, C.Ord, C.Show, C.Read, Functor) -data Branch' t = Branch (Pattern' t, t) (ExpT' t) - deriving (C.Eq, C.Ord, C.Show, C.Read, Functor) +data Bind' t = Bind (T Ident t) [T Ident t] (T' Exp' t) + deriving (Eq, Ord, Show, Functor) + +data Branch' t = Branch (T' Pattern' t) (T' Exp' t) + deriving (Eq, Ord, Show, Functor) instance Print Ident where prt _ (Ident s) = doc $ showString s @@ -81,22 +82,22 @@ instance Print t => Print (Bind' t) where , prt i rhs ] -prtSig :: Print t => Id' t -> Doc -prtSig (name, t) = +prtSig :: Print t => T Ident t -> Doc +prtSig (x, t) = concatD - [ prt 0 name + [ prt 0 x , doc $ showString ":" , prt 0 t ] -instance Print t => Print (ExpT' t) where - prt i (e, t) = +instance (Print a, Print t) => Print (T a t) where + prt i (x, t) = concatD - [ doc $ showString "(" - , prt i e - , doc $ showString ":" - , prt 0 t - , doc $ showString ")" + [ -- doc $ showString "(" + {- , -} prt i x +-- , doc $ showString ":" +-- , prt 0 t +-- , doc $ showString ")" ] instance Print t => Print [Bind' t] where @@ -104,15 +105,6 @@ instance Print t => Print [Bind' t] where prt i [x] = concatD [prt i x] prt i (x : xs) = concatD [prt i x, doc (showString ";"), prt i xs] -instance Print t => Print (Id' t) where - prt i (name, t) = - concatD - [ doc $ showString "(" - , prt i name - , doc $ showString "," - , prt i t - , doc $ showString ")" - ] instance Print t => Print (Exp' t) where prt i = \case @@ -151,9 +143,6 @@ instance Print t => Print [Inj' t] where prt i [x] = prt i x prt i (x : xs) = prPrec i 0 $ concatD [prt i x, doc $ showString "\n ", prt i xs] -instance Print t => Print (Pattern' t, t) where - prt i (p, t) = prPrec i 1 (concatD [prt i p, prt i t]) - instance Print t => Print (Pattern' t) where prt i = \case PVar name -> prPrec i 1 (concatD [prt 0 name]) @@ -189,8 +178,6 @@ type Branch = Branch' Type type Pattern = Pattern' Type type Inj = Inj' Type type Exp = Exp' Type -type ExpT = ExpT' Type -type Id = Id' Type pattern TVar' s = TVar (MkTVar s) pattern DBind' id vars expt = DBind (Bind id vars expt) pattern DData' typ injs = DData (Data typ injs) diff --git a/test_map2.ll b/test_map2.ll new file mode 100644 index 0000000..ae37f18 --- /dev/null +++ b/test_map2.ll @@ -0,0 +1,185 @@ +target triple = "x86_64-pc-linux-gnu" +target datalayout = "e-m:o-i64:64-f80:128-n8:16:32:64-S128" +@.str = private unnamed_addr constant [3 x i8] c"%i +", align 1 +@.non_exhaustive_patterns = private unnamed_addr constant [41 x i8] c"Non-exhaustive patterns in case at %i:%i +" +declare i32 @printf(ptr noalias nocapture, ...) +declare i32 @exit(i32 noundef) +declare ptr @malloc(i32 noundef) +%List = type { i8, [23 x i8] } +%Cons = type { i8, i64, %List* } +%Nil = type { i8 } +; NYTT: kontexttyp +%Closure_sc_0 = type { i64 (i64)*, i64 } + +; Ident "sum$List_Int": (ECase (EVar (Ident "$4xs"),TLit (Ident "List")) [Branch (PEnum (Ident "Nil"),TLit (Ident "List")) (ELit (LInt 0),TLit (Ident "Int")),Branch (PInj (Ident "Cons") [PVar (Ident "$5x",TLit (Ident "Int")),PVar (Ident "$6xs",TLit (Ident "List"))],TLit (Ident "List")) (EAdd (EVar (Ident "$5x"),TLit (Ident "Int")) (EApp (EVar (Ident "sum$List_Int"),TFun (TLit (Ident "List")) (TLit (Ident "Int"))) (EVar (Ident "$6xs"),TLit (Ident "List")),TLit (Ident "Int")),TLit (Ident "Int"))],TLit (Ident "Int")) +define fastcc i64 @sum$List_Int(%List %$4xs) { + %1 = alloca i64 + ; Penum + %2 = extractvalue %List %$4xs, 0 + %3 = icmp eq i8 %2, 1 + br i1 %3, label %lbl_success_3, label %lbl_failed_2 + +lbl_success_3: + %4 = alloca %List + store %List %$4xs, ptr %4 + %5 = load %Nil, ptr %4 + store i64 0, ptr %1 + br label %lbl_escape_1 + +lbl_failed_2: + ; Inj + %6 = extractvalue %List %$4xs, 0 + %7 = icmp eq i8 %6, 0 + br i1 %7, label %lbl_success_5, label %lbl_failed_4 + +lbl_success_5: + %8 = alloca %List + store %List %$4xs, ptr %8 + %9 = load %Cons, ptr %8 + ; ident i64 + %$5x = extractvalue %Cons %9, 1 + ; ident %List + %10 = extractvalue %Cons %9, 2 + %$6xs = load %List, ptr %10 + ; TLit (Ident "Int") + %11 = call fastcc i64 @sum$List_Int(%List %$6xs) + %12 = add i64 %$5x, %11 + store i64 %12, ptr %1 + br label %lbl_escape_1 + +lbl_failed_4: + call i32 (ptr, ...) @printf(ptr noundef @.non_exhaustive_patterns, i64 noundef 12, i64 noundef 6) + call i32 @exit(i32 noundef 1) + br label %lbl_escape_1 + +lbl_escape_1: + %15 = load i64, ptr %1 + ret i64 %15 +} + +; Ident "sc_0$Int_Int": (EAdd (EVar (Ident "$7x"),TLit (Ident "Int")) (ELit (LInt 10),TLit (Ident "Int")),TLit (Ident "Int")) +; ÄNDRAT: lägg till kontextpekare +define fastcc i64 @sc_0$Int_Int(ptr %closure_sc_0, i64 %$7x) { + ; NYTT: Ladda alla fria variabler + %fri_variabel_ptr = getelementptr inbounds %Closure_sc_0, ptr %closure_sc_0, i32 0, i32 1 + %fri_variabel = load i64, ptr %fri_variabel_ptr + ; ÄNDRAT: %fri_variabel istället för 2 + %1 = add i64 %$7x, %fri_variabel + ret i64 %1 +} + +; Ident "map$Int_Int_List_List": (ECase (EVar (Ident "$1xs"),TLit (Ident "List")) [Branch (PEnum (Ident "Nil"),TLit (Ident "List")) (EVar (Ident "Nil"),TLit (Ident "List")),Branch (PInj (Ident "Cons") [PVar (Ident "$2x",TLit (Ident "Int")),PVar (Ident "$3xs",TLit (Ident "List"))],TLit (Ident "List")) (EApp (EApp (EVar (Ident "Cons"),TFun (TLit (Ident "Int")) (TFun (TLit (Ident "List")) (TLit (Ident "List")))) (EApp (EVar (Ident "$0f"),TFun (TLit (Ident "Int")) (TLit (Ident "Int"))) (EVar (Ident "$2x"),TLit (Ident "Int")),TLit (Ident "Int")),TFun (TLit (Ident "List")) (TLit (Ident "List"))) (EApp (EApp (EVar (Ident "map$Int_Int_List_List"),TFun (TFun (TLit (Ident "Int")) (TLit (Ident "Int"))) (TFun (TLit (Ident "List")) (TLit (Ident "List")))) (EVar (Ident "$0f"),TFun (TLit (Ident "Int")) (TLit (Ident "Int"))),TFun (TLit (Ident "List")) (TLit (Ident "List"))) (EVar (Ident "$3xs"),TLit (Ident "List")),TLit (Ident "List")),TLit (Ident "List"))],TLit (Ident "List")) +; ÄNDRAT: ptr istället för i64 (i64)* +define fastcc %List @map$Int_Int_List_List(ptr %$0f, %List %$1xs) { + ; NYTT: ta fram funktionspekaren + %$0f_deref = load i64(i64)*, ptr %$0f + + %1 = alloca %List + ; Penum + %2 = extractvalue %List %$1xs, 0 + %3 = icmp eq i8 %2, 1 + br i1 %3, label %lbl_success_8, label %lbl_failed_7 + +lbl_success_8: + %4 = alloca %List + store %List %$1xs, ptr %4 + %5 = load %Nil, ptr %4 + %6 = call fastcc %List @Nil() + store %List %6, ptr %1 + br label %lbl_escape_6 + +lbl_failed_7: + ; Inj + %7 = extractvalue %List %$1xs, 0 + %8 = icmp eq i8 %7, 0 + br i1 %8, label %lbl_success_10, label %lbl_failed_9 + +lbl_success_10: + %9 = alloca %List + store %List %$1xs, ptr %9 + %10 = load %Cons, ptr %9 + ; ident i64 + %$2x = extractvalue %Cons %10, 1 + ; ident %List + %11 = extractvalue %Cons %10, 2 + %$3xs = load %List, ptr %11 + ; TLit (Ident "Int") + ; ÄNDRAT använd deref + %12 = call fastcc i64 %$0f_deref(ptr %$0f, i64 %$2x) + ; TLit (Ident "List") + ; ÄNDRAT ptr istället för 64 (64)* och skicka med ptr + %13 = call fastcc %List @map$Int_Int_List_List(ptr %$0f, %List %$3xs) + ; TLit (Ident "List") + %14 = call fastcc %List @Cons(i64 %12, %List %13) + store %List %14, ptr %1 + br label %lbl_escape_6 + +lbl_failed_9: + call i32 (ptr, ...) @printf(ptr noundef @.non_exhaustive_patterns, i64 noundef 14, i64 noundef 6) + call i32 @exit(i32 noundef 1) + br label %lbl_escape_6 + +lbl_escape_6: + %17 = load %List, ptr %1 + ret %List %17 +} + + +; Ident "main": (EApp (EVar (Ident "sum$List_Int"),TFun (TLit (Ident "List")) (TLit (Ident "Int"))) (EApp (EApp (EVar (Ident "map$Int_Int_List_List"),TFun (TFun (TLit (Ident "Int")) (TLit (Ident "Int"))) (TFun (TLit (Ident "List")) (TLit (Ident "List")))) (EVar (Ident "sc_0$Int_Int"),TFun (TLit (Ident "Int")) (TLit (Ident "Int"))),TFun (TLit (Ident "List")) (TLit (Ident "List"))) (EApp (EApp (EVar (Ident "Cons"),TFun (TLit (Ident "Int")) (TFun (TLit (Ident "List")) (TLit (Ident "List")))) (ELit (LInt 1),TLit (Ident "Int")),TFun (TLit (Ident "List")) (TLit (Ident "List"))) (EApp (EApp (EVar (Ident "Cons"),TFun (TLit (Ident "Int")) (TFun (TLit (Ident "List")) (TLit (Ident "List")))) (ELit (LInt 2),TLit (Ident "Int")),TFun (TLit (Ident "List")) (TLit (Ident "List"))) (EVar (Ident "Nil"),TLit (Ident "List")),TLit (Ident "List")),TLit (Ident "List")),TLit (Ident "List")),TLit (Ident "Int")) +define fastcc i64 @main() { + %1 = call fastcc %List @Nil() + ; TLit (Ident "List") + %2 = call fastcc %List @Cons(i64 2, %List %1) + ; TLit (Ident "List") + %3 = call fastcc %List @Cons(i64 1, %List %2) + ; TLit (Ident "List") + + ; NYTT: spara funktionspekaren och 100 i kontexten + %closure_sc_0 = alloca %Closure_sc_0 + store i64(i64)* @sc_0$Int_Int, ptr %closure_sc_0 + %fri_variabel_ptr = getelementptr inbounds %Closure_sc_0, ptr %closure_sc_0, i32 0, i32 1 + store i64 100, ptr %fri_variabel_ptr + + + ; store %Closure_sc_0 {i64 (i64)* @sc_0$Int_Int, 100}, ptr %closure_sc_0 + + ; ÄNDRAT ptr %closure_sc_0 istället för i64 (i64)* @sc_0$Int_Int + %4 = call fastcc %List @map$Int_Int_List_List(ptr %closure_sc_0, %List %3) + ; TLit (Ident "Int") + %5 = call fastcc i64 @sum$List_Int(%List %4) + call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef %5) + ret i64 0 +} + +define fastcc %List @Cons(i64 %arg_0, %List %arg_1) { + %1 = alloca %List + %2 = getelementptr %List, %List* %1, i64 0, i32 0 + store i8 0, i8* %2 + %3 = bitcast %List* %1 to %Cons* + ; i64 arg_0 1 + %4 = getelementptr %Cons, %Cons* %3, i64 0, i32 1 + ; Just store + store i64 %arg_0, ptr %4 + ; %List arg_1 2 + %5 = getelementptr %Cons, %Cons* %3, i64 0, i32 2 + ; Malloc and store + %6 = call ptr @malloc(i64 24) + store %List %arg_1, ptr %6 + store %List* %6, ptr %5 + ; Return the newly constructed value + %7 = load %List, ptr %1 + ret %List %7 +} + +define fastcc %List @Nil() { + %1 = alloca %List + %2 = getelementptr %List, %List* %1, i64 0, i32 0 + store i8 1, i8* %2 + %3 = bitcast %List* %1 to %Nil* + ; Return the newly constructed value + %4 = load %List, ptr %1 + ret %List %4 +} +