diff --git a/inference-engine/tests/ie_test_utils/common_test_utils/ngraph_test_utils.cpp b/inference-engine/tests/ie_test_utils/common_test_utils/ngraph_test_utils.cpp index 867c921c53e..4b8873a8062 100644 --- a/inference-engine/tests/ie_test_utils/common_test_utils/ngraph_test_utils.cpp +++ b/inference-engine/tests/ie_test_utils/common_test_utils/ngraph_test_utils.cpp @@ -139,6 +139,10 @@ private: /// Result compare(ngraph::Node* node1, ngraph::Node* node2, std::ostream& err_log); + void compare_inputs(ngraph::Node* node1, ngraph::Node* node2, std::ostream& err_log); + + void compare_outputs(ngraph::Node* node1, ngraph::Node* node2, std::ostream& err_log); + void add_nodes_inputs_to_queue(ngraph::Node* node1, ngraph::Node* node2); //-- DATA -- @@ -1360,7 +1364,9 @@ Comparator::Result Comparator::compare( auto subgraph1 = dynamic_cast(node1); auto subgraph2 = dynamic_cast(node2); - if (subgraph1 && subgraph2) { + const bool subgraph_nodes = subgraph1 && subgraph2; + + if (subgraph_nodes) { const auto result = subgraph::compare_io(subgraph1, subgraph2); if (!result.valid) { return result; @@ -1388,6 +1394,23 @@ Comparator::Result Comparator::compare( name(node1) + " and " + to_str(node2->inputs().size()) + " for " + name(node2)); } + if (!subgraph_nodes) { + compare_inputs(node1, node2, err_log); + compare_outputs(node1, node2, err_log); + } + + if (should_compare(CmpValues::ATTRIBUTES)) { + const auto result = attributes::compare(node1, node2, m_comparition_flags); + if (!result.valid) { + return result; + } + } + + return Result::ok("Check if any minor error was log in to err_log"); +} + + +void Comparator::compare_inputs(ngraph::Node* node1, ngraph::Node* node2, std::ostream& err_log) { for (size_t i = 0; i < node1->inputs().size(); ++i) { if (should_compare(CmpValues::CONST_VALUES)) { using Constant = ngraph::opset1::Constant; @@ -1435,7 +1458,9 @@ Comparator::Result Comparator::compare( << std::endl; } } +} +void Comparator::compare_outputs(ngraph::Node* node1, ngraph::Node* node2, std::ostream& err_log) { for (int i = 0; i < node1->outputs().size(); ++i) { const auto& tensor1 = node1->output(i).get_tensor(); const auto& tensor2 = node2->output(i).get_tensor(); @@ -1455,15 +1480,6 @@ Comparator::Result Comparator::compare( << i << ") " << node2->output(i).get_partial_shape() << std::endl; } } - - if (should_compare(CmpValues::ATTRIBUTES)) { - const auto result = attributes::compare(node1, node2, m_comparition_flags); - if (!result.valid) { - return result; - } - } - - return Result::ok("Check if any minor error was log in to err_log"); } void Comparator::add_nodes_inputs_to_queue(ngraph::Node* node1, ngraph::Node* node2) {