From 30824443474e0124df2c86415ee523b225fc9d2d Mon Sep 17 00:00:00 2001 From: sebastian Date: Sat, 25 Mar 2023 18:42:11 +0100 Subject: [PATCH] fixed bugs potentially. tests are working atleast --- src/TypeChecker/TypeChecker.hs | 62 +++++++++++++++++++++--------- tests/Tests.hs | 70 +++++++++++++++++++++++++++++++++- 2 files changed, 112 insertions(+), 20 deletions(-) diff --git a/src/TypeChecker/TypeChecker.hs b/src/TypeChecker/TypeChecker.hs index 70fb894..152669e 100644 --- a/src/TypeChecker/TypeChecker.hs +++ b/src/TypeChecker/TypeChecker.hs @@ -119,23 +119,45 @@ checkPrg (Program bs) = do checkBind :: Bind -> Infer T.Bind checkBind err@(Bind name args e) = do let lambda = makeLambda e (reverse (coerce args)) - (_, lambdaT) <- inferExp lambda - args <- zip args <$> mapM (const fresh) args - withBindings (map coerce args) $ do - e@(_, _) <- inferExp e - s <- gets sigs - case M.lookup (coerce name) s of - Just (Just t) -> do - sub <- bindErr (unify t lambdaT) err - let newT = apply sub t - insertSig (coerce name) (Just newT) - return $ T.Bind (apply sub (coerce name, newT)) (map coerce args) e - _ -> do - insertSig (coerce name) (Just lambdaT) - return (T.Bind (coerce name, lambdaT) (map coerce args) e) + e@(_, args_t) <- inferExp lambda + -- args <- zip args <$> mapM (const fresh) args + -- withBindings (coerce args) $ do + -- e@(_, t) <- inferExp e + -- let args_t = foldl' T.TFun t (reverse (map snd args)) + s <- gets sigs + case M.lookup (coerce name) s of + Just (Just t') -> do + -- sub <- bindErr (unify args_t t') err + -- let newT = apply sub args_t + -- insertSig (coerce name) (Just newT) + -- return $ T.Bind (apply sub (coerce name, newT)) [] e + unless + (args_t `typeEq` t') + ( throwError $ + "Inferred type '" + ++ printTree args_t + ++ " does not match specified type '" + ++ printTree t' + ++ "'" + ) + return $ T.Bind (coerce name, t') [] e + _ -> do + insertSig (coerce name) (Just args_t) + return (T.Bind (coerce name, args_t) [] e) + +typeEq :: T.Type -> T.Type -> Bool +typeEq (T.TFun l r) (T.TFun l' r') = typeEq l l' && typeEq r r' +typeEq (T.TLit a) (T.TLit b) = a == b +typeEq (T.TData name a) (T.TData name' b) = + length a == length b + && name == name' + && and (zipWith typeEq a b) +typeEq (T.TAll _ t1) (T.TAll _ t2) = t1 `typeEq` t2 +typeEq (T.TVar _) (T.TVar _) = True +typeEq _ _ = False isMoreSpecificOrEq :: T.Type -> T.Type -> Bool -isMoreSpecificOrEq _ (T.TAll _ _) = True +isMoreSpecificOrEq t1 (T.TAll _ t2) = isMoreSpecificOrEq t1 t2 isMoreSpecificOrEq (T.TFun a b) (T.TFun c d) = isMoreSpecificOrEq a c && isMoreSpecificOrEq b d isMoreSpecificOrEq (T.TData n1 ts1) (T.TData n2 ts2) = @@ -224,8 +246,10 @@ algoW = \case sig <- gets sigs case M.lookup (coerce i) sig of Just (Just t) -> return (nullSubst, (T.EId $ coerce i, t)) - Just Nothing -> - (\x -> (nullSubst, (T.EId $ coerce i, x))) <$> fresh + Just Nothing -> do + fr <- fresh + insertSig (coerce i) (Just fr) + return (nullSubst, (T.EId $ coerce i, fr)) Nothing -> throwError $ "Unbound variable: " <> printTree i EInj i -> do constr <- gets constructors @@ -315,15 +339,17 @@ makeLambda = foldl (flip (EAbs . coerce)) -- | Unify two types producing a new substitution unify :: T.Type -> T.Type -> Infer Subst +-- unify t0 t1 | trace ("T0: " ++ show t0 ++ "\nT1: " ++ show t1 ++ "\n") False = undefined unify t0 t1 = do case (t0, t1) of (T.TFun a b, T.TFun c d) -> do s1 <- unify a c s2 <- unify (apply s1 b) (apply s1 d) return $ s1 `compose` s2 - -- TODO: BEWARY. THIS IS PROBABLY WRONG!!! + ----------- TODO: CAREFUL!!!! THIS IS PROBABLY WRONG!!! ----------- (T.TVar (T.MkTVar a), t@(T.TData _ _)) -> return $ M.singleton a t (t@(T.TData _ _), T.TVar (T.MkTVar b)) -> return $ M.singleton b t + ------------------------------------------------------------------- (T.TVar (T.MkTVar a), t) -> occurs a t (t, T.TVar (T.MkTVar b)) -> occurs b t (T.TAll _ t, b) -> unify t b diff --git a/tests/Tests.hs b/tests/Tests.hs index 55ae9ab..99c49e6 100644 --- a/tests/Tests.hs +++ b/tests/Tests.hs @@ -7,7 +7,7 @@ import Control.Monad ((<=<)) import DoStrings qualified as D import Grammar.Par (myLexer, pProgram) import Test.Hspec -import Prelude (Bool (..), Either (..), IO, fmap, not, ($), (.)) +import Prelude (Bool (..), Either (..), IO, not, ($), (.)) -- import Test.QuickCheck import TypeChecker.TypeChecker (typecheck) @@ -16,9 +16,14 @@ main :: IO () main = hspec $ do ok1 ok2 + ok3 + ok4 + ok5 bad1 bad2 bad3 + bad4 + bad5 ok1 = specify "Basic polymorphism with multiple type variables" $ @@ -38,6 +43,41 @@ ok2 = ) `shouldSatisfy` ok +ok3 = + specify "A basic arithmetic function should be able to be inferred" $ + run + ( D.do + "plusOne x = x + 1 ;" + "main x = plusOne x ;" + ) + `shouldBe` run + ( D.do + "plusOne : Int -> Int ;" + "plusOne x = x + 1 ;" + "main : Int -> Int ;" + "main x = plusOne x ;" + ) + +ok4 = + specify "A basic arithmetic function should be able to be inferred" $ + run + ( D.do + "plusOne x = x + 1 ;" + ) + `shouldBe` run + ( D.do + "plusOne : Int -> Int ;" + "plusOne x = x + 1 ;" + ) + +ok5 = + specify "Most simple inference possible" $ + run + ( D.do + "id x = x ;" + ) + `shouldSatisfy` ok + bad1 = specify "Infinite type unification should not succeed" $ run @@ -59,7 +99,7 @@ bad2 = `shouldSatisfy` bad bad3 = - specify "Using a concrete function on a skolem variable should not succeed" $ + specify "Using a concrete function (data type) on a skolem variable should not succeed" $ run ( D.do bool @@ -69,6 +109,26 @@ bad3 = ) `shouldSatisfy` bad +bad4 = + specify "Using a concrete function (primitive type) on a skolem variable should not succeed" $ + run + ( D.do + "plusOne : Int -> Int ;" + "plusOne x = x + 1 ;" + "f : a -> Int ;" + " f x = plusOne x ;" + ) + `shouldSatisfy` bad + +bad5 = + specify "A function without signature used in an incompatible context should not succeed" $ + run + ( D.do + "main = id 1 2 ;" + "id x = x ;" + ) + `shouldSatisfy` bad + run = typecheck <=< pProgram . myLexer ok (Right _) = True @@ -90,6 +150,7 @@ list = D.do headSig = D.do "head : List (a) -> a ;" + head = D.do "head xs = " " case xs of {" @@ -108,3 +169,8 @@ _not = D.do " True => False ;" " False => True ;" "};" + +{- + [a, b, c] | (Int -> Int) + (a -> (b -> (c -> (Int -> Int)))) +-}