Merge closures mostly done. Desugaring cases is a problem.

This commit is contained in:
Martin Fredin 2023-05-06 23:38:56 +02:00
commit 019ed0d45a
29 changed files with 1484 additions and 757 deletions

View file

@ -43,6 +43,7 @@ executable language
TypeChecker.ReportTEVar TypeChecker.ReportTEVar
TypeChecker.RemoveForall TypeChecker.RemoveForall
LambdaLifter LambdaLifter
LambdaLifterIr
Monomorphizer.Monomorphizer Monomorphizer.Monomorphizer
Monomorphizer.MonomorphizerIr Monomorphizer.MonomorphizerIr
Monomorphizer.MorbIr Monomorphizer.MorbIr
@ -101,6 +102,8 @@ Test-suite language-testsuite
TypeChecker.TypeChecker TypeChecker.TypeChecker
AnnForall AnnForall
ReportForall ReportForall
LambdaLifterIr
LambdaLifter
TypeChecker.TypeCheckerHm TypeChecker.TypeCheckerHm
TypeChecker.TypeCheckerBidir TypeChecker.TypeCheckerBidir
TypeChecker.ReportTEVar TypeChecker.ReportTEVar

View file

@ -0,0 +1,6 @@
add : Int -> Int -> Int -> Int
add x y z = x + y + z
main = add 8 6 2

View file

@ -0,0 +1,7 @@
apply : (Int -> Int) -> Int -> Int
apply f y = f y
main = apply (\y. y + y) 5

View file

@ -0,0 +1,10 @@
apply : (Int -> Int) -> Int -> Int
apply f z = f z
main =
let x = 10 in
apply (\y. y + x) 6

View file

@ -0,0 +1,15 @@
data List a where
Nil : List a
Cons : a -> List a -> List a
foldr : (a -> b -> b) -> b -> List a -> b
foldr f y xs = case xs of
Nil => y
Cons x xs => f x (foldr f y xs)
main = let z = 2 in foldr (\x.\y. x + y + z) 0 (Cons 1000 (Cons 100 Nil))

View file

@ -0,0 +1,25 @@
data List (a) where
Nil : List (a)
Cons : a -> List (a) -> List (a)
map : (a -> b) -> List (a) -> List (b)
map f xs = case xs of
Nil => Nil
Cons x xs => Cons (f x) (map f xs)
add : Int -> Int -> Int
add x y = x + y
foldr : (a -> b -> b) -> b -> List (a) -> b
foldr f y xs = case xs of
Nil => y
Cons x xs => f x (foldr f y xs)
f : List (Int)
f = ((\x.\ys. map (\y. add y x) ys) 4 (Cons 1 (Cons 2 Nil)))
-- [5, 6]
main : Int
main = foldr add 0 f

View file

@ -0,0 +1,21 @@
data List a where
Nil : List a
Cons : a -> List a -> List a
map : (a -> b) -> List a -> List b
map f xs = case xs of
Nil => Nil
Cons x xs => Cons (f x) (map f xs)
f : List Int
f = (\x.\ys. map (\y. y + x) ys) 4 (Cons 1 (Cons 2 Nil))
-- [5, 6]
sum : List Int -> Int
sum xs = case xs of
Nil => 0
Cons x xs => x + sum xs
main = sum f

View file

@ -0,0 +1,3 @@
main = let x = 10 in 6 + x

View file

@ -0,0 +1,16 @@
data List a where
Nil : List a
Cons : a -> List a -> List a
map : (a -> b) -> List a -> List b
map f xs = case xs of
Nil => Nil
Cons x xs => Cons (f x) (map f xs)
sum : List Int -> Int
sum xs = case xs of
Nil => 0
Cons x xs => x + (sum xs)
main = let y = 10 in sum (map (\x. x + y) (Cons 2 (Cons 4 Nil)))

View file

@ -0,0 +1,7 @@
f = 10
main = f + 6

View file

@ -2,8 +2,8 @@ module Codegen.Auxillary where
import Codegen.LlvmIr (LLVMType (..), LLVMValue (..)) import Codegen.LlvmIr (LLVMType (..), LLVMValue (..))
import Control.Monad (foldM_) import Control.Monad (foldM_)
import Monomorphizer.MonomorphizerIr as MIR (ExpT, Type (..)) import Monomorphizer.MonomorphizerIr as MIR (Exp, T, Type (..))
import TypeChecker.TypeCheckerIr qualified as TIR import qualified TypeChecker.TypeCheckerIr as TIR
type2LlvmType :: MIR.Type -> LLVMType type2LlvmType :: MIR.Type -> LLVMType
type2LlvmType (MIR.TLit id@(TIR.Ident name)) = case name of type2LlvmType (MIR.TLit id@(TIR.Ident name)) = case name of
@ -19,7 +19,7 @@ type2LlvmType (MIR.TFun t xs) = do
function2LLVMType (TFun t xs) s = function2LLVMType xs (type2LlvmType t : s) function2LLVMType (TFun t xs) s = function2LLVMType xs (type2LlvmType t : s)
function2LLVMType x s = (type2LlvmType x, s) function2LLVMType x s = (type2LlvmType x, s)
getType :: ExpT -> LLVMType getType :: T Exp -> LLVMType
getType (_, t) = type2LlvmType t getType (_, t) = type2LlvmType t
extractTypeName :: MIR.Type -> TIR.Ident extractTypeName :: MIR.Type -> TIR.Ident

View file

