Add implicit foralls for bidir, update and unify pipeline

This commit is contained in:
Martin Fredin 2023-04-03 17:34:33 +02:00
parent 12bca1c32d
commit 9870802371
33 changed files with 1010 additions and 1055 deletions

70
src/ReportForall.hs Normal file
View file

@ -0,0 +1,70 @@
{-# LANGUAGE LambdaCase #-}
module ReportForall (reportForall) where
import Auxiliary (partitionDefs)
import Control.Monad (unless, void, when)
import Control.Monad.Except (MonadError (throwError))
import Data.Either.Combinators (mapRight)
import Data.Foldable (foldlM)
import Data.Function (on)
import Data.List (delete)
import Grammar.Abs
import Grammar.ErrM (Err)
import TypeChecker.TypeChecker (TypeChecker (Bi, Hm))
reportForall :: TypeChecker -> Program -> Err ()
reportForall tc p = do
when (tc == Hm) $ rpProgram rpaType p
rpProgram rpuType p
rpuType :: Type -> Err ()
rpuType typ = do
tvars <- go [] typ
unless (null tvars) $ throwError "Unused forall"
where
go tvars = \case
TAll tvar t
| tvar `elem` tvars -> throwError "Duplicate forall"
| otherwise -> go (tvar : tvars) t
TVar tvar -> pure (delete tvar tvars)
TFun t1 t2 -> go tvars t1 >>= (`go` t2)
TData _ typs -> foldlM go tvars typs
_ -> pure tvars
rpaType :: Type -> Err ()
rpaType = rpForall . skipForall
where
skipForall = \case
TAll _ t -> skipForall t
t -> t
rpForall = \case
TAll {} -> throwError "Higher rank forall not allowed"
TFun t1 t2 -> on (>>) rpForall t1 t2
TData _ typs -> mapM_ rpForall typs
_ -> pure ()
rpProgram :: (Type -> Err ()) -> Program -> Err ()
rpProgram rf (Program defs) = do
mapM_ rpuBind bs
mapM_ rpuData ds
mapM_ rpuSig ss
where
(ds, ss, bs) = partitionDefs defs
rpuSig (Sig _ typ) = rf typ
rpuData (Data typ injs) = rf typ >> mapM rpuInj injs
rpuInj (Inj _ typ) = rf typ
rpuBind (Bind _ _ rhs) = rpuExp rhs
rpuBranch (Branch _ e) = rpuExp e
rpuExp = \case
EAnn e t -> rpuExp e >> rf t
EApp e1 e2 -> on (>>) rpuExp e1 e2
EAdd e1 e2 -> on (>>) rpuExp e1 e2
ELet bind e -> rpuBind bind >> rpuExp e
EAbs _ e -> rpuExp e
ECase e bs -> rpuExp e >> mapM_ rpuBranch bs
_ -> pure ()
reportAnyForall :: Program -> Err ()
reportAnyForall = undefined