Add support for fp64 in Sqrt's evaluate method (#5913)

It's required for t2t-vit models.
This commit is contained in:
Mateusz Tabaka 2021-06-16 12:45:36 +02:00 committed by GitHub
parent 7b2779c406
commit dc0d482c23
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 9 additions and 2 deletions

View File

@ -57,6 +57,7 @@ namespace sqrtop
NGRAPH_TYPE_CASE(evaluate_sqrt, u64, arg0, out, count); NGRAPH_TYPE_CASE(evaluate_sqrt, u64, arg0, out, count);
NGRAPH_TYPE_CASE(evaluate_sqrt, f16, arg0, out, count); NGRAPH_TYPE_CASE(evaluate_sqrt, f16, arg0, out, count);
NGRAPH_TYPE_CASE(evaluate_sqrt, f32, arg0, out, count); NGRAPH_TYPE_CASE(evaluate_sqrt, f32, arg0, out, count);
NGRAPH_TYPE_CASE(evaluate_sqrt, f64, arg0, out, count);
default: rc = false; break; default: rc = false; break;
} }
return rc; return rc;
@ -79,7 +80,8 @@ bool op::Sqrt::has_evaluate() const
case ngraph::element::u32: case ngraph::element::u32:
case ngraph::element::u64: case ngraph::element::u64:
case ngraph::element::f16: case ngraph::element::f16:
case ngraph::element::f32: return true; case ngraph::element::f32:
case ngraph::element::f64: return true;
default: break; default: break;
} }
return false; return false;

View File

@ -300,6 +300,7 @@ TEST(constant_folding, constant_unary_binary)
auto g = make_shared<op::Constant>(element::i32, Shape{2}, values_g); auto g = make_shared<op::Constant>(element::i32, Shape{2}, values_g);
auto h = make_shared<op::Constant>(element::boolean, Shape{2, 2}, values_h); auto h = make_shared<op::Constant>(element::boolean, Shape{2, 2}, values_h);
auto i = make_shared<op::Constant>(element::boolean, Shape{2}, values_i); auto i = make_shared<op::Constant>(element::boolean, Shape{2}, values_i);
auto doubles = make_shared<op::Constant>(element::f64, Shape{2}, std::vector<double>{4.0, 9.0});
auto add = make_shared<op::v1::Add>(a, b); auto add = make_shared<op::v1::Add>(a, b);
auto sub = make_shared<op::v1::Subtract>(a, b); auto sub = make_shared<op::v1::Subtract>(a, b);
@ -328,6 +329,7 @@ TEST(constant_folding, constant_unary_binary)
auto logical_or_autob_numpy = auto logical_or_autob_numpy =
make_shared<op::v1::LogicalOr>(h, i, op::AutoBroadcastType::NUMPY); make_shared<op::v1::LogicalOr>(h, i, op::AutoBroadcastType::NUMPY);
auto logical_xor_autob_numpy = make_shared<op::Xor>(h, i, op::AutoBroadcastType::NUMPY); auto logical_xor_autob_numpy = make_shared<op::Xor>(h, i, op::AutoBroadcastType::NUMPY);
auto doubles_sqrt = make_shared<op::Sqrt>(doubles);
auto neg_sqrt = make_shared<op::Sqrt>(c); auto neg_sqrt = make_shared<op::Sqrt>(c);
@ -355,7 +357,8 @@ TEST(constant_folding, constant_unary_binary)
less_autob_numpy, less_autob_numpy,
less_eq_autob_numpy, less_eq_autob_numpy,
logical_or_autob_numpy, logical_or_autob_numpy,
logical_xor_autob_numpy}, logical_xor_autob_numpy,
doubles_sqrt},
ParameterVector{}); ParameterVector{});
auto func_error = make_shared<Function>(NodeVector{neg_sqrt}, ParameterVector{}); auto func_error = make_shared<Function>(NodeVector{neg_sqrt}, ParameterVector{});
@ -388,6 +391,7 @@ TEST(constant_folding, constant_unary_binary)
vector<char> less_eq_autob_numpy_expected{1, 1, 0, 1}; vector<char> less_eq_autob_numpy_expected{1, 1, 0, 1};
vector<char> logical_or_autob_numpy_expected{0, 1, 1, 1}; vector<char> logical_or_autob_numpy_expected{0, 1, 1, 1};
vector<char> logical_xor_autob_numpy_expected{0, 1, 1, 0}; vector<char> logical_xor_autob_numpy_expected{0, 1, 1, 0};
vector<double> doubles_sqrt_expected{2.0, 3.0};
ASSERT_EQ(get_result_constant<int>(func, 0), add_expected); ASSERT_EQ(get_result_constant<int>(func, 0), add_expected);
ASSERT_EQ(get_result_constant<int>(func, 1), sub_expected); ASSERT_EQ(get_result_constant<int>(func, 1), sub_expected);
@ -414,6 +418,7 @@ TEST(constant_folding, constant_unary_binary)
ASSERT_EQ(get_result_constant<char>(func, 22), less_eq_autob_numpy_expected); ASSERT_EQ(get_result_constant<char>(func, 22), less_eq_autob_numpy_expected);
ASSERT_EQ(get_result_constant<char>(func, 23), logical_or_autob_numpy_expected); ASSERT_EQ(get_result_constant<char>(func, 23), logical_or_autob_numpy_expected);
ASSERT_EQ(get_result_constant<char>(func, 24), logical_xor_autob_numpy_expected); ASSERT_EQ(get_result_constant<char>(func, 24), logical_xor_autob_numpy_expected);
ASSERT_EQ(get_result_constant<double>(func, 25), doubles_sqrt_expected);
ASSERT_NO_THROW(pass_manager.run_passes(func_error)); ASSERT_NO_THROW(pass_manager.run_passes(func_error));
} }