Fix lambda lifter

This commit is contained in:
Martin Fredin 2023-04-29 15:52:37 +02:00
parent df1a5de04a
commit 619242ccaf
6 changed files with 280 additions and 182 deletions

View file

@ -83,6 +83,8 @@ Test-suite language-testsuite
TestAnnForall
TestReportForall
TestRenamer
TestLambdaLifter
DoStrings
Grammar.Abs
Grammar.Lex

View file

@ -1,11 +1,9 @@
data Maybe () where {
data Maybe () where
Just : Int -> Maybe ()
Nothing : Maybe ()
};
demoFunc x = case x of {
Just x => x + 24;
Nothing => 0;
};
demoFunc x = case x of
Just x => x + 24
Nothing => 0
main = demoFunc Nothing ;
main = demoFunc Nothing

View file

@ -1,26 +0,0 @@
main = case f (Just 10) of {
Just a => a ;
Nothing => 0 ;
};
f x = bind (fmap (\s . s + 1) x) (\s . pure (s + 10)) ;
data Maybe () where {
Just : Int -> Maybe ()
Nothing : Maybe ()
};
fmap : (Int -> Int) -> Maybe () -> Maybe () ;
fmap f m = case m of {
Just a => pure (f a) ;
Nothing => Nothing ;
};
pure : Int -> Maybe () ;
pure x = Just x;
bind : Maybe () -> (Int -> Maybe ()) -> Maybe () ;
bind x f = case x of {
Just x => f x ;
Nothing => Nothing ;
};

View file

@ -40,4 +40,4 @@ repeatHelp acc x n = case n of {
-- represents minus one :)
minusOne : Int ;
minusOne = 9223372036854775807 + 9223372036854775807 + 1;
minusOne = 9223372036854775807 + 9223372036854775807 + 1;

View file

