Revert AnnForall change

This commit is contained in:
Martin Fredin 2023-05-04 23:54:19 +02:00
parent 15025a67f9
commit 0a588c4e14

View file

@ -8,7 +8,7 @@ import Control.Applicative (Applicative (liftA2))
import Control.Monad.Except (throwError) import Control.Monad.Except (throwError)
import Data.Function (on) import Data.Function (on)
import Data.Set (Set) import Data.Set (Set)
import Data.Set qualified as Set import qualified Data.Set as Set
import Grammar.Abs import Grammar.Abs
import Grammar.ErrM (Err) import Grammar.ErrM (Err)
@ -21,17 +21,19 @@ annotateForall (Program defs) = do
ss' = map (DSig . annSig) ss ss' = map (DSig . annSig) ss
(ds, ss, bs) = partitionDefs defs (ds, ss, bs) = partitionDefs defs
annData :: Data -> Err Data annData :: Data -> Err Data
annData (Data typ injs) = do annData (Data typ injs) = do
(typ', tvars) <- annTyp typ (typ', tvars) <- annTyp typ
pure (Data typ' $ map (annInj tvars) injs) pure (Data typ' $ map (annInj tvars) injs)
where where
annTyp typ = do annTyp typ = do
(bounded, ts) <- boundedTVars mempty typ (bounded, ts) <- boundedTVars mempty typ
unbounded <- Set.fromList <$> mapM assertTVar ts unbounded <- Set.fromList <$> mapM assertTVar ts
let diff = unbounded Set.\\ bounded let diff = unbounded Set.\\ bounded
typ' = foldr TAll typ diff typ' = foldr TAll typ diff
(typ',) . fst <$> boundedTVars mempty typ' (typ', ) . fst <$> boundedTVars mempty typ'
where where
boundedTVars tvars typ = case typ of boundedTVars tvars typ = case typ of
TAll tvar t -> boundedTVars (Set.insert tvar tvars) t TAll tvar t -> boundedTVars (Set.insert tvar tvars) t
@ -40,10 +42,7 @@ annData (Data typ injs) = do
assertTVar typ = case typ of assertTVar typ = case typ of
TVar tvar -> pure tvar TVar tvar -> pure tvar
_ -> _ -> throwError $ unwords [ "Misformed data declaration:"
throwError $
unwords
[ "Misformed data declaration:"
, "Non type variable argument" , "Non type variable argument"
] ]
annInj tvars (Inj n t) = annInj tvars (Inj n t) =
@ -56,9 +55,7 @@ annBind :: Bind -> Err Bind
annBind (Bind name vars exp) = Bind name vars <$> annExp exp annBind (Bind name vars exp) = Bind name vars <$> annExp exp
where where
annExp = \case annExp = \case
-- Annotated types should not be EAnn e t -> flip EAnn (annType t) <$> annExp e
-- foralled without the consent of the user
EAnn e t -> flip EAnn t <$> annExp e
EApp e1 e2 -> liftA2 EApp (annExp e1) (annExp e2) EApp e1 e2 -> liftA2 EApp (annExp e1) (annExp e2)
EAdd e1 e2 -> liftA2 EAdd (annExp e1) (annExp e2) EAdd e1 e2 -> liftA2 EAdd (annExp e1) (annExp e2)
ELet bind e -> liftA2 ELet (annBind bind) (annExp e) ELet bind e -> liftA2 ELet (annBind bind) (annExp e)
@ -82,25 +79,22 @@ unboundedTVars' bs typ = tvars.unbounded Set.\\ tvars.bounded
where where
tvars = gatherTVars typ tvars = gatherTVars typ
gatherTVars = \case gatherTVars = \case
TAll tvar t -> TAll tvar t -> TVars { bounded = Set.singleton tvar
TVars
{ bounded = Set.singleton tvar
, unbounded = unboundedTVars' (Set.insert tvar bs) t , unbounded = unboundedTVars' (Set.insert tvar bs) t
} }
TVar tvar -> uTVars $ Set.singleton tvar TVar tvar -> uTVars $ Set.singleton tvar
TFun t1 t2 -> uTVars $ on Set.union (unboundedTVars' bs) t1 t2 TFun t1 t2 -> uTVars $ on Set.union (unboundedTVars' bs) t1 t2
TData _ typs -> uTVars $ foldr (Set.union . unboundedTVars' bs) mempty typs TData _ typs -> uTVars $ foldr (Set.union . unboundedTVars' bs) mempty typs
_ -> TVars{bounded = mempty, unbounded = mempty} _ -> TVars { bounded = mempty, unbounded = mempty }
data TVars = TVars data TVars = TVars
{ bounded :: Set TVar { bounded :: Set TVar
, unbounded :: Set TVar , unbounded :: Set TVar
} } deriving (Eq, Show, Ord)
deriving (Eq, Show, Ord)
uTVars :: Set TVar -> TVars uTVars :: Set TVar -> TVars
uTVars us = uTVars us = TVars
TVars
{ bounded = mempty { bounded = mempty
, unbounded = us , unbounded = us
} }