Reshape-Permute-Reshape pattern to DepthToSpace layer transformation (#601)

* implemented depth_to_space transformation

* renaming

* added functional tests, fixed mistakes in implementation of the transformation

* disable ConvertSpaceToDepth/ConvertDepthToSpace transformation for CPU plugin, enable DepthToSpaceFusion for CPU plugin only, add specific creators

* fix wrong include

* fix for functional tests: set transformation callback

* revert callback calls for CPU plugin

* move functions to .cpp file

* Apply review comments

* Apply additional review comments

* fix cast to bool type
This commit is contained in:
Ivan Tikhonov 2020-06-01 09:24:16 +03:00 committed by GitHub
parent b4893945c7
commit cd01ccd449
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 435 additions and 10 deletions

View File

@ -355,6 +355,24 @@ InferenceEngine::details::CNNLayerCreator::CNNLayerCreator(const std::shared_ptr
return res;
});
addSpecificCreator({"DepthToSpace"}, [](const std::shared_ptr<::ngraph::Node>& node,
const std::map<std::string, std::string> params) -> CNNLayerPtr {
LayerParams attrs = {node->get_friendly_name(), node->description(),
details::convertPrecision(node->get_output_element_type(0))};
auto res = std::make_shared<DepthToSpaceLayer>(attrs);
res->params = params;
return res;
});
addSpecificCreator({"SpaceToDepth"}, [](const std::shared_ptr<::ngraph::Node>& node,
const std::map<std::string, std::string> params) -> CNNLayerPtr {
LayerParams attrs = {node->get_friendly_name(), node->description(),
details::convertPrecision(node->get_output_element_type(0))};
auto res = std::make_shared<SpaceToDepthLayer>(attrs);
res->params = params;
return res;
});
addSpecificCreator({"RNNCell"}, [](const std::shared_ptr<::ngraph::Node>& node,
const std::map<std::string, std::string> params) -> CNNLayerPtr {
THROW_IE_EXCEPTION << "RNNCell operation has a form that is not supported." << node->get_friendly_name()

View File

@ -25,3 +25,4 @@ NGRAPH_PASS(NopElimination, ::ngraph::pass) // may introduce fake dynamism
NGRAPH_PASS(AlgebraicSimplification, ::ngraph::pass) // may introduce fake dynamism
NGRAPH_PASS(ConstantFolding, ::ngraph::pass)
NGRAPH_PASS(ConvertScatterElementsToScatter, ::ngraph::pass) // partially depends on CF
NGRAPH_PASS(DepthToSpaceFusion, ::ngraph::pass)

View File

@ -10,6 +10,7 @@
#include <ie_api.h>
#include <ngraph/pass/graph_rewrite.hpp>
#include "transformations/utils/pass_param.hpp"
namespace ngraph {
namespace pass {
@ -19,9 +20,9 @@ class INFERENCE_ENGINE_API_CLASS(ConvertDepthToSpace);
} // namespace pass
} // namespace ngraph
class ngraph::pass::ConvertDepthToSpace: public ngraph::pass::GraphRewrite {
class ngraph::pass::ConvertDepthToSpace: public ngraph::pass::GraphRewrite, public ngraph::pass::PassParam {
public:
ConvertDepthToSpace() : GraphRewrite() {
ConvertDepthToSpace() : GraphRewrite(), PassParam() {
convert_depth_to_space();
}

View File

@ -19,4 +19,3 @@ NGRAPH_PASS(ConvertNMS3, ::ngraph::pass)
NGRAPH_PASS(ConvertShapeOf3, ::ngraph::pass)
NGRAPH_PASS(ConvertShuffleChannels3, ::ngraph::pass)
NGRAPH_PASS(ConvertTopK3, ::ngraph::pass)

View File

@ -10,6 +10,7 @@
#include <ie_api.h>
#include <ngraph/pass/graph_rewrite.hpp>
#include "transformations/utils/pass_param.hpp"
namespace ngraph {
namespace pass {
@ -19,9 +20,9 @@ class INFERENCE_ENGINE_API_CLASS(ConvertSpaceToDepth);
} // namespace pass
} // namespace ngraph
class ngraph::pass::ConvertSpaceToDepth: public ngraph::pass::GraphRewrite {
class ngraph::pass::ConvertSpaceToDepth: public ngraph::pass::GraphRewrite, public ngraph::pass::PassParam {
public:
ConvertSpaceToDepth() : GraphRewrite() {
ConvertSpaceToDepth() : GraphRewrite(), PassParam() {
convert();
}

View File

@ -0,0 +1,54 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <vector>
#include <memory>
#include <ie_api.h>
#include <ngraph/pass/graph_rewrite.hpp>
#include "transformations/utils/pass_param.hpp"
namespace ngraph {
namespace pass {
class INFERENCE_ENGINE_API_CLASS(DepthToSpaceFusion);
} // namespace pass
} // namespace ngraph
/*
* Description:
* DepthToSpaceFusion transformation detects Reshape-Transpose-Reshape pattern and
* tries to fuse it into a single DepthToSpace layer.
*
* Usage:
* DepthToSpaceFusion transformation is optional and disabled by default.
* The transformation can be enabled with callback using setCallback method.
* See the example below.
*
* Callback example:
*
* // This callback enables DepthToSpaceFusion transformation
* auto callback = [](const std::shared_ptr<const ngraph::Node> & node) -> bool {
* return std::dynamic_pointer_cast<const ngraph::opset3::DepthToSpace>(node) != nullptr;
* };
*
* auto p = ngraph::pass::DepthToSpaceFusion();
* p.setCallback(callback);
* p.run_on_function(f);
*
*/
class ngraph::pass::DepthToSpaceFusion: public ngraph::pass::GraphRewrite, public ngraph::pass::PassParam {
public:
DepthToSpaceFusion() : GraphRewrite(), PassParam() {
depth_to_space_fusion();
}
private:
void depth_to_space_fusion();
};

View File

@ -6,6 +6,7 @@
#include "transformations/common_optimizations/common_optimizations.hpp"
#include "transformations/convert_opset1_to_legacy/convert_prior_to_ie_prior.hpp"
#include "transformations/depth_to_space_fusion.hpp"
#include "transformations/optimize_strided_slice.hpp"
#include "transformations/convert_scatter_elements_to_scatter.hpp"
#include "transformations/remove_filtering_boxes_by_size.hpp"

View File

@ -14,9 +14,9 @@ void ngraph::pass::ConvertDepthToSpace::convert_depth_to_space() {
auto input0 = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1, 1});
auto dts_node = std::make_shared<ngraph::opset1::DepthToSpace>(input0, ngraph::op::DepthToSpace::DepthToSpaceMode::DEPTH_FIRST);
ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
auto dts_node = std::dynamic_pointer_cast<ngraph::opset1::DepthToSpace> (m.get_match_root());
if (!dts_node) {
if (!dts_node || transformation_callback(dts_node)) {
return false;
}

View File

@ -14,9 +14,9 @@ void ngraph::pass::ConvertSpaceToDepth::convert() {
auto input0 = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1, 1});
auto dts = std::make_shared<ngraph::opset1::SpaceToDepth>(input0, ngraph::opset1::SpaceToDepth::SpaceToDepthMode::DEPTH_FIRST);
ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
auto std_node = std::dynamic_pointer_cast<ngraph::opset1::SpaceToDepth> (m.get_match_root());
if (!std_node) {
if (!std_node || transformation_callback(std_node)) {
return false;
}

View File

@ -0,0 +1,166 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/depth_to_space_fusion.hpp"
#include <memory>
#include <vector>
#include <ngraph/opsets/opset3.hpp>
#include <ngraph/rt_info.hpp>
bool check_block_first(const ngraph::Shape& shape_input, const ngraph::Shape& shape_reshape_before,
const ngraph::AxisVector& permutation, const ngraph::Shape& shape_reshape_after,
size_t& possible_block_size) {
bool is_transformation_valid = true;
uint64_t spatial_dims = shape_input.size() - 2;
possible_block_size = shape_reshape_before[1];
if (possible_block_size == 0)
return false;
uint64_t c_dim = shape_input[1] / std::pow(possible_block_size, spatial_dims);
// x' = reshape(data, [N, block_size, block_size, ..., block_size, C / (block_size ^ K), D1, D2, ..., DK])
ngraph::Shape expected_shape = {shape_input[0]};
for (uint64_t i = 0; i < spatial_dims; ++i)
expected_shape.push_back(possible_block_size);
expected_shape.push_back(c_dim);
for (uint64_t i = 2; i < shape_input.size(); ++i)
expected_shape.push_back(shape_input[i]);
is_transformation_valid &= (expected_shape == shape_reshape_before);
// x'' = transpose(x', [0, K + 1, K + 2, 1, K + 3, 2, K + 4, 3, ..., K + (K + 1), K])
ngraph::AxisVector expected_permutation = {0, spatial_dims + 1};
for (uint64_t i = 2; i < shape_input.size(); ++i) {
expected_permutation.push_back(spatial_dims + i);
expected_permutation.push_back(i - 1);
}
is_transformation_valid &= (expected_permutation == permutation);
// y = reshape(x'', [N, C / (block_size ^ K), D1 * block_size, D2 * block_size, D3 * block_size, ..., DK * block_size])
expected_shape = {shape_input[0], c_dim};
for (uint64_t i = 2; i < shape_input.size(); ++i)
expected_shape.push_back(shape_input[i] * possible_block_size);
is_transformation_valid &= (expected_shape == shape_reshape_after);
return is_transformation_valid;
}
bool check_depth_first(const ngraph::Shape& shape_input, const ngraph::Shape& shape_reshape_before,
const ngraph::AxisVector& permutation, const ngraph::Shape& shape_reshape_after,
size_t& possible_block_size) {
bool is_transformation_valid = true;
uint64_t spatial_dims = shape_input.size() - 2;
possible_block_size = shape_reshape_before[2];
if (possible_block_size == 0)
return false;
uint64_t c_dim = shape_input[1] / std::pow(possible_block_size, spatial_dims);
// x' = reshape(data, [N, C / (block_size ^ K), block_size, block_size, ..., block_size, D1, D2, ..., DK])
ngraph::Shape expected_shape = {shape_input[0], c_dim};
for (uint64_t i = 0; i < spatial_dims; ++i)
expected_shape.push_back(possible_block_size);
for (uint64_t i = 2; i < shape_input.size(); ++i)
expected_shape.push_back(shape_input[i]);
is_transformation_valid &= (expected_shape == shape_reshape_before);
// x'' = transpose(x', [0, 1, K + 2, 2, K + 3, 3, K + 4, 4, ..., K + (K + 1), K + 1])
ngraph::AxisVector expected_permutation = {0, 1};
for (uint64_t i = 2; i < shape_input.size(); ++i) {
expected_permutation.push_back(spatial_dims + i);
expected_permutation.push_back(i);
}
is_transformation_valid &= (expected_permutation == permutation);
// y = reshape(x'', [N, C / (block_size ^ K), D1 * block_size, D2 * block_size, D3 * block_size, ..., DK * block_size])
expected_shape = {shape_input[0], c_dim};
for (uint64_t i = 2; i < shape_input.size(); ++i)
expected_shape.push_back(shape_input[i] * possible_block_size);
is_transformation_valid &= (expected_shape == shape_reshape_after);
return is_transformation_valid;
}
void ngraph::pass::DepthToSpaceFusion::depth_to_space_fusion() {
auto input0 = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1, 1});
auto input1 = std::make_shared<pattern::op::Label>(element::i64, Shape{4});
auto input2 = std::make_shared<pattern::op::Label>(element::i64, Shape{4});
auto input3 = std::make_shared<pattern::op::Label>(element::i64, Shape{4});
auto reshape_before = std::make_shared<ngraph::opset3::Reshape> (input0, input1, false);
auto permute = std::make_shared<ngraph::opset3::Transpose> (reshape_before, input2);
auto reshape_after = std::make_shared<ngraph::opset3::Reshape> (permute, input3, false);
ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
if (!transformation_callback(std::make_shared<ngraph::opset3::DepthToSpace>())) {
return false;
}
auto reshape_after = std::dynamic_pointer_cast<ngraph::opset3::Reshape>(m.get_match_root());
if (!reshape_after) {
return false;
}
auto permute = std::dynamic_pointer_cast<ngraph::opset3::Transpose>(reshape_after->input_value(0).get_node_shared_ptr());
if (!permute || permute->get_output_target_inputs(0).size() != 1) {
return false;
}
auto reshape_before = std::dynamic_pointer_cast<ngraph::opset3::Reshape>(permute->input_value(0).get_node_shared_ptr());
if (!reshape_before || reshape_before->get_output_target_inputs(0).size() != 1) {
return false;
}
auto p_shape_input = reshape_before->get_input_partial_shape(0);
auto p_shape_reshape_before = reshape_before->get_output_partial_shape(0);
auto p_shape_permute = permute->get_output_partial_shape(0);
auto p_shape_reshape_after = reshape_after->get_output_partial_shape(0);
if (p_shape_input.is_dynamic() || p_shape_reshape_before.is_dynamic() ||
p_shape_permute.is_dynamic() || p_shape_reshape_after.is_dynamic()) {
return false;
}
auto shape_input = p_shape_input.get_shape();
auto shape_reshape_before = p_shape_reshape_before.get_shape();
auto shape_permute = p_shape_permute.get_shape();
auto shape_reshape_after = p_shape_reshape_after.get_shape();
if (shape_input.size() < 3) {
return false;
}
// input shape: [ batch, C, spatial_dims], expected_shape = spatial_dims.size() * 2 + 2
size_t expected_shape_size = (shape_input.size() - 2) * 2 + 2;
if (shape_input.size() != shape_reshape_after.size() || shape_reshape_before.size() != expected_shape_size ||
shape_permute.size() != expected_shape_size) {
return false;
}
ngraph::AxisVector permutation;
if (auto input_const = std::dynamic_pointer_cast<opset3::Constant>(permute->input_value(1).get_node_shared_ptr())) {
permutation = input_const->get_axis_vector_val();
} else {
return false;
}
ngraph::opset3::DepthToSpace::DepthToSpaceMode mode;
size_t block_size;
if (check_depth_first(shape_input, shape_reshape_before, permutation, shape_reshape_after, block_size)) {
mode = ngraph::opset3::DepthToSpace::DepthToSpaceMode::DEPTH_FIRST;
} else if (check_block_first(shape_input, shape_reshape_before, permutation, shape_reshape_after, block_size)) {
mode = ngraph::opset3::DepthToSpace::DepthToSpaceMode::BLOCKS_FIRST;
} else {
return false;
}
auto depth_to_space =
std::make_shared<ngraph::opset3::DepthToSpace>(reshape_before->input_value(0), mode, block_size);
depth_to_space->set_friendly_name(reshape_after->get_friendly_name());
ngraph::copy_runtime_info({reshape_before, permute, reshape_after}, depth_to_space);
ngraph::replace_node(reshape_after, depth_to_space);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(reshape_after, "DepthToSpaceFusion");
this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
}

View File

@ -0,0 +1,184 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include "common_test_utils/test_common.hpp"
#include <string>
#include <sstream>
#include <fstream>
#include <memory>
#include <queue>
#include <map>
#include <ngraph/function.hpp>
#include <ngraph/opsets/opset3.hpp>
#include <ngraph/pass/constant_folding.hpp>
#include <ngraph_ops/fully_connected.hpp>
#include <transformations/depth_to_space_fusion.hpp>
#include <transformations/utils/utils.hpp>
#include <transformations/init_node_info.hpp>
#include "ngraph_test_utils.hpp"
using namespace testing;
TEST(TransformationTests, DepthToSpaceFusionDepthFirst) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto input0 = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::f32, ngraph::Shape{1, 128, 720, 480});
auto shape_reshape_before = ngraph::opset3::Constant::create(ngraph::element::i64, ngraph::Shape{6}, {1, 32, 2, 2, 720, 480});
auto permutation = ngraph::opset3::Constant::create(ngraph::element::i64, ngraph::Shape{6}, {0, 1, 4, 2, 5, 3});
auto shape_reshape_after = ngraph::opset3::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {1, 32, 1440, 960});
auto reshape_before = std::make_shared<ngraph::opset3::Reshape> (input0, shape_reshape_before, false);
auto permute = std::make_shared<ngraph::opset3::Transpose> (reshape_before, permutation);
auto reshape_after = std::make_shared<ngraph::opset3::Reshape> (permute, shape_reshape_after, false);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{reshape_after}, ngraph::ParameterVector{input0});
ngraph::pass::InitNodeInfo().run_on_function(f);
auto callback = [](const std::shared_ptr<const ngraph::Node> & node) -> bool {
return std::dynamic_pointer_cast<const ngraph::opset3::DepthToSpace>(node) != nullptr;
};
auto depth_to_space_transform = ngraph::pass::DepthToSpaceFusion();
depth_to_space_transform.setCallback(callback);
depth_to_space_transform.run_on_function(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto input0 = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::f32, ngraph::Shape{1, 128, 720, 480});
auto depth_to_space = std::make_shared<ngraph::opset3::DepthToSpace>(input0, ngraph::opset3::DepthToSpace::DepthToSpaceMode::DEPTH_FIRST, 2);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{depth_to_space}, ngraph::ParameterVector{input0});
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, DepthToSpaceFusionBlockFirst) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto input0 = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::f32, ngraph::Shape{1, 128, 720, 480});
auto shape_reshape_before = ngraph::opset3::Constant::create(ngraph::element::i64, ngraph::Shape{6}, {1, 2, 2, 32, 720, 480});
auto permutation = ngraph::opset3::Constant::create(ngraph::element::i64, ngraph::Shape{6}, {0, 3, 4, 1, 5, 2});
auto shape_reshape_after = ngraph::opset3::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {1, 32, 1440, 960});
auto reshape_before = std::make_shared<ngraph::opset3::Reshape> (input0, shape_reshape_before, false);
auto permute = std::make_shared<ngraph::opset3::Transpose> (reshape_before, permutation);
auto reshape_after = std::make_shared<ngraph::opset3::Reshape> (permute, shape_reshape_after, false);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{reshape_after}, ngraph::ParameterVector{input0});
ngraph::pass::InitNodeInfo().run_on_function(f);
auto callback = [](const std::shared_ptr<const ngraph::Node> & node) -> bool {
return std::dynamic_pointer_cast<const ngraph::opset3::DepthToSpace>(node) != nullptr;
};
auto depth_to_space_transform = ngraph::pass::DepthToSpaceFusion();
depth_to_space_transform.setCallback(callback);
depth_to_space_transform.run_on_function(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto input0 = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::f32, ngraph::Shape{1, 128, 720, 480});
auto depth_to_space = std::make_shared<ngraph::opset3::DepthToSpace>(input0, ngraph::opset3::DepthToSpace::DepthToSpaceMode::BLOCKS_FIRST, 2);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{depth_to_space}, ngraph::ParameterVector{input0});
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, DepthToSpaceFusionDynamicShape) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto input0 = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::f32, ngraph::Shape{1, 128, 720, 480});
auto shape_reshape_before = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::i64, ngraph::Shape{6});
auto permutation = ngraph::opset3::Constant::create(ngraph::element::i64, ngraph::Shape{6}, {0, 3, 4, 1, 5, 2});
auto shape_reshape_after = ngraph::opset3::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {1, 32, 1440, 960});
auto reshape_before = std::make_shared<ngraph::opset3::Reshape> (input0, shape_reshape_before, false);
auto permute = std::make_shared<ngraph::opset3::Transpose> (reshape_before, permutation);
auto reshape_after = std::make_shared<ngraph::opset3::Reshape> (permute, shape_reshape_after, false);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{reshape_after}, ngraph::ParameterVector{input0, shape_reshape_before});
ngraph::pass::InitNodeInfo().run_on_function(f);
auto callback = [](const std::shared_ptr<const ngraph::Node> & node) -> bool {
return std::dynamic_pointer_cast<const ngraph::opset3::DepthToSpace>(node) != nullptr;
};
// transformation won't be applied because of shape_reshape_before is dynamic, the graph will remain the same
auto depth_to_space_transform = ngraph::pass::DepthToSpaceFusion();
depth_to_space_transform.setCallback(callback);
depth_to_space_transform.run_on_function(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto input0 = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::f32, ngraph::Shape{1, 128, 720, 480});
auto shape_reshape_before = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::i64, ngraph::Shape{6});
auto permutation = ngraph::opset3::Constant::create(ngraph::element::i64, ngraph::Shape{6}, {0, 3, 4, 1, 5, 2});
auto shape_reshape_after = ngraph::opset3::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {1, 32, 1440, 960});
auto reshape_before = std::make_shared<ngraph::opset3::Reshape> (input0, shape_reshape_before, false);
auto permute = std::make_shared<ngraph::opset3::Transpose> (reshape_before, permutation);
auto reshape_after = std::make_shared<ngraph::opset3::Reshape> (permute, shape_reshape_after, false);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{reshape_after}, ngraph::ParameterVector{input0, shape_reshape_before});
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, DepthToSpaceFusionSeveralConsumers) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto input0 = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::f32, ngraph::Shape{1, 128, 720, 480});
auto shape_reshape_before = ngraph::opset3::Constant::create(ngraph::element::i64, ngraph::Shape{6}, {1, 2, 2, 32, 720, 480});
auto permutation = ngraph::opset3::Constant::create(ngraph::element::i64, ngraph::Shape{6}, {0, 3, 4, 1, 5, 2});
auto shape_reshape_after = ngraph::opset3::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {1, 32, 1440, 960});
auto reshape_before = std::make_shared<ngraph::opset3::Reshape> (input0, shape_reshape_before, false);
auto permute = std::make_shared<ngraph::opset3::Transpose> (reshape_before, permutation);
auto reshape_after = std::make_shared<ngraph::opset3::Reshape> (permute, shape_reshape_after, false);
// additional consumers, not output of the function
auto result = std::make_shared<ngraph::opset3::Result> (reshape_before);
auto result_2 = std::make_shared<ngraph::opset3::Result> (permute);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{reshape_after}, ngraph::ParameterVector{input0});
ngraph::pass::InitNodeInfo().run_on_function(f);
auto callback = [](const std::shared_ptr<const ngraph::Node> & node) -> bool {
return std::dynamic_pointer_cast<const ngraph::opset3::DepthToSpace>(node) != nullptr;
};
// transformation won't be applied because of reshape_before has several consumers, the graph will remain the same
auto depth_to_space_transform = ngraph::pass::DepthToSpaceFusion();
depth_to_space_transform.setCallback(callback);
depth_to_space_transform.run_on_function(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto input0 = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::f32, ngraph::Shape{1, 128, 720, 480});
auto shape_reshape_before = ngraph::opset3::Constant::create(ngraph::element::i64, ngraph::Shape{6}, {1, 2, 2, 32, 720, 480});
auto permutation = ngraph::opset3::Constant::create(ngraph::element::i64, ngraph::Shape{6}, {0, 3, 4, 1, 5, 2});
auto shape_reshape_after = ngraph::opset3::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {1, 32, 1440, 960});
auto reshape_before = std::make_shared<ngraph::opset3::Reshape> (input0, shape_reshape_before, false);
auto permute = std::make_shared<ngraph::opset3::Transpose> (reshape_before, permutation);
auto reshape_after = std::make_shared<ngraph::opset3::Reshape> (permute, shape_reshape_after, false);
// additional consumers, not output of the function
auto result = std::make_shared<ngraph::opset3::Result> (reshape_before);
auto result_2 = std::make_shared<ngraph::opset3::Result> (permute);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{reshape_after}, ngraph::ParameterVector{input0});
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}