Revise LRN reference implementation (#2672)

* fix typo in LRN docs

* fix link to reference in LRN doc

* LRN, LRN_IE types alignment with spec

* align LRN ref implementation to plugins behavior

* update LRN docs

* Improve LRN reference implementation performance

* restore LRN constructor with no axes in the input

* apply code format

* revert double->float size_t->int change

* small fix to example in doc

* revert double->float size_t->int in onnx_importer and backend tests

* Changes to docs after review
This commit is contained in:
Mateusz Tabaka 2020-10-19 07:40:04 +02:00 committed by GitHub
parent 84b5fc51dc
commit 5965010bec
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 193 additions and 151 deletions

View File

@ -26,7 +26,7 @@
* *bias*
* **Description**: *beta* represents the offset. Usually positive number to avoid dividing by zero.
* **Description**: *bias* represents the offset. Usually positive number to avoid dividing by zero.
* **Range of values**: no restrictions
* **Type**: float
* **Default value**: None
@ -50,13 +50,26 @@
* **1**: Output tensor of the same shape and type as the `data` input tensor.
**Detailed description**: [Reference](http://yeephycho.github.io/2016/08/03/Normalizations-in-neural-networks/#Local-Response-Normalization-LRN)
**Detailed description**:
Local Response Normalization performs a normalization over local input regions.
Each input value is divided by
\f[ (bias + \frac{alpha}{{size}^{len(axes)}} \cdot \sum_{i} data_{i})^{beta} \f]
The sum is taken over a region of a side length `size` and number of dimensions equal to number of axes.
The region is centered at the input value that's being normalized (with zero padding added if needed).
Here is an example for 4D `data` input tensor and `axes` = `[1]`:
Here is an example for 4D `data` input tensor and `axes = [1]`:
```
sqr_sum[a, b, c, d] =
sum(data[a, max(0, b - size / 2) : min(data.shape[1], b + size / 2 + 1), c, d] ** 2)
output = data / (bias + (alpha / size ** len(axes)) * sqr_sum) ** beta
```
sqr_sum[a, b, c, d] =
sum(input[a, b - local_size : b + local_size + 1, c, d] ** 2)
output = input / (bias + alpha * sqr_sum) ** beta
Example for 4D `data` input tensor and `axes = [2, 3]`:
```
sqr_sum[a, b, c, d] =
sum(data[a, b, max(0, c - size / 2) : min(data.shape[2], c + size / 2 + 1), max(0, d - size / 2) : min(data.shape[3], d + size / 2 + 1)] ** 2)
output = data / (bias + (alpha / size ** len(axes)) * sqr_sum) ** beta
```
**Example**
@ -83,4 +96,4 @@ Here is an example for 4D `data` input tensor and `axes` = `[1]`:
</port>
</output>
</layer>
```
```

View File

@ -17,6 +17,8 @@ namespace {
const std::vector<InferenceEngine::Precision> netPrecisions = {InferenceEngine::Precision::FP32,
InferenceEngine::Precision::FP16};
const std::vector<std::vector<int64_t>> axes = {{1}, {2, 3}};
const double alpha = 9.9e-05;
const double beta = 2;
const double bias = 1.0;
@ -27,7 +29,7 @@ INSTANTIATE_TEST_CASE_P(smoke_LrnCheck, LrnLayerTest,
::testing::Values(beta),
::testing::Values(bias),
::testing::Values(size),
::testing::Values(std::vector<int64_t>({1})),
::testing::ValuesIn(axes),
::testing::ValuesIn(netPrecisions),
::testing::Values(InferenceEngine::Precision::UNSPECIFIED),
::testing::Values(InferenceEngine::Precision::UNSPECIFIED),

View File

@ -15,6 +15,8 @@ namespace {
const std::vector<InferenceEngine::Precision> netPrecisions = {InferenceEngine::Precision::FP32,
InferenceEngine::Precision::FP16};
const std::vector<std::vector<int64_t>> axes = {{1}, {2, 3}};
const double alpha = 9.9e-05;
const double beta = 2;
const double bias = 1.0;
@ -25,7 +27,7 @@ INSTANTIATE_TEST_CASE_P(smoke_LrnCheck, LrnLayerTest,
::testing::Values(beta),
::testing::Values(bias),
::testing::Values(size),
::testing::Values(std::vector<int64_t>({1})),
::testing::ValuesIn(axes),
::testing::ValuesIn(netPrecisions),
::testing::Values(InferenceEngine::Precision::UNSPECIFIED),
::testing::Values(InferenceEngine::Precision::UNSPECIFIED),

View File

@ -29,38 +29,51 @@ namespace ngraph
{
namespace reference
{
template <typename T>
void sum_region_across_axes(const T* arg,
size_t current_axis_index,
const std::vector<size_t>& axes,
Coordinate& sum_coord,
T& square_sum,
const std::vector<size_t>& begin_area,
const std::vector<size_t>& end_area,
const CoordinateTransform& input_transform)
static size_t point_to_flat_idx(const Shape& shape, const std::vector<size_t>& point)
{
// all nested axes were visited
if (current_axis_index == axes.size())
size_t idx = point[0];
for (int i = 1; i < point.size(); i++)
{
square_sum += arg[input_transform.index(sum_coord)] *
arg[input_transform.index(sum_coord)];
return;
idx *= shape[i];
idx += point[i];
}
auto current_axis = axes[current_axis_index];
for (auto current_axis_coord = begin_area[current_axis];
current_axis_coord < end_area[current_axis];
++current_axis_coord)
return idx;
}
static std::vector<size_t> slice_indices(const Shape& full_shape,
const std::vector<size_t>& begin,
const Shape& slice_shape)
{
size_t begin_idx = begin[0];
size_t slice_size = shape_size(slice_shape);
size_t rank = begin.size();
auto coord = begin;
std::vector<size_t> indices;
indices.reserve(slice_size);
indices.push_back(point_to_flat_idx(full_shape, coord));
for (int i = 0; i < slice_size - 1; i++)
{
sum_coord.at(current_axis) = current_axis_coord;
sum_region_across_axes(arg,
current_axis_index + 1,
axes,
sum_coord,
square_sum,
begin_area,
end_area,
input_transform);
for (int r = rank - 1; r >= 0; r--)
{
coord[r]++;
if (coord[r] < (begin[r] + slice_shape[r]))
break;
coord[r] = begin[r];
}
indices.push_back(point_to_flat_idx(full_shape, coord));
}
return indices;
}
template <typename T>
static T sum_region_across_axes(const T* arg, const std::vector<size_t>& indices)
{
T square_sum = 0;
for (auto index : indices)
{
square_sum += arg[index] * arg[index];
}
return square_sum;
}
template <typename T>
@ -76,39 +89,42 @@ namespace ngraph
T alpha = static_cast<T>(dalpha);
T beta = static_cast<T>(dbeta);
T bias = static_cast<T>(dbias);
T scale = alpha / std::pow(size, axes.size());
std::vector<size_t> begin_area(arg_shape.size());
std::vector<size_t> end_area(arg_shape.size());
Shape area_shape(arg_shape.size(), 1);
std::vector<bool> axes_map(arg_shape.size(), false);
for (const auto& axis_coord : axes)
{
axes_map[axis_coord] = true;
}
CoordinateTransform input_transform(arg_shape);
for (const Coordinate& in_coord : input_transform)
{
// area determined by in_coord local neighborhood
for (const auto& axis_coord : axes)
for (size_t i = 0; i < axes_map.size(); i++)
{
begin_area[axis_coord] =
std::max<int>(0, in_coord.at(axis_coord) - (size - 1) / 2);
end_area[axis_coord] = std::min<int>(
arg_shape.at(axis_coord), in_coord.at(axis_coord) + (size - 1) / 2 + 1);
if (axes_map[i])
{
begin_area[i] = std::max<int>(0, in_coord.at(i) - (size - 1) / 2);
area_shape[i] = std::min<int>(arg_shape.at(i),
in_coord.at(i) + (size - 1) / 2 + 1) -
begin_area[i];
}
else
{
begin_area[i] = in_coord.at(i);
}
}
T square_sum = 0;
auto sum_coord = in_coord;
auto axes_vec = std::vector<size_t>(axes.begin(), axes.end());
sum_region_across_axes(arg,
0,
axes_vec,
sum_coord,
square_sum,
begin_area,
end_area,
input_transform);
T x = arg[input_transform.index(in_coord)];
out[input_transform.index(in_coord)] =
x / (std::pow(bias + (alpha / size) * square_sum, beta));
T square_sum = sum_region_across_axes(
arg, slice_indices(arg_shape, begin_area, area_shape));
auto index = input_transform.index(in_coord);
T x = arg[index];
out[index] = x / (std::pow(bias + scale * square_sum, beta));
}
}
}
}
}
} // namespace reference
} // namespace runtime
} // namespace ngraph

View File

@ -46,7 +46,6 @@ NGRAPH_TEST(${BACKEND_NAME}, lrn_across_channel)
double beta = 0.5;
double bias = 1;
size_t size = 3;
// lrn is performed across channel as default
auto lrn = make_shared<op::LRN>(A, alpha, beta, bias, size);
auto f = make_shared<Function>(lrn, ParameterVector{A});
@ -55,11 +54,11 @@ NGRAPH_TEST(${BACKEND_NAME}, lrn_across_channel)
auto test_case = test::TestCase<TestEngine>(f);
test_case.add_input<float>(shape, a);
test_case.add_expected_output<float>(shape,
{0.f,
{0.0000000f,
0.3015113f,
0.4364357f,
0.5f,
0.8728715f,
0.4364358f,
0.5000000f,
0.8728716f,
0.8451542f,
0.5970223f,
0.6115928f,
@ -67,6 +66,7 @@ NGRAPH_TEST(${BACKEND_NAME}, lrn_across_channel)
0.5669467f,
0.7784989f,
0.7720487f});
test_case.run();
}
@ -87,7 +87,7 @@ NGRAPH_TEST(${BACKEND_NAME}, lrn_across_h)
auto test_case = test::TestCase<TestEngine>(f);
test_case.add_input<float>(shape, a);
test_case.add_expected_output<float>(shape,
{0.0f,
{0.0000000f,
0.7071068f,
0.5345225f,
0.8017837f,
@ -97,8 +97,9 @@ NGRAPH_TEST(${BACKEND_NAME}, lrn_across_h)
0.7548294f,
0.6620847f,
0.7448453f,
0.671156f,
0.6711560f,
0.7382717f});
test_case.run();
}
@ -119,18 +120,19 @@ NGRAPH_TEST(${BACKEND_NAME}, lrn_across_hw)
auto test_case = test::TestCase<TestEngine>(f);
test_case.add_input<float>(shape, a);
test_case.add_expected_output<float>(shape,
{0.0f,
0.7071068f,
0.5345225f,
0.8017837f,
0.6172134f,
0.7715167f,
0.6469966f,
0.7548294f,
0.6620847f,
0.7448453f,
0.671156f,
0.7382717f});
{0.0000000f,
0.8660254f,
0.8660254f,
1.2990381f,
1.0444659f,
1.3055824f,
1.1078234f,
1.2924607f,
1.1389896f,
1.2813632f,
1.1572751f,
1.2730026f});
test_case.run();
}
@ -151,18 +153,19 @@ NGRAPH_TEST(${BACKEND_NAME}, lrn_across_all_dims)
auto test_case = test::TestCase<TestEngine>(f);
test_case.add_input<float>(shape, a);
test_case.add_expected_output<float>(shape,
{0.0f,
0.0638877f,
0.0888231f,
0.1332347f,
0.1949481f,
0.2436851f,
0.3833259f,
0.4472136f,
0.3552925f,
0.399704f,
0.4873702f,
0.5361072f});
{0.0000000f,
0.3156438f,
0.4501407f,
0.6752110f,
0.9830783f,
1.2288479f,
1.8938627f,
2.2095065f,
1.8005627f,
2.0256331f,
2.4576957f,
2.7034652f});
test_case.run(DEFAULT_FLOAT_TOLERANCE_BITS + 1);
}
@ -183,18 +186,19 @@ NGRAPH_TEST(${BACKEND_NAME}, lrn_across_nw)
auto test_case = test::TestCase<TestEngine>(f);
test_case.add_input<float>(shape, a);
test_case.add_expected_output<float>(shape,
{0.0f,
0.140028f,
0.2407717f,
0.3144855f,
0.3698001f,
0.4123931f,
0.9863939f,
0.9801961f,
0.9630868f,
0.9434564f,
0.9245003f,
0.9072647f});
{0.0000000f,
0.2379155f,
0.4111132f,
0.5388159f,
0.6351073f,
0.7094756f,
1.6641006f,
1.6654084f,
1.6444529f,
1.6164477f,
1.5877683f,
1.5608464f});
test_case.run();
}
@ -215,18 +219,19 @@ NGRAPH_TEST(${BACKEND_NAME}, lrn_across_empty)
auto test_case = test::TestCase<TestEngine>(f);
test_case.add_input<float>(shape, a);
test_case.add_expected_output<float>(shape,
{0.0f,
0.7071068f,
0.8944272f,
0.9486833f,
0.9701425f,
0.9805807f,
0.9863939f,
0.9899495f,
0.9922779f,
0.9938837f,
0.9950372f,
0.9958932f});
{0.0000000f,
0.5000000f,
0.5547002f,
0.5669467f,
0.5714286f,
0.5735393f,
0.5746958f,
0.5753965f,
0.5758526f,
0.5761660f,
0.5763904f,
0.5765567f});
test_case.run();
}
@ -248,10 +253,11 @@ NGRAPH_TEST(${BACKEND_NAME}, lrn_6D_across_2_axes)
auto test_case = test::TestCase<TestEngine>(f);
test_case.add_input<float>(shape, a);
test_case.add_expected_output<float>(
shape, {0.0f, 0.2581989f, 0.5163978f, 0.7745967f, 0.3549426f, 0.4436783f,
0.5324139f, 0.6211495f, 0.4175966f, 0.4697962f, 0.5219957f, 0.5741953f,
0.4426267f, 0.4795122f, 0.5163978f, 0.5532833f, 0.4560274f, 0.4845291f,
0.5130308f, 0.5415326f, 0.4643635f, 0.4875816f, 0.5107998f, 0.534018f});
shape, {0.0000000f, 0.4200840f, 0.8401681f, 1.2602521f, 0.6099943f, 0.7624928f,
0.9149914f, 1.0674900f, 0.7213357f, 0.8115027f, 0.9016696f, 0.9918366f,
0.7656109f, 0.8294119f, 0.8932127f, 0.9570137f, 0.7892218f, 0.8385482f,
0.8878745f, 0.9372009f, 0.8038679f, 0.8440613f, 0.8842546f, 0.9244481f});
test_case.run();
}
@ -272,18 +278,18 @@ NGRAPH_TEST(${BACKEND_NAME}, lrn_2d_across_empty)
auto test_case = test::TestCase<TestEngine>(f);
test_case.add_input<float>(shape, a);
test_case.add_expected_output<float>(shape,
{0.0f,
0.7071068f,
0.8944272f,
0.9486833f,
0.9701425f,
0.9805807f,
0.9863939f,
0.9899495f,
0.9922779f,
0.9938837f,
0.9950372f,
0.9958932f});
{0.0000000f,
0.5000000f,
0.5547002f,
0.5669467f,
0.5714286f,
0.5735393f,
0.5746958f,
0.5753964f,
0.5758526f,
0.5761660f,
0.5763904f,
0.5765566f});
test_case.run();
}
@ -315,17 +321,18 @@ NGRAPH_TEST(${BACKEND_NAME}, lrn_2d_across_outermost_axis)
auto test_case = test::TestCase<TestEngine>(f);
test_case.add_input<float>(shape, a);
test_case.add_expected_output<float>(shape,
{0.45900404f,
0.14999892f,
-1.04828012f,
-0.99727529f,
0.41144446f,
0.08083449f,
-0.16259004f,
-0.09422511f,
-0.02180192f,
-0.34259823f,
0.35597473f,
-0.70393407f});
{0.4590040f,
0.1499989f,
-1.0482801f,
-0.9972753f,
0.4114444f,
0.0808345f,
-0.1625900f,
-0.0942251f,
-0.0218018f,
-0.3425926f,
0.3559732f,
-0.7039225f});
test_case.run(23);
}

