Revert back to one lambda par, and fix issues with lambda lifter

This commit is contained in:
Martin Fredin 2023-02-11 09:59:26 +01:00
parent 78a3ed56ea
commit e212c79a44
4 changed files with 73 additions and 35 deletions

View file

@ -7,7 +7,7 @@ EInt. Exp3 ::= Integer;
ELet. Exp3 ::= "let" [Bind] "in" Exp; ELet. Exp3 ::= "let" [Bind] "in" Exp;
EApp. Exp2 ::= Exp2 Exp3; EApp. Exp2 ::= Exp2 Exp3;
EAdd. Exp1 ::= Exp1 "+" Exp2; EAdd. Exp1 ::= Exp1 "+" Exp2;
EAbs. Exp ::= "\\" [Ident] "." Exp; EAbs. Exp ::= "\\" Ident "." Exp;
Bind. Bind ::= Ident [Ident] "=" Exp; Bind. Bind ::= Ident [Ident] "=" Exp;
separator Bind ";"; separator Bind ";";

View file

@ -1,5 +1,11 @@
module Auxiliary (module Auxiliary) where module Auxiliary (module Auxiliary) where
import Control.Monad.Error.Class (liftEither)
import Control.Monad.Except (MonadError)
import Data.Either.Combinators (maybeToRight)
snoc :: a -> [a] -> [a] snoc :: a -> [a] -> [a]
snoc x xs = xs ++ [x] snoc x xs = xs ++ [x]
maybeToRightM :: MonadError l m => l -> Maybe r -> m r
maybeToRightM err = liftEither . maybeToRight err

View file

