{-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedStrings #-} module LambdaLifter (lambdaLift, freeVars, abstract, rename, collectScs) where import Data.List (mapAccumL, partition) import Data.Map (Map) import qualified Data.Map as Map import Data.Maybe (fromMaybe) import Data.Set (Set, (\\)) import qualified Data.Set as Set import Data.Tuple.Extra (uncurry3) import Grammar.Abs import Prelude hiding (exp) -- | Lift lambdas and let expression into supercombinators. lambdaLift :: Program -> Program lambdaLift = collectScs . rename . abstract . freeVars -- | Annotate free variables freeVars :: Program -> AnnProgram freeVars (Program ds) = [ (n, xs, freeVarsExp (Set.fromList xs) e) | Bind n xs e <- ds ] freeVarsExp :: Set Ident -> Exp -> AnnExp freeVarsExp lv = \case EId n | Set.member n lv -> (Set.singleton n, AId n) | otherwise -> (mempty, AId n) EInt i -> (mempty, AInt i) EApp e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), AApp e1' e2') where e1' = freeVarsExp lv e1 e2' = freeVarsExp lv e2 EAdd e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), AAdd e1' e2') where e1' = freeVarsExp lv e1 e2' = freeVarsExp lv e2 EAbs parms e -> (freeVarsOf e' \\ Set.fromList parms, AAbs parms e') where e' = freeVarsExp (foldr Set.insert lv parms) e ELet bs e -> (Set.union bsFree eFree, ALet bs' e') where bsFree = freeInValues \\ nsSet eFree = freeVarsOf e' \\ nsSet bs' = zipWith3 ABind ns xs es' e' = freeVarsExp e_lv e (ns, xs, es) = fromBinders bs nsSet = Set.fromList ns e_lv = Set.union lv nsSet es' = map (freeVarsExp e_lv) es freeInValues = foldr1 Set.union (map freeVarsOf es') freeVarsOf :: AnnExp -> Set Ident freeVarsOf = fst fromBinders :: [Bind] -> ([Ident], [[Ident]], [Exp]) fromBinders bs = unzip3 [ (n, xs, e) | Bind n xs e <- bs ] -- AST annotated with free variables type AnnProgram = [(Ident, [Ident], AnnExp)] type AnnExp = (Set Ident, AnnExp') data ABind = ABind Ident [Ident] AnnExp deriving Show data AnnExp' = AId Ident | AInt Integer | AApp AnnExp AnnExp | AAdd AnnExp AnnExp | AAbs [Ident] AnnExp | ALet [ABind] AnnExp deriving Show -- | Lift lambdas to let expression of the form @let sc = \v₁ x₁ -> e₁@. -- Free variables are @v₁ v₂ .. vₙ@ are bound. abstract :: AnnProgram -> Program abstract prog = Program $ map go prog where go :: (Ident, [Ident], AnnExp) -> Bind go (name, pars, rhs@(_, e)) = case e of AAbs pars1 e1 -> Bind name (pars ++ pars1) $ abstractExp e1 _ -> Bind name pars $ abstractExp rhs abstractExp :: AnnExp -> Exp abstractExp (free, exp) = case exp of AId n -> EId n AInt i -> EInt i AApp e1 e2 -> EApp (abstractExp e1) (abstractExp e2) AAdd e1 e2 -> EAdd (abstractExp e1) (abstractExp e2) ALet bs e -> ELet (map go bs) $ abstractExp e where go (ABind name parms rhs) = Bind name parms $ skipLambdas abstractExp rhs skipLambdas :: (AnnExp -> Exp) -> AnnExp -> Exp skipLambdas f (free, ae) = case ae of AAbs name ae1 -> EAbs name $ skipLambdas f ae1 _ -> f (free, ae) -- Lift lambda into let and bind free variables AAbs parms e -> foldl EApp sc $ map EId freeList where freeList = Set.toList free sc = ELet [Bind "sc" [] rhs] $ EId "sc" rhs = EAbs (freeList ++ parms) $ abstractExp e -- | Rename all supercombinators and variables rename :: Program -> Program rename (Program ds) = Program $ map (uncurry3 Bind) tuples where tuples = snd (mapAccumL renameSc 0 ds) renameSc i (Bind n xs e) = (i2, (n, xs', e')) where (i1, xs', env) = newNames i xs (i2, e') = renameExp env i1 e renameExp :: Map Ident Ident -> Int -> Exp -> (Int, Exp) renameExp env i = \case EId n -> (i, EId . fromMaybe n $ Map.lookup n env) EInt i1 -> (i, EInt i1) EApp e1 e2 -> (i2, EApp e1' e2') where (i1, e1') = renameExp env i e1 (i2, e2') = renameExp env i1 e2 EAdd e1 e2 -> (i2, EAdd e1' e2') where (i1, e1') = renameExp env i e1 (i2, e2') = renameExp env i1 e2 ELet bs e -> (i3, ELet (zipWith3 Bind ns' xs es') e') where (i1, e') = renameExp e_env i e (ns, xs, es) = fromBinders bs (i2, ns', env') = newNames i1 ns e_env = Map.union env' env (i3, es') = mapAccumL (renameExp e_env) i2 es EAbs parms e -> (i2, EAbs ns e') where (i1, ns, env') = newNames i parms (i2, e') = renameExp (Map.union env' env ) i1 e newNames :: Int -> [Ident] -> (Int, [Ident], Map Ident Ident) newNames i old_names = (i', new_names, env) where (i', new_names) = getNames i old_names env = Map.fromList $ zip old_names new_names getNames :: Int -> [Ident] -> (Int, [Ident]) getNames i ns = (i + length ss, zipWith makeName ss [i..]) where ss = map (\(Ident s) -> s) ns makeName :: String -> Int -> Ident makeName prefix i = Ident (prefix ++ "_" ++ show i) -- | Collects supercombinators by lifting appropriate let expressions collectScs :: Program -> Program collectScs (Program scs) = Program $ concatMap collectFromRhs scs where collectFromRhs (Bind name parms rhs) = let (rhs_scs, rhs') = collectScsExp rhs in Bind name parms rhs' : rhs_scs collectScsExp :: Exp -> ([Bind], Exp) collectScsExp = \case EId n -> ([], EId n) EInt i -> ([], EInt i) EApp e1 e2 -> (scs1 ++ scs2, EApp e1' e2') where (scs1, e1') = collectScsExp e1 (scs2, e2') = collectScsExp e2 EAdd e1 e2 -> (scs1 ++ scs2, EAdd e1' e2') where (scs1, e1') = collectScsExp e1 (scs2, e2') = collectScsExp e2 EAbs x e -> (scs, EAbs x e') where (scs, e') = collectScsExp e -- Collect supercombinators from binds, the rhss, and the expression, -- and the rhss. -- -- > f = let -- > sc = rhs -- > sc1 = rhs1 -- > ... -- > in e -- ELet binds e -> (binds_scs ++ rhss_scs ++ e_scs, mkEAbs non_scs' e') where binds_scs = [ Bind n (parms ++ parms1) e1 | Bind n parms (EAbs parms1 e1) <- scs' ] (rhss_scs, binds') = mapAccumL collectScsRhs [] binds (e_scs, e') = collectScsExp e (scs', non_scs') = partition (\(Bind _ _ rhs) -> isEAbs rhs) binds' collectScsRhs acc (Bind n xs rhs) = (acc ++ rhs_scs, Bind n xs rhs') where (rhs_scs, rhs') = collectScsExp rhs isEAbs :: Exp -> Bool isEAbs = \case EAbs {} -> True _ -> False mkEAbs :: [Bind] -> Exp -> Exp mkEAbs [] e = e mkEAbs bs e = ELet bs e