[nGraph]: fix fused names for ShuffleChannelsFusion transformation (#15150)

* [nGraph]: fix fused names for ShuffleChannelsFusion transformation

* Review comments
This commit is contained in:
Nadezhda Ageeva 2023-02-03 16:45:12 +04:00 committed by GitHub
parent 4103a931c2
commit 049cfcb72c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 54 additions and 11 deletions

View File

@ -98,26 +98,21 @@ ov::pass::ShuffleChannelsFusion::ShuffleChannelsFusion(const bool reshape_consta
const auto& pattern_map = m.get_pattern_value_map();
auto data = pattern_map.at(input);
auto reshape_before_constant = std::dynamic_pointer_cast<opset6::Constant>(
pattern_map.at(reshape_before_const_pattern).get_node_shared_ptr());
auto reshape_before =
std::dynamic_pointer_cast<opset6::Reshape>(pattern_map.at(reshape_before_pattern).get_node_shared_ptr());
auto transpose =
std::dynamic_pointer_cast<opset6::Transpose>(pattern_map.at(transpose_pattern).get_node_shared_ptr());
auto reshape_after =
std::dynamic_pointer_cast<opset6::Reshape>(pattern_map.at(reshape_after_pattern).get_node_shared_ptr());
if (!reshape_after || !transpose || !reshape_after) {
auto reshape_after_constant = std::dynamic_pointer_cast<opset6::Constant>(
pattern_map.at(reshape_after_const_pattern).get_node_shared_ptr());
if (!reshape_after || !transpose || !reshape_after || !reshape_before_constant || !reshape_after_constant) {
return false;
}
if (reshape_constants_check) {
auto reshape_before_constant = std::dynamic_pointer_cast<opset6::Constant>(
pattern_map.at(reshape_before_const_pattern).get_node_shared_ptr());
auto reshape_after_constant = std::dynamic_pointer_cast<opset6::Constant>(
pattern_map.at(reshape_after_const_pattern).get_node_shared_ptr());
if (!reshape_before_constant || !reshape_after_constant) {
return false;
}
const auto& reshape_before_values = reshape_before_constant->cast_vector<int64_t>();
const auto& reshape_after_values = reshape_after_constant->cast_vector<int64_t>();
if (std::any_of(reshape_before_values.cbegin(),
@ -148,7 +143,13 @@ ov::pass::ShuffleChannelsFusion::ShuffleChannelsFusion(const bool reshape_consta
auto shuffle_shannels = std::make_shared<opset6::ShuffleChannels>(data, axis, group);
shuffle_shannels->set_friendly_name(reshape_after->get_friendly_name());
ngraph::copy_runtime_info({reshape_before, transpose, reshape_after}, shuffle_shannels);
ngraph::copy_runtime_info({reshape_before,
reshape_before_constant,
transpose,
transpose_constant,
reshape_after,
reshape_after_constant},
shuffle_shannels);
ngraph::replace_node(reshape_after, shuffle_shannels);
return true;
};

View File

@ -419,3 +419,45 @@ TEST_F(GetSupportedNodesTest, ShapeOfNonConstantNode) {
},
{"input", "slope_compressed", "slope", "prelu"}); // keep dummy only since it has no unsupported consumers
}
TEST_F(GetSupportedNodesTest, ShuffleChannelFusion) {
{
ov::Shape input_shape = {1, 112, 56, 56};
auto input = std::make_shared<ov::opset9::Parameter>(ov::element::f32, input_shape);
input->set_friendly_name("input");
ov::Shape reshape_before_shape = {1, 4, 28, 56, 56};
auto shape_reshape_before = ov::opset9::Constant::create(ov::element::i64,
ov::Shape{reshape_before_shape.size()},
reshape_before_shape);
shape_reshape_before->set_friendly_name("shape_reshape_before");
auto reshape_before = std::make_shared<ov::opset9::Reshape>(input, shape_reshape_before, true);
reshape_before->set_friendly_name("reshape_before");
ov::Shape permute_order = {0, 2, 1, 3, 4};
auto permutation =
ov::opset9::Constant::create(ov::element::i64, ov::Shape{permute_order.size()}, permute_order);
permutation->set_friendly_name("permutation");
auto permute = std::make_shared<ov::opset9::Transpose>(reshape_before, permutation);
permute->set_friendly_name("permute");
auto shape_reshape_after =
ov::opset9::Constant::create(ov::element::i64, ov::Shape{input_shape.size()}, input_shape);
shape_reshape_after->set_friendly_name("shape_reshape_after");
auto reshape_after = std::make_shared<ov::opset9::Reshape>(permute, shape_reshape_after, true);
reshape_after->set_friendly_name("reshape_after");
m_function = std::make_shared<ov::Model>(ov::NodeVector{reshape_after}, ov::ParameterVector{input});
}
Run(
[&](std::shared_ptr<ov::Model>& model) {
ov::pass::Manager m;
m.register_pass<ov::pass::InitNodeInfo>();
m.register_pass<ov::pass::CommonOptimizations>();
m.run_passes(model);
},
[&](const std::shared_ptr<ov::Node>& op) {
return ov::op::util::is_parameter(op) || ov::op::util::is_output(op) || ov::op::util::is_constant(op);
},
{}); // Nothing is supported due to unsupported ShuffleChannels
}