diff --git a/src/TypeChecker/RemoveTEVar.hs b/src/TypeChecker/RemoveTEVar.hs index 43a87f7..bfa06ba 100644 --- a/src/TypeChecker/RemoveTEVar.hs +++ b/src/TypeChecker/RemoveTEVar.hs @@ -64,8 +64,8 @@ instance RemoveTEVar a b => RemoveTEVar [a] [b] where instance RemoveTEVar Type T.Type where rmTEVar = \case TLit lit -> pure $ T.TLit (coerce lit) - TVar tvar -> pure $ T.TVar tvar + TVar tvar -> pure $ T.TVar (coerce tvar) TData name typs -> T.TData (coerce name) <$> rmTEVar typs TFun t1 t2 -> liftA2 T.TFun (rmTEVar t1) (rmTEVar t2) - TAll tvar t -> T.TAll tvar <$> rmTEVar t + TAll tvar t -> T.TAll (coerce tvar) <$> rmTEVar t TEVar _ -> throwError "NewType TEVar!" diff --git a/src/TypeChecker/TypeCheckerHm.hs b/src/TypeChecker/TypeCheckerHm.hs index 2edd1f2..0cb8a4a 100644 --- a/src/TypeChecker/TypeCheckerHm.hs +++ b/src/TypeChecker/TypeCheckerHm.hs @@ -421,10 +421,10 @@ unify t0 t1 = s1 <- unify a c s2 <- unify (apply s1 b) (apply s1 d) return $ s1 `compose` s2 - (TVar (T.MkTVar a), t@(TData _ _)) -> return $ M.singleton (coerce a) t - (t@(TData _ _), TVar (T.MkTVar b)) -> return $ M.singleton (coerce b) t - (TVar (T.MkTVar a), t) -> occurs (coerce a) t - (t, TVar (T.MkTVar b)) -> occurs (coerce b) t + (TVar (MkTVar a), t@(TData _ _)) -> return $ M.singleton (coerce a) t + (t@(TData _ _), TVar (MkTVar b)) -> return $ M.singleton (coerce b) t + (TVar (MkTVar a), t) -> occurs (coerce a) t + (t, TVar (MkTVar b)) -> occurs (coerce b) t (TAll _ t, b) -> unify t b (a, TAll _ t) -> unify a t (TLit a, TLit b) -> @@ -478,7 +478,7 @@ occurs i t = catchableErr ( Aux.do "Occurs check failed, can't unify" - quote $ printTree (TVar $ T.MkTVar (coerce i)) + quote $ printTree (TVar $ MkTVar (coerce i)) "with" quote $ printTree t ) @@ -495,7 +495,7 @@ generalize env t = go (S.toList $ free t S.\\ free env) (removeForalls t) where go :: [T.Ident] -> Type -> Type go [] t = t - go (x : xs) t = TAll (T.MkTVar (coerce x)) (go xs t) + go (x : xs) t = TAll (MkTVar (coerce x)) (go xs t) removeForalls :: Type -> Type removeForalls (TAll _ t) = removeForalls t removeForalls (TFun t1 t2) = TFun (removeForalls t1) (removeForalls t2) @@ -506,7 +506,7 @@ with fresh ones. -} inst :: Type -> Infer Type inst = \case - TAll (T.MkTVar bound) t -> do + TAll (MkTVar bound) t -> do fr <- fresh let s = M.singleton (coerce bound) fr apply s <$> inst t @@ -528,8 +528,8 @@ fresh = do fresh else if n == 0 - then return . TVar . T.MkTVar $ LIdent [c] - else return . TVar . T.MkTVar . LIdent $ c : show n + then return . TVar . MkTVar $ LIdent [c] + else return . TVar . MkTVar . LIdent $ c : show n where next :: Char -> Char next 'z' = 'a' @@ -546,8 +546,8 @@ class FreeVars t where instance FreeVars Type where free :: Type -> Set T.Ident - free (TVar (T.MkTVar a)) = S.singleton (coerce a) - free (TAll (T.MkTVar bound) t) = + free (TVar (MkTVar a)) = S.singleton (coerce a) + free (TAll (MkTVar bound) t) = S.singleton (coerce bound) `S.intersection` free t free (TLit _) = mempty free (TFun a b) = free a `S.union` free b @@ -562,11 +562,11 @@ instance SubstType Type where apply sub t = do case t of TLit a -> TLit a - TVar (T.MkTVar a) -> case M.lookup (coerce a) sub of - Nothing -> TVar (T.MkTVar $ coerce a) + TVar (MkTVar a) -> case M.lookup (coerce a) sub of + Nothing -> TVar (MkTVar $ coerce a) Just t -> t - TAll (T.MkTVar i) t -> case M.lookup (coerce i) sub of - Nothing -> TAll (T.MkTVar i) (apply sub t) + TAll (MkTVar i) t -> case M.lookup (coerce i) sub of + Nothing -> TAll (MkTVar i) (apply sub t) Just _ -> apply sub t TFun a b -> TFun (apply sub a) (apply sub b) TData name a -> TData name (apply sub a) @@ -683,7 +683,7 @@ int = TLit "Int" char = TLit "Char" typeEq :: Type -> Type -> StateT Subst (ExceptT Error Identity) () -typeEq (TVar (T.MkTVar a)) t@(TVar _) = do +typeEq (TVar (MkTVar a)) t@(TVar _) = do st <- get case M.lookup (coerce a) st of Nothing -> put $ M.insert (coerce a) t st diff --git a/src/TypeChecker/TypeCheckerIr.hs b/src/TypeChecker/TypeCheckerIr.hs index c307ffe..b3f51d7 100644 --- a/src/TypeChecker/TypeCheckerIr.hs +++ b/src/TypeChecker/TypeCheckerIr.hs @@ -1,4 +1,4 @@ -{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE PatternSynonyms #-} module TypeChecker.TypeCheckerIr ( @@ -6,11 +6,11 @@ module TypeChecker.TypeCheckerIr ( module TypeChecker.TypeCheckerIr, ) where -import Data.String (IsString) -import Grammar.Abs (Lit (..)) -import Grammar.Print -import Prelude -import qualified Prelude as C (Eq, Ord, Read, Show) +import Data.String (IsString) +import Grammar.Abs (Lit (..)) +import Grammar.Print +import Prelude +import Prelude qualified as C (Eq, Ord, Read, Show) newtype Program' t = Program [Def' t] deriving (C.Eq, C.Ord, C.Show, C.Read) @@ -56,8 +56,8 @@ data Exp' t | ECase (ExpT' t) [Branch' t] deriving (C.Eq, C.Ord, C.Show, C.Read) -data TVar = MkTVar Ident - deriving (C.Eq, C.Ord, C.Show, C.Read) +newtype TVar = MkTVar Ident + deriving (C.Eq, C.Ord, C.Show, C.Read) type Id' t = (Ident, t) type ExpT' t = (Exp' t, t) @@ -105,8 +105,8 @@ instance Print t => Print (ExpT' t) where ] instance Print t => Print [Bind' t] where - prt _ [] = concatD [] - prt _ [x] = concatD [prt 0 x] + prt _ [] = concatD [] + prt _ [x] = concatD [prt 0 x] prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs] prtIdPs :: Print t => Int -> [Id' t] -> Doc @@ -171,13 +171,13 @@ instance Print t => Print (Branch' t) where prt i (Branch (pattern_, t) exp) = prPrec i 0 (concatD [doc (showString "("), prt 0 pattern_, doc (showString " : "), prt 0 t, doc (showString ")"), doc (showString "=>"), prt 0 exp]) instance Print t => Print [Branch' t] where - prt _ [] = concatD [] - prt _ [x] = concatD [prt 0 x] + prt _ [] = concatD [] + prt _ [x] = concatD [prt 0 x] prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs] instance Print t => Print (Def' t) where prt i = \case - DBind bind -> prPrec i 0 (concatD [prt 0 bind]) + DBind bind -> prPrec i 0 (concatD [prt 0 bind]) DData data_ -> prPrec i 0 (concatD [prt 0 data_]) instance Print t => Print (Data' t) where @@ -197,12 +197,12 @@ instance Print t => Print (Pattern' t) where PInj uident patterns -> prPrec i 0 (concatD [prt 0 uident, prt 1 patterns]) instance Print t => Print [Def' t] where - prt _ [] = concatD [] - prt _ [x] = concatD [prt 0 x] + prt _ [] = concatD [] + prt _ [x] = concatD [prt 0 x] prt _ (x : xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs] instance Print [Type] where - prt _ [] = concatD [] + prt _ [] = concatD [] prt _ (x : xs) = concatD [prt 0 x, doc (showString " "), prt 0 xs] instance Print Type where @@ -213,6 +213,9 @@ instance Print Type where TFun type_1 type_2 -> prPrec i 0 (concatD [prt 1 type_1, doc (showString "->"), prt 0 type_2]) TAll tvar type_ -> prPrec i 0 (concatD [doc (showString "forall"), prt 0 tvar, doc (showString "."), prt 0 type_]) +instance Print TVar where + prt i (MkTVar ident) = prt i ident + type Program = Program' Type type Def = Def' Type type Data = Data' Type