mirror of
				https://github.com/sphinx-doc/sphinx.git
				synced 2025-02-25 18:55:22 -06:00 
			
		
		
		
	
		
			
				
	
	
		
			125 lines
		
	
	
		
			4.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			125 lines
		
	
	
		
			4.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| from __future__ import annotations
 | |
| 
 | |
| __all__ = ('http_server',)
 | |
| 
 | |
| import socket
 | |
| from contextlib import contextmanager
 | |
| from http.server import ThreadingHTTPServer
 | |
| from pathlib import Path
 | |
| from ssl import PROTOCOL_TLS_SERVER, SSLContext
 | |
| from threading import Thread
 | |
| from typing import TYPE_CHECKING
 | |
| from urllib.parse import urlparse
 | |
| 
 | |
| if TYPE_CHECKING:
 | |
|     from collections.abc import Iterator
 | |
|     from http.server import HTTPServer
 | |
|     from socketserver import BaseRequestHandler
 | |
|     from typing import Final
 | |
| 
 | |
|     from sphinx.application import Sphinx
 | |
| 
 | |
| # Generated with:
 | |
| # $ openssl req -new -x509 -days 3650 -nodes -out cert.pem \
 | |
| #     -keyout cert.pem -addext "subjectAltName = DNS:localhost"
 | |
| TESTS_ROOT: Final[Path] = Path(__file__).resolve().parent
 | |
| CERT_FILE: Final[str] = str(TESTS_ROOT / 'certs' / 'cert.pem')
 | |
| 
 | |
| 
 | |
| class HttpServerThread(Thread):
 | |
|     def __init__(self, handler: type[BaseRequestHandler], *, port: int = 0) -> None:
 | |
|         """Constructs a threaded HTTP server.  The default port number of ``0``
 | |
|         delegates selection of a port number to bind to to Python.
 | |
| 
 | |
|         Ref: https://docs.python.org/3.11/library/socketserver.html#asynchronous-mixins
 | |
|         """
 | |
|         super().__init__(daemon=True)
 | |
|         self.server = ThreadingHTTPServer(('localhost', port), handler)
 | |
| 
 | |
|     def run(self) -> None:
 | |
|         self.server.serve_forever(poll_interval=0.001)
 | |
| 
 | |
|     def terminate(self) -> None:
 | |
|         self.server.shutdown()
 | |
|         self.server.server_close()
 | |
|         self.join()
 | |
| 
 | |
| 
 | |
| class HttpsServerThread(HttpServerThread):
 | |
|     def __init__(self, handler: type[BaseRequestHandler], *, port: int = 0) -> None:
 | |
|         super().__init__(handler, port=port)
 | |
|         sslcontext = SSLContext(PROTOCOL_TLS_SERVER)
 | |
|         sslcontext.load_cert_chain(CERT_FILE)
 | |
|         self.server.socket = sslcontext.wrap_socket(
 | |
|             self.server.socket, server_side=True
 | |
|         )
 | |
| 
 | |
| 
 | |
| @contextmanager
 | |
| def http_server(
 | |
|     handler: type[BaseRequestHandler],
 | |
|     *,
 | |
|     tls_enabled: bool = False,
 | |
|     port: int = 0,
 | |
| ) -> Iterator[HTTPServer]:
 | |
|     server_cls = HttpsServerThread if tls_enabled else HttpServerThread
 | |
|     server_thread = server_cls(handler, port=port)
 | |
|     server_thread.start()
 | |
|     server_port = server_thread.server.server_port
 | |
|     assert port in {0, server_port}
 | |
|     try:
 | |
|         socket.create_connection(('localhost', server_port), timeout=0.5).close()
 | |
|         yield server_thread.server  # Connection has been confirmed possible; proceed.
 | |
|     finally:
 | |
|         server_thread.terminate()
 | |
| 
 | |
| 
 | |
| @contextmanager
 | |
| def rewrite_hyperlinks(app: Sphinx, server: HTTPServer) -> Iterator[None]:
 | |
|     """Rewrite hyperlinks that refer to network location 'localhost:7777',
 | |
|     allowing that location to vary dynamically with the arbitrary test HTTP
 | |
|     server port assigned during unit testing.
 | |
| 
 | |
|     :param app: The Sphinx application where link replacement is to occur.
 | |
|     :param server: Destination server to redirect the hyperlinks to.
 | |
|     """
 | |
|     match_netloc, replacement_netloc = (
 | |
|         'localhost:7777',
 | |
|         f'localhost:{server.server_port}',
 | |
|     )
 | |
| 
 | |
|     def rewrite_hyperlink(_app: Sphinx, uri: str) -> str | None:
 | |
|         parsed_uri = urlparse(uri)
 | |
|         if parsed_uri.netloc != match_netloc:
 | |
|             return uri
 | |
|         return parsed_uri._replace(netloc=replacement_netloc).geturl()
 | |
| 
 | |
|     listener_id = app.connect('linkcheck-process-uri', rewrite_hyperlink)
 | |
|     yield
 | |
|     app.disconnect(listener_id)
 | |
| 
 | |
| 
 | |
| @contextmanager
 | |
| def serve_application(
 | |
|     app: Sphinx,
 | |
|     handler: type[BaseRequestHandler],
 | |
|     *,
 | |
|     tls_enabled: bool = False,
 | |
|     port: int = 0,
 | |
| ) -> Iterator[str]:
 | |
|     """Prepare a temporary server to handle HTTP requests related to the links
 | |
|     found in a Sphinx application project.
 | |
| 
 | |
|     :param app: The Sphinx application.
 | |
|     :param handler: Determines how each request will be handled.
 | |
|     :param tls_enabled: Whether TLS (SSL) should be enabled for the server.
 | |
|     :param port: Optional server port (default: auto).
 | |
| 
 | |
|     :return: The address of the temporary HTTP server.
 | |
|     """
 | |
|     with (
 | |
|         http_server(handler, tls_enabled=tls_enabled, port=port) as server,
 | |
|         rewrite_hyperlinks(app, server),
 | |
|     ):
 | |
|         yield f'localhost:{server.server_port}'
 |