[PT FE] Improve operation naming (#17997)

* Improve operation naming

* Use set to reduce operations with indexes, fix code style

* Refactor

* Make good names in transformations

* Remove tensor index from input/output names

* Fix tests
This commit is contained in:
Maxim Vafin 2023-06-15 14:28:38 +02:00 committed by GitHub
parent 483a040d52
commit 9684f9184a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 354 additions and 314 deletions

View File

@ -319,6 +319,13 @@ class TorchScriptPythonDecoder (Decoder):
return self.outputs()[index] return self.outputs()[index]
def mark_node(self, node): def mark_node(self, node):
name = self.graph_element.kind()
if "FrameworkNode" not in node.get_type_name():
name += "/" + node.get_type_name()
if self.graph_element.scopeName():
node.set_friendly_name(self.graph_element.scopeName().split("/")[-1] + "/" + name)
else:
node.set_friendly_name(name)
return node return node
def try_decode_get_attr(self): def try_decode_get_attr(self):

View File

@ -15,6 +15,7 @@ InputModel::InputModel(const std::shared_ptr<TorchDecoder>& model_decoder) : m_m
const auto& inputs = m_model_decoder->inputs(); const auto& inputs = m_model_decoder->inputs();
for (size_t i = 0; i < inputs.size(); ++i) { for (size_t i = 0; i < inputs.size(); ++i) {
auto in_place = std::make_shared<pytorch::Place>(*this, inputs[i]); auto in_place = std::make_shared<pytorch::Place>(*this, inputs[i]);
m_name_to_place.emplace(std::to_string(inputs[i]), std::dynamic_pointer_cast<frontend::Place>(in_place));
for (const auto& name : in_place->get_names()) { for (const auto& name : in_place->get_names()) {
m_name_to_place.emplace(name, std::dynamic_pointer_cast<frontend::Place>(in_place)); m_name_to_place.emplace(name, std::dynamic_pointer_cast<frontend::Place>(in_place));
} }
@ -28,6 +29,7 @@ InputModel::InputModel(const std::shared_ptr<TorchDecoder>& model_decoder) : m_m
const auto& outputs = m_model_decoder->outputs(); const auto& outputs = m_model_decoder->outputs();
for (size_t i = 0; i < outputs.size(); ++i) { for (size_t i = 0; i < outputs.size(); ++i) {
auto out_place = std::make_shared<pytorch::Place>(*this, outputs[i]); auto out_place = std::make_shared<pytorch::Place>(*this, outputs[i]);
m_name_to_place.emplace(std::to_string(inputs[i]), std::dynamic_pointer_cast<frontend::Place>(out_place));
for (const auto& name : out_place->get_names()) { for (const auto& name : out_place->get_names()) {
m_name_to_place.emplace(name, std::dynamic_pointer_cast<frontend::Place>(out_place)); m_name_to_place.emplace(name, std::dynamic_pointer_cast<frontend::Place>(out_place));
} }

View File

@ -40,23 +40,24 @@ OutputVector NodeContext::as_constant() const {
fw_node->set_attrs(attrs); fw_node->set_attrs(attrs);
return {fw_node}; return {fw_node};
} else { } else {
return m_decoder->as_constant(); auto c_outs = m_decoder->as_constant();
FRONT_END_OP_CONVERSION_CHECK(c_outs.size() == 1, "Constant must have exactly one output.");
c_outs[0].get_node_shared_ptr()->set_friendly_name(m_decoder->get_output_debug_name(0));
return c_outs;
} }
} }
std::shared_ptr<Node> NodeContext::mark_node(std::shared_ptr<Node> ov_node) const { std::shared_ptr<Node> NodeContext::mark_node(std::shared_ptr<Node> ov_node) const {
ov_node->set_friendly_name(get_op_type() + '_' + std::to_string(m_translate_session->m_friendly_name_counter++)); ov_node = m_decoder->mark_node(ov_node);
return m_decoder->mark_node(ov_node); m_translate_session->unique_name(ov_node);
return ov_node;
} }
void NodeContext::mutate_input(size_t index, Output<Node> ov_output) const { void NodeContext::mutate_input(size_t index, Output<Node> ov_output) const {
FRONT_END_GENERAL_CHECK(!m_decoder->input_is_none(index), "Input is none with index: ", index); FRONT_END_GENERAL_CHECK(!m_decoder->input_is_none(index), "Input is none with index: ", index);
auto input_id = m_decoder_inputs.at(index); auto input_id = m_decoder_inputs.at(index);
FRONT_END_GENERAL_CHECK(m_tensor_map->count(input_id), "No tensor corresponding input: ", input_id, " exist."); FRONT_END_GENERAL_CHECK(m_tensor_map->count(input_id), "No tensor corresponding input: ", input_id, " exist.");
m_translate_session->encode_tensor_name( m_translate_session->encode_tensor_name(ov_output, input_id, {m_decoder->get_input_debug_name(index)});
ov_output,
input_id,
{m_decoder->get_input_debug_name(index), m_decoder->get_input_signature_name(index)});
(*m_tensor_map)[input_id] = ov_output; (*m_tensor_map)[input_id] = ov_output;
m_mutated_tensors->insert(input_id); m_mutated_tensors->insert(input_id);

View File

@ -18,7 +18,6 @@ Place::Place(const ov::frontend::InputModel& input_model, size_t tensor_index)
m_tensor_index(tensor_index), m_tensor_index(tensor_index),
m_is_input(false), m_is_input(false),
m_is_output(false) { m_is_output(false) {
m_names.push_back(std::to_string(tensor_index));
const auto im = dynamic_cast<const ov::frontend::pytorch::InputModel*>(&m_input_model); const auto im = dynamic_cast<const ov::frontend::pytorch::InputModel*>(&m_input_model);
FRONT_END_GENERAL_CHECK(im, "PyTorch Place requires PyTorch InputModel class."); FRONT_END_GENERAL_CHECK(im, "PyTorch Place requires PyTorch InputModel class.");
const auto& inputs = im->m_model_decoder->inputs(); const auto& inputs = im->m_model_decoder->inputs();
@ -26,24 +25,16 @@ Place::Place(const ov::frontend::InputModel& input_model, size_t tensor_index)
auto in_it = std::find(inputs.begin(), inputs.end(), tensor_index); auto in_it = std::find(inputs.begin(), inputs.end(), tensor_index);
if (in_it != inputs.end()) { if (in_it != inputs.end()) {
m_is_input = true; m_is_input = true;
const auto& debug_name = im->m_model_decoder->get_input_debug_name(std::distance(inputs.begin(), in_it));
if (debug_name != m_names.at(0)) {
m_names.push_back(debug_name);
}
const auto& signature_name = const auto& signature_name =
im->m_model_decoder->get_input_signature_name(std::distance(inputs.begin(), in_it)); im->m_model_decoder->get_input_signature_name(std::distance(inputs.begin(), in_it));
if (signature_name != m_names.at(0) && signature_name != debug_name) {
m_names.push_back(signature_name); m_names.push_back(signature_name);
} }
}
auto out_it = std::find(outputs.begin(), outputs.end(), tensor_index); auto out_it = std::find(outputs.begin(), outputs.end(), tensor_index);
if (out_it != outputs.end()) { if (out_it != outputs.end()) {
m_is_output = true; m_is_output = true;
const auto& debug_name = im->m_model_decoder->get_output_debug_name(std::distance(outputs.begin(), out_it)); const auto& debug_name = im->m_model_decoder->get_output_debug_name(std::distance(outputs.begin(), out_it));
if (debug_name != m_names.at(0)) {
m_names.push_back(debug_name); m_names.push_back(debug_name);
} }
}
if (m_is_input && m_is_output) { if (m_is_input && m_is_output) {
OPENVINO_DEBUG << "[WARNING] Place " << tensor_index << " is input and output at a same time."; OPENVINO_DEBUG << "[WARNING] Place " << tensor_index << " is input and output at a same time.";
} }

View File

@ -21,6 +21,8 @@ namespace frontend {
namespace pytorch { namespace pytorch {
namespace pass { namespace pass {
using namespace ov::op;
AppendListUnpackReplacer::AppendListUnpackReplacer() { AppendListUnpackReplacer::AppendListUnpackReplacer() {
auto list_unpack = ov::pass::pattern::wrap_type<ov::op::util::FrameworkNode>(); auto list_unpack = ov::pass::pattern::wrap_type<ov::op::util::FrameworkNode>();
@ -30,7 +32,7 @@ AppendListUnpackReplacer::AppendListUnpackReplacer() {
return false; return false;
OutputVector tmp_inputs; OutputVector tmp_inputs;
NodeVector rt_copy_from{list_unpack}; NodeVector rt_copy_from;
auto input_node = list_unpack->input_value(0).get_node_shared_ptr(); auto input_node = list_unpack->input_value(0).get_node_shared_ptr();
// Optional aten::__getitem__ node. // Optional aten::__getitem__ node.
@ -60,7 +62,7 @@ AppendListUnpackReplacer::AppendListUnpackReplacer() {
// If aten::__getitem__, expect inputs to be equivalent of pytorch Tensor[][]. // If aten::__getitem__, expect inputs to be equivalent of pytorch Tensor[][].
// Tensor selected by aten::__getitem__ index needs to be splitted in axis 0. // Tensor selected by aten::__getitem__ index needs to be splitted in axis 0.
auto getitem_index_ptr = getitem_node->input_value(1).get_node_shared_ptr(); auto getitem_index_ptr = getitem_node->input_value(1).get_node_shared_ptr();
auto getitem_index_const = std::dynamic_pointer_cast<ov::op::v0::Constant>(getitem_index_ptr); auto getitem_index_const = std::dynamic_pointer_cast<v0::Constant>(getitem_index_ptr);
auto index_val = getitem_index_const->cast_vector<int64_t>(); auto index_val = getitem_index_const->cast_vector<int64_t>();
if (index_val.size() != 1) { if (index_val.size() != 1) {
add_exception_to_fw_node(list_unpack, "prim::ListUnpack: index of aten::__getitem__ is not scalar."); add_exception_to_fw_node(list_unpack, "prim::ListUnpack: index of aten::__getitem__ is not scalar.");
@ -70,16 +72,16 @@ AppendListUnpackReplacer::AppendListUnpackReplacer() {
if (index_val[0] < 0) { if (index_val[0] < 0) {
index = inputs.size() + index; index = inputs.size() + index;
} }
auto axis_0 = ov::op::v0::Constant::create(element::i32, Shape{}, {0}); auto axis_0 = v0::Constant::create(element::i32, Shape{}, {0});
auto split = std::make_shared<ov::op::v1::Split>(inputs[index], axis_0, list_unpack->get_output_size()); auto split = std::make_shared<v1::Split>(inputs[index], axis_0, list_unpack->get_output_size());
NodeVector to_copy_rt{axis_0, split}; NodeVector to_copy_rt{axis_0, split};
OutputVector res; OutputVector res;
for (auto output : split->outputs()) { for (auto output : split->outputs()) {
auto squeeze = std::make_shared<ov::op::v0::Squeeze>(output, axis_0); auto squeeze = std::make_shared<v0::Squeeze>(output, axis_0);
to_copy_rt.push_back(squeeze); to_copy_rt.push_back(squeeze);
res.push_back(squeeze); res.push_back(squeeze);
} }
copy_runtime_info(rt_copy_from, to_copy_rt); copy_runtime_info_and_name(list_unpack, to_copy_rt, rt_copy_from);
replace_node(list_unpack, res); replace_node(list_unpack, res);
return true; return true;
} else { } else {

View File

@ -21,6 +21,8 @@ namespace frontend {
namespace pytorch { namespace pytorch {
namespace pass { namespace pass {
using namespace ov::op;
// aten::cat needs a special handling since it takes a Tensor[] as input. We set the inputs of ListConstruct as the // aten::cat needs a special handling since it takes a Tensor[] as input. We set the inputs of ListConstruct as the
// inputs of cat. // inputs of cat.
// //
@ -41,7 +43,7 @@ AtenCatToConcat::AtenCatToConcat() {
int64_t axis; int64_t axis;
if (cat->get_input_size() > 1) { if (cat->get_input_size() > 1) {
auto axis_node = cat->get_input_node_shared_ptr(1); auto axis_node = cat->get_input_node_shared_ptr(1);
auto axis_const = std::dynamic_pointer_cast<ov::op::v0::Constant>(axis_node); auto axis_const = std::dynamic_pointer_cast<v0::Constant>(axis_node);
if (!axis_const) { if (!axis_const) {
add_exception_to_fw_node(cat, "aten::cat unsupported case: axis is not a constant."); add_exception_to_fw_node(cat, "aten::cat unsupported case: axis is not a constant.");
return false; return false;
@ -62,7 +64,7 @@ AtenCatToConcat::AtenCatToConcat() {
} }
std::shared_ptr<Node> input_node = cat->get_input_node_shared_ptr(0); std::shared_ptr<Node> input_node = cat->get_input_node_shared_ptr(0);
if (auto loop = std::dynamic_pointer_cast<ov::op::v5::Loop>(input_node)) { if (auto loop = std::dynamic_pointer_cast<v5::Loop>(input_node)) {
// case when concatenation is done inside the Loop // case when concatenation is done inside the Loop
auto body = loop->get_function(); auto body = loop->get_function();
auto output_index = cat->input(0).get_source_output().get_index(); auto output_index = cat->input(0).get_source_output().get_index();
@ -82,7 +84,7 @@ AtenCatToConcat::AtenCatToConcat() {
"aten::cat unsupported case: aten::append wasn't found inside prim::Loop body."); "aten::cat unsupported case: aten::append wasn't found inside prim::Loop body.");
return false; return false;
} }
auto param = std::dynamic_pointer_cast<ov::op::v0::Parameter>(append->get_input_node_shared_ptr(0)); auto param = std::dynamic_pointer_cast<v0::Parameter>(append->get_input_node_shared_ptr(0));
if (!param) { if (!param) {
add_exception_to_fw_node(cat, add_exception_to_fw_node(cat,
"aten::cat unsupported case: input of aten::append inside prim::Loop " "aten::cat unsupported case: input of aten::append inside prim::Loop "
@ -106,7 +108,7 @@ AtenCatToConcat::AtenCatToConcat() {
"body is not a prim::ListConstruct."); "body is not a prim::ListConstruct.");
return false; return false;
} }
auto new_result = std::make_shared<ov::op::v0::Result>(append->input_value(1)); auto new_result = std::make_shared<v0::Result>(append->input_value(1));
body->add_results({new_result}); body->add_results({new_result});
auto new_output = loop->get_concatenated_slices(new_result, 0, 1, 1, -1, axis); auto new_output = loop->get_concatenated_slices(new_result, 0, 1, 1, -1, axis);
copy_runtime_info(cat, loop); copy_runtime_info(cat, loop);
@ -115,10 +117,9 @@ AtenCatToConcat::AtenCatToConcat() {
} }
const auto&& tmp_inputs = get_list_as_outputs(cat->get_input_source_output(0)); const auto&& tmp_inputs = get_list_as_outputs(cat->get_input_source_output(0));
auto result = std::make_shared<ov::op::v0::Concat>(OutputVector(tmp_inputs.begin(), tmp_inputs.end()), axis); auto result = std::make_shared<v0::Concat>(OutputVector(tmp_inputs.begin(), tmp_inputs.end()), axis);
copy_runtime_info(cat, result); copy_runtime_info_and_name(cat, {result});
replace_node(cat, result); replace_node(cat, result);
result->set_friendly_name(cat->get_friendly_name());
return true; return true;
}; };

View File

@ -14,6 +14,10 @@
#include "openvino/op/convert.hpp" #include "openvino/op/convert.hpp"
#include "openvino/op/divide.hpp" #include "openvino/op/divide.hpp"
#include "openvino/op/gather.hpp" #include "openvino/op/gather.hpp"
#include "openvino/op/greater.hpp"
#include "openvino/op/greater_eq.hpp"
#include "openvino/op/less.hpp"
#include "openvino/op/mod.hpp"
#include "openvino/op/multiply.hpp" #include "openvino/op/multiply.hpp"
#include "openvino/op/range.hpp" #include "openvino/op/range.hpp"
#include "openvino/op/shape_of.hpp" #include "openvino/op/shape_of.hpp"
@ -22,7 +26,6 @@
#include "openvino/op/unsqueeze.hpp" #include "openvino/op/unsqueeze.hpp"
#include "openvino/op/util/framework_node.hpp" #include "openvino/op/util/framework_node.hpp"
#include "openvino/op/variadic_split.hpp" #include "openvino/op/variadic_split.hpp"
#include "openvino/opsets/opset10.hpp"
#include "openvino/pass/pattern/matcher.hpp" #include "openvino/pass/pattern/matcher.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp" #include "openvino/pass/pattern/op/wrap_type.hpp"
#include "pt_framework_node.hpp" #include "pt_framework_node.hpp"
@ -33,6 +36,8 @@ namespace frontend {
namespace pytorch { namespace pytorch {
namespace pass { namespace pass {
using namespace ov::op;
AtenGetItemReplacer::AtenGetItemReplacer() { AtenGetItemReplacer::AtenGetItemReplacer() {
auto getitem = ov::pass::pattern::wrap_type<ov::op::util::FrameworkNode>(); auto getitem = ov::pass::pattern::wrap_type<ov::op::util::FrameworkNode>();
@ -41,6 +46,7 @@ AtenGetItemReplacer::AtenGetItemReplacer() {
if (!getitem) if (!getitem)
return false; return false;
ov::pass::NodeRegistry rg;
auto input_node = getitem->input_value(0).get_node_shared_ptr(); auto input_node = getitem->input_value(0).get_node_shared_ptr();
if (auto torch_split = cast_fw_node(input_node, "aten::split")) { if (auto torch_split = cast_fw_node(input_node, "aten::split")) {
auto rank = torch_split->input(1).get_partial_shape().rank(); auto rank = torch_split->input(1).get_partial_shape().rank();
@ -51,51 +57,46 @@ AtenGetItemReplacer::AtenGetItemReplacer() {
if (rank.get_length() == 0) { if (rank.get_length() == 0) {
// Based on slice_size and output index select size. // Based on slice_size and output index select size.
// Constants required by transformation. // Constants required by transformation.
auto const_1 = ov::op::v0::Constant::create(element::i32, Shape{1}, {1}); auto const_1 = v0::Constant::create(element::i32, Shape{1}, {1});
auto const_1_0d = ov::op::v0::Constant::create(element::i32, Shape{}, {1}); auto const_1_0d = v0::Constant::create(element::i32, Shape{}, {1});
auto const_0 = ov::op::v0::Constant::create(element::i32, Shape{1}, {0}); auto const_0 = v0::Constant::create(element::i32, Shape{1}, {0});
auto const_0_0d = ov::op::v0::Constant::create(element::i32, Shape{}, {0}); auto const_0_0d = v0::Constant::create(element::i32, Shape{}, {0});
// Load and convert op inputs. // Load and convert op inputs.
auto input = torch_split->get_input_source_output(0); auto input = torch_split->get_input_source_output(0);
auto split_size = torch_split->get_input_source_output(1); auto split_size = torch_split->get_input_source_output(1);
auto split_size_1d = std::make_shared<ov::op::v0::Unsqueeze>(split_size, const_0); auto split_size_1d = rg.make<v0::Unsqueeze>(split_size, const_0);
auto axis = torch_split->get_input_source_output(2); auto axis = torch_split->get_input_source_output(2);
auto axis_1d = std::make_shared<ov::op::v0::Unsqueeze>(axis, const_0); auto axis_1d = rg.make<v0::Unsqueeze>(axis, const_0);
auto getitem_idx = getitem->input(1).get_source_output(); auto getitem_idx = getitem->input(1).get_source_output();
// Calculate number of splits based on input shape and split_size. // Calculate number of splits based on input shape and split_size.
auto shape = std::make_shared<ov::op::v3::ShapeOf>(input, element::i32); auto shape = rg.make<v3::ShapeOf>(input, element::i32);
auto len_to_split = std::make_shared<ov::op::v8::Gather>(shape, axis, const_0); auto len_to_split = rg.make<v8::Gather>(shape, axis, const_0);
// Convert to f64 from int to calculate reminder - last chunk can be smaller if Shape in given axis is // Convert to f64 from int to calculate reminder - last chunk can be smaller if Shape in given axis is
// not equally divisible. // not equally divisible.
auto len_to_split_float = std::make_shared<ov::op::v0::Convert>(len_to_split, element::f64); auto len_to_split_float = rg.make<v0::Convert>(len_to_split, element::f64);
auto split_size_1d_float = std::make_shared<ov::op::v0::Convert>(split_size_1d, element::f64); auto split_size_1d_float = rg.make<v0::Convert>(split_size_1d, element::f64);
auto out_div = std::make_shared<ov::op::v1::Divide>(len_to_split_float, split_size_1d_float); auto out_div = rg.make<v1::Divide>(len_to_split_float, split_size_1d_float);
auto out_num = std::make_shared<ov::op::v0::Ceiling>(out_div); auto out_num = rg.make<v0::Ceiling>(out_div);
auto out_num_0d = std::make_shared<ov::op::v0::Squeeze>(out_num, const_0); auto out_num_0d = rg.make<v0::Squeeze>(out_num, const_0);
// Use Range and Gather to convert negative getitem indexes into positive due problems with indexing // Use Range and Gather to convert negative getitem indexes into positive due problems with indexing
// with -1. // with -1.
auto possible_out_idx = std::make_shared<ov::op::v4::Range>(const_0_0d, auto possible_out_idx =
out_num_0d, rg.make<v4::Range>(const_0_0d, out_num_0d, const_1_0d, split_size.get_element_type());
const_1_0d, auto always_positive_out_idx = rg.make<v8::Gather>(possible_out_idx, getitem_idx, const_0);
split_size.get_element_type());
auto always_positive_out_idx =
std::make_shared<ov::op::v8::Gather>(possible_out_idx, getitem_idx, const_0);
// Use Slice to get only split output selected by getitem idx. Couldn't use VariadicSplit due to // Use Slice to get only split output selected by getitem idx. Couldn't use VariadicSplit due to
// problems with dynamic inputs. // problems with dynamic inputs.
auto split_slice_start = std::make_shared<ov::op::v1::Multiply>(always_positive_out_idx, split_size_1d); auto split_slice_start = rg.make<v1::Multiply>(always_positive_out_idx, split_size_1d);
auto split_slice_end = std::make_shared<ov::op::v1::Add>(split_slice_start, split_size_1d); auto split_slice_end = rg.make<v1::Add>(split_slice_start, split_size_1d);
auto split = auto split = rg.make<v8::Slice>(input, split_slice_start, split_slice_end, const_1, axis_1d);
std::make_shared<ov::op::v8::Slice>(input, split_slice_start, split_slice_end, const_1, axis_1d);
copy_runtime_info({getitem, input_node}, split);
replace_node(getitem, split); replace_node(getitem, split);
} else { } else {
auto getitem_index_ptr = getitem->input_value(1).get_node_shared_ptr(); auto getitem_index_ptr = getitem->input_value(1).get_node_shared_ptr();
auto getitem_index_const = std::dynamic_pointer_cast<ov::op::v0::Constant>(getitem_index_ptr); auto getitem_index_const = std::dynamic_pointer_cast<v0::Constant>(getitem_index_ptr);
auto split = std::make_shared<ov::op::v1::VariadicSplit>(torch_split->get_input_source_output(0), auto split = rg.make<v1::VariadicSplit>(torch_split->get_input_source_output(0),
torch_split->get_input_source_output(2), torch_split->get_input_source_output(2),
torch_split->get_input_source_output(1)); torch_split->get_input_source_output(1));
auto index_val = getitem_index_const->cast_vector<int64_t>(); auto index_val = getitem_index_const->cast_vector<int64_t>();
@ -107,82 +108,71 @@ AtenGetItemReplacer::AtenGetItemReplacer() {
if (index < 0) { if (index < 0) {
index = split->outputs().size() + index; index = split->outputs().size() + index;
} }
OutputVector res{split->outputs()[index]}; replace_node(getitem, {split->outputs()[index]});
copy_runtime_info({getitem, input_node}, split);
replace_node(getitem, res);
} }
return true; } else if (auto list_construct = cast_fw_node(input_node, "prim::ListConstruct")) {
}
if (auto list_construct = cast_fw_node(input_node, "prim::ListConstruct")) {
auto getitem_idx = getitem->input_value(1).get_node_shared_ptr(); auto getitem_idx = getitem->input_value(1).get_node_shared_ptr();
auto getitem_idx_const = std::dynamic_pointer_cast<ov::op::v0::Constant>(getitem_idx); auto getitem_idx_const = std::dynamic_pointer_cast<v0::Constant>(getitem_idx);
if (getitem_idx_const) { if (getitem_idx_const) {
auto idx = getitem_idx_const->cast_vector<int64_t>(); auto idx = getitem_idx_const->cast_vector<int64_t>();
auto element = list_construct->input_value(idx[0]).get_node_shared_ptr(); auto element = list_construct->input_value(idx[0]).get_node_shared_ptr();
copy_runtime_info({getitem, input_node}, element);
replace_node(getitem, element); replace_node(getitem, element);
return true; } else {
}
auto input_concat = concat_list_construct(list_construct); auto input_concat = concat_list_construct(list_construct);
auto zero = ov::op::v0::Constant::create(element::i32, Shape{}, {0}); auto zero = v0::Constant::create(element::i32, Shape{}, {0});
auto gather = std::make_shared<ov::op::v8::Gather>(input_concat, getitem_idx, zero); auto gather = rg.make<v8::Gather>(input_concat, getitem_idx, zero);
copy_runtime_info({getitem, input_node}, gather);
replace_node(getitem, gather); replace_node(getitem, gather);
return true;
} }
if (auto chunk = cast_fw_node(input_node, "aten::chunk")) { } else if (auto chunk = cast_fw_node(input_node, "aten::chunk")) {
auto input_tensor = chunk->get_input_source_output(0); auto input_tensor = chunk->get_input_source_output(0);
auto chunks_i32 = chunk->get_input_source_output(1); auto chunks_i32 = chunk->get_input_source_output(1);
auto dim_i32 = chunk->get_input_source_output(2); auto dim_i32 = chunk->get_input_source_output(2);
auto const_0 = opset10::Constant::create(element::i64, Shape{1}, {0}); auto const_0 = v0::Constant::create(element::i64, Shape{1}, {0});
auto const_1 = opset10::Constant::create(element::i64, Shape{1}, {1}); auto const_1 = v0::Constant::create(element::i64, Shape{1}, {1});
auto const_0_nodim = opset10::Constant::create(element::i64, Shape{}, {0}); auto const_0_nodim = v0::Constant::create(element::i64, Shape{}, {0});
auto getitem_index_i32 = getitem->get_input_source_output(1); auto getitem_index_i32 = getitem->get_input_source_output(1);
auto getitem_index_i64 = std::make_shared<opset10::Convert>(getitem_index_i32, element::i64); auto getitem_index_i64 = rg.make<v0::Convert>(getitem_index_i32, element::i64);
auto getitem_index = std::make_shared<opset10::Unsqueeze>(getitem_index_i64, const_0); auto getitem_index = rg.make<v0::Unsqueeze>(getitem_index_i64, const_0);
auto dim_i64 = std::make_shared<opset10::Convert>(dim_i32, element::i64); auto dim_i64 = rg.make<v0::Convert>(dim_i32, element::i64);
auto dim = std::make_shared<opset10::Unsqueeze>(dim_i64, const_0); auto dim = rg.make<v0::Unsqueeze>(dim_i64, const_0);
auto chunks = std::make_shared<opset10::Convert>(chunks_i32, element::i64); auto chunks = rg.make<v0::Convert>(chunks_i32, element::i64);
auto input_shape = std::make_shared<opset10::ShapeOf>(input_tensor); auto input_shape = rg.make<v3::ShapeOf>(input_tensor);
auto input_dimension = std::make_shared<opset10::Gather>(input_shape, dim, const_0); auto input_dimension = rg.make<v8::Gather>(input_shape, dim, const_0);
auto input_size = std::make_shared<opset10::Squeeze>(input_dimension); auto input_size = rg.make<v0::Squeeze>(input_dimension);
auto chunk_size = std::make_shared<opset10::Divide>(input_size, chunks, true); auto chunk_size = rg.make<v1::Divide>(input_size, chunks, true);
auto last_chunk_size = std::make_shared<opset10::Mod>(input_size, chunks); auto last_chunk_size = rg.make<v1::Mod>(input_size, chunks);
auto is_last_nonzero = std::make_shared<opset10::Greater>(last_chunk_size, const_0_nodim); auto is_last_nonzero = rg.make<v1::Greater>(last_chunk_size, const_0_nodim);
auto is_last_nonzero_int = std::make_shared<opset10::Convert>(is_last_nonzero, element::i64); auto is_last_nonzero_int = rg.make<v0::Convert>(is_last_nonzero, element::i64);
auto computed_chunk_size = std::make_shared<opset10::Add>(chunk_size, is_last_nonzero_int); auto computed_chunk_size = rg.make<v1::Add>(chunk_size, is_last_nonzero_int);
auto computed_last_chunk_size = std::make_shared<opset10::Mod>(input_size, computed_chunk_size); auto computed_last_chunk_size = rg.make<v1::Mod>(input_size, computed_chunk_size);
auto computed_is_last_nonzero = std::make_shared<opset10::Greater>(computed_last_chunk_size, const_0_nodim); auto computed_is_last_nonzero = rg.make<v1::Greater>(computed_last_chunk_size, const_0_nodim);
auto computed_chunks = std::make_shared<opset10::Divide>(input_size, computed_chunk_size, true); auto computed_chunks = rg.make<v1::Divide>(input_size, computed_chunk_size, true);
auto is_slice_normal_size = std::make_shared<opset10::Less>(getitem_index, computed_chunks); auto is_slice_normal_size = rg.make<v1::Less>(getitem_index, computed_chunks);
auto is_slice_not_normal_size = std::make_shared<opset10::GreaterEqual>(getitem_index, computed_chunks); auto is_slice_not_normal_size = rg.make<v1::GreaterEqual>(getitem_index, computed_chunks);
auto is_slice_normal_size_int = std::make_shared<opset10::Convert>(is_slice_normal_size, element::i64); auto is_slice_normal_size_int = rg.make<v0::Convert>(is_slice_normal_size, element::i64);
auto is_slice_not_normal_size_int = auto is_slice_not_normal_size_int = rg.make<v0::Convert>(is_slice_not_normal_size, element::i64);
std::make_shared<opset10::Convert>(is_slice_not_normal_size, element::i64);
auto slice_size_lhs = std::make_shared<opset10::Multiply>(is_slice_normal_size_int, computed_chunk_size); auto slice_size_lhs = rg.make<v1::Multiply>(is_slice_normal_size_int, computed_chunk_size);
auto slice_size_rhs = auto slice_size_rhs = rg.make<v1::Multiply>(is_slice_not_normal_size_int, computed_last_chunk_size);
std::make_shared<opset10::Multiply>(is_slice_not_normal_size_int, computed_last_chunk_size); auto slice_size = rg.make<v1::Add>(slice_size_lhs, slice_size_rhs);
auto slice_size = std::make_shared<opset10::Add>(slice_size_lhs, slice_size_rhs);
auto slice_begin = std::make_shared<opset10::Multiply>(getitem_index, computed_chunk_size); auto slice_begin = rg.make<v1::Multiply>(getitem_index, computed_chunk_size);
auto slice_end = std::make_shared<opset10::Add>(slice_begin, slice_size); auto slice_end = rg.make<v1::Add>(slice_begin, slice_size);
auto sliced_chunk = std::make_shared<opset10::Slice>(input_tensor, slice_begin, slice_end, const_1, dim); auto sliced_chunk = rg.make<v8::Slice>(input_tensor, slice_begin, slice_end, const_1, dim);
copy_runtime_info({getitem, input_node}, sliced_chunk);
replace_node(getitem, sliced_chunk); replace_node(getitem, sliced_chunk);
} else {
return true;
}
return false; return false;
}
copy_runtime_info_and_name(getitem, rg.get(), {input_node});
return true;
}; };
auto m = std::make_shared<ov::pass::pattern::Matcher>(getitem, "ov::frontend::pytorch::pass::AtenGetItemReplacer"); auto m = std::make_shared<ov::pass::pattern::Matcher>(getitem, "ov::frontend::pytorch::pass::AtenGetItemReplacer");

View File

@ -31,10 +31,12 @@ namespace pass {
using namespace ov::op; using namespace ov::op;
namespace { namespace {
Output<Node> generate_zeros_with_convertlike(const Output<Node> sizes, const Output<Node> tensor_of_type) { Output<Node> generate_zeros_with_convertlike(ov::pass::NodeRegistry& rg,
const Output<Node> sizes,
const Output<Node> tensor_of_type) {
auto const_0 = v0::Constant::create(element::i32, Shape{}, {0}); auto const_0 = v0::Constant::create(element::i32, Shape{}, {0});
auto zeros = std::make_shared<v3::Broadcast>(const_0, sizes); auto zeros = rg.make<v3::Broadcast>(const_0, sizes);
return std::make_shared<v1::ConvertLike>(zeros, tensor_of_type); return rg.make<v1::ConvertLike>(zeros, tensor_of_type);
} }
} // namespace } // namespace
@ -46,18 +48,18 @@ AtenIndexPutReplacer::AtenIndexPutReplacer() {
if (!index_op) { if (!index_op) {
return false; return false;
} }
NodeVector rt_copy_from{index_op}; NodeVector rt_copy_from;
ov::pass::NodeRegistry rg;
auto const_0 = v0::Constant::create(element::i32, Shape{}, {0}); auto const_0 = v0::Constant::create(element::i32, Shape{}, {0});
auto const_1 = v0::Constant::create(element::i32, Shape{1}, {1}); auto const_1 = v0::Constant::create(element::i32, Shape{1}, {1});
auto const_max_int = v0::Constant::create(element::i32, Shape{1}, {std::numeric_limits<int32_t>::max()}); auto const_max_int = v0::Constant::create(element::i32, Shape{1}, {std::numeric_limits<int32_t>::max()});
auto const_neg_1 = v0::Constant::create(element::i32, Shape{}, {-1}); auto const_neg_1 = v0::Constant::create(element::i32, Shape{}, {-1});
auto input = index_op->input_value(0); auto input = index_op->input_value(0);
auto input_shape = std::make_shared<v3::ShapeOf>(input, element::i32); auto input_shape = rg.make<v3::ShapeOf>(input, element::i32);
auto indices = index_op->input_value(1); auto indices = index_op->input_value(1);
auto values = index_op->input_value(2); auto values = index_op->input_value(2);
auto acc_const = auto acc_const = std::dynamic_pointer_cast<v0::Constant>(index_op->input_value(3).get_node_shared_ptr());
std::dynamic_pointer_cast<ov::op::v0::Constant>(index_op->input_value(3).get_node_shared_ptr());
if (!acc_const) { if (!acc_const) {
add_exception_to_fw_node(index_op, "aten::index_put_: non constant accumulate input is not supported."); add_exception_to_fw_node(index_op, "aten::index_put_: non constant accumulate input is not supported.");
return false; return false;
@ -85,12 +87,11 @@ AtenIndexPutReplacer::AtenIndexPutReplacer() {
return false; return false;
} }
indices_list_len = indices_first_dim.get_length(); indices_list_len = indices_first_dim.get_length();
auto split = std::make_shared<v1::Split>(indices, const_0, indices_list_len); auto split = rg.make<v1::Split>(indices, const_0, indices_list_len);
indices_inputs = split->outputs(); indices_inputs = split->outputs();
} }
if (indices_list_len == 0) { if (indices_list_len == 0) {
copy_runtime_info(rt_copy_from, values.get_node_shared_ptr());
replace_node(index_op, values.get_node_shared_ptr()); replace_node(index_op, values.get_node_shared_ptr());
return true; return true;
} }
@ -102,52 +103,51 @@ AtenIndexPutReplacer::AtenIndexPutReplacer() {
if (indices_list_len > 1) { if (indices_list_len > 1) {
index = indices_inputs[0]; index = indices_inputs[0];
for (int i = 1; i < indices_list_len; i++) { for (int i = 1; i < indices_list_len; i++) {
index = std::make_shared<v1::Add>(index, indices_inputs[i]); index = rg.make<v1::Add>(index, indices_inputs[i]);
} }
broadcast_index_shape = std::make_shared<v3::ShapeOf>(index, element::i32); broadcast_index_shape = rg.make<v3::ShapeOf>(index, element::i32);
OutputVector indices_list; OutputVector indices_list;
for (int i = 0; i < indices_list_len; i++) { for (int i = 0; i < indices_list_len; i++) {
auto broadcast = std::make_shared<v3::Broadcast>(indices_inputs[i], broadcast_index_shape); auto broadcast = rg.make<v3::Broadcast>(indices_inputs[i], broadcast_index_shape);
auto unsqueeze = std::make_shared<v0::Unsqueeze>(broadcast, const_neg_1); auto unsqueeze = rg.make<v0::Unsqueeze>(broadcast, const_neg_1);
// change negative indices to positive indices // change negative indices to positive indices
auto const_i = v0::Constant::create(element::i32, Shape{}, {i}); auto const_i = v0::Constant::create(element::i32, Shape{}, {i});
auto dim_i = std::make_shared<v8::Gather>(input_shape, const_i, const_0); auto dim_i = rg.make<v8::Gather>(input_shape, const_i, const_0);
auto dim_i_correct_type = std::make_shared<v1::ConvertLike>(dim_i, index); auto dim_i_correct_type = rg.make<v1::ConvertLike>(dim_i, index);
auto unsqueeze_add = std::make_shared<v1::Add>(unsqueeze, dim_i_correct_type); auto unsqueeze_add = rg.make<v1::Add>(unsqueeze, dim_i_correct_type);
auto unsqueeze_add_mod = std::make_shared<v1::Mod>(unsqueeze_add, dim_i_correct_type); auto unsqueeze_add_mod = rg.make<v1::Mod>(unsqueeze_add, dim_i_correct_type);
indices_list.push_back(unsqueeze_add_mod); indices_list.push_back(unsqueeze_add_mod);
} }
index = std::make_shared<v0::Concat>(indices_list, -1); index = rg.make<v0::Concat>(indices_list, -1);
} else { } else {
index = indices_inputs[0]; index = indices_inputs[0];
// change negative indices to positive indices // change negative indices to positive indices
auto dim_0 = (std::make_shared<v8::Gather>(input_shape, const_0, const_0)); auto dim_0 = (rg.make<v8::Gather>(input_shape, const_0, const_0));
auto dim_0_correct_type = (std::make_shared<v1::ConvertLike>(dim_0, index)); auto dim_0_correct_type = (rg.make<v1::ConvertLike>(dim_0, index));
index = std::make_shared<v1::Add>(index, dim_0_correct_type); index = rg.make<v1::Add>(index, dim_0_correct_type);
index = std::make_shared<v1::Mod>(index, dim_0_correct_type); index = rg.make<v1::Mod>(index, dim_0_correct_type);
broadcast_index_shape = std::make_shared<v3::ShapeOf>(index, element::i32); broadcast_index_shape = rg.make<v3::ShapeOf>(index, element::i32);
index = std::make_shared<v0::Unsqueeze>(index, const_neg_1); index = rg.make<v0::Unsqueeze>(index, const_neg_1);
} }
auto sub_data_shape = std::make_shared<v8::Slice>(input_shape, const_indices_list_len, const_max_int, const_1); auto sub_data_shape = rg.make<v8::Slice>(input_shape, const_indices_list_len, const_max_int, const_1);
auto values_shape = std::make_shared<v0::Concat>(OutputVector{broadcast_index_shape, sub_data_shape}, 0); auto values_shape = rg.make<v0::Concat>(OutputVector{broadcast_index_shape, sub_data_shape}, 0);
values = std::make_shared<v3::Broadcast>(values, values_shape); values = rg.make<v3::Broadcast>(values, values_shape);
values = std::make_shared<v1::ConvertLike>(values, input); values = rg.make<v1::ConvertLike>(values, input);
std::shared_ptr<ov::Node> result; std::shared_ptr<ov::Node> result;
if (accumulate) { if (accumulate) {
auto zeros = generate_zeros_with_convertlike(input_shape, input); auto zeros = generate_zeros_with_convertlike(rg, input_shape, input);
auto scatter = std::make_shared<v3::ScatterNDUpdate>(zeros, index, values); auto scatter = rg.make<v3::ScatterNDUpdate>(zeros, index, values);
result = std::make_shared<v1::Add>(input, scatter); result = rg.make<v1::Add>(input, scatter);
} else { } else {
result = std::make_shared<v3::ScatterNDUpdate>(input, index, values); result = rg.make<v3::ScatterNDUpdate>(input, index, values);
} }
copy_runtime_info(rt_copy_from, result); copy_runtime_info_and_name(index_op, rg.get(), rt_copy_from);
replace_node(index_op, result); replace_node(index_op, result);
result->set_friendly_name(index_op->get_friendly_name());
return true; return true;
}; };

View File

@ -33,9 +33,9 @@ namespace pytorch {
namespace pass { namespace pass {
using namespace ov::op; using namespace ov::op;
namespace {
std::shared_ptr<Node> flatten(const Output<Node>& value, size_t axis) { namespace {
Output<Node> flatten(ov::pass::NodeRegistry& rg, const Output<Node>& value, size_t axis) {
// First dimension of output tensor is the product of [d_0, ... d_{axis-1}] dimensions of // First dimension of output tensor is the product of [d_0, ... d_{axis-1}] dimensions of
// input tensor. The last dimension is the product of the rest of input tensor dimensions: // input tensor. The last dimension is the product of the rest of input tensor dimensions:
// [d_{axis}, ..., d_n] // [d_{axis}, ..., d_n]
@ -45,20 +45,20 @@ std::shared_ptr<Node> flatten(const Output<Node>& value, size_t axis) {
} else if (axis == 1) { } else if (axis == 1) {
output_shape = v0::Constant::create(element::i32, Shape{2}, {0, -1}); output_shape = v0::Constant::create(element::i32, Shape{2}, {0, -1});
} else { } else {
const auto value_shape = std::make_shared<v3::ShapeOf>(value, element::i32); const auto value_shape = rg.make<v3::ShapeOf>(value, element::i32);
const auto value_rank = std::make_shared<v3::ShapeOf>(value_shape, element::i32); const auto value_rank = rg.make<v3::ShapeOf>(value_shape, element::i32);
const auto axis_node = v0::Constant::create(element::i32, Shape{1}, {axis}); const auto axis_node = v0::Constant::create(element::i32, Shape{1}, {axis});
auto start = v0::Constant::create(element::i32, Shape{1}, {0}); auto start = v0::Constant::create(element::i32, Shape{1}, {0});
auto step = v0::Constant::create(element::i32, Shape{1}, {1}); auto step = v0::Constant::create(element::i32, Shape{1}, {1});
const auto first_part_dims = std::make_shared<v8::Slice>(value_shape, start, axis_node, step); const auto first_part_dims = rg.make<v8::Slice>(value_shape, start, axis_node, step);
auto zero = v0::Constant::create(element::i32, {}, {0}); auto zero = v0::Constant::create(element::i32, {}, {0});
auto first_part_dims_length = std::make_shared<v1::ReduceProd>(first_part_dims, zero, true); auto first_part_dims_length = rg.make<v1::ReduceProd>(first_part_dims, zero, true);
auto remaining_part_length = v0::Constant::create(element::i32, {1}, {-1}); auto remaining_part_length = v0::Constant::create(element::i32, {1}, {-1});
output_shape = std::make_shared<v0::Concat>(OutputVector{first_part_dims_length, remaining_part_length}, 0); output_shape = rg.make<v0::Concat>(OutputVector{first_part_dims_length, remaining_part_length}, 0);
} }
return std::make_shared<v1::Reshape>(value, output_shape, true); return rg.make<v1::Reshape>(value, output_shape, true);
} }
}; // namespace }; // namespace
@ -70,6 +70,7 @@ AtenIndexToSelect::AtenIndexToSelect() {
if (!index_op) { if (!index_op) {
return false; return false;
} }
ov::pass::NodeRegistry rg;
auto input_node = index_op->input_value(0); auto input_node = index_op->input_value(0);
auto indicies = index_op->input_value(1).get_node_shared_ptr(); auto indicies = index_op->input_value(1).get_node_shared_ptr();
auto list_indicies = cast_fw_node(indicies, "prim::ListConstruct"); auto list_indicies = cast_fw_node(indicies, "prim::ListConstruct");
@ -110,10 +111,10 @@ AtenIndexToSelect::AtenIndexToSelect() {
} }
auto id_dtype = ids[i].get_element_type(); auto id_dtype = ids[i].get_element_type();
if (id_dtype == element::boolean || id_dtype == element::u8) { if (id_dtype == element::boolean || id_dtype == element::u8) {
auto idx = std::make_shared<v0::Convert>(ids[i], element::u8); auto idx = rg.make<v0::Convert>(ids[i], element::u8);
auto nonzero = std::make_shared<v3::NonZero>(idx, element::i32); auto nonzero = rg.make<v3::NonZero>(idx, element::i32);
auto input_order = v0::Constant::create(element::i32, Shape{2}, {1, 0}); auto input_order = v0::Constant::create(element::i32, Shape{2}, {1, 0});
auto masked_id = std::make_shared<v1::Transpose>(nonzero, input_order); auto masked_id = rg.make<v1::Transpose>(nonzero, input_order);
masked_indicies.push_back(masked_id); masked_indicies.push_back(masked_id);
is_masked_bool.push_back(true); is_masked_bool.push_back(true);
} else { } else {
@ -132,17 +133,15 @@ AtenIndexToSelect::AtenIndexToSelect() {
if (advanced_ids.size() == 1) { if (advanced_ids.size() == 1) {
auto index = masked_indicies[advanced_ids[0]]; auto index = masked_indicies[advanced_ids[0]];
if (is_masked_bool[advanced_ids[0]]) { if (is_masked_bool[advanced_ids[0]]) {
auto gather = std::make_shared<v8::GatherND>(input_node, index); auto gather = rg.make<v8::GatherND>(input_node, index);
copy_runtime_info({index_op, indicies}, gather); copy_runtime_info_and_name(index_op, rg.get());
gather->set_friendly_name(index_op->get_friendly_name());
replace_node(index_op, gather); replace_node(index_op, gather);
return true; return true;
} }
index = std::make_shared<v0::Convert>(index, element::i32); index = rg.make<v0::Convert>(index, element::i32);
auto dim = v0::Constant::create(element::i32, Shape{}, {advanced_ids[0]}); auto dim = v0::Constant::create(element::i32, Shape{}, {advanced_ids[0]});
auto gather = std::make_shared<v8::Gather>(input_node, index, dim); auto gather = rg.make<v8::Gather>(input_node, index, dim);
copy_runtime_info({index_op, indicies}, gather); copy_runtime_info_and_name(index_op, rg.get());
gather->set_friendly_name(index_op->get_friendly_name());
replace_node(index_op, gather); replace_node(index_op, gather);
return true; return true;
} }
@ -153,9 +152,9 @@ AtenIndexToSelect::AtenIndexToSelect() {
add_exception_to_fw_node(index_op, "aten::index: dynamic rank for aten::index input is not supported."); add_exception_to_fw_node(index_op, "aten::index: dynamic rank for aten::index input is not supported.");
return false; return false;
} }
auto input_shape = std::make_shared<v3::ShapeOf>(input_node, element::i32); auto input_shape = rg.make<v3::ShapeOf>(input_node, element::i32);
auto zero = v0::Constant::create(element::i32, Shape{}, {0}); auto zero = v0::Constant::create(element::i32, Shape{}, {0});
auto input_dims = std::make_shared<v1::Split>(input_shape, zero, rank.get_length()); auto input_dims = rg.make<v1::Split>(input_shape, zero, rank.get_length());
std::vector<size_t> non_used_dims; std::vector<size_t> non_used_dims;
for (auto i = 0; i < rank.get_length(); i++) { for (auto i = 0; i < rank.get_length(); i++) {
if (std::find(advanced_ids.begin(), advanced_ids.end(), i) == advanced_ids.end()) { if (std::find(advanced_ids.begin(), advanced_ids.end(), i) == advanced_ids.end()) {
@ -166,23 +165,23 @@ AtenIndexToSelect::AtenIndexToSelect() {
permutation_dims.insert(permutation_dims.end(), advanced_ids.begin(), advanced_ids.end()); permutation_dims.insert(permutation_dims.end(), advanced_ids.begin(), advanced_ids.end());
permutation_dims.insert(permutation_dims.end(), non_used_dims.begin(), non_used_dims.end()); permutation_dims.insert(permutation_dims.end(), non_used_dims.begin(), non_used_dims.end());
auto transpose_dims = v0::Constant::create(element::i32, Shape{permutation_dims.size()}, permutation_dims); auto transpose_dims = v0::Constant::create(element::i32, Shape{permutation_dims.size()}, permutation_dims);
auto transposed_input = std::make_shared<v1::Transpose>(input_node, transpose_dims); auto transposed_input = rg.make<v1::Transpose>(input_node, transpose_dims);
auto flatten_input = flatten(transposed_input, adv_idx_count); auto flatten_input = flatten(rg, transposed_input, adv_idx_count);
auto cum_adv_index = masked_indicies[advanced_ids[adv_idx_count - 1]]; auto cum_adv_index = masked_indicies[advanced_ids[adv_idx_count - 1]];
cum_adv_index = std::make_shared<v0::Convert>(cum_adv_index, element::i32); cum_adv_index = rg.make<v0::Convert>(cum_adv_index, element::i32);
auto multiplier = input_dims->output(advanced_ids[adv_idx_count - 1]); auto multiplier = input_dims->output(advanced_ids[adv_idx_count - 1]);
for (int i = static_cast<int>(adv_idx_count) - 2; i > -1; i--) { for (int i = static_cast<int>(adv_idx_count) - 2; i > -1; i--) {
auto m_idx = std::make_shared<v0::Convert>(masked_indicies[i], element::i32); auto m_idx = rg.make<v0::Convert>(masked_indicies[i], element::i32);
auto adv_index = std::make_shared<v1::Multiply>(m_idx, multiplier); auto adv_index = rg.make<v1::Multiply>(m_idx, multiplier);
cum_adv_index = std::make_shared<v1::Add>(cum_adv_index, adv_index); cum_adv_index = rg.make<v1::Add>(cum_adv_index, adv_index);
auto input_id = advanced_ids[i]; auto input_id = advanced_ids[i];
multiplier = std::make_shared<v1::Multiply>(multiplier, input_dims->output(input_id)); multiplier = rg.make<v1::Multiply>(multiplier, input_dims->output(input_id));
} }
std::shared_ptr<Node> gather = std::make_shared<v8::Gather>(flatten_input, cum_adv_index, zero); std::shared_ptr<Node> gather = rg.make<v8::Gather>(flatten_input, cum_adv_index, zero);
OutputVector concat_dims; OutputVector concat_dims;
// check if all advanced indices are consecutive. // check if all advanced indices are consecutive.
std::vector<size_t> consequence_dims; std::vector<size_t> consequence_dims;
auto cum_adv_index_shape_tensor = std::make_shared<v3::ShapeOf>(cum_adv_index, element::i32); auto cum_adv_index_shape_tensor = rg.make<v3::ShapeOf>(cum_adv_index, element::i32);
for (size_t i = advanced_ids[0]; i <= advanced_ids[advanced_ids.size() - 1]; i++) { for (size_t i = advanced_ids[0]; i <= advanced_ids[advanced_ids.size() - 1]; i++) {
consequence_dims.push_back(i); consequence_dims.push_back(i);
} }
@ -194,8 +193,8 @@ AtenIndexToSelect::AtenIndexToSelect() {
for (auto i : non_used_dims) { for (auto i : non_used_dims) {
folded_adv_idx_shape_vector.push_back(input_dims->output(i)); folded_adv_idx_shape_vector.push_back(input_dims->output(i));
} }
auto folded_adv_idx_shape = std::make_shared<v0::Concat>(folded_adv_idx_shape_vector, 0); auto folded_adv_idx_shape = rg.make<v0::Concat>(folded_adv_idx_shape_vector, 0);
gather = std::make_shared<v1::Reshape>(gather, folded_adv_idx_shape, false); gather = rg.make<v1::Reshape>(gather, folded_adv_idx_shape, false);
std::vector<size_t> adv_idx_permute; std::vector<size_t> adv_idx_permute;
for (size_t i = 1; i < advanced_ids[0] + 1; i++) { for (size_t i = 1; i < advanced_ids[0] + 1; i++) {
adv_idx_permute.push_back(i); adv_idx_permute.push_back(i);
@ -207,7 +206,7 @@ AtenIndexToSelect::AtenIndexToSelect() {
// Transpose folded advanced indexed axis to its original location. // Transpose folded advanced indexed axis to its original location.
auto permute_indicies = auto permute_indicies =
v0::Constant::create(element::i32, Shape{adv_idx_permute.size()}, adv_idx_permute); v0::Constant::create(element::i32, Shape{adv_idx_permute.size()}, adv_idx_permute);
gather = std::make_shared<v1::Transpose>(gather, permute_indicies); gather = rg.make<v1::Transpose>(gather, permute_indicies);
// unfold advanced index axes // unfold advanced index axes
for (size_t i = 0; i < advanced_ids[0]; i++) { for (size_t i = 0; i < advanced_ids[0]; i++) {
concat_dims.push_back(input_dims->output(i)); concat_dims.push_back(input_dims->output(i));
@ -226,11 +225,10 @@ AtenIndexToSelect::AtenIndexToSelect() {
concat_dims.push_back(input_dims->output(i)); concat_dims.push_back(input_dims->output(i));
} }
} }
auto final_shape = std::make_shared<v0::Concat>(concat_dims, 0); auto final_shape = rg.make<v0::Concat>(concat_dims, 0);
gather = std::make_shared<v1::Reshape>(gather, final_shape, false); gather = rg.make<v1::Reshape>(gather, final_shape, false);
copy_runtime_info({index_op, indicies}, gather); copy_runtime_info_and_name(index_op, rg.get());
replace_node(index_op, gather); replace_node(index_op, gather);
gather->set_friendly_name(index_op->get_friendly_name());
return true; return true;
} else { } else {
@ -246,22 +244,21 @@ AtenIndexToSelect::AtenIndexToSelect() {
} }
auto index_dtype = indicies->get_output_element_type(0); auto index_dtype = indicies->get_output_element_type(0);
if (index_dtype == element::boolean || index_dtype == element::u8) { if (index_dtype == element::boolean || index_dtype == element::u8) {
auto nonzero = std::make_shared<v3::NonZero>(indicies, element::i32); auto nonzero = rg.make<v3::NonZero>(indicies, element::i32);
auto input_order = v0::Constant::create(element::i32, Shape{2}, {1, 0}); auto input_order = v0::Constant::create(element::i32, Shape{2}, {1, 0});
auto masked_id = std::make_shared<v1::Transpose>(nonzero, input_order); auto masked_id = rg.make<v1::Transpose>(nonzero, input_order);
auto gather = std::make_shared<v8::GatherND>(input_node, masked_id); auto gather = rg.make<v8::GatherND>(input_node, masked_id);
copy_runtime_info({index_op, indicies}, gather); copy_runtime_info_and_name(index_op, rg.get());
replace_node(index_op, gather); replace_node(index_op, gather);
return true; return true;
} }
if (index_dtype != element::i32) { if (index_dtype != element::i32) {
indicies = std::make_shared<ov::op::v0::Convert>(indicies, element::i32); indicies = rg.make<ov::op::v0::Convert>(indicies, element::i32);
} }
auto dim = v0::Constant::create(element::i32, Shape{}, {0}); auto dim = v0::Constant::create(element::i32, Shape{}, {0});
auto gather = std::make_shared<v8::Gather>(input_node, indicies, dim); auto gather = rg.make<v8::Gather>(input_node, indicies, dim);
copy_runtime_info({index_op, indicies}, gather); copy_runtime_info_and_name(index_op, rg.get());
replace_node(index_op, gather); replace_node(index_op, gather);
gather->set_friendly_name(index_op->get_friendly_name());
return true; return true;
} }
add_exception_to_fw_node(index_op, "Unsupported case of aten::index."); add_exception_to_fw_node(index_op, "Unsupported case of aten::index.");

View File

@ -5,27 +5,30 @@
#include "aten_stack_list_construct_replacer.hpp" #include "aten_stack_list_construct_replacer.hpp"
#include "openvino/core/rt_info.hpp" #include "openvino/core/rt_info.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/unsqueeze.hpp"
#include "openvino/op/util/framework_node.hpp" #include "openvino/op/util/framework_node.hpp"
#include "openvino/opsets/opset10.hpp"
#include "openvino/pass/pattern/matcher.hpp" #include "openvino/pass/pattern/matcher.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp" #include "openvino/pass/pattern/op/wrap_type.hpp"
#include "utils.hpp" #include "utils.hpp"
using namespace ov::pass::pattern;
namespace ov { namespace ov {
namespace frontend { namespace frontend {
namespace pytorch { namespace pytorch {
namespace pass { namespace pass {
using namespace ov::op;
using namespace ov::pass::pattern;
AtenStackListConstructReplacer::AtenStackListConstructReplacer() { AtenStackListConstructReplacer::AtenStackListConstructReplacer() {
auto list_construct = ov::pass::pattern::wrap_type<ov::op::util::FrameworkNode>(); auto list_construct = wrap_type<ov::op::util::FrameworkNode>();
auto axis = ov::pass::pattern::wrap_type<opset10::Constant>(); auto axis = wrap_type<v0::Constant>();
// We search for a pattern: ListConstruct -> aten::stack <- Constant // We search for a pattern: ListConstruct -> aten::stack <- Constant
auto stack = ov::pass::pattern::wrap_type<ov::op::util::FrameworkNode>({list_construct, axis}); auto stack = wrap_type<ov::op::util::FrameworkNode>({list_construct, axis});
ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) { ov::matcher_pass_callback callback = [=](Matcher& m) {
auto stack = cast_fw_node(m.get_match_root(), "aten::stack"); auto stack = cast_fw_node(m.get_match_root(), "aten::stack");
if (!stack) { if (!stack) {
return false; return false;
@ -33,23 +36,23 @@ AtenStackListConstructReplacer::AtenStackListConstructReplacer() {
const auto& pattern_map = m.get_pattern_value_map(); const auto& pattern_map = m.get_pattern_value_map();
auto input_node = pattern_map.at(list_construct).get_node_shared_ptr(); auto input_node = pattern_map.at(list_construct).get_node_shared_ptr();
auto axis_node = pattern_map.at(axis).get_node_shared_ptr(); auto axis_node = pattern_map.at(axis).get_node_shared_ptr();
auto axis_const = std::dynamic_pointer_cast<opset10::Constant>(axis_node); auto axis_const = std::dynamic_pointer_cast<v0::Constant>(axis_node);
auto axis = axis_const->cast_vector<int64_t>(); auto axis = axis_const->cast_vector<int64_t>();
// Check if ListConstruct is an input // Check if ListConstruct is an input
if (auto list_construct_node = cast_fw_node(input_node, "prim::ListConstruct")) { if (auto list_construct_node = cast_fw_node(input_node, "prim::ListConstruct")) {
const auto& list_inputs = list_construct_node->input_values(); const auto& list_inputs = list_construct_node->input_values();
OutputVector node_vector; OutputVector node_vector;
auto zero = opset10::Constant::create(element::i32, Shape{}, {0}); auto zero = v0::Constant::create(element::i32, Shape{}, {0});
// Iterate over values in ListConstruct // Iterate over values in ListConstruct
for (const auto& list_input : list_inputs) { for (const auto& list_input : list_inputs) {
auto node = concat_list_construct(list_input); auto node = concat_list_construct(list_input);
auto unsqueezed_node = std::make_shared<opset10::Unsqueeze>(node, axis_const); auto unsqueezed_node = std::make_shared<v0::Unsqueeze>(node, axis_const);
node_vector.push_back(unsqueezed_node); node_vector.push_back(unsqueezed_node);
} }
// Concat vectors on provided axis // Concat vectors on provided axis
auto concat = std::make_shared<opset10::Concat>(node_vector, axis[0]); auto concat = std::make_shared<v0::Concat>(node_vector, axis[0]);
copy_runtime_info({stack, input_node}, concat); copy_runtime_info_and_name(stack, {concat}, {input_node});
replace_node(stack, concat); replace_node(stack, concat);
return true; return true;
} }

View File

@ -50,7 +50,7 @@ AtenEinsumListConstructReplacer::AtenEinsumListConstructReplacer() {
} }
auto einsum = std::make_shared<v7::Einsum>(node_vector, equation); auto einsum = std::make_shared<v7::Einsum>(node_vector, equation);
copy_runtime_info({einsum_op, equation_input, tensor_list}, einsum); copy_runtime_info_and_name(einsum_op, {einsum}, {equation_input, tensor_list});
replace_node(einsum_op, einsum); replace_node(einsum_op, einsum);
return true; return true;
} }

View File

@ -126,7 +126,7 @@ IndexLoopGetitemReplacer::IndexLoopGetitemReplacer() {
auto stop = rg.make<v1::Add>(start, chunks_size_body); auto stop = rg.make<v1::Add>(start, chunks_size_body);
auto curr_chunk = rg.make<v8::Slice>(chunk_param, start, stop, one_1d, dim_body); auto curr_chunk = rg.make<v8::Slice>(chunk_param, start, stop, one_1d, dim_body);
replace_node(getitem, curr_chunk); replace_node(getitem, curr_chunk);
copy_runtime_info({chunk_op, getitem}, rg.get()); copy_runtime_info_and_name(chunk_op, rg.get(), {getitem});
curr_chunk->set_friendly_name(getitem->get_friendly_name()); curr_chunk->set_friendly_name(getitem->get_friendly_name());
return true; return true;
}; };

View File

@ -79,6 +79,7 @@ ListConstructReplacer::ListConstructReplacer() {
// Concatenation is possible because all elements in list should be scalar or 1D tensors, // Concatenation is possible because all elements in list should be scalar or 1D tensors,
// result should be 1D tensor. // result should be 1D tensor.
OutputVector inputs; OutputVector inputs;
ov::pass::NodeRegistry rg;
auto neg_1 = v0::Constant::create(element::i32, Shape{1}, {-1}); auto neg_1 = v0::Constant::create(element::i32, Shape{1}, {-1});
const auto& start_output = list_node->output(0); const auto& start_output = list_node->output(0);
for (const auto& input : get_list_as_outputs(start_output)) { for (const auto& input : get_list_as_outputs(start_output)) {
@ -94,13 +95,12 @@ ListConstructReplacer::ListConstructReplacer() {
return false; return false;
} }
// reshape all elements to 1D // reshape all elements to 1D
auto reshape = std::make_shared<v1::Reshape>(input, neg_1, false); auto reshape = rg.make<v1::Reshape>(input, neg_1, false);
inputs.push_back(reshape); inputs.push_back(reshape);
} }
auto concat = std::make_shared<v0::Concat>(inputs, 0); auto concat = rg.make<v0::Concat>(inputs, 0);
copy_runtime_info({list_node}, concat); copy_runtime_info_and_name(list_node, rg.get());
replace_node(list_node, concat); replace_node(list_node, concat);
concat->set_friendly_name(list_node->get_friendly_name());
return true; return true;
}; };
auto m = std::make_shared<pattern::Matcher>(lc_pattern, "ov::frontend::pytorch::pass::ListConstructReplacer"); auto m = std::make_shared<pattern::Matcher>(lc_pattern, "ov::frontend::pytorch::pass::ListConstructReplacer");

View File

@ -23,6 +23,8 @@ namespace frontend {
namespace pytorch { namespace pytorch {
namespace pass { namespace pass {
using namespace ov::op;
MinMaxPrimListConstructReplacer::MinMaxPrimListConstructReplacer() { MinMaxPrimListConstructReplacer::MinMaxPrimListConstructReplacer() {
auto op = ov::pass::pattern::wrap_type<ov::op::util::FrameworkNode>(); auto op = ov::pass::pattern::wrap_type<ov::op::util::FrameworkNode>();
@ -40,25 +42,26 @@ MinMaxPrimListConstructReplacer::MinMaxPrimListConstructReplacer() {
} else { } else {
op = max_op; op = max_op;
} }
ov::pass::NodeRegistry rg;
auto input_node = op->input_value(0); auto input_node = op->input_value(0);
auto num_inputs = op->inputs().size(); auto num_inputs = op->inputs().size();
auto input = concat_list_construct(input_node); auto input = concat_list_construct(input_node);
std::shared_ptr<Node> reduce_op; std::shared_ptr<Node> reduce_op;
if (num_inputs == 1) { if (num_inputs == 1) {
auto start = std::make_shared<ov::op::v0::Constant>(element::i32, Shape{}, 0); auto start = rg.make<v0::Constant>(element::i32, Shape{}, 0);
auto step = std::make_shared<ov::op::v0::Constant>(element::i32, Shape{}, 1); auto step = rg.make<v0::Constant>(element::i32, Shape{}, 1);
auto shape = std::make_shared<ov::op::v3::ShapeOf>(input, element::i32); auto shape = rg.make<v3::ShapeOf>(input, element::i32);
auto rank = std::make_shared<ov::op::v3::ShapeOf>(shape, element::i32); auto rank = rg.make<v3::ShapeOf>(shape, element::i32);
auto axis_0 = ov::op::v0::Constant::create(element::i32, Shape{}, {0}); auto axis_0 = v0::Constant::create(element::i32, Shape{}, {0});
auto reduced_rank = std::make_shared<ov::op::v0::Squeeze>(rank, axis_0); auto reduced_rank = rg.make<v0::Squeeze>(rank, axis_0);
auto axes = std::make_shared<ov::op::v4::Range>(start, reduced_rank, step, element::i32); auto axes = rg.make<v4::Range>(start, reduced_rank, step, element::i32);
std::shared_ptr<Node> reduce_op; std::shared_ptr<Node> reduce_op;
if (!is_min) { if (!is_min) {
reduce_op = std::make_shared<ov::op::v1::ReduceMax>(input, axes); reduce_op = rg.make<v1::ReduceMax>(input, axes);
} else { } else {
reduce_op = std::make_shared<ov::op::v1::ReduceMin>(input, axes); reduce_op = rg.make<v1::ReduceMin>(input, axes);
} }
copy_runtime_info({op, input_node.get_node_shared_ptr()}, reduce_op); copy_runtime_info_and_name(op, rg.get());
replace_node(op, reduce_op); replace_node(op, reduce_op);
return true; return true;
} }
@ -66,11 +69,11 @@ MinMaxPrimListConstructReplacer::MinMaxPrimListConstructReplacer() {
auto second_input = concat_list_construct(second_input_node); auto second_input = concat_list_construct(second_input_node);
std::shared_ptr<Node> min_or_max_op; std::shared_ptr<Node> min_or_max_op;
if (!is_min) { if (!is_min) {
min_or_max_op = std::make_shared<ov::op::v1::Maximum>(input, second_input); min_or_max_op = rg.make<v1::Maximum>(input, second_input);
} else { } else {
min_or_max_op = std::make_shared<ov::op::v1::Minimum>(input, second_input); min_or_max_op = rg.make<v1::Minimum>(input, second_input);
} }
copy_runtime_info({op, input_node.get_node_shared_ptr()}, min_or_max_op); copy_runtime_info_and_name(op, rg.get());
replace_node(op, min_or_max_op); replace_node(op, min_or_max_op);
return true; return true;
}; };

View File

@ -30,7 +30,8 @@ namespace pass {
using namespace ov::op; using namespace ov::op;
namespace { namespace {
Output<Node> create_padding(const Output<Node>& input_rank, Output<Node> create_padding(ov::pass::NodeRegistry& rg,
const Output<Node>& input_rank,
const Output<Node>& padding, const Output<Node>& padding,
const Output<Node>& start_id, const Output<Node>& start_id,
const Output<Node>& end_id) { const Output<Node>& end_id) {
@ -39,14 +40,14 @@ Output<Node> create_padding(const Output<Node>& input_rank,
// OV expects paddings separated on begins and ends for each dimension from first to last // OV expects paddings separated on begins and ends for each dimension from first to last
auto minus_two = v0::Constant::create(element::i32, Shape{}, {-2}); auto minus_two = v0::Constant::create(element::i32, Shape{}, {-2});
auto zero = v0::Constant::create(element::i32, Shape{}, {0}); auto zero = v0::Constant::create(element::i32, Shape{}, {0});
auto pad_id_range = std::make_shared<v4::Range>(start_id, end_id, minus_two, element::i32); auto pad_id_range = rg.make<v4::Range>(start_id, end_id, minus_two, element::i32);
auto pads = std::make_shared<v8::Gather>(padding, pad_id_range, zero); auto pads = rg.make<v8::Gather>(padding, pad_id_range, zero);
// add left side zero padding for difference between padding size and input rank // add left side zero padding for difference between padding size and input rank
auto pads_short_len = std::make_shared<v3::ShapeOf>(pads, element::i32); auto pads_short_len = rg.make<v3::ShapeOf>(pads, element::i32);
auto pads_diff = std::make_shared<v1::Subtract>(input_rank, pads_short_len); auto pads_diff = rg.make<v1::Subtract>(input_rank, pads_short_len);
auto pads_remaining = std::make_shared<v3::Broadcast>(zero, pads_diff); auto pads_remaining = rg.make<v3::Broadcast>(zero, pads_diff);
auto pads_remaining_c = std::make_shared<v1::ConvertLike>(pads_remaining, pads); auto pads_remaining_c = rg.make<v1::ConvertLike>(pads_remaining, pads);
auto pads_full = std::make_shared<v0::Concat>(OutputVector{pads_remaining_c, pads}, 0); auto pads_full = rg.make<v0::Concat>(OutputVector{pads_remaining_c, pads}, 0);
return pads_full; return pads_full;
} }
@ -64,6 +65,7 @@ PrimListConstructPadReplacer::PrimListConstructPadReplacer() {
if (!pad_op) { if (!pad_op) {
return false; return false;
} }
ov::pass::NodeRegistry rg;
auto minus_two = v0::Constant::create(element::i32, Shape{}, {-2}); auto minus_two = v0::Constant::create(element::i32, Shape{}, {-2});
auto minus_one = v0::Constant::create(element::i32, Shape{}, {-1}); auto minus_one = v0::Constant::create(element::i32, Shape{}, {-1});
auto zero = v0::Constant::create(element::i32, Shape{}, {0}); auto zero = v0::Constant::create(element::i32, Shape{}, {0});
@ -73,15 +75,15 @@ PrimListConstructPadReplacer::PrimListConstructPadReplacer() {
auto pad_values = concat_list_construct(padding); auto pad_values = concat_list_construct(padding);
std::string mode = "constant"; std::string mode = "constant";
auto zero_f = v0::Constant::create(element::f32, Shape{}, {0}); auto zero_f = v0::Constant::create(element::f32, Shape{}, {0});
auto input_shape = std::make_shared<v3::ShapeOf>(input_node, element::i32); auto input_shape = rg.make<v3::ShapeOf>(input_node, element::i32);
auto input_rank = std::make_shared<v3::ShapeOf>(input_shape, element::i32); auto input_rank = rg.make<v3::ShapeOf>(input_shape, element::i32);
auto pad_size_1d = std::make_shared<v3::ShapeOf>(pad_values, element::i32); auto pad_size_1d = rg.make<v3::ShapeOf>(pad_values, element::i32);
auto pad_size = std::make_shared<v0::Squeeze>(pad_size_1d, zero); auto pad_size = rg.make<v0::Squeeze>(pad_size_1d, zero);
// get pad_begins and pad_ends indexes starting for end of paddings // get pad_begins and pad_ends indexes starting for end of paddings
auto start_pad_begins = std::make_shared<v1::Add>(pad_size, minus_two); auto start_pad_begins = rg.make<v1::Add>(pad_size, minus_two);
auto start_pad_ends = std::make_shared<v1::Add>(pad_size, minus_one); auto start_pad_ends = rg.make<v1::Add>(pad_size, minus_one);
auto pad_begins_full = create_padding(input_rank, pad_values, start_pad_begins, minus_one); auto pad_begins_full = create_padding(rg, input_rank, pad_values, start_pad_begins, minus_one);
auto pad_ends_full = create_padding(input_rank, pad_values, start_pad_ends, zero); auto pad_ends_full = create_padding(rg, input_rank, pad_values, start_pad_ends, zero);
auto mode_const = pad_op->input_value(2).get_node_shared_ptr(); auto mode_const = pad_op->input_value(2).get_node_shared_ptr();
auto pad_value = pad_op->input_value(3); auto pad_value = pad_op->input_value(3);
if (const auto& fw_node_mode = cast_fw_node(mode_const, "prim::Constant")) { if (const auto& fw_node_mode = cast_fw_node(mode_const, "prim::Constant")) {
@ -97,16 +99,16 @@ PrimListConstructPadReplacer::PrimListConstructPadReplacer() {
pad_value = zero_f; pad_value = zero_f;
} }
} }
pad_value = std::make_shared<v1::ConvertLike>(pad_value, input_node); pad_value = rg.make<v1::ConvertLike>(pad_value, input_node);
} }
if (PAD_MODES.find(mode) == PAD_MODES.end()) { if (PAD_MODES.find(mode) == PAD_MODES.end()) {
add_exception_to_fw_node(pad_op, "Unsupported mode: " + mode + "for aten::pad"); add_exception_to_fw_node(pad_op, "Unsupported mode: " + mode + "for aten::pad");
return false; return false;
} }
auto pad_mode = PAD_MODES.at(mode); auto pad_mode = PAD_MODES.at(mode);
auto pad = std::make_shared<v1::Pad>(input_node, pad_begins_full, pad_ends_full, pad_value, pad_mode); auto pad = rg.make<v1::Pad>(input_node, pad_begins_full, pad_ends_full, pad_value, pad_mode);
replace_node(pad_op, pad); replace_node(pad_op, pad);
copy_runtime_info({pad_op, padding.get_node_shared_ptr(), mode_const, pad_value.get_node_shared_ptr()}, pad); copy_runtime_info_and_name(pad_op, rg.get());
pad->set_friendly_name(pad_op->get_friendly_name()); pad->set_friendly_name(pad_op->get_friendly_name());
return true; return true;
}; };

View File

@ -28,6 +28,7 @@ PrimListUnpackReplacer::PrimListUnpackReplacer() {
return false; return false;
auto input_node = list_unpack->input_value(0).get_node_shared_ptr(); auto input_node = list_unpack->input_value(0).get_node_shared_ptr();
ov::pass::NodeRegistry rg;
if (auto torch_split = cast_fw_node(input_node, "aten::split")) { if (auto torch_split = cast_fw_node(input_node, "aten::split")) {
auto rank = torch_split->input(1).get_partial_shape().rank(); auto rank = torch_split->input(1).get_partial_shape().rank();
if (rank.is_dynamic()) { if (rank.is_dynamic()) {
@ -44,19 +45,18 @@ PrimListUnpackReplacer::PrimListUnpackReplacer() {
Shape{1}, Shape{1},
{list_unpack->get_output_size() - 1}); {list_unpack->get_output_size() - 1});
auto const_neg_1 = opset10::Constant::create(split_size.get_element_type(), Shape{1}, {-1}); auto const_neg_1 = opset10::Constant::create(split_size.get_element_type(), Shape{1}, {-1});
auto split_lenghts_m_1 = std::make_shared<opset10::Tile>(split_size, num_out_m_1); auto split_lenghts_m_1 = rg.make<opset10::Tile>(split_size, num_out_m_1);
NodeVector concat_inputs{split_lenghts_m_1, const_neg_1}; NodeVector concat_inputs{split_lenghts_m_1, const_neg_1};
auto split_lenghts = std::make_shared<opset10::Concat>(concat_inputs, 0); auto split_lenghts = rg.make<opset10::Concat>(concat_inputs, 0);
split = std::make_shared<opset10::VariadicSplit>(torch_split->get_input_source_output(0), split = rg.make<opset10::VariadicSplit>(torch_split->get_input_source_output(0),
torch_split->get_input_source_output(2), torch_split->get_input_source_output(2),
split_lenghts); split_lenghts);
} else { } else {
split = std::make_shared<opset10::VariadicSplit>(torch_split->get_input_source_output(0), split = rg.make<opset10::VariadicSplit>(torch_split->get_input_source_output(0),
torch_split->get_input_source_output(2), torch_split->get_input_source_output(2),
torch_split->get_input_source_output(1)); torch_split->get_input_source_output(1));
} }
copy_runtime_info({list_unpack, input_node}, split); copy_runtime_info_and_name(list_unpack, rg.get(), {input_node});
split->set_friendly_name(input_node->get_friendly_name());
replace_node(list_unpack, split); replace_node(list_unpack, split);
return true; return true;
@ -64,12 +64,11 @@ PrimListUnpackReplacer::PrimListUnpackReplacer() {
if (auto split_with_sizes = cast_fw_node(input_node, "aten::split_with_sizes")) { if (auto split_with_sizes = cast_fw_node(input_node, "aten::split_with_sizes")) {
auto split_lengths = concat_list_construct(split_with_sizes->get_input_source_output(1)); auto split_lengths = concat_list_construct(split_with_sizes->get_input_source_output(1));
auto split = std::make_shared<opset10::VariadicSplit>(split_with_sizes->get_input_source_output(0), auto split = rg.make<opset10::VariadicSplit>(split_with_sizes->get_input_source_output(0),
split_with_sizes->get_input_source_output(2), split_with_sizes->get_input_source_output(2),
split_lengths); split_lengths);
copy_runtime_info({list_unpack, input_node}, split); copy_runtime_info_and_name(list_unpack, rg.get(), {input_node});
split->set_friendly_name(input_node->get_friendly_name());
replace_node(list_unpack, split); replace_node(list_unpack, split);
return true; return true;
@ -87,27 +86,26 @@ PrimListUnpackReplacer::PrimListUnpackReplacer() {
auto tensor_0 = opset10::Constant::create(element::i32, Shape{1}, {0}); auto tensor_0 = opset10::Constant::create(element::i32, Shape{1}, {0});
auto tensor_neg_1 = opset10::Constant::create(element::i32, Shape{1}, {-1}); auto tensor_neg_1 = opset10::Constant::create(element::i32, Shape{1}, {-1});
auto input_shape = std::make_shared<opset10::ShapeOf>(input_tensor, element::i32); auto input_shape = rg.make<opset10::ShapeOf>(input_tensor, element::i32);
auto input_dimension = std::make_shared<opset10::Gather>(input_shape, dim, tensor_0); auto input_dimension = rg.make<opset10::Gather>(input_shape, dim, tensor_0);
auto init_chunk_size = std::make_shared<opset10::Divide>(input_dimension, chunks, true); auto init_chunk_size = rg.make<opset10::Divide>(input_dimension, chunks, true);
// Add 1 if input is not evenly divisible by chunks // Add 1 if input is not evenly divisible by chunks
auto last_chunk_size = std::make_shared<opset10::Mod>(input_dimension, chunks); auto last_chunk_size = rg.make<opset10::Mod>(input_dimension, chunks);
auto is_last_nonzero = std::make_shared<opset10::Greater>(last_chunk_size, tensor_0); auto is_last_nonzero = rg.make<opset10::Greater>(last_chunk_size, tensor_0);
auto is_last_nonzero_int = std::make_shared<opset10::Convert>(is_last_nonzero, element::i32); auto is_last_nonzero_int = rg.make<opset10::Convert>(is_last_nonzero, element::i32);
auto chunk_size = std::make_shared<opset10::Add>(init_chunk_size, is_last_nonzero_int); auto chunk_size = rg.make<opset10::Add>(init_chunk_size, is_last_nonzero_int);
auto split_lengths_even_size = auto split_lengths_even_size =
opset10::Constant::create(element::i32, Shape{1}, {list_unpack->get_output_size() - 1}); opset10::Constant::create(element::i32, Shape{1}, {list_unpack->get_output_size() - 1});
auto split_lengths_even = std::make_shared<opset10::Broadcast>(chunk_size, split_lengths_even_size); auto split_lengths_even = rg.make<opset10::Broadcast>(chunk_size, split_lengths_even_size);
auto split_lengths = std::make_shared<opset10::Concat>(OutputVector{split_lengths_even, tensor_neg_1}, 0); auto split_lengths = rg.make<opset10::Concat>(OutputVector{split_lengths_even, tensor_neg_1}, 0);
auto sliced_chunks = std::make_shared<opset10::VariadicSplit>(input_tensor, dim, split_lengths); auto sliced_chunks = rg.make<opset10::VariadicSplit>(input_tensor, dim, split_lengths);
copy_runtime_info({list_unpack, input_node}, sliced_chunks); copy_runtime_info_and_name(list_unpack, rg.get(), {input_node});
sliced_chunks->set_friendly_name(input_node->get_friendly_name());
replace_node(list_unpack, sliced_chunks); replace_node(list_unpack, sliced_chunks);
return true; return true;
@ -117,51 +115,45 @@ PrimListUnpackReplacer::PrimListUnpackReplacer() {
const auto input = unbind->get_input_source_output(0); const auto input = unbind->get_input_source_output(0);
const auto axis = unbind->get_input_source_output(1); const auto axis = unbind->get_input_source_output(1);
const auto num_splits = list_unpack->get_output_size(); const auto num_splits = list_unpack->get_output_size();
auto split = std::make_shared<opset10::Split>(input, axis, num_splits); auto split = rg.make<opset10::Split>(input, axis, num_splits);
NodeVector to_copy_rt{split};
OutputVector outputs; OutputVector outputs;
for (auto output : split->outputs()) { for (auto output : split->outputs()) {
const auto squeeze = std::make_shared<opset10::Squeeze>(output, axis); const auto squeeze = rg.make<opset10::Squeeze>(output, axis);
outputs.push_back(squeeze); outputs.push_back(squeeze);
to_copy_rt.push_back(squeeze);
} }
copy_runtime_info({list_unpack, input_node}, to_copy_rt); copy_runtime_info_and_name(list_unpack, rg.get(), {input_node});
replace_node(list_unpack, outputs); replace_node(list_unpack, outputs);
return true; return true;
} }
if (auto where = cast_fw_node(input_node, "aten::where")) { if (auto where = cast_fw_node(input_node, "aten::where")) {
const auto input = where->get_input_source_output(0); const auto input = where->get_input_source_output(0);
auto non_zero = std::make_shared<opset10::NonZero>(input); auto non_zero = rg.make<opset10::NonZero>(input);
auto axis = opset10::Constant::create(element::i32, Shape{}, {0}); auto axis = opset10::Constant::create(element::i32, Shape{}, {0});
const auto num_splits = list_unpack->get_output_size(); const auto num_splits = list_unpack->get_output_size();
auto split = std::make_shared<opset10::Split>(non_zero, axis, num_splits); auto split = rg.make<opset10::Split>(non_zero, axis, num_splits);
NodeVector to_copy_rt{split};
OutputVector outputs; OutputVector outputs;
for (auto output : split->outputs()) { for (auto output : split->outputs()) {
const auto squeeze = std::make_shared<opset10::Squeeze>(output, axis); const auto squeeze = rg.make<opset10::Squeeze>(output, axis);
outputs.push_back(squeeze); outputs.push_back(squeeze);
to_copy_rt.push_back(squeeze);
} }
copy_runtime_info({list_unpack, input_node}, to_copy_rt); copy_runtime_info_and_name(list_unpack, rg.get(), {input_node});
replace_node(list_unpack, outputs); replace_node(list_unpack, outputs);
return true; return true;
} }
if (auto nonzero_numpy = cast_fw_node(input_node, "aten::nonzero_numpy")) { if (auto nonzero_numpy = cast_fw_node(input_node, "aten::nonzero_numpy")) {
const auto input = nonzero_numpy->get_input_source_output(0); const auto input = nonzero_numpy->get_input_source_output(0);
auto non_zero = std::make_shared<opset10::NonZero>(input); auto non_zero = rg.make<opset10::NonZero>(input);
auto axis = opset10::Constant::create(element::i32, Shape{}, {0}); auto axis = opset10::Constant::create(element::i32, Shape{}, {0});
const auto num_splits = list_unpack->get_output_size(); const auto num_splits = list_unpack->get_output_size();
auto split = std::make_shared<opset10::Split>(non_zero, axis, num_splits); auto split = rg.make<opset10::Split>(non_zero, axis, num_splits);
NodeVector to_copy_rt{split};
OutputVector outputs; OutputVector outputs;
for (auto output : split->outputs()) { for (auto output : split->outputs()) {
const auto squeeze = std::make_shared<opset10::Squeeze>(output, axis); const auto squeeze = rg.make<opset10::Squeeze>(output, axis);
outputs.push_back(squeeze); outputs.push_back(squeeze);
to_copy_rt.push_back(squeeze);
} }
copy_runtime_info({list_unpack, input_node}, to_copy_rt); copy_runtime_info_and_name(list_unpack, rg.get(), {input_node});
replace_node(list_unpack, outputs); replace_node(list_unpack, outputs);
return true; return true;
@ -175,7 +167,6 @@ PrimListUnpackReplacer::PrimListUnpackReplacer() {
add_exception_to_fw_node(input_node, "aten::meshgrid: only prim::ListConstruct supported as input."); add_exception_to_fw_node(input_node, "aten::meshgrid: only prim::ListConstruct supported as input.");
return false; return false;
} }
NodeVector rt_copy_from{list_unpack, input_node, meshgrid_input_node};
OutputVector meshgrid_inputs; OutputVector meshgrid_inputs;
for (auto& input : meshgrid_input_node->inputs()) { for (auto& input : meshgrid_input_node->inputs()) {
meshgrid_inputs.push_back(input.get_source_output()); meshgrid_inputs.push_back(input.get_source_output());
@ -203,29 +194,26 @@ PrimListUnpackReplacer::PrimListUnpackReplacer() {
auto const_1 = opset10::Constant::create(element::i32, Shape{1}, {1}); auto const_1 = opset10::Constant::create(element::i32, Shape{1}, {1});
int input_idx = 0; int input_idx = 0;
for (auto& input : meshgrid_inputs) { for (auto& input : meshgrid_inputs) {
auto reshaped_input = std::make_shared<opset10::Reshape>(input, const_neg_1, false); auto reshaped_input = rg.make<opset10::Reshape>(input, const_neg_1, false);
auto shape = std::make_shared<opset10::ShapeOf>(reshaped_input, element::i32); auto shape = rg.make<opset10::ShapeOf>(reshaped_input, element::i32);
cat_shapes.push_back(shape); cat_shapes.push_back(shape);
NodeVector cat_inputs(meshgrid_inputs.size(), const_1); NodeVector cat_inputs(meshgrid_inputs.size(), const_1);
cat_inputs[input_idx] = shape; cat_inputs[input_idx] = shape;
input_idx++; input_idx++;
auto input_cat = std::make_shared<opset10::Concat>(cat_inputs, 0); auto input_cat = rg.make<opset10::Concat>(cat_inputs, 0);
auto reshape_cat = std::make_shared<opset10::Reshape>(reshaped_input, input_cat, false); auto reshape_cat = rg.make<opset10::Reshape>(reshaped_input, input_cat, false);
reshapes.push_back(reshape_cat); reshapes.push_back(reshape_cat);
} }
auto cat = std::make_shared<opset10::Concat>(cat_shapes, 0); auto cat = rg.make<opset10::Concat>(cat_shapes, 0);
NodeVector to_copy_rt{cat};
to_copy_rt.push_back(cat);
OutputVector outputs{}; OutputVector outputs{};
for (auto& reshape : reshapes) { for (auto& reshape : reshapes) {
auto out = std::make_shared<opset10::Broadcast>(reshape, cat, ov::op::BroadcastType::BIDIRECTIONAL); auto out = rg.make<opset10::Broadcast>(reshape, cat, ov::op::BroadcastType::BIDIRECTIONAL);
to_copy_rt.push_back(out);
outputs.push_back(out); outputs.push_back(out);
} }
if (indexing == "xy" && meshgrid_inputs.size() >= 2) { if (indexing == "xy" && meshgrid_inputs.size() >= 2) {
std::swap(outputs[0], outputs[1]); std::swap(outputs[0], outputs[1]);
} }
copy_runtime_info(rt_copy_from, to_copy_rt); copy_runtime_info_and_name(list_unpack, rg.get(), {input_node, meshgrid_input_node});
replace_node(list_unpack, outputs); replace_node(list_unpack, outputs);
return true; return true;
} }
@ -234,17 +222,15 @@ PrimListUnpackReplacer::PrimListUnpackReplacer() {
// case aten::size as input // case aten::size as input
// Number of ListUnpack outputs should be equal to rank of input shape. // Number of ListUnpack outputs should be equal to rank of input shape.
auto axis_0 = opset10::Constant::create(element::i32, Shape{}, {0}); auto axis_0 = opset10::Constant::create(element::i32, Shape{}, {0});
auto split = std::make_shared<opset10::Split>(shape_of, axis_0, list_unpack->get_output_size()); auto split = rg.make<opset10::Split>(shape_of, axis_0, list_unpack->get_output_size());
NodeVector to_copy_rt{axis_0, split};
OutputVector res; OutputVector res;
for (auto output : split->outputs()) { for (auto output : split->outputs()) {
auto squeeze = std::make_shared<opset10::Squeeze>(output, axis_0); auto squeeze = rg.make<opset10::Squeeze>(output, axis_0);
to_copy_rt.push_back(squeeze);
res.push_back(squeeze); res.push_back(squeeze);
} }
copy_runtime_info({list_unpack, input_node}, to_copy_rt); copy_runtime_info_and_name(list_unpack, rg.get(), {input_node});
replace_node(list_unpack, res); replace_node(list_unpack, res);
return true; return true;
@ -254,17 +240,15 @@ PrimListUnpackReplacer::PrimListUnpackReplacer() {
// case aten::slice as input // case aten::slice as input
// Number of ListUnpack outputs should be equal to rank of input shape. // Number of ListUnpack outputs should be equal to rank of input shape.
auto axis_0 = opset10::Constant::create(element::i32, Shape{}, {0}); auto axis_0 = opset10::Constant::create(element::i32, Shape{}, {0});
auto split = std::make_shared<opset10::Split>(slice, axis_0, list_unpack->get_output_size()); auto split = rg.make<opset10::Split>(slice, axis_0, list_unpack->get_output_size());
NodeVector to_copy_rt{axis_0, split};
OutputVector res; OutputVector res;
for (auto output : split->outputs()) { for (auto output : split->outputs()) {
auto squeeze = std::make_shared<opset10::Squeeze>(output, axis_0); auto squeeze = rg.make<opset10::Squeeze>(output, axis_0);
to_copy_rt.push_back(squeeze);
res.push_back(squeeze); res.push_back(squeeze);
} }
copy_runtime_info({list_unpack, input_node}, to_copy_rt); copy_runtime_info_and_name(list_unpack, rg.get(), {input_node});
replace_node(list_unpack, res); replace_node(list_unpack, res);
return true; return true;

View File

@ -67,14 +67,14 @@ StringEqualityReplacer::StringEqualityReplacer() {
auto equal_node = pattern_map.at(equal_op).get_node_shared_ptr(); auto equal_node = pattern_map.at(equal_op).get_node_shared_ptr();
if (auto equal = std::dynamic_pointer_cast<v1::Equal>(equal_node)) { if (auto equal = std::dynamic_pointer_cast<v1::Equal>(equal_node)) {
auto const_result = v0::Constant::create(element::boolean, Shape{}, {lhs == rhs}); auto const_result = v0::Constant::create(element::boolean, Shape{}, {lhs == rhs});
copy_runtime_info({lhs_node, rhs_node, equal_node}, const_result); copy_runtime_info_and_name(equal_node, {const_result});
replace_node(equal_node, const_result); replace_node(equal_node, const_result);
return true; return true;
}; };
auto not_equal_node = pattern_map.at(not_equal_op).get_node_shared_ptr(); auto not_equal_node = pattern_map.at(not_equal_op).get_node_shared_ptr();
if (auto equal = std::dynamic_pointer_cast<v1::NotEqual>(not_equal_node)) { if (auto equal = std::dynamic_pointer_cast<v1::NotEqual>(not_equal_node)) {
auto const_result = v0::Constant::create(element::boolean, Shape{}, {lhs != rhs}); auto const_result = v0::Constant::create(element::boolean, Shape{}, {lhs != rhs});
copy_runtime_info({lhs_node, rhs_node, not_equal_node}, const_result); copy_runtime_info_and_name(equal_node, {const_result});
replace_node(equal_node, const_result); replace_node(equal_node, const_result);
return true; return true;
}; };

View File

@ -54,7 +54,26 @@ std::shared_ptr<ov::Model> TranslateSession::get_converted_model() {
std::shared_ptr<ov::Model> TranslateSession::translate_graph(const ov::frontend::InputModel::Ptr& input_model) { std::shared_ptr<ov::Model> TranslateSession::translate_graph(const ov::frontend::InputModel::Ptr& input_model) {
auto pytorch_model = std::dynamic_pointer_cast<pytorch::InputModel>(input_model); auto pytorch_model = std::dynamic_pointer_cast<pytorch::InputModel>(input_model);
FRONT_END_GENERAL_CHECK(pytorch_model != nullptr, "Invalid input model"); FRONT_END_GENERAL_CHECK(pytorch_model != nullptr, "Invalid input model");
return convert_pytorch_model(pytorch_model->m_model_decoder, {}, pytorch_model->m_descriptors); auto model = convert_pytorch_model(pytorch_model->m_model_decoder, {}, pytorch_model->m_descriptors);
// First delete tensor indexes from outputs then resolve input names, otherwise Parameter->Result will fail
for (auto& result : model->get_results()) {
auto tensor_desc = result->input_value(0);
auto names = tensor_desc.get_names();
if (!names.empty()) {
auto tensor_idx = decode_tensor_name(tensor_desc);
if (names.erase(std::to_string(tensor_idx))) {
tensor_desc.set_names(names);
}
}
}
// Set input tensor names to be equal to signature name saved in friendly name
for (auto& param : model->get_parameters()) {
if (param->get_friendly_name() != param->get_name()) {
// get_name is autogenerated name, we need to make sure that this parameter was named by frontend
param->output(0).set_names({param->get_friendly_name()});
}
}
return model;
} }
std::shared_ptr<Model> TranslateSession::convert_pytorch_model( std::shared_ptr<Model> TranslateSession::convert_pytorch_model(
@ -91,10 +110,8 @@ std::shared_ptr<Model> TranslateSession::convert_pytorch_model(
} }
if (!input_node) { if (!input_node) {
auto parameter = std::make_shared<v0::Parameter>(type, pshape); auto parameter = std::make_shared<v0::Parameter>(type, pshape);
encode_tensor_name( parameter->set_friendly_name(pytorch_model->get_input_signature_name(i));
parameter->output(0), encode_tensor_name(parameter->output(0), inputs.at(i), {pytorch_model->get_input_debug_name(i)});
inputs.at(i),
{pytorch_model->get_input_debug_name(i), pytorch_model->get_input_signature_name(i)});
parameters->push_back(parameter); parameters->push_back(parameter);
input_node = parameter; input_node = parameter;
auto order = pytorch_model->get_input_transpose_order(i); auto order = pytorch_model->get_input_transpose_order(i);
@ -404,6 +421,14 @@ Output<Node> TranslateSession::get_backprop_op(const std::shared_ptr<TorchDecode
return std::make_shared<PtFrameworkNode>(node, OutputVector{value}, 1, true); return std::make_shared<PtFrameworkNode>(node, OutputVector{value}, 1, true);
} }
void TranslateSession::unique_name(const std::shared_ptr<Node>& node) {
if (m_unique_friendly_name_set.count(node->get_friendly_name())) {
node->set_friendly_name(node->get_friendly_name() + '_' + std::to_string(m_friendly_name_counter++));
} else {
m_unique_friendly_name_set.insert(node->get_friendly_name());
}
}
} // namespace pytorch } // namespace pytorch
} // namespace frontend } // namespace frontend
} // namespace ov } // namespace ov

View File

@ -48,7 +48,8 @@ public:
/// \brief Gets pytorch tensor index from openvino tensor /// \brief Gets pytorch tensor index from openvino tensor
size_t decode_tensor_name(const Output<Node>& tensor_desc); size_t decode_tensor_name(const Output<Node>& tensor_desc);
size_t m_friendly_name_counter = 0; /// \brief Make sure Node has unique name
void unique_name(const std::shared_ptr<Node>& node);
// Maps tensor index to initial tensor index which it is alias to, and to decoder of the node produced this alias // Maps tensor index to initial tensor index which it is alias to, and to decoder of the node produced this alias
// and to the output produced during conversion of this node // and to the output produced during conversion of this node
@ -64,6 +65,8 @@ private:
std::map<size_t, std::pair<size_t, Output<Node>>> m_counter_map; std::map<size_t, std::pair<size_t, Output<Node>>> m_counter_map;
std::map<std::string, uint64_t> m_op_statistics; std::map<std::string, uint64_t> m_op_statistics;
std::unordered_set<std::string> m_unique_friendly_name_set;
size_t m_friendly_name_counter = 0;
}; };
} // namespace pytorch } // namespace pytorch

View File

@ -5,6 +5,7 @@
#include "utils.hpp" #include "utils.hpp"
#include "op_table.hpp" #include "op_table.hpp"
#include "openvino/core/rt_info.hpp"
#include "openvino/frontend/pytorch/decoder.hpp" #include "openvino/frontend/pytorch/decoder.hpp"
#include "openvino/opsets/opset10.hpp" #include "openvino/opsets/opset10.hpp"
#include "openvino/util/log.hpp" #include "openvino/util/log.hpp"
@ -497,6 +498,30 @@ void add_exception_to_fw_node(std::shared_ptr<Node> node, const std::string& msg
} }
} }
void copy_runtime_info_and_name(const std::shared_ptr<Node>& from,
ov::NodeVector to,
const ov::NodeVector& additional_rt_info_src) {
if (to.size() == 1) {
// We do 1 to 1 matching, no need to process names, just inherit initial name
to[0]->set_friendly_name(from->get_friendly_name());
} else {
std::unordered_set<std::string> unique_names;
size_t idx = 0;
for (auto& op : to) {
auto new_name = from->get_friendly_name() + '/' + op->get_type_name();
if (unique_names.count(new_name)) {
new_name += '_' + std::to_string(idx++);
} else {
unique_names.insert(new_name);
}
op->set_friendly_name(new_name);
}
}
copy_runtime_info(from, to);
if (!additional_rt_info_src.empty())
copy_runtime_info(additional_rt_info_src, to);
}
} // namespace pytorch } // namespace pytorch
} // namespace frontend } // namespace frontend
} // namespace ov } // namespace ov

View File

@ -72,6 +72,10 @@ void align_output_types(const NodeContext& context, OutputVector& outputs);
std::deque<Output<Node>> get_list_as_outputs(const Output<Node>& start); std::deque<Output<Node>> get_list_as_outputs(const Output<Node>& start);
void copy_runtime_info_and_name(const std::shared_ptr<Node>& from,
ov::NodeVector to,
const ov::NodeVector& additional_rt_info_src = {});
namespace op { namespace op {
template <OutputVector (*T)(const NodeContext&), size_t idx = 0> template <OutputVector (*T)(const NodeContext&), size_t idx = 0>
OutputVector inplace_op(const NodeContext& context) { OutputVector inplace_op(const NodeContext& context) {