Fix lambda lifter
This commit is contained in:
parent
df1a5de04a
commit
619242ccaf
6 changed files with 280 additions and 182 deletions
|
|
@ -83,6 +83,8 @@ Test-suite language-testsuite
|
||||||
TestAnnForall
|
TestAnnForall
|
||||||
TestReportForall
|
TestReportForall
|
||||||
TestRenamer
|
TestRenamer
|
||||||
|
TestLambdaLifter
|
||||||
|
DoStrings
|
||||||
|
|
||||||
Grammar.Abs
|
Grammar.Abs
|
||||||
Grammar.Lex
|
Grammar.Lex
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,9 @@
|
||||||
data Maybe () where {
|
data Maybe () where
|
||||||
Just : Int -> Maybe ()
|
Just : Int -> Maybe ()
|
||||||
Nothing : Maybe ()
|
Nothing : Maybe ()
|
||||||
};
|
|
||||||
|
|
||||||
demoFunc x = case x of {
|
demoFunc x = case x of
|
||||||
Just x => x + 24;
|
Just x => x + 24
|
||||||
Nothing => 0;
|
Nothing => 0
|
||||||
};
|
|
||||||
|
|
||||||
main = demoFunc Nothing ;
|
main = demoFunc Nothing
|
||||||
|
|
|
||||||
|
|
@ -1,26 +0,0 @@
|
||||||
main = case f (Just 10) of {
|
|
||||||
Just a => a ;
|
|
||||||
Nothing => 0 ;
|
|
||||||
};
|
|
||||||
|
|
||||||
f x = bind (fmap (\s . s + 1) x) (\s . pure (s + 10)) ;
|
|
||||||
|
|
||||||
data Maybe () where {
|
|
||||||
Just : Int -> Maybe ()
|
|
||||||
Nothing : Maybe ()
|
|
||||||
};
|
|
||||||
|
|
||||||
fmap : (Int -> Int) -> Maybe () -> Maybe () ;
|
|
||||||
fmap f m = case m of {
|
|
||||||
Just a => pure (f a) ;
|
|
||||||
Nothing => Nothing ;
|
|
||||||
};
|
|
||||||
|
|
||||||
pure : Int -> Maybe () ;
|
|
||||||
pure x = Just x;
|
|
||||||
|
|
||||||
bind : Maybe () -> (Int -> Maybe ()) -> Maybe () ;
|
|
||||||
bind x f = case x of {
|
|
||||||
Just x => f x ;
|
|
||||||
Nothing => Nothing ;
|
|
||||||
};
|
|
||||||
|
|
@ -1,17 +1,16 @@
|
||||||
{-# LANGUAGE LambdaCase #-}
|
{-# LANGUAGE LambdaCase #-}
|
||||||
{-# LANGUAGE OverloadedStrings #-}
|
{-# LANGUAGE OverloadedRecordDot #-}
|
||||||
|
{-# LANGUAGE OverloadedStrings #-}
|
||||||
|
|
||||||
|
|
||||||
module LambdaLifter (lambdaLift, freeVars, abstract, collectScs) where
|
module LambdaLifter where
|
||||||
|
|
||||||
import Auxiliary (mapAccumM, snoc)
|
import Auxiliary (onM, snoc)
|
||||||
import Control.Applicative (Applicative (liftA2))
|
import Control.Applicative (Applicative (liftA2))
|
||||||
import Control.Arrow (Arrow (second))
|
|
||||||
import Control.Monad.State (MonadState (get, put), State,
|
import Control.Monad.State (MonadState (get, put), State,
|
||||||
evalState)
|
evalState)
|
||||||
import Data.List (mapAccumL, partition)
|
import Data.Function (on)
|
||||||
import Data.Set (Set)
|
import Data.List (delete, mapAccumL, (\\))
|
||||||
import qualified Data.Set as Set
|
|
||||||
import Prelude hiding (exp)
|
import Prelude hiding (exp)
|
||||||
import TypeChecker.TypeCheckerIr
|
import TypeChecker.TypeCheckerIr
|
||||||
|
|
||||||
|
|
@ -21,176 +20,190 @@ import TypeChecker.TypeCheckerIr
|
||||||
-- @freeVars@ annotates all the free variables.
|
-- @freeVars@ annotates all the free variables.
|
||||||
-- @abstract@ converts lambdas into let expressions.
|
-- @abstract@ converts lambdas into let expressions.
|
||||||
-- @collectScs@ moves every non-constant let expression to a top-level function.
|
-- @collectScs@ moves every non-constant let expression to a top-level function.
|
||||||
|
--
|
||||||
lambdaLift :: Program -> Program
|
lambdaLift :: Program -> Program
|
||||||
lambdaLift (Program defs) = Program $ datatypes ++ ll binds
|
lambdaLift (Program ds) = Program (datatypes ++ binds)
|
||||||
where
|
where
|
||||||
ll = map DBind . collectScs . abstract . freeVars . map (\(DBind b) -> b)
|
datatypes = flip filter ds $ \case DData _ -> True
|
||||||
(binds, datatypes) = partition isBind defs
|
_ -> False
|
||||||
isBind = \case
|
binds = map DBind $ (collectScs . abstract . freeVars) [b | DBind b <- ds]
|
||||||
DBind _ -> True
|
|
||||||
_ -> False
|
-- lambdaLift (Program defs) = trace (printTree abst) $ Program $ datatypes ++ ll binds
|
||||||
|
-- where
|
||||||
|
-- abst = abstract frees
|
||||||
|
-- frees = freeVars [b | DBind b@(Bind (Ident s, _) _ _) <- binds, s == "f"]
|
||||||
|
--
|
||||||
|
-- ll = map DBind . collectScs . abstract . freeVars . map (\(DBind b) -> b)
|
||||||
|
-- (binds, datatypes) = partition isBind defs
|
||||||
|
-- isBind = \case
|
||||||
|
-- DBind _ -> True
|
||||||
|
-- _ -> False
|
||||||
|
|
||||||
-- | Annotate free variables
|
-- | Annotate free variables
|
||||||
freeVars :: [Bind] -> AnnBinds
|
freeVars :: [Bind] -> [ABind]
|
||||||
freeVars binds = [ (n, xs, freeVarsExp (Set.fromList $ map fst xs) e)
|
freeVars binds = [ let ae = freeVarsExp [] e
|
||||||
|
ae' = ae { frees = ae.frees \\ xs }
|
||||||
|
in ABind n xs ae'
|
||||||
| Bind n xs e <- binds
|
| Bind n xs e <- binds
|
||||||
]
|
]
|
||||||
|
|
||||||
freeVarsExp :: Set Ident -> ExpT -> AnnExpT
|
freeVarsExp :: Frees -> ExpT -> Ann AExpT
|
||||||
freeVarsExp localVars (exp, t) = case exp of
|
freeVarsExp localVars (ae, t) = case ae of
|
||||||
EVar n | Set.member n localVars -> (Set.singleton n, (AVar n, t))
|
EVar n | elem (n,t) localVars -> Ann { frees = [(n, t)]
|
||||||
| otherwise -> (mempty, (AVar n, t))
|
, term = (AVar n, t)
|
||||||
|
}
|
||||||
|
| otherwise -> Ann { frees = []
|
||||||
|
, term = (AVar n, t)
|
||||||
|
}
|
||||||
|
|
||||||
EInj n -> (mempty, (AVar n, t))
|
EInj n -> Ann { frees = [], term = (AInj n, t) }
|
||||||
|
|
||||||
ELit lit -> (mempty, (ALit lit, t))
|
ELit lit -> Ann { frees = [], term = (ALit lit, t) }
|
||||||
|
|
||||||
EApp e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), (AApp e1' e2', t))
|
EApp e1 e2 -> Ann { frees = annae1.frees <|| annae2.frees
|
||||||
|
, term = (AApp annae1 annae2, t)
|
||||||
|
}
|
||||||
where
|
where
|
||||||
e1' = freeVarsExp localVars e1
|
(annae1, annae2) = on (,) (freeVarsExp localVars) e1 e2
|
||||||
e2' = freeVarsExp localVars e2
|
|
||||||
|
|
||||||
EAdd e1 e2 -> (Set.union (freeVarsOf e1') (freeVarsOf e2'), (AAdd e1' e2', t))
|
EAdd e1 e2 -> Ann { frees = annae1.frees <|| annae2.frees
|
||||||
|
, term = (AAdd annae1 annae2, t)
|
||||||
|
}
|
||||||
where
|
where
|
||||||
e1' = freeVarsExp localVars e1
|
(annae1, annae2) = on (,) (freeVarsExp localVars) e1 e2
|
||||||
e2' = freeVarsExp localVars e2
|
|
||||||
|
|
||||||
EAbs par e -> (Set.delete par $ freeVarsOf e', (AAbs par e', t))
|
|
||||||
|
EAbs x e -> Ann { frees = delete (x,t_x) $ annae.frees
|
||||||
|
, term = (AAbs x annae, t) }
|
||||||
where
|
where
|
||||||
e' = freeVarsExp (Set.insert par localVars) e
|
annae = freeVarsExp (localVars <| (x,t_x)) e
|
||||||
|
t_x = case t of TFun t _ -> t
|
||||||
|
_ -> error "Impossible"
|
||||||
|
|
||||||
-- Sum free variables present in bind and the expression
|
-- Sum free variables present in bind and the expression
|
||||||
ELet (Bind (name, t_bind) parms rhs) e -> (Set.union binders_frees e_free, (ALet new_bind e', t))
|
-- let f x = x + y in f 5 + z → frees: y, z
|
||||||
|
ELet bind@(Bind n _ _) e ->
|
||||||
|
Ann { frees = delete n annae.frees <|| annbind.frees
|
||||||
|
, term = (ALet annbind annae, t)
|
||||||
|
}
|
||||||
where
|
where
|
||||||
binders_frees = Set.delete name $ freeVarsOf rhs'
|
annae = freeVarsExp (localVars <| n) e
|
||||||
e_free = Set.delete name $ freeVarsOf e'
|
annbind = freeVarsBind localVars bind
|
||||||
|
|
||||||
rhs' = freeVarsExp e_localVars rhs
|
ECase e branches ->
|
||||||
new_bind = ABind (name, t_bind) parms rhs'
|
Ann { frees = foldl (<||) annae.frees (map frees annbranches)
|
||||||
|
, term = (ACase annae annbranches, t)
|
||||||
e' = freeVarsExp e_localVars e
|
}
|
||||||
e_localVars = Set.insert name localVars
|
|
||||||
|
|
||||||
ECase e branches -> (frees, (ACase e' branches', t))
|
|
||||||
where
|
where
|
||||||
frees = foldr (\b s -> Set.union s $ fst b) (freeVarsOf e') branches'
|
annae = freeVarsExp localVars e
|
||||||
e' = freeVarsExp localVars e
|
annbranches = map (freeVarsBranch localVars) branches
|
||||||
branches' = map (freeVarsBranch localVars) branches
|
|
||||||
|
|
||||||
|
|
||||||
freeVarsBranch :: Set Ident -> Branch' Type -> (Set Ident, AnnBranch')
|
freeVarsBind :: Frees -> Bind -> Ann ABind
|
||||||
freeVarsBranch localVars (Branch (patt, t) exp) = (frees, AnnBranch (patt, t) exp')
|
freeVarsBind localVars (Bind name vars e) =
|
||||||
|
Ann { frees = annae.frees \\ vars
|
||||||
|
, term = ABind name vars annae
|
||||||
|
}
|
||||||
where
|
where
|
||||||
frees = freeVarsOf exp' Set.\\ freeVarsOfPattern patt
|
annae = freeVarsExp (localVars <|| vars) e
|
||||||
exp' = freeVarsExp localVars exp
|
|
||||||
freeVarsOfPattern = Set.fromList . go []
|
|
||||||
|
freeVarsBranch :: Frees -> Branch -> Ann ABranch
|
||||||
|
freeVarsBranch localVars (Branch pt e) =
|
||||||
|
Ann { frees = annae.frees \\ varsInPattern
|
||||||
|
, term = ABranch pt annae
|
||||||
|
}
|
||||||
|
where
|
||||||
|
annae = freeVarsExp localVars e
|
||||||
|
varsInPattern = go [] pt
|
||||||
where
|
where
|
||||||
go acc = \case
|
go acc (p, t) = case p of
|
||||||
PVar n -> snoc n acc
|
PVar n -> acc <| (n, t)
|
||||||
PInj _ ps -> foldl go acc $ map fst ps
|
PInj _ ps -> foldl go acc ps
|
||||||
|
_ -> []
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
freeVarsOf :: AnnExpT -> Set Ident
|
|
||||||
freeVarsOf = fst
|
|
||||||
|
|
||||||
-- AST annotated with free variables
|
-- AST annotated with free variables
|
||||||
type AnnBinds = [(Id, [Id], AnnExpT)]
|
|
||||||
|
|
||||||
type AnnExpT = (Set Ident, AnnExpT')
|
type Frees = [(Ident, Type)]
|
||||||
|
|
||||||
data ABind = ABind Id [Id] AnnExpT deriving Show
|
data Ann a = Ann
|
||||||
|
{ frees :: Frees
|
||||||
|
, term :: a
|
||||||
|
} deriving (Show, Eq)
|
||||||
|
|
||||||
type AnnExpT' = (AnnExp, Type)
|
data ABind = ABind Id [Id] (Ann AExpT) deriving (Show, Eq)
|
||||||
|
data ABranch = ABranch (Pattern, Type) (Ann AExpT) deriving (Show, Eq)
|
||||||
|
|
||||||
type AnnBranch = (Set Ident, AnnBranch')
|
type AExpT = (AExp, Type)
|
||||||
data AnnBranch' = AnnBranch (Pattern, Type) AnnExpT
|
|
||||||
deriving Show
|
|
||||||
|
|
||||||
data AnnExp = AVar Ident
|
data AExp = AVar Ident
|
||||||
| AInj Ident
|
| AInj Ident
|
||||||
| ALit Lit
|
| ALit Lit
|
||||||
| ALet ABind AnnExpT
|
| ALet (Ann ABind) (Ann AExpT)
|
||||||
| AApp AnnExpT AnnExpT
|
| AApp (Ann AExpT) (Ann AExpT)
|
||||||
| AAdd AnnExpT AnnExpT
|
| AAdd (Ann AExpT) (Ann AExpT)
|
||||||
| AAbs Ident AnnExpT
|
| AAbs Ident (Ann AExpT)
|
||||||
| ACase AnnExpT [AnnBranch]
|
| ACase (Ann AExpT) [Ann ABranch]
|
||||||
deriving Show
|
deriving (Show, Eq)
|
||||||
|
|
||||||
-- | Lift lambdas to let expression of the form @let sc = \v₁ x₁ -> e₁@.
|
abstract :: [ABind] -> [Bind]
|
||||||
-- Free variables are @v₁ v₂ .. vₙ@ are bound.
|
abstract bs = evalState (mapM (abstractAnnBind . Ann []) bs) 0
|
||||||
abstract :: AnnBinds -> [Bind]
|
|
||||||
abstract prog = evalState (mapM go prog) 0
|
abstractAnnBind :: Ann ABind -> State Int Bind
|
||||||
|
abstractAnnBind Ann { term = ABind name vars annae } =
|
||||||
|
Bind name (vars' <|| vars) <$> abstractAnnExp annae'
|
||||||
where
|
where
|
||||||
go :: (Id, [Id], AnnExpT) -> State Int Bind
|
(annae', vars') = go [] annae
|
||||||
go (name, parms, rhs) = Bind name (parms ++ parms1) <$> abstractExp rhs'
|
|
||||||
where
|
where
|
||||||
(rhs', parms1) = flattenLambdasAnn rhs
|
go acc = \case
|
||||||
|
Ann { term = (AAbs x ae, TFun t _) } -> go (snoc (x, t) acc) ae
|
||||||
|
ae -> (ae, acc)
|
||||||
|
|
||||||
|
abstractAnnExp :: Ann AExpT -> State Int ExpT
|
||||||
-- | Flatten nested lambdas and collect the parameters
|
abstractAnnExp Ann {frees, term = (annae, typ) } = case annae of
|
||||||
-- @\x.\y.\z. ae → (ae, [x,y,z])@
|
AVar n -> pure (EVar n, typ)
|
||||||
flattenLambdasAnn :: AnnExpT -> (AnnExpT, [Id])
|
AInj n -> pure (EInj n, typ)
|
||||||
flattenLambdasAnn ae = go (ae, [])
|
|
||||||
where
|
|
||||||
go :: (AnnExpT, [Id]) -> (AnnExpT, [Id])
|
|
||||||
go ((free, (e, t)), acc)
|
|
||||||
| AAbs par (free1, e1) <- e
|
|
||||||
, TFun t_par _ <- t
|
|
||||||
= go ((Set.delete par free1, e1), snoc (par, t_par) acc)
|
|
||||||
| otherwise = ((free, (e, t)), acc)
|
|
||||||
|
|
||||||
abstractExp :: AnnExpT -> State Int ExpT
|
|
||||||
abstractExp (free, (exp, typ)) = case exp of
|
|
||||||
AVar n -> pure (EVar n, typ)
|
|
||||||
AInj n -> pure (EInj n, typ)
|
|
||||||
ALit lit -> pure (ELit lit, typ)
|
ALit lit -> pure (ELit lit, typ)
|
||||||
AApp e1 e2 -> (, typ) <$> liftA2 EApp (abstractExp e1) (abstractExp e2)
|
AApp annae1 annae2 -> (, typ) <$> onM EApp abstractAnnExp annae1 annae2
|
||||||
AAdd e1 e2 -> (, typ) <$> liftA2 EAdd (abstractExp e1) (abstractExp e2)
|
AAdd annae1 annae2 -> (, typ) <$> onM EAdd abstractAnnExp annae1 annae2
|
||||||
ALet b e -> (, typ) <$> liftA2 ELet (go b) (abstractExp e)
|
|
||||||
where
|
|
||||||
go (ABind name parms rhs) = do
|
|
||||||
(rhs', parms1) <- flattenLambdas <$> skipLambdas abstractExp rhs
|
|
||||||
pure $ Bind name (parms ++ parms1) rhs'
|
|
||||||
|
|
||||||
skipLambdas :: (AnnExpT -> State Int ExpT) -> AnnExpT -> State Int ExpT
|
-- \x. \y. x + y + z ⇒ let sc x y z = x + y + z in sc
|
||||||
skipLambdas f (free, (ae, t)) = case ae of
|
AAbs x annae' -> do
|
||||||
AAbs par ae1 -> do
|
|
||||||
ae1' <- skipLambdas f ae1
|
|
||||||
pure (EAbs par ae1', t)
|
|
||||||
_ -> f (free, (ae, t))
|
|
||||||
|
|
||||||
ACase e branches -> (, typ) <$> liftA2 ECase (abstractExp e) (mapM abstractBranch branches)
|
|
||||||
|
|
||||||
|
|
||||||
-- Lift lambda into let and bind free variables
|
|
||||||
AAbs parm e -> do
|
|
||||||
i <- nextNumber
|
i <- nextNumber
|
||||||
rhs <- abstractExp e
|
rhs <- abstractAnnExp annae''
|
||||||
|
|
||||||
let sc_name = Ident ("sc_" ++ show i)
|
let sc_name = Ident ("sc_" ++ show i)
|
||||||
sc = (ELet (Bind (sc_name, typ) vars rhs) (EVar sc_name, typ), typ)
|
sc = (ELet (Bind (sc_name, typ) vars rhs) (EVar sc_name, typ), typ)
|
||||||
pure $ foldl applyVars sc freeList
|
pure $ foldl applyFree sc frees
|
||||||
|
|
||||||
where
|
where
|
||||||
freeList = Set.toList free
|
vars = frees <| (x, t_x) <|| ys
|
||||||
vars = zip names $ getVars typ
|
t_x = case typ of TFun t _ -> t
|
||||||
names = snoc parm freeList
|
_ -> error "Impossible"
|
||||||
applyVars (e, t) name = (EApp (e, t) (EVar name, t_var), t_return)
|
|
||||||
|
(annae'', ys) = go [] annae'
|
||||||
where
|
where
|
||||||
(t_var, t_return) = case t of
|
go acc = \case
|
||||||
TFun t1 t2 -> (t1, t2)
|
Ann { term = (AAbs x ae, TFun t _) } -> go (snoc (x, t) acc) ae
|
||||||
|
ae -> (ae, acc)
|
||||||
|
|
||||||
|
|
||||||
|
applyFree :: (Exp' Type, Type) -> (Ident, Type) -> (Exp' Type, Type)
|
||||||
|
applyFree (e, t_e) (x, t_x) = (EApp (e, t_e) (EVar x, t_x), t_e')
|
||||||
|
where
|
||||||
|
t_e' = case t_e of TFun _ t -> t
|
||||||
|
_ -> error "Impossible"
|
||||||
|
|
||||||
abstractBranch :: AnnBranch -> State Int Branch
|
ACase annae' bs -> do
|
||||||
abstractBranch (_, AnnBranch patt exp) = Branch patt <$> abstractExp exp
|
bs <- mapM go bs
|
||||||
|
e <- abstractAnnExp annae'
|
||||||
|
pure (ECase e bs, typ)
|
||||||
|
where
|
||||||
|
go Ann { term = ABranch p annae } = Branch p <$> abstractAnnExp annae
|
||||||
|
|
||||||
|
ALet b annae' ->
|
||||||
|
(, typ) <$> liftA2 ELet (abstractAnnBind b) (abstractAnnExp annae')
|
||||||
|
|
||||||
nextNumber :: State Int Int
|
|
||||||
nextNumber = do
|
|
||||||
i <- get
|
|
||||||
put $ succ i
|
|
||||||
pure i
|
|
||||||
|
|
||||||
-- | Collects supercombinators by lifting non-constant let expressions
|
-- | Collects supercombinators by lifting non-constant let expressions
|
||||||
collectScs :: [Bind] -> [Bind]
|
collectScs :: [Bind] -> [Bind]
|
||||||
|
|
@ -232,34 +245,28 @@ collectScsExp expT@(exp, typ) = case exp of
|
||||||
--
|
--
|
||||||
-- > f = let sc x y = rhs in e
|
-- > f = let sc x y = rhs in e
|
||||||
--
|
--
|
||||||
ELet (Bind name parms rhs) e -> if null parms
|
ELet (Bind name parms rhs) e
|
||||||
then ( rhs_scs ++ et_scs, (ELet bind et', snd et'))
|
| null parms -> (rhs_scs ++ et_scs, (ELet bind et', snd et'))
|
||||||
else (bind : rhs_scs ++ et_scs, et')
|
| otherwise -> (bind : rhs_scs ++ et_scs, et')
|
||||||
where
|
where
|
||||||
bind = Bind name parms rhs'
|
bind = Bind name parms rhs'
|
||||||
(rhs_scs, rhs') = collectScsExp rhs
|
(rhs_scs, rhs') = collectScsExp rhs
|
||||||
(et_scs, et') = collectScsExp e
|
(et_scs, et') = collectScsExp e
|
||||||
|
|
||||||
collectScsBranch (Branch patt exp) = (scs, Branch patt exp')
|
collectScsBranch (Branch patt exp) = (scs, Branch patt exp')
|
||||||
where (scs, exp') = collectScsExp exp
|
where (scs, exp') = collectScsExp exp
|
||||||
|
|
||||||
|
nextNumber :: State Int Int
|
||||||
|
nextNumber = do
|
||||||
|
i <- get
|
||||||
|
put $ succ i
|
||||||
|
pure i
|
||||||
|
|
||||||
-- @\x.\y.\z. e → (e, [x,y,z])@
|
|
||||||
flattenLambdas :: ExpT -> (ExpT, [Id])
|
|
||||||
flattenLambdas = go . (, [])
|
|
||||||
where
|
|
||||||
go ((e, t), acc) = case e of
|
|
||||||
EAbs name e1 -> go (e1, snoc (name, t_var) acc)
|
|
||||||
where t_var = head $ getVars t
|
|
||||||
_ -> ((e, t), acc)
|
|
||||||
|
|
||||||
getVars :: Type -> [Type]
|
(<|) :: Eq a => [a] -> a -> [a]
|
||||||
getVars = fst . partitionType
|
xs <| x | elem x xs = xs
|
||||||
|
| otherwise = snoc x xs
|
||||||
|
|
||||||
partitionType :: Type -> ([Type], Type)
|
(<||) :: Eq a => [a] -> [a] -> [a]
|
||||||
partitionType = go []
|
xs <|| ys = foldl (<|) xs ys
|
||||||
where
|
|
||||||
go acc t = case t of
|
|
||||||
TFun t1 t2 -> go (snoc t1 acc) t2
|
|
||||||
_ -> (acc, t)
|
|
||||||
|
|
||||||
|
|
|
||||||
117
tests/TestLambdaLifter.hs
Normal file
117
tests/TestLambdaLifter.hs
Normal file
|
|
@ -0,0 +1,117 @@
|
||||||
|
{-# LANGUAGE PatternSynonyms #-}
|
||||||
|
{-# HLINT ignore "Use camelCase" #-}
|
||||||
|
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
|
||||||
|
{-# LANGUAGE OverloadedStrings #-}
|
||||||
|
{-# LANGUAGE QualifiedDo #-}
|
||||||
|
|
||||||
|
module TestLambdaLifter where
|
||||||
|
|
||||||
|
import Test.Hspec
|
||||||
|
|
||||||
|
import AnnForall (annotateForall)
|
||||||
|
import Control.Monad ((<=<))
|
||||||
|
import Control.Monad.Error.Class (liftEither)
|
||||||
|
import Control.Monad.Extra (eitherM)
|
||||||
|
import Grammar.ErrM (Err, pattern Bad, pattern Ok)
|
||||||
|
import Grammar.Layout (resolveLayout)
|
||||||
|
import Grammar.Par (myLexer, pProgram)
|
||||||
|
import Grammar.Print (printTree)
|
||||||
|
import LambdaLifter
|
||||||
|
import Renamer.Renamer (rename)
|
||||||
|
import ReportForall (reportForall)
|
||||||
|
import TypeChecker.RemoveForall (removeForall)
|
||||||
|
import TypeChecker.ReportTEVar (reportTEVar)
|
||||||
|
import TypeChecker.TypeChecker (TypeChecker (Bi))
|
||||||
|
import TypeChecker.TypeCheckerBidir (typecheck)
|
||||||
|
import TypeChecker.TypeCheckerIr
|
||||||
|
|
||||||
|
|
||||||
|
test = hspec testLambdaLifter
|
||||||
|
|
||||||
|
testLambdaLifter = describe "Test Lambda Lifter" $ do
|
||||||
|
undefined
|
||||||
|
-- frees_exp1
|
||||||
|
|
||||||
|
-- frees_exp1 = specify "Free variables 1" $
|
||||||
|
-- freeVarsExp [] (EAbs "x" (EVar "x", TVar' "a"), TVar' "a")
|
||||||
|
-- `shouldBe` answer
|
||||||
|
-- where
|
||||||
|
-- answer = Ann { frees = []
|
||||||
|
-- , term = (AAbs (Ident "x") (Ann { frees = [Ident "x"]
|
||||||
|
-- , term = (AVar (Ident "x"),TVar (MkTVar (Ident "a")))
|
||||||
|
-- }
|
||||||
|
-- ),TVar (MkTVar (Ident "a")))
|
||||||
|
-- }
|
||||||
|
|
||||||
|
|
||||||
|
abs_1 = undefined
|
||||||
|
where
|
||||||
|
input = unlines [ "data List (a) where"
|
||||||
|
, " Nil : List (a)"
|
||||||
|
, " Cons : a -> List (a) -> List (a)"
|
||||||
|
, "map : (a -> b) -> List (a) -> List (b)"
|
||||||
|
, "add : Int -> Int -> Int"
|
||||||
|
|
||||||
|
, "f : List (Int)"
|
||||||
|
, "f = (\\x.\\ys. map (\\y. add y x) ys) 4 (Cons 1 (Cons 2 Nil))"
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
runPrintFree = print $ freeVarsExp [] (EAbs "x" (EVar "x", TVar' "a"), TVar' "a")
|
||||||
|
|
||||||
|
runAbstract = either putStrLn (putStrLn . printTree) (runAbs s2)
|
||||||
|
where
|
||||||
|
s = unlines [ "add : Int -> Int -> Int"
|
||||||
|
, "f : Int -> Int -> Int"
|
||||||
|
, "f x y = add x y"
|
||||||
|
, "f = \\x. (\\y. add x y)"
|
||||||
|
]
|
||||||
|
|
||||||
|
s2 = unlines [ "data List (a) where"
|
||||||
|
, " Nil : List (a)"
|
||||||
|
, " Cons : a -> List (a) -> List (a)"
|
||||||
|
, "map : (a -> b) -> List (a) -> List (b)"
|
||||||
|
, "add : Int -> Int -> Int"
|
||||||
|
|
||||||
|
, "f : List (Int)"
|
||||||
|
, "f = (\\x.\\ys. map (\\y. add y x) ys) 4 (Cons 1 (Cons 2 Nil))"
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
runCollect = either putStrLn (putStrLn . printTree) (run s)
|
||||||
|
where
|
||||||
|
s = unlines [ "data List (a) where"
|
||||||
|
, " Nil : List (a)"
|
||||||
|
, " Cons : a -> List (a) -> List (a)"
|
||||||
|
, "add : Int -> Int -> Int"
|
||||||
|
, "map : (a -> b) -> List (a) -> List (b)"
|
||||||
|
, "map f xs = case xs of"
|
||||||
|
, " Nil => Nil"
|
||||||
|
, " Cons x xs => Cons (f x) (map f xs)"
|
||||||
|
|
||||||
|
, "f : List (Int)"
|
||||||
|
, "f = (\\x.\\ys. map (\\y. add y x) ys) 4 (Cons 1 (Cons 2 Nil))"
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
run = fmap collectScs . runAbs
|
||||||
|
|
||||||
|
runAbs s = do
|
||||||
|
Program ds <- run' s
|
||||||
|
pure $ (abstract . freeVars) [b | DBind b <- ds]
|
||||||
|
|
||||||
|
|
||||||
|
run' = fmap removeForall
|
||||||
|
. reportTEVar
|
||||||
|
<=< typecheck
|
||||||
|
<=< run''
|
||||||
|
|
||||||
|
run'' s = do
|
||||||
|
p <- (pProgram . resolveLayout True . myLexer) s
|
||||||
|
reportForall Bi p
|
||||||
|
(rename <=< annotateForall) p
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue