From 40923893b607d1cba59d8e05bd856817c49eb0ea Mon Sep 17 00:00:00 2001 From: Alexander Peskov Date: Mon, 31 Aug 2020 15:50:20 +0300 Subject: [PATCH] [NGRAPH] Fix ReduceSum decompose pass Signed-off-by: Alexander Peskov --- .../convert_reduce_to_pooling.hpp | 27 ++++++++++++++----- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/inference-engine/src/transformations/include/transformations/convert_reduce_to_pooling.hpp b/inference-engine/src/transformations/include/transformations/convert_reduce_to_pooling.hpp index 94386d03a36..23dffdbdf5a 100644 --- a/inference-engine/src/transformations/include/transformations/convert_reduce_to_pooling.hpp +++ b/inference-engine/src/transformations/include/transformations/convert_reduce_to_pooling.hpp @@ -231,21 +231,34 @@ ngraph::matcher_pass_callback ConvertReduceBase::convert_reduce_to_pooling() { input.get_node_shared_ptr()->set_friendly_name(reduce->get_friendly_name() + "/pool"); new_ops.push_back(input.get_node_shared_ptr()); } else if (std::is_same()) { + // Fallback to real type because of potential data loss in case of integer AVG Pool + bool fallback_to_real = input.get_element_type().is_integral(); + + if (fallback_to_real) { + input = std::make_shared(input, ngraph::element::f32); + new_ops.push_back(input.get_node_shared_ptr()); + } + input = std::make_shared(input, - strides, - pads_begin, - pads_end, - kernel, - true, - ngraph::op::RoundingType::FLOOR); + strides, + pads_begin, + pads_end, + kernel, + true, + ngraph::op::RoundingType::FLOOR); input.get_node_shared_ptr()->set_friendly_name(reduce->get_friendly_name() + "/pool"); new_ops.push_back(input.get_node_shared_ptr()); input = std::make_shared(input, - ngraph::opset1::Constant::create(reduce->input(0).get_element_type(), ngraph::Shape{1}, {reduction_dims_count})); + ngraph::opset1::Constant::create(input.get_element_type(), ngraph::Shape{1}, {reduction_dims_count})); input.get_node_shared_ptr()->set_friendly_name(reduce->get_friendly_name() + "/mul"); new_ops.push_back(input.get_node_shared_ptr()); + + if (fallback_to_real) { + input = std::make_shared(input, reduce->output(0).get_element_type()); + new_ops.push_back(input.get_node_shared_ptr()); + } } else { return false; }