@ -4,15 +4,17 @@
module LambdaLifter (lambdaLift, freeVars, abstract, rename, collectScs) where module LambdaLifter (lambdaLift, freeVars, abstract, rename, collectScs) where
import Data.List (mapAccumL, partition) import Auxiliary (snoc)
import Data.Map (Map) import Data.Foldable.Extra (notNull)
import qualified Data.Map as Map import Data.List (mapAccumL, mapAccumR, partition)
import Data.Maybe (fromMaybe) import Data.Map (Map)
import Data.Set (Set, (\\)) import qualified Data.Map as Map
import qualified Data.Set as Set import Data.Maybe (fromMaybe, mapMaybe)
import Data.Tuple.Extra (uncurry3) import Data.Set (Set, (\\))
import qualified Data.Set as Set
import Data.Tuple.Extra (uncurry3)
import Grammar.Abs import Grammar.Abs
import Prelude hiding (exp) import Prelude hiding (exp)
-- | Lift lambdas and let expression into supercombinators. -- | Lift lambdas and let expression into supercombinators.
@ -44,9 +46,9 @@ freeVarsExp localVars = \case
e1' = freeVarsExp localVars e1 e1' = freeVarsExp localVars e1
e2' = freeVarsExp localVars e2 e2' = freeVarsExp localVars e2
EAbs parms e -> (freeVarsOf e' \\ Set.fromList parms, AAbs parms e') EAbs par e -> (Set.delete par $ freeVarsOf e', AAbs par e')
where where
e' = freeVarsExp (foldr Set.insert localVars parms) e e' = freeVarsExp (Set.insert par localVars) e
-- Sum free variables present in binders and the expression -- Sum free variables present in binders and the expression
ELet binders e -> (Set.union binders_frees e_free, ALet binders' e') ELet binders e -> (Set.union binders_frees e_free, ALet binders' e')
@ -82,7 +84,7 @@ data AnnExp' = AId Ident
| AInt Integer | AInt Integer
| AApp AnnExp AnnExp | AApp AnnExp AnnExp
| AAdd AnnExp AnnExp | AAdd AnnExp AnnExp
| AAbs [Ident] AnnExp | AAbs Ident AnnExp
| ALet [ABind] AnnExp | ALet [ABind] AnnExp
deriving Show deriving Show
@ -93,10 +95,24 @@ abstract prog = Program $ map go prog
where where
go :: (Ident, [Ident], AnnExp) -> Bind go :: (Ident, [Ident], AnnExp) -> Bind
go (name, pars, rhs@(_, e)) = go (name, pars, rhs@(_, e)) =
case e of case e of
AAbs pars1 e1 -> Bind name (pars ++ pars1) $ abstractExp e1 AAbs par e1 -> Bind name (snoc par pars ++ pars2) $ abstractExp e2
_ -> Bind name pars $ abstractExp rhs where
(e2, pars2) = flattenLambdasAnn e1
_ -> Bind name pars $ abstractExp rhs
-- | Flatten nested lambdas and collect the parameters
-- @\x.\y.\z. ae → (ae, [x,y,z])@
flattenLambdasAnn :: AnnExp -> (AnnExp, [Ident])
flattenLambdasAnn ae = go (ae, [])
where
go :: (AnnExp, [Ident]) -> (AnnExp, [Ident])
go ((free, e), acc) =
case e of
AAbs par (free1, e1) -> go ((Set.delete par free1, e1), snoc par acc)
_ -> ((free, e), acc)
abstractExp :: AnnExp -> Exp abstractExp :: AnnExp -> Exp
abstractExp (free, exp) = case exp of abstractExp (free, exp) = case exp of
@ -106,7 +122,11 @@ abstractExp (free, exp) = case exp of
AAdd e1 e2 -> EAdd (abstractExp e1) (abstractExp e2) AAdd e1 e2 -> EAdd (abstractExp e1) (abstractExp e2)
ALet bs e -> ELet (map go bs) $ abstractExp e ALet bs e -> ELet (map go bs) $ abstractExp e
where where
go (ABind name parms rhs) = Bind name parms $ skipLambdas abstractExp rhs go (ABind name parms rhs) =
let
(rhs', parms1) = flattenLambdas $ skipLambdas abstractExp rhs
in
Bind name (parms ++ parms1) rhs'
skipLambdas :: (AnnExp -> Exp) -> AnnExp -> Exp skipLambdas :: (AnnExp -> Exp) -> AnnExp -> Exp
skipLambdas f (free, ae) = case ae of skipLambdas f (free, ae) = case ae of
@ -114,11 +134,10 @@ abstractExp (free, exp) = case exp of
_ -> f (free, ae) _ -> f (free, ae)
-- Lift lambda into let and bind free variables -- Lift lambda into let and bind free variables
AAbs parms e -> foldl EApp sc $ map EId freeList AAbs par e -> foldl EApp sc $ map EId freeList
where where
freeList = Set.toList free freeList = Set.toList free
sc = ELet [Bind "sc" [] rhs] $ EId "sc" sc = ELet [Bind "sc" (snoc par freeList) $ abstractExp e] $ EId "sc"
rhs = EAbs (freeList ++ parms) $ abstractExp e
-- | Rename all supercombinators and variables -- | Rename all supercombinators and variables
rename :: Program -> Program rename :: Program -> Program
@ -147,21 +166,30 @@ renameExp env i = \case
(i1, e1') = renameExp env i e1 (i1, e1') = renameExp env i e1
(i2, e2') = renameExp env i1 e2 (i2, e2') = renameExp env i1 e2
ELet bs e -> (i3, ELet (zipWith3 Bind ns' xs es') e') ELet bs e -> (i3, ELet (zipWith3 Bind ns' pars' es') e')
where where
(i1, e') = renameExp e_env i e (i1, e') = renameExp e_env i e
(ns, xs, es) = fromBinders bs (names, pars, rhss) = fromBinders bs
(i2, ns', env') = newNames i1 ns (i2, ns', env') = newNames i1 (names ++ concat pars)
pars' = (map . map) renamePar pars
e_env = Map.union env' env e_env = Map.union env' env
(i3, es') = mapAccumL (renameExp e_env) i2 es (i3, es') = mapAccumL (renameExp e_env) i2 rhss
renamePar p = case Map.lookup p env' of
Just p' -> p'
Nothing -> error ("Can't find name for " ++ show p)
EAbs parms e -> (i2, EAbs ns e') EAbs par e -> (i2, EAbs par' e')
where where
(i1, ns, env') = newNames i parms (i1, par', env') = newName par
(i2, e') = renameExp (Map.union env' env ) i1 e (i2, e') = renameExp (Map.union env' env ) i1 e
newName :: Ident -> (Int, Ident, Map Ident Ident)
newName old_name = (i, head names, env)
where (i, names, env) = newNames 1 [old_name]
newNames :: Int -> [Ident] -> (Int, [Ident], Map Ident Ident) newNames :: Int -> [Ident] -> (Int, [Ident], Map Ident Ident)
newNames i old_names = (i', new_names, env) newNames i old_names = (i', new_names, env)
where where
@ -215,22 +243,26 @@ collectScsExp = \case
-- --
ELet binds e -> (binds_scs ++ rhss_scs ++ e_scs, mkEAbs non_scs' e') ELet binds e -> (binds_scs ++ rhss_scs ++ e_scs, mkEAbs non_scs' e')
where where
binds_scs = [ Bind n (parms ++ parms1) e1 binds_scs = [ let (rhs', parms1) = flattenLambdas rhs in
| Bind n parms (EAbs parms1 e1) <- scs' Bind n (parms ++ parms1) rhs'
| Bind n parms rhs <- scs'
] ]
(rhss_scs, binds') = mapAccumL collectScsRhs [] binds (rhss_scs, binds') = mapAccumL collectScsRhs [] binds
(e_scs, e') = collectScsExp e (e_scs, e') = collectScsExp e
(scs', non_scs') = partition (\(Bind _ _ rhs) -> isEAbs rhs) binds' (scs', non_scs') = partition (\(Bind _ pars _) -> notNull pars) binds'
collectScsRhs acc (Bind n xs rhs) = (acc ++ rhs_scs, Bind n xs rhs') collectScsRhs acc (Bind n xs rhs) = (acc ++ rhs_scs, Bind n xs rhs')
where where
(rhs_scs, rhs') = collectScsExp rhs (rhs_scs, rhs') = collectScsExp rhs
isEAbs :: Exp -> Bool -- @\x.\y.\z. e → (e, [x,y,z])@
isEAbs = \case flattenLambdas :: Exp -> (Exp, [Ident])
EAbs {} -> True flattenLambdas e = go (e, [])
_ -> False where
go (e, acc) = case e of
EAbs par e1 -> go (e1, snoc par acc)
_ -> (e, acc)
mkEAbs :: [Bind] -> Exp -> Exp mkEAbs :: [Bind] -> Exp -> Exp
mkEAbs [] e = e mkEAbs [] e = e

View file

@ -3,7 +3,7 @@ module Main where
import Grammar.Par (myLexer, pProgram) import Grammar.Par (myLexer, pProgram)
import Grammar.Print (printTree) import Grammar.Print (printTree)
import LambdaLifter (lambdaLift) import LambdaLifter (abstract, freeVars, lambdaLift)
import System.Environment (getArgs) import System.Environment (getArgs)
import System.Exit (exitFailure, exitSuccess) import System.Exit (exitFailure, exitSuccess)
@ -20,7 +20,7 @@ main = getArgs >>= \case
Right prg -> do Right prg -> do
putStrLn "-- Parse" putStrLn "-- Parse"
putStrLn $ printTree prg putStrLn $ printTree prg
putStrLn "\n-- Lamda lifter" putStrLn "\n-- Lambda lifter"
putStrLn . printTree $ lambdaLift prg putStrLn . printTree $ lambdaLift prg
putStrLn "" putStrLn ""
exitSuccess exitSuccess