[Core] Allow ScatterND inputs type to be dynamic (#16236)

* Allow ScatterND inputs type to be dynamic

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>

* Update src/core/src/op/util/scatter_nd_base.cpp

Co-authored-by: Pawel Raasz <pawel.raasz@intel.com>

* Update src/core/src/op/util/scatter_nd_base.cpp

Co-authored-by: Pawel Raasz <pawel.raasz@intel.com>

* Update src/core/src/op/util/scatter_nd_base.cpp

* Apply code-style

---------

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
Co-authored-by: Pawel Raasz <pawel.raasz@intel.com>
This commit is contained in:
Roman Kazantsev 2023-03-13 14:36:00 +04:00 committed by GitHub
parent df6cd3303a
commit 0ffa4eb507
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 39 additions and 3 deletions

View File

@ -35,10 +35,13 @@ void ov::op::util::ScatterNDBase::validate_and_infer_types() {
element::Type updates_et = get_input_element_type(UPDATES);
NODE_VALIDATION_CHECK(this,
indices_et == element::i32 || indices_et == element::i64,
indices_et.compatible(element::i32) || indices_et.compatible(element::i64),
"Indices element type must be i64 or i32");
NODE_VALIDATION_CHECK(this, updates_et == inputs_et, "Updates element type must be the same as inputs");
element::Type outputs_et = element::dynamic;
NODE_VALIDATION_CHECK(this,
element::Type::merge(outputs_et, inputs_et, updates_et),
"Updates element type must be the same as inputs");
const auto& inputs = get_input_partial_shape(0);
const auto& indices = get_input_partial_shape(1);
@ -48,5 +51,5 @@ void ov::op::util::ScatterNDBase::validate_and_infer_types() {
std::vector<ov::PartialShape> input_shapes = {inputs, indices, updates};
shape_infer(this, input_shapes, output_shapes);
set_output_type(0, inputs_et, output_shapes[0]);
set_output_type(0, outputs_et, output_shapes[0]);
}

View File

@ -209,3 +209,36 @@ TEST_F(TypePropScatterUpdateNDV3Test, preserve_partial_values_and_labels_via_eva
EXPECT_EQ(bc->get_output_partial_shape(0), PartialShape({{3, 4}, 3, {10, 20}, 4}));
EXPECT_THAT(get_shape_labels(bc->get_output_partial_shape(0)), ElementsAre(21, ov::no_label, 20, ov::no_label));
}
TEST_F(TypePropScatterUpdateNDV3Test, indices_dynamic_type) {
const auto d = std::make_shared<Parameter>(element::f32, data_3d_dynamic);
const auto i = std::make_shared<Parameter>(element::dynamic, PartialShape{3, 2});
const auto u = std::make_shared<Parameter>(element::f32, PartialShape{3, 5});
const auto op = make_op(d, i, u);
EXPECT_EQ(op->get_output_element_type(0), element::f32);
EXPECT_EQ(op->get_output_partial_shape(0), data_3d_dynamic);
}
TEST_F(TypePropScatterUpdateNDV3Test, updates_dynamic_type) {
const auto d = std::make_shared<Parameter>(element::i64, data_3d_dynamic);
const auto i = std::make_shared<Parameter>(element::i32, PartialShape{3, 2});
const auto u = std::make_shared<Parameter>(element::dynamic, PartialShape{3, 5});
const auto op = make_op(d, i, u);
EXPECT_EQ(op->get_output_element_type(0), element::i64);
EXPECT_EQ(op->get_output_partial_shape(0), data_3d_dynamic);
}
TEST_F(TypePropScatterUpdateNDV3Test, all_dynamic_type) {
const auto d = std::make_shared<Parameter>(element::dynamic, data_3d_dynamic);
const auto i = std::make_shared<Parameter>(element::i64, PartialShape{3, 2});
const auto u = std::make_shared<Parameter>(element::dynamic, PartialShape{3, 5});
const auto op = make_op(d, i, u);
EXPECT_EQ(op->get_output_element_type(0), element::dynamic);
EXPECT_EQ(op->get_output_partial_shape(0), data_3d_dynamic);
}