Add closures and fix lets in monomorphizer
This commit is contained in:
parent
677a200a15
commit
72e599d5de
26 changed files with 1440 additions and 692 deletions
|
|
@ -1,25 +1,25 @@
|
|||
module Codegen.Auxillary where
|
||||
|
||||
import Codegen.LlvmIr (LLVMType (..), LLVMValue (..))
|
||||
import Control.Monad (foldM_)
|
||||
import Monomorphizer.MonomorphizerIr as MIR (ExpT, Type (..))
|
||||
import TypeChecker.TypeCheckerIr qualified as TIR
|
||||
import Codegen.LlvmIr (LLVMType (..), LLVMValue (..))
|
||||
import Control.Monad (foldM_)
|
||||
import Monomorphizer.MonomorphizerIr as MIR (Exp, T, Type (..))
|
||||
import qualified TypeChecker.TypeCheckerIr as TIR
|
||||
|
||||
type2LlvmType :: MIR.Type -> LLVMType
|
||||
type2LlvmType (MIR.TLit id@(TIR.Ident name)) = case name of
|
||||
"Int" -> I64
|
||||
"Int" -> I64
|
||||
"Char" -> I8
|
||||
"Bool" -> I1
|
||||
_ -> CustomType id
|
||||
_ -> CustomType id
|
||||
type2LlvmType (MIR.TFun t xs) = do
|
||||
let (t', xs') = function2LLVMType xs [type2LlvmType t]
|
||||
Function t' xs'
|
||||
where
|
||||
function2LLVMType :: Type -> [LLVMType] -> (LLVMType, [LLVMType])
|
||||
function2LLVMType (TFun t xs) s = function2LLVMType xs (type2LlvmType t : s)
|
||||
function2LLVMType x s = (type2LlvmType x, s)
|
||||
function2LLVMType x s = (type2LlvmType x, s)
|
||||
|
||||
getType :: ExpT -> LLVMType
|
||||
getType :: T Exp -> LLVMType
|
||||
getType (_, t) = type2LlvmType t
|
||||
|
||||
extractTypeName :: MIR.Type -> TIR.Ident
|
||||
|
|
@ -30,21 +30,21 @@ extractTypeName (MIR.TFun t xs) =
|
|||
in TIR.Ident $ i <> "_$_" <> is
|
||||
|
||||
valueGetType :: LLVMValue -> LLVMType
|
||||
valueGetType (VInteger _) = I64
|
||||
valueGetType (VChar _) = I8
|
||||
valueGetType (VIdent _ t) = t
|
||||
valueGetType (VConstant s) = Array (fromIntegral $ length s) I8
|
||||
valueGetType (VInteger _) = I64
|
||||
valueGetType (VChar _) = I8
|
||||
valueGetType (VIdent _ t) = t
|
||||
valueGetType (VConstant s) = Array (fromIntegral $ length s) I8
|
||||
valueGetType (VFunction _ _ t) = t
|
||||
|
||||
typeByteSize :: LLVMType -> Integer
|
||||
typeByteSize I1 = 1
|
||||
typeByteSize I8 = 1
|
||||
typeByteSize I32 = 4
|
||||
typeByteSize I64 = 8
|
||||
typeByteSize Ptr = 8
|
||||
typeByteSize (Ref _) = 8
|
||||
typeByteSize I1 = 1
|
||||
typeByteSize I8 = 1
|
||||
typeByteSize I32 = 4
|
||||
typeByteSize I64 = 8
|
||||
typeByteSize Ptr = 8
|
||||
typeByteSize (Ref _) = 8
|
||||
typeByteSize (Function _ _) = 8
|
||||
typeByteSize (Array n t) = n * typeByteSize t
|
||||
typeByteSize (Array n t) = n * typeByteSize t
|
||||
typeByteSize (CustomType _) = 8
|
||||
|
||||
enumerateOneM_ :: Monad m => (Integer -> a -> m b) -> [a] -> m ()
|
||||
|
|
|
|||
|
|
@ -1,18 +1,24 @@
|
|||
{-# LANGUAGE OverloadedRecordDot #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
|
||||
module Codegen.Codegen (generateCode) where
|
||||
|
||||
import Codegen.CompilerState (
|
||||
CodeGenerator (instructions),
|
||||
initCodeGenerator,
|
||||
)
|
||||
import Codegen.Emits (compileScs)
|
||||
import Codegen.LlvmIr as LIR (llvmIrToString)
|
||||
import Control.Monad.State (
|
||||
execStateT,
|
||||
)
|
||||
import Data.List (sortBy)
|
||||
import Grammar.ErrM (Err)
|
||||
import Monomorphizer.MonomorphizerIr as MIR (Bind (..), Data (..), Def (DBind, DData), Program (..), Type (TLit))
|
||||
import TypeChecker.TypeCheckerIr (Ident (..))
|
||||
import Codegen.CompilerState (CodeGenerator (..),
|
||||
StructType (inst),
|
||||
initCodeGenerator)
|
||||
import Codegen.Emits (compileScs)
|
||||
import Codegen.LlvmIr as LIR (LLVMIr (UnsafeRaw),
|
||||
llvmIrToString)
|
||||
import Control.Monad.State (execStateT)
|
||||
import Data.Functor ((<&>))
|
||||
import Data.List (sortBy)
|
||||
import qualified Data.Map as Map
|
||||
import Grammar.ErrM (Err)
|
||||
import Monomorphizer.MonomorphizerIr as MIR (Bind (..), Data (..),
|
||||
Def (DBind, DData),
|
||||
Program (..),
|
||||
Type (TLit))
|
||||
import TypeChecker.TypeCheckerIr (Ident (..))
|
||||
|
||||
{- | Compiles an AST and produces a LLVM Ir string.
|
||||
An easy way to actually "compile" this output is to
|
||||
|
|
@ -20,16 +26,43 @@ import TypeChecker.TypeCheckerIr (Ident (..))
|
|||
-}
|
||||
generateCode :: MIR.Program -> Bool -> Err String
|
||||
generateCode (MIR.Program scs) addGc = do
|
||||
let tree = filter (not . detectPrelude) (sortBy lowData scs)
|
||||
let codegen = initCodeGenerator addGc tree
|
||||
llvmIrToString . instructions <$> execStateT (compileScs tree) codegen
|
||||
let tree = filter (not . detectPrelude) (sortBy lowData scs)
|
||||
codegen = initCodeGenerator addGc tree
|
||||
|
||||
-- Append instructions
|
||||
execStateT (compileScs tree) codegen <&> \state ->
|
||||
llvmIrToString $ defaultStart
|
||||
++ (if addGc then gcStart else [])
|
||||
++ map inst (Map.elems state.structTypes)
|
||||
++ state.instructions
|
||||
|
||||
detectPrelude :: Def -> Bool
|
||||
detectPrelude (DData (Data (TLit (Ident "Bool")) _)) = True
|
||||
detectPrelude (DData (Data (TLit (Ident "Bool")) _)) = True
|
||||
detectPrelude (DBind (Bind (Ident ('l' : 't' : '$' : _), _) _ _)) = True
|
||||
detectPrelude _ = False
|
||||
detectPrelude _ = False
|
||||
|
||||
lowData :: Def -> Def -> Ordering
|
||||
lowData (DData _) (DBind _) = LT
|
||||
lowData (DBind _) (DData _) = GT
|
||||
lowData _ _ = EQ
|
||||
lowData _ _ = EQ
|
||||
|
||||
defaultStart :: [LLVMIr]
|
||||
defaultStart =
|
||||
[ UnsafeRaw "target triple = \"x86_64-pc-linux-gnu\"\n"
|
||||
, UnsafeRaw "target datalayout = \"e-m:o-i64:64-f80:128-n8:16:32:64-S128\"\n"
|
||||
, UnsafeRaw "@.str = private unnamed_addr constant [3 x i8] c\"%i\n\", align 1\n"
|
||||
, UnsafeRaw "@.non_exhaustive_patterns = private unnamed_addr constant [41 x i8] c\"Non-exhaustive patterns in case at %i:%i\n\"\n"
|
||||
, UnsafeRaw "declare i32 @printf(ptr noalias nocapture, ...)\n"
|
||||
, UnsafeRaw "declare i32 @exit(i32 noundef)\n"
|
||||
, UnsafeRaw "declare ptr @malloc(i32 noundef)\n"
|
||||
]
|
||||
|
||||
gcStart :: [LLVMIr]
|
||||
gcStart =
|
||||
[ UnsafeRaw "declare external void @cheap_init()\n"
|
||||
, UnsafeRaw "declare external ptr @cheap_alloc(i64)\n"
|
||||
, UnsafeRaw "declare external void @cheap_dispose()\n"
|
||||
, UnsafeRaw "declare external ptr @cheap_the()\n"
|
||||
, UnsafeRaw "declare external void @cheap_set_profiler(ptr, i1)\n"
|
||||
, UnsafeRaw "declare external void @cheap_profiler_log_options(ptr, i64)\n"
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,46 +1,101 @@
|
|||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE OverloadedRecordDot #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
|
||||
module Codegen.CompilerState where
|
||||
|
||||
import Auxiliary (snoc)
|
||||
import Codegen.Auxillary (type2LlvmType, typeByteSize)
|
||||
import Codegen.LlvmIr as LIR (LLVMIr (UnsafeRaw),
|
||||
LLVMType)
|
||||
import Control.Monad.State (StateT, gets, modify)
|
||||
import Codegen.LlvmIr as LIR (LLVMIr (SetVariable, Type),
|
||||
LLVMType (CustomType, Function, I64, Ptr),
|
||||
LLVMValue (VFunction, VIdent),
|
||||
Visibility (Global),
|
||||
typeOf)
|
||||
import Control.Monad.State (StateT, gets, modify, void)
|
||||
import Data.Map (Map)
|
||||
import qualified Data.Map as Map
|
||||
import Grammar.ErrM (Err)
|
||||
import Monomorphizer.MonomorphizerIr as MIR
|
||||
import Monomorphizer.MonomorphizerIr (Ident (..), Inj (..), T,
|
||||
flattenType)
|
||||
import qualified Monomorphizer.MonomorphizerIr as MIR
|
||||
import qualified TypeChecker.TypeCheckerIr as TIR
|
||||
|
||||
-- | The record used as the code generator state
|
||||
data CodeGenerator = CodeGenerator
|
||||
{ instructions :: [LLVMIr]
|
||||
, functions :: Map MIR.Id FunctionInfo
|
||||
, functions :: Map (T Ident) FunctionInfo
|
||||
, customTypes :: Map LLVMType Integer
|
||||
, constructors :: Map TIR.Ident ConstructorInfo
|
||||
, constructors :: Map Ident ConstructorInfo
|
||||
, variableCount :: Integer
|
||||
, labelCount :: Integer
|
||||
, gcEnabled :: Bool
|
||||
, structTypes :: Map Ident StructType
|
||||
-- ^ Custom stucture types
|
||||
, locals :: [(Ident, LocalElem)]
|
||||
-- ^ Arguments and variables in local environment
|
||||
, globals :: Map Ident (LLVMType, LLVMValue)
|
||||
}
|
||||
|
||||
data StructType = StructType
|
||||
{ ptr :: LLVMType
|
||||
, typs :: [LLVMType]
|
||||
, inst :: LLVMIr
|
||||
}
|
||||
|
||||
data LocalElem = LocalElem
|
||||
{ typ :: LLVMType
|
||||
, val :: LLVMValue
|
||||
}
|
||||
|
||||
|
||||
-- | A state type synonym
|
||||
type CompilerState a = StateT CodeGenerator Err a
|
||||
|
||||
data FunctionInfo = FunctionInfo
|
||||
{ numArgs :: Int
|
||||
, arguments :: [Id]
|
||||
, arguments :: [T Ident]
|
||||
}
|
||||
deriving (Show)
|
||||
data ConstructorInfo = ConstructorInfo
|
||||
{ numArgsCI :: Int
|
||||
, argumentsCI :: [Id]
|
||||
, argumentsCI :: [T Ident]
|
||||
, numCI :: Integer
|
||||
, returnTypeCI :: MIR.Type
|
||||
}
|
||||
deriving (Show)
|
||||
|
||||
|
||||
addStructType_ :: Ident -> [LLVMType] -> CompilerState ()
|
||||
addStructType_ = fmap void . addStructType
|
||||
|
||||
addStructType :: Ident -> [LLVMType] -> CompilerState LLVMType
|
||||
addStructType x ts = do
|
||||
modify $ \s -> s { structTypes = Map.insert x struct s.structTypes }
|
||||
pure t
|
||||
where
|
||||
struct = StructType
|
||||
{ ptr = t
|
||||
, typs = ts
|
||||
, inst = Type x ts
|
||||
}
|
||||
t = CustomType x
|
||||
|
||||
-- | Adds a instruction to the CodeGenerator state
|
||||
emit :: LLVMIr -> CompilerState ()
|
||||
emit l = modify $ \t -> t{instructions = Auxiliary.snoc l $ instructions t}
|
||||
|
||||
-- Add variable to environment
|
||||
emit l@(SetVariable x _) = modify $ \t ->
|
||||
t { instructions = Auxiliary.snoc l t.instructions
|
||||
, locals = snoc (x, local)
|
||||
t.locals
|
||||
}
|
||||
where
|
||||
local = LocalElem { typ = typeOf l
|
||||
, val = VIdent x (typeOf l)
|
||||
}
|
||||
|
||||
emit l = modify $ \t -> t { instructions = Auxiliary.snoc l t.instructions }
|
||||
|
||||
-- | Increases the variable counter in the CodeGenerator state
|
||||
increaseVarCount :: CompilerState ()
|
||||
|
|
@ -63,16 +118,19 @@ getNewLabel = do
|
|||
{- | Produces a map of functions infos from a list of binds,
|
||||
which contains useful data for code generation.
|
||||
-}
|
||||
getFunctions :: [MIR.Def] -> Map Id FunctionInfo
|
||||
getFunctions :: [MIR.Def] -> Map (T Ident) FunctionInfo
|
||||
getFunctions bs = Map.fromList $ go bs
|
||||
where
|
||||
go [] = []
|
||||
go (MIR.DBind (MIR.Bind id args _) : xs) =
|
||||
(id, FunctionInfo{numArgs = length args, arguments = args})
|
||||
: go xs
|
||||
(id, FunctionInfo { numArgs = length args
|
||||
, arguments = args
|
||||
}
|
||||
)
|
||||
: go xs
|
||||
go (_ : xs) = go xs
|
||||
|
||||
createArgs :: [MIR.Type] -> [Id]
|
||||
createArgs :: [MIR.Type] -> [T Ident]
|
||||
createArgs xs = fst $ foldl (\(acc, l) t -> (acc ++ [(TIR.Ident ("arg_" <> show l), t)], l + 1)) ([], 0) xs
|
||||
|
||||
{- | Produces a map of functions infos from a list of binds,
|
||||
|
|
@ -113,35 +171,43 @@ getTypes bs = Map.fromList $ go bs
|
|||
variantTypes fi = init $ map type2LlvmType (flattenType fi)
|
||||
biggestVariant ts = 8 + maximum (sum . (\(Inj _ fi) -> typeByteSize <$> variantTypes fi) <$> ts)
|
||||
|
||||
getGlobals :: [MIR.Def] -> Map Ident (LLVMType, LLVMValue)
|
||||
getGlobals scs = Map.fromList [ go b | MIR.DBind b <- scs ]
|
||||
where
|
||||
go bind | x == "main" = let typ = Function I64 []
|
||||
in (x, (typ, VFunction x Global typ))
|
||||
| otherwise = (x, (typ, VFunction x Global typ))
|
||||
where
|
||||
typ = Function tr $ Ptr : ts
|
||||
Function tr ts = type2LlvmType' t
|
||||
|
||||
(x, t) = case bind of
|
||||
MIR.Bind xt _ _ -> xt
|
||||
MIR.BindC _ xt _ _ -> xt
|
||||
|
||||
-- Higher order function arguments are replaced with ptr
|
||||
type2LlvmType' = go []
|
||||
where
|
||||
go acc = \case
|
||||
MIR.TFun (MIR.TFun _ _) t2 -> go (snoc Ptr acc) t2
|
||||
MIR.TFun t1 t2 -> go (snoc (type2LlvmType t1) acc) t2
|
||||
t -> Function (type2LlvmType t) acc
|
||||
|
||||
|
||||
|
||||
|
||||
initCodeGenerator :: Bool -> [MIR.Def] -> CodeGenerator
|
||||
initCodeGenerator addGc scs =
|
||||
CodeGenerator
|
||||
{ instructions = defaultStart <> if addGc then gcStart else []
|
||||
{ instructions = []
|
||||
, functions = getFunctions scs
|
||||
, constructors = getConstructors scs
|
||||
, customTypes = getTypes scs
|
||||
, structTypes = mempty
|
||||
, variableCount = 0
|
||||
, labelCount = 0
|
||||
, gcEnabled = addGc
|
||||
, locals = mempty
|
||||
, globals = getGlobals scs
|
||||
}
|
||||
|
||||
defaultStart :: [LLVMIr]
|
||||
defaultStart =
|
||||
[ UnsafeRaw "target triple = \"x86_64-pc-linux-gnu\"\n"
|
||||
, UnsafeRaw "target datalayout = \"e-m:o-i64:64-f80:128-n8:16:32:64-S128\"\n"
|
||||
, UnsafeRaw "@.str = private unnamed_addr constant [3 x i8] c\"%i\n\", align 1\n"
|
||||
, UnsafeRaw "@.non_exhaustive_patterns = private unnamed_addr constant [41 x i8] c\"Non-exhaustive patterns in case at %i:%i\n\"\n"
|
||||
, UnsafeRaw "declare i32 @printf(ptr noalias nocapture, ...)\n"
|
||||
, UnsafeRaw "declare i32 @exit(i32 noundef)\n"
|
||||
, UnsafeRaw "declare ptr @malloc(i32 noundef)\n"
|
||||
]
|
||||
|
||||
gcStart :: [LLVMIr]
|
||||
gcStart =
|
||||
[ UnsafeRaw "declare external void @cheap_init()\n"
|
||||
, UnsafeRaw "declare external ptr @cheap_alloc(i64)\n"
|
||||
, UnsafeRaw "declare external void @cheap_dispose()\n"
|
||||
, UnsafeRaw "declare external ptr @cheap_the()\n"
|
||||
, UnsafeRaw "declare external void @cheap_set_profiler(ptr, i1)\n"
|
||||
, UnsafeRaw "declare external void @cheap_profiler_log_options(ptr, i64)\n"
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,36 +1,40 @@
|
|||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE DuplicateRecordFields #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE NamedFieldPuns #-}
|
||||
{-# LANGUAGE OverloadedRecordDot #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
|
||||
module Codegen.Emits where
|
||||
|
||||
import Codegen.Auxillary
|
||||
import Codegen.CompilerState
|
||||
import Codegen.LlvmIr as LIR
|
||||
import Control.Applicative ((<|>))
|
||||
import Control.Monad (when)
|
||||
import Control.Monad.State (gets, modify)
|
||||
import Data.Bifunctor qualified as BI
|
||||
import Data.Char (ord)
|
||||
import Data.Coerce (coerce)
|
||||
import Data.Map qualified as Map
|
||||
import Data.Maybe (fromJust, fromMaybe, isNothing)
|
||||
import Data.Tuple.Extra (dupe, first, second)
|
||||
import Debug.Trace (trace, traceShow)
|
||||
import Grammar.Print
|
||||
import Monomorphizer.MonomorphizerIr as MIR
|
||||
import TypeChecker.TypeCheckerIr qualified as TIR
|
||||
import Auxiliary (snoc)
|
||||
import Codegen.Auxillary
|
||||
import Codegen.CompilerState
|
||||
import Codegen.LlvmIr as LIR
|
||||
import Control.Applicative (Applicative (liftA2), (<|>))
|
||||
import Control.Monad (forM_, when, zipWithM_)
|
||||
import Control.Monad.Extra (whenJust)
|
||||
import Control.Monad.State (gets, modify)
|
||||
import Data.Char (ord)
|
||||
import Data.Coerce (coerce)
|
||||
import Data.Foldable.Extra (notNull)
|
||||
import qualified Data.Map as Map
|
||||
import Data.Maybe (fromJust, fromMaybe, isNothing)
|
||||
import Data.Tuple.Extra (second)
|
||||
import Grammar.Print (printTree)
|
||||
import Monomorphizer.MonomorphizerIr
|
||||
|
||||
compileScs :: [MIR.Def] -> CompilerState ()
|
||||
|
||||
compileScs :: [Def] -> CompilerState ()
|
||||
compileScs [] = do
|
||||
emit $ UnsafeRaw "\n"
|
||||
mapM_ createConstructor =<< gets (Map.toList . constructors)
|
||||
-- as a last step create all the constructors
|
||||
-- //TODO maybe merge this with the data type match?
|
||||
c <- gets (Map.toList . constructors)
|
||||
mapM_
|
||||
( \(id, ci) -> do
|
||||
let t = returnTypeCI ci
|
||||
let t' = type2LlvmType t
|
||||
let x = BI.second type2LlvmType <$> argumentsCI ci
|
||||
where
|
||||
createConstructor (id, ci) = do
|
||||
let t = returnTypeCI ci
|
||||
t' = type2LlvmType t
|
||||
x = (mkCxtName, Ptr) : map (second type2LlvmType) ci.argumentsCI
|
||||
emit $ Define FastCC t' id x
|
||||
top <- getNewVar
|
||||
ptr <- getNewVar
|
||||
|
|
@ -56,7 +60,7 @@ compileScs [] = do
|
|||
cTypes <- gets customTypes
|
||||
|
||||
enumerateOneM_
|
||||
( \i (TIR.Ident arg_n, arg_t) -> do
|
||||
( \i (Ident arg_n, arg_t) -> do
|
||||
let arg_t' = type2LlvmType arg_t
|
||||
emit $ Comment (toIr arg_t' <> " " <> arg_n <> " " <> show i)
|
||||
elemPtr <- getNewVar
|
||||
|
|
@ -78,11 +82,11 @@ compileScs [] = do
|
|||
heapPtr <- getNewVar
|
||||
useGc <- gets gcEnabled
|
||||
emit $ SetVariable heapPtr (if useGc then GcMalloc s else Malloc s)
|
||||
emit $ Store arg_t' (VIdent (TIR.Ident arg_n) arg_t') Ptr heapPtr
|
||||
emit $ Store arg_t' (VIdent (Ident arg_n) arg_t') Ptr heapPtr
|
||||
emit $ Store (Ref arg_t') (VIdent heapPtr arg_t') Ptr elemPtr
|
||||
Nothing -> do
|
||||
emit $ Comment "Just store"
|
||||
emit $ Store arg_t' (VIdent (TIR.Ident arg_n) arg_t') Ptr elemPtr
|
||||
emit $ Store arg_t' (VIdent (Ident arg_n) arg_t') Ptr elemPtr
|
||||
)
|
||||
(argumentsCI ci)
|
||||
|
||||
|
|
@ -95,34 +99,83 @@ compileScs [] = do
|
|||
emit $ UnsafeRaw "\n"
|
||||
|
||||
modify $ \s -> s{variableCount = 0}
|
||||
)
|
||||
c
|
||||
compileScs (MIR.DBind (MIR.Bind (name, t) args exp) : xs) = do
|
||||
let t_return = type2LlvmType . last . flattenType $ t
|
||||
|
||||
compileScs (DBind bind : xs) = do
|
||||
emit $ UnsafeRaw "\n"
|
||||
emit . Comment $ show name <> ": " <> show exp
|
||||
let args' = map (second type2LlvmType) args
|
||||
emit . Comment $ show name <> ": " <> show (fst exp)
|
||||
|
||||
Function t_return t_args <- gets $ fst
|
||||
. fromJust
|
||||
. Map.lookup name
|
||||
. globals
|
||||
|
||||
let args' = zip (mkCxtName : map fst args) t_args
|
||||
|
||||
emit $ Define FastCC t_return name args'
|
||||
useGc <- gets gcEnabled
|
||||
when (name == "main") (mapM_ emit (firstMainContent useGc))
|
||||
functionBody <- exprToValue exp
|
||||
if name == "main"
|
||||
then mapM_ emit $ lastMainContent useGc functionBody
|
||||
else emit $ Ret t_return functionBody
|
||||
modify $ \s -> s { locals = foldr insertArg s.locals args' }
|
||||
|
||||
-- Dereference ptr arguments
|
||||
when (notNull args') $
|
||||
forM_ (tail args') $ \(x, t) -> when (t == Ptr) $ do
|
||||
let t_deref =
|
||||
let
|
||||
Function t ts = type2LlvmType . fromJust $ lookup x args
|
||||
in
|
||||
Function t (Ptr : ts)
|
||||
|
||||
emit . SetVariable (mkDerefName x)
|
||||
$ Load t_deref Ptr x
|
||||
|
||||
whenJust mcxt loadFreeVars
|
||||
|
||||
gcEnabled <- gets gcEnabled
|
||||
when isMain $ mapM_ emit (firstMainContent gcEnabled)
|
||||
|
||||
result <- exprToValue exp
|
||||
|
||||
if isMain
|
||||
then mapM_ emit $ lastMainContent gcEnabled result
|
||||
else emit $ Ret t_return result
|
||||
|
||||
emit DefineEnd
|
||||
modify $ \s -> s{variableCount = 0}
|
||||
-- Reset variable count and empty locals
|
||||
modify $ \s -> s { variableCount = 0, locals = mempty }
|
||||
compileScs xs
|
||||
compileScs (MIR.DData (MIR.Data typ ts) : xs) = do
|
||||
let (TIR.Ident outer_id) = extractTypeName typ
|
||||
where
|
||||
loadFreeVars cxt = do
|
||||
emit $ Comment "Load free variables"
|
||||
zipWithM_ go cxt' [1 ..]
|
||||
where
|
||||
go (x, t) i = do
|
||||
vc <- getNewVar
|
||||
emit . SetVariable vc
|
||||
$ GetElementPtrInbounds (CustomType $ mkClosureName name) Ptr (VIdent mkCxtName Ptr)
|
||||
I32 (VInteger 0) I32 (VInteger i) -- TODO fix indices
|
||||
emit . SetVariable x $ Load t Ptr vc
|
||||
cxt' = map (second type2LlvmType) cxt
|
||||
|
||||
isMain = name == "main"
|
||||
|
||||
(name, args, exp, mcxt) = case bind of
|
||||
Bind (name, _) args exp -> (name, args, exp, Nothing)
|
||||
BindC cxt (name, _) args exp -> (name, args, exp, Just cxt)
|
||||
|
||||
|
||||
insertArg (x, t) = snoc (x, LocalElem { val = VIdent x t, typ = t })
|
||||
|
||||
compileScs (DData (Data typ ts) : xs) = do
|
||||
let (Ident outer_id) = extractTypeName typ
|
||||
-- //TODO this could be extracted from the customTypes map
|
||||
let variantTypes fi = init $ map type2LlvmType (flattenType fi)
|
||||
let biggestVariant = 7 + maximum (sum . (\(Inj _ fi) -> typeByteSize <$> variantTypes fi) <$> ts)
|
||||
emit $ LIR.Type (TIR.Ident outer_id) [I8, Array biggestVariant I8]
|
||||
-- Add data type (e.g. %List) to top of the file
|
||||
addStructType_ (Ident outer_id) [I8, Array biggestVariant I8]
|
||||
typeSets <- gets customTypes
|
||||
mapM_
|
||||
( \(Inj inner_id fi) -> do
|
||||
let types = (\s -> if Map.member s typeSets then Ref s else s) <$> variantTypes fi
|
||||
emit $ LIR.Type inner_id (I8 : types)
|
||||
-- Add constructor type (e.g. %Cons) to top of the file
|
||||
addStructType_ inner_id (I8 : types)
|
||||
)
|
||||
ts
|
||||
compileScs xs
|
||||
|
|
@ -149,16 +202,16 @@ lastMainContent False var =
|
|||
, Ret I64 (VInteger 0)
|
||||
]
|
||||
|
||||
compileExp :: ExpT -> CompilerState ()
|
||||
compileExp (MIR.ELit lit, _t) = emitLit lit
|
||||
compileExp (MIR.EAdd e1 e2, t) = emitAdd t e1 e2
|
||||
compileExp (MIR.EVar name, _t) = emitIdent name
|
||||
compileExp (MIR.EApp e1 e2, t) = emitApp t e1 e2
|
||||
compileExp (MIR.ELet bind e, _) = emitLet bind e
|
||||
compileExp (MIR.ECase e cs, t) = emitECased t e (map (t,) cs)
|
||||
compileExp :: T Exp -> CompilerState ()
|
||||
compileExp (ELit lit, _t) = emitLit lit
|
||||
compileExp (EAdd e1 e2, t) = emitAdd t e1 e2
|
||||
compileExp (EVar name, _t) = emitIdent name
|
||||
compileExp (EApp e1 e2, t) = emitApp t e1 e2
|
||||
compileExp (ELet bind e, _) = emitLet bind e
|
||||
compileExp (ECase e cs, t) = emitECased t e (map (t,) cs)
|
||||
|
||||
emitLet :: MIR.Bind -> ExpT -> CompilerState ()
|
||||
emitLet (MIR.Bind id [] innerExp) e = do
|
||||
emitLet :: Bind -> T Exp -> CompilerState ()
|
||||
emitLet (Bind id [] innerExp) e = do
|
||||
evaled <- exprToValue innerExp
|
||||
tempVar <- getNewVar
|
||||
let t = type2LlvmType . snd $ innerExp
|
||||
|
|
@ -168,14 +221,14 @@ emitLet (MIR.Bind id [] innerExp) e = do
|
|||
compileExp e
|
||||
emitLet b _ = error $ "Non empty argument list in let-bind " <> show b
|
||||
|
||||
emitECased :: MIR.Type -> ExpT -> [(MIR.Type, Branch)] -> CompilerState ()
|
||||
emitECased :: Type -> T Exp -> [(Type, Branch)] -> CompilerState ()
|
||||
emitECased t e cases = do
|
||||
let cs = snd <$> cases
|
||||
let ty = type2LlvmType t
|
||||
let rt = type2LlvmType (snd e)
|
||||
vs <- exprToValue e
|
||||
lbl <- getNewLabel
|
||||
let label = TIR.Ident $ "escape_" <> show lbl
|
||||
let label = Ident $ "escape_" <> show lbl
|
||||
stackPtr <- getNewVar
|
||||
emit $ SetVariable stackPtr (Alloca ty)
|
||||
mapM_ (emitCases rt ty label stackPtr vs) cs
|
||||
|
|
@ -192,14 +245,14 @@ emitECased t e cases = do
|
|||
res <- getNewVar
|
||||
emit $ SetVariable res (Load ty Ptr stackPtr)
|
||||
where
|
||||
emitCases :: LLVMType -> LLVMType -> TIR.Ident -> TIR.Ident -> LLVMValue -> Branch -> CompilerState ()
|
||||
emitCases rt ty label stackPtr vs (Branch (MIR.PInj consId cs, _t) exp) = do
|
||||
emitCases :: LLVMType -> LLVMType -> Ident -> Ident -> LLVMValue -> Branch -> CompilerState ()
|
||||
emitCases rt ty label stackPtr vs (Branch (PInj consId cs, _t) exp) = do
|
||||
emit $ Comment "Inj"
|
||||
cons <- gets constructors
|
||||
let r = fromJust $ Map.lookup consId cons
|
||||
|
||||
lbl_failPos <- (\x -> TIR.Ident $ "failed_" <> show x) <$> getNewLabel
|
||||
lbl_succPos <- (\x -> TIR.Ident $ "success_" <> show x) <$> getNewLabel
|
||||
lbl_failPos <- (\x -> Ident $ "failed_" <> show x) <$> getNewLabel
|
||||
lbl_succPos <- (\x -> Ident $ "success_" <> show x) <$> getNewLabel
|
||||
|
||||
consVal <- getNewVar
|
||||
emit $ SetVariable consVal (ExtractValue rt vs 0)
|
||||
|
|
@ -215,10 +268,10 @@ emitECased t e cases = do
|
|||
emit $ Store rt vs Ptr castPtr
|
||||
emit $ SetVariable casted (Load (CustomType (coerce consId)) Ptr castPtr)
|
||||
enumerateOneM_
|
||||
( \i c -> do
|
||||
( \i (c, t) -> do
|
||||
case c of
|
||||
PVar (x, topT) -> do
|
||||
let topT' = type2LlvmType topT
|
||||
PVar x -> do
|
||||
let topT' = type2LlvmType t
|
||||
let botT' = CustomType (coerce consId)
|
||||
emit . Comment $ "ident " <> toIr topT'
|
||||
cTypes <- gets customTypes
|
||||
|
|
@ -228,7 +281,7 @@ emitECased t e cases = do
|
|||
emit $ SetVariable deref (ExtractValue botT' (VIdent casted Ptr) i)
|
||||
emit $ SetVariable x (Load topT' Ptr deref)
|
||||
else emit $ SetVariable x (ExtractValue botT' (VIdent casted Ptr) i)
|
||||
PLit (_l, _t) -> error "Nested pattern matching to be implemented"
|
||||
PLit _l -> error "Nested pattern matching to be implemented"
|
||||
PInj _id _ps -> error "Nested pattern matching to be implemented"
|
||||
PCatch -> pure ()
|
||||
PEnum _id -> error "Nested pattern matching to be implemented"
|
||||
|
|
@ -238,22 +291,22 @@ emitECased t e cases = do
|
|||
emit $ Store ty val Ptr stackPtr
|
||||
emit $ Br label
|
||||
emit $ Label lbl_failPos
|
||||
emitCases _rt ty label stackPtr vs (Branch (MIR.PLit (i, ct), t) exp) = do
|
||||
emitCases _rt ty label stackPtr vs (Branch (PLit i, t) exp) = do
|
||||
emit $ Comment "Plit"
|
||||
let i' = case i of
|
||||
MIR.LInt i -> VInteger i
|
||||
MIR.LChar i -> VChar (ord i)
|
||||
LInt i -> VInteger i
|
||||
LChar i -> VChar (ord i)
|
||||
ns <- getNewVar
|
||||
lbl_failPos <- (\x -> TIR.Ident $ "failed_" <> show x) <$> getNewLabel
|
||||
lbl_succPos <- (\x -> TIR.Ident $ "success_" <> show x) <$> getNewLabel
|
||||
emit $ SetVariable ns (Icmp LLEq (type2LlvmType ct) vs i')
|
||||
lbl_failPos <- (\x -> Ident $ "failed_" <> show x) <$> getNewLabel
|
||||
lbl_succPos <- (\x -> Ident $ "success_" <> show x) <$> getNewLabel
|
||||
emit $ SetVariable ns (Icmp LLEq (type2LlvmType t) vs i')
|
||||
emit $ BrCond (VIdent ns ty) lbl_succPos lbl_failPos
|
||||
emit $ Label lbl_succPos
|
||||
val <- exprToValue exp
|
||||
emit $ Store ty val Ptr stackPtr
|
||||
emit $ Br label
|
||||
emit $ Label lbl_failPos
|
||||
emitCases rt ty label stackPtr vs (Branch (MIR.PVar (id, _), _) exp) = do
|
||||
emitCases rt ty label stackPtr vs (Branch (PVar id, _) exp) = do
|
||||
emit $ Comment "Pvar"
|
||||
-- //TODO this is pretty disgusting and would heavily benefit from a rewrite
|
||||
valPtr <- getNewVar
|
||||
|
|
@ -263,20 +316,20 @@ emitECased t e cases = do
|
|||
val <- exprToValue exp
|
||||
emit $ Store ty val Ptr stackPtr
|
||||
emit $ Br label
|
||||
lbl_failPos <- (\x -> TIR.Ident $ "failed_" <> show x) <$> getNewLabel
|
||||
lbl_failPos <- (\x -> Ident $ "failed_" <> show x) <$> getNewLabel
|
||||
emit $ Label lbl_failPos
|
||||
emitCases rt ty label stackPtr vs (Branch (MIR.PEnum (TIR.Ident "True$Bool"), t) exp) = do
|
||||
emitCases rt ty label stackPtr vs (Branch (MIR.PLit (MIR.LInt 1, TLit "Bool"), t) exp)
|
||||
emitCases rt ty label stackPtr vs (Branch (MIR.PEnum (TIR.Ident "False$Bool"), _) exp) = do
|
||||
emitCases rt ty label stackPtr vs (Branch (MIR.PLit (MIR.LInt 0, TLit "Bool"), t) exp)
|
||||
emitCases rt ty label stackPtr vs br@(Branch (MIR.PEnum consId, _) exp) = do
|
||||
emitCases rt ty label stackPtr vs (Branch (PEnum (Ident "True$Bool"), t) exp) = do
|
||||
emitCases rt ty label stackPtr vs (Branch (PLit $ LInt 1, t) exp)
|
||||
emitCases rt ty label stackPtr vs (Branch (PEnum (Ident "False$Bool"), _) exp) = do
|
||||
emitCases rt ty label stackPtr vs (Branch (PLit (LInt 0), t) exp)
|
||||
emitCases rt ty label stackPtr vs br@(Branch (PEnum consId, _) exp) = do
|
||||
emit $ Comment "Penum"
|
||||
cons <- gets constructors
|
||||
let r = Map.lookup consId cons
|
||||
when (isNothing r) (error $ "Constructor: '" ++ printTree consId ++ "' does not exist in cons state:\n" ++ show cons ++ "\nin pattern\n'" ++ printTree br ++ "'\n")
|
||||
|
||||
lbl_failPos <- (\x -> TIR.Ident $ "failed_" <> show x) <$> getNewLabel
|
||||
lbl_succPos <- (\x -> TIR.Ident $ "success_" <> show x) <$> getNewLabel
|
||||
lbl_failPos <- (\x -> Ident $ "failed_" <> show x) <$> getNewLabel
|
||||
lbl_succPos <- (\x -> Ident $ "success_" <> show x) <$> getNewLabel
|
||||
|
||||
consVal <- getNewVar
|
||||
emit $ SetVariable consVal (ExtractValue rt vs 0)
|
||||
|
|
@ -295,98 +348,167 @@ emitECased t e cases = do
|
|||
emit $ Store ty val Ptr stackPtr
|
||||
emit $ Br label
|
||||
emit $ Label lbl_failPos
|
||||
emitCases _ ty label stackPtr _ (Branch (MIR.PCatch, _) exp) = do
|
||||
emitCases _ ty label stackPtr _ (Branch (PCatch, _) exp) = do
|
||||
emit $ Comment "Pcatch"
|
||||
val <- exprToValue exp
|
||||
emit $ Store ty val Ptr stackPtr
|
||||
emit $ Br label
|
||||
lbl_failPos <- (\x -> TIR.Ident $ "failed_" <> show x) <$> getNewLabel
|
||||
lbl_failPos <- (\x -> Ident $ "failed_" <> show x) <$> getNewLabel
|
||||
emit $ Label lbl_failPos
|
||||
|
||||
emitApp :: MIR.Type -> ExpT -> ExpT -> CompilerState ()
|
||||
emitApp rt e1 e2 = appEmitter e1 e2 []
|
||||
where
|
||||
appEmitter :: ExpT -> ExpT -> [ExpT] -> CompilerState ()
|
||||
appEmitter e1 e2 stack = do
|
||||
let newStack = e2 : stack
|
||||
case e1 of
|
||||
(MIR.EApp e1' e2', _) -> appEmitter e1' e2' newStack
|
||||
(MIR.EVar name, t) -> do
|
||||
args <- traverse exprToValue newStack
|
||||
vs <- getNewVar
|
||||
funcs <- gets functions
|
||||
consts <- gets constructors
|
||||
let visibility =
|
||||
fromMaybe Local $
|
||||
Global <$ Map.lookup name consts
|
||||
<|> Global <$ Map.lookup (name, t) funcs
|
||||
-- this piece of code could probably be improved, i.e remove the double `const Global`
|
||||
args' = map (first valueGetType . dupe) args
|
||||
let call =
|
||||
case name of
|
||||
TIR.Ident ('l' : 't' : '$' : _) -> Icmp LLSlt I64 (snd (head args')) (snd (args' !! 1))
|
||||
TIR.Ident ('$' : 'm' : 'i' : 'n' : 'u' : 's' : '$' : '$' : _) -> Sub I64 (snd (head args')) (snd (args' !! 1))
|
||||
_ -> Call FastCC (type2LlvmType rt) visibility name args'
|
||||
emit $ Comment $ show rt
|
||||
emit $ SetVariable vs call
|
||||
x -> error $ "The unspeakable happened: " <> show x
|
||||
emitApp :: Type -> T Exp -> T Exp -> CompilerState ()
|
||||
emitApp rt e1 e2 = do
|
||||
((EVar name, t), args) <- go (EApp e1 e2, rt)
|
||||
vs <- getNewVar
|
||||
funcs <- gets functions
|
||||
consts <- gets constructors
|
||||
let visibility =
|
||||
fromMaybe Local $
|
||||
Global <$ Map.lookup name consts
|
||||
<|> Global <$ Map.lookup (name, t) funcs
|
||||
-- this piece of code could probably be improved, i.e remove the double `const Global`
|
||||
|
||||
emitIdent :: TIR.Ident -> CompilerState ()
|
||||
call <- case name of
|
||||
Ident ('l' : 't' : '$' : _) ->
|
||||
pure $ Icmp LLSlt I64 (snd (head args)) (snd (args !! 1))
|
||||
Ident ('$' : 'm' : 'i' : 'n' : 'u' : 's' : '$' : '$' : _) ->
|
||||
pure $ Sub I64 (snd (head args)) (snd (args !! 1))
|
||||
|
||||
-- FIXME
|
||||
_ -> do
|
||||
let closure_call LocalElem { typ = Ptr, val } = (mkDerefName name, (Ptr, val) : args)
|
||||
|
||||
(name, args) <- gets $ maybe (name, (Ptr, VNull) : args) closure_call
|
||||
. lookup name
|
||||
. locals
|
||||
|
||||
pure $ Call FastCC (type2LlvmType rt) visibility name args
|
||||
|
||||
emit $ Comment $ show (type2LlvmType rt)
|
||||
emit $ SetVariable vs call
|
||||
|
||||
where
|
||||
|
||||
go :: T Exp -> CompilerState (T Exp, [(LLVMType, LLVMValue)])
|
||||
go et@(e, _) = case e of
|
||||
EApp e1 e2@(_, t) -> do
|
||||
(x, as) <- go e1
|
||||
a <- exprToValue e2
|
||||
let t' = type2LlvmType' t
|
||||
pure (x, snoc (t', a) as)
|
||||
_ -> pure (et, [])
|
||||
|
||||
type2LlvmType' = \case
|
||||
TFun _ _ -> Ptr
|
||||
t -> type2LlvmType t
|
||||
|
||||
emitIdent :: Ident -> CompilerState ()
|
||||
emitIdent id = do
|
||||
-- !!this should never happen!!
|
||||
emit $ Comment "This should not have happened!"
|
||||
emit $ Variable id
|
||||
emit $ UnsafeRaw "\n"
|
||||
|
||||
emitLit :: MIR.Lit -> CompilerState ()
|
||||
emitLit :: Lit -> CompilerState ()
|
||||
emitLit i = do
|
||||
-- !!this should never happen!!
|
||||
let (i', t) = case i of
|
||||
(MIR.LInt i'') -> (VInteger i'', I64)
|
||||
(MIR.LChar i'') -> (VChar $ ord i'', I8)
|
||||
(LInt i'') -> (VInteger i'', I64)
|
||||
(LChar i'') -> (VChar $ ord i'', I8)
|
||||
varCount <- getNewVar
|
||||
emit $ Comment "This should not have happened!"
|
||||
emit $ SetVariable varCount (Add t i' (VInteger 0))
|
||||
|
||||
emitAdd :: MIR.Type -> ExpT -> ExpT -> CompilerState ()
|
||||
emitAdd :: Type -> T Exp -> T Exp -> CompilerState ()
|
||||
emitAdd t e1 e2 = do
|
||||
v1 <- exprToValue e1
|
||||
v2 <- exprToValue e2
|
||||
v <- getNewVar
|
||||
emit $ SetVariable v (Add (type2LlvmType t) v1 v2)
|
||||
|
||||
exprToValue :: ExpT -> CompilerState LLVMValue
|
||||
exprToValue = \case
|
||||
(MIR.ELit i, _t) -> pure $ case i of
|
||||
(MIR.LInt i) -> VInteger i
|
||||
(MIR.LChar i) -> VChar $ ord i
|
||||
(MIR.EVar (TIR.Ident "True$Bool"), _t) -> pure $ VInteger 1
|
||||
(MIR.EVar (TIR.Ident "False$Bool"), _t) -> pure $ VInteger 0
|
||||
(MIR.EVar name, t) -> do
|
||||
funcs <- gets functions
|
||||
cons <- gets constructors
|
||||
let res =
|
||||
Map.lookup (name, t) funcs
|
||||
<|> ( \c ->
|
||||
FunctionInfo
|
||||
{ numArgs = numArgsCI c
|
||||
, arguments = argumentsCI c
|
||||
}
|
||||
)
|
||||
<$> Map.lookup name cons
|
||||
case res of
|
||||
Just fi -> do
|
||||
if numArgs fi == 0
|
||||
then do
|
||||
vc <- getNewVar
|
||||
emit $
|
||||
SetVariable
|
||||
vc
|
||||
(Call FastCC (type2LlvmType t) Global name [])
|
||||
pure $ VIdent vc (type2LlvmType t)
|
||||
else pure $ VFunction name Global (type2LlvmType t)
|
||||
Nothing -> pure $ VIdent name (type2LlvmType t)
|
||||
e -> do
|
||||
compileExp e
|
||||
|
||||
exprToValue :: T Exp -> CompilerState LLVMValue
|
||||
exprToValue et@(e, t) = case e of
|
||||
ELit (LInt i) -> pure $ VInteger i
|
||||
ELit (LChar c) -> pure . VChar $ ord c
|
||||
|
||||
EVar "True$Bool" -> pure $ VInteger 1
|
||||
EVar "False$Bool" -> pure $ VInteger 0
|
||||
|
||||
EVar name -> gets (Map.lookup name . globals) >>= \case
|
||||
Just (typ@(Function _ ts), val) | length ts > 1 -> do
|
||||
type_struct <- addStructType (mkClosureName name) [typ]
|
||||
emit $ Comment "Allocating structure"
|
||||
emit . SetVariable name $ Alloca type_struct
|
||||
emit $ Store typ val Ptr name
|
||||
pure $ VIdent name Ptr
|
||||
|
||||
Just _ | name == "main" -> do
|
||||
vc <- getNewVar
|
||||
emit $ SetVariable vc (Call FastCC I64 Global name [])
|
||||
pure $ VIdent vc I64
|
||||
|
||||
|
||||
Just (Function t_return [_], _) -> do
|
||||
vc <- getNewVar
|
||||
emit $ SetVariable vc (Call FastCC t_return Global name [(Ptr, VNull)])
|
||||
pure $ VIdent vc t_return
|
||||
|
||||
Just _ -> error "Bad"
|
||||
|
||||
Nothing -> gets (Map.lookup name . constructors) >>= \case
|
||||
|
||||
Just ConstructorInfo {numArgsCI}
|
||||
| numArgsCI == 0 -> do
|
||||
vc <- getNewVar
|
||||
emit $ SetVariable vc call
|
||||
pure $ VIdent vc (type2LlvmType t)
|
||||
| otherwise -> pure $ VFunction name Global (type2LlvmType t)
|
||||
where
|
||||
call = Call FastCC (type2LlvmType t) Global name []
|
||||
|
||||
Nothing -> gets $ val
|
||||
. fromJust
|
||||
. lookup name
|
||||
. locals
|
||||
|
||||
EVarC cxt name -> do
|
||||
let cxt' = flip map cxt $ \(x, t) -> let t' = type2LlvmType t
|
||||
in (t', VIdent x t')
|
||||
cxt'' <- gets $ (:cxt')
|
||||
. fromJust
|
||||
. Map.lookup name
|
||||
. globals
|
||||
|
||||
-- Create a new type for function pointer and arguments
|
||||
type_struct <- addStructType (mkClosureName name) $ map fst cxt''
|
||||
emit $ Comment "Allocating structure"
|
||||
emit . SetVariable name $ Alloca type_struct
|
||||
|
||||
let ptr_struct = VIdent name Ptr
|
||||
storeArg (t, v) i = do
|
||||
vc <- getNewVar
|
||||
emit . SetVariable vc
|
||||
$ GetElementPtrInbounds type_struct Ptr ptr_struct
|
||||
I32 (VInteger 0) I32 (VInteger i) -- TODO fix indices
|
||||
emit $ Store t v Ptr vc
|
||||
|
||||
-- Store arguments in structure
|
||||
zipWithM_ storeArg cxt'' [0 ..]
|
||||
pure ptr_struct
|
||||
|
||||
_ -> do
|
||||
compileExp et
|
||||
v <- getVarCount
|
||||
pure $ VIdent (TIR.Ident $ show v) (getType e)
|
||||
pure $ VIdent (Ident $ show v) (getType et)
|
||||
|
||||
|
||||
mkClosureName :: Ident -> Ident
|
||||
mkClosureName (Ident s) = Ident $ "Closure_" ++ s
|
||||
|
||||
mkDerefName :: Ident -> Ident
|
||||
mkDerefName (Ident s) = Ident $ s ++ "_deref"
|
||||
|
||||
mkCxtName :: Ident
|
||||
mkCxtName = Ident "cxt"
|
||||
|
||||
|
|
|
|||
|
|
@ -9,17 +9,18 @@ module Codegen.LlvmIr (
|
|||
Visibility (..),
|
||||
CallingConvention (..),
|
||||
ToIr (..),
|
||||
typeOf
|
||||
) where
|
||||
|
||||
import Data.List (intercalate)
|
||||
import TypeChecker.TypeCheckerIr (Ident (..))
|
||||
import Data.List (intercalate)
|
||||
import TypeChecker.TypeCheckerIr (Ident (..))
|
||||
|
||||
data CallingConvention = TailCC | FastCC | CCC | ColdCC deriving (Show, Eq, Ord)
|
||||
instance ToIr CallingConvention where
|
||||
toIr :: CallingConvention -> String
|
||||
toIr TailCC = "tailcc"
|
||||
toIr FastCC = "fastcc"
|
||||
toIr CCC = "ccc"
|
||||
toIr CCC = "ccc"
|
||||
toIr ColdCC = "coldcc"
|
||||
|
||||
-- | A datatype which represents some basic LLVM types
|
||||
|
|
@ -38,6 +39,9 @@ data LLVMType
|
|||
class ToIr a where
|
||||
toIr :: a -> String
|
||||
|
||||
instance ToIr a => ToIr [a] where
|
||||
toIr = concatMap toIr
|
||||
|
||||
instance ToIr LLVMType where
|
||||
toIr :: LLVMType -> String
|
||||
toIr = \case
|
||||
|
|
@ -66,8 +70,8 @@ data LLVMComp
|
|||
instance ToIr LLVMComp where
|
||||
toIr :: LLVMComp -> String
|
||||
toIr = \case
|
||||
LLEq -> "eq"
|
||||
LLNe -> "ne"
|
||||
LLEq -> "eq"
|
||||
LLNe -> "ne"
|
||||
LLUgt -> "ugt"
|
||||
LLUge -> "uge"
|
||||
LLUlt -> "ult"
|
||||
|
|
@ -80,7 +84,7 @@ instance ToIr LLVMComp where
|
|||
data Visibility = Local | Global deriving (Show, Eq, Ord)
|
||||
instance ToIr Visibility where
|
||||
toIr :: Visibility -> String
|
||||
toIr Local = "%"
|
||||
toIr Local = "%"
|
||||
toIr Global = "@"
|
||||
|
||||
{- | Represents a LLVM "value", as in an integer, a register variable,
|
||||
|
|
@ -92,16 +96,18 @@ data LLVMValue
|
|||
| VIdent Ident LLVMType
|
||||
| VConstant String
|
||||
| VFunction Ident Visibility LLVMType
|
||||
| VNull
|
||||
deriving (Show, Eq, Ord)
|
||||
|
||||
instance ToIr LLVMValue where
|
||||
toIr :: LLVMValue -> String
|
||||
toIr v = case v of
|
||||
VInteger i -> show i
|
||||
VChar i -> show i
|
||||
VIdent (Ident n) _ -> "%" <> n
|
||||
VInteger i -> show i
|
||||
VChar i -> show i
|
||||
VIdent (Ident n) _ -> "%" <> n
|
||||
VFunction (Ident n) vis _ -> toIr vis <> n
|
||||
VConstant s -> "c" <> show s
|
||||
VConstant s -> "c" <> show s
|
||||
VNull -> "null"
|
||||
|
||||
type Params = [(Ident, LLVMType)]
|
||||
type Args = [(LLVMType, LLVMValue)]
|
||||
|
|
@ -139,6 +145,21 @@ data LLVMIr
|
|||
-- instructions should be used in its place
|
||||
deriving (Show, Eq, Ord)
|
||||
|
||||
|
||||
-- TODO add missing clauses
|
||||
typeOf :: LLVMIr -> LLVMType
|
||||
typeOf = \case
|
||||
Add t _ _ -> t
|
||||
Sub t _ _ -> t
|
||||
Mul t _ _ -> t
|
||||
Div t _ _ -> t
|
||||
Load t _ _ -> t
|
||||
Store t _ _ _ -> t
|
||||
Type x _ -> CustomType x
|
||||
SetVariable _ ir -> typeOf ir
|
||||
|
||||
|
||||
|
||||
-- | Converts a list of LLVMIr instructions to a string
|
||||
llvmIrToString :: [LLVMIr] -> String
|
||||
llvmIrToString = go 0
|
||||
|
|
@ -147,9 +168,9 @@ llvmIrToString = go 0
|
|||
go _ [] = mempty
|
||||
go i (x : xs) = do
|
||||
let (i', n) = case x of
|
||||
Define{} -> (i + 1, 0)
|
||||
Define{} -> (i + 1, 0)
|
||||
DefineEnd -> (i - 1, 0)
|
||||
_ -> (i, i)
|
||||
_ -> (i, i)
|
||||
insToString n x <> go i' xs
|
||||
|
||||
-- \| Converts a LLVM inststruction to a String, allowing for printing etc.
|
||||
|
|
@ -224,10 +245,10 @@ llvmIrToString = go 0
|
|||
, ")\n"
|
||||
]
|
||||
(Alloca t) -> unwords ["alloca", toIr t, "\n"]
|
||||
(Malloc t) ->
|
||||
(Malloc t) ->
|
||||
concat
|
||||
[ "call ptr @malloc(i64 ", show t, ")\n"]
|
||||
(GcMalloc t) ->
|
||||
(GcMalloc t) ->
|
||||
concat
|
||||
[ "call ptr @cheap_alloc(i64 ", show t, ")\n"]
|
||||
(Store t1 val t2 (Ident id2)) ->
|
||||
|
|
|
|||
|
|
@ -11,9 +11,11 @@ import Control.Monad.State (MonadState (get, put), State,
|
|||
evalState)
|
||||
import Data.Function (on)
|
||||
import Data.List (delete, mapAccumL, (\\))
|
||||
import Data.Tuple.Extra (first, second)
|
||||
import LambdaLifterIr (T)
|
||||
import qualified LambdaLifterIr as L
|
||||
import Prelude hiding (exp)
|
||||
import TypeChecker.TypeCheckerIr
|
||||
|
||||
import TypeChecker.TypeCheckerIr hiding (T)
|
||||
|
||||
-- | Lift lambdas and let expression into supercombinators.
|
||||
-- Three phases:
|
||||
|
|
@ -21,12 +23,13 @@ import TypeChecker.TypeCheckerIr
|
|||
-- @abstract@ converts lambdas into let expressions.
|
||||
-- @collectScs@ moves every non-constant let expression to a top-level function.
|
||||
--
|
||||
lambdaLift :: Program -> Program
|
||||
lambdaLift (Program ds) = Program (datatypes ++ binds)
|
||||
lambdaLift :: Program -> L.Program
|
||||
lambdaLift (Program ds) = L.Program (datatypes ++ binds)
|
||||
where
|
||||
datatypes = flip filter ds $ \case DData _ -> True
|
||||
_ -> False
|
||||
binds = map DBind $ (collectScs . abstract . freeVars) [b | DBind b <- ds]
|
||||
datatypes = [L.DData (toLirData d) | DData d <- ds]
|
||||
|
||||
binds = map L.DBind $ (collectScs . abstract . freeVars) [b | DBind b <- ds]
|
||||
|
||||
|
||||
-- | Annotate free variables
|
||||
freeVars :: [Bind] -> [ABind]
|
||||
|
|
@ -36,7 +39,7 @@ freeVars binds = [ let ae = freeVarsExp [] e
|
|||
| Bind n xs e <- binds
|
||||
]
|
||||
|
||||
freeVarsExp :: Frees -> ExpT -> Ann AExpT
|
||||
freeVarsExp :: Frees -> T Exp -> Ann (T AExp)
|
||||
freeVarsExp localVars (ae, t) = case ae of
|
||||
EVar n | elem (n,t) localVars -> Ann { frees = [(n, t)]
|
||||
, term = (AVar n, t)
|
||||
|
|
@ -121,27 +124,47 @@ data Ann a = Ann
|
|||
, term :: a
|
||||
} deriving (Show, Eq)
|
||||
|
||||
data ABind = ABind Id [Id] (Ann AExpT) deriving (Show, Eq)
|
||||
data ABranch = ABranch (Pattern, Type) (Ann AExpT) deriving (Show, Eq)
|
||||
|
||||
type AExpT = (AExp, Type)
|
||||
data ABind = ABind (T Ident) [T Ident] (Ann (T AExp)) deriving (Show, Eq)
|
||||
data ABranch = ABranch (Pattern, Type) (Ann (T AExp)) deriving (Show, Eq)
|
||||
|
||||
data AExp = AVar Ident
|
||||
| AInj Ident
|
||||
| ALit Lit
|
||||
| ALet (Ann ABind) (Ann AExpT)
|
||||
| AApp (Ann AExpT) (Ann AExpT)
|
||||
| AAdd (Ann AExpT) (Ann AExpT)
|
||||
| AAbs Ident (Ann AExpT)
|
||||
| ACase (Ann AExpT) [Ann ABranch]
|
||||
| ALet (Ann ABind) (Ann (T AExp))
|
||||
| AApp (Ann (T AExp)) (Ann (T AExp))
|
||||
| AAdd (Ann (T AExp)) (Ann (T AExp))
|
||||
| AAbs Ident (Ann (T AExp))
|
||||
| ACase (Ann (T AExp)) [Ann ABranch]
|
||||
deriving (Show, Eq)
|
||||
|
||||
abstract :: [ABind] -> [Bind]
|
||||
|
||||
|
||||
data BBind = BBind (T Ident) [T Ident] (T BExp)
|
||||
| BBindCxt [T Ident] (T Ident) [T Ident] (T BExp)
|
||||
deriving (Eq, Ord, Show)
|
||||
|
||||
|
||||
data BBranch = BBranch (T Pattern) (T BExp)
|
||||
deriving (Eq, Ord, Show)
|
||||
|
||||
data BExp
|
||||
= BVar Ident
|
||||
| BVarC [T Ident] Ident
|
||||
| BInj Ident
|
||||
| BLit Lit
|
||||
| BLet BBind (T BExp)
|
||||
| BApp (T BExp)(T BExp)
|
||||
| BAdd (T BExp)(T BExp)
|
||||
| BCase (T BExp) [BBranch]
|
||||
deriving (Eq, Ord, Show)
|
||||
|
||||
|
||||
abstract :: [ABind] -> [BBind]
|
||||
abstract bs = evalState (mapM (abstractAnnBind . Ann []) bs) 0
|
||||
|
||||
abstractAnnBind :: Ann ABind -> State Int Bind
|
||||
abstractAnnBind :: Ann ABind -> State Int BBind
|
||||
abstractAnnBind Ann { term = ABind name vars annae } =
|
||||
Bind name (vars' <|| vars) <$> abstractAnnExp annae'
|
||||
BBind name (vars' <|| vars) <$> abstractAnnExp annae'
|
||||
where
|
||||
(annae', vars') = go [] annae
|
||||
where
|
||||
|
|
@ -149,24 +172,27 @@ abstractAnnBind Ann { term = ABind name vars annae } =
|
|||
Ann { term = (AAbs x ae, TFun t _) } -> go (snoc (x, t) acc) ae
|
||||
ae -> (ae, acc)
|
||||
|
||||
abstractAnnExp :: Ann AExpT -> State Int ExpT
|
||||
abstractAnnExp :: Ann (T AExp) -> State Int (T BExp)
|
||||
abstractAnnExp Ann {frees, term = (annae, typ) } = case annae of
|
||||
AVar n -> pure (EVar n, typ)
|
||||
AInj n -> pure (EInj n, typ)
|
||||
ALit lit -> pure (ELit lit, typ)
|
||||
AApp annae1 annae2 -> (, typ) <$> onM EApp abstractAnnExp annae1 annae2
|
||||
AAdd annae1 annae2 -> (, typ) <$> onM EAdd abstractAnnExp annae1 annae2
|
||||
AVar n -> pure (BVar n, typ)
|
||||
AInj n -> pure (BInj n, typ)
|
||||
ALit lit -> pure (BLit lit, typ)
|
||||
AApp annae1 annae2 -> (, typ) <$> onM BApp abstractAnnExp annae1 annae2
|
||||
AAdd annae1 annae2 -> (, typ) <$> onM BAdd abstractAnnExp annae1 annae2
|
||||
|
||||
-- \x. \y. x + y + z ⇒ let sc x y z = x + y + z in sc
|
||||
AAbs x annae' -> do
|
||||
i <- nextNumber
|
||||
rhs <- abstractAnnExp annae''
|
||||
let sc_name = Ident ("sc_" ++ show i)
|
||||
e@(_, t) = foldl applyFree (EVar sc_name, typ) frees
|
||||
pure (ELet (Bind (sc_name, typ) vars rhs) e ,t)
|
||||
sc | null frees = (BVar sc_name, typ)
|
||||
| otherwise = (BVarC frees sc_name, typ)
|
||||
bind | null frees = BBind (sc_name, typ) vars rhs
|
||||
| otherwise = BBindCxt frees (sc_name, typ) vars rhs
|
||||
|
||||
pure (BLet bind sc ,typ)
|
||||
|
||||
where
|
||||
vars = frees <| (x, t_x) <|| ys
|
||||
vars = [(x, t_x)] <|| ys
|
||||
t_x = case typ of TFun t _ -> t
|
||||
_ -> error "Impossible"
|
||||
|
||||
|
|
@ -176,54 +202,48 @@ abstractAnnExp Ann {frees, term = (annae, typ) } = case annae of
|
|||
Ann { term = (AAbs x ae, TFun t _) } -> go (snoc (x, t) acc) ae
|
||||
ae -> (ae, acc)
|
||||
|
||||
|
||||
applyFree :: (Exp' Type, Type) -> (Ident, Type) -> (Exp' Type, Type)
|
||||
applyFree (e, t_e) (x, t_x) = (EApp (e, t_e) (EVar x, t_x), t_e')
|
||||
where
|
||||
t_e' = case t_e of TFun _ t -> t
|
||||
_ -> error "Impossible"
|
||||
|
||||
ACase annae' bs -> do
|
||||
bs <- mapM go bs
|
||||
e <- abstractAnnExp annae'
|
||||
pure (ECase e bs, typ)
|
||||
pure (BCase e bs, typ)
|
||||
where
|
||||
go Ann { term = ABranch p annae } = Branch p <$> abstractAnnExp annae
|
||||
go Ann { term = ABranch p annae } = BBranch p <$> abstractAnnExp annae
|
||||
|
||||
ALet b annae' ->
|
||||
(, typ) <$> liftA2 ELet (abstractAnnBind b) (abstractAnnExp annae')
|
||||
(, typ) <$> liftA2 BLet (abstractAnnBind b) (abstractAnnExp annae')
|
||||
|
||||
|
||||
-- | Collects supercombinators by lifting non-constant let expressions
|
||||
collectScs :: [Bind] -> [Bind]
|
||||
collectScs :: [BBind] -> [L.Bind]
|
||||
collectScs = concatMap collectFromRhs
|
||||
where
|
||||
collectFromRhs (Bind name parms rhs) =
|
||||
collectFromRhs (BBind name parms rhs) =
|
||||
let (rhs_scs, rhs') = collectScsExp rhs
|
||||
in Bind name parms rhs' : rhs_scs
|
||||
in L.Bind name parms rhs' : rhs_scs
|
||||
collectFromRhs (BBindCxt cxt name parms rhs) =
|
||||
let (rhs_scs, rhs') = collectScsExp rhs
|
||||
in L.BindC cxt name parms rhs' : rhs_scs
|
||||
|
||||
|
||||
collectScsExp :: ExpT -> ([Bind], ExpT)
|
||||
collectScsExp expT@(exp, typ) = case exp of
|
||||
EVar _ -> ([], expT)
|
||||
EInj _ -> ([], expT)
|
||||
ELit _ -> ([], expT)
|
||||
collectScsExp :: T BExp -> ([L.Bind], T L.Exp)
|
||||
collectScsExp (exp, typ) = case exp of
|
||||
BVar x -> ([], (L.EVar x, typ))
|
||||
BVarC as x -> ([], (L.EVarC as x, typ))
|
||||
BInj k -> ([], (L.EInj k, typ))
|
||||
BLit lit -> ([], (L.ELit lit, typ))
|
||||
|
||||
EApp e1 e2 -> (scs1 ++ scs2, (EApp e1' e2', typ))
|
||||
BApp e1 e2 -> (scs1 ++ scs2, (L.EApp e1' e2', typ))
|
||||
where
|
||||
(scs1, e1') = collectScsExp e1
|
||||
(scs2, e2') = collectScsExp e2
|
||||
|
||||
EAdd e1 e2 -> (scs1 ++ scs2, (EAdd e1' e2', typ))
|
||||
BAdd e1 e2 -> (scs1 ++ scs2, (L.EAdd e1' e2', typ))
|
||||
where
|
||||
(scs1, e1') = collectScsExp e1
|
||||
(scs2, e2') = collectScsExp e2
|
||||
|
||||
EAbs par e -> (scs, (EAbs par e', typ))
|
||||
where
|
||||
(scs, e') = collectScsExp e
|
||||
|
||||
ECase e branches -> (scs ++ scs_e, (ECase e' branches', typ))
|
||||
BCase e branches -> (scs ++ scs_e, (L.ECase e' branches', typ))
|
||||
where
|
||||
(scs, branches') = mapAccumL f [] branches
|
||||
(scs_e, e') = collectScsExp e
|
||||
|
|
@ -234,15 +254,24 @@ collectScsExp expT@(exp, typ) = case exp of
|
|||
--
|
||||
-- > f = let sc x y = rhs in e
|
||||
--
|
||||
ELet (Bind name parms rhs) e
|
||||
| null parms -> (rhs_scs ++ et_scs, (ELet bind et', snd et'))
|
||||
BLet (BBind name parms rhs) e
|
||||
| null parms -> (rhs_scs ++ et_scs, (L.ELet name rhs' et', snd et'))
|
||||
| otherwise -> (bind : rhs_scs ++ et_scs, et')
|
||||
where
|
||||
bind = Bind name parms rhs'
|
||||
bind = L.Bind name parms rhs'
|
||||
(rhs_scs, rhs') = collectScsExp rhs
|
||||
(et_scs, et') = collectScsExp e
|
||||
|
||||
collectScsBranch (Branch patt exp) = (scs, Branch patt exp')
|
||||
|
||||
BLet (BBindCxt cxt name parms rhs) e
|
||||
| null parms -> (rhs_scs ++ et_scs, (L.ELet name rhs' et', snd et'))
|
||||
| otherwise -> (bind : rhs_scs ++ et_scs, et')
|
||||
where
|
||||
bind = L.BindC cxt name parms rhs'
|
||||
(rhs_scs, rhs') = collectScsExp rhs
|
||||
(et_scs, et') = collectScsExp e
|
||||
|
||||
collectScsBranch (BBranch patt exp) = (scs, L.Branch (first toLirPattern patt) exp')
|
||||
where (scs, exp') = collectScsExp exp
|
||||
|
||||
nextNumber :: State Int Int
|
||||
|
|
@ -259,3 +288,19 @@ xs <| x | elem x xs = xs
|
|||
(<||) :: Eq a => [a] -> [a] -> [a]
|
||||
xs <|| ys = foldl (<|) xs ys
|
||||
|
||||
|
||||
|
||||
toLirData (Data t injs) = L.Data t (map toLirInj injs)
|
||||
toLirInj (Inj n t) = L.Inj n t
|
||||
|
||||
toLirPattern :: Pattern -> L.Pattern
|
||||
toLirPattern = \case
|
||||
PVar x -> L.PVar x
|
||||
PLit lit -> L.PLit lit
|
||||
PCatch -> L.PCatch
|
||||
PEnum k -> L.PEnum k
|
||||
PInj k ps -> L.PInj k (map (first toLirPattern) ps)
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
140
src/LambdaLifterIr.hs
Normal file
140
src/LambdaLifterIr.hs
Normal file
|
|
@ -0,0 +1,140 @@
|
|||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE PatternSynonyms #-}
|
||||
|
||||
module LambdaLifterIr (
|
||||
module Grammar.Abs,
|
||||
module LambdaLifterIr,
|
||||
module TypeChecker.TypeCheckerIr
|
||||
) where
|
||||
|
||||
import Data.List (intercalate)
|
||||
import Grammar.Abs (Lit (..))
|
||||
import Grammar.Print
|
||||
import Prelude hiding (exp)
|
||||
import qualified Prelude as C (Eq, Ord, Show)
|
||||
import TypeChecker.TypeCheckerIr (Ident (..), TVar (..), Type (..))
|
||||
|
||||
newtype Program = Program [Def]
|
||||
deriving (C.Eq, C.Ord, C.Show)
|
||||
|
||||
data Def
|
||||
= DBind Bind
|
||||
| DData Data
|
||||
deriving (C.Eq, C.Ord, C.Show)
|
||||
|
||||
data Data = Data Type [Inj]
|
||||
deriving (C.Eq, C.Ord, C.Show)
|
||||
|
||||
data Inj = Inj Ident Type
|
||||
deriving (C.Eq, C.Ord, C.Show)
|
||||
|
||||
data Pattern
|
||||
= PVar Ident
|
||||
| PLit Lit
|
||||
| PCatch
|
||||
| PEnum Ident
|
||||
| PInj Ident [(Pattern, Type)]
|
||||
deriving (C.Eq, C.Ord, C.Show)
|
||||
|
||||
data Exp
|
||||
= EVar Ident
|
||||
| EVarC [T Ident] Ident
|
||||
| EInj Ident
|
||||
| ELit Lit
|
||||
| ELet (T Ident) (T Exp) (T Exp)
|
||||
| EApp (T Exp)(T Exp)
|
||||
| EAdd (T Exp)(T Exp)
|
||||
| ECase (T Exp) [Branch]
|
||||
deriving (C.Eq, C.Ord, C.Show)
|
||||
|
||||
|
||||
type T a = (a, Type)
|
||||
|
||||
data Bind = Bind (T Ident) [T Ident] (T Exp)
|
||||
| BindC [T Ident] (T Ident) [T Ident] (T Exp)
|
||||
deriving (C.Eq, C.Ord, C.Show)
|
||||
|
||||
data Branch = Branch (T Pattern) (T Exp)
|
||||
deriving (C.Eq, C.Ord, C.Show)
|
||||
|
||||
instance Print Program where
|
||||
prt i (Program sc) = prt i sc
|
||||
|
||||
instance Print Bind where
|
||||
prt i (Bind sig parms rhs) = concatD
|
||||
[ prt i sig
|
||||
, prt i parms
|
||||
, doc $ showString "="
|
||||
, prt i rhs
|
||||
]
|
||||
prt i (BindC cxt sig parms rhs) =
|
||||
prPrec i 0 $
|
||||
concatD
|
||||
[ doc . showString $ "{" ++ intercalate ", " (map (\(x, _) -> printTree x ++ "^") cxt) ++ "}" ++ printTree sig
|
||||
, prt i parms
|
||||
, doc $ showString "="
|
||||
, prt i rhs
|
||||
]
|
||||
|
||||
instance Print [Bind] where
|
||||
prt _ [] = concatD []
|
||||
prt i [x] = concatD [prt i x]
|
||||
prt i (x : xs) = concatD [prt i x, doc (showString ";"), prt i xs]
|
||||
|
||||
instance Print Exp where
|
||||
prt i = \case
|
||||
EVar lident -> prPrec i 3 (concatD [prt 0 lident])
|
||||
EVarC as lident -> doc . showString
|
||||
$ "{" ++ intercalate ", " (map go as) ++ "}" ++ printTree lident
|
||||
where
|
||||
go (x, _) = printTree x ++ "^=" ++ printTree (EVar x)
|
||||
EInj uident -> prPrec i 3 (concatD [prt 0 uident])
|
||||
ELit lit -> prPrec i 3 (concatD [prt 0 lit])
|
||||
EApp exp1 exp2 -> prPrec i 2 (concatD [prt 2 exp1, prt 3 exp2])
|
||||
EAdd exp1 exp2 -> prPrec i 1 (concatD [prt 1 exp1, doc (showString "+"), prt 2 exp2])
|
||||
ELet lident exp1 exp2 -> prPrec i 0 (concatD [doc (showString "let"), prt 0 lident, doc (showString "="), prt 0 exp1 , doc (showString "in"), prt 0 exp2])
|
||||
ECase exp branchs -> prPrec i 0 (concatD [doc (showString "case"), prt 0 exp, doc (showString "of"), doc (showString "{"), prt 0 branchs, doc (showString "}")])
|
||||
|
||||
|
||||
instance Print Branch where
|
||||
prt i (Branch (pattern_, t) exp) = prPrec i 0 (concatD [doc (showString "("), prt 0 pattern_, doc (showString " : "), prt 0 t, doc (showString ")"), doc (showString "=>"), prt 0 exp])
|
||||
|
||||
instance Print [Branch] where
|
||||
prt _ [] = concatD []
|
||||
prt _ [x] = concatD [prt 0 x]
|
||||
prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs]
|
||||
|
||||
instance Print Def where
|
||||
prt i = \case
|
||||
DBind bind -> prPrec i 0 (concatD [prt 0 bind])
|
||||
DData data_ -> prPrec i 0 (concatD [prt 0 data_])
|
||||
|
||||
instance Print Data where
|
||||
prt i = \case
|
||||
Data type_ injs -> prPrec i 0 (concatD [doc (showString "data"), prt 0 type_, doc (showString "where"), doc (showString "{"), prt 0 injs, doc (showString "}")])
|
||||
|
||||
instance Print Inj where
|
||||
prt i = \case
|
||||
Inj uident type_ -> prt i (uident, type_)
|
||||
|
||||
instance Print [Inj] where
|
||||
prt _ [] = concatD []
|
||||
prt i [x] = prt i x
|
||||
prt i (x : xs) = prPrec i 0 $ concatD [prt i x, doc $ showString "\n ", prt i xs]
|
||||
|
||||
instance Print Pattern where
|
||||
prt i = \case
|
||||
PVar name -> prPrec i 1 (concatD [prt 0 name])
|
||||
PLit lit -> prPrec i 1 (concatD [prt 0 lit])
|
||||
PCatch -> prPrec i 1 (concatD [doc (showString "_")])
|
||||
PEnum name -> prPrec i 1 (concatD [prt 0 name])
|
||||
PInj uident patterns -> prPrec i 0 (concatD [prt 0 uident, prt 1 patterns])
|
||||
|
||||
instance Print [Def] where
|
||||
prt _ [] = concatD []
|
||||
prt _ [x] = concatD [prt 0 x]
|
||||
prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs]
|
||||
|
||||
pattern DBind' id vars expt = DBind (Bind id vars expt)
|
||||
pattern DData' typ injs = DData (Data typ injs)
|
||||
|
||||
|
|
@ -1,8 +1,11 @@
|
|||
|
||||
module Monomorphizer.DataTypeRemover (removeDataTypes) where
|
||||
|
||||
import Monomorphizer.MonomorphizerIr qualified as M2
|
||||
import Monomorphizer.MorbIr qualified as M1
|
||||
import TypeChecker.TypeCheckerIr (Ident (Ident))
|
||||
import Data.Bifunctor (Bifunctor (bimap))
|
||||
import Monomorphizer.MonomorphizerIr (Ident (..))
|
||||
import qualified Monomorphizer.MonomorphizerIr as M2
|
||||
import qualified Monomorphizer.MorbIr as M1
|
||||
import Prelude hiding (exp)
|
||||
|
||||
removeDataTypes :: M1.Program -> M2.Program
|
||||
removeDataTypes (M1.Program defs) = M2.Program (map pDef defs)
|
||||
|
|
@ -18,43 +21,43 @@ pCons :: M1.Inj -> M2.Inj
|
|||
pCons (M1.Inj ident t) = M2.Inj ident (pType t)
|
||||
|
||||
pType :: M1.Type -> M2.Type
|
||||
pType (M1.TLit ident) = M2.TLit ident
|
||||
pType (M1.TFun t1 t2) = M2.TFun (pType t1) (pType t2)
|
||||
pType (M1.TLit ident) = M2.TLit ident
|
||||
pType (M1.TFun t1 t2) = M2.TFun (pType t1) (pType t2)
|
||||
pType (M1.TData (Ident "Bool") _) = M2.TLit (Ident "Bool")
|
||||
pType d = M2.TLit (Ident (newName d)) -- This is the step
|
||||
pType d = M2.TLit (Ident (newName d)) -- This is the step
|
||||
|
||||
newName :: M1.Type -> String
|
||||
newName (M1.TLit (Ident str)) = str
|
||||
newName (M1.TFun t1 t2) = newName t1 ++ newName t2
|
||||
newName (M1.TLit (Ident str)) = str
|
||||
newName (M1.TFun t1 t2) = newName t1 ++ newName t2
|
||||
newName (M1.TData (Ident str) args) = str ++ concatMap newName args
|
||||
|
||||
pBind :: M1.Bind -> M2.Bind
|
||||
pBind (M1.Bind id argIds expt) = M2.Bind (pId id) (map pId argIds) (pExpT expt)
|
||||
pBind (M1.BindC cxt id argIds expt) =
|
||||
M2.BindC (map pId cxt) (pId id) (map pId argIds) (pExpT expt)
|
||||
|
||||
pId :: (Ident, M1.Type) -> (Ident, M2.Type)
|
||||
pId (ident, t) = (ident, pType t)
|
||||
|
||||
pExpT :: M1.ExpT -> M2.ExpT
|
||||
pExpT :: M1.T M1.Exp -> M2.T M2.Exp
|
||||
pExpT (exp, t) = (pExp exp, pType t)
|
||||
|
||||
pExp :: M1.Exp -> M2.Exp
|
||||
pExp (M1.EVar ident) = M2.EVar ident
|
||||
pExp (M1.ELit lit) = M2.ELit (pLit lit)
|
||||
pExp (M1.ELet bind expt) = M2.ELet (pBind bind) (pExpT expt)
|
||||
pExp (M1.EApp e1 e2) = M2.EApp (pExpT e1) (pExpT e2)
|
||||
pExp (M1.EAdd e1 e2) = M2.EAdd (pExpT e1) (pExpT e2)
|
||||
pExp (M1.EVar ident) = M2.EVar ident
|
||||
pExp (M1.EVarC as ident) = M2.EVarC (map pId as) ident
|
||||
pExp (M1.ELit lit) = M2.ELit lit
|
||||
pExp (M1.ELet bind expt) = M2.ELet (pBind bind) (pExpT expt)
|
||||
pExp (M1.EApp e1 e2) = M2.EApp (pExpT e1) (pExpT e2)
|
||||
pExp (M1.EAdd e1 e2) = M2.EAdd (pExpT e1) (pExpT e2)
|
||||
pExp (M1.ECase expT branches) = M2.ECase (pExpT expT) (map pBranch branches)
|
||||
|
||||
pBranch :: M1.Branch -> M2.Branch
|
||||
pBranch (M1.Branch (patt, t) expt) = M2.Branch (pPattern patt, pType t) (pExpT expt)
|
||||
|
||||
pPattern :: M1.Pattern -> M2.Pattern
|
||||
pPattern (M1.PVar id) = M2.PVar (pId id)
|
||||
pPattern (M1.PLit (lit, t)) = M2.PLit (pLit lit, pType t)
|
||||
pPattern (M1.PInj ident patts) = M2.PInj ident (map pPattern patts)
|
||||
pPattern M1.PCatch = M2.PCatch
|
||||
pPattern (M1.PEnum ident) = M2.PEnum ident
|
||||
pPattern (M1.PVar ident) = M2.PVar ident
|
||||
pPattern (M1.PLit lit) = M2.PLit lit
|
||||
pPattern (M1.PInj ident patts) = M2.PInj ident (map (bimap pPattern pType) patts)
|
||||
pPattern M1.PCatch = M2.PCatch
|
||||
pPattern (M1.PEnum ident) = M2.PEnum ident
|
||||
|
||||
pLit :: M1.Lit -> M2.Lit
|
||||
pLit (M1.LInt v) = M2.LInt v
|
||||
pLit (M1.LChar c) = M2.LChar c
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE OverloadedRecordDot #-}
|
||||
|
||||
{- | For now, converts polymorphic functions to concrete ones based on usage.
|
||||
Assumes lambdas are lifted.
|
||||
|
|
@ -25,30 +26,35 @@ bind) is added to the resulting set of binds.
|
|||
-}
|
||||
module Monomorphizer.Monomorphizer (monomorphize, morphExp, morphBind) where
|
||||
|
||||
import Monomorphizer.DataTypeRemover (removeDataTypes)
|
||||
import Monomorphizer.MonomorphizerIr qualified as O
|
||||
import Monomorphizer.MorbIr qualified as M
|
||||
import TypeChecker.TypeCheckerIr (Ident (Ident))
|
||||
import TypeChecker.TypeCheckerIr qualified as T
|
||||
|
||||
import Control.Monad.Reader (
|
||||
MonadReader (ask, local),
|
||||
Reader,
|
||||
asks,
|
||||
runReader,
|
||||
)
|
||||
import Control.Monad.State (
|
||||
MonadState (get),
|
||||
StateT (runStateT),
|
||||
gets,
|
||||
modify,
|
||||
)
|
||||
import Data.Coerce (coerce)
|
||||
import Data.Map qualified as Map
|
||||
import Data.Maybe (catMaybes)
|
||||
import Data.Set qualified as Set
|
||||
import Grammar.Print (printTree)
|
||||
import Debug.Trace (trace)
|
||||
import Control.Monad.Reader (MonadReader (ask, local),
|
||||
Reader, asks, runReader)
|
||||
import Control.Monad.State (MonadState (get),
|
||||
StateT (runStateT), gets,
|
||||
modify)
|
||||
import Data.Coerce (coerce)
|
||||
import qualified Data.Map as Map
|
||||
import Data.Maybe (catMaybes)
|
||||
import qualified Data.Set as Set
|
||||
import Debug.Trace (trace)
|
||||
import Grammar.Print (printTree)
|
||||
import Monomorphizer.DataTypeRemover (removeDataTypes)
|
||||
import qualified Monomorphizer.MonomorphizerIr as O
|
||||
import qualified Monomorphizer.MorbIr as M
|
||||
-- import TypeChecker.TypeCheckerIr (Ident (Ident))
|
||||
import LambdaLifterIr (Ident (..))
|
||||
-- import TypeChecker.TypeCheckerIr qualified as T
|
||||
import qualified LambdaLifterIr as L
|
||||
|
||||
import Control.Monad.Reader (MonadReader (ask, local),
|
||||
Reader, asks, runReader)
|
||||
import Control.Monad.State (MonadState, StateT (runStateT),
|
||||
gets, modify)
|
||||
import qualified Data.Map as Map
|
||||
import Data.Maybe (catMaybes, fromJust)
|
||||
import qualified Data.Set as Set
|
||||
import Data.Tuple.Extra (secondM)
|
||||
import Grammar.Print (printTree)
|
||||
|
||||
{- | EnvM is the monad containing the read-only state as well as the
|
||||
output state containing monomorphized functions and to-be monomorphized
|
||||
|
|
@ -64,18 +70,18 @@ Binds, Polymorphic Data types (monomorphized in a later step) and
|
|||
Marked bind, which means that it is in the process of monomorphization
|
||||
and should not be monomorphized again.
|
||||
-}
|
||||
data Outputted = Marked | Complete M.Bind | Data M.Type T.Data deriving (Show)
|
||||
data Outputted = Marked | Complete M.Bind | Data M.Type L.Data deriving (Show)
|
||||
|
||||
-- | Static environment.
|
||||
data Env = Env
|
||||
{ input :: Map.Map Ident T.Bind
|
||||
{ input :: Map.Map Ident L.Bind
|
||||
-- ^ All binds in the program.
|
||||
, dataDefs :: Map.Map Ident T.Data
|
||||
, dataDefs :: Map.Map Ident L.Data
|
||||
-- ^ All constructors mapped to their respective polymorphic data def
|
||||
-- which includes all other constructors.
|
||||
, polys :: Map.Map Ident M.Type
|
||||
, polys :: Map.Map Ident M.Type
|
||||
-- ^ Maps polymorphic identifiers with concrete types.
|
||||
, locals :: Set.Set Ident
|
||||
, locals :: Set.Set Ident
|
||||
-- ^ Local variables.
|
||||
}
|
||||
|
||||
|
|
@ -84,12 +90,13 @@ localExists :: Ident -> EnvM Bool
|
|||
localExists ident = asks (Set.member ident . locals)
|
||||
|
||||
-- | Gets a polymorphic bind from an id.
|
||||
getInputBind :: Ident -> EnvM (Maybe T.Bind)
|
||||
getInputBind :: Ident -> EnvM (Maybe L.Bind)
|
||||
getInputBind ident = asks (Map.lookup ident . input)
|
||||
|
||||
-- | Add monomorphic function derived from a polymorphic one, to env.
|
||||
addOutputBind :: M.Bind -> EnvM ()
|
||||
addOutputBind b@(M.Bind (ident, _) _ _) = modify (Map.insert ident (Complete b))
|
||||
addOutputBind b@(M.BindC _ (ident, _) _ _) = modify (Map.insert ident (Complete b))
|
||||
|
||||
{- | Marks a global bind as being processed, meaning that when encountered again,
|
||||
it should not be recursively processed.
|
||||
|
|
@ -106,8 +113,8 @@ isConsMarked :: Ident -> EnvM Bool
|
|||
isConsMarked ident = gets (Map.member ident)
|
||||
|
||||
-- | Finds main bind.
|
||||
getMain :: EnvM T.Bind
|
||||
getMain = asks (\env -> case Map.lookup (T.Ident "main") (input env) of
|
||||
getMain :: EnvM L.Bind
|
||||
getMain = asks (\env -> case Map.lookup (Ident "main") (input env) of
|
||||
Just mainBind -> mainBind
|
||||
Nothing -> error "main not found in monomorphizer!"
|
||||
)
|
||||
|
|
@ -116,13 +123,13 @@ getMain = asks (\env -> case Map.lookup (T.Ident "main") (input env) of
|
|||
error when encountering different structures between the two arguments. Debug:
|
||||
First argument is the name of the bind.
|
||||
-}
|
||||
mapTypes :: Ident -> T.Type -> M.Type -> [(Ident, M.Type)]
|
||||
mapTypes _ident (T.TLit _) (M.TLit _) = []
|
||||
mapTypes _ident (T.TVar (T.MkTVar i1)) tm = [(i1, tm)]
|
||||
mapTypes ident (T.TFun pt1 pt2) (M.TFun mt1 mt2) =
|
||||
mapTypes :: Ident -> L.Type -> M.Type -> [(Ident, M.Type)]
|
||||
mapTypes _ident (L.TLit _) (M.TLit _) = []
|
||||
mapTypes _ident (L.TVar (L.MkTVar i1)) tm = [(i1, tm)]
|
||||
mapTypes ident (L.TFun pt1 pt2) (M.TFun mt1 mt2) =
|
||||
mapTypes ident pt1 mt1
|
||||
++ mapTypes ident pt2 mt2
|
||||
mapTypes ident (T.TData tIdent pTs) (M.TData mIdent mTs) =
|
||||
mapTypes ident (L.TData tIdent pTs) (M.TData mIdent mTs) =
|
||||
if tIdent /= mIdent
|
||||
then error "the data type names of monomorphic and polymorphic data types does not match"
|
||||
else foldl (\xs (p, m) -> mapTypes ident p m ++ xs) [] (zip pTs mTs)
|
||||
|
|
@ -130,30 +137,30 @@ mapTypes ident t1 t2 = error $ "in bind: '" ++ printTree ident ++ "', " ++
|
|||
"structure of types not the same: '" ++ printTree t1 ++ "', '" ++ printTree t2 ++ "'"
|
||||
|
||||
-- | Gets the mapped monomorphic type of a polymorphic type in the current context.
|
||||
getMonoFromPoly :: T.Type -> EnvM M.Type
|
||||
getMonoFromPoly :: L.Type -> EnvM M.Type
|
||||
getMonoFromPoly t = do
|
||||
env <- ask
|
||||
return $ getMono (polys env) t
|
||||
where
|
||||
getMono :: Map.Map Ident M.Type -> T.Type -> M.Type
|
||||
getMono :: Map.Map Ident M.Type -> L.Type -> M.Type
|
||||
getMono polys t = case t of
|
||||
(T.TLit ident) -> M.TLit (coerce ident)
|
||||
(T.TFun t1 t2) -> M.TFun (getMono polys t1) (getMono polys t2)
|
||||
(T.TVar (T.MkTVar ident)) -> case Map.lookup ident polys of
|
||||
(L.TLit ident) -> M.TLit ident
|
||||
(L.TFun t1 t2) -> M.TFun (getMono polys t1) (getMono polys t2)
|
||||
(L.TVar (L.MkTVar ident)) -> case Map.lookup ident polys of
|
||||
Just concrete -> concrete
|
||||
Nothing -> M.TLit (Ident "void")
|
||||
Nothing -> M.TLit (Ident "void")
|
||||
-- error $ "type not found! type: " ++ show ident ++ ", error in previous compilation steps"
|
||||
(T.TData ident args) -> M.TData ident (map (getMono polys) args)
|
||||
(L.TData ident args) -> M.TData ident (map (getMono polys) args)
|
||||
|
||||
{- | If ident not already in env's output, morphed bind to output
|
||||
(and all referenced binds within this bind).
|
||||
Returns the annotated bind name.
|
||||
-}
|
||||
morphBind :: M.Type -> T.Bind -> EnvM Ident
|
||||
morphBind expectedType b@(T.Bind (ident, btype) args (exp, expt)) = do
|
||||
morphBind :: M.Type -> L.Bind -> EnvM Ident
|
||||
morphBind expectedType b@(L.Bind (ident, btype) args (exp, expt)) = do
|
||||
-- The "new name" is used to find out if it is already marked or not.
|
||||
let name' = newFuncName expectedType b
|
||||
bindMarked <- isBindMarked (coerce name')
|
||||
bindMarked <- isBindMarked name'
|
||||
local
|
||||
( \env ->
|
||||
env
|
||||
|
|
@ -168,26 +175,59 @@ morphBind expectedType b@(T.Bind (ident, btype) args (exp, expt)) = do
|
|||
else do
|
||||
-- Mark so that this bind will not be processed in recursive or cyclic
|
||||
-- function calls
|
||||
markBind (coerce name')
|
||||
markBind name'
|
||||
expt' <- getMonoFromPoly expt
|
||||
exp' <- morphExp expt' exp
|
||||
-- Get monomorphic type sof args
|
||||
args' <- mapM morphArg args
|
||||
addOutputBind $
|
||||
M.Bind
|
||||
(coerce name', expectedType)
|
||||
(name', expectedType)
|
||||
args'
|
||||
(exp', expt')
|
||||
return name'
|
||||
|
||||
morphBind expectedType b@(L.BindC cxt (ident, btype) args (exp, expt)) = do
|
||||
-- The "new name" is used to find out if it is already marked or not.
|
||||
let name' = newFuncName expectedType b
|
||||
bindMarked <- isBindMarked name'
|
||||
local
|
||||
( \env ->
|
||||
env
|
||||
{ locals = Set.fromList (map fst args)
|
||||
, polys = Map.fromList (mapTypes ident btype expectedType)
|
||||
}
|
||||
)
|
||||
$ do
|
||||
-- Return with right name if already marked
|
||||
if bindMarked
|
||||
then return name'
|
||||
else do
|
||||
-- Mark so that this bind will not be processed in recursive or cyclic
|
||||
-- function calls
|
||||
markBind name'
|
||||
-- Get monomorphic type sof args
|
||||
args' <- mapM morphArg args
|
||||
cxt' <- mapM (secondM getMonoFromPoly) cxt
|
||||
expt' <- getMonoFromPoly expt
|
||||
exp' <- local (\env -> foldr (addLocal . fst) env cxt)
|
||||
(morphExp expt' exp)
|
||||
addOutputBind $
|
||||
M.BindC cxt'
|
||||
(name', expectedType)
|
||||
args'
|
||||
(exp', expt')
|
||||
return name'
|
||||
|
||||
|
||||
-- | Monomorphizes arguments of a bind.
|
||||
morphArg :: (Ident, T.Type) -> EnvM (Ident, M.Type)
|
||||
morphArg :: (Ident, L.Type) -> EnvM (Ident, M.Type)
|
||||
morphArg (ident, t) = do
|
||||
t' <- getMonoFromPoly t
|
||||
return (ident, t')
|
||||
|
||||
-- | Gets the data bind from the name of a constructor.
|
||||
getInputData :: Ident -> EnvM (Maybe T.Data)
|
||||
getInputData :: Ident -> EnvM (Maybe L.Data)
|
||||
getInputData ident = do
|
||||
env <- ask
|
||||
return $ Map.lookup ident (dataDefs env)
|
||||
|
|
@ -201,50 +241,50 @@ morphCons expectedType ident newIdent = do
|
|||
--trace ("Tjofras:" ++ show (newName expectedType ident)) $ return ()
|
||||
maybeD <- getInputData ident
|
||||
case maybeD of
|
||||
Nothing -> error $ "identifier '" ++ show ident ++ "' not found"
|
||||
-- closures can have unbound variables
|
||||
Nothing -> pure ()
|
||||
Just d -> do
|
||||
modify (\output -> Map.insert newIdent (Data expectedType d) output)
|
||||
|
||||
-- | Converts literals from input to output tree.
|
||||
convertLit :: T.Lit -> M.Lit
|
||||
convertLit (T.LInt v) = M.LInt v
|
||||
convertLit (T.LChar v) = M.LChar v
|
||||
convertLit :: L.Lit -> M.Lit
|
||||
convertLit (L.LInt v) = M.LInt v
|
||||
convertLit (L.LChar v) = M.LChar v
|
||||
|
||||
|
||||
-- | Monomorphizes an expression, given an expected type.
|
||||
morphExp :: M.Type -> T.Exp -> EnvM M.Exp
|
||||
morphExp :: M.Type -> L.Exp -> EnvM M.Exp
|
||||
morphExp expectedType exp = case exp of
|
||||
T.ELit lit -> return $ M.ELit (convertLit lit)
|
||||
L.ELit lit -> return $ M.ELit lit
|
||||
-- Constructor
|
||||
T.EInj ident -> do
|
||||
L.EInj ident -> do
|
||||
let ident' = newName (getDataType expectedType) ident
|
||||
morphCons expectedType ident ident'
|
||||
return $ M.EVar ident'
|
||||
T.EApp (e1, _t1) (e2, t2) -> do
|
||||
L.EApp (e1, _t1) (e2, t2) -> do
|
||||
t2' <- getMonoFromPoly t2
|
||||
e2' <- morphExp t2' e2
|
||||
e1' <- morphExp (M.TFun t2' expectedType) e1
|
||||
return $ M.EApp (e1', M.TFun t2' expectedType) (e2', t2')
|
||||
T.EAdd (e1, t1) (e2, t2) -> do
|
||||
L.EAdd (e1, t1) (e2, t2) -> do
|
||||
t1' <- getMonoFromPoly t1
|
||||
t2' <- getMonoFromPoly t2
|
||||
e1' <- morphExp t1' e1
|
||||
e2' <- morphExp t2' e2
|
||||
return $ M.EAdd (e1', expectedType) (e2', expectedType)
|
||||
T.EAbs ident (exp, t) -> local (\env -> env{locals = Set.insert ident (locals env)}) $ do
|
||||
t' <- getMonoFromPoly t
|
||||
morphExp t' exp
|
||||
T.ECase (exp, t) bs -> do
|
||||
L.ECase (exp, t) bs -> do
|
||||
t' <- getMonoFromPoly t
|
||||
exp' <- morphExp t' exp
|
||||
bs' <- mapM morphBranch bs
|
||||
return $ M.ECase (exp', t') (catMaybes bs')
|
||||
-- Ideally constructors should be EInj, though this code handles them
|
||||
-- as well.
|
||||
T.EVar ident -> do
|
||||
-- FIXME MAKE EVAR AND EINJ SEPARATE!!!
|
||||
L.EVar ident -> do
|
||||
isLocal <- localExists ident
|
||||
if isLocal
|
||||
then do
|
||||
return $ M.EVar (coerce ident)
|
||||
return $ M.EVar ident
|
||||
else do
|
||||
bind <- getInputBind ident
|
||||
case bind of
|
||||
|
|
@ -252,38 +292,51 @@ morphExp expectedType exp = case exp of
|
|||
Just bind' -> do
|
||||
-- New bind to process
|
||||
newBindName <- morphBind expectedType bind'
|
||||
return $ M.EVar (coerce newBindName)
|
||||
T.ELet (T.Bind (identB, tB) args (expB, tExpB)) (exp, tExp) ->
|
||||
if length args > 0 then error "only constants in lets allowed"
|
||||
else do
|
||||
return $ M.EVar newBindName
|
||||
L.EVarC as ident -> do
|
||||
isLocal <- localExists ident
|
||||
if isLocal
|
||||
then do
|
||||
return $ M.EVar ident
|
||||
else do
|
||||
bind <- fromJust <$> getInputBind ident
|
||||
as' <- mapM (secondM getMonoFromPoly) as
|
||||
-- New bind to process
|
||||
newBindName <- morphBind expectedType bind
|
||||
return $ M.EVarC as' newBindName
|
||||
-- Ideally constructors should be EInj, though this code handles them
|
||||
-- as well.
|
||||
|
||||
|
||||
L.ELet (identB, tB) (expB, tExpB) (exp, tExp) -> do
|
||||
tB' <- getMonoFromPoly tB
|
||||
tExpB' <- getMonoFromPoly tExpB
|
||||
tExp' <- getMonoFromPoly tExp
|
||||
expB' <- morphExp tExpB' expB
|
||||
exp' <- morphExp tExp' exp
|
||||
exp' <- local (addLocal identB) (morphExp tExp' exp)
|
||||
return $ M.ELet (M.Bind (identB, tB') [] (expB', tExpB')) (exp', tExp')
|
||||
|
||||
-- | Monomorphizes case-of branches.
|
||||
morphBranch :: T.Branch -> EnvM (Maybe M.Branch)
|
||||
morphBranch (T.Branch (p, pt) (e, et)) = do
|
||||
morphBranch :: L.Branch -> EnvM (Maybe M.Branch)
|
||||
morphBranch (L.Branch (p, pt) (e, et)) = do
|
||||
pt' <- getMonoFromPoly pt
|
||||
et' <- getMonoFromPoly et
|
||||
env <- ask
|
||||
maybeMorphedPattern <- morphPattern p pt'
|
||||
case maybeMorphedPattern of
|
||||
Nothing -> return Nothing
|
||||
Just (p', newLocals) ->
|
||||
Just (p', newLocals) ->
|
||||
local (const env { locals = Set.union (locals env) newLocals }) $ do
|
||||
e' <- morphExp et' e
|
||||
return $ Just (M.Branch (p', pt') (e', et'))
|
||||
return $ Just (M.Branch p' (e', et'))
|
||||
|
||||
morphPattern :: T.Pattern -> M.Type -> EnvM (Maybe (M.Pattern, Set.Set Ident))
|
||||
morphPattern :: L.Pattern -> M.Type -> EnvM (Maybe (M.T M.Pattern, Set.Set Ident))
|
||||
morphPattern p expectedType = case p of
|
||||
T.PVar ident -> return $ Just (M.PVar (ident, expectedType), Set.singleton ident)
|
||||
T.PLit lit -> return $ Just (M.PLit (convertLit lit, expectedType), Set.empty)
|
||||
T.PCatch -> return $ Just (M.PCatch, Set.empty)
|
||||
T.PEnum ident -> return $ Just (M.PEnum (newName expectedType ident), Set.empty)
|
||||
T.PInj ident pts -> do let newIdent = newName expectedType ident
|
||||
L.PVar ident -> return $ Just ((M.PVar ident, expectedType), Set.singleton ident)
|
||||
L.PLit lit -> return $ Just ((M.PLit (convertLit lit), expectedType), Set.empty)
|
||||
L.PCatch -> return $ Just ((M.PCatch, expectedType), Set.empty)
|
||||
L.PEnum ident -> return $ Just ((M.PEnum (newName expectedType ident), expectedType), Set.empty)
|
||||
L.PInj ident pts -> do let newIdent = newName expectedType ident
|
||||
outEnv <- get
|
||||
trace ("WOW: " ++ show (newName expectedType ident)) $ return ()
|
||||
trace ("WOW2: " ++ show (outEnv)) $ return ()
|
||||
|
|
@ -297,13 +350,18 @@ morphPattern p expectedType = case p of
|
|||
let maybePsSets = sequence psSets
|
||||
case maybePsSets of
|
||||
Nothing -> return Nothing
|
||||
Just psSets' -> return $ Just
|
||||
(M.PInj newIdent (map fst psSets'), Set.unions $ map snd psSets')
|
||||
Just psSets' -> return $ Just
|
||||
((M.PInj newIdent (map fst psSets'), expectedType), Set.unions $ map snd psSets')
|
||||
else return Nothing
|
||||
|
||||
-- | Creates a new identifier for a function with an assigned type.
|
||||
newFuncName :: M.Type -> T.Bind -> Ident
|
||||
newFuncName t (T.Bind (ident@(Ident bindName), _) _ _) =
|
||||
newFuncName :: M.Type -> L.Bind -> Ident
|
||||
newFuncName t (L.Bind (ident@(Ident bindName), _) _ _) =
|
||||
if bindName == "main"
|
||||
then Ident bindName
|
||||
else newName t ident
|
||||
|
||||
newFuncName t (L.BindC _ (ident@(Ident bindName), _) _ _) =
|
||||
if bindName == "main"
|
||||
then Ident bindName
|
||||
else newName t ident
|
||||
|
|
@ -317,8 +375,8 @@ newName t (Ident str) = Ident $ str ++ "$" ++ newName' t
|
|||
newName' (M.TData (Ident str) ts) = str ++ foldl (\s t -> s ++ "." ++ newName' t) "" ts
|
||||
|
||||
-- | Monomorphization step.
|
||||
monomorphize :: T.Program -> O.Program
|
||||
monomorphize (T.Program defs) =
|
||||
monomorphize :: L.Program -> O.Program
|
||||
monomorphize (L.Program defs) =
|
||||
removeDataTypes $
|
||||
M.Program
|
||||
( getDefsFromOutput
|
||||
|
|
@ -336,7 +394,7 @@ runEnvM :: Output -> Env -> EnvM () -> Output
|
|||
runEnvM o env (EnvM stateM) = snd $ runReader (runStateT stateM o) env
|
||||
|
||||
-- | Creates the environment based on the input binds.
|
||||
createEnv :: [T.Def] -> Env
|
||||
createEnv :: [L.Def] -> Env
|
||||
createEnv defs =
|
||||
Env
|
||||
{ input = Map.fromList bindPairs
|
||||
|
|
@ -346,33 +404,34 @@ createEnv defs =
|
|||
}
|
||||
where
|
||||
bindPairs = (map (\b -> (getBindName b, b)) . getBindsFromDefs) defs
|
||||
dataPairs :: [(Ident, T.Data)]
|
||||
dataPairs = (foldl (\acc d@(T.Data _ cs) -> map ((,d) . getConsName) cs ++ acc) [] . getDataFromDefs) defs
|
||||
dataPairs :: [(Ident, L.Data)]
|
||||
dataPairs = (foldl (\acc d@(L.Data _ cs) -> map ((,d) . getConsName) cs ++ acc) [] . getDataFromDefs) defs
|
||||
|
||||
-- | Gets a top-lefel function name.
|
||||
getBindName :: T.Bind -> Ident
|
||||
getBindName (T.Bind (ident, _) _ _) = ident
|
||||
getBindName :: L.Bind -> Ident
|
||||
getBindName (L.Bind (ident, _) _ _) = ident
|
||||
getBindName (L.BindC _ (ident, _) _ _) = ident
|
||||
|
||||
-- Helper functions
|
||||
-- Gets custom data declarations form defs.
|
||||
getDataFromDefs :: [T.Def] -> [T.Data]
|
||||
getDataFromDefs :: [L.Def] -> [L.Data]
|
||||
getDataFromDefs =
|
||||
foldl
|
||||
( \bs -> \case
|
||||
T.DBind _ -> bs
|
||||
T.DData d -> d : bs
|
||||
L.DBind _ -> bs
|
||||
L.DData d -> d : bs
|
||||
)
|
||||
[]
|
||||
|
||||
getConsName :: T.Inj -> Ident
|
||||
getConsName (T.Inj ident _) = ident
|
||||
getConsName :: L.Inj -> Ident
|
||||
getConsName (L.Inj ident _) = ident
|
||||
|
||||
getBindsFromDefs :: [T.Def] -> [T.Bind]
|
||||
getBindsFromDefs :: [L.Def] -> [L.Bind]
|
||||
getBindsFromDefs =
|
||||
foldl
|
||||
( \bs -> \case
|
||||
T.DBind b -> b : bs
|
||||
T.DData _ -> bs
|
||||
L.DBind b -> b : bs
|
||||
L.DData _ -> bs
|
||||
)
|
||||
[]
|
||||
|
||||
|
|
@ -384,19 +443,19 @@ getDefsFromOutput o =
|
|||
(binds, dataInput) = splitBindsAndData o
|
||||
|
||||
-- | Splits the output into binds and data declaration components (used in createNewData)
|
||||
splitBindsAndData :: Output -> ([M.Bind], [(Ident, M.Type, T.Data)])
|
||||
splitBindsAndData :: Output -> ([M.Bind], [(Ident, M.Type, L.Data)])
|
||||
splitBindsAndData output =
|
||||
foldl
|
||||
( \(oBinds, oData) (ident, o) -> case o of
|
||||
Marked -> error "internal bug in monomorphizer"
|
||||
Marked -> error "internal bug in monomorphizer"
|
||||
Complete b -> (b : oBinds, oData)
|
||||
Data t d -> (oBinds, (ident, t, d) : oData)
|
||||
Data t d -> (oBinds, (ident, t, d) : oData)
|
||||
)
|
||||
([], [])
|
||||
(Map.toList output)
|
||||
|
||||
-- | Converts all found constructors to monomorphic data declarations.
|
||||
createNewData :: [(Ident, M.Type, T.Data)] -> Map.Map Ident M.Data -> Map.Map Ident M.Data
|
||||
createNewData :: [(Ident, M.Type, L.Data)] -> Map.Map Ident M.Data -> Map.Map Ident M.Data
|
||||
createNewData [] o = o
|
||||
createNewData ((consIdent, consType, polyData) : input) o =
|
||||
createNewData input $
|
||||
|
|
@ -406,14 +465,17 @@ createNewData ((consIdent, consType, polyData) : input) o =
|
|||
(M.Data newDataType [newCons])
|
||||
o
|
||||
where
|
||||
T.Data (T.TData polyDataIdent _) _ = polyData
|
||||
L.Data (L.TData polyDataIdent _) _ = polyData
|
||||
newDataType = getDataType consType
|
||||
newDataName = newName newDataType polyDataIdent
|
||||
newCons = M.Inj consIdent consType
|
||||
|
||||
-- | Gets the Data Type of a constructor type (a -> Just a becomes Just a).
|
||||
getDataType :: M.Type -> M.Type
|
||||
getDataType (M.TFun _t1 t2) = getDataType t2
|
||||
getDataType (M.TFun _t1 t2) = getDataType t2
|
||||
getDataType tData@(M.TData _ _) = tData
|
||||
getDataType _ = error "???"
|
||||
getDataType _ = error "???"
|
||||
|
||||
|
||||
addLocal :: Ident -> Env -> Env
|
||||
addLocal x env = env { locals = Set.insert x env.locals }
|
||||
|
|
|
|||
|
|
@ -1,11 +1,14 @@
|
|||
{-# LANGUAGE LambdaCase #-}
|
||||
|
||||
module Monomorphizer.MonomorphizerIr (module Monomorphizer.MonomorphizerIr) where
|
||||
module Monomorphizer.MonomorphizerIr (
|
||||
module Monomorphizer.MonomorphizerIr,
|
||||
module LambdaLifterIr
|
||||
) where
|
||||
|
||||
import Grammar.Print
|
||||
import TypeChecker.TypeCheckerIr qualified as TIR (Ident (..))
|
||||
|
||||
type Id = (TIR.Ident, Type)
|
||||
import Data.List (intercalate)
|
||||
import Grammar.Print
|
||||
import LambdaLifterIr (Ident (..), Lit (..))
|
||||
import Prelude hiding (exp)
|
||||
|
||||
newtype Program = Program [Def]
|
||||
deriving (Show, Ord, Eq)
|
||||
|
|
@ -16,90 +19,80 @@ data Def = DBind Bind | DData Data
|
|||
data Data = Data Type [Inj]
|
||||
deriving (Show, Ord, Eq)
|
||||
|
||||
data Bind = Bind Id [Id] ExpT
|
||||
data Bind = Bind (T Ident) [T Ident] (T Exp)
|
||||
| BindC [T Ident] (T Ident) [T Ident] (T Exp)
|
||||
deriving (Show, Ord, Eq)
|
||||
|
||||
type T a = (a, Type)
|
||||
|
||||
data Exp
|
||||
= EVar TIR.Ident
|
||||
= EVar Ident
|
||||
| EVarC [T Ident] Ident
|
||||
| ELit Lit
|
||||
| ELet Bind ExpT
|
||||
| EApp ExpT ExpT
|
||||
| EAdd ExpT ExpT
|
||||
| ECase ExpT [Branch]
|
||||
| ELet Bind (T Exp)
|
||||
| EApp (T Exp) (T Exp)
|
||||
| EAdd (T Exp) (T Exp)
|
||||
| ECase (T Exp) [Branch]
|
||||
deriving (Show, Ord, Eq)
|
||||
|
||||
data Pattern
|
||||
= PVar Id
|
||||
| PLit (Lit, Type)
|
||||
| PInj TIR.Ident [Pattern]
|
||||
= PVar Ident
|
||||
| PLit Lit
|
||||
| PInj Ident [T Pattern]
|
||||
| PCatch
|
||||
| PEnum TIR.Ident
|
||||
| PEnum Ident
|
||||
deriving (Eq, Ord, Show)
|
||||
|
||||
data Branch = Branch (Pattern, Type) ExpT
|
||||
data Branch = Branch (T Pattern) (T Exp)
|
||||
deriving (Eq, Ord, Show)
|
||||
|
||||
type ExpT = (Exp, Type)
|
||||
|
||||
data Inj = Inj TIR.Ident Type
|
||||
data Inj = Inj Ident Type
|
||||
deriving (Show, Ord, Eq)
|
||||
|
||||
data Lit
|
||||
= LInt Integer
|
||||
| LChar Char
|
||||
deriving (Show, Ord, Eq)
|
||||
|
||||
data Type = TLit TIR.Ident | TFun Type Type
|
||||
data Type = TLit Ident | TFun Type Type
|
||||
deriving (Show, Ord, Eq)
|
||||
|
||||
flattenType :: Type -> [Type]
|
||||
flattenType (TFun t1 t2) = t1 : flattenType t2
|
||||
flattenType x = [x]
|
||||
flattenType x = [x]
|
||||
|
||||
instance Print Program where
|
||||
prt i (Program sc) = prPrec i 0 $ prt 0 sc
|
||||
|
||||
instance Print (Bind) where
|
||||
instance Print Bind where
|
||||
prt i (Bind sig@(name, _) parms rhs) =
|
||||
prPrec i 0 $
|
||||
concatD
|
||||
[ prtSig sig
|
||||
[ prt 0 sig
|
||||
, prt 0 name
|
||||
, prtIdPs 0 parms
|
||||
, prt 0 parms
|
||||
, doc $ showString "="
|
||||
, prt 0 rhs
|
||||
]
|
||||
|
||||
prtSig :: Id -> Doc
|
||||
prtSig (name, t) =
|
||||
concatD
|
||||
[ prt 0 name
|
||||
, doc $ showString ":"
|
||||
, prt 0 t
|
||||
, doc $ showString ";"
|
||||
]
|
||||
prt i (BindC cxt sig parms rhs) =
|
||||
prPrec i 0 $
|
||||
concatD
|
||||
[ doc . showString $ "{" ++ intercalate ", " (map (\(x, _) -> printTree x ++ "^") cxt) ++ "}" ++ printTree sig
|
||||
, prt i parms
|
||||
, doc $ showString "="
|
||||
, prt i rhs
|
||||
]
|
||||
|
||||
instance Print (ExpT) where
|
||||
prt i (e, t) =
|
||||
concatD
|
||||
[ doc $ showString "("
|
||||
, prt i e
|
||||
, doc $ showString ","
|
||||
, prt i t
|
||||
, doc $ showString ")"
|
||||
]
|
||||
|
||||
instance Print [Bind] where
|
||||
prt _ [] = concatD []
|
||||
prt _ [x] = concatD [prt 0 x]
|
||||
prt _ [] = concatD []
|
||||
prt _ [x] = concatD [prt 0 x]
|
||||
prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs]
|
||||
|
||||
prtIdPs :: Int -> [Id] -> Doc
|
||||
prtIdPs i = prPrec i 0 . concatD . map (prt i)
|
||||
|
||||
instance Print Exp where
|
||||
prt i = \case
|
||||
EVar name -> prPrec i 3 $ prt 0 name
|
||||
EVarC as lident -> doc . showString
|
||||
$ "{" ++ intercalate ", " (map go as) ++ "}" ++ printTree lident
|
||||
where
|
||||
go (x, _) = printTree x ++ "^=" ++ printTree (EVar x)
|
||||
ELit lit -> prPrec i 3 $ prt 0 lit
|
||||
ELet b e ->
|
||||
prPrec i 3 $
|
||||
|
|
@ -134,16 +127,16 @@ instance Print Exp where
|
|||
]
|
||||
|
||||
instance Print Branch where
|
||||
prt i (Branch (pattern_, t) exp) = prPrec i 0 (concatD [doc (showString "("), prt 0 pattern_, doc (showString " : "), prt 0 t, doc (showString ")"), doc (showString "=>"), prt 0 exp])
|
||||
prt i (Branch patt exp) = prPrec i 0 (concatD [prt i patt, doc (showString "=>"), prt 0 exp])
|
||||
|
||||
instance Print [Branch] where
|
||||
prt _ [] = concatD []
|
||||
prt _ [x] = concatD [prt 0 x]
|
||||
prt _ [] = concatD []
|
||||
prt _ [x] = concatD [prt 0 x]
|
||||
prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs]
|
||||
|
||||
instance Print Def where
|
||||
prt i = \case
|
||||
DBind bind -> prPrec i 0 (concatD [prt 0 bind])
|
||||
DBind bind -> prPrec i 0 (concatD [prt 0 bind])
|
||||
DData data_ -> prPrec i 0 (concatD [prt 0 data_])
|
||||
|
||||
instance Print Data where
|
||||
|
|
@ -152,23 +145,23 @@ instance Print Data where
|
|||
|
||||
instance Print Inj where
|
||||
prt i = \case
|
||||
Inj uident type_ -> prPrec i 0 (concatD [prt 0 uident, doc (showString ":"), prt 0 type_])
|
||||
Inj uident type_ -> prt i (uident, type_)
|
||||
|
||||
instance Print Pattern where
|
||||
prt i = \case
|
||||
PVar name -> prPrec i 1 (concatD [prt 0 name])
|
||||
PLit (lit, _) -> prPrec i 1 (concatD [prt 0 lit])
|
||||
PLit lit -> prPrec i 1 (concatD [prt 0 lit])
|
||||
PCatch -> prPrec i 1 (concatD [doc (showString "_")])
|
||||
PEnum name -> prPrec i 1 (concatD [prt 0 name])
|
||||
PInj uident patterns -> prPrec i 0 (concatD [prt 0 uident, prt 1 patterns])
|
||||
|
||||
instance Print [Def] where
|
||||
prt _ [] = concatD []
|
||||
prt _ [x] = concatD [prt 0 x]
|
||||
prt _ [] = concatD []
|
||||
prt _ [x] = concatD [prt 0 x]
|
||||
prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs]
|
||||
|
||||
instance Print [Type] where
|
||||
prt _ [] = concatD []
|
||||
prt _ [] = concatD []
|
||||
prt _ (x : xs) = concatD [prt 0 x, doc (showString " "), prt 0 xs]
|
||||
|
||||
instance Print Type where
|
||||
|
|
@ -176,7 +169,3 @@ instance Print Type where
|
|||
TLit uident -> prPrec i 1 (concatD [prt 0 uident])
|
||||
TFun type_1 type_2 -> prPrec i 0 (concatD [prt 1 type_1, doc (showString "->"), prt 0 type_2])
|
||||
|
||||
instance Print Lit where
|
||||
prt i = \case
|
||||
LInt int -> prt i int
|
||||
LChar char -> prt i char
|
||||
|
|
|
|||
|
|
@ -1,10 +1,14 @@
|
|||
{-# LANGUAGE LambdaCase #-}
|
||||
module Monomorphizer.MorbIr where
|
||||
|
||||
import Grammar.Print
|
||||
import TypeChecker.TypeCheckerIr qualified as TIR (Ident (..))
|
||||
module Monomorphizer.MorbIr (
|
||||
module Monomorphizer.MorbIr,
|
||||
module LambdaLifterIr
|
||||
) where
|
||||
|
||||
type Id = (TIR.Ident, Type)
|
||||
import Data.List (intercalate)
|
||||
import Grammar.Print
|
||||
import LambdaLifterIr (Ident (..), Lit (..))
|
||||
import Prelude hiding (exp)
|
||||
|
||||
newtype Program = Program [Def]
|
||||
deriving (Show, Ord, Eq)
|
||||
|
|
@ -15,91 +19,81 @@ data Def = DBind Bind | DData Data
|
|||
data Data = Data Type [Inj]
|
||||
deriving (Show, Ord, Eq)
|
||||
|
||||
data Bind = Bind Id [Id] ExpT
|
||||
data Bind = Bind (T Ident) [T Ident] (T Exp)
|
||||
| BindC [T Ident] (T Ident) [T Ident] (T Exp)
|
||||
deriving (Show, Ord, Eq)
|
||||
|
||||
|
||||
type T a = (a, Type)
|
||||
|
||||
data Exp
|
||||
= EVar TIR.Ident
|
||||
= EVar Ident
|
||||
| EVarC [T Ident] Ident
|
||||
| ELit Lit
|
||||
| ELet Bind ExpT
|
||||
| EApp ExpT ExpT
|
||||
| EAdd ExpT ExpT
|
||||
| ECase ExpT [Branch]
|
||||
| ELet Bind (T Exp)
|
||||
| EApp (T Exp) (T Exp)
|
||||
| EAdd (T Exp) (T Exp)
|
||||
| ECase (T Exp) [Branch]
|
||||
deriving (Show, Ord, Eq)
|
||||
|
||||
data Pattern
|
||||
= PVar Id
|
||||
| PLit (Lit, Type)
|
||||
| PInj TIR.Ident [Pattern]
|
||||
= PVar Ident
|
||||
| PLit Lit
|
||||
| PInj Ident [T Pattern]
|
||||
| PCatch
|
||||
| PEnum TIR.Ident
|
||||
| PEnum Ident
|
||||
deriving (Eq, Ord, Show)
|
||||
|
||||
data Branch = Branch (Pattern, Type) ExpT
|
||||
|
||||
data Branch = Branch (T Pattern) (T Exp)
|
||||
deriving (Eq, Ord, Show)
|
||||
|
||||
type ExpT = (Exp, Type)
|
||||
|
||||
data Inj = Inj TIR.Ident Type
|
||||
data Inj = Inj Ident Type
|
||||
deriving (Show, Ord, Eq)
|
||||
|
||||
data Lit
|
||||
= LInt Integer
|
||||
| LChar Char
|
||||
deriving (Show, Ord, Eq)
|
||||
|
||||
data Type = TLit TIR.Ident | TFun Type Type | TData TIR.Ident [Type]
|
||||
data Type = TLit Ident | TFun Type Type | TData Ident [Type]
|
||||
|
||||
deriving (Show, Ord, Eq)
|
||||
|
||||
flattenType :: Type -> [Type]
|
||||
flattenType (TFun t1 t2) = t1 : flattenType t2
|
||||
flattenType x = [x]
|
||||
flattenType x = [x]
|
||||
|
||||
instance Print Program where
|
||||
prt i (Program sc) = prPrec i 0 $ prt 0 sc
|
||||
|
||||
instance Print (Bind) where
|
||||
instance Print Bind where
|
||||
prt i (Bind sig@(name, _) parms rhs) =
|
||||
prPrec i 0 $
|
||||
concatD
|
||||
[ prtSig sig
|
||||
[ prt 0 sig
|
||||
, prt 0 name
|
||||
, prtIdPs 0 parms
|
||||
, prt 0 parms
|
||||
, doc $ showString "="
|
||||
, prt 0 rhs
|
||||
]
|
||||
|
||||
prtSig :: Id -> Doc
|
||||
prtSig (name, t) =
|
||||
concatD
|
||||
[ prt 0 name
|
||||
, doc $ showString ":"
|
||||
, prt 0 t
|
||||
, doc $ showString ";"
|
||||
]
|
||||
|
||||
instance Print (ExpT) where
|
||||
prt i (e, t) =
|
||||
concatD
|
||||
[ doc $ showString "("
|
||||
, prt i e
|
||||
, doc $ showString ","
|
||||
, prt i t
|
||||
, doc $ showString ")"
|
||||
]
|
||||
prt i (BindC cxt sig parms rhs) =
|
||||
prPrec i 0 $
|
||||
concatD
|
||||
[ doc . showString $ "{" ++ intercalate ", " (map (\(x, _) -> printTree x ++ "^") cxt) ++ "}" ++ printTree sig
|
||||
, prt i parms
|
||||
, doc $ showString "="
|
||||
, prt i rhs
|
||||
]
|
||||
|
||||
instance Print [Bind] where
|
||||
prt _ [] = concatD []
|
||||
prt _ [x] = concatD [prt 0 x]
|
||||
prt _ [] = concatD []
|
||||
prt _ [x] = concatD [prt 0 x]
|
||||
prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs]
|
||||
|
||||
prtIdPs :: Int -> [Id] -> Doc
|
||||
prtIdPs i = prPrec i 0 . concatD . map (prt i)
|
||||
|
||||
instance Print Exp where
|
||||
prt i = \case
|
||||
EVar name -> prPrec i 3 $ prt 0 name
|
||||
EVarC as lident -> doc . showString
|
||||
$ "{" ++ intercalate ", " (map go as) ++ "}" ++ printTree lident
|
||||
where
|
||||
go (x, _) = printTree x ++ "^=" ++ printTree (EVar x)
|
||||
ELit lit -> prPrec i 3 $ prt 0 lit
|
||||
ELet b e ->
|
||||
prPrec i 3 $
|
||||
|
|
@ -134,16 +128,16 @@ instance Print Exp where
|
|||
]
|
||||
|
||||
instance Print Branch where
|
||||
prt i (Branch (pattern_, t) exp) = prPrec i 0 (concatD [doc (showString "("), prt 0 pattern_, doc (showString " : "), prt 0 t, doc (showString ")"), doc (showString "=>"), prt 0 exp])
|
||||
prt i (Branch patt exp) = prPrec i 0 (concatD [prt i patt, doc (showString "=>"), prt 0 exp])
|
||||
|
||||
instance Print [Branch] where
|
||||
prt _ [] = concatD []
|
||||
prt _ [x] = concatD [prt 0 x]
|
||||
prt _ [] = concatD []
|
||||
prt _ [x] = concatD [prt 0 x]
|
||||
prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs]
|
||||
|
||||
instance Print Def where
|
||||
prt i = \case
|
||||
DBind bind -> prPrec i 0 (concatD [prt 0 bind])
|
||||
DBind bind -> prPrec i 0 (concatD [prt 0 bind])
|
||||
DData data_ -> prPrec i 0 (concatD [prt 0 data_])
|
||||
|
||||
instance Print Data where
|
||||
|
|
@ -152,23 +146,23 @@ instance Print Data where
|
|||
|
||||
instance Print Inj where
|
||||
prt i = \case
|
||||
Inj uident type_ -> prPrec i 0 (concatD [prt 0 uident, doc (showString ":"), prt 0 type_])
|
||||
Inj uident type_ -> prt i (uident, type_)
|
||||
|
||||
instance Print Pattern where
|
||||
prt i = \case
|
||||
PVar name -> prPrec i 1 (concatD [prt 0 name])
|
||||
PLit (lit, _) -> prPrec i 1 (concatD [prt 0 lit])
|
||||
PLit lit -> prPrec i 1 (concatD [prt 0 lit])
|
||||
PCatch -> prPrec i 1 (concatD [doc (showString "_")])
|
||||
PEnum name -> prPrec i 1 (concatD [prt 0 name])
|
||||
PInj uident patterns -> prPrec i 0 (concatD [prt 0 uident, prt 1 patterns])
|
||||
|
||||
instance Print [Def] where
|
||||
prt _ [] = concatD []
|
||||
prt _ [x] = concatD [prt 0 x]
|
||||
prt _ [] = concatD []
|
||||
prt _ [x] = concatD [prt 0 x]
|
||||
prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs]
|
||||
|
||||
instance Print [Type] where
|
||||
prt _ [] = concatD []
|
||||
prt _ [] = concatD []
|
||||
prt _ (x : xs) = concatD [prt 0 x, doc (showString " "), prt 0 xs]
|
||||
|
||||
instance Print Type where
|
||||
|
|
@ -177,8 +171,4 @@ instance Print Type where
|
|||
TFun type_1 type_2 -> prPrec i 0 (concatD [prt 1 type_1, doc (showString "->"), prt 0 type_2])
|
||||
TData uident types -> prPrec i 1 (concatD [prt 0 uident, doc (showString "("), prt 0 types, doc (showString ")")])
|
||||
|
||||
instance Print Lit where
|
||||
prt i = \case
|
||||
LInt int -> prt i int
|
||||
LChar char -> prt i char
|
||||
|
||||
|
|
|
|||
|
|
@ -2,15 +2,15 @@
|
|||
|
||||
module TypeChecker.ReportTEVar where
|
||||
|
||||
import Auxiliary (onM)
|
||||
import Control.Applicative (Applicative (liftA2), liftA3)
|
||||
import Control.Monad.Except (MonadError (throwError))
|
||||
import Data.Coerce (coerce)
|
||||
import Data.Tuple.Extra (secondM)
|
||||
import Grammar.Abs qualified as G
|
||||
import Grammar.ErrM (Err)
|
||||
import Grammar.Print (printTree)
|
||||
import TypeChecker.TypeCheckerIr hiding (Type (..))
|
||||
import Auxiliary (onM)
|
||||
import Control.Applicative (Applicative (liftA2), liftA3)
|
||||
import Control.Monad.Except (MonadError (throwError))
|
||||
import Data.Coerce (coerce)
|
||||
import Data.Tuple.Extra (secondM)
|
||||
import qualified Grammar.Abs as G
|
||||
import Grammar.ErrM (Err)
|
||||
import Grammar.Print (printTree)
|
||||
import TypeChecker.TypeCheckerIr hiding (Type (..))
|
||||
|
||||
data Type
|
||||
= TLit Ident
|
||||
|
|
@ -18,7 +18,7 @@ data Type
|
|||
| TData Ident [Type]
|
||||
| TFun Type Type
|
||||
| TAll TVar Type
|
||||
deriving (Eq, Ord, Show, Read)
|
||||
deriving (Eq, Ord, Show)
|
||||
|
||||
class ReportTEVar a b where
|
||||
reportTEVar :: a -> Err b
|
||||
|
|
@ -29,20 +29,20 @@ instance ReportTEVar (Program' G.Type) (Program' Type) where
|
|||
instance ReportTEVar (Def' G.Type) (Def' Type) where
|
||||
reportTEVar = \case
|
||||
DBind bind -> DBind <$> reportTEVar bind
|
||||
DData dat -> DData <$> reportTEVar dat
|
||||
DData dat -> DData <$> reportTEVar dat
|
||||
|
||||
instance ReportTEVar (Bind' G.Type) (Bind' Type) where
|
||||
reportTEVar (Bind id vars rhs) = liftA3 Bind (reportTEVar id) (reportTEVar vars) (reportTEVar rhs)
|
||||
|
||||
instance ReportTEVar (Exp' G.Type) (Exp' Type) where
|
||||
reportTEVar exp = case exp of
|
||||
EVar name -> pure $ EVar name
|
||||
EInj name -> pure $ EInj name
|
||||
ELit lit -> pure $ ELit lit
|
||||
ELet bind e -> liftA2 ELet (reportTEVar bind) (reportTEVar e)
|
||||
EApp e1 e2 -> onM EApp reportTEVar e1 e2
|
||||
EAdd e1 e2 -> onM EAdd reportTEVar e1 e2
|
||||
EAbs name e -> EAbs name <$> reportTEVar e
|
||||
EVar name -> pure $ EVar name
|
||||
EInj name -> pure $ EInj name
|
||||
ELit lit -> pure $ ELit lit
|
||||
ELet bind e -> liftA2 ELet (reportTEVar bind) (reportTEVar e)
|
||||
EApp e1 e2 -> onM EApp reportTEVar e1 e2
|
||||
EAdd e1 e2 -> onM EAdd reportTEVar e1 e2
|
||||
EAbs name e -> EAbs name <$> reportTEVar e
|
||||
ECase e branches -> liftA2 ECase (reportTEVar e) (reportTEVar branches)
|
||||
|
||||
instance ReportTEVar (Branch' G.Type) (Branch' Type) where
|
||||
|
|
@ -53,10 +53,10 @@ instance ReportTEVar (Pattern' G.Type, G.Type) (Pattern' Type, Type) where
|
|||
|
||||
instance ReportTEVar (Pattern' G.Type) (Pattern' Type) where
|
||||
reportTEVar = \case
|
||||
PVar name -> pure $ PVar name
|
||||
PLit lit -> pure $ PLit lit
|
||||
PCatch -> pure PCatch
|
||||
PEnum name -> pure $ PEnum name
|
||||
PVar name -> pure $ PVar name
|
||||
PLit lit -> pure $ PLit lit
|
||||
PCatch -> pure PCatch
|
||||
PEnum name -> pure $ PEnum name
|
||||
PInj name ps -> PInj name <$> reportTEVar ps
|
||||
|
||||
instance ReportTEVar (Data' G.Type) (Data' Type) where
|
||||
|
|
@ -65,10 +65,10 @@ instance ReportTEVar (Data' G.Type) (Data' Type) where
|
|||
instance ReportTEVar (Inj' G.Type) (Inj' Type) where
|
||||
reportTEVar (Inj name typ) = Inj name <$> reportTEVar typ
|
||||
|
||||
instance ReportTEVar (Id' G.Type) (Id' Type) where
|
||||
instance ReportTEVar (a, G.Type) (a, Type) where
|
||||
reportTEVar = secondM reportTEVar
|
||||
|
||||
instance ReportTEVar (ExpT' G.Type) (ExpT' Type) where
|
||||
instance ReportTEVar (T' Exp' G.Type) (T' Exp' Type) where
|
||||
reportTEVar (exp, typ) = liftA2 (,) (reportTEVar exp) (reportTEVar typ)
|
||||
|
||||
instance ReportTEVar a b => ReportTEVar [a] [b] where
|
||||
|
|
@ -76,9 +76,9 @@ instance ReportTEVar a b => ReportTEVar [a] [b] where
|
|||
|
||||
instance ReportTEVar G.Type Type where
|
||||
reportTEVar = \case
|
||||
G.TLit lit -> pure $ TLit (coerce lit)
|
||||
G.TVar (G.MkTVar i) -> pure $ TVar (MkTVar $ coerce i)
|
||||
G.TData name typs -> TData (coerce name) <$> reportTEVar typs
|
||||
G.TFun t1 t2 -> liftA2 TFun (reportTEVar t1) (reportTEVar t2)
|
||||
G.TLit lit -> pure $ TLit (coerce lit)
|
||||
G.TVar (G.MkTVar i) -> pure $ TVar (MkTVar $ coerce i)
|
||||
G.TData name typs -> TData (coerce name) <$> reportTEVar typs
|
||||
G.TFun t1 t2 -> liftA2 TFun (reportTEVar t1) (reportTEVar t2)
|
||||
G.TAll (G.MkTVar i) t -> TAll (MkTVar $ coerce i) <$> reportTEVar t
|
||||
G.TEVar tevar -> throwError ("Found TEVar: " ++ printTree tevar)
|
||||
G.TEVar tevar -> throwError ("Found TEVar: " ++ printTree tevar)
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ import Grammar.ErrM
|
|||
import Grammar.Print (printTree)
|
||||
import Prelude hiding (exp)
|
||||
import qualified TypeChecker.TypeCheckerIr as T
|
||||
import TypeChecker.TypeCheckerIr (T, T')
|
||||
|
||||
-- Implementation is derived from the paper (Dunfield and Krishnaswami 2013)
|
||||
-- https://doi.org/10.1145/2500365.2500582
|
||||
|
|
@ -172,7 +173,7 @@ typecheckInj (Inj inj_name inj_typ) name tvars
|
|||
|
||||
-- | Γ ⊢ e ↑ A ⊣ Δ
|
||||
-- Under input context Γ, e checks against input type A, with output context ∆
|
||||
check :: Exp -> Type -> Tc (T.ExpT' Type)
|
||||
check :: Exp -> Type -> Tc (T' T.Exp' Type)
|
||||
|
||||
-- Γ,α ⊢ e ↑ A ⊣ Δ,α,Θ
|
||||
-- ------------------- ∀I
|
||||
|
|
@ -212,12 +213,6 @@ check (ECase scrut pi) c = do
|
|||
e' <- check e c
|
||||
pure (T.Branch p' e')
|
||||
apply (T.ECase (scrut', a) pi', c)
|
||||
where
|
||||
go (pi, b) (Branch p e) = do
|
||||
p' <- checkPattern p =<< apply a
|
||||
e'@(_, b') <- infer e
|
||||
subtype b' b
|
||||
apply (T.Branch p' e' : pi, b')
|
||||
|
||||
|
||||
-- Γ,α ⊢ e ↓ A ⊣ Θ Θ ⊢ [Θ]A <: [Θ]B ⊣ Δ
|
||||
|
|
@ -229,9 +224,6 @@ check e b = do
|
|||
subtype a b'
|
||||
apply (e', b)
|
||||
|
||||
|
||||
|
||||
|
||||
checkPattern :: Pattern -> Type -> Tc (T.Pattern' Type, Type)
|
||||
checkPattern patt t_patt = case patt of
|
||||
|
||||
|
|
@ -297,7 +289,7 @@ checkPattern patt t_patt = case patt of
|
|||
|
||||
-- | Γ ⊢ e ↓ A ⊣ Δ
|
||||
-- Under input context Γ, e infers output type A, with output context ∆
|
||||
infer :: Exp -> Tc (T.ExpT' Type)
|
||||
infer :: Exp -> Tc (T' T.Exp' Type)
|
||||
infer (ELit lit) = apply (T.ELit lit, litType lit)
|
||||
|
||||
-- Γ ∋ (x : A) Γ ⊢ rec(x)
|
||||
|
|
@ -391,7 +383,7 @@ infer (ECase scrut pi) = do
|
|||
-- | Γ ⊢ A • e ⇓ C ⊣ Δ
|
||||
-- Under input context Γ , applying a function of type A to e infers type C, with output context ∆
|
||||
-- Instantiate existential type variables until there is an arrow type.
|
||||
applyInfer :: Type -> Exp -> Tc (T.ExpT' Type, Type)
|
||||
applyInfer :: Type -> Exp -> Tc (T' T.Exp' Type, Type)
|
||||
|
||||
-- Γ,ά ⊢ [ά/α]A • e ⇓ C ⊣ Δ
|
||||
-- ------------------------ ∀App
|
||||
|
|
|
|||
|
|
@ -1,32 +1,32 @@
|
|||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE OverloadedRecordDot #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE QualifiedDo #-}
|
||||
{-# OPTIONS_GHC -Wno-incomplete-patterns #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE QualifiedDo #-}
|
||||
|
||||
-- | A module for type checking and inference using algorithm W, Hindley-Milner
|
||||
module TypeChecker.TypeCheckerHm where
|
||||
|
||||
import Auxiliary (int, litType, maybeToRightM, unzip4)
|
||||
import Auxiliary qualified as Aux
|
||||
import Control.Monad.Except
|
||||
import Control.Monad.Identity (Identity, runIdentity)
|
||||
import Control.Monad.Reader
|
||||
import Control.Monad.State
|
||||
import Control.Monad.Writer
|
||||
import Data.Coerce (coerce)
|
||||
import Data.Function (on)
|
||||
import Data.List (foldl', nub, sortOn)
|
||||
import Data.List.Extra (unsnoc)
|
||||
import Data.Map (Map)
|
||||
import Data.Map qualified as M
|
||||
import Data.Maybe (fromJust)
|
||||
import Data.Set (Set)
|
||||
import Data.Set qualified as S
|
||||
import Debug.Trace (trace, traceShow)
|
||||
import Grammar.Abs
|
||||
import Grammar.Print (printTree)
|
||||
import TypeChecker.TypeCheckerIr qualified as T
|
||||
import Auxiliary (int, litType, maybeToRightM, unzip4)
|
||||
import qualified Auxiliary as Aux
|
||||
import Control.Monad.Except
|
||||
import Control.Monad.Identity (Identity, runIdentity)
|
||||
import Control.Monad.Reader
|
||||
import Control.Monad.State
|
||||
import Control.Monad.Writer
|
||||
import Data.Coerce (coerce)
|
||||
import Data.Function (on)
|
||||
import Data.List (foldl', nub, sortOn)
|
||||
import Data.List.Extra (unsnoc)
|
||||
import Data.Map (Map)
|
||||
import qualified Data.Map as M
|
||||
import Data.Maybe (fromJust)
|
||||
import Data.Set (Set)
|
||||
import qualified Data.Set as S
|
||||
import Debug.Trace (trace, traceShow)
|
||||
import Grammar.Abs
|
||||
import Grammar.Print (printTree)
|
||||
import qualified TypeChecker.TypeCheckerIr as T
|
||||
import TypeChecker.TypeCheckerIr (T, T')
|
||||
|
||||
{-
|
||||
TODO
|
||||
|
|
@ -41,7 +41,7 @@ typecheck :: Program -> Either String (T.Program' Type, [Warning])
|
|||
typecheck = onLeft msg . run . checkPrg
|
||||
where
|
||||
onLeft :: (Error -> String) -> Either Error a -> Either String a
|
||||
onLeft f (Left x) = Left $ f x
|
||||
onLeft f (Left x) = Left $ f x
|
||||
onLeft _ (Right x) = Right x
|
||||
|
||||
checkPrg :: Program -> Infer (T.Program' Type)
|
||||
|
|
@ -68,13 +68,13 @@ prettify s (T.Program defs) = T.Program $ map (go s) defs
|
|||
|
||||
replace :: Map T.Ident T.Ident -> Type -> Type
|
||||
replace m def@(TVar (MkTVar (LIdent a))) = case M.lookup (coerce a) m of
|
||||
Just t -> TVar . MkTVar . LIdent $ coerce t
|
||||
Just t -> TVar . MkTVar . LIdent $ coerce t
|
||||
Nothing -> def
|
||||
replace m (TFun t1 t2) = (TFun `on` replace m) t1 t2
|
||||
replace m (TData name ts) = TData name (map (replace m) ts)
|
||||
replace m def@(TAll (MkTVar forall_) t) = case M.lookup (coerce forall_) m of
|
||||
Just found -> TAll (MkTVar $ coerce found) (replace m t)
|
||||
Nothing -> def
|
||||
Nothing -> def
|
||||
replace _ t = t
|
||||
|
||||
bindCount :: [Def] -> Infer [(Int, Def)]
|
||||
|
|
@ -128,7 +128,7 @@ preRun (x : xs) = case x of
|
|||
s <- gets sigs
|
||||
case M.lookup (coerce n) s of
|
||||
Nothing -> insertSig (coerce n) Nothing >> preRun xs
|
||||
Just _ -> preRun xs
|
||||
Just _ -> preRun xs
|
||||
DData d@(Data t _) -> collect (collectTVars t) >> checkData d >> preRun xs
|
||||
where
|
||||
-- Check if function body / signature has been declared already
|
||||
|
|
@ -150,11 +150,11 @@ checkDef (x : xs) = case x of
|
|||
T.Data t $ map (\(Inj name typ) -> T.Inj (coerce name) typ) injs
|
||||
|
||||
freeOrdered :: Type -> [T.Ident]
|
||||
freeOrdered (TVar (MkTVar a)) = return (coerce a)
|
||||
freeOrdered (TVar (MkTVar a)) = return (coerce a)
|
||||
freeOrdered (TAll (MkTVar bound) t) = return (coerce bound) ++ freeOrdered t
|
||||
freeOrdered (TFun a b) = freeOrdered a ++ freeOrdered b
|
||||
freeOrdered (TData _ a) = concatMap freeOrdered a
|
||||
freeOrdered _ = mempty
|
||||
freeOrdered (TFun a b) = freeOrdered a ++ freeOrdered b
|
||||
freeOrdered (TData _ a) = concatMap freeOrdered a
|
||||
freeOrdered _ = mempty
|
||||
|
||||
-- Much cleaner implementation, unfortunately one minor bug
|
||||
-- checkBind :: Bind -> Infer (T.Bind' Type)
|
||||
|
|
@ -257,13 +257,13 @@ checkInj (Inj c inj_typ) name tvars
|
|||
toTVar :: Type -> Either Error TVar
|
||||
toTVar = \case
|
||||
TVar tvar -> pure tvar
|
||||
_ -> uncatchableErr "Not a type variable"
|
||||
_ -> uncatchableErr "Not a type variable"
|
||||
|
||||
returnType :: Type -> Type
|
||||
returnType (TFun _ t2) = returnType t2
|
||||
returnType a = a
|
||||
returnType a = a
|
||||
|
||||
inferExp :: Exp -> Infer (T.ExpT' Type)
|
||||
inferExp :: Exp -> Infer (T' T.Exp' Type)
|
||||
inferExp e = do
|
||||
(s, (e', t)) <- algoW e
|
||||
let subbed = apply s t
|
||||
|
|
@ -274,7 +274,7 @@ class CollectTVars a where
|
|||
|
||||
instance CollectTVars Exp where
|
||||
collectTVars (EAnn e t) = collectTVars t `S.union` collectTVars e
|
||||
collectTVars _ = S.empty
|
||||
collectTVars _ = S.empty
|
||||
|
||||
instance CollectTVars Type where
|
||||
collectTVars (TVar (MkTVar i)) = S.singleton (coerce i)
|
||||
|
|
@ -287,7 +287,7 @@ instance CollectTVars Type where
|
|||
collect :: Set T.Ident -> Infer ()
|
||||
collect s = modify (\st -> st{takenTypeVars = s `S.union` takenTypeVars st})
|
||||
|
||||
algoW :: Exp -> Infer (Subst, T.ExpT' Type)
|
||||
algoW :: Exp -> Infer (Subst, T' T.Exp' Type)
|
||||
algoW = \case
|
||||
err@(EAnn e t) -> do
|
||||
(sub0, (e', t')) <- exprErr (algoW e) err
|
||||
|
|
@ -600,12 +600,12 @@ generalize :: Map T.Ident Type -> Type -> Type
|
|||
generalize env t = go (S.toList $ free t S.\\ free env) (removeForalls t)
|
||||
where
|
||||
go :: [T.Ident] -> Type -> Type
|
||||
go [] t = t
|
||||
go [] t = t
|
||||
go (x : xs) t = TAll (MkTVar (coerce x)) (go xs t)
|
||||
removeForalls :: Type -> Type
|
||||
removeForalls (TAll _ t) = removeForalls t
|
||||
removeForalls (TAll _ t) = removeForalls t
|
||||
removeForalls (TFun t1 t2) = TFun (removeForalls t1) (removeForalls t2)
|
||||
removeForalls t = t
|
||||
removeForalls t = t
|
||||
|
||||
{- | Instantiate a polymorphic type. The free type variables are substituted
|
||||
with fresh ones.
|
||||
|
|
@ -643,7 +643,7 @@ fresh = do
|
|||
ungo :: [TVar] -> Type -> Type -> Bool
|
||||
ungo tvars t1 t2 = case run (go tvars t1 t2) of
|
||||
Right (b, _) -> b
|
||||
_ -> False
|
||||
_ -> False
|
||||
-- TODO: Fix the following
|
||||
-- Maybe locally using the Infer monad can cause trouble.
|
||||
-- Since the fresh count starts from zero
|
||||
|
|
@ -656,7 +656,7 @@ fresh = do
|
|||
skipForalls :: Type -> Type
|
||||
skipForalls = \case
|
||||
TAll _ t -> skipForalls t
|
||||
t -> t
|
||||
t -> t
|
||||
|
||||
freshen :: Type -> Infer Type
|
||||
freshen t = do
|
||||
|
|
@ -705,10 +705,10 @@ instance SubstType Type where
|
|||
TLit _ -> t
|
||||
TVar (MkTVar a) -> case M.lookup (coerce a) sub of
|
||||
Nothing -> TVar (MkTVar $ coerce a)
|
||||
Just t -> t
|
||||
Just t -> t
|
||||
TAll (MkTVar i) t -> case M.lookup (coerce i) sub of
|
||||
Nothing -> TAll (MkTVar i) (apply sub t)
|
||||
Just _ -> apply sub t
|
||||
Just _ -> apply sub t
|
||||
TFun a b -> TFun (apply sub a) (apply sub b)
|
||||
TData name a -> TData name (apply sub a)
|
||||
TEVar (MkTEVar _) -> t
|
||||
|
|
@ -724,7 +724,7 @@ instance SubstType (Map T.Ident Type) where
|
|||
instance SubstType (Map T.Ident (Maybe Type)) where
|
||||
apply s = M.map (fmap $ apply s)
|
||||
|
||||
instance SubstType (T.ExpT' Type) where
|
||||
instance SubstType (T' T.Exp' Type) where
|
||||
apply s (e, t) = (apply s e, apply s t)
|
||||
|
||||
instance SubstType (T.Exp' Type) where
|
||||
|
|
@ -753,10 +753,10 @@ instance SubstType (T.Branch' Type) where
|
|||
instance SubstType (T.Pattern' Type) where
|
||||
apply s = \case
|
||||
T.PVar iden -> T.PVar iden
|
||||
T.PLit lit -> T.PLit lit
|
||||
T.PLit lit -> T.PLit lit
|
||||
T.PInj i ps -> T.PInj i $ apply s ps
|
||||
T.PCatch -> T.PCatch
|
||||
T.PEnum i -> T.PEnum i
|
||||
T.PCatch -> T.PCatch
|
||||
T.PEnum i -> T.PEnum i
|
||||
|
||||
instance SubstType (T.Pattern' Type, Type) where
|
||||
apply s (p, t) = (apply s p, apply s t)
|
||||
|
|
@ -764,7 +764,7 @@ instance SubstType (T.Pattern' Type, Type) where
|
|||
instance SubstType a => SubstType [a] where
|
||||
apply s = map (apply s)
|
||||
|
||||
instance SubstType (T.Id' Type) where
|
||||
instance SubstType (T T.Ident Type) where
|
||||
apply s (name, t) = (name, apply s t)
|
||||
|
||||
-- | Represents the empty substition set
|
||||
|
|
@ -797,11 +797,11 @@ withBindings xs =
|
|||
-- | Run the monadic action with a pattern
|
||||
withPattern :: (Monad m, MonadReader Ctx m) => (T.Pattern' Type, Type) -> m a -> m a
|
||||
withPattern (p, t) ma = case p of
|
||||
T.PVar x -> withBinding x t ma
|
||||
T.PVar x -> withBinding x t ma
|
||||
T.PInj _ ps -> foldl' (flip withPattern) ma ps
|
||||
T.PLit _ -> ma
|
||||
T.PCatch -> ma
|
||||
T.PEnum _ -> ma
|
||||
T.PLit _ -> ma
|
||||
T.PCatch -> ma
|
||||
T.PEnum _ -> ma
|
||||
|
||||
-- | Insert a function signature into the environment
|
||||
insertSig :: T.Ident -> Maybe Type -> Infer ()
|
||||
|
|
@ -826,11 +826,11 @@ existInj n = gets (M.lookup n . injections)
|
|||
|
||||
flattenType :: Type -> [Type]
|
||||
flattenType (TFun a b) = flattenType a <> flattenType b
|
||||
flattenType a = [a]
|
||||
flattenType a = [a]
|
||||
|
||||
typeLength :: Type -> Int
|
||||
typeLength (TFun _ b) = 1 + typeLength b
|
||||
typeLength _ = 1
|
||||
typeLength _ = 1
|
||||
|
||||
{- | Catch an error if possible and add the given
|
||||
expression as addition to the error message
|
||||
|
|
@ -913,11 +913,11 @@ newtype Ctx = Ctx {vars :: Map T.Ident Type}
|
|||
deriving (Show)
|
||||
|
||||
data Env = Env
|
||||
{ count :: Int
|
||||
, nextChar :: Char
|
||||
, sigs :: Map T.Ident (Maybe Type)
|
||||
{ count :: Int
|
||||
, nextChar :: Char
|
||||
, sigs :: Map T.Ident (Maybe Type)
|
||||
, takenTypeVars :: Set T.Ident
|
||||
, injections :: Map T.Ident Type
|
||||
, injections :: Map T.Ident Type
|
||||
, declaredBinds :: Set T.Ident
|
||||
}
|
||||
deriving (Show)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE PatternSynonyms #-}
|
||||
|
||||
|
||||
module TypeChecker.TypeCheckerIr (
|
||||
module Grammar.Abs,
|
||||
module TypeChecker.TypeCheckerIr,
|
||||
|
|
@ -10,31 +11,30 @@ import Data.String (IsString)
|
|||
import Grammar.Abs (Lit (..))
|
||||
import Grammar.Print
|
||||
import Prelude
|
||||
import qualified Prelude as C (Eq, Ord, Read, Show)
|
||||
|
||||
newtype Program' t = Program [Def' t]
|
||||
deriving (C.Eq, C.Ord, C.Show, C.Read, Functor)
|
||||
deriving (Eq, Ord, Show, Functor)
|
||||
|
||||
data Def' t
|
||||
= DBind (Bind' t)
|
||||
| DData (Data' t)
|
||||
deriving (C.Eq, C.Ord, C.Show, C.Read, Functor)
|
||||
deriving (Eq, Ord, Show, Functor)
|
||||
|
||||
data Type
|
||||
= TLit Ident
|
||||
| TVar TVar
|
||||
| TData Ident [Type]
|
||||
| TFun Type Type
|
||||
deriving (Eq, Ord, Show, Read)
|
||||
deriving (Eq, Ord, Show)
|
||||
|
||||
data Data' t = Data t [Inj' t]
|
||||
deriving (C.Eq, C.Ord, C.Show, C.Read, Functor)
|
||||
deriving (Eq, Ord, Show, Functor)
|
||||
|
||||
data Inj' t = Inj Ident t
|
||||
deriving (C.Eq, C.Ord, C.Show, C.Read, Functor)
|
||||
deriving (Eq, Ord, Show, Functor)
|
||||
|
||||
newtype Ident = Ident String
|
||||
deriving (C.Eq, C.Ord, C.Show, C.Read, IsString)
|
||||
deriving (Eq, Ord, Show, IsString)
|
||||
|
||||
data Pattern' t
|
||||
= PVar Ident
|
||||
|
|
@ -42,30 +42,31 @@ data Pattern' t
|
|||
| PCatch
|
||||
| PEnum Ident
|
||||
| PInj Ident [(Pattern' t, t)]
|
||||
deriving (C.Eq, C.Ord, C.Show, C.Read, Functor)
|
||||
deriving (Eq, Ord, Show, Functor)
|
||||
|
||||
data Exp' t
|
||||
= EVar Ident
|
||||
| EInj Ident
|
||||
| ELit Lit
|
||||
| ELet (Bind' t) (ExpT' t)
|
||||
| EApp (ExpT' t) (ExpT' t)
|
||||
| EAdd (ExpT' t) (ExpT' t)
|
||||
| EAbs Ident (ExpT' t)
|
||||
| ECase (ExpT' t) [Branch' t]
|
||||
deriving (C.Eq, C.Ord, C.Show, C.Read, Functor)
|
||||
| ELet (Bind' t) (T' Exp' t)
|
||||
| EApp (T' Exp' t) (T' Exp' t)
|
||||
| EAdd (T' Exp' t) (T' Exp' t)
|
||||
| EAbs Ident (T' Exp' t)
|
||||
| ECase (T' Exp' t) [Branch' t]
|
||||
deriving (Eq, Ord, Show, Functor)
|
||||
|
||||
newtype TVar = MkTVar Ident
|
||||
deriving (C.Eq, C.Ord, C.Show, C.Read)
|
||||
deriving (Eq, Ord, Show)
|
||||
|
||||
type Id' t = (Ident, t)
|
||||
type ExpT' t = (Exp' t, t)
|
||||
type T' a t = (a t, t)
|
||||
type T a t = (a, t)
|
||||
|
||||
data Bind' t = Bind (Id' t) [Id' t] (ExpT' t)
|
||||
deriving (C.Eq, C.Ord, C.Show, C.Read, Functor)
|
||||
|
||||
data Branch' t = Branch (Pattern' t, t) (ExpT' t)
|
||||
deriving (C.Eq, C.Ord, C.Show, C.Read, Functor)
|
||||
data Bind' t = Bind (T Ident t) [T Ident t] (T' Exp' t)
|
||||
deriving (Eq, Ord, Show, Functor)
|
||||
|
||||
data Branch' t = Branch (T' Pattern' t) (T' Exp' t)
|
||||
deriving (Eq, Ord, Show, Functor)
|
||||
|
||||
instance Print Ident where
|
||||
prt _ (Ident s) = doc $ showString s
|
||||
|
|
@ -81,22 +82,22 @@ instance Print t => Print (Bind' t) where
|
|||
, prt i rhs
|
||||
]
|
||||
|
||||
prtSig :: Print t => Id' t -> Doc
|
||||
prtSig (name, t) =
|
||||
prtSig :: Print t => T Ident t -> Doc
|
||||
prtSig (x, t) =
|
||||
concatD
|
||||
[ prt 0 name
|
||||
[ prt 0 x
|
||||
, doc $ showString ":"
|
||||
, prt 0 t
|
||||
]
|
||||
|
||||
instance Print t => Print (ExpT' t) where
|
||||
prt i (e, t) =
|
||||
instance (Print a, Print t) => Print (T a t) where
|
||||
prt i (x, t) =
|
||||
concatD
|
||||
[ doc $ showString "("
|
||||
, prt i e
|
||||
, doc $ showString ":"
|
||||
, prt 0 t
|
||||
, doc $ showString ")"
|
||||
[ -- doc $ showString "("
|
||||
{- , -} prt i x
|
||||
-- , doc $ showString ":"
|
||||
-- , prt 0 t
|
||||
-- , doc $ showString ")"
|
||||
]
|
||||
|
||||
instance Print t => Print [Bind' t] where
|
||||
|
|
@ -104,15 +105,6 @@ instance Print t => Print [Bind' t] where
|
|||
prt i [x] = concatD [prt i x]
|
||||
prt i (x : xs) = concatD [prt i x, doc (showString ";"), prt i xs]
|
||||
|
||||
instance Print t => Print (Id' t) where
|
||||
prt i (name, t) =
|
||||
concatD
|
||||
[ doc $ showString "("
|
||||
, prt i name
|
||||
, doc $ showString ","
|
||||
, prt i t
|
||||
, doc $ showString ")"
|
||||
]
|
||||
|
||||
instance Print t => Print (Exp' t) where
|
||||
prt i = \case
|
||||
|
|
@ -151,9 +143,6 @@ instance Print t => Print [Inj' t] where
|
|||
prt i [x] = prt i x
|
||||
prt i (x : xs) = prPrec i 0 $ concatD [prt i x, doc $ showString "\n ", prt i xs]
|
||||
|
||||
instance Print t => Print (Pattern' t, t) where
|
||||
prt i (p, t) = prPrec i 1 (concatD [prt i p, prt i t])
|
||||
|
||||
instance Print t => Print (Pattern' t) where
|
||||
prt i = \case
|
||||
PVar name -> prPrec i 1 (concatD [prt 0 name])
|
||||
|
|
@ -189,8 +178,6 @@ type Branch = Branch' Type
|
|||
type Pattern = Pattern' Type
|
||||
type Inj = Inj' Type
|
||||
type Exp = Exp' Type
|
||||
type ExpT = ExpT' Type
|
||||
type Id = Id' Type
|
||||
pattern TVar' s = TVar (MkTVar s)
|
||||
pattern DBind' id vars expt = DBind (Bind id vars expt)
|
||||
pattern DData' typ injs = DData (Data typ injs)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue