From 0d2fe862e064e2cf428089aadab7ad0f2fd4995a Mon Sep 17 00:00:00 2001 From: sebastian Date: Mon, 27 Mar 2023 23:05:40 +0200 Subject: [PATCH] fixed bug and additional test --- src/TypeChecker/TypeCheckerHm.hs | 43 ++++++++++++++++++-------------- tests/TestTypeCheckerHm.hs | 38 ++++++++++++++++++++-------- 2 files changed, 52 insertions(+), 29 deletions(-) diff --git a/src/TypeChecker/TypeCheckerHm.hs b/src/TypeChecker/TypeCheckerHm.hs index 026810f..92af317 100644 --- a/src/TypeChecker/TypeCheckerHm.hs +++ b/src/TypeChecker/TypeCheckerHm.hs @@ -113,7 +113,7 @@ preRun :: [Def] -> Infer () preRun [] = return () preRun (x : xs) = case x of DSig (Sig n t) -> do - collect (collectTypeVars t) + collect (collectTVars t) gets (M.member (coerce n) . sigs) >>= flip when @@ -123,20 +123,23 @@ preRun (x : xs) = case x of ) insertSig (coerce n) (Just t) >> preRun xs DBind (Bind n _ e) -> do - collect (collectTypeVars e) + collect (collectTVars e) s <- gets sigs case M.lookup (coerce n) s of Nothing -> insertSig (coerce n) Nothing >> preRun xs Just _ -> preRun xs - DData d@(Data t _) -> collect (collectTypeVars t) >> checkData d >> preRun xs + DData d@(Data t _) -> collect (collectTVars t) >> checkData d >> preRun xs checkDef :: [Def] -> Infer [T.Def' Type] checkDef [] = return [] checkDef (x : xs) = case x of (DBind b) -> do b' <- checkBind b - fmap (T.DBind b' :) (checkDef xs) - (DData d) -> fmap (T.DData (coerceData d) :) (checkDef xs) + xs' <- checkDef xs + return $ T.DBind b' : xs' + (DData d) -> do + xs' <- checkDef xs + return $ T.DData (coerceData d) : xs' (DSig _) -> checkDef xs where coerceData (Data t injs) = @@ -145,22 +148,24 @@ checkDef (x : xs) = case x of checkBind :: Bind -> Infer (T.Bind' Type) checkBind (Bind name args e) = do let lambda = makeLambda e (reverse (coerce args)) - e@(_, args_t) <- inferExp lambda + (e, lambda_t) <- inferExp lambda s <- gets sigs case M.lookup (coerce name) s of Just (Just t') -> do + sub1 <- unify lambda_t t' + sub2 <- unify t' lambda_t unless - (args_t `typeEq` t') + (apply sub1 lambda_t == t' && lambda_t == apply sub2 t') ( throwError $ Aux.do "Inferred type" - quote $ printTree args_t + quote $ printTree lambda_t "does not match specified type" quote $ printTree t' ) - return $ T.Bind (coerce name, t') [] e + return $ T.Bind (coerce name, t') [] (e, lambda_t) _ -> do - insertSig (coerce name) (Just args_t) - return (T.Bind (coerce name, args_t) [] e) + insertSig (coerce name) (Just lambda_t) + return (T.Bind (coerce name, lambda_t) [] (e, lambda_t)) typeEq :: Type -> Type -> Bool typeEq (TFun l r) (TFun l' r') = typeEq l l' && typeEq r r' @@ -203,18 +208,18 @@ inferExp e = do return $ second (const subbed) (e', t) class CollectTVars a where - collectTypeVars :: a -> Set T.Ident + collectTVars :: a -> Set T.Ident instance CollectTVars Exp where - collectTypeVars (EAnn e t) = collectTypeVars t `S.union` collectTypeVars e - collectTypeVars _ = S.empty + collectTVars (EAnn e t) = collectTVars t `S.union` collectTVars e + collectTVars _ = S.empty instance CollectTVars Type where - collectTypeVars (TVar (MkTVar i)) = S.singleton (coerce i) - collectTypeVars (TAll _ t) = collectTypeVars t - collectTypeVars (TFun t1 t2) = (S.union `on` collectTypeVars) t1 t2 - collectTypeVars (TData _ ts) = foldl' (\acc x -> acc `S.union` collectTypeVars x) S.empty ts - collectTypeVars _ = S.empty + collectTVars (TVar (MkTVar i)) = S.singleton (coerce i) + collectTVars (TAll _ t) = collectTVars t + collectTVars (TFun t1 t2) = (S.union `on` collectTVars) t1 t2 + collectTVars (TData _ ts) = foldl' (\acc x -> acc `S.union` collectTVars x) S.empty ts + collectTVars _ = S.empty collect :: Set T.Ident -> Infer () collect s = modify (\st -> st{takenTypeVars = s `S.union` takenTypeVars st}) diff --git a/tests/TestTypeCheckerHm.hs b/tests/TestTypeCheckerHm.hs index 0a8e76f..e326bd5 100644 --- a/tests/TestTypeCheckerHm.hs +++ b/tests/TestTypeCheckerHm.hs @@ -1,17 +1,28 @@ +{-# LANGUAGE QualifiedDo #-} {-# LANGUAGE NoImplicitPrelude #-} -{-# LANGUAGE QualifiedDo #-} module TestTypeCheckerHm where -import Control.Monad ((<=<)) -import qualified DoStrings as D -import Grammar.Par (myLexer, pProgram) -import Prelude (Bool (..), Either (..), IO, foldl1, - mapM_, not, ($), (.), (>>)) -import Test.Hspec +import Control.Monad ((<=<)) +import DoStrings qualified as D +import Grammar.Par (myLexer, pProgram) +import Grammar.Print (printTree) +import Test.Hspec +import Prelude ( + Bool (..), + Either (..), + IO, + fmap, + foldl1, + mapM_, + not, + ($), + (.), + (>>), + ) -- import Test.QuickCheck -import TypeChecker.TypeCheckerHm (typecheck) +import TypeChecker.TypeCheckerHm (typecheck) testTypeCheckerHm = describe "Hindley-Milner type checker test" $ do foldl1 (>>) goods @@ -124,6 +135,13 @@ bads = "id x = x;" ) bad + , testSatisfy + "incorrect signature on const" + ( D.do + "const : a -> b -> b;" + "const x y = x" + ) + bad , testSatisfy "incorrect type signature on id lambda" ( D.do @@ -176,10 +194,10 @@ bes = testSatisfy desc test satisfaction = specify desc $ run test `shouldSatisfy` satisfaction testBe desc test shouldbe = specify desc $ run test `shouldBe` run shouldbe -run = typecheck <=< pProgram . myLexer +run = fmap printTree . typecheck <=< pProgram . myLexer ok (Right _) = True -ok (Left _) = False +ok (Left _) = False bad = not . ok