Unroll If in ngraph (#7403)
* added unroll_if * Fix interface function * Fix code style * add false test for unroll if * add if transformation in subgraphs * fix code style * fix transformations * delete internal transformations * Fix comments * fix to replace_node * fix description fo transformation * fix CompileNetwork * fix comments * fix comments * add function get_ie_output_name(input); * fix code style * disable cpu test
This commit is contained in:
parent
3f690314fa
commit
3e48008c3f
@ -113,8 +113,9 @@ void TemplatePlugin::ExecutableNetwork::CompileNetwork(const std::shared_ptr<con
|
||||
// Generate backend specific blob mappings. For example Inference Engine uses not ngraph::Result nodes friendly name
|
||||
// as inference request output names but the name of the layer before.
|
||||
for (auto&& result : _function->get_results()) {
|
||||
auto outputName = ngraph::op::util::create_ie_output_name(result->input_value(0));
|
||||
_outputIndex.emplace(outputName, _function->get_result_index(result));
|
||||
const auto& input = result->input_value(0);
|
||||
auto name = ngraph::op::util::get_ie_output_name(input);
|
||||
_outputIndex.emplace(name, _function->get_result_index(result));
|
||||
}
|
||||
for (auto&& parameter : _function->get_parameters()) {
|
||||
_inputIndex.emplace(parameter->get_friendly_name(), _function->get_parameter_index(parameter));
|
||||
|
@ -2037,12 +2037,9 @@ void convertFunctionToICNNNetwork(const std::shared_ptr<const ::ngraph::Function
|
||||
cnnLayer->outData.clear();
|
||||
continue;
|
||||
}
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
auto outName = layer->output(i).get_tensor().get_name();
|
||||
NGRAPH_SUPPRESS_DEPRECATED_END
|
||||
if (outName.empty()) {
|
||||
outName = ngraph::op::util::create_ie_output_name(layer->output(i));
|
||||
}
|
||||
|
||||
auto outName = ngraph::op::util::get_ie_output_name(layer->output(i));
|
||||
|
||||
|
||||
DataPtr &ptr = cnnNetworkImpl->getData(outName.c_str());
|
||||
IE_ASSERT(layer->get_output_partial_shape(i).is_static()) << " nGraph "
|
||||
@ -2093,13 +2090,7 @@ void convertFunctionToICNNNetwork(const std::shared_ptr<const ::ngraph::Function
|
||||
if (std::dynamic_pointer_cast<::ngraph::op::Result>(layer)) {
|
||||
IE_ASSERT(layer->get_input_size() == 1);
|
||||
const auto &input = layer->input_value(0);
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
auto name = input.get_tensor().get_name();
|
||||
NGRAPH_SUPPRESS_DEPRECATED_END
|
||||
if (!name.empty())
|
||||
cnnNetworkImpl->addOutput(name);
|
||||
else
|
||||
cnnNetworkImpl->addOutput(ngraph::op::util::create_ie_output_name(input));
|
||||
cnnNetworkImpl->addOutput(ngraph::op::util::get_ie_output_name(input));
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -228,12 +228,7 @@ void MKLDNNGraph::Replicate(const CNNNetwork &network, const MKLDNNExtensionMana
|
||||
|
||||
if (op->get_type_info() == ngraph::op::v0::Result::type_info) {
|
||||
const auto &input = op->input_value(0);
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
auto name = input.get_tensor().get_name();
|
||||
NGRAPH_SUPPRESS_DEPRECATED_END
|
||||
if (name.empty()) {
|
||||
name = ngraph::op::util::create_ie_output_name(input);
|
||||
}
|
||||
auto name = ngraph::op::util::get_ie_output_name(input);
|
||||
|
||||
if (outputsInfo.count(name) != 0) {
|
||||
outputNodesMap[name] = node;
|
||||
|
@ -0,0 +1,29 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ngraph/pass/graph_rewrite.hpp>
|
||||
#include <transformations_visibility.hpp>
|
||||
|
||||
namespace ngraph {
|
||||
namespace pass {
|
||||
|
||||
class TRANSFORMATIONS_API UnrollIf;
|
||||
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief The transformation replaces 'If' operations with one of the internal functions (bodies) if the provided condition is constant.
|
||||
* The condition is true: 'If' op is replaced with then_body
|
||||
* The condition is false 'If' op is replaced with else_body
|
||||
*/
|
||||
|
||||
class ngraph::pass::UnrollIf : public ngraph::pass::FunctionPass {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
|
||||
};
|
@ -49,7 +49,6 @@ bool has_op_with_type(const std::shared_ptr<const ngraph::Function> &function) {
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
inline std::string create_ie_output_name(const ngraph::Output<ngraph::Node>& output) {
|
||||
const auto& prev_layer = output.get_node_shared_ptr();
|
||||
std::string out_name = prev_layer->get_friendly_name();
|
||||
@ -57,6 +56,15 @@ inline std::string create_ie_output_name(const ngraph::Output<ngraph::Node>& out
|
||||
out_name += "." + std::to_string(output.get_index());
|
||||
return out_name;
|
||||
}
|
||||
inline std::string get_ie_output_name(const ngraph::Output<ngraph::Node>& output) {
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
auto name = output.get_tensor().get_name();
|
||||
NGRAPH_SUPPRESS_DEPRECATED_END
|
||||
if (name.empty()) {
|
||||
name = create_ie_output_name(output);
|
||||
}
|
||||
return name;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool has_constant_value(const std::shared_ptr<ngraph::opset4::Constant>& constant,
|
||||
|
@ -82,6 +82,8 @@
|
||||
#include <ngraph/pass/constant_folding.hpp>
|
||||
#include <transformations/common_optimizations/weights_dequantize_to_fake_quantize.hpp>
|
||||
#include <transformations/common_optimizations/simplify_shape_of_sub_graph.hpp>
|
||||
|
||||
#include <transformations/control_flow/unroll_if.hpp>
|
||||
#include <transformations/op_conversions/normalize_l2_decomposition.hpp>
|
||||
#include <transformations/op_conversions/softmax_decomposition.hpp>
|
||||
#include <transformations/common_optimizations/moc_transformations.hpp>
|
||||
@ -143,6 +145,7 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::
|
||||
|
||||
// LinOpSequenceFusion must be executed after all decompositions
|
||||
manager.register_pass<ngraph::pass::LinOpSequenceFusion>();
|
||||
manager.register_pass<ngraph::pass::UnrollIf>();
|
||||
|
||||
auto conv_fusions = manager.register_pass<ngraph::pass::GraphRewrite>();
|
||||
conv_fusions->add_matcher<ngraph::pass::ConvolutionMultiplyFusion>();
|
||||
|
@ -0,0 +1,61 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transformations/control_flow/unroll_if.hpp"
|
||||
|
||||
#include <memory>
|
||||
#include <ngraph/graph_util.hpp>
|
||||
#include <ngraph/opsets/opset8.hpp>
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
#include <ngraph/rt_info.hpp>
|
||||
#include <ngraph/validation_util.hpp>
|
||||
|
||||
#include "itt.hpp"
|
||||
#include "transformations/utils/utils.hpp"
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::UnrollIf, "UnrollIf", 0);
|
||||
|
||||
bool ngraph::pass::UnrollIf::run_on_function(std::shared_ptr<ngraph::Function> f) {
|
||||
RUN_ON_FUNCTION_SCOPE(UnrollIf);
|
||||
bool is_applicable = false;
|
||||
for (const auto& op : f->get_ordered_ops()) {
|
||||
auto if_node = std::dynamic_pointer_cast<opset8::If>(op);
|
||||
if (!if_node || transformation_callback(if_node)) {
|
||||
continue;
|
||||
}
|
||||
Output<Node> cond = if_node->input_value(0);
|
||||
const auto cond_is_const = ngraph::get_constant_from_source(cond);
|
||||
if (!cond_is_const) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto cond_value = cond_is_const->cast_vector<bool>();
|
||||
auto body = (cond_value[0]) ? if_node->get_then_body() : if_node->get_else_body();
|
||||
auto input_descriptions = if_node->get_input_descriptions(static_cast<int>(!cond_value[0]));
|
||||
auto output_descriptions = if_node->get_output_descriptions(static_cast<int>(!cond_value[0]));
|
||||
|
||||
// connect inputs instead of body parameters
|
||||
for (const auto& input_descr : input_descriptions) {
|
||||
auto in_data = if_node->input_value(input_descr->m_input_index);
|
||||
auto& param = body->get_parameters()[input_descr->m_body_parameter_index];
|
||||
ngraph::replace_node(param, in_data.get_node_shared_ptr());
|
||||
}
|
||||
for (const auto& output_desc : output_descriptions) {
|
||||
std::shared_ptr<opset8::Result> result = body->get_results()[output_desc->m_body_value_index];
|
||||
const auto& in_value = result->input_value(0);
|
||||
|
||||
// set output name to Tensor to store it for ngraph to cnn conversion
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
in_value.get_tensor().set_name(op::util::create_ie_output_name(if_node->output(output_desc->m_output_index)));
|
||||
NGRAPH_SUPPRESS_DEPRECATED_END
|
||||
for (const auto& input : if_node->output(output_desc->m_output_index).get_target_inputs()) {
|
||||
input.replace_source_output(result->get_input_source_output(0));
|
||||
}
|
||||
}
|
||||
is_applicable = true;
|
||||
f->add_sinks(body->get_sinks());
|
||||
copy_runtime_info(if_node, body->get_ops());
|
||||
}
|
||||
return is_applicable;
|
||||
}
|
@ -0,0 +1,115 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <memory>
|
||||
#include <ngraph/function.hpp>
|
||||
#include <ngraph/opsets/opset3.hpp>
|
||||
#include <ngraph/opsets/opset6.hpp>
|
||||
#include <ngraph/opsets/opset8.hpp>
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
#include <string>
|
||||
#include <transformations/control_flow/unroll_if.hpp>
|
||||
#include <transformations/init_node_info.hpp>
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
|
||||
using namespace testing;
|
||||
|
||||
TEST(TransformationTests, UnrollIfCondIsTrue) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto X = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 3 });
|
||||
auto Y = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 3 });
|
||||
auto cond = std::make_shared<ngraph::opset1::Constant>(ngraph::element::boolean, ngraph::Shape{ 1 }, true);
|
||||
auto if_op = std::make_shared<ngraph::opset8::If>(cond);
|
||||
auto Xt = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 3 });
|
||||
auto Yt = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 3 });
|
||||
auto add_op = std::make_shared<ngraph::opset1::Add>(Xt, Yt);
|
||||
auto then_op_result = std::make_shared<ngraph::opset1::Result>(add_op);
|
||||
auto then_body = std::make_shared<ngraph::Function>(ngraph::OutputVector{ then_op_result }, ngraph::ParameterVector{ Xt, Yt });
|
||||
|
||||
auto Xe = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 3 });
|
||||
auto Ye = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 3 });
|
||||
auto mul_op = std::make_shared<ngraph::opset1::Multiply>(Xe, Ye);
|
||||
auto else_op_result = std::make_shared<ngraph::opset1::Result>(mul_op);
|
||||
auto else_body = std::make_shared<ngraph::Function>(ngraph::OutputVector{ else_op_result }, ngraph::ParameterVector{ Xe, Ye });
|
||||
|
||||
if_op->set_then_body(then_body);
|
||||
if_op->set_else_body(else_body);
|
||||
if_op->set_input(X, Xt, Xe);
|
||||
if_op->set_input(Y, Yt, Ye);
|
||||
if_op->set_output(then_op_result, else_op_result);
|
||||
auto if_result = std::make_shared<ngraph::opset1::Result>(if_op);
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ if_result }, ngraph::ParameterVector{ X, Y });
|
||||
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::UnrollIf>();
|
||||
manager.run_passes(f);
|
||||
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
{
|
||||
auto X = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{ 3 });
|
||||
auto Y = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{ 3 });
|
||||
auto add_op = std::make_shared<ngraph::opset1::Add>(X, Y);
|
||||
auto if_result = std::make_shared<ngraph::opset1::Result>(add_op);
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ if_result }, ngraph::ParameterVector{ X, Y });
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, UnrollIfCondIsFalse) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto X = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 3 });
|
||||
auto Y = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 3 });
|
||||
auto cond = std::make_shared<ngraph::opset1::Constant>(ngraph::element::boolean, ngraph::Shape{ 1 }, false);
|
||||
auto if_op = std::make_shared<ngraph::opset8::If>(cond);
|
||||
auto Xt = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 3 });
|
||||
auto Yt = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 3 });
|
||||
auto add_op = std::make_shared<ngraph::opset1::Add>(Xt, Yt);
|
||||
auto then_op_result = std::make_shared<ngraph::opset1::Result>(add_op);
|
||||
auto then_body = std::make_shared<ngraph::Function>(ngraph::OutputVector{ then_op_result }, ngraph::ParameterVector{ Xt, Yt });
|
||||
|
||||
auto Xe = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 3 });
|
||||
auto Ye = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 3 });
|
||||
auto mul_op = std::make_shared<ngraph::opset1::Multiply>(Xe, Ye);
|
||||
auto else_op_result = std::make_shared<ngraph::opset1::Result>(mul_op);
|
||||
auto else_body = std::make_shared<ngraph::Function>(ngraph::OutputVector{ else_op_result }, ngraph::ParameterVector{ Xe, Ye });
|
||||
|
||||
if_op->set_then_body(then_body);
|
||||
if_op->set_else_body(else_body);
|
||||
if_op->set_input(X, Xt, Xe);
|
||||
if_op->set_input(Y, Yt, Ye);
|
||||
if_op->set_output(then_op_result, else_op_result);
|
||||
auto if_result = std::make_shared<ngraph::opset1::Result>(if_op);
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ if_result }, ngraph::ParameterVector{ X, Y });
|
||||
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::UnrollIf>();
|
||||
manager.run_passes(f);
|
||||
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
{
|
||||
auto X = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{ 3 });
|
||||
auto Y = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{ 3 });
|
||||
auto mul_op = std::make_shared<ngraph::opset1::Multiply>(X, Y);
|
||||
auto if_result = std::make_shared<ngraph::opset1::Result>(mul_op);
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ if_result }, ngraph::ParameterVector{ X, Y });
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
@ -131,6 +131,8 @@ std::vector<std::string> disabledTestPatterns() {
|
||||
R"(smoke_PrePostProcess.*resize_and_convert_layout_i8.*)",
|
||||
// Issue 67910
|
||||
R"(.*smoke_PrePostProcess.*two_inputs_trivial.*)",
|
||||
// TODO: CVS-67255
|
||||
R"(smoke_If.*SimpleIf2OutTest.*)"
|
||||
};
|
||||
|
||||
#define FIX_62820 0
|
||||
|
Loading…
Reference in New Issue
Block a user