From 8a4aa1bafaf7b520409eaccc146b64b1da020621 Mon Sep 17 00:00:00 2001 From: Vladislav Golubev Date: Tue, 16 Feb 2021 16:41:37 +0300 Subject: [PATCH] [LPT] StridedSliceTransformation accuracy degradation fix (#4300) * [LPT] StridedSliceTransformation fix * added comments --- .../src/strided_slice.cpp | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/inference-engine/src/low_precision_transformations/src/strided_slice.cpp b/inference-engine/src/low_precision_transformations/src/strided_slice.cpp index 2acc5bb05b2..10ebb60d391 100644 --- a/inference-engine/src/low_precision_transformations/src/strided_slice.cpp +++ b/inference-engine/src/low_precision_transformations/src/strided_slice.cpp @@ -17,9 +17,10 @@ std::shared_ptr stridedSliceDeqConstant( const std::shared_ptr strSlice, const std::shared_ptr dequantizaitonConstant) { auto constant = as_type_ptr(dequantizaitonConstant); - if (NetworkHelper::isScalarLike(constant)) { - return NetworkHelper::toScalar(constant); - } + // issue #48857: constant is mistakenly recognized as a scalar. Uncomment after fix + //if (NetworkHelper::isScalarLike(constant)) { + // return NetworkHelper::toScalar(constant); + //} if (strSlice->get_input_shape(0).size() != constant->get_shape().size()) { const auto constantShape = constant->get_shape(); @@ -39,7 +40,7 @@ std::shared_ptr stridedSliceDeqConstant( } const auto stridedSlice = as_type_ptr(strSlice); - return fold( + const auto result = fold( constant, stridedSlice->get_input_node_shared_ptr(1), stridedSlice->get_input_node_shared_ptr(2), @@ -49,6 +50,8 @@ std::shared_ptr stridedSliceDeqConstant( stridedSlice->get_new_axis_mask(), stridedSlice->get_shrink_axis_mask(), stridedSlice->get_ellipsis_mask()); + + return NetworkHelper::toScalarIfPossible(result); } StridedSliceTransformation::StridedSliceTransformation(const Params& params) : LayerTransformation(params) {}