From 404e4ffbedb316fe2ba61e3f610437ecc0194211 Mon Sep 17 00:00:00 2001 From: Adam Turner <9087854+aa-turner@users.noreply.github.com> Date: Fri, 10 Jan 2025 14:54:24 +0000 Subject: [PATCH] Consolidate both ``handle_exception()`` implementations --- sphinx/_cli/util/errors.py | 96 +++++++++++++++++----------- sphinx/cmd/build.py | 127 +++++++++---------------------------- sphinx/errors.py | 4 +- sphinx/util/exceptions.py | 19 ------ 4 files changed, 89 insertions(+), 157 deletions(-) delete mode 100644 sphinx/util/exceptions.py diff --git a/sphinx/_cli/util/errors.py b/sphinx/_cli/util/errors.py index 8bf2ac91d..df7d1ad7c 100644 --- a/sphinx/_cli/util/errors.py +++ b/sphinx/_cli/util/errors.py @@ -2,15 +2,18 @@ from __future__ import annotations import re import sys -import tempfile -from typing import TYPE_CHECKING, TextIO +from typing import TYPE_CHECKING -from sphinx.errors import SphinxParallelError +from sphinx.errors import SphinxError, SphinxParallelError if TYPE_CHECKING: - from typing import Final + from collections.abc import Collection + from typing import Final, Protocol - from sphinx.application import Sphinx + from sphinx.extension import Extension + + class SupportsWrite(Protocol): + def write(self, text: str, /) -> int | None: ... _CSI: Final[str] = re.escape('\x1b[') # 'ESC [': Control Sequence Introducer @@ -50,6 +53,34 @@ def strip_escape_sequences(text: str, /) -> str: return _ANSI_CODES.sub('', text) +def full_exception_context( + exception: BaseException, + *, + message_log: Collection[str] = (), + extensions: Collection[Extension] = (), +) -> str: + """Return a formatted message containing useful debugging context.""" + last_msgs = '\n'.join(f'* {strip_escape_sequences(s)}' for s in message_log) + exts_list = '\n'.join( + f'* {ext.name} ({ext.version})' + for ext in extensions + if ext.version != 'builtin' + ) + exc_format = format_traceback(exception) + return error_info(last_msgs or 'None.', exts_list or 'None.', exc_format) + + +def format_traceback(exception: BaseException, /) -> str: + """Format the given exception's traceback.""" + if isinstance(exception, SphinxParallelError): + return f'(Error in parallel process)\n{exception.traceback}' + else: + from traceback import format_exception + + exc_format = ''.join(format_exception(exception)) + return exc_format + + def error_info(messages: str, extensions: str, traceback: str) -> str: """Format the traceback and extensions list with environment information.""" import platform @@ -88,37 +119,26 @@ Traceback """ -def format_traceback(app: Sphinx | None, exc: BaseException) -> str: - """Format the given exception's traceback with environment information.""" - if isinstance(exc, SphinxParallelError): - exc_format = '(Error in parallel process)\n' + exc.traceback - else: - import traceback - - exc_format = traceback.format_exc() - - last_msgs = exts_list = '' - if app is not None: - extensions = app.extensions.values() - last_msgs = '\n'.join(f'* {strip_escape_sequences(s)}' for s in app.messagelog) - exts_list = '\n'.join( - f'* {ext.name} ({ext.version})' - for ext in extensions - if ext.version != 'builtin' - ) - - return error_info(last_msgs, exts_list, exc_format) - - -def save_traceback(app: Sphinx | None, exc: BaseException) -> str: +def save_traceback( + exception: BaseException, + *, + message_log: Collection[str] = (), + extensions: Collection[Extension] = (), +) -> str: """Save the given exception's traceback in a temporary file.""" - output = format_traceback(app=app, exc=exc) + output = full_exception_context( + exception=exception, + message_log=message_log, + extensions=extensions, + ) filename = write_temporary_file(output) return filename def write_temporary_file(content: str) -> str: """Write content to a temporary file and return the filename.""" + import tempfile + with tempfile.NamedTemporaryFile( 'w', encoding='utf-8', suffix='.log', prefix='sphinx-err-', delete=False ) as f: @@ -131,18 +151,18 @@ def handle_exception( exception: BaseException, /, *, - stderr: TextIO = sys.stderr, + stderr: SupportsWrite = sys.stderr, use_pdb: bool = False, print_traceback: bool = False, - app: Sphinx | None = None, + message_log: Collection[str] = (), + extensions: Collection[Extension] = (), ) -> None: from bdb import BdbQuit - from traceback import TracebackException, print_exc + from traceback import TracebackException from docutils.utils import SystemMessage from sphinx._cli.util.colour import red - from sphinx.errors import SphinxError from sphinx.locale import __ if isinstance(exception, BdbQuit): @@ -156,14 +176,14 @@ def handle_exception( print_err() if print_traceback or use_pdb: - print_exc(file=stderr) + print_err(format_traceback(exception)) print_err() if use_pdb: from pdb import post_mortem print_red(__('Exception occurred, starting debugger:')) - post_mortem() + post_mortem(exception.__traceback__) return if isinstance(exception, KeyboardInterrupt): @@ -193,7 +213,7 @@ def handle_exception( __( 'This can happen with very large or deeply nested source ' 'files. You can carefully increase the default Python ' - 'recursion limit of 1000 in conf.py with e.g.:' + 'recursion limit of 1,000 in conf.py with e.g.:' ) ) print_err('\n import sys\n sys.setrecursionlimit(1_500)\n') @@ -205,7 +225,9 @@ def handle_exception( print_red(__('Exception occurred:')) print_err(formatted_tb) - traceback_info_path = save_traceback(app, exception) + traceback_info_path = save_traceback( + exception, message_log=message_log, extensions=extensions + ) print_err(__('The full traceback has been saved in:')) print_err(traceback_info_path) print_err() diff --git a/sphinx/cmd/build.py b/sphinx/cmd/build.py index 21f741b60..fe9ec3e86 100644 --- a/sphinx/cmd/build.py +++ b/sphinx/cmd/build.py @@ -3,36 +3,28 @@ from __future__ import annotations import argparse -import bdb import contextlib import locale import multiprocessing -import pdb # NoQA: T100 import sys -import traceback from pathlib import Path from typing import TYPE_CHECKING, Any, TextIO -from docutils.utils import SystemMessage - +import sphinx._cli.util.errors import sphinx.locale from sphinx import __display_version__ from sphinx.application import Sphinx -from sphinx.errors import SphinxError, SphinxParallelError from sphinx.locale import __ from sphinx.util._io import TeeStripANSI from sphinx.util._pathlib import _StrPath -from sphinx.util.console import color_terminal, nocolor, red, terminal_safe +from sphinx.util.console import color_terminal, nocolor from sphinx.util.docutils import docutils_namespace, patch_docutils -from sphinx.util.exceptions import format_exception_cut_frames, save_traceback from sphinx.util.osutil import ensuredir if TYPE_CHECKING: - from collections.abc import Sequence - from typing import Protocol + from collections.abc import Collection, Sequence - class SupportsWrite(Protocol): - def write(self, text: str, /) -> int | None: ... + from sphinx.extension import Extension def handle_exception( @@ -41,92 +33,19 @@ def handle_exception( exception: BaseException, stderr: TextIO = sys.stderr, ) -> None: - if isinstance(exception, bdb.BdbQuit): - return - - if args.pdb: - print( - red(__('Exception occurred while building, starting debugger:')), - file=stderr, - ) - traceback.print_exc() - pdb.post_mortem(sys.exc_info()[2]) + if app is not None: + message_log: Sequence[str] = app.messagelog + extensions: Collection[Extension] = app.extensions.values() else: - print(file=stderr) - if args.verbosity or args.traceback: - exc = sys.exc_info()[1] - if isinstance(exc, SphinxParallelError): - exc_format = '(Error in parallel process)\n' + exc.traceback - print(exc_format, file=stderr) - else: - traceback.print_exc(None, stderr) - print(file=stderr) - if isinstance(exception, KeyboardInterrupt): - print(__('Interrupted!'), file=stderr) - elif isinstance(exception, SystemMessage): - print(red(__('reST markup error:')), file=stderr) - print(terminal_safe(exception.args[0]), file=stderr) - elif isinstance(exception, SphinxError): - print(red('%s:' % exception.category), file=stderr) - print(str(exception), file=stderr) - elif isinstance(exception, UnicodeError): - print(red(__('Encoding error:')), file=stderr) - print(terminal_safe(str(exception)), file=stderr) - tbpath = save_traceback(app, exception) - print( - red( - __( - 'The full traceback has been saved in %s, if you want ' - 'to report the issue to the developers.' - ) - % tbpath - ), - file=stderr, - ) - elif ( - isinstance(exception, RuntimeError) - and 'recursion depth' in str(exception) - ): # fmt: skip - print(red(__('Recursion error:')), file=stderr) - print(terminal_safe(str(exception)), file=stderr) - print(file=stderr) - print( - __( - 'This can happen with very large or deeply nested source ' - 'files. You can carefully increase the default Python ' - 'recursion limit of 1000 in conf.py with e.g.:' - ), - file=stderr, - ) - print(' import sys; sys.setrecursionlimit(1500)', file=stderr) - else: - print(red(__('Exception occurred:')), file=stderr) - print(format_exception_cut_frames().rstrip(), file=stderr) - tbpath = save_traceback(app, exception) - print( - red( - __( - 'The full traceback has been saved in %s, if you ' - 'want to report the issue to the developers.' - ) - % tbpath - ), - file=stderr, - ) - print( - __( - 'Please also report this if it was a user error, so ' - 'that a better error message can be provided next time.' - ), - file=stderr, - ) - print( - __( - 'A bug report can be filed in the tracker at ' - '. Thanks!' - ), - file=stderr, - ) + message_log = extensions = () + return sphinx._cli.util.errors.handle_exception( + exception, + stderr=stderr, + use_pdb=args.pdb, + print_traceback=args.verbosity or args.traceback, + message_log=message_log, + extensions=extensions, + ) def jobs_argument(value: str) -> int: @@ -512,7 +431,19 @@ def build_main(argv: Sequence[str]) -> int: app.build(args.force_all, args.filenames) return app.statuscode except (Exception, KeyboardInterrupt) as exc: - handle_exception(app, args, exc, args.error) + if app is not None: + message_log: Sequence[str] = app.messagelog + extensions: Collection[Extension] = app.extensions.values() + else: + message_log = extensions = () + sphinx._cli.util.errors.handle_exception( + exc, + stderr=args.error, + use_pdb=args.pdb, + print_traceback=args.verbosity or args.traceback, + message_log=message_log, + extensions=extensions, + ) return 2 finally: if warnfp is not None: diff --git a/sphinx/errors.py b/sphinx/errors.py index c0339b4e9..c1ae15446 100644 --- a/sphinx/errors.py +++ b/sphinx/errors.py @@ -2,8 +2,6 @@ from __future__ import annotations -from typing import Any - class SphinxError(Exception): """Base class for Sphinx errors. @@ -109,7 +107,7 @@ class SphinxParallelError(SphinxError): category = 'Sphinx parallel build error' - def __init__(self, message: str, traceback: Any) -> None: + def __init__(self, message: str, traceback: str) -> None: self.message = message self.traceback = traceback diff --git a/sphinx/util/exceptions.py b/sphinx/util/exceptions.py deleted file mode 100644 index c25a9ac7f..000000000 --- a/sphinx/util/exceptions.py +++ /dev/null @@ -1,19 +0,0 @@ -from __future__ import annotations - -import sys -import traceback - -from sphinx._cli.util.errors import save_traceback - -__all__ = 'save_traceback', 'format_exception_cut_frames' - - -def format_exception_cut_frames(x: int = 1) -> str: - """Format an exception with traceback, but only the last x frames.""" - typ, val, tb = sys.exc_info() - # res = ['Traceback (most recent call last):\n'] - res: list[str] = [] - tbres = traceback.format_tb(tb) - res += tbres[-x:] - res += traceback.format_exception_only(typ, val) - return ''.join(res)