@ -1,17 +1,23 @@
{-# LANGUAGE OverloadedRecordDot #-}
{-# LANGUAGE OverloadedStrings #-}
module Codegen.Codegen (generateCode) where module Codegen.Codegen (generateCode) where
import Codegen.CompilerState ( import Codegen.CompilerState (CodeGenerator (..),
CodeGenerator (instructions), StructType (inst),
initCodeGenerator, initCodeGenerator)
)
import Codegen.Emits (compileScs) import Codegen.Emits (compileScs)
import Codegen.LlvmIr as LIR (llvmIrToString) import Codegen.LlvmIr as LIR (LLVMIr (UnsafeRaw),
import Control.Monad.State ( llvmIrToString)
execStateT, import Control.Monad.State (execStateT)
) import Data.Functor ((<&>))
import Data.List (sortBy) import Data.List (sortBy)
import qualified Data.Map as Map
import Grammar.ErrM (Err) import Grammar.ErrM (Err)
import Monomorphizer.MonomorphizerIr as MIR (Bind (..), Data (..), Def (DBind, DData), Program (..), Type (TLit)) import Monomorphizer.MonomorphizerIr as MIR (Bind (..), Data (..),
Def (DBind, DData),
Program (..),
Type (TLit))
import TypeChecker.TypeCheckerIr (Ident (..)) import TypeChecker.TypeCheckerIr (Ident (..))
{- | Compiles an AST and produces a LLVM Ir string. {- | Compiles an AST and produces a LLVM Ir string.
@ -21,8 +27,14 @@ import TypeChecker.TypeCheckerIr (Ident (..))
generateCode :: MIR.Program -> Bool -> Err String generateCode :: MIR.Program -> Bool -> Err String
generateCode (MIR.Program scs) addGc = do generateCode (MIR.Program scs) addGc = do
let tree = filter (not . detectPrelude) (sortBy lowData scs) let tree = filter (not . detectPrelude) (sortBy lowData scs)
let codegen = initCodeGenerator addGc tree codegen = initCodeGenerator addGc tree
llvmIrToString . instructions <$> execStateT (compileScs tree) codegen
-- Append instructions
execStateT (compileScs tree) codegen <&> \state ->
llvmIrToString $ defaultStart
++ (if addGc then gcStart else [])
++ map inst (Map.elems state.structTypes)
++ state.instructions
detectPrelude :: Def -> Bool detectPrelude :: Def -> Bool
detectPrelude (DData (Data (TLit (Ident "Bool")) _)) = True detectPrelude (DData (Data (TLit (Ident "Bool")) _)) = True
@ -33,3 +45,24 @@ lowData :: Def -> Def -> Ordering
lowData (DData _) (DBind _) = LT lowData (DData _) (DBind _) = LT
lowData (DBind _) (DData _) = GT lowData (DBind _) (DData _) = GT
lowData _ _ = EQ lowData _ _ = EQ
defaultStart :: [LLVMIr]
defaultStart =
[ UnsafeRaw "target triple = \"x86_64-pc-linux-gnu\"\n"
, UnsafeRaw "target datalayout = \"e-m:o-i64:64-f80:128-n8:16:32:64-S128\"\n"
, UnsafeRaw "@.str = private unnamed_addr constant [3 x i8] c\"%i\n\", align 1\n"
, UnsafeRaw "@.non_exhaustive_patterns = private unnamed_addr constant [41 x i8] c\"Non-exhaustive patterns in case at %i:%i\n\"\n"
, UnsafeRaw "declare i32 @printf(ptr noalias nocapture, ...)\n"
, UnsafeRaw "declare i32 @exit(i32 noundef)\n"
, UnsafeRaw "declare ptr @malloc(i32 noundef)\n"
]
gcStart :: [LLVMIr]
gcStart =
[ UnsafeRaw "declare external void @cheap_init()\n"
, UnsafeRaw "declare external ptr @cheap_alloc(i64)\n"
, UnsafeRaw "declare external void @cheap_dispose()\n"
, UnsafeRaw "declare external ptr @cheap_the()\n"
, UnsafeRaw "declare external void @cheap_set_profiler(ptr, i1)\n"
, UnsafeRaw "declare external void @cheap_profiler_log_options(ptr, i64)\n"
]

View file

@ -1,46 +1,101 @@
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedRecordDot #-}
{-# LANGUAGE OverloadedStrings #-}
module Codegen.CompilerState where module Codegen.CompilerState where
import Auxiliary (snoc) import Auxiliary (snoc)
import Codegen.Auxillary (type2LlvmType, typeByteSize) import Codegen.Auxillary (type2LlvmType, typeByteSize)
import Codegen.LlvmIr as LIR (LLVMIr (UnsafeRaw), import Codegen.LlvmIr as LIR (LLVMIr (SetVariable, Type),
LLVMType) LLVMType (CustomType, Function, I64, Ptr),
import Control.Monad.State (StateT, gets, modify) LLVMValue (VFunction, VIdent),
Visibility (Global),
typeOf)
import Control.Monad.State (StateT, gets, modify, void)
import Data.Map (Map) import Data.Map (Map)
import qualified Data.Map as Map import qualified Data.Map as Map
import Grammar.ErrM (Err) import Grammar.ErrM (Err)
import Monomorphizer.MonomorphizerIr as MIR import Monomorphizer.MonomorphizerIr (Ident (..), Inj (..), T,
flattenType)
import qualified Monomorphizer.MonomorphizerIr as MIR
import qualified TypeChecker.TypeCheckerIr as TIR import qualified TypeChecker.TypeCheckerIr as TIR
-- | The record used as the code generator state -- | The record used as the code generator state
data CodeGenerator = CodeGenerator data CodeGenerator = CodeGenerator
{ instructions :: [LLVMIr] { instructions :: [LLVMIr]
, functions :: Map MIR.Id FunctionInfo , functions :: Map (T Ident) FunctionInfo
, customTypes :: Map LLVMType Integer , customTypes :: Map LLVMType Integer
, constructors :: Map TIR.Ident ConstructorInfo , constructors :: Map Ident ConstructorInfo
, variableCount :: Integer , variableCount :: Integer
, labelCount :: Integer , labelCount :: Integer
, gcEnabled :: Bool , gcEnabled :: Bool
, structTypes :: Map Ident StructType
-- ^ Custom stucture types
, locals :: [(Ident, LocalElem)]
-- ^ Arguments and variables in local environment
, globals :: Map Ident (LLVMType, LLVMValue)
} }
data StructType = StructType
{ ptr :: LLVMType
, typs :: [LLVMType]
, inst :: LLVMIr
}
data LocalElem = LocalElem
{ typ :: LLVMType
, val :: LLVMValue
}
-- | A state type synonym -- | A state type synonym
type CompilerState a = StateT CodeGenerator Err a type CompilerState a = StateT CodeGenerator Err a
data FunctionInfo = FunctionInfo data FunctionInfo = FunctionInfo
{ numArgs :: Int { numArgs :: Int
, arguments :: [Id] , arguments :: [T Ident]
} }
deriving (Show) deriving (Show)
data ConstructorInfo = ConstructorInfo data ConstructorInfo = ConstructorInfo
{ numArgsCI :: Int { numArgsCI :: Int
, argumentsCI :: [Id] , argumentsCI :: [T Ident]
, numCI :: Integer , numCI :: Integer
, returnTypeCI :: MIR.Type , returnTypeCI :: MIR.Type
} }
deriving (Show) deriving (Show)
addStructType_ :: Ident -> [LLVMType] -> CompilerState ()
addStructType_ = fmap void . addStructType
addStructType :: Ident -> [LLVMType] -> CompilerState LLVMType
addStructType x ts = do
modify $ \s -> s { structTypes = Map.insert x struct s.structTypes }
pure t
where
struct = StructType
{ ptr = t
, typs = ts
, inst = Type x ts
}
t = CustomType x
-- | Adds a instruction to the CodeGenerator state -- | Adds a instruction to the CodeGenerator state
emit :: LLVMIr -> CompilerState () emit :: LLVMIr -> CompilerState ()
emit l = modify $ \t -> t{instructions = Auxiliary.snoc l $ instructions t}
-- Add variable to environment
emit l@(SetVariable x _) = modify $ \t ->
t { instructions = Auxiliary.snoc l t.instructions
, locals = snoc (x, local)
t.locals
}
where
local = LocalElem { typ = typeOf l
, val = VIdent x (typeOf l)
}
emit l = modify $ \t -> t { instructions = Auxiliary.snoc l t.instructions }
-- | Increases the variable counter in the CodeGenerator state -- | Increases the variable counter in the CodeGenerator state
increaseVarCount :: CompilerState () increaseVarCount :: CompilerState ()
@ -63,16 +118,19 @@ getNewLabel = do
{- | Produces a map of functions infos from a list of binds, {- | Produces a map of functions infos from a list of binds,
which contains useful data for code generation. which contains useful data for code generation.
-} -}
getFunctions :: [MIR.Def] -> Map Id FunctionInfo getFunctions :: [MIR.Def] -> Map (T Ident) FunctionInfo
getFunctions bs = Map.fromList $ go bs getFunctions bs = Map.fromList $ go bs
where where
go [] = [] go [] = []
go (MIR.DBind (MIR.Bind id args _) : xs) = go (MIR.DBind (MIR.Bind id args _) : xs) =
(id, FunctionInfo{numArgs = length args, arguments = args}) (id, FunctionInfo { numArgs = length args
, arguments = args
}
)
: go xs : go xs
go (_ : xs) = go xs go (_ : xs) = go xs
createArgs :: [MIR.Type] -> [Id] createArgs :: [MIR.Type] -> [T Ident]
createArgs xs = fst $ foldl (\(acc, l) t -> (acc ++ [(TIR.Ident ("arg_" <> show l), t)], l + 1)) ([], 0) xs createArgs xs = fst $ foldl (\(acc, l) t -> (acc ++ [(TIR.Ident ("arg_" <> show l), t)], l + 1)) ([], 0) xs
{- | Produces a map of functions infos from a list of binds, {- | Produces a map of functions infos from a list of binds,
@ -113,35 +171,43 @@ getTypes bs = Map.fromList $ go bs
variantTypes fi = init $ map type2LlvmType (flattenType fi) variantTypes fi = init $ map type2LlvmType (flattenType fi)
biggestVariant ts = 8 + maximum (sum . (\(Inj _ fi) -> typeByteSize <$> variantTypes fi) <$> ts) biggestVariant ts = 8 + maximum (sum . (\(Inj _ fi) -> typeByteSize <$> variantTypes fi) <$> ts)
getGlobals :: [MIR.Def] -> Map Ident (LLVMType, LLVMValue)
getGlobals scs = Map.fromList [ go b | MIR.DBind b <- scs ]
where
go bind | x == "main" = let typ = Function I64 []
in (x, (typ, VFunction x Global typ))
| otherwise = (x, (typ, VFunction x Global typ))
where
typ = Function tr $ Ptr : ts
Function tr ts = type2LlvmType' t
(x, t) = case bind of
MIR.Bind xt _ _ -> xt
MIR.BindC _ xt _ _ -> xt
-- Higher order function arguments are replaced with ptr
type2LlvmType' = go []
where
go acc = \case
MIR.TFun (MIR.TFun _ _) t2 -> go (snoc Ptr acc) t2
MIR.TFun t1 t2 -> go (snoc (type2LlvmType t1) acc) t2
t -> Function (type2LlvmType t) acc
initCodeGenerator :: Bool -> [MIR.Def] -> CodeGenerator initCodeGenerator :: Bool -> [MIR.Def] -> CodeGenerator
initCodeGenerator addGc scs = initCodeGenerator addGc scs =
CodeGenerator CodeGenerator
{ instructions = defaultStart <> if addGc then gcStart else [] { instructions = []
, functions = getFunctions scs , functions = getFunctions scs
, constructors = getConstructors scs , constructors = getConstructors scs
, customTypes = getTypes scs , customTypes = getTypes scs
, structTypes = mempty
, variableCount = 0 , variableCount = 0
, labelCount = 0 , labelCount = 0
, gcEnabled = addGc , gcEnabled = addGc
, locals = mempty
, globals = getGlobals scs
} }
defaultStart :: [LLVMIr]
defaultStart =
[ UnsafeRaw "target triple = \"x86_64-pc-linux-gnu\"\n"
, UnsafeRaw "target datalayout = \"e-m:o-i64:64-f80:128-n8:16:32:64-S128\"\n"
, UnsafeRaw "@.str = private unnamed_addr constant [3 x i8] c\"%i\n\", align 1\n"
, UnsafeRaw "@.non_exhaustive_patterns = private unnamed_addr constant [41 x i8] c\"Non-exhaustive patterns in case at %i:%i\n\"\n"
, UnsafeRaw "declare i32 @printf(ptr noalias nocapture, ...)\n"
, UnsafeRaw "declare i32 @exit(i32 noundef)\n"
, UnsafeRaw "declare ptr @malloc(i32 noundef)\n"
]
gcStart :: [LLVMIr]
gcStart =
[ UnsafeRaw "declare external void @cheap_init()\n"
, UnsafeRaw "declare external ptr @cheap_alloc(i64)\n"
, UnsafeRaw "declare external void @cheap_dispose()\n"
, UnsafeRaw "declare external ptr @cheap_the()\n"
, UnsafeRaw "declare external void @cheap_set_profiler(ptr, i1)\n"
, UnsafeRaw "declare external void @cheap_profiler_log_options(ptr, i64)\n"
]

View file

@ -1,36 +1,40 @@
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE LambdaCase #-} {-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedRecordDot #-}
{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE OverloadedStrings #-}
module Codegen.Emits where module Codegen.Emits where
import Auxiliary (snoc)
import Codegen.Auxillary import Codegen.Auxillary
import Codegen.CompilerState import Codegen.CompilerState
import Codegen.LlvmIr as LIR import Codegen.LlvmIr as LIR
import Control.Applicative ((<|>)) import Control.Applicative (Applicative (liftA2), (<|>))
import Control.Monad (when) import Control.Monad (forM_, when, zipWithM_)
import Control.Monad.Extra (whenJust)
import Control.Monad.State (gets, modify) import Control.Monad.State (gets, modify)
import Data.Bifunctor qualified as BI
import Data.Char (ord) import Data.Char (ord)
import Data.Coerce (coerce) import Data.Coerce (coerce)
import Data.Map qualified as Map import Data.Foldable.Extra (notNull)
import qualified Data.Map as Map
import Data.Maybe (fromJust, fromMaybe, isNothing) import Data.Maybe (fromJust, fromMaybe, isNothing)
import Data.Tuple.Extra (dupe, first, second) import Data.Tuple.Extra (second)
import Debug.Trace (trace, traceShow) import Grammar.Print (printTree)
import Grammar.Print import Monomorphizer.MonomorphizerIr
import Monomorphizer.MonomorphizerIr as MIR
import TypeChecker.TypeCheckerIr qualified as TIR
compileScs :: [MIR.Def] -> CompilerState ()
compileScs :: [Def] -> CompilerState ()
compileScs [] = do compileScs [] = do
emit $ UnsafeRaw "\n" emit $ UnsafeRaw "\n"
mapM_ createConstructor =<< gets (Map.toList . constructors)
-- as a last step create all the constructors -- as a last step create all the constructors
-- //TODO maybe merge this with the data type match? -- //TODO maybe merge this with the data type match?
c <- gets (Map.toList . constructors) where
mapM_ createConstructor (id, ci) = do
( \(id, ci) -> do
let t = returnTypeCI ci let t = returnTypeCI ci
let t' = type2LlvmType t t' = type2LlvmType t
let x = BI.second type2LlvmType <$> argumentsCI ci x = (mkCxtName, Ptr) : map (second type2LlvmType) ci.argumentsCI
emit $ Define FastCC t' id x emit $ Define FastCC t' id x
top <- getNewVar top <- getNewVar
ptr <- getNewVar ptr <- getNewVar
@ -56,7 +60,7 @@ compileScs [] = do
cTypes <- gets customTypes cTypes <- gets customTypes
enumerateOneM_ enumerateOneM_
( \i (TIR.Ident arg_n, arg_t) -> do ( \i (Ident arg_n, arg_t) -> do
let arg_t' = type2LlvmType arg_t let arg_t' = type2LlvmType arg_t
emit $ Comment (toIr arg_t' <> " " <> arg_n <> " " <> show i) emit $ Comment (toIr arg_t' <> " " <> arg_n <> " " <> show i)
elemPtr <- getNewVar elemPtr <- getNewVar
@ -78,11 +82,11 @@ compileScs [] = do
heapPtr <- getNewVar heapPtr <- getNewVar
useGc <- gets gcEnabled useGc <- gets gcEnabled
emit $ SetVariable heapPtr (if useGc then GcMalloc s else Malloc s) emit $ SetVariable heapPtr (if useGc then GcMalloc s else Malloc s)
emit $ Store arg_t' (VIdent (TIR.Ident arg_n) arg_t') Ptr heapPtr emit $ Store arg_t' (VIdent (Ident arg_n) arg_t') Ptr heapPtr
emit $ Store (Ref arg_t') (VIdent heapPtr arg_t') Ptr elemPtr emit $ Store (Ref arg_t') (VIdent heapPtr arg_t') Ptr elemPtr
Nothing -> do Nothing -> do
emit $ Comment "Just store" emit $ Comment "Just store"
emit $ Store arg_t' (VIdent (TIR.Ident arg_n) arg_t') Ptr elemPtr emit $ Store arg_t' (VIdent (Ident arg_n) arg_t') Ptr elemPtr
) )
(argumentsCI ci) (argumentsCI ci)
@ -95,34 +99,83 @@ compileScs [] = do
emit $ UnsafeRaw "\n" emit $ UnsafeRaw "\n"
modify $ \s -> s{variableCount = 0} modify $ \s -> s{variableCount = 0}
)
c compileScs (DBind bind : xs) = do
compileScs (MIR.DBind (MIR.Bind (name, t) args exp) : xs) = do
let t_return = type2LlvmType . last . flattenType $ t
emit $ UnsafeRaw "\n" emit $ UnsafeRaw "\n"
emit . Comment $ show name <> ": " <> show exp emit . Comment $ show name <> ": " <> show (fst exp)
let args' = map (second type2LlvmType) args
Function t_return t_args <- gets $ fst
. fromJust
. Map.lookup name
. globals
let args' = zip (mkCxtName : map fst args) t_args
emit $ Define FastCC t_return name args' emit $ Define FastCC t_return name args'
useGc <- gets gcEnabled modify $ \s -> s { locals = foldr insertArg s.locals args' }
when (name == "main") (mapM_ emit (firstMainContent useGc))
functionBody <- exprToValue exp -- Dereference ptr arguments
if name == "main" when (notNull args') $
then mapM_ emit $ lastMainContent useGc functionBody forM_ (tail args') $ \(x, t) -> when (t == Ptr) $ do
else emit $ Ret t_return functionBody let t_deref =
let
Function t ts = type2LlvmType . fromJust $ lookup x args
in
Function t (Ptr : ts)
emit . SetVariable (mkDerefName x)
$ Load t_deref Ptr x
whenJust mcxt loadFreeVars
gcEnabled <- gets gcEnabled
when isMain $ mapM_ emit (firstMainContent gcEnabled)
result <- exprToValue exp
if isMain
then mapM_ emit $ lastMainContent gcEnabled result
else emit $ Ret t_return result
emit DefineEnd emit DefineEnd
modify $ \s -> s{variableCount = 0} -- Reset variable count and empty locals
modify $ \s -> s { variableCount = 0, locals = mempty }
compileScs xs compileScs xs
compileScs (MIR.DData (MIR.Data typ ts) : xs) = do where
let (TIR.Ident outer_id) = extractTypeName typ loadFreeVars cxt = do
emit $ Comment "Load free variables"
zipWithM_ go cxt' [1 ..]
where
go (x, t) i = do
vc <- getNewVar
emit . SetVariable vc
$ GetElementPtrInbounds (CustomType $ mkClosureName name) Ptr (VIdent mkCxtName Ptr)
I32 (VInteger 0) I32 (VInteger i) -- TODO fix indices
emit . SetVariable x $ Load t Ptr vc
cxt' = map (second type2LlvmType) cxt
isMain = name == "main"
(name, args, exp, mcxt) = case bind of
Bind (name, _) args exp -> (name, args, exp, Nothing)
BindC cxt (name, _) args exp -> (name, args, exp, Just cxt)
insertArg (x, t) = snoc (x, LocalElem { val = VIdent x t, typ = t })
compileScs (DData (Data typ ts) : xs) = do
let (Ident outer_id) = extractTypeName typ
-- //TODO this could be extracted from the customTypes map -- //TODO this could be extracted from the customTypes map
let variantTypes fi = init $ map type2LlvmType (flattenType fi) let variantTypes fi = init $ map type2LlvmType (flattenType fi)
let biggestVariant = 7 + maximum (sum . (\(Inj _ fi) -> typeByteSize <$> variantTypes fi) <$> ts) let biggestVariant = 7 + maximum (sum . (\(Inj _ fi) -> typeByteSize <$> variantTypes fi) <$> ts)
emit $ LIR.Type (TIR.Ident outer_id) [I8, Array biggestVariant I8] -- Add data type (e.g. %List) to top of the file
addStructType_ (Ident outer_id) [I8, Array biggestVariant I8]
typeSets <- gets customTypes typeSets <- gets customTypes
mapM_ mapM_
( \(Inj inner_id fi) -> do ( \(Inj inner_id fi) -> do
let types = (\s -> if Map.member s typeSets then Ref s else s) <$> variantTypes fi let types = (\s -> if Map.member s typeSets then Ref s else s) <$> variantTypes fi
emit $ LIR.Type inner_id (I8 : types) -- Add constructor type (e.g. %Cons) to top of the file
addStructType_ inner_id (I8 : types)
) )
ts ts
compileScs xs compileScs xs
@ -149,16 +202,16 @@ lastMainContent False var =
, Ret I64 (VInteger 0) , Ret I64 (VInteger 0)
] ]
compileExp :: ExpT -> CompilerState () compileExp :: T Exp -> CompilerState ()
compileExp (MIR.ELit lit, _t) = emitLit lit compileExp (ELit lit, _t) = emitLit lit
compileExp (MIR.EAdd e1 e2, t) = emitAdd t e1 e2 compileExp (EAdd e1 e2, t) = emitAdd t e1 e2
compileExp (MIR.EVar name, _t) = emitIdent name compileExp (EVar name, _t) = emitIdent name
compileExp (MIR.EApp e1 e2, t) = emitApp t e1 e2 compileExp (EApp e1 e2, t) = emitApp t e1 e2
compileExp (MIR.ELet bind e, _) = emitLet bind e compileExp (ELet bind e, _) = emitLet bind e
compileExp (MIR.ECase e cs, t) = emitECased t e (map (t,) cs) compileExp (ECase e cs, t) = emitECased t e (map (t,) cs)
emitLet :: MIR.Bind -> ExpT -> CompilerState () emitLet :: Bind -> T Exp -> CompilerState ()
emitLet (MIR.Bind id [] innerExp) e = do emitLet (Bind id [] innerExp) e = do
evaled <- exprToValue innerExp evaled <- exprToValue innerExp
tempVar <- getNewVar tempVar <- getNewVar
let t = type2LlvmType . snd $ innerExp let t = type2LlvmType . snd $ innerExp
@ -168,14 +221,14 @@ emitLet (MIR.Bind id [] innerExp) e = do
compileExp e compileExp e
emitLet b _ = error $ "Non empty argument list in let-bind " <> show b emitLet b _ = error $ "Non empty argument list in let-bind " <> show b
emitECased :: MIR.Type -> ExpT -> [(MIR.Type, Branch)] -> CompilerState () emitECased :: Type -> T Exp -> [(Type, Branch)] -> CompilerState ()
emitECased t e cases = do emitECased t e cases = do
let cs = snd <$> cases let cs = snd <$> cases
let ty = type2LlvmType t let ty = type2LlvmType t
let rt = type2LlvmType (snd e) let rt = type2LlvmType (snd e)
vs <- exprToValue e vs <- exprToValue e
lbl <- getNewLabel lbl <- getNewLabel
let label = TIR.Ident $ "escape_" <> show lbl let label = Ident $ "escape_" <> show lbl
stackPtr <- getNewVar stackPtr <- getNewVar
emit $ SetVariable stackPtr (Alloca ty) emit $ SetVariable stackPtr (Alloca ty)
mapM_ (emitCases rt ty label stackPtr vs) cs mapM_ (emitCases rt ty label stackPtr vs) cs
@ -192,14 +245,14 @@ emitECased t e cases = do
res <- getNewVar res <- getNewVar
emit $ SetVariable res (Load ty Ptr stackPtr) emit $ SetVariable res (Load ty Ptr stackPtr)
where where
emitCases :: LLVMType -> LLVMType -> TIR.Ident -> TIR.Ident -> LLVMValue -> Branch -> CompilerState () emitCases :: LLVMType -> LLVMType -> Ident -> Ident -> LLVMValue -> Branch -> CompilerState ()
emitCases rt ty label stackPtr vs (Branch (MIR.PInj consId cs, _t) exp) = do emitCases rt ty label stackPtr vs (Branch (PInj consId cs, _t) exp) = do
emit $ Comment "Inj" emit $ Comment "Inj"
cons <- gets constructors cons <- gets constructors
let r = fromJust $ Map.lookup consId cons let r = fromJust $ Map.lookup consId cons
lbl_failPos <- (\x -> TIR.Ident $ "failed_" <> show x) <$> getNewLabel lbl_failPos <- (\x -> Ident $ "failed_" <> show x) <$> getNewLabel
lbl_succPos <- (\x -> TIR.Ident $ "success_" <> show x) <$> getNewLabel lbl_succPos <- (\x -> Ident $ "success_" <> show x) <$> getNewLabel
consVal <- getNewVar consVal <- getNewVar
emit $ SetVariable consVal (ExtractValue rt vs 0) emit $ SetVariable consVal (ExtractValue rt vs 0)
@ -215,10 +268,10 @@ emitECased t e cases = do
emit $ Store rt vs Ptr castPtr emit $ Store rt vs Ptr castPtr
emit $ SetVariable casted (Load (CustomType (coerce consId)) Ptr castPtr) emit $ SetVariable casted (Load (CustomType (coerce consId)) Ptr castPtr)
enumerateOneM_ enumerateOneM_
( \i c -> do ( \i (c, t) -> do
case c of case c of
PVar (x, topT) -> do PVar x -> do
let topT' = type2LlvmType topT let topT' = type2LlvmType t
let botT' = CustomType (coerce consId) let botT' = CustomType (coerce consId)
emit . Comment $ "ident " <> toIr topT' emit . Comment $ "ident " <> toIr topT'
cTypes <- gets customTypes cTypes <- gets customTypes
@ -228,7 +281,7 @@ emitECased t e cases = do
emit $ SetVariable deref (ExtractValue botT' (VIdent casted Ptr) i) emit $ SetVariable deref (ExtractValue botT' (VIdent casted Ptr) i)
emit $ SetVariable x (Load topT' Ptr deref) emit $ SetVariable x (Load topT' Ptr deref)
else emit $ SetVariable x (ExtractValue botT' (VIdent casted Ptr) i) else emit $ SetVariable x (ExtractValue botT' (VIdent casted Ptr) i)
PLit (_l, _t) -> error "Nested pattern matching to be implemented" PLit _l -> error "Nested pattern matching to be implemented"
PInj _id _ps -> error "Nested pattern matching to be implemented" PInj _id _ps -> error "Nested pattern matching to be implemented"
PCatch -> pure () PCatch -> pure ()
PEnum _id -> error "Nested pattern matching to be implemented" PEnum _id -> error "Nested pattern matching to be implemented"
@ -238,22 +291,22 @@ emitECased t e cases = do
emit $ Store ty val Ptr stackPtr emit $ Store ty val Ptr stackPtr
emit $ Br label emit $ Br label
emit $ Label lbl_failPos emit $ Label lbl_failPos
emitCases _rt ty label stackPtr vs (Branch (MIR.PLit (i, ct), t) exp) = do emitCases _rt ty label stackPtr vs (Branch (PLit i, t) exp) = do
emit $ Comment "Plit" emit $ Comment "Plit"
let i' = case i of let i' = case i of
MIR.LInt i -> VInteger i LInt i -> VInteger i
MIR.LChar i -> VChar (ord i) LChar i -> VChar (ord i)
ns <- getNewVar ns <- getNewVar
lbl_failPos <- (\x -> TIR.Ident $ "failed_" <> show x) <$> getNewLabel lbl_failPos <- (\x -> Ident $ "failed_" <> show x) <$> getNewLabel
lbl_succPos <- (\x -> TIR.Ident $ "success_" <> show x) <$> getNewLabel lbl_succPos <- (\x -> Ident $ "success_" <> show x) <$> getNewLabel
emit $ SetVariable ns (Icmp LLEq (type2LlvmType ct) vs i') emit $ SetVariable ns (Icmp LLEq (type2LlvmType t) vs i')
emit $ BrCond (VIdent ns ty) lbl_succPos lbl_failPos emit $ BrCond (VIdent ns ty) lbl_succPos lbl_failPos
emit $ Label lbl_succPos emit $ Label lbl_succPos
val <- exprToValue exp val <- exprToValue exp
emit $ Store ty val Ptr stackPtr emit $ Store ty val Ptr stackPtr
emit $ Br label emit $ Br label
emit $ Label lbl_failPos emit $ Label lbl_failPos
emitCases rt ty label stackPtr vs (Branch (MIR.PVar (id, _), _) exp) = do emitCases rt ty label stackPtr vs (Branch (PVar id, _) exp) = do
emit $ Comment "Pvar" emit $ Comment "Pvar"
-- //TODO this is pretty disgusting and would heavily benefit from a rewrite -- //TODO this is pretty disgusting and would heavily benefit from a rewrite
valPtr <- getNewVar valPtr <- getNewVar
@ -263,20 +316,20 @@ emitECased t e cases = do
val <- exprToValue exp val <- exprToValue exp
emit $ Store ty val Ptr stackPtr emit $ Store ty val Ptr stackPtr
emit $ Br label emit $ Br label
lbl_failPos <- (\x -> TIR.Ident $ "failed_" <> show x) <$> getNewLabel lbl_failPos <- (\x -> Ident $ "failed_" <> show x) <$> getNewLabel
emit $ Label lbl_failPos emit $ Label lbl_failPos
emitCases rt ty label stackPtr vs (Branch (MIR.PEnum (TIR.Ident "True$Bool"), t) exp) = do emitCases rt ty label stackPtr vs (Branch (PEnum (Ident "True$Bool"), t) exp) = do
emitCases rt ty label stackPtr vs (Branch (MIR.PLit (MIR.LInt 1, TLit "Bool"), t) exp) emitCases rt ty label stackPtr vs (Branch (PLit $ LInt 1, t) exp)
emitCases rt ty label stackPtr vs (Branch (MIR.PEnum (TIR.Ident "False$Bool"), _) exp) = do emitCases rt ty label stackPtr vs (Branch (PEnum (Ident "False$Bool"), _) exp) = do
emitCases rt ty label stackPtr vs (Branch (MIR.PLit (MIR.LInt 0, TLit "Bool"), t) exp) emitCases rt ty label stackPtr vs (Branch (PLit (LInt 0), t) exp)
emitCases rt ty label stackPtr vs br@(Branch (MIR.PEnum consId, _) exp) = do emitCases rt ty label stackPtr vs br@(Branch (PEnum consId, _) exp) = do
emit $ Comment "Penum" emit $ Comment "Penum"
cons <- gets constructors cons <- gets constructors
let r = Map.lookup consId cons let r = Map.lookup consId cons
when (isNothing r) (error $ "Constructor: '" ++ printTree consId ++ "' does not exist in cons state:\n" ++ show cons ++ "\nin pattern\n'" ++ printTree br ++ "'\n") when (isNothing r) (error $ "Constructor: '" ++ printTree consId ++ "' does not exist in cons state:\n" ++ show cons ++ "\nin pattern\n'" ++ printTree br ++ "'\n")
lbl_failPos <- (\x -> TIR.Ident $ "failed_" <> show x) <$> getNewLabel lbl_failPos <- (\x -> Ident $ "failed_" <> show x) <$> getNewLabel
lbl_succPos <- (\x -> TIR.Ident $ "success_" <> show x) <$> getNewLabel lbl_succPos <- (\x -> Ident $ "success_" <> show x) <$> getNewLabel
consVal <- getNewVar consVal <- getNewVar
emit $ SetVariable consVal (ExtractValue rt vs 0) emit $ SetVariable consVal (ExtractValue rt vs 0)
@ -295,24 +348,17 @@ emitECased t e cases = do
emit $ Store ty val Ptr stackPtr emit $ Store ty val Ptr stackPtr
emit $ Br label emit $ Br label
emit $ Label lbl_failPos emit $ Label lbl_failPos
emitCases _ ty label stackPtr _ (Branch (MIR.PCatch, _) exp) = do emitCases _ ty label stackPtr _ (Branch (PCatch, _) exp) = do
emit $ Comment "Pcatch" emit $ Comment "Pcatch"
val <- exprToValue exp val <- exprToValue exp
emit $ Store ty val Ptr stackPtr emit $ Store ty val Ptr stackPtr
emit $ Br label emit $ Br label
lbl_failPos <- (\x -> TIR.Ident $ "failed_" <> show x) <$> getNewLabel lbl_failPos <- (\x -> Ident $ "failed_" <> show x) <$> getNewLabel
emit $ Label lbl_failPos emit $ Label lbl_failPos
emitApp :: MIR.Type -> ExpT -> ExpT -> CompilerState () emitApp :: Type -> T Exp -> T Exp -> CompilerState ()
emitApp rt e1 e2 = appEmitter e1 e2 [] emitApp rt e1 e2 = do
where ((EVar name, t), args) <- go (EApp e1 e2, rt)
appEmitter :: ExpT -> ExpT -> [ExpT] -> CompilerState ()
appEmitter e1 e2 stack = do
let newStack = e2 : stack
case e1 of
(MIR.EApp e1' e2', _) -> appEmitter e1' e2' newStack
(MIR.EVar name, t) -> do
args <- traverse exprToValue newStack
vs <- getNewVar vs <- getNewVar
funcs <- gets functions funcs <- gets functions
consts <- gets constructors consts <- gets constructors
@ -321,72 +367,147 @@ emitApp rt e1 e2 = appEmitter e1 e2 []
Global <$ Map.lookup name consts Global <$ Map.lookup name consts
<|> Global <$ Map.lookup (name, t) funcs <|> Global <$ Map.lookup (name, t) funcs
-- this piece of code could probably be improved, i.e remove the double `const Global` -- this piece of code could probably be improved, i.e remove the double `const Global`
args' = map (first valueGetType . dupe) args
let call =
case name of
TIR.Ident ('l' : 't' : '$' : _) -> Icmp LLSlt I64 (snd (head args')) (snd (args' !! 1))
TIR.Ident ('$' : 'm' : 'i' : 'n' : 'u' : 's' : '$' : '$' : _) -> Sub I64 (snd (head args')) (snd (args' !! 1))
_ -> Call FastCC (type2LlvmType rt) visibility name args'
emit $ Comment $ show rt
emit $ SetVariable vs call
x -> error $ "The unspeakable happened: " <> show x
emitIdent :: TIR.Ident -> CompilerState () call <- case name of
Ident ('l' : 't' : '$' : _) ->
pure $ Icmp LLSlt I64 (snd (head args)) (snd (args !! 1))
Ident ('$' : 'm' : 'i' : 'n' : 'u' : 's' : '$' : '$' : _) ->
pure $ Sub I64 (snd (head args)) (snd (args !! 1))
-- FIXME
_ -> do
let closure_call LocalElem { typ = Ptr, val } = (mkDerefName name, (Ptr, val) : args)
(name, args) <- gets $ maybe (name, (Ptr, VNull) : args) closure_call
. lookup name
. locals
pure $ Call FastCC (type2LlvmType rt) visibility name args
emit $ Comment $ show (type2LlvmType rt)
emit $ SetVariable vs call
where
go :: T Exp -> CompilerState (T Exp, [(LLVMType, LLVMValue)])
go et@(e, _) = case e of
EApp e1 e2@(_, t) -> do
(x, as) <- go e1
a <- exprToValue e2
let t' = type2LlvmType' t
pure (x, snoc (t', a) as)
_ -> pure (et, [])
type2LlvmType' = \case
TFun _ _ -> Ptr
t -> type2LlvmType t
emitIdent :: Ident -> CompilerState ()
emitIdent id = do emitIdent id = do
-- !!this should never happen!! -- !!this should never happen!!
emit $ Comment "This should not have happened!" emit $ Comment "This should not have happened!"
emit $ Variable id emit $ Variable id
emit $ UnsafeRaw "\n" emit $ UnsafeRaw "\n"
emitLit :: MIR.Lit -> CompilerState () emitLit :: Lit -> CompilerState ()
emitLit i = do emitLit i = do
-- !!this should never happen!! -- !!this should never happen!!
let (i', t) = case i of let (i', t) = case i of
(MIR.LInt i'') -> (VInteger i'', I64) (LInt i'') -> (VInteger i'', I64)
(MIR.LChar i'') -> (VChar $ ord i'', I8) (LChar i'') -> (VChar $ ord i'', I8)
varCount <- getNewVar varCount <- getNewVar
emit $ Comment "This should not have happened!" emit $ Comment "This should not have happened!"
emit $ SetVariable varCount (Add t i' (VInteger 0)) emit $ SetVariable varCount (Add t i' (VInteger 0))
emitAdd :: MIR.Type -> ExpT -> ExpT -> CompilerState () emitAdd :: Type -> T Exp -> T Exp -> CompilerState ()
emitAdd t e1 e2 = do emitAdd t e1 e2 = do
v1 <- exprToValue e1 v1 <- exprToValue e1
v2 <- exprToValue e2 v2 <- exprToValue e2
v <- getNewVar v <- getNewVar
emit $ SetVariable v (Add (type2LlvmType t) v1 v2) emit $ SetVariable v (Add (type2LlvmType t) v1 v2)
exprToValue :: ExpT -> CompilerState LLVMValue
exprToValue = \case exprToValue :: T Exp -> CompilerState LLVMValue
(MIR.ELit i, _t) -> pure $ case i of exprToValue et@(e, t) = case e of
(MIR.LInt i) -> VInteger i ELit (LInt i) -> pure $ VInteger i
(MIR.LChar i) -> VChar $ ord i ELit (LChar c) -> pure . VChar $ ord c
(MIR.EVar (TIR.Ident "True$Bool"), _t) -> pure $ VInteger 1
(MIR.EVar (TIR.Ident "False$Bool"), _t) -> pure $ VInteger 0 EVar "True$Bool" -> pure $ VInteger 1
(MIR.EVar name, t) -> do EVar "False$Bool" -> pure $ VInteger 0
funcs <- gets functions
cons <- gets constructors EVar name -> gets (Map.lookup name . globals) >>= \case
let res = Just (typ@(Function _ ts), val) | length ts > 1 -> do
Map.lookup (name, t) funcs type_struct <- addStructType (mkClosureName name) [typ]
<|> ( \c -> emit $ Comment "Allocating structure"
FunctionInfo emit . SetVariable name $ Alloca type_struct
{ numArgs = numArgsCI c emit $ Store typ val Ptr name
, arguments = argumentsCI c pure $ VIdent name Ptr
}
) Just _ | name == "main" -> do
<$> Map.lookup name cons
case res of
Just fi -> do
if numArgs fi == 0
then do
vc <- getNewVar vc <- getNewVar
emit $ emit $ SetVariable vc (Call FastCC I64 Global name [])
SetVariable pure $ VIdent vc I64
vc
(Call FastCC (type2LlvmType t) Global name [])
Just (Function t_return [_], _) -> do
vc <- getNewVar
emit $ SetVariable vc (Call FastCC t_return Global name [(Ptr, VNull)])
pure $ VIdent vc t_return
Just _ -> error "Bad"
Nothing -> gets (Map.lookup name . constructors) >>= \case
Just ConstructorInfo {numArgsCI}
| numArgsCI == 0 -> do
vc <- getNewVar
emit $ SetVariable vc call
pure $ VIdent vc (type2LlvmType t) pure $ VIdent vc (type2LlvmType t)
else pure $ VFunction name Global (type2LlvmType t) | otherwise -> pure $ VFunction name Global (type2LlvmType t)
Nothing -> pure $ VIdent name (type2LlvmType t) where
e -> do call = Call FastCC (type2LlvmType t) Global name []
compileExp e
Nothing -> gets $ val
. fromJust
. lookup name
. locals
EVarC cxt name -> do
let cxt' = flip map cxt $ \(x, t) -> let t' = type2LlvmType t
in (t', VIdent x t')
cxt'' <- gets $ (:cxt')
. fromJust
. Map.lookup name
. globals
-- Create a new type for function pointer and arguments
type_struct <- addStructType (mkClosureName name) $ map fst cxt''
emit $ Comment "Allocating structure"
emit . SetVariable name $ Alloca type_struct
let ptr_struct = VIdent name Ptr
storeArg (t, v) i = do
vc <- getNewVar
emit . SetVariable vc
$ GetElementPtrInbounds type_struct Ptr ptr_struct
I32 (VInteger 0) I32 (VInteger i) -- TODO fix indices
emit $ Store t v Ptr vc
-- Store arguments in structure
zipWithM_ storeArg cxt'' [0 ..]
pure ptr_struct
_ -> do
compileExp et
v <- getVarCount v <- getVarCount
pure $ VIdent (TIR.Ident $ show v) (getType e) pure $ VIdent (Ident $ show v) (getType et)
mkClosureName :: Ident -> Ident
mkClosureName (Ident s) = Ident $ "Closure_" ++ s
mkDerefName :: Ident -> Ident
mkDerefName (Ident s) = Ident $ s ++ "_deref"
mkCxtName :: Ident
mkCxtName = Ident "cxt"

View file

@ -9,6 +9,7 @@ module Codegen.LlvmIr (
Visibility (..), Visibility (..),
CallingConvention (..), CallingConvention (..),
ToIr (..), ToIr (..),
typeOf
) where ) where
import Data.List (intercalate) import Data.List (intercalate)
@ -38,6 +39,9 @@ data LLVMType
class ToIr a where class ToIr a where
toIr :: a -> String toIr :: a -> String
instance ToIr a => ToIr [a] where
toIr = concatMap toIr
instance ToIr LLVMType where instance ToIr LLVMType where
toIr :: LLVMType -> String toIr :: LLVMType -> String
toIr = \case toIr = \case
@ -92,6 +96,7 @@ data LLVMValue
| VIdent Ident LLVMType | VIdent Ident LLVMType
| VConstant String | VConstant String
| VFunction Ident Visibility LLVMType | VFunction Ident Visibility LLVMType
| VNull
deriving (Show, Eq, Ord) deriving (Show, Eq, Ord)
instance ToIr LLVMValue where instance ToIr LLVMValue where
@ -102,6 +107,7 @@ instance ToIr LLVMValue where
VIdent (Ident n) _ -> "%" <> n VIdent (Ident n) _ -> "%" <> n
VFunction (Ident n) vis _ -> toIr vis <> n VFunction (Ident n) vis _ -> toIr vis <> n
VConstant s -> "c" <> show s VConstant s -> "c" <> show s
VNull -> "null"
type Params = [(Ident, LLVMType)] type Params = [(Ident, LLVMType)]
type Args = [(LLVMType, LLVMValue)] type Args = [(LLVMType, LLVMValue)]
@ -139,6 +145,21 @@ data LLVMIr
-- instructions should be used in its place -- instructions should be used in its place
deriving (Show, Eq, Ord) deriving (Show, Eq, Ord)
-- TODO add missing clauses
typeOf :: LLVMIr -> LLVMType
typeOf = \case
Add t _ _ -> t
Sub t _ _ -> t
Mul t _ _ -> t
Div t _ _ -> t
Load t _ _ -> t
Store t _ _ _ -> t
Type x _ -> CustomType x
SetVariable _ ir -> typeOf ir
-- | Converts a list of LLVMIr instructions to a string -- | Converts a list of LLVMIr instructions to a string
llvmIrToString :: [LLVMIr] -> String llvmIrToString :: [LLVMIr] -> String
llvmIrToString = go 0 llvmIrToString = go 0

View file

@ -11,9 +11,11 @@ import Control.Monad.State (MonadState (get, put), State,
evalState) evalState)
import Data.Function (on) import Data.Function (on)
import Data.List (delete, mapAccumL, (\\)) import Data.List (delete, mapAccumL, (\\))
import Data.Tuple.Extra (first, second)
import LambdaLifterIr (T)
import qualified LambdaLifterIr as L
import Prelude hiding (exp) import Prelude hiding (exp)
import TypeChecker.TypeCheckerIr import TypeChecker.TypeCheckerIr hiding (T)
-- | Lift lambdas and let expression into supercombinators. -- | Lift lambdas and let expression into supercombinators.
-- Three phases: -- Three phases:
@ -21,12 +23,13 @@ import TypeChecker.TypeCheckerIr
-- @abstract@ converts lambdas into let expressions. -- @abstract@ converts lambdas into let expressions.
-- @collectScs@ moves every non-constant let expression to a top-level function. -- @collectScs@ moves every non-constant let expression to a top-level function.
-- --
lambdaLift :: Program -> Program lambdaLift :: Program -> L.Program
lambdaLift (Program ds) = Program (datatypes ++ binds) lambdaLift (Program ds) = L.Program (datatypes ++ binds)
where where
datatypes = flip filter ds $ \case DData _ -> True datatypes = [L.DData (toLirData d) | DData d <- ds]
_ -> False
binds = map DBind $ (collectScs . abstract . freeVars) [b | DBind b <- ds] binds = map L.DBind $ (collectScs . abstract . freeVars) [b | DBind b <- ds]
-- | Annotate free variables -- | Annotate free variables
freeVars :: [Bind] -> [ABind] freeVars :: [Bind] -> [ABind]
@ -36,7 +39,7 @@ freeVars binds = [ let ae = freeVarsExp [] e
| Bind n xs e <- binds | Bind n xs e <- binds
] ]
freeVarsExp :: Frees -> ExpT -> Ann AExpT freeVarsExp :: Frees -> T Exp -> Ann (T AExp)
freeVarsExp localVars (ae, t) = case ae of freeVarsExp localVars (ae, t) = case ae of
EVar n | elem (n,t) localVars -> Ann { frees = [(n, t)] EVar n | elem (n,t) localVars -> Ann { frees = [(n, t)]
, term = (AVar n, t) , term = (AVar n, t)
@ -121,27 +124,47 @@ data Ann a = Ann
, term :: a , term :: a
} deriving (Show, Eq) } deriving (Show, Eq)
data ABind = ABind Id [Id] (Ann AExpT) deriving (Show, Eq) data ABind = ABind (T Ident) [T Ident] (Ann (T AExp)) deriving (Show, Eq)
data ABranch = ABranch (Pattern, Type) (Ann AExpT) deriving (Show, Eq) data ABranch = ABranch (Pattern, Type) (Ann (T AExp)) deriving (Show, Eq)
type AExpT = (AExp, Type)
data AExp = AVar Ident data AExp = AVar Ident
| AInj Ident | AInj Ident
| ALit Lit | ALit Lit
| ALet (Ann ABind) (Ann AExpT) | ALet (Ann ABind) (Ann (T AExp))
| AApp (Ann AExpT) (Ann AExpT) | AApp (Ann (T AExp)) (Ann (T AExp))
| AAdd (Ann AExpT) (Ann AExpT) | AAdd (Ann (T AExp)) (Ann (T AExp))
| AAbs Ident (Ann AExpT) | AAbs Ident (Ann (T AExp))
| ACase (Ann AExpT) [Ann ABranch] | ACase (Ann (T AExp)) [Ann ABranch]
deriving (Show, Eq) deriving (Show, Eq)
abstract :: [ABind] -> [Bind]
data BBind = BBind (T Ident) [T Ident] (T BExp)
| BBindCxt [T Ident] (T Ident) [T Ident] (T BExp)
deriving (Eq, Ord, Show)
data BBranch = BBranch (T Pattern) (T BExp)
deriving (Eq, Ord, Show)
data BExp
= BVar Ident
| BVarC [T Ident] Ident
| BInj Ident
| BLit Lit
| BLet BBind (T BExp)
| BApp (T BExp)(T BExp)
| BAdd (T BExp)(T BExp)
| BCase (T BExp) [BBranch]
deriving (Eq, Ord, Show)
abstract :: [ABind] -> [BBind]
abstract bs = evalState (mapM (abstractAnnBind . Ann []) bs) 0 abstract bs = evalState (mapM (abstractAnnBind . Ann []) bs) 0
abstractAnnBind :: Ann ABind -> State Int Bind abstractAnnBind :: Ann ABind -> State Int BBind
abstractAnnBind Ann { term = ABind name vars annae } = abstractAnnBind Ann { term = ABind name vars annae } =
Bind name (vars' <|| vars) <$> abstractAnnExp annae' BBind name (vars' <|| vars) <$> abstractAnnExp annae'
where where
(annae', vars') = go [] annae (annae', vars') = go [] annae
where where
@ -149,24 +172,27 @@ abstractAnnBind Ann { term = ABind name vars annae } =
Ann { term = (AAbs x ae, TFun t _) } -> go (snoc (x, t) acc) ae Ann { term = (AAbs x ae, TFun t _) } -> go (snoc (x, t) acc) ae
ae -> (ae, acc) ae -> (ae, acc)
abstractAnnExp :: Ann AExpT -> State Int ExpT abstractAnnExp :: Ann (T AExp) -> State Int (T BExp)
abstractAnnExp Ann {frees, term = (annae, typ) } = case annae of abstractAnnExp Ann {frees, term = (annae, typ) } = case annae of
AVar n -> pure (EVar n, typ) AVar n -> pure (BVar n, typ)
AInj n -> pure (EInj n, typ) AInj n -> pure (BInj n, typ)
ALit lit -> pure (ELit lit, typ) ALit lit -> pure (BLit lit, typ)
AApp annae1 annae2 -> (, typ) <$> onM EApp abstractAnnExp annae1 annae2 AApp annae1 annae2 -> (, typ) <$> onM BApp abstractAnnExp annae1 annae2
AAdd annae1 annae2 -> (, typ) <$> onM EAdd abstractAnnExp annae1 annae2 AAdd annae1 annae2 -> (, typ) <$> onM BAdd abstractAnnExp annae1 annae2
-- \x. \y. x + y + z ⇒ let sc x y z = x + y + z in sc
AAbs x annae' -> do AAbs x annae' -> do
i <- nextNumber i <- nextNumber
rhs <- abstractAnnExp annae'' rhs <- abstractAnnExp annae''
let sc_name = Ident ("sc_" ++ show i) let sc_name = Ident ("sc_" ++ show i)
e@(_, t) = foldl applyFree (EVar sc_name, typ) frees sc | null frees = (BVar sc_name, typ)
pure (ELet (Bind (sc_name, typ) vars rhs) e ,t) | otherwise = (BVarC frees sc_name, typ)
bind | null frees = BBind (sc_name, typ) vars rhs
| otherwise = BBindCxt frees (sc_name, typ) vars rhs
pure (BLet bind sc ,typ)
where where
vars = frees <| (x, t_x) <|| ys vars = [(x, t_x)] <|| ys
t_x = case typ of TFun t _ -> t t_x = case typ of TFun t _ -> t
_ -> error "Impossible" _ -> error "Impossible"
@ -176,54 +202,47 @@ abstractAnnExp Ann {frees, term = (annae, typ) } = case annae of
Ann { term = (AAbs x ae, TFun t _) } -> go (snoc (x, t) acc) ae Ann { term = (AAbs x ae, TFun t _) } -> go (snoc (x, t) acc) ae
ae -> (ae, acc) ae -> (ae, acc)
applyFree :: (Exp' Type, Type) -> (Ident, Type) -> (Exp' Type, Type)
applyFree (e, t_e) (x, t_x) = (EApp (e, t_e) (EVar x, t_x), t_e')
where
t_e' = case t_e of TFun _ t -> t
_ -> error "Impossible"
ACase annae' bs -> do ACase annae' bs -> do
bs <- mapM go bs bs <- mapM go bs
e <- abstractAnnExp annae' e <- abstractAnnExp annae'
pure (ECase e bs, typ) pure (BCase e bs, typ)
where where
go Ann { term = ABranch p annae } = Branch p <$> abstractAnnExp annae go Ann { term = ABranch p annae } = BBranch p <$> abstractAnnExp annae
ALet b annae' -> ALet b annae' ->
(, typ) <$> liftA2 ELet (abstractAnnBind b) (abstractAnnExp annae') (, typ) <$> liftA2 BLet (abstractAnnBind b) (abstractAnnExp annae')
-- | Collects supercombinators by lifting non-constant let expressions -- | Collects supercombinators by lifting non-constant let expressions
collectScs :: [Bind] -> [Bind] collectScs :: [BBind] -> [L.Bind]
collectScs = concatMap collectFromRhs collectScs = concatMap collectFromRhs
where where
collectFromRhs (Bind name parms rhs) = collectFromRhs (BBind name parms rhs) =
let (rhs_scs, rhs') = collectScsExp rhs let (rhs_scs, rhs') = collectScsExp rhs
in Bind name parms rhs' : rhs_scs in L.Bind name parms rhs' : rhs_scs
collectFromRhs (BBindCxt cxt name parms rhs) =
let (rhs_scs, rhs') = collectScsExp rhs
in L.BindC cxt name parms rhs' : rhs_scs
collectScsExp :: ExpT -> ([Bind], ExpT) collectScsExp :: T BExp -> ([L.Bind], T L.Exp)
collectScsExp expT@(exp, typ) = case exp of collectScsExp (exp, typ) = case exp of
EVar _ -> ([], expT) BVar x -> ([], (L.EVar x, typ))
EInj _ -> ([], expT) BVarC as x -> ([], (L.EVarC as x, typ))
ELit _ -> ([], expT) BInj k -> ([], (L.EInj k, typ))
BLit lit -> ([], (L.ELit lit, typ))
EApp e1 e2 -> (scs1 ++ scs2, (EApp e1' e2', typ)) BApp e1 e2 -> (scs1 ++ scs2, (L.EApp e1' e2', typ))
where where
(scs1, e1') = collectScsExp e1 (scs1, e1') = collectScsExp e1
(scs2, e2') = collectScsExp e2 (scs2, e2') = collectScsExp e2
EAdd e1 e2 -> (scs1 ++ scs2, (EAdd e1' e2', typ)) BAdd e1 e2 -> (scs1 ++ scs2, (L.EAdd e1' e2', typ))
where where
(scs1, e1') = collectScsExp e1 (scs1, e1') = collectScsExp e1
(scs2, e2') = collectScsExp e2 (scs2, e2') = collectScsExp e2
EAbs par e -> (scs, (EAbs par e', typ)) BCase e branches -> (scs ++ scs_e, (L.ECase e' branches', typ))
where
(scs, e') = collectScsExp e
ECase e branches -> (scs ++ scs_e, (ECase e' branches', typ))
where where
(scs, branches') = mapAccumL f [] branches (scs, branches') = mapAccumL f [] branches
(scs_e, e') = collectScsExp e (scs_e, e') = collectScsExp e
@ -234,15 +253,24 @@ collectScsExp expT@(exp, typ) = case exp of
-- --
-- > f = let sc x y = rhs in e -- > f = let sc x y = rhs in e
-- --
ELet (Bind name parms rhs) e BLet (BBind name parms rhs) e
| null parms -> (rhs_scs ++ et_scs, (ELet bind et', snd et')) | null parms -> (rhs_scs ++ et_scs, (L.ELet name rhs' et', snd et'))
| otherwise -> (bind : rhs_scs ++ et_scs, et') | otherwise -> (bind : rhs_scs ++ et_scs, et')
where where
bind = Bind name parms rhs' bind = L.Bind name parms rhs'
(rhs_scs, rhs') = collectScsExp rhs (rhs_scs, rhs') = collectScsExp rhs
(et_scs, et') = collectScsExp e (et_scs, et') = collectScsExp e
collectScsBranch (Branch patt exp) = (scs, Branch patt exp')
BLet (BBindCxt cxt name parms rhs) e
| null parms -> (rhs_scs ++ et_scs, (L.ELet name rhs' et', snd et'))
| otherwise -> (bind : rhs_scs ++ et_scs, et')
where
bind = L.BindC cxt name parms rhs'
(rhs_scs, rhs') = collectScsExp rhs
(et_scs, et') = collectScsExp e
collectScsBranch (BBranch patt exp) = (scs, L.Branch (first toLirPattern patt) exp')
where (scs, exp') = collectScsExp exp where (scs, exp') = collectScsExp exp
nextNumber :: State Int Int nextNumber :: State Int Int
@ -259,3 +287,13 @@ xs <| x | elem x xs = xs
(<||) :: Eq a => [a] -> [a] -> [a] (<||) :: Eq a => [a] -> [a] -> [a]
xs <|| ys = foldl (<|) xs ys xs <|| ys = foldl (<|) xs ys
toLirData (Data t injs) = L.Data t (map toLirInj injs)
toLirInj (Inj n t) = L.Inj n t
toLirPattern :: Pattern -> L.Pattern
toLirPattern = \case
PVar x -> L.PVar x
PLit lit -> L.PLit lit
PCatch -> L.PCatch
PEnum k -> L.PEnum k
PInj k ps -> L.PInj k (map (first toLirPattern) ps)

140
src/LambdaLifterIr.hs Normal file
View file

@ -0,0 +1,140 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PatternSynonyms #-}
module LambdaLifterIr (
module Grammar.Abs,
module LambdaLifterIr,
module TypeChecker.TypeCheckerIr
) where
import Data.List (intercalate)
import Grammar.Abs (Lit (..))
import Grammar.Print
import Prelude hiding (exp)
import qualified Prelude as C (Eq, Ord, Show)
import TypeChecker.TypeCheckerIr (Ident (..), TVar (..), Type (..))
newtype Program = Program [Def]
deriving (C.Eq, C.Ord, C.Show)
data Def
= DBind Bind
| DData Data
deriving (C.Eq, C.Ord, C.Show)
data Data = Data Type [Inj]
deriving (C.Eq, C.Ord, C.Show)
data Inj = Inj Ident Type
deriving (C.Eq, C.Ord, C.Show)
data Pattern
= PVar Ident
| PLit Lit
| PCatch
| PEnum Ident
| PInj Ident [(Pattern, Type)]
deriving (C.Eq, C.Ord, C.Show)
data Exp
= EVar Ident
| EVarC [T Ident] Ident
| EInj Ident
| ELit Lit
| ELet (T Ident) (T Exp) (T Exp)
| EApp (T Exp)(T Exp)
| EAdd (T Exp)(T Exp)
| ECase (T Exp) [Branch]
deriving (C.Eq, C.Ord, C.Show)
type T a = (a, Type)
data Bind = Bind (T Ident) [T Ident] (T Exp)
| BindC [T Ident] (T Ident) [T Ident] (T Exp)
deriving (C.Eq, C.Ord, C.Show)
data Branch = Branch (T Pattern) (T Exp)
deriving (C.Eq, C.Ord, C.Show)
instance Print Program where
prt i (Program sc) = prt i sc
instance Print Bind where
prt i (Bind sig parms rhs) = concatD
[ prt i sig
, prt i parms
, doc $ showString "="
, prt i rhs
]
prt i (BindC cxt sig parms rhs) =
prPrec i 0 $
concatD
[ doc . showString $ "{" ++ intercalate ", " (map (\(x, _) -> printTree x ++ "^") cxt) ++ "}" ++ printTree sig
, prt i parms
, doc $ showString "="
, prt i rhs
]
instance Print [Bind] where
prt _ [] = concatD []
prt i [x] = concatD [prt i x]
prt i (x : xs) = concatD [prt i x, doc (showString ";"), prt i xs]
instance Print Exp where
prt i = \case
EVar lident -> prPrec i 3 (concatD [prt 0 lident])
EVarC as lident -> doc . showString
$ "{" ++ intercalate ", " (map go as) ++ "}" ++ printTree lident
where
go (x, _) = printTree x ++ "^=" ++ printTree (EVar x)
EInj uident -> prPrec i 3 (concatD [prt 0 uident])
ELit lit -> prPrec i 3 (concatD [prt 0 lit])
EApp exp1 exp2 -> prPrec i 2 (concatD [prt 2 exp1, prt 3 exp2])
EAdd exp1 exp2 -> prPrec i 1 (concatD [prt 1 exp1, doc (showString "+"), prt 2 exp2])
ELet lident exp1 exp2 -> prPrec i 0 (concatD [doc (showString "let"), prt 0 lident, doc (showString "="), prt 0 exp1 , doc (showString "in"), prt 0 exp2])
ECase exp branchs -> prPrec i 0 (concatD [doc (showString "case"), prt 0 exp, doc (showString "of"), doc (showString "{"), prt 0 branchs, doc (showString "}")])
instance Print Branch where
prt i (Branch (pattern_, t) exp) = prPrec i 0 (concatD [doc (showString "("), prt 0 pattern_, doc (showString " : "), prt 0 t, doc (showString ")"), doc (showString "=>"), prt 0 exp])
instance Print [Branch] where
prt _ [] = concatD []
prt _ [x] = concatD [prt 0 x]
prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs]
instance Print Def where
prt i = \case
DBind bind -> prPrec i 0 (concatD [prt 0 bind])
DData data_ -> prPrec i 0 (concatD [prt 0 data_])
instance Print Data where
prt i = \case
Data type_ injs -> prPrec i 0 (concatD [doc (showString "data"), prt 0 type_, doc (showString "where"), doc (showString "{"), prt 0 injs, doc (showString "}")])
instance Print Inj where
prt i = \case
Inj uident type_ -> prt i (uident, type_)
instance Print [Inj] where
prt _ [] = concatD []
prt i [x] = prt i x
prt i (x : xs) = prPrec i 0 $ concatD [prt i x, doc $ showString "\n ", prt i xs]
instance Print Pattern where
prt i = \case
PVar name -> prPrec i 1 (concatD [prt 0 name])
PLit lit -> prPrec i 1 (concatD [prt 0 lit])
PCatch -> prPrec i 1 (concatD [doc (showString "_")])
PEnum name -> prPrec i 1 (concatD [prt 0 name])
PInj uident patterns -> prPrec i 0 (concatD [prt 0 uident, prt 1 patterns])
instance Print [Def] where
prt _ [] = concatD []
prt _ [x] = concatD [prt 0 x]
prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs]
pattern DBind' id vars expt = DBind (Bind id vars expt)
pattern DData' typ injs = DData (Data typ injs)

View file

@ -19,27 +19,18 @@ import Monomorphizer.Monomorphizer (monomorphize)
import OrderDefs (orderDefs) import OrderDefs (orderDefs)
import Renamer.Renamer (rename) import Renamer.Renamer (rename)
import ReportForall (reportForall) import ReportForall (reportForall)
import System.Console.GetOpt ( import System.Console.GetOpt (ArgDescr (NoArg, ReqArg),
ArgDescr (NoArg, ReqArg),
ArgOrder (RequireOrder), ArgOrder (RequireOrder),
OptDescr (Option), OptDescr (Option), getOpt,
getOpt, usageInfo)
usageInfo, import System.Directory (createDirectory, doesPathExist,
)
import System.Directory (
createDirectory,
doesPathExist,
getDirectoryContents, getDirectoryContents,
removeDirectoryRecursive, removeDirectoryRecursive,
setCurrentDirectory, setCurrentDirectory)
)
import System.Environment (getArgs) import System.Environment (getArgs)
import System.Exit ( import System.Exit (ExitCode (ExitFailure),
ExitCode (ExitFailure), exitFailure, exitSuccess,
exitFailure, exitWith)
exitSuccess,
exitWith,
)
import System.IO (stderr) import System.IO (stderr)
import System.Process (spawnCommand, waitForProcess) import System.Process (spawnCommand, waitForProcess)
import TypeChecker.TypeChecker (TypeChecker (Bi, Hm), typecheck) import TypeChecker.TypeChecker (TypeChecker (Bi, Hm), typecheck)

View file

@ -1,8 +1,11 @@
module Monomorphizer.DataTypeRemover (removeDataTypes) where module Monomorphizer.DataTypeRemover (removeDataTypes) where
import Monomorphizer.MonomorphizerIr qualified as M2 import Data.Bifunctor (Bifunctor (bimap))
import Monomorphizer.MorbIr qualified as M1 import Monomorphizer.MonomorphizerIr (Ident (..))
import TypeChecker.TypeCheckerIr (Ident (Ident)) import qualified Monomorphizer.MonomorphizerIr as M2
import qualified Monomorphizer.MorbIr as M1
import Prelude hiding (exp)
removeDataTypes :: M1.Program -> M2.Program removeDataTypes :: M1.Program -> M2.Program
removeDataTypes (M1.Program defs) = M2.Program (map pDef defs) removeDataTypes (M1.Program defs) = M2.Program (map pDef defs)
@ -30,16 +33,19 @@ newName (M1.TData (Ident str) args) = str ++ concatMap newName args
pBind :: M1.Bind -> M2.Bind pBind :: M1.Bind -> M2.Bind
pBind (M1.Bind id argIds expt) = M2.Bind (pId id) (map pId argIds) (pExpT expt) pBind (M1.Bind id argIds expt) = M2.Bind (pId id) (map pId argIds) (pExpT expt)
pBind (M1.BindC cxt id argIds expt) =
M2.BindC (map pId cxt) (pId id) (map pId argIds) (pExpT expt)
pId :: (Ident, M1.Type) -> (Ident, M2.Type) pId :: (Ident, M1.Type) -> (Ident, M2.Type)
pId (ident, t) = (ident, pType t) pId (ident, t) = (ident, pType t)
pExpT :: M1.ExpT -> M2.ExpT pExpT :: M1.T M1.Exp -> M2.T M2.Exp
pExpT (exp, t) = (pExp exp, pType t) pExpT (exp, t) = (pExp exp, pType t)
pExp :: M1.Exp -> M2.Exp pExp :: M1.Exp -> M2.Exp
pExp (M1.EVar ident) = M2.EVar ident pExp (M1.EVar ident) = M2.EVar ident
pExp (M1.ELit lit) = M2.ELit (pLit lit) pExp (M1.EVarC as ident) = M2.EVarC (map pId as) ident
pExp (M1.ELit lit) = M2.ELit lit
pExp (M1.ELet bind expt) = M2.ELet (pBind bind) (pExpT expt) pExp (M1.ELet bind expt) = M2.ELet (pBind bind) (pExpT expt)
pExp (M1.EApp e1 e2) = M2.EApp (pExpT e1) (pExpT e2) pExp (M1.EApp e1 e2) = M2.EApp (pExpT e1) (pExpT e2)
pExp (M1.EAdd e1 e2) = M2.EAdd (pExpT e1) (pExpT e2) pExp (M1.EAdd e1 e2) = M2.EAdd (pExpT e1) (pExpT e2)
@ -49,12 +55,9 @@ pBranch :: M1.Branch -> M2.Branch
pBranch (M1.Branch (patt, t) expt) = M2.Branch (pPattern patt, pType t) (pExpT expt) pBranch (M1.Branch (patt, t) expt) = M2.Branch (pPattern patt, pType t) (pExpT expt)
pPattern :: M1.Pattern -> M2.Pattern pPattern :: M1.Pattern -> M2.Pattern
pPattern (M1.PVar id) = M2.PVar (pId id) pPattern (M1.PVar ident) = M2.PVar ident
pPattern (M1.PLit (lit, t)) = M2.PLit (pLit lit, pType t) pPattern (M1.PLit lit) = M2.PLit lit
pPattern (M1.PInj ident patts) = M2.PInj ident (map pPattern patts) pPattern (M1.PInj ident patts) = M2.PInj ident (map (bimap pPattern pType) patts)
pPattern M1.PCatch = M2.PCatch pPattern M1.PCatch = M2.PCatch
pPattern (M1.PEnum ident) = M2.PEnum ident pPattern (M1.PEnum ident) = M2.PEnum ident
pLit :: M1.Lit -> M2.Lit
pLit (M1.LInt v) = M2.LInt v
pLit (M1.LChar c) = M2.LChar c

View file

@ -1,4 +1,5 @@
{-# LANGUAGE LambdaCase #-} {-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedRecordDot #-}
{- | For now, converts polymorphic functions to concrete ones based on usage. {- | For now, converts polymorphic functions to concrete ones based on usage.
Assumes lambdas are lifted. Assumes lambdas are lifted.
@ -25,30 +26,35 @@ bind) is added to the resulting set of binds.
-} -}
module Monomorphizer.Monomorphizer (monomorphize, morphExp, morphBind) where module Monomorphizer.Monomorphizer (monomorphize, morphExp, morphBind) where
import Monomorphizer.DataTypeRemover (removeDataTypes)
import Monomorphizer.MonomorphizerIr qualified as O
import Monomorphizer.MorbIr qualified as M
import TypeChecker.TypeCheckerIr (Ident (Ident))
import TypeChecker.TypeCheckerIr qualified as T
import Control.Monad.Reader ( import Control.Monad.Reader (MonadReader (ask, local),
MonadReader (ask, local), Reader, asks, runReader)
Reader, import Control.Monad.State (MonadState (get),
asks, StateT (runStateT), gets,
runReader, modify)
)
import Control.Monad.State (
MonadState (get),
StateT (runStateT),
gets,
modify,
)
import Data.Coerce (coerce) import Data.Coerce (coerce)
import Data.Map qualified as Map import qualified Data.Map as Map
import Data.Maybe (catMaybes) import Data.Maybe (catMaybes)
import Data.Set qualified as Set import qualified Data.Set as Set
import Grammar.Print (printTree)
import Debug.Trace (trace) import Debug.Trace (trace)
import Grammar.Print (printTree)
import Monomorphizer.DataTypeRemover (removeDataTypes)
import qualified Monomorphizer.MonomorphizerIr as O
import qualified Monomorphizer.MorbIr as M
-- import TypeChecker.TypeCheckerIr (Ident (Ident))
import LambdaLifterIr (Ident (..))
-- import TypeChecker.TypeCheckerIr qualified as T
import qualified LambdaLifterIr as L
import Control.Monad.Reader (MonadReader (ask, local),
Reader, asks, runReader)
import Control.Monad.State (MonadState, StateT (runStateT),
gets, modify)
import qualified Data.Map as Map
import Data.Maybe (catMaybes, fromJust)
import qualified Data.Set as Set
import Data.Tuple.Extra (secondM)
import Grammar.Print (printTree)
{- | EnvM is the monad containing the read-only state as well as the {- | EnvM is the monad containing the read-only state as well as the
output state containing monomorphized functions and to-be monomorphized output state containing monomorphized functions and to-be monomorphized
@ -64,13 +70,13 @@ Binds, Polymorphic Data types (monomorphized in a later step) and
Marked bind, which means that it is in the process of monomorphization Marked bind, which means that it is in the process of monomorphization
and should not be monomorphized again. and should not be monomorphized again.
-} -}
data Outputted = Marked | Complete M.Bind | Data M.Type T.Data deriving (Show) data Outputted = Marked | Complete M.Bind | Data M.Type L.Data deriving (Show)
-- | Static environment. -- | Static environment.
data Env = Env data Env = Env
{ input :: Map.Map Ident T.Bind { input :: Map.Map Ident L.Bind
-- ^ All binds in the program. -- ^ All binds in the program.
, dataDefs :: Map.Map Ident T.Data , dataDefs :: Map.Map Ident L.Data
-- ^ All constructors mapped to their respective polymorphic data def -- ^ All constructors mapped to their respective polymorphic data def
-- which includes all other constructors. -- which includes all other constructors.
, polys :: Map.Map Ident M.Type , polys :: Map.Map Ident M.Type
@ -84,12 +90,13 @@ localExists :: Ident -> EnvM Bool
localExists ident = asks (Set.member ident . locals) localExists ident = asks (Set.member ident . locals)
-- | Gets a polymorphic bind from an id. -- | Gets a polymorphic bind from an id.
getInputBind :: Ident -> EnvM (Maybe T.Bind) getInputBind :: Ident -> EnvM (Maybe L.Bind)
getInputBind ident = asks (Map.lookup ident . input) getInputBind ident = asks (Map.lookup ident . input)
-- | Add monomorphic function derived from a polymorphic one, to env. -- | Add monomorphic function derived from a polymorphic one, to env.
addOutputBind :: M.Bind -> EnvM () addOutputBind :: M.Bind -> EnvM ()
addOutputBind b@(M.Bind (ident, _) _ _) = modify (Map.insert ident (Complete b)) addOutputBind b@(M.Bind (ident, _) _ _) = modify (Map.insert ident (Complete b))
addOutputBind b@(M.BindC _ (ident, _) _ _) = modify (Map.insert ident (Complete b))
{- | Marks a global bind as being processed, meaning that when encountered again, {- | Marks a global bind as being processed, meaning that when encountered again,
it should not be recursively processed. it should not be recursively processed.
@ -106,8 +113,8 @@ isConsMarked :: Ident -> EnvM Bool
isConsMarked ident = gets (Map.member ident) isConsMarked ident = gets (Map.member ident)
-- | Finds main bind. -- | Finds main bind.
getMain :: EnvM T.Bind getMain :: EnvM L.Bind
getMain = asks (\env -> case Map.lookup (T.Ident "main") (input env) of getMain = asks (\env -> case Map.lookup (Ident "main") (input env) of
Just mainBind -> mainBind Just mainBind -> mainBind
Nothing -> error "main not found in monomorphizer!" Nothing -> error "main not found in monomorphizer!"
) )
@ -116,13 +123,13 @@ getMain = asks (\env -> case Map.lookup (T.Ident "main") (input env) of
error when encountering different structures between the two arguments. Debug: error when encountering different structures between the two arguments. Debug:
First argument is the name of the bind. First argument is the name of the bind.
-} -}
mapTypes :: Ident -> T.Type -> M.Type -> [(Ident, M.Type)] mapTypes :: Ident -> L.Type -> M.Type -> [(Ident, M.Type)]
mapTypes _ident (T.TLit _) (M.TLit _) = [] mapTypes _ident (L.TLit _) (M.TLit _) = []
mapTypes _ident (T.TVar (T.MkTVar i1)) tm = [(i1, tm)] mapTypes _ident (L.TVar (L.MkTVar i1)) tm = [(i1, tm)]
mapTypes ident (T.TFun pt1 pt2) (M.TFun mt1 mt2) = mapTypes ident (L.TFun pt1 pt2) (M.TFun mt1 mt2) =
mapTypes ident pt1 mt1 mapTypes ident pt1 mt1
++ mapTypes ident pt2 mt2 ++ mapTypes ident pt2 mt2
mapTypes ident (T.TData tIdent pTs) (M.TData mIdent mTs) = mapTypes ident (L.TData tIdent pTs) (M.TData mIdent mTs) =
if tIdent /= mIdent if tIdent /= mIdent
then error "the data type names of monomorphic and polymorphic data types does not match" then error "the data type names of monomorphic and polymorphic data types does not match"
else foldl (\xs (p, m) -> mapTypes ident p m ++ xs) [] (zip pTs mTs) else foldl (\xs (p, m) -> mapTypes ident p m ++ xs) [] (zip pTs mTs)
@ -130,30 +137,30 @@ mapTypes ident t1 t2 = error $ "in bind: '" ++ printTree ident ++ "', " ++
"structure of types not the same: '" ++ printTree t1 ++ "', '" ++ printTree t2 ++ "'" "structure of types not the same: '" ++ printTree t1 ++ "', '" ++ printTree t2 ++ "'"
-- | Gets the mapped monomorphic type of a polymorphic type in the current context. -- | Gets the mapped monomorphic type of a polymorphic type in the current context.
getMonoFromPoly :: T.Type -> EnvM M.Type getMonoFromPoly :: L.Type -> EnvM M.Type
getMonoFromPoly t = do getMonoFromPoly t = do
env <- ask env <- ask
return $ getMono (polys env) t return $ getMono (polys env) t
where where
getMono :: Map.Map Ident M.Type -> T.Type -> M.Type getMono :: Map.Map Ident M.Type -> L.Type -> M.Type
getMono polys t = case t of getMono polys t = case t of
(T.TLit ident) -> M.TLit (coerce ident) (L.TLit ident) -> M.TLit ident
(T.TFun t1 t2) -> M.TFun (getMono polys t1) (getMono polys t2) (L.TFun t1 t2) -> M.TFun (getMono polys t1) (getMono polys t2)
(T.TVar (T.MkTVar ident)) -> case Map.lookup ident polys of (L.TVar (L.MkTVar ident)) -> case Map.lookup ident polys of
Just concrete -> concrete Just concrete -> concrete
Nothing -> M.TLit (Ident "void") Nothing -> M.TLit (Ident "void")
-- error $ "type not found! type: " ++ show ident ++ ", error in previous compilation steps" -- error $ "type not found! type: " ++ show ident ++ ", error in previous compilation steps"
(T.TData ident args) -> M.TData ident (map (getMono polys) args) (L.TData ident args) -> M.TData ident (map (getMono polys) args)
{- | If ident not already in env's output, morphed bind to output {- | If ident not already in env's output, morphed bind to output
(and all referenced binds within this bind). (and all referenced binds within this bind).
Returns the annotated bind name. Returns the annotated bind name.
-} -}
morphBind :: M.Type -> T.Bind -> EnvM Ident morphBind :: M.Type -> L.Bind -> EnvM Ident
morphBind expectedType b@(T.Bind (ident, btype) args (exp, expt)) = do morphBind expectedType b@(L.Bind (ident, btype) args (exp, expt)) = do
-- The "new name" is used to find out if it is already marked or not. -- The "new name" is used to find out if it is already marked or not.
let name' = newFuncName expectedType b let name' = newFuncName expectedType b
bindMarked <- isBindMarked (coerce name') bindMarked <- isBindMarked name'
local local
( \env -> ( \env ->
env env
@ -168,26 +175,59 @@ morphBind expectedType b@(T.Bind (ident, btype) args (exp, expt)) = do
else do else do
-- Mark so that this bind will not be processed in recursive or cyclic -- Mark so that this bind will not be processed in recursive or cyclic
-- function calls -- function calls
markBind (coerce name') markBind name'
expt' <- getMonoFromPoly expt expt' <- getMonoFromPoly expt
exp' <- morphExp expt' exp exp' <- morphExp expt' exp
-- Get monomorphic type sof args -- Get monomorphic type sof args
args' <- mapM morphArg args args' <- mapM morphArg args
addOutputBind $ addOutputBind $
M.Bind M.Bind
(coerce name', expectedType) (name', expectedType)
args' args'
(exp', expt') (exp', expt')
return name' return name'
morphBind expectedType b@(L.BindC cxt (ident, btype) args (exp, expt)) = do
-- The "new name" is used to find out if it is already marked or not.
let name' = newFuncName expectedType b
bindMarked <- isBindMarked name'
local
( \env ->
env
{ locals = Set.fromList (map fst args)
, polys = Map.fromList (mapTypes ident btype expectedType)
}
)
$ do
-- Return with right name if already marked
if bindMarked
then return name'
else do
-- Mark so that this bind will not be processed in recursive or cyclic
-- function calls
markBind name'
-- Get monomorphic type sof args
args' <- mapM morphArg args
cxt' <- mapM (secondM getMonoFromPoly) cxt
expt' <- getMonoFromPoly expt
exp' <- local (\env -> foldr (addLocal . fst) env cxt)
(morphExp expt' exp)
addOutputBind $
M.BindC cxt'
(name', expectedType)
args'
(exp', expt')
return name'
-- | Monomorphizes arguments of a bind. -- | Monomorphizes arguments of a bind.
morphArg :: (Ident, T.Type) -> EnvM (Ident, M.Type) morphArg :: (Ident, L.Type) -> EnvM (Ident, M.Type)
morphArg (ident, t) = do morphArg (ident, t) = do
t' <- getMonoFromPoly t t' <- getMonoFromPoly t
return (ident, t') return (ident, t')
-- | Gets the data bind from the name of a constructor. -- | Gets the data bind from the name of a constructor.
getInputData :: Ident -> EnvM (Maybe T.Data) getInputData :: Ident -> EnvM (Maybe L.Data)
getInputData ident = do getInputData ident = do
env <- ask env <- ask
return $ Map.lookup ident (dataDefs env) return $ Map.lookup ident (dataDefs env)
@ -201,50 +241,50 @@ morphCons expectedType ident newIdent = do
--trace ("Tjofras:" ++ show (newName expectedType ident)) $ return () --trace ("Tjofras:" ++ show (newName expectedType ident)) $ return ()
maybeD <- getInputData ident maybeD <- getInputData ident
case maybeD of case maybeD of
Nothing -> error $ "identifier '" ++ show ident ++ "' not found" -- closures can have unbound variables
Nothing -> pure ()
Just d -> do Just d -> do
modify (\output -> Map.insert newIdent (Data expectedType d) output) modify (\output -> Map.insert newIdent (Data expectedType d) output)
-- | Converts literals from input to output tree. -- | Converts literals from input to output tree.
convertLit :: T.Lit -> M.Lit convertLit :: L.Lit -> M.Lit
convertLit (T.LInt v) = M.LInt v convertLit (L.LInt v) = M.LInt v
convertLit (T.LChar v) = M.LChar v convertLit (L.LChar v) = M.LChar v
-- | Monomorphizes an expression, given an expected type. -- | Monomorphizes an expression, given an expected type.
morphExp :: M.Type -> T.Exp -> EnvM M.Exp morphExp :: M.Type -> L.Exp -> EnvM M.Exp
morphExp expectedType exp = case exp of morphExp expectedType exp = case exp of
T.ELit lit -> return $ M.ELit (convertLit lit) L.ELit lit -> return $ M.ELit lit
-- Constructor -- Constructor
T.EInj ident -> do L.EInj ident -> do
let ident' = newName (getDataType expectedType) ident let ident' = newName (getDataType expectedType) ident
morphCons expectedType ident ident' morphCons expectedType ident ident'
return $ M.EVar ident' return $ M.EVar ident'
T.EApp (e1, _t1) (e2, t2) -> do L.EApp (e1, _t1) (e2, t2) -> do
t2' <- getMonoFromPoly t2 t2' <- getMonoFromPoly t2
e2' <- morphExp t2' e2 e2' <- morphExp t2' e2
e1' <- morphExp (M.TFun t2' expectedType) e1 e1' <- morphExp (M.TFun t2' expectedType) e1
return $ M.EApp (e1', M.TFun t2' expectedType) (e2', t2') return $ M.EApp (e1', M.TFun t2' expectedType) (e2', t2')
T.EAdd (e1, t1) (e2, t2) -> do L.EAdd (e1, t1) (e2, t2) -> do
t1' <- getMonoFromPoly t1 t1' <- getMonoFromPoly t1
t2' <- getMonoFromPoly t2 t2' <- getMonoFromPoly t2
e1' <- morphExp t1' e1 e1' <- morphExp t1' e1
e2' <- morphExp t2' e2 e2' <- morphExp t2' e2
return $ M.EAdd (e1', expectedType) (e2', expectedType) return $ M.EAdd (e1', expectedType) (e2', expectedType)
T.EAbs ident (exp, t) -> local (\env -> env{locals = Set.insert ident (locals env)}) $ do L.ECase (exp, t) bs -> do
t' <- getMonoFromPoly t
morphExp t' exp
T.ECase (exp, t) bs -> do
t' <- getMonoFromPoly t t' <- getMonoFromPoly t
exp' <- morphExp t' exp exp' <- morphExp t' exp
bs' <- mapM morphBranch bs bs' <- mapM morphBranch bs
return $ M.ECase (exp', t') (catMaybes bs') return $ M.ECase (exp', t') (catMaybes bs')
-- Ideally constructors should be EInj, though this code handles them -- Ideally constructors should be EInj, though this code handles them
-- as well. -- as well.
T.EVar ident -> do -- FIXME MAKE EVAR AND EINJ SEPARATE!!!
L.EVar ident -> do
isLocal <- localExists ident isLocal <- localExists ident
if isLocal if isLocal
then do then do
return $ M.EVar (coerce ident) return $ M.EVar ident
else do else do
bind <- getInputBind ident bind <- getInputBind ident
case bind of case bind of
@ -252,20 +292,33 @@ morphExp expectedType exp = case exp of
Just bind' -> do Just bind' -> do
-- New bind to process -- New bind to process
newBindName <- morphBind expectedType bind' newBindName <- morphBind expectedType bind'
return $ M.EVar (coerce newBindName) return $ M.EVar newBindName
T.ELet (T.Bind (identB, tB) args (expB, tExpB)) (exp, tExp) -> L.EVarC as ident -> do
if length args > 0 then error "only constants in lets allowed" isLocal <- localExists ident
if isLocal
then do
return $ M.EVar ident
else do else do
bind <- fromJust <$> getInputBind ident
as' <- mapM (secondM getMonoFromPoly) as
-- New bind to process
newBindName <- morphBind expectedType bind
return $ M.EVarC as' newBindName
-- Ideally constructors should be EInj, though this code handles them
-- as well.
L.ELet (identB, tB) (expB, tExpB) (exp, tExp) -> do
tB' <- getMonoFromPoly tB tB' <- getMonoFromPoly tB
tExpB' <- getMonoFromPoly tExpB tExpB' <- getMonoFromPoly tExpB
tExp' <- getMonoFromPoly tExp tExp' <- getMonoFromPoly tExp
expB' <- morphExp tExpB' expB expB' <- morphExp tExpB' expB
exp' <- morphExp tExp' exp exp' <- local (addLocal identB) (morphExp tExp' exp)
return $ M.ELet (M.Bind (identB, tB') [] (expB', tExpB')) (exp', tExp') return $ M.ELet (M.Bind (identB, tB') [] (expB', tExpB')) (exp', tExp')
-- | Monomorphizes case-of branches. -- | Monomorphizes case-of branches.
morphBranch :: T.Branch -> EnvM (Maybe M.Branch) morphBranch :: L.Branch -> EnvM (Maybe M.Branch)
morphBranch (T.Branch (p, pt) (e, et)) = do morphBranch (L.Branch (p, pt) (e, et)) = do
pt' <- getMonoFromPoly pt pt' <- getMonoFromPoly pt
et' <- getMonoFromPoly et et' <- getMonoFromPoly et
env <- ask env <- ask
@ -275,15 +328,15 @@ morphBranch (T.Branch (p, pt) (e, et)) = do
Just (p', newLocals) -> Just (p', newLocals) ->
local (const env { locals = Set.union (locals env) newLocals }) $ do local (const env { locals = Set.union (locals env) newLocals }) $ do
e' <- morphExp et' e e' <- morphExp et' e
return $ Just (M.Branch (p', pt') (e', et')) return $ Just (M.Branch p' (e', et'))
morphPattern :: T.Pattern -> M.Type -> EnvM (Maybe (M.Pattern, Set.Set Ident)) morphPattern :: L.Pattern -> M.Type -> EnvM (Maybe (M.T M.Pattern, Set.Set Ident))
morphPattern p expectedType = case p of morphPattern p expectedType = case p of
T.PVar ident -> return $ Just (M.PVar (ident, expectedType), Set.singleton ident) L.PVar ident -> return $ Just ((M.PVar ident, expectedType), Set.singleton ident)
T.PLit lit -> return $ Just (M.PLit (convertLit lit, expectedType), Set.empty) L.PLit lit -> return $ Just ((M.PLit (convertLit lit), expectedType), Set.empty)
T.PCatch -> return $ Just (M.PCatch, Set.empty) L.PCatch -> return $ Just ((M.PCatch, expectedType), Set.empty)
T.PEnum ident -> return $ Just (M.PEnum (newName expectedType ident), Set.empty) L.PEnum ident -> return $ Just ((M.PEnum (newName expectedType ident), expectedType), Set.empty)
T.PInj ident pts -> do let newIdent = newName expectedType ident L.PInj ident pts -> do let newIdent = newName expectedType ident
outEnv <- get outEnv <- get
trace ("WOW: " ++ show (newName expectedType ident)) $ return () trace ("WOW: " ++ show (newName expectedType ident)) $ return ()
trace ("WOW2: " ++ show (outEnv)) $ return () trace ("WOW2: " ++ show (outEnv)) $ return ()
@ -298,12 +351,17 @@ morphPattern p expectedType = case p of
case maybePsSets of case maybePsSets of
Nothing -> return Nothing Nothing -> return Nothing
Just psSets' -> return $ Just Just psSets' -> return $ Just
(M.PInj newIdent (map fst psSets'), Set.unions $ map snd psSets') ((M.PInj newIdent (map fst psSets'), expectedType), Set.unions $ map snd psSets')
else return Nothing else return Nothing
-- | Creates a new identifier for a function with an assigned type. -- | Creates a new identifier for a function with an assigned type.
newFuncName :: M.Type -> T.Bind -> Ident newFuncName :: M.Type -> L.Bind -> Ident
newFuncName t (T.Bind (ident@(Ident bindName), _) _ _) = newFuncName t (L.Bind (ident@(Ident bindName), _) _ _) =
if bindName == "main"
then Ident bindName
else newName t ident
newFuncName t (L.BindC _ (ident@(Ident bindName), _) _ _) =
if bindName == "main" if bindName == "main"
then Ident bindName then Ident bindName
else newName t ident else newName t ident
@ -317,8 +375,8 @@ newName t (Ident str) = Ident $ str ++ "$" ++ newName' t
newName' (M.TData (Ident str) ts) = str ++ foldl (\s t -> s ++ "." ++ newName' t) "" ts newName' (M.TData (Ident str) ts) = str ++ foldl (\s t -> s ++ "." ++ newName' t) "" ts
-- | Monomorphization step. -- | Monomorphization step.
monomorphize :: T.Program -> O.Program monomorphize :: L.Program -> O.Program
monomorphize (T.Program defs) = monomorphize (L.Program defs) =
removeDataTypes $ removeDataTypes $
M.Program M.Program
( getDefsFromOutput ( getDefsFromOutput
@ -336,7 +394,7 @@ runEnvM :: Output -> Env -> EnvM () -> Output
runEnvM o env (EnvM stateM) = snd $ runReader (runStateT stateM o) env runEnvM o env (EnvM stateM) = snd $ runReader (runStateT stateM o) env
-- | Creates the environment based on the input binds. -- | Creates the environment based on the input binds.
createEnv :: [T.Def] -> Env createEnv :: [L.Def] -> Env
createEnv defs = createEnv defs =
Env Env
{ input = Map.fromList bindPairs { input = Map.fromList bindPairs
@ -346,33 +404,34 @@ createEnv defs =
} }
where where
bindPairs = (map (\b -> (getBindName b, b)) . getBindsFromDefs) defs bindPairs = (map (\b -> (getBindName b, b)) . getBindsFromDefs) defs
dataPairs :: [(Ident, T.Data)] dataPairs :: [(Ident, L.Data)]
dataPairs = (foldl (\acc d@(T.Data _ cs) -> map ((,d) . getConsName) cs ++ acc) [] . getDataFromDefs) defs dataPairs = (foldl (\acc d@(L.Data _ cs) -> map ((,d) . getConsName) cs ++ acc) [] . getDataFromDefs) defs
-- | Gets a top-lefel function name. -- | Gets a top-lefel function name.
getBindName :: T.Bind -> Ident getBindName :: L.Bind -> Ident
getBindName (T.Bind (ident, _) _ _) = ident getBindName (L.Bind (ident, _) _ _) = ident
getBindName (L.BindC _ (ident, _) _ _) = ident
-- Helper functions -- Helper functions
-- Gets custom data declarations form defs. -- Gets custom data declarations form defs.
getDataFromDefs :: [T.Def] -> [T.Data] getDataFromDefs :: [L.Def] -> [L.Data]
getDataFromDefs = getDataFromDefs =
foldl foldl
( \bs -> \case ( \bs -> \case
T.DBind _ -> bs L.DBind _ -> bs
T.DData d -> d : bs L.DData d -> d : bs
) )
[] []
getConsName :: T.Inj -> Ident getConsName :: L.Inj -> Ident
getConsName (T.Inj ident _) = ident getConsName (L.Inj ident _) = ident
getBindsFromDefs :: [T.Def] -> [T.Bind] getBindsFromDefs :: [L.Def] -> [L.Bind]
getBindsFromDefs = getBindsFromDefs =
foldl foldl
( \bs -> \case ( \bs -> \case
T.DBind b -> b : bs L.DBind b -> b : bs
T.DData _ -> bs L.DData _ -> bs
) )
[] []
@ -384,7 +443,7 @@ getDefsFromOutput o =
(binds, dataInput) = splitBindsAndData o (binds, dataInput) = splitBindsAndData o
-- | Splits the output into binds and data declaration components (used in createNewData) -- | Splits the output into binds and data declaration components (used in createNewData)
splitBindsAndData :: Output -> ([M.Bind], [(Ident, M.Type, T.Data)]) splitBindsAndData :: Output -> ([M.Bind], [(Ident, M.Type, L.Data)])
splitBindsAndData output = splitBindsAndData output =
foldl foldl
( \(oBinds, oData) (ident, o) -> case o of ( \(oBinds, oData) (ident, o) -> case o of
@ -396,7 +455,7 @@ splitBindsAndData output =
(Map.toList output) (Map.toList output)
-- | Converts all found constructors to monomorphic data declarations. -- | Converts all found constructors to monomorphic data declarations.
createNewData :: [(Ident, M.Type, T.Data)] -> Map.Map Ident M.Data -> Map.Map Ident M.Data createNewData :: [(Ident, M.Type, L.Data)] -> Map.Map Ident M.Data -> Map.Map Ident M.Data
createNewData [] o = o createNewData [] o = o
createNewData ((consIdent, consType, polyData) : input) o = createNewData ((consIdent, consType, polyData) : input) o =
createNewData input $ createNewData input $
@ -406,7 +465,7 @@ createNewData ((consIdent, consType, polyData) : input) o =
(M.Data newDataType [newCons]) (M.Data newDataType [newCons])
o o
where where
T.Data (T.TData polyDataIdent _) _ = polyData L.Data (L.TData polyDataIdent _) _ = polyData
newDataType = getDataType consType newDataType = getDataType consType
newDataName = newName newDataType polyDataIdent newDataName = newName newDataType polyDataIdent
newCons = M.Inj consIdent consType newCons = M.Inj consIdent consType
@ -417,3 +476,6 @@ getDataType (M.TFun _t1 t2) = getDataType t2
getDataType tData@(M.TData _ _) = tData getDataType tData@(M.TData _ _) = tData
getDataType _ = error "???" getDataType _ = error "???"
addLocal :: Ident -> Env -> Env
addLocal x env = env { locals = Set.insert x env.locals }

View file

@ -1,11 +1,14 @@
{-# LANGUAGE LambdaCase #-} {-# LANGUAGE LambdaCase #-}
module Monomorphizer.MonomorphizerIr (module Monomorphizer.MonomorphizerIr) where module Monomorphizer.MonomorphizerIr (
module Monomorphizer.MonomorphizerIr,
module LambdaLifterIr
) where
import Data.List (intercalate)
import Grammar.Print import Grammar.Print
import TypeChecker.TypeCheckerIr qualified as TIR (Ident (..)) import LambdaLifterIr (Ident (..), Lit (..))
import Prelude hiding (exp)
type Id = (TIR.Ident, Type)
newtype Program = Program [Def] newtype Program = Program [Def]
deriving (Show, Ord, Eq) deriving (Show, Ord, Eq)
@ -16,40 +19,37 @@ data Def = DBind Bind | DData Data
data Data = Data Type [Inj] data Data = Data Type [Inj]
deriving (Show, Ord, Eq) deriving (Show, Ord, Eq)
data Bind = Bind Id [Id] ExpT data Bind = Bind (T Ident) [T Ident] (T Exp)
| BindC [T Ident] (T Ident) [T Ident] (T Exp)
deriving (Show, Ord, Eq) deriving (Show, Ord, Eq)
type T a = (a, Type)
data Exp data Exp
= EVar TIR.Ident = EVar Ident
| EVarC [T Ident] Ident
| ELit Lit | ELit Lit
| ELet Bind ExpT | ELet Bind (T Exp)
| EApp ExpT ExpT | EApp (T Exp) (T Exp)
| EAdd ExpT ExpT | EAdd (T Exp) (T Exp)
| ECase ExpT [Branch] | ECase (T Exp) [Branch]
deriving (Show, Ord, Eq) deriving (Show, Ord, Eq)
data Pattern data Pattern
= PVar Id = PVar Ident
| PLit (Lit, Type) | PLit Lit
| PInj TIR.Ident [Pattern] | PInj Ident [T Pattern]
| PCatch | PCatch
| PEnum TIR.Ident | PEnum Ident
deriving (Eq, Ord, Show) deriving (Eq, Ord, Show)
data Branch = Branch (Pattern, Type) ExpT data Branch = Branch (T Pattern) (T Exp)
deriving (Eq, Ord, Show) deriving (Eq, Ord, Show)
type ExpT = (Exp, Type) data Inj = Inj Ident Type
data Inj = Inj TIR.Ident Type
deriving (Show, Ord, Eq) deriving (Show, Ord, Eq)
data Lit data Type = TLit Ident | TFun Type Type
= LInt Integer
| LChar Char
deriving (Show, Ord, Eq)
data Type = TLit TIR.Ident | TFun Type Type
deriving (Show, Ord, Eq) deriving (Show, Ord, Eq)
flattenType :: Type -> [Type] flattenType :: Type -> [Type]
@ -59,47 +59,40 @@ flattenType x = [x]
instance Print Program where instance Print Program where
prt i (Program sc) = prPrec i 0 $ prt 0 sc prt i (Program sc) = prPrec i 0 $ prt 0 sc
instance Print (Bind) where instance Print Bind where
prt i (Bind sig@(name, _) parms rhs) = prt i (Bind sig@(name, _) parms rhs) =
prPrec i 0 $ prPrec i 0 $
concatD concatD
[ prtSig sig [ prt 0 sig
, prt 0 name , prt 0 name
, prtIdPs 0 parms , prt 0 parms
, doc $ showString "=" , doc $ showString "="
, prt 0 rhs , prt 0 rhs
] ]
prtSig :: Id -> Doc prt i (BindC cxt sig parms rhs) =
prtSig (name, t) = prPrec i 0 $
concatD concatD
[ prt 0 name [ doc . showString $ "{" ++ intercalate ", " (map (\(x, _) -> printTree x ++ "^") cxt) ++ "}" ++ printTree sig
, doc $ showString ":" , prt i parms
, prt 0 t , doc $ showString "="
, doc $ showString ";" , prt i rhs
] ]
instance Print (ExpT) where
prt i (e, t) =
concatD
[ doc $ showString "("
, prt i e
, doc $ showString ","
, prt i t
, doc $ showString ")"
]
instance Print [Bind] where instance Print [Bind] where
prt _ [] = concatD [] prt _ [] = concatD []
prt _ [x] = concatD [prt 0 x] prt _ [x] = concatD [prt 0 x]
prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs] prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs]
prtIdPs :: Int -> [Id] -> Doc
prtIdPs i = prPrec i 0 . concatD . map (prt i)
instance Print Exp where instance Print Exp where
prt i = \case prt i = \case
EVar name -> prPrec i 3 $ prt 0 name EVar name -> prPrec i 3 $ prt 0 name
EVarC as lident -> doc . showString
$ "{" ++ intercalate ", " (map go as) ++ "}" ++ printTree lident
where
go (x, _) = printTree x ++ "^=" ++ printTree (EVar x)
ELit lit -> prPrec i 3 $ prt 0 lit ELit lit -> prPrec i 3 $ prt 0 lit
ELet b e -> ELet b e ->
prPrec i 3 $ prPrec i 3 $
@ -134,7 +127,7 @@ instance Print Exp where
] ]
instance Print Branch where instance Print Branch where
prt i (Branch (pattern_, t) exp) = prPrec i 0 (concatD [doc (showString "("), prt 0 pattern_, doc (showString " : "), prt 0 t, doc (showString ")"), doc (showString "=>"), prt 0 exp]) prt i (Branch patt exp) = prPrec i 0 (concatD [prt i patt, doc (showString "=>"), prt 0 exp])
instance Print [Branch] where instance Print [Branch] where
prt _ [] = concatD [] prt _ [] = concatD []
@ -152,12 +145,12 @@ instance Print Data where
instance Print Inj where instance Print Inj where
prt i = \case prt i = \case
Inj uident type_ -> prPrec i 0 (concatD [prt 0 uident, doc (showString ":"), prt 0 type_]) Inj uident type_ -> prt i (uident, type_)
instance Print Pattern where instance Print Pattern where
prt i = \case prt i = \case
PVar name -> prPrec i 1 (concatD [prt 0 name]) PVar name -> prPrec i 1 (concatD [prt 0 name])
PLit (lit, _) -> prPrec i 1 (concatD [prt 0 lit]) PLit lit -> prPrec i 1 (concatD [prt 0 lit])
PCatch -> prPrec i 1 (concatD [doc (showString "_")]) PCatch -> prPrec i 1 (concatD [doc (showString "_")])
PEnum name -> prPrec i 1 (concatD [prt 0 name]) PEnum name -> prPrec i 1 (concatD [prt 0 name])
PInj uident patterns -> prPrec i 0 (concatD [prt 0 uident, prt 1 patterns]) PInj uident patterns -> prPrec i 0 (concatD [prt 0 uident, prt 1 patterns])
@ -175,8 +168,3 @@ instance Print Type where
prt i = \case prt i = \case
TLit uident -> prPrec i 1 (concatD [prt 0 uident]) TLit uident -> prPrec i 1 (concatD [prt 0 uident])
TFun type_1 type_2 -> prPrec i 0 (concatD [prt 1 type_1, doc (showString "->"), prt 0 type_2]) TFun type_1 type_2 -> prPrec i 0 (concatD [prt 1 type_1, doc (showString "->"), prt 0 type_2])
instance Print Lit where
prt i = \case
LInt int -> prt i int
LChar char -> prt i char

View file

@ -1,10 +1,14 @@
{-# LANGUAGE LambdaCase #-} {-# LANGUAGE LambdaCase #-}
module Monomorphizer.MorbIr where
module Monomorphizer.MorbIr (
module Monomorphizer.MorbIr,
module LambdaLifterIr
) where
import Data.List (intercalate)
import Grammar.Print import Grammar.Print
import TypeChecker.TypeCheckerIr qualified as TIR (Ident (..)) import LambdaLifterIr (Ident (..), Lit (..))
import Prelude hiding (exp)
type Id = (TIR.Ident, Type)
newtype Program = Program [Def] newtype Program = Program [Def]
deriving (Show, Ord, Eq) deriving (Show, Ord, Eq)
@ -15,40 +19,39 @@ data Def = DBind Bind | DData Data
data Data = Data Type [Inj] data Data = Data Type [Inj]
deriving (Show, Ord, Eq) deriving (Show, Ord, Eq)
data Bind = Bind Id [Id] ExpT data Bind = Bind (T Ident) [T Ident] (T Exp)
| BindC [T Ident] (T Ident) [T Ident] (T Exp)
deriving (Show, Ord, Eq) deriving (Show, Ord, Eq)
type T a = (a, Type)
data Exp data Exp
= EVar TIR.Ident = EVar Ident
| EVarC [T Ident] Ident
| ELit Lit | ELit Lit
| ELet Bind ExpT | ELet Bind (T Exp)
| EApp ExpT ExpT | EApp (T Exp) (T Exp)
| EAdd ExpT ExpT | EAdd (T Exp) (T Exp)
| ECase ExpT [Branch] | ECase (T Exp) [Branch]
deriving (Show, Ord, Eq) deriving (Show, Ord, Eq)
data Pattern data Pattern
= PVar Id = PVar Ident
| PLit (Lit, Type) | PLit Lit
| PInj TIR.Ident [Pattern] | PInj Ident [T Pattern]
| PCatch | PCatch
| PEnum TIR.Ident | PEnum Ident
deriving (Eq, Ord, Show) deriving (Eq, Ord, Show)
data Branch = Branch (Pattern, Type) ExpT
data Branch = Branch (T Pattern) (T Exp)
deriving (Eq, Ord, Show) deriving (Eq, Ord, Show)
type ExpT = (Exp, Type) data Inj = Inj Ident Type
data Inj = Inj TIR.Ident Type
deriving (Show, Ord, Eq) deriving (Show, Ord, Eq)
data Lit data Type = TLit Ident | TFun Type Type | TData Ident [Type]
= LInt Integer
| LChar Char
deriving (Show, Ord, Eq)
data Type = TLit TIR.Ident | TFun Type Type | TData TIR.Ident [Type]
deriving (Show, Ord, Eq) deriving (Show, Ord, Eq)
@ -59,34 +62,24 @@ flattenType x = [x]
instance Print Program where instance Print Program where
prt i (Program sc) = prPrec i 0 $ prt 0 sc prt i (Program sc) = prPrec i 0 $ prt 0 sc
instance Print (Bind) where instance Print Bind where
prt i (Bind sig@(name, _) parms rhs) = prt i (Bind sig@(name, _) parms rhs) =
prPrec i 0 $ prPrec i 0 $
concatD concatD
[ prtSig sig [ prt 0 sig
, prt 0 name , prt 0 name
, prtIdPs 0 parms , prt 0 parms
, doc $ showString "=" , doc $ showString "="
, prt 0 rhs , prt 0 rhs
] ]
prtSig :: Id -> Doc prt i (BindC cxt sig parms rhs) =
prtSig (name, t) = prPrec i 0 $
concatD concatD
[ prt 0 name [ doc . showString $ "{" ++ intercalate ", " (map (\(x, _) -> printTree x ++ "^") cxt) ++ "}" ++ printTree sig
, doc $ showString ":" , prt i parms
, prt 0 t , doc $ showString "="
, doc $ showString ";" , prt i rhs
]
instance Print (ExpT) where
prt i (e, t) =
concatD
[ doc $ showString "("
, prt i e
, doc $ showString ","
, prt i t
, doc $ showString ")"
] ]
instance Print [Bind] where instance Print [Bind] where
@ -94,12 +87,13 @@ instance Print [Bind] where
prt _ [x] = concatD [prt 0 x] prt _ [x] = concatD [prt 0 x]
prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs] prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs]
prtIdPs :: Int -> [Id] -> Doc
prtIdPs i = prPrec i 0 . concatD . map (prt i)
instance Print Exp where instance Print Exp where
prt i = \case prt i = \case
EVar name -> prPrec i 3 $ prt 0 name EVar name -> prPrec i 3 $ prt 0 name
EVarC as lident -> doc . showString
$ "{" ++ intercalate ", " (map go as) ++ "}" ++ printTree lident
where
go (x, _) = printTree x ++ "^=" ++ printTree (EVar x)
ELit lit -> prPrec i 3 $ prt 0 lit ELit lit -> prPrec i 3 $ prt 0 lit
ELet b e -> ELet b e ->
prPrec i 3 $ prPrec i 3 $
@ -134,7 +128,7 @@ instance Print Exp where
] ]
instance Print Branch where instance Print Branch where
prt i (Branch (pattern_, t) exp) = prPrec i 0 (concatD [doc (showString "("), prt 0 pattern_, doc (showString " : "), prt 0 t, doc (showString ")"), doc (showString "=>"), prt 0 exp]) prt i (Branch patt exp) = prPrec i 0 (concatD [prt i patt, doc (showString "=>"), prt 0 exp])
instance Print [Branch] where instance Print [Branch] where
prt _ [] = concatD [] prt _ [] = concatD []
@ -152,12 +146,12 @@ instance Print Data where
instance Print Inj where instance Print Inj where
prt i = \case prt i = \case
Inj uident type_ -> prPrec i 0 (concatD [prt 0 uident, doc (showString ":"), prt 0 type_]) Inj uident type_ -> prt i (uident, type_)
instance Print Pattern where instance Print Pattern where
prt i = \case prt i = \case
PVar name -> prPrec i 1 (concatD [prt 0 name]) PVar name -> prPrec i 1 (concatD [prt 0 name])
PLit (lit, _) -> prPrec i 1 (concatD [prt 0 lit]) PLit lit -> prPrec i 1 (concatD [prt 0 lit])
PCatch -> prPrec i 1 (concatD [doc (showString "_")]) PCatch -> prPrec i 1 (concatD [doc (showString "_")])
PEnum name -> prPrec i 1 (concatD [prt 0 name]) PEnum name -> prPrec i 1 (concatD [prt 0 name])
PInj uident patterns -> prPrec i 0 (concatD [prt 0 uident, prt 1 patterns]) PInj uident patterns -> prPrec i 0 (concatD [prt 0 uident, prt 1 patterns])
@ -176,9 +170,3 @@ instance Print Type where
TLit uident -> prPrec i 1 (concatD [prt 0 uident]) TLit uident -> prPrec i 1 (concatD [prt 0 uident])
TFun type_1 type_2 -> prPrec i 0 (concatD [prt 1 type_1, doc (showString "->"), prt 0 type_2]) TFun type_1 type_2 -> prPrec i 0 (concatD [prt 1 type_1, doc (showString "->"), prt 0 type_2])
TData uident types -> prPrec i 1 (concatD [prt 0 uident, doc (showString "("), prt 0 types, doc (showString ")")]) TData uident types -> prPrec i 1 (concatD [prt 0 uident, doc (showString "("), prt 0 types, doc (showString ")")])
instance Print Lit where
prt i = \case
LInt int -> prt i int
LChar char -> prt i char

View file

@ -7,7 +7,7 @@ import Control.Applicative (Applicative (liftA2), liftA3)
import Control.Monad.Except (MonadError (throwError)) import Control.Monad.Except (MonadError (throwError))
import Data.Coerce (coerce) import Data.Coerce (coerce)
import Data.Tuple.Extra (secondM) import Data.Tuple.Extra (secondM)
import Grammar.Abs qualified as G import qualified Grammar.Abs as G
import Grammar.ErrM (Err) import Grammar.ErrM (Err)
import Grammar.Print (printTree) import Grammar.Print (printTree)
import TypeChecker.TypeCheckerIr hiding (Type (..)) import TypeChecker.TypeCheckerIr hiding (Type (..))
@ -18,7 +18,7 @@ data Type
| TData Ident [Type] | TData Ident [Type]
| TFun Type Type | TFun Type Type
| TAll TVar Type | TAll TVar Type
deriving (Eq, Ord, Show, Read) deriving (Eq, Ord, Show)
class ReportTEVar a b where class ReportTEVar a b where
reportTEVar :: a -> Err b reportTEVar :: a -> Err b
@ -65,10 +65,10 @@ instance ReportTEVar (Data' G.Type) (Data' Type) where
instance ReportTEVar (Inj' G.Type) (Inj' Type) where instance ReportTEVar (Inj' G.Type) (Inj' Type) where
reportTEVar (Inj name typ) = Inj name <$> reportTEVar typ reportTEVar (Inj name typ) = Inj name <$> reportTEVar typ
instance ReportTEVar (Id' G.Type) (Id' Type) where instance ReportTEVar (a, G.Type) (a, Type) where
reportTEVar = secondM reportTEVar reportTEVar = secondM reportTEVar
instance ReportTEVar (ExpT' G.Type) (ExpT' Type) where instance ReportTEVar (T' Exp' G.Type) (T' Exp' Type) where
reportTEVar (exp, typ) = liftA2 (,) (reportTEVar exp) (reportTEVar typ) reportTEVar (exp, typ) = liftA2 (,) (reportTEVar exp) (reportTEVar typ)
instance ReportTEVar a b => ReportTEVar [a] [b] where instance ReportTEVar a b => ReportTEVar [a] [b] where

View file

@ -31,6 +31,7 @@ import Grammar.ErrM
import Grammar.Print (printTree) import Grammar.Print (printTree)
import Prelude hiding (exp) import Prelude hiding (exp)
import qualified TypeChecker.TypeCheckerIr as T import qualified TypeChecker.TypeCheckerIr as T
import TypeChecker.TypeCheckerIr (T, T')
-- Implementation is derived from the paper (Dunfield and Krishnaswami 2013) -- Implementation is derived from the paper (Dunfield and Krishnaswami 2013)
-- https://doi.org/10.1145/2500365.2500582 -- https://doi.org/10.1145/2500365.2500582
@ -172,7 +173,7 @@ typecheckInj (Inj inj_name inj_typ) name tvars
-- | Γ ⊢ e ↑ A ⊣ Δ -- | Γ ⊢ e ↑ A ⊣ Δ
-- Under input context Γ, e checks against input type A, with output context ∆ -- Under input context Γ, e checks against input type A, with output context ∆
check :: Exp -> Type -> Tc (T.ExpT' Type) check :: Exp -> Type -> Tc (T' T.Exp' Type)
-- Γ,α ⊢ e ↑ A ⊣ Δ,α -- Γ,α ⊢ e ↑ A ⊣ Δ,α
-- ------------------- ∀I -- ------------------- ∀I
@ -212,12 +213,6 @@ check (ECase scrut pi) c = do
e' <- check e c e' <- check e c
pure (T.Branch p' e') pure (T.Branch p' e')
apply (T.ECase (scrut', a) pi', c) apply (T.ECase (scrut', a) pi', c)
where
go (pi, b) (Branch p e) = do
p' <- checkPattern p =<< apply a
e'@(_, b') <- infer e
subtype b' b
apply (T.Branch p' e' : pi, b')
-- Γ,α ⊢ e ↓ A ⊣ Θ Θ ⊢ [Θ]A <: [Θ]B ⊣ Δ -- Γ,α ⊢ e ↓ A ⊣ Θ Θ ⊢ [Θ]A <: [Θ]B ⊣ Δ
@ -229,9 +224,6 @@ check e b = do
subtype a b' subtype a b'
apply (e', b) apply (e', b)
checkPattern :: Pattern -> Type -> Tc (T.Pattern' Type, Type) checkPattern :: Pattern -> Type -> Tc (T.Pattern' Type, Type)
checkPattern patt t_patt = case patt of checkPattern patt t_patt = case patt of
@ -297,7 +289,7 @@ checkPattern patt t_patt = case patt of
-- | Γ ⊢ e ↓ A ⊣ Δ -- | Γ ⊢ e ↓ A ⊣ Δ
-- Under input context Γ, e infers output type A, with output context ∆ -- Under input context Γ, e infers output type A, with output context ∆
infer :: Exp -> Tc (T.ExpT' Type) infer :: Exp -> Tc (T' T.Exp' Type)
infer (ELit lit) = apply (T.ELit lit, litType lit) infer (ELit lit) = apply (T.ELit lit, litType lit)
-- Γ ∋ (x : A) Γ ⊢ rec(x) -- Γ ∋ (x : A) Γ ⊢ rec(x)
@ -391,7 +383,7 @@ infer (ECase scrut pi) = do
-- | Γ ⊢ A • e ⇓ C ⊣ Δ -- | Γ ⊢ A • e ⇓ C ⊣ Δ
-- Under input context Γ , applying a function of type A to e infers type C, with output context ∆ -- Under input context Γ , applying a function of type A to e infers type C, with output context ∆
-- Instantiate existential type variables until there is an arrow type. -- Instantiate existential type variables until there is an arrow type.
applyInfer :: Type -> Exp -> Tc (T.ExpT' Type, Type) applyInfer :: Type -> Exp -> Tc (T' T.Exp' Type, Type)
-- Γ,ά ⊢ [ά/α]A • e ⇓ C ⊣ Δ -- Γ,ά ⊢ [ά/α]A • e ⇓ C ⊣ Δ
-- ------------------------ ∀App -- ------------------------ ∀App

View file

@ -8,7 +8,7 @@
module TypeChecker.TypeCheckerHm where module TypeChecker.TypeCheckerHm where
import Auxiliary (int, litType, maybeToRightM, unzip4) import Auxiliary (int, litType, maybeToRightM, unzip4)
import Auxiliary qualified as Aux import qualified Auxiliary as Aux
import Control.Monad.Except import Control.Monad.Except
import Control.Monad.Identity (Identity, runIdentity) import Control.Monad.Identity (Identity, runIdentity)
import Control.Monad.Reader import Control.Monad.Reader
@ -19,14 +19,15 @@ import Data.Function (on)
import Data.List (foldl', nub, sortOn) import Data.List (foldl', nub, sortOn)
import Data.List.Extra (unsnoc) import Data.List.Extra (unsnoc)
import Data.Map (Map) import Data.Map (Map)
import Data.Map qualified as M import qualified Data.Map as M
import Data.Maybe (fromJust) import Data.Maybe (fromJust)
import Data.Set (Set) import Data.Set (Set)
import Data.Set qualified as S import qualified Data.Set as S
import Debug.Trace (trace, traceShow) import Debug.Trace (trace, traceShow)
import Grammar.Abs import Grammar.Abs
import Grammar.Print (printTree) import Grammar.Print (printTree)
import TypeChecker.TypeCheckerIr qualified as T import TypeChecker.TypeCheckerIr (T, T')
import qualified TypeChecker.TypeCheckerIr as T
{- {-
TODO TODO
@ -265,7 +266,7 @@ returnType :: Type -> Type
returnType (TFun _ t2) = returnType t2 returnType (TFun _ t2) = returnType t2
returnType a = a returnType a = a
inferExp :: Exp -> Infer (T.ExpT' Type) inferExp :: Exp -> Infer (T' T.Exp' Type)
inferExp e = do inferExp e = do
(s, (e', t)) <- algoW e (s, (e', t)) <- algoW e
let subbed = apply s t let subbed = apply s t
@ -289,7 +290,7 @@ instance CollectTVars Type where
collect :: Set T.Ident -> Infer () collect :: Set T.Ident -> Infer ()
collect s = modify (\st -> st{takenTypeVars = s `S.union` takenTypeVars st}) collect s = modify (\st -> st{takenTypeVars = s `S.union` takenTypeVars st})
algoW :: Exp -> Infer (Subst, T.ExpT' Type) algoW :: Exp -> Infer (Subst, T' T.Exp' Type)
algoW = \case algoW = \case
err@(EAnn e t) -> do err@(EAnn e t) -> do
(sub0, (e', t')) <- exprErr (algoW e) err (sub0, (e', t')) <- exprErr (algoW e) err
@ -721,7 +722,7 @@ instance SubstType (Map T.Ident Type) where
instance SubstType (Map T.Ident (Maybe Type)) where instance SubstType (Map T.Ident (Maybe Type)) where
apply s = M.map (fmap $ apply s) apply s = M.map (fmap $ apply s)
instance SubstType (T.ExpT' Type) where instance SubstType (T' T.Exp' Type) where
apply s (e, t) = (apply s e, apply s t) apply s (e, t) = (apply s e, apply s t)
instance SubstType (T.Exp' Type) where instance SubstType (T.Exp' Type) where
@ -761,7 +762,7 @@ instance SubstType (T.Pattern' Type, Type) where
instance SubstType a => SubstType [a] where instance SubstType a => SubstType [a] where
apply s = map (apply s) apply s = map (apply s)
instance SubstType (T.Id' Type) where instance SubstType (T T.Ident Type) where
apply s (name, t) = (name, apply s t) apply s (name, t) = (name, apply s t)
-- | Represents the empty substition set -- | Represents the empty substition set

View file

@ -10,31 +10,30 @@ import Data.String (IsString)
import Grammar.Abs (Lit (..)) import Grammar.Abs (Lit (..))
import Grammar.Print import Grammar.Print
import Prelude import Prelude
import qualified Prelude as C (Eq, Ord, Read, Show)
newtype Program' t = Program [Def' t] newtype Program' t = Program [Def' t]
deriving (C.Eq, C.Ord, C.Show, C.Read, Functor) deriving (Eq, Ord, Show, Functor)
data Def' t data Def' t
= DBind (Bind' t) = DBind (Bind' t)
| DData (Data' t) | DData (Data' t)
deriving (C.Eq, C.Ord, C.Show, C.Read, Functor) deriving (Eq, Ord, Show, Functor)
data Type data Type
= TLit Ident = TLit Ident
| TVar TVar | TVar TVar
| TData Ident [Type] | TData Ident [Type]
| TFun Type Type | TFun Type Type
deriving (Eq, Ord, Show, Read) deriving (Eq, Ord, Show)
data Data' t = Data t [Inj' t] data Data' t = Data t [Inj' t]
deriving (C.Eq, C.Ord, C.Show, C.Read, Functor) deriving (Eq, Ord, Show, Functor)
data Inj' t = Inj Ident t data Inj' t = Inj Ident t
deriving (C.Eq, C.Ord, C.Show, C.Read, Functor) deriving (Eq, Ord, Show, Functor)
newtype Ident = Ident String newtype Ident = Ident String
deriving (C.Eq, C.Ord, C.Show, C.Read, IsString) deriving (Eq, Ord, Show, IsString)
data Pattern' t data Pattern' t
= PVar Ident = PVar Ident
@ -42,30 +41,31 @@ data Pattern' t
| PCatch | PCatch
| PEnum Ident | PEnum Ident
| PInj Ident [(Pattern' t, t)] | PInj Ident [(Pattern' t, t)]
deriving (C.Eq, C.Ord, C.Show, C.Read, Functor) deriving (Eq, Ord, Show, Functor)
data Exp' t data Exp' t
= EVar Ident = EVar Ident
| EInj Ident | EInj Ident
| ELit Lit | ELit Lit
| ELet (Bind' t) (ExpT' t) | ELet (Bind' t) (T' Exp' t)
| EApp (ExpT' t) (ExpT' t) | EApp (T' Exp' t) (T' Exp' t)
| EAdd (ExpT' t) (ExpT' t) | EAdd (T' Exp' t) (T' Exp' t)
| EAbs Ident (ExpT' t) | EAbs Ident (T' Exp' t)
| ECase (ExpT' t) [Branch' t] | ECase (T' Exp' t) [Branch' t]
deriving (C.Eq, C.Ord, C.Show, C.Read, Functor) deriving (Eq, Ord, Show, Functor)
newtype TVar = MkTVar Ident newtype TVar = MkTVar Ident
deriving (C.Eq, C.Ord, C.Show, C.Read) deriving (Eq, Ord, Show)
type Id' t = (Ident, t) type T' a t = (a t, t)
type ExpT' t = (Exp' t, t) type T a t = (a, t)
data Bind' t = Bind (Id' t) [Id' t] (ExpT' t)
deriving (C.Eq, C.Ord, C.Show, C.Read, Functor)
data Branch' t = Branch (Pattern' t, t) (ExpT' t) data Bind' t = Bind (T Ident t) [T Ident t] (T' Exp' t)
deriving (C.Eq, C.Ord, C.Show, C.Read, Functor) deriving (Eq, Ord, Show, Functor)
data Branch' t = Branch (T' Pattern' t) (T' Exp' t)
deriving (Eq, Ord, Show, Functor)
instance Print Ident where instance Print Ident where
prt _ (Ident s) = doc $ showString s prt _ (Ident s) = doc $ showString s
@ -81,22 +81,22 @@ instance Print t => Print (Bind' t) where
, prt i rhs , prt i rhs
] ]
prtSig :: Print t => Id' t -> Doc prtSig :: Print t => T Ident t -> Doc
prtSig (name, t) = prtSig (x, t) =
concatD concatD
[ prt 0 name [ prt 0 x
, doc $ showString ":" , doc $ showString ":"
, prt 0 t , prt 0 t
] ]
instance Print t => Print (ExpT' t) where instance (Print a, Print t) => Print (T a t) where
prt i (e, t) = prt i (x, t) =
concatD concatD
[ doc $ showString "(" [ -- doc $ showString "("
, prt i e {- , -} prt i x
, doc $ showString ":" -- , doc $ showString ":"
, prt 0 t -- , prt 0 t
, doc $ showString ")" -- , doc $ showString ")"
] ]
instance Print t => Print [Bind' t] where instance Print t => Print [Bind' t] where
@ -104,16 +104,6 @@ instance Print t => Print [Bind' t] where
prt i [x] = concatD [prt i x] prt i [x] = concatD [prt i x]
prt i (x : xs) = concatD [prt i x, doc (showString ";"), prt i xs] prt i (x : xs) = concatD [prt i x, doc (showString ";"), prt i xs]
instance Print t => Print (Id' t) where
prt i (name, t) =
concatD
[ doc $ showString "("
, prt i name
, doc $ showString ","
, prt i t
, doc $ showString ")"
]
instance Print t => Print (Exp' t) where instance Print t => Print (Exp' t) where
prt i = \case prt i = \case
EVar lident -> prPrec i 3 (concatD [prt 0 lident]) EVar lident -> prPrec i 3 (concatD [prt 0 lident])
@ -151,9 +141,6 @@ instance Print t => Print [Inj' t] where
prt i [x] = prt i x prt i [x] = prt i x
prt i (x : xs) = prPrec i 0 $ concatD [prt i x, doc $ showString "\n ", prt i xs] prt i (x : xs) = prPrec i 0 $ concatD [prt i x, doc $ showString "\n ", prt i xs]
instance Print t => Print (Pattern' t, t) where
prt i (p, t) = prPrec i 1 (concatD [prt i p, prt i t])
instance Print t => Print (Pattern' t) where instance Print t => Print (Pattern' t) where
prt i = \case prt i = \case
PVar name -> prPrec i 1 (concatD [prt 0 name]) PVar name -> prPrec i 1 (concatD [prt 0 name])
@ -189,8 +176,6 @@ type Branch = Branch' Type
type Pattern = Pattern' Type type Pattern = Pattern' Type
type Inj = Inj' Type type Inj = Inj' Type
type Exp = Exp' Type type Exp = Exp' Type
type ExpT = ExpT' Type
type Id = Id' Type
pattern TVar' s = TVar (MkTVar s) pattern TVar' s = TVar (MkTVar s)
pattern DBind' id vars expt = DBind (Bind id vars expt) pattern DBind' id vars expt = DBind (Bind id vars expt)
pattern DData' typ injs = DData (Data typ injs) pattern DData' typ injs = DData (Data typ injs)

185
test_map2.ll Normal file
View file

@ -0,0 +1,185 @@
target triple = "x86_64-pc-linux-gnu"
target datalayout = "e-m:o-i64:64-f80:128-n8:16:32:64-S128"
@.str = private unnamed_addr constant [3 x i8] c"%i
", align 1
@.non_exhaustive_patterns = private unnamed_addr constant [41 x i8] c"Non-exhaustive patterns in case at %i:%i
"
declare i32 @printf(ptr noalias nocapture, ...)
declare i32 @exit(i32 noundef)
declare ptr @malloc(i32 noundef)
%List = type { i8, [23 x i8] }
%Cons = type { i8, i64, %List* }
%Nil = type { i8 }
; NYTT: kontexttyp
%Closure_sc_0 = type { i64 (i64)*, i64 }
; Ident "sum$List_Int": (ECase (EVar (Ident "$4xs"),TLit (Ident "List")) [Branch (PEnum (Ident "Nil"),TLit (Ident "List")) (ELit (LInt 0),TLit (Ident "Int")),Branch (PInj (Ident "Cons") [PVar (Ident "$5x",TLit (Ident "Int")),PVar (Ident "$6xs",TLit (Ident "List"))],TLit (Ident "List")) (EAdd (EVar (Ident "$5x"),TLit (Ident "Int")) (EApp (EVar (Ident "sum$List_Int"),TFun (TLit (Ident "List")) (TLit (Ident "Int"))) (EVar (Ident "$6xs"),TLit (Ident "List")),TLit (Ident "Int")),TLit (Ident "Int"))],TLit (Ident "Int"))
define fastcc i64 @sum$List_Int(%List %$4xs) {
%1 = alloca i64
; Penum
%2 = extractvalue %List %$4xs, 0
%3 = icmp eq i8 %2, 1
br i1 %3, label %lbl_success_3, label %lbl_failed_2
lbl_success_3:
%4 = alloca %List
store %List %$4xs, ptr %4
%5 = load %Nil, ptr %4
store i64 0, ptr %1
br label %lbl_escape_1
lbl_failed_2:
; Inj
%6 = extractvalue %List %$4xs, 0
%7 = icmp eq i8 %6, 0
br i1 %7, label %lbl_success_5, label %lbl_failed_4
lbl_success_5:
%8 = alloca %List
store %List %$4xs, ptr %8
%9 = load %Cons, ptr %8
; ident i64
%$5x = extractvalue %Cons %9, 1
; ident %List
%10 = extractvalue %Cons %9, 2
%$6xs = load %List, ptr %10
; TLit (Ident "Int")
%11 = call fastcc i64 @sum$List_Int(%List %$6xs)
%12 = add i64 %$5x, %11
store i64 %12, ptr %1
br label %lbl_escape_1
lbl_failed_4:
call i32 (ptr, ...) @printf(ptr noundef @.non_exhaustive_patterns, i64 noundef 12, i64 noundef 6)
call i32 @exit(i32 noundef 1)
br label %lbl_escape_1
lbl_escape_1:
%15 = load i64, ptr %1
ret i64 %15
}
; Ident "sc_0$Int_Int": (EAdd (EVar (Ident "$7x"),TLit (Ident "Int")) (ELit (LInt 10),TLit (Ident "Int")),TLit (Ident "Int"))
; ÄNDRAT: lägg till kontextpekare
define fastcc i64 @sc_0$Int_Int(ptr %closure_sc_0, i64 %$7x) {
; NYTT: Ladda alla fria variabler
%fri_variabel_ptr = getelementptr inbounds %Closure_sc_0, ptr %closure_sc_0, i32 0, i32 1
%fri_variabel = load i64, ptr %fri_variabel_ptr
; ÄNDRAT: %fri_variabel istället för 2
%1 = add i64 %$7x, %fri_variabel
ret i64 %1
}
; Ident "map$Int_Int_List_List": (ECase (EVar (Ident "$1xs"),TLit (Ident "List")) [Branch (PEnum (Ident "Nil"),TLit (Ident "List")) (EVar (Ident "Nil"),TLit (Ident "List")),Branch (PInj (Ident "Cons") [PVar (Ident "$2x",TLit (Ident "Int")),PVar (Ident "$3xs",TLit (Ident "List"))],TLit (Ident "List")) (EApp (EApp (EVar (Ident "Cons"),TFun (TLit (Ident "Int")) (TFun (TLit (Ident "List")) (TLit (Ident "List")))) (EApp (EVar (Ident "$0f"),TFun (TLit (Ident "Int")) (TLit (Ident "Int"))) (EVar (Ident "$2x"),TLit (Ident "Int")),TLit (Ident "Int")),TFun (TLit (Ident "List")) (TLit (Ident "List"))) (EApp (EApp (EVar (Ident "map$Int_Int_List_List"),TFun (TFun (TLit (Ident "Int")) (TLit (Ident "Int"))) (TFun (TLit (Ident "List")) (TLit (Ident "List")))) (EVar (Ident "$0f"),TFun (TLit (Ident "Int")) (TLit (Ident "Int"))),TFun (TLit (Ident "List")) (TLit (Ident "List"))) (EVar (Ident "$3xs"),TLit (Ident "List")),TLit (Ident "List")),TLit (Ident "List"))],TLit (Ident "List"))
; ÄNDRAT: ptr istället för i64 (i64)*
define fastcc %List @map$Int_Int_List_List(ptr %$0f, %List %$1xs) {
; NYTT: ta fram funktionspekaren
%$0f_deref = load i64(i64)*, ptr %$0f
%1 = alloca %List
; Penum
%2 = extractvalue %List %$1xs, 0
%3 = icmp eq i8 %2, 1
br i1 %3, label %lbl_success_8, label %lbl_failed_7
lbl_success_8:
%4 = alloca %List
store %List %$1xs, ptr %4
%5 = load %Nil, ptr %4
%6 = call fastcc %List @Nil()
store %List %6, ptr %1
br label %lbl_escape_6
lbl_failed_7:
; Inj
%7 = extractvalue %List %$1xs, 0
%8 = icmp eq i8 %7, 0
br i1 %8, label %lbl_success_10, label %lbl_failed_9
lbl_success_10:
%9 = alloca %List
store %List %$1xs, ptr %9
%10 = load %Cons, ptr %9
; ident i64
%$2x = extractvalue %Cons %10, 1
; ident %List
%11 = extractvalue %Cons %10, 2
%$3xs = load %List, ptr %11
; TLit (Ident "Int")
; ÄNDRAT använd deref
%12 = call fastcc i64 %$0f_deref(ptr %$0f, i64 %$2x)
; TLit (Ident "List")
; ÄNDRAT ptr istället för 64 (64)* och skicka med ptr
%13 = call fastcc %List @map$Int_Int_List_List(ptr %$0f, %List %$3xs)
; TLit (Ident "List")
%14 = call fastcc %List @Cons(i64 %12, %List %13)
store %List %14, ptr %1
br label %lbl_escape_6
lbl_failed_9:
call i32 (ptr, ...) @printf(ptr noundef @.non_exhaustive_patterns, i64 noundef 14, i64 noundef 6)
call i32 @exit(i32 noundef 1)
br label %lbl_escape_6
lbl_escape_6:
%17 = load %List, ptr %1
ret %List %17
}
; Ident "main": (EApp (EVar (Ident "sum$List_Int"),TFun (TLit (Ident "List")) (TLit (Ident "Int"))) (EApp (EApp (EVar (Ident "map$Int_Int_List_List"),TFun (TFun (TLit (Ident "Int")) (TLit (Ident "Int"))) (TFun (TLit (Ident "List")) (TLit (Ident "List")))) (EVar (Ident "sc_0$Int_Int"),TFun (TLit (Ident "Int")) (TLit (Ident "Int"))),TFun (TLit (Ident "List")) (TLit (Ident "List"))) (EApp (EApp (EVar (Ident "Cons"),TFun (TLit (Ident "Int")) (TFun (TLit (Ident "List")) (TLit (Ident "List")))) (ELit (LInt 1),TLit (Ident "Int")),TFun (TLit (Ident "List")) (TLit (Ident "List"))) (EApp (EApp (EVar (Ident "Cons"),TFun (TLit (Ident "Int")) (TFun (TLit (Ident "List")) (TLit (Ident "List")))) (ELit (LInt 2),TLit (Ident "Int")),TFun (TLit (Ident "List")) (TLit (Ident "List"))) (EVar (Ident "Nil"),TLit (Ident "List")),TLit (Ident "List")),TLit (Ident "List")),TLit (Ident "List")),TLit (Ident "Int"))
define fastcc i64 @main() {
%1 = call fastcc %List @Nil()
; TLit (Ident "List")
%2 = call fastcc %List @Cons(i64 2, %List %1)
; TLit (Ident "List")
%3 = call fastcc %List @Cons(i64 1, %List %2)
; TLit (Ident "List")
; NYTT: spara funktionspekaren och 100 i kontexten
%closure_sc_0 = alloca %Closure_sc_0
store i64(i64)* @sc_0$Int_Int, ptr %closure_sc_0
%fri_variabel_ptr = getelementptr inbounds %Closure_sc_0, ptr %closure_sc_0, i32 0, i32 1
store i64 100, ptr %fri_variabel_ptr
; store %Closure_sc_0 {i64 (i64)* @sc_0$Int_Int, 100}, ptr %closure_sc_0
; ÄNDRAT ptr %closure_sc_0 istället för i64 (i64)* @sc_0$Int_Int
%4 = call fastcc %List @map$Int_Int_List_List(ptr %closure_sc_0, %List %3)
; TLit (Ident "Int")
%5 = call fastcc i64 @sum$List_Int(%List %4)
call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef %5)
ret i64 0
}
define fastcc %List @Cons(i64 %arg_0, %List %arg_1) {
%1 = alloca %List
%2 = getelementptr %List, %List* %1, i64 0, i32 0
store i8 0, i8* %2
%3 = bitcast %List* %1 to %Cons*
; i64 arg_0 1
%4 = getelementptr %Cons, %Cons* %3, i64 0, i32 1
; Just store
store i64 %arg_0, ptr %4
; %List arg_1 2
%5 = getelementptr %Cons, %Cons* %3, i64 0, i32 2
; Malloc and store
%6 = call ptr @malloc(i64 24)
store %List %arg_1, ptr %6
store %List* %6, ptr %5
; Return the newly constructed value
%7 = load %List, ptr %1
ret %List %7
}
define fastcc %List @Nil() {
%1 = alloca %List
%2 = getelementptr %List, %List* %1, i64 0, i32 0
store i8 1, i8* %2
%3 = bitcast %List* %1 to %Nil*
; Return the newly constructed value
%4 = load %List, ptr %1
ret %List %4
}