Better inference & stuff on pattern matches, added more tests for regression

This commit is contained in:
sebastian 2023-03-25 20:43:19 +01:00
parent 88eaa466e4
commit 975dd34063
3 changed files with 94 additions and 64 deletions

View file

@ -5,6 +5,7 @@ build:
clean: clean:
rm -r src/Grammar rm -r src/Grammar
rm language rm language
rm -r dist-newstyle/
# run all tests # run all tests
test: test:

View file

@ -6,19 +6,19 @@ module TypeChecker.TypeChecker where
import Auxiliary import Auxiliary
import Control.Monad.Except import Control.Monad.Except
import Control.Monad.Identity (runIdentity)
import Control.Monad.Reader import Control.Monad.Reader
import Control.Monad.State import Control.Monad.State
import Data.Bifunctor (second) import Data.Bifunctor (second)
import Data.Coerce (coerce) import Data.Coerce (coerce)
import Data.Foldable (traverse_) import Data.Foldable (traverse_)
import Data.Functor.Identity (runIdentity)
import Data.List (foldl') import Data.List (foldl')
import Data.List.Extra (unsnoc) import Data.List.Extra (unsnoc)
import Data.Map (Map) import Data.Map (Map)
import Data.Map qualified as M import Data.Map qualified as M
import Data.Maybe (fromJust)
import Data.Set (Set) import Data.Set (Set)
import Data.Set qualified as S import Data.Set qualified as S
import Debug.Trace (trace)
import Grammar.Abs import Grammar.Abs
import Grammar.Print (printTree) import Grammar.Print (printTree)
import TypeChecker.TypeCheckerIr ( import TypeChecker.TypeCheckerIr (
@ -117,20 +117,12 @@ checkPrg (Program bs) = do
(DSig _) -> checkDef xs (DSig _) -> checkDef xs
checkBind :: Bind -> Infer T.Bind checkBind :: Bind -> Infer T.Bind
checkBind err@(Bind name args e) = do checkBind (Bind name args e) = do
let lambda = makeLambda e (reverse (coerce args)) let lambda = makeLambda e (reverse (coerce args))
e@(_, args_t) <- inferExp lambda e@(_, args_t) <- inferExp lambda
-- args <- zip args <$> mapM (const fresh) args
-- withBindings (coerce args) $ do
-- e@(_, t) <- inferExp e
-- let args_t = foldl' T.TFun t (reverse (map snd args))
s <- gets sigs s <- gets sigs
case M.lookup (coerce name) s of case M.lookup (coerce name) s of
Just (Just t') -> do Just (Just t') -> do
-- sub <- bindErr (unify args_t t') err
-- let newT = apply sub args_t
-- insertSig (coerce name) (Just newT)
-- return $ T.Bind (apply sub (coerce name, newT)) [] e
unless unless
(args_t `typeEq` t') (args_t `typeEq` t')
( throwError $ ( throwError $
@ -152,7 +144,8 @@ typeEq (T.TData name a) (T.TData name' b) =
length a == length b length a == length b
&& name == name' && name == name'
&& and (zipWith typeEq a b) && and (zipWith typeEq a b)
typeEq (T.TAll _ t1) (T.TAll _ t2) = t1 `typeEq` t2 typeEq (T.TAll _ t1) t2 = t1 `typeEq` t2
typeEq t1 (T.TAll _ t2) = t1 `typeEq` t2
typeEq (T.TVar _) (T.TVar _) = True typeEq (T.TVar _) (T.TVar _) = True
typeEq _ _ = False typeEq _ _ = False
@ -164,6 +157,7 @@ isMoreSpecificOrEq (T.TData n1 ts1) (T.TData n2 ts2) =
n1 == n2 n1 == n2
&& length ts1 == length ts2 && length ts1 == length ts2
&& and (zipWith isMoreSpecificOrEq ts1 ts2) && and (zipWith isMoreSpecificOrEq ts1 ts2)
isMoreSpecificOrEq _ (T.TVar _) = True
isMoreSpecificOrEq a b = a == b isMoreSpecificOrEq a b = a == b
isPoly :: Type -> Bool isPoly :: Type -> Bool
@ -175,10 +169,7 @@ inferExp :: Exp -> Infer T.ExpT
inferExp e = do inferExp e = do
(s, (e', t)) <- algoW e (s, (e', t)) <- algoW e
let subbed = apply s t let subbed = apply s t
return $ replace subbed (e', t) return $ second (const subbed) (e', t)
replace :: T.Type -> T.ExpT -> T.ExpT
replace t = second (const t)
class NewType a b where class NewType a b where
toNew :: a -> b toNew :: a -> b
@ -200,7 +191,7 @@ instance NewType Data T.Data where
toNew (Data t xs) = T.Data (name $ retType t) (toNew xs) toNew (Data t xs) = T.Data (name $ retType t) (toNew xs)
where where
name (TData n _) = coerce n name (TData n _) = coerce n
name _ = error "Bug in toNew Data -> T.Data" name _ = error "Bug: Data types should not be able to be typed over non type variables"
instance NewType Constructor T.Constructor where instance NewType Constructor T.Constructor where
toNew (Constructor name xs) = T.Constructor (coerce name) (toNew xs) toNew (Constructor name xs) = T.Constructor (coerce name) (toNew xs)
@ -213,7 +204,6 @@ instance NewType a b => NewType [a] [b] where
algoW :: Exp -> Infer (Subst, T.ExpT) algoW :: Exp -> Infer (Subst, T.ExpT)
algoW = \case algoW = \case
-- \| TODO: More testing need to be done. Unsure of the correctness of this
err@(EAnn e t) -> do err@(EAnn e t) -> do
(s1, (e', t')) <- exprErr (algoW e) err (s1, (e', t')) <- exprErr (algoW e) err
unless unless
@ -434,6 +424,9 @@ inst = \case
compose :: Subst -> Subst -> Subst compose :: Subst -> Subst -> Subst
compose m1 m2 = M.map (apply m1) m2 `M.union` m1 compose m1 m2 = M.map (apply m1) m2 `M.union` m1
composeAll :: [Subst] -> Subst
composeAll = foldl' compose nullSubst
-- TODO: Split this class into two separate classes, one for free variables -- TODO: Split this class into two separate classes, one for free variables
-- and one for applying substitutions -- and one for applying substitutions
@ -477,21 +470,19 @@ instance SubstType (Map T.Ident T.Type) where
apply :: Subst -> Map T.Ident T.Type -> Map T.Ident T.Type apply :: Subst -> Map T.Ident T.Type -> Map T.Ident T.Type
apply s = M.map (apply s) apply s = M.map (apply s)
instance SubstType T.ExpT where instance SubstType T.Exp where
apply :: Subst -> T.ExpT -> T.ExpT apply :: Subst -> T.Exp -> T.Exp
apply s = \case apply s = \case
(T.EId i, outerT) -> (T.EId i, apply s outerT) T.EId i -> T.EId i
(T.ELit lit, t) -> (T.ELit lit, apply s t) T.ELit lit -> T.ELit lit
(T.ELet (T.Bind (ident, t1) args e1) e2, t2) -> T.ELet (T.Bind (ident, t1) args e1) e2 ->
( T.ELet T.ELet
(T.Bind (ident, apply s t1) args (apply s e1)) (T.Bind (ident, apply s t1) args (apply s e1))
(apply s e2) (apply s e2)
, apply s t2 T.EApp e1 e2 -> T.EApp (apply s e1) (apply s e2)
) T.EAdd e1 e2 -> T.EAdd (apply s e1) (apply s e2)
(T.EApp e1 e2, t) -> (T.EApp (apply s e1) (apply s e2), apply s t) T.EAbs ident e -> T.EAbs ident (apply s e)
(T.EAdd e1 e2, t) -> (T.EAdd (apply s e1) (apply s e2), apply s t) T.ECase e brnch -> T.ECase (apply s e) (apply s brnch)
(T.EAbs ident e, t1) -> (T.EAbs ident (apply s e), apply s t1)
(T.ECase e brnch, t) -> (T.ECase (apply s e) (apply s brnch), apply s t)
instance SubstType T.Branch where instance SubstType T.Branch where
apply :: Subst -> T.Branch -> T.Branch apply :: Subst -> T.Branch -> T.Branch
@ -509,6 +500,9 @@ instance SubstType T.Pattern where
instance SubstType a => SubstType [a] where instance SubstType a => SubstType [a] where
apply s = map (apply s) apply s = map (apply s)
instance (SubstType a, SubstType b) => SubstType (a, b) where
apply s (a, b) = (apply s a, apply s b)
instance SubstType T.Id where instance SubstType T.Id where
apply s (name, t) = (name, apply s t) apply s (name, t) = (name, apply s t)
@ -548,8 +542,10 @@ insertConstr i t =
-------- PATTERN MATCHING --------- -------- PATTERN MATCHING ---------
checkCase :: T.Type -> [Branch] -> Infer (Subst, [T.Branch], T.Type) checkCase :: T.Type -> [Branch] -> Infer (Subst, [T.Branch], T.Type)
checkCase expT injs = do checkCase _ [] = throwError "Atleast one case required"
(injTs, injs, returns) <- unzip3 <$> mapM inferBranch injs checkCase expT brnchs = do
(subs, injTs, injs, returns) <- unzip4 <$> mapM inferBranch brnchs
let sub0 = composeAll subs
(sub1, _) <- (sub1, _) <-
foldM foldM
( \(sub, acc) x -> ( \(sub, acc) x ->
@ -564,17 +560,14 @@ checkCase expT injs = do
) )
(nullSubst, head returns) (nullSubst, head returns)
(tail returns) (tail returns)
return (sub2 `compose` sub1, injs, returns_type) let comp = sub2 `compose` sub1 `compose` sub0
return (comp, apply comp injs, apply comp returns_type)
{- | fst = type of init inferBranch :: Branch -> Infer (Subst, T.Type, T.Branch, T.Type)
| snd = type of expr
-}
inferBranch :: Branch -> Infer (T.Type, T.Branch, T.Type)
inferBranch (Branch pat expr) = do inferBranch (Branch pat expr) = do
newPat@(pat, branchT) <- inferPattern pat newPat@(pat, branchT) <- inferPattern pat
trace ("BRANCH TYPE: " ++ show branchT) pure () (sub, newExp@(_, exprT)) <- withPattern pat (algoW expr)
newExp@(_, exprT) <- withPattern pat (inferExp expr) return (sub, branchT, T.Branch (apply sub newPat) (apply sub newExp), apply sub exprT)
return (branchT, T.Branch newPat newExp, exprT)
withPattern :: T.Pattern -> Infer a -> Infer a withPattern :: T.Pattern -> Infer a -> Infer a
withPattern p ma = case p of withPattern p ma = case p of
@ -590,14 +583,17 @@ inferPattern = \case
PInj constr patterns -> do PInj constr patterns -> do
t <- gets (M.lookup (coerce constr) . constructors) t <- gets (M.lookup (coerce constr) . constructors)
t <- maybeToRightM ("Constructor: " <> printTree constr <> " does not exist") t t <- maybeToRightM ("Constructor: " <> printTree constr <> " does not exist") t
(vs, ret) <- maybeToRightM "Partial pattern match not allowed" (unsnoc $ flattenType t) let numArgs = typeLength t - 1
let (vs, ret) = fromJust (unsnoc $ flattenType t)
patterns <- mapM inferPattern patterns patterns <- mapM inferPattern patterns
sub <- foldl' compose nullSubst <$> zipWithM unify vs (map snd patterns) unless (length patterns == numArgs) (throwError $ "The constructor '" ++ printTree constr ++ "'" ++ " should have " ++ show numArgs ++ " arguments but has been given " ++ show (length patterns))
sub <- composeAll <$> zipWithM unify vs (map snd patterns)
return (T.PInj (coerce constr) (map fst patterns), apply sub ret) return (T.PInj (coerce constr) (map fst patterns), apply sub ret)
PCatch -> (T.PCatch,) <$> fresh PCatch -> (T.PCatch,) <$> fresh
PEnum p -> do PEnum p -> do
t <- gets (M.lookup (coerce p) . constructors) t <- gets (M.lookup (coerce p) . constructors)
t <- maybeToRightM ("Constructor: " <> printTree p <> " does not exist") t t <- maybeToRightM ("Constructor: " <> printTree p <> " does not exist") t
unless (typeLength t == 1) (throwError $ "The constructor '" ++ printTree p ++ "'" ++ " should have " ++ show (typeLength t - 1) ++ " arguments but has been given 0")
return (T.PEnum $ coerce p, t) return (T.PEnum $ coerce p, t)
PVar x -> do PVar x -> do
fr <- fresh fr <- fresh
@ -608,6 +604,10 @@ flattenType :: T.Type -> [T.Type]
flattenType (T.TFun a b) = flattenType a <> flattenType b flattenType (T.TFun a b) = flattenType a <> flattenType b
flattenType a = [a] flattenType a = [a]
typeLength :: T.Type -> Int
typeLength (T.TFun a b) = typeLength a + typeLength b
typeLength _ = 1
litType :: Lit -> T.Type litType :: Lit -> T.Type
litType (LInt _) = int litType (LInt _) = int
litType (LChar _) = char litType (LChar _) = char
@ -629,8 +629,7 @@ partitionType = go []
exprErr :: Infer a -> Exp -> Infer a exprErr :: Infer a -> Exp -> Infer a
exprErr ma exp = exprErr ma exp =
catchError ma (\x -> throwError $ x <> " on expression: " <> printTree exp) catchError ma (\x -> throwError $ x <> " in expression: \n" <> printTree exp)
bindErr :: Infer a -> Bind -> Infer a unzip4 :: [(a, b, c, d)] -> ([a], [b], [c], [d])
bindErr ma exp = unzip4 = foldl' (\(as, bs, cs, ds) (a, b, c, d) -> (as ++ [a], bs ++ [b], cs ++ [c], ds ++ [d])) ([], [], [], [])
catchError ma (\x -> throwError $ x <> " on expression: " <> printTree exp)

View file

@ -21,16 +21,16 @@ goods =
[ specify "Basic polymorphism with multiple type variables" $ [ specify "Basic polymorphism with multiple type variables" $
run run
( D.do ( D.do
const _const
"main = const 'a' 65 ;" "main = const 'a' 65 ;"
) )
`shouldSatisfy` ok `shouldSatisfy` ok
, specify "Head with a correct signature is accepted" $ , specify "Head with a correct signature is accepted" $
run run
( D.do ( D.do
list _list
headSig _headSig
head _head
) )
`shouldSatisfy` ok `shouldSatisfy` ok
, specify "A basic arithmetic function should be able to be inferred" $ , specify "A basic arithmetic function should be able to be inferred" $
@ -59,13 +59,13 @@ goods =
, specify "Most simple inference possible" $ , specify "Most simple inference possible" $
run run
( D.do ( D.do
"id x = x ;" _id
) )
`shouldSatisfy` ok `shouldSatisfy` ok
, specify "Pattern matching on a nested list" $ , specify "Pattern matching on a nested list" $
run run
( D.do ( D.do
list _list
"main : List (List (a)) -> Int ;" "main : List (List (a)) -> Int ;"
"main xs = case xs of {" "main xs = case xs of {"
" Cons Nil _ => 1 ;" " Cons Nil _ => 1 ;"
@ -73,6 +73,24 @@ goods =
"};" "};"
) )
`shouldSatisfy` ok `shouldSatisfy` ok
, specify "List of function Int -> Int functions should be inferred corretly" $
run
( D.do
_list
"main xs = case xs of {"
" Cons f _ => f 1 ;"
" Nil => 0 ;"
" };"
)
`shouldBe` run
( D.do
_list
"main : List (Int -> Int) -> Int ;"
"main xs = case xs of {"
" Cons f _ => f 1 ;"
" Nil => 0 ;"
" };"
)
] ]
bads = bads =
@ -85,7 +103,7 @@ bads =
, specify "Pattern matching using different types should not succeed" $ , specify "Pattern matching using different types should not succeed" $
run run
( D.do ( D.do
list _list
"bad xs = case xs of {" "bad xs = case xs of {"
" 1 => 0 ;" " 1 => 0 ;"
" Nil => 0 ;" " Nil => 0 ;"
@ -95,7 +113,7 @@ bads =
, specify "Using a concrete function (data type) on a skolem variable should not succeed" $ , specify "Using a concrete function (data type) on a skolem variable should not succeed" $
run run
( D.do ( D.do
bool _bool
_not _not
"f : a -> Bool () ;" "f : a -> Bool () ;"
"f x = not x ;" "f x = not x ;"
@ -113,21 +131,32 @@ bads =
, specify "A function without signature used in an incompatible context should not succeed" $ , specify "A function without signature used in an incompatible context should not succeed" $
run run
( D.do ( D.do
"main = id 1 2 ;" "main = _id 1 2 ;"
"id x = x ;" "_id x = x ;"
) )
`shouldSatisfy` bad `shouldSatisfy` bad
, specify "Pattern matching on literal and list should not succeed" $ , specify "Pattern matching on literal and _list should not succeed" $
run run
( D.do ( D.do
list _list
"length : List (c) -> Int;" "length : List (c) -> Int;"
"length list = case list of {" "length _list = case _list of {"
" 0 => 0;" " 0 => 0;"
" Cons x xs => 1 + length xs;" " Cons x xs => 1 + length xs;"
"};" "};"
) )
`shouldSatisfy` bad `shouldSatisfy` bad
, specify "List of function Int -> Int functions should not be usable on Char" $
run
( D.do
_list
"main : List (Int -> Int) -> Int ;"
"main xs = case xs of {"
" Cons f _ => f 'a' ;"
" Nil => 0 ;"
" };"
)
`shouldSatisfy` bad
] ]
run = typecheck <=< pProgram . myLexer run = typecheck <=< pProgram . myLexer
@ -139,26 +168,26 @@ bad = not . ok
-- FUNCTIONS -- FUNCTIONS
const = D.do _const = D.do
"const : a -> b -> a ;" "const : a -> b -> a ;"
"const x y = x ;" "const x y = x ;"
list = D.do _list = D.do
"data List (a) where" "data List (a) where"
" {" " {"
" Nil : List (a)" " Nil : List (a)"
" Cons : a -> List (a) -> List (a)" " Cons : a -> List (a) -> List (a)"
" };" " };"
headSig = D.do _headSig = D.do
"head : List (a) -> a ;" "head : List (a) -> a ;"
head = D.do _head = D.do
"head xs = " "head xs = "
" case xs of {" " case xs of {"
" Cons x xs => x ;" " Cons x xs => x ;"
" };" " };"
bool = D.do _bool = D.do
"data Bool () where {" "data Bool () where {"
" True : Bool ()" " True : Bool ()"
" False : Bool ()" " False : Bool ()"
@ -170,3 +199,4 @@ _not = D.do
" True => False ;" " True => False ;"
" False => True ;" " False => True ;"
"};" "};"
_id = "id x = x ;"