[PT FE]: fix aten::index inconsistent reshape (#16741)

* [PT FE]:  fix aten::index inconsistent reshape

* add index name, return false

* Update src/frontends/pytorch/src/transforms/aten_index_replacer.cpp
This commit is contained in:
Ekaterina Aidova 2023-04-05 12:44:25 +04:00 committed by GitHub
parent 73ab0dd065
commit 837f5a7d53
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 69 additions and 33 deletions

View File

@ -47,12 +47,12 @@ std::shared_ptr<Node> flatten(const Output<Node>& value, size_t axis) {
} else { } else {
const auto value_shape = std::make_shared<v3::ShapeOf>(value, element::i32); const auto value_shape = std::make_shared<v3::ShapeOf>(value, element::i32);
const auto value_rank = std::make_shared<v3::ShapeOf>(value_shape, element::i32); const auto value_rank = std::make_shared<v3::ShapeOf>(value_shape, element::i32);
const auto axis_node = v0::Constant::create(element::i32, Shape{}, {axis}); const auto axis_node = v0::Constant::create(element::i32, Shape{1}, {axis});
auto start = v0::Constant::create(element::i32, Shape{}, {0}); auto start = v0::Constant::create(element::i32, Shape{1}, {0});
auto step = v0::Constant::create(element::i32, Shape{}, {1}); auto step = v0::Constant::create(element::i32, Shape{1}, {1});
const auto first_part_dims = std::make_shared<v8::Slice>(value_shape, start, axis_node, step); const auto first_part_dims = std::make_shared<v8::Slice>(value_shape, start, axis_node, step);
auto zero = v0::Constant::create(element::i32, {}, {0}); auto zero = v0::Constant::create(element::i32, {}, {0});
auto first_part_dims_length = std::make_shared<ov::op::v1::ReduceProd>(first_part_dims, zero, true); auto first_part_dims_length = std::make_shared<v1::ReduceProd>(first_part_dims, zero, true);
auto remaining_part_length = v0::Constant::create(element::i32, {1}, {-1}); auto remaining_part_length = v0::Constant::create(element::i32, {1}, {-1});
@ -70,7 +70,7 @@ AtenIndexToSelect::AtenIndexToSelect() {
if (!index_op) { if (!index_op) {
return false; return false;
} }
auto input_node = index_op->input_value(0).get_node_shared_ptr(); auto input_node = index_op->input_value(0);
auto indicies = index_op->input_value(1).get_node_shared_ptr(); auto indicies = index_op->input_value(1).get_node_shared_ptr();
auto list_indicies = cast_fw_node(indicies, "prim::ListConstruct"); auto list_indicies = cast_fw_node(indicies, "prim::ListConstruct");
if (list_indicies) { if (list_indicies) {
@ -108,10 +108,10 @@ AtenIndexToSelect::AtenIndexToSelect() {
continue; continue;
} }
} }
auto id_dtype = ids[i].get_node_shared_ptr()->get_element_type(); auto id_dtype = ids[i].get_element_type();
if (id_dtype == element::boolean || id_dtype == element::u8) { if (id_dtype == element::boolean || id_dtype == element::u8) {
auto idx = std::make_shared<ov::op::v0::Convert>(ids[i], element::u8); auto idx = std::make_shared<v0::Convert>(ids[i], element::u8);
auto nonzero = std::make_shared<ov::op::v3::NonZero>(idx); auto nonzero = std::make_shared<v3::NonZero>(idx, element::i32);
auto input_order = v0::Constant::create(element::i32, Shape{2}, {1, 0}); auto input_order = v0::Constant::create(element::i32, Shape{2}, {1, 0});
auto masked_id = std::make_shared<v1::Transpose>(nonzero, input_order); auto masked_id = std::make_shared<v1::Transpose>(nonzero, input_order);
masked_indicies.push_back(masked_id); masked_indicies.push_back(masked_id);
@ -125,30 +125,32 @@ AtenIndexToSelect::AtenIndexToSelect() {
// all indicies prim::Constant(None), return input as is // all indicies prim::Constant(None), return input as is
if (advanced_ids.size() == 0) { if (advanced_ids.size() == 0) {
copy_runtime_info({index_op, input_node}, input_node); replace_node(index_op, input_node.get_node_shared_ptr());
replace_node(index_op, input_node);
return true; return true;
} }
// perform gather for single element case // perform gather for single element case
if (advanced_ids.size() == 1) { if (advanced_ids.size() == 1) {
auto index = masked_indicies[advanced_ids[0]]; auto index = masked_indicies[advanced_ids[0]];
index = std::make_shared<v0::Convert>(index, element::i32);
if (is_masked_bool[advanced_ids[0]]) { if (is_masked_bool[advanced_ids[0]]) {
auto gather = std::make_shared<v8::GatherND>(input_node, index); auto gather = std::make_shared<v8::GatherND>(input_node, index);
copy_runtime_info({index_op, input_node, indicies}, gather); copy_runtime_info({index_op, indicies}, gather);
gather->set_friendly_name(index_op->get_friendly_name());
replace_node(index_op, gather); replace_node(index_op, gather);
return true; return true;
} }
index = std::make_shared<v0::Convert>(index, element::i32);
auto dim = v0::Constant::create(element::i32, Shape{}, {advanced_ids[0]}); auto dim = v0::Constant::create(element::i32, Shape{}, {advanced_ids[0]});
auto gather = std::make_shared<v8::Gather>(input_node, index, dim); auto gather = std::make_shared<v8::Gather>(input_node, index, dim);
copy_runtime_info({index_op, input_node, indicies}, gather); copy_runtime_info({index_op, indicies}, gather);
gather->set_friendly_name(index_op->get_friendly_name());
replace_node(index_op, gather); replace_node(index_op, gather);
return true; return true;
} }
auto adv_idx_count = advanced_ids.size(); auto adv_idx_count = advanced_ids.size();
auto rank = input_node->get_input_partial_shape(0).rank(); auto rank = input_node.get_partial_shape().rank();
// index transformation supports only tensors with static rank
if (rank.is_dynamic()) { if (rank.is_dynamic()) {
FRONT_END_CHECK_IMPLEMENTED(false, "indexing for tensor with dynamic rank is not implemented "); return false;
} }
auto input_shape = std::make_shared<v3::ShapeOf>(input_node, element::i32); auto input_shape = std::make_shared<v3::ShapeOf>(input_node, element::i32);
auto zero = v0::Constant::create(element::i32, Shape{}, {0}); auto zero = v0::Constant::create(element::i32, Shape{}, {0});
@ -166,9 +168,11 @@ AtenIndexToSelect::AtenIndexToSelect() {
auto transposed_input = std::make_shared<v1::Transpose>(input_node, transpose_dims); auto transposed_input = std::make_shared<v1::Transpose>(input_node, transpose_dims);
auto flatten_input = flatten(transposed_input, adv_idx_count); auto flatten_input = flatten(transposed_input, adv_idx_count);
auto cum_adv_index = masked_indicies[advanced_ids[adv_idx_count - 1]]; auto cum_adv_index = masked_indicies[advanced_ids[adv_idx_count - 1]];
cum_adv_index = std::make_shared<v0::Convert>(cum_adv_index, element::i32);
auto multiplier = input_dims->output(advanced_ids[adv_idx_count - 1]); auto multiplier = input_dims->output(advanced_ids[adv_idx_count - 1]);
for (int i = static_cast<int>(adv_idx_count) - 2; i > 0; i--) { for (int i = static_cast<int>(adv_idx_count) - 2; i > -1; i--) {
auto adv_index = std::make_shared<v1::Multiply>(masked_indicies[i], multiplier); auto m_idx = std::make_shared<v0::Convert>(masked_indicies[i], element::i32);
auto adv_index = std::make_shared<v1::Multiply>(m_idx, multiplier);
cum_adv_index = std::make_shared<v1::Add>(cum_adv_index, adv_index); cum_adv_index = std::make_shared<v1::Add>(cum_adv_index, adv_index);
auto input_id = advanced_ids[i]; auto input_id = advanced_ids[i];
multiplier = std::make_shared<v1::Multiply>(multiplier, input_dims->output(input_id)); multiplier = std::make_shared<v1::Multiply>(multiplier, input_dims->output(input_id));
@ -204,7 +208,7 @@ AtenIndexToSelect::AtenIndexToSelect() {
v0::Constant::create(element::i32, Shape{adv_idx_permute.size()}, adv_idx_permute); v0::Constant::create(element::i32, Shape{adv_idx_permute.size()}, adv_idx_permute);
gather = std::make_shared<v1::Transpose>(gather, permute_indicies); gather = std::make_shared<v1::Transpose>(gather, permute_indicies);
// unfold advanced index axes // unfold advanced index axes
for (size_t i = 0; i <= advanced_ids[0]; i++) { for (size_t i = 0; i < advanced_ids[0]; i++) {
concat_dims.push_back(input_dims->output(i)); concat_dims.push_back(input_dims->output(i));
} }
concat_dims.push_back(cum_adv_index_shape_tensor); concat_dims.push_back(cum_adv_index_shape_tensor);
@ -223,8 +227,9 @@ AtenIndexToSelect::AtenIndexToSelect() {
} }
auto final_shape = std::make_shared<v0::Concat>(concat_dims, 0); auto final_shape = std::make_shared<v0::Concat>(concat_dims, 0);
gather = std::make_shared<v1::Reshape>(gather, final_shape, false); gather = std::make_shared<v1::Reshape>(gather, final_shape, false);
copy_runtime_info({index_op, input_node, indicies}, gather); copy_runtime_info({index_op, indicies}, gather);
replace_node(index_op, gather); replace_node(index_op, gather);
gather->set_friendly_name(index_op->get_friendly_name());
return true; return true;
} else { } else {
@ -234,28 +239,28 @@ AtenIndexToSelect::AtenIndexToSelect() {
// index is None, stay input as is // index is None, stay input as is
const auto& attrs = const_input->get_attrs(); const auto& attrs = const_input->get_attrs();
if (attrs.find("none_value") != attrs.end()) { if (attrs.find("none_value") != attrs.end()) {
copy_runtime_info({index_op, input_node, indicies}, input_node); replace_node(index_op, input_node.get_node_shared_ptr());
replace_node(index_op, input_node);
return true; return true;
} }
} }
auto index_dtype = indicies->get_output_element_type(0); auto index_dtype = indicies->get_output_element_type(0);
if (index_dtype == element::boolean || index_dtype == element::u8) { if (index_dtype == element::boolean || index_dtype == element::u8) {
auto nonzero = std::make_shared<v3::NonZero>(indicies); auto nonzero = std::make_shared<v3::NonZero>(indicies, element::i32);
auto input_order = v0::Constant::create(element::i32, Shape{2}, {1, 0}); auto input_order = v0::Constant::create(element::i32, Shape{2}, {1, 0});
auto masked_id = std::make_shared<v1::Transpose>(nonzero, input_order); auto masked_id = std::make_shared<v1::Transpose>(nonzero, input_order);
auto gather = std::make_shared<v8::GatherND>(input_node, masked_id); auto gather = std::make_shared<v8::GatherND>(input_node, masked_id);
copy_runtime_info({index_op, input_node, indicies}, gather); copy_runtime_info({index_op, indicies}, gather);
replace_node(index_op, gather); replace_node(index_op, gather);
return true; return true;
} }
if (index_dtype != element::i32 && index_dtype != element::i32) { if (index_dtype != element::i32) {
indicies = std::make_shared<ov::op::v0::Convert>(indicies, element::i32); indicies = std::make_shared<ov::op::v0::Convert>(indicies, element::i32);
} }
auto dim = v0::Constant::create(element::i32, Shape{}, {0}); auto dim = v0::Constant::create(element::i32, Shape{}, {0});
auto gather = std::make_shared<v8::Gather>(input_node, indicies, dim); auto gather = std::make_shared<v8::Gather>(input_node, indicies, dim);
copy_runtime_info({index_op, input_node, indicies}, gather); copy_runtime_info({index_op, indicies}, gather);
replace_node(index_op, gather); replace_node(index_op, gather);
gather->set_friendly_name(index_op->get_friendly_name());
return true; return true;
} }
return false; return false;

View File

@ -25,7 +25,6 @@ class TestIndex(PytorchLayerTest):
def forward(self, x, idx): def forward(self, x, idx):
return x.__getitem__(idx) return x.__getitem__(idx)
class aten_index_list_bool(torch.nn.Module): class aten_index_list_bool(torch.nn.Module):
def forward(self, x, idx): def forward(self, x, idx):
@ -58,7 +57,8 @@ class TestIndex(PytorchLayerTest):
([7, 8, 9], np.array((-1, 2, -3)).astype(int)), ([7, 8, 9], np.array((-1, 2, -3)).astype(int)),
([2, 2, 3, 4], np.array((1,)).astype(int))]) ([2, 2, 3, 4], np.array((1,)).astype(int))])
def test_index(self, input_shape, idx, case, ie_device, precision, ir_version): def test_index(self, input_shape, idx, case, ie_device, precision, ir_version):
self._test(*self.create_model(case), ie_device, precision, ir_version, kwargs_to_prepare_input={"input_shape": input_shape, "idx": idx}) self._test(*self.create_model(case), ie_device, precision, ir_version,
kwargs_to_prepare_input={"input_shape": input_shape, "idx": idx})
@pytest.mark.nightly @pytest.mark.nightly
@pytest.mark.precommit @pytest.mark.precommit
@ -70,4 +70,35 @@ class TestIndex(PytorchLayerTest):
((2, 2, 5), np.random.rand(2, 2, 5) > 0) ((2, 2, 5), np.random.rand(2, 2, 5) > 0)
]) ])
def test_index_bool(self, input_shape, idx, case, ie_device, precision, ir_version): def test_index_bool(self, input_shape, idx, case, ie_device, precision, ir_version):
self._test(*self.create_model(case), ie_device, precision, ir_version, kwargs_to_prepare_input={"input_shape": input_shape, "idx": idx}) self._test(*self.create_model(case), ie_device, precision, ir_version,
kwargs_to_prepare_input={"input_shape": input_shape, "idx": idx})
class TestIndexRange(PytorchLayerTest):
def _prepare_input(self, input_shape, idx):
import numpy as np
return (np.random.randn(*input_shape).astype(np.float32), np.array(idx).astype(np.int32))
def create_model(self):
import torch
class aten_index_unsqueeze(torch.nn.Module):
def forward(self, x, y):
x = x.reshape(x.shape[0], -1)
return x[torch.arange(x.shape[0]), y]
ref_net = None
return aten_index_unsqueeze(), ref_net, "aten::index"
@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.parametrize(("input_shape", "idx"), (
((1, 1), [0]),
([2, 3], [1, 2]),
([7, 8, 9], [1]),
([2, 2, 3, 4], [0])))
def test_index_range(self, input_shape, idx, ie_device, precision, ir_version):
self._test(*self.create_model(), ie_device, precision, ir_version, kwargs_to_prepare_input={
"input_shape": input_shape, "idx": idx}, trace_model=True, dynamic_shapes=False)