From b7f708dc634732aa90978e1c141ba6bd2af7ee84 Mon Sep 17 00:00:00 2001 From: James Addison <55152140+jayaddison@users.noreply.github.com> Date: Wed, 20 Mar 2024 22:13:41 +0000 Subject: [PATCH] [tests] utils: refactor type-hint signatures. (#12144) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Consolidate the provision of a single `http_server` utility method, with `tls_enabled` as a boolean flag, and rework type annotations to make them more understandable. Co-authored-by: Bénédikt Tran <10796600+picnixz@users.noreply.github.com> Co-authored-by: Chris Sewell --- tests/test_builders/test_build_linkcheck.py | 12 +++---- tests/utils.py | 40 ++++++++------------- 2 files changed, 21 insertions(+), 31 deletions(-) diff --git a/tests/test_builders/test_build_linkcheck.py b/tests/test_builders/test_build_linkcheck.py index c630b700e..f19838d05 100644 --- a/tests/test_builders/test_build_linkcheck.py +++ b/tests/test_builders/test_build_linkcheck.py @@ -28,7 +28,7 @@ from sphinx.deprecation import RemovedInSphinx80Warning from sphinx.testing.util import strip_escseq from sphinx.util import requests -from tests.utils import CERT_FILE, http_server, https_server +from tests.utils import CERT_FILE, http_server ts_re = re.compile(r".*\[(?P.*)\].*") @@ -633,7 +633,7 @@ def test_invalid_ssl(get_request, app): @pytest.mark.sphinx('linkcheck', testroot='linkcheck-localserver-https', freshenv=True) def test_connect_to_selfsigned_fails(app): - with https_server(OKHandler): + with http_server(OKHandler, tls_enabled=True): app.build() with open(app.outdir / 'output.json', encoding='utf-8') as fp: @@ -648,7 +648,7 @@ def test_connect_to_selfsigned_fails(app): @pytest.mark.sphinx('linkcheck', testroot='linkcheck-localserver-https', freshenv=True) def test_connect_to_selfsigned_with_tls_verify_false(app): app.config.tls_verify = False - with https_server(OKHandler): + with http_server(OKHandler, tls_enabled=True): app.build() with open(app.outdir / 'output.json', encoding='utf-8') as fp: @@ -666,7 +666,7 @@ def test_connect_to_selfsigned_with_tls_verify_false(app): @pytest.mark.sphinx('linkcheck', testroot='linkcheck-localserver-https', freshenv=True) def test_connect_to_selfsigned_with_tls_cacerts(app): app.config.tls_cacerts = CERT_FILE - with https_server(OKHandler): + with http_server(OKHandler, tls_enabled=True): app.build() with open(app.outdir / 'output.json', encoding='utf-8') as fp: @@ -684,7 +684,7 @@ def test_connect_to_selfsigned_with_tls_cacerts(app): @pytest.mark.sphinx('linkcheck', testroot='linkcheck-localserver-https', freshenv=True) def test_connect_to_selfsigned_with_requests_env_var(monkeypatch, app): monkeypatch.setenv("REQUESTS_CA_BUNDLE", CERT_FILE) - with https_server(OKHandler): + with http_server(OKHandler, tls_enabled=True): app.build() with open(app.outdir / 'output.json', encoding='utf-8') as fp: @@ -702,7 +702,7 @@ def test_connect_to_selfsigned_with_requests_env_var(monkeypatch, app): @pytest.mark.sphinx('linkcheck', testroot='linkcheck-localserver-https', freshenv=True) def test_connect_to_selfsigned_nonexistent_cert_file(app): app.config.tls_cacerts = "does/not/exist" - with https_server(OKHandler): + with http_server(OKHandler, tls_enabled=True): app.build() with open(app.outdir / 'output.json', encoding='utf-8') as fp: diff --git a/tests/utils.py b/tests/utils.py index 1fbd431eb..6e3aed2c9 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,17 +1,18 @@ from __future__ import annotations -import contextlib +__all__ = ("http_server",) + +from contextlib import contextmanager from http.server import ThreadingHTTPServer from pathlib import Path from ssl import PROTOCOL_TLS_SERVER, SSLContext from threading import Thread -from typing import TYPE_CHECKING, TypeVar +from typing import TYPE_CHECKING import filelock if TYPE_CHECKING: - from collections.abc import Callable, Generator - from contextlib import AbstractContextManager + from collections.abc import Iterator from socketserver import BaseRequestHandler from typing import Any, Final @@ -49,24 +50,13 @@ class HttpsServerThread(HttpServerThread): self.server.socket = sslcontext.wrap_socket(self.server.socket, server_side=True) -_T_co = TypeVar('_T_co', bound=HttpServerThread, covariant=True) - - -def create_server( - server_thread_class: type[_T_co], -) -> Callable[[type[BaseRequestHandler]], AbstractContextManager[_T_co]]: - @contextlib.contextmanager - def server(handler_class: type[BaseRequestHandler]) -> Generator[_T_co, None, None]: - lock = filelock.FileLock(LOCK_PATH) - with lock: - server_thread = server_thread_class(handler_class, daemon=True) - server_thread.start() - try: - yield server_thread - finally: - server_thread.terminate() - return server - - -http_server = create_server(HttpServerThread) -https_server = create_server(HttpsServerThread) +@contextmanager +def http_server(handler: type[BaseRequestHandler], *, tls_enabled: bool = False) -> Iterator[HttpServerThread]: + server_cls = HttpsServerThread if tls_enabled else HttpServerThread + with filelock.FileLock(LOCK_PATH): + server = server_cls(handler, daemon=True) + server.start() + try: + yield server + finally: + server.terminate()