MaxPool 8 reference implementation (#7115)

This commit is contained in:
Tomasz Dołbniak 2021-09-01 19:14:45 +03:00 committed by GitHub
parent e07ac53300
commit 28075fb7fc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 792 additions and 12 deletions

View File

@ -0,0 +1,400 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include <ie_core.hpp>
#include <ie_ngraph_utils.hpp>
#include <ngraph/ngraph.hpp>
#include <shared_test_classes/base/layer_test_utils.hpp>
#include <tuple>
#include "base_reference_test.hpp"
using namespace ngraph;
using namespace reference_tests;
using namespace InferenceEngine;
struct MaxPoolParams {
template <class Input_t, class Indices_t>
MaxPoolParams(const Shape& input_shape,
const element::Type& input_type,
const std::vector<Input_t>& input_data,
const std::vector<Input_t>& expected_values,
const element::Type& indices_type,
const std::vector<Indices_t>& expected_indices,
const Strides& strides,
const Strides& dilations,
const Shape& pads_begin,
const Shape& pads_end,
const Shape& kernel,
const op::PadType pad_type = op::PadType::EXPLICIT,
const int64_t axis = 0)
: m_input_shape(input_shape),
m_input_type(input_type),
m_indices_type(indices_type),
m_input_data(CreateBlob(input_type, input_data)),
m_expected_values(CreateBlob(input_type, expected_values)),
m_expected_indices(CreateBlob(indices_type, expected_indices)),
m_strides(strides),
m_dilations(dilations),
m_pads_begin(pads_begin),
m_pads_end(pads_end),
m_kernel(kernel),
m_pad_type(pad_type),
m_axis(axis) {}
Shape m_input_shape;
element::Type m_input_type;
element::Type m_indices_type;
InferenceEngine::Blob::Ptr m_input_data;
InferenceEngine::Blob::Ptr m_expected_values;
InferenceEngine::Blob::Ptr m_expected_indices;
Strides m_strides;
Strides m_dilations;
Shape m_pads_begin;
Shape m_pads_end;
Shape m_kernel;
op::PadType m_pad_type;
int64_t m_axis;
};
class ReferenceMaxPoolLayerTest : public testing::TestWithParam<MaxPoolParams>, public CommonReferenceTest {
public:
void SetUp() override {
const auto params = GetParam();
function = CreateFunction(params);
inputData = {params.m_input_data};
refOutData = {params.m_expected_values, params.m_expected_indices};
}
static std::string getTestCaseName(const testing::TestParamInfo<MaxPoolParams>& obj) {
const auto p = obj.param;
std::ostringstream result;
result << p.m_input_shape.size() - 2 << "D/";
result << "input_shape=" << p.m_input_shape << ";";
result << "input_type=" << p.m_input_type << ";";
result << "indices_type=" << p.m_indices_type << ";";
result << "strides=" << p.m_strides << ";";
result << "dilations=" << p.m_dilations << ";";
result << "pads_begin=" << p.m_pads_begin << ";";
result << "pads_end=" << p.m_pads_end << ";";
result << "kernel=" << p.m_kernel << ";";
result << "pad_type=" << p.m_pad_type << ";";
result << "axis=" << p.m_axis;
return result.str();
}
private:
static std::shared_ptr<Function> CreateFunction(const MaxPoolParams& params) {
const auto in = std::make_shared<op::Parameter>(params.m_input_type, params.m_input_shape);
const auto max_pool = std::make_shared<op::v8::MaxPool>(in,
params.m_strides,
params.m_dilations,
params.m_pads_begin,
params.m_pads_end,
params.m_kernel,
op::RoundingType::FLOOR,
params.m_pad_type,
params.m_indices_type,
params.m_axis);
return std::make_shared<Function>(max_pool, ParameterVector{in});
}
};
TEST_P(ReferenceMaxPoolLayerTest, CompareWithHardcodedRefs) {
Exec();
}
INSTANTIATE_TEST_SUITE_P(
smoke_MaxPool_With_Hardcoded_Refs,
ReferenceMaxPoolLayerTest,
::testing::Values(
MaxPoolParams(Shape{1, 1, 9},
element::i32,
std::vector<int32_t>{1, 2, 3, 4, 5, 6, 7, 8, 9},
std::vector<int32_t>{2, 3, 4, 5, 6, 7, 8, 9},
element::i64,
std::vector<int64_t>{1, 2, 3, 4, 5, 6, 7, 8},
Strides{1},
Strides{1},
Shape{},
Shape{},
Shape{2}),
MaxPoolParams(Shape{1, 1, 9},
element::i32,
std::vector<int32_t>{1, 2, 3, 4, 5, 6, 7, 8, 9},
std::vector<int32_t>{2, 4, 6, 8},
element::i64,
std::vector<int64_t>{1, 3, 5, 7},
Strides{2},
Strides{1},
Shape{},
Shape{},
Shape{2}),
MaxPoolParams(Shape{1, 1, 9},
element::i32,
std::vector<int32_t>{1, 2, 3, 4, 5, 6, 7, 8, 9},
std::vector<int32_t>{1, 3, 5, 7, 9},
element::i64,
std::vector<int64_t>{0, 2, 4, 6, 8},
Strides{2},
Strides{1},
Shape{},
Shape{},
Shape{2},
op::PadType::SAME_LOWER),
MaxPoolParams(Shape{1, 1, 9},
element::i32,
std::vector<int32_t>{1, 2, 3, 4, 5, 6, 7, 8, 9},
std::vector<int32_t>{2, 4, 6, 8, 9},
element::i64,
std::vector<int64_t>{1, 3, 5, 7, 8},
Strides{2},
Strides{1},
Shape{},
Shape{},
Shape{2},
op::PadType::SAME_UPPER),
MaxPoolParams(Shape{1, 1, 9},
element::i32,
std::vector<int32_t>{1, 2, 3, 4, 5, 6, 7, 8, 9},
std::vector<int32_t>{3, 5, 7, 9},
element::i32,
std::vector<int32_t>{2, 4, 6, 8},
Strides{2},
Strides{2},
Shape{},
Shape{},
Shape{2}),
MaxPoolParams(Shape{1, 2, 4},
element::f32,
std::vector<float>{1.0f, 2.0f, 3.0f, 4.0f, 0.0f, -3.14f, -2.71f, 5.0f},
std::vector<float>{3.0f, 4.0f, 0.0f, 5.0f},
element::i32,
std::vector<int32_t>{2, 3, 4, 7},
Strides{1},
Strides{1},
Shape{},
Shape{},
Shape{3}),
MaxPoolParams(Shape{1, 2, 4},
element::f32,
std::vector<float>{1.0f, 2.0f, 3.0f, 4.0f, 0.0f, -3.14f, -2.71f, 5.0f},
std::vector<float>{3.0f, 4.0f, 0.0f, 5.0f},
element::i32,
std::vector<int32_t>{2, 3, 0, 3},
Strides{1},
Strides{1},
Shape{},
Shape{},
Shape{3},
op::PadType::EXPLICIT,
2),
MaxPoolParams(Shape{1, 1, 9},
element::i32,
std::vector<int32_t>{1, 9, 3, 8, 5, 2, 6, 4, 7},
std::vector<int32_t>{1, 9, 6, 7},
element::i32,
std::vector<int32_t>{0, 1, 6, 8},
Strides{3},
Strides{1},
Shape{2},
Shape{2},
Shape{3}),
/*************************************************/
/***************** 2D test cases *****************/
/*************************************************/
MaxPoolParams(Shape{1, 1, 3, 3},
element::i32,
std::vector<int32_t>{3, 9, 10, 5, 7, 2, 18, 8, -2},
std::vector<int32_t>{9, 10, 18, 8},
element::i32,
std::vector<int32_t>{1, 2, 6, 7},
Strides{1, 1},
Strides{1, 1},
Shape{},
Shape{},
Shape{2, 2}),
MaxPoolParams(Shape{1, 1, 4, 4}, // simple 4x4 input test
element::i32,
std::vector<int32_t>{8, -9, 1, -16, -14, 15, -17, 19, -13, 3, 10, 17, 16, -11, -15, 20},
std::vector<int32_t>{15, 15, 19, 15, 15, 19, 16, 10, 20},
element::i32,
std::vector<int32_t>{5, 5, 7, 5, 5, 7, 12, 10, 15},
Strides{1, 1},
Strides{1, 1},
Shape{},
Shape{},
Shape{2, 2}),
MaxPoolParams(Shape{1, 1, 4, 4},
element::i32,
std::vector<int32_t>{8, -9, 1, -16, -14, 15, -17, 19, -13, 3, 10, 17, 16, -11, -15, 20},
std::vector<int32_t>{15, 19, 16, 20},
element::i32,
std::vector<int32_t>{5, 7, 12, 15},
Strides{2, 2}, // strides: 2x2
Strides{1, 1},
Shape{},
Shape{},
Shape{2, 2}),
MaxPoolParams(Shape{1, 1, 4, 4},
element::i32,
std::vector<int32_t>{8, -9, 1, -16, -14, 15, -17, 19, -13, 3, 10, 17, 16, -11, -15, 20},
std::vector<int32_t>{10, 17, 16, 20},
element::i32,
std::vector<int32_t>{10, 11, 12, 15},
Strides{1, 1},
Strides{2, 2}, // dilations: 2x2
Shape{},
Shape{},
Shape{2, 2}),
MaxPoolParams(Shape{1, 1, 4, 4},
element::i32,
std::vector<int32_t>{8, -9, 1, -16, -14, 15, -17, 19, -13, 3, 10, 17, 16, -11, -15, 20},
std::vector<int32_t>{15, 19, 16, 20},
element::i32,
std::vector<int32_t>{5, 7, 12, 15},
Strides{1, 1},
Strides{1, 1},
Shape{},
Shape{},
Shape{3, 3}), // kernel: 3x3
MaxPoolParams(Shape{1, 1, 5, 5},
element::i32,
std::vector<int32_t>{0, -2, 24, 13, 7, -5, -4, 4, 21, -18, 81, 20, -15,
37, 23, 41, 18, 42, 8, 32, 9, 57, 58, 29, 3},
std::vector<int32_t>{0, 21, 81, 37},
element::i32,
std::vector<int32_t>{0, 8, 10, 13},
Strides{2, 3}, // strides: 2x3
Strides{1, 1},
Shape{},
Shape{},
Shape{2, 2}),
MaxPoolParams(Shape{1, 1, 5, 5},
element::i32,
std::vector<int32_t>{0, -2, 24, 13, 7, -5, -4, 4, 21, -18, 81, 20, -15,
37, 23, 41, 18, 42, 8, 32, 9, 57, 58, 29, 3},
std::vector<int32_t>{0, 24, 57, 58},
element::i32,
std::vector<int32_t>{0, 2, 21, 22},
Strides{3, 2}, // strides: 3x2
Strides{1, 1},
Shape{},
Shape{},
Shape{2, 2}),
MaxPoolParams(Shape{1, 1, 5, 5},
element::i32,
std::vector<int32_t>{0, -2, 24, 13, 7, -5, -4, 4, 21, -18, 81, 20, -15,
37, 23, 41, 18, 42, 8, 32, 9, 57, 58, 29, 3},
std::vector<int32_t>{81, 24, 81, 58},
element::i32,
std::vector<int32_t>{10, 2, 10, 22},
Strides{2, 2}, // strides: 2x2
Strides{2, 2}, // dilations: 2x2
Shape{},
Shape{},
Shape{2, 2}),
MaxPoolParams(Shape{1, 1, 5, 5},
element::i32,
std::vector<int32_t>{0, -2, 24, 13, 7, -5, -4, 4, 21, -18, 81, 20, -15,
37, 23, 41, 18, 42, 8, 32, 9, 57, 58, 29, 3},
std::vector<int32_t>{0, 24, 21, 81, 42, 37, 57, 58, 32},
element::i32,
std::vector<int32_t>{0, 2, 8, 10, 17, 13, 21, 22, 19},
Strides{2, 2}, // strides: 2x2
Strides{1, 1},
Shape{1, 1}, // pads_begin: 1x1
Shape{1, 1}, // pads_end: 1x1
Shape{3, 3}),
MaxPoolParams(Shape{1, 1, 5, 5},
element::i32,
std::vector<int32_t>{0, -2, 24, 13, 7, -5, -4, 4, 21, -18, 81, 20, -15,
37, 23, 41, 18, 42, 8, 32, 9, 57, 58, 29, 3},
std::vector<int32_t>{81, 37, 81, 58, 58, 58},
element::i32,
std::vector<int32_t>{10, 13, 10, 22, 22, 22},
Strides{2, 2}, // strides: 2x2
Strides{1, 1},
Shape{},
Shape{2, 1}, // pads_end: 2x1
Shape{3, 3}),
MaxPoolParams(Shape{1, 2, 3, 3},
element::i64,
std::vector<int64_t>{0, -2, 24, 13, 7, -5, -4, 4, 21, -18, 81, 20, -15, 37, 23, 41, 18, 42},
std::vector<int64_t>{13, 24, 13, 21, 81, 81, 41, 42},
element::i64,
std::vector<int64_t>{3, 2, 3, 8, 1, 1, 6, 8},
Strides{1, 1},
Strides{1, 1},
Shape{},
Shape{},
Shape{2, 2},
op::PadType::EXPLICIT,
2), // axis: 2
MaxPoolParams(Shape{1, 1, 2, 2},
element::i32,
std::vector<int32_t>{1, 2, 3, 4},
std::vector<int32_t>{1, 2, 3, 4},
element::i32,
std::vector<int32_t>{0, 1, 2, 3},
Strides{1, 1},
Strides{1, 1},
Shape{},
Shape{},
Shape{1, 1}), // kernel: 1x1
/*************************************************/
/***************** 3D test cases *****************/
/*************************************************/
MaxPoolParams(Shape{1, 1, 3, 3, 3},
element::i32,
std::vector<int32_t>{0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 20, 30, 40, -20,
60, 70, 80, 50, 50, 1, 2, 3, -15, -10, 50, 30, 81},
std::vector<int32_t>{40, 60, 80, 80, 50, 60, 80, 81},
element::i32,
std::vector<int32_t>{12, 14, 16, 16, 18, 14, 16, 26},
Strides{1, 1, 1},
Strides{1, 1, 1},
Shape{},
Shape{},
Shape{2, 2, 2}),
MaxPoolParams(Shape{1, 1, 3, 3, 3},
element::i32,
std::vector<int32_t>{0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 20, 30, 40, -20,
60, 70, 80, 50, 50, 1, 2, 3, -15, -10, 50, 30, 81},
std::vector<int32_t>{-20, -20, -20, -20, -20, -20, -20, -20},
element::i32,
std::vector<int32_t>{13, 13, 13, 13, 13, 13, 13, 13},
Strides{2, 2, 2},
Strides{2, 2, 2},
Shape{1, 1, 1},
Shape{1, 1, 1},
Shape{2, 2, 2}),
MaxPoolParams(Shape{1, 1, 3, 3, 3},
element::i32,
std::vector<int32_t>{0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 20, 30, 40, -20,
60, 70, 80, 50, 50, 1, 2, 3, -15, -10, 50, 30, 81},
std::vector<int32_t>{8, 80, 81},
element::i32,
std::vector<int32_t>{8, 16, 26},
Strides{1, 1, 1},
Strides{1, 1, 1},
Shape{},
Shape{},
Shape{1, 3, 3}),
MaxPoolParams(Shape{1, 1, 3, 3, 3},
element::i32,
std::vector<int32_t>{0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 20, 30, 40, -20,
60, 70, 80, 50, 50, 1, 2, 3, -15, -10, 50, 30, 81},
std::vector<int32_t>{4, 5, 7, 8, 40, 60, 80, 80, 50, 2, 50, 81},
element::i32,
std::vector<int32_t>{4, 5, 7, 8, 3, 5, 7, 7, 0, 2, 6, 8},
Strides{1, 1, 1},
Strides{1, 1, 1},
Shape{},
Shape{},
Shape{1, 2, 2},
op::PadType::EXPLICIT,
3)),
ReferenceMaxPoolLayerTest::getTestCaseName);

View File

@ -5,6 +5,7 @@
#pragma once
#include <cstdio>
#include <numeric>
#include <vector>
#include "ngraph/attribute_adapter.hpp"
@ -45,6 +46,20 @@ size_t shape_size(const SHAPE_TYPE& shape) {
return size;
}
/// Number of elements in a subset of dimensions of a shape.
/// Returns a product of dimensions in a range [start_dim;end_dim)
template <typename ForwardIt>
size_t shape_size(ForwardIt start_dim, const ForwardIt end_dim) {
static_assert(std::is_arithmetic<typename std::iterator_traits<ForwardIt>::value_type>::value,
"shape_size expects 2 forward iterators as inputs. value_type of those iterators has to be an "
"arithmetic type so that they can be used in multiplication operation.");
return std::accumulate(start_dim,
end_dim,
typename std::iterator_traits<ForwardIt>::value_type{1},
std::multiplies<typename std::iterator_traits<ForwardIt>::value_type>());
}
/// Row-major strides for a shape
template <typename SHAPE_TYPE>
std::vector<size_t> row_major_strides(const SHAPE_TYPE& shape) {

View File

@ -110,6 +110,339 @@ void max_pool(const T* arg,
}
NGRAPH_SUPPRESS_DEPRECATED_END
}
namespace {
void validate_max_pool_kernel_params(const size_t dims,
const Shape& kernel,
const Strides& kernel_strides,
const Strides& kernel_dilations,
const Shape& pads_begin,
const Shape& pads_end) {
NGRAPH_CHECK(kernel.size() == dims && kernel_strides.size() == dims && kernel_dilations.size() == dims &&
pads_begin.size() == dims && pads_end.size() == dims,
"One of the MaxPool params does not match the ",
dims,
"D implementation.\nkernel=",
kernel,
"\nkernel_strides=",
kernel_strides,
"\nkernel_dilations=",
kernel_dilations,
"\npads_begin=",
pads_begin,
"\npads_end=",
pads_end);
}
/// \brief A helper struct representing spatial coordinates of a tensor element. It can use signed numbers as the
/// underlying type; this way it is possible to represent elements which belong to the padding area
/// (by using negative values).
///
/// \note This struct can be used to represent a location of a pooling kernel in space (non-flattened version)
/// but at the same time it can represent pixel offsets in the filter itself (dilated or non-dilated)
template <typename T>
struct Coord : public std::vector<T> {
Coord(const Shape& pads_begin) {
std::vector<T>::reserve(pads_begin.size());
for (const auto axis_padding : pads_begin) {
std::vector<T>::push_back(0 - axis_padding);
}
}
Coord(std::initializer_list<T>&& values) : std::vector<T>{std::move(values)} {}
};
bool elem_in_padding_area(const Coord<int>& kernel_position,
const Coord<size_t>& kernel_offset,
const Shape& data_shape) {
for (size_t dim = 0; dim + 2 < data_shape.size(); ++dim) {
if (kernel_position[dim] + kernel_offset[dim] < 0 ||
kernel_position[dim] + kernel_offset[dim] >= data_shape[dim + 2]) {
return true;
}
}
return false;
}
template <typename T>
Coord<T> next_kernel_position_2D(Coord<T> kernel_position,
const Shape& kernel,
const Strides& kernel_strides,
const Strides& kernel_dilations,
const Shape& data_shape,
const Shape& pads_begin,
const Shape& pads_end) {
// move the kernel horizontally one stride to the right
kernel_position[1] += kernel_strides[1];
// if the top-right corner of the kernel is outside of the padding area,
// move it back to the left and one stride down
if (kernel_position[1] + (kernel[1] - 1) * kernel_dilations[1] >= data_shape[3] + pads_end[1]) {
kernel_position[1] = 0 - pads_begin[1];
kernel_position[0] += kernel_strides[0];
}
return kernel_position;
}
template <typename T>
Coord<T> next_kernel_position_3D(Coord<T> kernel_position,
const Shape& kernel,
const Strides& kernel_strides,
const Strides& kernel_dilations,
const Shape& data_shape,
const Shape& pads_begin,
const Shape& pads_end) {
kernel_position[2] += kernel_strides[2];
if (kernel_position[2] + (kernel[2] - 1) * kernel_dilations[2] >= data_shape[4] + pads_end[2]) {
kernel_position[2] = 0 - pads_begin[2];
kernel_position[1] += kernel_strides[1];
if (kernel_position[1] + (kernel[1] - 1) * kernel_dilations[1] >= data_shape[3] + pads_end[1]) {
kernel_position[1] = 0 - pads_begin[1];
kernel_position[0] += kernel_strides[0];
}
}
return kernel_position;
}
namespace kernel {
template <typename Values_t, typename Indices_t>
void max_pool_1d(const Values_t* data,
Values_t* values,
Indices_t* indices,
const size_t data_elems,
const size_t out_elems,
const size_t kernel_size,
const size_t kernel_stride,
const size_t kernel_dilation,
const size_t pads_begin,
const size_t pads_end,
const size_t indices_offset) {
int kernel_position = 0 - pads_begin;
// select max elem and its index for each "placeholder" in the out buffer (pointed to by out_idx)
for (size_t out_idx = 0; out_idx < out_elems; ++out_idx) {
Values_t max_elem = std::numeric_limits<Values_t>::lowest();
Indices_t max_elem_idx = Indices_t{0};
for (size_t kernel_elem = 0; kernel_elem < kernel_size; ++kernel_elem) {
const size_t kernel_elem_offset = kernel_elem * kernel_dilation;
// don't process the padding elements
if (kernel_position + kernel_elem_offset >= 0 && kernel_position + kernel_elem_offset < data_elems &&
data[kernel_position + kernel_elem_offset] > max_elem) {
max_elem = data[kernel_position + kernel_elem_offset];
max_elem_idx = kernel_position + kernel_elem_offset;
}
}
values[out_idx] = max_elem;
indices[out_idx] = max_elem_idx + indices_offset;
kernel_position += kernel_stride;
}
}
template <typename Values_t, typename Indices_t>
void max_pool_2d(const Values_t* data,
Values_t* values,
Indices_t* indices,
const Shape& data_shape,
const Shape& out_shape,
const Shape& kernel,
const Strides& kernel_strides,
const Strides& kernel_dilations,
const Shape& pads_begin,
const Shape& pads_end,
const size_t indices_offset) {
validate_max_pool_kernel_params(2, kernel, kernel_strides, kernel_dilations, pads_begin, pads_end);
Coord<int> kernel_position{pads_begin};
// select max elem and its index for each "placeholder" in the out buffer (pointed to by out_idx)
for (size_t out_idx = 0; out_idx < out_shape[2] * out_shape[3]; ++out_idx) {
Values_t max_elem = std::numeric_limits<Values_t>::lowest();
Indices_t max_elem_idx = Indices_t{0};
// find the max element in the area covered by a current position of the kernel
for (size_t kernel_row = 0; kernel_row < kernel[0]; ++kernel_row) {
for (size_t kernel_col = 0; kernel_col < kernel[1]; ++kernel_col) {
// offset from the top-left corner of the kernel for a given row and col
const Coord<size_t> kernel_offset{kernel_row * kernel_dilations[0], kernel_col * kernel_dilations[1]};
// ignore the elements in the padding area
if (!elem_in_padding_area(kernel_position, kernel_offset, data_shape)) {
// index of the flattened tensor element under the current row & column of the kernel
const size_t data_elem_index =
data_shape[2] * (kernel_offset[0] + kernel_position[0]) + kernel_offset[1] + kernel_position[1];
if (data[data_elem_index] > max_elem) {
max_elem = data[data_elem_index];
max_elem_idx = data_elem_index;
}
}
}
}
values[out_idx] = max_elem;
indices[out_idx] = max_elem_idx + indices_offset;
kernel_position = next_kernel_position_2D(kernel_position,
kernel,
kernel_strides,
kernel_dilations,
data_shape,
pads_begin,
pads_end);
}
}
template <typename Values_t, typename Indices_t>
void max_pool_3d(const Values_t* data,
Values_t* values,
Indices_t* indices,
const Shape& data_shape,
const Shape& out_shape,
const Shape& kernel,
const Strides& kernel_strides,
const Strides& kernel_dilations,
const Shape& pads_begin,
const Shape& pads_end,
const size_t indices_offset) {
validate_max_pool_kernel_params(3, kernel, kernel_strides, kernel_dilations, pads_begin, pads_end);
Coord<int> kernel_position{pads_begin};
const size_t out_elems = shape_size(std::begin(out_shape) + 2, std::end(out_shape));
// select max elem and its index for each "placeholder" in the out buffer (pointed to by out_idx)
for (size_t out_idx = 0; out_idx < out_elems; ++out_idx) {
Values_t max_elem = std::numeric_limits<Values_t>::lowest();
Indices_t max_elem_idx = Indices_t{0};
for (size_t kernel_channel = 0; kernel_channel < kernel[0]; ++kernel_channel) {
for (size_t kernel_row = 0; kernel_row < kernel[1]; ++kernel_row) {
for (size_t kernel_col = 0; kernel_col < kernel[2]; ++kernel_col) {
// offset from the top-left corner of the kernel for a given row and col
const Coord<size_t> kernel_offset{kernel_channel * kernel_dilations[0],
kernel_row * kernel_dilations[1],
kernel_col * kernel_dilations[2]};
// ignore the elements in the padding area
if (!elem_in_padding_area(kernel_position, kernel_offset, data_shape)) {
// index of the flattened tensor element under the current row & column of the kernel
const size_t data_elem_index =
data_shape[2] * data_shape[3] * (kernel_offset[0] + kernel_position[0]) +
data_shape[3] * (kernel_offset[1] + kernel_position[1]) + kernel_offset[2] +
kernel_position[2];
if (data[data_elem_index] > max_elem) {
max_elem = data[data_elem_index];
max_elem_idx = data_elem_index;
}
}
}
}
}
values[out_idx] = max_elem;
indices[out_idx] = max_elem_idx + indices_offset;
kernel_position = next_kernel_position_3D(kernel_position,
kernel,
kernel_strides,
kernel_dilations,
data_shape,
pads_begin,
pads_end);
}
}
} // namespace kernel
} // namespace
template <typename Values_t, typename Indices_t>
void max_pool(const Values_t* data,
Values_t* values,
Indices_t* indices,
const Shape& data_shape,
const Shape& out_shape,
const Shape& kernel,
const Strides& strides,
const Strides& dilations,
const Shape& pads_begin,
const Shape& pads_end,
const int64_t axis = 0) {
const auto data_batch_elems = shape_size(std::begin(data_shape) + 1, std::end(data_shape));
const auto data_channel_elems = shape_size(std::begin(data_shape) + 2, std::end(data_shape));
const auto out_batch_elems = shape_size(std::begin(out_shape) + 1, std::end(out_shape));
const auto out_channel_elems = shape_size(std::begin(out_shape) + 2, std::end(out_shape));
for (size_t b = 0; b < data_shape[0]; ++b) {
const Indices_t batch_indices_offset = b * data_batch_elems;
for (size_t c = 0; c < data_shape[1]; ++c) {
// calculate the buffer offsets for a given channel "c" then execute an appropriate
// kernel for each processed channel
const Values_t* data_channel_first_elem = data + b * data_batch_elems + c * data_channel_elems;
Values_t* out_channel_first_elem = values + b * out_batch_elems + c * out_channel_elems;
Indices_t* indices_channel_first_elem = indices + b * out_batch_elems + c * out_channel_elems;
const Indices_t channel_indices_offset = c * data_channel_elems;
// total offset of the flattened tensor indices for currently processed batch and channel
const Indices_t indices_offset = batch_indices_offset + channel_indices_offset;
if (data_shape.size() == 3) {
kernel::max_pool_1d<Values_t, Indices_t>(data_channel_first_elem,
out_channel_first_elem,
indices_channel_first_elem,
data_shape[2],
out_shape[2],
kernel[0],
strides[0],
dilations[0],
pads_begin[0],
pads_end[0],
indices_offset);
} else if (data_shape.size() == 4) {
kernel::max_pool_2d<Values_t, Indices_t>(data_channel_first_elem,
out_channel_first_elem,
indices_channel_first_elem,
data_shape,
out_shape,
kernel,
strides,
dilations,
pads_begin,
pads_end,
indices_offset);
} else if (data_shape.size() == 5) {
kernel::max_pool_3d<Values_t, Indices_t>(data_channel_first_elem,
out_channel_first_elem,
indices_channel_first_elem,
data_shape,
out_shape,
kernel,
strides,
dilations,
pads_begin,
pads_end,
indices_offset);
} else {
NGRAPH_CHECK(false,
"Unsupported input shape ",
data_shape,
" passed to the MaxPool reference implementation. Supported shapes: 3D, 4D and 5D.");
}
}
}
// adjust the calculated indices to the requested range (specified by the axis attribute) if needed
if (axis != 0) {
const Indices_t max_index = shape_size(std::begin(data_shape) + axis, std::end(data_shape));
const auto indices_number = shape_size(out_shape);
for (size_t i = 0; i < indices_number; ++i) {
indices[i] %= max_index;
}
}
}
} // namespace reference
} // namespace runtime
} // namespace ngraph

