diff --git a/sphinx/util/__init__.py b/sphinx/util/__init__.py index 18e9f8701..bfe99778a 100644 --- a/sphinx/util/__init__.py +++ b/sphinx/util/__init__.py @@ -48,7 +48,7 @@ from sphinx.util.matching import patfilter # noqa if False: # For type annotation - from typing import Any, Callable, Dict, IO, Iterable, Iterator, List, Pattern, Sequence, Set, Tuple, Union # NOQA + from typing import Any, Callable, Dict, IO, Iterable, Iterator, List, Pattern, Sequence, Set, Tuple, Type, Union # NOQA logger = logging.getLogger(__name__) @@ -676,6 +676,10 @@ def status_iterator(iterable, summary, color="darkgreen", length=0, verbosity=0, logger.info('') +class SkipProgressMessage(Exception): + pass + + class progress_message: def __init__(self, message): # type: (str) -> None @@ -686,8 +690,13 @@ class progress_message: logger.info(bold(self.message + '... '), nonl=True) def __exit__(self, exc_type, exc_value, traceback): - # type: (Any, Any, Any) -> bool - if exc_type: + # type: (Type[Exception], Exception, Any) -> bool + if isinstance(exc_value, SkipProgressMessage): + logger.info(__('skipped')) + if exc_value.args: + logger.info(*exc_value.args) + return True + elif exc_type: logger.info(__('failed')) else: logger.info(__('done')) diff --git a/tests/test_util.py b/tests/test_util.py index 0860ac6a4..0926096f4 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -18,8 +18,8 @@ import sphinx from sphinx.errors import PycodeError from sphinx.testing.util import strip_escseq from sphinx.util import ( - display_chunk, encode_uri, ensuredir, get_module_source, parselinenos, status_iterator, - xmlname_checker + SkipProgressMessage, display_chunk, encode_uri, ensuredir, get_module_source, + parselinenos, progress_message, status_iterator, xmlname_checker ) from sphinx.util import logging @@ -123,6 +123,44 @@ def test_parselinenos(): parselinenos('3-1', 10) +def test_progress_message(app, status, warning): + logging.setup(app, status, warning) + logger = logging.getLogger(__name__) + + # standard case + with progress_message('testing'): + logger.info('blah ', nonl=True) + + output = strip_escseq(status.getvalue()) + assert 'testing... blah done\n' in output + + # skipping case + with progress_message('testing'): + raise SkipProgressMessage('Reason: %s', 'error') + + output = strip_escseq(status.getvalue()) + assert 'testing... skipped\nReason: error\n' in output + + # error case + try: + with progress_message('testing'): + raise + except Exception: + pass + + output = strip_escseq(status.getvalue()) + assert 'testing... failed\n' in output + + # decorator + @progress_message('testing') + def func(): + logger.info('in func ', nonl=True) + + func() + output = strip_escseq(status.getvalue()) + assert 'testing... in func done\n' in output + + def test_xmlname_check(): checker = xmlname_checker() assert checker.match('id-pub')