Add closures and fix lets in monomorphizer

This commit is contained in:
Martin Fredin 2023-05-06 22:49:08 +02:00
parent 677a200a15
commit 72e599d5de
26 changed files with 1440 additions and 692 deletions

View file

@ -11,9 +11,11 @@ import Control.Monad.State (MonadState (get, put), State,
evalState)
import Data.Function (on)
import Data.List (delete, mapAccumL, (\\))
import Data.Tuple.Extra (first, second)
import LambdaLifterIr (T)
import qualified LambdaLifterIr as L
import Prelude hiding (exp)
import TypeChecker.TypeCheckerIr
import TypeChecker.TypeCheckerIr hiding (T)
-- | Lift lambdas and let expression into supercombinators.
-- Three phases:
@ -21,12 +23,13 @@ import TypeChecker.TypeCheckerIr
-- @abstract@ converts lambdas into let expressions.
-- @collectScs@ moves every non-constant let expression to a top-level function.
--
lambdaLift :: Program -> Program
lambdaLift (Program ds) = Program (datatypes ++ binds)
lambdaLift :: Program -> L.Program
lambdaLift (Program ds) = L.Program (datatypes ++ binds)
where
datatypes = flip filter ds $ \case DData _ -> True
_ -> False
binds = map DBind $ (collectScs . abstract . freeVars) [b | DBind b <- ds]
datatypes = [L.DData (toLirData d) | DData d <- ds]
binds = map L.DBind $ (collectScs . abstract . freeVars) [b | DBind b <- ds]
-- | Annotate free variables
freeVars :: [Bind] -> [ABind]
@ -36,7 +39,7 @@ freeVars binds = [ let ae = freeVarsExp [] e
| Bind n xs e <- binds
]
freeVarsExp :: Frees -> ExpT -> Ann AExpT
freeVarsExp :: Frees -> T Exp -> Ann (T AExp)
freeVarsExp localVars (ae, t) = case ae of
EVar n | elem (n,t) localVars -> Ann { frees = [(n, t)]
, term = (AVar n, t)
@ -121,27 +124,47 @@ data Ann a = Ann
, term :: a
} deriving (Show, Eq)
data ABind = ABind Id [Id] (Ann AExpT) deriving (Show, Eq)
data ABranch = ABranch (Pattern, Type) (Ann AExpT) deriving (Show, Eq)
type AExpT = (AExp, Type)
data ABind = ABind (T Ident) [T Ident] (Ann (T AExp)) deriving (Show, Eq)
data ABranch = ABranch (Pattern, Type) (Ann (T AExp)) deriving (Show, Eq)
data AExp = AVar Ident
| AInj Ident
| ALit Lit
| ALet (Ann ABind) (Ann AExpT)
| AApp (Ann AExpT) (Ann AExpT)
| AAdd (Ann AExpT) (Ann AExpT)
| AAbs Ident (Ann AExpT)
| ACase (Ann AExpT) [Ann ABranch]
| ALet (Ann ABind) (Ann (T AExp))
| AApp (Ann (T AExp)) (Ann (T AExp))
| AAdd (Ann (T AExp)) (Ann (T AExp))
| AAbs Ident (Ann (T AExp))
| ACase (Ann (T AExp)) [Ann ABranch]
deriving (Show, Eq)
abstract :: [ABind] -> [Bind]
data BBind = BBind (T Ident) [T Ident] (T BExp)
| BBindCxt [T Ident] (T Ident) [T Ident] (T BExp)
deriving (Eq, Ord, Show)
data BBranch = BBranch (T Pattern) (T BExp)
deriving (Eq, Ord, Show)
data BExp
= BVar Ident
| BVarC [T Ident] Ident
| BInj Ident
| BLit Lit
| BLet BBind (T BExp)
| BApp (T BExp)(T BExp)
| BAdd (T BExp)(T BExp)
| BCase (T BExp) [BBranch]
deriving (Eq, Ord, Show)
abstract :: [ABind] -> [BBind]
abstract bs = evalState (mapM (abstractAnnBind . Ann []) bs) 0
abstractAnnBind :: Ann ABind -> State Int Bind
abstractAnnBind :: Ann ABind -> State Int BBind
abstractAnnBind Ann { term = ABind name vars annae } =
Bind name (vars' <|| vars) <$> abstractAnnExp annae'
BBind name (vars' <|| vars) <$> abstractAnnExp annae'
where
(annae', vars') = go [] annae
where
@ -149,24 +172,27 @@ abstractAnnBind Ann { term = ABind name vars annae } =
Ann { term = (AAbs x ae, TFun t _) } -> go (snoc (x, t) acc) ae
ae -> (ae, acc)
abstractAnnExp :: Ann AExpT -> State Int ExpT
abstractAnnExp :: Ann (T AExp) -> State Int (T BExp)
abstractAnnExp Ann {frees, term = (annae, typ) } = case annae of
AVar n -> pure (EVar n, typ)
AInj n -> pure (EInj n, typ)
ALit lit -> pure (ELit lit, typ)
AApp annae1 annae2 -> (, typ) <$> onM EApp abstractAnnExp annae1 annae2
AAdd annae1 annae2 -> (, typ) <$> onM EAdd abstractAnnExp annae1 annae2
AVar n -> pure (BVar n, typ)
AInj n -> pure (BInj n, typ)
ALit lit -> pure (BLit lit, typ)
AApp annae1 annae2 -> (, typ) <$> onM BApp abstractAnnExp annae1 annae2
AAdd annae1 annae2 -> (, typ) <$> onM BAdd abstractAnnExp annae1 annae2
-- \x. \y. x + y + z ⇒ let sc x y z = x + y + z in sc
AAbs x annae' -> do
i <- nextNumber
rhs <- abstractAnnExp annae''
let sc_name = Ident ("sc_" ++ show i)
e@(_, t) = foldl applyFree (EVar sc_name, typ) frees
pure (ELet (Bind (sc_name, typ) vars rhs) e ,t)
sc | null frees = (BVar sc_name, typ)
| otherwise = (BVarC frees sc_name, typ)
bind | null frees = BBind (sc_name, typ) vars rhs
| otherwise = BBindCxt frees (sc_name, typ) vars rhs
pure (BLet bind sc ,typ)
where
vars = frees <| (x, t_x) <|| ys
vars = [(x, t_x)] <|| ys
t_x = case typ of TFun t _ -> t
_ -> error "Impossible"
@ -176,54 +202,48 @@ abstractAnnExp Ann {frees, term = (annae, typ) } = case annae of
Ann { term = (AAbs x ae, TFun t _) } -> go (snoc (x, t) acc) ae
ae -> (ae, acc)
applyFree :: (Exp' Type, Type) -> (Ident, Type) -> (Exp' Type, Type)
applyFree (e, t_e) (x, t_x) = (EApp (e, t_e) (EVar x, t_x), t_e')
where
t_e' = case t_e of TFun _ t -> t
_ -> error "Impossible"
ACase annae' bs -> do
bs <- mapM go bs
e <- abstractAnnExp annae'
pure (ECase e bs, typ)
pure (BCase e bs, typ)
where
go Ann { term = ABranch p annae } = Branch p <$> abstractAnnExp annae
go Ann { term = ABranch p annae } = BBranch p <$> abstractAnnExp annae
ALet b annae' ->
(, typ) <$> liftA2 ELet (abstractAnnBind b) (abstractAnnExp annae')
(, typ) <$> liftA2 BLet (abstractAnnBind b) (abstractAnnExp annae')
-- | Collects supercombinators by lifting non-constant let expressions
collectScs :: [Bind] -> [Bind]
collectScs :: [BBind] -> [L.Bind]
collectScs = concatMap collectFromRhs
where
collectFromRhs (Bind name parms rhs) =
collectFromRhs (BBind name parms rhs) =
let (rhs_scs, rhs') = collectScsExp rhs
in Bind name parms rhs' : rhs_scs
in L.Bind name parms rhs' : rhs_scs
collectFromRhs (BBindCxt cxt name parms rhs) =
let (rhs_scs, rhs') = collectScsExp rhs
in L.BindC cxt name parms rhs' : rhs_scs
collectScsExp :: ExpT -> ([Bind], ExpT)
collectScsExp expT@(exp, typ) = case exp of
EVar _ -> ([], expT)
EInj _ -> ([], expT)
ELit _ -> ([], expT)
collectScsExp :: T BExp -> ([L.Bind], T L.Exp)
collectScsExp (exp, typ) = case exp of
BVar x -> ([], (L.EVar x, typ))
BVarC as x -> ([], (L.EVarC as x, typ))
BInj k -> ([], (L.EInj k, typ))
BLit lit -> ([], (L.ELit lit, typ))
EApp e1 e2 -> (scs1 ++ scs2, (EApp e1' e2', typ))
BApp e1 e2 -> (scs1 ++ scs2, (L.EApp e1' e2', typ))
where
(scs1, e1') = collectScsExp e1
(scs2, e2') = collectScsExp e2
EAdd e1 e2 -> (scs1 ++ scs2, (EAdd e1' e2', typ))
BAdd e1 e2 -> (scs1 ++ scs2, (L.EAdd e1' e2', typ))
where
(scs1, e1') = collectScsExp e1
(scs2, e2') = collectScsExp e2
EAbs par e -> (scs, (EAbs par e', typ))
where
(scs, e') = collectScsExp e
ECase e branches -> (scs ++ scs_e, (ECase e' branches', typ))
BCase e branches -> (scs ++ scs_e, (L.ECase e' branches', typ))
where
(scs, branches') = mapAccumL f [] branches
(scs_e, e') = collectScsExp e
@ -234,15 +254,24 @@ collectScsExp expT@(exp, typ) = case exp of
--
-- > f = let sc x y = rhs in e
--
ELet (Bind name parms rhs) e
| null parms -> (rhs_scs ++ et_scs, (ELet bind et', snd et'))
BLet (BBind name parms rhs) e
| null parms -> (rhs_scs ++ et_scs, (L.ELet name rhs' et', snd et'))
| otherwise -> (bind : rhs_scs ++ et_scs, et')
where
bind = Bind name parms rhs'
bind = L.Bind name parms rhs'
(rhs_scs, rhs') = collectScsExp rhs
(et_scs, et') = collectScsExp e
collectScsBranch (Branch patt exp) = (scs, Branch patt exp')
BLet (BBindCxt cxt name parms rhs) e
| null parms -> (rhs_scs ++ et_scs, (L.ELet name rhs' et', snd et'))
| otherwise -> (bind : rhs_scs ++ et_scs, et')
where
bind = L.BindC cxt name parms rhs'
(rhs_scs, rhs') = collectScsExp rhs
(et_scs, et') = collectScsExp e
collectScsBranch (BBranch patt exp) = (scs, L.Branch (first toLirPattern patt) exp')
where (scs, exp') = collectScsExp exp
nextNumber :: State Int Int
@ -259,3 +288,19 @@ xs <| x | elem x xs = xs
(<||) :: Eq a => [a] -> [a] -> [a]
xs <|| ys = foldl (<|) xs ys
toLirData (Data t injs) = L.Data t (map toLirInj injs)
toLirInj (Inj n t) = L.Inj n t
toLirPattern :: Pattern -> L.Pattern
toLirPattern = \case
PVar x -> L.PVar x
PLit lit -> L.PLit lit
PCatch -> L.PCatch
PEnum k -> L.PEnum k
PInj k ps -> L.PInj k (map (first toLirPattern) ps)