Reshape-Permute-Reshape pattern to DepthToSpace layer transformation (#601)
* implemented depth_to_space transformation * renaming * added functional tests, fixed mistakes in implementation of the transformation * disable ConvertSpaceToDepth/ConvertDepthToSpace transformation for CPU plugin, enable DepthToSpaceFusion for CPU plugin only, add specific creators * fix wrong include * fix for functional tests: set transformation callback * revert callback calls for CPU plugin * move functions to .cpp file * Apply review comments * Apply additional review comments * fix cast to bool type
This commit is contained in:
parent
b4893945c7
commit
cd01ccd449
@ -332,7 +332,7 @@ InferenceEngine::details::CNNLayerCreator::CNNLayerCreator(const std::shared_ptr
|
|||||||
res->params = params;
|
res->params = params;
|
||||||
return res;
|
return res;
|
||||||
});
|
});
|
||||||
|
|
||||||
addSpecificCreator({"Assign"}, [](const std::shared_ptr<::ngraph::Node>& node,
|
addSpecificCreator({"Assign"}, [](const std::shared_ptr<::ngraph::Node>& node,
|
||||||
const std::map<std::string, std::string> params) -> CNNLayerPtr {
|
const std::map<std::string, std::string> params) -> CNNLayerPtr {
|
||||||
LayerParams attrs = {node->get_friendly_name(), "Memory",
|
LayerParams attrs = {node->get_friendly_name(), "Memory",
|
||||||
@ -355,6 +355,24 @@ InferenceEngine::details::CNNLayerCreator::CNNLayerCreator(const std::shared_ptr
|
|||||||
return res;
|
return res;
|
||||||
});
|
});
|
||||||
|
|
||||||
|
addSpecificCreator({"DepthToSpace"}, [](const std::shared_ptr<::ngraph::Node>& node,
|
||||||
|
const std::map<std::string, std::string> params) -> CNNLayerPtr {
|
||||||
|
LayerParams attrs = {node->get_friendly_name(), node->description(),
|
||||||
|
details::convertPrecision(node->get_output_element_type(0))};
|
||||||
|
auto res = std::make_shared<DepthToSpaceLayer>(attrs);
|
||||||
|
res->params = params;
|
||||||
|
return res;
|
||||||
|
});
|
||||||
|
|
||||||
|
addSpecificCreator({"SpaceToDepth"}, [](const std::shared_ptr<::ngraph::Node>& node,
|
||||||
|
const std::map<std::string, std::string> params) -> CNNLayerPtr {
|
||||||
|
LayerParams attrs = {node->get_friendly_name(), node->description(),
|
||||||
|
details::convertPrecision(node->get_output_element_type(0))};
|
||||||
|
auto res = std::make_shared<SpaceToDepthLayer>(attrs);
|
||||||
|
res->params = params;
|
||||||
|
return res;
|
||||||
|
});
|
||||||
|
|
||||||
addSpecificCreator({"RNNCell"}, [](const std::shared_ptr<::ngraph::Node>& node,
|
addSpecificCreator({"RNNCell"}, [](const std::shared_ptr<::ngraph::Node>& node,
|
||||||
const std::map<std::string, std::string> params) -> CNNLayerPtr {
|
const std::map<std::string, std::string> params) -> CNNLayerPtr {
|
||||||
THROW_IE_EXCEPTION << "RNNCell operation has a form that is not supported." << node->get_friendly_name()
|
THROW_IE_EXCEPTION << "RNNCell operation has a form that is not supported." << node->get_friendly_name()
|
||||||
|
@ -25,3 +25,4 @@ NGRAPH_PASS(NopElimination, ::ngraph::pass) // may introduce fake dynamism
|
|||||||
NGRAPH_PASS(AlgebraicSimplification, ::ngraph::pass) // may introduce fake dynamism
|
NGRAPH_PASS(AlgebraicSimplification, ::ngraph::pass) // may introduce fake dynamism
|
||||||
NGRAPH_PASS(ConstantFolding, ::ngraph::pass)
|
NGRAPH_PASS(ConstantFolding, ::ngraph::pass)
|
||||||
NGRAPH_PASS(ConvertScatterElementsToScatter, ::ngraph::pass) // partially depends on CF
|
NGRAPH_PASS(ConvertScatterElementsToScatter, ::ngraph::pass) // partially depends on CF
|
||||||
|
NGRAPH_PASS(DepthToSpaceFusion, ::ngraph::pass)
|
||||||
|
@ -10,6 +10,7 @@
|
|||||||
#include <ie_api.h>
|
#include <ie_api.h>
|
||||||
|
|
||||||
#include <ngraph/pass/graph_rewrite.hpp>
|
#include <ngraph/pass/graph_rewrite.hpp>
|
||||||
|
#include "transformations/utils/pass_param.hpp"
|
||||||
|
|
||||||
namespace ngraph {
|
namespace ngraph {
|
||||||
namespace pass {
|
namespace pass {
|
||||||
@ -19,9 +20,9 @@ class INFERENCE_ENGINE_API_CLASS(ConvertDepthToSpace);
|
|||||||
} // namespace pass
|
} // namespace pass
|
||||||
} // namespace ngraph
|
} // namespace ngraph
|
||||||
|
|
||||||
class ngraph::pass::ConvertDepthToSpace: public ngraph::pass::GraphRewrite {
|
class ngraph::pass::ConvertDepthToSpace: public ngraph::pass::GraphRewrite, public ngraph::pass::PassParam {
|
||||||
public:
|
public:
|
||||||
ConvertDepthToSpace() : GraphRewrite() {
|
ConvertDepthToSpace() : GraphRewrite(), PassParam() {
|
||||||
convert_depth_to_space();
|
convert_depth_to_space();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -19,4 +19,3 @@ NGRAPH_PASS(ConvertNMS3, ::ngraph::pass)
|
|||||||
NGRAPH_PASS(ConvertShapeOf3, ::ngraph::pass)
|
NGRAPH_PASS(ConvertShapeOf3, ::ngraph::pass)
|
||||||
NGRAPH_PASS(ConvertShuffleChannels3, ::ngraph::pass)
|
NGRAPH_PASS(ConvertShuffleChannels3, ::ngraph::pass)
|
||||||
NGRAPH_PASS(ConvertTopK3, ::ngraph::pass)
|
NGRAPH_PASS(ConvertTopK3, ::ngraph::pass)
|
||||||
|
|
||||||
|
@ -10,6 +10,7 @@
|
|||||||
#include <ie_api.h>
|
#include <ie_api.h>
|
||||||
|
|
||||||
#include <ngraph/pass/graph_rewrite.hpp>
|
#include <ngraph/pass/graph_rewrite.hpp>
|
||||||
|
#include "transformations/utils/pass_param.hpp"
|
||||||
|
|
||||||
namespace ngraph {
|
namespace ngraph {
|
||||||
namespace pass {
|
namespace pass {
|
||||||
@ -19,9 +20,9 @@ class INFERENCE_ENGINE_API_CLASS(ConvertSpaceToDepth);
|
|||||||
} // namespace pass
|
} // namespace pass
|
||||||
} // namespace ngraph
|
} // namespace ngraph
|
||||||
|
|
||||||
class ngraph::pass::ConvertSpaceToDepth: public ngraph::pass::GraphRewrite {
|
class ngraph::pass::ConvertSpaceToDepth: public ngraph::pass::GraphRewrite, public ngraph::pass::PassParam {
|
||||||
public:
|
public:
|
||||||
ConvertSpaceToDepth() : GraphRewrite() {
|
ConvertSpaceToDepth() : GraphRewrite(), PassParam() {
|
||||||
convert();
|
convert();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -0,0 +1,54 @@
|
|||||||
|
// Copyright (C) 2020 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include <ie_api.h>
|
||||||
|
|
||||||
|
#include <ngraph/pass/graph_rewrite.hpp>
|
||||||
|
#include "transformations/utils/pass_param.hpp"
|
||||||
|
|
||||||
|
namespace ngraph {
|
||||||
|
namespace pass {
|
||||||
|
|
||||||
|
class INFERENCE_ENGINE_API_CLASS(DepthToSpaceFusion);
|
||||||
|
|
||||||
|
} // namespace pass
|
||||||
|
} // namespace ngraph
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Description:
|
||||||
|
* DepthToSpaceFusion transformation detects Reshape-Transpose-Reshape pattern and
|
||||||
|
* tries to fuse it into a single DepthToSpace layer.
|
||||||
|
*
|
||||||
|
* Usage:
|
||||||
|
* DepthToSpaceFusion transformation is optional and disabled by default.
|
||||||
|
* The transformation can be enabled with callback using setCallback method.
|
||||||
|
* See the example below.
|
||||||
|
*
|
||||||
|
* Callback example:
|
||||||
|
*
|
||||||
|
* // This callback enables DepthToSpaceFusion transformation
|
||||||
|
* auto callback = [](const std::shared_ptr<const ngraph::Node> & node) -> bool {
|
||||||
|
* return std::dynamic_pointer_cast<const ngraph::opset3::DepthToSpace>(node) != nullptr;
|
||||||
|
* };
|
||||||
|
*
|
||||||
|
* auto p = ngraph::pass::DepthToSpaceFusion();
|
||||||
|
* p.setCallback(callback);
|
||||||
|
* p.run_on_function(f);
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
|
||||||
|
class ngraph::pass::DepthToSpaceFusion: public ngraph::pass::GraphRewrite, public ngraph::pass::PassParam {
|
||||||
|
public:
|
||||||
|
DepthToSpaceFusion() : GraphRewrite(), PassParam() {
|
||||||
|
depth_to_space_fusion();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
void depth_to_space_fusion();
|
||||||
|
};
|
@ -6,6 +6,7 @@
|
|||||||
|
|
||||||
#include "transformations/common_optimizations/common_optimizations.hpp"
|
#include "transformations/common_optimizations/common_optimizations.hpp"
|
||||||
#include "transformations/convert_opset1_to_legacy/convert_prior_to_ie_prior.hpp"
|
#include "transformations/convert_opset1_to_legacy/convert_prior_to_ie_prior.hpp"
|
||||||
|
#include "transformations/depth_to_space_fusion.hpp"
|
||||||
#include "transformations/optimize_strided_slice.hpp"
|
#include "transformations/optimize_strided_slice.hpp"
|
||||||
#include "transformations/convert_scatter_elements_to_scatter.hpp"
|
#include "transformations/convert_scatter_elements_to_scatter.hpp"
|
||||||
#include "transformations/remove_filtering_boxes_by_size.hpp"
|
#include "transformations/remove_filtering_boxes_by_size.hpp"
|
||||||
|
@ -14,9 +14,9 @@ void ngraph::pass::ConvertDepthToSpace::convert_depth_to_space() {
|
|||||||
auto input0 = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1, 1});
|
auto input0 = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1, 1});
|
||||||
auto dts_node = std::make_shared<ngraph::opset1::DepthToSpace>(input0, ngraph::op::DepthToSpace::DepthToSpaceMode::DEPTH_FIRST);
|
auto dts_node = std::make_shared<ngraph::opset1::DepthToSpace>(input0, ngraph::op::DepthToSpace::DepthToSpaceMode::DEPTH_FIRST);
|
||||||
|
|
||||||
ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
|
ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
|
||||||
auto dts_node = std::dynamic_pointer_cast<ngraph::opset1::DepthToSpace> (m.get_match_root());
|
auto dts_node = std::dynamic_pointer_cast<ngraph::opset1::DepthToSpace> (m.get_match_root());
|
||||||
if (!dts_node) {
|
if (!dts_node || transformation_callback(dts_node)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -14,9 +14,9 @@ void ngraph::pass::ConvertSpaceToDepth::convert() {
|
|||||||
auto input0 = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1, 1});
|
auto input0 = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1, 1});
|
||||||
auto dts = std::make_shared<ngraph::opset1::SpaceToDepth>(input0, ngraph::opset1::SpaceToDepth::SpaceToDepthMode::DEPTH_FIRST);
|
auto dts = std::make_shared<ngraph::opset1::SpaceToDepth>(input0, ngraph::opset1::SpaceToDepth::SpaceToDepthMode::DEPTH_FIRST);
|
||||||
|
|
||||||
ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
|
ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
|
||||||
auto std_node = std::dynamic_pointer_cast<ngraph::opset1::SpaceToDepth> (m.get_match_root());
|
auto std_node = std::dynamic_pointer_cast<ngraph::opset1::SpaceToDepth> (m.get_match_root());
|
||||||
if (!std_node) {
|
if (!std_node || transformation_callback(std_node)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -0,0 +1,166 @@
|
|||||||
|
// Copyright (C) 2020 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "transformations/depth_to_space_fusion.hpp"
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include <ngraph/opsets/opset3.hpp>
|
||||||
|
#include <ngraph/rt_info.hpp>
|
||||||
|
|
||||||
|
bool check_block_first(const ngraph::Shape& shape_input, const ngraph::Shape& shape_reshape_before,
|
||||||
|
const ngraph::AxisVector& permutation, const ngraph::Shape& shape_reshape_after,
|
||||||
|
size_t& possible_block_size) {
|
||||||
|
bool is_transformation_valid = true;
|
||||||
|
uint64_t spatial_dims = shape_input.size() - 2;
|
||||||
|
possible_block_size = shape_reshape_before[1];
|
||||||
|
if (possible_block_size == 0)
|
||||||
|
return false;
|
||||||
|
uint64_t c_dim = shape_input[1] / std::pow(possible_block_size, spatial_dims);
|
||||||
|
|
||||||
|
// x' = reshape(data, [N, block_size, block_size, ..., block_size, C / (block_size ^ K), D1, D2, ..., DK])
|
||||||
|
ngraph::Shape expected_shape = {shape_input[0]};
|
||||||
|
for (uint64_t i = 0; i < spatial_dims; ++i)
|
||||||
|
expected_shape.push_back(possible_block_size);
|
||||||
|
expected_shape.push_back(c_dim);
|
||||||
|
for (uint64_t i = 2; i < shape_input.size(); ++i)
|
||||||
|
expected_shape.push_back(shape_input[i]);
|
||||||
|
is_transformation_valid &= (expected_shape == shape_reshape_before);
|
||||||
|
|
||||||
|
// x'' = transpose(x', [0, K + 1, K + 2, 1, K + 3, 2, K + 4, 3, ..., K + (K + 1), K])
|
||||||
|
ngraph::AxisVector expected_permutation = {0, spatial_dims + 1};
|
||||||
|
for (uint64_t i = 2; i < shape_input.size(); ++i) {
|
||||||
|
expected_permutation.push_back(spatial_dims + i);
|
||||||
|
expected_permutation.push_back(i - 1);
|
||||||
|
}
|
||||||
|
is_transformation_valid &= (expected_permutation == permutation);
|
||||||
|
|
||||||
|
// y = reshape(x'', [N, C / (block_size ^ K), D1 * block_size, D2 * block_size, D3 * block_size, ..., DK * block_size])
|
||||||
|
expected_shape = {shape_input[0], c_dim};
|
||||||
|
for (uint64_t i = 2; i < shape_input.size(); ++i)
|
||||||
|
expected_shape.push_back(shape_input[i] * possible_block_size);
|
||||||
|
is_transformation_valid &= (expected_shape == shape_reshape_after);
|
||||||
|
|
||||||
|
return is_transformation_valid;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool check_depth_first(const ngraph::Shape& shape_input, const ngraph::Shape& shape_reshape_before,
|
||||||
|
const ngraph::AxisVector& permutation, const ngraph::Shape& shape_reshape_after,
|
||||||
|
size_t& possible_block_size) {
|
||||||
|
bool is_transformation_valid = true;
|
||||||
|
uint64_t spatial_dims = shape_input.size() - 2;
|
||||||
|
possible_block_size = shape_reshape_before[2];
|
||||||
|
if (possible_block_size == 0)
|
||||||
|
return false;
|
||||||
|
uint64_t c_dim = shape_input[1] / std::pow(possible_block_size, spatial_dims);
|
||||||
|
|
||||||
|
// x' = reshape(data, [N, C / (block_size ^ K), block_size, block_size, ..., block_size, D1, D2, ..., DK])
|
||||||
|
ngraph::Shape expected_shape = {shape_input[0], c_dim};
|
||||||
|
for (uint64_t i = 0; i < spatial_dims; ++i)
|
||||||
|
expected_shape.push_back(possible_block_size);
|
||||||
|
for (uint64_t i = 2; i < shape_input.size(); ++i)
|
||||||
|
expected_shape.push_back(shape_input[i]);
|
||||||
|
is_transformation_valid &= (expected_shape == shape_reshape_before);
|
||||||
|
|
||||||
|
// x'' = transpose(x', [0, 1, K + 2, 2, K + 3, 3, K + 4, 4, ..., K + (K + 1), K + 1])
|
||||||
|
ngraph::AxisVector expected_permutation = {0, 1};
|
||||||
|
for (uint64_t i = 2; i < shape_input.size(); ++i) {
|
||||||
|
expected_permutation.push_back(spatial_dims + i);
|
||||||
|
expected_permutation.push_back(i);
|
||||||
|
}
|
||||||
|
is_transformation_valid &= (expected_permutation == permutation);
|
||||||
|
|
||||||
|
// y = reshape(x'', [N, C / (block_size ^ K), D1 * block_size, D2 * block_size, D3 * block_size, ..., DK * block_size])
|
||||||
|
expected_shape = {shape_input[0], c_dim};
|
||||||
|
for (uint64_t i = 2; i < shape_input.size(); ++i)
|
||||||
|
expected_shape.push_back(shape_input[i] * possible_block_size);
|
||||||
|
is_transformation_valid &= (expected_shape == shape_reshape_after);
|
||||||
|
|
||||||
|
return is_transformation_valid;
|
||||||
|
}
|
||||||
|
|
||||||
|
void ngraph::pass::DepthToSpaceFusion::depth_to_space_fusion() {
|
||||||
|
auto input0 = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1, 1});
|
||||||
|
auto input1 = std::make_shared<pattern::op::Label>(element::i64, Shape{4});
|
||||||
|
auto input2 = std::make_shared<pattern::op::Label>(element::i64, Shape{4});
|
||||||
|
auto input3 = std::make_shared<pattern::op::Label>(element::i64, Shape{4});
|
||||||
|
auto reshape_before = std::make_shared<ngraph::opset3::Reshape> (input0, input1, false);
|
||||||
|
auto permute = std::make_shared<ngraph::opset3::Transpose> (reshape_before, input2);
|
||||||
|
auto reshape_after = std::make_shared<ngraph::opset3::Reshape> (permute, input3, false);
|
||||||
|
|
||||||
|
ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
|
||||||
|
if (!transformation_callback(std::make_shared<ngraph::opset3::DepthToSpace>())) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto reshape_after = std::dynamic_pointer_cast<ngraph::opset3::Reshape>(m.get_match_root());
|
||||||
|
if (!reshape_after) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto permute = std::dynamic_pointer_cast<ngraph::opset3::Transpose>(reshape_after->input_value(0).get_node_shared_ptr());
|
||||||
|
if (!permute || permute->get_output_target_inputs(0).size() != 1) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto reshape_before = std::dynamic_pointer_cast<ngraph::opset3::Reshape>(permute->input_value(0).get_node_shared_ptr());
|
||||||
|
if (!reshape_before || reshape_before->get_output_target_inputs(0).size() != 1) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto p_shape_input = reshape_before->get_input_partial_shape(0);
|
||||||
|
auto p_shape_reshape_before = reshape_before->get_output_partial_shape(0);
|
||||||
|
auto p_shape_permute = permute->get_output_partial_shape(0);
|
||||||
|
auto p_shape_reshape_after = reshape_after->get_output_partial_shape(0);
|
||||||
|
|
||||||
|
if (p_shape_input.is_dynamic() || p_shape_reshape_before.is_dynamic() ||
|
||||||
|
p_shape_permute.is_dynamic() || p_shape_reshape_after.is_dynamic()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto shape_input = p_shape_input.get_shape();
|
||||||
|
auto shape_reshape_before = p_shape_reshape_before.get_shape();
|
||||||
|
auto shape_permute = p_shape_permute.get_shape();
|
||||||
|
auto shape_reshape_after = p_shape_reshape_after.get_shape();
|
||||||
|
|
||||||
|
if (shape_input.size() < 3) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// input shape: [ batch, C, spatial_dims], expected_shape = spatial_dims.size() * 2 + 2
|
||||||
|
size_t expected_shape_size = (shape_input.size() - 2) * 2 + 2;
|
||||||
|
if (shape_input.size() != shape_reshape_after.size() || shape_reshape_before.size() != expected_shape_size ||
|
||||||
|
shape_permute.size() != expected_shape_size) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
ngraph::AxisVector permutation;
|
||||||
|
if (auto input_const = std::dynamic_pointer_cast<opset3::Constant>(permute->input_value(1).get_node_shared_ptr())) {
|
||||||
|
permutation = input_const->get_axis_vector_val();
|
||||||
|
} else {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
ngraph::opset3::DepthToSpace::DepthToSpaceMode mode;
|
||||||
|
size_t block_size;
|
||||||
|
if (check_depth_first(shape_input, shape_reshape_before, permutation, shape_reshape_after, block_size)) {
|
||||||
|
mode = ngraph::opset3::DepthToSpace::DepthToSpaceMode::DEPTH_FIRST;
|
||||||
|
} else if (check_block_first(shape_input, shape_reshape_before, permutation, shape_reshape_after, block_size)) {
|
||||||
|
mode = ngraph::opset3::DepthToSpace::DepthToSpaceMode::BLOCKS_FIRST;
|
||||||
|
} else {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto depth_to_space =
|
||||||
|
std::make_shared<ngraph::opset3::DepthToSpace>(reshape_before->input_value(0), mode, block_size);
|
||||||
|
depth_to_space->set_friendly_name(reshape_after->get_friendly_name());
|
||||||
|
ngraph::copy_runtime_info({reshape_before, permute, reshape_after}, depth_to_space);
|
||||||
|
ngraph::replace_node(reshape_after, depth_to_space);
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto m = std::make_shared<ngraph::pattern::Matcher>(reshape_after, "DepthToSpaceFusion");
|
||||||
|
this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
|
||||||
|
}
|
@ -0,0 +1,184 @@
|
|||||||
|
// Copyright (C) 2020 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
#include "common_test_utils/test_common.hpp"
|
||||||
|
#include <string>
|
||||||
|
#include <sstream>
|
||||||
|
#include <fstream>
|
||||||
|
#include <memory>
|
||||||
|
#include <queue>
|
||||||
|
#include <map>
|
||||||
|
|
||||||
|
#include <ngraph/function.hpp>
|
||||||
|
#include <ngraph/opsets/opset3.hpp>
|
||||||
|
#include <ngraph/pass/constant_folding.hpp>
|
||||||
|
#include <ngraph_ops/fully_connected.hpp>
|
||||||
|
#include <transformations/depth_to_space_fusion.hpp>
|
||||||
|
#include <transformations/utils/utils.hpp>
|
||||||
|
#include <transformations/init_node_info.hpp>
|
||||||
|
|
||||||
|
#include "ngraph_test_utils.hpp"
|
||||||
|
|
||||||
|
using namespace testing;
|
||||||
|
|
||||||
|
TEST(TransformationTests, DepthToSpaceFusionDepthFirst) {
|
||||||
|
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||||
|
{
|
||||||
|
auto input0 = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::f32, ngraph::Shape{1, 128, 720, 480});
|
||||||
|
auto shape_reshape_before = ngraph::opset3::Constant::create(ngraph::element::i64, ngraph::Shape{6}, {1, 32, 2, 2, 720, 480});
|
||||||
|
auto permutation = ngraph::opset3::Constant::create(ngraph::element::i64, ngraph::Shape{6}, {0, 1, 4, 2, 5, 3});
|
||||||
|
auto shape_reshape_after = ngraph::opset3::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {1, 32, 1440, 960});
|
||||||
|
|
||||||
|
auto reshape_before = std::make_shared<ngraph::opset3::Reshape> (input0, shape_reshape_before, false);
|
||||||
|
auto permute = std::make_shared<ngraph::opset3::Transpose> (reshape_before, permutation);
|
||||||
|
auto reshape_after = std::make_shared<ngraph::opset3::Reshape> (permute, shape_reshape_after, false);
|
||||||
|
|
||||||
|
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{reshape_after}, ngraph::ParameterVector{input0});
|
||||||
|
ngraph::pass::InitNodeInfo().run_on_function(f);
|
||||||
|
auto callback = [](const std::shared_ptr<const ngraph::Node> & node) -> bool {
|
||||||
|
return std::dynamic_pointer_cast<const ngraph::opset3::DepthToSpace>(node) != nullptr;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto depth_to_space_transform = ngraph::pass::DepthToSpaceFusion();
|
||||||
|
depth_to_space_transform.setCallback(callback);
|
||||||
|
depth_to_space_transform.run_on_function(f);
|
||||||
|
ASSERT_NO_THROW(check_rt_info(f));
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto input0 = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::f32, ngraph::Shape{1, 128, 720, 480});
|
||||||
|
auto depth_to_space = std::make_shared<ngraph::opset3::DepthToSpace>(input0, ngraph::opset3::DepthToSpace::DepthToSpaceMode::DEPTH_FIRST, 2);
|
||||||
|
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{depth_to_space}, ngraph::ParameterVector{input0});
|
||||||
|
}
|
||||||
|
|
||||||
|
auto res = compare_functions(f, f_ref);
|
||||||
|
ASSERT_TRUE(res.first) << res.second;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(TransformationTests, DepthToSpaceFusionBlockFirst) {
|
||||||
|
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||||
|
{
|
||||||
|
auto input0 = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::f32, ngraph::Shape{1, 128, 720, 480});
|
||||||
|
auto shape_reshape_before = ngraph::opset3::Constant::create(ngraph::element::i64, ngraph::Shape{6}, {1, 2, 2, 32, 720, 480});
|
||||||
|
auto permutation = ngraph::opset3::Constant::create(ngraph::element::i64, ngraph::Shape{6}, {0, 3, 4, 1, 5, 2});
|
||||||
|
auto shape_reshape_after = ngraph::opset3::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {1, 32, 1440, 960});
|
||||||
|
|
||||||
|
auto reshape_before = std::make_shared<ngraph::opset3::Reshape> (input0, shape_reshape_before, false);
|
||||||
|
auto permute = std::make_shared<ngraph::opset3::Transpose> (reshape_before, permutation);
|
||||||
|
auto reshape_after = std::make_shared<ngraph::opset3::Reshape> (permute, shape_reshape_after, false);
|
||||||
|
|
||||||
|
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{reshape_after}, ngraph::ParameterVector{input0});
|
||||||
|
ngraph::pass::InitNodeInfo().run_on_function(f);
|
||||||
|
auto callback = [](const std::shared_ptr<const ngraph::Node> & node) -> bool {
|
||||||
|
return std::dynamic_pointer_cast<const ngraph::opset3::DepthToSpace>(node) != nullptr;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto depth_to_space_transform = ngraph::pass::DepthToSpaceFusion();
|
||||||
|
depth_to_space_transform.setCallback(callback);
|
||||||
|
depth_to_space_transform.run_on_function(f);
|
||||||
|
ASSERT_NO_THROW(check_rt_info(f));
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto input0 = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::f32, ngraph::Shape{1, 128, 720, 480});
|
||||||
|
auto depth_to_space = std::make_shared<ngraph::opset3::DepthToSpace>(input0, ngraph::opset3::DepthToSpace::DepthToSpaceMode::BLOCKS_FIRST, 2);
|
||||||
|
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{depth_to_space}, ngraph::ParameterVector{input0});
|
||||||
|
}
|
||||||
|
|
||||||
|
auto res = compare_functions(f, f_ref);
|
||||||
|
ASSERT_TRUE(res.first) << res.second;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(TransformationTests, DepthToSpaceFusionDynamicShape) {
|
||||||
|
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||||
|
{
|
||||||
|
auto input0 = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::f32, ngraph::Shape{1, 128, 720, 480});
|
||||||
|
auto shape_reshape_before = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::i64, ngraph::Shape{6});
|
||||||
|
auto permutation = ngraph::opset3::Constant::create(ngraph::element::i64, ngraph::Shape{6}, {0, 3, 4, 1, 5, 2});
|
||||||
|
auto shape_reshape_after = ngraph::opset3::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {1, 32, 1440, 960});
|
||||||
|
|
||||||
|
auto reshape_before = std::make_shared<ngraph::opset3::Reshape> (input0, shape_reshape_before, false);
|
||||||
|
auto permute = std::make_shared<ngraph::opset3::Transpose> (reshape_before, permutation);
|
||||||
|
auto reshape_after = std::make_shared<ngraph::opset3::Reshape> (permute, shape_reshape_after, false);
|
||||||
|
|
||||||
|
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{reshape_after}, ngraph::ParameterVector{input0, shape_reshape_before});
|
||||||
|
ngraph::pass::InitNodeInfo().run_on_function(f);
|
||||||
|
auto callback = [](const std::shared_ptr<const ngraph::Node> & node) -> bool {
|
||||||
|
return std::dynamic_pointer_cast<const ngraph::opset3::DepthToSpace>(node) != nullptr;
|
||||||
|
};
|
||||||
|
|
||||||
|
// transformation won't be applied because of shape_reshape_before is dynamic, the graph will remain the same
|
||||||
|
auto depth_to_space_transform = ngraph::pass::DepthToSpaceFusion();
|
||||||
|
depth_to_space_transform.setCallback(callback);
|
||||||
|
depth_to_space_transform.run_on_function(f);
|
||||||
|
ASSERT_NO_THROW(check_rt_info(f));
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto input0 = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::f32, ngraph::Shape{1, 128, 720, 480});
|
||||||
|
auto shape_reshape_before = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::i64, ngraph::Shape{6});
|
||||||
|
auto permutation = ngraph::opset3::Constant::create(ngraph::element::i64, ngraph::Shape{6}, {0, 3, 4, 1, 5, 2});
|
||||||
|
auto shape_reshape_after = ngraph::opset3::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {1, 32, 1440, 960});
|
||||||
|
|
||||||
|
auto reshape_before = std::make_shared<ngraph::opset3::Reshape> (input0, shape_reshape_before, false);
|
||||||
|
auto permute = std::make_shared<ngraph::opset3::Transpose> (reshape_before, permutation);
|
||||||
|
auto reshape_after = std::make_shared<ngraph::opset3::Reshape> (permute, shape_reshape_after, false);
|
||||||
|
|
||||||
|
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{reshape_after}, ngraph::ParameterVector{input0, shape_reshape_before});
|
||||||
|
}
|
||||||
|
|
||||||
|
auto res = compare_functions(f, f_ref);
|
||||||
|
ASSERT_TRUE(res.first) << res.second;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(TransformationTests, DepthToSpaceFusionSeveralConsumers) {
|
||||||
|
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||||
|
{
|
||||||
|
auto input0 = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::f32, ngraph::Shape{1, 128, 720, 480});
|
||||||
|
auto shape_reshape_before = ngraph::opset3::Constant::create(ngraph::element::i64, ngraph::Shape{6}, {1, 2, 2, 32, 720, 480});
|
||||||
|
auto permutation = ngraph::opset3::Constant::create(ngraph::element::i64, ngraph::Shape{6}, {0, 3, 4, 1, 5, 2});
|
||||||
|
auto shape_reshape_after = ngraph::opset3::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {1, 32, 1440, 960});
|
||||||
|
|
||||||
|
auto reshape_before = std::make_shared<ngraph::opset3::Reshape> (input0, shape_reshape_before, false);
|
||||||
|
auto permute = std::make_shared<ngraph::opset3::Transpose> (reshape_before, permutation);
|
||||||
|
auto reshape_after = std::make_shared<ngraph::opset3::Reshape> (permute, shape_reshape_after, false);
|
||||||
|
|
||||||
|
// additional consumers, not output of the function
|
||||||
|
auto result = std::make_shared<ngraph::opset3::Result> (reshape_before);
|
||||||
|
auto result_2 = std::make_shared<ngraph::opset3::Result> (permute);
|
||||||
|
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{reshape_after}, ngraph::ParameterVector{input0});
|
||||||
|
ngraph::pass::InitNodeInfo().run_on_function(f);
|
||||||
|
auto callback = [](const std::shared_ptr<const ngraph::Node> & node) -> bool {
|
||||||
|
return std::dynamic_pointer_cast<const ngraph::opset3::DepthToSpace>(node) != nullptr;
|
||||||
|
};
|
||||||
|
|
||||||
|
// transformation won't be applied because of reshape_before has several consumers, the graph will remain the same
|
||||||
|
auto depth_to_space_transform = ngraph::pass::DepthToSpaceFusion();
|
||||||
|
depth_to_space_transform.setCallback(callback);
|
||||||
|
depth_to_space_transform.run_on_function(f);
|
||||||
|
ASSERT_NO_THROW(check_rt_info(f));
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto input0 = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::f32, ngraph::Shape{1, 128, 720, 480});
|
||||||
|
auto shape_reshape_before = ngraph::opset3::Constant::create(ngraph::element::i64, ngraph::Shape{6}, {1, 2, 2, 32, 720, 480});
|
||||||
|
auto permutation = ngraph::opset3::Constant::create(ngraph::element::i64, ngraph::Shape{6}, {0, 3, 4, 1, 5, 2});
|
||||||
|
auto shape_reshape_after = ngraph::opset3::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {1, 32, 1440, 960});
|
||||||
|
|
||||||
|
auto reshape_before = std::make_shared<ngraph::opset3::Reshape> (input0, shape_reshape_before, false);
|
||||||
|
auto permute = std::make_shared<ngraph::opset3::Transpose> (reshape_before, permutation);
|
||||||
|
auto reshape_after = std::make_shared<ngraph::opset3::Reshape> (permute, shape_reshape_after, false);
|
||||||
|
|
||||||
|
// additional consumers, not output of the function
|
||||||
|
auto result = std::make_shared<ngraph::opset3::Result> (reshape_before);
|
||||||
|
auto result_2 = std::make_shared<ngraph::opset3::Result> (permute);
|
||||||
|
|
||||||
|
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{reshape_after}, ngraph::ParameterVector{input0});
|
||||||
|
}
|
||||||
|
|
||||||
|
auto res = compare_functions(f, f_ref);
|
||||||
|
ASSERT_TRUE(res.first) << res.second;
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user