diff --git a/language.cabal b/language.cabal index 6ae9e12..af7178c 100644 --- a/language.cabal +++ b/language.cabal @@ -83,6 +83,8 @@ Test-suite language-testsuite TestAnnForall TestReportForall TestRenamer + TestLambdaLifter + DoStrings Grammar.Abs Grammar.Lex diff --git a/sample-programs/example-programs/ex4.crf b/sample-programs/example-programs/ex4.crf index a64adb5..9f412c6 100644 --- a/sample-programs/example-programs/ex4.crf +++ b/sample-programs/example-programs/ex4.crf @@ -1,11 +1,9 @@ -data Maybe () where { +data Maybe () where Just : Int -> Maybe () Nothing : Maybe () -}; -demoFunc x = case x of { - Just x => x + 24; - Nothing => 0; -}; +demoFunc x = case x of + Just x => x + 24 + Nothing => 0 -main = demoFunc Nothing ; \ No newline at end of file +main = demoFunc Nothing diff --git a/sample-programs/example-programs/ex5.crf b/sample-programs/example-programs/ex5.crf index b9457ed..e69de29 100644 --- a/sample-programs/example-programs/ex5.crf +++ b/sample-programs/example-programs/ex5.crf @@ -1,26 +0,0 @@ -main = case f (Just 10) of { - Just a => a ; - Nothing => 0 ; -}; - -f x = bind (fmap (\s . s + 1) x) (\s . pure (s + 10)) ; - -data Maybe () where { - Just : Int -> Maybe () - Nothing : Maybe () -}; - -fmap : (Int -> Int) -> Maybe () -> Maybe () ; -fmap f m = case m of { - Just a => pure (f a) ; - Nothing => Nothing ; -}; - -pure : Int -> Maybe () ; -pure x = Just x; - -bind : Maybe () -> (Int -> Maybe ()) -> Maybe () ; -bind x f = case x of { - Just x => f x ; - Nothing => Nothing ; -}; \ No newline at end of file diff --git a/sample-programs/example-programs/ex6.crf b/sample-programs/example-programs/ex6.crf index 41894a0..ebf8c6b 100644 --- a/sample-programs/example-programs/ex6.crf +++ b/sample-programs/example-programs/ex6.crf @@ -40,4 +40,4 @@ repeatHelp acc x n = case n of { -- represents minus one :) minusOne : Int ; -minusOne = 9223372036854775807 + 9223372036854775807 + 1; \ No newline at end of file +minusOne = 9223372036854775807 + 9223372036854775807 + 1; diff --git a/src/LambdaLifter.hs b/src/LambdaLifter.hs index dcd715b..83d3466 100644 --- a/src/LambdaLifter.hs +++ b/src/LambdaLifter.hs @@ -1,17 +1,16 @@ -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedRecordDot #-} +{-# LANGUAGE OverloadedStrings #-} -module LambdaLifter (lambdaLift, freeVars, abstract, collectScs) where +module LambdaLifter where -import Auxiliary (mapAccumM, snoc) +import Auxiliary (onM, snoc) import Control.Applicative (Applicative (liftA2)) -import Control.Arrow (Arrow (second)) import Control.Monad.State (MonadState (get, put), State, evalState) -import Data.List (mapAccumL, partition) -import Data.Set (Set) -import qualified Data.Set as Set +import Data.Function (on) +import Data.List (delete, mapAccumL, (\\)) import Prelude hiding (exp) import TypeChecker.TypeCheckerIr @@ -21,176 +20,190 @@ import TypeChecker.TypeCheckerIr -- @freeVars@ annotates all the free variables. -- @abstract@ converts lambdas into let expressions. -- @collectScs@ moves every non-constant let expression to a top-level function. +-- lambdaLift :: Program -> Program -lambdaLift (Program defs) = Program $ datatypes ++ ll binds +lambdaLift (Program ds) = Program (datatypes ++ binds) where - ll = map DBind . collectScs . abstract . freeVars . map (\(DBind b) -> b) - (binds, datatypes) = partition isBind defs - isBind = \case - DBind _ -> True - _ -> False + datatypes = flip filter ds $ \case DData _ -> True + _ -> False + binds = map DBind $ (collectScs . abstract . freeVars) [b | DBind b <- ds] + +-- lambdaLift (Program defs) = trace (printTree abst) $ Program $ datatypes ++ ll binds +-- where +-- abst = abstract frees +-- frees = freeVars [b | DBind b@(Bind (Ident s, _) _ _) <- binds, s == "f"] +-- +-- ll = map DBind . collectScs . abstract . freeVars . map (\(DBind b) -> b) +-- (binds, datatypes) = partition isBind defs +-- isBind = \case +-- DBind _ -> True +-- _ -> False -- | Annotate free variables -freeVars :: [Bind] -> AnnBinds -freeVars binds = [ (n, xs, freeVarsExp (Set.fromList $ map fst xs) e) +freeVars :: [Bind] -> [ABind] +freeVars binds = [ let ae = freeVarsExp [] e + ae' = ae { frees = ae.frees \\ xs } + in ABind n xs ae' | Bind n xs e <- binds ] -freeVarsExp :: Set Ident -> ExpT -> AnnExpT -freeVarsExp localVars (exp, t) = case exp of - EVar n | Set.member n localVars -> (Set.singleton n, (AVar n, t)) - | otherwise -> (mempty, (AVar n, t)) +freeVarsExp :: Frees -> ExpT -> Ann AExpT +freeVarsExp localVars (ae, t) = case ae of + EVar n | elem (n,t) localVars -> Ann { frees = [(n, t)] + , term = (AVar n, t) + } + | otherwise -> Ann { frees = [] + , term = (AVar n, t) + } - EInj n -> (mempty, (AVar n, t)) + EInj n -> Ann { frees = [], term = (AInj n, t) } - ELit lit -> (mempty, (ALit lit, t)) + ELit lit -> Ann { frees = [], term = (ALit lit, t) } - EApp e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), (AApp e1' e2', t)) + EApp e1 e2 -> Ann { frees = annae1.frees <|| annae2.frees + , term = (AApp annae1 annae2, t) + } where - e1' = freeVarsExp localVars e1 - e2' = freeVarsExp localVars e2 + (annae1, annae2) = on (,) (freeVarsExp localVars) e1 e2 - EAdd e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), (AAdd e1' e2', t)) + EAdd e1 e2 -> Ann { frees = annae1.frees <|| annae2.frees + , term = (AAdd annae1 annae2, t) + } where - e1' = freeVarsExp localVars e1 - e2' = freeVarsExp localVars e2 + (annae1, annae2) = on (,) (freeVarsExp localVars) e1 e2 - EAbs par e -> (Set.delete par $ freeVarsOf e', (AAbs par e', t)) + + EAbs x e -> Ann { frees = delete (x,t_x) $ annae.frees + , term = (AAbs x annae, t) } where - e' = freeVarsExp (Set.insert par localVars) e + annae = freeVarsExp (localVars <| (x,t_x)) e + t_x = case t of TFun t _ -> t + _ -> error "Impossible" -- Sum free variables present in bind and the expression - ELet (Bind (name, t_bind) parms rhs) e -> (Set.union binders_frees e_free, (ALet new_bind e', t)) + -- let f x = x + y in f 5 + z → frees: y, z + ELet bind@(Bind n _ _) e -> + Ann { frees = delete n annae.frees <|| annbind.frees + , term = (ALet annbind annae, t) + } where - binders_frees = Set.delete name $ freeVarsOf rhs' - e_free = Set.delete name $ freeVarsOf e' + annae = freeVarsExp (localVars <| n) e + annbind = freeVarsBind localVars bind - rhs' = freeVarsExp e_localVars rhs - new_bind = ABind (name, t_bind) parms rhs' - - e' = freeVarsExp e_localVars e - e_localVars = Set.insert name localVars - - ECase e branches -> (frees, (ACase e' branches', t)) + ECase e branches -> + Ann { frees = foldl (<||) annae.frees (map frees annbranches) + , term = (ACase annae annbranches, t) + } where - frees = foldr (\b s -> Set.union s $ fst b) (freeVarsOf e') branches' - e' = freeVarsExp localVars e - branches' = map (freeVarsBranch localVars) branches + annae = freeVarsExp localVars e + annbranches = map (freeVarsBranch localVars) branches -freeVarsBranch :: Set Ident -> Branch' Type -> (Set Ident, AnnBranch') -freeVarsBranch localVars (Branch (patt, t) exp) = (frees, AnnBranch (patt, t) exp') +freeVarsBind :: Frees -> Bind -> Ann ABind +freeVarsBind localVars (Bind name vars e) = + Ann { frees = annae.frees \\ vars + , term = ABind name vars annae + } where - frees = freeVarsOf exp' Set.\\ freeVarsOfPattern patt - exp' = freeVarsExp localVars exp - freeVarsOfPattern = Set.fromList . go [] + annae = freeVarsExp (localVars <|| vars) e + + +freeVarsBranch :: Frees -> Branch -> Ann ABranch +freeVarsBranch localVars (Branch pt e) = + Ann { frees = annae.frees \\ varsInPattern + , term = ABranch pt annae + } + where + annae = freeVarsExp localVars e + varsInPattern = go [] pt where - go acc = \case - PVar n -> snoc n acc - PInj _ ps -> foldl go acc $ map fst ps + go acc (p, t) = case p of + PVar n -> acc <| (n, t) + PInj _ ps -> foldl go acc ps + _ -> [] - -freeVarsOf :: AnnExpT -> Set Ident -freeVarsOf = fst - -- AST annotated with free variables -type AnnBinds = [(Id, [Id], AnnExpT)] -type AnnExpT = (Set Ident, AnnExpT') +type Frees = [(Ident, Type)] -data ABind = ABind Id [Id] AnnExpT deriving Show +data Ann a = Ann + { frees :: Frees + , term :: a + } deriving (Show, Eq) -type AnnExpT' = (AnnExp, Type) +data ABind = ABind Id [Id] (Ann AExpT) deriving (Show, Eq) +data ABranch = ABranch (Pattern, Type) (Ann AExpT) deriving (Show, Eq) -type AnnBranch = (Set Ident, AnnBranch') -data AnnBranch' = AnnBranch (Pattern, Type) AnnExpT - deriving Show +type AExpT = (AExp, Type) -data AnnExp = AVar Ident +data AExp = AVar Ident | AInj Ident | ALit Lit - | ALet ABind AnnExpT - | AApp AnnExpT AnnExpT - | AAdd AnnExpT AnnExpT - | AAbs Ident AnnExpT - | ACase AnnExpT [AnnBranch] - deriving Show + | ALet (Ann ABind) (Ann AExpT) + | AApp (Ann AExpT) (Ann AExpT) + | AAdd (Ann AExpT) (Ann AExpT) + | AAbs Ident (Ann AExpT) + | ACase (Ann AExpT) [Ann ABranch] + deriving (Show, Eq) --- | Lift lambdas to let expression of the form @let sc = \v₁ x₁ -> e₁@. --- Free variables are @v₁ v₂ .. vₙ@ are bound. -abstract :: AnnBinds -> [Bind] -abstract prog = evalState (mapM go prog) 0 +abstract :: [ABind] -> [Bind] +abstract bs = evalState (mapM (abstractAnnBind . Ann []) bs) 0 + +abstractAnnBind :: Ann ABind -> State Int Bind +abstractAnnBind Ann { term = ABind name vars annae } = + Bind name (vars' <|| vars) <$> abstractAnnExp annae' where - go :: (Id, [Id], AnnExpT) -> State Int Bind - go (name, parms, rhs) = Bind name (parms ++ parms1) <$> abstractExp rhs' + (annae', vars') = go [] annae where - (rhs', parms1) = flattenLambdasAnn rhs + go acc = \case + Ann { term = (AAbs x ae, TFun t _) } -> go (snoc (x, t) acc) ae + ae -> (ae, acc) - --- | Flatten nested lambdas and collect the parameters --- @\x.\y.\z. ae → (ae, [x,y,z])@ -flattenLambdasAnn :: AnnExpT -> (AnnExpT, [Id]) -flattenLambdasAnn ae = go (ae, []) - where - go :: (AnnExpT, [Id]) -> (AnnExpT, [Id]) - go ((free, (e, t)), acc) - | AAbs par (free1, e1) <- e - , TFun t_par _ <- t - = go ((Set.delete par free1, e1), snoc (par, t_par) acc) - | otherwise = ((free, (e, t)), acc) - -abstractExp :: AnnExpT -> State Int ExpT -abstractExp (free, (exp, typ)) = case exp of - AVar n -> pure (EVar n, typ) - AInj n -> pure (EInj n, typ) +abstractAnnExp :: Ann AExpT -> State Int ExpT +abstractAnnExp Ann {frees, term = (annae, typ) } = case annae of + AVar n -> pure (EVar n, typ) + AInj n -> pure (EInj n, typ) ALit lit -> pure (ELit lit, typ) - AApp e1 e2 -> (, typ) <$> liftA2 EApp (abstractExp e1) (abstractExp e2) - AAdd e1 e2 -> (, typ) <$> liftA2 EAdd (abstractExp e1) (abstractExp e2) - ALet b e -> (, typ) <$> liftA2 ELet (go b) (abstractExp e) - where - go (ABind name parms rhs) = do - (rhs', parms1) <- flattenLambdas <$> skipLambdas abstractExp rhs - pure $ Bind name (parms ++ parms1) rhs' + AApp annae1 annae2 -> (, typ) <$> onM EApp abstractAnnExp annae1 annae2 + AAdd annae1 annae2 -> (, typ) <$> onM EAdd abstractAnnExp annae1 annae2 - skipLambdas :: (AnnExpT -> State Int ExpT) -> AnnExpT -> State Int ExpT - skipLambdas f (free, (ae, t)) = case ae of - AAbs par ae1 -> do - ae1' <- skipLambdas f ae1 - pure (EAbs par ae1', t) - _ -> f (free, (ae, t)) - - ACase e branches -> (, typ) <$> liftA2 ECase (abstractExp e) (mapM abstractBranch branches) - - - -- Lift lambda into let and bind free variables - AAbs parm e -> do + -- \x. \y. x + y + z ⇒ let sc x y z = x + y + z in sc + AAbs x annae' -> do i <- nextNumber - rhs <- abstractExp e - + rhs <- abstractAnnExp annae'' let sc_name = Ident ("sc_" ++ show i) sc = (ELet (Bind (sc_name, typ) vars rhs) (EVar sc_name, typ), typ) - pure $ foldl applyVars sc freeList + pure $ foldl applyFree sc frees where - freeList = Set.toList free - vars = zip names $ getVars typ - names = snoc parm freeList - applyVars (e, t) name = (EApp (e, t) (EVar name, t_var), t_return) + vars = frees <| (x, t_x) <|| ys + t_x = case typ of TFun t _ -> t + _ -> error "Impossible" + + (annae'', ys) = go [] annae' where - (t_var, t_return) = case t of - TFun t1 t2 -> (t1, t2) + go acc = \case + Ann { term = (AAbs x ae, TFun t _) } -> go (snoc (x, t) acc) ae + ae -> (ae, acc) + applyFree :: (Exp' Type, Type) -> (Ident, Type) -> (Exp' Type, Type) + applyFree (e, t_e) (x, t_x) = (EApp (e, t_e) (EVar x, t_x), t_e') + where + t_e' = case t_e of TFun _ t -> t + _ -> error "Impossible" -abstractBranch :: AnnBranch -> State Int Branch -abstractBranch (_, AnnBranch patt exp) = Branch patt <$> abstractExp exp + ACase annae' bs -> do + bs <- mapM go bs + e <- abstractAnnExp annae' + pure (ECase e bs, typ) + where + go Ann { term = ABranch p annae } = Branch p <$> abstractAnnExp annae + + ALet b annae' -> + (, typ) <$> liftA2 ELet (abstractAnnBind b) (abstractAnnExp annae') -nextNumber :: State Int Int -nextNumber = do - i <- get - put $ succ i - pure i -- | Collects supercombinators by lifting non-constant let expressions collectScs :: [Bind] -> [Bind] @@ -232,34 +245,28 @@ collectScsExp expT@(exp, typ) = case exp of -- -- > f = let sc x y = rhs in e -- - ELet (Bind name parms rhs) e -> if null parms - then ( rhs_scs ++ et_scs, (ELet bind et', snd et')) - else (bind : rhs_scs ++ et_scs, et') + ELet (Bind name parms rhs) e + | null parms -> (rhs_scs ++ et_scs, (ELet bind et', snd et')) + | otherwise -> (bind : rhs_scs ++ et_scs, et') where bind = Bind name parms rhs' (rhs_scs, rhs') = collectScsExp rhs - (et_scs, et') = collectScsExp e + (et_scs, et') = collectScsExp e collectScsBranch (Branch patt exp) = (scs, Branch patt exp') where (scs, exp') = collectScsExp exp +nextNumber :: State Int Int +nextNumber = do + i <- get + put $ succ i + pure i --- @\x.\y.\z. e → (e, [x,y,z])@ -flattenLambdas :: ExpT -> (ExpT, [Id]) -flattenLambdas = go . (, []) - where - go ((e, t), acc) = case e of - EAbs name e1 -> go (e1, snoc (name, t_var) acc) - where t_var = head $ getVars t - _ -> ((e, t), acc) -getVars :: Type -> [Type] -getVars = fst . partitionType +(<|) :: Eq a => [a] -> a -> [a] +xs <| x | elem x xs = xs + | otherwise = snoc x xs -partitionType :: Type -> ([Type], Type) -partitionType = go [] - where - go acc t = case t of - TFun t1 t2 -> go (snoc t1 acc) t2 - _ -> (acc, t) +(<||) :: Eq a => [a] -> [a] -> [a] +xs <|| ys = foldl (<|) xs ys diff --git a/tests/TestLambdaLifter.hs b/tests/TestLambdaLifter.hs new file mode 100644 index 0000000..79c78b2 --- /dev/null +++ b/tests/TestLambdaLifter.hs @@ -0,0 +1,117 @@ +{-# LANGUAGE PatternSynonyms #-} +{-# HLINT ignore "Use camelCase" #-} +{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE QualifiedDo #-} + +module TestLambdaLifter where + +import Test.Hspec + +import AnnForall (annotateForall) +import Control.Monad ((<=<)) +import Control.Monad.Error.Class (liftEither) +import Control.Monad.Extra (eitherM) +import Grammar.ErrM (Err, pattern Bad, pattern Ok) +import Grammar.Layout (resolveLayout) +import Grammar.Par (myLexer, pProgram) +import Grammar.Print (printTree) +import LambdaLifter +import Renamer.Renamer (rename) +import ReportForall (reportForall) +import TypeChecker.RemoveForall (removeForall) +import TypeChecker.ReportTEVar (reportTEVar) +import TypeChecker.TypeChecker (TypeChecker (Bi)) +import TypeChecker.TypeCheckerBidir (typecheck) +import TypeChecker.TypeCheckerIr + + +test = hspec testLambdaLifter + +testLambdaLifter = describe "Test Lambda Lifter" $ do + undefined +-- frees_exp1 + +-- frees_exp1 = specify "Free variables 1" $ +-- freeVarsExp [] (EAbs "x" (EVar "x", TVar' "a"), TVar' "a") +-- `shouldBe` answer +-- where +-- answer = Ann { frees = [] +-- , term = (AAbs (Ident "x") (Ann { frees = [Ident "x"] +-- , term = (AVar (Ident "x"),TVar (MkTVar (Ident "a"))) +-- } +-- ),TVar (MkTVar (Ident "a"))) +-- } + + +abs_1 = undefined + where + input = unlines [ "data List (a) where" + , " Nil : List (a)" + , " Cons : a -> List (a) -> List (a)" + , "map : (a -> b) -> List (a) -> List (b)" + , "add : Int -> Int -> Int" + + , "f : List (Int)" + , "f = (\\x.\\ys. map (\\y. add y x) ys) 4 (Cons 1 (Cons 2 Nil))" + ] + + + +runPrintFree = print $ freeVarsExp [] (EAbs "x" (EVar "x", TVar' "a"), TVar' "a") + +runAbstract = either putStrLn (putStrLn . printTree) (runAbs s2) + where + s = unlines [ "add : Int -> Int -> Int" + , "f : Int -> Int -> Int" + , "f x y = add x y" + , "f = \\x. (\\y. add x y)" + ] + + s2 = unlines [ "data List (a) where" + , " Nil : List (a)" + , " Cons : a -> List (a) -> List (a)" + , "map : (a -> b) -> List (a) -> List (b)" + , "add : Int -> Int -> Int" + + , "f : List (Int)" + , "f = (\\x.\\ys. map (\\y. add y x) ys) 4 (Cons 1 (Cons 2 Nil))" + ] + + +runCollect = either putStrLn (putStrLn . printTree) (run s) + where + s = unlines [ "data List (a) where" + , " Nil : List (a)" + , " Cons : a -> List (a) -> List (a)" + , "add : Int -> Int -> Int" + , "map : (a -> b) -> List (a) -> List (b)" + , "map f xs = case xs of" + , " Nil => Nil" + , " Cons x xs => Cons (f x) (map f xs)" + + , "f : List (Int)" + , "f = (\\x.\\ys. map (\\y. add y x) ys) 4 (Cons 1 (Cons 2 Nil))" + ] + + +run = fmap collectScs . runAbs + +runAbs s = do + Program ds <- run' s + pure $ (abstract . freeVars) [b | DBind b <- ds] + + +run' = fmap removeForall + . reportTEVar + <=< typecheck + <=< run'' + +run'' s = do + p <- (pProgram . resolveLayout True . myLexer) s + reportForall Bi p + (rename <=< annotateForall) p + + + +