Add closures and fix lets in monomorphizer

This commit is contained in:
Martin Fredin 2023-05-06 22:49:08 +02:00
parent 677a200a15
commit 72e599d5de
26 changed files with 1440 additions and 692 deletions

View file

@ -1,8 +1,11 @@
module Monomorphizer.DataTypeRemover (removeDataTypes) where
import Monomorphizer.MonomorphizerIr qualified as M2
import Monomorphizer.MorbIr qualified as M1
import TypeChecker.TypeCheckerIr (Ident (Ident))
import Data.Bifunctor (Bifunctor (bimap))
import Monomorphizer.MonomorphizerIr (Ident (..))
import qualified Monomorphizer.MonomorphizerIr as M2
import qualified Monomorphizer.MorbIr as M1
import Prelude hiding (exp)
removeDataTypes :: M1.Program -> M2.Program
removeDataTypes (M1.Program defs) = M2.Program (map pDef defs)
@ -18,43 +21,43 @@ pCons :: M1.Inj -> M2.Inj
pCons (M1.Inj ident t) = M2.Inj ident (pType t)
pType :: M1.Type -> M2.Type
pType (M1.TLit ident) = M2.TLit ident
pType (M1.TFun t1 t2) = M2.TFun (pType t1) (pType t2)
pType (M1.TLit ident) = M2.TLit ident
pType (M1.TFun t1 t2) = M2.TFun (pType t1) (pType t2)
pType (M1.TData (Ident "Bool") _) = M2.TLit (Ident "Bool")
pType d = M2.TLit (Ident (newName d)) -- This is the step
pType d = M2.TLit (Ident (newName d)) -- This is the step
newName :: M1.Type -> String
newName (M1.TLit (Ident str)) = str
newName (M1.TFun t1 t2) = newName t1 ++ newName t2
newName (M1.TLit (Ident str)) = str
newName (M1.TFun t1 t2) = newName t1 ++ newName t2
newName (M1.TData (Ident str) args) = str ++ concatMap newName args
pBind :: M1.Bind -> M2.Bind
pBind (M1.Bind id argIds expt) = M2.Bind (pId id) (map pId argIds) (pExpT expt)
pBind (M1.BindC cxt id argIds expt) =
M2.BindC (map pId cxt) (pId id) (map pId argIds) (pExpT expt)
pId :: (Ident, M1.Type) -> (Ident, M2.Type)
pId (ident, t) = (ident, pType t)
pExpT :: M1.ExpT -> M2.ExpT
pExpT :: M1.T M1.Exp -> M2.T M2.Exp
pExpT (exp, t) = (pExp exp, pType t)
pExp :: M1.Exp -> M2.Exp
pExp (M1.EVar ident) = M2.EVar ident
pExp (M1.ELit lit) = M2.ELit (pLit lit)
pExp (M1.ELet bind expt) = M2.ELet (pBind bind) (pExpT expt)
pExp (M1.EApp e1 e2) = M2.EApp (pExpT e1) (pExpT e2)
pExp (M1.EAdd e1 e2) = M2.EAdd (pExpT e1) (pExpT e2)
pExp (M1.EVar ident) = M2.EVar ident
pExp (M1.EVarC as ident) = M2.EVarC (map pId as) ident
pExp (M1.ELit lit) = M2.ELit lit
pExp (M1.ELet bind expt) = M2.ELet (pBind bind) (pExpT expt)
pExp (M1.EApp e1 e2) = M2.EApp (pExpT e1) (pExpT e2)
pExp (M1.EAdd e1 e2) = M2.EAdd (pExpT e1) (pExpT e2)
pExp (M1.ECase expT branches) = M2.ECase (pExpT expT) (map pBranch branches)
pBranch :: M1.Branch -> M2.Branch
pBranch (M1.Branch (patt, t) expt) = M2.Branch (pPattern patt, pType t) (pExpT expt)
pPattern :: M1.Pattern -> M2.Pattern
pPattern (M1.PVar id) = M2.PVar (pId id)
pPattern (M1.PLit (lit, t)) = M2.PLit (pLit lit, pType t)
pPattern (M1.PInj ident patts) = M2.PInj ident (map pPattern patts)
pPattern M1.PCatch = M2.PCatch
pPattern (M1.PEnum ident) = M2.PEnum ident
pPattern (M1.PVar ident) = M2.PVar ident
pPattern (M1.PLit lit) = M2.PLit lit
pPattern (M1.PInj ident patts) = M2.PInj ident (map (bimap pPattern pType) patts)
pPattern M1.PCatch = M2.PCatch
pPattern (M1.PEnum ident) = M2.PEnum ident
pLit :: M1.Lit -> M2.Lit
pLit (M1.LInt v) = M2.LInt v
pLit (M1.LChar c) = M2.LChar c

View file

