[PT FE] Fix failing translation of aten::index_put_ (#16140)

* Initial commit

* Fix for reading processed list

* Format code

* Cleanup

* cleanup

* Cleanup

* cleanup test

* Add comment

* Add rt_info

* fix type

* Update src/frontends/pytorch/src/transforms/aten_index_put_replacer.cpp

Co-authored-by: Maxim Vafin <maxim.vafin@intel.com>

---------

Co-authored-by: Andrei Kochin <andrei.kochin@intel.com>
Co-authored-by: Maxim Vafin <maxim.vafin@intel.com>
This commit is contained in:
Mateusz Mikolajczyk
2023-03-09 21:14:58 +01:00
committed by GitHub
parent 654f3d988f
commit 31489931cf
6 changed files with 301 additions and 144 deletions

View File

@@ -18,6 +18,7 @@
#include "transforms/append_list_unpack_replacer.hpp"
#include "transforms/aten_cat_replacer.hpp"
#include "transforms/aten_getitem_replacer.hpp"
#include "transforms/aten_index_put_replacer.hpp"
#include "transforms/aten_index_replacer.hpp"
#include "transforms/aten_stack_list_construct_replacer.hpp"
#include "transforms/einsum_list_construct.hpp"
@@ -100,6 +101,7 @@ void FrontEnd::normalize(const std::shared_ptr<ov::Model>& model) const {
manager.register_pass<ov::frontend::pytorch::pass::AtenGetItemReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::ListConstructReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::AtenIndexToSelect>();
manager.register_pass<ov::frontend::pytorch::pass::AtenIndexPutReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::PrimListConstructPadReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::AtenEinsumListConstructReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::MinMaxPrimListConstructReplacer>();

View File

@@ -3,19 +3,7 @@
//
#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/broadcast.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/convert_like.hpp"
#include "openvino/op/gather.hpp"
#include "openvino/op/mod.hpp"
#include "openvino/op/scatter_nd_update.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/op/slice.hpp"
#include "openvino/op/split.hpp"
#include "openvino/op/unsqueeze.hpp"
#include "utils.hpp"
#include "pt_framework_node.hpp"
namespace ov {
namespace frontend {
@@ -24,101 +12,13 @@ namespace op {
using namespace ov::op;
namespace {
Output<Node> generate_zeros_with_convertlike(const NodeContext& context,
const Output<Node> sizes,
const Output<Node> tensor_of_type) {
auto const_0 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
auto zeros = context.mark_node(std::make_shared<v3::Broadcast>(const_0, sizes));
return context.mark_node(std::make_shared<v1::ConvertLike>(zeros, tensor_of_type));
}
} // namespace
OutputVector translate_index_put_(NodeContext& context) {
num_inputs_check(context, 4, 4);
auto const_0 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
auto const_1 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {1}));
auto const_max_int =
context.mark_node(v0::Constant::create(element::i32, Shape{1}, {std::numeric_limits<int32_t>::max()}));
auto const_neg_1 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {-1}));
auto input = context.get_input(0);
auto input_shape = context.mark_node(std::make_shared<v3::ShapeOf>(input, element::i32));
auto indices = context.get_input(1);
auto values = context.get_input(2);
auto accumulate = context.const_input<bool>(3);
auto indices_partial_shape = indices.get_partial_shape();
FRONT_END_OP_CONVERSION_CHECK(indices_partial_shape.rank().is_static(),
"We support only indices with static rank.");
auto indices_first_dim = indices_partial_shape[0];
FRONT_END_OP_CONVERSION_CHECK(indices_first_dim.is_static(),
"We support only lists of tensors with static number of elements.");
int64_t indices_list_len = indices_first_dim.get_length();
if (indices_list_len == 0) {
return {values};
}
auto const_indices_list_len = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {indices_list_len}));
auto split_indices = context.mark_node(std::make_shared<v1::Split>(indices, const_0, indices_list_len));
std::shared_ptr<Node> broadcast_index_shape;
Output<Node> index;
if (indices_list_len > 1) {
index = split_indices->output(0);
for (int i = 1; i < indices_list_len; i++) {
index = context.mark_node(std::make_shared<v1::Add>(index, split_indices->output(i)));
}
broadcast_index_shape = context.mark_node(std::make_shared<v3::ShapeOf>(index, element::i32));
OutputVector indices_list;
for (int i = 0; i < indices_list_len; i++) {
auto broadcast =
context.mark_node(std::make_shared<v3::Broadcast>(split_indices->output(i), broadcast_index_shape));
auto unsqueeze = context.mark_node(std::make_shared<v0::Unsqueeze>(broadcast, const_neg_1));
// change negative indices to positive indices
auto const_i = context.mark_node(v0::Constant::create(element::i32, Shape{}, {i}));
auto dim_i = context.mark_node(std::make_shared<v8::Gather>(input_shape, const_i, const_0));
auto dim_i_correct_type = context.mark_node(std::make_shared<v1::ConvertLike>(dim_i, index));
unsqueeze = context.mark_node(std::make_shared<v1::Add>(unsqueeze, dim_i_correct_type));
unsqueeze = context.mark_node(std::make_shared<v1::Mod>(unsqueeze, dim_i_correct_type));
indices_list.push_back(unsqueeze);
}
index = context.mark_node(std::make_shared<v0::Concat>(indices_list, -1));
} else {
index = split_indices->output(0);
// change negative indices to positive indices
auto dim_0 = context.mark_node(std::make_shared<v8::Gather>(input_shape, const_0, const_0));
auto dim_0_correct_type = context.mark_node(std::make_shared<v1::ConvertLike>(dim_0, index));
index = context.mark_node(std::make_shared<v1::Add>(index, dim_0_correct_type));
index = context.mark_node(std::make_shared<v1::Mod>(index, dim_0_correct_type));
broadcast_index_shape = context.mark_node(std::make_shared<v3::ShapeOf>(index, element::i32));
index = context.mark_node(std::make_shared<v0::Unsqueeze>(index, const_neg_1));
}
auto sub_data_shape =
context.mark_node(std::make_shared<v8::Slice>(input_shape, const_indices_list_len, const_max_int, const_1));
auto values_shape =
context.mark_node(std::make_shared<v0::Concat>(OutputVector{broadcast_index_shape, sub_data_shape}, 0));
values = context.mark_node(std::make_shared<v3::Broadcast>(values, values_shape));
values = context.mark_node(std::make_shared<v1::ConvertLike>(values, input));
Output<Node> result;
if (accumulate) {
auto zeros = generate_zeros_with_convertlike(context, input_shape, input);
result = context.mark_node(std::make_shared<v3::ScatterNDUpdate>(zeros, index, values));
result = context.mark_node(std::make_shared<v1::Add>(input, result));
} else {
result = context.mark_node(std::make_shared<v3::ScatterNDUpdate>(input, index, values));
}
return {result};
// Pass as PtFrameworkNode to register as `inplace_op`. Conversion to OV operators is done as transformation.
auto node = std::make_shared<PtFrameworkNode>(context.get_decoder(), context.inputs());
return {context.mark_node(node)};
};
} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov
} // namespace ov

