From 0e7eef3c886abf944886d9ae5f0e0387074d5fa8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tomasz=20Do=C5=82bniak?= Date: Thu, 11 Nov 2021 12:48:32 +0100 Subject: [PATCH] ONNX MaxPool (opset 8+) (#7100) --- .../ngraph/runtime/reference/max_pool.hpp | 180 +++++++----------- ngraph/core/src/op/max_pool.cpp | 6 + ngraph/core/src/pass/convert_precision.cpp | 13 ++ .../onnx/frontend/src/op/max_pool.cpp | 8 +- .../onnx/frontend/src/op/max_pool.hpp | 13 ++ .../frontend/onnx/frontend/src/ops_bridge.cpp | 1 + .../frontend/src/utils/pooling_factory.cpp | 42 ++++ .../frontend/src/utils/pooling_factory.hpp | 7 + ngraph/test/models/onnx/max_pool_3d.prototxt | 76 ++++++++ .../onnx/max_pool_4d_ceil_mode.prototxt | 91 +++++++++ .../onnx/max_pool_4d_ceil_strides.prototxt | 97 ++++++++++ .../onnx/max_pool_4d_dilations.prototxt | 92 +++++++++ .../models/onnx/max_pool_4d_strides.prototxt | 100 ++++++++++ ngraph/test/onnx/onnx_import.in.cpp | 80 +++++++- .../test/onnx/onnx_import_dyn_shapes.in.cpp | 9 +- ngraph/test/runtime/ie/unit_test.manifest | 13 +- ngraph/test/util/test_tools.hpp | 9 + runtime/bindings/python/tests/__init__.py | 3 - .../python/tests/test_onnx/test_backend.py | 14 -- .../tests/test_onnx/test_ops_convpool.py | 2 +- .../python/tests/test_onnx/test_zoo_models.py | 13 ++ .../python/tests_compatibility/__init__.py | 3 - .../test_onnx/test_backend.py | 14 -- 23 files changed, 725 insertions(+), 161 deletions(-) create mode 100644 ngraph/test/models/onnx/max_pool_3d.prototxt create mode 100644 ngraph/test/models/onnx/max_pool_4d_ceil_mode.prototxt create mode 100644 ngraph/test/models/onnx/max_pool_4d_ceil_strides.prototxt create mode 100644 ngraph/test/models/onnx/max_pool_4d_dilations.prototxt create mode 100644 ngraph/test/models/onnx/max_pool_4d_strides.prototxt diff --git a/ngraph/core/reference/include/ngraph/runtime/reference/max_pool.hpp b/ngraph/core/reference/include/ngraph/runtime/reference/max_pool.hpp index 74c0261112b..72c1f2230b2 100644 --- a/ngraph/core/reference/include/ngraph/runtime/reference/max_pool.hpp +++ b/ngraph/core/reference/include/ngraph/runtime/reference/max_pool.hpp @@ -142,12 +142,7 @@ void validate_max_pool_kernel_params(const size_t dims, /// but at the same time it can represent pixel offsets in the filter itself (dilated or non-dilated) template struct Coord : public std::vector { - Coord(const Shape& pads_begin) { - std::vector::reserve(pads_begin.size()); - for (const auto axis_padding : pads_begin) { - std::vector::push_back(0 - axis_padding); - } - } + Coord() = default; Coord(std::initializer_list&& values) : std::vector{std::move(values)} {} }; @@ -165,47 +160,15 @@ bool elem_in_padding_area(const Coord& kernel_position, return false; } -template -Coord next_kernel_position_2D(Coord kernel_position, - const Shape& kernel, - const Strides& kernel_strides, - const Strides& kernel_dilations, - const Shape& data_shape, - const Shape& pads_begin, - const Shape& pads_end) { - // move the kernel horizontally one stride to the right - kernel_position[1] += kernel_strides[1]; - - // if the top-right corner of the kernel is outside of the padding area, - // move it back to the left and one stride down - if (kernel_position[1] + (kernel[1] - 1) * kernel_dilations[1] >= data_shape[3] + pads_end[1]) { - kernel_position[1] = 0 - pads_begin[1]; - kernel_position[0] += kernel_strides[0]; +Coord calculate_kernel_position(const Coord& out_elem_coord, + const Strides& kernel_strides, + const Shape& pads_begin) { + Coord top_left_corner; + top_left_corner.reserve(out_elem_coord.size()); + for (size_t i = 0u; i < out_elem_coord.size(); ++i) { + top_left_corner.emplace_back(out_elem_coord[i] * kernel_strides[i] - pads_begin[i]); } - - return kernel_position; -} - -template -Coord next_kernel_position_3D(Coord kernel_position, - const Shape& kernel, - const Strides& kernel_strides, - const Strides& kernel_dilations, - const Shape& data_shape, - const Shape& pads_begin, - const Shape& pads_end) { - kernel_position[2] += kernel_strides[2]; - - if (kernel_position[2] + (kernel[2] - 1) * kernel_dilations[2] >= data_shape[4] + pads_end[2]) { - kernel_position[2] = 0 - pads_begin[2]; - kernel_position[1] += kernel_strides[1]; - if (kernel_position[1] + (kernel[1] - 1) * kernel_dilations[1] >= data_shape[3] + pads_end[1]) { - kernel_position[1] = 0 - pads_begin[1]; - kernel_position[0] += kernel_strides[0]; - } - } - - return kernel_position; + return top_left_corner; } namespace kernel { @@ -255,43 +218,44 @@ void max_pool_2d(const Values_t* data, const size_t indices_offset) { validate_max_pool_kernel_params(2, kernel, kernel_strides, kernel_dilations, pads_begin, pads_end); - Coord kernel_position{pads_begin}; + // helper constants(axes) denoting dimensions in the input data shape and kernel shape + constexpr size_t data_H = 2, data_W = 3; + constexpr size_t kernel_H = 0, kernel_W = 1; // select max elem and its index for each "placeholder" in the out buffer (pointed to by out_idx) - for (size_t out_idx = 0; out_idx < out_shape[2] * out_shape[3]; ++out_idx) { - Values_t max_elem = std::numeric_limits::lowest(); - Indices_t max_elem_idx = Indices_t{0}; + size_t out_idx = 0u; + for (size_t out_row = 0u; out_row < out_shape[data_H]; ++out_row) { + for (size_t out_col = 0u; out_col < out_shape[data_W]; ++out_col) { + Values_t max_elem = std::numeric_limits::lowest(); + Indices_t max_elem_idx = Indices_t{0}; - // find the max element in the area covered by a current position of the kernel - for (size_t kernel_row = 0; kernel_row < kernel[0]; ++kernel_row) { - for (size_t kernel_col = 0; kernel_col < kernel[1]; ++kernel_col) { - // offset from the top-left corner of the kernel for a given row and col - const Coord kernel_offset{kernel_row * kernel_dilations[0], kernel_col * kernel_dilations[1]}; + const auto kernel_position = calculate_kernel_position({out_row, out_col}, kernel_strides, pads_begin); + // find the max element in the area covered by a current position of the kernel + for (size_t kernel_row = 0; kernel_row < kernel[kernel_H]; ++kernel_row) { + for (size_t kernel_col = 0; kernel_col < kernel[kernel_W]; ++kernel_col) { + // offset from the top-left corner of the kernel for a given row and col + const Coord kernel_offset{kernel_row * kernel_dilations[kernel_H], + kernel_col * kernel_dilations[kernel_W]}; - // ignore the elements in the padding area - if (!elem_in_padding_area(kernel_position, kernel_offset, data_shape)) { - // index of the flattened tensor element under the current row & column of the kernel - const size_t data_elem_index = - data_shape[2] * (kernel_offset[0] + kernel_position[0]) + kernel_offset[1] + kernel_position[1]; + // ignore the elements in the padding area + if (!elem_in_padding_area(kernel_position, kernel_offset, data_shape)) { + // index of the flattened tensor element under the current row & column of the kernel + const size_t data_elem_index = + data_shape[data_W] * (kernel_offset[kernel_H] + kernel_position[kernel_H]) + + kernel_offset[kernel_W] + kernel_position[kernel_W]; - if (data[data_elem_index] > max_elem) { - max_elem = data[data_elem_index]; - max_elem_idx = data_elem_index; + if (data[data_elem_index] > max_elem) { + max_elem = data[data_elem_index]; + max_elem_idx = data_elem_index; + } } } } + + values[out_idx] = max_elem; + indices[out_idx] = max_elem_idx + indices_offset; + ++out_idx; } - - values[out_idx] = max_elem; - indices[out_idx] = max_elem_idx + indices_offset; - - kernel_position = next_kernel_position_2D(kernel_position, - kernel, - kernel_strides, - kernel_dilations, - data_shape, - pads_begin, - pads_end); } } @@ -309,49 +273,51 @@ void max_pool_3d(const Values_t* data, const size_t indices_offset) { validate_max_pool_kernel_params(3, kernel, kernel_strides, kernel_dilations, pads_begin, pads_end); - Coord kernel_position{pads_begin}; - - const size_t out_elems = shape_size(std::begin(out_shape) + 2, std::end(out_shape)); + // helper constants(axes) denoting dimensions in the input data shape and kernel shape + constexpr size_t data_D = 2, data_H = 3, data_W = 4; + constexpr size_t kernel_D = 0, kernel_H = 1, kernel_W = 2; // select max elem and its index for each "placeholder" in the out buffer (pointed to by out_idx) - for (size_t out_idx = 0; out_idx < out_elems; ++out_idx) { - Values_t max_elem = std::numeric_limits::lowest(); - Indices_t max_elem_idx = Indices_t{0}; + size_t out_idx = 0u; + for (size_t out_channel = 0u; out_channel < out_shape[data_D]; ++out_channel) { + for (size_t out_row = 0u; out_row < out_shape[data_H]; ++out_row) { + for (size_t out_col = 0u; out_col < out_shape[data_W]; ++out_col) { + Values_t max_elem = std::numeric_limits::lowest(); + Indices_t max_elem_idx = Indices_t{0}; - for (size_t kernel_channel = 0; kernel_channel < kernel[0]; ++kernel_channel) { - for (size_t kernel_row = 0; kernel_row < kernel[1]; ++kernel_row) { - for (size_t kernel_col = 0; kernel_col < kernel[2]; ++kernel_col) { - // offset from the top-left corner of the kernel for a given row and col - const Coord kernel_offset{kernel_channel * kernel_dilations[0], - kernel_row * kernel_dilations[1], - kernel_col * kernel_dilations[2]}; + const auto kernel_position = + calculate_kernel_position({out_channel, out_row, out_col}, kernel_strides, pads_begin); - // ignore the elements in the padding area - if (!elem_in_padding_area(kernel_position, kernel_offset, data_shape)) { - // index of the flattened tensor element under the current row & column of the kernel - const size_t data_elem_index = - data_shape[2] * data_shape[3] * (kernel_offset[0] + kernel_position[0]) + - data_shape[3] * (kernel_offset[1] + kernel_position[1]) + kernel_offset[2] + - kernel_position[2]; + for (size_t kernel_channel = 0; kernel_channel < kernel[kernel_D]; ++kernel_channel) { + for (size_t kernel_row = 0; kernel_row < kernel[kernel_H]; ++kernel_row) { + for (size_t kernel_col = 0; kernel_col < kernel[kernel_W]; ++kernel_col) { + // offset from the top-left corner of the kernel for a given row and col + const Coord kernel_offset{kernel_channel * kernel_dilations[kernel_D], + kernel_row * kernel_dilations[kernel_H], + kernel_col * kernel_dilations[kernel_W]}; - if (data[data_elem_index] > max_elem) { - max_elem = data[data_elem_index]; - max_elem_idx = data_elem_index; + // ignore the elements in the padding area + if (!elem_in_padding_area(kernel_position, kernel_offset, data_shape)) { + // index of the flattened tensor element under the current row & column of the kernel + const size_t data_elem_index = + data_shape[data_H] * data_shape[data_W] * + (kernel_offset[kernel_D] + kernel_position[kernel_D]) + + data_shape[data_W] * (kernel_offset[kernel_H] + kernel_position[kernel_H]) + + kernel_offset[kernel_W] + kernel_position[kernel_W]; + + if (data[data_elem_index] > max_elem) { + max_elem = data[data_elem_index]; + max_elem_idx = data_elem_index; + } + } } } } + values[out_idx] = max_elem; + indices[out_idx] = max_elem_idx + indices_offset; + ++out_idx; } } - values[out_idx] = max_elem; - indices[out_idx] = max_elem_idx + indices_offset; - - kernel_position = next_kernel_position_3D(kernel_position, - kernel, - kernel_strides, - kernel_dilations, - data_shape, - pads_begin, - pads_end); } } } // namespace kernel diff --git a/ngraph/core/src/op/max_pool.cpp b/ngraph/core/src/op/max_pool.cpp index afb78c6fe3b..ed9775c1d0c 100644 --- a/ngraph/core/src/op/max_pool.cpp +++ b/ngraph/core/src/op/max_pool.cpp @@ -220,8 +220,10 @@ bool evaluate_maxpool(const HostTensorPtr& data, switch (indices->get_element_type()) { case element::Type_t::i32: { switch (data->get_element_type()) { + EVAL_MAX_POOL_8(i8, i32); EVAL_MAX_POOL_8(i32, i32); EVAL_MAX_POOL_8(i64, i32); + EVAL_MAX_POOL_8(u8, i32); EVAL_MAX_POOL_8(u32, i32); EVAL_MAX_POOL_8(u64, i32); EVAL_MAX_POOL_8(f16, i32); @@ -233,8 +235,10 @@ bool evaluate_maxpool(const HostTensorPtr& data, } break; case element::Type_t::i64: { switch (data->get_element_type()) { + EVAL_MAX_POOL_8(i8, i64); EVAL_MAX_POOL_8(i32, i64); EVAL_MAX_POOL_8(i64, i64); + EVAL_MAX_POOL_8(u8, i64); EVAL_MAX_POOL_8(u32, i64); EVAL_MAX_POOL_8(u64, i64); EVAL_MAX_POOL_8(f16, i64); @@ -319,8 +323,10 @@ shared_ptr op::v8::MaxPool::clone_with_new_inputs(const OutputVector& new_ bool op::v8::MaxPool::has_evaluate() const { NGRAPH_OP_SCOPE(v8_MaxPool_has_evaluate); switch (get_input_element_type(0)) { + case ngraph::element::i8: case ngraph::element::i32: case ngraph::element::i64: + case ngraph::element::u8: case ngraph::element::u32: case ngraph::element::u64: case ngraph::element::f16: diff --git a/ngraph/core/src/pass/convert_precision.cpp b/ngraph/core/src/pass/convert_precision.cpp index 41793842b4d..b11ee9fe029 100644 --- a/ngraph/core/src/pass/convert_precision.cpp +++ b/ngraph/core/src/pass/convert_precision.cpp @@ -10,6 +10,7 @@ #include #include #include +#include #include #include @@ -30,6 +31,7 @@ bool fuse_type_to_nms3(const std::shared_ptr& node, ngraph::elemen bool fuse_type_to_nms4(const std::shared_ptr& node, ngraph::element::Type to, size_t idx); bool fuse_type_to_nms5(const std::shared_ptr& node, ngraph::element::Type to, size_t idx); bool fuse_type_to_topk(const std::shared_ptr& node, ngraph::element::Type to, size_t idx); +bool fuse_type_to_maxpool(const std::shared_ptr& node, ngraph::element::Type to, size_t idx); bool fuse_type_to_nonzero(const std::shared_ptr& node, ngraph::element::Type to, size_t idx); bool fuse_type_to_bucketize(const std::shared_ptr& node, ngraph::element::Type to, size_t idx); bool fuse_type_to_ctc_greedy_decoder_seq_len(const std::shared_ptr& node, @@ -253,6 +255,7 @@ bool ngraph::pass::ConvertPrecision::run_on_function(std::shared_ptr}, @@ -392,6 +395,16 @@ bool fuse_type_to_topk(const std::shared_ptr& node, ngraph::elemen return false; } +bool fuse_type_to_maxpool(const std::shared_ptr& node, ngraph::element::Type to, size_t idx) { + if (auto maxpool = ov::as_type_ptr(node)) { + if (idx == 1 && (to == element::i32 || to == element::i64)) { + maxpool->set_index_element_type(to); + return true; + } + } + return false; +} + bool fuse_type_to_ctc_greedy_decoder_seq_len(const std::shared_ptr& node, ngraph::element::Type to, size_t idx) { diff --git a/ngraph/frontend/onnx/frontend/src/op/max_pool.cpp b/ngraph/frontend/onnx/frontend/src/op/max_pool.cpp index 18004d6bf43..df02d83943e 100644 --- a/ngraph/frontend/onnx/frontend/src/op/max_pool.cpp +++ b/ngraph/frontend/onnx/frontend/src/op/max_pool.cpp @@ -25,8 +25,14 @@ OutputVector max_pool(const Node& node) { } // namespace set_1 +namespace set_8 { +OutputVector max_pool(const Node& node) { + return pooling::PoolingFactory(node).make_max_pool_with_indices(); +} +} // namespace set_8 + } // namespace op } // namespace onnx_import -} // namespace ngraph \ No newline at end of file +} // namespace ngraph diff --git a/ngraph/frontend/onnx/frontend/src/op/max_pool.hpp b/ngraph/frontend/onnx/frontend/src/op/max_pool.hpp index 82c10cc45c5..b7fda5e4058 100644 --- a/ngraph/frontend/onnx/frontend/src/op/max_pool.hpp +++ b/ngraph/frontend/onnx/frontend/src/op/max_pool.hpp @@ -23,6 +23,19 @@ OutputVector max_pool(const Node& node); } // namespace set_1 +namespace set_8 { +/// +/// \brief Convert ONNX MaxPool operation to an nGraph node. +/// +/// \param node The ONNX node object representing this operation. +/// +/// \return The vector containing Ngraph nodes producing output of ONNX MaxPool +/// operation. +/// +OutputVector max_pool(const Node& node); + +} // namespace set_8 + } // namespace op } // namespace onnx_import diff --git a/ngraph/frontend/onnx/frontend/src/ops_bridge.cpp b/ngraph/frontend/onnx/frontend/src/ops_bridge.cpp index 6dfa81ea02b..ffe15e5b1e5 100644 --- a/ngraph/frontend/onnx/frontend/src/ops_bridge.cpp +++ b/ngraph/frontend/onnx/frontend/src/ops_bridge.cpp @@ -361,6 +361,7 @@ OperatorsBridge::OperatorsBridge() { REGISTER_OPERATOR("MatMulInteger", 1, matmul_integer); REGISTER_OPERATOR("MatMul", 1, matmul); REGISTER_OPERATOR("MaxPool", 1, max_pool); + REGISTER_OPERATOR("MaxPool", 8, max_pool); REGISTER_OPERATOR("Max", 1, max); REGISTER_OPERATOR("Max", 8, max); REGISTER_OPERATOR("Mean", 1, mean); diff --git a/ngraph/frontend/onnx/frontend/src/utils/pooling_factory.cpp b/ngraph/frontend/onnx/frontend/src/utils/pooling_factory.cpp index fff79f9988f..4984698bbac 100644 --- a/ngraph/frontend/onnx/frontend/src/utils/pooling_factory.cpp +++ b/ngraph/frontend/onnx/frontend/src/utils/pooling_factory.cpp @@ -14,6 +14,27 @@ namespace ngraph { namespace onnx_import { namespace pooling { + +namespace { +std::shared_ptr transposition_axis_order(const Rank& input_rank) { + NGRAPH_CHECK(input_rank.is_static(), + "Generating column-major MaxPool results is supported only for inputs with static rank."); + + const auto rank = static_cast(input_rank.get_length()); + + std::vector axes(rank); + std::iota(axes.begin(), axes.end(), 0); + std::reverse(axes.begin() + 2, axes.end()); + + return std::make_shared(element::i32, Shape{rank}, axes); +} + +std::shared_ptr identity(Output node_output) { + const auto zero = default_opset::Constant::create(node_output.get_element_type(), {}, {0}); + return std::make_shared(node_output, zero); +} +} // namespace + PoolingFactory::PoolingFactory(const Node& node) : m_onnx_node{node}, m_inputs{node.get_ng_inputs()}, @@ -27,6 +48,7 @@ PoolingFactory::PoolingFactory(const Node& node) const CoordinateDiff& padding_below{paddings.first}; m_padding_below = Shape{std::begin(padding_below), std::end(padding_below)}; m_padding_above = Shape{std::begin(padding_above), std::end(padding_above)}; + m_storage_order = static_cast(node.get_attribute_value("storage_order", 0)); } OutputVector PoolingFactory::make_avg_pool() const { @@ -50,6 +72,26 @@ OutputVector PoolingFactory::make_max_pool() const { m_rounding_type, m_auto_pad)}; } + +OutputVector PoolingFactory::make_max_pool_with_indices() const { + const auto max_pool = std::make_shared(m_inputs.at(0), + m_strides, + m_dilations, + m_padding_below, + m_padding_above, + m_kernel_shape, + m_rounding_type, + m_auto_pad); + if (m_storage_order == StorageOrder::COLUMN_MAJOR) { + const auto transposition_axes = transposition_axis_order(m_inputs.at(0).get_partial_shape().rank()); + const auto transposed_indices = + std::make_shared(max_pool->output(1), transposition_axes); + + return {max_pool->output(0), transposed_indices}; + } else { + return {identity(max_pool->output(0)), identity(max_pool->output(1))}; + } +} } // namespace pooling } // namespace onnx_import } // namespace ngraph diff --git a/ngraph/frontend/onnx/frontend/src/utils/pooling_factory.hpp b/ngraph/frontend/onnx/frontend/src/utils/pooling_factory.hpp index 133f21a6c36..4e4f750db25 100644 --- a/ngraph/frontend/onnx/frontend/src/utils/pooling_factory.hpp +++ b/ngraph/frontend/onnx/frontend/src/utils/pooling_factory.hpp @@ -46,6 +46,9 @@ public: /// OutputVector make_max_pool() const; + /// \brief Creates max pooling ONNX operation with 2 outputs (values and indices). + OutputVector make_max_pool_with_indices() const; + protected: Node m_onnx_node; const OutputVector m_inputs; @@ -56,6 +59,10 @@ protected: Shape m_padding_above; ngraph::op::PadType m_auto_pad; ngraph::op::RoundingType m_rounding_type; + + enum class StorageOrder : int64_t { ROW_MAJOR = 0, COLUMN_MAJOR = 1 }; + + StorageOrder m_storage_order; }; } // namespace pooling } // namespace onnx_import diff --git a/ngraph/test/models/onnx/max_pool_3d.prototxt b/ngraph/test/models/onnx/max_pool_3d.prototxt new file mode 100644 index 00000000000..d84830a05d3 --- /dev/null +++ b/ngraph/test/models/onnx/max_pool_3d.prototxt @@ -0,0 +1,76 @@ +ir_version: 3 +producer_name: "backend-test" +graph { + node { + input: "x" + output: "y" + output: "z" + op_type: "MaxPool" + attribute { + name: "kernel_shape" + ints: 2 + type: INTS + } + } + name: "maxpool_test" + input { + name: "x" + type { + tensor_type { + elem_type: 6 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 3 + } + dim { + dim_value: 3 + } + } + } + } + } + output { + name: "y" + type { + tensor_type { + elem_type: 6 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 3 + } + dim { + dim_value: 2 + } + } + } + } + } + output { + name: "z" + type { + tensor_type { + elem_type: 7 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 3 + } + dim { + dim_value: 2 + } + } + } + } + } +} +opset_import { + version: 13 +} diff --git a/ngraph/test/models/onnx/max_pool_4d_ceil_mode.prototxt b/ngraph/test/models/onnx/max_pool_4d_ceil_mode.prototxt new file mode 100644 index 00000000000..456aa6df3c4 --- /dev/null +++ b/ngraph/test/models/onnx/max_pool_4d_ceil_mode.prototxt @@ -0,0 +1,91 @@ +ir_version: 3 +producer_name: "backend-test" +graph { + node { + input: "x" + output: "y" + output: "z" + op_type: "MaxPool" + attribute { + name: "kernel_shape" + ints: 3 + ints: 3 + type: INTS + } + attribute { + name: "ceil_mode" + i: 1 + type: INT + } + } + name: "maxpool_test" + input { + name: "x" + type { + tensor_type { + elem_type: 6 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 4 + } + dim { + dim_value: 4 + } + } + } + } + } + output { + name: "y" + type { + tensor_type { + elem_type: 6 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 2 + } + dim { + dim_value: 2 + } + } + } + } + } + output { + name: "z" + type { + tensor_type { + elem_type: 7 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 2 + } + dim { + dim_value: 2 + } + } + } + } + } +} +opset_import { + version: 13 +} diff --git a/ngraph/test/models/onnx/max_pool_4d_ceil_strides.prototxt b/ngraph/test/models/onnx/max_pool_4d_ceil_strides.prototxt new file mode 100644 index 00000000000..dab7cfdb090 --- /dev/null +++ b/ngraph/test/models/onnx/max_pool_4d_ceil_strides.prototxt @@ -0,0 +1,97 @@ +ir_version: 4 +producer_name: "backend-test" +graph { + node { + input: "x" + output: "y" + output: "z" + op_type: "MaxPool" + attribute { + name: "ceil_mode" + i: 1 + type: INT + } + attribute { + name: "kernel_shape" + ints: 3 + ints: 3 + type: INTS + } + attribute { + name: "strides" + ints: 2 + ints: 2 + type: INTS + } + } + name: "test_maxpool_2d_ceil" + input { + name: "x" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 4 + } + dim { + dim_value: 4 + } + } + } + } + } + output { + name: "y" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 2 + } + dim { + dim_value: 2 + } + } + } + } + } + output { + name: "z" + type { + tensor_type { + elem_type: 7 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 2 + } + dim { + dim_value: 2 + } + } + } + } + } +} +opset_import { + version: 10 +} diff --git a/ngraph/test/models/onnx/max_pool_4d_dilations.prototxt b/ngraph/test/models/onnx/max_pool_4d_dilations.prototxt new file mode 100644 index 00000000000..12cfdbe162b --- /dev/null +++ b/ngraph/test/models/onnx/max_pool_4d_dilations.prototxt @@ -0,0 +1,92 @@ +ir_version: 3 +producer_name: "backend-test" +graph { + node { + input: "x" + output: "y" + output: "z" + op_type: "MaxPool" + attribute { + name: "kernel_shape" + ints: 2 + ints: 2 + type: INTS + } + attribute { + name: "dilations" + ints: 2 + ints: 2 + type: INTS + } + } + name: "maxpool_test" + input { + name: "x" + type { + tensor_type { + elem_type: 6 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 4 + } + dim { + dim_value: 4 + } + } + } + } + } + output { + name: "y" + type { + tensor_type { + elem_type: 6 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 2 + } + dim { + dim_value: 2 + } + } + } + } + } + output { + name: "z" + type { + tensor_type { + elem_type: 7 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 2 + } + dim { + dim_value: 2 + } + } + } + } + } +} +opset_import { + version: 13 +} diff --git a/ngraph/test/models/onnx/max_pool_4d_strides.prototxt b/ngraph/test/models/onnx/max_pool_4d_strides.prototxt new file mode 100644 index 00000000000..2c207a1a9c0 --- /dev/null +++ b/ngraph/test/models/onnx/max_pool_4d_strides.prototxt @@ -0,0 +1,100 @@ +ir_version: 3 +producer_name: "backend-test" +graph { + node { + input: "x" + output: "y" + output: "z" + op_type: "MaxPool" + attribute { + name: "kernel_shape" + ints: 3 + ints: 3 + type: INTS + } + attribute { + name: "strides" + ints: 3 + ints: 3 + type: INTS + } + attribute { + name: "pads" + ints: 2 + ints: 2 + ints: 2 + ints: 2 + type: INTS + } + } + name: "maxpool_test" + input { + name: "x" + type { + tensor_type { + elem_type: 3 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 5 + } + dim { + dim_value: 5 + } + } + } + } + } + output { + name: "y" + type { + tensor_type { + elem_type: 3 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 3 + } + dim { + dim_value: 3 + } + } + } + } + } + output { + name: "z" + type { + tensor_type { + elem_type: 7 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 3 + } + dim { + dim_value: 3 + } + } + } + } + } +} +opset_import { + version: 13 +} diff --git a/ngraph/test/onnx/onnx_import.in.cpp b/ngraph/test/onnx/onnx_import.in.cpp index 0e0be7e5bb5..7663bb7280a 100644 --- a/ngraph/test/onnx/onnx_import.in.cpp +++ b/ngraph/test/onnx/onnx_import.in.cpp @@ -22,14 +22,16 @@ #endif // clang-format on -#include "onnx_import/core/null_node.hpp" +#include + +#include "default_opset.hpp" #include "gtest/gtest.h" +#include "ngraph/ngraph.hpp" +#include "ngraph/pass/constant_folding.hpp" +#include "ngraph/pass/manager.hpp" +#include "onnx_import/core/null_node.hpp" #include "onnx_import/onnx.hpp" #include "onnx_import/onnx_utils.hpp" -#include "default_opset.hpp" -#include "ngraph/ngraph.hpp" -#include "ngraph/pass/manager.hpp" -#include "ngraph/pass/constant_folding.hpp" #include "util/all_close.hpp" #include "util/all_close_f.hpp" #include "util/ndarray.hpp" @@ -38,7 +40,6 @@ #include "engines_util/test_engines.hpp" #include "util/test_tools.hpp" #include "util/type_prop.hpp" -#include NGRAPH_SUPPRESS_DEPRECATED_START @@ -4165,6 +4166,73 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_float16_tensor_as_int32) { test_case.run(); } +NGRAPH_TEST(${BACKEND_NAME}, onnx_model_max_pool_3d) { + const auto function = onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/max_pool_3d.onnx")); + + auto test_case = test::TestCase(function); + test_case.add_input(Shape{1, 3, 3}, {-1, 0, 1, 20, -20, 10, 0, 2, 1}); + test_case.add_expected_output(Shape{1, 3, 2}, {0, 1, 20, 10, 2, 2}); + test_case.add_expected_output(Shape{1, 3, 2}, {1, 2, 3, 5, 7, 7}); + + test_case.run(); +} + +NGRAPH_TEST(${BACKEND_NAME}, onnx_model_max_pool_4d_ceil_mode) { + const auto function = + onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/max_pool_4d_ceil_mode.onnx")); + + auto test_case = test::TestCase(function); + test_case.add_input(Shape{1, 1, 4, 4}, gen_range(16, 1)); + test_case.add_expected_output(Shape{1, 1, 2, 2}, {11, 12, 15, 16}); + test_case.add_expected_output(Shape{1, 1, 2, 2}, {10, 11, 14, 15}); + + test_case.run(); +} + +NGRAPH_TEST(${BACKEND_NAME}, onnx_model_max_pool_4d_dilations) { + const auto function = + onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/max_pool_4d_dilations.onnx")); + + auto test_case = test::TestCase(function); + test_case.add_input(Shape{1, 1, 4, 4}, {9, 10, 11, 12, 1, 2, 3, 4, 16, 14, 15, 13, 5, 6, 8, 7}); + test_case.add_expected_output(Shape{1, 1, 2, 2}, {16, 14, 8, 7}); + test_case.add_expected_output(Shape{1, 1, 2, 2}, {8, 9, 14, 15}); + + test_case.run(); +} + +NGRAPH_TEST(${BACKEND_NAME}, onnx_model_max_pool_4d_strides) { + // kernel: 3x3 + // strides: 3, 3 + // explicit pads: 2, 2, 2, 2 + const auto function = + onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/max_pool_4d_strides.onnx")); + + auto test_case = test::TestCase(function); + test_case.add_input(Shape{1, 1, 5, 5}, gen_range(25, 1)); + test_case.add_expected_output(Shape{1, 1, 3, 3}, {1, 4, 5, 16, 19, 20, 21, 24, 25}); + test_case.add_expected_output(Shape{1, 1, 3, 3}, {0, 3, 4, 15, 18, 19, 20, 23, 24}); + + test_case.run(); +} + +NGRAPH_TEST(${BACKEND_NAME}, onnx_model_max_pool_4d_ceil_strides) { + // kernel: 3x3 + // strides: 2, 2 + // ceil_mode: 1 + const auto function = + onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/max_pool_4d_ceil_strides.onnx")); + + auto test_case = test::TestCase(function); + test_case.add_input( + Shape{1, 1, 4, 4}, + {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f}); + test_case.add_expected_output(Shape{1, 1, 2, 2}, {11.0f, 12.0f, 15.0f, 16.0f}); + test_case.add_expected_output(Shape{1, 1, 2, 2}, {10, 11, 14, 15}); + + test_case.run(); +} + NGRAPH_TEST(${BACKEND_NAME}, onnx_model_random_uniform) { const auto function = onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/random_uniform.onnx")); diff --git a/ngraph/test/onnx/onnx_import_dyn_shapes.in.cpp b/ngraph/test/onnx/onnx_import_dyn_shapes.in.cpp index df3184753be..fdc4eab69bd 100644 --- a/ngraph/test/onnx/onnx_import_dyn_shapes.in.cpp +++ b/ngraph/test/onnx/onnx_import_dyn_shapes.in.cpp @@ -15,9 +15,9 @@ #include #include +#include "default_opset.hpp" #include "gtest/gtest.h" #include "ngraph/file_util.hpp" -#include "default_opset.hpp" #include "onnx_import/onnx.hpp" #include "engines_util/test_engines.hpp" #include "engines_util/test_case.hpp" @@ -333,10 +333,9 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_dyn_shapes_max_pool_with_indices_output) { 25.f, 25.f, 23.f, 24.f, 25.f, 25.f, 25.f, 23.f, 24.f, 25.f, 25.f, 25.f}; test_case.add_expected_output(Shape{1, 1, 5, 5}, expected_values); - // indices output is not supported and is ingored in current implementation - // std::vector expected_indices{12, 13, 14, 14, 14, 17, 18, 19, 19, 19, 22, 23, 24, 24, - // 24, 22, 23, 24, 24, 24, 22, 23, 24, 24, 24}; - // test_case.add_expected_output(Shape{1, 1, 5, 5}, expected_indices); + std::vector expected_indices{12, 13, 14, 14, 14, 17, 18, 19, 19, 19, 22, 23, 24, + 24, 24, 22, 23, 24, 24, 24, 22, 23, 24, 24, 24}; + test_case.add_expected_output(Shape{1, 1, 5, 5}, expected_indices); test_case.run(); } diff --git a/ngraph/test/runtime/ie/unit_test.manifest b/ngraph/test/runtime/ie/unit_test.manifest index 8a55b98d623..5942e64e256 100644 --- a/ngraph/test/runtime/ie/unit_test.manifest +++ b/ngraph/test/runtime/ie/unit_test.manifest @@ -371,14 +371,12 @@ tile_3d_few_repeats # Result mismatch sum_large_1d_to_scalar sum_stable_acc -max_pool_3d avg_pool_2d_2channel_2image_padded_only_above_include_in_computation avg_pool_3d_uneven_strided_padded multiple_result lrn_across_all_dims elu elu_negative_alpha -max_pool_2d_1channel_1image_overpadded grn_2d_with_bias erf divide_adjoint_stability @@ -386,8 +384,6 @@ notequal less sum_3d_to_scalar_int32 sum_2d_to_scalar_int8 -max_pool_uint8 -max_pool_int8 avg_pool_uint8 avg_pool_int8 max_to_scalar_int8 @@ -432,6 +428,12 @@ onnx_constant_integer_array adaptive_max_pool_1d adaptive_max_pool_2d adaptive_max_pool_3d +onnx_dyn_shapes_max_pool_with_indices_output +onnx_model_max_pool_3d +onnx_model_max_pool_4d_ceil_mode +onnx_model_max_pool_4d_dilations +onnx_model_max_pool_4d_strides +onnx_model_max_pool_4d_ceil_strides # Unsupported primitive of type: SigmoidBackprop sigmoid_bprop_n1c1h4 @@ -554,9 +556,6 @@ product_to_scalar_int8 min_to_scalar_int8 # Pooling layer. Unsupported mode. Only 4D and 5D blobs are supported as input. -max_pool_1d_1channel_1image -max_pool_1d_1channel_2image -max_pool_1d_2channel_2image avg_pool_1d_1channel_1image avg_pool_1d_1channel_2image avg_pool_1d_2channel_2image diff --git a/ngraph/test/util/test_tools.hpp b/ngraph/test/util/test_tools.hpp index 1a3dc91ad0d..f559fe3c48f 100644 --- a/ngraph/test/util/test_tools.hpp +++ b/ngraph/test/util/test_tools.hpp @@ -55,3 +55,12 @@ std::vector read_binary_file(const std::string& path) { inputs_fs.read(reinterpret_cast(file_content.data()), size); return file_content; } + +template +std::vector gen_range(const size_t elements, const T start = T{0}) { + std::vector range; + range.resize(elements); + std::iota(range.begin(), range.end(), start); + + return range; +} diff --git a/runtime/bindings/python/tests/__init__.py b/runtime/bindings/python/tests/__init__.py index 7b6826b239d..07e4bd87d06 100644 --- a/runtime/bindings/python/tests/__init__.py +++ b/runtime/bindings/python/tests/__init__.py @@ -47,7 +47,6 @@ xfail_issue_33651 = xfail_test(reason="RuntimeError: nGraph does not support the "TfIdfVectorizer") xfail_issue_33581 = xfail_test(reason="RuntimeError: nGraph does not support the following ONNX operations: " "GatherElements") -xfail_issue_33633 = xfail_test(reason="MaxPool: dilations unsupported") xfail_issue_35923 = xfail_test(reason="RuntimeError: PReLU without weights is not supported") xfail_issue_35927 = xfail_test(reason="RuntimeError: B has zero dimension that is not allowable") xfail_issue_36486 = xfail_test(reason="RuntimeError: HardSigmoid operation should be converted " @@ -93,7 +92,6 @@ xfail_issue_44965 = xfail_test(reason="Expected: RuntimeError: value info has no xfail_issue_44968 = xfail_test(reason="Expected: Unsupported dynamic op: Squeeze") xfail_issue_47323 = xfail_test(reason="RuntimeError: The plugin does not support FP64") xfail_issue_47337 = xfail_test(reason="RuntimeError: Unsupported dynamic ops: v1::OneHot") -xfail_issue_33593 = xfail_test(reason="Current implementation of MaxPool doesn't support indices output") xfail_issue_55760 = xfail_test(reason="RuntimeError: Reversed axis have axes above the source space shape") # Model MSFT issues: @@ -136,7 +134,6 @@ xfail_issue_63036 = xfail_test(reason="Changes in ConvTranspose padding") xfail_issue_63039 = xfail_test(reason="Result mismatches with UINT8 operations") xfail_issue_63043 = xfail_test(reason="Recurrent node expects constants as W, R, B inputs.") xfail_issue_63044 = xfail_test(reason="ONNX opset 14 operation: Trilu") -xfail_issue_63045 = xfail_test(reason="Maxpool with strides, padding and dilations fail") skip_rng_tests = pytest.mark.skip(reason="Tests use random number generator with no seed.") xfail_issue_63136 = xfail_test(reason="Unsupported operation: CastLike") diff --git a/runtime/bindings/python/tests/test_onnx/test_backend.py b/runtime/bindings/python/tests/test_onnx/test_backend.py index 06e07603710..70b6702adfc 100644 --- a/runtime/bindings/python/tests/test_onnx/test_backend.py +++ b/runtime/bindings/python/tests/test_onnx/test_backend.py @@ -12,11 +12,9 @@ from tests import ( xfail_issue_33538, xfail_issue_33581, xfail_issue_33589, - xfail_issue_33593, xfail_issue_33595, xfail_issue_33596, xfail_issue_33606, - xfail_issue_33633, xfail_issue_33651, xfail_issue_38091, xfail_issue_38699, @@ -49,7 +47,6 @@ from tests import ( xfail_issue_63039, xfail_issue_63043, xfail_issue_63044, - xfail_issue_63045, xfail_issue_63136, xfail_issue_63137, xfail_issue_63138, @@ -143,7 +140,6 @@ tests_expected_to_fail = [ "OnnxBackendNodeModelTest.test_scatter_elements_with_negative_indices_cpu", "OnnxBackendNodeModelTest.test_gather_negative_indices_cpu", ), - (xfail_issue_33633, "OnnxBackendNodeModelTest.test_maxpool_2d_dilations_cpu"), ( xfail_issue_55760, "OnnxBackendNodeModelTest.test_argmax_negative_axis_keepdims_example_select_last_index_cpu", @@ -338,11 +334,6 @@ tests_expected_to_fail = [ "OnnxBackendNodeModelTest.test_squeeze_cpu", "OnnxBackendNodeModelTest.test_squeeze_negative_axes_cpu", ), - ( - xfail_issue_33593, - "OnnxBackendNodeModelTest.test_maxpool_with_argmax_2d_precomputed_strides_cpu", - "OnnxBackendNodeModelTest.test_maxpool_with_argmax_2d_precomputed_pads_cpu", - ), (xfail_issue_58033, "OnnxBackendNodeModelTest.test_einsum_batch_diagonal_cpu"), ( xfail_issue_63033, @@ -387,11 +378,6 @@ tests_expected_to_fail = [ "OnnxBackendNodeModelTest.test_triu_square_neg_cpu", "OnnxBackendNodeModelTest.test_triu_zero_cpu", ), - ( - xfail_issue_63045, - "OnnxBackendPyTorchConvertedModelTest.test_MaxPool1d_stride_padding_dilation_cpu", - "OnnxBackendPyTorchConvertedModelTest.test_MaxPool2d_stride_padding_dilation_cpu", - ), ( skip_rng_tests, "OnnxBackendNodeModelTest.test_bernoulli_cpu", diff --git a/runtime/bindings/python/tests/test_onnx/test_ops_convpool.py b/runtime/bindings/python/tests/test_onnx/test_ops_convpool.py index 6637a06463b..6d431fea9e2 100644 --- a/runtime/bindings/python/tests/test_onnx/test_ops_convpool.py +++ b/runtime/bindings/python/tests/test_onnx/test_ops_convpool.py @@ -369,7 +369,7 @@ def test_pool_max(ndarray_1x1x4x4): x = ndarray_1x1x4x4 y = np.array([[16, 18], [24, 26]], dtype=np.float32).reshape([1, 1, 2, 2]) - ng_results = run_node(node, [x]) + ng_results = run_node(node, [x], opset_version=7) assert np.array_equal(ng_results, [y]) diff --git a/runtime/bindings/python/tests/test_onnx/test_zoo_models.py b/runtime/bindings/python/tests/test_onnx/test_zoo_models.py index 2042d7a0f6d..ee16fa4724a 100644 --- a/runtime/bindings/python/tests/test_onnx/test_zoo_models.py +++ b/runtime/bindings/python/tests/test_onnx/test_zoo_models.py @@ -77,6 +77,7 @@ tolerance_map = { "resnet34-v2-7": {"atol": 0.001, "rtol": 0.001}, "vgg16-7": {"atol": 0.001, "rtol": 0.001}, "vgg19-bn-7": {"atol": 0.001, "rtol": 0.001}, + "vgg19-7": {"atol": 0.001, "rtol": 0.001}, "tinyyolov2-7": {"atol": 0.001, "rtol": 0.001}, "tinyyolov2-8": {"atol": 0.001, "rtol": 0.001}, "candy-8": {"atol": 0.001, "rtol": 0.001}, @@ -115,6 +116,12 @@ tolerance_map = { "test_retinanet_resnet101": {"atol": 1.3e-06}, } +def tolerance_map_key_in_model_path(path): + for key in tolerance_map: + if key in path: + return key + return None + zoo_models = [] # rglob doesn't work for symlinks, so models have to be physically somwhere inside "MODELS_ROOT_DIR" for path in Path(MODELS_ROOT_DIR).rglob("*.onnx"): @@ -127,6 +134,12 @@ for path in Path(MODELS_ROOT_DIR).rglob("*.onnx"): # updated model looks now: # {"model_name": path, "model_file": file, "dir": mdir, "atol": ..., "rtol": ...} model.update(tolerance_map[basedir]) + else: + # some models have the same stem, have to check if any of the keys from tolerance_map + # is found in the full model path + model_key = tolerance_map_key_in_model_path(str(path)) + if model_key is not None: + model.update(tolerance_map[model_key]) if basedir in post_processing: model.update(post_processing[basedir]) zoo_models.append(model) diff --git a/runtime/bindings/python/tests_compatibility/__init__.py b/runtime/bindings/python/tests_compatibility/__init__.py index 8d6cd57543c..f8d40d21e7b 100644 --- a/runtime/bindings/python/tests_compatibility/__init__.py +++ b/runtime/bindings/python/tests_compatibility/__init__.py @@ -46,7 +46,6 @@ xfail_issue_33651 = xfail_test(reason="RuntimeError: nGraph does not support the "TfIdfVectorizer") xfail_issue_33581 = xfail_test(reason="RuntimeError: nGraph does not support the following ONNX operations: " "GatherElements") -xfail_issue_33633 = xfail_test(reason="MaxPool: dilations unsupported") xfail_issue_35923 = xfail_test(reason="RuntimeError: PReLU without weights is not supported") xfail_issue_35927 = xfail_test(reason="RuntimeError: B has zero dimension that is not allowable") xfail_issue_36486 = xfail_test(reason="RuntimeError: HardSigmoid operation should be converted " @@ -99,7 +98,6 @@ xfail_issue_44965 = xfail_test(reason="Expected: RuntimeError: value info has no xfail_issue_44968 = xfail_test(reason="Expected: Unsupported dynamic op: Squeeze") xfail_issue_47323 = xfail_test(reason="RuntimeError: The plugin does not support FP64") xfail_issue_47337 = xfail_test(reason="RuntimeError: Unsupported dynamic ops: v1::OneHot") -xfail_issue_33593 = xfail_test(reason="Current implementation of MaxPool doesn't support indices output") xfail_issue_55760 = xfail_test(reason="RuntimeError: Reversed axis have axes above the source space shape") # Model MSFT issues: @@ -142,7 +140,6 @@ xfail_issue_63036 = xfail_test(reason="Changes in ConvTranspose padding") xfail_issue_63039 = xfail_test(reason="Result mismatches with UINT8 operations") xfail_issue_63043 = xfail_test(reason="Recurrent node expects constants as W, R, B inputs.") xfail_issue_63044 = xfail_test(reason="ONNX opset 14 operation: Trilu") -xfail_issue_63045 = xfail_test(reason="Maxpool with strides, padding and dilations fail") skip_rng_tests = pytest.mark.skip(reason="Tests use random number generator with no seed.") xfail_issue_63136 = xfail_test(reason="Unsupported operation: CastLike") diff --git a/runtime/bindings/python/tests_compatibility/test_onnx/test_backend.py b/runtime/bindings/python/tests_compatibility/test_onnx/test_backend.py index 3321e036e24..baafa738e26 100644 --- a/runtime/bindings/python/tests_compatibility/test_onnx/test_backend.py +++ b/runtime/bindings/python/tests_compatibility/test_onnx/test_backend.py @@ -11,11 +11,9 @@ from tests_compatibility import ( xfail_issue_33538, xfail_issue_33581, xfail_issue_33589, - xfail_issue_33593, xfail_issue_33595, xfail_issue_33596, xfail_issue_33606, - xfail_issue_33633, xfail_issue_33651, xfail_issue_38091, xfail_issue_38699, @@ -48,7 +46,6 @@ from tests_compatibility import ( xfail_issue_63039, xfail_issue_63043, xfail_issue_63044, - xfail_issue_63045, xfail_issue_63136, xfail_issue_63137, xfail_issue_63138, @@ -132,7 +129,6 @@ tests_expected_to_fail = [ "OnnxBackendNodeModelTest.test_scatter_elements_with_negative_indices_cpu", "OnnxBackendNodeModelTest.test_gather_negative_indices_cpu", ), - (xfail_issue_33633, "OnnxBackendNodeModelTest.test_maxpool_2d_dilations_cpu"), ( xfail_issue_55760, "OnnxBackendNodeModelTest.test_argmax_negative_axis_keepdims_example_select_last_index_cpu", @@ -327,11 +323,6 @@ tests_expected_to_fail = [ "OnnxBackendNodeModelTest.test_squeeze_cpu", "OnnxBackendNodeModelTest.test_squeeze_negative_axes_cpu", ), - ( - xfail_issue_33593, - "OnnxBackendNodeModelTest.test_maxpool_with_argmax_2d_precomputed_strides_cpu", - "OnnxBackendNodeModelTest.test_maxpool_with_argmax_2d_precomputed_pads_cpu", - ), (xfail_issue_58033, "OnnxBackendNodeModelTest.test_einsum_batch_diagonal_cpu"), ( xfail_issue_63033, @@ -376,11 +367,6 @@ tests_expected_to_fail = [ "OnnxBackendNodeModelTest.test_triu_square_neg_cpu", "OnnxBackendNodeModelTest.test_triu_zero_cpu", ), - ( - xfail_issue_63045, - "OnnxBackendPyTorchConvertedModelTest.test_MaxPool1d_stride_padding_dilation_cpu", - "OnnxBackendPyTorchConvertedModelTest.test_MaxPool2d_stride_padding_dilation_cpu", - ), ( skip_rng_tests, "OnnxBackendNodeModelTest.test_bernoulli_cpu",