@ -1,4 +1,5 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedRecordDot #-}
{- | For now, converts polymorphic functions to concrete ones based on usage.
Assumes lambdas are lifted.
@ -25,30 +26,35 @@ bind) is added to the resulting set of binds.
-}
module Monomorphizer.Monomorphizer (monomorphize, morphExp, morphBind) where
import Monomorphizer.DataTypeRemover (removeDataTypes)
import Monomorphizer.MonomorphizerIr qualified as O
import Monomorphizer.MorbIr qualified as M
import TypeChecker.TypeCheckerIr (Ident (Ident))
import TypeChecker.TypeCheckerIr qualified as T
import Control.Monad.Reader (
MonadReader (ask, local),
Reader,
asks,
runReader,
)
import Control.Monad.State (
MonadState (get),
StateT (runStateT),
gets,
modify,
)
import Data.Coerce (coerce)
import Data.Map qualified as Map
import Data.Maybe (catMaybes)
import Data.Set qualified as Set
import Grammar.Print (printTree)
import Debug.Trace (trace)
import Control.Monad.Reader (MonadReader (ask, local),
Reader, asks, runReader)
import Control.Monad.State (MonadState (get),
StateT (runStateT), gets,
modify)
import Data.Coerce (coerce)
import qualified Data.Map as Map
import Data.Maybe (catMaybes)
import qualified Data.Set as Set
import Debug.Trace (trace)
import Grammar.Print (printTree)
import Monomorphizer.DataTypeRemover (removeDataTypes)
import qualified Monomorphizer.MonomorphizerIr as O
import qualified Monomorphizer.MorbIr as M
-- import TypeChecker.TypeCheckerIr (Ident (Ident))
import LambdaLifterIr (Ident (..))
-- import TypeChecker.TypeCheckerIr qualified as T
import qualified LambdaLifterIr as L
import Control.Monad.Reader (MonadReader (ask, local),
Reader, asks, runReader)
import Control.Monad.State (MonadState, StateT (runStateT),
gets, modify)
import qualified Data.Map as Map
import Data.Maybe (catMaybes, fromJust)
import qualified Data.Set as Set
import Data.Tuple.Extra (secondM)
import Grammar.Print (printTree)
{- | EnvM is the monad containing the read-only state as well as the
output state containing monomorphized functions and to-be monomorphized
@ -64,18 +70,18 @@ Binds, Polymorphic Data types (monomorphized in a later step) and
Marked bind, which means that it is in the process of monomorphization
and should not be monomorphized again.
-}
data Outputted = Marked | Complete M.Bind | Data M.Type T.Data deriving (Show)
data Outputted = Marked | Complete M.Bind | Data M.Type L.Data deriving (Show)
-- | Static environment.
data Env = Env
{ input :: Map.Map Ident T.Bind
{ input :: Map.Map Ident L.Bind
-- ^ All binds in the program.
, dataDefs :: Map.Map Ident T.Data
, dataDefs :: Map.Map Ident L.Data
-- ^ All constructors mapped to their respective polymorphic data def
-- which includes all other constructors.
, polys :: Map.Map Ident M.Type
, polys :: Map.Map Ident M.Type
-- ^ Maps polymorphic identifiers with concrete types.
, locals :: Set.Set Ident
, locals :: Set.Set Ident
-- ^ Local variables.
}
@ -84,12 +90,13 @@ localExists :: Ident -> EnvM Bool
localExists ident = asks (Set.member ident . locals)
-- | Gets a polymorphic bind from an id.
getInputBind :: Ident -> EnvM (Maybe T.Bind)
getInputBind :: Ident -> EnvM (Maybe L.Bind)
getInputBind ident = asks (Map.lookup ident . input)
-- | Add monomorphic function derived from a polymorphic one, to env.
addOutputBind :: M.Bind -> EnvM ()
addOutputBind b@(M.Bind (ident, _) _ _) = modify (Map.insert ident (Complete b))
addOutputBind b@(M.BindC _ (ident, _) _ _) = modify (Map.insert ident (Complete b))
{- | Marks a global bind as being processed, meaning that when encountered again,
it should not be recursively processed.
@ -106,8 +113,8 @@ isConsMarked :: Ident -> EnvM Bool
isConsMarked ident = gets (Map.member ident)
-- | Finds main bind.
getMain :: EnvM T.Bind
getMain = asks (\env -> case Map.lookup (T.Ident "main") (input env) of
getMain :: EnvM L.Bind
getMain = asks (\env -> case Map.lookup (Ident "main") (input env) of
Just mainBind -> mainBind
Nothing -> error "main not found in monomorphizer!"
)
@ -116,13 +123,13 @@ getMain = asks (\env -> case Map.lookup (T.Ident "main") (input env) of
error when encountering different structures between the two arguments. Debug:
First argument is the name of the bind.
-}
mapTypes :: Ident -> T.Type -> M.Type -> [(Ident, M.Type)]
mapTypes _ident (T.TLit _) (M.TLit _) = []
mapTypes _ident (T.TVar (T.MkTVar i1)) tm = [(i1, tm)]
mapTypes ident (T.TFun pt1 pt2) (M.TFun mt1 mt2) =
mapTypes :: Ident -> L.Type -> M.Type -> [(Ident, M.Type)]
mapTypes _ident (L.TLit _) (M.TLit _) = []
mapTypes _ident (L.TVar (L.MkTVar i1)) tm = [(i1, tm)]
mapTypes ident (L.TFun pt1 pt2) (M.TFun mt1 mt2) =
mapTypes ident pt1 mt1
++ mapTypes ident pt2 mt2
mapTypes ident (T.TData tIdent pTs) (M.TData mIdent mTs) =
mapTypes ident (L.TData tIdent pTs) (M.TData mIdent mTs) =
if tIdent /= mIdent
then error "the data type names of monomorphic and polymorphic data types does not match"
else foldl (\xs (p, m) -> mapTypes ident p m ++ xs) [] (zip pTs mTs)
@ -130,30 +137,30 @@ mapTypes ident t1 t2 = error $ "in bind: '" ++ printTree ident ++ "', " ++
"structure of types not the same: '" ++ printTree t1 ++ "', '" ++ printTree t2 ++ "'"
-- | Gets the mapped monomorphic type of a polymorphic type in the current context.
getMonoFromPoly :: T.Type -> EnvM M.Type
getMonoFromPoly :: L.Type -> EnvM M.Type
getMonoFromPoly t = do
env <- ask
return $ getMono (polys env) t
where
getMono :: Map.Map Ident M.Type -> T.Type -> M.Type
getMono :: Map.Map Ident M.Type -> L.Type -> M.Type
getMono polys t = case t of
(T.TLit ident) -> M.TLit (coerce ident)
(T.TFun t1 t2) -> M.TFun (getMono polys t1) (getMono polys t2)
(T.TVar (T.MkTVar ident)) -> case Map.lookup ident polys of
(L.TLit ident) -> M.TLit ident
(L.TFun t1 t2) -> M.TFun (getMono polys t1) (getMono polys t2)
(L.TVar (L.MkTVar ident)) -> case Map.lookup ident polys of
Just concrete -> concrete
Nothing -> M.TLit (Ident "void")
Nothing -> M.TLit (Ident "void")
-- error $ "type not found! type: " ++ show ident ++ ", error in previous compilation steps"
(T.TData ident args) -> M.TData ident (map (getMono polys) args)
(L.TData ident args) -> M.TData ident (map (getMono polys) args)
{- | If ident not already in env's output, morphed bind to output
(and all referenced binds within this bind).
Returns the annotated bind name.
-}
morphBind :: M.Type -> T.Bind -> EnvM Ident
morphBind expectedType b@(T.Bind (ident, btype) args (exp, expt)) = do
morphBind :: M.Type -> L.Bind -> EnvM Ident
morphBind expectedType b@(L.Bind (ident, btype) args (exp, expt)) = do
-- The "new name" is used to find out if it is already marked or not.
let name' = newFuncName expectedType b
bindMarked <- isBindMarked (coerce name')
bindMarked <- isBindMarked name'
local
( \env ->
env
@ -168,26 +175,59 @@ morphBind expectedType b@(T.Bind (ident, btype) args (exp, expt)) = do
else do
-- Mark so that this bind will not be processed in recursive or cyclic
-- function calls
markBind (coerce name')
markBind name'
expt' <- getMonoFromPoly expt
exp' <- morphExp expt' exp
-- Get monomorphic type sof args
args' <- mapM morphArg args
addOutputBind $
M.Bind
(coerce name', expectedType)
(name', expectedType)
args'
(exp', expt')
return name'
morphBind expectedType b@(L.BindC cxt (ident, btype) args (exp, expt)) = do
-- The "new name" is used to find out if it is already marked or not.
let name' = newFuncName expectedType b
bindMarked <- isBindMarked name'
local
( \env ->
env
{ locals = Set.fromList (map fst args)
, polys = Map.fromList (mapTypes ident btype expectedType)
}
)
$ do
-- Return with right name if already marked
if bindMarked
then return name'
else do
-- Mark so that this bind will not be processed in recursive or cyclic
-- function calls
markBind name'
-- Get monomorphic type sof args
args' <- mapM morphArg args
cxt' <- mapM (secondM getMonoFromPoly) cxt
expt' <- getMonoFromPoly expt
exp' <- local (\env -> foldr (addLocal . fst) env cxt)
(morphExp expt' exp)
addOutputBind $
M.BindC cxt'
(name', expectedType)
args'
(exp', expt')
return name'
-- | Monomorphizes arguments of a bind.
morphArg :: (Ident, T.Type) -> EnvM (Ident, M.Type)
morphArg :: (Ident, L.Type) -> EnvM (Ident, M.Type)
morphArg (ident, t) = do
t' <- getMonoFromPoly t
return (ident, t')
-- | Gets the data bind from the name of a constructor.
getInputData :: Ident -> EnvM (Maybe T.Data)
getInputData :: Ident -> EnvM (Maybe L.Data)
getInputData ident = do
env <- ask
return $ Map.lookup ident (dataDefs env)
@ -201,50 +241,50 @@ morphCons expectedType ident newIdent = do
--trace ("Tjofras:" ++ show (newName expectedType ident)) $ return ()
maybeD <- getInputData ident
case maybeD of
Nothing -> error $ "identifier '" ++ show ident ++ "' not found"
-- closures can have unbound variables
Nothing -> pure ()
Just d -> do
modify (\output -> Map.insert newIdent (Data expectedType d) output)
-- | Converts literals from input to output tree.
convertLit :: T.Lit -> M.Lit
convertLit (T.LInt v) = M.LInt v
convertLit (T.LChar v) = M.LChar v
convertLit :: L.Lit -> M.Lit
convertLit (L.LInt v) = M.LInt v
convertLit (L.LChar v) = M.LChar v
-- | Monomorphizes an expression, given an expected type.
morphExp :: M.Type -> T.Exp -> EnvM M.Exp
morphExp :: M.Type -> L.Exp -> EnvM M.Exp
morphExp expectedType exp = case exp of
T.ELit lit -> return $ M.ELit (convertLit lit)
L.ELit lit -> return $ M.ELit lit
-- Constructor
T.EInj ident -> do
L.EInj ident -> do
let ident' = newName (getDataType expectedType) ident
morphCons expectedType ident ident'
return $ M.EVar ident'
T.EApp (e1, _t1) (e2, t2) -> do
L.EApp (e1, _t1) (e2, t2) -> do
t2' <- getMonoFromPoly t2
e2' <- morphExp t2' e2
e1' <- morphExp (M.TFun t2' expectedType) e1
return $ M.EApp (e1', M.TFun t2' expectedType) (e2', t2')
T.EAdd (e1, t1) (e2, t2) -> do
L.EAdd (e1, t1) (e2, t2) -> do
t1' <- getMonoFromPoly t1
t2' <- getMonoFromPoly t2
e1' <- morphExp t1' e1
e2' <- morphExp t2' e2
return $ M.EAdd (e1', expectedType) (e2', expectedType)
T.EAbs ident (exp, t) -> local (\env -> env{locals = Set.insert ident (locals env)}) $ do
t' <- getMonoFromPoly t
morphExp t' exp
T.ECase (exp, t) bs -> do
L.ECase (exp, t) bs -> do
t' <- getMonoFromPoly t
exp' <- morphExp t' exp
bs' <- mapM morphBranch bs
return $ M.ECase (exp', t') (catMaybes bs')
-- Ideally constructors should be EInj, though this code handles them
-- as well.
T.EVar ident -> do
-- FIXME MAKE EVAR AND EINJ SEPARATE!!!
L.EVar ident -> do
isLocal <- localExists ident
if isLocal
then do
return $ M.EVar (coerce ident)
return $ M.EVar ident
else do
bind <- getInputBind ident
case bind of
@ -252,38 +292,51 @@ morphExp expectedType exp = case exp of
Just bind' -> do
-- New bind to process
newBindName <- morphBind expectedType bind'
return $ M.EVar (coerce newBindName)
T.ELet (T.Bind (identB, tB) args (expB, tExpB)) (exp, tExp) ->
if length args > 0 then error "only constants in lets allowed"
else do
return $ M.EVar newBindName
L.EVarC as ident -> do
isLocal <- localExists ident
if isLocal
then do
return $ M.EVar ident
else do
bind <- fromJust <$> getInputBind ident
as' <- mapM (secondM getMonoFromPoly) as
-- New bind to process
newBindName <- morphBind expectedType bind
return $ M.EVarC as' newBindName
-- Ideally constructors should be EInj, though this code handles them
-- as well.
L.ELet (identB, tB) (expB, tExpB) (exp, tExp) -> do
tB' <- getMonoFromPoly tB
tExpB' <- getMonoFromPoly tExpB
tExp' <- getMonoFromPoly tExp
expB' <- morphExp tExpB' expB
exp' <- morphExp tExp' exp
exp' <- local (addLocal identB) (morphExp tExp' exp)
return $ M.ELet (M.Bind (identB, tB') [] (expB', tExpB')) (exp', tExp')
-- | Monomorphizes case-of branches.
morphBranch :: T.Branch -> EnvM (Maybe M.Branch)
morphBranch (T.Branch (p, pt) (e, et)) = do
morphBranch :: L.Branch -> EnvM (Maybe M.Branch)
morphBranch (L.Branch (p, pt) (e, et)) = do
pt' <- getMonoFromPoly pt
et' <- getMonoFromPoly et
env <- ask
maybeMorphedPattern <- morphPattern p pt'
case maybeMorphedPattern of
Nothing -> return Nothing
Just (p', newLocals) ->
Just (p', newLocals) ->
local (const env { locals = Set.union (locals env) newLocals }) $ do
e' <- morphExp et' e
return $ Just (M.Branch (p', pt') (e', et'))
return $ Just (M.Branch p' (e', et'))
morphPattern :: T.Pattern -> M.Type -> EnvM (Maybe (M.Pattern, Set.Set Ident))
morphPattern :: L.Pattern -> M.Type -> EnvM (Maybe (M.T M.Pattern, Set.Set Ident))
morphPattern p expectedType = case p of
T.PVar ident -> return $ Just (M.PVar (ident, expectedType), Set.singleton ident)
T.PLit lit -> return $ Just (M.PLit (convertLit lit, expectedType), Set.empty)
T.PCatch -> return $ Just (M.PCatch, Set.empty)
T.PEnum ident -> return $ Just (M.PEnum (newName expectedType ident), Set.empty)
T.PInj ident pts -> do let newIdent = newName expectedType ident
L.PVar ident -> return $ Just ((M.PVar ident, expectedType), Set.singleton ident)
L.PLit lit -> return $ Just ((M.PLit (convertLit lit), expectedType), Set.empty)
L.PCatch -> return $ Just ((M.PCatch, expectedType), Set.empty)
L.PEnum ident -> return $ Just ((M.PEnum (newName expectedType ident), expectedType), Set.empty)
L.PInj ident pts -> do let newIdent = newName expectedType ident
outEnv <- get
trace ("WOW: " ++ show (newName expectedType ident)) $ return ()
trace ("WOW2: " ++ show (outEnv)) $ return ()
@ -297,13 +350,18 @@ morphPattern p expectedType = case p of
let maybePsSets = sequence psSets
case maybePsSets of
Nothing -> return Nothing
Just psSets' -> return $ Just
(M.PInj newIdent (map fst psSets'), Set.unions $ map snd psSets')
Just psSets' -> return $ Just
((M.PInj newIdent (map fst psSets'), expectedType), Set.unions $ map snd psSets')
else return Nothing
-- | Creates a new identifier for a function with an assigned type.
newFuncName :: M.Type -> T.Bind -> Ident
newFuncName t (T.Bind (ident@(Ident bindName), _) _ _) =
newFuncName :: M.Type -> L.Bind -> Ident
newFuncName t (L.Bind (ident@(Ident bindName), _) _ _) =
if bindName == "main"
then Ident bindName
else newName t ident
newFuncName t (L.BindC _ (ident@(Ident bindName), _) _ _) =
if bindName == "main"
then Ident bindName
else newName t ident
@ -317,8 +375,8 @@ newName t (Ident str) = Ident $ str ++ "$" ++ newName' t
newName' (M.TData (Ident str) ts) = str ++ foldl (\s t -> s ++ "." ++ newName' t) "" ts
-- | Monomorphization step.
monomorphize :: T.Program -> O.Program
monomorphize (T.Program defs) =
monomorphize :: L.Program -> O.Program
monomorphize (L.Program defs) =
removeDataTypes $
M.Program
( getDefsFromOutput
@ -336,7 +394,7 @@ runEnvM :: Output -> Env -> EnvM () -> Output
runEnvM o env (EnvM stateM) = snd $ runReader (runStateT stateM o) env
-- | Creates the environment based on the input binds.
createEnv :: [T.Def] -> Env
createEnv :: [L.Def] -> Env
createEnv defs =
Env
{ input = Map.fromList bindPairs
@ -346,33 +404,34 @@ createEnv defs =
}
where
bindPairs = (map (\b -> (getBindName b, b)) . getBindsFromDefs) defs
dataPairs :: [(Ident, T.Data)]
dataPairs = (foldl (\acc d@(T.Data _ cs) -> map ((,d) . getConsName) cs ++ acc) [] . getDataFromDefs) defs
dataPairs :: [(Ident, L.Data)]
dataPairs = (foldl (\acc d@(L.Data _ cs) -> map ((,d) . getConsName) cs ++ acc) [] . getDataFromDefs) defs
-- | Gets a top-lefel function name.
getBindName :: T.Bind -> Ident
getBindName (T.Bind (ident, _) _ _) = ident
getBindName :: L.Bind -> Ident
getBindName (L.Bind (ident, _) _ _) = ident
getBindName (L.BindC _ (ident, _) _ _) = ident
-- Helper functions
-- Gets custom data declarations form defs.
getDataFromDefs :: [T.Def] -> [T.Data]
getDataFromDefs :: [L.Def] -> [L.Data]
getDataFromDefs =
foldl
( \bs -> \case
T.DBind _ -> bs
T.DData d -> d : bs
L.DBind _ -> bs
L.DData d -> d : bs
)
[]
getConsName :: T.Inj -> Ident
getConsName (T.Inj ident _) = ident
getConsName :: L.Inj -> Ident
getConsName (L.Inj ident _) = ident
getBindsFromDefs :: [T.Def] -> [T.Bind]
getBindsFromDefs :: [L.Def] -> [L.Bind]
getBindsFromDefs =
foldl
( \bs -> \case
T.DBind b -> b : bs
T.DData _ -> bs
L.DBind b -> b : bs
L.DData _ -> bs
)
[]
@ -384,19 +443,19 @@ getDefsFromOutput o =
(binds, dataInput) = splitBindsAndData o
-- | Splits the output into binds and data declaration components (used in createNewData)
splitBindsAndData :: Output -> ([M.Bind], [(Ident, M.Type, T.Data)])
splitBindsAndData :: Output -> ([M.Bind], [(Ident, M.Type, L.Data)])
splitBindsAndData output =
foldl
( \(oBinds, oData) (ident, o) -> case o of
Marked -> error "internal bug in monomorphizer"
Marked -> error "internal bug in monomorphizer"
Complete b -> (b : oBinds, oData)
Data t d -> (oBinds, (ident, t, d) : oData)
Data t d -> (oBinds, (ident, t, d) : oData)
)
([], [])
(Map.toList output)
-- | Converts all found constructors to monomorphic data declarations.
createNewData :: [(Ident, M.Type, T.Data)] -> Map.Map Ident M.Data -> Map.Map Ident M.Data
createNewData :: [(Ident, M.Type, L.Data)] -> Map.Map Ident M.Data -> Map.Map Ident M.Data
createNewData [] o = o
createNewData ((consIdent, consType, polyData) : input) o =
createNewData input $
@ -406,14 +465,17 @@ createNewData ((consIdent, consType, polyData) : input) o =
(M.Data newDataType [newCons])
o
where
T.Data (T.TData polyDataIdent _) _ = polyData
L.Data (L.TData polyDataIdent _) _ = polyData
newDataType = getDataType consType
newDataName = newName newDataType polyDataIdent
newCons = M.Inj consIdent consType
-- | Gets the Data Type of a constructor type (a -> Just a becomes Just a).
getDataType :: M.Type -> M.Type
getDataType (M.TFun _t1 t2) = getDataType t2
getDataType (M.TFun _t1 t2) = getDataType t2
getDataType tData@(M.TData _ _) = tData
getDataType _ = error "???"
getDataType _ = error "???"
addLocal :: Ident -> Env -> Env
addLocal x env = env { locals = Set.insert x env.locals }

View file

@ -1,11 +1,14 @@
{-# LANGUAGE LambdaCase #-}
module Monomorphizer.MonomorphizerIr (module Monomorphizer.MonomorphizerIr) where
module Monomorphizer.MonomorphizerIr (
module Monomorphizer.MonomorphizerIr,
module LambdaLifterIr
) where
import Grammar.Print
import TypeChecker.TypeCheckerIr qualified as TIR (Ident (..))
type Id = (TIR.Ident, Type)
import Data.List (intercalate)
import Grammar.Print
import LambdaLifterIr (Ident (..), Lit (..))
import Prelude hiding (exp)
newtype Program = Program [Def]
deriving (Show, Ord, Eq)
@ -16,90 +19,80 @@ data Def = DBind Bind | DData Data
data Data = Data Type [Inj]
deriving (Show, Ord, Eq)
data Bind = Bind Id [Id] ExpT
data Bind = Bind (T Ident) [T Ident] (T Exp)
| BindC [T Ident] (T Ident) [T Ident] (T Exp)
deriving (Show, Ord, Eq)
type T a = (a, Type)
data Exp
= EVar TIR.Ident
= EVar Ident
| EVarC [T Ident] Ident
| ELit Lit
| ELet Bind ExpT
| EApp ExpT ExpT
| EAdd ExpT ExpT
| ECase ExpT [Branch]
| ELet Bind (T Exp)
| EApp (T Exp) (T Exp)
| EAdd (T Exp) (T Exp)
| ECase (T Exp) [Branch]
deriving (Show, Ord, Eq)
data Pattern
= PVar Id
| PLit (Lit, Type)
| PInj TIR.Ident [Pattern]
= PVar Ident
| PLit Lit
| PInj Ident [T Pattern]
| PCatch
| PEnum TIR.Ident
| PEnum Ident
deriving (Eq, Ord, Show)
data Branch = Branch (Pattern, Type) ExpT
data Branch = Branch (T Pattern) (T Exp)
deriving (Eq, Ord, Show)
type ExpT = (Exp, Type)
data Inj = Inj TIR.Ident Type
data Inj = Inj Ident Type
deriving (Show, Ord, Eq)
data Lit
= LInt Integer
| LChar Char
deriving (Show, Ord, Eq)
data Type = TLit TIR.Ident | TFun Type Type
data Type = TLit Ident | TFun Type Type
deriving (Show, Ord, Eq)
flattenType :: Type -> [Type]
flattenType (TFun t1 t2) = t1 : flattenType t2
flattenType x = [x]
flattenType x = [x]
instance Print Program where
prt i (Program sc) = prPrec i 0 $ prt 0 sc
instance Print (Bind) where
instance Print Bind where
prt i (Bind sig@(name, _) parms rhs) =
prPrec i 0 $
concatD
[ prtSig sig
[ prt 0 sig
, prt 0 name
, prtIdPs 0 parms
, prt 0 parms
, doc $ showString "="
, prt 0 rhs
]
prtSig :: Id -> Doc
prtSig (name, t) =
concatD
[ prt 0 name
, doc $ showString ":"
, prt 0 t
, doc $ showString ";"
]
prt i (BindC cxt sig parms rhs) =
prPrec i 0 $
concatD
[ doc . showString $ "{" ++ intercalate ", " (map (\(x, _) -> printTree x ++ "^") cxt) ++ "}" ++ printTree sig
, prt i parms
, doc $ showString "="
, prt i rhs
]
instance Print (ExpT) where
prt i (e, t) =
concatD
[ doc $ showString "("
, prt i e
, doc $ showString ","
, prt i t
, doc $ showString ")"
]
instance Print [Bind] where
prt _ [] = concatD []
prt _ [x] = concatD [prt 0 x]
prt _ [] = concatD []
prt _ [x] = concatD [prt 0 x]
prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs]
prtIdPs :: Int -> [Id] -> Doc
prtIdPs i = prPrec i 0 . concatD . map (prt i)
instance Print Exp where
prt i = \case
EVar name -> prPrec i 3 $ prt 0 name
EVarC as lident -> doc . showString
$ "{" ++ intercalate ", " (map go as) ++ "}" ++ printTree lident
where
go (x, _) = printTree x ++ "^=" ++ printTree (EVar x)
ELit lit -> prPrec i 3 $ prt 0 lit
ELet b e ->
prPrec i 3 $
@ -134,16 +127,16 @@ instance Print Exp where
]
instance Print Branch where
prt i (Branch (pattern_, t) exp) = prPrec i 0 (concatD [doc (showString "("), prt 0 pattern_, doc (showString " : "), prt 0 t, doc (showString ")"), doc (showString "=>"), prt 0 exp])
prt i (Branch patt exp) = prPrec i 0 (concatD [prt i patt, doc (showString "=>"), prt 0 exp])
instance Print [Branch] where
prt _ [] = concatD []
prt _ [x] = concatD [prt 0 x]
prt _ [] = concatD []
prt _ [x] = concatD [prt 0 x]
prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs]
instance Print Def where
prt i = \case
DBind bind -> prPrec i 0 (concatD [prt 0 bind])
DBind bind -> prPrec i 0 (concatD [prt 0 bind])
DData data_ -> prPrec i 0 (concatD [prt 0 data_])
instance Print Data where
@ -152,23 +145,23 @@ instance Print Data where
instance Print Inj where
prt i = \case
Inj uident type_ -> prPrec i 0 (concatD [prt 0 uident, doc (showString ":"), prt 0 type_])
Inj uident type_ -> prt i (uident, type_)
instance Print Pattern where
prt i = \case
PVar name -> prPrec i 1 (concatD [prt 0 name])
PLit (lit, _) -> prPrec i 1 (concatD [prt 0 lit])
PLit lit -> prPrec i 1 (concatD [prt 0 lit])
PCatch -> prPrec i 1 (concatD [doc (showString "_")])
PEnum name -> prPrec i 1 (concatD [prt 0 name])
PInj uident patterns -> prPrec i 0 (concatD [prt 0 uident, prt 1 patterns])
instance Print [Def] where
prt _ [] = concatD []
prt _ [x] = concatD [prt 0 x]
prt _ [] = concatD []
prt _ [x] = concatD [prt 0 x]
prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs]
instance Print [Type] where
prt _ [] = concatD []
prt _ [] = concatD []
prt _ (x : xs) = concatD [prt 0 x, doc (showString " "), prt 0 xs]
instance Print Type where
@ -176,7 +169,3 @@ instance Print Type where
TLit uident -> prPrec i 1 (concatD [prt 0 uident])
TFun type_1 type_2 -> prPrec i 0 (concatD [prt 1 type_1, doc (showString "->"), prt 0 type_2])
instance Print Lit where
prt i = \case
LInt int -> prt i int
LChar char -> prt i char

View file

@ -1,10 +1,14 @@
{-# LANGUAGE LambdaCase #-}
module Monomorphizer.MorbIr where
import Grammar.Print
import TypeChecker.TypeCheckerIr qualified as TIR (Ident (..))
module Monomorphizer.MorbIr (
module Monomorphizer.MorbIr,
module LambdaLifterIr
) where
type Id = (TIR.Ident, Type)
import Data.List (intercalate)
import Grammar.Print
import LambdaLifterIr (Ident (..), Lit (..))
import Prelude hiding (exp)
newtype Program = Program [Def]
deriving (Show, Ord, Eq)
@ -15,91 +19,81 @@ data Def = DBind Bind | DData Data
data Data = Data Type [Inj]
deriving (Show, Ord, Eq)
data Bind = Bind Id [Id] ExpT
data Bind = Bind (T Ident) [T Ident] (T Exp)
| BindC [T Ident] (T Ident) [T Ident] (T Exp)
deriving (Show, Ord, Eq)
type T a = (a, Type)
data Exp
= EVar TIR.Ident
= EVar Ident
| EVarC [T Ident] Ident
| ELit Lit
| ELet Bind ExpT
| EApp ExpT ExpT
| EAdd ExpT ExpT
| ECase ExpT [Branch]
| ELet Bind (T Exp)
| EApp (T Exp) (T Exp)
| EAdd (T Exp) (T Exp)
| ECase (T Exp) [Branch]
deriving (Show, Ord, Eq)
data Pattern
= PVar Id
| PLit (Lit, Type)
| PInj TIR.Ident [Pattern]
= PVar Ident
| PLit Lit
| PInj Ident [T Pattern]
| PCatch
| PEnum TIR.Ident
| PEnum Ident
deriving (Eq, Ord, Show)
data Branch = Branch (Pattern, Type) ExpT
data Branch = Branch (T Pattern) (T Exp)
deriving (Eq, Ord, Show)
type ExpT = (Exp, Type)
data Inj = Inj TIR.Ident Type
data Inj = Inj Ident Type
deriving (Show, Ord, Eq)
data Lit
= LInt Integer
| LChar Char
deriving (Show, Ord, Eq)
data Type = TLit TIR.Ident | TFun Type Type | TData TIR.Ident [Type]
data Type = TLit Ident | TFun Type Type | TData Ident [Type]
deriving (Show, Ord, Eq)
flattenType :: Type -> [Type]
flattenType (TFun t1 t2) = t1 : flattenType t2
flattenType x = [x]
flattenType x = [x]
instance Print Program where
prt i (Program sc) = prPrec i 0 $ prt 0 sc
instance Print (Bind) where
instance Print Bind where
prt i (Bind sig@(name, _) parms rhs) =
prPrec i 0 $
concatD
[ prtSig sig
[ prt 0 sig
, prt 0 name
, prtIdPs 0 parms
, prt 0 parms
, doc $ showString "="
, prt 0 rhs
]
prtSig :: Id -> Doc
prtSig (name, t) =
concatD
[ prt 0 name
, doc $ showString ":"
, prt 0 t
, doc $ showString ";"
]
instance Print (ExpT) where
prt i (e, t) =
concatD
[ doc $ showString "("
, prt i e
, doc $ showString ","
, prt i t
, doc $ showString ")"
]
prt i (BindC cxt sig parms rhs) =
prPrec i 0 $
concatD
[ doc . showString $ "{" ++ intercalate ", " (map (\(x, _) -> printTree x ++ "^") cxt) ++ "}" ++ printTree sig
, prt i parms
, doc $ showString "="
, prt i rhs
]
instance Print [Bind] where
prt _ [] = concatD []
prt _ [x] = concatD [prt 0 x]
prt _ [] = concatD []
prt _ [x] = concatD [prt 0 x]
prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs]
prtIdPs :: Int -> [Id] -> Doc
prtIdPs i = prPrec i 0 . concatD . map (prt i)
instance Print Exp where
prt i = \case
EVar name -> prPrec i 3 $ prt 0 name
EVarC as lident -> doc . showString
$ "{" ++ intercalate ", " (map go as) ++ "}" ++ printTree lident
where
go (x, _) = printTree x ++ "^=" ++ printTree (EVar x)
ELit lit -> prPrec i 3 $ prt 0 lit
ELet b e ->
prPrec i 3 $
@ -134,16 +128,16 @@ instance Print Exp where
]
instance Print Branch where
prt i (Branch (pattern_, t) exp) = prPrec i 0 (concatD [doc (showString "("), prt 0 pattern_, doc (showString " : "), prt 0 t, doc (showString ")"), doc (showString "=>"), prt 0 exp])
prt i (Branch patt exp) = prPrec i 0 (concatD [prt i patt, doc (showString "=>"), prt 0 exp])
instance Print [Branch] where
prt _ [] = concatD []
prt _ [x] = concatD [prt 0 x]
prt _ [] = concatD []
prt _ [x] = concatD [prt 0 x]
prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs]
instance Print Def where
prt i = \case
DBind bind -> prPrec i 0 (concatD [prt 0 bind])
DBind bind -> prPrec i 0 (concatD [prt 0 bind])
DData data_ -> prPrec i 0 (concatD [prt 0 data_])
instance Print Data where
@ -152,23 +146,23 @@ instance Print Data where
instance Print Inj where
prt i = \case
Inj uident type_ -> prPrec i 0 (concatD [prt 0 uident, doc (showString ":"), prt 0 type_])
Inj uident type_ -> prt i (uident, type_)
instance Print Pattern where
prt i = \case
PVar name -> prPrec i 1 (concatD [prt 0 name])
PLit (lit, _) -> prPrec i 1 (concatD [prt 0 lit])
PLit lit -> prPrec i 1 (concatD [prt 0 lit])
PCatch -> prPrec i 1 (concatD [doc (showString "_")])
PEnum name -> prPrec i 1 (concatD [prt 0 name])
PInj uident patterns -> prPrec i 0 (concatD [prt 0 uident, prt 1 patterns])
instance Print [Def] where
prt _ [] = concatD []
prt _ [x] = concatD [prt 0 x]
prt _ [] = concatD []
prt _ [x] = concatD [prt 0 x]
prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs]
instance Print [Type] where
prt _ [] = concatD []
prt _ [] = concatD []
prt _ (x : xs) = concatD [prt 0 x, doc (showString " "), prt 0 xs]
instance Print Type where
@ -177,8 +171,4 @@ instance Print Type where
TFun type_1 type_2 -> prPrec i 0 (concatD [prt 1 type_1, doc (showString "->"), prt 0 type_2])
TData uident types -> prPrec i 1 (concatD [prt 0 uident, doc (showString "("), prt 0 types, doc (showString ")")])
instance Print Lit where
prt i = \case
LInt int -> prt i int
LChar char -> prt i char