View File

@@ -0,0 +1,158 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "aten_index_put_replacer.hpp"
#include "openvino/core/rt_info.hpp"
#include "openvino/frontend/pytorch/visibility.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/broadcast.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/convert_like.hpp"
#include "openvino/op/gather.hpp"
#include "openvino/op/mod.hpp"
#include "openvino/op/scatter_nd_update.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/op/slice.hpp"
#include "openvino/op/split.hpp"
#include "openvino/op/unsqueeze.hpp"
#include "openvino/op/util/framework_node.hpp"
#include "openvino/pass/pattern/matcher.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "utils.hpp"
namespace ov {
namespace frontend {
namespace pytorch {
namespace pass {
using namespace ov::op;
namespace {
Output<Node> generate_zeros_with_convertlike(const Output<Node> sizes, const Output<Node> tensor_of_type) {
auto const_0 = v0::Constant::create(element::i32, Shape{}, {0});
auto zeros = std::make_shared<v3::Broadcast>(const_0, sizes);
return std::make_shared<v1::ConvertLike>(zeros, tensor_of_type);
}
} // namespace
AtenIndexPutReplacer::AtenIndexPutReplacer() {
auto index_op = ov::pass::pattern::wrap_type<ov::op::util::FrameworkNode>();
ov::matcher_pass_callback callback = [](ov::pass::pattern::Matcher& m) {
auto index_op = cast_fw_node(m.get_match_root(), "aten::index_put_");
if (!index_op) {
return false;
}
NodeVector rt_copy_from{index_op};
auto const_0 = v0::Constant::create(element::i32, Shape{}, {0});
auto const_1 = v0::Constant::create(element::i32, Shape{1}, {1});
auto const_max_int = v0::Constant::create(element::i32, Shape{1}, {std::numeric_limits<int32_t>::max()});
auto const_neg_1 = v0::Constant::create(element::i32, Shape{}, {-1});
auto input = index_op->input_value(0);
auto input_shape = std::make_shared<v3::ShapeOf>(input, element::i32);
auto indices = index_op->input_value(1);
auto values = index_op->input_value(2);
auto acc_const =
std::dynamic_pointer_cast<ov::op::v0::Constant>(index_op->input_value(3).get_node_shared_ptr());
if (!acc_const) {
return false;
}
bool accumulate = acc_const->cast_vector<bool>()[0];
int64_t indices_list_len;
OutputVector indices_inputs;
if (auto listconstruct = cast_fw_node(indices.get_node_shared_ptr(), "prim::ListConstruct")) {
rt_copy_from.push_back(listconstruct);
indices_inputs = listconstruct->input_values();
indices_list_len = indices_inputs.size();
} else {
auto indices_partial_shape = indices.get_partial_shape();
if (!indices_partial_shape.rank().is_static()) {
// "We support only indices with static rank."
return false;
}
auto indices_first_dim = indices_partial_shape[0];
if (!indices_first_dim.is_static()) {
// We support only lists of tensors with static number of elements.
return false;
}
indices_list_len = indices_first_dim.get_length();
auto split = std::make_shared<v1::Split>(indices, const_0, indices_list_len);
indices_inputs = split->outputs();
}
if (indices_list_len == 0) {
copy_runtime_info(rt_copy_from, values.get_node_shared_ptr());
replace_node(index_op, values.get_node_shared_ptr());
return true;
}
auto const_indices_list_len = v0::Constant::create(element::i32, Shape{1}, {indices_list_len});
std::shared_ptr<Node> broadcast_index_shape;
Output<Node> index;
if (indices_list_len > 1) {
index = indices_inputs[0];
for (int i = 1; i < indices_list_len; i++) {
index = std::make_shared<v1::Add>(index, indices_inputs[i]);
}
broadcast_index_shape = std::make_shared<v3::ShapeOf>(index, element::i32);
OutputVector indices_list;
for (int i = 0; i < indices_list_len; i++) {
auto broadcast = std::make_shared<v3::Broadcast>(indices_inputs[i], broadcast_index_shape);
auto unsqueeze = std::make_shared<v0::Unsqueeze>(broadcast, const_neg_1);
// change negative indices to positive indices
auto const_i = v0::Constant::create(element::i32, Shape{}, {i});
auto dim_i = std::make_shared<v8::Gather>(input_shape, const_i, const_0);
auto dim_i_correct_type = std::make_shared<v1::ConvertLike>(dim_i, index);
auto unsqueeze_add = std::make_shared<v1::Add>(unsqueeze, dim_i_correct_type);
auto unsqueeze_add_mod = std::make_shared<v1::Mod>(unsqueeze_add, dim_i_correct_type);
indices_list.push_back(unsqueeze_add_mod);
}
index = std::make_shared<v0::Concat>(indices_list, -1);
} else {
index = indices_inputs[0];
// change negative indices to positive indices
auto dim_0 = (std::make_shared<v8::Gather>(input_shape, const_0, const_0));
auto dim_0_correct_type = (std::make_shared<v1::ConvertLike>(dim_0, index));
index = std::make_shared<v1::Add>(index, dim_0_correct_type);
index = std::make_shared<v1::Mod>(index, dim_0_correct_type);
broadcast_index_shape = std::make_shared<v3::ShapeOf>(index, element::i32);
index = std::make_shared<v0::Unsqueeze>(index, const_neg_1);
}
auto sub_data_shape = std::make_shared<v8::Slice>(input_shape, const_indices_list_len, const_max_int, const_1);
auto values_shape = std::make_shared<v0::Concat>(OutputVector{broadcast_index_shape, sub_data_shape}, 0);
values = std::make_shared<v3::Broadcast>(values, values_shape);
values = std::make_shared<v1::ConvertLike>(values, input);
std::shared_ptr<ov::Node> result;
if (accumulate) {
auto zeros = generate_zeros_with_convertlike(input_shape, input);
auto scatter = std::make_shared<v3::ScatterNDUpdate>(zeros, index, values);
result = std::make_shared<v1::Add>(input, scatter);
} else {
result = std::make_shared<v3::ScatterNDUpdate>(input, index, values);
}
copy_runtime_info(rt_copy_from, result);
replace_node(index_op, result);
result->set_friendly_name(index_op->get_friendly_name());
return true;
};
auto m =
std::make_shared<ov::pass::pattern::Matcher>(index_op, "ov::frontend::pytorch::pass::AtenIndexPutReplacer");
this->register_matcher(m, callback);
}
} // namespace pass
} // namespace pytorch
} // namespace frontend
} // namespace ov

