Fix lambda lifter

This commit is contained in:
Martin Fredin 2023-04-29 15:52:37 +02:00
parent df1a5de04a
commit 619242ccaf
6 changed files with 280 additions and 182 deletions

View file

@ -83,6 +83,8 @@ Test-suite language-testsuite
TestAnnForall TestAnnForall
TestReportForall TestReportForall
TestRenamer TestRenamer
TestLambdaLifter
DoStrings
Grammar.Abs Grammar.Abs
Grammar.Lex Grammar.Lex

View file

@ -1,11 +1,9 @@
data Maybe () where { data Maybe () where
Just : Int -> Maybe () Just : Int -> Maybe ()
Nothing : Maybe () Nothing : Maybe ()
};
demoFunc x = case x of { demoFunc x = case x of
Just x => x + 24; Just x => x + 24
Nothing => 0; Nothing => 0
};
main = demoFunc Nothing ; main = demoFunc Nothing

View file

@ -1,26 +0,0 @@
main = case f (Just 10) of {
Just a => a ;
Nothing => 0 ;
};
f x = bind (fmap (\s . s + 1) x) (\s . pure (s + 10)) ;
data Maybe () where {
Just : Int -> Maybe ()
Nothing : Maybe ()
};
fmap : (Int -> Int) -> Maybe () -> Maybe () ;
fmap f m = case m of {
Just a => pure (f a) ;
Nothing => Nothing ;
};
pure : Int -> Maybe () ;
pure x = Just x;
bind : Maybe () -> (Int -> Maybe ()) -> Maybe () ;
bind x f = case x of {
Just x => f x ;
Nothing => Nothing ;
};

View file

