From bcc1404c58b5df24f3fc31e87e5ef19b31e3b077 Mon Sep 17 00:00:00 2001 From: Mateusz Tabaka Date: Tue, 19 Oct 2021 10:18:13 +0200 Subject: [PATCH] [ONNX] remove unnecessary nodes when handling Identity op (#8060) --- .../onnx/frontend/src/op/identity.hpp | 8 +----- ngraph/frontend/onnx/frontend/src/op/loop.cpp | 28 ++++++------------- ngraph/test/models/onnx/tensor_names.prototxt | 6 ++++ ngraph/test/onnx/onnx_tensor_names.cpp | 10 +++++-- 4 files changed, 24 insertions(+), 28 deletions(-) diff --git a/ngraph/frontend/onnx/frontend/src/op/identity.hpp b/ngraph/frontend/onnx/frontend/src/op/identity.hpp index e811a73d0e7..16217f66709 100644 --- a/ngraph/frontend/onnx/frontend/src/op/identity.hpp +++ b/ngraph/frontend/onnx/frontend/src/op/identity.hpp @@ -15,13 +15,7 @@ namespace onnx_import { namespace op { namespace set_1 { inline OutputVector identity(const Node& node) { - auto input = node.get_ng_inputs().at(0); - if (input.get_element_type() == ngraph::element::boolean) { - const auto logic_zero = default_opset::Constant::create(ngraph::element::boolean, {}, {false}); - return {std::make_shared(input, logic_zero)}; - } - const auto zero = default_opset::Constant::create(input.get_element_type(), {}, {0}); - return {std::make_shared(input, zero)}; + return node.get_ng_inputs(); } } // namespace set_1 diff --git a/ngraph/frontend/onnx/frontend/src/op/loop.cpp b/ngraph/frontend/onnx/frontend/src/op/loop.cpp index 979f1665320..cf95abd68e4 100644 --- a/ngraph/frontend/onnx/frontend/src/op/loop.cpp +++ b/ngraph/frontend/onnx/frontend/src/op/loop.cpp @@ -28,24 +28,13 @@ namespace { /// As a result ngraph Loop shape inference is able to handle more /// cases. /// -/// \param[in] body_out_cond Termination loop condition input of the body of -/// the Loop (value updated during Loop iterations). +/// \param[in] cond_in boolean input to the loop body depicting loop termination condition /// -/// \return true if termination condition is true and it cannot be changed -/// during Loop iterations, false otherwise. -bool is_termination_condition_always_true(const Output& body_out_cond) { - // If body termination condition input matches Indentity op pattern the has - // value of loop_cond - true - // Identity op for boolean value is represented by LogicalOr op whose second - // input is always false - if (ov::is_type(body_out_cond.get_node_shared_ptr())) { - const auto second_input = body_out_cond.get_node_shared_ptr()->input_value(1).get_node_shared_ptr(); - if (ngraph::op::is_constant(second_input) && second_input->get_element_type() == element::boolean && - ov::as_type_ptr(second_input)->cast_vector().at(0) == false) { - return true; - } - } - return false; +/// \param[in] cond_out loop termination condition computed after each iteration +/// +/// \return true if termination condition is not modified during loop iterations, false otherwise. +bool is_termination_condition_always_true(const ngraph::Node* cond_in, const ngraph::Node* cond_out) { + return cond_in == cond_out; } } // namespace @@ -112,9 +101,10 @@ OutputVector loop(const Node& node) { body_outputs[i] = std::make_shared(body_outputs[i], concat_axis_const); } - const auto& body_loop_out_cond = body_outputs.at(0).get_node_shared_ptr(); + const auto& cond_in = body_inputs[1]; + const auto& cond_out = body_outputs[0]; // optimization allow to improve nG Loop shape inference - if (is_termination_condition_always_true(body_loop_out_cond)) { + if (is_termination_condition_always_true(cond_in.get(), cond_out.get_node())) { body_outputs[0] = ngraph::op::Constant::create(ngraph::element::boolean, {1}, {true}); } diff --git a/ngraph/test/models/onnx/tensor_names.prototxt b/ngraph/test/models/onnx/tensor_names.prototxt index 184674b3bf0..784beb19916 100644 --- a/ngraph/test/models/onnx/tensor_names.prototxt +++ b/ngraph/test/models/onnx/tensor_names.prototxt @@ -9,6 +9,12 @@ graph { } node { input: "relu_t" + output: "abs_t" + op_type: "Abs" + name: "abs" + } + node { + input: "abs_t" output: "final_output" name: "ident" op_type: "Identity" diff --git a/ngraph/test/onnx/onnx_tensor_names.cpp b/ngraph/test/onnx/onnx_tensor_names.cpp index 198f8f4d2a1..1dcbb65a552 100644 --- a/ngraph/test/onnx/onnx_tensor_names.cpp +++ b/ngraph/test/onnx/onnx_tensor_names.cpp @@ -22,14 +22,20 @@ NGRAPH_TEST(onnx_tensor_names, simple_model) { auto function = onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/tensor_names.onnx")); auto ops = function->get_ordered_ops(); + // Parameter ASSERT_EQ(ops[0]->get_friendly_name(), "input"); ASSERT_EQ(ops[0]->get_output_tensor(0).get_names(), std::unordered_set{"input"}); + ASSERT_EQ(ops[1]->get_friendly_name(), "relu_t"); ASSERT_EQ(ops[1]->get_output_tensor(0).get_names(), std::unordered_set{"relu_t"}); - // ops[2] is a constant created in the ONNX importer as part of Identity operator + + // should be abs_t, but Identity operator gets cut out and that makes Abs become the 'final_output' + ASSERT_EQ(ops[2]->get_friendly_name(), "final_output"); + ASSERT_EQ(ops[2]->get_output_tensor(0).get_names(), std::unordered_set{"final_output"}); + + // Result node ASSERT_EQ(ops[3]->get_friendly_name(), "final_output"); ASSERT_EQ(ops[3]->get_output_tensor(0).get_names(), std::unordered_set{"final_output"}); - ASSERT_EQ(ops[4]->get_friendly_name(), "final_output"); ASSERT_EQ(function->get_result()->get_input_tensor(0).get_names(), std::unordered_set{"final_output"}); ASSERT_EQ(function->get_result()->input_value(0).get_tensor().get_names(),