diff --git a/src/TypeChecker/TypeCheckerHm.hs b/src/TypeChecker/TypeCheckerHm.hs index 3a505b4..4e7e7d6 100644 --- a/src/TypeChecker/TypeCheckerHm.hs +++ b/src/TypeChecker/TypeCheckerHm.hs @@ -22,7 +22,7 @@ import Data.Map qualified as M import Data.Maybe (fromJust) import Data.Set (Set) import Data.Set qualified as S -import Debug.Trace (trace) +import Debug.Trace (trace, traceShow) import Grammar.Abs import Grammar.Print (printTree) import TypeChecker.TypeCheckerIr qualified as T @@ -158,36 +158,31 @@ freeOrdered _ = mempty checkBind :: Bind -> Infer (T.Bind' Type) checkBind (Bind name args e) = do let lambda = makeLambda e (reverse (coerce args)) - (e, lambda_t) <- inferExp lambda + (e, infSig) <- inferExp lambda s <- gets sigs case M.lookup (coerce name) s of - Just (Just t') -> do - let fvs0 = nub $ freeOrdered t' - let m0 = M.fromList $ zip fvs0 letters - let fvs1 = nub $ freeOrdered lambda_t - let m1 = M.fromList $ zip fvs1 letters - let t0 = replace m0 t' - let t1 = replace m1 lambda_t - -- Not sure if this is actually correct - sub <- unify t' lambda_t + Just (Just typSig) -> do + let genInfSig = generalize mempty infSig + (trace ("Inferred: " ++ printTree infSig ++ "\nGeneralized inferred: " ++ printTree genInfSig ++ "\nGiven: " ++ printTree typSig ++ "\n") pure ()) + sub <- genInfSig `unify` typSig unless - (t1 <<= t0) + (genInfSig <<= typSig) ( throwError $ Error ( Aux.do "Inferred type" - quote $ printTree t1 + quote $ printTree infSig "doesn't match given type" - quote $ printTree $ mkForall t0 + quote $ printTree typSig ) False ) - -- Applying sub to t' will worsen error messages. + -- Applying sub to typSig will worsen error messages. -- Unfortunately I do not know a better solution at the moment. - return $ T.Bind (coerce name, apply sub t') [] (apply sub e, lambda_t) + return $ T.Bind (coerce name, apply sub typSig) [] (apply sub e, typSig) _ -> do - insertSig (coerce name) (Just lambda_t) - return (T.Bind (coerce name, lambda_t) [] (e, lambda_t)) + insertSig (coerce name) (Just infSig) + return (T.Bind (coerce name, infSig) [] (e, infSig)) checkData :: (MonadState Env m, Monad m, MonadError Error m) => Data -> m () checkData err@(Data typ injs) = do @@ -276,7 +271,7 @@ algoW = \case quote $ printTree t' ) let comp = sub2 `compose` sub1 `compose` sub0 - return (comp, (apply comp e', skolemize t)) + return (comp, (apply comp e', t)) -- \| ------------------ -- \| Γ ⊢ i : Int, ∅ @@ -384,7 +379,9 @@ algoW = \case return (comp, apply comp (T.ECase (e', t) injs, ret_t)) checkCase :: Type -> [Branch] -> Infer (Subst, [T.Branch' Type], Type) -checkCase _ [] = catchableErr "Atleast one case required" +checkCase _ [] = do + fr <- fresh + return (nullSubst, [], fr) checkCase expT brnchs = do (subs, branchTs, injs, returns) <- unzip4 <$> mapM inferBranch brnchs let sub0 = composeAll subs @@ -608,15 +605,37 @@ currently this is not the case, the TAll pattern match is incorrectly implemente -} -- Is the left a subtype of the right (<<=) :: Type -> Type -> Bool -(<<=) (TVar _) _ = True -(<<=) t1@TAll{} t2 = skipForalls t1 <<= t2 -(<<=) t1 t2@TAll{} = t1 <<= skipForalls t2 -(<<=) (TFun a b) (TFun c d) = a <<= c && b <<= d -(<<=) (TData n1 ts1) (TData n2 ts2) = - n1 == n2 - && length ts1 == length ts2 - && and (zipWith (<<=) ts1 ts2) -(<<=) a b = a == b +(<<=) a b = case (a,b) of + (TVar _, _) -> True + (TFun a b,TFun c d) -> a <<= c && b <<= d + (TAll tvar1 t1, TAll tvar2 t2) -> ungo [tvar1, tvar2] t1 t2 + (TAll tvar t1, t2) -> ungo [tvar] t1 t2 + (t1, TAll tvar t2) -> ungo [tvar] t1 t2 + (TData n1 ts1, TData n2 ts2) -> n1 == n2 + && length ts1 == length ts2 + && and (zipWith (<<=) ts1 ts2) + (t1,t2) -> t1 == t2 + where + ungo :: [TVar] -> Type -> Type -> Bool + ungo tvars t1 t2 = case run (go tvars t1 t2) of + Right (b,_) -> b + _ -> False + go :: [TVar] -> Type -> Type -> Infer Bool + go tvars t1 t2 = do + fr <- fresh + let sub = M.fromList [(coerce x, fr) | (MkTVar x) <- tvars] + return (apply sub t1 <<= apply sub t2) + +{- + +typSig = TAll (MkTVar "a") (TAll (MkTVar "b") ((TVar (MkTVar "a") `TFun` (TVar (MkTVar "b"))))) + +infSig = generalize mempty $ TFun (TVar $ MkTVar "x") (TVar $ MkTVar "x") + +a = TAll (MkTVar "a") (TFun (TVar $ MkTVar "a") (TVar $ MkTVar "a")) +b = TAll (MkTVar "b") (TFun (TVar $ MkTVar "b") (TVar $ MkTVar "b")) +int = TFun (TLit "Int") (TLit "Int") +-} skipForalls :: Type -> Type skipForalls = \case @@ -897,6 +916,7 @@ data Error = Error {msg :: String, catchable :: Bool} type Subst = Map T.Ident Type newtype Warning = NonExhaustive String + deriving (Show) newtype Infer a = Infer {runInfer :: StateT Env (ReaderT Ctx (WriterT [Warning] (ExceptT Error Identity))) a} deriving (Functor, Applicative, Monad, MonadReader Ctx, MonadError Error, MonadState Env)