Add implicit foralls for bidir, update and unify pipeline

This commit is contained in:
Martin Fredin 2023-04-03 17:34:33 +02:00
parent 12bca1c32d
commit 9870802371
33 changed files with 1010 additions and 1055 deletions

View file

@ -8,19 +8,25 @@ module TestTypeCheckerBidir (test, testTypeCheckerBidir) where
import Test.Hspec
import AnnForall (annotateForall)
import Control.Monad ((<=<))
import Grammar.Abs (Program)
import Grammar.ErrM (Err, pattern Bad, pattern Ok)
import Grammar.Layout (resolveLayout)
import Grammar.Par (myLexer, pProgram)
import Grammar.Print (printTree)
import Renamer.Renamer (rename)
import TypeChecker.RemoveTEVar (RemoveTEVar (rmTEVar))
import ReportForall (reportForall)
import TypeChecker.RemoveForall (removeForall)
import TypeChecker.ReportTEVar (reportTEVar)
import TypeChecker.TypeChecker (TypeChecker (Bi))
import TypeChecker.TypeCheckerBidir (typecheck)
import qualified TypeChecker.TypeCheckerIr as T
test = hspec testTypeCheckerBidir
testTypeCheckerBidir = describe "Bidirectional type checker test" $ do
testTypeCheckerBidir = describe "Test Bidirectional type checker" $ do
tc_id
tc_double
tc_add_lam
@ -39,7 +45,7 @@ testTypeCheckerBidir = describe "Bidirectional type checker test" $ do
tc_id =
specify "Basic identity function polymorphism" $
run
[ "id : forall a. a -> a"
[ "id : a -> a"
, "id x = x"
, "main = id 4"
]
@ -60,7 +66,7 @@ tc_add_lam =
tc_const =
specify "Basic polymorphism with multiple type variables" $
run
[ "const : forall a. forall b. a -> b -> a"
[ "const : a -> b -> a"
, "const x y = x"
, "main = const 'a' 65"
]
@ -69,9 +75,9 @@ tc_const =
tc_simple_rank2 =
specify "Simple rank two polymorphism" $
run
[ "id : forall a. a -> a"
[ "id : a -> a"
, "id x = x"
, "f : forall a. a -> (forall b. b -> b) -> a"
, "f : a -> (forall b. b -> b) -> a"
, "f x g = g x"
, "main = f 4 id"
]
@ -80,11 +86,11 @@ tc_simple_rank2 =
tc_rank2 =
specify "Rank two polymorphism is ok" $
run
[ "const : forall a. forall b. a -> b -> a"
[ "const : a -> b -> a"
, "const x y = x"
, "rank2 : forall a. forall b. a -> (forall c. c -> Int) -> b -> Int"
, "rank2 : a -> (forall c. c -> Int) -> b -> Int"
, "rank2 x f y = f x + f y"
, "main = rank2 3 (\\x. const 5 x : forall a. a -> Int) 'h'"
, "main = rank2 3 (\\x. const 5 x : a -> Int) 'h'"
]
`shouldSatisfy` ok
@ -93,9 +99,9 @@ tc_identity = describe "(∀b. b → b) should only accept the identity function
specify "identity is accepted" $ run (fs ++ id) `shouldSatisfy` ok
where
fs =
[ "f : forall a. a -> (forall b. b -> b) -> a"
[ "f : a -> (forall b. b -> b) -> a"
, "f x g = g x"
, "id : forall a. a -> a"
, "id : a -> a"
, "id x = x"
, "id_int : Int -> Int"
, "id_int x = x"
@ -114,7 +120,7 @@ tc_pair = describe "Pair. Type variables in Pair a b typechecked" $ do
specify "Correct arguments are accepted" $ run (fs ++ correct) `shouldSatisfy` ok
where
fs =
[ "data forall a. forall b. Pair (a b) where"
[ "data Pair (a b) where"
, " Pair : a -> b -> Pair (a b)"
, "main : Pair (Int Char)"
]
@ -126,7 +132,7 @@ tc_tree = describe "Tree. Recursive data type" $ do
specify "Correct tree is accepted" $ run (fs ++ correct) `shouldSatisfy` ok
where
fs =
[ "data forall a. Tree (a) where"
[ "data Tree (a) where"
, " Node : a -> Tree (a) -> Tree (a) -> Tree (a)"
, " Leaf : a -> Tree (a)"
]
@ -195,30 +201,30 @@ tc_pol_case = describe "Polymophic and recursive pattern matching" $ do
run (fs ++ correct4) `shouldSatisfy` ok
where
fs =
[ "data forall a. List (a) where"
[ "data List (a) where"
, " Nil : List (a)"
, " Cons : a -> List (a) -> List (a)"
]
wrong1 =
[ "length : forall c. List (c) -> Int"
[ "length : List (c) -> Int"
, "length = \\list. case list of"
, " Nil => 0"
, " Cons 6 xs => 1 + length xs"
]
wrong2 =
[ "length : forall c. List (c) -> Int"
[ "length : List (c) -> Int"
, "length = \\list. case list of"
, " Cons => 0"
, " Cons x xs => 1 + length xs"
]
wrong3 =
[ "length : forall c. List (c) -> Int"
[ "length : List (c) -> Int"
, "length = \\list. case list of"
, " 0 => 0"
, " Cons x xs => 1 + length xs"
]
wrong4 =
[ "elems : forall c. List (List(c)) -> Int"
[ "elems : List (List(c)) -> Int"
, "elems = \\list. case list of"
, " Nil => 0"
, " Cons Nil Nil => 0"
@ -226,14 +232,14 @@ tc_pol_case = describe "Polymophic and recursive pattern matching" $ do
, " Cons (Cons Nil ys) xs => 1 + elems (Cons ys xs)"
]
correct1 =
[ "length : forall c. List (c) -> Int"
[ "length : List (c) -> Int"
, "length = \\list. case list of"
, " Nil => 0"
, " Cons x xs => 1 + length xs"
, " Cons x (Cons y Nil) => 2"
]
correct2 =
[ "length : forall c. List (c) -> Int"
[ "length : List (c) -> Int"
, "length = \\list. case list of"
, " Nil => 0"
, " non_empty => 1"
@ -246,7 +252,7 @@ tc_pol_case = describe "Polymophic and recursive pattern matching" $ do
, " Cons x (Cons 2 xs) => 2 + length xs"
]
correct4 =
[ "elems : forall c. List (List(c)) -> Int"
[ "elems : List (List(c)) -> Int"
, "elems = \\list. case list of"
, " Nil => 0"
, " Cons Nil Nil => 0"
@ -292,9 +298,19 @@ tc_rec2 = specify "Infer recursive definition with pattern matching" $ run
, " _ => test (x+1)"
] `shouldSatisfy` ok
run :: [String] -> Err T.Program
run = rmTEVar <=< typecheck <=< pProgram . resolveLayout True . myLexer . unlines
run = fmap removeForall
. reportTEVar
<=< typecheck
<=< run'
run' s = do
p <- (pProgram . resolveLayout True . myLexer . unlines) s
reportForall Bi p
(rename <=< annotateForall) p
runPrint = (putStrLn . either show printTree . run')
["double x = x + x"]
ok = \case
Ok _ -> True