[PT FE]: fix for aten::index_put_ if values.r > indices.r (#21255)
* [PT FE]: fix for aten::index_put_ if values.r > indicies.r * add more complex test case
This commit is contained in:
parent
9421f4cf2d
commit
a5d53aeaef
@ -14,6 +14,7 @@
|
||||
#include "openvino/op/gather.hpp"
|
||||
#include "openvino/op/mod.hpp"
|
||||
#include "openvino/op/non_zero.hpp"
|
||||
#include "openvino/op/reshape.hpp"
|
||||
#include "openvino/op/scatter_nd_update.hpp"
|
||||
#include "openvino/op/shape_of.hpp"
|
||||
#include "openvino/op/slice.hpp"
|
||||
@ -128,16 +129,22 @@ AtenIndexPutReplacer::AtenIndexPutReplacer() {
|
||||
auto index_dtype = index.get_element_type();
|
||||
// Do we need to also check u8?
|
||||
if (index_dtype == element::boolean) {
|
||||
auto nonzero = rg.make<v3::NonZero>(index, element::i32);
|
||||
// then apply masked scatter
|
||||
auto input_shape = rg.make<v3::ShapeOf>(input, element::i32);
|
||||
auto expanded_mask = rg.make<v3::Broadcast>(index, input_shape, BroadcastType::BIDIRECTIONAL);
|
||||
auto nonzero = rg.make<v3::NonZero>(expanded_mask, element::i32);
|
||||
auto input_order = v0::Constant::create(element::i32, Shape{2}, {1, 0});
|
||||
index = rg.make<v1::Transpose>(nonzero, input_order);
|
||||
broadcast_index_shape = rg.make<v3::ShapeOf>(index, element::i32);
|
||||
auto start_0 = v0::Constant::create(element::i32, Shape{1}, {0});
|
||||
auto end_neg_1 = v0::Constant::create(element::i32, Shape{1}, {-1});
|
||||
auto values_shape = rg.make<v8::Slice>(broadcast_index_shape, start_0, end_neg_1, const_1);
|
||||
values = rg.make<v3::Broadcast>(values, values_shape);
|
||||
values = rg.make<v1::ConvertLike>(values, input);
|
||||
auto result = rg.make<v3::ScatterNDUpdate>(input, index, values);
|
||||
// source can be arbitary shape, select only relevant data
|
||||
auto const_minus_1 = v0::Constant::create(element::i32, Shape{1}, {-1});
|
||||
auto flatten_values = rg.make<v1::Reshape>(values, const_minus_1, false);
|
||||
auto const_0 = v0::Constant::create(element::i32, Shape{1}, {0});
|
||||
|
||||
auto index_shape = rg.make<v3::ShapeOf>(index, element::i32);
|
||||
auto index_dim_zero = rg.make<v8::Gather>(index_shape, const_0, const_0);
|
||||
auto slice_steps = v0::Constant::create(element::i32, Shape{1}, {1});
|
||||
auto sliced_source = rg.make<v8::Slice>(flatten_values, const_0, index_dim_zero, slice_steps, const_0);
|
||||
auto result = rg.make<v3::ScatterNDUpdate>(input, index, sliced_source);
|
||||
copy_runtime_info_and_name(index_op, rg.get(), rt_copy_from);
|
||||
replace_node(index_op, result);
|
||||
return true;
|
||||
|
@ -142,6 +142,18 @@ class TestNonZero_IndexPut(PytorchLayerTest):
|
||||
"input_shape": [3, 3],
|
||||
"values": np.array([10, 11, 12]).astype(np.float32),
|
||||
},
|
||||
{
|
||||
"input_shape": [3, 3, 3],
|
||||
"values": np.array([[10, 11, 12]]).astype(np.float32),
|
||||
},
|
||||
{
|
||||
"input_shape": [3, 3, 3],
|
||||
"values": np.array(10).astype(np.float32),
|
||||
},
|
||||
{
|
||||
"input_shape": [3, 3, 3],
|
||||
"values": np.zeros((1, 1, 3)).astype(np.float32),
|
||||
},
|
||||
),
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
@ -167,6 +179,18 @@ class TestNonZero_IndexPut(PytorchLayerTest):
|
||||
self.indices_1 = indices[1]
|
||||
self._test(*self.create_model(accumulate), ie_device, precision, ir_version, trace_model=True, use_convert_model=True)
|
||||
|
||||
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_nonzero_index_put_different_ranks(self, ie_device, precision, ir_version):
|
||||
self.input_tensor = np.random.randn(1, 10, 2).astype(np.float32)
|
||||
self.values = np.zeros((10, 2), dtype=np.float32)
|
||||
self.indices_0 = np.array([[0, 0, 1, 1, 1, 1, 1, 1, 0, 0]]).astype(np.float32)
|
||||
self.indices_1 = np.zeros((1, 10), dtype=np.float32)
|
||||
self._test(*self.create_model(False), ie_device, precision, ir_version, trace_model=True, use_convert_model=True)
|
||||
|
||||
|
||||
|
||||
class TestMask_IndexPut(PytorchLayerTest):
|
||||
def _prepare_input(self):
|
||||
return (np.random.randn(100, 5).astype(np.float32),np.random.randn(100, 5).astype(np.float32))
|
||||
|
Loading…
Reference in New Issue
Block a user