View File

@@ -0,0 +1,25 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/frontend/pytorch/visibility.hpp"
#include "openvino/pass/graph_rewrite.hpp"
#include "openvino/pass/pass.hpp"
namespace ov {
namespace frontend {
namespace pytorch {
namespace pass {
class PYTORCH_API AtenIndexPutReplacer : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ov::frontend::pytorch::pass::AtenIndexPutReplacer");
AtenIndexPutReplacer();
};
} // namespace pass
} // namespace pytorch
} // namespace frontend
} // namespace ov

View File

@@ -3,6 +3,7 @@
import itertools
import warnings
from copy import deepcopy
import numpy as np
from common.constants import test_device, test_precision
@@ -51,7 +52,7 @@ class PytorchLayerTest:
model = torch.jit.script(model)
else:
torch_inputs = [torch.from_numpy(inp) for inp in inputs]
model = torch.jit.trace(model, torch_inputs)
model = torch.jit.trace(model, deepcopy(torch_inputs))
if kwargs.get('freeze_model', True):
model = torch.jit.freeze(model)
graph = model.inlined_graph
@@ -91,14 +92,14 @@ class PytorchLayerTest:
# OV infer:
core = Core()
compiled = core.compile_model(om, ie_device)
infer_res = compiled(inputs)
infer_res = compiled(deepcopy(inputs))
if hasattr(self, 'skip_framework') and self.skip_framework:
warnings.warn('Framework is skipped')
return
# Framework infer:
fw_res = model(*torch_inps)
fw_res = model(*deepcopy(torch_inps))
if not isinstance(fw_res, (tuple)):
fw_res = (fw_res,)

