Extend nGraph for operation GatherND-5 and implement reference (#2587)

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>
This commit is contained in:
Roman Kazantsev 2020-10-14 12:20:22 +03:00 committed by GitHub
parent 6d72110365
commit 9956639531
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 1245 additions and 307 deletions

View File

@ -0,0 +1,122 @@
// Copyright (C) 2018-2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <string>
#include "ngraph_reader_tests.hpp"
TEST_F(NGraphReaderTests, ReadGatherNDNetwork) {
std::string model = R"V0G0N(
<net name="Network" version="10">
<layers>
<layer id="0" name="params_x" type="Parameter" version="opset1">
<data element_type="f32" shape="10,20,30"/>
<output>
<port id="0" precision="FP32">
<dim>10</dim>
<dim>20</dim>
<dim>30</dim>
</port>
</output>
</layer>
<layer id="1" name="indices_y" type="Parameter" version="opset1">
<data element_type="i32" shape="10,3,2"/>
<output>
<port id="0" precision="I32">
<dim>10</dim>
<dim>3</dim>
<dim>2</dim>
</port>
</output>
</layer>
<layer id="2" name="MyGatherND" type="GatherND" version="opset5">
<data batch_dims="0"/>
<input>
<port id="0">
<dim>10</dim>
<dim>20</dim>
<dim>30</dim>
</port>
<port id="1">
<dim>10</dim>
<dim>3</dim>
<dim>2</dim>
</port>
</input>
<output>
<port id="2" precision="FP32">
<dim>10</dim>
<dim>3</dim>
<dim>30</dim>
</port>
</output>
</layer>
<layer id="3" name="MyGatherND/sink_port_0" type="Result" version="opset1">
<input>
<port id="0">
<dim>10</dim>
<dim>3</dim>
<dim>30</dim>
</port>
</input>
</layer>
</layers>
<edges>
<edge from-layer="0" from-port="0" to-layer="2" to-port="0"/>
<edge from-layer="1" from-port="0" to-layer="2" to-port="1"/>
<edge from-layer="2" from-port="2" to-layer="3" to-port="0"/>
</edges>
</net>
)V0G0N";
std::string modelV5 = R"V0G0N(
<net name="Network" version="5" precision="FP32" batch="1">
<layers>
<layer id="0" name="params_x" type="Input" precision="FP32">
<output>
<port id="0" precision="FP32">
<dim>10</dim>
<dim>20</dim>
<dim>30</dim>
</port>
</output>
</layer>
<layer id="1" name="indices_y" type="Input" precision="I32">
<output>
<port id="0" precision="I32">
<dim>10</dim>
<dim>3</dim>
<dim>2</dim>
</port>
</output>
</layer>
<layer id="2" name="MyGatherND" type="GatherND" version="opset5">
<data batch_dims="0"/>
<input>
<port id="0">
<dim>10</dim>
<dim>20</dim>
<dim>30</dim>
</port>
<port id="1">
<dim>10</dim>
<dim>3</dim>
<dim>2</dim>
</port>
</input>
<output>
<port id="2" precision="FP32">
<dim>10</dim>
<dim>3</dim>
<dim>30</dim>
</port>
</output>
</layer>
</layers>
<edges>
<edge from-layer="0" from-port="0" to-layer="2" to-port="0"/>
<edge from-layer="1" from-port="0" to-layer="2" to-port="1"/>
</edges>
</net>
)V0G0N";
compareIRs(model, modelV5, 10);
}

View File

@ -51,5 +51,36 @@ namespace ngraph
NGRAPH_SUPPRESS_DEPRECATED_START
using v0::GatherND;
NGRAPH_SUPPRESS_DEPRECATED_END
namespace v5
{
/// \brief GatherND operation
///
class NGRAPH_API GatherND : public Op
{
public:
NGRAPH_RTTI_DECLARATION;
GatherND() = default;
/// \brief Constructs a GatherND operation.
///
/// \param data Node producing data that are gathered
/// \param indices Node producing indices by which the operation gathers elements
/// or slices from data
/// \param batch_dims Specifies a number of batch dimensions
GatherND(const Output<Node>& data,
const Output<Node>& indices,
const size_t batch_dims = 0);
void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
size_t get_batch_dims() const { return m_batch_dims; }
private:
size_t m_batch_dims;
};
}
}
}

View File

@ -164,6 +164,7 @@ NGRAPH_OP(SoftPlus, ngraph::op::v4)
NGRAPH_OP(Swish, ngraph::op::v4)
// New operations added in opset5
NGRAPH_OP(GatherND, ngraph::op::v5)
NGRAPH_OP(LogSoftmax, ngraph::op::v5)
NGRAPH_OP(LSTMSequence, ngraph::op::v5)
NGRAPH_OP(GRUSequence, ngraph::op::v5)

View File

@ -30,12 +30,12 @@ namespace ngraph
// vector = indices[leaf_vector_index]
// out[leaf_vector_index:] = params[vector]
template <typename T, typename U>
void gather_nd(const T* params,
const U* indices,
T* out,
const Shape& params_shape,
const Shape& indices_shape,
const Shape& out_shape)
void gather_nd_batch(const T* params,
const U* indices,
T* out,
const Shape& params_shape,
const Shape& indices_shape,
const Shape& out_shape)
{
using namespace std;
// Create a CoordinateTransform for "indices" that visits only the first element
@ -105,6 +105,89 @@ namespace ngraph
out_coord_iter++;
}
}
template <typename T, typename U>
void gather_nd(const T* params,
const U* indices,
T* out,
const Shape& params_shape,
const Shape& indices_shape,
const Shape& out_shape,
int batch_dims = 0)
{
using namespace std;
if (batch_dims == 0)
{
gather_nd_batch(params, indices, out, params_shape, indices_shape, out_shape);
return;
}
size_t indices_ndim = static_cast<size_t>(indices_shape.size());
Coordinate indices_outer_start_corner(indices_ndim, 0);
Coordinate indices_outer_end_corner(indices_shape);
for (size_t i = batch_dims; i < indices_ndim; i++)
{
indices_outer_end_corner[i] = 1;
}
Strides indices_strides(indices_ndim, 1);
AxisVector indices_axis_order(indices_ndim);
std::iota(indices_axis_order.begin(), indices_axis_order.end(), 0);
CoordinateTransform indices_outer_transform(indices_shape,
indices_outer_start_corner,
indices_outer_end_corner,
indices_strides,
indices_axis_order);
size_t params_ndim = static_cast<size_t>(params_shape.size());
Coordinate params_outer_start_corner(params_ndim, 0);
Coordinate params_outer_end_corner(params_shape);
for (size_t i = batch_dims; i < params_ndim; i++)
{
params_outer_end_corner[i] = 1;
}
Strides params_strides(params_ndim, 1);
AxisVector params_axis_order(params_ndim);
std::iota(params_axis_order.begin(), params_axis_order.end(), 0);
CoordinateTransform params_outer_transform(params_shape,
params_outer_start_corner,
params_outer_end_corner,
params_strides,
params_axis_order);
size_t out_ndim = static_cast<size_t>(out_shape.size());
Coordinate out_start_corner(out_ndim, 0);
Coordinate out_end_corner(out_shape);
for (size_t i = 1; i < out_ndim; i++)
{
out_end_corner[i] = 1;
}
Strides out_strides(out_ndim, 1);
AxisVector out_axis_order(out_ndim);
std::iota(out_axis_order.begin(), out_axis_order.end(), 0);
CoordinateTransform out_transform(
out_shape, out_start_corner, out_end_corner, out_strides, out_axis_order);
Shape indices_shape_batch(indices_shape.begin() + batch_dims, indices_shape.end());
Shape params_shape_batch(params_shape.begin() + batch_dims, params_shape.end());
Shape output_shape_batch(out_shape.begin() + 1, out_shape.end());
auto out_coord_iter = out_transform.begin();
auto params_coord_iter = params_outer_transform.begin();
for (const Coordinate& indices_coord : indices_outer_transform)
{
auto indices_index = indices_outer_transform.index(indices_coord);
auto params_index = params_outer_transform.index(*params_coord_iter);
auto output_index = out_transform.index(*out_coord_iter);
gather_nd_batch(params + params_index,
indices + indices_index,
out + output_index,
params_shape_batch,
indices_shape_batch,
output_shape_batch);
out_coord_iter++;
params_coord_iter++;
}
}
}
}
}

View File

@ -17,11 +17,151 @@
#include "ngraph/op/gather_nd.hpp"
#include "ngraph/shape.hpp"
NGRAPH_SUPPRESS_DEPRECATED_START
using namespace std;
using namespace ngraph;
// ------------------------------ V5 ------------------------------
NGRAPH_RTTI_DEFINITION(op::v5::GatherND, "GatherND", 5);
op::v5::GatherND::GatherND(const Output<Node>& data,
const Output<Node>& indices,
const size_t batch_dims)
: Op({data, indices})
, m_batch_dims(batch_dims)
{
constructor_validate_and_infer_types();
}
void op::v5::GatherND::validate_and_infer_types()
{
// check types of input tensors
const auto& data_type = get_input_element_type(0);
const auto& indices_type = get_input_element_type(1);
NODE_VALIDATION_CHECK(this,
indices_type.is_integral_number(),
"The indices type is expected to be an integer type. Got: ",
indices_type);
// check ranks of input tensors
const auto& data_pshape = get_input_partial_shape(0);
const auto& indices_pshape = get_input_partial_shape(1);
if (data_pshape.rank().is_static())
{
NODE_VALIDATION_CHECK(
this, data_pshape.rank().get_length() > 0, "Data rank must be at least 1.");
NODE_VALIDATION_CHECK(this,
data_pshape.rank().get_length() > m_batch_dims,
"Number of batch dimensions must not exceed a rank of data.");
}
if (indices_pshape.rank().is_static())
{
NODE_VALIDATION_CHECK(
this, indices_pshape.rank().get_length() > 0, "Indices rank must be at least 1.");
NODE_VALIDATION_CHECK(this,
indices_pshape.rank().get_length() > m_batch_dims,
"Number of batch dimensions must not exceed a rank of indices.");
}
if (data_pshape.rank().is_static() && indices_pshape.rank().is_static())
{
// check that batch dimensions of data and indices are the same
for (auto batch_dim = 0; batch_dim < m_batch_dims; batch_dim++)
{
if (data_pshape[batch_dim].is_static() && indices_pshape[batch_dim].is_static())
{
NODE_VALIDATION_CHECK(this,
data_pshape[batch_dim].get_length() ==
indices_pshape[batch_dim].get_length(),
"Batch dimensions of data and indices must be the same.");
}
}
if (indices_pshape[indices_pshape.rank().get_length() - 1].is_static())
{
NODE_VALIDATION_CHECK(
this,
(indices_pshape[indices_pshape.rank().get_length() - 1].get_length() +
m_batch_dims) <= data_pshape.rank().get_length(),
"Length of a tuple with indices must not exceed a rank of data tensor excluding "
"batch dimensions.");
}
}
// set output shape
set_output_size(1);
if (data_pshape.rank().is_static() && indices_pshape.rank().is_static() &&
indices_pshape[indices_pshape.rank().get_length() - 1].is_static())
{
auto indices_tuple_length =
indices_pshape[indices_pshape.rank().get_length() - 1].get_length();
auto slice_length = data_pshape.rank().get_length() - indices_tuple_length - m_batch_dims;
auto output_indices_length = indices_pshape.rank().get_length() - m_batch_dims - 1;
auto output_rank = output_indices_length + slice_length;
size_t delta_output_rank = 0;
if (m_batch_dims > 0)
{
delta_output_rank = 1;
}
std::vector<Dimension> output_shape(output_rank + delta_output_rank);
if (m_batch_dims > 0)
{
output_shape[0] = 1;
for (auto dim = 0; dim < m_batch_dims; dim++)
{
if (data_pshape[dim].is_static())
{
output_shape[0] *= data_pshape[dim].get_length();
}
else if (indices_pshape[dim].is_static())
{
output_shape[0] *= indices_pshape[dim].get_length();
}
else
{
output_shape[0] = Dimension::dynamic();
break;
}
}
}
for (auto dim = 0; dim < output_indices_length; dim++)
{
output_shape[dim + delta_output_rank] = indices_pshape[dim + m_batch_dims];
}
for (auto dim = 0; dim < slice_length; dim++)
{
output_shape[output_indices_length + dim + delta_output_rank] =
data_pshape[m_batch_dims + indices_tuple_length + dim];
}
set_output_type(0, data_type, PartialShape(output_shape));
}
else
{
set_output_type(0, data_type, PartialShape{Dimension::dynamic()});
}
}
bool op::v5::GatherND::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("batch_dims", m_batch_dims);
return true;
}
shared_ptr<Node> op::v5::GatherND::clone_with_new_inputs(const OutputVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<op::v5::GatherND>(new_args.at(0), new_args.at(1), m_batch_dims);
}
// ------------------------------ V0 ------------------------------
NGRAPH_SUPPRESS_DEPRECATED_START
static int PARAMS = 0;
static int INDICES = 1;

View File

@ -76,6 +76,7 @@ from ngraph.opset5 import fake_quantize
from ngraph.opset5 import floor
from ngraph.opset5 import floor_mod
from ngraph.opset5 import gather
from ngraph.opset5 import gather_nd
from ngraph.opset5 import gather_tree
from ngraph.opset5 import gelu
from ngraph.opset5 import greater

View File

@ -63,6 +63,7 @@ from ngraph.opset1.ops import fake_quantize
from ngraph.opset1.ops import floor
from ngraph.opset1.ops import floor_mod
from ngraph.opset1.ops import gather
from ngraph.opset5.ops import gather_nd
from ngraph.opset1.ops import gather_tree
from ngraph.opset2.ops import gelu
from ngraph.opset1.ops import greater

View File

@ -58,6 +58,29 @@ _get_node_factory_opset5 = partial(_get_node_factory, "opset5")
# -------------------------------------------- ops ------------------------------------------------
@nameable_op
def gather_nd(
data: NodeInput,
indices: NodeInput,
batch_dims: Optional[int] = 0,
name: Optional[str] = None,
) -> Node:
"""Return a node which performs GatherND.
:param data: N-D tensor with data for gathering
:param indices: K-D tensor of tuples with indices by which data is gathered
:param batch_dims: Scalar value of batch dimensions
:return: The new node which performs GatherND
"""
inputs = as_nodes(data, indices)
attributes = {
"batch_dims": batch_dims
}
return _get_node_factory_opset5().create("GatherND", inputs, attributes)
@nameable_op
def log_softmax(data: NodeInput, axis: int, name: Optional[str] = None) -> Node:
"""Apply LogSoftmax operation on each element of input tensor.

View File

@ -102,7 +102,7 @@ namespace
return it->second();
}
const ngraph::OpSet& m_opset{ngraph::get_opset4()};
const ngraph::OpSet& m_opset{ngraph::get_opset5()};
};
}

View File

@ -17,6 +17,7 @@ import numpy as np
import pytest
import ngraph as ng
from ngraph.impl import Type
from tests.runtime import get_runtime
from tests.test_ngraph.util import run_op_node
@ -199,3 +200,18 @@ def test_select():
result = run_op_node([cond, then_node, else_node], ng.select)
assert np.allclose(result, excepted)
def test_gather_nd():
indices_type = np.int32
data_dtype = np.float32
data = ng.parameter([2, 10, 80, 30, 50], dtype=data_dtype, name="data")
indices = ng.parameter([2, 10, 30, 40, 2], dtype=indices_type, name="indices")
batch_dims = 2
expected_shape = [20, 30, 40, 50]
node = ng.gather_nd(data, indices, batch_dims)
assert node.get_type_name() == "GatherND"
assert node.get_output_size() == 1
assert list(node.get_output_shape(0)) == expected_shape
assert node.get_output_element_type(0) == Type.f32

View File

@ -287,6 +287,7 @@ set(MULTI_TEST_SRC
backend/function_name.in.cpp
backend/fused_op.in.cpp
backend/gather.in.cpp
backend/gather_nd.in.cpp
backend/gelu.in.cpp
backend/group_convolution.in.cpp
backend/interpolate.in.cpp

View File

@ -324,288 +324,6 @@ NGRAPH_TEST(${BACKEND_NAME}, gather_scalar_indices_axis_1_2d_input)
(vector<float>{1.0f, 2.0f, 3.0f}), read_vector<float>(result), MIN_FLOAT_TOLERANCE_BITS));
}
NGRAPH_TEST(${BACKEND_NAME}, gather_nd_single_indices)
{
Shape params_shape{3, 3};
Shape indices_shape{2};
Shape out_shape{};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto G = make_shared<op::GatherND>(P, I);
auto f = make_shared<Function>(G, ParameterVector{P, I});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto p = backend->create_tensor(element::f32, params_shape);
copy_data(p, vector<float>{1.0f, 1.1f, 1.2f, 1.3f, 1.4f, 1.5f, 1.6f, 1.7f, 1.8f});
auto i = backend->create_tensor(element::i32, indices_shape);
copy_data(i, vector<int32_t>{1, 2});
auto result = backend->create_tensor(element::f32, out_shape);
auto c = backend->compile(f);
c->call_with_validate({result}, {p, i});
EXPECT_TRUE(test::all_close_f(
(vector<float>{1.5f}), read_vector<float>(result), MIN_FLOAT_TOLERANCE_BITS));
}
NGRAPH_TEST(${BACKEND_NAME}, gather_nd_scalar_from_2d)
{
Shape params_shape{2, 2};
Shape indices_shape{2, 2};
Shape out_shape{2};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto G = make_shared<op::GatherND>(P, I);
auto f = make_shared<Function>(G, ParameterVector{P, I});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto p = backend->create_tensor(element::f32, params_shape);
copy_data(p, vector<float>{1.0f, 1.1f, 1.2f, 1.3f});
auto i = backend->create_tensor(element::i32, indices_shape);
copy_data(i, vector<int32_t>{0, 0, 1, 1});
auto result = backend->create_tensor(element::f32, out_shape);
auto c = backend->compile(f);
c->call_with_validate({result}, {p, i});
EXPECT_TRUE(test::all_close_f(
(vector<float>{1.0f, 1.3f}), read_vector<float>(result), MIN_FLOAT_TOLERANCE_BITS));
}
NGRAPH_TEST(${BACKEND_NAME}, gather_nd_1d_from_2d)
{
Shape params_shape{2, 2};
Shape indices_shape{2, 1};
Shape out_shape{2, 2};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto G = make_shared<op::GatherND>(P, I);
auto f = make_shared<Function>(G, ParameterVector{P, I});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto p = backend->create_tensor(element::f32, params_shape);
copy_data(p, vector<float>{1.0f, 1.1f, 1.2f, 1.3f});
auto i = backend->create_tensor(element::i32, indices_shape);
copy_data(i, vector<int32_t>{1, 0});
auto result = backend->create_tensor(element::f32, out_shape);
auto c = backend->compile(f);
c->call_with_validate({result}, {p, i});
EXPECT_TRUE(test::all_close_f((vector<float>{1.2f, 1.3f, 1.0f, 1.1f}),
read_vector<float>(result),
MIN_FLOAT_TOLERANCE_BITS));
}
NGRAPH_TEST(${BACKEND_NAME}, gather_nd_scalar_from_3d)
{
Shape params_shape{2, 2, 2};
Shape indices_shape{2, 3};
Shape out_shape{2};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto G = make_shared<op::GatherND>(P, I);
auto f = make_shared<Function>(G, ParameterVector{P, I});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto p = backend->create_tensor(element::f32, params_shape);
copy_data(p, vector<float>{1.0f, 1.1f, 1.2f, 1.3f, 2.0f, 2.1f, 2.2f, 2.3f});
auto i = backend->create_tensor(element::i32, indices_shape);
copy_data(i, vector<int32_t>{0, 0, 1, 1, 0, 1});
auto result = backend->create_tensor(element::f32, out_shape);
auto c = backend->compile(f);
c->call_with_validate({result}, {p, i});
EXPECT_TRUE(test::all_close_f(
(vector<float>{1.1f, 2.1f}), read_vector<float>(result), MIN_FLOAT_TOLERANCE_BITS));
}
NGRAPH_TEST(${BACKEND_NAME}, gather_nd_1d_from_3d)
{
Shape params_shape{2, 2, 2};
Shape indices_shape{2, 2};
Shape out_shape{2, 2};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto G = make_shared<op::GatherND>(P, I);
auto f = make_shared<Function>(G, ParameterVector{P, I});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto p = backend->create_tensor(element::f32, params_shape);
copy_data(p, vector<float>{1.0f, 1.1f, 1.2f, 1.3f, 2.0f, 2.1f, 2.2f, 2.3f});
auto i = backend->create_tensor(element::i32, indices_shape);
copy_data(i, vector<int32_t>{0, 1, 1, 0});
auto result = backend->create_tensor(element::f32, out_shape);
auto c = backend->compile(f);
c->call_with_validate({result}, {p, i});
EXPECT_TRUE(test::all_close_f((vector<float>{1.2f, 1.3f, 2.0f, 2.1f}),
read_vector<float>(result),
MIN_FLOAT_TOLERANCE_BITS));
}
NGRAPH_TEST(${BACKEND_NAME}, gather_nd_2d_from_3d)
{
Shape params_shape{2, 2, 2};
Shape indices_shape{1, 1};
Shape out_shape{1, 2, 2};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto G = make_shared<op::GatherND>(P, I);
auto f = make_shared<Function>(G, ParameterVector{P, I});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto p = backend->create_tensor(element::f32, params_shape);
copy_data(p, vector<float>{1.0f, 1.1f, 1.2f, 1.3f, 2.0f, 2.1f, 2.2f, 2.3f});
auto i = backend->create_tensor(element::i32, indices_shape);
copy_data(i, vector<int32_t>{1});
auto result = backend->create_tensor(element::f32, out_shape);
auto c = backend->compile(f);
c->call_with_validate({result}, {p, i});
EXPECT_TRUE(test::all_close_f((vector<float>{2.0f, 2.1f, 2.2f, 2.3f}),
read_vector<float>(result),
MIN_FLOAT_TOLERANCE_BITS));
}
NGRAPH_TEST(${BACKEND_NAME}, gather_nd_batch_scalar_from_2d)
{
Shape params_shape{2, 2};
Shape indices_shape{2, 1, 2};
Shape out_shape{2, 1};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto G = make_shared<op::GatherND>(P, I);
auto f = make_shared<Function>(G, ParameterVector{P, I});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto p = backend->create_tensor(element::f32, params_shape);
copy_data(p, vector<float>{1.0f, 1.1f, 1.2f, 1.3f});
auto i = backend->create_tensor(element::i32, indices_shape);
copy_data(i, vector<int32_t>{0, 0, 0, 1});
auto result = backend->create_tensor(element::f32, out_shape);
auto c = backend->compile(f);
c->call_with_validate({result}, {p, i});
EXPECT_TRUE(test::all_close_f(
(vector<float>{1.0f, 1.1f}), read_vector<float>(result), MIN_FLOAT_TOLERANCE_BITS));
}
NGRAPH_TEST(${BACKEND_NAME}, gather_nd_batch_1d_from_2d)
{
Shape params_shape{2, 2};
Shape indices_shape{2, 1, 1};
Shape out_shape{2, 1, 2};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto G = make_shared<op::GatherND>(P, I);
auto f = make_shared<Function>(G, ParameterVector{P, I});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto p = backend->create_tensor(element::f32, params_shape);
copy_data(p, vector<float>{1.0f, 1.1f, 1.2f, 1.3f});
auto i = backend->create_tensor(element::i32, indices_shape);
copy_data(i, vector<int32_t>{1, 0});
auto result = backend->create_tensor(element::f32, out_shape);
auto c = backend->compile(f);
c->call_with_validate({result}, {p, i});
EXPECT_TRUE(test::all_close_f((vector<float>{1.2f, 1.3f, 1.0f, 1.1f}),
read_vector<float>(result),
MIN_FLOAT_TOLERANCE_BITS));
}
NGRAPH_TEST(${BACKEND_NAME}, gather_nd_batch_scalar_from_3d)
{
Shape params_shape{2, 2, 2};
Shape indices_shape{2, 2, 3};
Shape out_shape{2, 2};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto G = make_shared<op::GatherND>(P, I);
auto f = make_shared<Function>(G, ParameterVector{P, I});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto p = backend->create_tensor(element::f32, params_shape);
copy_data(p, vector<float>{1.0f, 1.1f, 1.2f, 1.3f, 2.0f, 2.1f, 2.2f, 2.3f});
auto i = backend->create_tensor(element::i32, indices_shape);
copy_data(i, vector<int32_t>{0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0});
auto result = backend->create_tensor(element::f32, out_shape);
auto c = backend->compile(f);
c->call_with_validate({result}, {p, i});
EXPECT_TRUE(test::all_close_f((vector<float>{1.1f, 2.1f, 1.3f, 2.2f}),
read_vector<float>(result),
MIN_FLOAT_TOLERANCE_BITS));
}
NGRAPH_TEST(${BACKEND_NAME}, gather_nd_batch_1d_from_3d)
{
Shape params_shape{2, 2, 2};
Shape indices_shape{2, 2, 2};
Shape out_shape{2, 2, 2};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto G = make_shared<op::GatherND>(P, I);
auto f = make_shared<Function>(G, ParameterVector{P, I});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto p = backend->create_tensor(element::f32, params_shape);
copy_data(p, vector<float>{1.0f, 1.1f, 1.2f, 1.3f, 2.0f, 2.1f, 2.2f, 2.3f});
auto i = backend->create_tensor(element::i32, indices_shape);
copy_data(i, vector<int32_t>{0, 1, 1, 0, 0, 0, 1, 1});
auto result = backend->create_tensor(element::f32, out_shape);
auto c = backend->compile(f);
c->call_with_validate({result}, {p, i});
EXPECT_TRUE(test::all_close_f((vector<float>{1.2f, 1.3f, 2.0f, 2.1f, 1.0f, 1.1f, 2.2f, 2.3f}),
read_vector<float>(result),
MIN_FLOAT_TOLERANCE_BITS));
}
NGRAPH_TEST(${BACKEND_NAME}, gather_nd_batch_2d_from_3d)
{
Shape params_shape{2, 2, 2};
Shape indices_shape{2, 1, 1};
Shape out_shape{2, 1, 2, 2};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto G = make_shared<op::GatherND>(P, I);
auto f = make_shared<Function>(G, ParameterVector{P, I});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto p = backend->create_tensor(element::f32, params_shape);
copy_data(p, vector<float>{1.0f, 1.1f, 1.2f, 1.3f, 2.0f, 2.1f, 2.2f, 2.3f});
auto i = backend->create_tensor(element::i32, indices_shape);
copy_data(i, vector<int32_t>{1, 0});
auto result = backend->create_tensor(element::f32, out_shape);
auto c = backend->compile(f);
c->call_with_validate({result}, {p, i});
EXPECT_TRUE(test::all_close_f((vector<float>{2.0f, 2.1f, 2.2f, 2.3f, 1.0f, 1.1f, 1.2f, 1.3f}),
read_vector<float>(result),
MIN_FLOAT_TOLERANCE_BITS));
}
NGRAPH_TEST(${BACKEND_NAME}, gather_no_axis_int8)
{
Shape params_shape{3, 2};

View File

@ -0,0 +1,494 @@
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <algorithm>
#include <cinttypes>
#include <cmath>
#include <cstdlib>
#include <random>
#include <string>
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/runtime/tensor.hpp"
#include "runtime/backend.hpp"
#include "util/all_close.hpp"
#include "util/all_close_f.hpp"
#include "util/ndarray.hpp"
#include "util/random.hpp"
#include "util/test_case.hpp"
#include "util/test_control.hpp"
#include "util/test_tools.hpp"
NGRAPH_SUPPRESS_DEPRECATED_START
using namespace std;
using namespace ngraph;
static string s_manifest = "${MANIFEST}";
NGRAPH_TEST(${BACKEND_NAME}, gather_nd_single_indices)
{
Shape params_shape{3, 3};
Shape indices_shape{2};
Shape out_shape{};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto G = make_shared<op::GatherND>(P, I);
auto f = make_shared<Function>(G, ParameterVector{P, I});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto p = backend->create_tensor(element::f32, params_shape);
copy_data(p, vector<float>{1.0f, 1.1f, 1.2f, 1.3f, 1.4f, 1.5f, 1.6f, 1.7f, 1.8f});
auto i = backend->create_tensor(element::i32, indices_shape);
copy_data(i, vector<int32_t>{1, 2});
auto result = backend->create_tensor(element::f32, out_shape);
auto c = backend->compile(f);
c->call_with_validate({result}, {p, i});
EXPECT_TRUE(test::all_close_f(
(vector<float>{1.5f}), read_vector<float>(result), MIN_FLOAT_TOLERANCE_BITS));
auto G5 = make_shared<op::v5::GatherND>(P, I);
auto f5 = make_shared<Function>(G5, ParameterVector{P, I});
auto c5 = backend->compile(f5);
c5->call_with_validate({result}, {p, i});
EXPECT_TRUE(test::all_close_f(
(vector<float>{1.5f}), read_vector<float>(result), MIN_FLOAT_TOLERANCE_BITS));
}
NGRAPH_TEST(${BACKEND_NAME}, gather_nd_scalar_from_2d)
{
Shape params_shape{2, 2};
Shape indices_shape{2, 2};
Shape out_shape{2};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto G = make_shared<op::GatherND>(P, I);
auto f = make_shared<Function>(G, ParameterVector{P, I});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto p = backend->create_tensor(element::f32, params_shape);
copy_data(p, vector<float>{1.0f, 1.1f, 1.2f, 1.3f});
auto i = backend->create_tensor(element::i32, indices_shape);
copy_data(i, vector<int32_t>{0, 0, 1, 1});
auto result = backend->create_tensor(element::f32, out_shape);
auto c = backend->compile(f);
c->call_with_validate({result}, {p, i});
EXPECT_TRUE(test::all_close_f(
(vector<float>{1.0f, 1.3f}), read_vector<float>(result), MIN_FLOAT_TOLERANCE_BITS));
auto G5 = make_shared<op::v5::GatherND>(P, I);
auto f5 = make_shared<Function>(G5, ParameterVector{P, I});
auto c5 = backend->compile(f5);
c5->call_with_validate({result}, {p, i});
EXPECT_TRUE(test::all_close_f(
(vector<float>{1.0f, 1.3f}), read_vector<float>(result), MIN_FLOAT_TOLERANCE_BITS));
}
NGRAPH_TEST(${BACKEND_NAME}, gather_nd_1d_from_2d)
{
Shape params_shape{2, 2};
Shape indices_shape{2, 1};
Shape out_shape{2, 2};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto G = make_shared<op::GatherND>(P, I);
auto f = make_shared<Function>(G, ParameterVector{P, I});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto p = backend->create_tensor(element::f32, params_shape);
copy_data(p, vector<float>{1.0f, 1.1f, 1.2f, 1.3f});
auto i = backend->create_tensor(element::i32, indices_shape);
copy_data(i, vector<int32_t>{1, 0});
auto result = backend->create_tensor(element::f32, out_shape);
auto c = backend->compile(f);
c->call_with_validate({result}, {p, i});
EXPECT_TRUE(test::all_close_f((vector<float>{1.2f, 1.3f, 1.0f, 1.1f}),
read_vector<float>(result),
MIN_FLOAT_TOLERANCE_BITS));
auto G5 = make_shared<op::v5::GatherND>(P, I);
auto f5 = make_shared<Function>(G5, ParameterVector{P, I});
auto c5 = backend->compile(f5);
c5->call_with_validate({result}, {p, i});
EXPECT_TRUE(test::all_close_f((vector<float>{1.2f, 1.3f, 1.0f, 1.1f}),
read_vector<float>(result),
MIN_FLOAT_TOLERANCE_BITS));
}
NGRAPH_TEST(${BACKEND_NAME}, gather_nd_scalar_from_3d)
{
Shape params_shape{2, 2, 2};
Shape indices_shape{2, 3};
Shape out_shape{2};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto G = make_shared<op::GatherND>(P, I);
auto f = make_shared<Function>(G, ParameterVector{P, I});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto p = backend->create_tensor(element::f32, params_shape);
copy_data(p, vector<float>{1.0f, 1.1f, 1.2f, 1.3f, 2.0f, 2.1f, 2.2f, 2.3f});
auto i = backend->create_tensor(element::i32, indices_shape);
copy_data(i, vector<int32_t>{0, 0, 1, 1, 0, 1});
auto result = backend->create_tensor(element::f32, out_shape);
auto c = backend->compile(f);
c->call_with_validate({result}, {p, i});
EXPECT_TRUE(test::all_close_f(
(vector<float>{1.1f, 2.1f}), read_vector<float>(result), MIN_FLOAT_TOLERANCE_BITS));
auto G5 = make_shared<op::v5::GatherND>(P, I);
auto f5 = make_shared<Function>(G5, ParameterVector{P, I});
auto c5 = backend->compile(f5);
c5->call_with_validate({result}, {p, i});
EXPECT_TRUE(test::all_close_f(
(vector<float>{1.1f, 2.1f}), read_vector<float>(result), MIN_FLOAT_TOLERANCE_BITS));
}
NGRAPH_TEST(${BACKEND_NAME}, gather_nd_1d_from_3d)
{
Shape params_shape{2, 2, 2};
Shape indices_shape{2, 2};
Shape out_shape{2, 2};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto G = make_shared<op::GatherND>(P, I);
auto f = make_shared<Function>(G, ParameterVector{P, I});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto p = backend->create_tensor(element::f32, params_shape);
copy_data(p, vector<float>{1.0f, 1.1f, 1.2f, 1.3f, 2.0f, 2.1f, 2.2f, 2.3f});
auto i = backend->create_tensor(element::i32, indices_shape);
copy_data(i, vector<int32_t>{0, 1, 1, 0});
auto result = backend->create_tensor(element::f32, out_shape);
auto c = backend->compile(f);
c->call_with_validate({result}, {p, i});
EXPECT_TRUE(test::all_close_f((vector<float>{1.2f, 1.3f, 2.0f, 2.1f}),
read_vector<float>(result),
MIN_FLOAT_TOLERANCE_BITS));
auto G5 = make_shared<op::v5::GatherND>(P, I);
auto f5 = make_shared<Function>(G5, ParameterVector{P, I});
auto c5 = backend->compile(f5);
c5->call_with_validate({result}, {p, i});
EXPECT_TRUE(test::all_close_f((vector<float>{1.2f, 1.3f, 2.0f, 2.1f}),
read_vector<float>(result),
MIN_FLOAT_TOLERANCE_BITS));
}
NGRAPH_TEST(${BACKEND_NAME}, gather_nd_2d_from_3d)
{
Shape params_shape{2, 2, 2};
Shape indices_shape{1, 1};
Shape out_shape{1, 2, 2};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto G = make_shared<op::GatherND>(P, I);
auto f = make_shared<Function>(G, ParameterVector{P, I});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto p = backend->create_tensor(element::f32, params_shape);
copy_data(p, vector<float>{1.0f, 1.1f, 1.2f, 1.3f, 2.0f, 2.1f, 2.2f, 2.3f});
auto i = backend->create_tensor(element::i32, indices_shape);
copy_data(i, vector<int32_t>{1});
auto result = backend->create_tensor(element::f32, out_shape);
auto c = backend->compile(f);
c->call_with_validate({result}, {p, i});
EXPECT_TRUE(test::all_close_f((vector<float>{2.0f, 2.1f, 2.2f, 2.3f}),
read_vector<float>(result),
MIN_FLOAT_TOLERANCE_BITS));
auto G5 = make_shared<op::v5::GatherND>(P, I);
auto f5 = make_shared<Function>(G5, ParameterVector{P, I});
auto c5 = backend->compile(f5);
c5->call_with_validate({result}, {p, i});
EXPECT_TRUE(test::all_close_f((vector<float>{2.0f, 2.1f, 2.2f, 2.3f}),
read_vector<float>(result),
MIN_FLOAT_TOLERANCE_BITS));
}
NGRAPH_TEST(${BACKEND_NAME}, gather_nd_batch_scalar_from_2d)
{
Shape params_shape{2, 2};
Shape indices_shape{2, 1, 2};
Shape out_shape{2, 1};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto G = make_shared<op::GatherND>(P, I);
auto f = make_shared<Function>(G, ParameterVector{P, I});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto p = backend->create_tensor(element::f32, params_shape);
copy_data(p, vector<float>{1.0f, 1.1f, 1.2f, 1.3f});
auto i = backend->create_tensor(element::i32, indices_shape);
copy_data(i, vector<int32_t>{0, 0, 0, 1});
auto result = backend->create_tensor(element::f32, out_shape);
auto c = backend->compile(f);
c->call_with_validate({result}, {p, i});
EXPECT_TRUE(test::all_close_f(
(vector<float>{1.0f, 1.1f}), read_vector<float>(result), MIN_FLOAT_TOLERANCE_BITS));
auto G5 = make_shared<op::v5::GatherND>(P, I);
auto f5 = make_shared<Function>(G5, ParameterVector{P, I});
auto c5 = backend->compile(f5);
c5->call_with_validate({result}, {p, i});
EXPECT_TRUE(test::all_close_f(
(vector<float>{1.0f, 1.1f}), read_vector<float>(result), MIN_FLOAT_TOLERANCE_BITS));
}
NGRAPH_TEST(${BACKEND_NAME}, gather_nd_batch_1d_from_2d)
{
Shape params_shape{2, 2};
Shape indices_shape{2, 1, 1};
Shape out_shape{2, 1, 2};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto G = make_shared<op::GatherND>(P, I);
auto f = make_shared<Function>(G, ParameterVector{P, I});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto p = backend->create_tensor(element::f32, params_shape);
copy_data(p, vector<float>{1.0f, 1.1f, 1.2f, 1.3f});
auto i = backend->create_tensor(element::i32, indices_shape);
copy_data(i, vector<int32_t>{1, 0});
auto result = backend->create_tensor(element::f32, out_shape);
auto c = backend->compile(f);
c->call_with_validate({result}, {p, i});
EXPECT_TRUE(test::all_close_f((vector<float>{1.2f, 1.3f, 1.0f, 1.1f}),
read_vector<float>(result),
MIN_FLOAT_TOLERANCE_BITS));
auto G5 = make_shared<op::v5::GatherND>(P, I);
auto f5 = make_shared<Function>(G5, ParameterVector{P, I});
auto c5 = backend->compile(f5);
c5->call_with_validate({result}, {p, i});
EXPECT_TRUE(test::all_close_f((vector<float>{1.2f, 1.3f, 1.0f, 1.1f}),
read_vector<float>(result),
MIN_FLOAT_TOLERANCE_BITS));
}
NGRAPH_TEST(${BACKEND_NAME}, gather_nd_batch_scalar_from_3d)
{
Shape params_shape{2, 2, 2};
Shape indices_shape{2, 2, 3};
Shape out_shape{2, 2};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto G = make_shared<op::GatherND>(P, I);
auto f = make_shared<Function>(G, ParameterVector{P, I});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto p = backend->create_tensor(element::f32, params_shape);
copy_data(p, vector<float>{1.0f, 1.1f, 1.2f, 1.3f, 2.0f, 2.1f, 2.2f, 2.3f});
auto i = backend->create_tensor(element::i32, indices_shape);
copy_data(i, vector<int32_t>{0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0});
auto result = backend->create_tensor(element::f32, out_shape);
auto c = backend->compile(f);
c->call_with_validate({result}, {p, i});
EXPECT_TRUE(test::all_close_f((vector<float>{1.1f, 2.1f, 1.3f, 2.2f}),
read_vector<float>(result),
MIN_FLOAT_TOLERANCE_BITS));
auto G5 = make_shared<op::v5::GatherND>(P, I);
auto f5 = make_shared<Function>(G5, ParameterVector{P, I});
auto c5 = backend->compile(f5);
c5->call_with_validate({result}, {p, i});
EXPECT_TRUE(test::all_close_f((vector<float>{1.1f, 2.1f, 1.3f, 2.2f}),
read_vector<float>(result),
MIN_FLOAT_TOLERANCE_BITS));
}
NGRAPH_TEST(${BACKEND_NAME}, gather_nd_batch_1d_from_3d)
{
Shape params_shape{2, 2, 2};
Shape indices_shape{2, 2, 2};
Shape out_shape{2, 2, 2};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto G = make_shared<op::GatherND>(P, I);
auto f = make_shared<Function>(G, ParameterVector{P, I});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto p = backend->create_tensor(element::f32, params_shape);
copy_data(p, vector<float>{1.0f, 1.1f, 1.2f, 1.3f, 2.0f, 2.1f, 2.2f, 2.3f});
auto i = backend->create_tensor(element::i32, indices_shape);
copy_data(i, vector<int32_t>{0, 1, 1, 0, 0, 0, 1, 1});
auto result = backend->create_tensor(element::f32, out_shape);
auto c = backend->compile(f);
c->call_with_validate({result}, {p, i});
EXPECT_TRUE(test::all_close_f((vector<float>{1.2f, 1.3f, 2.0f, 2.1f, 1.0f, 1.1f, 2.2f, 2.3f}),
read_vector<float>(result),
MIN_FLOAT_TOLERANCE_BITS));
auto G5 = make_shared<op::v5::GatherND>(P, I);
auto f5 = make_shared<Function>(G5, ParameterVector{P, I});
auto c5 = backend->compile(f5);
c5->call_with_validate({result}, {p, i});
EXPECT_TRUE(test::all_close_f((vector<float>{1.2f, 1.3f, 2.0f, 2.1f, 1.0f, 1.1f, 2.2f, 2.3f}),
read_vector<float>(result),
MIN_FLOAT_TOLERANCE_BITS));
}
NGRAPH_TEST(${BACKEND_NAME}, gather_nd_batch_2d_from_3d)
{
Shape params_shape{2, 2, 2};
Shape indices_shape{2, 1, 1};
Shape out_shape{2, 1, 2, 2};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto G = make_shared<op::GatherND>(P, I);
auto f = make_shared<Function>(G, ParameterVector{P, I});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto p = backend->create_tensor(element::f32, params_shape);
copy_data(p, vector<float>{1.0f, 1.1f, 1.2f, 1.3f, 2.0f, 2.1f, 2.2f, 2.3f});
auto i = backend->create_tensor(element::i32, indices_shape);
copy_data(i, vector<int32_t>{1, 0});
auto result = backend->create_tensor(element::f32, out_shape);
auto c = backend->compile(f);
c->call_with_validate({result}, {p, i});
EXPECT_TRUE(test::all_close_f((vector<float>{2.0f, 2.1f, 2.2f, 2.3f, 1.0f, 1.1f, 1.2f, 1.3f}),
read_vector<float>(result),
MIN_FLOAT_TOLERANCE_BITS));
auto G5 = make_shared<op::v5::GatherND>(P, I);
auto f5 = make_shared<Function>(G5, ParameterVector{P, I});
auto c5 = backend->compile(f5);
c5->call_with_validate({result}, {p, i});
EXPECT_TRUE(test::all_close_f((vector<float>{2.0f, 2.1f, 2.2f, 2.3f, 1.0f, 1.1f, 1.2f, 1.3f}),
read_vector<float>(result),
MIN_FLOAT_TOLERANCE_BITS));
}
NGRAPH_TEST(${BACKEND_NAME}, gather_nd_batch_dims1)
{
Shape params_shape{2, 3, 4};
Shape indices_shape{2, 1};
Shape out_shape{2, 4};
int batch_dims = 1;
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto G = make_shared<op::v5::GatherND>(P, I, batch_dims);
auto f = make_shared<Function>(G, ParameterVector{P, I});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto p = backend->create_tensor(element::f32, params_shape);
copy_data(p, vector<float>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24});
auto i = backend->create_tensor(element::i32, indices_shape);
copy_data(i, vector<int32_t>{1, 0});
auto result = backend->create_tensor(element::f32, out_shape);
auto c = backend->compile(f);
c->call_with_validate({result}, {p, i});
EXPECT_TRUE(test::all_close_f((vector<float>{5, 6, 7, 8, 13, 14, 15, 16}),
read_vector<float>(result),
MIN_FLOAT_TOLERANCE_BITS));
}
NGRAPH_TEST(${BACKEND_NAME}, gather_nd_batch_dims2)
{
Shape params_shape{2, 3, 4, 2};
Shape indices_shape{2, 3, 3, 2};
Shape out_shape{6, 3};
int batch_dims = 2;
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto G = make_shared<op::v5::GatherND>(P, I, batch_dims);
auto f = make_shared<Function>(G, ParameterVector{P, I});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto p = backend->create_tensor(element::f32, params_shape);
copy_data(p, vector<float>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32,
33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48});
auto i = backend->create_tensor(element::i32, indices_shape);
copy_data(i, vector<int32_t>{1, 0, 3, 1, 2, 1, 0, 1, 1, 1, 2, 0, 3, 0, 3, 1, 2, 1,
2, 0, 1, 1, 3, 1, 1, 1, 2, 0, 2, 0, 0, 0, 3, 1, 3, 1});
auto result = backend->create_tensor(element::f32, out_shape);
auto c = backend->compile(f);
c->call_with_validate({result}, {p, i});
EXPECT_TRUE(test::all_close_f(
(vector<float>{3, 8, 6, 10, 12, 13, 23, 24, 22, 29, 28, 32, 36, 37, 37, 41, 48, 48}),
read_vector<float>(result),
MIN_FLOAT_TOLERANCE_BITS));
}
NGRAPH_TEST(${BACKEND_NAME}, gather_nd_batch_dims2_lead_dims)
{
Shape params_shape{2, 3, 4};
Shape indices_shape{2, 3, 1, 1};
Shape out_shape{6, 1};
int batch_dims = 2;
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto G = make_shared<op::v5::GatherND>(P, I, batch_dims);
auto f = make_shared<Function>(G, ParameterVector{P, I});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto p = backend->create_tensor(element::f32, params_shape);
copy_data(p, vector<float>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24});
auto i = backend->create_tensor(element::i32, indices_shape);
copy_data(i, vector<int32_t>{1, 0, 2, 0, 2, 2});
auto result = backend->create_tensor(element::f32, out_shape);
auto c = backend->compile(f);
c->call_with_validate({result}, {p, i});
EXPECT_TRUE(test::all_close_f((vector<float>{2, 5, 11, 13, 19, 23}),
read_vector<float>(result),
MIN_FLOAT_TOLERANCE_BITS));
}

View File

@ -787,6 +787,9 @@ gather_nd_batch_1d_from_2d
gather_nd_batch_scalar_from_3d
gather_nd_batch_1d_from_3d
gather_nd_batch_2d_from_3d
gather_nd_batch_dims1
gather_nd_batch_dims2
gather_nd_batch_dims2_lead_dims
# Cannot cast ngraph node Stack to CNNLayer!
stack_matrix_rowise

View File

@ -712,6 +712,35 @@ protected:
}
break;
}
case OP_TYPEID::GatherND_v5:
{
const op::v5::GatherND* gatherNDNode = static_cast<const op::v5::GatherND*>(&node);
if (node.get_input_element_type(1) == element::i64)
{
reference::gather_nd<T, int64_t>(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<int64_t>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(0),
node.get_input_shape(1),
node.get_output_shape(0),
gatherNDNode->get_batch_dims());
}
else if (node.get_input_element_type(1) == element::i32)
{
reference::gather_nd<T, int32_t>(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<int32_t>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(0),
node.get_input_shape(1),
node.get_output_shape(0),
gatherNDNode->get_batch_dims());
}
else
{
throw ngraph_error("Unexpected type");
}
break;
}
case OP_TYPEID::GRUCell_v3:
{
const op::v3::GRUCell* gru_cell = static_cast<const op::v3::GRUCell*>(&node);

View File

@ -51,6 +51,7 @@ NGRAPH_OP(LSTMCell, op::v4)
#undef ID_SUFFIX
#define ID_SUFFIX(NAME) NAME##_v5
NGRAPH_OP(GatherND, op::v5)
NGRAPH_OP(LSTMSequence, op::v5)
NGRAPH_OP(GRUSequence, op::v5)
NGRAPH_OP(RNNSequence, op::v5)

View File

@ -18,11 +18,160 @@
#include "ngraph/ngraph.hpp"
#include "util/type_prop.hpp"
NGRAPH_SUPPRESS_DEPRECATED_START
using namespace std;
using namespace ngraph;
// ------------------------------ V5 ------------------------------
TEST(type_prop, gather_nd_slices_from_4d_batch_dims0)
{
Shape params_shape{2, 3, 11, 12};
Shape indices_shape{2, 3, 2};
Shape out_shape{2, 3, 11, 12};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto G5 = make_shared<op::v5::GatherND>(P, I, 0);
ASSERT_EQ(G5->get_element_type(), element::f32);
ASSERT_EQ(G5->get_shape(), out_shape);
}
TEST(type_prop, gather_nd_scalars_from_4d_batch_dims2)
{
Shape params_shape{2, 3, 11, 12};
Shape indices_shape{2, 3, 2};
Shape out_shape{6};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto G5 = make_shared<op::v5::GatherND>(P, I, 2);
ASSERT_EQ(G5->get_element_type(), element::f32);
ASSERT_EQ(G5->get_shape(), out_shape);
}
TEST(type_prop, gather_nd_slices_from_5d_batch_dims2)
{
Shape params_shape{7, 5, 11, 12, 32};
Shape indices_shape{7, 5, 3, 1};
Shape out_shape{35, 3, 12, 32};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto G5 = make_shared<op::v5::GatherND>(P, I, 2);
ASSERT_EQ(G5->get_element_type(), element::f32);
ASSERT_EQ(G5->get_shape(), out_shape);
}
TEST(type_prop, gather_nd_batch_dim2_with_dyn_dim)
{
PartialShape params_shape{7, Dimension::dynamic(), 11, 12, 32};
Shape indices_shape{7, 5, 3, 1};
Shape out_shape{35, 3, 12, 32};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto G5 = make_shared<op::v5::GatherND>(P, I, 2);
ASSERT_EQ(G5->get_element_type(), element::f32);
ASSERT_EQ(G5->get_shape(), out_shape);
}
TEST(type_prop, gather_nd_batch_dim2_with_dyn_dim2)
{
PartialShape params_shape{7, Dimension::dynamic(), Dimension::dynamic(), 12, 32};
Shape indices_shape{7, 5, 3, 1};
Shape out_shape{35, 3, 12, 32};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto G5 = make_shared<op::v5::GatherND>(P, I, 2);
ASSERT_EQ(G5->get_element_type(), element::f32);
ASSERT_EQ(G5->get_shape(), out_shape);
}
TEST(type_prop, gather_nd_batch_dim2_with_dyn_dim3)
{
PartialShape params_shape{
7, Dimension::dynamic(), Dimension::dynamic(), 12, Dimension::dynamic()};
Shape indices_shape{7, 5, 3, 1};
PartialShape out_shape{35, 3, 12, Dimension::dynamic()};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto G5 = make_shared<op::v5::GatherND>(P, I, 2);
ASSERT_EQ(G5->get_element_type(), element::f32);
ASSERT_TRUE(G5->get_output_partial_shape(0).same_scheme(out_shape));
}
TEST(type_prop, gather_nd_fail_batch_dims_greater_indices_rank)
{
Shape params_shape{2, 3, 4, 5};
Shape indices_shape{2, 1};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
try
{
auto G5 = make_shared<op::v5::GatherND>(P, I, 3);
// Should have thrown, so fail if it didn't
FAIL() << "Incorrect indices rank";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
std::string("Number of batch dimensions must not exceed a rank of indices."));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, gather_nd_fail_unequal_batch_dims)
{
Shape params_shape{2, 3, 4, 5};
Shape indices_shape{2, 1, 4};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
try
{
auto G5 = make_shared<op::v5::GatherND>(P, I, 2);
// Should have thrown, so fail if it didn't
FAIL() << "Incorrect indices rank";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Batch dimensions of data and indices must be the same."));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, gather_nd_fail_indices_tuple_greater_data_rank_batch_dims2)
{
Shape params_shape{2, 1, 4, 5};
Shape indices_shape{2, 1, 5, 3};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
try
{
auto G5 = make_shared<op::v5::GatherND>(P, I, 2);
// Should have thrown, so fail if it didn't
FAIL() << "Incorrect indices rank";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Length of a tuple with indices must not exceed a rank of "
"data tensor excluding batch dimensions."));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
// ------------------------------ V0 + V5 ------------------------------
TEST(type_prop, gather_nd_scalar_from_2d)
{
Shape params_shape{2, 2};
@ -30,9 +179,16 @@ TEST(type_prop, gather_nd_scalar_from_2d)
Shape out_shape{2};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto G = make_shared<op::GatherND>(P, I);
NGRAPH_SUPPRESS_DEPRECATED_START
auto G = make_shared<op::v0::GatherND>(P, I);
ASSERT_EQ(G->get_element_type(), element::f32);
ASSERT_EQ(G->get_shape(), out_shape);
NGRAPH_SUPPRESS_DEPRECATED_END
auto G5 = make_shared<op::v5::GatherND>(P, I);
ASSERT_EQ(G5->get_element_type(), element::f32);
ASSERT_EQ(G5->get_shape(), out_shape);
}
TEST(type_prop, gather_nd_1d_from_2d)
@ -42,9 +198,16 @@ TEST(type_prop, gather_nd_1d_from_2d)
Shape out_shape{2, 2};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto G = make_shared<op::GatherND>(P, I);
NGRAPH_SUPPRESS_DEPRECATED_START
auto G = make_shared<op::v0::GatherND>(P, I);
ASSERT_EQ(G->get_element_type(), element::f32);
ASSERT_EQ(G->get_shape(), out_shape);
NGRAPH_SUPPRESS_DEPRECATED_END
auto G5 = make_shared<op::v5::GatherND>(P, I);
ASSERT_EQ(G5->get_element_type(), element::f32);
ASSERT_EQ(G5->get_shape(), out_shape);
}
TEST(type_prop, gather_nd_scalar_from_3d)
@ -54,9 +217,16 @@ TEST(type_prop, gather_nd_scalar_from_3d)
Shape out_shape{2};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto G = make_shared<op::GatherND>(P, I);
NGRAPH_SUPPRESS_DEPRECATED_START
auto G = make_shared<op::v0::GatherND>(P, I);
ASSERT_EQ(G->get_element_type(), element::f32);
ASSERT_EQ(G->get_shape(), out_shape);
NGRAPH_SUPPRESS_DEPRECATED_END
auto G5 = make_shared<op::v5::GatherND>(P, I);
ASSERT_EQ(G5->get_element_type(), element::f32);
ASSERT_EQ(G5->get_shape(), out_shape);
}
TEST(type_prop, gather_nd_1d_from_3d)
@ -66,9 +236,16 @@ TEST(type_prop, gather_nd_1d_from_3d)
Shape out_shape{2, 2};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto G = make_shared<op::GatherND>(P, I);
NGRAPH_SUPPRESS_DEPRECATED_START
auto G = make_shared<op::v0::GatherND>(P, I);
ASSERT_EQ(G->get_element_type(), element::f32);
ASSERT_EQ(G->get_shape(), out_shape);
NGRAPH_SUPPRESS_DEPRECATED_END
auto G5 = make_shared<op::v5::GatherND>(P, I);
ASSERT_EQ(G5->get_element_type(), element::f32);
ASSERT_EQ(G5->get_shape(), out_shape);
}
TEST(type_prop, gather_nd_2d_from_3d)
@ -78,9 +255,16 @@ TEST(type_prop, gather_nd_2d_from_3d)
Shape out_shape{1, 2, 2};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto G = make_shared<op::GatherND>(P, I);
NGRAPH_SUPPRESS_DEPRECATED_START
auto G = make_shared<op::v0::GatherND>(P, I);
ASSERT_EQ(G->get_element_type(), element::f32);
ASSERT_EQ(G->get_shape(), out_shape);
NGRAPH_SUPPRESS_DEPRECATED_END
auto G5 = make_shared<op::v5::GatherND>(P, I);
ASSERT_EQ(G5->get_element_type(), element::f32);
ASSERT_EQ(G5->get_shape(), out_shape);
}
TEST(type_prop, gather_nd_batch_scalar_from_2d)
@ -90,9 +274,16 @@ TEST(type_prop, gather_nd_batch_scalar_from_2d)
Shape out_shape{2, 1};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto G = make_shared<op::GatherND>(P, I);
NGRAPH_SUPPRESS_DEPRECATED_START
auto G = make_shared<op::v0::GatherND>(P, I);
ASSERT_EQ(G->get_element_type(), element::f32);
ASSERT_EQ(G->get_shape(), out_shape);
NGRAPH_SUPPRESS_DEPRECATED_END
auto G5 = make_shared<op::v5::GatherND>(P, I);
ASSERT_EQ(G5->get_element_type(), element::f32);
ASSERT_EQ(G5->get_shape(), out_shape);
}
TEST(type_prop, gather_nd_batch_1d_from_2d)
@ -102,9 +293,16 @@ TEST(type_prop, gather_nd_batch_1d_from_2d)
Shape out_shape{2, 1, 2};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto G = make_shared<op::GatherND>(P, I);
NGRAPH_SUPPRESS_DEPRECATED_START
auto G = make_shared<op::v0::GatherND>(P, I);
ASSERT_EQ(G->get_element_type(), element::f32);
ASSERT_EQ(G->get_shape(), out_shape);
NGRAPH_SUPPRESS_DEPRECATED_END
auto G5 = make_shared<op::v5::GatherND>(P, I);
ASSERT_EQ(G5->get_element_type(), element::f32);
ASSERT_EQ(G5->get_shape(), out_shape);
}
TEST(type_prop, gather_nd_batch_scalar_from_3d)
@ -114,9 +312,16 @@ TEST(type_prop, gather_nd_batch_scalar_from_3d)
Shape out_shape{2, 2};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto G = make_shared<op::GatherND>(P, I);
NGRAPH_SUPPRESS_DEPRECATED_START
auto G = make_shared<op::v0::GatherND>(P, I);
ASSERT_EQ(G->get_element_type(), element::f32);
ASSERT_EQ(G->get_shape(), out_shape);
NGRAPH_SUPPRESS_DEPRECATED_END
auto G5 = make_shared<op::v5::GatherND>(P, I);
ASSERT_EQ(G5->get_element_type(), element::f32);
ASSERT_EQ(G5->get_shape(), out_shape);
}
TEST(type_prop, gather_nd_batch_1d_from_3d)
@ -126,9 +331,16 @@ TEST(type_prop, gather_nd_batch_1d_from_3d)
Shape out_shape{2, 2, 2};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto G = make_shared<op::GatherND>(P, I);
NGRAPH_SUPPRESS_DEPRECATED_START
auto G = make_shared<op::v0::GatherND>(P, I);
ASSERT_EQ(G->get_element_type(), element::f32);
ASSERT_EQ(G->get_shape(), out_shape);
NGRAPH_SUPPRESS_DEPRECATED_END
auto G5 = make_shared<op::v5::GatherND>(P, I);
ASSERT_EQ(G5->get_element_type(), element::f32);
ASSERT_EQ(G5->get_shape(), out_shape);
}
TEST(type_prop, gather_nd_batch_2d_from_3d)
@ -138,9 +350,16 @@ TEST(type_prop, gather_nd_batch_2d_from_3d)
Shape out_shape{2, 1, 2, 2};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
auto G = make_shared<op::GatherND>(P, I);
NGRAPH_SUPPRESS_DEPRECATED_START
auto G = make_shared<op::v0::GatherND>(P, I);
ASSERT_EQ(G->get_element_type(), element::f32);
ASSERT_EQ(G->get_shape(), out_shape);
NGRAPH_SUPPRESS_DEPRECATED_END
auto G5 = make_shared<op::v5::GatherND>(P, I);
ASSERT_EQ(G5->get_element_type(), element::f32);
ASSERT_EQ(G5->get_shape(), out_shape);
}
TEST(type_prop, gather_nd_fail_params_rank)
@ -150,9 +369,11 @@ TEST(type_prop, gather_nd_fail_params_rank)
Shape out_shape{2, 1, 2, 2};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
NGRAPH_SUPPRESS_DEPRECATED_START
try
{
auto G = make_shared<op::GatherND>(P, I);
auto G = make_shared<op::v0::GatherND>(P, I);
// Should have thrown, so fail if it didn't
FAIL() << "Incorrect params rank";
}
@ -164,6 +385,22 @@ TEST(type_prop, gather_nd_fail_params_rank)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
NGRAPH_SUPPRESS_DEPRECATED_END
try
{
auto G5 = make_shared<op::v5::GatherND>(P, I);
// Should have thrown, so fail if it didn't
FAIL() << "Incorrect params rank";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Data rank must be at least 1."));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, gather_nd_fail_indices_rank)
@ -173,9 +410,11 @@ TEST(type_prop, gather_nd_fail_indices_rank)
Shape out_shape{2, 1, 2, 2};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
NGRAPH_SUPPRESS_DEPRECATED_START
try
{
auto G = make_shared<op::GatherND>(P, I);
auto G = make_shared<op::v0::GatherND>(P, I);
// Should have thrown, so fail if it didn't
FAIL() << "Incorrect indices rank";
}
@ -188,6 +427,22 @@ TEST(type_prop, gather_nd_fail_indices_rank)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
NGRAPH_SUPPRESS_DEPRECATED_END
try
{
auto G5 = make_shared<op::v5::GatherND>(P, I);
// Should have thrown, so fail if it didn't
FAIL() << "Incorrect indices rank";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Indices rank must be at least 1."));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, gather_nd_fail_indices_element_type)
@ -196,10 +451,12 @@ TEST(type_prop, gather_nd_fail_indices_element_type)
Shape indices_shape{2, 1, 1};
Shape out_shape{2, 1, 2, 2};
auto P = make_shared<op::Parameter>(element::f32, params_shape);
auto I = make_shared<op::Parameter>(element::i16, indices_shape);
auto I = make_shared<op::Parameter>(element::f32, indices_shape);
NGRAPH_SUPPRESS_DEPRECATED_START
try
{
auto G = make_shared<op::GatherND>(P, I);
auto G = make_shared<op::v0::GatherND>(P, I);
// Should have thrown, so fail if it didn't
FAIL() << "Incorrect indices element type";
}
@ -211,4 +468,21 @@ TEST(type_prop, gather_nd_fail_indices_element_type)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
NGRAPH_SUPPRESS_DEPRECATED_END
try
{
auto G5 = make_shared<op::v5::GatherND>(P, I);
// Should have thrown, so fail if it didn't
FAIL() << "Incorrect indices element type";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("The indices type is expected to be an integer type."));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}