Fix unnecessary supercombinator issue

This commit is contained in:
Martin Fredin 2023-02-10 11:47:07 +01:00
parent ece621b0aa
commit 8688b303ac
5 changed files with 62 additions and 43 deletions

View file

@ -7,7 +7,7 @@ EInt. Exp3 ::= Integer;
ELet. Exp3 ::= "let" [Bind] "in" Exp; ELet. Exp3 ::= "let" [Bind] "in" Exp;
EApp. Exp2 ::= Exp2 Exp3; EApp. Exp2 ::= Exp2 Exp3;
EAdd. Exp1 ::= Exp1 "+" Exp2; EAdd. Exp1 ::= Exp1 "+" Exp2;
EAbs. Exp ::= "\\" Ident "." Exp; EAbs. Exp ::= "\\" [Ident] "." Exp;
Bind. Bind ::= Ident [Ident] "=" Exp; Bind. Bind ::= Ident [Ident] "=" Exp;
separator Bind ";"; separator Bind ";";

3
sample-programs/basic-6 Normal file
View file

@ -0,0 +1,3 @@
f = \x.\y. x+y

5
sample-programs/basic-7 Normal file
View file

@ -0,0 +1,5 @@
add x y = x + y;
apply f x = f x;
main = apply (add 4) 6;

View file

