From 8688b303ac9593f3dc107c3778f6392b18ca8aef Mon Sep 17 00:00:00 2001 From: Martin Fredin Date: Fri, 10 Feb 2023 11:47:07 +0100 Subject: [PATCH] Fix unnecessary supercombinator issue --- Grammar.cf | 2 +- sample-programs/basic-6 | 3 ++ sample-programs/basic-7 | 5 +++ src/LambdaLifter.hs | 89 ++++++++++++++++++++++++----------------- src/Main.hs | 6 +-- 5 files changed, 62 insertions(+), 43 deletions(-) create mode 100644 sample-programs/basic-6 create mode 100644 sample-programs/basic-7 diff --git a/Grammar.cf b/Grammar.cf index 410d11d..72e01da 100644 --- a/Grammar.cf +++ b/Grammar.cf @@ -7,7 +7,7 @@ EInt. Exp3 ::= Integer; ELet. Exp3 ::= "let" [Bind] "in" Exp; EApp. Exp2 ::= Exp2 Exp3; EAdd. Exp1 ::= Exp1 "+" Exp2; -EAbs. Exp ::= "\\" Ident "." Exp; +EAbs. Exp ::= "\\" [Ident] "." Exp; Bind. Bind ::= Ident [Ident] "=" Exp; separator Bind ";"; diff --git a/sample-programs/basic-6 b/sample-programs/basic-6 new file mode 100644 index 0000000..511ae10 --- /dev/null +++ b/sample-programs/basic-6 @@ -0,0 +1,3 @@ + + +f = \x.\y. x+y diff --git a/sample-programs/basic-7 b/sample-programs/basic-7 new file mode 100644 index 0000000..b3769b9 --- /dev/null +++ b/sample-programs/basic-7 @@ -0,0 +1,5 @@ +add x y = x + y; + +apply f x = f x; + +main = apply (add 4) 6; diff --git a/src/LambdaLifter.hs b/src/LambdaLifter.hs index ac9cee0..e8862a2 100644 --- a/src/LambdaLifter.hs +++ b/src/LambdaLifter.hs @@ -4,7 +4,7 @@ module LambdaLifter (lambdaLift, freeVars, abstract, rename, collectScs) where -import Data.List (mapAccumL) +import Data.List (mapAccumL, partition) import Data.Map (Map) import qualified Data.Map as Map import Data.Maybe (fromMaybe) @@ -33,7 +33,7 @@ freeVarsExp lv = \case EId n | Set.member n lv -> (Set.singleton n, AId n) | otherwise -> (mempty, AId n) - EInt i -> (mempty, AInt i) + EInt i -> (mempty, AInt i) EApp e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), AApp e1' e2') where e1' = freeVarsExp lv e1 @@ -43,8 +43,8 @@ freeVarsExp lv = \case where e1' = freeVarsExp lv e1 e2' = freeVarsExp lv e2 - EAbs n e -> (Set.delete n $ freeVarsOf e', AAbs n e') - where e' = freeVarsExp (Set.insert n lv) e + EAbs parms e -> (freeVarsOf e' \\ Set.fromList parms, AAbs parms e') + where e' = freeVarsExp (foldr Set.insert lv parms) e ELet bs e -> (Set.union bsFree eFree, ALet bs' e') where @@ -76,18 +76,20 @@ data AnnExp' = AId Ident | AInt Integer | AApp AnnExp AnnExp | AAdd AnnExp AnnExp - | AAbs Ident AnnExp + | AAbs [Ident] AnnExp | ALet [ABind] AnnExp deriving Show --- | Lift lambdas to let expression of the form @let sc = \x -> rhs@ +-- | Lift lambdas to let expression of the form @let sc = \v₁ x₁ -> e₁@. +-- Free variables are @v₁ v₂ .. vₙ@ are bound. abstract :: AnnProgram -> Program -abstract prog = Program $ map f prog +abstract prog = Program $ map go prog where - f :: (Ident, [Ident], AnnExp) -> Bind - f (name, pars, rhs@(_, e)) = + go :: (Ident, [Ident], AnnExp) -> Bind + go (name, pars, rhs@(_, e)) = + case e of - AAbs par body -> Bind name (snoc par pars) $ abstractExp body + AAbs pars1 e1 -> Bind name (pars ++ pars1) $ abstractExp e1 _ -> Bind name pars $ abstractExp rhs abstractExp :: AnnExp -> Exp @@ -96,17 +98,21 @@ abstractExp (free, exp) = case exp of AInt i -> EInt i AApp e1 e2 -> EApp (abstractExp e1) (abstractExp e2) AAdd e1 e2 -> EAdd (abstractExp e1) (abstractExp e2) - ALet bs e -> ELet [Bind n xs (abstractExp e1) | ABind n xs e1 <- bs ] $ abstractExp e - AAbs n e -> foldl EApp sc (map EId fvList) + ALet bs e -> ELet (map go bs) $ abstractExp e where - fvList = Set.toList free - bind = Bind "sc" [] e' - e' = foldr EAbs (abstractExp e) (fvList ++ [n]) - sc = ELet [bind] (EId (Ident "sc")) + go (ABind name parms rhs) = Bind name parms $ skipLambdas abstractExp rhs + skipLambdas :: (AnnExp -> Exp) -> AnnExp -> Exp + skipLambdas f (free, ae) = case ae of + AAbs name ae1 -> EAbs name $ skipLambdas f ae1 + _ -> f (free, ae) -snoc :: a -> [a] -> [a] -snoc x xs = xs ++ [x] + -- Lift lambda into let and bind free variables + AAbs parms e -> foldl EApp sc $ map EId freeList + where + freeList = Set.toList free + sc = ELet [Bind "sc" [] rhs] $ EId "sc" + rhs = EAbs (freeList ++ parms) $ abstractExp e -- | Rename all supercombinators and variables rename :: Program -> Program @@ -144,9 +150,9 @@ renameExp env i = \case (i3, es') = mapAccumL (renameExp e_env) i2 es - EAbs n e -> (i2, EAbs (head ns) e') + EAbs parms e -> (i2, EAbs ns e') where - (i1, ns, env') = newNames i [n] + (i1, ns, env') = newNames i parms (i2, e') = renameExp (Map.union env' env ) i1 e @@ -156,10 +162,6 @@ newNames i old_names = (i', new_names, env) (i', new_names) = getNames i old_names env = Map.fromList $ zip old_names new_names - -getName :: Int -> Ident -> (Int, Ident) -getName i (Ident s) = (i + 1, makeName s i) - getNames :: Int -> [Ident] -> (Int, [Ident]) getNames i ns = (i + length ss, zipWith makeName ss [i..]) where @@ -171,16 +173,16 @@ makeName prefix i = Ident (prefix ++ "_" ++ show i) -- | Collects supercombinators by lifting appropriate let expressions collectScs :: Program -> Program -collectScs (Program ds) = Program $ concatMap collectOneSc ds +collectScs (Program scs) = Program $ concatMap collectFromRhs scs where - collectOneSc (Bind name args rhs) = Bind name args rhs' : scs - where (scs, rhs') = collectScsExp rhs + collectFromRhs (Bind name parms rhs) = + let (rhs_scs, rhs') = collectScsExp rhs + in Bind name parms rhs' : rhs_scs + collectScsExp :: Exp -> ([Bind], Exp) collectScsExp = \case - EId n -> ([], EId n) - EInt i -> ([], EInt i) EApp e1 e2 -> (scs1 ++ scs2, EApp e1' e2') @@ -197,17 +199,30 @@ collectScsExp = \case where (scs, e') = collectScsExp e - ELet bs e -> (rhss_scs ++ e_scs ++ local_scs, mkEAbs non_scs' e') + -- Collect supercombinators from binds, the rhss, and the expression, + -- and the rhss. + -- + -- > f = let + -- > sc = rhs + -- > sc1 = rhs1 + -- > ... + -- > in e + -- + ELet binds e -> (binds_scs ++ rhss_scs ++ e_scs, mkEAbs non_scs' e') where - (rhss_scs, bs') = mapAccumL collectScs_d [] bs - scs' = [ Bind n xs rhs | Bind n xs rhs <- bs', isEAbs rhs] - non_scs' = [ Bind n xs rhs | Bind n xs rhs <- bs', not $ isEAbs rhs] - local_scs = [ Bind n (xs ++ [x]) e1 | Bind n xs (EAbs x e1) <- scs'] - (e_scs, e') = collectScsExp e + binds_scs = [ Bind n (parms ++ parms1) e1 + | Bind n parms (EAbs parms1 e1) <- scs' + ] + (rhss_scs, binds') = mapAccumL collectScsRhs [] binds + (e_scs, e') = collectScsExp e - collectScs_d scs (Bind n xs rhs) = (scs ++ rhs_scs1, Bind n xs rhs') + (scs', non_scs') = partition (\(Bind _ _ rhs) -> isEAbs rhs) binds' + + collectScsRhs acc (Bind n xs rhs) = (acc ++ rhs_scs, Bind n xs rhs') where - (rhs_scs1, rhs') = collectScsExp rhs + (rhs_scs, rhs') = collectScsExp rhs + + isEAbs :: Exp -> Bool isEAbs = \case diff --git a/src/Main.hs b/src/Main.hs index 9af1753..570ac1a 100644 --- a/src/Main.hs +++ b/src/Main.hs @@ -3,7 +3,7 @@ module Main where import Grammar.Par (myLexer, pProgram) import Grammar.Print (printTree) -import LambdaLifter (abstract, freeVars, lambdaLift, rename) +import LambdaLifter (lambdaLift) import System.Environment (getArgs) import System.Exit (exitFailure, exitSuccess) @@ -20,10 +20,6 @@ main = getArgs >>= \case Right prg -> do putStrLn "-- Parse" putStrLn $ printTree prg - -- putStrLn "\n-- Abstract" - -- putStrLn . printTree $ (abstract . freeVars) prg - -- putStrLn "\n-- Rename" - -- putStrLn . printTree $ (rename . abstract . freeVars) prg putStrLn "\n-- Lamda lifter" putStrLn . printTree $ lambdaLift prg putStrLn ""