Parens removed on types and infix symbols work almost, just need to adapt in LLVM

This commit is contained in:
sebastian 2023-05-04 22:50:15 +02:00
parent c309c439cb
commit 0dc06eaf80
10 changed files with 494 additions and 437 deletions

View file

@ -1,28 +1,28 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE PatternSynonyms #-}
{-# HLINT ignore "Use camelCase" #-}
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
module TestTypeCheckerBidir (test, testTypeCheckerBidir) where
import Test.Hspec
import AnnForall (annotateForall)
import Control.Monad ((<=<))
import Grammar.Abs (Program)
import Grammar.ErrM (Err, pattern Bad, pattern Ok)
import Grammar.Layout (resolveLayout)
import Grammar.Par (myLexer, pProgram)
import Grammar.Print (printTree)
import Renamer.Renamer (rename)
import ReportForall (reportForall)
import TypeChecker.RemoveForall (removeForall)
import TypeChecker.ReportTEVar (reportTEVar)
import TypeChecker.TypeChecker (TypeChecker (Bi))
import TypeChecker.TypeCheckerBidir (typecheck)
import qualified TypeChecker.TypeCheckerIr as T
import Test.Hspec
import AnnForall (annotateForall)
import Control.Monad ((<=<))
import Desugar.Desugar (desugar)
import Grammar.Abs (Program)
import Grammar.ErrM (Err, pattern Bad, pattern Ok)
import Grammar.Layout (resolveLayout)
import Grammar.Par (myLexer, pProgram)
import Grammar.Print (printTree)
import Renamer.Renamer (rename)
import ReportForall (reportForall)
import TypeChecker.RemoveForall (removeForall)
import TypeChecker.ReportTEVar (reportTEVar)
import TypeChecker.TypeChecker (TypeChecker (Bi))
import TypeChecker.TypeCheckerBidir (typecheck)
import TypeChecker.TypeCheckerIr qualified as T
test = hspec testTypeCheckerBidir
@ -120,9 +120,9 @@ tc_pair = describe "Pair. Type variables in Pair a b typechecked" $ do
specify "Correct arguments are accepted" $ run (fs ++ correct) `shouldSatisfy` ok
where
fs =
[ "data Pair (a b) where"
, " Pair : a -> b -> Pair (a b)"
, "main : Pair (Int Char)"
[ "data Pair a b where"
, " Pair : a -> b -> Pair a b"
, "main : Pair Int Char"
]
wrong = ["main = Pair 'a' 65"]
correct = ["main = Pair 65 'a'"]
@ -132,9 +132,9 @@ tc_tree = describe "Tree. Recursive data type" $ do
specify "Correct tree is accepted" $ run (fs ++ correct) `shouldSatisfy` ok
where
fs =
[ "data Tree (a) where"
, " Node : a -> Tree (a) -> Tree (a) -> Tree (a)"
, " Leaf : a -> Tree (a)"
[ "data Tree a where"
, " Node : a -> Tree a -> Tree a -> Tree a"
, " Leaf : a -> Tree a"
]
wrong = ["tree = Node 1 (Node 2 (Node 4) (Leaf 5)) (Leaf 3)"]
correct = ["tree = Node 1 (Node 2 (Leaf 4) (Leaf 5)) (Leaf 3)"]
@ -201,30 +201,30 @@ tc_pol_case = describe "Polymophic and recursive pattern matching" $ do
run (fs ++ correct4) `shouldSatisfy` ok
where
fs =
[ "data List (a) where"
, " Nil : List (a)"
, " Cons : a -> List (a) -> List (a)"
[ "data List a where"
, " Nil : List a"
, " Cons : a -> List a -> List a"
]
wrong1 =
[ "length : List (c) -> Int"
[ "length : List c -> Int"
, "length = \\list. case list of"
, " Nil => 0"
, " Cons 6 xs => 1 + length xs"
]
wrong2 =
[ "length : List (c) -> Int"
[ "length : List c -> Int"
, "length = \\list. case list of"
, " Cons => 0"
, " Cons x xs => 1 + length xs"
]
wrong3 =
[ "length : List (c) -> Int"
[ "length : List c -> Int"
, "length = \\list. case list of"
, " 0 => 0"
, " Cons x xs => 1 + length xs"
]
wrong4 =
[ "elems : List (List(c)) -> Int"
[ "elems : List (List c) -> Int"
, "elems = \\list. case list of"
, " Nil => 0"
, " Cons Nil Nil => 0"
@ -232,27 +232,27 @@ tc_pol_case = describe "Polymophic and recursive pattern matching" $ do
, " Cons (Cons Nil ys) xs => 1 + elems (Cons ys xs)"
]
correct1 =
[ "length : List (c) -> Int"
[ "length : List c -> Int"
, "length = \\list. case list of"
, " Nil => 0"
, " Cons x xs => 1 + length xs"
, " Cons x (Cons y Nil) => 2"
]
correct2 =
[ "length : List (c) -> Int"
[ "length : List c -> Int"
, "length = \\list. case list of"
, " Nil => 0"
, " non_empty => 1"
]
correct3 =
[ "length : List (Int) -> Int"
[ "length : List Int -> Int"
, "length = \\list. case list of"
, " Nil => 0"
, " Cons 1 Nil => 1"
, " Cons x (Cons 2 xs) => 2 + length xs"
]
correct4 =
[ "elems : List (List(c)) -> Int"
[ "elems : List (List c) -> Int"
, "elems = \\list. case list of"
, " Nil => 0"
, " Cons Nil Nil => 0"
@ -261,16 +261,16 @@ tc_pol_case = describe "Polymophic and recursive pattern matching" $ do
]
tc_if = specify "Test if else case expression" $ do
run [ "data Bool () where"
, " True : Bool ()"
, " False : Bool ()"
, "ifThenElse : Bool () -> a -> a -> a"
run
[ "data Bool where"
, " True : Bool"
, " False : Bool"
, "ifThenElse : Bool -> a -> a -> a"
, "ifThenElse b if else = case b of"
, " True => if"
, " False => else"
] `shouldSatisfy` ok
]
`shouldSatisfy` ok
tc_infer_case = describe "Infer case expression" $ do
specify "Wrong case expression rejected" $
@ -279,9 +279,9 @@ tc_infer_case = describe "Infer case expression" $ do
run (fs ++ correct) `shouldSatisfy` ok
where
fs =
[ "data Bool () where"
, " True : Bool ()"
, " False : Bool ()"
[ "data Bool where"
, " True : Bool"
, " False : Bool"
]
correct =
@ -296,33 +296,38 @@ tc_infer_case = describe "Infer case expression" $ do
, " _ => 1"
]
tc_rec1 = specify "Infer simple recursive definition" $
run ["test x = 1 + test (x + 1)"] `shouldSatisfy` ok
tc_rec1 =
specify "Infer simple recursive definition" $
run ["test x = 1 + test (x + 1)"] `shouldSatisfy` ok
tc_rec2 = specify "Infer recursive definition with pattern matching" $ run
[ "data Bool () where"
, " False : Bool ()"
, " True : Bool ()"
, "test = \\x. case x of"
, " 10 => True"
, " _ => test (x+1)"
] `shouldSatisfy` ok
tc_rec2 =
specify "Infer recursive definition with pattern matching" $
run
[ "data Bool where"
, " False : Bool"
, " True : Bool"
, "test = \\x. case x of"
, " 10 => True"
, " _ => test (x+1)"
]
`shouldSatisfy` ok
run :: [String] -> Err T.Program
run = fmap removeForall
. reportTEVar
<=< typecheck
<=< run'
run =
fmap removeForall
. reportTEVar
<=< typecheck
<=< run'
run' s = do
p <- (pProgram . resolveLayout True . myLexer . unlines) s
p <- (fmap desugar . pProgram . resolveLayout True . myLexer . unlines) s
reportForall Bi p
(rename <=< annotateForall) p
runPrint = (putStrLn . either show printTree . run')
["double x = x + x"]
runPrint =
(putStrLn . either show printTree . run')
["double x = x + x"]
ok = \case
Ok _ -> True
Ok _ -> True
Bad _ -> False