[PT FE]: fix aten::index inconsistent reshape (#16741)

* [PT FE]:  fix aten::index inconsistent reshape

* add index name, return false

* Update src/frontends/pytorch/src/transforms/aten_index_replacer.cpp
This commit is contained in:
Ekaterina Aidova 2023-04-05 12:44:25 +04:00 committed by GitHub
parent 73ab0dd065
commit 837f5a7d53
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 69 additions and 33 deletions

View File

@ -47,12 +47,12 @@ std::shared_ptr<Node> flatten(const Output<Node>& value, size_t axis) {
} else {
const auto value_shape = std::make_shared<v3::ShapeOf>(value, element::i32);
const auto value_rank = std::make_shared<v3::ShapeOf>(value_shape, element::i32);
const auto axis_node = v0::Constant::create(element::i32, Shape{}, {axis});
auto start = v0::Constant::create(element::i32, Shape{}, {0});
auto step = v0::Constant::create(element::i32, Shape{}, {1});
const auto axis_node = v0::Constant::create(element::i32, Shape{1}, {axis});
auto start = v0::Constant::create(element::i32, Shape{1}, {0});
auto step = v0::Constant::create(element::i32, Shape{1}, {1});
const auto first_part_dims = std::make_shared<v8::Slice>(value_shape, start, axis_node, step);
auto zero = v0::Constant::create(element::i32, {}, {0});
auto first_part_dims_length = std::make_shared<ov::op::v1::ReduceProd>(first_part_dims, zero, true);
auto first_part_dims_length = std::make_shared<v1::ReduceProd>(first_part_dims, zero, true);
auto remaining_part_length = v0::Constant::create(element::i32, {1}, {-1});
@ -70,7 +70,7 @@ AtenIndexToSelect::AtenIndexToSelect() {
if (!index_op) {
return false;
}
auto input_node = index_op->input_value(0).get_node_shared_ptr();
auto input_node = index_op->input_value(0);
auto indicies = index_op->input_value(1).get_node_shared_ptr();
auto list_indicies = cast_fw_node(indicies, "prim::ListConstruct");
if (list_indicies) {
@ -108,10 +108,10 @@ AtenIndexToSelect::AtenIndexToSelect() {
continue;
}
}
auto id_dtype = ids[i].get_node_shared_ptr()->get_element_type();
auto id_dtype = ids[i].get_element_type();
if (id_dtype == element::boolean || id_dtype == element::u8) {
auto idx = std::make_shared<ov::op::v0::Convert>(ids[i], element::u8);
auto nonzero = std::make_shared<ov::op::v3::NonZero>(idx);
auto idx = std::make_shared<v0::Convert>(ids[i], element::u8);
auto nonzero = std::make_shared<v3::NonZero>(idx, element::i32);
auto input_order = v0::Constant::create(element::i32, Shape{2}, {1, 0});
auto masked_id = std::make_shared<v1::Transpose>(nonzero, input_order);
masked_indicies.push_back(masked_id);
@ -125,30 +125,32 @@ AtenIndexToSelect::AtenIndexToSelect() {
// all indicies prim::Constant(None), return input as is
if (advanced_ids.size() == 0) {
copy_runtime_info({index_op, input_node}, input_node);
replace_node(index_op, input_node);
replace_node(index_op, input_node.get_node_shared_ptr());
return true;
}
// perform gather for single element case
if (advanced_ids.size() == 1) {
auto index = masked_indicies[advanced_ids[0]];
index = std::make_shared<v0::Convert>(index, element::i32);
if (is_masked_bool[advanced_ids[0]]) {
auto gather = std::make_shared<v8::GatherND>(input_node, index);
copy_runtime_info({index_op, input_node, indicies}, gather);
copy_runtime_info({index_op, indicies}, gather);
gather->set_friendly_name(index_op->get_friendly_name());
replace_node(index_op, gather);
return true;
}
index = std::make_shared<v0::Convert>(index, element::i32);
auto dim = v0::Constant::create(element::i32, Shape{}, {advanced_ids[0]});
auto gather = std::make_shared<v8::Gather>(input_node, index, dim);
copy_runtime_info({index_op, input_node, indicies}, gather);
copy_runtime_info({index_op, indicies}, gather);
gather->set_friendly_name(index_op->get_friendly_name());
replace_node(index_op, gather);
return true;
}
auto adv_idx_count = advanced_ids.size();
auto rank = input_node->get_input_partial_shape(0).rank();
auto rank = input_node.get_partial_shape().rank();
// index transformation supports only tensors with static rank
if (rank.is_dynamic()) {
FRONT_END_CHECK_IMPLEMENTED(false, "indexing for tensor with dynamic rank is not implemented ");
return false;
}
auto input_shape = std::make_shared<v3::ShapeOf>(input_node, element::i32);
auto zero = v0::Constant::create(element::i32, Shape{}, {0});
@ -166,9 +168,11 @@ AtenIndexToSelect::AtenIndexToSelect() {
auto transposed_input = std::make_shared<v1::Transpose>(input_node, transpose_dims);
auto flatten_input = flatten(transposed_input, adv_idx_count);
auto cum_adv_index = masked_indicies[advanced_ids[adv_idx_count - 1]];
cum_adv_index = std::make_shared<v0::Convert>(cum_adv_index, element::i32);
auto multiplier = input_dims->output(advanced_ids[adv_idx_count - 1]);
for (int i = static_cast<int>(adv_idx_count) - 2; i > 0; i--) {
auto adv_index = std::make_shared<v1::Multiply>(masked_indicies[i], multiplier);
for (int i = static_cast<int>(adv_idx_count) - 2; i > -1; i--) {
auto m_idx = std::make_shared<v0::Convert>(masked_indicies[i], element::i32);
auto adv_index = std::make_shared<v1::Multiply>(m_idx, multiplier);
cum_adv_index = std::make_shared<v1::Add>(cum_adv_index, adv_index);
auto input_id = advanced_ids[i];
multiplier = std::make_shared<v1::Multiply>(multiplier, input_dims->output(input_id));
@ -204,7 +208,7 @@ AtenIndexToSelect::AtenIndexToSelect() {
v0::Constant::create(element::i32, Shape{adv_idx_permute.size()}, adv_idx_permute);
gather = std::make_shared<v1::Transpose>(gather, permute_indicies);
// unfold advanced index axes
for (size_t i = 0; i <= advanced_ids[0]; i++) {
for (size_t i = 0; i < advanced_ids[0]; i++) {
concat_dims.push_back(input_dims->output(i));
}
concat_dims.push_back(cum_adv_index_shape_tensor);
@ -223,8 +227,9 @@ AtenIndexToSelect::AtenIndexToSelect() {
}
auto final_shape = std::make_shared<v0::Concat>(concat_dims, 0);
gather = std::make_shared<v1::Reshape>(gather, final_shape, false);
copy_runtime_info({index_op, input_node, indicies}, gather);
copy_runtime_info({index_op, indicies}, gather);
replace_node(index_op, gather);
gather->set_friendly_name(index_op->get_friendly_name());
return true;
} else {
@ -234,28 +239,28 @@ AtenIndexToSelect::AtenIndexToSelect() {
// index is None, stay input as is
const auto& attrs = const_input->get_attrs();
if (attrs.find("none_value") != attrs.end()) {
copy_runtime_info({index_op, input_node, indicies}, input_node);
replace_node(index_op, input_node);
replace_node(index_op, input_node.get_node_shared_ptr());
return true;
}
}
auto index_dtype = indicies->get_output_element_type(0);
if (index_dtype == element::boolean || index_dtype == element::u8) {
auto nonzero = std::make_shared<v3::NonZero>(indicies);
auto nonzero = std::make_shared<v3::NonZero>(indicies, element::i32);
auto input_order = v0::Constant::create(element::i32, Shape{2}, {1, 0});
auto masked_id = std::make_shared<v1::Transpose>(nonzero, input_order);
auto gather = std::make_shared<v8::GatherND>(input_node, masked_id);
copy_runtime_info({index_op, input_node, indicies}, gather);
copy_runtime_info({index_op, indicies}, gather);
replace_node(index_op, gather);
return true;
}
if (index_dtype != element::i32 && index_dtype != element::i32) {
if (index_dtype != element::i32) {
indicies = std::make_shared<ov::op::v0::Convert>(indicies, element::i32);
}
auto dim = v0::Constant::create(element::i32, Shape{}, {0});
auto gather = std::make_shared<v8::Gather>(input_node, indicies, dim);
copy_runtime_info({index_op, input_node, indicies}, gather);
copy_runtime_info({index_op, indicies}, gather);
replace_node(index_op, gather);
gather->set_friendly_name(index_op->get_friendly_name());
return true;
}
return false;

View File

@ -24,7 +24,6 @@ class TestIndex(PytorchLayerTest):
def forward(self, x, idx):
return x.__getitem__(idx)
class aten_index_list_bool(torch.nn.Module):
@ -52,13 +51,14 @@ class TestIndex(PytorchLayerTest):
@pytest.mark.precommit
@pytest.mark.parametrize("case", ["list", "getitem"])
@pytest.mark.parametrize(("input_shape", "idx"), [
((1,), np.array(0).astype(int)),
([2, 3], np.array(-1).astype(int)),
([4, 5, 6], np.array((1, 2)).astype(int)),
([7, 8, 9], np.array((-1, 2, -3)).astype(int)),
((1,), np.array(0).astype(int)),
([2, 3], np.array(-1).astype(int)),
([4, 5, 6], np.array((1, 2)).astype(int)),
([7, 8, 9], np.array((-1, 2, -3)).astype(int)),
([2, 2, 3, 4], np.array((1,)).astype(int))])
def test_index(self, input_shape, idx, case, ie_device, precision, ir_version):
self._test(*self.create_model(case), ie_device, precision, ir_version, kwargs_to_prepare_input={"input_shape": input_shape, "idx": idx})
self._test(*self.create_model(case), ie_device, precision, ir_version,
kwargs_to_prepare_input={"input_shape": input_shape, "idx": idx})
@pytest.mark.nightly
@pytest.mark.precommit
@ -68,6 +68,37 @@ class TestIndex(PytorchLayerTest):
((2, 2, 5), np.zeros([2, 2, 5]).astype(bool)),
((2, 2, 5), np.ones([2, 2, 5]).astype(bool)),
((2, 2, 5), np.random.rand(2, 2, 5) > 0)
])
])
def test_index_bool(self, input_shape, idx, case, ie_device, precision, ir_version):
self._test(*self.create_model(case), ie_device, precision, ir_version, kwargs_to_prepare_input={"input_shape": input_shape, "idx": idx})
self._test(*self.create_model(case), ie_device, precision, ir_version,
kwargs_to_prepare_input={"input_shape": input_shape, "idx": idx})
class TestIndexRange(PytorchLayerTest):
def _prepare_input(self, input_shape, idx):
import numpy as np
return (np.random.randn(*input_shape).astype(np.float32), np.array(idx).astype(np.int32))
def create_model(self):
import torch
class aten_index_unsqueeze(torch.nn.Module):
def forward(self, x, y):
x = x.reshape(x.shape[0], -1)
return x[torch.arange(x.shape[0]), y]
ref_net = None
return aten_index_unsqueeze(), ref_net, "aten::index"
@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.parametrize(("input_shape", "idx"), (
((1, 1), [0]),
([2, 3], [1, 2]),
([7, 8, 9], [1]),
([2, 2, 3, 4], [0])))
def test_index_range(self, input_shape, idx, ie_device, precision, ir_version):
self._test(*self.create_model(), ie_device, precision, ir_version, kwargs_to_prepare_input={
"input_shape": input_shape, "idx": idx}, trace_model=True, dynamic_shapes=False)