Reject negative pads in SpaceToBatchFusion (#18028)
This commit is contained in:
@@ -8,6 +8,7 @@
|
||||
#include <ngraph/pattern/op/or.hpp>
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
#include <ngraph/rt_info.hpp>
|
||||
#include <openvino/op/util/pad_base.hpp>
|
||||
#include <openvino/opsets/opset6.hpp>
|
||||
#include <vector>
|
||||
|
||||
@@ -28,7 +29,7 @@ ov::pass::SpaceToBatchFusion::SpaceToBatchFusion() {
|
||||
auto pads_begin_pattern = pattern::wrap_type<opset6::Constant>();
|
||||
auto pads_end_pattern = pattern::wrap_type<opset6::Constant>();
|
||||
auto pad_value = pattern::wrap_type<opset6::Constant>();
|
||||
auto pad_pattern = pattern::wrap_type<opset6::Pad>(
|
||||
auto pad_pattern = pattern::wrap_type<op::util::PadBase>(
|
||||
{reshape_or_transpose_before_pattern, pads_begin_pattern, pads_end_pattern, pad_value});
|
||||
auto space_to_depth_pattern = pattern::wrap_type<opset6::SpaceToDepth>({pad_pattern}, pattern::has_static_shape());
|
||||
auto reshape_after_pattern =
|
||||
@@ -60,6 +61,20 @@ ov::pass::SpaceToBatchFusion::SpaceToBatchFusion() {
|
||||
input_shape[2] == output_shape[2] && input_shape[3] == output_shape[3];
|
||||
};
|
||||
|
||||
auto pads_are_negative = [](const std::shared_ptr<Node>& pads) -> bool {
|
||||
auto constant = ov::as_type_ptr<opset6::Constant>(pads);
|
||||
if (!constant)
|
||||
return true;
|
||||
|
||||
for (auto pad : constant->cast_vector<int>()) {
|
||||
if (pad < 0) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
};
|
||||
|
||||
std::shared_ptr<Node> reshape_or_trans_before =
|
||||
get_reshape_or_transpose(reshape_before_pattern, trans_before_pattern);
|
||||
if (!reshape_or_trans_before)
|
||||
@@ -73,7 +88,7 @@ ov::pass::SpaceToBatchFusion::SpaceToBatchFusion() {
|
||||
if (!check_input_output_shape(reshape_or_trans_after))
|
||||
return false;
|
||||
|
||||
auto pad = std::dynamic_pointer_cast<opset6::Pad>(pattern_map.at(pad_pattern).get_node_shared_ptr());
|
||||
auto pad = std::dynamic_pointer_cast<op::util::PadBase>(pattern_map.at(pad_pattern).get_node_shared_ptr());
|
||||
if (!pad || pad->get_pad_mode() != op::PadMode::CONSTANT)
|
||||
return false;
|
||||
auto pad_value_const =
|
||||
@@ -84,6 +99,13 @@ ov::pass::SpaceToBatchFusion::SpaceToBatchFusion() {
|
||||
if (pad_value.size() != 1 || pad_value[0] != 0.0f)
|
||||
return false;
|
||||
|
||||
const auto pads_begin = pattern_map.at(pads_begin_pattern).get_node_shared_ptr();
|
||||
if (pads_are_negative(pads_begin))
|
||||
return false;
|
||||
const auto pads_end = pattern_map.at(pads_end_pattern).get_node_shared_ptr();
|
||||
if (pads_are_negative(pads_end))
|
||||
return false;
|
||||
|
||||
auto space_to_depth = std::dynamic_pointer_cast<opset6::SpaceToDepth>(
|
||||
pattern_map.at(space_to_depth_pattern).get_node_shared_ptr());
|
||||
if (!space_to_depth)
|
||||
@@ -93,10 +115,8 @@ ov::pass::SpaceToBatchFusion::SpaceToBatchFusion() {
|
||||
auto block_size = static_cast<int64_t>(space_to_depth->get_block_size());
|
||||
auto block_shape =
|
||||
opset6::Constant::create(element::i64, Shape{4}, std::vector<int64_t>{1, 1, block_size, block_size});
|
||||
auto space_to_batch = register_new_node<opset6::SpaceToBatch>(pattern_map.at(data_pattern),
|
||||
block_shape,
|
||||
pattern_map.at(pads_begin_pattern),
|
||||
pattern_map.at(pads_end_pattern));
|
||||
auto space_to_batch =
|
||||
register_new_node<opset6::SpaceToBatch>(pattern_map.at(data_pattern), block_shape, pads_begin, pads_end);
|
||||
space_to_batch->set_friendly_name(reshape_or_trans_after->get_friendly_name());
|
||||
|
||||
copy_runtime_info(
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
#include <ngraph/opsets/opset6.hpp>
|
||||
#include <ngraph/pass/constant_folding.hpp>
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
#include <openvino/op/pad.hpp>
|
||||
#include <queue>
|
||||
#include <string>
|
||||
#include <transformations/common_optimizations/space_to_batch_fusion.hpp>
|
||||
@@ -52,6 +53,59 @@ TEST_F(TransformationTestsF, SpaceToBatchFusionTranspose) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, SpaceToBatchFusionTransposePad12) {
|
||||
{
|
||||
auto data = std::make_shared<opset6::Parameter>(element::f32, Shape{12, 3, 4, 8});
|
||||
auto trans_before =
|
||||
std::make_shared<opset6::Transpose>(data, op::Constant::create(element::i64, Shape{4}, {1, 0, 2, 3}));
|
||||
auto pad = std::make_shared<op::v12::Pad>(trans_before,
|
||||
op::Constant::create(element::i64, Shape{4}, {1, 1, 1, 1}),
|
||||
op::Constant::create(element::i64, Shape{4}, {2, 2, 3, 3}),
|
||||
op::Constant::create(element::f32, Shape{}, {0}),
|
||||
op::PadMode::CONSTANT);
|
||||
auto space_to_depth =
|
||||
std::make_shared<opset6::SpaceToDepth>(pad, opset6::SpaceToDepth::SpaceToDepthMode::BLOCKS_FIRST, 2);
|
||||
auto trans_after =
|
||||
std::make_shared<opset6::Transpose>(space_to_depth,
|
||||
op::Constant::create(element::i64, Shape{4}, {1, 0, 2, 3}));
|
||||
function = std::make_shared<Function>(NodeVector{trans_after}, ParameterVector{data});
|
||||
|
||||
manager.register_pass<ov::pass::SpaceToBatchFusion>();
|
||||
}
|
||||
|
||||
{
|
||||
auto data = std::make_shared<opset6::Parameter>(element::f32, Shape{12, 3, 4, 8});
|
||||
auto space_to_batch =
|
||||
std::make_shared<opset6::SpaceToBatch>(data,
|
||||
op::Constant::create(element::i64, Shape{4}, {1, 1, 2, 2}),
|
||||
op::Constant::create(element::i64, Shape{4}, {1, 1, 1, 1}),
|
||||
op::Constant::create(element::i64, Shape{4}, {2, 2, 3, 3}));
|
||||
|
||||
function_ref = std::make_shared<Function>(NodeVector{space_to_batch}, ParameterVector{data});
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, SpaceToBatchFusionTransposeNegativePads) {
|
||||
{
|
||||
auto data = std::make_shared<opset6::Parameter>(element::f32, Shape{12, 3, 4, 8});
|
||||
auto trans_before =
|
||||
std::make_shared<opset6::Transpose>(data, op::Constant::create(element::i64, Shape{4}, {1, 0, 2, 3}));
|
||||
auto pad = std::make_shared<op::v12::Pad>(trans_before,
|
||||
op::Constant::create(element::i64, Shape{4}, {1, 1, -1, -1}),
|
||||
op::Constant::create(element::i64, Shape{4}, {2, 2, -3, -3}),
|
||||
op::Constant::create(element::f32, Shape{}, {0}),
|
||||
op::PadMode::CONSTANT);
|
||||
auto space_to_depth =
|
||||
std::make_shared<opset6::SpaceToDepth>(pad, opset6::SpaceToDepth::SpaceToDepthMode::BLOCKS_FIRST, 4);
|
||||
auto trans_after =
|
||||
std::make_shared<opset6::Transpose>(space_to_depth,
|
||||
op::Constant::create(element::i64, Shape{4}, {1, 0, 2, 3}));
|
||||
function = std::make_shared<Function>(NodeVector{trans_after}, ParameterVector{data});
|
||||
|
||||
manager.register_pass<ov::pass::SpaceToBatchFusion>();
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, SpaceToBatchFusionReshape) {
|
||||
{
|
||||
auto data = std::make_shared<opset6::Parameter>(element::f32, Shape{12, 3, 4, 8});
|
||||
|
||||
Reference in New Issue
Block a user