churf/src/ReportForall.hs
2023-04-24 10:10:15 +02:00

68 lines
2.2 KiB
Haskell

{-# LANGUAGE LambdaCase #-}
module ReportForall (reportForall) where
import Auxiliary (partitionDefs)
import Control.Monad (unless, void, when)
import Control.Monad.Except (MonadError (throwError))
import Data.Either.Combinators (mapRight)
import Data.Foldable (foldlM)
import Data.Function (on)
import Data.List (delete)
import Grammar.Abs
import Grammar.ErrM (Err)
import TypeChecker.TypeChecker (TypeChecker (Bi, Hm))
reportForall :: TypeChecker -> Program -> Err ()
reportForall tc p = do
when (tc == Hm) $ rpProgram rpaType p
rpProgram rpuType p
rpuType :: Type -> Err ()
rpuType typ = do
tvars <- go [] typ
unless (null tvars) $ throwError "Unused forall"
where
go tvars = \case
TAll tvar t
| tvar `elem` tvars -> throwError "Unused forall"
| otherwise -> go (tvar : tvars) t
TVar tvar -> pure (delete tvar tvars)
TFun t1 t2 -> go tvars t1 >>= (`go` t2)
TData _ typs -> foldlM go tvars typs
_ -> pure tvars
rpaType :: Type -> Err ()
rpaType = rpForall . skipForall
where
skipForall = \case
TAll _ t -> skipForall t
t -> t
rpForall = \case
TAll {} -> throwError "Higher rank forall not allowed"
TFun t1 t2 -> on (>>) rpForall t1 t2
TData _ typs -> mapM_ rpForall typs
_ -> pure ()
rpProgram :: (Type -> Err ()) -> Program -> Err ()
rpProgram rf (Program defs) = do
mapM_ rpuBind bs
mapM_ rpuData ds
mapM_ rpuSig ss
where
(ds, ss, bs) = partitionDefs defs
rpuSig (Sig _ typ) = rf typ
rpuData (Data typ injs) = rf typ >> mapM rpuInj injs
rpuInj (Inj _ typ) = rf typ
rpuBind (Bind _ _ rhs) = rpuExp rhs
rpuBranch (Branch _ e) = rpuExp e
rpuExp = \case
EAnn e t -> rpuExp e >> rf t
EApp e1 e2 -> on (>>) rpuExp e1 e2
EAdd e1 e2 -> on (>>) rpuExp e1 e2
ELet bind e -> rpuBind bind >> rpuExp e
EAbs _ e -> rpuExp e
ECase e bs -> rpuExp e >> mapM_ rpuBranch bs
_ -> pure ()