221 lines
7.1 KiB
Haskell
221 lines
7.1 KiB
Haskell
{-# LANGUAGE LambdaCase #-}
|
|
{-# LANGUAGE OverloadedStrings #-}
|
|
|
|
|
|
module LambdaLifter (lambdaLift, freeVars, abstract, rename, collectScs) where
|
|
|
|
import Auxiliary (snoc)
|
|
import Control.Applicative (Applicative (liftA2))
|
|
import Control.Monad.State (MonadState (get, put), State, evalState)
|
|
import Data.Set (Set)
|
|
import qualified Data.Set as Set
|
|
import Debug.Trace (trace)
|
|
import qualified Grammar.Abs as GA
|
|
import Prelude hiding (exp)
|
|
import Renamer
|
|
import TypeCheckerIr
|
|
|
|
|
|
-- | Lift lambdas and let expression into supercombinators.
|
|
-- Three phases:
|
|
-- @freeVars@ annotatss all the free variables.
|
|
-- @abstract@ converts lambdas into let expressions.
|
|
-- @collectScs@ moves every non-constant let expression to a top-level function.
|
|
lambdaLift :: Program -> Program
|
|
lambdaLift = collectScs . abstract . freeVars
|
|
|
|
-- | Annotate free variables
|
|
freeVars :: Program -> AnnProgram
|
|
freeVars (Program ds) = [ (n, xs, freeVarsExp (Set.fromList xs) e)
|
|
| Bind n xs e <- ds
|
|
]
|
|
|
|
freeVarsExp :: Set Id -> Exp -> AnnExp
|
|
freeVarsExp localVars = \case
|
|
EId n | Set.member n localVars -> (Set.singleton n, AId n)
|
|
| otherwise -> (mempty, AId n)
|
|
|
|
EInt i -> (mempty, AInt i)
|
|
|
|
EApp t e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), AApp t e1' e2')
|
|
where
|
|
e1' = freeVarsExp localVars e1
|
|
e2' = freeVarsExp localVars e2
|
|
|
|
EAdd t e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), AAdd t e1' e2')
|
|
where
|
|
e1' = freeVarsExp localVars e1
|
|
e2' = freeVarsExp localVars e2
|
|
|
|
EAbs t par e -> (Set.delete par $ freeVarsOf e', AAbs t par e')
|
|
where
|
|
e' = freeVarsExp (Set.insert par localVars) e
|
|
|
|
-- Sum free variables present in bind and the expression
|
|
ELet (Bind name parms rhs) e -> (Set.union binders_frees e_free, ALet new_bind e')
|
|
where
|
|
binders_frees = Set.delete name $ freeVarsOf rhs'
|
|
e_free = Set.delete name $ freeVarsOf e'
|
|
|
|
rhs' = freeVarsExp e_localVars rhs
|
|
new_bind = ABind name parms rhs'
|
|
|
|
e' = freeVarsExp e_localVars e
|
|
e_localVars = Set.insert name localVars
|
|
|
|
(ECase t e cs) -> do
|
|
let e' = freeVarsExp localVars e
|
|
let vars = freeVarsOf e'
|
|
let (vars', cs') = foldr (\(_, Case c e) (vars,acc) -> do
|
|
let e' = freeVarsExp vars e
|
|
let vars' = freeVarsOf e'
|
|
(Set.union vars vars', AnnCase c e' : acc)
|
|
) (vars, []) cs
|
|
(vars', ACase t e' (reverse cs'))
|
|
|
|
|
|
freeVarsOf :: AnnExp -> Set Id
|
|
freeVarsOf = fst
|
|
|
|
-- AST annotated with free variables
|
|
type AnnProgram = [(Id, [Id], AnnExp)]
|
|
|
|
type AnnExp = (Set Id, AnnExp')
|
|
|
|
data ABind = ABind Id [Id] AnnExp deriving Show
|
|
|
|
data AnnExp' = AId Id
|
|
| AInt Integer
|
|
| ALet ABind AnnExp
|
|
| AApp Type AnnExp AnnExp
|
|
| AAdd Type AnnExp AnnExp
|
|
| AAbs Type Id AnnExp
|
|
| ACase Type AnnExp [AnnCase]
|
|
deriving Show
|
|
data AnnCase = AnnCase GA.Case AnnExp
|
|
deriving Show
|
|
|
|
|
|
-- | Lift lambdas to let expression of the form @let sc = \v₁ x₁ -> e₁@.
|
|
-- Free variables are @v₁ v₂ .. vₙ@ are bound.
|
|
abstract :: AnnProgram -> Program
|
|
abstract prog = Program $ evalState (mapM go prog) 0
|
|
where
|
|
go :: (Id, [Id], AnnExp) -> State Int Bind
|
|
go (name, parms, rhs) = Bind name (parms ++ parms1) <$> abstractExp rhs'
|
|
where
|
|
(rhs', parms1) = flattenLambdasAnn rhs
|
|
|
|
|
|
-- | Flatten nested lambdas and collect the parameters
|
|
-- @\x.\y.\z. ae → (ae, [x,y,z])@
|
|
flattenLambdasAnn :: AnnExp -> (AnnExp, [Id])
|
|
flattenLambdasAnn ae = go (ae, [])
|
|
where
|
|
go :: (AnnExp, [Id]) -> (AnnExp, [Id])
|
|
go ((free, e), acc) =
|
|
case e of
|
|
AAbs _ par (free1, e1) ->
|
|
go ((Set.delete par free1, e1), snoc par acc)
|
|
_ -> ((free, e), acc)
|
|
|
|
abstractExp :: AnnExp -> State Int Exp
|
|
abstractExp (free, exp) = case exp of
|
|
AId n -> pure $ EId n
|
|
AInt i -> pure $ EInt i
|
|
AApp t e1 e2 -> liftA2 (EApp t) (abstractExp e1) (abstractExp e2)
|
|
AAdd t e1 e2 -> liftA2 (EAdd t) (abstractExp e1) (abstractExp e2)
|
|
ALet b e -> liftA2 ELet (go b) (abstractExp e)
|
|
where
|
|
go (ABind name parms rhs) = do
|
|
(rhs', parms1) <- flattenLambdas <$> skipLambdas abstractExp rhs
|
|
pure $ Bind name (parms ++ parms1) rhs'
|
|
|
|
skipLambdas :: (AnnExp -> State Int Exp) -> AnnExp -> State Int Exp
|
|
skipLambdas f (free, ae) = case ae of
|
|
AAbs t par ae1 -> EAbs t par <$> skipLambdas f ae1
|
|
_ -> f (free, ae)
|
|
|
|
|
|
ACase t e cs -> do
|
|
e' <- abstractExp e
|
|
cs' <- mapM (\(AnnCase c e) -> do
|
|
e' <- abstractExp e
|
|
pure (t,Case c e')) cs
|
|
pure $ ECase t e' cs'
|
|
|
|
-- Lift lambda into let and bind free variables
|
|
AAbs t parm e -> do
|
|
i <- nextNumber
|
|
rhs <- abstractExp e
|
|
|
|
let sc_name = Ident ("sc_" ++ show i)
|
|
sc = ELet (Bind (sc_name, t) parms rhs) $ EId (sc_name, t)
|
|
|
|
pure $ foldl (EApp TInt) sc $ map EId freeList
|
|
where
|
|
freeList = Set.toList free
|
|
parms = snoc parm freeList
|
|
|
|
|
|
nextNumber :: State Int Int
|
|
nextNumber = do
|
|
i <- get
|
|
put $ succ i
|
|
pure i
|
|
|
|
-- | Collects supercombinators by lifting non-constant let expressions
|
|
collectScs :: Program -> Program
|
|
collectScs (Program scs) = Program $ concatMap collectFromRhs scs
|
|
where
|
|
collectFromRhs (Bind name parms rhs) =
|
|
let (rhs_scs, rhs') = collectScsExp rhs
|
|
in Bind name parms rhs' : rhs_scs
|
|
|
|
|
|
collectScsExp :: Exp -> ([Bind], Exp)
|
|
collectScsExp = \case
|
|
EId n -> ([], EId n)
|
|
EInt i -> ([], EInt i)
|
|
|
|
EApp t e1 e2 -> (scs1 ++ scs2, EApp t e1' e2')
|
|
where
|
|
(scs1, e1') = collectScsExp e1
|
|
(scs2, e2') = collectScsExp e2
|
|
|
|
EAdd t e1 e2 -> (scs1 ++ scs2, EAdd t e1' e2')
|
|
where
|
|
(scs1, e1') = collectScsExp e1
|
|
(scs2, e2') = collectScsExp e2
|
|
|
|
EAbs t par e -> (scs, EAbs t par e')
|
|
where
|
|
(scs, e') = collectScsExp e
|
|
|
|
-- Collect supercombinators from bind, the rhss, and the expression.
|
|
--
|
|
-- > f = let sc x y = rhs in e
|
|
--
|
|
ELet (Bind name parms rhs) e -> if null parms
|
|
then ( rhs_scs ++ e_scs, ELet bind e')
|
|
else (bind : rhs_scs ++ e_scs, e')
|
|
where
|
|
bind = Bind name parms rhs'
|
|
(rhs_scs, rhs') = collectScsExp rhs
|
|
(e_scs, e') = collectScsExp e
|
|
ECase t e cs -> do
|
|
let (scs, e') = collectScsExp e
|
|
let (scs',cs') = foldr (\(t, Case c e) (scs, acc) -> do
|
|
let (scs', e') = collectScsExp e
|
|
(scs ++ scs', (t,Case c e') : acc)
|
|
) (scs,[]) cs
|
|
(scs', ECase t e' cs')
|
|
|
|
|
|
-- @\x.\y.\z. e → (e, [x,y,z])@
|
|
flattenLambdas :: Exp -> (Exp, [Id])
|
|
flattenLambdas = go . (, [])
|
|
where
|
|
go (e, acc) = case e of
|
|
EAbs _ par e1 -> go (e1, snoc par acc)
|
|
_ -> (e, acc)
|