From 4efe7cf9a2f47e386d229b4774d813c8f98ac3fa Mon Sep 17 00:00:00 2001 From: sebastianselander Date: Wed, 29 Mar 2023 17:30:31 +0200 Subject: [PATCH] inference does not depend on order. mutual recursion still not working correctly --- src/TypeChecker/TypeCheckerHm.hs | 139 ++++++++++----- tests/TestTypeChekerHm.hs/DoStrings.hs | 9 - tests/TestTypeChekerHm.hs/Tests.hs | 231 ------------------------- 3 files changed, 97 insertions(+), 282 deletions(-) delete mode 100644 tests/TestTypeChekerHm.hs/DoStrings.hs delete mode 100644 tests/TestTypeChekerHm.hs/Tests.hs diff --git a/src/TypeChecker/TypeCheckerHm.hs b/src/TypeChecker/TypeCheckerHm.hs index 3d1121e..ea819fc 100644 --- a/src/TypeChecker/TypeCheckerHm.hs +++ b/src/TypeChecker/TypeCheckerHm.hs @@ -12,6 +12,7 @@ import Control.Monad.Except import Control.Monad.Identity (Identity, runIdentity) import Control.Monad.Reader import Control.Monad.State +import Data.Bifunctor (first) import Data.Coerce (coerce) import Data.Function (on) import Data.List (foldl') @@ -27,7 +28,7 @@ import Grammar.Print (printTree) import TypeChecker.TypeCheckerIr qualified as T initCtx = Ctx mempty -initEnv = Env 0 'a' mempty mempty mempty +initEnv = Env 0 'a' mempty mempty mempty "" mempty mempty run :: Infer a -> Either Error a run = run' initEnv initCtx @@ -51,8 +52,20 @@ typecheck = onLeft msg . run . checkPrg checkPrg :: Program -> Infer (T.Program' Type) checkPrg (Program bs) = do preRun bs - bs' <- checkDef bs - return $ T.Program bs' + bs <- checkDef bs + sub <- solveUndecidable + dec <- gets toDecide + trace (printTree bs) pure () + bs <- mapM (mono sub) bs + return $ T.Program bs + +mono :: Subst -> T.Def' Type -> Infer (T.Def' Type) +mono s bind@(T.DBind (T.Bind (name, t) args e)) = do + b <- gets (S.member name . toDecide) + if b + then return $ T.DBind $ T.Bind (name, apply s t) (apply s args) (apply s e) + else return bind +mono _ (T.DData d) = return $ T.DData d preRun :: [Def] -> Infer () preRun [] = return () @@ -66,7 +79,7 @@ preRun (x : xs) = case x of "Duplicate signatures for function" quote $ printTree n ) - insertSig (coerce n) (Just $ skolemize t) >> preRun xs + insertSig (coerce n) (Just t) >> preRun xs DBind (Bind n _ e) -> do collect (collectTVars e) s <- gets sigs @@ -91,25 +104,15 @@ checkDef (x : xs) = case x of T.Data t $ map (\(Inj name typ) -> T.Inj (coerce name) typ) injs checkBind :: Bind -> Infer (T.Bind' Type) -checkBind (Bind name args e) = do +checkBind bind@(Bind name args e) = do + setCurrentBind $ coerce name let lambda = makeLambda e (reverse (coerce args)) - (sub0, (e, lambda_t)) <- inferExp lambda + (e, lambda_t) <- inferExp lambda s <- gets sigs case M.lookup (coerce name) s of Just (Just t') -> do - -- \| TODO: Fix, this is not correct - let fsig = apply sub0 t' - sub1 <- liftEither $ runIdentity $ runExceptT $ execStateT (typeEq fsig lambda_t) mempty - sub2 <- liftEither $ runIdentity $ runExceptT $ execStateT (typeEq lambda_t fsig) mempty - unless - (lambda_t == apply sub1 fsig && apply sub2 lambda_t == fsig) - ( uncatchableErr $ Aux.do - "Inferred type" - quote $ printTree lambda_t - "does not match specified type" - quote $ printTree t' - ) - return $ T.Bind (coerce name, lambda_t) [] (e, lambda_t) + sub1 <- bindErr (unify lambda_t (skolemize t')) bind + return $ T.Bind (coerce name, apply sub1 t') [] (e, lambda_t) _ -> do insertSig (coerce name) (Just lambda_t) return (T.Bind (coerce name, lambda_t) [] (e, lambda_t)) @@ -123,7 +126,7 @@ checkData err@(Data typ injs) = do TData name typs | Right tvars' <- mapM toTVar typs -> pure (name, tvars') - TAll _ _ -> uncatchableErr "Explicit foralls not allowed, for now" + TAll _ _ -> uncatchableErr "Explicit forall not allowed, for now" _ -> uncatchableErr $ unwords ["Bad data type definition: ", printTree typ] @@ -158,7 +161,7 @@ checkInj (Inj c inj_typ) name tvars where boundTVars :: [TVar] -> Type -> Either Error Bool boundTVars tvars' = \case - TAll{} -> uncatchableErr "Explicit foralls not allowed, for now" + TAll{} -> uncatchableErr "Explicit forall not allowed, for now" TFun t1 t2 -> do t1' <- boundTVars tvars t1 t2' <- boundTVars tvars t2 @@ -177,11 +180,12 @@ returnType :: Type -> Type returnType (TFun _ t2) = returnType t2 returnType a = a -inferExp :: Exp -> Infer (Subst, T.ExpT' Type) +inferExp :: Exp -> Infer (T.ExpT' Type) inferExp e = do (s, (e', t)) <- algoW e let subbed = apply s t - return (s, (e', subbed)) + modify (\st -> st{undecidedSigs = apply s st.undecidedSigs}) + return (e', subbed) class CollectTVars a where collectTVars :: a -> Set T.Ident @@ -225,7 +229,7 @@ algoW = \case -- \| x : σ ∈ Γ   τ = inst(σ) -- \| ---------------------- -- \| Γ ⊢ x : τ, ∅ - EVar i -> do + EVar (LIdent i) -> do var <- asks vars case M.lookup (coerce i) var of Just t -> @@ -237,7 +241,8 @@ algoW = \case Just (Just t) -> return (nullSubst, (T.EVar $ coerce i, t)) Just Nothing -> do fr <- fresh - insertSig (coerce i) (Just fr) + cb <- gets currentBind + modify (\st -> st{toDecide = S.insert cb st.toDecide, undecidedSigs = M.insert (coerce $ concat [[prefix], i, [delim], coerce cb]) fr st.undecidedSigs}) return (nullSubst, (T.EVar $ coerce i, fr)) Nothing -> uncatchableErr $ @@ -591,6 +596,9 @@ instance SubstType (Map T.Ident Type) where apply :: Subst -> Map T.Ident Type -> Map T.Ident Type apply = M.map . apply +instance SubstType (T.ExpT' Type) where + apply s (e, t) = (apply s e, apply s t) + instance SubstType (T.Exp' Type) where apply s = \case T.EVar i -> T.EVar i @@ -605,6 +613,11 @@ instance SubstType (T.Exp' Type) where T.ECase e brnch -> T.ECase (apply s e) (apply s brnch) T.EInj i -> T.EInj i +instance SubstType (T.Def' Type) where + apply s = \case + T.DBind (T.Bind name args e) -> T.DBind $ T.Bind (apply s name) (apply s args) (apply s e) + d -> d + instance SubstType (T.Branch' Type) where apply s (T.Branch (i, t) e) = T.Branch (apply s i, apply s t) (apply s e) @@ -616,18 +629,18 @@ instance SubstType (T.Pattern' Type) where T.PCatch -> T.PCatch T.PEnum i -> T.PEnum i +instance SubstType (T.Pattern' Type, Type) where + apply s (p, t) = (apply s p, apply s t) + instance SubstType a => SubstType [a] where apply s = map (apply s) -instance (SubstType a, SubstType b) => SubstType (a, b) where - apply s (a, b) = (apply s a, apply s b) - instance SubstType (T.Id' Type) where apply s (name, t) = (name, apply s t) -- | Represents the empty substition set nullSubst :: Subst -nullSubst = M.empty +nullSubst = mempty -- | Compose two substitution sets compose :: Subst -> Subst -> Subst @@ -676,6 +689,31 @@ with an equivalent name has been declared already existInj :: T.Ident -> Infer (Maybe Type) existInj n = gets (M.lookup n . injections) +setCurrentBind :: T.Ident -> Infer () +setCurrentBind i = modify (\st -> st{currentBind = i}) + +solveUndecidable :: Infer Subst +solveUndecidable = do + sigs <- gets sigs + undecided <- gets undecidedSigs + let xs = M.toList undecided + ys <- + maybeToRightM + (Error "SIGNATURE MISSING" False) + (mapM (tupSequence . first (join . flip M.lookup sigs . getOriginal)) xs) + composeAll <$> mapM (uncurry unify) ys + +tupSequence :: Monad m => (m a, b) -> m (a, b) +tupSequence (ma, b) = (,b) <$> ma + +getOriginal :: T.Ident -> T.Ident +getOriginal (T.Ident i) = coerce $ takeWhile (/= delim) $ drop 1 i + +delim :: Char +delim = '_' +prefix :: Char +prefix = '$' + flattenType :: Type -> [Type] flattenType (TFun a b) = flattenType a <> flattenType b flattenType a = [a] @@ -740,19 +778,30 @@ exprErr :: (Monad m, MonadError Error m) => m a -> Exp -> m a exprErr ma exp = catchError ma - ( \x -> - if x.catchable - then - throwError - ( x - { msg = - x.msg + ( \err -> if err.catchable + then throwError + ( err { msg = err.msg <> " in expression: \n" <> printTree exp , catchable = False } ) - else throwError x + else throwError err + ) + +bindErr :: (Monad m, MonadError Error m) => m a -> Bind -> m a +bindErr ma bind = + catchError + ma + ( \err -> if err.catchable + then throwError + ( err { msg = err.msg + <> " in function: \n" + <> printTree bind + , catchable = False + } + ) + else throwError err ) {- | Catch an error if possible and add the given @@ -762,18 +811,18 @@ dataErr :: Infer a -> Data -> Infer a dataErr ma d = catchError ma - ( \x -> - if x.catchable + ( \err -> + if err.catchable then throwError - ( x + ( err { msg = - x.msg + err.msg <> " in data: \n" <> printTree d } ) - else throwError (x{catchable = False}) + else throwError (err{catchable = False}) ) unzip4 :: [(a, b, c, d)] -> ([a], [b], [c], [d]) @@ -793,6 +842,9 @@ data Env = Env , sigs :: Map T.Ident (Maybe Type) , injections :: Map T.Ident Type , takenTypeVars :: Set T.Ident + , currentBind :: T.Ident + , undecidedSigs :: Map T.Ident Type + , toDecide :: Set T.Ident } deriving (Show) @@ -811,3 +863,6 @@ uncatchableErr msg = throwError $ Error msg False quote :: String -> String quote s = "'" ++ s ++ "'" + +ctrace :: (Monad m, Show a) => String -> a -> m () +ctrace str a = trace (str ++ ": " ++ show a) pure () diff --git a/tests/TestTypeChekerHm.hs/DoStrings.hs b/tests/TestTypeChekerHm.hs/DoStrings.hs deleted file mode 100644 index dabf5d6..0000000 --- a/tests/TestTypeChekerHm.hs/DoStrings.hs +++ /dev/null @@ -1,9 +0,0 @@ -module DoStrings where - -import Prelude hiding ((>>), (>>=)) - -(>>) :: String -> String -> String -(>>) str1 str2 = str1 ++ "\n" ++ str2 - -(>>=) :: String -> (String -> String) -> String -(>>=) str f = f str diff --git a/tests/TestTypeChekerHm.hs/Tests.hs b/tests/TestTypeChekerHm.hs/Tests.hs deleted file mode 100644 index b5d14c6..0000000 --- a/tests/TestTypeChekerHm.hs/Tests.hs +++ /dev/null @@ -1,231 +0,0 @@ -{-# LANGUAGE QualifiedDo #-} -{-# LANGUAGE NoImplicitPrelude #-} - -module Main where - -import Control.Monad ((<=<)) -import DoStrings qualified as D -import Grammar.Par (myLexer, pProgram) -import Test.Hspec -import Prelude (Bool (..), Either (..), IO, mapM_, not, ($), (.)) - --- import Test.QuickCheck -import TypeChecker.TypeChecker (typecheck) - -main :: IO () -main = do - mapM_ hspec goods - mapM_ hspec bads - mapM_ hspec bes - -goods = - [ testSatisfy - "Basic polymorphism with multiple type variables" - ( D.do - _const - "main = const 'a' 65 ;" - ) - ok - , testSatisfy - "Head with a correct signature is accepted" - ( D.do - _List - _headSig - _head - ) - ok - , testSatisfy - "Most simple inference possible" - ( D.do - _id - ) - ok - , testSatisfy - "Pattern matching on a nested list" - ( D.do - _List - "main : List (List (a)) -> Int ;" - "main xs = case xs of {" - " Cons Nil _ => 1 ;" - " _ => 0 ;" - "};" - ) - ok - ] - -bads = - [ testSatisfy - "Infinite type unification should not succeed" - ( D.do - "main = \\x. x x ;" - ) - bad - , testSatisfy - "Pattern matching using different types should not succeed" - ( D.do - _List - "bad xs = case xs of {" - " 1 => 0 ;" - " Nil => 0 ;" - "};" - ) - bad - , testSatisfy - "Using a concrete function (data type) on a skolem variable should not succeed" - ( D.do - _Bool - _not - "f : a -> Bool () ;" - "f x = not x ;" - ) - bad - , testSatisfy - "Using a concrete function (primitive type) on a skolem variable should not succeed" - ( D.do - "plusOne : Int -> Int ;" - "plusOne x = x + 1 ;" - "f : a -> Int ;" - "f x = plusOne x ;" - ) - bad - , testSatisfy - "A function without signature used in an incompatible context should not succeed" - ( D.do - "main = _id 1 2 ;" - "_id x = x ;" - ) - bad - , testSatisfy - "Pattern matching on literal and _List should not succeed" - ( D.do - _List - "length : List (c) -> Int;" - "length _List = case _List of {" - " 0 => 0;" - " Cons x xs => 1 + length xs;" - "};" - ) - bad - , testSatisfy - "List of function Int -> Int functions should not be usable on Char" - ( D.do - _List - "main : List (Int -> Int) -> Int ;" - "main xs = case xs of {" - " Cons f _ => f 'a' ;" - " Nil => 0 ;" - " };" - ) - bad - , testSatisfy - "id with incorrect signature" - ( D.do - "id : a -> b;" - "id x = x;" - ) - bad - , testSatisfy - "incorrect type signature on id lambda" - ( D.do - "id = ((\\x. x) : a -> b);" - ) - bad - ] - -bes = - [ testBe - "A basic arithmetic function should be able to be inferred" - ( D.do - "plusOne x = x + 1 ;" - "main x = plusOne x ;" - ) - ( D.do - "plusOne : Int -> Int ;" - "plusOne x = x + 1 ;" - "main : Int -> Int ;" - "main x = plusOne x ;" - ) - , testBe - "A basic arithmetic function should be able to be inferred" - ( D.do - "plusOne x = x + 1 ;" - ) - ( D.do - "plusOne : Int -> Int ;" - "plusOne x = x + 1 ;" - ) - , testBe - "List of function Int -> Int functions should be inferred corretly" - ( D.do - _List - "main xs = case xs of {" - " Cons f _ => f 1 ;" - " Nil => 0 ;" - " };" - ) - ( D.do - _List - "main : List (Int -> Int) -> Int ;" - "main xs = case xs of {" - " Cons f _ => f 1 ;" - " Nil => 0 ;" - " };" - ) - ] - -testSatisfy desc test satisfaction = specify desc $ run test `shouldSatisfy` satisfaction -testBe desc test shouldbe = specify desc $ run test `shouldBe` run shouldbe - -run = typecheck <=< pProgram . myLexer - -ok (Right _) = True -ok (Left _) = False - -bad = not . ok - --- FUNCTIONS - -_const = D.do - "const : a -> b -> a ;" - "const x y = x ;" -_List = D.do - "data List (a) where" - " {" - " Nil : List (a)" - " Cons : a -> List (a) -> List (a)" - " };" - -_headSig = D.do - "head : List (a) -> a ;" - -_head = D.do - "head xs = " - " case xs of {" - " Cons x xs => x ;" - " };" - -_Bool = D.do - "data Bool () where {" - " True : Bool ()" - " False : Bool ()" - "};" - -_not = D.do - "not : Bool () -> Bool () ;" - "not x = case x of {" - " True => False ;" - " False => True ;" - "};" -_id = "id x = x ;" - -_Maybe = D.do - "data Maybe (a) where {" - " Nothing : Maybe (a)" - " Just : a -> Maybe (a)" - " };" - -_fmap = D.do - "fmap f ma = case ma of {" - " Nothing => Nothing ;" - " Just a => Just (f a) ;" - "};"