Merge closures mostly done. Desugaring cases is a problem.

This commit is contained in:
Martin Fredin 2023-05-06 23:38:56 +02:00
commit 019ed0d45a
29 changed files with 1484 additions and 757 deletions

View file

@ -2,15 +2,15 @@
module TypeChecker.ReportTEVar where
import Auxiliary (onM)
import Control.Applicative (Applicative (liftA2), liftA3)
import Control.Monad.Except (MonadError (throwError))
import Data.Coerce (coerce)
import Data.Tuple.Extra (secondM)
import Grammar.Abs qualified as G
import Grammar.ErrM (Err)
import Grammar.Print (printTree)
import TypeChecker.TypeCheckerIr hiding (Type (..))
import Auxiliary (onM)
import Control.Applicative (Applicative (liftA2), liftA3)
import Control.Monad.Except (MonadError (throwError))
import Data.Coerce (coerce)
import Data.Tuple.Extra (secondM)
import qualified Grammar.Abs as G
import Grammar.ErrM (Err)
import Grammar.Print (printTree)
import TypeChecker.TypeCheckerIr hiding (Type (..))
data Type
= TLit Ident
@ -18,7 +18,7 @@ data Type
| TData Ident [Type]
| TFun Type Type
| TAll TVar Type
deriving (Eq, Ord, Show, Read)
deriving (Eq, Ord, Show)
class ReportTEVar a b where
reportTEVar :: a -> Err b
@ -29,20 +29,20 @@ instance ReportTEVar (Program' G.Type) (Program' Type) where
instance ReportTEVar (Def' G.Type) (Def' Type) where
reportTEVar = \case
DBind bind -> DBind <$> reportTEVar bind
DData dat -> DData <$> reportTEVar dat
DData dat -> DData <$> reportTEVar dat
instance ReportTEVar (Bind' G.Type) (Bind' Type) where
reportTEVar (Bind id vars rhs) = liftA3 Bind (reportTEVar id) (reportTEVar vars) (reportTEVar rhs)
instance ReportTEVar (Exp' G.Type) (Exp' Type) where
reportTEVar exp = case exp of
EVar name -> pure $ EVar name
EInj name -> pure $ EInj name
ELit lit -> pure $ ELit lit
ELet bind e -> liftA2 ELet (reportTEVar bind) (reportTEVar e)
EApp e1 e2 -> onM EApp reportTEVar e1 e2
EAdd e1 e2 -> onM EAdd reportTEVar e1 e2
EAbs name e -> EAbs name <$> reportTEVar e
EVar name -> pure $ EVar name
EInj name -> pure $ EInj name
ELit lit -> pure $ ELit lit
ELet bind e -> liftA2 ELet (reportTEVar bind) (reportTEVar e)
EApp e1 e2 -> onM EApp reportTEVar e1 e2
EAdd e1 e2 -> onM EAdd reportTEVar e1 e2
EAbs name e -> EAbs name <$> reportTEVar e
ECase e branches -> liftA2 ECase (reportTEVar e) (reportTEVar branches)
instance ReportTEVar (Branch' G.Type) (Branch' Type) where
@ -53,10 +53,10 @@ instance ReportTEVar (Pattern' G.Type, G.Type) (Pattern' Type, Type) where
instance ReportTEVar (Pattern' G.Type) (Pattern' Type) where
reportTEVar = \case
PVar name -> pure $ PVar name
PLit lit -> pure $ PLit lit
PCatch -> pure PCatch
PEnum name -> pure $ PEnum name
PVar name -> pure $ PVar name
PLit lit -> pure $ PLit lit
PCatch -> pure PCatch
PEnum name -> pure $ PEnum name
PInj name ps -> PInj name <$> reportTEVar ps
instance ReportTEVar (Data' G.Type) (Data' Type) where
@ -65,10 +65,10 @@ instance ReportTEVar (Data' G.Type) (Data' Type) where
instance ReportTEVar (Inj' G.Type) (Inj' Type) where
reportTEVar (Inj name typ) = Inj name <$> reportTEVar typ
instance ReportTEVar (Id' G.Type) (Id' Type) where
instance ReportTEVar (a, G.Type) (a, Type) where
reportTEVar = secondM reportTEVar
instance ReportTEVar (ExpT' G.Type) (ExpT' Type) where
instance ReportTEVar (T' Exp' G.Type) (T' Exp' Type) where
reportTEVar (exp, typ) = liftA2 (,) (reportTEVar exp) (reportTEVar typ)
instance ReportTEVar a b => ReportTEVar [a] [b] where
@ -76,9 +76,9 @@ instance ReportTEVar a b => ReportTEVar [a] [b] where
instance ReportTEVar G.Type Type where
reportTEVar = \case
G.TLit lit -> pure $ TLit (coerce lit)
G.TVar (G.MkTVar i) -> pure $ TVar (MkTVar $ coerce i)
G.TData name typs -> TData (coerce name) <$> reportTEVar typs
G.TFun t1 t2 -> liftA2 TFun (reportTEVar t1) (reportTEVar t2)
G.TLit lit -> pure $ TLit (coerce lit)
G.TVar (G.MkTVar i) -> pure $ TVar (MkTVar $ coerce i)
G.TData name typs -> TData (coerce name) <$> reportTEVar typs
G.TFun t1 t2 -> liftA2 TFun (reportTEVar t1) (reportTEVar t2)
G.TAll (G.MkTVar i) t -> TAll (MkTVar $ coerce i) <$> reportTEVar t
G.TEVar tevar -> throwError ("Found TEVar: " ++ printTree tevar)
G.TEVar tevar -> throwError ("Found TEVar: " ++ printTree tevar)

