Add test for pattern matching on recursive data types, and remove traces

This commit is contained in:
Martin Fredin 2023-03-29 11:25:45 +02:00
parent 52db1943bb
commit 4755f434fd
2 changed files with 49 additions and 16 deletions

View file

@ -243,8 +243,6 @@ subtype t1 t2 = case (t1, t2) of
, t1:t1s <- typs1 , t1:t1s <- typs1
, t2:t2s <- typs2 , t2:t2s <- typs2
-> do -> do
traceT "t1" (TData name1 typs1)
traceT "t2" (TData name2 typs2)
subtype t1 t2 subtype t1 t2
zipWithM_ go t1s t2s zipWithM_ go t1s t2s
where where
@ -868,7 +866,7 @@ putEnv = modifyEnv . const
modifyEnv :: (Env -> Env) -> Tc () modifyEnv :: (Env -> Env) -> Tc ()
modifyEnv f = modifyEnv f =
modify $ \cxt -> trace (ppEnv (f cxt.env)) cxt { env = f cxt.env } modify $ \cxt -> {- trace (ppEnv (f cxt.env)) -} cxt { env = f cxt.env }
pattern DBind' name vars exp = DBind (Bind name vars exp) pattern DBind' name vars exp = DBind (Bind name vars exp)
pattern DSig' name typ = DSig (Sig name typ) pattern DSig' name typ = DSig (Sig name typ)

View file

@ -1,20 +1,23 @@
{-# LANGUAGE LambdaCase #-} {-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE PatternSynonyms #-}
{-# HLINT ignore "Use camelCase" #-} {-# HLINT ignore "Use camelCase" #-}
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-} {-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
module TestTypeCheckerBidir (testTypeCheckerBidir) where module TestTypeCheckerBidir (test, testTypeCheckerBidir) where
import Test.Hspec import Test.Hspec
import Control.Monad ((<=<)) import Control.Monad ((<=<))
import Grammar.ErrM (Err, pattern Bad, pattern Ok) import Grammar.ErrM (Err, pattern Bad, pattern Ok)
import Grammar.Par (myLexer, pProgram) import Grammar.Par (myLexer, pProgram)
import Renamer.Renamer (rename) import Renamer.Renamer (rename)
import TypeChecker.RemoveTEVar (RemoveTEVar (rmTEVar)) import TypeChecker.RemoveTEVar (RemoveTEVar (rmTEVar))
import TypeChecker.TypeCheckerBidir (typecheck) import TypeChecker.TypeCheckerBidir (typecheck)
import TypeChecker.TypeCheckerIr qualified as T import qualified TypeChecker.TypeCheckerIr as T
test = hspec testTypeCheckerBidir
testTypeCheckerBidir = describe "Bidirectional type checker test" $ do testTypeCheckerBidir = describe "Bidirectional type checker test" $ do
tc_id tc_id
@ -176,17 +179,23 @@ tc_mono_case = describe "Monomorphic pattern matching" $ do
, "};" , "};"
] ]
tc_pol_case = describe "Polymophic pattern matching" $ do tc_pol_case = describe "Polymophic and recursive pattern matching" $ do
specify "First wrong case expression rejected" $ specify "First wrong case expression rejected" $
run (fs ++ wrong1) `shouldNotSatisfy` ok run (fs ++ wrong1) `shouldNotSatisfy` ok
specify "Second wrong case expression rejected" $ specify "Second wrong case expression rejected" $
run (fs ++ wrong2) `shouldNotSatisfy` ok run (fs ++ wrong2) `shouldNotSatisfy` ok
specify "Third wrong case expression rejected" $ specify "Third wrong case expression rejected" $
run (fs ++ wrong3) `shouldNotSatisfy` ok run (fs ++ wrong3) `shouldNotSatisfy` ok
specify "Forth wrong case expression rejected" $
run (fs ++ wrong4) `shouldNotSatisfy` ok
specify "First correct case expression accepted" $ specify "First correct case expression accepted" $
run (fs ++ correct1) `shouldSatisfy` ok run (fs ++ correct1) `shouldSatisfy` ok
specify "Second correct case expression accepted" $ specify "Second correct case expression accepted" $
run (fs ++ correct2) `shouldSatisfy` ok run (fs ++ correct2) `shouldSatisfy` ok
specify "Third correct case expression accepted" $
run (fs ++ correct3) `shouldSatisfy` ok
specify "Forth correct case expression accepted" $
run (fs ++ correct4) `shouldSatisfy` ok
where where
fs = fs =
[ "data forall a. List (a) where {" [ "data forall a. List (a) where {"
@ -215,6 +224,15 @@ tc_pol_case = describe "Polymophic pattern matching" $ do
, " Cons x xs => 1 + length xs;" , " Cons x xs => 1 + length xs;"
, "};" , "};"
] ]
wrong4 =
[ "elems : forall c. List (List(c)) -> Int;"
, "elems = \\list. case list of {"
, " Nil => 0;"
, " Cons Nil Nil => 0;"
, " Cons Nil xs => elems xs;"
, " Cons (Cons Nil ys) xs => 1 + elems (Cons ys xs);"
, "};"
]
correct1 = correct1 =
[ "length : forall c. List (c) -> Int;" [ "length : forall c. List (c) -> Int;"
, "length = \\list. case list of {" , "length = \\list. case list of {"
@ -230,10 +248,27 @@ tc_pol_case = describe "Polymophic pattern matching" $ do
, " non_empty => 1;" , " non_empty => 1;"
, "};" , "};"
] ]
correct3 =
[ "length : List (Int) -> Int;"
, "length = \\list. case list of {"
, " Nil => 0;"
, " Cons 1 Nil => 1;"
, " Cons x (Cons 2 xs) => 2 + length xs;"
, "};"
]
correct4 =
[ "elems : forall c. List (List(c)) -> Int;"
, "elems = \\list. case list of {"
, " Nil => 0;"
, " Cons Nil Nil => 0;"
, " Cons Nil xs => elems xs;"
, " Cons (Cons _ ys) xs => 1 + elems (Cons ys xs);"
, "};"
]
run :: [String] -> Err T.Program run :: [String] -> Err T.Program
run = rmTEVar <=< typecheck <=< pProgram . myLexer . unlines run = rmTEVar <=< typecheck <=< pProgram . myLexer . unlines
ok = \case ok = \case
Ok _ -> True Ok _ -> True
Bad _ -> False Bad _ -> False