mirror of
https://github.com/sphinx-doc/sphinx.git
synced 2025-02-25 18:55:22 -06:00
sphinx.util.parallel supports logging in child workers
This commit is contained in:
parent
b43523fcbe
commit
d8ad3d063c
@ -53,6 +53,14 @@ def getLogger(name):
|
||||
return SphinxLoggerAdapter(logging.getLogger(name), {})
|
||||
|
||||
|
||||
def convert_serializable(records):
|
||||
"""Convert LogRecord serializable."""
|
||||
for r in records:
|
||||
# extract arguments to a message and clear them
|
||||
r.msg = r.getMessage()
|
||||
r.args = ()
|
||||
|
||||
|
||||
class SphinxWarningLogRecord(logging.LogRecord):
|
||||
"""Log record class supporting location"""
|
||||
location = None # type: Any
|
||||
@ -113,6 +121,10 @@ class SphinxLoggerAdapter(logging.LoggerAdapter):
|
||||
|
||||
return msg, kwargs
|
||||
|
||||
def handle(self, record):
|
||||
# type: (logging.LogRecord) -> None
|
||||
self.logger.handle(record) # type: ignore
|
||||
|
||||
|
||||
class NewLineStreamHandlerPY2(logging.StreamHandler):
|
||||
"""StreamHandler which switches line terminator by record.nonl flag."""
|
||||
@ -177,6 +189,11 @@ class MemoryHandler(logging.handlers.BufferingHandler):
|
||||
finally:
|
||||
self.release()
|
||||
|
||||
def clear(self):
|
||||
# type: () -> List[logging.LogRecord]
|
||||
buffer, self.buffer = self.buffer, []
|
||||
return buffer
|
||||
|
||||
|
||||
@contextmanager
|
||||
def pending_logging():
|
||||
@ -192,7 +209,7 @@ def pending_logging():
|
||||
handlers.append(handler)
|
||||
|
||||
logger.addHandler(memhandler)
|
||||
yield
|
||||
yield memhandler
|
||||
finally:
|
||||
logger.removeHandler(memhandler)
|
||||
|
||||
@ -202,6 +219,18 @@ def pending_logging():
|
||||
memhandler.flushTo(logger)
|
||||
|
||||
|
||||
class LogCollector(object):
|
||||
def __init__(self):
|
||||
self.logs = [] # type: logging.LogRecord
|
||||
|
||||
@contextmanager
|
||||
def collect(self):
|
||||
with pending_logging() as memhandler:
|
||||
yield
|
||||
|
||||
self.logs = memhandler.clear()
|
||||
|
||||
|
||||
class InfoFilter(logging.Filter):
|
||||
"""Filter error and warning messages."""
|
||||
|
||||
|
@ -21,11 +21,15 @@ except ImportError:
|
||||
multiprocessing = None
|
||||
|
||||
from sphinx.errors import SphinxParallelError
|
||||
from sphinx.util import logging
|
||||
|
||||
if False:
|
||||
# For type annotation
|
||||
from typing import Any, Callable, Sequence # NOQA
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# our parallel functionality only works for the forking Process
|
||||
parallel_available = multiprocessing and (os.name == 'posix')
|
||||
|
||||
@ -75,19 +79,24 @@ class ParallelTasks(object):
|
||||
def _process(self, pipe, func, arg):
|
||||
# type: (Any, Callable, Any) -> None
|
||||
try:
|
||||
if arg is None:
|
||||
ret = func()
|
||||
else:
|
||||
ret = func(arg)
|
||||
pipe.send((False, ret))
|
||||
collector = logging.LogCollector()
|
||||
with collector.collect():
|
||||
if arg is None:
|
||||
ret = func()
|
||||
else:
|
||||
ret = func(arg)
|
||||
failed = False
|
||||
except BaseException as err:
|
||||
pipe.send((True, (err, traceback.format_exc())))
|
||||
failed = True
|
||||
ret = (err, traceback.format_exc())
|
||||
logging.convert_serializable(collector.logs)
|
||||
pipe.send((failed, collector.logs, ret))
|
||||
|
||||
def add_task(self, task_func, arg=None, result_func=None):
|
||||
# type: (Callable, Any, Callable) -> None
|
||||
tid = self._taskid
|
||||
self._taskid += 1
|
||||
self._result_funcs[tid] = result_func or (lambda arg: None)
|
||||
self._result_funcs[tid] = result_func or (lambda arg, result: None)
|
||||
self._args[tid] = arg
|
||||
precv, psend = multiprocessing.Pipe(False)
|
||||
proc = multiprocessing.Process(target=self._process,
|
||||
@ -105,9 +114,11 @@ class ParallelTasks(object):
|
||||
# type: () -> None
|
||||
for tid, pipe in iteritems(self._precvs):
|
||||
if pipe.poll():
|
||||
exc, result = pipe.recv()
|
||||
exc, logs, result = pipe.recv()
|
||||
if exc:
|
||||
raise SphinxParallelError(*result)
|
||||
for log in logs:
|
||||
logger.handle(log)
|
||||
self._result_funcs.pop(tid)(self._args.pop(tid), result)
|
||||
self._procs[tid].join()
|
||||
self._pworking -= 1
|
||||
|
@ -16,6 +16,7 @@ from sphinx.errors import SphinxWarning
|
||||
from sphinx.util import logging
|
||||
from sphinx.util.console import colorize
|
||||
from sphinx.util.logging import is_suppressed_warning
|
||||
from sphinx.util.parallel import ParallelTasks
|
||||
|
||||
from util import with_app, raises, strip_escseq
|
||||
|
||||
@ -269,3 +270,19 @@ def test_colored_logs(app, status, warning):
|
||||
logger.info('message9', color='red')
|
||||
assert colorize('white', 'message8') in status.getvalue()
|
||||
assert colorize('red', 'message9') in status.getvalue()
|
||||
|
||||
|
||||
@with_app()
|
||||
def test_logging_in_ParallelTasks(app, status, warning):
|
||||
logging.setup(app, status, warning)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def child_process():
|
||||
logger.info('message1')
|
||||
logger.warning('message2', location='index')
|
||||
|
||||
tasks = ParallelTasks(1)
|
||||
tasks.add_task(child_process)
|
||||
tasks.join()
|
||||
assert 'message1' in status.getvalue()
|
||||
assert 'index.txt: WARNING: message2' in warning.getvalue()
|
||||
|
Loading…
Reference in New Issue
Block a user