Add cases for lambda lifter
This commit is contained in:
parent
5986e2108e
commit
59d9be87cb
1 changed files with 50 additions and 4 deletions
|
|
@ -4,11 +4,12 @@
|
|||
|
||||
module LambdaLifter (lambdaLift, freeVars, abstract, collectScs) where
|
||||
|
||||
import Auxiliary (snoc)
|
||||
import Auxiliary (mapAccumM, snoc)
|
||||
import Control.Applicative (Applicative (liftA2))
|
||||
import Control.Arrow (Arrow (second))
|
||||
import Control.Monad.State (MonadState (get, put), State,
|
||||
evalState)
|
||||
import Data.List (partition)
|
||||
import Data.List (mapAccumL, partition)
|
||||
import Data.Set (Set)
|
||||
import qualified Data.Set as Set
|
||||
import Prelude hiding (exp)
|
||||
|
|
@ -40,6 +41,8 @@ freeVarsExp localVars (exp, t) = case exp of
|
|||
EVar n | Set.member n localVars -> (Set.singleton n, (AVar n, t))
|
||||
| otherwise -> (mempty, (AVar n, t))
|
||||
|
||||
EInj n -> (mempty, (AVar n, t))
|
||||
|
||||
ELit lit -> (mempty, (ALit lit, t))
|
||||
|
||||
EApp e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), (AApp e1' e2', t))
|
||||
|
|
@ -68,6 +71,25 @@ freeVarsExp localVars (exp, t) = case exp of
|
|||
e' = freeVarsExp e_localVars e
|
||||
e_localVars = Set.insert name localVars
|
||||
|
||||
ECase e branches -> (frees, (ACase e' branches', t))
|
||||
where
|
||||
frees = foldr (\b s -> Set.union s $ fst b) (freeVarsOf e') branches'
|
||||
e' = freeVarsExp localVars e
|
||||
branches' = map (freeVarsBranch localVars) branches
|
||||
|
||||
|
||||
freeVarsBranch :: Set Ident -> Branch' Type -> (Set Ident, AnnBranch')
|
||||
freeVarsBranch localVars (Branch (patt, t) exp) = (frees, AnnBranch (patt, t) exp')
|
||||
where
|
||||
frees = freeVarsOf exp' Set.\\ freeVarsOfPattern patt
|
||||
exp' = freeVarsExp localVars exp
|
||||
freeVarsOfPattern = Set.fromList . go []
|
||||
where
|
||||
go acc = \case
|
||||
PVar (n,_) -> snoc n acc
|
||||
PInj _ ps -> foldl go acc ps
|
||||
|
||||
|
||||
|
||||
freeVarsOf :: AnnExpT -> Set Ident
|
||||
freeVarsOf = fst
|
||||
|
|
@ -81,6 +103,10 @@ data ABind = ABind Id [Id] AnnExpT deriving Show
|
|||
|
||||
type AnnExpT' = (AnnExp, Type)
|
||||
|
||||
type AnnBranch = (Set Ident, AnnBranch')
|
||||
data AnnBranch' = AnnBranch (Pattern, Type) AnnExpT
|
||||
deriving Show
|
||||
|
||||
data AnnExp = AVar Ident
|
||||
| AInj Ident
|
||||
| ALit Lit
|
||||
|
|
@ -88,6 +114,7 @@ data AnnExp = AVar Ident
|
|||
| AApp AnnExpT AnnExpT
|
||||
| AAdd AnnExpT AnnExpT
|
||||
| AAbs Ident AnnExpT
|
||||
| ACase AnnExpT [AnnBranch]
|
||||
deriving Show
|
||||
|
||||
-- | Lift lambdas to let expression of the form @let sc = \v₁ x₁ -> e₁@.
|
||||
|
|
@ -115,7 +142,8 @@ flattenLambdasAnn ae = go (ae, [])
|
|||
|
||||
abstractExp :: AnnExpT -> State Int ExpT
|
||||
abstractExp (free, (exp, typ)) = case exp of
|
||||
AVar n -> pure (EVar n, typ)
|
||||
AVar n -> pure (EVar n, typ)
|
||||
AInj n -> pure (EInj n, typ)
|
||||
ALit lit -> pure (ELit lit, typ)
|
||||
AApp e1 e2 -> (, typ) <$> liftA2 EApp (abstractExp e1) (abstractExp e2)
|
||||
AAdd e1 e2 -> (, typ) <$> liftA2 EAdd (abstractExp e1) (abstractExp e2)
|
||||
|
|
@ -132,6 +160,9 @@ abstractExp (free, (exp, typ)) = case exp of
|
|||
pure (EAbs par ae1', t)
|
||||
_ -> f (free, (ae, t))
|
||||
|
||||
ACase e branches -> (, typ) <$> liftA2 ECase (abstractExp e) (mapM abstractBranch branches)
|
||||
|
||||
|
||||
-- Lift lambda into let and bind free variables
|
||||
AAbs parm e -> do
|
||||
i <- nextNumber
|
||||
|
|
@ -149,6 +180,10 @@ abstractExp (free, (exp, typ)) = case exp of
|
|||
where
|
||||
(t_var, t_return) = applyVarType t
|
||||
|
||||
|
||||
abstractBranch :: AnnBranch -> State Int Branch
|
||||
abstractBranch (_, AnnBranch patt exp) = Branch patt <$> abstractExp exp
|
||||
|
||||
applyVarType :: Type -> (Type, Type)
|
||||
applyVarType typ = (t1, foldr ($) t2 foralls)
|
||||
|
||||
|
|
@ -182,7 +217,8 @@ collectScs = concatMap collectFromRhs
|
|||
collectScsExp :: ExpT -> ([Bind], ExpT)
|
||||
collectScsExp expT@(exp, typ) = case exp of
|
||||
EVar _ -> ([], expT)
|
||||
ELit _ -> ([], expT)
|
||||
EInj _ -> ([], expT)
|
||||
ELit _ -> ([], expT)
|
||||
|
||||
EApp e1 e2 -> (scs1 ++ scs2, (EApp e1' e2', typ))
|
||||
where
|
||||
|
|
@ -198,6 +234,12 @@ collectScsExp expT@(exp, typ) = case exp of
|
|||
where
|
||||
(scs, e') = collectScsExp e
|
||||
|
||||
ECase e branches -> (scs, (ECase e branches', typ))
|
||||
where
|
||||
(scs, branches') = mapAccumL f [] branches
|
||||
f acc b = (acc ++ acc', b')
|
||||
where (acc', b') = collectScsBranch b
|
||||
|
||||
-- Collect supercombinators from bind, the rhss, and the expression.
|
||||
--
|
||||
-- > f = let sc x y = rhs in e
|
||||
|
|
@ -210,6 +252,9 @@ collectScsExp expT@(exp, typ) = case exp of
|
|||
(rhs_scs, rhs') = collectScsExp rhs
|
||||
(et_scs, et') = collectScsExp e
|
||||
|
||||
collectScsBranch (Branch patt exp) = (scs, Branch patt exp')
|
||||
where (scs, exp') = collectScsExp exp
|
||||
|
||||
|
||||
-- @\x.\y.\z. e → (e, [x,y,z])@
|
||||
flattenLambdas :: ExpT -> (ExpT, [Id])
|
||||
|
|
@ -240,3 +285,4 @@ skipForalls = go []
|
|||
go acc typ = case typ of
|
||||
TAll tvar t -> go (snoc (TAll tvar) acc) t
|
||||
_ -> (acc, typ)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue