From 9684f9184a9b87b774cdc6cec8eb8efa52c28e9b Mon Sep 17 00:00:00 2001 From: Maxim Vafin Date: Thu, 15 Jun 2023 14:28:38 +0200 Subject: [PATCH] [PT FE] Improve operation naming (#17997) * Improve operation naming * Use set to reduce operations with indexes, fix code style * Refactor * Make good names in transformations * Remove tensor index from input/output names * Fix tests --- .../src/openvino/frontend/pytorch/decoder.py | 7 + src/frontends/pytorch/src/input_model.cpp | 2 + src/frontends/pytorch/src/node_context.cpp | 15 +- src/frontends/pytorch/src/place.cpp | 13 +- .../append_list_unpack_replacer.cpp | 14 +- .../src/transforms/aten_cat_replacer.cpp | 15 +- .../src/transforms/aten_getitem_replacer.cpp | 154 ++++++++---------- .../transforms/aten_index_put_replacer.cpp | 68 ++++---- .../src/transforms/aten_index_replacer.cpp | 85 +++++----- .../aten_stack_list_construct_replacer.cpp | 27 +-- .../src/transforms/einsum_list_construct.cpp | 2 +- .../index_loop_getitem_replacer.cpp | 2 +- .../src/transforms/listconstruct_replacer.cpp | 8 +- .../min_max_prim_list_construct_replacer.cpp | 29 ++-- .../transforms/prim_list_construct_pad.cpp | 40 ++--- .../transforms/prim_list_unpack_replacer.cpp | 114 ++++++------- .../transforms/string_equality_replacer.cpp | 4 +- .../pytorch/src/translate_session.cpp | 35 +++- .../pytorch/src/translate_session.hpp | 5 +- src/frontends/pytorch/src/utils.cpp | 25 +++ src/frontends/pytorch/src/utils.hpp | 4 + 21 files changed, 354 insertions(+), 314 deletions(-) diff --git a/src/bindings/python/src/openvino/frontend/pytorch/decoder.py b/src/bindings/python/src/openvino/frontend/pytorch/decoder.py index feed23d3822..704a78aaaf3 100644 --- a/src/bindings/python/src/openvino/frontend/pytorch/decoder.py +++ b/src/bindings/python/src/openvino/frontend/pytorch/decoder.py @@ -319,6 +319,13 @@ class TorchScriptPythonDecoder (Decoder): return self.outputs()[index] def mark_node(self, node): + name = self.graph_element.kind() + if "FrameworkNode" not in node.get_type_name(): + name += "/" + node.get_type_name() + if self.graph_element.scopeName(): + node.set_friendly_name(self.graph_element.scopeName().split("/")[-1] + "/" + name) + else: + node.set_friendly_name(name) return node def try_decode_get_attr(self): diff --git a/src/frontends/pytorch/src/input_model.cpp b/src/frontends/pytorch/src/input_model.cpp index db90c8bc275..68009bf8b49 100644 --- a/src/frontends/pytorch/src/input_model.cpp +++ b/src/frontends/pytorch/src/input_model.cpp @@ -15,6 +15,7 @@ InputModel::InputModel(const std::shared_ptr& model_decoder) : m_m const auto& inputs = m_model_decoder->inputs(); for (size_t i = 0; i < inputs.size(); ++i) { auto in_place = std::make_shared(*this, inputs[i]); + m_name_to_place.emplace(std::to_string(inputs[i]), std::dynamic_pointer_cast(in_place)); for (const auto& name : in_place->get_names()) { m_name_to_place.emplace(name, std::dynamic_pointer_cast(in_place)); } @@ -28,6 +29,7 @@ InputModel::InputModel(const std::shared_ptr& model_decoder) : m_m const auto& outputs = m_model_decoder->outputs(); for (size_t i = 0; i < outputs.size(); ++i) { auto out_place = std::make_shared(*this, outputs[i]); + m_name_to_place.emplace(std::to_string(inputs[i]), std::dynamic_pointer_cast(out_place)); for (const auto& name : out_place->get_names()) { m_name_to_place.emplace(name, std::dynamic_pointer_cast(out_place)); } diff --git a/src/frontends/pytorch/src/node_context.cpp b/src/frontends/pytorch/src/node_context.cpp index f0f667f3a18..aa204b7978e 100644 --- a/src/frontends/pytorch/src/node_context.cpp +++ b/src/frontends/pytorch/src/node_context.cpp @@ -40,23 +40,24 @@ OutputVector NodeContext::as_constant() const { fw_node->set_attrs(attrs); return {fw_node}; } else { - return m_decoder->as_constant(); + auto c_outs = m_decoder->as_constant(); + FRONT_END_OP_CONVERSION_CHECK(c_outs.size() == 1, "Constant must have exactly one output."); + c_outs[0].get_node_shared_ptr()->set_friendly_name(m_decoder->get_output_debug_name(0)); + return c_outs; } } std::shared_ptr NodeContext::mark_node(std::shared_ptr ov_node) const { - ov_node->set_friendly_name(get_op_type() + '_' + std::to_string(m_translate_session->m_friendly_name_counter++)); - return m_decoder->mark_node(ov_node); + ov_node = m_decoder->mark_node(ov_node); + m_translate_session->unique_name(ov_node); + return ov_node; } void NodeContext::mutate_input(size_t index, Output ov_output) const { FRONT_END_GENERAL_CHECK(!m_decoder->input_is_none(index), "Input is none with index: ", index); auto input_id = m_decoder_inputs.at(index); FRONT_END_GENERAL_CHECK(m_tensor_map->count(input_id), "No tensor corresponding input: ", input_id, " exist."); - m_translate_session->encode_tensor_name( - ov_output, - input_id, - {m_decoder->get_input_debug_name(index), m_decoder->get_input_signature_name(index)}); + m_translate_session->encode_tensor_name(ov_output, input_id, {m_decoder->get_input_debug_name(index)}); (*m_tensor_map)[input_id] = ov_output; m_mutated_tensors->insert(input_id); diff --git a/src/frontends/pytorch/src/place.cpp b/src/frontends/pytorch/src/place.cpp index 9fbf55e5c4c..b68d848622b 100644 --- a/src/frontends/pytorch/src/place.cpp +++ b/src/frontends/pytorch/src/place.cpp @@ -18,7 +18,6 @@ Place::Place(const ov::frontend::InputModel& input_model, size_t tensor_index) m_tensor_index(tensor_index), m_is_input(false), m_is_output(false) { - m_names.push_back(std::to_string(tensor_index)); const auto im = dynamic_cast(&m_input_model); FRONT_END_GENERAL_CHECK(im, "PyTorch Place requires PyTorch InputModel class."); const auto& inputs = im->m_model_decoder->inputs(); @@ -26,23 +25,15 @@ Place::Place(const ov::frontend::InputModel& input_model, size_t tensor_index) auto in_it = std::find(inputs.begin(), inputs.end(), tensor_index); if (in_it != inputs.end()) { m_is_input = true; - const auto& debug_name = im->m_model_decoder->get_input_debug_name(std::distance(inputs.begin(), in_it)); - if (debug_name != m_names.at(0)) { - m_names.push_back(debug_name); - } const auto& signature_name = im->m_model_decoder->get_input_signature_name(std::distance(inputs.begin(), in_it)); - if (signature_name != m_names.at(0) && signature_name != debug_name) { - m_names.push_back(signature_name); - } + m_names.push_back(signature_name); } auto out_it = std::find(outputs.begin(), outputs.end(), tensor_index); if (out_it != outputs.end()) { m_is_output = true; const auto& debug_name = im->m_model_decoder->get_output_debug_name(std::distance(outputs.begin(), out_it)); - if (debug_name != m_names.at(0)) { - m_names.push_back(debug_name); - } + m_names.push_back(debug_name); } if (m_is_input && m_is_output) { OPENVINO_DEBUG << "[WARNING] Place " << tensor_index << " is input and output at a same time."; diff --git a/src/frontends/pytorch/src/transforms/append_list_unpack_replacer.cpp b/src/frontends/pytorch/src/transforms/append_list_unpack_replacer.cpp index 52fe00fd08e..d3dfc7467b3 100644 --- a/src/frontends/pytorch/src/transforms/append_list_unpack_replacer.cpp +++ b/src/frontends/pytorch/src/transforms/append_list_unpack_replacer.cpp @@ -21,6 +21,8 @@ namespace frontend { namespace pytorch { namespace pass { +using namespace ov::op; + AppendListUnpackReplacer::AppendListUnpackReplacer() { auto list_unpack = ov::pass::pattern::wrap_type(); @@ -30,7 +32,7 @@ AppendListUnpackReplacer::AppendListUnpackReplacer() { return false; OutputVector tmp_inputs; - NodeVector rt_copy_from{list_unpack}; + NodeVector rt_copy_from; auto input_node = list_unpack->input_value(0).get_node_shared_ptr(); // Optional aten::__getitem__ node. @@ -60,7 +62,7 @@ AppendListUnpackReplacer::AppendListUnpackReplacer() { // If aten::__getitem__, expect inputs to be equivalent of pytorch Tensor[][]. // Tensor selected by aten::__getitem__ index needs to be splitted in axis 0. auto getitem_index_ptr = getitem_node->input_value(1).get_node_shared_ptr(); - auto getitem_index_const = std::dynamic_pointer_cast(getitem_index_ptr); + auto getitem_index_const = std::dynamic_pointer_cast(getitem_index_ptr); auto index_val = getitem_index_const->cast_vector(); if (index_val.size() != 1) { add_exception_to_fw_node(list_unpack, "prim::ListUnpack: index of aten::__getitem__ is not scalar."); @@ -70,16 +72,16 @@ AppendListUnpackReplacer::AppendListUnpackReplacer() { if (index_val[0] < 0) { index = inputs.size() + index; } - auto axis_0 = ov::op::v0::Constant::create(element::i32, Shape{}, {0}); - auto split = std::make_shared(inputs[index], axis_0, list_unpack->get_output_size()); + auto axis_0 = v0::Constant::create(element::i32, Shape{}, {0}); + auto split = std::make_shared(inputs[index], axis_0, list_unpack->get_output_size()); NodeVector to_copy_rt{axis_0, split}; OutputVector res; for (auto output : split->outputs()) { - auto squeeze = std::make_shared(output, axis_0); + auto squeeze = std::make_shared(output, axis_0); to_copy_rt.push_back(squeeze); res.push_back(squeeze); } - copy_runtime_info(rt_copy_from, to_copy_rt); + copy_runtime_info_and_name(list_unpack, to_copy_rt, rt_copy_from); replace_node(list_unpack, res); return true; } else { diff --git a/src/frontends/pytorch/src/transforms/aten_cat_replacer.cpp b/src/frontends/pytorch/src/transforms/aten_cat_replacer.cpp index facbf44949d..bb387ee77be 100644 --- a/src/frontends/pytorch/src/transforms/aten_cat_replacer.cpp +++ b/src/frontends/pytorch/src/transforms/aten_cat_replacer.cpp @@ -21,6 +21,8 @@ namespace frontend { namespace pytorch { namespace pass { +using namespace ov::op; + // aten::cat needs a special handling since it takes a Tensor[] as input. We set the inputs of ListConstruct as the // inputs of cat. // @@ -41,7 +43,7 @@ AtenCatToConcat::AtenCatToConcat() { int64_t axis; if (cat->get_input_size() > 1) { auto axis_node = cat->get_input_node_shared_ptr(1); - auto axis_const = std::dynamic_pointer_cast(axis_node); + auto axis_const = std::dynamic_pointer_cast(axis_node); if (!axis_const) { add_exception_to_fw_node(cat, "aten::cat unsupported case: axis is not a constant."); return false; @@ -62,7 +64,7 @@ AtenCatToConcat::AtenCatToConcat() { } std::shared_ptr input_node = cat->get_input_node_shared_ptr(0); - if (auto loop = std::dynamic_pointer_cast(input_node)) { + if (auto loop = std::dynamic_pointer_cast(input_node)) { // case when concatenation is done inside the Loop auto body = loop->get_function(); auto output_index = cat->input(0).get_source_output().get_index(); @@ -82,7 +84,7 @@ AtenCatToConcat::AtenCatToConcat() { "aten::cat unsupported case: aten::append wasn't found inside prim::Loop body."); return false; } - auto param = std::dynamic_pointer_cast(append->get_input_node_shared_ptr(0)); + auto param = std::dynamic_pointer_cast(append->get_input_node_shared_ptr(0)); if (!param) { add_exception_to_fw_node(cat, "aten::cat unsupported case: input of aten::append inside prim::Loop " @@ -106,7 +108,7 @@ AtenCatToConcat::AtenCatToConcat() { "body is not a prim::ListConstruct."); return false; } - auto new_result = std::make_shared(append->input_value(1)); + auto new_result = std::make_shared(append->input_value(1)); body->add_results({new_result}); auto new_output = loop->get_concatenated_slices(new_result, 0, 1, 1, -1, axis); copy_runtime_info(cat, loop); @@ -115,10 +117,9 @@ AtenCatToConcat::AtenCatToConcat() { } const auto&& tmp_inputs = get_list_as_outputs(cat->get_input_source_output(0)); - auto result = std::make_shared(OutputVector(tmp_inputs.begin(), tmp_inputs.end()), axis); - copy_runtime_info(cat, result); + auto result = std::make_shared(OutputVector(tmp_inputs.begin(), tmp_inputs.end()), axis); + copy_runtime_info_and_name(cat, {result}); replace_node(cat, result); - result->set_friendly_name(cat->get_friendly_name()); return true; }; diff --git a/src/frontends/pytorch/src/transforms/aten_getitem_replacer.cpp b/src/frontends/pytorch/src/transforms/aten_getitem_replacer.cpp index 1dc86a050fe..b0babdf82b2 100644 --- a/src/frontends/pytorch/src/transforms/aten_getitem_replacer.cpp +++ b/src/frontends/pytorch/src/transforms/aten_getitem_replacer.cpp @@ -14,6 +14,10 @@ #include "openvino/op/convert.hpp" #include "openvino/op/divide.hpp" #include "openvino/op/gather.hpp" +#include "openvino/op/greater.hpp" +#include "openvino/op/greater_eq.hpp" +#include "openvino/op/less.hpp" +#include "openvino/op/mod.hpp" #include "openvino/op/multiply.hpp" #include "openvino/op/range.hpp" #include "openvino/op/shape_of.hpp" @@ -22,7 +26,6 @@ #include "openvino/op/unsqueeze.hpp" #include "openvino/op/util/framework_node.hpp" #include "openvino/op/variadic_split.hpp" -#include "openvino/opsets/opset10.hpp" #include "openvino/pass/pattern/matcher.hpp" #include "openvino/pass/pattern/op/wrap_type.hpp" #include "pt_framework_node.hpp" @@ -33,6 +36,8 @@ namespace frontend { namespace pytorch { namespace pass { +using namespace ov::op; + AtenGetItemReplacer::AtenGetItemReplacer() { auto getitem = ov::pass::pattern::wrap_type(); @@ -41,6 +46,7 @@ AtenGetItemReplacer::AtenGetItemReplacer() { if (!getitem) return false; + ov::pass::NodeRegistry rg; auto input_node = getitem->input_value(0).get_node_shared_ptr(); if (auto torch_split = cast_fw_node(input_node, "aten::split")) { auto rank = torch_split->input(1).get_partial_shape().rank(); @@ -51,53 +57,48 @@ AtenGetItemReplacer::AtenGetItemReplacer() { if (rank.get_length() == 0) { // Based on slice_size and output index select size. // Constants required by transformation. - auto const_1 = ov::op::v0::Constant::create(element::i32, Shape{1}, {1}); - auto const_1_0d = ov::op::v0::Constant::create(element::i32, Shape{}, {1}); - auto const_0 = ov::op::v0::Constant::create(element::i32, Shape{1}, {0}); - auto const_0_0d = ov::op::v0::Constant::create(element::i32, Shape{}, {0}); + auto const_1 = v0::Constant::create(element::i32, Shape{1}, {1}); + auto const_1_0d = v0::Constant::create(element::i32, Shape{}, {1}); + auto const_0 = v0::Constant::create(element::i32, Shape{1}, {0}); + auto const_0_0d = v0::Constant::create(element::i32, Shape{}, {0}); // Load and convert op inputs. auto input = torch_split->get_input_source_output(0); auto split_size = torch_split->get_input_source_output(1); - auto split_size_1d = std::make_shared(split_size, const_0); + auto split_size_1d = rg.make(split_size, const_0); auto axis = torch_split->get_input_source_output(2); - auto axis_1d = std::make_shared(axis, const_0); + auto axis_1d = rg.make(axis, const_0); auto getitem_idx = getitem->input(1).get_source_output(); // Calculate number of splits based on input shape and split_size. - auto shape = std::make_shared(input, element::i32); - auto len_to_split = std::make_shared(shape, axis, const_0); + auto shape = rg.make(input, element::i32); + auto len_to_split = rg.make(shape, axis, const_0); // Convert to f64 from int to calculate reminder - last chunk can be smaller if Shape in given axis is // not equally divisible. - auto len_to_split_float = std::make_shared(len_to_split, element::f64); - auto split_size_1d_float = std::make_shared(split_size_1d, element::f64); - auto out_div = std::make_shared(len_to_split_float, split_size_1d_float); - auto out_num = std::make_shared(out_div); - auto out_num_0d = std::make_shared(out_num, const_0); + auto len_to_split_float = rg.make(len_to_split, element::f64); + auto split_size_1d_float = rg.make(split_size_1d, element::f64); + auto out_div = rg.make(len_to_split_float, split_size_1d_float); + auto out_num = rg.make(out_div); + auto out_num_0d = rg.make(out_num, const_0); // Use Range and Gather to convert negative getitem indexes into positive due problems with indexing // with -1. - auto possible_out_idx = std::make_shared(const_0_0d, - out_num_0d, - const_1_0d, - split_size.get_element_type()); - auto always_positive_out_idx = - std::make_shared(possible_out_idx, getitem_idx, const_0); + auto possible_out_idx = + rg.make(const_0_0d, out_num_0d, const_1_0d, split_size.get_element_type()); + auto always_positive_out_idx = rg.make(possible_out_idx, getitem_idx, const_0); // Use Slice to get only split output selected by getitem idx. Couldn't use VariadicSplit due to // problems with dynamic inputs. - auto split_slice_start = std::make_shared(always_positive_out_idx, split_size_1d); - auto split_slice_end = std::make_shared(split_slice_start, split_size_1d); - auto split = - std::make_shared(input, split_slice_start, split_slice_end, const_1, axis_1d); - copy_runtime_info({getitem, input_node}, split); + auto split_slice_start = rg.make(always_positive_out_idx, split_size_1d); + auto split_slice_end = rg.make(split_slice_start, split_size_1d); + auto split = rg.make(input, split_slice_start, split_slice_end, const_1, axis_1d); replace_node(getitem, split); } else { auto getitem_index_ptr = getitem->input_value(1).get_node_shared_ptr(); - auto getitem_index_const = std::dynamic_pointer_cast(getitem_index_ptr); - auto split = std::make_shared(torch_split->get_input_source_output(0), - torch_split->get_input_source_output(2), - torch_split->get_input_source_output(1)); + auto getitem_index_const = std::dynamic_pointer_cast(getitem_index_ptr); + auto split = rg.make(torch_split->get_input_source_output(0), + torch_split->get_input_source_output(2), + torch_split->get_input_source_output(1)); auto index_val = getitem_index_const->cast_vector(); if (index_val.size() != 1) { add_exception_to_fw_node(getitem, "aten::__getitem__ index is not scalar."); @@ -107,82 +108,71 @@ AtenGetItemReplacer::AtenGetItemReplacer() { if (index < 0) { index = split->outputs().size() + index; } - OutputVector res{split->outputs()[index]}; - copy_runtime_info({getitem, input_node}, split); - replace_node(getitem, res); + replace_node(getitem, {split->outputs()[index]}); } - return true; - } - if (auto list_construct = cast_fw_node(input_node, "prim::ListConstruct")) { + } else if (auto list_construct = cast_fw_node(input_node, "prim::ListConstruct")) { auto getitem_idx = getitem->input_value(1).get_node_shared_ptr(); - auto getitem_idx_const = std::dynamic_pointer_cast(getitem_idx); + auto getitem_idx_const = std::dynamic_pointer_cast(getitem_idx); if (getitem_idx_const) { auto idx = getitem_idx_const->cast_vector(); auto element = list_construct->input_value(idx[0]).get_node_shared_ptr(); - copy_runtime_info({getitem, input_node}, element); replace_node(getitem, element); - return true; + } else { + auto input_concat = concat_list_construct(list_construct); + auto zero = v0::Constant::create(element::i32, Shape{}, {0}); + auto gather = rg.make(input_concat, getitem_idx, zero); + replace_node(getitem, gather); } - auto input_concat = concat_list_construct(list_construct); - auto zero = ov::op::v0::Constant::create(element::i32, Shape{}, {0}); - auto gather = std::make_shared(input_concat, getitem_idx, zero); - copy_runtime_info({getitem, input_node}, gather); - replace_node(getitem, gather); - return true; - } - if (auto chunk = cast_fw_node(input_node, "aten::chunk")) { + } else if (auto chunk = cast_fw_node(input_node, "aten::chunk")) { auto input_tensor = chunk->get_input_source_output(0); auto chunks_i32 = chunk->get_input_source_output(1); auto dim_i32 = chunk->get_input_source_output(2); - auto const_0 = opset10::Constant::create(element::i64, Shape{1}, {0}); - auto const_1 = opset10::Constant::create(element::i64, Shape{1}, {1}); - auto const_0_nodim = opset10::Constant::create(element::i64, Shape{}, {0}); + auto const_0 = v0::Constant::create(element::i64, Shape{1}, {0}); + auto const_1 = v0::Constant::create(element::i64, Shape{1}, {1}); + auto const_0_nodim = v0::Constant::create(element::i64, Shape{}, {0}); auto getitem_index_i32 = getitem->get_input_source_output(1); - auto getitem_index_i64 = std::make_shared(getitem_index_i32, element::i64); - auto getitem_index = std::make_shared(getitem_index_i64, const_0); - auto dim_i64 = std::make_shared(dim_i32, element::i64); - auto dim = std::make_shared(dim_i64, const_0); - auto chunks = std::make_shared(chunks_i32, element::i64); + auto getitem_index_i64 = rg.make(getitem_index_i32, element::i64); + auto getitem_index = rg.make(getitem_index_i64, const_0); + auto dim_i64 = rg.make(dim_i32, element::i64); + auto dim = rg.make(dim_i64, const_0); + auto chunks = rg.make(chunks_i32, element::i64); - auto input_shape = std::make_shared(input_tensor); - auto input_dimension = std::make_shared(input_shape, dim, const_0); - auto input_size = std::make_shared(input_dimension); + auto input_shape = rg.make(input_tensor); + auto input_dimension = rg.make(input_shape, dim, const_0); + auto input_size = rg.make(input_dimension); - auto chunk_size = std::make_shared(input_size, chunks, true); - auto last_chunk_size = std::make_shared(input_size, chunks); - auto is_last_nonzero = std::make_shared(last_chunk_size, const_0_nodim); - auto is_last_nonzero_int = std::make_shared(is_last_nonzero, element::i64); + auto chunk_size = rg.make(input_size, chunks, true); + auto last_chunk_size = rg.make(input_size, chunks); + auto is_last_nonzero = rg.make(last_chunk_size, const_0_nodim); + auto is_last_nonzero_int = rg.make(is_last_nonzero, element::i64); - auto computed_chunk_size = std::make_shared(chunk_size, is_last_nonzero_int); - auto computed_last_chunk_size = std::make_shared(input_size, computed_chunk_size); - auto computed_is_last_nonzero = std::make_shared(computed_last_chunk_size, const_0_nodim); - auto computed_chunks = std::make_shared(input_size, computed_chunk_size, true); + auto computed_chunk_size = rg.make(chunk_size, is_last_nonzero_int); + auto computed_last_chunk_size = rg.make(input_size, computed_chunk_size); + auto computed_is_last_nonzero = rg.make(computed_last_chunk_size, const_0_nodim); + auto computed_chunks = rg.make(input_size, computed_chunk_size, true); - auto is_slice_normal_size = std::make_shared(getitem_index, computed_chunks); - auto is_slice_not_normal_size = std::make_shared(getitem_index, computed_chunks); - auto is_slice_normal_size_int = std::make_shared(is_slice_normal_size, element::i64); - auto is_slice_not_normal_size_int = - std::make_shared(is_slice_not_normal_size, element::i64); + auto is_slice_normal_size = rg.make(getitem_index, computed_chunks); + auto is_slice_not_normal_size = rg.make(getitem_index, computed_chunks); + auto is_slice_normal_size_int = rg.make(is_slice_normal_size, element::i64); + auto is_slice_not_normal_size_int = rg.make(is_slice_not_normal_size, element::i64); - auto slice_size_lhs = std::make_shared(is_slice_normal_size_int, computed_chunk_size); - auto slice_size_rhs = - std::make_shared(is_slice_not_normal_size_int, computed_last_chunk_size); - auto slice_size = std::make_shared(slice_size_lhs, slice_size_rhs); + auto slice_size_lhs = rg.make(is_slice_normal_size_int, computed_chunk_size); + auto slice_size_rhs = rg.make(is_slice_not_normal_size_int, computed_last_chunk_size); + auto slice_size = rg.make(slice_size_lhs, slice_size_rhs); - auto slice_begin = std::make_shared(getitem_index, computed_chunk_size); - auto slice_end = std::make_shared(slice_begin, slice_size); + auto slice_begin = rg.make(getitem_index, computed_chunk_size); + auto slice_end = rg.make(slice_begin, slice_size); - auto sliced_chunk = std::make_shared(input_tensor, slice_begin, slice_end, const_1, dim); + auto sliced_chunk = rg.make(input_tensor, slice_begin, slice_end, const_1, dim); - copy_runtime_info({getitem, input_node}, sliced_chunk); replace_node(getitem, sliced_chunk); - - return true; + } else { + return false; } - - return false; + copy_runtime_info_and_name(getitem, rg.get(), {input_node}); + return true; }; auto m = std::make_shared(getitem, "ov::frontend::pytorch::pass::AtenGetItemReplacer"); diff --git a/src/frontends/pytorch/src/transforms/aten_index_put_replacer.cpp b/src/frontends/pytorch/src/transforms/aten_index_put_replacer.cpp index 6164b944a43..c73767840b1 100644 --- a/src/frontends/pytorch/src/transforms/aten_index_put_replacer.cpp +++ b/src/frontends/pytorch/src/transforms/aten_index_put_replacer.cpp @@ -31,10 +31,12 @@ namespace pass { using namespace ov::op; namespace { -Output generate_zeros_with_convertlike(const Output sizes, const Output tensor_of_type) { +Output generate_zeros_with_convertlike(ov::pass::NodeRegistry& rg, + const Output sizes, + const Output tensor_of_type) { auto const_0 = v0::Constant::create(element::i32, Shape{}, {0}); - auto zeros = std::make_shared(const_0, sizes); - return std::make_shared(zeros, tensor_of_type); + auto zeros = rg.make(const_0, sizes); + return rg.make(zeros, tensor_of_type); } } // namespace @@ -46,18 +48,18 @@ AtenIndexPutReplacer::AtenIndexPutReplacer() { if (!index_op) { return false; } - NodeVector rt_copy_from{index_op}; + NodeVector rt_copy_from; + ov::pass::NodeRegistry rg; auto const_0 = v0::Constant::create(element::i32, Shape{}, {0}); auto const_1 = v0::Constant::create(element::i32, Shape{1}, {1}); auto const_max_int = v0::Constant::create(element::i32, Shape{1}, {std::numeric_limits::max()}); auto const_neg_1 = v0::Constant::create(element::i32, Shape{}, {-1}); auto input = index_op->input_value(0); - auto input_shape = std::make_shared(input, element::i32); + auto input_shape = rg.make(input, element::i32); auto indices = index_op->input_value(1); auto values = index_op->input_value(2); - auto acc_const = - std::dynamic_pointer_cast(index_op->input_value(3).get_node_shared_ptr()); + auto acc_const = std::dynamic_pointer_cast(index_op->input_value(3).get_node_shared_ptr()); if (!acc_const) { add_exception_to_fw_node(index_op, "aten::index_put_: non constant accumulate input is not supported."); return false; @@ -85,12 +87,11 @@ AtenIndexPutReplacer::AtenIndexPutReplacer() { return false; } indices_list_len = indices_first_dim.get_length(); - auto split = std::make_shared(indices, const_0, indices_list_len); + auto split = rg.make(indices, const_0, indices_list_len); indices_inputs = split->outputs(); } if (indices_list_len == 0) { - copy_runtime_info(rt_copy_from, values.get_node_shared_ptr()); replace_node(index_op, values.get_node_shared_ptr()); return true; } @@ -102,52 +103,51 @@ AtenIndexPutReplacer::AtenIndexPutReplacer() { if (indices_list_len > 1) { index = indices_inputs[0]; for (int i = 1; i < indices_list_len; i++) { - index = std::make_shared(index, indices_inputs[i]); + index = rg.make(index, indices_inputs[i]); } - broadcast_index_shape = std::make_shared(index, element::i32); + broadcast_index_shape = rg.make(index, element::i32); OutputVector indices_list; for (int i = 0; i < indices_list_len; i++) { - auto broadcast = std::make_shared(indices_inputs[i], broadcast_index_shape); - auto unsqueeze = std::make_shared(broadcast, const_neg_1); + auto broadcast = rg.make(indices_inputs[i], broadcast_index_shape); + auto unsqueeze = rg.make(broadcast, const_neg_1); // change negative indices to positive indices auto const_i = v0::Constant::create(element::i32, Shape{}, {i}); - auto dim_i = std::make_shared(input_shape, const_i, const_0); - auto dim_i_correct_type = std::make_shared(dim_i, index); - auto unsqueeze_add = std::make_shared(unsqueeze, dim_i_correct_type); - auto unsqueeze_add_mod = std::make_shared(unsqueeze_add, dim_i_correct_type); + auto dim_i = rg.make(input_shape, const_i, const_0); + auto dim_i_correct_type = rg.make(dim_i, index); + auto unsqueeze_add = rg.make(unsqueeze, dim_i_correct_type); + auto unsqueeze_add_mod = rg.make(unsqueeze_add, dim_i_correct_type); indices_list.push_back(unsqueeze_add_mod); } - index = std::make_shared(indices_list, -1); + index = rg.make(indices_list, -1); } else { index = indices_inputs[0]; // change negative indices to positive indices - auto dim_0 = (std::make_shared(input_shape, const_0, const_0)); - auto dim_0_correct_type = (std::make_shared(dim_0, index)); - index = std::make_shared(index, dim_0_correct_type); - index = std::make_shared(index, dim_0_correct_type); + auto dim_0 = (rg.make(input_shape, const_0, const_0)); + auto dim_0_correct_type = (rg.make(dim_0, index)); + index = rg.make(index, dim_0_correct_type); + index = rg.make(index, dim_0_correct_type); - broadcast_index_shape = std::make_shared(index, element::i32); - index = std::make_shared(index, const_neg_1); + broadcast_index_shape = rg.make(index, element::i32); + index = rg.make(index, const_neg_1); } - auto sub_data_shape = std::make_shared(input_shape, const_indices_list_len, const_max_int, const_1); - auto values_shape = std::make_shared(OutputVector{broadcast_index_shape, sub_data_shape}, 0); - values = std::make_shared(values, values_shape); - values = std::make_shared(values, input); + auto sub_data_shape = rg.make(input_shape, const_indices_list_len, const_max_int, const_1); + auto values_shape = rg.make(OutputVector{broadcast_index_shape, sub_data_shape}, 0); + values = rg.make(values, values_shape); + values = rg.make(values, input); std::shared_ptr result; if (accumulate) { - auto zeros = generate_zeros_with_convertlike(input_shape, input); - auto scatter = std::make_shared(zeros, index, values); - result = std::make_shared(input, scatter); + auto zeros = generate_zeros_with_convertlike(rg, input_shape, input); + auto scatter = rg.make(zeros, index, values); + result = rg.make(input, scatter); } else { - result = std::make_shared(input, index, values); + result = rg.make(input, index, values); } - copy_runtime_info(rt_copy_from, result); + copy_runtime_info_and_name(index_op, rg.get(), rt_copy_from); replace_node(index_op, result); - result->set_friendly_name(index_op->get_friendly_name()); return true; }; diff --git a/src/frontends/pytorch/src/transforms/aten_index_replacer.cpp b/src/frontends/pytorch/src/transforms/aten_index_replacer.cpp index 1afbb24b83a..0ae98cc322c 100644 --- a/src/frontends/pytorch/src/transforms/aten_index_replacer.cpp +++ b/src/frontends/pytorch/src/transforms/aten_index_replacer.cpp @@ -33,9 +33,9 @@ namespace pytorch { namespace pass { using namespace ov::op; -namespace { -std::shared_ptr flatten(const Output& value, size_t axis) { +namespace { +Output flatten(ov::pass::NodeRegistry& rg, const Output& value, size_t axis) { // First dimension of output tensor is the product of [d_0, ... d_{axis-1}] dimensions of // input tensor. The last dimension is the product of the rest of input tensor dimensions: // [d_{axis}, ..., d_n] @@ -45,20 +45,20 @@ std::shared_ptr flatten(const Output& value, size_t axis) { } else if (axis == 1) { output_shape = v0::Constant::create(element::i32, Shape{2}, {0, -1}); } else { - const auto value_shape = std::make_shared(value, element::i32); - const auto value_rank = std::make_shared(value_shape, element::i32); + const auto value_shape = rg.make(value, element::i32); + const auto value_rank = rg.make(value_shape, element::i32); const auto axis_node = v0::Constant::create(element::i32, Shape{1}, {axis}); auto start = v0::Constant::create(element::i32, Shape{1}, {0}); auto step = v0::Constant::create(element::i32, Shape{1}, {1}); - const auto first_part_dims = std::make_shared(value_shape, start, axis_node, step); + const auto first_part_dims = rg.make(value_shape, start, axis_node, step); auto zero = v0::Constant::create(element::i32, {}, {0}); - auto first_part_dims_length = std::make_shared(first_part_dims, zero, true); + auto first_part_dims_length = rg.make(first_part_dims, zero, true); auto remaining_part_length = v0::Constant::create(element::i32, {1}, {-1}); - output_shape = std::make_shared(OutputVector{first_part_dims_length, remaining_part_length}, 0); + output_shape = rg.make(OutputVector{first_part_dims_length, remaining_part_length}, 0); } - return std::make_shared(value, output_shape, true); + return rg.make(value, output_shape, true); } }; // namespace @@ -70,6 +70,7 @@ AtenIndexToSelect::AtenIndexToSelect() { if (!index_op) { return false; } + ov::pass::NodeRegistry rg; auto input_node = index_op->input_value(0); auto indicies = index_op->input_value(1).get_node_shared_ptr(); auto list_indicies = cast_fw_node(indicies, "prim::ListConstruct"); @@ -110,10 +111,10 @@ AtenIndexToSelect::AtenIndexToSelect() { } auto id_dtype = ids[i].get_element_type(); if (id_dtype == element::boolean || id_dtype == element::u8) { - auto idx = std::make_shared(ids[i], element::u8); - auto nonzero = std::make_shared(idx, element::i32); + auto idx = rg.make(ids[i], element::u8); + auto nonzero = rg.make(idx, element::i32); auto input_order = v0::Constant::create(element::i32, Shape{2}, {1, 0}); - auto masked_id = std::make_shared(nonzero, input_order); + auto masked_id = rg.make(nonzero, input_order); masked_indicies.push_back(masked_id); is_masked_bool.push_back(true); } else { @@ -132,17 +133,15 @@ AtenIndexToSelect::AtenIndexToSelect() { if (advanced_ids.size() == 1) { auto index = masked_indicies[advanced_ids[0]]; if (is_masked_bool[advanced_ids[0]]) { - auto gather = std::make_shared(input_node, index); - copy_runtime_info({index_op, indicies}, gather); - gather->set_friendly_name(index_op->get_friendly_name()); + auto gather = rg.make(input_node, index); + copy_runtime_info_and_name(index_op, rg.get()); replace_node(index_op, gather); return true; } - index = std::make_shared(index, element::i32); + index = rg.make(index, element::i32); auto dim = v0::Constant::create(element::i32, Shape{}, {advanced_ids[0]}); - auto gather = std::make_shared(input_node, index, dim); - copy_runtime_info({index_op, indicies}, gather); - gather->set_friendly_name(index_op->get_friendly_name()); + auto gather = rg.make(input_node, index, dim); + copy_runtime_info_and_name(index_op, rg.get()); replace_node(index_op, gather); return true; } @@ -153,9 +152,9 @@ AtenIndexToSelect::AtenIndexToSelect() { add_exception_to_fw_node(index_op, "aten::index: dynamic rank for aten::index input is not supported."); return false; } - auto input_shape = std::make_shared(input_node, element::i32); + auto input_shape = rg.make(input_node, element::i32); auto zero = v0::Constant::create(element::i32, Shape{}, {0}); - auto input_dims = std::make_shared(input_shape, zero, rank.get_length()); + auto input_dims = rg.make(input_shape, zero, rank.get_length()); std::vector non_used_dims; for (auto i = 0; i < rank.get_length(); i++) { if (std::find(advanced_ids.begin(), advanced_ids.end(), i) == advanced_ids.end()) { @@ -166,23 +165,23 @@ AtenIndexToSelect::AtenIndexToSelect() { permutation_dims.insert(permutation_dims.end(), advanced_ids.begin(), advanced_ids.end()); permutation_dims.insert(permutation_dims.end(), non_used_dims.begin(), non_used_dims.end()); auto transpose_dims = v0::Constant::create(element::i32, Shape{permutation_dims.size()}, permutation_dims); - auto transposed_input = std::make_shared(input_node, transpose_dims); - auto flatten_input = flatten(transposed_input, adv_idx_count); + auto transposed_input = rg.make(input_node, transpose_dims); + auto flatten_input = flatten(rg, transposed_input, adv_idx_count); auto cum_adv_index = masked_indicies[advanced_ids[adv_idx_count - 1]]; - cum_adv_index = std::make_shared(cum_adv_index, element::i32); + cum_adv_index = rg.make(cum_adv_index, element::i32); auto multiplier = input_dims->output(advanced_ids[adv_idx_count - 1]); for (int i = static_cast(adv_idx_count) - 2; i > -1; i--) { - auto m_idx = std::make_shared(masked_indicies[i], element::i32); - auto adv_index = std::make_shared(m_idx, multiplier); - cum_adv_index = std::make_shared(cum_adv_index, adv_index); + auto m_idx = rg.make(masked_indicies[i], element::i32); + auto adv_index = rg.make(m_idx, multiplier); + cum_adv_index = rg.make(cum_adv_index, adv_index); auto input_id = advanced_ids[i]; - multiplier = std::make_shared(multiplier, input_dims->output(input_id)); + multiplier = rg.make(multiplier, input_dims->output(input_id)); } - std::shared_ptr gather = std::make_shared(flatten_input, cum_adv_index, zero); + std::shared_ptr gather = rg.make(flatten_input, cum_adv_index, zero); OutputVector concat_dims; // check if all advanced indices are consecutive. std::vector consequence_dims; - auto cum_adv_index_shape_tensor = std::make_shared(cum_adv_index, element::i32); + auto cum_adv_index_shape_tensor = rg.make(cum_adv_index, element::i32); for (size_t i = advanced_ids[0]; i <= advanced_ids[advanced_ids.size() - 1]; i++) { consequence_dims.push_back(i); } @@ -194,8 +193,8 @@ AtenIndexToSelect::AtenIndexToSelect() { for (auto i : non_used_dims) { folded_adv_idx_shape_vector.push_back(input_dims->output(i)); } - auto folded_adv_idx_shape = std::make_shared(folded_adv_idx_shape_vector, 0); - gather = std::make_shared(gather, folded_adv_idx_shape, false); + auto folded_adv_idx_shape = rg.make(folded_adv_idx_shape_vector, 0); + gather = rg.make(gather, folded_adv_idx_shape, false); std::vector adv_idx_permute; for (size_t i = 1; i < advanced_ids[0] + 1; i++) { adv_idx_permute.push_back(i); @@ -207,7 +206,7 @@ AtenIndexToSelect::AtenIndexToSelect() { // Transpose folded advanced indexed axis to its original location. auto permute_indicies = v0::Constant::create(element::i32, Shape{adv_idx_permute.size()}, adv_idx_permute); - gather = std::make_shared(gather, permute_indicies); + gather = rg.make(gather, permute_indicies); // unfold advanced index axes for (size_t i = 0; i < advanced_ids[0]; i++) { concat_dims.push_back(input_dims->output(i)); @@ -226,11 +225,10 @@ AtenIndexToSelect::AtenIndexToSelect() { concat_dims.push_back(input_dims->output(i)); } } - auto final_shape = std::make_shared(concat_dims, 0); - gather = std::make_shared(gather, final_shape, false); - copy_runtime_info({index_op, indicies}, gather); + auto final_shape = rg.make(concat_dims, 0); + gather = rg.make(gather, final_shape, false); + copy_runtime_info_and_name(index_op, rg.get()); replace_node(index_op, gather); - gather->set_friendly_name(index_op->get_friendly_name()); return true; } else { @@ -246,22 +244,21 @@ AtenIndexToSelect::AtenIndexToSelect() { } auto index_dtype = indicies->get_output_element_type(0); if (index_dtype == element::boolean || index_dtype == element::u8) { - auto nonzero = std::make_shared(indicies, element::i32); + auto nonzero = rg.make(indicies, element::i32); auto input_order = v0::Constant::create(element::i32, Shape{2}, {1, 0}); - auto masked_id = std::make_shared(nonzero, input_order); - auto gather = std::make_shared(input_node, masked_id); - copy_runtime_info({index_op, indicies}, gather); + auto masked_id = rg.make(nonzero, input_order); + auto gather = rg.make(input_node, masked_id); + copy_runtime_info_and_name(index_op, rg.get()); replace_node(index_op, gather); return true; } if (index_dtype != element::i32) { - indicies = std::make_shared(indicies, element::i32); + indicies = rg.make(indicies, element::i32); } auto dim = v0::Constant::create(element::i32, Shape{}, {0}); - auto gather = std::make_shared(input_node, indicies, dim); - copy_runtime_info({index_op, indicies}, gather); + auto gather = rg.make(input_node, indicies, dim); + copy_runtime_info_and_name(index_op, rg.get()); replace_node(index_op, gather); - gather->set_friendly_name(index_op->get_friendly_name()); return true; } add_exception_to_fw_node(index_op, "Unsupported case of aten::index."); diff --git a/src/frontends/pytorch/src/transforms/aten_stack_list_construct_replacer.cpp b/src/frontends/pytorch/src/transforms/aten_stack_list_construct_replacer.cpp index b811811d59f..f8de5275b69 100644 --- a/src/frontends/pytorch/src/transforms/aten_stack_list_construct_replacer.cpp +++ b/src/frontends/pytorch/src/transforms/aten_stack_list_construct_replacer.cpp @@ -5,27 +5,30 @@ #include "aten_stack_list_construct_replacer.hpp" #include "openvino/core/rt_info.hpp" +#include "openvino/op/concat.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/unsqueeze.hpp" #include "openvino/op/util/framework_node.hpp" -#include "openvino/opsets/opset10.hpp" #include "openvino/pass/pattern/matcher.hpp" #include "openvino/pass/pattern/op/wrap_type.hpp" #include "utils.hpp" -using namespace ov::pass::pattern; - namespace ov { namespace frontend { namespace pytorch { namespace pass { +using namespace ov::op; +using namespace ov::pass::pattern; + AtenStackListConstructReplacer::AtenStackListConstructReplacer() { - auto list_construct = ov::pass::pattern::wrap_type(); - auto axis = ov::pass::pattern::wrap_type(); + auto list_construct = wrap_type(); + auto axis = wrap_type(); // We search for a pattern: ListConstruct -> aten::stack <- Constant - auto stack = ov::pass::pattern::wrap_type({list_construct, axis}); + auto stack = wrap_type({list_construct, axis}); - ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) { + ov::matcher_pass_callback callback = [=](Matcher& m) { auto stack = cast_fw_node(m.get_match_root(), "aten::stack"); if (!stack) { return false; @@ -33,23 +36,23 @@ AtenStackListConstructReplacer::AtenStackListConstructReplacer() { const auto& pattern_map = m.get_pattern_value_map(); auto input_node = pattern_map.at(list_construct).get_node_shared_ptr(); auto axis_node = pattern_map.at(axis).get_node_shared_ptr(); - auto axis_const = std::dynamic_pointer_cast(axis_node); + auto axis_const = std::dynamic_pointer_cast(axis_node); auto axis = axis_const->cast_vector(); // Check if ListConstruct is an input if (auto list_construct_node = cast_fw_node(input_node, "prim::ListConstruct")) { const auto& list_inputs = list_construct_node->input_values(); OutputVector node_vector; - auto zero = opset10::Constant::create(element::i32, Shape{}, {0}); + auto zero = v0::Constant::create(element::i32, Shape{}, {0}); // Iterate over values in ListConstruct for (const auto& list_input : list_inputs) { auto node = concat_list_construct(list_input); - auto unsqueezed_node = std::make_shared(node, axis_const); + auto unsqueezed_node = std::make_shared(node, axis_const); node_vector.push_back(unsqueezed_node); } // Concat vectors on provided axis - auto concat = std::make_shared(node_vector, axis[0]); + auto concat = std::make_shared(node_vector, axis[0]); - copy_runtime_info({stack, input_node}, concat); + copy_runtime_info_and_name(stack, {concat}, {input_node}); replace_node(stack, concat); return true; } diff --git a/src/frontends/pytorch/src/transforms/einsum_list_construct.cpp b/src/frontends/pytorch/src/transforms/einsum_list_construct.cpp index ca9f423ffc5..e9ca0c5fb94 100644 --- a/src/frontends/pytorch/src/transforms/einsum_list_construct.cpp +++ b/src/frontends/pytorch/src/transforms/einsum_list_construct.cpp @@ -50,7 +50,7 @@ AtenEinsumListConstructReplacer::AtenEinsumListConstructReplacer() { } auto einsum = std::make_shared(node_vector, equation); - copy_runtime_info({einsum_op, equation_input, tensor_list}, einsum); + copy_runtime_info_and_name(einsum_op, {einsum}, {equation_input, tensor_list}); replace_node(einsum_op, einsum); return true; } diff --git a/src/frontends/pytorch/src/transforms/index_loop_getitem_replacer.cpp b/src/frontends/pytorch/src/transforms/index_loop_getitem_replacer.cpp index c9ca09c376b..34154fd9cc4 100644 --- a/src/frontends/pytorch/src/transforms/index_loop_getitem_replacer.cpp +++ b/src/frontends/pytorch/src/transforms/index_loop_getitem_replacer.cpp @@ -126,7 +126,7 @@ IndexLoopGetitemReplacer::IndexLoopGetitemReplacer() { auto stop = rg.make(start, chunks_size_body); auto curr_chunk = rg.make(chunk_param, start, stop, one_1d, dim_body); replace_node(getitem, curr_chunk); - copy_runtime_info({chunk_op, getitem}, rg.get()); + copy_runtime_info_and_name(chunk_op, rg.get(), {getitem}); curr_chunk->set_friendly_name(getitem->get_friendly_name()); return true; }; diff --git a/src/frontends/pytorch/src/transforms/listconstruct_replacer.cpp b/src/frontends/pytorch/src/transforms/listconstruct_replacer.cpp index 7f0ab3820f4..72c7d620592 100644 --- a/src/frontends/pytorch/src/transforms/listconstruct_replacer.cpp +++ b/src/frontends/pytorch/src/transforms/listconstruct_replacer.cpp @@ -79,6 +79,7 @@ ListConstructReplacer::ListConstructReplacer() { // Concatenation is possible because all elements in list should be scalar or 1D tensors, // result should be 1D tensor. OutputVector inputs; + ov::pass::NodeRegistry rg; auto neg_1 = v0::Constant::create(element::i32, Shape{1}, {-1}); const auto& start_output = list_node->output(0); for (const auto& input : get_list_as_outputs(start_output)) { @@ -94,13 +95,12 @@ ListConstructReplacer::ListConstructReplacer() { return false; } // reshape all elements to 1D - auto reshape = std::make_shared(input, neg_1, false); + auto reshape = rg.make(input, neg_1, false); inputs.push_back(reshape); } - auto concat = std::make_shared(inputs, 0); - copy_runtime_info({list_node}, concat); + auto concat = rg.make(inputs, 0); + copy_runtime_info_and_name(list_node, rg.get()); replace_node(list_node, concat); - concat->set_friendly_name(list_node->get_friendly_name()); return true; }; auto m = std::make_shared(lc_pattern, "ov::frontend::pytorch::pass::ListConstructReplacer"); diff --git a/src/frontends/pytorch/src/transforms/min_max_prim_list_construct_replacer.cpp b/src/frontends/pytorch/src/transforms/min_max_prim_list_construct_replacer.cpp index 2f38919df8e..dd672a94698 100644 --- a/src/frontends/pytorch/src/transforms/min_max_prim_list_construct_replacer.cpp +++ b/src/frontends/pytorch/src/transforms/min_max_prim_list_construct_replacer.cpp @@ -23,6 +23,8 @@ namespace frontend { namespace pytorch { namespace pass { +using namespace ov::op; + MinMaxPrimListConstructReplacer::MinMaxPrimListConstructReplacer() { auto op = ov::pass::pattern::wrap_type(); @@ -40,25 +42,26 @@ MinMaxPrimListConstructReplacer::MinMaxPrimListConstructReplacer() { } else { op = max_op; } + ov::pass::NodeRegistry rg; auto input_node = op->input_value(0); auto num_inputs = op->inputs().size(); auto input = concat_list_construct(input_node); std::shared_ptr reduce_op; if (num_inputs == 1) { - auto start = std::make_shared(element::i32, Shape{}, 0); - auto step = std::make_shared(element::i32, Shape{}, 1); - auto shape = std::make_shared(input, element::i32); - auto rank = std::make_shared(shape, element::i32); - auto axis_0 = ov::op::v0::Constant::create(element::i32, Shape{}, {0}); - auto reduced_rank = std::make_shared(rank, axis_0); - auto axes = std::make_shared(start, reduced_rank, step, element::i32); + auto start = rg.make(element::i32, Shape{}, 0); + auto step = rg.make(element::i32, Shape{}, 1); + auto shape = rg.make(input, element::i32); + auto rank = rg.make(shape, element::i32); + auto axis_0 = v0::Constant::create(element::i32, Shape{}, {0}); + auto reduced_rank = rg.make(rank, axis_0); + auto axes = rg.make(start, reduced_rank, step, element::i32); std::shared_ptr reduce_op; if (!is_min) { - reduce_op = std::make_shared(input, axes); + reduce_op = rg.make(input, axes); } else { - reduce_op = std::make_shared(input, axes); + reduce_op = rg.make(input, axes); } - copy_runtime_info({op, input_node.get_node_shared_ptr()}, reduce_op); + copy_runtime_info_and_name(op, rg.get()); replace_node(op, reduce_op); return true; } @@ -66,11 +69,11 @@ MinMaxPrimListConstructReplacer::MinMaxPrimListConstructReplacer() { auto second_input = concat_list_construct(second_input_node); std::shared_ptr min_or_max_op; if (!is_min) { - min_or_max_op = std::make_shared(input, second_input); + min_or_max_op = rg.make(input, second_input); } else { - min_or_max_op = std::make_shared(input, second_input); + min_or_max_op = rg.make(input, second_input); } - copy_runtime_info({op, input_node.get_node_shared_ptr()}, min_or_max_op); + copy_runtime_info_and_name(op, rg.get()); replace_node(op, min_or_max_op); return true; }; diff --git a/src/frontends/pytorch/src/transforms/prim_list_construct_pad.cpp b/src/frontends/pytorch/src/transforms/prim_list_construct_pad.cpp index 02b0dbe38d6..ca8a1f1227f 100644 --- a/src/frontends/pytorch/src/transforms/prim_list_construct_pad.cpp +++ b/src/frontends/pytorch/src/transforms/prim_list_construct_pad.cpp @@ -30,7 +30,8 @@ namespace pass { using namespace ov::op; namespace { -Output create_padding(const Output& input_rank, +Output create_padding(ov::pass::NodeRegistry& rg, + const Output& input_rank, const Output& padding, const Output& start_id, const Output& end_id) { @@ -39,14 +40,14 @@ Output create_padding(const Output& input_rank, // OV expects paddings separated on begins and ends for each dimension from first to last auto minus_two = v0::Constant::create(element::i32, Shape{}, {-2}); auto zero = v0::Constant::create(element::i32, Shape{}, {0}); - auto pad_id_range = std::make_shared(start_id, end_id, minus_two, element::i32); - auto pads = std::make_shared(padding, pad_id_range, zero); + auto pad_id_range = rg.make(start_id, end_id, minus_two, element::i32); + auto pads = rg.make(padding, pad_id_range, zero); // add left side zero padding for difference between padding size and input rank - auto pads_short_len = std::make_shared(pads, element::i32); - auto pads_diff = std::make_shared(input_rank, pads_short_len); - auto pads_remaining = std::make_shared(zero, pads_diff); - auto pads_remaining_c = std::make_shared(pads_remaining, pads); - auto pads_full = std::make_shared(OutputVector{pads_remaining_c, pads}, 0); + auto pads_short_len = rg.make(pads, element::i32); + auto pads_diff = rg.make(input_rank, pads_short_len); + auto pads_remaining = rg.make(zero, pads_diff); + auto pads_remaining_c = rg.make(pads_remaining, pads); + auto pads_full = rg.make(OutputVector{pads_remaining_c, pads}, 0); return pads_full; } @@ -64,6 +65,7 @@ PrimListConstructPadReplacer::PrimListConstructPadReplacer() { if (!pad_op) { return false; } + ov::pass::NodeRegistry rg; auto minus_two = v0::Constant::create(element::i32, Shape{}, {-2}); auto minus_one = v0::Constant::create(element::i32, Shape{}, {-1}); auto zero = v0::Constant::create(element::i32, Shape{}, {0}); @@ -73,15 +75,15 @@ PrimListConstructPadReplacer::PrimListConstructPadReplacer() { auto pad_values = concat_list_construct(padding); std::string mode = "constant"; auto zero_f = v0::Constant::create(element::f32, Shape{}, {0}); - auto input_shape = std::make_shared(input_node, element::i32); - auto input_rank = std::make_shared(input_shape, element::i32); - auto pad_size_1d = std::make_shared(pad_values, element::i32); - auto pad_size = std::make_shared(pad_size_1d, zero); + auto input_shape = rg.make(input_node, element::i32); + auto input_rank = rg.make(input_shape, element::i32); + auto pad_size_1d = rg.make(pad_values, element::i32); + auto pad_size = rg.make(pad_size_1d, zero); // get pad_begins and pad_ends indexes starting for end of paddings - auto start_pad_begins = std::make_shared(pad_size, minus_two); - auto start_pad_ends = std::make_shared(pad_size, minus_one); - auto pad_begins_full = create_padding(input_rank, pad_values, start_pad_begins, minus_one); - auto pad_ends_full = create_padding(input_rank, pad_values, start_pad_ends, zero); + auto start_pad_begins = rg.make(pad_size, minus_two); + auto start_pad_ends = rg.make(pad_size, minus_one); + auto pad_begins_full = create_padding(rg, input_rank, pad_values, start_pad_begins, minus_one); + auto pad_ends_full = create_padding(rg, input_rank, pad_values, start_pad_ends, zero); auto mode_const = pad_op->input_value(2).get_node_shared_ptr(); auto pad_value = pad_op->input_value(3); if (const auto& fw_node_mode = cast_fw_node(mode_const, "prim::Constant")) { @@ -97,16 +99,16 @@ PrimListConstructPadReplacer::PrimListConstructPadReplacer() { pad_value = zero_f; } } - pad_value = std::make_shared(pad_value, input_node); + pad_value = rg.make(pad_value, input_node); } if (PAD_MODES.find(mode) == PAD_MODES.end()) { add_exception_to_fw_node(pad_op, "Unsupported mode: " + mode + "for aten::pad"); return false; } auto pad_mode = PAD_MODES.at(mode); - auto pad = std::make_shared(input_node, pad_begins_full, pad_ends_full, pad_value, pad_mode); + auto pad = rg.make(input_node, pad_begins_full, pad_ends_full, pad_value, pad_mode); replace_node(pad_op, pad); - copy_runtime_info({pad_op, padding.get_node_shared_ptr(), mode_const, pad_value.get_node_shared_ptr()}, pad); + copy_runtime_info_and_name(pad_op, rg.get()); pad->set_friendly_name(pad_op->get_friendly_name()); return true; }; diff --git a/src/frontends/pytorch/src/transforms/prim_list_unpack_replacer.cpp b/src/frontends/pytorch/src/transforms/prim_list_unpack_replacer.cpp index 045546e4cd3..7bc4d39d353 100644 --- a/src/frontends/pytorch/src/transforms/prim_list_unpack_replacer.cpp +++ b/src/frontends/pytorch/src/transforms/prim_list_unpack_replacer.cpp @@ -28,6 +28,7 @@ PrimListUnpackReplacer::PrimListUnpackReplacer() { return false; auto input_node = list_unpack->input_value(0).get_node_shared_ptr(); + ov::pass::NodeRegistry rg; if (auto torch_split = cast_fw_node(input_node, "aten::split")) { auto rank = torch_split->input(1).get_partial_shape().rank(); if (rank.is_dynamic()) { @@ -44,19 +45,18 @@ PrimListUnpackReplacer::PrimListUnpackReplacer() { Shape{1}, {list_unpack->get_output_size() - 1}); auto const_neg_1 = opset10::Constant::create(split_size.get_element_type(), Shape{1}, {-1}); - auto split_lenghts_m_1 = std::make_shared(split_size, num_out_m_1); + auto split_lenghts_m_1 = rg.make(split_size, num_out_m_1); NodeVector concat_inputs{split_lenghts_m_1, const_neg_1}; - auto split_lenghts = std::make_shared(concat_inputs, 0); - split = std::make_shared(torch_split->get_input_source_output(0), - torch_split->get_input_source_output(2), - split_lenghts); + auto split_lenghts = rg.make(concat_inputs, 0); + split = rg.make(torch_split->get_input_source_output(0), + torch_split->get_input_source_output(2), + split_lenghts); } else { - split = std::make_shared(torch_split->get_input_source_output(0), - torch_split->get_input_source_output(2), - torch_split->get_input_source_output(1)); + split = rg.make(torch_split->get_input_source_output(0), + torch_split->get_input_source_output(2), + torch_split->get_input_source_output(1)); } - copy_runtime_info({list_unpack, input_node}, split); - split->set_friendly_name(input_node->get_friendly_name()); + copy_runtime_info_and_name(list_unpack, rg.get(), {input_node}); replace_node(list_unpack, split); return true; @@ -64,12 +64,11 @@ PrimListUnpackReplacer::PrimListUnpackReplacer() { if (auto split_with_sizes = cast_fw_node(input_node, "aten::split_with_sizes")) { auto split_lengths = concat_list_construct(split_with_sizes->get_input_source_output(1)); - auto split = std::make_shared(split_with_sizes->get_input_source_output(0), - split_with_sizes->get_input_source_output(2), - split_lengths); + auto split = rg.make(split_with_sizes->get_input_source_output(0), + split_with_sizes->get_input_source_output(2), + split_lengths); - copy_runtime_info({list_unpack, input_node}, split); - split->set_friendly_name(input_node->get_friendly_name()); + copy_runtime_info_and_name(list_unpack, rg.get(), {input_node}); replace_node(list_unpack, split); return true; @@ -87,27 +86,26 @@ PrimListUnpackReplacer::PrimListUnpackReplacer() { auto tensor_0 = opset10::Constant::create(element::i32, Shape{1}, {0}); auto tensor_neg_1 = opset10::Constant::create(element::i32, Shape{1}, {-1}); - auto input_shape = std::make_shared(input_tensor, element::i32); - auto input_dimension = std::make_shared(input_shape, dim, tensor_0); + auto input_shape = rg.make(input_tensor, element::i32); + auto input_dimension = rg.make(input_shape, dim, tensor_0); - auto init_chunk_size = std::make_shared(input_dimension, chunks, true); + auto init_chunk_size = rg.make(input_dimension, chunks, true); // Add 1 if input is not evenly divisible by chunks - auto last_chunk_size = std::make_shared(input_dimension, chunks); - auto is_last_nonzero = std::make_shared(last_chunk_size, tensor_0); - auto is_last_nonzero_int = std::make_shared(is_last_nonzero, element::i32); + auto last_chunk_size = rg.make(input_dimension, chunks); + auto is_last_nonzero = rg.make(last_chunk_size, tensor_0); + auto is_last_nonzero_int = rg.make(is_last_nonzero, element::i32); - auto chunk_size = std::make_shared(init_chunk_size, is_last_nonzero_int); + auto chunk_size = rg.make(init_chunk_size, is_last_nonzero_int); auto split_lengths_even_size = opset10::Constant::create(element::i32, Shape{1}, {list_unpack->get_output_size() - 1}); - auto split_lengths_even = std::make_shared(chunk_size, split_lengths_even_size); + auto split_lengths_even = rg.make(chunk_size, split_lengths_even_size); - auto split_lengths = std::make_shared(OutputVector{split_lengths_even, tensor_neg_1}, 0); - auto sliced_chunks = std::make_shared(input_tensor, dim, split_lengths); + auto split_lengths = rg.make(OutputVector{split_lengths_even, tensor_neg_1}, 0); + auto sliced_chunks = rg.make(input_tensor, dim, split_lengths); - copy_runtime_info({list_unpack, input_node}, sliced_chunks); - sliced_chunks->set_friendly_name(input_node->get_friendly_name()); + copy_runtime_info_and_name(list_unpack, rg.get(), {input_node}); replace_node(list_unpack, sliced_chunks); return true; @@ -117,51 +115,45 @@ PrimListUnpackReplacer::PrimListUnpackReplacer() { const auto input = unbind->get_input_source_output(0); const auto axis = unbind->get_input_source_output(1); const auto num_splits = list_unpack->get_output_size(); - auto split = std::make_shared(input, axis, num_splits); - NodeVector to_copy_rt{split}; + auto split = rg.make(input, axis, num_splits); OutputVector outputs; for (auto output : split->outputs()) { - const auto squeeze = std::make_shared(output, axis); + const auto squeeze = rg.make(output, axis); outputs.push_back(squeeze); - to_copy_rt.push_back(squeeze); } - copy_runtime_info({list_unpack, input_node}, to_copy_rt); + copy_runtime_info_and_name(list_unpack, rg.get(), {input_node}); replace_node(list_unpack, outputs); return true; } if (auto where = cast_fw_node(input_node, "aten::where")) { const auto input = where->get_input_source_output(0); - auto non_zero = std::make_shared(input); + auto non_zero = rg.make(input); auto axis = opset10::Constant::create(element::i32, Shape{}, {0}); const auto num_splits = list_unpack->get_output_size(); - auto split = std::make_shared(non_zero, axis, num_splits); - NodeVector to_copy_rt{split}; + auto split = rg.make(non_zero, axis, num_splits); OutputVector outputs; for (auto output : split->outputs()) { - const auto squeeze = std::make_shared(output, axis); + const auto squeeze = rg.make(output, axis); outputs.push_back(squeeze); - to_copy_rt.push_back(squeeze); } - copy_runtime_info({list_unpack, input_node}, to_copy_rt); + copy_runtime_info_and_name(list_unpack, rg.get(), {input_node}); replace_node(list_unpack, outputs); return true; } if (auto nonzero_numpy = cast_fw_node(input_node, "aten::nonzero_numpy")) { const auto input = nonzero_numpy->get_input_source_output(0); - auto non_zero = std::make_shared(input); + auto non_zero = rg.make(input); auto axis = opset10::Constant::create(element::i32, Shape{}, {0}); const auto num_splits = list_unpack->get_output_size(); - auto split = std::make_shared(non_zero, axis, num_splits); - NodeVector to_copy_rt{split}; + auto split = rg.make(non_zero, axis, num_splits); OutputVector outputs; for (auto output : split->outputs()) { - const auto squeeze = std::make_shared(output, axis); + const auto squeeze = rg.make(output, axis); outputs.push_back(squeeze); - to_copy_rt.push_back(squeeze); } - copy_runtime_info({list_unpack, input_node}, to_copy_rt); + copy_runtime_info_and_name(list_unpack, rg.get(), {input_node}); replace_node(list_unpack, outputs); return true; @@ -175,7 +167,6 @@ PrimListUnpackReplacer::PrimListUnpackReplacer() { add_exception_to_fw_node(input_node, "aten::meshgrid: only prim::ListConstruct supported as input."); return false; } - NodeVector rt_copy_from{list_unpack, input_node, meshgrid_input_node}; OutputVector meshgrid_inputs; for (auto& input : meshgrid_input_node->inputs()) { meshgrid_inputs.push_back(input.get_source_output()); @@ -203,29 +194,26 @@ PrimListUnpackReplacer::PrimListUnpackReplacer() { auto const_1 = opset10::Constant::create(element::i32, Shape{1}, {1}); int input_idx = 0; for (auto& input : meshgrid_inputs) { - auto reshaped_input = std::make_shared(input, const_neg_1, false); - auto shape = std::make_shared(reshaped_input, element::i32); + auto reshaped_input = rg.make(input, const_neg_1, false); + auto shape = rg.make(reshaped_input, element::i32); cat_shapes.push_back(shape); NodeVector cat_inputs(meshgrid_inputs.size(), const_1); cat_inputs[input_idx] = shape; input_idx++; - auto input_cat = std::make_shared(cat_inputs, 0); - auto reshape_cat = std::make_shared(reshaped_input, input_cat, false); + auto input_cat = rg.make(cat_inputs, 0); + auto reshape_cat = rg.make(reshaped_input, input_cat, false); reshapes.push_back(reshape_cat); } - auto cat = std::make_shared(cat_shapes, 0); - NodeVector to_copy_rt{cat}; - to_copy_rt.push_back(cat); + auto cat = rg.make(cat_shapes, 0); OutputVector outputs{}; for (auto& reshape : reshapes) { - auto out = std::make_shared(reshape, cat, ov::op::BroadcastType::BIDIRECTIONAL); - to_copy_rt.push_back(out); + auto out = rg.make(reshape, cat, ov::op::BroadcastType::BIDIRECTIONAL); outputs.push_back(out); } if (indexing == "xy" && meshgrid_inputs.size() >= 2) { std::swap(outputs[0], outputs[1]); } - copy_runtime_info(rt_copy_from, to_copy_rt); + copy_runtime_info_and_name(list_unpack, rg.get(), {input_node, meshgrid_input_node}); replace_node(list_unpack, outputs); return true; } @@ -234,17 +222,15 @@ PrimListUnpackReplacer::PrimListUnpackReplacer() { // case aten::size as input // Number of ListUnpack outputs should be equal to rank of input shape. auto axis_0 = opset10::Constant::create(element::i32, Shape{}, {0}); - auto split = std::make_shared(shape_of, axis_0, list_unpack->get_output_size()); + auto split = rg.make(shape_of, axis_0, list_unpack->get_output_size()); - NodeVector to_copy_rt{axis_0, split}; OutputVector res; for (auto output : split->outputs()) { - auto squeeze = std::make_shared(output, axis_0); - to_copy_rt.push_back(squeeze); + auto squeeze = rg.make(output, axis_0); res.push_back(squeeze); } - copy_runtime_info({list_unpack, input_node}, to_copy_rt); + copy_runtime_info_and_name(list_unpack, rg.get(), {input_node}); replace_node(list_unpack, res); return true; @@ -254,17 +240,15 @@ PrimListUnpackReplacer::PrimListUnpackReplacer() { // case aten::slice as input // Number of ListUnpack outputs should be equal to rank of input shape. auto axis_0 = opset10::Constant::create(element::i32, Shape{}, {0}); - auto split = std::make_shared(slice, axis_0, list_unpack->get_output_size()); + auto split = rg.make(slice, axis_0, list_unpack->get_output_size()); - NodeVector to_copy_rt{axis_0, split}; OutputVector res; for (auto output : split->outputs()) { - auto squeeze = std::make_shared(output, axis_0); - to_copy_rt.push_back(squeeze); + auto squeeze = rg.make(output, axis_0); res.push_back(squeeze); } - copy_runtime_info({list_unpack, input_node}, to_copy_rt); + copy_runtime_info_and_name(list_unpack, rg.get(), {input_node}); replace_node(list_unpack, res); return true; diff --git a/src/frontends/pytorch/src/transforms/string_equality_replacer.cpp b/src/frontends/pytorch/src/transforms/string_equality_replacer.cpp index 11315353a3b..0219600799a 100644 --- a/src/frontends/pytorch/src/transforms/string_equality_replacer.cpp +++ b/src/frontends/pytorch/src/transforms/string_equality_replacer.cpp @@ -67,14 +67,14 @@ StringEqualityReplacer::StringEqualityReplacer() { auto equal_node = pattern_map.at(equal_op).get_node_shared_ptr(); if (auto equal = std::dynamic_pointer_cast(equal_node)) { auto const_result = v0::Constant::create(element::boolean, Shape{}, {lhs == rhs}); - copy_runtime_info({lhs_node, rhs_node, equal_node}, const_result); + copy_runtime_info_and_name(equal_node, {const_result}); replace_node(equal_node, const_result); return true; }; auto not_equal_node = pattern_map.at(not_equal_op).get_node_shared_ptr(); if (auto equal = std::dynamic_pointer_cast(not_equal_node)) { auto const_result = v0::Constant::create(element::boolean, Shape{}, {lhs != rhs}); - copy_runtime_info({lhs_node, rhs_node, not_equal_node}, const_result); + copy_runtime_info_and_name(equal_node, {const_result}); replace_node(equal_node, const_result); return true; }; diff --git a/src/frontends/pytorch/src/translate_session.cpp b/src/frontends/pytorch/src/translate_session.cpp index 86894129fe0..594d74a8656 100644 --- a/src/frontends/pytorch/src/translate_session.cpp +++ b/src/frontends/pytorch/src/translate_session.cpp @@ -54,7 +54,26 @@ std::shared_ptr TranslateSession::get_converted_model() { std::shared_ptr TranslateSession::translate_graph(const ov::frontend::InputModel::Ptr& input_model) { auto pytorch_model = std::dynamic_pointer_cast(input_model); FRONT_END_GENERAL_CHECK(pytorch_model != nullptr, "Invalid input model"); - return convert_pytorch_model(pytorch_model->m_model_decoder, {}, pytorch_model->m_descriptors); + auto model = convert_pytorch_model(pytorch_model->m_model_decoder, {}, pytorch_model->m_descriptors); + // First delete tensor indexes from outputs then resolve input names, otherwise Parameter->Result will fail + for (auto& result : model->get_results()) { + auto tensor_desc = result->input_value(0); + auto names = tensor_desc.get_names(); + if (!names.empty()) { + auto tensor_idx = decode_tensor_name(tensor_desc); + if (names.erase(std::to_string(tensor_idx))) { + tensor_desc.set_names(names); + } + } + } + // Set input tensor names to be equal to signature name saved in friendly name + for (auto& param : model->get_parameters()) { + if (param->get_friendly_name() != param->get_name()) { + // get_name is autogenerated name, we need to make sure that this parameter was named by frontend + param->output(0).set_names({param->get_friendly_name()}); + } + } + return model; } std::shared_ptr TranslateSession::convert_pytorch_model( @@ -91,10 +110,8 @@ std::shared_ptr TranslateSession::convert_pytorch_model( } if (!input_node) { auto parameter = std::make_shared(type, pshape); - encode_tensor_name( - parameter->output(0), - inputs.at(i), - {pytorch_model->get_input_debug_name(i), pytorch_model->get_input_signature_name(i)}); + parameter->set_friendly_name(pytorch_model->get_input_signature_name(i)); + encode_tensor_name(parameter->output(0), inputs.at(i), {pytorch_model->get_input_debug_name(i)}); parameters->push_back(parameter); input_node = parameter; auto order = pytorch_model->get_input_transpose_order(i); @@ -404,6 +421,14 @@ Output TranslateSession::get_backprop_op(const std::shared_ptr(node, OutputVector{value}, 1, true); } +void TranslateSession::unique_name(const std::shared_ptr& node) { + if (m_unique_friendly_name_set.count(node->get_friendly_name())) { + node->set_friendly_name(node->get_friendly_name() + '_' + std::to_string(m_friendly_name_counter++)); + } else { + m_unique_friendly_name_set.insert(node->get_friendly_name()); + } +} + } // namespace pytorch } // namespace frontend } // namespace ov diff --git a/src/frontends/pytorch/src/translate_session.hpp b/src/frontends/pytorch/src/translate_session.hpp index 940edf1d867..6cf10d81180 100644 --- a/src/frontends/pytorch/src/translate_session.hpp +++ b/src/frontends/pytorch/src/translate_session.hpp @@ -48,7 +48,8 @@ public: /// \brief Gets pytorch tensor index from openvino tensor size_t decode_tensor_name(const Output& tensor_desc); - size_t m_friendly_name_counter = 0; + /// \brief Make sure Node has unique name + void unique_name(const std::shared_ptr& node); // Maps tensor index to initial tensor index which it is alias to, and to decoder of the node produced this alias // and to the output produced during conversion of this node @@ -64,6 +65,8 @@ private: std::map>> m_counter_map; std::map m_op_statistics; + std::unordered_set m_unique_friendly_name_set; + size_t m_friendly_name_counter = 0; }; } // namespace pytorch diff --git a/src/frontends/pytorch/src/utils.cpp b/src/frontends/pytorch/src/utils.cpp index bda0c085e3c..b6e9d0972be 100644 --- a/src/frontends/pytorch/src/utils.cpp +++ b/src/frontends/pytorch/src/utils.cpp @@ -5,6 +5,7 @@ #include "utils.hpp" #include "op_table.hpp" +#include "openvino/core/rt_info.hpp" #include "openvino/frontend/pytorch/decoder.hpp" #include "openvino/opsets/opset10.hpp" #include "openvino/util/log.hpp" @@ -497,6 +498,30 @@ void add_exception_to_fw_node(std::shared_ptr node, const std::string& msg } } +void copy_runtime_info_and_name(const std::shared_ptr& from, + ov::NodeVector to, + const ov::NodeVector& additional_rt_info_src) { + if (to.size() == 1) { + // We do 1 to 1 matching, no need to process names, just inherit initial name + to[0]->set_friendly_name(from->get_friendly_name()); + } else { + std::unordered_set unique_names; + size_t idx = 0; + for (auto& op : to) { + auto new_name = from->get_friendly_name() + '/' + op->get_type_name(); + if (unique_names.count(new_name)) { + new_name += '_' + std::to_string(idx++); + } else { + unique_names.insert(new_name); + } + op->set_friendly_name(new_name); + } + } + copy_runtime_info(from, to); + if (!additional_rt_info_src.empty()) + copy_runtime_info(additional_rt_info_src, to); +} + } // namespace pytorch } // namespace frontend } // namespace ov diff --git a/src/frontends/pytorch/src/utils.hpp b/src/frontends/pytorch/src/utils.hpp index 32bb115f545..a3732b681ae 100644 --- a/src/frontends/pytorch/src/utils.hpp +++ b/src/frontends/pytorch/src/utils.hpp @@ -72,6 +72,10 @@ void align_output_types(const NodeContext& context, OutputVector& outputs); std::deque> get_list_as_outputs(const Output& start); +void copy_runtime_info_and_name(const std::shared_ptr& from, + ov::NodeVector to, + const ov::NodeVector& additional_rt_info_src = {}); + namespace op { template OutputVector inplace_op(const NodeContext& context) {