Add support for ONNX Operator ReduceSum v13 and revise other Reduce operators (#3605)
This commit is contained in:
parent
e20a58d770
commit
928201bee4
@ -21,6 +21,7 @@
|
|||||||
#include "exceptions.hpp"
|
#include "exceptions.hpp"
|
||||||
#include "ngraph/builder/norm.hpp"
|
#include "ngraph/builder/norm.hpp"
|
||||||
#include "ngraph/node.hpp"
|
#include "ngraph/node.hpp"
|
||||||
|
#include "op/identity.hpp"
|
||||||
#include "utils/common.hpp"
|
#include "utils/common.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph
|
||||||
@ -31,7 +32,64 @@ namespace ngraph
|
|||||||
{
|
{
|
||||||
namespace
|
namespace
|
||||||
{
|
{
|
||||||
std::shared_ptr<default_opset::Constant> get_reduction_axes(const Node& node)
|
std::shared_ptr<ngraph::Node> get_dynamic_all_axes_range(const Node& node)
|
||||||
|
{
|
||||||
|
const auto input = node.get_ng_inputs().at(0);
|
||||||
|
const auto shape_of_input = std::make_shared<default_opset::ShapeOf>(input);
|
||||||
|
const auto scalar =
|
||||||
|
default_opset::Constant::create(element::i32, Shape{1}, {0});
|
||||||
|
const auto rank_of_input =
|
||||||
|
std::make_shared<default_opset::ShapeOf>(shape_of_input);
|
||||||
|
const auto rank_of_input_scalar =
|
||||||
|
std::make_shared<default_opset::Squeeze>(rank_of_input, scalar);
|
||||||
|
const auto start = default_opset::Constant::create(element::i32, Shape{}, {0});
|
||||||
|
const auto step = default_opset::Constant::create(element::i32, Shape{}, {1});
|
||||||
|
return std::make_shared<default_opset::Range>(
|
||||||
|
start, rank_of_input_scalar, step, element::i64);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<ngraph::Node> get_reduction_axes_from_input(const Node& node)
|
||||||
|
{
|
||||||
|
const std::int64_t noop_with_empty_axes =
|
||||||
|
node.get_attribute_value<std::int64_t>("noop_with_empty_axes", 0);
|
||||||
|
const auto input = node.get_ng_inputs().at(0);
|
||||||
|
const auto input_rank = node.get_ng_inputs().at(0).get_partial_shape().rank();
|
||||||
|
if (node.get_ng_inputs().size() > 1)
|
||||||
|
{
|
||||||
|
const auto reduction_axes = node.get_ng_inputs().at(1);
|
||||||
|
const auto reduction_axes_rank = reduction_axes.get_partial_shape().rank();
|
||||||
|
NGRAPH_CHECK(reduction_axes.get_partial_shape().is_static(),
|
||||||
|
"The axes tensor's shape needs to be known(static). Node: ",
|
||||||
|
node.get_description());
|
||||||
|
|
||||||
|
if (reduction_axes_rank.get_length() != 0 &&
|
||||||
|
reduction_axes.get_shape() != Shape{0})
|
||||||
|
{
|
||||||
|
return reduction_axes.get_node_shared_ptr();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (noop_with_empty_axes)
|
||||||
|
{
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
if (input_rank.is_static())
|
||||||
|
{
|
||||||
|
auto all_axes = onnx_import::common::get_monotonic_range<int64_t>(
|
||||||
|
input_rank.get_length());
|
||||||
|
return default_opset::Constant::create(
|
||||||
|
element::i64, Shape{all_axes.size()}, all_axes);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
return get_dynamic_all_axes_range(node);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<ngraph::Node> get_reduction_axes_from_attr(const Node& node)
|
||||||
{
|
{
|
||||||
auto reduction_axes =
|
auto reduction_axes =
|
||||||
node.get_attribute_value<std::vector<std::int64_t>>("axes", {});
|
node.get_attribute_value<std::vector<std::int64_t>>("axes", {});
|
||||||
@ -40,13 +98,15 @@ namespace ngraph
|
|||||||
|
|
||||||
if (reduction_axes.empty())
|
if (reduction_axes.empty())
|
||||||
{
|
{
|
||||||
NGRAPH_CHECK(input_rank.is_static(),
|
if (input_rank.is_static())
|
||||||
"The input tensor's rank needs to be known(static) when the "
|
{
|
||||||
"'axes' attribute is not specified. Node: ",
|
reduction_axes = onnx_import::common::get_monotonic_range<int64_t>(
|
||||||
node.get_description());
|
input_rank.get_length());
|
||||||
|
}
|
||||||
reduction_axes = onnx_import::common::get_monotonic_range<int64_t>(
|
else
|
||||||
input_rank.get_length());
|
{
|
||||||
|
return get_dynamic_all_axes_range(node);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (input_rank.is_static())
|
if (input_rank.is_static())
|
||||||
@ -66,19 +126,36 @@ namespace ngraph
|
|||||||
|
|
||||||
template <typename OpType>
|
template <typename OpType>
|
||||||
std::shared_ptr<ngraph::Node>
|
std::shared_ptr<ngraph::Node>
|
||||||
make_ng_reduction_op(const Node& node, const Output<ngraph::Node>& ng_input)
|
make_ng_reduction_op(const Node& node,
|
||||||
|
const Output<ngraph::Node>& ng_input,
|
||||||
|
bool axes_as_attr = true)
|
||||||
{
|
{
|
||||||
const auto reduction_axes = get_reduction_axes(node);
|
|
||||||
const std::int64_t keepdims =
|
const std::int64_t keepdims =
|
||||||
node.get_attribute_value<std::int64_t>("keepdims", 1);
|
node.get_attribute_value<std::int64_t>("keepdims", 1);
|
||||||
|
|
||||||
const auto op_node = std::make_shared<OpType>(
|
const auto reduction_axes = axes_as_attr ? get_reduction_axes_from_attr(node)
|
||||||
ng_input, reduction_axes, static_cast<bool>(keepdims));
|
: get_reduction_axes_from_input(node);
|
||||||
|
if (reduction_axes != nullptr)
|
||||||
return op_node;
|
{
|
||||||
|
return std::make_shared<OpType>(
|
||||||
|
ng_input, reduction_axes, static_cast<bool>(keepdims));
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
return op::set_1::identity(node).at(0).get_node_shared_ptr();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
namespace set_13
|
||||||
|
{
|
||||||
|
OutputVector reduce_sum(const Node& node)
|
||||||
|
{
|
||||||
|
return {make_ng_reduction_op<default_opset::ReduceSum>(
|
||||||
|
node, node.get_ng_inputs().at(0), false)};
|
||||||
|
}
|
||||||
|
} // namespace set_13
|
||||||
|
|
||||||
namespace set_1
|
namespace set_1
|
||||||
{
|
{
|
||||||
OutputVector reduce_log_sum(const Node& node)
|
OutputVector reduce_log_sum(const Node& node)
|
||||||
|
@ -24,6 +24,22 @@ namespace ngraph
|
|||||||
{
|
{
|
||||||
namespace op
|
namespace op
|
||||||
{
|
{
|
||||||
|
namespace set_13
|
||||||
|
{
|
||||||
|
/// \brief Compute the sum of the input tensor's elements along the provided
|
||||||
|
/// axes.
|
||||||
|
///
|
||||||
|
/// \par Overview
|
||||||
|
/// The output tensor has the same rank as the input if Node attribute keepdims
|
||||||
|
/// equals 1. If keepdims equals 0, then the output tensor has the reduced
|
||||||
|
/// dimension pruned.
|
||||||
|
///
|
||||||
|
/// \param[in] node The ONNX node representing operation.
|
||||||
|
///
|
||||||
|
/// \return The nGraph node equivalent of the ONNX operation.
|
||||||
|
///
|
||||||
|
OutputVector reduce_sum(const Node& node);
|
||||||
|
}
|
||||||
namespace set_1
|
namespace set_1
|
||||||
{
|
{
|
||||||
/// \brief Compute the log sum of the input tensor's elements along the
|
/// \brief Compute the log sum of the input tensor's elements along the
|
||||||
|
@ -404,6 +404,7 @@ namespace ngraph
|
|||||||
REGISTER_OPERATOR("ReduceMin", 1, reduce_min);
|
REGISTER_OPERATOR("ReduceMin", 1, reduce_min);
|
||||||
REGISTER_OPERATOR("ReduceProd", 1, reduce_prod);
|
REGISTER_OPERATOR("ReduceProd", 1, reduce_prod);
|
||||||
REGISTER_OPERATOR("ReduceSum", 1, reduce_sum);
|
REGISTER_OPERATOR("ReduceSum", 1, reduce_sum);
|
||||||
|
REGISTER_OPERATOR("ReduceSum", 13, reduce_sum);
|
||||||
REGISTER_OPERATOR("ReduceSumSquare", 1, reduce_sum_square);
|
REGISTER_OPERATOR("ReduceSumSquare", 1, reduce_sum_square);
|
||||||
REGISTER_OPERATOR("Relu", 1, relu);
|
REGISTER_OPERATOR("Relu", 1, relu);
|
||||||
REGISTER_OPERATOR("Reshape", 1, reshape);
|
REGISTER_OPERATOR("Reshape", 1, reshape);
|
||||||
|
@ -167,8 +167,8 @@ xfail_issue_38735 = xfail_test(reason="RuntimeError: nGraph does not support the
|
|||||||
"ai.onnx.preview.training.Adagrad")
|
"ai.onnx.preview.training.Adagrad")
|
||||||
xfail_issue_38736 = xfail_test(reason="RuntimeError: nGraph does not support the following ONNX operations:"
|
xfail_issue_38736 = xfail_test(reason="RuntimeError: nGraph does not support the following ONNX operations:"
|
||||||
"NegativeLogLikelihoodLoss")
|
"NegativeLogLikelihoodLoss")
|
||||||
xfail_issue_43523 = xfail_test(reason="onnx.onnx_cpp2py_export.checker.ValidationError:"
|
xfail_issue_45177 = xfail_test(reason="RuntimeError: axes has zero dimension which is not allowed")
|
||||||
" Unrecognized attribute: axes for operator ReduceSum")
|
xfail_issue_45180 = xfail_test(reason="RuntimeError: Unsupported dynamic op: ReduceSum")
|
||||||
xfail_issue_44839 = xfail_test(reason="Huge computation missmatch")
|
xfail_issue_44839 = xfail_test(reason="Huge computation missmatch")
|
||||||
xfail_issue_44848 = xfail_test(reason="E Unsupported dynamic op: Range")
|
xfail_issue_44848 = xfail_test(reason="E Unsupported dynamic op: Range")
|
||||||
xfail_issue_44851 = xfail_test(reason="E Unsupported dynamic op: Broadcast")
|
xfail_issue_44851 = xfail_test(reason="E Unsupported dynamic op: Broadcast")
|
||||||
|
@ -76,7 +76,8 @@ from tests import (BACKEND_NAME,
|
|||||||
xfail_issue_40319,
|
xfail_issue_40319,
|
||||||
xfail_issue_40485,
|
xfail_issue_40485,
|
||||||
xfail_issue_41894,
|
xfail_issue_41894,
|
||||||
xfail_issue_43523,
|
xfail_issue_45177,
|
||||||
|
xfail_issue_45180,
|
||||||
xfail_issue_43742,
|
xfail_issue_43742,
|
||||||
xfail_issue_44839,
|
xfail_issue_44839,
|
||||||
xfail_issue_44848,
|
xfail_issue_44848,
|
||||||
@ -632,38 +633,27 @@ tests_expected_to_fail = [
|
|||||||
"OnnxBackendNodeModelTest.test_adagrad_cpu"),
|
"OnnxBackendNodeModelTest.test_adagrad_cpu"),
|
||||||
(xfail_issue_41894,
|
(xfail_issue_41894,
|
||||||
"OnnxBackendNodeModelTest.test_max_uint16_cpu"),
|
"OnnxBackendNodeModelTest.test_max_uint16_cpu"),
|
||||||
(xfail_issue_43523,
|
(xfail_issue_45177,
|
||||||
"OnnxBackendNodeModelTest.test_reduce_sum_do_not_keepdims_example_cpu",
|
|
||||||
"OnnxBackendNodeModelTest.test_reduce_sum_do_not_keepdims_random_cpu",
|
|
||||||
"OnnxBackendNodeModelTest.test_reduce_sum_keepdims_example_cpu",
|
|
||||||
"OnnxBackendNodeModelTest.test_reduce_sum_keepdims_random_cpu",
|
|
||||||
"OnnxBackendNodeModelTest.test_reduce_sum_negative_axes_keepdims_example_cpu",
|
|
||||||
"OnnxBackendNodeModelTest.test_reduce_sum_default_axes_keepdims_example_cpu",
|
"OnnxBackendNodeModelTest.test_reduce_sum_default_axes_keepdims_example_cpu",
|
||||||
"OnnxBackendNodeModelTest.test_reduce_sum_default_axes_keepdims_random_cpu",
|
"OnnxBackendNodeModelTest.test_reduce_sum_default_axes_keepdims_random_cpu",
|
||||||
"OnnxBackendNodeModelTest.test_reduce_sum_empty_axes_input_noop_example_cpu",
|
"OnnxBackendNodeModelTest.test_reduce_sum_empty_axes_input_noop_example_cpu",
|
||||||
"OnnxBackendNodeModelTest.test_reduce_sum_empty_axes_input_noop_random_cpu",
|
"OnnxBackendNodeModelTest.test_reduce_sum_empty_axes_input_noop_random_cpu",
|
||||||
"OnnxBackendNodeModelTest.test_reduce_sum_negative_axes_keepdims_random_cpu"),
|
"OnnxBackendNodeModelTest.test_reduce_sum_negative_axes_keepdims_random_cpu"),
|
||||||
|
(xfail_issue_45180,
|
||||||
|
"OnnxBackendNodeModelTest.test_reduce_sum_do_not_keepdims_example_cpu",
|
||||||
|
"OnnxBackendNodeModelTest.test_reduce_sum_do_not_keepdims_random_cpu",
|
||||||
|
"OnnxBackendNodeModelTest.test_reduce_sum_keepdims_example_cpu",
|
||||||
|
"OnnxBackendNodeModelTest.test_reduce_sum_keepdims_random_cpu",
|
||||||
|
"OnnxBackendNodeModelTest.test_reduce_sum_negative_axes_keepdims_example_cpu"),
|
||||||
(xfail_issue_43742,
|
(xfail_issue_43742,
|
||||||
"OnnxBackendNodeModelTest.test_if_cpu",
|
"OnnxBackendNodeModelTest.test_if_cpu",
|
||||||
"OnnxBackendNodeModelTest.test_if_seq_cpu"),
|
"OnnxBackendNodeModelTest.test_if_seq_cpu"),
|
||||||
(xfail_issue_44839,
|
(xfail_issue_44839,
|
||||||
"OnnxBackendNodeModelTest.test_logsoftmax_axis_0_cpu",
|
"OnnxBackendNodeModelTest.test_logsoftmax_axis_0_cpu",
|
||||||
"OnnxBackendNodeModelTest.test_logsoftmax_axis_0_expanded_cpu",
|
|
||||||
"OnnxBackendNodeModelTest.test_logsoftmax_axis_1_cpu",
|
"OnnxBackendNodeModelTest.test_logsoftmax_axis_1_cpu",
|
||||||
"OnnxBackendNodeModelTest.test_logsoftmax_axis_1_expanded_cpu",
|
|
||||||
"OnnxBackendNodeModelTest.test_logsoftmax_axis_2_expanded_cpu",
|
|
||||||
"OnnxBackendNodeModelTest.test_logsoftmax_default_axis_expanded_cpu",
|
|
||||||
"OnnxBackendNodeModelTest.test_logsoftmax_large_number_expanded_cpu",
|
|
||||||
"OnnxBackendNodeModelTest.test_logsoftmax_negative_axis_expanded_cpu",
|
|
||||||
"OnnxBackendNodeModelTest.test_softmax_axis_0_cpu",
|
"OnnxBackendNodeModelTest.test_softmax_axis_0_cpu",
|
||||||
"OnnxBackendNodeModelTest.test_softmax_axis_0_expanded_cpu",
|
|
||||||
"OnnxBackendNodeModelTest.test_softmax_axis_1_cpu",
|
"OnnxBackendNodeModelTest.test_softmax_axis_1_cpu",
|
||||||
"OnnxBackendNodeModelTest.test_softmax_axis_1_expanded_cpu",
|
|
||||||
"OnnxBackendNodeModelTest.test_softmax_axis_2_expanded_cpu",
|
|
||||||
"OnnxBackendNodeModelTest.test_softmax_default_axis_cpu",
|
"OnnxBackendNodeModelTest.test_softmax_default_axis_cpu",
|
||||||
"OnnxBackendNodeModelTest.test_softmax_default_axis_expanded_cpu",
|
|
||||||
"OnnxBackendNodeModelTest.test_softmax_large_number_expanded_cpu",
|
|
||||||
"OnnxBackendNodeModelTest.test_softmax_negative_axis_expanded_cpu",
|
|
||||||
"OnnxBackendNodeModelTest.test_hardmax_axis_0_cpu",
|
"OnnxBackendNodeModelTest.test_hardmax_axis_0_cpu",
|
||||||
"OnnxBackendNodeModelTest.test_hardmax_axis_1_cpu",
|
"OnnxBackendNodeModelTest.test_hardmax_axis_1_cpu",
|
||||||
"OnnxBackendNodeModelTest.test_hardmax_default_axis_cpu",),
|
"OnnxBackendNodeModelTest.test_hardmax_default_axis_cpu",),
|
||||||
|
@ -17,9 +17,12 @@ import numpy as np
|
|||||||
import onnx
|
import onnx
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from tests.test_onnx.utils import run_node
|
from tests.runtime import get_runtime
|
||||||
from tests import (xfail_issue_35925,
|
from tests.test_onnx.utils import (
|
||||||
xfail_issue_43523)
|
run_node,
|
||||||
|
import_onnx_model,
|
||||||
|
)
|
||||||
|
from tests import xfail_issue_35925
|
||||||
|
|
||||||
reduce_data = np.array([[[5, 1], [20, 2]], [[30, 1], [40, 2]], [[55, 1], [60, 2]]], dtype=np.float32)
|
reduce_data = np.array([[[5, 1], [20, 2]], [[30, 1], [40, 2]], [[55, 1], [60, 2]]], dtype=np.float32)
|
||||||
reduce_axis_parameters = [
|
reduce_axis_parameters = [
|
||||||
@ -32,14 +35,17 @@ reduce_axis_parameters = [
|
|||||||
(0, 1, 2)
|
(0, 1, 2)
|
||||||
]
|
]
|
||||||
|
|
||||||
reduce_operation_parameters = [
|
reduce_operation_parameters_as_attr = [
|
||||||
("ReduceMax", np.max),
|
("ReduceMax", np.max),
|
||||||
("ReduceMin", np.min),
|
("ReduceMin", np.min),
|
||||||
("ReduceMean", np.mean),
|
("ReduceMean", np.mean),
|
||||||
pytest.param("ReduceSum", np.sum, marks=xfail_issue_43523),
|
|
||||||
("ReduceProd", np.prod)
|
("ReduceProd", np.prod)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
reduce_operation_parameters_as_const = [
|
||||||
|
("ReduceSum", np.sum),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def import_and_compute(op_type, input_data, **node_attrs):
|
def import_and_compute(op_type, input_data, **node_attrs):
|
||||||
data_inputs = [np.array(input_data)]
|
data_inputs = [np.array(input_data)]
|
||||||
@ -47,25 +53,59 @@ def import_and_compute(op_type, input_data, **node_attrs):
|
|||||||
return run_node(node, data_inputs).pop()
|
return run_node(node, data_inputs).pop()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("operation, ref_operation", [
|
def import_and_compute_with_axes_as_const(op_type, data, axes, **node_attrs):
|
||||||
("ReduceMax", np.max),
|
data_input = np.array(data)
|
||||||
("ReduceMin", np.min),
|
axes_input = np.array(axes, dtype=int)
|
||||||
("ReduceMean", np.mean),
|
axes_const_node = onnx.helper.make_node(
|
||||||
("ReduceSum", np.sum),
|
"Constant",
|
||||||
("ReduceProd", np.prod)
|
inputs=[],
|
||||||
])
|
outputs=["const_axes"],
|
||||||
|
value=onnx.helper.make_tensor(
|
||||||
|
name="const_axes",
|
||||||
|
data_type=onnx.TensorProto.INT64,
|
||||||
|
dims=axes_input.shape,
|
||||||
|
vals=axes_input.flatten(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
node = onnx.helper.make_node(
|
||||||
|
op_type, inputs=["x", "const_axes"], outputs=["y"], **node_attrs
|
||||||
|
)
|
||||||
|
graph = onnx.helper.make_graph(
|
||||||
|
[axes_const_node, node],
|
||||||
|
"test_graph",
|
||||||
|
[onnx.helper.make_tensor_value_info("x", onnx.TensorProto.FLOAT, data_input.shape)],
|
||||||
|
[onnx.helper.make_tensor_value_info("y", onnx.TensorProto.FLOAT, ())],
|
||||||
|
)
|
||||||
|
|
||||||
|
model = onnx.helper.make_model(graph, producer_name="ngraph ONNX Importer")
|
||||||
|
model.opset_import[0].version = 13
|
||||||
|
ng_model_function = import_onnx_model(model)
|
||||||
|
runtime = get_runtime()
|
||||||
|
computation = runtime.computation(ng_model_function)
|
||||||
|
return computation(data_input)[0]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("operation, ref_operation",
|
||||||
|
reduce_operation_parameters_as_attr + reduce_operation_parameters_as_const)
|
||||||
def test_reduce_operation_keepdims_none_axes(operation, ref_operation):
|
def test_reduce_operation_keepdims_none_axes(operation, ref_operation):
|
||||||
assert np.array_equal(import_and_compute(operation, reduce_data, keepdims=True),
|
assert np.array_equal(import_and_compute(operation, reduce_data, keepdims=True),
|
||||||
ref_operation(reduce_data, keepdims=True))
|
ref_operation(reduce_data, keepdims=True))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("operation, ref_operation", reduce_operation_parameters)
|
@pytest.mark.parametrize("operation, ref_operation", reduce_operation_parameters_as_attr)
|
||||||
@pytest.mark.parametrize("axes", reduce_axis_parameters)
|
@pytest.mark.parametrize("axes", reduce_axis_parameters)
|
||||||
def test_reduce_operation_keepdims(operation, ref_operation, axes):
|
def test_reduce_operation_keepdims_with_axes_as_attr(operation, ref_operation, axes):
|
||||||
assert np.array_equal(import_and_compute(operation, reduce_data, axes=axes, keepdims=True),
|
assert np.array_equal(import_and_compute(operation, reduce_data, axes=axes, keepdims=True),
|
||||||
ref_operation(reduce_data, keepdims=True, axis=axes))
|
ref_operation(reduce_data, keepdims=True, axis=axes))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("operation, ref_operation", reduce_operation_parameters_as_const)
|
||||||
|
@pytest.mark.parametrize("axes", reduce_axis_parameters)
|
||||||
|
def test_reduce_operation_keepdims_with_axes_as_const(operation, ref_operation, axes):
|
||||||
|
assert np.array_equal(import_and_compute_with_axes_as_const(operation, reduce_data, axes, keepdims=True),
|
||||||
|
ref_operation(reduce_data, keepdims=True, axis=axes))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("axes", [
|
@pytest.mark.parametrize("axes", [
|
||||||
pytest.param(None, marks=xfail_issue_35925),
|
pytest.param(None, marks=xfail_issue_35925),
|
||||||
(0,),
|
(0,),
|
||||||
@ -75,8 +115,8 @@ def test_reduce_operation_keepdims(operation, ref_operation, axes):
|
|||||||
(0, 2),
|
(0, 2),
|
||||||
(1, 2),
|
(1, 2),
|
||||||
pytest.param((0, 1, 2), marks=xfail_issue_35925)])
|
pytest.param((0, 1, 2), marks=xfail_issue_35925)])
|
||||||
@pytest.mark.parametrize("operation, ref_operation", reduce_operation_parameters)
|
@pytest.mark.parametrize("operation, ref_operation", reduce_operation_parameters_as_attr)
|
||||||
def test_reduce_operation_no_keepdims(operation, ref_operation, axes):
|
def test_reduce_operation_no_keepdims_axes_as_attr(operation, ref_operation, axes):
|
||||||
if axes:
|
if axes:
|
||||||
assert np.array_equal(import_and_compute(operation, reduce_data, axes=axes, keepdims=False),
|
assert np.array_equal(import_and_compute(operation, reduce_data, axes=axes, keepdims=False),
|
||||||
ref_operation(reduce_data, keepdims=False, axis=axes))
|
ref_operation(reduce_data, keepdims=False, axis=axes))
|
||||||
@ -85,6 +125,28 @@ def test_reduce_operation_no_keepdims(operation, ref_operation, axes):
|
|||||||
ref_operation(reduce_data, keepdims=False))
|
ref_operation(reduce_data, keepdims=False))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("axes", [
|
||||||
|
pytest.param(None, marks=xfail_issue_35925),
|
||||||
|
(0,),
|
||||||
|
(1,),
|
||||||
|
(2,),
|
||||||
|
(0, 1),
|
||||||
|
(0, 2),
|
||||||
|
(1, 2),
|
||||||
|
pytest.param((0, 1, 2), marks=xfail_issue_35925)])
|
||||||
|
@pytest.mark.parametrize("operation, ref_operation", reduce_operation_parameters_as_const)
|
||||||
|
def test_reduce_operation_no_keepdims_axes_as_const(operation, ref_operation, axes):
|
||||||
|
if axes:
|
||||||
|
assert np.array_equal(import_and_compute_with_axes_as_const(operation,
|
||||||
|
reduce_data,
|
||||||
|
axes,
|
||||||
|
keepdims=False),
|
||||||
|
ref_operation(reduce_data, keepdims=False, axis=axes))
|
||||||
|
else:
|
||||||
|
assert np.array_equal(import_and_compute(operation, reduce_data, keepdims=False),
|
||||||
|
ref_operation(reduce_data, keepdims=False))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("reduction_axes", [(0,), (0, 2), (0, 1, 2)])
|
@pytest.mark.parametrize("reduction_axes", [(0,), (0, 2), (0, 1, 2)])
|
||||||
def test_reduce_l1(reduction_axes):
|
def test_reduce_l1(reduction_axes):
|
||||||
shape = [2, 4, 3, 2]
|
shape = [2, 4, 3, 2]
|
||||||
|
@ -0,0 +1,75 @@
|
|||||||
|
ir_version: 7
|
||||||
|
producer_name: "backend-test"
|
||||||
|
graph {
|
||||||
|
node {
|
||||||
|
input: "data"
|
||||||
|
input: "axes"
|
||||||
|
output: "reduced"
|
||||||
|
op_type: "ReduceSum"
|
||||||
|
attribute {
|
||||||
|
name: "keepdims"
|
||||||
|
i: 1
|
||||||
|
type: INT
|
||||||
|
}
|
||||||
|
attribute {
|
||||||
|
name: "noop_with_empty_axes"
|
||||||
|
i: 1
|
||||||
|
type: INT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
name: "test_reduce_sum_empty_axes_input_noop_example"
|
||||||
|
input {
|
||||||
|
name: "data"
|
||||||
|
type {
|
||||||
|
tensor_type {
|
||||||
|
elem_type: 1
|
||||||
|
shape {
|
||||||
|
dim {
|
||||||
|
dim_value: 3
|
||||||
|
}
|
||||||
|
dim {
|
||||||
|
dim_value: 2
|
||||||
|
}
|
||||||
|
dim {
|
||||||
|
dim_value: 2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
input {
|
||||||
|
name: "axes"
|
||||||
|
type {
|
||||||
|
tensor_type {
|
||||||
|
elem_type: 7
|
||||||
|
shape {
|
||||||
|
dim {
|
||||||
|
dim_value: 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
output {
|
||||||
|
name: "reduced"
|
||||||
|
type {
|
||||||
|
tensor_type {
|
||||||
|
elem_type: 1
|
||||||
|
shape {
|
||||||
|
dim {
|
||||||
|
dim_value: 3
|
||||||
|
}
|
||||||
|
dim {
|
||||||
|
dim_value: 2
|
||||||
|
}
|
||||||
|
dim {
|
||||||
|
dim_value: 2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
opset_import {
|
||||||
|
version: 13
|
||||||
|
}
|
@ -0,0 +1,67 @@
|
|||||||
|
ir_version: 3
|
||||||
|
producer_name: "nGraph ONNX Importer"
|
||||||
|
graph {
|
||||||
|
node {
|
||||||
|
output: "axes"
|
||||||
|
op_type: "Constant"
|
||||||
|
attribute {
|
||||||
|
name: "value"
|
||||||
|
t {
|
||||||
|
dims: 4
|
||||||
|
data_type: 7
|
||||||
|
int64_data: 0
|
||||||
|
int64_data: 1
|
||||||
|
int64_data: 2
|
||||||
|
int64_data: 3
|
||||||
|
name: "const_tensor"
|
||||||
|
}
|
||||||
|
type: TENSOR
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
input: "data"
|
||||||
|
input: "axes"
|
||||||
|
output: "B"
|
||||||
|
op_type: "ReduceSum"
|
||||||
|
}
|
||||||
|
name: "compute_graph"
|
||||||
|
input {
|
||||||
|
name: "data"
|
||||||
|
type {
|
||||||
|
tensor_type {
|
||||||
|
elem_type: 1
|
||||||
|
shape {
|
||||||
|
dim {
|
||||||
|
dim_value: 1
|
||||||
|
}
|
||||||
|
dim {
|
||||||
|
dim_value: 1
|
||||||
|
}
|
||||||
|
dim {
|
||||||
|
dim_value: 4
|
||||||
|
}
|
||||||
|
dim {
|
||||||
|
dim_value: 4
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
output {
|
||||||
|
name: "B"
|
||||||
|
type {
|
||||||
|
tensor_type {
|
||||||
|
elem_type: 1
|
||||||
|
shape {
|
||||||
|
dim {
|
||||||
|
dim_value: 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
opset_import {
|
||||||
|
version: 13
|
||||||
|
}
|
@ -0,0 +1,72 @@
|
|||||||
|
ir_version: 3
|
||||||
|
producer_name: "nGraph ONNX Importer"
|
||||||
|
graph {
|
||||||
|
node {
|
||||||
|
output: "axes"
|
||||||
|
op_type: "Constant"
|
||||||
|
attribute {
|
||||||
|
name: "value"
|
||||||
|
t {
|
||||||
|
dims: 4
|
||||||
|
data_type: 7
|
||||||
|
int64_data: 0
|
||||||
|
int64_data: 1
|
||||||
|
int64_data: 2
|
||||||
|
int64_data: 3
|
||||||
|
name: "const_tensor"
|
||||||
|
}
|
||||||
|
type: TENSOR
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
input: "data"
|
||||||
|
input: "axes"
|
||||||
|
attribute {
|
||||||
|
name: "keepdims"
|
||||||
|
i: 0
|
||||||
|
type: INT
|
||||||
|
}
|
||||||
|
output: "B"
|
||||||
|
op_type: "ReduceSum"
|
||||||
|
}
|
||||||
|
name: "compute_graph"
|
||||||
|
input {
|
||||||
|
name: "data"
|
||||||
|
type {
|
||||||
|
tensor_type {
|
||||||
|
elem_type: 1
|
||||||
|
shape {
|
||||||
|
dim {
|
||||||
|
dim_value: 1
|
||||||
|
}
|
||||||
|
dim {
|
||||||
|
dim_value: 1
|
||||||
|
}
|
||||||
|
dim {
|
||||||
|
dim_value: 4
|
||||||
|
}
|
||||||
|
dim {
|
||||||
|
dim_value: 4
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
output {
|
||||||
|
name: "B"
|
||||||
|
type {
|
||||||
|
tensor_type {
|
||||||
|
elem_type: 1
|
||||||
|
shape {
|
||||||
|
dim {
|
||||||
|
dim_value: 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
opset_import {
|
||||||
|
version: 13
|
||||||
|
}
|
@ -0,0 +1,67 @@
|
|||||||
|
ir_version: 3
|
||||||
|
producer_name: "nGraph ONNX Importer"
|
||||||
|
graph {
|
||||||
|
node {
|
||||||
|
output: "axes"
|
||||||
|
op_type: "Constant"
|
||||||
|
attribute {
|
||||||
|
name: "value"
|
||||||
|
t {
|
||||||
|
dims: 1
|
||||||
|
data_type: 7
|
||||||
|
int64_data: 1
|
||||||
|
name: "const_tensor"
|
||||||
|
}
|
||||||
|
type: TENSOR
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
input: "data"
|
||||||
|
input: "axes"
|
||||||
|
output: "B"
|
||||||
|
op_type: "ReduceSum"
|
||||||
|
}
|
||||||
|
name: "compute_graph"
|
||||||
|
input {
|
||||||
|
name: "data"
|
||||||
|
type {
|
||||||
|
tensor_type {
|
||||||
|
elem_type: 1
|
||||||
|
shape {
|
||||||
|
dim {
|
||||||
|
dim_value: 2
|
||||||
|
}
|
||||||
|
dim {
|
||||||
|
dim_value: 2
|
||||||
|
}
|
||||||
|
dim {
|
||||||
|
dim_value: 3
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
output {
|
||||||
|
name: "B"
|
||||||
|
type {
|
||||||
|
tensor_type {
|
||||||
|
elem_type: 1
|
||||||
|
shape {
|
||||||
|
dim {
|
||||||
|
dim_value: 2
|
||||||
|
}
|
||||||
|
dim {
|
||||||
|
dim_value: 2
|
||||||
|
}
|
||||||
|
dim {
|
||||||
|
dim_value: 3
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
opset_import {
|
||||||
|
version: 13
|
||||||
|
}
|
63
ngraph/test/models/onnx/reduce_sum_13_axes_as_input.prototxt
Normal file
63
ngraph/test/models/onnx/reduce_sum_13_axes_as_input.prototxt
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
ir_version: 3
|
||||||
|
producer_name: "nGraph ONNX Importer"
|
||||||
|
graph {
|
||||||
|
node {
|
||||||
|
input: "data"
|
||||||
|
input: "axes"
|
||||||
|
output: "B"
|
||||||
|
op_type: "ReduceSum"
|
||||||
|
}
|
||||||
|
name: "compute_graph"
|
||||||
|
input {
|
||||||
|
name: "data"
|
||||||
|
type {
|
||||||
|
tensor_type {
|
||||||
|
elem_type: 1
|
||||||
|
shape {
|
||||||
|
dim {
|
||||||
|
dim_value: 1
|
||||||
|
}
|
||||||
|
dim {
|
||||||
|
dim_value: 1
|
||||||
|
}
|
||||||
|
dim {
|
||||||
|
dim_value: 4
|
||||||
|
}
|
||||||
|
dim {
|
||||||
|
dim_value: 4
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
input {
|
||||||
|
name: "axes"
|
||||||
|
type {
|
||||||
|
tensor_type {
|
||||||
|
elem_type: 7
|
||||||
|
shape {
|
||||||
|
dim {
|
||||||
|
dim_value: 4
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
output {
|
||||||
|
name: "B"
|
||||||
|
type {
|
||||||
|
tensor_type {
|
||||||
|
elem_type: 1
|
||||||
|
shape {
|
||||||
|
dim {
|
||||||
|
dim_value: 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
opset_import {
|
||||||
|
version: 13
|
||||||
|
}
|
49
ngraph/test/models/onnx/reduce_sum_13_axes_empty.prototxt
Normal file
49
ngraph/test/models/onnx/reduce_sum_13_axes_empty.prototxt
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
ir_version: 3
|
||||||
|
producer_name: "nGraph ONNX Importer"
|
||||||
|
graph {
|
||||||
|
node {
|
||||||
|
input: "data"
|
||||||
|
output: "B"
|
||||||
|
op_type: "ReduceSum"
|
||||||
|
}
|
||||||
|
name: "compute_graph"
|
||||||
|
input {
|
||||||
|
name: "data"
|
||||||
|
type {
|
||||||
|
tensor_type {
|
||||||
|
elem_type: 1
|
||||||
|
shape {
|
||||||
|
dim {
|
||||||
|
dim_value: 1
|
||||||
|
}
|
||||||
|
dim {
|
||||||
|
dim_value: 1
|
||||||
|
}
|
||||||
|
dim {
|
||||||
|
dim_value: 4
|
||||||
|
}
|
||||||
|
dim {
|
||||||
|
dim_value: 4
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
output {
|
||||||
|
name: "B"
|
||||||
|
type {
|
||||||
|
tensor_type {
|
||||||
|
elem_type: 1
|
||||||
|
shape {
|
||||||
|
dim {
|
||||||
|
dim_value: 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
opset_import {
|
||||||
|
version: 13
|
||||||
|
}
|
@ -0,0 +1,35 @@
|
|||||||
|
ir_version: 3
|
||||||
|
producer_name: "nGraph ONNX Importer"
|
||||||
|
graph {
|
||||||
|
node {
|
||||||
|
input: "data"
|
||||||
|
output: "B"
|
||||||
|
op_type: "ReduceSum"
|
||||||
|
}
|
||||||
|
name: "compute_graph"
|
||||||
|
input {
|
||||||
|
name: "data"
|
||||||
|
type {
|
||||||
|
tensor_type {
|
||||||
|
elem_type: 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
output {
|
||||||
|
name: "B"
|
||||||
|
type {
|
||||||
|
tensor_type {
|
||||||
|
elem_type: 1
|
||||||
|
shape {
|
||||||
|
dim {
|
||||||
|
dim_value: 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
opset_import {
|
||||||
|
version: 13
|
||||||
|
}
|
@ -0,0 +1,63 @@
|
|||||||
|
ir_version: 3
|
||||||
|
producer_name: "nGraph ONNX Importer"
|
||||||
|
graph {
|
||||||
|
node {
|
||||||
|
input: "data"
|
||||||
|
output: "B"
|
||||||
|
op_type: "ReduceSum"
|
||||||
|
attribute {
|
||||||
|
name: "noop_with_empty_axes"
|
||||||
|
i: 1
|
||||||
|
type: INT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
name: "compute_graph"
|
||||||
|
input {
|
||||||
|
name: "data"
|
||||||
|
type {
|
||||||
|
tensor_type {
|
||||||
|
elem_type: 1
|
||||||
|
shape {
|
||||||
|
dim {
|
||||||
|
dim_value: 1
|
||||||
|
}
|
||||||
|
dim {
|
||||||
|
dim_value: 1
|
||||||
|
}
|
||||||
|
dim {
|
||||||
|
dim_value: 4
|
||||||
|
}
|
||||||
|
dim {
|
||||||
|
dim_value: 4
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
output {
|
||||||
|
name: "B"
|
||||||
|
type {
|
||||||
|
tensor_type {
|
||||||
|
elem_type: 1
|
||||||
|
shape {
|
||||||
|
dim {
|
||||||
|
dim_value: 1
|
||||||
|
}
|
||||||
|
dim {
|
||||||
|
dim_value: 1
|
||||||
|
}
|
||||||
|
dim {
|
||||||
|
dim_value: 4
|
||||||
|
}
|
||||||
|
dim {
|
||||||
|
dim_value: 4
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
opset_import {
|
||||||
|
version: 13
|
||||||
|
}
|
@ -0,0 +1,63 @@
|
|||||||
|
ir_version: 3
|
||||||
|
producer_name: "nGraph ONNX Importer"
|
||||||
|
graph {
|
||||||
|
node {
|
||||||
|
input: "data"
|
||||||
|
output: "B"
|
||||||
|
op_type: "ReduceSum"
|
||||||
|
attribute {
|
||||||
|
name: "noop_with_empty_axes"
|
||||||
|
i: 0
|
||||||
|
type: INT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
name: "compute_graph"
|
||||||
|
input {
|
||||||
|
name: "data"
|
||||||
|
type {
|
||||||
|
tensor_type {
|
||||||
|
elem_type: 1
|
||||||
|
shape {
|
||||||
|
dim {
|
||||||
|
dim_value: 1
|
||||||
|
}
|
||||||
|
dim {
|
||||||
|
dim_value: 1
|
||||||
|
}
|
||||||
|
dim {
|
||||||
|
dim_value: 4
|
||||||
|
}
|
||||||
|
dim {
|
||||||
|
dim_value: 4
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
output {
|
||||||
|
name: "B"
|
||||||
|
type {
|
||||||
|
tensor_type {
|
||||||
|
elem_type: 1
|
||||||
|
shape {
|
||||||
|
dim {
|
||||||
|
dim_value: 1
|
||||||
|
}
|
||||||
|
dim {
|
||||||
|
dim_value: 1
|
||||||
|
}
|
||||||
|
dim {
|
||||||
|
dim_value: 4
|
||||||
|
}
|
||||||
|
dim {
|
||||||
|
dim_value: 4
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
opset_import {
|
||||||
|
version: 13
|
||||||
|
}
|
89
ngraph/test/models/onnx/reduce_sum_13_input_dynamic.prototxt
Normal file
89
ngraph/test/models/onnx/reduce_sum_13_input_dynamic.prototxt
Normal file
@ -0,0 +1,89 @@
|
|||||||
|
ir_version: 3
|
||||||
|
producer_name: "nGraph ONNX Importer"
|
||||||
|
graph {
|
||||||
|
node {
|
||||||
|
output: "B"
|
||||||
|
op_type: "Constant"
|
||||||
|
attribute {
|
||||||
|
name: "value"
|
||||||
|
t {
|
||||||
|
dims: 1
|
||||||
|
data_type: 7
|
||||||
|
int64_data: -1
|
||||||
|
name: "const_tensor"
|
||||||
|
}
|
||||||
|
type: TENSOR
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
input: "A"
|
||||||
|
input: "B"
|
||||||
|
output: "X"
|
||||||
|
attribute {
|
||||||
|
name: "special_zero"
|
||||||
|
i: 1
|
||||||
|
type: INT
|
||||||
|
}
|
||||||
|
op_type: "Reshape"
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
input: "X"
|
||||||
|
output: "C"
|
||||||
|
attribute {
|
||||||
|
name: "value"
|
||||||
|
t {
|
||||||
|
dims: 1
|
||||||
|
data_type: 7
|
||||||
|
int64_data: 5
|
||||||
|
name: "const_tensor"
|
||||||
|
}
|
||||||
|
type: TENSOR
|
||||||
|
}
|
||||||
|
op_type: "ConstantOfShape"
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
input: "C"
|
||||||
|
output: "Y"
|
||||||
|
op_type: "ReduceSum"
|
||||||
|
}
|
||||||
|
name: "compute_graph"
|
||||||
|
input {
|
||||||
|
name: "A"
|
||||||
|
type {
|
||||||
|
tensor_type {
|
||||||
|
elem_type: 7
|
||||||
|
shape {
|
||||||
|
dim {
|
||||||
|
dim_value: 1
|
||||||
|
}
|
||||||
|
dim {
|
||||||
|
dim_value: 1
|
||||||
|
}
|
||||||
|
dim {
|
||||||
|
dim_value: 4
|
||||||
|
}
|
||||||
|
dim {
|
||||||
|
dim_value: 4
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
output {
|
||||||
|
name: "Y"
|
||||||
|
type {
|
||||||
|
tensor_type {
|
||||||
|
elem_type: 1
|
||||||
|
shape {
|
||||||
|
dim {
|
||||||
|
dim_value: 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
opset_import {
|
||||||
|
version: 13
|
||||||
|
}
|
@ -0,0 +1,34 @@
|
|||||||
|
ir_version: 3
|
||||||
|
producer_name: "nGraph ONNX Importer"
|
||||||
|
graph {
|
||||||
|
node {
|
||||||
|
input: "A"
|
||||||
|
output: "B"
|
||||||
|
op_type: "ReduceSum"
|
||||||
|
}
|
||||||
|
name: "compute_graph"
|
||||||
|
input {
|
||||||
|
name: "A"
|
||||||
|
type {
|
||||||
|
tensor_type {
|
||||||
|
elem_type: 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
output {
|
||||||
|
name: "B"
|
||||||
|
type {
|
||||||
|
tensor_type {
|
||||||
|
elem_type: 1
|
||||||
|
shape {
|
||||||
|
dim {
|
||||||
|
dim_value: 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
opset_import {
|
||||||
|
version: 1
|
||||||
|
}
|
@ -1102,6 +1102,33 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_reduce_sum)
|
|||||||
test_case.run();
|
test_case.run();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_reduce_sum_dynamic_rank_input)
|
||||||
|
{
|
||||||
|
auto function = onnx_import::import_onnx_model(
|
||||||
|
file_util::path_join(SERIALIZED_ZOO, "onnx/reduce_sum_dynamic_rank_input.prototxt"));
|
||||||
|
auto test_case = test::TestCase<TestEngine, test::TestCaseType::DYNAMIC>(function);
|
||||||
|
test_case.add_input<float>(Shape{1, 1, 4, 4},
|
||||||
|
{1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f});
|
||||||
|
|
||||||
|
test_case.add_expected_output<float>(Shape{1, 1, 1, 1}, {16.0f});
|
||||||
|
test_case.run();
|
||||||
|
}
|
||||||
|
|
||||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_reduce_sum_square)
|
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_reduce_sum_square)
|
||||||
{
|
{
|
||||||
auto function = onnx_import::import_onnx_model(
|
auto function = onnx_import::import_onnx_model(
|
||||||
@ -1121,6 +1148,242 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_reduce_sum_square)
|
|||||||
test_case.run();
|
test_case.run();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_reduce_sum_13_axes_as_constant)
|
||||||
|
{
|
||||||
|
auto function = onnx_import::import_onnx_model(
|
||||||
|
file_util::path_join(SERIALIZED_ZOO, "onnx/reduce_sum_13_axes_as_constant.prototxt"));
|
||||||
|
|
||||||
|
Inputs inputs{test::NDArray<float, 4>({{{{1.0f, 1.0f, 1.0f, 1.0f},
|
||||||
|
{1.0f, 1.0f, 1.0f, 1.0f},
|
||||||
|
{1.0f, 1.0f, 1.0f, 1.0f},
|
||||||
|
{1.0f, 1.0f, 1.0f, 1.0f}}}})
|
||||||
|
.get_vector()};
|
||||||
|
|
||||||
|
auto test_case = test::TestCase<TestEngine>(function);
|
||||||
|
|
||||||
|
test_case.add_expected_output<float>(Shape{1, 1, 1, 1}, {16.0f});
|
||||||
|
|
||||||
|
test_case.add_multiple_inputs(inputs);
|
||||||
|
test_case.run();
|
||||||
|
}
|
||||||
|
|
||||||
|
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_reduce_sum_13_axes_as_constant_single_axis)
|
||||||
|
{
|
||||||
|
auto function = onnx_import::import_onnx_model(file_util::path_join(
|
||||||
|
SERIALIZED_ZOO, "onnx/reduce_sum_13_axes_as_constant_single_axis.prototxt"));
|
||||||
|
|
||||||
|
Inputs inputs{
|
||||||
|
test::NDArray<float, 3>({{{1, 2, 3}, {4, 5, 6}}, {{7, 8, 9}, {10, 11, 12}}}).get_vector()};
|
||||||
|
|
||||||
|
auto test_case = test::TestCase<TestEngine>(function);
|
||||||
|
|
||||||
|
test_case.add_expected_output<float>(Shape{2, 1, 3}, {5.0f, 7.0f, 9.0f, 17.0f, 19.0f, 21.0f});
|
||||||
|
|
||||||
|
test_case.add_multiple_inputs(inputs);
|
||||||
|
test_case.run();
|
||||||
|
}
|
||||||
|
|
||||||
|
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_reduce_sum_13_axes_as_constant_keepdims_off)
|
||||||
|
{
|
||||||
|
auto function = onnx_import::import_onnx_model(file_util::path_join(
|
||||||
|
SERIALIZED_ZOO, "onnx/reduce_sum_13_axes_as_constant_keepdims_off.prototxt"));
|
||||||
|
|
||||||
|
// input data shape (1, 1, 4, 4)
|
||||||
|
Inputs inputs{test::NDArray<float, 4>({{{{1.0f, 1.0f, 1.0f, 1.0f},
|
||||||
|
{1.0f, 1.0f, 1.0f, 1.0f},
|
||||||
|
{1.0f, 1.0f, 1.0f, 1.0f},
|
||||||
|
{1.0f, 1.0f, 1.0f, 1.0f}}}})
|
||||||
|
.get_vector()};
|
||||||
|
|
||||||
|
auto test_case = test::TestCase<TestEngine>(function);
|
||||||
|
|
||||||
|
test_case.add_expected_output<float>(Shape{}, {16.0f});
|
||||||
|
|
||||||
|
test_case.add_multiple_inputs(inputs);
|
||||||
|
test_case.run();
|
||||||
|
}
|
||||||
|
|
||||||
|
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_reduce_sum_13_axes_as_input)
|
||||||
|
{
|
||||||
|
auto function = onnx_import::import_onnx_model(
|
||||||
|
file_util::path_join(SERIALIZED_ZOO, "onnx/reduce_sum_13_axes_as_input.prototxt"));
|
||||||
|
|
||||||
|
auto test_case = test::TestCase<TestEngine, test::TestCaseType::DYNAMIC>(function);
|
||||||
|
test_case.add_input<float>({1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f});
|
||||||
|
test_case.add_input<int64_t>({0, 1, 2, 3});
|
||||||
|
|
||||||
|
test_case.add_expected_output<float>(Shape{1, 1, 1, 1}, {16.0f});
|
||||||
|
test_case.run();
|
||||||
|
}
|
||||||
|
|
||||||
|
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_reduce_sum_13_axes_as_0_dim_input)
|
||||||
|
{
|
||||||
|
auto function = onnx_import::import_onnx_model(
|
||||||
|
file_util::path_join(SERIALIZED_ZOO, "onnx/reduce_sum_13_axes_as_0_dim_input.prototxt"));
|
||||||
|
|
||||||
|
auto test_case = test::TestCase<TestEngine, test::TestCaseType::DYNAMIC>(function);
|
||||||
|
test_case.add_input<float>(
|
||||||
|
{1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f});
|
||||||
|
|
||||||
|
test_case.add_expected_output<float>(
|
||||||
|
Shape{3, 2, 2},
|
||||||
|
{1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f});
|
||||||
|
test_case.run();
|
||||||
|
}
|
||||||
|
|
||||||
|
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_reduce_sum_13_input_dynamic)
|
||||||
|
{
|
||||||
|
auto function = onnx_import::import_onnx_model(
|
||||||
|
file_util::path_join(SERIALIZED_ZOO, "onnx/reduce_sum_13_input_dynamic.prototxt"));
|
||||||
|
|
||||||
|
auto test_case = test::TestCase<TestEngine, test::TestCaseType::DYNAMIC>(function);
|
||||||
|
test_case.add_input<int64_t>({1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1});
|
||||||
|
|
||||||
|
test_case.add_expected_output<int64_t>(Shape{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
|
||||||
|
{5});
|
||||||
|
test_case.run();
|
||||||
|
}
|
||||||
|
|
||||||
|
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_reduce_sum_13_axes_empty)
|
||||||
|
{
|
||||||
|
auto function = onnx_import::import_onnx_model(
|
||||||
|
file_util::path_join(SERIALIZED_ZOO, "onnx/reduce_sum_13_axes_empty.prototxt"));
|
||||||
|
|
||||||
|
auto test_case = test::TestCase<TestEngine>(function);
|
||||||
|
test_case.add_input<float>({1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f});
|
||||||
|
|
||||||
|
test_case.add_expected_output<float>(Shape{1, 1, 1, 1}, {16.0f});
|
||||||
|
test_case.run();
|
||||||
|
}
|
||||||
|
|
||||||
|
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_reduce_sum_13_axes_empty_dynamic_rank_input)
|
||||||
|
{
|
||||||
|
auto function = onnx_import::import_onnx_model(file_util::path_join(
|
||||||
|
SERIALIZED_ZOO, "onnx/reduce_sum_13_axes_empty_dynamic_rank_input.prototxt"));
|
||||||
|
|
||||||
|
auto test_case = test::TestCase<TestEngine, test::TestCaseType::DYNAMIC>(function);
|
||||||
|
test_case.add_input<float>(Shape{1, 1, 4, 4},
|
||||||
|
{1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f});
|
||||||
|
|
||||||
|
test_case.add_expected_output<float>(Shape{1, 1, 1, 1}, {16.0f});
|
||||||
|
test_case.run();
|
||||||
|
}
|
||||||
|
|
||||||
|
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_reduce_sum_13_axes_empty_with_noop)
|
||||||
|
{
|
||||||
|
auto function = onnx_import::import_onnx_model(
|
||||||
|
file_util::path_join(SERIALIZED_ZOO, "onnx/reduce_sum_13_axes_empty_with_noop.prototxt"));
|
||||||
|
|
||||||
|
auto test_case = test::TestCase<TestEngine>(function);
|
||||||
|
test_case.add_input<float>({1.f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f});
|
||||||
|
|
||||||
|
test_case.add_expected_output<float>(Shape{1, 1, 4, 4},
|
||||||
|
{1.f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f});
|
||||||
|
test_case.run();
|
||||||
|
}
|
||||||
|
|
||||||
|
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_reduce_sum_13_axes_empty_without_noop)
|
||||||
|
{
|
||||||
|
auto function = onnx_import::import_onnx_model(file_util::path_join(
|
||||||
|
SERIALIZED_ZOO, "onnx/reduce_sum_13_axes_empty_without_noop.prototxt"));
|
||||||
|
|
||||||
|
auto test_case = test::TestCase<TestEngine>(function);
|
||||||
|
test_case.add_input<float>({1.f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f,
|
||||||
|
1.0f});
|
||||||
|
|
||||||
|
test_case.add_expected_output<float>(Shape{1, 1, 1, 1}, {16.0f});
|
||||||
|
test_case.run();
|
||||||
|
}
|
||||||
|
|
||||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_resize11_empty_constant_as_input)
|
NGRAPH_TEST(${BACKEND_NAME}, onnx_resize11_empty_constant_as_input)
|
||||||
{
|
{
|
||||||
// this model contains a Constant node with an empty underlying tensor
|
// this model contains a Constant node with an empty underlying tensor
|
||||||
|
@ -1168,6 +1168,13 @@ IE_CPU.roi_pooling_1x1_bilinear
|
|||||||
|
|
||||||
# Unsupported dynamic op
|
# Unsupported dynamic op
|
||||||
IE_CPU.range_v4_trunc_inputs
|
IE_CPU.range_v4_trunc_inputs
|
||||||
|
IE_CPU.onnx_model_reduce_sum_13_axes_as_input
|
||||||
|
IE_CPU.onnx_model_reduce_sum_13_input_dynamic
|
||||||
|
IE_CPU.onnx_model_reduce_sum_13_axes_empty_dynamic_rank_input
|
||||||
|
IE_CPU.onnx_model_reduce_sum_dynamic_rank_input
|
||||||
|
|
||||||
|
# Axes has zero dimension which is not allowed
|
||||||
|
IE_CPU.onnx_model_reduce_sum_13_axes_as_0_dim_input
|
||||||
|
|
||||||
# output mismatch
|
# output mismatch
|
||||||
IE_CPU.gather_nd_batch_1d_from_3d_negative
|
IE_CPU.gather_nd_batch_1d_from_3d_negative
|
||||||
|
Loading…
Reference in New Issue
Block a user