[Transformations] ShuffleChannelsFusion fix and tests added (#5448)

This commit is contained in:
Vladislav Golubev 2021-04-29 10:47:04 +03:00 committed by GitHub
parent 64a032fa18
commit 449f3376e1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 33 additions and 20 deletions

View File

@ -55,7 +55,11 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::ShuffleChannelsFusion, "ShuffleChannelsFusi
ngraph::pass::ShuffleChannelsFusion::ShuffleChannelsFusion(const bool reshape_constants_check) {
MATCHER_SCOPE(ShuffleChannelsFusion);
auto input = ngraph::pattern::any_input(pattern::has_static_shape());
auto has_static_4d_shape = [](const Output<Node>& output) {
return pattern::has_static_shape()(output) && pattern::rank_equals(4)(output);
};
auto input = ngraph::pattern::any_input(has_static_4d_shape);
auto reshape_before_const_pattern = ngraph::pattern::wrap_type<ngraph::opset6::Constant>();
auto transpose_const_pattern = ngraph::pattern::wrap_type<ngraph::opset6::Constant>();
auto reshape_after_const_pattern = ngraph::pattern::wrap_type<ngraph::opset6::Constant>();

View File

@ -22,11 +22,10 @@ using namespace ngraph;
class ShuffleChannelsFusionTestValues {
public:
bool dynamicShape;
ngraph::PartialShape inputPartialShape;
std::vector<int64_t> reshape_before_val;
std::vector<size_t> transpose_val;
std::vector<int64_t> reshape_after_val;
size_t batch_size;
bool check_reshape_values;
bool fuse_happened;
};
@ -49,8 +48,7 @@ public:
void SetUp() override {
const auto values = GetParam();
{
const PartialShape inputPartialShape = values.dynamicShape ? PartialShape::dynamic() : Shape{ values.batch_size, 128, 720, 480 };
auto input0 = std::make_shared<opset6::Parameter>(element::f32, inputPartialShape);
auto input0 = std::make_shared<opset6::Parameter>(element::f32, values.inputPartialShape);
auto shape_reshape_before = opset6::Constant::create(element::i64, Shape{ values.reshape_before_val.size() }, values.reshape_before_val);
auto permutation = opset6::Constant::create(element::i64, Shape{ values.transpose_val.size() }, values.transpose_val);
auto shape_reshape_after = opset6::Constant::create(element::i64, Shape{ values.reshape_after_val.size() }, values.reshape_after_val);
@ -69,7 +67,7 @@ public:
}
if (values.fuse_happened) {
auto input0 = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ values.batch_size, 128, 720, 480 });
auto input0 = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, values.inputPartialShape);
auto shuffle_channels = std::make_shared<ngraph::opset6::ShuffleChannels>(input0, 1, values.reshape_before_val[1]);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ shuffle_channels }, ngraph::ParameterVector{ input0 });
} else {
@ -81,10 +79,10 @@ public:
const ShuffleChannelsFusionTestValues testValues = obj.param;
std::ostringstream result;
if (testValues.dynamicShape) {
if (testValues.inputPartialShape.is_dynamic()) {
result << "_dynamic_shape_";
} else {
result << "_batch_size_" << testValues.batch_size;
result << "_input_shape_" << testValues.inputPartialShape;
}
result << "_before_" << testValues.reshape_before_val
@ -105,19 +103,30 @@ TEST_P(ShuffleChannelsFusion, CompareFunctions) {
}
const std::vector<ShuffleChannelsFusionTestValues> testValues = {
{ true, {1, 2, 64, 720, 480}, {0, 2, 1, 3, 4}, {1, 128, 720, 480}, 1, false, false },
{ false, {1, 2, 64, 720, 480}, {0, 2, 1, 3, 4}, {1, 128, 720, 480}, 1, false, true },
{ false, {1, 2, 64, 720, 480}, {0, 2, 1, 3, 4}, {1, 128, 720, 480}, 1, true, true },
{ false, {1, 2, 64, 720, 480}, {0, 2, 1, 3, 4}, {1, -1, 720, 480}, 1, false, true },
{ false, {4, 2, 64, 720, 480}, {0, 2, 1, 3, 4}, {1, -1, 720, 480}, 4, false, false },
{ false, {1, 2, 64, 720, 480}, {0, 2, 1, 3, 4}, {1, -1, 720, 480}, 1, true, false },
{ true, {1, 4, 32, 720 * 480}, {0, 2, 1, 3}, {1, 128, 720, 480}, 1, false, false },
{ false, {1, 4, 32, 720 * 480}, {0, 2, 1, 3}, {1, 128, 720, 480}, 1, false, true },
{ false, {1, 2, 64, 720 * 480}, {0, 2, 1, 3}, {1, 128, 720, 480}, 1, true, true },
{ false, {1, 2, 64, 720 * 480}, {0, 2, 1, 3}, {1, -1, 720, 480}, 1, false, true },
{ false, {4, 2, 64, 720 * 480}, {0, 2, 1, 3}, {1, -1, 720, 480}, 4, false, false },
{ false, {1, 2, 64, 720 * 480}, {0, 2, 1, 3}, {1, -1, 720, 480}, 1, true, false },
// dynamic shape
{ ngraph::PartialShape::dynamic(), {1, 2, 64, 720, 480}, {0, 2, 1, 3, 4}, {1, 128, 720, 480}, false, false },
{ ngraph::PartialShape::dynamic(), {1, 4, 32, 720 * 480}, {0, 2, 1, 3}, {1, 128, 720, 480}, false, false },
// 4D, batch_size = 1, 4D reshape constant
{ {1, 128, 720, 480}, {1, 2, 64, 720, 480}, {0, 2, 1, 3, 4}, {1, 128, 720, 480}, false, true },
{ {1, 128, 720, 480}, {1, 2, 64, 720, 480}, {0, 2, 1, 3, 4}, {1, 128, 720, 480}, true, true },
{ {1, 128, 720, 480}, {1, 2, 64, 720, 480}, {0, 2, 1, 3, 4}, {1, -1, 720, 480}, false, true },
{ {1, 128, 720, 480}, {1, 2, 64, 720, 480}, {0, 2, 1, 3, 4}, {1, -1, 720, 480}, true, false },
// 4D, batch_size = 1, 3D reshape constant
{ {1, 128, 720, 480}, {1, 4, 32, 720 * 480}, {0, 2, 1, 3}, {1, 128, 720, 480}, false, true },
{ {1, 128, 720, 480}, {1, 2, 64, 720 * 480}, {0, 2, 1, 3}, {1, 128, 720, 480}, true, true },
{ {1, 128, 720, 480}, {1, 2, 64, 720 * 480}, {0, 2, 1, 3}, {1, -1, 720, 480}, false, true },
{ {1, 128, 720, 480}, {1, 2, 64, 720 * 480}, {0, 2, 1, 3}, {1, -1, 720, 480}, true, false },
// 4D, batch_size = 4
{ {4, 128, 720, 480}, {4, 2, 64, 720, 480}, {0, 2, 1, 3, 4}, {1, -1, 720, 480}, false, false },
{ {4, 128, 720, 480}, {4, 2, 64, 720 * 480}, {0, 2, 1, 3}, {1, -1, 720, 480}, false, false },
// 2D
{ {128, 720 * 480}, {1, 2, 64, 720 * 480}, {0, 2, 1, 3}, {1, -1, 720, 480}, false, false },
};
INSTANTIATE_TEST_CASE_P(
TransformationTests,
ShuffleChannelsFusion,