View File

@ -51,6 +51,7 @@
#include <ngraph/runtime/reference/lrn.hpp>
#include <ngraph/runtime/reference/lstm_cell.hpp>
#include <ngraph/runtime/reference/matrix_nms.hpp>
#include <ngraph/runtime/reference/max_pool.hpp>
#include <ngraph/runtime/reference/mod.hpp>
#include <ngraph/runtime/reference/multiclass_nms.hpp>
#include <ngraph/runtime/reference/mvn.hpp>
@ -2945,6 +2946,47 @@ namespace
return true;
}
template <element::Type_t ET>
bool evaluate(const shared_ptr<op::v8::MaxPool>& op,
const HostTensorVector& outputs,
const HostTensorVector& inputs)
{
using T = typename element_type_traits<ET>::value_type;
if (op->get_index_element_type() == element::i32)
{
runtime::reference::max_pool(inputs[0]->get_data_ptr<const T>(),
outputs[0]->get_data_ptr<T>(),
outputs[1]->get_data_ptr<int32_t>(),
inputs[0]->get_shape(),
outputs[0]->get_shape(),
op->get_kernel(),
op->get_strides(),
op->get_dilations(),
op->get_pads_begin(),
op->get_pads_end(),
op->get_axis());
}
else if (op->get_index_element_type() == element::i64)
{
runtime::reference::max_pool(inputs[0]->get_data_ptr<const T>(),
outputs[0]->get_data_ptr<T>(),
outputs[1]->get_data_ptr<int64_t>(),
inputs[0]->get_shape(),
outputs[0]->get_shape(),
op->get_kernel(),
op->get_strides(),
op->get_dilations(),
op->get_pads_begin(),
op->get_pads_end(),
op->get_axis());
}
else
{
return false;
}
return true;
}
template <typename T>
bool evaluate_node(std::shared_ptr<Node> node,
const HostTensorVector& outputs,
@ -2959,18 +3001,7 @@ namespace
{
element_type = node->get_input_element_type(0);
}
for (size_t i = 1; i < node->outputs().size(); i++)
{
if ((ov::is_type<op::v5::NonMaxSuppression>(node) ||
ov::is_type<op::v8::MulticlassNms>(node) ||
ov::is_type<op::v8::MatrixNms>(node) ||
ov::is_type<op::v6::ExperimentalDetectronDetectionOutput>(node) ||
ov::is_type<op::v8::AdaptiveMaxPool>(node)) &&
i == 1)
{
continue;
}
}
switch (element_type)
{
case element::Type_t::boolean:

View File

@ -101,4 +101,5 @@ NGRAPH_OP(AdaptiveAvgPool, ngraph::op::v8)
NGRAPH_OP(AdaptiveMaxPool, ngraph::op::v8)
NGRAPH_OP(Gather, op::v8)
NGRAPH_OP(MatrixNms, op::v8)
NGRAPH_OP(MaxPool, op::v8)
NGRAPH_OP(MulticlassNms, op::v8)