diff --git a/module/Rest/src/Middleware/CrossDomainMiddleware.php b/module/Rest/src/Middleware/CrossDomainMiddleware.php index 75024b62..cf76fad1 100644 --- a/module/Rest/src/Middleware/CrossDomainMiddleware.php +++ b/module/Rest/src/Middleware/CrossDomainMiddleware.php @@ -4,29 +4,19 @@ declare(strict_types=1); namespace Shlinkio\Shlink\Rest\Middleware; use Fig\Http\Message\RequestMethodInterface; -use Psr\Http\Message\ResponseInterface as Response; -use Psr\Http\Message\ServerRequestInterface as Request; +use Psr\Http\Message\ResponseInterface; +use Psr\Http\Message\ServerRequestInterface; use Psr\Http\Server\MiddlewareInterface; use Psr\Http\Server\RequestHandlerInterface; use Shlinkio\Shlink\Rest\Authentication; +use Zend\Expressive\Router\RouteResult; use function implode; class CrossDomainMiddleware implements MiddlewareInterface, RequestMethodInterface { - /** - * Process an incoming server request and return a response, optionally delegating - * to the next middleware component to create the response. - * - * @param Request $request - * @param RequestHandlerInterface $handler - * - * @return Response - * @throws \InvalidArgumentException - */ - public function process(Request $request, RequestHandlerInterface $handler): Response + public function process(ServerRequestInterface $request, RequestHandlerInterface $handler): ResponseInterface { - /** @var Response $response */ $response = $handler->handle($request); if (! $request->hasHeader('Origin')) { return $response; @@ -42,13 +32,28 @@ class CrossDomainMiddleware implements MiddlewareInterface, RequestMethodInterfa return $response; } - // Add OPTIONS-specific headers - foreach ([ - 'Access-Control-Allow-Methods' => 'GET,POST,PUT,PATCH,DELETE,OPTIONS', // TODO Should be dynamic -// 'Access-Control-Allow-Methods' => $response->getHeaderLine('Allow'), + return $this->addOptionsHeaders($request, $response); + } + + private function addOptionsHeaders(ServerRequestInterface $request, ResponseInterface $response): ResponseInterface + { + /** @var RouteResult $matchedRoute */ + $matchedRoute = $request->getAttribute(RouteResult::class); + $matchedMethods = $matchedRoute !== null ? $matchedRoute->getAllowedMethods() : [ + self::METHOD_GET, + self::METHOD_POST, + self::METHOD_PUT, + self::METHOD_PATCH, + self::METHOD_DELETE, + self::METHOD_OPTIONS, + ]; + $corsHeaders = [ + 'Access-Control-Allow-Methods' => implode(',', $matchedMethods), 'Access-Control-Max-Age' => '1000', 'Access-Control-Allow-Headers' => $request->getHeaderLine('Access-Control-Request-Headers'), - ] as $key => $value) { + ]; + + foreach ($corsHeaders as $key => $value) { $response = $response->withHeader($key, $value); } diff --git a/module/Rest/test/Middleware/CrossDomainMiddlewareTest.php b/module/Rest/test/Middleware/CrossDomainMiddlewareTest.php index 17539f86..9e93c500 100644 --- a/module/Rest/test/Middleware/CrossDomainMiddlewareTest.php +++ b/module/Rest/test/Middleware/CrossDomainMiddlewareTest.php @@ -10,63 +10,108 @@ use Psr\Http\Server\RequestHandlerInterface; use Shlinkio\Shlink\Rest\Middleware\CrossDomainMiddleware; use Zend\Diactoros\Response; use Zend\Diactoros\ServerRequest; +use Zend\Expressive\Router\Route; +use Zend\Expressive\Router\RouteResult; + +use function Zend\Stratigility\middleware; class CrossDomainMiddlewareTest extends TestCase { /** @var CrossDomainMiddleware */ private $middleware; /** @var ObjectProphecy */ - private $delegate; + private $handler; public function setUp(): void { $this->middleware = new CrossDomainMiddleware(); - $this->delegate = $this->prophesize(RequestHandlerInterface::class); + $this->handler = $this->prophesize(RequestHandlerInterface::class); } /** @test */ - public function nonCrossDomainRequestsAreNotAffected() + public function nonCrossDomainRequestsAreNotAffected(): void { $originalResponse = new Response(); - $this->delegate->handle(Argument::any())->willReturn($originalResponse)->shouldBeCalledOnce(); + $this->handler->handle(Argument::any())->willReturn($originalResponse)->shouldBeCalledOnce(); - $response = $this->middleware->process(new ServerRequest(), $this->delegate->reveal()); + $response = $this->middleware->process(new ServerRequest(), $this->handler->reveal()); $this->assertSame($originalResponse, $response); $headers = $response->getHeaders(); $this->assertArrayNotHasKey('Access-Control-Allow-Origin', $headers); + $this->assertArrayNotHasKey('Access-Control-Expose-Headers', $headers); + $this->assertArrayNotHasKey('Access-Control-Allow-Methods', $headers); + $this->assertArrayNotHasKey('Access-Control-Max-Age', $headers); $this->assertArrayNotHasKey('Access-Control-Allow-Headers', $headers); } /** @test */ - public function anyRequestIncludesTheAllowAccessHeader() + public function anyRequestIncludesTheAllowAccessHeader(): void { $originalResponse = new Response(); - $this->delegate->handle(Argument::any())->willReturn($originalResponse)->shouldBeCalledOnce(); + $this->handler->handle(Argument::any())->willReturn($originalResponse)->shouldBeCalledOnce(); $response = $this->middleware->process( (new ServerRequest())->withHeader('Origin', 'local'), - $this->delegate->reveal() + $this->handler->reveal() ); $this->assertNotSame($originalResponse, $response); $headers = $response->getHeaders(); $this->assertArrayHasKey('Access-Control-Allow-Origin', $headers); + $this->assertArrayHasKey('Access-Control-Expose-Headers', $headers); + $this->assertArrayNotHasKey('Access-Control-Allow-Methods', $headers); + $this->assertArrayNotHasKey('Access-Control-Max-Age', $headers); $this->assertArrayNotHasKey('Access-Control-Allow-Headers', $headers); } /** @test */ - public function optionsRequestIncludesMoreHeaders() + public function optionsRequestIncludesMoreHeaders(): void { $originalResponse = new Response(); $request = (new ServerRequest())->withMethod('OPTIONS')->withHeader('Origin', 'local'); - $this->delegate->handle(Argument::any())->willReturn($originalResponse)->shouldBeCalledOnce(); + $this->handler->handle(Argument::any())->willReturn($originalResponse)->shouldBeCalledOnce(); - $response = $this->middleware->process($request, $this->delegate->reveal()); + $response = $this->middleware->process($request, $this->handler->reveal()); $this->assertNotSame($originalResponse, $response); $headers = $response->getHeaders(); $this->assertArrayHasKey('Access-Control-Allow-Origin', $headers); + $this->assertArrayHasKey('Access-Control-Expose-Headers', $headers); + $this->assertArrayHasKey('Access-Control-Allow-Methods', $headers); + $this->assertArrayHasKey('Access-Control-Max-Age', $headers); $this->assertArrayHasKey('Access-Control-Allow-Headers', $headers); } + + /** + * @test + * @dataProvider provideRouteResults + */ + public function optionsRequestParsesRouteMatchToDetermineAllowedMethods( + ?RouteResult $result, + string $expectedAllowedMethods + ): void { + $originalResponse = new Response(); + $request = (new ServerRequest())->withAttribute(RouteResult::class, $result) + ->withMethod('OPTIONS') + ->withHeader('Origin', 'local'); + $this->handler->handle(Argument::any())->willReturn($originalResponse)->shouldBeCalledOnce(); + + $response = $this->middleware->process($request, $this->handler->reveal()); + + $this->assertEquals($response->getHeaderLine('Access-Control-Allow-Methods'), $expectedAllowedMethods); + } + + public function provideRouteResults(): iterable + { + yield 'with no route result' => [null, 'GET,POST,PUT,PATCH,DELETE,OPTIONS']; + yield 'with failed route result' => [RouteResult::fromRouteFailure(['POST', 'GET']), 'POST,GET']; + yield 'with success route result' => [ + RouteResult::fromRoute( + new Route('/', middleware(function () { + }), ['DELETE', 'PATCH', 'PUT']) + ), + 'DELETE,PATCH,PUT', + ]; + } }