261 lines
8.3 KiB
Haskell
261 lines
8.3 KiB
Haskell
{-# LANGUAGE LambdaCase #-}
|
|
{-# LANGUAGE OverloadedRecordDot #-}
|
|
{-# LANGUAGE OverloadedStrings #-}
|
|
|
|
|
|
module LambdaLifter where
|
|
|
|
import Auxiliary (onM, snoc)
|
|
import Control.Applicative (Applicative (liftA2))
|
|
import Control.Monad.State (MonadState (get, put), State,
|
|
evalState)
|
|
import Data.Function (on)
|
|
import Data.List (delete, mapAccumL, (\\))
|
|
import Prelude hiding (exp)
|
|
import TypeChecker.TypeCheckerIr
|
|
|
|
|
|
-- | Lift lambdas and let expression into supercombinators.
|
|
-- Three phases:
|
|
-- @freeVars@ annotates all the free variables.
|
|
-- @abstract@ converts lambdas into let expressions.
|
|
-- @collectScs@ moves every non-constant let expression to a top-level function.
|
|
--
|
|
lambdaLift :: Program -> Program
|
|
lambdaLift (Program ds) = Program (datatypes ++ binds)
|
|
where
|
|
datatypes = flip filter ds $ \case DData _ -> True
|
|
_ -> False
|
|
binds = map DBind $ (collectScs . abstract . freeVars) [b | DBind b <- ds]
|
|
|
|
-- | Annotate free variables
|
|
freeVars :: [Bind] -> [ABind]
|
|
freeVars binds = [ let ae = freeVarsExp [] e
|
|
ae' = ae { frees = ae.frees \\ xs }
|
|
in ABind n xs ae'
|
|
| Bind n xs e <- binds
|
|
]
|
|
|
|
freeVarsExp :: Frees -> ExpT -> Ann AExpT
|
|
freeVarsExp localVars (ae, t) = case ae of
|
|
EVar n | elem (n,t) localVars -> Ann { frees = [(n, t)]
|
|
, term = (AVar n, t)
|
|
}
|
|
| otherwise -> Ann { frees = []
|
|
, term = (AVar n, t)
|
|
}
|
|
|
|
EInj n -> Ann { frees = [], term = (AInj n, t) }
|
|
|
|
ELit lit -> Ann { frees = [], term = (ALit lit, t) }
|
|
|
|
EApp e1 e2 -> Ann { frees = annae1.frees <|| annae2.frees
|
|
, term = (AApp annae1 annae2, t)
|
|
}
|
|
where
|
|
(annae1, annae2) = on (,) (freeVarsExp localVars) e1 e2
|
|
|
|
EAdd e1 e2 -> Ann { frees = annae1.frees <|| annae2.frees
|
|
, term = (AAdd annae1 annae2, t)
|
|
}
|
|
where
|
|
(annae1, annae2) = on (,) (freeVarsExp localVars) e1 e2
|
|
|
|
|
|
EAbs x e -> Ann { frees = delete (x,t_x) $ annae.frees
|
|
, term = (AAbs x annae, t) }
|
|
where
|
|
annae = freeVarsExp (localVars <| (x,t_x)) e
|
|
t_x = case t of TFun t _ -> t
|
|
_ -> error "Impossible"
|
|
|
|
-- Sum free variables present in bind and the expression
|
|
-- let f x = x + y in f 5 + z → frees: y, z
|
|
ELet bind@(Bind n _ _) e ->
|
|
Ann { frees = delete n annae.frees <|| annbind.frees
|
|
, term = (ALet annbind annae, t)
|
|
}
|
|
where
|
|
annae = freeVarsExp (localVars <| n) e
|
|
annbind = freeVarsBind localVars bind
|
|
|
|
ECase e branches ->
|
|
Ann { frees = foldl (<||) annae.frees (map frees annbranches)
|
|
, term = (ACase annae annbranches, t)
|
|
}
|
|
where
|
|
annae = freeVarsExp localVars e
|
|
annbranches = map (freeVarsBranch localVars) branches
|
|
|
|
|
|
freeVarsBind :: Frees -> Bind -> Ann ABind
|
|
freeVarsBind localVars (Bind name vars e) =
|
|
Ann { frees = annae.frees \\ vars
|
|
, term = ABind name vars annae
|
|
}
|
|
where
|
|
annae = freeVarsExp (localVars <|| vars) e
|
|
|
|
|
|
freeVarsBranch :: Frees -> Branch -> Ann ABranch
|
|
freeVarsBranch localVars (Branch pt e) =
|
|
Ann { frees = annae.frees \\ varsInPattern
|
|
, term = ABranch pt annae
|
|
}
|
|
where
|
|
annae = freeVarsExp localVars e
|
|
varsInPattern = go [] pt
|
|
where
|
|
go acc (p, t) = case p of
|
|
PVar n -> acc <| (n, t)
|
|
PInj _ ps -> foldl go acc ps
|
|
_ -> []
|
|
|
|
|
|
-- AST annotated with free variables
|
|
|
|
type Frees = [(Ident, Type)]
|
|
|
|
data Ann a = Ann
|
|
{ frees :: Frees
|
|
, term :: a
|
|
} deriving (Show, Eq)
|
|
|
|
data ABind = ABind Id [Id] (Ann AExpT) deriving (Show, Eq)
|
|
data ABranch = ABranch (Pattern, Type) (Ann AExpT) deriving (Show, Eq)
|
|
|
|
type AExpT = (AExp, Type)
|
|
|
|
data AExp = AVar Ident
|
|
| AInj Ident
|
|
| ALit Lit
|
|
| ALet (Ann ABind) (Ann AExpT)
|
|
| AApp (Ann AExpT) (Ann AExpT)
|
|
| AAdd (Ann AExpT) (Ann AExpT)
|
|
| AAbs Ident (Ann AExpT)
|
|
| ACase (Ann AExpT) [Ann ABranch]
|
|
deriving (Show, Eq)
|
|
|
|
abstract :: [ABind] -> [Bind]
|
|
abstract bs = evalState (mapM (abstractAnnBind . Ann []) bs) 0
|
|
|
|
abstractAnnBind :: Ann ABind -> State Int Bind
|
|
abstractAnnBind Ann { term = ABind name vars annae } =
|
|
Bind name (vars' <|| vars) <$> abstractAnnExp annae'
|
|
where
|
|
(annae', vars') = go [] annae
|
|
where
|
|
go acc = \case
|
|
Ann { term = (AAbs x ae, TFun t _) } -> go (snoc (x, t) acc) ae
|
|
ae -> (ae, acc)
|
|
|
|
abstractAnnExp :: Ann AExpT -> State Int ExpT
|
|
abstractAnnExp Ann {frees, term = (annae, typ) } = case annae of
|
|
AVar n -> pure (EVar n, typ)
|
|
AInj n -> pure (EInj n, typ)
|
|
ALit lit -> pure (ELit lit, typ)
|
|
AApp annae1 annae2 -> (, typ) <$> onM EApp abstractAnnExp annae1 annae2
|
|
AAdd annae1 annae2 -> (, typ) <$> onM EAdd abstractAnnExp annae1 annae2
|
|
|
|
-- \x. \y. x + y + z ⇒ let sc x y z = x + y + z in sc
|
|
AAbs x annae' -> do
|
|
i <- nextNumber
|
|
rhs <- abstractAnnExp annae''
|
|
let sc_name = Ident ("sc_" ++ show i)
|
|
e@(_, t) = foldl applyFree (EVar sc_name, typ) frees
|
|
pure (ELet (Bind (sc_name, typ) vars rhs) e ,t)
|
|
|
|
where
|
|
vars = frees <| (x, t_x) <|| ys
|
|
t_x = case typ of TFun t _ -> t
|
|
_ -> error "Impossible"
|
|
|
|
(annae'', ys) = go [] annae'
|
|
where
|
|
go acc = \case
|
|
Ann { term = (AAbs x ae, TFun t _) } -> go (snoc (x, t) acc) ae
|
|
ae -> (ae, acc)
|
|
|
|
|
|
applyFree :: (Exp' Type, Type) -> (Ident, Type) -> (Exp' Type, Type)
|
|
applyFree (e, t_e) (x, t_x) = (EApp (e, t_e) (EVar x, t_x), t_e')
|
|
where
|
|
t_e' = case t_e of TFun _ t -> t
|
|
_ -> error "Impossible"
|
|
|
|
ACase annae' bs -> do
|
|
bs <- mapM go bs
|
|
e <- abstractAnnExp annae'
|
|
pure (ECase e bs, typ)
|
|
where
|
|
go Ann { term = ABranch p annae } = Branch p <$> abstractAnnExp annae
|
|
|
|
ALet b annae' ->
|
|
(, typ) <$> liftA2 ELet (abstractAnnBind b) (abstractAnnExp annae')
|
|
|
|
|
|
-- | Collects supercombinators by lifting non-constant let expressions
|
|
collectScs :: [Bind] -> [Bind]
|
|
collectScs = concatMap collectFromRhs
|
|
where
|
|
collectFromRhs (Bind name parms rhs) =
|
|
let (rhs_scs, rhs') = collectScsExp rhs
|
|
in Bind name parms rhs' : rhs_scs
|
|
|
|
|
|
collectScsExp :: ExpT -> ([Bind], ExpT)
|
|
collectScsExp expT@(exp, typ) = case exp of
|
|
EVar _ -> ([], expT)
|
|
EInj _ -> ([], expT)
|
|
ELit _ -> ([], expT)
|
|
|
|
EApp e1 e2 -> (scs1 ++ scs2, (EApp e1' e2', typ))
|
|
where
|
|
(scs1, e1') = collectScsExp e1
|
|
(scs2, e2') = collectScsExp e2
|
|
|
|
EAdd e1 e2 -> (scs1 ++ scs2, (EAdd e1' e2', typ))
|
|
where
|
|
(scs1, e1') = collectScsExp e1
|
|
(scs2, e2') = collectScsExp e2
|
|
|
|
EAbs par e -> (scs, (EAbs par e', typ))
|
|
where
|
|
(scs, e') = collectScsExp e
|
|
|
|
ECase e branches -> (scs ++ scs_e, (ECase e' branches', typ))
|
|
where
|
|
(scs, branches') = mapAccumL f [] branches
|
|
(scs_e, e') = collectScsExp e
|
|
f acc b = (acc ++ acc', b')
|
|
where (acc', b') = collectScsBranch b
|
|
|
|
-- Collect supercombinators from bind, the rhss, and the expression.
|
|
--
|
|
-- > f = let sc x y = rhs in e
|
|
--
|
|
ELet (Bind name parms rhs) e
|
|
| null parms -> (rhs_scs ++ et_scs, (ELet bind et', snd et'))
|
|
| otherwise -> (bind : rhs_scs ++ et_scs, et')
|
|
where
|
|
bind = Bind name parms rhs'
|
|
(rhs_scs, rhs') = collectScsExp rhs
|
|
(et_scs, et') = collectScsExp e
|
|
|
|
collectScsBranch (Branch patt exp) = (scs, Branch patt exp')
|
|
where (scs, exp') = collectScsExp exp
|
|
|
|
nextNumber :: State Int Int
|
|
nextNumber = do
|
|
i <- get
|
|
put $ succ i
|
|
pure i
|
|
|
|
|
|
(<|) :: Eq a => [a] -> a -> [a]
|
|
xs <| x | elem x xs = xs
|
|
| otherwise = snoc x xs
|
|
|
|
(<||) :: Eq a => [a] -> [a] -> [a]
|
|
xs <|| ys = foldl (<|) xs ys
|
|
|