Add implicit foralls for bidir, update and unify pipeline

This commit is contained in:
Martin Fredin 2023-04-03 17:34:33 +02:00
parent 12bca1c32d
commit 9870802371
33 changed files with 1010 additions and 1055 deletions

View file

@ -1,31 +1,31 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedRecordDot #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE QualifiedDo #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE QualifiedDo #-}
-- | A module for type checking and inference using algorithm W, Hindley-Milner
module TypeChecker.TypeCheckerHm where
import Auxiliary (int, litType, maybeToRightM, unzip4)
import Auxiliary qualified as Aux
import Control.Monad.Except
import Control.Monad.Identity (Identity, runIdentity)
import Control.Monad.Reader
import Control.Monad.State
import Control.Monad.Writer
import Data.Coerce (coerce)
import Data.Function (on)
import Data.List (foldl', nub, sortOn)
import Data.List.Extra (unsnoc)
import Data.Map (Map)
import Data.Map qualified as M
import Data.Maybe (fromJust)
import Data.Set (Set)
import Data.Set qualified as S
import Debug.Trace (trace)
import Grammar.Abs
import Grammar.Print (printTree)
import TypeChecker.TypeCheckerIr qualified as T
import Auxiliary (int, litType, maybeToRightM, unzip4)
import qualified Auxiliary as Aux
import Control.Monad.Except
import Control.Monad.Identity (Identity, runIdentity)
import Control.Monad.Reader
import Control.Monad.State
import Control.Monad.Writer
import Data.Coerce (coerce)
import Data.Function (on)
import Data.List (foldl', nub, sortOn)
import Data.List.Extra (unsnoc)
import Data.Map (Map)
import qualified Data.Map as M
import Data.Maybe (fromJust)
import Data.Set (Set)
import qualified Data.Set as S
import Debug.Trace (trace)
import Grammar.Abs
import Grammar.Print (printTree)
import qualified TypeChecker.TypeCheckerIr as T
-- TODO: Disallow mutual recursion
@ -34,7 +34,7 @@ typecheck :: Program -> Either String (T.Program' Type, [Warning])
typecheck = onLeft msg . run . checkPrg
where
onLeft :: (Error -> String) -> Either Error a -> Either String a
onLeft f (Left x) = Left $ f x
onLeft f (Left x) = Left $ f x
onLeft _ (Right x) = Right x
checkPrg :: Program -> Infer (T.Program' Type)
@ -118,7 +118,7 @@ preRun (x : xs) = case x of
s <- gets sigs
case M.lookup (coerce n) s of
Nothing -> insertSig (coerce n) Nothing >> preRun xs
Just _ -> preRun xs
Just _ -> preRun xs
DData d@(Data t _) -> collect (collectTVars t) >> checkData d >> preRun xs
where
-- Check if function body / signature has been declared already
@ -140,11 +140,11 @@ checkDef (x : xs) = case x of
T.Data t $ map (\(Inj name typ) -> T.Inj (coerce name) typ) injs
freeOrdered :: Type -> [T.Ident]
freeOrdered (TVar (MkTVar a)) = return (coerce a)
freeOrdered (TVar (MkTVar a)) = return (coerce a)
freeOrdered (TAll (MkTVar bound) t) = return (coerce bound) ++ freeOrdered t
freeOrdered (TFun a b) = freeOrdered a ++ freeOrdered b
freeOrdered (TData _ a) = concatMap freeOrdered a
freeOrdered _ = mempty
freeOrdered (TFun a b) = freeOrdered a ++ freeOrdered b
freeOrdered (TData _ a) = concatMap freeOrdered a
freeOrdered _ = mempty
checkBind :: Bind -> Infer (T.Bind' Type)
checkBind (Bind name args e) = do
@ -178,22 +178,19 @@ checkBind (Bind name args e) = do
checkData :: (MonadState Env m, Monad m, MonadError Error m) => Data -> m ()
checkData err@(Data typ injs) = do
(name, tvars) <- go typ
(name, tvars) <- go (skipForalls typ)
dataErr (mapM_ (\i -> checkInj i name tvars) injs) err
where
go = \case
TData name typs
| Right tvars' <- mapM toTVar typs ->
pure (name, tvars')
TAll _ _ -> uncatchableErr "Explicit forall not allowed, for now"
_ ->
uncatchableErr $
unwords ["Bad data type definition: ", printTree typ]
checkInj :: (MonadError Error m, MonadState Env m, Monad m) => Inj -> UIdent -> [TVar] -> m ()
checkInj (Inj c inj_typ) name tvars
| Right False <- boundTVars tvars inj_typ =
catchableErr "Unbound type variables"
| TData name' typs <- returnType inj_typ
, Right tvars' <- mapM toTVar typs
, name' == name
@ -217,27 +214,15 @@ checkInj (Inj c inj_typ) name tvars
, "\nActual: "
, printTree $ returnType inj_typ
]
where
boundTVars :: [TVar] -> Type -> Either Error Bool
boundTVars tvars' = \case
TAll{} -> uncatchableErr "Explicit forall not allowed, for now"
TFun t1 t2 -> do
t1' <- boundTVars tvars t1
t2' <- boundTVars tvars t2
return $ t1' && t2'
TVar tvar -> return $ tvar `elem` tvars'
TData _ typs -> and <$> mapM (boundTVars tvars) typs
TLit _ -> return True
TEVar _ -> error "TEVar in data type declaration"
toTVar :: Type -> Either Error TVar
toTVar = \case
TVar tvar -> pure tvar
_ -> uncatchableErr "Not a type variable"
_ -> uncatchableErr "Not a type variable"
returnType :: Type -> Type
returnType (TFun _ t2) = returnType t2
returnType a = a
returnType a = a
inferExp :: Exp -> Infer (T.ExpT' Type)
inferExp e = do
@ -250,7 +235,7 @@ class CollectTVars a where
instance CollectTVars Exp where
collectTVars (EAnn e t) = collectTVars t `S.union` collectTVars e
collectTVars _ = S.empty
collectTVars _ = S.empty
instance CollectTVars Type where
collectTVars (TVar (MkTVar i)) = S.singleton (coerce i)
@ -569,12 +554,12 @@ generalize :: Map T.Ident Type -> Type -> Type
generalize env t = go (S.toList $ free t S.\\ free env) (removeForalls t)
where
go :: [T.Ident] -> Type -> Type
go [] t = t
go [] t = t
go (x : xs) t = TAll (MkTVar (coerce x)) (go xs t)
removeForalls :: Type -> Type
removeForalls (TAll _ t) = removeForalls t
removeForalls (TAll _ t) = removeForalls t
removeForalls (TFun t1 t2) = TFun (removeForalls t1) (removeForalls t2)
removeForalls t = t
removeForalls t = t
{- | Instantiate a polymorphic type. The free type variables are substituted
with fresh ones.
@ -611,42 +596,39 @@ currently this is not the case, the TAll pattern match is incorrectly implemente
-- Is the left a subtype of the right
(<<=) :: Type -> Type -> Bool
(<<=) (TVar _) _ = True
(<<=) (TAll _ t1) (TAll _ t2) = t1 <<= t2
(<<=) t1@TAll{} t2 = skipForalls t1 <<= t2
(<<=) t1 t2@TAll{} = t1 <<= skipForalls t2
(<<=) (TFun a b) (TFun c d) = a <<= c && b <<= d
(<<=) (TData n1 ts1) (TData n2 ts2) =
n1 == n2
&& length ts1 == length ts2
&& and (zipWith (<<=) ts1 ts2)
(<<=) t0 t@(TAll _ _) = go t0 t
where
go t0 t@(TAll _ t1) = S.toList (free t0) == foralls t && go' t0 t1
go _ _ = undefined
go' (TEVar (MkTEVar a)) (TVar (MkTVar b)) = a == b
go' (TEVar (MkTEVar a)) (TEVar (MkTEVar b)) = a == b
go' (TFun a b) (TFun c d) = a `go'` c && b `go'` d
go' _ _ = False
(<<=) a b = a == b
skipForalls :: Type -> Type
skipForalls = \case
TAll _ t -> t
t -> t
foralls :: Type -> [T.Ident]
foralls (TAll (MkTVar a) t) = coerce a : foralls t
foralls _ = []
foralls _ = []
mkForall :: Type -> Type
mkForall t = case map (TAll . MkTVar . coerce) $ S.toList $ free t of
[] -> t
(x : xs) ->
let f acc [] = acc
let f acc [] = acc
f acc (x : xs) = f (x acc) xs
(y : ys) = reverse $ x : xs
in f (y t) ys
skolemize :: Type -> Type
skolemize (TVar (MkTVar a)) = TEVar $ MkTEVar a
skolemize (TAll x t) = TAll x (skolemize t)
skolemize (TFun t1 t2) = (TFun `on` skolemize) t1 t2
skolemize (TData n ts) = TData n (map skolemize ts)
skolemize t = t
skolemize (TAll x t) = TAll x (skolemize t)
skolemize (TFun t1 t2) = (TFun `on` skolemize) t1 t2
skolemize (TData n ts) = TData n (map skolemize ts)
skolemize t = t
-- | A class for substitutions
class SubstType t where
@ -680,10 +662,10 @@ instance SubstType Type where
TLit _ -> t
TVar (MkTVar a) -> case M.lookup (coerce a) sub of
Nothing -> TVar (MkTVar $ coerce a)
Just t -> t
Just t -> t
TAll (MkTVar i) t -> case M.lookup (coerce i) sub of
Nothing -> TAll (MkTVar i) (apply sub t)
Just _ -> apply sub t
Just _ -> apply sub t
TFun a b -> TFun (apply sub a) (apply sub b)
TData name a -> TData name (apply sub a)
TEVar (MkTEVar _) -> t
@ -728,10 +710,10 @@ instance SubstType (T.Branch' Type) where
instance SubstType (T.Pattern' Type) where
apply s = \case
T.PVar (iden, t) -> T.PVar (iden, apply s t)
T.PLit (lit, t) -> T.PLit (lit, apply s t)
T.PInj i ps -> T.PInj i $ apply s ps
T.PCatch -> T.PCatch
T.PEnum i -> T.PEnum i
T.PLit (lit, t) -> T.PLit (lit, apply s t)
T.PInj i ps -> T.PInj i $ apply s ps
T.PCatch -> T.PCatch
T.PEnum i -> T.PEnum i
instance SubstType (T.Pattern' Type, Type) where
apply s (p, t) = (apply s p, apply s t)
@ -773,10 +755,10 @@ withBindings xs =
withPattern :: (Monad m, MonadReader Ctx m) => T.Pattern' Type -> m a -> m a
withPattern p ma = case p of
T.PVar (x, t) -> withBinding x t ma
T.PInj _ ps -> foldl' (flip withPattern) ma ps
T.PLit _ -> ma
T.PCatch -> ma
T.PEnum _ -> ma
T.PInj _ ps -> foldl' (flip withPattern) ma ps
T.PLit _ -> ma
T.PCatch -> ma
T.PEnum _ -> ma
-- | Insert a function signature into the environment
insertSig :: T.Ident -> Maybe Type -> Infer ()
@ -801,11 +783,11 @@ existInj n = gets (M.lookup n . injections)
flattenType :: Type -> [Type]
flattenType (TFun a b) = flattenType a <> flattenType b
flattenType a = [a]
flattenType a = [a]
typeLength :: Type -> Int
typeLength (TFun _ b) = 1 + typeLength b
typeLength _ = 1
typeLength _ = 1
{- | Catch an error if possible and add the given
expression as addition to the error message
@ -888,11 +870,11 @@ newtype Ctx = Ctx {vars :: Map T.Ident Type}
deriving (Show)
data Env = Env
{ count :: Int
, nextChar :: Char
, sigs :: Map T.Ident (Maybe Type)
{ count :: Int
, nextChar :: Char
, sigs :: Map T.Ident (Maybe Type)
, takenTypeVars :: Set T.Ident
, injections :: Map T.Ident Type
, injections :: Map T.Ident Type
, declaredBinds :: Set T.Ident
}
deriving (Show)