@ -1,17 +1,16 @@
{-# LANGUAGE LambdaCase #-} {-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedRecordDot #-}
{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE OverloadedStrings #-}
module LambdaLifter (lambdaLift, freeVars, abstract, collectScs) where module LambdaLifter where
import Auxiliary (mapAccumM, snoc) import Auxiliary (onM, snoc)
import Control.Applicative (Applicative (liftA2)) import Control.Applicative (Applicative (liftA2))
import Control.Arrow (Arrow (second))
import Control.Monad.State (MonadState (get, put), State, import Control.Monad.State (MonadState (get, put), State,
evalState) evalState)
import Data.List (mapAccumL, partition) import Data.Function (on)
import Data.Set (Set) import Data.List (delete, mapAccumL, (\\))
import qualified Data.Set as Set
import Prelude hiding (exp) import Prelude hiding (exp)
import TypeChecker.TypeCheckerIr import TypeChecker.TypeCheckerIr
@ -21,176 +20,190 @@ import TypeChecker.TypeCheckerIr
-- @freeVars@ annotates all the free variables. -- @freeVars@ annotates all the free variables.
-- @abstract@ converts lambdas into let expressions. -- @abstract@ converts lambdas into let expressions.
-- @collectScs@ moves every non-constant let expression to a top-level function. -- @collectScs@ moves every non-constant let expression to a top-level function.
--
lambdaLift :: Program -> Program lambdaLift :: Program -> Program
lambdaLift (Program defs) = Program $ datatypes ++ ll binds lambdaLift (Program ds) = Program (datatypes ++ binds)
where where
ll = map DBind . collectScs . abstract . freeVars . map (\(DBind b) -> b) datatypes = flip filter ds $ \case DData _ -> True
(binds, datatypes) = partition isBind defs
isBind = \case
DBind _ -> True
_ -> False _ -> False
binds = map DBind $ (collectScs . abstract . freeVars) [b | DBind b <- ds]
-- lambdaLift (Program defs) = trace (printTree abst) $ Program $ datatypes ++ ll binds
-- where
-- abst = abstract frees
-- frees = freeVars [b | DBind b@(Bind (Ident s, _) _ _) <- binds, s == "f"]
--
-- ll = map DBind . collectScs . abstract . freeVars . map (\(DBind b) -> b)
-- (binds, datatypes) = partition isBind defs
-- isBind = \case
-- DBind _ -> True
-- _ -> False
-- | Annotate free variables -- | Annotate free variables
freeVars :: [Bind] -> AnnBinds freeVars :: [Bind] -> [ABind]
freeVars binds = [ (n, xs, freeVarsExp (Set.fromList $ map fst xs) e) freeVars binds = [ let ae = freeVarsExp [] e
ae' = ae { frees = ae.frees \\ xs }
in ABind n xs ae'
| Bind n xs e <- binds | Bind n xs e <- binds
] ]
freeVarsExp :: Set Ident -> ExpT -> AnnExpT freeVarsExp :: Frees -> ExpT -> Ann AExpT
freeVarsExp localVars (exp, t) = case exp of freeVarsExp localVars (ae, t) = case ae of
EVar n | Set.member n localVars -> (Set.singleton n, (AVar n, t)) EVar n | elem (n,t) localVars -> Ann { frees = [(n, t)]
| otherwise -> (mempty, (AVar n, t)) , term = (AVar n, t)
}
| otherwise -> Ann { frees = []
, term = (AVar n, t)
}
EInj n -> (mempty, (AVar n, t)) EInj n -> Ann { frees = [], term = (AInj n, t) }
ELit lit -> (mempty, (ALit lit, t)) ELit lit -> Ann { frees = [], term = (ALit lit, t) }
EApp e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), (AApp e1' e2', t)) EApp e1 e2 -> Ann { frees = annae1.frees <|| annae2.frees
, term = (AApp annae1 annae2, t)
}
where where
e1' = freeVarsExp localVars e1 (annae1, annae2) = on (,) (freeVarsExp localVars) e1 e2
e2' = freeVarsExp localVars e2
EAdd e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), (AAdd e1' e2', t)) EAdd e1 e2 -> Ann { frees = annae1.frees <|| annae2.frees
, term = (AAdd annae1 annae2, t)
}
where where
e1' = freeVarsExp localVars e1 (annae1, annae2) = on (,) (freeVarsExp localVars) e1 e2
e2' = freeVarsExp localVars e2
EAbs par e -> (Set.delete par $ freeVarsOf e', (AAbs par e', t))
EAbs x e -> Ann { frees = delete (x,t_x) $ annae.frees
, term = (AAbs x annae, t) }
where where
e' = freeVarsExp (Set.insert par localVars) e annae = freeVarsExp (localVars <| (x,t_x)) e
t_x = case t of TFun t _ -> t
_ -> error "Impossible"
-- Sum free variables present in bind and the expression -- Sum free variables present in bind and the expression
ELet (Bind (name, t_bind) parms rhs) e -> (Set.union binders_frees e_free, (ALet new_bind e', t)) -- let f x = x + y in f 5 + z → frees: y, z
ELet bind@(Bind n _ _) e ->
Ann { frees = delete n annae.frees <|| annbind.frees
, term = (ALet annbind annae, t)
}
where where
binders_frees = Set.delete name $ freeVarsOf rhs' annae = freeVarsExp (localVars <| n) e
e_free = Set.delete name $ freeVarsOf e' annbind = freeVarsBind localVars bind
rhs' = freeVarsExp e_localVars rhs ECase e branches ->
new_bind = ABind (name, t_bind) parms rhs' Ann { frees = foldl (<||) annae.frees (map frees annbranches)
, term = (ACase annae annbranches, t)
e' = freeVarsExp e_localVars e }
e_localVars = Set.insert name localVars
ECase e branches -> (frees, (ACase e' branches', t))
where where
frees = foldr (\b s -> Set.union s $ fst b) (freeVarsOf e') branches' annae = freeVarsExp localVars e
e' = freeVarsExp localVars e annbranches = map (freeVarsBranch localVars) branches
branches' = map (freeVarsBranch localVars) branches
freeVarsBranch :: Set Ident -> Branch' Type -> (Set Ident, AnnBranch') freeVarsBind :: Frees -> Bind -> Ann ABind
freeVarsBranch localVars (Branch (patt, t) exp) = (frees, AnnBranch (patt, t) exp') freeVarsBind localVars (Bind name vars e) =
Ann { frees = annae.frees \\ vars
, term = ABind name vars annae
}
where where
frees = freeVarsOf exp' Set.\\ freeVarsOfPattern patt annae = freeVarsExp (localVars <|| vars) e
exp' = freeVarsExp localVars exp
freeVarsOfPattern = Set.fromList . go []
freeVarsBranch :: Frees -> Branch -> Ann ABranch
freeVarsBranch localVars (Branch pt e) =
Ann { frees = annae.frees \\ varsInPattern
, term = ABranch pt annae
}
where where
go acc = \case annae = freeVarsExp localVars e
PVar n -> snoc n acc varsInPattern = go [] pt
PInj _ ps -> foldl go acc $ map fst ps where
go acc (p, t) = case p of
PVar n -> acc <| (n, t)
PInj _ ps -> foldl go acc ps
_ -> []
freeVarsOf :: AnnExpT -> Set Ident
freeVarsOf = fst
-- AST annotated with free variables -- AST annotated with free variables
type AnnBinds = [(Id, [Id], AnnExpT)]
type AnnExpT = (Set Ident, AnnExpT') type Frees = [(Ident, Type)]
data ABind = ABind Id [Id] AnnExpT deriving Show data Ann a = Ann
{ frees :: Frees
, term :: a
} deriving (Show, Eq)
type AnnExpT' = (AnnExp, Type) data ABind = ABind Id [Id] (Ann AExpT) deriving (Show, Eq)
data ABranch = ABranch (Pattern, Type) (Ann AExpT) deriving (Show, Eq)
type AnnBranch = (Set Ident, AnnBranch') type AExpT = (AExp, Type)
data AnnBranch' = AnnBranch (Pattern, Type) AnnExpT
deriving Show
data AnnExp = AVar Ident data AExp = AVar Ident
| AInj Ident | AInj Ident
| ALit Lit | ALit Lit
| ALet ABind AnnExpT | ALet (Ann ABind) (Ann AExpT)
| AApp AnnExpT AnnExpT | AApp (Ann AExpT) (Ann AExpT)
| AAdd AnnExpT AnnExpT | AAdd (Ann AExpT) (Ann AExpT)
| AAbs Ident AnnExpT | AAbs Ident (Ann AExpT)
| ACase AnnExpT [AnnBranch] | ACase (Ann AExpT) [Ann ABranch]
deriving Show deriving (Show, Eq)
-- | Lift lambdas to let expression of the form @let sc = \v₁ x₁ -> e₁@. abstract :: [ABind] -> [Bind]
-- Free variables are @v₁ v₂ .. vₙ@ are bound. abstract bs = evalState (mapM (abstractAnnBind . Ann []) bs) 0
abstract :: AnnBinds -> [Bind]
abstract prog = evalState (mapM go prog) 0 abstractAnnBind :: Ann ABind -> State Int Bind
abstractAnnBind Ann { term = ABind name vars annae } =
Bind name (vars' <|| vars) <$> abstractAnnExp annae'
where where
go :: (Id, [Id], AnnExpT) -> State Int Bind (annae', vars') = go [] annae
go (name, parms, rhs) = Bind name (parms ++ parms1) <$> abstractExp rhs'
where where
(rhs', parms1) = flattenLambdasAnn rhs go acc = \case
Ann { term = (AAbs x ae, TFun t _) } -> go (snoc (x, t) acc) ae
ae -> (ae, acc)
abstractAnnExp :: Ann AExpT -> State Int ExpT
-- | Flatten nested lambdas and collect the parameters abstractAnnExp Ann {frees, term = (annae, typ) } = case annae of
-- @\x.\y.\z. ae → (ae, [x,y,z])@
flattenLambdasAnn :: AnnExpT -> (AnnExpT, [Id])
flattenLambdasAnn ae = go (ae, [])
where
go :: (AnnExpT, [Id]) -> (AnnExpT, [Id])
go ((free, (e, t)), acc)
| AAbs par (free1, e1) <- e
, TFun t_par _ <- t
= go ((Set.delete par free1, e1), snoc (par, t_par) acc)
| otherwise = ((free, (e, t)), acc)
abstractExp :: AnnExpT -> State Int ExpT
abstractExp (free, (exp, typ)) = case exp of
AVar n -> pure (EVar n, typ) AVar n -> pure (EVar n, typ)
AInj n -> pure (EInj n, typ) AInj n -> pure (EInj n, typ)
ALit lit -> pure (ELit lit, typ) ALit lit -> pure (ELit lit, typ)
AApp e1 e2 -> (, typ) <$> liftA2 EApp (abstractExp e1) (abstractExp e2) AApp annae1 annae2 -> (, typ) <$> onM EApp abstractAnnExp annae1 annae2
AAdd e1 e2 -> (, typ) <$> liftA2 EAdd (abstractExp e1) (abstractExp e2) AAdd annae1 annae2 -> (, typ) <$> onM EAdd abstractAnnExp annae1 annae2
ALet b e -> (, typ) <$> liftA2 ELet (go b) (abstractExp e)
where
go (ABind name parms rhs) = do
(rhs', parms1) <- flattenLambdas <$> skipLambdas abstractExp rhs
pure $ Bind name (parms ++ parms1) rhs'
skipLambdas :: (AnnExpT -> State Int ExpT) -> AnnExpT -> State Int ExpT -- \x. \y. x + y + z ⇒ let sc x y z = x + y + z in sc
skipLambdas f (free, (ae, t)) = case ae of AAbs x annae' -> do
AAbs par ae1 -> do
ae1' <- skipLambdas f ae1
pure (EAbs par ae1', t)
_ -> f (free, (ae, t))
ACase e branches -> (, typ) <$> liftA2 ECase (abstractExp e) (mapM abstractBranch branches)
-- Lift lambda into let and bind free variables
AAbs parm e -> do
i <- nextNumber i <- nextNumber
rhs <- abstractExp e rhs <- abstractAnnExp annae''
let sc_name = Ident ("sc_" ++ show i) let sc_name = Ident ("sc_" ++ show i)
sc = (ELet (Bind (sc_name, typ) vars rhs) (EVar sc_name, typ), typ) sc = (ELet (Bind (sc_name, typ) vars rhs) (EVar sc_name, typ), typ)
pure $ foldl applyVars sc freeList pure $ foldl applyFree sc frees
where where
freeList = Set.toList free vars = frees <| (x, t_x) <|| ys
vars = zip names $ getVars typ t_x = case typ of TFun t _ -> t
names = snoc parm freeList _ -> error "Impossible"
applyVars (e, t) name = (EApp (e, t) (EVar name, t_var), t_return)
(annae'', ys) = go [] annae'
where where
(t_var, t_return) = case t of go acc = \case
TFun t1 t2 -> (t1, t2) Ann { term = (AAbs x ae, TFun t _) } -> go (snoc (x, t) acc) ae
ae -> (ae, acc)
applyFree :: (Exp' Type, Type) -> (Ident, Type) -> (Exp' Type, Type)
applyFree (e, t_e) (x, t_x) = (EApp (e, t_e) (EVar x, t_x), t_e')
where
t_e' = case t_e of TFun _ t -> t
_ -> error "Impossible"
abstractBranch :: AnnBranch -> State Int Branch ACase annae' bs -> do
abstractBranch (_, AnnBranch patt exp) = Branch patt <$> abstractExp exp bs <- mapM go bs
e <- abstractAnnExp annae'
pure (ECase e bs, typ)
where
go Ann { term = ABranch p annae } = Branch p <$> abstractAnnExp annae
ALet b annae' ->
(, typ) <$> liftA2 ELet (abstractAnnBind b) (abstractAnnExp annae')
nextNumber :: State Int Int
nextNumber = do
i <- get
put $ succ i
pure i
-- | Collects supercombinators by lifting non-constant let expressions -- | Collects supercombinators by lifting non-constant let expressions
collectScs :: [Bind] -> [Bind] collectScs :: [Bind] -> [Bind]
@ -232,9 +245,9 @@ collectScsExp expT@(exp, typ) = case exp of
-- --
-- > f = let sc x y = rhs in e -- > f = let sc x y = rhs in e
-- --
ELet (Bind name parms rhs) e -> if null parms ELet (Bind name parms rhs) e
then ( rhs_scs ++ et_scs, (ELet bind et', snd et')) | null parms -> (rhs_scs ++ et_scs, (ELet bind et', snd et'))
else (bind : rhs_scs ++ et_scs, et') | otherwise -> (bind : rhs_scs ++ et_scs, et')
where where
bind = Bind name parms rhs' bind = Bind name parms rhs'
(rhs_scs, rhs') = collectScsExp rhs (rhs_scs, rhs') = collectScsExp rhs
@ -243,23 +256,17 @@ collectScsExp expT@(exp, typ) = case exp of
collectScsBranch (Branch patt exp) = (scs, Branch patt exp') collectScsBranch (Branch patt exp) = (scs, Branch patt exp')
where (scs, exp') = collectScsExp exp where (scs, exp') = collectScsExp exp
nextNumber :: State Int Int
nextNumber = do
i <- get
put $ succ i
pure i
-- @\x.\y.\z. e → (e, [x,y,z])@
flattenLambdas :: ExpT -> (ExpT, [Id])
flattenLambdas = go . (, [])
where
go ((e, t), acc) = case e of
EAbs name e1 -> go (e1, snoc (name, t_var) acc)
where t_var = head $ getVars t
_ -> ((e, t), acc)
getVars :: Type -> [Type] (<|) :: Eq a => [a] -> a -> [a]
getVars = fst . partitionType xs <| x | elem x xs = xs
| otherwise = snoc x xs
partitionType :: Type -> ([Type], Type) (<||) :: Eq a => [a] -> [a] -> [a]
partitionType = go [] xs <|| ys = foldl (<|) xs ys
where
go acc t = case t of
TFun t1 t2 -> go (snoc t1 acc) t2
_ -> (acc, t)

117
tests/TestLambdaLifter.hs Normal file
View file

@ -0,0 +1,117 @@
{-# LANGUAGE PatternSynonyms #-}
{-# HLINT ignore "Use camelCase" #-}
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE QualifiedDo #-}
module TestLambdaLifter where
import Test.Hspec
import AnnForall (annotateForall)
import Control.Monad ((<=<))
import Control.Monad.Error.Class (liftEither)
import Control.Monad.Extra (eitherM)
import Grammar.ErrM (Err, pattern Bad, pattern Ok)
import Grammar.Layout (resolveLayout)
import Grammar.Par (myLexer, pProgram)
import Grammar.Print (printTree)
import LambdaLifter
import Renamer.Renamer (rename)
import ReportForall (reportForall)
import TypeChecker.RemoveForall (removeForall)
import TypeChecker.ReportTEVar (reportTEVar)
import TypeChecker.TypeChecker (TypeChecker (Bi))
import TypeChecker.TypeCheckerBidir (typecheck)
import TypeChecker.TypeCheckerIr
test = hspec testLambdaLifter
testLambdaLifter = describe "Test Lambda Lifter" $ do
undefined
-- frees_exp1
-- frees_exp1 = specify "Free variables 1" $
-- freeVarsExp [] (EAbs "x" (EVar "x", TVar' "a"), TVar' "a")
-- `shouldBe` answer
-- where
-- answer = Ann { frees = []
-- , term = (AAbs (Ident "x") (Ann { frees = [Ident "x"]
-- , term = (AVar (Ident "x"),TVar (MkTVar (Ident "a")))
-- }
-- ),TVar (MkTVar (Ident "a")))
-- }
abs_1 = undefined
where
input = unlines [ "data List (a) where"
, " Nil : List (a)"
, " Cons : a -> List (a) -> List (a)"
, "map : (a -> b) -> List (a) -> List (b)"
, "add : Int -> Int -> Int"
, "f : List (Int)"
, "f = (\\x.\\ys. map (\\y. add y x) ys) 4 (Cons 1 (Cons 2 Nil))"
]
runPrintFree = print $ freeVarsExp [] (EAbs "x" (EVar "x", TVar' "a"), TVar' "a")
runAbstract = either putStrLn (putStrLn . printTree) (runAbs s2)
where
s = unlines [ "add : Int -> Int -> Int"
, "f : Int -> Int -> Int"
, "f x y = add x y"
, "f = \\x. (\\y. add x y)"
]
s2 = unlines [ "data List (a) where"
, " Nil : List (a)"
, " Cons : a -> List (a) -> List (a)"
, "map : (a -> b) -> List (a) -> List (b)"
, "add : Int -> Int -> Int"
, "f : List (Int)"
, "f = (\\x.\\ys. map (\\y. add y x) ys) 4 (Cons 1 (Cons 2 Nil))"
]
runCollect = either putStrLn (putStrLn . printTree) (run s)
where
s = unlines [ "data List (a) where"
, " Nil : List (a)"
, " Cons : a -> List (a) -> List (a)"
, "add : Int -> Int -> Int"
, "map : (a -> b) -> List (a) -> List (b)"
, "map f xs = case xs of"
, " Nil => Nil"
, " Cons x xs => Cons (f x) (map f xs)"
, "f : List (Int)"
, "f = (\\x.\\ys. map (\\y. add y x) ys) 4 (Cons 1 (Cons 2 Nil))"
]
run = fmap collectScs . runAbs
runAbs s = do
Program ds <- run' s
pure $ (abstract . freeVars) [b | DBind b <- ds]
run' = fmap removeForall
. reportTEVar
<=< typecheck
<=< run''
run'' s = do
p <- (pProgram . resolveLayout True . myLexer) s
reportForall Bi p
(rename <=< annotateForall) p