[LPT] StridedSlice extending (#10148)

* [LPT] StridedSlice extending

* [LPT] tests
This commit is contained in:
Edward Shogulin 2022-02-09 11:23:18 +03:00 committed by GitHub
parent 9d40c5184f
commit c4e54d882b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 68 additions and 4 deletions

View File

@ -55,6 +55,9 @@ std::shared_ptr<opset1::Constant> stridedSliceDeqConstant(
auto beginMask = stridedSlice->get_begin_mask();
auto endMask = stridedSlice->get_end_mask();
for (size_t i = 0; i < constantShape.size(); ++i) {
if ((beginMask.size() <= i) && (endMask.size() <= i)) {
break;
}
// don't slice constant if current dimension is 1
if (constantShape[i] == 1ul) {
beginMask[i] = 1ul;

View File

@ -136,6 +136,17 @@ StridedSliceTransformationTestValues::LayerParams channelSlice = {
{} // elipsisMask
};
StridedSliceTransformationTestValues::LayerParams channelSlice2D = {
{0, 0}, // begin
{0, 2}, // end
{1, 1}, // strided
{1, 0}, // beginMask
{1, 0}, // endMask
{0, 0}, // newAxisMask
{0, 0}, // shrinkAxisMask
{0, 0} // elipsisMask
};
StridedSliceTransformationTestValues::LayerParams spatialDimensionSlice = {
{ 0, 0, 0, 0 },
{ 1, 3, 20, 24 },
@ -232,6 +243,21 @@ const std::vector<StridedSliceTransformationTestValues> stridedSliceTransformati
{{ngraph::element::f32}, {{ 128.f, 64.f }}, {{ 0.1f, 0.01f }}}
}
},
// U8: channel slice, per-channel quantization with different values
{
LayerTransformation::createParamsU8I8(),
channelSlice2D,
{
ngraph::element::u8,
{{ngraph::element::f32}, {{ 128.f, 64.f, 128.f }}, {{ 0.1f, 0.01f, 1.f }}}
},
{
ngraph::element::u8,
{},
ngraph::element::u8,
{{ngraph::element::f32}, {{ 128.f, 64.f }}, {{ 0.1f, 0.01f }}}
}
},
// U8: without subtract
{
LayerTransformation::createParamsU8I8(),

View File

@ -18,10 +18,7 @@ const std::vector<ngraph::element::Type> netPrecisions = {
};
const std::vector<ngraph::pass::low_precision::LayerTransformation::Params> trasformationParamValues = {
LayerTestsUtils::LayerTransformationParamsNGraphFactory::createParams(),
// LayerTestsUtils::LayerTransformationParamsNGraphFactory::createParams().setUpdatePrecisions(false),
// LayerTestsUtils::LayerTransformationParamsNGraphFactory::createParamsI8I8(),
// LayerTestsUtils::LayerTransformationParamsNGraphFactory::createParamsU8I8()
LayerTestsUtils::LayerTransformationParamsNGraphFactory::createParams()
};
const std::vector<LayerTestsDefinitions::StridedSliceTransformationParam> params = {
@ -68,6 +65,25 @@ const std::vector<LayerTestsDefinitions::StridedSliceTransformationParam> params
{},
{}
},
// channel slice, per-channel quantization
{
{
256ul,
ngraph::Shape{ 1, 3, 1, 1 },
{ 0.f, 0.f, 0.f },
{ 255.f, 25.5f, 2.55f },
{ 0.f, 0.f, 0.f },
{ 255.f, 25.5f, 2.55f },
},
{ 0, 0 },
{ 1, 2 },
{ 1, 1 },
{ 1, 0 },
{ 1, 0 },
{},
{},
{}
},
// special dimension slice, per-channel quantization
{
{

View File

@ -65,6 +65,25 @@ const std::vector<LayerTestsDefinitions::StridedSliceTransformationParam> params
{},
{}
},
// channel slice, per-channel quantization
{
{
256ul,
ngraph::Shape{ 1, 3, 1, 1 },
{ 0.f, 0.f, 0.f },
{ 255.f, 25.5f, 2.55f },
{ 0.f, 0.f, 0.f },
{ 255.f, 25.5f, 2.55f },
},
{ 0, 0 },
{ 1, 2 },
{ 1, 1 },
{ 1, 0 },
{ 1, 0 },
{},
{},
{}
},
// special dimension slice, per-channel quantization
{
{