Add implicit foralls for bidir, update and unify pipeline
This commit is contained in:
parent
12bca1c32d
commit
9870802371
33 changed files with 1010 additions and 1055 deletions
48
src/TypeChecker/RemoveForall.hs
Normal file
48
src/TypeChecker/RemoveForall.hs
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
{-# LANGUAGE LambdaCase #-}
|
||||
|
||||
module TypeChecker.RemoveForall (removeForall) where
|
||||
|
||||
import Auxiliary (onM)
|
||||
import Control.Applicative (Applicative (liftA2))
|
||||
import Data.Function (on)
|
||||
import Data.List (partition)
|
||||
import Data.Tuple.Extra (second)
|
||||
import Grammar.ErrM (Err)
|
||||
import qualified TypeChecker.ReportTEVar as R
|
||||
import TypeChecker.TypeCheckerIr
|
||||
|
||||
removeForall :: Program' R.Type -> Program
|
||||
removeForall (Program defs) = Program $ map (DData . rfData) ds
|
||||
++ map (DBind . rfBind) bs
|
||||
where
|
||||
(ds, bs) = ([d | DData d <- defs ], [ b | DBind b <- defs ])
|
||||
rfData (Data typ injs) = Data (rfType typ) (map rfInj injs)
|
||||
rfInj (Inj name typ) = Inj name (rfType typ)
|
||||
rfBind (Bind name vars rhs) = Bind (rfId name) (map rfId vars) (rfExpT rhs)
|
||||
rfId = second rfType
|
||||
rfExpT (e, t) = (rfExp e, rfType t)
|
||||
rfExp = \case
|
||||
EApp e1 e2 -> on EApp rfExpT e1 e2
|
||||
EAdd e1 e2 -> on EAdd rfExpT e1 e2
|
||||
ELet bind e -> ELet (rfBind bind) (rfExpT e)
|
||||
EAbs name e -> EAbs name (rfExpT e)
|
||||
ECase e bs -> ECase (rfExpT e) (map rfBranch bs)
|
||||
ELit lit -> ELit lit
|
||||
EVar name -> EVar name
|
||||
EInj name -> EInj name
|
||||
rfBranch (Branch (p, t) e) = Branch (rfPattern p, rfType t) (rfExpT e)
|
||||
rfPattern = \case
|
||||
PVar id -> PVar (rfId id)
|
||||
PLit (lit, t) -> PLit (lit, rfType t)
|
||||
PCatch -> PCatch
|
||||
PEnum name -> PEnum name
|
||||
PInj name ps -> PInj name (map rfPattern ps)
|
||||
|
||||
rfType :: R.Type -> Type
|
||||
rfType = \case
|
||||
R.TAll _ t -> rfType t
|
||||
R.TFun t1 t2 -> on TFun rfType t1 t2
|
||||
R.TData name ts -> TData name (map rfType ts)
|
||||
R.TLit lit -> TLit lit
|
||||
R.TVar tvar -> TVar tvar
|
||||
|
||||
|
|
@ -1,71 +0,0 @@
|
|||
{-# LANGUAGE LambdaCase #-}
|
||||
|
||||
module TypeChecker.RemoveTEVar where
|
||||
|
||||
import Control.Applicative (Applicative (liftA2), liftA3)
|
||||
import Control.Monad.Except (MonadError (throwError))
|
||||
import Data.Coerce (coerce)
|
||||
import Data.Tuple.Extra (secondM)
|
||||
import Grammar.Abs
|
||||
import Grammar.ErrM (Err)
|
||||
import TypeChecker.TypeCheckerIr qualified as T
|
||||
|
||||
class RemoveTEVar a b where
|
||||
rmTEVar :: a -> Err b
|
||||
|
||||
instance RemoveTEVar (T.Program' Type) (T.Program' T.Type) where
|
||||
rmTEVar (T.Program defs) = T.Program <$> rmTEVar defs
|
||||
|
||||
instance RemoveTEVar (T.Def' Type) (T.Def' T.Type) where
|
||||
rmTEVar = \case
|
||||
T.DBind bind -> T.DBind <$> rmTEVar bind
|
||||
T.DData dat -> T.DData <$> rmTEVar dat
|
||||
|
||||
instance RemoveTEVar (T.Bind' Type) (T.Bind' T.Type) where
|
||||
rmTEVar (T.Bind id vars rhs) = liftA3 T.Bind (rmTEVar id) (rmTEVar vars) (rmTEVar rhs)
|
||||
|
||||
instance RemoveTEVar (T.Exp' Type) (T.Exp' T.Type) where
|
||||
rmTEVar exp = case exp of
|
||||
T.EVar name -> pure $ T.EVar name
|
||||
T.EInj name -> pure $ T.EInj name
|
||||
T.ELit lit -> pure $ T.ELit lit
|
||||
T.ELet bind e -> liftA2 T.ELet (rmTEVar bind) (rmTEVar e)
|
||||
T.EApp e1 e2 -> liftA2 T.EApp (rmTEVar e1) (rmTEVar e2)
|
||||
T.EAdd e1 e2 -> liftA2 T.EAdd (rmTEVar e1) (rmTEVar e2)
|
||||
T.EAbs name e -> T.EAbs name <$> rmTEVar e
|
||||
T.ECase e branches -> liftA2 T.ECase (rmTEVar e) (rmTEVar branches)
|
||||
|
||||
instance RemoveTEVar (T.Branch' Type) (T.Branch' T.Type) where
|
||||
rmTEVar (T.Branch (patt, t_patt) e) = liftA2 T.Branch (liftA2 (,) (rmTEVar patt) (rmTEVar t_patt)) (rmTEVar e)
|
||||
|
||||
instance RemoveTEVar (T.Pattern' Type) (T.Pattern' T.Type) where
|
||||
rmTEVar = \case
|
||||
T.PVar (name, t) -> T.PVar . (name,) <$> rmTEVar t
|
||||
T.PLit (lit, t) -> T.PLit . (lit,) <$> rmTEVar t
|
||||
T.PCatch -> pure T.PCatch
|
||||
T.PEnum name -> pure $ T.PEnum name
|
||||
T.PInj name ps -> T.PInj name <$> rmTEVar ps
|
||||
|
||||
instance RemoveTEVar (T.Data' Type) (T.Data' T.Type) where
|
||||
rmTEVar (T.Data typ injs) = liftA2 T.Data (rmTEVar typ) (rmTEVar injs)
|
||||
|
||||
instance RemoveTEVar (T.Inj' Type) (T.Inj' T.Type) where
|
||||
rmTEVar (T.Inj name typ) = T.Inj name <$> rmTEVar typ
|
||||
|
||||
instance RemoveTEVar (T.Id' Type) (T.Id' T.Type) where
|
||||
rmTEVar = secondM rmTEVar
|
||||
|
||||
instance RemoveTEVar (T.ExpT' Type) (T.ExpT' T.Type) where
|
||||
rmTEVar (exp, typ) = liftA2 (,) (rmTEVar exp) (rmTEVar typ)
|
||||
|
||||
instance RemoveTEVar a b => RemoveTEVar [a] [b] where
|
||||
rmTEVar = mapM rmTEVar
|
||||
|
||||
instance RemoveTEVar Type T.Type where
|
||||
rmTEVar = \case
|
||||
TLit lit -> pure $ T.TLit (coerce lit)
|
||||
TVar (MkTVar i) -> pure $ T.TVar (T.MkTVar $ coerce i)
|
||||
TData name typs -> T.TData (coerce name) <$> rmTEVar typs
|
||||
TFun t1 t2 -> liftA2 T.TFun (rmTEVar t1) (rmTEVar t2)
|
||||
TAll (MkTVar i) t -> T.TAll (T.MkTVar $ coerce i) <$> rmTEVar t
|
||||
TEVar _ -> throwError "NewType TEVar!"
|
||||
81
src/TypeChecker/ReportTEVar.hs
Normal file
81
src/TypeChecker/ReportTEVar.hs
Normal file
|
|
@ -0,0 +1,81 @@
|
|||
{-# LANGUAGE LambdaCase #-}
|
||||
|
||||
module TypeChecker.ReportTEVar where
|
||||
|
||||
import Auxiliary (onM)
|
||||
import Control.Applicative (Applicative (liftA2), liftA3)
|
||||
import Control.Monad.Except (MonadError (throwError))
|
||||
import Data.Coerce (coerce)
|
||||
import Data.Tuple.Extra (secondM)
|
||||
import qualified Grammar.Abs as G
|
||||
import Grammar.ErrM (Err)
|
||||
import TypeChecker.TypeCheckerIr hiding (Type (..))
|
||||
|
||||
|
||||
data Type
|
||||
= TLit Ident
|
||||
| TVar TVar
|
||||
| TData Ident [Type]
|
||||
| TFun Type Type
|
||||
| TAll TVar Type
|
||||
deriving (Eq, Ord, Show, Read)
|
||||
|
||||
class ReportTEVar a b where
|
||||
reportTEVar :: a -> Err b
|
||||
|
||||
instance ReportTEVar (Program' G.Type) (Program' Type) where
|
||||
reportTEVar (Program defs) = Program <$> reportTEVar defs
|
||||
|
||||
instance ReportTEVar (Def' G.Type) (Def' Type) where
|
||||
reportTEVar = \case
|
||||
DBind bind -> DBind <$> reportTEVar bind
|
||||
DData dat -> DData <$> reportTEVar dat
|
||||
|
||||
instance ReportTEVar (Bind' G.Type) (Bind' Type) where
|
||||
reportTEVar (Bind id vars rhs) = liftA3 Bind (reportTEVar id) (reportTEVar vars) (reportTEVar rhs)
|
||||
|
||||
instance ReportTEVar (Exp' G.Type) (Exp' Type) where
|
||||
reportTEVar exp = case exp of
|
||||
EVar name -> pure $ EVar name
|
||||
EInj name -> pure $ EInj name
|
||||
ELit lit -> pure $ ELit lit
|
||||
ELet bind e -> liftA2 ELet (reportTEVar bind) (reportTEVar e)
|
||||
EApp e1 e2 -> onM EApp reportTEVar e1 e2
|
||||
EAdd e1 e2 -> onM EAdd reportTEVar e1 e2
|
||||
EAbs name e -> EAbs name <$> reportTEVar e
|
||||
ECase e branches -> liftA2 ECase (reportTEVar e) (reportTEVar branches)
|
||||
|
||||
instance ReportTEVar (Branch' G.Type) (Branch' Type) where
|
||||
reportTEVar (Branch (patt, t_patt) e) = liftA2 Branch (liftA2 (,) (reportTEVar patt) (reportTEVar t_patt)) (reportTEVar e)
|
||||
|
||||
instance ReportTEVar (Pattern' G.Type) (Pattern' Type) where
|
||||
reportTEVar = \case
|
||||
PVar (name, t) -> PVar . (name,) <$> reportTEVar t
|
||||
PLit (lit, t) -> PLit . (lit,) <$> reportTEVar t
|
||||
PCatch -> pure PCatch
|
||||
PEnum name -> pure $ PEnum name
|
||||
PInj name ps -> PInj name <$> reportTEVar ps
|
||||
|
||||
instance ReportTEVar (Data' G.Type) (Data' Type) where
|
||||
reportTEVar (Data typ injs) = liftA2 Data (reportTEVar typ) (reportTEVar injs)
|
||||
|
||||
instance ReportTEVar (Inj' G.Type) (Inj' Type) where
|
||||
reportTEVar (Inj name typ) = Inj name <$> reportTEVar typ
|
||||
|
||||
instance ReportTEVar (Id' G.Type) (Id' Type) where
|
||||
reportTEVar = secondM reportTEVar
|
||||
|
||||
instance ReportTEVar (ExpT' G.Type) (ExpT' Type) where
|
||||
reportTEVar (exp, typ) = liftA2 (,) (reportTEVar exp) (reportTEVar typ)
|
||||
|
||||
instance ReportTEVar a b => ReportTEVar [a] [b] where
|
||||
reportTEVar = mapM reportTEVar
|
||||
|
||||
instance ReportTEVar G.Type Type where
|
||||
reportTEVar = \case
|
||||
G.TLit lit -> pure $ TLit (coerce lit)
|
||||
G.TVar (G.MkTVar i) -> pure $ TVar (MkTVar $ coerce i)
|
||||
G.TData name typs -> TData (coerce name) <$> reportTEVar typs
|
||||
G.TFun t1 t2 -> liftA2 TFun (reportTEVar t1) (reportTEVar t2)
|
||||
G.TAll (G.MkTVar i) t -> TAll (MkTVar $ coerce i) <$> reportTEVar t
|
||||
G.TEVar _ -> throwError "NewType TEVar!"
|
||||
|
|
@ -1,17 +1,19 @@
|
|||
module TypeChecker.TypeChecker (typecheck, TypeChecker (..)) where
|
||||
|
||||
import Control.Monad ((<=<))
|
||||
import Grammar.Abs
|
||||
import Grammar.ErrM (Err)
|
||||
import TypeChecker.RemoveTEVar (RemoveTEVar (rmTEVar))
|
||||
import TypeChecker.TypeCheckerBidir qualified as Bi
|
||||
import TypeChecker.TypeCheckerHm qualified as Hm
|
||||
import TypeChecker.TypeCheckerIr qualified as T
|
||||
import Control.Monad ((<=<))
|
||||
import qualified Grammar.Abs as G
|
||||
import Grammar.ErrM (Err)
|
||||
import TypeChecker.RemoveForall (removeForall)
|
||||
import qualified TypeChecker.ReportTEVar as R
|
||||
import TypeChecker.ReportTEVar (reportTEVar)
|
||||
import qualified TypeChecker.TypeCheckerBidir as Bi
|
||||
import qualified TypeChecker.TypeCheckerHm as Hm
|
||||
import TypeChecker.TypeCheckerIr
|
||||
|
||||
data TypeChecker = Bi | Hm
|
||||
data TypeChecker = Bi | Hm deriving Eq
|
||||
|
||||
typecheck :: TypeChecker -> Program -> Err T.Program
|
||||
typecheck tc = rmTEVar <=< f
|
||||
typecheck :: TypeChecker -> G.Program -> Err Program
|
||||
typecheck tc = fmap removeForall . (reportTEVar <=< f)
|
||||
where
|
||||
f = case tc of
|
||||
Bi -> Bi.typecheck
|
||||
|
|
|
|||
|
|
@ -121,6 +121,7 @@ typecheckBind (Bind name vars rhs) = do
|
|||
, "Did you forget to add type annotation to a polymorphic function?"
|
||||
]
|
||||
|
||||
-- TODO remove some checks
|
||||
typecheckDataType :: Data -> Err (T.Data' Type)
|
||||
typecheckDataType (Data typ injs) = do
|
||||
(name, tvars) <- go [] typ
|
||||
|
|
@ -135,6 +136,7 @@ typecheckDataType (Data typ injs) = do
|
|||
-> pure (name, tvars')
|
||||
_ -> throwError $ unwords ["Bad data type definition: ", ppT typ]
|
||||
|
||||
-- TODO remove some checks
|
||||
typecheckInj :: Inj -> UIdent -> [TVar] -> Err (T.Inj' Type)
|
||||
typecheckInj (Inj inj_name inj_typ) name tvars
|
||||
| not $ boundTVars tvars inj_typ
|
||||
|
|
@ -878,18 +880,18 @@ traceTs s xs = trace (s ++ " [ " ++ intercalate ", " (map ppT xs) ++ " ]") pure
|
|||
|
||||
ppT = \case
|
||||
TLit (UIdent s) -> s
|
||||
TVar (MkTVar (LIdent s)) -> "α_" ++ s
|
||||
TFun t1 t2 -> ppT t1 ++ "→" ++ ppT t2
|
||||
TVar (MkTVar (LIdent s)) -> "a_" ++ s
|
||||
TFun t1 t2 -> ppT t1 ++ "->" ++ ppT t2
|
||||
TAll (MkTVar (LIdent s)) t -> "forall " ++ s ++ ". " ++ ppT t
|
||||
TEVar (MkTEVar (LIdent s)) -> "ά_" ++ s
|
||||
TEVar (MkTEVar (LIdent s)) -> "a^_" ++ s
|
||||
TData (UIdent name) typs -> name ++ " (" ++ unwords (map ppT typs)
|
||||
++ " )"
|
||||
ppEnvElem = \case
|
||||
EnvVar (LIdent s) t -> s ++ ":" ++ ppT t
|
||||
EnvTVar (MkTVar (LIdent s)) -> "α_" ++ s
|
||||
EnvTEVar (MkTEVar (LIdent s)) -> "ά_" ++ s
|
||||
EnvTEVarSolved (MkTEVar (LIdent s)) t -> "ά_" ++ s ++ "=" ++ ppT t
|
||||
EnvMark (MkTEVar (LIdent s)) -> "▶" ++ "ά_" ++ s
|
||||
EnvTVar (MkTVar (LIdent s)) -> "a_" ++ s
|
||||
EnvTEVar (MkTEVar (LIdent s)) -> "a^_" ++ s
|
||||
EnvTEVarSolved (MkTEVar (LIdent s)) t -> "_" ++ s ++ "=" ++ ppT t
|
||||
EnvMark (MkTEVar (LIdent s)) -> "▶" ++ "a^_" ++ s
|
||||
|
||||
ppEnv = \case
|
||||
Empty -> "·"
|
||||
|
|
|
|||
|
|
@ -1,31 +1,31 @@
|
|||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE OverloadedRecordDot #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE QualifiedDo #-}
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
{-# LANGUAGE QualifiedDo #-}
|
||||
|
||||
-- | A module for type checking and inference using algorithm W, Hindley-Milner
|
||||
module TypeChecker.TypeCheckerHm where
|
||||
|
||||
import Auxiliary (int, litType, maybeToRightM, unzip4)
|
||||
import Auxiliary qualified as Aux
|
||||
import Control.Monad.Except
|
||||
import Control.Monad.Identity (Identity, runIdentity)
|
||||
import Control.Monad.Reader
|
||||
import Control.Monad.State
|
||||
import Control.Monad.Writer
|
||||
import Data.Coerce (coerce)
|
||||
import Data.Function (on)
|
||||
import Data.List (foldl', nub, sortOn)
|
||||
import Data.List.Extra (unsnoc)
|
||||
import Data.Map (Map)
|
||||
import Data.Map qualified as M
|
||||
import Data.Maybe (fromJust)
|
||||
import Data.Set (Set)
|
||||
import Data.Set qualified as S
|
||||
import Debug.Trace (trace)
|
||||
import Grammar.Abs
|
||||
import Grammar.Print (printTree)
|
||||
import TypeChecker.TypeCheckerIr qualified as T
|
||||
import Auxiliary (int, litType, maybeToRightM, unzip4)
|
||||
import qualified Auxiliary as Aux
|
||||
import Control.Monad.Except
|
||||
import Control.Monad.Identity (Identity, runIdentity)
|
||||
import Control.Monad.Reader
|
||||
import Control.Monad.State
|
||||
import Control.Monad.Writer
|
||||
import Data.Coerce (coerce)
|
||||
import Data.Function (on)
|
||||
import Data.List (foldl', nub, sortOn)
|
||||
import Data.List.Extra (unsnoc)
|
||||
import Data.Map (Map)
|
||||
import qualified Data.Map as M
|
||||
import Data.Maybe (fromJust)
|
||||
import Data.Set (Set)
|
||||
import qualified Data.Set as S
|
||||
import Debug.Trace (trace)
|
||||
import Grammar.Abs
|
||||
import Grammar.Print (printTree)
|
||||
import qualified TypeChecker.TypeCheckerIr as T
|
||||
|
||||
-- TODO: Disallow mutual recursion
|
||||
|
||||
|
|
@ -34,7 +34,7 @@ typecheck :: Program -> Either String (T.Program' Type, [Warning])
|
|||
typecheck = onLeft msg . run . checkPrg
|
||||
where
|
||||
onLeft :: (Error -> String) -> Either Error a -> Either String a
|
||||
onLeft f (Left x) = Left $ f x
|
||||
onLeft f (Left x) = Left $ f x
|
||||
onLeft _ (Right x) = Right x
|
||||
|
||||
checkPrg :: Program -> Infer (T.Program' Type)
|
||||
|
|
@ -118,7 +118,7 @@ preRun (x : xs) = case x of
|
|||
s <- gets sigs
|
||||
case M.lookup (coerce n) s of
|
||||
Nothing -> insertSig (coerce n) Nothing >> preRun xs
|
||||
Just _ -> preRun xs
|
||||
Just _ -> preRun xs
|
||||
DData d@(Data t _) -> collect (collectTVars t) >> checkData d >> preRun xs
|
||||
where
|
||||
-- Check if function body / signature has been declared already
|
||||
|
|
@ -140,11 +140,11 @@ checkDef (x : xs) = case x of
|
|||
T.Data t $ map (\(Inj name typ) -> T.Inj (coerce name) typ) injs
|
||||
|
||||
freeOrdered :: Type -> [T.Ident]
|
||||
freeOrdered (TVar (MkTVar a)) = return (coerce a)
|
||||
freeOrdered (TVar (MkTVar a)) = return (coerce a)
|
||||
freeOrdered (TAll (MkTVar bound) t) = return (coerce bound) ++ freeOrdered t
|
||||
freeOrdered (TFun a b) = freeOrdered a ++ freeOrdered b
|
||||
freeOrdered (TData _ a) = concatMap freeOrdered a
|
||||
freeOrdered _ = mempty
|
||||
freeOrdered (TFun a b) = freeOrdered a ++ freeOrdered b
|
||||
freeOrdered (TData _ a) = concatMap freeOrdered a
|
||||
freeOrdered _ = mempty
|
||||
|
||||
checkBind :: Bind -> Infer (T.Bind' Type)
|
||||
checkBind (Bind name args e) = do
|
||||
|
|
@ -178,22 +178,19 @@ checkBind (Bind name args e) = do
|
|||
|
||||
checkData :: (MonadState Env m, Monad m, MonadError Error m) => Data -> m ()
|
||||
checkData err@(Data typ injs) = do
|
||||
(name, tvars) <- go typ
|
||||
(name, tvars) <- go (skipForalls typ)
|
||||
dataErr (mapM_ (\i -> checkInj i name tvars) injs) err
|
||||
where
|
||||
go = \case
|
||||
TData name typs
|
||||
| Right tvars' <- mapM toTVar typs ->
|
||||
pure (name, tvars')
|
||||
TAll _ _ -> uncatchableErr "Explicit forall not allowed, for now"
|
||||
_ ->
|
||||
uncatchableErr $
|
||||
unwords ["Bad data type definition: ", printTree typ]
|
||||
|
||||
checkInj :: (MonadError Error m, MonadState Env m, Monad m) => Inj -> UIdent -> [TVar] -> m ()
|
||||
checkInj (Inj c inj_typ) name tvars
|
||||
| Right False <- boundTVars tvars inj_typ =
|
||||
catchableErr "Unbound type variables"
|
||||
| TData name' typs <- returnType inj_typ
|
||||
, Right tvars' <- mapM toTVar typs
|
||||
, name' == name
|
||||
|
|
@ -217,27 +214,15 @@ checkInj (Inj c inj_typ) name tvars
|
|||
, "\nActual: "
|
||||
, printTree $ returnType inj_typ
|
||||
]
|
||||
where
|
||||
boundTVars :: [TVar] -> Type -> Either Error Bool
|
||||
boundTVars tvars' = \case
|
||||
TAll{} -> uncatchableErr "Explicit forall not allowed, for now"
|
||||
TFun t1 t2 -> do
|
||||
t1' <- boundTVars tvars t1
|
||||
t2' <- boundTVars tvars t2
|
||||
return $ t1' && t2'
|
||||
TVar tvar -> return $ tvar `elem` tvars'
|
||||
TData _ typs -> and <$> mapM (boundTVars tvars) typs
|
||||
TLit _ -> return True
|
||||
TEVar _ -> error "TEVar in data type declaration"
|
||||
|
||||
toTVar :: Type -> Either Error TVar
|
||||
toTVar = \case
|
||||
TVar tvar -> pure tvar
|
||||
_ -> uncatchableErr "Not a type variable"
|
||||
_ -> uncatchableErr "Not a type variable"
|
||||
|
||||
returnType :: Type -> Type
|
||||
returnType (TFun _ t2) = returnType t2
|
||||
returnType a = a
|
||||
returnType a = a
|
||||
|
||||
inferExp :: Exp -> Infer (T.ExpT' Type)
|
||||
inferExp e = do
|
||||
|
|
@ -250,7 +235,7 @@ class CollectTVars a where
|
|||
|
||||
instance CollectTVars Exp where
|
||||
collectTVars (EAnn e t) = collectTVars t `S.union` collectTVars e
|
||||
collectTVars _ = S.empty
|
||||
collectTVars _ = S.empty
|
||||
|
||||
instance CollectTVars Type where
|
||||
collectTVars (TVar (MkTVar i)) = S.singleton (coerce i)
|
||||
|
|
@ -569,12 +554,12 @@ generalize :: Map T.Ident Type -> Type -> Type
|
|||
generalize env t = go (S.toList $ free t S.\\ free env) (removeForalls t)
|
||||
where
|
||||
go :: [T.Ident] -> Type -> Type
|
||||
go [] t = t
|
||||
go [] t = t
|
||||
go (x : xs) t = TAll (MkTVar (coerce x)) (go xs t)
|
||||
removeForalls :: Type -> Type
|
||||
removeForalls (TAll _ t) = removeForalls t
|
||||
removeForalls (TAll _ t) = removeForalls t
|
||||
removeForalls (TFun t1 t2) = TFun (removeForalls t1) (removeForalls t2)
|
||||
removeForalls t = t
|
||||
removeForalls t = t
|
||||
|
||||
{- | Instantiate a polymorphic type. The free type variables are substituted
|
||||
with fresh ones.
|
||||
|
|
@ -611,42 +596,39 @@ currently this is not the case, the TAll pattern match is incorrectly implemente
|
|||
-- Is the left a subtype of the right
|
||||
(<<=) :: Type -> Type -> Bool
|
||||
(<<=) (TVar _) _ = True
|
||||
(<<=) (TAll _ t1) (TAll _ t2) = t1 <<= t2
|
||||
(<<=) t1@TAll{} t2 = skipForalls t1 <<= t2
|
||||
(<<=) t1 t2@TAll{} = t1 <<= skipForalls t2
|
||||
(<<=) (TFun a b) (TFun c d) = a <<= c && b <<= d
|
||||
(<<=) (TData n1 ts1) (TData n2 ts2) =
|
||||
n1 == n2
|
||||
&& length ts1 == length ts2
|
||||
&& and (zipWith (<<=) ts1 ts2)
|
||||
(<<=) t0 t@(TAll _ _) = go t0 t
|
||||
where
|
||||
go t0 t@(TAll _ t1) = S.toList (free t0) == foralls t && go' t0 t1
|
||||
go _ _ = undefined
|
||||
|
||||
go' (TEVar (MkTEVar a)) (TVar (MkTVar b)) = a == b
|
||||
go' (TEVar (MkTEVar a)) (TEVar (MkTEVar b)) = a == b
|
||||
go' (TFun a b) (TFun c d) = a `go'` c && b `go'` d
|
||||
go' _ _ = False
|
||||
(<<=) a b = a == b
|
||||
|
||||
skipForalls :: Type -> Type
|
||||
skipForalls = \case
|
||||
TAll _ t -> t
|
||||
t -> t
|
||||
|
||||
foralls :: Type -> [T.Ident]
|
||||
foralls (TAll (MkTVar a) t) = coerce a : foralls t
|
||||
foralls _ = []
|
||||
foralls _ = []
|
||||
|
||||
mkForall :: Type -> Type
|
||||
mkForall t = case map (TAll . MkTVar . coerce) $ S.toList $ free t of
|
||||
[] -> t
|
||||
(x : xs) ->
|
||||
let f acc [] = acc
|
||||
let f acc [] = acc
|
||||
f acc (x : xs) = f (x acc) xs
|
||||
(y : ys) = reverse $ x : xs
|
||||
in f (y t) ys
|
||||
|
||||
skolemize :: Type -> Type
|
||||
skolemize (TVar (MkTVar a)) = TEVar $ MkTEVar a
|
||||
skolemize (TAll x t) = TAll x (skolemize t)
|
||||
skolemize (TFun t1 t2) = (TFun `on` skolemize) t1 t2
|
||||
skolemize (TData n ts) = TData n (map skolemize ts)
|
||||
skolemize t = t
|
||||
skolemize (TAll x t) = TAll x (skolemize t)
|
||||
skolemize (TFun t1 t2) = (TFun `on` skolemize) t1 t2
|
||||
skolemize (TData n ts) = TData n (map skolemize ts)
|
||||
skolemize t = t
|
||||
|
||||
-- | A class for substitutions
|
||||
class SubstType t where
|
||||
|
|
@ -680,10 +662,10 @@ instance SubstType Type where
|
|||
TLit _ -> t
|
||||
TVar (MkTVar a) -> case M.lookup (coerce a) sub of
|
||||
Nothing -> TVar (MkTVar $ coerce a)
|
||||
Just t -> t
|
||||
Just t -> t
|
||||
TAll (MkTVar i) t -> case M.lookup (coerce i) sub of
|
||||
Nothing -> TAll (MkTVar i) (apply sub t)
|
||||
Just _ -> apply sub t
|
||||
Just _ -> apply sub t
|
||||
TFun a b -> TFun (apply sub a) (apply sub b)
|
||||
TData name a -> TData name (apply sub a)
|
||||
TEVar (MkTEVar _) -> t
|
||||
|
|
@ -728,10 +710,10 @@ instance SubstType (T.Branch' Type) where
|
|||
instance SubstType (T.Pattern' Type) where
|
||||
apply s = \case
|
||||
T.PVar (iden, t) -> T.PVar (iden, apply s t)
|
||||
T.PLit (lit, t) -> T.PLit (lit, apply s t)
|
||||
T.PInj i ps -> T.PInj i $ apply s ps
|
||||
T.PCatch -> T.PCatch
|
||||
T.PEnum i -> T.PEnum i
|
||||
T.PLit (lit, t) -> T.PLit (lit, apply s t)
|
||||
T.PInj i ps -> T.PInj i $ apply s ps
|
||||
T.PCatch -> T.PCatch
|
||||
T.PEnum i -> T.PEnum i
|
||||
|
||||
instance SubstType (T.Pattern' Type, Type) where
|
||||
apply s (p, t) = (apply s p, apply s t)
|
||||
|
|
@ -773,10 +755,10 @@ withBindings xs =
|
|||
withPattern :: (Monad m, MonadReader Ctx m) => T.Pattern' Type -> m a -> m a
|
||||
withPattern p ma = case p of
|
||||
T.PVar (x, t) -> withBinding x t ma
|
||||
T.PInj _ ps -> foldl' (flip withPattern) ma ps
|
||||
T.PLit _ -> ma
|
||||
T.PCatch -> ma
|
||||
T.PEnum _ -> ma
|
||||
T.PInj _ ps -> foldl' (flip withPattern) ma ps
|
||||
T.PLit _ -> ma
|
||||
T.PCatch -> ma
|
||||
T.PEnum _ -> ma
|
||||
|
||||
-- | Insert a function signature into the environment
|
||||
insertSig :: T.Ident -> Maybe Type -> Infer ()
|
||||
|
|
@ -801,11 +783,11 @@ existInj n = gets (M.lookup n . injections)
|
|||
|
||||
flattenType :: Type -> [Type]
|
||||
flattenType (TFun a b) = flattenType a <> flattenType b
|
||||
flattenType a = [a]
|
||||
flattenType a = [a]
|
||||
|
||||
typeLength :: Type -> Int
|
||||
typeLength (TFun _ b) = 1 + typeLength b
|
||||
typeLength _ = 1
|
||||
typeLength _ = 1
|
||||
|
||||
{- | Catch an error if possible and add the given
|
||||
expression as addition to the error message
|
||||
|
|
@ -888,11 +870,11 @@ newtype Ctx = Ctx {vars :: Map T.Ident Type}
|
|||
deriving (Show)
|
||||
|
||||
data Env = Env
|
||||
{ count :: Int
|
||||
, nextChar :: Char
|
||||
, sigs :: Map T.Ident (Maybe Type)
|
||||
{ count :: Int
|
||||
, nextChar :: Char
|
||||
, sigs :: Map T.Ident (Maybe Type)
|
||||
, takenTypeVars :: Set T.Ident
|
||||
, injections :: Map T.Ident Type
|
||||
, injections :: Map T.Ident Type
|
||||
, declaredBinds :: Set T.Ident
|
||||
}
|
||||
deriving (Show)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# LANGUAGE PatternSynonyms #-}
|
||||
|
||||
module TypeChecker.TypeCheckerIr (
|
||||
|
|
@ -6,11 +6,11 @@ module TypeChecker.TypeCheckerIr (
|
|||
module TypeChecker.TypeCheckerIr,
|
||||
) where
|
||||
|
||||
import Data.String (IsString)
|
||||
import Grammar.Abs (Lit (..))
|
||||
import Grammar.Print
|
||||
import Prelude
|
||||
import Prelude qualified as C (Eq, Ord, Read, Show)
|
||||
import Data.String (IsString)
|
||||
import Grammar.Abs (Lit (..))
|
||||
import Grammar.Print
|
||||
import Prelude
|
||||
import qualified Prelude as C (Eq, Ord, Read, Show)
|
||||
|
||||
newtype Program' t = Program [Def' t]
|
||||
deriving (C.Eq, C.Ord, C.Show, C.Read, Functor)
|
||||
|
|
@ -25,8 +25,7 @@ data Type
|
|||
| TVar TVar
|
||||
| TData Ident [Type]
|
||||
| TFun Type Type
|
||||
| TAll TVar Type
|
||||
deriving (C.Eq, C.Ord, C.Show, C.Read)
|
||||
deriving (Eq, Ord, Show, Read)
|
||||
|
||||
data Data' t = Data t [Inj' t]
|
||||
deriving (C.Eq, C.Ord, C.Show, C.Read, Functor)
|
||||
|
|
@ -105,8 +104,8 @@ instance Print t => Print (ExpT' t) where
|
|||
]
|
||||
|
||||
instance Print t => Print [Bind' t] where
|
||||
prt _ [] = concatD []
|
||||
prt _ [x] = concatD [prt 0 x]
|
||||
prt _ [] = concatD []
|
||||
prt _ [x] = concatD [prt 0 x]
|
||||
prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs]
|
||||
|
||||
prtIdPs :: Print t => Int -> [Id' t] -> Doc
|
||||
|
|
@ -171,13 +170,13 @@ instance Print t => Print (Branch' t) where
|
|||
prt i (Branch (pattern_, t) exp) = prPrec i 0 (concatD [doc (showString "("), prt 0 pattern_, doc (showString " : "), prt 0 t, doc (showString ")"), doc (showString "=>"), prt 0 exp])
|
||||
|
||||
instance Print t => Print [Branch' t] where
|
||||
prt _ [] = concatD []
|
||||
prt _ [x] = concatD [prt 0 x]
|
||||
prt _ [] = concatD []
|
||||
prt _ [x] = concatD [prt 0 x]
|
||||
prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs]
|
||||
|
||||
instance Print t => Print (Def' t) where
|
||||
prt i = \case
|
||||
DBind bind -> prPrec i 0 (concatD [prt 0 bind])
|
||||
DBind bind -> prPrec i 0 (concatD [prt 0 bind])
|
||||
DData data_ -> prPrec i 0 (concatD [prt 0 data_])
|
||||
|
||||
instance Print t => Print (Data' t) where
|
||||
|
|
@ -202,12 +201,12 @@ instance Print t => Print (Pattern' t) where
|
|||
PInj uident patterns -> prPrec i 0 (concatD [prt 0 uident, prt 1 patterns])
|
||||
|
||||
instance Print t => Print [Def' t] where
|
||||
prt _ [] = concatD []
|
||||
prt _ [x] = concatD [prt 0 x]
|
||||
prt _ [] = concatD []
|
||||
prt _ [x] = concatD [prt 0 x]
|
||||
prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs]
|
||||
|
||||
instance Print [Type] where
|
||||
prt _ [] = concatD []
|
||||
prt _ [] = concatD []
|
||||
prt _ (x : xs) = concatD [prt 0 x, doc (showString " "), prt 0 xs]
|
||||
|
||||
instance Print Type where
|
||||
|
|
@ -216,7 +215,6 @@ instance Print Type where
|
|||
TVar tvar -> prPrec i 1 (concatD [prt 0 tvar])
|
||||
TData uident types -> prPrec i 1 (concatD [prt 0 uident, doc (showString "("), prt 0 types, doc (showString ")")])
|
||||
TFun type_1 type_2 -> prPrec i 0 (concatD [prt 1 type_1, doc (showString "->"), prt 0 type_2])
|
||||
TAll tvar type_ -> prPrec i 0 (concatD [doc (showString "forall"), prt 0 tvar, doc (showString "."), prt 0 type_])
|
||||
|
||||
instance Print TVar where
|
||||
prt i (MkTVar ident) = prt i ident
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue