Fix ScatterND validation and implement evaluate (#4905)

* Fix ScatterND validation and implement evaluate

* Apply review feedback

* Update scatternd.py
This commit is contained in:
Maxim Vafin 2021-03-31 19:01:46 +03:00 committed by GitHub
parent 224dfd6520
commit 2bed9c9277
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 142 additions and 39 deletions

View File

@ -48,7 +48,7 @@ output = [[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]],
* **2**: `indices` tensor with indices of arbitrary rank `q` >= 1 and of type *T_IND*. All index values `i_j` in index entry `(i_0, i_1, ...,i_k)` (where `k = indices.shape[-1]`) must be within bounds `[0, s_j - 1]` where `s_j = data.shape[j]`. `k` must be at most `r`. Required.
* **3**: `updates` tensor of rank `r - indices.shape[-1] + q - 1` of type *T*. Required.
* **3**: `updates` tensor of rank `r - indices.shape[-1] + q - 1` of type *T*. If expected `updates` rank is 0D it can be a tensor with single element. Required.
**Outputs**:

View File

@ -44,14 +44,16 @@ class ScatterNDBase(Op):
# 1. ranks of both input and indices must be at least 1
assert len(input_shape) >= 1 and len(indices_shape) >= 1, \
'The node "{}" input and indices ranks must be at least 1'.format(node_name)
# 2. the last dimension of indices shape must be at most a rank of input
assert indices_shape[-1] <= len(input_shape), \
'The last dimension of indices shape must be at most a rank of input for the node "{}"'.format(node_name)
# 3. updates is a tensor of shape indices_shape[:-1] + input_shape[indices_shape[-1]:]
# if expected updates shape is scalar, updates can be tensor with the single element (for example, of shape [1], [[1]], etc.)
expected_updates_shape = np.concatenate((indices_shape[:-1], input_shape[indices_shape[-1]:]), axis=0)
assert np.array_equal(updates_shape, expected_updates_shape), \
assert np.array_equal(updates_shape, expected_updates_shape) or\
np.array_equal(expected_updates_shape, []) and np.array_equal(updates_shape, np.ones(len(updates_shape))), \
'The updates shape must be equal to indices_shape[:-1] + input_shape[indices_shape[-1]:] for the node "{}"'.format(node_name)
node.out_port(0).data.set_shape(input_shape)

View File

@ -62,6 +62,11 @@ inputs7 = {'input': {'shape': int64_array([8]), 'value': int64_array([1, 2, 3, 4
'updates': {'shape': int64_array([]), 'value': 9}}
output7 = int64_array([1, 2, 3, 4, 9, 6, 7, 8])
inputs8 = {'input': {'shape': int64_array([3]), 'value': int64_array([1, 2, 3])},
'indices': {'shape': int64_array([1]), 'value': int64_array([2])},
'updates': {'shape': int64_array([1]), 'value': int64_array([9])}}
output8 = int64_array([1, 2, 9])
class TestScatterNDUpdate(unittest.TestCase):
def test_partial_infer1(self):
graph = build_graph(nodes_attributes, edges, inputs1)
@ -139,7 +144,7 @@ class TestScatterNDUpdate(unittest.TestCase):
res_output_value = graph.node['output']['value']
self.assertTrue(np.array_equal(output6, res_output_value),
'values do not match expected: {} and given: {}'.format(output5, res_output_value))
'values do not match expected: {} and given: {}'.format(output6, res_output_value))
def test_infer7_scalar(self):
graph = build_graph(nodes_attributes, edges, inputs7)
@ -150,4 +155,15 @@ class TestScatterNDUpdate(unittest.TestCase):
res_output_value = graph.node['output']['value']
self.assertTrue(np.array_equal(output7, res_output_value),
'values do not match expected: {} and given: {}'.format(output5, res_output_value))
'values do not match expected: {} and given: {}'.format(output7, res_output_value))
def test_infer8(self):
graph = build_graph(nodes_attributes, edges, inputs8)
scatternd_node = Node(graph, 'scatternd_node')
ScatterNDUpdate.infer(scatternd_node)
# get the result
res_output_value = graph.node['output']['value']
self.assertTrue(np.array_equal(output8, res_output_value),
'values do not match expected: {} and given: {}'.format(output8, res_output_value))

View File

@ -32,6 +32,8 @@ namespace ngraph
virtual std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
bool evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) const override;
};
}
using v3::ScatterNDUpdate;

