[PT FE]: fix aten::index inconsistent reshape (#16741)
* [PT FE]: fix aten::index inconsistent reshape * add index name, return false * Update src/frontends/pytorch/src/transforms/aten_index_replacer.cpp
This commit is contained in:
parent
73ab0dd065
commit
837f5a7d53
@ -47,12 +47,12 @@ std::shared_ptr<Node> flatten(const Output<Node>& value, size_t axis) {
|
||||
} else {
|
||||
const auto value_shape = std::make_shared<v3::ShapeOf>(value, element::i32);
|
||||
const auto value_rank = std::make_shared<v3::ShapeOf>(value_shape, element::i32);
|
||||
const auto axis_node = v0::Constant::create(element::i32, Shape{}, {axis});
|
||||
auto start = v0::Constant::create(element::i32, Shape{}, {0});
|
||||
auto step = v0::Constant::create(element::i32, Shape{}, {1});
|
||||
const auto axis_node = v0::Constant::create(element::i32, Shape{1}, {axis});
|
||||
auto start = v0::Constant::create(element::i32, Shape{1}, {0});
|
||||
auto step = v0::Constant::create(element::i32, Shape{1}, {1});
|
||||
const auto first_part_dims = std::make_shared<v8::Slice>(value_shape, start, axis_node, step);
|
||||
auto zero = v0::Constant::create(element::i32, {}, {0});
|
||||
auto first_part_dims_length = std::make_shared<ov::op::v1::ReduceProd>(first_part_dims, zero, true);
|
||||
auto first_part_dims_length = std::make_shared<v1::ReduceProd>(first_part_dims, zero, true);
|
||||
|
||||
auto remaining_part_length = v0::Constant::create(element::i32, {1}, {-1});
|
||||
|
||||
@ -70,7 +70,7 @@ AtenIndexToSelect::AtenIndexToSelect() {
|
||||
if (!index_op) {
|
||||
return false;
|
||||
}
|
||||
auto input_node = index_op->input_value(0).get_node_shared_ptr();
|
||||
auto input_node = index_op->input_value(0);
|
||||
auto indicies = index_op->input_value(1).get_node_shared_ptr();
|
||||
auto list_indicies = cast_fw_node(indicies, "prim::ListConstruct");
|
||||
if (list_indicies) {
|
||||
@ -108,10 +108,10 @@ AtenIndexToSelect::AtenIndexToSelect() {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
auto id_dtype = ids[i].get_node_shared_ptr()->get_element_type();
|
||||
auto id_dtype = ids[i].get_element_type();
|
||||
if (id_dtype == element::boolean || id_dtype == element::u8) {
|
||||
auto idx = std::make_shared<ov::op::v0::Convert>(ids[i], element::u8);
|
||||
auto nonzero = std::make_shared<ov::op::v3::NonZero>(idx);
|
||||
auto idx = std::make_shared<v0::Convert>(ids[i], element::u8);
|
||||
auto nonzero = std::make_shared<v3::NonZero>(idx, element::i32);
|
||||
auto input_order = v0::Constant::create(element::i32, Shape{2}, {1, 0});
|
||||
auto masked_id = std::make_shared<v1::Transpose>(nonzero, input_order);
|
||||
masked_indicies.push_back(masked_id);
|
||||
@ -125,30 +125,32 @@ AtenIndexToSelect::AtenIndexToSelect() {
|
||||
|
||||
// all indicies prim::Constant(None), return input as is
|
||||
if (advanced_ids.size() == 0) {
|
||||
copy_runtime_info({index_op, input_node}, input_node);
|
||||
replace_node(index_op, input_node);
|
||||
replace_node(index_op, input_node.get_node_shared_ptr());
|
||||
return true;
|
||||
}
|
||||
// perform gather for single element case
|
||||
if (advanced_ids.size() == 1) {
|
||||
auto index = masked_indicies[advanced_ids[0]];
|
||||
index = std::make_shared<v0::Convert>(index, element::i32);
|
||||
if (is_masked_bool[advanced_ids[0]]) {
|
||||
auto gather = std::make_shared<v8::GatherND>(input_node, index);
|
||||
copy_runtime_info({index_op, input_node, indicies}, gather);
|
||||
copy_runtime_info({index_op, indicies}, gather);
|
||||
gather->set_friendly_name(index_op->get_friendly_name());
|
||||
replace_node(index_op, gather);
|
||||
return true;
|
||||
}
|
||||
index = std::make_shared<v0::Convert>(index, element::i32);
|
||||
auto dim = v0::Constant::create(element::i32, Shape{}, {advanced_ids[0]});
|
||||
auto gather = std::make_shared<v8::Gather>(input_node, index, dim);
|
||||
copy_runtime_info({index_op, input_node, indicies}, gather);
|
||||
copy_runtime_info({index_op, indicies}, gather);
|
||||
gather->set_friendly_name(index_op->get_friendly_name());
|
||||
replace_node(index_op, gather);
|
||||
return true;
|
||||
}
|
||||
auto adv_idx_count = advanced_ids.size();
|
||||
auto rank = input_node->get_input_partial_shape(0).rank();
|
||||
auto rank = input_node.get_partial_shape().rank();
|
||||
// index transformation supports only tensors with static rank
|
||||
if (rank.is_dynamic()) {
|
||||
FRONT_END_CHECK_IMPLEMENTED(false, "indexing for tensor with dynamic rank is not implemented ");
|
||||
return false;
|
||||
}
|
||||
auto input_shape = std::make_shared<v3::ShapeOf>(input_node, element::i32);
|
||||
auto zero = v0::Constant::create(element::i32, Shape{}, {0});
|
||||
@ -166,9 +168,11 @@ AtenIndexToSelect::AtenIndexToSelect() {
|
||||
auto transposed_input = std::make_shared<v1::Transpose>(input_node, transpose_dims);
|
||||
auto flatten_input = flatten(transposed_input, adv_idx_count);
|
||||
auto cum_adv_index = masked_indicies[advanced_ids[adv_idx_count - 1]];
|
||||
cum_adv_index = std::make_shared<v0::Convert>(cum_adv_index, element::i32);
|
||||
auto multiplier = input_dims->output(advanced_ids[adv_idx_count - 1]);
|
||||
for (int i = static_cast<int>(adv_idx_count) - 2; i > 0; i--) {
|
||||
auto adv_index = std::make_shared<v1::Multiply>(masked_indicies[i], multiplier);
|
||||
for (int i = static_cast<int>(adv_idx_count) - 2; i > -1; i--) {
|
||||
auto m_idx = std::make_shared<v0::Convert>(masked_indicies[i], element::i32);
|
||||
auto adv_index = std::make_shared<v1::Multiply>(m_idx, multiplier);
|
||||
cum_adv_index = std::make_shared<v1::Add>(cum_adv_index, adv_index);
|
||||
auto input_id = advanced_ids[i];
|
||||
multiplier = std::make_shared<v1::Multiply>(multiplier, input_dims->output(input_id));
|
||||
@ -204,7 +208,7 @@ AtenIndexToSelect::AtenIndexToSelect() {
|
||||
v0::Constant::create(element::i32, Shape{adv_idx_permute.size()}, adv_idx_permute);
|
||||
gather = std::make_shared<v1::Transpose>(gather, permute_indicies);
|
||||
// unfold advanced index axes
|
||||
for (size_t i = 0; i <= advanced_ids[0]; i++) {
|
||||
for (size_t i = 0; i < advanced_ids[0]; i++) {
|
||||
concat_dims.push_back(input_dims->output(i));
|
||||
}
|
||||
concat_dims.push_back(cum_adv_index_shape_tensor);
|
||||
@ -223,8 +227,9 @@ AtenIndexToSelect::AtenIndexToSelect() {
|
||||
}
|
||||
auto final_shape = std::make_shared<v0::Concat>(concat_dims, 0);
|
||||
gather = std::make_shared<v1::Reshape>(gather, final_shape, false);
|
||||
copy_runtime_info({index_op, input_node, indicies}, gather);
|
||||
copy_runtime_info({index_op, indicies}, gather);
|
||||
replace_node(index_op, gather);
|
||||
gather->set_friendly_name(index_op->get_friendly_name());
|
||||
return true;
|
||||
|
||||
} else {
|
||||
@ -234,28 +239,28 @@ AtenIndexToSelect::AtenIndexToSelect() {
|
||||
// index is None, stay input as is
|
||||
const auto& attrs = const_input->get_attrs();
|
||||
if (attrs.find("none_value") != attrs.end()) {
|
||||
copy_runtime_info({index_op, input_node, indicies}, input_node);
|
||||
replace_node(index_op, input_node);
|
||||
replace_node(index_op, input_node.get_node_shared_ptr());
|
||||
return true;
|
||||
}
|
||||
}
|
||||
auto index_dtype = indicies->get_output_element_type(0);
|
||||
if (index_dtype == element::boolean || index_dtype == element::u8) {
|
||||
auto nonzero = std::make_shared<v3::NonZero>(indicies);
|
||||
auto nonzero = std::make_shared<v3::NonZero>(indicies, element::i32);
|
||||
auto input_order = v0::Constant::create(element::i32, Shape{2}, {1, 0});
|
||||
auto masked_id = std::make_shared<v1::Transpose>(nonzero, input_order);
|
||||
auto gather = std::make_shared<v8::GatherND>(input_node, masked_id);
|
||||
copy_runtime_info({index_op, input_node, indicies}, gather);
|
||||
copy_runtime_info({index_op, indicies}, gather);
|
||||
replace_node(index_op, gather);
|
||||
return true;
|
||||
}
|
||||
if (index_dtype != element::i32 && index_dtype != element::i32) {
|
||||
if (index_dtype != element::i32) {
|
||||
indicies = std::make_shared<ov::op::v0::Convert>(indicies, element::i32);
|
||||
}
|
||||
auto dim = v0::Constant::create(element::i32, Shape{}, {0});
|
||||
auto gather = std::make_shared<v8::Gather>(input_node, indicies, dim);
|
||||
copy_runtime_info({index_op, input_node, indicies}, gather);
|
||||
copy_runtime_info({index_op, indicies}, gather);
|
||||
replace_node(index_op, gather);
|
||||
gather->set_friendly_name(index_op->get_friendly_name());
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
|
@ -25,7 +25,6 @@ class TestIndex(PytorchLayerTest):
|
||||
def forward(self, x, idx):
|
||||
return x.__getitem__(idx)
|
||||
|
||||
|
||||
class aten_index_list_bool(torch.nn.Module):
|
||||
|
||||
def forward(self, x, idx):
|
||||
@ -58,7 +57,8 @@ class TestIndex(PytorchLayerTest):
|
||||
([7, 8, 9], np.array((-1, 2, -3)).astype(int)),
|
||||
([2, 2, 3, 4], np.array((1,)).astype(int))])
|
||||
def test_index(self, input_shape, idx, case, ie_device, precision, ir_version):
|
||||
self._test(*self.create_model(case), ie_device, precision, ir_version, kwargs_to_prepare_input={"input_shape": input_shape, "idx": idx})
|
||||
self._test(*self.create_model(case), ie_device, precision, ir_version,
|
||||
kwargs_to_prepare_input={"input_shape": input_shape, "idx": idx})
|
||||
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
@ -70,4 +70,35 @@ class TestIndex(PytorchLayerTest):
|
||||
((2, 2, 5), np.random.rand(2, 2, 5) > 0)
|
||||
])
|
||||
def test_index_bool(self, input_shape, idx, case, ie_device, precision, ir_version):
|
||||
self._test(*self.create_model(case), ie_device, precision, ir_version, kwargs_to_prepare_input={"input_shape": input_shape, "idx": idx})
|
||||
self._test(*self.create_model(case), ie_device, precision, ir_version,
|
||||
kwargs_to_prepare_input={"input_shape": input_shape, "idx": idx})
|
||||
|
||||
|
||||
class TestIndexRange(PytorchLayerTest):
|
||||
def _prepare_input(self, input_shape, idx):
|
||||
import numpy as np
|
||||
return (np.random.randn(*input_shape).astype(np.float32), np.array(idx).astype(np.int32))
|
||||
|
||||
def create_model(self):
|
||||
import torch
|
||||
|
||||
class aten_index_unsqueeze(torch.nn.Module):
|
||||
|
||||
def forward(self, x, y):
|
||||
x = x.reshape(x.shape[0], -1)
|
||||
return x[torch.arange(x.shape[0]), y]
|
||||
|
||||
ref_net = None
|
||||
|
||||
return aten_index_unsqueeze(), ref_net, "aten::index"
|
||||
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
@pytest.mark.parametrize(("input_shape", "idx"), (
|
||||
((1, 1), [0]),
|
||||
([2, 3], [1, 2]),
|
||||
([7, 8, 9], [1]),
|
||||
([2, 2, 3, 4], [0])))
|
||||
def test_index_range(self, input_shape, idx, ie_device, precision, ir_version):
|
||||
self._test(*self.create_model(), ie_device, precision, ir_version, kwargs_to_prepare_input={
|
||||
"input_shape": input_shape, "idx": idx}, trace_model=True, dynamic_shapes=False)
|
||||
|
Loading…
Reference in New Issue
Block a user