From 975dd340630eaa2223144f1ecc276ee6bd17615c Mon Sep 17 00:00:00 2001 From: sebastian Date: Sat, 25 Mar 2023 20:43:19 +0100 Subject: [PATCH] Better inference & stuff on pattern matches, added more tests for regression --- Justfile | 1 + src/TypeChecker/TypeChecker.hs | 91 +++++++++++++++++----------------- tests/Tests.hs | 66 +++++++++++++++++------- 3 files changed, 94 insertions(+), 64 deletions(-) diff --git a/Justfile b/Justfile index 8079213..7787dc8 100644 --- a/Justfile +++ b/Justfile @@ -5,6 +5,7 @@ build: clean: rm -r src/Grammar rm language + rm -r dist-newstyle/ # run all tests test: diff --git a/src/TypeChecker/TypeChecker.hs b/src/TypeChecker/TypeChecker.hs index 4944071..2b53760 100644 --- a/src/TypeChecker/TypeChecker.hs +++ b/src/TypeChecker/TypeChecker.hs @@ -6,19 +6,19 @@ module TypeChecker.TypeChecker where import Auxiliary import Control.Monad.Except +import Control.Monad.Identity (runIdentity) import Control.Monad.Reader import Control.Monad.State import Data.Bifunctor (second) import Data.Coerce (coerce) import Data.Foldable (traverse_) -import Data.Functor.Identity (runIdentity) import Data.List (foldl') import Data.List.Extra (unsnoc) import Data.Map (Map) import Data.Map qualified as M +import Data.Maybe (fromJust) import Data.Set (Set) import Data.Set qualified as S -import Debug.Trace (trace) import Grammar.Abs import Grammar.Print (printTree) import TypeChecker.TypeCheckerIr ( @@ -117,20 +117,12 @@ checkPrg (Program bs) = do (DSig _) -> checkDef xs checkBind :: Bind -> Infer T.Bind -checkBind err@(Bind name args e) = do +checkBind (Bind name args e) = do let lambda = makeLambda e (reverse (coerce args)) e@(_, args_t) <- inferExp lambda - -- args <- zip args <$> mapM (const fresh) args - -- withBindings (coerce args) $ do - -- e@(_, t) <- inferExp e - -- let args_t = foldl' T.TFun t (reverse (map snd args)) s <- gets sigs case M.lookup (coerce name) s of Just (Just t') -> do - -- sub <- bindErr (unify args_t t') err - -- let newT = apply sub args_t - -- insertSig (coerce name) (Just newT) - -- return $ T.Bind (apply sub (coerce name, newT)) [] e unless (args_t `typeEq` t') ( throwError $ @@ -152,7 +144,8 @@ typeEq (T.TData name a) (T.TData name' b) = length a == length b && name == name' && and (zipWith typeEq a b) -typeEq (T.TAll _ t1) (T.TAll _ t2) = t1 `typeEq` t2 +typeEq (T.TAll _ t1) t2 = t1 `typeEq` t2 +typeEq t1 (T.TAll _ t2) = t1 `typeEq` t2 typeEq (T.TVar _) (T.TVar _) = True typeEq _ _ = False @@ -164,6 +157,7 @@ isMoreSpecificOrEq (T.TData n1 ts1) (T.TData n2 ts2) = n1 == n2 && length ts1 == length ts2 && and (zipWith isMoreSpecificOrEq ts1 ts2) +isMoreSpecificOrEq _ (T.TVar _) = True isMoreSpecificOrEq a b = a == b isPoly :: Type -> Bool @@ -175,10 +169,7 @@ inferExp :: Exp -> Infer T.ExpT inferExp e = do (s, (e', t)) <- algoW e let subbed = apply s t - return $ replace subbed (e', t) - -replace :: T.Type -> T.ExpT -> T.ExpT -replace t = second (const t) + return $ second (const subbed) (e', t) class NewType a b where toNew :: a -> b @@ -200,7 +191,7 @@ instance NewType Data T.Data where toNew (Data t xs) = T.Data (name $ retType t) (toNew xs) where name (TData n _) = coerce n - name _ = error "Bug in toNew Data -> T.Data" + name _ = error "Bug: Data types should not be able to be typed over non type variables" instance NewType Constructor T.Constructor where toNew (Constructor name xs) = T.Constructor (coerce name) (toNew xs) @@ -213,7 +204,6 @@ instance NewType a b => NewType [a] [b] where algoW :: Exp -> Infer (Subst, T.ExpT) algoW = \case - -- \| TODO: More testing need to be done. Unsure of the correctness of this err@(EAnn e t) -> do (s1, (e', t')) <- exprErr (algoW e) err unless @@ -434,6 +424,9 @@ inst = \case compose :: Subst -> Subst -> Subst compose m1 m2 = M.map (apply m1) m2 `M.union` m1 +composeAll :: [Subst] -> Subst +composeAll = foldl' compose nullSubst + -- TODO: Split this class into two separate classes, one for free variables -- and one for applying substitutions @@ -477,21 +470,19 @@ instance SubstType (Map T.Ident T.Type) where apply :: Subst -> Map T.Ident T.Type -> Map T.Ident T.Type apply s = M.map (apply s) -instance SubstType T.ExpT where - apply :: Subst -> T.ExpT -> T.ExpT +instance SubstType T.Exp where + apply :: Subst -> T.Exp -> T.Exp apply s = \case - (T.EId i, outerT) -> (T.EId i, apply s outerT) - (T.ELit lit, t) -> (T.ELit lit, apply s t) - (T.ELet (T.Bind (ident, t1) args e1) e2, t2) -> - ( T.ELet + T.EId i -> T.EId i + T.ELit lit -> T.ELit lit + T.ELet (T.Bind (ident, t1) args e1) e2 -> + T.ELet (T.Bind (ident, apply s t1) args (apply s e1)) (apply s e2) - , apply s t2 - ) - (T.EApp e1 e2, t) -> (T.EApp (apply s e1) (apply s e2), apply s t) - (T.EAdd e1 e2, t) -> (T.EAdd (apply s e1) (apply s e2), apply s t) - (T.EAbs ident e, t1) -> (T.EAbs ident (apply s e), apply s t1) - (T.ECase e brnch, t) -> (T.ECase (apply s e) (apply s brnch), apply s t) + T.EApp e1 e2 -> T.EApp (apply s e1) (apply s e2) + T.EAdd e1 e2 -> T.EAdd (apply s e1) (apply s e2) + T.EAbs ident e -> T.EAbs ident (apply s e) + T.ECase e brnch -> T.ECase (apply s e) (apply s brnch) instance SubstType T.Branch where apply :: Subst -> T.Branch -> T.Branch @@ -509,6 +500,9 @@ instance SubstType T.Pattern where instance SubstType a => SubstType [a] where apply s = map (apply s) +instance (SubstType a, SubstType b) => SubstType (a, b) where + apply s (a, b) = (apply s a, apply s b) + instance SubstType T.Id where apply s (name, t) = (name, apply s t) @@ -548,8 +542,10 @@ insertConstr i t = -------- PATTERN MATCHING --------- checkCase :: T.Type -> [Branch] -> Infer (Subst, [T.Branch], T.Type) -checkCase expT injs = do - (injTs, injs, returns) <- unzip3 <$> mapM inferBranch injs +checkCase _ [] = throwError "Atleast one case required" +checkCase expT brnchs = do + (subs, injTs, injs, returns) <- unzip4 <$> mapM inferBranch brnchs + let sub0 = composeAll subs (sub1, _) <- foldM ( \(sub, acc) x -> @@ -564,17 +560,14 @@ checkCase expT injs = do ) (nullSubst, head returns) (tail returns) - return (sub2 `compose` sub1, injs, returns_type) + let comp = sub2 `compose` sub1 `compose` sub0 + return (comp, apply comp injs, apply comp returns_type) -{- | fst = type of init - | snd = type of expr --} -inferBranch :: Branch -> Infer (T.Type, T.Branch, T.Type) +inferBranch :: Branch -> Infer (Subst, T.Type, T.Branch, T.Type) inferBranch (Branch pat expr) = do newPat@(pat, branchT) <- inferPattern pat - trace ("BRANCH TYPE: " ++ show branchT) pure () - newExp@(_, exprT) <- withPattern pat (inferExp expr) - return (branchT, T.Branch newPat newExp, exprT) + (sub, newExp@(_, exprT)) <- withPattern pat (algoW expr) + return (sub, branchT, T.Branch (apply sub newPat) (apply sub newExp), apply sub exprT) withPattern :: T.Pattern -> Infer a -> Infer a withPattern p ma = case p of @@ -590,14 +583,17 @@ inferPattern = \case PInj constr patterns -> do t <- gets (M.lookup (coerce constr) . constructors) t <- maybeToRightM ("Constructor: " <> printTree constr <> " does not exist") t - (vs, ret) <- maybeToRightM "Partial pattern match not allowed" (unsnoc $ flattenType t) + let numArgs = typeLength t - 1 + let (vs, ret) = fromJust (unsnoc $ flattenType t) patterns <- mapM inferPattern patterns - sub <- foldl' compose nullSubst <$> zipWithM unify vs (map snd patterns) + unless (length patterns == numArgs) (throwError $ "The constructor '" ++ printTree constr ++ "'" ++ " should have " ++ show numArgs ++ " arguments but has been given " ++ show (length patterns)) + sub <- composeAll <$> zipWithM unify vs (map snd patterns) return (T.PInj (coerce constr) (map fst patterns), apply sub ret) PCatch -> (T.PCatch,) <$> fresh PEnum p -> do t <- gets (M.lookup (coerce p) . constructors) t <- maybeToRightM ("Constructor: " <> printTree p <> " does not exist") t + unless (typeLength t == 1) (throwError $ "The constructor '" ++ printTree p ++ "'" ++ " should have " ++ show (typeLength t - 1) ++ " arguments but has been given 0") return (T.PEnum $ coerce p, t) PVar x -> do fr <- fresh @@ -608,6 +604,10 @@ flattenType :: T.Type -> [T.Type] flattenType (T.TFun a b) = flattenType a <> flattenType b flattenType a = [a] +typeLength :: T.Type -> Int +typeLength (T.TFun a b) = typeLength a + typeLength b +typeLength _ = 1 + litType :: Lit -> T.Type litType (LInt _) = int litType (LChar _) = char @@ -629,8 +629,7 @@ partitionType = go [] exprErr :: Infer a -> Exp -> Infer a exprErr ma exp = - catchError ma (\x -> throwError $ x <> " on expression: " <> printTree exp) + catchError ma (\x -> throwError $ x <> " in expression: \n" <> printTree exp) -bindErr :: Infer a -> Bind -> Infer a -bindErr ma exp = - catchError ma (\x -> throwError $ x <> " on expression: " <> printTree exp) +unzip4 :: [(a, b, c, d)] -> ([a], [b], [c], [d]) +unzip4 = foldl' (\(as, bs, cs, ds) (a, b, c, d) -> (as ++ [a], bs ++ [b], cs ++ [c], ds ++ [d])) ([], [], [], []) diff --git a/tests/Tests.hs b/tests/Tests.hs index c6a92da..d1b87a6 100644 --- a/tests/Tests.hs +++ b/tests/Tests.hs @@ -21,16 +21,16 @@ goods = [ specify "Basic polymorphism with multiple type variables" $ run ( D.do - const + _const "main = const 'a' 65 ;" ) `shouldSatisfy` ok , specify "Head with a correct signature is accepted" $ run ( D.do - list - headSig - head + _list + _headSig + _head ) `shouldSatisfy` ok , specify "A basic arithmetic function should be able to be inferred" $ @@ -59,13 +59,13 @@ goods = , specify "Most simple inference possible" $ run ( D.do - "id x = x ;" + _id ) `shouldSatisfy` ok , specify "Pattern matching on a nested list" $ run ( D.do - list + _list "main : List (List (a)) -> Int ;" "main xs = case xs of {" " Cons Nil _ => 1 ;" @@ -73,6 +73,24 @@ goods = "};" ) `shouldSatisfy` ok + , specify "List of function Int -> Int functions should be inferred corretly" $ + run + ( D.do + _list + "main xs = case xs of {" + " Cons f _ => f 1 ;" + " Nil => 0 ;" + " };" + ) + `shouldBe` run + ( D.do + _list + "main : List (Int -> Int) -> Int ;" + "main xs = case xs of {" + " Cons f _ => f 1 ;" + " Nil => 0 ;" + " };" + ) ] bads = @@ -85,7 +103,7 @@ bads = , specify "Pattern matching using different types should not succeed" $ run ( D.do - list + _list "bad xs = case xs of {" " 1 => 0 ;" " Nil => 0 ;" @@ -95,7 +113,7 @@ bads = , specify "Using a concrete function (data type) on a skolem variable should not succeed" $ run ( D.do - bool + _bool _not "f : a -> Bool () ;" "f x = not x ;" @@ -113,21 +131,32 @@ bads = , specify "A function without signature used in an incompatible context should not succeed" $ run ( D.do - "main = id 1 2 ;" - "id x = x ;" + "main = _id 1 2 ;" + "_id x = x ;" ) `shouldSatisfy` bad - , specify "Pattern matching on literal and list should not succeed" $ + , specify "Pattern matching on literal and _list should not succeed" $ run ( D.do - list + _list "length : List (c) -> Int;" - "length list = case list of {" + "length _list = case _list of {" " 0 => 0;" " Cons x xs => 1 + length xs;" "};" ) `shouldSatisfy` bad + , specify "List of function Int -> Int functions should not be usable on Char" $ + run + ( D.do + _list + "main : List (Int -> Int) -> Int ;" + "main xs = case xs of {" + " Cons f _ => f 'a' ;" + " Nil => 0 ;" + " };" + ) + `shouldSatisfy` bad ] run = typecheck <=< pProgram . myLexer @@ -139,26 +168,26 @@ bad = not . ok -- FUNCTIONS -const = D.do +_const = D.do "const : a -> b -> a ;" "const x y = x ;" -list = D.do +_list = D.do "data List (a) where" " {" " Nil : List (a)" " Cons : a -> List (a) -> List (a)" " };" -headSig = D.do +_headSig = D.do "head : List (a) -> a ;" -head = D.do +_head = D.do "head xs = " " case xs of {" " Cons x xs => x ;" " };" -bool = D.do +_bool = D.do "data Bool () where {" " True : Bool ()" " False : Bool ()" @@ -170,3 +199,4 @@ _not = D.do " True => False ;" " False => True ;" "};" +_id = "id x = x ;"