Add bf16, f64, i4, u4, i16, u16 types to Equal's evaluate (#10508)
* Add f64 type to Equal's evaluate Required by t2t-vit models. Ticket: 79610. * add also i16 u16 because prior_box tests fail with "Check eval_status failed at" * code style * add i4, u4, bf16 to equal's evaluate
This commit is contained in:
parent
1fa5d44769
commit
430e898c33
@ -35,14 +35,20 @@ bool evaluate_equal(const HostTensorPtr& arg0,
|
||||
out->set_broadcast(broadcast_spec, arg0, arg1, element::boolean);
|
||||
switch (arg0->get_element_type()) {
|
||||
NGRAPH_TYPE_CASE(evaluate_equal, boolean, arg0, arg1, out, broadcast_spec);
|
||||
NGRAPH_TYPE_CASE(evaluate_equal, i4, arg0, arg1, out, broadcast_spec);
|
||||
NGRAPH_TYPE_CASE(evaluate_equal, i8, arg0, arg1, out, broadcast_spec);
|
||||
NGRAPH_TYPE_CASE(evaluate_equal, u8, arg0, arg1, out, broadcast_spec);
|
||||
NGRAPH_TYPE_CASE(evaluate_equal, i16, arg0, arg1, out, broadcast_spec);
|
||||
NGRAPH_TYPE_CASE(evaluate_equal, i32, arg0, arg1, out, broadcast_spec);
|
||||
NGRAPH_TYPE_CASE(evaluate_equal, i64, arg0, arg1, out, broadcast_spec);
|
||||
NGRAPH_TYPE_CASE(evaluate_equal, u4, arg0, arg1, out, broadcast_spec);
|
||||
NGRAPH_TYPE_CASE(evaluate_equal, u8, arg0, arg1, out, broadcast_spec);
|
||||
NGRAPH_TYPE_CASE(evaluate_equal, u16, arg0, arg1, out, broadcast_spec);
|
||||
NGRAPH_TYPE_CASE(evaluate_equal, u32, arg0, arg1, out, broadcast_spec);
|
||||
NGRAPH_TYPE_CASE(evaluate_equal, u64, arg0, arg1, out, broadcast_spec);
|
||||
NGRAPH_TYPE_CASE(evaluate_equal, bf16, arg0, arg1, out, broadcast_spec);
|
||||
NGRAPH_TYPE_CASE(evaluate_equal, f16, arg0, arg1, out, broadcast_spec);
|
||||
NGRAPH_TYPE_CASE(evaluate_equal, f32, arg0, arg1, out, broadcast_spec);
|
||||
NGRAPH_TYPE_CASE(evaluate_equal, f64, arg0, arg1, out, broadcast_spec);
|
||||
default:
|
||||
rc = false;
|
||||
break;
|
||||
|
@ -1223,7 +1223,8 @@ bool are_equal(const HostTensorPtr& lhs, const HostTensorPtr& rhs, size_t max_el
|
||||
return false;
|
||||
auto mask = std::make_shared<HostTensor>(element::boolean, lhs_shape);
|
||||
const auto& param = std::make_shared<op::Parameter>(lhs_et, lhs_shape);
|
||||
op::v1::Equal(param, param, ngraph::op::AutoBroadcastType::NUMPY).evaluate({mask}, {lhs, rhs});
|
||||
bool eval_status = op::v1::Equal(param, param, ngraph::op::AutoBroadcastType::NUMPY).evaluate({mask}, {lhs, rhs});
|
||||
OPENVINO_ASSERT(eval_status);
|
||||
auto equal = op::Constant(mask).cast_vector<bool>();
|
||||
return std::all_of(equal.begin(), equal.end(), [](bool i) {
|
||||
return i;
|
||||
|
@ -283,6 +283,11 @@ TEST(constant_folding, constant_unary_binary) {
|
||||
auto j = make_shared<op::Constant>(element::i8, Shape{2}, values_j);
|
||||
auto k = make_shared<op::Constant>(element::u8, Shape{2}, values_k);
|
||||
auto doubles = make_shared<op::Constant>(element::f64, Shape{2}, std::vector<double>{4.0, 9.0});
|
||||
auto doubles2 = make_shared<op::Constant>(element::f64, Shape{2}, std::vector<double>{4.0, 1.0});
|
||||
auto shorts = make_shared<op::Constant>(element::i16, Shape{3}, std::vector<int16_t>{14, -3, -3});
|
||||
auto shorts2 = make_shared<op::Constant>(element::i16, Shape{1}, std::vector<int16_t>{-3});
|
||||
auto unsigned_shorts = make_shared<op::Constant>(element::u16, Shape{3}, std::vector<uint16_t>{14, 300, 14});
|
||||
auto unsigned_shorts2 = make_shared<op::Constant>(element::u16, Shape{1}, std::vector<uint16_t>{300});
|
||||
|
||||
auto add = make_shared<op::v1::Add>(a, b);
|
||||
auto sub = make_shared<op::v1::Subtract>(a, b);
|
||||
@ -312,6 +317,10 @@ TEST(constant_folding, constant_unary_binary) {
|
||||
auto doubles_sqrt = make_shared<op::Sqrt>(doubles);
|
||||
auto sub_int8 = make_shared<op::v1::Subtract>(j, j);
|
||||
auto sub_uint8 = make_shared<op::v1::Subtract>(k, k);
|
||||
auto equal_doubles = make_shared<op::v1::Equal>(doubles, doubles2, op::AutoBroadcastType::NUMPY);
|
||||
auto equal_shorts = make_shared<op::v1::Equal>(shorts, shorts2, op::AutoBroadcastType::NUMPY);
|
||||
auto equal_unsigned_shorts =
|
||||
make_shared<op::v1::Equal>(unsigned_shorts, unsigned_shorts2, op::AutoBroadcastType::NUMPY);
|
||||
|
||||
auto neg_sqrt = make_shared<op::Sqrt>(c);
|
||||
|
||||
@ -342,7 +351,10 @@ TEST(constant_folding, constant_unary_binary) {
|
||||
logical_xor_autob_numpy,
|
||||
doubles_sqrt,
|
||||
sub_int8,
|
||||
sub_uint8},
|
||||
sub_uint8,
|
||||
equal_doubles,
|
||||
equal_shorts,
|
||||
equal_unsigned_shorts},
|
||||
ParameterVector{});
|
||||
auto func_error = make_shared<Function>(NodeVector{neg_sqrt}, ParameterVector{});
|
||||
|
||||
@ -378,6 +390,9 @@ TEST(constant_folding, constant_unary_binary) {
|
||||
vector<double> doubles_sqrt_expected{2.0, 3.0};
|
||||
vector<int8_t> sub_int8_expected{0, 0};
|
||||
vector<uint8_t> sub_uint8_expected{0, 0};
|
||||
vector<char> equal_doubles_expected{1, 0};
|
||||
vector<char> equal_shorts_expected{0, 1, 1};
|
||||
vector<char> equal_unsigned_shorts_expected{0, 1, 0};
|
||||
|
||||
ASSERT_EQ(get_result_constant<int>(func, 0), add_expected);
|
||||
ASSERT_EQ(get_result_constant<int>(func, 1), sub_expected);
|
||||
@ -407,6 +422,9 @@ TEST(constant_folding, constant_unary_binary) {
|
||||
ASSERT_EQ(get_result_constant<double>(func, 25), doubles_sqrt_expected);
|
||||
ASSERT_EQ(get_result_constant<int8_t>(func, 26), sub_int8_expected);
|
||||
ASSERT_EQ(get_result_constant<uint8_t>(func, 27), sub_uint8_expected);
|
||||
ASSERT_EQ(get_result_constant<char>(func, 28), equal_doubles_expected);
|
||||
ASSERT_EQ(get_result_constant<char>(func, 29), equal_shorts_expected);
|
||||
ASSERT_EQ(get_result_constant<char>(func, 30), equal_unsigned_shorts_expected);
|
||||
ASSERT_NO_THROW(pass_manager.run_passes(func_error));
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user