diff --git a/src/frontends/pytorch/src/frontend.cpp b/src/frontends/pytorch/src/frontend.cpp index 3596e91b9d8..685b14c157d 100644 --- a/src/frontends/pytorch/src/frontend.cpp +++ b/src/frontends/pytorch/src/frontend.cpp @@ -116,7 +116,7 @@ void FrontEnd::normalize(const std::shared_ptr& model) const { // Usually if nn.Module.forward is given as a source model for conversion, there is the first Parameter // that represents original `self` argument in forward(self, ...). `self` shouldn't play any role in model - // inference if model is completelly frozed and all methods are inlined. So we check if it doesn't have any + // inference if model is completely frozen and all methods are inlined. So we check if it doesn't have any // consumers in the finally converted model and remove this parameter. This parameter should have index 0. if (model->get_parameters().size() > 0) { auto self = model->get_parameters()[0]; diff --git a/src/frontends/pytorch/src/op/full.cpp b/src/frontends/pytorch/src/op/full.cpp index bbb7f98022f..cf60d096555 100644 --- a/src/frontends/pytorch/src/op/full.cpp +++ b/src/frontends/pytorch/src/op/full.cpp @@ -176,7 +176,7 @@ OutputVector translate_empty(const NodeContext& context) { // side, so just skip these parameters num_inputs_check(context, 1, 6); auto sizes = context.get_input(0); - // In OV uninitialised data is not supported, so we create a tensor filled with zeros with a given shape and type. + // In OV uninitialized data is not supported, so we create a tensor filled with zeros with a given shape and type. auto value = context.mark_node(v0::Constant::create(element::f32, Shape{}, {0})); int dtype_id = 1; Output empty; diff --git a/src/frontends/pytorch/src/op/if.cpp b/src/frontends/pytorch/src/op/if.cpp index 7fb3ecce123..77015fb1dee 100644 --- a/src/frontends/pytorch/src/op/if.cpp +++ b/src/frontends/pytorch/src/op/if.cpp @@ -13,6 +13,31 @@ namespace frontend { namespace pytorch { namespace op { +namespace { +// TODO: Ticket 106627. This is a WA and will work only if both branches of if will eventually go to the operation that +// will have same output type for both types +void align_result_types(const NodeContext& context, + std::shared_ptr r1, + std::shared_ptr r2) { + auto r1_tensor = r1->input_value(0); + auto r2_tensor = r2->input_value(0); + auto r1_type = r1_tensor.get_element_type(); + auto r2_type = r2_tensor.get_element_type(); + if (r1_type.is_dynamic() || r2_type.is_dynamic()) + return; + element::Type merged_type; + if (!element::Type::merge(merged_type, r1_type, r2_type)) { + if (r1_type.bitwidth() >= r2_type.bitwidth()) { + auto convert = std::make_shared(r2_tensor, r1_type); + r2->set_argument(0, convert); + } else { + auto convert = std::make_shared(r1_tensor, r2_type); + r1->set_argument(0, convert); + } + } +} +} // namespace + OutputVector translate_if(const NodeContext& context) { auto if_node = std::make_shared(context.get_input(0)); context.mark_node(if_node); @@ -62,6 +87,7 @@ OutputVector translate_if(const NodeContext& context) { FRONT_END_OP_CONVERSION_CHECK(then_results.size() >= num_outs && else_results.size() >= num_outs, "Else or then body have less outputs than prim::If requires."); for (size_t i = 0; i < num_outs; i++) { + align_result_types(context, then_results[i], else_results[i]); res.push_back(if_node->set_output(then_results[i], else_results[i])); } // Each body can have mutated outputs that are not included into pytorch node outputs. @@ -136,6 +162,7 @@ OutputVector translate_if(const NodeContext& context) { } } for (const auto& output_idx : extra_output_idxs) { + align_result_types(context, extra_then_body_results.at(output_idx), extra_else_body_results.at(output_idx)); context.add_tensor_to_context( output_idx, if_node->set_output(extra_then_body_results.at(output_idx), extra_else_body_results.at(output_idx))); diff --git a/src/frontends/pytorch/src/op/select.cpp b/src/frontends/pytorch/src/op/select.cpp index ea5255f2410..7cd898fdf22 100644 --- a/src/frontends/pytorch/src/op/select.cpp +++ b/src/frontends/pytorch/src/op/select.cpp @@ -5,11 +5,7 @@ #include "openvino/op/select.hpp" #include "openvino/frontend/pytorch/node_context.hpp" -#include "openvino/op/add.hpp" -#include "openvino/op/constant.hpp" -#include "openvino/op/less.hpp" -#include "openvino/op/reshape.hpp" -#include "openvino/op/slice.hpp" +#include "openvino/op/gather.hpp" #include "openvino/op/squeeze.hpp" #include "utils.hpp" @@ -21,22 +17,12 @@ namespace op { using namespace ov::op; OutputVector translate_select(const NodeContext& context) { + // aten::select.int(Tensor(a) self, int dim, SymInt index) -> Tensor(a) num_inputs_check(context, 3, 3); - auto const_1 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {1})); - auto const_minus_1 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-1})); - auto const_0 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {0})); - - auto input_tensor = context.get_input(0); - auto dim = context.mark_node(std::make_shared(context.get_input(1), const_1, false)); - auto start = context.mark_node(std::make_shared(context.get_input(2), const_1, false)); - - auto less = context.mark_node(std::make_shared(start, const_0)); - auto const_1_signed = context.mark_node(std::make_shared(less, const_minus_1, const_1)); - auto stop = context.mark_node(std::make_shared(start, const_1_signed)); - - auto slice_node = context.mark_node(std::make_shared(input_tensor, start, stop, const_1_signed, dim)); - - return {context.mark_node(std::make_shared(slice_node, dim))}; + auto data = context.get_input(0); + auto dim = context.get_input(1); + auto index = context.get_input(2); + return {context.mark_node(std::make_shared(data, index, dim))}; }; } // namespace op diff --git a/src/frontends/pytorch/src/transforms/listconstruct_replacer.cpp b/src/frontends/pytorch/src/transforms/listconstruct_replacer.cpp index a5496501dba..6b1792f7a63 100644 --- a/src/frontends/pytorch/src/transforms/listconstruct_replacer.cpp +++ b/src/frontends/pytorch/src/transforms/listconstruct_replacer.cpp @@ -17,6 +17,7 @@ #include "openvino/op/tile.hpp" #include "openvino/op/transpose.hpp" #include "openvino/op/util/framework_node.hpp" +#include "openvino/op/variadic_split.hpp" #include "openvino/pass/pattern/matcher.hpp" #include "openvino/pass/pattern/op/or.hpp" #include "openvino/pass/pattern/op/wrap_type.hpp" @@ -49,6 +50,8 @@ ListConstructReplacer::ListConstructReplacer() { auto tile_op = pattern::wrap_type({pattern::any_input(), list}); // replace aten::permute(tensor, prim::ListConstruct) auto transpose_op = pattern::wrap_type({pattern::any_input(), list}); + // aten::split_with_sizes case + auto vsplit_op = pattern::wrap_type({pattern::any_input(), pattern::any_input(), list}); auto lc_pattern = std::make_shared(OutputVector{reshape_op, roll_op, broadcast_op, @@ -57,7 +60,8 @@ ListConstructReplacer::ListConstructReplacer() { equal_op, select_op, tile_op, - transpose_op}); + transpose_op, + vsplit_op}); ov::matcher_pass_callback callback = [=](pattern::Matcher& m) { auto& pattern_map = m.get_pattern_value_map(); diff --git a/src/frontends/pytorch/src/transforms/min_max_prim_list_construct_replacer.cpp b/src/frontends/pytorch/src/transforms/min_max_prim_list_construct_replacer.cpp index eed50b174f2..d4602ee162c 100644 --- a/src/frontends/pytorch/src/transforms/min_max_prim_list_construct_replacer.cpp +++ b/src/frontends/pytorch/src/transforms/min_max_prim_list_construct_replacer.cpp @@ -49,7 +49,8 @@ MinMaxPrimListConstructReplacer::MinMaxPrimListConstructReplacer() { auto step = std::make_shared(element::i32, Shape{}, 1); auto shape = std::make_shared(input, element::i32); auto rank = std::make_shared(shape, element::i32); - auto reduced_rank = std::make_shared(rank); + auto axis_0 = ov::op::v0::Constant::create(element::i32, Shape{}, {0}); + auto reduced_rank = std::make_shared(rank, axis_0); auto axes = std::make_shared(start, reduced_rank, step, element::i32); std::shared_ptr reduce_op; if (!is_min) { diff --git a/src/frontends/pytorch/src/transforms/prim_list_unpack_replacer.cpp b/src/frontends/pytorch/src/transforms/prim_list_unpack_replacer.cpp index a4ba9a8b3cd..cb7704a99a6 100644 --- a/src/frontends/pytorch/src/transforms/prim_list_unpack_replacer.cpp +++ b/src/frontends/pytorch/src/transforms/prim_list_unpack_replacer.cpp @@ -33,6 +33,7 @@ PrimListUnpackReplacer::PrimListUnpackReplacer() { if (rank.is_dynamic()) { return false; } + std::shared_ptr split; if (rank.get_length() == 0) { // Create split_lenghts tensor from split_size int, // allow for last chunk to be smaller if data is not equally divisible. @@ -45,18 +46,17 @@ PrimListUnpackReplacer::PrimListUnpackReplacer() { auto split_lenghts_m_1 = std::make_shared(split_size, num_out_m_1); NodeVector concat_inputs{split_lenghts_m_1, const_neg_1}; auto split_lenghts = std::make_shared(concat_inputs, 0); - auto split = std::make_shared(torch_split->get_input_source_output(0), - torch_split->get_input_source_output(2), - split_lenghts); - copy_runtime_info({list_unpack, input_node}, split); - replace_node(list_unpack, split); + split = std::make_shared(torch_split->get_input_source_output(0), + torch_split->get_input_source_output(2), + split_lenghts); } else { - auto split = std::make_shared(torch_split->get_input_source_output(0), - torch_split->get_input_source_output(2), - torch_split->get_input_source_output(1)); - copy_runtime_info({list_unpack, input_node}, split); - replace_node(list_unpack, split); + split = std::make_shared(torch_split->get_input_source_output(0), + torch_split->get_input_source_output(2), + torch_split->get_input_source_output(1)); } + copy_runtime_info({list_unpack, input_node}, split); + split->set_friendly_name(input_node->get_friendly_name()); + replace_node(list_unpack, split); return true; } @@ -67,6 +67,7 @@ PrimListUnpackReplacer::PrimListUnpackReplacer() { split_with_sizes->get_input_source_output(1)); copy_runtime_info({list_unpack, input_node}, split); + split->set_friendly_name(input_node->get_friendly_name()); replace_node(list_unpack, split); return true; diff --git a/src/frontends/pytorch/src/utils.cpp b/src/frontends/pytorch/src/utils.cpp index bdae3e9e75e..e9c67d73f54 100644 --- a/src/frontends/pytorch/src/utils.cpp +++ b/src/frontends/pytorch/src/utils.cpp @@ -66,7 +66,8 @@ std::tuple, Output> get_shape_rank(const NodeContext& context auto shape = context.mark_node(std::make_shared(x, output_type)); Output rank = context.mark_node(std::make_shared(shape, output_type)); if (as_scalar) { - rank = context.mark_node(std::make_shared(rank)); + auto axis_0 = context.mark_node(opset10::Constant::create(output_type, Shape{}, {0})); + rank = context.mark_node(std::make_shared(rank, axis_0)); } return std::make_tuple(shape, rank); } @@ -110,9 +111,8 @@ std::shared_ptr get_axes_range(const NodeContext& context, int input_id) { auto x = context.get_input(input_id); auto start = std::make_shared(element::i32, Shape{}, 0); auto step = std::make_shared(element::i32, Shape{}, 1); - auto shape = context.mark_node(std::make_shared(x, element::i32)); - auto rank = context.mark_node(std::make_shared(shape, element::i32)); - auto reduced_rank = context.mark_node(std::make_shared(rank)); + Output reduced_rank; + std::tie(std::ignore, reduced_rank) = get_shape_rank(context, x, true); return context.mark_node(std::make_shared(start, reduced_rank, step, element::i32)); }; diff --git a/tests/layer_tests/pytorch_tests/test_argsort.py b/tests/layer_tests/pytorch_tests/test_argsort.py index c29a5e91ae9..e3514d4c0e6 100644 --- a/tests/layer_tests/pytorch_tests/test_argsort.py +++ b/tests/layer_tests/pytorch_tests/test_argsort.py @@ -11,7 +11,7 @@ def not_yet_supported(value): return pytest.param( value, marks = pytest.mark.xfail( - reason="Failed due to aten::sargsort not yet supporting stable sorting. Ticket 105242" + reason="Failed due to aten::argsort not yet supporting stable sorting. Ticket 105242" ), ) diff --git a/tests/layer_tests/pytorch_tests/test_if.py b/tests/layer_tests/pytorch_tests/test_if.py new file mode 100644 index 00000000000..9e18d1d8f3d --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_if.py @@ -0,0 +1,40 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import numpy as np + +from pytorch_layer_test_class import PytorchLayerTest + + +class TestIf(PytorchLayerTest): + def _prepare_input(self): + return (np.random.randn(1, 3, 224, 224).astype(np.float32), self.y) + + def create_model(self): + import torch + import torch.nn.functional as F + + class prim_if(torch.nn.Module): + def __init__(self): + super(prim_if, self).__init__() + + def forward(self, x, y): + if y > 0: + res = x.new_empty((0, 10), dtype=torch.uint8) + else: + res = torch.zeros(x.shape[:2], dtype=torch.bool) + return res.to(torch.bool) + + ref_net = None + + return prim_if(), ref_net, "prim::If" + + @pytest.mark.parametrize("y", [np.array(1), + np.array(-1) + ]) + @pytest.mark.nightly + @pytest.mark.precommit + def test_if(self, y, ie_device, precision, ir_version): + self.y = y + self._test(*self.create_model(), ie_device, precision, ir_version) diff --git a/tests/layer_tests/pytorch_tests/test_split.py b/tests/layer_tests/pytorch_tests/test_split.py index 3328557c2f1..57627672c5f 100644 --- a/tests/layer_tests/pytorch_tests/test_split.py +++ b/tests/layer_tests/pytorch_tests/test_split.py @@ -47,7 +47,7 @@ class TestSplit(PytorchLayerTest): return aten_split(self.split_param, self.axis), ref_net, "aten::split" - # Test case - (split_param, axis), always split into 5 due to hardcoded number of outputs in ListUnpack test. + # Test case - (split_param, axis), always split into 5 due to hardcoded number of outputs in ListUnpack test. test_cases = [ (2, 1), (45, 2), @@ -64,7 +64,8 @@ class TestSplit(PytorchLayerTest): def test_split_getitem(self, params, getitem, ie_device, precision, ir_version): (self.split_param, self.axis) = params self.getitem = getitem - self._test(*self.create_model_split_getitem(), ie_device, precision, ir_version) + self._test(*self.create_model_split_getitem(), + ie_device, precision, ir_version) @pytest.mark.parametrize("params", test_cases) @pytest.mark.nightly @@ -74,3 +75,30 @@ class TestSplit(PytorchLayerTest): self._test( *self.create_model_split_listunpack(), ie_device, precision, ir_version ) + + +class TestSplitWithSizes(PytorchLayerTest): + def _prepare_input(self): + import numpy as np + return (np.random.randn(20).astype(np.float32),np.random.randn(20).astype(np.float32)) + + def create_model(self): + import torch + + class aten_split_with_sizes(torch.nn.Module): + def __init__(self): + super(aten_split_with_sizes, self).__init__() + #self.sizes = 20 + + def forward(self, x, y): + return x.split([y.shape[0]], dim=0) + + ref_net = None + + return aten_split_with_sizes(), ref_net, ["aten::split_with_sizes", "prim::ListConstruct"] + + @pytest.mark.nightly + @pytest.mark.precommit + def test_relu(self, ie_device, precision, ir_version): + self._test(*self.create_model(), + ie_device, precision, ir_version, trace_model=True)