Add implicit foralls for bidir, update and unify pipeline
This commit is contained in:
parent
12bca1c32d
commit
9870802371
33 changed files with 1010 additions and 1055 deletions
100
src/AnnForall.hs
Normal file
100
src/AnnForall.hs
Normal file
|
|
@ -0,0 +1,100 @@
|
|||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE OverloadedRecordDot #-}
|
||||
|
||||
module AnnForall (annotateForall) where
|
||||
|
||||
import Auxiliary (partitionDefs)
|
||||
import Control.Applicative (Applicative (liftA2))
|
||||
import Control.Monad.Except (throwError)
|
||||
import Data.Function (on)
|
||||
import Data.Set (Set)
|
||||
import qualified Data.Set as Set
|
||||
import Grammar.Abs
|
||||
import Grammar.ErrM (Err)
|
||||
|
||||
annotateForall :: Program -> Err Program
|
||||
annotateForall (Program defs) = do
|
||||
ds' <- mapM (fmap DData . annData) ds
|
||||
bs' <- mapM (fmap DBind . annBind) bs
|
||||
pure $ Program (ds' ++ ss' ++ bs')
|
||||
where
|
||||
ss' = map (DSig . annSig) ss
|
||||
(ds, ss, bs) = partitionDefs defs
|
||||
|
||||
|
||||
annData :: Data -> Err Data
|
||||
annData (Data typ injs) = do
|
||||
(typ', tvars) <- annTyp typ
|
||||
pure (Data typ' $ map (annInj tvars) injs)
|
||||
|
||||
where
|
||||
annTyp typ = do
|
||||
(bounded, ts) <- boundedTVars mempty typ
|
||||
unbounded <- Set.fromList <$> mapM assertTVar ts
|
||||
let diff = unbounded Set.\\ bounded
|
||||
typ' = foldr TAll typ diff
|
||||
(typ', ) . fst <$> boundedTVars mempty typ'
|
||||
where
|
||||
boundedTVars tvars typ = case typ of
|
||||
TAll tvar t -> boundedTVars (Set.insert tvar tvars) t
|
||||
TData _ ts -> pure (tvars, ts)
|
||||
_ -> throwError "Misformed data declaration"
|
||||
|
||||
assertTVar typ = case typ of
|
||||
TVar tvar -> pure tvar
|
||||
_ -> throwError $ unwords [ "Misformed data declaration:"
|
||||
, "Non type variable argument"
|
||||
]
|
||||
annInj tvars (Inj n t) =
|
||||
Inj n $ foldr TAll t (unboundedTVars t Set.\\ tvars)
|
||||
|
||||
annSig :: Sig -> Sig
|
||||
annSig (Sig name typ) = Sig name $ annType typ
|
||||
|
||||
annBind :: Bind -> Err Bind
|
||||
annBind (Bind name vars exp) = Bind name vars <$> annExp exp
|
||||
where
|
||||
annExp = \case
|
||||
EAnn e t -> flip EAnn (annType t) <$> annExp e
|
||||
EApp e1 e2 -> liftA2 EApp (annExp e1) (annExp e2)
|
||||
EAdd e1 e2 -> liftA2 EAdd (annExp e1) (annExp e2)
|
||||
ELet bind e -> liftA2 ELet (annBind bind) (annExp e)
|
||||
EAbs x e -> EAbs x <$> annExp e
|
||||
ECase e bs -> liftA2 ECase (annExp e) (mapM annBranch bs)
|
||||
e -> pure e
|
||||
annBranch (Branch p e) = Branch p <$> annExp e
|
||||
|
||||
annType :: Type -> Type
|
||||
annType typ = go $ unboundedTVars typ
|
||||
where
|
||||
go us
|
||||
| null us = typ
|
||||
| otherwise = foldr TAll typ us
|
||||
|
||||
unboundedTVars :: Type -> Set TVar
|
||||
unboundedTVars = unboundedTVars' mempty
|
||||
|
||||
unboundedTVars' :: Set TVar -> Type -> Set TVar
|
||||
unboundedTVars' bs typ = tvars.unbounded Set.\\ tvars.bounded
|
||||
where
|
||||
tvars = gatherTVars typ
|
||||
gatherTVars = \case
|
||||
TAll tvar t -> TVars { bounded = Set.singleton tvar
|
||||
, unbounded = unboundedTVars' (Set.insert tvar bs) t
|
||||
}
|
||||
TVar tvar -> uTVars $ Set.singleton tvar
|
||||
TFun t1 t2 -> uTVars $ on Set.union (unboundedTVars' bs) t1 t2
|
||||
TData _ typs -> uTVars $ foldr (Set.union . unboundedTVars' bs) mempty typs
|
||||
_ -> TVars { bounded = mempty, unbounded = mempty }
|
||||
|
||||
data TVars = TVars
|
||||
{ bounded :: Set TVar
|
||||
, unbounded :: Set TVar
|
||||
} deriving (Eq, Show, Ord)
|
||||
|
||||
uTVars :: Set TVar -> TVars
|
||||
uTVars us = TVars
|
||||
{ bounded = mempty
|
||||
, unbounded = us
|
||||
}
|
||||
|
||||
Loading…
Add table
Add a link
Reference in a new issue