[PT FE]: fix for aten::index_put_ if values.r > indices.r (#21255)

* [PT FE]: fix for aten::index_put_ if values.r > indicies.r

* add more complex test case
This commit is contained in:
Ekaterina Aidova 2023-11-27 11:01:16 +04:00 committed by GitHub
parent 9421f4cf2d
commit a5d53aeaef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 39 additions and 8 deletions

View File

@ -14,6 +14,7 @@
#include "openvino/op/gather.hpp"
#include "openvino/op/mod.hpp"
#include "openvino/op/non_zero.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/op/scatter_nd_update.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/op/slice.hpp"
@ -128,16 +129,22 @@ AtenIndexPutReplacer::AtenIndexPutReplacer() {
auto index_dtype = index.get_element_type();
// Do we need to also check u8?
if (index_dtype == element::boolean) {
auto nonzero = rg.make<v3::NonZero>(index, element::i32);
// then apply masked scatter
auto input_shape = rg.make<v3::ShapeOf>(input, element::i32);
auto expanded_mask = rg.make<v3::Broadcast>(index, input_shape, BroadcastType::BIDIRECTIONAL);
auto nonzero = rg.make<v3::NonZero>(expanded_mask, element::i32);
auto input_order = v0::Constant::create(element::i32, Shape{2}, {1, 0});
index = rg.make<v1::Transpose>(nonzero, input_order);
broadcast_index_shape = rg.make<v3::ShapeOf>(index, element::i32);
auto start_0 = v0::Constant::create(element::i32, Shape{1}, {0});
auto end_neg_1 = v0::Constant::create(element::i32, Shape{1}, {-1});
auto values_shape = rg.make<v8::Slice>(broadcast_index_shape, start_0, end_neg_1, const_1);
values = rg.make<v3::Broadcast>(values, values_shape);
values = rg.make<v1::ConvertLike>(values, input);
auto result = rg.make<v3::ScatterNDUpdate>(input, index, values);
// source can be arbitary shape, select only relevant data
auto const_minus_1 = v0::Constant::create(element::i32, Shape{1}, {-1});
auto flatten_values = rg.make<v1::Reshape>(values, const_minus_1, false);
auto const_0 = v0::Constant::create(element::i32, Shape{1}, {0});
auto index_shape = rg.make<v3::ShapeOf>(index, element::i32);
auto index_dim_zero = rg.make<v8::Gather>(index_shape, const_0, const_0);
auto slice_steps = v0::Constant::create(element::i32, Shape{1}, {1});
auto sliced_source = rg.make<v8::Slice>(flatten_values, const_0, index_dim_zero, slice_steps, const_0);
auto result = rg.make<v3::ScatterNDUpdate>(input, index, sliced_source);
copy_runtime_info_and_name(index_op, rg.get(), rt_copy_from);
replace_node(index_op, result);
return true;

View File

@ -142,6 +142,18 @@ class TestNonZero_IndexPut(PytorchLayerTest):
"input_shape": [3, 3],
"values": np.array([10, 11, 12]).astype(np.float32),
},
{
"input_shape": [3, 3, 3],
"values": np.array([[10, 11, 12]]).astype(np.float32),
},
{
"input_shape": [3, 3, 3],
"values": np.array(10).astype(np.float32),
},
{
"input_shape": [3, 3, 3],
"values": np.zeros((1, 1, 3)).astype(np.float32),
},
),
)
@pytest.mark.parametrize(
@ -166,6 +178,18 @@ class TestNonZero_IndexPut(PytorchLayerTest):
self.indices_0 = indices[0]
self.indices_1 = indices[1]
self._test(*self.create_model(accumulate), ie_device, precision, ir_version, trace_model=True, use_convert_model=True)
@pytest.mark.nightly
@pytest.mark.precommit
def test_nonzero_index_put_different_ranks(self, ie_device, precision, ir_version):
self.input_tensor = np.random.randn(1, 10, 2).astype(np.float32)
self.values = np.zeros((10, 2), dtype=np.float32)
self.indices_0 = np.array([[0, 0, 1, 1, 1, 1, 1, 1, 0, 0]]).astype(np.float32)
self.indices_1 = np.zeros((1, 10), dtype=np.float32)
self._test(*self.create_model(False), ie_device, precision, ir_version, trace_model=True, use_convert_model=True)
class TestMask_IndexPut(PytorchLayerTest):
def _prepare_input(self):