Fix unnecessary supercombinator issue

This commit is contained in:
Martin Fredin 2023-02-10 11:47:07 +01:00
parent ece621b0aa
commit 8688b303ac
5 changed files with 62 additions and 43 deletions

View file

@ -7,7 +7,7 @@ EInt. Exp3 ::= Integer;
ELet. Exp3 ::= "let" [Bind] "in" Exp;
EApp. Exp2 ::= Exp2 Exp3;
EAdd. Exp1 ::= Exp1 "+" Exp2;
EAbs. Exp ::= "\\" Ident "." Exp;
EAbs. Exp ::= "\\" [Ident] "." Exp;
Bind. Bind ::= Ident [Ident] "=" Exp;
separator Bind ";";

3
sample-programs/basic-6 Normal file
View file

@ -0,0 +1,3 @@
f = \x.\y. x+y

5
sample-programs/basic-7 Normal file
View file

@ -0,0 +1,5 @@
add x y = x + y;
apply f x = f x;
main = apply (add 4) 6;

View file

@ -4,7 +4,7 @@
module LambdaLifter (lambdaLift, freeVars, abstract, rename, collectScs) where
import Data.List (mapAccumL)
import Data.List (mapAccumL, partition)
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Maybe (fromMaybe)
@ -33,7 +33,7 @@ freeVarsExp lv = \case
EId n | Set.member n lv -> (Set.singleton n, AId n)
| otherwise -> (mempty, AId n)
EInt i -> (mempty, AInt i)
EInt i -> (mempty, AInt i)
EApp e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), AApp e1' e2')
where e1' = freeVarsExp lv e1
@ -43,8 +43,8 @@ freeVarsExp lv = \case
where e1' = freeVarsExp lv e1
e2' = freeVarsExp lv e2
EAbs n e -> (Set.delete n $ freeVarsOf e', AAbs n e')
where e' = freeVarsExp (Set.insert n lv) e
EAbs parms e -> (freeVarsOf e' \\ Set.fromList parms, AAbs parms e')
where e' = freeVarsExp (foldr Set.insert lv parms) e
ELet bs e -> (Set.union bsFree eFree, ALet bs' e')
where
@ -76,18 +76,20 @@ data AnnExp' = AId Ident
| AInt Integer
| AApp AnnExp AnnExp
| AAdd AnnExp AnnExp
| AAbs Ident AnnExp
| AAbs [Ident] AnnExp
| ALet [ABind] AnnExp
deriving Show
-- | Lift lambdas to let expression of the form @let sc = \x -> rhs@
-- | Lift lambdas to let expression of the form @let sc = \v₁ x₁ -> e₁@.
-- Free variables are @v₁ v₂ .. vₙ@ are bound.
abstract :: AnnProgram -> Program
abstract prog = Program $ map f prog
abstract prog = Program $ map go prog
where
f :: (Ident, [Ident], AnnExp) -> Bind
f (name, pars, rhs@(_, e)) =
go :: (Ident, [Ident], AnnExp) -> Bind
go (name, pars, rhs@(_, e)) =
case e of
AAbs par body -> Bind name (snoc par pars) $ abstractExp body
AAbs pars1 e1 -> Bind name (pars ++ pars1) $ abstractExp e1
_ -> Bind name pars $ abstractExp rhs
abstractExp :: AnnExp -> Exp
@ -96,17 +98,21 @@ abstractExp (free, exp) = case exp of
AInt i -> EInt i
AApp e1 e2 -> EApp (abstractExp e1) (abstractExp e2)
AAdd e1 e2 -> EAdd (abstractExp e1) (abstractExp e2)
ALet bs e -> ELet [Bind n xs (abstractExp e1) | ABind n xs e1 <- bs ] $ abstractExp e
AAbs n e -> foldl EApp sc (map EId fvList)
ALet bs e -> ELet (map go bs) $ abstractExp e
where
fvList = Set.toList free
bind = Bind "sc" [] e'
e' = foldr EAbs (abstractExp e) (fvList ++ [n])
sc = ELet [bind] (EId (Ident "sc"))
go (ABind name parms rhs) = Bind name parms $ skipLambdas abstractExp rhs
skipLambdas :: (AnnExp -> Exp) -> AnnExp -> Exp
skipLambdas f (free, ae) = case ae of
AAbs name ae1 -> EAbs name $ skipLambdas f ae1
_ -> f (free, ae)
snoc :: a -> [a] -> [a]
snoc x xs = xs ++ [x]
-- Lift lambda into let and bind free variables
AAbs parms e -> foldl EApp sc $ map EId freeList
where
freeList = Set.toList free
sc = ELet [Bind "sc" [] rhs] $ EId "sc"
rhs = EAbs (freeList ++ parms) $ abstractExp e
-- | Rename all supercombinators and variables
rename :: Program -> Program
@ -144,9 +150,9 @@ renameExp env i = \case
(i3, es') = mapAccumL (renameExp e_env) i2 es
EAbs n e -> (i2, EAbs (head ns) e')
EAbs parms e -> (i2, EAbs ns e')
where
(i1, ns, env') = newNames i [n]
(i1, ns, env') = newNames i parms
(i2, e') = renameExp (Map.union env' env ) i1 e
@ -156,10 +162,6 @@ newNames i old_names = (i', new_names, env)
(i', new_names) = getNames i old_names
env = Map.fromList $ zip old_names new_names
getName :: Int -> Ident -> (Int, Ident)
getName i (Ident s) = (i + 1, makeName s i)
getNames :: Int -> [Ident] -> (Int, [Ident])
getNames i ns = (i + length ss, zipWith makeName ss [i..])
where
@ -171,16 +173,16 @@ makeName prefix i = Ident (prefix ++ "_" ++ show i)
-- | Collects supercombinators by lifting appropriate let expressions
collectScs :: Program -> Program
collectScs (Program ds) = Program $ concatMap collectOneSc ds
collectScs (Program scs) = Program $ concatMap collectFromRhs scs
where
collectOneSc (Bind name args rhs) = Bind name args rhs' : scs
where (scs, rhs') = collectScsExp rhs
collectFromRhs (Bind name parms rhs) =
let (rhs_scs, rhs') = collectScsExp rhs
in Bind name parms rhs' : rhs_scs
collectScsExp :: Exp -> ([Bind], Exp)
collectScsExp = \case
EId n -> ([], EId n)
EInt i -> ([], EInt i)
EApp e1 e2 -> (scs1 ++ scs2, EApp e1' e2')
@ -197,17 +199,30 @@ collectScsExp = \case
where
(scs, e') = collectScsExp e
ELet bs e -> (rhss_scs ++ e_scs ++ local_scs, mkEAbs non_scs' e')
-- Collect supercombinators from binds, the rhss, and the expression,
-- and the rhss.
--
-- > f = let
-- > sc = rhs
-- > sc1 = rhs1
-- > ...
-- > in e
--
ELet binds e -> (binds_scs ++ rhss_scs ++ e_scs, mkEAbs non_scs' e')
where
(rhss_scs, bs') = mapAccumL collectScs_d [] bs
scs' = [ Bind n xs rhs | Bind n xs rhs <- bs', isEAbs rhs]
non_scs' = [ Bind n xs rhs | Bind n xs rhs <- bs', not $ isEAbs rhs]
local_scs = [ Bind n (xs ++ [x]) e1 | Bind n xs (EAbs x e1) <- scs']
(e_scs, e') = collectScsExp e
binds_scs = [ Bind n (parms ++ parms1) e1
| Bind n parms (EAbs parms1 e1) <- scs'
]
(rhss_scs, binds') = mapAccumL collectScsRhs [] binds
(e_scs, e') = collectScsExp e
collectScs_d scs (Bind n xs rhs) = (scs ++ rhs_scs1, Bind n xs rhs')
(scs', non_scs') = partition (\(Bind _ _ rhs) -> isEAbs rhs) binds'
collectScsRhs acc (Bind n xs rhs) = (acc ++ rhs_scs, Bind n xs rhs')
where
(rhs_scs1, rhs') = collectScsExp rhs
(rhs_scs, rhs') = collectScsExp rhs
isEAbs :: Exp -> Bool
isEAbs = \case

View file

@ -3,7 +3,7 @@ module Main where
import Grammar.Par (myLexer, pProgram)
import Grammar.Print (printTree)
import LambdaLifter (abstract, freeVars, lambdaLift, rename)
import LambdaLifter (lambdaLift)
import System.Environment (getArgs)
import System.Exit (exitFailure, exitSuccess)
@ -20,10 +20,6 @@ main = getArgs >>= \case
Right prg -> do
putStrLn "-- Parse"
putStrLn $ printTree prg
-- putStrLn "\n-- Abstract"
-- putStrLn . printTree $ (abstract . freeVars) prg
-- putStrLn "\n-- Rename"
-- putStrLn . printTree $ (rename . abstract . freeVars) prg
putStrLn "\n-- Lamda lifter"
putStrLn . printTree $ lambdaLift prg
putStrLn ""