Add closures and fix lets in monomorphizer
This commit is contained in:
parent
677a200a15
commit
72e599d5de
26 changed files with 1440 additions and 692 deletions
|
|
@ -11,9 +11,11 @@ import Control.Monad.State (MonadState (get, put), State,
|
|||
evalState)
|
||||
import Data.Function (on)
|
||||
import Data.List (delete, mapAccumL, (\\))
|
||||
import Data.Tuple.Extra (first, second)
|
||||
import LambdaLifterIr (T)
|
||||
import qualified LambdaLifterIr as L
|
||||
import Prelude hiding (exp)
|
||||
import TypeChecker.TypeCheckerIr
|
||||
|
||||
import TypeChecker.TypeCheckerIr hiding (T)
|
||||
|
||||
-- | Lift lambdas and let expression into supercombinators.
|
||||
-- Three phases:
|
||||
|
|
@ -21,12 +23,13 @@ import TypeChecker.TypeCheckerIr
|
|||
-- @abstract@ converts lambdas into let expressions.
|
||||
-- @collectScs@ moves every non-constant let expression to a top-level function.
|
||||
--
|
||||
lambdaLift :: Program -> Program
|
||||
lambdaLift (Program ds) = Program (datatypes ++ binds)
|
||||
lambdaLift :: Program -> L.Program
|
||||
lambdaLift (Program ds) = L.Program (datatypes ++ binds)
|
||||
where
|
||||
datatypes = flip filter ds $ \case DData _ -> True
|
||||
_ -> False
|
||||
binds = map DBind $ (collectScs . abstract . freeVars) [b | DBind b <- ds]
|
||||
datatypes = [L.DData (toLirData d) | DData d <- ds]
|
||||
|
||||
binds = map L.DBind $ (collectScs . abstract . freeVars) [b | DBind b <- ds]
|
||||
|
||||
|
||||
-- | Annotate free variables
|
||||
freeVars :: [Bind] -> [ABind]
|
||||
|
|
@ -36,7 +39,7 @@ freeVars binds = [ let ae = freeVarsExp [] e
|
|||
| Bind n xs e <- binds
|
||||
]
|
||||
|
||||
freeVarsExp :: Frees -> ExpT -> Ann AExpT
|
||||
freeVarsExp :: Frees -> T Exp -> Ann (T AExp)
|
||||
freeVarsExp localVars (ae, t) = case ae of
|
||||
EVar n | elem (n,t) localVars -> Ann { frees = [(n, t)]
|
||||
, term = (AVar n, t)
|
||||
|
|
@ -121,27 +124,47 @@ data Ann a = Ann
|
|||
, term :: a
|
||||
} deriving (Show, Eq)
|
||||
|
||||
data ABind = ABind Id [Id] (Ann AExpT) deriving (Show, Eq)
|
||||
data ABranch = ABranch (Pattern, Type) (Ann AExpT) deriving (Show, Eq)
|
||||
|
||||
type AExpT = (AExp, Type)
|
||||
data ABind = ABind (T Ident) [T Ident] (Ann (T AExp)) deriving (Show, Eq)
|
||||
data ABranch = ABranch (Pattern, Type) (Ann (T AExp)) deriving (Show, Eq)
|
||||
|
||||
data AExp = AVar Ident
|
||||
| AInj Ident
|
||||
| ALit Lit
|
||||
| ALet (Ann ABind) (Ann AExpT)
|
||||
| AApp (Ann AExpT) (Ann AExpT)
|
||||
| AAdd (Ann AExpT) (Ann AExpT)
|
||||
| AAbs Ident (Ann AExpT)
|
||||
| ACase (Ann AExpT) [Ann ABranch]
|
||||
| ALet (Ann ABind) (Ann (T AExp))
|
||||
| AApp (Ann (T AExp)) (Ann (T AExp))
|
||||
| AAdd (Ann (T AExp)) (Ann (T AExp))
|
||||
| AAbs Ident (Ann (T AExp))
|
||||
| ACase (Ann (T AExp)) [Ann ABranch]
|
||||
deriving (Show, Eq)
|
||||
|
||||
abstract :: [ABind] -> [Bind]
|
||||
|
||||
|
||||
data BBind = BBind (T Ident) [T Ident] (T BExp)
|
||||
| BBindCxt [T Ident] (T Ident) [T Ident] (T BExp)
|
||||
deriving (Eq, Ord, Show)
|
||||
|
||||
|
||||
data BBranch = BBranch (T Pattern) (T BExp)
|
||||
deriving (Eq, Ord, Show)
|
||||
|
||||
data BExp
|
||||
= BVar Ident
|
||||
| BVarC [T Ident] Ident
|
||||
| BInj Ident
|
||||
| BLit Lit
|
||||
| BLet BBind (T BExp)
|
||||
| BApp (T BExp)(T BExp)
|
||||
| BAdd (T BExp)(T BExp)
|
||||
| BCase (T BExp) [BBranch]
|
||||
deriving (Eq, Ord, Show)
|
||||
|
||||
|
||||
abstract :: [ABind] -> [BBind]
|
||||
abstract bs = evalState (mapM (abstractAnnBind . Ann []) bs) 0
|
||||
|
||||
abstractAnnBind :: Ann ABind -> State Int Bind
|
||||
abstractAnnBind :: Ann ABind -> State Int BBind
|
||||
abstractAnnBind Ann { term = ABind name vars annae } =
|
||||
Bind name (vars' <|| vars) <$> abstractAnnExp annae'
|
||||
BBind name (vars' <|| vars) <$> abstractAnnExp annae'
|
||||
where
|
||||
(annae', vars') = go [] annae
|
||||
where
|
||||
|
|
@ -149,24 +172,27 @@ abstractAnnBind Ann { term = ABind name vars annae } =
|
|||
Ann { term = (AAbs x ae, TFun t _) } -> go (snoc (x, t) acc) ae
|
||||
ae -> (ae, acc)
|
||||
|
||||
abstractAnnExp :: Ann AExpT -> State Int ExpT
|
||||
abstractAnnExp :: Ann (T AExp) -> State Int (T BExp)
|
||||
abstractAnnExp Ann {frees, term = (annae, typ) } = case annae of
|
||||
AVar n -> pure (EVar n, typ)
|
||||
AInj n -> pure (EInj n, typ)
|
||||
ALit lit -> pure (ELit lit, typ)
|
||||
AApp annae1 annae2 -> (, typ) <$> onM EApp abstractAnnExp annae1 annae2
|
||||
AAdd annae1 annae2 -> (, typ) <$> onM EAdd abstractAnnExp annae1 annae2
|
||||
AVar n -> pure (BVar n, typ)
|
||||
AInj n -> pure (BInj n, typ)
|
||||
ALit lit -> pure (BLit lit, typ)
|
||||
AApp annae1 annae2 -> (, typ) <$> onM BApp abstractAnnExp annae1 annae2
|
||||
AAdd annae1 annae2 -> (, typ) <$> onM BAdd abstractAnnExp annae1 annae2
|
||||
|
||||
-- \x. \y. x + y + z ⇒ let sc x y z = x + y + z in sc
|
||||
AAbs x annae' -> do
|
||||
i <- nextNumber
|
||||
rhs <- abstractAnnExp annae''
|
||||
let sc_name = Ident ("sc_" ++ show i)
|
||||
e@(_, t) = foldl applyFree (EVar sc_name, typ) frees
|
||||
pure (ELet (Bind (sc_name, typ) vars rhs) e ,t)
|
||||
sc | null frees = (BVar sc_name, typ)
|
||||
| otherwise = (BVarC frees sc_name, typ)
|
||||
bind | null frees = BBind (sc_name, typ) vars rhs
|
||||
| otherwise = BBindCxt frees (sc_name, typ) vars rhs
|
||||
|
||||
pure (BLet bind sc ,typ)
|
||||
|
||||
where
|
||||
vars = frees <| (x, t_x) <|| ys
|
||||
vars = [(x, t_x)] <|| ys
|
||||
t_x = case typ of TFun t _ -> t
|
||||
_ -> error "Impossible"
|
||||
|
||||
|
|
@ -176,54 +202,48 @@ abstractAnnExp Ann {frees, term = (annae, typ) } = case annae of
|
|||
Ann { term = (AAbs x ae, TFun t _) } -> go (snoc (x, t) acc) ae
|
||||
ae -> (ae, acc)
|
||||
|
||||
|
||||
applyFree :: (Exp' Type, Type) -> (Ident, Type) -> (Exp' Type, Type)
|
||||
applyFree (e, t_e) (x, t_x) = (EApp (e, t_e) (EVar x, t_x), t_e')
|
||||
where
|
||||
t_e' = case t_e of TFun _ t -> t
|
||||
_ -> error "Impossible"
|
||||
|
||||
ACase annae' bs -> do
|
||||
bs <- mapM go bs
|
||||
e <- abstractAnnExp annae'
|
||||
pure (ECase e bs, typ)
|
||||
pure (BCase e bs, typ)
|
||||
where
|
||||
go Ann { term = ABranch p annae } = Branch p <$> abstractAnnExp annae
|
||||
go Ann { term = ABranch p annae } = BBranch p <$> abstractAnnExp annae
|
||||
|
||||
ALet b annae' ->
|
||||
(, typ) <$> liftA2 ELet (abstractAnnBind b) (abstractAnnExp annae')
|
||||
(, typ) <$> liftA2 BLet (abstractAnnBind b) (abstractAnnExp annae')
|
||||
|
||||
|
||||
-- | Collects supercombinators by lifting non-constant let expressions
|
||||
collectScs :: [Bind] -> [Bind]
|
||||
collectScs :: [BBind] -> [L.Bind]
|
||||
collectScs = concatMap collectFromRhs
|
||||
where
|
||||
collectFromRhs (Bind name parms rhs) =
|
||||
collectFromRhs (BBind name parms rhs) =
|
||||
let (rhs_scs, rhs') = collectScsExp rhs
|
||||
in Bind name parms rhs' : rhs_scs
|
||||
in L.Bind name parms rhs' : rhs_scs
|
||||
collectFromRhs (BBindCxt cxt name parms rhs) =
|
||||
let (rhs_scs, rhs') = collectScsExp rhs
|
||||
in L.BindC cxt name parms rhs' : rhs_scs
|
||||
|
||||
|
||||
collectScsExp :: ExpT -> ([Bind], ExpT)
|
||||
collectScsExp expT@(exp, typ) = case exp of
|
||||
EVar _ -> ([], expT)
|
||||
EInj _ -> ([], expT)
|
||||
ELit _ -> ([], expT)
|
||||
collectScsExp :: T BExp -> ([L.Bind], T L.Exp)
|
||||
collectScsExp (exp, typ) = case exp of
|
||||
BVar x -> ([], (L.EVar x, typ))
|
||||
BVarC as x -> ([], (L.EVarC as x, typ))
|
||||
BInj k -> ([], (L.EInj k, typ))
|
||||
BLit lit -> ([], (L.ELit lit, typ))
|
||||
|
||||
EApp e1 e2 -> (scs1 ++ scs2, (EApp e1' e2', typ))
|
||||
BApp e1 e2 -> (scs1 ++ scs2, (L.EApp e1' e2', typ))
|
||||
where
|
||||
(scs1, e1') = collectScsExp e1
|
||||
(scs2, e2') = collectScsExp e2
|
||||
|
||||
EAdd e1 e2 -> (scs1 ++ scs2, (EAdd e1' e2', typ))
|
||||
BAdd e1 e2 -> (scs1 ++ scs2, (L.EAdd e1' e2', typ))
|
||||
where
|
||||
(scs1, e1') = collectScsExp e1
|
||||
(scs2, e2') = collectScsExp e2
|
||||
|
||||
EAbs par e -> (scs, (EAbs par e', typ))
|
||||
where
|
||||
(scs, e') = collectScsExp e
|
||||
|
||||
ECase e branches -> (scs ++ scs_e, (ECase e' branches', typ))
|
||||
BCase e branches -> (scs ++ scs_e, (L.ECase e' branches', typ))
|
||||
where
|
||||
(scs, branches') = mapAccumL f [] branches
|
||||
(scs_e, e') = collectScsExp e
|
||||
|
|
@ -234,15 +254,24 @@ collectScsExp expT@(exp, typ) = case exp of
|
|||
--
|
||||
-- > f = let sc x y = rhs in e
|
||||
--
|
||||
ELet (Bind name parms rhs) e
|
||||
| null parms -> (rhs_scs ++ et_scs, (ELet bind et', snd et'))
|
||||
BLet (BBind name parms rhs) e
|
||||
| null parms -> (rhs_scs ++ et_scs, (L.ELet name rhs' et', snd et'))
|
||||
| otherwise -> (bind : rhs_scs ++ et_scs, et')
|
||||
where
|
||||
bind = Bind name parms rhs'
|
||||
bind = L.Bind name parms rhs'
|
||||
(rhs_scs, rhs') = collectScsExp rhs
|
||||
(et_scs, et') = collectScsExp e
|
||||
|
||||
collectScsBranch (Branch patt exp) = (scs, Branch patt exp')
|
||||
|
||||
BLet (BBindCxt cxt name parms rhs) e
|
||||
| null parms -> (rhs_scs ++ et_scs, (L.ELet name rhs' et', snd et'))
|
||||
| otherwise -> (bind : rhs_scs ++ et_scs, et')
|
||||
where
|
||||
bind = L.BindC cxt name parms rhs'
|
||||
(rhs_scs, rhs') = collectScsExp rhs
|
||||
(et_scs, et') = collectScsExp e
|
||||
|
||||
collectScsBranch (BBranch patt exp) = (scs, L.Branch (first toLirPattern patt) exp')
|
||||
where (scs, exp') = collectScsExp exp
|
||||
|
||||
nextNumber :: State Int Int
|
||||
|
|
@ -259,3 +288,19 @@ xs <| x | elem x xs = xs
|
|||
(<||) :: Eq a => [a] -> [a] -> [a]
|
||||
xs <|| ys = foldl (<|) xs ys
|
||||
|
||||
|
||||
|
||||
toLirData (Data t injs) = L.Data t (map toLirInj injs)
|
||||
toLirInj (Inj n t) = L.Inj n t
|
||||
|
||||
toLirPattern :: Pattern -> L.Pattern
|
||||
toLirPattern = \case
|
||||
PVar x -> L.PVar x
|
||||
PLit lit -> L.PLit lit
|
||||
PCatch -> L.PCatch
|
||||
PEnum k -> L.PEnum k
|
||||
PInj k ps -> L.PInj k (map (first toLirPattern) ps)
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue