diff --git a/src/TypeChecker/TypeCheckerHm.hs b/src/TypeChecker/TypeCheckerHm.hs index 5ef3f47..3a505b4 100644 --- a/src/TypeChecker/TypeCheckerHm.hs +++ b/src/TypeChecker/TypeCheckerHm.hs @@ -1,31 +1,31 @@ -{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedRecordDot #-} -{-# LANGUAGE OverloadedStrings #-} -{-# LANGUAGE QualifiedDo #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE QualifiedDo #-} -- | A module for type checking and inference using algorithm W, Hindley-Milner module TypeChecker.TypeCheckerHm where -import Auxiliary (int, litType, maybeToRightM, unzip4) -import qualified Auxiliary as Aux -import Control.Monad.Except -import Control.Monad.Identity (Identity, runIdentity) -import Control.Monad.Reader -import Control.Monad.State -import Control.Monad.Writer -import Data.Coerce (coerce) -import Data.Function (on) -import Data.List (foldl', nub, sortOn) -import Data.List.Extra (unsnoc) -import Data.Map (Map) -import qualified Data.Map as M -import Data.Maybe (fromJust) -import Data.Set (Set) -import qualified Data.Set as S -import Debug.Trace (trace) -import Grammar.Abs -import Grammar.Print (printTree) -import qualified TypeChecker.TypeCheckerIr as T +import Auxiliary (int, litType, maybeToRightM, unzip4) +import Auxiliary qualified as Aux +import Control.Monad.Except +import Control.Monad.Identity (Identity, runIdentity) +import Control.Monad.Reader +import Control.Monad.State +import Control.Monad.Writer +import Data.Coerce (coerce) +import Data.Function (on) +import Data.List (foldl', nub, sortOn) +import Data.List.Extra (unsnoc) +import Data.Map (Map) +import Data.Map qualified as M +import Data.Maybe (fromJust) +import Data.Set (Set) +import Data.Set qualified as S +import Debug.Trace (trace) +import Grammar.Abs +import Grammar.Print (printTree) +import TypeChecker.TypeCheckerIr qualified as T {- TODO @@ -40,7 +40,7 @@ typecheck :: Program -> Either String (T.Program' Type, [Warning]) typecheck = onLeft msg . run . checkPrg where onLeft :: (Error -> String) -> Either Error a -> Either String a - onLeft f (Left x) = Left $ f x + onLeft f (Left x) = Left $ f x onLeft _ (Right x) = Right x checkPrg :: Program -> Infer (T.Program' Type) @@ -67,13 +67,13 @@ prettify s (T.Program defs) = T.Program $ map (go s) defs replace :: Map T.Ident T.Ident -> Type -> Type replace m def@(TVar (MkTVar (LIdent a))) = case M.lookup (coerce a) m of - Just t -> TVar . MkTVar . LIdent $ coerce t + Just t -> TVar . MkTVar . LIdent $ coerce t Nothing -> def replace m (TFun t1 t2) = (TFun `on` replace m) t1 t2 replace m (TData name ts) = TData name (map (replace m) ts) replace m def@(TAll (MkTVar forall_) t) = case M.lookup (coerce forall_) m of Just found -> TAll (MkTVar $ coerce found) (replace m t) - Nothing -> def + Nothing -> def replace _ t = t bindCount :: [Def] -> Infer [(Int, Def)] @@ -127,7 +127,7 @@ preRun (x : xs) = case x of s <- gets sigs case M.lookup (coerce n) s of Nothing -> insertSig (coerce n) Nothing >> preRun xs - Just _ -> preRun xs + Just _ -> preRun xs DData d@(Data t _) -> collect (collectTVars t) >> checkData d >> preRun xs where -- Check if function body / signature has been declared already @@ -149,11 +149,11 @@ checkDef (x : xs) = case x of T.Data t $ map (\(Inj name typ) -> T.Inj (coerce name) typ) injs freeOrdered :: Type -> [T.Ident] -freeOrdered (TVar (MkTVar a)) = return (coerce a) +freeOrdered (TVar (MkTVar a)) = return (coerce a) freeOrdered (TAll (MkTVar bound) t) = return (coerce bound) ++ freeOrdered t -freeOrdered (TFun a b) = freeOrdered a ++ freeOrdered b -freeOrdered (TData _ a) = concatMap freeOrdered a -freeOrdered _ = mempty +freeOrdered (TFun a b) = freeOrdered a ++ freeOrdered b +freeOrdered (TData _ a) = concatMap freeOrdered a +freeOrdered _ = mempty checkBind :: Bind -> Infer (T.Bind' Type) checkBind (Bind name args e) = do @@ -168,6 +168,8 @@ checkBind (Bind name args e) = do let m1 = M.fromList $ zip fvs1 letters let t0 = replace m0 t' let t1 = replace m1 lambda_t + -- Not sure if this is actually correct + sub <- unify t' lambda_t unless (t1 <<= t0) ( throwError $ @@ -180,7 +182,9 @@ checkBind (Bind name args e) = do ) False ) - return $ T.Bind (coerce name, t') [] (e, lambda_t) + -- Applying sub to t' will worsen error messages. + -- Unfortunately I do not know a better solution at the moment. + return $ T.Bind (coerce name, apply sub t') [] (apply sub e, lambda_t) _ -> do insertSig (coerce name) (Just lambda_t) return (T.Bind (coerce name, lambda_t) [] (e, lambda_t)) @@ -227,11 +231,11 @@ checkInj (Inj c inj_typ) name tvars toTVar :: Type -> Either Error TVar toTVar = \case TVar tvar -> pure tvar - _ -> uncatchableErr "Not a type variable" + _ -> uncatchableErr "Not a type variable" returnType :: Type -> Type returnType (TFun _ t2) = returnType t2 -returnType a = a +returnType a = a inferExp :: Exp -> Infer (T.ExpT' Type) inferExp e = do @@ -244,7 +248,7 @@ class CollectTVars a where instance CollectTVars Exp where collectTVars (EAnn e t) = collectTVars t `S.union` collectTVars e - collectTVars _ = S.empty + collectTVars _ = S.empty instance CollectTVars Type where collectTVars (TVar (MkTVar i)) = S.singleton (coerce i) @@ -563,12 +567,12 @@ generalize :: Map T.Ident Type -> Type -> Type generalize env t = go (S.toList $ free t S.\\ free env) (removeForalls t) where go :: [T.Ident] -> Type -> Type - go [] t = t + go [] t = t go (x : xs) t = TAll (MkTVar (coerce x)) (go xs t) removeForalls :: Type -> Type - removeForalls (TAll _ t) = removeForalls t + removeForalls (TAll _ t) = removeForalls t removeForalls (TFun t1 t2) = TFun (removeForalls t1) (removeForalls t2) - removeForalls t = t + removeForalls t = t {- | Instantiate a polymorphic type. The free type variables are substituted with fresh ones. @@ -617,27 +621,27 @@ currently this is not the case, the TAll pattern match is incorrectly implemente skipForalls :: Type -> Type skipForalls = \case TAll _ t -> skipForalls t - t -> t + t -> t foralls :: Type -> [T.Ident] foralls (TAll (MkTVar a) t) = coerce a : foralls t -foralls _ = [] +foralls _ = [] mkForall :: Type -> Type mkForall t = case map (TAll . MkTVar . coerce) $ S.toList $ free t of [] -> t (x : xs) -> - let f acc [] = acc + let f acc [] = acc f acc (x : xs) = f (x acc) xs (y : ys) = reverse $ x : xs in f (y t) ys skolemize :: Type -> Type skolemize (TVar (MkTVar a)) = TEVar $ MkTEVar a -skolemize (TAll x t) = TAll x (skolemize t) -skolemize (TFun t1 t2) = (TFun `on` skolemize) t1 t2 -skolemize (TData n ts) = TData n (map skolemize ts) -skolemize t = t +skolemize (TAll x t) = TAll x (skolemize t) +skolemize (TFun t1 t2) = (TFun `on` skolemize) t1 t2 +skolemize (TData n ts) = TData n (map skolemize ts) +skolemize t = t -- | A class for substitutions class SubstType t where @@ -671,10 +675,10 @@ instance SubstType Type where TLit _ -> t TVar (MkTVar a) -> case M.lookup (coerce a) sub of Nothing -> TVar (MkTVar $ coerce a) - Just t -> t + Just t -> t TAll (MkTVar i) t -> case M.lookup (coerce i) sub of Nothing -> TAll (MkTVar i) (apply sub t) - Just _ -> apply sub t + Just _ -> apply sub t TFun a b -> TFun (apply sub a) (apply sub b) TData name a -> TData name (apply sub a) TEVar (MkTEVar _) -> t @@ -719,10 +723,10 @@ instance SubstType (T.Branch' Type) where instance SubstType (T.Pattern' Type) where apply s = \case T.PVar iden -> T.PVar iden - T.PLit lit -> T.PLit lit + T.PLit lit -> T.PLit lit T.PInj i ps -> T.PInj i $ apply s ps - T.PCatch -> T.PCatch - T.PEnum i -> T.PEnum i + T.PCatch -> T.PCatch + T.PEnum i -> T.PEnum i instance SubstType (T.Pattern' Type, Type) where apply s (p, t) = (apply s p, apply s t) @@ -763,11 +767,11 @@ withBindings xs = -- | Run the monadic action with a pattern withPattern :: (Monad m, MonadReader Ctx m) => (T.Pattern' Type, Type) -> m a -> m a withPattern (p, t) ma = case p of - T.PVar x -> withBinding x t ma + T.PVar x -> withBinding x t ma T.PInj _ ps -> foldl' (flip withPattern) ma ps - T.PLit _ -> ma - T.PCatch -> ma - T.PEnum _ -> ma + T.PLit _ -> ma + T.PCatch -> ma + T.PEnum _ -> ma -- | Insert a function signature into the environment insertSig :: T.Ident -> Maybe Type -> Infer () @@ -792,11 +796,11 @@ existInj n = gets (M.lookup n . injections) flattenType :: Type -> [Type] flattenType (TFun a b) = flattenType a <> flattenType b -flattenType a = [a] +flattenType a = [a] typeLength :: Type -> Int typeLength (TFun _ b) = 1 + typeLength b -typeLength _ = 1 +typeLength _ = 1 {- | Catch an error if possible and add the given expression as addition to the error message @@ -879,11 +883,11 @@ newtype Ctx = Ctx {vars :: Map T.Ident Type} deriving (Show) data Env = Env - { count :: Int - , nextChar :: Char - , sigs :: Map T.Ident (Maybe Type) + { count :: Int + , nextChar :: Char + , sigs :: Map T.Ident (Maybe Type) , takenTypeVars :: Set T.Ident - , injections :: Map T.Ident Type + , injections :: Map T.Ident Type , declaredBinds :: Set T.Ident } deriving (Show)