Add cases for lambda lifter

This commit is contained in:
Martin Fredin 2023-03-28 15:35:01 +02:00
parent 5986e2108e
commit 59d9be87cb

View file

@ -4,11 +4,12 @@
module LambdaLifter (lambdaLift, freeVars, abstract, collectScs) where module LambdaLifter (lambdaLift, freeVars, abstract, collectScs) where
import Auxiliary (snoc) import Auxiliary (mapAccumM, snoc)
import Control.Applicative (Applicative (liftA2)) import Control.Applicative (Applicative (liftA2))
import Control.Arrow (Arrow (second))
import Control.Monad.State (MonadState (get, put), State, import Control.Monad.State (MonadState (get, put), State,
evalState) evalState)
import Data.List (partition) import Data.List (mapAccumL, partition)
import Data.Set (Set) import Data.Set (Set)
import qualified Data.Set as Set import qualified Data.Set as Set
import Prelude hiding (exp) import Prelude hiding (exp)
@ -40,6 +41,8 @@ freeVarsExp localVars (exp, t) = case exp of
EVar n | Set.member n localVars -> (Set.singleton n, (AVar n, t)) EVar n | Set.member n localVars -> (Set.singleton n, (AVar n, t))
| otherwise -> (mempty, (AVar n, t)) | otherwise -> (mempty, (AVar n, t))
EInj n -> (mempty, (AVar n, t))
ELit lit -> (mempty, (ALit lit, t)) ELit lit -> (mempty, (ALit lit, t))
EApp e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), (AApp e1' e2', t)) EApp e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), (AApp e1' e2', t))
@ -68,6 +71,25 @@ freeVarsExp localVars (exp, t) = case exp of
e' = freeVarsExp e_localVars e e' = freeVarsExp e_localVars e
e_localVars = Set.insert name localVars e_localVars = Set.insert name localVars
ECase e branches -> (frees, (ACase e' branches', t))
where
frees = foldr (\b s -> Set.union s $ fst b) (freeVarsOf e') branches'
e' = freeVarsExp localVars e
branches' = map (freeVarsBranch localVars) branches
freeVarsBranch :: Set Ident -> Branch' Type -> (Set Ident, AnnBranch')
freeVarsBranch localVars (Branch (patt, t) exp) = (frees, AnnBranch (patt, t) exp')
where
frees = freeVarsOf exp' Set.\\ freeVarsOfPattern patt
exp' = freeVarsExp localVars exp
freeVarsOfPattern = Set.fromList . go []
where
go acc = \case
PVar (n,_) -> snoc n acc
PInj _ ps -> foldl go acc ps
freeVarsOf :: AnnExpT -> Set Ident freeVarsOf :: AnnExpT -> Set Ident
freeVarsOf = fst freeVarsOf = fst
@ -81,6 +103,10 @@ data ABind = ABind Id [Id] AnnExpT deriving Show
type AnnExpT' = (AnnExp, Type) type AnnExpT' = (AnnExp, Type)
type AnnBranch = (Set Ident, AnnBranch')
data AnnBranch' = AnnBranch (Pattern, Type) AnnExpT
deriving Show
data AnnExp = AVar Ident data AnnExp = AVar Ident
| AInj Ident | AInj Ident
| ALit Lit | ALit Lit
@ -88,6 +114,7 @@ data AnnExp = AVar Ident
| AApp AnnExpT AnnExpT | AApp AnnExpT AnnExpT
| AAdd AnnExpT AnnExpT | AAdd AnnExpT AnnExpT
| AAbs Ident AnnExpT | AAbs Ident AnnExpT
| ACase AnnExpT [AnnBranch]
deriving Show deriving Show
-- | Lift lambdas to let expression of the form @let sc = \v₁ x₁ -> e₁@. -- | Lift lambdas to let expression of the form @let sc = \v₁ x₁ -> e₁@.
@ -115,7 +142,8 @@ flattenLambdasAnn ae = go (ae, [])
abstractExp :: AnnExpT -> State Int ExpT abstractExp :: AnnExpT -> State Int ExpT
abstractExp (free, (exp, typ)) = case exp of abstractExp (free, (exp, typ)) = case exp of
AVar n -> pure (EVar n, typ) AVar n -> pure (EVar n, typ)
AInj n -> pure (EInj n, typ)
ALit lit -> pure (ELit lit, typ) ALit lit -> pure (ELit lit, typ)
AApp e1 e2 -> (, typ) <$> liftA2 EApp (abstractExp e1) (abstractExp e2) AApp e1 e2 -> (, typ) <$> liftA2 EApp (abstractExp e1) (abstractExp e2)
AAdd e1 e2 -> (, typ) <$> liftA2 EAdd (abstractExp e1) (abstractExp e2) AAdd e1 e2 -> (, typ) <$> liftA2 EAdd (abstractExp e1) (abstractExp e2)
@ -132,6 +160,9 @@ abstractExp (free, (exp, typ)) = case exp of
pure (EAbs par ae1', t) pure (EAbs par ae1', t)
_ -> f (free, (ae, t)) _ -> f (free, (ae, t))
ACase e branches -> (, typ) <$> liftA2 ECase (abstractExp e) (mapM abstractBranch branches)
-- Lift lambda into let and bind free variables -- Lift lambda into let and bind free variables
AAbs parm e -> do AAbs parm e -> do
i <- nextNumber i <- nextNumber
@ -149,6 +180,10 @@ abstractExp (free, (exp, typ)) = case exp of
where where
(t_var, t_return) = applyVarType t (t_var, t_return) = applyVarType t
abstractBranch :: AnnBranch -> State Int Branch
abstractBranch (_, AnnBranch patt exp) = Branch patt <$> abstractExp exp
applyVarType :: Type -> (Type, Type) applyVarType :: Type -> (Type, Type)
applyVarType typ = (t1, foldr ($) t2 foralls) applyVarType typ = (t1, foldr ($) t2 foralls)
@ -182,7 +217,8 @@ collectScs = concatMap collectFromRhs
collectScsExp :: ExpT -> ([Bind], ExpT) collectScsExp :: ExpT -> ([Bind], ExpT)
collectScsExp expT@(exp, typ) = case exp of collectScsExp expT@(exp, typ) = case exp of
EVar _ -> ([], expT) EVar _ -> ([], expT)
ELit _ -> ([], expT) EInj _ -> ([], expT)
ELit _ -> ([], expT)
EApp e1 e2 -> (scs1 ++ scs2, (EApp e1' e2', typ)) EApp e1 e2 -> (scs1 ++ scs2, (EApp e1' e2', typ))
where where
@ -198,6 +234,12 @@ collectScsExp expT@(exp, typ) = case exp of
where where
(scs, e') = collectScsExp e (scs, e') = collectScsExp e
ECase e branches -> (scs, (ECase e branches', typ))
where
(scs, branches') = mapAccumL f [] branches
f acc b = (acc ++ acc', b')
where (acc', b') = collectScsBranch b
-- Collect supercombinators from bind, the rhss, and the expression. -- Collect supercombinators from bind, the rhss, and the expression.
-- --
-- > f = let sc x y = rhs in e -- > f = let sc x y = rhs in e
@ -210,6 +252,9 @@ collectScsExp expT@(exp, typ) = case exp of
(rhs_scs, rhs') = collectScsExp rhs (rhs_scs, rhs') = collectScsExp rhs
(et_scs, et') = collectScsExp e (et_scs, et') = collectScsExp e
collectScsBranch (Branch patt exp) = (scs, Branch patt exp')
where (scs, exp') = collectScsExp exp
-- @\x.\y.\z. e → (e, [x,y,z])@ -- @\x.\y.\z. e → (e, [x,y,z])@
flattenLambdas :: ExpT -> (ExpT, [Id]) flattenLambdas :: ExpT -> (ExpT, [Id])
@ -240,3 +285,4 @@ skipForalls = go []
go acc typ = case typ of go acc typ = case typ of
TAll tvar t -> go (snoc (TAll tvar) acc) t TAll tvar t -> go (snoc (TAll tvar) acc) t
_ -> (acc, typ) _ -> (acc, typ)