diff --git a/docs/ops/movement/ScatterNDUpdate_3.md b/docs/ops/movement/ScatterNDUpdate_3.md index 93398fa3f98..5dd1ed9a462 100644 --- a/docs/ops/movement/ScatterNDUpdate_3.md +++ b/docs/ops/movement/ScatterNDUpdate_3.md @@ -48,7 +48,7 @@ output = [[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]], * **2**: `indices` tensor with indices of arbitrary rank `q` >= 1 and of type *T_IND*. All index values `i_j` in index entry `(i_0, i_1, ...,i_k)` (where `k = indices.shape[-1]`) must be within bounds `[0, s_j - 1]` where `s_j = data.shape[j]`. `k` must be at most `r`. Required. -* **3**: `updates` tensor of rank `r - indices.shape[-1] + q - 1` of type *T*. Required. +* **3**: `updates` tensor of rank `r - indices.shape[-1] + q - 1` of type *T*. If expected `updates` rank is 0D it can be a tensor with single element. Required. **Outputs**: diff --git a/model-optimizer/extensions/ops/scatternd.py b/model-optimizer/extensions/ops/scatternd.py index cffb226b268..8917d11cfb8 100644 --- a/model-optimizer/extensions/ops/scatternd.py +++ b/model-optimizer/extensions/ops/scatternd.py @@ -44,14 +44,16 @@ class ScatterNDBase(Op): # 1. ranks of both input and indices must be at least 1 assert len(input_shape) >= 1 and len(indices_shape) >= 1, \ 'The node "{}" input and indices ranks must be at least 1'.format(node_name) - + # 2. the last dimension of indices shape must be at most a rank of input assert indices_shape[-1] <= len(input_shape), \ 'The last dimension of indices shape must be at most a rank of input for the node "{}"'.format(node_name) # 3. updates is a tensor of shape indices_shape[:-1] + input_shape[indices_shape[-1]:] + # if expected updates shape is scalar, updates can be tensor with the single element (for example, of shape [1], [[1]], etc.) expected_updates_shape = np.concatenate((indices_shape[:-1], input_shape[indices_shape[-1]:]), axis=0) - assert np.array_equal(updates_shape, expected_updates_shape), \ + assert np.array_equal(updates_shape, expected_updates_shape) or\ + np.array_equal(expected_updates_shape, []) and np.array_equal(updates_shape, np.ones(len(updates_shape))), \ 'The updates shape must be equal to indices_shape[:-1] + input_shape[indices_shape[-1]:] for the node "{}"'.format(node_name) node.out_port(0).data.set_shape(input_shape) diff --git a/model-optimizer/extensions/ops/scatternd_test.py b/model-optimizer/extensions/ops/scatternd_test.py index 41e9ad94101..a53b020202c 100644 --- a/model-optimizer/extensions/ops/scatternd_test.py +++ b/model-optimizer/extensions/ops/scatternd_test.py @@ -62,6 +62,11 @@ inputs7 = {'input': {'shape': int64_array([8]), 'value': int64_array([1, 2, 3, 4 'updates': {'shape': int64_array([]), 'value': 9}} output7 = int64_array([1, 2, 3, 4, 9, 6, 7, 8]) +inputs8 = {'input': {'shape': int64_array([3]), 'value': int64_array([1, 2, 3])}, + 'indices': {'shape': int64_array([1]), 'value': int64_array([2])}, + 'updates': {'shape': int64_array([1]), 'value': int64_array([9])}} +output8 = int64_array([1, 2, 9]) + class TestScatterNDUpdate(unittest.TestCase): def test_partial_infer1(self): graph = build_graph(nodes_attributes, edges, inputs1) @@ -139,7 +144,7 @@ class TestScatterNDUpdate(unittest.TestCase): res_output_value = graph.node['output']['value'] self.assertTrue(np.array_equal(output6, res_output_value), - 'values do not match expected: {} and given: {}'.format(output5, res_output_value)) + 'values do not match expected: {} and given: {}'.format(output6, res_output_value)) def test_infer7_scalar(self): graph = build_graph(nodes_attributes, edges, inputs7) @@ -150,4 +155,15 @@ class TestScatterNDUpdate(unittest.TestCase): res_output_value = graph.node['output']['value'] self.assertTrue(np.array_equal(output7, res_output_value), - 'values do not match expected: {} and given: {}'.format(output5, res_output_value)) + 'values do not match expected: {} and given: {}'.format(output7, res_output_value)) + + def test_infer8(self): + graph = build_graph(nodes_attributes, edges, inputs8) + scatternd_node = Node(graph, 'scatternd_node') + ScatterNDUpdate.infer(scatternd_node) + + # get the result + res_output_value = graph.node['output']['value'] + + self.assertTrue(np.array_equal(output8, res_output_value), + 'values do not match expected: {} and given: {}'.format(output8, res_output_value)) diff --git a/ngraph/core/include/ngraph/op/scatter_nd_update.hpp b/ngraph/core/include/ngraph/op/scatter_nd_update.hpp index ae646d6980a..3fccbdee97a 100644 --- a/ngraph/core/include/ngraph/op/scatter_nd_update.hpp +++ b/ngraph/core/include/ngraph/op/scatter_nd_update.hpp @@ -32,6 +32,8 @@ namespace ngraph virtual std::shared_ptr clone_with_new_inputs(const OutputVector& new_args) const override; + bool evaluate(const HostTensorVector& outputs, + const HostTensorVector& inputs) const override; }; } using v3::ScatterNDUpdate; diff --git a/ngraph/core/src/op/scatter_nd_update.cpp b/ngraph/core/src/op/scatter_nd_update.cpp index 0142286981d..2c4716c9c01 100644 --- a/ngraph/core/src/op/scatter_nd_update.cpp +++ b/ngraph/core/src/op/scatter_nd_update.cpp @@ -4,6 +4,9 @@ #include "ngraph/op/scatter_nd_update.hpp" #include "itt.hpp" +#include "ngraph/runtime/host_tensor.hpp" +#include "ngraph/runtime/reference/scatter_nd_update.hpp" +#include "ngraph/validation_util.hpp" using namespace std; using namespace ngraph; @@ -18,3 +21,79 @@ shared_ptr op::v3::ScatterNDUpdate::clone_with_new_inputs(const OutputVect new_args.at(op::util::ScatterNDBase::INDICES), new_args.at(op::util::ScatterNDBase::UPDATES)); } + +namespace scatter +{ + template + bool evaluate(const HostTensorPtr& arg0, + const HostTensorPtr& arg1, + const HostTensorPtr& arg2, + const HostTensorPtr& out) + { + using T = typename element_type_traits::value_type; + Shape params_shape = arg0->get_shape(); + Shape indices_shape = arg1->get_shape(); + Shape updates_shape = arg1->get_shape(); + Shape out_shape(params_shape); + out->set_shape(out_shape); + + if (arg1->get_element_type() == element::i64) + { + runtime::reference::scatterNdUpdate(arg0->get_data_ptr(), + arg1->get_data_ptr(), + arg2->get_data_ptr(), + out->get_data_ptr(), + arg0->get_shape(), + arg1->get_shape(), + arg2->get_shape()); + } + else if (arg1->get_element_type() == element::i32) + { + runtime::reference::scatterNdUpdate(arg0->get_data_ptr(), + arg1->get_data_ptr(), + arg2->get_data_ptr(), + out->get_data_ptr(), + arg0->get_shape(), + arg1->get_shape(), + arg2->get_shape()); + } + else + { + throw ngraph_error("Unexpected type"); + } + + return true; + } + + bool evaluate_scatter(const HostTensorPtr& arg0, + const HostTensorPtr& arg1, + const HostTensorPtr& arg2, + const HostTensorPtr& out) + { + bool rc = true; + + switch (out->get_element_type()) + { + NGRAPH_TYPE_CASE(evaluate_scatter, i32, arg0, arg1, arg2, out); + NGRAPH_TYPE_CASE(evaluate_scatter, i64, arg0, arg1, arg2, out); + NGRAPH_TYPE_CASE(evaluate_scatter, u32, arg0, arg1, arg2, out); + NGRAPH_TYPE_CASE(evaluate_scatter, u64, arg0, arg1, arg2, out); + NGRAPH_TYPE_CASE(evaluate_scatter, f16, arg0, arg1, arg2, out); + NGRAPH_TYPE_CASE(evaluate_scatter, f32, arg0, arg1, arg2, out); + NGRAPH_TYPE_CASE(evaluate_scatter, boolean, arg0, arg1, arg2, out); + default: rc = false; break; + } + return rc; + } +} + +bool op::v3::ScatterNDUpdate::evaluate(const HostTensorVector& outputs, + const HostTensorVector& inputs) const +{ + NGRAPH_OP_SCOPE(v3_ScatterNDUpdate_evaluate); + NGRAPH_CHECK(this, !inputs.empty()); + NGRAPH_CHECK(this, validate_host_tensor_vector(inputs, 3)); + NGRAPH_CHECK(this, validate_host_tensor_vector(outputs, 1)); + + return scatter::evaluate_scatter(inputs[0], inputs[1], inputs[2], outputs[0]); +} diff --git a/ngraph/core/src/op/util/scatter_nd_base.cpp b/ngraph/core/src/op/util/scatter_nd_base.cpp index 449de80cf20..9d91891c58f 100644 --- a/ngraph/core/src/op/util/scatter_nd_base.cpp +++ b/ngraph/core/src/op/util/scatter_nd_base.cpp @@ -40,6 +40,10 @@ void op::util::ScatterNDBase::validate_and_infer_types() const PartialShape& indices_shape = get_input_partial_shape(INDICES); const PartialShape& updates_shape = get_input_partial_shape(UPDATES); + const auto& inputs_rank = inputs_shape.rank(); + const auto& indices_rank = indices_shape.rank(); + const auto& updates_rank = updates_shape.rank(); + NODE_VALIDATION_CHECK(this, indices_et == element::i32 || indices_et == element::i64, "Indices element type must be i64 or i32"); @@ -48,47 +52,47 @@ void op::util::ScatterNDBase::validate_and_infer_types() this, updates_et == inputs_et, "Updates element type must be the same as inputs"); NODE_VALIDATION_CHECK(this, - indices_shape.rank().is_dynamic() || - indices_shape.rank().get_length() >= 1, + indices_rank.is_dynamic() || indices_rank.get_length() >= 1, "Indices rank is expected to be at least 1"); NODE_VALIDATION_CHECK(this, - inputs_shape.rank().is_dynamic() || indices_shape.rank().is_dynamic() || - indices_shape[indices_shape.rank().get_length() - 1].get_length() <= - inputs_shape.rank().get_length(), + inputs_rank.is_dynamic() || indices_rank.is_dynamic() || + indices_shape[indices_rank.get_length() - 1].get_length() <= + inputs_rank.get_length(), "Last dimension of indices can be at most the rank of inputs"); - NODE_VALIDATION_CHECK( - this, - inputs_shape.rank().is_dynamic() || indices_shape.rank().is_dynamic() || - updates_shape.rank().is_dynamic() || - updates_shape.rank().get_length() == - indices_shape.rank().get_length() + inputs_shape.rank().get_length() - - indices_shape[indices_shape.rank().get_length() - 1].get_length() - 1, - "Rank of updates must be rank of inputs + rank of indices - last dimension of indices " - "- 1"); - - bool compatible = true; - if (inputs_shape.is_static() && indices_shape.is_static() && updates_shape.is_static()) + if (inputs_rank.is_static() && indices_rank.is_static() && updates_rank.is_static()) { - size_t indices_rank = indices_shape.rank().get_length(); - size_t updates_rank = updates_shape.rank().get_length(); - for (size_t i = 0; i < indices_rank - 1; i++) + auto expected_updates_rank = indices_rank.get_length() + inputs_rank.get_length() - + indices_shape[indices_rank.get_length() - 1].get_length() - 1; + // If expected updates rank is 0D it also can be a tensor with one element + NODE_VALIDATION_CHECK( + this, + updates_rank.get_length() == expected_updates_rank || expected_updates_rank == 0, + "Rank of updates must be rank of inputs + rank of indices - last dimension of indices " + "- 1"); + + bool compatible = true; + if (inputs_shape.is_static() && indices_shape.is_static() && updates_shape.is_static()) { - compatible = compatible && updates_shape[i].same_scheme(indices_shape[i]); - NODE_VALIDATION_CHECK( - this, - compatible, - "updates_shape[0:indices_rank-1] shape must be indices_shape[:-1]"); - } - size_t j = indices_shape[indices_rank - 1].get_length(); - for (size_t i = indices_rank - 1; i < updates_rank; i++, j++) - { - compatible = compatible && updates_shape[i].same_scheme(inputs_shape[j]); - NODE_VALIDATION_CHECK( - this, - compatible, - "updates_shape[indices_rank-1:] shape must be input_shape[indices_shape[-1]:]"); + size_t static_indices_rank = indices_rank.get_length(); + for (size_t i = 0; i < static_indices_rank - 1; i++) + { + compatible = compatible && updates_shape[i].same_scheme(indices_shape[i]); + NODE_VALIDATION_CHECK( + this, + compatible, + "updates_shape[0:indices_rank-1] shape must be indices_shape[:-1]"); + } + size_t j = indices_shape[static_indices_rank - 1].get_length(); + for (size_t i = static_indices_rank - 1; i < expected_updates_rank; i++, j++) + { + compatible = compatible && updates_shape[i].same_scheme(inputs_shape[j]); + NODE_VALIDATION_CHECK( + this, + compatible, + "updates_shape[indices_rank-1:] shape must be input_shape[indices_shape[-1]:]"); + } } }