[LPT] Fix Yolo v5 bug with Concat transformation subtract constant (#8491)

* add supporting convert after subtract constant in concat_transformation

* add check for the existence of subtract constant

* add convert fold for subtract convert & add tests for subtract with convert
This commit is contained in:
Nikita Demashov 2021-11-23 17:10:14 +03:00 committed by GitHub
parent dd38ffb387
commit 6156626706
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 95 additions and 2 deletions

View File

@ -110,9 +110,14 @@ bool ConcatTransformation::transform(TransformationContext& context, ngraph::pat
targetShape[1] = concat->get_input_partial_shape(i)[1].get_length();
if (!allDequantizationShiftAreZero) {
subtractNodes.push_back(dequantization.subtract == nullptr ?
auto subtractInput = dequantization.subtract == nullptr ?
std::make_shared<ngraph::opset1::Constant>(deqPrecision, targetShape, std::vector<float>({ 0.f })) :
broadcastElementWiseConst(dequantization.subtractConstant, targetShape));
broadcastElementWiseConst(dequantization.subtractConstant, targetShape);
if (dequantization.subtractConvert != nullptr) {
subtractInput = foldConvert(subtractInput, dequantization.subtractConvert->get_convert_element_type());
NetworkHelper::copyInfo(dequantization.subtractConvert, subtractInput);
}
subtractNodes.push_back(subtractInput);
}
if (!allDequantizationMultiplyAreZero) {

View File

@ -866,6 +866,94 @@ const std::vector<ConcatTransformationTestValues> testValues = {
{}
},
},
// U8: concat with subtract convert
{
LayerTransformation::createParamsU8I8(),
true,
1,
{
{ 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f} },
{},
{},
{ 256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f} },
{ ngraph::element::u8 },
{
{ element::f32 },
{
{ 0 },
element::f32,
{},
false,
1ul,
ngraph::element::i8,
true,
{},
{}
},
{ 0.01f }
},
},
{
{
256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}, ngraph::element::u8,
{ make_shared_attribute_ptr<IntervalsAlignmentAttribute>(IntervalsAlignmentSharedValue::Interval{0.f, 2.55f}, 256ul) }
},
{},
{},
{
256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}, ngraph::element::u8,
{ make_shared_attribute_ptr<IntervalsAlignmentAttribute>(IntervalsAlignmentSharedValue::Interval{0.f, 2.55f}, 256ul) }
},
{},
{},
ngraph::element::u8,
{ ngraph::element::f32, { 0 }, { 0.01f } }
}
},
// U8: concat multi channels with subtract convert
{
LayerTransformation::createParamsU8I8(),
true,
1,
{
{ 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f} },
{},
{},
{ 256ul, {}, {-1.28f}, {1.27f}, {-1.28f}, {1.27f} },
{ ngraph::element::u8 },
{
{ element::f32 },
{
{ 128 },
element::f32,
{},
false,
1ul,
ngraph::element::u8,
true,
{},
{}
},
{ 0.01f }
},
},
{
{
256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}, ngraph::element::u8,
{ make_shared_attribute_ptr<IntervalsAlignmentAttribute>(IntervalsAlignmentSharedValue::Interval{-1.28f, 2.55f}, 256ul) }
},
{},
{},
{
256ul, {}, {-1.28f}, {1.27f}, {-1.28f}, {1.27f}, ngraph::element::u8,
{ make_shared_attribute_ptr<IntervalsAlignmentAttribute>(IntervalsAlignmentSharedValue::Interval{-1.28f, 2.55f}, 256ul) }
},
{},
{},
ngraph::element::u8,
{ ngraph::element::f32, {{0.f, 0.f, 0.f, 128.f, 128.f, 128.f}}, { 0.01f } }
}
},
// U8: concat multi channels with subtract
// Features:
// 1. fakeQuantize1 defines precision