From b10b773f0c02bef3a71b60beb497317db98d7a68 Mon Sep 17 00:00:00 2001 From: Maxim Vafin Date: Mon, 14 Aug 2023 15:08:08 +0200 Subject: [PATCH] [PT FE] Support aten::randint and aten::index_put_ on mask (#19158) * [PT FE] Support aten::randint and aten::index_put_ on mask * Fix code style --- src/frontends/pytorch/src/op/rand.cpp | 36 +++++++++++++++---- src/frontends/pytorch/src/op_table.cpp | 2 ++ .../transforms/aten_index_put_replacer.cpp | 34 ++++++++++++++---- .../src/transforms/listconstruct_replacer.cpp | 6 +++- tests/layer_tests/pytorch_tests/test_index.py | 26 ++++++++++++++ .../pytorch_tests/test_index_put_.py | 19 ++++++++++ 6 files changed, 109 insertions(+), 14 deletions(-) diff --git a/src/frontends/pytorch/src/op/rand.cpp b/src/frontends/pytorch/src/op/rand.cpp index 152c1a98701..f0c5308d7f2 100644 --- a/src/frontends/pytorch/src/op/rand.cpp +++ b/src/frontends/pytorch/src/op/rand.cpp @@ -170,8 +170,6 @@ OutputVector translate_randn(const NodeContext& context) { sizes = concat_list_construct(sizes); } sizes = context.mark_node(std::make_shared(sizes, element::i32)); - auto low = context.mark_node(v0::Constant::create(element::f32, Shape{1}, {0})); - auto high = context.mark_node(v0::Constant::create(element::f32, Shape{1}, {1})); auto dtype = element::f32; size_t out_id = 1; if (context.get_input_size() == 3) { @@ -202,8 +200,6 @@ OutputVector translate_randn(const NodeContext& context) { if (std::dynamic_pointer_cast( context.get_input_from_visible_context(dtype_id).get_node_shared_ptr())) { dtype = convert_dtype(context.const_input(dtype_id)); - low = context.mark_node(std::make_shared(low, dtype)); - high = context.mark_node(std::make_shared(low, dtype)); } else if (const auto& fw_node = cast_fw_node(context.get_input(static_cast(dtype_id)).get_node_shared_ptr(), "prim::dtype")) { @@ -228,8 +224,6 @@ OutputVector translate_randn_like(const NodeContext& context) { num_inputs_check(context, 3, 6); auto inp_tensor = context.get_input(0); auto sizes = context.mark_node(std::make_shared(inp_tensor, element::i32)); - auto low = context.mark_node(v0::Constant::create(element::f32, Shape{1}, {0})); - auto high = context.mark_node(v0::Constant::create(element::f32, Shape{1}, {1})); auto dtype = element::f32; if (context.get_input_size() == 3) { auto res = make_random_normal(context, sizes, dtype); @@ -259,6 +253,36 @@ OutputVector translate_randn_like(const NodeContext& context) { return res; }; +OutputVector translate_randint(const NodeContext& context) { + // aten::randint.low(int low, int high, SymInt[] size, *, ScalarType? dtype=4, Layout? layout=None, Device? + // device=None, bool? pin_memory=None) -> Tensor + num_inputs_check(context, 7, 7); + auto low = context.get_input(0); + auto high = context.get_input(1); + auto sizes = context.get_input(2); + auto dtype = element::i64; + bool dtype_applied = true; + Output convert_like_out; + if (!context.input_is_none(3)) { + if (std::dynamic_pointer_cast(context.get_input_from_visible_context(3).get_node_shared_ptr())) { + dtype = convert_dtype(context.const_input(3)); + } else if (const auto& fw_node = + cast_fw_node(context.get_input(static_cast(3)).get_node_shared_ptr(), "prim::dtype")) { + convert_like_out = fw_node->input_value(0); + dtype_applied = false; + } else { + FRONT_END_OP_CONVERSION_CHECK(false, "Couldn't get dtype input"); + } + } + low = context.mark_node(std::make_shared(low, dtype)); + high = context.mark_node(std::make_shared(high, dtype)); + auto res = context.mark_node(std::make_shared(sizes, low, high, dtype)); + if (!dtype_applied) { + res = context.mark_node(std::make_shared(res, convert_like_out)); + } + return {res}; +}; + } // namespace op } // namespace pytorch } // namespace frontend diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index 6ba7e70ab73..fbc13a99447 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -125,6 +125,7 @@ OP_CONVERTER(translate_quantized_mul); OP_CONVERTER(translate_range_length); OP_CONVERTER(translate_rand); OP_CONVERTER(translate_randn); +OP_CONVERTER(translate_randint); OP_CONVERTER(translate_rand_like); OP_CONVERTER(translate_randn_like); OP_CONVERTER(translate_reciprocal); @@ -379,6 +380,7 @@ const std::map get_supported_ops_ts() { {"aten::quantize_per_tensor", op::translate_quantize_per_tensor}, {"aten::rand", op::translate_rand}, {"aten::randn", op::translate_randn}, + {"aten::randint", op::translate_randint}, {"aten::rand_like", op::translate_rand_like}, {"aten::randn_like", op::translate_randn_like}, {"aten::reciprocal", op::translate_reciprocal}, diff --git a/src/frontends/pytorch/src/transforms/aten_index_put_replacer.cpp b/src/frontends/pytorch/src/transforms/aten_index_put_replacer.cpp index c73767840b1..3959e2383a8 100644 --- a/src/frontends/pytorch/src/transforms/aten_index_put_replacer.cpp +++ b/src/frontends/pytorch/src/transforms/aten_index_put_replacer.cpp @@ -13,10 +13,12 @@ #include "openvino/op/convert_like.hpp" #include "openvino/op/gather.hpp" #include "openvino/op/mod.hpp" +#include "openvino/op/non_zero.hpp" #include "openvino/op/scatter_nd_update.hpp" #include "openvino/op/shape_of.hpp" #include "openvino/op/slice.hpp" #include "openvino/op/split.hpp" +#include "openvino/op/transpose.hpp" #include "openvino/op/unsqueeze.hpp" #include "openvino/op/util/framework_node.hpp" #include "openvino/pass/pattern/matcher.hpp" @@ -123,14 +125,32 @@ AtenIndexPutReplacer::AtenIndexPutReplacer() { index = rg.make(indices_list, -1); } else { index = indices_inputs[0]; - // change negative indices to positive indices - auto dim_0 = (rg.make(input_shape, const_0, const_0)); - auto dim_0_correct_type = (rg.make(dim_0, index)); - index = rg.make(index, dim_0_correct_type); - index = rg.make(index, dim_0_correct_type); + auto index_dtype = index.get_element_type(); + // Do we need to also check u8? + if (index_dtype == element::boolean) { + auto nonzero = rg.make(index, element::i32); + auto input_order = v0::Constant::create(element::i32, Shape{2}, {1, 0}); + index = rg.make(nonzero, input_order); + broadcast_index_shape = rg.make(index, element::i32); + auto start_0 = v0::Constant::create(element::i32, Shape{1}, {0}); + auto end_neg_1 = v0::Constant::create(element::i32, Shape{1}, {-1}); + auto values_shape = rg.make(broadcast_index_shape, start_0, end_neg_1, const_1); + values = rg.make(values, values_shape); + values = rg.make(values, input); + auto result = rg.make(input, index, values); + copy_runtime_info_and_name(index_op, rg.get(), rt_copy_from); + replace_node(index_op, result); + return true; + } else { + // change negative indices to positive indices + auto dim_0 = (rg.make(input_shape, const_0, const_0)); + auto dim_0_correct_type = (rg.make(dim_0, index)); + index = rg.make(index, dim_0_correct_type); + index = rg.make(index, dim_0_correct_type); - broadcast_index_shape = rg.make(index, element::i32); - index = rg.make(index, const_neg_1); + broadcast_index_shape = rg.make(index, element::i32); + index = rg.make(index, const_neg_1); + } } auto sub_data_shape = rg.make(input_shape, const_indices_list_len, const_max_int, const_1); diff --git a/src/frontends/pytorch/src/transforms/listconstruct_replacer.cpp b/src/frontends/pytorch/src/transforms/listconstruct_replacer.cpp index 72c7d620592..e223c5a73ce 100644 --- a/src/frontends/pytorch/src/transforms/listconstruct_replacer.cpp +++ b/src/frontends/pytorch/src/transforms/listconstruct_replacer.cpp @@ -13,6 +13,7 @@ #include "openvino/op/equal.hpp" #include "openvino/op/interpolate.hpp" #include "openvino/op/multiply.hpp" +#include "openvino/op/random_uniform.hpp" #include "openvino/op/reshape.hpp" #include "openvino/op/roll.hpp" #include "openvino/op/select.hpp" @@ -60,6 +61,8 @@ ListConstructReplacer::ListConstructReplacer() { auto interpolate_mul_op = pattern::wrap_type({interpolate_convert_op, pattern::any_input()}); auto interpolate_op = pattern::wrap_type({pattern::any_input(), interpolate_mul_op, pattern::any_input()}); + // aten::randint case + auto rand_op = pattern::wrap_type({list, pattern::any_input(), pattern::any_input()}); auto lc_pattern = std::make_shared(OutputVector{reshape_op, roll_op, broadcast_op, @@ -70,7 +73,8 @@ ListConstructReplacer::ListConstructReplacer() { tile_op, transpose_op, vsplit_op, - interpolate_op}); + interpolate_op, + rand_op}); ov::matcher_pass_callback callback = [=](pattern::Matcher& m) { auto& pattern_map = m.get_pattern_value_map(); diff --git a/tests/layer_tests/pytorch_tests/test_index.py b/tests/layer_tests/pytorch_tests/test_index.py index f3f681d1c3c..6f7cea86990 100644 --- a/tests/layer_tests/pytorch_tests/test_index.py +++ b/tests/layer_tests/pytorch_tests/test_index.py @@ -125,3 +125,29 @@ class TestIndexRange(PytorchLayerTest): def test_index_range_free_dims(self, input_shape, idx, ie_device, precision, ir_version): self._test(*self.create_model2(), ie_device, precision, ir_version, kwargs_to_prepare_input={ "input_shape": input_shape, "idx": idx}, trace_model=True, dynamic_shapes=False) + +class TestIndexMask(PytorchLayerTest): + def _prepare_input(self, input_shape): + import numpy as np + return (np.random.randn(*input_shape).astype(np.float32),) + + def create_model(self): + import torch + + class aten_index_mask(torch.nn.Module): + def forward(self, x): + return x[x > 0] + + ref_net = None + + return aten_index_mask(), ref_net, "aten::index" + + @pytest.mark.nightly + @pytest.mark.precommit + @pytest.mark.parametrize(("input_shape"), ((1, 1), + [2, 3], + [7, 8, 9], + [2, 2, 3, 4])) + def test_index_mask(self, input_shape, ie_device, precision, ir_version): + self._test(*self.create_model(), ie_device, precision, ir_version, kwargs_to_prepare_input={ + "input_shape": input_shape}, trace_model=True) diff --git a/tests/layer_tests/pytorch_tests/test_index_put_.py b/tests/layer_tests/pytorch_tests/test_index_put_.py index 61c6ced767b..55cbe39bd92 100644 --- a/tests/layer_tests/pytorch_tests/test_index_put_.py +++ b/tests/layer_tests/pytorch_tests/test_index_put_.py @@ -163,3 +163,22 @@ class TestNonZero_IndexPut(PytorchLayerTest): self.indices_0 = indices[0] self.indices_1 = indices[1] self._test(*self.create_model(accumulate), ie_device, precision, ir_version, trace_model=True) + +class TestMask_IndexPut(PytorchLayerTest): + def _prepare_input(self): + return (np.random.randn(100, 5).astype(np.float32),np.random.randn(100, 5).astype(np.float32)) + + def create_model(self): + class aten_index_put_mask(torch.nn.Module): + def forward(self, x, y): + x[x < 0] = y[x < 0] + return x + + ref_net = None + + return aten_index_put_mask(), ref_net, "aten::index_put_" + + @pytest.mark.nightly + @pytest.mark.precommit + def test_nonzero_index_put_(self, ie_device, precision, ir_version): + self._test(*self.create_model(), ie_device, precision, ir_version, trace_model=True)