[nGraph]: fix fused names for ShuffleChannelsFusion transformation (#15150)
* [nGraph]: fix fused names for ShuffleChannelsFusion transformation * Review comments
This commit is contained in:
parent
4103a931c2
commit
049cfcb72c
@ -98,26 +98,21 @@ ov::pass::ShuffleChannelsFusion::ShuffleChannelsFusion(const bool reshape_consta
|
||||
const auto& pattern_map = m.get_pattern_value_map();
|
||||
|
||||
auto data = pattern_map.at(input);
|
||||
auto reshape_before_constant = std::dynamic_pointer_cast<opset6::Constant>(
|
||||
pattern_map.at(reshape_before_const_pattern).get_node_shared_ptr());
|
||||
auto reshape_before =
|
||||
std::dynamic_pointer_cast<opset6::Reshape>(pattern_map.at(reshape_before_pattern).get_node_shared_ptr());
|
||||
auto transpose =
|
||||
std::dynamic_pointer_cast<opset6::Transpose>(pattern_map.at(transpose_pattern).get_node_shared_ptr());
|
||||
auto reshape_after =
|
||||
std::dynamic_pointer_cast<opset6::Reshape>(pattern_map.at(reshape_after_pattern).get_node_shared_ptr());
|
||||
if (!reshape_after || !transpose || !reshape_after) {
|
||||
auto reshape_after_constant = std::dynamic_pointer_cast<opset6::Constant>(
|
||||
pattern_map.at(reshape_after_const_pattern).get_node_shared_ptr());
|
||||
if (!reshape_after || !transpose || !reshape_after || !reshape_before_constant || !reshape_after_constant) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (reshape_constants_check) {
|
||||
auto reshape_before_constant = std::dynamic_pointer_cast<opset6::Constant>(
|
||||
pattern_map.at(reshape_before_const_pattern).get_node_shared_ptr());
|
||||
auto reshape_after_constant = std::dynamic_pointer_cast<opset6::Constant>(
|
||||
pattern_map.at(reshape_after_const_pattern).get_node_shared_ptr());
|
||||
|
||||
if (!reshape_before_constant || !reshape_after_constant) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto& reshape_before_values = reshape_before_constant->cast_vector<int64_t>();
|
||||
const auto& reshape_after_values = reshape_after_constant->cast_vector<int64_t>();
|
||||
if (std::any_of(reshape_before_values.cbegin(),
|
||||
@ -148,7 +143,13 @@ ov::pass::ShuffleChannelsFusion::ShuffleChannelsFusion(const bool reshape_consta
|
||||
|
||||
auto shuffle_shannels = std::make_shared<opset6::ShuffleChannels>(data, axis, group);
|
||||
shuffle_shannels->set_friendly_name(reshape_after->get_friendly_name());
|
||||
ngraph::copy_runtime_info({reshape_before, transpose, reshape_after}, shuffle_shannels);
|
||||
ngraph::copy_runtime_info({reshape_before,
|
||||
reshape_before_constant,
|
||||
transpose,
|
||||
transpose_constant,
|
||||
reshape_after,
|
||||
reshape_after_constant},
|
||||
shuffle_shannels);
|
||||
ngraph::replace_node(reshape_after, shuffle_shannels);
|
||||
return true;
|
||||
};
|
||||
|
@ -419,3 +419,45 @@ TEST_F(GetSupportedNodesTest, ShapeOfNonConstantNode) {
|
||||
},
|
||||
{"input", "slope_compressed", "slope", "prelu"}); // keep dummy only since it has no unsupported consumers
|
||||
}
|
||||
|
||||
TEST_F(GetSupportedNodesTest, ShuffleChannelFusion) {
|
||||
{
|
||||
ov::Shape input_shape = {1, 112, 56, 56};
|
||||
auto input = std::make_shared<ov::opset9::Parameter>(ov::element::f32, input_shape);
|
||||
input->set_friendly_name("input");
|
||||
|
||||
ov::Shape reshape_before_shape = {1, 4, 28, 56, 56};
|
||||
auto shape_reshape_before = ov::opset9::Constant::create(ov::element::i64,
|
||||
ov::Shape{reshape_before_shape.size()},
|
||||
reshape_before_shape);
|
||||
shape_reshape_before->set_friendly_name("shape_reshape_before");
|
||||
auto reshape_before = std::make_shared<ov::opset9::Reshape>(input, shape_reshape_before, true);
|
||||
reshape_before->set_friendly_name("reshape_before");
|
||||
|
||||
ov::Shape permute_order = {0, 2, 1, 3, 4};
|
||||
auto permutation =
|
||||
ov::opset9::Constant::create(ov::element::i64, ov::Shape{permute_order.size()}, permute_order);
|
||||
permutation->set_friendly_name("permutation");
|
||||
auto permute = std::make_shared<ov::opset9::Transpose>(reshape_before, permutation);
|
||||
permute->set_friendly_name("permute");
|
||||
|
||||
auto shape_reshape_after =
|
||||
ov::opset9::Constant::create(ov::element::i64, ov::Shape{input_shape.size()}, input_shape);
|
||||
shape_reshape_after->set_friendly_name("shape_reshape_after");
|
||||
auto reshape_after = std::make_shared<ov::opset9::Reshape>(permute, shape_reshape_after, true);
|
||||
reshape_after->set_friendly_name("reshape_after");
|
||||
|
||||
m_function = std::make_shared<ov::Model>(ov::NodeVector{reshape_after}, ov::ParameterVector{input});
|
||||
}
|
||||
Run(
|
||||
[&](std::shared_ptr<ov::Model>& model) {
|
||||
ov::pass::Manager m;
|
||||
m.register_pass<ov::pass::InitNodeInfo>();
|
||||
m.register_pass<ov::pass::CommonOptimizations>();
|
||||
m.run_passes(model);
|
||||
},
|
||||
[&](const std::shared_ptr<ov::Node>& op) {
|
||||
return ov::op::util::is_parameter(op) || ov::op::util::is_output(op) || ov::op::util::is_constant(op);
|
||||
},
|
||||
{}); // Nothing is supported due to unsupported ShuffleChannels
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user