diff --git a/inference-engine/src/legacy_api/src/convert_function_to_cnn_network.cpp b/inference-engine/src/legacy_api/src/convert_function_to_cnn_network.cpp index 22db9348541..e56ed03e82a 100644 --- a/inference-engine/src/legacy_api/src/convert_function_to_cnn_network.cpp +++ b/inference-engine/src/legacy_api/src/convert_function_to_cnn_network.cpp @@ -332,7 +332,7 @@ InferenceEngine::details::CNNLayerCreator::CNNLayerCreator(const std::shared_ptr res->params = params; return res; }); - + addSpecificCreator({"Assign"}, [](const std::shared_ptr<::ngraph::Node>& node, const std::map params) -> CNNLayerPtr { LayerParams attrs = {node->get_friendly_name(), "Memory", @@ -355,6 +355,24 @@ InferenceEngine::details::CNNLayerCreator::CNNLayerCreator(const std::shared_ptr return res; }); + addSpecificCreator({"DepthToSpace"}, [](const std::shared_ptr<::ngraph::Node>& node, + const std::map params) -> CNNLayerPtr { + LayerParams attrs = {node->get_friendly_name(), node->description(), + details::convertPrecision(node->get_output_element_type(0))}; + auto res = std::make_shared(attrs); + res->params = params; + return res; + }); + + addSpecificCreator({"SpaceToDepth"}, [](const std::shared_ptr<::ngraph::Node>& node, + const std::map params) -> CNNLayerPtr { + LayerParams attrs = {node->get_friendly_name(), node->description(), + details::convertPrecision(node->get_output_element_type(0))}; + auto res = std::make_shared(attrs); + res->params = params; + return res; + }); + addSpecificCreator({"RNNCell"}, [](const std::shared_ptr<::ngraph::Node>& node, const std::map params) -> CNNLayerPtr { THROW_IE_EXCEPTION << "RNNCell operation has a form that is not supported." << node->get_friendly_name() diff --git a/inference-engine/src/transformations/include/transformations/common_optimizations/common_optimizations_tbl.hpp b/inference-engine/src/transformations/include/transformations/common_optimizations/common_optimizations_tbl.hpp index c0a53fcf5f7..f2da6d7049e 100644 --- a/inference-engine/src/transformations/include/transformations/common_optimizations/common_optimizations_tbl.hpp +++ b/inference-engine/src/transformations/include/transformations/common_optimizations/common_optimizations_tbl.hpp @@ -25,3 +25,4 @@ NGRAPH_PASS(NopElimination, ::ngraph::pass) // may introduce fake dynamism NGRAPH_PASS(AlgebraicSimplification, ::ngraph::pass) // may introduce fake dynamism NGRAPH_PASS(ConstantFolding, ::ngraph::pass) NGRAPH_PASS(ConvertScatterElementsToScatter, ::ngraph::pass) // partially depends on CF +NGRAPH_PASS(DepthToSpaceFusion, ::ngraph::pass) diff --git a/inference-engine/src/transformations/include/transformations/convert_depth_to_space.hpp b/inference-engine/src/transformations/include/transformations/convert_depth_to_space.hpp index 7ad3f6c26e3..8ac0cbabba9 100644 --- a/inference-engine/src/transformations/include/transformations/convert_depth_to_space.hpp +++ b/inference-engine/src/transformations/include/transformations/convert_depth_to_space.hpp @@ -10,6 +10,7 @@ #include #include +#include "transformations/utils/pass_param.hpp" namespace ngraph { namespace pass { @@ -19,9 +20,9 @@ class INFERENCE_ENGINE_API_CLASS(ConvertDepthToSpace); } // namespace pass } // namespace ngraph -class ngraph::pass::ConvertDepthToSpace: public ngraph::pass::GraphRewrite { +class ngraph::pass::ConvertDepthToSpace: public ngraph::pass::GraphRewrite, public ngraph::pass::PassParam { public: - ConvertDepthToSpace() : GraphRewrite() { + ConvertDepthToSpace() : GraphRewrite(), PassParam() { convert_depth_to_space(); } diff --git a/inference-engine/src/transformations/include/transformations/convert_opset3_to_opset2/convert_opset3_to_opset2_tbl.hpp b/inference-engine/src/transformations/include/transformations/convert_opset3_to_opset2/convert_opset3_to_opset2_tbl.hpp index c802dcb3f2a..271c0e96493 100644 --- a/inference-engine/src/transformations/include/transformations/convert_opset3_to_opset2/convert_opset3_to_opset2_tbl.hpp +++ b/inference-engine/src/transformations/include/transformations/convert_opset3_to_opset2/convert_opset3_to_opset2_tbl.hpp @@ -19,4 +19,3 @@ NGRAPH_PASS(ConvertNMS3, ::ngraph::pass) NGRAPH_PASS(ConvertShapeOf3, ::ngraph::pass) NGRAPH_PASS(ConvertShuffleChannels3, ::ngraph::pass) NGRAPH_PASS(ConvertTopK3, ::ngraph::pass) - diff --git a/inference-engine/src/transformations/include/transformations/convert_space_to_depth.hpp b/inference-engine/src/transformations/include/transformations/convert_space_to_depth.hpp index 75058b30360..f148c90f4a6 100644 --- a/inference-engine/src/transformations/include/transformations/convert_space_to_depth.hpp +++ b/inference-engine/src/transformations/include/transformations/convert_space_to_depth.hpp @@ -10,6 +10,7 @@ #include #include +#include "transformations/utils/pass_param.hpp" namespace ngraph { namespace pass { @@ -19,9 +20,9 @@ class INFERENCE_ENGINE_API_CLASS(ConvertSpaceToDepth); } // namespace pass } // namespace ngraph -class ngraph::pass::ConvertSpaceToDepth: public ngraph::pass::GraphRewrite { +class ngraph::pass::ConvertSpaceToDepth: public ngraph::pass::GraphRewrite, public ngraph::pass::PassParam { public: - ConvertSpaceToDepth() : GraphRewrite() { + ConvertSpaceToDepth() : GraphRewrite(), PassParam() { convert(); } diff --git a/inference-engine/src/transformations/include/transformations/depth_to_space_fusion.hpp b/inference-engine/src/transformations/include/transformations/depth_to_space_fusion.hpp new file mode 100644 index 00000000000..e9e0fc159c7 --- /dev/null +++ b/inference-engine/src/transformations/include/transformations/depth_to_space_fusion.hpp @@ -0,0 +1,54 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include + +#include + +#include +#include "transformations/utils/pass_param.hpp" + +namespace ngraph { +namespace pass { + + class INFERENCE_ENGINE_API_CLASS(DepthToSpaceFusion); + +} // namespace pass +} // namespace ngraph + +/* + * Description: + * DepthToSpaceFusion transformation detects Reshape-Transpose-Reshape pattern and + * tries to fuse it into a single DepthToSpace layer. + * + * Usage: + * DepthToSpaceFusion transformation is optional and disabled by default. + * The transformation can be enabled with callback using setCallback method. + * See the example below. + * + * Callback example: + * + * // This callback enables DepthToSpaceFusion transformation + * auto callback = [](const std::shared_ptr & node) -> bool { + * return std::dynamic_pointer_cast(node) != nullptr; + * }; + * + * auto p = ngraph::pass::DepthToSpaceFusion(); + * p.setCallback(callback); + * p.run_on_function(f); + * + */ + +class ngraph::pass::DepthToSpaceFusion: public ngraph::pass::GraphRewrite, public ngraph::pass::PassParam { +public: + DepthToSpaceFusion() : GraphRewrite(), PassParam() { + depth_to_space_fusion(); + } + +private: + void depth_to_space_fusion(); +}; diff --git a/inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp b/inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp index 467f0766d59..5b1a956d187 100644 --- a/inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp +++ b/inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp @@ -6,6 +6,7 @@ #include "transformations/common_optimizations/common_optimizations.hpp" #include "transformations/convert_opset1_to_legacy/convert_prior_to_ie_prior.hpp" +#include "transformations/depth_to_space_fusion.hpp" #include "transformations/optimize_strided_slice.hpp" #include "transformations/convert_scatter_elements_to_scatter.hpp" #include "transformations/remove_filtering_boxes_by_size.hpp" diff --git a/inference-engine/src/transformations/src/transformations/convert_depth_to_space.cpp b/inference-engine/src/transformations/src/transformations/convert_depth_to_space.cpp index aea1555256b..959b4591f78 100644 --- a/inference-engine/src/transformations/src/transformations/convert_depth_to_space.cpp +++ b/inference-engine/src/transformations/src/transformations/convert_depth_to_space.cpp @@ -14,9 +14,9 @@ void ngraph::pass::ConvertDepthToSpace::convert_depth_to_space() { auto input0 = std::make_shared(element::f32, Shape{1, 1, 1, 1}); auto dts_node = std::make_shared(input0, ngraph::op::DepthToSpace::DepthToSpaceMode::DEPTH_FIRST); - ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) { + ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) { auto dts_node = std::dynamic_pointer_cast (m.get_match_root()); - if (!dts_node) { + if (!dts_node || transformation_callback(dts_node)) { return false; } diff --git a/inference-engine/src/transformations/src/transformations/convert_space_to_depth.cpp b/inference-engine/src/transformations/src/transformations/convert_space_to_depth.cpp index f4ff7ec97f2..5eb16557ce9 100644 --- a/inference-engine/src/transformations/src/transformations/convert_space_to_depth.cpp +++ b/inference-engine/src/transformations/src/transformations/convert_space_to_depth.cpp @@ -14,9 +14,9 @@ void ngraph::pass::ConvertSpaceToDepth::convert() { auto input0 = std::make_shared(element::f32, Shape{1, 1, 1, 1}); auto dts = std::make_shared(input0, ngraph::opset1::SpaceToDepth::SpaceToDepthMode::DEPTH_FIRST); - ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) { + ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) { auto std_node = std::dynamic_pointer_cast (m.get_match_root()); - if (!std_node) { + if (!std_node || transformation_callback(std_node)) { return false; } diff --git a/inference-engine/src/transformations/src/transformations/depth_to_space_fusion.cpp b/inference-engine/src/transformations/src/transformations/depth_to_space_fusion.cpp new file mode 100644 index 00000000000..a2323b16202 --- /dev/null +++ b/inference-engine/src/transformations/src/transformations/depth_to_space_fusion.cpp @@ -0,0 +1,166 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "transformations/depth_to_space_fusion.hpp" + +#include +#include + +#include +#include + +bool check_block_first(const ngraph::Shape& shape_input, const ngraph::Shape& shape_reshape_before, + const ngraph::AxisVector& permutation, const ngraph::Shape& shape_reshape_after, + size_t& possible_block_size) { + bool is_transformation_valid = true; + uint64_t spatial_dims = shape_input.size() - 2; + possible_block_size = shape_reshape_before[1]; + if (possible_block_size == 0) + return false; + uint64_t c_dim = shape_input[1] / std::pow(possible_block_size, spatial_dims); + + // x' = reshape(data, [N, block_size, block_size, ..., block_size, C / (block_size ^ K), D1, D2, ..., DK]) + ngraph::Shape expected_shape = {shape_input[0]}; + for (uint64_t i = 0; i < spatial_dims; ++i) + expected_shape.push_back(possible_block_size); + expected_shape.push_back(c_dim); + for (uint64_t i = 2; i < shape_input.size(); ++i) + expected_shape.push_back(shape_input[i]); + is_transformation_valid &= (expected_shape == shape_reshape_before); + + // x'' = transpose(x', [0, K + 1, K + 2, 1, K + 3, 2, K + 4, 3, ..., K + (K + 1), K]) + ngraph::AxisVector expected_permutation = {0, spatial_dims + 1}; + for (uint64_t i = 2; i < shape_input.size(); ++i) { + expected_permutation.push_back(spatial_dims + i); + expected_permutation.push_back(i - 1); + } + is_transformation_valid &= (expected_permutation == permutation); + + // y = reshape(x'', [N, C / (block_size ^ K), D1 * block_size, D2 * block_size, D3 * block_size, ..., DK * block_size]) + expected_shape = {shape_input[0], c_dim}; + for (uint64_t i = 2; i < shape_input.size(); ++i) + expected_shape.push_back(shape_input[i] * possible_block_size); + is_transformation_valid &= (expected_shape == shape_reshape_after); + + return is_transformation_valid; +} + +bool check_depth_first(const ngraph::Shape& shape_input, const ngraph::Shape& shape_reshape_before, + const ngraph::AxisVector& permutation, const ngraph::Shape& shape_reshape_after, + size_t& possible_block_size) { + bool is_transformation_valid = true; + uint64_t spatial_dims = shape_input.size() - 2; + possible_block_size = shape_reshape_before[2]; + if (possible_block_size == 0) + return false; + uint64_t c_dim = shape_input[1] / std::pow(possible_block_size, spatial_dims); + + // x' = reshape(data, [N, C / (block_size ^ K), block_size, block_size, ..., block_size, D1, D2, ..., DK]) + ngraph::Shape expected_shape = {shape_input[0], c_dim}; + for (uint64_t i = 0; i < spatial_dims; ++i) + expected_shape.push_back(possible_block_size); + for (uint64_t i = 2; i < shape_input.size(); ++i) + expected_shape.push_back(shape_input[i]); + is_transformation_valid &= (expected_shape == shape_reshape_before); + + // x'' = transpose(x', [0, 1, K + 2, 2, K + 3, 3, K + 4, 4, ..., K + (K + 1), K + 1]) + ngraph::AxisVector expected_permutation = {0, 1}; + for (uint64_t i = 2; i < shape_input.size(); ++i) { + expected_permutation.push_back(spatial_dims + i); + expected_permutation.push_back(i); + } + is_transformation_valid &= (expected_permutation == permutation); + + // y = reshape(x'', [N, C / (block_size ^ K), D1 * block_size, D2 * block_size, D3 * block_size, ..., DK * block_size]) + expected_shape = {shape_input[0], c_dim}; + for (uint64_t i = 2; i < shape_input.size(); ++i) + expected_shape.push_back(shape_input[i] * possible_block_size); + is_transformation_valid &= (expected_shape == shape_reshape_after); + + return is_transformation_valid; +} + +void ngraph::pass::DepthToSpaceFusion::depth_to_space_fusion() { + auto input0 = std::make_shared(element::f32, Shape{1, 1, 1, 1}); + auto input1 = std::make_shared(element::i64, Shape{4}); + auto input2 = std::make_shared(element::i64, Shape{4}); + auto input3 = std::make_shared(element::i64, Shape{4}); + auto reshape_before = std::make_shared (input0, input1, false); + auto permute = std::make_shared (reshape_before, input2); + auto reshape_after = std::make_shared (permute, input3, false); + + ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) { + if (!transformation_callback(std::make_shared())) { + return false; + } + + auto reshape_after = std::dynamic_pointer_cast(m.get_match_root()); + if (!reshape_after) { + return false; + } + + auto permute = std::dynamic_pointer_cast(reshape_after->input_value(0).get_node_shared_ptr()); + if (!permute || permute->get_output_target_inputs(0).size() != 1) { + return false; + } + + auto reshape_before = std::dynamic_pointer_cast(permute->input_value(0).get_node_shared_ptr()); + if (!reshape_before || reshape_before->get_output_target_inputs(0).size() != 1) { + return false; + } + + auto p_shape_input = reshape_before->get_input_partial_shape(0); + auto p_shape_reshape_before = reshape_before->get_output_partial_shape(0); + auto p_shape_permute = permute->get_output_partial_shape(0); + auto p_shape_reshape_after = reshape_after->get_output_partial_shape(0); + + if (p_shape_input.is_dynamic() || p_shape_reshape_before.is_dynamic() || + p_shape_permute.is_dynamic() || p_shape_reshape_after.is_dynamic()) { + return false; + } + + auto shape_input = p_shape_input.get_shape(); + auto shape_reshape_before = p_shape_reshape_before.get_shape(); + auto shape_permute = p_shape_permute.get_shape(); + auto shape_reshape_after = p_shape_reshape_after.get_shape(); + + if (shape_input.size() < 3) { + return false; + } + + // input shape: [ batch, C, spatial_dims], expected_shape = spatial_dims.size() * 2 + 2 + size_t expected_shape_size = (shape_input.size() - 2) * 2 + 2; + if (shape_input.size() != shape_reshape_after.size() || shape_reshape_before.size() != expected_shape_size || + shape_permute.size() != expected_shape_size) { + return false; + } + + ngraph::AxisVector permutation; + if (auto input_const = std::dynamic_pointer_cast(permute->input_value(1).get_node_shared_ptr())) { + permutation = input_const->get_axis_vector_val(); + } else { + return false; + } + + ngraph::opset3::DepthToSpace::DepthToSpaceMode mode; + size_t block_size; + if (check_depth_first(shape_input, shape_reshape_before, permutation, shape_reshape_after, block_size)) { + mode = ngraph::opset3::DepthToSpace::DepthToSpaceMode::DEPTH_FIRST; + } else if (check_block_first(shape_input, shape_reshape_before, permutation, shape_reshape_after, block_size)) { + mode = ngraph::opset3::DepthToSpace::DepthToSpaceMode::BLOCKS_FIRST; + } else { + return false; + } + + auto depth_to_space = + std::make_shared(reshape_before->input_value(0), mode, block_size); + depth_to_space->set_friendly_name(reshape_after->get_friendly_name()); + ngraph::copy_runtime_info({reshape_before, permute, reshape_after}, depth_to_space); + ngraph::replace_node(reshape_after, depth_to_space); + return true; + }; + + auto m = std::make_shared(reshape_after, "DepthToSpaceFusion"); + this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE); +} \ No newline at end of file diff --git a/inference-engine/tests/functional/inference_engine/transformations/depth_to_space_fusion_test.cpp b/inference-engine/tests/functional/inference_engine/transformations/depth_to_space_fusion_test.cpp new file mode 100644 index 00000000000..55a91872470 --- /dev/null +++ b/inference-engine/tests/functional/inference_engine/transformations/depth_to_space_fusion_test.cpp @@ -0,0 +1,184 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include "common_test_utils/test_common.hpp" +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "ngraph_test_utils.hpp" + +using namespace testing; + +TEST(TransformationTests, DepthToSpaceFusionDepthFirst) { + std::shared_ptr f(nullptr), f_ref(nullptr); + { + auto input0 = std::make_shared(ngraph::element::f32, ngraph::Shape{1, 128, 720, 480}); + auto shape_reshape_before = ngraph::opset3::Constant::create(ngraph::element::i64, ngraph::Shape{6}, {1, 32, 2, 2, 720, 480}); + auto permutation = ngraph::opset3::Constant::create(ngraph::element::i64, ngraph::Shape{6}, {0, 1, 4, 2, 5, 3}); + auto shape_reshape_after = ngraph::opset3::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {1, 32, 1440, 960}); + + auto reshape_before = std::make_shared (input0, shape_reshape_before, false); + auto permute = std::make_shared (reshape_before, permutation); + auto reshape_after = std::make_shared (permute, shape_reshape_after, false); + + f = std::make_shared(ngraph::NodeVector{reshape_after}, ngraph::ParameterVector{input0}); + ngraph::pass::InitNodeInfo().run_on_function(f); + auto callback = [](const std::shared_ptr & node) -> bool { + return std::dynamic_pointer_cast(node) != nullptr; + }; + + auto depth_to_space_transform = ngraph::pass::DepthToSpaceFusion(); + depth_to_space_transform.setCallback(callback); + depth_to_space_transform.run_on_function(f); + ASSERT_NO_THROW(check_rt_info(f)); + } + + { + auto input0 = std::make_shared(ngraph::element::f32, ngraph::Shape{1, 128, 720, 480}); + auto depth_to_space = std::make_shared(input0, ngraph::opset3::DepthToSpace::DepthToSpaceMode::DEPTH_FIRST, 2); + f_ref = std::make_shared(ngraph::NodeVector{depth_to_space}, ngraph::ParameterVector{input0}); + } + + auto res = compare_functions(f, f_ref); + ASSERT_TRUE(res.first) << res.second; +} + +TEST(TransformationTests, DepthToSpaceFusionBlockFirst) { + std::shared_ptr f(nullptr), f_ref(nullptr); + { + auto input0 = std::make_shared(ngraph::element::f32, ngraph::Shape{1, 128, 720, 480}); + auto shape_reshape_before = ngraph::opset3::Constant::create(ngraph::element::i64, ngraph::Shape{6}, {1, 2, 2, 32, 720, 480}); + auto permutation = ngraph::opset3::Constant::create(ngraph::element::i64, ngraph::Shape{6}, {0, 3, 4, 1, 5, 2}); + auto shape_reshape_after = ngraph::opset3::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {1, 32, 1440, 960}); + + auto reshape_before = std::make_shared (input0, shape_reshape_before, false); + auto permute = std::make_shared (reshape_before, permutation); + auto reshape_after = std::make_shared (permute, shape_reshape_after, false); + + f = std::make_shared(ngraph::NodeVector{reshape_after}, ngraph::ParameterVector{input0}); + ngraph::pass::InitNodeInfo().run_on_function(f); + auto callback = [](const std::shared_ptr & node) -> bool { + return std::dynamic_pointer_cast(node) != nullptr; + }; + + auto depth_to_space_transform = ngraph::pass::DepthToSpaceFusion(); + depth_to_space_transform.setCallback(callback); + depth_to_space_transform.run_on_function(f); + ASSERT_NO_THROW(check_rt_info(f)); + } + + { + auto input0 = std::make_shared(ngraph::element::f32, ngraph::Shape{1, 128, 720, 480}); + auto depth_to_space = std::make_shared(input0, ngraph::opset3::DepthToSpace::DepthToSpaceMode::BLOCKS_FIRST, 2); + f_ref = std::make_shared(ngraph::NodeVector{depth_to_space}, ngraph::ParameterVector{input0}); + } + + auto res = compare_functions(f, f_ref); + ASSERT_TRUE(res.first) << res.second; +} + +TEST(TransformationTests, DepthToSpaceFusionDynamicShape) { + std::shared_ptr f(nullptr), f_ref(nullptr); + { + auto input0 = std::make_shared(ngraph::element::f32, ngraph::Shape{1, 128, 720, 480}); + auto shape_reshape_before = std::make_shared(ngraph::element::i64, ngraph::Shape{6}); + auto permutation = ngraph::opset3::Constant::create(ngraph::element::i64, ngraph::Shape{6}, {0, 3, 4, 1, 5, 2}); + auto shape_reshape_after = ngraph::opset3::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {1, 32, 1440, 960}); + + auto reshape_before = std::make_shared (input0, shape_reshape_before, false); + auto permute = std::make_shared (reshape_before, permutation); + auto reshape_after = std::make_shared (permute, shape_reshape_after, false); + + f = std::make_shared(ngraph::NodeVector{reshape_after}, ngraph::ParameterVector{input0, shape_reshape_before}); + ngraph::pass::InitNodeInfo().run_on_function(f); + auto callback = [](const std::shared_ptr & node) -> bool { + return std::dynamic_pointer_cast(node) != nullptr; + }; + + // transformation won't be applied because of shape_reshape_before is dynamic, the graph will remain the same + auto depth_to_space_transform = ngraph::pass::DepthToSpaceFusion(); + depth_to_space_transform.setCallback(callback); + depth_to_space_transform.run_on_function(f); + ASSERT_NO_THROW(check_rt_info(f)); + } + + { + auto input0 = std::make_shared(ngraph::element::f32, ngraph::Shape{1, 128, 720, 480}); + auto shape_reshape_before = std::make_shared(ngraph::element::i64, ngraph::Shape{6}); + auto permutation = ngraph::opset3::Constant::create(ngraph::element::i64, ngraph::Shape{6}, {0, 3, 4, 1, 5, 2}); + auto shape_reshape_after = ngraph::opset3::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {1, 32, 1440, 960}); + + auto reshape_before = std::make_shared (input0, shape_reshape_before, false); + auto permute = std::make_shared (reshape_before, permutation); + auto reshape_after = std::make_shared (permute, shape_reshape_after, false); + + f_ref = std::make_shared(ngraph::NodeVector{reshape_after}, ngraph::ParameterVector{input0, shape_reshape_before}); + } + + auto res = compare_functions(f, f_ref); + ASSERT_TRUE(res.first) << res.second; +} + +TEST(TransformationTests, DepthToSpaceFusionSeveralConsumers) { + std::shared_ptr f(nullptr), f_ref(nullptr); + { + auto input0 = std::make_shared(ngraph::element::f32, ngraph::Shape{1, 128, 720, 480}); + auto shape_reshape_before = ngraph::opset3::Constant::create(ngraph::element::i64, ngraph::Shape{6}, {1, 2, 2, 32, 720, 480}); + auto permutation = ngraph::opset3::Constant::create(ngraph::element::i64, ngraph::Shape{6}, {0, 3, 4, 1, 5, 2}); + auto shape_reshape_after = ngraph::opset3::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {1, 32, 1440, 960}); + + auto reshape_before = std::make_shared (input0, shape_reshape_before, false); + auto permute = std::make_shared (reshape_before, permutation); + auto reshape_after = std::make_shared (permute, shape_reshape_after, false); + + // additional consumers, not output of the function + auto result = std::make_shared (reshape_before); + auto result_2 = std::make_shared (permute); + f = std::make_shared(ngraph::NodeVector{reshape_after}, ngraph::ParameterVector{input0}); + ngraph::pass::InitNodeInfo().run_on_function(f); + auto callback = [](const std::shared_ptr & node) -> bool { + return std::dynamic_pointer_cast(node) != nullptr; + }; + + // transformation won't be applied because of reshape_before has several consumers, the graph will remain the same + auto depth_to_space_transform = ngraph::pass::DepthToSpaceFusion(); + depth_to_space_transform.setCallback(callback); + depth_to_space_transform.run_on_function(f); + ASSERT_NO_THROW(check_rt_info(f)); + } + + { + auto input0 = std::make_shared(ngraph::element::f32, ngraph::Shape{1, 128, 720, 480}); + auto shape_reshape_before = ngraph::opset3::Constant::create(ngraph::element::i64, ngraph::Shape{6}, {1, 2, 2, 32, 720, 480}); + auto permutation = ngraph::opset3::Constant::create(ngraph::element::i64, ngraph::Shape{6}, {0, 3, 4, 1, 5, 2}); + auto shape_reshape_after = ngraph::opset3::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {1, 32, 1440, 960}); + + auto reshape_before = std::make_shared (input0, shape_reshape_before, false); + auto permute = std::make_shared (reshape_before, permutation); + auto reshape_after = std::make_shared (permute, shape_reshape_after, false); + + // additional consumers, not output of the function + auto result = std::make_shared (reshape_before); + auto result_2 = std::make_shared (permute); + + f_ref = std::make_shared(ngraph::NodeVector{reshape_after}, ngraph::ParameterVector{input0}); + } + + auto res = compare_functions(f, f_ref); + ASSERT_TRUE(res.first) << res.second; +}