From 59d9be87cb51ef025288bfd10aa82d628086d9d2 Mon Sep 17 00:00:00 2001 From: Martin Fredin Date: Tue, 28 Mar 2023 15:35:01 +0200 Subject: [PATCH] Add cases for lambda lifter --- src/LambdaLifter.hs | 54 +++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 50 insertions(+), 4 deletions(-) diff --git a/src/LambdaLifter.hs b/src/LambdaLifter.hs index b85dd8b..5020fb6 100644 --- a/src/LambdaLifter.hs +++ b/src/LambdaLifter.hs @@ -4,11 +4,12 @@ module LambdaLifter (lambdaLift, freeVars, abstract, collectScs) where -import Auxiliary (snoc) +import Auxiliary (mapAccumM, snoc) import Control.Applicative (Applicative (liftA2)) +import Control.Arrow (Arrow (second)) import Control.Monad.State (MonadState (get, put), State, evalState) -import Data.List (partition) +import Data.List (mapAccumL, partition) import Data.Set (Set) import qualified Data.Set as Set import Prelude hiding (exp) @@ -40,6 +41,8 @@ freeVarsExp localVars (exp, t) = case exp of EVar n | Set.member n localVars -> (Set.singleton n, (AVar n, t)) | otherwise -> (mempty, (AVar n, t)) + EInj n -> (mempty, (AVar n, t)) + ELit lit -> (mempty, (ALit lit, t)) EApp e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), (AApp e1' e2', t)) @@ -68,6 +71,25 @@ freeVarsExp localVars (exp, t) = case exp of e' = freeVarsExp e_localVars e e_localVars = Set.insert name localVars + ECase e branches -> (frees, (ACase e' branches', t)) + where + frees = foldr (\b s -> Set.union s $ fst b) (freeVarsOf e') branches' + e' = freeVarsExp localVars e + branches' = map (freeVarsBranch localVars) branches + + +freeVarsBranch :: Set Ident -> Branch' Type -> (Set Ident, AnnBranch') +freeVarsBranch localVars (Branch (patt, t) exp) = (frees, AnnBranch (patt, t) exp') + where + frees = freeVarsOf exp' Set.\\ freeVarsOfPattern patt + exp' = freeVarsExp localVars exp + freeVarsOfPattern = Set.fromList . go [] + where + go acc = \case + PVar (n,_) -> snoc n acc + PInj _ ps -> foldl go acc ps + + freeVarsOf :: AnnExpT -> Set Ident freeVarsOf = fst @@ -81,6 +103,10 @@ data ABind = ABind Id [Id] AnnExpT deriving Show type AnnExpT' = (AnnExp, Type) +type AnnBranch = (Set Ident, AnnBranch') +data AnnBranch' = AnnBranch (Pattern, Type) AnnExpT + deriving Show + data AnnExp = AVar Ident | AInj Ident | ALit Lit @@ -88,6 +114,7 @@ data AnnExp = AVar Ident | AApp AnnExpT AnnExpT | AAdd AnnExpT AnnExpT | AAbs Ident AnnExpT + | ACase AnnExpT [AnnBranch] deriving Show -- | Lift lambdas to let expression of the form @let sc = \v₁ x₁ -> e₁@. @@ -115,7 +142,8 @@ flattenLambdasAnn ae = go (ae, []) abstractExp :: AnnExpT -> State Int ExpT abstractExp (free, (exp, typ)) = case exp of - AVar n -> pure (EVar n, typ) + AVar n -> pure (EVar n, typ) + AInj n -> pure (EInj n, typ) ALit lit -> pure (ELit lit, typ) AApp e1 e2 -> (, typ) <$> liftA2 EApp (abstractExp e1) (abstractExp e2) AAdd e1 e2 -> (, typ) <$> liftA2 EAdd (abstractExp e1) (abstractExp e2) @@ -132,6 +160,9 @@ abstractExp (free, (exp, typ)) = case exp of pure (EAbs par ae1', t) _ -> f (free, (ae, t)) + ACase e branches -> (, typ) <$> liftA2 ECase (abstractExp e) (mapM abstractBranch branches) + + -- Lift lambda into let and bind free variables AAbs parm e -> do i <- nextNumber @@ -149,6 +180,10 @@ abstractExp (free, (exp, typ)) = case exp of where (t_var, t_return) = applyVarType t + +abstractBranch :: AnnBranch -> State Int Branch +abstractBranch (_, AnnBranch patt exp) = Branch patt <$> abstractExp exp + applyVarType :: Type -> (Type, Type) applyVarType typ = (t1, foldr ($) t2 foralls) @@ -182,7 +217,8 @@ collectScs = concatMap collectFromRhs collectScsExp :: ExpT -> ([Bind], ExpT) collectScsExp expT@(exp, typ) = case exp of EVar _ -> ([], expT) - ELit _ -> ([], expT) + EInj _ -> ([], expT) + ELit _ -> ([], expT) EApp e1 e2 -> (scs1 ++ scs2, (EApp e1' e2', typ)) where @@ -198,6 +234,12 @@ collectScsExp expT@(exp, typ) = case exp of where (scs, e') = collectScsExp e + ECase e branches -> (scs, (ECase e branches', typ)) + where + (scs, branches') = mapAccumL f [] branches + f acc b = (acc ++ acc', b') + where (acc', b') = collectScsBranch b + -- Collect supercombinators from bind, the rhss, and the expression. -- -- > f = let sc x y = rhs in e @@ -210,6 +252,9 @@ collectScsExp expT@(exp, typ) = case exp of (rhs_scs, rhs') = collectScsExp rhs (et_scs, et') = collectScsExp e +collectScsBranch (Branch patt exp) = (scs, Branch patt exp') + where (scs, exp') = collectScsExp exp + -- @\x.\y.\z. e → (e, [x,y,z])@ flattenLambdas :: ExpT -> (ExpT, [Id]) @@ -240,3 +285,4 @@ skipForalls = go [] go acc typ = case typ of TAll tvar t -> go (snoc (TAll tvar) acc) t _ -> (acc, typ) +