[ONNX] remove unnecessary nodes when handling Identity op (#8060)
This commit is contained in:
parent
c84db94697
commit
bcc1404c58
@ -15,13 +15,7 @@ namespace onnx_import {
|
||||
namespace op {
|
||||
namespace set_1 {
|
||||
inline OutputVector identity(const Node& node) {
|
||||
auto input = node.get_ng_inputs().at(0);
|
||||
if (input.get_element_type() == ngraph::element::boolean) {
|
||||
const auto logic_zero = default_opset::Constant::create(ngraph::element::boolean, {}, {false});
|
||||
return {std::make_shared<default_opset::LogicalOr>(input, logic_zero)};
|
||||
}
|
||||
const auto zero = default_opset::Constant::create(input.get_element_type(), {}, {0});
|
||||
return {std::make_shared<default_opset::Add>(input, zero)};
|
||||
return node.get_ng_inputs();
|
||||
}
|
||||
} // namespace set_1
|
||||
|
||||
|
@ -28,24 +28,13 @@ namespace {
|
||||
/// As a result ngraph Loop shape inference is able to handle more
|
||||
/// cases.
|
||||
///
|
||||
/// \param[in] body_out_cond Termination loop condition input of the body of
|
||||
/// the Loop (value updated during Loop iterations).
|
||||
/// \param[in] cond_in boolean input to the loop body depicting loop termination condition
|
||||
///
|
||||
/// \return true if termination condition is true and it cannot be changed
|
||||
/// during Loop iterations, false otherwise.
|
||||
bool is_termination_condition_always_true(const Output<ngraph::Node>& body_out_cond) {
|
||||
// If body termination condition input matches Indentity op pattern the has
|
||||
// value of loop_cond - true
|
||||
// Identity op for boolean value is represented by LogicalOr op whose second
|
||||
// input is always false
|
||||
if (ov::is_type<default_opset::LogicalOr>(body_out_cond.get_node_shared_ptr())) {
|
||||
const auto second_input = body_out_cond.get_node_shared_ptr()->input_value(1).get_node_shared_ptr();
|
||||
if (ngraph::op::is_constant(second_input) && second_input->get_element_type() == element::boolean &&
|
||||
ov::as_type_ptr<default_opset::Constant>(second_input)->cast_vector<bool>().at(0) == false) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
/// \param[in] cond_out loop termination condition computed after each iteration
|
||||
///
|
||||
/// \return true if termination condition is not modified during loop iterations, false otherwise.
|
||||
bool is_termination_condition_always_true(const ngraph::Node* cond_in, const ngraph::Node* cond_out) {
|
||||
return cond_in == cond_out;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
@ -112,9 +101,10 @@ OutputVector loop(const Node& node) {
|
||||
body_outputs[i] = std::make_shared<default_opset::Unsqueeze>(body_outputs[i], concat_axis_const);
|
||||
}
|
||||
|
||||
const auto& body_loop_out_cond = body_outputs.at(0).get_node_shared_ptr();
|
||||
const auto& cond_in = body_inputs[1];
|
||||
const auto& cond_out = body_outputs[0];
|
||||
// optimization allow to improve nG Loop shape inference
|
||||
if (is_termination_condition_always_true(body_loop_out_cond)) {
|
||||
if (is_termination_condition_always_true(cond_in.get(), cond_out.get_node())) {
|
||||
body_outputs[0] = ngraph::op::Constant::create(ngraph::element::boolean, {1}, {true});
|
||||
}
|
||||
|
||||
|
@ -9,6 +9,12 @@ graph {
|
||||
}
|
||||
node {
|
||||
input: "relu_t"
|
||||
output: "abs_t"
|
||||
op_type: "Abs"
|
||||
name: "abs"
|
||||
}
|
||||
node {
|
||||
input: "abs_t"
|
||||
output: "final_output"
|
||||
name: "ident"
|
||||
op_type: "Identity"
|
||||
|
@ -22,14 +22,20 @@ NGRAPH_TEST(onnx_tensor_names, simple_model) {
|
||||
auto function = onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/tensor_names.onnx"));
|
||||
|
||||
auto ops = function->get_ordered_ops();
|
||||
// Parameter
|
||||
ASSERT_EQ(ops[0]->get_friendly_name(), "input");
|
||||
ASSERT_EQ(ops[0]->get_output_tensor(0).get_names(), std::unordered_set<std::string>{"input"});
|
||||
|
||||
ASSERT_EQ(ops[1]->get_friendly_name(), "relu_t");
|
||||
ASSERT_EQ(ops[1]->get_output_tensor(0).get_names(), std::unordered_set<std::string>{"relu_t"});
|
||||
// ops[2] is a constant created in the ONNX importer as part of Identity operator
|
||||
|
||||
// should be abs_t, but Identity operator gets cut out and that makes Abs become the 'final_output'
|
||||
ASSERT_EQ(ops[2]->get_friendly_name(), "final_output");
|
||||
ASSERT_EQ(ops[2]->get_output_tensor(0).get_names(), std::unordered_set<std::string>{"final_output"});
|
||||
|
||||
// Result node
|
||||
ASSERT_EQ(ops[3]->get_friendly_name(), "final_output");
|
||||
ASSERT_EQ(ops[3]->get_output_tensor(0).get_names(), std::unordered_set<std::string>{"final_output"});
|
||||
ASSERT_EQ(ops[4]->get_friendly_name(), "final_output");
|
||||
|
||||
ASSERT_EQ(function->get_result()->get_input_tensor(0).get_names(), std::unordered_set<std::string>{"final_output"});
|
||||
ASSERT_EQ(function->get_result()->input_value(0).get_tensor().get_names(),
|
||||
|
Loading…
Reference in New Issue
Block a user