[PT FE]: handle prim::ListConstruct + aten::pad case (#15288)

This commit is contained in:
Ekaterina Aidova 2023-01-31 18:08:22 +04:00 committed by GitHub
parent 407590cfc2
commit a12de8183c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 240 additions and 0 deletions

View File

@ -15,6 +15,7 @@
#include "transforms/aten_cat_replacer.hpp"
#include "transforms/aten_getitem_replacer.hpp"
#include "transforms/max_prim_list_construct_replacer.hpp"
#include "transforms/prim_list_construct_pad.hpp"
#include "transforms/prim_list_unpack_replacer.hpp"
#include "transforms/prim_tuple_construct_replacer.hpp"
@ -87,6 +88,7 @@ void FrontEnd::normalize(const std::shared_ptr<ov::Model>& model) const {
manager.register_pass<ov::frontend::pytorch::pass::PrimListUnpackReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::AtenGetItemReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::MaxPrimListConstructReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::PrimListConstructPadReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::DecomposeTupleResults>();
manager.register_pass<ov::pass::ConstantFolding>();

View File

@ -0,0 +1,122 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "prim_list_construct_pad.hpp"
#include "openvino/core/rt_info.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/broadcast.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/convert_like.hpp"
#include "openvino/op/gather.hpp"
#include "openvino/op/pad.hpp"
#include "openvino/op/range.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/op/squeeze.hpp"
#include "openvino/op/subtract.hpp"
#include "openvino/op/util/framework_node.hpp"
#include "openvino/pass/pattern/matcher.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "pt_framework_node.hpp"
#include "utils.hpp"
namespace ov {
namespace frontend {
namespace pytorch {
namespace pass {
namespace {
std::shared_ptr<Node> create_padding(std::shared_ptr<Node> input_rank,
std::shared_ptr<Node> padding,
std::shared_ptr<Node> start_id,
std::shared_ptr<Node> end_id) {
// PyTorch paddings represented as [N_pad_begins, N_pad_ends, N-1_pad_begins, N-1_pad_ends, ... ]
// if len of paddings not equal to input rank * 2, zero padding added to first rank - N dimensions
// OV expects paddings separated on begins and ends for each dimension from first to last
auto minus_two = ov::op::v0::Constant::create(element::i64, Shape{}, {-2});
auto zero = ov::op::v0::Constant::create(element::i64, Shape{}, {0});
auto pad_id_range = std::make_shared<ov::op::v4::Range>(start_id, end_id, minus_two, element::i64);
auto pads = std::make_shared<ov::op::v8::Gather>(padding, pad_id_range, zero);
// add left side zero padding for difference between padding size and input rank
auto pads_short_len = std::make_shared<ov::op::v3::ShapeOf>(pads);
auto pads_diff = std::make_shared<ov::op::v1::Subtract>(input_rank, pads_short_len);
auto pads_remaining = std::make_shared<ov::op::v3::Broadcast>(zero, pads_diff);
auto pads_remaining_c = std::make_shared<ov::op::v1::ConvertLike>(pads_remaining, pads);
auto pads_full = std::make_shared<ov::op::v0::Concat>(OutputVector{pads_remaining_c, pads}, 0);
return pads_full;
}
const std::unordered_map<std::string, ov::op::PadMode> PAD_MODES = {{"constant", ov::op::PadMode::CONSTANT},
{"reflect", ov::op::PadMode::REFLECT},
{"replicate", ov::op::PadMode::EDGE}};
}; // namespace
PrimListConstructPadReplacer::PrimListConstructPadReplacer() {
// transformation for case aten::pad + prim::ListConstruct as paddings
auto pad_op = ov::pass::pattern::wrap_type<ov::op::util::FrameworkNode>();
ov::matcher_pass_callback callback = [](ov::pass::pattern::Matcher& m) {
auto pad_op = cast_fw_node(m.get_match_root(), "aten::pad");
if (!pad_op) {
return false;
}
auto minus_two = ov::op::v0::Constant::create(element::i64, Shape{}, {-2});
auto minus_one = ov::op::v0::Constant::create(element::i64, Shape{}, {-1});
auto zero = ov::op::v0::Constant::create(element::i64, Shape{}, {0});
auto input_node = pad_op->input_value(0).get_node_shared_ptr();
auto padding = pad_op->input_value(1).get_node_shared_ptr();
// for case. when padding is list of scalars, concatenate them into one tensor
auto pad_values = concat_list_construct(padding);
std::string mode = "constant";
auto zero_f = ov::op::v0::Constant::create(element::f32, Shape{}, {0});
auto input_shape = std::make_shared<ov::op::v3::ShapeOf>(input_node);
auto input_rank = std::make_shared<ov::op::v3::ShapeOf>(input_shape);
auto pad_size_1d = std::make_shared<ov::op::v3::ShapeOf>(pad_values);
auto pad_size = std::make_shared<ov::op::v0::Squeeze>(pad_size_1d, zero);
// get pad_begins and pad_ends indexes starting for end of paddings
auto start_pad_begins = std::make_shared<ov::op::v1::Add>(pad_size, minus_two);
auto start_pad_ends = std::make_shared<ov::op::v1::Add>(pad_size, minus_one);
auto pad_begins_full = create_padding(input_rank, pad_values, start_pad_begins, minus_one);
auto pad_ends_full = create_padding(input_rank, pad_values, start_pad_ends, zero);
auto mode_const = pad_op->input_value(2).get_node_shared_ptr();
auto pad_value = pad_op->input_value(3).get_node_shared_ptr();
if (const auto& fw_node_mode = cast_fw_node(mode_const, "prim::Constant")) {
const auto& attrs = fw_node_mode->get_attrs();
if (attrs.find("string_value") != attrs.end()) {
mode = attrs.at("string_value");
}
}
if (mode == "constant") {
if (const auto& fw_node_value = cast_fw_node(pad_value, "prim::Constant")) {
const auto& attrs = fw_node_value->get_attrs();
if (attrs.find("none_value") != attrs.end()) {
pad_value = zero_f;
}
}
}
FRONT_END_OP_CONVERSION_CHECK(PAD_MODES.find(mode) != PAD_MODES.end(),
"Unsupported mode: ",
mode,
"for aten::pad");
auto pad_mode = PAD_MODES.at(mode);
auto pad = std::make_shared<ov::op::v1::Pad>(input_node, pad_begins_full, pad_ends_full, pad_value, pad_mode);
replace_node(pad_op, pad);
copy_runtime_info({pad_op,
input_node,
padding,
pad_op->input_value(2).get_node_shared_ptr(),
pad_op->input_value(3).get_node_shared_ptr()},
pad);
return true;
};
auto m = std::make_shared<ov::pass::pattern::Matcher>(pad_op,
"ov::frontend::pytorch::pass::PrimListConstructPadReplacer");
this->register_matcher(m, callback);
};
} // namespace pass
} // namespace pytorch
} // namespace frontend
} // namespace ov

View File

@ -0,0 +1,24 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/pass/graph_rewrite.hpp"
#include "openvino/pass/pass.hpp"
namespace ov {
namespace frontend {
namespace pytorch {
namespace pass {
class PrimListConstructPadReplacer : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ov::frontend::pytorch::pass::PrimListConstructPadReplacer");
PrimListConstructPadReplacer();
};
} // namespace pass
} // namespace pytorch
} // namespace frontend
} // namespace ov

View File

@ -108,3 +108,95 @@ class TestPad(PytorchLayerTest):
def test_pad2d(self, pads, mode, value, ie_device, precision, ir_version):
self._test(*self.create_model(pads, mode, value), ie_device, precision, ir_version,
kwargs_to_prepare_input={'ndim': 2}, trace_model=True)
class TestPadListPaddingings(PytorchLayerTest):
def _prepare_input(self, ndim=4, pad_w=0, pad_h=0):
import numpy as np
input_5d_shape = [1, 3, 14, 14, 18]
return (np.random.randn(*input_5d_shape[:ndim]).astype(np.float32), np.array(pad_w, dtype=np.int32), np.array(pad_h, dtype=np.int32))
def create_model(self, mode, value=None):
import torch
import torch.nn.functional as F
class aten_pad(torch.nn.Module):
def __init__(self, mode, value=None):
super().__init__()
self.mode = mode
self.value = value
def forward(self, x, pad_w:int, pad_h:int):
return F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2], value=self.value)
ref_net = None
return aten_pad(mode, value), ref_net, "aten::pad"
@pytest.mark.parametrize("pad_w,pad_h,mode,value", [
(2, 0, "reflect", None),
(0, 2, "reflect", None),
(10, 10, "reflect", None),
(0, 0, "reflect", None),
(5, 3, "reflect", None),
(2, 0, "replicate", None),
(0, 2, "replicate", None),
(10, 10, "replicate", None),
(5, 3, "replicate", None),
(0, 0, "replicate", None),
(2, 0, "constant", None),
(0, 3, "constant", 42.),
(4, 4, "constant", -0.57),
(1, 2, "constant", None),
(0, 0, "constant", -0.57),
])
@pytest.mark.nightly
@pytest.mark.precommit
def test_pad4d(self, pad_w, pad_h, mode, value, ie_device, precision, ir_version):
self._test(*self.create_model(mode, value), ie_device, precision, ir_version,
kwargs_to_prepare_input={'ndim': 4, "pad_w": pad_w, "pad_h": pad_h})
@pytest.mark.parametrize("pad_w,pad_h,mode,value", [
(2, 0, "reflect", None),
(0, 2, "reflect", None),
(10, 10, "reflect", None),
(0, 0, "reflect", None),
(5, 3, "reflect", None),
(2, 0, "replicate", None),
(0, 2, "replicate", None),
(10, 10, "replicate", None),
(5, 3, "replicate", None),
(0, 0, "replicate", None),
(2, 0, "constant", None),
(0, 3, "constant", 42.),
(4, 4, "constant", -0.57),
(1, 2, "constant", None),
(0, 0, "constant", -0.57)
])
@pytest.mark.nightly
def test_pad5d(self, pad_w, pad_h, mode, value, ie_device, precision, ir_version):
self._test(*self.create_model(mode, value), ie_device, precision, ir_version,
kwargs_to_prepare_input={'ndim': 5, "pad_w": pad_w, "pad_h": pad_h})
@pytest.mark.parametrize("pad_w,pad_h,mode,value", [
(2, 0, "reflect", None),
(0, 2, "reflect", None),
(10, 10, "reflect", None),
(0, 0, "reflect", None),
(5, 3, "reflect", None),
(2, 0, "replicate", None),
(0, 2, "replicate", None),
(10, 10, "replicate", None),
(5, 3, "replicate", None),
(0, 0, "replicate", None),
(2, 0, "constant", None),
(0, 3, "constant", 42.),
(4, 4, "constant", -0.57),
(1, 2, "constant", None),
(0, 0, "constant", -0.57)
])
@pytest.mark.nightly
@pytest.mark.precommit
def test_pad2d(self, pad_w, pad_h, mode, value, ie_device, precision, ir_version):
self._test(*self.create_model(mode, value), ie_device, precision, ir_version,
kwargs_to_prepare_input={'ndim': 2, "pad_w": pad_w, "pad_h": pad_h})