Updated ConvertPrecision pass not to make extra copy of Constant data (#2108)

This commit is contained in:
Gleb Kazantaev 2020-09-10 17:11:10 +03:00 committed by GitHub
parent 6c3f7fd654
commit 18abd6cfd7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -300,17 +300,22 @@ std::shared_ptr<Node> change_constant_precision(std::shared_ptr<opset4::Constant
using src_type = typename element_type_traits<PREC_FROM>::value_type;
using dst_type = typename element_type_traits<PREC_TO>::value_type;
std::vector<src_type> data(std::move(constant->get_vector<src_type>()));
const auto * src_data = constant->get_data_ptr<src_type>();
const auto size = shape_size(constant->get_shape());
auto new_constant = std::make_shared<ngraph::opset4::Constant>(PREC_TO, constant->get_shape());
auto * dst_data = const_cast<dst_type *>(reinterpret_cast<const dst_type *>(new_constant->get_data_ptr()));
std::vector<dst_type> final_data;
std::transform(data.begin(), data.end(), std::back_inserter(final_data),
[](src_type val) {
for (size_t i = 0; i < size; ++i) {
const auto & val = src_data[i];
if (val > std::numeric_limits<dst_type>::max()) {
return std::numeric_limits<dst_type>::max();
dst_data[i] = std::numeric_limits<dst_type>::max();
} else {
return static_cast<dst_type>(val);
dst_data[i] = static_cast<dst_type>(val);
}
});
return std::make_shared<ngraph::opset4::Constant>(PREC_TO, constant->get_shape(), final_data);
}
return new_constant;
}
bool fuse_type_to_constant(std::shared_ptr<Node> & node, element::Type to, const std::vector<Input<Node>> & consumers) {