[LPT] StridedSliceTransformation accuracy degradation fix (#4300)
* [LPT] StridedSliceTransformation fix * added comments
This commit is contained in:
parent
5e17926604
commit
8a4aa1bafa
@ -17,9 +17,10 @@ std::shared_ptr<Node> stridedSliceDeqConstant(
|
|||||||
const std::shared_ptr<ngraph::Node> strSlice,
|
const std::shared_ptr<ngraph::Node> strSlice,
|
||||||
const std::shared_ptr<ngraph::Node> dequantizaitonConstant) {
|
const std::shared_ptr<ngraph::Node> dequantizaitonConstant) {
|
||||||
auto constant = as_type_ptr<ngraph::opset1::Constant>(dequantizaitonConstant);
|
auto constant = as_type_ptr<ngraph::opset1::Constant>(dequantizaitonConstant);
|
||||||
if (NetworkHelper::isScalarLike(constant)) {
|
// issue #48857: constant is mistakenly recognized as a scalar. Uncomment after fix
|
||||||
return NetworkHelper::toScalar(constant);
|
//if (NetworkHelper::isScalarLike(constant)) {
|
||||||
}
|
// return NetworkHelper::toScalar(constant);
|
||||||
|
//}
|
||||||
|
|
||||||
if (strSlice->get_input_shape(0).size() != constant->get_shape().size()) {
|
if (strSlice->get_input_shape(0).size() != constant->get_shape().size()) {
|
||||||
const auto constantShape = constant->get_shape();
|
const auto constantShape = constant->get_shape();
|
||||||
@ -39,7 +40,7 @@ std::shared_ptr<Node> stridedSliceDeqConstant(
|
|||||||
}
|
}
|
||||||
|
|
||||||
const auto stridedSlice = as_type_ptr<ngraph::opset1::StridedSlice>(strSlice);
|
const auto stridedSlice = as_type_ptr<ngraph::opset1::StridedSlice>(strSlice);
|
||||||
return fold<ngraph::opset1::StridedSlice>(
|
const auto result = fold<ngraph::opset1::StridedSlice>(
|
||||||
constant,
|
constant,
|
||||||
stridedSlice->get_input_node_shared_ptr(1),
|
stridedSlice->get_input_node_shared_ptr(1),
|
||||||
stridedSlice->get_input_node_shared_ptr(2),
|
stridedSlice->get_input_node_shared_ptr(2),
|
||||||
@ -49,6 +50,8 @@ std::shared_ptr<Node> stridedSliceDeqConstant(
|
|||||||
stridedSlice->get_new_axis_mask(),
|
stridedSlice->get_new_axis_mask(),
|
||||||
stridedSlice->get_shrink_axis_mask(),
|
stridedSlice->get_shrink_axis_mask(),
|
||||||
stridedSlice->get_ellipsis_mask());
|
stridedSlice->get_ellipsis_mask());
|
||||||
|
|
||||||
|
return NetworkHelper::toScalarIfPossible(result);
|
||||||
}
|
}
|
||||||
|
|
||||||
StridedSliceTransformation::StridedSliceTransformation(const Params& params) : LayerTransformation(params) {}
|
StridedSliceTransformation::StridedSliceTransformation(const Params& params) : LayerTransformation(params) {}
|
||||||
|
Loading…
Reference in New Issue
Block a user