View File

@ -4,6 +4,9 @@
#include "ngraph/op/scatter_nd_update.hpp"
#include "itt.hpp"
#include "ngraph/runtime/host_tensor.hpp"
#include "ngraph/runtime/reference/scatter_nd_update.hpp"
#include "ngraph/validation_util.hpp"
using namespace std;
using namespace ngraph;
@ -18,3 +21,79 @@ shared_ptr<Node> op::v3::ScatterNDUpdate::clone_with_new_inputs(const OutputVect
new_args.at(op::util::ScatterNDBase::INDICES),
new_args.at(op::util::ScatterNDBase::UPDATES));
}
namespace scatter
{
template <element::Type_t ET>
bool evaluate(const HostTensorPtr& arg0,
const HostTensorPtr& arg1,
const HostTensorPtr& arg2,
const HostTensorPtr& out)
{
using T = typename element_type_traits<ET>::value_type;
Shape params_shape = arg0->get_shape();
Shape indices_shape = arg1->get_shape();
Shape updates_shape = arg1->get_shape();
Shape out_shape(params_shape);
out->set_shape(out_shape);
if (arg1->get_element_type() == element::i64)
{
runtime::reference::scatterNdUpdate<T, int64_t>(arg0->get_data_ptr<ET>(),
arg1->get_data_ptr<int64_t>(),
arg2->get_data_ptr<ET>(),
out->get_data_ptr<ET>(),
arg0->get_shape(),
arg1->get_shape(),
arg2->get_shape());
}
else if (arg1->get_element_type() == element::i32)
{
runtime::reference::scatterNdUpdate<T, int32_t>(arg0->get_data_ptr<ET>(),
arg1->get_data_ptr<int32_t>(),
arg2->get_data_ptr<ET>(),
out->get_data_ptr<ET>(),
arg0->get_shape(),
arg1->get_shape(),
arg2->get_shape());
}
else
{
throw ngraph_error("Unexpected type");
}
return true;
}
bool evaluate_scatter(const HostTensorPtr& arg0,
const HostTensorPtr& arg1,
const HostTensorPtr& arg2,
const HostTensorPtr& out)
{
bool rc = true;
switch (out->get_element_type())
{
NGRAPH_TYPE_CASE(evaluate_scatter, i32, arg0, arg1, arg2, out);
NGRAPH_TYPE_CASE(evaluate_scatter, i64, arg0, arg1, arg2, out);
NGRAPH_TYPE_CASE(evaluate_scatter, u32, arg0, arg1, arg2, out);
NGRAPH_TYPE_CASE(evaluate_scatter, u64, arg0, arg1, arg2, out);
NGRAPH_TYPE_CASE(evaluate_scatter, f16, arg0, arg1, arg2, out);
NGRAPH_TYPE_CASE(evaluate_scatter, f32, arg0, arg1, arg2, out);
NGRAPH_TYPE_CASE(evaluate_scatter, boolean, arg0, arg1, arg2, out);
default: rc = false; break;
}
return rc;
}
}
bool op::v3::ScatterNDUpdate::evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) const
{
NGRAPH_OP_SCOPE(v3_ScatterNDUpdate_evaluate);
NGRAPH_CHECK(this, !inputs.empty());
NGRAPH_CHECK(this, validate_host_tensor_vector(inputs, 3));
NGRAPH_CHECK(this, validate_host_tensor_vector(outputs, 1));
return scatter::evaluate_scatter(inputs[0], inputs[1], inputs[2], outputs[0]);
}

View File

