Consolidate both `handle_exception()` implementations

This commit is contained in:
Adam Turner 2025-01-10 14:54:24 +00:00
parent dec45eaf28
commit 404e4ffbed
4 changed files with 89 additions and 157 deletions

View File

@ -2,15 +2,18 @@ from __future__ import annotations
import re
import sys
import tempfile
from typing import TYPE_CHECKING, TextIO
from typing import TYPE_CHECKING
from sphinx.errors import SphinxParallelError
from sphinx.errors import SphinxError, SphinxParallelError
if TYPE_CHECKING:
from typing import Final
from collections.abc import Collection
from typing import Final, Protocol
from sphinx.application import Sphinx
from sphinx.extension import Extension
class SupportsWrite(Protocol):
def write(self, text: str, /) -> int | None: ...
_CSI: Final[str] = re.escape('\x1b[') # 'ESC [': Control Sequence Introducer
@ -50,6 +53,34 @@ def strip_escape_sequences(text: str, /) -> str:
return _ANSI_CODES.sub('', text)
def full_exception_context(
exception: BaseException,
*,
message_log: Collection[str] = (),
extensions: Collection[Extension] = (),
) -> str:
"""Return a formatted message containing useful debugging context."""
last_msgs = '\n'.join(f'* {strip_escape_sequences(s)}' for s in message_log)
exts_list = '\n'.join(
f'* {ext.name} ({ext.version})'
for ext in extensions
if ext.version != 'builtin'
)
exc_format = format_traceback(exception)
return error_info(last_msgs or 'None.', exts_list or 'None.', exc_format)
def format_traceback(exception: BaseException, /) -> str:
"""Format the given exception's traceback."""
if isinstance(exception, SphinxParallelError):
return f'(Error in parallel process)\n{exception.traceback}'
else:
from traceback import format_exception
exc_format = ''.join(format_exception(exception))
return exc_format
def error_info(messages: str, extensions: str, traceback: str) -> str:
"""Format the traceback and extensions list with environment information."""
import platform
@ -88,37 +119,26 @@ Traceback
"""
def format_traceback(app: Sphinx | None, exc: BaseException) -> str:
"""Format the given exception's traceback with environment information."""
if isinstance(exc, SphinxParallelError):
exc_format = '(Error in parallel process)\n' + exc.traceback
else:
import traceback
exc_format = traceback.format_exc()
last_msgs = exts_list = ''
if app is not None:
extensions = app.extensions.values()
last_msgs = '\n'.join(f'* {strip_escape_sequences(s)}' for s in app.messagelog)
exts_list = '\n'.join(
f'* {ext.name} ({ext.version})'
for ext in extensions
if ext.version != 'builtin'
)
return error_info(last_msgs, exts_list, exc_format)
def save_traceback(app: Sphinx | None, exc: BaseException) -> str:
def save_traceback(
exception: BaseException,
*,
message_log: Collection[str] = (),
extensions: Collection[Extension] = (),
) -> str:
"""Save the given exception's traceback in a temporary file."""
output = format_traceback(app=app, exc=exc)
output = full_exception_context(
exception=exception,
message_log=message_log,
extensions=extensions,
)
filename = write_temporary_file(output)
return filename
def write_temporary_file(content: str) -> str:
"""Write content to a temporary file and return the filename."""
import tempfile
with tempfile.NamedTemporaryFile(
'w', encoding='utf-8', suffix='.log', prefix='sphinx-err-', delete=False
) as f:
@ -131,18 +151,18 @@ def handle_exception(
exception: BaseException,
/,
*,
stderr: TextIO = sys.stderr,
stderr: SupportsWrite = sys.stderr,
use_pdb: bool = False,
print_traceback: bool = False,
app: Sphinx | None = None,
message_log: Collection[str] = (),
extensions: Collection[Extension] = (),
) -> None:
from bdb import BdbQuit
from traceback import TracebackException, print_exc
from traceback import TracebackException
from docutils.utils import SystemMessage
from sphinx._cli.util.colour import red
from sphinx.errors import SphinxError
from sphinx.locale import __
if isinstance(exception, BdbQuit):
@ -156,14 +176,14 @@ def handle_exception(
print_err()
if print_traceback or use_pdb:
print_exc(file=stderr)
print_err(format_traceback(exception))
print_err()
if use_pdb:
from pdb import post_mortem
print_red(__('Exception occurred, starting debugger:'))
post_mortem()
post_mortem(exception.__traceback__)
return
if isinstance(exception, KeyboardInterrupt):
@ -193,7 +213,7 @@ def handle_exception(
__(
'This can happen with very large or deeply nested source '
'files. You can carefully increase the default Python '
'recursion limit of 1000 in conf.py with e.g.:'
'recursion limit of 1,000 in conf.py with e.g.:'
)
)
print_err('\n import sys\n sys.setrecursionlimit(1_500)\n')
@ -205,7 +225,9 @@ def handle_exception(
print_red(__('Exception occurred:'))
print_err(formatted_tb)
traceback_info_path = save_traceback(app, exception)
traceback_info_path = save_traceback(
exception, message_log=message_log, extensions=extensions
)
print_err(__('The full traceback has been saved in:'))
print_err(traceback_info_path)
print_err()

View File

@ -3,36 +3,28 @@
from __future__ import annotations
import argparse
import bdb
import contextlib
import locale
import multiprocessing
import pdb # NoQA: T100
import sys
import traceback
from pathlib import Path
from typing import TYPE_CHECKING, Any, TextIO
from docutils.utils import SystemMessage
import sphinx._cli.util.errors
import sphinx.locale
from sphinx import __display_version__
from sphinx.application import Sphinx
from sphinx.errors import SphinxError, SphinxParallelError
from sphinx.locale import __
from sphinx.util._io import TeeStripANSI
from sphinx.util._pathlib import _StrPath
from sphinx.util.console import color_terminal, nocolor, red, terminal_safe
from sphinx.util.console import color_terminal, nocolor
from sphinx.util.docutils import docutils_namespace, patch_docutils
from sphinx.util.exceptions import format_exception_cut_frames, save_traceback
from sphinx.util.osutil import ensuredir
if TYPE_CHECKING:
from collections.abc import Sequence
from typing import Protocol
from collections.abc import Collection, Sequence
class SupportsWrite(Protocol):
def write(self, text: str, /) -> int | None: ...
from sphinx.extension import Extension
def handle_exception(
@ -41,92 +33,19 @@ def handle_exception(
exception: BaseException,
stderr: TextIO = sys.stderr,
) -> None:
if isinstance(exception, bdb.BdbQuit):
return
if args.pdb:
print(
red(__('Exception occurred while building, starting debugger:')),
file=stderr,
)
traceback.print_exc()
pdb.post_mortem(sys.exc_info()[2])
if app is not None:
message_log: Sequence[str] = app.messagelog
extensions: Collection[Extension] = app.extensions.values()
else:
print(file=stderr)
if args.verbosity or args.traceback:
exc = sys.exc_info()[1]
if isinstance(exc, SphinxParallelError):
exc_format = '(Error in parallel process)\n' + exc.traceback
print(exc_format, file=stderr)
else:
traceback.print_exc(None, stderr)
print(file=stderr)
if isinstance(exception, KeyboardInterrupt):
print(__('Interrupted!'), file=stderr)
elif isinstance(exception, SystemMessage):
print(red(__('reST markup error:')), file=stderr)
print(terminal_safe(exception.args[0]), file=stderr)
elif isinstance(exception, SphinxError):
print(red('%s:' % exception.category), file=stderr)
print(str(exception), file=stderr)
elif isinstance(exception, UnicodeError):
print(red(__('Encoding error:')), file=stderr)
print(terminal_safe(str(exception)), file=stderr)
tbpath = save_traceback(app, exception)
print(
red(
__(
'The full traceback has been saved in %s, if you want '
'to report the issue to the developers.'
)
% tbpath
),
file=stderr,
)
elif (
isinstance(exception, RuntimeError)
and 'recursion depth' in str(exception)
): # fmt: skip
print(red(__('Recursion error:')), file=stderr)
print(terminal_safe(str(exception)), file=stderr)
print(file=stderr)
print(
__(
'This can happen with very large or deeply nested source '
'files. You can carefully increase the default Python '
'recursion limit of 1000 in conf.py with e.g.:'
),
file=stderr,
)
print(' import sys; sys.setrecursionlimit(1500)', file=stderr)
else:
print(red(__('Exception occurred:')), file=stderr)
print(format_exception_cut_frames().rstrip(), file=stderr)
tbpath = save_traceback(app, exception)
print(
red(
__(
'The full traceback has been saved in %s, if you '
'want to report the issue to the developers.'
)
% tbpath
),
file=stderr,
)
print(
__(
'Please also report this if it was a user error, so '
'that a better error message can be provided next time.'
),
file=stderr,
)
print(
__(
'A bug report can be filed in the tracker at '
'<https://github.com/sphinx-doc/sphinx/issues>. Thanks!'
),
file=stderr,
)
message_log = extensions = ()
return sphinx._cli.util.errors.handle_exception(
exception,
stderr=stderr,
use_pdb=args.pdb,
print_traceback=args.verbosity or args.traceback,
message_log=message_log,
extensions=extensions,
)
def jobs_argument(value: str) -> int:
@ -512,7 +431,19 @@ def build_main(argv: Sequence[str]) -> int:
app.build(args.force_all, args.filenames)
return app.statuscode
except (Exception, KeyboardInterrupt) as exc:
handle_exception(app, args, exc, args.error)
if app is not None:
message_log: Sequence[str] = app.messagelog
extensions: Collection[Extension] = app.extensions.values()
else:
message_log = extensions = ()
sphinx._cli.util.errors.handle_exception(
exc,
stderr=args.error,
use_pdb=args.pdb,
print_traceback=args.verbosity or args.traceback,
message_log=message_log,
extensions=extensions,
)
return 2
finally:
if warnfp is not None:

View File

@ -2,8 +2,6 @@
from __future__ import annotations
from typing import Any
class SphinxError(Exception):
"""Base class for Sphinx errors.
@ -109,7 +107,7 @@ class SphinxParallelError(SphinxError):
category = 'Sphinx parallel build error'
def __init__(self, message: str, traceback: Any) -> None:
def __init__(self, message: str, traceback: str) -> None:
self.message = message
self.traceback = traceback

View File

@ -1,19 +0,0 @@
from __future__ import annotations
import sys
import traceback
from sphinx._cli.util.errors import save_traceback
__all__ = 'save_traceback', 'format_exception_cut_frames'
def format_exception_cut_frames(x: int = 1) -> str:
"""Format an exception with traceback, but only the last x frames."""
typ, val, tb = sys.exc_info()
# res = ['Traceback (most recent call last):\n']
res: list[str] = []
tbres = traceback.format_tb(tb)
res += tbres[-x:]
res += traceback.format_exception_only(typ, val)
return ''.join(res)