EliminatePad and PadFusion support Pad12 positive indexes (#18278)
* update opset5::Pad -> PadBase * rewrite unit tests * refactor unit tests * add unit test NegativePadElimination * fix add destructor * clang fixes * fix unit tests * add unit tests * bug fix --------- Co-authored-by: Ivan Tikhonov <ivan.tikhonov@intel.com>
This commit is contained in:
parent
8a76f4e7fa
commit
f313dde66b
@ -9,6 +9,7 @@
|
||||
#include <ngraph/util.hpp>
|
||||
#include <numeric>
|
||||
#include <openvino/core/validation_util.hpp>
|
||||
#include <openvino/op/util/pad_base.hpp>
|
||||
#include <openvino/opsets/opset3.hpp>
|
||||
#include <openvino/opsets/opset7.hpp>
|
||||
#include <openvino/opsets/opset8.hpp>
|
||||
@ -315,7 +316,7 @@ SIMPLE_MATCHER_PASS_DEFINITION(EliminateGather, simplify_gather, opset3::Gather,
|
||||
|
||||
pass::EliminatePad::EliminatePad() {
|
||||
MATCHER_SCOPE(EliminatePad);
|
||||
auto pad_node_pattern = pattern::wrap_type<opset8::Pad>();
|
||||
auto pad_node_pattern = pattern::wrap_type<op::util::PadBase>();
|
||||
|
||||
matcher_pass_callback callback = [=](pattern::Matcher& m) {
|
||||
auto pad = m.get_match_root();
|
||||
|
@ -8,6 +8,7 @@
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
#include <ngraph/rt_info.hpp>
|
||||
#include <ngraph/validation_util.hpp>
|
||||
#include <openvino/op/util/pad_base.hpp>
|
||||
#include <openvino/opsets/opset5.hpp>
|
||||
#include <vector>
|
||||
|
||||
@ -17,7 +18,7 @@
|
||||
using namespace ov;
|
||||
|
||||
template <typename T>
|
||||
static bool can_be_fused(const std::shared_ptr<opset5::Pad>& pad,
|
||||
static bool can_be_fused(const std::shared_ptr<op::util::PadBase>& pad,
|
||||
const std::shared_ptr<T>& node,
|
||||
const std::shared_ptr<Node>& pad_value_node,
|
||||
const std::shared_ptr<opset5::Constant>& pads_begin,
|
||||
@ -96,14 +97,14 @@ pass::PadFusionAvgPool::PadFusionAvgPool() {
|
||||
auto pads_end_pattern = pattern::wrap_type<opset5::Constant>();
|
||||
auto pad_value_pattern = pattern::any_input();
|
||||
auto pad_node_pattern =
|
||||
pattern::wrap_type<opset5::Pad>({data_pattern, pads_begin_pattern, pads_end_pattern, pad_value_pattern},
|
||||
pattern::consumers_count(1));
|
||||
pattern::wrap_type<op::util::PadBase>({data_pattern, pads_begin_pattern, pads_end_pattern, pad_value_pattern},
|
||||
pattern::consumers_count(1));
|
||||
auto avg_pool_pattern = pattern::wrap_type<opset5::AvgPool>({pad_node_pattern});
|
||||
|
||||
matcher_pass_callback callback = [=](pattern::Matcher& m) {
|
||||
auto pattern_map = m.get_pattern_value_map();
|
||||
auto data = pattern_map[data_pattern];
|
||||
auto pad = std::dynamic_pointer_cast<opset5::Pad>(pattern_map[pad_node_pattern].get_node_shared_ptr());
|
||||
auto pad = std::dynamic_pointer_cast<op::util::PadBase>(pattern_map[pad_node_pattern].get_node_shared_ptr());
|
||||
auto pad_value = pattern_map[pad_value_pattern].get_node_shared_ptr();
|
||||
auto pads_begin =
|
||||
std::dynamic_pointer_cast<opset5::Constant>(pattern_map[pads_begin_pattern].get_node_shared_ptr());
|
||||
@ -196,15 +197,16 @@ pass::PadFusionConvolution::PadFusionConvolution() {
|
||||
auto pads_end_pattern = pattern::wrap_type<opset5::Constant>();
|
||||
auto pad_value_pattern = pattern::any_input();
|
||||
auto pad_node_pattern =
|
||||
pattern::wrap_type<opset5::Pad>({data_pattern, pads_begin_pattern, pads_end_pattern, pad_value_pattern},
|
||||
pattern::consumers_count(1));
|
||||
pattern::wrap_type<op::util::PadBase>({data_pattern, pads_begin_pattern, pads_end_pattern, pad_value_pattern},
|
||||
pattern::consumers_count(1));
|
||||
auto conv_pattern = pattern::wrap_type<opset5::Convolution>({pad_node_pattern, filter_pattern});
|
||||
|
||||
matcher_pass_callback callback = [=](pattern::Matcher& m) {
|
||||
std::cout << "[EMUTEX DEBUG] CHECKPOINT PadFusionConvolution" << std::endl;
|
||||
auto pattern_map = m.get_pattern_value_map();
|
||||
auto data = pattern_map[data_pattern];
|
||||
auto filter = pattern_map[filter_pattern];
|
||||
auto pad = std::dynamic_pointer_cast<opset5::Pad>(pattern_map[pad_node_pattern].get_node_shared_ptr());
|
||||
auto pad = std::dynamic_pointer_cast<op::util::PadBase>(pattern_map[pad_node_pattern].get_node_shared_ptr());
|
||||
auto pad_value = pattern_map[pad_value_pattern].get_node_shared_ptr();
|
||||
auto pads_begin =
|
||||
std::dynamic_pointer_cast<opset5::Constant>(pattern_map[pads_begin_pattern].get_node_shared_ptr());
|
||||
@ -243,15 +245,15 @@ pass::PadFusionConvolutionBackpropData::PadFusionConvolutionBackpropData() {
|
||||
auto pads_end_pattern = pattern::wrap_type<opset5::Constant>();
|
||||
auto pad_value_pattern = pattern::any_input();
|
||||
auto pad_node_pattern =
|
||||
pattern::wrap_type<opset5::Pad>({data_pattern, pads_begin_pattern, pads_end_pattern, pad_value_pattern},
|
||||
pattern::consumers_count(1));
|
||||
pattern::wrap_type<op::util::PadBase>({data_pattern, pads_begin_pattern, pads_end_pattern, pad_value_pattern},
|
||||
pattern::consumers_count(1));
|
||||
auto conv_pattern = pattern::wrap_type<opset5::ConvolutionBackpropData>({pad_node_pattern, filter_pattern});
|
||||
|
||||
matcher_pass_callback callback = [=](pattern::Matcher& m) {
|
||||
auto pattern_map = m.get_pattern_value_map();
|
||||
auto data = pattern_map[data_pattern];
|
||||
auto filter = pattern_map[filter_pattern];
|
||||
auto pad = std::dynamic_pointer_cast<opset5::Pad>(pattern_map[pad_node_pattern].get_node_shared_ptr());
|
||||
auto pad = std::dynamic_pointer_cast<op::util::PadBase>(pattern_map[pad_node_pattern].get_node_shared_ptr());
|
||||
auto pad_value = pattern_map[pad_value_pattern].get_node_shared_ptr();
|
||||
auto pads_begin =
|
||||
std::dynamic_pointer_cast<opset5::Constant>(pattern_map[pads_begin_pattern].get_node_shared_ptr());
|
||||
@ -301,15 +303,15 @@ pass::PadFusionGroupConvolution::PadFusionGroupConvolution() {
|
||||
auto pads_end_pattern = pattern::wrap_type<opset5::Constant>();
|
||||
auto pad_value_pattern = pattern::any_input();
|
||||
auto pad_node_pattern =
|
||||
pattern::wrap_type<opset5::Pad>({data_pattern, pads_begin_pattern, pads_end_pattern, pad_value_pattern},
|
||||
pattern::consumers_count(1));
|
||||
pattern::wrap_type<op::util::PadBase>({data_pattern, pads_begin_pattern, pads_end_pattern, pad_value_pattern},
|
||||
pattern::consumers_count(1));
|
||||
auto conv_pattern = pattern::wrap_type<opset5::GroupConvolution>({pad_node_pattern, filter_pattern});
|
||||
|
||||
matcher_pass_callback callback = [=](pattern::Matcher& m) {
|
||||
auto pattern_map = m.get_pattern_value_map();
|
||||
auto data = pattern_map[data_pattern];
|
||||
auto filter = pattern_map[filter_pattern];
|
||||
auto pad = std::dynamic_pointer_cast<opset5::Pad>(pattern_map[pad_node_pattern].get_node_shared_ptr());
|
||||
auto pad = std::dynamic_pointer_cast<op::util::PadBase>(pattern_map[pad_node_pattern].get_node_shared_ptr());
|
||||
auto pad_value = pattern_map[pad_value_pattern].get_node_shared_ptr();
|
||||
auto pads_begin =
|
||||
std::dynamic_pointer_cast<opset5::Constant>(pattern_map[pads_begin_pattern].get_node_shared_ptr());
|
||||
@ -349,15 +351,15 @@ pass::PadFusionGroupConvolutionBackpropData::PadFusionGroupConvolutionBackpropDa
|
||||
auto pads_end_pattern = pattern::wrap_type<opset5::Constant>();
|
||||
auto pad_value_pattern = pattern::any_input();
|
||||
auto pad_node_pattern =
|
||||
pattern::wrap_type<opset5::Pad>({data_pattern, pads_begin_pattern, pads_end_pattern, pad_value_pattern},
|
||||
pattern::consumers_count(1));
|
||||
pattern::wrap_type<op::util::PadBase>({data_pattern, pads_begin_pattern, pads_end_pattern, pad_value_pattern},
|
||||
pattern::consumers_count(1));
|
||||
auto conv_pattern = pattern::wrap_type<opset5::GroupConvolutionBackpropData>({pad_node_pattern, filter_pattern});
|
||||
|
||||
matcher_pass_callback callback = [=](pattern::Matcher& m) {
|
||||
auto pattern_map = m.get_pattern_value_map();
|
||||
auto data = pattern_map[data_pattern];
|
||||
auto filter = pattern_map[filter_pattern];
|
||||
auto pad = std::dynamic_pointer_cast<opset5::Pad>(pattern_map[pad_node_pattern].get_node_shared_ptr());
|
||||
auto pad = std::dynamic_pointer_cast<op::util::PadBase>(pattern_map[pad_node_pattern].get_node_shared_ptr());
|
||||
auto pad_value = pattern_map[pad_value_pattern].get_node_shared_ptr();
|
||||
auto pads_begin =
|
||||
std::dynamic_pointer_cast<opset5::Constant>(pattern_map[pads_begin_pattern].get_node_shared_ptr());
|
||||
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user