@ -4,7 +4,7 @@
module LambdaLifter (lambdaLift, freeVars, abstract, rename, collectScs) where module LambdaLifter (lambdaLift, freeVars, abstract, rename, collectScs) where
import Data.List (mapAccumL) import Data.List (mapAccumL, partition)
import Data.Map (Map) import Data.Map (Map)
import qualified Data.Map as Map import qualified Data.Map as Map
import Data.Maybe (fromMaybe) import Data.Maybe (fromMaybe)
@ -43,8 +43,8 @@ freeVarsExp lv = \case
where e1' = freeVarsExp lv e1 where e1' = freeVarsExp lv e1
e2' = freeVarsExp lv e2 e2' = freeVarsExp lv e2
EAbs n e -> (Set.delete n $ freeVarsOf e', AAbs n e') EAbs parms e -> (freeVarsOf e' \\ Set.fromList parms, AAbs parms e')
where e' = freeVarsExp (Set.insert n lv) e where e' = freeVarsExp (foldr Set.insert lv parms) e
ELet bs e -> (Set.union bsFree eFree, ALet bs' e') ELet bs e -> (Set.union bsFree eFree, ALet bs' e')
where where
@ -76,18 +76,20 @@ data AnnExp' = AId Ident
| AInt Integer | AInt Integer
| AApp AnnExp AnnExp | AApp AnnExp AnnExp
| AAdd AnnExp AnnExp | AAdd AnnExp AnnExp
| AAbs Ident AnnExp | AAbs [Ident] AnnExp
| ALet [ABind] AnnExp | ALet [ABind] AnnExp
deriving Show deriving Show
-- | Lift lambdas to let expression of the form @let sc = \x -> rhs@ -- | Lift lambdas to let expression of the form @let sc = \v₁ x₁ -> e₁@.
-- Free variables are @v₁ v₂ .. vₙ@ are bound.
abstract :: AnnProgram -> Program abstract :: AnnProgram -> Program
abstract prog = Program $ map f prog abstract prog = Program $ map go prog
where where
f :: (Ident, [Ident], AnnExp) -> Bind go :: (Ident, [Ident], AnnExp) -> Bind
f (name, pars, rhs@(_, e)) = go (name, pars, rhs@(_, e)) =
case e of case e of
AAbs par body -> Bind name (snoc par pars) $ abstractExp body AAbs pars1 e1 -> Bind name (pars ++ pars1) $ abstractExp e1
_ -> Bind name pars $ abstractExp rhs _ -> Bind name pars $ abstractExp rhs
abstractExp :: AnnExp -> Exp abstractExp :: AnnExp -> Exp
@ -96,17 +98,21 @@ abstractExp (free, exp) = case exp of
AInt i -> EInt i AInt i -> EInt i
AApp e1 e2 -> EApp (abstractExp e1) (abstractExp e2) AApp e1 e2 -> EApp (abstractExp e1) (abstractExp e2)
AAdd e1 e2 -> EAdd (abstractExp e1) (abstractExp e2) AAdd e1 e2 -> EAdd (abstractExp e1) (abstractExp e2)
ALet bs e -> ELet [Bind n xs (abstractExp e1) | ABind n xs e1 <- bs ] $ abstractExp e ALet bs e -> ELet (map go bs) $ abstractExp e
AAbs n e -> foldl EApp sc (map EId fvList)
where where
fvList = Set.toList free go (ABind name parms rhs) = Bind name parms $ skipLambdas abstractExp rhs
bind = Bind "sc" [] e'
e' = foldr EAbs (abstractExp e) (fvList ++ [n])
sc = ELet [bind] (EId (Ident "sc"))
skipLambdas :: (AnnExp -> Exp) -> AnnExp -> Exp
skipLambdas f (free, ae) = case ae of
AAbs name ae1 -> EAbs name $ skipLambdas f ae1
_ -> f (free, ae)
snoc :: a -> [a] -> [a] -- Lift lambda into let and bind free variables
snoc x xs = xs ++ [x] AAbs parms e -> foldl EApp sc $ map EId freeList
where
freeList = Set.toList free
sc = ELet [Bind "sc" [] rhs] $ EId "sc"
rhs = EAbs (freeList ++ parms) $ abstractExp e
-- | Rename all supercombinators and variables -- | Rename all supercombinators and variables
rename :: Program -> Program rename :: Program -> Program
@ -144,9 +150,9 @@ renameExp env i = \case
(i3, es') = mapAccumL (renameExp e_env) i2 es (i3, es') = mapAccumL (renameExp e_env) i2 es
EAbs n e -> (i2, EAbs (head ns) e') EAbs parms e -> (i2, EAbs ns e')
where where
(i1, ns, env') = newNames i [n] (i1, ns, env') = newNames i parms
(i2, e') = renameExp (Map.union env' env ) i1 e (i2, e') = renameExp (Map.union env' env ) i1 e
@ -156,10 +162,6 @@ newNames i old_names = (i', new_names, env)
(i', new_names) = getNames i old_names (i', new_names) = getNames i old_names
env = Map.fromList $ zip old_names new_names env = Map.fromList $ zip old_names new_names
getName :: Int -> Ident -> (Int, Ident)
getName i (Ident s) = (i + 1, makeName s i)
getNames :: Int -> [Ident] -> (Int, [Ident]) getNames :: Int -> [Ident] -> (Int, [Ident])
getNames i ns = (i + length ss, zipWith makeName ss [i..]) getNames i ns = (i + length ss, zipWith makeName ss [i..])
where where
@ -171,16 +173,16 @@ makeName prefix i = Ident (prefix ++ "_" ++ show i)
-- | Collects supercombinators by lifting appropriate let expressions -- | Collects supercombinators by lifting appropriate let expressions
collectScs :: Program -> Program collectScs :: Program -> Program
collectScs (Program ds) = Program $ concatMap collectOneSc ds collectScs (Program scs) = Program $ concatMap collectFromRhs scs
where where
collectOneSc (Bind name args rhs) = Bind name args rhs' : scs collectFromRhs (Bind name parms rhs) =
where (scs, rhs') = collectScsExp rhs let (rhs_scs, rhs') = collectScsExp rhs
in Bind name parms rhs' : rhs_scs
collectScsExp :: Exp -> ([Bind], Exp) collectScsExp :: Exp -> ([Bind], Exp)
collectScsExp = \case collectScsExp = \case
EId n -> ([], EId n) EId n -> ([], EId n)
EInt i -> ([], EInt i) EInt i -> ([], EInt i)
EApp e1 e2 -> (scs1 ++ scs2, EApp e1' e2') EApp e1 e2 -> (scs1 ++ scs2, EApp e1' e2')
@ -197,17 +199,30 @@ collectScsExp = \case
where where
(scs, e') = collectScsExp e (scs, e') = collectScsExp e
ELet bs e -> (rhss_scs ++ e_scs ++ local_scs, mkEAbs non_scs' e') -- Collect supercombinators from binds, the rhss, and the expression,
-- and the rhss.
--
-- > f = let
-- > sc = rhs
-- > sc1 = rhs1
-- > ...
-- > in e
--
ELet binds e -> (binds_scs ++ rhss_scs ++ e_scs, mkEAbs non_scs' e')
where where
(rhss_scs, bs') = mapAccumL collectScs_d [] bs binds_scs = [ Bind n (parms ++ parms1) e1
scs' = [ Bind n xs rhs | Bind n xs rhs <- bs', isEAbs rhs] | Bind n parms (EAbs parms1 e1) <- scs'
non_scs' = [ Bind n xs rhs | Bind n xs rhs <- bs', not $ isEAbs rhs] ]
local_scs = [ Bind n (xs ++ [x]) e1 | Bind n xs (EAbs x e1) <- scs'] (rhss_scs, binds') = mapAccumL collectScsRhs [] binds
(e_scs, e') = collectScsExp e (e_scs, e') = collectScsExp e
collectScs_d scs (Bind n xs rhs) = (scs ++ rhs_scs1, Bind n xs rhs') (scs', non_scs') = partition (\(Bind _ _ rhs) -> isEAbs rhs) binds'
collectScsRhs acc (Bind n xs rhs) = (acc ++ rhs_scs, Bind n xs rhs')
where where
(rhs_scs1, rhs') = collectScsExp rhs (rhs_scs, rhs') = collectScsExp rhs
isEAbs :: Exp -> Bool isEAbs :: Exp -> Bool
isEAbs = \case isEAbs = \case

View file

@ -3,7 +3,7 @@ module Main where
import Grammar.Par (myLexer, pProgram) import Grammar.Par (myLexer, pProgram)
import Grammar.Print (printTree) import Grammar.Print (printTree)
import LambdaLifter (abstract, freeVars, lambdaLift, rename) import LambdaLifter (lambdaLift)
import System.Environment (getArgs) import System.Environment (getArgs)
import System.Exit (exitFailure, exitSuccess) import System.Exit (exitFailure, exitSuccess)
@ -20,10 +20,6 @@ main = getArgs >>= \case
Right prg -> do Right prg -> do
putStrLn "-- Parse" putStrLn "-- Parse"
putStrLn $ printTree prg putStrLn $ printTree prg
-- putStrLn "\n-- Abstract"
-- putStrLn . printTree $ (abstract . freeVars) prg
-- putStrLn "\n-- Rename"
-- putStrLn . printTree $ (rename . abstract . freeVars) prg
putStrLn "\n-- Lamda lifter" putStrLn "\n-- Lamda lifter"
putStrLn . printTree $ lambdaLift prg putStrLn . printTree $ lambdaLift prg
putStrLn "" putStrLn ""