churf/src/Monomorphizer/Monomorphizer.hs
2023-05-06 22:49:08 +02:00

481 lines
19 KiB
Haskell

{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedRecordDot #-}
{- | For now, converts polymorphic functions to concrete ones based on usage.
Assumes lambdas are lifted.
This step of compilation is as follows:
Split all function bindings into monomorphic and polymorphic binds. The
monomorphic bindings will be part of this compilation step.
Apply the following monomorphization function on all monomorphic binds, with
their type as an additional argument.
The function that transforms Binds operates on both monomorphic and
polymorphic functions, creates a context in which all possible polymorphic types
are mapped to concrete types, created using the additional argument.
Expressions are then recursively processed. The type of these expressions
are changed to using the mapped generic types. The expected type provided
in the recursion is changed depending on the different nodes.
When an external bind is encountered (with EId), it is checked whether it
exists in outputed binds or not. If it does, nothing further is evaluated.
If not, the bind transformer function is called on it with the
expected type in this context. The result of this computation (a monomorphic
bind) is added to the resulting set of binds.
-}
module Monomorphizer.Monomorphizer (monomorphize, morphExp, morphBind) where
import Control.Monad.Reader (MonadReader (ask, local),
Reader, asks, runReader)
import Control.Monad.State (MonadState (get),
StateT (runStateT), gets,
modify)
import Data.Coerce (coerce)
import qualified Data.Map as Map
import Data.Maybe (catMaybes)
import qualified Data.Set as Set
import Debug.Trace (trace)
import Grammar.Print (printTree)
import Monomorphizer.DataTypeRemover (removeDataTypes)
import qualified Monomorphizer.MonomorphizerIr as O
import qualified Monomorphizer.MorbIr as M
-- import TypeChecker.TypeCheckerIr (Ident (Ident))
import LambdaLifterIr (Ident (..))
-- import TypeChecker.TypeCheckerIr qualified as T
import qualified LambdaLifterIr as L
import Control.Monad.Reader (MonadReader (ask, local),
Reader, asks, runReader)
import Control.Monad.State (MonadState, StateT (runStateT),
gets, modify)
import qualified Data.Map as Map
import Data.Maybe (catMaybes, fromJust)
import qualified Data.Set as Set
import Data.Tuple.Extra (secondM)
import Grammar.Print (printTree)
{- | EnvM is the monad containing the read-only state as well as the
output state containing monomorphized functions and to-be monomorphized
data type declarations.
-}
newtype EnvM a = EnvM (StateT Output (Reader Env) a)
deriving (Functor, Applicative, Monad, MonadState Output, MonadReader Env)
type Output = Map.Map Ident Outputted
{- | Data structure describing outputted top-level information, that is
Binds, Polymorphic Data types (monomorphized in a later step) and
Marked bind, which means that it is in the process of monomorphization
and should not be monomorphized again.
-}
data Outputted = Marked | Complete M.Bind | Data M.Type L.Data deriving (Show)
-- | Static environment.
data Env = Env
{ input :: Map.Map Ident L.Bind
-- ^ All binds in the program.
, dataDefs :: Map.Map Ident L.Data
-- ^ All constructors mapped to their respective polymorphic data def
-- which includes all other constructors.
, polys :: Map.Map Ident M.Type
-- ^ Maps polymorphic identifiers with concrete types.
, locals :: Set.Set Ident
-- ^ Local variables.
}
-- | Determines if the identifier describes a local variable in the given context.
localExists :: Ident -> EnvM Bool
localExists ident = asks (Set.member ident . locals)
-- | Gets a polymorphic bind from an id.
getInputBind :: Ident -> EnvM (Maybe L.Bind)
getInputBind ident = asks (Map.lookup ident . input)
-- | Add monomorphic function derived from a polymorphic one, to env.
addOutputBind :: M.Bind -> EnvM ()
addOutputBind b@(M.Bind (ident, _) _ _) = modify (Map.insert ident (Complete b))
addOutputBind b@(M.BindC _ (ident, _) _ _) = modify (Map.insert ident (Complete b))
{- | Marks a global bind as being processed, meaning that when encountered again,
it should not be recursively processed.
-}
markBind :: Ident -> EnvM ()
markBind ident = modify (Map.insert ident Marked)
-- | Check if bind has been touched or not.
isBindMarked :: Ident -> EnvM Bool
isBindMarked ident = gets (Map.member ident)
-- | Checks if constructor is outputted.
isConsMarked :: Ident -> EnvM Bool
isConsMarked ident = gets (Map.member ident)
-- | Finds main bind.
getMain :: EnvM L.Bind
getMain = asks (\env -> case Map.lookup (Ident "main") (input env) of
Just mainBind -> mainBind
Nothing -> error "main not found in monomorphizer!"
)
{- | Makes a kv pair list of polymorphic to monomorphic mappings, throws runtime
error when encountering different structures between the two arguments. Debug:
First argument is the name of the bind.
-}
mapTypes :: Ident -> L.Type -> M.Type -> [(Ident, M.Type)]
mapTypes _ident (L.TLit _) (M.TLit _) = []
mapTypes _ident (L.TVar (L.MkTVar i1)) tm = [(i1, tm)]
mapTypes ident (L.TFun pt1 pt2) (M.TFun mt1 mt2) =
mapTypes ident pt1 mt1
++ mapTypes ident pt2 mt2
mapTypes ident (L.TData tIdent pTs) (M.TData mIdent mTs) =
if tIdent /= mIdent
then error "the data type names of monomorphic and polymorphic data types does not match"
else foldl (\xs (p, m) -> mapTypes ident p m ++ xs) [] (zip pTs mTs)
mapTypes ident t1 t2 = error $ "in bind: '" ++ printTree ident ++ "', " ++
"structure of types not the same: '" ++ printTree t1 ++ "', '" ++ printTree t2 ++ "'"
-- | Gets the mapped monomorphic type of a polymorphic type in the current context.
getMonoFromPoly :: L.Type -> EnvM M.Type
getMonoFromPoly t = do
env <- ask
return $ getMono (polys env) t
where
getMono :: Map.Map Ident M.Type -> L.Type -> M.Type
getMono polys t = case t of
(L.TLit ident) -> M.TLit ident
(L.TFun t1 t2) -> M.TFun (getMono polys t1) (getMono polys t2)
(L.TVar (L.MkTVar ident)) -> case Map.lookup ident polys of
Just concrete -> concrete
Nothing -> M.TLit (Ident "void")
-- error $ "type not found! type: " ++ show ident ++ ", error in previous compilation steps"
(L.TData ident args) -> M.TData ident (map (getMono polys) args)
{- | If ident not already in env's output, morphed bind to output
(and all referenced binds within this bind).
Returns the annotated bind name.
-}
morphBind :: M.Type -> L.Bind -> EnvM Ident
morphBind expectedType b@(L.Bind (ident, btype) args (exp, expt)) = do
-- The "new name" is used to find out if it is already marked or not.
let name' = newFuncName expectedType b
bindMarked <- isBindMarked name'
local
( \env ->
env
{ locals = Set.fromList (map fst args)
, polys = Map.fromList (mapTypes ident btype expectedType)
}
)
$ do
-- Return with right name if already marked
if bindMarked
then return name'
else do
-- Mark so that this bind will not be processed in recursive or cyclic
-- function calls
markBind name'
expt' <- getMonoFromPoly expt
exp' <- morphExp expt' exp
-- Get monomorphic type sof args
args' <- mapM morphArg args
addOutputBind $
M.Bind
(name', expectedType)
args'
(exp', expt')
return name'
morphBind expectedType b@(L.BindC cxt (ident, btype) args (exp, expt)) = do
-- The "new name" is used to find out if it is already marked or not.
let name' = newFuncName expectedType b
bindMarked <- isBindMarked name'
local
( \env ->
env
{ locals = Set.fromList (map fst args)
, polys = Map.fromList (mapTypes ident btype expectedType)
}
)
$ do
-- Return with right name if already marked
if bindMarked
then return name'
else do
-- Mark so that this bind will not be processed in recursive or cyclic
-- function calls
markBind name'
-- Get monomorphic type sof args
args' <- mapM morphArg args
cxt' <- mapM (secondM getMonoFromPoly) cxt
expt' <- getMonoFromPoly expt
exp' <- local (\env -> foldr (addLocal . fst) env cxt)
(morphExp expt' exp)
addOutputBind $
M.BindC cxt'
(name', expectedType)
args'
(exp', expt')
return name'
-- | Monomorphizes arguments of a bind.
morphArg :: (Ident, L.Type) -> EnvM (Ident, M.Type)
morphArg (ident, t) = do
t' <- getMonoFromPoly t
return (ident, t')
-- | Gets the data bind from the name of a constructor.
getInputData :: Ident -> EnvM (Maybe L.Data)
getInputData ident = do
env <- ask
return $ Map.lookup ident (dataDefs env)
{- | Monomorphize a constructor using it's global name. Constructors may
appear as expressions in the tree, or as patterns in case-expressions.
'newIdent' has a unique name while 'ident' has a general name.
-}
morphCons :: M.Type -> Ident -> Ident -> EnvM ()
morphCons expectedType ident newIdent = do
--trace ("Tjofras:" ++ show (newName expectedType ident)) $ return ()
maybeD <- getInputData ident
case maybeD of
-- closures can have unbound variables
Nothing -> pure ()
Just d -> do
modify (\output -> Map.insert newIdent (Data expectedType d) output)
-- | Converts literals from input to output tree.
convertLit :: L.Lit -> M.Lit
convertLit (L.LInt v) = M.LInt v
convertLit (L.LChar v) = M.LChar v
-- | Monomorphizes an expression, given an expected type.
morphExp :: M.Type -> L.Exp -> EnvM M.Exp
morphExp expectedType exp = case exp of
L.ELit lit -> return $ M.ELit lit
-- Constructor
L.EInj ident -> do
let ident' = newName (getDataType expectedType) ident
morphCons expectedType ident ident'
return $ M.EVar ident'
L.EApp (e1, _t1) (e2, t2) -> do
t2' <- getMonoFromPoly t2
e2' <- morphExp t2' e2
e1' <- morphExp (M.TFun t2' expectedType) e1
return $ M.EApp (e1', M.TFun t2' expectedType) (e2', t2')
L.EAdd (e1, t1) (e2, t2) -> do
t1' <- getMonoFromPoly t1
t2' <- getMonoFromPoly t2
e1' <- morphExp t1' e1
e2' <- morphExp t2' e2
return $ M.EAdd (e1', expectedType) (e2', expectedType)
L.ECase (exp, t) bs -> do
t' <- getMonoFromPoly t
exp' <- morphExp t' exp
bs' <- mapM morphBranch bs
return $ M.ECase (exp', t') (catMaybes bs')
-- Ideally constructors should be EInj, though this code handles them
-- as well.
-- FIXME MAKE EVAR AND EINJ SEPARATE!!!
L.EVar ident -> do
isLocal <- localExists ident
if isLocal
then do
return $ M.EVar ident
else do
bind <- getInputBind ident
case bind of
Nothing -> error $ "unbound variable: '" ++ printTree ident ++ "'"
Just bind' -> do
-- New bind to process
newBindName <- morphBind expectedType bind'
return $ M.EVar newBindName
L.EVarC as ident -> do
isLocal <- localExists ident
if isLocal
then do
return $ M.EVar ident
else do
bind <- fromJust <$> getInputBind ident
as' <- mapM (secondM getMonoFromPoly) as
-- New bind to process
newBindName <- morphBind expectedType bind
return $ M.EVarC as' newBindName
-- Ideally constructors should be EInj, though this code handles them
-- as well.
L.ELet (identB, tB) (expB, tExpB) (exp, tExp) -> do
tB' <- getMonoFromPoly tB
tExpB' <- getMonoFromPoly tExpB
tExp' <- getMonoFromPoly tExp
expB' <- morphExp tExpB' expB
exp' <- local (addLocal identB) (morphExp tExp' exp)
return $ M.ELet (M.Bind (identB, tB') [] (expB', tExpB')) (exp', tExp')
-- | Monomorphizes case-of branches.
morphBranch :: L.Branch -> EnvM (Maybe M.Branch)
morphBranch (L.Branch (p, pt) (e, et)) = do
pt' <- getMonoFromPoly pt
et' <- getMonoFromPoly et
env <- ask
maybeMorphedPattern <- morphPattern p pt'
case maybeMorphedPattern of
Nothing -> return Nothing
Just (p', newLocals) ->
local (const env { locals = Set.union (locals env) newLocals }) $ do
e' <- morphExp et' e
return $ Just (M.Branch p' (e', et'))
morphPattern :: L.Pattern -> M.Type -> EnvM (Maybe (M.T M.Pattern, Set.Set Ident))
morphPattern p expectedType = case p of
L.PVar ident -> return $ Just ((M.PVar ident, expectedType), Set.singleton ident)
L.PLit lit -> return $ Just ((M.PLit (convertLit lit), expectedType), Set.empty)
L.PCatch -> return $ Just ((M.PCatch, expectedType), Set.empty)
L.PEnum ident -> return $ Just ((M.PEnum (newName expectedType ident), expectedType), Set.empty)
L.PInj ident pts -> do let newIdent = newName expectedType ident
outEnv <- get
trace ("WOW: " ++ show (newName expectedType ident)) $ return ()
trace ("WOW2: " ++ show (outEnv)) $ return ()
isMarked <- isConsMarked newIdent
if isMarked
then do
trace ("WOW3") $ return ()
ts' <- mapM (getMonoFromPoly . snd) pts
let pts' = zip (map fst pts) ts'
psSets <- mapM (uncurry morphPattern) pts'
let maybePsSets = sequence psSets
case maybePsSets of
Nothing -> return Nothing
Just psSets' -> return $ Just
((M.PInj newIdent (map fst psSets'), expectedType), Set.unions $ map snd psSets')
else return Nothing
-- | Creates a new identifier for a function with an assigned type.
newFuncName :: M.Type -> L.Bind -> Ident
newFuncName t (L.Bind (ident@(Ident bindName), _) _ _) =
if bindName == "main"
then Ident bindName
else newName t ident
newFuncName t (L.BindC _ (ident@(Ident bindName), _) _ _) =
if bindName == "main"
then Ident bindName
else newName t ident
newName :: M.Type -> Ident -> Ident
newName t (Ident str) = Ident $ str ++ "$" ++ newName' t
where
newName' :: M.Type -> String
newName' (M.TLit (Ident str)) = str
newName' (M.TFun t1 t2) = newName' t1 ++ "_" ++ newName' t2
newName' (M.TData (Ident str) ts) = str ++ foldl (\s t -> s ++ "." ++ newName' t) "" ts
-- | Monomorphization step.
monomorphize :: L.Program -> O.Program
monomorphize (L.Program defs) =
removeDataTypes $
M.Program
( getDefsFromOutput
(runEnvM Map.empty (createEnv defs) monomorphize')
)
where
monomorphize' :: EnvM ()
monomorphize' = do
main <- getMain
morphBind (M.TLit $ Ident "Int") main
return ()
-- | Runs and gives the output binds.
runEnvM :: Output -> Env -> EnvM () -> Output
runEnvM o env (EnvM stateM) = snd $ runReader (runStateT stateM o) env
-- | Creates the environment based on the input binds.
createEnv :: [L.Def] -> Env
createEnv defs =
Env
{ input = Map.fromList bindPairs
, dataDefs = Map.fromList dataPairs
, polys = Map.empty
, locals = Set.empty
}
where
bindPairs = (map (\b -> (getBindName b, b)) . getBindsFromDefs) defs
dataPairs :: [(Ident, L.Data)]
dataPairs = (foldl (\acc d@(L.Data _ cs) -> map ((,d) . getConsName) cs ++ acc) [] . getDataFromDefs) defs
-- | Gets a top-lefel function name.
getBindName :: L.Bind -> Ident
getBindName (L.Bind (ident, _) _ _) = ident
getBindName (L.BindC _ (ident, _) _ _) = ident
-- Helper functions
-- Gets custom data declarations form defs.
getDataFromDefs :: [L.Def] -> [L.Data]
getDataFromDefs =
foldl
( \bs -> \case
L.DBind _ -> bs
L.DData d -> d : bs
)
[]
getConsName :: L.Inj -> Ident
getConsName (L.Inj ident _) = ident
getBindsFromDefs :: [L.Def] -> [L.Bind]
getBindsFromDefs =
foldl
( \bs -> \case
L.DBind b -> b : bs
L.DData _ -> bs
)
[]
getDefsFromOutput :: Output -> [M.Def]
getDefsFromOutput o =
map M.DBind binds
++ (map (M.DData . snd) . Map.toList) (createNewData dataInput Map.empty)
where
(binds, dataInput) = splitBindsAndData o
-- | Splits the output into binds and data declaration components (used in createNewData)
splitBindsAndData :: Output -> ([M.Bind], [(Ident, M.Type, L.Data)])
splitBindsAndData output =
foldl
( \(oBinds, oData) (ident, o) -> case o of
Marked -> error "internal bug in monomorphizer"
Complete b -> (b : oBinds, oData)
Data t d -> (oBinds, (ident, t, d) : oData)
)
([], [])
(Map.toList output)
-- | Converts all found constructors to monomorphic data declarations.
createNewData :: [(Ident, M.Type, L.Data)] -> Map.Map Ident M.Data -> Map.Map Ident M.Data
createNewData [] o = o
createNewData ((consIdent, consType, polyData) : input) o =
createNewData input $
Map.insertWith
(\_ (M.Data _ cs) -> M.Data newDataType (newCons : cs))
newDataName
(M.Data newDataType [newCons])
o
where
L.Data (L.TData polyDataIdent _) _ = polyData
newDataType = getDataType consType
newDataName = newName newDataType polyDataIdent
newCons = M.Inj consIdent consType
-- | Gets the Data Type of a constructor type (a -> Just a becomes Just a).
getDataType :: M.Type -> M.Type
getDataType (M.TFun _t1 t2) = getDataType t2
getDataType tData@(M.TData _ _) = tData
getDataType _ = error "???"
addLocal :: Ident -> Env -> Env
addLocal x env = env { locals = Set.insert x env.locals }