diff --git a/Grammar.cf b/Grammar.cf index 72e01da..410d11d 100644 --- a/Grammar.cf +++ b/Grammar.cf @@ -7,7 +7,7 @@ EInt. Exp3 ::= Integer; ELet. Exp3 ::= "let" [Bind] "in" Exp; EApp. Exp2 ::= Exp2 Exp3; EAdd. Exp1 ::= Exp1 "+" Exp2; -EAbs. Exp ::= "\\" [Ident] "." Exp; +EAbs. Exp ::= "\\" Ident "." Exp; Bind. Bind ::= Ident [Ident] "=" Exp; separator Bind ";"; diff --git a/src/Auxiliary.hs b/src/Auxiliary.hs index cd844d7..2de36a7 100644 --- a/src/Auxiliary.hs +++ b/src/Auxiliary.hs @@ -1,5 +1,11 @@ module Auxiliary (module Auxiliary) where +import Control.Monad.Error.Class (liftEither) +import Control.Monad.Except (MonadError) +import Data.Either.Combinators (maybeToRight) snoc :: a -> [a] -> [a] snoc x xs = xs ++ [x] + +maybeToRightM :: MonadError l m => l -> Maybe r -> m r +maybeToRightM err = liftEither . maybeToRight err diff --git a/src/LambdaLifter.hs b/src/LambdaLifter.hs index 625041c..3d9595d 100644 --- a/src/LambdaLifter.hs +++ b/src/LambdaLifter.hs @@ -4,15 +4,17 @@ module LambdaLifter (lambdaLift, freeVars, abstract, rename, collectScs) where -import Data.List (mapAccumL, partition) -import Data.Map (Map) -import qualified Data.Map as Map -import Data.Maybe (fromMaybe) -import Data.Set (Set, (\\)) -import qualified Data.Set as Set -import Data.Tuple.Extra (uncurry3) +import Auxiliary (snoc) +import Data.Foldable.Extra (notNull) +import Data.List (mapAccumL, mapAccumR, partition) +import Data.Map (Map) +import qualified Data.Map as Map +import Data.Maybe (fromMaybe, mapMaybe) +import Data.Set (Set, (\\)) +import qualified Data.Set as Set +import Data.Tuple.Extra (uncurry3) import Grammar.Abs -import Prelude hiding (exp) +import Prelude hiding (exp) -- | Lift lambdas and let expression into supercombinators. @@ -44,9 +46,9 @@ freeVarsExp localVars = \case e1' = freeVarsExp localVars e1 e2' = freeVarsExp localVars e2 - EAbs parms e -> (freeVarsOf e' \\ Set.fromList parms, AAbs parms e') + EAbs par e -> (Set.delete par $ freeVarsOf e', AAbs par e') where - e' = freeVarsExp (foldr Set.insert localVars parms) e + e' = freeVarsExp (Set.insert par localVars) e -- Sum free variables present in binders and the expression ELet binders e -> (Set.union binders_frees e_free, ALet binders' e') @@ -82,7 +84,7 @@ data AnnExp' = AId Ident | AInt Integer | AApp AnnExp AnnExp | AAdd AnnExp AnnExp - | AAbs [Ident] AnnExp + | AAbs Ident AnnExp | ALet [ABind] AnnExp deriving Show @@ -93,10 +95,24 @@ abstract prog = Program $ map go prog where go :: (Ident, [Ident], AnnExp) -> Bind go (name, pars, rhs@(_, e)) = - case e of - AAbs pars1 e1 -> Bind name (pars ++ pars1) $ abstractExp e1 - _ -> Bind name pars $ abstractExp rhs + AAbs par e1 -> Bind name (snoc par pars ++ pars2) $ abstractExp e2 + where + (e2, pars2) = flattenLambdasAnn e1 + _ -> Bind name pars $ abstractExp rhs + + +-- | Flatten nested lambdas and collect the parameters +-- @\x.\y.\z. ae → (ae, [x,y,z])@ +flattenLambdasAnn :: AnnExp -> (AnnExp, [Ident]) +flattenLambdasAnn ae = go (ae, []) + where + go :: (AnnExp, [Ident]) -> (AnnExp, [Ident]) + go ((free, e), acc) = + case e of + AAbs par (free1, e1) -> go ((Set.delete par free1, e1), snoc par acc) + _ -> ((free, e), acc) + abstractExp :: AnnExp -> Exp abstractExp (free, exp) = case exp of @@ -106,7 +122,11 @@ abstractExp (free, exp) = case exp of AAdd e1 e2 -> EAdd (abstractExp e1) (abstractExp e2) ALet bs e -> ELet (map go bs) $ abstractExp e where - go (ABind name parms rhs) = Bind name parms $ skipLambdas abstractExp rhs + go (ABind name parms rhs) = + let + (rhs', parms1) = flattenLambdas $ skipLambdas abstractExp rhs + in + Bind name (parms ++ parms1) rhs' skipLambdas :: (AnnExp -> Exp) -> AnnExp -> Exp skipLambdas f (free, ae) = case ae of @@ -114,11 +134,10 @@ abstractExp (free, exp) = case exp of _ -> f (free, ae) -- Lift lambda into let and bind free variables - AAbs parms e -> foldl EApp sc $ map EId freeList + AAbs par e -> foldl EApp sc $ map EId freeList where freeList = Set.toList free - sc = ELet [Bind "sc" [] rhs] $ EId "sc" - rhs = EAbs (freeList ++ parms) $ abstractExp e + sc = ELet [Bind "sc" (snoc par freeList) $ abstractExp e] $ EId "sc" -- | Rename all supercombinators and variables rename :: Program -> Program @@ -147,21 +166,30 @@ renameExp env i = \case (i1, e1') = renameExp env i e1 (i2, e2') = renameExp env i1 e2 - ELet bs e -> (i3, ELet (zipWith3 Bind ns' xs es') e') + ELet bs e -> (i3, ELet (zipWith3 Bind ns' pars' es') e') where (i1, e') = renameExp e_env i e - (ns, xs, es) = fromBinders bs - (i2, ns', env') = newNames i1 ns + (names, pars, rhss) = fromBinders bs + (i2, ns', env') = newNames i1 (names ++ concat pars) + pars' = (map . map) renamePar pars e_env = Map.union env' env - (i3, es') = mapAccumL (renameExp e_env) i2 es + (i3, es') = mapAccumL (renameExp e_env) i2 rhss + + renamePar p = case Map.lookup p env' of + Just p' -> p' + Nothing -> error ("Can't find name for " ++ show p) - EAbs parms e -> (i2, EAbs ns e') + EAbs par e -> (i2, EAbs par' e') where - (i1, ns, env') = newNames i parms - (i2, e') = renameExp (Map.union env' env ) i1 e + (i1, par', env') = newName par + (i2, e') = renameExp (Map.union env' env ) i1 e +newName :: Ident -> (Int, Ident, Map Ident Ident) +newName old_name = (i, head names, env) + where (i, names, env) = newNames 1 [old_name] + newNames :: Int -> [Ident] -> (Int, [Ident], Map Ident Ident) newNames i old_names = (i', new_names, env) where @@ -215,22 +243,26 @@ collectScsExp = \case -- ELet binds e -> (binds_scs ++ rhss_scs ++ e_scs, mkEAbs non_scs' e') where - binds_scs = [ Bind n (parms ++ parms1) e1 - | Bind n parms (EAbs parms1 e1) <- scs' + binds_scs = [ let (rhs', parms1) = flattenLambdas rhs in + Bind n (parms ++ parms1) rhs' + | Bind n parms rhs <- scs' ] (rhss_scs, binds') = mapAccumL collectScsRhs [] binds (e_scs, e') = collectScsExp e - (scs', non_scs') = partition (\(Bind _ _ rhs) -> isEAbs rhs) binds' + (scs', non_scs') = partition (\(Bind _ pars _) -> notNull pars) binds' collectScsRhs acc (Bind n xs rhs) = (acc ++ rhs_scs, Bind n xs rhs') where (rhs_scs, rhs') = collectScsExp rhs -isEAbs :: Exp -> Bool -isEAbs = \case - EAbs {} -> True - _ -> False +-- @\x.\y.\z. e → (e, [x,y,z])@ +flattenLambdas :: Exp -> (Exp, [Ident]) +flattenLambdas e = go (e, []) + where + go (e, acc) = case e of + EAbs par e1 -> go (e1, snoc par acc) + _ -> (e, acc) mkEAbs :: [Bind] -> Exp -> Exp mkEAbs [] e = e diff --git a/src/Main.hs b/src/Main.hs index 570ac1a..ba6edf2 100644 --- a/src/Main.hs +++ b/src/Main.hs @@ -3,7 +3,7 @@ module Main where import Grammar.Par (myLexer, pProgram) import Grammar.Print (printTree) -import LambdaLifter (lambdaLift) +import LambdaLifter (abstract, freeVars, lambdaLift) import System.Environment (getArgs) import System.Exit (exitFailure, exitSuccess) @@ -20,7 +20,7 @@ main = getArgs >>= \case Right prg -> do putStrLn "-- Parse" putStrLn $ printTree prg - putStrLn "\n-- Lamda lifter" + putStrLn "\n-- Lambda lifter" putStrLn . printTree $ lambdaLift prg putStrLn "" exitSuccess