[PT FE] Fix failing translation of aten::index_put_ (#16140)
* Initial commit * Fix for reading processed list * Format code * Cleanup * cleanup * Cleanup * cleanup test * Add comment * Add rt_info * fix type * Update src/frontends/pytorch/src/transforms/aten_index_put_replacer.cpp Co-authored-by: Maxim Vafin <maxim.vafin@intel.com> --------- Co-authored-by: Andrei Kochin <andrei.kochin@intel.com> Co-authored-by: Maxim Vafin <maxim.vafin@intel.com>
This commit is contained in:
committed by
GitHub
parent
654f3d988f
commit
31489931cf
@@ -18,6 +18,7 @@
|
||||
#include "transforms/append_list_unpack_replacer.hpp"
|
||||
#include "transforms/aten_cat_replacer.hpp"
|
||||
#include "transforms/aten_getitem_replacer.hpp"
|
||||
#include "transforms/aten_index_put_replacer.hpp"
|
||||
#include "transforms/aten_index_replacer.hpp"
|
||||
#include "transforms/aten_stack_list_construct_replacer.hpp"
|
||||
#include "transforms/einsum_list_construct.hpp"
|
||||
@@ -100,6 +101,7 @@ void FrontEnd::normalize(const std::shared_ptr<ov::Model>& model) const {
|
||||
manager.register_pass<ov::frontend::pytorch::pass::AtenGetItemReplacer>();
|
||||
manager.register_pass<ov::frontend::pytorch::pass::ListConstructReplacer>();
|
||||
manager.register_pass<ov::frontend::pytorch::pass::AtenIndexToSelect>();
|
||||
manager.register_pass<ov::frontend::pytorch::pass::AtenIndexPutReplacer>();
|
||||
manager.register_pass<ov::frontend::pytorch::pass::PrimListConstructPadReplacer>();
|
||||
manager.register_pass<ov::frontend::pytorch::pass::AtenEinsumListConstructReplacer>();
|
||||
manager.register_pass<ov::frontend::pytorch::pass::MinMaxPrimListConstructReplacer>();
|
||||
|
||||
@@ -3,19 +3,7 @@
|
||||
//
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/op/add.hpp"
|
||||
#include "openvino/op/broadcast.hpp"
|
||||
#include "openvino/op/concat.hpp"
|
||||
#include "openvino/op/constant.hpp"
|
||||
#include "openvino/op/convert_like.hpp"
|
||||
#include "openvino/op/gather.hpp"
|
||||
#include "openvino/op/mod.hpp"
|
||||
#include "openvino/op/scatter_nd_update.hpp"
|
||||
#include "openvino/op/shape_of.hpp"
|
||||
#include "openvino/op/slice.hpp"
|
||||
#include "openvino/op/split.hpp"
|
||||
#include "openvino/op/unsqueeze.hpp"
|
||||
#include "utils.hpp"
|
||||
#include "pt_framework_node.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
@@ -24,101 +12,13 @@ namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
namespace {
|
||||
Output<Node> generate_zeros_with_convertlike(const NodeContext& context,
|
||||
const Output<Node> sizes,
|
||||
const Output<Node> tensor_of_type) {
|
||||
auto const_0 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
|
||||
auto zeros = context.mark_node(std::make_shared<v3::Broadcast>(const_0, sizes));
|
||||
return context.mark_node(std::make_shared<v1::ConvertLike>(zeros, tensor_of_type));
|
||||
}
|
||||
} // namespace
|
||||
|
||||
OutputVector translate_index_put_(NodeContext& context) {
|
||||
num_inputs_check(context, 4, 4);
|
||||
auto const_0 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
|
||||
auto const_1 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {1}));
|
||||
auto const_max_int =
|
||||
context.mark_node(v0::Constant::create(element::i32, Shape{1}, {std::numeric_limits<int32_t>::max()}));
|
||||
auto const_neg_1 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {-1}));
|
||||
|
||||
auto input = context.get_input(0);
|
||||
auto input_shape = context.mark_node(std::make_shared<v3::ShapeOf>(input, element::i32));
|
||||
auto indices = context.get_input(1);
|
||||
auto values = context.get_input(2);
|
||||
auto accumulate = context.const_input<bool>(3);
|
||||
|
||||
auto indices_partial_shape = indices.get_partial_shape();
|
||||
FRONT_END_OP_CONVERSION_CHECK(indices_partial_shape.rank().is_static(),
|
||||
"We support only indices with static rank.");
|
||||
auto indices_first_dim = indices_partial_shape[0];
|
||||
FRONT_END_OP_CONVERSION_CHECK(indices_first_dim.is_static(),
|
||||
"We support only lists of tensors with static number of elements.");
|
||||
int64_t indices_list_len = indices_first_dim.get_length();
|
||||
if (indices_list_len == 0) {
|
||||
return {values};
|
||||
}
|
||||
|
||||
auto const_indices_list_len = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {indices_list_len}));
|
||||
auto split_indices = context.mark_node(std::make_shared<v1::Split>(indices, const_0, indices_list_len));
|
||||
|
||||
std::shared_ptr<Node> broadcast_index_shape;
|
||||
Output<Node> index;
|
||||
if (indices_list_len > 1) {
|
||||
index = split_indices->output(0);
|
||||
for (int i = 1; i < indices_list_len; i++) {
|
||||
index = context.mark_node(std::make_shared<v1::Add>(index, split_indices->output(i)));
|
||||
}
|
||||
broadcast_index_shape = context.mark_node(std::make_shared<v3::ShapeOf>(index, element::i32));
|
||||
OutputVector indices_list;
|
||||
for (int i = 0; i < indices_list_len; i++) {
|
||||
auto broadcast =
|
||||
context.mark_node(std::make_shared<v3::Broadcast>(split_indices->output(i), broadcast_index_shape));
|
||||
auto unsqueeze = context.mark_node(std::make_shared<v0::Unsqueeze>(broadcast, const_neg_1));
|
||||
|
||||
// change negative indices to positive indices
|
||||
auto const_i = context.mark_node(v0::Constant::create(element::i32, Shape{}, {i}));
|
||||
auto dim_i = context.mark_node(std::make_shared<v8::Gather>(input_shape, const_i, const_0));
|
||||
auto dim_i_correct_type = context.mark_node(std::make_shared<v1::ConvertLike>(dim_i, index));
|
||||
unsqueeze = context.mark_node(std::make_shared<v1::Add>(unsqueeze, dim_i_correct_type));
|
||||
unsqueeze = context.mark_node(std::make_shared<v1::Mod>(unsqueeze, dim_i_correct_type));
|
||||
|
||||
indices_list.push_back(unsqueeze);
|
||||
}
|
||||
index = context.mark_node(std::make_shared<v0::Concat>(indices_list, -1));
|
||||
} else {
|
||||
index = split_indices->output(0);
|
||||
|
||||
// change negative indices to positive indices
|
||||
auto dim_0 = context.mark_node(std::make_shared<v8::Gather>(input_shape, const_0, const_0));
|
||||
auto dim_0_correct_type = context.mark_node(std::make_shared<v1::ConvertLike>(dim_0, index));
|
||||
index = context.mark_node(std::make_shared<v1::Add>(index, dim_0_correct_type));
|
||||
index = context.mark_node(std::make_shared<v1::Mod>(index, dim_0_correct_type));
|
||||
|
||||
broadcast_index_shape = context.mark_node(std::make_shared<v3::ShapeOf>(index, element::i32));
|
||||
index = context.mark_node(std::make_shared<v0::Unsqueeze>(index, const_neg_1));
|
||||
}
|
||||
|
||||
auto sub_data_shape =
|
||||
context.mark_node(std::make_shared<v8::Slice>(input_shape, const_indices_list_len, const_max_int, const_1));
|
||||
auto values_shape =
|
||||
context.mark_node(std::make_shared<v0::Concat>(OutputVector{broadcast_index_shape, sub_data_shape}, 0));
|
||||
values = context.mark_node(std::make_shared<v3::Broadcast>(values, values_shape));
|
||||
values = context.mark_node(std::make_shared<v1::ConvertLike>(values, input));
|
||||
|
||||
Output<Node> result;
|
||||
if (accumulate) {
|
||||
auto zeros = generate_zeros_with_convertlike(context, input_shape, input);
|
||||
result = context.mark_node(std::make_shared<v3::ScatterNDUpdate>(zeros, index, values));
|
||||
result = context.mark_node(std::make_shared<v1::Add>(input, result));
|
||||
} else {
|
||||
result = context.mark_node(std::make_shared<v3::ScatterNDUpdate>(input, index, values));
|
||||
}
|
||||
|
||||
return {result};
|
||||
// Pass as PtFrameworkNode to register as `inplace_op`. Conversion to OV operators is done as transformation.
|
||||
auto node = std::make_shared<PtFrameworkNode>(context.get_decoder(), context.inputs());
|
||||
return {context.mark_node(node)};
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
} // namespace pytorch
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
||||
} // namespace ov
|
||||
|
||||
158
src/frontends/pytorch/src/transforms/aten_index_put_replacer.cpp
Normal file
158
src/frontends/pytorch/src/transforms/aten_index_put_replacer.cpp
Normal file
@@ -0,0 +1,158 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "aten_index_put_replacer.hpp"
|
||||
|
||||
#include "openvino/core/rt_info.hpp"
|
||||
#include "openvino/frontend/pytorch/visibility.hpp"
|
||||
#include "openvino/op/add.hpp"
|
||||
#include "openvino/op/broadcast.hpp"
|
||||
#include "openvino/op/concat.hpp"
|
||||
#include "openvino/op/constant.hpp"
|
||||
#include "openvino/op/convert_like.hpp"
|
||||
#include "openvino/op/gather.hpp"
|
||||
#include "openvino/op/mod.hpp"
|
||||
#include "openvino/op/scatter_nd_update.hpp"
|
||||
#include "openvino/op/shape_of.hpp"
|
||||
#include "openvino/op/slice.hpp"
|
||||
#include "openvino/op/split.hpp"
|
||||
#include "openvino/op/unsqueeze.hpp"
|
||||
#include "openvino/op/util/framework_node.hpp"
|
||||
#include "openvino/pass/pattern/matcher.hpp"
|
||||
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace pass {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
namespace {
|
||||
Output<Node> generate_zeros_with_convertlike(const Output<Node> sizes, const Output<Node> tensor_of_type) {
|
||||
auto const_0 = v0::Constant::create(element::i32, Shape{}, {0});
|
||||
auto zeros = std::make_shared<v3::Broadcast>(const_0, sizes);
|
||||
return std::make_shared<v1::ConvertLike>(zeros, tensor_of_type);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
AtenIndexPutReplacer::AtenIndexPutReplacer() {
|
||||
auto index_op = ov::pass::pattern::wrap_type<ov::op::util::FrameworkNode>();
|
||||
|
||||
ov::matcher_pass_callback callback = [](ov::pass::pattern::Matcher& m) {
|
||||
auto index_op = cast_fw_node(m.get_match_root(), "aten::index_put_");
|
||||
if (!index_op) {
|
||||
return false;
|
||||
}
|
||||
NodeVector rt_copy_from{index_op};
|
||||
auto const_0 = v0::Constant::create(element::i32, Shape{}, {0});
|
||||
auto const_1 = v0::Constant::create(element::i32, Shape{1}, {1});
|
||||
auto const_max_int = v0::Constant::create(element::i32, Shape{1}, {std::numeric_limits<int32_t>::max()});
|
||||
auto const_neg_1 = v0::Constant::create(element::i32, Shape{}, {-1});
|
||||
|
||||
auto input = index_op->input_value(0);
|
||||
auto input_shape = std::make_shared<v3::ShapeOf>(input, element::i32);
|
||||
auto indices = index_op->input_value(1);
|
||||
auto values = index_op->input_value(2);
|
||||
auto acc_const =
|
||||
std::dynamic_pointer_cast<ov::op::v0::Constant>(index_op->input_value(3).get_node_shared_ptr());
|
||||
if (!acc_const) {
|
||||
return false;
|
||||
}
|
||||
bool accumulate = acc_const->cast_vector<bool>()[0];
|
||||
|
||||
int64_t indices_list_len;
|
||||
OutputVector indices_inputs;
|
||||
if (auto listconstruct = cast_fw_node(indices.get_node_shared_ptr(), "prim::ListConstruct")) {
|
||||
rt_copy_from.push_back(listconstruct);
|
||||
indices_inputs = listconstruct->input_values();
|
||||
indices_list_len = indices_inputs.size();
|
||||
} else {
|
||||
auto indices_partial_shape = indices.get_partial_shape();
|
||||
if (!indices_partial_shape.rank().is_static()) {
|
||||
// "We support only indices with static rank."
|
||||
return false;
|
||||
}
|
||||
auto indices_first_dim = indices_partial_shape[0];
|
||||
if (!indices_first_dim.is_static()) {
|
||||
// We support only lists of tensors with static number of elements.
|
||||
return false;
|
||||
}
|
||||
indices_list_len = indices_first_dim.get_length();
|
||||
auto split = std::make_shared<v1::Split>(indices, const_0, indices_list_len);
|
||||
indices_inputs = split->outputs();
|
||||
}
|
||||
|
||||
if (indices_list_len == 0) {
|
||||
copy_runtime_info(rt_copy_from, values.get_node_shared_ptr());
|
||||
replace_node(index_op, values.get_node_shared_ptr());
|
||||
return true;
|
||||
}
|
||||
|
||||
auto const_indices_list_len = v0::Constant::create(element::i32, Shape{1}, {indices_list_len});
|
||||
|
||||
std::shared_ptr<Node> broadcast_index_shape;
|
||||
Output<Node> index;
|
||||
if (indices_list_len > 1) {
|
||||
index = indices_inputs[0];
|
||||
for (int i = 1; i < indices_list_len; i++) {
|
||||
index = std::make_shared<v1::Add>(index, indices_inputs[i]);
|
||||
}
|
||||
broadcast_index_shape = std::make_shared<v3::ShapeOf>(index, element::i32);
|
||||
OutputVector indices_list;
|
||||
for (int i = 0; i < indices_list_len; i++) {
|
||||
auto broadcast = std::make_shared<v3::Broadcast>(indices_inputs[i], broadcast_index_shape);
|
||||
auto unsqueeze = std::make_shared<v0::Unsqueeze>(broadcast, const_neg_1);
|
||||
|
||||
// change negative indices to positive indices
|
||||
auto const_i = v0::Constant::create(element::i32, Shape{}, {i});
|
||||
auto dim_i = std::make_shared<v8::Gather>(input_shape, const_i, const_0);
|
||||
auto dim_i_correct_type = std::make_shared<v1::ConvertLike>(dim_i, index);
|
||||
auto unsqueeze_add = std::make_shared<v1::Add>(unsqueeze, dim_i_correct_type);
|
||||
auto unsqueeze_add_mod = std::make_shared<v1::Mod>(unsqueeze_add, dim_i_correct_type);
|
||||
|
||||
indices_list.push_back(unsqueeze_add_mod);
|
||||
}
|
||||
index = std::make_shared<v0::Concat>(indices_list, -1);
|
||||
} else {
|
||||
index = indices_inputs[0];
|
||||
// change negative indices to positive indices
|
||||
auto dim_0 = (std::make_shared<v8::Gather>(input_shape, const_0, const_0));
|
||||
auto dim_0_correct_type = (std::make_shared<v1::ConvertLike>(dim_0, index));
|
||||
index = std::make_shared<v1::Add>(index, dim_0_correct_type);
|
||||
index = std::make_shared<v1::Mod>(index, dim_0_correct_type);
|
||||
|
||||
broadcast_index_shape = std::make_shared<v3::ShapeOf>(index, element::i32);
|
||||
index = std::make_shared<v0::Unsqueeze>(index, const_neg_1);
|
||||
}
|
||||
|
||||
auto sub_data_shape = std::make_shared<v8::Slice>(input_shape, const_indices_list_len, const_max_int, const_1);
|
||||
auto values_shape = std::make_shared<v0::Concat>(OutputVector{broadcast_index_shape, sub_data_shape}, 0);
|
||||
values = std::make_shared<v3::Broadcast>(values, values_shape);
|
||||
values = std::make_shared<v1::ConvertLike>(values, input);
|
||||
|
||||
std::shared_ptr<ov::Node> result;
|
||||
if (accumulate) {
|
||||
auto zeros = generate_zeros_with_convertlike(input_shape, input);
|
||||
auto scatter = std::make_shared<v3::ScatterNDUpdate>(zeros, index, values);
|
||||
result = std::make_shared<v1::Add>(input, scatter);
|
||||
} else {
|
||||
result = std::make_shared<v3::ScatterNDUpdate>(input, index, values);
|
||||
}
|
||||
copy_runtime_info(rt_copy_from, result);
|
||||
replace_node(index_op, result);
|
||||
result->set_friendly_name(index_op->get_friendly_name());
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m =
|
||||
std::make_shared<ov::pass::pattern::Matcher>(index_op, "ov::frontend::pytorch::pass::AtenIndexPutReplacer");
|
||||
this->register_matcher(m, callback);
|
||||
}
|
||||
|
||||
} // namespace pass
|
||||
} // namespace pytorch
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
||||
@@ -0,0 +1,25 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "openvino/frontend/pytorch/visibility.hpp"
|
||||
#include "openvino/pass/graph_rewrite.hpp"
|
||||
#include "openvino/pass/pass.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace pass {
|
||||
|
||||
class PYTORCH_API AtenIndexPutReplacer : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::frontend::pytorch::pass::AtenIndexPutReplacer");
|
||||
AtenIndexPutReplacer();
|
||||
};
|
||||
|
||||
} // namespace pass
|
||||
} // namespace pytorch
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
import itertools
|
||||
import warnings
|
||||
from copy import deepcopy
|
||||
|
||||
import numpy as np
|
||||
from common.constants import test_device, test_precision
|
||||
@@ -51,7 +52,7 @@ class PytorchLayerTest:
|
||||
model = torch.jit.script(model)
|
||||
else:
|
||||
torch_inputs = [torch.from_numpy(inp) for inp in inputs]
|
||||
model = torch.jit.trace(model, torch_inputs)
|
||||
model = torch.jit.trace(model, deepcopy(torch_inputs))
|
||||
if kwargs.get('freeze_model', True):
|
||||
model = torch.jit.freeze(model)
|
||||
graph = model.inlined_graph
|
||||
@@ -91,14 +92,14 @@ class PytorchLayerTest:
|
||||
# OV infer:
|
||||
core = Core()
|
||||
compiled = core.compile_model(om, ie_device)
|
||||
infer_res = compiled(inputs)
|
||||
infer_res = compiled(deepcopy(inputs))
|
||||
|
||||
if hasattr(self, 'skip_framework') and self.skip_framework:
|
||||
warnings.warn('Framework is skipped')
|
||||
return
|
||||
|
||||
# Framework infer:
|
||||
fw_res = model(*torch_inps)
|
||||
fw_res = model(*deepcopy(torch_inps))
|
||||
|
||||
if not isinstance(fw_res, (tuple)):
|
||||
fw_res = (fw_res,)
|
||||
|
||||
@@ -14,9 +14,7 @@ class TestIndexPut_SingleIndices(PytorchLayerTest):
|
||||
return (self.input_tensor, self.values)
|
||||
|
||||
def create_model(self, indices, accumulate):
|
||||
|
||||
class aten_index_put_(torch.nn.Module):
|
||||
|
||||
def __init__(self, indices, accumulate):
|
||||
super().__init__()
|
||||
self.indices = indices
|
||||
@@ -30,31 +28,43 @@ class TestIndexPut_SingleIndices(PytorchLayerTest):
|
||||
|
||||
return aten_index_put_(indices, accumulate), ref_net, "aten::index_put_"
|
||||
|
||||
@pytest.mark.parametrize('input_data', ({'input_tensor': np.random.randn(5).astype(np.float32),
|
||||
'values': np.array(11).astype(np.float32)},
|
||||
{'input_tensor': np.random.randn(3, 3).astype(np.float32),
|
||||
'values': np.array([10, 11, 12]).astype(np.float32)}))
|
||||
@pytest.mark.parametrize('indices', (torch.tensor([0], dtype=torch.long),
|
||||
torch.tensor([-1, -2], dtype=torch.long),
|
||||
torch.tensor([0, -1, -2], dtype=torch.long),
|
||||
torch.tensor([1, 2], dtype=torch.long),
|
||||
torch.tensor([0, 1, 2], dtype=torch.long)))
|
||||
@pytest.mark.parametrize('accumulate', (True, False))
|
||||
@pytest.mark.parametrize(
|
||||
"input_data",
|
||||
(
|
||||
{
|
||||
"input_tensor": np.random.randn(5).astype(np.float32),
|
||||
"values": np.array(11).astype(np.float32)},
|
||||
{
|
||||
"input_tensor": np.random.randn(3, 3).astype(np.float32),
|
||||
"values": np.array([10, 11, 12]).astype(np.float32),
|
||||
},
|
||||
),
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"indices",
|
||||
(
|
||||
torch.tensor([0], dtype=torch.long),
|
||||
torch.tensor([-1, -2], dtype=torch.long),
|
||||
torch.tensor([0, -1, -2], dtype=torch.long),
|
||||
torch.tensor([1, 2], dtype=torch.long),
|
||||
torch.tensor([0, 1, 2], dtype=torch.long),
|
||||
),
|
||||
)
|
||||
@pytest.mark.parametrize("accumulate", (True, False))
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_index_put_single_indices(self, ie_device, precision, ir_version, input_data, indices, accumulate):
|
||||
self.input_tensor = input_data['input_tensor']
|
||||
self.values = input_data['values']
|
||||
self.input_tensor = input_data["input_tensor"]
|
||||
self.values = input_data["values"]
|
||||
self._test(*self.create_model(indices, accumulate), ie_device, precision, ir_version)
|
||||
|
||||
|
||||
class TestIndexPut_ManyIndices(PytorchLayerTest):
|
||||
def _prepare_input(self):
|
||||
return (self.input_tensor, self.values)
|
||||
|
||||
def create_model(self, indices, accumulate):
|
||||
|
||||
class aten_index_put_(torch.nn.Module):
|
||||
|
||||
def __init__(self, indices, accumulate):
|
||||
super().__init__()
|
||||
self.indices_first = indices[0]
|
||||
@@ -69,26 +79,87 @@ class TestIndexPut_ManyIndices(PytorchLayerTest):
|
||||
|
||||
return aten_index_put_(indices, accumulate), ref_net, "aten::index_put_"
|
||||
|
||||
@pytest.mark.parametrize('input_data', ({'input_tensor': np.random.randn(3, 3).astype(np.float32),
|
||||
'values': np.array(12).astype(np.float32)},
|
||||
{'input_tensor': np.random.randn(3, 3, 3).astype(np.float32),
|
||||
'values': np.array([10, 11, 12]).astype(np.float32)},))
|
||||
@pytest.mark.parametrize('indices', ((torch.tensor([0], dtype=torch.long),
|
||||
torch.tensor([2], dtype=torch.long)),
|
||||
(torch.tensor([1, 2], dtype=torch.long),
|
||||
torch.tensor([0, 1], dtype=torch.long)),
|
||||
(torch.tensor([0, 1], dtype=torch.long),
|
||||
torch.tensor([0, 1], dtype=torch.long)),
|
||||
(torch.tensor([0], dtype=torch.long),
|
||||
torch.tensor([-2], dtype=torch.long)),
|
||||
(torch.tensor([-1, -2], dtype=torch.long),
|
||||
torch.tensor([0, 1], dtype=torch.long)),
|
||||
(torch.tensor([0, -1], dtype=torch.long),
|
||||
torch.tensor([0, -1], dtype=torch.long))))
|
||||
@pytest.mark.parametrize('accumulate', (True, False))
|
||||
@pytest.mark.parametrize(
|
||||
"input_data",
|
||||
(
|
||||
{
|
||||
"input_tensor": np.random.randn(3, 3).astype(np.float32),
|
||||
"values": np.array(12).astype(np.float32)
|
||||
},
|
||||
{
|
||||
"input_tensor": np.random.randn(3, 3, 3).astype(np.float32),
|
||||
"values": np.array([10, 11, 12]).astype(np.float32),
|
||||
},
|
||||
),
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"indices",
|
||||
(
|
||||
(torch.tensor([0], dtype=torch.long), torch.tensor([2], dtype=torch.long)),
|
||||
(torch.tensor([1, 2], dtype=torch.long), torch.tensor([0, 1], dtype=torch.long)),
|
||||
(torch.tensor([0, 1], dtype=torch.long), torch.tensor([0, 1], dtype=torch.long)),
|
||||
(torch.tensor([0], dtype=torch.long), torch.tensor([-2], dtype=torch.long)),
|
||||
(torch.tensor([-1, -2], dtype=torch.long), torch.tensor([0, 1], dtype=torch.long)),
|
||||
(torch.tensor([0, -1], dtype=torch.long), torch.tensor([0, -1], dtype=torch.long)),
|
||||
),
|
||||
)
|
||||
@pytest.mark.parametrize("accumulate", (True, False))
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_index_put_many_indices(self, ie_device, precision, ir_version, input_data, indices, accumulate):
|
||||
self.input_tensor = input_data['input_tensor']
|
||||
self.values = input_data['values']
|
||||
self._test(*self.create_model(indices, accumulate), ie_device, precision, ir_version)
|
||||
self.input_tensor = input_data["input_tensor"]
|
||||
self.values = input_data["values"]
|
||||
self._test(*self.create_model(indices, accumulate), ie_device, precision, ir_version)
|
||||
|
||||
|
||||
class TestNonZero_IndexPut(PytorchLayerTest):
|
||||
def _prepare_input(self):
|
||||
return (self.input_tensor, self.values, self.indices_0, self.indices_1)
|
||||
|
||||
def create_model(self, accumulate):
|
||||
class aten_index_put_(torch.nn.Module):
|
||||
def __init__(self, accumulate):
|
||||
super().__init__()
|
||||
self.accumulate = accumulate
|
||||
|
||||
def forward(self, input_tensor, values, indices_0, indices_1):
|
||||
nonzero = (indices_0 == indices_1).nonzero(as_tuple=True)[0]
|
||||
input_tensor.index_put_((nonzero,), values, self.accumulate)
|
||||
return input_tensor
|
||||
|
||||
ref_net = None
|
||||
|
||||
return aten_index_put_(accumulate), ref_net, "aten::index_put_"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_data",
|
||||
(
|
||||
{
|
||||
"input_tensor": np.random.randn(3).astype(np.float32),
|
||||
"values": np.array(11).astype(np.float32),
|
||||
},
|
||||
{
|
||||
"input_tensor": np.random.randn(3, 3).astype(np.float32),
|
||||
"values": np.array([10, 11, 12]).astype(np.float32),
|
||||
},
|
||||
),
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"indices",
|
||||
(
|
||||
(np.random.randint(low=0, high=2, size=(1,)), np.random.randint(low=0, high=2, size=(1,))),
|
||||
(np.random.randint(low=0, high=2, size=(2,)), np.random.randint(low=0, high=2, size=(2,))),
|
||||
(np.array([0, 1, 0]), np.array([1, 1, 0])),
|
||||
(np.ones(shape=(3,)), np.ones(shape=(3,))),
|
||||
(np.ones(shape=(3,)), np.zeros(shape=(3,))),
|
||||
),
|
||||
)
|
||||
@pytest.mark.parametrize("accumulate", (False, True))
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_nonzero_index_put_(self, ie_device, precision, ir_version, input_data, indices, accumulate):
|
||||
self.input_tensor = input_data["input_tensor"]
|
||||
self.values = input_data["values"]
|
||||
self.indices_0 = indices[0]
|
||||
self.indices_1 = indices[1]
|
||||
self._test(*self.create_model(accumulate), ie_device, precision, ir_version, trace_model=True)
|
||||
|
||||
Reference in New Issue
Block a user