Add cases for lambda lifter

This commit is contained in:
Martin Fredin 2023-03-28 15:35:01 +02:00
parent 5986e2108e
commit 59d9be87cb

View file

@ -4,11 +4,12 @@
module LambdaLifter (lambdaLift, freeVars, abstract, collectScs) where
import Auxiliary (snoc)
import Auxiliary (mapAccumM, snoc)
import Control.Applicative (Applicative (liftA2))
import Control.Arrow (Arrow (second))
import Control.Monad.State (MonadState (get, put), State,
evalState)
import Data.List (partition)
import Data.List (mapAccumL, partition)
import Data.Set (Set)
import qualified Data.Set as Set
import Prelude hiding (exp)
@ -40,6 +41,8 @@ freeVarsExp localVars (exp, t) = case exp of
EVar n | Set.member n localVars -> (Set.singleton n, (AVar n, t))
| otherwise -> (mempty, (AVar n, t))
EInj n -> (mempty, (AVar n, t))
ELit lit -> (mempty, (ALit lit, t))
EApp e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), (AApp e1' e2', t))
@ -68,6 +71,25 @@ freeVarsExp localVars (exp, t) = case exp of
e' = freeVarsExp e_localVars e
e_localVars = Set.insert name localVars
ECase e branches -> (frees, (ACase e' branches', t))
where
frees = foldr (\b s -> Set.union s $ fst b) (freeVarsOf e') branches'
e' = freeVarsExp localVars e
branches' = map (freeVarsBranch localVars) branches
freeVarsBranch :: Set Ident -> Branch' Type -> (Set Ident, AnnBranch')
freeVarsBranch localVars (Branch (patt, t) exp) = (frees, AnnBranch (patt, t) exp')
where
frees = freeVarsOf exp' Set.\\ freeVarsOfPattern patt
exp' = freeVarsExp localVars exp
freeVarsOfPattern = Set.fromList . go []
where
go acc = \case
PVar (n,_) -> snoc n acc
PInj _ ps -> foldl go acc ps
freeVarsOf :: AnnExpT -> Set Ident
freeVarsOf = fst
@ -81,6 +103,10 @@ data ABind = ABind Id [Id] AnnExpT deriving Show
type AnnExpT' = (AnnExp, Type)
type AnnBranch = (Set Ident, AnnBranch')
data AnnBranch' = AnnBranch (Pattern, Type) AnnExpT
deriving Show
data AnnExp = AVar Ident
| AInj Ident
| ALit Lit
@ -88,6 +114,7 @@ data AnnExp = AVar Ident
| AApp AnnExpT AnnExpT
| AAdd AnnExpT AnnExpT
| AAbs Ident AnnExpT
| ACase AnnExpT [AnnBranch]
deriving Show
-- | Lift lambdas to let expression of the form @let sc = \v₁ x₁ -> e₁@.
@ -115,7 +142,8 @@ flattenLambdasAnn ae = go (ae, [])
abstractExp :: AnnExpT -> State Int ExpT
abstractExp (free, (exp, typ)) = case exp of
AVar n -> pure (EVar n, typ)
AVar n -> pure (EVar n, typ)
AInj n -> pure (EInj n, typ)
ALit lit -> pure (ELit lit, typ)
AApp e1 e2 -> (, typ) <$> liftA2 EApp (abstractExp e1) (abstractExp e2)
AAdd e1 e2 -> (, typ) <$> liftA2 EAdd (abstractExp e1) (abstractExp e2)
@ -132,6 +160,9 @@ abstractExp (free, (exp, typ)) = case exp of
pure (EAbs par ae1', t)
_ -> f (free, (ae, t))
ACase e branches -> (, typ) <$> liftA2 ECase (abstractExp e) (mapM abstractBranch branches)
-- Lift lambda into let and bind free variables
AAbs parm e -> do
i <- nextNumber
@ -149,6 +180,10 @@ abstractExp (free, (exp, typ)) = case exp of
where
(t_var, t_return) = applyVarType t
abstractBranch :: AnnBranch -> State Int Branch
abstractBranch (_, AnnBranch patt exp) = Branch patt <$> abstractExp exp
applyVarType :: Type -> (Type, Type)
applyVarType typ = (t1, foldr ($) t2 foralls)
@ -182,7 +217,8 @@ collectScs = concatMap collectFromRhs
collectScsExp :: ExpT -> ([Bind], ExpT)
collectScsExp expT@(exp, typ) = case exp of
EVar _ -> ([], expT)
ELit _ -> ([], expT)
EInj _ -> ([], expT)
ELit _ -> ([], expT)
EApp e1 e2 -> (scs1 ++ scs2, (EApp e1' e2', typ))
where
@ -198,6 +234,12 @@ collectScsExp expT@(exp, typ) = case exp of
where
(scs, e') = collectScsExp e
ECase e branches -> (scs, (ECase e branches', typ))
where
(scs, branches') = mapAccumL f [] branches
f acc b = (acc ++ acc', b')
where (acc', b') = collectScsBranch b
-- Collect supercombinators from bind, the rhss, and the expression.
--
-- > f = let sc x y = rhs in e
@ -210,6 +252,9 @@ collectScsExp expT@(exp, typ) = case exp of
(rhs_scs, rhs') = collectScsExp rhs
(et_scs, et') = collectScsExp e
collectScsBranch (Branch patt exp) = (scs, Branch patt exp')
where (scs, exp') = collectScsExp exp
-- @\x.\y.\z. e → (e, [x,y,z])@
flattenLambdas :: ExpT -> (ExpT, [Id])
@ -240,3 +285,4 @@ skipForalls = go []
go acc typ = case typ of
TAll tvar t -> go (snoc (TAll tvar) acc) t
_ -> (acc, typ)