@ -40,6 +40,10 @@ void op::util::ScatterNDBase::validate_and_infer_types()
const PartialShape& indices_shape = get_input_partial_shape(INDICES);
const PartialShape& updates_shape = get_input_partial_shape(UPDATES);
const auto& inputs_rank = inputs_shape.rank();
const auto& indices_rank = indices_shape.rank();
const auto& updates_rank = updates_shape.rank();
NODE_VALIDATION_CHECK(this,
indices_et == element::i32 || indices_et == element::i64,
"Indices element type must be i64 or i32");
@ -48,47 +52,47 @@ void op::util::ScatterNDBase::validate_and_infer_types()
this, updates_et == inputs_et, "Updates element type must be the same as inputs");
NODE_VALIDATION_CHECK(this,
indices_shape.rank().is_dynamic() ||
indices_shape.rank().get_length() >= 1,
indices_rank.is_dynamic() || indices_rank.get_length() >= 1,
"Indices rank is expected to be at least 1");
NODE_VALIDATION_CHECK(this,
inputs_shape.rank().is_dynamic() || indices_shape.rank().is_dynamic() ||
indices_shape[indices_shape.rank().get_length() - 1].get_length() <=
inputs_shape.rank().get_length(),
inputs_rank.is_dynamic() || indices_rank.is_dynamic() ||
indices_shape[indices_rank.get_length() - 1].get_length() <=
inputs_rank.get_length(),
"Last dimension of indices can be at most the rank of inputs");
NODE_VALIDATION_CHECK(
this,
inputs_shape.rank().is_dynamic() || indices_shape.rank().is_dynamic() ||
updates_shape.rank().is_dynamic() ||
updates_shape.rank().get_length() ==
indices_shape.rank().get_length() + inputs_shape.rank().get_length() -
indices_shape[indices_shape.rank().get_length() - 1].get_length() - 1,
"Rank of updates must be rank of inputs + rank of indices - last dimension of indices "
"- 1");
bool compatible = true;
if (inputs_shape.is_static() && indices_shape.is_static() && updates_shape.is_static())
if (inputs_rank.is_static() && indices_rank.is_static() && updates_rank.is_static())
{
size_t indices_rank = indices_shape.rank().get_length();
size_t updates_rank = updates_shape.rank().get_length();
for (size_t i = 0; i < indices_rank - 1; i++)
auto expected_updates_rank = indices_rank.get_length() + inputs_rank.get_length() -
indices_shape[indices_rank.get_length() - 1].get_length() - 1;
// If expected updates rank is 0D it also can be a tensor with one element
NODE_VALIDATION_CHECK(
this,
updates_rank.get_length() == expected_updates_rank || expected_updates_rank == 0,
"Rank of updates must be rank of inputs + rank of indices - last dimension of indices "
"- 1");
bool compatible = true;
if (inputs_shape.is_static() && indices_shape.is_static() && updates_shape.is_static())
{
compatible = compatible && updates_shape[i].same_scheme(indices_shape[i]);
NODE_VALIDATION_CHECK(
this,
compatible,
"updates_shape[0:indices_rank-1] shape must be indices_shape[:-1]");
}
size_t j = indices_shape[indices_rank - 1].get_length();
for (size_t i = indices_rank - 1; i < updates_rank; i++, j++)
{
compatible = compatible && updates_shape[i].same_scheme(inputs_shape[j]);
NODE_VALIDATION_CHECK(
this,
compatible,
"updates_shape[indices_rank-1:] shape must be input_shape[indices_shape[-1]:]");
size_t static_indices_rank = indices_rank.get_length();
for (size_t i = 0; i < static_indices_rank - 1; i++)
{
compatible = compatible && updates_shape[i].same_scheme(indices_shape[i]);
NODE_VALIDATION_CHECK(
this,
compatible,
"updates_shape[0:indices_rank-1] shape must be indices_shape[:-1]");
}
size_t j = indices_shape[static_indices_rank - 1].get_length();
for (size_t i = static_indices_rank - 1; i < expected_updates_rank; i++, j++)
{
compatible = compatible && updates_shape[i].same_scheme(inputs_shape[j]);
NODE_VALIDATION_CHECK(
this,
compatible,
"updates_shape[indices_rank-1:] shape must be input_shape[indices_shape[-1]:]");
}
}
}