From 8d90661d0a7e0b5d7bb91e1b48a477500151ca53 Mon Sep 17 00:00:00 2001 From: Alejandro Celaya Date: Sat, 6 Jul 2024 10:12:05 +0200 Subject: [PATCH] Extract logic to match IP address against list of groups --- .../Exception/InvalidIpFormatException.php | 15 +++++ module/Core/src/Util/IpAddressUtils.php | 65 +++++++++++++++++++ module/Core/src/Visit/RequestTracker.php | 51 +++------------ module/Core/test/Visit/RequestTrackerTest.php | 15 +++++ 4 files changed, 103 insertions(+), 43 deletions(-) create mode 100644 module/Core/src/Exception/InvalidIpFormatException.php create mode 100644 module/Core/src/Util/IpAddressUtils.php diff --git a/module/Core/src/Exception/InvalidIpFormatException.php b/module/Core/src/Exception/InvalidIpFormatException.php new file mode 100644 index 00000000..df9d207b --- /dev/null +++ b/module/Core/src/Exception/InvalidIpFormatException.php @@ -0,0 +1,15 @@ + strict equality with provided IP address. + * * CIDR block -> provided IP address is part of that block. + * * Wildcard -> static parts match the corresponding ones in provided IP address. + * + * @param string[] $groups + * @throws InvalidIpFormatException + */ + public static function ipAddressMatchesGroups(string $ipAddress, array $groups): bool + { + $ip = IPv4::parseString($ipAddress); + if ($ip === null) { + throw InvalidIpFormatException::fromInvalidIp($ipAddress); + } + + $ipAddressParts = explode('.', $ipAddress); + + return some($groups, function (string $value) use ($ip, $ipAddressParts): bool { + $range = str_contains($value, '*') + ? self::parseValueWithWildcards($value, $ipAddressParts) + : Factory::parseRangeString($value); + + return $range !== null && $ip->matches($range); + }); + } + + private static function parseValueWithWildcards(string $value, array $ipAddressParts): ?RangeInterface + { + $octets = explode('.', $value); + $keys = array_keys($octets); + + // Replace wildcard parts with the corresponding ones from the remote address + return Factory::parseRangeString( + implode('.', array_map( + fn (string $part, int $index) => $part === '*' ? $ipAddressParts[$index] : $part, + $octets, + $keys, + )), + ); + } +} diff --git a/module/Core/src/Visit/RequestTracker.php b/module/Core/src/Visit/RequestTracker.php index 1a6b04f9..ecc3d94f 100644 --- a/module/Core/src/Visit/RequestTracker.php +++ b/module/Core/src/Visit/RequestTracker.php @@ -5,30 +5,20 @@ declare(strict_types=1); namespace Shlinkio\Shlink\Core\Visit; use Fig\Http\Message\RequestMethodInterface; -use IPLib\Address\IPv4; -use IPLib\Factory; -use IPLib\Range\RangeInterface; use Mezzio\Router\Middleware\ImplicitHeadMiddleware; use Psr\Http\Message\ServerRequestInterface; use Shlinkio\Shlink\Common\Middleware\IpAddressMiddlewareFactory; use Shlinkio\Shlink\Core\ErrorHandler\Model\NotFoundType; +use Shlinkio\Shlink\Core\Exception\InvalidIpFormatException; use Shlinkio\Shlink\Core\Options\TrackingOptions; use Shlinkio\Shlink\Core\ShortUrl\Entity\ShortUrl; +use Shlinkio\Shlink\Core\Util\IpAddressUtils; use Shlinkio\Shlink\Core\Visit\Model\Visitor; -use function array_keys; -use function array_map; -use function explode; -use function implode; -use function Shlinkio\Shlink\Core\ArrayUtils\some; -use function str_contains; - -class RequestTracker implements RequestTrackerInterface, RequestMethodInterface +readonly class RequestTracker implements RequestTrackerInterface, RequestMethodInterface { - public function __construct( - private readonly VisitsTrackerInterface $visitsTracker, - private readonly TrackingOptions $trackingOptions, - ) { + public function __construct(private VisitsTrackerInterface $visitsTracker, private TrackingOptions $trackingOptions) + { } public function trackIfApplicable(ShortUrl $shortUrl, ServerRequestInterface $request): void @@ -78,35 +68,10 @@ class RequestTracker implements RequestTrackerInterface, RequestMethodInterface return false; } - $ip = IPv4::parseString($remoteAddr); - if ($ip === null) { + try { + return IpAddressUtils::ipAddressMatchesGroups($remoteAddr, $this->trackingOptions->disableTrackingFrom); + } catch (InvalidIpFormatException) { return false; } - - $remoteAddrParts = explode('.', $remoteAddr); - $disableTrackingFrom = $this->trackingOptions->disableTrackingFrom; - - return some($disableTrackingFrom, function (string $value) use ($ip, $remoteAddrParts): bool { - $range = str_contains($value, '*') - ? $this->parseValueWithWildcards($value, $remoteAddrParts) - : Factory::parseRangeString($value); - - return $range !== null && $ip->matches($range); - }); - } - - private function parseValueWithWildcards(string $value, array $remoteAddrParts): ?RangeInterface - { - $octets = explode('.', $value); - $keys = array_keys($octets); - - // Replace wildcard parts with the corresponding ones from the remote address - return Factory::parseRangeString( - implode('.', array_map( - fn (string $part, int $index) => $part === '*' ? $remoteAddrParts[$index] : $part, - $octets, - $keys, - )), - ); } } diff --git a/module/Core/test/Visit/RequestTrackerTest.php b/module/Core/test/Visit/RequestTrackerTest.php index fdf7e493..c8746f91 100644 --- a/module/Core/test/Visit/RequestTrackerTest.php +++ b/module/Core/test/Visit/RequestTrackerTest.php @@ -92,6 +92,21 @@ class RequestTrackerTest extends TestCase $this->requestTracker->trackIfApplicable($shortUrl, $this->request); } + #[Test] + public function trackingHappensOverShortUrlsWhenRemoteAddressIsInvalid(): void + { + $shortUrl = ShortUrl::withLongUrl(self::LONG_URL); + $this->visitsTracker->expects($this->once())->method('track')->with( + $shortUrl, + $this->isInstanceOf(Visitor::class), + ); + + $this->requestTracker->trackIfApplicable($shortUrl, ServerRequestFactory::fromGlobals()->withAttribute( + IpAddressMiddlewareFactory::REQUEST_ATTR, + 'invalid', + )); + } + #[Test] public function baseUrlErrorIsTracked(): void {