Refactoring FunctionsComparator - extract node comparison part (#4175)
* Refactoring FunctionsComparator - extract node comparison part * try to fix logic and CentOS bulids * Add negative test for precision * Use fixed ngraph::descriptor::Tensor type instead template type * reorganize ngraph_test_utils.cpp * Cleanup after merge master into branch Co-authored-by: Patryk Elszkowski <patryk.elszkowki@intel.com>
This commit is contained in:
committed by
GitHub
parent
769bb77095
commit
a73e997480
@@ -522,3 +522,42 @@ TEST(TransformationTests, DummyOpNegativeNotSupportedType) {
|
||||
EXPECT_FALSE(res.valid);
|
||||
EXPECT_THAT(res.message, HasSubstr(" [drop `void` comparison which is '"));
|
||||
}
|
||||
|
||||
TEST(TransformationTests, DifferentPrecisionVersusAttributes) {
|
||||
const auto createReadValueFunc = [](ngraph::element::Type t) {
|
||||
using namespace ngraph::opset5;
|
||||
|
||||
auto input1 = std::make_shared<Parameter>(t, ngraph::Shape{15, 20, 3});
|
||||
auto node = std::make_shared<ReadValue>(input1, "1");
|
||||
|
||||
return std::make_shared<Function>(OutputVector{node}, ParameterVector{input1});
|
||||
};
|
||||
|
||||
const auto& f1 = createReadValueFunc(ngraph::element::f16);
|
||||
const auto& f2 = createReadValueFunc(ngraph::element::i16);
|
||||
|
||||
///
|
||||
/// if FunctionComparator::ATTRIBUTES is select error from Attribute comparator override error
|
||||
/// found when FunctionComparator::PRECISION is enabled
|
||||
///
|
||||
|
||||
{ // check precision only
|
||||
const auto fc = FunctionsComparator::no_default().enable(FunctionsComparator::PRECISIONS);
|
||||
const auto res = fc.compare(f1, f2);
|
||||
EXPECT_FALSE(res.valid);
|
||||
EXPECT_THAT(res.message, HasSubstr("Different element type detected"));
|
||||
EXPECT_THAT(res.message, HasSubstr("f16"));
|
||||
EXPECT_THAT(res.message, HasSubstr("i16"));
|
||||
}
|
||||
|
||||
{ // check precision and attributes
|
||||
const auto fc = FunctionsComparator::no_default()
|
||||
.enable(FunctionsComparator::PRECISIONS)
|
||||
.enable(FunctionsComparator::ATTRIBUTES);
|
||||
const auto res = fc.compare(f1, f2);
|
||||
EXPECT_FALSE(res.valid);
|
||||
EXPECT_THAT(res.message, HasSubstr("Comparison of attributes failed for nodes "));
|
||||
EXPECT_THAT(res.message, HasSubstr("f16"));
|
||||
EXPECT_THAT(res.message, HasSubstr("i16"));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -77,10 +77,6 @@ std::string to_str(const T& v) {
|
||||
return std::to_string(v);
|
||||
}
|
||||
|
||||
FunctionsComparator::Result error(std::string s) {
|
||||
return {false, std::move(s)};
|
||||
}
|
||||
|
||||
std::string typeInfoToStr(const ngraph::Node::type_info_t& typeInfo) {
|
||||
return std::string(typeInfo.name) + "/" + to_str(typeInfo.version);
|
||||
}
|
||||
@@ -90,6 +86,61 @@ std::string name(const Node& n) {
|
||||
return n->get_friendly_name();
|
||||
}
|
||||
|
||||
std::string tensor_names(const ngraph::descriptor::Tensor& t) {
|
||||
std::string n;
|
||||
const char* glue = "";
|
||||
for (const auto& name : t.get_names()) {
|
||||
n.append(glue).append(name);
|
||||
glue = ", ";
|
||||
}
|
||||
return "\"" + n + "\"";
|
||||
}
|
||||
|
||||
class Comparator {
|
||||
public:
|
||||
using CmpValues = FunctionsComparator::CmpValues;
|
||||
using Result = FunctionsComparator::Result;
|
||||
using ComparedNodes = std::pair<ngraph::Node*, ngraph::Node*>;
|
||||
|
||||
explicit Comparator(CmpValues f) : m_comparition_flags(f) {}
|
||||
|
||||
Result compare(
|
||||
const std::shared_ptr<ngraph::Function>& f1, const std::shared_ptr<ngraph::Function>& f2);
|
||||
|
||||
Result compare(ngraph::Node* node1, ngraph::Node* node2) {
|
||||
std::stringstream errors;
|
||||
const auto result = compare(node1, node2, errors);
|
||||
if (!result.valid) {
|
||||
return result;
|
||||
}
|
||||
const auto msg = errors.str();
|
||||
return msg.empty() ? Result::ok() : Result::error(msg);
|
||||
}
|
||||
|
||||
Comparator recreate() const {
|
||||
return Comparator(m_comparition_flags);
|
||||
}
|
||||
|
||||
private:
|
||||
bool should_compare(CmpValues f) const noexcept {
|
||||
return m_comparition_flags & f;
|
||||
}
|
||||
|
||||
///
|
||||
/// \param err_log - will be fill by minor errors if happen
|
||||
/// \return only fatality error if some minor one appears it will be add to err_log
|
||||
///
|
||||
Result compare(ngraph::Node* node1, ngraph::Node* node2, std::ostream& err_log);
|
||||
|
||||
void add_nodes_inputs_to_queue(ngraph::Node* node1, ngraph::Node* node2);
|
||||
|
||||
//-- DATA --
|
||||
CmpValues m_comparition_flags;
|
||||
|
||||
std::queue<ComparedNodes> q;
|
||||
std::unordered_set<ngraph::Node*> used;
|
||||
};
|
||||
|
||||
namespace attr_comparison {
|
||||
|
||||
using AttrName = std::string;
|
||||
@@ -240,13 +291,20 @@ public:
|
||||
class ReadAndStoreAttributes : public ngraph::AttributeVisitor, protected storage::Storage {
|
||||
public:
|
||||
void on_adapter(const std::string& name, ngraph::ValueAccessor<void>& adapter) override {
|
||||
if (auto inputs = ngraph::as_type<ngraph::AttributeAdapter<SubGraphOpInputDescription>>(&adapter)) {
|
||||
if (auto inputs =
|
||||
ngraph::as_type<ngraph::AttributeAdapter<SubGraphOpInputDescription>>(&adapter)) {
|
||||
insert(name, inputs->get());
|
||||
} else if (auto outputs = ngraph::as_type<ngraph::AttributeAdapter<SubGraphOpOutputDescription>>(&adapter)) {
|
||||
} else if (
|
||||
auto outputs =
|
||||
ngraph::as_type<ngraph::AttributeAdapter<SubGraphOpOutputDescription>>(&adapter)) {
|
||||
insert(name, outputs->get());
|
||||
} else if (auto ports = ngraph::as_type<ngraph::AttributeAdapter<SpecialBodyPorts>>(&adapter)) {
|
||||
} else if (
|
||||
auto ports = ngraph::as_type<ngraph::AttributeAdapter<SpecialBodyPorts>>(&adapter)) {
|
||||
insert(name, ports->get());
|
||||
} else if (auto a = ngraph::as_type<ngraph::AttributeAdapter<std::shared_ptr<ngraph::runtime::AlignedBuffer>>>(&adapter)) {
|
||||
} else if (
|
||||
auto a = ngraph::as_type<
|
||||
ngraph::AttributeAdapter<std::shared_ptr<ngraph::runtime::AlignedBuffer>>>(
|
||||
&adapter)) {
|
||||
const auto beg = static_cast<unsigned char*>(a->get()->get_ptr());
|
||||
const auto end = beg + a->get()->size();
|
||||
insert(name, storage::MemoryChunk{storage::MemoryChunk::Data(beg, end)});
|
||||
@@ -326,22 +384,22 @@ struct Equal {
|
||||
|
||||
template <>
|
||||
struct Equal<ngraph::bfloat16> {
|
||||
static bool equal_value(ngraph::bfloat16 lhs, ngraph::bfloat16 rhs) {
|
||||
if (lhs.to_bits() == rhs.to_bits()) {
|
||||
return true;
|
||||
static bool equal_value(ngraph::bfloat16 lhs, ngraph::bfloat16 rhs) {
|
||||
if (lhs.to_bits() == rhs.to_bits()) {
|
||||
return true;
|
||||
}
|
||||
return std::abs(lhs - rhs) < 1e-3;
|
||||
}
|
||||
return std::abs(lhs - rhs) < 1e-3;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Equal<ngraph::float16> {
|
||||
static bool equal_value(ngraph::float16 lhs, ngraph::float16 rhs) {
|
||||
if (lhs.to_bits() == rhs.to_bits()) {
|
||||
return true;
|
||||
static bool equal_value(ngraph::float16 lhs, ngraph::float16 rhs) {
|
||||
if (lhs.to_bits() == rhs.to_bits()) {
|
||||
return true;
|
||||
}
|
||||
return std::abs(lhs - rhs) < 1e-3;
|
||||
}
|
||||
return std::abs(lhs - rhs) < 1e-3;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
@@ -452,9 +510,10 @@ struct Equal<SpecialBodyPorts> {
|
||||
};
|
||||
|
||||
using Constant = ngraph::opset1::Constant;
|
||||
template <> struct Equal<std::shared_ptr<Constant>> {
|
||||
static bool equal_value(const std::shared_ptr<Constant>& lhs,
|
||||
const std::shared_ptr<Constant>& rhs) {
|
||||
template <>
|
||||
struct Equal<std::shared_ptr<Constant>> {
|
||||
static bool equal_value(
|
||||
const std::shared_ptr<Constant>& lhs, const std::shared_ptr<Constant>& rhs) {
|
||||
const auto lhs_t = lhs->get_element_type();
|
||||
const auto rhs_t = rhs->get_element_type();
|
||||
if (lhs_t != rhs_t) {
|
||||
@@ -469,20 +528,20 @@ template <> struct Equal<std::shared_ptr<Constant>> {
|
||||
break;
|
||||
}
|
||||
case ngraph::element::Type_t::f16: {
|
||||
const auto &lhs_v = lhs->cast_vector<ngraph::float16>();
|
||||
const auto &rhs_v = rhs->cast_vector<ngraph::float16>();
|
||||
const auto& lhs_v = lhs->cast_vector<ngraph::float16>();
|
||||
const auto& rhs_v = rhs->cast_vector<ngraph::float16>();
|
||||
return Equal<std::vector<ngraph::float16>>::equal_value(lhs_v, rhs_v);
|
||||
break;
|
||||
}
|
||||
case ngraph::element::Type_t::f32: {
|
||||
const auto &lhs_v = lhs->cast_vector<float>();
|
||||
const auto &rhs_v = rhs->cast_vector<float>();
|
||||
const auto& lhs_v = lhs->cast_vector<float>();
|
||||
const auto& rhs_v = rhs->cast_vector<float>();
|
||||
return Equal<std::vector<float>>::equal_value(lhs_v, rhs_v);
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
const auto &lhs_v = lhs->cast_vector<double>();
|
||||
const auto &rhs_v = rhs->cast_vector<double>();
|
||||
const auto& lhs_v = lhs->cast_vector<double>();
|
||||
const auto& rhs_v = rhs->cast_vector<double>();
|
||||
return Equal<std::vector<double>>::equal_value(lhs_v, rhs_v);
|
||||
break;
|
||||
}
|
||||
@@ -551,13 +610,20 @@ public:
|
||||
return;
|
||||
}
|
||||
m_visited_attributes.insert(name);
|
||||
if (auto inputs = ngraph::as_type<ngraph::AttributeAdapter<SubGraphOpInputDescription>>(&adapter)) {
|
||||
if (auto inputs =
|
||||
ngraph::as_type<ngraph::AttributeAdapter<SubGraphOpInputDescription>>(&adapter)) {
|
||||
verify(name, inputs->get());
|
||||
} else if (auto outputs = ngraph::as_type<ngraph::AttributeAdapter<SubGraphOpOutputDescription>>(&adapter)) {
|
||||
} else if (
|
||||
auto outputs =
|
||||
ngraph::as_type<ngraph::AttributeAdapter<SubGraphOpOutputDescription>>(&adapter)) {
|
||||
verify(name, outputs->get());
|
||||
} else if (auto ports = ngraph::as_type<ngraph::AttributeAdapter<SpecialBodyPorts>>(&adapter)) {
|
||||
} else if (
|
||||
auto ports = ngraph::as_type<ngraph::AttributeAdapter<SpecialBodyPorts>>(&adapter)) {
|
||||
verify(name, ports->get());
|
||||
} else if (auto a = ngraph::as_type<ngraph::AttributeAdapter<std::shared_ptr<ngraph::runtime::AlignedBuffer>>>(&adapter)) {
|
||||
} else if (
|
||||
auto a = ngraph::as_type<
|
||||
ngraph::AttributeAdapter<std::shared_ptr<ngraph::runtime::AlignedBuffer>>>(
|
||||
&adapter)) {
|
||||
m_visited_attributes.insert(name);
|
||||
const auto ref_value = m_attr_ref.get<storage::MemoryChunk>(name);
|
||||
if (!ref_value) {
|
||||
@@ -660,7 +726,7 @@ class CompareNodesAttributes {
|
||||
public:
|
||||
CompareNodesAttributes() : m_compare_attr(m_store_attr) {}
|
||||
|
||||
attr_comparison::ReadAndStoreAttributes& get_ref_reder() {
|
||||
attr_comparison::ReadAndStoreAttributes& get_ref_reader() {
|
||||
return m_store_attr;
|
||||
}
|
||||
|
||||
@@ -691,11 +757,8 @@ private:
|
||||
attr_comparison::ReadAndCompareAttributes m_compare_attr;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
FunctionsComparator::Result FunctionsComparator::compare(
|
||||
const std::shared_ptr<ngraph::Function>& f1,
|
||||
const std::shared_ptr<ngraph::Function>& f2) const {
|
||||
Comparator::Result Comparator::compare(
|
||||
const std::shared_ptr<ngraph::Function>& f1, const std::shared_ptr<ngraph::Function>& f2) {
|
||||
/*
|
||||
* This function compares two nGraph functions and requires them to have exactly one output
|
||||
* + Check nodes types
|
||||
@@ -712,7 +775,7 @@ FunctionsComparator::Result FunctionsComparator::compare(
|
||||
std::sort(f2_results.begin(), f2_results.end(), less_by_name);
|
||||
|
||||
if (f1_results.size() != f2_results.size()) {
|
||||
return error(
|
||||
return Result::error(
|
||||
"Number of results is different: " + to_str(f1_results.size()) + " and " +
|
||||
to_str(f2_results.size()));
|
||||
}
|
||||
@@ -720,22 +783,16 @@ FunctionsComparator::Result FunctionsComparator::compare(
|
||||
const auto& f1_sinks = f1->get_sinks();
|
||||
const auto& f2_sinks = f2->get_sinks();
|
||||
if (f1_sinks.size() != f2_sinks.size()) {
|
||||
return error(
|
||||
return Result::error(
|
||||
"Number of sinks is different: " + to_str(f1_sinks.size()) + " and " +
|
||||
to_str(f2_sinks.size()));
|
||||
}
|
||||
|
||||
std::ostringstream err_log;
|
||||
|
||||
using ComparedNodes = std::pair<ngraph::Node*, ngraph::Node*>;
|
||||
std::queue<ComparedNodes> q;
|
||||
std::unordered_set<ngraph::Node*> used;
|
||||
|
||||
for (size_t i = 0; i < f1_results.size(); ++i) {
|
||||
if (should_compare(NAMES)) {
|
||||
if (should_compare(CmpValues::NAMES)) {
|
||||
if (name(f1_results[i]->get_input_node_shared_ptr(0)) !=
|
||||
name(f2_results[i]->get_input_node_shared_ptr(0))) {
|
||||
return error(
|
||||
return Result::error(
|
||||
"Different output names: " + name(f1_results[i]->get_input_node_shared_ptr(0)) +
|
||||
" and " + name(f2_results[i]->get_input_node_shared_ptr(0)));
|
||||
}
|
||||
@@ -744,150 +801,162 @@ FunctionsComparator::Result FunctionsComparator::compare(
|
||||
used.insert(f1_results[i].get());
|
||||
}
|
||||
|
||||
std::stringstream errors;
|
||||
|
||||
while (!q.empty()) {
|
||||
auto node1 = q.front().first;
|
||||
auto node2 = q.front().second;
|
||||
ngraph::Node* const node1 = q.front().first;
|
||||
ngraph::Node* const node2 = q.front().second;
|
||||
q.pop();
|
||||
|
||||
auto type_info1 = node1->get_type_info();
|
||||
auto type_info2 = node2->get_type_info();
|
||||
|
||||
if (!compareTypeInfo(type_info1, type_info2)) {
|
||||
return error(typeInfoToStr(type_info1) + " != " + typeInfoToStr(type_info2));
|
||||
const auto result = compare(node1, node2, errors);
|
||||
if (!result.valid) {
|
||||
return result;
|
||||
}
|
||||
|
||||
auto subgraph1 = dynamic_cast<ngraph::op::util::SubGraphOp*>(node1);
|
||||
auto subgraph2 = dynamic_cast<ngraph::op::util::SubGraphOp*>(node2);
|
||||
add_nodes_inputs_to_queue(node1, node2);
|
||||
}
|
||||
const auto msg = errors.str();
|
||||
return msg.empty() ? Result::ok() : Result::error(msg);
|
||||
}
|
||||
|
||||
if (subgraph1 && subgraph2) {
|
||||
auto result = compare(subgraph1->get_function(), subgraph2->get_function());
|
||||
if (!result.valid) {
|
||||
return result;
|
||||
}
|
||||
}
|
||||
Comparator::Result Comparator::compare(
|
||||
ngraph::Node* node1, ngraph::Node* node2, std::ostream& err_log) {
|
||||
auto type_info1 = node1->get_type_info();
|
||||
auto type_info2 = node2->get_type_info();
|
||||
|
||||
const auto& dependencies_1 = node1->get_control_dependencies();
|
||||
const auto& dependencies_2 = node2->get_control_dependencies();
|
||||
if (!compareTypeInfo(type_info1, type_info2)) {
|
||||
return Result::error(typeInfoToStr(type_info1) + " != " + typeInfoToStr(type_info2));
|
||||
}
|
||||
|
||||
if (dependencies_1.size() != dependencies_2.size()) {
|
||||
return error(
|
||||
"Number of dependencies is different: " + to_str(dependencies_1.size()) + " for " +
|
||||
name(node1) + " and " + to_str(dependencies_2.size()) + " for " + name(node2));
|
||||
}
|
||||
auto subgraph1 = dynamic_cast<ngraph::op::util::SubGraphOp*>(node1);
|
||||
auto subgraph2 = dynamic_cast<ngraph::op::util::SubGraphOp*>(node2);
|
||||
|
||||
if (node1->inputs().size() != node2->inputs().size()) {
|
||||
return error(
|
||||
"Number of inputs is different: " + to_str(node1->inputs().size()) + " for " +
|
||||
name(node1) + " and " + to_str(node2->inputs().size()) + " for " + name(node2));
|
||||
}
|
||||
|
||||
if (node1->outputs().size() != node2->outputs().size()) {
|
||||
return error(
|
||||
"Number of outputs is different: " + to_str(node1->inputs().size()) + " for " +
|
||||
name(node1) + " and " + to_str(node2->inputs().size()) + " for " + name(node2));
|
||||
}
|
||||
|
||||
for (int i = 0; i < node1->inputs().size(); ++i) {
|
||||
if (should_compare(CONST_VALUES)) {
|
||||
using Constant = ngraph::opset1::Constant;
|
||||
auto const1 = ngraph::as_type_ptr<Constant>(node1->get_input_node_shared_ptr(i));
|
||||
auto const2 = ngraph::as_type_ptr<Constant>(node2->get_input_node_shared_ptr(i));
|
||||
using namespace ::attr_comparison::equal;
|
||||
if (const1 && const2 &&
|
||||
!Equal<std::shared_ptr<Constant>>::equal_value(const1, const2)) {
|
||||
err_log << "Different Constant values detected\n"
|
||||
<< node1->description() << " Input(" << i << ") and "
|
||||
<< node2->description() << " Input(" << i << ")"
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
if (should_compare(PRECISIONS)) {
|
||||
if (node1->input(i).get_element_type() != node2->input(i).get_element_type()) {
|
||||
err_log << "Different element type detected\n"
|
||||
<< name(node1) << " Input(" << i << ") "
|
||||
<< node1->input(i).get_element_type() << " and " << name(node2)
|
||||
<< " Input(" << i << ") " << node2->input(i).get_element_type()
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
if (!node1->input(i).get_partial_shape().same_scheme(
|
||||
node2->input(i).get_partial_shape())) {
|
||||
err_log << "Different shape detected\n"
|
||||
<< name(node1) << " Input(" << i << ") "
|
||||
<< node1->input(i).get_partial_shape() << " and " << name(node2)
|
||||
<< " Input(" << i << ") " << node2->input(i).get_partial_shape()
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
if (node1->get_input_source_output(i).get_index() !=
|
||||
node2->get_input_source_output(i).get_index()) {
|
||||
auto idx1 = node1->get_input_source_output(i).get_index();
|
||||
auto idx2 = node2->get_input_source_output(i).get_index();
|
||||
err_log << "Different ports detected\n"
|
||||
<< name(node1) << " Input(" << i << ") connected to parent port " << idx1
|
||||
<< " and " << name(node2) << " Input(" << i << ") connected to parent port "
|
||||
<< idx2 << std::endl;
|
||||
}
|
||||
|
||||
if (should_compare(RUNTIME_KEYS) && !compare_rt_keys(node1, node2)) {
|
||||
err_log << "Different runtime info detected\n"
|
||||
<< name(node1) << " and " << name(node2) << " not equal runtime info."
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
if (!used.count(node1->input_value(i).get_node())) {
|
||||
q.push({node1->input_value(i).get_node(), node2->input_value(i).get_node()});
|
||||
used.insert(node1->input_value(i).get_node());
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < node1->outputs().size(); ++i) {
|
||||
const auto& tensor1 = node1->output(i).get_tensor();
|
||||
const auto& tensor2 = node2->output(i).get_tensor();
|
||||
|
||||
if (tensor1.get_names() != tensor2.get_names()) {
|
||||
std::string names1 = "";
|
||||
for (const auto& name : tensor1.get_names()) {
|
||||
if (!names1.empty())
|
||||
names1 += ", ";
|
||||
names1 += name;
|
||||
}
|
||||
names1 = "\"" + names1 + "\"";
|
||||
std::string names2 = "";
|
||||
for (const auto& name : tensor2.get_names()) {
|
||||
if (!names2.empty())
|
||||
names2 += ", ";
|
||||
names2 += name;
|
||||
}
|
||||
names2 = "\"" + names2 + "\"";
|
||||
err_log << "Output tensors names " << names1 << " and " << names2 << " are different for nodes: "
|
||||
<< node1->get_friendly_name() << " and " << node2->get_friendly_name() << std::endl;
|
||||
}
|
||||
if (!node1->output(i).get_partial_shape().same_scheme(
|
||||
node2->output(i).get_partial_shape())) {
|
||||
err_log << "Different shape detected\n"
|
||||
<< name(node1) << " Output(" << i << ") "
|
||||
<< node1->output(i).get_partial_shape() << " and " << name(node2)
|
||||
<< " Output(" << i << ") " << node2->output(i).get_partial_shape()
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
if (should_compare(ATTRIBUTES)) {
|
||||
CompareNodesAttributes compare_nodes;
|
||||
node1->visit_attributes(compare_nodes.get_ref_reder());
|
||||
node2->visit_attributes(compare_nodes.get_cmp_reader());
|
||||
if (!compare_nodes.equal()) {
|
||||
return error(
|
||||
"Comparison of attributes failed for nodes " + name(node1) + ", " +
|
||||
name(node2) + " [cmp status: " + to_string(compare_nodes) + "]");
|
||||
}
|
||||
if (subgraph1 && subgraph2) {
|
||||
auto result = recreate().compare(subgraph1->get_function(), subgraph2->get_function());
|
||||
if (!result.valid) {
|
||||
return result;
|
||||
}
|
||||
}
|
||||
return {err_log.str().empty(), err_log.str()};
|
||||
|
||||
const auto& dependencies_1 = node1->get_control_dependencies();
|
||||
const auto& dependencies_2 = node2->get_control_dependencies();
|
||||
|
||||
if (dependencies_1.size() != dependencies_2.size()) {
|
||||
return Result::error(
|
||||
"Number of dependencies is different: " + to_str(dependencies_1.size()) + " for " +
|
||||
name(node1) + " and " + to_str(dependencies_2.size()) + " for " + name(node2));
|
||||
}
|
||||
|
||||
if (node1->inputs().size() != node2->inputs().size()) {
|
||||
return Result::error(
|
||||
"Number of inputs is different: " + to_str(node1->inputs().size()) + " for " +
|
||||
name(node1) + " and " + to_str(node2->inputs().size()) + " for " + name(node2));
|
||||
}
|
||||
|
||||
if (node1->outputs().size() != node2->outputs().size()) {
|
||||
return Result::error(
|
||||
"Number of outputs is different: " + to_str(node1->inputs().size()) + " for " +
|
||||
name(node1) + " and " + to_str(node2->inputs().size()) + " for " + name(node2));
|
||||
}
|
||||
|
||||
for (int i = 0; i < node1->inputs().size(); ++i) {
|
||||
if (should_compare(CmpValues::CONST_VALUES)) {
|
||||
using Constant = ngraph::opset1::Constant;
|
||||
auto const1 = ngraph::as_type_ptr<Constant>(node1->get_input_node_shared_ptr(i));
|
||||
auto const2 = ngraph::as_type_ptr<Constant>(node2->get_input_node_shared_ptr(i));
|
||||
using namespace ::attr_comparison::equal;
|
||||
if (const1 && const2 &&
|
||||
!Equal<std::shared_ptr<Constant>>::equal_value(const1, const2)) {
|
||||
err_log << "Different Constant values detected\n"
|
||||
<< node1->description() << " Input(" << i << ") and "
|
||||
<< node2->description() << " Input(" << i << ")" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
if (should_compare(CmpValues::PRECISIONS)) {
|
||||
if (node1->input(i).get_element_type() != node2->input(i).get_element_type()) {
|
||||
err_log << "Different element type detected\n"
|
||||
<< name(node1) << " Input(" << i << ") "
|
||||
<< node1->input(i).get_element_type() << " and " << name(node2) << " Input("
|
||||
<< i << ") " << node2->input(i).get_element_type() << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
if (!node1->input(i).get_partial_shape().same_scheme(node2->input(i).get_partial_shape())) {
|
||||
err_log << "Different shape detected\n"
|
||||
<< name(node1) << " Input(" << i << ") " << node1->input(i).get_partial_shape()
|
||||
<< " and " << name(node2) << " Input(" << i << ") "
|
||||
<< node2->input(i).get_partial_shape() << std::endl;
|
||||
}
|
||||
|
||||
if (node1->get_input_source_output(i).get_index() !=
|
||||
node2->get_input_source_output(i).get_index()) {
|
||||
auto idx1 = node1->get_input_source_output(i).get_index();
|
||||
auto idx2 = node2->get_input_source_output(i).get_index();
|
||||
err_log << "Different ports detected\n"
|
||||
<< name(node1) << " Input(" << i << ") connected to parent port " << idx1
|
||||
<< " and " << name(node2) << " Input(" << i << ") connected to parent port "
|
||||
<< idx2 << std::endl;
|
||||
}
|
||||
|
||||
if (should_compare(CmpValues::RUNTIME_KEYS) && !compare_rt_keys(node1, node2)) {
|
||||
err_log << "Different runtime info detected\n"
|
||||
<< name(node1) << " and " << name(node2) << " not equal runtime info."
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < node1->outputs().size(); ++i) {
|
||||
const auto& tensor1 = node1->output(i).get_tensor();
|
||||
const auto& tensor2 = node2->output(i).get_tensor();
|
||||
|
||||
if (tensor1.get_names() != tensor2.get_names()) {
|
||||
err_log << "Output tensors names " << tensor_names(tensor1) << " and "
|
||||
<< tensor_names(tensor2)
|
||||
<< " are different for nodes: " << node1->get_friendly_name() << " and "
|
||||
<< node2->get_friendly_name() << std::endl;
|
||||
}
|
||||
|
||||
if (!node1->output(i).get_partial_shape().same_scheme(
|
||||
node2->output(i).get_partial_shape())) {
|
||||
err_log << "Different shape detected\n"
|
||||
<< name(node1) << " Output(" << i << ") "
|
||||
<< node1->output(i).get_partial_shape() << " and " << name(node2) << " Output("
|
||||
<< i << ") " << node2->output(i).get_partial_shape() << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
if (should_compare(CmpValues::ATTRIBUTES)) {
|
||||
CompareNodesAttributes compare_nodes;
|
||||
node1->visit_attributes(compare_nodes.get_ref_reader());
|
||||
node2->visit_attributes(compare_nodes.get_cmp_reader());
|
||||
if (!compare_nodes.equal()) {
|
||||
return Result::error(
|
||||
"Comparison of attributes failed for nodes " + name(node1) + ", " + name(node2) +
|
||||
" [cmp status: " + to_string(compare_nodes) + "]");
|
||||
}
|
||||
}
|
||||
|
||||
return Result::ok("Check if any minor error was log in to err_log");
|
||||
}
|
||||
|
||||
void Comparator::add_nodes_inputs_to_queue(ngraph::Node* node1, ngraph::Node* node2) {
|
||||
for (int i = 0; i < node1->inputs().size(); ++i) {
|
||||
if (!used.count(node1->input_value(i).get_node())) {
|
||||
q.push({node1->input_value(i).get_node(), node2->input_value(i).get_node()});
|
||||
used.insert(node1->input_value(i).get_node());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
FunctionsComparator::Result FunctionsComparator::compare(
|
||||
const std::shared_ptr<ngraph::Function>& f1,
|
||||
const std::shared_ptr<ngraph::Function>& f2) const {
|
||||
return Comparator(m_comparition_flags).compare(f1, f2);
|
||||
}
|
||||
|
||||
void check_rt_info(const std::shared_ptr<ngraph::Function>& f) {
|
||||
static const std::vector<std::string> attrs_to_check{"Variant::RuntimeAttribute::FusedNames"};
|
||||
|
||||
|
||||
@@ -32,6 +32,13 @@ public:
|
||||
struct Result {
|
||||
bool valid;
|
||||
std::string message;
|
||||
|
||||
static Result ok(std::string msg = {}) {
|
||||
return {true, std::move(msg)};
|
||||
}
|
||||
static Result error(std::string msg) {
|
||||
return {false, std::move(msg)};
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr FunctionsComparator no_default() noexcept {
|
||||
|
||||
Reference in New Issue
Block a user