ONNX: Throw exception if attribute reduction is used in ScatterND or ScatterElements (#13778)
* ONNX: Throw exception if `reduction` attribute has unsupported value Operators: - ScatterND - ScatterElements
This commit is contained in:
parent
5f0b063455
commit
a48d1558e3
@ -7,6 +7,7 @@
|
|||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
#include "default_opset.hpp"
|
#include "default_opset.hpp"
|
||||||
|
#include "exceptions.hpp"
|
||||||
#include "ngraph/opsets/opset3.hpp"
|
#include "ngraph/opsets/opset3.hpp"
|
||||||
|
|
||||||
namespace ngraph {
|
namespace ngraph {
|
||||||
@ -17,9 +18,14 @@ OutputVector scatter_elements(const Node& node) {
|
|||||||
const auto data = node.get_ng_inputs().at(0);
|
const auto data = node.get_ng_inputs().at(0);
|
||||||
const auto indices = node.get_ng_inputs().at(1);
|
const auto indices = node.get_ng_inputs().at(1);
|
||||||
const auto updates = node.get_ng_inputs().at(2);
|
const auto updates = node.get_ng_inputs().at(2);
|
||||||
|
|
||||||
const auto axis_node = node.get_attribute_as_constant<std::int64_t>("axis", 0);
|
const auto axis_node = node.get_attribute_as_constant<std::int64_t>("axis", 0);
|
||||||
|
if (node.has_attribute("reduction")) {
|
||||||
|
const auto reduction = node.get_attribute_value<std::string>("reduction", "none");
|
||||||
|
CHECK_VALID_NODE(node,
|
||||||
|
reduction == "none",
|
||||||
|
"Unsupported value of attribute: `reduction`. Only `none` is supported, got:",
|
||||||
|
reduction);
|
||||||
|
}
|
||||||
return {std::make_shared<ngraph::opset3::ScatterElementsUpdate>(data, indices, updates, axis_node)};
|
return {std::make_shared<ngraph::opset3::ScatterElementsUpdate>(data, indices, updates, axis_node)};
|
||||||
}
|
}
|
||||||
} // namespace set_1
|
} // namespace set_1
|
||||||
|
@ -10,6 +10,7 @@
|
|||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
#include "default_opset.hpp"
|
#include "default_opset.hpp"
|
||||||
|
#include "exceptions.hpp"
|
||||||
|
|
||||||
namespace ngraph {
|
namespace ngraph {
|
||||||
namespace onnx_import {
|
namespace onnx_import {
|
||||||
@ -20,6 +21,13 @@ OutputVector scatter_nd(const Node& node) {
|
|||||||
auto data = ng_inputs.at(0);
|
auto data = ng_inputs.at(0);
|
||||||
auto indices = ng_inputs.at(1);
|
auto indices = ng_inputs.at(1);
|
||||||
auto updates = ng_inputs.at(2);
|
auto updates = ng_inputs.at(2);
|
||||||
|
if (node.has_attribute("reduction")) {
|
||||||
|
const auto reduction = node.get_attribute_value<std::string>("reduction", "none");
|
||||||
|
CHECK_VALID_NODE(node,
|
||||||
|
reduction == "none",
|
||||||
|
"Unsupported value of attribute: `reduction`. Only `none` is supported, got:",
|
||||||
|
reduction);
|
||||||
|
}
|
||||||
|
|
||||||
return {std::make_shared<default_opset::ScatterNDUpdate>(data, indices, updates)};
|
return {std::make_shared<default_opset::ScatterNDUpdate>(data, indices, updates)};
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,64 @@
|
|||||||
|
ir_version: 8
|
||||||
|
graph {
|
||||||
|
node {
|
||||||
|
input: "data"
|
||||||
|
input: "indices"
|
||||||
|
input: "updates"
|
||||||
|
output: "y"
|
||||||
|
op_type: "ScatterElements"
|
||||||
|
attribute {
|
||||||
|
name: "axis"
|
||||||
|
i: 1
|
||||||
|
type: INT
|
||||||
|
}
|
||||||
|
attribute {
|
||||||
|
name: "reduction"
|
||||||
|
s: "add"
|
||||||
|
type: STRING
|
||||||
|
}
|
||||||
|
}
|
||||||
|
name: "test_scatter"
|
||||||
|
initializer {
|
||||||
|
dims: 1
|
||||||
|
dims: 5
|
||||||
|
data_type: 1
|
||||||
|
float_data: 1
|
||||||
|
float_data: 2
|
||||||
|
float_data: 3
|
||||||
|
float_data: 4
|
||||||
|
float_data: 5
|
||||||
|
name: "data"
|
||||||
|
}
|
||||||
|
initializer {
|
||||||
|
dims: 1
|
||||||
|
dims: 2
|
||||||
|
data_type: 6
|
||||||
|
int32_data: 1
|
||||||
|
int32_data: 3
|
||||||
|
name: "indices"
|
||||||
|
}
|
||||||
|
initializer {
|
||||||
|
dims: 1
|
||||||
|
dims: 2
|
||||||
|
data_type: 1
|
||||||
|
float_data: 1.1000000238418579
|
||||||
|
float_data: 2.0999999046325684
|
||||||
|
name: "updates"
|
||||||
|
}
|
||||||
|
output {
|
||||||
|
name: "y"
|
||||||
|
type {
|
||||||
|
tensor_type {
|
||||||
|
elem_type: 1
|
||||||
|
shape {
|
||||||
|
dim {
|
||||||
|
dim_value: 2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
opset_import {
|
||||||
|
version: 16
|
||||||
|
}
|
@ -0,0 +1,64 @@
|
|||||||
|
ir_version: 8
|
||||||
|
graph {
|
||||||
|
node {
|
||||||
|
input: "data"
|
||||||
|
input: "indices"
|
||||||
|
input: "updates"
|
||||||
|
output: "y"
|
||||||
|
op_type: "ScatterElements"
|
||||||
|
attribute {
|
||||||
|
name: "axis"
|
||||||
|
i: 1
|
||||||
|
type: INT
|
||||||
|
}
|
||||||
|
attribute {
|
||||||
|
name: "reduction"
|
||||||
|
s: "none"
|
||||||
|
type: STRING
|
||||||
|
}
|
||||||
|
}
|
||||||
|
name: "test_scatter"
|
||||||
|
initializer {
|
||||||
|
dims: 1
|
||||||
|
dims: 5
|
||||||
|
data_type: 1
|
||||||
|
float_data: 1
|
||||||
|
float_data: 2
|
||||||
|
float_data: 3
|
||||||
|
float_data: 4
|
||||||
|
float_data: 5
|
||||||
|
name: "data"
|
||||||
|
}
|
||||||
|
initializer {
|
||||||
|
dims: 1
|
||||||
|
dims: 2
|
||||||
|
data_type: 6
|
||||||
|
int32_data: 1
|
||||||
|
int32_data: 3
|
||||||
|
name: "indices"
|
||||||
|
}
|
||||||
|
initializer {
|
||||||
|
dims: 1
|
||||||
|
dims: 2
|
||||||
|
data_type: 1
|
||||||
|
float_data: 1.1000000238418579
|
||||||
|
float_data: 2.0999999046325684
|
||||||
|
name: "updates"
|
||||||
|
}
|
||||||
|
output {
|
||||||
|
name: "y"
|
||||||
|
type {
|
||||||
|
tensor_type {
|
||||||
|
elem_type: 1
|
||||||
|
shape {
|
||||||
|
dim {
|
||||||
|
dim_value: 2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
opset_import {
|
||||||
|
version: 16
|
||||||
|
}
|
@ -0,0 +1,75 @@
|
|||||||
|
ir_version: 3
|
||||||
|
producer_name: "nGraph ONNX Importer"
|
||||||
|
graph {
|
||||||
|
node {
|
||||||
|
input: "x"
|
||||||
|
input: "i"
|
||||||
|
input: "u"
|
||||||
|
output: "y"
|
||||||
|
op_type: "ScatterND"
|
||||||
|
attribute {
|
||||||
|
name: "reduction"
|
||||||
|
s: "add"
|
||||||
|
type: STRING
|
||||||
|
}
|
||||||
|
}
|
||||||
|
name: "test_scatterND"
|
||||||
|
input {
|
||||||
|
name: "x"
|
||||||
|
type {
|
||||||
|
tensor_type {
|
||||||
|
elem_type: 1
|
||||||
|
shape {
|
||||||
|
dim {
|
||||||
|
dim_value: 8
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
input {
|
||||||
|
name: "i"
|
||||||
|
type {
|
||||||
|
tensor_type {
|
||||||
|
elem_type: 7
|
||||||
|
shape {
|
||||||
|
dim {
|
||||||
|
dim_value: 4
|
||||||
|
}
|
||||||
|
dim {
|
||||||
|
dim_value: 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
input {
|
||||||
|
name: "u"
|
||||||
|
type {
|
||||||
|
tensor_type {
|
||||||
|
elem_type: 1
|
||||||
|
shape {
|
||||||
|
dim {
|
||||||
|
dim_value: 4
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
output {
|
||||||
|
name: "y"
|
||||||
|
type {
|
||||||
|
tensor_type {
|
||||||
|
elem_type: 1
|
||||||
|
shape {
|
||||||
|
dim {
|
||||||
|
dim_value: 8
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
opset_import {
|
||||||
|
version: 16
|
||||||
|
}
|
@ -0,0 +1,75 @@
|
|||||||
|
ir_version: 3
|
||||||
|
producer_name: "nGraph ONNX Importer"
|
||||||
|
graph {
|
||||||
|
node {
|
||||||
|
input: "x"
|
||||||
|
input: "i"
|
||||||
|
input: "u"
|
||||||
|
output: "y"
|
||||||
|
op_type: "ScatterND"
|
||||||
|
attribute {
|
||||||
|
name: "reduction"
|
||||||
|
s: "none"
|
||||||
|
type: STRING
|
||||||
|
}
|
||||||
|
}
|
||||||
|
name: "test_scatterND"
|
||||||
|
input {
|
||||||
|
name: "x"
|
||||||
|
type {
|
||||||
|
tensor_type {
|
||||||
|
elem_type: 1
|
||||||
|
shape {
|
||||||
|
dim {
|
||||||
|
dim_value: 8
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
input {
|
||||||
|
name: "i"
|
||||||
|
type {
|
||||||
|
tensor_type {
|
||||||
|
elem_type: 7
|
||||||
|
shape {
|
||||||
|
dim {
|
||||||
|
dim_value: 4
|
||||||
|
}
|
||||||
|
dim {
|
||||||
|
dim_value: 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
input {
|
||||||
|
name: "u"
|
||||||
|
type {
|
||||||
|
tensor_type {
|
||||||
|
elem_type: 1
|
||||||
|
shape {
|
||||||
|
dim {
|
||||||
|
dim_value: 4
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
output {
|
||||||
|
name: "y"
|
||||||
|
type {
|
||||||
|
tensor_type {
|
||||||
|
elem_type: 1
|
||||||
|
shape {
|
||||||
|
dim {
|
||||||
|
dim_value: 8
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
opset_import {
|
||||||
|
version: 16
|
||||||
|
}
|
@ -3105,6 +3105,29 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_scatterND_const_i32_indices) {
|
|||||||
test_case.run();
|
test_case.run();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_scatterND_opset16_reduction_none) {
|
||||||
|
const auto function =
|
||||||
|
onnx_import::import_onnx_model(file_util::path_join(CommonTestUtils::getExecutableDirectory(),
|
||||||
|
SERIALIZED_ZOO,
|
||||||
|
"onnx/scatter_nd_opset16_reduction_none.onnx"));
|
||||||
|
auto test_case = test::TestCase(function, s_device);
|
||||||
|
|
||||||
|
test_case.add_input<float>({1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f});
|
||||||
|
test_case.add_input<int64_t>({4, 3, 1, 7});
|
||||||
|
test_case.add_input<float>({9.f, 10.f, 11.f, 12.f});
|
||||||
|
test_case.add_expected_output<float>(Shape{8}, {1.f, 11.f, 3.f, 10.f, 9.f, 6.f, 7.f, 12.f});
|
||||||
|
|
||||||
|
test_case.run();
|
||||||
|
}
|
||||||
|
|
||||||
|
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_scatterND_opset16_reduction_add) {
|
||||||
|
EXPECT_THROW(onnx_import::import_onnx_model(file_util::path_join(CommonTestUtils::getExecutableDirectory(),
|
||||||
|
SERIALIZED_ZOO,
|
||||||
|
"onnx/scatter_nd_opset16_reduction_add.onnx")),
|
||||||
|
ngraph_error)
|
||||||
|
<< "Unsupported type of attribute: `reduction`. Only `none` is supported";
|
||||||
|
}
|
||||||
|
|
||||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_gather_float_1D) {
|
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_gather_float_1D) {
|
||||||
const auto function = onnx_import::import_onnx_model(
|
const auto function = onnx_import::import_onnx_model(
|
||||||
file_util::path_join(CommonTestUtils::getExecutableDirectory(), SERIALIZED_ZOO, "onnx/gather_float_1D.onnx"));
|
file_util::path_join(CommonTestUtils::getExecutableDirectory(), SERIALIZED_ZOO, "onnx/gather_float_1D.onnx"));
|
||||||
@ -3442,6 +3465,28 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_scatter_elements_import_only) {
|
|||||||
EXPECT_EQ(count_ops_of_type<op::v0::Constant>(scatter_fn), 4);
|
EXPECT_EQ(count_ops_of_type<op::v0::Constant>(scatter_fn), 4);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_scatter_elements_opset16_reduction_none) {
|
||||||
|
const auto scatter_fn =
|
||||||
|
onnx_import::import_onnx_model(file_util::path_join(CommonTestUtils::getExecutableDirectory(),
|
||||||
|
SERIALIZED_ZOO,
|
||||||
|
"onnx/scatter_elements_opset16_reduction_none.onnx"));
|
||||||
|
|
||||||
|
const Shape data_shape{1, 5};
|
||||||
|
|
||||||
|
EXPECT_EQ(scatter_fn->get_output_size(), 1);
|
||||||
|
EXPECT_EQ(scatter_fn->get_output_shape(0), data_shape);
|
||||||
|
EXPECT_EQ(count_ops_of_type<op::v3::ScatterElementsUpdate>(scatter_fn), 1);
|
||||||
|
EXPECT_EQ(count_ops_of_type<op::v0::Constant>(scatter_fn), 4);
|
||||||
|
}
|
||||||
|
|
||||||
|
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_scatter_elements_opset16_reduction_add) {
|
||||||
|
const auto path = file_util::path_join(CommonTestUtils::getExecutableDirectory(),
|
||||||
|
SERIALIZED_ZOO,
|
||||||
|
"onnx/scatter_elements_opset16_reduction_add.onnx");
|
||||||
|
EXPECT_THROW(onnx_import::import_onnx_model(path), ngraph_error)
|
||||||
|
<< "Unsupported type of attribute: `reduction`. Only `none` is supported";
|
||||||
|
}
|
||||||
|
|
||||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_upsample6_nearest_infer) {
|
NGRAPH_TEST(${BACKEND_NAME}, onnx_upsample6_nearest_infer) {
|
||||||
// clang-format off
|
// clang-format off
|
||||||
const auto function = onnx_import::import_onnx_model(
|
const auto function = onnx_import::import_onnx_model(
|
||||||
|
Loading…
Reference in New Issue
Block a user