Drop comparison of inputs and outputs for SubGraph. (#4716)

SubGraph might have inputs and outputs in different order and still be
the same Function.

Co-authored-by: Patryk Elszkowski <patryk.elszkowki@intel.com>
This commit is contained in:
Patryk Elszkowski 2021-03-12 05:11:25 +01:00 committed by GitHub
parent 18dd574864
commit 1799df4cc8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -139,6 +139,10 @@ private:
///
Result compare(ngraph::Node* node1, ngraph::Node* node2, std::ostream& err_log);
void compare_inputs(ngraph::Node* node1, ngraph::Node* node2, std::ostream& err_log);
void compare_outputs(ngraph::Node* node1, ngraph::Node* node2, std::ostream& err_log);
void add_nodes_inputs_to_queue(ngraph::Node* node1, ngraph::Node* node2);
//-- DATA --
@ -1360,7 +1364,9 @@ Comparator::Result Comparator::compare(
auto subgraph1 = dynamic_cast<ngraph::op::util::SubGraphOp*>(node1);
auto subgraph2 = dynamic_cast<ngraph::op::util::SubGraphOp*>(node2);
if (subgraph1 && subgraph2) {
const bool subgraph_nodes = subgraph1 && subgraph2;
if (subgraph_nodes) {
const auto result = subgraph::compare_io(subgraph1, subgraph2);
if (!result.valid) {
return result;
@ -1388,6 +1394,23 @@ Comparator::Result Comparator::compare(
name(node1) + " and " + to_str(node2->inputs().size()) + " for " + name(node2));
}
if (!subgraph_nodes) {
compare_inputs(node1, node2, err_log);
compare_outputs(node1, node2, err_log);
}
if (should_compare(CmpValues::ATTRIBUTES)) {
const auto result = attributes::compare(node1, node2, m_comparition_flags);
if (!result.valid) {
return result;
}
}
return Result::ok("Check if any minor error was log in to err_log");
}
void Comparator::compare_inputs(ngraph::Node* node1, ngraph::Node* node2, std::ostream& err_log) {
for (size_t i = 0; i < node1->inputs().size(); ++i) {
if (should_compare(CmpValues::CONST_VALUES)) {
using Constant = ngraph::opset1::Constant;
@ -1435,7 +1458,9 @@ Comparator::Result Comparator::compare(
<< std::endl;
}
}
}
void Comparator::compare_outputs(ngraph::Node* node1, ngraph::Node* node2, std::ostream& err_log) {
for (int i = 0; i < node1->outputs().size(); ++i) {
const auto& tensor1 = node1->output(i).get_tensor();
const auto& tensor2 = node2->output(i).get_tensor();
@ -1455,15 +1480,6 @@ Comparator::Result Comparator::compare(
<< i << ") " << node2->output(i).get_partial_shape() << std::endl;
}
}
if (should_compare(CmpValues::ATTRIBUTES)) {
const auto result = attributes::compare(node1, node2, m_comparition_flags);
if (!result.valid) {
return result;
}
}
return Result::ok("Check if any minor error was log in to err_log");
}
void Comparator::add_nodes_inputs_to_queue(ngraph::Node* node1, ngraph::Node* node2) {