[LPT] StridedSlice dequantization improvement (#17563)
* [LPT] StridedSlice dequantization improvement * review comments: refactoring & simplification
This commit is contained in:
parent
031f2cc7d1
commit
43d67b0a32
@ -18,62 +18,81 @@ namespace low_precision {
|
||||
namespace {
|
||||
|
||||
std::shared_ptr<ov::opset1::Constant> stridedSliceDeqConstant(
|
||||
const std::shared_ptr<ngraph::Node> strSlice,
|
||||
const std::shared_ptr<ngraph::Node> dequantizaitonConstant) {
|
||||
auto constant = ov::as_type_ptr<ov::opset1::Constant>(dequantizaitonConstant);
|
||||
auto constantShape = constant->get_shape();
|
||||
if (shape_size(constantShape) == 1ul) {
|
||||
const std::shared_ptr<Node> node,
|
||||
const std::shared_ptr<Node> dequantizaiton_constant) {
|
||||
const auto constant = ov::as_type_ptr<ov::opset1::Constant>(dequantizaiton_constant);
|
||||
const auto& original_constant_shape = constant->get_shape();
|
||||
if (shape_size(original_constant_shape) == 1ul) {
|
||||
return NetworkHelper::toScalar(constant);
|
||||
}
|
||||
|
||||
const auto stridedSlicePShape = strSlice->get_input_partial_shape(0);
|
||||
const size_t rank = stridedSlicePShape.rank().get_length();
|
||||
if (rank != constantShape.size()) {
|
||||
ngraph::Shape newConstantShape;
|
||||
if (ngraph::shape_size(constantShape) == 1) {
|
||||
newConstantShape = ngraph::Shape(rank, 1);
|
||||
} else {
|
||||
newConstantShape = constantShape;
|
||||
// step #1: align shapes
|
||||
std::shared_ptr<ov::opset1::Constant> new_constant = constant;
|
||||
const size_t rank = node->get_input_partial_shape(0).size();
|
||||
Shape new_constant_shape = original_constant_shape;
|
||||
if (rank != new_constant_shape.size()) {
|
||||
// case when constant shape without batch
|
||||
if (original_constant_shape.size() < rank) {
|
||||
new_constant_shape.insert(new_constant_shape.begin(), 1);
|
||||
}
|
||||
|
||||
// case when constShape without batch
|
||||
if ((constantShape.size() > 1) &&
|
||||
(constantShape.size() < rank)) {
|
||||
newConstantShape.insert(newConstantShape.begin(), 1);
|
||||
if (original_constant_shape != new_constant_shape) {
|
||||
const auto result = fold<ov::opset1::Broadcast>(
|
||||
constant,
|
||||
ov::opset1::Constant::create(ov::element::i32, { new_constant_shape.size() }, new_constant_shape));
|
||||
new_constant = ov::as_type_ptr<ov::opset1::Constant>(result);
|
||||
}
|
||||
}
|
||||
|
||||
// step #2: update original begin & end & strides
|
||||
auto cast_vector = [](const std::shared_ptr<ov::opset1::StridedSlice>& strided_slice, const size_t i) {
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
const auto constant = ov::get_constant_from_source(strided_slice->get_input_source_output(i));
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
assert(constant != nullptr);
|
||||
return constant->cast_vector<int64_t>();
|
||||
};
|
||||
|
||||
const auto strided_slice = ov::as_type_ptr<ov::opset1::StridedSlice>(node);
|
||||
auto begin = cast_vector(strided_slice, 1);
|
||||
auto end = cast_vector(strided_slice, 2);
|
||||
auto strides = cast_vector(strided_slice, 3);
|
||||
auto begin_mask = strided_slice->get_begin_mask();
|
||||
auto end_mask = strided_slice->get_end_mask();
|
||||
for (auto i = 0ull; i < new_constant_shape.size(); ++i) {
|
||||
// don't slice constant if current dimension is 1
|
||||
if (new_constant_shape[i] == 1ull) {
|
||||
if (i < begin.size()) {
|
||||
begin[i] = 0;
|
||||
}
|
||||
if (i < end.size()) {
|
||||
end[i] = 0;
|
||||
}
|
||||
if (i < strides.size()) {
|
||||
strides[i] = 1;
|
||||
}
|
||||
|
||||
if (i < begin_mask.size()) {
|
||||
begin_mask[i] = 1;
|
||||
}
|
||||
|
||||
if (i < end_mask.size()) {
|
||||
end_mask[i] = 1;
|
||||
}
|
||||
}
|
||||
constantShape = newConstantShape;
|
||||
|
||||
const auto newConstant = fold<ov::opset1::Broadcast>(
|
||||
constant,
|
||||
ov::opset1::Constant::create(ngraph::element::i32, { newConstantShape.size() }, newConstantShape));
|
||||
constant = ov::as_type_ptr<ov::opset1::Constant>(newConstant);
|
||||
}
|
||||
|
||||
const auto stridedSlice = ov::as_type_ptr<ov::opset1::StridedSlice>(strSlice);
|
||||
|
||||
auto beginMask = stridedSlice->get_begin_mask();
|
||||
auto endMask = stridedSlice->get_end_mask();
|
||||
for (size_t i = 0; i < constantShape.size(); ++i) {
|
||||
if ((beginMask.size() <= i) && (endMask.size() <= i)) {
|
||||
break;
|
||||
}
|
||||
// don't slice constant if current dimension is 1
|
||||
if (constantShape[i] == 1ul) {
|
||||
beginMask[i] = 1ul;
|
||||
endMask[i] = 1ul;
|
||||
}
|
||||
}
|
||||
|
||||
// step #3: final step: dequantizatin constant folding
|
||||
const auto result = fold<ov::opset1::StridedSlice>(
|
||||
constant,
|
||||
stridedSlice->input_value(1),
|
||||
stridedSlice->input_value(2),
|
||||
stridedSlice->input_value(3),
|
||||
beginMask,
|
||||
endMask,
|
||||
stridedSlice->get_new_axis_mask(),
|
||||
stridedSlice->get_shrink_axis_mask(),
|
||||
stridedSlice->get_ellipsis_mask());
|
||||
new_constant,
|
||||
std::make_shared<ov::opset1::Constant>(element::i64, Shape{ begin.size() }, begin),
|
||||
std::make_shared<ov::opset1::Constant>(element::i64, Shape{ end.size() }, end),
|
||||
std::make_shared<ov::opset1::Constant>(element::i64, Shape{ strides.size() }, strides),
|
||||
begin_mask,
|
||||
end_mask,
|
||||
strided_slice->get_new_axis_mask(),
|
||||
strided_slice->get_shrink_axis_mask(),
|
||||
strided_slice->get_ellipsis_mask());
|
||||
|
||||
return ov::as_type_ptr<ov::opset1::Constant>(NetworkHelper::toScalarIfPossible(result));
|
||||
}
|
||||
@ -84,7 +103,7 @@ StridedSliceTransformation::StridedSliceTransformation(const Params& params) : L
|
||||
MATCHER_SCOPE(StridedSliceTransformation);
|
||||
auto matcher = ngraph::pattern::wrap_type<ov::opset1::StridedSlice>();
|
||||
|
||||
ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
|
||||
ov::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
|
||||
auto op = m.get_match_root();
|
||||
if (transformation_callback(op)) {
|
||||
return false;
|
||||
@ -92,29 +111,29 @@ StridedSliceTransformation::StridedSliceTransformation(const Params& params) : L
|
||||
return transform(*context, m);
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(matcher, matcher_name);
|
||||
auto m = std::make_shared<ov::pass::pattern::Matcher>(matcher, matcher_name);
|
||||
this->register_matcher(m, callback);
|
||||
}
|
||||
|
||||
bool StridedSliceTransformation::transform(TransformationContext& context, ngraph::pattern::Matcher& m) {
|
||||
bool StridedSliceTransformation::transform(TransformationContext& context, ov::pass::pattern::Matcher& m) {
|
||||
if (!StridedSliceTransformation::canBeTransformed(context, m.get_match_root())) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto stridedSlice = NetworkHelper::separateInStandaloneBranch(m.get_match_root(), defaultPrecisions);
|
||||
auto dequantization = NetworkHelper::getDequantization(stridedSlice, defaultPrecisions);
|
||||
const auto strided_slice = NetworkHelper::separateInStandaloneBranch(m.get_match_root(), defaultPrecisions);
|
||||
auto dequantization = NetworkHelper::getDequantization(strided_slice, defaultPrecisions);
|
||||
|
||||
if (dequantization.subtract) {
|
||||
const auto newSubConst = stridedSliceDeqConstant(stridedSlice, dequantization.subtractConstant);
|
||||
replace_node(dequantization.subtractConstant, newSubConst);
|
||||
dequantization.subtractConstant = newSubConst;
|
||||
const auto new_sub_const = stridedSliceDeqConstant(strided_slice, dequantization.subtractConstant);
|
||||
replace_node(dequantization.subtractConstant, new_sub_const);
|
||||
dequantization.subtractConstant = new_sub_const;
|
||||
}
|
||||
|
||||
const auto newMulConst = stridedSliceDeqConstant(stridedSlice, dequantization.multiplyConstant);
|
||||
replace_node(dequantization.multiplyConstant, newMulConst);
|
||||
dequantization.multiplyConstant = newMulConst;
|
||||
const auto new_mul_const = stridedSliceDeqConstant(strided_slice, dequantization.multiplyConstant);
|
||||
replace_node(dequantization.multiplyConstant, new_mul_const);
|
||||
dequantization.multiplyConstant = new_mul_const;
|
||||
|
||||
moveDequantizationAfter(context, stridedSlice, NetworkHelper::getDequantization(stridedSlice, defaultPrecisions), false);
|
||||
moveDequantizationAfter(context, strided_slice, NetworkHelper::getDequantization(strided_slice, defaultPrecisions), false);
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -128,13 +147,21 @@ bool StridedSliceTransformation::canBeTransformed(const TransformationContext& c
|
||||
return false;
|
||||
}
|
||||
|
||||
if (operation->get_input_partial_shape(0).rank().is_dynamic() &&
|
||||
((dequantization.subtract && ngraph::shape_size(dequantization.subtractConstant->get_shape()) > 1) ||
|
||||
(dequantization.multiply && ngraph::shape_size(dequantization.multiplyConstant->get_shape()) > 1))) {
|
||||
const auto is_dequantization_scalar =
|
||||
((dequantization.subtract && shape_size(dequantization.subtractConstant->get_shape()) == 1ull) &&
|
||||
(dequantization.multiply && shape_size(dequantization.multiplyConstant->get_shape()) == 1ull));
|
||||
|
||||
if (operation->get_input_partial_shape(0).rank().is_dynamic() && !is_dequantization_scalar) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
return
|
||||
is_dequantization_scalar ||
|
||||
(ov::get_constant_from_source(operation->get_input_source_output(1)) &&
|
||||
ov::get_constant_from_source(operation->get_input_source_output(2)) &&
|
||||
ov::get_constant_from_source(operation->get_input_source_output(3)));
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
}
|
||||
|
||||
bool StridedSliceTransformation::isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept {
|
||||
|
@ -136,6 +136,8 @@ StridedSliceTransformationTestValues::LayerParams channelSlice = {
|
||||
{} // elipsisMask
|
||||
};
|
||||
|
||||
namespace inputs_4d {
|
||||
|
||||
StridedSliceTransformationTestValues::LayerParams channelSlice2D = {
|
||||
{0, 0}, // begin
|
||||
{0, 2}, // end
|
||||
@ -177,7 +179,7 @@ StridedSliceTransformationTestValues::LayerParams sliceWithRemovedAxis = {
|
||||
{ 1, 0, 1, 1 }, // endMask
|
||||
{ 0, 0, 0, 0 }, // newAxisMask
|
||||
{ 0, 1, 0, 0 }, // shrinkAxisMask
|
||||
{ 0, 0, 0, 0 } // elipsisMask
|
||||
{ 0, 0, 0, 0 } // elipsisMask
|
||||
};
|
||||
|
||||
StridedSliceTransformationTestValues::LayerParams sliceWithAdditionalAxis = {
|
||||
@ -191,7 +193,6 @@ StridedSliceTransformationTestValues::LayerParams sliceWithAdditionalAxis = {
|
||||
{ 0, 0, 0, 0 } // elipsisMask
|
||||
};
|
||||
|
||||
namespace testValues1 {
|
||||
const std::vector<ngraph::PartialShape> inputShapes = {
|
||||
{1, 3, 24, 24},
|
||||
{-1, -1, -1, -1}
|
||||
@ -549,9 +550,9 @@ INSTANTIATE_TEST_SUITE_P(
|
||||
::testing::ValuesIn(inputShapes),
|
||||
::testing::ValuesIn(stridedSliceTransformationTestValues)),
|
||||
StridedSliceTransformation::getTestCaseName);
|
||||
} // namespace testValues1
|
||||
} // namespace inputs_4d
|
||||
|
||||
namespace testValues2 {
|
||||
namespace inputs_4d_spatial {
|
||||
const std::vector<ngraph::PartialShape> inputShapes = {
|
||||
{ -1, -1, -1, -1 },
|
||||
{ 1, 3, 4, 4 }
|
||||
@ -590,9 +591,9 @@ INSTANTIATE_TEST_SUITE_P(
|
||||
::testing::ValuesIn(inputShapes),
|
||||
::testing::ValuesIn(testValuesWithDQBySpatialDimension)),
|
||||
StridedSliceTransformation::getTestCaseName);
|
||||
} // namespace testValues2
|
||||
} // namespace inputs_4d_spatial
|
||||
|
||||
namespace testValues3 {
|
||||
namespace dynamic_inputs {
|
||||
const std::vector<ngraph::PartialShape> inputShapesWithDynamicChannels = {
|
||||
PartialShape::dynamic()
|
||||
};
|
||||
@ -637,5 +638,92 @@ INSTANTIATE_TEST_SUITE_P(
|
||||
::testing::ValuesIn(inputShapesWithDynamicChannels),
|
||||
::testing::ValuesIn(testValues)),
|
||||
StridedSliceTransformation::getTestCaseName);
|
||||
} // namespace testValues3
|
||||
} // namespace dynamic_inputs
|
||||
|
||||
namespace inputs_3d {
|
||||
const std::vector<ngraph::PartialShape> inputShapes = {
|
||||
{ 1, 3, 4 },
|
||||
{ 1, -1, 4 }
|
||||
};
|
||||
|
||||
StridedSliceTransformationTestValues::LayerParams slice = {
|
||||
{ 0, 1 }, // begin
|
||||
{ 0, 2 }, // end
|
||||
{ 1, 1 }, // strided
|
||||
{ 1, 0 }, // beginMask
|
||||
{ 1, 0 }, // endMask
|
||||
{ 0, 0 }, // newAxisMask
|
||||
{ 0, 1 }, // shrinkAxisMask
|
||||
{ 0, 0 } // elipsisMask
|
||||
};
|
||||
|
||||
StridedSliceTransformationTestValues::LayerParams slice2 = {
|
||||
{ 0, 1 }, // begin
|
||||
{ 0, 2 }, // end
|
||||
{ 1, 1 }, // strided
|
||||
{ 1, 0 }, // beginMask
|
||||
{ 1, 0 }, // endMask
|
||||
{ 0, 0 }, // newAxisMask
|
||||
{ 0, 1 }, // shrinkAxisMask
|
||||
{ 0, 0 } // elipsisMask
|
||||
};
|
||||
|
||||
const std::vector<StridedSliceTransformationTestValues> testValuesWithDQBySpatialDimension = {
|
||||
// U8: channel slice, quantization by spatial dimension
|
||||
{
|
||||
LayerTransformation::createParamsU8I8(),
|
||||
slice,
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{
|
||||
{ngraph::element::f32},
|
||||
{{1.f, 2.f, 3.f, 4.f}, ngraph::element::f32, {1, 1, 4}},
|
||||
{{1.f, 2.f, 3.f, 4.f}, ngraph::element::f32, {1, 1, 4}}
|
||||
}
|
||||
},
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{},
|
||||
ngraph::element::u8,
|
||||
{
|
||||
{ngraph::element::f32},
|
||||
{{1.f, 2.f, 3.f, 4.f}, ngraph::element::f32, {1, 4}},
|
||||
{{1.f, 2.f, 3.f, 4.f}, ngraph::element::f32, {1, 4}}
|
||||
}
|
||||
}
|
||||
},
|
||||
// U8: channel slice, quantization by spatial dimension
|
||||
{
|
||||
LayerTransformation::createParamsU8I8(),
|
||||
slice2,
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{
|
||||
{ngraph::element::f32},
|
||||
{{1.f, 2.f, 3.f}, ngraph::element::f32, {1, 3, 1}},
|
||||
{{1.f, 2.f, 3.f}, ngraph::element::f32, {1, 3, 1}}
|
||||
}
|
||||
},
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{},
|
||||
ngraph::element::u8,
|
||||
{
|
||||
{ngraph::element::f32},
|
||||
{{2.f}, ngraph::element::f32, {}},
|
||||
{{2.f}, ngraph::element::f32, {}}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
smoke_LPT,
|
||||
StridedSliceTransformation,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(inputShapes),
|
||||
::testing::ValuesIn(testValuesWithDQBySpatialDimension)),
|
||||
StridedSliceTransformation::getTestCaseName);
|
||||
} // namespace inputs_3d
|
||||
|
||||
} // namespace
|
||||
|
@ -34,8 +34,11 @@ std::shared_ptr<ngraph::Function> StridedSliceFunction::getOriginal(
|
||||
const auto deq = makeDequantization(input, dequantization);
|
||||
|
||||
const auto beginParam = ngraph::op::Constant::create(ngraph::element::i64, ngraph::Shape{ begin.size() }, begin);
|
||||
beginParam->set_friendly_name("begin");
|
||||
const auto endParam = ngraph::op::Constant::create(ngraph::element::i64, ngraph::Shape{ end.size() }, end);
|
||||
endParam->set_friendly_name("end");
|
||||
const auto stridesParam = ngraph::op::Constant::create(ngraph::element::i64, ngraph::Shape{ strides.size() }, strides);
|
||||
stridesParam->set_friendly_name("strides");
|
||||
|
||||
const auto stridedSlice = std::make_shared<ngraph::opset1::StridedSlice>(
|
||||
deq, beginParam, endParam, stridesParam,
|
||||
@ -69,8 +72,11 @@ std::shared_ptr<ngraph::Function> StridedSliceFunction::getOriginal(
|
||||
const auto fqOnData = makeFakeQuantize(input, inputPrecision, fakeQuantize);
|
||||
|
||||
const auto beginParam = ngraph::op::Constant::create(ngraph::element::i64, ngraph::Shape{ begin.size() }, begin);
|
||||
beginParam->set_friendly_name("begin");
|
||||
const auto endParam = ngraph::op::Constant::create(ngraph::element::i64, ngraph::Shape{ end.size() }, end);
|
||||
endParam->set_friendly_name("end");
|
||||
const auto stridesParam = ngraph::op::Constant::create(ngraph::element::i64, ngraph::Shape{ strides.size() }, strides);
|
||||
stridesParam->set_friendly_name("strides");
|
||||
|
||||
const auto stridedSlice = std::make_shared<ngraph::opset1::StridedSlice>(
|
||||
fqOnData, beginParam, endParam, stridesParam,
|
||||
@ -106,8 +112,11 @@ std::shared_ptr<ngraph::Function> StridedSliceFunction::getReference(
|
||||
const auto deqBefore = makeDequantization(input, dequantizationBefore);
|
||||
|
||||
const auto beginParam = ngraph::op::Constant::create(ngraph::element::i64, ngraph::Shape{ begin.size() }, begin);
|
||||
beginParam->set_friendly_name("begin");
|
||||
const auto endParam = ngraph::op::Constant::create(ngraph::element::i64, ngraph::Shape{ end.size() }, end);
|
||||
endParam->set_friendly_name("end");
|
||||
const auto stridesParam = ngraph::op::Constant::create(ngraph::element::i64, ngraph::Shape{ strides.size() }, strides);
|
||||
stridesParam->set_friendly_name("strides");
|
||||
|
||||
const auto stridedSlice = std::make_shared<ngraph::opset1::StridedSlice>(
|
||||
deqBefore, beginParam, endParam, stridesParam,
|
||||
|
Loading…
Reference in New Issue
Block a user