Add implicit foralls for bidir, update and unify pipeline

This commit is contained in:
Martin Fredin 2023-04-03 17:34:33 +02:00
parent 12bca1c32d
commit 9870802371
33 changed files with 1010 additions and 1055 deletions

View file

@ -1,224 +1,112 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedRecordDot #-}
module Renamer.Renamer (rename) where
import Auxiliary (mapAccumM)
import Control.Applicative (Applicative (liftA2))
import Control.Monad (when)
import Control.Monad.Except (
ExceptT,
MonadError (catchError, throwError),
runExceptT,
)
import Control.Monad.State (
MonadState,
State,
StateT,
evalState,
evalStateT,
get,
gets,
lift,
mapAndUnzipM,
modify,
put,
)
import Data.Function (on)
import Data.Map (Map)
import Data.Map qualified as Map
import Data.Maybe (fromMaybe)
import Data.Set (Set)
import Data.Set qualified as Set
import Data.Tuple.Extra (dupe, second)
import Grammar.Abs
import Grammar.ErrM (Err)
import Grammar.Print (printTree)
import Auxiliary (maybeToRightM, onM, partitionDefs)
import Control.Applicative (liftA2)
import Control.Monad.Except (ExceptT, MonadError, runExceptT)
import Control.Monad.State (MonadState, State, evalState, gets,
modify)
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Tuple.Extra (dupe)
import Grammar.Abs
import Grammar.ErrM (Err)
import Grammar.Print (printTree)
-- | Rename all variables and local binds
rename :: Program -> Err Program
rename (Program defs) = Program <$> renameDefs defs
rename (Program defs) = rename' $ do
ds' <- mapM (fmap DData . rnData) ds
ss' <- mapM (fmap DSig . rnSig) ss
bs' <- mapM (fmap DBind . rnTopBind) bs
pure $ Program (ds' ++ ss' ++ bs')
where
(ds, ss, bs) = partitionDefs defs
rename' = flip evalState initCxt
. runExceptT
. runRn
initCxt = Cxt
{ counter = 0
, names = Map.fromList $ [ dupe n | Sig n _ <- ss ]
++ [ dupe n | Bind n _ _ <- bs ]
}
rnData :: Data -> Rn Data
rnData (Data typ injs) = liftA2 Data (rnType typ) (mapM rnInj injs)
where
rnInj (Inj name t) = Inj name <$> rnType t
initCxt :: Cxt
initCxt = Cxt 0 0
rnSig :: Sig -> Rn Sig
rnSig (Sig name typ) = liftA2 Sig (getName name) (rnType typ)
rnType :: Type -> Rn Type
rnType = \case
TVar (MkTVar name) -> TVar . MkTVar <$> getName name
TData name ts -> TData name <$> localNames (mapM rnType ts)
TFun t1 t2 -> onM TFun (localNames . rnType) t1 t2
TAll (MkTVar name) t -> liftA2 (TAll . MkTVar) (newName name) (rnType t)
typ -> pure typ
rnTopBind :: Bind -> Rn Bind
rnTopBind = rnBind' False
rnLocalBind :: Bind -> Rn Bind
rnLocalBind = rnBind' True
rnBind' :: Bool -> Bind -> Rn Bind
rnBind' isLocal (Bind name vars rhs) = do
name' <- if isLocal then newName name else getName name
(vars', rhs') <- localNames $ liftA2 (,) (mapM newName vars) (rnExp rhs)
pure (Bind name' vars' rhs')
rnExp :: Exp -> Rn Exp
rnExp = \case
EVar x -> EVar <$> getName x
EInj x -> pure (EInj x)
ELit lit -> pure (ELit lit)
EApp e1 e2 -> onM EApp (localNames . rnExp) e1 e2
EAdd e1 e2 -> onM EAdd (localNames . rnExp) e1 e2
ELet bind e -> liftA2 ELet (rnLocalBind bind) (rnExp e)
EAbs x e -> liftA2 EAbs (newName x) (rnExp e)
EAnn e t -> liftA2 EAnn (rnExp e) (rnType t)
ECase e bs -> liftA2 ECase (rnExp e) (mapM (localNames . rnBranch) bs)
rnBranch :: Branch -> Rn Branch
rnBranch (Branch p e) = liftA2 Branch (rnPattern p) (rnExp e)
rnPattern :: Pattern -> Rn Pattern
rnPattern = \case
PVar x -> PVar <$> newName x
PLit lit -> pure (PLit lit)
PCatch -> pure PCatch
PEnum name -> pure (PEnum name)
PInj name ps -> PInj name <$> mapM rnPattern ps
data Cxt = Cxt
{ var_counter :: Int
, tvar_counter :: Int
{ counter :: Int
, names :: Map LIdent LIdent
}
-- | Rename monad. State holds the number of renamed names.
newtype Rn a = Rn {runRn :: ExceptT String (State Cxt) a}
deriving (Functor, Applicative, Monad, MonadState Cxt, MonadError String)
-- | Maps old to new name
type Names = Map String String
getName :: LIdent -> Rn LIdent
getName name = maybeToRightM err =<< gets (Map.lookup name . names)
where err = "Can't find new name " ++ printTree name
renameDefs :: [Def] -> Err [Def]
renameDefs defs = evalState (runExceptT (runRn $ mapM renameDef defs)) initCxt
newName :: LIdent -> Rn LIdent
newName name = do
name' <- gets (mk name . counter)
modify $ \cxt -> cxt { counter = succ cxt.counter
, names = Map.insert name name' cxt.names
}
pure name'
where
initNames = Map.fromList [dupe s | DBind (Bind (LIdent s) _ _) <- defs]
mk (LIdent name) i = LIdent ("#" ++ show i ++ name)
renameDef :: Def -> Rn Def
renameDef = \case
DSig (Sig name typ) -> DSig . Sig name <$> renameTVars typ
DBind (Bind name vars rhs) -> do
(new_names, vars') <- newNamesL initNames vars
rhs' <- snd <$> renameExp new_names rhs
pure . DBind $ Bind name vars' rhs'
DData (Data typ injs) -> do
tvars <- collectTVars [] typ
tvars' <- mapM nextNameTVar tvars
let tvars_lt = zip tvars tvars'
typ' = substituteTVar tvars_lt typ
injs' = map (renameInj tvars_lt) injs
pure . DData $ Data typ' injs'
where
collectTVars tvars = \case
TAll tvar t -> collectTVars (tvar : tvars) t
TData _ _ -> pure tvars
_ -> throwError ("Bad data type definition: " ++ printTree typ)
renameInj :: [(TVar, TVar)] -> Inj -> Inj
renameInj new_types (Inj name typ) =
Inj name $ substituteTVar new_types typ
substituteTVar :: [(TVar, TVar)] -> Type -> Type
substituteTVar new_names typ = case typ of
TLit _ -> typ
TVar tvar
| Just tvar' <- lookup tvar new_names ->
TVar tvar'
| otherwise ->
typ
TFun t1 t2 -> on TFun substitute' t1 t2
TAll tvar t
| Just tvar' <- lookup tvar new_names ->
TAll tvar' $ substitute' t
| otherwise ->
TAll tvar $ substitute' t
TData name typs -> TData name $ map substitute' typs
_ -> error ("Impossible " ++ show typ)
where
substitute' = substituteTVar new_names
renameExp :: Names -> Exp -> Rn (Names, Exp)
renameExp old_names = \case
EVar (LIdent n) -> pure (old_names, EVar . LIdent . fromMaybe n $ Map.lookup n old_names)
EInj (UIdent n) -> pure (old_names, EInj . UIdent . fromMaybe n $ Map.lookup n old_names)
ELit lit -> pure (old_names, ELit lit)
EApp e1 e2 -> do
(env1, e1') <- renameExp old_names e1
(env2, e2') <- renameExp old_names e2
pure (Map.union env1 env2, EApp e1' e2')
EAdd e1 e2 -> do
(env1, e1') <- renameExp old_names e1
(env2, e2') <- renameExp old_names e2
pure (Map.union env1 env2, EAdd e1' e2')
-- TODO fix shadowing
ELet (Bind name vars rhs) e -> do
(new_names, name') <- newNameL old_names name
(new_names', vars') <- newNamesL new_names vars
(new_names'', rhs') <- renameExp new_names' rhs
(new_names''', e') <- renameExp new_names'' e
pure (new_names''', ELet (Bind name' vars' rhs') e')
EAbs par e -> do
(new_names, par') <- newNameL old_names par
(new_names', e') <- renameExp new_names e
pure (new_names', EAbs par' e')
EAnn e t -> do
(new_names, e') <- renameExp old_names e
t' <- renameTVars t
pure (new_names, EAnn e' t')
ECase e injs -> do
(new_names, e') <- renameExp old_names e
(new_names', injs') <- renameBranches new_names injs
pure (new_names', ECase e' injs')
renameBranches :: Names -> [Branch] -> Rn (Names, [Branch])
renameBranches ns xs = do
(new_names, xs') <- mapAndUnzipM (renameBranch ns) xs
if null new_names then return (mempty, xs') else return (head new_names, xs')
renameBranch :: Names -> Branch -> Rn (Names, Branch)
renameBranch ns b@(Branch patt e) = do
(new_names, patt') <- catchError (evalStateT (renamePattern ns patt) mempty) (\x -> throwError $ x ++ " in pattern '" ++ printTree b ++ "'")
(new_names', e') <- renameExp new_names e
return (new_names', Branch patt' e')
renamePattern :: Names -> Pattern -> StateT (Set LIdent) Rn (Names, Pattern)
renamePattern ns p = case p of
PInj cs ps -> do
(ns_new, ps') <- mapAccumM renamePattern ns ps
return (ns_new, PInj cs ps')
PVar name -> do
vs <- get
when (name `Set.member` vs) (throwError $ "Conflicting definitions of '" ++ printTree name ++ "'")
put (Set.insert name vs)
nn <- lift $ newNameL ns name
return $ second PVar nn
_ -> return (ns, p)
renameTVars :: Type -> Rn Type
renameTVars typ = case typ of
TAll tvar t -> do
tvar' <- nextNameTVar tvar
t' <- renameTVars $ substitute tvar tvar' t
pure $ TAll tvar' t'
TFun t1 t2 -> liftA2 TFun (renameTVars t1) (renameTVars t2)
_ -> pure typ
substitute ::
TVar -> -- α
TVar -> -- α_n
Type -> -- A
Type -- [α_n/α]A
substitute tvar1 tvar2 typ = case typ of
TLit _ -> typ
TVar tvar
| tvar == tvar1 -> TVar tvar2
| otherwise -> typ
TFun t1 t2 -> on TFun substitute' t1 t2
TAll tvar t
| tvar == tvar1 -> TAll tvar2 $ substitute' t
| otherwise -> TAll tvar $ substitute' t
TData name typs -> TData name $ map substitute' typs
_ -> error "Impossible"
where
substitute' = substitute tvar1 tvar2
-- | Create multiple names and add them to the name environment
newNamesL :: Names -> [LIdent] -> Rn (Names, [LIdent])
newNamesL = mapAccumM newNameL
-- | Create a new name and add it to name environment.
newNameL :: Names -> LIdent -> Rn (Names, LIdent)
newNameL env (LIdent old_name) = do
new_name <- makeName old_name
pure (Map.insert old_name new_name env, LIdent new_name)
-- | Create multiple names and add them to the name environment
newNamesU :: Names -> [UIdent] -> Rn (Names, [UIdent])
newNamesU = mapAccumM newNameU
-- | Create a new name and add it to name environment.
newNameU :: Names -> UIdent -> Rn (Names, UIdent)
newNameU env (UIdent old_name) = do
new_name <- makeName old_name
pure (Map.insert old_name new_name env, UIdent new_name)
-- | Annotate name with number and increment the number @prefix ⇒ prefix_number@.
makeName :: String -> Rn String
makeName prefix = do
i <- gets var_counter
let name = prefix ++ "_" ++ show i
modify $ \cxt -> cxt{var_counter = succ cxt.var_counter}
pure name
nextNameTVar :: TVar -> Rn TVar
nextNameTVar (MkTVar (LIdent s)) = do
i <- gets tvar_counter
let tvar = MkTVar . LIdent $ s ++ "_" ++ show i
modify $ \cxt -> cxt{tvar_counter = succ cxt.tvar_counter}
pure tvar
localNames :: MonadState Cxt m => m b -> m b
localNames m = do
old_names <- gets names
m <* modify ( \cxt' -> cxt' { names = old_names })

View file

@ -1,206 +0,0 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedRecordDot #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-}
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
{-# HLINT ignore "Use mapAndUnzipM" #-}
module Renamer.Renamer (rename) where
import Auxiliary (mapAccumM)
import Control.Applicative (Applicative (liftA2))
import Control.Monad (foldM)
import Control.Monad.Except (ExceptT, MonadError, runExceptT,
throwError)
import Control.Monad.Identity (Identity, runIdentity)
import Control.Monad.State (MonadState, StateT, evalStateT, gets,
modify)
import Data.Coerce (coerce)
import Data.Function (on)
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Maybe (fromMaybe)
import Data.Tuple.Extra (dupe)
import Grammar.Abs
-- | Rename all variables and local binds
rename :: Program -> Either String Program
rename (Program defs) = Program <$> renameDefs defs
renameDefs :: [Def] -> Either String [Def]
renameDefs defs = runIdentity $ runExceptT $ evalStateT (runRn $ mapM renameDef defs) initCxt
where
initNames = Map.fromList [dupe (coerce name) | DBind (Bind name _ _) <- defs]
renameDef :: Def -> Rn Def
renameDef = \case
DSig (Sig name typ) -> DSig . Sig name <$> renameTVars typ
DBind bind -> DBind . snd <$> renameBind initNames bind
DData (Data (TData cname types) constrs) -> do
tvars_ <- tvars
tvars' <- mapM nextNameTVar tvars_
let tvars_lt = zip tvars_ tvars'
typ' = map (substituteTVar tvars_lt) types
constrs' = map (renameConstr tvars_lt) constrs
pure . DData $ Data (TData cname typ') constrs'
where
tvars = concat <$> mapM (collectTVars []) types
collectTVars :: [TVar] -> Type -> Rn [TVar]
collectTVars tvars = \case
TAll tvar t -> collectTVars (tvar : tvars) t
TData _ _ -> return tvars
-- Should be monad error
TVar v -> return [v]
_ -> throwError ("Bad data type definition: " ++ show types)
DData (Data types _) -> throwError ("Bad data type definition: " ++ show types)
renameConstr :: [(TVar, TVar)] -> Inj -> Inj
renameConstr new_types (Inj name typ) =
Inj name $ substituteTVar new_types typ
renameBind :: Names -> Bind -> Rn (Names, Bind)
renameBind old_names (Bind name vars rhs) = do
(new_names, vars') <- newNames old_names (coerce vars)
(newer_names, rhs') <- renameExp new_names rhs
pure (newer_names, Bind name (coerce vars') rhs')
substituteTVar :: [(TVar, TVar)] -> Type -> Type
substituteTVar new_names typ = case typ of
TLit _ -> typ
TVar tvar
| Just tvar' <- lookup tvar new_names ->
TVar tvar'
| otherwise ->
typ
TFun t1 t2 -> on TFun substitute' t1 t2
TAll tvar t
| Just tvar' <- lookup tvar new_names ->
TAll tvar' $ substitute' t
| otherwise ->
TAll tvar $ substitute' t
TData name typs -> TData name $ map substitute' typs
_ -> error ("Impossible " ++ show typ)
where
substitute' = substituteTVar new_names
initCxt :: Cxt
initCxt = Cxt 0 0
data Cxt = Cxt
{ var_counter :: Int
, tvar_counter :: Int
}
-- | Rename monad. State holds the number of renamed names.
newtype Rn a = Rn {runRn :: StateT Cxt (ExceptT String Identity) a}
deriving (Functor, Applicative, Monad, MonadState Cxt, MonadError String)
-- | Maps old to new name
type Names = Map LIdent LIdent
renameExp :: Names -> Exp -> Rn (Names, Exp)
renameExp old_names = \case
EVar n -> pure (coerce old_names, EVar . fromMaybe n $ Map.lookup n old_names)
EInj n -> pure (old_names, EInj n)
ELit lit -> pure (old_names, ELit lit)
EApp e1 e2 -> do
(env1, e1') <- renameExp old_names e1
(env2, e2') <- renameExp old_names e2
pure (Map.union env1 env2, EApp e1' e2')
EAdd e1 e2 -> do
(env1, e1') <- renameExp old_names e1
(env2, e2') <- renameExp old_names e2
pure (Map.union env1 env2, EAdd e1' e2')
-- TODO fix shadowing
ELet bind e -> do
(new_names, bind') <- renameBind old_names bind
(new_names', e') <- renameExp new_names e
pure (new_names', ELet bind' e')
EAbs par e -> do
(new_names, par') <- newName old_names (coerce par)
(new_names', e') <- renameExp new_names e
pure (new_names', EAbs (coerce par') e')
EAnn e t -> do
(new_names, e') <- renameExp old_names e
t' <- renameTVars t
pure (new_names, EAnn e' t')
ECase e injs -> do
(new_names, e') <- renameExp old_names e
(new_names', injs') <- renameBranches new_names injs
pure (new_names', ECase e' injs')
renameBranches :: Names -> [Branch] -> Rn (Names, [Branch])
renameBranches ns xs = do
(new_names, xs') <- unzip <$> mapM (renameBranch ns) xs
if null new_names then return (mempty, xs') else return (head new_names, xs')
renameBranch :: Names -> Branch -> Rn (Names, Branch)
renameBranch ns (Branch init e) = do
(new_names, init') <- renamePattern ns init
(new_names', e') <- renameExp new_names e
return (new_names', Branch init' e')
renamePattern :: Names -> Pattern -> Rn (Names, Pattern)
renamePattern ns i = case i of
PInj cs ps -> do
(ns_new, ps) <- renamePatterns ns ps
return (ns_new, PInj cs ps)
rest -> return (ns, rest)
renamePatterns :: Names -> [Pattern] -> Rn (Names, [Pattern])
renamePatterns ns xs = do
(new_names, xs') <- unzip <$> mapM (renamePattern ns) xs
if null new_names then return (mempty, xs') else return (head new_names, xs')
renameTVars :: Type -> Rn Type
renameTVars typ = case typ of
TAll tvar t -> do
tvar' <- nextNameTVar tvar
t' <- renameTVars $ substitute tvar tvar' t
pure $ TAll tvar' t'
TFun t1 t2 -> liftA2 TFun (renameTVars t1) (renameTVars t2)
_ -> pure typ
substitute ::
TVar -> -- α
TVar -> -- α_n
Type -> -- A
Type -- [α_n/α]A
substitute tvar1 tvar2 typ = case typ of
TLit _ -> typ
TVar tvar'
| tvar' == tvar1 -> TVar tvar2
| otherwise -> typ
TFun t1 t2 -> on TFun substitute' t1 t2
TAll tvar t -> TAll tvar $ substitute' t
TData name typs -> TData name $ map substitute' typs
_ -> error "Impossible"
where
substitute' = substitute tvar1 tvar2
-- | Create a new name and add it to name environment.
newName :: Names -> LIdent -> Rn (Names, LIdent)
newName env old_name = do
new_name <- makeName old_name
pure (Map.insert old_name new_name env, new_name)
-- | Create multiple names and add them to the name environment
newNames :: Names -> [LIdent] -> Rn (Names, [LIdent])
newNames = mapAccumM newName
-- | Annotate name with number and increment the number @prefix ⇒ prefix_number@.
makeName :: LIdent -> Rn LIdent
makeName (LIdent prefix) = do
i <- gets var_counter
let name = LIdent $ prefix ++ "_" ++ show i
modify $ \cxt -> cxt{var_counter = succ cxt.var_counter}
pure name
nextNameTVar :: TVar -> Rn TVar
nextNameTVar (MkTVar (LIdent s)) = do
i <- gets tvar_counter
let tvar = MkTVar $ coerce $ s ++ "_" ++ show i
modify $ \cxt -> cxt{tvar_counter = succ cxt.tvar_counter}
pure tvar