View File

@ -25,20 +25,24 @@ def LRN(input, size=3, bias=1.0, alpha=3.0, beta=0.5):
H = input.shape[2]
W = input.shape[3]
for n in range(N):
begin_n = max(0, n - (size-1)//2)
end_n = min(N, n + (size-1)//2 + 1)
for c in range(C):
begin_c = max(0, c - (size-1)//2)
end_c = min(C, c + (size-1)//2 + 1)
for h in range(H):
begin_h = max(0, h - (size-1)/2)
end_h = min(H, h + (size-1)/2 + 1)
begin_h = max(0, h - (size-1)//2)
end_h = min(H, h + (size-1)//2 + 1)
for w in range(W):
begin_w = max(0, w - (size-1)/2)
end_w = min(W, w + (size-1)/2 + 1)
begin_w = max(0, w - (size-1)//2)
end_w = min(W, w + (size-1)//2 + 1)
patch = input[n, c, begin_h:end_h, begin_w:end_w]
output[n, c, h, w] /= (
np.power(bias + (alpha/size) * np.sum(patch * patch), beta))
np.power(bias + (alpha/(size**2)) * np.sum(patch * patch), beta))
return output
input = np.arange(0, 12, 1).reshape(2, 3, 2, 1).astype(np.float32)
result = LRN(input)
for elem in np.nditer(result):
print(str(round(elem, 7)) + "f, ")
print("{:.7f}f,".format(elem))

View File

@ -452,7 +452,6 @@ max_pool_3d
avg_pool_2d_2channel_2image_padded_only_above_include_in_computation
avg_pool_3d_uneven_strided_padded
multiple_result
lrn_across_hw
lrn_across_all_dims
elu
elu_negative_alpha
@ -1331,7 +1330,6 @@ IE_GPU.max_3d_to_matrix_least_sig
IE_GPU.max_3d_to_vector
IE_GPU.max_3d_to_scalar
IE_GPU.max_3d_to_scalar_int32
IE_GPU.lrn_across_channel
IE_GPU.log
IE_GPU.gather_4d_indices_no_axis_2d_input
IE_GPU.gather_3d_indices_no_axis_2d_input