churf/src/LambdaLifter.hs
2023-04-29 21:59:05 +02:00

261 lines
8.3 KiB
Haskell

{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedRecordDot #-}
{-# LANGUAGE OverloadedStrings #-}
module LambdaLifter where
import Auxiliary (onM, snoc)
import Control.Applicative (Applicative (liftA2))
import Control.Monad.State (MonadState (get, put), State,
evalState)
import Data.Function (on)
import Data.List (delete, mapAccumL, (\\))
import Prelude hiding (exp)
import TypeChecker.TypeCheckerIr
-- | Lift lambdas and let expression into supercombinators.
-- Three phases:
-- @freeVars@ annotates all the free variables.
-- @abstract@ converts lambdas into let expressions.
-- @collectScs@ moves every non-constant let expression to a top-level function.
--
lambdaLift :: Program -> Program
lambdaLift (Program ds) = Program (datatypes ++ binds)
where
datatypes = flip filter ds $ \case DData _ -> True
_ -> False
binds = map DBind $ (collectScs . abstract . freeVars) [b | DBind b <- ds]
-- | Annotate free variables
freeVars :: [Bind] -> [ABind]
freeVars binds = [ let ae = freeVarsExp [] e
ae' = ae { frees = ae.frees \\ xs }
in ABind n xs ae'
| Bind n xs e <- binds
]
freeVarsExp :: Frees -> ExpT -> Ann AExpT
freeVarsExp localVars (ae, t) = case ae of
EVar n | elem (n,t) localVars -> Ann { frees = [(n, t)]
, term = (AVar n, t)
}
| otherwise -> Ann { frees = []
, term = (AVar n, t)
}
EInj n -> Ann { frees = [], term = (AInj n, t) }
ELit lit -> Ann { frees = [], term = (ALit lit, t) }
EApp e1 e2 -> Ann { frees = annae1.frees <|| annae2.frees
, term = (AApp annae1 annae2, t)
}
where
(annae1, annae2) = on (,) (freeVarsExp localVars) e1 e2
EAdd e1 e2 -> Ann { frees = annae1.frees <|| annae2.frees
, term = (AAdd annae1 annae2, t)
}
where
(annae1, annae2) = on (,) (freeVarsExp localVars) e1 e2
EAbs x e -> Ann { frees = delete (x,t_x) $ annae.frees
, term = (AAbs x annae, t) }
where
annae = freeVarsExp (localVars <| (x,t_x)) e
t_x = case t of TFun t _ -> t
_ -> error "Impossible"
-- Sum free variables present in bind and the expression
-- let f x = x + y in f 5 + z → frees: y, z
ELet bind@(Bind n _ _) e ->
Ann { frees = delete n annae.frees <|| annbind.frees
, term = (ALet annbind annae, t)
}
where
annae = freeVarsExp (localVars <| n) e
annbind = freeVarsBind localVars bind
ECase e branches ->
Ann { frees = foldl (<||) annae.frees (map frees annbranches)
, term = (ACase annae annbranches, t)
}
where
annae = freeVarsExp localVars e
annbranches = map (freeVarsBranch localVars) branches
freeVarsBind :: Frees -> Bind -> Ann ABind
freeVarsBind localVars (Bind name vars e) =
Ann { frees = annae.frees \\ vars
, term = ABind name vars annae
}
where
annae = freeVarsExp (localVars <|| vars) e
freeVarsBranch :: Frees -> Branch -> Ann ABranch
freeVarsBranch localVars (Branch pt e) =
Ann { frees = annae.frees \\ varsInPattern
, term = ABranch pt annae
}
where
annae = freeVarsExp localVars e
varsInPattern = go [] pt
where
go acc (p, t) = case p of
PVar n -> acc <| (n, t)
PInj _ ps -> foldl go acc ps
_ -> []
-- AST annotated with free variables
type Frees = [(Ident, Type)]
data Ann a = Ann
{ frees :: Frees
, term :: a
} deriving (Show, Eq)
data ABind = ABind Id [Id] (Ann AExpT) deriving (Show, Eq)
data ABranch = ABranch (Pattern, Type) (Ann AExpT) deriving (Show, Eq)
type AExpT = (AExp, Type)
data AExp = AVar Ident
| AInj Ident
| ALit Lit
| ALet (Ann ABind) (Ann AExpT)
| AApp (Ann AExpT) (Ann AExpT)
| AAdd (Ann AExpT) (Ann AExpT)
| AAbs Ident (Ann AExpT)
| ACase (Ann AExpT) [Ann ABranch]
deriving (Show, Eq)
abstract :: [ABind] -> [Bind]
abstract bs = evalState (mapM (abstractAnnBind . Ann []) bs) 0
abstractAnnBind :: Ann ABind -> State Int Bind
abstractAnnBind Ann { term = ABind name vars annae } =
Bind name (vars' <|| vars) <$> abstractAnnExp annae'
where
(annae', vars') = go [] annae
where
go acc = \case
Ann { term = (AAbs x ae, TFun t _) } -> go (snoc (x, t) acc) ae
ae -> (ae, acc)
abstractAnnExp :: Ann AExpT -> State Int ExpT
abstractAnnExp Ann {frees, term = (annae, typ) } = case annae of
AVar n -> pure (EVar n, typ)
AInj n -> pure (EInj n, typ)
ALit lit -> pure (ELit lit, typ)
AApp annae1 annae2 -> (, typ) <$> onM EApp abstractAnnExp annae1 annae2
AAdd annae1 annae2 -> (, typ) <$> onM EAdd abstractAnnExp annae1 annae2
-- \x. \y. x + y + z ⇒ let sc x y z = x + y + z in sc
AAbs x annae' -> do
i <- nextNumber
rhs <- abstractAnnExp annae''
let sc_name = Ident ("sc_" ++ show i)
e@(_, t) = foldl applyFree (EVar sc_name, typ) frees
pure (ELet (Bind (sc_name, typ) vars rhs) e ,t)
where
vars = frees <| (x, t_x) <|| ys
t_x = case typ of TFun t _ -> t
_ -> error "Impossible"
(annae'', ys) = go [] annae'
where
go acc = \case
Ann { term = (AAbs x ae, TFun t _) } -> go (snoc (x, t) acc) ae
ae -> (ae, acc)
applyFree :: (Exp' Type, Type) -> (Ident, Type) -> (Exp' Type, Type)
applyFree (e, t_e) (x, t_x) = (EApp (e, t_e) (EVar x, t_x), t_e')
where
t_e' = case t_e of TFun _ t -> t
_ -> error "Impossible"
ACase annae' bs -> do
bs <- mapM go bs
e <- abstractAnnExp annae'
pure (ECase e bs, typ)
where
go Ann { term = ABranch p annae } = Branch p <$> abstractAnnExp annae
ALet b annae' ->
(, typ) <$> liftA2 ELet (abstractAnnBind b) (abstractAnnExp annae')
-- | Collects supercombinators by lifting non-constant let expressions
collectScs :: [Bind] -> [Bind]
collectScs = concatMap collectFromRhs
where
collectFromRhs (Bind name parms rhs) =
let (rhs_scs, rhs') = collectScsExp rhs
in Bind name parms rhs' : rhs_scs
collectScsExp :: ExpT -> ([Bind], ExpT)
collectScsExp expT@(exp, typ) = case exp of
EVar _ -> ([], expT)
EInj _ -> ([], expT)
ELit _ -> ([], expT)
EApp e1 e2 -> (scs1 ++ scs2, (EApp e1' e2', typ))
where
(scs1, e1') = collectScsExp e1
(scs2, e2') = collectScsExp e2
EAdd e1 e2 -> (scs1 ++ scs2, (EAdd e1' e2', typ))
where
(scs1, e1') = collectScsExp e1
(scs2, e2') = collectScsExp e2
EAbs par e -> (scs, (EAbs par e', typ))
where
(scs, e') = collectScsExp e
ECase e branches -> (scs ++ scs_e, (ECase e' branches', typ))
where
(scs, branches') = mapAccumL f [] branches
(scs_e, e') = collectScsExp e
f acc b = (acc ++ acc', b')
where (acc', b') = collectScsBranch b
-- Collect supercombinators from bind, the rhss, and the expression.
--
-- > f = let sc x y = rhs in e
--
ELet (Bind name parms rhs) e
| null parms -> (rhs_scs ++ et_scs, (ELet bind et', snd et'))
| otherwise -> (bind : rhs_scs ++ et_scs, et')
where
bind = Bind name parms rhs'
(rhs_scs, rhs') = collectScsExp rhs
(et_scs, et') = collectScsExp e
collectScsBranch (Branch patt exp) = (scs, Branch patt exp')
where (scs, exp') = collectScsExp exp
nextNumber :: State Int Int
nextNumber = do
i <- get
put $ succ i
pure i
(<|) :: Eq a => [a] -> a -> [a]
xs <| x | elem x xs = xs
| otherwise = snoc x xs
(<||) :: Eq a => [a] -> [a] -> [a]
xs <|| ys = foldl (<|) xs ys