[ONNX] remove unnecessary nodes when handling Identity op (#8060)

This commit is contained in:
Mateusz Tabaka 2021-10-19 10:18:13 +02:00 committed by GitHub
parent c84db94697
commit bcc1404c58
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 24 additions and 28 deletions

View File

@ -15,13 +15,7 @@ namespace onnx_import {
namespace op { namespace op {
namespace set_1 { namespace set_1 {
inline OutputVector identity(const Node& node) { inline OutputVector identity(const Node& node) {
auto input = node.get_ng_inputs().at(0); return node.get_ng_inputs();
if (input.get_element_type() == ngraph::element::boolean) {
const auto logic_zero = default_opset::Constant::create(ngraph::element::boolean, {}, {false});
return {std::make_shared<default_opset::LogicalOr>(input, logic_zero)};
}
const auto zero = default_opset::Constant::create(input.get_element_type(), {}, {0});
return {std::make_shared<default_opset::Add>(input, zero)};
} }
} // namespace set_1 } // namespace set_1

View File

@ -28,24 +28,13 @@ namespace {
/// As a result ngraph Loop shape inference is able to handle more /// As a result ngraph Loop shape inference is able to handle more
/// cases. /// cases.
/// ///
/// \param[in] body_out_cond Termination loop condition input of the body of /// \param[in] cond_in boolean input to the loop body depicting loop termination condition
/// the Loop (value updated during Loop iterations).
/// ///
/// \return true if termination condition is true and it cannot be changed /// \param[in] cond_out loop termination condition computed after each iteration
/// during Loop iterations, false otherwise. ///
bool is_termination_condition_always_true(const Output<ngraph::Node>& body_out_cond) { /// \return true if termination condition is not modified during loop iterations, false otherwise.
// If body termination condition input matches Indentity op pattern the has bool is_termination_condition_always_true(const ngraph::Node* cond_in, const ngraph::Node* cond_out) {
// value of loop_cond - true return cond_in == cond_out;
// Identity op for boolean value is represented by LogicalOr op whose second
// input is always false
if (ov::is_type<default_opset::LogicalOr>(body_out_cond.get_node_shared_ptr())) {
const auto second_input = body_out_cond.get_node_shared_ptr()->input_value(1).get_node_shared_ptr();
if (ngraph::op::is_constant(second_input) && second_input->get_element_type() == element::boolean &&
ov::as_type_ptr<default_opset::Constant>(second_input)->cast_vector<bool>().at(0) == false) {
return true;
}
}
return false;
} }
} // namespace } // namespace
@ -112,9 +101,10 @@ OutputVector loop(const Node& node) {
body_outputs[i] = std::make_shared<default_opset::Unsqueeze>(body_outputs[i], concat_axis_const); body_outputs[i] = std::make_shared<default_opset::Unsqueeze>(body_outputs[i], concat_axis_const);
} }
const auto& body_loop_out_cond = body_outputs.at(0).get_node_shared_ptr(); const auto& cond_in = body_inputs[1];
const auto& cond_out = body_outputs[0];
// optimization allow to improve nG Loop shape inference // optimization allow to improve nG Loop shape inference
if (is_termination_condition_always_true(body_loop_out_cond)) { if (is_termination_condition_always_true(cond_in.get(), cond_out.get_node())) {
body_outputs[0] = ngraph::op::Constant::create(ngraph::element::boolean, {1}, {true}); body_outputs[0] = ngraph::op::Constant::create(ngraph::element::boolean, {1}, {true});
} }

View File

@ -9,6 +9,12 @@ graph {
} }
node { node {
input: "relu_t" input: "relu_t"
output: "abs_t"
op_type: "Abs"
name: "abs"
}
node {
input: "abs_t"
output: "final_output" output: "final_output"
name: "ident" name: "ident"
op_type: "Identity" op_type: "Identity"

View File

@ -22,14 +22,20 @@ NGRAPH_TEST(onnx_tensor_names, simple_model) {
auto function = onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/tensor_names.onnx")); auto function = onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/tensor_names.onnx"));
auto ops = function->get_ordered_ops(); auto ops = function->get_ordered_ops();
// Parameter
ASSERT_EQ(ops[0]->get_friendly_name(), "input"); ASSERT_EQ(ops[0]->get_friendly_name(), "input");
ASSERT_EQ(ops[0]->get_output_tensor(0).get_names(), std::unordered_set<std::string>{"input"}); ASSERT_EQ(ops[0]->get_output_tensor(0).get_names(), std::unordered_set<std::string>{"input"});
ASSERT_EQ(ops[1]->get_friendly_name(), "relu_t"); ASSERT_EQ(ops[1]->get_friendly_name(), "relu_t");
ASSERT_EQ(ops[1]->get_output_tensor(0).get_names(), std::unordered_set<std::string>{"relu_t"}); ASSERT_EQ(ops[1]->get_output_tensor(0).get_names(), std::unordered_set<std::string>{"relu_t"});
// ops[2] is a constant created in the ONNX importer as part of Identity operator
// should be abs_t, but Identity operator gets cut out and that makes Abs become the 'final_output'
ASSERT_EQ(ops[2]->get_friendly_name(), "final_output");
ASSERT_EQ(ops[2]->get_output_tensor(0).get_names(), std::unordered_set<std::string>{"final_output"});
// Result node
ASSERT_EQ(ops[3]->get_friendly_name(), "final_output"); ASSERT_EQ(ops[3]->get_friendly_name(), "final_output");
ASSERT_EQ(ops[3]->get_output_tensor(0).get_names(), std::unordered_set<std::string>{"final_output"}); ASSERT_EQ(ops[3]->get_output_tensor(0).get_names(), std::unordered_set<std::string>{"final_output"});
ASSERT_EQ(ops[4]->get_friendly_name(), "final_output");
ASSERT_EQ(function->get_result()->get_input_tensor(0).get_names(), std::unordered_set<std::string>{"final_output"}); ASSERT_EQ(function->get_result()->get_input_tensor(0).get_names(), std::unordered_set<std::string>{"final_output"});
ASSERT_EQ(function->get_result()->input_value(0).get_tensor().get_names(), ASSERT_EQ(function->get_result()->input_value(0).get_tensor().get_names(),