Review scatter nd update class for shape inference aspects (#15610)

* Review interval shape and labels propagation

* Review shape infer template implementation
- add test for default ctor

* Add evaluate upper, lower and label
- add new default label evaluator which propagates labels
from inputs list

* default_label_evaluator for 0 input only is wrapper for
generic evaluator implementation

* Use default_label_evaluator in ScatterUpdate

* Fix build issues
This commit is contained in:
Pawel Raasz
2023-02-16 13:49:19 +01:00
committed by GitHub
parent 71cff0ae62
commit d32da828b4
12 changed files with 301 additions and 157 deletions

View File

@@ -156,7 +156,7 @@ const std::vector<GatherTransformationTestValues> testValues = {
ngraph::element::u8,
{{ngraph::element::f32},
{{128.f}, element::undefined, {1, 3, 1}, false, 1ul, element::u8, true},
{{0.1f}, ngraph::element::f32, {1, 3, 1}}}}},
{{0.1f}, ngraph::element::f32, {1, 3, 1}}}}},
// U8: per-channel quantization, gather axis match with channel
{{1},
{0},

View File

@@ -25,6 +25,9 @@ public:
OPENVINO_SUPPRESS_DEPRECATED_START
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
OPENVINO_SUPPRESS_DEPRECATED_END
bool evaluate_lower(TensorVector& output_values) const override;
bool evaluate_upper(TensorVector& output_values) const override;
bool evaluate_label(TensorLabelVector& output_labels) const override;
bool has_evaluate() const override;
};
} // namespace v3

View File

@@ -8,14 +8,15 @@
#include "utils.hpp"
template <class T>
void shape_infer(const ov::op::util::ScatterNDBase* op,
const std::vector<T>& input_shapes,
std::vector<T>& output_shapes) {
NODE_VALIDATION_CHECK(op, input_shapes.size() == 3 && output_shapes.size() == 1);
const auto& inputs_shape = input_shapes[0];
const auto& indices_shape = input_shapes[1];
const auto& updates_shape = input_shapes[2];
namespace ov {
namespace op {
template <class TShape>
std::vector<TShape> shape_infer(const util::ScatterNDBase* op, const std::vector<TShape>& input_shapes) {
NODE_VALIDATION_CHECK(op, input_shapes.size() == 3);
const auto& inputs_shape = input_shapes[util::ScatterNDBase::INPUTS];
const auto& indices_shape = input_shapes[util::ScatterNDBase::INDICES];
const auto& updates_shape = input_shapes[util::ScatterNDBase::UPDATES];
const auto& inputs_rank = inputs_shape.rank();
const auto& indices_rank = indices_shape.rank();
@@ -25,35 +26,51 @@ void shape_infer(const ov::op::util::ScatterNDBase* op,
indices_rank != 0 && inputs_rank != 0,
"Indices rank and inputs_rank are expected to be at least 1");
NODE_VALIDATION_CHECK(
op,
inputs_rank.is_dynamic() || indices_rank.is_dynamic() || indices_shape[indices_shape.size() - 1].is_dynamic() ||
static_cast<size_t>(indices_shape[indices_shape.size() - 1].get_length()) <= inputs_shape.size(),
"Last dimension of indices can be at most the rank of inputs");
if (inputs_rank.is_static() && indices_rank.is_static()) {
const auto last_idx_pos = indices_shape.size() - 1;
const auto& last_idx_dim = indices_shape[last_idx_pos];
if (inputs_rank.is_static() && indices_rank.is_static() && updates_rank.is_static() &&
indices_shape[indices_shape.size() - 1].is_static()) {
auto expected_updates_rank =
indices_shape.size() + inputs_shape.size() - indices_shape[indices_shape.size() - 1].get_length() - 1;
// If expected updates rank is 0D it also can be a tensor with one element
NODE_VALIDATION_CHECK(op,
updates_shape.size() == expected_updates_rank || expected_updates_rank == 0,
"Rank of updates must be rank of inputs + rank of indices - last dimension of indices "
"- 1");
if (last_idx_dim.is_static()) {
const auto last_idx_dim_size = static_cast<size_t>(last_idx_dim.get_length());
bool compatible = true;
size_t static_indices_rank = indices_shape.size();
for (size_t i = 0; i < static_indices_rank - 1; i++) {
compatible = compatible && updates_shape[i].compatible(indices_shape[i]);
NODE_VALIDATION_CHECK(op, compatible, "updates_shape[0:indices_rank-1] shape must be indices_shape[:-1]");
}
size_t j = indices_shape[static_indices_rank - 1].get_length();
for (int64_t i = static_indices_rank - 1; i < static_cast<int64_t>(expected_updates_rank); i++, j++) {
compatible = compatible && updates_shape[i].compatible(inputs_shape[j]);
NODE_VALIDATION_CHECK(op,
compatible,
"updates_shape[indices_rank-1:] shape must be input_shape[indices_shape[-1]:]");
last_idx_dim_size <= inputs_shape.size(),
"Last dimension of indices can be at most the rank of inputs");
if (updates_rank.is_static()) {
// Used last_idx_pos because is equal rank of indices - 1
const auto expected_updates_rank = inputs_shape.size() + last_idx_pos - last_idx_dim_size;
// If expected updates rank is 0D it also can be a tensor with one element
NODE_VALIDATION_CHECK(
op,
updates_shape.size() == expected_updates_rank || expected_updates_rank == 0,
"Rank of updates must be rank of inputs + rank of indices - last dimension of indices - 1");
auto update_iter = updates_shape.begin();
auto is_update_compatible = [&update_iter](const typename TShape::value_type& d) -> bool {
return d.compatible(*update_iter++);
};
NODE_VALIDATION_CHECK(
op,
std::all_of(indices_shape.begin(), indices_shape.begin() + last_idx_pos, is_update_compatible),
"updates_shape[0:indices_rank-1] shape must be indices_shape[:-1]");
NODE_VALIDATION_CHECK(
op,
std::all_of(inputs_shape.begin() + last_idx_dim_size, inputs_shape.end(), is_update_compatible),
"updates_shape[indices_rank-1:] shape must be input_shape[indices_shape[-1]:]");
}
}
}
output_shapes[0] = inputs_shape;
return {inputs_shape};
}
template <class TShape>
void shape_infer(const util::ScatterNDBase* op,
const std::vector<TShape>& input_shapes,
std::vector<TShape>& output_shapes) {
output_shapes = shape_infer(op, input_shapes);
}
} // namespace op
} // namespace ov

View File

@@ -4,6 +4,7 @@
#include "bound_evaluate.hpp"
#include "dimension_tracker.hpp"
#include "ngraph/validation_util.hpp"
#include "openvino/core/rt_info.hpp"
#include "openvino/opsets/opset10.hpp"
@@ -425,3 +426,58 @@ bool ov::have_node_inputs_bounds_set(const Node* const node, const size_t first_
}
return have_bound_set;
}
bool ov::default_label_evaluator(const Node* node,
std::initializer_list<size_t> labeled_inputs,
TensorLabelVector& output_labels) {
bool has_any_input_labels = false;
const auto& inputs_count = node->get_input_size();
TensorVector inputs;
inputs.reserve(inputs_count);
for (size_t i = 0; i < inputs_count; ++i) {
if (std::find(labeled_inputs.begin(), labeled_inputs.end(), i) != labeled_inputs.end()) {
auto labels = node->get_input_tensor(i).get_value_label();
if (!has_no_labels(labels) && !has_any_input_labels) {
has_any_input_labels = true;
}
if (node->get_input_partial_shape(i).is_static()) {
labels.resize(shape_size(node->get_input_shape(i)), no_label);
inputs.emplace_back(element::from<label_t>(), node->get_input_shape(i));
std::copy(labels.begin(), labels.end(), inputs.back().data<label_t>());
} else {
return false;
}
} else {
if (node->get_input_tensor(i).has_and_set_bound()) {
inputs.push_back(node->get_input_tensor(i).get_lower_value());
} else {
return false;
}
}
}
if (has_any_input_labels) {
const auto& outputs_count = node->get_output_size();
TensorVector outputs;
outputs.reserve(outputs_count);
for (size_t i = 0; i < outputs_count; ++i) {
const auto& partial_shape = node->get_output_partial_shape(i);
// Set shape for static or Shape{0} for dynamic to postpone memory allocation
auto shape = partial_shape.is_static() ? partial_shape.to_shape() : Shape{0};
outputs.emplace_back(element::from<label_t>(), shape);
}
if (node->evaluate(outputs, inputs)) {
std::transform(outputs.cbegin(), outputs.cend(), output_labels.begin(), [](const Tensor& t) {
// Return empty label tensor if input tensor not valid (can have Shape{0})
return t ? TensorLabel(t.data<label_t>(), t.data<label_t>() + t.get_size()) : TensorLabel();
});
return true;
}
}
return false;
}

View File

@@ -53,4 +53,15 @@ bool has_and_set_equal_bounds(const Output<Node>& source);
/// greater than node's inputs count.
bool have_node_inputs_bounds_set(const ov::Node* const node, const size_t first_idx, const size_t last_idx);
/// \brief Propagates value label from given inputs list to the only output through an operation.
/// Not applicable for operations which require values interaction (example: mathematical
/// operations). Could be used for movement operations (example: gathering, shape change)
///
/// \param node Operation to be performed
/// \param labeled_inputs List of node inputs to propagate labels.
/// \param output_labels Vector of TensorLabel objects representing resulting value labels
/// \return boolean status if label evaluation was successful.
bool default_label_evaluator(const Node* node,
std::initializer_list<size_t> labeled_inputs,
TensorLabelVector& output_labels);
} // namespace ov

View File

@@ -4,6 +4,7 @@
#include "ngraph/op/scatter_nd_update.hpp"
#include "bound_evaluate.hpp"
#include "itt.hpp"
#include "ngraph/runtime/host_tensor.hpp"
#include "ngraph/runtime/reference/scatter_nd_update.hpp"
@@ -28,11 +29,7 @@ bool evaluate(const HostTensorPtr& arg0,
const HostTensorPtr& arg2,
const HostTensorPtr& out) {
using T = typename element_type_traits<ET>::value_type;
ov::Shape params_shape = arg0->get_shape();
ov::Shape indices_shape = arg1->get_shape();
ov::Shape updates_shape = arg1->get_shape();
const ov::Shape& out_shape(params_shape);
out->set_shape(out_shape);
out->set_shape(arg0->get_shape());
if (arg1->get_element_type() == element::i64) {
runtime::reference::scatterNdUpdate<T, int64_t>(arg0->get_data_ptr<ET>(),
@@ -114,3 +111,19 @@ bool op::v3::ScatterNDUpdate::has_evaluate() const {
}
return true;
}
bool op::v3::ScatterNDUpdate::evaluate_lower(ov::TensorVector& output_values) const {
OV_OP_SCOPE(v3_ScatterNDUpdate_evaluate_lower);
return get_input_tensor(1).has_and_set_bound() && ov::default_lower_bound_evaluator(this, output_values);
}
bool op::v3::ScatterNDUpdate::evaluate_upper(ov::TensorVector& output_values) const {
OV_OP_SCOPE(v3_ScatterNDUpdate_evaluate_upper);
return get_input_tensor(1).has_and_set_bound() && ov::default_upper_bound_evaluator(this, output_values);
}
bool op::v3::ScatterNDUpdate::evaluate_label(TensorLabelVector& output_labels) const {
OV_OP_SCOPE(v3_ScatterNDUpdate_evaluate_label);
return ov::default_label_evaluator(this, {0, 2}, output_labels);
}

View File

@@ -124,54 +124,7 @@ bool op::v3::ScatterUpdate::has_evaluate() const {
return false;
}
namespace {
bool scatter_label_evaluator(const Node* node, TensorLabelVector& output_labels) {
const auto& input_values = node->input_values();
constexpr auto data_in_idx = 0;
constexpr auto updates_in_idx = 2;
auto data_labels = input_values[data_in_idx].get_tensor().get_value_label();
auto updates_labels = input_values[updates_in_idx].get_tensor().get_value_label();
if (ov::has_no_labels(data_labels) && ov::has_no_labels(updates_labels)) {
return false;
}
ov::TensorVector input_tensors;
input_tensors.reserve(input_values.size());
auto make_input_label = [&](const Output<Node>& input, TensorLabel& labels) {
input_tensors.emplace_back(ov::element::from<ov::label_t>(), input.get_shape());
labels.resize(shape_size(input.get_shape()));
memcpy(input_tensors.back().data(), labels.data(), input_tensors.back().get_byte_size());
};
for (size_t i = 0; i < input_values.size(); ++i) {
const auto& input = input_values[i];
if (i == data_in_idx) {
make_input_label(input, data_labels);
} else if (i == updates_in_idx) {
make_input_label(input, updates_labels);
} else {
input_tensors.push_back(input.get_tensor().get_lower_value());
}
}
ov::TensorVector output_tensors{ov::Tensor(ov::element::from<ov::label_t>(), node->get_output_shape(0))};
if (node->evaluate(output_tensors, input_tensors)) {
output_labels[0] = ov::TensorLabel(output_tensors[0].data<ov::label_t>(),
output_tensors[0].data<ov::label_t>() + output_tensors[0].get_size());
return true;
}
return false;
}
} // namespace
bool op::v3::ScatterUpdate::evaluate_label(TensorLabelVector& output_labels) const {
OV_OP_SCOPE(v3_ScatterUpdate_evaluate_label);
if (get_input_partial_shape(0).is_static() && get_input_partial_shape(2).is_static() &&
get_input_tensor(1).has_and_set_bound() && get_input_tensor(3).has_and_set_bound()) {
return scatter_label_evaluator(this, output_labels);
}
return false;
return ov::default_label_evaluator(this, {0, 2}, output_labels);
}

View File

@@ -1203,45 +1203,7 @@ bool ov::evaluate_as_partial_shape(const Output<Node>& output, PartialShape& psh
}
bool ov::default_label_evaluator(const Node* node, TensorLabelVector& output_labels) {
const auto& inputs_count = node->get_input_size();
if (inputs_count > 0) {
const auto& labels = node->get_input_tensor(0).get_value_label();
if (!has_no_labels(labels)) {
TensorVector inputs;
inputs.reserve(inputs_count);
inputs.emplace_back(element::from<label_t>(), node->get_input_shape(0));
std::copy(labels.begin(), labels.end(), inputs.back().data<label_t>());
for (size_t i = 1; i < inputs_count; ++i) {
if (node->get_input_tensor(i).has_and_set_bound()) {
inputs.push_back(node->get_input_tensor(i).get_lower_value());
} else {
return false;
}
}
const auto& outputs_count = node->get_output_size();
TensorVector outputs;
outputs.reserve(outputs_count);
for (size_t i = 0; i < outputs_count; ++i) {
const auto& partial_shape = node->get_output_partial_shape(i);
// Set shape for static or Shape{0} for dynamic to postpone memory allocation
auto shape = partial_shape.is_static() ? partial_shape.to_shape() : Shape{0};
outputs.emplace_back(element::from<label_t>(), shape);
}
if (node->evaluate(outputs, inputs)) {
std::transform(outputs.cbegin(), outputs.cend(), output_labels.begin(), [](const Tensor& t) {
// Return empty label tensor if input tensor not valid (can have Shape{0})
return t ? TensorLabel(t.data<label_t>(), t.data<label_t>() + t.get_size()) : TensorLabel();
});
return true;
}
}
}
return false;
return default_label_evaluator(node, {0}, output_labels);
}
shared_ptr<op::Constant> ngraph::get_constant_max_of_type(element::Type_t t) {

View File

@@ -4,6 +4,7 @@
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "openvino/opsets/opset10.hpp"
#include "util/type_prop.hpp"
using namespace std;
@@ -106,3 +107,105 @@ TEST(type_prop, scatter_nd_update_fail_indices_last_dim) {
FAIL() << "Deduced type check failed for unexpected reason";
}
}
using namespace ov::opset10;
using namespace testing;
class TypePropScatterUpdateNDV3Test : public TypePropOpTest<op::v3::ScatterNDUpdate> {
protected:
void SetUp() override {
set_shape_labels(data_3d_dynamic, 10);
}
PartialShape data_3d_dynamic{{2, 5}, 2, {4, 10}};
};
TEST_F(TypePropScatterUpdateNDV3Test, data_input_partial_shape_and_labels_propagation) {
const auto d = std::make_shared<Parameter>(element::f32, data_3d_dynamic);
const auto i = std::make_shared<Parameter>(element::i32, PartialShape{3, 2});
const auto u = std::make_shared<Parameter>(element::f32, PartialShape{3, 5});
const auto op = make_op(d, i, u);
EXPECT_EQ(op->get_input_size(), 3);
EXPECT_EQ(op->get_output_size(), 1);
EXPECT_EQ(op->get_output_element_type(0), element::f32);
EXPECT_EQ(op->get_output_partial_shape(0), data_3d_dynamic);
EXPECT_THAT(get_shape_labels(op->get_output_partial_shape(0)), ElementsAre(10, 11, 12));
}
TEST_F(TypePropScatterUpdateNDV3Test, indicies_input_is_dynamic) {
const auto d = std::make_shared<Parameter>(element::f64, data_3d_dynamic);
const auto i = std::make_shared<Parameter>(element::i32, PartialShape::dynamic());
const auto u = std::make_shared<Parameter>(element::f64, PartialShape{3, 5});
const auto op = make_op(d, i, u);
EXPECT_EQ(op->get_output_element_type(0), element::f64);
EXPECT_EQ(op->get_output_partial_shape(0), data_3d_dynamic);
EXPECT_THAT(get_shape_labels(op->get_output_partial_shape(0)), ElementsAre(10, 11, 12));
}
TEST_F(TypePropScatterUpdateNDV3Test, updates_input_is_dynamic) {
const auto d = std::make_shared<Parameter>(element::f64, data_3d_dynamic);
const auto i = std::make_shared<Parameter>(element::i32, PartialShape{3, 2});
const auto u = std::make_shared<Parameter>(element::f64, PartialShape::dynamic());
const auto op = make_op(d, i, u);
EXPECT_EQ(op->get_output_element_type(0), element::f64);
EXPECT_EQ(op->get_output_partial_shape(0), data_3d_dynamic);
EXPECT_THAT(get_shape_labels(op->get_output_partial_shape(0)), ElementsAre(10, 11, 12));
}
TEST_F(TypePropScatterUpdateNDV3Test, indicies_input_has_interval_dimensions) {
const auto d = std::make_shared<Parameter>(element::i64, data_3d_dynamic);
const auto i = std::make_shared<Parameter>(element::i32, PartialShape{{0, 3}, 1});
const auto u = std::make_shared<Parameter>(element::i64, PartialShape{3, 2, {8, 10}});
const auto op = make_op(d, i, u);
EXPECT_EQ(op->get_output_element_type(0), element::i64);
EXPECT_EQ(op->get_output_partial_shape(0), data_3d_dynamic);
EXPECT_THAT(get_shape_labels(op->get_output_partial_shape(0)), ElementsAre(10, 11, 12));
}
TEST_F(TypePropScatterUpdateNDV3Test, updates_input_is_scalar) {
const auto d = std::make_shared<Parameter>(element::i8, data_3d_dynamic);
const auto i = std::make_shared<Parameter>(element::i32, PartialShape{3});
const auto u = std::make_shared<Parameter>(element::i8, PartialShape{});
const auto op = make_op(d, i, u);
EXPECT_EQ(op->get_output_element_type(0), element::i8);
EXPECT_EQ(op->get_output_partial_shape(0), data_3d_dynamic);
}
TEST_F(TypePropScatterUpdateNDV3Test, default_ctor) {
const auto d = std::make_shared<Parameter>(element::i64, PartialShape{2, 3, 5, 1});
const auto i = std::make_shared<Parameter>(element::i32, PartialShape{1, 3});
const auto u = std::make_shared<Parameter>(element::i64, PartialShape{1, 1});
const auto op = make_op();
op->set_arguments(OutputVector{d, i, u});
op->validate_and_infer_types();
EXPECT_EQ(op->get_output_element_type(0), element::i64);
EXPECT_EQ(op->get_output_partial_shape(0), PartialShape({2, 3, 5, 1}));
EXPECT_THAT(get_shape_labels(op->get_output_partial_shape(0)), Each(ov::no_label));
}
TEST_F(TypePropScatterUpdateNDV3Test, preserve_partial_values_and_labels_via_evaluates_bounds) {
const auto d = Constant::create(element::i64, Shape{4}, {2, 3, 15, 4});
const auto i = Constant::create(element::i64, Shape{2, 1}, {2, 0});
auto u_shape = PartialShape{{10, 20}, {3, 4}};
set_shape_labels(u_shape, 20);
const auto shape_of_u = std::make_shared<op::ShapeOf>(std::make_shared<Parameter>(element::i64, u_shape));
const auto op = make_op(d, i, shape_of_u);
auto param = std::make_shared<op::Parameter>(element::f32, PartialShape{1});
auto bc = std::make_shared<op::v3::Broadcast>(param, op, op::BroadcastType::BIDIRECTIONAL);
EXPECT_EQ(bc->get_output_partial_shape(0), PartialShape({{3, 4}, 3, {10, 20}, 4}));
EXPECT_THAT(get_shape_labels(bc->get_output_partial_shape(0)), ElementsAre(21, ov::no_label, 20, ov::no_label));
}

View File

@@ -579,7 +579,6 @@ const IShapeInferCommonFactory::TRegistry IShapeInferCommonFactory::registry{
_OV_OP_SHAPE_INFER_REG(ScatterElementsUpdate, entryIOC),
_OV_OP_SHAPE_INFER_REG(ScatterNDUpdate, entryIO),
_OV_OP_SHAPE_INFER_REG(Select, entryIO),
_OV_OP_SHAPE_INFER_REG(Select, entryIO),
_OV_OP_SHAPE_INFER_REG(ShapeOf, entryIO),
_OV_OP_SHAPE_INFER_REG(ShuffleChannels, entryIO),
_OV_OP_SHAPE_INFER_REG(Slice, entryIOC),

View File

@@ -1,29 +0,0 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include <openvino/op/ops.hpp>
#include <openvino/op/parameter.hpp>
#include <scatter_nd_base_shape_inference.hpp>
#include <utils/shape_inference/shape_inference.hpp>
#include <utils/shape_inference/static_shape.hpp>
using namespace ov;
using namespace ov::intel_cpu;
TEST(StaticShapeInferenceTest, ScatterNDUpdateTest) {
auto data_shape = std::make_shared<ov::op::v0::Parameter>(element::i32, PartialShape{-1, -1, -1, -1});
auto indices_shape = std::make_shared<ov::op::v0::Parameter>(element::i32, PartialShape{-1, -1, -1});
auto updates_shape = std::make_shared<ov::op::v0::Parameter>(element::i32, PartialShape{-1, -1, -1});
auto scatter_nd_update = std::make_shared<op::v3::ScatterNDUpdate>(data_shape, indices_shape, updates_shape);
std::vector<StaticShape> input_shapes = {StaticShape{1000, 256, 10, 15},
StaticShape{25, 125, 3},
StaticShape{25, 125, 15}},
output_shapes = {StaticShape{}};
shape_inference(scatter_nd_update.get(), input_shapes, output_shapes);
ASSERT_EQ(output_shapes[0], StaticShape({1000, 256, 10, 15}));
}

View File

@@ -0,0 +1,56 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include "openvino/opsets/opset10.hpp"
#include "utils.hpp"
using namespace ov;
using namespace ov::intel_cpu;
class ScatterNDUpdateV3StaticShapeInferenceTest : public OpStaticShapeInferenceTest<op::v3::ScatterNDUpdate> {
protected:
void SetUp() override {
output_shapes.resize(1);
}
};
TEST_F(ScatterNDUpdateV3StaticShapeInferenceTest, default_ctor) {
const auto op = make_op();
input_shapes = ShapeVector{{1000, 256, 10, 13}, {25, 125, 3}, {25, 125, 13}};
shape_inference(op.get(), input_shapes, output_shapes);
EXPECT_EQ(output_shapes.size(), 1);
EXPECT_EQ(output_shapes[0], StaticShape({1000, 256, 10, 13}));
}
TEST_F(ScatterNDUpdateV3StaticShapeInferenceTest, correct_inputs) {
const auto d = std::make_shared<ov::op::v0::Parameter>(element::i32, PartialShape{-1, -1, -1, -1});
const auto i = std::make_shared<ov::op::v0::Parameter>(element::i32, PartialShape{-1, -1, -1});
const auto u = std::make_shared<ov::op::v0::Parameter>(element::i32, PartialShape{-1, -1, -1});
const auto op = make_op(d, i, u);
input_shapes = ShapeVector{{1000, 256, 10, 15}, {25, 125, 3}, {25, 125, 15}};
shape_inference(op.get(), input_shapes, output_shapes);
EXPECT_EQ(output_shapes.size(), 1);
EXPECT_EQ(output_shapes[0], StaticShape({1000, 256, 10, 15}));
}
TEST_F(ScatterNDUpdateV3StaticShapeInferenceTest, params_are_dynamic_rank) {
const auto d = std::make_shared<ov::op::v0::Parameter>(element::i32, PartialShape::dynamic());
const auto i = std::make_shared<ov::op::v0::Parameter>(element::i32, PartialShape::dynamic());
const auto u = std::make_shared<ov::op::v0::Parameter>(element::i32, PartialShape::dynamic());
const auto op = make_op(d, i, u);
input_shapes = ShapeVector{{5000, 256, 10, 15}, {30, 25, 3}, {30, 25, 15}};
shape_inference(op.get(), input_shapes, output_shapes);
EXPECT_EQ(output_shapes.size(), 1);
EXPECT_EQ(output_shapes[0], StaticShape({5000, 256, 10, 15}));
}