[ONNX] remove unnecessary nodes when handling Identity op (#8060)
This commit is contained in:
parent
c84db94697
commit
bcc1404c58
@ -15,13 +15,7 @@ namespace onnx_import {
|
|||||||
namespace op {
|
namespace op {
|
||||||
namespace set_1 {
|
namespace set_1 {
|
||||||
inline OutputVector identity(const Node& node) {
|
inline OutputVector identity(const Node& node) {
|
||||||
auto input = node.get_ng_inputs().at(0);
|
return node.get_ng_inputs();
|
||||||
if (input.get_element_type() == ngraph::element::boolean) {
|
|
||||||
const auto logic_zero = default_opset::Constant::create(ngraph::element::boolean, {}, {false});
|
|
||||||
return {std::make_shared<default_opset::LogicalOr>(input, logic_zero)};
|
|
||||||
}
|
|
||||||
const auto zero = default_opset::Constant::create(input.get_element_type(), {}, {0});
|
|
||||||
return {std::make_shared<default_opset::Add>(input, zero)};
|
|
||||||
}
|
}
|
||||||
} // namespace set_1
|
} // namespace set_1
|
||||||
|
|
||||||
|
@ -28,24 +28,13 @@ namespace {
|
|||||||
/// As a result ngraph Loop shape inference is able to handle more
|
/// As a result ngraph Loop shape inference is able to handle more
|
||||||
/// cases.
|
/// cases.
|
||||||
///
|
///
|
||||||
/// \param[in] body_out_cond Termination loop condition input of the body of
|
/// \param[in] cond_in boolean input to the loop body depicting loop termination condition
|
||||||
/// the Loop (value updated during Loop iterations).
|
|
||||||
///
|
///
|
||||||
/// \return true if termination condition is true and it cannot be changed
|
/// \param[in] cond_out loop termination condition computed after each iteration
|
||||||
/// during Loop iterations, false otherwise.
|
///
|
||||||
bool is_termination_condition_always_true(const Output<ngraph::Node>& body_out_cond) {
|
/// \return true if termination condition is not modified during loop iterations, false otherwise.
|
||||||
// If body termination condition input matches Indentity op pattern the has
|
bool is_termination_condition_always_true(const ngraph::Node* cond_in, const ngraph::Node* cond_out) {
|
||||||
// value of loop_cond - true
|
return cond_in == cond_out;
|
||||||
// Identity op for boolean value is represented by LogicalOr op whose second
|
|
||||||
// input is always false
|
|
||||||
if (ov::is_type<default_opset::LogicalOr>(body_out_cond.get_node_shared_ptr())) {
|
|
||||||
const auto second_input = body_out_cond.get_node_shared_ptr()->input_value(1).get_node_shared_ptr();
|
|
||||||
if (ngraph::op::is_constant(second_input) && second_input->get_element_type() == element::boolean &&
|
|
||||||
ov::as_type_ptr<default_opset::Constant>(second_input)->cast_vector<bool>().at(0) == false) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
@ -112,9 +101,10 @@ OutputVector loop(const Node& node) {
|
|||||||
body_outputs[i] = std::make_shared<default_opset::Unsqueeze>(body_outputs[i], concat_axis_const);
|
body_outputs[i] = std::make_shared<default_opset::Unsqueeze>(body_outputs[i], concat_axis_const);
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto& body_loop_out_cond = body_outputs.at(0).get_node_shared_ptr();
|
const auto& cond_in = body_inputs[1];
|
||||||
|
const auto& cond_out = body_outputs[0];
|
||||||
// optimization allow to improve nG Loop shape inference
|
// optimization allow to improve nG Loop shape inference
|
||||||
if (is_termination_condition_always_true(body_loop_out_cond)) {
|
if (is_termination_condition_always_true(cond_in.get(), cond_out.get_node())) {
|
||||||
body_outputs[0] = ngraph::op::Constant::create(ngraph::element::boolean, {1}, {true});
|
body_outputs[0] = ngraph::op::Constant::create(ngraph::element::boolean, {1}, {true});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -9,6 +9,12 @@ graph {
|
|||||||
}
|
}
|
||||||
node {
|
node {
|
||||||
input: "relu_t"
|
input: "relu_t"
|
||||||
|
output: "abs_t"
|
||||||
|
op_type: "Abs"
|
||||||
|
name: "abs"
|
||||||
|
}
|
||||||
|
node {
|
||||||
|
input: "abs_t"
|
||||||
output: "final_output"
|
output: "final_output"
|
||||||
name: "ident"
|
name: "ident"
|
||||||
op_type: "Identity"
|
op_type: "Identity"
|
||||||
|
@ -22,14 +22,20 @@ NGRAPH_TEST(onnx_tensor_names, simple_model) {
|
|||||||
auto function = onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/tensor_names.onnx"));
|
auto function = onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/tensor_names.onnx"));
|
||||||
|
|
||||||
auto ops = function->get_ordered_ops();
|
auto ops = function->get_ordered_ops();
|
||||||
|
// Parameter
|
||||||
ASSERT_EQ(ops[0]->get_friendly_name(), "input");
|
ASSERT_EQ(ops[0]->get_friendly_name(), "input");
|
||||||
ASSERT_EQ(ops[0]->get_output_tensor(0).get_names(), std::unordered_set<std::string>{"input"});
|
ASSERT_EQ(ops[0]->get_output_tensor(0).get_names(), std::unordered_set<std::string>{"input"});
|
||||||
|
|
||||||
ASSERT_EQ(ops[1]->get_friendly_name(), "relu_t");
|
ASSERT_EQ(ops[1]->get_friendly_name(), "relu_t");
|
||||||
ASSERT_EQ(ops[1]->get_output_tensor(0).get_names(), std::unordered_set<std::string>{"relu_t"});
|
ASSERT_EQ(ops[1]->get_output_tensor(0).get_names(), std::unordered_set<std::string>{"relu_t"});
|
||||||
// ops[2] is a constant created in the ONNX importer as part of Identity operator
|
|
||||||
|
// should be abs_t, but Identity operator gets cut out and that makes Abs become the 'final_output'
|
||||||
|
ASSERT_EQ(ops[2]->get_friendly_name(), "final_output");
|
||||||
|
ASSERT_EQ(ops[2]->get_output_tensor(0).get_names(), std::unordered_set<std::string>{"final_output"});
|
||||||
|
|
||||||
|
// Result node
|
||||||
ASSERT_EQ(ops[3]->get_friendly_name(), "final_output");
|
ASSERT_EQ(ops[3]->get_friendly_name(), "final_output");
|
||||||
ASSERT_EQ(ops[3]->get_output_tensor(0).get_names(), std::unordered_set<std::string>{"final_output"});
|
ASSERT_EQ(ops[3]->get_output_tensor(0).get_names(), std::unordered_set<std::string>{"final_output"});
|
||||||
ASSERT_EQ(ops[4]->get_friendly_name(), "final_output");
|
|
||||||
|
|
||||||
ASSERT_EQ(function->get_result()->get_input_tensor(0).get_names(), std::unordered_set<std::string>{"final_output"});
|
ASSERT_EQ(function->get_result()->get_input_tensor(0).get_names(), std::unordered_set<std::string>{"final_output"});
|
||||||
ASSERT_EQ(function->get_result()->input_value(0).get_tensor().get_names(),
|
ASSERT_EQ(function->get_result()->input_value(0).get_tensor().get_names(),
|
||||||
|
Loading…
Reference in New Issue
Block a user