[PT FE] Improve operation naming (#17997)
* Improve operation naming * Use set to reduce operations with indexes, fix code style * Refactor * Make good names in transformations * Remove tensor index from input/output names * Fix tests
This commit is contained in:
parent
483a040d52
commit
9684f9184a
@ -319,6 +319,13 @@ class TorchScriptPythonDecoder (Decoder):
|
|||||||
return self.outputs()[index]
|
return self.outputs()[index]
|
||||||
|
|
||||||
def mark_node(self, node):
|
def mark_node(self, node):
|
||||||
|
name = self.graph_element.kind()
|
||||||
|
if "FrameworkNode" not in node.get_type_name():
|
||||||
|
name += "/" + node.get_type_name()
|
||||||
|
if self.graph_element.scopeName():
|
||||||
|
node.set_friendly_name(self.graph_element.scopeName().split("/")[-1] + "/" + name)
|
||||||
|
else:
|
||||||
|
node.set_friendly_name(name)
|
||||||
return node
|
return node
|
||||||
|
|
||||||
def try_decode_get_attr(self):
|
def try_decode_get_attr(self):
|
||||||
|
@ -15,6 +15,7 @@ InputModel::InputModel(const std::shared_ptr<TorchDecoder>& model_decoder) : m_m
|
|||||||
const auto& inputs = m_model_decoder->inputs();
|
const auto& inputs = m_model_decoder->inputs();
|
||||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||||
auto in_place = std::make_shared<pytorch::Place>(*this, inputs[i]);
|
auto in_place = std::make_shared<pytorch::Place>(*this, inputs[i]);
|
||||||
|
m_name_to_place.emplace(std::to_string(inputs[i]), std::dynamic_pointer_cast<frontend::Place>(in_place));
|
||||||
for (const auto& name : in_place->get_names()) {
|
for (const auto& name : in_place->get_names()) {
|
||||||
m_name_to_place.emplace(name, std::dynamic_pointer_cast<frontend::Place>(in_place));
|
m_name_to_place.emplace(name, std::dynamic_pointer_cast<frontend::Place>(in_place));
|
||||||
}
|
}
|
||||||
@ -28,6 +29,7 @@ InputModel::InputModel(const std::shared_ptr<TorchDecoder>& model_decoder) : m_m
|
|||||||
const auto& outputs = m_model_decoder->outputs();
|
const auto& outputs = m_model_decoder->outputs();
|
||||||
for (size_t i = 0; i < outputs.size(); ++i) {
|
for (size_t i = 0; i < outputs.size(); ++i) {
|
||||||
auto out_place = std::make_shared<pytorch::Place>(*this, outputs[i]);
|
auto out_place = std::make_shared<pytorch::Place>(*this, outputs[i]);
|
||||||
|
m_name_to_place.emplace(std::to_string(inputs[i]), std::dynamic_pointer_cast<frontend::Place>(out_place));
|
||||||
for (const auto& name : out_place->get_names()) {
|
for (const auto& name : out_place->get_names()) {
|
||||||
m_name_to_place.emplace(name, std::dynamic_pointer_cast<frontend::Place>(out_place));
|
m_name_to_place.emplace(name, std::dynamic_pointer_cast<frontend::Place>(out_place));
|
||||||
}
|
}
|
||||||
|
@ -40,23 +40,24 @@ OutputVector NodeContext::as_constant() const {
|
|||||||
fw_node->set_attrs(attrs);
|
fw_node->set_attrs(attrs);
|
||||||
return {fw_node};
|
return {fw_node};
|
||||||
} else {
|
} else {
|
||||||
return m_decoder->as_constant();
|
auto c_outs = m_decoder->as_constant();
|
||||||
|
FRONT_END_OP_CONVERSION_CHECK(c_outs.size() == 1, "Constant must have exactly one output.");
|
||||||
|
c_outs[0].get_node_shared_ptr()->set_friendly_name(m_decoder->get_output_debug_name(0));
|
||||||
|
return c_outs;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<Node> NodeContext::mark_node(std::shared_ptr<Node> ov_node) const {
|
std::shared_ptr<Node> NodeContext::mark_node(std::shared_ptr<Node> ov_node) const {
|
||||||
ov_node->set_friendly_name(get_op_type() + '_' + std::to_string(m_translate_session->m_friendly_name_counter++));
|
ov_node = m_decoder->mark_node(ov_node);
|
||||||
return m_decoder->mark_node(ov_node);
|
m_translate_session->unique_name(ov_node);
|
||||||
|
return ov_node;
|
||||||
}
|
}
|
||||||
|
|
||||||
void NodeContext::mutate_input(size_t index, Output<Node> ov_output) const {
|
void NodeContext::mutate_input(size_t index, Output<Node> ov_output) const {
|
||||||
FRONT_END_GENERAL_CHECK(!m_decoder->input_is_none(index), "Input is none with index: ", index);
|
FRONT_END_GENERAL_CHECK(!m_decoder->input_is_none(index), "Input is none with index: ", index);
|
||||||
auto input_id = m_decoder_inputs.at(index);
|
auto input_id = m_decoder_inputs.at(index);
|
||||||
FRONT_END_GENERAL_CHECK(m_tensor_map->count(input_id), "No tensor corresponding input: ", input_id, " exist.");
|
FRONT_END_GENERAL_CHECK(m_tensor_map->count(input_id), "No tensor corresponding input: ", input_id, " exist.");
|
||||||
m_translate_session->encode_tensor_name(
|
m_translate_session->encode_tensor_name(ov_output, input_id, {m_decoder->get_input_debug_name(index)});
|
||||||
ov_output,
|
|
||||||
input_id,
|
|
||||||
{m_decoder->get_input_debug_name(index), m_decoder->get_input_signature_name(index)});
|
|
||||||
(*m_tensor_map)[input_id] = ov_output;
|
(*m_tensor_map)[input_id] = ov_output;
|
||||||
m_mutated_tensors->insert(input_id);
|
m_mutated_tensors->insert(input_id);
|
||||||
|
|
||||||
|
@ -18,7 +18,6 @@ Place::Place(const ov::frontend::InputModel& input_model, size_t tensor_index)
|
|||||||
m_tensor_index(tensor_index),
|
m_tensor_index(tensor_index),
|
||||||
m_is_input(false),
|
m_is_input(false),
|
||||||
m_is_output(false) {
|
m_is_output(false) {
|
||||||
m_names.push_back(std::to_string(tensor_index));
|
|
||||||
const auto im = dynamic_cast<const ov::frontend::pytorch::InputModel*>(&m_input_model);
|
const auto im = dynamic_cast<const ov::frontend::pytorch::InputModel*>(&m_input_model);
|
||||||
FRONT_END_GENERAL_CHECK(im, "PyTorch Place requires PyTorch InputModel class.");
|
FRONT_END_GENERAL_CHECK(im, "PyTorch Place requires PyTorch InputModel class.");
|
||||||
const auto& inputs = im->m_model_decoder->inputs();
|
const auto& inputs = im->m_model_decoder->inputs();
|
||||||
@ -26,24 +25,16 @@ Place::Place(const ov::frontend::InputModel& input_model, size_t tensor_index)
|
|||||||
auto in_it = std::find(inputs.begin(), inputs.end(), tensor_index);
|
auto in_it = std::find(inputs.begin(), inputs.end(), tensor_index);
|
||||||
if (in_it != inputs.end()) {
|
if (in_it != inputs.end()) {
|
||||||
m_is_input = true;
|
m_is_input = true;
|
||||||
const auto& debug_name = im->m_model_decoder->get_input_debug_name(std::distance(inputs.begin(), in_it));
|
|
||||||
if (debug_name != m_names.at(0)) {
|
|
||||||
m_names.push_back(debug_name);
|
|
||||||
}
|
|
||||||
const auto& signature_name =
|
const auto& signature_name =
|
||||||
im->m_model_decoder->get_input_signature_name(std::distance(inputs.begin(), in_it));
|
im->m_model_decoder->get_input_signature_name(std::distance(inputs.begin(), in_it));
|
||||||
if (signature_name != m_names.at(0) && signature_name != debug_name) {
|
|
||||||
m_names.push_back(signature_name);
|
m_names.push_back(signature_name);
|
||||||
}
|
}
|
||||||
}
|
|
||||||
auto out_it = std::find(outputs.begin(), outputs.end(), tensor_index);
|
auto out_it = std::find(outputs.begin(), outputs.end(), tensor_index);
|
||||||
if (out_it != outputs.end()) {
|
if (out_it != outputs.end()) {
|
||||||
m_is_output = true;
|
m_is_output = true;
|
||||||
const auto& debug_name = im->m_model_decoder->get_output_debug_name(std::distance(outputs.begin(), out_it));
|
const auto& debug_name = im->m_model_decoder->get_output_debug_name(std::distance(outputs.begin(), out_it));
|
||||||
if (debug_name != m_names.at(0)) {
|
|
||||||
m_names.push_back(debug_name);
|
m_names.push_back(debug_name);
|
||||||
}
|
}
|
||||||
}
|
|
||||||
if (m_is_input && m_is_output) {
|
if (m_is_input && m_is_output) {
|
||||||
OPENVINO_DEBUG << "[WARNING] Place " << tensor_index << " is input and output at a same time.";
|
OPENVINO_DEBUG << "[WARNING] Place " << tensor_index << " is input and output at a same time.";
|
||||||
}
|
}
|
||||||
|
@ -21,6 +21,8 @@ namespace frontend {
|
|||||||
namespace pytorch {
|
namespace pytorch {
|
||||||
namespace pass {
|
namespace pass {
|
||||||
|
|
||||||
|
using namespace ov::op;
|
||||||
|
|
||||||
AppendListUnpackReplacer::AppendListUnpackReplacer() {
|
AppendListUnpackReplacer::AppendListUnpackReplacer() {
|
||||||
auto list_unpack = ov::pass::pattern::wrap_type<ov::op::util::FrameworkNode>();
|
auto list_unpack = ov::pass::pattern::wrap_type<ov::op::util::FrameworkNode>();
|
||||||
|
|
||||||
@ -30,7 +32,7 @@ AppendListUnpackReplacer::AppendListUnpackReplacer() {
|
|||||||
return false;
|
return false;
|
||||||
|
|
||||||
OutputVector tmp_inputs;
|
OutputVector tmp_inputs;
|
||||||
NodeVector rt_copy_from{list_unpack};
|
NodeVector rt_copy_from;
|
||||||
auto input_node = list_unpack->input_value(0).get_node_shared_ptr();
|
auto input_node = list_unpack->input_value(0).get_node_shared_ptr();
|
||||||
|
|
||||||
// Optional aten::__getitem__ node.
|
// Optional aten::__getitem__ node.
|
||||||
@ -60,7 +62,7 @@ AppendListUnpackReplacer::AppendListUnpackReplacer() {
|
|||||||
// If aten::__getitem__, expect inputs to be equivalent of pytorch Tensor[][].
|
// If aten::__getitem__, expect inputs to be equivalent of pytorch Tensor[][].
|
||||||
// Tensor selected by aten::__getitem__ index needs to be splitted in axis 0.
|
// Tensor selected by aten::__getitem__ index needs to be splitted in axis 0.
|
||||||
auto getitem_index_ptr = getitem_node->input_value(1).get_node_shared_ptr();
|
auto getitem_index_ptr = getitem_node->input_value(1).get_node_shared_ptr();
|
||||||
auto getitem_index_const = std::dynamic_pointer_cast<ov::op::v0::Constant>(getitem_index_ptr);
|
auto getitem_index_const = std::dynamic_pointer_cast<v0::Constant>(getitem_index_ptr);
|
||||||
auto index_val = getitem_index_const->cast_vector<int64_t>();
|
auto index_val = getitem_index_const->cast_vector<int64_t>();
|
||||||
if (index_val.size() != 1) {
|
if (index_val.size() != 1) {
|
||||||
add_exception_to_fw_node(list_unpack, "prim::ListUnpack: index of aten::__getitem__ is not scalar.");
|
add_exception_to_fw_node(list_unpack, "prim::ListUnpack: index of aten::__getitem__ is not scalar.");
|
||||||
@ -70,16 +72,16 @@ AppendListUnpackReplacer::AppendListUnpackReplacer() {
|
|||||||
if (index_val[0] < 0) {
|
if (index_val[0] < 0) {
|
||||||
index = inputs.size() + index;
|
index = inputs.size() + index;
|
||||||
}
|
}
|
||||||
auto axis_0 = ov::op::v0::Constant::create(element::i32, Shape{}, {0});
|
auto axis_0 = v0::Constant::create(element::i32, Shape{}, {0});
|
||||||
auto split = std::make_shared<ov::op::v1::Split>(inputs[index], axis_0, list_unpack->get_output_size());
|
auto split = std::make_shared<v1::Split>(inputs[index], axis_0, list_unpack->get_output_size());
|
||||||
NodeVector to_copy_rt{axis_0, split};
|
NodeVector to_copy_rt{axis_0, split};
|
||||||
OutputVector res;
|
OutputVector res;
|
||||||
for (auto output : split->outputs()) {
|
for (auto output : split->outputs()) {
|
||||||
auto squeeze = std::make_shared<ov::op::v0::Squeeze>(output, axis_0);
|
auto squeeze = std::make_shared<v0::Squeeze>(output, axis_0);
|
||||||
to_copy_rt.push_back(squeeze);
|
to_copy_rt.push_back(squeeze);
|
||||||
res.push_back(squeeze);
|
res.push_back(squeeze);
|
||||||
}
|
}
|
||||||
copy_runtime_info(rt_copy_from, to_copy_rt);
|
copy_runtime_info_and_name(list_unpack, to_copy_rt, rt_copy_from);
|
||||||
replace_node(list_unpack, res);
|
replace_node(list_unpack, res);
|
||||||
return true;
|
return true;
|
||||||
} else {
|
} else {
|
||||||
|
@ -21,6 +21,8 @@ namespace frontend {
|
|||||||
namespace pytorch {
|
namespace pytorch {
|
||||||
namespace pass {
|
namespace pass {
|
||||||
|
|
||||||
|
using namespace ov::op;
|
||||||
|
|
||||||
// aten::cat needs a special handling since it takes a Tensor[] as input. We set the inputs of ListConstruct as the
|
// aten::cat needs a special handling since it takes a Tensor[] as input. We set the inputs of ListConstruct as the
|
||||||
// inputs of cat.
|
// inputs of cat.
|
||||||
//
|
//
|
||||||
@ -41,7 +43,7 @@ AtenCatToConcat::AtenCatToConcat() {
|
|||||||
int64_t axis;
|
int64_t axis;
|
||||||
if (cat->get_input_size() > 1) {
|
if (cat->get_input_size() > 1) {
|
||||||
auto axis_node = cat->get_input_node_shared_ptr(1);
|
auto axis_node = cat->get_input_node_shared_ptr(1);
|
||||||
auto axis_const = std::dynamic_pointer_cast<ov::op::v0::Constant>(axis_node);
|
auto axis_const = std::dynamic_pointer_cast<v0::Constant>(axis_node);
|
||||||
if (!axis_const) {
|
if (!axis_const) {
|
||||||
add_exception_to_fw_node(cat, "aten::cat unsupported case: axis is not a constant.");
|
add_exception_to_fw_node(cat, "aten::cat unsupported case: axis is not a constant.");
|
||||||
return false;
|
return false;
|
||||||
@ -62,7 +64,7 @@ AtenCatToConcat::AtenCatToConcat() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<Node> input_node = cat->get_input_node_shared_ptr(0);
|
std::shared_ptr<Node> input_node = cat->get_input_node_shared_ptr(0);
|
||||||
if (auto loop = std::dynamic_pointer_cast<ov::op::v5::Loop>(input_node)) {
|
if (auto loop = std::dynamic_pointer_cast<v5::Loop>(input_node)) {
|
||||||
// case when concatenation is done inside the Loop
|
// case when concatenation is done inside the Loop
|
||||||
auto body = loop->get_function();
|
auto body = loop->get_function();
|
||||||
auto output_index = cat->input(0).get_source_output().get_index();
|
auto output_index = cat->input(0).get_source_output().get_index();
|
||||||
@ -82,7 +84,7 @@ AtenCatToConcat::AtenCatToConcat() {
|
|||||||
"aten::cat unsupported case: aten::append wasn't found inside prim::Loop body.");
|
"aten::cat unsupported case: aten::append wasn't found inside prim::Loop body.");
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
auto param = std::dynamic_pointer_cast<ov::op::v0::Parameter>(append->get_input_node_shared_ptr(0));
|
auto param = std::dynamic_pointer_cast<v0::Parameter>(append->get_input_node_shared_ptr(0));
|
||||||
if (!param) {
|
if (!param) {
|
||||||
add_exception_to_fw_node(cat,
|
add_exception_to_fw_node(cat,
|
||||||
"aten::cat unsupported case: input of aten::append inside prim::Loop "
|
"aten::cat unsupported case: input of aten::append inside prim::Loop "
|
||||||
@ -106,7 +108,7 @@ AtenCatToConcat::AtenCatToConcat() {
|
|||||||
"body is not a prim::ListConstruct.");
|
"body is not a prim::ListConstruct.");
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
auto new_result = std::make_shared<ov::op::v0::Result>(append->input_value(1));
|
auto new_result = std::make_shared<v0::Result>(append->input_value(1));
|
||||||
body->add_results({new_result});
|
body->add_results({new_result});
|
||||||
auto new_output = loop->get_concatenated_slices(new_result, 0, 1, 1, -1, axis);
|
auto new_output = loop->get_concatenated_slices(new_result, 0, 1, 1, -1, axis);
|
||||||
copy_runtime_info(cat, loop);
|
copy_runtime_info(cat, loop);
|
||||||
@ -115,10 +117,9 @@ AtenCatToConcat::AtenCatToConcat() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const auto&& tmp_inputs = get_list_as_outputs(cat->get_input_source_output(0));
|
const auto&& tmp_inputs = get_list_as_outputs(cat->get_input_source_output(0));
|
||||||
auto result = std::make_shared<ov::op::v0::Concat>(OutputVector(tmp_inputs.begin(), tmp_inputs.end()), axis);
|
auto result = std::make_shared<v0::Concat>(OutputVector(tmp_inputs.begin(), tmp_inputs.end()), axis);
|
||||||
copy_runtime_info(cat, result);
|
copy_runtime_info_and_name(cat, {result});
|
||||||
replace_node(cat, result);
|
replace_node(cat, result);
|
||||||
result->set_friendly_name(cat->get_friendly_name());
|
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
};
|
};
|
||||||
|
@ -14,6 +14,10 @@
|
|||||||
#include "openvino/op/convert.hpp"
|
#include "openvino/op/convert.hpp"
|
||||||
#include "openvino/op/divide.hpp"
|
#include "openvino/op/divide.hpp"
|
||||||
#include "openvino/op/gather.hpp"
|
#include "openvino/op/gather.hpp"
|
||||||
|
#include "openvino/op/greater.hpp"
|
||||||
|
#include "openvino/op/greater_eq.hpp"
|
||||||
|
#include "openvino/op/less.hpp"
|
||||||
|
#include "openvino/op/mod.hpp"
|
||||||
#include "openvino/op/multiply.hpp"
|
#include "openvino/op/multiply.hpp"
|
||||||
#include "openvino/op/range.hpp"
|
#include "openvino/op/range.hpp"
|
||||||
#include "openvino/op/shape_of.hpp"
|
#include "openvino/op/shape_of.hpp"
|
||||||
@ -22,7 +26,6 @@
|
|||||||
#include "openvino/op/unsqueeze.hpp"
|
#include "openvino/op/unsqueeze.hpp"
|
||||||
#include "openvino/op/util/framework_node.hpp"
|
#include "openvino/op/util/framework_node.hpp"
|
||||||
#include "openvino/op/variadic_split.hpp"
|
#include "openvino/op/variadic_split.hpp"
|
||||||
#include "openvino/opsets/opset10.hpp"
|
|
||||||
#include "openvino/pass/pattern/matcher.hpp"
|
#include "openvino/pass/pattern/matcher.hpp"
|
||||||
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||||
#include "pt_framework_node.hpp"
|
#include "pt_framework_node.hpp"
|
||||||
@ -33,6 +36,8 @@ namespace frontend {
|
|||||||
namespace pytorch {
|
namespace pytorch {
|
||||||
namespace pass {
|
namespace pass {
|
||||||
|
|
||||||
|
using namespace ov::op;
|
||||||
|
|
||||||
AtenGetItemReplacer::AtenGetItemReplacer() {
|
AtenGetItemReplacer::AtenGetItemReplacer() {
|
||||||
auto getitem = ov::pass::pattern::wrap_type<ov::op::util::FrameworkNode>();
|
auto getitem = ov::pass::pattern::wrap_type<ov::op::util::FrameworkNode>();
|
||||||
|
|
||||||
@ -41,6 +46,7 @@ AtenGetItemReplacer::AtenGetItemReplacer() {
|
|||||||
if (!getitem)
|
if (!getitem)
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
|
ov::pass::NodeRegistry rg;
|
||||||
auto input_node = getitem->input_value(0).get_node_shared_ptr();
|
auto input_node = getitem->input_value(0).get_node_shared_ptr();
|
||||||
if (auto torch_split = cast_fw_node(input_node, "aten::split")) {
|
if (auto torch_split = cast_fw_node(input_node, "aten::split")) {
|
||||||
auto rank = torch_split->input(1).get_partial_shape().rank();
|
auto rank = torch_split->input(1).get_partial_shape().rank();
|
||||||
@ -51,51 +57,46 @@ AtenGetItemReplacer::AtenGetItemReplacer() {
|
|||||||
if (rank.get_length() == 0) {
|
if (rank.get_length() == 0) {
|
||||||
// Based on slice_size and output index select size.
|
// Based on slice_size and output index select size.
|
||||||
// Constants required by transformation.
|
// Constants required by transformation.
|
||||||
auto const_1 = ov::op::v0::Constant::create(element::i32, Shape{1}, {1});
|
auto const_1 = v0::Constant::create(element::i32, Shape{1}, {1});
|
||||||
auto const_1_0d = ov::op::v0::Constant::create(element::i32, Shape{}, {1});
|
auto const_1_0d = v0::Constant::create(element::i32, Shape{}, {1});
|
||||||
auto const_0 = ov::op::v0::Constant::create(element::i32, Shape{1}, {0});
|
auto const_0 = v0::Constant::create(element::i32, Shape{1}, {0});
|
||||||
auto const_0_0d = ov::op::v0::Constant::create(element::i32, Shape{}, {0});
|
auto const_0_0d = v0::Constant::create(element::i32, Shape{}, {0});
|
||||||
|
|
||||||
// Load and convert op inputs.
|
// Load and convert op inputs.
|
||||||
auto input = torch_split->get_input_source_output(0);
|
auto input = torch_split->get_input_source_output(0);
|
||||||
auto split_size = torch_split->get_input_source_output(1);
|
auto split_size = torch_split->get_input_source_output(1);
|
||||||
auto split_size_1d = std::make_shared<ov::op::v0::Unsqueeze>(split_size, const_0);
|
auto split_size_1d = rg.make<v0::Unsqueeze>(split_size, const_0);
|
||||||
auto axis = torch_split->get_input_source_output(2);
|
auto axis = torch_split->get_input_source_output(2);
|
||||||
auto axis_1d = std::make_shared<ov::op::v0::Unsqueeze>(axis, const_0);
|
auto axis_1d = rg.make<v0::Unsqueeze>(axis, const_0);
|
||||||
auto getitem_idx = getitem->input(1).get_source_output();
|
auto getitem_idx = getitem->input(1).get_source_output();
|
||||||
|
|
||||||
// Calculate number of splits based on input shape and split_size.
|
// Calculate number of splits based on input shape and split_size.
|
||||||
auto shape = std::make_shared<ov::op::v3::ShapeOf>(input, element::i32);
|
auto shape = rg.make<v3::ShapeOf>(input, element::i32);
|
||||||
auto len_to_split = std::make_shared<ov::op::v8::Gather>(shape, axis, const_0);
|
auto len_to_split = rg.make<v8::Gather>(shape, axis, const_0);
|
||||||
// Convert to f64 from int to calculate reminder - last chunk can be smaller if Shape in given axis is
|
// Convert to f64 from int to calculate reminder - last chunk can be smaller if Shape in given axis is
|
||||||
// not equally divisible.
|
// not equally divisible.
|
||||||
auto len_to_split_float = std::make_shared<ov::op::v0::Convert>(len_to_split, element::f64);
|
auto len_to_split_float = rg.make<v0::Convert>(len_to_split, element::f64);
|
||||||
auto split_size_1d_float = std::make_shared<ov::op::v0::Convert>(split_size_1d, element::f64);
|
auto split_size_1d_float = rg.make<v0::Convert>(split_size_1d, element::f64);
|
||||||
auto out_div = std::make_shared<ov::op::v1::Divide>(len_to_split_float, split_size_1d_float);
|
auto out_div = rg.make<v1::Divide>(len_to_split_float, split_size_1d_float);
|
||||||
auto out_num = std::make_shared<ov::op::v0::Ceiling>(out_div);
|
auto out_num = rg.make<v0::Ceiling>(out_div);
|
||||||
auto out_num_0d = std::make_shared<ov::op::v0::Squeeze>(out_num, const_0);
|
auto out_num_0d = rg.make<v0::Squeeze>(out_num, const_0);
|
||||||
|
|
||||||
// Use Range and Gather to convert negative getitem indexes into positive due problems with indexing
|
// Use Range and Gather to convert negative getitem indexes into positive due problems with indexing
|
||||||
// with -1.
|
// with -1.
|
||||||
auto possible_out_idx = std::make_shared<ov::op::v4::Range>(const_0_0d,
|
auto possible_out_idx =
|
||||||
out_num_0d,
|
rg.make<v4::Range>(const_0_0d, out_num_0d, const_1_0d, split_size.get_element_type());
|
||||||
const_1_0d,
|
auto always_positive_out_idx = rg.make<v8::Gather>(possible_out_idx, getitem_idx, const_0);
|
||||||
split_size.get_element_type());
|
|
||||||
auto always_positive_out_idx =
|
|
||||||
std::make_shared<ov::op::v8::Gather>(possible_out_idx, getitem_idx, const_0);
|
|
||||||
|
|
||||||
// Use Slice to get only split output selected by getitem idx. Couldn't use VariadicSplit due to
|
// Use Slice to get only split output selected by getitem idx. Couldn't use VariadicSplit due to
|
||||||
// problems with dynamic inputs.
|
// problems with dynamic inputs.
|
||||||
auto split_slice_start = std::make_shared<ov::op::v1::Multiply>(always_positive_out_idx, split_size_1d);
|
auto split_slice_start = rg.make<v1::Multiply>(always_positive_out_idx, split_size_1d);
|
||||||
auto split_slice_end = std::make_shared<ov::op::v1::Add>(split_slice_start, split_size_1d);
|
auto split_slice_end = rg.make<v1::Add>(split_slice_start, split_size_1d);
|
||||||
auto split =
|
auto split = rg.make<v8::Slice>(input, split_slice_start, split_slice_end, const_1, axis_1d);
|
||||||
std::make_shared<ov::op::v8::Slice>(input, split_slice_start, split_slice_end, const_1, axis_1d);
|
|
||||||
copy_runtime_info({getitem, input_node}, split);
|
|
||||||
replace_node(getitem, split);
|
replace_node(getitem, split);
|
||||||
} else {
|
} else {
|
||||||
auto getitem_index_ptr = getitem->input_value(1).get_node_shared_ptr();
|
auto getitem_index_ptr = getitem->input_value(1).get_node_shared_ptr();
|
||||||
auto getitem_index_const = std::dynamic_pointer_cast<ov::op::v0::Constant>(getitem_index_ptr);
|
auto getitem_index_const = std::dynamic_pointer_cast<v0::Constant>(getitem_index_ptr);
|
||||||
auto split = std::make_shared<ov::op::v1::VariadicSplit>(torch_split->get_input_source_output(0),
|
auto split = rg.make<v1::VariadicSplit>(torch_split->get_input_source_output(0),
|
||||||
torch_split->get_input_source_output(2),
|
torch_split->get_input_source_output(2),
|
||||||
torch_split->get_input_source_output(1));
|
torch_split->get_input_source_output(1));
|
||||||
auto index_val = getitem_index_const->cast_vector<int64_t>();
|
auto index_val = getitem_index_const->cast_vector<int64_t>();
|
||||||
@ -107,82 +108,71 @@ AtenGetItemReplacer::AtenGetItemReplacer() {
|
|||||||
if (index < 0) {
|
if (index < 0) {
|
||||||
index = split->outputs().size() + index;
|
index = split->outputs().size() + index;
|
||||||
}
|
}
|
||||||
OutputVector res{split->outputs()[index]};
|
replace_node(getitem, {split->outputs()[index]});
|
||||||
copy_runtime_info({getitem, input_node}, split);
|
|
||||||
replace_node(getitem, res);
|
|
||||||
}
|
}
|
||||||
return true;
|
} else if (auto list_construct = cast_fw_node(input_node, "prim::ListConstruct")) {
|
||||||
}
|
|
||||||
if (auto list_construct = cast_fw_node(input_node, "prim::ListConstruct")) {
|
|
||||||
auto getitem_idx = getitem->input_value(1).get_node_shared_ptr();
|
auto getitem_idx = getitem->input_value(1).get_node_shared_ptr();
|
||||||
auto getitem_idx_const = std::dynamic_pointer_cast<ov::op::v0::Constant>(getitem_idx);
|
auto getitem_idx_const = std::dynamic_pointer_cast<v0::Constant>(getitem_idx);
|
||||||
if (getitem_idx_const) {
|
if (getitem_idx_const) {
|
||||||
auto idx = getitem_idx_const->cast_vector<int64_t>();
|
auto idx = getitem_idx_const->cast_vector<int64_t>();
|
||||||
auto element = list_construct->input_value(idx[0]).get_node_shared_ptr();
|
auto element = list_construct->input_value(idx[0]).get_node_shared_ptr();
|
||||||
copy_runtime_info({getitem, input_node}, element);
|
|
||||||
replace_node(getitem, element);
|
replace_node(getitem, element);
|
||||||
return true;
|
} else {
|
||||||
}
|
|
||||||
auto input_concat = concat_list_construct(list_construct);
|
auto input_concat = concat_list_construct(list_construct);
|
||||||
auto zero = ov::op::v0::Constant::create(element::i32, Shape{}, {0});
|
auto zero = v0::Constant::create(element::i32, Shape{}, {0});
|
||||||
auto gather = std::make_shared<ov::op::v8::Gather>(input_concat, getitem_idx, zero);
|
auto gather = rg.make<v8::Gather>(input_concat, getitem_idx, zero);
|
||||||
copy_runtime_info({getitem, input_node}, gather);
|
|
||||||
replace_node(getitem, gather);
|
replace_node(getitem, gather);
|
||||||
return true;
|
|
||||||
}
|
}
|
||||||
if (auto chunk = cast_fw_node(input_node, "aten::chunk")) {
|
} else if (auto chunk = cast_fw_node(input_node, "aten::chunk")) {
|
||||||
auto input_tensor = chunk->get_input_source_output(0);
|
auto input_tensor = chunk->get_input_source_output(0);
|
||||||
auto chunks_i32 = chunk->get_input_source_output(1);
|
auto chunks_i32 = chunk->get_input_source_output(1);
|
||||||
auto dim_i32 = chunk->get_input_source_output(2);
|
auto dim_i32 = chunk->get_input_source_output(2);
|
||||||
|
|
||||||
auto const_0 = opset10::Constant::create(element::i64, Shape{1}, {0});
|
auto const_0 = v0::Constant::create(element::i64, Shape{1}, {0});
|
||||||
auto const_1 = opset10::Constant::create(element::i64, Shape{1}, {1});
|
auto const_1 = v0::Constant::create(element::i64, Shape{1}, {1});
|
||||||
auto const_0_nodim = opset10::Constant::create(element::i64, Shape{}, {0});
|
auto const_0_nodim = v0::Constant::create(element::i64, Shape{}, {0});
|
||||||
|
|
||||||
auto getitem_index_i32 = getitem->get_input_source_output(1);
|
auto getitem_index_i32 = getitem->get_input_source_output(1);
|
||||||
auto getitem_index_i64 = std::make_shared<opset10::Convert>(getitem_index_i32, element::i64);
|
auto getitem_index_i64 = rg.make<v0::Convert>(getitem_index_i32, element::i64);
|
||||||
auto getitem_index = std::make_shared<opset10::Unsqueeze>(getitem_index_i64, const_0);
|
auto getitem_index = rg.make<v0::Unsqueeze>(getitem_index_i64, const_0);
|
||||||
auto dim_i64 = std::make_shared<opset10::Convert>(dim_i32, element::i64);
|
auto dim_i64 = rg.make<v0::Convert>(dim_i32, element::i64);
|
||||||
auto dim = std::make_shared<opset10::Unsqueeze>(dim_i64, const_0);
|
auto dim = rg.make<v0::Unsqueeze>(dim_i64, const_0);
|
||||||
auto chunks = std::make_shared<opset10::Convert>(chunks_i32, element::i64);
|
auto chunks = rg.make<v0::Convert>(chunks_i32, element::i64);
|
||||||
|
|
||||||
auto input_shape = std::make_shared<opset10::ShapeOf>(input_tensor);
|
auto input_shape = rg.make<v3::ShapeOf>(input_tensor);
|
||||||
auto input_dimension = std::make_shared<opset10::Gather>(input_shape, dim, const_0);
|
auto input_dimension = rg.make<v8::Gather>(input_shape, dim, const_0);
|
||||||
auto input_size = std::make_shared<opset10::Squeeze>(input_dimension);
|
auto input_size = rg.make<v0::Squeeze>(input_dimension);
|
||||||
|
|
||||||
auto chunk_size = std::make_shared<opset10::Divide>(input_size, chunks, true);
|
auto chunk_size = rg.make<v1::Divide>(input_size, chunks, true);
|
||||||
auto last_chunk_size = std::make_shared<opset10::Mod>(input_size, chunks);
|
auto last_chunk_size = rg.make<v1::Mod>(input_size, chunks);
|
||||||
auto is_last_nonzero = std::make_shared<opset10::Greater>(last_chunk_size, const_0_nodim);
|
auto is_last_nonzero = rg.make<v1::Greater>(last_chunk_size, const_0_nodim);
|
||||||
auto is_last_nonzero_int = std::make_shared<opset10::Convert>(is_last_nonzero, element::i64);
|
auto is_last_nonzero_int = rg.make<v0::Convert>(is_last_nonzero, element::i64);
|
||||||
|
|
||||||
auto computed_chunk_size = std::make_shared<opset10::Add>(chunk_size, is_last_nonzero_int);
|
auto computed_chunk_size = rg.make<v1::Add>(chunk_size, is_last_nonzero_int);
|
||||||
auto computed_last_chunk_size = std::make_shared<opset10::Mod>(input_size, computed_chunk_size);
|
auto computed_last_chunk_size = rg.make<v1::Mod>(input_size, computed_chunk_size);
|
||||||
auto computed_is_last_nonzero = std::make_shared<opset10::Greater>(computed_last_chunk_size, const_0_nodim);
|
auto computed_is_last_nonzero = rg.make<v1::Greater>(computed_last_chunk_size, const_0_nodim);
|
||||||
auto computed_chunks = std::make_shared<opset10::Divide>(input_size, computed_chunk_size, true);
|
auto computed_chunks = rg.make<v1::Divide>(input_size, computed_chunk_size, true);
|
||||||
|
|
||||||
auto is_slice_normal_size = std::make_shared<opset10::Less>(getitem_index, computed_chunks);
|
auto is_slice_normal_size = rg.make<v1::Less>(getitem_index, computed_chunks);
|
||||||
auto is_slice_not_normal_size = std::make_shared<opset10::GreaterEqual>(getitem_index, computed_chunks);
|
auto is_slice_not_normal_size = rg.make<v1::GreaterEqual>(getitem_index, computed_chunks);
|
||||||
auto is_slice_normal_size_int = std::make_shared<opset10::Convert>(is_slice_normal_size, element::i64);
|
auto is_slice_normal_size_int = rg.make<v0::Convert>(is_slice_normal_size, element::i64);
|
||||||
auto is_slice_not_normal_size_int =
|
auto is_slice_not_normal_size_int = rg.make<v0::Convert>(is_slice_not_normal_size, element::i64);
|
||||||
std::make_shared<opset10::Convert>(is_slice_not_normal_size, element::i64);
|
|
||||||
|
|
||||||
auto slice_size_lhs = std::make_shared<opset10::Multiply>(is_slice_normal_size_int, computed_chunk_size);
|
auto slice_size_lhs = rg.make<v1::Multiply>(is_slice_normal_size_int, computed_chunk_size);
|
||||||
auto slice_size_rhs =
|
auto slice_size_rhs = rg.make<v1::Multiply>(is_slice_not_normal_size_int, computed_last_chunk_size);
|
||||||
std::make_shared<opset10::Multiply>(is_slice_not_normal_size_int, computed_last_chunk_size);
|
auto slice_size = rg.make<v1::Add>(slice_size_lhs, slice_size_rhs);
|
||||||
auto slice_size = std::make_shared<opset10::Add>(slice_size_lhs, slice_size_rhs);
|
|
||||||
|
|
||||||
auto slice_begin = std::make_shared<opset10::Multiply>(getitem_index, computed_chunk_size);
|
auto slice_begin = rg.make<v1::Multiply>(getitem_index, computed_chunk_size);
|
||||||
auto slice_end = std::make_shared<opset10::Add>(slice_begin, slice_size);
|
auto slice_end = rg.make<v1::Add>(slice_begin, slice_size);
|
||||||
|
|
||||||
auto sliced_chunk = std::make_shared<opset10::Slice>(input_tensor, slice_begin, slice_end, const_1, dim);
|
auto sliced_chunk = rg.make<v8::Slice>(input_tensor, slice_begin, slice_end, const_1, dim);
|
||||||
|
|
||||||
copy_runtime_info({getitem, input_node}, sliced_chunk);
|
|
||||||
replace_node(getitem, sliced_chunk);
|
replace_node(getitem, sliced_chunk);
|
||||||
|
} else {
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
return false;
|
return false;
|
||||||
|
}
|
||||||
|
copy_runtime_info_and_name(getitem, rg.get(), {input_node});
|
||||||
|
return true;
|
||||||
};
|
};
|
||||||
|
|
||||||
auto m = std::make_shared<ov::pass::pattern::Matcher>(getitem, "ov::frontend::pytorch::pass::AtenGetItemReplacer");
|
auto m = std::make_shared<ov::pass::pattern::Matcher>(getitem, "ov::frontend::pytorch::pass::AtenGetItemReplacer");
|
||||||
|
@ -31,10 +31,12 @@ namespace pass {
|
|||||||
using namespace ov::op;
|
using namespace ov::op;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
Output<Node> generate_zeros_with_convertlike(const Output<Node> sizes, const Output<Node> tensor_of_type) {
|
Output<Node> generate_zeros_with_convertlike(ov::pass::NodeRegistry& rg,
|
||||||
|
const Output<Node> sizes,
|
||||||
|
const Output<Node> tensor_of_type) {
|
||||||
auto const_0 = v0::Constant::create(element::i32, Shape{}, {0});
|
auto const_0 = v0::Constant::create(element::i32, Shape{}, {0});
|
||||||
auto zeros = std::make_shared<v3::Broadcast>(const_0, sizes);
|
auto zeros = rg.make<v3::Broadcast>(const_0, sizes);
|
||||||
return std::make_shared<v1::ConvertLike>(zeros, tensor_of_type);
|
return rg.make<v1::ConvertLike>(zeros, tensor_of_type);
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
@ -46,18 +48,18 @@ AtenIndexPutReplacer::AtenIndexPutReplacer() {
|
|||||||
if (!index_op) {
|
if (!index_op) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
NodeVector rt_copy_from{index_op};
|
NodeVector rt_copy_from;
|
||||||
|
ov::pass::NodeRegistry rg;
|
||||||
auto const_0 = v0::Constant::create(element::i32, Shape{}, {0});
|
auto const_0 = v0::Constant::create(element::i32, Shape{}, {0});
|
||||||
auto const_1 = v0::Constant::create(element::i32, Shape{1}, {1});
|
auto const_1 = v0::Constant::create(element::i32, Shape{1}, {1});
|
||||||
auto const_max_int = v0::Constant::create(element::i32, Shape{1}, {std::numeric_limits<int32_t>::max()});
|
auto const_max_int = v0::Constant::create(element::i32, Shape{1}, {std::numeric_limits<int32_t>::max()});
|
||||||
auto const_neg_1 = v0::Constant::create(element::i32, Shape{}, {-1});
|
auto const_neg_1 = v0::Constant::create(element::i32, Shape{}, {-1});
|
||||||
|
|
||||||
auto input = index_op->input_value(0);
|
auto input = index_op->input_value(0);
|
||||||
auto input_shape = std::make_shared<v3::ShapeOf>(input, element::i32);
|
auto input_shape = rg.make<v3::ShapeOf>(input, element::i32);
|
||||||
auto indices = index_op->input_value(1);
|
auto indices = index_op->input_value(1);
|
||||||
auto values = index_op->input_value(2);
|
auto values = index_op->input_value(2);
|
||||||
auto acc_const =
|
auto acc_const = std::dynamic_pointer_cast<v0::Constant>(index_op->input_value(3).get_node_shared_ptr());
|
||||||
std::dynamic_pointer_cast<ov::op::v0::Constant>(index_op->input_value(3).get_node_shared_ptr());
|
|
||||||
if (!acc_const) {
|
if (!acc_const) {
|
||||||
add_exception_to_fw_node(index_op, "aten::index_put_: non constant accumulate input is not supported.");
|
add_exception_to_fw_node(index_op, "aten::index_put_: non constant accumulate input is not supported.");
|
||||||
return false;
|
return false;
|
||||||
@ -85,12 +87,11 @@ AtenIndexPutReplacer::AtenIndexPutReplacer() {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
indices_list_len = indices_first_dim.get_length();
|
indices_list_len = indices_first_dim.get_length();
|
||||||
auto split = std::make_shared<v1::Split>(indices, const_0, indices_list_len);
|
auto split = rg.make<v1::Split>(indices, const_0, indices_list_len);
|
||||||
indices_inputs = split->outputs();
|
indices_inputs = split->outputs();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (indices_list_len == 0) {
|
if (indices_list_len == 0) {
|
||||||
copy_runtime_info(rt_copy_from, values.get_node_shared_ptr());
|
|
||||||
replace_node(index_op, values.get_node_shared_ptr());
|
replace_node(index_op, values.get_node_shared_ptr());
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
@ -102,52 +103,51 @@ AtenIndexPutReplacer::AtenIndexPutReplacer() {
|
|||||||
if (indices_list_len > 1) {
|
if (indices_list_len > 1) {
|
||||||
index = indices_inputs[0];
|
index = indices_inputs[0];
|
||||||
for (int i = 1; i < indices_list_len; i++) {
|
for (int i = 1; i < indices_list_len; i++) {
|
||||||
index = std::make_shared<v1::Add>(index, indices_inputs[i]);
|
index = rg.make<v1::Add>(index, indices_inputs[i]);
|
||||||
}
|
}
|
||||||
broadcast_index_shape = std::make_shared<v3::ShapeOf>(index, element::i32);
|
broadcast_index_shape = rg.make<v3::ShapeOf>(index, element::i32);
|
||||||
OutputVector indices_list;
|
OutputVector indices_list;
|
||||||
for (int i = 0; i < indices_list_len; i++) {
|
for (int i = 0; i < indices_list_len; i++) {
|
||||||
auto broadcast = std::make_shared<v3::Broadcast>(indices_inputs[i], broadcast_index_shape);
|
auto broadcast = rg.make<v3::Broadcast>(indices_inputs[i], broadcast_index_shape);
|
||||||
auto unsqueeze = std::make_shared<v0::Unsqueeze>(broadcast, const_neg_1);
|
auto unsqueeze = rg.make<v0::Unsqueeze>(broadcast, const_neg_1);
|
||||||
|
|
||||||
// change negative indices to positive indices
|
// change negative indices to positive indices
|
||||||
auto const_i = v0::Constant::create(element::i32, Shape{}, {i});
|
auto const_i = v0::Constant::create(element::i32, Shape{}, {i});
|
||||||
auto dim_i = std::make_shared<v8::Gather>(input_shape, const_i, const_0);
|
auto dim_i = rg.make<v8::Gather>(input_shape, const_i, const_0);
|
||||||
auto dim_i_correct_type = std::make_shared<v1::ConvertLike>(dim_i, index);
|
auto dim_i_correct_type = rg.make<v1::ConvertLike>(dim_i, index);
|
||||||
auto unsqueeze_add = std::make_shared<v1::Add>(unsqueeze, dim_i_correct_type);
|
auto unsqueeze_add = rg.make<v1::Add>(unsqueeze, dim_i_correct_type);
|
||||||
auto unsqueeze_add_mod = std::make_shared<v1::Mod>(unsqueeze_add, dim_i_correct_type);
|
auto unsqueeze_add_mod = rg.make<v1::Mod>(unsqueeze_add, dim_i_correct_type);
|
||||||
|
|
||||||
indices_list.push_back(unsqueeze_add_mod);
|
indices_list.push_back(unsqueeze_add_mod);
|
||||||
}
|
}
|
||||||
index = std::make_shared<v0::Concat>(indices_list, -1);
|
index = rg.make<v0::Concat>(indices_list, -1);
|
||||||
} else {
|
} else {
|
||||||
index = indices_inputs[0];
|
index = indices_inputs[0];
|
||||||
// change negative indices to positive indices
|
// change negative indices to positive indices
|
||||||
auto dim_0 = (std::make_shared<v8::Gather>(input_shape, const_0, const_0));
|
auto dim_0 = (rg.make<v8::Gather>(input_shape, const_0, const_0));
|
||||||
auto dim_0_correct_type = (std::make_shared<v1::ConvertLike>(dim_0, index));
|
auto dim_0_correct_type = (rg.make<v1::ConvertLike>(dim_0, index));
|
||||||
index = std::make_shared<v1::Add>(index, dim_0_correct_type);
|
index = rg.make<v1::Add>(index, dim_0_correct_type);
|
||||||
index = std::make_shared<v1::Mod>(index, dim_0_correct_type);
|
index = rg.make<v1::Mod>(index, dim_0_correct_type);
|
||||||
|
|
||||||
broadcast_index_shape = std::make_shared<v3::ShapeOf>(index, element::i32);
|
broadcast_index_shape = rg.make<v3::ShapeOf>(index, element::i32);
|
||||||
index = std::make_shared<v0::Unsqueeze>(index, const_neg_1);
|
index = rg.make<v0::Unsqueeze>(index, const_neg_1);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto sub_data_shape = std::make_shared<v8::Slice>(input_shape, const_indices_list_len, const_max_int, const_1);
|
auto sub_data_shape = rg.make<v8::Slice>(input_shape, const_indices_list_len, const_max_int, const_1);
|
||||||
auto values_shape = std::make_shared<v0::Concat>(OutputVector{broadcast_index_shape, sub_data_shape}, 0);
|
auto values_shape = rg.make<v0::Concat>(OutputVector{broadcast_index_shape, sub_data_shape}, 0);
|
||||||
values = std::make_shared<v3::Broadcast>(values, values_shape);
|
values = rg.make<v3::Broadcast>(values, values_shape);
|
||||||
values = std::make_shared<v1::ConvertLike>(values, input);
|
values = rg.make<v1::ConvertLike>(values, input);
|
||||||
|
|
||||||
std::shared_ptr<ov::Node> result;
|
std::shared_ptr<ov::Node> result;
|
||||||
if (accumulate) {
|
if (accumulate) {
|
||||||
auto zeros = generate_zeros_with_convertlike(input_shape, input);
|
auto zeros = generate_zeros_with_convertlike(rg, input_shape, input);
|
||||||
auto scatter = std::make_shared<v3::ScatterNDUpdate>(zeros, index, values);
|
auto scatter = rg.make<v3::ScatterNDUpdate>(zeros, index, values);
|
||||||
result = std::make_shared<v1::Add>(input, scatter);
|
result = rg.make<v1::Add>(input, scatter);
|
||||||
} else {
|
} else {
|
||||||
result = std::make_shared<v3::ScatterNDUpdate>(input, index, values);
|
result = rg.make<v3::ScatterNDUpdate>(input, index, values);
|
||||||
}
|
}
|
||||||
copy_runtime_info(rt_copy_from, result);
|
copy_runtime_info_and_name(index_op, rg.get(), rt_copy_from);
|
||||||
replace_node(index_op, result);
|
replace_node(index_op, result);
|
||||||
result->set_friendly_name(index_op->get_friendly_name());
|
|
||||||
return true;
|
return true;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -33,9 +33,9 @@ namespace pytorch {
|
|||||||
namespace pass {
|
namespace pass {
|
||||||
|
|
||||||
using namespace ov::op;
|
using namespace ov::op;
|
||||||
namespace {
|
|
||||||
|
|
||||||
std::shared_ptr<Node> flatten(const Output<Node>& value, size_t axis) {
|
namespace {
|
||||||
|
Output<Node> flatten(ov::pass::NodeRegistry& rg, const Output<Node>& value, size_t axis) {
|
||||||
// First dimension of output tensor is the product of [d_0, ... d_{axis-1}] dimensions of
|
// First dimension of output tensor is the product of [d_0, ... d_{axis-1}] dimensions of
|
||||||
// input tensor. The last dimension is the product of the rest of input tensor dimensions:
|
// input tensor. The last dimension is the product of the rest of input tensor dimensions:
|
||||||
// [d_{axis}, ..., d_n]
|
// [d_{axis}, ..., d_n]
|
||||||
@ -45,20 +45,20 @@ std::shared_ptr<Node> flatten(const Output<Node>& value, size_t axis) {
|
|||||||
} else if (axis == 1) {
|
} else if (axis == 1) {
|
||||||
output_shape = v0::Constant::create(element::i32, Shape{2}, {0, -1});
|
output_shape = v0::Constant::create(element::i32, Shape{2}, {0, -1});
|
||||||
} else {
|
} else {
|
||||||
const auto value_shape = std::make_shared<v3::ShapeOf>(value, element::i32);
|
const auto value_shape = rg.make<v3::ShapeOf>(value, element::i32);
|
||||||
const auto value_rank = std::make_shared<v3::ShapeOf>(value_shape, element::i32);
|
const auto value_rank = rg.make<v3::ShapeOf>(value_shape, element::i32);
|
||||||
const auto axis_node = v0::Constant::create(element::i32, Shape{1}, {axis});
|
const auto axis_node = v0::Constant::create(element::i32, Shape{1}, {axis});
|
||||||
auto start = v0::Constant::create(element::i32, Shape{1}, {0});
|
auto start = v0::Constant::create(element::i32, Shape{1}, {0});
|
||||||
auto step = v0::Constant::create(element::i32, Shape{1}, {1});
|
auto step = v0::Constant::create(element::i32, Shape{1}, {1});
|
||||||
const auto first_part_dims = std::make_shared<v8::Slice>(value_shape, start, axis_node, step);
|
const auto first_part_dims = rg.make<v8::Slice>(value_shape, start, axis_node, step);
|
||||||
auto zero = v0::Constant::create(element::i32, {}, {0});
|
auto zero = v0::Constant::create(element::i32, {}, {0});
|
||||||
auto first_part_dims_length = std::make_shared<v1::ReduceProd>(first_part_dims, zero, true);
|
auto first_part_dims_length = rg.make<v1::ReduceProd>(first_part_dims, zero, true);
|
||||||
|
|
||||||
auto remaining_part_length = v0::Constant::create(element::i32, {1}, {-1});
|
auto remaining_part_length = v0::Constant::create(element::i32, {1}, {-1});
|
||||||
|
|
||||||
output_shape = std::make_shared<v0::Concat>(OutputVector{first_part_dims_length, remaining_part_length}, 0);
|
output_shape = rg.make<v0::Concat>(OutputVector{first_part_dims_length, remaining_part_length}, 0);
|
||||||
}
|
}
|
||||||
return std::make_shared<v1::Reshape>(value, output_shape, true);
|
return rg.make<v1::Reshape>(value, output_shape, true);
|
||||||
}
|
}
|
||||||
}; // namespace
|
}; // namespace
|
||||||
|
|
||||||
@ -70,6 +70,7 @@ AtenIndexToSelect::AtenIndexToSelect() {
|
|||||||
if (!index_op) {
|
if (!index_op) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
ov::pass::NodeRegistry rg;
|
||||||
auto input_node = index_op->input_value(0);
|
auto input_node = index_op->input_value(0);
|
||||||
auto indicies = index_op->input_value(1).get_node_shared_ptr();
|
auto indicies = index_op->input_value(1).get_node_shared_ptr();
|
||||||
auto list_indicies = cast_fw_node(indicies, "prim::ListConstruct");
|
auto list_indicies = cast_fw_node(indicies, "prim::ListConstruct");
|
||||||
@ -110,10 +111,10 @@ AtenIndexToSelect::AtenIndexToSelect() {
|
|||||||
}
|
}
|
||||||
auto id_dtype = ids[i].get_element_type();
|
auto id_dtype = ids[i].get_element_type();
|
||||||
if (id_dtype == element::boolean || id_dtype == element::u8) {
|
if (id_dtype == element::boolean || id_dtype == element::u8) {
|
||||||
auto idx = std::make_shared<v0::Convert>(ids[i], element::u8);
|
auto idx = rg.make<v0::Convert>(ids[i], element::u8);
|
||||||
auto nonzero = std::make_shared<v3::NonZero>(idx, element::i32);
|
auto nonzero = rg.make<v3::NonZero>(idx, element::i32);
|
||||||
auto input_order = v0::Constant::create(element::i32, Shape{2}, {1, 0});
|
auto input_order = v0::Constant::create(element::i32, Shape{2}, {1, 0});
|
||||||
auto masked_id = std::make_shared<v1::Transpose>(nonzero, input_order);
|
auto masked_id = rg.make<v1::Transpose>(nonzero, input_order);
|
||||||
masked_indicies.push_back(masked_id);
|
masked_indicies.push_back(masked_id);
|
||||||
is_masked_bool.push_back(true);
|
is_masked_bool.push_back(true);
|
||||||
} else {
|
} else {
|
||||||
@ -132,17 +133,15 @@ AtenIndexToSelect::AtenIndexToSelect() {
|
|||||||
if (advanced_ids.size() == 1) {
|
if (advanced_ids.size() == 1) {
|
||||||
auto index = masked_indicies[advanced_ids[0]];
|
auto index = masked_indicies[advanced_ids[0]];
|
||||||
if (is_masked_bool[advanced_ids[0]]) {
|
if (is_masked_bool[advanced_ids[0]]) {
|
||||||
auto gather = std::make_shared<v8::GatherND>(input_node, index);
|
auto gather = rg.make<v8::GatherND>(input_node, index);
|
||||||
copy_runtime_info({index_op, indicies}, gather);
|
copy_runtime_info_and_name(index_op, rg.get());
|
||||||
gather->set_friendly_name(index_op->get_friendly_name());
|
|
||||||
replace_node(index_op, gather);
|
replace_node(index_op, gather);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
index = std::make_shared<v0::Convert>(index, element::i32);
|
index = rg.make<v0::Convert>(index, element::i32);
|
||||||
auto dim = v0::Constant::create(element::i32, Shape{}, {advanced_ids[0]});
|
auto dim = v0::Constant::create(element::i32, Shape{}, {advanced_ids[0]});
|
||||||
auto gather = std::make_shared<v8::Gather>(input_node, index, dim);
|
auto gather = rg.make<v8::Gather>(input_node, index, dim);
|
||||||
copy_runtime_info({index_op, indicies}, gather);
|
copy_runtime_info_and_name(index_op, rg.get());
|
||||||
gather->set_friendly_name(index_op->get_friendly_name());
|
|
||||||
replace_node(index_op, gather);
|
replace_node(index_op, gather);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
@ -153,9 +152,9 @@ AtenIndexToSelect::AtenIndexToSelect() {
|
|||||||
add_exception_to_fw_node(index_op, "aten::index: dynamic rank for aten::index input is not supported.");
|
add_exception_to_fw_node(index_op, "aten::index: dynamic rank for aten::index input is not supported.");
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
auto input_shape = std::make_shared<v3::ShapeOf>(input_node, element::i32);
|
auto input_shape = rg.make<v3::ShapeOf>(input_node, element::i32);
|
||||||
auto zero = v0::Constant::create(element::i32, Shape{}, {0});
|
auto zero = v0::Constant::create(element::i32, Shape{}, {0});
|
||||||
auto input_dims = std::make_shared<v1::Split>(input_shape, zero, rank.get_length());
|
auto input_dims = rg.make<v1::Split>(input_shape, zero, rank.get_length());
|
||||||
std::vector<size_t> non_used_dims;
|
std::vector<size_t> non_used_dims;
|
||||||
for (auto i = 0; i < rank.get_length(); i++) {
|
for (auto i = 0; i < rank.get_length(); i++) {
|
||||||
if (std::find(advanced_ids.begin(), advanced_ids.end(), i) == advanced_ids.end()) {
|
if (std::find(advanced_ids.begin(), advanced_ids.end(), i) == advanced_ids.end()) {
|
||||||
@ -166,23 +165,23 @@ AtenIndexToSelect::AtenIndexToSelect() {
|
|||||||
permutation_dims.insert(permutation_dims.end(), advanced_ids.begin(), advanced_ids.end());
|
permutation_dims.insert(permutation_dims.end(), advanced_ids.begin(), advanced_ids.end());
|
||||||
permutation_dims.insert(permutation_dims.end(), non_used_dims.begin(), non_used_dims.end());
|
permutation_dims.insert(permutation_dims.end(), non_used_dims.begin(), non_used_dims.end());
|
||||||
auto transpose_dims = v0::Constant::create(element::i32, Shape{permutation_dims.size()}, permutation_dims);
|
auto transpose_dims = v0::Constant::create(element::i32, Shape{permutation_dims.size()}, permutation_dims);
|
||||||
auto transposed_input = std::make_shared<v1::Transpose>(input_node, transpose_dims);
|
auto transposed_input = rg.make<v1::Transpose>(input_node, transpose_dims);
|
||||||
auto flatten_input = flatten(transposed_input, adv_idx_count);
|
auto flatten_input = flatten(rg, transposed_input, adv_idx_count);
|
||||||
auto cum_adv_index = masked_indicies[advanced_ids[adv_idx_count - 1]];
|
auto cum_adv_index = masked_indicies[advanced_ids[adv_idx_count - 1]];
|
||||||
cum_adv_index = std::make_shared<v0::Convert>(cum_adv_index, element::i32);
|
cum_adv_index = rg.make<v0::Convert>(cum_adv_index, element::i32);
|
||||||
auto multiplier = input_dims->output(advanced_ids[adv_idx_count - 1]);
|
auto multiplier = input_dims->output(advanced_ids[adv_idx_count - 1]);
|
||||||
for (int i = static_cast<int>(adv_idx_count) - 2; i > -1; i--) {
|
for (int i = static_cast<int>(adv_idx_count) - 2; i > -1; i--) {
|
||||||
auto m_idx = std::make_shared<v0::Convert>(masked_indicies[i], element::i32);
|
auto m_idx = rg.make<v0::Convert>(masked_indicies[i], element::i32);
|
||||||
auto adv_index = std::make_shared<v1::Multiply>(m_idx, multiplier);
|
auto adv_index = rg.make<v1::Multiply>(m_idx, multiplier);
|
||||||
cum_adv_index = std::make_shared<v1::Add>(cum_adv_index, adv_index);
|
cum_adv_index = rg.make<v1::Add>(cum_adv_index, adv_index);
|
||||||
auto input_id = advanced_ids[i];
|
auto input_id = advanced_ids[i];
|
||||||
multiplier = std::make_shared<v1::Multiply>(multiplier, input_dims->output(input_id));
|
multiplier = rg.make<v1::Multiply>(multiplier, input_dims->output(input_id));
|
||||||
}
|
}
|
||||||
std::shared_ptr<Node> gather = std::make_shared<v8::Gather>(flatten_input, cum_adv_index, zero);
|
std::shared_ptr<Node> gather = rg.make<v8::Gather>(flatten_input, cum_adv_index, zero);
|
||||||
OutputVector concat_dims;
|
OutputVector concat_dims;
|
||||||
// check if all advanced indices are consecutive.
|
// check if all advanced indices are consecutive.
|
||||||
std::vector<size_t> consequence_dims;
|
std::vector<size_t> consequence_dims;
|
||||||
auto cum_adv_index_shape_tensor = std::make_shared<v3::ShapeOf>(cum_adv_index, element::i32);
|
auto cum_adv_index_shape_tensor = rg.make<v3::ShapeOf>(cum_adv_index, element::i32);
|
||||||
for (size_t i = advanced_ids[0]; i <= advanced_ids[advanced_ids.size() - 1]; i++) {
|
for (size_t i = advanced_ids[0]; i <= advanced_ids[advanced_ids.size() - 1]; i++) {
|
||||||
consequence_dims.push_back(i);
|
consequence_dims.push_back(i);
|
||||||
}
|
}
|
||||||
@ -194,8 +193,8 @@ AtenIndexToSelect::AtenIndexToSelect() {
|
|||||||
for (auto i : non_used_dims) {
|
for (auto i : non_used_dims) {
|
||||||
folded_adv_idx_shape_vector.push_back(input_dims->output(i));
|
folded_adv_idx_shape_vector.push_back(input_dims->output(i));
|
||||||
}
|
}
|
||||||
auto folded_adv_idx_shape = std::make_shared<v0::Concat>(folded_adv_idx_shape_vector, 0);
|
auto folded_adv_idx_shape = rg.make<v0::Concat>(folded_adv_idx_shape_vector, 0);
|
||||||
gather = std::make_shared<v1::Reshape>(gather, folded_adv_idx_shape, false);
|
gather = rg.make<v1::Reshape>(gather, folded_adv_idx_shape, false);
|
||||||
std::vector<size_t> adv_idx_permute;
|
std::vector<size_t> adv_idx_permute;
|
||||||
for (size_t i = 1; i < advanced_ids[0] + 1; i++) {
|
for (size_t i = 1; i < advanced_ids[0] + 1; i++) {
|
||||||
adv_idx_permute.push_back(i);
|
adv_idx_permute.push_back(i);
|
||||||
@ -207,7 +206,7 @@ AtenIndexToSelect::AtenIndexToSelect() {
|
|||||||
// Transpose folded advanced indexed axis to its original location.
|
// Transpose folded advanced indexed axis to its original location.
|
||||||
auto permute_indicies =
|
auto permute_indicies =
|
||||||
v0::Constant::create(element::i32, Shape{adv_idx_permute.size()}, adv_idx_permute);
|
v0::Constant::create(element::i32, Shape{adv_idx_permute.size()}, adv_idx_permute);
|
||||||
gather = std::make_shared<v1::Transpose>(gather, permute_indicies);
|
gather = rg.make<v1::Transpose>(gather, permute_indicies);
|
||||||
// unfold advanced index axes
|
// unfold advanced index axes
|
||||||
for (size_t i = 0; i < advanced_ids[0]; i++) {
|
for (size_t i = 0; i < advanced_ids[0]; i++) {
|
||||||
concat_dims.push_back(input_dims->output(i));
|
concat_dims.push_back(input_dims->output(i));
|
||||||
@ -226,11 +225,10 @@ AtenIndexToSelect::AtenIndexToSelect() {
|
|||||||
concat_dims.push_back(input_dims->output(i));
|
concat_dims.push_back(input_dims->output(i));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
auto final_shape = std::make_shared<v0::Concat>(concat_dims, 0);
|
auto final_shape = rg.make<v0::Concat>(concat_dims, 0);
|
||||||
gather = std::make_shared<v1::Reshape>(gather, final_shape, false);
|
gather = rg.make<v1::Reshape>(gather, final_shape, false);
|
||||||
copy_runtime_info({index_op, indicies}, gather);
|
copy_runtime_info_and_name(index_op, rg.get());
|
||||||
replace_node(index_op, gather);
|
replace_node(index_op, gather);
|
||||||
gather->set_friendly_name(index_op->get_friendly_name());
|
|
||||||
return true;
|
return true;
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
@ -246,22 +244,21 @@ AtenIndexToSelect::AtenIndexToSelect() {
|
|||||||
}
|
}
|
||||||
auto index_dtype = indicies->get_output_element_type(0);
|
auto index_dtype = indicies->get_output_element_type(0);
|
||||||
if (index_dtype == element::boolean || index_dtype == element::u8) {
|
if (index_dtype == element::boolean || index_dtype == element::u8) {
|
||||||
auto nonzero = std::make_shared<v3::NonZero>(indicies, element::i32);
|
auto nonzero = rg.make<v3::NonZero>(indicies, element::i32);
|
||||||
auto input_order = v0::Constant::create(element::i32, Shape{2}, {1, 0});
|
auto input_order = v0::Constant::create(element::i32, Shape{2}, {1, 0});
|
||||||
auto masked_id = std::make_shared<v1::Transpose>(nonzero, input_order);
|
auto masked_id = rg.make<v1::Transpose>(nonzero, input_order);
|
||||||
auto gather = std::make_shared<v8::GatherND>(input_node, masked_id);
|
auto gather = rg.make<v8::GatherND>(input_node, masked_id);
|
||||||
copy_runtime_info({index_op, indicies}, gather);
|
copy_runtime_info_and_name(index_op, rg.get());
|
||||||
replace_node(index_op, gather);
|
replace_node(index_op, gather);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
if (index_dtype != element::i32) {
|
if (index_dtype != element::i32) {
|
||||||
indicies = std::make_shared<ov::op::v0::Convert>(indicies, element::i32);
|
indicies = rg.make<ov::op::v0::Convert>(indicies, element::i32);
|
||||||
}
|
}
|
||||||
auto dim = v0::Constant::create(element::i32, Shape{}, {0});
|
auto dim = v0::Constant::create(element::i32, Shape{}, {0});
|
||||||
auto gather = std::make_shared<v8::Gather>(input_node, indicies, dim);
|
auto gather = rg.make<v8::Gather>(input_node, indicies, dim);
|
||||||
copy_runtime_info({index_op, indicies}, gather);
|
copy_runtime_info_and_name(index_op, rg.get());
|
||||||
replace_node(index_op, gather);
|
replace_node(index_op, gather);
|
||||||
gather->set_friendly_name(index_op->get_friendly_name());
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
add_exception_to_fw_node(index_op, "Unsupported case of aten::index.");
|
add_exception_to_fw_node(index_op, "Unsupported case of aten::index.");
|
||||||
|
@ -5,27 +5,30 @@
|
|||||||
#include "aten_stack_list_construct_replacer.hpp"
|
#include "aten_stack_list_construct_replacer.hpp"
|
||||||
|
|
||||||
#include "openvino/core/rt_info.hpp"
|
#include "openvino/core/rt_info.hpp"
|
||||||
|
#include "openvino/op/concat.hpp"
|
||||||
|
#include "openvino/op/constant.hpp"
|
||||||
|
#include "openvino/op/unsqueeze.hpp"
|
||||||
#include "openvino/op/util/framework_node.hpp"
|
#include "openvino/op/util/framework_node.hpp"
|
||||||
#include "openvino/opsets/opset10.hpp"
|
|
||||||
#include "openvino/pass/pattern/matcher.hpp"
|
#include "openvino/pass/pattern/matcher.hpp"
|
||||||
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||||
#include "utils.hpp"
|
#include "utils.hpp"
|
||||||
|
|
||||||
using namespace ov::pass::pattern;
|
|
||||||
|
|
||||||
namespace ov {
|
namespace ov {
|
||||||
namespace frontend {
|
namespace frontend {
|
||||||
namespace pytorch {
|
namespace pytorch {
|
||||||
namespace pass {
|
namespace pass {
|
||||||
|
|
||||||
|
using namespace ov::op;
|
||||||
|
using namespace ov::pass::pattern;
|
||||||
|
|
||||||
AtenStackListConstructReplacer::AtenStackListConstructReplacer() {
|
AtenStackListConstructReplacer::AtenStackListConstructReplacer() {
|
||||||
auto list_construct = ov::pass::pattern::wrap_type<ov::op::util::FrameworkNode>();
|
auto list_construct = wrap_type<ov::op::util::FrameworkNode>();
|
||||||
auto axis = ov::pass::pattern::wrap_type<opset10::Constant>();
|
auto axis = wrap_type<v0::Constant>();
|
||||||
|
|
||||||
// We search for a pattern: ListConstruct -> aten::stack <- Constant
|
// We search for a pattern: ListConstruct -> aten::stack <- Constant
|
||||||
auto stack = ov::pass::pattern::wrap_type<ov::op::util::FrameworkNode>({list_construct, axis});
|
auto stack = wrap_type<ov::op::util::FrameworkNode>({list_construct, axis});
|
||||||
|
|
||||||
ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) {
|
ov::matcher_pass_callback callback = [=](Matcher& m) {
|
||||||
auto stack = cast_fw_node(m.get_match_root(), "aten::stack");
|
auto stack = cast_fw_node(m.get_match_root(), "aten::stack");
|
||||||
if (!stack) {
|
if (!stack) {
|
||||||
return false;
|
return false;
|
||||||
@ -33,23 +36,23 @@ AtenStackListConstructReplacer::AtenStackListConstructReplacer() {
|
|||||||
const auto& pattern_map = m.get_pattern_value_map();
|
const auto& pattern_map = m.get_pattern_value_map();
|
||||||
auto input_node = pattern_map.at(list_construct).get_node_shared_ptr();
|
auto input_node = pattern_map.at(list_construct).get_node_shared_ptr();
|
||||||
auto axis_node = pattern_map.at(axis).get_node_shared_ptr();
|
auto axis_node = pattern_map.at(axis).get_node_shared_ptr();
|
||||||
auto axis_const = std::dynamic_pointer_cast<opset10::Constant>(axis_node);
|
auto axis_const = std::dynamic_pointer_cast<v0::Constant>(axis_node);
|
||||||
auto axis = axis_const->cast_vector<int64_t>();
|
auto axis = axis_const->cast_vector<int64_t>();
|
||||||
// Check if ListConstruct is an input
|
// Check if ListConstruct is an input
|
||||||
if (auto list_construct_node = cast_fw_node(input_node, "prim::ListConstruct")) {
|
if (auto list_construct_node = cast_fw_node(input_node, "prim::ListConstruct")) {
|
||||||
const auto& list_inputs = list_construct_node->input_values();
|
const auto& list_inputs = list_construct_node->input_values();
|
||||||
OutputVector node_vector;
|
OutputVector node_vector;
|
||||||
auto zero = opset10::Constant::create(element::i32, Shape{}, {0});
|
auto zero = v0::Constant::create(element::i32, Shape{}, {0});
|
||||||
// Iterate over values in ListConstruct
|
// Iterate over values in ListConstruct
|
||||||
for (const auto& list_input : list_inputs) {
|
for (const auto& list_input : list_inputs) {
|
||||||
auto node = concat_list_construct(list_input);
|
auto node = concat_list_construct(list_input);
|
||||||
auto unsqueezed_node = std::make_shared<opset10::Unsqueeze>(node, axis_const);
|
auto unsqueezed_node = std::make_shared<v0::Unsqueeze>(node, axis_const);
|
||||||
node_vector.push_back(unsqueezed_node);
|
node_vector.push_back(unsqueezed_node);
|
||||||
}
|
}
|
||||||
// Concat vectors on provided axis
|
// Concat vectors on provided axis
|
||||||
auto concat = std::make_shared<opset10::Concat>(node_vector, axis[0]);
|
auto concat = std::make_shared<v0::Concat>(node_vector, axis[0]);
|
||||||
|
|
||||||
copy_runtime_info({stack, input_node}, concat);
|
copy_runtime_info_and_name(stack, {concat}, {input_node});
|
||||||
replace_node(stack, concat);
|
replace_node(stack, concat);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -50,7 +50,7 @@ AtenEinsumListConstructReplacer::AtenEinsumListConstructReplacer() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto einsum = std::make_shared<v7::Einsum>(node_vector, equation);
|
auto einsum = std::make_shared<v7::Einsum>(node_vector, equation);
|
||||||
copy_runtime_info({einsum_op, equation_input, tensor_list}, einsum);
|
copy_runtime_info_and_name(einsum_op, {einsum}, {equation_input, tensor_list});
|
||||||
replace_node(einsum_op, einsum);
|
replace_node(einsum_op, einsum);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -126,7 +126,7 @@ IndexLoopGetitemReplacer::IndexLoopGetitemReplacer() {
|
|||||||
auto stop = rg.make<v1::Add>(start, chunks_size_body);
|
auto stop = rg.make<v1::Add>(start, chunks_size_body);
|
||||||
auto curr_chunk = rg.make<v8::Slice>(chunk_param, start, stop, one_1d, dim_body);
|
auto curr_chunk = rg.make<v8::Slice>(chunk_param, start, stop, one_1d, dim_body);
|
||||||
replace_node(getitem, curr_chunk);
|
replace_node(getitem, curr_chunk);
|
||||||
copy_runtime_info({chunk_op, getitem}, rg.get());
|
copy_runtime_info_and_name(chunk_op, rg.get(), {getitem});
|
||||||
curr_chunk->set_friendly_name(getitem->get_friendly_name());
|
curr_chunk->set_friendly_name(getitem->get_friendly_name());
|
||||||
return true;
|
return true;
|
||||||
};
|
};
|
||||||
|
@ -79,6 +79,7 @@ ListConstructReplacer::ListConstructReplacer() {
|
|||||||
// Concatenation is possible because all elements in list should be scalar or 1D tensors,
|
// Concatenation is possible because all elements in list should be scalar or 1D tensors,
|
||||||
// result should be 1D tensor.
|
// result should be 1D tensor.
|
||||||
OutputVector inputs;
|
OutputVector inputs;
|
||||||
|
ov::pass::NodeRegistry rg;
|
||||||
auto neg_1 = v0::Constant::create(element::i32, Shape{1}, {-1});
|
auto neg_1 = v0::Constant::create(element::i32, Shape{1}, {-1});
|
||||||
const auto& start_output = list_node->output(0);
|
const auto& start_output = list_node->output(0);
|
||||||
for (const auto& input : get_list_as_outputs(start_output)) {
|
for (const auto& input : get_list_as_outputs(start_output)) {
|
||||||
@ -94,13 +95,12 @@ ListConstructReplacer::ListConstructReplacer() {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
// reshape all elements to 1D
|
// reshape all elements to 1D
|
||||||
auto reshape = std::make_shared<v1::Reshape>(input, neg_1, false);
|
auto reshape = rg.make<v1::Reshape>(input, neg_1, false);
|
||||||
inputs.push_back(reshape);
|
inputs.push_back(reshape);
|
||||||
}
|
}
|
||||||
auto concat = std::make_shared<v0::Concat>(inputs, 0);
|
auto concat = rg.make<v0::Concat>(inputs, 0);
|
||||||
copy_runtime_info({list_node}, concat);
|
copy_runtime_info_and_name(list_node, rg.get());
|
||||||
replace_node(list_node, concat);
|
replace_node(list_node, concat);
|
||||||
concat->set_friendly_name(list_node->get_friendly_name());
|
|
||||||
return true;
|
return true;
|
||||||
};
|
};
|
||||||
auto m = std::make_shared<pattern::Matcher>(lc_pattern, "ov::frontend::pytorch::pass::ListConstructReplacer");
|
auto m = std::make_shared<pattern::Matcher>(lc_pattern, "ov::frontend::pytorch::pass::ListConstructReplacer");
|
||||||
|
@ -23,6 +23,8 @@ namespace frontend {
|
|||||||
namespace pytorch {
|
namespace pytorch {
|
||||||
namespace pass {
|
namespace pass {
|
||||||
|
|
||||||
|
using namespace ov::op;
|
||||||
|
|
||||||
MinMaxPrimListConstructReplacer::MinMaxPrimListConstructReplacer() {
|
MinMaxPrimListConstructReplacer::MinMaxPrimListConstructReplacer() {
|
||||||
auto op = ov::pass::pattern::wrap_type<ov::op::util::FrameworkNode>();
|
auto op = ov::pass::pattern::wrap_type<ov::op::util::FrameworkNode>();
|
||||||
|
|
||||||
@ -40,25 +42,26 @@ MinMaxPrimListConstructReplacer::MinMaxPrimListConstructReplacer() {
|
|||||||
} else {
|
} else {
|
||||||
op = max_op;
|
op = max_op;
|
||||||
}
|
}
|
||||||
|
ov::pass::NodeRegistry rg;
|
||||||
auto input_node = op->input_value(0);
|
auto input_node = op->input_value(0);
|
||||||
auto num_inputs = op->inputs().size();
|
auto num_inputs = op->inputs().size();
|
||||||
auto input = concat_list_construct(input_node);
|
auto input = concat_list_construct(input_node);
|
||||||
std::shared_ptr<Node> reduce_op;
|
std::shared_ptr<Node> reduce_op;
|
||||||
if (num_inputs == 1) {
|
if (num_inputs == 1) {
|
||||||
auto start = std::make_shared<ov::op::v0::Constant>(element::i32, Shape{}, 0);
|
auto start = rg.make<v0::Constant>(element::i32, Shape{}, 0);
|
||||||
auto step = std::make_shared<ov::op::v0::Constant>(element::i32, Shape{}, 1);
|
auto step = rg.make<v0::Constant>(element::i32, Shape{}, 1);
|
||||||
auto shape = std::make_shared<ov::op::v3::ShapeOf>(input, element::i32);
|
auto shape = rg.make<v3::ShapeOf>(input, element::i32);
|
||||||
auto rank = std::make_shared<ov::op::v3::ShapeOf>(shape, element::i32);
|
auto rank = rg.make<v3::ShapeOf>(shape, element::i32);
|
||||||
auto axis_0 = ov::op::v0::Constant::create(element::i32, Shape{}, {0});
|
auto axis_0 = v0::Constant::create(element::i32, Shape{}, {0});
|
||||||
auto reduced_rank = std::make_shared<ov::op::v0::Squeeze>(rank, axis_0);
|
auto reduced_rank = rg.make<v0::Squeeze>(rank, axis_0);
|
||||||
auto axes = std::make_shared<ov::op::v4::Range>(start, reduced_rank, step, element::i32);
|
auto axes = rg.make<v4::Range>(start, reduced_rank, step, element::i32);
|
||||||
std::shared_ptr<Node> reduce_op;
|
std::shared_ptr<Node> reduce_op;
|
||||||
if (!is_min) {
|
if (!is_min) {
|
||||||
reduce_op = std::make_shared<ov::op::v1::ReduceMax>(input, axes);
|
reduce_op = rg.make<v1::ReduceMax>(input, axes);
|
||||||
} else {
|
} else {
|
||||||
reduce_op = std::make_shared<ov::op::v1::ReduceMin>(input, axes);
|
reduce_op = rg.make<v1::ReduceMin>(input, axes);
|
||||||
}
|
}
|
||||||
copy_runtime_info({op, input_node.get_node_shared_ptr()}, reduce_op);
|
copy_runtime_info_and_name(op, rg.get());
|
||||||
replace_node(op, reduce_op);
|
replace_node(op, reduce_op);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
@ -66,11 +69,11 @@ MinMaxPrimListConstructReplacer::MinMaxPrimListConstructReplacer() {
|
|||||||
auto second_input = concat_list_construct(second_input_node);
|
auto second_input = concat_list_construct(second_input_node);
|
||||||
std::shared_ptr<Node> min_or_max_op;
|
std::shared_ptr<Node> min_or_max_op;
|
||||||
if (!is_min) {
|
if (!is_min) {
|
||||||
min_or_max_op = std::make_shared<ov::op::v1::Maximum>(input, second_input);
|
min_or_max_op = rg.make<v1::Maximum>(input, second_input);
|
||||||
} else {
|
} else {
|
||||||
min_or_max_op = std::make_shared<ov::op::v1::Minimum>(input, second_input);
|
min_or_max_op = rg.make<v1::Minimum>(input, second_input);
|
||||||
}
|
}
|
||||||
copy_runtime_info({op, input_node.get_node_shared_ptr()}, min_or_max_op);
|
copy_runtime_info_and_name(op, rg.get());
|
||||||
replace_node(op, min_or_max_op);
|
replace_node(op, min_or_max_op);
|
||||||
return true;
|
return true;
|
||||||
};
|
};
|
||||||
|
@ -30,7 +30,8 @@ namespace pass {
|
|||||||
using namespace ov::op;
|
using namespace ov::op;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
Output<Node> create_padding(const Output<Node>& input_rank,
|
Output<Node> create_padding(ov::pass::NodeRegistry& rg,
|
||||||
|
const Output<Node>& input_rank,
|
||||||
const Output<Node>& padding,
|
const Output<Node>& padding,
|
||||||
const Output<Node>& start_id,
|
const Output<Node>& start_id,
|
||||||
const Output<Node>& end_id) {
|
const Output<Node>& end_id) {
|
||||||
@ -39,14 +40,14 @@ Output<Node> create_padding(const Output<Node>& input_rank,
|
|||||||
// OV expects paddings separated on begins and ends for each dimension from first to last
|
// OV expects paddings separated on begins and ends for each dimension from first to last
|
||||||
auto minus_two = v0::Constant::create(element::i32, Shape{}, {-2});
|
auto minus_two = v0::Constant::create(element::i32, Shape{}, {-2});
|
||||||
auto zero = v0::Constant::create(element::i32, Shape{}, {0});
|
auto zero = v0::Constant::create(element::i32, Shape{}, {0});
|
||||||
auto pad_id_range = std::make_shared<v4::Range>(start_id, end_id, minus_two, element::i32);
|
auto pad_id_range = rg.make<v4::Range>(start_id, end_id, minus_two, element::i32);
|
||||||
auto pads = std::make_shared<v8::Gather>(padding, pad_id_range, zero);
|
auto pads = rg.make<v8::Gather>(padding, pad_id_range, zero);
|
||||||
// add left side zero padding for difference between padding size and input rank
|
// add left side zero padding for difference between padding size and input rank
|
||||||
auto pads_short_len = std::make_shared<v3::ShapeOf>(pads, element::i32);
|
auto pads_short_len = rg.make<v3::ShapeOf>(pads, element::i32);
|
||||||
auto pads_diff = std::make_shared<v1::Subtract>(input_rank, pads_short_len);
|
auto pads_diff = rg.make<v1::Subtract>(input_rank, pads_short_len);
|
||||||
auto pads_remaining = std::make_shared<v3::Broadcast>(zero, pads_diff);
|
auto pads_remaining = rg.make<v3::Broadcast>(zero, pads_diff);
|
||||||
auto pads_remaining_c = std::make_shared<v1::ConvertLike>(pads_remaining, pads);
|
auto pads_remaining_c = rg.make<v1::ConvertLike>(pads_remaining, pads);
|
||||||
auto pads_full = std::make_shared<v0::Concat>(OutputVector{pads_remaining_c, pads}, 0);
|
auto pads_full = rg.make<v0::Concat>(OutputVector{pads_remaining_c, pads}, 0);
|
||||||
return pads_full;
|
return pads_full;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -64,6 +65,7 @@ PrimListConstructPadReplacer::PrimListConstructPadReplacer() {
|
|||||||
if (!pad_op) {
|
if (!pad_op) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
ov::pass::NodeRegistry rg;
|
||||||
auto minus_two = v0::Constant::create(element::i32, Shape{}, {-2});
|
auto minus_two = v0::Constant::create(element::i32, Shape{}, {-2});
|
||||||
auto minus_one = v0::Constant::create(element::i32, Shape{}, {-1});
|
auto minus_one = v0::Constant::create(element::i32, Shape{}, {-1});
|
||||||
auto zero = v0::Constant::create(element::i32, Shape{}, {0});
|
auto zero = v0::Constant::create(element::i32, Shape{}, {0});
|
||||||
@ -73,15 +75,15 @@ PrimListConstructPadReplacer::PrimListConstructPadReplacer() {
|
|||||||
auto pad_values = concat_list_construct(padding);
|
auto pad_values = concat_list_construct(padding);
|
||||||
std::string mode = "constant";
|
std::string mode = "constant";
|
||||||
auto zero_f = v0::Constant::create(element::f32, Shape{}, {0});
|
auto zero_f = v0::Constant::create(element::f32, Shape{}, {0});
|
||||||
auto input_shape = std::make_shared<v3::ShapeOf>(input_node, element::i32);
|
auto input_shape = rg.make<v3::ShapeOf>(input_node, element::i32);
|
||||||
auto input_rank = std::make_shared<v3::ShapeOf>(input_shape, element::i32);
|
auto input_rank = rg.make<v3::ShapeOf>(input_shape, element::i32);
|
||||||
auto pad_size_1d = std::make_shared<v3::ShapeOf>(pad_values, element::i32);
|
auto pad_size_1d = rg.make<v3::ShapeOf>(pad_values, element::i32);
|
||||||
auto pad_size = std::make_shared<v0::Squeeze>(pad_size_1d, zero);
|
auto pad_size = rg.make<v0::Squeeze>(pad_size_1d, zero);
|
||||||
// get pad_begins and pad_ends indexes starting for end of paddings
|
// get pad_begins and pad_ends indexes starting for end of paddings
|
||||||
auto start_pad_begins = std::make_shared<v1::Add>(pad_size, minus_two);
|
auto start_pad_begins = rg.make<v1::Add>(pad_size, minus_two);
|
||||||
auto start_pad_ends = std::make_shared<v1::Add>(pad_size, minus_one);
|
auto start_pad_ends = rg.make<v1::Add>(pad_size, minus_one);
|
||||||
auto pad_begins_full = create_padding(input_rank, pad_values, start_pad_begins, minus_one);
|
auto pad_begins_full = create_padding(rg, input_rank, pad_values, start_pad_begins, minus_one);
|
||||||
auto pad_ends_full = create_padding(input_rank, pad_values, start_pad_ends, zero);
|
auto pad_ends_full = create_padding(rg, input_rank, pad_values, start_pad_ends, zero);
|
||||||
auto mode_const = pad_op->input_value(2).get_node_shared_ptr();
|
auto mode_const = pad_op->input_value(2).get_node_shared_ptr();
|
||||||
auto pad_value = pad_op->input_value(3);
|
auto pad_value = pad_op->input_value(3);
|
||||||
if (const auto& fw_node_mode = cast_fw_node(mode_const, "prim::Constant")) {
|
if (const auto& fw_node_mode = cast_fw_node(mode_const, "prim::Constant")) {
|
||||||
@ -97,16 +99,16 @@ PrimListConstructPadReplacer::PrimListConstructPadReplacer() {
|
|||||||
pad_value = zero_f;
|
pad_value = zero_f;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
pad_value = std::make_shared<v1::ConvertLike>(pad_value, input_node);
|
pad_value = rg.make<v1::ConvertLike>(pad_value, input_node);
|
||||||
}
|
}
|
||||||
if (PAD_MODES.find(mode) == PAD_MODES.end()) {
|
if (PAD_MODES.find(mode) == PAD_MODES.end()) {
|
||||||
add_exception_to_fw_node(pad_op, "Unsupported mode: " + mode + "for aten::pad");
|
add_exception_to_fw_node(pad_op, "Unsupported mode: " + mode + "for aten::pad");
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
auto pad_mode = PAD_MODES.at(mode);
|
auto pad_mode = PAD_MODES.at(mode);
|
||||||
auto pad = std::make_shared<v1::Pad>(input_node, pad_begins_full, pad_ends_full, pad_value, pad_mode);
|
auto pad = rg.make<v1::Pad>(input_node, pad_begins_full, pad_ends_full, pad_value, pad_mode);
|
||||||
replace_node(pad_op, pad);
|
replace_node(pad_op, pad);
|
||||||
copy_runtime_info({pad_op, padding.get_node_shared_ptr(), mode_const, pad_value.get_node_shared_ptr()}, pad);
|
copy_runtime_info_and_name(pad_op, rg.get());
|
||||||
pad->set_friendly_name(pad_op->get_friendly_name());
|
pad->set_friendly_name(pad_op->get_friendly_name());
|
||||||
return true;
|
return true;
|
||||||
};
|
};
|
||||||
|
@ -28,6 +28,7 @@ PrimListUnpackReplacer::PrimListUnpackReplacer() {
|
|||||||
return false;
|
return false;
|
||||||
|
|
||||||
auto input_node = list_unpack->input_value(0).get_node_shared_ptr();
|
auto input_node = list_unpack->input_value(0).get_node_shared_ptr();
|
||||||
|
ov::pass::NodeRegistry rg;
|
||||||
if (auto torch_split = cast_fw_node(input_node, "aten::split")) {
|
if (auto torch_split = cast_fw_node(input_node, "aten::split")) {
|
||||||
auto rank = torch_split->input(1).get_partial_shape().rank();
|
auto rank = torch_split->input(1).get_partial_shape().rank();
|
||||||
if (rank.is_dynamic()) {
|
if (rank.is_dynamic()) {
|
||||||
@ -44,19 +45,18 @@ PrimListUnpackReplacer::PrimListUnpackReplacer() {
|
|||||||
Shape{1},
|
Shape{1},
|
||||||
{list_unpack->get_output_size() - 1});
|
{list_unpack->get_output_size() - 1});
|
||||||
auto const_neg_1 = opset10::Constant::create(split_size.get_element_type(), Shape{1}, {-1});
|
auto const_neg_1 = opset10::Constant::create(split_size.get_element_type(), Shape{1}, {-1});
|
||||||
auto split_lenghts_m_1 = std::make_shared<opset10::Tile>(split_size, num_out_m_1);
|
auto split_lenghts_m_1 = rg.make<opset10::Tile>(split_size, num_out_m_1);
|
||||||
NodeVector concat_inputs{split_lenghts_m_1, const_neg_1};
|
NodeVector concat_inputs{split_lenghts_m_1, const_neg_1};
|
||||||
auto split_lenghts = std::make_shared<opset10::Concat>(concat_inputs, 0);
|
auto split_lenghts = rg.make<opset10::Concat>(concat_inputs, 0);
|
||||||
split = std::make_shared<opset10::VariadicSplit>(torch_split->get_input_source_output(0),
|
split = rg.make<opset10::VariadicSplit>(torch_split->get_input_source_output(0),
|
||||||
torch_split->get_input_source_output(2),
|
torch_split->get_input_source_output(2),
|
||||||
split_lenghts);
|
split_lenghts);
|
||||||
} else {
|
} else {
|
||||||
split = std::make_shared<opset10::VariadicSplit>(torch_split->get_input_source_output(0),
|
split = rg.make<opset10::VariadicSplit>(torch_split->get_input_source_output(0),
|
||||||
torch_split->get_input_source_output(2),
|
torch_split->get_input_source_output(2),
|
||||||
torch_split->get_input_source_output(1));
|
torch_split->get_input_source_output(1));
|
||||||
}
|
}
|
||||||
copy_runtime_info({list_unpack, input_node}, split);
|
copy_runtime_info_and_name(list_unpack, rg.get(), {input_node});
|
||||||
split->set_friendly_name(input_node->get_friendly_name());
|
|
||||||
replace_node(list_unpack, split);
|
replace_node(list_unpack, split);
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
@ -64,12 +64,11 @@ PrimListUnpackReplacer::PrimListUnpackReplacer() {
|
|||||||
|
|
||||||
if (auto split_with_sizes = cast_fw_node(input_node, "aten::split_with_sizes")) {
|
if (auto split_with_sizes = cast_fw_node(input_node, "aten::split_with_sizes")) {
|
||||||
auto split_lengths = concat_list_construct(split_with_sizes->get_input_source_output(1));
|
auto split_lengths = concat_list_construct(split_with_sizes->get_input_source_output(1));
|
||||||
auto split = std::make_shared<opset10::VariadicSplit>(split_with_sizes->get_input_source_output(0),
|
auto split = rg.make<opset10::VariadicSplit>(split_with_sizes->get_input_source_output(0),
|
||||||
split_with_sizes->get_input_source_output(2),
|
split_with_sizes->get_input_source_output(2),
|
||||||
split_lengths);
|
split_lengths);
|
||||||
|
|
||||||
copy_runtime_info({list_unpack, input_node}, split);
|
copy_runtime_info_and_name(list_unpack, rg.get(), {input_node});
|
||||||
split->set_friendly_name(input_node->get_friendly_name());
|
|
||||||
replace_node(list_unpack, split);
|
replace_node(list_unpack, split);
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
@ -87,27 +86,26 @@ PrimListUnpackReplacer::PrimListUnpackReplacer() {
|
|||||||
auto tensor_0 = opset10::Constant::create(element::i32, Shape{1}, {0});
|
auto tensor_0 = opset10::Constant::create(element::i32, Shape{1}, {0});
|
||||||
auto tensor_neg_1 = opset10::Constant::create(element::i32, Shape{1}, {-1});
|
auto tensor_neg_1 = opset10::Constant::create(element::i32, Shape{1}, {-1});
|
||||||
|
|
||||||
auto input_shape = std::make_shared<opset10::ShapeOf>(input_tensor, element::i32);
|
auto input_shape = rg.make<opset10::ShapeOf>(input_tensor, element::i32);
|
||||||
auto input_dimension = std::make_shared<opset10::Gather>(input_shape, dim, tensor_0);
|
auto input_dimension = rg.make<opset10::Gather>(input_shape, dim, tensor_0);
|
||||||
|
|
||||||
auto init_chunk_size = std::make_shared<opset10::Divide>(input_dimension, chunks, true);
|
auto init_chunk_size = rg.make<opset10::Divide>(input_dimension, chunks, true);
|
||||||
|
|
||||||
// Add 1 if input is not evenly divisible by chunks
|
// Add 1 if input is not evenly divisible by chunks
|
||||||
auto last_chunk_size = std::make_shared<opset10::Mod>(input_dimension, chunks);
|
auto last_chunk_size = rg.make<opset10::Mod>(input_dimension, chunks);
|
||||||
auto is_last_nonzero = std::make_shared<opset10::Greater>(last_chunk_size, tensor_0);
|
auto is_last_nonzero = rg.make<opset10::Greater>(last_chunk_size, tensor_0);
|
||||||
auto is_last_nonzero_int = std::make_shared<opset10::Convert>(is_last_nonzero, element::i32);
|
auto is_last_nonzero_int = rg.make<opset10::Convert>(is_last_nonzero, element::i32);
|
||||||
|
|
||||||
auto chunk_size = std::make_shared<opset10::Add>(init_chunk_size, is_last_nonzero_int);
|
auto chunk_size = rg.make<opset10::Add>(init_chunk_size, is_last_nonzero_int);
|
||||||
|
|
||||||
auto split_lengths_even_size =
|
auto split_lengths_even_size =
|
||||||
opset10::Constant::create(element::i32, Shape{1}, {list_unpack->get_output_size() - 1});
|
opset10::Constant::create(element::i32, Shape{1}, {list_unpack->get_output_size() - 1});
|
||||||
auto split_lengths_even = std::make_shared<opset10::Broadcast>(chunk_size, split_lengths_even_size);
|
auto split_lengths_even = rg.make<opset10::Broadcast>(chunk_size, split_lengths_even_size);
|
||||||
|
|
||||||
auto split_lengths = std::make_shared<opset10::Concat>(OutputVector{split_lengths_even, tensor_neg_1}, 0);
|
auto split_lengths = rg.make<opset10::Concat>(OutputVector{split_lengths_even, tensor_neg_1}, 0);
|
||||||
auto sliced_chunks = std::make_shared<opset10::VariadicSplit>(input_tensor, dim, split_lengths);
|
auto sliced_chunks = rg.make<opset10::VariadicSplit>(input_tensor, dim, split_lengths);
|
||||||
|
|
||||||
copy_runtime_info({list_unpack, input_node}, sliced_chunks);
|
copy_runtime_info_and_name(list_unpack, rg.get(), {input_node});
|
||||||
sliced_chunks->set_friendly_name(input_node->get_friendly_name());
|
|
||||||
replace_node(list_unpack, sliced_chunks);
|
replace_node(list_unpack, sliced_chunks);
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
@ -117,51 +115,45 @@ PrimListUnpackReplacer::PrimListUnpackReplacer() {
|
|||||||
const auto input = unbind->get_input_source_output(0);
|
const auto input = unbind->get_input_source_output(0);
|
||||||
const auto axis = unbind->get_input_source_output(1);
|
const auto axis = unbind->get_input_source_output(1);
|
||||||
const auto num_splits = list_unpack->get_output_size();
|
const auto num_splits = list_unpack->get_output_size();
|
||||||
auto split = std::make_shared<opset10::Split>(input, axis, num_splits);
|
auto split = rg.make<opset10::Split>(input, axis, num_splits);
|
||||||
NodeVector to_copy_rt{split};
|
|
||||||
OutputVector outputs;
|
OutputVector outputs;
|
||||||
for (auto output : split->outputs()) {
|
for (auto output : split->outputs()) {
|
||||||
const auto squeeze = std::make_shared<opset10::Squeeze>(output, axis);
|
const auto squeeze = rg.make<opset10::Squeeze>(output, axis);
|
||||||
outputs.push_back(squeeze);
|
outputs.push_back(squeeze);
|
||||||
to_copy_rt.push_back(squeeze);
|
|
||||||
}
|
}
|
||||||
copy_runtime_info({list_unpack, input_node}, to_copy_rt);
|
copy_runtime_info_and_name(list_unpack, rg.get(), {input_node});
|
||||||
replace_node(list_unpack, outputs);
|
replace_node(list_unpack, outputs);
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
if (auto where = cast_fw_node(input_node, "aten::where")) {
|
if (auto where = cast_fw_node(input_node, "aten::where")) {
|
||||||
const auto input = where->get_input_source_output(0);
|
const auto input = where->get_input_source_output(0);
|
||||||
auto non_zero = std::make_shared<opset10::NonZero>(input);
|
auto non_zero = rg.make<opset10::NonZero>(input);
|
||||||
auto axis = opset10::Constant::create(element::i32, Shape{}, {0});
|
auto axis = opset10::Constant::create(element::i32, Shape{}, {0});
|
||||||
const auto num_splits = list_unpack->get_output_size();
|
const auto num_splits = list_unpack->get_output_size();
|
||||||
auto split = std::make_shared<opset10::Split>(non_zero, axis, num_splits);
|
auto split = rg.make<opset10::Split>(non_zero, axis, num_splits);
|
||||||
NodeVector to_copy_rt{split};
|
|
||||||
OutputVector outputs;
|
OutputVector outputs;
|
||||||
for (auto output : split->outputs()) {
|
for (auto output : split->outputs()) {
|
||||||
const auto squeeze = std::make_shared<opset10::Squeeze>(output, axis);
|
const auto squeeze = rg.make<opset10::Squeeze>(output, axis);
|
||||||
outputs.push_back(squeeze);
|
outputs.push_back(squeeze);
|
||||||
to_copy_rt.push_back(squeeze);
|
|
||||||
}
|
}
|
||||||
copy_runtime_info({list_unpack, input_node}, to_copy_rt);
|
copy_runtime_info_and_name(list_unpack, rg.get(), {input_node});
|
||||||
replace_node(list_unpack, outputs);
|
replace_node(list_unpack, outputs);
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
if (auto nonzero_numpy = cast_fw_node(input_node, "aten::nonzero_numpy")) {
|
if (auto nonzero_numpy = cast_fw_node(input_node, "aten::nonzero_numpy")) {
|
||||||
const auto input = nonzero_numpy->get_input_source_output(0);
|
const auto input = nonzero_numpy->get_input_source_output(0);
|
||||||
auto non_zero = std::make_shared<opset10::NonZero>(input);
|
auto non_zero = rg.make<opset10::NonZero>(input);
|
||||||
auto axis = opset10::Constant::create(element::i32, Shape{}, {0});
|
auto axis = opset10::Constant::create(element::i32, Shape{}, {0});
|
||||||
const auto num_splits = list_unpack->get_output_size();
|
const auto num_splits = list_unpack->get_output_size();
|
||||||
auto split = std::make_shared<opset10::Split>(non_zero, axis, num_splits);
|
auto split = rg.make<opset10::Split>(non_zero, axis, num_splits);
|
||||||
NodeVector to_copy_rt{split};
|
|
||||||
OutputVector outputs;
|
OutputVector outputs;
|
||||||
for (auto output : split->outputs()) {
|
for (auto output : split->outputs()) {
|
||||||
const auto squeeze = std::make_shared<opset10::Squeeze>(output, axis);
|
const auto squeeze = rg.make<opset10::Squeeze>(output, axis);
|
||||||
outputs.push_back(squeeze);
|
outputs.push_back(squeeze);
|
||||||
to_copy_rt.push_back(squeeze);
|
|
||||||
}
|
}
|
||||||
copy_runtime_info({list_unpack, input_node}, to_copy_rt);
|
copy_runtime_info_and_name(list_unpack, rg.get(), {input_node});
|
||||||
replace_node(list_unpack, outputs);
|
replace_node(list_unpack, outputs);
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
@ -175,7 +167,6 @@ PrimListUnpackReplacer::PrimListUnpackReplacer() {
|
|||||||
add_exception_to_fw_node(input_node, "aten::meshgrid: only prim::ListConstruct supported as input.");
|
add_exception_to_fw_node(input_node, "aten::meshgrid: only prim::ListConstruct supported as input.");
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
NodeVector rt_copy_from{list_unpack, input_node, meshgrid_input_node};
|
|
||||||
OutputVector meshgrid_inputs;
|
OutputVector meshgrid_inputs;
|
||||||
for (auto& input : meshgrid_input_node->inputs()) {
|
for (auto& input : meshgrid_input_node->inputs()) {
|
||||||
meshgrid_inputs.push_back(input.get_source_output());
|
meshgrid_inputs.push_back(input.get_source_output());
|
||||||
@ -203,29 +194,26 @@ PrimListUnpackReplacer::PrimListUnpackReplacer() {
|
|||||||
auto const_1 = opset10::Constant::create(element::i32, Shape{1}, {1});
|
auto const_1 = opset10::Constant::create(element::i32, Shape{1}, {1});
|
||||||
int input_idx = 0;
|
int input_idx = 0;
|
||||||
for (auto& input : meshgrid_inputs) {
|
for (auto& input : meshgrid_inputs) {
|
||||||
auto reshaped_input = std::make_shared<opset10::Reshape>(input, const_neg_1, false);
|
auto reshaped_input = rg.make<opset10::Reshape>(input, const_neg_1, false);
|
||||||
auto shape = std::make_shared<opset10::ShapeOf>(reshaped_input, element::i32);
|
auto shape = rg.make<opset10::ShapeOf>(reshaped_input, element::i32);
|
||||||
cat_shapes.push_back(shape);
|
cat_shapes.push_back(shape);
|
||||||
NodeVector cat_inputs(meshgrid_inputs.size(), const_1);
|
NodeVector cat_inputs(meshgrid_inputs.size(), const_1);
|
||||||
cat_inputs[input_idx] = shape;
|
cat_inputs[input_idx] = shape;
|
||||||
input_idx++;
|
input_idx++;
|
||||||
auto input_cat = std::make_shared<opset10::Concat>(cat_inputs, 0);
|
auto input_cat = rg.make<opset10::Concat>(cat_inputs, 0);
|
||||||
auto reshape_cat = std::make_shared<opset10::Reshape>(reshaped_input, input_cat, false);
|
auto reshape_cat = rg.make<opset10::Reshape>(reshaped_input, input_cat, false);
|
||||||
reshapes.push_back(reshape_cat);
|
reshapes.push_back(reshape_cat);
|
||||||
}
|
}
|
||||||
auto cat = std::make_shared<opset10::Concat>(cat_shapes, 0);
|
auto cat = rg.make<opset10::Concat>(cat_shapes, 0);
|
||||||
NodeVector to_copy_rt{cat};
|
|
||||||
to_copy_rt.push_back(cat);
|
|
||||||
OutputVector outputs{};
|
OutputVector outputs{};
|
||||||
for (auto& reshape : reshapes) {
|
for (auto& reshape : reshapes) {
|
||||||
auto out = std::make_shared<opset10::Broadcast>(reshape, cat, ov::op::BroadcastType::BIDIRECTIONAL);
|
auto out = rg.make<opset10::Broadcast>(reshape, cat, ov::op::BroadcastType::BIDIRECTIONAL);
|
||||||
to_copy_rt.push_back(out);
|
|
||||||
outputs.push_back(out);
|
outputs.push_back(out);
|
||||||
}
|
}
|
||||||
if (indexing == "xy" && meshgrid_inputs.size() >= 2) {
|
if (indexing == "xy" && meshgrid_inputs.size() >= 2) {
|
||||||
std::swap(outputs[0], outputs[1]);
|
std::swap(outputs[0], outputs[1]);
|
||||||
}
|
}
|
||||||
copy_runtime_info(rt_copy_from, to_copy_rt);
|
copy_runtime_info_and_name(list_unpack, rg.get(), {input_node, meshgrid_input_node});
|
||||||
replace_node(list_unpack, outputs);
|
replace_node(list_unpack, outputs);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
@ -234,17 +222,15 @@ PrimListUnpackReplacer::PrimListUnpackReplacer() {
|
|||||||
// case aten::size as input
|
// case aten::size as input
|
||||||
// Number of ListUnpack outputs should be equal to rank of input shape.
|
// Number of ListUnpack outputs should be equal to rank of input shape.
|
||||||
auto axis_0 = opset10::Constant::create(element::i32, Shape{}, {0});
|
auto axis_0 = opset10::Constant::create(element::i32, Shape{}, {0});
|
||||||
auto split = std::make_shared<opset10::Split>(shape_of, axis_0, list_unpack->get_output_size());
|
auto split = rg.make<opset10::Split>(shape_of, axis_0, list_unpack->get_output_size());
|
||||||
|
|
||||||
NodeVector to_copy_rt{axis_0, split};
|
|
||||||
OutputVector res;
|
OutputVector res;
|
||||||
for (auto output : split->outputs()) {
|
for (auto output : split->outputs()) {
|
||||||
auto squeeze = std::make_shared<opset10::Squeeze>(output, axis_0);
|
auto squeeze = rg.make<opset10::Squeeze>(output, axis_0);
|
||||||
to_copy_rt.push_back(squeeze);
|
|
||||||
res.push_back(squeeze);
|
res.push_back(squeeze);
|
||||||
}
|
}
|
||||||
|
|
||||||
copy_runtime_info({list_unpack, input_node}, to_copy_rt);
|
copy_runtime_info_and_name(list_unpack, rg.get(), {input_node});
|
||||||
replace_node(list_unpack, res);
|
replace_node(list_unpack, res);
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
@ -254,17 +240,15 @@ PrimListUnpackReplacer::PrimListUnpackReplacer() {
|
|||||||
// case aten::slice as input
|
// case aten::slice as input
|
||||||
// Number of ListUnpack outputs should be equal to rank of input shape.
|
// Number of ListUnpack outputs should be equal to rank of input shape.
|
||||||
auto axis_0 = opset10::Constant::create(element::i32, Shape{}, {0});
|
auto axis_0 = opset10::Constant::create(element::i32, Shape{}, {0});
|
||||||
auto split = std::make_shared<opset10::Split>(slice, axis_0, list_unpack->get_output_size());
|
auto split = rg.make<opset10::Split>(slice, axis_0, list_unpack->get_output_size());
|
||||||
|
|
||||||
NodeVector to_copy_rt{axis_0, split};
|
|
||||||
OutputVector res;
|
OutputVector res;
|
||||||
for (auto output : split->outputs()) {
|
for (auto output : split->outputs()) {
|
||||||
auto squeeze = std::make_shared<opset10::Squeeze>(output, axis_0);
|
auto squeeze = rg.make<opset10::Squeeze>(output, axis_0);
|
||||||
to_copy_rt.push_back(squeeze);
|
|
||||||
res.push_back(squeeze);
|
res.push_back(squeeze);
|
||||||
}
|
}
|
||||||
|
|
||||||
copy_runtime_info({list_unpack, input_node}, to_copy_rt);
|
copy_runtime_info_and_name(list_unpack, rg.get(), {input_node});
|
||||||
replace_node(list_unpack, res);
|
replace_node(list_unpack, res);
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
|
@ -67,14 +67,14 @@ StringEqualityReplacer::StringEqualityReplacer() {
|
|||||||
auto equal_node = pattern_map.at(equal_op).get_node_shared_ptr();
|
auto equal_node = pattern_map.at(equal_op).get_node_shared_ptr();
|
||||||
if (auto equal = std::dynamic_pointer_cast<v1::Equal>(equal_node)) {
|
if (auto equal = std::dynamic_pointer_cast<v1::Equal>(equal_node)) {
|
||||||
auto const_result = v0::Constant::create(element::boolean, Shape{}, {lhs == rhs});
|
auto const_result = v0::Constant::create(element::boolean, Shape{}, {lhs == rhs});
|
||||||
copy_runtime_info({lhs_node, rhs_node, equal_node}, const_result);
|
copy_runtime_info_and_name(equal_node, {const_result});
|
||||||
replace_node(equal_node, const_result);
|
replace_node(equal_node, const_result);
|
||||||
return true;
|
return true;
|
||||||
};
|
};
|
||||||
auto not_equal_node = pattern_map.at(not_equal_op).get_node_shared_ptr();
|
auto not_equal_node = pattern_map.at(not_equal_op).get_node_shared_ptr();
|
||||||
if (auto equal = std::dynamic_pointer_cast<v1::NotEqual>(not_equal_node)) {
|
if (auto equal = std::dynamic_pointer_cast<v1::NotEqual>(not_equal_node)) {
|
||||||
auto const_result = v0::Constant::create(element::boolean, Shape{}, {lhs != rhs});
|
auto const_result = v0::Constant::create(element::boolean, Shape{}, {lhs != rhs});
|
||||||
copy_runtime_info({lhs_node, rhs_node, not_equal_node}, const_result);
|
copy_runtime_info_and_name(equal_node, {const_result});
|
||||||
replace_node(equal_node, const_result);
|
replace_node(equal_node, const_result);
|
||||||
return true;
|
return true;
|
||||||
};
|
};
|
||||||
|
@ -54,7 +54,26 @@ std::shared_ptr<ov::Model> TranslateSession::get_converted_model() {
|
|||||||
std::shared_ptr<ov::Model> TranslateSession::translate_graph(const ov::frontend::InputModel::Ptr& input_model) {
|
std::shared_ptr<ov::Model> TranslateSession::translate_graph(const ov::frontend::InputModel::Ptr& input_model) {
|
||||||
auto pytorch_model = std::dynamic_pointer_cast<pytorch::InputModel>(input_model);
|
auto pytorch_model = std::dynamic_pointer_cast<pytorch::InputModel>(input_model);
|
||||||
FRONT_END_GENERAL_CHECK(pytorch_model != nullptr, "Invalid input model");
|
FRONT_END_GENERAL_CHECK(pytorch_model != nullptr, "Invalid input model");
|
||||||
return convert_pytorch_model(pytorch_model->m_model_decoder, {}, pytorch_model->m_descriptors);
|
auto model = convert_pytorch_model(pytorch_model->m_model_decoder, {}, pytorch_model->m_descriptors);
|
||||||
|
// First delete tensor indexes from outputs then resolve input names, otherwise Parameter->Result will fail
|
||||||
|
for (auto& result : model->get_results()) {
|
||||||
|
auto tensor_desc = result->input_value(0);
|
||||||
|
auto names = tensor_desc.get_names();
|
||||||
|
if (!names.empty()) {
|
||||||
|
auto tensor_idx = decode_tensor_name(tensor_desc);
|
||||||
|
if (names.erase(std::to_string(tensor_idx))) {
|
||||||
|
tensor_desc.set_names(names);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Set input tensor names to be equal to signature name saved in friendly name
|
||||||
|
for (auto& param : model->get_parameters()) {
|
||||||
|
if (param->get_friendly_name() != param->get_name()) {
|
||||||
|
// get_name is autogenerated name, we need to make sure that this parameter was named by frontend
|
||||||
|
param->output(0).set_names({param->get_friendly_name()});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return model;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<Model> TranslateSession::convert_pytorch_model(
|
std::shared_ptr<Model> TranslateSession::convert_pytorch_model(
|
||||||
@ -91,10 +110,8 @@ std::shared_ptr<Model> TranslateSession::convert_pytorch_model(
|
|||||||
}
|
}
|
||||||
if (!input_node) {
|
if (!input_node) {
|
||||||
auto parameter = std::make_shared<v0::Parameter>(type, pshape);
|
auto parameter = std::make_shared<v0::Parameter>(type, pshape);
|
||||||
encode_tensor_name(
|
parameter->set_friendly_name(pytorch_model->get_input_signature_name(i));
|
||||||
parameter->output(0),
|
encode_tensor_name(parameter->output(0), inputs.at(i), {pytorch_model->get_input_debug_name(i)});
|
||||||
inputs.at(i),
|
|
||||||
{pytorch_model->get_input_debug_name(i), pytorch_model->get_input_signature_name(i)});
|
|
||||||
parameters->push_back(parameter);
|
parameters->push_back(parameter);
|
||||||
input_node = parameter;
|
input_node = parameter;
|
||||||
auto order = pytorch_model->get_input_transpose_order(i);
|
auto order = pytorch_model->get_input_transpose_order(i);
|
||||||
@ -404,6 +421,14 @@ Output<Node> TranslateSession::get_backprop_op(const std::shared_ptr<TorchDecode
|
|||||||
return std::make_shared<PtFrameworkNode>(node, OutputVector{value}, 1, true);
|
return std::make_shared<PtFrameworkNode>(node, OutputVector{value}, 1, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void TranslateSession::unique_name(const std::shared_ptr<Node>& node) {
|
||||||
|
if (m_unique_friendly_name_set.count(node->get_friendly_name())) {
|
||||||
|
node->set_friendly_name(node->get_friendly_name() + '_' + std::to_string(m_friendly_name_counter++));
|
||||||
|
} else {
|
||||||
|
m_unique_friendly_name_set.insert(node->get_friendly_name());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace pytorch
|
} // namespace pytorch
|
||||||
} // namespace frontend
|
} // namespace frontend
|
||||||
} // namespace ov
|
} // namespace ov
|
||||||
|
@ -48,7 +48,8 @@ public:
|
|||||||
/// \brief Gets pytorch tensor index from openvino tensor
|
/// \brief Gets pytorch tensor index from openvino tensor
|
||||||
size_t decode_tensor_name(const Output<Node>& tensor_desc);
|
size_t decode_tensor_name(const Output<Node>& tensor_desc);
|
||||||
|
|
||||||
size_t m_friendly_name_counter = 0;
|
/// \brief Make sure Node has unique name
|
||||||
|
void unique_name(const std::shared_ptr<Node>& node);
|
||||||
|
|
||||||
// Maps tensor index to initial tensor index which it is alias to, and to decoder of the node produced this alias
|
// Maps tensor index to initial tensor index which it is alias to, and to decoder of the node produced this alias
|
||||||
// and to the output produced during conversion of this node
|
// and to the output produced during conversion of this node
|
||||||
@ -64,6 +65,8 @@ private:
|
|||||||
|
|
||||||
std::map<size_t, std::pair<size_t, Output<Node>>> m_counter_map;
|
std::map<size_t, std::pair<size_t, Output<Node>>> m_counter_map;
|
||||||
std::map<std::string, uint64_t> m_op_statistics;
|
std::map<std::string, uint64_t> m_op_statistics;
|
||||||
|
std::unordered_set<std::string> m_unique_friendly_name_set;
|
||||||
|
size_t m_friendly_name_counter = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace pytorch
|
} // namespace pytorch
|
||||||
|
@ -5,6 +5,7 @@
|
|||||||
#include "utils.hpp"
|
#include "utils.hpp"
|
||||||
|
|
||||||
#include "op_table.hpp"
|
#include "op_table.hpp"
|
||||||
|
#include "openvino/core/rt_info.hpp"
|
||||||
#include "openvino/frontend/pytorch/decoder.hpp"
|
#include "openvino/frontend/pytorch/decoder.hpp"
|
||||||
#include "openvino/opsets/opset10.hpp"
|
#include "openvino/opsets/opset10.hpp"
|
||||||
#include "openvino/util/log.hpp"
|
#include "openvino/util/log.hpp"
|
||||||
@ -497,6 +498,30 @@ void add_exception_to_fw_node(std::shared_ptr<Node> node, const std::string& msg
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void copy_runtime_info_and_name(const std::shared_ptr<Node>& from,
|
||||||
|
ov::NodeVector to,
|
||||||
|
const ov::NodeVector& additional_rt_info_src) {
|
||||||
|
if (to.size() == 1) {
|
||||||
|
// We do 1 to 1 matching, no need to process names, just inherit initial name
|
||||||
|
to[0]->set_friendly_name(from->get_friendly_name());
|
||||||
|
} else {
|
||||||
|
std::unordered_set<std::string> unique_names;
|
||||||
|
size_t idx = 0;
|
||||||
|
for (auto& op : to) {
|
||||||
|
auto new_name = from->get_friendly_name() + '/' + op->get_type_name();
|
||||||
|
if (unique_names.count(new_name)) {
|
||||||
|
new_name += '_' + std::to_string(idx++);
|
||||||
|
} else {
|
||||||
|
unique_names.insert(new_name);
|
||||||
|
}
|
||||||
|
op->set_friendly_name(new_name);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
copy_runtime_info(from, to);
|
||||||
|
if (!additional_rt_info_src.empty())
|
||||||
|
copy_runtime_info(additional_rt_info_src, to);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace pytorch
|
} // namespace pytorch
|
||||||
} // namespace frontend
|
} // namespace frontend
|
||||||
} // namespace ov
|
} // namespace ov
|
||||||
|
@ -72,6 +72,10 @@ void align_output_types(const NodeContext& context, OutputVector& outputs);
|
|||||||
|
|
||||||
std::deque<Output<Node>> get_list_as_outputs(const Output<Node>& start);
|
std::deque<Output<Node>> get_list_as_outputs(const Output<Node>& start);
|
||||||
|
|
||||||
|
void copy_runtime_info_and_name(const std::shared_ptr<Node>& from,
|
||||||
|
ov::NodeVector to,
|
||||||
|
const ov::NodeVector& additional_rt_info_src = {});
|
||||||
|
|
||||||
namespace op {
|
namespace op {
|
||||||
template <OutputVector (*T)(const NodeContext&), size_t idx = 0>
|
template <OutputVector (*T)(const NodeContext&), size_t idx = 0>
|
||||||
OutputVector inplace_op(const NodeContext& context) {
|
OutputVector inplace_op(const NodeContext& context) {
|
||||||
|
Loading…
Reference in New Issue
Block a user