[NGRAPH] Fix ReduceSum decompose pass

Signed-off-by: Alexander Peskov <alexander.peskov@intel.com>
This commit is contained in:
Alexander Peskov 2020-08-31 15:50:20 +03:00
parent 030e0f46fe
commit 40923893b6

View File

@ -231,21 +231,34 @@ ngraph::matcher_pass_callback ConvertReduceBase::convert_reduce_to_pooling() {
input.get_node_shared_ptr()->set_friendly_name(reduce->get_friendly_name() + "/pool");
new_ops.push_back(input.get_node_shared_ptr());
} else if (std::is_same<T, ngraph::opset1::ReduceSum>()) {
// Fallback to real type because of potential data loss in case of integer AVG Pool
bool fallback_to_real = input.get_element_type().is_integral();
if (fallback_to_real) {
input = std::make_shared<ngraph::opset1::Convert>(input, ngraph::element::f32);
new_ops.push_back(input.get_node_shared_ptr());
}
input = std::make_shared<ngraph::opset1::AvgPool>(input,
strides,
pads_begin,
pads_end,
kernel,
true,
ngraph::op::RoundingType::FLOOR);
strides,
pads_begin,
pads_end,
kernel,
true,
ngraph::op::RoundingType::FLOOR);
input.get_node_shared_ptr()->set_friendly_name(reduce->get_friendly_name() + "/pool");
new_ops.push_back(input.get_node_shared_ptr());
input = std::make_shared<ngraph::opset1::Multiply>(input,
ngraph::opset1::Constant::create(reduce->input(0).get_element_type(), ngraph::Shape{1}, {reduction_dims_count}));
ngraph::opset1::Constant::create(input.get_element_type(), ngraph::Shape{1}, {reduction_dims_count}));
input.get_node_shared_ptr()->set_friendly_name(reduce->get_friendly_name() + "/mul");
new_ops.push_back(input.get_node_shared_ptr());
if (fallback_to_real) {
input = std::make_shared<ngraph::opset1::Convert>(input, reduce->output(0).get_element_type());
new_ops.push_back(input.get_node_shared_ptr());
}
} else {
return false;
}