[LPT] reshape 4D->4D per-tensor quantization fix (#3644)

This commit is contained in:
Vladislav Golubev 2020-12-29 12:40:31 +03:00 committed by GitHub
parent 631d452258
commit 6326eb348a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 43 additions and 4 deletions

View File

@ -250,9 +250,11 @@ bool ReshapeTransformation::canBeTransformed(
return false;
}
} else {
for (size_t i = 0; i < 2ul; ++i) {
if (inputShape[i] != outputShape[i]) {
return false;
if (ngraph::shape_size(subtractShape) > 1 || ngraph::shape_size(multiplyShape) > 1) {
for (size_t i = 0; i < 2ul; ++i) {
if (inputShape[i] != outputShape[i]) {
return false;
}
}
}

View File

@ -533,6 +533,38 @@ const std::vector<ReshapeTransformationTestValues> testValues = {
ngraph::element::u8,
{{ngraph::element::f32}, {}, {{0.1f}, ngraph::element::f32, {1, 1}}}
}
},
// U8: no subtract 4D -> 4D: channels are affected
{
ngraph::Shape({ 1, 64, 320, 1 }),
{ 0, 2, 3, 1},
LayerTransformation::createParamsU8I8(),
{
ngraph::element::u8,
{{ngraph::element::f32}, {}, {{0.1f}, ngraph::element::f32, {}}}
},
{
ngraph::element::u8,
{{}, {}, {}},
ngraph::element::u8,
{{ngraph::element::f32}, {}, {{0.1f}, ngraph::element::f32, {}}}
}
},
// U8: with subtract 4D -> 4D: channels are affected
{
ngraph::Shape({ 1, 64, 320, 1 }),
{ 0, 2, 3, 1},
LayerTransformation::createParamsU8I8(),
{
ngraph::element::u8,
{{ngraph::element::f32}, {{128.f}, ngraph::element::f32, {}}, {{0.1f}, ngraph::element::f32, {}}}
},
{
ngraph::element::u8,
{{}, {}, {}},
ngraph::element::u8,
{{ngraph::element::f32}, {{128.f}, ngraph::element::f32, {}}, {{0.1f}, ngraph::element::f32, {}}}
}
}
};

View File

@ -55,7 +55,12 @@ void TransposeTransformation::validate() {
const auto output = transformed->get_output_op(0);
const auto layer = output->get_input_node_shared_ptr(0);
const std::string typeName = layer->get_type_name();
ASSERT_EQ("Reshape", typeName);
if (testValues.fqOnData.outputLowValues.size() > 1 || testValues.fqOnData.outputHighValues.size() > 1) {
ASSERT_EQ("Reshape", typeName);
} else {
ASSERT_EQ("ScaleShiftIE", typeName);
}
}
TEST_P(TransposeTransformation, CompareWithRefImpl) {