mirror of
https://github.com/sphinx-doc/sphinx.git
synced 2025-02-25 18:55:22 -06:00
[config] protect `is_serializable
` against circular references (#12196)
This commit is contained in:
parent
885818bb7f
commit
f26d492d6d
@ -51,17 +51,30 @@ class ConfigValue(NamedTuple):
|
||||
rebuild: _ConfigRebuild
|
||||
|
||||
|
||||
def is_serializable(obj: Any) -> bool:
|
||||
def is_serializable(obj: object, *, _recursive_guard: frozenset[int] = frozenset()) -> bool:
|
||||
"""Check if object is serializable or not."""
|
||||
if isinstance(obj, UNSERIALIZABLE_TYPES):
|
||||
return False
|
||||
elif isinstance(obj, dict):
|
||||
for key, value in obj.items():
|
||||
if not is_serializable(key) or not is_serializable(value):
|
||||
return False
|
||||
elif isinstance(obj, (list, tuple, set)):
|
||||
return all(map(is_serializable, obj))
|
||||
|
||||
# use id() to handle un-hashable objects
|
||||
if id(obj) in _recursive_guard:
|
||||
return True
|
||||
|
||||
if isinstance(obj, dict):
|
||||
guard = _recursive_guard | {id(obj)}
|
||||
for key, value in obj.items():
|
||||
if (
|
||||
not is_serializable(key, _recursive_guard=guard)
|
||||
or not is_serializable(value, _recursive_guard=guard)
|
||||
):
|
||||
return False
|
||||
elif isinstance(obj, (list, tuple, set, frozenset)):
|
||||
guard = _recursive_guard | {id(obj)}
|
||||
return all(is_serializable(item, _recursive_guard=guard) for item in obj)
|
||||
|
||||
# if an issue occurs for a non-serializable type, pickle will complain
|
||||
# since the object is likely coming from a third-party extension (we
|
||||
# natively expect 'simple' types and not weird ones)
|
||||
return True
|
||||
|
||||
|
||||
|
@ -1,7 +1,11 @@
|
||||
"""Test the sphinx.config.Config class."""
|
||||
from __future__ import annotations
|
||||
|
||||
import pickle
|
||||
import time
|
||||
from collections import Counter
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
@ -14,10 +18,51 @@ from sphinx.config import (
|
||||
_Opt,
|
||||
check_confval_types,
|
||||
correct_copyright_year,
|
||||
is_serializable,
|
||||
)
|
||||
from sphinx.deprecation import RemovedInSphinx90Warning
|
||||
from sphinx.errors import ConfigError, ExtensionError, VersionRequirementError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable
|
||||
from typing import Union
|
||||
|
||||
CircularList = list[Union[int, 'CircularList']]
|
||||
CircularDict = dict[str, Union[int, 'CircularDict']]
|
||||
|
||||
|
||||
def check_is_serializable(subject: object, *, circular: bool) -> None:
|
||||
assert is_serializable(subject)
|
||||
|
||||
if circular:
|
||||
class UselessGuard(frozenset[int]):
|
||||
def __or__(self, other: object, /) -> UselessGuard:
|
||||
# do nothing
|
||||
return self
|
||||
|
||||
def union(self, *args: Iterable[object]) -> UselessGuard:
|
||||
# do nothing
|
||||
return self
|
||||
|
||||
# check that without recursive guards, a recursion error occurs
|
||||
with pytest.raises(RecursionError):
|
||||
assert is_serializable(subject, _recursive_guard=UselessGuard())
|
||||
|
||||
|
||||
def test_is_serializable() -> None:
|
||||
subject = [1, [2, {3, 'a'}], {'x': {'y': frozenset((4, 5))}}]
|
||||
check_is_serializable(subject, circular=False)
|
||||
|
||||
a, b = [1], [2] # type: (CircularList, CircularList)
|
||||
a.append(b)
|
||||
b.append(a)
|
||||
check_is_serializable(a, circular=True)
|
||||
check_is_serializable(b, circular=True)
|
||||
|
||||
x: CircularDict = {'a': 1, 'b': {'c': 1}}
|
||||
x['b'] = x
|
||||
check_is_serializable(x, circular=True)
|
||||
|
||||
|
||||
def test_config_opt_deprecated(recwarn):
|
||||
opt = _Opt('default', '', ())
|
||||
@ -102,6 +147,151 @@ def test_config_pickle_protocol(tmp_path, protocol: int):
|
||||
assert repr(config) == repr(pickled_config)
|
||||
|
||||
|
||||
def test_config_pickle_circular_reference_in_list():
|
||||
a, b = [1], [2] # type: (CircularList, CircularList)
|
||||
a.append(b)
|
||||
b.append(a)
|
||||
|
||||
check_is_serializable(a, circular=True)
|
||||
check_is_serializable(b, circular=True)
|
||||
|
||||
config = Config()
|
||||
config.add('a', [], '', types=list)
|
||||
config.add('b', [], '', types=list)
|
||||
config.a, config.b = a, b
|
||||
|
||||
actual = pickle.loads(pickle.dumps(config))
|
||||
assert isinstance(actual.a, list)
|
||||
check_is_serializable(actual.a, circular=True)
|
||||
|
||||
assert isinstance(actual.b, list)
|
||||
check_is_serializable(actual.b, circular=True)
|
||||
|
||||
assert actual.a[0] == 1
|
||||
assert actual.a[1][0] == 2
|
||||
assert actual.a[1][1][0] == 1
|
||||
assert actual.a[1][1][1][0] == 2
|
||||
|
||||
assert actual.b[0] == 2
|
||||
assert actual.b[1][0] == 1
|
||||
assert actual.b[1][1][0] == 2
|
||||
assert actual.b[1][1][1][0] == 1
|
||||
|
||||
assert len(actual.a) == 2
|
||||
assert len(actual.a[1]) == 2
|
||||
assert len(actual.a[1][1]) == 2
|
||||
assert len(actual.a[1][1][1]) == 2
|
||||
assert len(actual.a[1][1][1][1]) == 2
|
||||
|
||||
assert len(actual.b) == 2
|
||||
assert len(actual.b[1]) == 2
|
||||
assert len(actual.b[1][1]) == 2
|
||||
assert len(actual.b[1][1][1]) == 2
|
||||
assert len(actual.b[1][1][1][1]) == 2
|
||||
|
||||
def check(
|
||||
u: list[list[object] | int],
|
||||
v: list[list[object] | int],
|
||||
*,
|
||||
counter: Counter[type, int] | None = None,
|
||||
guard: frozenset[int] = frozenset(),
|
||||
) -> Counter[type, int]:
|
||||
counter = Counter() if counter is None else counter
|
||||
|
||||
if id(u) in guard and id(v) in guard:
|
||||
return counter
|
||||
|
||||
if isinstance(u, int):
|
||||
assert v.__class__ is u.__class__
|
||||
assert u == v
|
||||
counter[type(u)] += 1
|
||||
return counter
|
||||
|
||||
assert isinstance(u, list)
|
||||
assert v.__class__ is u.__class__
|
||||
assert len(u) == len(v)
|
||||
|
||||
for u_i, v_i in zip(u, v):
|
||||
counter[type(u)] += 1
|
||||
check(u_i, v_i, counter=counter, guard=guard | {id(u), id(v)})
|
||||
|
||||
return counter
|
||||
|
||||
counter = check(actual.a, a)
|
||||
# check(actual.a, a)
|
||||
# check(actual.a[0], a[0]) -> ++counter[dict]
|
||||
# ++counter[int] (a[0] is an int)
|
||||
# check(actual.a[1], a[1]) -> ++counter[dict]
|
||||
# check(actual.a[1][0], a[1][0]) -> ++counter[dict]
|
||||
# ++counter[int] (a[1][0] is an int)
|
||||
# check(actual.a[1][1], a[1][1]) -> ++counter[dict]
|
||||
# recursive guard since a[1][1] == a
|
||||
assert counter[type(a[0])] == 2
|
||||
assert counter[type(a[1])] == 4
|
||||
|
||||
# same logic as above
|
||||
counter = check(actual.b, b)
|
||||
assert counter[type(b[0])] == 2
|
||||
assert counter[type(b[1])] == 4
|
||||
|
||||
|
||||
def test_config_pickle_circular_reference_in_dict():
|
||||
x: CircularDict = {'a': 1, 'b': {'c': 1}}
|
||||
x['b'] = x
|
||||
check_is_serializable(x, circular=True)
|
||||
|
||||
config = Config()
|
||||
config.add('x', [], '', types=dict)
|
||||
config.x = x
|
||||
|
||||
actual = pickle.loads(pickle.dumps(config))
|
||||
check_is_serializable(actual.x, circular=True)
|
||||
assert isinstance(actual.x, dict)
|
||||
|
||||
assert actual.x['a'] == 1
|
||||
assert actual.x['b']['a'] == 1
|
||||
|
||||
assert len(actual.x) == 2
|
||||
assert len(actual.x['b']) == 2
|
||||
assert len(actual.x['b']['b']) == 2
|
||||
|
||||
def check(
|
||||
u: dict[str, dict[str, object] | int],
|
||||
v: dict[str, dict[str, object] | int],
|
||||
*,
|
||||
counter: Counter[type, int] | None = None,
|
||||
guard: frozenset[int] = frozenset(),
|
||||
) -> Counter:
|
||||
counter = Counter() if counter is None else counter
|
||||
|
||||
if id(u) in guard and id(v) in guard:
|
||||
return counter
|
||||
|
||||
if isinstance(u, int):
|
||||
assert v.__class__ is u.__class__
|
||||
assert u == v
|
||||
counter[type(u)] += 1
|
||||
return counter
|
||||
|
||||
assert isinstance(u, dict)
|
||||
assert v.__class__ is u.__class__
|
||||
assert len(u) == len(v)
|
||||
|
||||
for u_i, v_i in zip(u, v):
|
||||
counter[type(u)] += 1
|
||||
check(u[u_i], v[v_i], counter=counter, guard=guard | {id(u), id(v)})
|
||||
return counter
|
||||
|
||||
counters = check(actual.x, x, counter=Counter())
|
||||
# check(actual.x, x)
|
||||
# check(actual.x['a'], x['a']) -> ++counter[dict]
|
||||
# ++counter[int] (x['a'] is an int)
|
||||
# check(actual.x['b'], x['b']) -> ++counter[dict]
|
||||
# recursive guard since x['b'] == x
|
||||
assert counters[type(x['a'])] == 1
|
||||
assert counters[type(x['b'])] == 2
|
||||
|
||||
|
||||
def test_extension_values():
|
||||
config = Config()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user