Change grammar: only one bind in let and no EAnn for typed syntax

This commit is contained in:
Martin Fredin 2023-02-18 12:57:23 +01:00
parent 7cedc2e28c
commit a3e57dde7b
7 changed files with 172 additions and 228 deletions

View file

@ -7,12 +7,10 @@ module LambdaLifter (lambdaLift, freeVars, abstract, rename, collectScs) where
import Auxiliary (snoc)
import Control.Applicative (Applicative (liftA2))
import Control.Monad.State (MonadState (get, put), State, evalState)
import Data.Foldable.Extra (notNull)
import Data.List (mapAccumL, partition)
import Data.Set (Set, (\\))
import Data.Set (Set)
import qualified Data.Set as Set
import Prelude hiding (exp)
import Renamer hiding (fromBinders)
import Renamer
import TypeCheckerIr
@ -49,35 +47,22 @@ freeVarsExp localVars = \case
where
e' = freeVarsExp (Set.insert par localVars) e
-- Sum free variables present in binders and the expression
ELet binders e -> (Set.union binders_frees e_free, ALet binders' e')
-- Sum free variables present in bind and the expression
ELet (Bind name parms rhs) e -> (Set.union binders_frees e_free, ALet new_bind e')
where
binders_frees = rhss_frees \\ names_set
e_free = freeVarsOf e' \\ names_set
binders_frees = Set.delete name $ freeVarsOf rhs'
e_free = Set.delete name $ freeVarsOf e'
rhss_frees = foldr1 Set.union (map freeVarsOf rhss')
names_set = Set.fromList names
rhs' = freeVarsExp e_localVars rhs
new_bind = ABind name parms rhs'
(names, parms, rhss) = fromBinders binders
rhss' = map (freeVarsExp e_localVars) rhss
e_localVars = Set.union localVars names_set
binders' = zipWith3 ABind names parms rhss'
e' = freeVarsExp e_localVars e
EAnn e t -> (freeVarsOf e', AAnn e' t)
where
e' = freeVarsExp localVars e
e' = freeVarsExp e_localVars e
e_localVars = Set.insert name localVars
freeVarsOf :: AnnExp -> Set Id
freeVarsOf = fst
fromBinders :: [Bind] -> ([Id], [[Id]], [Exp])
fromBinders bs = unzip3 [ (name, parms, rhs) | Bind name parms rhs <- bs ]
-- AST annotated with free variables
type AnnProgram = [(Id, [Id], AnnExp)]
@ -87,14 +72,11 @@ data ABind = ABind Id [Id] AnnExp deriving Show
data AnnExp' = AId Id
| AInt Integer
| ALet [ABind] AnnExp
| ALet ABind AnnExp
| AApp Type AnnExp AnnExp
| AAdd Type AnnExp AnnExp
| AAbs Type Id AnnExp
| AAnn AnnExp Type
deriving Show
-- | Lift lambdas to let expression of the form @let sc = \v₁ x₁ -> e₁@.
-- Free variables are @v₁ v₂ .. vₙ@ are bound.
abstract :: AnnProgram -> Program
@ -124,7 +106,7 @@ abstractExp (free, exp) = case exp of
AInt i -> pure $ EInt i
AApp t e1 e2 -> liftA2 (EApp t) (abstractExp e1) (abstractExp e2)
AAdd t e1 e2 -> liftA2 (EAdd t) (abstractExp e1) (abstractExp e2)
ALet bs e -> liftA2 ELet (mapM go bs) (abstractExp e)
ALet b e -> liftA2 ELet (go b) (abstractExp e)
where
go (ABind name parms rhs) = do
(rhs', parms1) <- flattenLambdas <$> skipLambdas abstractExp rhs
@ -141,14 +123,13 @@ abstractExp (free, exp) = case exp of
rhs <- abstractExp e
let sc_name = Ident ("sc_" ++ show i)
sc = ELet [Bind (sc_name, t) parms rhs] $ EId (sc_name, t)
sc = ELet (Bind (sc_name, t) parms rhs) $ EId (sc_name, t)
pure $ foldl (EApp TInt) sc $ map EId freeList
where
freeList = Set.toList free
parms = snoc parm freeList
AAnn e t -> abstractExp e >>= \e' -> pure $ EAnn e' t
nextNumber :: State Int Int
nextNumber = do
@ -156,7 +137,7 @@ nextNumber = do
put $ succ i
pure i
-- | Collects supercombinators by lifting appropriate let expressions
-- | Collects supercombinators by lifting non-constant let expressions
collectScs :: Program -> Program
collectScs (Program scs) = Program $ concatMap collectFromRhs scs
where
@ -167,49 +148,35 @@ collectScs (Program scs) = Program $ concatMap collectFromRhs scs
collectScsExp :: Exp -> ([Bind], Exp)
collectScsExp = \case
EId n -> ([], EId n)
EInt i -> ([], EInt i)
EId n -> ([], EId n)
EInt i -> ([], EInt i)
EApp t e1 e2 -> (scs1 ++ scs2, EApp t e1' e2')
where
(scs1, e1') = collectScsExp e1
(scs2, e2') = collectScsExp e2
EApp t e1 e2 -> (scs1 ++ scs2, EApp t e1' e2')
where
(scs1, e1') = collectScsExp e1
(scs2, e2') = collectScsExp e2
EAdd t e1 e2 -> (scs1 ++ scs2, EAdd t e1' e2')
where
(scs1, e1') = collectScsExp e1
(scs2, e2') = collectScsExp e2
EAdd t e1 e2 -> (scs1 ++ scs2, EAdd t e1' e2')
where
(scs1, e1') = collectScsExp e1
(scs2, e2') = collectScsExp e2
EAbs t par e -> (scs, EAbs t par e')
where
(scs, e') = collectScsExp e
EAbs t par e -> (scs, EAbs t par e')
where
(scs, e') = collectScsExp e
-- Collect supercombinators from binds, the rhss, and the expression.
--
-- > f = let
-- > sc = rhs
-- > sc1 = rhs1
-- > ...
-- > in e
--
ELet binds e -> (binds_scs ++ rhss_scs ++ e_scs, mkEAbs non_scs' e')
where
binds_scs = [ let (rhs', parms1) = flattenLambdas rhs in
Bind n (parms ++ parms1) rhs'
| Bind n parms rhs <- scs'
]
(rhss_scs, binds') = mapAccumL collectScsRhs [] binds
(e_scs, e') = collectScsExp e
-- Collect supercombinators from bind, the rhss, and the expression.
--
-- > f = let sc x y = rhs in e
--
ELet (Bind name parms rhs) e -> if null parms
then ( rhs_scs ++ e_scs, ELet bind e')
else (bind : rhs_scs ++ e_scs, e')
where
bind = Bind name parms rhs'
(rhs_scs, rhs') = collectScsExp rhs
(e_scs, e') = collectScsExp e
(scs', non_scs') = partition (\(Bind _ pars _) -> notNull pars) binds'
collectScsRhs acc (Bind n xs rhs) = (acc ++ rhs_scs, Bind n xs rhs')
where
(rhs_scs, rhs') = collectScsExp rhs
EAnn e t -> (scs, EAnn e' t)
where
(scs, e') = collectScsExp e
-- @\x.\y.\z. e → (e, [x,y,z])@
flattenLambdas :: Exp -> (Exp, [Id])
@ -218,7 +185,3 @@ flattenLambdas = go . (, [])
go (e, acc) = case e of
EAbs _ par e1 -> go (e1, snoc par acc)
_ -> (e, acc)
mkEAbs :: [Bind] -> Exp -> Exp
mkEAbs [] e = e
mkEAbs bs e = ELet bs e