diff --git a/inference-engine/tests/functional/inference_engine/ir_serialization/models/add_abc_initializers_nan_const.bin b/inference-engine/tests/functional/inference_engine/ir_serialization/models/add_abc_initializers_nan_const.bin
new file mode 100644
index 00000000000..16a41a06ca8
Binary files /dev/null and b/inference-engine/tests/functional/inference_engine/ir_serialization/models/add_abc_initializers_nan_const.bin differ
diff --git a/inference-engine/tests/functional/inference_engine/ir_serialization/models/add_abc_initializers_nan_const.xml b/inference-engine/tests/functional/inference_engine/ir_serialization/models/add_abc_initializers_nan_const.xml
new file mode 100644
index 00000000000..c0ecd4025c5
--- /dev/null
+++ b/inference-engine/tests/functional/inference_engine/ir_serialization/models/add_abc_initializers_nan_const.xml
@@ -0,0 +1,93 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ 2
+ 2
+
+
+ 2
+ 2
+
+
+
+
+
+
+
+ 2
+ 2
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/inference-engine/tests/functional/inference_engine/ir_serialization/serialize.cpp b/inference-engine/tests/functional/inference_engine/ir_serialization/serialize.cpp
index 067ed94bc7d..254622157ad 100644
--- a/inference-engine/tests/functional/inference_engine/ir_serialization/serialize.cpp
+++ b/inference-engine/tests/functional/inference_engine/ir_serialization/serialize.cpp
@@ -64,6 +64,7 @@ INSTANTIATE_TEST_CASE_P(IRSerialization, SerializationTest,
std::make_tuple("split_equal_parts_2d.xml", "split_equal_parts_2d.bin"),
std::make_tuple("addmul_abc.xml", "addmul_abc.bin"),
std::make_tuple("add_abc_initializers.xml", "add_abc_initializers.bin"),
+ std::make_tuple("add_abc_initializers_nan_const.xml", "add_abc_initializers_nan_const.bin"),
std::make_tuple("experimental_detectron_roi_feature_extractor.xml", ""),
std::make_tuple("experimental_detectron_roi_feature_extractor_opset6.xml", ""),
std::make_tuple("experimental_detectron_detection_output.xml", ""),
diff --git a/inference-engine/tests/ie_test_utils/common_test_utils/ngraph_test_utils.cpp b/inference-engine/tests/ie_test_utils/common_test_utils/ngraph_test_utils.cpp
index 4045c8fd00e..81a207d9c6b 100644
--- a/inference-engine/tests/ie_test_utils/common_test_utils/ngraph_test_utils.cpp
+++ b/inference-engine/tests/ie_test_utils/common_test_utils/ngraph_test_utils.cpp
@@ -324,10 +324,30 @@ struct Equal {
}
};
+template <>
+struct Equal {
+ static bool equal_value(ngraph::bfloat16 lhs, ngraph::bfloat16 rhs) {
+ if (lhs.to_bits() == rhs.to_bits()) {
+ return true;
+ }
+ return std::abs(lhs - rhs) < 1e-3;
+ }
+};
+
+template <>
+struct Equal {
+ static bool equal_value(ngraph::float16 lhs, ngraph::float16 rhs) {
+ if (lhs.to_bits() == rhs.to_bits()) {
+ return true;
+ }
+ return std::abs(lhs - rhs) < 1e-3;
+ }
+};
+
template <>
struct Equal {
static bool equal_value(float lhs, float rhs) {
- return std::abs(lhs - rhs) < 1e-5;
+ return std::abs(lhs - rhs) < 1e-4;
}
};
@@ -338,19 +358,11 @@ struct Equal {
}
};
-template <>
-struct Equal> {
- static bool equal_value(const std::vector& lhs, const std::vector& rhs) {
+template
+struct Equal> {
+ static bool equal_value(const std::vector& lhs, const std::vector& rhs) {
return lhs.size() == rhs.size() &&
- std::equal(begin(lhs), end(lhs), begin(rhs), Equal::equal_value);
- }
-};
-
-template <>
-struct Equal> {
- static bool equal_value(const std::vector& lhs, const std::vector& rhs) {
- return lhs.size() == rhs.size() &&
- std::equal(begin(lhs), end(lhs), begin(rhs), Equal::equal_value);
+ std::equal(begin(lhs), end(lhs), begin(rhs), Equal::equal_value);
}
};
@@ -439,6 +451,45 @@ struct Equal {
}
};
+using Constant = ngraph::opset1::Constant;
+template <> struct Equal> {
+ static bool equal_value(const std::shared_ptr& lhs,
+ const std::shared_ptr& rhs) {
+ const auto lhs_t = lhs->get_element_type();
+ const auto rhs_t = rhs->get_element_type();
+ if (lhs_t != rhs_t) {
+ return false;
+ }
+
+ switch (lhs_t) {
+ case ngraph::element::Type_t::bf16: {
+ auto lhs_v = lhs->cast_vector();
+ auto rhs_v = rhs->cast_vector();
+ return Equal>::equal_value(lhs_v, rhs_v);
+ break;
+ }
+ case ngraph::element::Type_t::f16: {
+ const auto &lhs_v = lhs->cast_vector();
+ const auto &rhs_v = rhs->cast_vector();
+ return Equal>::equal_value(lhs_v, rhs_v);
+ break;
+ }
+ case ngraph::element::Type_t::f32: {
+ const auto &lhs_v = lhs->cast_vector();
+ const auto &rhs_v = rhs->cast_vector();
+ return Equal>::equal_value(lhs_v, rhs_v);
+ break;
+ }
+ default: {
+ const auto &lhs_v = lhs->cast_vector();
+ const auto &rhs_v = rhs->cast_vector();
+ return Equal>::equal_value(lhs_v, rhs_v);
+ break;
+ }
+ }
+ return false;
+ }
+};
} // namespace equal
namespace str {
@@ -741,22 +792,13 @@ FunctionsComparator::Result FunctionsComparator::compare(
using Constant = ngraph::opset1::Constant;
auto const1 = ngraph::as_type_ptr(node1->get_input_node_shared_ptr(i));
auto const2 = ngraph::as_type_ptr(node2->get_input_node_shared_ptr(i));
-
- const auto equal = [](std::shared_ptr c1, std::shared_ptr c2) {
- const auto& c1v = c1->cast_vector();
- const auto& c2v = c2->cast_vector();
-
- return c1v.size() == c2v.size() && std::equal(
- begin(c1v), end(c1v), begin(c2v),
- [](const double& s1, const double& s2) {
- return std::abs(s1 - s2) < 0.001;
- });
- };
-
- if (const1 && const2 && !equal(const1, const2)) {
+ using namespace ::attr_comparison::equal;
+ if (const1 && const2 &&
+ !Equal>::equal_value(const1, const2)) {
err_log << "Different Constant values detected\n"
<< node1->description() << " Input(" << i << ") and "
- << node2->description() << " Input(" << i << ")" << std::endl;
+ << node2->description() << " Input(" << i << ")"
+ << std::endl;
}
}