[LPT] StridedSlice dequantization improvement (#17563)

* [LPT] StridedSlice dequantization improvement

* review comments: refactoring & simplification
This commit is contained in:
Edward Shogulin 2023-06-02 08:47:36 +01:00 committed by GitHub
parent 031f2cc7d1
commit 43d67b0a32
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 194 additions and 70 deletions

View File

@ -18,62 +18,81 @@ namespace low_precision {
namespace {
std::shared_ptr<ov::opset1::Constant> stridedSliceDeqConstant(
const std::shared_ptr<ngraph::Node> strSlice,
const std::shared_ptr<ngraph::Node> dequantizaitonConstant) {
auto constant = ov::as_type_ptr<ov::opset1::Constant>(dequantizaitonConstant);
auto constantShape = constant->get_shape();
if (shape_size(constantShape) == 1ul) {
const std::shared_ptr<Node> node,
const std::shared_ptr<Node> dequantizaiton_constant) {
const auto constant = ov::as_type_ptr<ov::opset1::Constant>(dequantizaiton_constant);
const auto& original_constant_shape = constant->get_shape();
if (shape_size(original_constant_shape) == 1ul) {
return NetworkHelper::toScalar(constant);
}
const auto stridedSlicePShape = strSlice->get_input_partial_shape(0);
const size_t rank = stridedSlicePShape.rank().get_length();
if (rank != constantShape.size()) {
ngraph::Shape newConstantShape;
if (ngraph::shape_size(constantShape) == 1) {
newConstantShape = ngraph::Shape(rank, 1);
} else {
newConstantShape = constantShape;
// step #1: align shapes
std::shared_ptr<ov::opset1::Constant> new_constant = constant;
const size_t rank = node->get_input_partial_shape(0).size();
Shape new_constant_shape = original_constant_shape;
if (rank != new_constant_shape.size()) {
// case when constant shape without batch
if (original_constant_shape.size() < rank) {
new_constant_shape.insert(new_constant_shape.begin(), 1);
}
// case when constShape without batch
if ((constantShape.size() > 1) &&
(constantShape.size() < rank)) {
newConstantShape.insert(newConstantShape.begin(), 1);
if (original_constant_shape != new_constant_shape) {
const auto result = fold<ov::opset1::Broadcast>(
constant,
ov::opset1::Constant::create(ov::element::i32, { new_constant_shape.size() }, new_constant_shape));
new_constant = ov::as_type_ptr<ov::opset1::Constant>(result);
}
}
// step #2: update original begin & end & strides
auto cast_vector = [](const std::shared_ptr<ov::opset1::StridedSlice>& strided_slice, const size_t i) {
OPENVINO_SUPPRESS_DEPRECATED_START
const auto constant = ov::get_constant_from_source(strided_slice->get_input_source_output(i));
OPENVINO_SUPPRESS_DEPRECATED_END
assert(constant != nullptr);
return constant->cast_vector<int64_t>();
};
const auto strided_slice = ov::as_type_ptr<ov::opset1::StridedSlice>(node);
auto begin = cast_vector(strided_slice, 1);
auto end = cast_vector(strided_slice, 2);
auto strides = cast_vector(strided_slice, 3);
auto begin_mask = strided_slice->get_begin_mask();
auto end_mask = strided_slice->get_end_mask();
for (auto i = 0ull; i < new_constant_shape.size(); ++i) {
// don't slice constant if current dimension is 1
if (new_constant_shape[i] == 1ull) {
if (i < begin.size()) {
begin[i] = 0;
}
if (i < end.size()) {
end[i] = 0;
}
if (i < strides.size()) {
strides[i] = 1;
}
if (i < begin_mask.size()) {
begin_mask[i] = 1;
}
if (i < end_mask.size()) {
end_mask[i] = 1;
}
}
constantShape = newConstantShape;
const auto newConstant = fold<ov::opset1::Broadcast>(
constant,
ov::opset1::Constant::create(ngraph::element::i32, { newConstantShape.size() }, newConstantShape));
constant = ov::as_type_ptr<ov::opset1::Constant>(newConstant);
}
const auto stridedSlice = ov::as_type_ptr<ov::opset1::StridedSlice>(strSlice);
auto beginMask = stridedSlice->get_begin_mask();
auto endMask = stridedSlice->get_end_mask();
for (size_t i = 0; i < constantShape.size(); ++i) {
if ((beginMask.size() <= i) && (endMask.size() <= i)) {
break;
}
// don't slice constant if current dimension is 1
if (constantShape[i] == 1ul) {
beginMask[i] = 1ul;
endMask[i] = 1ul;
}
}
// step #3: final step: dequantizatin constant folding
const auto result = fold<ov::opset1::StridedSlice>(
constant,
stridedSlice->input_value(1),
stridedSlice->input_value(2),
stridedSlice->input_value(3),
beginMask,
endMask,
stridedSlice->get_new_axis_mask(),
stridedSlice->get_shrink_axis_mask(),
stridedSlice->get_ellipsis_mask());
new_constant,
std::make_shared<ov::opset1::Constant>(element::i64, Shape{ begin.size() }, begin),
std::make_shared<ov::opset1::Constant>(element::i64, Shape{ end.size() }, end),
std::make_shared<ov::opset1::Constant>(element::i64, Shape{ strides.size() }, strides),
begin_mask,
end_mask,
strided_slice->get_new_axis_mask(),
strided_slice->get_shrink_axis_mask(),
strided_slice->get_ellipsis_mask());
return ov::as_type_ptr<ov::opset1::Constant>(NetworkHelper::toScalarIfPossible(result));
}
@ -84,7 +103,7 @@ StridedSliceTransformation::StridedSliceTransformation(const Params& params) : L
MATCHER_SCOPE(StridedSliceTransformation);
auto matcher = ngraph::pattern::wrap_type<ov::opset1::StridedSlice>();
ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
ov::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
auto op = m.get_match_root();
if (transformation_callback(op)) {
return false;
@ -92,29 +111,29 @@ StridedSliceTransformation::StridedSliceTransformation(const Params& params) : L
return transform(*context, m);
};
auto m = std::make_shared<ngraph::pattern::Matcher>(matcher, matcher_name);
auto m = std::make_shared<ov::pass::pattern::Matcher>(matcher, matcher_name);
this->register_matcher(m, callback);
}
bool StridedSliceTransformation::transform(TransformationContext& context, ngraph::pattern::Matcher& m) {
bool StridedSliceTransformation::transform(TransformationContext& context, ov::pass::pattern::Matcher& m) {
if (!StridedSliceTransformation::canBeTransformed(context, m.get_match_root())) {
return false;
}
const auto stridedSlice = NetworkHelper::separateInStandaloneBranch(m.get_match_root(), defaultPrecisions);
auto dequantization = NetworkHelper::getDequantization(stridedSlice, defaultPrecisions);
const auto strided_slice = NetworkHelper::separateInStandaloneBranch(m.get_match_root(), defaultPrecisions);
auto dequantization = NetworkHelper::getDequantization(strided_slice, defaultPrecisions);
if (dequantization.subtract) {
const auto newSubConst = stridedSliceDeqConstant(stridedSlice, dequantization.subtractConstant);
replace_node(dequantization.subtractConstant, newSubConst);
dequantization.subtractConstant = newSubConst;
const auto new_sub_const = stridedSliceDeqConstant(strided_slice, dequantization.subtractConstant);
replace_node(dequantization.subtractConstant, new_sub_const);
dequantization.subtractConstant = new_sub_const;
}
const auto newMulConst = stridedSliceDeqConstant(stridedSlice, dequantization.multiplyConstant);
replace_node(dequantization.multiplyConstant, newMulConst);
dequantization.multiplyConstant = newMulConst;
const auto new_mul_const = stridedSliceDeqConstant(strided_slice, dequantization.multiplyConstant);
replace_node(dequantization.multiplyConstant, new_mul_const);
dequantization.multiplyConstant = new_mul_const;
moveDequantizationAfter(context, stridedSlice, NetworkHelper::getDequantization(stridedSlice, defaultPrecisions), false);
moveDequantizationAfter(context, strided_slice, NetworkHelper::getDequantization(strided_slice, defaultPrecisions), false);
return true;
}
@ -128,13 +147,21 @@ bool StridedSliceTransformation::canBeTransformed(const TransformationContext& c
return false;
}
if (operation->get_input_partial_shape(0).rank().is_dynamic() &&
((dequantization.subtract && ngraph::shape_size(dequantization.subtractConstant->get_shape()) > 1) ||
(dequantization.multiply && ngraph::shape_size(dequantization.multiplyConstant->get_shape()) > 1))) {
const auto is_dequantization_scalar =
((dequantization.subtract && shape_size(dequantization.subtractConstant->get_shape()) == 1ull) &&
(dequantization.multiply && shape_size(dequantization.multiplyConstant->get_shape()) == 1ull));
if (operation->get_input_partial_shape(0).rank().is_dynamic() && !is_dequantization_scalar) {
return false;
}
return true;
OPENVINO_SUPPRESS_DEPRECATED_START
return
is_dequantization_scalar ||
(ov::get_constant_from_source(operation->get_input_source_output(1)) &&
ov::get_constant_from_source(operation->get_input_source_output(2)) &&
ov::get_constant_from_source(operation->get_input_source_output(3)));
OPENVINO_SUPPRESS_DEPRECATED_END
}
bool StridedSliceTransformation::isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept {

View File

@ -136,6 +136,8 @@ StridedSliceTransformationTestValues::LayerParams channelSlice = {
{} // elipsisMask
};
namespace inputs_4d {
StridedSliceTransformationTestValues::LayerParams channelSlice2D = {
{0, 0}, // begin
{0, 2}, // end
@ -177,7 +179,7 @@ StridedSliceTransformationTestValues::LayerParams sliceWithRemovedAxis = {
{ 1, 0, 1, 1 }, // endMask
{ 0, 0, 0, 0 }, // newAxisMask
{ 0, 1, 0, 0 }, // shrinkAxisMask
{ 0, 0, 0, 0 } // elipsisMask
{ 0, 0, 0, 0 } // elipsisMask
};
StridedSliceTransformationTestValues::LayerParams sliceWithAdditionalAxis = {
@ -191,7 +193,6 @@ StridedSliceTransformationTestValues::LayerParams sliceWithAdditionalAxis = {
{ 0, 0, 0, 0 } // elipsisMask
};
namespace testValues1 {
const std::vector<ngraph::PartialShape> inputShapes = {
{1, 3, 24, 24},
{-1, -1, -1, -1}
@ -549,9 +550,9 @@ INSTANTIATE_TEST_SUITE_P(
::testing::ValuesIn(inputShapes),
::testing::ValuesIn(stridedSliceTransformationTestValues)),
StridedSliceTransformation::getTestCaseName);
} // namespace testValues1
} // namespace inputs_4d
namespace testValues2 {
namespace inputs_4d_spatial {
const std::vector<ngraph::PartialShape> inputShapes = {
{ -1, -1, -1, -1 },
{ 1, 3, 4, 4 }
@ -590,9 +591,9 @@ INSTANTIATE_TEST_SUITE_P(
::testing::ValuesIn(inputShapes),
::testing::ValuesIn(testValuesWithDQBySpatialDimension)),
StridedSliceTransformation::getTestCaseName);
} // namespace testValues2
} // namespace inputs_4d_spatial
namespace testValues3 {
namespace dynamic_inputs {
const std::vector<ngraph::PartialShape> inputShapesWithDynamicChannels = {
PartialShape::dynamic()
};
@ -637,5 +638,92 @@ INSTANTIATE_TEST_SUITE_P(
::testing::ValuesIn(inputShapesWithDynamicChannels),
::testing::ValuesIn(testValues)),
StridedSliceTransformation::getTestCaseName);
} // namespace testValues3
} // namespace dynamic_inputs
namespace inputs_3d {
const std::vector<ngraph::PartialShape> inputShapes = {
{ 1, 3, 4 },
{ 1, -1, 4 }
};
StridedSliceTransformationTestValues::LayerParams slice = {
{ 0, 1 }, // begin
{ 0, 2 }, // end
{ 1, 1 }, // strided
{ 1, 0 }, // beginMask
{ 1, 0 }, // endMask
{ 0, 0 }, // newAxisMask
{ 0, 1 }, // shrinkAxisMask
{ 0, 0 } // elipsisMask
};
StridedSliceTransformationTestValues::LayerParams slice2 = {
{ 0, 1 }, // begin
{ 0, 2 }, // end
{ 1, 1 }, // strided
{ 1, 0 }, // beginMask
{ 1, 0 }, // endMask
{ 0, 0 }, // newAxisMask
{ 0, 1 }, // shrinkAxisMask
{ 0, 0 } // elipsisMask
};
const std::vector<StridedSliceTransformationTestValues> testValuesWithDQBySpatialDimension = {
// U8: channel slice, quantization by spatial dimension
{
LayerTransformation::createParamsU8I8(),
slice,
{
ngraph::element::u8,
{
{ngraph::element::f32},
{{1.f, 2.f, 3.f, 4.f}, ngraph::element::f32, {1, 1, 4}},
{{1.f, 2.f, 3.f, 4.f}, ngraph::element::f32, {1, 1, 4}}
}
},
{
ngraph::element::u8,
{},
ngraph::element::u8,
{
{ngraph::element::f32},
{{1.f, 2.f, 3.f, 4.f}, ngraph::element::f32, {1, 4}},
{{1.f, 2.f, 3.f, 4.f}, ngraph::element::f32, {1, 4}}
}
}
},
// U8: channel slice, quantization by spatial dimension
{
LayerTransformation::createParamsU8I8(),
slice2,
{
ngraph::element::u8,
{
{ngraph::element::f32},
{{1.f, 2.f, 3.f}, ngraph::element::f32, {1, 3, 1}},
{{1.f, 2.f, 3.f}, ngraph::element::f32, {1, 3, 1}}
}
},
{
ngraph::element::u8,
{},
ngraph::element::u8,
{
{ngraph::element::f32},
{{2.f}, ngraph::element::f32, {}},
{{2.f}, ngraph::element::f32, {}}
}
}
}
};
INSTANTIATE_TEST_SUITE_P(
smoke_LPT,
StridedSliceTransformation,
::testing::Combine(
::testing::ValuesIn(inputShapes),
::testing::ValuesIn(testValuesWithDQBySpatialDimension)),
StridedSliceTransformation::getTestCaseName);
} // namespace inputs_3d
} // namespace

View File

@ -34,8 +34,11 @@ std::shared_ptr<ngraph::Function> StridedSliceFunction::getOriginal(
const auto deq = makeDequantization(input, dequantization);
const auto beginParam = ngraph::op::Constant::create(ngraph::element::i64, ngraph::Shape{ begin.size() }, begin);
beginParam->set_friendly_name("begin");
const auto endParam = ngraph::op::Constant::create(ngraph::element::i64, ngraph::Shape{ end.size() }, end);
endParam->set_friendly_name("end");
const auto stridesParam = ngraph::op::Constant::create(ngraph::element::i64, ngraph::Shape{ strides.size() }, strides);
stridesParam->set_friendly_name("strides");
const auto stridedSlice = std::make_shared<ngraph::opset1::StridedSlice>(
deq, beginParam, endParam, stridesParam,
@ -69,8 +72,11 @@ std::shared_ptr<ngraph::Function> StridedSliceFunction::getOriginal(
const auto fqOnData = makeFakeQuantize(input, inputPrecision, fakeQuantize);
const auto beginParam = ngraph::op::Constant::create(ngraph::element::i64, ngraph::Shape{ begin.size() }, begin);
beginParam->set_friendly_name("begin");
const auto endParam = ngraph::op::Constant::create(ngraph::element::i64, ngraph::Shape{ end.size() }, end);
endParam->set_friendly_name("end");
const auto stridesParam = ngraph::op::Constant::create(ngraph::element::i64, ngraph::Shape{ strides.size() }, strides);
stridesParam->set_friendly_name("strides");
const auto stridedSlice = std::make_shared<ngraph::opset1::StridedSlice>(
fqOnData, beginParam, endParam, stridesParam,
@ -106,8 +112,11 @@ std::shared_ptr<ngraph::Function> StridedSliceFunction::getReference(
const auto deqBefore = makeDequantization(input, dequantizationBefore);
const auto beginParam = ngraph::op::Constant::create(ngraph::element::i64, ngraph::Shape{ begin.size() }, begin);
beginParam->set_friendly_name("begin");
const auto endParam = ngraph::op::Constant::create(ngraph::element::i64, ngraph::Shape{ end.size() }, end);
endParam->set_friendly_name("end");
const auto stridesParam = ngraph::op::Constant::create(ngraph::element::i64, ngraph::Shape{ strides.size() }, strides);
stridesParam->set_friendly_name("strides");
const auto stridedSlice = std::make_shared<ngraph::opset1::StridedSlice>(
deqBefore, beginParam, endParam, stridesParam,