diff --git a/src/TypeChecker/TypeCheckerHm.hs b/src/TypeChecker/TypeCheckerHm.hs index 1d40a5c..1fc0ee4 100644 --- a/src/TypeChecker/TypeCheckerHm.hs +++ b/src/TypeChecker/TypeCheckerHm.hs @@ -1,4 +1,5 @@ {-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedRecordDot #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE QualifiedDo #-} @@ -39,32 +40,36 @@ runC e c = . flip evalStateT e . runInfer -typecheck :: Program -> Either Error (T.Program' Type) -typecheck = run . checkPrg +typecheck :: Program -> Either String (T.Program' Type) +typecheck = onLeft msg . run . checkPrg + where + onLeft :: (Error -> String) -> Either Error a -> Either String a + onLeft f (Left x) = Left $ f x + onLeft _ (Right x) = Right x checkData :: Data -> Infer () -checkData (Data typ injs) = do +checkData err@(Data typ injs) = do (name, tvars) <- go typ - mapM_ (\i -> typecheckInj i name tvars) injs + dataErr (mapM_ (\i -> typecheckInj i name tvars) injs) err where go = \case TData name typs | Right tvars' <- mapM toTVar typs -> pure (name, tvars') - TAll _ _ -> throwError "Explicit foralls not allowed, for now" - _ -> throwError $ unwords ["Bad data type definition: ", printTree typ] + TAll _ _ -> uncatchableErr "Explicit foralls not allowed, for now" + _ -> uncatchableErr $ unwords ["Bad data type definition: ", printTree typ] typecheckInj :: Inj -> UIdent -> [TVar] -> Infer () typecheckInj (Inj c inj_typ) name tvars | Right False <- boundTVars tvars inj_typ = - throwError "Unbound type variables" + catchableErr "Unbound type variables" | TData name' typs <- returnType inj_typ , Right tvars' <- mapM toTVar typs , name' == name , tvars' == tvars = do exist <- existInj (coerce c) case exist of - Just t -> throwError $ Aux.do + Just t -> uncatchableErr $ Aux.do "Constructor" quote $ coerce name "with type" @@ -72,7 +77,7 @@ typecheckInj (Inj c inj_typ) name tvars "already exist" Nothing -> insertInj (coerce c) inj_typ | otherwise = - throwError $ + uncatchableErr $ unwords [ "Bad type constructor: " , show name @@ -84,7 +89,7 @@ typecheckInj (Inj c inj_typ) name tvars where boundTVars :: [TVar] -> Type -> Either Error Bool boundTVars tvars' = \case - TAll{} -> throwError "Explicit foralls not allowed, for now" + TAll{} -> uncatchableErr "Explicit foralls not allowed, for now" TFun t1 t2 -> do t1' <- boundTVars tvars t1 t2' <- boundTVars tvars t2 @@ -94,10 +99,10 @@ typecheckInj (Inj c inj_typ) name tvars TLit _ -> return True TEVar _ -> error "TEVar in data type declaration" -toTVar :: Type -> Either String TVar +toTVar :: Type -> Either Error TVar toTVar = \case TVar tvar -> pure tvar - _ -> throwError "Not a type variable" + _ -> uncatchableErr "Not a type variable" returnType :: Type -> Type returnType (TFun _ t2) = returnType t2 @@ -117,7 +122,7 @@ preRun (x : xs) = case x of gets (M.member (coerce n) . sigs) >>= flip when - ( throwError $ Aux.do + ( uncatchableErr $ Aux.do "Duplicate signatures for function" quote $ printTree n ) @@ -156,7 +161,7 @@ checkBind (Bind name args e) = do sub2 <- unify t' lambda_t unless (apply sub1 lambda_t == t' && lambda_t == apply sub2 t') - ( throwError $ Aux.do + ( uncatchableErr $ Aux.do "Inferred type" quote $ printTree lambda_t "does not match specified type" @@ -232,13 +237,11 @@ algoW = \case sub2 <- unify t' t unless (apply sub1 t == t' && apply sub2 t' == t) - ( throwError $ - unwords - [ "Annotated type:" - , printTree t - , "does not match inferred type:" - , printTree t' - ] + ( uncatchableErr $ Aux.do + "Annotated type" + quote $ printTree t + "does not match inferred type" + quote $ printTree t' ) s2 <- exprErr (unify t t') err let comp = s2 `compose` s1 @@ -263,16 +266,16 @@ algoW = \case fr <- fresh insertSig (coerce i) (Just fr) return (nullSubst, (T.EVar $ coerce i, fr)) - Nothing -> throwError $ "Unbound variable: " <> printTree i + Nothing -> uncatchableErr $ "Unbound variable: " <> printTree i EInj i -> do constr <- gets injections case M.lookup (coerce i) constr of Just t -> return (nullSubst, (T.EVar $ coerce i, t)) Nothing -> - throwError $ - "Constructor: '" - <> printTree i - <> "' is not defined" + uncatchableErr $ Aux.do + "Constructor:" + quote $ printTree i + "is not defined" -- \| τ = newvar Γ, x : τ ⊢ e : τ', S -- \| --------------------------------- @@ -365,7 +368,7 @@ unify t0 t1 = do (TLit a, TLit b) -> if a == b then return M.empty - else throwError $ + else catchableErr $ Aux.do "Can not unify" quote $ printTree (TLit a) @@ -376,7 +379,7 @@ unify t0 t1 = do then do xs <- zipWithM unify t t' return $ foldr compose nullSubst xs - else throwError $ + else catchableErr $ Aux.do "Type constructor:" printTree name @@ -387,14 +390,14 @@ unify t0 t1 = do (TEVar a, TEVar b) -> if a == b then return M.empty - else throwError $ + else catchableErr $ Aux.do "Can not unify" quote $ printTree (TEVar a) "with" quote $ printTree (TEVar b) (a, b) -> do - throwError $ + catchableErr $ Aux.do "Can not unify" quote $ printTree a @@ -409,12 +412,14 @@ occurs :: T.Ident -> Type -> Infer Subst occurs i t@(TVar _) = return (M.singleton i t) occurs i t = if S.member i (free t) - then throwError $ - Aux.do - "Occurs check failed, can't unify" - quote $ printTree (TVar $ T.MkTVar (coerce i)) - "with" - quote $ printTree t + then + catchableErr + ( Aux.do + "Occurs check failed, can't unify" + quote $ printTree (TVar $ T.MkTVar (coerce i)) + "with" + quote $ printTree t + ) else return $ M.singleton i t -- | Generalize a type over all free variables in the substitution set @@ -581,7 +586,7 @@ existInj n = gets (M.lookup n . injections) -------- PATTERN MATCHING --------- checkCase :: Type -> [Branch] -> Infer (Subst, [T.Branch' Type], Type) -checkCase _ [] = throwError "Atleast one case required" +checkCase _ [] = catchableErr "Atleast one case required" checkCase expT brnchs = do (subs, injTs, injs, returns) <- unzip4 <$> mapM inferBranch brnchs let sub0 = composeAll subs @@ -621,13 +626,23 @@ inferPattern = \case PLit lit -> let lt = litType lit in return (T.PLit (lit, lt), lt) PInj constr patterns -> do t <- gets (M.lookup (coerce constr) . injections) - t <- maybeToRightM ("Constructor: " <> printTree constr <> " does not exist") t + t <- + maybeToRightM + ( Error + ( Aux.do + "Constructor:" + quote $ printTree constr + "does not exist" + ) + True + ) + t let numArgs = typeLength t - 1 let (vs, ret) = fromJust (unsnoc $ flattenType t) patterns <- mapM inferPattern patterns unless (length patterns == numArgs) - ( throwError $ Aux.do + ( catchableErr $ Aux.do "The constructor" quote $ printTree constr " should have " @@ -640,10 +655,20 @@ inferPattern = \case PCatch -> (T.PCatch,) <$> fresh PEnum p -> do t <- gets (M.lookup (coerce p) . injections) - t <- maybeToRightM ("Constructor: " <> printTree p <> " does not exist") t + t <- + maybeToRightM + ( Error + ( Aux.do + "Constructor:" + quote $ printTree p + "does not exist" + ) + True + ) + t unless (typeLength t == 1) - ( throwError $ Aux.do + ( catchableErr $ Aux.do "The constructor" quote $ printTree p " should have " @@ -686,8 +711,10 @@ partitionType = go [] _ -> error "Number of parameters and type doesn't match" exprErr :: Infer a -> Exp -> Infer a -exprErr ma exp = - catchError ma (\x -> throwError $ x <> " in expression: \n" <> printTree exp) +exprErr ma exp = catchError ma (\x -> if x.catchable then throwError (x{msg = x.msg <> " in expression: \n" <> printTree exp, catchable = False}) else throwError x) + +dataErr :: Infer a -> Data -> Infer a +dataErr ma d = catchError ma (\x -> if x.catchable then throwError (x{msg = x.msg <> " in data: \n" <> printTree d}) else throwError (x{catchable = False})) unzip4 :: [(a, b, c, d)] -> ([a], [b], [c], [d]) unzip4 = @@ -709,11 +736,17 @@ data Env = Env } deriving (Show) -type Error = String +data Error = Error {msg :: String, catchable :: Bool} type Subst = Map T.Ident Type newtype Infer a = Infer {runInfer :: StateT Env (ReaderT Ctx (ExceptT Error Identity)) a} deriving (Functor, Applicative, Monad, MonadReader Ctx, MonadError Error, MonadState Env) +catchableErr :: MonadError Error m => String -> m a +catchableErr msg = throwError $ Error msg True + +uncatchableErr :: MonadError Error m => String -> m a +uncatchableErr msg = throwError $ Error msg False + quote :: String -> String quote s = "'" ++ s ++ "'"