diff --git a/src/frontends/pytorch/src/transforms/aten_index_put_replacer.cpp b/src/frontends/pytorch/src/transforms/aten_index_put_replacer.cpp index 3959e2383a8..39c2baef8ef 100644 --- a/src/frontends/pytorch/src/transforms/aten_index_put_replacer.cpp +++ b/src/frontends/pytorch/src/transforms/aten_index_put_replacer.cpp @@ -14,6 +14,7 @@ #include "openvino/op/gather.hpp" #include "openvino/op/mod.hpp" #include "openvino/op/non_zero.hpp" +#include "openvino/op/reshape.hpp" #include "openvino/op/scatter_nd_update.hpp" #include "openvino/op/shape_of.hpp" #include "openvino/op/slice.hpp" @@ -128,16 +129,22 @@ AtenIndexPutReplacer::AtenIndexPutReplacer() { auto index_dtype = index.get_element_type(); // Do we need to also check u8? if (index_dtype == element::boolean) { - auto nonzero = rg.make(index, element::i32); + // then apply masked scatter + auto input_shape = rg.make(input, element::i32); + auto expanded_mask = rg.make(index, input_shape, BroadcastType::BIDIRECTIONAL); + auto nonzero = rg.make(expanded_mask, element::i32); auto input_order = v0::Constant::create(element::i32, Shape{2}, {1, 0}); index = rg.make(nonzero, input_order); - broadcast_index_shape = rg.make(index, element::i32); - auto start_0 = v0::Constant::create(element::i32, Shape{1}, {0}); - auto end_neg_1 = v0::Constant::create(element::i32, Shape{1}, {-1}); - auto values_shape = rg.make(broadcast_index_shape, start_0, end_neg_1, const_1); - values = rg.make(values, values_shape); - values = rg.make(values, input); - auto result = rg.make(input, index, values); + // source can be arbitary shape, select only relevant data + auto const_minus_1 = v0::Constant::create(element::i32, Shape{1}, {-1}); + auto flatten_values = rg.make(values, const_minus_1, false); + auto const_0 = v0::Constant::create(element::i32, Shape{1}, {0}); + + auto index_shape = rg.make(index, element::i32); + auto index_dim_zero = rg.make(index_shape, const_0, const_0); + auto slice_steps = v0::Constant::create(element::i32, Shape{1}, {1}); + auto sliced_source = rg.make(flatten_values, const_0, index_dim_zero, slice_steps, const_0); + auto result = rg.make(input, index, sliced_source); copy_runtime_info_and_name(index_op, rg.get(), rt_copy_from); replace_node(index_op, result); return true; diff --git a/tests/layer_tests/pytorch_tests/test_index_put_.py b/tests/layer_tests/pytorch_tests/test_index_put_.py index 68eaed21626..e367d2a6d68 100644 --- a/tests/layer_tests/pytorch_tests/test_index_put_.py +++ b/tests/layer_tests/pytorch_tests/test_index_put_.py @@ -142,6 +142,18 @@ class TestNonZero_IndexPut(PytorchLayerTest): "input_shape": [3, 3], "values": np.array([10, 11, 12]).astype(np.float32), }, + { + "input_shape": [3, 3, 3], + "values": np.array([[10, 11, 12]]).astype(np.float32), + }, + { + "input_shape": [3, 3, 3], + "values": np.array(10).astype(np.float32), + }, + { + "input_shape": [3, 3, 3], + "values": np.zeros((1, 1, 3)).astype(np.float32), + }, ), ) @pytest.mark.parametrize( @@ -166,6 +178,18 @@ class TestNonZero_IndexPut(PytorchLayerTest): self.indices_0 = indices[0] self.indices_1 = indices[1] self._test(*self.create_model(accumulate), ie_device, precision, ir_version, trace_model=True, use_convert_model=True) + + + @pytest.mark.nightly + @pytest.mark.precommit + def test_nonzero_index_put_different_ranks(self, ie_device, precision, ir_version): + self.input_tensor = np.random.randn(1, 10, 2).astype(np.float32) + self.values = np.zeros((10, 2), dtype=np.float32) + self.indices_0 = np.array([[0, 0, 1, 1, 1, 1, 1, 1, 0, 0]]).astype(np.float32) + self.indices_1 = np.zeros((1, 10), dtype=np.float32) + self._test(*self.create_model(False), ie_device, precision, ir_version, trace_model=True, use_convert_model=True) + + class TestMask_IndexPut(PytorchLayerTest): def _prepare_input(self):