[PT FE] Revise usage output vs node in frontend (#19215)

* [PT FE] Revise usage output vs node in frontend

* Fix code style
This commit is contained in:
Maxim Vafin 2023-08-16 15:05:25 +02:00 committed by GitHub
parent cce9872fea
commit daa4f17a0a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 11 additions and 13 deletions

View File

@ -60,8 +60,7 @@ OutputVector translate_avg_poolnd(const NodeContext& context) {
auto pads_len = context.mark_node(v0::Constant::create(element::i32, Shape{}, {pads.size()}));
auto pads_diff = context.mark_node(std::make_shared<v1::Subtract>(rank, pads_len));
auto pads_remaining = context.mark_node(std::make_shared<v3::Broadcast>(zero_i32, pads_diff));
auto padding = context.mark_node(
std::make_shared<v0::Concat>(NodeVector{pads_remaining, pad_values.get_node_shared_ptr()}, 0));
auto padding = context.mark_node(std::make_shared<v0::Concat>(OutputVector{pads_remaining, pad_values}, 0));
input = context.mark_node(std::make_shared<v1::Pad>(input, padding, padding, zero, ov::op::PadMode::CONSTANT));
pads = Shape(pads.size(), 0);
}

View File

@ -52,11 +52,11 @@ OutputVector translate_pad(const NodeContext& context) {
int64_t pad_l;
int64_t pad_r;
auto pad_last_id = paddings.size();
auto cur = data.get_node_shared_ptr();
auto cur = data;
auto step = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {1}));
auto zero_1d = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {0}));
for (size_t i = 0; i < pad_size_half; i++) {
ov::NodeVector tensors;
OutputVector tensors;
pad_r = paddings[pad_last_id - (2 * i + 1)];
pad_l = paddings[pad_last_id - (2 * i + 2)];
auto axes = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {2 + i}));

View File

@ -108,15 +108,14 @@ AtenGetItemReplacer::AtenGetItemReplacer() {
if (index < 0) {
index = split->outputs().size() + index;
}
replace_node(getitem, {split->outputs()[index]});
getitem->output(0).replace(split->outputs()[index]);
}
} else if (auto list_construct = cast_fw_node(input_node, "prim::ListConstruct")) {
auto getitem_idx = getitem->input_value(1).get_node_shared_ptr();
auto getitem_idx_const = std::dynamic_pointer_cast<v0::Constant>(getitem_idx);
if (getitem_idx_const) {
auto idx = getitem_idx_const->cast_vector<int64_t>();
auto element = list_construct->input_value(idx[0]).get_node_shared_ptr();
replace_node(getitem, element);
getitem->output(0).replace(list_construct->input_value(idx[0]));
} else {
auto input_concat = concat_list_construct(list_construct);
auto zero = v0::Constant::create(element::i32, Shape{}, {0});

View File

@ -126,7 +126,7 @@ AtenIndexToSelect::AtenIndexToSelect() {
// all indicies prim::Constant(None), return input as is
if (advanced_ids.size() == 0) {
replace_node(index_op, input_node.get_node_shared_ptr());
index_op->output(0).replace(index_op->get_input_source_output(0));
return true;
}
// perform gather for single element case
@ -238,7 +238,7 @@ AtenIndexToSelect::AtenIndexToSelect() {
// index is None, stay input as is
const auto& attrs = const_input->get_attrs();
if (attrs.find("none_value") != attrs.end()) {
replace_node(index_op, input_node.get_node_shared_ptr());
index_op->output(0).replace(index_op->get_input_source_output(0));
return true;
}
}

View File

@ -74,7 +74,7 @@ IRFFTNComplexReplacer::IRFFTNComplexReplacer() {
bool dim_use_default = is_none_node(irfftn_op->input_value(2));
bool s_use_default = is_none_node(irfftn_op->input_value(1));
// Can be None constant, when used check s_use_default.
auto raw_s_input_maybe = concat_list_construct(irfftn_op->input_value(1)).get_node_shared_ptr();
auto raw_s_input_maybe = concat_list_construct(irfftn_op->input_value(1));
// Handle dim parameter containing vector of intigers indicating dimensions to be transformed.
std::shared_ptr<ov::Node> dim;

View File

@ -57,7 +57,7 @@ RFFTNComplexReplacer::RFFTNComplexReplacer() {
bool dim_use_default = is_none_node(rfftn_op->input_value(2));
bool s_use_default = is_none_node(rfftn_op->input_value(1));
// Can be None constant, when used check s_use_default.
auto raw_s_input_maybe = concat_list_construct(rfftn_op->input_value(1)).get_node_shared_ptr();
auto raw_s_input_maybe = concat_list_construct(rfftn_op->input_value(1));
// Handle dim parameter containing vector of intigers indicating dimensions to be transformed.
std::shared_ptr<ov::Node> dim;

View File

@ -181,7 +181,7 @@ Output<Node> concat_list_construct(const Output<Node>& input) {
OutputVector node_vector;
auto zero = opset10::Constant::create(element::i32, Shape{}, {0});
for (size_t i = 0; i < list_inputs.size(); i++) {
auto node = concat_list_construct(list_inputs[i].get_node_shared_ptr());
auto node = concat_list_construct(list_inputs[i]);
auto unsqueezed_node = std::make_shared<opset10::Unsqueeze>(node, zero);
node_vector.push_back(unsqueezed_node);
}

View File

@ -136,7 +136,7 @@ inline OutputVector return_false_scalar(const NodeContext& context) {
}
inline OutputVector skip_node(const NodeContext& context) {
return {context.get_input(0).get_node_shared_ptr()};
return {context.get_input(0)};
}
} // namespace op