Revert back to one lambda par, and fix issues with lambda lifter

This commit is contained in:
Martin Fredin 2023-02-11 09:59:26 +01:00
parent 78a3ed56ea
commit e212c79a44
4 changed files with 73 additions and 35 deletions

View file

@ -7,7 +7,7 @@ EInt. Exp3 ::= Integer;
ELet. Exp3 ::= "let" [Bind] "in" Exp;
EApp. Exp2 ::= Exp2 Exp3;
EAdd. Exp1 ::= Exp1 "+" Exp2;
EAbs. Exp ::= "\\" [Ident] "." Exp;
EAbs. Exp ::= "\\" Ident "." Exp;
Bind. Bind ::= Ident [Ident] "=" Exp;
separator Bind ";";

View file

@ -1,5 +1,11 @@
module Auxiliary (module Auxiliary) where
import Control.Monad.Error.Class (liftEither)
import Control.Monad.Except (MonadError)
import Data.Either.Combinators (maybeToRight)
snoc :: a -> [a] -> [a]
snoc x xs = xs ++ [x]
maybeToRightM :: MonadError l m => l -> Maybe r -> m r
maybeToRightM err = liftEither . maybeToRight err

View file

@ -4,15 +4,17 @@
module LambdaLifter (lambdaLift, freeVars, abstract, rename, collectScs) where
import Data.List (mapAccumL, partition)
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Maybe (fromMaybe)
import Data.Set (Set, (\\))
import qualified Data.Set as Set
import Data.Tuple.Extra (uncurry3)
import Auxiliary (snoc)
import Data.Foldable.Extra (notNull)
import Data.List (mapAccumL, mapAccumR, partition)
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Maybe (fromMaybe, mapMaybe)
import Data.Set (Set, (\\))
import qualified Data.Set as Set
import Data.Tuple.Extra (uncurry3)
import Grammar.Abs
import Prelude hiding (exp)
import Prelude hiding (exp)
-- | Lift lambdas and let expression into supercombinators.
@ -44,9 +46,9 @@ freeVarsExp localVars = \case
e1' = freeVarsExp localVars e1
e2' = freeVarsExp localVars e2
EAbs parms e -> (freeVarsOf e' \\ Set.fromList parms, AAbs parms e')
EAbs par e -> (Set.delete par $ freeVarsOf e', AAbs par e')
where
e' = freeVarsExp (foldr Set.insert localVars parms) e
e' = freeVarsExp (Set.insert par localVars) e
-- Sum free variables present in binders and the expression
ELet binders e -> (Set.union binders_frees e_free, ALet binders' e')
@ -82,7 +84,7 @@ data AnnExp' = AId Ident
| AInt Integer
| AApp AnnExp AnnExp
| AAdd AnnExp AnnExp
| AAbs [Ident] AnnExp
| AAbs Ident AnnExp
| ALet [ABind] AnnExp
deriving Show
@ -93,10 +95,24 @@ abstract prog = Program $ map go prog
where
go :: (Ident, [Ident], AnnExp) -> Bind
go (name, pars, rhs@(_, e)) =
case e of
AAbs pars1 e1 -> Bind name (pars ++ pars1) $ abstractExp e1
_ -> Bind name pars $ abstractExp rhs
AAbs par e1 -> Bind name (snoc par pars ++ pars2) $ abstractExp e2
where
(e2, pars2) = flattenLambdasAnn e1
_ -> Bind name pars $ abstractExp rhs
-- | Flatten nested lambdas and collect the parameters
-- @\x.\y.\z. ae → (ae, [x,y,z])@
flattenLambdasAnn :: AnnExp -> (AnnExp, [Ident])
flattenLambdasAnn ae = go (ae, [])
where
go :: (AnnExp, [Ident]) -> (AnnExp, [Ident])
go ((free, e), acc) =
case e of
AAbs par (free1, e1) -> go ((Set.delete par free1, e1), snoc par acc)
_ -> ((free, e), acc)
abstractExp :: AnnExp -> Exp
abstractExp (free, exp) = case exp of
@ -106,7 +122,11 @@ abstractExp (free, exp) = case exp of
AAdd e1 e2 -> EAdd (abstractExp e1) (abstractExp e2)
ALet bs e -> ELet (map go bs) $ abstractExp e
where
go (ABind name parms rhs) = Bind name parms $ skipLambdas abstractExp rhs
go (ABind name parms rhs) =
let
(rhs', parms1) = flattenLambdas $ skipLambdas abstractExp rhs
in
Bind name (parms ++ parms1) rhs'
skipLambdas :: (AnnExp -> Exp) -> AnnExp -> Exp
skipLambdas f (free, ae) = case ae of
@ -114,11 +134,10 @@ abstractExp (free, exp) = case exp of
_ -> f (free, ae)
-- Lift lambda into let and bind free variables
AAbs parms e -> foldl EApp sc $ map EId freeList
AAbs par e -> foldl EApp sc $ map EId freeList
where
freeList = Set.toList free
sc = ELet [Bind "sc" [] rhs] $ EId "sc"
rhs = EAbs (freeList ++ parms) $ abstractExp e
sc = ELet [Bind "sc" (snoc par freeList) $ abstractExp e] $ EId "sc"
-- | Rename all supercombinators and variables
rename :: Program -> Program
@ -147,21 +166,30 @@ renameExp env i = \case
(i1, e1') = renameExp env i e1
(i2, e2') = renameExp env i1 e2
ELet bs e -> (i3, ELet (zipWith3 Bind ns' xs es') e')
ELet bs e -> (i3, ELet (zipWith3 Bind ns' pars' es') e')
where
(i1, e') = renameExp e_env i e
(ns, xs, es) = fromBinders bs
(i2, ns', env') = newNames i1 ns
(names, pars, rhss) = fromBinders bs
(i2, ns', env') = newNames i1 (names ++ concat pars)
pars' = (map . map) renamePar pars
e_env = Map.union env' env
(i3, es') = mapAccumL (renameExp e_env) i2 es
(i3, es') = mapAccumL (renameExp e_env) i2 rhss
renamePar p = case Map.lookup p env' of
Just p' -> p'
Nothing -> error ("Can't find name for " ++ show p)
EAbs parms e -> (i2, EAbs ns e')
EAbs par e -> (i2, EAbs par' e')
where
(i1, ns, env') = newNames i parms
(i2, e') = renameExp (Map.union env' env ) i1 e
(i1, par', env') = newName par
(i2, e') = renameExp (Map.union env' env ) i1 e
newName :: Ident -> (Int, Ident, Map Ident Ident)
newName old_name = (i, head names, env)
where (i, names, env) = newNames 1 [old_name]
newNames :: Int -> [Ident] -> (Int, [Ident], Map Ident Ident)
newNames i old_names = (i', new_names, env)
where
@ -215,22 +243,26 @@ collectScsExp = \case
--
ELet binds e -> (binds_scs ++ rhss_scs ++ e_scs, mkEAbs non_scs' e')
where
binds_scs = [ Bind n (parms ++ parms1) e1
| Bind n parms (EAbs parms1 e1) <- scs'
binds_scs = [ let (rhs', parms1) = flattenLambdas rhs in
Bind n (parms ++ parms1) rhs'
| Bind n parms rhs <- scs'
]
(rhss_scs, binds') = mapAccumL collectScsRhs [] binds
(e_scs, e') = collectScsExp e
(scs', non_scs') = partition (\(Bind _ _ rhs) -> isEAbs rhs) binds'
(scs', non_scs') = partition (\(Bind _ pars _) -> notNull pars) binds'
collectScsRhs acc (Bind n xs rhs) = (acc ++ rhs_scs, Bind n xs rhs')
where
(rhs_scs, rhs') = collectScsExp rhs
isEAbs :: Exp -> Bool
isEAbs = \case
EAbs {} -> True
_ -> False
-- @\x.\y.\z. e → (e, [x,y,z])@
flattenLambdas :: Exp -> (Exp, [Ident])
flattenLambdas e = go (e, [])
where
go (e, acc) = case e of
EAbs par e1 -> go (e1, snoc par acc)
_ -> (e, acc)
mkEAbs :: [Bind] -> Exp -> Exp
mkEAbs [] e = e

View file

@ -3,7 +3,7 @@ module Main where
import Grammar.Par (myLexer, pProgram)
import Grammar.Print (printTree)
import LambdaLifter (lambdaLift)
import LambdaLifter (abstract, freeVars, lambdaLift)
import System.Environment (getArgs)
import System.Exit (exitFailure, exitSuccess)
@ -20,7 +20,7 @@ main = getArgs >>= \case
Right prg -> do
putStrLn "-- Parse"
putStrLn $ printTree prg
putStrLn "\n-- Lamda lifter"
putStrLn "\n-- Lambda lifter"
putStrLn . printTree $ lambdaLift prg
putStrLn ""
exitSuccess