Nested pattern matching should work correctly, added more tests

This commit is contained in:
sebastian 2023-03-25 19:17:46 +01:00
parent 3082444347
commit 88eaa466e4
2 changed files with 44 additions and 48 deletions

View file

@ -339,7 +339,6 @@ makeLambda = foldl (flip (EAbs . coerce))
-- | Unify two types producing a new substitution -- | Unify two types producing a new substitution
unify :: T.Type -> T.Type -> Infer Subst unify :: T.Type -> T.Type -> Infer Subst
-- unify t0 t1 | trace ("T0: " ++ show t0 ++ "\nT1: " ++ show t1 ++ "\n") False = undefined
unify t0 t1 = do unify t0 t1 = do
case (t0, t1) of case (t0, t1) of
(T.TFun a b, T.TFun c d) -> do (T.TFun a b, T.TFun c d) -> do
@ -573,6 +572,7 @@ checkCase expT injs = do
inferBranch :: Branch -> Infer (T.Type, T.Branch, T.Type) inferBranch :: Branch -> Infer (T.Type, T.Branch, T.Type)
inferBranch (Branch pat expr) = do inferBranch (Branch pat expr) = do
newPat@(pat, branchT) <- inferPattern pat newPat@(pat, branchT) <- inferPattern pat
trace ("BRANCH TYPE: " ++ show branchT) pure ()
newExp@(_, exprT) <- withPattern pat (inferExp expr) newExp@(_, exprT) <- withPattern pat (inferExp expr)
return (branchT, T.Branch newPat newExp, exprT) return (branchT, T.Branch newPat newExp, exprT)
@ -592,8 +592,8 @@ inferPattern = \case
t <- maybeToRightM ("Constructor: " <> printTree constr <> " does not exist") t t <- maybeToRightM ("Constructor: " <> printTree constr <> " does not exist") t
(vs, ret) <- maybeToRightM "Partial pattern match not allowed" (unsnoc $ flattenType t) (vs, ret) <- maybeToRightM "Partial pattern match not allowed" (unsnoc $ flattenType t)
patterns <- mapM inferPattern patterns patterns <- mapM inferPattern patterns
zipWithM_ unify vs (map snd patterns) sub <- foldl' compose nullSubst <$> zipWithM unify vs (map snd patterns)
return (T.PInj (coerce constr) (map fst patterns), ret) return (T.PInj (coerce constr) (map fst patterns), apply sub ret)
PCatch -> (T.PCatch,) <$> fresh PCatch -> (T.PCatch,) <$> fresh
PEnum p -> do PEnum p -> do
t <- gets (M.lookup (coerce p) . constructors) t <- gets (M.lookup (coerce p) . constructors)

View file

