From 430e898c339f5f0ad29e2e5fe2317bec98401ae8 Mon Sep 17 00:00:00 2001 From: Mateusz Tabaka Date: Mon, 21 Feb 2022 09:37:17 +0100 Subject: [PATCH] Add bf16, f64, i4, u4, i16, u16 types to Equal's evaluate (#10508) * Add f64 type to Equal's evaluate Required by t2t-vit models. Ticket: 79610. * add also i16 u16 because prior_box tests fail with "Check eval_status failed at" * code style * add i4, u4, bf16 to equal's evaluate --- src/core/src/op/equal.cpp | 8 +++++++- src/core/src/validation_util.cpp | 3 ++- src/core/tests/constant_folding.cpp | 20 +++++++++++++++++++- 3 files changed, 28 insertions(+), 3 deletions(-) diff --git a/src/core/src/op/equal.cpp b/src/core/src/op/equal.cpp index 0fada0dc463..4ea2ccef238 100644 --- a/src/core/src/op/equal.cpp +++ b/src/core/src/op/equal.cpp @@ -35,14 +35,20 @@ bool evaluate_equal(const HostTensorPtr& arg0, out->set_broadcast(broadcast_spec, arg0, arg1, element::boolean); switch (arg0->get_element_type()) { NGRAPH_TYPE_CASE(evaluate_equal, boolean, arg0, arg1, out, broadcast_spec); + NGRAPH_TYPE_CASE(evaluate_equal, i4, arg0, arg1, out, broadcast_spec); NGRAPH_TYPE_CASE(evaluate_equal, i8, arg0, arg1, out, broadcast_spec); - NGRAPH_TYPE_CASE(evaluate_equal, u8, arg0, arg1, out, broadcast_spec); + NGRAPH_TYPE_CASE(evaluate_equal, i16, arg0, arg1, out, broadcast_spec); NGRAPH_TYPE_CASE(evaluate_equal, i32, arg0, arg1, out, broadcast_spec); NGRAPH_TYPE_CASE(evaluate_equal, i64, arg0, arg1, out, broadcast_spec); + NGRAPH_TYPE_CASE(evaluate_equal, u4, arg0, arg1, out, broadcast_spec); + NGRAPH_TYPE_CASE(evaluate_equal, u8, arg0, arg1, out, broadcast_spec); + NGRAPH_TYPE_CASE(evaluate_equal, u16, arg0, arg1, out, broadcast_spec); NGRAPH_TYPE_CASE(evaluate_equal, u32, arg0, arg1, out, broadcast_spec); NGRAPH_TYPE_CASE(evaluate_equal, u64, arg0, arg1, out, broadcast_spec); + NGRAPH_TYPE_CASE(evaluate_equal, bf16, arg0, arg1, out, broadcast_spec); NGRAPH_TYPE_CASE(evaluate_equal, f16, arg0, arg1, out, broadcast_spec); NGRAPH_TYPE_CASE(evaluate_equal, f32, arg0, arg1, out, broadcast_spec); + NGRAPH_TYPE_CASE(evaluate_equal, f64, arg0, arg1, out, broadcast_spec); default: rc = false; break; diff --git a/src/core/src/validation_util.cpp b/src/core/src/validation_util.cpp index 762b1b33fd8..7f4667791ca 100644 --- a/src/core/src/validation_util.cpp +++ b/src/core/src/validation_util.cpp @@ -1223,7 +1223,8 @@ bool are_equal(const HostTensorPtr& lhs, const HostTensorPtr& rhs, size_t max_el return false; auto mask = std::make_shared(element::boolean, lhs_shape); const auto& param = std::make_shared(lhs_et, lhs_shape); - op::v1::Equal(param, param, ngraph::op::AutoBroadcastType::NUMPY).evaluate({mask}, {lhs, rhs}); + bool eval_status = op::v1::Equal(param, param, ngraph::op::AutoBroadcastType::NUMPY).evaluate({mask}, {lhs, rhs}); + OPENVINO_ASSERT(eval_status); auto equal = op::Constant(mask).cast_vector(); return std::all_of(equal.begin(), equal.end(), [](bool i) { return i; diff --git a/src/core/tests/constant_folding.cpp b/src/core/tests/constant_folding.cpp index 9f8e3395393..9ffb9b2fd85 100644 --- a/src/core/tests/constant_folding.cpp +++ b/src/core/tests/constant_folding.cpp @@ -283,6 +283,11 @@ TEST(constant_folding, constant_unary_binary) { auto j = make_shared(element::i8, Shape{2}, values_j); auto k = make_shared(element::u8, Shape{2}, values_k); auto doubles = make_shared(element::f64, Shape{2}, std::vector{4.0, 9.0}); + auto doubles2 = make_shared(element::f64, Shape{2}, std::vector{4.0, 1.0}); + auto shorts = make_shared(element::i16, Shape{3}, std::vector{14, -3, -3}); + auto shorts2 = make_shared(element::i16, Shape{1}, std::vector{-3}); + auto unsigned_shorts = make_shared(element::u16, Shape{3}, std::vector{14, 300, 14}); + auto unsigned_shorts2 = make_shared(element::u16, Shape{1}, std::vector{300}); auto add = make_shared(a, b); auto sub = make_shared(a, b); @@ -312,6 +317,10 @@ TEST(constant_folding, constant_unary_binary) { auto doubles_sqrt = make_shared(doubles); auto sub_int8 = make_shared(j, j); auto sub_uint8 = make_shared(k, k); + auto equal_doubles = make_shared(doubles, doubles2, op::AutoBroadcastType::NUMPY); + auto equal_shorts = make_shared(shorts, shorts2, op::AutoBroadcastType::NUMPY); + auto equal_unsigned_shorts = + make_shared(unsigned_shorts, unsigned_shorts2, op::AutoBroadcastType::NUMPY); auto neg_sqrt = make_shared(c); @@ -342,7 +351,10 @@ TEST(constant_folding, constant_unary_binary) { logical_xor_autob_numpy, doubles_sqrt, sub_int8, - sub_uint8}, + sub_uint8, + equal_doubles, + equal_shorts, + equal_unsigned_shorts}, ParameterVector{}); auto func_error = make_shared(NodeVector{neg_sqrt}, ParameterVector{}); @@ -378,6 +390,9 @@ TEST(constant_folding, constant_unary_binary) { vector doubles_sqrt_expected{2.0, 3.0}; vector sub_int8_expected{0, 0}; vector sub_uint8_expected{0, 0}; + vector equal_doubles_expected{1, 0}; + vector equal_shorts_expected{0, 1, 1}; + vector equal_unsigned_shorts_expected{0, 1, 0}; ASSERT_EQ(get_result_constant(func, 0), add_expected); ASSERT_EQ(get_result_constant(func, 1), sub_expected); @@ -407,6 +422,9 @@ TEST(constant_folding, constant_unary_binary) { ASSERT_EQ(get_result_constant(func, 25), doubles_sqrt_expected); ASSERT_EQ(get_result_constant(func, 26), sub_int8_expected); ASSERT_EQ(get_result_constant(func, 27), sub_uint8_expected); + ASSERT_EQ(get_result_constant(func, 28), equal_doubles_expected); + ASSERT_EQ(get_result_constant(func, 29), equal_shorts_expected); + ASSERT_EQ(get_result_constant(func, 30), equal_unsigned_shorts_expected); ASSERT_NO_THROW(pass_manager.run_passes(func_error)); }