From 837f5a7d53c46c0e76138f330a688bf1cf564066 Mon Sep 17 00:00:00 2001 From: Ekaterina Aidova Date: Wed, 5 Apr 2023 12:44:25 +0400 Subject: [PATCH] [PT FE]: fix aten::index inconsistent reshape (#16741) * [PT FE]: fix aten::index inconsistent reshape * add index name, return false * Update src/frontends/pytorch/src/transforms/aten_index_replacer.cpp --- .../src/transforms/aten_index_replacer.cpp | 55 ++++++++++--------- tests/layer_tests/pytorch_tests/test_index.py | 47 +++++++++++++--- 2 files changed, 69 insertions(+), 33 deletions(-) diff --git a/src/frontends/pytorch/src/transforms/aten_index_replacer.cpp b/src/frontends/pytorch/src/transforms/aten_index_replacer.cpp index 93d4214add1..cf9bd5c9fa2 100644 --- a/src/frontends/pytorch/src/transforms/aten_index_replacer.cpp +++ b/src/frontends/pytorch/src/transforms/aten_index_replacer.cpp @@ -47,12 +47,12 @@ std::shared_ptr flatten(const Output& value, size_t axis) { } else { const auto value_shape = std::make_shared(value, element::i32); const auto value_rank = std::make_shared(value_shape, element::i32); - const auto axis_node = v0::Constant::create(element::i32, Shape{}, {axis}); - auto start = v0::Constant::create(element::i32, Shape{}, {0}); - auto step = v0::Constant::create(element::i32, Shape{}, {1}); + const auto axis_node = v0::Constant::create(element::i32, Shape{1}, {axis}); + auto start = v0::Constant::create(element::i32, Shape{1}, {0}); + auto step = v0::Constant::create(element::i32, Shape{1}, {1}); const auto first_part_dims = std::make_shared(value_shape, start, axis_node, step); auto zero = v0::Constant::create(element::i32, {}, {0}); - auto first_part_dims_length = std::make_shared(first_part_dims, zero, true); + auto first_part_dims_length = std::make_shared(first_part_dims, zero, true); auto remaining_part_length = v0::Constant::create(element::i32, {1}, {-1}); @@ -70,7 +70,7 @@ AtenIndexToSelect::AtenIndexToSelect() { if (!index_op) { return false; } - auto input_node = index_op->input_value(0).get_node_shared_ptr(); + auto input_node = index_op->input_value(0); auto indicies = index_op->input_value(1).get_node_shared_ptr(); auto list_indicies = cast_fw_node(indicies, "prim::ListConstruct"); if (list_indicies) { @@ -108,10 +108,10 @@ AtenIndexToSelect::AtenIndexToSelect() { continue; } } - auto id_dtype = ids[i].get_node_shared_ptr()->get_element_type(); + auto id_dtype = ids[i].get_element_type(); if (id_dtype == element::boolean || id_dtype == element::u8) { - auto idx = std::make_shared(ids[i], element::u8); - auto nonzero = std::make_shared(idx); + auto idx = std::make_shared(ids[i], element::u8); + auto nonzero = std::make_shared(idx, element::i32); auto input_order = v0::Constant::create(element::i32, Shape{2}, {1, 0}); auto masked_id = std::make_shared(nonzero, input_order); masked_indicies.push_back(masked_id); @@ -125,30 +125,32 @@ AtenIndexToSelect::AtenIndexToSelect() { // all indicies prim::Constant(None), return input as is if (advanced_ids.size() == 0) { - copy_runtime_info({index_op, input_node}, input_node); - replace_node(index_op, input_node); + replace_node(index_op, input_node.get_node_shared_ptr()); return true; } // perform gather for single element case if (advanced_ids.size() == 1) { auto index = masked_indicies[advanced_ids[0]]; - index = std::make_shared(index, element::i32); if (is_masked_bool[advanced_ids[0]]) { auto gather = std::make_shared(input_node, index); - copy_runtime_info({index_op, input_node, indicies}, gather); + copy_runtime_info({index_op, indicies}, gather); + gather->set_friendly_name(index_op->get_friendly_name()); replace_node(index_op, gather); return true; } + index = std::make_shared(index, element::i32); auto dim = v0::Constant::create(element::i32, Shape{}, {advanced_ids[0]}); auto gather = std::make_shared(input_node, index, dim); - copy_runtime_info({index_op, input_node, indicies}, gather); + copy_runtime_info({index_op, indicies}, gather); + gather->set_friendly_name(index_op->get_friendly_name()); replace_node(index_op, gather); return true; } auto adv_idx_count = advanced_ids.size(); - auto rank = input_node->get_input_partial_shape(0).rank(); + auto rank = input_node.get_partial_shape().rank(); + // index transformation supports only tensors with static rank if (rank.is_dynamic()) { - FRONT_END_CHECK_IMPLEMENTED(false, "indexing for tensor with dynamic rank is not implemented "); + return false; } auto input_shape = std::make_shared(input_node, element::i32); auto zero = v0::Constant::create(element::i32, Shape{}, {0}); @@ -166,9 +168,11 @@ AtenIndexToSelect::AtenIndexToSelect() { auto transposed_input = std::make_shared(input_node, transpose_dims); auto flatten_input = flatten(transposed_input, adv_idx_count); auto cum_adv_index = masked_indicies[advanced_ids[adv_idx_count - 1]]; + cum_adv_index = std::make_shared(cum_adv_index, element::i32); auto multiplier = input_dims->output(advanced_ids[adv_idx_count - 1]); - for (int i = static_cast(adv_idx_count) - 2; i > 0; i--) { - auto adv_index = std::make_shared(masked_indicies[i], multiplier); + for (int i = static_cast(adv_idx_count) - 2; i > -1; i--) { + auto m_idx = std::make_shared(masked_indicies[i], element::i32); + auto adv_index = std::make_shared(m_idx, multiplier); cum_adv_index = std::make_shared(cum_adv_index, adv_index); auto input_id = advanced_ids[i]; multiplier = std::make_shared(multiplier, input_dims->output(input_id)); @@ -204,7 +208,7 @@ AtenIndexToSelect::AtenIndexToSelect() { v0::Constant::create(element::i32, Shape{adv_idx_permute.size()}, adv_idx_permute); gather = std::make_shared(gather, permute_indicies); // unfold advanced index axes - for (size_t i = 0; i <= advanced_ids[0]; i++) { + for (size_t i = 0; i < advanced_ids[0]; i++) { concat_dims.push_back(input_dims->output(i)); } concat_dims.push_back(cum_adv_index_shape_tensor); @@ -223,8 +227,9 @@ AtenIndexToSelect::AtenIndexToSelect() { } auto final_shape = std::make_shared(concat_dims, 0); gather = std::make_shared(gather, final_shape, false); - copy_runtime_info({index_op, input_node, indicies}, gather); + copy_runtime_info({index_op, indicies}, gather); replace_node(index_op, gather); + gather->set_friendly_name(index_op->get_friendly_name()); return true; } else { @@ -234,28 +239,28 @@ AtenIndexToSelect::AtenIndexToSelect() { // index is None, stay input as is const auto& attrs = const_input->get_attrs(); if (attrs.find("none_value") != attrs.end()) { - copy_runtime_info({index_op, input_node, indicies}, input_node); - replace_node(index_op, input_node); + replace_node(index_op, input_node.get_node_shared_ptr()); return true; } } auto index_dtype = indicies->get_output_element_type(0); if (index_dtype == element::boolean || index_dtype == element::u8) { - auto nonzero = std::make_shared(indicies); + auto nonzero = std::make_shared(indicies, element::i32); auto input_order = v0::Constant::create(element::i32, Shape{2}, {1, 0}); auto masked_id = std::make_shared(nonzero, input_order); auto gather = std::make_shared(input_node, masked_id); - copy_runtime_info({index_op, input_node, indicies}, gather); + copy_runtime_info({index_op, indicies}, gather); replace_node(index_op, gather); return true; } - if (index_dtype != element::i32 && index_dtype != element::i32) { + if (index_dtype != element::i32) { indicies = std::make_shared(indicies, element::i32); } auto dim = v0::Constant::create(element::i32, Shape{}, {0}); auto gather = std::make_shared(input_node, indicies, dim); - copy_runtime_info({index_op, input_node, indicies}, gather); + copy_runtime_info({index_op, indicies}, gather); replace_node(index_op, gather); + gather->set_friendly_name(index_op->get_friendly_name()); return true; } return false; diff --git a/tests/layer_tests/pytorch_tests/test_index.py b/tests/layer_tests/pytorch_tests/test_index.py index 967ef4c98af..c4e303d244f 100644 --- a/tests/layer_tests/pytorch_tests/test_index.py +++ b/tests/layer_tests/pytorch_tests/test_index.py @@ -24,7 +24,6 @@ class TestIndex(PytorchLayerTest): def forward(self, x, idx): return x.__getitem__(idx) - class aten_index_list_bool(torch.nn.Module): @@ -52,13 +51,14 @@ class TestIndex(PytorchLayerTest): @pytest.mark.precommit @pytest.mark.parametrize("case", ["list", "getitem"]) @pytest.mark.parametrize(("input_shape", "idx"), [ - ((1,), np.array(0).astype(int)), - ([2, 3], np.array(-1).astype(int)), - ([4, 5, 6], np.array((1, 2)).astype(int)), - ([7, 8, 9], np.array((-1, 2, -3)).astype(int)), + ((1,), np.array(0).astype(int)), + ([2, 3], np.array(-1).astype(int)), + ([4, 5, 6], np.array((1, 2)).astype(int)), + ([7, 8, 9], np.array((-1, 2, -3)).astype(int)), ([2, 2, 3, 4], np.array((1,)).astype(int))]) def test_index(self, input_shape, idx, case, ie_device, precision, ir_version): - self._test(*self.create_model(case), ie_device, precision, ir_version, kwargs_to_prepare_input={"input_shape": input_shape, "idx": idx}) + self._test(*self.create_model(case), ie_device, precision, ir_version, + kwargs_to_prepare_input={"input_shape": input_shape, "idx": idx}) @pytest.mark.nightly @pytest.mark.precommit @@ -68,6 +68,37 @@ class TestIndex(PytorchLayerTest): ((2, 2, 5), np.zeros([2, 2, 5]).astype(bool)), ((2, 2, 5), np.ones([2, 2, 5]).astype(bool)), ((2, 2, 5), np.random.rand(2, 2, 5) > 0) - ]) + ]) def test_index_bool(self, input_shape, idx, case, ie_device, precision, ir_version): - self._test(*self.create_model(case), ie_device, precision, ir_version, kwargs_to_prepare_input={"input_shape": input_shape, "idx": idx}) \ No newline at end of file + self._test(*self.create_model(case), ie_device, precision, ir_version, + kwargs_to_prepare_input={"input_shape": input_shape, "idx": idx}) + + +class TestIndexRange(PytorchLayerTest): + def _prepare_input(self, input_shape, idx): + import numpy as np + return (np.random.randn(*input_shape).astype(np.float32), np.array(idx).astype(np.int32)) + + def create_model(self): + import torch + + class aten_index_unsqueeze(torch.nn.Module): + + def forward(self, x, y): + x = x.reshape(x.shape[0], -1) + return x[torch.arange(x.shape[0]), y] + + ref_net = None + + return aten_index_unsqueeze(), ref_net, "aten::index" + + @pytest.mark.nightly + @pytest.mark.precommit + @pytest.mark.parametrize(("input_shape", "idx"), ( + ((1, 1), [0]), + ([2, 3], [1, 2]), + ([7, 8, 9], [1]), + ([2, 2, 3, 4], [0]))) + def test_index_range(self, input_shape, idx, ie_device, precision, ir_version): + self._test(*self.create_model(), ie_device, precision, ir_version, kwargs_to_prepare_input={ + "input_shape": input_shape, "idx": idx}, trace_model=True, dynamic_shapes=False)