@ -7,34 +7,25 @@ import Control.Monad ((<=<))
import DoStrings qualified as D import DoStrings qualified as D
import Grammar.Par (myLexer, pProgram) import Grammar.Par (myLexer, pProgram)
import Test.Hspec import Test.Hspec
import Prelude (Bool (..), Either (..), IO, not, ($), (.)) import Prelude (Bool (..), Either (..), IO, mapM_, not, ($), (.))
-- import Test.QuickCheck -- import Test.QuickCheck
import TypeChecker.TypeChecker (typecheck) import TypeChecker.TypeChecker (typecheck)
main :: IO () main :: IO ()
main = hspec $ do main = do
ok1 mapM_ hspec goods
ok2 mapM_ hspec bads
ok3
ok4
ok5
bad1
bad2
bad3
bad4
bad5
ok1 = goods =
specify "Basic polymorphism with multiple type variables" $ [ specify "Basic polymorphism with multiple type variables" $
run run
( D.do ( D.do
const const
"main = const 'a' 65 ;" "main = const 'a' 65 ;"
) )
`shouldSatisfy` ok `shouldSatisfy` ok
ok2 = , specify "Head with a correct signature is accepted" $
specify "Head with a correct signature is accepted" $
run run
( D.do ( D.do
list list
@ -42,9 +33,7 @@ ok2 =
head head
) )
`shouldSatisfy` ok `shouldSatisfy` ok
, specify "A basic arithmetic function should be able to be inferred" $
ok3 =
specify "A basic arithmetic function should be able to be inferred" $
run run
( D.do ( D.do
"plusOne x = x + 1 ;" "plusOne x = x + 1 ;"
@ -57,9 +46,7 @@ ok3 =
"main : Int -> Int ;" "main : Int -> Int ;"
"main x = plusOne x ;" "main x = plusOne x ;"
) )
, specify "A basic arithmetic function should be able to be inferred" $
ok4 =
specify "A basic arithmetic function should be able to be inferred" $
run run
( D.do ( D.do
"plusOne x = x + 1 ;" "plusOne x = x + 1 ;"
@ -69,25 +56,33 @@ ok4 =
"plusOne : Int -> Int ;" "plusOne : Int -> Int ;"
"plusOne x = x + 1 ;" "plusOne x = x + 1 ;"
) )
, specify "Most simple inference possible" $
ok5 =
specify "Most simple inference possible" $
run run
( D.do ( D.do
"id x = x ;" "id x = x ;"
) )
`shouldSatisfy` ok `shouldSatisfy` ok
, specify "Pattern matching on a nested list" $
run
( D.do
list
"main : List (List (a)) -> Int ;"
"main xs = case xs of {"
" Cons Nil _ => 1 ;"
" _ => 0 ;"
"};"
)
`shouldSatisfy` ok
]
bad1 = bads =
specify "Infinite type unification should not succeed" $ [ specify "Infinite type unification should not succeed" $
run run
( D.do ( D.do
"main = \\x. x x ;" "main = \\x. x x ;"
) )
`shouldSatisfy` bad `shouldSatisfy` bad
, specify "Pattern matching using different types should not succeed" $
bad2 =
specify "Pattern matching using different types should not succeed" $
run run
( D.do ( D.do
list list
@ -97,9 +92,7 @@ bad2 =
"};" "};"
) )
`shouldSatisfy` bad `shouldSatisfy` bad
, specify "Using a concrete function (data type) on a skolem variable should not succeed" $
bad3 =
specify "Using a concrete function (data type) on a skolem variable should not succeed" $
run run
( D.do ( D.do
bool bool
@ -108,9 +101,7 @@ bad3 =
"f x = not x ;" "f x = not x ;"
) )
`shouldSatisfy` bad `shouldSatisfy` bad
, specify "Using a concrete function (primitive type) on a skolem variable should not succeed" $
bad4 =
specify "Using a concrete function (primitive type) on a skolem variable should not succeed" $
run run
( D.do ( D.do
"plusOne : Int -> Int ;" "plusOne : Int -> Int ;"
@ -119,15 +110,25 @@ bad4 =
" f x = plusOne x ;" " f x = plusOne x ;"
) )
`shouldSatisfy` bad `shouldSatisfy` bad
, specify "A function without signature used in an incompatible context should not succeed" $
bad5 =
specify "A function without signature used in an incompatible context should not succeed" $
run run
( D.do ( D.do
"main = id 1 2 ;" "main = id 1 2 ;"
"id x = x ;" "id x = x ;"
) )
`shouldSatisfy` bad `shouldSatisfy` bad
, specify "Pattern matching on literal and list should not succeed" $
run
( D.do
list
"length : List (c) -> Int;"
"length list = case list of {"
" 0 => 0;"
" Cons x xs => 1 + length xs;"
"};"
)
`shouldSatisfy` bad
]
run = typecheck <=< pProgram . myLexer run = typecheck <=< pProgram . myLexer
@ -169,8 +170,3 @@ _not = D.do
" True => False ;" " True => False ;"
" False => True ;" " False => True ;"
"};" "};"
{-
[a, b, c] | (Int -> Int)
(a -> (b -> (c -> (Int -> Int))))
-}