481 lines
19 KiB
Haskell
481 lines
19 KiB
Haskell
{-# LANGUAGE LambdaCase #-}
|
|
{-# LANGUAGE OverloadedRecordDot #-}
|
|
|
|
{- | For now, converts polymorphic functions to concrete ones based on usage.
|
|
Assumes lambdas are lifted.
|
|
|
|
This step of compilation is as follows:
|
|
|
|
Split all function bindings into monomorphic and polymorphic binds. The
|
|
monomorphic bindings will be part of this compilation step.
|
|
Apply the following monomorphization function on all monomorphic binds, with
|
|
their type as an additional argument.
|
|
|
|
The function that transforms Binds operates on both monomorphic and
|
|
polymorphic functions, creates a context in which all possible polymorphic types
|
|
are mapped to concrete types, created using the additional argument.
|
|
Expressions are then recursively processed. The type of these expressions
|
|
are changed to using the mapped generic types. The expected type provided
|
|
in the recursion is changed depending on the different nodes.
|
|
|
|
When an external bind is encountered (with EId), it is checked whether it
|
|
exists in outputed binds or not. If it does, nothing further is evaluated.
|
|
If not, the bind transformer function is called on it with the
|
|
expected type in this context. The result of this computation (a monomorphic
|
|
bind) is added to the resulting set of binds.
|
|
-}
|
|
module Monomorphizer.Monomorphizer (monomorphize, morphExp, morphBind) where
|
|
|
|
|
|
import Control.Monad.Reader (MonadReader (ask, local),
|
|
Reader, asks, runReader)
|
|
import Control.Monad.State (MonadState (get),
|
|
StateT (runStateT), gets,
|
|
modify)
|
|
import Data.Coerce (coerce)
|
|
import qualified Data.Map as Map
|
|
import Data.Maybe (catMaybes)
|
|
import qualified Data.Set as Set
|
|
import Debug.Trace (trace)
|
|
import Grammar.Print (printTree)
|
|
import Monomorphizer.DataTypeRemover (removeDataTypes)
|
|
import qualified Monomorphizer.MonomorphizerIr as O
|
|
import qualified Monomorphizer.MorbIr as M
|
|
-- import TypeChecker.TypeCheckerIr (Ident (Ident))
|
|
import LambdaLifterIr (Ident (..))
|
|
-- import TypeChecker.TypeCheckerIr qualified as T
|
|
import qualified LambdaLifterIr as L
|
|
|
|
import Control.Monad.Reader (MonadReader (ask, local),
|
|
Reader, asks, runReader)
|
|
import Control.Monad.State (MonadState, StateT (runStateT),
|
|
gets, modify)
|
|
import qualified Data.Map as Map
|
|
import Data.Maybe (catMaybes, fromJust)
|
|
import qualified Data.Set as Set
|
|
import Data.Tuple.Extra (secondM)
|
|
import Grammar.Print (printTree)
|
|
|
|
{- | EnvM is the monad containing the read-only state as well as the
|
|
output state containing monomorphized functions and to-be monomorphized
|
|
data type declarations.
|
|
-}
|
|
newtype EnvM a = EnvM (StateT Output (Reader Env) a)
|
|
deriving (Functor, Applicative, Monad, MonadState Output, MonadReader Env)
|
|
|
|
type Output = Map.Map Ident Outputted
|
|
|
|
{- | Data structure describing outputted top-level information, that is
|
|
Binds, Polymorphic Data types (monomorphized in a later step) and
|
|
Marked bind, which means that it is in the process of monomorphization
|
|
and should not be monomorphized again.
|
|
-}
|
|
data Outputted = Marked | Complete M.Bind | Data M.Type L.Data deriving (Show)
|
|
|
|
-- | Static environment.
|
|
data Env = Env
|
|
{ input :: Map.Map Ident L.Bind
|
|
-- ^ All binds in the program.
|
|
, dataDefs :: Map.Map Ident L.Data
|
|
-- ^ All constructors mapped to their respective polymorphic data def
|
|
-- which includes all other constructors.
|
|
, polys :: Map.Map Ident M.Type
|
|
-- ^ Maps polymorphic identifiers with concrete types.
|
|
, locals :: Set.Set Ident
|
|
-- ^ Local variables.
|
|
}
|
|
|
|
-- | Determines if the identifier describes a local variable in the given context.
|
|
localExists :: Ident -> EnvM Bool
|
|
localExists ident = asks (Set.member ident . locals)
|
|
|
|
-- | Gets a polymorphic bind from an id.
|
|
getInputBind :: Ident -> EnvM (Maybe L.Bind)
|
|
getInputBind ident = asks (Map.lookup ident . input)
|
|
|
|
-- | Add monomorphic function derived from a polymorphic one, to env.
|
|
addOutputBind :: M.Bind -> EnvM ()
|
|
addOutputBind b@(M.Bind (ident, _) _ _) = modify (Map.insert ident (Complete b))
|
|
addOutputBind b@(M.BindC _ (ident, _) _ _) = modify (Map.insert ident (Complete b))
|
|
|
|
{- | Marks a global bind as being processed, meaning that when encountered again,
|
|
it should not be recursively processed.
|
|
-}
|
|
markBind :: Ident -> EnvM ()
|
|
markBind ident = modify (Map.insert ident Marked)
|
|
|
|
-- | Check if bind has been touched or not.
|
|
isBindMarked :: Ident -> EnvM Bool
|
|
isBindMarked ident = gets (Map.member ident)
|
|
|
|
-- | Checks if constructor is outputted.
|
|
isConsMarked :: Ident -> EnvM Bool
|
|
isConsMarked ident = gets (Map.member ident)
|
|
|
|
-- | Finds main bind.
|
|
getMain :: EnvM L.Bind
|
|
getMain = asks (\env -> case Map.lookup (Ident "main") (input env) of
|
|
Just mainBind -> mainBind
|
|
Nothing -> error "main not found in monomorphizer!"
|
|
)
|
|
|
|
{- | Makes a kv pair list of polymorphic to monomorphic mappings, throws runtime
|
|
error when encountering different structures between the two arguments. Debug:
|
|
First argument is the name of the bind.
|
|
-}
|
|
mapTypes :: Ident -> L.Type -> M.Type -> [(Ident, M.Type)]
|
|
mapTypes _ident (L.TLit _) (M.TLit _) = []
|
|
mapTypes _ident (L.TVar (L.MkTVar i1)) tm = [(i1, tm)]
|
|
mapTypes ident (L.TFun pt1 pt2) (M.TFun mt1 mt2) =
|
|
mapTypes ident pt1 mt1
|
|
++ mapTypes ident pt2 mt2
|
|
mapTypes ident (L.TData tIdent pTs) (M.TData mIdent mTs) =
|
|
if tIdent /= mIdent
|
|
then error "the data type names of monomorphic and polymorphic data types does not match"
|
|
else foldl (\xs (p, m) -> mapTypes ident p m ++ xs) [] (zip pTs mTs)
|
|
mapTypes ident t1 t2 = error $ "in bind: '" ++ printTree ident ++ "', " ++
|
|
"structure of types not the same: '" ++ printTree t1 ++ "', '" ++ printTree t2 ++ "'"
|
|
|
|
-- | Gets the mapped monomorphic type of a polymorphic type in the current context.
|
|
getMonoFromPoly :: L.Type -> EnvM M.Type
|
|
getMonoFromPoly t = do
|
|
env <- ask
|
|
return $ getMono (polys env) t
|
|
where
|
|
getMono :: Map.Map Ident M.Type -> L.Type -> M.Type
|
|
getMono polys t = case t of
|
|
(L.TLit ident) -> M.TLit ident
|
|
(L.TFun t1 t2) -> M.TFun (getMono polys t1) (getMono polys t2)
|
|
(L.TVar (L.MkTVar ident)) -> case Map.lookup ident polys of
|
|
Just concrete -> concrete
|
|
Nothing -> M.TLit (Ident "void")
|
|
-- error $ "type not found! type: " ++ show ident ++ ", error in previous compilation steps"
|
|
(L.TData ident args) -> M.TData ident (map (getMono polys) args)
|
|
|
|
{- | If ident not already in env's output, morphed bind to output
|
|
(and all referenced binds within this bind).
|
|
Returns the annotated bind name.
|
|
-}
|
|
morphBind :: M.Type -> L.Bind -> EnvM Ident
|
|
morphBind expectedType b@(L.Bind (ident, btype) args (exp, expt)) = do
|
|
-- The "new name" is used to find out if it is already marked or not.
|
|
let name' = newFuncName expectedType b
|
|
bindMarked <- isBindMarked name'
|
|
local
|
|
( \env ->
|
|
env
|
|
{ locals = Set.fromList (map fst args)
|
|
, polys = Map.fromList (mapTypes ident btype expectedType)
|
|
}
|
|
)
|
|
$ do
|
|
-- Return with right name if already marked
|
|
if bindMarked
|
|
then return name'
|
|
else do
|
|
-- Mark so that this bind will not be processed in recursive or cyclic
|
|
-- function calls
|
|
markBind name'
|
|
expt' <- getMonoFromPoly expt
|
|
exp' <- morphExp expt' exp
|
|
-- Get monomorphic type sof args
|
|
args' <- mapM morphArg args
|
|
addOutputBind $
|
|
M.Bind
|
|
(name', expectedType)
|
|
args'
|
|
(exp', expt')
|
|
return name'
|
|
|
|
morphBind expectedType b@(L.BindC cxt (ident, btype) args (exp, expt)) = do
|
|
-- The "new name" is used to find out if it is already marked or not.
|
|
let name' = newFuncName expectedType b
|
|
bindMarked <- isBindMarked name'
|
|
local
|
|
( \env ->
|
|
env
|
|
{ locals = Set.fromList (map fst args)
|
|
, polys = Map.fromList (mapTypes ident btype expectedType)
|
|
}
|
|
)
|
|
$ do
|
|
-- Return with right name if already marked
|
|
if bindMarked
|
|
then return name'
|
|
else do
|
|
-- Mark so that this bind will not be processed in recursive or cyclic
|
|
-- function calls
|
|
markBind name'
|
|
-- Get monomorphic type sof args
|
|
args' <- mapM morphArg args
|
|
cxt' <- mapM (secondM getMonoFromPoly) cxt
|
|
expt' <- getMonoFromPoly expt
|
|
exp' <- local (\env -> foldr (addLocal . fst) env cxt)
|
|
(morphExp expt' exp)
|
|
addOutputBind $
|
|
M.BindC cxt'
|
|
(name', expectedType)
|
|
args'
|
|
(exp', expt')
|
|
return name'
|
|
|
|
|
|
-- | Monomorphizes arguments of a bind.
|
|
morphArg :: (Ident, L.Type) -> EnvM (Ident, M.Type)
|
|
morphArg (ident, t) = do
|
|
t' <- getMonoFromPoly t
|
|
return (ident, t')
|
|
|
|
-- | Gets the data bind from the name of a constructor.
|
|
getInputData :: Ident -> EnvM (Maybe L.Data)
|
|
getInputData ident = do
|
|
env <- ask
|
|
return $ Map.lookup ident (dataDefs env)
|
|
|
|
{- | Monomorphize a constructor using it's global name. Constructors may
|
|
appear as expressions in the tree, or as patterns in case-expressions.
|
|
'newIdent' has a unique name while 'ident' has a general name.
|
|
-}
|
|
morphCons :: M.Type -> Ident -> Ident -> EnvM ()
|
|
morphCons expectedType ident newIdent = do
|
|
--trace ("Tjofras:" ++ show (newName expectedType ident)) $ return ()
|
|
maybeD <- getInputData ident
|
|
case maybeD of
|
|
-- closures can have unbound variables
|
|
Nothing -> pure ()
|
|
Just d -> do
|
|
modify (\output -> Map.insert newIdent (Data expectedType d) output)
|
|
|
|
-- | Converts literals from input to output tree.
|
|
convertLit :: L.Lit -> M.Lit
|
|
convertLit (L.LInt v) = M.LInt v
|
|
convertLit (L.LChar v) = M.LChar v
|
|
|
|
|
|
-- | Monomorphizes an expression, given an expected type.
|
|
morphExp :: M.Type -> L.Exp -> EnvM M.Exp
|
|
morphExp expectedType exp = case exp of
|
|
L.ELit lit -> return $ M.ELit lit
|
|
-- Constructor
|
|
L.EInj ident -> do
|
|
let ident' = newName (getDataType expectedType) ident
|
|
morphCons expectedType ident ident'
|
|
return $ M.EVar ident'
|
|
L.EApp (e1, _t1) (e2, t2) -> do
|
|
t2' <- getMonoFromPoly t2
|
|
e2' <- morphExp t2' e2
|
|
e1' <- morphExp (M.TFun t2' expectedType) e1
|
|
return $ M.EApp (e1', M.TFun t2' expectedType) (e2', t2')
|
|
L.EAdd (e1, t1) (e2, t2) -> do
|
|
t1' <- getMonoFromPoly t1
|
|
t2' <- getMonoFromPoly t2
|
|
e1' <- morphExp t1' e1
|
|
e2' <- morphExp t2' e2
|
|
return $ M.EAdd (e1', expectedType) (e2', expectedType)
|
|
L.ECase (exp, t) bs -> do
|
|
t' <- getMonoFromPoly t
|
|
exp' <- morphExp t' exp
|
|
bs' <- mapM morphBranch bs
|
|
return $ M.ECase (exp', t') (catMaybes bs')
|
|
-- Ideally constructors should be EInj, though this code handles them
|
|
-- as well.
|
|
-- FIXME MAKE EVAR AND EINJ SEPARATE!!!
|
|
L.EVar ident -> do
|
|
isLocal <- localExists ident
|
|
if isLocal
|
|
then do
|
|
return $ M.EVar ident
|
|
else do
|
|
bind <- getInputBind ident
|
|
case bind of
|
|
Nothing -> error $ "unbound variable: '" ++ printTree ident ++ "'"
|
|
Just bind' -> do
|
|
-- New bind to process
|
|
newBindName <- morphBind expectedType bind'
|
|
return $ M.EVar newBindName
|
|
L.EVarC as ident -> do
|
|
isLocal <- localExists ident
|
|
if isLocal
|
|
then do
|
|
return $ M.EVar ident
|
|
else do
|
|
bind <- fromJust <$> getInputBind ident
|
|
as' <- mapM (secondM getMonoFromPoly) as
|
|
-- New bind to process
|
|
newBindName <- morphBind expectedType bind
|
|
return $ M.EVarC as' newBindName
|
|
-- Ideally constructors should be EInj, though this code handles them
|
|
-- as well.
|
|
|
|
|
|
L.ELet (identB, tB) (expB, tExpB) (exp, tExp) -> do
|
|
tB' <- getMonoFromPoly tB
|
|
tExpB' <- getMonoFromPoly tExpB
|
|
tExp' <- getMonoFromPoly tExp
|
|
expB' <- morphExp tExpB' expB
|
|
exp' <- local (addLocal identB) (morphExp tExp' exp)
|
|
return $ M.ELet (M.Bind (identB, tB') [] (expB', tExpB')) (exp', tExp')
|
|
|
|
-- | Monomorphizes case-of branches.
|
|
morphBranch :: L.Branch -> EnvM (Maybe M.Branch)
|
|
morphBranch (L.Branch (p, pt) (e, et)) = do
|
|
pt' <- getMonoFromPoly pt
|
|
et' <- getMonoFromPoly et
|
|
env <- ask
|
|
maybeMorphedPattern <- morphPattern p pt'
|
|
case maybeMorphedPattern of
|
|
Nothing -> return Nothing
|
|
Just (p', newLocals) ->
|
|
local (const env { locals = Set.union (locals env) newLocals }) $ do
|
|
e' <- morphExp et' e
|
|
return $ Just (M.Branch p' (e', et'))
|
|
|
|
morphPattern :: L.Pattern -> M.Type -> EnvM (Maybe (M.T M.Pattern, Set.Set Ident))
|
|
morphPattern p expectedType = case p of
|
|
L.PVar ident -> return $ Just ((M.PVar ident, expectedType), Set.singleton ident)
|
|
L.PLit lit -> return $ Just ((M.PLit (convertLit lit), expectedType), Set.empty)
|
|
L.PCatch -> return $ Just ((M.PCatch, expectedType), Set.empty)
|
|
L.PEnum ident -> return $ Just ((M.PEnum (newName expectedType ident), expectedType), Set.empty)
|
|
L.PInj ident pts -> do let newIdent = newName expectedType ident
|
|
outEnv <- get
|
|
trace ("WOW: " ++ show (newName expectedType ident)) $ return ()
|
|
trace ("WOW2: " ++ show (outEnv)) $ return ()
|
|
isMarked <- isConsMarked newIdent
|
|
if isMarked
|
|
then do
|
|
trace ("WOW3") $ return ()
|
|
ts' <- mapM (getMonoFromPoly . snd) pts
|
|
let pts' = zip (map fst pts) ts'
|
|
psSets <- mapM (uncurry morphPattern) pts'
|
|
let maybePsSets = sequence psSets
|
|
case maybePsSets of
|
|
Nothing -> return Nothing
|
|
Just psSets' -> return $ Just
|
|
((M.PInj newIdent (map fst psSets'), expectedType), Set.unions $ map snd psSets')
|
|
else return Nothing
|
|
|
|
-- | Creates a new identifier for a function with an assigned type.
|
|
newFuncName :: M.Type -> L.Bind -> Ident
|
|
newFuncName t (L.Bind (ident@(Ident bindName), _) _ _) =
|
|
if bindName == "main"
|
|
then Ident bindName
|
|
else newName t ident
|
|
|
|
newFuncName t (L.BindC _ (ident@(Ident bindName), _) _ _) =
|
|
if bindName == "main"
|
|
then Ident bindName
|
|
else newName t ident
|
|
|
|
newName :: M.Type -> Ident -> Ident
|
|
newName t (Ident str) = Ident $ str ++ "$" ++ newName' t
|
|
where
|
|
newName' :: M.Type -> String
|
|
newName' (M.TLit (Ident str)) = str
|
|
newName' (M.TFun t1 t2) = newName' t1 ++ "_" ++ newName' t2
|
|
newName' (M.TData (Ident str) ts) = str ++ foldl (\s t -> s ++ "." ++ newName' t) "" ts
|
|
|
|
-- | Monomorphization step.
|
|
monomorphize :: L.Program -> O.Program
|
|
monomorphize (L.Program defs) =
|
|
removeDataTypes $
|
|
M.Program
|
|
( getDefsFromOutput
|
|
(runEnvM Map.empty (createEnv defs) monomorphize')
|
|
)
|
|
where
|
|
monomorphize' :: EnvM ()
|
|
monomorphize' = do
|
|
main <- getMain
|
|
morphBind (M.TLit $ Ident "Int") main
|
|
return ()
|
|
|
|
-- | Runs and gives the output binds.
|
|
runEnvM :: Output -> Env -> EnvM () -> Output
|
|
runEnvM o env (EnvM stateM) = snd $ runReader (runStateT stateM o) env
|
|
|
|
-- | Creates the environment based on the input binds.
|
|
createEnv :: [L.Def] -> Env
|
|
createEnv defs =
|
|
Env
|
|
{ input = Map.fromList bindPairs
|
|
, dataDefs = Map.fromList dataPairs
|
|
, polys = Map.empty
|
|
, locals = Set.empty
|
|
}
|
|
where
|
|
bindPairs = (map (\b -> (getBindName b, b)) . getBindsFromDefs) defs
|
|
dataPairs :: [(Ident, L.Data)]
|
|
dataPairs = (foldl (\acc d@(L.Data _ cs) -> map ((,d) . getConsName) cs ++ acc) [] . getDataFromDefs) defs
|
|
|
|
-- | Gets a top-lefel function name.
|
|
getBindName :: L.Bind -> Ident
|
|
getBindName (L.Bind (ident, _) _ _) = ident
|
|
getBindName (L.BindC _ (ident, _) _ _) = ident
|
|
|
|
-- Helper functions
|
|
-- Gets custom data declarations form defs.
|
|
getDataFromDefs :: [L.Def] -> [L.Data]
|
|
getDataFromDefs =
|
|
foldl
|
|
( \bs -> \case
|
|
L.DBind _ -> bs
|
|
L.DData d -> d : bs
|
|
)
|
|
[]
|
|
|
|
getConsName :: L.Inj -> Ident
|
|
getConsName (L.Inj ident _) = ident
|
|
|
|
getBindsFromDefs :: [L.Def] -> [L.Bind]
|
|
getBindsFromDefs =
|
|
foldl
|
|
( \bs -> \case
|
|
L.DBind b -> b : bs
|
|
L.DData _ -> bs
|
|
)
|
|
[]
|
|
|
|
getDefsFromOutput :: Output -> [M.Def]
|
|
getDefsFromOutput o =
|
|
map M.DBind binds
|
|
++ (map (M.DData . snd) . Map.toList) (createNewData dataInput Map.empty)
|
|
where
|
|
(binds, dataInput) = splitBindsAndData o
|
|
|
|
-- | Splits the output into binds and data declaration components (used in createNewData)
|
|
splitBindsAndData :: Output -> ([M.Bind], [(Ident, M.Type, L.Data)])
|
|
splitBindsAndData output =
|
|
foldl
|
|
( \(oBinds, oData) (ident, o) -> case o of
|
|
Marked -> error "internal bug in monomorphizer"
|
|
Complete b -> (b : oBinds, oData)
|
|
Data t d -> (oBinds, (ident, t, d) : oData)
|
|
)
|
|
([], [])
|
|
(Map.toList output)
|
|
|
|
-- | Converts all found constructors to monomorphic data declarations.
|
|
createNewData :: [(Ident, M.Type, L.Data)] -> Map.Map Ident M.Data -> Map.Map Ident M.Data
|
|
createNewData [] o = o
|
|
createNewData ((consIdent, consType, polyData) : input) o =
|
|
createNewData input $
|
|
Map.insertWith
|
|
(\_ (M.Data _ cs) -> M.Data newDataType (newCons : cs))
|
|
newDataName
|
|
(M.Data newDataType [newCons])
|
|
o
|
|
where
|
|
L.Data (L.TData polyDataIdent _) _ = polyData
|
|
newDataType = getDataType consType
|
|
newDataName = newName newDataType polyDataIdent
|
|
newCons = M.Inj consIdent consType
|
|
|
|
-- | Gets the Data Type of a constructor type (a -> Just a becomes Just a).
|
|
getDataType :: M.Type -> M.Type
|
|
getDataType (M.TFun _t1 t2) = getDataType t2
|
|
getDataType tData@(M.TData _ _) = tData
|
|
getDataType _ = error "???"
|
|
|
|
|
|
addLocal :: Ident -> Env -> Env
|
|
addLocal x env = env { locals = Set.insert x env.locals }
|