[Core] Allow ScatterND inputs type to be dynamic (#16236)
* Allow ScatterND inputs type to be dynamic Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com> * Update src/core/src/op/util/scatter_nd_base.cpp Co-authored-by: Pawel Raasz <pawel.raasz@intel.com> * Update src/core/src/op/util/scatter_nd_base.cpp Co-authored-by: Pawel Raasz <pawel.raasz@intel.com> * Update src/core/src/op/util/scatter_nd_base.cpp * Apply code-style --------- Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com> Co-authored-by: Pawel Raasz <pawel.raasz@intel.com>
This commit is contained in:
parent
df6cd3303a
commit
0ffa4eb507
@ -35,10 +35,13 @@ void ov::op::util::ScatterNDBase::validate_and_infer_types() {
|
||||
element::Type updates_et = get_input_element_type(UPDATES);
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
indices_et == element::i32 || indices_et == element::i64,
|
||||
indices_et.compatible(element::i32) || indices_et.compatible(element::i64),
|
||||
"Indices element type must be i64 or i32");
|
||||
|
||||
NODE_VALIDATION_CHECK(this, updates_et == inputs_et, "Updates element type must be the same as inputs");
|
||||
element::Type outputs_et = element::dynamic;
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
element::Type::merge(outputs_et, inputs_et, updates_et),
|
||||
"Updates element type must be the same as inputs");
|
||||
|
||||
const auto& inputs = get_input_partial_shape(0);
|
||||
const auto& indices = get_input_partial_shape(1);
|
||||
@ -48,5 +51,5 @@ void ov::op::util::ScatterNDBase::validate_and_infer_types() {
|
||||
std::vector<ov::PartialShape> input_shapes = {inputs, indices, updates};
|
||||
|
||||
shape_infer(this, input_shapes, output_shapes);
|
||||
set_output_type(0, inputs_et, output_shapes[0]);
|
||||
set_output_type(0, outputs_et, output_shapes[0]);
|
||||
}
|
||||
|
@ -209,3 +209,36 @@ TEST_F(TypePropScatterUpdateNDV3Test, preserve_partial_values_and_labels_via_eva
|
||||
EXPECT_EQ(bc->get_output_partial_shape(0), PartialShape({{3, 4}, 3, {10, 20}, 4}));
|
||||
EXPECT_THAT(get_shape_labels(bc->get_output_partial_shape(0)), ElementsAre(21, ov::no_label, 20, ov::no_label));
|
||||
}
|
||||
|
||||
TEST_F(TypePropScatterUpdateNDV3Test, indices_dynamic_type) {
|
||||
const auto d = std::make_shared<Parameter>(element::f32, data_3d_dynamic);
|
||||
const auto i = std::make_shared<Parameter>(element::dynamic, PartialShape{3, 2});
|
||||
const auto u = std::make_shared<Parameter>(element::f32, PartialShape{3, 5});
|
||||
|
||||
const auto op = make_op(d, i, u);
|
||||
|
||||
EXPECT_EQ(op->get_output_element_type(0), element::f32);
|
||||
EXPECT_EQ(op->get_output_partial_shape(0), data_3d_dynamic);
|
||||
}
|
||||
|
||||
TEST_F(TypePropScatterUpdateNDV3Test, updates_dynamic_type) {
|
||||
const auto d = std::make_shared<Parameter>(element::i64, data_3d_dynamic);
|
||||
const auto i = std::make_shared<Parameter>(element::i32, PartialShape{3, 2});
|
||||
const auto u = std::make_shared<Parameter>(element::dynamic, PartialShape{3, 5});
|
||||
|
||||
const auto op = make_op(d, i, u);
|
||||
|
||||
EXPECT_EQ(op->get_output_element_type(0), element::i64);
|
||||
EXPECT_EQ(op->get_output_partial_shape(0), data_3d_dynamic);
|
||||
}
|
||||
|
||||
TEST_F(TypePropScatterUpdateNDV3Test, all_dynamic_type) {
|
||||
const auto d = std::make_shared<Parameter>(element::dynamic, data_3d_dynamic);
|
||||
const auto i = std::make_shared<Parameter>(element::i64, PartialShape{3, 2});
|
||||
const auto u = std::make_shared<Parameter>(element::dynamic, PartialShape{3, 5});
|
||||
|
||||
const auto op = make_op(d, i, u);
|
||||
|
||||
EXPECT_EQ(op->get_output_element_type(0), element::dynamic);
|
||||
EXPECT_EQ(op->get_output_partial_shape(0), data_3d_dynamic);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user