[config] protect `is_serializable` against circular references (#12196)

This commit is contained in:
Bénédikt Tran 2024-03-25 11:19:02 +01:00 committed by GitHub
parent 885818bb7f
commit f26d492d6d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 210 additions and 7 deletions

View File

@ -51,17 +51,30 @@ class ConfigValue(NamedTuple):
rebuild: _ConfigRebuild
def is_serializable(obj: Any) -> bool:
def is_serializable(obj: object, *, _recursive_guard: frozenset[int] = frozenset()) -> bool:
"""Check if object is serializable or not."""
if isinstance(obj, UNSERIALIZABLE_TYPES):
return False
elif isinstance(obj, dict):
for key, value in obj.items():
if not is_serializable(key) or not is_serializable(value):
return False
elif isinstance(obj, (list, tuple, set)):
return all(map(is_serializable, obj))
# use id() to handle un-hashable objects
if id(obj) in _recursive_guard:
return True
if isinstance(obj, dict):
guard = _recursive_guard | {id(obj)}
for key, value in obj.items():
if (
not is_serializable(key, _recursive_guard=guard)
or not is_serializable(value, _recursive_guard=guard)
):
return False
elif isinstance(obj, (list, tuple, set, frozenset)):
guard = _recursive_guard | {id(obj)}
return all(is_serializable(item, _recursive_guard=guard) for item in obj)
# if an issue occurs for a non-serializable type, pickle will complain
# since the object is likely coming from a third-party extension (we
# natively expect 'simple' types and not weird ones)
return True

View File

@ -1,7 +1,11 @@
"""Test the sphinx.config.Config class."""
from __future__ import annotations
import pickle
import time
from collections import Counter
from pathlib import Path
from typing import TYPE_CHECKING
from unittest import mock
import pytest
@ -14,10 +18,51 @@ from sphinx.config import (
_Opt,
check_confval_types,
correct_copyright_year,
is_serializable,
)
from sphinx.deprecation import RemovedInSphinx90Warning
from sphinx.errors import ConfigError, ExtensionError, VersionRequirementError
if TYPE_CHECKING:
from collections.abc import Iterable
from typing import Union
CircularList = list[Union[int, 'CircularList']]
CircularDict = dict[str, Union[int, 'CircularDict']]
def check_is_serializable(subject: object, *, circular: bool) -> None:
assert is_serializable(subject)
if circular:
class UselessGuard(frozenset[int]):
def __or__(self, other: object, /) -> UselessGuard:
# do nothing
return self
def union(self, *args: Iterable[object]) -> UselessGuard:
# do nothing
return self
# check that without recursive guards, a recursion error occurs
with pytest.raises(RecursionError):
assert is_serializable(subject, _recursive_guard=UselessGuard())
def test_is_serializable() -> None:
subject = [1, [2, {3, 'a'}], {'x': {'y': frozenset((4, 5))}}]
check_is_serializable(subject, circular=False)
a, b = [1], [2] # type: (CircularList, CircularList)
a.append(b)
b.append(a)
check_is_serializable(a, circular=True)
check_is_serializable(b, circular=True)
x: CircularDict = {'a': 1, 'b': {'c': 1}}
x['b'] = x
check_is_serializable(x, circular=True)
def test_config_opt_deprecated(recwarn):
opt = _Opt('default', '', ())
@ -102,6 +147,151 @@ def test_config_pickle_protocol(tmp_path, protocol: int):
assert repr(config) == repr(pickled_config)
def test_config_pickle_circular_reference_in_list():
a, b = [1], [2] # type: (CircularList, CircularList)
a.append(b)
b.append(a)
check_is_serializable(a, circular=True)
check_is_serializable(b, circular=True)
config = Config()
config.add('a', [], '', types=list)
config.add('b', [], '', types=list)
config.a, config.b = a, b
actual = pickle.loads(pickle.dumps(config))
assert isinstance(actual.a, list)
check_is_serializable(actual.a, circular=True)
assert isinstance(actual.b, list)
check_is_serializable(actual.b, circular=True)
assert actual.a[0] == 1
assert actual.a[1][0] == 2
assert actual.a[1][1][0] == 1
assert actual.a[1][1][1][0] == 2
assert actual.b[0] == 2
assert actual.b[1][0] == 1
assert actual.b[1][1][0] == 2
assert actual.b[1][1][1][0] == 1
assert len(actual.a) == 2
assert len(actual.a[1]) == 2
assert len(actual.a[1][1]) == 2
assert len(actual.a[1][1][1]) == 2
assert len(actual.a[1][1][1][1]) == 2
assert len(actual.b) == 2
assert len(actual.b[1]) == 2
assert len(actual.b[1][1]) == 2
assert len(actual.b[1][1][1]) == 2
assert len(actual.b[1][1][1][1]) == 2
def check(
u: list[list[object] | int],
v: list[list[object] | int],
*,
counter: Counter[type, int] | None = None,
guard: frozenset[int] = frozenset(),
) -> Counter[type, int]:
counter = Counter() if counter is None else counter
if id(u) in guard and id(v) in guard:
return counter
if isinstance(u, int):
assert v.__class__ is u.__class__
assert u == v
counter[type(u)] += 1
return counter
assert isinstance(u, list)
assert v.__class__ is u.__class__
assert len(u) == len(v)
for u_i, v_i in zip(u, v):
counter[type(u)] += 1
check(u_i, v_i, counter=counter, guard=guard | {id(u), id(v)})
return counter
counter = check(actual.a, a)
# check(actual.a, a)
# check(actual.a[0], a[0]) -> ++counter[dict]
# ++counter[int] (a[0] is an int)
# check(actual.a[1], a[1]) -> ++counter[dict]
# check(actual.a[1][0], a[1][0]) -> ++counter[dict]
# ++counter[int] (a[1][0] is an int)
# check(actual.a[1][1], a[1][1]) -> ++counter[dict]
# recursive guard since a[1][1] == a
assert counter[type(a[0])] == 2
assert counter[type(a[1])] == 4
# same logic as above
counter = check(actual.b, b)
assert counter[type(b[0])] == 2
assert counter[type(b[1])] == 4
def test_config_pickle_circular_reference_in_dict():
x: CircularDict = {'a': 1, 'b': {'c': 1}}
x['b'] = x
check_is_serializable(x, circular=True)
config = Config()
config.add('x', [], '', types=dict)
config.x = x
actual = pickle.loads(pickle.dumps(config))
check_is_serializable(actual.x, circular=True)
assert isinstance(actual.x, dict)
assert actual.x['a'] == 1
assert actual.x['b']['a'] == 1
assert len(actual.x) == 2
assert len(actual.x['b']) == 2
assert len(actual.x['b']['b']) == 2
def check(
u: dict[str, dict[str, object] | int],
v: dict[str, dict[str, object] | int],
*,
counter: Counter[type, int] | None = None,
guard: frozenset[int] = frozenset(),
) -> Counter:
counter = Counter() if counter is None else counter
if id(u) in guard and id(v) in guard:
return counter
if isinstance(u, int):
assert v.__class__ is u.__class__
assert u == v
counter[type(u)] += 1
return counter
assert isinstance(u, dict)
assert v.__class__ is u.__class__
assert len(u) == len(v)
for u_i, v_i in zip(u, v):
counter[type(u)] += 1
check(u[u_i], v[v_i], counter=counter, guard=guard | {id(u), id(v)})
return counter
counters = check(actual.x, x, counter=Counter())
# check(actual.x, x)
# check(actual.x['a'], x['a']) -> ++counter[dict]
# ++counter[int] (x['a'] is an int)
# check(actual.x['b'], x['b']) -> ++counter[dict]
# recursive guard since x['b'] == x
assert counters[type(x['a'])] == 1
assert counters[type(x['b'])] == 2
def test_extension_values():
config = Config()