[LPT] StridedSliceTransformation accuracy degradation fix (#4300)

* [LPT] StridedSliceTransformation fix

* added comments
This commit is contained in:
Vladislav Golubev 2021-02-16 16:41:37 +03:00 committed by GitHub
parent 5e17926604
commit 8a4aa1bafa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -17,9 +17,10 @@ std::shared_ptr<Node> stridedSliceDeqConstant(
const std::shared_ptr<ngraph::Node> strSlice, const std::shared_ptr<ngraph::Node> strSlice,
const std::shared_ptr<ngraph::Node> dequantizaitonConstant) { const std::shared_ptr<ngraph::Node> dequantizaitonConstant) {
auto constant = as_type_ptr<ngraph::opset1::Constant>(dequantizaitonConstant); auto constant = as_type_ptr<ngraph::opset1::Constant>(dequantizaitonConstant);
if (NetworkHelper::isScalarLike(constant)) { // issue #48857: constant is mistakenly recognized as a scalar. Uncomment after fix
return NetworkHelper::toScalar(constant); //if (NetworkHelper::isScalarLike(constant)) {
} // return NetworkHelper::toScalar(constant);
//}
if (strSlice->get_input_shape(0).size() != constant->get_shape().size()) { if (strSlice->get_input_shape(0).size() != constant->get_shape().size()) {
const auto constantShape = constant->get_shape(); const auto constantShape = constant->get_shape();
@ -39,7 +40,7 @@ std::shared_ptr<Node> stridedSliceDeqConstant(
} }
const auto stridedSlice = as_type_ptr<ngraph::opset1::StridedSlice>(strSlice); const auto stridedSlice = as_type_ptr<ngraph::opset1::StridedSlice>(strSlice);
return fold<ngraph::opset1::StridedSlice>( const auto result = fold<ngraph::opset1::StridedSlice>(
constant, constant,
stridedSlice->get_input_node_shared_ptr(1), stridedSlice->get_input_node_shared_ptr(1),
stridedSlice->get_input_node_shared_ptr(2), stridedSlice->get_input_node_shared_ptr(2),
@ -49,6 +50,8 @@ std::shared_ptr<Node> stridedSliceDeqConstant(
stridedSlice->get_new_axis_mask(), stridedSlice->get_new_axis_mask(),
stridedSlice->get_shrink_axis_mask(), stridedSlice->get_shrink_axis_mask(),
stridedSlice->get_ellipsis_mask()); stridedSlice->get_ellipsis_mask());
return NetworkHelper::toScalarIfPossible(result);
} }
StridedSliceTransformation::StridedSliceTransformation(const Params& params) : LayerTransformation(params) {} StridedSliceTransformation::StridedSliceTransformation(const Params& params) : LayerTransformation(params) {}