[ONNX] remove unnecessary nodes when handling Identity op (#8060)

This commit is contained in:
Mateusz Tabaka 2021-10-19 10:18:13 +02:00 committed by GitHub
parent c84db94697
commit bcc1404c58
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 24 additions and 28 deletions

View File

@ -15,13 +15,7 @@ namespace onnx_import {
namespace op {
namespace set_1 {
inline OutputVector identity(const Node& node) {
auto input = node.get_ng_inputs().at(0);
if (input.get_element_type() == ngraph::element::boolean) {
const auto logic_zero = default_opset::Constant::create(ngraph::element::boolean, {}, {false});
return {std::make_shared<default_opset::LogicalOr>(input, logic_zero)};
}
const auto zero = default_opset::Constant::create(input.get_element_type(), {}, {0});
return {std::make_shared<default_opset::Add>(input, zero)};
return node.get_ng_inputs();
}
} // namespace set_1

View File

@ -28,24 +28,13 @@ namespace {
/// As a result ngraph Loop shape inference is able to handle more
/// cases.
///
/// \param[in] body_out_cond Termination loop condition input of the body of
/// the Loop (value updated during Loop iterations).
/// \param[in] cond_in boolean input to the loop body depicting loop termination condition
///
/// \return true if termination condition is true and it cannot be changed
/// during Loop iterations, false otherwise.
bool is_termination_condition_always_true(const Output<ngraph::Node>& body_out_cond) {
// If body termination condition input matches Indentity op pattern the has
// value of loop_cond - true
// Identity op for boolean value is represented by LogicalOr op whose second
// input is always false
if (ov::is_type<default_opset::LogicalOr>(body_out_cond.get_node_shared_ptr())) {
const auto second_input = body_out_cond.get_node_shared_ptr()->input_value(1).get_node_shared_ptr();
if (ngraph::op::is_constant(second_input) && second_input->get_element_type() == element::boolean &&
ov::as_type_ptr<default_opset::Constant>(second_input)->cast_vector<bool>().at(0) == false) {
return true;
}
}
return false;
/// \param[in] cond_out loop termination condition computed after each iteration
///
/// \return true if termination condition is not modified during loop iterations, false otherwise.
bool is_termination_condition_always_true(const ngraph::Node* cond_in, const ngraph::Node* cond_out) {
return cond_in == cond_out;
}
} // namespace
@ -112,9 +101,10 @@ OutputVector loop(const Node& node) {
body_outputs[i] = std::make_shared<default_opset::Unsqueeze>(body_outputs[i], concat_axis_const);
}
const auto& body_loop_out_cond = body_outputs.at(0).get_node_shared_ptr();
const auto& cond_in = body_inputs[1];
const auto& cond_out = body_outputs[0];
// optimization allow to improve nG Loop shape inference
if (is_termination_condition_always_true(body_loop_out_cond)) {
if (is_termination_condition_always_true(cond_in.get(), cond_out.get_node())) {
body_outputs[0] = ngraph::op::Constant::create(ngraph::element::boolean, {1}, {true});
}

View File

@ -9,6 +9,12 @@ graph {
}
node {
input: "relu_t"
output: "abs_t"
op_type: "Abs"
name: "abs"
}
node {
input: "abs_t"
output: "final_output"
name: "ident"
op_type: "Identity"

View File

@ -22,14 +22,20 @@ NGRAPH_TEST(onnx_tensor_names, simple_model) {
auto function = onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/tensor_names.onnx"));
auto ops = function->get_ordered_ops();
// Parameter
ASSERT_EQ(ops[0]->get_friendly_name(), "input");
ASSERT_EQ(ops[0]->get_output_tensor(0).get_names(), std::unordered_set<std::string>{"input"});
ASSERT_EQ(ops[1]->get_friendly_name(), "relu_t");
ASSERT_EQ(ops[1]->get_output_tensor(0).get_names(), std::unordered_set<std::string>{"relu_t"});
// ops[2] is a constant created in the ONNX importer as part of Identity operator
// should be abs_t, but Identity operator gets cut out and that makes Abs become the 'final_output'
ASSERT_EQ(ops[2]->get_friendly_name(), "final_output");
ASSERT_EQ(ops[2]->get_output_tensor(0).get_names(), std::unordered_set<std::string>{"final_output"});
// Result node
ASSERT_EQ(ops[3]->get_friendly_name(), "final_output");
ASSERT_EQ(ops[3]->get_output_tensor(0).get_names(), std::unordered_set<std::string>{"final_output"});
ASSERT_EQ(ops[4]->get_friendly_name(), "final_output");
ASSERT_EQ(function->get_result()->get_input_tensor(0).get_names(), std::unordered_set<std::string>{"final_output"});
ASSERT_EQ(function->get_result()->input_value(0).get_tensor().get_names(),