265 lines
8.7 KiB
Haskell
265 lines
8.7 KiB
Haskell
{-# LANGUAGE LambdaCase #-}
|
|
{-# LANGUAGE OverloadedStrings #-}
|
|
|
|
|
|
module LambdaLifter (lambdaLift, freeVars, abstract, collectScs) where
|
|
|
|
import Auxiliary (mapAccumM, snoc)
|
|
import Control.Applicative (Applicative (liftA2))
|
|
import Control.Arrow (Arrow (second))
|
|
import Control.Monad.State (MonadState (get, put), State,
|
|
evalState)
|
|
import Data.List (mapAccumL, partition)
|
|
import Data.Set (Set)
|
|
import qualified Data.Set as Set
|
|
import Prelude hiding (exp)
|
|
import TypeChecker.TypeCheckerIr
|
|
|
|
|
|
-- | Lift lambdas and let expression into supercombinators.
|
|
-- Three phases:
|
|
-- @freeVars@ annotates all the free variables.
|
|
-- @abstract@ converts lambdas into let expressions.
|
|
-- @collectScs@ moves every non-constant let expression to a top-level function.
|
|
lambdaLift :: Program -> Program
|
|
lambdaLift (Program defs) = Program $ datatypes ++ ll binds
|
|
where
|
|
ll = map DBind . collectScs . abstract . freeVars . map (\(DBind b) -> b)
|
|
(binds, datatypes) = partition isBind defs
|
|
isBind = \case
|
|
DBind _ -> True
|
|
_ -> False
|
|
|
|
-- | Annotate free variables
|
|
freeVars :: [Bind] -> AnnBinds
|
|
freeVars binds = [ (n, xs, freeVarsExp (Set.fromList $ map fst xs) e)
|
|
| Bind n xs e <- binds
|
|
]
|
|
|
|
freeVarsExp :: Set Ident -> ExpT -> AnnExpT
|
|
freeVarsExp localVars (exp, t) = case exp of
|
|
EVar n | Set.member n localVars -> (Set.singleton n, (AVar n, t))
|
|
| otherwise -> (mempty, (AVar n, t))
|
|
|
|
EInj n -> (mempty, (AVar n, t))
|
|
|
|
ELit lit -> (mempty, (ALit lit, t))
|
|
|
|
EApp e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), (AApp e1' e2', t))
|
|
where
|
|
e1' = freeVarsExp localVars e1
|
|
e2' = freeVarsExp localVars e2
|
|
|
|
EAdd e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), (AAdd e1' e2', t))
|
|
where
|
|
e1' = freeVarsExp localVars e1
|
|
e2' = freeVarsExp localVars e2
|
|
|
|
EAbs par e -> (Set.delete par $ freeVarsOf e', (AAbs par e', t))
|
|
where
|
|
e' = freeVarsExp (Set.insert par localVars) e
|
|
|
|
-- Sum free variables present in bind and the expression
|
|
ELet (Bind (name, t_bind) parms rhs) e -> (Set.union binders_frees e_free, (ALet new_bind e', t))
|
|
where
|
|
binders_frees = Set.delete name $ freeVarsOf rhs'
|
|
e_free = Set.delete name $ freeVarsOf e'
|
|
|
|
rhs' = freeVarsExp e_localVars rhs
|
|
new_bind = ABind (name, t_bind) parms rhs'
|
|
|
|
e' = freeVarsExp e_localVars e
|
|
e_localVars = Set.insert name localVars
|
|
|
|
ECase e branches -> (frees, (ACase e' branches', t))
|
|
where
|
|
frees = foldr (\b s -> Set.union s $ fst b) (freeVarsOf e') branches'
|
|
e' = freeVarsExp localVars e
|
|
branches' = map (freeVarsBranch localVars) branches
|
|
|
|
|
|
freeVarsBranch :: Set Ident -> Branch' Type -> (Set Ident, AnnBranch')
|
|
freeVarsBranch localVars (Branch (patt, t) exp) = (frees, AnnBranch (patt, t) exp')
|
|
where
|
|
frees = freeVarsOf exp' Set.\\ freeVarsOfPattern patt
|
|
exp' = freeVarsExp localVars exp
|
|
freeVarsOfPattern = Set.fromList . go []
|
|
where
|
|
go acc = \case
|
|
PVar (n,_) -> snoc n acc
|
|
PInj _ ps -> foldl go acc ps
|
|
|
|
|
|
|
|
freeVarsOf :: AnnExpT -> Set Ident
|
|
freeVarsOf = fst
|
|
|
|
-- AST annotated with free variables
|
|
type AnnBinds = [(Id, [Id], AnnExpT)]
|
|
|
|
type AnnExpT = (Set Ident, AnnExpT')
|
|
|
|
data ABind = ABind Id [Id] AnnExpT deriving Show
|
|
|
|
type AnnExpT' = (AnnExp, Type)
|
|
|
|
type AnnBranch = (Set Ident, AnnBranch')
|
|
data AnnBranch' = AnnBranch (Pattern, Type) AnnExpT
|
|
deriving Show
|
|
|
|
data AnnExp = AVar Ident
|
|
| AInj Ident
|
|
| ALit Lit
|
|
| ALet ABind AnnExpT
|
|
| AApp AnnExpT AnnExpT
|
|
| AAdd AnnExpT AnnExpT
|
|
| AAbs Ident AnnExpT
|
|
| ACase AnnExpT [AnnBranch]
|
|
deriving Show
|
|
|
|
-- | Lift lambdas to let expression of the form @let sc = \v₁ x₁ -> e₁@.
|
|
-- Free variables are @v₁ v₂ .. vₙ@ are bound.
|
|
abstract :: AnnBinds -> [Bind]
|
|
abstract prog = evalState (mapM go prog) 0
|
|
where
|
|
go :: (Id, [Id], AnnExpT) -> State Int Bind
|
|
go (name, parms, rhs) = Bind name (parms ++ parms1) <$> abstractExp rhs'
|
|
where
|
|
(rhs', parms1) = flattenLambdasAnn rhs
|
|
|
|
|
|
-- | Flatten nested lambdas and collect the parameters
|
|
-- @\x.\y.\z. ae → (ae, [x,y,z])@
|
|
flattenLambdasAnn :: AnnExpT -> (AnnExpT, [Id])
|
|
flattenLambdasAnn ae = go (ae, [])
|
|
where
|
|
go :: (AnnExpT, [Id]) -> (AnnExpT, [Id])
|
|
go ((free, (e, t)), acc)
|
|
| AAbs par (free1, e1) <- e
|
|
, TFun t_par _ <- t
|
|
= go ((Set.delete par free1, e1), snoc (par, t_par) acc)
|
|
| otherwise = ((free, (e, t)), acc)
|
|
|
|
abstractExp :: AnnExpT -> State Int ExpT
|
|
abstractExp (free, (exp, typ)) = case exp of
|
|
AVar n -> pure (EVar n, typ)
|
|
AInj n -> pure (EInj n, typ)
|
|
ALit lit -> pure (ELit lit, typ)
|
|
AApp e1 e2 -> (, typ) <$> liftA2 EApp (abstractExp e1) (abstractExp e2)
|
|
AAdd e1 e2 -> (, typ) <$> liftA2 EAdd (abstractExp e1) (abstractExp e2)
|
|
ALet b e -> (, typ) <$> liftA2 ELet (go b) (abstractExp e)
|
|
where
|
|
go (ABind name parms rhs) = do
|
|
(rhs', parms1) <- flattenLambdas <$> skipLambdas abstractExp rhs
|
|
pure $ Bind name (parms ++ parms1) rhs'
|
|
|
|
skipLambdas :: (AnnExpT -> State Int ExpT) -> AnnExpT -> State Int ExpT
|
|
skipLambdas f (free, (ae, t)) = case ae of
|
|
AAbs par ae1 -> do
|
|
ae1' <- skipLambdas f ae1
|
|
pure (EAbs par ae1', t)
|
|
_ -> f (free, (ae, t))
|
|
|
|
ACase e branches -> (, typ) <$> liftA2 ECase (abstractExp e) (mapM abstractBranch branches)
|
|
|
|
|
|
-- Lift lambda into let and bind free variables
|
|
AAbs parm e -> do
|
|
i <- nextNumber
|
|
rhs <- abstractExp e
|
|
|
|
let sc_name = Ident ("sc_" ++ show i)
|
|
sc = (ELet (Bind (sc_name, typ) vars rhs) (EVar sc_name, typ), typ)
|
|
pure $ foldl applyVars sc freeList
|
|
|
|
where
|
|
freeList = Set.toList free
|
|
vars = zip names $ getVars typ
|
|
names = snoc parm freeList
|
|
applyVars (e, t) name = (EApp (e, t) (EVar name, t_var), t_return)
|
|
where
|
|
(t_var, t_return) = case t of
|
|
TFun t1 t2 -> (t1, t2)
|
|
|
|
|
|
|
|
abstractBranch :: AnnBranch -> State Int Branch
|
|
abstractBranch (_, AnnBranch patt exp) = Branch patt <$> abstractExp exp
|
|
|
|
nextNumber :: State Int Int
|
|
nextNumber = do
|
|
i <- get
|
|
put $ succ i
|
|
pure i
|
|
|
|
-- | Collects supercombinators by lifting non-constant let expressions
|
|
collectScs :: [Bind] -> [Bind]
|
|
collectScs = concatMap collectFromRhs
|
|
where
|
|
collectFromRhs (Bind name parms rhs) =
|
|
let (rhs_scs, rhs') = collectScsExp rhs
|
|
in Bind name parms rhs' : rhs_scs
|
|
|
|
|
|
collectScsExp :: ExpT -> ([Bind], ExpT)
|
|
collectScsExp expT@(exp, typ) = case exp of
|
|
EVar _ -> ([], expT)
|
|
EInj _ -> ([], expT)
|
|
ELit _ -> ([], expT)
|
|
|
|
EApp e1 e2 -> (scs1 ++ scs2, (EApp e1' e2', typ))
|
|
where
|
|
(scs1, e1') = collectScsExp e1
|
|
(scs2, e2') = collectScsExp e2
|
|
|
|
EAdd e1 e2 -> (scs1 ++ scs2, (EAdd e1' e2', typ))
|
|
where
|
|
(scs1, e1') = collectScsExp e1
|
|
(scs2, e2') = collectScsExp e2
|
|
|
|
EAbs par e -> (scs, (EAbs par e', typ))
|
|
where
|
|
(scs, e') = collectScsExp e
|
|
|
|
ECase e branches -> (scs ++ scs_e, (ECase e' branches', typ))
|
|
where
|
|
(scs, branches') = mapAccumL f [] branches
|
|
(scs_e, e') = collectScsExp e
|
|
f acc b = (acc ++ acc', b')
|
|
where (acc', b') = collectScsBranch b
|
|
|
|
-- Collect supercombinators from bind, the rhss, and the expression.
|
|
--
|
|
-- > f = let sc x y = rhs in e
|
|
--
|
|
ELet (Bind name parms rhs) e -> if null parms
|
|
then ( rhs_scs ++ et_scs, (ELet bind et', snd et'))
|
|
else (bind : rhs_scs ++ et_scs, et')
|
|
where
|
|
bind = Bind name parms rhs'
|
|
(rhs_scs, rhs') = collectScsExp rhs
|
|
(et_scs, et') = collectScsExp e
|
|
|
|
collectScsBranch (Branch patt exp) = (scs, Branch patt exp')
|
|
where (scs, exp') = collectScsExp exp
|
|
|
|
|
|
-- @\x.\y.\z. e → (e, [x,y,z])@
|
|
flattenLambdas :: ExpT -> (ExpT, [Id])
|
|
flattenLambdas = go . (, [])
|
|
where
|
|
go ((e, t), acc) = case e of
|
|
EAbs name e1 -> go (e1, snoc (name, t_var) acc)
|
|
where t_var = head $ getVars t
|
|
_ -> ((e, t), acc)
|
|
|
|
getVars :: Type -> [Type]
|
|
getVars = fst . partitionType
|
|
|
|
partitionType :: Type -> ([Type], Type)
|
|
partitionType = go []
|
|
where
|
|
go acc t = case t of
|
|
TFun t1 t2 -> go (snoc t1 acc) t2
|
|
_ -> (acc, t)
|
|
|