68 lines
2.2 KiB
Haskell
68 lines
2.2 KiB
Haskell
{-# LANGUAGE LambdaCase #-}
|
|
|
|
module ReportForall (reportForall) where
|
|
|
|
import Auxiliary (partitionDefs)
|
|
import Control.Monad (unless, void, when)
|
|
import Control.Monad.Except (MonadError (throwError))
|
|
import Data.Either.Combinators (mapRight)
|
|
import Data.Foldable (foldlM)
|
|
import Data.Function (on)
|
|
import Data.List (delete)
|
|
import Grammar.Abs
|
|
import Grammar.ErrM (Err)
|
|
import TypeChecker.TypeChecker (TypeChecker (Bi, Hm))
|
|
|
|
reportForall :: TypeChecker -> Program -> Err ()
|
|
reportForall tc p = do
|
|
when (tc == Hm) $ rpProgram rpaType p
|
|
rpProgram rpuType p
|
|
|
|
rpuType :: Type -> Err ()
|
|
rpuType typ = do
|
|
tvars <- go [] typ
|
|
unless (null tvars) $ throwError "Unused forall"
|
|
where
|
|
go tvars = \case
|
|
TAll tvar t
|
|
| tvar `elem` tvars -> throwError "Unused forall"
|
|
| otherwise -> go (tvar : tvars) t
|
|
TVar tvar -> pure (delete tvar tvars)
|
|
TFun t1 t2 -> go tvars t1 >>= (`go` t2)
|
|
TData _ typs -> foldlM go tvars typs
|
|
_ -> pure tvars
|
|
|
|
|
|
rpaType :: Type -> Err ()
|
|
rpaType = rpForall . skipForall
|
|
where
|
|
skipForall = \case
|
|
TAll _ t -> skipForall t
|
|
t -> t
|
|
rpForall = \case
|
|
TAll {} -> throwError "Higher rank forall not allowed"
|
|
TFun t1 t2 -> on (>>) rpForall t1 t2
|
|
TData _ typs -> mapM_ rpForall typs
|
|
_ -> pure ()
|
|
|
|
rpProgram :: (Type -> Err ()) -> Program -> Err ()
|
|
rpProgram rf (Program defs) = do
|
|
mapM_ rpuBind bs
|
|
mapM_ rpuData ds
|
|
mapM_ rpuSig ss
|
|
where
|
|
(ds, ss, bs) = partitionDefs defs
|
|
rpuSig (Sig _ typ) = rf typ
|
|
rpuData (Data typ injs) = rf typ >> mapM rpuInj injs
|
|
rpuInj (Inj _ typ) = rf typ
|
|
rpuBind (Bind _ _ rhs) = rpuExp rhs
|
|
rpuBranch (Branch _ e) = rpuExp e
|
|
rpuExp = \case
|
|
EAnn e t -> rpuExp e >> rf t
|
|
EApp e1 e2 -> on (>>) rpuExp e1 e2
|
|
EAdd e1 e2 -> on (>>) rpuExp e1 e2
|
|
ELet bind e -> rpuBind bind >> rpuExp e
|
|
EAbs _ e -> rpuExp e
|
|
ECase e bs -> rpuExp e >> mapM_ rpuBranch bs
|
|
_ -> pure ()
|
|
|