mirror of
https://github.com/shlinkio/shlink.git
synced 2024-11-22 08:56:42 -06:00
Extract logic to match IP address against list of groups
This commit is contained in:
parent
b6b2530cb6
commit
8d90661d0a
15
module/Core/src/Exception/InvalidIpFormatException.php
Normal file
15
module/Core/src/Exception/InvalidIpFormatException.php
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
<?php
|
||||||
|
|
||||||
|
declare(strict_types=1);
|
||||||
|
|
||||||
|
namespace Shlinkio\Shlink\Core\Exception;
|
||||||
|
|
||||||
|
use function sprintf;
|
||||||
|
|
||||||
|
class InvalidIpFormatException extends RuntimeException implements ExceptionInterface
|
||||||
|
{
|
||||||
|
public static function fromInvalidIp(string $ipAddress): self
|
||||||
|
{
|
||||||
|
return new self(sprintf('Provided IP %s does not have the right format. Expected X.X.X.X', $ipAddress));
|
||||||
|
}
|
||||||
|
}
|
65
module/Core/src/Util/IpAddressUtils.php
Normal file
65
module/Core/src/Util/IpAddressUtils.php
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
<?php
|
||||||
|
|
||||||
|
declare(strict_types=1);
|
||||||
|
|
||||||
|
namespace Shlinkio\Shlink\Core\Util;
|
||||||
|
|
||||||
|
use IPLib\Address\IPv4;
|
||||||
|
use IPLib\Factory;
|
||||||
|
use IPLib\Range\RangeInterface;
|
||||||
|
use Shlinkio\Shlink\Core\Exception\InvalidIpFormatException;
|
||||||
|
|
||||||
|
use function array_keys;
|
||||||
|
use function array_map;
|
||||||
|
use function explode;
|
||||||
|
use function implode;
|
||||||
|
use function Shlinkio\Shlink\Core\ArrayUtils\some;
|
||||||
|
|
||||||
|
class IpAddressUtils
|
||||||
|
{
|
||||||
|
/**
|
||||||
|
* Checks if an IP address matches any of provided groups.
|
||||||
|
* Every group can be a static IP address (100.200.80.40), a CIDR block (192.168.10.0/24) or a wildcard pattern
|
||||||
|
* (11.22.*.*).
|
||||||
|
*
|
||||||
|
* Matching will happen as follows:
|
||||||
|
* * Static IP address -> strict equality with provided IP address.
|
||||||
|
* * CIDR block -> provided IP address is part of that block.
|
||||||
|
* * Wildcard -> static parts match the corresponding ones in provided IP address.
|
||||||
|
*
|
||||||
|
* @param string[] $groups
|
||||||
|
* @throws InvalidIpFormatException
|
||||||
|
*/
|
||||||
|
public static function ipAddressMatchesGroups(string $ipAddress, array $groups): bool
|
||||||
|
{
|
||||||
|
$ip = IPv4::parseString($ipAddress);
|
||||||
|
if ($ip === null) {
|
||||||
|
throw InvalidIpFormatException::fromInvalidIp($ipAddress);
|
||||||
|
}
|
||||||
|
|
||||||
|
$ipAddressParts = explode('.', $ipAddress);
|
||||||
|
|
||||||
|
return some($groups, function (string $value) use ($ip, $ipAddressParts): bool {
|
||||||
|
$range = str_contains($value, '*')
|
||||||
|
? self::parseValueWithWildcards($value, $ipAddressParts)
|
||||||
|
: Factory::parseRangeString($value);
|
||||||
|
|
||||||
|
return $range !== null && $ip->matches($range);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
private static function parseValueWithWildcards(string $value, array $ipAddressParts): ?RangeInterface
|
||||||
|
{
|
||||||
|
$octets = explode('.', $value);
|
||||||
|
$keys = array_keys($octets);
|
||||||
|
|
||||||
|
// Replace wildcard parts with the corresponding ones from the remote address
|
||||||
|
return Factory::parseRangeString(
|
||||||
|
implode('.', array_map(
|
||||||
|
fn (string $part, int $index) => $part === '*' ? $ipAddressParts[$index] : $part,
|
||||||
|
$octets,
|
||||||
|
$keys,
|
||||||
|
)),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
@ -5,30 +5,20 @@ declare(strict_types=1);
|
|||||||
namespace Shlinkio\Shlink\Core\Visit;
|
namespace Shlinkio\Shlink\Core\Visit;
|
||||||
|
|
||||||
use Fig\Http\Message\RequestMethodInterface;
|
use Fig\Http\Message\RequestMethodInterface;
|
||||||
use IPLib\Address\IPv4;
|
|
||||||
use IPLib\Factory;
|
|
||||||
use IPLib\Range\RangeInterface;
|
|
||||||
use Mezzio\Router\Middleware\ImplicitHeadMiddleware;
|
use Mezzio\Router\Middleware\ImplicitHeadMiddleware;
|
||||||
use Psr\Http\Message\ServerRequestInterface;
|
use Psr\Http\Message\ServerRequestInterface;
|
||||||
use Shlinkio\Shlink\Common\Middleware\IpAddressMiddlewareFactory;
|
use Shlinkio\Shlink\Common\Middleware\IpAddressMiddlewareFactory;
|
||||||
use Shlinkio\Shlink\Core\ErrorHandler\Model\NotFoundType;
|
use Shlinkio\Shlink\Core\ErrorHandler\Model\NotFoundType;
|
||||||
|
use Shlinkio\Shlink\Core\Exception\InvalidIpFormatException;
|
||||||
use Shlinkio\Shlink\Core\Options\TrackingOptions;
|
use Shlinkio\Shlink\Core\Options\TrackingOptions;
|
||||||
use Shlinkio\Shlink\Core\ShortUrl\Entity\ShortUrl;
|
use Shlinkio\Shlink\Core\ShortUrl\Entity\ShortUrl;
|
||||||
|
use Shlinkio\Shlink\Core\Util\IpAddressUtils;
|
||||||
use Shlinkio\Shlink\Core\Visit\Model\Visitor;
|
use Shlinkio\Shlink\Core\Visit\Model\Visitor;
|
||||||
|
|
||||||
use function array_keys;
|
readonly class RequestTracker implements RequestTrackerInterface, RequestMethodInterface
|
||||||
use function array_map;
|
{
|
||||||
use function explode;
|
public function __construct(private VisitsTrackerInterface $visitsTracker, private TrackingOptions $trackingOptions)
|
||||||
use function implode;
|
|
||||||
use function Shlinkio\Shlink\Core\ArrayUtils\some;
|
|
||||||
use function str_contains;
|
|
||||||
|
|
||||||
class RequestTracker implements RequestTrackerInterface, RequestMethodInterface
|
|
||||||
{
|
{
|
||||||
public function __construct(
|
|
||||||
private readonly VisitsTrackerInterface $visitsTracker,
|
|
||||||
private readonly TrackingOptions $trackingOptions,
|
|
||||||
) {
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public function trackIfApplicable(ShortUrl $shortUrl, ServerRequestInterface $request): void
|
public function trackIfApplicable(ShortUrl $shortUrl, ServerRequestInterface $request): void
|
||||||
@ -78,35 +68,10 @@ class RequestTracker implements RequestTrackerInterface, RequestMethodInterface
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
$ip = IPv4::parseString($remoteAddr);
|
try {
|
||||||
if ($ip === null) {
|
return IpAddressUtils::ipAddressMatchesGroups($remoteAddr, $this->trackingOptions->disableTrackingFrom);
|
||||||
|
} catch (InvalidIpFormatException) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
$remoteAddrParts = explode('.', $remoteAddr);
|
|
||||||
$disableTrackingFrom = $this->trackingOptions->disableTrackingFrom;
|
|
||||||
|
|
||||||
return some($disableTrackingFrom, function (string $value) use ($ip, $remoteAddrParts): bool {
|
|
||||||
$range = str_contains($value, '*')
|
|
||||||
? $this->parseValueWithWildcards($value, $remoteAddrParts)
|
|
||||||
: Factory::parseRangeString($value);
|
|
||||||
|
|
||||||
return $range !== null && $ip->matches($range);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
private function parseValueWithWildcards(string $value, array $remoteAddrParts): ?RangeInterface
|
|
||||||
{
|
|
||||||
$octets = explode('.', $value);
|
|
||||||
$keys = array_keys($octets);
|
|
||||||
|
|
||||||
// Replace wildcard parts with the corresponding ones from the remote address
|
|
||||||
return Factory::parseRangeString(
|
|
||||||
implode('.', array_map(
|
|
||||||
fn (string $part, int $index) => $part === '*' ? $remoteAddrParts[$index] : $part,
|
|
||||||
$octets,
|
|
||||||
$keys,
|
|
||||||
)),
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -92,6 +92,21 @@ class RequestTrackerTest extends TestCase
|
|||||||
$this->requestTracker->trackIfApplicable($shortUrl, $this->request);
|
$this->requestTracker->trackIfApplicable($shortUrl, $this->request);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[Test]
|
||||||
|
public function trackingHappensOverShortUrlsWhenRemoteAddressIsInvalid(): void
|
||||||
|
{
|
||||||
|
$shortUrl = ShortUrl::withLongUrl(self::LONG_URL);
|
||||||
|
$this->visitsTracker->expects($this->once())->method('track')->with(
|
||||||
|
$shortUrl,
|
||||||
|
$this->isInstanceOf(Visitor::class),
|
||||||
|
);
|
||||||
|
|
||||||
|
$this->requestTracker->trackIfApplicable($shortUrl, ServerRequestFactory::fromGlobals()->withAttribute(
|
||||||
|
IpAddressMiddlewareFactory::REQUEST_ATTR,
|
||||||
|
'invalid',
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
#[Test]
|
#[Test]
|
||||||
public function baseUrlErrorIsTracked(): void
|
public function baseUrlErrorIsTracked(): void
|
||||||
{
|
{
|
||||||
|
Loading…
Reference in New Issue
Block a user