[PT FE] Support aten::randint and aten::index_put_ on mask (#19158)
* [PT FE] Support aten::randint and aten::index_put_ on mask * Fix code style
This commit is contained in:
parent
11610b2cc9
commit
b10b773f0c
@ -170,8 +170,6 @@ OutputVector translate_randn(const NodeContext& context) {
|
|||||||
sizes = concat_list_construct(sizes);
|
sizes = concat_list_construct(sizes);
|
||||||
}
|
}
|
||||||
sizes = context.mark_node(std::make_shared<v0::Convert>(sizes, element::i32));
|
sizes = context.mark_node(std::make_shared<v0::Convert>(sizes, element::i32));
|
||||||
auto low = context.mark_node(v0::Constant::create(element::f32, Shape{1}, {0}));
|
|
||||||
auto high = context.mark_node(v0::Constant::create(element::f32, Shape{1}, {1}));
|
|
||||||
auto dtype = element::f32;
|
auto dtype = element::f32;
|
||||||
size_t out_id = 1;
|
size_t out_id = 1;
|
||||||
if (context.get_input_size() == 3) {
|
if (context.get_input_size() == 3) {
|
||||||
@ -202,8 +200,6 @@ OutputVector translate_randn(const NodeContext& context) {
|
|||||||
if (std::dynamic_pointer_cast<v0::Constant>(
|
if (std::dynamic_pointer_cast<v0::Constant>(
|
||||||
context.get_input_from_visible_context(dtype_id).get_node_shared_ptr())) {
|
context.get_input_from_visible_context(dtype_id).get_node_shared_ptr())) {
|
||||||
dtype = convert_dtype(context.const_input<int64_t>(dtype_id));
|
dtype = convert_dtype(context.const_input<int64_t>(dtype_id));
|
||||||
low = context.mark_node(std::make_shared<v0::Convert>(low, dtype));
|
|
||||||
high = context.mark_node(std::make_shared<v0::Convert>(low, dtype));
|
|
||||||
} else if (const auto& fw_node =
|
} else if (const auto& fw_node =
|
||||||
cast_fw_node(context.get_input(static_cast<int>(dtype_id)).get_node_shared_ptr(),
|
cast_fw_node(context.get_input(static_cast<int>(dtype_id)).get_node_shared_ptr(),
|
||||||
"prim::dtype")) {
|
"prim::dtype")) {
|
||||||
@ -228,8 +224,6 @@ OutputVector translate_randn_like(const NodeContext& context) {
|
|||||||
num_inputs_check(context, 3, 6);
|
num_inputs_check(context, 3, 6);
|
||||||
auto inp_tensor = context.get_input(0);
|
auto inp_tensor = context.get_input(0);
|
||||||
auto sizes = context.mark_node(std::make_shared<v3::ShapeOf>(inp_tensor, element::i32));
|
auto sizes = context.mark_node(std::make_shared<v3::ShapeOf>(inp_tensor, element::i32));
|
||||||
auto low = context.mark_node(v0::Constant::create(element::f32, Shape{1}, {0}));
|
|
||||||
auto high = context.mark_node(v0::Constant::create(element::f32, Shape{1}, {1}));
|
|
||||||
auto dtype = element::f32;
|
auto dtype = element::f32;
|
||||||
if (context.get_input_size() == 3) {
|
if (context.get_input_size() == 3) {
|
||||||
auto res = make_random_normal(context, sizes, dtype);
|
auto res = make_random_normal(context, sizes, dtype);
|
||||||
@ -259,6 +253,36 @@ OutputVector translate_randn_like(const NodeContext& context) {
|
|||||||
return res;
|
return res;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
OutputVector translate_randint(const NodeContext& context) {
|
||||||
|
// aten::randint.low(int low, int high, SymInt[] size, *, ScalarType? dtype=4, Layout? layout=None, Device?
|
||||||
|
// device=None, bool? pin_memory=None) -> Tensor
|
||||||
|
num_inputs_check(context, 7, 7);
|
||||||
|
auto low = context.get_input(0);
|
||||||
|
auto high = context.get_input(1);
|
||||||
|
auto sizes = context.get_input(2);
|
||||||
|
auto dtype = element::i64;
|
||||||
|
bool dtype_applied = true;
|
||||||
|
Output<Node> convert_like_out;
|
||||||
|
if (!context.input_is_none(3)) {
|
||||||
|
if (std::dynamic_pointer_cast<v0::Constant>(context.get_input_from_visible_context(3).get_node_shared_ptr())) {
|
||||||
|
dtype = convert_dtype(context.const_input<int64_t>(3));
|
||||||
|
} else if (const auto& fw_node =
|
||||||
|
cast_fw_node(context.get_input(static_cast<int>(3)).get_node_shared_ptr(), "prim::dtype")) {
|
||||||
|
convert_like_out = fw_node->input_value(0);
|
||||||
|
dtype_applied = false;
|
||||||
|
} else {
|
||||||
|
FRONT_END_OP_CONVERSION_CHECK(false, "Couldn't get dtype input");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
low = context.mark_node(std::make_shared<v0::Convert>(low, dtype));
|
||||||
|
high = context.mark_node(std::make_shared<v0::Convert>(high, dtype));
|
||||||
|
auto res = context.mark_node(std::make_shared<v8::RandomUniform>(sizes, low, high, dtype));
|
||||||
|
if (!dtype_applied) {
|
||||||
|
res = context.mark_node(std::make_shared<v1::ConvertLike>(res, convert_like_out));
|
||||||
|
}
|
||||||
|
return {res};
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace op
|
} // namespace op
|
||||||
} // namespace pytorch
|
} // namespace pytorch
|
||||||
} // namespace frontend
|
} // namespace frontend
|
||||||
|
@ -125,6 +125,7 @@ OP_CONVERTER(translate_quantized_mul);
|
|||||||
OP_CONVERTER(translate_range_length);
|
OP_CONVERTER(translate_range_length);
|
||||||
OP_CONVERTER(translate_rand);
|
OP_CONVERTER(translate_rand);
|
||||||
OP_CONVERTER(translate_randn);
|
OP_CONVERTER(translate_randn);
|
||||||
|
OP_CONVERTER(translate_randint);
|
||||||
OP_CONVERTER(translate_rand_like);
|
OP_CONVERTER(translate_rand_like);
|
||||||
OP_CONVERTER(translate_randn_like);
|
OP_CONVERTER(translate_randn_like);
|
||||||
OP_CONVERTER(translate_reciprocal);
|
OP_CONVERTER(translate_reciprocal);
|
||||||
@ -379,6 +380,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
|
|||||||
{"aten::quantize_per_tensor", op::translate_quantize_per_tensor},
|
{"aten::quantize_per_tensor", op::translate_quantize_per_tensor},
|
||||||
{"aten::rand", op::translate_rand},
|
{"aten::rand", op::translate_rand},
|
||||||
{"aten::randn", op::translate_randn},
|
{"aten::randn", op::translate_randn},
|
||||||
|
{"aten::randint", op::translate_randint},
|
||||||
{"aten::rand_like", op::translate_rand_like},
|
{"aten::rand_like", op::translate_rand_like},
|
||||||
{"aten::randn_like", op::translate_randn_like},
|
{"aten::randn_like", op::translate_randn_like},
|
||||||
{"aten::reciprocal", op::translate_reciprocal},
|
{"aten::reciprocal", op::translate_reciprocal},
|
||||||
|
@ -13,10 +13,12 @@
|
|||||||
#include "openvino/op/convert_like.hpp"
|
#include "openvino/op/convert_like.hpp"
|
||||||
#include "openvino/op/gather.hpp"
|
#include "openvino/op/gather.hpp"
|
||||||
#include "openvino/op/mod.hpp"
|
#include "openvino/op/mod.hpp"
|
||||||
|
#include "openvino/op/non_zero.hpp"
|
||||||
#include "openvino/op/scatter_nd_update.hpp"
|
#include "openvino/op/scatter_nd_update.hpp"
|
||||||
#include "openvino/op/shape_of.hpp"
|
#include "openvino/op/shape_of.hpp"
|
||||||
#include "openvino/op/slice.hpp"
|
#include "openvino/op/slice.hpp"
|
||||||
#include "openvino/op/split.hpp"
|
#include "openvino/op/split.hpp"
|
||||||
|
#include "openvino/op/transpose.hpp"
|
||||||
#include "openvino/op/unsqueeze.hpp"
|
#include "openvino/op/unsqueeze.hpp"
|
||||||
#include "openvino/op/util/framework_node.hpp"
|
#include "openvino/op/util/framework_node.hpp"
|
||||||
#include "openvino/pass/pattern/matcher.hpp"
|
#include "openvino/pass/pattern/matcher.hpp"
|
||||||
@ -123,14 +125,32 @@ AtenIndexPutReplacer::AtenIndexPutReplacer() {
|
|||||||
index = rg.make<v0::Concat>(indices_list, -1);
|
index = rg.make<v0::Concat>(indices_list, -1);
|
||||||
} else {
|
} else {
|
||||||
index = indices_inputs[0];
|
index = indices_inputs[0];
|
||||||
// change negative indices to positive indices
|
auto index_dtype = index.get_element_type();
|
||||||
auto dim_0 = (rg.make<v8::Gather>(input_shape, const_0, const_0));
|
// Do we need to also check u8?
|
||||||
auto dim_0_correct_type = (rg.make<v1::ConvertLike>(dim_0, index));
|
if (index_dtype == element::boolean) {
|
||||||
index = rg.make<v1::Add>(index, dim_0_correct_type);
|
auto nonzero = rg.make<v3::NonZero>(index, element::i32);
|
||||||
index = rg.make<v1::Mod>(index, dim_0_correct_type);
|
auto input_order = v0::Constant::create(element::i32, Shape{2}, {1, 0});
|
||||||
|
index = rg.make<v1::Transpose>(nonzero, input_order);
|
||||||
|
broadcast_index_shape = rg.make<v3::ShapeOf>(index, element::i32);
|
||||||
|
auto start_0 = v0::Constant::create(element::i32, Shape{1}, {0});
|
||||||
|
auto end_neg_1 = v0::Constant::create(element::i32, Shape{1}, {-1});
|
||||||
|
auto values_shape = rg.make<v8::Slice>(broadcast_index_shape, start_0, end_neg_1, const_1);
|
||||||
|
values = rg.make<v3::Broadcast>(values, values_shape);
|
||||||
|
values = rg.make<v1::ConvertLike>(values, input);
|
||||||
|
auto result = rg.make<v3::ScatterNDUpdate>(input, index, values);
|
||||||
|
copy_runtime_info_and_name(index_op, rg.get(), rt_copy_from);
|
||||||
|
replace_node(index_op, result);
|
||||||
|
return true;
|
||||||
|
} else {
|
||||||
|
// change negative indices to positive indices
|
||||||
|
auto dim_0 = (rg.make<v8::Gather>(input_shape, const_0, const_0));
|
||||||
|
auto dim_0_correct_type = (rg.make<v1::ConvertLike>(dim_0, index));
|
||||||
|
index = rg.make<v1::Add>(index, dim_0_correct_type);
|
||||||
|
index = rg.make<v1::Mod>(index, dim_0_correct_type);
|
||||||
|
|
||||||
broadcast_index_shape = rg.make<v3::ShapeOf>(index, element::i32);
|
broadcast_index_shape = rg.make<v3::ShapeOf>(index, element::i32);
|
||||||
index = rg.make<v0::Unsqueeze>(index, const_neg_1);
|
index = rg.make<v0::Unsqueeze>(index, const_neg_1);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
auto sub_data_shape = rg.make<v8::Slice>(input_shape, const_indices_list_len, const_max_int, const_1);
|
auto sub_data_shape = rg.make<v8::Slice>(input_shape, const_indices_list_len, const_max_int, const_1);
|
||||||
|
@ -13,6 +13,7 @@
|
|||||||
#include "openvino/op/equal.hpp"
|
#include "openvino/op/equal.hpp"
|
||||||
#include "openvino/op/interpolate.hpp"
|
#include "openvino/op/interpolate.hpp"
|
||||||
#include "openvino/op/multiply.hpp"
|
#include "openvino/op/multiply.hpp"
|
||||||
|
#include "openvino/op/random_uniform.hpp"
|
||||||
#include "openvino/op/reshape.hpp"
|
#include "openvino/op/reshape.hpp"
|
||||||
#include "openvino/op/roll.hpp"
|
#include "openvino/op/roll.hpp"
|
||||||
#include "openvino/op/select.hpp"
|
#include "openvino/op/select.hpp"
|
||||||
@ -60,6 +61,8 @@ ListConstructReplacer::ListConstructReplacer() {
|
|||||||
auto interpolate_mul_op = pattern::wrap_type<v1::Multiply>({interpolate_convert_op, pattern::any_input()});
|
auto interpolate_mul_op = pattern::wrap_type<v1::Multiply>({interpolate_convert_op, pattern::any_input()});
|
||||||
auto interpolate_op =
|
auto interpolate_op =
|
||||||
pattern::wrap_type<v11::Interpolate>({pattern::any_input(), interpolate_mul_op, pattern::any_input()});
|
pattern::wrap_type<v11::Interpolate>({pattern::any_input(), interpolate_mul_op, pattern::any_input()});
|
||||||
|
// aten::randint case
|
||||||
|
auto rand_op = pattern::wrap_type<v8::RandomUniform>({list, pattern::any_input(), pattern::any_input()});
|
||||||
auto lc_pattern = std::make_shared<pattern::op::Or>(OutputVector{reshape_op,
|
auto lc_pattern = std::make_shared<pattern::op::Or>(OutputVector{reshape_op,
|
||||||
roll_op,
|
roll_op,
|
||||||
broadcast_op,
|
broadcast_op,
|
||||||
@ -70,7 +73,8 @@ ListConstructReplacer::ListConstructReplacer() {
|
|||||||
tile_op,
|
tile_op,
|
||||||
transpose_op,
|
transpose_op,
|
||||||
vsplit_op,
|
vsplit_op,
|
||||||
interpolate_op});
|
interpolate_op,
|
||||||
|
rand_op});
|
||||||
|
|
||||||
ov::matcher_pass_callback callback = [=](pattern::Matcher& m) {
|
ov::matcher_pass_callback callback = [=](pattern::Matcher& m) {
|
||||||
auto& pattern_map = m.get_pattern_value_map();
|
auto& pattern_map = m.get_pattern_value_map();
|
||||||
|
@ -125,3 +125,29 @@ class TestIndexRange(PytorchLayerTest):
|
|||||||
def test_index_range_free_dims(self, input_shape, idx, ie_device, precision, ir_version):
|
def test_index_range_free_dims(self, input_shape, idx, ie_device, precision, ir_version):
|
||||||
self._test(*self.create_model2(), ie_device, precision, ir_version, kwargs_to_prepare_input={
|
self._test(*self.create_model2(), ie_device, precision, ir_version, kwargs_to_prepare_input={
|
||||||
"input_shape": input_shape, "idx": idx}, trace_model=True, dynamic_shapes=False)
|
"input_shape": input_shape, "idx": idx}, trace_model=True, dynamic_shapes=False)
|
||||||
|
|
||||||
|
class TestIndexMask(PytorchLayerTest):
|
||||||
|
def _prepare_input(self, input_shape):
|
||||||
|
import numpy as np
|
||||||
|
return (np.random.randn(*input_shape).astype(np.float32),)
|
||||||
|
|
||||||
|
def create_model(self):
|
||||||
|
import torch
|
||||||
|
|
||||||
|
class aten_index_mask(torch.nn.Module):
|
||||||
|
def forward(self, x):
|
||||||
|
return x[x > 0]
|
||||||
|
|
||||||
|
ref_net = None
|
||||||
|
|
||||||
|
return aten_index_mask(), ref_net, "aten::index"
|
||||||
|
|
||||||
|
@pytest.mark.nightly
|
||||||
|
@pytest.mark.precommit
|
||||||
|
@pytest.mark.parametrize(("input_shape"), ((1, 1),
|
||||||
|
[2, 3],
|
||||||
|
[7, 8, 9],
|
||||||
|
[2, 2, 3, 4]))
|
||||||
|
def test_index_mask(self, input_shape, ie_device, precision, ir_version):
|
||||||
|
self._test(*self.create_model(), ie_device, precision, ir_version, kwargs_to_prepare_input={
|
||||||
|
"input_shape": input_shape}, trace_model=True)
|
||||||
|
@ -163,3 +163,22 @@ class TestNonZero_IndexPut(PytorchLayerTest):
|
|||||||
self.indices_0 = indices[0]
|
self.indices_0 = indices[0]
|
||||||
self.indices_1 = indices[1]
|
self.indices_1 = indices[1]
|
||||||
self._test(*self.create_model(accumulate), ie_device, precision, ir_version, trace_model=True)
|
self._test(*self.create_model(accumulate), ie_device, precision, ir_version, trace_model=True)
|
||||||
|
|
||||||
|
class TestMask_IndexPut(PytorchLayerTest):
|
||||||
|
def _prepare_input(self):
|
||||||
|
return (np.random.randn(100, 5).astype(np.float32),np.random.randn(100, 5).astype(np.float32))
|
||||||
|
|
||||||
|
def create_model(self):
|
||||||
|
class aten_index_put_mask(torch.nn.Module):
|
||||||
|
def forward(self, x, y):
|
||||||
|
x[x < 0] = y[x < 0]
|
||||||
|
return x
|
||||||
|
|
||||||
|
ref_net = None
|
||||||
|
|
||||||
|
return aten_index_put_mask(), ref_net, "aten::index_put_"
|
||||||
|
|
||||||
|
@pytest.mark.nightly
|
||||||
|
@pytest.mark.precommit
|
||||||
|
def test_nonzero_index_put_(self, ie_device, precision, ir_version):
|
||||||
|
self._test(*self.create_model(), ie_device, precision, ir_version, trace_model=True)
|
||||||
|
Loading…
Reference in New Issue
Block a user