diff --git a/src/common/transformations/src/transformations/common_optimizations/shuffle_channels_fusion.cpp b/src/common/transformations/src/transformations/common_optimizations/shuffle_channels_fusion.cpp index 8e1f95c2616..9a3cafb4ee8 100644 --- a/src/common/transformations/src/transformations/common_optimizations/shuffle_channels_fusion.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/shuffle_channels_fusion.cpp @@ -98,26 +98,21 @@ ov::pass::ShuffleChannelsFusion::ShuffleChannelsFusion(const bool reshape_consta const auto& pattern_map = m.get_pattern_value_map(); auto data = pattern_map.at(input); + auto reshape_before_constant = std::dynamic_pointer_cast( + pattern_map.at(reshape_before_const_pattern).get_node_shared_ptr()); auto reshape_before = std::dynamic_pointer_cast(pattern_map.at(reshape_before_pattern).get_node_shared_ptr()); auto transpose = std::dynamic_pointer_cast(pattern_map.at(transpose_pattern).get_node_shared_ptr()); auto reshape_after = std::dynamic_pointer_cast(pattern_map.at(reshape_after_pattern).get_node_shared_ptr()); - if (!reshape_after || !transpose || !reshape_after) { + auto reshape_after_constant = std::dynamic_pointer_cast( + pattern_map.at(reshape_after_const_pattern).get_node_shared_ptr()); + if (!reshape_after || !transpose || !reshape_after || !reshape_before_constant || !reshape_after_constant) { return false; } if (reshape_constants_check) { - auto reshape_before_constant = std::dynamic_pointer_cast( - pattern_map.at(reshape_before_const_pattern).get_node_shared_ptr()); - auto reshape_after_constant = std::dynamic_pointer_cast( - pattern_map.at(reshape_after_const_pattern).get_node_shared_ptr()); - - if (!reshape_before_constant || !reshape_after_constant) { - return false; - } - const auto& reshape_before_values = reshape_before_constant->cast_vector(); const auto& reshape_after_values = reshape_after_constant->cast_vector(); if (std::any_of(reshape_before_values.cbegin(), @@ -148,7 +143,13 @@ ov::pass::ShuffleChannelsFusion::ShuffleChannelsFusion(const bool reshape_consta auto shuffle_shannels = std::make_shared(data, axis, group); shuffle_shannels->set_friendly_name(reshape_after->get_friendly_name()); - ngraph::copy_runtime_info({reshape_before, transpose, reshape_after}, shuffle_shannels); + ngraph::copy_runtime_info({reshape_before, + reshape_before_constant, + transpose, + transpose_constant, + reshape_after, + reshape_after_constant}, + shuffle_shannels); ngraph::replace_node(reshape_after, shuffle_shannels); return true; }; diff --git a/src/inference/tests/unit/query_model_test.cpp b/src/inference/tests/unit/query_model_test.cpp index fff953eec6c..803e4f72ea5 100644 --- a/src/inference/tests/unit/query_model_test.cpp +++ b/src/inference/tests/unit/query_model_test.cpp @@ -419,3 +419,45 @@ TEST_F(GetSupportedNodesTest, ShapeOfNonConstantNode) { }, {"input", "slope_compressed", "slope", "prelu"}); // keep dummy only since it has no unsupported consumers } + +TEST_F(GetSupportedNodesTest, ShuffleChannelFusion) { + { + ov::Shape input_shape = {1, 112, 56, 56}; + auto input = std::make_shared(ov::element::f32, input_shape); + input->set_friendly_name("input"); + + ov::Shape reshape_before_shape = {1, 4, 28, 56, 56}; + auto shape_reshape_before = ov::opset9::Constant::create(ov::element::i64, + ov::Shape{reshape_before_shape.size()}, + reshape_before_shape); + shape_reshape_before->set_friendly_name("shape_reshape_before"); + auto reshape_before = std::make_shared(input, shape_reshape_before, true); + reshape_before->set_friendly_name("reshape_before"); + + ov::Shape permute_order = {0, 2, 1, 3, 4}; + auto permutation = + ov::opset9::Constant::create(ov::element::i64, ov::Shape{permute_order.size()}, permute_order); + permutation->set_friendly_name("permutation"); + auto permute = std::make_shared(reshape_before, permutation); + permute->set_friendly_name("permute"); + + auto shape_reshape_after = + ov::opset9::Constant::create(ov::element::i64, ov::Shape{input_shape.size()}, input_shape); + shape_reshape_after->set_friendly_name("shape_reshape_after"); + auto reshape_after = std::make_shared(permute, shape_reshape_after, true); + reshape_after->set_friendly_name("reshape_after"); + + m_function = std::make_shared(ov::NodeVector{reshape_after}, ov::ParameterVector{input}); + } + Run( + [&](std::shared_ptr& model) { + ov::pass::Manager m; + m.register_pass(); + m.register_pass(); + m.run_passes(model); + }, + [&](const std::shared_ptr& op) { + return ov::op::util::is_parameter(op) || ov::op::util::is_output(op) || ov::op::util::is_constant(op); + }, + {}); // Nothing is supported due to unsupported ShuffleChannels +}