[Transformations] ShuffleChannelsFusion fix and tests added (#5448)
This commit is contained in:
parent
64a032fa18
commit
449f3376e1
@ -55,7 +55,11 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::ShuffleChannelsFusion, "ShuffleChannelsFusi
|
||||
|
||||
ngraph::pass::ShuffleChannelsFusion::ShuffleChannelsFusion(const bool reshape_constants_check) {
|
||||
MATCHER_SCOPE(ShuffleChannelsFusion);
|
||||
auto input = ngraph::pattern::any_input(pattern::has_static_shape());
|
||||
auto has_static_4d_shape = [](const Output<Node>& output) {
|
||||
return pattern::has_static_shape()(output) && pattern::rank_equals(4)(output);
|
||||
};
|
||||
|
||||
auto input = ngraph::pattern::any_input(has_static_4d_shape);
|
||||
auto reshape_before_const_pattern = ngraph::pattern::wrap_type<ngraph::opset6::Constant>();
|
||||
auto transpose_const_pattern = ngraph::pattern::wrap_type<ngraph::opset6::Constant>();
|
||||
auto reshape_after_const_pattern = ngraph::pattern::wrap_type<ngraph::opset6::Constant>();
|
||||
|
@ -22,11 +22,10 @@ using namespace ngraph;
|
||||
|
||||
class ShuffleChannelsFusionTestValues {
|
||||
public:
|
||||
bool dynamicShape;
|
||||
ngraph::PartialShape inputPartialShape;
|
||||
std::vector<int64_t> reshape_before_val;
|
||||
std::vector<size_t> transpose_val;
|
||||
std::vector<int64_t> reshape_after_val;
|
||||
size_t batch_size;
|
||||
bool check_reshape_values;
|
||||
bool fuse_happened;
|
||||
};
|
||||
@ -49,8 +48,7 @@ public:
|
||||
void SetUp() override {
|
||||
const auto values = GetParam();
|
||||
{
|
||||
const PartialShape inputPartialShape = values.dynamicShape ? PartialShape::dynamic() : Shape{ values.batch_size, 128, 720, 480 };
|
||||
auto input0 = std::make_shared<opset6::Parameter>(element::f32, inputPartialShape);
|
||||
auto input0 = std::make_shared<opset6::Parameter>(element::f32, values.inputPartialShape);
|
||||
auto shape_reshape_before = opset6::Constant::create(element::i64, Shape{ values.reshape_before_val.size() }, values.reshape_before_val);
|
||||
auto permutation = opset6::Constant::create(element::i64, Shape{ values.transpose_val.size() }, values.transpose_val);
|
||||
auto shape_reshape_after = opset6::Constant::create(element::i64, Shape{ values.reshape_after_val.size() }, values.reshape_after_val);
|
||||
@ -69,7 +67,7 @@ public:
|
||||
}
|
||||
|
||||
if (values.fuse_happened) {
|
||||
auto input0 = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ values.batch_size, 128, 720, 480 });
|
||||
auto input0 = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, values.inputPartialShape);
|
||||
auto shuffle_channels = std::make_shared<ngraph::opset6::ShuffleChannels>(input0, 1, values.reshape_before_val[1]);
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ shuffle_channels }, ngraph::ParameterVector{ input0 });
|
||||
} else {
|
||||
@ -81,10 +79,10 @@ public:
|
||||
const ShuffleChannelsFusionTestValues testValues = obj.param;
|
||||
|
||||
std::ostringstream result;
|
||||
if (testValues.dynamicShape) {
|
||||
if (testValues.inputPartialShape.is_dynamic()) {
|
||||
result << "_dynamic_shape_";
|
||||
} else {
|
||||
result << "_batch_size_" << testValues.batch_size;
|
||||
result << "_input_shape_" << testValues.inputPartialShape;
|
||||
}
|
||||
|
||||
result << "_before_" << testValues.reshape_before_val
|
||||
@ -105,19 +103,30 @@ TEST_P(ShuffleChannelsFusion, CompareFunctions) {
|
||||
}
|
||||
|
||||
const std::vector<ShuffleChannelsFusionTestValues> testValues = {
|
||||
{ true, {1, 2, 64, 720, 480}, {0, 2, 1, 3, 4}, {1, 128, 720, 480}, 1, false, false },
|
||||
{ false, {1, 2, 64, 720, 480}, {0, 2, 1, 3, 4}, {1, 128, 720, 480}, 1, false, true },
|
||||
{ false, {1, 2, 64, 720, 480}, {0, 2, 1, 3, 4}, {1, 128, 720, 480}, 1, true, true },
|
||||
{ false, {1, 2, 64, 720, 480}, {0, 2, 1, 3, 4}, {1, -1, 720, 480}, 1, false, true },
|
||||
{ false, {4, 2, 64, 720, 480}, {0, 2, 1, 3, 4}, {1, -1, 720, 480}, 4, false, false },
|
||||
{ false, {1, 2, 64, 720, 480}, {0, 2, 1, 3, 4}, {1, -1, 720, 480}, 1, true, false },
|
||||
{ true, {1, 4, 32, 720 * 480}, {0, 2, 1, 3}, {1, 128, 720, 480}, 1, false, false },
|
||||
{ false, {1, 4, 32, 720 * 480}, {0, 2, 1, 3}, {1, 128, 720, 480}, 1, false, true },
|
||||
{ false, {1, 2, 64, 720 * 480}, {0, 2, 1, 3}, {1, 128, 720, 480}, 1, true, true },
|
||||
{ false, {1, 2, 64, 720 * 480}, {0, 2, 1, 3}, {1, -1, 720, 480}, 1, false, true },
|
||||
{ false, {4, 2, 64, 720 * 480}, {0, 2, 1, 3}, {1, -1, 720, 480}, 4, false, false },
|
||||
{ false, {1, 2, 64, 720 * 480}, {0, 2, 1, 3}, {1, -1, 720, 480}, 1, true, false },
|
||||
// dynamic shape
|
||||
{ ngraph::PartialShape::dynamic(), {1, 2, 64, 720, 480}, {0, 2, 1, 3, 4}, {1, 128, 720, 480}, false, false },
|
||||
{ ngraph::PartialShape::dynamic(), {1, 4, 32, 720 * 480}, {0, 2, 1, 3}, {1, 128, 720, 480}, false, false },
|
||||
|
||||
// 4D, batch_size = 1, 4D reshape constant
|
||||
{ {1, 128, 720, 480}, {1, 2, 64, 720, 480}, {0, 2, 1, 3, 4}, {1, 128, 720, 480}, false, true },
|
||||
{ {1, 128, 720, 480}, {1, 2, 64, 720, 480}, {0, 2, 1, 3, 4}, {1, 128, 720, 480}, true, true },
|
||||
{ {1, 128, 720, 480}, {1, 2, 64, 720, 480}, {0, 2, 1, 3, 4}, {1, -1, 720, 480}, false, true },
|
||||
{ {1, 128, 720, 480}, {1, 2, 64, 720, 480}, {0, 2, 1, 3, 4}, {1, -1, 720, 480}, true, false },
|
||||
|
||||
// 4D, batch_size = 1, 3D reshape constant
|
||||
{ {1, 128, 720, 480}, {1, 4, 32, 720 * 480}, {0, 2, 1, 3}, {1, 128, 720, 480}, false, true },
|
||||
{ {1, 128, 720, 480}, {1, 2, 64, 720 * 480}, {0, 2, 1, 3}, {1, 128, 720, 480}, true, true },
|
||||
{ {1, 128, 720, 480}, {1, 2, 64, 720 * 480}, {0, 2, 1, 3}, {1, -1, 720, 480}, false, true },
|
||||
{ {1, 128, 720, 480}, {1, 2, 64, 720 * 480}, {0, 2, 1, 3}, {1, -1, 720, 480}, true, false },
|
||||
|
||||
// 4D, batch_size = 4
|
||||
{ {4, 128, 720, 480}, {4, 2, 64, 720, 480}, {0, 2, 1, 3, 4}, {1, -1, 720, 480}, false, false },
|
||||
{ {4, 128, 720, 480}, {4, 2, 64, 720 * 480}, {0, 2, 1, 3}, {1, -1, 720, 480}, false, false },
|
||||
|
||||
// 2D
|
||||
{ {128, 720 * 480}, {1, 2, 64, 720 * 480}, {0, 2, 1, 3}, {1, -1, 720, 480}, false, false },
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(
|
||||
TransformationTests,
|
||||
ShuffleChannelsFusion,
|
||||
|
Loading…
Reference in New Issue
Block a user