[PT FE] Revise usage output vs node in frontend (#19215)
* [PT FE] Revise usage output vs node in frontend * Fix code style
This commit is contained in:
parent
cce9872fea
commit
daa4f17a0a
@ -60,8 +60,7 @@ OutputVector translate_avg_poolnd(const NodeContext& context) {
|
||||
auto pads_len = context.mark_node(v0::Constant::create(element::i32, Shape{}, {pads.size()}));
|
||||
auto pads_diff = context.mark_node(std::make_shared<v1::Subtract>(rank, pads_len));
|
||||
auto pads_remaining = context.mark_node(std::make_shared<v3::Broadcast>(zero_i32, pads_diff));
|
||||
auto padding = context.mark_node(
|
||||
std::make_shared<v0::Concat>(NodeVector{pads_remaining, pad_values.get_node_shared_ptr()}, 0));
|
||||
auto padding = context.mark_node(std::make_shared<v0::Concat>(OutputVector{pads_remaining, pad_values}, 0));
|
||||
input = context.mark_node(std::make_shared<v1::Pad>(input, padding, padding, zero, ov::op::PadMode::CONSTANT));
|
||||
pads = Shape(pads.size(), 0);
|
||||
}
|
||||
|
@ -52,11 +52,11 @@ OutputVector translate_pad(const NodeContext& context) {
|
||||
int64_t pad_l;
|
||||
int64_t pad_r;
|
||||
auto pad_last_id = paddings.size();
|
||||
auto cur = data.get_node_shared_ptr();
|
||||
auto cur = data;
|
||||
auto step = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {1}));
|
||||
auto zero_1d = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {0}));
|
||||
for (size_t i = 0; i < pad_size_half; i++) {
|
||||
ov::NodeVector tensors;
|
||||
OutputVector tensors;
|
||||
pad_r = paddings[pad_last_id - (2 * i + 1)];
|
||||
pad_l = paddings[pad_last_id - (2 * i + 2)];
|
||||
auto axes = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {2 + i}));
|
||||
|
@ -108,15 +108,14 @@ AtenGetItemReplacer::AtenGetItemReplacer() {
|
||||
if (index < 0) {
|
||||
index = split->outputs().size() + index;
|
||||
}
|
||||
replace_node(getitem, {split->outputs()[index]});
|
||||
getitem->output(0).replace(split->outputs()[index]);
|
||||
}
|
||||
} else if (auto list_construct = cast_fw_node(input_node, "prim::ListConstruct")) {
|
||||
auto getitem_idx = getitem->input_value(1).get_node_shared_ptr();
|
||||
auto getitem_idx_const = std::dynamic_pointer_cast<v0::Constant>(getitem_idx);
|
||||
if (getitem_idx_const) {
|
||||
auto idx = getitem_idx_const->cast_vector<int64_t>();
|
||||
auto element = list_construct->input_value(idx[0]).get_node_shared_ptr();
|
||||
replace_node(getitem, element);
|
||||
getitem->output(0).replace(list_construct->input_value(idx[0]));
|
||||
} else {
|
||||
auto input_concat = concat_list_construct(list_construct);
|
||||
auto zero = v0::Constant::create(element::i32, Shape{}, {0});
|
||||
|
@ -126,7 +126,7 @@ AtenIndexToSelect::AtenIndexToSelect() {
|
||||
|
||||
// all indicies prim::Constant(None), return input as is
|
||||
if (advanced_ids.size() == 0) {
|
||||
replace_node(index_op, input_node.get_node_shared_ptr());
|
||||
index_op->output(0).replace(index_op->get_input_source_output(0));
|
||||
return true;
|
||||
}
|
||||
// perform gather for single element case
|
||||
@ -238,7 +238,7 @@ AtenIndexToSelect::AtenIndexToSelect() {
|
||||
// index is None, stay input as is
|
||||
const auto& attrs = const_input->get_attrs();
|
||||
if (attrs.find("none_value") != attrs.end()) {
|
||||
replace_node(index_op, input_node.get_node_shared_ptr());
|
||||
index_op->output(0).replace(index_op->get_input_source_output(0));
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
@ -74,7 +74,7 @@ IRFFTNComplexReplacer::IRFFTNComplexReplacer() {
|
||||
bool dim_use_default = is_none_node(irfftn_op->input_value(2));
|
||||
bool s_use_default = is_none_node(irfftn_op->input_value(1));
|
||||
// Can be None constant, when used check s_use_default.
|
||||
auto raw_s_input_maybe = concat_list_construct(irfftn_op->input_value(1)).get_node_shared_ptr();
|
||||
auto raw_s_input_maybe = concat_list_construct(irfftn_op->input_value(1));
|
||||
|
||||
// Handle dim parameter containing vector of intigers indicating dimensions to be transformed.
|
||||
std::shared_ptr<ov::Node> dim;
|
||||
|
@ -57,7 +57,7 @@ RFFTNComplexReplacer::RFFTNComplexReplacer() {
|
||||
bool dim_use_default = is_none_node(rfftn_op->input_value(2));
|
||||
bool s_use_default = is_none_node(rfftn_op->input_value(1));
|
||||
// Can be None constant, when used check s_use_default.
|
||||
auto raw_s_input_maybe = concat_list_construct(rfftn_op->input_value(1)).get_node_shared_ptr();
|
||||
auto raw_s_input_maybe = concat_list_construct(rfftn_op->input_value(1));
|
||||
|
||||
// Handle dim parameter containing vector of intigers indicating dimensions to be transformed.
|
||||
std::shared_ptr<ov::Node> dim;
|
||||
|
@ -181,7 +181,7 @@ Output<Node> concat_list_construct(const Output<Node>& input) {
|
||||
OutputVector node_vector;
|
||||
auto zero = opset10::Constant::create(element::i32, Shape{}, {0});
|
||||
for (size_t i = 0; i < list_inputs.size(); i++) {
|
||||
auto node = concat_list_construct(list_inputs[i].get_node_shared_ptr());
|
||||
auto node = concat_list_construct(list_inputs[i]);
|
||||
auto unsqueezed_node = std::make_shared<opset10::Unsqueeze>(node, zero);
|
||||
node_vector.push_back(unsqueezed_node);
|
||||
}
|
||||
|
@ -136,7 +136,7 @@ inline OutputVector return_false_scalar(const NodeContext& context) {
|
||||
}
|
||||
|
||||
inline OutputVector skip_node(const NodeContext& context) {
|
||||
return {context.get_input(0).get_node_shared_ptr()};
|
||||
return {context.get_input(0)};
|
||||
}
|
||||
|
||||
} // namespace op
|
||||
|
Loading…
Reference in New Issue
Block a user