Merge closures mostly done. Desugaring cases is a problem.

This commit is contained in:
Martin Fredin 2023-05-06 23:38:56 +02:00
commit 019ed0d45a
29 changed files with 1484 additions and 757 deletions

View file

@ -43,6 +43,7 @@ executable language
TypeChecker.ReportTEVar
TypeChecker.RemoveForall
LambdaLifter
LambdaLifterIr
Monomorphizer.Monomorphizer
Monomorphizer.MonomorphizerIr
Monomorphizer.MorbIr
@ -101,6 +102,8 @@ Test-suite language-testsuite
TypeChecker.TypeChecker
AnnForall
ReportForall
LambdaLifterIr
LambdaLifter
TypeChecker.TypeCheckerHm
TypeChecker.TypeCheckerBidir
TypeChecker.ReportTEVar

View file

@ -15,4 +15,4 @@ revRange x = case x of
sum xs = case xs of
Cons x ys => x + sum ys
Nil => 0
Nil => 0

View file

@ -0,0 +1,6 @@
add : Int -> Int -> Int -> Int
add x y z = x + y + z
main = add 8 6 2

View file

@ -0,0 +1,7 @@
apply : (Int -> Int) -> Int -> Int
apply f y = f y
main = apply (\y. y + y) 5

View file

@ -0,0 +1,10 @@
apply : (Int -> Int) -> Int -> Int
apply f z = f z
main =
let x = 10 in
apply (\y. y + x) 6

View file

@ -0,0 +1,15 @@
data List a where
Nil : List a
Cons : a -> List a -> List a
foldr : (a -> b -> b) -> b -> List a -> b
foldr f y xs = case xs of
Nil => y
Cons x xs => f x (foldr f y xs)
main = let z = 2 in foldr (\x.\y. x + y + z) 0 (Cons 1000 (Cons 100 Nil))

View file

@ -0,0 +1,25 @@
data List (a) where
Nil : List (a)
Cons : a -> List (a) -> List (a)
map : (a -> b) -> List (a) -> List (b)
map f xs = case xs of
Nil => Nil
Cons x xs => Cons (f x) (map f xs)
add : Int -> Int -> Int
add x y = x + y
foldr : (a -> b -> b) -> b -> List (a) -> b
foldr f y xs = case xs of
Nil => y
Cons x xs => f x (foldr f y xs)
f : List (Int)
f = ((\x.\ys. map (\y. add y x) ys) 4 (Cons 1 (Cons 2 Nil)))
-- [5, 6]
main : Int
main = foldr add 0 f

View file

@ -0,0 +1,21 @@
data List a where
Nil : List a
Cons : a -> List a -> List a
map : (a -> b) -> List a -> List b
map f xs = case xs of
Nil => Nil
Cons x xs => Cons (f x) (map f xs)
f : List Int
f = (\x.\ys. map (\y. y + x) ys) 4 (Cons 1 (Cons 2 Nil))
-- [5, 6]
sum : List Int -> Int
sum xs = case xs of
Nil => 0
Cons x xs => x + sum xs
main = sum f

View file

@ -0,0 +1,3 @@
main = let x = 10 in 6 + x

View file

@ -0,0 +1,16 @@
data List a where
Nil : List a
Cons : a -> List a -> List a
map : (a -> b) -> List a -> List b
map f xs = case xs of
Nil => Nil
Cons x xs => Cons (f x) (map f xs)
sum : List Int -> Int
sum xs = case xs of
Nil => 0
Cons x xs => x + (sum xs)
main = let y = 10 in sum (map (\x. x + y) (Cons 2 (Cons 4 Nil)))

View file

@ -0,0 +1,7 @@
f = 10
main = f + 6

View file

@ -1,25 +1,25 @@
module Codegen.Auxillary where
import Codegen.LlvmIr (LLVMType (..), LLVMValue (..))
import Control.Monad (foldM_)
import Monomorphizer.MonomorphizerIr as MIR (ExpT, Type (..))
import TypeChecker.TypeCheckerIr qualified as TIR
import Codegen.LlvmIr (LLVMType (..), LLVMValue (..))
import Control.Monad (foldM_)
import Monomorphizer.MonomorphizerIr as MIR (Exp, T, Type (..))
import qualified TypeChecker.TypeCheckerIr as TIR
type2LlvmType :: MIR.Type -> LLVMType
type2LlvmType (MIR.TLit id@(TIR.Ident name)) = case name of
"Int" -> I64
"Int" -> I64
"Char" -> I8
"Bool" -> I1
_ -> CustomType id
_ -> CustomType id
type2LlvmType (MIR.TFun t xs) = do
let (t', xs') = function2LLVMType xs [type2LlvmType t]
Function t' xs'
where
function2LLVMType :: Type -> [LLVMType] -> (LLVMType, [LLVMType])
function2LLVMType (TFun t xs) s = function2LLVMType xs (type2LlvmType t : s)
function2LLVMType x s = (type2LlvmType x, s)
function2LLVMType x s = (type2LlvmType x, s)
getType :: ExpT -> LLVMType
getType :: T Exp -> LLVMType
getType (_, t) = type2LlvmType t
extractTypeName :: MIR.Type -> TIR.Ident
@ -30,21 +30,21 @@ extractTypeName (MIR.TFun t xs) =
in TIR.Ident $ i <> "_$_" <> is
valueGetType :: LLVMValue -> LLVMType
valueGetType (VInteger _) = I64
valueGetType (VChar _) = I8
valueGetType (VIdent _ t) = t
valueGetType (VConstant s) = Array (fromIntegral $ length s) I8
valueGetType (VInteger _) = I64
valueGetType (VChar _) = I8
valueGetType (VIdent _ t) = t
valueGetType (VConstant s) = Array (fromIntegral $ length s) I8
valueGetType (VFunction _ _ t) = t
typeByteSize :: LLVMType -> Integer
typeByteSize I1 = 1
typeByteSize I8 = 1
typeByteSize I32 = 4
typeByteSize I64 = 8
typeByteSize Ptr = 8
typeByteSize (Ref _) = 8
typeByteSize I1 = 1
typeByteSize I8 = 1
typeByteSize I32 = 4
typeByteSize I64 = 8
typeByteSize Ptr = 8
typeByteSize (Ref _) = 8
typeByteSize (Function _ _) = 8
typeByteSize (Array n t) = n * typeByteSize t
typeByteSize (Array n t) = n * typeByteSize t
typeByteSize (CustomType _) = 8
enumerateOneM_ :: Monad m => (Integer -> a -> m b) -> [a] -> m ()

View file

