Merge closures mostly done. Desugaring cases is a problem.
This commit is contained in:
commit
019ed0d45a
29 changed files with 1484 additions and 757 deletions
|
|
@ -43,6 +43,7 @@ executable language
|
||||||
TypeChecker.ReportTEVar
|
TypeChecker.ReportTEVar
|
||||||
TypeChecker.RemoveForall
|
TypeChecker.RemoveForall
|
||||||
LambdaLifter
|
LambdaLifter
|
||||||
|
LambdaLifterIr
|
||||||
Monomorphizer.Monomorphizer
|
Monomorphizer.Monomorphizer
|
||||||
Monomorphizer.MonomorphizerIr
|
Monomorphizer.MonomorphizerIr
|
||||||
Monomorphizer.MorbIr
|
Monomorphizer.MorbIr
|
||||||
|
|
@ -101,6 +102,8 @@ Test-suite language-testsuite
|
||||||
TypeChecker.TypeChecker
|
TypeChecker.TypeChecker
|
||||||
AnnForall
|
AnnForall
|
||||||
ReportForall
|
ReportForall
|
||||||
|
LambdaLifterIr
|
||||||
|
LambdaLifter
|
||||||
TypeChecker.TypeCheckerHm
|
TypeChecker.TypeCheckerHm
|
||||||
TypeChecker.TypeCheckerBidir
|
TypeChecker.TypeCheckerBidir
|
||||||
TypeChecker.ReportTEVar
|
TypeChecker.ReportTEVar
|
||||||
|
|
|
||||||
6
sample-programs/working/addition.chrf
Normal file
6
sample-programs/working/addition.chrf
Normal file
|
|
@ -0,0 +1,6 @@
|
||||||
|
|
||||||
|
|
||||||
|
add : Int -> Int -> Int -> Int
|
||||||
|
add x y z = x + y + z
|
||||||
|
|
||||||
|
main = add 8 6 2
|
||||||
7
sample-programs/working/apply.crf
Normal file
7
sample-programs/working/apply.crf
Normal file
|
|
@ -0,0 +1,7 @@
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
apply : (Int -> Int) -> Int -> Int
|
||||||
|
apply f y = f y
|
||||||
|
|
||||||
|
main = apply (\y. y + y) 5
|
||||||
10
sample-programs/working/closure.crf
Normal file
10
sample-programs/working/closure.crf
Normal file
|
|
@ -0,0 +1,10 @@
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
apply : (Int -> Int) -> Int -> Int
|
||||||
|
apply f z = f z
|
||||||
|
|
||||||
|
main =
|
||||||
|
let x = 10 in
|
||||||
|
apply (\y. y + x) 6
|
||||||
15
sample-programs/working/foldr.crf
Normal file
15
sample-programs/working/foldr.crf
Normal file
|
|
@ -0,0 +1,15 @@
|
||||||
|
|
||||||
|
data List a where
|
||||||
|
Nil : List a
|
||||||
|
Cons : a -> List a -> List a
|
||||||
|
|
||||||
|
foldr : (a -> b -> b) -> b -> List a -> b
|
||||||
|
foldr f y xs = case xs of
|
||||||
|
Nil => y
|
||||||
|
Cons x xs => f x (foldr f y xs)
|
||||||
|
|
||||||
|
|
||||||
|
main = let z = 2 in foldr (\x.\y. x + y + z) 0 (Cons 1000 (Cons 100 Nil))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
25
sample-programs/working/lambda-2.crf
Normal file
25
sample-programs/working/lambda-2.crf
Normal file
|
|
@ -0,0 +1,25 @@
|
||||||
|
data List (a) where
|
||||||
|
Nil : List (a)
|
||||||
|
Cons : a -> List (a) -> List (a)
|
||||||
|
|
||||||
|
map : (a -> b) -> List (a) -> List (b)
|
||||||
|
map f xs = case xs of
|
||||||
|
Nil => Nil
|
||||||
|
Cons x xs => Cons (f x) (map f xs)
|
||||||
|
|
||||||
|
add : Int -> Int -> Int
|
||||||
|
add x y = x + y
|
||||||
|
|
||||||
|
foldr : (a -> b -> b) -> b -> List (a) -> b
|
||||||
|
foldr f y xs = case xs of
|
||||||
|
Nil => y
|
||||||
|
Cons x xs => f x (foldr f y xs)
|
||||||
|
|
||||||
|
f : List (Int)
|
||||||
|
f = ((\x.\ys. map (\y. add y x) ys) 4 (Cons 1 (Cons 2 Nil)))
|
||||||
|
-- [5, 6]
|
||||||
|
|
||||||
|
main : Int
|
||||||
|
main = foldr add 0 f
|
||||||
|
|
||||||
|
|
||||||
21
sample-programs/working/lambda.crf
Normal file
21
sample-programs/working/lambda.crf
Normal file
|
|
@ -0,0 +1,21 @@
|
||||||
|
data List a where
|
||||||
|
Nil : List a
|
||||||
|
Cons : a -> List a -> List a
|
||||||
|
|
||||||
|
map : (a -> b) -> List a -> List b
|
||||||
|
map f xs = case xs of
|
||||||
|
Nil => Nil
|
||||||
|
Cons x xs => Cons (f x) (map f xs)
|
||||||
|
|
||||||
|
|
||||||
|
f : List Int
|
||||||
|
f = (\x.\ys. map (\y. y + x) ys) 4 (Cons 1 (Cons 2 Nil))
|
||||||
|
-- [5, 6]
|
||||||
|
|
||||||
|
sum : List Int -> Int
|
||||||
|
sum xs = case xs of
|
||||||
|
Nil => 0
|
||||||
|
Cons x xs => x + sum xs
|
||||||
|
|
||||||
|
main = sum f
|
||||||
|
|
||||||
3
sample-programs/working/let.crf
Normal file
3
sample-programs/working/let.crf
Normal file
|
|
@ -0,0 +1,3 @@
|
||||||
|
|
||||||
|
|
||||||
|
main = let x = 10 in 6 + x
|
||||||
16
sample-programs/working/map.crf
Normal file
16
sample-programs/working/map.crf
Normal file
|
|
@ -0,0 +1,16 @@
|
||||||
|
|
||||||
|
data List a where
|
||||||
|
Nil : List a
|
||||||
|
Cons : a -> List a -> List a
|
||||||
|
|
||||||
|
map : (a -> b) -> List a -> List b
|
||||||
|
map f xs = case xs of
|
||||||
|
Nil => Nil
|
||||||
|
Cons x xs => Cons (f x) (map f xs)
|
||||||
|
|
||||||
|
sum : List Int -> Int
|
||||||
|
sum xs = case xs of
|
||||||
|
Nil => 0
|
||||||
|
Cons x xs => x + (sum xs)
|
||||||
|
|
||||||
|
main = let y = 10 in sum (map (\x. x + y) (Cons 2 (Cons 4 Nil)))
|
||||||
7
sample-programs/working/simple.crf
Normal file
7
sample-programs/working/simple.crf
Normal file
|
|
@ -0,0 +1,7 @@
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
f = 10
|
||||||
|
|
||||||
|
|
||||||
|
main = f + 6
|
||||||
|
|
@ -2,8 +2,8 @@ module Codegen.Auxillary where
|
||||||
|
|
||||||
import Codegen.LlvmIr (LLVMType (..), LLVMValue (..))
|
import Codegen.LlvmIr (LLVMType (..), LLVMValue (..))
|
||||||
import Control.Monad (foldM_)
|
import Control.Monad (foldM_)
|
||||||
import Monomorphizer.MonomorphizerIr as MIR (ExpT, Type (..))
|
import Monomorphizer.MonomorphizerIr as MIR (Exp, T, Type (..))
|
||||||
import TypeChecker.TypeCheckerIr qualified as TIR
|
import qualified TypeChecker.TypeCheckerIr as TIR
|
||||||
|
|
||||||
type2LlvmType :: MIR.Type -> LLVMType
|
type2LlvmType :: MIR.Type -> LLVMType
|
||||||
type2LlvmType (MIR.TLit id@(TIR.Ident name)) = case name of
|
type2LlvmType (MIR.TLit id@(TIR.Ident name)) = case name of
|
||||||
|
|
@ -19,7 +19,7 @@ type2LlvmType (MIR.TFun t xs) = do
|
||||||
function2LLVMType (TFun t xs) s = function2LLVMType xs (type2LlvmType t : s)
|
function2LLVMType (TFun t xs) s = function2LLVMType xs (type2LlvmType t : s)
|
||||||
function2LLVMType x s = (type2LlvmType x, s)
|
function2LLVMType x s = (type2LlvmType x, s)
|
||||||
|
|
||||||
getType :: ExpT -> LLVMType
|
getType :: T Exp -> LLVMType
|
||||||
getType (_, t) = type2LlvmType t
|
getType (_, t) = type2LlvmType t
|
||||||
|
|
||||||
extractTypeName :: MIR.Type -> TIR.Ident
|
extractTypeName :: MIR.Type -> TIR.Ident
|
||||||
|
|
|
||||||
|
|
@ -1,17 +1,23 @@
|
||||||
|
{-# LANGUAGE OverloadedRecordDot #-}
|
||||||
|
{-# LANGUAGE OverloadedStrings #-}
|
||||||
|
|
||||||
module Codegen.Codegen (generateCode) where
|
module Codegen.Codegen (generateCode) where
|
||||||
|
|
||||||
import Codegen.CompilerState (
|
import Codegen.CompilerState (CodeGenerator (..),
|
||||||
CodeGenerator (instructions),
|
StructType (inst),
|
||||||
initCodeGenerator,
|
initCodeGenerator)
|
||||||
)
|
|
||||||
import Codegen.Emits (compileScs)
|
import Codegen.Emits (compileScs)
|
||||||
import Codegen.LlvmIr as LIR (llvmIrToString)
|
import Codegen.LlvmIr as LIR (LLVMIr (UnsafeRaw),
|
||||||
import Control.Monad.State (
|
llvmIrToString)
|
||||||
execStateT,
|
import Control.Monad.State (execStateT)
|
||||||
)
|
import Data.Functor ((<&>))
|
||||||
import Data.List (sortBy)
|
import Data.List (sortBy)
|
||||||
|
import qualified Data.Map as Map
|
||||||
import Grammar.ErrM (Err)
|
import Grammar.ErrM (Err)
|
||||||
import Monomorphizer.MonomorphizerIr as MIR (Bind (..), Data (..), Def (DBind, DData), Program (..), Type (TLit))
|
import Monomorphizer.MonomorphizerIr as MIR (Bind (..), Data (..),
|
||||||
|
Def (DBind, DData),
|
||||||
|
Program (..),
|
||||||
|
Type (TLit))
|
||||||
import TypeChecker.TypeCheckerIr (Ident (..))
|
import TypeChecker.TypeCheckerIr (Ident (..))
|
||||||
|
|
||||||
{- | Compiles an AST and produces a LLVM Ir string.
|
{- | Compiles an AST and produces a LLVM Ir string.
|
||||||
|
|
@ -21,8 +27,14 @@ import TypeChecker.TypeCheckerIr (Ident (..))
|
||||||
generateCode :: MIR.Program -> Bool -> Err String
|
generateCode :: MIR.Program -> Bool -> Err String
|
||||||
generateCode (MIR.Program scs) addGc = do
|
generateCode (MIR.Program scs) addGc = do
|
||||||
let tree = filter (not . detectPrelude) (sortBy lowData scs)
|
let tree = filter (not . detectPrelude) (sortBy lowData scs)
|
||||||
let codegen = initCodeGenerator addGc tree
|
codegen = initCodeGenerator addGc tree
|
||||||
llvmIrToString . instructions <$> execStateT (compileScs tree) codegen
|
|
||||||
|
-- Append instructions
|
||||||
|
execStateT (compileScs tree) codegen <&> \state ->
|
||||||
|
llvmIrToString $ defaultStart
|
||||||
|
++ (if addGc then gcStart else [])
|
||||||
|
++ map inst (Map.elems state.structTypes)
|
||||||
|
++ state.instructions
|
||||||
|
|
||||||
detectPrelude :: Def -> Bool
|
detectPrelude :: Def -> Bool
|
||||||
detectPrelude (DData (Data (TLit (Ident "Bool")) _)) = True
|
detectPrelude (DData (Data (TLit (Ident "Bool")) _)) = True
|
||||||
|
|
@ -33,3 +45,24 @@ lowData :: Def -> Def -> Ordering
|
||||||
lowData (DData _) (DBind _) = LT
|
lowData (DData _) (DBind _) = LT
|
||||||
lowData (DBind _) (DData _) = GT
|
lowData (DBind _) (DData _) = GT
|
||||||
lowData _ _ = EQ
|
lowData _ _ = EQ
|
||||||
|
|
||||||
|
defaultStart :: [LLVMIr]
|
||||||
|
defaultStart =
|
||||||
|
[ UnsafeRaw "target triple = \"x86_64-pc-linux-gnu\"\n"
|
||||||
|
, UnsafeRaw "target datalayout = \"e-m:o-i64:64-f80:128-n8:16:32:64-S128\"\n"
|
||||||
|
, UnsafeRaw "@.str = private unnamed_addr constant [3 x i8] c\"%i\n\", align 1\n"
|
||||||
|
, UnsafeRaw "@.non_exhaustive_patterns = private unnamed_addr constant [41 x i8] c\"Non-exhaustive patterns in case at %i:%i\n\"\n"
|
||||||
|
, UnsafeRaw "declare i32 @printf(ptr noalias nocapture, ...)\n"
|
||||||
|
, UnsafeRaw "declare i32 @exit(i32 noundef)\n"
|
||||||
|
, UnsafeRaw "declare ptr @malloc(i32 noundef)\n"
|
||||||
|
]
|
||||||
|
|
||||||
|
gcStart :: [LLVMIr]
|
||||||
|
gcStart =
|
||||||
|
[ UnsafeRaw "declare external void @cheap_init()\n"
|
||||||
|
, UnsafeRaw "declare external ptr @cheap_alloc(i64)\n"
|
||||||
|
, UnsafeRaw "declare external void @cheap_dispose()\n"
|
||||||
|
, UnsafeRaw "declare external ptr @cheap_the()\n"
|
||||||
|
, UnsafeRaw "declare external void @cheap_set_profiler(ptr, i1)\n"
|
||||||
|
, UnsafeRaw "declare external void @cheap_profiler_log_options(ptr, i64)\n"
|
||||||
|
]
|
||||||
|
|
|
||||||
|
|
@ -1,46 +1,101 @@
|
||||||
|
{-# LANGUAGE DuplicateRecordFields #-}
|
||||||
|
{-# LANGUAGE LambdaCase #-}
|
||||||
|
{-# LANGUAGE OverloadedRecordDot #-}
|
||||||
|
{-# LANGUAGE OverloadedStrings #-}
|
||||||
|
|
||||||
module Codegen.CompilerState where
|
module Codegen.CompilerState where
|
||||||
|
|
||||||
import Auxiliary (snoc)
|
import Auxiliary (snoc)
|
||||||
import Codegen.Auxillary (type2LlvmType, typeByteSize)
|
import Codegen.Auxillary (type2LlvmType, typeByteSize)
|
||||||
import Codegen.LlvmIr as LIR (LLVMIr (UnsafeRaw),
|
import Codegen.LlvmIr as LIR (LLVMIr (SetVariable, Type),
|
||||||
LLVMType)
|
LLVMType (CustomType, Function, I64, Ptr),
|
||||||
import Control.Monad.State (StateT, gets, modify)
|
LLVMValue (VFunction, VIdent),
|
||||||
|
Visibility (Global),
|
||||||
|
typeOf)
|
||||||
|
import Control.Monad.State (StateT, gets, modify, void)
|
||||||
import Data.Map (Map)
|
import Data.Map (Map)
|
||||||
import qualified Data.Map as Map
|
import qualified Data.Map as Map
|
||||||
import Grammar.ErrM (Err)
|
import Grammar.ErrM (Err)
|
||||||
import Monomorphizer.MonomorphizerIr as MIR
|
import Monomorphizer.MonomorphizerIr (Ident (..), Inj (..), T,
|
||||||
|
flattenType)
|
||||||
|
import qualified Monomorphizer.MonomorphizerIr as MIR
|
||||||
import qualified TypeChecker.TypeCheckerIr as TIR
|
import qualified TypeChecker.TypeCheckerIr as TIR
|
||||||
|
|
||||||
-- | The record used as the code generator state
|
-- | The record used as the code generator state
|
||||||
data CodeGenerator = CodeGenerator
|
data CodeGenerator = CodeGenerator
|
||||||
{ instructions :: [LLVMIr]
|
{ instructions :: [LLVMIr]
|
||||||
, functions :: Map MIR.Id FunctionInfo
|
, functions :: Map (T Ident) FunctionInfo
|
||||||
, customTypes :: Map LLVMType Integer
|
, customTypes :: Map LLVMType Integer
|
||||||
, constructors :: Map TIR.Ident ConstructorInfo
|
, constructors :: Map Ident ConstructorInfo
|
||||||
, variableCount :: Integer
|
, variableCount :: Integer
|
||||||
, labelCount :: Integer
|
, labelCount :: Integer
|
||||||
, gcEnabled :: Bool
|
, gcEnabled :: Bool
|
||||||
|
, structTypes :: Map Ident StructType
|
||||||
|
-- ^ Custom stucture types
|
||||||
|
, locals :: [(Ident, LocalElem)]
|
||||||
|
-- ^ Arguments and variables in local environment
|
||||||
|
, globals :: Map Ident (LLVMType, LLVMValue)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
data StructType = StructType
|
||||||
|
{ ptr :: LLVMType
|
||||||
|
, typs :: [LLVMType]
|
||||||
|
, inst :: LLVMIr
|
||||||
|
}
|
||||||
|
|
||||||
|
data LocalElem = LocalElem
|
||||||
|
{ typ :: LLVMType
|
||||||
|
, val :: LLVMValue
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
-- | A state type synonym
|
-- | A state type synonym
|
||||||
type CompilerState a = StateT CodeGenerator Err a
|
type CompilerState a = StateT CodeGenerator Err a
|
||||||
|
|
||||||
data FunctionInfo = FunctionInfo
|
data FunctionInfo = FunctionInfo
|
||||||
{ numArgs :: Int
|
{ numArgs :: Int
|
||||||
, arguments :: [Id]
|
, arguments :: [T Ident]
|
||||||
}
|
}
|
||||||
deriving (Show)
|
deriving (Show)
|
||||||
data ConstructorInfo = ConstructorInfo
|
data ConstructorInfo = ConstructorInfo
|
||||||
{ numArgsCI :: Int
|
{ numArgsCI :: Int
|
||||||
, argumentsCI :: [Id]
|
, argumentsCI :: [T Ident]
|
||||||
, numCI :: Integer
|
, numCI :: Integer
|
||||||
, returnTypeCI :: MIR.Type
|
, returnTypeCI :: MIR.Type
|
||||||
}
|
}
|
||||||
deriving (Show)
|
deriving (Show)
|
||||||
|
|
||||||
|
|
||||||
|
addStructType_ :: Ident -> [LLVMType] -> CompilerState ()
|
||||||
|
addStructType_ = fmap void . addStructType
|
||||||
|
|
||||||
|
addStructType :: Ident -> [LLVMType] -> CompilerState LLVMType
|
||||||
|
addStructType x ts = do
|
||||||
|
modify $ \s -> s { structTypes = Map.insert x struct s.structTypes }
|
||||||
|
pure t
|
||||||
|
where
|
||||||
|
struct = StructType
|
||||||
|
{ ptr = t
|
||||||
|
, typs = ts
|
||||||
|
, inst = Type x ts
|
||||||
|
}
|
||||||
|
t = CustomType x
|
||||||
|
|
||||||
-- | Adds a instruction to the CodeGenerator state
|
-- | Adds a instruction to the CodeGenerator state
|
||||||
emit :: LLVMIr -> CompilerState ()
|
emit :: LLVMIr -> CompilerState ()
|
||||||
emit l = modify $ \t -> t{instructions = Auxiliary.snoc l $ instructions t}
|
|
||||||
|
-- Add variable to environment
|
||||||
|
emit l@(SetVariable x _) = modify $ \t ->
|
||||||
|
t { instructions = Auxiliary.snoc l t.instructions
|
||||||
|
, locals = snoc (x, local)
|
||||||
|
t.locals
|
||||||
|
}
|
||||||
|
where
|
||||||
|
local = LocalElem { typ = typeOf l
|
||||||
|
, val = VIdent x (typeOf l)
|
||||||
|
}
|
||||||
|
|
||||||
|
emit l = modify $ \t -> t { instructions = Auxiliary.snoc l t.instructions }
|
||||||
|
|
||||||
-- | Increases the variable counter in the CodeGenerator state
|
-- | Increases the variable counter in the CodeGenerator state
|
||||||
increaseVarCount :: CompilerState ()
|
increaseVarCount :: CompilerState ()
|
||||||
|
|
@ -63,16 +118,19 @@ getNewLabel = do
|
||||||
{- | Produces a map of functions infos from a list of binds,
|
{- | Produces a map of functions infos from a list of binds,
|
||||||
which contains useful data for code generation.
|
which contains useful data for code generation.
|
||||||
-}
|
-}
|
||||||
getFunctions :: [MIR.Def] -> Map Id FunctionInfo
|
getFunctions :: [MIR.Def] -> Map (T Ident) FunctionInfo
|
||||||
getFunctions bs = Map.fromList $ go bs
|
getFunctions bs = Map.fromList $ go bs
|
||||||
where
|
where
|
||||||
go [] = []
|
go [] = []
|
||||||
go (MIR.DBind (MIR.Bind id args _) : xs) =
|
go (MIR.DBind (MIR.Bind id args _) : xs) =
|
||||||
(id, FunctionInfo{numArgs = length args, arguments = args})
|
(id, FunctionInfo { numArgs = length args
|
||||||
|
, arguments = args
|
||||||
|
}
|
||||||
|
)
|
||||||
: go xs
|
: go xs
|
||||||
go (_ : xs) = go xs
|
go (_ : xs) = go xs
|
||||||
|
|
||||||
createArgs :: [MIR.Type] -> [Id]
|
createArgs :: [MIR.Type] -> [T Ident]
|
||||||
createArgs xs = fst $ foldl (\(acc, l) t -> (acc ++ [(TIR.Ident ("arg_" <> show l), t)], l + 1)) ([], 0) xs
|
createArgs xs = fst $ foldl (\(acc, l) t -> (acc ++ [(TIR.Ident ("arg_" <> show l), t)], l + 1)) ([], 0) xs
|
||||||
|
|
||||||
{- | Produces a map of functions infos from a list of binds,
|
{- | Produces a map of functions infos from a list of binds,
|
||||||
|
|
@ -113,35 +171,43 @@ getTypes bs = Map.fromList $ go bs
|
||||||
variantTypes fi = init $ map type2LlvmType (flattenType fi)
|
variantTypes fi = init $ map type2LlvmType (flattenType fi)
|
||||||
biggestVariant ts = 8 + maximum (sum . (\(Inj _ fi) -> typeByteSize <$> variantTypes fi) <$> ts)
|
biggestVariant ts = 8 + maximum (sum . (\(Inj _ fi) -> typeByteSize <$> variantTypes fi) <$> ts)
|
||||||
|
|
||||||
|
getGlobals :: [MIR.Def] -> Map Ident (LLVMType, LLVMValue)
|
||||||
|
getGlobals scs = Map.fromList [ go b | MIR.DBind b <- scs ]
|
||||||
|
where
|
||||||
|
go bind | x == "main" = let typ = Function I64 []
|
||||||
|
in (x, (typ, VFunction x Global typ))
|
||||||
|
| otherwise = (x, (typ, VFunction x Global typ))
|
||||||
|
where
|
||||||
|
typ = Function tr $ Ptr : ts
|
||||||
|
Function tr ts = type2LlvmType' t
|
||||||
|
|
||||||
|
(x, t) = case bind of
|
||||||
|
MIR.Bind xt _ _ -> xt
|
||||||
|
MIR.BindC _ xt _ _ -> xt
|
||||||
|
|
||||||
|
-- Higher order function arguments are replaced with ptr
|
||||||
|
type2LlvmType' = go []
|
||||||
|
where
|
||||||
|
go acc = \case
|
||||||
|
MIR.TFun (MIR.TFun _ _) t2 -> go (snoc Ptr acc) t2
|
||||||
|
MIR.TFun t1 t2 -> go (snoc (type2LlvmType t1) acc) t2
|
||||||
|
t -> Function (type2LlvmType t) acc
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
initCodeGenerator :: Bool -> [MIR.Def] -> CodeGenerator
|
initCodeGenerator :: Bool -> [MIR.Def] -> CodeGenerator
|
||||||
initCodeGenerator addGc scs =
|
initCodeGenerator addGc scs =
|
||||||
CodeGenerator
|
CodeGenerator
|
||||||
{ instructions = defaultStart <> if addGc then gcStart else []
|
{ instructions = []
|
||||||
, functions = getFunctions scs
|
, functions = getFunctions scs
|
||||||
, constructors = getConstructors scs
|
, constructors = getConstructors scs
|
||||||
, customTypes = getTypes scs
|
, customTypes = getTypes scs
|
||||||
|
, structTypes = mempty
|
||||||
, variableCount = 0
|
, variableCount = 0
|
||||||
, labelCount = 0
|
, labelCount = 0
|
||||||
, gcEnabled = addGc
|
, gcEnabled = addGc
|
||||||
|
, locals = mempty
|
||||||
|
, globals = getGlobals scs
|
||||||
}
|
}
|
||||||
|
|
||||||
defaultStart :: [LLVMIr]
|
|
||||||
defaultStart =
|
|
||||||
[ UnsafeRaw "target triple = \"x86_64-pc-linux-gnu\"\n"
|
|
||||||
, UnsafeRaw "target datalayout = \"e-m:o-i64:64-f80:128-n8:16:32:64-S128\"\n"
|
|
||||||
, UnsafeRaw "@.str = private unnamed_addr constant [3 x i8] c\"%i\n\", align 1\n"
|
|
||||||
, UnsafeRaw "@.non_exhaustive_patterns = private unnamed_addr constant [41 x i8] c\"Non-exhaustive patterns in case at %i:%i\n\"\n"
|
|
||||||
, UnsafeRaw "declare i32 @printf(ptr noalias nocapture, ...)\n"
|
|
||||||
, UnsafeRaw "declare i32 @exit(i32 noundef)\n"
|
|
||||||
, UnsafeRaw "declare ptr @malloc(i32 noundef)\n"
|
|
||||||
]
|
|
||||||
|
|
||||||
gcStart :: [LLVMIr]
|
|
||||||
gcStart =
|
|
||||||
[ UnsafeRaw "declare external void @cheap_init()\n"
|
|
||||||
, UnsafeRaw "declare external ptr @cheap_alloc(i64)\n"
|
|
||||||
, UnsafeRaw "declare external void @cheap_dispose()\n"
|
|
||||||
, UnsafeRaw "declare external ptr @cheap_the()\n"
|
|
||||||
, UnsafeRaw "declare external void @cheap_set_profiler(ptr, i1)\n"
|
|
||||||
, UnsafeRaw "declare external void @cheap_profiler_log_options(ptr, i64)\n"
|
|
||||||
]
|
|
||||||
|
|
|
||||||
|
|
@ -1,36 +1,40 @@
|
||||||
|
{-# LANGUAGE DuplicateRecordFields #-}
|
||||||
{-# LANGUAGE LambdaCase #-}
|
{-# LANGUAGE LambdaCase #-}
|
||||||
|
{-# LANGUAGE NamedFieldPuns #-}
|
||||||
|
{-# LANGUAGE OverloadedRecordDot #-}
|
||||||
{-# LANGUAGE OverloadedStrings #-}
|
{-# LANGUAGE OverloadedStrings #-}
|
||||||
|
|
||||||
module Codegen.Emits where
|
module Codegen.Emits where
|
||||||
|
|
||||||
|
import Auxiliary (snoc)
|
||||||
import Codegen.Auxillary
|
import Codegen.Auxillary
|
||||||
import Codegen.CompilerState
|
import Codegen.CompilerState
|
||||||
import Codegen.LlvmIr as LIR
|
import Codegen.LlvmIr as LIR
|
||||||
import Control.Applicative ((<|>))
|
import Control.Applicative (Applicative (liftA2), (<|>))
|
||||||
import Control.Monad (when)
|
import Control.Monad (forM_, when, zipWithM_)
|
||||||
|
import Control.Monad.Extra (whenJust)
|
||||||
import Control.Monad.State (gets, modify)
|
import Control.Monad.State (gets, modify)
|
||||||
import Data.Bifunctor qualified as BI
|
|
||||||
import Data.Char (ord)
|
import Data.Char (ord)
|
||||||
import Data.Coerce (coerce)
|
import Data.Coerce (coerce)
|
||||||
import Data.Map qualified as Map
|
import Data.Foldable.Extra (notNull)
|
||||||
|
import qualified Data.Map as Map
|
||||||
import Data.Maybe (fromJust, fromMaybe, isNothing)
|
import Data.Maybe (fromJust, fromMaybe, isNothing)
|
||||||
import Data.Tuple.Extra (dupe, first, second)
|
import Data.Tuple.Extra (second)
|
||||||
import Debug.Trace (trace, traceShow)
|
import Grammar.Print (printTree)
|
||||||
import Grammar.Print
|
import Monomorphizer.MonomorphizerIr
|
||||||
import Monomorphizer.MonomorphizerIr as MIR
|
|
||||||
import TypeChecker.TypeCheckerIr qualified as TIR
|
|
||||||
|
|
||||||
compileScs :: [MIR.Def] -> CompilerState ()
|
|
||||||
|
compileScs :: [Def] -> CompilerState ()
|
||||||
compileScs [] = do
|
compileScs [] = do
|
||||||
emit $ UnsafeRaw "\n"
|
emit $ UnsafeRaw "\n"
|
||||||
|
mapM_ createConstructor =<< gets (Map.toList . constructors)
|
||||||
-- as a last step create all the constructors
|
-- as a last step create all the constructors
|
||||||
-- //TODO maybe merge this with the data type match?
|
-- //TODO maybe merge this with the data type match?
|
||||||
c <- gets (Map.toList . constructors)
|
where
|
||||||
mapM_
|
createConstructor (id, ci) = do
|
||||||
( \(id, ci) -> do
|
|
||||||
let t = returnTypeCI ci
|
let t = returnTypeCI ci
|
||||||
let t' = type2LlvmType t
|
t' = type2LlvmType t
|
||||||
let x = BI.second type2LlvmType <$> argumentsCI ci
|
x = (mkCxtName, Ptr) : map (second type2LlvmType) ci.argumentsCI
|
||||||
emit $ Define FastCC t' id x
|
emit $ Define FastCC t' id x
|
||||||
top <- getNewVar
|
top <- getNewVar
|
||||||
ptr <- getNewVar
|
ptr <- getNewVar
|
||||||
|
|
@ -56,7 +60,7 @@ compileScs [] = do
|
||||||
cTypes <- gets customTypes
|
cTypes <- gets customTypes
|
||||||
|
|
||||||
enumerateOneM_
|
enumerateOneM_
|
||||||
( \i (TIR.Ident arg_n, arg_t) -> do
|
( \i (Ident arg_n, arg_t) -> do
|
||||||
let arg_t' = type2LlvmType arg_t
|
let arg_t' = type2LlvmType arg_t
|
||||||
emit $ Comment (toIr arg_t' <> " " <> arg_n <> " " <> show i)
|
emit $ Comment (toIr arg_t' <> " " <> arg_n <> " " <> show i)
|
||||||
elemPtr <- getNewVar
|
elemPtr <- getNewVar
|
||||||
|
|
@ -78,11 +82,11 @@ compileScs [] = do
|
||||||
heapPtr <- getNewVar
|
heapPtr <- getNewVar
|
||||||
useGc <- gets gcEnabled
|
useGc <- gets gcEnabled
|
||||||
emit $ SetVariable heapPtr (if useGc then GcMalloc s else Malloc s)
|
emit $ SetVariable heapPtr (if useGc then GcMalloc s else Malloc s)
|
||||||
emit $ Store arg_t' (VIdent (TIR.Ident arg_n) arg_t') Ptr heapPtr
|
emit $ Store arg_t' (VIdent (Ident arg_n) arg_t') Ptr heapPtr
|
||||||
emit $ Store (Ref arg_t') (VIdent heapPtr arg_t') Ptr elemPtr
|
emit $ Store (Ref arg_t') (VIdent heapPtr arg_t') Ptr elemPtr
|
||||||
Nothing -> do
|
Nothing -> do
|
||||||
emit $ Comment "Just store"
|
emit $ Comment "Just store"
|
||||||
emit $ Store arg_t' (VIdent (TIR.Ident arg_n) arg_t') Ptr elemPtr
|
emit $ Store arg_t' (VIdent (Ident arg_n) arg_t') Ptr elemPtr
|
||||||
)
|
)
|
||||||
(argumentsCI ci)
|
(argumentsCI ci)
|
||||||
|
|
||||||
|
|
@ -95,34 +99,83 @@ compileScs [] = do
|
||||||
emit $ UnsafeRaw "\n"
|
emit $ UnsafeRaw "\n"
|
||||||
|
|
||||||
modify $ \s -> s{variableCount = 0}
|
modify $ \s -> s{variableCount = 0}
|
||||||
)
|
|
||||||
c
|
compileScs (DBind bind : xs) = do
|
||||||
compileScs (MIR.DBind (MIR.Bind (name, t) args exp) : xs) = do
|
|
||||||
let t_return = type2LlvmType . last . flattenType $ t
|
|
||||||
emit $ UnsafeRaw "\n"
|
emit $ UnsafeRaw "\n"
|
||||||
emit . Comment $ show name <> ": " <> show exp
|
emit . Comment $ show name <> ": " <> show (fst exp)
|
||||||
let args' = map (second type2LlvmType) args
|
|
||||||
|
Function t_return t_args <- gets $ fst
|
||||||
|
. fromJust
|
||||||
|
. Map.lookup name
|
||||||
|
. globals
|
||||||
|
|
||||||
|
let args' = zip (mkCxtName : map fst args) t_args
|
||||||
|
|
||||||
emit $ Define FastCC t_return name args'
|
emit $ Define FastCC t_return name args'
|
||||||
useGc <- gets gcEnabled
|
modify $ \s -> s { locals = foldr insertArg s.locals args' }
|
||||||
when (name == "main") (mapM_ emit (firstMainContent useGc))
|
|
||||||
functionBody <- exprToValue exp
|
-- Dereference ptr arguments
|
||||||
if name == "main"
|
when (notNull args') $
|
||||||
then mapM_ emit $ lastMainContent useGc functionBody
|
forM_ (tail args') $ \(x, t) -> when (t == Ptr) $ do
|
||||||
else emit $ Ret t_return functionBody
|
let t_deref =
|
||||||
|
let
|
||||||
|
Function t ts = type2LlvmType . fromJust $ lookup x args
|
||||||
|
in
|
||||||
|
Function t (Ptr : ts)
|
||||||
|
|
||||||
|
emit . SetVariable (mkDerefName x)
|
||||||
|
$ Load t_deref Ptr x
|
||||||
|
|
||||||
|
whenJust mcxt loadFreeVars
|
||||||
|
|
||||||
|
gcEnabled <- gets gcEnabled
|
||||||
|
when isMain $ mapM_ emit (firstMainContent gcEnabled)
|
||||||
|
|
||||||
|
result <- exprToValue exp
|
||||||
|
|
||||||
|
if isMain
|
||||||
|
then mapM_ emit $ lastMainContent gcEnabled result
|
||||||
|
else emit $ Ret t_return result
|
||||||
|
|
||||||
emit DefineEnd
|
emit DefineEnd
|
||||||
modify $ \s -> s{variableCount = 0}
|
-- Reset variable count and empty locals
|
||||||
|
modify $ \s -> s { variableCount = 0, locals = mempty }
|
||||||
compileScs xs
|
compileScs xs
|
||||||
compileScs (MIR.DData (MIR.Data typ ts) : xs) = do
|
where
|
||||||
let (TIR.Ident outer_id) = extractTypeName typ
|
loadFreeVars cxt = do
|
||||||
|
emit $ Comment "Load free variables"
|
||||||
|
zipWithM_ go cxt' [1 ..]
|
||||||
|
where
|
||||||
|
go (x, t) i = do
|
||||||
|
vc <- getNewVar
|
||||||
|
emit . SetVariable vc
|
||||||
|
$ GetElementPtrInbounds (CustomType $ mkClosureName name) Ptr (VIdent mkCxtName Ptr)
|
||||||
|
I32 (VInteger 0) I32 (VInteger i) -- TODO fix indices
|
||||||
|
emit . SetVariable x $ Load t Ptr vc
|
||||||
|
cxt' = map (second type2LlvmType) cxt
|
||||||
|
|
||||||
|
isMain = name == "main"
|
||||||
|
|
||||||
|
(name, args, exp, mcxt) = case bind of
|
||||||
|
Bind (name, _) args exp -> (name, args, exp, Nothing)
|
||||||
|
BindC cxt (name, _) args exp -> (name, args, exp, Just cxt)
|
||||||
|
|
||||||
|
|
||||||
|
insertArg (x, t) = snoc (x, LocalElem { val = VIdent x t, typ = t })
|
||||||
|
|
||||||
|
compileScs (DData (Data typ ts) : xs) = do
|
||||||
|
let (Ident outer_id) = extractTypeName typ
|
||||||
-- //TODO this could be extracted from the customTypes map
|
-- //TODO this could be extracted from the customTypes map
|
||||||
let variantTypes fi = init $ map type2LlvmType (flattenType fi)
|
let variantTypes fi = init $ map type2LlvmType (flattenType fi)
|
||||||
let biggestVariant = 7 + maximum (sum . (\(Inj _ fi) -> typeByteSize <$> variantTypes fi) <$> ts)
|
let biggestVariant = 7 + maximum (sum . (\(Inj _ fi) -> typeByteSize <$> variantTypes fi) <$> ts)
|
||||||
emit $ LIR.Type (TIR.Ident outer_id) [I8, Array biggestVariant I8]
|
-- Add data type (e.g. %List) to top of the file
|
||||||
|
addStructType_ (Ident outer_id) [I8, Array biggestVariant I8]
|
||||||
typeSets <- gets customTypes
|
typeSets <- gets customTypes
|
||||||
mapM_
|
mapM_
|
||||||
( \(Inj inner_id fi) -> do
|
( \(Inj inner_id fi) -> do
|
||||||
let types = (\s -> if Map.member s typeSets then Ref s else s) <$> variantTypes fi
|
let types = (\s -> if Map.member s typeSets then Ref s else s) <$> variantTypes fi
|
||||||
emit $ LIR.Type inner_id (I8 : types)
|
-- Add constructor type (e.g. %Cons) to top of the file
|
||||||
|
addStructType_ inner_id (I8 : types)
|
||||||
)
|
)
|
||||||
ts
|
ts
|
||||||
compileScs xs
|
compileScs xs
|
||||||
|
|
@ -149,16 +202,16 @@ lastMainContent False var =
|
||||||
, Ret I64 (VInteger 0)
|
, Ret I64 (VInteger 0)
|
||||||
]
|
]
|
||||||
|
|
||||||
compileExp :: ExpT -> CompilerState ()
|
compileExp :: T Exp -> CompilerState ()
|
||||||
compileExp (MIR.ELit lit, _t) = emitLit lit
|
compileExp (ELit lit, _t) = emitLit lit
|
||||||
compileExp (MIR.EAdd e1 e2, t) = emitAdd t e1 e2
|
compileExp (EAdd e1 e2, t) = emitAdd t e1 e2
|
||||||
compileExp (MIR.EVar name, _t) = emitIdent name
|
compileExp (EVar name, _t) = emitIdent name
|
||||||
compileExp (MIR.EApp e1 e2, t) = emitApp t e1 e2
|
compileExp (EApp e1 e2, t) = emitApp t e1 e2
|
||||||
compileExp (MIR.ELet bind e, _) = emitLet bind e
|
compileExp (ELet bind e, _) = emitLet bind e
|
||||||
compileExp (MIR.ECase e cs, t) = emitECased t e (map (t,) cs)
|
compileExp (ECase e cs, t) = emitECased t e (map (t,) cs)
|
||||||
|
|
||||||
emitLet :: MIR.Bind -> ExpT -> CompilerState ()
|
emitLet :: Bind -> T Exp -> CompilerState ()
|
||||||
emitLet (MIR.Bind id [] innerExp) e = do
|
emitLet (Bind id [] innerExp) e = do
|
||||||
evaled <- exprToValue innerExp
|
evaled <- exprToValue innerExp
|
||||||
tempVar <- getNewVar
|
tempVar <- getNewVar
|
||||||
let t = type2LlvmType . snd $ innerExp
|
let t = type2LlvmType . snd $ innerExp
|
||||||
|
|
@ -168,14 +221,14 @@ emitLet (MIR.Bind id [] innerExp) e = do
|
||||||
compileExp e
|
compileExp e
|
||||||
emitLet b _ = error $ "Non empty argument list in let-bind " <> show b
|
emitLet b _ = error $ "Non empty argument list in let-bind " <> show b
|
||||||
|
|
||||||
emitECased :: MIR.Type -> ExpT -> [(MIR.Type, Branch)] -> CompilerState ()
|
emitECased :: Type -> T Exp -> [(Type, Branch)] -> CompilerState ()
|
||||||
emitECased t e cases = do
|
emitECased t e cases = do
|
||||||
let cs = snd <$> cases
|
let cs = snd <$> cases
|
||||||
let ty = type2LlvmType t
|
let ty = type2LlvmType t
|
||||||
let rt = type2LlvmType (snd e)
|
let rt = type2LlvmType (snd e)
|
||||||
vs <- exprToValue e
|
vs <- exprToValue e
|
||||||
lbl <- getNewLabel
|
lbl <- getNewLabel
|
||||||
let label = TIR.Ident $ "escape_" <> show lbl
|
let label = Ident $ "escape_" <> show lbl
|
||||||
stackPtr <- getNewVar
|
stackPtr <- getNewVar
|
||||||
emit $ SetVariable stackPtr (Alloca ty)
|
emit $ SetVariable stackPtr (Alloca ty)
|
||||||
mapM_ (emitCases rt ty label stackPtr vs) cs
|
mapM_ (emitCases rt ty label stackPtr vs) cs
|
||||||
|
|
@ -192,14 +245,14 @@ emitECased t e cases = do
|
||||||
res <- getNewVar
|
res <- getNewVar
|
||||||
emit $ SetVariable res (Load ty Ptr stackPtr)
|
emit $ SetVariable res (Load ty Ptr stackPtr)
|
||||||
where
|
where
|
||||||
emitCases :: LLVMType -> LLVMType -> TIR.Ident -> TIR.Ident -> LLVMValue -> Branch -> CompilerState ()
|
emitCases :: LLVMType -> LLVMType -> Ident -> Ident -> LLVMValue -> Branch -> CompilerState ()
|
||||||
emitCases rt ty label stackPtr vs (Branch (MIR.PInj consId cs, _t) exp) = do
|
emitCases rt ty label stackPtr vs (Branch (PInj consId cs, _t) exp) = do
|
||||||
emit $ Comment "Inj"
|
emit $ Comment "Inj"
|
||||||
cons <- gets constructors
|
cons <- gets constructors
|
||||||
let r = fromJust $ Map.lookup consId cons
|
let r = fromJust $ Map.lookup consId cons
|
||||||
|
|
||||||
lbl_failPos <- (\x -> TIR.Ident $ "failed_" <> show x) <$> getNewLabel
|
lbl_failPos <- (\x -> Ident $ "failed_" <> show x) <$> getNewLabel
|
||||||
lbl_succPos <- (\x -> TIR.Ident $ "success_" <> show x) <$> getNewLabel
|
lbl_succPos <- (\x -> Ident $ "success_" <> show x) <$> getNewLabel
|
||||||
|
|
||||||
consVal <- getNewVar
|
consVal <- getNewVar
|
||||||
emit $ SetVariable consVal (ExtractValue rt vs 0)
|
emit $ SetVariable consVal (ExtractValue rt vs 0)
|
||||||
|
|
@ -215,10 +268,10 @@ emitECased t e cases = do
|
||||||
emit $ Store rt vs Ptr castPtr
|
emit $ Store rt vs Ptr castPtr
|
||||||
emit $ SetVariable casted (Load (CustomType (coerce consId)) Ptr castPtr)
|
emit $ SetVariable casted (Load (CustomType (coerce consId)) Ptr castPtr)
|
||||||
enumerateOneM_
|
enumerateOneM_
|
||||||
( \i c -> do
|
( \i (c, t) -> do
|
||||||
case c of
|
case c of
|
||||||
PVar (x, topT) -> do
|
PVar x -> do
|
||||||
let topT' = type2LlvmType topT
|
let topT' = type2LlvmType t
|
||||||
let botT' = CustomType (coerce consId)
|
let botT' = CustomType (coerce consId)
|
||||||
emit . Comment $ "ident " <> toIr topT'
|
emit . Comment $ "ident " <> toIr topT'
|
||||||
cTypes <- gets customTypes
|
cTypes <- gets customTypes
|
||||||
|
|
@ -228,7 +281,7 @@ emitECased t e cases = do
|
||||||
emit $ SetVariable deref (ExtractValue botT' (VIdent casted Ptr) i)
|
emit $ SetVariable deref (ExtractValue botT' (VIdent casted Ptr) i)
|
||||||
emit $ SetVariable x (Load topT' Ptr deref)
|
emit $ SetVariable x (Load topT' Ptr deref)
|
||||||
else emit $ SetVariable x (ExtractValue botT' (VIdent casted Ptr) i)
|
else emit $ SetVariable x (ExtractValue botT' (VIdent casted Ptr) i)
|
||||||
PLit (_l, _t) -> error "Nested pattern matching to be implemented"
|
PLit _l -> error "Nested pattern matching to be implemented"
|
||||||
PInj _id _ps -> error "Nested pattern matching to be implemented"
|
PInj _id _ps -> error "Nested pattern matching to be implemented"
|
||||||
PCatch -> pure ()
|
PCatch -> pure ()
|
||||||
PEnum _id -> error "Nested pattern matching to be implemented"
|
PEnum _id -> error "Nested pattern matching to be implemented"
|
||||||
|
|
@ -238,22 +291,22 @@ emitECased t e cases = do
|
||||||
emit $ Store ty val Ptr stackPtr
|
emit $ Store ty val Ptr stackPtr
|
||||||
emit $ Br label
|
emit $ Br label
|
||||||
emit $ Label lbl_failPos
|
emit $ Label lbl_failPos
|
||||||
emitCases _rt ty label stackPtr vs (Branch (MIR.PLit (i, ct), t) exp) = do
|
emitCases _rt ty label stackPtr vs (Branch (PLit i, t) exp) = do
|
||||||
emit $ Comment "Plit"
|
emit $ Comment "Plit"
|
||||||
let i' = case i of
|
let i' = case i of
|
||||||
MIR.LInt i -> VInteger i
|
LInt i -> VInteger i
|
||||||
MIR.LChar i -> VChar (ord i)
|
LChar i -> VChar (ord i)
|
||||||
ns <- getNewVar
|
ns <- getNewVar
|
||||||
lbl_failPos <- (\x -> TIR.Ident $ "failed_" <> show x) <$> getNewLabel
|
lbl_failPos <- (\x -> Ident $ "failed_" <> show x) <$> getNewLabel
|
||||||
lbl_succPos <- (\x -> TIR.Ident $ "success_" <> show x) <$> getNewLabel
|
lbl_succPos <- (\x -> Ident $ "success_" <> show x) <$> getNewLabel
|
||||||
emit $ SetVariable ns (Icmp LLEq (type2LlvmType ct) vs i')
|
emit $ SetVariable ns (Icmp LLEq (type2LlvmType t) vs i')
|
||||||
emit $ BrCond (VIdent ns ty) lbl_succPos lbl_failPos
|
emit $ BrCond (VIdent ns ty) lbl_succPos lbl_failPos
|
||||||
emit $ Label lbl_succPos
|
emit $ Label lbl_succPos
|
||||||
val <- exprToValue exp
|
val <- exprToValue exp
|
||||||
emit $ Store ty val Ptr stackPtr
|
emit $ Store ty val Ptr stackPtr
|
||||||
emit $ Br label
|
emit $ Br label
|
||||||
emit $ Label lbl_failPos
|
emit $ Label lbl_failPos
|
||||||
emitCases rt ty label stackPtr vs (Branch (MIR.PVar (id, _), _) exp) = do
|
emitCases rt ty label stackPtr vs (Branch (PVar id, _) exp) = do
|
||||||
emit $ Comment "Pvar"
|
emit $ Comment "Pvar"
|
||||||
-- //TODO this is pretty disgusting and would heavily benefit from a rewrite
|
-- //TODO this is pretty disgusting and would heavily benefit from a rewrite
|
||||||
valPtr <- getNewVar
|
valPtr <- getNewVar
|
||||||
|
|
@ -263,20 +316,20 @@ emitECased t e cases = do
|
||||||
val <- exprToValue exp
|
val <- exprToValue exp
|
||||||
emit $ Store ty val Ptr stackPtr
|
emit $ Store ty val Ptr stackPtr
|
||||||
emit $ Br label
|
emit $ Br label
|
||||||
lbl_failPos <- (\x -> TIR.Ident $ "failed_" <> show x) <$> getNewLabel
|
lbl_failPos <- (\x -> Ident $ "failed_" <> show x) <$> getNewLabel
|
||||||
emit $ Label lbl_failPos
|
emit $ Label lbl_failPos
|
||||||
emitCases rt ty label stackPtr vs (Branch (MIR.PEnum (TIR.Ident "True$Bool"), t) exp) = do
|
emitCases rt ty label stackPtr vs (Branch (PEnum (Ident "True$Bool"), t) exp) = do
|
||||||
emitCases rt ty label stackPtr vs (Branch (MIR.PLit (MIR.LInt 1, TLit "Bool"), t) exp)
|
emitCases rt ty label stackPtr vs (Branch (PLit $ LInt 1, t) exp)
|
||||||
emitCases rt ty label stackPtr vs (Branch (MIR.PEnum (TIR.Ident "False$Bool"), _) exp) = do
|
emitCases rt ty label stackPtr vs (Branch (PEnum (Ident "False$Bool"), _) exp) = do
|
||||||
emitCases rt ty label stackPtr vs (Branch (MIR.PLit (MIR.LInt 0, TLit "Bool"), t) exp)
|
emitCases rt ty label stackPtr vs (Branch (PLit (LInt 0), t) exp)
|
||||||
emitCases rt ty label stackPtr vs br@(Branch (MIR.PEnum consId, _) exp) = do
|
emitCases rt ty label stackPtr vs br@(Branch (PEnum consId, _) exp) = do
|
||||||
emit $ Comment "Penum"
|
emit $ Comment "Penum"
|
||||||
cons <- gets constructors
|
cons <- gets constructors
|
||||||
let r = Map.lookup consId cons
|
let r = Map.lookup consId cons
|
||||||
when (isNothing r) (error $ "Constructor: '" ++ printTree consId ++ "' does not exist in cons state:\n" ++ show cons ++ "\nin pattern\n'" ++ printTree br ++ "'\n")
|
when (isNothing r) (error $ "Constructor: '" ++ printTree consId ++ "' does not exist in cons state:\n" ++ show cons ++ "\nin pattern\n'" ++ printTree br ++ "'\n")
|
||||||
|
|
||||||
lbl_failPos <- (\x -> TIR.Ident $ "failed_" <> show x) <$> getNewLabel
|
lbl_failPos <- (\x -> Ident $ "failed_" <> show x) <$> getNewLabel
|
||||||
lbl_succPos <- (\x -> TIR.Ident $ "success_" <> show x) <$> getNewLabel
|
lbl_succPos <- (\x -> Ident $ "success_" <> show x) <$> getNewLabel
|
||||||
|
|
||||||
consVal <- getNewVar
|
consVal <- getNewVar
|
||||||
emit $ SetVariable consVal (ExtractValue rt vs 0)
|
emit $ SetVariable consVal (ExtractValue rt vs 0)
|
||||||
|
|
@ -295,24 +348,17 @@ emitECased t e cases = do
|
||||||
emit $ Store ty val Ptr stackPtr
|
emit $ Store ty val Ptr stackPtr
|
||||||
emit $ Br label
|
emit $ Br label
|
||||||
emit $ Label lbl_failPos
|
emit $ Label lbl_failPos
|
||||||
emitCases _ ty label stackPtr _ (Branch (MIR.PCatch, _) exp) = do
|
emitCases _ ty label stackPtr _ (Branch (PCatch, _) exp) = do
|
||||||
emit $ Comment "Pcatch"
|
emit $ Comment "Pcatch"
|
||||||
val <- exprToValue exp
|
val <- exprToValue exp
|
||||||
emit $ Store ty val Ptr stackPtr
|
emit $ Store ty val Ptr stackPtr
|
||||||
emit $ Br label
|
emit $ Br label
|
||||||
lbl_failPos <- (\x -> TIR.Ident $ "failed_" <> show x) <$> getNewLabel
|
lbl_failPos <- (\x -> Ident $ "failed_" <> show x) <$> getNewLabel
|
||||||
emit $ Label lbl_failPos
|
emit $ Label lbl_failPos
|
||||||
|
|
||||||
emitApp :: MIR.Type -> ExpT -> ExpT -> CompilerState ()
|
emitApp :: Type -> T Exp -> T Exp -> CompilerState ()
|
||||||
emitApp rt e1 e2 = appEmitter e1 e2 []
|
emitApp rt e1 e2 = do
|
||||||
where
|
((EVar name, t), args) <- go (EApp e1 e2, rt)
|
||||||
appEmitter :: ExpT -> ExpT -> [ExpT] -> CompilerState ()
|
|
||||||
appEmitter e1 e2 stack = do
|
|
||||||
let newStack = e2 : stack
|
|
||||||
case e1 of
|
|
||||||
(MIR.EApp e1' e2', _) -> appEmitter e1' e2' newStack
|
|
||||||
(MIR.EVar name, t) -> do
|
|
||||||
args <- traverse exprToValue newStack
|
|
||||||
vs <- getNewVar
|
vs <- getNewVar
|
||||||
funcs <- gets functions
|
funcs <- gets functions
|
||||||
consts <- gets constructors
|
consts <- gets constructors
|
||||||
|
|
@ -321,72 +367,147 @@ emitApp rt e1 e2 = appEmitter e1 e2 []
|
||||||
Global <$ Map.lookup name consts
|
Global <$ Map.lookup name consts
|
||||||
<|> Global <$ Map.lookup (name, t) funcs
|
<|> Global <$ Map.lookup (name, t) funcs
|
||||||
-- this piece of code could probably be improved, i.e remove the double `const Global`
|
-- this piece of code could probably be improved, i.e remove the double `const Global`
|
||||||
args' = map (first valueGetType . dupe) args
|
|
||||||
let call =
|
|
||||||
case name of
|
|
||||||
TIR.Ident ('l' : 't' : '$' : _) -> Icmp LLSlt I64 (snd (head args')) (snd (args' !! 1))
|
|
||||||
TIR.Ident ('$' : 'm' : 'i' : 'n' : 'u' : 's' : '$' : '$' : _) -> Sub I64 (snd (head args')) (snd (args' !! 1))
|
|
||||||
_ -> Call FastCC (type2LlvmType rt) visibility name args'
|
|
||||||
emit $ Comment $ show rt
|
|
||||||
emit $ SetVariable vs call
|
|
||||||
x -> error $ "The unspeakable happened: " <> show x
|
|
||||||
|
|
||||||
emitIdent :: TIR.Ident -> CompilerState ()
|
call <- case name of
|
||||||
|
Ident ('l' : 't' : '$' : _) ->
|
||||||
|
pure $ Icmp LLSlt I64 (snd (head args)) (snd (args !! 1))
|
||||||
|
Ident ('$' : 'm' : 'i' : 'n' : 'u' : 's' : '$' : '$' : _) ->
|
||||||
|
pure $ Sub I64 (snd (head args)) (snd (args !! 1))
|
||||||
|
|
||||||
|
-- FIXME
|
||||||
|
_ -> do
|
||||||
|
let closure_call LocalElem { typ = Ptr, val } = (mkDerefName name, (Ptr, val) : args)
|
||||||
|
|
||||||
|
(name, args) <- gets $ maybe (name, (Ptr, VNull) : args) closure_call
|
||||||
|
. lookup name
|
||||||
|
. locals
|
||||||
|
|
||||||
|
pure $ Call FastCC (type2LlvmType rt) visibility name args
|
||||||
|
|
||||||
|
emit $ Comment $ show (type2LlvmType rt)
|
||||||
|
emit $ SetVariable vs call
|
||||||
|
|
||||||
|
where
|
||||||
|
|
||||||
|
go :: T Exp -> CompilerState (T Exp, [(LLVMType, LLVMValue)])
|
||||||
|
go et@(e, _) = case e of
|
||||||
|
EApp e1 e2@(_, t) -> do
|
||||||
|
(x, as) <- go e1
|
||||||
|
a <- exprToValue e2
|
||||||
|
let t' = type2LlvmType' t
|
||||||
|
pure (x, snoc (t', a) as)
|
||||||
|
_ -> pure (et, [])
|
||||||
|
|
||||||
|
type2LlvmType' = \case
|
||||||
|
TFun _ _ -> Ptr
|
||||||
|
t -> type2LlvmType t
|
||||||
|
|
||||||
|
emitIdent :: Ident -> CompilerState ()
|
||||||
emitIdent id = do
|
emitIdent id = do
|
||||||
-- !!this should never happen!!
|
-- !!this should never happen!!
|
||||||
emit $ Comment "This should not have happened!"
|
emit $ Comment "This should not have happened!"
|
||||||
emit $ Variable id
|
emit $ Variable id
|
||||||
emit $ UnsafeRaw "\n"
|
emit $ UnsafeRaw "\n"
|
||||||
|
|
||||||
emitLit :: MIR.Lit -> CompilerState ()
|
emitLit :: Lit -> CompilerState ()
|
||||||
emitLit i = do
|
emitLit i = do
|
||||||
-- !!this should never happen!!
|
-- !!this should never happen!!
|
||||||
let (i', t) = case i of
|
let (i', t) = case i of
|
||||||
(MIR.LInt i'') -> (VInteger i'', I64)
|
(LInt i'') -> (VInteger i'', I64)
|
||||||
(MIR.LChar i'') -> (VChar $ ord i'', I8)
|
(LChar i'') -> (VChar $ ord i'', I8)
|
||||||
varCount <- getNewVar
|
varCount <- getNewVar
|
||||||
emit $ Comment "This should not have happened!"
|
emit $ Comment "This should not have happened!"
|
||||||
emit $ SetVariable varCount (Add t i' (VInteger 0))
|
emit $ SetVariable varCount (Add t i' (VInteger 0))
|
||||||
|
|
||||||
emitAdd :: MIR.Type -> ExpT -> ExpT -> CompilerState ()
|
emitAdd :: Type -> T Exp -> T Exp -> CompilerState ()
|
||||||
emitAdd t e1 e2 = do
|
emitAdd t e1 e2 = do
|
||||||
v1 <- exprToValue e1
|
v1 <- exprToValue e1
|
||||||
v2 <- exprToValue e2
|
v2 <- exprToValue e2
|
||||||
v <- getNewVar
|
v <- getNewVar
|
||||||
emit $ SetVariable v (Add (type2LlvmType t) v1 v2)
|
emit $ SetVariable v (Add (type2LlvmType t) v1 v2)
|
||||||
|
|
||||||
exprToValue :: ExpT -> CompilerState LLVMValue
|
|
||||||
exprToValue = \case
|
exprToValue :: T Exp -> CompilerState LLVMValue
|
||||||
(MIR.ELit i, _t) -> pure $ case i of
|
exprToValue et@(e, t) = case e of
|
||||||
(MIR.LInt i) -> VInteger i
|
ELit (LInt i) -> pure $ VInteger i
|
||||||
(MIR.LChar i) -> VChar $ ord i
|
ELit (LChar c) -> pure . VChar $ ord c
|
||||||
(MIR.EVar (TIR.Ident "True$Bool"), _t) -> pure $ VInteger 1
|
|
||||||
(MIR.EVar (TIR.Ident "False$Bool"), _t) -> pure $ VInteger 0
|
EVar "True$Bool" -> pure $ VInteger 1
|
||||||
(MIR.EVar name, t) -> do
|
EVar "False$Bool" -> pure $ VInteger 0
|
||||||
funcs <- gets functions
|
|
||||||
cons <- gets constructors
|
EVar name -> gets (Map.lookup name . globals) >>= \case
|
||||||
let res =
|
Just (typ@(Function _ ts), val) | length ts > 1 -> do
|
||||||
Map.lookup (name, t) funcs
|
type_struct <- addStructType (mkClosureName name) [typ]
|
||||||
<|> ( \c ->
|
emit $ Comment "Allocating structure"
|
||||||
FunctionInfo
|
emit . SetVariable name $ Alloca type_struct
|
||||||
{ numArgs = numArgsCI c
|
emit $ Store typ val Ptr name
|
||||||
, arguments = argumentsCI c
|
pure $ VIdent name Ptr
|
||||||
}
|
|
||||||
)
|
Just _ | name == "main" -> do
|
||||||
<$> Map.lookup name cons
|
|
||||||
case res of
|
|
||||||
Just fi -> do
|
|
||||||
if numArgs fi == 0
|
|
||||||
then do
|
|
||||||
vc <- getNewVar
|
vc <- getNewVar
|
||||||
emit $
|
emit $ SetVariable vc (Call FastCC I64 Global name [])
|
||||||
SetVariable
|
pure $ VIdent vc I64
|
||||||
vc
|
|
||||||
(Call FastCC (type2LlvmType t) Global name [])
|
|
||||||
|
Just (Function t_return [_], _) -> do
|
||||||
|
vc <- getNewVar
|
||||||
|
emit $ SetVariable vc (Call FastCC t_return Global name [(Ptr, VNull)])
|
||||||
|
pure $ VIdent vc t_return
|
||||||
|
|
||||||
|
Just _ -> error "Bad"
|
||||||
|
|
||||||
|
Nothing -> gets (Map.lookup name . constructors) >>= \case
|
||||||
|
|
||||||
|
Just ConstructorInfo {numArgsCI}
|
||||||
|
| numArgsCI == 0 -> do
|
||||||
|
vc <- getNewVar
|
||||||
|
emit $ SetVariable vc call
|
||||||
pure $ VIdent vc (type2LlvmType t)
|
pure $ VIdent vc (type2LlvmType t)
|
||||||
else pure $ VFunction name Global (type2LlvmType t)
|
| otherwise -> pure $ VFunction name Global (type2LlvmType t)
|
||||||
Nothing -> pure $ VIdent name (type2LlvmType t)
|
where
|
||||||
e -> do
|
call = Call FastCC (type2LlvmType t) Global name []
|
||||||
compileExp e
|
|
||||||
|
Nothing -> gets $ val
|
||||||
|
. fromJust
|
||||||
|
. lookup name
|
||||||
|
. locals
|
||||||
|
|
||||||
|
EVarC cxt name -> do
|
||||||
|
let cxt' = flip map cxt $ \(x, t) -> let t' = type2LlvmType t
|
||||||
|
in (t', VIdent x t')
|
||||||
|
cxt'' <- gets $ (:cxt')
|
||||||
|
. fromJust
|
||||||
|
. Map.lookup name
|
||||||
|
. globals
|
||||||
|
|
||||||
|
-- Create a new type for function pointer and arguments
|
||||||
|
type_struct <- addStructType (mkClosureName name) $ map fst cxt''
|
||||||
|
emit $ Comment "Allocating structure"
|
||||||
|
emit . SetVariable name $ Alloca type_struct
|
||||||
|
|
||||||
|
let ptr_struct = VIdent name Ptr
|
||||||
|
storeArg (t, v) i = do
|
||||||
|
vc <- getNewVar
|
||||||
|
emit . SetVariable vc
|
||||||
|
$ GetElementPtrInbounds type_struct Ptr ptr_struct
|
||||||
|
I32 (VInteger 0) I32 (VInteger i) -- TODO fix indices
|
||||||
|
emit $ Store t v Ptr vc
|
||||||
|
|
||||||
|
-- Store arguments in structure
|
||||||
|
zipWithM_ storeArg cxt'' [0 ..]
|
||||||
|
pure ptr_struct
|
||||||
|
|
||||||
|
_ -> do
|
||||||
|
compileExp et
|
||||||
v <- getVarCount
|
v <- getVarCount
|
||||||
pure $ VIdent (TIR.Ident $ show v) (getType e)
|
pure $ VIdent (Ident $ show v) (getType et)
|
||||||
|
|
||||||
|
|
||||||
|
mkClosureName :: Ident -> Ident
|
||||||
|
mkClosureName (Ident s) = Ident $ "Closure_" ++ s
|
||||||
|
|
||||||
|
mkDerefName :: Ident -> Ident
|
||||||
|
mkDerefName (Ident s) = Ident $ s ++ "_deref"
|
||||||
|
|
||||||
|
mkCxtName :: Ident
|
||||||
|
mkCxtName = Ident "cxt"
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ module Codegen.LlvmIr (
|
||||||
Visibility (..),
|
Visibility (..),
|
||||||
CallingConvention (..),
|
CallingConvention (..),
|
||||||
ToIr (..),
|
ToIr (..),
|
||||||
|
typeOf
|
||||||
) where
|
) where
|
||||||
|
|
||||||
import Data.List (intercalate)
|
import Data.List (intercalate)
|
||||||
|
|
@ -38,6 +39,9 @@ data LLVMType
|
||||||
class ToIr a where
|
class ToIr a where
|
||||||
toIr :: a -> String
|
toIr :: a -> String
|
||||||
|
|
||||||
|
instance ToIr a => ToIr [a] where
|
||||||
|
toIr = concatMap toIr
|
||||||
|
|
||||||
instance ToIr LLVMType where
|
instance ToIr LLVMType where
|
||||||
toIr :: LLVMType -> String
|
toIr :: LLVMType -> String
|
||||||
toIr = \case
|
toIr = \case
|
||||||
|
|
@ -92,6 +96,7 @@ data LLVMValue
|
||||||
| VIdent Ident LLVMType
|
| VIdent Ident LLVMType
|
||||||
| VConstant String
|
| VConstant String
|
||||||
| VFunction Ident Visibility LLVMType
|
| VFunction Ident Visibility LLVMType
|
||||||
|
| VNull
|
||||||
deriving (Show, Eq, Ord)
|
deriving (Show, Eq, Ord)
|
||||||
|
|
||||||
instance ToIr LLVMValue where
|
instance ToIr LLVMValue where
|
||||||
|
|
@ -102,6 +107,7 @@ instance ToIr LLVMValue where
|
||||||
VIdent (Ident n) _ -> "%" <> n
|
VIdent (Ident n) _ -> "%" <> n
|
||||||
VFunction (Ident n) vis _ -> toIr vis <> n
|
VFunction (Ident n) vis _ -> toIr vis <> n
|
||||||
VConstant s -> "c" <> show s
|
VConstant s -> "c" <> show s
|
||||||
|
VNull -> "null"
|
||||||
|
|
||||||
type Params = [(Ident, LLVMType)]
|
type Params = [(Ident, LLVMType)]
|
||||||
type Args = [(LLVMType, LLVMValue)]
|
type Args = [(LLVMType, LLVMValue)]
|
||||||
|
|
@ -139,6 +145,21 @@ data LLVMIr
|
||||||
-- instructions should be used in its place
|
-- instructions should be used in its place
|
||||||
deriving (Show, Eq, Ord)
|
deriving (Show, Eq, Ord)
|
||||||
|
|
||||||
|
|
||||||
|
-- TODO add missing clauses
|
||||||
|
typeOf :: LLVMIr -> LLVMType
|
||||||
|
typeOf = \case
|
||||||
|
Add t _ _ -> t
|
||||||
|
Sub t _ _ -> t
|
||||||
|
Mul t _ _ -> t
|
||||||
|
Div t _ _ -> t
|
||||||
|
Load t _ _ -> t
|
||||||
|
Store t _ _ _ -> t
|
||||||
|
Type x _ -> CustomType x
|
||||||
|
SetVariable _ ir -> typeOf ir
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
-- | Converts a list of LLVMIr instructions to a string
|
-- | Converts a list of LLVMIr instructions to a string
|
||||||
llvmIrToString :: [LLVMIr] -> String
|
llvmIrToString :: [LLVMIr] -> String
|
||||||
llvmIrToString = go 0
|
llvmIrToString = go 0
|
||||||
|
|
|
||||||
|
|
@ -11,9 +11,11 @@ import Control.Monad.State (MonadState (get, put), State,
|
||||||
evalState)
|
evalState)
|
||||||
import Data.Function (on)
|
import Data.Function (on)
|
||||||
import Data.List (delete, mapAccumL, (\\))
|
import Data.List (delete, mapAccumL, (\\))
|
||||||
|
import Data.Tuple.Extra (first, second)
|
||||||
|
import LambdaLifterIr (T)
|
||||||
|
import qualified LambdaLifterIr as L
|
||||||
import Prelude hiding (exp)
|
import Prelude hiding (exp)
|
||||||
import TypeChecker.TypeCheckerIr
|
import TypeChecker.TypeCheckerIr hiding (T)
|
||||||
|
|
||||||
|
|
||||||
-- | Lift lambdas and let expression into supercombinators.
|
-- | Lift lambdas and let expression into supercombinators.
|
||||||
-- Three phases:
|
-- Three phases:
|
||||||
|
|
@ -21,12 +23,13 @@ import TypeChecker.TypeCheckerIr
|
||||||
-- @abstract@ converts lambdas into let expressions.
|
-- @abstract@ converts lambdas into let expressions.
|
||||||
-- @collectScs@ moves every non-constant let expression to a top-level function.
|
-- @collectScs@ moves every non-constant let expression to a top-level function.
|
||||||
--
|
--
|
||||||
lambdaLift :: Program -> Program
|
lambdaLift :: Program -> L.Program
|
||||||
lambdaLift (Program ds) = Program (datatypes ++ binds)
|
lambdaLift (Program ds) = L.Program (datatypes ++ binds)
|
||||||
where
|
where
|
||||||
datatypes = flip filter ds $ \case DData _ -> True
|
datatypes = [L.DData (toLirData d) | DData d <- ds]
|
||||||
_ -> False
|
|
||||||
binds = map DBind $ (collectScs . abstract . freeVars) [b | DBind b <- ds]
|
binds = map L.DBind $ (collectScs . abstract . freeVars) [b | DBind b <- ds]
|
||||||
|
|
||||||
|
|
||||||
-- | Annotate free variables
|
-- | Annotate free variables
|
||||||
freeVars :: [Bind] -> [ABind]
|
freeVars :: [Bind] -> [ABind]
|
||||||
|
|
@ -36,7 +39,7 @@ freeVars binds = [ let ae = freeVarsExp [] e
|
||||||
| Bind n xs e <- binds
|
| Bind n xs e <- binds
|
||||||
]
|
]
|
||||||
|
|
||||||
freeVarsExp :: Frees -> ExpT -> Ann AExpT
|
freeVarsExp :: Frees -> T Exp -> Ann (T AExp)
|
||||||
freeVarsExp localVars (ae, t) = case ae of
|
freeVarsExp localVars (ae, t) = case ae of
|
||||||
EVar n | elem (n,t) localVars -> Ann { frees = [(n, t)]
|
EVar n | elem (n,t) localVars -> Ann { frees = [(n, t)]
|
||||||
, term = (AVar n, t)
|
, term = (AVar n, t)
|
||||||
|
|
@ -121,27 +124,47 @@ data Ann a = Ann
|
||||||
, term :: a
|
, term :: a
|
||||||
} deriving (Show, Eq)
|
} deriving (Show, Eq)
|
||||||
|
|
||||||
data ABind = ABind Id [Id] (Ann AExpT) deriving (Show, Eq)
|
data ABind = ABind (T Ident) [T Ident] (Ann (T AExp)) deriving (Show, Eq)
|
||||||
data ABranch = ABranch (Pattern, Type) (Ann AExpT) deriving (Show, Eq)
|
data ABranch = ABranch (Pattern, Type) (Ann (T AExp)) deriving (Show, Eq)
|
||||||
|
|
||||||
type AExpT = (AExp, Type)
|
|
||||||
|
|
||||||
data AExp = AVar Ident
|
data AExp = AVar Ident
|
||||||
| AInj Ident
|
| AInj Ident
|
||||||
| ALit Lit
|
| ALit Lit
|
||||||
| ALet (Ann ABind) (Ann AExpT)
|
| ALet (Ann ABind) (Ann (T AExp))
|
||||||
| AApp (Ann AExpT) (Ann AExpT)
|
| AApp (Ann (T AExp)) (Ann (T AExp))
|
||||||
| AAdd (Ann AExpT) (Ann AExpT)
|
| AAdd (Ann (T AExp)) (Ann (T AExp))
|
||||||
| AAbs Ident (Ann AExpT)
|
| AAbs Ident (Ann (T AExp))
|
||||||
| ACase (Ann AExpT) [Ann ABranch]
|
| ACase (Ann (T AExp)) [Ann ABranch]
|
||||||
deriving (Show, Eq)
|
deriving (Show, Eq)
|
||||||
|
|
||||||
abstract :: [ABind] -> [Bind]
|
|
||||||
|
|
||||||
|
data BBind = BBind (T Ident) [T Ident] (T BExp)
|
||||||
|
| BBindCxt [T Ident] (T Ident) [T Ident] (T BExp)
|
||||||
|
deriving (Eq, Ord, Show)
|
||||||
|
|
||||||
|
|
||||||
|
data BBranch = BBranch (T Pattern) (T BExp)
|
||||||
|
deriving (Eq, Ord, Show)
|
||||||
|
|
||||||
|
data BExp
|
||||||
|
= BVar Ident
|
||||||
|
| BVarC [T Ident] Ident
|
||||||
|
| BInj Ident
|
||||||
|
| BLit Lit
|
||||||
|
| BLet BBind (T BExp)
|
||||||
|
| BApp (T BExp)(T BExp)
|
||||||
|
| BAdd (T BExp)(T BExp)
|
||||||
|
| BCase (T BExp) [BBranch]
|
||||||
|
deriving (Eq, Ord, Show)
|
||||||
|
|
||||||
|
|
||||||
|
abstract :: [ABind] -> [BBind]
|
||||||
abstract bs = evalState (mapM (abstractAnnBind . Ann []) bs) 0
|
abstract bs = evalState (mapM (abstractAnnBind . Ann []) bs) 0
|
||||||
|
|
||||||
abstractAnnBind :: Ann ABind -> State Int Bind
|
abstractAnnBind :: Ann ABind -> State Int BBind
|
||||||
abstractAnnBind Ann { term = ABind name vars annae } =
|
abstractAnnBind Ann { term = ABind name vars annae } =
|
||||||
Bind name (vars' <|| vars) <$> abstractAnnExp annae'
|
BBind name (vars' <|| vars) <$> abstractAnnExp annae'
|
||||||
where
|
where
|
||||||
(annae', vars') = go [] annae
|
(annae', vars') = go [] annae
|
||||||
where
|
where
|
||||||
|
|
@ -149,24 +172,27 @@ abstractAnnBind Ann { term = ABind name vars annae } =
|
||||||
Ann { term = (AAbs x ae, TFun t _) } -> go (snoc (x, t) acc) ae
|
Ann { term = (AAbs x ae, TFun t _) } -> go (snoc (x, t) acc) ae
|
||||||
ae -> (ae, acc)
|
ae -> (ae, acc)
|
||||||
|
|
||||||
abstractAnnExp :: Ann AExpT -> State Int ExpT
|
abstractAnnExp :: Ann (T AExp) -> State Int (T BExp)
|
||||||
abstractAnnExp Ann {frees, term = (annae, typ) } = case annae of
|
abstractAnnExp Ann {frees, term = (annae, typ) } = case annae of
|
||||||
AVar n -> pure (EVar n, typ)
|
AVar n -> pure (BVar n, typ)
|
||||||
AInj n -> pure (EInj n, typ)
|
AInj n -> pure (BInj n, typ)
|
||||||
ALit lit -> pure (ELit lit, typ)
|
ALit lit -> pure (BLit lit, typ)
|
||||||
AApp annae1 annae2 -> (, typ) <$> onM EApp abstractAnnExp annae1 annae2
|
AApp annae1 annae2 -> (, typ) <$> onM BApp abstractAnnExp annae1 annae2
|
||||||
AAdd annae1 annae2 -> (, typ) <$> onM EAdd abstractAnnExp annae1 annae2
|
AAdd annae1 annae2 -> (, typ) <$> onM BAdd abstractAnnExp annae1 annae2
|
||||||
|
|
||||||
-- \x. \y. x + y + z ⇒ let sc x y z = x + y + z in sc
|
|
||||||
AAbs x annae' -> do
|
AAbs x annae' -> do
|
||||||
i <- nextNumber
|
i <- nextNumber
|
||||||
rhs <- abstractAnnExp annae''
|
rhs <- abstractAnnExp annae''
|
||||||
let sc_name = Ident ("sc_" ++ show i)
|
let sc_name = Ident ("sc_" ++ show i)
|
||||||
e@(_, t) = foldl applyFree (EVar sc_name, typ) frees
|
sc | null frees = (BVar sc_name, typ)
|
||||||
pure (ELet (Bind (sc_name, typ) vars rhs) e ,t)
|
| otherwise = (BVarC frees sc_name, typ)
|
||||||
|
bind | null frees = BBind (sc_name, typ) vars rhs
|
||||||
|
| otherwise = BBindCxt frees (sc_name, typ) vars rhs
|
||||||
|
|
||||||
|
pure (BLet bind sc ,typ)
|
||||||
|
|
||||||
where
|
where
|
||||||
vars = frees <| (x, t_x) <|| ys
|
vars = [(x, t_x)] <|| ys
|
||||||
t_x = case typ of TFun t _ -> t
|
t_x = case typ of TFun t _ -> t
|
||||||
_ -> error "Impossible"
|
_ -> error "Impossible"
|
||||||
|
|
||||||
|
|
@ -176,54 +202,47 @@ abstractAnnExp Ann {frees, term = (annae, typ) } = case annae of
|
||||||
Ann { term = (AAbs x ae, TFun t _) } -> go (snoc (x, t) acc) ae
|
Ann { term = (AAbs x ae, TFun t _) } -> go (snoc (x, t) acc) ae
|
||||||
ae -> (ae, acc)
|
ae -> (ae, acc)
|
||||||
|
|
||||||
|
|
||||||
applyFree :: (Exp' Type, Type) -> (Ident, Type) -> (Exp' Type, Type)
|
|
||||||
applyFree (e, t_e) (x, t_x) = (EApp (e, t_e) (EVar x, t_x), t_e')
|
|
||||||
where
|
|
||||||
t_e' = case t_e of TFun _ t -> t
|
|
||||||
_ -> error "Impossible"
|
|
||||||
|
|
||||||
ACase annae' bs -> do
|
ACase annae' bs -> do
|
||||||
bs <- mapM go bs
|
bs <- mapM go bs
|
||||||
e <- abstractAnnExp annae'
|
e <- abstractAnnExp annae'
|
||||||
pure (ECase e bs, typ)
|
pure (BCase e bs, typ)
|
||||||
where
|
where
|
||||||
go Ann { term = ABranch p annae } = Branch p <$> abstractAnnExp annae
|
go Ann { term = ABranch p annae } = BBranch p <$> abstractAnnExp annae
|
||||||
|
|
||||||
ALet b annae' ->
|
ALet b annae' ->
|
||||||
(, typ) <$> liftA2 ELet (abstractAnnBind b) (abstractAnnExp annae')
|
(, typ) <$> liftA2 BLet (abstractAnnBind b) (abstractAnnExp annae')
|
||||||
|
|
||||||
|
|
||||||
-- | Collects supercombinators by lifting non-constant let expressions
|
-- | Collects supercombinators by lifting non-constant let expressions
|
||||||
collectScs :: [Bind] -> [Bind]
|
collectScs :: [BBind] -> [L.Bind]
|
||||||
collectScs = concatMap collectFromRhs
|
collectScs = concatMap collectFromRhs
|
||||||
where
|
where
|
||||||
collectFromRhs (Bind name parms rhs) =
|
collectFromRhs (BBind name parms rhs) =
|
||||||
let (rhs_scs, rhs') = collectScsExp rhs
|
let (rhs_scs, rhs') = collectScsExp rhs
|
||||||
in Bind name parms rhs' : rhs_scs
|
in L.Bind name parms rhs' : rhs_scs
|
||||||
|
collectFromRhs (BBindCxt cxt name parms rhs) =
|
||||||
|
let (rhs_scs, rhs') = collectScsExp rhs
|
||||||
|
in L.BindC cxt name parms rhs' : rhs_scs
|
||||||
|
|
||||||
|
|
||||||
collectScsExp :: ExpT -> ([Bind], ExpT)
|
collectScsExp :: T BExp -> ([L.Bind], T L.Exp)
|
||||||
collectScsExp expT@(exp, typ) = case exp of
|
collectScsExp (exp, typ) = case exp of
|
||||||
EVar _ -> ([], expT)
|
BVar x -> ([], (L.EVar x, typ))
|
||||||
EInj _ -> ([], expT)
|
BVarC as x -> ([], (L.EVarC as x, typ))
|
||||||
ELit _ -> ([], expT)
|
BInj k -> ([], (L.EInj k, typ))
|
||||||
|
BLit lit -> ([], (L.ELit lit, typ))
|
||||||
|
|
||||||
EApp e1 e2 -> (scs1 ++ scs2, (EApp e1' e2', typ))
|
BApp e1 e2 -> (scs1 ++ scs2, (L.EApp e1' e2', typ))
|
||||||
where
|
where
|
||||||
(scs1, e1') = collectScsExp e1
|
(scs1, e1') = collectScsExp e1
|
||||||
(scs2, e2') = collectScsExp e2
|
(scs2, e2') = collectScsExp e2
|
||||||
|
|
||||||
EAdd e1 e2 -> (scs1 ++ scs2, (EAdd e1' e2', typ))
|
BAdd e1 e2 -> (scs1 ++ scs2, (L.EAdd e1' e2', typ))
|
||||||
where
|
where
|
||||||
(scs1, e1') = collectScsExp e1
|
(scs1, e1') = collectScsExp e1
|
||||||
(scs2, e2') = collectScsExp e2
|
(scs2, e2') = collectScsExp e2
|
||||||
|
|
||||||
EAbs par e -> (scs, (EAbs par e', typ))
|
BCase e branches -> (scs ++ scs_e, (L.ECase e' branches', typ))
|
||||||
where
|
|
||||||
(scs, e') = collectScsExp e
|
|
||||||
|
|
||||||
ECase e branches -> (scs ++ scs_e, (ECase e' branches', typ))
|
|
||||||
where
|
where
|
||||||
(scs, branches') = mapAccumL f [] branches
|
(scs, branches') = mapAccumL f [] branches
|
||||||
(scs_e, e') = collectScsExp e
|
(scs_e, e') = collectScsExp e
|
||||||
|
|
@ -234,15 +253,24 @@ collectScsExp expT@(exp, typ) = case exp of
|
||||||
--
|
--
|
||||||
-- > f = let sc x y = rhs in e
|
-- > f = let sc x y = rhs in e
|
||||||
--
|
--
|
||||||
ELet (Bind name parms rhs) e
|
BLet (BBind name parms rhs) e
|
||||||
| null parms -> (rhs_scs ++ et_scs, (ELet bind et', snd et'))
|
| null parms -> (rhs_scs ++ et_scs, (L.ELet name rhs' et', snd et'))
|
||||||
| otherwise -> (bind : rhs_scs ++ et_scs, et')
|
| otherwise -> (bind : rhs_scs ++ et_scs, et')
|
||||||
where
|
where
|
||||||
bind = Bind name parms rhs'
|
bind = L.Bind name parms rhs'
|
||||||
(rhs_scs, rhs') = collectScsExp rhs
|
(rhs_scs, rhs') = collectScsExp rhs
|
||||||
(et_scs, et') = collectScsExp e
|
(et_scs, et') = collectScsExp e
|
||||||
|
|
||||||
collectScsBranch (Branch patt exp) = (scs, Branch patt exp')
|
|
||||||
|
BLet (BBindCxt cxt name parms rhs) e
|
||||||
|
| null parms -> (rhs_scs ++ et_scs, (L.ELet name rhs' et', snd et'))
|
||||||
|
| otherwise -> (bind : rhs_scs ++ et_scs, et')
|
||||||
|
where
|
||||||
|
bind = L.BindC cxt name parms rhs'
|
||||||
|
(rhs_scs, rhs') = collectScsExp rhs
|
||||||
|
(et_scs, et') = collectScsExp e
|
||||||
|
|
||||||
|
collectScsBranch (BBranch patt exp) = (scs, L.Branch (first toLirPattern patt) exp')
|
||||||
where (scs, exp') = collectScsExp exp
|
where (scs, exp') = collectScsExp exp
|
||||||
|
|
||||||
nextNumber :: State Int Int
|
nextNumber :: State Int Int
|
||||||
|
|
@ -259,3 +287,13 @@ xs <| x | elem x xs = xs
|
||||||
(<||) :: Eq a => [a] -> [a] -> [a]
|
(<||) :: Eq a => [a] -> [a] -> [a]
|
||||||
xs <|| ys = foldl (<|) xs ys
|
xs <|| ys = foldl (<|) xs ys
|
||||||
|
|
||||||
|
toLirData (Data t injs) = L.Data t (map toLirInj injs)
|
||||||
|
toLirInj (Inj n t) = L.Inj n t
|
||||||
|
|
||||||
|
toLirPattern :: Pattern -> L.Pattern
|
||||||
|
toLirPattern = \case
|
||||||
|
PVar x -> L.PVar x
|
||||||
|
PLit lit -> L.PLit lit
|
||||||
|
PCatch -> L.PCatch
|
||||||
|
PEnum k -> L.PEnum k
|
||||||
|
PInj k ps -> L.PInj k (map (first toLirPattern) ps)
|
||||||
|
|
|
||||||
140
src/LambdaLifterIr.hs
Normal file
140
src/LambdaLifterIr.hs
Normal file
|
|
@ -0,0 +1,140 @@
|
||||||
|
{-# LANGUAGE LambdaCase #-}
|
||||||
|
{-# LANGUAGE PatternSynonyms #-}
|
||||||
|
|
||||||
|
module LambdaLifterIr (
|
||||||
|
module Grammar.Abs,
|
||||||
|
module LambdaLifterIr,
|
||||||
|
module TypeChecker.TypeCheckerIr
|
||||||
|
) where
|
||||||
|
|
||||||
|
import Data.List (intercalate)
|
||||||
|
import Grammar.Abs (Lit (..))
|
||||||
|
import Grammar.Print
|
||||||
|
import Prelude hiding (exp)
|
||||||
|
import qualified Prelude as C (Eq, Ord, Show)
|
||||||
|
import TypeChecker.TypeCheckerIr (Ident (..), TVar (..), Type (..))
|
||||||
|
|
||||||
|
newtype Program = Program [Def]
|
||||||
|
deriving (C.Eq, C.Ord, C.Show)
|
||||||
|
|
||||||
|
data Def
|
||||||
|
= DBind Bind
|
||||||
|
| DData Data
|
||||||
|
deriving (C.Eq, C.Ord, C.Show)
|
||||||
|
|
||||||
|
data Data = Data Type [Inj]
|
||||||
|
deriving (C.Eq, C.Ord, C.Show)
|
||||||
|
|
||||||
|
data Inj = Inj Ident Type
|
||||||
|
deriving (C.Eq, C.Ord, C.Show)
|
||||||
|
|
||||||
|
data Pattern
|
||||||
|
= PVar Ident
|
||||||
|
| PLit Lit
|
||||||
|
| PCatch
|
||||||
|
| PEnum Ident
|
||||||
|
| PInj Ident [(Pattern, Type)]
|
||||||
|
deriving (C.Eq, C.Ord, C.Show)
|
||||||
|
|
||||||
|
data Exp
|
||||||
|
= EVar Ident
|
||||||
|
| EVarC [T Ident] Ident
|
||||||
|
| EInj Ident
|
||||||
|
| ELit Lit
|
||||||
|
| ELet (T Ident) (T Exp) (T Exp)
|
||||||
|
| EApp (T Exp)(T Exp)
|
||||||
|
| EAdd (T Exp)(T Exp)
|
||||||
|
| ECase (T Exp) [Branch]
|
||||||
|
deriving (C.Eq, C.Ord, C.Show)
|
||||||
|
|
||||||
|
|
||||||
|
type T a = (a, Type)
|
||||||
|
|
||||||
|
data Bind = Bind (T Ident) [T Ident] (T Exp)
|
||||||
|
| BindC [T Ident] (T Ident) [T Ident] (T Exp)
|
||||||
|
deriving (C.Eq, C.Ord, C.Show)
|
||||||
|
|
||||||
|
data Branch = Branch (T Pattern) (T Exp)
|
||||||
|
deriving (C.Eq, C.Ord, C.Show)
|
||||||
|
|
||||||
|
instance Print Program where
|
||||||
|
prt i (Program sc) = prt i sc
|
||||||
|
|
||||||
|
instance Print Bind where
|
||||||
|
prt i (Bind sig parms rhs) = concatD
|
||||||
|
[ prt i sig
|
||||||
|
, prt i parms
|
||||||
|
, doc $ showString "="
|
||||||
|
, prt i rhs
|
||||||
|
]
|
||||||
|
prt i (BindC cxt sig parms rhs) =
|
||||||
|
prPrec i 0 $
|
||||||
|
concatD
|
||||||
|
[ doc . showString $ "{" ++ intercalate ", " (map (\(x, _) -> printTree x ++ "^") cxt) ++ "}" ++ printTree sig
|
||||||
|
, prt i parms
|
||||||
|
, doc $ showString "="
|
||||||
|
, prt i rhs
|
||||||
|
]
|
||||||
|
|
||||||
|
instance Print [Bind] where
|
||||||
|
prt _ [] = concatD []
|
||||||
|
prt i [x] = concatD [prt i x]
|
||||||
|
prt i (x : xs) = concatD [prt i x, doc (showString ";"), prt i xs]
|
||||||
|
|
||||||
|
instance Print Exp where
|
||||||
|
prt i = \case
|
||||||
|
EVar lident -> prPrec i 3 (concatD [prt 0 lident])
|
||||||
|
EVarC as lident -> doc . showString
|
||||||
|
$ "{" ++ intercalate ", " (map go as) ++ "}" ++ printTree lident
|
||||||
|
where
|
||||||
|
go (x, _) = printTree x ++ "^=" ++ printTree (EVar x)
|
||||||
|
EInj uident -> prPrec i 3 (concatD [prt 0 uident])
|
||||||
|
ELit lit -> prPrec i 3 (concatD [prt 0 lit])
|
||||||
|
EApp exp1 exp2 -> prPrec i 2 (concatD [prt 2 exp1, prt 3 exp2])
|
||||||
|
EAdd exp1 exp2 -> prPrec i 1 (concatD [prt 1 exp1, doc (showString "+"), prt 2 exp2])
|
||||||
|
ELet lident exp1 exp2 -> prPrec i 0 (concatD [doc (showString "let"), prt 0 lident, doc (showString "="), prt 0 exp1 , doc (showString "in"), prt 0 exp2])
|
||||||
|
ECase exp branchs -> prPrec i 0 (concatD [doc (showString "case"), prt 0 exp, doc (showString "of"), doc (showString "{"), prt 0 branchs, doc (showString "}")])
|
||||||
|
|
||||||
|
|
||||||
|
instance Print Branch where
|
||||||
|
prt i (Branch (pattern_, t) exp) = prPrec i 0 (concatD [doc (showString "("), prt 0 pattern_, doc (showString " : "), prt 0 t, doc (showString ")"), doc (showString "=>"), prt 0 exp])
|
||||||
|
|
||||||
|
instance Print [Branch] where
|
||||||
|
prt _ [] = concatD []
|
||||||
|
prt _ [x] = concatD [prt 0 x]
|
||||||
|
prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs]
|
||||||
|
|
||||||
|
instance Print Def where
|
||||||
|
prt i = \case
|
||||||
|
DBind bind -> prPrec i 0 (concatD [prt 0 bind])
|
||||||
|
DData data_ -> prPrec i 0 (concatD [prt 0 data_])
|
||||||
|
|
||||||
|
instance Print Data where
|
||||||
|
prt i = \case
|
||||||
|
Data type_ injs -> prPrec i 0 (concatD [doc (showString "data"), prt 0 type_, doc (showString "where"), doc (showString "{"), prt 0 injs, doc (showString "}")])
|
||||||
|
|
||||||
|
instance Print Inj where
|
||||||
|
prt i = \case
|
||||||
|
Inj uident type_ -> prt i (uident, type_)
|
||||||
|
|
||||||
|
instance Print [Inj] where
|
||||||
|
prt _ [] = concatD []
|
||||||
|
prt i [x] = prt i x
|
||||||
|
prt i (x : xs) = prPrec i 0 $ concatD [prt i x, doc $ showString "\n ", prt i xs]
|
||||||
|
|
||||||
|
instance Print Pattern where
|
||||||
|
prt i = \case
|
||||||
|
PVar name -> prPrec i 1 (concatD [prt 0 name])
|
||||||
|
PLit lit -> prPrec i 1 (concatD [prt 0 lit])
|
||||||
|
PCatch -> prPrec i 1 (concatD [doc (showString "_")])
|
||||||
|
PEnum name -> prPrec i 1 (concatD [prt 0 name])
|
||||||
|
PInj uident patterns -> prPrec i 0 (concatD [prt 0 uident, prt 1 patterns])
|
||||||
|
|
||||||
|
instance Print [Def] where
|
||||||
|
prt _ [] = concatD []
|
||||||
|
prt _ [x] = concatD [prt 0 x]
|
||||||
|
prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs]
|
||||||
|
|
||||||
|
pattern DBind' id vars expt = DBind (Bind id vars expt)
|
||||||
|
pattern DData' typ injs = DData (Data typ injs)
|
||||||
|
|
||||||
25
src/Main.hs
25
src/Main.hs
|
|
@ -19,27 +19,18 @@ import Monomorphizer.Monomorphizer (monomorphize)
|
||||||
import OrderDefs (orderDefs)
|
import OrderDefs (orderDefs)
|
||||||
import Renamer.Renamer (rename)
|
import Renamer.Renamer (rename)
|
||||||
import ReportForall (reportForall)
|
import ReportForall (reportForall)
|
||||||
import System.Console.GetOpt (
|
import System.Console.GetOpt (ArgDescr (NoArg, ReqArg),
|
||||||
ArgDescr (NoArg, ReqArg),
|
|
||||||
ArgOrder (RequireOrder),
|
ArgOrder (RequireOrder),
|
||||||
OptDescr (Option),
|
OptDescr (Option), getOpt,
|
||||||
getOpt,
|
usageInfo)
|
||||||
usageInfo,
|
import System.Directory (createDirectory, doesPathExist,
|
||||||
)
|
|
||||||
import System.Directory (
|
|
||||||
createDirectory,
|
|
||||||
doesPathExist,
|
|
||||||
getDirectoryContents,
|
getDirectoryContents,
|
||||||
removeDirectoryRecursive,
|
removeDirectoryRecursive,
|
||||||
setCurrentDirectory,
|
setCurrentDirectory)
|
||||||
)
|
|
||||||
import System.Environment (getArgs)
|
import System.Environment (getArgs)
|
||||||
import System.Exit (
|
import System.Exit (ExitCode (ExitFailure),
|
||||||
ExitCode (ExitFailure),
|
exitFailure, exitSuccess,
|
||||||
exitFailure,
|
exitWith)
|
||||||
exitSuccess,
|
|
||||||
exitWith,
|
|
||||||
)
|
|
||||||
import System.IO (stderr)
|
import System.IO (stderr)
|
||||||
import System.Process (spawnCommand, waitForProcess)
|
import System.Process (spawnCommand, waitForProcess)
|
||||||
import TypeChecker.TypeChecker (TypeChecker (Bi, Hm), typecheck)
|
import TypeChecker.TypeChecker (TypeChecker (Bi, Hm), typecheck)
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,11 @@
|
||||||
|
|
||||||
module Monomorphizer.DataTypeRemover (removeDataTypes) where
|
module Monomorphizer.DataTypeRemover (removeDataTypes) where
|
||||||
|
|
||||||
import Monomorphizer.MonomorphizerIr qualified as M2
|
import Data.Bifunctor (Bifunctor (bimap))
|
||||||
import Monomorphizer.MorbIr qualified as M1
|
import Monomorphizer.MonomorphizerIr (Ident (..))
|
||||||
import TypeChecker.TypeCheckerIr (Ident (Ident))
|
import qualified Monomorphizer.MonomorphizerIr as M2
|
||||||
|
import qualified Monomorphizer.MorbIr as M1
|
||||||
|
import Prelude hiding (exp)
|
||||||
|
|
||||||
removeDataTypes :: M1.Program -> M2.Program
|
removeDataTypes :: M1.Program -> M2.Program
|
||||||
removeDataTypes (M1.Program defs) = M2.Program (map pDef defs)
|
removeDataTypes (M1.Program defs) = M2.Program (map pDef defs)
|
||||||
|
|
@ -30,16 +33,19 @@ newName (M1.TData (Ident str) args) = str ++ concatMap newName args
|
||||||
|
|
||||||
pBind :: M1.Bind -> M2.Bind
|
pBind :: M1.Bind -> M2.Bind
|
||||||
pBind (M1.Bind id argIds expt) = M2.Bind (pId id) (map pId argIds) (pExpT expt)
|
pBind (M1.Bind id argIds expt) = M2.Bind (pId id) (map pId argIds) (pExpT expt)
|
||||||
|
pBind (M1.BindC cxt id argIds expt) =
|
||||||
|
M2.BindC (map pId cxt) (pId id) (map pId argIds) (pExpT expt)
|
||||||
|
|
||||||
pId :: (Ident, M1.Type) -> (Ident, M2.Type)
|
pId :: (Ident, M1.Type) -> (Ident, M2.Type)
|
||||||
pId (ident, t) = (ident, pType t)
|
pId (ident, t) = (ident, pType t)
|
||||||
|
|
||||||
pExpT :: M1.ExpT -> M2.ExpT
|
pExpT :: M1.T M1.Exp -> M2.T M2.Exp
|
||||||
pExpT (exp, t) = (pExp exp, pType t)
|
pExpT (exp, t) = (pExp exp, pType t)
|
||||||
|
|
||||||
pExp :: M1.Exp -> M2.Exp
|
pExp :: M1.Exp -> M2.Exp
|
||||||
pExp (M1.EVar ident) = M2.EVar ident
|
pExp (M1.EVar ident) = M2.EVar ident
|
||||||
pExp (M1.ELit lit) = M2.ELit (pLit lit)
|
pExp (M1.EVarC as ident) = M2.EVarC (map pId as) ident
|
||||||
|
pExp (M1.ELit lit) = M2.ELit lit
|
||||||
pExp (M1.ELet bind expt) = M2.ELet (pBind bind) (pExpT expt)
|
pExp (M1.ELet bind expt) = M2.ELet (pBind bind) (pExpT expt)
|
||||||
pExp (M1.EApp e1 e2) = M2.EApp (pExpT e1) (pExpT e2)
|
pExp (M1.EApp e1 e2) = M2.EApp (pExpT e1) (pExpT e2)
|
||||||
pExp (M1.EAdd e1 e2) = M2.EAdd (pExpT e1) (pExpT e2)
|
pExp (M1.EAdd e1 e2) = M2.EAdd (pExpT e1) (pExpT e2)
|
||||||
|
|
@ -49,12 +55,9 @@ pBranch :: M1.Branch -> M2.Branch
|
||||||
pBranch (M1.Branch (patt, t) expt) = M2.Branch (pPattern patt, pType t) (pExpT expt)
|
pBranch (M1.Branch (patt, t) expt) = M2.Branch (pPattern patt, pType t) (pExpT expt)
|
||||||
|
|
||||||
pPattern :: M1.Pattern -> M2.Pattern
|
pPattern :: M1.Pattern -> M2.Pattern
|
||||||
pPattern (M1.PVar id) = M2.PVar (pId id)
|
pPattern (M1.PVar ident) = M2.PVar ident
|
||||||
pPattern (M1.PLit (lit, t)) = M2.PLit (pLit lit, pType t)
|
pPattern (M1.PLit lit) = M2.PLit lit
|
||||||
pPattern (M1.PInj ident patts) = M2.PInj ident (map pPattern patts)
|
pPattern (M1.PInj ident patts) = M2.PInj ident (map (bimap pPattern pType) patts)
|
||||||
pPattern M1.PCatch = M2.PCatch
|
pPattern M1.PCatch = M2.PCatch
|
||||||
pPattern (M1.PEnum ident) = M2.PEnum ident
|
pPattern (M1.PEnum ident) = M2.PEnum ident
|
||||||
|
|
||||||
pLit :: M1.Lit -> M2.Lit
|
|
||||||
pLit (M1.LInt v) = M2.LInt v
|
|
||||||
pLit (M1.LChar c) = M2.LChar c
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
{-# LANGUAGE LambdaCase #-}
|
{-# LANGUAGE LambdaCase #-}
|
||||||
|
{-# LANGUAGE OverloadedRecordDot #-}
|
||||||
|
|
||||||
{- | For now, converts polymorphic functions to concrete ones based on usage.
|
{- | For now, converts polymorphic functions to concrete ones based on usage.
|
||||||
Assumes lambdas are lifted.
|
Assumes lambdas are lifted.
|
||||||
|
|
@ -25,30 +26,35 @@ bind) is added to the resulting set of binds.
|
||||||
-}
|
-}
|
||||||
module Monomorphizer.Monomorphizer (monomorphize, morphExp, morphBind) where
|
module Monomorphizer.Monomorphizer (monomorphize, morphExp, morphBind) where
|
||||||
|
|
||||||
import Monomorphizer.DataTypeRemover (removeDataTypes)
|
|
||||||
import Monomorphizer.MonomorphizerIr qualified as O
|
|
||||||
import Monomorphizer.MorbIr qualified as M
|
|
||||||
import TypeChecker.TypeCheckerIr (Ident (Ident))
|
|
||||||
import TypeChecker.TypeCheckerIr qualified as T
|
|
||||||
|
|
||||||
import Control.Monad.Reader (
|
import Control.Monad.Reader (MonadReader (ask, local),
|
||||||
MonadReader (ask, local),
|
Reader, asks, runReader)
|
||||||
Reader,
|
import Control.Monad.State (MonadState (get),
|
||||||
asks,
|
StateT (runStateT), gets,
|
||||||
runReader,
|
modify)
|
||||||
)
|
|
||||||
import Control.Monad.State (
|
|
||||||
MonadState (get),
|
|
||||||
StateT (runStateT),
|
|
||||||
gets,
|
|
||||||
modify,
|
|
||||||
)
|
|
||||||
import Data.Coerce (coerce)
|
import Data.Coerce (coerce)
|
||||||
import Data.Map qualified as Map
|
import qualified Data.Map as Map
|
||||||
import Data.Maybe (catMaybes)
|
import Data.Maybe (catMaybes)
|
||||||
import Data.Set qualified as Set
|
import qualified Data.Set as Set
|
||||||
import Grammar.Print (printTree)
|
|
||||||
import Debug.Trace (trace)
|
import Debug.Trace (trace)
|
||||||
|
import Grammar.Print (printTree)
|
||||||
|
import Monomorphizer.DataTypeRemover (removeDataTypes)
|
||||||
|
import qualified Monomorphizer.MonomorphizerIr as O
|
||||||
|
import qualified Monomorphizer.MorbIr as M
|
||||||
|
-- import TypeChecker.TypeCheckerIr (Ident (Ident))
|
||||||
|
import LambdaLifterIr (Ident (..))
|
||||||
|
-- import TypeChecker.TypeCheckerIr qualified as T
|
||||||
|
import qualified LambdaLifterIr as L
|
||||||
|
|
||||||
|
import Control.Monad.Reader (MonadReader (ask, local),
|
||||||
|
Reader, asks, runReader)
|
||||||
|
import Control.Monad.State (MonadState, StateT (runStateT),
|
||||||
|
gets, modify)
|
||||||
|
import qualified Data.Map as Map
|
||||||
|
import Data.Maybe (catMaybes, fromJust)
|
||||||
|
import qualified Data.Set as Set
|
||||||
|
import Data.Tuple.Extra (secondM)
|
||||||
|
import Grammar.Print (printTree)
|
||||||
|
|
||||||
{- | EnvM is the monad containing the read-only state as well as the
|
{- | EnvM is the monad containing the read-only state as well as the
|
||||||
output state containing monomorphized functions and to-be monomorphized
|
output state containing monomorphized functions and to-be monomorphized
|
||||||
|
|
@ -64,13 +70,13 @@ Binds, Polymorphic Data types (monomorphized in a later step) and
|
||||||
Marked bind, which means that it is in the process of monomorphization
|
Marked bind, which means that it is in the process of monomorphization
|
||||||
and should not be monomorphized again.
|
and should not be monomorphized again.
|
||||||
-}
|
-}
|
||||||
data Outputted = Marked | Complete M.Bind | Data M.Type T.Data deriving (Show)
|
data Outputted = Marked | Complete M.Bind | Data M.Type L.Data deriving (Show)
|
||||||
|
|
||||||
-- | Static environment.
|
-- | Static environment.
|
||||||
data Env = Env
|
data Env = Env
|
||||||
{ input :: Map.Map Ident T.Bind
|
{ input :: Map.Map Ident L.Bind
|
||||||
-- ^ All binds in the program.
|
-- ^ All binds in the program.
|
||||||
, dataDefs :: Map.Map Ident T.Data
|
, dataDefs :: Map.Map Ident L.Data
|
||||||
-- ^ All constructors mapped to their respective polymorphic data def
|
-- ^ All constructors mapped to their respective polymorphic data def
|
||||||
-- which includes all other constructors.
|
-- which includes all other constructors.
|
||||||
, polys :: Map.Map Ident M.Type
|
, polys :: Map.Map Ident M.Type
|
||||||
|
|
@ -84,12 +90,13 @@ localExists :: Ident -> EnvM Bool
|
||||||
localExists ident = asks (Set.member ident . locals)
|
localExists ident = asks (Set.member ident . locals)
|
||||||
|
|
||||||
-- | Gets a polymorphic bind from an id.
|
-- | Gets a polymorphic bind from an id.
|
||||||
getInputBind :: Ident -> EnvM (Maybe T.Bind)
|
getInputBind :: Ident -> EnvM (Maybe L.Bind)
|
||||||
getInputBind ident = asks (Map.lookup ident . input)
|
getInputBind ident = asks (Map.lookup ident . input)
|
||||||
|
|
||||||
-- | Add monomorphic function derived from a polymorphic one, to env.
|
-- | Add monomorphic function derived from a polymorphic one, to env.
|
||||||
addOutputBind :: M.Bind -> EnvM ()
|
addOutputBind :: M.Bind -> EnvM ()
|
||||||
addOutputBind b@(M.Bind (ident, _) _ _) = modify (Map.insert ident (Complete b))
|
addOutputBind b@(M.Bind (ident, _) _ _) = modify (Map.insert ident (Complete b))
|
||||||
|
addOutputBind b@(M.BindC _ (ident, _) _ _) = modify (Map.insert ident (Complete b))
|
||||||
|
|
||||||
{- | Marks a global bind as being processed, meaning that when encountered again,
|
{- | Marks a global bind as being processed, meaning that when encountered again,
|
||||||
it should not be recursively processed.
|
it should not be recursively processed.
|
||||||
|
|
@ -106,8 +113,8 @@ isConsMarked :: Ident -> EnvM Bool
|
||||||
isConsMarked ident = gets (Map.member ident)
|
isConsMarked ident = gets (Map.member ident)
|
||||||
|
|
||||||
-- | Finds main bind.
|
-- | Finds main bind.
|
||||||
getMain :: EnvM T.Bind
|
getMain :: EnvM L.Bind
|
||||||
getMain = asks (\env -> case Map.lookup (T.Ident "main") (input env) of
|
getMain = asks (\env -> case Map.lookup (Ident "main") (input env) of
|
||||||
Just mainBind -> mainBind
|
Just mainBind -> mainBind
|
||||||
Nothing -> error "main not found in monomorphizer!"
|
Nothing -> error "main not found in monomorphizer!"
|
||||||
)
|
)
|
||||||
|
|
@ -116,13 +123,13 @@ getMain = asks (\env -> case Map.lookup (T.Ident "main") (input env) of
|
||||||
error when encountering different structures between the two arguments. Debug:
|
error when encountering different structures between the two arguments. Debug:
|
||||||
First argument is the name of the bind.
|
First argument is the name of the bind.
|
||||||
-}
|
-}
|
||||||
mapTypes :: Ident -> T.Type -> M.Type -> [(Ident, M.Type)]
|
mapTypes :: Ident -> L.Type -> M.Type -> [(Ident, M.Type)]
|
||||||
mapTypes _ident (T.TLit _) (M.TLit _) = []
|
mapTypes _ident (L.TLit _) (M.TLit _) = []
|
||||||
mapTypes _ident (T.TVar (T.MkTVar i1)) tm = [(i1, tm)]
|
mapTypes _ident (L.TVar (L.MkTVar i1)) tm = [(i1, tm)]
|
||||||
mapTypes ident (T.TFun pt1 pt2) (M.TFun mt1 mt2) =
|
mapTypes ident (L.TFun pt1 pt2) (M.TFun mt1 mt2) =
|
||||||
mapTypes ident pt1 mt1
|
mapTypes ident pt1 mt1
|
||||||
++ mapTypes ident pt2 mt2
|
++ mapTypes ident pt2 mt2
|
||||||
mapTypes ident (T.TData tIdent pTs) (M.TData mIdent mTs) =
|
mapTypes ident (L.TData tIdent pTs) (M.TData mIdent mTs) =
|
||||||
if tIdent /= mIdent
|
if tIdent /= mIdent
|
||||||
then error "the data type names of monomorphic and polymorphic data types does not match"
|
then error "the data type names of monomorphic and polymorphic data types does not match"
|
||||||
else foldl (\xs (p, m) -> mapTypes ident p m ++ xs) [] (zip pTs mTs)
|
else foldl (\xs (p, m) -> mapTypes ident p m ++ xs) [] (zip pTs mTs)
|
||||||
|
|
@ -130,30 +137,30 @@ mapTypes ident t1 t2 = error $ "in bind: '" ++ printTree ident ++ "', " ++
|
||||||
"structure of types not the same: '" ++ printTree t1 ++ "', '" ++ printTree t2 ++ "'"
|
"structure of types not the same: '" ++ printTree t1 ++ "', '" ++ printTree t2 ++ "'"
|
||||||
|
|
||||||
-- | Gets the mapped monomorphic type of a polymorphic type in the current context.
|
-- | Gets the mapped monomorphic type of a polymorphic type in the current context.
|
||||||
getMonoFromPoly :: T.Type -> EnvM M.Type
|
getMonoFromPoly :: L.Type -> EnvM M.Type
|
||||||
getMonoFromPoly t = do
|
getMonoFromPoly t = do
|
||||||
env <- ask
|
env <- ask
|
||||||
return $ getMono (polys env) t
|
return $ getMono (polys env) t
|
||||||
where
|
where
|
||||||
getMono :: Map.Map Ident M.Type -> T.Type -> M.Type
|
getMono :: Map.Map Ident M.Type -> L.Type -> M.Type
|
||||||
getMono polys t = case t of
|
getMono polys t = case t of
|
||||||
(T.TLit ident) -> M.TLit (coerce ident)
|
(L.TLit ident) -> M.TLit ident
|
||||||
(T.TFun t1 t2) -> M.TFun (getMono polys t1) (getMono polys t2)
|
(L.TFun t1 t2) -> M.TFun (getMono polys t1) (getMono polys t2)
|
||||||
(T.TVar (T.MkTVar ident)) -> case Map.lookup ident polys of
|
(L.TVar (L.MkTVar ident)) -> case Map.lookup ident polys of
|
||||||
Just concrete -> concrete
|
Just concrete -> concrete
|
||||||
Nothing -> M.TLit (Ident "void")
|
Nothing -> M.TLit (Ident "void")
|
||||||
-- error $ "type not found! type: " ++ show ident ++ ", error in previous compilation steps"
|
-- error $ "type not found! type: " ++ show ident ++ ", error in previous compilation steps"
|
||||||
(T.TData ident args) -> M.TData ident (map (getMono polys) args)
|
(L.TData ident args) -> M.TData ident (map (getMono polys) args)
|
||||||
|
|
||||||
{- | If ident not already in env's output, morphed bind to output
|
{- | If ident not already in env's output, morphed bind to output
|
||||||
(and all referenced binds within this bind).
|
(and all referenced binds within this bind).
|
||||||
Returns the annotated bind name.
|
Returns the annotated bind name.
|
||||||
-}
|
-}
|
||||||
morphBind :: M.Type -> T.Bind -> EnvM Ident
|
morphBind :: M.Type -> L.Bind -> EnvM Ident
|
||||||
morphBind expectedType b@(T.Bind (ident, btype) args (exp, expt)) = do
|
morphBind expectedType b@(L.Bind (ident, btype) args (exp, expt)) = do
|
||||||
-- The "new name" is used to find out if it is already marked or not.
|
-- The "new name" is used to find out if it is already marked or not.
|
||||||
let name' = newFuncName expectedType b
|
let name' = newFuncName expectedType b
|
||||||
bindMarked <- isBindMarked (coerce name')
|
bindMarked <- isBindMarked name'
|
||||||
local
|
local
|
||||||
( \env ->
|
( \env ->
|
||||||
env
|
env
|
||||||
|
|
@ -168,26 +175,59 @@ morphBind expectedType b@(T.Bind (ident, btype) args (exp, expt)) = do
|
||||||
else do
|
else do
|
||||||
-- Mark so that this bind will not be processed in recursive or cyclic
|
-- Mark so that this bind will not be processed in recursive or cyclic
|
||||||
-- function calls
|
-- function calls
|
||||||
markBind (coerce name')
|
markBind name'
|
||||||
expt' <- getMonoFromPoly expt
|
expt' <- getMonoFromPoly expt
|
||||||
exp' <- morphExp expt' exp
|
exp' <- morphExp expt' exp
|
||||||
-- Get monomorphic type sof args
|
-- Get monomorphic type sof args
|
||||||
args' <- mapM morphArg args
|
args' <- mapM morphArg args
|
||||||
addOutputBind $
|
addOutputBind $
|
||||||
M.Bind
|
M.Bind
|
||||||
(coerce name', expectedType)
|
(name', expectedType)
|
||||||
args'
|
args'
|
||||||
(exp', expt')
|
(exp', expt')
|
||||||
return name'
|
return name'
|
||||||
|
|
||||||
|
morphBind expectedType b@(L.BindC cxt (ident, btype) args (exp, expt)) = do
|
||||||
|
-- The "new name" is used to find out if it is already marked or not.
|
||||||
|
let name' = newFuncName expectedType b
|
||||||
|
bindMarked <- isBindMarked name'
|
||||||
|
local
|
||||||
|
( \env ->
|
||||||
|
env
|
||||||
|
{ locals = Set.fromList (map fst args)
|
||||||
|
, polys = Map.fromList (mapTypes ident btype expectedType)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
$ do
|
||||||
|
-- Return with right name if already marked
|
||||||
|
if bindMarked
|
||||||
|
then return name'
|
||||||
|
else do
|
||||||
|
-- Mark so that this bind will not be processed in recursive or cyclic
|
||||||
|
-- function calls
|
||||||
|
markBind name'
|
||||||
|
-- Get monomorphic type sof args
|
||||||
|
args' <- mapM morphArg args
|
||||||
|
cxt' <- mapM (secondM getMonoFromPoly) cxt
|
||||||
|
expt' <- getMonoFromPoly expt
|
||||||
|
exp' <- local (\env -> foldr (addLocal . fst) env cxt)
|
||||||
|
(morphExp expt' exp)
|
||||||
|
addOutputBind $
|
||||||
|
M.BindC cxt'
|
||||||
|
(name', expectedType)
|
||||||
|
args'
|
||||||
|
(exp', expt')
|
||||||
|
return name'
|
||||||
|
|
||||||
|
|
||||||
-- | Monomorphizes arguments of a bind.
|
-- | Monomorphizes arguments of a bind.
|
||||||
morphArg :: (Ident, T.Type) -> EnvM (Ident, M.Type)
|
morphArg :: (Ident, L.Type) -> EnvM (Ident, M.Type)
|
||||||
morphArg (ident, t) = do
|
morphArg (ident, t) = do
|
||||||
t' <- getMonoFromPoly t
|
t' <- getMonoFromPoly t
|
||||||
return (ident, t')
|
return (ident, t')
|
||||||
|
|
||||||
-- | Gets the data bind from the name of a constructor.
|
-- | Gets the data bind from the name of a constructor.
|
||||||
getInputData :: Ident -> EnvM (Maybe T.Data)
|
getInputData :: Ident -> EnvM (Maybe L.Data)
|
||||||
getInputData ident = do
|
getInputData ident = do
|
||||||
env <- ask
|
env <- ask
|
||||||
return $ Map.lookup ident (dataDefs env)
|
return $ Map.lookup ident (dataDefs env)
|
||||||
|
|
@ -201,50 +241,50 @@ morphCons expectedType ident newIdent = do
|
||||||
--trace ("Tjofras:" ++ show (newName expectedType ident)) $ return ()
|
--trace ("Tjofras:" ++ show (newName expectedType ident)) $ return ()
|
||||||
maybeD <- getInputData ident
|
maybeD <- getInputData ident
|
||||||
case maybeD of
|
case maybeD of
|
||||||
Nothing -> error $ "identifier '" ++ show ident ++ "' not found"
|
-- closures can have unbound variables
|
||||||
|
Nothing -> pure ()
|
||||||
Just d -> do
|
Just d -> do
|
||||||
modify (\output -> Map.insert newIdent (Data expectedType d) output)
|
modify (\output -> Map.insert newIdent (Data expectedType d) output)
|
||||||
|
|
||||||
-- | Converts literals from input to output tree.
|
-- | Converts literals from input to output tree.
|
||||||
convertLit :: T.Lit -> M.Lit
|
convertLit :: L.Lit -> M.Lit
|
||||||
convertLit (T.LInt v) = M.LInt v
|
convertLit (L.LInt v) = M.LInt v
|
||||||
convertLit (T.LChar v) = M.LChar v
|
convertLit (L.LChar v) = M.LChar v
|
||||||
|
|
||||||
|
|
||||||
-- | Monomorphizes an expression, given an expected type.
|
-- | Monomorphizes an expression, given an expected type.
|
||||||
morphExp :: M.Type -> T.Exp -> EnvM M.Exp
|
morphExp :: M.Type -> L.Exp -> EnvM M.Exp
|
||||||
morphExp expectedType exp = case exp of
|
morphExp expectedType exp = case exp of
|
||||||
T.ELit lit -> return $ M.ELit (convertLit lit)
|
L.ELit lit -> return $ M.ELit lit
|
||||||
-- Constructor
|
-- Constructor
|
||||||
T.EInj ident -> do
|
L.EInj ident -> do
|
||||||
let ident' = newName (getDataType expectedType) ident
|
let ident' = newName (getDataType expectedType) ident
|
||||||
morphCons expectedType ident ident'
|
morphCons expectedType ident ident'
|
||||||
return $ M.EVar ident'
|
return $ M.EVar ident'
|
||||||
T.EApp (e1, _t1) (e2, t2) -> do
|
L.EApp (e1, _t1) (e2, t2) -> do
|
||||||
t2' <- getMonoFromPoly t2
|
t2' <- getMonoFromPoly t2
|
||||||
e2' <- morphExp t2' e2
|
e2' <- morphExp t2' e2
|
||||||
e1' <- morphExp (M.TFun t2' expectedType) e1
|
e1' <- morphExp (M.TFun t2' expectedType) e1
|
||||||
return $ M.EApp (e1', M.TFun t2' expectedType) (e2', t2')
|
return $ M.EApp (e1', M.TFun t2' expectedType) (e2', t2')
|
||||||
T.EAdd (e1, t1) (e2, t2) -> do
|
L.EAdd (e1, t1) (e2, t2) -> do
|
||||||
t1' <- getMonoFromPoly t1
|
t1' <- getMonoFromPoly t1
|
||||||
t2' <- getMonoFromPoly t2
|
t2' <- getMonoFromPoly t2
|
||||||
e1' <- morphExp t1' e1
|
e1' <- morphExp t1' e1
|
||||||
e2' <- morphExp t2' e2
|
e2' <- morphExp t2' e2
|
||||||
return $ M.EAdd (e1', expectedType) (e2', expectedType)
|
return $ M.EAdd (e1', expectedType) (e2', expectedType)
|
||||||
T.EAbs ident (exp, t) -> local (\env -> env{locals = Set.insert ident (locals env)}) $ do
|
L.ECase (exp, t) bs -> do
|
||||||
t' <- getMonoFromPoly t
|
|
||||||
morphExp t' exp
|
|
||||||
T.ECase (exp, t) bs -> do
|
|
||||||
t' <- getMonoFromPoly t
|
t' <- getMonoFromPoly t
|
||||||
exp' <- morphExp t' exp
|
exp' <- morphExp t' exp
|
||||||
bs' <- mapM morphBranch bs
|
bs' <- mapM morphBranch bs
|
||||||
return $ M.ECase (exp', t') (catMaybes bs')
|
return $ M.ECase (exp', t') (catMaybes bs')
|
||||||
-- Ideally constructors should be EInj, though this code handles them
|
-- Ideally constructors should be EInj, though this code handles them
|
||||||
-- as well.
|
-- as well.
|
||||||
T.EVar ident -> do
|
-- FIXME MAKE EVAR AND EINJ SEPARATE!!!
|
||||||
|
L.EVar ident -> do
|
||||||
isLocal <- localExists ident
|
isLocal <- localExists ident
|
||||||
if isLocal
|
if isLocal
|
||||||
then do
|
then do
|
||||||
return $ M.EVar (coerce ident)
|
return $ M.EVar ident
|
||||||
else do
|
else do
|
||||||
bind <- getInputBind ident
|
bind <- getInputBind ident
|
||||||
case bind of
|
case bind of
|
||||||
|
|
@ -252,20 +292,33 @@ morphExp expectedType exp = case exp of
|
||||||
Just bind' -> do
|
Just bind' -> do
|
||||||
-- New bind to process
|
-- New bind to process
|
||||||
newBindName <- morphBind expectedType bind'
|
newBindName <- morphBind expectedType bind'
|
||||||
return $ M.EVar (coerce newBindName)
|
return $ M.EVar newBindName
|
||||||
T.ELet (T.Bind (identB, tB) args (expB, tExpB)) (exp, tExp) ->
|
L.EVarC as ident -> do
|
||||||
if length args > 0 then error "only constants in lets allowed"
|
isLocal <- localExists ident
|
||||||
|
if isLocal
|
||||||
|
then do
|
||||||
|
return $ M.EVar ident
|
||||||
else do
|
else do
|
||||||
|
bind <- fromJust <$> getInputBind ident
|
||||||
|
as' <- mapM (secondM getMonoFromPoly) as
|
||||||
|
-- New bind to process
|
||||||
|
newBindName <- morphBind expectedType bind
|
||||||
|
return $ M.EVarC as' newBindName
|
||||||
|
-- Ideally constructors should be EInj, though this code handles them
|
||||||
|
-- as well.
|
||||||
|
|
||||||
|
|
||||||
|
L.ELet (identB, tB) (expB, tExpB) (exp, tExp) -> do
|
||||||
tB' <- getMonoFromPoly tB
|
tB' <- getMonoFromPoly tB
|
||||||
tExpB' <- getMonoFromPoly tExpB
|
tExpB' <- getMonoFromPoly tExpB
|
||||||
tExp' <- getMonoFromPoly tExp
|
tExp' <- getMonoFromPoly tExp
|
||||||
expB' <- morphExp tExpB' expB
|
expB' <- morphExp tExpB' expB
|
||||||
exp' <- morphExp tExp' exp
|
exp' <- local (addLocal identB) (morphExp tExp' exp)
|
||||||
return $ M.ELet (M.Bind (identB, tB') [] (expB', tExpB')) (exp', tExp')
|
return $ M.ELet (M.Bind (identB, tB') [] (expB', tExpB')) (exp', tExp')
|
||||||
|
|
||||||
-- | Monomorphizes case-of branches.
|
-- | Monomorphizes case-of branches.
|
||||||
morphBranch :: T.Branch -> EnvM (Maybe M.Branch)
|
morphBranch :: L.Branch -> EnvM (Maybe M.Branch)
|
||||||
morphBranch (T.Branch (p, pt) (e, et)) = do
|
morphBranch (L.Branch (p, pt) (e, et)) = do
|
||||||
pt' <- getMonoFromPoly pt
|
pt' <- getMonoFromPoly pt
|
||||||
et' <- getMonoFromPoly et
|
et' <- getMonoFromPoly et
|
||||||
env <- ask
|
env <- ask
|
||||||
|
|
@ -275,15 +328,15 @@ morphBranch (T.Branch (p, pt) (e, et)) = do
|
||||||
Just (p', newLocals) ->
|
Just (p', newLocals) ->
|
||||||
local (const env { locals = Set.union (locals env) newLocals }) $ do
|
local (const env { locals = Set.union (locals env) newLocals }) $ do
|
||||||
e' <- morphExp et' e
|
e' <- morphExp et' e
|
||||||
return $ Just (M.Branch (p', pt') (e', et'))
|
return $ Just (M.Branch p' (e', et'))
|
||||||
|
|
||||||
morphPattern :: T.Pattern -> M.Type -> EnvM (Maybe (M.Pattern, Set.Set Ident))
|
morphPattern :: L.Pattern -> M.Type -> EnvM (Maybe (M.T M.Pattern, Set.Set Ident))
|
||||||
morphPattern p expectedType = case p of
|
morphPattern p expectedType = case p of
|
||||||
T.PVar ident -> return $ Just (M.PVar (ident, expectedType), Set.singleton ident)
|
L.PVar ident -> return $ Just ((M.PVar ident, expectedType), Set.singleton ident)
|
||||||
T.PLit lit -> return $ Just (M.PLit (convertLit lit, expectedType), Set.empty)
|
L.PLit lit -> return $ Just ((M.PLit (convertLit lit), expectedType), Set.empty)
|
||||||
T.PCatch -> return $ Just (M.PCatch, Set.empty)
|
L.PCatch -> return $ Just ((M.PCatch, expectedType), Set.empty)
|
||||||
T.PEnum ident -> return $ Just (M.PEnum (newName expectedType ident), Set.empty)
|
L.PEnum ident -> return $ Just ((M.PEnum (newName expectedType ident), expectedType), Set.empty)
|
||||||
T.PInj ident pts -> do let newIdent = newName expectedType ident
|
L.PInj ident pts -> do let newIdent = newName expectedType ident
|
||||||
outEnv <- get
|
outEnv <- get
|
||||||
trace ("WOW: " ++ show (newName expectedType ident)) $ return ()
|
trace ("WOW: " ++ show (newName expectedType ident)) $ return ()
|
||||||
trace ("WOW2: " ++ show (outEnv)) $ return ()
|
trace ("WOW2: " ++ show (outEnv)) $ return ()
|
||||||
|
|
@ -298,12 +351,17 @@ morphPattern p expectedType = case p of
|
||||||
case maybePsSets of
|
case maybePsSets of
|
||||||
Nothing -> return Nothing
|
Nothing -> return Nothing
|
||||||
Just psSets' -> return $ Just
|
Just psSets' -> return $ Just
|
||||||
(M.PInj newIdent (map fst psSets'), Set.unions $ map snd psSets')
|
((M.PInj newIdent (map fst psSets'), expectedType), Set.unions $ map snd psSets')
|
||||||
else return Nothing
|
else return Nothing
|
||||||
|
|
||||||
-- | Creates a new identifier for a function with an assigned type.
|
-- | Creates a new identifier for a function with an assigned type.
|
||||||
newFuncName :: M.Type -> T.Bind -> Ident
|
newFuncName :: M.Type -> L.Bind -> Ident
|
||||||
newFuncName t (T.Bind (ident@(Ident bindName), _) _ _) =
|
newFuncName t (L.Bind (ident@(Ident bindName), _) _ _) =
|
||||||
|
if bindName == "main"
|
||||||
|
then Ident bindName
|
||||||
|
else newName t ident
|
||||||
|
|
||||||
|
newFuncName t (L.BindC _ (ident@(Ident bindName), _) _ _) =
|
||||||
if bindName == "main"
|
if bindName == "main"
|
||||||
then Ident bindName
|
then Ident bindName
|
||||||
else newName t ident
|
else newName t ident
|
||||||
|
|
@ -317,8 +375,8 @@ newName t (Ident str) = Ident $ str ++ "$" ++ newName' t
|
||||||
newName' (M.TData (Ident str) ts) = str ++ foldl (\s t -> s ++ "." ++ newName' t) "" ts
|
newName' (M.TData (Ident str) ts) = str ++ foldl (\s t -> s ++ "." ++ newName' t) "" ts
|
||||||
|
|
||||||
-- | Monomorphization step.
|
-- | Monomorphization step.
|
||||||
monomorphize :: T.Program -> O.Program
|
monomorphize :: L.Program -> O.Program
|
||||||
monomorphize (T.Program defs) =
|
monomorphize (L.Program defs) =
|
||||||
removeDataTypes $
|
removeDataTypes $
|
||||||
M.Program
|
M.Program
|
||||||
( getDefsFromOutput
|
( getDefsFromOutput
|
||||||
|
|
@ -336,7 +394,7 @@ runEnvM :: Output -> Env -> EnvM () -> Output
|
||||||
runEnvM o env (EnvM stateM) = snd $ runReader (runStateT stateM o) env
|
runEnvM o env (EnvM stateM) = snd $ runReader (runStateT stateM o) env
|
||||||
|
|
||||||
-- | Creates the environment based on the input binds.
|
-- | Creates the environment based on the input binds.
|
||||||
createEnv :: [T.Def] -> Env
|
createEnv :: [L.Def] -> Env
|
||||||
createEnv defs =
|
createEnv defs =
|
||||||
Env
|
Env
|
||||||
{ input = Map.fromList bindPairs
|
{ input = Map.fromList bindPairs
|
||||||
|
|
@ -346,33 +404,34 @@ createEnv defs =
|
||||||
}
|
}
|
||||||
where
|
where
|
||||||
bindPairs = (map (\b -> (getBindName b, b)) . getBindsFromDefs) defs
|
bindPairs = (map (\b -> (getBindName b, b)) . getBindsFromDefs) defs
|
||||||
dataPairs :: [(Ident, T.Data)]
|
dataPairs :: [(Ident, L.Data)]
|
||||||
dataPairs = (foldl (\acc d@(T.Data _ cs) -> map ((,d) . getConsName) cs ++ acc) [] . getDataFromDefs) defs
|
dataPairs = (foldl (\acc d@(L.Data _ cs) -> map ((,d) . getConsName) cs ++ acc) [] . getDataFromDefs) defs
|
||||||
|
|
||||||
-- | Gets a top-lefel function name.
|
-- | Gets a top-lefel function name.
|
||||||
getBindName :: T.Bind -> Ident
|
getBindName :: L.Bind -> Ident
|
||||||
getBindName (T.Bind (ident, _) _ _) = ident
|
getBindName (L.Bind (ident, _) _ _) = ident
|
||||||
|
getBindName (L.BindC _ (ident, _) _ _) = ident
|
||||||
|
|
||||||
-- Helper functions
|
-- Helper functions
|
||||||
-- Gets custom data declarations form defs.
|
-- Gets custom data declarations form defs.
|
||||||
getDataFromDefs :: [T.Def] -> [T.Data]
|
getDataFromDefs :: [L.Def] -> [L.Data]
|
||||||
getDataFromDefs =
|
getDataFromDefs =
|
||||||
foldl
|
foldl
|
||||||
( \bs -> \case
|
( \bs -> \case
|
||||||
T.DBind _ -> bs
|
L.DBind _ -> bs
|
||||||
T.DData d -> d : bs
|
L.DData d -> d : bs
|
||||||
)
|
)
|
||||||
[]
|
[]
|
||||||
|
|
||||||
getConsName :: T.Inj -> Ident
|
getConsName :: L.Inj -> Ident
|
||||||
getConsName (T.Inj ident _) = ident
|
getConsName (L.Inj ident _) = ident
|
||||||
|
|
||||||
getBindsFromDefs :: [T.Def] -> [T.Bind]
|
getBindsFromDefs :: [L.Def] -> [L.Bind]
|
||||||
getBindsFromDefs =
|
getBindsFromDefs =
|
||||||
foldl
|
foldl
|
||||||
( \bs -> \case
|
( \bs -> \case
|
||||||
T.DBind b -> b : bs
|
L.DBind b -> b : bs
|
||||||
T.DData _ -> bs
|
L.DData _ -> bs
|
||||||
)
|
)
|
||||||
[]
|
[]
|
||||||
|
|
||||||
|
|
@ -384,7 +443,7 @@ getDefsFromOutput o =
|
||||||
(binds, dataInput) = splitBindsAndData o
|
(binds, dataInput) = splitBindsAndData o
|
||||||
|
|
||||||
-- | Splits the output into binds and data declaration components (used in createNewData)
|
-- | Splits the output into binds and data declaration components (used in createNewData)
|
||||||
splitBindsAndData :: Output -> ([M.Bind], [(Ident, M.Type, T.Data)])
|
splitBindsAndData :: Output -> ([M.Bind], [(Ident, M.Type, L.Data)])
|
||||||
splitBindsAndData output =
|
splitBindsAndData output =
|
||||||
foldl
|
foldl
|
||||||
( \(oBinds, oData) (ident, o) -> case o of
|
( \(oBinds, oData) (ident, o) -> case o of
|
||||||
|
|
@ -396,7 +455,7 @@ splitBindsAndData output =
|
||||||
(Map.toList output)
|
(Map.toList output)
|
||||||
|
|
||||||
-- | Converts all found constructors to monomorphic data declarations.
|
-- | Converts all found constructors to monomorphic data declarations.
|
||||||
createNewData :: [(Ident, M.Type, T.Data)] -> Map.Map Ident M.Data -> Map.Map Ident M.Data
|
createNewData :: [(Ident, M.Type, L.Data)] -> Map.Map Ident M.Data -> Map.Map Ident M.Data
|
||||||
createNewData [] o = o
|
createNewData [] o = o
|
||||||
createNewData ((consIdent, consType, polyData) : input) o =
|
createNewData ((consIdent, consType, polyData) : input) o =
|
||||||
createNewData input $
|
createNewData input $
|
||||||
|
|
@ -406,7 +465,7 @@ createNewData ((consIdent, consType, polyData) : input) o =
|
||||||
(M.Data newDataType [newCons])
|
(M.Data newDataType [newCons])
|
||||||
o
|
o
|
||||||
where
|
where
|
||||||
T.Data (T.TData polyDataIdent _) _ = polyData
|
L.Data (L.TData polyDataIdent _) _ = polyData
|
||||||
newDataType = getDataType consType
|
newDataType = getDataType consType
|
||||||
newDataName = newName newDataType polyDataIdent
|
newDataName = newName newDataType polyDataIdent
|
||||||
newCons = M.Inj consIdent consType
|
newCons = M.Inj consIdent consType
|
||||||
|
|
@ -417,3 +476,6 @@ getDataType (M.TFun _t1 t2) = getDataType t2
|
||||||
getDataType tData@(M.TData _ _) = tData
|
getDataType tData@(M.TData _ _) = tData
|
||||||
getDataType _ = error "???"
|
getDataType _ = error "???"
|
||||||
|
|
||||||
|
|
||||||
|
addLocal :: Ident -> Env -> Env
|
||||||
|
addLocal x env = env { locals = Set.insert x env.locals }
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,14 @@
|
||||||
{-# LANGUAGE LambdaCase #-}
|
{-# LANGUAGE LambdaCase #-}
|
||||||
|
|
||||||
module Monomorphizer.MonomorphizerIr (module Monomorphizer.MonomorphizerIr) where
|
module Monomorphizer.MonomorphizerIr (
|
||||||
|
module Monomorphizer.MonomorphizerIr,
|
||||||
|
module LambdaLifterIr
|
||||||
|
) where
|
||||||
|
|
||||||
|
import Data.List (intercalate)
|
||||||
import Grammar.Print
|
import Grammar.Print
|
||||||
import TypeChecker.TypeCheckerIr qualified as TIR (Ident (..))
|
import LambdaLifterIr (Ident (..), Lit (..))
|
||||||
|
import Prelude hiding (exp)
|
||||||
type Id = (TIR.Ident, Type)
|
|
||||||
|
|
||||||
newtype Program = Program [Def]
|
newtype Program = Program [Def]
|
||||||
deriving (Show, Ord, Eq)
|
deriving (Show, Ord, Eq)
|
||||||
|
|
@ -16,40 +19,37 @@ data Def = DBind Bind | DData Data
|
||||||
data Data = Data Type [Inj]
|
data Data = Data Type [Inj]
|
||||||
deriving (Show, Ord, Eq)
|
deriving (Show, Ord, Eq)
|
||||||
|
|
||||||
data Bind = Bind Id [Id] ExpT
|
data Bind = Bind (T Ident) [T Ident] (T Exp)
|
||||||
|
| BindC [T Ident] (T Ident) [T Ident] (T Exp)
|
||||||
deriving (Show, Ord, Eq)
|
deriving (Show, Ord, Eq)
|
||||||
|
|
||||||
|
type T a = (a, Type)
|
||||||
|
|
||||||
data Exp
|
data Exp
|
||||||
= EVar TIR.Ident
|
= EVar Ident
|
||||||
|
| EVarC [T Ident] Ident
|
||||||
| ELit Lit
|
| ELit Lit
|
||||||
| ELet Bind ExpT
|
| ELet Bind (T Exp)
|
||||||
| EApp ExpT ExpT
|
| EApp (T Exp) (T Exp)
|
||||||
| EAdd ExpT ExpT
|
| EAdd (T Exp) (T Exp)
|
||||||
| ECase ExpT [Branch]
|
| ECase (T Exp) [Branch]
|
||||||
deriving (Show, Ord, Eq)
|
deriving (Show, Ord, Eq)
|
||||||
|
|
||||||
data Pattern
|
data Pattern
|
||||||
= PVar Id
|
= PVar Ident
|
||||||
| PLit (Lit, Type)
|
| PLit Lit
|
||||||
| PInj TIR.Ident [Pattern]
|
| PInj Ident [T Pattern]
|
||||||
| PCatch
|
| PCatch
|
||||||
| PEnum TIR.Ident
|
| PEnum Ident
|
||||||
deriving (Eq, Ord, Show)
|
deriving (Eq, Ord, Show)
|
||||||
|
|
||||||
data Branch = Branch (Pattern, Type) ExpT
|
data Branch = Branch (T Pattern) (T Exp)
|
||||||
deriving (Eq, Ord, Show)
|
deriving (Eq, Ord, Show)
|
||||||
|
|
||||||
type ExpT = (Exp, Type)
|
data Inj = Inj Ident Type
|
||||||
|
|
||||||
data Inj = Inj TIR.Ident Type
|
|
||||||
deriving (Show, Ord, Eq)
|
deriving (Show, Ord, Eq)
|
||||||
|
|
||||||
data Lit
|
data Type = TLit Ident | TFun Type Type
|
||||||
= LInt Integer
|
|
||||||
| LChar Char
|
|
||||||
deriving (Show, Ord, Eq)
|
|
||||||
|
|
||||||
data Type = TLit TIR.Ident | TFun Type Type
|
|
||||||
deriving (Show, Ord, Eq)
|
deriving (Show, Ord, Eq)
|
||||||
|
|
||||||
flattenType :: Type -> [Type]
|
flattenType :: Type -> [Type]
|
||||||
|
|
@ -59,47 +59,40 @@ flattenType x = [x]
|
||||||
instance Print Program where
|
instance Print Program where
|
||||||
prt i (Program sc) = prPrec i 0 $ prt 0 sc
|
prt i (Program sc) = prPrec i 0 $ prt 0 sc
|
||||||
|
|
||||||
instance Print (Bind) where
|
instance Print Bind where
|
||||||
prt i (Bind sig@(name, _) parms rhs) =
|
prt i (Bind sig@(name, _) parms rhs) =
|
||||||
prPrec i 0 $
|
prPrec i 0 $
|
||||||
concatD
|
concatD
|
||||||
[ prtSig sig
|
[ prt 0 sig
|
||||||
, prt 0 name
|
, prt 0 name
|
||||||
, prtIdPs 0 parms
|
, prt 0 parms
|
||||||
, doc $ showString "="
|
, doc $ showString "="
|
||||||
, prt 0 rhs
|
, prt 0 rhs
|
||||||
]
|
]
|
||||||
|
|
||||||
prtSig :: Id -> Doc
|
prt i (BindC cxt sig parms rhs) =
|
||||||
prtSig (name, t) =
|
prPrec i 0 $
|
||||||
concatD
|
concatD
|
||||||
[ prt 0 name
|
[ doc . showString $ "{" ++ intercalate ", " (map (\(x, _) -> printTree x ++ "^") cxt) ++ "}" ++ printTree sig
|
||||||
, doc $ showString ":"
|
, prt i parms
|
||||||
, prt 0 t
|
, doc $ showString "="
|
||||||
, doc $ showString ";"
|
, prt i rhs
|
||||||
]
|
]
|
||||||
|
|
||||||
instance Print (ExpT) where
|
|
||||||
prt i (e, t) =
|
|
||||||
concatD
|
|
||||||
[ doc $ showString "("
|
|
||||||
, prt i e
|
|
||||||
, doc $ showString ","
|
|
||||||
, prt i t
|
|
||||||
, doc $ showString ")"
|
|
||||||
]
|
|
||||||
|
|
||||||
instance Print [Bind] where
|
instance Print [Bind] where
|
||||||
prt _ [] = concatD []
|
prt _ [] = concatD []
|
||||||
prt _ [x] = concatD [prt 0 x]
|
prt _ [x] = concatD [prt 0 x]
|
||||||
prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs]
|
prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs]
|
||||||
|
|
||||||
prtIdPs :: Int -> [Id] -> Doc
|
|
||||||
prtIdPs i = prPrec i 0 . concatD . map (prt i)
|
|
||||||
|
|
||||||
instance Print Exp where
|
instance Print Exp where
|
||||||
prt i = \case
|
prt i = \case
|
||||||
EVar name -> prPrec i 3 $ prt 0 name
|
EVar name -> prPrec i 3 $ prt 0 name
|
||||||
|
EVarC as lident -> doc . showString
|
||||||
|
$ "{" ++ intercalate ", " (map go as) ++ "}" ++ printTree lident
|
||||||
|
where
|
||||||
|
go (x, _) = printTree x ++ "^=" ++ printTree (EVar x)
|
||||||
ELit lit -> prPrec i 3 $ prt 0 lit
|
ELit lit -> prPrec i 3 $ prt 0 lit
|
||||||
ELet b e ->
|
ELet b e ->
|
||||||
prPrec i 3 $
|
prPrec i 3 $
|
||||||
|
|
@ -134,7 +127,7 @@ instance Print Exp where
|
||||||
]
|
]
|
||||||
|
|
||||||
instance Print Branch where
|
instance Print Branch where
|
||||||
prt i (Branch (pattern_, t) exp) = prPrec i 0 (concatD [doc (showString "("), prt 0 pattern_, doc (showString " : "), prt 0 t, doc (showString ")"), doc (showString "=>"), prt 0 exp])
|
prt i (Branch patt exp) = prPrec i 0 (concatD [prt i patt, doc (showString "=>"), prt 0 exp])
|
||||||
|
|
||||||
instance Print [Branch] where
|
instance Print [Branch] where
|
||||||
prt _ [] = concatD []
|
prt _ [] = concatD []
|
||||||
|
|
@ -152,12 +145,12 @@ instance Print Data where
|
||||||
|
|
||||||
instance Print Inj where
|
instance Print Inj where
|
||||||
prt i = \case
|
prt i = \case
|
||||||
Inj uident type_ -> prPrec i 0 (concatD [prt 0 uident, doc (showString ":"), prt 0 type_])
|
Inj uident type_ -> prt i (uident, type_)
|
||||||
|
|
||||||
instance Print Pattern where
|
instance Print Pattern where
|
||||||
prt i = \case
|
prt i = \case
|
||||||
PVar name -> prPrec i 1 (concatD [prt 0 name])
|
PVar name -> prPrec i 1 (concatD [prt 0 name])
|
||||||
PLit (lit, _) -> prPrec i 1 (concatD [prt 0 lit])
|
PLit lit -> prPrec i 1 (concatD [prt 0 lit])
|
||||||
PCatch -> prPrec i 1 (concatD [doc (showString "_")])
|
PCatch -> prPrec i 1 (concatD [doc (showString "_")])
|
||||||
PEnum name -> prPrec i 1 (concatD [prt 0 name])
|
PEnum name -> prPrec i 1 (concatD [prt 0 name])
|
||||||
PInj uident patterns -> prPrec i 0 (concatD [prt 0 uident, prt 1 patterns])
|
PInj uident patterns -> prPrec i 0 (concatD [prt 0 uident, prt 1 patterns])
|
||||||
|
|
@ -175,8 +168,3 @@ instance Print Type where
|
||||||
prt i = \case
|
prt i = \case
|
||||||
TLit uident -> prPrec i 1 (concatD [prt 0 uident])
|
TLit uident -> prPrec i 1 (concatD [prt 0 uident])
|
||||||
TFun type_1 type_2 -> prPrec i 0 (concatD [prt 1 type_1, doc (showString "->"), prt 0 type_2])
|
TFun type_1 type_2 -> prPrec i 0 (concatD [prt 1 type_1, doc (showString "->"), prt 0 type_2])
|
||||||
|
|
||||||
instance Print Lit where
|
|
||||||
prt i = \case
|
|
||||||
LInt int -> prt i int
|
|
||||||
LChar char -> prt i char
|
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,14 @@
|
||||||
{-# LANGUAGE LambdaCase #-}
|
{-# LANGUAGE LambdaCase #-}
|
||||||
module Monomorphizer.MorbIr where
|
|
||||||
|
|
||||||
|
module Monomorphizer.MorbIr (
|
||||||
|
module Monomorphizer.MorbIr,
|
||||||
|
module LambdaLifterIr
|
||||||
|
) where
|
||||||
|
|
||||||
|
import Data.List (intercalate)
|
||||||
import Grammar.Print
|
import Grammar.Print
|
||||||
import TypeChecker.TypeCheckerIr qualified as TIR (Ident (..))
|
import LambdaLifterIr (Ident (..), Lit (..))
|
||||||
|
import Prelude hiding (exp)
|
||||||
type Id = (TIR.Ident, Type)
|
|
||||||
|
|
||||||
newtype Program = Program [Def]
|
newtype Program = Program [Def]
|
||||||
deriving (Show, Ord, Eq)
|
deriving (Show, Ord, Eq)
|
||||||
|
|
@ -15,40 +19,39 @@ data Def = DBind Bind | DData Data
|
||||||
data Data = Data Type [Inj]
|
data Data = Data Type [Inj]
|
||||||
deriving (Show, Ord, Eq)
|
deriving (Show, Ord, Eq)
|
||||||
|
|
||||||
data Bind = Bind Id [Id] ExpT
|
data Bind = Bind (T Ident) [T Ident] (T Exp)
|
||||||
|
| BindC [T Ident] (T Ident) [T Ident] (T Exp)
|
||||||
deriving (Show, Ord, Eq)
|
deriving (Show, Ord, Eq)
|
||||||
|
|
||||||
|
|
||||||
|
type T a = (a, Type)
|
||||||
|
|
||||||
data Exp
|
data Exp
|
||||||
= EVar TIR.Ident
|
= EVar Ident
|
||||||
|
| EVarC [T Ident] Ident
|
||||||
| ELit Lit
|
| ELit Lit
|
||||||
| ELet Bind ExpT
|
| ELet Bind (T Exp)
|
||||||
| EApp ExpT ExpT
|
| EApp (T Exp) (T Exp)
|
||||||
| EAdd ExpT ExpT
|
| EAdd (T Exp) (T Exp)
|
||||||
| ECase ExpT [Branch]
|
| ECase (T Exp) [Branch]
|
||||||
deriving (Show, Ord, Eq)
|
deriving (Show, Ord, Eq)
|
||||||
|
|
||||||
data Pattern
|
data Pattern
|
||||||
= PVar Id
|
= PVar Ident
|
||||||
| PLit (Lit, Type)
|
| PLit Lit
|
||||||
| PInj TIR.Ident [Pattern]
|
| PInj Ident [T Pattern]
|
||||||
| PCatch
|
| PCatch
|
||||||
| PEnum TIR.Ident
|
| PEnum Ident
|
||||||
deriving (Eq, Ord, Show)
|
deriving (Eq, Ord, Show)
|
||||||
|
|
||||||
data Branch = Branch (Pattern, Type) ExpT
|
|
||||||
|
data Branch = Branch (T Pattern) (T Exp)
|
||||||
deriving (Eq, Ord, Show)
|
deriving (Eq, Ord, Show)
|
||||||
|
|
||||||
type ExpT = (Exp, Type)
|
data Inj = Inj Ident Type
|
||||||
|
|
||||||
data Inj = Inj TIR.Ident Type
|
|
||||||
deriving (Show, Ord, Eq)
|
deriving (Show, Ord, Eq)
|
||||||
|
|
||||||
data Lit
|
data Type = TLit Ident | TFun Type Type | TData Ident [Type]
|
||||||
= LInt Integer
|
|
||||||
| LChar Char
|
|
||||||
deriving (Show, Ord, Eq)
|
|
||||||
|
|
||||||
data Type = TLit TIR.Ident | TFun Type Type | TData TIR.Ident [Type]
|
|
||||||
|
|
||||||
deriving (Show, Ord, Eq)
|
deriving (Show, Ord, Eq)
|
||||||
|
|
||||||
|
|
@ -59,34 +62,24 @@ flattenType x = [x]
|
||||||
instance Print Program where
|
instance Print Program where
|
||||||
prt i (Program sc) = prPrec i 0 $ prt 0 sc
|
prt i (Program sc) = prPrec i 0 $ prt 0 sc
|
||||||
|
|
||||||
instance Print (Bind) where
|
instance Print Bind where
|
||||||
prt i (Bind sig@(name, _) parms rhs) =
|
prt i (Bind sig@(name, _) parms rhs) =
|
||||||
prPrec i 0 $
|
prPrec i 0 $
|
||||||
concatD
|
concatD
|
||||||
[ prtSig sig
|
[ prt 0 sig
|
||||||
, prt 0 name
|
, prt 0 name
|
||||||
, prtIdPs 0 parms
|
, prt 0 parms
|
||||||
, doc $ showString "="
|
, doc $ showString "="
|
||||||
, prt 0 rhs
|
, prt 0 rhs
|
||||||
]
|
]
|
||||||
|
|
||||||
prtSig :: Id -> Doc
|
prt i (BindC cxt sig parms rhs) =
|
||||||
prtSig (name, t) =
|
prPrec i 0 $
|
||||||
concatD
|
concatD
|
||||||
[ prt 0 name
|
[ doc . showString $ "{" ++ intercalate ", " (map (\(x, _) -> printTree x ++ "^") cxt) ++ "}" ++ printTree sig
|
||||||
, doc $ showString ":"
|
, prt i parms
|
||||||
, prt 0 t
|
, doc $ showString "="
|
||||||
, doc $ showString ";"
|
, prt i rhs
|
||||||
]
|
|
||||||
|
|
||||||
instance Print (ExpT) where
|
|
||||||
prt i (e, t) =
|
|
||||||
concatD
|
|
||||||
[ doc $ showString "("
|
|
||||||
, prt i e
|
|
||||||
, doc $ showString ","
|
|
||||||
, prt i t
|
|
||||||
, doc $ showString ")"
|
|
||||||
]
|
]
|
||||||
|
|
||||||
instance Print [Bind] where
|
instance Print [Bind] where
|
||||||
|
|
@ -94,12 +87,13 @@ instance Print [Bind] where
|
||||||
prt _ [x] = concatD [prt 0 x]
|
prt _ [x] = concatD [prt 0 x]
|
||||||
prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs]
|
prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs]
|
||||||
|
|
||||||
prtIdPs :: Int -> [Id] -> Doc
|
|
||||||
prtIdPs i = prPrec i 0 . concatD . map (prt i)
|
|
||||||
|
|
||||||
instance Print Exp where
|
instance Print Exp where
|
||||||
prt i = \case
|
prt i = \case
|
||||||
EVar name -> prPrec i 3 $ prt 0 name
|
EVar name -> prPrec i 3 $ prt 0 name
|
||||||
|
EVarC as lident -> doc . showString
|
||||||
|
$ "{" ++ intercalate ", " (map go as) ++ "}" ++ printTree lident
|
||||||
|
where
|
||||||
|
go (x, _) = printTree x ++ "^=" ++ printTree (EVar x)
|
||||||
ELit lit -> prPrec i 3 $ prt 0 lit
|
ELit lit -> prPrec i 3 $ prt 0 lit
|
||||||
ELet b e ->
|
ELet b e ->
|
||||||
prPrec i 3 $
|
prPrec i 3 $
|
||||||
|
|
@ -134,7 +128,7 @@ instance Print Exp where
|
||||||
]
|
]
|
||||||
|
|
||||||
instance Print Branch where
|
instance Print Branch where
|
||||||
prt i (Branch (pattern_, t) exp) = prPrec i 0 (concatD [doc (showString "("), prt 0 pattern_, doc (showString " : "), prt 0 t, doc (showString ")"), doc (showString "=>"), prt 0 exp])
|
prt i (Branch patt exp) = prPrec i 0 (concatD [prt i patt, doc (showString "=>"), prt 0 exp])
|
||||||
|
|
||||||
instance Print [Branch] where
|
instance Print [Branch] where
|
||||||
prt _ [] = concatD []
|
prt _ [] = concatD []
|
||||||
|
|
@ -152,12 +146,12 @@ instance Print Data where
|
||||||
|
|
||||||
instance Print Inj where
|
instance Print Inj where
|
||||||
prt i = \case
|
prt i = \case
|
||||||
Inj uident type_ -> prPrec i 0 (concatD [prt 0 uident, doc (showString ":"), prt 0 type_])
|
Inj uident type_ -> prt i (uident, type_)
|
||||||
|
|
||||||
instance Print Pattern where
|
instance Print Pattern where
|
||||||
prt i = \case
|
prt i = \case
|
||||||
PVar name -> prPrec i 1 (concatD [prt 0 name])
|
PVar name -> prPrec i 1 (concatD [prt 0 name])
|
||||||
PLit (lit, _) -> prPrec i 1 (concatD [prt 0 lit])
|
PLit lit -> prPrec i 1 (concatD [prt 0 lit])
|
||||||
PCatch -> prPrec i 1 (concatD [doc (showString "_")])
|
PCatch -> prPrec i 1 (concatD [doc (showString "_")])
|
||||||
PEnum name -> prPrec i 1 (concatD [prt 0 name])
|
PEnum name -> prPrec i 1 (concatD [prt 0 name])
|
||||||
PInj uident patterns -> prPrec i 0 (concatD [prt 0 uident, prt 1 patterns])
|
PInj uident patterns -> prPrec i 0 (concatD [prt 0 uident, prt 1 patterns])
|
||||||
|
|
@ -176,9 +170,3 @@ instance Print Type where
|
||||||
TLit uident -> prPrec i 1 (concatD [prt 0 uident])
|
TLit uident -> prPrec i 1 (concatD [prt 0 uident])
|
||||||
TFun type_1 type_2 -> prPrec i 0 (concatD [prt 1 type_1, doc (showString "->"), prt 0 type_2])
|
TFun type_1 type_2 -> prPrec i 0 (concatD [prt 1 type_1, doc (showString "->"), prt 0 type_2])
|
||||||
TData uident types -> prPrec i 1 (concatD [prt 0 uident, doc (showString "("), prt 0 types, doc (showString ")")])
|
TData uident types -> prPrec i 1 (concatD [prt 0 uident, doc (showString "("), prt 0 types, doc (showString ")")])
|
||||||
|
|
||||||
instance Print Lit where
|
|
||||||
prt i = \case
|
|
||||||
LInt int -> prt i int
|
|
||||||
LChar char -> prt i char
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ import Control.Applicative (Applicative (liftA2), liftA3)
|
||||||
import Control.Monad.Except (MonadError (throwError))
|
import Control.Monad.Except (MonadError (throwError))
|
||||||
import Data.Coerce (coerce)
|
import Data.Coerce (coerce)
|
||||||
import Data.Tuple.Extra (secondM)
|
import Data.Tuple.Extra (secondM)
|
||||||
import Grammar.Abs qualified as G
|
import qualified Grammar.Abs as G
|
||||||
import Grammar.ErrM (Err)
|
import Grammar.ErrM (Err)
|
||||||
import Grammar.Print (printTree)
|
import Grammar.Print (printTree)
|
||||||
import TypeChecker.TypeCheckerIr hiding (Type (..))
|
import TypeChecker.TypeCheckerIr hiding (Type (..))
|
||||||
|
|
@ -18,7 +18,7 @@ data Type
|
||||||
| TData Ident [Type]
|
| TData Ident [Type]
|
||||||
| TFun Type Type
|
| TFun Type Type
|
||||||
| TAll TVar Type
|
| TAll TVar Type
|
||||||
deriving (Eq, Ord, Show, Read)
|
deriving (Eq, Ord, Show)
|
||||||
|
|
||||||
class ReportTEVar a b where
|
class ReportTEVar a b where
|
||||||
reportTEVar :: a -> Err b
|
reportTEVar :: a -> Err b
|
||||||
|
|
@ -65,10 +65,10 @@ instance ReportTEVar (Data' G.Type) (Data' Type) where
|
||||||
instance ReportTEVar (Inj' G.Type) (Inj' Type) where
|
instance ReportTEVar (Inj' G.Type) (Inj' Type) where
|
||||||
reportTEVar (Inj name typ) = Inj name <$> reportTEVar typ
|
reportTEVar (Inj name typ) = Inj name <$> reportTEVar typ
|
||||||
|
|
||||||
instance ReportTEVar (Id' G.Type) (Id' Type) where
|
instance ReportTEVar (a, G.Type) (a, Type) where
|
||||||
reportTEVar = secondM reportTEVar
|
reportTEVar = secondM reportTEVar
|
||||||
|
|
||||||
instance ReportTEVar (ExpT' G.Type) (ExpT' Type) where
|
instance ReportTEVar (T' Exp' G.Type) (T' Exp' Type) where
|
||||||
reportTEVar (exp, typ) = liftA2 (,) (reportTEVar exp) (reportTEVar typ)
|
reportTEVar (exp, typ) = liftA2 (,) (reportTEVar exp) (reportTEVar typ)
|
||||||
|
|
||||||
instance ReportTEVar a b => ReportTEVar [a] [b] where
|
instance ReportTEVar a b => ReportTEVar [a] [b] where
|
||||||
|
|
|
||||||
|
|
@ -31,6 +31,7 @@ import Grammar.ErrM
|
||||||
import Grammar.Print (printTree)
|
import Grammar.Print (printTree)
|
||||||
import Prelude hiding (exp)
|
import Prelude hiding (exp)
|
||||||
import qualified TypeChecker.TypeCheckerIr as T
|
import qualified TypeChecker.TypeCheckerIr as T
|
||||||
|
import TypeChecker.TypeCheckerIr (T, T')
|
||||||
|
|
||||||
-- Implementation is derived from the paper (Dunfield and Krishnaswami 2013)
|
-- Implementation is derived from the paper (Dunfield and Krishnaswami 2013)
|
||||||
-- https://doi.org/10.1145/2500365.2500582
|
-- https://doi.org/10.1145/2500365.2500582
|
||||||
|
|
@ -172,7 +173,7 @@ typecheckInj (Inj inj_name inj_typ) name tvars
|
||||||
|
|
||||||
-- | Γ ⊢ e ↑ A ⊣ Δ
|
-- | Γ ⊢ e ↑ A ⊣ Δ
|
||||||
-- Under input context Γ, e checks against input type A, with output context ∆
|
-- Under input context Γ, e checks against input type A, with output context ∆
|
||||||
check :: Exp -> Type -> Tc (T.ExpT' Type)
|
check :: Exp -> Type -> Tc (T' T.Exp' Type)
|
||||||
|
|
||||||
-- Γ,α ⊢ e ↑ A ⊣ Δ,α,Θ
|
-- Γ,α ⊢ e ↑ A ⊣ Δ,α,Θ
|
||||||
-- ------------------- ∀I
|
-- ------------------- ∀I
|
||||||
|
|
@ -212,12 +213,6 @@ check (ECase scrut pi) c = do
|
||||||
e' <- check e c
|
e' <- check e c
|
||||||
pure (T.Branch p' e')
|
pure (T.Branch p' e')
|
||||||
apply (T.ECase (scrut', a) pi', c)
|
apply (T.ECase (scrut', a) pi', c)
|
||||||
where
|
|
||||||
go (pi, b) (Branch p e) = do
|
|
||||||
p' <- checkPattern p =<< apply a
|
|
||||||
e'@(_, b') <- infer e
|
|
||||||
subtype b' b
|
|
||||||
apply (T.Branch p' e' : pi, b')
|
|
||||||
|
|
||||||
|
|
||||||
-- Γ,α ⊢ e ↓ A ⊣ Θ Θ ⊢ [Θ]A <: [Θ]B ⊣ Δ
|
-- Γ,α ⊢ e ↓ A ⊣ Θ Θ ⊢ [Θ]A <: [Θ]B ⊣ Δ
|
||||||
|
|
@ -229,9 +224,6 @@ check e b = do
|
||||||
subtype a b'
|
subtype a b'
|
||||||
apply (e', b)
|
apply (e', b)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
checkPattern :: Pattern -> Type -> Tc (T.Pattern' Type, Type)
|
checkPattern :: Pattern -> Type -> Tc (T.Pattern' Type, Type)
|
||||||
checkPattern patt t_patt = case patt of
|
checkPattern patt t_patt = case patt of
|
||||||
|
|
||||||
|
|
@ -297,7 +289,7 @@ checkPattern patt t_patt = case patt of
|
||||||
|
|
||||||
-- | Γ ⊢ e ↓ A ⊣ Δ
|
-- | Γ ⊢ e ↓ A ⊣ Δ
|
||||||
-- Under input context Γ, e infers output type A, with output context ∆
|
-- Under input context Γ, e infers output type A, with output context ∆
|
||||||
infer :: Exp -> Tc (T.ExpT' Type)
|
infer :: Exp -> Tc (T' T.Exp' Type)
|
||||||
infer (ELit lit) = apply (T.ELit lit, litType lit)
|
infer (ELit lit) = apply (T.ELit lit, litType lit)
|
||||||
|
|
||||||
-- Γ ∋ (x : A) Γ ⊢ rec(x)
|
-- Γ ∋ (x : A) Γ ⊢ rec(x)
|
||||||
|
|
@ -391,7 +383,7 @@ infer (ECase scrut pi) = do
|
||||||
-- | Γ ⊢ A • e ⇓ C ⊣ Δ
|
-- | Γ ⊢ A • e ⇓ C ⊣ Δ
|
||||||
-- Under input context Γ , applying a function of type A to e infers type C, with output context ∆
|
-- Under input context Γ , applying a function of type A to e infers type C, with output context ∆
|
||||||
-- Instantiate existential type variables until there is an arrow type.
|
-- Instantiate existential type variables until there is an arrow type.
|
||||||
applyInfer :: Type -> Exp -> Tc (T.ExpT' Type, Type)
|
applyInfer :: Type -> Exp -> Tc (T' T.Exp' Type, Type)
|
||||||
|
|
||||||
-- Γ,ά ⊢ [ά/α]A • e ⇓ C ⊣ Δ
|
-- Γ,ά ⊢ [ά/α]A • e ⇓ C ⊣ Δ
|
||||||
-- ------------------------ ∀App
|
-- ------------------------ ∀App
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@
|
||||||
module TypeChecker.TypeCheckerHm where
|
module TypeChecker.TypeCheckerHm where
|
||||||
|
|
||||||
import Auxiliary (int, litType, maybeToRightM, unzip4)
|
import Auxiliary (int, litType, maybeToRightM, unzip4)
|
||||||
import Auxiliary qualified as Aux
|
import qualified Auxiliary as Aux
|
||||||
import Control.Monad.Except
|
import Control.Monad.Except
|
||||||
import Control.Monad.Identity (Identity, runIdentity)
|
import Control.Monad.Identity (Identity, runIdentity)
|
||||||
import Control.Monad.Reader
|
import Control.Monad.Reader
|
||||||
|
|
@ -19,14 +19,15 @@ import Data.Function (on)
|
||||||
import Data.List (foldl', nub, sortOn)
|
import Data.List (foldl', nub, sortOn)
|
||||||
import Data.List.Extra (unsnoc)
|
import Data.List.Extra (unsnoc)
|
||||||
import Data.Map (Map)
|
import Data.Map (Map)
|
||||||
import Data.Map qualified as M
|
import qualified Data.Map as M
|
||||||
import Data.Maybe (fromJust)
|
import Data.Maybe (fromJust)
|
||||||
import Data.Set (Set)
|
import Data.Set (Set)
|
||||||
import Data.Set qualified as S
|
import qualified Data.Set as S
|
||||||
import Debug.Trace (trace, traceShow)
|
import Debug.Trace (trace, traceShow)
|
||||||
import Grammar.Abs
|
import Grammar.Abs
|
||||||
import Grammar.Print (printTree)
|
import Grammar.Print (printTree)
|
||||||
import TypeChecker.TypeCheckerIr qualified as T
|
import TypeChecker.TypeCheckerIr (T, T')
|
||||||
|
import qualified TypeChecker.TypeCheckerIr as T
|
||||||
|
|
||||||
{-
|
{-
|
||||||
TODO
|
TODO
|
||||||
|
|
@ -265,7 +266,7 @@ returnType :: Type -> Type
|
||||||
returnType (TFun _ t2) = returnType t2
|
returnType (TFun _ t2) = returnType t2
|
||||||
returnType a = a
|
returnType a = a
|
||||||
|
|
||||||
inferExp :: Exp -> Infer (T.ExpT' Type)
|
inferExp :: Exp -> Infer (T' T.Exp' Type)
|
||||||
inferExp e = do
|
inferExp e = do
|
||||||
(s, (e', t)) <- algoW e
|
(s, (e', t)) <- algoW e
|
||||||
let subbed = apply s t
|
let subbed = apply s t
|
||||||
|
|
@ -289,7 +290,7 @@ instance CollectTVars Type where
|
||||||
collect :: Set T.Ident -> Infer ()
|
collect :: Set T.Ident -> Infer ()
|
||||||
collect s = modify (\st -> st{takenTypeVars = s `S.union` takenTypeVars st})
|
collect s = modify (\st -> st{takenTypeVars = s `S.union` takenTypeVars st})
|
||||||
|
|
||||||
algoW :: Exp -> Infer (Subst, T.ExpT' Type)
|
algoW :: Exp -> Infer (Subst, T' T.Exp' Type)
|
||||||
algoW = \case
|
algoW = \case
|
||||||
err@(EAnn e t) -> do
|
err@(EAnn e t) -> do
|
||||||
(sub0, (e', t')) <- exprErr (algoW e) err
|
(sub0, (e', t')) <- exprErr (algoW e) err
|
||||||
|
|
@ -721,7 +722,7 @@ instance SubstType (Map T.Ident Type) where
|
||||||
instance SubstType (Map T.Ident (Maybe Type)) where
|
instance SubstType (Map T.Ident (Maybe Type)) where
|
||||||
apply s = M.map (fmap $ apply s)
|
apply s = M.map (fmap $ apply s)
|
||||||
|
|
||||||
instance SubstType (T.ExpT' Type) where
|
instance SubstType (T' T.Exp' Type) where
|
||||||
apply s (e, t) = (apply s e, apply s t)
|
apply s (e, t) = (apply s e, apply s t)
|
||||||
|
|
||||||
instance SubstType (T.Exp' Type) where
|
instance SubstType (T.Exp' Type) where
|
||||||
|
|
@ -761,7 +762,7 @@ instance SubstType (T.Pattern' Type, Type) where
|
||||||
instance SubstType a => SubstType [a] where
|
instance SubstType a => SubstType [a] where
|
||||||
apply s = map (apply s)
|
apply s = map (apply s)
|
||||||
|
|
||||||
instance SubstType (T.Id' Type) where
|
instance SubstType (T T.Ident Type) where
|
||||||
apply s (name, t) = (name, apply s t)
|
apply s (name, t) = (name, apply s t)
|
||||||
|
|
||||||
-- | Represents the empty substition set
|
-- | Represents the empty substition set
|
||||||
|
|
|
||||||
|
|
@ -10,31 +10,30 @@ import Data.String (IsString)
|
||||||
import Grammar.Abs (Lit (..))
|
import Grammar.Abs (Lit (..))
|
||||||
import Grammar.Print
|
import Grammar.Print
|
||||||
import Prelude
|
import Prelude
|
||||||
import qualified Prelude as C (Eq, Ord, Read, Show)
|
|
||||||
|
|
||||||
newtype Program' t = Program [Def' t]
|
newtype Program' t = Program [Def' t]
|
||||||
deriving (C.Eq, C.Ord, C.Show, C.Read, Functor)
|
deriving (Eq, Ord, Show, Functor)
|
||||||
|
|
||||||
data Def' t
|
data Def' t
|
||||||
= DBind (Bind' t)
|
= DBind (Bind' t)
|
||||||
| DData (Data' t)
|
| DData (Data' t)
|
||||||
deriving (C.Eq, C.Ord, C.Show, C.Read, Functor)
|
deriving (Eq, Ord, Show, Functor)
|
||||||
|
|
||||||
data Type
|
data Type
|
||||||
= TLit Ident
|
= TLit Ident
|
||||||
| TVar TVar
|
| TVar TVar
|
||||||
| TData Ident [Type]
|
| TData Ident [Type]
|
||||||
| TFun Type Type
|
| TFun Type Type
|
||||||
deriving (Eq, Ord, Show, Read)
|
deriving (Eq, Ord, Show)
|
||||||
|
|
||||||
data Data' t = Data t [Inj' t]
|
data Data' t = Data t [Inj' t]
|
||||||
deriving (C.Eq, C.Ord, C.Show, C.Read, Functor)
|
deriving (Eq, Ord, Show, Functor)
|
||||||
|
|
||||||
data Inj' t = Inj Ident t
|
data Inj' t = Inj Ident t
|
||||||
deriving (C.Eq, C.Ord, C.Show, C.Read, Functor)
|
deriving (Eq, Ord, Show, Functor)
|
||||||
|
|
||||||
newtype Ident = Ident String
|
newtype Ident = Ident String
|
||||||
deriving (C.Eq, C.Ord, C.Show, C.Read, IsString)
|
deriving (Eq, Ord, Show, IsString)
|
||||||
|
|
||||||
data Pattern' t
|
data Pattern' t
|
||||||
= PVar Ident
|
= PVar Ident
|
||||||
|
|
@ -42,30 +41,31 @@ data Pattern' t
|
||||||
| PCatch
|
| PCatch
|
||||||
| PEnum Ident
|
| PEnum Ident
|
||||||
| PInj Ident [(Pattern' t, t)]
|
| PInj Ident [(Pattern' t, t)]
|
||||||
deriving (C.Eq, C.Ord, C.Show, C.Read, Functor)
|
deriving (Eq, Ord, Show, Functor)
|
||||||
|
|
||||||
data Exp' t
|
data Exp' t
|
||||||
= EVar Ident
|
= EVar Ident
|
||||||
| EInj Ident
|
| EInj Ident
|
||||||
| ELit Lit
|
| ELit Lit
|
||||||
| ELet (Bind' t) (ExpT' t)
|
| ELet (Bind' t) (T' Exp' t)
|
||||||
| EApp (ExpT' t) (ExpT' t)
|
| EApp (T' Exp' t) (T' Exp' t)
|
||||||
| EAdd (ExpT' t) (ExpT' t)
|
| EAdd (T' Exp' t) (T' Exp' t)
|
||||||
| EAbs Ident (ExpT' t)
|
| EAbs Ident (T' Exp' t)
|
||||||
| ECase (ExpT' t) [Branch' t]
|
| ECase (T' Exp' t) [Branch' t]
|
||||||
deriving (C.Eq, C.Ord, C.Show, C.Read, Functor)
|
deriving (Eq, Ord, Show, Functor)
|
||||||
|
|
||||||
newtype TVar = MkTVar Ident
|
newtype TVar = MkTVar Ident
|
||||||
deriving (C.Eq, C.Ord, C.Show, C.Read)
|
deriving (Eq, Ord, Show)
|
||||||
|
|
||||||
type Id' t = (Ident, t)
|
type T' a t = (a t, t)
|
||||||
type ExpT' t = (Exp' t, t)
|
type T a t = (a, t)
|
||||||
|
|
||||||
data Bind' t = Bind (Id' t) [Id' t] (ExpT' t)
|
|
||||||
deriving (C.Eq, C.Ord, C.Show, C.Read, Functor)
|
|
||||||
|
|
||||||
data Branch' t = Branch (Pattern' t, t) (ExpT' t)
|
data Bind' t = Bind (T Ident t) [T Ident t] (T' Exp' t)
|
||||||
deriving (C.Eq, C.Ord, C.Show, C.Read, Functor)
|
deriving (Eq, Ord, Show, Functor)
|
||||||
|
|
||||||
|
data Branch' t = Branch (T' Pattern' t) (T' Exp' t)
|
||||||
|
deriving (Eq, Ord, Show, Functor)
|
||||||
|
|
||||||
instance Print Ident where
|
instance Print Ident where
|
||||||
prt _ (Ident s) = doc $ showString s
|
prt _ (Ident s) = doc $ showString s
|
||||||
|
|
@ -81,22 +81,22 @@ instance Print t => Print (Bind' t) where
|
||||||
, prt i rhs
|
, prt i rhs
|
||||||
]
|
]
|
||||||
|
|
||||||
prtSig :: Print t => Id' t -> Doc
|
prtSig :: Print t => T Ident t -> Doc
|
||||||
prtSig (name, t) =
|
prtSig (x, t) =
|
||||||
concatD
|
concatD
|
||||||
[ prt 0 name
|
[ prt 0 x
|
||||||
, doc $ showString ":"
|
, doc $ showString ":"
|
||||||
, prt 0 t
|
, prt 0 t
|
||||||
]
|
]
|
||||||
|
|
||||||
instance Print t => Print (ExpT' t) where
|
instance (Print a, Print t) => Print (T a t) where
|
||||||
prt i (e, t) =
|
prt i (x, t) =
|
||||||
concatD
|
concatD
|
||||||
[ doc $ showString "("
|
[ -- doc $ showString "("
|
||||||
, prt i e
|
{- , -} prt i x
|
||||||
, doc $ showString ":"
|
-- , doc $ showString ":"
|
||||||
, prt 0 t
|
-- , prt 0 t
|
||||||
, doc $ showString ")"
|
-- , doc $ showString ")"
|
||||||
]
|
]
|
||||||
|
|
||||||
instance Print t => Print [Bind' t] where
|
instance Print t => Print [Bind' t] where
|
||||||
|
|
@ -104,16 +104,6 @@ instance Print t => Print [Bind' t] where
|
||||||
prt i [x] = concatD [prt i x]
|
prt i [x] = concatD [prt i x]
|
||||||
prt i (x : xs) = concatD [prt i x, doc (showString ";"), prt i xs]
|
prt i (x : xs) = concatD [prt i x, doc (showString ";"), prt i xs]
|
||||||
|
|
||||||
instance Print t => Print (Id' t) where
|
|
||||||
prt i (name, t) =
|
|
||||||
concatD
|
|
||||||
[ doc $ showString "("
|
|
||||||
, prt i name
|
|
||||||
, doc $ showString ","
|
|
||||||
, prt i t
|
|
||||||
, doc $ showString ")"
|
|
||||||
]
|
|
||||||
|
|
||||||
instance Print t => Print (Exp' t) where
|
instance Print t => Print (Exp' t) where
|
||||||
prt i = \case
|
prt i = \case
|
||||||
EVar lident -> prPrec i 3 (concatD [prt 0 lident])
|
EVar lident -> prPrec i 3 (concatD [prt 0 lident])
|
||||||
|
|
@ -151,9 +141,6 @@ instance Print t => Print [Inj' t] where
|
||||||
prt i [x] = prt i x
|
prt i [x] = prt i x
|
||||||
prt i (x : xs) = prPrec i 0 $ concatD [prt i x, doc $ showString "\n ", prt i xs]
|
prt i (x : xs) = prPrec i 0 $ concatD [prt i x, doc $ showString "\n ", prt i xs]
|
||||||
|
|
||||||
instance Print t => Print (Pattern' t, t) where
|
|
||||||
prt i (p, t) = prPrec i 1 (concatD [prt i p, prt i t])
|
|
||||||
|
|
||||||
instance Print t => Print (Pattern' t) where
|
instance Print t => Print (Pattern' t) where
|
||||||
prt i = \case
|
prt i = \case
|
||||||
PVar name -> prPrec i 1 (concatD [prt 0 name])
|
PVar name -> prPrec i 1 (concatD [prt 0 name])
|
||||||
|
|
@ -189,8 +176,6 @@ type Branch = Branch' Type
|
||||||
type Pattern = Pattern' Type
|
type Pattern = Pattern' Type
|
||||||
type Inj = Inj' Type
|
type Inj = Inj' Type
|
||||||
type Exp = Exp' Type
|
type Exp = Exp' Type
|
||||||
type ExpT = ExpT' Type
|
|
||||||
type Id = Id' Type
|
|
||||||
pattern TVar' s = TVar (MkTVar s)
|
pattern TVar' s = TVar (MkTVar s)
|
||||||
pattern DBind' id vars expt = DBind (Bind id vars expt)
|
pattern DBind' id vars expt = DBind (Bind id vars expt)
|
||||||
pattern DData' typ injs = DData (Data typ injs)
|
pattern DData' typ injs = DData (Data typ injs)
|
||||||
|
|
|
||||||
185
test_map2.ll
Normal file
185
test_map2.ll
Normal file
|
|
@ -0,0 +1,185 @@
|
||||||
|
target triple = "x86_64-pc-linux-gnu"
|
||||||
|
target datalayout = "e-m:o-i64:64-f80:128-n8:16:32:64-S128"
|
||||||
|
@.str = private unnamed_addr constant [3 x i8] c"%i
|
||||||
|
", align 1
|
||||||
|
@.non_exhaustive_patterns = private unnamed_addr constant [41 x i8] c"Non-exhaustive patterns in case at %i:%i
|
||||||
|
"
|
||||||
|
declare i32 @printf(ptr noalias nocapture, ...)
|
||||||
|
declare i32 @exit(i32 noundef)
|
||||||
|
declare ptr @malloc(i32 noundef)
|
||||||
|
%List = type { i8, [23 x i8] }
|
||||||
|
%Cons = type { i8, i64, %List* }
|
||||||
|
%Nil = type { i8 }
|
||||||
|
; NYTT: kontexttyp
|
||||||
|
%Closure_sc_0 = type { i64 (i64)*, i64 }
|
||||||
|
|
||||||
|
; Ident "sum$List_Int": (ECase (EVar (Ident "$4xs"),TLit (Ident "List")) [Branch (PEnum (Ident "Nil"),TLit (Ident "List")) (ELit (LInt 0),TLit (Ident "Int")),Branch (PInj (Ident "Cons") [PVar (Ident "$5x",TLit (Ident "Int")),PVar (Ident "$6xs",TLit (Ident "List"))],TLit (Ident "List")) (EAdd (EVar (Ident "$5x"),TLit (Ident "Int")) (EApp (EVar (Ident "sum$List_Int"),TFun (TLit (Ident "List")) (TLit (Ident "Int"))) (EVar (Ident "$6xs"),TLit (Ident "List")),TLit (Ident "Int")),TLit (Ident "Int"))],TLit (Ident "Int"))
|
||||||
|
define fastcc i64 @sum$List_Int(%List %$4xs) {
|
||||||
|
%1 = alloca i64
|
||||||
|
; Penum
|
||||||
|
%2 = extractvalue %List %$4xs, 0
|
||||||
|
%3 = icmp eq i8 %2, 1
|
||||||
|
br i1 %3, label %lbl_success_3, label %lbl_failed_2
|
||||||
|
|
||||||
|
lbl_success_3:
|
||||||
|
%4 = alloca %List
|
||||||
|
store %List %$4xs, ptr %4
|
||||||
|
%5 = load %Nil, ptr %4
|
||||||
|
store i64 0, ptr %1
|
||||||
|
br label %lbl_escape_1
|
||||||
|
|
||||||
|
lbl_failed_2:
|
||||||
|
; Inj
|
||||||
|
%6 = extractvalue %List %$4xs, 0
|
||||||
|
%7 = icmp eq i8 %6, 0
|
||||||
|
br i1 %7, label %lbl_success_5, label %lbl_failed_4
|
||||||
|
|
||||||
|
lbl_success_5:
|
||||||
|
%8 = alloca %List
|
||||||
|
store %List %$4xs, ptr %8
|
||||||
|
%9 = load %Cons, ptr %8
|
||||||
|
; ident i64
|
||||||
|
%$5x = extractvalue %Cons %9, 1
|
||||||
|
; ident %List
|
||||||
|
%10 = extractvalue %Cons %9, 2
|
||||||
|
%$6xs = load %List, ptr %10
|
||||||
|
; TLit (Ident "Int")
|
||||||
|
%11 = call fastcc i64 @sum$List_Int(%List %$6xs)
|
||||||
|
%12 = add i64 %$5x, %11
|
||||||
|
store i64 %12, ptr %1
|
||||||
|
br label %lbl_escape_1
|
||||||
|
|
||||||
|
lbl_failed_4:
|
||||||
|
call i32 (ptr, ...) @printf(ptr noundef @.non_exhaustive_patterns, i64 noundef 12, i64 noundef 6)
|
||||||
|
call i32 @exit(i32 noundef 1)
|
||||||
|
br label %lbl_escape_1
|
||||||
|
|
||||||
|
lbl_escape_1:
|
||||||
|
%15 = load i64, ptr %1
|
||||||
|
ret i64 %15
|
||||||
|
}
|
||||||
|
|
||||||
|
; Ident "sc_0$Int_Int": (EAdd (EVar (Ident "$7x"),TLit (Ident "Int")) (ELit (LInt 10),TLit (Ident "Int")),TLit (Ident "Int"))
|
||||||
|
; ÄNDRAT: lägg till kontextpekare
|
||||||
|
define fastcc i64 @sc_0$Int_Int(ptr %closure_sc_0, i64 %$7x) {
|
||||||
|
; NYTT: Ladda alla fria variabler
|
||||||
|
%fri_variabel_ptr = getelementptr inbounds %Closure_sc_0, ptr %closure_sc_0, i32 0, i32 1
|
||||||
|
%fri_variabel = load i64, ptr %fri_variabel_ptr
|
||||||
|
; ÄNDRAT: %fri_variabel istället för 2
|
||||||
|
%1 = add i64 %$7x, %fri_variabel
|
||||||
|
ret i64 %1
|
||||||
|
}
|
||||||
|
|
||||||
|
; Ident "map$Int_Int_List_List": (ECase (EVar (Ident "$1xs"),TLit (Ident "List")) [Branch (PEnum (Ident "Nil"),TLit (Ident "List")) (EVar (Ident "Nil"),TLit (Ident "List")),Branch (PInj (Ident "Cons") [PVar (Ident "$2x",TLit (Ident "Int")),PVar (Ident "$3xs",TLit (Ident "List"))],TLit (Ident "List")) (EApp (EApp (EVar (Ident "Cons"),TFun (TLit (Ident "Int")) (TFun (TLit (Ident "List")) (TLit (Ident "List")))) (EApp (EVar (Ident "$0f"),TFun (TLit (Ident "Int")) (TLit (Ident "Int"))) (EVar (Ident "$2x"),TLit (Ident "Int")),TLit (Ident "Int")),TFun (TLit (Ident "List")) (TLit (Ident "List"))) (EApp (EApp (EVar (Ident "map$Int_Int_List_List"),TFun (TFun (TLit (Ident "Int")) (TLit (Ident "Int"))) (TFun (TLit (Ident "List")) (TLit (Ident "List")))) (EVar (Ident "$0f"),TFun (TLit (Ident "Int")) (TLit (Ident "Int"))),TFun (TLit (Ident "List")) (TLit (Ident "List"))) (EVar (Ident "$3xs"),TLit (Ident "List")),TLit (Ident "List")),TLit (Ident "List"))],TLit (Ident "List"))
|
||||||
|
; ÄNDRAT: ptr istället för i64 (i64)*
|
||||||
|
define fastcc %List @map$Int_Int_List_List(ptr %$0f, %List %$1xs) {
|
||||||
|
; NYTT: ta fram funktionspekaren
|
||||||
|
%$0f_deref = load i64(i64)*, ptr %$0f
|
||||||
|
|
||||||
|
%1 = alloca %List
|
||||||
|
; Penum
|
||||||
|
%2 = extractvalue %List %$1xs, 0
|
||||||
|
%3 = icmp eq i8 %2, 1
|
||||||
|
br i1 %3, label %lbl_success_8, label %lbl_failed_7
|
||||||
|
|
||||||
|
lbl_success_8:
|
||||||
|
%4 = alloca %List
|
||||||
|
store %List %$1xs, ptr %4
|
||||||
|
%5 = load %Nil, ptr %4
|
||||||
|
%6 = call fastcc %List @Nil()
|
||||||
|
store %List %6, ptr %1
|
||||||
|
br label %lbl_escape_6
|
||||||
|
|
||||||
|
lbl_failed_7:
|
||||||
|
; Inj
|
||||||
|
%7 = extractvalue %List %$1xs, 0
|
||||||
|
%8 = icmp eq i8 %7, 0
|
||||||
|
br i1 %8, label %lbl_success_10, label %lbl_failed_9
|
||||||
|
|
||||||
|
lbl_success_10:
|
||||||
|
%9 = alloca %List
|
||||||
|
store %List %$1xs, ptr %9
|
||||||
|
%10 = load %Cons, ptr %9
|
||||||
|
; ident i64
|
||||||
|
%$2x = extractvalue %Cons %10, 1
|
||||||
|
; ident %List
|
||||||
|
%11 = extractvalue %Cons %10, 2
|
||||||
|
%$3xs = load %List, ptr %11
|
||||||
|
; TLit (Ident "Int")
|
||||||
|
; ÄNDRAT använd deref
|
||||||
|
%12 = call fastcc i64 %$0f_deref(ptr %$0f, i64 %$2x)
|
||||||
|
; TLit (Ident "List")
|
||||||
|
; ÄNDRAT ptr istället för 64 (64)* och skicka med ptr
|
||||||
|
%13 = call fastcc %List @map$Int_Int_List_List(ptr %$0f, %List %$3xs)
|
||||||
|
; TLit (Ident "List")
|
||||||
|
%14 = call fastcc %List @Cons(i64 %12, %List %13)
|
||||||
|
store %List %14, ptr %1
|
||||||
|
br label %lbl_escape_6
|
||||||
|
|
||||||
|
lbl_failed_9:
|
||||||
|
call i32 (ptr, ...) @printf(ptr noundef @.non_exhaustive_patterns, i64 noundef 14, i64 noundef 6)
|
||||||
|
call i32 @exit(i32 noundef 1)
|
||||||
|
br label %lbl_escape_6
|
||||||
|
|
||||||
|
lbl_escape_6:
|
||||||
|
%17 = load %List, ptr %1
|
||||||
|
ret %List %17
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
; Ident "main": (EApp (EVar (Ident "sum$List_Int"),TFun (TLit (Ident "List")) (TLit (Ident "Int"))) (EApp (EApp (EVar (Ident "map$Int_Int_List_List"),TFun (TFun (TLit (Ident "Int")) (TLit (Ident "Int"))) (TFun (TLit (Ident "List")) (TLit (Ident "List")))) (EVar (Ident "sc_0$Int_Int"),TFun (TLit (Ident "Int")) (TLit (Ident "Int"))),TFun (TLit (Ident "List")) (TLit (Ident "List"))) (EApp (EApp (EVar (Ident "Cons"),TFun (TLit (Ident "Int")) (TFun (TLit (Ident "List")) (TLit (Ident "List")))) (ELit (LInt 1),TLit (Ident "Int")),TFun (TLit (Ident "List")) (TLit (Ident "List"))) (EApp (EApp (EVar (Ident "Cons"),TFun (TLit (Ident "Int")) (TFun (TLit (Ident "List")) (TLit (Ident "List")))) (ELit (LInt 2),TLit (Ident "Int")),TFun (TLit (Ident "List")) (TLit (Ident "List"))) (EVar (Ident "Nil"),TLit (Ident "List")),TLit (Ident "List")),TLit (Ident "List")),TLit (Ident "List")),TLit (Ident "Int"))
|
||||||
|
define fastcc i64 @main() {
|
||||||
|
%1 = call fastcc %List @Nil()
|
||||||
|
; TLit (Ident "List")
|
||||||
|
%2 = call fastcc %List @Cons(i64 2, %List %1)
|
||||||
|
; TLit (Ident "List")
|
||||||
|
%3 = call fastcc %List @Cons(i64 1, %List %2)
|
||||||
|
; TLit (Ident "List")
|
||||||
|
|
||||||
|
; NYTT: spara funktionspekaren och 100 i kontexten
|
||||||
|
%closure_sc_0 = alloca %Closure_sc_0
|
||||||
|
store i64(i64)* @sc_0$Int_Int, ptr %closure_sc_0
|
||||||
|
%fri_variabel_ptr = getelementptr inbounds %Closure_sc_0, ptr %closure_sc_0, i32 0, i32 1
|
||||||
|
store i64 100, ptr %fri_variabel_ptr
|
||||||
|
|
||||||
|
|
||||||
|
; store %Closure_sc_0 {i64 (i64)* @sc_0$Int_Int, 100}, ptr %closure_sc_0
|
||||||
|
|
||||||
|
; ÄNDRAT ptr %closure_sc_0 istället för i64 (i64)* @sc_0$Int_Int
|
||||||
|
%4 = call fastcc %List @map$Int_Int_List_List(ptr %closure_sc_0, %List %3)
|
||||||
|
; TLit (Ident "Int")
|
||||||
|
%5 = call fastcc i64 @sum$List_Int(%List %4)
|
||||||
|
call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef %5)
|
||||||
|
ret i64 0
|
||||||
|
}
|
||||||
|
|
||||||
|
define fastcc %List @Cons(i64 %arg_0, %List %arg_1) {
|
||||||
|
%1 = alloca %List
|
||||||
|
%2 = getelementptr %List, %List* %1, i64 0, i32 0
|
||||||
|
store i8 0, i8* %2
|
||||||
|
%3 = bitcast %List* %1 to %Cons*
|
||||||
|
; i64 arg_0 1
|
||||||
|
%4 = getelementptr %Cons, %Cons* %3, i64 0, i32 1
|
||||||
|
; Just store
|
||||||
|
store i64 %arg_0, ptr %4
|
||||||
|
; %List arg_1 2
|
||||||
|
%5 = getelementptr %Cons, %Cons* %3, i64 0, i32 2
|
||||||
|
; Malloc and store
|
||||||
|
%6 = call ptr @malloc(i64 24)
|
||||||
|
store %List %arg_1, ptr %6
|
||||||
|
store %List* %6, ptr %5
|
||||||
|
; Return the newly constructed value
|
||||||
|
%7 = load %List, ptr %1
|
||||||
|
ret %List %7
|
||||||
|
}
|
||||||
|
|
||||||
|
define fastcc %List @Nil() {
|
||||||
|
%1 = alloca %List
|
||||||
|
%2 = getelementptr %List, %List* %1, i64 0, i32 0
|
||||||
|
store i8 1, i8* %2
|
||||||
|
%3 = bitcast %List* %1 to %Nil*
|
||||||
|
; Return the newly constructed value
|
||||||
|
%4 = load %List, ptr %1
|
||||||
|
ret %List %4
|
||||||
|
}
|
||||||
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue