Cleanup in ngraph_test_utils.hpp/cpp (#3959)

Co-authored-by: Patryk Elszkowski <patryk.elszkowki@intel.com>
This commit is contained in:
Patryk Elszkowski
2021-01-22 11:02:33 +01:00
committed by GitHub
parent 5993383e31
commit cce0328947
2 changed files with 135 additions and 83 deletions

View File

@@ -1,33 +1,39 @@
// Copyright (C) 2020 Intel Corporation
// Copyright (C) 2020-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <string>
#include <memory>
#include <queue>
#include <assert.h>
#include <ngraph/function.hpp>
#include <ngraph/opsets/opset1.hpp>
#include <ngraph/op/util/op_types.hpp>
#include <ngraph/pass/visualize_tree.hpp>
#include "ngraph_test_utils.hpp"
#include <cassert>
#include <memory>
#include <queue>
#include <string>
#include <ngraph/function.hpp>
#include <ngraph/op/util/op_types.hpp>
#include <ngraph/opsets/opset1.hpp>
#include <ngraph/pass/visualize_tree.hpp>
namespace {
bool isTypeRelaxed(const std::string& type) {
return type.find_first_of("TypeRelaxed") == 0;
}
bool compareTypeInfo(const ngraph::DiscreteTypeInfo& info1, const ngraph::DiscreteTypeInfo& info2) {
if (!isTypeRelaxed(info1.name) && !isTypeRelaxed(info2.name) && (info1.version != info2.version)) {
if (!isTypeRelaxed(info1.name) && !isTypeRelaxed(info2.name) &&
(info1.version != info2.version)) {
return false;
}
const std::string info1Name = isTypeRelaxed(info1.name) && (info1.parent != nullptr) ? info1.parent->name : info1.name;
const std::string info2Name = isTypeRelaxed(info2.name) && (info2.parent != nullptr) ? info2.parent->name : info2.name;
const std::string info1Name =
isTypeRelaxed(info1.name) && (info1.parent != nullptr) ? info1.parent->name : info1.name;
const std::string info2Name =
isTypeRelaxed(info2.name) && (info2.parent != nullptr) ? info2.parent->name : info2.name;
return info1Name == info2Name;
}
bool compare_rt_keys(const std::shared_ptr<ngraph::Node>& node1, const std::shared_ptr<ngraph::Node>& node2) {
template <typename Node>
bool compare_rt_keys(const Node& node1, const Node& node2) {
const auto& first_node_rt_info = node1->get_rt_info();
const auto& second_node_rt_info = node2->get_rt_info();
@@ -52,6 +58,44 @@ bool compare_rt_keys(const std::shared_ptr<ngraph::Node>& node1, const std::shar
return true;
}
bool less_by_name(
const std::shared_ptr<ngraph::op::v0::Result>& l,
const std::shared_ptr<ngraph::op::v0::Result>& r) {
return l->get_friendly_name() < r->get_friendly_name();
}
template <typename T>
std::string to_str(const T& v) {
return std::to_string(v);
}
std::pair<bool, std::string> error(std::string s) {
return {false, std::move(s)};
}
std::string typeInfoToStr(const ngraph::Node::type_info_t& typeInfo) {
return std::string(typeInfo.name) + "/" + to_str(typeInfo.version);
}
template <typename Node>
std::string name(const Node& n) {
return n->get_friendly_name();
}
template <typename Constant>
bool equal(const Constant& c1, const Constant& c2) {
const auto equal_float_str = [](const std::string& s1, const std::string s2) {
return std::abs(std::stof(s1) - std::stof(s2)) < 0.001;
};
const auto& c1v = c1.get_value_strings();
const auto& c2v = c2.get_value_strings();
return c1v.size() == c2v.size() &&
std::equal(begin(c1v), end(c1v), begin(c2v), equal_float_str);
}
} // namespace
std::pair<bool, std::string> compare_functions(
const std::shared_ptr<ngraph::Function>& f1,
const std::shared_ptr<ngraph::Function>& f2,
@@ -70,36 +114,37 @@ std::pair<bool, std::string> compare_functions(
auto f1_results = f1->get_results();
auto f2_results = f2->get_results();
auto compare_nodes_by_name = [](const std::shared_ptr<ngraph::Node> & l, const std::shared_ptr<ngraph::Node> & r)
{ return l->get_friendly_name() < r->get_friendly_name(); };
std::sort(f1_results.begin(), f1_results.end(), compare_nodes_by_name);
std::sort(f2_results.begin(), f2_results.end(), compare_nodes_by_name);
std::sort(f1_results.begin(), f1_results.end(), less_by_name);
std::sort(f2_results.begin(), f2_results.end(), less_by_name);
if (f1_results.size() != f2_results.size()) {
return { false, "Number of results is different: " + std::to_string(f1_results.size()) + " and " + std::to_string(f2_results.size()) };
return error(
"Number of results is different: " + to_str(f1_results.size()) + " and " + to_str(f2_results.size()));
}
const auto& f1_sinks = f1->get_sinks();
const auto& f2_sinks = f2->get_sinks();
if (f1_sinks.size() != f2_sinks.size()) {
return { false, "Number of sinks is different: " + std::to_string(f1_sinks.size()) + " and " + std::to_string(f2_sinks.size()) };
return error(
"Number of sinks is different: " + to_str(f1_sinks.size()) + " and " + to_str(f2_sinks.size()));
}
auto typeInfoToStr = [](const ngraph::Node::type_info_t & typeInfo) {
return std::string(typeInfo.name) + "/" + std::to_string(typeInfo.version);
};
std::ostringstream err_log;
std::queue<std::pair<std::shared_ptr<ngraph::Node>, std::shared_ptr<ngraph::Node>>> q;
using ComparedNodes = std::pair<std::shared_ptr<ngraph::Node>, std::shared_ptr<ngraph::Node>>;
std::queue<ComparedNodes> q;
for (size_t i = 0; i < f1_results.size(); ++i) {
if (compareNames) {
if (f1_results[i]->get_input_node_shared_ptr(0)->get_friendly_name() !=
f2_results[i]->get_input_node_shared_ptr(0)->get_friendly_name()) {
return { false, "Different output names: " + f1_results[i]->get_input_node_shared_ptr(0)->get_friendly_name()
+ " and " + f2_results[i]->get_input_node_shared_ptr(0)->get_friendly_name() };
if (name(f1_results[i]->get_input_node_shared_ptr(0)) !=
name(f2_results[i]->get_input_node_shared_ptr(0))) {
return error(
"Different output names: " + name(f1_results[i]->get_input_node_shared_ptr(0)) +
" and " + name(f2_results[i]->get_input_node_shared_ptr(0)));
}
}
q.push({ f1_results[i], f2_results[i] });
q.push({f1_results[i], f2_results[i]});
}
while (!q.empty()) {
@@ -111,7 +156,7 @@ std::pair<bool, std::string> compare_functions(
auto type_info2 = node2->get_type_info();
if (!compareTypeInfo(type_info1, type_info2)) {
return {false, typeInfoToStr(type_info1) + " != " + typeInfoToStr(type_info2)};
return error(typeInfoToStr(type_info1) + " != " + typeInfoToStr(type_info2));
}
auto subgraph1 = std::dynamic_pointer_cast<ngraph::op::util::SubGraphOp>(node1);
@@ -127,19 +172,23 @@ std::pair<bool, std::string> compare_functions(
const auto& dependencies_1 = node1->get_control_dependencies();
const auto& dependencies_2 = node2->get_control_dependencies();
if (dependencies_1.size() != dependencies_2.size()) {
return {false, "Number of dependencies is different: " + std::to_string(dependencies_1.size()) + " for " + node1->get_friendly_name() +
+ " and " + std::to_string(dependencies_2.size()) + " for " + node2->get_friendly_name()};
return error(
"Number of dependencies is different: " + to_str(dependencies_1.size()) + " for " +
name(node1) + " and " + to_str(dependencies_2.size()) + " for " + name(node2));
}
if (node1->inputs().size() != node2->inputs().size()) {
return {false, "Number of inputs is different: " + std::to_string(node1->inputs().size()) + " for " + node1->get_friendly_name() +
+ " and " + std::to_string(node2->inputs().size()) + " for " + node2->get_friendly_name()};
return error(
"Number of inputs is different: " + to_str(node1->inputs().size()) + " for " +
name(node1) + " and " + to_str(node2->inputs().size()) + " for " + name(node2));
}
if (node1->outputs().size() != node2->outputs().size()) {
return { false, "Number of outputs is different: " + std::to_string(node1->inputs().size()) + " for " + node1->get_friendly_name() +
+ " and " + std::to_string(node2->inputs().size()) + " for " + node2->get_friendly_name()};
return error(
"Number of outputs is different: " + to_str(node1->inputs().size()) + " for " +
name(node1) + " and " + to_str(node2->inputs().size()) + " for " + name(node2));
}
for (int i = 0; i < node1->inputs().size(); ++i) {
@@ -148,19 +197,8 @@ std::pair<bool, std::string> compare_functions(
auto const1 = ngraph::as_type_ptr<Constant>(node1->get_input_node_shared_ptr(i));
auto const2 = ngraph::as_type_ptr<Constant>(node2->get_input_node_shared_ptr(i));
const auto equal = [](const Constant &c1, const Constant &c2) {
const auto equal_float_str = [](const std::string &s1, const std::string s2) {
return std::abs(std::stof(s1) - std::stof(s2)) < 0.001;
};
const auto &c1v = c1.get_value_strings();
const auto &c2v = c2.get_value_strings();
return c1v.size() == c2v.size()
&& std::equal(begin(c1v), end(c1v), begin(c2v), equal_float_str);
};
if (const1 && const2 && !equal(*const1, *const2)) {
err_log << "Different Constant values detected \n"
err_log << "Different Constant values detected\n"
<< node1->description() << " Input(" << i << ") and "
<< node2->description() << " Input(" << i << ")" << std::endl;
}
@@ -168,56 +206,71 @@ std::pair<bool, std::string> compare_functions(
if (comparePrecisions) {
if (node1->input(i).get_element_type() != node2->input(i).get_element_type()) {
err_log << "Different element type detected" << std::endl
<< node1->get_friendly_name() << " Input(" << i << ") " << node1->input(i).get_element_type() << " and "
<< node2->get_friendly_name() << " Input(" << i << ") " << node2->input(i).get_element_type() << std::endl;
err_log << "Different element type detected\n"
<< name(node1) << " Input(" << i << ") "
<< node1->input(i).get_element_type() << " and " << name(node2)
<< " Input(" << i << ") " << node2->input(i).get_element_type()
<< std::endl;
}
}
if (!node1->input(i).get_partial_shape().same_scheme(node2->input(i).get_partial_shape())) {
err_log << "Different shape detected" << std::endl
<< node1->get_friendly_name() << " Input(" << i << ") " << node1->input(i).get_partial_shape() << " and "
<< node2->get_friendly_name() << " Input(" << i << ") " << node2->input(i).get_partial_shape() << std::endl;
if (!node1->input(i).get_partial_shape().same_scheme(
node2->input(i).get_partial_shape())) {
err_log << "Different shape detected\n"
<< name(node1) << " Input(" << i << ") "
<< node1->input(i).get_partial_shape() << " and " << name(node2)
<< " Input(" << i << ") " << node2->input(i).get_partial_shape()
<< std::endl;
}
if (node1->get_input_source_output(i).get_index() != node2->get_input_source_output(i).get_index()) {
if (node1->get_input_source_output(i).get_index() !=
node2->get_input_source_output(i).get_index()) {
auto idx1 = node1->get_input_source_output(i).get_index();
auto idx2 = node2->get_input_source_output(i).get_index();
err_log << "Different ports detected" << std::endl
<< node1->get_friendly_name() << " Input(" << i << ") connected to parent port " << idx1 << " and "
<< node2->get_friendly_name() << " Input(" << i << ") connected to parent port " << idx2 << std::endl;
err_log << "Different ports detected\n"
<< name(node1) << " Input(" << i << ") connected to parent port " << idx1
<< " and " << name(node2) << " Input(" << i << ") connected to parent port "
<< idx2 << std::endl;
}
if (compareRuntimeKeys && !compare_rt_keys(node1, node2)) {
err_log << "Different runtime info detected" << std::endl
<< node1->get_friendly_name() << " and " << node2->get_friendly_name() << " not equal runttime info." << std::endl;;
err_log << "Different runtime info detected\n"
<< name(node1) << " and " << name(node2) << " not equal runtime info."
<< std::endl;
}
q.push({node1->input_value(i).get_node_shared_ptr(), node2->input_value(i).get_node_shared_ptr()});
q.push(
{node1->input_value(i).get_node_shared_ptr(),
node2->input_value(i).get_node_shared_ptr()});
}
for (int i = 0; i < node1->outputs().size(); ++i) {
if (!node1->output(i).get_partial_shape().same_scheme(node2->output(i).get_partial_shape())) {
err_log << "Different shape detected" << std::endl
<< node1->get_friendly_name() << " Output(" << i << ") " << node1->output(i).get_partial_shape() << " and "
<< node2->get_friendly_name() << " Output(" << i << ") " << node2->output(i).get_partial_shape() << std::endl;
if (!node1->output(i).get_partial_shape().same_scheme(
node2->output(i).get_partial_shape())) {
err_log << "Different shape detected\n"
<< name(node1) << " Output(" << i << ") "
<< node1->output(i).get_partial_shape() << " and " << name(node2)
<< " Output(" << i << ") " << node2->output(i).get_partial_shape()
<< std::endl;
}
}
}
return {err_log.str().empty(), err_log.str()};
}
void check_rt_info(const std::shared_ptr<ngraph::Function> & f) {
void check_rt_info(const std::shared_ptr<ngraph::Function>& f) {
static const std::vector<std::string> attrs_to_check{"Variant::RuntimeAttribute::FusedNames"};
std::ostringstream err_log;
for (auto & op : f->get_ops()) {
for (auto& op : f->get_ops()) {
if (ngraph::op::is_constant(op)) continue;
const auto & rt_info = op->get_rt_info();
for (const auto & attr_name : attrs_to_check) {
const auto& rt_info = op->get_rt_info();
for (const auto& attr_name : attrs_to_check) {
if (!rt_info.count(attr_name)) {
err_log << "Node: " << op->get_friendly_name() << " has no attribute: " << attr_name << std::endl;
err_log << "Node: " << op->get_friendly_name() << " has no attribute: " << attr_name
<< std::endl;
}
}
}

View File

@@ -1,4 +1,4 @@
// Copyright (C) 2020 Intel Corporation
// Copyright (C) 2020-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
@@ -7,9 +7,9 @@
#include <memory>
#include <queue>
#include <ngraph/dimension.hpp>
#include <ngraph/function.hpp>
#include <ngraph/opsets/opset1.hpp>
#include <ngraph/dimension.hpp>
#include <ngraph/pass/pass.hpp>
#include "test_common.hpp"
@@ -26,10 +26,9 @@ std::pair<bool, std::string> compare_functions(
const bool compareRuntimeKeys = false,
const bool comparePrecisions = true);
void check_rt_info(const std::shared_ptr<ngraph::Function> & f);
void check_rt_info(const std::shared_ptr<ngraph::Function>& f);
template<typename T>
template <typename T>
std::vector<std::shared_ptr<T>> get(const std::shared_ptr<ngraph::Function>& f) {
std::vector<std::shared_ptr<T>> nodes;
@@ -57,17 +56,17 @@ std::vector<std::shared_ptr<T>> get(const std::shared_ptr<ngraph::Function>& f)
namespace ngraph {
namespace pass {
class InjectionPass;
} // namespace pass
} // namespace ngraph
} // namespace pass
} // namespace ngraph
class ngraph::pass::InjectionPass : public ngraph::pass::FunctionPass {
public:
using injection_callback = std::function<void(std::shared_ptr<ngraph::Function>)>;
explicit InjectionPass(injection_callback callback) : FunctionPass(), m_callback(std::move(callback)) {}
explicit InjectionPass(injection_callback callback)
: FunctionPass(), m_callback(std::move(callback)) {}
bool run_on_function(std::shared_ptr<ngraph::Function> f) override {
m_callback(f);
@@ -105,8 +104,8 @@ public:
set_output_type(1, get_input_element_type(1), get_input_partial_shape(1));
}
std::shared_ptr<Node>
clone_with_new_inputs(const ngraph::OutputVector& new_args) const override {
std::shared_ptr<Node> clone_with_new_inputs(
const ngraph::OutputVector& new_args) const override {
return std::make_shared<TestOpMultiOut>(new_args.at(0), new_args.at(1));
}
};