Updated ConvertPrecision pass not to make extra copy of Constant data (#2108)
This commit is contained in:
parent
6c3f7fd654
commit
18abd6cfd7
@ -300,17 +300,22 @@ std::shared_ptr<Node> change_constant_precision(std::shared_ptr<opset4::Constant
|
||||
using src_type = typename element_type_traits<PREC_FROM>::value_type;
|
||||
using dst_type = typename element_type_traits<PREC_TO>::value_type;
|
||||
|
||||
std::vector<src_type> data(std::move(constant->get_vector<src_type>()));
|
||||
const auto * src_data = constant->get_data_ptr<src_type>();
|
||||
const auto size = shape_size(constant->get_shape());
|
||||
|
||||
auto new_constant = std::make_shared<ngraph::opset4::Constant>(PREC_TO, constant->get_shape());
|
||||
auto * dst_data = const_cast<dst_type *>(reinterpret_cast<const dst_type *>(new_constant->get_data_ptr()));
|
||||
|
||||
std::vector<dst_type> final_data;
|
||||
std::transform(data.begin(), data.end(), std::back_inserter(final_data),
|
||||
[](src_type val) {
|
||||
if (val > std::numeric_limits<dst_type>::max()) {
|
||||
return std::numeric_limits<dst_type>::max();
|
||||
} else {
|
||||
return static_cast<dst_type>(val);
|
||||
}
|
||||
});
|
||||
return std::make_shared<ngraph::opset4::Constant>(PREC_TO, constant->get_shape(), final_data);
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
const auto & val = src_data[i];
|
||||
if (val > std::numeric_limits<dst_type>::max()) {
|
||||
dst_data[i] = std::numeric_limits<dst_type>::max();
|
||||
} else {
|
||||
dst_data[i] = static_cast<dst_type>(val);
|
||||
}
|
||||
}
|
||||
return new_constant;
|
||||
}
|
||||
|
||||
bool fuse_type_to_constant(std::shared_ptr<Node> & node, element::Type to, const std::vector<Input<Node>> & consumers) {
|
||||
|
Loading…
Reference in New Issue
Block a user