diff --git a/src/TypeChecker/TypeChecker.hs b/src/TypeChecker/TypeChecker.hs index 9cc37ee..8b7625e 100644 --- a/src/TypeChecker/TypeChecker.hs +++ b/src/TypeChecker/TypeChecker.hs @@ -12,6 +12,7 @@ import Control.Monad.State import Data.Bifunctor (second) import Data.Coerce (coerce) import Data.Foldable (traverse_) +import Data.Function (on) import Data.List (foldl') import Data.List.Extra (unsnoc) import Data.Map (Map) @@ -149,6 +150,12 @@ typeEq t1 (T.TAll _ t2) = t1 `typeEq` t2 typeEq (T.TVar _) (T.TVar _) = True typeEq _ _ = False +skolem :: T.Type -> T.Type +skolem (T.TVar (T.MkTVar a)) = T.TLit a +skolem (T.TAll x t) = T.TAll x (skolem t) +skolem (T.TFun t1 t2) = (T.TFun `on` skolem) t1 t2 +skolem t = t + isMoreSpecificOrEq :: T.Type -> T.Type -> Bool isMoreSpecificOrEq t1 (T.TAll _ t2) = isMoreSpecificOrEq t1 t2 isMoreSpecificOrEq (T.TFun a b) (T.TFun c d) = @@ -181,7 +188,7 @@ instance CollectTVars Exp where instance CollectTVars Type where collectTypeVars (TVar (MkTVar i)) = S.singleton (coerce i) collectTypeVars (TAll _ t) = collectTypeVars t - collectTypeVars (TFun t1 t2) = collectTypeVars t1 `S.union` collectTypeVars t2 + collectTypeVars (TFun t1 t2) = (S.union `on` collectTypeVars) t1 t2 collectTypeVars (TData _ ts) = foldl' (\acc x -> acc `S.union` collectTypeVars x) S.empty ts collectTypeVars _ = S.empty @@ -195,7 +202,7 @@ instance NewType Type T.Type where toNew = \case TLit i -> T.TLit $ coerce i TVar v -> T.TVar $ toNew v - TFun t1 t2 -> T.TFun (toNew t1) (toNew t2) + TFun t1 t2 -> (T.TFun `on` toNew) t1 t2 TAll b t -> T.TAll (toNew b) (toNew t) TData i ts -> T.TData (coerce i) (map toNew ts) TEVar _ -> error "Should not exist after typechecker" @@ -414,10 +421,8 @@ occurs i t = -- | Generalize a type over all free variables in the substitution set generalize :: Map T.Ident T.Type -> T.Type -> T.Type -generalize env t = go freeVars $ removeForalls t +generalize env t = go (S.toList $ free t S.\\ free env) (removeForalls t) where - freeVars :: [T.Ident] - freeVars = S.toList $ free t S.\\ free env go :: [T.Ident] -> T.Type -> T.Type go [] t = t go (x : xs) t = T.TAll (T.MkTVar x) (go xs t)