@ -1,17 +1,16 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedRecordDot #-}
{-# LANGUAGE OverloadedStrings #-}
module LambdaLifter (lambdaLift, freeVars, abstract, collectScs) where
module LambdaLifter where
import Auxiliary (mapAccumM, snoc)
import Auxiliary (onM, snoc)
import Control.Applicative (Applicative (liftA2))
import Control.Arrow (Arrow (second))
import Control.Monad.State (MonadState (get, put), State,
evalState)
import Data.List (mapAccumL, partition)
import Data.Set (Set)
import qualified Data.Set as Set
import Data.Function (on)
import Data.List (delete, mapAccumL, (\\))
import Prelude hiding (exp)
import TypeChecker.TypeCheckerIr
@ -21,176 +20,190 @@ import TypeChecker.TypeCheckerIr
-- @freeVars@ annotates all the free variables.
-- @abstract@ converts lambdas into let expressions.
-- @collectScs@ moves every non-constant let expression to a top-level function.
--
lambdaLift :: Program -> Program
lambdaLift (Program defs) = Program $ datatypes ++ ll binds
lambdaLift (Program ds) = Program (datatypes ++ binds)
where
ll = map DBind . collectScs . abstract . freeVars . map (\(DBind b) -> b)
(binds, datatypes) = partition isBind defs
isBind = \case
DBind _ -> True
_ -> False
datatypes = flip filter ds $ \case DData _ -> True
_ -> False
binds = map DBind $ (collectScs . abstract . freeVars) [b | DBind b <- ds]
-- lambdaLift (Program defs) = trace (printTree abst) $ Program $ datatypes ++ ll binds
-- where
-- abst = abstract frees
-- frees = freeVars [b | DBind b@(Bind (Ident s, _) _ _) <- binds, s == "f"]
--
-- ll = map DBind . collectScs . abstract . freeVars . map (\(DBind b) -> b)
-- (binds, datatypes) = partition isBind defs
-- isBind = \case
-- DBind _ -> True
-- _ -> False
-- | Annotate free variables
freeVars :: [Bind] -> AnnBinds
freeVars binds = [ (n, xs, freeVarsExp (Set.fromList $ map fst xs) e)
freeVars :: [Bind] -> [ABind]
freeVars binds = [ let ae = freeVarsExp [] e
ae' = ae { frees = ae.frees \\ xs }
in ABind n xs ae'
| Bind n xs e <- binds
]
freeVarsExp :: Set Ident -> ExpT -> AnnExpT
freeVarsExp localVars (exp, t) = case exp of
EVar n | Set.member n localVars -> (Set.singleton n, (AVar n, t))
| otherwise -> (mempty, (AVar n, t))
freeVarsExp :: Frees -> ExpT -> Ann AExpT
freeVarsExp localVars (ae, t) = case ae of
EVar n | elem (n,t) localVars -> Ann { frees = [(n, t)]
, term = (AVar n, t)
}
| otherwise -> Ann { frees = []
, term = (AVar n, t)
}
EInj n -> (mempty, (AVar n, t))
EInj n -> Ann { frees = [], term = (AInj n, t) }
ELit lit -> (mempty, (ALit lit, t))
ELit lit -> Ann { frees = [], term = (ALit lit, t) }
EApp e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), (AApp e1' e2', t))
EApp e1 e2 -> Ann { frees = annae1.frees <|| annae2.frees
, term = (AApp annae1 annae2, t)
}
where
e1' = freeVarsExp localVars e1
e2' = freeVarsExp localVars e2
(annae1, annae2) = on (,) (freeVarsExp localVars) e1 e2
EAdd e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), (AAdd e1' e2', t))
EAdd e1 e2 -> Ann { frees = annae1.frees <|| annae2.frees
, term = (AAdd annae1 annae2, t)
}
where
e1' = freeVarsExp localVars e1
e2' = freeVarsExp localVars e2
(annae1, annae2) = on (,) (freeVarsExp localVars) e1 e2
EAbs par e -> (Set.delete par $ freeVarsOf e', (AAbs par e', t))
EAbs x e -> Ann { frees = delete (x,t_x) $ annae.frees
, term = (AAbs x annae, t) }
where
e' = freeVarsExp (Set.insert par localVars) e
annae = freeVarsExp (localVars <| (x,t_x)) e
t_x = case t of TFun t _ -> t
_ -> error "Impossible"
-- Sum free variables present in bind and the expression
ELet (Bind (name, t_bind) parms rhs) e -> (Set.union binders_frees e_free, (ALet new_bind e', t))
-- let f x = x + y in f 5 + z → frees: y, z
ELet bind@(Bind n _ _) e ->
Ann { frees = delete n annae.frees <|| annbind.frees
, term = (ALet annbind annae, t)
}
where
binders_frees = Set.delete name $ freeVarsOf rhs'
e_free = Set.delete name $ freeVarsOf e'
annae = freeVarsExp (localVars <| n) e
annbind = freeVarsBind localVars bind
rhs' = freeVarsExp e_localVars rhs
new_bind = ABind (name, t_bind) parms rhs'
e' = freeVarsExp e_localVars e
e_localVars = Set.insert name localVars
ECase e branches -> (frees, (ACase e' branches', t))
ECase e branches ->
Ann { frees = foldl (<||) annae.frees (map frees annbranches)
, term = (ACase annae annbranches, t)
}
where
frees = foldr (\b s -> Set.union s $ fst b) (freeVarsOf e') branches'
e' = freeVarsExp localVars e
branches' = map (freeVarsBranch localVars) branches
annae = freeVarsExp localVars e
annbranches = map (freeVarsBranch localVars) branches
freeVarsBranch :: Set Ident -> Branch' Type -> (Set Ident, AnnBranch')
freeVarsBranch localVars (Branch (patt, t) exp) = (frees, AnnBranch (patt, t) exp')
freeVarsBind :: Frees -> Bind -> Ann ABind
freeVarsBind localVars (Bind name vars e) =
Ann { frees = annae.frees \\ vars
, term = ABind name vars annae
}
where
frees = freeVarsOf exp' Set.\\ freeVarsOfPattern patt
exp' = freeVarsExp localVars exp
freeVarsOfPattern = Set.fromList . go []
annae = freeVarsExp (localVars <|| vars) e
freeVarsBranch :: Frees -> Branch -> Ann ABranch
freeVarsBranch localVars (Branch pt e) =
Ann { frees = annae.frees \\ varsInPattern
, term = ABranch pt annae
}
where
annae = freeVarsExp localVars e
varsInPattern = go [] pt
where
go acc = \case
PVar n -> snoc n acc
PInj _ ps -> foldl go acc $ map fst ps
go acc (p, t) = case p of
PVar n -> acc <| (n, t)
PInj _ ps -> foldl go acc ps
_ -> []
freeVarsOf :: AnnExpT -> Set Ident
freeVarsOf = fst
-- AST annotated with free variables
type AnnBinds = [(Id, [Id], AnnExpT)]
type AnnExpT = (Set Ident, AnnExpT')
type Frees = [(Ident, Type)]
data ABind = ABind Id [Id] AnnExpT deriving Show
data Ann a = Ann
{ frees :: Frees
, term :: a
} deriving (Show, Eq)
type AnnExpT' = (AnnExp, Type)
data ABind = ABind Id [Id] (Ann AExpT) deriving (Show, Eq)
data ABranch = ABranch (Pattern, Type) (Ann AExpT) deriving (Show, Eq)
type AnnBranch = (Set Ident, AnnBranch')
data AnnBranch' = AnnBranch (Pattern, Type) AnnExpT
deriving Show
type AExpT = (AExp, Type)
data AnnExp = AVar Ident
data AExp = AVar Ident
| AInj Ident
| ALit Lit
| ALet ABind AnnExpT
| AApp AnnExpT AnnExpT
| AAdd AnnExpT AnnExpT
| AAbs Ident AnnExpT
| ACase AnnExpT [AnnBranch]
deriving Show
| ALet (Ann ABind) (Ann AExpT)
| AApp (Ann AExpT) (Ann AExpT)
| AAdd (Ann AExpT) (Ann AExpT)
| AAbs Ident (Ann AExpT)
| ACase (Ann AExpT) [Ann ABranch]
deriving (Show, Eq)
-- | Lift lambdas to let expression of the form @let sc = \v₁ x₁ -> e₁@.
-- Free variables are @v₁ v₂ .. vₙ@ are bound.
abstract :: AnnBinds -> [Bind]
abstract prog = evalState (mapM go prog) 0
abstract :: [ABind] -> [Bind]
abstract bs = evalState (mapM (abstractAnnBind . Ann []) bs) 0
abstractAnnBind :: Ann ABind -> State Int Bind
abstractAnnBind Ann { term = ABind name vars annae } =
Bind name (vars' <|| vars) <$> abstractAnnExp annae'
where
go :: (Id, [Id], AnnExpT) -> State Int Bind
go (name, parms, rhs) = Bind name (parms ++ parms1) <$> abstractExp rhs'
(annae', vars') = go [] annae
where
(rhs', parms1) = flattenLambdasAnn rhs
go acc = \case
Ann { term = (AAbs x ae, TFun t _) } -> go (snoc (x, t) acc) ae
ae -> (ae, acc)
-- | Flatten nested lambdas and collect the parameters
-- @\x.\y.\z. ae → (ae, [x,y,z])@
flattenLambdasAnn :: AnnExpT -> (AnnExpT, [Id])
flattenLambdasAnn ae = go (ae, [])
where
go :: (AnnExpT, [Id]) -> (AnnExpT, [Id])
go ((free, (e, t)), acc)
| AAbs par (free1, e1) <- e
, TFun t_par _ <- t
= go ((Set.delete par free1, e1), snoc (par, t_par) acc)
| otherwise = ((free, (e, t)), acc)
abstractExp :: AnnExpT -> State Int ExpT
abstractExp (free, (exp, typ)) = case exp of
AVar n -> pure (EVar n, typ)
AInj n -> pure (EInj n, typ)
abstractAnnExp :: Ann AExpT -> State Int ExpT
abstractAnnExp Ann {frees, term = (annae, typ) } = case annae of
AVar n -> pure (EVar n, typ)
AInj n -> pure (EInj n, typ)
ALit lit -> pure (ELit lit, typ)
AApp e1 e2 -> (, typ) <$> liftA2 EApp (abstractExp e1) (abstractExp e2)
AAdd e1 e2 -> (, typ) <$> liftA2 EAdd (abstractExp e1) (abstractExp e2)
ALet b e -> (, typ) <$> liftA2 ELet (go b) (abstractExp e)
where
go (ABind name parms rhs) = do
(rhs', parms1) <- flattenLambdas <$> skipLambdas abstractExp rhs
pure $ Bind name (parms ++ parms1) rhs'
AApp annae1 annae2 -> (, typ) <$> onM EApp abstractAnnExp annae1 annae2
AAdd annae1 annae2 -> (, typ) <$> onM EAdd abstractAnnExp annae1 annae2
skipLambdas :: (AnnExpT -> State Int ExpT) -> AnnExpT -> State Int ExpT
skipLambdas f (free, (ae, t)) = case ae of
AAbs par ae1 -> do
ae1' <- skipLambdas f ae1
pure (EAbs par ae1', t)
_ -> f (free, (ae, t))
ACase e branches -> (, typ) <$> liftA2 ECase (abstractExp e) (mapM abstractBranch branches)
-- Lift lambda into let and bind free variables
AAbs parm e -> do
-- \x. \y. x + y + z ⇒ let sc x y z = x + y + z in sc
AAbs x annae' -> do
i <- nextNumber
rhs <- abstractExp e
rhs <- abstractAnnExp annae''
let sc_name = Ident ("sc_" ++ show i)
sc = (ELet (Bind (sc_name, typ) vars rhs) (EVar sc_name, typ), typ)
pure $ foldl applyVars sc freeList
pure $ foldl applyFree sc frees
where
freeList = Set.toList free
vars = zip names $ getVars typ
names = snoc parm freeList
applyVars (e, t) name = (EApp (e, t) (EVar name, t_var), t_return)
vars = frees <| (x, t_x) <|| ys
t_x = case typ of TFun t _ -> t
_ -> error "Impossible"
(annae'', ys) = go [] annae'
where
(t_var, t_return) = case t of
TFun t1 t2 -> (t1, t2)
go acc = \case
Ann { term = (AAbs x ae, TFun t _) } -> go (snoc (x, t) acc) ae
ae -> (ae, acc)
applyFree :: (Exp' Type, Type) -> (Ident, Type) -> (Exp' Type, Type)
applyFree (e, t_e) (x, t_x) = (EApp (e, t_e) (EVar x, t_x), t_e')
where
t_e' = case t_e of TFun _ t -> t
_ -> error "Impossible"
abstractBranch :: AnnBranch -> State Int Branch
abstractBranch (_, AnnBranch patt exp) = Branch patt <$> abstractExp exp
ACase annae' bs -> do
bs <- mapM go bs
e <- abstractAnnExp annae'
pure (ECase e bs, typ)
where
go Ann { term = ABranch p annae } = Branch p <$> abstractAnnExp annae
ALet b annae' ->
(, typ) <$> liftA2 ELet (abstractAnnBind b) (abstractAnnExp annae')
nextNumber :: State Int Int
nextNumber = do
i <- get
put $ succ i
pure i
-- | Collects supercombinators by lifting non-constant let expressions
collectScs :: [Bind] -> [Bind]
@ -232,34 +245,28 @@ collectScsExp expT@(exp, typ) = case exp of
--
-- > f = let sc x y = rhs in e
--
ELet (Bind name parms rhs) e -> if null parms
then ( rhs_scs ++ et_scs, (ELet bind et', snd et'))
else (bind : rhs_scs ++ et_scs, et')
ELet (Bind name parms rhs) e
| null parms -> (rhs_scs ++ et_scs, (ELet bind et', snd et'))
| otherwise -> (bind : rhs_scs ++ et_scs, et')
where
bind = Bind name parms rhs'
(rhs_scs, rhs') = collectScsExp rhs
(et_scs, et') = collectScsExp e
(et_scs, et') = collectScsExp e
collectScsBranch (Branch patt exp) = (scs, Branch patt exp')
where (scs, exp') = collectScsExp exp
nextNumber :: State Int Int
nextNumber = do
i <- get
put $ succ i
pure i
-- @\x.\y.\z. e → (e, [x,y,z])@
flattenLambdas :: ExpT -> (ExpT, [Id])
flattenLambdas = go . (, [])
where
go ((e, t), acc) = case e of
EAbs name e1 -> go (e1, snoc (name, t_var) acc)
where t_var = head $ getVars t
_ -> ((e, t), acc)
getVars :: Type -> [Type]
getVars = fst . partitionType
(<|) :: Eq a => [a] -> a -> [a]
xs <| x | elem x xs = xs
| otherwise = snoc x xs
partitionType :: Type -> ([Type], Type)
partitionType = go []
where
go acc t = case t of
TFun t1 t2 -> go (snoc t1 acc) t2
_ -> (acc, t)
(<||) :: Eq a => [a] -> [a] -> [a]
xs <|| ys = foldl (<|) xs ys

117
tests/TestLambdaLifter.hs Normal file
View file

@ -0,0 +1,117 @@
{-# LANGUAGE PatternSynonyms #-}
{-# HLINT ignore "Use camelCase" #-}
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE QualifiedDo #-}
module TestLambdaLifter where
import Test.Hspec
import AnnForall (annotateForall)
import Control.Monad ((<=<))
import Control.Monad.Error.Class (liftEither)
import Control.Monad.Extra (eitherM)
import Grammar.ErrM (Err, pattern Bad, pattern Ok)
import Grammar.Layout (resolveLayout)
import Grammar.Par (myLexer, pProgram)
import Grammar.Print (printTree)
import LambdaLifter
import Renamer.Renamer (rename)
import ReportForall (reportForall)
import TypeChecker.RemoveForall (removeForall)
import TypeChecker.ReportTEVar (reportTEVar)
import TypeChecker.TypeChecker (TypeChecker (Bi))
import TypeChecker.TypeCheckerBidir (typecheck)
import TypeChecker.TypeCheckerIr
test = hspec testLambdaLifter
testLambdaLifter = describe "Test Lambda Lifter" $ do
undefined
-- frees_exp1
-- frees_exp1 = specify "Free variables 1" $
-- freeVarsExp [] (EAbs "x" (EVar "x", TVar' "a"), TVar' "a")
-- `shouldBe` answer
-- where
-- answer = Ann { frees = []
-- , term = (AAbs (Ident "x") (Ann { frees = [Ident "x"]
-- , term = (AVar (Ident "x"),TVar (MkTVar (Ident "a")))
-- }
-- ),TVar (MkTVar (Ident "a")))
-- }
abs_1 = undefined
where
input = unlines [ "data List (a) where"
, " Nil : List (a)"
, " Cons : a -> List (a) -> List (a)"
, "map : (a -> b) -> List (a) -> List (b)"
, "add : Int -> Int -> Int"
, "f : List (Int)"
, "f = (\\x.\\ys. map (\\y. add y x) ys) 4 (Cons 1 (Cons 2 Nil))"
]
runPrintFree = print $ freeVarsExp [] (EAbs "x" (EVar "x", TVar' "a"), TVar' "a")
runAbstract = either putStrLn (putStrLn . printTree) (runAbs s2)
where
s = unlines [ "add : Int -> Int -> Int"
, "f : Int -> Int -> Int"
, "f x y = add x y"
, "f = \\x. (\\y. add x y)"
]
s2 = unlines [ "data List (a) where"
, " Nil : List (a)"
, " Cons : a -> List (a) -> List (a)"
, "map : (a -> b) -> List (a) -> List (b)"
, "add : Int -> Int -> Int"
, "f : List (Int)"
, "f = (\\x.\\ys. map (\\y. add y x) ys) 4 (Cons 1 (Cons 2 Nil))"
]
runCollect = either putStrLn (putStrLn . printTree) (run s)
where
s = unlines [ "data List (a) where"
, " Nil : List (a)"
, " Cons : a -> List (a) -> List (a)"
, "add : Int -> Int -> Int"
, "map : (a -> b) -> List (a) -> List (b)"
, "map f xs = case xs of"
, " Nil => Nil"
, " Cons x xs => Cons (f x) (map f xs)"
, "f : List (Int)"
, "f = (\\x.\\ys. map (\\y. add y x) ys) 4 (Cons 1 (Cons 2 Nil))"
]
run = fmap collectScs . runAbs
runAbs s = do
Program ds <- run' s
pure $ (abstract . freeVars) [b | DBind b <- ds]
run' = fmap removeForall
. reportTEVar
<=< typecheck
<=< run''
run'' s = do
p <- (pProgram . resolveLayout True . myLexer) s
reportForall Bi p
(rename <=< annotateForall) p