Fix ScatterND validation and implement evaluate (#4905)
* Fix ScatterND validation and implement evaluate * Apply review feedback * Update scatternd.py
This commit is contained in:
parent
224dfd6520
commit
2bed9c9277
@ -48,7 +48,7 @@ output = [[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]],
|
||||
|
||||
* **2**: `indices` tensor with indices of arbitrary rank `q` >= 1 and of type *T_IND*. All index values `i_j` in index entry `(i_0, i_1, ...,i_k)` (where `k = indices.shape[-1]`) must be within bounds `[0, s_j - 1]` where `s_j = data.shape[j]`. `k` must be at most `r`. Required.
|
||||
|
||||
* **3**: `updates` tensor of rank `r - indices.shape[-1] + q - 1` of type *T*. Required.
|
||||
* **3**: `updates` tensor of rank `r - indices.shape[-1] + q - 1` of type *T*. If expected `updates` rank is 0D it can be a tensor with single element. Required.
|
||||
|
||||
**Outputs**:
|
||||
|
||||
|
@ -44,14 +44,16 @@ class ScatterNDBase(Op):
|
||||
# 1. ranks of both input and indices must be at least 1
|
||||
assert len(input_shape) >= 1 and len(indices_shape) >= 1, \
|
||||
'The node "{}" input and indices ranks must be at least 1'.format(node_name)
|
||||
|
||||
|
||||
# 2. the last dimension of indices shape must be at most a rank of input
|
||||
assert indices_shape[-1] <= len(input_shape), \
|
||||
'The last dimension of indices shape must be at most a rank of input for the node "{}"'.format(node_name)
|
||||
|
||||
# 3. updates is a tensor of shape indices_shape[:-1] + input_shape[indices_shape[-1]:]
|
||||
# if expected updates shape is scalar, updates can be tensor with the single element (for example, of shape [1], [[1]], etc.)
|
||||
expected_updates_shape = np.concatenate((indices_shape[:-1], input_shape[indices_shape[-1]:]), axis=0)
|
||||
assert np.array_equal(updates_shape, expected_updates_shape), \
|
||||
assert np.array_equal(updates_shape, expected_updates_shape) or\
|
||||
np.array_equal(expected_updates_shape, []) and np.array_equal(updates_shape, np.ones(len(updates_shape))), \
|
||||
'The updates shape must be equal to indices_shape[:-1] + input_shape[indices_shape[-1]:] for the node "{}"'.format(node_name)
|
||||
|
||||
node.out_port(0).data.set_shape(input_shape)
|
||||
|
@ -62,6 +62,11 @@ inputs7 = {'input': {'shape': int64_array([8]), 'value': int64_array([1, 2, 3, 4
|
||||
'updates': {'shape': int64_array([]), 'value': 9}}
|
||||
output7 = int64_array([1, 2, 3, 4, 9, 6, 7, 8])
|
||||
|
||||
inputs8 = {'input': {'shape': int64_array([3]), 'value': int64_array([1, 2, 3])},
|
||||
'indices': {'shape': int64_array([1]), 'value': int64_array([2])},
|
||||
'updates': {'shape': int64_array([1]), 'value': int64_array([9])}}
|
||||
output8 = int64_array([1, 2, 9])
|
||||
|
||||
class TestScatterNDUpdate(unittest.TestCase):
|
||||
def test_partial_infer1(self):
|
||||
graph = build_graph(nodes_attributes, edges, inputs1)
|
||||
@ -139,7 +144,7 @@ class TestScatterNDUpdate(unittest.TestCase):
|
||||
res_output_value = graph.node['output']['value']
|
||||
|
||||
self.assertTrue(np.array_equal(output6, res_output_value),
|
||||
'values do not match expected: {} and given: {}'.format(output5, res_output_value))
|
||||
'values do not match expected: {} and given: {}'.format(output6, res_output_value))
|
||||
|
||||
def test_infer7_scalar(self):
|
||||
graph = build_graph(nodes_attributes, edges, inputs7)
|
||||
@ -150,4 +155,15 @@ class TestScatterNDUpdate(unittest.TestCase):
|
||||
res_output_value = graph.node['output']['value']
|
||||
|
||||
self.assertTrue(np.array_equal(output7, res_output_value),
|
||||
'values do not match expected: {} and given: {}'.format(output5, res_output_value))
|
||||
'values do not match expected: {} and given: {}'.format(output7, res_output_value))
|
||||
|
||||
def test_infer8(self):
|
||||
graph = build_graph(nodes_attributes, edges, inputs8)
|
||||
scatternd_node = Node(graph, 'scatternd_node')
|
||||
ScatterNDUpdate.infer(scatternd_node)
|
||||
|
||||
# get the result
|
||||
res_output_value = graph.node['output']['value']
|
||||
|
||||
self.assertTrue(np.array_equal(output8, res_output_value),
|
||||
'values do not match expected: {} and given: {}'.format(output8, res_output_value))
|
||||
|
@ -32,6 +32,8 @@ namespace ngraph
|
||||
|
||||
virtual std::shared_ptr<Node>
|
||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
bool evaluate(const HostTensorVector& outputs,
|
||||
const HostTensorVector& inputs) const override;
|
||||
};
|
||||
}
|
||||
using v3::ScatterNDUpdate;
|
||||
|
@ -4,6 +4,9 @@
|
||||
|
||||
#include "ngraph/op/scatter_nd_update.hpp"
|
||||
#include "itt.hpp"
|
||||
#include "ngraph/runtime/host_tensor.hpp"
|
||||
#include "ngraph/runtime/reference/scatter_nd_update.hpp"
|
||||
#include "ngraph/validation_util.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
@ -18,3 +21,79 @@ shared_ptr<Node> op::v3::ScatterNDUpdate::clone_with_new_inputs(const OutputVect
|
||||
new_args.at(op::util::ScatterNDBase::INDICES),
|
||||
new_args.at(op::util::ScatterNDBase::UPDATES));
|
||||
}
|
||||
|
||||
namespace scatter
|
||||
{
|
||||
template <element::Type_t ET>
|
||||
bool evaluate(const HostTensorPtr& arg0,
|
||||
const HostTensorPtr& arg1,
|
||||
const HostTensorPtr& arg2,
|
||||
const HostTensorPtr& out)
|
||||
{
|
||||
using T = typename element_type_traits<ET>::value_type;
|
||||
Shape params_shape = arg0->get_shape();
|
||||
Shape indices_shape = arg1->get_shape();
|
||||
Shape updates_shape = arg1->get_shape();
|
||||
Shape out_shape(params_shape);
|
||||
out->set_shape(out_shape);
|
||||
|
||||
if (arg1->get_element_type() == element::i64)
|
||||
{
|
||||
runtime::reference::scatterNdUpdate<T, int64_t>(arg0->get_data_ptr<ET>(),
|
||||
arg1->get_data_ptr<int64_t>(),
|
||||
arg2->get_data_ptr<ET>(),
|
||||
out->get_data_ptr<ET>(),
|
||||
arg0->get_shape(),
|
||||
arg1->get_shape(),
|
||||
arg2->get_shape());
|
||||
}
|
||||
else if (arg1->get_element_type() == element::i32)
|
||||
{
|
||||
runtime::reference::scatterNdUpdate<T, int32_t>(arg0->get_data_ptr<ET>(),
|
||||
arg1->get_data_ptr<int32_t>(),
|
||||
arg2->get_data_ptr<ET>(),
|
||||
out->get_data_ptr<ET>(),
|
||||
arg0->get_shape(),
|
||||
arg1->get_shape(),
|
||||
arg2->get_shape());
|
||||
}
|
||||
else
|
||||
{
|
||||
throw ngraph_error("Unexpected type");
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool evaluate_scatter(const HostTensorPtr& arg0,
|
||||
const HostTensorPtr& arg1,
|
||||
const HostTensorPtr& arg2,
|
||||
const HostTensorPtr& out)
|
||||
{
|
||||
bool rc = true;
|
||||
|
||||
switch (out->get_element_type())
|
||||
{
|
||||
NGRAPH_TYPE_CASE(evaluate_scatter, i32, arg0, arg1, arg2, out);
|
||||
NGRAPH_TYPE_CASE(evaluate_scatter, i64, arg0, arg1, arg2, out);
|
||||
NGRAPH_TYPE_CASE(evaluate_scatter, u32, arg0, arg1, arg2, out);
|
||||
NGRAPH_TYPE_CASE(evaluate_scatter, u64, arg0, arg1, arg2, out);
|
||||
NGRAPH_TYPE_CASE(evaluate_scatter, f16, arg0, arg1, arg2, out);
|
||||
NGRAPH_TYPE_CASE(evaluate_scatter, f32, arg0, arg1, arg2, out);
|
||||
NGRAPH_TYPE_CASE(evaluate_scatter, boolean, arg0, arg1, arg2, out);
|
||||
default: rc = false; break;
|
||||
}
|
||||
return rc;
|
||||
}
|
||||
}
|
||||
|
||||
bool op::v3::ScatterNDUpdate::evaluate(const HostTensorVector& outputs,
|
||||
const HostTensorVector& inputs) const
|
||||
{
|
||||
NGRAPH_OP_SCOPE(v3_ScatterNDUpdate_evaluate);
|
||||
NGRAPH_CHECK(this, !inputs.empty());
|
||||
NGRAPH_CHECK(this, validate_host_tensor_vector(inputs, 3));
|
||||
NGRAPH_CHECK(this, validate_host_tensor_vector(outputs, 1));
|
||||
|
||||
return scatter::evaluate_scatter(inputs[0], inputs[1], inputs[2], outputs[0]);
|
||||
}
|
||||
|
@ -40,6 +40,10 @@ void op::util::ScatterNDBase::validate_and_infer_types()
|
||||
const PartialShape& indices_shape = get_input_partial_shape(INDICES);
|
||||
const PartialShape& updates_shape = get_input_partial_shape(UPDATES);
|
||||
|
||||
const auto& inputs_rank = inputs_shape.rank();
|
||||
const auto& indices_rank = indices_shape.rank();
|
||||
const auto& updates_rank = updates_shape.rank();
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
indices_et == element::i32 || indices_et == element::i64,
|
||||
"Indices element type must be i64 or i32");
|
||||
@ -48,47 +52,47 @@ void op::util::ScatterNDBase::validate_and_infer_types()
|
||||
this, updates_et == inputs_et, "Updates element type must be the same as inputs");
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
indices_shape.rank().is_dynamic() ||
|
||||
indices_shape.rank().get_length() >= 1,
|
||||
indices_rank.is_dynamic() || indices_rank.get_length() >= 1,
|
||||
"Indices rank is expected to be at least 1");
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
inputs_shape.rank().is_dynamic() || indices_shape.rank().is_dynamic() ||
|
||||
indices_shape[indices_shape.rank().get_length() - 1].get_length() <=
|
||||
inputs_shape.rank().get_length(),
|
||||
inputs_rank.is_dynamic() || indices_rank.is_dynamic() ||
|
||||
indices_shape[indices_rank.get_length() - 1].get_length() <=
|
||||
inputs_rank.get_length(),
|
||||
"Last dimension of indices can be at most the rank of inputs");
|
||||
|
||||
NODE_VALIDATION_CHECK(
|
||||
this,
|
||||
inputs_shape.rank().is_dynamic() || indices_shape.rank().is_dynamic() ||
|
||||
updates_shape.rank().is_dynamic() ||
|
||||
updates_shape.rank().get_length() ==
|
||||
indices_shape.rank().get_length() + inputs_shape.rank().get_length() -
|
||||
indices_shape[indices_shape.rank().get_length() - 1].get_length() - 1,
|
||||
"Rank of updates must be rank of inputs + rank of indices - last dimension of indices "
|
||||
"- 1");
|
||||
|
||||
bool compatible = true;
|
||||
if (inputs_shape.is_static() && indices_shape.is_static() && updates_shape.is_static())
|
||||
if (inputs_rank.is_static() && indices_rank.is_static() && updates_rank.is_static())
|
||||
{
|
||||
size_t indices_rank = indices_shape.rank().get_length();
|
||||
size_t updates_rank = updates_shape.rank().get_length();
|
||||
for (size_t i = 0; i < indices_rank - 1; i++)
|
||||
auto expected_updates_rank = indices_rank.get_length() + inputs_rank.get_length() -
|
||||
indices_shape[indices_rank.get_length() - 1].get_length() - 1;
|
||||
// If expected updates rank is 0D it also can be a tensor with one element
|
||||
NODE_VALIDATION_CHECK(
|
||||
this,
|
||||
updates_rank.get_length() == expected_updates_rank || expected_updates_rank == 0,
|
||||
"Rank of updates must be rank of inputs + rank of indices - last dimension of indices "
|
||||
"- 1");
|
||||
|
||||
bool compatible = true;
|
||||
if (inputs_shape.is_static() && indices_shape.is_static() && updates_shape.is_static())
|
||||
{
|
||||
compatible = compatible && updates_shape[i].same_scheme(indices_shape[i]);
|
||||
NODE_VALIDATION_CHECK(
|
||||
this,
|
||||
compatible,
|
||||
"updates_shape[0:indices_rank-1] shape must be indices_shape[:-1]");
|
||||
}
|
||||
size_t j = indices_shape[indices_rank - 1].get_length();
|
||||
for (size_t i = indices_rank - 1; i < updates_rank; i++, j++)
|
||||
{
|
||||
compatible = compatible && updates_shape[i].same_scheme(inputs_shape[j]);
|
||||
NODE_VALIDATION_CHECK(
|
||||
this,
|
||||
compatible,
|
||||
"updates_shape[indices_rank-1:] shape must be input_shape[indices_shape[-1]:]");
|
||||
size_t static_indices_rank = indices_rank.get_length();
|
||||
for (size_t i = 0; i < static_indices_rank - 1; i++)
|
||||
{
|
||||
compatible = compatible && updates_shape[i].same_scheme(indices_shape[i]);
|
||||
NODE_VALIDATION_CHECK(
|
||||
this,
|
||||
compatible,
|
||||
"updates_shape[0:indices_rank-1] shape must be indices_shape[:-1]");
|
||||
}
|
||||
size_t j = indices_shape[static_indices_rank - 1].get_length();
|
||||
for (size_t i = static_indices_rank - 1; i < expected_updates_rank; i++, j++)
|
||||
{
|
||||
compatible = compatible && updates_shape[i].same_scheme(inputs_shape[j]);
|
||||
NODE_VALIDATION_CHECK(
|
||||
this,
|
||||
compatible,
|
||||
"updates_shape[indices_rank-1:] shape must be input_shape[indices_shape[-1]:]");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user