Add support for fp64 in Sqrt's evaluate method (#5913)
It's required for t2t-vit models.
This commit is contained in:
parent
7b2779c406
commit
dc0d482c23
@ -57,6 +57,7 @@ namespace sqrtop
|
|||||||
NGRAPH_TYPE_CASE(evaluate_sqrt, u64, arg0, out, count);
|
NGRAPH_TYPE_CASE(evaluate_sqrt, u64, arg0, out, count);
|
||||||
NGRAPH_TYPE_CASE(evaluate_sqrt, f16, arg0, out, count);
|
NGRAPH_TYPE_CASE(evaluate_sqrt, f16, arg0, out, count);
|
||||||
NGRAPH_TYPE_CASE(evaluate_sqrt, f32, arg0, out, count);
|
NGRAPH_TYPE_CASE(evaluate_sqrt, f32, arg0, out, count);
|
||||||
|
NGRAPH_TYPE_CASE(evaluate_sqrt, f64, arg0, out, count);
|
||||||
default: rc = false; break;
|
default: rc = false; break;
|
||||||
}
|
}
|
||||||
return rc;
|
return rc;
|
||||||
@ -79,7 +80,8 @@ bool op::Sqrt::has_evaluate() const
|
|||||||
case ngraph::element::u32:
|
case ngraph::element::u32:
|
||||||
case ngraph::element::u64:
|
case ngraph::element::u64:
|
||||||
case ngraph::element::f16:
|
case ngraph::element::f16:
|
||||||
case ngraph::element::f32: return true;
|
case ngraph::element::f32:
|
||||||
|
case ngraph::element::f64: return true;
|
||||||
default: break;
|
default: break;
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
|
@ -300,6 +300,7 @@ TEST(constant_folding, constant_unary_binary)
|
|||||||
auto g = make_shared<op::Constant>(element::i32, Shape{2}, values_g);
|
auto g = make_shared<op::Constant>(element::i32, Shape{2}, values_g);
|
||||||
auto h = make_shared<op::Constant>(element::boolean, Shape{2, 2}, values_h);
|
auto h = make_shared<op::Constant>(element::boolean, Shape{2, 2}, values_h);
|
||||||
auto i = make_shared<op::Constant>(element::boolean, Shape{2}, values_i);
|
auto i = make_shared<op::Constant>(element::boolean, Shape{2}, values_i);
|
||||||
|
auto doubles = make_shared<op::Constant>(element::f64, Shape{2}, std::vector<double>{4.0, 9.0});
|
||||||
|
|
||||||
auto add = make_shared<op::v1::Add>(a, b);
|
auto add = make_shared<op::v1::Add>(a, b);
|
||||||
auto sub = make_shared<op::v1::Subtract>(a, b);
|
auto sub = make_shared<op::v1::Subtract>(a, b);
|
||||||
@ -328,6 +329,7 @@ TEST(constant_folding, constant_unary_binary)
|
|||||||
auto logical_or_autob_numpy =
|
auto logical_or_autob_numpy =
|
||||||
make_shared<op::v1::LogicalOr>(h, i, op::AutoBroadcastType::NUMPY);
|
make_shared<op::v1::LogicalOr>(h, i, op::AutoBroadcastType::NUMPY);
|
||||||
auto logical_xor_autob_numpy = make_shared<op::Xor>(h, i, op::AutoBroadcastType::NUMPY);
|
auto logical_xor_autob_numpy = make_shared<op::Xor>(h, i, op::AutoBroadcastType::NUMPY);
|
||||||
|
auto doubles_sqrt = make_shared<op::Sqrt>(doubles);
|
||||||
|
|
||||||
auto neg_sqrt = make_shared<op::Sqrt>(c);
|
auto neg_sqrt = make_shared<op::Sqrt>(c);
|
||||||
|
|
||||||
@ -355,7 +357,8 @@ TEST(constant_folding, constant_unary_binary)
|
|||||||
less_autob_numpy,
|
less_autob_numpy,
|
||||||
less_eq_autob_numpy,
|
less_eq_autob_numpy,
|
||||||
logical_or_autob_numpy,
|
logical_or_autob_numpy,
|
||||||
logical_xor_autob_numpy},
|
logical_xor_autob_numpy,
|
||||||
|
doubles_sqrt},
|
||||||
ParameterVector{});
|
ParameterVector{});
|
||||||
auto func_error = make_shared<Function>(NodeVector{neg_sqrt}, ParameterVector{});
|
auto func_error = make_shared<Function>(NodeVector{neg_sqrt}, ParameterVector{});
|
||||||
|
|
||||||
@ -388,6 +391,7 @@ TEST(constant_folding, constant_unary_binary)
|
|||||||
vector<char> less_eq_autob_numpy_expected{1, 1, 0, 1};
|
vector<char> less_eq_autob_numpy_expected{1, 1, 0, 1};
|
||||||
vector<char> logical_or_autob_numpy_expected{0, 1, 1, 1};
|
vector<char> logical_or_autob_numpy_expected{0, 1, 1, 1};
|
||||||
vector<char> logical_xor_autob_numpy_expected{0, 1, 1, 0};
|
vector<char> logical_xor_autob_numpy_expected{0, 1, 1, 0};
|
||||||
|
vector<double> doubles_sqrt_expected{2.0, 3.0};
|
||||||
|
|
||||||
ASSERT_EQ(get_result_constant<int>(func, 0), add_expected);
|
ASSERT_EQ(get_result_constant<int>(func, 0), add_expected);
|
||||||
ASSERT_EQ(get_result_constant<int>(func, 1), sub_expected);
|
ASSERT_EQ(get_result_constant<int>(func, 1), sub_expected);
|
||||||
@ -414,6 +418,7 @@ TEST(constant_folding, constant_unary_binary)
|
|||||||
ASSERT_EQ(get_result_constant<char>(func, 22), less_eq_autob_numpy_expected);
|
ASSERT_EQ(get_result_constant<char>(func, 22), less_eq_autob_numpy_expected);
|
||||||
ASSERT_EQ(get_result_constant<char>(func, 23), logical_or_autob_numpy_expected);
|
ASSERT_EQ(get_result_constant<char>(func, 23), logical_or_autob_numpy_expected);
|
||||||
ASSERT_EQ(get_result_constant<char>(func, 24), logical_xor_autob_numpy_expected);
|
ASSERT_EQ(get_result_constant<char>(func, 24), logical_xor_autob_numpy_expected);
|
||||||
|
ASSERT_EQ(get_result_constant<double>(func, 25), doubles_sqrt_expected);
|
||||||
ASSERT_NO_THROW(pass_manager.run_passes(func_error));
|
ASSERT_NO_THROW(pass_manager.run_passes(func_error));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user