[LPT] reshape 4D->4D per-tensor quantization fix (#3644)
This commit is contained in:
parent
631d452258
commit
6326eb348a
@ -250,9 +250,11 @@ bool ReshapeTransformation::canBeTransformed(
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
for (size_t i = 0; i < 2ul; ++i) {
|
||||
if (inputShape[i] != outputShape[i]) {
|
||||
return false;
|
||||
if (ngraph::shape_size(subtractShape) > 1 || ngraph::shape_size(multiplyShape) > 1) {
|
||||
for (size_t i = 0; i < 2ul; ++i) {
|
||||
if (inputShape[i] != outputShape[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -533,6 +533,38 @@ const std::vector<ReshapeTransformationTestValues> testValues = {
|
||||
ngraph::element::u8,
|
||||
{{ngraph::element::f32}, {}, {{0.1f}, ngraph::element::f32, {1, 1}}}
|
||||
}
|
||||
},
|
||||
// U8: no subtract 4D -> 4D: channels are affected
|
||||
{
|
||||
ngraph::Shape({ 1, 64, 320, 1 }),
|
||||
{ 0, 2, 3, 1},
|
||||
LayerTransformation::createParamsU8I8(),
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{{ngraph::element::f32}, {}, {{0.1f}, ngraph::element::f32, {}}}
|
||||
},
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{{}, {}, {}},
|
||||
ngraph::element::u8,
|
||||
{{ngraph::element::f32}, {}, {{0.1f}, ngraph::element::f32, {}}}
|
||||
}
|
||||
},
|
||||
// U8: with subtract 4D -> 4D: channels are affected
|
||||
{
|
||||
ngraph::Shape({ 1, 64, 320, 1 }),
|
||||
{ 0, 2, 3, 1},
|
||||
LayerTransformation::createParamsU8I8(),
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{{ngraph::element::f32}, {{128.f}, ngraph::element::f32, {}}, {{0.1f}, ngraph::element::f32, {}}}
|
||||
},
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{{}, {}, {}},
|
||||
ngraph::element::u8,
|
||||
{{ngraph::element::f32}, {{128.f}, ngraph::element::f32, {}}, {{0.1f}, ngraph::element::f32, {}}}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -55,7 +55,12 @@ void TransposeTransformation::validate() {
|
||||
const auto output = transformed->get_output_op(0);
|
||||
const auto layer = output->get_input_node_shared_ptr(0);
|
||||
const std::string typeName = layer->get_type_name();
|
||||
ASSERT_EQ("Reshape", typeName);
|
||||
|
||||
if (testValues.fqOnData.outputLowValues.size() > 1 || testValues.fqOnData.outputHighValues.size() > 1) {
|
||||
ASSERT_EQ("Reshape", typeName);
|
||||
} else {
|
||||
ASSERT_EQ("ScaleShiftIE", typeName);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_P(TransposeTransformation, CompareWithRefImpl) {
|
||||
|
Loading…
Reference in New Issue
Block a user