From a34b6e38f3758722b88875f05c1c1198c8c94484 Mon Sep 17 00:00:00 2001 From: Mateusz Tabaka Date: Fri, 18 Sep 2020 09:56:11 +0200 Subject: [PATCH] =?UTF-8?q?ConvertPrecision=20-=20saturate=20Constant's=20?= =?UTF-8?q?value=20to=20std::numeric=5Flimits::lowest() if it's below that limit. * Remove clamping to std::numeric_limits::lowest() in U32/U64 case --- .../src/transformations/convert_precision.cpp | 37 ++++++++--- .../transformations/convert_precision.cpp | 63 ++++++++++++++++++- 2 files changed, 92 insertions(+), 8 deletions(-) diff --git a/inference-engine/src/transformations/src/transformations/convert_precision.cpp b/inference-engine/src/transformations/src/transformations/convert_precision.cpp index 132cf9691e5..41804a03057 100644 --- a/inference-engine/src/transformations/src/transformations/convert_precision.cpp +++ b/inference-engine/src/transformations/src/transformations/convert_precision.cpp @@ -297,8 +297,36 @@ bool extend_select_type(std::shared_ptr & node, ngraph::element::T return false; } +template +inline dst_type convert_value(src_type val) { + if (val > std::numeric_limits::max()) { + return std::numeric_limits::max(); + } else if (val < std::numeric_limits::lowest()) { + return std::numeric_limits::lowest(); + } + return static_cast(val); +} + +// We need to treat U64->I32 and U32->I32 as a separate case, because of C++'s implicit promotion from signed to unsigned, +// and we don't need to compare and clamp the input to std::numeric_limits::lowest() +template <> +inline int32_t convert_value(uint64_t val) { + if (val > std::numeric_limits::max()) { + return std::numeric_limits::max(); + } + return static_cast(val); +} + +template <> +inline int32_t convert_value(uint32_t val) { + if (val > std::numeric_limits::max()) { + return std::numeric_limits::max(); + } + return static_cast(val); +} + template -std::shared_ptr change_constant_precision(std::shared_ptr & constant) { +static std::shared_ptr change_constant_precision(std::shared_ptr& constant) { using src_type = typename element_type_traits::value_type; using dst_type = typename element_type_traits::value_type; @@ -310,12 +338,7 @@ std::shared_ptr change_constant_precision(std::shared_ptr final_data; for (size_t i = 0; i < size; ++i) { - const auto & val = src_data[i]; - if (val > std::numeric_limits::max()) { - dst_data[i] = std::numeric_limits::max(); - } else { - dst_data[i] = static_cast(val); - } + dst_data[i] = convert_value(src_data[i]); } return new_constant; } diff --git a/inference-engine/tests/functional/inference_engine/transformations/convert_precision.cpp b/inference-engine/tests/functional/inference_engine/transformations/convert_precision.cpp index 6671f346296..d67d67d29ae 100644 --- a/inference-engine/tests/functional/inference_engine/transformations/convert_precision.cpp +++ b/inference-engine/tests/functional/inference_engine/transformations/convert_precision.cpp @@ -560,4 +560,65 @@ TEST(TransformationTests, ConvertPrecision_Variables) { } ASSERT_FALSE(has_type(f)); -} \ No newline at end of file +} + +template +void constant_convert_test(element::Type_t type_from, element::Type_t type_to, From value, To expected) { + std::shared_ptr f(nullptr); + { + auto c = opset4::Constant::create(type_from, Shape{}, {value}); + f = std::make_shared(NodeVector{c}, ParameterVector{}); + + pass::Manager manager; + manager.register_pass(type_from, type_to); + manager.run_passes(f); + } + auto ops = f->get_ordered_ops(); + auto c = std::dynamic_pointer_cast(ops[0]); + ASSERT_NE(c, nullptr); + + auto actual = c->cast_vector()[0]; + ASSERT_EQ(expected, actual); +} + +TEST(TransformationTests, ConvertPrecision_ConstantConversion_I64MinToI32) { + constant_convert_test(element::Type_t::i64, element::Type_t::i32, + std::numeric_limits::min(), + std::numeric_limits::min()); +} + +TEST(TransformationTests, ConvertPrecision_ConstantConversion_I64MaxToI32) { + constant_convert_test(element::Type_t::i64, element::Type_t::i32, + std::numeric_limits::max(), + std::numeric_limits::max()); +} + +TEST(TransformationTests, ConvertPrecision_ConstantConversion_U64MinToI32) { + constant_convert_test(element::Type_t::u64, element::Type_t::i32, + std::numeric_limits::min(), 0); +} + +TEST(TransformationTests, ConvertPrecision_ConstantConversion_U64MaxToI32) { + constant_convert_test(element::Type_t::u64, element::Type_t::i32, + std::numeric_limits::max(), + std::numeric_limits::max()); +} + +TEST(TransformationTests, ConvertPrecision_ConstantConversion_U64ToI32) { + constant_convert_test(element::Type_t::u64, element::Type_t::i32, 42, 42); +} + +TEST(TransformationTests, ConvertPrecision_ConstantConversion_U32MinToI32) { + constant_convert_test(element::Type_t::u32, element::Type_t::i32, + std::numeric_limits::min(), 0); +} + +TEST(TransformationTests, ConvertPrecision_ConstantConversion_U32MaxToI32) { + constant_convert_test(element::Type_t::u32, element::Type_t::i32, + std::numeric_limits::max(), + std::numeric_limits::max()); +} + +TEST(TransformationTests, ConvertPrecision_ConstantConversion_U32ToI32) { + constant_convert_test(element::Type_t::u32, element::Type_t::i32, 42, 42); +}