ConvertPrecision - saturate Constant's value to std::numeric_limits<dst… (#2206)

* ConvertPrecision - saturate Constant's value to std::numeric_limits<dst_type>::lowest() if it's below that limit.

* Remove clamping to std::numeric_limits<int32_t>::lowest() in U32/U64 case
This commit is contained in:
Mateusz Tabaka 2020-09-18 09:56:11 +02:00 committed by GitHub
parent 1b7dfc6e4c
commit a34b6e38f3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 92 additions and 8 deletions

View File

@ -297,8 +297,36 @@ bool extend_select_type(std::shared_ptr<ngraph::Node> & node, ngraph::element::T
return false;
}
template <typename src_type, typename dst_type>
inline dst_type convert_value(src_type val) {
if (val > std::numeric_limits<dst_type>::max()) {
return std::numeric_limits<dst_type>::max();
} else if (val < std::numeric_limits<dst_type>::lowest()) {
return std::numeric_limits<dst_type>::lowest();
}
return static_cast<dst_type>(val);
}
// We need to treat U64->I32 and U32->I32 as a separate case, because of C++'s implicit promotion from signed to unsigned,
// and we don't need to compare and clamp the input to std::numeric_limits<int32_t>::lowest()
template <>
inline int32_t convert_value<uint64_t, int32_t>(uint64_t val) {
if (val > std::numeric_limits<int32_t>::max()) {
return std::numeric_limits<int32_t>::max();
}
return static_cast<int32_t>(val);
}
template <>
inline int32_t convert_value<uint32_t, int32_t>(uint32_t val) {
if (val > std::numeric_limits<int32_t>::max()) {
return std::numeric_limits<int32_t>::max();
}
return static_cast<int32_t>(val);
}
template <element::Type_t PREC_FROM, element::Type_t PREC_TO>
std::shared_ptr<Node> change_constant_precision(std::shared_ptr<opset4::Constant> & constant) {
static std::shared_ptr<Node> change_constant_precision(std::shared_ptr<opset4::Constant>& constant) {
using src_type = typename element_type_traits<PREC_FROM>::value_type;
using dst_type = typename element_type_traits<PREC_TO>::value_type;
@ -310,12 +338,7 @@ std::shared_ptr<Node> change_constant_precision(std::shared_ptr<opset4::Constant
std::vector<dst_type> final_data;
for (size_t i = 0; i < size; ++i) {
const auto & val = src_data[i];
if (val > std::numeric_limits<dst_type>::max()) {
dst_data[i] = std::numeric_limits<dst_type>::max();
} else {
dst_data[i] = static_cast<dst_type>(val);
}
dst_data[i] = convert_value<src_type, dst_type>(src_data[i]);
}
return new_constant;
}

View File

@ -560,4 +560,65 @@ TEST(TransformationTests, ConvertPrecision_Variables) {
}
ASSERT_FALSE(has_type<ngraph::element::Type_t::f16>(f));
}
}
template <typename From, typename To>
void constant_convert_test(element::Type_t type_from, element::Type_t type_to, From value, To expected) {
std::shared_ptr<ngraph::Function> f(nullptr);
{
auto c = opset4::Constant::create(type_from, Shape{}, {value});
f = std::make_shared<Function>(NodeVector{c}, ParameterVector{});
pass::Manager manager;
manager.register_pass<ngraph::pass::ConvertPrecision>(type_from, type_to);
manager.run_passes(f);
}
auto ops = f->get_ordered_ops();
auto c = std::dynamic_pointer_cast<opset4::Constant>(ops[0]);
ASSERT_NE(c, nullptr);
auto actual = c->cast_vector<To>()[0];
ASSERT_EQ(expected, actual);
}
TEST(TransformationTests, ConvertPrecision_ConstantConversion_I64MinToI32) {
constant_convert_test(element::Type_t::i64, element::Type_t::i32,
std::numeric_limits<int64_t>::min(),
std::numeric_limits<int32_t>::min());
}
TEST(TransformationTests, ConvertPrecision_ConstantConversion_I64MaxToI32) {
constant_convert_test(element::Type_t::i64, element::Type_t::i32,
std::numeric_limits<int64_t>::max(),
std::numeric_limits<int32_t>::max());
}
TEST(TransformationTests, ConvertPrecision_ConstantConversion_U64MinToI32) {
constant_convert_test(element::Type_t::u64, element::Type_t::i32,
std::numeric_limits<uint64_t>::min(), 0);
}
TEST(TransformationTests, ConvertPrecision_ConstantConversion_U64MaxToI32) {
constant_convert_test(element::Type_t::u64, element::Type_t::i32,
std::numeric_limits<uint64_t>::max(),
std::numeric_limits<int32_t>::max());
}
TEST(TransformationTests, ConvertPrecision_ConstantConversion_U64ToI32) {
constant_convert_test(element::Type_t::u64, element::Type_t::i32, 42, 42);
}
TEST(TransformationTests, ConvertPrecision_ConstantConversion_U32MinToI32) {
constant_convert_test(element::Type_t::u32, element::Type_t::i32,
std::numeric_limits<uint32_t>::min(), 0);
}
TEST(TransformationTests, ConvertPrecision_ConstantConversion_U32MaxToI32) {
constant_convert_test(element::Type_t::u32, element::Type_t::i32,
std::numeric_limits<uint32_t>::max(),
std::numeric_limits<int32_t>::max());
}
TEST(TransformationTests, ConvertPrecision_ConstantConversion_U32ToI32) {
constant_convert_test(element::Type_t::u32, element::Type_t::i32, 42, 42);
}