* [LPT] StridedSlice fix

* [LPT] separateInStandaloneBranch fix
This commit is contained in:
Vladislav Golubev
2021-06-21 18:02:23 +03:00
committed by GitHub
parent 9e7d98fca9
commit 6799a31911
4 changed files with 106 additions and 8 deletions

View File

@@ -549,7 +549,11 @@ std::shared_ptr<ngraph::Node> NetworkHelper::separateInStandaloneBranch(std::sha
}
std::vector<Output<Node>> inputs = node->input_values();
const size_t inputIndex = NetworkHelper::getChildInputIndex(dequantization.multiply, node);
const auto originalParent = dequantization.multiply ?
dequantization.multiply->shared_from_this() :
dequantization.subtract->shared_from_this();
const size_t inputIndex = NetworkHelper::getChildInputIndex(originalParent, node);
inputs[inputIndex] = parent;
const std::shared_ptr<Node> newNode = node->clone_with_new_inputs(inputs);

View File

@@ -17,13 +17,12 @@ std::shared_ptr<Node> stridedSliceDeqConstant(
const std::shared_ptr<ngraph::Node> strSlice,
const std::shared_ptr<ngraph::Node> dequantizaitonConstant) {
auto constant = as_type_ptr<ngraph::opset1::Constant>(dequantizaitonConstant);
// issue #48857: constant is mistakenly recognized as a scalar. Uncomment after fix
//if (NetworkHelper::isScalarLike(constant)) {
// return NetworkHelper::toScalar(constant);
//}
auto constantShape = constant->get_shape();
if (ngraph::shape_size(constantShape) == 1ul) {
return NetworkHelper::toScalar(constant);
}
const auto stridedSliceShape = strSlice->get_input_shape(0);
auto constantShape = constant->get_shape();
if (stridedSliceShape.size() != constantShape.size()) {
ngraph::Shape newConstantShape;
if (ngraph::shape_size(constantShape) == 1) {

View File

@@ -81,7 +81,6 @@ public:
"SeparateInStandaloneBranchTransformation");
};
actualFunction = createActualFunction(testValues.precisionBefore, shape, testValues.dequantization);
const auto result = actualFunction->get_results()[0];
ngraph::pass::low_precision::NetworkHelper::separateInStandaloneBranch(result->get_input_node_shared_ptr(0));
@@ -143,6 +142,11 @@ std::vector<SeparateInStandaloneBranchTransformationTestValues> testValues = {
ngraph::element::u8,
{ ngraph::element::f32, { 127.f }, { 0.02f } }
},
{
LayerTransformation::createParamsU8U8(),
ngraph::element::u8,
{ ngraph::element::f32, { 127.f }, {} }
},
{
LayerTransformation::createParamsU8U8().setSupportAsymmetricQuantization(true),
ngraph::element::u8,

View File

@@ -117,7 +117,8 @@ public:
testValues.inputShape << testValues.actual.inputPrecision << "_" << toString(testValues.params) <<
testValues.actual.dequantization << "_strided_slice_params_" << testValues.layerParams.begin <<
testValues.layerParams.end << testValues.layerParams.beginMask <<
testValues.layerParams.endMask << testValues.layerParams.strides;
testValues.layerParams.endMask << testValues.layerParams.strides <<
testValues.layerParams.shrinkAxisMask << testValues.layerParams.newAxisMask;
return result.str();
}
};
@@ -161,6 +162,28 @@ StridedSliceTransformationTestValues::LayerParams specialDimensionEndSlice = {
{}
};
StridedSliceTransformationTestValues::LayerParams sliceWithRemovedAxis = {
{ 0, 1, 0, 0 }, // begin
{ 1, 2, 1, 1 }, // end
{ 1, 1, 1, 1 }, // strided
{ 1, 0, 1, 1 }, // beginMask
{ 1, 0, 1, 1 }, // endMask
{ 0, 0, 0, 0 }, // newAxisMask
{ 0, 1, 0, 0 }, // shrinkAxisMask
{ 0, 0, 0, 0 } // elipsisMask
};
StridedSliceTransformationTestValues::LayerParams sliceWithAdditionalAxis = {
{ 0, 1, 0, 0 }, // begin
{ 1, 2, 1, 1 }, // end
{ 1, 1, 1, 1 }, // strided
{ 1, 0, 1, 1 }, // beginMask
{ 1, 0, 1, 1 }, // endMask
{ 0, 1, 0, 0 }, // newAxisMask
{ 0, 0, 0, 0 }, // shrinkAxisMask
{ 0, 0, 0, 0 } // elipsisMask
};
const std::vector<StridedSliceTransformationTestValues> stridedSliceTransformationTestValues = {
// U8: channel slice, per-tensor quantization
{
@@ -442,6 +465,74 @@ const std::vector<StridedSliceTransformationTestValues> stridedSliceTransformati
{{ngraph::element::f32}, {}, { {0.1f, 0.01f}, ngraph::element::f32, {1, 2, 1, 1} }}
}
},
// U8: channel slice, per-tensor quantization
{
ngraph::Shape{1, 3, 16, 1200},
LayerTransformation::createParamsU8I8(),
sliceWithRemovedAxis,
{
ngraph::element::u8,
{{ngraph::element::f32}, {128.f}, {0.1f}}
},
{
ngraph::element::u8,
{},
ngraph::element::u8,
{{ngraph::element::f32}, {128.f}, {0.1f}}
}
},
// U8: channel slice, per-channel quantization
{
ngraph::Shape{1, 3, 16, 1200},
LayerTransformation::createParamsU8I8(),
sliceWithRemovedAxis,
{
ngraph::element::u8,
{{ngraph::element::f32}, { {128.f, 64.f, 32.f} }, { {0.1f, 0.2f, 0.3f} }}
},
{
ngraph::element::u8,
{},
ngraph::element::u8,
{{ngraph::element::f32}, {64.f}, {0.2f}},
}
},
// U8: channel slice, per-tensor quantization
{
ngraph::Shape{1, 3, 16, 1200},
LayerTransformation::createParamsU8I8(),
sliceWithAdditionalAxis,
{
ngraph::element::u8,
{{ngraph::element::f32}, {128.f}, {0.1f}}
},
{
ngraph::element::u8,
{},
ngraph::element::u8,
{{ngraph::element::f32}, {128.f}, {0.1f}}
}
},
// U8: channel slice, per-channel quantization
{
ngraph::Shape{1, 3, 16, 1200},
LayerTransformation::createParamsU8I8(),
sliceWithAdditionalAxis,
{
ngraph::element::u8,
{{ngraph::element::f32}, { {128.f, 64.f, 32.f} }, { {0.1f, 0.2f, 0.3f} }}
},
{
ngraph::element::u8,
{},
ngraph::element::u8,
{
{ngraph::element::f32},
{ {128.f, 64.f, 32.f}, ngraph::element::f32, {1, 1, 3, 1, 1} },
{ {0.1f, 0.2f, 0.3f}, ngraph::element::f32, {1, 1, 3, 1, 1} }
},
}
},
};
INSTANTIATE_TEST_CASE_P(