View file

@ -31,6 +31,7 @@ import Grammar.ErrM
import Grammar.Print (printTree)
import Prelude hiding (exp)
import qualified TypeChecker.TypeCheckerIr as T
import TypeChecker.TypeCheckerIr (T, T')
-- Implementation is derived from the paper (Dunfield and Krishnaswami 2013)
-- https://doi.org/10.1145/2500365.2500582
@ -172,7 +173,7 @@ typecheckInj (Inj inj_name inj_typ) name tvars
-- | Γ ⊢ e ↑ A ⊣ Δ
-- Under input context Γ, e checks against input type A, with output context ∆
check :: Exp -> Type -> Tc (T.ExpT' Type)
check :: Exp -> Type -> Tc (T' T.Exp' Type)
-- Γ,α ⊢ e ↑ A ⊣ Δ,α
-- ------------------- ∀I
@ -212,12 +213,6 @@ check (ECase scrut pi) c = do
e' <- check e c
pure (T.Branch p' e')
apply (T.ECase (scrut', a) pi', c)
where
go (pi, b) (Branch p e) = do
p' <- checkPattern p =<< apply a
e'@(_, b') <- infer e
subtype b' b
apply (T.Branch p' e' : pi, b')
-- Γ,α ⊢ e ↓ A ⊣ Θ Θ ⊢ [Θ]A <: [Θ]B ⊣ Δ
@ -229,9 +224,6 @@ check e b = do
subtype a b'
apply (e', b)
checkPattern :: Pattern -> Type -> Tc (T.Pattern' Type, Type)
checkPattern patt t_patt = case patt of
@ -297,7 +289,7 @@ checkPattern patt t_patt = case patt of
-- | Γ ⊢ e ↓ A ⊣ Δ
-- Under input context Γ, e infers output type A, with output context ∆
infer :: Exp -> Tc (T.ExpT' Type)
infer :: Exp -> Tc (T' T.Exp' Type)
infer (ELit lit) = apply (T.ELit lit, litType lit)
-- Γ ∋ (x : A) Γ ⊢ rec(x)
@ -391,7 +383,7 @@ infer (ECase scrut pi) = do
-- | Γ ⊢ A • e ⇓ C ⊣ Δ
-- Under input context Γ , applying a function of type A to e infers type C, with output context ∆
-- Instantiate existential type variables until there is an arrow type.
applyInfer :: Type -> Exp -> Tc (T.ExpT' Type, Type)
applyInfer :: Type -> Exp -> Tc (T' T.Exp' Type, Type)
-- Γ,ά ⊢ [ά/α]A • e ⇓ C ⊣ Δ
-- ------------------------ ∀App

View file

@ -1,32 +1,33 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedRecordDot #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE QualifiedDo #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE QualifiedDo #-}
{-# OPTIONS_GHC -Wno-incomplete-patterns #-}
-- | A module for type checking and inference using algorithm W, Hindley-Milner
module TypeChecker.TypeCheckerHm where
import Auxiliary (int, litType, maybeToRightM, unzip4)
import Auxiliary qualified as Aux
import Control.Monad.Except
import Control.Monad.Identity (Identity, runIdentity)
import Control.Monad.Reader
import Control.Monad.State
import Control.Monad.Writer
import Data.Coerce (coerce)
import Data.Function (on)
import Data.List (foldl', nub, sortOn)
import Data.List.Extra (unsnoc)
import Data.Map (Map)
import Data.Map qualified as M
import Data.Maybe (fromJust)
import Data.Set (Set)
import Data.Set qualified as S
import Debug.Trace (trace, traceShow)
import Grammar.Abs
import Grammar.Print (printTree)
import TypeChecker.TypeCheckerIr qualified as T
import Auxiliary (int, litType, maybeToRightM, unzip4)
import qualified Auxiliary as Aux
import Control.Monad.Except
import Control.Monad.Identity (Identity, runIdentity)
import Control.Monad.Reader
import Control.Monad.State
import Control.Monad.Writer
import Data.Coerce (coerce)
import Data.Function (on)
import Data.List (foldl', nub, sortOn)
import Data.List.Extra (unsnoc)
import Data.Map (Map)
import qualified Data.Map as M
import Data.Maybe (fromJust)
import Data.Set (Set)
import qualified Data.Set as S
import Debug.Trace (trace, traceShow)
import Grammar.Abs
import Grammar.Print (printTree)
import TypeChecker.TypeCheckerIr (T, T')
import qualified TypeChecker.TypeCheckerIr as T
{-
TODO
@ -41,7 +42,7 @@ typecheck :: Program -> Either String (T.Program' Type, [Warning])
typecheck = onLeft msg . run . checkPrg
where
onLeft :: (Error -> String) -> Either Error a -> Either String a
onLeft f (Left x) = Left $ f x
onLeft f (Left x) = Left $ f x
onLeft _ (Right x) = Right x
checkPrg :: Program -> Infer (T.Program' Type)
@ -68,13 +69,13 @@ prettify s (T.Program defs) = T.Program $ map (go s) defs
replace :: Map T.Ident T.Ident -> Type -> Type
replace m def@(TVar (MkTVar (LIdent a))) = case M.lookup (coerce a) m of
Just t -> TVar . MkTVar . LIdent $ coerce t
Just t -> TVar . MkTVar . LIdent $ coerce t
Nothing -> def
replace m (TFun t1 t2) = (TFun `on` replace m) t1 t2
replace m (TData name ts) = TData name (map (replace m) ts)
replace m def@(TAll (MkTVar forall_) t) = case M.lookup (coerce forall_) m of
Just found -> TAll (MkTVar $ coerce found) (replace m t)
Nothing -> def
Nothing -> def
replace _ t = t
bindCount :: [Def] -> Infer [(Int, Def)]
@ -128,7 +129,7 @@ preRun (x : xs) = case x of
s <- gets sigs
case M.lookup (coerce n) s of
Nothing -> insertSig (coerce n) Nothing >> preRun xs
Just _ -> preRun xs
Just _ -> preRun xs
DData d@(Data t _) -> collect (collectTVars t) >> checkData d >> preRun xs
where
-- Check if function body / signature has been declared already
@ -150,11 +151,11 @@ checkDef (x : xs) = case x of
T.Data t $ map (\(Inj name typ) -> T.Inj (coerce name) typ) injs
freeOrdered :: Type -> [T.Ident]
freeOrdered (TVar (MkTVar a)) = return (coerce a)
freeOrdered (TVar (MkTVar a)) = return (coerce a)
freeOrdered (TAll (MkTVar bound) t) = return (coerce bound) ++ freeOrdered t
freeOrdered (TFun a b) = freeOrdered a ++ freeOrdered b
freeOrdered (TData _ a) = concatMap freeOrdered a
freeOrdered _ = mempty
freeOrdered (TFun a b) = freeOrdered a ++ freeOrdered b
freeOrdered (TData _ a) = concatMap freeOrdered a
freeOrdered _ = mempty
-- Much cleaner implementation, unfortunately one minor bug
-- checkBind :: Bind -> Infer (T.Bind' Type)
@ -259,13 +260,13 @@ checkInj (Inj c inj_typ) name tvars
toTVar :: Type -> Either Error TVar
toTVar = \case
TVar tvar -> pure tvar
_ -> uncatchableErr "Not a type variable"
_ -> uncatchableErr "Not a type variable"
returnType :: Type -> Type
returnType (TFun _ t2) = returnType t2
returnType a = a
returnType a = a
inferExp :: Exp -> Infer (T.ExpT' Type)
inferExp :: Exp -> Infer (T' T.Exp' Type)
inferExp e = do
(s, (e', t)) <- algoW e
let subbed = apply s t
@ -276,7 +277,7 @@ class CollectTVars a where
instance CollectTVars Exp where
collectTVars (EAnn e t) = collectTVars t `S.union` collectTVars e
collectTVars _ = S.empty
collectTVars _ = S.empty
instance CollectTVars Type where
collectTVars (TVar (MkTVar i)) = S.singleton (coerce i)
@ -289,7 +290,7 @@ instance CollectTVars Type where
collect :: Set T.Ident -> Infer ()
collect s = modify (\st -> st{takenTypeVars = s `S.union` takenTypeVars st})
algoW :: Exp -> Infer (Subst, T.ExpT' Type)
algoW :: Exp -> Infer (Subst, T' T.Exp' Type)
algoW = \case
err@(EAnn e t) -> do
(sub0, (e', t')) <- exprErr (algoW e) err
@ -602,12 +603,12 @@ generalize :: Map T.Ident Type -> Type -> Type
generalize env t = go (S.toList $ free t S.\\ free env) (removeForalls t)
where
go :: [T.Ident] -> Type -> Type
go [] t = t
go [] t = t
go (x : xs) t = TAll (MkTVar (coerce x)) (go xs t)
removeForalls :: Type -> Type
removeForalls (TAll _ t) = removeForalls t
removeForalls (TAll _ t) = removeForalls t
removeForalls (TFun t1 t2) = TFun (removeForalls t1) (removeForalls t2)
removeForalls t = t
removeForalls t = t
{- | Instantiate a polymorphic type. The free type variables are substituted
with fresh ones.
@ -655,7 +656,7 @@ fresh = do
ungo :: [TVar] -> Type -> Type -> Bool
ungo tvars t1 t2 = case run (go tvars t1 t2) of
Right (b, _) -> b
_ -> False
_ -> False
-- TODO: Fix the following
-- Maybe locally using the Infer monad can cause trouble.
-- Since the fresh count starts from zero
@ -700,15 +701,15 @@ instance SubstType Type where
TLit _ -> t
TVar (MkTVar a) -> case M.lookup (coerce a) sub of
Nothing -> TVar (MkTVar $ coerce a)
Just t -> t
Just t -> t
TAll (MkTVar i) t -> case M.lookup (coerce i) sub of
Nothing -> TAll (MkTVar i) (apply sub t)
Just _ -> apply sub t
Just _ -> apply sub t
TFun a b -> TFun (apply sub a) (apply sub b)
TData name a -> TData name (apply sub a)
TEVar (MkTEVar a) -> case M.lookup (coerce a) sub of
Nothing -> TEVar (MkTEVar $ coerce a)
Just t -> t
Just t -> t
instance FreeVars (Map T.Ident Type) where
free :: Map T.Ident Type -> Set T.Ident
@ -721,7 +722,7 @@ instance SubstType (Map T.Ident Type) where
instance SubstType (Map T.Ident (Maybe Type)) where
apply s = M.map (fmap $ apply s)
instance SubstType (T.ExpT' Type) where
instance SubstType (T' T.Exp' Type) where
apply s (e, t) = (apply s e, apply s t)
instance SubstType (T.Exp' Type) where
@ -750,10 +751,10 @@ instance SubstType (T.Branch' Type) where
instance SubstType (T.Pattern' Type) where
apply s = \case
T.PVar iden -> T.PVar iden
T.PLit lit -> T.PLit lit
T.PLit lit -> T.PLit lit
T.PInj i ps -> T.PInj i $ apply s ps
T.PCatch -> T.PCatch
T.PEnum i -> T.PEnum i
T.PCatch -> T.PCatch
T.PEnum i -> T.PEnum i
instance SubstType (T.Pattern' Type, Type) where
apply s (p, t) = (apply s p, apply s t)
@ -761,7 +762,7 @@ instance SubstType (T.Pattern' Type, Type) where
instance SubstType a => SubstType [a] where
apply s = map (apply s)
instance SubstType (T.Id' Type) where
instance SubstType (T T.Ident Type) where
apply s (name, t) = (name, apply s t)
-- | Represents the empty substition set
@ -794,11 +795,11 @@ withBindings xs =
-- | Run the monadic action with a pattern
withPattern :: (Monad m, MonadReader Ctx m) => (T.Pattern' Type, Type) -> m a -> m a
withPattern (p, t) ma = case p of
T.PVar x -> withBinding x t ma
T.PVar x -> withBinding x t ma
T.PInj _ ps -> foldl' (flip withPattern) ma ps
T.PLit _ -> ma
T.PCatch -> ma
T.PEnum _ -> ma
T.PLit _ -> ma
T.PCatch -> ma
T.PEnum _ -> ma
-- | Insert a function signature into the environment
insertSig :: T.Ident -> Maybe Type -> Infer ()
@ -823,11 +824,11 @@ existInj n = gets (M.lookup n . injections)
flattenType :: Type -> [Type]
flattenType (TFun a b) = flattenType a <> flattenType b
flattenType a = [a]
flattenType a = [a]
typeLength :: Type -> Int
typeLength (TFun _ b) = 1 + typeLength b
typeLength _ = 1
typeLength _ = 1
{- | Catch an error if possible and add the given
expression as addition to the error message
@ -910,11 +911,11 @@ newtype Ctx = Ctx {vars :: Map T.Ident Type}
deriving (Show)
data Env = Env
{ count :: Int
, nextChar :: Char
, sigs :: Map T.Ident (Maybe Type)
{ count :: Int
, nextChar :: Char
, sigs :: Map T.Ident (Maybe Type)
, takenTypeVars :: Set T.Ident
, injections :: Map T.Ident Type
, injections :: Map T.Ident Type
, declaredBinds :: Set T.Ident
}
deriving (Show)

View file

@ -10,31 +10,30 @@ import Data.String (IsString)
import Grammar.Abs (Lit (..))
import Grammar.Print
import Prelude
import qualified Prelude as C (Eq, Ord, Read, Show)
newtype Program' t = Program [Def' t]
deriving (C.Eq, C.Ord, C.Show, C.Read, Functor)
deriving (Eq, Ord, Show, Functor)
data Def' t
= DBind (Bind' t)
| DData (Data' t)
deriving (C.Eq, C.Ord, C.Show, C.Read, Functor)
deriving (Eq, Ord, Show, Functor)
data Type
= TLit Ident
| TVar TVar
| TData Ident [Type]
| TFun Type Type
deriving (Eq, Ord, Show, Read)
deriving (Eq, Ord, Show)
data Data' t = Data t [Inj' t]
deriving (C.Eq, C.Ord, C.Show, C.Read, Functor)
deriving (Eq, Ord, Show, Functor)
data Inj' t = Inj Ident t
deriving (C.Eq, C.Ord, C.Show, C.Read, Functor)
deriving (Eq, Ord, Show, Functor)
newtype Ident = Ident String
deriving (C.Eq, C.Ord, C.Show, C.Read, IsString)
deriving (Eq, Ord, Show, IsString)
data Pattern' t
= PVar Ident
@ -42,30 +41,31 @@ data Pattern' t
| PCatch
| PEnum Ident
| PInj Ident [(Pattern' t, t)]
deriving (C.Eq, C.Ord, C.Show, C.Read, Functor)
deriving (Eq, Ord, Show, Functor)
data Exp' t
= EVar Ident
| EInj Ident
| ELit Lit
| ELet (Bind' t) (ExpT' t)
| EApp (ExpT' t) (ExpT' t)
| EAdd (ExpT' t) (ExpT' t)
| EAbs Ident (ExpT' t)
| ECase (ExpT' t) [Branch' t]
deriving (C.Eq, C.Ord, C.Show, C.Read, Functor)
| ELet (Bind' t) (T' Exp' t)
| EApp (T' Exp' t) (T' Exp' t)
| EAdd (T' Exp' t) (T' Exp' t)
| EAbs Ident (T' Exp' t)
| ECase (T' Exp' t) [Branch' t]
deriving (Eq, Ord, Show, Functor)
newtype TVar = MkTVar Ident
deriving (C.Eq, C.Ord, C.Show, C.Read)
deriving (Eq, Ord, Show)
type Id' t = (Ident, t)
type ExpT' t = (Exp' t, t)
type T' a t = (a t, t)
type T a t = (a, t)
data Bind' t = Bind (Id' t) [Id' t] (ExpT' t)
deriving (C.Eq, C.Ord, C.Show, C.Read, Functor)
data Branch' t = Branch (Pattern' t, t) (ExpT' t)
deriving (C.Eq, C.Ord, C.Show, C.Read, Functor)
data Bind' t = Bind (T Ident t) [T Ident t] (T' Exp' t)
deriving (Eq, Ord, Show, Functor)
data Branch' t = Branch (T' Pattern' t) (T' Exp' t)
deriving (Eq, Ord, Show, Functor)
instance Print Ident where
prt _ (Ident s) = doc $ showString s
@ -81,22 +81,22 @@ instance Print t => Print (Bind' t) where
, prt i rhs
]
prtSig :: Print t => Id' t -> Doc
prtSig (name, t) =
prtSig :: Print t => T Ident t -> Doc
prtSig (x, t) =
concatD
[ prt 0 name
[ prt 0 x
, doc $ showString ":"
, prt 0 t
]
instance Print t => Print (ExpT' t) where
prt i (e, t) =
instance (Print a, Print t) => Print (T a t) where
prt i (x, t) =
concatD
[ doc $ showString "("
, prt i e
, doc $ showString ":"
, prt 0 t
, doc $ showString ")"
[ -- doc $ showString "("
{- , -} prt i x
-- , doc $ showString ":"
-- , prt 0 t
-- , doc $ showString ")"
]
instance Print t => Print [Bind' t] where
@ -104,16 +104,6 @@ instance Print t => Print [Bind' t] where
prt i [x] = concatD [prt i x]
prt i (x : xs) = concatD [prt i x, doc (showString ";"), prt i xs]
instance Print t => Print (Id' t) where
prt i (name, t) =
concatD
[ doc $ showString "("
, prt i name
, doc $ showString ","
, prt i t
, doc $ showString ")"
]
instance Print t => Print (Exp' t) where
prt i = \case
EVar lident -> prPrec i 3 (concatD [prt 0 lident])
@ -151,9 +141,6 @@ instance Print t => Print [Inj' t] where
prt i [x] = prt i x
prt i (x : xs) = prPrec i 0 $ concatD [prt i x, doc $ showString "\n ", prt i xs]
instance Print t => Print (Pattern' t, t) where
prt i (p, t) = prPrec i 1 (concatD [prt i p, prt i t])
instance Print t => Print (Pattern' t) where
prt i = \case
PVar name -> prPrec i 1 (concatD [prt 0 name])
@ -189,8 +176,6 @@ type Branch = Branch' Type
type Pattern = Pattern' Type
type Inj = Inj' Type
type Exp = Exp' Type
type ExpT = ExpT' Type
type Id = Id' Type
pattern TVar' s = TVar (MkTVar s)
pattern DBind' id vars expt = DBind (Bind id vars expt)
pattern DData' typ injs = DData (Data typ injs)