[PT FE] Support aten::randint and aten::index_put_ on mask (#19158)
* [PT FE] Support aten::randint and aten::index_put_ on mask * Fix code style
This commit is contained in:
parent
11610b2cc9
commit
b10b773f0c
@ -170,8 +170,6 @@ OutputVector translate_randn(const NodeContext& context) {
|
||||
sizes = concat_list_construct(sizes);
|
||||
}
|
||||
sizes = context.mark_node(std::make_shared<v0::Convert>(sizes, element::i32));
|
||||
auto low = context.mark_node(v0::Constant::create(element::f32, Shape{1}, {0}));
|
||||
auto high = context.mark_node(v0::Constant::create(element::f32, Shape{1}, {1}));
|
||||
auto dtype = element::f32;
|
||||
size_t out_id = 1;
|
||||
if (context.get_input_size() == 3) {
|
||||
@ -202,8 +200,6 @@ OutputVector translate_randn(const NodeContext& context) {
|
||||
if (std::dynamic_pointer_cast<v0::Constant>(
|
||||
context.get_input_from_visible_context(dtype_id).get_node_shared_ptr())) {
|
||||
dtype = convert_dtype(context.const_input<int64_t>(dtype_id));
|
||||
low = context.mark_node(std::make_shared<v0::Convert>(low, dtype));
|
||||
high = context.mark_node(std::make_shared<v0::Convert>(low, dtype));
|
||||
} else if (const auto& fw_node =
|
||||
cast_fw_node(context.get_input(static_cast<int>(dtype_id)).get_node_shared_ptr(),
|
||||
"prim::dtype")) {
|
||||
@ -228,8 +224,6 @@ OutputVector translate_randn_like(const NodeContext& context) {
|
||||
num_inputs_check(context, 3, 6);
|
||||
auto inp_tensor = context.get_input(0);
|
||||
auto sizes = context.mark_node(std::make_shared<v3::ShapeOf>(inp_tensor, element::i32));
|
||||
auto low = context.mark_node(v0::Constant::create(element::f32, Shape{1}, {0}));
|
||||
auto high = context.mark_node(v0::Constant::create(element::f32, Shape{1}, {1}));
|
||||
auto dtype = element::f32;
|
||||
if (context.get_input_size() == 3) {
|
||||
auto res = make_random_normal(context, sizes, dtype);
|
||||
@ -259,6 +253,36 @@ OutputVector translate_randn_like(const NodeContext& context) {
|
||||
return res;
|
||||
};
|
||||
|
||||
OutputVector translate_randint(const NodeContext& context) {
|
||||
// aten::randint.low(int low, int high, SymInt[] size, *, ScalarType? dtype=4, Layout? layout=None, Device?
|
||||
// device=None, bool? pin_memory=None) -> Tensor
|
||||
num_inputs_check(context, 7, 7);
|
||||
auto low = context.get_input(0);
|
||||
auto high = context.get_input(1);
|
||||
auto sizes = context.get_input(2);
|
||||
auto dtype = element::i64;
|
||||
bool dtype_applied = true;
|
||||
Output<Node> convert_like_out;
|
||||
if (!context.input_is_none(3)) {
|
||||
if (std::dynamic_pointer_cast<v0::Constant>(context.get_input_from_visible_context(3).get_node_shared_ptr())) {
|
||||
dtype = convert_dtype(context.const_input<int64_t>(3));
|
||||
} else if (const auto& fw_node =
|
||||
cast_fw_node(context.get_input(static_cast<int>(3)).get_node_shared_ptr(), "prim::dtype")) {
|
||||
convert_like_out = fw_node->input_value(0);
|
||||
dtype_applied = false;
|
||||
} else {
|
||||
FRONT_END_OP_CONVERSION_CHECK(false, "Couldn't get dtype input");
|
||||
}
|
||||
}
|
||||
low = context.mark_node(std::make_shared<v0::Convert>(low, dtype));
|
||||
high = context.mark_node(std::make_shared<v0::Convert>(high, dtype));
|
||||
auto res = context.mark_node(std::make_shared<v8::RandomUniform>(sizes, low, high, dtype));
|
||||
if (!dtype_applied) {
|
||||
res = context.mark_node(std::make_shared<v1::ConvertLike>(res, convert_like_out));
|
||||
}
|
||||
return {res};
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
} // namespace pytorch
|
||||
} // namespace frontend
|
||||
|
@ -125,6 +125,7 @@ OP_CONVERTER(translate_quantized_mul);
|
||||
OP_CONVERTER(translate_range_length);
|
||||
OP_CONVERTER(translate_rand);
|
||||
OP_CONVERTER(translate_randn);
|
||||
OP_CONVERTER(translate_randint);
|
||||
OP_CONVERTER(translate_rand_like);
|
||||
OP_CONVERTER(translate_randn_like);
|
||||
OP_CONVERTER(translate_reciprocal);
|
||||
@ -379,6 +380,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
|
||||
{"aten::quantize_per_tensor", op::translate_quantize_per_tensor},
|
||||
{"aten::rand", op::translate_rand},
|
||||
{"aten::randn", op::translate_randn},
|
||||
{"aten::randint", op::translate_randint},
|
||||
{"aten::rand_like", op::translate_rand_like},
|
||||
{"aten::randn_like", op::translate_randn_like},
|
||||
{"aten::reciprocal", op::translate_reciprocal},
|
||||
|
@ -13,10 +13,12 @@
|
||||
#include "openvino/op/convert_like.hpp"
|
||||
#include "openvino/op/gather.hpp"
|
||||
#include "openvino/op/mod.hpp"
|
||||
#include "openvino/op/non_zero.hpp"
|
||||
#include "openvino/op/scatter_nd_update.hpp"
|
||||
#include "openvino/op/shape_of.hpp"
|
||||
#include "openvino/op/slice.hpp"
|
||||
#include "openvino/op/split.hpp"
|
||||
#include "openvino/op/transpose.hpp"
|
||||
#include "openvino/op/unsqueeze.hpp"
|
||||
#include "openvino/op/util/framework_node.hpp"
|
||||
#include "openvino/pass/pattern/matcher.hpp"
|
||||
@ -123,14 +125,32 @@ AtenIndexPutReplacer::AtenIndexPutReplacer() {
|
||||
index = rg.make<v0::Concat>(indices_list, -1);
|
||||
} else {
|
||||
index = indices_inputs[0];
|
||||
// change negative indices to positive indices
|
||||
auto dim_0 = (rg.make<v8::Gather>(input_shape, const_0, const_0));
|
||||
auto dim_0_correct_type = (rg.make<v1::ConvertLike>(dim_0, index));
|
||||
index = rg.make<v1::Add>(index, dim_0_correct_type);
|
||||
index = rg.make<v1::Mod>(index, dim_0_correct_type);
|
||||
auto index_dtype = index.get_element_type();
|
||||
// Do we need to also check u8?
|
||||
if (index_dtype == element::boolean) {
|
||||
auto nonzero = rg.make<v3::NonZero>(index, element::i32);
|
||||
auto input_order = v0::Constant::create(element::i32, Shape{2}, {1, 0});
|
||||
index = rg.make<v1::Transpose>(nonzero, input_order);
|
||||
broadcast_index_shape = rg.make<v3::ShapeOf>(index, element::i32);
|
||||
auto start_0 = v0::Constant::create(element::i32, Shape{1}, {0});
|
||||
auto end_neg_1 = v0::Constant::create(element::i32, Shape{1}, {-1});
|
||||
auto values_shape = rg.make<v8::Slice>(broadcast_index_shape, start_0, end_neg_1, const_1);
|
||||
values = rg.make<v3::Broadcast>(values, values_shape);
|
||||
values = rg.make<v1::ConvertLike>(values, input);
|
||||
auto result = rg.make<v3::ScatterNDUpdate>(input, index, values);
|
||||
copy_runtime_info_and_name(index_op, rg.get(), rt_copy_from);
|
||||
replace_node(index_op, result);
|
||||
return true;
|
||||
} else {
|
||||
// change negative indices to positive indices
|
||||
auto dim_0 = (rg.make<v8::Gather>(input_shape, const_0, const_0));
|
||||
auto dim_0_correct_type = (rg.make<v1::ConvertLike>(dim_0, index));
|
||||
index = rg.make<v1::Add>(index, dim_0_correct_type);
|
||||
index = rg.make<v1::Mod>(index, dim_0_correct_type);
|
||||
|
||||
broadcast_index_shape = rg.make<v3::ShapeOf>(index, element::i32);
|
||||
index = rg.make<v0::Unsqueeze>(index, const_neg_1);
|
||||
broadcast_index_shape = rg.make<v3::ShapeOf>(index, element::i32);
|
||||
index = rg.make<v0::Unsqueeze>(index, const_neg_1);
|
||||
}
|
||||
}
|
||||
|
||||
auto sub_data_shape = rg.make<v8::Slice>(input_shape, const_indices_list_len, const_max_int, const_1);
|
||||
|
@ -13,6 +13,7 @@
|
||||
#include "openvino/op/equal.hpp"
|
||||
#include "openvino/op/interpolate.hpp"
|
||||
#include "openvino/op/multiply.hpp"
|
||||
#include "openvino/op/random_uniform.hpp"
|
||||
#include "openvino/op/reshape.hpp"
|
||||
#include "openvino/op/roll.hpp"
|
||||
#include "openvino/op/select.hpp"
|
||||
@ -60,6 +61,8 @@ ListConstructReplacer::ListConstructReplacer() {
|
||||
auto interpolate_mul_op = pattern::wrap_type<v1::Multiply>({interpolate_convert_op, pattern::any_input()});
|
||||
auto interpolate_op =
|
||||
pattern::wrap_type<v11::Interpolate>({pattern::any_input(), interpolate_mul_op, pattern::any_input()});
|
||||
// aten::randint case
|
||||
auto rand_op = pattern::wrap_type<v8::RandomUniform>({list, pattern::any_input(), pattern::any_input()});
|
||||
auto lc_pattern = std::make_shared<pattern::op::Or>(OutputVector{reshape_op,
|
||||
roll_op,
|
||||
broadcast_op,
|
||||
@ -70,7 +73,8 @@ ListConstructReplacer::ListConstructReplacer() {
|
||||
tile_op,
|
||||
transpose_op,
|
||||
vsplit_op,
|
||||
interpolate_op});
|
||||
interpolate_op,
|
||||
rand_op});
|
||||
|
||||
ov::matcher_pass_callback callback = [=](pattern::Matcher& m) {
|
||||
auto& pattern_map = m.get_pattern_value_map();
|
||||
|
@ -125,3 +125,29 @@ class TestIndexRange(PytorchLayerTest):
|
||||
def test_index_range_free_dims(self, input_shape, idx, ie_device, precision, ir_version):
|
||||
self._test(*self.create_model2(), ie_device, precision, ir_version, kwargs_to_prepare_input={
|
||||
"input_shape": input_shape, "idx": idx}, trace_model=True, dynamic_shapes=False)
|
||||
|
||||
class TestIndexMask(PytorchLayerTest):
|
||||
def _prepare_input(self, input_shape):
|
||||
import numpy as np
|
||||
return (np.random.randn(*input_shape).astype(np.float32),)
|
||||
|
||||
def create_model(self):
|
||||
import torch
|
||||
|
||||
class aten_index_mask(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return x[x > 0]
|
||||
|
||||
ref_net = None
|
||||
|
||||
return aten_index_mask(), ref_net, "aten::index"
|
||||
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
@pytest.mark.parametrize(("input_shape"), ((1, 1),
|
||||
[2, 3],
|
||||
[7, 8, 9],
|
||||
[2, 2, 3, 4]))
|
||||
def test_index_mask(self, input_shape, ie_device, precision, ir_version):
|
||||
self._test(*self.create_model(), ie_device, precision, ir_version, kwargs_to_prepare_input={
|
||||
"input_shape": input_shape}, trace_model=True)
|
||||
|
@ -163,3 +163,22 @@ class TestNonZero_IndexPut(PytorchLayerTest):
|
||||
self.indices_0 = indices[0]
|
||||
self.indices_1 = indices[1]
|
||||
self._test(*self.create_model(accumulate), ie_device, precision, ir_version, trace_model=True)
|
||||
|
||||
class TestMask_IndexPut(PytorchLayerTest):
|
||||
def _prepare_input(self):
|
||||
return (np.random.randn(100, 5).astype(np.float32),np.random.randn(100, 5).astype(np.float32))
|
||||
|
||||
def create_model(self):
|
||||
class aten_index_put_mask(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
x[x < 0] = y[x < 0]
|
||||
return x
|
||||
|
||||
ref_net = None
|
||||
|
||||
return aten_index_put_mask(), ref_net, "aten::index_put_"
|
||||
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_nonzero_index_put_(self, ie_device, precision, ir_version):
|
||||
self._test(*self.create_model(), ie_device, precision, ir_version, trace_model=True)
|
||||
|
Loading…
Reference in New Issue
Block a user