ONNX MaxPool (opset 8+) (#7100)

This commit is contained in:
Tomasz Dołbniak 2021-11-11 12:48:32 +01:00 committed by GitHub
parent 9d42aa22b6
commit 0e7eef3c88
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 725 additions and 161 deletions

View File

@ -142,12 +142,7 @@ void validate_max_pool_kernel_params(const size_t dims,
/// but at the same time it can represent pixel offsets in the filter itself (dilated or non-dilated)
template <typename T>
struct Coord : public std::vector<T> {
Coord(const Shape& pads_begin) {
std::vector<T>::reserve(pads_begin.size());
for (const auto axis_padding : pads_begin) {
std::vector<T>::push_back(0 - axis_padding);
}
}
Coord() = default;
Coord(std::initializer_list<T>&& values) : std::vector<T>{std::move(values)} {}
};
@ -165,47 +160,15 @@ bool elem_in_padding_area(const Coord<int>& kernel_position,
return false;
}
template <typename T>
Coord<T> next_kernel_position_2D(Coord<T> kernel_position,
const Shape& kernel,
const Strides& kernel_strides,
const Strides& kernel_dilations,
const Shape& data_shape,
const Shape& pads_begin,
const Shape& pads_end) {
// move the kernel horizontally one stride to the right
kernel_position[1] += kernel_strides[1];
// if the top-right corner of the kernel is outside of the padding area,
// move it back to the left and one stride down
if (kernel_position[1] + (kernel[1] - 1) * kernel_dilations[1] >= data_shape[3] + pads_end[1]) {
kernel_position[1] = 0 - pads_begin[1];
kernel_position[0] += kernel_strides[0];
Coord<int> calculate_kernel_position(const Coord<size_t>& out_elem_coord,
const Strides& kernel_strides,
const Shape& pads_begin) {
Coord<int> top_left_corner;
top_left_corner.reserve(out_elem_coord.size());
for (size_t i = 0u; i < out_elem_coord.size(); ++i) {
top_left_corner.emplace_back(out_elem_coord[i] * kernel_strides[i] - pads_begin[i]);
}
return kernel_position;
}
template <typename T>
Coord<T> next_kernel_position_3D(Coord<T> kernel_position,
const Shape& kernel,
const Strides& kernel_strides,
const Strides& kernel_dilations,
const Shape& data_shape,
const Shape& pads_begin,
const Shape& pads_end) {
kernel_position[2] += kernel_strides[2];
if (kernel_position[2] + (kernel[2] - 1) * kernel_dilations[2] >= data_shape[4] + pads_end[2]) {
kernel_position[2] = 0 - pads_begin[2];
kernel_position[1] += kernel_strides[1];
if (kernel_position[1] + (kernel[1] - 1) * kernel_dilations[1] >= data_shape[3] + pads_end[1]) {
kernel_position[1] = 0 - pads_begin[1];
kernel_position[0] += kernel_strides[0];
}
}
return kernel_position;
return top_left_corner;
}
namespace kernel {
@ -255,43 +218,44 @@ void max_pool_2d(const Values_t* data,
const size_t indices_offset) {
validate_max_pool_kernel_params(2, kernel, kernel_strides, kernel_dilations, pads_begin, pads_end);
Coord<int> kernel_position{pads_begin};
// helper constants(axes) denoting dimensions in the input data shape and kernel shape
constexpr size_t data_H = 2, data_W = 3;
constexpr size_t kernel_H = 0, kernel_W = 1;
// select max elem and its index for each "placeholder" in the out buffer (pointed to by out_idx)
for (size_t out_idx = 0; out_idx < out_shape[2] * out_shape[3]; ++out_idx) {
Values_t max_elem = std::numeric_limits<Values_t>::lowest();
Indices_t max_elem_idx = Indices_t{0};
size_t out_idx = 0u;
for (size_t out_row = 0u; out_row < out_shape[data_H]; ++out_row) {
for (size_t out_col = 0u; out_col < out_shape[data_W]; ++out_col) {
Values_t max_elem = std::numeric_limits<Values_t>::lowest();
Indices_t max_elem_idx = Indices_t{0};
// find the max element in the area covered by a current position of the kernel
for (size_t kernel_row = 0; kernel_row < kernel[0]; ++kernel_row) {
for (size_t kernel_col = 0; kernel_col < kernel[1]; ++kernel_col) {
// offset from the top-left corner of the kernel for a given row and col
const Coord<size_t> kernel_offset{kernel_row * kernel_dilations[0], kernel_col * kernel_dilations[1]};
const auto kernel_position = calculate_kernel_position({out_row, out_col}, kernel_strides, pads_begin);
// find the max element in the area covered by a current position of the kernel
for (size_t kernel_row = 0; kernel_row < kernel[kernel_H]; ++kernel_row) {
for (size_t kernel_col = 0; kernel_col < kernel[kernel_W]; ++kernel_col) {
// offset from the top-left corner of the kernel for a given row and col
const Coord<size_t> kernel_offset{kernel_row * kernel_dilations[kernel_H],
kernel_col * kernel_dilations[kernel_W]};
// ignore the elements in the padding area
if (!elem_in_padding_area(kernel_position, kernel_offset, data_shape)) {
// index of the flattened tensor element under the current row & column of the kernel
const size_t data_elem_index =
data_shape[2] * (kernel_offset[0] + kernel_position[0]) + kernel_offset[1] + kernel_position[1];
// ignore the elements in the padding area
if (!elem_in_padding_area(kernel_position, kernel_offset, data_shape)) {
// index of the flattened tensor element under the current row & column of the kernel
const size_t data_elem_index =
data_shape[data_W] * (kernel_offset[kernel_H] + kernel_position[kernel_H]) +
kernel_offset[kernel_W] + kernel_position[kernel_W];
if (data[data_elem_index] > max_elem) {
max_elem = data[data_elem_index];
max_elem_idx = data_elem_index;
if (data[data_elem_index] > max_elem) {
max_elem = data[data_elem_index];
max_elem_idx = data_elem_index;
}
}
}
}
values[out_idx] = max_elem;
indices[out_idx] = max_elem_idx + indices_offset;
++out_idx;
}
values[out_idx] = max_elem;
indices[out_idx] = max_elem_idx + indices_offset;
kernel_position = next_kernel_position_2D(kernel_position,
kernel,
kernel_strides,
kernel_dilations,
data_shape,
pads_begin,
pads_end);
}
}
@ -309,49 +273,51 @@ void max_pool_3d(const Values_t* data,
const size_t indices_offset) {
validate_max_pool_kernel_params(3, kernel, kernel_strides, kernel_dilations, pads_begin, pads_end);
Coord<int> kernel_position{pads_begin};
const size_t out_elems = shape_size(std::begin(out_shape) + 2, std::end(out_shape));
// helper constants(axes) denoting dimensions in the input data shape and kernel shape
constexpr size_t data_D = 2, data_H = 3, data_W = 4;
constexpr size_t kernel_D = 0, kernel_H = 1, kernel_W = 2;
// select max elem and its index for each "placeholder" in the out buffer (pointed to by out_idx)
for (size_t out_idx = 0; out_idx < out_elems; ++out_idx) {
Values_t max_elem = std::numeric_limits<Values_t>::lowest();
Indices_t max_elem_idx = Indices_t{0};
size_t out_idx = 0u;
for (size_t out_channel = 0u; out_channel < out_shape[data_D]; ++out_channel) {
for (size_t out_row = 0u; out_row < out_shape[data_H]; ++out_row) {
for (size_t out_col = 0u; out_col < out_shape[data_W]; ++out_col) {
Values_t max_elem = std::numeric_limits<Values_t>::lowest();
Indices_t max_elem_idx = Indices_t{0};
for (size_t kernel_channel = 0; kernel_channel < kernel[0]; ++kernel_channel) {
for (size_t kernel_row = 0; kernel_row < kernel[1]; ++kernel_row) {
for (size_t kernel_col = 0; kernel_col < kernel[2]; ++kernel_col) {
// offset from the top-left corner of the kernel for a given row and col
const Coord<size_t> kernel_offset{kernel_channel * kernel_dilations[0],
kernel_row * kernel_dilations[1],
kernel_col * kernel_dilations[2]};
const auto kernel_position =
calculate_kernel_position({out_channel, out_row, out_col}, kernel_strides, pads_begin);
// ignore the elements in the padding area
if (!elem_in_padding_area(kernel_position, kernel_offset, data_shape)) {
// index of the flattened tensor element under the current row & column of the kernel
const size_t data_elem_index =
data_shape[2] * data_shape[3] * (kernel_offset[0] + kernel_position[0]) +
data_shape[3] * (kernel_offset[1] + kernel_position[1]) + kernel_offset[2] +
kernel_position[2];
for (size_t kernel_channel = 0; kernel_channel < kernel[kernel_D]; ++kernel_channel) {
for (size_t kernel_row = 0; kernel_row < kernel[kernel_H]; ++kernel_row) {
for (size_t kernel_col = 0; kernel_col < kernel[kernel_W]; ++kernel_col) {
// offset from the top-left corner of the kernel for a given row and col
const Coord<size_t> kernel_offset{kernel_channel * kernel_dilations[kernel_D],
kernel_row * kernel_dilations[kernel_H],
kernel_col * kernel_dilations[kernel_W]};
if (data[data_elem_index] > max_elem) {
max_elem = data[data_elem_index];
max_elem_idx = data_elem_index;
// ignore the elements in the padding area
if (!elem_in_padding_area(kernel_position, kernel_offset, data_shape)) {
// index of the flattened tensor element under the current row & column of the kernel
const size_t data_elem_index =
data_shape[data_H] * data_shape[data_W] *
(kernel_offset[kernel_D] + kernel_position[kernel_D]) +
data_shape[data_W] * (kernel_offset[kernel_H] + kernel_position[kernel_H]) +
kernel_offset[kernel_W] + kernel_position[kernel_W];
if (data[data_elem_index] > max_elem) {
max_elem = data[data_elem_index];
max_elem_idx = data_elem_index;
}
}
}
}
}
values[out_idx] = max_elem;
indices[out_idx] = max_elem_idx + indices_offset;
++out_idx;
}
}
values[out_idx] = max_elem;
indices[out_idx] = max_elem_idx + indices_offset;
kernel_position = next_kernel_position_3D(kernel_position,
kernel,
kernel_strides,
kernel_dilations,
data_shape,
pads_begin,
pads_end);
}
}
} // namespace kernel

View File

@ -220,8 +220,10 @@ bool evaluate_maxpool(const HostTensorPtr& data,
switch (indices->get_element_type()) {
case element::Type_t::i32: {
switch (data->get_element_type()) {
EVAL_MAX_POOL_8(i8, i32);
EVAL_MAX_POOL_8(i32, i32);
EVAL_MAX_POOL_8(i64, i32);
EVAL_MAX_POOL_8(u8, i32);
EVAL_MAX_POOL_8(u32, i32);
EVAL_MAX_POOL_8(u64, i32);
EVAL_MAX_POOL_8(f16, i32);
@ -233,8 +235,10 @@ bool evaluate_maxpool(const HostTensorPtr& data,
} break;
case element::Type_t::i64: {
switch (data->get_element_type()) {
EVAL_MAX_POOL_8(i8, i64);
EVAL_MAX_POOL_8(i32, i64);
EVAL_MAX_POOL_8(i64, i64);
EVAL_MAX_POOL_8(u8, i64);
EVAL_MAX_POOL_8(u32, i64);
EVAL_MAX_POOL_8(u64, i64);
EVAL_MAX_POOL_8(f16, i64);
@ -319,8 +323,10 @@ shared_ptr<Node> op::v8::MaxPool::clone_with_new_inputs(const OutputVector& new_
bool op::v8::MaxPool::has_evaluate() const {
NGRAPH_OP_SCOPE(v8_MaxPool_has_evaluate);
switch (get_input_element_type(0)) {
case ngraph::element::i8:
case ngraph::element::i32:
case ngraph::element::i64:
case ngraph::element::u8:
case ngraph::element::u32:
case ngraph::element::u64:
case ngraph::element::f16:

View File

@ -10,6 +10,7 @@
#include <ngraph/opsets/opset4.hpp>
#include <ngraph/opsets/opset5.hpp>
#include <ngraph/opsets/opset6.hpp>
#include <ngraph/opsets/opset8.hpp>
#include <ngraph/runtime/reference/convert.hpp>
#include <vector>
@ -30,6 +31,7 @@ bool fuse_type_to_nms3(const std::shared_ptr<ngraph::Node>& node, ngraph::elemen
bool fuse_type_to_nms4(const std::shared_ptr<ngraph::Node>& node, ngraph::element::Type to, size_t idx);
bool fuse_type_to_nms5(const std::shared_ptr<ngraph::Node>& node, ngraph::element::Type to, size_t idx);
bool fuse_type_to_topk(const std::shared_ptr<ngraph::Node>& node, ngraph::element::Type to, size_t idx);
bool fuse_type_to_maxpool(const std::shared_ptr<ngraph::Node>& node, ngraph::element::Type to, size_t idx);
bool fuse_type_to_nonzero(const std::shared_ptr<ngraph::Node>& node, ngraph::element::Type to, size_t idx);
bool fuse_type_to_bucketize(const std::shared_ptr<ngraph::Node>& node, ngraph::element::Type to, size_t idx);
bool fuse_type_to_ctc_greedy_decoder_seq_len(const std::shared_ptr<ngraph::Node>& node,
@ -253,6 +255,7 @@ bool ngraph::pass::ConvertPrecision::run_on_function(std::shared_ptr<ngraph::Fun
{opset5::NonMaxSuppression::get_type_info_static(), fuse_type_to_nms5},
{opset6::CTCGreedyDecoderSeqLen::get_type_info_static(), fuse_type_to_ctc_greedy_decoder_seq_len},
{opset4::TopK::get_type_info_static(), fuse_type_to_topk},
{opset8::MaxPool::get_type_info_static(), fuse_type_to_maxpool},
{opset4::NonZero::get_type_info_static(), fuse_type_to_nonzero},
{opset4::Bucketize::get_type_info_static(), fuse_type_to_bucketize},
{opset4::Equal::get_type_info_static(), fuse_type_to_binary_comparision<opset4::Equal>},
@ -392,6 +395,16 @@ bool fuse_type_to_topk(const std::shared_ptr<ngraph::Node>& node, ngraph::elemen
return false;
}
bool fuse_type_to_maxpool(const std::shared_ptr<ngraph::Node>& node, ngraph::element::Type to, size_t idx) {
if (auto maxpool = ov::as_type_ptr<opset8::MaxPool>(node)) {
if (idx == 1 && (to == element::i32 || to == element::i64)) {
maxpool->set_index_element_type(to);
return true;
}
}
return false;
}
bool fuse_type_to_ctc_greedy_decoder_seq_len(const std::shared_ptr<ngraph::Node>& node,
ngraph::element::Type to,
size_t idx) {

View File

@ -25,8 +25,14 @@ OutputVector max_pool(const Node& node) {
} // namespace set_1
namespace set_8 {
OutputVector max_pool(const Node& node) {
return pooling::PoolingFactory(node).make_max_pool_with_indices();
}
} // namespace set_8
} // namespace op
} // namespace onnx_import
} // namespace ngraph
} // namespace ngraph

View File

@ -23,6 +23,19 @@ OutputVector max_pool(const Node& node);
} // namespace set_1
namespace set_8 {
///
/// \brief Convert ONNX MaxPool operation to an nGraph node.
///
/// \param node The ONNX node object representing this operation.
///
/// \return The vector containing Ngraph nodes producing output of ONNX MaxPool
/// operation.
///
OutputVector max_pool(const Node& node);
} // namespace set_8
} // namespace op
} // namespace onnx_import

View File

@ -361,6 +361,7 @@ OperatorsBridge::OperatorsBridge() {
REGISTER_OPERATOR("MatMulInteger", 1, matmul_integer);
REGISTER_OPERATOR("MatMul", 1, matmul);
REGISTER_OPERATOR("MaxPool", 1, max_pool);
REGISTER_OPERATOR("MaxPool", 8, max_pool);
REGISTER_OPERATOR("Max", 1, max);
REGISTER_OPERATOR("Max", 8, max);
REGISTER_OPERATOR("Mean", 1, mean);

View File

@ -14,6 +14,27 @@
namespace ngraph {
namespace onnx_import {
namespace pooling {
namespace {
std::shared_ptr<default_opset::Constant> transposition_axis_order(const Rank& input_rank) {
NGRAPH_CHECK(input_rank.is_static(),
"Generating column-major MaxPool results is supported only for inputs with static rank.");
const auto rank = static_cast<size_t>(input_rank.get_length());
std::vector<int32_t> axes(rank);
std::iota(axes.begin(), axes.end(), 0);
std::reverse(axes.begin() + 2, axes.end());
return std::make_shared<default_opset::Constant>(element::i32, Shape{rank}, axes);
}
std::shared_ptr<ngraph::Node> identity(Output<ngraph::Node> node_output) {
const auto zero = default_opset::Constant::create(node_output.get_element_type(), {}, {0});
return std::make_shared<default_opset::Add>(node_output, zero);
}
} // namespace
PoolingFactory::PoolingFactory(const Node& node)
: m_onnx_node{node},
m_inputs{node.get_ng_inputs()},
@ -27,6 +48,7 @@ PoolingFactory::PoolingFactory(const Node& node)
const CoordinateDiff& padding_below{paddings.first};
m_padding_below = Shape{std::begin(padding_below), std::end(padding_below)};
m_padding_above = Shape{std::begin(padding_above), std::end(padding_above)};
m_storage_order = static_cast<StorageOrder>(node.get_attribute_value<int64_t>("storage_order", 0));
}
OutputVector PoolingFactory::make_avg_pool() const {
@ -50,6 +72,26 @@ OutputVector PoolingFactory::make_max_pool() const {
m_rounding_type,
m_auto_pad)};
}
OutputVector PoolingFactory::make_max_pool_with_indices() const {
const auto max_pool = std::make_shared<op::v8::MaxPool>(m_inputs.at(0),
m_strides,
m_dilations,
m_padding_below,
m_padding_above,
m_kernel_shape,
m_rounding_type,
m_auto_pad);
if (m_storage_order == StorageOrder::COLUMN_MAJOR) {
const auto transposition_axes = transposition_axis_order(m_inputs.at(0).get_partial_shape().rank());
const auto transposed_indices =
std::make_shared<default_opset::Transpose>(max_pool->output(1), transposition_axes);
return {max_pool->output(0), transposed_indices};
} else {
return {identity(max_pool->output(0)), identity(max_pool->output(1))};
}
}
} // namespace pooling
} // namespace onnx_import
} // namespace ngraph

View File

@ -46,6 +46,9 @@ public:
///
OutputVector make_max_pool() const;
/// \brief Creates max pooling ONNX operation with 2 outputs (values and indices).
OutputVector make_max_pool_with_indices() const;
protected:
Node m_onnx_node;
const OutputVector m_inputs;
@ -56,6 +59,10 @@ protected:
Shape m_padding_above;
ngraph::op::PadType m_auto_pad;
ngraph::op::RoundingType m_rounding_type;
enum class StorageOrder : int64_t { ROW_MAJOR = 0, COLUMN_MAJOR = 1 };
StorageOrder m_storage_order;
};
} // namespace pooling
} // namespace onnx_import

View File

@ -0,0 +1,76 @@
ir_version: 3
producer_name: "backend-test"
graph {
node {
input: "x"
output: "y"
output: "z"
op_type: "MaxPool"
attribute {
name: "kernel_shape"
ints: 2
type: INTS
}
}
name: "maxpool_test"
input {
name: "x"
type {
tensor_type {
elem_type: 6
shape {
dim {
dim_value: 1
}
dim {
dim_value: 3
}
dim {
dim_value: 3
}
}
}
}
}
output {
name: "y"
type {
tensor_type {
elem_type: 6
shape {
dim {
dim_value: 1
}
dim {
dim_value: 3
}
dim {
dim_value: 2
}
}
}
}
}
output {
name: "z"
type {
tensor_type {
elem_type: 7
shape {
dim {
dim_value: 1
}
dim {
dim_value: 3
}
dim {
dim_value: 2
}
}
}
}
}
}
opset_import {
version: 13
}

View File

@ -0,0 +1,91 @@
ir_version: 3
producer_name: "backend-test"
graph {
node {
input: "x"
output: "y"
output: "z"
op_type: "MaxPool"
attribute {
name: "kernel_shape"
ints: 3
ints: 3
type: INTS
}
attribute {
name: "ceil_mode"
i: 1
type: INT
}
}
name: "maxpool_test"
input {
name: "x"
type {
tensor_type {
elem_type: 6
shape {
dim {
dim_value: 1
}
dim {
dim_value: 1
}
dim {
dim_value: 4
}
dim {
dim_value: 4
}
}
}
}
}
output {
name: "y"
type {
tensor_type {
elem_type: 6
shape {
dim {
dim_value: 1
}
dim {
dim_value: 1
}
dim {
dim_value: 2
}
dim {
dim_value: 2
}
}
}
}
}
output {
name: "z"
type {
tensor_type {
elem_type: 7
shape {
dim {
dim_value: 1
}
dim {
dim_value: 1
}
dim {
dim_value: 2
}
dim {
dim_value: 2
}
}
}
}
}
}
opset_import {
version: 13
}

View File

@ -0,0 +1,97 @@
ir_version: 4
producer_name: "backend-test"
graph {
node {
input: "x"
output: "y"
output: "z"
op_type: "MaxPool"
attribute {
name: "ceil_mode"
i: 1
type: INT
}
attribute {
name: "kernel_shape"
ints: 3
ints: 3
type: INTS
}
attribute {
name: "strides"
ints: 2
ints: 2
type: INTS
}
}
name: "test_maxpool_2d_ceil"
input {
name: "x"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 1
}
dim {
dim_value: 4
}
dim {
dim_value: 4
}
}
}
}
}
output {
name: "y"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 1
}
dim {
dim_value: 2
}
dim {
dim_value: 2
}
}
}
}
}
output {
name: "z"
type {
tensor_type {
elem_type: 7
shape {
dim {
dim_value: 1
}
dim {
dim_value: 1
}
dim {
dim_value: 2
}
dim {
dim_value: 2
}
}
}
}
}
}
opset_import {
version: 10
}

View File

@ -0,0 +1,92 @@
ir_version: 3
producer_name: "backend-test"
graph {
node {
input: "x"
output: "y"
output: "z"
op_type: "MaxPool"
attribute {
name: "kernel_shape"
ints: 2
ints: 2
type: INTS
}
attribute {
name: "dilations"
ints: 2
ints: 2
type: INTS
}
}
name: "maxpool_test"
input {
name: "x"
type {
tensor_type {
elem_type: 6
shape {
dim {
dim_value: 1
}
dim {
dim_value: 1
}
dim {
dim_value: 4
}
dim {
dim_value: 4
}
}
}
}
}
output {
name: "y"
type {
tensor_type {
elem_type: 6
shape {
dim {
dim_value: 1
}
dim {
dim_value: 1
}
dim {
dim_value: 2
}
dim {
dim_value: 2
}
}
}
}
}
output {
name: "z"
type {
tensor_type {
elem_type: 7
shape {
dim {
dim_value: 1
}
dim {
dim_value: 1
}
dim {
dim_value: 2
}
dim {
dim_value: 2
}
}
}
}
}
}
opset_import {
version: 13
}

View File

@ -0,0 +1,100 @@
ir_version: 3
producer_name: "backend-test"
graph {
node {
input: "x"
output: "y"
output: "z"
op_type: "MaxPool"
attribute {
name: "kernel_shape"
ints: 3
ints: 3
type: INTS
}
attribute {
name: "strides"
ints: 3
ints: 3
type: INTS
}
attribute {
name: "pads"
ints: 2
ints: 2
ints: 2
ints: 2
type: INTS
}
}
name: "maxpool_test"
input {
name: "x"
type {
tensor_type {
elem_type: 3
shape {
dim {
dim_value: 1
}
dim {
dim_value: 1
}
dim {
dim_value: 5
}
dim {
dim_value: 5
}
}
}
}
}
output {
name: "y"
type {
tensor_type {
elem_type: 3
shape {
dim {
dim_value: 1
}
dim {
dim_value: 1
}
dim {
dim_value: 3
}
dim {
dim_value: 3
}
}
}
}
}
output {
name: "z"
type {
tensor_type {
elem_type: 7
shape {
dim {
dim_value: 1
}
dim {
dim_value: 1
}
dim {
dim_value: 3
}
dim {
dim_value: 3
}
}
}
}
}
}
opset_import {
version: 13
}

View File

@ -22,14 +22,16 @@
#endif
// clang-format on
#include "onnx_import/core/null_node.hpp"
#include <cpp/ie_cnn_network.h>
#include "default_opset.hpp"
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/pass/constant_folding.hpp"
#include "ngraph/pass/manager.hpp"
#include "onnx_import/core/null_node.hpp"
#include "onnx_import/onnx.hpp"
#include "onnx_import/onnx_utils.hpp"
#include "default_opset.hpp"
#include "ngraph/ngraph.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/constant_folding.hpp"
#include "util/all_close.hpp"
#include "util/all_close_f.hpp"
#include "util/ndarray.hpp"
@ -38,7 +40,6 @@
#include "engines_util/test_engines.hpp"
#include "util/test_tools.hpp"
#include "util/type_prop.hpp"
#include <cpp/ie_cnn_network.h>
NGRAPH_SUPPRESS_DEPRECATED_START
@ -4165,6 +4166,73 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_float16_tensor_as_int32) {
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_max_pool_3d) {
const auto function = onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/max_pool_3d.onnx"));
auto test_case = test::TestCase<TestEngine>(function);
test_case.add_input<int32_t>(Shape{1, 3, 3}, {-1, 0, 1, 20, -20, 10, 0, 2, 1});
test_case.add_expected_output<int32_t>(Shape{1, 3, 2}, {0, 1, 20, 10, 2, 2});
test_case.add_expected_output<int64_t>(Shape{1, 3, 2}, {1, 2, 3, 5, 7, 7});
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_max_pool_4d_ceil_mode) {
const auto function =
onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/max_pool_4d_ceil_mode.onnx"));
auto test_case = test::TestCase<TestEngine>(function);
test_case.add_input<int32_t>(Shape{1, 1, 4, 4}, gen_range<int32_t>(16, 1));
test_case.add_expected_output<int32_t>(Shape{1, 1, 2, 2}, {11, 12, 15, 16});
test_case.add_expected_output<int64_t>(Shape{1, 1, 2, 2}, {10, 11, 14, 15});
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_max_pool_4d_dilations) {
const auto function =
onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/max_pool_4d_dilations.onnx"));
auto test_case = test::TestCase<TestEngine>(function);
test_case.add_input<int32_t>(Shape{1, 1, 4, 4}, {9, 10, 11, 12, 1, 2, 3, 4, 16, 14, 15, 13, 5, 6, 8, 7});
test_case.add_expected_output<int32_t>(Shape{1, 1, 2, 2}, {16, 14, 8, 7});
test_case.add_expected_output<int64_t>(Shape{1, 1, 2, 2}, {8, 9, 14, 15});
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_max_pool_4d_strides) {
// kernel: 3x3
// strides: 3, 3
// explicit pads: 2, 2, 2, 2
const auto function =
onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/max_pool_4d_strides.onnx"));
auto test_case = test::TestCase<TestEngine>(function);
test_case.add_input<int8_t>(Shape{1, 1, 5, 5}, gen_range<int8_t>(25, 1));
test_case.add_expected_output<int8_t>(Shape{1, 1, 3, 3}, {1, 4, 5, 16, 19, 20, 21, 24, 25});
test_case.add_expected_output<int64_t>(Shape{1, 1, 3, 3}, {0, 3, 4, 15, 18, 19, 20, 23, 24});
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_max_pool_4d_ceil_strides) {
// kernel: 3x3
// strides: 2, 2
// ceil_mode: 1
const auto function =
onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/max_pool_4d_ceil_strides.onnx"));
auto test_case = test::TestCase<TestEngine>(function);
test_case.add_input<float>(
Shape{1, 1, 4, 4},
{1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f});
test_case.add_expected_output<float>(Shape{1, 1, 2, 2}, {11.0f, 12.0f, 15.0f, 16.0f});
test_case.add_expected_output<int64_t>(Shape{1, 1, 2, 2}, {10, 11, 14, 15});
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_random_uniform) {
const auto function =
onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/random_uniform.onnx"));

View File

@ -15,9 +15,9 @@
#include <iterator>
#include <numeric>
#include "default_opset.hpp"
#include "gtest/gtest.h"
#include "ngraph/file_util.hpp"
#include "default_opset.hpp"
#include "onnx_import/onnx.hpp"
#include "engines_util/test_engines.hpp"
#include "engines_util/test_case.hpp"
@ -333,10 +333,9 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_dyn_shapes_max_pool_with_indices_output) {
25.f, 25.f, 23.f, 24.f, 25.f, 25.f, 25.f, 23.f, 24.f, 25.f, 25.f, 25.f};
test_case.add_expected_output<float>(Shape{1, 1, 5, 5}, expected_values);
// indices output is not supported and is ingored in current implementation
// std::vector<int64_t> expected_indices{12, 13, 14, 14, 14, 17, 18, 19, 19, 19, 22, 23, 24, 24,
// 24, 22, 23, 24, 24, 24, 22, 23, 24, 24, 24};
// test_case.add_expected_output<float>(Shape{1, 1, 5, 5}, expected_indices);
std::vector<int64_t> expected_indices{12, 13, 14, 14, 14, 17, 18, 19, 19, 19, 22, 23, 24,
24, 24, 22, 23, 24, 24, 24, 22, 23, 24, 24, 24};
test_case.add_expected_output<int64_t>(Shape{1, 1, 5, 5}, expected_indices);
test_case.run();
}

View File

@ -371,14 +371,12 @@ tile_3d_few_repeats
# Result mismatch
sum_large_1d_to_scalar
sum_stable_acc
max_pool_3d
avg_pool_2d_2channel_2image_padded_only_above_include_in_computation
avg_pool_3d_uneven_strided_padded
multiple_result
lrn_across_all_dims
elu
elu_negative_alpha
max_pool_2d_1channel_1image_overpadded
grn_2d_with_bias
erf
divide_adjoint_stability
@ -386,8 +384,6 @@ notequal
less
sum_3d_to_scalar_int32
sum_2d_to_scalar_int8
max_pool_uint8
max_pool_int8
avg_pool_uint8
avg_pool_int8
max_to_scalar_int8
@ -432,6 +428,12 @@ onnx_constant_integer_array
adaptive_max_pool_1d
adaptive_max_pool_2d
adaptive_max_pool_3d
onnx_dyn_shapes_max_pool_with_indices_output
onnx_model_max_pool_3d
onnx_model_max_pool_4d_ceil_mode
onnx_model_max_pool_4d_dilations
onnx_model_max_pool_4d_strides
onnx_model_max_pool_4d_ceil_strides
# Unsupported primitive of type: SigmoidBackprop
sigmoid_bprop_n1c1h4
@ -554,9 +556,6 @@ product_to_scalar_int8
min_to_scalar_int8
# Pooling layer. Unsupported mode. Only 4D and 5D blobs are supported as input.
max_pool_1d_1channel_1image
max_pool_1d_1channel_2image
max_pool_1d_2channel_2image
avg_pool_1d_1channel_1image
avg_pool_1d_1channel_2image
avg_pool_1d_2channel_2image

View File

@ -55,3 +55,12 @@ std::vector<T> read_binary_file(const std::string& path) {
inputs_fs.read(reinterpret_cast<char*>(file_content.data()), size);
return file_content;
}
template <typename T = int32_t>
std::vector<T> gen_range(const size_t elements, const T start = T{0}) {
std::vector<T> range;
range.resize(elements);
std::iota(range.begin(), range.end(), start);
return range;
}

View File

@ -47,7 +47,6 @@ xfail_issue_33651 = xfail_test(reason="RuntimeError: nGraph does not support the
"TfIdfVectorizer")
xfail_issue_33581 = xfail_test(reason="RuntimeError: nGraph does not support the following ONNX operations: "
"GatherElements")
xfail_issue_33633 = xfail_test(reason="MaxPool: dilations unsupported")
xfail_issue_35923 = xfail_test(reason="RuntimeError: PReLU without weights is not supported")
xfail_issue_35927 = xfail_test(reason="RuntimeError: B has zero dimension that is not allowable")
xfail_issue_36486 = xfail_test(reason="RuntimeError: HardSigmoid operation should be converted "
@ -93,7 +92,6 @@ xfail_issue_44965 = xfail_test(reason="Expected: RuntimeError: value info has no
xfail_issue_44968 = xfail_test(reason="Expected: Unsupported dynamic op: Squeeze")
xfail_issue_47323 = xfail_test(reason="RuntimeError: The plugin does not support FP64")
xfail_issue_47337 = xfail_test(reason="RuntimeError: Unsupported dynamic ops: v1::OneHot")
xfail_issue_33593 = xfail_test(reason="Current implementation of MaxPool doesn't support indices output")
xfail_issue_55760 = xfail_test(reason="RuntimeError: Reversed axis have axes above the source space shape")
# Model MSFT issues:
@ -136,7 +134,6 @@ xfail_issue_63036 = xfail_test(reason="Changes in ConvTranspose padding")
xfail_issue_63039 = xfail_test(reason="Result mismatches with UINT8 operations")
xfail_issue_63043 = xfail_test(reason="Recurrent node expects constants as W, R, B inputs.")
xfail_issue_63044 = xfail_test(reason="ONNX opset 14 operation: Trilu")
xfail_issue_63045 = xfail_test(reason="Maxpool with strides, padding and dilations fail")
skip_rng_tests = pytest.mark.skip(reason="Tests use random number generator with no seed.")
xfail_issue_63136 = xfail_test(reason="Unsupported operation: CastLike")

View File

@ -12,11 +12,9 @@ from tests import (
xfail_issue_33538,
xfail_issue_33581,
xfail_issue_33589,
xfail_issue_33593,
xfail_issue_33595,
xfail_issue_33596,
xfail_issue_33606,
xfail_issue_33633,
xfail_issue_33651,
xfail_issue_38091,
xfail_issue_38699,
@ -49,7 +47,6 @@ from tests import (
xfail_issue_63039,
xfail_issue_63043,
xfail_issue_63044,
xfail_issue_63045,
xfail_issue_63136,
xfail_issue_63137,
xfail_issue_63138,
@ -143,7 +140,6 @@ tests_expected_to_fail = [
"OnnxBackendNodeModelTest.test_scatter_elements_with_negative_indices_cpu",
"OnnxBackendNodeModelTest.test_gather_negative_indices_cpu",
),
(xfail_issue_33633, "OnnxBackendNodeModelTest.test_maxpool_2d_dilations_cpu"),
(
xfail_issue_55760,
"OnnxBackendNodeModelTest.test_argmax_negative_axis_keepdims_example_select_last_index_cpu",
@ -338,11 +334,6 @@ tests_expected_to_fail = [
"OnnxBackendNodeModelTest.test_squeeze_cpu",
"OnnxBackendNodeModelTest.test_squeeze_negative_axes_cpu",
),
(
xfail_issue_33593,
"OnnxBackendNodeModelTest.test_maxpool_with_argmax_2d_precomputed_strides_cpu",
"OnnxBackendNodeModelTest.test_maxpool_with_argmax_2d_precomputed_pads_cpu",
),
(xfail_issue_58033, "OnnxBackendNodeModelTest.test_einsum_batch_diagonal_cpu"),
(
xfail_issue_63033,
@ -387,11 +378,6 @@ tests_expected_to_fail = [
"OnnxBackendNodeModelTest.test_triu_square_neg_cpu",
"OnnxBackendNodeModelTest.test_triu_zero_cpu",
),
(
xfail_issue_63045,
"OnnxBackendPyTorchConvertedModelTest.test_MaxPool1d_stride_padding_dilation_cpu",
"OnnxBackendPyTorchConvertedModelTest.test_MaxPool2d_stride_padding_dilation_cpu",
),
(
skip_rng_tests,
"OnnxBackendNodeModelTest.test_bernoulli_cpu",

View File

@ -369,7 +369,7 @@ def test_pool_max(ndarray_1x1x4x4):
x = ndarray_1x1x4x4
y = np.array([[16, 18], [24, 26]], dtype=np.float32).reshape([1, 1, 2, 2])
ng_results = run_node(node, [x])
ng_results = run_node(node, [x], opset_version=7)
assert np.array_equal(ng_results, [y])

View File

@ -77,6 +77,7 @@ tolerance_map = {
"resnet34-v2-7": {"atol": 0.001, "rtol": 0.001},
"vgg16-7": {"atol": 0.001, "rtol": 0.001},
"vgg19-bn-7": {"atol": 0.001, "rtol": 0.001},
"vgg19-7": {"atol": 0.001, "rtol": 0.001},
"tinyyolov2-7": {"atol": 0.001, "rtol": 0.001},
"tinyyolov2-8": {"atol": 0.001, "rtol": 0.001},
"candy-8": {"atol": 0.001, "rtol": 0.001},
@ -115,6 +116,12 @@ tolerance_map = {
"test_retinanet_resnet101": {"atol": 1.3e-06},
}
def tolerance_map_key_in_model_path(path):
for key in tolerance_map:
if key in path:
return key
return None
zoo_models = []
# rglob doesn't work for symlinks, so models have to be physically somwhere inside "MODELS_ROOT_DIR"
for path in Path(MODELS_ROOT_DIR).rglob("*.onnx"):
@ -127,6 +134,12 @@ for path in Path(MODELS_ROOT_DIR).rglob("*.onnx"):
# updated model looks now:
# {"model_name": path, "model_file": file, "dir": mdir, "atol": ..., "rtol": ...}
model.update(tolerance_map[basedir])
else:
# some models have the same stem, have to check if any of the keys from tolerance_map
# is found in the full model path
model_key = tolerance_map_key_in_model_path(str(path))
if model_key is not None:
model.update(tolerance_map[model_key])
if basedir in post_processing:
model.update(post_processing[basedir])
zoo_models.append(model)

View File

@ -46,7 +46,6 @@ xfail_issue_33651 = xfail_test(reason="RuntimeError: nGraph does not support the
"TfIdfVectorizer")
xfail_issue_33581 = xfail_test(reason="RuntimeError: nGraph does not support the following ONNX operations: "
"GatherElements")
xfail_issue_33633 = xfail_test(reason="MaxPool: dilations unsupported")
xfail_issue_35923 = xfail_test(reason="RuntimeError: PReLU without weights is not supported")
xfail_issue_35927 = xfail_test(reason="RuntimeError: B has zero dimension that is not allowable")
xfail_issue_36486 = xfail_test(reason="RuntimeError: HardSigmoid operation should be converted "
@ -99,7 +98,6 @@ xfail_issue_44965 = xfail_test(reason="Expected: RuntimeError: value info has no
xfail_issue_44968 = xfail_test(reason="Expected: Unsupported dynamic op: Squeeze")
xfail_issue_47323 = xfail_test(reason="RuntimeError: The plugin does not support FP64")
xfail_issue_47337 = xfail_test(reason="RuntimeError: Unsupported dynamic ops: v1::OneHot")
xfail_issue_33593 = xfail_test(reason="Current implementation of MaxPool doesn't support indices output")
xfail_issue_55760 = xfail_test(reason="RuntimeError: Reversed axis have axes above the source space shape")
# Model MSFT issues:
@ -142,7 +140,6 @@ xfail_issue_63036 = xfail_test(reason="Changes in ConvTranspose padding")
xfail_issue_63039 = xfail_test(reason="Result mismatches with UINT8 operations")
xfail_issue_63043 = xfail_test(reason="Recurrent node expects constants as W, R, B inputs.")
xfail_issue_63044 = xfail_test(reason="ONNX opset 14 operation: Trilu")
xfail_issue_63045 = xfail_test(reason="Maxpool with strides, padding and dilations fail")
skip_rng_tests = pytest.mark.skip(reason="Tests use random number generator with no seed.")
xfail_issue_63136 = xfail_test(reason="Unsupported operation: CastLike")

View File

@ -11,11 +11,9 @@ from tests_compatibility import (
xfail_issue_33538,
xfail_issue_33581,
xfail_issue_33589,
xfail_issue_33593,
xfail_issue_33595,
xfail_issue_33596,
xfail_issue_33606,
xfail_issue_33633,
xfail_issue_33651,
xfail_issue_38091,
xfail_issue_38699,
@ -48,7 +46,6 @@ from tests_compatibility import (
xfail_issue_63039,
xfail_issue_63043,
xfail_issue_63044,
xfail_issue_63045,
xfail_issue_63136,
xfail_issue_63137,
xfail_issue_63138,
@ -132,7 +129,6 @@ tests_expected_to_fail = [
"OnnxBackendNodeModelTest.test_scatter_elements_with_negative_indices_cpu",
"OnnxBackendNodeModelTest.test_gather_negative_indices_cpu",
),
(xfail_issue_33633, "OnnxBackendNodeModelTest.test_maxpool_2d_dilations_cpu"),
(
xfail_issue_55760,
"OnnxBackendNodeModelTest.test_argmax_negative_axis_keepdims_example_select_last_index_cpu",
@ -327,11 +323,6 @@ tests_expected_to_fail = [
"OnnxBackendNodeModelTest.test_squeeze_cpu",
"OnnxBackendNodeModelTest.test_squeeze_negative_axes_cpu",
),
(
xfail_issue_33593,
"OnnxBackendNodeModelTest.test_maxpool_with_argmax_2d_precomputed_strides_cpu",
"OnnxBackendNodeModelTest.test_maxpool_with_argmax_2d_precomputed_pads_cpu",
),
(xfail_issue_58033, "OnnxBackendNodeModelTest.test_einsum_batch_diagonal_cpu"),
(
xfail_issue_63033,
@ -376,11 +367,6 @@ tests_expected_to_fail = [
"OnnxBackendNodeModelTest.test_triu_square_neg_cpu",
"OnnxBackendNodeModelTest.test_triu_zero_cpu",
),
(
xfail_issue_63045,
"OnnxBackendPyTorchConvertedModelTest.test_MaxPool1d_stride_padding_dilation_cpu",
"OnnxBackendPyTorchConvertedModelTest.test_MaxPool2d_stride_padding_dilation_cpu",
),
(
skip_rng_tests,
"OnnxBackendNodeModelTest.test_bernoulli_cpu",