diff --git a/src/AnnForall.hs b/src/AnnForall.hs index f309a37..16222bd 100644 --- a/src/AnnForall.hs +++ b/src/AnnForall.hs @@ -1,16 +1,16 @@ -{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedRecordDot #-} module AnnForall (annotateForall) where -import Auxiliary (partitionDefs) -import Control.Applicative (Applicative (liftA2)) -import Control.Monad.Except (throwError) -import Data.Function (on) -import Data.Set (Set) -import Data.Set qualified as Set -import Grammar.Abs -import Grammar.ErrM (Err) +import Auxiliary (partitionDefs) +import Control.Applicative (Applicative (liftA2)) +import Control.Monad.Except (throwError) +import Data.Function (on) +import Data.Set (Set) +import qualified Data.Set as Set +import Grammar.Abs +import Grammar.ErrM (Err) annotateForall :: Program -> Err Program annotateForall (Program defs) = do @@ -21,31 +21,30 @@ annotateForall (Program defs) = do ss' = map (DSig . annSig) ss (ds, ss, bs) = partitionDefs defs + annData :: Data -> Err Data annData (Data typ injs) = do - (typ', tvars) <- annTyp typ - pure (Data typ' $ map (annInj tvars) injs) + (typ', tvars) <- annTyp typ + pure (Data typ' $ map (annInj tvars) injs) + where annTyp typ = do (bounded, ts) <- boundedTVars mempty typ unbounded <- Set.fromList <$> mapM assertTVar ts let diff = unbounded Set.\\ bounded typ' = foldr TAll typ diff - (typ',) . fst <$> boundedTVars mempty typ' + (typ', ) . fst <$> boundedTVars mempty typ' where boundedTVars tvars typ = case typ of - TAll tvar t -> boundedTVars (Set.insert tvar tvars) t - TData _ ts -> pure (tvars, ts) - _ -> throwError "Misformed data declaration" + TAll tvar t -> boundedTVars (Set.insert tvar tvars) t + TData _ ts -> pure (tvars, ts) + _ -> throwError "Misformed data declaration" assertTVar typ = case typ of TVar tvar -> pure tvar - _ -> - throwError $ - unwords - [ "Misformed data declaration:" - , "Non type variable argument" - ] + _ -> throwError $ unwords [ "Misformed data declaration:" + , "Non type variable argument" + ] annInj tvars (Inj n t) = Inj n $ foldr TAll t (unboundedTVars t Set.\\ tvars) @@ -56,22 +55,20 @@ annBind :: Bind -> Err Bind annBind (Bind name vars exp) = Bind name vars <$> annExp exp where annExp = \case - -- Annotated types should not be - -- foralled without the consent of the user - EAnn e t -> flip EAnn t <$> annExp e - EApp e1 e2 -> liftA2 EApp (annExp e1) (annExp e2) - EAdd e1 e2 -> liftA2 EAdd (annExp e1) (annExp e2) + EAnn e t -> flip EAnn (annType t) <$> annExp e + EApp e1 e2 -> liftA2 EApp (annExp e1) (annExp e2) + EAdd e1 e2 -> liftA2 EAdd (annExp e1) (annExp e2) ELet bind e -> liftA2 ELet (annBind bind) (annExp e) - EAbs x e -> EAbs x <$> annExp e - ECase e bs -> liftA2 ECase (annExp e) (mapM annBranch bs) - e -> pure e + EAbs x e -> EAbs x <$> annExp e + ECase e bs -> liftA2 ECase (annExp e) (mapM annBranch bs) + e -> pure e annBranch (Branch p e) = Branch p <$> annExp e annType :: Type -> Type annType typ = go $ unboundedTVars typ where go us - | null us = typ + | null us = typ | otherwise = foldr TAll typ us unboundedTVars :: Type -> Set TVar @@ -82,25 +79,22 @@ unboundedTVars' bs typ = tvars.unbounded Set.\\ tvars.bounded where tvars = gatherTVars typ gatherTVars = \case - TAll tvar t -> - TVars - { bounded = Set.singleton tvar - , unbounded = unboundedTVars' (Set.insert tvar bs) t - } - TVar tvar -> uTVars $ Set.singleton tvar - TFun t1 t2 -> uTVars $ on Set.union (unboundedTVars' bs) t1 t2 - TData _ typs -> uTVars $ foldr (Set.union . unboundedTVars' bs) mempty typs - _ -> TVars{bounded = mempty, unbounded = mempty} + TAll tvar t -> TVars { bounded = Set.singleton tvar + , unbounded = unboundedTVars' (Set.insert tvar bs) t + } + TVar tvar -> uTVars $ Set.singleton tvar + TFun t1 t2 -> uTVars $ on Set.union (unboundedTVars' bs) t1 t2 + TData _ typs -> uTVars $ foldr (Set.union . unboundedTVars' bs) mempty typs + _ -> TVars { bounded = mempty, unbounded = mempty } data TVars = TVars - { bounded :: Set TVar - , unbounded :: Set TVar - } - deriving (Eq, Show, Ord) + { bounded :: Set TVar + , unbounded :: Set TVar + } deriving (Eq, Show, Ord) uTVars :: Set TVar -> TVars -uTVars us = - TVars - { bounded = mempty - , unbounded = us - } +uTVars us = TVars + { bounded = mempty + , unbounded = us + } +