Revert AnnForall change

This commit is contained in:
Martin Fredin 2023-05-04 23:54:19 +02:00
parent 15025a67f9
commit 0a588c4e14

View file

@ -1,16 +1,16 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedRecordDot #-}
module AnnForall (annotateForall) where
import Auxiliary (partitionDefs)
import Control.Applicative (Applicative (liftA2))
import Control.Monad.Except (throwError)
import Data.Function (on)
import Data.Set (Set)
import Data.Set qualified as Set
import Grammar.Abs
import Grammar.ErrM (Err)
import Auxiliary (partitionDefs)
import Control.Applicative (Applicative (liftA2))
import Control.Monad.Except (throwError)
import Data.Function (on)
import Data.Set (Set)
import qualified Data.Set as Set
import Grammar.Abs
import Grammar.ErrM (Err)
annotateForall :: Program -> Err Program
annotateForall (Program defs) = do
@ -21,31 +21,30 @@ annotateForall (Program defs) = do
ss' = map (DSig . annSig) ss
(ds, ss, bs) = partitionDefs defs
annData :: Data -> Err Data
annData (Data typ injs) = do
(typ', tvars) <- annTyp typ
pure (Data typ' $ map (annInj tvars) injs)
(typ', tvars) <- annTyp typ
pure (Data typ' $ map (annInj tvars) injs)
where
annTyp typ = do
(bounded, ts) <- boundedTVars mempty typ
unbounded <- Set.fromList <$> mapM assertTVar ts
let diff = unbounded Set.\\ bounded
typ' = foldr TAll typ diff
(typ',) . fst <$> boundedTVars mempty typ'
(typ', ) . fst <$> boundedTVars mempty typ'
where
boundedTVars tvars typ = case typ of
TAll tvar t -> boundedTVars (Set.insert tvar tvars) t
TData _ ts -> pure (tvars, ts)
_ -> throwError "Misformed data declaration"
TAll tvar t -> boundedTVars (Set.insert tvar tvars) t
TData _ ts -> pure (tvars, ts)
_ -> throwError "Misformed data declaration"
assertTVar typ = case typ of
TVar tvar -> pure tvar
_ ->
throwError $
unwords
[ "Misformed data declaration:"
, "Non type variable argument"
]
_ -> throwError $ unwords [ "Misformed data declaration:"
, "Non type variable argument"
]
annInj tvars (Inj n t) =
Inj n $ foldr TAll t (unboundedTVars t Set.\\ tvars)
@ -56,22 +55,20 @@ annBind :: Bind -> Err Bind
annBind (Bind name vars exp) = Bind name vars <$> annExp exp
where
annExp = \case
-- Annotated types should not be
-- foralled without the consent of the user
EAnn e t -> flip EAnn t <$> annExp e
EApp e1 e2 -> liftA2 EApp (annExp e1) (annExp e2)
EAdd e1 e2 -> liftA2 EAdd (annExp e1) (annExp e2)
EAnn e t -> flip EAnn (annType t) <$> annExp e
EApp e1 e2 -> liftA2 EApp (annExp e1) (annExp e2)
EAdd e1 e2 -> liftA2 EAdd (annExp e1) (annExp e2)
ELet bind e -> liftA2 ELet (annBind bind) (annExp e)
EAbs x e -> EAbs x <$> annExp e
ECase e bs -> liftA2 ECase (annExp e) (mapM annBranch bs)
e -> pure e
EAbs x e -> EAbs x <$> annExp e
ECase e bs -> liftA2 ECase (annExp e) (mapM annBranch bs)
e -> pure e
annBranch (Branch p e) = Branch p <$> annExp e
annType :: Type -> Type
annType typ = go $ unboundedTVars typ
where
go us
| null us = typ
| null us = typ
| otherwise = foldr TAll typ us
unboundedTVars :: Type -> Set TVar
@ -82,25 +79,22 @@ unboundedTVars' bs typ = tvars.unbounded Set.\\ tvars.bounded
where
tvars = gatherTVars typ
gatherTVars = \case
TAll tvar t ->
TVars
{ bounded = Set.singleton tvar
, unbounded = unboundedTVars' (Set.insert tvar bs) t
}
TVar tvar -> uTVars $ Set.singleton tvar
TFun t1 t2 -> uTVars $ on Set.union (unboundedTVars' bs) t1 t2
TData _ typs -> uTVars $ foldr (Set.union . unboundedTVars' bs) mempty typs
_ -> TVars{bounded = mempty, unbounded = mempty}
TAll tvar t -> TVars { bounded = Set.singleton tvar
, unbounded = unboundedTVars' (Set.insert tvar bs) t
}
TVar tvar -> uTVars $ Set.singleton tvar
TFun t1 t2 -> uTVars $ on Set.union (unboundedTVars' bs) t1 t2
TData _ typs -> uTVars $ foldr (Set.union . unboundedTVars' bs) mempty typs
_ -> TVars { bounded = mempty, unbounded = mempty }
data TVars = TVars
{ bounded :: Set TVar
, unbounded :: Set TVar
}
deriving (Eq, Show, Ord)
{ bounded :: Set TVar
, unbounded :: Set TVar
} deriving (Eq, Show, Ord)
uTVars :: Set TVar -> TVars
uTVars us =
TVars
{ bounded = mempty
, unbounded = us
}
uTVars us = TVars
{ bounded = mempty
, unbounded = us
}