diff --git a/ngraph/core/src/op/sqrt.cpp b/ngraph/core/src/op/sqrt.cpp index e706e4ae7c2..339a8b74706 100644 --- a/ngraph/core/src/op/sqrt.cpp +++ b/ngraph/core/src/op/sqrt.cpp @@ -57,6 +57,7 @@ namespace sqrtop NGRAPH_TYPE_CASE(evaluate_sqrt, u64, arg0, out, count); NGRAPH_TYPE_CASE(evaluate_sqrt, f16, arg0, out, count); NGRAPH_TYPE_CASE(evaluate_sqrt, f32, arg0, out, count); + NGRAPH_TYPE_CASE(evaluate_sqrt, f64, arg0, out, count); default: rc = false; break; } return rc; @@ -79,7 +80,8 @@ bool op::Sqrt::has_evaluate() const case ngraph::element::u32: case ngraph::element::u64: case ngraph::element::f16: - case ngraph::element::f32: return true; + case ngraph::element::f32: + case ngraph::element::f64: return true; default: break; } return false; diff --git a/ngraph/test/constant_folding.cpp b/ngraph/test/constant_folding.cpp index c34dcb12c6e..d6b2d98ee9f 100644 --- a/ngraph/test/constant_folding.cpp +++ b/ngraph/test/constant_folding.cpp @@ -300,6 +300,7 @@ TEST(constant_folding, constant_unary_binary) auto g = make_shared(element::i32, Shape{2}, values_g); auto h = make_shared(element::boolean, Shape{2, 2}, values_h); auto i = make_shared(element::boolean, Shape{2}, values_i); + auto doubles = make_shared(element::f64, Shape{2}, std::vector{4.0, 9.0}); auto add = make_shared(a, b); auto sub = make_shared(a, b); @@ -328,6 +329,7 @@ TEST(constant_folding, constant_unary_binary) auto logical_or_autob_numpy = make_shared(h, i, op::AutoBroadcastType::NUMPY); auto logical_xor_autob_numpy = make_shared(h, i, op::AutoBroadcastType::NUMPY); + auto doubles_sqrt = make_shared(doubles); auto neg_sqrt = make_shared(c); @@ -355,7 +357,8 @@ TEST(constant_folding, constant_unary_binary) less_autob_numpy, less_eq_autob_numpy, logical_or_autob_numpy, - logical_xor_autob_numpy}, + logical_xor_autob_numpy, + doubles_sqrt}, ParameterVector{}); auto func_error = make_shared(NodeVector{neg_sqrt}, ParameterVector{}); @@ -388,6 +391,7 @@ TEST(constant_folding, constant_unary_binary) vector less_eq_autob_numpy_expected{1, 1, 0, 1}; vector logical_or_autob_numpy_expected{0, 1, 1, 1}; vector logical_xor_autob_numpy_expected{0, 1, 1, 0}; + vector doubles_sqrt_expected{2.0, 3.0}; ASSERT_EQ(get_result_constant(func, 0), add_expected); ASSERT_EQ(get_result_constant(func, 1), sub_expected); @@ -414,6 +418,7 @@ TEST(constant_folding, constant_unary_binary) ASSERT_EQ(get_result_constant(func, 22), less_eq_autob_numpy_expected); ASSERT_EQ(get_result_constant(func, 23), logical_or_autob_numpy_expected); ASSERT_EQ(get_result_constant(func, 24), logical_xor_autob_numpy_expected); + ASSERT_EQ(get_result_constant(func, 25), doubles_sqrt_expected); ASSERT_NO_THROW(pass_manager.run_passes(func_error)); }