View File

@@ -14,9 +14,7 @@ class TestIndexPut_SingleIndices(PytorchLayerTest):
return (self.input_tensor, self.values)
def create_model(self, indices, accumulate):
class aten_index_put_(torch.nn.Module):
def __init__(self, indices, accumulate):
super().__init__()
self.indices = indices
@@ -30,31 +28,43 @@ class TestIndexPut_SingleIndices(PytorchLayerTest):
return aten_index_put_(indices, accumulate), ref_net, "aten::index_put_"
@pytest.mark.parametrize('input_data', ({'input_tensor': np.random.randn(5).astype(np.float32),
'values': np.array(11).astype(np.float32)},
{'input_tensor': np.random.randn(3, 3).astype(np.float32),
'values': np.array([10, 11, 12]).astype(np.float32)}))
@pytest.mark.parametrize('indices', (torch.tensor([0], dtype=torch.long),
torch.tensor([-1, -2], dtype=torch.long),
torch.tensor([0, -1, -2], dtype=torch.long),
torch.tensor([1, 2], dtype=torch.long),
torch.tensor([0, 1, 2], dtype=torch.long)))
@pytest.mark.parametrize('accumulate', (True, False))
@pytest.mark.parametrize(
"input_data",
(
{
"input_tensor": np.random.randn(5).astype(np.float32),
"values": np.array(11).astype(np.float32)},
{
"input_tensor": np.random.randn(3, 3).astype(np.float32),
"values": np.array([10, 11, 12]).astype(np.float32),
},
),
)
@pytest.mark.parametrize(
"indices",
(
torch.tensor([0], dtype=torch.long),
torch.tensor([-1, -2], dtype=torch.long),
torch.tensor([0, -1, -2], dtype=torch.long),
torch.tensor([1, 2], dtype=torch.long),
torch.tensor([0, 1, 2], dtype=torch.long),
),
)
@pytest.mark.parametrize("accumulate", (True, False))
@pytest.mark.nightly
@pytest.mark.precommit
def test_index_put_single_indices(self, ie_device, precision, ir_version, input_data, indices, accumulate):
self.input_tensor = input_data['input_tensor']
self.values = input_data['values']
self.input_tensor = input_data["input_tensor"]
self.values = input_data["values"]
self._test(*self.create_model(indices, accumulate), ie_device, precision, ir_version)
class TestIndexPut_ManyIndices(PytorchLayerTest):
def _prepare_input(self):
return (self.input_tensor, self.values)
def create_model(self, indices, accumulate):
class aten_index_put_(torch.nn.Module):
def __init__(self, indices, accumulate):
super().__init__()
self.indices_first = indices[0]
@@ -69,26 +79,87 @@ class TestIndexPut_ManyIndices(PytorchLayerTest):
return aten_index_put_(indices, accumulate), ref_net, "aten::index_put_"
@pytest.mark.parametrize('input_data', ({'input_tensor': np.random.randn(3, 3).astype(np.float32),
'values': np.array(12).astype(np.float32)},
{'input_tensor': np.random.randn(3, 3, 3).astype(np.float32),
'values': np.array([10, 11, 12]).astype(np.float32)},))
@pytest.mark.parametrize('indices', ((torch.tensor([0], dtype=torch.long),
torch.tensor([2], dtype=torch.long)),
(torch.tensor([1, 2], dtype=torch.long),
torch.tensor([0, 1], dtype=torch.long)),
(torch.tensor([0, 1], dtype=torch.long),
torch.tensor([0, 1], dtype=torch.long)),
(torch.tensor([0], dtype=torch.long),
torch.tensor([-2], dtype=torch.long)),
(torch.tensor([-1, -2], dtype=torch.long),
torch.tensor([0, 1], dtype=torch.long)),
(torch.tensor([0, -1], dtype=torch.long),
torch.tensor([0, -1], dtype=torch.long))))
@pytest.mark.parametrize('accumulate', (True, False))
@pytest.mark.parametrize(
"input_data",
(
{
"input_tensor": np.random.randn(3, 3).astype(np.float32),
"values": np.array(12).astype(np.float32)
},
{
"input_tensor": np.random.randn(3, 3, 3).astype(np.float32),
"values": np.array([10, 11, 12]).astype(np.float32),
},
),
)
@pytest.mark.parametrize(
"indices",
(
(torch.tensor([0], dtype=torch.long), torch.tensor([2], dtype=torch.long)),
(torch.tensor([1, 2], dtype=torch.long), torch.tensor([0, 1], dtype=torch.long)),
(torch.tensor([0, 1], dtype=torch.long), torch.tensor([0, 1], dtype=torch.long)),
(torch.tensor([0], dtype=torch.long), torch.tensor([-2], dtype=torch.long)),
(torch.tensor([-1, -2], dtype=torch.long), torch.tensor([0, 1], dtype=torch.long)),
(torch.tensor([0, -1], dtype=torch.long), torch.tensor([0, -1], dtype=torch.long)),
),
)
@pytest.mark.parametrize("accumulate", (True, False))
@pytest.mark.nightly
@pytest.mark.precommit
def test_index_put_many_indices(self, ie_device, precision, ir_version, input_data, indices, accumulate):
self.input_tensor = input_data['input_tensor']
self.values = input_data['values']
self._test(*self.create_model(indices, accumulate), ie_device, precision, ir_version)
self.input_tensor = input_data["input_tensor"]
self.values = input_data["values"]
self._test(*self.create_model(indices, accumulate), ie_device, precision, ir_version)
class TestNonZero_IndexPut(PytorchLayerTest):
def _prepare_input(self):
return (self.input_tensor, self.values, self.indices_0, self.indices_1)
def create_model(self, accumulate):
class aten_index_put_(torch.nn.Module):
def __init__(self, accumulate):
super().__init__()
self.accumulate = accumulate
def forward(self, input_tensor, values, indices_0, indices_1):
nonzero = (indices_0 == indices_1).nonzero(as_tuple=True)[0]
input_tensor.index_put_((nonzero,), values, self.accumulate)
return input_tensor
ref_net = None
return aten_index_put_(accumulate), ref_net, "aten::index_put_"
@pytest.mark.parametrize(
"input_data",
(
{
"input_tensor": np.random.randn(3).astype(np.float32),
"values": np.array(11).astype(np.float32),
},
{
"input_tensor": np.random.randn(3, 3).astype(np.float32),
"values": np.array([10, 11, 12]).astype(np.float32),
},
),
)
@pytest.mark.parametrize(
"indices",
(
(np.random.randint(low=0, high=2, size=(1,)), np.random.randint(low=0, high=2, size=(1,))),
(np.random.randint(low=0, high=2, size=(2,)), np.random.randint(low=0, high=2, size=(2,))),
(np.array([0, 1, 0]), np.array([1, 1, 0])),
(np.ones(shape=(3,)), np.ones(shape=(3,))),
(np.ones(shape=(3,)), np.zeros(shape=(3,))),
),
)
@pytest.mark.parametrize("accumulate", (False, True))
@pytest.mark.nightly
@pytest.mark.precommit
def test_nonzero_index_put_(self, ie_device, precision, ir_version, input_data, indices, accumulate):
self.input_tensor = input_data["input_tensor"]
self.values = input_data["values"]
self.indices_0 = indices[0]
self.indices_1 = indices[1]
self._test(*self.create_model(accumulate), ie_device, precision, ir_version, trace_model=True)