Add cases for lambda lifter
This commit is contained in:
parent
5986e2108e
commit
59d9be87cb
1 changed files with 50 additions and 4 deletions
|
|
@ -4,11 +4,12 @@
|
||||||
|
|
||||||
module LambdaLifter (lambdaLift, freeVars, abstract, collectScs) where
|
module LambdaLifter (lambdaLift, freeVars, abstract, collectScs) where
|
||||||
|
|
||||||
import Auxiliary (snoc)
|
import Auxiliary (mapAccumM, snoc)
|
||||||
import Control.Applicative (Applicative (liftA2))
|
import Control.Applicative (Applicative (liftA2))
|
||||||
|
import Control.Arrow (Arrow (second))
|
||||||
import Control.Monad.State (MonadState (get, put), State,
|
import Control.Monad.State (MonadState (get, put), State,
|
||||||
evalState)
|
evalState)
|
||||||
import Data.List (partition)
|
import Data.List (mapAccumL, partition)
|
||||||
import Data.Set (Set)
|
import Data.Set (Set)
|
||||||
import qualified Data.Set as Set
|
import qualified Data.Set as Set
|
||||||
import Prelude hiding (exp)
|
import Prelude hiding (exp)
|
||||||
|
|
@ -40,6 +41,8 @@ freeVarsExp localVars (exp, t) = case exp of
|
||||||
EVar n | Set.member n localVars -> (Set.singleton n, (AVar n, t))
|
EVar n | Set.member n localVars -> (Set.singleton n, (AVar n, t))
|
||||||
| otherwise -> (mempty, (AVar n, t))
|
| otherwise -> (mempty, (AVar n, t))
|
||||||
|
|
||||||
|
EInj n -> (mempty, (AVar n, t))
|
||||||
|
|
||||||
ELit lit -> (mempty, (ALit lit, t))
|
ELit lit -> (mempty, (ALit lit, t))
|
||||||
|
|
||||||
EApp e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), (AApp e1' e2', t))
|
EApp e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), (AApp e1' e2', t))
|
||||||
|
|
@ -68,6 +71,25 @@ freeVarsExp localVars (exp, t) = case exp of
|
||||||
e' = freeVarsExp e_localVars e
|
e' = freeVarsExp e_localVars e
|
||||||
e_localVars = Set.insert name localVars
|
e_localVars = Set.insert name localVars
|
||||||
|
|
||||||
|
ECase e branches -> (frees, (ACase e' branches', t))
|
||||||
|
where
|
||||||
|
frees = foldr (\b s -> Set.union s $ fst b) (freeVarsOf e') branches'
|
||||||
|
e' = freeVarsExp localVars e
|
||||||
|
branches' = map (freeVarsBranch localVars) branches
|
||||||
|
|
||||||
|
|
||||||
|
freeVarsBranch :: Set Ident -> Branch' Type -> (Set Ident, AnnBranch')
|
||||||
|
freeVarsBranch localVars (Branch (patt, t) exp) = (frees, AnnBranch (patt, t) exp')
|
||||||
|
where
|
||||||
|
frees = freeVarsOf exp' Set.\\ freeVarsOfPattern patt
|
||||||
|
exp' = freeVarsExp localVars exp
|
||||||
|
freeVarsOfPattern = Set.fromList . go []
|
||||||
|
where
|
||||||
|
go acc = \case
|
||||||
|
PVar (n,_) -> snoc n acc
|
||||||
|
PInj _ ps -> foldl go acc ps
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
freeVarsOf :: AnnExpT -> Set Ident
|
freeVarsOf :: AnnExpT -> Set Ident
|
||||||
freeVarsOf = fst
|
freeVarsOf = fst
|
||||||
|
|
@ -81,6 +103,10 @@ data ABind = ABind Id [Id] AnnExpT deriving Show
|
||||||
|
|
||||||
type AnnExpT' = (AnnExp, Type)
|
type AnnExpT' = (AnnExp, Type)
|
||||||
|
|
||||||
|
type AnnBranch = (Set Ident, AnnBranch')
|
||||||
|
data AnnBranch' = AnnBranch (Pattern, Type) AnnExpT
|
||||||
|
deriving Show
|
||||||
|
|
||||||
data AnnExp = AVar Ident
|
data AnnExp = AVar Ident
|
||||||
| AInj Ident
|
| AInj Ident
|
||||||
| ALit Lit
|
| ALit Lit
|
||||||
|
|
@ -88,6 +114,7 @@ data AnnExp = AVar Ident
|
||||||
| AApp AnnExpT AnnExpT
|
| AApp AnnExpT AnnExpT
|
||||||
| AAdd AnnExpT AnnExpT
|
| AAdd AnnExpT AnnExpT
|
||||||
| AAbs Ident AnnExpT
|
| AAbs Ident AnnExpT
|
||||||
|
| ACase AnnExpT [AnnBranch]
|
||||||
deriving Show
|
deriving Show
|
||||||
|
|
||||||
-- | Lift lambdas to let expression of the form @let sc = \v₁ x₁ -> e₁@.
|
-- | Lift lambdas to let expression of the form @let sc = \v₁ x₁ -> e₁@.
|
||||||
|
|
@ -116,6 +143,7 @@ flattenLambdasAnn ae = go (ae, [])
|
||||||
abstractExp :: AnnExpT -> State Int ExpT
|
abstractExp :: AnnExpT -> State Int ExpT
|
||||||
abstractExp (free, (exp, typ)) = case exp of
|
abstractExp (free, (exp, typ)) = case exp of
|
||||||
AVar n -> pure (EVar n, typ)
|
AVar n -> pure (EVar n, typ)
|
||||||
|
AInj n -> pure (EInj n, typ)
|
||||||
ALit lit -> pure (ELit lit, typ)
|
ALit lit -> pure (ELit lit, typ)
|
||||||
AApp e1 e2 -> (, typ) <$> liftA2 EApp (abstractExp e1) (abstractExp e2)
|
AApp e1 e2 -> (, typ) <$> liftA2 EApp (abstractExp e1) (abstractExp e2)
|
||||||
AAdd e1 e2 -> (, typ) <$> liftA2 EAdd (abstractExp e1) (abstractExp e2)
|
AAdd e1 e2 -> (, typ) <$> liftA2 EAdd (abstractExp e1) (abstractExp e2)
|
||||||
|
|
@ -132,6 +160,9 @@ abstractExp (free, (exp, typ)) = case exp of
|
||||||
pure (EAbs par ae1', t)
|
pure (EAbs par ae1', t)
|
||||||
_ -> f (free, (ae, t))
|
_ -> f (free, (ae, t))
|
||||||
|
|
||||||
|
ACase e branches -> (, typ) <$> liftA2 ECase (abstractExp e) (mapM abstractBranch branches)
|
||||||
|
|
||||||
|
|
||||||
-- Lift lambda into let and bind free variables
|
-- Lift lambda into let and bind free variables
|
||||||
AAbs parm e -> do
|
AAbs parm e -> do
|
||||||
i <- nextNumber
|
i <- nextNumber
|
||||||
|
|
@ -149,6 +180,10 @@ abstractExp (free, (exp, typ)) = case exp of
|
||||||
where
|
where
|
||||||
(t_var, t_return) = applyVarType t
|
(t_var, t_return) = applyVarType t
|
||||||
|
|
||||||
|
|
||||||
|
abstractBranch :: AnnBranch -> State Int Branch
|
||||||
|
abstractBranch (_, AnnBranch patt exp) = Branch patt <$> abstractExp exp
|
||||||
|
|
||||||
applyVarType :: Type -> (Type, Type)
|
applyVarType :: Type -> (Type, Type)
|
||||||
applyVarType typ = (t1, foldr ($) t2 foralls)
|
applyVarType typ = (t1, foldr ($) t2 foralls)
|
||||||
|
|
||||||
|
|
@ -182,6 +217,7 @@ collectScs = concatMap collectFromRhs
|
||||||
collectScsExp :: ExpT -> ([Bind], ExpT)
|
collectScsExp :: ExpT -> ([Bind], ExpT)
|
||||||
collectScsExp expT@(exp, typ) = case exp of
|
collectScsExp expT@(exp, typ) = case exp of
|
||||||
EVar _ -> ([], expT)
|
EVar _ -> ([], expT)
|
||||||
|
EInj _ -> ([], expT)
|
||||||
ELit _ -> ([], expT)
|
ELit _ -> ([], expT)
|
||||||
|
|
||||||
EApp e1 e2 -> (scs1 ++ scs2, (EApp e1' e2', typ))
|
EApp e1 e2 -> (scs1 ++ scs2, (EApp e1' e2', typ))
|
||||||
|
|
@ -198,6 +234,12 @@ collectScsExp expT@(exp, typ) = case exp of
|
||||||
where
|
where
|
||||||
(scs, e') = collectScsExp e
|
(scs, e') = collectScsExp e
|
||||||
|
|
||||||
|
ECase e branches -> (scs, (ECase e branches', typ))
|
||||||
|
where
|
||||||
|
(scs, branches') = mapAccumL f [] branches
|
||||||
|
f acc b = (acc ++ acc', b')
|
||||||
|
where (acc', b') = collectScsBranch b
|
||||||
|
|
||||||
-- Collect supercombinators from bind, the rhss, and the expression.
|
-- Collect supercombinators from bind, the rhss, and the expression.
|
||||||
--
|
--
|
||||||
-- > f = let sc x y = rhs in e
|
-- > f = let sc x y = rhs in e
|
||||||
|
|
@ -210,6 +252,9 @@ collectScsExp expT@(exp, typ) = case exp of
|
||||||
(rhs_scs, rhs') = collectScsExp rhs
|
(rhs_scs, rhs') = collectScsExp rhs
|
||||||
(et_scs, et') = collectScsExp e
|
(et_scs, et') = collectScsExp e
|
||||||
|
|
||||||
|
collectScsBranch (Branch patt exp) = (scs, Branch patt exp')
|
||||||
|
where (scs, exp') = collectScsExp exp
|
||||||
|
|
||||||
|
|
||||||
-- @\x.\y.\z. e → (e, [x,y,z])@
|
-- @\x.\y.\z. e → (e, [x,y,z])@
|
||||||
flattenLambdas :: ExpT -> (ExpT, [Id])
|
flattenLambdas :: ExpT -> (ExpT, [Id])
|
||||||
|
|
@ -240,3 +285,4 @@ skipForalls = go []
|
||||||
go acc typ = case typ of
|
go acc typ = case typ of
|
||||||
TAll tvar t -> go (snoc (TAll tvar) acc) t
|
TAll tvar t -> go (snoc (TAll tvar) acc) t
|
||||||
_ -> (acc, typ)
|
_ -> (acc, typ)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue