From 0ffa4eb507396948d87efc01303f7e5b84f8005b Mon Sep 17 00:00:00 2001 From: Roman Kazantsev Date: Mon, 13 Mar 2023 14:36:00 +0400 Subject: [PATCH] [Core] Allow ScatterND inputs type to be dynamic (#16236) * Allow ScatterND inputs type to be dynamic Signed-off-by: Kazantsev, Roman * Update src/core/src/op/util/scatter_nd_base.cpp Co-authored-by: Pawel Raasz * Update src/core/src/op/util/scatter_nd_base.cpp Co-authored-by: Pawel Raasz * Update src/core/src/op/util/scatter_nd_base.cpp * Apply code-style --------- Signed-off-by: Kazantsev, Roman Co-authored-by: Pawel Raasz --- src/core/src/op/util/scatter_nd_base.cpp | 9 +++-- .../tests/type_prop/scatter_nd_update.cpp | 33 +++++++++++++++++++ 2 files changed, 39 insertions(+), 3 deletions(-) diff --git a/src/core/src/op/util/scatter_nd_base.cpp b/src/core/src/op/util/scatter_nd_base.cpp index 22724cd4bc8..f4d70e1d2c9 100644 --- a/src/core/src/op/util/scatter_nd_base.cpp +++ b/src/core/src/op/util/scatter_nd_base.cpp @@ -35,10 +35,13 @@ void ov::op::util::ScatterNDBase::validate_and_infer_types() { element::Type updates_et = get_input_element_type(UPDATES); NODE_VALIDATION_CHECK(this, - indices_et == element::i32 || indices_et == element::i64, + indices_et.compatible(element::i32) || indices_et.compatible(element::i64), "Indices element type must be i64 or i32"); - NODE_VALIDATION_CHECK(this, updates_et == inputs_et, "Updates element type must be the same as inputs"); + element::Type outputs_et = element::dynamic; + NODE_VALIDATION_CHECK(this, + element::Type::merge(outputs_et, inputs_et, updates_et), + "Updates element type must be the same as inputs"); const auto& inputs = get_input_partial_shape(0); const auto& indices = get_input_partial_shape(1); @@ -48,5 +51,5 @@ void ov::op::util::ScatterNDBase::validate_and_infer_types() { std::vector input_shapes = {inputs, indices, updates}; shape_infer(this, input_shapes, output_shapes); - set_output_type(0, inputs_et, output_shapes[0]); + set_output_type(0, outputs_et, output_shapes[0]); } diff --git a/src/core/tests/type_prop/scatter_nd_update.cpp b/src/core/tests/type_prop/scatter_nd_update.cpp index 07808adf3db..47374d0b671 100644 --- a/src/core/tests/type_prop/scatter_nd_update.cpp +++ b/src/core/tests/type_prop/scatter_nd_update.cpp @@ -209,3 +209,36 @@ TEST_F(TypePropScatterUpdateNDV3Test, preserve_partial_values_and_labels_via_eva EXPECT_EQ(bc->get_output_partial_shape(0), PartialShape({{3, 4}, 3, {10, 20}, 4})); EXPECT_THAT(get_shape_labels(bc->get_output_partial_shape(0)), ElementsAre(21, ov::no_label, 20, ov::no_label)); } + +TEST_F(TypePropScatterUpdateNDV3Test, indices_dynamic_type) { + const auto d = std::make_shared(element::f32, data_3d_dynamic); + const auto i = std::make_shared(element::dynamic, PartialShape{3, 2}); + const auto u = std::make_shared(element::f32, PartialShape{3, 5}); + + const auto op = make_op(d, i, u); + + EXPECT_EQ(op->get_output_element_type(0), element::f32); + EXPECT_EQ(op->get_output_partial_shape(0), data_3d_dynamic); +} + +TEST_F(TypePropScatterUpdateNDV3Test, updates_dynamic_type) { + const auto d = std::make_shared(element::i64, data_3d_dynamic); + const auto i = std::make_shared(element::i32, PartialShape{3, 2}); + const auto u = std::make_shared(element::dynamic, PartialShape{3, 5}); + + const auto op = make_op(d, i, u); + + EXPECT_EQ(op->get_output_element_type(0), element::i64); + EXPECT_EQ(op->get_output_partial_shape(0), data_3d_dynamic); +} + +TEST_F(TypePropScatterUpdateNDV3Test, all_dynamic_type) { + const auto d = std::make_shared(element::dynamic, data_3d_dynamic); + const auto i = std::make_shared(element::i64, PartialShape{3, 2}); + const auto u = std::make_shared(element::dynamic, PartialShape{3, 5}); + + const auto op = make_op(d, i, u); + + EXPECT_EQ(op->get_output_element_type(0), element::dynamic); + EXPECT_EQ(op->get_output_partial_shape(0), data_3d_dynamic); +}