Add bf16, f64, i4, u4, i16, u16 types to Equal's evaluate (#10508)

* Add f64 type to Equal's evaluate

Required by t2t-vit models.

Ticket: 79610.

* add also i16 u16 because prior_box tests fail with "Check eval_status failed at"

* code style

* add i4, u4, bf16 to equal's evaluate
This commit is contained in:
Mateusz Tabaka 2022-02-21 09:37:17 +01:00 committed by GitHub
parent 1fa5d44769
commit 430e898c33
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 28 additions and 3 deletions

View File

@ -35,14 +35,20 @@ bool evaluate_equal(const HostTensorPtr& arg0,
out->set_broadcast(broadcast_spec, arg0, arg1, element::boolean);
switch (arg0->get_element_type()) {
NGRAPH_TYPE_CASE(evaluate_equal, boolean, arg0, arg1, out, broadcast_spec);
NGRAPH_TYPE_CASE(evaluate_equal, i4, arg0, arg1, out, broadcast_spec);
NGRAPH_TYPE_CASE(evaluate_equal, i8, arg0, arg1, out, broadcast_spec);
NGRAPH_TYPE_CASE(evaluate_equal, u8, arg0, arg1, out, broadcast_spec);
NGRAPH_TYPE_CASE(evaluate_equal, i16, arg0, arg1, out, broadcast_spec);
NGRAPH_TYPE_CASE(evaluate_equal, i32, arg0, arg1, out, broadcast_spec);
NGRAPH_TYPE_CASE(evaluate_equal, i64, arg0, arg1, out, broadcast_spec);
NGRAPH_TYPE_CASE(evaluate_equal, u4, arg0, arg1, out, broadcast_spec);
NGRAPH_TYPE_CASE(evaluate_equal, u8, arg0, arg1, out, broadcast_spec);
NGRAPH_TYPE_CASE(evaluate_equal, u16, arg0, arg1, out, broadcast_spec);
NGRAPH_TYPE_CASE(evaluate_equal, u32, arg0, arg1, out, broadcast_spec);
NGRAPH_TYPE_CASE(evaluate_equal, u64, arg0, arg1, out, broadcast_spec);
NGRAPH_TYPE_CASE(evaluate_equal, bf16, arg0, arg1, out, broadcast_spec);
NGRAPH_TYPE_CASE(evaluate_equal, f16, arg0, arg1, out, broadcast_spec);
NGRAPH_TYPE_CASE(evaluate_equal, f32, arg0, arg1, out, broadcast_spec);
NGRAPH_TYPE_CASE(evaluate_equal, f64, arg0, arg1, out, broadcast_spec);
default:
rc = false;
break;

View File

@ -1223,7 +1223,8 @@ bool are_equal(const HostTensorPtr& lhs, const HostTensorPtr& rhs, size_t max_el
return false;
auto mask = std::make_shared<HostTensor>(element::boolean, lhs_shape);
const auto& param = std::make_shared<op::Parameter>(lhs_et, lhs_shape);
op::v1::Equal(param, param, ngraph::op::AutoBroadcastType::NUMPY).evaluate({mask}, {lhs, rhs});
bool eval_status = op::v1::Equal(param, param, ngraph::op::AutoBroadcastType::NUMPY).evaluate({mask}, {lhs, rhs});
OPENVINO_ASSERT(eval_status);
auto equal = op::Constant(mask).cast_vector<bool>();
return std::all_of(equal.begin(), equal.end(), [](bool i) {
return i;

View File

@ -283,6 +283,11 @@ TEST(constant_folding, constant_unary_binary) {
auto j = make_shared<op::Constant>(element::i8, Shape{2}, values_j);
auto k = make_shared<op::Constant>(element::u8, Shape{2}, values_k);
auto doubles = make_shared<op::Constant>(element::f64, Shape{2}, std::vector<double>{4.0, 9.0});
auto doubles2 = make_shared<op::Constant>(element::f64, Shape{2}, std::vector<double>{4.0, 1.0});
auto shorts = make_shared<op::Constant>(element::i16, Shape{3}, std::vector<int16_t>{14, -3, -3});
auto shorts2 = make_shared<op::Constant>(element::i16, Shape{1}, std::vector<int16_t>{-3});
auto unsigned_shorts = make_shared<op::Constant>(element::u16, Shape{3}, std::vector<uint16_t>{14, 300, 14});
auto unsigned_shorts2 = make_shared<op::Constant>(element::u16, Shape{1}, std::vector<uint16_t>{300});
auto add = make_shared<op::v1::Add>(a, b);
auto sub = make_shared<op::v1::Subtract>(a, b);
@ -312,6 +317,10 @@ TEST(constant_folding, constant_unary_binary) {
auto doubles_sqrt = make_shared<op::Sqrt>(doubles);
auto sub_int8 = make_shared<op::v1::Subtract>(j, j);
auto sub_uint8 = make_shared<op::v1::Subtract>(k, k);
auto equal_doubles = make_shared<op::v1::Equal>(doubles, doubles2, op::AutoBroadcastType::NUMPY);
auto equal_shorts = make_shared<op::v1::Equal>(shorts, shorts2, op::AutoBroadcastType::NUMPY);
auto equal_unsigned_shorts =
make_shared<op::v1::Equal>(unsigned_shorts, unsigned_shorts2, op::AutoBroadcastType::NUMPY);
auto neg_sqrt = make_shared<op::Sqrt>(c);
@ -342,7 +351,10 @@ TEST(constant_folding, constant_unary_binary) {
logical_xor_autob_numpy,
doubles_sqrt,
sub_int8,
sub_uint8},
sub_uint8,
equal_doubles,
equal_shorts,
equal_unsigned_shorts},
ParameterVector{});
auto func_error = make_shared<Function>(NodeVector{neg_sqrt}, ParameterVector{});
@ -378,6 +390,9 @@ TEST(constant_folding, constant_unary_binary) {
vector<double> doubles_sqrt_expected{2.0, 3.0};
vector<int8_t> sub_int8_expected{0, 0};
vector<uint8_t> sub_uint8_expected{0, 0};
vector<char> equal_doubles_expected{1, 0};
vector<char> equal_shorts_expected{0, 1, 1};
vector<char> equal_unsigned_shorts_expected{0, 1, 0};
ASSERT_EQ(get_result_constant<int>(func, 0), add_expected);
ASSERT_EQ(get_result_constant<int>(func, 1), sub_expected);
@ -407,6 +422,9 @@ TEST(constant_folding, constant_unary_binary) {
ASSERT_EQ(get_result_constant<double>(func, 25), doubles_sqrt_expected);
ASSERT_EQ(get_result_constant<int8_t>(func, 26), sub_int8_expected);
ASSERT_EQ(get_result_constant<uint8_t>(func, 27), sub_uint8_expected);
ASSERT_EQ(get_result_constant<char>(func, 28), equal_doubles_expected);
ASSERT_EQ(get_result_constant<char>(func, 29), equal_shorts_expected);
ASSERT_EQ(get_result_constant<char>(func, 30), equal_unsigned_shorts_expected);
ASSERT_NO_THROW(pass_manager.run_passes(func_error));
}