[LPT] StridedSliceTransformation accuracy degradation fix (#4300)
* [LPT] StridedSliceTransformation fix * added comments
This commit is contained in:
parent
5e17926604
commit
8a4aa1bafa
@ -17,9 +17,10 @@ std::shared_ptr<Node> stridedSliceDeqConstant(
|
||||
const std::shared_ptr<ngraph::Node> strSlice,
|
||||
const std::shared_ptr<ngraph::Node> dequantizaitonConstant) {
|
||||
auto constant = as_type_ptr<ngraph::opset1::Constant>(dequantizaitonConstant);
|
||||
if (NetworkHelper::isScalarLike(constant)) {
|
||||
return NetworkHelper::toScalar(constant);
|
||||
}
|
||||
// issue #48857: constant is mistakenly recognized as a scalar. Uncomment after fix
|
||||
//if (NetworkHelper::isScalarLike(constant)) {
|
||||
// return NetworkHelper::toScalar(constant);
|
||||
//}
|
||||
|
||||
if (strSlice->get_input_shape(0).size() != constant->get_shape().size()) {
|
||||
const auto constantShape = constant->get_shape();
|
||||
@ -39,7 +40,7 @@ std::shared_ptr<Node> stridedSliceDeqConstant(
|
||||
}
|
||||
|
||||
const auto stridedSlice = as_type_ptr<ngraph::opset1::StridedSlice>(strSlice);
|
||||
return fold<ngraph::opset1::StridedSlice>(
|
||||
const auto result = fold<ngraph::opset1::StridedSlice>(
|
||||
constant,
|
||||
stridedSlice->get_input_node_shared_ptr(1),
|
||||
stridedSlice->get_input_node_shared_ptr(2),
|
||||
@ -49,6 +50,8 @@ std::shared_ptr<Node> stridedSliceDeqConstant(
|
||||
stridedSlice->get_new_axis_mask(),
|
||||
stridedSlice->get_shrink_axis_mask(),
|
||||
stridedSlice->get_ellipsis_mask());
|
||||
|
||||
return NetworkHelper::toScalarIfPossible(result);
|
||||
}
|
||||
|
||||
StridedSliceTransformation::StridedSliceTransformation(const Params& params) : LayerTransformation(params) {}
|
||||
|
Loading…
Reference in New Issue
Block a user