churf/src/LambdaLifter.hs
2023-05-10 19:42:27 +02:00

299 lines
9.6 KiB
Haskell

{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedRecordDot #-}
{-# LANGUAGE OverloadedStrings #-}
module LambdaLifter where
import Auxiliary (onM, snoc)
import Control.Applicative (Applicative (liftA2))
import Control.Monad.State (MonadState (get, put), State,
evalState)
import Data.Function (on)
import Data.List (delete, mapAccumL, (\\))
import Data.Tuple.Extra (first, second)
import LambdaLifterIr (T)
import qualified LambdaLifterIr as L
import Prelude hiding (exp)
import TypeChecker.TypeCheckerIr hiding (T)
-- | Lift lambdas and let expression into supercombinators.
-- Three phases:
-- @freeVars@ annotates all the free variables.
-- @abstract@ converts lambdas into let expressions.
-- @collectScs@ moves every non-constant let expression to a top-level function.
--
lambdaLift :: Program -> L.Program
lambdaLift (Program ds) = L.Program (datatypes ++ binds)
where
datatypes = [L.DData (toLirData d) | DData d <- ds]
binds = map L.DBind $ (collectScs . abstract . freeVars) [b | DBind b <- ds]
-- | Annotate free variables
freeVars :: [Bind] -> [ABind]
freeVars binds = [ let ae = freeVarsExp [] e
ae' = ae { frees = ae.frees \\ xs }
in ABind n xs ae'
| Bind n xs e <- binds
]
freeVarsExp :: Frees -> T Exp -> Ann (T AExp)
freeVarsExp localVars (ae, t) = case ae of
EVar n | elem (n,t) localVars -> Ann { frees = [(n, t)]
, term = (AVar n, t)
}
| otherwise -> Ann { frees = []
, term = (AVar n, t)
}
EInj n -> Ann { frees = [], term = (AInj n, t) }
ELit lit -> Ann { frees = [], term = (ALit lit, t) }
EApp e1 e2 -> Ann { frees = annae1.frees <|| annae2.frees
, term = (AApp annae1 annae2, t)
}
where
(annae1, annae2) = on (,) (freeVarsExp localVars) e1 e2
EAdd e1 e2 -> Ann { frees = annae1.frees <|| annae2.frees
, term = (AAdd annae1 annae2, t)
}
where
(annae1, annae2) = on (,) (freeVarsExp localVars) e1 e2
EAbs x e -> Ann { frees = delete (x,t_x) $ annae.frees
, term = (AAbs x annae, t) }
where
annae = freeVarsExp (localVars <| (x,t_x)) e
t_x = case t of TFun t _ -> t
_ -> error "Impossible"
-- Sum free variables present in bind and the expression
-- let f x = x + y in f 5 + z → frees: y, z
ELet bind@(Bind n _ _) e ->
Ann { frees = delete n annae.frees <|| annbind.frees
, term = (ALet annbind annae, t)
}
where
annae = freeVarsExp (localVars <| n) e
annbind = freeVarsBind localVars bind
ECase e branches ->
Ann { frees = foldl (<||) annae.frees (map frees annbranches)
, term = (ACase annae annbranches, t)
}
where
annae = freeVarsExp localVars e
annbranches = map (freeVarsBranch localVars) branches
freeVarsBind :: Frees -> Bind -> Ann ABind
freeVarsBind localVars (Bind name vars e) =
Ann { frees = annae.frees \\ vars
, term = ABind name vars annae
}
where
annae = freeVarsExp (localVars <|| vars) e
freeVarsBranch :: Frees -> Branch -> Ann ABranch
freeVarsBranch localVars (Branch pt e) =
Ann { frees = annae.frees \\ varsInPattern
, term = ABranch pt annae
}
where
annae = freeVarsExp (localVars <|| varsInPattern) e
varsInPattern = go [] pt
where
go acc (p, t) = case p of
PVar n -> acc <| (n, t)
PInj _ ps -> foldl go acc ps
_ -> []
-- AST annotated with free variables
type Frees = [(Ident, Type)]
data Ann a = Ann
{ frees :: Frees
, term :: a
} deriving (Show, Eq)
data ABind = ABind (T Ident) [T Ident] (Ann (T AExp)) deriving (Show, Eq)
data ABranch = ABranch (Pattern, Type) (Ann (T AExp)) deriving (Show, Eq)
data AExp = AVar Ident
| AInj Ident
| ALit Lit
| ALet (Ann ABind) (Ann (T AExp))
| AApp (Ann (T AExp)) (Ann (T AExp))
| AAdd (Ann (T AExp)) (Ann (T AExp))
| AAbs Ident (Ann (T AExp))
| ACase (Ann (T AExp)) [Ann ABranch]
deriving (Show, Eq)
data BBind = BBind (T Ident) [T Ident] (T BExp)
| BBindCxt [T Ident] (T Ident) [T Ident] (T BExp)
deriving (Eq, Ord, Show)
data BBranch = BBranch (T Pattern) (T BExp)
deriving (Eq, Ord, Show)
data BExp
= BVar Ident
| BVarC [T Ident] Ident
| BInj Ident
| BLit Lit
| BLet BBind (T BExp)
| BApp (T BExp)(T BExp)
| BAdd (T BExp)(T BExp)
| BCase (T BExp) [BBranch]
deriving (Eq, Ord, Show)
abstract :: [ABind] -> [BBind]
abstract bs = evalState (mapM (abstractAnnBind . Ann []) bs) 0
abstractAnnBind :: Ann ABind -> State Int BBind
abstractAnnBind Ann { term = ABind name vars annae } =
BBind name (vars' <|| vars) <$> abstractAnnExp annae'
where
(annae', vars') = go [] annae
where
go acc = \case
Ann { term = (AAbs x ae, TFun t _) } -> go (snoc (x, t) acc) ae
ae -> (ae, acc)
abstractAnnExp :: Ann (T AExp) -> State Int (T BExp)
abstractAnnExp Ann {frees, term = (annae, typ) } = case annae of
AVar n -> pure (BVar n, typ)
AInj n -> pure (BInj n, typ)
ALit lit -> pure (BLit lit, typ)
AApp annae1 annae2 -> (, typ) <$> onM BApp abstractAnnExp annae1 annae2
AAdd annae1 annae2 -> (, typ) <$> onM BAdd abstractAnnExp annae1 annae2
AAbs x annae' -> do
i <- nextNumber
rhs <- abstractAnnExp annae''
let sc_name = Ident ("sc_" ++ show i)
sc | null frees = (BVar sc_name, typ)
| otherwise = (BVarC frees sc_name, typ)
bind | null frees = BBind (sc_name, typ) vars rhs
| otherwise = BBindCxt frees (sc_name, typ) vars rhs
pure (BLet bind sc ,typ)
where
vars = [(x, t_x)] <|| ys
t_x = case typ of TFun t _ -> t
_ -> error "Impossible"
(annae'', ys) = go [] annae'
where
go acc = \case
Ann { term = (AAbs x ae, TFun t _) } -> go (snoc (x, t) acc) ae
ae -> (ae, acc)
ACase annae' bs -> do
bs <- mapM go bs
e <- abstractAnnExp annae'
pure (BCase e bs, typ)
where
go Ann { term = ABranch p annae } = BBranch p <$> abstractAnnExp annae
ALet b annae' ->
(, typ) <$> liftA2 BLet (abstractAnnBind b) (abstractAnnExp annae')
-- | Collects supercombinators by lifting non-constant let expressions
collectScs :: [BBind] -> [L.Bind]
collectScs = concatMap collectFromRhs
where
collectFromRhs (BBind name parms rhs) =
let (rhs_scs, rhs') = collectScsExp rhs
in L.Bind name parms rhs' : rhs_scs
collectFromRhs (BBindCxt cxt name parms rhs) =
let (rhs_scs, rhs') = collectScsExp rhs
in L.BindC cxt name parms rhs' : rhs_scs
collectScsExp :: T BExp -> ([L.Bind], T L.Exp)
collectScsExp (exp, typ) = case exp of
BVar x -> ([], (L.EVar x, typ))
BVarC as x -> ([], (L.EVarC as x, typ))
BInj k -> ([], (L.EInj k, typ))
BLit lit -> ([], (L.ELit lit, typ))
BApp e1 e2 -> (scs1 ++ scs2, (L.EApp e1' e2', typ))
where
(scs1, e1') = collectScsExp e1
(scs2, e2') = collectScsExp e2
BAdd e1 e2 -> (scs1 ++ scs2, (L.EAdd e1' e2', typ))
where
(scs1, e1') = collectScsExp e1
(scs2, e2') = collectScsExp e2
BCase e branches -> (scs ++ scs_e, (L.ECase e' branches', typ))
where
(scs, branches') = mapAccumL f [] branches
(scs_e, e') = collectScsExp e
f acc b = (acc ++ acc', b')
where (acc', b') = collectScsBranch b
-- Collect supercombinators from bind, the rhss, and the expression.
--
-- > f = let sc x y = rhs in e
--
BLet (BBind name parms rhs) e
| null parms -> (rhs_scs ++ et_scs, (L.ELet name rhs' et', snd et'))
| otherwise -> (bind : rhs_scs ++ et_scs, et')
where
bind = L.Bind name parms rhs'
(rhs_scs, rhs') = collectScsExp rhs
(et_scs, et') = collectScsExp e
BLet (BBindCxt cxt name parms rhs) e
| null parms -> (rhs_scs ++ et_scs, (L.ELet name rhs' et', snd et'))
| otherwise -> (bind : rhs_scs ++ et_scs, et')
where
bind = L.BindC cxt name parms rhs'
(rhs_scs, rhs') = collectScsExp rhs
(et_scs, et') = collectScsExp e
collectScsBranch (BBranch patt exp) = (scs, L.Branch (first toLirPattern patt) exp')
where (scs, exp') = collectScsExp exp
nextNumber :: State Int Int
nextNumber = do
i <- get
put $ succ i
pure i
(<|) :: Eq a => [a] -> a -> [a]
xs <| x | elem x xs = xs
| otherwise = snoc x xs
(<||) :: Eq a => [a] -> [a] -> [a]
xs <|| ys = foldl (<|) xs ys
toLirData (Data t injs) = L.Data t (map toLirInj injs)
toLirInj (Inj n t) = L.Inj n t
toLirPattern :: Pattern -> L.Pattern
toLirPattern = \case
PVar x -> L.PVar x
PLit lit -> L.PLit lit
PCatch -> L.PCatch
PEnum k -> L.PEnum k
PInj k ps -> L.PInj k (map (first toLirPattern) ps)