236 lines
6.9 KiB
Haskell
236 lines
6.9 KiB
Haskell
{-# LANGUAGE LambdaCase #-}
|
|
{-# LANGUAGE OverloadedStrings #-}
|
|
|
|
|
|
module LambdaLifter (lambdaLift, freeVars, abstract, rename, collectScs) where
|
|
|
|
import Data.List (mapAccumL, partition)
|
|
import Data.Map (Map)
|
|
import qualified Data.Map as Map
|
|
import Data.Maybe (fromMaybe)
|
|
import Data.Set (Set, (\\))
|
|
import qualified Data.Set as Set
|
|
import Data.Tuple.Extra (uncurry3)
|
|
import Grammar.Abs
|
|
import Prelude hiding (exp)
|
|
|
|
|
|
|
|
-- | Lift lambdas and let expression into supercombinators.
|
|
lambdaLift :: Program -> Program
|
|
lambdaLift = collectScs . rename . abstract . freeVars
|
|
|
|
|
|
-- | Annotate free variables
|
|
freeVars :: Program -> AnnProgram
|
|
freeVars (Program ds) = [ (n, xs, freeVarsExp (Set.fromList xs) e)
|
|
| Bind n xs e <- ds
|
|
]
|
|
|
|
freeVarsExp :: Set Ident -> Exp -> AnnExp
|
|
freeVarsExp lv = \case
|
|
|
|
EId n | Set.member n lv -> (Set.singleton n, AId n)
|
|
| otherwise -> (mempty, AId n)
|
|
|
|
EInt i -> (mempty, AInt i)
|
|
|
|
EApp e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), AApp e1' e2')
|
|
where e1' = freeVarsExp lv e1
|
|
e2' = freeVarsExp lv e2
|
|
|
|
EAdd e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), AAdd e1' e2')
|
|
where e1' = freeVarsExp lv e1
|
|
e2' = freeVarsExp lv e2
|
|
|
|
EAbs parms e -> (freeVarsOf e' \\ Set.fromList parms, AAbs parms e')
|
|
where e' = freeVarsExp (foldr Set.insert lv parms) e
|
|
|
|
ELet bs e -> (Set.union bsFree eFree, ALet bs' e')
|
|
where
|
|
bsFree = freeInValues \\ nsSet
|
|
eFree = freeVarsOf e' \\ nsSet
|
|
bs' = zipWith3 ABind ns xs es'
|
|
e' = freeVarsExp e_lv e
|
|
(ns, xs, es) = fromBinders bs
|
|
nsSet = Set.fromList ns
|
|
e_lv = Set.union lv nsSet
|
|
es' = map (freeVarsExp e_lv) es
|
|
freeInValues = foldr1 Set.union (map freeVarsOf es')
|
|
|
|
|
|
freeVarsOf :: AnnExp -> Set Ident
|
|
freeVarsOf = fst
|
|
|
|
fromBinders :: [Bind] -> ([Ident], [[Ident]], [Exp])
|
|
fromBinders bs = unzip3 [ (n, xs, e) | Bind n xs e <- bs ]
|
|
|
|
-- AST annotated with free variables
|
|
type AnnProgram = [(Ident, [Ident], AnnExp)]
|
|
|
|
type AnnExp = (Set Ident, AnnExp')
|
|
|
|
data ABind = ABind Ident [Ident] AnnExp deriving Show
|
|
|
|
data AnnExp' = AId Ident
|
|
| AInt Integer
|
|
| AApp AnnExp AnnExp
|
|
| AAdd AnnExp AnnExp
|
|
| AAbs [Ident] AnnExp
|
|
| ALet [ABind] AnnExp
|
|
deriving Show
|
|
|
|
-- | Lift lambdas to let expression of the form @let sc = \v₁ x₁ -> e₁@.
|
|
-- Free variables are @v₁ v₂ .. vₙ@ are bound.
|
|
abstract :: AnnProgram -> Program
|
|
abstract prog = Program $ map go prog
|
|
where
|
|
go :: (Ident, [Ident], AnnExp) -> Bind
|
|
go (name, pars, rhs@(_, e)) =
|
|
|
|
case e of
|
|
AAbs pars1 e1 -> Bind name (pars ++ pars1) $ abstractExp e1
|
|
_ -> Bind name pars $ abstractExp rhs
|
|
|
|
abstractExp :: AnnExp -> Exp
|
|
abstractExp (free, exp) = case exp of
|
|
AId n -> EId n
|
|
AInt i -> EInt i
|
|
AApp e1 e2 -> EApp (abstractExp e1) (abstractExp e2)
|
|
AAdd e1 e2 -> EAdd (abstractExp e1) (abstractExp e2)
|
|
ALet bs e -> ELet (map go bs) $ abstractExp e
|
|
where
|
|
go (ABind name parms rhs) = Bind name parms $ skipLambdas abstractExp rhs
|
|
|
|
skipLambdas :: (AnnExp -> Exp) -> AnnExp -> Exp
|
|
skipLambdas f (free, ae) = case ae of
|
|
AAbs name ae1 -> EAbs name $ skipLambdas f ae1
|
|
_ -> f (free, ae)
|
|
|
|
-- Lift lambda into let and bind free variables
|
|
AAbs parms e -> foldl EApp sc $ map EId freeList
|
|
where
|
|
freeList = Set.toList free
|
|
sc = ELet [Bind "sc" [] rhs] $ EId "sc"
|
|
rhs = EAbs (freeList ++ parms) $ abstractExp e
|
|
|
|
-- | Rename all supercombinators and variables
|
|
rename :: Program -> Program
|
|
rename (Program ds) = Program $ map (uncurry3 Bind) tuples
|
|
where
|
|
tuples = snd (mapAccumL renameSc 0 ds)
|
|
renameSc i (Bind n xs e) = (i2, (n, xs', e'))
|
|
where
|
|
(i1, xs', env) = newNames i xs
|
|
(i2, e') = renameExp env i1 e
|
|
|
|
renameExp :: Map Ident Ident -> Int -> Exp -> (Int, Exp)
|
|
renameExp env i = \case
|
|
|
|
EId n -> (i, EId . fromMaybe n $ Map.lookup n env)
|
|
|
|
EInt i1 -> (i, EInt i1)
|
|
|
|
EApp e1 e2 -> (i2, EApp e1' e2')
|
|
where
|
|
(i1, e1') = renameExp env i e1
|
|
(i2, e2') = renameExp env i1 e2
|
|
|
|
EAdd e1 e2 -> (i2, EAdd e1' e2')
|
|
where
|
|
(i1, e1') = renameExp env i e1
|
|
(i2, e2') = renameExp env i1 e2
|
|
|
|
ELet bs e -> (i3, ELet (zipWith3 Bind ns' xs es') e')
|
|
where
|
|
(i1, e') = renameExp e_env i e
|
|
(ns, xs, es) = fromBinders bs
|
|
(i2, ns', env') = newNames i1 ns
|
|
e_env = Map.union env' env
|
|
(i3, es') = mapAccumL (renameExp e_env) i2 es
|
|
|
|
|
|
EAbs parms e -> (i2, EAbs ns e')
|
|
where
|
|
(i1, ns, env') = newNames i parms
|
|
(i2, e') = renameExp (Map.union env' env ) i1 e
|
|
|
|
|
|
newNames :: Int -> [Ident] -> (Int, [Ident], Map Ident Ident)
|
|
newNames i old_names = (i', new_names, env)
|
|
where
|
|
(i', new_names) = getNames i old_names
|
|
env = Map.fromList $ zip old_names new_names
|
|
|
|
getNames :: Int -> [Ident] -> (Int, [Ident])
|
|
getNames i ns = (i + length ss, zipWith makeName ss [i..])
|
|
where
|
|
ss = map (\(Ident s) -> s) ns
|
|
|
|
makeName :: String -> Int -> Ident
|
|
makeName prefix i = Ident (prefix ++ "_" ++ show i)
|
|
|
|
|
|
-- | Collects supercombinators by lifting appropriate let expressions
|
|
collectScs :: Program -> Program
|
|
collectScs (Program scs) = Program $ concatMap collectFromRhs scs
|
|
where
|
|
collectFromRhs (Bind name parms rhs) =
|
|
let (rhs_scs, rhs') = collectScsExp rhs
|
|
in Bind name parms rhs' : rhs_scs
|
|
|
|
|
|
collectScsExp :: Exp -> ([Bind], Exp)
|
|
collectScsExp = \case
|
|
EId n -> ([], EId n)
|
|
EInt i -> ([], EInt i)
|
|
|
|
EApp e1 e2 -> (scs1 ++ scs2, EApp e1' e2')
|
|
where
|
|
(scs1, e1') = collectScsExp e1
|
|
(scs2, e2') = collectScsExp e2
|
|
|
|
EAdd e1 e2 -> (scs1 ++ scs2, EAdd e1' e2')
|
|
where
|
|
(scs1, e1') = collectScsExp e1
|
|
(scs2, e2') = collectScsExp e2
|
|
|
|
EAbs x e -> (scs, EAbs x e')
|
|
where
|
|
(scs, e') = collectScsExp e
|
|
|
|
-- Collect supercombinators from binds, the rhss, and the expression,
|
|
-- and the rhss.
|
|
--
|
|
-- > f = let
|
|
-- > sc = rhs
|
|
-- > sc1 = rhs1
|
|
-- > ...
|
|
-- > in e
|
|
--
|
|
ELet binds e -> (binds_scs ++ rhss_scs ++ e_scs, mkEAbs non_scs' e')
|
|
where
|
|
binds_scs = [ Bind n (parms ++ parms1) e1
|
|
| Bind n parms (EAbs parms1 e1) <- scs'
|
|
]
|
|
(rhss_scs, binds') = mapAccumL collectScsRhs [] binds
|
|
(e_scs, e') = collectScsExp e
|
|
|
|
(scs', non_scs') = partition (\(Bind _ _ rhs) -> isEAbs rhs) binds'
|
|
|
|
collectScsRhs acc (Bind n xs rhs) = (acc ++ rhs_scs, Bind n xs rhs')
|
|
where
|
|
(rhs_scs, rhs') = collectScsExp rhs
|
|
|
|
|
|
|
|
isEAbs :: Exp -> Bool
|
|
isEAbs = \case
|
|
EAbs {} -> True
|
|
_ -> False
|
|
|
|
mkEAbs :: [Bind] -> Exp -> Exp
|
|
mkEAbs [] e = e
|
|
mkEAbs bs e = ELet bs e
|
|
|
|
|