LPT fixes (#6214)
* [LPT] StridedSlice fix * [LPT] separateInStandaloneBranch fix
This commit is contained in:
committed by
GitHub
parent
9e7d98fca9
commit
6799a31911
@@ -549,7 +549,11 @@ std::shared_ptr<ngraph::Node> NetworkHelper::separateInStandaloneBranch(std::sha
|
||||
}
|
||||
|
||||
std::vector<Output<Node>> inputs = node->input_values();
|
||||
const size_t inputIndex = NetworkHelper::getChildInputIndex(dequantization.multiply, node);
|
||||
const auto originalParent = dequantization.multiply ?
|
||||
dequantization.multiply->shared_from_this() :
|
||||
dequantization.subtract->shared_from_this();
|
||||
|
||||
const size_t inputIndex = NetworkHelper::getChildInputIndex(originalParent, node);
|
||||
inputs[inputIndex] = parent;
|
||||
const std::shared_ptr<Node> newNode = node->clone_with_new_inputs(inputs);
|
||||
|
||||
|
||||
@@ -17,13 +17,12 @@ std::shared_ptr<Node> stridedSliceDeqConstant(
|
||||
const std::shared_ptr<ngraph::Node> strSlice,
|
||||
const std::shared_ptr<ngraph::Node> dequantizaitonConstant) {
|
||||
auto constant = as_type_ptr<ngraph::opset1::Constant>(dequantizaitonConstant);
|
||||
// issue #48857: constant is mistakenly recognized as a scalar. Uncomment after fix
|
||||
//if (NetworkHelper::isScalarLike(constant)) {
|
||||
// return NetworkHelper::toScalar(constant);
|
||||
//}
|
||||
auto constantShape = constant->get_shape();
|
||||
if (ngraph::shape_size(constantShape) == 1ul) {
|
||||
return NetworkHelper::toScalar(constant);
|
||||
}
|
||||
|
||||
const auto stridedSliceShape = strSlice->get_input_shape(0);
|
||||
auto constantShape = constant->get_shape();
|
||||
if (stridedSliceShape.size() != constantShape.size()) {
|
||||
ngraph::Shape newConstantShape;
|
||||
if (ngraph::shape_size(constantShape) == 1) {
|
||||
|
||||
@@ -81,7 +81,6 @@ public:
|
||||
"SeparateInStandaloneBranchTransformation");
|
||||
};
|
||||
actualFunction = createActualFunction(testValues.precisionBefore, shape, testValues.dequantization);
|
||||
|
||||
const auto result = actualFunction->get_results()[0];
|
||||
ngraph::pass::low_precision::NetworkHelper::separateInStandaloneBranch(result->get_input_node_shared_ptr(0));
|
||||
|
||||
@@ -143,6 +142,11 @@ std::vector<SeparateInStandaloneBranchTransformationTestValues> testValues = {
|
||||
ngraph::element::u8,
|
||||
{ ngraph::element::f32, { 127.f }, { 0.02f } }
|
||||
},
|
||||
{
|
||||
LayerTransformation::createParamsU8U8(),
|
||||
ngraph::element::u8,
|
||||
{ ngraph::element::f32, { 127.f }, {} }
|
||||
},
|
||||
{
|
||||
LayerTransformation::createParamsU8U8().setSupportAsymmetricQuantization(true),
|
||||
ngraph::element::u8,
|
||||
|
||||
@@ -117,7 +117,8 @@ public:
|
||||
testValues.inputShape << testValues.actual.inputPrecision << "_" << toString(testValues.params) <<
|
||||
testValues.actual.dequantization << "_strided_slice_params_" << testValues.layerParams.begin <<
|
||||
testValues.layerParams.end << testValues.layerParams.beginMask <<
|
||||
testValues.layerParams.endMask << testValues.layerParams.strides;
|
||||
testValues.layerParams.endMask << testValues.layerParams.strides <<
|
||||
testValues.layerParams.shrinkAxisMask << testValues.layerParams.newAxisMask;
|
||||
return result.str();
|
||||
}
|
||||
};
|
||||
@@ -161,6 +162,28 @@ StridedSliceTransformationTestValues::LayerParams specialDimensionEndSlice = {
|
||||
{}
|
||||
};
|
||||
|
||||
StridedSliceTransformationTestValues::LayerParams sliceWithRemovedAxis = {
|
||||
{ 0, 1, 0, 0 }, // begin
|
||||
{ 1, 2, 1, 1 }, // end
|
||||
{ 1, 1, 1, 1 }, // strided
|
||||
{ 1, 0, 1, 1 }, // beginMask
|
||||
{ 1, 0, 1, 1 }, // endMask
|
||||
{ 0, 0, 0, 0 }, // newAxisMask
|
||||
{ 0, 1, 0, 0 }, // shrinkAxisMask
|
||||
{ 0, 0, 0, 0 } // elipsisMask
|
||||
};
|
||||
|
||||
StridedSliceTransformationTestValues::LayerParams sliceWithAdditionalAxis = {
|
||||
{ 0, 1, 0, 0 }, // begin
|
||||
{ 1, 2, 1, 1 }, // end
|
||||
{ 1, 1, 1, 1 }, // strided
|
||||
{ 1, 0, 1, 1 }, // beginMask
|
||||
{ 1, 0, 1, 1 }, // endMask
|
||||
{ 0, 1, 0, 0 }, // newAxisMask
|
||||
{ 0, 0, 0, 0 }, // shrinkAxisMask
|
||||
{ 0, 0, 0, 0 } // elipsisMask
|
||||
};
|
||||
|
||||
const std::vector<StridedSliceTransformationTestValues> stridedSliceTransformationTestValues = {
|
||||
// U8: channel slice, per-tensor quantization
|
||||
{
|
||||
@@ -442,6 +465,74 @@ const std::vector<StridedSliceTransformationTestValues> stridedSliceTransformati
|
||||
{{ngraph::element::f32}, {}, { {0.1f, 0.01f}, ngraph::element::f32, {1, 2, 1, 1} }}
|
||||
}
|
||||
},
|
||||
// U8: channel slice, per-tensor quantization
|
||||
{
|
||||
ngraph::Shape{1, 3, 16, 1200},
|
||||
LayerTransformation::createParamsU8I8(),
|
||||
sliceWithRemovedAxis,
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{{ngraph::element::f32}, {128.f}, {0.1f}}
|
||||
},
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{},
|
||||
ngraph::element::u8,
|
||||
{{ngraph::element::f32}, {128.f}, {0.1f}}
|
||||
}
|
||||
},
|
||||
// U8: channel slice, per-channel quantization
|
||||
{
|
||||
ngraph::Shape{1, 3, 16, 1200},
|
||||
LayerTransformation::createParamsU8I8(),
|
||||
sliceWithRemovedAxis,
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{{ngraph::element::f32}, { {128.f, 64.f, 32.f} }, { {0.1f, 0.2f, 0.3f} }}
|
||||
},
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{},
|
||||
ngraph::element::u8,
|
||||
{{ngraph::element::f32}, {64.f}, {0.2f}},
|
||||
}
|
||||
},
|
||||
// U8: channel slice, per-tensor quantization
|
||||
{
|
||||
ngraph::Shape{1, 3, 16, 1200},
|
||||
LayerTransformation::createParamsU8I8(),
|
||||
sliceWithAdditionalAxis,
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{{ngraph::element::f32}, {128.f}, {0.1f}}
|
||||
},
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{},
|
||||
ngraph::element::u8,
|
||||
{{ngraph::element::f32}, {128.f}, {0.1f}}
|
||||
}
|
||||
},
|
||||
// U8: channel slice, per-channel quantization
|
||||
{
|
||||
ngraph::Shape{1, 3, 16, 1200},
|
||||
LayerTransformation::createParamsU8I8(),
|
||||
sliceWithAdditionalAxis,
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{{ngraph::element::f32}, { {128.f, 64.f, 32.f} }, { {0.1f, 0.2f, 0.3f} }}
|
||||
},
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{},
|
||||
ngraph::element::u8,
|
||||
{
|
||||
{ngraph::element::f32},
|
||||
{ {128.f, 64.f, 32.f}, ngraph::element::f32, {1, 1, 3, 1, 1} },
|
||||
{ {0.1f, 0.2f, 0.3f}, ngraph::element::f32, {1, 1, 3, 1, 1} }
|
||||
},
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(
|
||||
|
||||
Reference in New Issue
Block a user