From 03d7080396bcd24b91506c5fb808e3759ef564e8 Mon Sep 17 00:00:00 2001 From: sebastianselander Date: Fri, 3 Mar 2023 11:46:54 +0100 Subject: [PATCH] pattern matching works? have to test more --- src/Main.hs | 2 +- src/Renamer/Renamer.hs | 4 +- src/TypeChecker/TypeChecker.hs | 63 +++++++++++++++++++------------- src/TypeChecker/TypeCheckerIr.hs | 11 ++++++ test_program | 30 ++++++++++++--- 5 files changed, 76 insertions(+), 34 deletions(-) diff --git a/src/Main.hs b/src/Main.hs index 8e62f2b..bef4a3b 100644 --- a/src/Main.hs +++ b/src/Main.hs @@ -33,7 +33,7 @@ main' s = do putStrLn "\n-- TypeChecker --" typechecked <- fromTypeCheckerErr $ typecheck renamed - putStrLn $ printTree typechecked + putStrLn $ show typechecked -- putStrLn "\n-- Lambda Lifter --" -- let lifted = lambdaLift typechecked diff --git a/src/Renamer/Renamer.hs b/src/Renamer/Renamer.hs index 24582f6..d471553 100644 --- a/src/Renamer/Renamer.hs +++ b/src/Renamer/Renamer.hs @@ -73,7 +73,9 @@ renameExp old_names = \case (new_names, e') <- renameExp old_names e pure (new_names, EAnn e' t) - ECase _ _ -> error "ECase NOT IMPLEMENTED YET" + ECase e injs -> do + (new_names, e') <- renameExp old_names e + pure (new_names, ECase e' injs) -- | Create a new name and add it to name environment. newName :: Names -> Ident -> Rn (Names, Ident) diff --git a/src/TypeChecker/TypeChecker.hs b/src/TypeChecker/TypeChecker.hs index 7a2b96b..9c55388 100644 --- a/src/TypeChecker/TypeChecker.hs +++ b/src/TypeChecker/TypeChecker.hs @@ -1,5 +1,8 @@ {-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedStrings #-} +{-# OPTIONS_GHC -Wno-unused-matches #-} +{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-} +{-# HLINT ignore "Use mapAndUnzipM" #-} -- | A module for type checking and inference using algorithm W, Hindley-Milner module TypeChecker.TypeChecker where @@ -100,10 +103,12 @@ typeEq (TConstr (Constr name a)) (TConstr (Constr name' b)) = length a == length typeEq (TPol _) (TPol _) = True typeEq _ _ = False -isMoreGeneral :: Type -> Type -> Bool -isMoreGeneral _ (TPol _) = True -isMoreGeneral (TArr a b) (TArr c d) = isMoreGeneral a c && isMoreGeneral b d -isMoreGeneral a b = a == b +isMoreSpecificOrEq :: Type -> Type -> Bool +isMoreSpecificOrEq _ (TPol _) = True +isMoreSpecificOrEq (TArr a b) (TArr c d) = isMoreSpecificOrEq a c && isMoreSpecificOrEq b d +isMoreSpecificOrEq (TConstr (Constr n1 ts1)) (TConstr (Constr n2 ts2)) + = n1 == n2 && length ts1 == length ts2 && and (zipWith isMoreSpecificOrEq ts1 ts2) +isMoreSpecificOrEq a b = a == b isPoly :: Type -> Bool isPoly (TPol _) = True @@ -117,12 +122,13 @@ inferExp e = do replace :: Type -> T.Exp -> T.Exp replace t = \case - T.ELit _ e -> T.ELit t e - T.EId (n, _) -> T.EId (n, t) - T.EAbs _ name e -> T.EAbs t name e - T.EApp _ e1 e2 -> T.EApp t e1 e2 - T.EAdd _ e1 e2 -> T.EAdd t e1 e2 + T.ELit _ e -> T.ELit t e + T.EId (n, _) -> T.EId (n, t) + T.EAbs _ name e -> T.EAbs t name e + T.EApp _ e1 e2 -> T.EApp t e1 e2 + T.EAdd _ e1 e2 -> T.EAdd t e1 e2 T.ELet (T.Bind (n, _) e1) e2 -> T.ELet (T.Bind (n, t) e1) e2 + T.ECase _ expr injs -> T.ECase t expr injs algoW :: Exp -> Infer (Subst, Type, T.Exp) algoW = \case @@ -130,7 +136,7 @@ algoW = \case -- | TODO: Reason more about this one. Could be wrong EAnn e t -> do (s1, t', e') <- algoW e - unless (t `isMoreGeneral` t') (throwError $ unwords + unless (t `isMoreSpecificOrEq` t') (throwError $ unwords ["Annotated type:" , printTree t , "does not match inferred type:" @@ -218,13 +224,18 @@ algoW = \case let t' = generalize (apply s1 env) t1 withBinding name t' $ do (s2, t2, e1') <- algoW e1 - return (s2 `compose` s1, t2, T.ELet (T.Bind (name,t2) e0') e1' ) + return (s2 `compose` s1, t2, T.ELet (T.Bind (name,t2) e0') e1') ECase caseExpr injs -> do (s0, t0, e0') <- algoW caseExpr - injs' <- mapM (checkInj t0) injs - undefined - + (injs', ts) <- unzip <$> mapM (checkInj t0) injs + case ts of + [] -> throwError "Case expression missing any matches" + ts -> do + unified <- zipWithM unify ts (tail ts) + let unified' = foldl' compose mempty unified + let typ = apply unified' (head ts) + return (unified', typ, T.ECase typ e0' injs') -- | Unify two types producing a new substitution @@ -340,19 +351,19 @@ insertConstr i t = modify (\st -> st { constructors = M.insert i t (constructors -------- PATTERN MATCHING --------- --- case expr of, the type of 'expr' is caseType -checkInj :: Type -> Inj -> Infer T.Inj +-- "case expr of", the type of 'expr' is caseType +checkInj :: Type -> Inj -> Infer (T.Inj, Type) checkInj caseType (Inj it expr) = do - (_, e') <- inferExp expr - t' <- initType caseType it - return $ T.Inj (it, t') e' + (args, t') <- initType caseType it + (s, t, e') <- local (\st -> st { vars = args }) (algoW expr) + return (T.Inj (it, t') e', t) -initType :: Type -> Init -> Infer Type +initType :: Type -> Init -> Infer (Map Ident Poly, Type) initType expected = \case InitLit lit -> let returnType = litType lit in if expected == returnType - then return expected - else throwError $ unwords ["Inferred type", printTree returnType, "does not match expected type:", printTree expected] + then return (mempty,expected) + else throwError $ unwords ["Inferred type", printTree returnType, "does not match expected type:", printTree expected] InitConstr c args -> do st <- gets constructors case M.lookup c st of @@ -360,14 +371,14 @@ initType expected = \case Just t -> do let flat = flattenType t let returnType = last flat - case (length (init flat) == length args, returnType == expected) of - (True, True) -> return returnType + case (length (init flat) == length args, returnType `isMoreSpecificOrEq` expected) of + (True, True) -> return (M.fromList $ zip args (map (Forall []) flat), expected) (False, _) -> throwError $ "Can't partially match on the constructor: " ++ printTree c - (_, False) -> throwError $ unwords ["Inferred type", printTree returnType, "does not match expected type:", printTree expected] + (_, False) -> throwError $ unwords ["Inferred type", printTree returnType, "does not match expected type:", printTree expected] -- Ignoring the variables for now, they can not be used in the expression to the -- right of '=>' - InitCatch -> return expected + InitCatch -> return (mempty, expected) flattenType :: Type -> [Type] flattenType (TArr a b) = flattenType a ++ flattenType b diff --git a/src/TypeChecker/TypeCheckerIr.hs b/src/TypeChecker/TypeCheckerIr.hs index a2c86f7..c07da96 100644 --- a/src/TypeChecker/TypeCheckerIr.hs +++ b/src/TypeChecker/TypeCheckerIr.hs @@ -127,6 +127,17 @@ instance Print Exp where , doc $ showString "." , prt 0 e ] + ECase t exp injs -> prPrec i 0 (concatD [doc (showString "case"), prt 0 exp, doc (showString "of"), doc (showString "{"), prt 0 injs, doc (showString "}"), doc (showString ":"), prt 0 t]) + +instance Print Inj where + prt i = \case + Inj (init,t) exp -> prPrec i 0 (concatD [prt 0 init, doc (showString ":"), prt 0 t, doc (showString "=>"), prt 0 exp]) + +instance Print [Inj] where + prt _ [] = concatD [] + prt _ [x] = concatD [prt 0 x] + prt _ (x:xs) = concatD [prt 0 x, doc (showString ";"), prt 0 xs] + diff --git a/test_program b/test_program index c5af112..26220e3 100644 --- a/test_program +++ b/test_program @@ -1,12 +1,30 @@ +-- data Bool () where { +-- True : Bool () +-- False : Bool () +-- }; +-- +-- main : _Int ; +-- main = case True of { +-- False => 0 ; +-- True => 1 +-- }; + data List ('a) where { Nil : List ('a) - Cons : 'a -> List ('a) -> List ('a) + Cons : ('a) -> List ('a) -> List ('a) }; -data Bool () where { - True : Bool () - False : Bool () +data Maybe ('a) where { + Nothing : Maybe ('a) + Just : 'a -> Maybe ('a) }; -main : List (_Int) ; -main = Cons 1 (Cons 0 Nil) ; +safeHead : List ('a) -> Maybe ('a) ; +safeHead xs = + case xs of { + Nil => Nothing ; + Cons x xs => Just x + }; + +main : Maybe (_Int) ; +main = safeHead (Cons 0 (Cons 1 Nil)) ;