Unroll If in ngraph (#7403)

* added unroll_if

* Fix interface function

* Fix code style

* add false test for unroll if

* add if transformation in subgraphs

* fix code style

* fix transformations

* delete internal transformations

* Fix comments

* fix to replace_node

* fix description fo transformation

* fix CompileNetwork

* fix comments

* fix comments

* add function get_ie_output_name(input);

* fix code style

* disable cpu test
This commit is contained in:
Eugeny Volosenkov 2021-10-15 12:37:43 +03:00 committed by GitHub
parent 3f690314fa
commit 3e48008c3f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 227 additions and 22 deletions

View File

@ -113,8 +113,9 @@ void TemplatePlugin::ExecutableNetwork::CompileNetwork(const std::shared_ptr<con
// Generate backend specific blob mappings. For example Inference Engine uses not ngraph::Result nodes friendly name // Generate backend specific blob mappings. For example Inference Engine uses not ngraph::Result nodes friendly name
// as inference request output names but the name of the layer before. // as inference request output names but the name of the layer before.
for (auto&& result : _function->get_results()) { for (auto&& result : _function->get_results()) {
auto outputName = ngraph::op::util::create_ie_output_name(result->input_value(0)); const auto& input = result->input_value(0);
_outputIndex.emplace(outputName, _function->get_result_index(result)); auto name = ngraph::op::util::get_ie_output_name(input);
_outputIndex.emplace(name, _function->get_result_index(result));
} }
for (auto&& parameter : _function->get_parameters()) { for (auto&& parameter : _function->get_parameters()) {
_inputIndex.emplace(parameter->get_friendly_name(), _function->get_parameter_index(parameter)); _inputIndex.emplace(parameter->get_friendly_name(), _function->get_parameter_index(parameter));

View File

@ -2037,12 +2037,9 @@ void convertFunctionToICNNNetwork(const std::shared_ptr<const ::ngraph::Function
cnnLayer->outData.clear(); cnnLayer->outData.clear();
continue; continue;
} }
NGRAPH_SUPPRESS_DEPRECATED_START
auto outName = layer->output(i).get_tensor().get_name(); auto outName = ngraph::op::util::get_ie_output_name(layer->output(i));
NGRAPH_SUPPRESS_DEPRECATED_END
if (outName.empty()) {
outName = ngraph::op::util::create_ie_output_name(layer->output(i));
}
DataPtr &ptr = cnnNetworkImpl->getData(outName.c_str()); DataPtr &ptr = cnnNetworkImpl->getData(outName.c_str());
IE_ASSERT(layer->get_output_partial_shape(i).is_static()) << " nGraph " IE_ASSERT(layer->get_output_partial_shape(i).is_static()) << " nGraph "
@ -2093,13 +2090,7 @@ void convertFunctionToICNNNetwork(const std::shared_ptr<const ::ngraph::Function
if (std::dynamic_pointer_cast<::ngraph::op::Result>(layer)) { if (std::dynamic_pointer_cast<::ngraph::op::Result>(layer)) {
IE_ASSERT(layer->get_input_size() == 1); IE_ASSERT(layer->get_input_size() == 1);
const auto &input = layer->input_value(0); const auto &input = layer->input_value(0);
NGRAPH_SUPPRESS_DEPRECATED_START cnnNetworkImpl->addOutput(ngraph::op::util::get_ie_output_name(input));
auto name = input.get_tensor().get_name();
NGRAPH_SUPPRESS_DEPRECATED_END
if (!name.empty())
cnnNetworkImpl->addOutput(name);
else
cnnNetworkImpl->addOutput(ngraph::op::util::create_ie_output_name(input));
continue; continue;
} }

View File

@ -228,12 +228,7 @@ void MKLDNNGraph::Replicate(const CNNNetwork &network, const MKLDNNExtensionMana
if (op->get_type_info() == ngraph::op::v0::Result::type_info) { if (op->get_type_info() == ngraph::op::v0::Result::type_info) {
const auto &input = op->input_value(0); const auto &input = op->input_value(0);
NGRAPH_SUPPRESS_DEPRECATED_START auto name = ngraph::op::util::get_ie_output_name(input);
auto name = input.get_tensor().get_name();
NGRAPH_SUPPRESS_DEPRECATED_END
if (name.empty()) {
name = ngraph::op::util::create_ie_output_name(input);
}
if (outputsInfo.count(name) != 0) { if (outputsInfo.count(name) != 0) {
outputNodesMap[name] = node; outputNodesMap[name] = node;

View File

@ -0,0 +1,29 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <ngraph/pass/graph_rewrite.hpp>
#include <transformations_visibility.hpp>
namespace ngraph {
namespace pass {
class TRANSFORMATIONS_API UnrollIf;
} // namespace pass
} // namespace ngraph
/**
* @ingroup ie_transformation_common_api
* @brief The transformation replaces 'If' operations with one of the internal functions (bodies) if the provided condition is constant.
* The condition is true: 'If' op is replaced with then_body
* The condition is false 'If' op is replaced with else_body
*/
class ngraph::pass::UnrollIf : public ngraph::pass::FunctionPass {
public:
NGRAPH_RTTI_DECLARATION;
bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
};

View File

@ -49,7 +49,6 @@ bool has_op_with_type(const std::shared_ptr<const ngraph::Function> &function) {
} }
return false; return false;
} }
inline std::string create_ie_output_name(const ngraph::Output<ngraph::Node>& output) { inline std::string create_ie_output_name(const ngraph::Output<ngraph::Node>& output) {
const auto& prev_layer = output.get_node_shared_ptr(); const auto& prev_layer = output.get_node_shared_ptr();
std::string out_name = prev_layer->get_friendly_name(); std::string out_name = prev_layer->get_friendly_name();
@ -57,6 +56,15 @@ inline std::string create_ie_output_name(const ngraph::Output<ngraph::Node>& out
out_name += "." + std::to_string(output.get_index()); out_name += "." + std::to_string(output.get_index());
return out_name; return out_name;
} }
inline std::string get_ie_output_name(const ngraph::Output<ngraph::Node>& output) {
NGRAPH_SUPPRESS_DEPRECATED_START
auto name = output.get_tensor().get_name();
NGRAPH_SUPPRESS_DEPRECATED_END
if (name.empty()) {
name = create_ie_output_name(output);
}
return name;
}
template <typename T> template <typename T>
bool has_constant_value(const std::shared_ptr<ngraph::opset4::Constant>& constant, bool has_constant_value(const std::shared_ptr<ngraph::opset4::Constant>& constant,

View File

@ -82,6 +82,8 @@
#include <ngraph/pass/constant_folding.hpp> #include <ngraph/pass/constant_folding.hpp>
#include <transformations/common_optimizations/weights_dequantize_to_fake_quantize.hpp> #include <transformations/common_optimizations/weights_dequantize_to_fake_quantize.hpp>
#include <transformations/common_optimizations/simplify_shape_of_sub_graph.hpp> #include <transformations/common_optimizations/simplify_shape_of_sub_graph.hpp>
#include <transformations/control_flow/unroll_if.hpp>
#include <transformations/op_conversions/normalize_l2_decomposition.hpp> #include <transformations/op_conversions/normalize_l2_decomposition.hpp>
#include <transformations/op_conversions/softmax_decomposition.hpp> #include <transformations/op_conversions/softmax_decomposition.hpp>
#include <transformations/common_optimizations/moc_transformations.hpp> #include <transformations/common_optimizations/moc_transformations.hpp>
@ -143,6 +145,7 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::
// LinOpSequenceFusion must be executed after all decompositions // LinOpSequenceFusion must be executed after all decompositions
manager.register_pass<ngraph::pass::LinOpSequenceFusion>(); manager.register_pass<ngraph::pass::LinOpSequenceFusion>();
manager.register_pass<ngraph::pass::UnrollIf>();
auto conv_fusions = manager.register_pass<ngraph::pass::GraphRewrite>(); auto conv_fusions = manager.register_pass<ngraph::pass::GraphRewrite>();
conv_fusions->add_matcher<ngraph::pass::ConvolutionMultiplyFusion>(); conv_fusions->add_matcher<ngraph::pass::ConvolutionMultiplyFusion>();

View File

@ -0,0 +1,61 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/control_flow/unroll_if.hpp"
#include <memory>
#include <ngraph/graph_util.hpp>
#include <ngraph/opsets/opset8.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <ngraph/rt_info.hpp>
#include <ngraph/validation_util.hpp>
#include "itt.hpp"
#include "transformations/utils/utils.hpp"
NGRAPH_RTTI_DEFINITION(ngraph::pass::UnrollIf, "UnrollIf", 0);
bool ngraph::pass::UnrollIf::run_on_function(std::shared_ptr<ngraph::Function> f) {
RUN_ON_FUNCTION_SCOPE(UnrollIf);
bool is_applicable = false;
for (const auto& op : f->get_ordered_ops()) {
auto if_node = std::dynamic_pointer_cast<opset8::If>(op);
if (!if_node || transformation_callback(if_node)) {
continue;
}
Output<Node> cond = if_node->input_value(0);
const auto cond_is_const = ngraph::get_constant_from_source(cond);
if (!cond_is_const) {
continue;
}
auto cond_value = cond_is_const->cast_vector<bool>();
auto body = (cond_value[0]) ? if_node->get_then_body() : if_node->get_else_body();
auto input_descriptions = if_node->get_input_descriptions(static_cast<int>(!cond_value[0]));
auto output_descriptions = if_node->get_output_descriptions(static_cast<int>(!cond_value[0]));
// connect inputs instead of body parameters
for (const auto& input_descr : input_descriptions) {
auto in_data = if_node->input_value(input_descr->m_input_index);
auto& param = body->get_parameters()[input_descr->m_body_parameter_index];
ngraph::replace_node(param, in_data.get_node_shared_ptr());
}
for (const auto& output_desc : output_descriptions) {
std::shared_ptr<opset8::Result> result = body->get_results()[output_desc->m_body_value_index];
const auto& in_value = result->input_value(0);
// set output name to Tensor to store it for ngraph to cnn conversion
NGRAPH_SUPPRESS_DEPRECATED_START
in_value.get_tensor().set_name(op::util::create_ie_output_name(if_node->output(output_desc->m_output_index)));
NGRAPH_SUPPRESS_DEPRECATED_END
for (const auto& input : if_node->output(output_desc->m_output_index).get_target_inputs()) {
input.replace_source_output(result->get_input_source_output(0));
}
}
is_applicable = true;
f->add_sinks(body->get_sinks());
copy_runtime_info(if_node, body->get_ops());
}
return is_applicable;
}

View File

@ -0,0 +1,115 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include <memory>
#include <ngraph/function.hpp>
#include <ngraph/opsets/opset3.hpp>
#include <ngraph/opsets/opset6.hpp>
#include <ngraph/opsets/opset8.hpp>
#include <ngraph/pass/manager.hpp>
#include <string>
#include <transformations/control_flow/unroll_if.hpp>
#include <transformations/init_node_info.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
using namespace testing;
TEST(TransformationTests, UnrollIfCondIsTrue) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto X = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 3 });
auto Y = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 3 });
auto cond = std::make_shared<ngraph::opset1::Constant>(ngraph::element::boolean, ngraph::Shape{ 1 }, true);
auto if_op = std::make_shared<ngraph::opset8::If>(cond);
auto Xt = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 3 });
auto Yt = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 3 });
auto add_op = std::make_shared<ngraph::opset1::Add>(Xt, Yt);
auto then_op_result = std::make_shared<ngraph::opset1::Result>(add_op);
auto then_body = std::make_shared<ngraph::Function>(ngraph::OutputVector{ then_op_result }, ngraph::ParameterVector{ Xt, Yt });
auto Xe = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 3 });
auto Ye = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 3 });
auto mul_op = std::make_shared<ngraph::opset1::Multiply>(Xe, Ye);
auto else_op_result = std::make_shared<ngraph::opset1::Result>(mul_op);
auto else_body = std::make_shared<ngraph::Function>(ngraph::OutputVector{ else_op_result }, ngraph::ParameterVector{ Xe, Ye });
if_op->set_then_body(then_body);
if_op->set_else_body(else_body);
if_op->set_input(X, Xt, Xe);
if_op->set_input(Y, Yt, Ye);
if_op->set_output(then_op_result, else_op_result);
auto if_result = std::make_shared<ngraph::opset1::Result>(if_op);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ if_result }, ngraph::ParameterVector{ X, Y });
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::UnrollIf>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto X = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{ 3 });
auto Y = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{ 3 });
auto add_op = std::make_shared<ngraph::opset1::Add>(X, Y);
auto if_result = std::make_shared<ngraph::opset1::Result>(add_op);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ if_result }, ngraph::ParameterVector{ X, Y });
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, UnrollIfCondIsFalse) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto X = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 3 });
auto Y = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 3 });
auto cond = std::make_shared<ngraph::opset1::Constant>(ngraph::element::boolean, ngraph::Shape{ 1 }, false);
auto if_op = std::make_shared<ngraph::opset8::If>(cond);
auto Xt = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 3 });
auto Yt = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 3 });
auto add_op = std::make_shared<ngraph::opset1::Add>(Xt, Yt);
auto then_op_result = std::make_shared<ngraph::opset1::Result>(add_op);
auto then_body = std::make_shared<ngraph::Function>(ngraph::OutputVector{ then_op_result }, ngraph::ParameterVector{ Xt, Yt });
auto Xe = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 3 });
auto Ye = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 3 });
auto mul_op = std::make_shared<ngraph::opset1::Multiply>(Xe, Ye);
auto else_op_result = std::make_shared<ngraph::opset1::Result>(mul_op);
auto else_body = std::make_shared<ngraph::Function>(ngraph::OutputVector{ else_op_result }, ngraph::ParameterVector{ Xe, Ye });
if_op->set_then_body(then_body);
if_op->set_else_body(else_body);
if_op->set_input(X, Xt, Xe);
if_op->set_input(Y, Yt, Ye);
if_op->set_output(then_op_result, else_op_result);
auto if_result = std::make_shared<ngraph::opset1::Result>(if_op);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ if_result }, ngraph::ParameterVector{ X, Y });
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::UnrollIf>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto X = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{ 3 });
auto Y = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{ 3 });
auto mul_op = std::make_shared<ngraph::opset1::Multiply>(X, Y);
auto if_result = std::make_shared<ngraph::opset1::Result>(mul_op);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ if_result }, ngraph::ParameterVector{ X, Y });
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}

View File

@ -131,6 +131,8 @@ std::vector<std::string> disabledTestPatterns() {
R"(smoke_PrePostProcess.*resize_and_convert_layout_i8.*)", R"(smoke_PrePostProcess.*resize_and_convert_layout_i8.*)",
// Issue 67910 // Issue 67910
R"(.*smoke_PrePostProcess.*two_inputs_trivial.*)", R"(.*smoke_PrePostProcess.*two_inputs_trivial.*)",
// TODO: CVS-67255
R"(smoke_If.*SimpleIf2OutTest.*)"
}; };
#define FIX_62820 0 #define FIX_62820 0