@ -1,18 +1,24 @@
{-# LANGUAGE OverloadedRecordDot #-}
{-# LANGUAGE OverloadedStrings #-}
module Codegen.Codegen (generateCode) where
import Codegen.CompilerState (
CodeGenerator (instructions),
initCodeGenerator,
)
import Codegen.Emits (compileScs)
import Codegen.LlvmIr as LIR (llvmIrToString)
import Control.Monad.State (
execStateT,
)
import Data.List (sortBy)
import Grammar.ErrM (Err)
import Monomorphizer.MonomorphizerIr as MIR (Bind (..), Data (..), Def (DBind, DData), Program (..), Type (TLit))
import TypeChecker.TypeCheckerIr (Ident (..))
import Codegen.CompilerState (CodeGenerator (..),
StructType (inst),
initCodeGenerator)
import Codegen.Emits (compileScs)
import Codegen.LlvmIr as LIR (LLVMIr (UnsafeRaw),
llvmIrToString)
import Control.Monad.State (execStateT)
import Data.Functor ((<&>))
import Data.List (sortBy)
import qualified Data.Map as Map
import Grammar.ErrM (Err)
import Monomorphizer.MonomorphizerIr as MIR (Bind (..), Data (..),
Def (DBind, DData),
Program (..),
Type (TLit))
import TypeChecker.TypeCheckerIr (Ident (..))
{- | Compiles an AST and produces a LLVM Ir string.
An easy way to actually "compile" this output is to
@ -20,16 +26,43 @@ import TypeChecker.TypeCheckerIr (Ident (..))
-}
generateCode :: MIR.Program -> Bool -> Err String
generateCode (MIR.Program scs) addGc = do
let tree = filter (not . detectPrelude) (sortBy lowData scs)
let codegen = initCodeGenerator addGc tree
llvmIrToString . instructions <$> execStateT (compileScs tree) codegen
let tree = filter (not . detectPrelude) (sortBy lowData scs)
codegen = initCodeGenerator addGc tree
-- Append instructions
execStateT (compileScs tree) codegen <&> \state ->
llvmIrToString $ defaultStart
++ (if addGc then gcStart else [])
++ map inst (Map.elems state.structTypes)
++ state.instructions
detectPrelude :: Def -> Bool
detectPrelude (DData (Data (TLit (Ident "Bool")) _)) = True
detectPrelude (DData (Data (TLit (Ident "Bool")) _)) = True
detectPrelude (DBind (Bind (Ident ('l' : 't' : '$' : _), _) _ _)) = True
detectPrelude _ = False
detectPrelude _ = False
lowData :: Def -> Def -> Ordering
lowData (DData _) (DBind _) = LT
lowData (DBind _) (DData _) = GT
lowData _ _ = EQ
lowData _ _ = EQ
defaultStart :: [LLVMIr]
defaultStart =
[ UnsafeRaw "target triple = \"x86_64-pc-linux-gnu\"\n"
, UnsafeRaw "target datalayout = \"e-m:o-i64:64-f80:128-n8:16:32:64-S128\"\n"
, UnsafeRaw "@.str = private unnamed_addr constant [3 x i8] c\"%i\n\", align 1\n"
, UnsafeRaw "@.non_exhaustive_patterns = private unnamed_addr constant [41 x i8] c\"Non-exhaustive patterns in case at %i:%i\n\"\n"
, UnsafeRaw "declare i32 @printf(ptr noalias nocapture, ...)\n"
, UnsafeRaw "declare i32 @exit(i32 noundef)\n"
, UnsafeRaw "declare ptr @malloc(i32 noundef)\n"
]
gcStart :: [LLVMIr]
gcStart =
[ UnsafeRaw "declare external void @cheap_init()\n"
, UnsafeRaw "declare external ptr @cheap_alloc(i64)\n"
, UnsafeRaw "declare external void @cheap_dispose()\n"
, UnsafeRaw "declare external ptr @cheap_the()\n"
, UnsafeRaw "declare external void @cheap_set_profiler(ptr, i1)\n"
, UnsafeRaw "declare external void @cheap_profiler_log_options(ptr, i64)\n"
]

View file

@ -1,46 +1,101 @@
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedRecordDot #-}
{-# LANGUAGE OverloadedStrings #-}
module Codegen.CompilerState where
import Auxiliary (snoc)
import Codegen.Auxillary (type2LlvmType, typeByteSize)
import Codegen.LlvmIr as LIR (LLVMIr (UnsafeRaw),
LLVMType)
import Control.Monad.State (StateT, gets, modify)
import Codegen.LlvmIr as LIR (LLVMIr (SetVariable, Type),
LLVMType (CustomType, Function, I64, Ptr),
LLVMValue (VFunction, VIdent),
Visibility (Global),
typeOf)
import Control.Monad.State (StateT, gets, modify, void)
import Data.Map (Map)
import qualified Data.Map as Map
import Grammar.ErrM (Err)
import Monomorphizer.MonomorphizerIr as MIR
import Monomorphizer.MonomorphizerIr (Ident (..), Inj (..), T,
flattenType)
import qualified Monomorphizer.MonomorphizerIr as MIR
import qualified TypeChecker.TypeCheckerIr as TIR
-- | The record used as the code generator state
data CodeGenerator = CodeGenerator
{ instructions :: [LLVMIr]
, functions :: Map MIR.Id FunctionInfo
, functions :: Map (T Ident) FunctionInfo
, customTypes :: Map LLVMType Integer
, constructors :: Map TIR.Ident ConstructorInfo
, constructors :: Map Ident ConstructorInfo
, variableCount :: Integer
, labelCount :: Integer
, gcEnabled :: Bool
, structTypes :: Map Ident StructType
-- ^ Custom stucture types
, locals :: [(Ident, LocalElem)]
-- ^ Arguments and variables in local environment
, globals :: Map Ident (LLVMType, LLVMValue)
}
data StructType = StructType
{ ptr :: LLVMType
, typs :: [LLVMType]
, inst :: LLVMIr
}
data LocalElem = LocalElem
{ typ :: LLVMType
, val :: LLVMValue
}
-- | A state type synonym
type CompilerState a = StateT CodeGenerator Err a
data FunctionInfo = FunctionInfo
{ numArgs :: Int
, arguments :: [Id]
, arguments :: [T Ident]
}
deriving (Show)
data ConstructorInfo = ConstructorInfo
{ numArgsCI :: Int
, argumentsCI :: [Id]
, argumentsCI :: [T Ident]
, numCI :: Integer
, returnTypeCI :: MIR.Type
}
deriving (Show)
addStructType_ :: Ident -> [LLVMType] -> CompilerState ()
addStructType_ = fmap void . addStructType
addStructType :: Ident -> [LLVMType] -> CompilerState LLVMType
addStructType x ts = do
modify $ \s -> s { structTypes = Map.insert x struct s.structTypes }
pure t
where
struct = StructType
{ ptr = t
, typs = ts
, inst = Type x ts
}
t = CustomType x
-- | Adds a instruction to the CodeGenerator state
emit :: LLVMIr -> CompilerState ()
emit l = modify $ \t -> t{instructions = Auxiliary.snoc l $ instructions t}
-- Add variable to environment
emit l@(SetVariable x _) = modify $ \t ->
t { instructions = Auxiliary.snoc l t.instructions
, locals = snoc (x, local)
t.locals
}
where
local = LocalElem { typ = typeOf l
, val = VIdent x (typeOf l)
}
emit l = modify $ \t -> t { instructions = Auxiliary.snoc l t.instructions }
-- | Increases the variable counter in the CodeGenerator state
increaseVarCount :: CompilerState ()
@ -63,16 +118,19 @@ getNewLabel = do
{- | Produces a map of functions infos from a list of binds,
which contains useful data for code generation.
-}
getFunctions :: [MIR.Def] -> Map Id FunctionInfo
getFunctions :: [MIR.Def] -> Map (T Ident) FunctionInfo
getFunctions bs = Map.fromList $ go bs
where
go [] = []
go (MIR.DBind (MIR.Bind id args _) : xs) =
(id, FunctionInfo{numArgs = length args, arguments = args})
: go xs
(id, FunctionInfo { numArgs = length args
, arguments = args
}
)
: go xs
go (_ : xs) = go xs
createArgs :: [MIR.Type] -> [Id]
createArgs :: [MIR.Type] -> [T Ident]
createArgs xs = fst $ foldl (\(acc, l) t -> (acc ++ [(TIR.Ident ("arg_" <> show l), t)], l + 1)) ([], 0) xs
{- | Produces a map of functions infos from a list of binds,
@ -113,35 +171,43 @@ getTypes bs = Map.fromList $ go bs
variantTypes fi = init $ map type2LlvmType (flattenType fi)
biggestVariant ts = 8 + maximum (sum . (\(Inj _ fi) -> typeByteSize <$> variantTypes fi) <$> ts)
getGlobals :: [MIR.Def] -> Map Ident (LLVMType, LLVMValue)
getGlobals scs = Map.fromList [ go b | MIR.DBind b <- scs ]
where
go bind | x == "main" = let typ = Function I64 []
in (x, (typ, VFunction x Global typ))
| otherwise = (x, (typ, VFunction x Global typ))
where
typ = Function tr $ Ptr : ts
Function tr ts = type2LlvmType' t
(x, t) = case bind of
MIR.Bind xt _ _ -> xt
MIR.BindC _ xt _ _ -> xt
-- Higher order function arguments are replaced with ptr
type2LlvmType' = go []
where
go acc = \case
MIR.TFun (MIR.TFun _ _) t2 -> go (snoc Ptr acc) t2
MIR.TFun t1 t2 -> go (snoc (type2LlvmType t1) acc) t2
t -> Function (type2LlvmType t) acc
initCodeGenerator :: Bool -> [MIR.Def] -> CodeGenerator
initCodeGenerator addGc scs =
CodeGenerator
{ instructions = defaultStart <> if addGc then gcStart else []
{ instructions = []
, functions = getFunctions scs
, constructors = getConstructors scs
, customTypes = getTypes scs
, structTypes = mempty
, variableCount = 0
, labelCount = 0
, gcEnabled = addGc
, locals = mempty
, globals = getGlobals scs
}
defaultStart :: [LLVMIr]
defaultStart =
[ UnsafeRaw "target triple = \"x86_64-pc-linux-gnu\"\n"
, UnsafeRaw "target datalayout = \"e-m:o-i64:64-f80:128-n8:16:32:64-S128\"\n"
, UnsafeRaw "@.str = private unnamed_addr constant [3 x i8] c\"%i\n\", align 1\n"
, UnsafeRaw "@.non_exhaustive_patterns = private unnamed_addr constant [41 x i8] c\"Non-exhaustive patterns in case at %i:%i\n\"\n"
, UnsafeRaw "declare i32 @printf(ptr noalias nocapture, ...)\n"
, UnsafeRaw "declare i32 @exit(i32 noundef)\n"
, UnsafeRaw "declare ptr @malloc(i32 noundef)\n"
]
gcStart :: [LLVMIr]
gcStart =
[ UnsafeRaw "declare external void @cheap_init()\n"
, UnsafeRaw "declare external ptr @cheap_alloc(i64)\n"
, UnsafeRaw "declare external void @cheap_dispose()\n"
, UnsafeRaw "declare external ptr @cheap_the()\n"
, UnsafeRaw "declare external void @cheap_set_profiler(ptr, i1)\n"
, UnsafeRaw "declare external void @cheap_profiler_log_options(ptr, i64)\n"
]

View file

@ -1,36 +1,40 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedRecordDot #-}
{-# LANGUAGE OverloadedStrings #-}
module Codegen.Emits where
import Codegen.Auxillary
import Codegen.CompilerState
import Codegen.LlvmIr as LIR
import Control.Applicative ((<|>))
import Control.Monad (when)
import Control.Monad.State (gets, modify)
import Data.Bifunctor qualified as BI
import Data.Char (ord)
import Data.Coerce (coerce)
import Data.Map qualified as Map
import Data.Maybe (fromJust, fromMaybe, isNothing)
import Data.Tuple.Extra (dupe, first, second)
import Debug.Trace (trace, traceShow)
import Grammar.Print
import Monomorphizer.MonomorphizerIr as MIR
import TypeChecker.TypeCheckerIr qualified as TIR
import Auxiliary (snoc)
import Codegen.Auxillary
import Codegen.CompilerState
import Codegen.LlvmIr as LIR
import Control.Applicative (Applicative (liftA2), (<|>))
import Control.Monad (forM_, when, zipWithM_)
import Control.Monad.Extra (whenJust)
import Control.Monad.State (gets, modify)
import Data.Char (ord)
import Data.Coerce (coerce)
import Data.Foldable.Extra (notNull)
import qualified Data.Map as Map
import Data.Maybe (fromJust, fromMaybe, isNothing)
import Data.Tuple.Extra (second)
import Grammar.Print (printTree)
import Monomorphizer.MonomorphizerIr
compileScs :: [MIR.Def] -> CompilerState ()
compileScs :: [Def] -> CompilerState ()
compileScs [] = do
emit $ UnsafeRaw "\n"
mapM_ createConstructor =<< gets (Map.toList . constructors)
-- as a last step create all the constructors
-- //TODO maybe merge this with the data type match?
c <- gets (Map.toList . constructors)
mapM_
( \(id, ci) -> do
let t = returnTypeCI ci
let t' = type2LlvmType t
let x = BI.second type2LlvmType <$> argumentsCI ci
where
createConstructor (id, ci) = do
let t = returnTypeCI ci
t' = type2LlvmType t
x = (mkCxtName, Ptr) : map (second type2LlvmType) ci.argumentsCI
emit $ Define FastCC t' id x
top <- getNewVar
ptr <- getNewVar
@ -56,7 +60,7 @@ compileScs [] = do
cTypes <- gets customTypes
enumerateOneM_
( \i (TIR.Ident arg_n, arg_t) -> do
( \i (Ident arg_n, arg_t) -> do
let arg_t' = type2LlvmType arg_t
emit $ Comment (toIr arg_t' <> " " <> arg_n <> " " <> show i)
elemPtr <- getNewVar
@ -78,11 +82,11 @@ compileScs [] = do
heapPtr <- getNewVar
useGc <- gets gcEnabled
emit $ SetVariable heapPtr (if useGc then GcMalloc s else Malloc s)
emit $ Store arg_t' (VIdent (TIR.Ident arg_n) arg_t') Ptr heapPtr
emit $ Store arg_t' (VIdent (Ident arg_n) arg_t') Ptr heapPtr
emit $ Store (Ref arg_t') (VIdent heapPtr arg_t') Ptr elemPtr
Nothing -> do
emit $ Comment "Just store"
emit $ Store arg_t' (VIdent (TIR.Ident arg_n) arg_t') Ptr elemPtr
emit $ Store arg_t' (VIdent (Ident arg_n) arg_t') Ptr elemPtr
)
(argumentsCI ci)
@ -95,34 +99,83 @@ compileScs [] = do
emit $ UnsafeRaw "\n"
modify $ \s -> s{variableCount = 0}
)
c
compileScs (MIR.DBind (MIR.Bind (name, t) args exp) : xs) = do
let t_return = type2LlvmType . last . flattenType $ t
compileScs (DBind bind : xs) = do
emit $ UnsafeRaw "\n"
emit . Comment $ show name <> ": " <> show exp
let args' = map (second type2LlvmType) args
emit . Comment $ show name <> ": " <> show (fst exp)
Function t_return t_args <- gets $ fst
. fromJust
. Map.lookup name
. globals
let args' = zip (mkCxtName : map fst args) t_args
emit $ Define FastCC t_return name args'
useGc <- gets gcEnabled
when (name == "main") (mapM_ emit (firstMainContent useGc))
functionBody <- exprToValue exp
if name == "main"
then mapM_ emit $ lastMainContent useGc functionBody
else emit $ Ret t_return functionBody
modify $ \s -> s { locals = foldr insertArg s.locals args' }
-- Dereference ptr arguments
when (notNull args') $
forM_ (tail args') $ \(x, t) -> when (t == Ptr) $ do
let t_deref =
let
Function t ts = type2LlvmType . fromJust $ lookup x args
in
Function t (Ptr : ts)
emit . SetVariable (mkDerefName x)
$ Load t_deref Ptr x
whenJust mcxt loadFreeVars
gcEnabled <- gets gcEnabled
when isMain $ mapM_ emit (firstMainContent gcEnabled)
result <- exprToValue exp
if isMain
then mapM_ emit $ lastMainContent gcEnabled result
else emit $ Ret t_return result
emit DefineEnd
modify $ \s -> s{variableCount = 0}
-- Reset variable count and empty locals
modify $ \s -> s { variableCount = 0, locals = mempty }
compileScs xs
compileScs (MIR.DData (MIR.Data typ ts) : xs) = do
let (TIR.Ident outer_id) = extractTypeName typ
where
loadFreeVars cxt = do
emit $ Comment "Load free variables"
zipWithM_ go cxt' [1 ..]
where
go (x, t) i = do
vc <- getNewVar
emit . SetVariable vc
$ GetElementPtrInbounds (CustomType $ mkClosureName name) Ptr (VIdent mkCxtName Ptr)
I32 (VInteger 0) I32 (VInteger i) -- TODO fix indices
emit . SetVariable x $ Load t Ptr vc
cxt' = map (second type2LlvmType) cxt
isMain = name == "main"
(name, args, exp, mcxt) = case bind of
Bind (name, _) args exp -> (name, args, exp, Nothing)
BindC cxt (name, _) args exp -> (name, args, exp, Just cxt)
insertArg (x, t) = snoc (x, LocalElem { val = VIdent x t, typ = t })
compileScs (DData (Data typ ts) : xs) = do
let (Ident outer_id) = extractTypeName typ
-- //TODO this could be extracted from the customTypes map
let variantTypes fi = init $ map type2LlvmType (flattenType fi)
let biggestVariant = 7 + maximum (sum . (\(Inj _ fi) -> typeByteSize <$> variantTypes fi) <$> ts)
emit $ LIR.Type (TIR.Ident outer_id) [I8, Array biggestVariant I8]
-- Add data type (e.g. %List) to top of the file
addStructType_ (Ident outer_id) [I8, Array biggestVariant I8]
typeSets <- gets customTypes
mapM_
( \(Inj inner_id fi) -> do
let types = (\s -> if Map.member s typeSets then Ref s else s) <$> variantTypes fi
emit $ LIR.Type inner_id (I8 : types)
-- Add constructor type (e.g. %Cons) to top of the file
addStructType_ inner_id (I8 : types)
)
ts
compileScs xs
@ -149,16 +202,16 @@ lastMainContent False var =
, Ret I64 (VInteger 0)
]
compileExp :: ExpT -> CompilerState ()
compileExp (MIR.ELit lit, _t) = emitLit lit
compileExp (MIR.EAdd e1 e2, t) = emitAdd t e1 e2
compileExp (MIR.EVar name, _t) = emitIdent name
compileExp (MIR.EApp e1 e2, t) = emitApp t e1 e2
compileExp (MIR.ELet bind e, _) = emitLet bind e
compileExp (MIR.ECase e cs, t) = emitECased t e (map (t,) cs)
compileExp :: T Exp -> CompilerState ()
compileExp (ELit lit, _t) = emitLit lit
compileExp (EAdd e1 e2, t) = emitAdd t e1 e2
compileExp (EVar name, _t) = emitIdent name
compileExp (EApp e1 e2, t) = emitApp t e1 e2
compileExp (ELet bind e, _) = emitLet bind e
compileExp (ECase e cs, t) = emitECased t e (map (t,) cs)
emitLet :: MIR.Bind -> ExpT -> CompilerState ()
emitLet (MIR.Bind id [] innerExp) e = do
emitLet :: Bind -> T Exp -> CompilerState ()
emitLet (Bind id [] innerExp) e = do
evaled <- exprToValue innerExp
tempVar <- getNewVar
let t = type2LlvmType . snd $ innerExp
@ -168,14 +221,14 @@ emitLet (MIR.Bind id [] innerExp) e = do
compileExp e
emitLet b _ = error $ "Non empty argument list in let-bind " <> show b
emitECased :: MIR.Type -> ExpT -> [(MIR.Type, Branch)] -> CompilerState ()
emitECased :: Type -> T Exp -> [(Type, Branch)] -> CompilerState ()
emitECased t e cases = do
let cs = snd <$> cases
let ty = type2LlvmType t
let rt = type2LlvmType (snd e)
vs <- exprToValue e
lbl <- getNewLabel
let label = TIR.Ident $ "escape_" <> show lbl
let label = Ident $ "escape_" <> show lbl
stackPtr <- getNewVar
emit $ SetVariable stackPtr (Alloca ty)
mapM_ (emitCases rt ty label stackPtr vs) cs
@ -192,14 +245,14 @@ emitECased t e cases = do
res <- getNewVar
emit $ SetVariable res (Load ty Ptr stackPtr)
where
emitCases :: LLVMType -> LLVMType -> TIR.Ident -> TIR.Ident -> LLVMValue -> Branch -> CompilerState ()
emitCases rt ty label stackPtr vs (Branch (MIR.PInj consId cs, _t) exp) = do
emitCases :: LLVMType -> LLVMType -> Ident -> Ident -> LLVMValue -> Branch -> CompilerState ()
emitCases rt ty label stackPtr vs (Branch (PInj consId cs, _t) exp) = do
emit $ Comment "Inj"
cons <- gets constructors
let r = fromJust $ Map.lookup consId cons
lbl_failPos <- (\x -> TIR.Ident $ "failed_" <> show x) <$> getNewLabel
lbl_succPos <- (\x -> TIR.Ident $ "success_" <> show x) <$> getNewLabel
lbl_failPos <- (\x -> Ident $ "failed_" <> show x) <$> getNewLabel
lbl_succPos <- (\x -> Ident $ "success_" <> show x) <$> getNewLabel
consVal <- getNewVar
emit $ SetVariable consVal (ExtractValue rt vs 0)
@ -215,10 +268,10 @@ emitECased t e cases = do
emit $ Store rt vs Ptr castPtr
emit $ SetVariable casted (Load (CustomType (coerce consId)) Ptr castPtr)
enumerateOneM_
( \i c -> do
( \i (c, t) -> do
case c of
PVar (x, topT) -> do
let topT' = type2LlvmType topT
PVar x -> do
let topT' = type2LlvmType t
let botT' = CustomType (coerce consId)
emit . Comment $ "ident " <> toIr topT'
cTypes <- gets customTypes
@ -228,7 +281,7 @@ emitECased t e cases = do
emit $ SetVariable deref (ExtractValue botT' (VIdent casted Ptr) i)
emit $ SetVariable x (Load topT' Ptr deref)
else emit $ SetVariable x (ExtractValue botT' (VIdent casted Ptr) i)
PLit (_l, _t) -> error "Nested pattern matching to be implemented"
PLit _l -> error "Nested pattern matching to be implemented"
PInj _id _ps -> error "Nested pattern matching to be implemented"
PCatch -> pure ()
PEnum _id -> error "Nested pattern matching to be implemented"
@ -238,22 +291,22 @@ emitECased t e cases = do
emit $ Store ty val Ptr stackPtr
emit $ Br label
emit $ Label lbl_failPos
emitCases _rt ty label stackPtr vs (Branch (MIR.PLit (i, ct), t) exp) = do
emitCases _rt ty label stackPtr vs (Branch (PLit i, t) exp) = do
emit $ Comment "Plit"
let i' = case i of
MIR.LInt i -> VInteger i
MIR.LChar i -> VChar (ord i)
LInt i -> VInteger i
LChar i -> VChar (ord i)
ns <- getNewVar
lbl_failPos <- (\x -> TIR.Ident $ "failed_" <> show x) <$> getNewLabel
lbl_succPos <- (\x -> TIR.Ident $ "success_" <> show x) <$> getNewLabel
emit $ SetVariable ns (Icmp LLEq (type2LlvmType ct) vs i')
lbl_failPos <- (\x -> Ident $ "failed_" <> show x) <$> getNewLabel
lbl_succPos <- (\x -> Ident $ "success_" <> show x) <$> getNewLabel
emit $ SetVariable ns (Icmp LLEq (type2LlvmType t) vs i')
emit $ BrCond (VIdent ns ty) lbl_succPos lbl_failPos
emit $ Label lbl_succPos
val <- exprToValue exp
emit $ Store ty val Ptr stackPtr
emit $ Br label
emit $ Label lbl_failPos
emitCases rt ty label stackPtr vs (Branch (MIR.PVar (id, _), _) exp) = do
emitCases rt ty label stackPtr vs (Branch (PVar id, _) exp) = do
emit $ Comment "Pvar"
-- //TODO this is pretty disgusting and would heavily benefit from a rewrite
valPtr <- getNewVar
@ -263,20 +316,20 @@ emitECased t e cases = do
val <- exprToValue exp
emit $ Store ty val Ptr stackPtr
emit $ Br label
lbl_failPos <- (\x -> TIR.Ident $ "failed_" <> show x) <$> getNewLabel
lbl_failPos <- (\x -> Ident $ "failed_" <> show x) <$> getNewLabel
emit $ Label lbl_failPos
emitCases rt ty label stackPtr vs (Branch (MIR.PEnum (TIR.Ident "True$Bool"), t) exp) = do
emitCases rt ty label stackPtr vs (Branch (MIR.PLit (MIR.LInt 1, TLit "Bool"), t) exp)
emitCases rt ty label stackPtr vs (Branch (MIR.PEnum (TIR.Ident "False$Bool"), _) exp) = do
emitCases rt ty label stackPtr vs (Branch (MIR.PLit (MIR.LInt 0, TLit "Bool"), t) exp)
emitCases rt ty label stackPtr vs br@(Branch (MIR.PEnum consId, _) exp) = do
emitCases rt ty label stackPtr vs (Branch (PEnum (Ident "True$Bool"), t) exp) = do
emitCases rt ty label stackPtr vs (Branch (PLit $ LInt 1, t) exp)
emitCases rt ty label stackPtr vs (Branch (PEnum (Ident "False$Bool"), _) exp) = do
emitCases rt ty label stackPtr vs (Branch (PLit (LInt 0), t) exp)
emitCases rt ty label stackPtr vs br@(Branch (PEnum consId, _) exp) = do
emit $ Comment "Penum"
cons <- gets constructors
let r = Map.lookup consId cons
when (isNothing r) (error $ "Constructor: '" ++ printTree consId ++ "' does not exist in cons state:\n" ++ show cons ++ "\nin pattern\n'" ++ printTree br ++ "'\n")
lbl_failPos <- (\x -> TIR.Ident $ "failed_" <> show x) <$> getNewLabel
lbl_succPos <- (\x -> TIR.Ident $ "success_" <> show x) <$> getNewLabel
lbl_failPos <- (\x -> Ident $ "failed_" <> show x) <$> getNewLabel
lbl_succPos <- (\x -> Ident $ "success_" <> show x) <$> getNewLabel
consVal <- getNewVar
emit $ SetVariable consVal (ExtractValue rt vs 0)
@ -295,98 +348,166 @@ emitECased t e cases = do
emit $ Store ty val Ptr stackPtr
emit $ Br label
emit $ Label lbl_failPos
emitCases _ ty label stackPtr _ (Branch (MIR.PCatch, _) exp) = do
emitCases _ ty label stackPtr _ (Branch (PCatch, _) exp) = do
emit $ Comment "Pcatch"
val <- exprToValue exp
emit $ Store ty val Ptr stackPtr
emit $ Br label
lbl_failPos <- (\x -> TIR.Ident $ "failed_" <> show x) <$> getNewLabel
lbl_failPos <- (\x -> Ident $ "failed_" <> show x) <$> getNewLabel
emit $ Label lbl_failPos
emitApp :: MIR.Type -> ExpT -> ExpT -> CompilerState ()
emitApp rt e1 e2 = appEmitter e1 e2 []
where
appEmitter :: ExpT -> ExpT -> [ExpT] -> CompilerState ()
appEmitter e1 e2 stack = do
let newStack = e2 : stack
case e1 of
(MIR.EApp e1' e2', _) -> appEmitter e1' e2' newStack
(MIR.EVar name, t) -> do
args <- traverse exprToValue newStack
vs <- getNewVar
funcs <- gets functions
consts <- gets constructors
let visibility =
fromMaybe Local $
Global <$ Map.lookup name consts
<|> Global <$ Map.lookup (name, t) funcs
-- this piece of code could probably be improved, i.e remove the double `const Global`
args' = map (first valueGetType . dupe) args
let call =
case name of
TIR.Ident ('l' : 't' : '$' : _) -> Icmp LLSlt I64 (snd (head args')) (snd (args' !! 1))
TIR.Ident ('$' : 'm' : 'i' : 'n' : 'u' : 's' : '$' : '$' : _) -> Sub I64 (snd (head args')) (snd (args' !! 1))
_ -> Call FastCC (type2LlvmType rt) visibility name args'
emit $ Comment $ show rt
emit $ SetVariable vs call
x -> error $ "The unspeakable happened: " <> show x
emitApp :: Type -> T Exp -> T Exp -> CompilerState ()
emitApp rt e1 e2 = do
((EVar name, t), args) <- go (EApp e1 e2, rt)
vs <- getNewVar
funcs <- gets functions
consts <- gets constructors
let visibility =
fromMaybe Local $
Global <$ Map.lookup name consts
<|> Global <$ Map.lookup (name, t) funcs
-- this piece of code could probably be improved, i.e remove the double `const Global`
emitIdent :: TIR.Ident -> CompilerState ()
call <- case name of
Ident ('l' : 't' : '$' : _) ->
pure $ Icmp LLSlt I64 (snd (head args)) (snd (args !! 1))
Ident ('$' : 'm' : 'i' : 'n' : 'u' : 's' : '$' : '$' : _) ->
pure $ Sub I64 (snd (head args)) (snd (args !! 1))
-- FIXME
_ -> do
let closure_call LocalElem { typ = Ptr, val } = (mkDerefName name, (Ptr, val) : args)
(name, args) <- gets $ maybe (name, (Ptr, VNull) : args) closure_call
. lookup name
. locals
pure $ Call FastCC (type2LlvmType rt) visibility name args
emit $ Comment $ show (type2LlvmType rt)
emit $ SetVariable vs call
where
go :: T Exp -> CompilerState (T Exp, [(LLVMType, LLVMValue)])
go et@(e, _) = case e of
EApp e1 e2@(_, t) -> do
(x, as) <- go e1
a <- exprToValue e2
let t' = type2LlvmType' t
pure (x, snoc (t', a) as)
_ -> pure (et, [])
type2LlvmType' = \case
TFun _ _ -> Ptr
t -> type2LlvmType t
emitIdent :: Ident -> CompilerState ()
emitIdent id = do
-- !!this should never happen!!
emit $ Comment "This should not have happened!"
emit $ Variable id
emit $ UnsafeRaw "\n"
emitLit :: MIR.Lit -> CompilerState ()
emitLit :: Lit -> CompilerState ()
emitLit i = do
-- !!this should never happen!!
let (i', t) = case i of
(MIR.LInt i'') -> (VInteger i'', I64)
(MIR.LChar i'') -> (VChar $ ord i'', I8)
(LInt i'') -> (VInteger i'', I64)
(LChar i'') -> (VChar $ ord i'', I8)
varCount <- getNewVar
emit $ Comment "This should not have happened!"
emit $ SetVariable varCount (Add t i' (VInteger 0))
emitAdd :: MIR.Type -> ExpT -> ExpT -> CompilerState ()
emitAdd :: Type -> T Exp -> T Exp -> CompilerState ()
emitAdd t e1 e2 = do
v1 <- exprToValue e1
v2 <- exprToValue e2
v <- getNewVar
emit $ SetVariable v (Add (type2LlvmType t) v1 v2)
exprToValue :: ExpT -> CompilerState LLVMValue
exprToValue = \case
(MIR.ELit i, _t) -> pure $ case i of
(MIR.LInt i) -> VInteger i
(MIR.LChar i) -> VChar $ ord i
(MIR.EVar (TIR.Ident "True$Bool"), _t) -> pure $ VInteger 1
(MIR.EVar (TIR.Ident "False$Bool"), _t) -> pure $ VInteger 0
(MIR.EVar name, t) -> do
funcs <- gets functions
cons <- gets constructors
let res =
Map.lookup (name, t) funcs
<|> ( \c ->
FunctionInfo
{ numArgs = numArgsCI c
, arguments = argumentsCI c
}
)
<$> Map.lookup name cons
case res of
Just fi -> do
if numArgs fi == 0
then do
vc <- getNewVar
emit $
SetVariable
vc
(Call FastCC (type2LlvmType t) Global name [])
pure $ VIdent vc (type2LlvmType t)
else pure $ VFunction name Global (type2LlvmType t)
Nothing -> pure $ VIdent name (type2LlvmType t)
e -> do
compileExp e
exprToValue :: T Exp -> CompilerState LLVMValue
exprToValue et@(e, t) = case e of
ELit (LInt i) -> pure $ VInteger i
ELit (LChar c) -> pure . VChar $ ord c
EVar "True$Bool" -> pure $ VInteger 1
EVar "False$Bool" -> pure $ VInteger 0
EVar name -> gets (Map.lookup name . globals) >>= \case
Just (typ@(Function _ ts), val) | length ts > 1 -> do
type_struct <- addStructType (mkClosureName name) [typ]
emit $ Comment "Allocating structure"
emit . SetVariable name $ Alloca type_struct
emit $ Store typ val Ptr name
pure $ VIdent name Ptr
Just _ | name == "main" -> do
vc <- getNewVar
emit $ SetVariable vc (Call FastCC I64 Global name [])
pure $ VIdent vc I64
Just (Function t_return [_], _) -> do
vc <- getNewVar
emit $ SetVariable vc (Call FastCC t_return Global name [(Ptr, VNull)])
pure $ VIdent vc t_return
Just _ -> error "Bad"
Nothing -> gets (Map.lookup name . constructors) >>= \case
Just ConstructorInfo {numArgsCI}
| numArgsCI == 0 -> do
vc <- getNewVar
emit $ SetVariable vc call
pure $ VIdent vc (type2LlvmType t)
| otherwise -> pure $ VFunction name Global (type2LlvmType t)
where
call = Call FastCC (type2LlvmType t) Global name []
Nothing -> gets $ val
. fromJust
. lookup name
. locals
EVarC cxt name -> do
let cxt' = flip map cxt $ \(x, t) -> let t' = type2LlvmType t
in (t', VIdent x t')
cxt'' <- gets $ (:cxt')
. fromJust
. Map.lookup name
. globals
-- Create a new type for function pointer and arguments
type_struct <- addStructType (mkClosureName name) $ map fst cxt''
emit $ Comment "Allocating structure"
emit . SetVariable name $ Alloca type_struct
let ptr_struct = VIdent name Ptr
storeArg (t, v) i = do
vc <- getNewVar
emit . SetVariable vc
$ GetElementPtrInbounds type_struct Ptr ptr_struct
I32 (VInteger 0) I32 (VInteger i) -- TODO fix indices
emit $ Store t v Ptr vc
-- Store arguments in structure
zipWithM_ storeArg cxt'' [0 ..]
pure ptr_struct
_ -> do
compileExp et
v <- getVarCount
pure $ VIdent (TIR.Ident $ show v) (getType e)
pure $ VIdent (Ident $ show v) (getType et)
mkClosureName :: Ident -> Ident
mkClosureName (Ident s) = Ident $ "Closure_" ++ s
mkDerefName :: Ident -> Ident
mkDerefName (Ident s) = Ident $ s ++ "_deref"
mkCxtName :: Ident
mkCxtName = Ident "cxt"

View file

@ -9,17 +9,18 @@ module Codegen.LlvmIr (
Visibility (..),
CallingConvention (..),
ToIr (..),
typeOf
) where
import Data.List (intercalate)
import TypeChecker.TypeCheckerIr (Ident (..))
import Data.List (intercalate)
import TypeChecker.TypeCheckerIr (Ident (..))
data CallingConvention = TailCC | FastCC | CCC | ColdCC deriving (Show, Eq, Ord)
instance ToIr CallingConvention where
toIr :: CallingConvention -> String
toIr TailCC = "tailcc"
toIr FastCC = "fastcc"
toIr CCC = "ccc"
toIr CCC = "ccc"
toIr ColdCC = "coldcc"
-- | A datatype which represents some basic LLVM types
@ -38,6 +39,9 @@ data LLVMType
class ToIr a where
toIr :: a -> String
instance ToIr a => ToIr [a] where
toIr = concatMap toIr
instance ToIr LLVMType where
toIr :: LLVMType -> String
toIr = \case
@ -66,8 +70,8 @@ data LLVMComp
instance ToIr LLVMComp where
toIr :: LLVMComp -> String
toIr = \case
LLEq -> "eq"
LLNe -> "ne"
LLEq -> "eq"
LLNe -> "ne"
LLUgt -> "ugt"
LLUge -> "uge"
LLUlt -> "ult"
@ -80,7 +84,7 @@ instance ToIr LLVMComp where
data Visibility = Local | Global deriving (Show, Eq, Ord)
instance ToIr Visibility where
toIr :: Visibility -> String
toIr Local = "%"
toIr Local = "%"
toIr Global = "@"
{- | Represents a LLVM "value", as in an integer, a register variable,
@ -92,16 +96,18 @@ data LLVMValue
| VIdent Ident LLVMType
| VConstant String
| VFunction Ident Visibility LLVMType
| VNull
deriving (Show, Eq, Ord)
instance ToIr LLVMValue where
toIr :: LLVMValue -> String
toIr v = case v of
VInteger i -> show i
VChar i -> show i
VIdent (Ident n) _ -> "%" <> n
VInteger i -> show i
VChar i -> show i
VIdent (Ident n) _ -> "%" <> n
VFunction (Ident n) vis _ -> toIr vis <> n
VConstant s -> "c" <> show s
VConstant s -> "c" <> show s
VNull -> "null"
type Params = [(Ident, LLVMType)]
type Args = [(LLVMType, LLVMValue)]
@ -139,6 +145,21 @@ data LLVMIr
-- instructions should be used in its place
deriving (Show, Eq, Ord)
-- TODO add missing clauses
typeOf :: LLVMIr -> LLVMType
typeOf = \case
Add t _ _ -> t
Sub t _ _ -> t
Mul t _ _ -> t
Div t _ _ -> t
Load t _ _ -> t
Store t _ _ _ -> t
Type x _ -> CustomType x
SetVariable _ ir -> typeOf ir
-- | Converts a list of LLVMIr instructions to a string
llvmIrToString :: [LLVMIr] -> String
llvmIrToString = go 0
@ -147,9 +168,9 @@ llvmIrToString = go 0
go _ [] = mempty
go i (x : xs) = do
let (i', n) = case x of
Define{} -> (i + 1, 0)
Define{} -> (i + 1, 0)
DefineEnd -> (i - 1, 0)
_ -> (i, i)
_ -> (i, i)
insToString n x <> go i' xs
-- \| Converts a LLVM inststruction to a String, allowing for printing etc.
@ -224,10 +245,10 @@ llvmIrToString = go 0
, ")\n"
]
(Alloca t) -> unwords ["alloca", toIr t, "\n"]
(Malloc t) ->
(Malloc t) ->
concat
[ "call ptr @malloc(i64 ", show t, ")\n"]
(GcMalloc t) ->
(GcMalloc t) ->
concat
[ "call ptr @cheap_alloc(i64 ", show t, ")\n"]
(Store t1 val t2 (Ident id2)) ->

View file

@ -1,9 +1,9 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
module Desugar.Desugar (desugar) where
import Grammar.Abs
import Grammar.Abs
{-
@ -17,21 +17,21 @@ desugar (Program defs) = Program (map desugarDef defs)
desugarVarName :: VarName -> LIdent
desugarVarName (VSymbol (Symbol i)) = LIdent $ fixName i
desugarVarName (VIdent i) = i
desugarVarName (VIdent i) = i
desugarDef :: Def -> Def
desugarDef = \case
DBind b -> DBind (desugarBind b)
DBind b -> DBind (desugarBind b)
DSig sig -> DSig (desugarSig sig)
DData d -> DData (desugarData d)
DData d -> DData (desugarData d)
desugarBind :: Bind -> Bind
desugarBind (BindS name args e) = Bind (desugarVarName name) args (desugarExp e)
desugarBind (Bind name args e) = Bind name args (desugarExp e)
desugarBind (Bind name args e) = Bind name args (desugarExp e)
desugarSig :: Sig -> Sig
desugarSig (SigS ident typ) = Sig (desugarVarName ident) (desugarType typ)
desugarSig (Sig ident typ) = Sig ident (desugarType typ)
desugarSig (Sig ident typ) = Sig ident (desugarType typ)
desugarData :: Data -> Data
desugarData (Data typ injs) = Data (desugarType typ) (map desugarInj injs)
@ -45,7 +45,7 @@ desugarType = \case
let (name : tvars) = flatten t1 ++ [t2]
in case name of
TIdent ident -> TData ident (map desugarType tvars)
_ -> error "desugarType is not implemented correctly"
_ -> error "desugarType is not implemented correctly"
TLit l -> TLit l
TVar v -> TVar v
(TAll i t) -> TAll i (desugarType t)
@ -55,7 +55,7 @@ desugarType = \case
where
flatten :: Type -> [Type]
flatten (TApp a b) = flatten a <> flatten b
flatten a = [a]
flatten a = [a]
desugarInj :: Inj -> Inj
desugarInj (Inj ident typ) = Inj ident (desugarType typ)
@ -80,14 +80,14 @@ desugarBranch (Branch p e) = Branch (desugarPattern p) (desugarExp e)
desugarPattern :: Pattern -> Pattern
desugarPattern = \case
PVar ident -> PVar ident
PLit lit -> PLit (desugarLit lit)
PCatch -> PCatch
PEnum ident -> PEnum ident
PVar ident -> PVar ident
PLit lit -> PLit (desugarLit lit)
PCatch -> PCatch
PEnum ident -> PEnum ident
PInj ident patterns -> PInj ident (map desugarPattern patterns)
desugarLit :: Lit -> Lit
desugarLit (LInt i) = LInt i
desugarLit (LInt i) = LInt i
desugarLit (LChar c) = LChar c
fixName :: String -> String
@ -115,4 +115,4 @@ fixName = concatMap mapSymbols
':' -> "$semicolon$"
'[' -> "$lbracket$"
']' -> "$rbracket$"
c -> c : ""
c -> c : ""

View file

@ -11,9 +11,11 @@ import Control.Monad.State (MonadState (get, put), State,
evalState)
import Data.Function (on)
import Data.List (delete, mapAccumL, (\\))
import Data.Tuple.Extra (first, second)
import LambdaLifterIr (T)
import qualified LambdaLifterIr as L
import Prelude hiding (exp)
import TypeChecker.TypeCheckerIr
import TypeChecker.TypeCheckerIr hiding (T)
-- | Lift lambdas and let expression into supercombinators.
-- Three phases:
@ -21,12 +23,13 @@ import TypeChecker.TypeCheckerIr
-- @abstract@ converts lambdas into let expressions.
-- @collectScs@ moves every non-constant let expression to a top-level function.
--
lambdaLift :: Program -> Program
lambdaLift (Program ds) = Program (datatypes ++ binds)
lambdaLift :: Program -> L.Program
lambdaLift (Program ds) = L.Program (datatypes ++ binds)
where
datatypes = flip filter ds $ \case DData _ -> True
_ -> False
binds = map DBind $ (collectScs . abstract . freeVars) [b | DBind b <- ds]
datatypes = [L.DData (toLirData d) | DData d <- ds]
binds = map L.DBind $ (collectScs . abstract . freeVars) [b | DBind b <- ds]
-- | Annotate free variables
freeVars :: [Bind] -> [ABind]
@ -36,7 +39,7 @@ freeVars binds = [ let ae = freeVarsExp [] e
| Bind n xs e <- binds
]
freeVarsExp :: Frees -> ExpT -> Ann AExpT
freeVarsExp :: Frees -> T Exp -> Ann (T AExp)
freeVarsExp localVars (ae, t) = case ae of
EVar n | elem (n,t) localVars -> Ann { frees = [(n, t)]
, term = (AVar n, t)
@ -121,27 +124,47 @@ data Ann a = Ann
, term :: a
} deriving (Show, Eq)
data ABind = ABind Id [Id] (Ann AExpT) deriving (Show, Eq)
data ABranch = ABranch (Pattern, Type) (Ann AExpT) deriving (Show, Eq)
type AExpT = (AExp, Type)
data ABind = ABind (T Ident) [T Ident] (Ann (T AExp)) deriving (Show, Eq)
data ABranch = ABranch (Pattern, Type) (Ann (T AExp)) deriving (Show, Eq)
data AExp = AVar Ident
| AInj Ident
| ALit Lit
| ALet (Ann ABind) (Ann AExpT)
| AApp (Ann AExpT) (Ann AExpT)
| AAdd (Ann AExpT) (Ann AExpT)
| AAbs Ident (Ann AExpT)
| ACase (Ann AExpT) [Ann ABranch]
| ALet (Ann ABind) (Ann (T AExp))
| AApp (Ann (T AExp)) (Ann (T AExp))
| AAdd (Ann (T AExp)) (Ann (T AExp))
| AAbs Ident (Ann (T AExp))
| ACase (Ann (T AExp)) [Ann ABranch]
deriving (Show, Eq)
abstract :: [ABind] -> [Bind]
data BBind = BBind (T Ident) [T Ident] (T BExp)
| BBindCxt [T Ident] (T Ident) [T Ident] (T BExp)
deriving (Eq, Ord, Show)
data BBranch = BBranch (T Pattern) (T BExp)
deriving (Eq, Ord, Show)
data BExp
= BVar Ident
| BVarC [T Ident] Ident
| BInj Ident
| BLit Lit
| BLet BBind (T BExp)
| BApp (T BExp)(T BExp)
| BAdd (T BExp)(T BExp)
| BCase (T BExp) [BBranch]
deriving (Eq, Ord, Show)
abstract :: [ABind] -> [BBind]
abstract bs = evalState (mapM (abstractAnnBind . Ann []) bs) 0
abstractAnnBind :: Ann ABind -> State Int Bind
abstractAnnBind :: Ann ABind -> State Int BBind
abstractAnnBind Ann { term = ABind name vars annae } =
Bind name (vars' <|| vars) <$> abstractAnnExp annae'
BBind name (vars' <|| vars) <$> abstractAnnExp annae'
where
(annae', vars') = go [] annae
where
@ -149,24 +172,27 @@ abstractAnnBind Ann { term = ABind name vars annae } =
Ann { term = (AAbs x ae, TFun t _) } -> go (snoc (x, t) acc) ae
ae -> (ae, acc)
abstractAnnExp :: Ann AExpT -> State Int ExpT
abstractAnnExp :: Ann (T AExp) -> State Int (T BExp)
abstractAnnExp Ann {frees, term = (annae, typ) } = case annae of
AVar n -> pure (EVar n, typ)
AInj n -> pure (EInj n, typ)
ALit lit -> pure (ELit lit, typ)
AApp annae1 annae2 -> (, typ) <$> onM EApp abstractAnnExp annae1 annae2
AAdd annae1 annae2 -> (, typ) <$> onM EAdd abstractAnnExp annae1 annae2
AVar n -> pure (BVar n, typ)
AInj n -> pure (BInj n, typ)
ALit lit -> pure (BLit lit, typ)
AApp annae1 annae2 -> (, typ) <$> onM BApp abstractAnnExp annae1 annae2
AAdd annae1 annae2 -> (, typ) <$> onM BAdd abstractAnnExp annae1 annae2
-- \x. \y. x + y + z ⇒ let sc x y z = x + y + z in sc
AAbs x annae' -> do
i <- nextNumber
rhs <- abstractAnnExp annae''
let sc_name = Ident ("sc_" ++ show i)
e@(_, t) = foldl applyFree (EVar sc_name, typ) frees
pure (ELet (Bind (sc_name, typ) vars rhs) e ,t)
sc | null frees = (BVar sc_name, typ)
| otherwise = (BVarC frees sc_name, typ)
bind | null frees = BBind (sc_name, typ) vars rhs
| otherwise = BBindCxt frees (sc_name, typ) vars rhs
pure (BLet bind sc ,typ)
where
vars = frees <| (x, t_x) <|| ys
vars = [(x, t_x)] <|| ys
t_x = case typ of TFun t _ -> t
_ -> error "Impossible"
@ -176,54 +202,47 @@ abstractAnnExp Ann {frees, term = (annae, typ) } = case annae of
Ann { term = (AAbs x ae, TFun t _) } -> go (snoc (x, t) acc) ae
ae -> (ae, acc)
applyFree :: (Exp' Type, Type) -> (Ident, Type) -> (Exp' Type, Type)
applyFree (e, t_e) (x, t_x) = (EApp (e, t_e) (EVar x, t_x), t_e')
where
t_e' = case t_e of TFun _ t -> t
_ -> error "Impossible"
ACase annae' bs -> do
bs <- mapM go bs
e <- abstractAnnExp annae'
pure (ECase e bs, typ)
pure (BCase e bs, typ)
where
go Ann { term = ABranch p annae } = Branch p <$> abstractAnnExp annae
go Ann { term = ABranch p annae } = BBranch p <$> abstractAnnExp annae
ALet b annae' ->
(, typ) <$> liftA2 ELet (abstractAnnBind b) (abstractAnnExp annae')
(, typ) <$> liftA2 BLet (abstractAnnBind b) (abstractAnnExp annae')
-- | Collects supercombinators by lifting non-constant let expressions
collectScs :: [Bind] -> [Bind]
collectScs :: [BBind] -> [L.Bind]
collectScs = concatMap collectFromRhs
where
collectFromRhs (Bind name parms rhs) =
collectFromRhs (BBind name parms rhs) =
let (rhs_scs, rhs') = collectScsExp rhs
in Bind name parms rhs' : rhs_scs
in L.Bind name parms rhs' : rhs_scs
collectFromRhs (BBindCxt cxt name parms rhs) =
let (rhs_scs, rhs') = collectScsExp rhs
in L.BindC cxt name parms rhs' : rhs_scs
collectScsExp :: ExpT -> ([Bind], ExpT)
collectScsExp expT@(exp, typ) = case exp of
EVar _ -> ([], expT)
EInj _ -> ([], expT)
ELit _ -> ([], expT)
collectScsExp :: T BExp -> ([L.Bind], T L.Exp)
collectScsExp (exp, typ) = case exp of
BVar x -> ([], (L.EVar x, typ))
BVarC as x -> ([], (L.EVarC as x, typ))
BInj k -> ([], (L.EInj k, typ))
BLit lit -> ([], (L.ELit lit, typ))
EApp e1 e2 -> (scs1 ++ scs2, (EApp e1' e2', typ))
BApp e1 e2 -> (scs1 ++ scs2, (L.EApp e1' e2', typ))
where
(scs1, e1') = collectScsExp e1
(scs2, e2') = collectScsExp e2
EAdd e1 e2 -> (scs1 ++ scs2, (EAdd e1' e2', typ))
BAdd e1 e2 -> (scs1 ++ scs2, (L.EAdd e1' e2', typ))
where
(scs1, e1') = collectScsExp e1
(scs2, e2') = collectScsExp e2
EAbs par e -> (scs, (EAbs par e', typ))
where
(scs, e') = collectScsExp e
ECase e branches -> (scs ++ scs_e, (ECase e' branches', typ))
BCase e branches -> (scs ++ scs_e, (L.ECase e' branches', typ))
where
(scs, branches') = mapAccumL f [] branches
(scs_e, e') = collectScsExp e
@ -234,15 +253,24 @@ collectScsExp expT@(exp, typ) = case exp of
--
-- > f = let sc x y = rhs in e
--
ELet (Bind name parms rhs) e
| null parms -> (rhs_scs ++ et_scs, (ELet bind et', snd et'))
BLet (BBind name parms rhs) e
| null parms -> (rhs_scs ++ et_scs, (L.ELet name rhs' et', snd et'))
| otherwise -> (bind : rhs_scs ++ et_scs, et')
where
bind = Bind name parms rhs'
bind = L.Bind name parms rhs'
(rhs_scs, rhs') = collectScsExp rhs
(et_scs, et') = collectScsExp e
collectScsBranch (Branch patt exp) = (scs, Branch patt exp')
BLet (BBindCxt cxt name parms rhs) e
| null parms -> (rhs_scs ++ et_scs, (L.ELet name rhs' et', snd et'))
| otherwise -> (bind : rhs_scs ++ et_scs, et')
where
bind = L.BindC cxt name parms rhs'
(rhs_scs, rhs') = collectScsExp rhs
(et_scs, et') = collectScsExp e
collectScsBranch (BBranch patt exp) = (scs, L.Branch (first toLirPattern patt) exp')
where (scs, exp') = collectScsExp exp
nextNumber :: State Int Int
@ -259,3 +287,13 @@ xs <| x | elem x xs = xs
(<||) :: Eq a => [a] -> [a] -> [a]
xs <|| ys = foldl (<|) xs ys
toLirData (Data t injs) = L.Data t (map toLirInj injs)
toLirInj (Inj n t) = L.Inj n t
toLirPattern :: Pattern -> L.Pattern
toLirPattern = \case
PVar x -> L.PVar x
PLit lit -> L.PLit lit
PCatch -> L.PCatch
PEnum k -> L.PEnum k
PInj k ps -> L.PInj k (map (first toLirPattern) ps)

140
src/LambdaLifterIr.hs Normal file
View file

@ -0,0 +1,140 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PatternSynonyms #-}
module LambdaLifterIr (
module Grammar.Abs,
module LambdaLifterIr,
module TypeChecker.TypeCheckerIr
) where
import Data.List (intercalate)
import Grammar.Abs (Lit (..))
import Grammar.Print
import Prelude hiding (exp)
import qualified Prelude as C (Eq, Ord, Show)
import TypeChecker.TypeCheckerIr (Ident (..), TVar (..), Type (..))
newtype Program = Program [Def]
deriving (C.Eq, C.Ord, C.Show)
data Def
= DBind Bind
| DData Data
deriving (C.Eq, C.Ord, C.Show)
data Data = Data Type [Inj]
deriving (C.Eq, C.Ord, C.Show)
data Inj = Inj Ident Type
deriving (C.Eq, C.Ord, C.Show)
data Pattern
= PVar Ident
| PLit Lit
| PCatch
| PEnum Ident
| PInj Ident [(Pattern, Type)]
deriving (C.Eq, C.Ord, C.Show)
data Exp
= EVar Ident
| EVarC [T Ident] Ident
| EInj Ident
| ELit Lit
| ELet (T Ident) (T Exp) (T Exp)
| EApp (T Exp)(T Exp)
| EAdd (T Exp)(T Exp)
| ECase (T Exp) [Branch]
deriving (C.Eq, C.Ord, C.Show)
type T a = (a, Type)
data Bind = Bind (T Ident) [T Ident] (T Exp)
| BindC [T Ident] (T Ident) [T Ident] (T Exp)
deriving (C.Eq, C.Ord, C.Show)
data Branch = Branch (T Pattern) (T Exp)
deriving (C.Eq, C.Ord, C.Show)
instance Print Program where
prt i (Program sc) = prt i sc
instance Print Bind where
prt i (Bind sig parms rhs) = concatD
[ prt i sig
, prt i parms
, doc $ showString "="
, prt i rhs
]
prt i (BindC cxt sig parms rhs) =
prPrec i 0 $
concatD
[ doc . showString $ "{" ++ intercalate ", " (map (\(x, _) -> printTree x ++ "^") cxt) ++ "}" ++ printTree sig
, prt i parms
, doc $ showString "="
, prt i rhs
]
instance Print [Bind] where
prt _ [] = concatD []
prt i [x] = concatD [prt i x]
prt i (x : xs) = concatD [prt i x, doc (showString ";"), prt i xs]
instance Print Exp where
prt i = \case
EVar lident -> prPrec i 3 (concatD [prt 0 lident])
EVarC as lident -> doc . showString
$ "{" ++ intercalate ", " (map go as) ++ "}" ++ printTree lident
where
go (x, _) = printTree x ++ "^=" ++ printTree (EVar x)
EInj uident -> prPrec i 3 (concatD [prt 0 uident])
ELit lit -> prPrec i 3 (concatD [prt 0 lit])
EApp exp1 exp2 -> prPrec i 2 (concatD [prt 2 exp1, prt 3 exp2])
EAdd exp1 exp2 -> prPrec i 1 (concatD [prt 1 exp1, doc (showString "+"), prt 2 exp2])
ELet lident exp1 exp2 -> prPrec i 0 (concatD [doc (showString "let"), prt 0 lident, doc (showString "="), prt 0 exp1 , doc (showString "in"), prt 0 exp2])
ECase exp branchs -> prPrec i 0 (concatD [doc (showString "case"), prt 0 exp, doc (showString "of"), doc (showString "{"), prt 0 branchs, doc (showString "}")])
instance Print Branch where
prt i (Branch (pattern_, t) exp) = prPrec i 0 (concatD [doc (showString "("), prt 0 pattern_, doc (showString " : "), prt 0 t, doc (showString ")"), doc (showString "=>"), prt 0 exp])
instance Print [Branch] where
prt _ [] = concatD []
prt _ [x] = concatD [prt 0 x]
prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs]
instance Print Def where
prt i = \case
DBind bind -> prPrec i 0 (concatD [prt 0 bind])
DData data_ -> prPrec i 0 (concatD [prt 0 data_])
instance Print Data where
prt i = \case
Data type_ injs -> prPrec i 0 (concatD [doc (showString "data"), prt 0 type_, doc (showString "where"), doc (showString "{"), prt 0 injs, doc (showString "}")])
instance Print Inj where
prt i = \case
Inj uident type_ -> prt i (uident, type_)
instance Print [Inj] where
prt _ [] = concatD []
prt i [x] = prt i x
prt i (x : xs) = prPrec i 0 $ concatD [prt i x, doc $ showString "\n ", prt i xs]
instance Print Pattern where
prt i = \case
PVar name -> prPrec i 1 (concatD [prt 0 name])
PLit lit -> prPrec i 1 (concatD [prt 0 lit])
PCatch -> prPrec i 1 (concatD [doc (showString "_")])
PEnum name -> prPrec i 1 (concatD [prt 0 name])
PInj uident patterns -> prPrec i 0 (concatD [prt 0 uident, prt 1 patterns])
instance Print [Def] where
prt _ [] = concatD []
prt _ [x] = concatD [prt 0 x]
prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs]
pattern DBind' id vars expt = DBind (Bind id vars expt)
pattern DData' typ injs = DData (Data typ injs)

View file

@ -2,47 +2,38 @@
module Main where
import AnnForall (annotateForall)
import Codegen.Codegen (generateCode)
import Compiler (compile)
import Control.Monad (when, (<=<))
import Data.List.Extra (isSuffixOf)
import Data.Maybe (fromJust, isNothing)
import Desugar.Desugar (desugar)
import GHC.IO.Handle.Text (hPutStrLn)
import Grammar.ErrM (Err)
import Grammar.Layout (resolveLayout)
import Grammar.Par (myLexer, pProgram)
import Grammar.Print (Print, printTree)
import LambdaLifter (lambdaLift)
import Monomorphizer.Monomorphizer (monomorphize)
import OrderDefs (orderDefs)
import Renamer.Renamer (rename)
import ReportForall (reportForall)
import System.Console.GetOpt (
ArgDescr (NoArg, ReqArg),
ArgOrder (RequireOrder),
OptDescr (Option),
getOpt,
usageInfo,
)
import System.Directory (
createDirectory,
doesPathExist,
getDirectoryContents,
removeDirectoryRecursive,
setCurrentDirectory,
)
import System.Environment (getArgs)
import System.Exit (
ExitCode (ExitFailure),
exitFailure,
exitSuccess,
exitWith,
)
import System.IO (stderr)
import System.Process (spawnCommand, waitForProcess)
import TypeChecker.TypeChecker (TypeChecker (Bi, Hm), typecheck)
import AnnForall (annotateForall)
import Codegen.Codegen (generateCode)
import Compiler (compile)
import Control.Monad (when, (<=<))
import Data.List.Extra (isSuffixOf)
import Data.Maybe (fromJust, isNothing)
import Desugar.Desugar (desugar)
import GHC.IO.Handle.Text (hPutStrLn)
import Grammar.ErrM (Err)
import Grammar.Layout (resolveLayout)
import Grammar.Par (myLexer, pProgram)
import Grammar.Print (Print, printTree)
import LambdaLifter (lambdaLift)
import Monomorphizer.Monomorphizer (monomorphize)
import OrderDefs (orderDefs)
import Renamer.Renamer (rename)
import ReportForall (reportForall)
import System.Console.GetOpt (ArgDescr (NoArg, ReqArg),
ArgOrder (RequireOrder),
OptDescr (Option), getOpt,
usageInfo)
import System.Directory (createDirectory, doesPathExist,
getDirectoryContents,
removeDirectoryRecursive,
setCurrentDirectory)
import System.Environment (getArgs)
import System.Exit (ExitCode (ExitFailure),
exitFailure, exitSuccess,
exitWith)
import System.IO (stderr)
import System.Process (spawnCommand, waitForProcess)
import TypeChecker.TypeChecker (TypeChecker (Bi, Hm), typecheck)
main :: IO ()
main = getArgs >>= parseArgs >>= uncurry main'
@ -94,12 +85,12 @@ chooseTypechecker s options = options{typechecker = tc}
tc = case s of
"hm" -> pure Hm
"bi" -> pure Bi
_ -> Nothing
_ -> Nothing
data Options = Options
{ help :: Bool
, debug :: Bool
, gc :: Bool
{ help :: Bool
, debug :: Bool
, gc :: Bool
, typechecker :: Maybe TypeChecker
}

View file

@ -1,8 +1,11 @@
module Monomorphizer.DataTypeRemover (removeDataTypes) where
import Monomorphizer.MonomorphizerIr qualified as M2
import Monomorphizer.MorbIr qualified as M1
import TypeChecker.TypeCheckerIr (Ident (Ident))
import Data.Bifunctor (Bifunctor (bimap))
import Monomorphizer.MonomorphizerIr (Ident (..))
import qualified Monomorphizer.MonomorphizerIr as M2
import qualified Monomorphizer.MorbIr as M1
import Prelude hiding (exp)
removeDataTypes :: M1.Program -> M2.Program
removeDataTypes (M1.Program defs) = M2.Program (map pDef defs)
@ -18,43 +21,43 @@ pCons :: M1.Inj -> M2.Inj
pCons (M1.Inj ident t) = M2.Inj ident (pType t)
pType :: M1.Type -> M2.Type
pType (M1.TLit ident) = M2.TLit ident
pType (M1.TFun t1 t2) = M2.TFun (pType t1) (pType t2)
pType (M1.TLit ident) = M2.TLit ident
pType (M1.TFun t1 t2) = M2.TFun (pType t1) (pType t2)
pType (M1.TData (Ident "Bool") _) = M2.TLit (Ident "Bool")
pType d = M2.TLit (Ident (newName d)) -- This is the step
pType d = M2.TLit (Ident (newName d)) -- This is the step
newName :: M1.Type -> String
newName (M1.TLit (Ident str)) = str
newName (M1.TFun t1 t2) = newName t1 ++ newName t2
newName (M1.TLit (Ident str)) = str
newName (M1.TFun t1 t2) = newName t1 ++ newName t2
newName (M1.TData (Ident str) args) = str ++ concatMap newName args
pBind :: M1.Bind -> M2.Bind
pBind (M1.Bind id argIds expt) = M2.Bind (pId id) (map pId argIds) (pExpT expt)
pBind (M1.BindC cxt id argIds expt) =
M2.BindC (map pId cxt) (pId id) (map pId argIds) (pExpT expt)
pId :: (Ident, M1.Type) -> (Ident, M2.Type)
pId (ident, t) = (ident, pType t)
pExpT :: M1.ExpT -> M2.ExpT
pExpT :: M1.T M1.Exp -> M2.T M2.Exp
pExpT (exp, t) = (pExp exp, pType t)
pExp :: M1.Exp -> M2.Exp
pExp (M1.EVar ident) = M2.EVar ident
pExp (M1.ELit lit) = M2.ELit (pLit lit)
pExp (M1.ELet bind expt) = M2.ELet (pBind bind) (pExpT expt)
pExp (M1.EApp e1 e2) = M2.EApp (pExpT e1) (pExpT e2)
pExp (M1.EAdd e1 e2) = M2.EAdd (pExpT e1) (pExpT e2)
pExp (M1.EVar ident) = M2.EVar ident
pExp (M1.EVarC as ident) = M2.EVarC (map pId as) ident
pExp (M1.ELit lit) = M2.ELit lit
pExp (M1.ELet bind expt) = M2.ELet (pBind bind) (pExpT expt)
pExp (M1.EApp e1 e2) = M2.EApp (pExpT e1) (pExpT e2)
pExp (M1.EAdd e1 e2) = M2.EAdd (pExpT e1) (pExpT e2)
pExp (M1.ECase expT branches) = M2.ECase (pExpT expT) (map pBranch branches)
pBranch :: M1.Branch -> M2.Branch
pBranch (M1.Branch (patt, t) expt) = M2.Branch (pPattern patt, pType t) (pExpT expt)
pPattern :: M1.Pattern -> M2.Pattern
pPattern (M1.PVar id) = M2.PVar (pId id)
pPattern (M1.PLit (lit, t)) = M2.PLit (pLit lit, pType t)
pPattern (M1.PInj ident patts) = M2.PInj ident (map pPattern patts)
pPattern M1.PCatch = M2.PCatch
pPattern (M1.PEnum ident) = M2.PEnum ident
pPattern (M1.PVar ident) = M2.PVar ident
pPattern (M1.PLit lit) = M2.PLit lit
pPattern (M1.PInj ident patts) = M2.PInj ident (map (bimap pPattern pType) patts)
pPattern M1.PCatch = M2.PCatch
pPattern (M1.PEnum ident) = M2.PEnum ident
pLit :: M1.Lit -> M2.Lit
pLit (M1.LInt v) = M2.LInt v
pLit (M1.LChar c) = M2.LChar c

View file

@ -1,4 +1,5 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedRecordDot #-}
{- | For now, converts polymorphic functions to concrete ones based on usage.
Assumes lambdas are lifted.
@ -25,30 +26,35 @@ bind) is added to the resulting set of binds.
-}
module Monomorphizer.Monomorphizer (monomorphize, morphExp, morphBind) where
import Monomorphizer.DataTypeRemover (removeDataTypes)
import Monomorphizer.MonomorphizerIr qualified as O
import Monomorphizer.MorbIr qualified as M
import TypeChecker.TypeCheckerIr (Ident (Ident))
import TypeChecker.TypeCheckerIr qualified as T
import Control.Monad.Reader (
MonadReader (ask, local),
Reader,
asks,
runReader,
)
import Control.Monad.State (
MonadState (get),
StateT (runStateT),
gets,
modify,
)
import Data.Coerce (coerce)
import Data.Map qualified as Map
import Data.Maybe (catMaybes)
import Data.Set qualified as Set
import Grammar.Print (printTree)
import Debug.Trace (trace)
import Control.Monad.Reader (MonadReader (ask, local),
Reader, asks, runReader)
import Control.Monad.State (MonadState (get),
StateT (runStateT), gets,
modify)
import Data.Coerce (coerce)
import qualified Data.Map as Map
import Data.Maybe (catMaybes)
import qualified Data.Set as Set
import Debug.Trace (trace)
import Grammar.Print (printTree)
import Monomorphizer.DataTypeRemover (removeDataTypes)
import qualified Monomorphizer.MonomorphizerIr as O
import qualified Monomorphizer.MorbIr as M
-- import TypeChecker.TypeCheckerIr (Ident (Ident))
import LambdaLifterIr (Ident (..))
-- import TypeChecker.TypeCheckerIr qualified as T
import qualified LambdaLifterIr as L
import Control.Monad.Reader (MonadReader (ask, local),
Reader, asks, runReader)
import Control.Monad.State (MonadState, StateT (runStateT),
gets, modify)
import qualified Data.Map as Map
import Data.Maybe (catMaybes, fromJust)
import qualified Data.Set as Set
import Data.Tuple.Extra (secondM)
import Grammar.Print (printTree)
{- | EnvM is the monad containing the read-only state as well as the
output state containing monomorphized functions and to-be monomorphized
@ -64,18 +70,18 @@ Binds, Polymorphic Data types (monomorphized in a later step) and
Marked bind, which means that it is in the process of monomorphization
and should not be monomorphized again.
-}
data Outputted = Marked | Complete M.Bind | Data M.Type T.Data deriving (Show)
data Outputted = Marked | Complete M.Bind | Data M.Type L.Data deriving (Show)
-- | Static environment.
data Env = Env
{ input :: Map.Map Ident T.Bind
{ input :: Map.Map Ident L.Bind
-- ^ All binds in the program.
, dataDefs :: Map.Map Ident T.Data
, dataDefs :: Map.Map Ident L.Data
-- ^ All constructors mapped to their respective polymorphic data def
-- which includes all other constructors.
, polys :: Map.Map Ident M.Type
, polys :: Map.Map Ident M.Type
-- ^ Maps polymorphic identifiers with concrete types.
, locals :: Set.Set Ident
, locals :: Set.Set Ident
-- ^ Local variables.
}
@ -84,12 +90,13 @@ localExists :: Ident -> EnvM Bool
localExists ident = asks (Set.member ident . locals)
-- | Gets a polymorphic bind from an id.
getInputBind :: Ident -> EnvM (Maybe T.Bind)
getInputBind :: Ident -> EnvM (Maybe L.Bind)
getInputBind ident = asks (Map.lookup ident . input)
-- | Add monomorphic function derived from a polymorphic one, to env.
addOutputBind :: M.Bind -> EnvM ()
addOutputBind b@(M.Bind (ident, _) _ _) = modify (Map.insert ident (Complete b))
addOutputBind b@(M.BindC _ (ident, _) _ _) = modify (Map.insert ident (Complete b))
{- | Marks a global bind as being processed, meaning that when encountered again,
it should not be recursively processed.
@ -106,8 +113,8 @@ isConsMarked :: Ident -> EnvM Bool
isConsMarked ident = gets (Map.member ident)
-- | Finds main bind.
getMain :: EnvM T.Bind
getMain = asks (\env -> case Map.lookup (T.Ident "main") (input env) of
getMain :: EnvM L.Bind
getMain = asks (\env -> case Map.lookup (Ident "main") (input env) of
Just mainBind -> mainBind
Nothing -> error "main not found in monomorphizer!"
)
@ -116,13 +123,13 @@ getMain = asks (\env -> case Map.lookup (T.Ident "main") (input env) of
error when encountering different structures between the two arguments. Debug:
First argument is the name of the bind.
-}
mapTypes :: Ident -> T.Type -> M.Type -> [(Ident, M.Type)]
mapTypes _ident (T.TLit _) (M.TLit _) = []
mapTypes _ident (T.TVar (T.MkTVar i1)) tm = [(i1, tm)]
mapTypes ident (T.TFun pt1 pt2) (M.TFun mt1 mt2) =
mapTypes :: Ident -> L.Type -> M.Type -> [(Ident, M.Type)]
mapTypes _ident (L.TLit _) (M.TLit _) = []
mapTypes _ident (L.TVar (L.MkTVar i1)) tm = [(i1, tm)]
mapTypes ident (L.TFun pt1 pt2) (M.TFun mt1 mt2) =
mapTypes ident pt1 mt1
++ mapTypes ident pt2 mt2
mapTypes ident (T.TData tIdent pTs) (M.TData mIdent mTs) =
mapTypes ident (L.TData tIdent pTs) (M.TData mIdent mTs) =
if tIdent /= mIdent
then error "the data type names of monomorphic and polymorphic data types does not match"
else foldl (\xs (p, m) -> mapTypes ident p m ++ xs) [] (zip pTs mTs)
@ -130,30 +137,30 @@ mapTypes ident t1 t2 = error $ "in bind: '" ++ printTree ident ++ "', " ++
"structure of types not the same: '" ++ printTree t1 ++ "', '" ++ printTree t2 ++ "'"
-- | Gets the mapped monomorphic type of a polymorphic type in the current context.
getMonoFromPoly :: T.Type -> EnvM M.Type
getMonoFromPoly :: L.Type -> EnvM M.Type
getMonoFromPoly t = do
env <- ask
return $ getMono (polys env) t
where
getMono :: Map.Map Ident M.Type -> T.Type -> M.Type
getMono :: Map.Map Ident M.Type -> L.Type -> M.Type
getMono polys t = case t of
(T.TLit ident) -> M.TLit (coerce ident)
(T.TFun t1 t2) -> M.TFun (getMono polys t1) (getMono polys t2)
(T.TVar (T.MkTVar ident)) -> case Map.lookup ident polys of
(L.TLit ident) -> M.TLit ident
(L.TFun t1 t2) -> M.TFun (getMono polys t1) (getMono polys t2)
(L.TVar (L.MkTVar ident)) -> case Map.lookup ident polys of
Just concrete -> concrete
Nothing -> M.TLit (Ident "void")
Nothing -> M.TLit (Ident "void")
-- error $ "type not found! type: " ++ show ident ++ ", error in previous compilation steps"
(T.TData ident args) -> M.TData ident (map (getMono polys) args)
(L.TData ident args) -> M.TData ident (map (getMono polys) args)
{- | If ident not already in env's output, morphed bind to output
(and all referenced binds within this bind).
Returns the annotated bind name.
-}
morphBind :: M.Type -> T.Bind -> EnvM Ident
morphBind expectedType b@(T.Bind (ident, btype) args (exp, expt)) = do
morphBind :: M.Type -> L.Bind -> EnvM Ident
morphBind expectedType b@(L.Bind (ident, btype) args (exp, expt)) = do
-- The "new name" is used to find out if it is already marked or not.
let name' = newFuncName expectedType b
bindMarked <- isBindMarked (coerce name')
bindMarked <- isBindMarked name'
local
( \env ->
env
@ -168,26 +175,59 @@ morphBind expectedType b@(T.Bind (ident, btype) args (exp, expt)) = do
else do
-- Mark so that this bind will not be processed in recursive or cyclic
-- function calls
markBind (coerce name')
markBind name'
expt' <- getMonoFromPoly expt
exp' <- morphExp expt' exp
-- Get monomorphic type sof args
args' <- mapM morphArg args
addOutputBind $
M.Bind
(coerce name', expectedType)
(name', expectedType)
args'
(exp', expt')
return name'
morphBind expectedType b@(L.BindC cxt (ident, btype) args (exp, expt)) = do
-- The "new name" is used to find out if it is already marked or not.
let name' = newFuncName expectedType b
bindMarked <- isBindMarked name'
local
( \env ->
env
{ locals = Set.fromList (map fst args)
, polys = Map.fromList (mapTypes ident btype expectedType)
}
)
$ do
-- Return with right name if already marked
if bindMarked
then return name'
else do
-- Mark so that this bind will not be processed in recursive or cyclic
-- function calls
markBind name'
-- Get monomorphic type sof args
args' <- mapM morphArg args
cxt' <- mapM (secondM getMonoFromPoly) cxt
expt' <- getMonoFromPoly expt
exp' <- local (\env -> foldr (addLocal . fst) env cxt)
(morphExp expt' exp)
addOutputBind $
M.BindC cxt'
(name', expectedType)
args'
(exp', expt')
return name'
-- | Monomorphizes arguments of a bind.
morphArg :: (Ident, T.Type) -> EnvM (Ident, M.Type)
morphArg :: (Ident, L.Type) -> EnvM (Ident, M.Type)
morphArg (ident, t) = do
t' <- getMonoFromPoly t
return (ident, t')
-- | Gets the data bind from the name of a constructor.
getInputData :: Ident -> EnvM (Maybe T.Data)
getInputData :: Ident -> EnvM (Maybe L.Data)
getInputData ident = do
env <- ask
return $ Map.lookup ident (dataDefs env)
@ -201,50 +241,50 @@ morphCons expectedType ident newIdent = do
--trace ("Tjofras:" ++ show (newName expectedType ident)) $ return ()
maybeD <- getInputData ident
case maybeD of
Nothing -> error $ "identifier '" ++ show ident ++ "' not found"
-- closures can have unbound variables
Nothing -> pure ()
Just d -> do
modify (\output -> Map.insert newIdent (Data expectedType d) output)
-- | Converts literals from input to output tree.
convertLit :: T.Lit -> M.Lit
convertLit (T.LInt v) = M.LInt v
convertLit (T.LChar v) = M.LChar v
convertLit :: L.Lit -> M.Lit
convertLit (L.LInt v) = M.LInt v
convertLit (L.LChar v) = M.LChar v
-- | Monomorphizes an expression, given an expected type.
morphExp :: M.Type -> T.Exp -> EnvM M.Exp
morphExp :: M.Type -> L.Exp -> EnvM M.Exp
morphExp expectedType exp = case exp of
T.ELit lit -> return $ M.ELit (convertLit lit)
L.ELit lit -> return $ M.ELit lit
-- Constructor
T.EInj ident -> do
L.EInj ident -> do
let ident' = newName (getDataType expectedType) ident
morphCons expectedType ident ident'
return $ M.EVar ident'
T.EApp (e1, _t1) (e2, t2) -> do
L.EApp (e1, _t1) (e2, t2) -> do
t2' <- getMonoFromPoly t2
e2' <- morphExp t2' e2
e1' <- morphExp (M.TFun t2' expectedType) e1
return $ M.EApp (e1', M.TFun t2' expectedType) (e2', t2')
T.EAdd (e1, t1) (e2, t2) -> do
L.EAdd (e1, t1) (e2, t2) -> do
t1' <- getMonoFromPoly t1
t2' <- getMonoFromPoly t2
e1' <- morphExp t1' e1
e2' <- morphExp t2' e2
return $ M.EAdd (e1', expectedType) (e2', expectedType)
T.EAbs ident (exp, t) -> local (\env -> env{locals = Set.insert ident (locals env)}) $ do
t' <- getMonoFromPoly t
morphExp t' exp
T.ECase (exp, t) bs -> do
L.ECase (exp, t) bs -> do
t' <- getMonoFromPoly t
exp' <- morphExp t' exp
bs' <- mapM morphBranch bs
return $ M.ECase (exp', t') (catMaybes bs')
-- Ideally constructors should be EInj, though this code handles them
-- as well.
T.EVar ident -> do
-- FIXME MAKE EVAR AND EINJ SEPARATE!!!
L.EVar ident -> do
isLocal <- localExists ident
if isLocal
then do
return $ M.EVar (coerce ident)
return $ M.EVar ident
else do
bind <- getInputBind ident
case bind of
@ -252,38 +292,51 @@ morphExp expectedType exp = case exp of
Just bind' -> do
-- New bind to process
newBindName <- morphBind expectedType bind'
return $ M.EVar (coerce newBindName)
T.ELet (T.Bind (identB, tB) args (expB, tExpB)) (exp, tExp) ->
if length args > 0 then error "only constants in lets allowed"
else do
return $ M.EVar newBindName
L.EVarC as ident -> do
isLocal <- localExists ident
if isLocal
then do
return $ M.EVar ident
else do
bind <- fromJust <$> getInputBind ident
as' <- mapM (secondM getMonoFromPoly) as
-- New bind to process
newBindName <- morphBind expectedType bind
return $ M.EVarC as' newBindName
-- Ideally constructors should be EInj, though this code handles them
-- as well.
L.ELet (identB, tB) (expB, tExpB) (exp, tExp) -> do
tB' <- getMonoFromPoly tB
tExpB' <- getMonoFromPoly tExpB
tExp' <- getMonoFromPoly tExp
expB' <- morphExp tExpB' expB
exp' <- morphExp tExp' exp
exp' <- local (addLocal identB) (morphExp tExp' exp)
return $ M.ELet (M.Bind (identB, tB') [] (expB', tExpB')) (exp', tExp')
-- | Monomorphizes case-of branches.
morphBranch :: T.Branch -> EnvM (Maybe M.Branch)
morphBranch (T.Branch (p, pt) (e, et)) = do
morphBranch :: L.Branch -> EnvM (Maybe M.Branch)
morphBranch (L.Branch (p, pt) (e, et)) = do
pt' <- getMonoFromPoly pt
et' <- getMonoFromPoly et
env <- ask
maybeMorphedPattern <- morphPattern p pt'
case maybeMorphedPattern of
Nothing -> return Nothing
Just (p', newLocals) ->
Just (p', newLocals) ->
local (const env { locals = Set.union (locals env) newLocals }) $ do
e' <- morphExp et' e
return $ Just (M.Branch (p', pt') (e', et'))
return $ Just (M.Branch p' (e', et'))
morphPattern :: T.Pattern -> M.Type -> EnvM (Maybe (M.Pattern, Set.Set Ident))
morphPattern :: L.Pattern -> M.Type -> EnvM (Maybe (M.T M.Pattern, Set.Set Ident))
morphPattern p expectedType = case p of
T.PVar ident -> return $ Just (M.PVar (ident, expectedType), Set.singleton ident)
T.PLit lit -> return $ Just (M.PLit (convertLit lit, expectedType), Set.empty)
T.PCatch -> return $ Just (M.PCatch, Set.empty)
T.PEnum ident -> return $ Just (M.PEnum (newName expectedType ident), Set.empty)
T.PInj ident pts -> do let newIdent = newName expectedType ident
L.PVar ident -> return $ Just ((M.PVar ident, expectedType), Set.singleton ident)
L.PLit lit -> return $ Just ((M.PLit (convertLit lit), expectedType), Set.empty)
L.PCatch -> return $ Just ((M.PCatch, expectedType), Set.empty)
L.PEnum ident -> return $ Just ((M.PEnum (newName expectedType ident), expectedType), Set.empty)
L.PInj ident pts -> do let newIdent = newName expectedType ident
outEnv <- get
trace ("WOW: " ++ show (newName expectedType ident)) $ return ()
trace ("WOW2: " ++ show (outEnv)) $ return ()
@ -297,13 +350,18 @@ morphPattern p expectedType = case p of
let maybePsSets = sequence psSets
case maybePsSets of
Nothing -> return Nothing
Just psSets' -> return $ Just
(M.PInj newIdent (map fst psSets'), Set.unions $ map snd psSets')
Just psSets' -> return $ Just
((M.PInj newIdent (map fst psSets'), expectedType), Set.unions $ map snd psSets')
else return Nothing
-- | Creates a new identifier for a function with an assigned type.
newFuncName :: M.Type -> T.Bind -> Ident
newFuncName t (T.Bind (ident@(Ident bindName), _) _ _) =
newFuncName :: M.Type -> L.Bind -> Ident
newFuncName t (L.Bind (ident@(Ident bindName), _) _ _) =
if bindName == "main"
then Ident bindName
else newName t ident
newFuncName t (L.BindC _ (ident@(Ident bindName), _) _ _) =
if bindName == "main"
then Ident bindName
else newName t ident
@ -317,8 +375,8 @@ newName t (Ident str) = Ident $ str ++ "$" ++ newName' t
newName' (M.TData (Ident str) ts) = str ++ foldl (\s t -> s ++ "." ++ newName' t) "" ts
-- | Monomorphization step.
monomorphize :: T.Program -> O.Program
monomorphize (T.Program defs) =
monomorphize :: L.Program -> O.Program
monomorphize (L.Program defs) =
removeDataTypes $
M.Program
( getDefsFromOutput
@ -336,7 +394,7 @@ runEnvM :: Output -> Env -> EnvM () -> Output
runEnvM o env (EnvM stateM) = snd $ runReader (runStateT stateM o) env
-- | Creates the environment based on the input binds.
createEnv :: [T.Def] -> Env
createEnv :: [L.Def] -> Env
createEnv defs =
Env
{ input = Map.fromList bindPairs
@ -346,33 +404,34 @@ createEnv defs =
}
where
bindPairs = (map (\b -> (getBindName b, b)) . getBindsFromDefs) defs
dataPairs :: [(Ident, T.Data)]
dataPairs = (foldl (\acc d@(T.Data _ cs) -> map ((,d) . getConsName) cs ++ acc) [] . getDataFromDefs) defs
dataPairs :: [(Ident, L.Data)]
dataPairs = (foldl (\acc d@(L.Data _ cs) -> map ((,d) . getConsName) cs ++ acc) [] . getDataFromDefs) defs
-- | Gets a top-lefel function name.
getBindName :: T.Bind -> Ident
getBindName (T.Bind (ident, _) _ _) = ident
getBindName :: L.Bind -> Ident
getBindName (L.Bind (ident, _) _ _) = ident
getBindName (L.BindC _ (ident, _) _ _) = ident
-- Helper functions
-- Gets custom data declarations form defs.
getDataFromDefs :: [T.Def] -> [T.Data]
getDataFromDefs :: [L.Def] -> [L.Data]
getDataFromDefs =
foldl
( \bs -> \case
T.DBind _ -> bs
T.DData d -> d : bs
L.DBind _ -> bs
L.DData d -> d : bs
)
[]
getConsName :: T.Inj -> Ident
getConsName (T.Inj ident _) = ident
getConsName :: L.Inj -> Ident
getConsName (L.Inj ident _) = ident
getBindsFromDefs :: [T.Def] -> [T.Bind]
getBindsFromDefs :: [L.Def] -> [L.Bind]
getBindsFromDefs =
foldl
( \bs -> \case
T.DBind b -> b : bs
T.DData _ -> bs
L.DBind b -> b : bs
L.DData _ -> bs
)
[]
@ -384,19 +443,19 @@ getDefsFromOutput o =
(binds, dataInput) = splitBindsAndData o
-- | Splits the output into binds and data declaration components (used in createNewData)
splitBindsAndData :: Output -> ([M.Bind], [(Ident, M.Type, T.Data)])
splitBindsAndData :: Output -> ([M.Bind], [(Ident, M.Type, L.Data)])
splitBindsAndData output =
foldl
( \(oBinds, oData) (ident, o) -> case o of
Marked -> error "internal bug in monomorphizer"
Marked -> error "internal bug in monomorphizer"
Complete b -> (b : oBinds, oData)
Data t d -> (oBinds, (ident, t, d) : oData)
Data t d -> (oBinds, (ident, t, d) : oData)
)
([], [])
(Map.toList output)
-- | Converts all found constructors to monomorphic data declarations.
createNewData :: [(Ident, M.Type, T.Data)] -> Map.Map Ident M.Data -> Map.Map Ident M.Data
createNewData :: [(Ident, M.Type, L.Data)] -> Map.Map Ident M.Data -> Map.Map Ident M.Data
createNewData [] o = o
createNewData ((consIdent, consType, polyData) : input) o =
createNewData input $
@ -406,14 +465,17 @@ createNewData ((consIdent, consType, polyData) : input) o =
(M.Data newDataType [newCons])
o
where
T.Data (T.TData polyDataIdent _) _ = polyData
L.Data (L.TData polyDataIdent _) _ = polyData
newDataType = getDataType consType
newDataName = newName newDataType polyDataIdent
newCons = M.Inj consIdent consType
-- | Gets the Data Type of a constructor type (a -> Just a becomes Just a).
getDataType :: M.Type -> M.Type
getDataType (M.TFun _t1 t2) = getDataType t2
getDataType (M.TFun _t1 t2) = getDataType t2
getDataType tData@(M.TData _ _) = tData
getDataType _ = error "???"
getDataType _ = error "???"
addLocal :: Ident -> Env -> Env
addLocal x env = env { locals = Set.insert x env.locals }

View file

@ -1,11 +1,14 @@
{-# LANGUAGE LambdaCase #-}
module Monomorphizer.MonomorphizerIr (module Monomorphizer.MonomorphizerIr) where
module Monomorphizer.MonomorphizerIr (
module Monomorphizer.MonomorphizerIr,
module LambdaLifterIr
) where
import Grammar.Print
import TypeChecker.TypeCheckerIr qualified as TIR (Ident (..))
type Id = (TIR.Ident, Type)
import Data.List (intercalate)
import Grammar.Print
import LambdaLifterIr (Ident (..), Lit (..))
import Prelude hiding (exp)
newtype Program = Program [Def]
deriving (Show, Ord, Eq)
@ -16,90 +19,80 @@ data Def = DBind Bind | DData Data
data Data = Data Type [Inj]
deriving (Show, Ord, Eq)
data Bind = Bind Id [Id] ExpT
data Bind = Bind (T Ident) [T Ident] (T Exp)
| BindC [T Ident] (T Ident) [T Ident] (T Exp)
deriving (Show, Ord, Eq)
type T a = (a, Type)
data Exp
= EVar TIR.Ident
= EVar Ident
| EVarC [T Ident] Ident
| ELit Lit
| ELet Bind ExpT
| EApp ExpT ExpT
| EAdd ExpT ExpT
| ECase ExpT [Branch]
| ELet Bind (T Exp)
| EApp (T Exp) (T Exp)
| EAdd (T Exp) (T Exp)
| ECase (T Exp) [Branch]
deriving (Show, Ord, Eq)
data Pattern
= PVar Id
| PLit (Lit, Type)
| PInj TIR.Ident [Pattern]
= PVar Ident
| PLit Lit
| PInj Ident [T Pattern]
| PCatch
| PEnum TIR.Ident
| PEnum Ident
deriving (Eq, Ord, Show)
data Branch = Branch (Pattern, Type) ExpT
data Branch = Branch (T Pattern) (T Exp)
deriving (Eq, Ord, Show)
type ExpT = (Exp, Type)
data Inj = Inj TIR.Ident Type
data Inj = Inj Ident Type
deriving (Show, Ord, Eq)
data Lit
= LInt Integer
| LChar Char
deriving (Show, Ord, Eq)
data Type = TLit TIR.Ident | TFun Type Type
data Type = TLit Ident | TFun Type Type
deriving (Show, Ord, Eq)
flattenType :: Type -> [Type]
flattenType (TFun t1 t2) = t1 : flattenType t2
flattenType x = [x]
flattenType x = [x]
instance Print Program where
prt i (Program sc) = prPrec i 0 $ prt 0 sc
instance Print (Bind) where
instance Print Bind where
prt i (Bind sig@(name, _) parms rhs) =
prPrec i 0 $
concatD
[ prtSig sig
[ prt 0 sig
, prt 0 name
, prtIdPs 0 parms
, prt 0 parms
, doc $ showString "="
, prt 0 rhs
]
prtSig :: Id -> Doc
prtSig (name, t) =
concatD
[ prt 0 name
, doc $ showString ":"
, prt 0 t
, doc $ showString ";"
]
prt i (BindC cxt sig parms rhs) =
prPrec i 0 $
concatD
[ doc . showString $ "{" ++ intercalate ", " (map (\(x, _) -> printTree x ++ "^") cxt) ++ "}" ++ printTree sig
, prt i parms
, doc $ showString "="
, prt i rhs
]
instance Print (ExpT) where
prt i (e, t) =
concatD
[ doc $ showString "("
, prt i e
, doc $ showString ","
, prt i t
, doc $ showString ")"
]
instance Print [Bind] where
prt _ [] = concatD []
prt _ [x] = concatD [prt 0 x]
prt _ [] = concatD []
prt _ [x] = concatD [prt 0 x]
prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs]
prtIdPs :: Int -> [Id] -> Doc
prtIdPs i = prPrec i 0 . concatD . map (prt i)
instance Print Exp where
prt i = \case
EVar name -> prPrec i 3 $ prt 0 name
EVarC as lident -> doc . showString
$ "{" ++ intercalate ", " (map go as) ++ "}" ++ printTree lident
where
go (x, _) = printTree x ++ "^=" ++ printTree (EVar x)
ELit lit -> prPrec i 3 $ prt 0 lit
ELet b e ->
prPrec i 3 $
@ -134,16 +127,16 @@ instance Print Exp where
]
instance Print Branch where
prt i (Branch (pattern_, t) exp) = prPrec i 0 (concatD [doc (showString "("), prt 0 pattern_, doc (showString " : "), prt 0 t, doc (showString ")"), doc (showString "=>"), prt 0 exp])
prt i (Branch patt exp) = prPrec i 0 (concatD [prt i patt, doc (showString "=>"), prt 0 exp])
instance Print [Branch] where
prt _ [] = concatD []
prt _ [x] = concatD [prt 0 x]
prt _ [] = concatD []
prt _ [x] = concatD [prt 0 x]
prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs]
instance Print Def where
prt i = \case
DBind bind -> prPrec i 0 (concatD [prt 0 bind])
DBind bind -> prPrec i 0 (concatD [prt 0 bind])
DData data_ -> prPrec i 0 (concatD [prt 0 data_])
instance Print Data where
@ -152,31 +145,26 @@ instance Print Data where
instance Print Inj where
prt i = \case
Inj uident type_ -> prPrec i 0 (concatD [prt 0 uident, doc (showString ":"), prt 0 type_])
Inj uident type_ -> prt i (uident, type_)
instance Print Pattern where
prt i = \case
PVar name -> prPrec i 1 (concatD [prt 0 name])
PLit (lit, _) -> prPrec i 1 (concatD [prt 0 lit])
PLit lit -> prPrec i 1 (concatD [prt 0 lit])
PCatch -> prPrec i 1 (concatD [doc (showString "_")])
PEnum name -> prPrec i 1 (concatD [prt 0 name])
PInj uident patterns -> prPrec i 0 (concatD [prt 0 uident, prt 1 patterns])
instance Print [Def] where
prt _ [] = concatD []
prt _ [x] = concatD [prt 0 x]
prt _ [] = concatD []
prt _ [x] = concatD [prt 0 x]
prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs]
instance Print [Type] where
prt _ [] = concatD []
prt _ [] = concatD []
prt _ (x : xs) = concatD [prt 0 x, doc (showString " "), prt 0 xs]
instance Print Type where
prt i = \case
TLit uident -> prPrec i 1 (concatD [prt 0 uident])
TFun type_1 type_2 -> prPrec i 0 (concatD [prt 1 type_1, doc (showString "->"), prt 0 type_2])
instance Print Lit where
prt i = \case
LInt int -> prt i int
LChar char -> prt i char

View file

@ -1,10 +1,14 @@
{-# LANGUAGE LambdaCase #-}
module Monomorphizer.MorbIr where
import Grammar.Print
import TypeChecker.TypeCheckerIr qualified as TIR (Ident (..))
module Monomorphizer.MorbIr (
module Monomorphizer.MorbIr,
module LambdaLifterIr
) where
type Id = (TIR.Ident, Type)
import Data.List (intercalate)
import Grammar.Print
import LambdaLifterIr (Ident (..), Lit (..))
import Prelude hiding (exp)
newtype Program = Program [Def]
deriving (Show, Ord, Eq)
@ -15,91 +19,81 @@ data Def = DBind Bind | DData Data
data Data = Data Type [Inj]
deriving (Show, Ord, Eq)
data Bind = Bind Id [Id] ExpT
data Bind = Bind (T Ident) [T Ident] (T Exp)
| BindC [T Ident] (T Ident) [T Ident] (T Exp)
deriving (Show, Ord, Eq)
type T a = (a, Type)
data Exp
= EVar TIR.Ident
= EVar Ident
| EVarC [T Ident] Ident
| ELit Lit
| ELet Bind ExpT
| EApp ExpT ExpT
| EAdd ExpT ExpT
| ECase ExpT [Branch]
| ELet Bind (T Exp)
| EApp (T Exp) (T Exp)
| EAdd (T Exp) (T Exp)
| ECase (T Exp) [Branch]
deriving (Show, Ord, Eq)
data Pattern
= PVar Id
| PLit (Lit, Type)
| PInj TIR.Ident [Pattern]
= PVar Ident
| PLit Lit
| PInj Ident [T Pattern]
| PCatch
| PEnum TIR.Ident
| PEnum Ident
deriving (Eq, Ord, Show)
data Branch = Branch (Pattern, Type) ExpT
data Branch = Branch (T Pattern) (T Exp)
deriving (Eq, Ord, Show)
type ExpT = (Exp, Type)
data Inj = Inj TIR.Ident Type
data Inj = Inj Ident Type
deriving (Show, Ord, Eq)
data Lit
= LInt Integer
| LChar Char
deriving (Show, Ord, Eq)
data Type = TLit TIR.Ident | TFun Type Type | TData TIR.Ident [Type]
data Type = TLit Ident | TFun Type Type | TData Ident [Type]
deriving (Show, Ord, Eq)
flattenType :: Type -> [Type]
flattenType (TFun t1 t2) = t1 : flattenType t2
flattenType x = [x]
flattenType x = [x]
instance Print Program where
prt i (Program sc) = prPrec i 0 $ prt 0 sc
instance Print (Bind) where
instance Print Bind where
prt i (Bind sig@(name, _) parms rhs) =
prPrec i 0 $
concatD
[ prtSig sig
[ prt 0 sig
, prt 0 name
, prtIdPs 0 parms
, prt 0 parms
, doc $ showString "="
, prt 0 rhs
]
prtSig :: Id -> Doc
prtSig (name, t) =
concatD
[ prt 0 name
, doc $ showString ":"
, prt 0 t
, doc $ showString ";"
]
instance Print (ExpT) where
prt i (e, t) =
concatD
[ doc $ showString "("
, prt i e
, doc $ showString ","
, prt i t
, doc $ showString ")"
]
prt i (BindC cxt sig parms rhs) =
prPrec i 0 $
concatD
[ doc . showString $ "{" ++ intercalate ", " (map (\(x, _) -> printTree x ++ "^") cxt) ++ "}" ++ printTree sig
, prt i parms
, doc $ showString "="
, prt i rhs
]
instance Print [Bind] where
prt _ [] = concatD []
prt _ [x] = concatD [prt 0 x]
prt _ [] = concatD []
prt _ [x] = concatD [prt 0 x]
prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs]
prtIdPs :: Int -> [Id] -> Doc
prtIdPs i = prPrec i 0 . concatD . map (prt i)
instance Print Exp where
prt i = \case
EVar name -> prPrec i 3 $ prt 0 name
EVarC as lident -> doc . showString
$ "{" ++ intercalate ", " (map go as) ++ "}" ++ printTree lident
where
go (x, _) = printTree x ++ "^=" ++ printTree (EVar x)
ELit lit -> prPrec i 3 $ prt 0 lit
ELet b e ->
prPrec i 3 $
@ -134,16 +128,16 @@ instance Print Exp where
]
instance Print Branch where
prt i (Branch (pattern_, t) exp) = prPrec i 0 (concatD [doc (showString "("), prt 0 pattern_, doc (showString " : "), prt 0 t, doc (showString ")"), doc (showString "=>"), prt 0 exp])
prt i (Branch patt exp) = prPrec i 0 (concatD [prt i patt, doc (showString "=>"), prt 0 exp])
instance Print [Branch] where
prt _ [] = concatD []
prt _ [x] = concatD [prt 0 x]
prt _ [] = concatD []
prt _ [x] = concatD [prt 0 x]
prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs]
instance Print Def where
prt i = \case
DBind bind -> prPrec i 0 (concatD [prt 0 bind])
DBind bind -> prPrec i 0 (concatD [prt 0 bind])
DData data_ -> prPrec i 0 (concatD [prt 0 data_])
instance Print Data where
@ -152,23 +146,23 @@ instance Print Data where
instance Print Inj where
prt i = \case
Inj uident type_ -> prPrec i 0 (concatD [prt 0 uident, doc (showString ":"), prt 0 type_])
Inj uident type_ -> prt i (uident, type_)
instance Print Pattern where
prt i = \case
PVar name -> prPrec i 1 (concatD [prt 0 name])
PLit (lit, _) -> prPrec i 1 (concatD [prt 0 lit])
PLit lit -> prPrec i 1 (concatD [prt 0 lit])
PCatch -> prPrec i 1 (concatD [doc (showString "_")])
PEnum name -> prPrec i 1 (concatD [prt 0 name])
PInj uident patterns -> prPrec i 0 (concatD [prt 0 uident, prt 1 patterns])
instance Print [Def] where
prt _ [] = concatD []
prt _ [x] = concatD [prt 0 x]
prt _ [] = concatD []
prt _ [x] = concatD [prt 0 x]
prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs]
instance Print [Type] where
prt _ [] = concatD []
prt _ [] = concatD []
prt _ (x : xs) = concatD [prt 0 x, doc (showString " "), prt 0 xs]
instance Print Type where
@ -176,9 +170,3 @@ instance Print Type where
TLit uident -> prPrec i 1 (concatD [prt 0 uident])
TFun type_1 type_2 -> prPrec i 0 (concatD [prt 1 type_1, doc (showString "->"), prt 0 type_2])
TData uident types -> prPrec i 1 (concatD [prt 0 uident, doc (showString "("), prt 0 types, doc (showString ")")])
instance Print Lit where
prt i = \case
LInt int -> prt i int
LChar char -> prt i char

View file

@ -2,15 +2,15 @@
module TypeChecker.ReportTEVar where
import Auxiliary (onM)
import Control.Applicative (Applicative (liftA2), liftA3)
import Control.Monad.Except (MonadError (throwError))
import Data.Coerce (coerce)
import Data.Tuple.Extra (secondM)
import Grammar.Abs qualified as G
import Grammar.ErrM (Err)
import Grammar.Print (printTree)
import TypeChecker.TypeCheckerIr hiding (Type (..))
import Auxiliary (onM)
import Control.Applicative (Applicative (liftA2), liftA3)
import Control.Monad.Except (MonadError (throwError))
import Data.Coerce (coerce)
import Data.Tuple.Extra (secondM)
import qualified Grammar.Abs as G
import Grammar.ErrM (Err)
import Grammar.Print (printTree)
import TypeChecker.TypeCheckerIr hiding (Type (..))
data Type
= TLit Ident
@ -18,7 +18,7 @@ data Type
| TData Ident [Type]
| TFun Type Type
| TAll TVar Type
deriving (Eq, Ord, Show, Read)
deriving (Eq, Ord, Show)
class ReportTEVar a b where
reportTEVar :: a -> Err b
@ -29,20 +29,20 @@ instance ReportTEVar (Program' G.Type) (Program' Type) where
instance ReportTEVar (Def' G.Type) (Def' Type) where
reportTEVar = \case
DBind bind -> DBind <$> reportTEVar bind
DData dat -> DData <$> reportTEVar dat
DData dat -> DData <$> reportTEVar dat
instance ReportTEVar (Bind' G.Type) (Bind' Type) where
reportTEVar (Bind id vars rhs) = liftA3 Bind (reportTEVar id) (reportTEVar vars) (reportTEVar rhs)
instance ReportTEVar (Exp' G.Type) (Exp' Type) where
reportTEVar exp = case exp of
EVar name -> pure $ EVar name
EInj name -> pure $ EInj name
ELit lit -> pure $ ELit lit
ELet bind e -> liftA2 ELet (reportTEVar bind) (reportTEVar e)
EApp e1 e2 -> onM EApp reportTEVar e1 e2
EAdd e1 e2 -> onM EAdd reportTEVar e1 e2
EAbs name e -> EAbs name <$> reportTEVar e
EVar name -> pure $ EVar name
EInj name -> pure $ EInj name
ELit lit -> pure $ ELit lit
ELet bind e -> liftA2 ELet (reportTEVar bind) (reportTEVar e)
EApp e1 e2 -> onM EApp reportTEVar e1 e2
EAdd e1 e2 -> onM EAdd reportTEVar e1 e2
EAbs name e -> EAbs name <$> reportTEVar e
ECase e branches -> liftA2 ECase (reportTEVar e) (reportTEVar branches)
instance ReportTEVar (Branch' G.Type) (Branch' Type) where
@ -53,10 +53,10 @@ instance ReportTEVar (Pattern' G.Type, G.Type) (Pattern' Type, Type) where
instance ReportTEVar (Pattern' G.Type) (Pattern' Type) where
reportTEVar = \case
PVar name -> pure $ PVar name
PLit lit -> pure $ PLit lit
PCatch -> pure PCatch
PEnum name -> pure $ PEnum name
PVar name -> pure $ PVar name
PLit lit -> pure $ PLit lit
PCatch -> pure PCatch
PEnum name -> pure $ PEnum name
PInj name ps -> PInj name <$> reportTEVar ps
instance ReportTEVar (Data' G.Type) (Data' Type) where
@ -65,10 +65,10 @@ instance ReportTEVar (Data' G.Type) (Data' Type) where
instance ReportTEVar (Inj' G.Type) (Inj' Type) where
reportTEVar (Inj name typ) = Inj name <$> reportTEVar typ
instance ReportTEVar (Id' G.Type) (Id' Type) where
instance ReportTEVar (a, G.Type) (a, Type) where
reportTEVar = secondM reportTEVar
instance ReportTEVar (ExpT' G.Type) (ExpT' Type) where
instance ReportTEVar (T' Exp' G.Type) (T' Exp' Type) where
reportTEVar (exp, typ) = liftA2 (,) (reportTEVar exp) (reportTEVar typ)
instance ReportTEVar a b => ReportTEVar [a] [b] where
@ -76,9 +76,9 @@ instance ReportTEVar a b => ReportTEVar [a] [b] where
instance ReportTEVar G.Type Type where
reportTEVar = \case
G.TLit lit -> pure $ TLit (coerce lit)
G.TVar (G.MkTVar i) -> pure $ TVar (MkTVar $ coerce i)
G.TData name typs -> TData (coerce name) <$> reportTEVar typs
G.TFun t1 t2 -> liftA2 TFun (reportTEVar t1) (reportTEVar t2)
G.TLit lit -> pure $ TLit (coerce lit)
G.TVar (G.MkTVar i) -> pure $ TVar (MkTVar $ coerce i)
G.TData name typs -> TData (coerce name) <$> reportTEVar typs
G.TFun t1 t2 -> liftA2 TFun (reportTEVar t1) (reportTEVar t2)
G.TAll (G.MkTVar i) t -> TAll (MkTVar $ coerce i) <$> reportTEVar t
G.TEVar tevar -> throwError ("Found TEVar: " ++ printTree tevar)
G.TEVar tevar -> throwError ("Found TEVar: " ++ printTree tevar)

View file

@ -31,6 +31,7 @@ import Grammar.ErrM
import Grammar.Print (printTree)
import Prelude hiding (exp)
import qualified TypeChecker.TypeCheckerIr as T
import TypeChecker.TypeCheckerIr (T, T')
-- Implementation is derived from the paper (Dunfield and Krishnaswami 2013)
-- https://doi.org/10.1145/2500365.2500582
@ -172,7 +173,7 @@ typecheckInj (Inj inj_name inj_typ) name tvars
-- | Γ ⊢ e ↑ A ⊣ Δ
-- Under input context Γ, e checks against input type A, with output context ∆
check :: Exp -> Type -> Tc (T.ExpT' Type)
check :: Exp -> Type -> Tc (T' T.Exp' Type)
-- Γ,α ⊢ e ↑ A ⊣ Δ,α
-- ------------------- ∀I
@ -212,12 +213,6 @@ check (ECase scrut pi) c = do
e' <- check e c
pure (T.Branch p' e')
apply (T.ECase (scrut', a) pi', c)
where
go (pi, b) (Branch p e) = do
p' <- checkPattern p =<< apply a
e'@(_, b') <- infer e
subtype b' b
apply (T.Branch p' e' : pi, b')
-- Γ,α ⊢ e ↓ A ⊣ Θ Θ ⊢ [Θ]A <: [Θ]B ⊣ Δ
@ -229,9 +224,6 @@ check e b = do
subtype a b'
apply (e', b)
checkPattern :: Pattern -> Type -> Tc (T.Pattern' Type, Type)
checkPattern patt t_patt = case patt of
@ -297,7 +289,7 @@ checkPattern patt t_patt = case patt of
-- | Γ ⊢ e ↓ A ⊣ Δ
-- Under input context Γ, e infers output type A, with output context ∆
infer :: Exp -> Tc (T.ExpT' Type)
infer :: Exp -> Tc (T' T.Exp' Type)
infer (ELit lit) = apply (T.ELit lit, litType lit)
-- Γ ∋ (x : A) Γ ⊢ rec(x)
@ -391,7 +383,7 @@ infer (ECase scrut pi) = do
-- | Γ ⊢ A • e ⇓ C ⊣ Δ
-- Under input context Γ , applying a function of type A to e infers type C, with output context ∆
-- Instantiate existential type variables until there is an arrow type.
applyInfer :: Type -> Exp -> Tc (T.ExpT' Type, Type)
applyInfer :: Type -> Exp -> Tc (T' T.Exp' Type, Type)
-- Γ,ά ⊢ [ά/α]A • e ⇓ C ⊣ Δ
-- ------------------------ ∀App

View file

@ -1,32 +1,33 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedRecordDot #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE QualifiedDo #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE QualifiedDo #-}
{-# OPTIONS_GHC -Wno-incomplete-patterns #-}
-- | A module for type checking and inference using algorithm W, Hindley-Milner
module TypeChecker.TypeCheckerHm where
import Auxiliary (int, litType, maybeToRightM, unzip4)
import Auxiliary qualified as Aux
import Control.Monad.Except
import Control.Monad.Identity (Identity, runIdentity)
import Control.Monad.Reader
import Control.Monad.State
import Control.Monad.Writer
import Data.Coerce (coerce)
import Data.Function (on)
import Data.List (foldl', nub, sortOn)
import Data.List.Extra (unsnoc)
import Data.Map (Map)
import Data.Map qualified as M
import Data.Maybe (fromJust)
import Data.Set (Set)
import Data.Set qualified as S
import Debug.Trace (trace, traceShow)
import Grammar.Abs
import Grammar.Print (printTree)
import TypeChecker.TypeCheckerIr qualified as T
import Auxiliary (int, litType, maybeToRightM, unzip4)
import qualified Auxiliary as Aux
import Control.Monad.Except
import Control.Monad.Identity (Identity, runIdentity)
import Control.Monad.Reader
import Control.Monad.State
import Control.Monad.Writer
import Data.Coerce (coerce)
import Data.Function (on)
import Data.List (foldl', nub, sortOn)
import Data.List.Extra (unsnoc)
import Data.Map (Map)
import qualified Data.Map as M
import Data.Maybe (fromJust)
import Data.Set (Set)
import qualified Data.Set as S
import Debug.Trace (trace, traceShow)
import Grammar.Abs
import Grammar.Print (printTree)
import TypeChecker.TypeCheckerIr (T, T')
import qualified TypeChecker.TypeCheckerIr as T
{-
TODO
@ -41,7 +42,7 @@ typecheck :: Program -> Either String (T.Program' Type, [Warning])
typecheck = onLeft msg . run . checkPrg
where
onLeft :: (Error -> String) -> Either Error a -> Either String a
onLeft f (Left x) = Left $ f x
onLeft f (Left x) = Left $ f x
onLeft _ (Right x) = Right x
checkPrg :: Program -> Infer (T.Program' Type)
@ -68,13 +69,13 @@ prettify s (T.Program defs) = T.Program $ map (go s) defs
replace :: Map T.Ident T.Ident -> Type -> Type
replace m def@(TVar (MkTVar (LIdent a))) = case M.lookup (coerce a) m of
Just t -> TVar . MkTVar . LIdent $ coerce t
Just t -> TVar . MkTVar . LIdent $ coerce t
Nothing -> def
replace m (TFun t1 t2) = (TFun `on` replace m) t1 t2
replace m (TData name ts) = TData name (map (replace m) ts)
replace m def@(TAll (MkTVar forall_) t) = case M.lookup (coerce forall_) m of
Just found -> TAll (MkTVar $ coerce found) (replace m t)
Nothing -> def
Nothing -> def
replace _ t = t
bindCount :: [Def] -> Infer [(Int, Def)]
@ -128,7 +129,7 @@ preRun (x : xs) = case x of
s <- gets sigs
case M.lookup (coerce n) s of
Nothing -> insertSig (coerce n) Nothing >> preRun xs
Just _ -> preRun xs
Just _ -> preRun xs
DData d@(Data t _) -> collect (collectTVars t) >> checkData d >> preRun xs
where
-- Check if function body / signature has been declared already
@ -150,11 +151,11 @@ checkDef (x : xs) = case x of
T.Data t $ map (\(Inj name typ) -> T.Inj (coerce name) typ) injs
freeOrdered :: Type -> [T.Ident]
freeOrdered (TVar (MkTVar a)) = return (coerce a)
freeOrdered (TVar (MkTVar a)) = return (coerce a)
freeOrdered (TAll (MkTVar bound) t) = return (coerce bound) ++ freeOrdered t
freeOrdered (TFun a b) = freeOrdered a ++ freeOrdered b
freeOrdered (TData _ a) = concatMap freeOrdered a
freeOrdered _ = mempty
freeOrdered (TFun a b) = freeOrdered a ++ freeOrdered b
freeOrdered (TData _ a) = concatMap freeOrdered a
freeOrdered _ = mempty
-- Much cleaner implementation, unfortunately one minor bug
-- checkBind :: Bind -> Infer (T.Bind' Type)
@ -259,13 +260,13 @@ checkInj (Inj c inj_typ) name tvars
toTVar :: Type -> Either Error TVar
toTVar = \case
TVar tvar -> pure tvar
_ -> uncatchableErr "Not a type variable"
_ -> uncatchableErr "Not a type variable"
returnType :: Type -> Type
returnType (TFun _ t2) = returnType t2
returnType a = a
returnType a = a
inferExp :: Exp -> Infer (T.ExpT' Type)
inferExp :: Exp -> Infer (T' T.Exp' Type)
inferExp e = do
(s, (e', t)) <- algoW e
let subbed = apply s t
@ -276,7 +277,7 @@ class CollectTVars a where
instance CollectTVars Exp where
collectTVars (EAnn e t) = collectTVars t `S.union` collectTVars e
collectTVars _ = S.empty
collectTVars _ = S.empty
instance CollectTVars Type where
collectTVars (TVar (MkTVar i)) = S.singleton (coerce i)
@ -289,7 +290,7 @@ instance CollectTVars Type where
collect :: Set T.Ident -> Infer ()
collect s = modify (\st -> st{takenTypeVars = s `S.union` takenTypeVars st})
algoW :: Exp -> Infer (Subst, T.ExpT' Type)
algoW :: Exp -> Infer (Subst, T' T.Exp' Type)
algoW = \case
err@(EAnn e t) -> do
(sub0, (e', t')) <- exprErr (algoW e) err
@ -602,12 +603,12 @@ generalize :: Map T.Ident Type -> Type -> Type
generalize env t = go (S.toList $ free t S.\\ free env) (removeForalls t)
where
go :: [T.Ident] -> Type -> Type
go [] t = t
go [] t = t
go (x : xs) t = TAll (MkTVar (coerce x)) (go xs t)
removeForalls :: Type -> Type
removeForalls (TAll _ t) = removeForalls t
removeForalls (TAll _ t) = removeForalls t
removeForalls (TFun t1 t2) = TFun (removeForalls t1) (removeForalls t2)
removeForalls t = t
removeForalls t = t
{- | Instantiate a polymorphic type. The free type variables are substituted
with fresh ones.
@ -655,7 +656,7 @@ fresh = do
ungo :: [TVar] -> Type -> Type -> Bool
ungo tvars t1 t2 = case run (go tvars t1 t2) of
Right (b, _) -> b
_ -> False
_ -> False
-- TODO: Fix the following
-- Maybe locally using the Infer monad can cause trouble.
-- Since the fresh count starts from zero
@ -700,15 +701,15 @@ instance SubstType Type where
TLit _ -> t
TVar (MkTVar a) -> case M.lookup (coerce a) sub of
Nothing -> TVar (MkTVar $ coerce a)
Just t -> t
Just t -> t
TAll (MkTVar i) t -> case M.lookup (coerce i) sub of
Nothing -> TAll (MkTVar i) (apply sub t)
Just _ -> apply sub t
Just _ -> apply sub t
TFun a b -> TFun (apply sub a) (apply sub b)
TData name a -> TData name (apply sub a)
TEVar (MkTEVar a) -> case M.lookup (coerce a) sub of
Nothing -> TEVar (MkTEVar $ coerce a)
Just t -> t
Just t -> t
instance FreeVars (Map T.Ident Type) where
free :: Map T.Ident Type -> Set T.Ident
@ -721,7 +722,7 @@ instance SubstType (Map T.Ident Type) where
instance SubstType (Map T.Ident (Maybe Type)) where
apply s = M.map (fmap $ apply s)
instance SubstType (T.ExpT' Type) where
instance SubstType (T' T.Exp' Type) where
apply s (e, t) = (apply s e, apply s t)
instance SubstType (T.Exp' Type) where
@ -750,10 +751,10 @@ instance SubstType (T.Branch' Type) where
instance SubstType (T.Pattern' Type) where
apply s = \case
T.PVar iden -> T.PVar iden
T.PLit lit -> T.PLit lit
T.PLit lit -> T.PLit lit
T.PInj i ps -> T.PInj i $ apply s ps
T.PCatch -> T.PCatch
T.PEnum i -> T.PEnum i
T.PCatch -> T.PCatch
T.PEnum i -> T.PEnum i
instance SubstType (T.Pattern' Type, Type) where
apply s (p, t) = (apply s p, apply s t)
@ -761,7 +762,7 @@ instance SubstType (T.Pattern' Type, Type) where
instance SubstType a => SubstType [a] where
apply s = map (apply s)
instance SubstType (T.Id' Type) where
instance SubstType (T T.Ident Type) where
apply s (name, t) = (name, apply s t)
-- | Represents the empty substition set
@ -794,11 +795,11 @@ withBindings xs =
-- | Run the monadic action with a pattern
withPattern :: (Monad m, MonadReader Ctx m) => (T.Pattern' Type, Type) -> m a -> m a
withPattern (p, t) ma = case p of
T.PVar x -> withBinding x t ma
T.PVar x -> withBinding x t ma
T.PInj _ ps -> foldl' (flip withPattern) ma ps
T.PLit _ -> ma
T.PCatch -> ma
T.PEnum _ -> ma
T.PLit _ -> ma
T.PCatch -> ma
T.PEnum _ -> ma
-- | Insert a function signature into the environment
insertSig :: T.Ident -> Maybe Type -> Infer ()
@ -823,11 +824,11 @@ existInj n = gets (M.lookup n . injections)
flattenType :: Type -> [Type]
flattenType (TFun a b) = flattenType a <> flattenType b
flattenType a = [a]
flattenType a = [a]
typeLength :: Type -> Int
typeLength (TFun _ b) = 1 + typeLength b
typeLength _ = 1
typeLength _ = 1
{- | Catch an error if possible and add the given
expression as addition to the error message
@ -910,11 +911,11 @@ newtype Ctx = Ctx {vars :: Map T.Ident Type}
deriving (Show)
data Env = Env
{ count :: Int
, nextChar :: Char
, sigs :: Map T.Ident (Maybe Type)
{ count :: Int
, nextChar :: Char
, sigs :: Map T.Ident (Maybe Type)
, takenTypeVars :: Set T.Ident
, injections :: Map T.Ident Type
, injections :: Map T.Ident Type
, declaredBinds :: Set T.Ident
}
deriving (Show)

View file

@ -10,31 +10,30 @@ import Data.String (IsString)
import Grammar.Abs (Lit (..))
import Grammar.Print
import Prelude
import qualified Prelude as C (Eq, Ord, Read, Show)
newtype Program' t = Program [Def' t]
deriving (C.Eq, C.Ord, C.Show, C.Read, Functor)
deriving (Eq, Ord, Show, Functor)
data Def' t
= DBind (Bind' t)
| DData (Data' t)
deriving (C.Eq, C.Ord, C.Show, C.Read, Functor)
deriving (Eq, Ord, Show, Functor)
data Type
= TLit Ident
| TVar TVar
| TData Ident [Type]
| TFun Type Type
deriving (Eq, Ord, Show, Read)
deriving (Eq, Ord, Show)
data Data' t = Data t [Inj' t]
deriving (C.Eq, C.Ord, C.Show, C.Read, Functor)
deriving (Eq, Ord, Show, Functor)
data Inj' t = Inj Ident t
deriving (C.Eq, C.Ord, C.Show, C.Read, Functor)
deriving (Eq, Ord, Show, Functor)
newtype Ident = Ident String
deriving (C.Eq, C.Ord, C.Show, C.Read, IsString)
deriving (Eq, Ord, Show, IsString)
data Pattern' t
= PVar Ident
@ -42,30 +41,31 @@ data Pattern' t
| PCatch
| PEnum Ident
| PInj Ident [(Pattern' t, t)]
deriving (C.Eq, C.Ord, C.Show, C.Read, Functor)
deriving (Eq, Ord, Show, Functor)
data Exp' t
= EVar Ident
| EInj Ident
| ELit Lit
| ELet (Bind' t) (ExpT' t)
| EApp (ExpT' t) (ExpT' t)
| EAdd (ExpT' t) (ExpT' t)
| EAbs Ident (ExpT' t)
| ECase (ExpT' t) [Branch' t]
deriving (C.Eq, C.Ord, C.Show, C.Read, Functor)
| ELet (Bind' t) (T' Exp' t)
| EApp (T' Exp' t) (T' Exp' t)
| EAdd (T' Exp' t) (T' Exp' t)
| EAbs Ident (T' Exp' t)
| ECase (T' Exp' t) [Branch' t]
deriving (Eq, Ord, Show, Functor)
newtype TVar = MkTVar Ident
deriving (C.Eq, C.Ord, C.Show, C.Read)
deriving (Eq, Ord, Show)
type Id' t = (Ident, t)
type ExpT' t = (Exp' t, t)
type T' a t = (a t, t)
type T a t = (a, t)
data Bind' t = Bind (Id' t) [Id' t] (ExpT' t)
deriving (C.Eq, C.Ord, C.Show, C.Read, Functor)
data Branch' t = Branch (Pattern' t, t) (ExpT' t)
deriving (C.Eq, C.Ord, C.Show, C.Read, Functor)
data Bind' t = Bind (T Ident t) [T Ident t] (T' Exp' t)
deriving (Eq, Ord, Show, Functor)
data Branch' t = Branch (T' Pattern' t) (T' Exp' t)
deriving (Eq, Ord, Show, Functor)
instance Print Ident where
prt _ (Ident s) = doc $ showString s
@ -81,22 +81,22 @@ instance Print t => Print (Bind' t) where
, prt i rhs
]
prtSig :: Print t => Id' t -> Doc
prtSig (name, t) =
prtSig :: Print t => T Ident t -> Doc
prtSig (x, t) =
concatD
[ prt 0 name
[ prt 0 x
, doc $ showString ":"
, prt 0 t
]
instance Print t => Print (ExpT' t) where
prt i (e, t) =
instance (Print a, Print t) => Print (T a t) where
prt i (x, t) =
concatD
[ doc $ showString "("
, prt i e
, doc $ showString ":"
, prt 0 t
, doc $ showString ")"
[ -- doc $ showString "("
{- , -} prt i x
-- , doc $ showString ":"
-- , prt 0 t
-- , doc $ showString ")"
]
instance Print t => Print [Bind' t] where
@ -104,16 +104,6 @@ instance Print t => Print [Bind' t] where
prt i [x] = concatD [prt i x]
prt i (x : xs) = concatD [prt i x, doc (showString ";"), prt i xs]
instance Print t => Print (Id' t) where
prt i (name, t) =
concatD
[ doc $ showString "("
, prt i name
, doc $ showString ","
, prt i t
, doc $ showString ")"
]
instance Print t => Print (Exp' t) where
prt i = \case
EVar lident -> prPrec i 3 (concatD [prt 0 lident])
@ -151,9 +141,6 @@ instance Print t => Print [Inj' t] where
prt i [x] = prt i x
prt i (x : xs) = prPrec i 0 $ concatD [prt i x, doc $ showString "\n ", prt i xs]
instance Print t => Print (Pattern' t, t) where
prt i (p, t) = prPrec i 1 (concatD [prt i p, prt i t])
instance Print t => Print (Pattern' t) where
prt i = \case
PVar name -> prPrec i 1 (concatD [prt 0 name])
@ -189,8 +176,6 @@ type Branch = Branch' Type
type Pattern = Pattern' Type
type Inj = Inj' Type
type Exp = Exp' Type
type ExpT = ExpT' Type
type Id = Id' Type
pattern TVar' s = TVar (MkTVar s)
pattern DBind' id vars expt = DBind (Bind id vars expt)
pattern DData' typ injs = DData (Data typ injs)

185
test_map2.ll Normal file
View file

@ -0,0 +1,185 @@
target triple = "x86_64-pc-linux-gnu"
target datalayout = "e-m:o-i64:64-f80:128-n8:16:32:64-S128"
@.str = private unnamed_addr constant [3 x i8] c"%i
", align 1
@.non_exhaustive_patterns = private unnamed_addr constant [41 x i8] c"Non-exhaustive patterns in case at %i:%i
"
declare i32 @printf(ptr noalias nocapture, ...)
declare i32 @exit(i32 noundef)
declare ptr @malloc(i32 noundef)
%List = type { i8, [23 x i8] }
%Cons = type { i8, i64, %List* }
%Nil = type { i8 }
; NYTT: kontexttyp
%Closure_sc_0 = type { i64 (i64)*, i64 }
; Ident "sum$List_Int": (ECase (EVar (Ident "$4xs"),TLit (Ident "List")) [Branch (PEnum (Ident "Nil"),TLit (Ident "List")) (ELit (LInt 0),TLit (Ident "Int")),Branch (PInj (Ident "Cons") [PVar (Ident "$5x",TLit (Ident "Int")),PVar (Ident "$6xs",TLit (Ident "List"))],TLit (Ident "List")) (EAdd (EVar (Ident "$5x"),TLit (Ident "Int")) (EApp (EVar (Ident "sum$List_Int"),TFun (TLit (Ident "List")) (TLit (Ident "Int"))) (EVar (Ident "$6xs"),TLit (Ident "List")),TLit (Ident "Int")),TLit (Ident "Int"))],TLit (Ident "Int"))
define fastcc i64 @sum$List_Int(%List %$4xs) {
%1 = alloca i64
; Penum
%2 = extractvalue %List %$4xs, 0
%3 = icmp eq i8 %2, 1
br i1 %3, label %lbl_success_3, label %lbl_failed_2
lbl_success_3:
%4 = alloca %List
store %List %$4xs, ptr %4
%5 = load %Nil, ptr %4
store i64 0, ptr %1
br label %lbl_escape_1
lbl_failed_2:
; Inj
%6 = extractvalue %List %$4xs, 0
%7 = icmp eq i8 %6, 0
br i1 %7, label %lbl_success_5, label %lbl_failed_4
lbl_success_5:
%8 = alloca %List
store %List %$4xs, ptr %8
%9 = load %Cons, ptr %8
; ident i64
%$5x = extractvalue %Cons %9, 1
; ident %List
%10 = extractvalue %Cons %9, 2
%$6xs = load %List, ptr %10
; TLit (Ident "Int")
%11 = call fastcc i64 @sum$List_Int(%List %$6xs)
%12 = add i64 %$5x, %11
store i64 %12, ptr %1
br label %lbl_escape_1
lbl_failed_4:
call i32 (ptr, ...) @printf(ptr noundef @.non_exhaustive_patterns, i64 noundef 12, i64 noundef 6)
call i32 @exit(i32 noundef 1)
br label %lbl_escape_1
lbl_escape_1:
%15 = load i64, ptr %1
ret i64 %15
}
; Ident "sc_0$Int_Int": (EAdd (EVar (Ident "$7x"),TLit (Ident "Int")) (ELit (LInt 10),TLit (Ident "Int")),TLit (Ident "Int"))
; ÄNDRAT: lägg till kontextpekare
define fastcc i64 @sc_0$Int_Int(ptr %closure_sc_0, i64 %$7x) {
; NYTT: Ladda alla fria variabler
%fri_variabel_ptr = getelementptr inbounds %Closure_sc_0, ptr %closure_sc_0, i32 0, i32 1
%fri_variabel = load i64, ptr %fri_variabel_ptr
; ÄNDRAT: %fri_variabel istället för 2
%1 = add i64 %$7x, %fri_variabel
ret i64 %1
}
; Ident "map$Int_Int_List_List": (ECase (EVar (Ident "$1xs"),TLit (Ident "List")) [Branch (PEnum (Ident "Nil"),TLit (Ident "List")) (EVar (Ident "Nil"),TLit (Ident "List")),Branch (PInj (Ident "Cons") [PVar (Ident "$2x",TLit (Ident "Int")),PVar (Ident "$3xs",TLit (Ident "List"))],TLit (Ident "List")) (EApp (EApp (EVar (Ident "Cons"),TFun (TLit (Ident "Int")) (TFun (TLit (Ident "List")) (TLit (Ident "List")))) (EApp (EVar (Ident "$0f"),TFun (TLit (Ident "Int")) (TLit (Ident "Int"))) (EVar (Ident "$2x"),TLit (Ident "Int")),TLit (Ident "Int")),TFun (TLit (Ident "List")) (TLit (Ident "List"))) (EApp (EApp (EVar (Ident "map$Int_Int_List_List"),TFun (TFun (TLit (Ident "Int")) (TLit (Ident "Int"))) (TFun (TLit (Ident "List")) (TLit (Ident "List")))) (EVar (Ident "$0f"),TFun (TLit (Ident "Int")) (TLit (Ident "Int"))),TFun (TLit (Ident "List")) (TLit (Ident "List"))) (EVar (Ident "$3xs"),TLit (Ident "List")),TLit (Ident "List")),TLit (Ident "List"))],TLit (Ident "List"))
; ÄNDRAT: ptr istället för i64 (i64)*
define fastcc %List @map$Int_Int_List_List(ptr %$0f, %List %$1xs) {
; NYTT: ta fram funktionspekaren
%$0f_deref = load i64(i64)*, ptr %$0f
%1 = alloca %List
; Penum
%2 = extractvalue %List %$1xs, 0
%3 = icmp eq i8 %2, 1
br i1 %3, label %lbl_success_8, label %lbl_failed_7
lbl_success_8:
%4 = alloca %List
store %List %$1xs, ptr %4
%5 = load %Nil, ptr %4
%6 = call fastcc %List @Nil()
store %List %6, ptr %1
br label %lbl_escape_6
lbl_failed_7:
; Inj
%7 = extractvalue %List %$1xs, 0
%8 = icmp eq i8 %7, 0
br i1 %8, label %lbl_success_10, label %lbl_failed_9
lbl_success_10:
%9 = alloca %List
store %List %$1xs, ptr %9
%10 = load %Cons, ptr %9
; ident i64
%$2x = extractvalue %Cons %10, 1
; ident %List
%11 = extractvalue %Cons %10, 2
%$3xs = load %List, ptr %11
; TLit (Ident "Int")
; ÄNDRAT använd deref
%12 = call fastcc i64 %$0f_deref(ptr %$0f, i64 %$2x)
; TLit (Ident "List")
; ÄNDRAT ptr istället för 64 (64)* och skicka med ptr
%13 = call fastcc %List @map$Int_Int_List_List(ptr %$0f, %List %$3xs)
; TLit (Ident "List")
%14 = call fastcc %List @Cons(i64 %12, %List %13)
store %List %14, ptr %1
br label %lbl_escape_6
lbl_failed_9:
call i32 (ptr, ...) @printf(ptr noundef @.non_exhaustive_patterns, i64 noundef 14, i64 noundef 6)
call i32 @exit(i32 noundef 1)
br label %lbl_escape_6
lbl_escape_6:
%17 = load %List, ptr %1
ret %List %17
}
; Ident "main": (EApp (EVar (Ident "sum$List_Int"),TFun (TLit (Ident "List")) (TLit (Ident "Int"))) (EApp (EApp (EVar (Ident "map$Int_Int_List_List"),TFun (TFun (TLit (Ident "Int")) (TLit (Ident "Int"))) (TFun (TLit (Ident "List")) (TLit (Ident "List")))) (EVar (Ident "sc_0$Int_Int"),TFun (TLit (Ident "Int")) (TLit (Ident "Int"))),TFun (TLit (Ident "List")) (TLit (Ident "List"))) (EApp (EApp (EVar (Ident "Cons"),TFun (TLit (Ident "Int")) (TFun (TLit (Ident "List")) (TLit (Ident "List")))) (ELit (LInt 1),TLit (Ident "Int")),TFun (TLit (Ident "List")) (TLit (Ident "List"))) (EApp (EApp (EVar (Ident "Cons"),TFun (TLit (Ident "Int")) (TFun (TLit (Ident "List")) (TLit (Ident "List")))) (ELit (LInt 2),TLit (Ident "Int")),TFun (TLit (Ident "List")) (TLit (Ident "List"))) (EVar (Ident "Nil"),TLit (Ident "List")),TLit (Ident "List")),TLit (Ident "List")),TLit (Ident "List")),TLit (Ident "Int"))
define fastcc i64 @main() {
%1 = call fastcc %List @Nil()
; TLit (Ident "List")
%2 = call fastcc %List @Cons(i64 2, %List %1)
; TLit (Ident "List")
%3 = call fastcc %List @Cons(i64 1, %List %2)
; TLit (Ident "List")
; NYTT: spara funktionspekaren och 100 i kontexten
%closure_sc_0 = alloca %Closure_sc_0
store i64(i64)* @sc_0$Int_Int, ptr %closure_sc_0
%fri_variabel_ptr = getelementptr inbounds %Closure_sc_0, ptr %closure_sc_0, i32 0, i32 1
store i64 100, ptr %fri_variabel_ptr
; store %Closure_sc_0 {i64 (i64)* @sc_0$Int_Int, 100}, ptr %closure_sc_0
; ÄNDRAT ptr %closure_sc_0 istället för i64 (i64)* @sc_0$Int_Int
%4 = call fastcc %List @map$Int_Int_List_List(ptr %closure_sc_0, %List %3)
; TLit (Ident "Int")
%5 = call fastcc i64 @sum$List_Int(%List %4)
call i32 (ptr, ...) @printf(ptr noundef @.str, i64 noundef %5)
ret i64 0
}
define fastcc %List @Cons(i64 %arg_0, %List %arg_1) {
%1 = alloca %List
%2 = getelementptr %List, %List* %1, i64 0, i32 0
store i8 0, i8* %2
%3 = bitcast %List* %1 to %Cons*
; i64 arg_0 1
%4 = getelementptr %Cons, %Cons* %3, i64 0, i32 1
; Just store
store i64 %arg_0, ptr %4
; %List arg_1 2
%5 = getelementptr %Cons, %Cons* %3, i64 0, i32 2
; Malloc and store
%6 = call ptr @malloc(i64 24)
store %List %arg_1, ptr %6
store %List* %6, ptr %5
; Return the newly constructed value
%7 = load %List, ptr %1
ret %List %7
}
define fastcc %List @Nil() {
%1 = alloca %List
%2 = getelementptr %List, %List* %1, i64 0, i32 0
store i8 1, i8* %2
%3 = bitcast %List* %1 to %Nil*
; Return the newly constructed value
%4 = load %List, ptr %1
ret %List %4
}