diff --git a/ngraph/frontend/onnx_import/src/op/reduce.cpp b/ngraph/frontend/onnx_import/src/op/reduce.cpp index 29f117055b9..4f9517b07c8 100644 --- a/ngraph/frontend/onnx_import/src/op/reduce.cpp +++ b/ngraph/frontend/onnx_import/src/op/reduce.cpp @@ -21,6 +21,7 @@ #include "exceptions.hpp" #include "ngraph/builder/norm.hpp" #include "ngraph/node.hpp" +#include "op/identity.hpp" #include "utils/common.hpp" namespace ngraph @@ -31,7 +32,64 @@ namespace ngraph { namespace { - std::shared_ptr get_reduction_axes(const Node& node) + std::shared_ptr get_dynamic_all_axes_range(const Node& node) + { + const auto input = node.get_ng_inputs().at(0); + const auto shape_of_input = std::make_shared(input); + const auto scalar = + default_opset::Constant::create(element::i32, Shape{1}, {0}); + const auto rank_of_input = + std::make_shared(shape_of_input); + const auto rank_of_input_scalar = + std::make_shared(rank_of_input, scalar); + const auto start = default_opset::Constant::create(element::i32, Shape{}, {0}); + const auto step = default_opset::Constant::create(element::i32, Shape{}, {1}); + return std::make_shared( + start, rank_of_input_scalar, step, element::i64); + } + + std::shared_ptr get_reduction_axes_from_input(const Node& node) + { + const std::int64_t noop_with_empty_axes = + node.get_attribute_value("noop_with_empty_axes", 0); + const auto input = node.get_ng_inputs().at(0); + const auto input_rank = node.get_ng_inputs().at(0).get_partial_shape().rank(); + if (node.get_ng_inputs().size() > 1) + { + const auto reduction_axes = node.get_ng_inputs().at(1); + const auto reduction_axes_rank = reduction_axes.get_partial_shape().rank(); + NGRAPH_CHECK(reduction_axes.get_partial_shape().is_static(), + "The axes tensor's shape needs to be known(static). Node: ", + node.get_description()); + + if (reduction_axes_rank.get_length() != 0 && + reduction_axes.get_shape() != Shape{0}) + { + return reduction_axes.get_node_shared_ptr(); + } + } + + if (noop_with_empty_axes) + { + return nullptr; + } + else + { + if (input_rank.is_static()) + { + auto all_axes = onnx_import::common::get_monotonic_range( + input_rank.get_length()); + return default_opset::Constant::create( + element::i64, Shape{all_axes.size()}, all_axes); + } + else + { + return get_dynamic_all_axes_range(node); + } + } + } + + std::shared_ptr get_reduction_axes_from_attr(const Node& node) { auto reduction_axes = node.get_attribute_value>("axes", {}); @@ -40,13 +98,15 @@ namespace ngraph if (reduction_axes.empty()) { - NGRAPH_CHECK(input_rank.is_static(), - "The input tensor's rank needs to be known(static) when the " - "'axes' attribute is not specified. Node: ", - node.get_description()); - - reduction_axes = onnx_import::common::get_monotonic_range( - input_rank.get_length()); + if (input_rank.is_static()) + { + reduction_axes = onnx_import::common::get_monotonic_range( + input_rank.get_length()); + } + else + { + return get_dynamic_all_axes_range(node); + } } if (input_rank.is_static()) @@ -66,19 +126,36 @@ namespace ngraph template std::shared_ptr - make_ng_reduction_op(const Node& node, const Output& ng_input) + make_ng_reduction_op(const Node& node, + const Output& ng_input, + bool axes_as_attr = true) { - const auto reduction_axes = get_reduction_axes(node); const std::int64_t keepdims = node.get_attribute_value("keepdims", 1); - const auto op_node = std::make_shared( - ng_input, reduction_axes, static_cast(keepdims)); - - return op_node; + const auto reduction_axes = axes_as_attr ? get_reduction_axes_from_attr(node) + : get_reduction_axes_from_input(node); + if (reduction_axes != nullptr) + { + return std::make_shared( + ng_input, reduction_axes, static_cast(keepdims)); + } + else + { + return op::set_1::identity(node).at(0).get_node_shared_ptr(); + } } } // namespace + namespace set_13 + { + OutputVector reduce_sum(const Node& node) + { + return {make_ng_reduction_op( + node, node.get_ng_inputs().at(0), false)}; + } + } // namespace set_13 + namespace set_1 { OutputVector reduce_log_sum(const Node& node) diff --git a/ngraph/frontend/onnx_import/src/op/reduce.hpp b/ngraph/frontend/onnx_import/src/op/reduce.hpp index 0d601c1a0eb..8d8620cfb46 100644 --- a/ngraph/frontend/onnx_import/src/op/reduce.hpp +++ b/ngraph/frontend/onnx_import/src/op/reduce.hpp @@ -24,6 +24,22 @@ namespace ngraph { namespace op { + namespace set_13 + { + /// \brief Compute the sum of the input tensor's elements along the provided + /// axes. + /// + /// \par Overview + /// The output tensor has the same rank as the input if Node attribute keepdims + /// equals 1. If keepdims equals 0, then the output tensor has the reduced + /// dimension pruned. + /// + /// \param[in] node The ONNX node representing operation. + /// + /// \return The nGraph node equivalent of the ONNX operation. + /// + OutputVector reduce_sum(const Node& node); + } namespace set_1 { /// \brief Compute the log sum of the input tensor's elements along the diff --git a/ngraph/frontend/onnx_import/src/ops_bridge.cpp b/ngraph/frontend/onnx_import/src/ops_bridge.cpp index 230834216e6..fb9be12dfbd 100644 --- a/ngraph/frontend/onnx_import/src/ops_bridge.cpp +++ b/ngraph/frontend/onnx_import/src/ops_bridge.cpp @@ -404,6 +404,7 @@ namespace ngraph REGISTER_OPERATOR("ReduceMin", 1, reduce_min); REGISTER_OPERATOR("ReduceProd", 1, reduce_prod); REGISTER_OPERATOR("ReduceSum", 1, reduce_sum); + REGISTER_OPERATOR("ReduceSum", 13, reduce_sum); REGISTER_OPERATOR("ReduceSumSquare", 1, reduce_sum_square); REGISTER_OPERATOR("Relu", 1, relu); REGISTER_OPERATOR("Reshape", 1, reshape); diff --git a/ngraph/python/tests/__init__.py b/ngraph/python/tests/__init__.py index d76c7f1d0d3..bf08d8cfa1b 100644 --- a/ngraph/python/tests/__init__.py +++ b/ngraph/python/tests/__init__.py @@ -167,8 +167,8 @@ xfail_issue_38735 = xfail_test(reason="RuntimeError: nGraph does not support the "ai.onnx.preview.training.Adagrad") xfail_issue_38736 = xfail_test(reason="RuntimeError: nGraph does not support the following ONNX operations:" "NegativeLogLikelihoodLoss") -xfail_issue_43523 = xfail_test(reason="onnx.onnx_cpp2py_export.checker.ValidationError:" - " Unrecognized attribute: axes for operator ReduceSum") +xfail_issue_45177 = xfail_test(reason="RuntimeError: axes has zero dimension which is not allowed") +xfail_issue_45180 = xfail_test(reason="RuntimeError: Unsupported dynamic op: ReduceSum") xfail_issue_44839 = xfail_test(reason="Huge computation missmatch") xfail_issue_44848 = xfail_test(reason="E Unsupported dynamic op: Range") xfail_issue_44851 = xfail_test(reason="E Unsupported dynamic op: Broadcast") diff --git a/ngraph/python/tests/test_onnx/test_backend.py b/ngraph/python/tests/test_onnx/test_backend.py index ac68f751792..08f2b68308d 100644 --- a/ngraph/python/tests/test_onnx/test_backend.py +++ b/ngraph/python/tests/test_onnx/test_backend.py @@ -76,7 +76,8 @@ from tests import (BACKEND_NAME, xfail_issue_40319, xfail_issue_40485, xfail_issue_41894, - xfail_issue_43523, + xfail_issue_45177, + xfail_issue_45180, xfail_issue_43742, xfail_issue_44839, xfail_issue_44848, @@ -632,38 +633,27 @@ tests_expected_to_fail = [ "OnnxBackendNodeModelTest.test_adagrad_cpu"), (xfail_issue_41894, "OnnxBackendNodeModelTest.test_max_uint16_cpu"), - (xfail_issue_43523, - "OnnxBackendNodeModelTest.test_reduce_sum_do_not_keepdims_example_cpu", - "OnnxBackendNodeModelTest.test_reduce_sum_do_not_keepdims_random_cpu", - "OnnxBackendNodeModelTest.test_reduce_sum_keepdims_example_cpu", - "OnnxBackendNodeModelTest.test_reduce_sum_keepdims_random_cpu", - "OnnxBackendNodeModelTest.test_reduce_sum_negative_axes_keepdims_example_cpu", + (xfail_issue_45177, "OnnxBackendNodeModelTest.test_reduce_sum_default_axes_keepdims_example_cpu", "OnnxBackendNodeModelTest.test_reduce_sum_default_axes_keepdims_random_cpu", "OnnxBackendNodeModelTest.test_reduce_sum_empty_axes_input_noop_example_cpu", "OnnxBackendNodeModelTest.test_reduce_sum_empty_axes_input_noop_random_cpu", "OnnxBackendNodeModelTest.test_reduce_sum_negative_axes_keepdims_random_cpu"), + (xfail_issue_45180, + "OnnxBackendNodeModelTest.test_reduce_sum_do_not_keepdims_example_cpu", + "OnnxBackendNodeModelTest.test_reduce_sum_do_not_keepdims_random_cpu", + "OnnxBackendNodeModelTest.test_reduce_sum_keepdims_example_cpu", + "OnnxBackendNodeModelTest.test_reduce_sum_keepdims_random_cpu", + "OnnxBackendNodeModelTest.test_reduce_sum_negative_axes_keepdims_example_cpu"), (xfail_issue_43742, "OnnxBackendNodeModelTest.test_if_cpu", "OnnxBackendNodeModelTest.test_if_seq_cpu"), (xfail_issue_44839, "OnnxBackendNodeModelTest.test_logsoftmax_axis_0_cpu", - "OnnxBackendNodeModelTest.test_logsoftmax_axis_0_expanded_cpu", "OnnxBackendNodeModelTest.test_logsoftmax_axis_1_cpu", - "OnnxBackendNodeModelTest.test_logsoftmax_axis_1_expanded_cpu", - "OnnxBackendNodeModelTest.test_logsoftmax_axis_2_expanded_cpu", - "OnnxBackendNodeModelTest.test_logsoftmax_default_axis_expanded_cpu", - "OnnxBackendNodeModelTest.test_logsoftmax_large_number_expanded_cpu", - "OnnxBackendNodeModelTest.test_logsoftmax_negative_axis_expanded_cpu", "OnnxBackendNodeModelTest.test_softmax_axis_0_cpu", - "OnnxBackendNodeModelTest.test_softmax_axis_0_expanded_cpu", "OnnxBackendNodeModelTest.test_softmax_axis_1_cpu", - "OnnxBackendNodeModelTest.test_softmax_axis_1_expanded_cpu", - "OnnxBackendNodeModelTest.test_softmax_axis_2_expanded_cpu", "OnnxBackendNodeModelTest.test_softmax_default_axis_cpu", - "OnnxBackendNodeModelTest.test_softmax_default_axis_expanded_cpu", - "OnnxBackendNodeModelTest.test_softmax_large_number_expanded_cpu", - "OnnxBackendNodeModelTest.test_softmax_negative_axis_expanded_cpu", "OnnxBackendNodeModelTest.test_hardmax_axis_0_cpu", "OnnxBackendNodeModelTest.test_hardmax_axis_1_cpu", "OnnxBackendNodeModelTest.test_hardmax_default_axis_cpu",), diff --git a/ngraph/python/tests/test_onnx/test_ops_reduction.py b/ngraph/python/tests/test_onnx/test_ops_reduction.py index abc4ad46a02..c46399f2e15 100644 --- a/ngraph/python/tests/test_onnx/test_ops_reduction.py +++ b/ngraph/python/tests/test_onnx/test_ops_reduction.py @@ -17,9 +17,12 @@ import numpy as np import onnx import pytest -from tests.test_onnx.utils import run_node -from tests import (xfail_issue_35925, - xfail_issue_43523) +from tests.runtime import get_runtime +from tests.test_onnx.utils import ( + run_node, + import_onnx_model, +) +from tests import xfail_issue_35925 reduce_data = np.array([[[5, 1], [20, 2]], [[30, 1], [40, 2]], [[55, 1], [60, 2]]], dtype=np.float32) reduce_axis_parameters = [ @@ -32,14 +35,17 @@ reduce_axis_parameters = [ (0, 1, 2) ] -reduce_operation_parameters = [ +reduce_operation_parameters_as_attr = [ ("ReduceMax", np.max), ("ReduceMin", np.min), ("ReduceMean", np.mean), - pytest.param("ReduceSum", np.sum, marks=xfail_issue_43523), ("ReduceProd", np.prod) ] +reduce_operation_parameters_as_const = [ + ("ReduceSum", np.sum), +] + def import_and_compute(op_type, input_data, **node_attrs): data_inputs = [np.array(input_data)] @@ -47,25 +53,59 @@ def import_and_compute(op_type, input_data, **node_attrs): return run_node(node, data_inputs).pop() -@pytest.mark.parametrize("operation, ref_operation", [ - ("ReduceMax", np.max), - ("ReduceMin", np.min), - ("ReduceMean", np.mean), - ("ReduceSum", np.sum), - ("ReduceProd", np.prod) -]) +def import_and_compute_with_axes_as_const(op_type, data, axes, **node_attrs): + data_input = np.array(data) + axes_input = np.array(axes, dtype=int) + axes_const_node = onnx.helper.make_node( + "Constant", + inputs=[], + outputs=["const_axes"], + value=onnx.helper.make_tensor( + name="const_axes", + data_type=onnx.TensorProto.INT64, + dims=axes_input.shape, + vals=axes_input.flatten(), + ), + ) + node = onnx.helper.make_node( + op_type, inputs=["x", "const_axes"], outputs=["y"], **node_attrs + ) + graph = onnx.helper.make_graph( + [axes_const_node, node], + "test_graph", + [onnx.helper.make_tensor_value_info("x", onnx.TensorProto.FLOAT, data_input.shape)], + [onnx.helper.make_tensor_value_info("y", onnx.TensorProto.FLOAT, ())], + ) + + model = onnx.helper.make_model(graph, producer_name="ngraph ONNX Importer") + model.opset_import[0].version = 13 + ng_model_function = import_onnx_model(model) + runtime = get_runtime() + computation = runtime.computation(ng_model_function) + return computation(data_input)[0] + + +@pytest.mark.parametrize("operation, ref_operation", + reduce_operation_parameters_as_attr + reduce_operation_parameters_as_const) def test_reduce_operation_keepdims_none_axes(operation, ref_operation): assert np.array_equal(import_and_compute(operation, reduce_data, keepdims=True), ref_operation(reduce_data, keepdims=True)) -@pytest.mark.parametrize("operation, ref_operation", reduce_operation_parameters) +@pytest.mark.parametrize("operation, ref_operation", reduce_operation_parameters_as_attr) @pytest.mark.parametrize("axes", reduce_axis_parameters) -def test_reduce_operation_keepdims(operation, ref_operation, axes): +def test_reduce_operation_keepdims_with_axes_as_attr(operation, ref_operation, axes): assert np.array_equal(import_and_compute(operation, reduce_data, axes=axes, keepdims=True), ref_operation(reduce_data, keepdims=True, axis=axes)) +@pytest.mark.parametrize("operation, ref_operation", reduce_operation_parameters_as_const) +@pytest.mark.parametrize("axes", reduce_axis_parameters) +def test_reduce_operation_keepdims_with_axes_as_const(operation, ref_operation, axes): + assert np.array_equal(import_and_compute_with_axes_as_const(operation, reduce_data, axes, keepdims=True), + ref_operation(reduce_data, keepdims=True, axis=axes)) + + @pytest.mark.parametrize("axes", [ pytest.param(None, marks=xfail_issue_35925), (0,), @@ -75,8 +115,8 @@ def test_reduce_operation_keepdims(operation, ref_operation, axes): (0, 2), (1, 2), pytest.param((0, 1, 2), marks=xfail_issue_35925)]) -@pytest.mark.parametrize("operation, ref_operation", reduce_operation_parameters) -def test_reduce_operation_no_keepdims(operation, ref_operation, axes): +@pytest.mark.parametrize("operation, ref_operation", reduce_operation_parameters_as_attr) +def test_reduce_operation_no_keepdims_axes_as_attr(operation, ref_operation, axes): if axes: assert np.array_equal(import_and_compute(operation, reduce_data, axes=axes, keepdims=False), ref_operation(reduce_data, keepdims=False, axis=axes)) @@ -85,6 +125,28 @@ def test_reduce_operation_no_keepdims(operation, ref_operation, axes): ref_operation(reduce_data, keepdims=False)) +@pytest.mark.parametrize("axes", [ + pytest.param(None, marks=xfail_issue_35925), + (0,), + (1,), + (2,), + (0, 1), + (0, 2), + (1, 2), + pytest.param((0, 1, 2), marks=xfail_issue_35925)]) +@pytest.mark.parametrize("operation, ref_operation", reduce_operation_parameters_as_const) +def test_reduce_operation_no_keepdims_axes_as_const(operation, ref_operation, axes): + if axes: + assert np.array_equal(import_and_compute_with_axes_as_const(operation, + reduce_data, + axes, + keepdims=False), + ref_operation(reduce_data, keepdims=False, axis=axes)) + else: + assert np.array_equal(import_and_compute(operation, reduce_data, keepdims=False), + ref_operation(reduce_data, keepdims=False)) + + @pytest.mark.parametrize("reduction_axes", [(0,), (0, 2), (0, 1, 2)]) def test_reduce_l1(reduction_axes): shape = [2, 4, 3, 2] diff --git a/ngraph/test/models/onnx/reduce_sum_13_axes_as_0_dim_input.prototxt b/ngraph/test/models/onnx/reduce_sum_13_axes_as_0_dim_input.prototxt new file mode 100644 index 00000000000..a840a9e35dd --- /dev/null +++ b/ngraph/test/models/onnx/reduce_sum_13_axes_as_0_dim_input.prototxt @@ -0,0 +1,75 @@ +ir_version: 7 +producer_name: "backend-test" +graph { + node { + input: "data" + input: "axes" + output: "reduced" + op_type: "ReduceSum" + attribute { + name: "keepdims" + i: 1 + type: INT + } + attribute { + name: "noop_with_empty_axes" + i: 1 + type: INT + } + } + name: "test_reduce_sum_empty_axes_input_noop_example" + input { + name: "data" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 3 + } + dim { + dim_value: 2 + } + dim { + dim_value: 2 + } + } + } + } + } + input { + name: "axes" + type { + tensor_type { + elem_type: 7 + shape { + dim { + dim_value: 0 + } + } + } + } + } + output { + name: "reduced" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 3 + } + dim { + dim_value: 2 + } + dim { + dim_value: 2 + } + } + } + } + } +} +opset_import { + version: 13 +} diff --git a/ngraph/test/models/onnx/reduce_sum_13_axes_as_constant.prototxt b/ngraph/test/models/onnx/reduce_sum_13_axes_as_constant.prototxt new file mode 100644 index 00000000000..f3f2dd09c59 --- /dev/null +++ b/ngraph/test/models/onnx/reduce_sum_13_axes_as_constant.prototxt @@ -0,0 +1,67 @@ +ir_version: 3 +producer_name: "nGraph ONNX Importer" +graph { + node { + output: "axes" + op_type: "Constant" + attribute { + name: "value" + t { + dims: 4 + data_type: 7 + int64_data: 0 + int64_data: 1 + int64_data: 2 + int64_data: 3 + name: "const_tensor" + } + type: TENSOR + } + } + node { + input: "data" + input: "axes" + output: "B" + op_type: "ReduceSum" + } + name: "compute_graph" + input { + name: "data" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 4 + } + dim { + dim_value: 4 + } + } + } + } + } + + output { + name: "B" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + } + } + } + } +} +opset_import { + version: 13 +} diff --git a/ngraph/test/models/onnx/reduce_sum_13_axes_as_constant_keepdims_off.prototxt b/ngraph/test/models/onnx/reduce_sum_13_axes_as_constant_keepdims_off.prototxt new file mode 100644 index 00000000000..066d22727b6 --- /dev/null +++ b/ngraph/test/models/onnx/reduce_sum_13_axes_as_constant_keepdims_off.prototxt @@ -0,0 +1,72 @@ +ir_version: 3 +producer_name: "nGraph ONNX Importer" +graph { + node { + output: "axes" + op_type: "Constant" + attribute { + name: "value" + t { + dims: 4 + data_type: 7 + int64_data: 0 + int64_data: 1 + int64_data: 2 + int64_data: 3 + name: "const_tensor" + } + type: TENSOR + } + } + node { + input: "data" + input: "axes" + attribute { + name: "keepdims" + i: 0 + type: INT + } + output: "B" + op_type: "ReduceSum" + } + name: "compute_graph" + input { + name: "data" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 4 + } + dim { + dim_value: 4 + } + } + } + } + } + + output { + name: "B" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + } + } + } + } +} +opset_import { + version: 13 +} diff --git a/ngraph/test/models/onnx/reduce_sum_13_axes_as_constant_single_axis.prototxt b/ngraph/test/models/onnx/reduce_sum_13_axes_as_constant_single_axis.prototxt new file mode 100644 index 00000000000..4fe71a1dd07 --- /dev/null +++ b/ngraph/test/models/onnx/reduce_sum_13_axes_as_constant_single_axis.prototxt @@ -0,0 +1,67 @@ +ir_version: 3 +producer_name: "nGraph ONNX Importer" +graph { + node { + output: "axes" + op_type: "Constant" + attribute { + name: "value" + t { + dims: 1 + data_type: 7 + int64_data: 1 + name: "const_tensor" + } + type: TENSOR + } + } + node { + input: "data" + input: "axes" + output: "B" + op_type: "ReduceSum" + } + name: "compute_graph" + input { + name: "data" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_value: 2 + } + dim { + dim_value: 3 + } + } + } + } + } + + output { + name: "B" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_value: 2 + } + dim { + dim_value: 3 + } + } + } + } + } +} +opset_import { + version: 13 +} diff --git a/ngraph/test/models/onnx/reduce_sum_13_axes_as_input.prototxt b/ngraph/test/models/onnx/reduce_sum_13_axes_as_input.prototxt new file mode 100644 index 00000000000..6a458e74393 --- /dev/null +++ b/ngraph/test/models/onnx/reduce_sum_13_axes_as_input.prototxt @@ -0,0 +1,63 @@ +ir_version: 3 +producer_name: "nGraph ONNX Importer" +graph { + node { + input: "data" + input: "axes" + output: "B" + op_type: "ReduceSum" + } + name: "compute_graph" + input { + name: "data" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 4 + } + dim { + dim_value: 4 + } + } + } + } + } + input { + name: "axes" + type { + tensor_type { + elem_type: 7 + shape { + dim { + dim_value: 4 + } + } + } + } + } + + output { + name: "B" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + } + } + } + } +} +opset_import { + version: 13 +} diff --git a/ngraph/test/models/onnx/reduce_sum_13_axes_empty.prototxt b/ngraph/test/models/onnx/reduce_sum_13_axes_empty.prototxt new file mode 100644 index 00000000000..44d655723cf --- /dev/null +++ b/ngraph/test/models/onnx/reduce_sum_13_axes_empty.prototxt @@ -0,0 +1,49 @@ +ir_version: 3 +producer_name: "nGraph ONNX Importer" +graph { + node { + input: "data" + output: "B" + op_type: "ReduceSum" + } + name: "compute_graph" + input { + name: "data" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 4 + } + dim { + dim_value: 4 + } + } + } + } + } + + output { + name: "B" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + } + } + } + } +} +opset_import { + version: 13 +} diff --git a/ngraph/test/models/onnx/reduce_sum_13_axes_empty_dynamic_rank_input.prototxt b/ngraph/test/models/onnx/reduce_sum_13_axes_empty_dynamic_rank_input.prototxt new file mode 100644 index 00000000000..61bb13255b4 --- /dev/null +++ b/ngraph/test/models/onnx/reduce_sum_13_axes_empty_dynamic_rank_input.prototxt @@ -0,0 +1,35 @@ +ir_version: 3 +producer_name: "nGraph ONNX Importer" +graph { + node { + input: "data" + output: "B" + op_type: "ReduceSum" + } + name: "compute_graph" + input { + name: "data" + type { + tensor_type { + elem_type: 1 + } + } + } + + output { + name: "B" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + } + } + } + } +} +opset_import { + version: 13 +} diff --git a/ngraph/test/models/onnx/reduce_sum_13_axes_empty_with_noop.prototxt b/ngraph/test/models/onnx/reduce_sum_13_axes_empty_with_noop.prototxt new file mode 100644 index 00000000000..e44b67dcd83 --- /dev/null +++ b/ngraph/test/models/onnx/reduce_sum_13_axes_empty_with_noop.prototxt @@ -0,0 +1,63 @@ +ir_version: 3 +producer_name: "nGraph ONNX Importer" +graph { + node { + input: "data" + output: "B" + op_type: "ReduceSum" + attribute { + name: "noop_with_empty_axes" + i: 1 + type: INT + } + } + name: "compute_graph" + input { + name: "data" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 4 + } + dim { + dim_value: 4 + } + } + } + } + } + + output { + name: "B" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 4 + } + dim { + dim_value: 4 + } + } + } + } + } +} +opset_import { + version: 13 +} diff --git a/ngraph/test/models/onnx/reduce_sum_13_axes_empty_without_noop.prototxt b/ngraph/test/models/onnx/reduce_sum_13_axes_empty_without_noop.prototxt new file mode 100644 index 00000000000..73c22f7739c --- /dev/null +++ b/ngraph/test/models/onnx/reduce_sum_13_axes_empty_without_noop.prototxt @@ -0,0 +1,63 @@ +ir_version: 3 +producer_name: "nGraph ONNX Importer" +graph { + node { + input: "data" + output: "B" + op_type: "ReduceSum" + attribute { + name: "noop_with_empty_axes" + i: 0 + type: INT + } + } + name: "compute_graph" + input { + name: "data" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 4 + } + dim { + dim_value: 4 + } + } + } + } + } + + output { + name: "B" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 4 + } + dim { + dim_value: 4 + } + } + } + } + } +} +opset_import { + version: 13 +} diff --git a/ngraph/test/models/onnx/reduce_sum_13_input_dynamic.prototxt b/ngraph/test/models/onnx/reduce_sum_13_input_dynamic.prototxt new file mode 100644 index 00000000000..bacf4619341 --- /dev/null +++ b/ngraph/test/models/onnx/reduce_sum_13_input_dynamic.prototxt @@ -0,0 +1,89 @@ +ir_version: 3 +producer_name: "nGraph ONNX Importer" +graph { + node { + output: "B" + op_type: "Constant" + attribute { + name: "value" + t { + dims: 1 + data_type: 7 + int64_data: -1 + name: "const_tensor" + } + type: TENSOR + } + } + node { + input: "A" + input: "B" + output: "X" + attribute { + name: "special_zero" + i: 1 + type: INT + } + op_type: "Reshape" + } + node { + input: "X" + output: "C" + attribute { + name: "value" + t { + dims: 1 + data_type: 7 + int64_data: 5 + name: "const_tensor" + } + type: TENSOR + } + op_type: "ConstantOfShape" + } + node { + input: "C" + output: "Y" + op_type: "ReduceSum" + } + name: "compute_graph" + input { + name: "A" + type { + tensor_type { + elem_type: 7 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 1 + } + dim { + dim_value: 4 + } + dim { + dim_value: 4 + } + } + } + } + } + + output { + name: "Y" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + } + } + } + } +} +opset_import { + version: 13 +} diff --git a/ngraph/test/models/onnx/reduce_sum_dynamic_rank_input.prototxt b/ngraph/test/models/onnx/reduce_sum_dynamic_rank_input.prototxt new file mode 100644 index 00000000000..3a9f2fd52ff --- /dev/null +++ b/ngraph/test/models/onnx/reduce_sum_dynamic_rank_input.prototxt @@ -0,0 +1,34 @@ +ir_version: 3 +producer_name: "nGraph ONNX Importer" +graph { + node { + input: "A" + output: "B" + op_type: "ReduceSum" + } + name: "compute_graph" + input { + name: "A" + type { + tensor_type { + elem_type: 1 + } + } + } + output { + name: "B" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + } + } + } + } +} +opset_import { + version: 1 +} diff --git a/ngraph/test/onnx/onnx_import.in.cpp b/ngraph/test/onnx/onnx_import.in.cpp index f4472b9179e..aef86887575 100644 --- a/ngraph/test/onnx/onnx_import.in.cpp +++ b/ngraph/test/onnx/onnx_import.in.cpp @@ -1102,6 +1102,33 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_reduce_sum) test_case.run(); } +NGRAPH_TEST(${BACKEND_NAME}, onnx_model_reduce_sum_dynamic_rank_input) +{ + auto function = onnx_import::import_onnx_model( + file_util::path_join(SERIALIZED_ZOO, "onnx/reduce_sum_dynamic_rank_input.prototxt")); + auto test_case = test::TestCase(function); + test_case.add_input(Shape{1, 1, 4, 4}, + {1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f}); + + test_case.add_expected_output(Shape{1, 1, 1, 1}, {16.0f}); + test_case.run(); +} + NGRAPH_TEST(${BACKEND_NAME}, onnx_model_reduce_sum_square) { auto function = onnx_import::import_onnx_model( @@ -1121,6 +1148,242 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_reduce_sum_square) test_case.run(); } +NGRAPH_TEST(${BACKEND_NAME}, onnx_model_reduce_sum_13_axes_as_constant) +{ + auto function = onnx_import::import_onnx_model( + file_util::path_join(SERIALIZED_ZOO, "onnx/reduce_sum_13_axes_as_constant.prototxt")); + + Inputs inputs{test::NDArray({{{{1.0f, 1.0f, 1.0f, 1.0f}, + {1.0f, 1.0f, 1.0f, 1.0f}, + {1.0f, 1.0f, 1.0f, 1.0f}, + {1.0f, 1.0f, 1.0f, 1.0f}}}}) + .get_vector()}; + + auto test_case = test::TestCase(function); + + test_case.add_expected_output(Shape{1, 1, 1, 1}, {16.0f}); + + test_case.add_multiple_inputs(inputs); + test_case.run(); +} + +NGRAPH_TEST(${BACKEND_NAME}, onnx_model_reduce_sum_13_axes_as_constant_single_axis) +{ + auto function = onnx_import::import_onnx_model(file_util::path_join( + SERIALIZED_ZOO, "onnx/reduce_sum_13_axes_as_constant_single_axis.prototxt")); + + Inputs inputs{ + test::NDArray({{{1, 2, 3}, {4, 5, 6}}, {{7, 8, 9}, {10, 11, 12}}}).get_vector()}; + + auto test_case = test::TestCase(function); + + test_case.add_expected_output(Shape{2, 1, 3}, {5.0f, 7.0f, 9.0f, 17.0f, 19.0f, 21.0f}); + + test_case.add_multiple_inputs(inputs); + test_case.run(); +} + +NGRAPH_TEST(${BACKEND_NAME}, onnx_model_reduce_sum_13_axes_as_constant_keepdims_off) +{ + auto function = onnx_import::import_onnx_model(file_util::path_join( + SERIALIZED_ZOO, "onnx/reduce_sum_13_axes_as_constant_keepdims_off.prototxt")); + + // input data shape (1, 1, 4, 4) + Inputs inputs{test::NDArray({{{{1.0f, 1.0f, 1.0f, 1.0f}, + {1.0f, 1.0f, 1.0f, 1.0f}, + {1.0f, 1.0f, 1.0f, 1.0f}, + {1.0f, 1.0f, 1.0f, 1.0f}}}}) + .get_vector()}; + + auto test_case = test::TestCase(function); + + test_case.add_expected_output(Shape{}, {16.0f}); + + test_case.add_multiple_inputs(inputs); + test_case.run(); +} + +NGRAPH_TEST(${BACKEND_NAME}, onnx_model_reduce_sum_13_axes_as_input) +{ + auto function = onnx_import::import_onnx_model( + file_util::path_join(SERIALIZED_ZOO, "onnx/reduce_sum_13_axes_as_input.prototxt")); + + auto test_case = test::TestCase(function); + test_case.add_input({1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f}); + test_case.add_input({0, 1, 2, 3}); + + test_case.add_expected_output(Shape{1, 1, 1, 1}, {16.0f}); + test_case.run(); +} + +NGRAPH_TEST(${BACKEND_NAME}, onnx_model_reduce_sum_13_axes_as_0_dim_input) +{ + auto function = onnx_import::import_onnx_model( + file_util::path_join(SERIALIZED_ZOO, "onnx/reduce_sum_13_axes_as_0_dim_input.prototxt")); + + auto test_case = test::TestCase(function); + test_case.add_input( + {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f}); + + test_case.add_expected_output( + Shape{3, 2, 2}, + {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f}); + test_case.run(); +} + +NGRAPH_TEST(${BACKEND_NAME}, onnx_model_reduce_sum_13_input_dynamic) +{ + auto function = onnx_import::import_onnx_model( + file_util::path_join(SERIALIZED_ZOO, "onnx/reduce_sum_13_input_dynamic.prototxt")); + + auto test_case = test::TestCase(function); + test_case.add_input({1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}); + + test_case.add_expected_output(Shape{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, + {5}); + test_case.run(); +} + +NGRAPH_TEST(${BACKEND_NAME}, onnx_model_reduce_sum_13_axes_empty) +{ + auto function = onnx_import::import_onnx_model( + file_util::path_join(SERIALIZED_ZOO, "onnx/reduce_sum_13_axes_empty.prototxt")); + + auto test_case = test::TestCase(function); + test_case.add_input({1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f}); + + test_case.add_expected_output(Shape{1, 1, 1, 1}, {16.0f}); + test_case.run(); +} + +NGRAPH_TEST(${BACKEND_NAME}, onnx_model_reduce_sum_13_axes_empty_dynamic_rank_input) +{ + auto function = onnx_import::import_onnx_model(file_util::path_join( + SERIALIZED_ZOO, "onnx/reduce_sum_13_axes_empty_dynamic_rank_input.prototxt")); + + auto test_case = test::TestCase(function); + test_case.add_input(Shape{1, 1, 4, 4}, + {1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f}); + + test_case.add_expected_output(Shape{1, 1, 1, 1}, {16.0f}); + test_case.run(); +} + +NGRAPH_TEST(${BACKEND_NAME}, onnx_model_reduce_sum_13_axes_empty_with_noop) +{ + auto function = onnx_import::import_onnx_model( + file_util::path_join(SERIALIZED_ZOO, "onnx/reduce_sum_13_axes_empty_with_noop.prototxt")); + + auto test_case = test::TestCase(function); + test_case.add_input({1.f, + 1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f}); + + test_case.add_expected_output(Shape{1, 1, 4, 4}, + {1.f, + 1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f}); + test_case.run(); +} + +NGRAPH_TEST(${BACKEND_NAME}, onnx_model_reduce_sum_13_axes_empty_without_noop) +{ + auto function = onnx_import::import_onnx_model(file_util::path_join( + SERIALIZED_ZOO, "onnx/reduce_sum_13_axes_empty_without_noop.prototxt")); + + auto test_case = test::TestCase(function); + test_case.add_input({1.f, + 1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f, + 1.0f}); + + test_case.add_expected_output(Shape{1, 1, 1, 1}, {16.0f}); + test_case.run(); +} + NGRAPH_TEST(${BACKEND_NAME}, onnx_resize11_empty_constant_as_input) { // this model contains a Constant node with an empty underlying tensor diff --git a/ngraph/test/runtime/ie/unit_test.manifest b/ngraph/test/runtime/ie/unit_test.manifest index e41d779f697..67f6929685a 100644 --- a/ngraph/test/runtime/ie/unit_test.manifest +++ b/ngraph/test/runtime/ie/unit_test.manifest @@ -1168,6 +1168,13 @@ IE_CPU.roi_pooling_1x1_bilinear # Unsupported dynamic op IE_CPU.range_v4_trunc_inputs +IE_CPU.onnx_model_reduce_sum_13_axes_as_input +IE_CPU.onnx_model_reduce_sum_13_input_dynamic +IE_CPU.onnx_model_reduce_sum_13_axes_empty_dynamic_rank_input +IE_CPU.onnx_model_reduce_sum_dynamic_rank_input + +# Axes has zero dimension which is not allowed +IE_CPU.onnx_model_reduce_sum_13_axes_as_0_dim_input # output mismatch IE_CPU.gather_nd_batch_1d_from_3d_negative