[nGraph][ONNX] Extend ONNX Importer for operation "GatherElements-6" (#3822)
This commit is contained in:
41
ngraph/frontend/onnx_import/src/op/gather_elements.hpp
Normal file
41
ngraph/frontend/onnx_import/src/op/gather_elements.hpp
Normal file
@@ -0,0 +1,41 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2021 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ngraph/output_vector.hpp"
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
namespace onnx_import
|
||||
{
|
||||
namespace op
|
||||
{
|
||||
namespace set_1
|
||||
{
|
||||
inline OutputVector gather_elements(const Node& node)
|
||||
{
|
||||
OutputVector ng_inputs{node.get_ng_inputs()};
|
||||
auto data = ng_inputs.at(0);
|
||||
auto indices = ng_inputs.at(1);
|
||||
auto axis = node.get_attribute_value<int64_t>("axis", 0);
|
||||
|
||||
return {std::make_shared<ngraph::op::v6::GatherElements>(data, indices, axis)};
|
||||
}
|
||||
} // namespace set_1
|
||||
} // namespace op
|
||||
} // namespace onnx_import
|
||||
} // namespace ngraph
|
||||
@@ -60,6 +60,7 @@
|
||||
#include "op/flatten.hpp"
|
||||
#include "op/floor.hpp"
|
||||
#include "op/gather.hpp"
|
||||
#include "op/gather_elements.hpp"
|
||||
#include "op/gather_nd.hpp"
|
||||
#include "op/gemm.hpp"
|
||||
#include "op/global_average_pool.hpp"
|
||||
@@ -346,6 +347,7 @@ namespace ngraph
|
||||
REGISTER_OPERATOR("Flatten", 1, flatten);
|
||||
REGISTER_OPERATOR("Floor", 1, floor);
|
||||
REGISTER_OPERATOR("Gather", 1, gather);
|
||||
REGISTER_OPERATOR("GatherElements", 1, gather_elements);
|
||||
REGISTER_OPERATOR("GatherND", 1, gather_nd);
|
||||
REGISTER_OPERATOR("Gemm", 1, gemm);
|
||||
REGISTER_OPERATOR("Gemm", 6, gemm);
|
||||
|
||||
58
ngraph/test/models/onnx/gather_elements_float_1D.prototxt
Normal file
58
ngraph/test/models/onnx/gather_elements_float_1D.prototxt
Normal file
@@ -0,0 +1,58 @@
|
||||
ir_version: 7
|
||||
producer_name: "nGraph ONNX Importer"
|
||||
graph {
|
||||
node {
|
||||
input: "data"
|
||||
input: "indices"
|
||||
output: "output"
|
||||
op_type: "GatherElements"
|
||||
attribute {
|
||||
name: "axis"
|
||||
i: 0
|
||||
type: INT
|
||||
}
|
||||
}
|
||||
name: "test_gather_elements"
|
||||
input {
|
||||
name: "data"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 3
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "indices"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 7
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "output"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
opset_import {
|
||||
version: 13
|
||||
}
|
||||
@@ -0,0 +1,76 @@
|
||||
ir_version: 7
|
||||
producer_name: "nGraph ONNX Importer"
|
||||
graph {
|
||||
node {
|
||||
input: "data"
|
||||
input: "indices"
|
||||
output: "output"
|
||||
op_type: "GatherElements"
|
||||
attribute {
|
||||
name: "axis"
|
||||
i: 2
|
||||
type: INT
|
||||
}
|
||||
}
|
||||
name: "test_gather_elements"
|
||||
input {
|
||||
name: "data"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "indices"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 7
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
dim {
|
||||
dim_value: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "output"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
dim {
|
||||
dim_value: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
opset_import {
|
||||
version: 13
|
||||
}
|
||||
@@ -0,0 +1,67 @@
|
||||
ir_version: 7
|
||||
producer_name: "nGraph ONNX Importer"
|
||||
graph {
|
||||
node {
|
||||
input: "data"
|
||||
input: "indices"
|
||||
output: "output"
|
||||
op_type: "GatherElements"
|
||||
attribute {
|
||||
name: "axis"
|
||||
i: -1
|
||||
type: INT
|
||||
}
|
||||
}
|
||||
name: "test_gather_elements"
|
||||
input {
|
||||
name: "data"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "indices"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 7
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "output"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
opset_import {
|
||||
version: 13
|
||||
}
|
||||
@@ -0,0 +1,67 @@
|
||||
ir_version: 7
|
||||
producer_name: "nGraph ONNX Importer"
|
||||
graph {
|
||||
node {
|
||||
input: "data"
|
||||
input: "indices"
|
||||
output: "output"
|
||||
op_type: "GatherElements"
|
||||
attribute {
|
||||
name: "axis"
|
||||
i: 0
|
||||
type: INT
|
||||
}
|
||||
}
|
||||
name: "test_gather_elements"
|
||||
input {
|
||||
name: "data"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 6
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 3
|
||||
}
|
||||
dim {
|
||||
dim_value: 3
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "indices"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 7
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
dim {
|
||||
dim_value: 3
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "output"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 6
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
dim {
|
||||
dim_value: 3
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
opset_import {
|
||||
version: 13
|
||||
}
|
||||
67
ngraph/test/models/onnx/gather_elements_int8_axis_1.prototxt
Normal file
67
ngraph/test/models/onnx/gather_elements_int8_axis_1.prototxt
Normal file
@@ -0,0 +1,67 @@
|
||||
ir_version: 7
|
||||
producer_name: "nGraph ONNX Importer"
|
||||
graph {
|
||||
node {
|
||||
input: "data"
|
||||
input: "indices"
|
||||
output: "output"
|
||||
op_type: "GatherElements"
|
||||
attribute {
|
||||
name: "axis"
|
||||
i: 1
|
||||
type: INT
|
||||
}
|
||||
}
|
||||
name: "test_gather_elements"
|
||||
input {
|
||||
name: "data"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 3
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "indices"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 6
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "output"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 3
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
opset_import {
|
||||
version: 13
|
||||
}
|
||||
@@ -3020,6 +3020,71 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_scatterND_const_i32_indices)
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_gather_elements_float_1D)
|
||||
{
|
||||
const auto function = onnx_import::import_onnx_model(
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/gather_elements_float_1D.prototxt"));
|
||||
auto test_case = test::TestCase<TestEngine>(function);
|
||||
|
||||
test_case.add_input<float>(Shape{3}, {1, 2, 3});
|
||||
test_case.add_input<int64_t>(Shape{1}, {1});
|
||||
test_case.add_expected_output<float>(Shape{1}, {2});
|
||||
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_gather_elements_int8_axis_1)
|
||||
{
|
||||
const auto function = onnx_import::import_onnx_model(
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/gather_elements_int8_axis_1.prototxt"));
|
||||
auto test_case = test::TestCase<TestEngine>(function);
|
||||
|
||||
test_case.add_input<int8_t>(Shape{2, 2}, {1, 2, 3, 4});
|
||||
test_case.add_input<int32_t>(Shape{2, 2}, {0, 0, 1, 0});
|
||||
test_case.add_expected_output<int8_t>(Shape{2, 2}, {1, 1, 4, 3});
|
||||
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_gather_elements_int32_axis_0)
|
||||
{
|
||||
const auto function = onnx_import::import_onnx_model(
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/gather_elements_int32_axis_0.prototxt"));
|
||||
auto test_case = test::TestCase<TestEngine>(function);
|
||||
|
||||
test_case.add_input<int32_t>(Shape{3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
|
||||
test_case.add_input<int64_t>(Shape{2, 3}, {1, 2, 0, 2, 0, 0});
|
||||
test_case.add_expected_output<int32_t>(Shape{2, 3}, {4, 8, 3, 7, 2, 3});
|
||||
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_gather_elements_float_negative_axis)
|
||||
{
|
||||
const auto function = onnx_import::import_onnx_model(
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/gather_elements_float_negative_axis.prototxt"));
|
||||
auto test_case = test::TestCase<TestEngine>(function);
|
||||
|
||||
test_case.add_input<float>(Shape{2, 2}, {1, 2, 3, 4});
|
||||
test_case.add_input<int64_t>(Shape{2, 2}, {1, 1, 1, 0});
|
||||
test_case.add_expected_output<float>(Shape{2, 2}, {2, 2, 4, 3});
|
||||
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_gather_elements_float_3D_axis_2)
|
||||
{
|
||||
const auto function = onnx_import::import_onnx_model(
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/gather_elements_float_3D_axis_2.prototxt"));
|
||||
auto test_case = test::TestCase<TestEngine>(function);
|
||||
|
||||
test_case.add_input<float>(Shape{2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8});
|
||||
test_case.add_input<int64_t>(Shape{2, 2, 1}, {0, 1, 0, 1});
|
||||
test_case.add_expected_output<float>(Shape{2, 2, 1}, {1, 4, 5, 8});
|
||||
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_gatherND_int32)
|
||||
{
|
||||
const auto function = onnx_import::import_onnx_model(
|
||||
|
||||
@@ -1562,6 +1562,11 @@ IE_CPU.evaluate_1D_gather_elements_negative_test
|
||||
IE_CPU.evaluate_2D_gather_elements_negative_test
|
||||
IE_CPU.evaluate_2D_gather_elements_2x2x1_data_float32
|
||||
IE_CPU.evaluate_4D_gather_elements_3x2x2x2_indices_int64
|
||||
IE_CPU.onnx_model_gather_elements_float_1D
|
||||
IE_CPU.onnx_model_gather_elements_float_negative_axis
|
||||
IE_CPU.onnx_model_gather_elements_int32_axis_0
|
||||
IE_CPU.onnx_model_gather_elements_int8_axis_1
|
||||
IE_CPU.onnx_model_gather_elements_float_3D_axis_2
|
||||
IE_GPU.evaluate_1D_gather_elements_3_indices_int32
|
||||
IE_GPU.evaluate_2D_gather_elements_2x2_indices_int32_axis_0
|
||||
IE_GPU.evaluate_2D_gather_elements_2x2_indices_int32_axis_1
|
||||
@@ -1575,6 +1580,11 @@ IE_GPU.evaluate_1D_gather_elements_negative_test
|
||||
IE_GPU.evaluate_2D_gather_elements_negative_test
|
||||
IE_GPU.evaluate_2D_gather_elements_2x2x1_data_float32
|
||||
IE_GPU.evaluate_4D_gather_elements_3x2x2x2_indices_int64
|
||||
IE_GPU.onnx_model_gather_elements_float_1D
|
||||
IE_GPU.onnx_model_gather_elements_float_negative_axis
|
||||
IE_GPU.onnx_model_gather_elements_int32_axis_0
|
||||
IE_GPU.onnx_model_gather_elements_int8_axis_1
|
||||
IE_GPU.onnx_model_gather_elements_float_3D_axis_2
|
||||
|
||||
# incorrect result for Minimum if u16 type is unsupported
|
||||
minimum_u16
|
||||
|
||||
@@ -210,6 +210,8 @@ std::set<NodeTypeInfo> test::IE_Engine::get_ie_ops() const
|
||||
ie_ops.insert(opset4.begin(), opset4.end());
|
||||
const auto& opset5 = get_opset5().get_type_info_set();
|
||||
ie_ops.insert(opset5.begin(), opset5.end());
|
||||
const auto& opset6 = get_opset6().get_type_info_set();
|
||||
ie_ops.insert(opset6.begin(), opset6.end());
|
||||
return ie_ops;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user