ConvertPrecision - saturate Constant's value to std::numeric_limits<dst… (#2206)
* ConvertPrecision - saturate Constant's value to std::numeric_limits<dst_type>::lowest() if it's below that limit. * Remove clamping to std::numeric_limits<int32_t>::lowest() in U32/U64 case
This commit is contained in:
parent
1b7dfc6e4c
commit
a34b6e38f3
@ -297,8 +297,36 @@ bool extend_select_type(std::shared_ptr<ngraph::Node> & node, ngraph::element::T
|
||||
return false;
|
||||
}
|
||||
|
||||
template <typename src_type, typename dst_type>
|
||||
inline dst_type convert_value(src_type val) {
|
||||
if (val > std::numeric_limits<dst_type>::max()) {
|
||||
return std::numeric_limits<dst_type>::max();
|
||||
} else if (val < std::numeric_limits<dst_type>::lowest()) {
|
||||
return std::numeric_limits<dst_type>::lowest();
|
||||
}
|
||||
return static_cast<dst_type>(val);
|
||||
}
|
||||
|
||||
// We need to treat U64->I32 and U32->I32 as a separate case, because of C++'s implicit promotion from signed to unsigned,
|
||||
// and we don't need to compare and clamp the input to std::numeric_limits<int32_t>::lowest()
|
||||
template <>
|
||||
inline int32_t convert_value<uint64_t, int32_t>(uint64_t val) {
|
||||
if (val > std::numeric_limits<int32_t>::max()) {
|
||||
return std::numeric_limits<int32_t>::max();
|
||||
}
|
||||
return static_cast<int32_t>(val);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline int32_t convert_value<uint32_t, int32_t>(uint32_t val) {
|
||||
if (val > std::numeric_limits<int32_t>::max()) {
|
||||
return std::numeric_limits<int32_t>::max();
|
||||
}
|
||||
return static_cast<int32_t>(val);
|
||||
}
|
||||
|
||||
template <element::Type_t PREC_FROM, element::Type_t PREC_TO>
|
||||
std::shared_ptr<Node> change_constant_precision(std::shared_ptr<opset4::Constant> & constant) {
|
||||
static std::shared_ptr<Node> change_constant_precision(std::shared_ptr<opset4::Constant>& constant) {
|
||||
using src_type = typename element_type_traits<PREC_FROM>::value_type;
|
||||
using dst_type = typename element_type_traits<PREC_TO>::value_type;
|
||||
|
||||
@ -310,12 +338,7 @@ std::shared_ptr<Node> change_constant_precision(std::shared_ptr<opset4::Constant
|
||||
|
||||
std::vector<dst_type> final_data;
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
const auto & val = src_data[i];
|
||||
if (val > std::numeric_limits<dst_type>::max()) {
|
||||
dst_data[i] = std::numeric_limits<dst_type>::max();
|
||||
} else {
|
||||
dst_data[i] = static_cast<dst_type>(val);
|
||||
}
|
||||
dst_data[i] = convert_value<src_type, dst_type>(src_data[i]);
|
||||
}
|
||||
return new_constant;
|
||||
}
|
||||
|
@ -560,4 +560,65 @@ TEST(TransformationTests, ConvertPrecision_Variables) {
|
||||
}
|
||||
|
||||
ASSERT_FALSE(has_type<ngraph::element::Type_t::f16>(f));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename From, typename To>
|
||||
void constant_convert_test(element::Type_t type_from, element::Type_t type_to, From value, To expected) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr);
|
||||
{
|
||||
auto c = opset4::Constant::create(type_from, Shape{}, {value});
|
||||
f = std::make_shared<Function>(NodeVector{c}, ParameterVector{});
|
||||
|
||||
pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::ConvertPrecision>(type_from, type_to);
|
||||
manager.run_passes(f);
|
||||
}
|
||||
auto ops = f->get_ordered_ops();
|
||||
auto c = std::dynamic_pointer_cast<opset4::Constant>(ops[0]);
|
||||
ASSERT_NE(c, nullptr);
|
||||
|
||||
auto actual = c->cast_vector<To>()[0];
|
||||
ASSERT_EQ(expected, actual);
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertPrecision_ConstantConversion_I64MinToI32) {
|
||||
constant_convert_test(element::Type_t::i64, element::Type_t::i32,
|
||||
std::numeric_limits<int64_t>::min(),
|
||||
std::numeric_limits<int32_t>::min());
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertPrecision_ConstantConversion_I64MaxToI32) {
|
||||
constant_convert_test(element::Type_t::i64, element::Type_t::i32,
|
||||
std::numeric_limits<int64_t>::max(),
|
||||
std::numeric_limits<int32_t>::max());
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertPrecision_ConstantConversion_U64MinToI32) {
|
||||
constant_convert_test(element::Type_t::u64, element::Type_t::i32,
|
||||
std::numeric_limits<uint64_t>::min(), 0);
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertPrecision_ConstantConversion_U64MaxToI32) {
|
||||
constant_convert_test(element::Type_t::u64, element::Type_t::i32,
|
||||
std::numeric_limits<uint64_t>::max(),
|
||||
std::numeric_limits<int32_t>::max());
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertPrecision_ConstantConversion_U64ToI32) {
|
||||
constant_convert_test(element::Type_t::u64, element::Type_t::i32, 42, 42);
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertPrecision_ConstantConversion_U32MinToI32) {
|
||||
constant_convert_test(element::Type_t::u32, element::Type_t::i32,
|
||||
std::numeric_limits<uint32_t>::min(), 0);
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertPrecision_ConstantConversion_U32MaxToI32) {
|
||||
constant_convert_test(element::Type_t::u32, element::Type_t::i32,
|
||||
std::numeric_limits<uint32_t>::max(),
|
||||
std::numeric_limits<int32_t>::max());
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertPrecision_ConstantConversion_U32ToI32) {
|
||||
constant_convert_test(element::Type_t::u32, element::Type_t::i32, 42, 42);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user