[LPT] StridedSliceTransformation accuracy degradation fix (#4300)

* [LPT] StridedSliceTransformation fix

* added comments
This commit is contained in:
Vladislav Golubev 2021-02-16 16:41:37 +03:00 committed by GitHub
parent 5e17926604
commit 8a4aa1bafa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -17,9 +17,10 @@ std::shared_ptr<Node> stridedSliceDeqConstant(
const std::shared_ptr<ngraph::Node> strSlice,
const std::shared_ptr<ngraph::Node> dequantizaitonConstant) {
auto constant = as_type_ptr<ngraph::opset1::Constant>(dequantizaitonConstant);
if (NetworkHelper::isScalarLike(constant)) {
return NetworkHelper::toScalar(constant);
}
// issue #48857: constant is mistakenly recognized as a scalar. Uncomment after fix
//if (NetworkHelper::isScalarLike(constant)) {
// return NetworkHelper::toScalar(constant);
//}
if (strSlice->get_input_shape(0).size() != constant->get_shape().size()) {
const auto constantShape = constant->get_shape();
@ -39,7 +40,7 @@ std::shared_ptr<Node> stridedSliceDeqConstant(
}
const auto stridedSlice = as_type_ptr<ngraph::opset1::StridedSlice>(strSlice);
return fold<ngraph::opset1::StridedSlice>(
const auto result = fold<ngraph::opset1::StridedSlice>(
constant,
stridedSlice->get_input_node_shared_ptr(1),
stridedSlice->get_input_node_shared_ptr(2),
@ -49,6 +50,8 @@ std::shared_ptr<Node> stridedSliceDeqConstant(
stridedSlice->get_new_axis_mask(),
stridedSlice->get_shrink_axis_mask(),
stridedSlice->get_ellipsis_mask());
return NetworkHelper::toScalarIfPossible(result);
}
StridedSliceTransformation::StridedSliceTransformation(const Params& params) : LayerTransformation(params) {}