Document and fix code style

This commit is contained in:
Martin Fredin 2023-02-18 13:35:33 +01:00
parent 21fb6bf5ed
commit b8aedd541d

View file

@ -15,6 +15,10 @@ import TypeCheckerIr
-- | Lift lambdas and let expression into supercombinators. -- | Lift lambdas and let expression into supercombinators.
-- Three phases:
-- @freeVars@ annotatss all the free variables.
-- @abstract@ converts lambdas into let expressions.
-- @collectScs@ moves every non-constant let expression to a top-level function.
lambdaLift :: Program -> Program lambdaLift :: Program -> Program
lambdaLift = collectScs . abstract . freeVars lambdaLift = collectScs . abstract . freeVars
@ -27,37 +31,36 @@ freeVars (Program ds) = [ (n, xs, freeVarsExp (Set.fromList xs) e)
freeVarsExp :: Set Id -> Exp -> AnnExp freeVarsExp :: Set Id -> Exp -> AnnExp
freeVarsExp localVars = \case freeVarsExp localVars = \case
EId n | Set.member n localVars -> (Set.singleton n, AId n)
| otherwise -> (mempty, AId n)
EId n | Set.member n localVars -> (Set.singleton n, AId n) EInt i -> (mempty, AInt i)
| otherwise -> (mempty, AId n)
EInt i -> (mempty, AInt i) EApp t e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), AApp t e1' e2')
where
e1' = freeVarsExp localVars e1
e2' = freeVarsExp localVars e2
EApp t e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), AApp t e1' e2') EAdd t e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), AAdd t e1' e2')
where where
e1' = freeVarsExp localVars e1 e1' = freeVarsExp localVars e1
e2' = freeVarsExp localVars e2 e2' = freeVarsExp localVars e2
EAdd t e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), AAdd t e1' e2') EAbs t par e -> (Set.delete par $ freeVarsOf e', AAbs t par e')
where where
e1' = freeVarsExp localVars e1 e' = freeVarsExp (Set.insert par localVars) e
e2' = freeVarsExp localVars e2
EAbs t par e -> (Set.delete par $ freeVarsOf e', AAbs t par e') -- Sum free variables present in bind and the expression
where ELet (Bind name parms rhs) e -> (Set.union binders_frees e_free, ALet new_bind e')
e' = freeVarsExp (Set.insert par localVars) e where
binders_frees = Set.delete name $ freeVarsOf rhs'
e_free = Set.delete name $ freeVarsOf e'
-- Sum free variables present in bind and the expression rhs' = freeVarsExp e_localVars rhs
ELet (Bind name parms rhs) e -> (Set.union binders_frees e_free, ALet new_bind e') new_bind = ABind name parms rhs'
where
binders_frees = Set.delete name $ freeVarsOf rhs'
e_free = Set.delete name $ freeVarsOf e'
rhs' = freeVarsExp e_localVars rhs e' = freeVarsExp e_localVars e
new_bind = ABind name parms rhs' e_localVars = Set.insert name localVars
e' = freeVarsExp e_localVars e
e_localVars = Set.insert name localVars
freeVarsOf :: AnnExp -> Set Id freeVarsOf :: AnnExp -> Set Id
@ -82,10 +85,10 @@ data AnnExp' = AId Id
abstract :: AnnProgram -> Program abstract :: AnnProgram -> Program
abstract prog = Program $ evalState (mapM go prog) 0 abstract prog = Program $ evalState (mapM go prog) 0
where where
go :: (Id, [Id], AnnExp) -> State Int Bind go :: (Id, [Id], AnnExp) -> State Int Bind
go (name, parms, rhs) = Bind name (parms ++ parms1) <$> abstractExp rhs' go (name, parms, rhs) = Bind name (parms ++ parms1) <$> abstractExp rhs'
where where
(rhs', parms1) = flattenLambdasAnn rhs (rhs', parms1) = flattenLambdasAnn rhs
-- | Flatten nested lambdas and collect the parameters -- | Flatten nested lambdas and collect the parameters
@ -95,55 +98,55 @@ flattenLambdasAnn ae = go (ae, [])
where where
go :: (AnnExp, [Id]) -> (AnnExp, [Id]) go :: (AnnExp, [Id]) -> (AnnExp, [Id])
go ((free, e), acc) = go ((free, e), acc) =
case e of case e of
AAbs _ par (free1, e1) -> AAbs _ par (free1, e1) ->
go ((Set.delete par free1, e1), snoc par acc) go ((Set.delete par free1, e1), snoc par acc)
_ -> ((free, e), acc) _ -> ((free, e), acc)
abstractExp :: AnnExp -> State Int Exp abstractExp :: AnnExp -> State Int Exp
abstractExp (free, exp) = case exp of abstractExp (free, exp) = case exp of
AId n -> pure $ EId n AId n -> pure $ EId n
AInt i -> pure $ EInt i AInt i -> pure $ EInt i
AApp t e1 e2 -> liftA2 (EApp t) (abstractExp e1) (abstractExp e2) AApp t e1 e2 -> liftA2 (EApp t) (abstractExp e1) (abstractExp e2)
AAdd t e1 e2 -> liftA2 (EAdd t) (abstractExp e1) (abstractExp e2) AAdd t e1 e2 -> liftA2 (EAdd t) (abstractExp e1) (abstractExp e2)
ALet b e -> liftA2 ELet (go b) (abstractExp e) ALet b e -> liftA2 ELet (go b) (abstractExp e)
where where
go (ABind name parms rhs) = do go (ABind name parms rhs) = do
(rhs', parms1) <- flattenLambdas <$> skipLambdas abstractExp rhs (rhs', parms1) <- flattenLambdas <$> skipLambdas abstractExp rhs
pure $ Bind name (parms ++ parms1) rhs' pure $ Bind name (parms ++ parms1) rhs'
skipLambdas :: (AnnExp -> State Int Exp) -> AnnExp -> State Int Exp skipLambdas :: (AnnExp -> State Int Exp) -> AnnExp -> State Int Exp
skipLambdas f (free, ae) = case ae of skipLambdas f (free, ae) = case ae of
AAbs t par ae1 -> EAbs t par <$> skipLambdas f ae1 AAbs t par ae1 -> EAbs t par <$> skipLambdas f ae1
_ -> f (free, ae) _ -> f (free, ae)
-- Lift lambda into let and bind free variables -- Lift lambda into let and bind free variables
AAbs t parm e -> do AAbs t parm e -> do
i <- nextNumber i <- nextNumber
rhs <- abstractExp e rhs <- abstractExp e
let sc_name = Ident ("sc_" ++ show i) let sc_name = Ident ("sc_" ++ show i)
sc = ELet (Bind (sc_name, t) parms rhs) $ EId (sc_name, t) sc = ELet (Bind (sc_name, t) parms rhs) $ EId (sc_name, t)
pure $ foldl (EApp TInt) sc $ map EId freeList pure $ foldl (EApp TInt) sc $ map EId freeList
where where
freeList = Set.toList free freeList = Set.toList free
parms = snoc parm freeList parms = snoc parm freeList
nextNumber :: State Int Int nextNumber :: State Int Int
nextNumber = do nextNumber = do
i <- get i <- get
put $ succ i put $ succ i
pure i pure i
-- | Collects supercombinators by lifting non-constant let expressions -- | Collects supercombinators by lifting non-constant let expressions
collectScs :: Program -> Program collectScs :: Program -> Program
collectScs (Program scs) = Program $ concatMap collectFromRhs scs collectScs (Program scs) = Program $ concatMap collectFromRhs scs
where where
collectFromRhs (Bind name parms rhs) = collectFromRhs (Bind name parms rhs) =
let (rhs_scs, rhs') = collectScsExp rhs let (rhs_scs, rhs') = collectScsExp rhs
in Bind name parms rhs' : rhs_scs in Bind name parms rhs' : rhs_scs
collectScsExp :: Exp -> ([Bind], Exp) collectScsExp :: Exp -> ([Bind], Exp)
@ -183,5 +186,5 @@ flattenLambdas :: Exp -> (Exp, [Id])
flattenLambdas = go . (, []) flattenLambdas = go . (, [])
where where
go (e, acc) = case e of go (e, acc) = case e of
EAbs _ par e1 -> go (e1, snoc par acc) EAbs _ par e1 -> go (e1, snoc par acc)
_ -> (e, acc) _ -> (e, acc)