[LPT] Fix Yolo v5 bug with Concat transformation subtract constant (#8491)
* add supporting convert after subtract constant in concat_transformation * add check for the existence of subtract constant * add convert fold for subtract convert & add tests for subtract with convert
This commit is contained in:
parent
dd38ffb387
commit
6156626706
@ -110,9 +110,14 @@ bool ConcatTransformation::transform(TransformationContext& context, ngraph::pat
|
||||
targetShape[1] = concat->get_input_partial_shape(i)[1].get_length();
|
||||
|
||||
if (!allDequantizationShiftAreZero) {
|
||||
subtractNodes.push_back(dequantization.subtract == nullptr ?
|
||||
auto subtractInput = dequantization.subtract == nullptr ?
|
||||
std::make_shared<ngraph::opset1::Constant>(deqPrecision, targetShape, std::vector<float>({ 0.f })) :
|
||||
broadcastElementWiseConst(dequantization.subtractConstant, targetShape));
|
||||
broadcastElementWiseConst(dequantization.subtractConstant, targetShape);
|
||||
if (dequantization.subtractConvert != nullptr) {
|
||||
subtractInput = foldConvert(subtractInput, dequantization.subtractConvert->get_convert_element_type());
|
||||
NetworkHelper::copyInfo(dequantization.subtractConvert, subtractInput);
|
||||
}
|
||||
subtractNodes.push_back(subtractInput);
|
||||
}
|
||||
|
||||
if (!allDequantizationMultiplyAreZero) {
|
||||
|
@ -866,6 +866,94 @@ const std::vector<ConcatTransformationTestValues> testValues = {
|
||||
{}
|
||||
},
|
||||
},
|
||||
// U8: concat with subtract convert
|
||||
{
|
||||
LayerTransformation::createParamsU8I8(),
|
||||
true,
|
||||
1,
|
||||
{
|
||||
{ 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f} },
|
||||
{},
|
||||
{},
|
||||
{ 256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f} },
|
||||
{ ngraph::element::u8 },
|
||||
{
|
||||
{ element::f32 },
|
||||
{
|
||||
{ 0 },
|
||||
element::f32,
|
||||
{},
|
||||
false,
|
||||
1ul,
|
||||
ngraph::element::i8,
|
||||
true,
|
||||
{},
|
||||
{}
|
||||
},
|
||||
{ 0.01f }
|
||||
},
|
||||
},
|
||||
{
|
||||
{
|
||||
256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}, ngraph::element::u8,
|
||||
{ make_shared_attribute_ptr<IntervalsAlignmentAttribute>(IntervalsAlignmentSharedValue::Interval{0.f, 2.55f}, 256ul) }
|
||||
},
|
||||
{},
|
||||
{},
|
||||
{
|
||||
256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}, ngraph::element::u8,
|
||||
{ make_shared_attribute_ptr<IntervalsAlignmentAttribute>(IntervalsAlignmentSharedValue::Interval{0.f, 2.55f}, 256ul) }
|
||||
},
|
||||
{},
|
||||
{},
|
||||
ngraph::element::u8,
|
||||
{ ngraph::element::f32, { 0 }, { 0.01f } }
|
||||
}
|
||||
},
|
||||
// U8: concat multi channels with subtract convert
|
||||
{
|
||||
LayerTransformation::createParamsU8I8(),
|
||||
true,
|
||||
1,
|
||||
{
|
||||
{ 256ul, {}, {0.f}, {2.55f}, {0.f}, {2.55f} },
|
||||
{},
|
||||
{},
|
||||
{ 256ul, {}, {-1.28f}, {1.27f}, {-1.28f}, {1.27f} },
|
||||
{ ngraph::element::u8 },
|
||||
{
|
||||
{ element::f32 },
|
||||
{
|
||||
{ 128 },
|
||||
element::f32,
|
||||
{},
|
||||
false,
|
||||
1ul,
|
||||
ngraph::element::u8,
|
||||
true,
|
||||
{},
|
||||
{}
|
||||
},
|
||||
{ 0.01f }
|
||||
},
|
||||
},
|
||||
{
|
||||
{
|
||||
256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}, ngraph::element::u8,
|
||||
{ make_shared_attribute_ptr<IntervalsAlignmentAttribute>(IntervalsAlignmentSharedValue::Interval{-1.28f, 2.55f}, 256ul) }
|
||||
},
|
||||
{},
|
||||
{},
|
||||
{
|
||||
256ul, {}, {-1.28f}, {1.27f}, {-1.28f}, {1.27f}, ngraph::element::u8,
|
||||
{ make_shared_attribute_ptr<IntervalsAlignmentAttribute>(IntervalsAlignmentSharedValue::Interval{-1.28f, 2.55f}, 256ul) }
|
||||
},
|
||||
{},
|
||||
{},
|
||||
ngraph::element::u8,
|
||||
{ ngraph::element::f32, {{0.f, 0.f, 0.f, 128.f, 128.f, 128.f}}, { 0.01f } }
|
||||
}
|
||||
},
|
||||
// U8: concat multi channels with subtract
|
||||
// Features:
|
||||
// 1. fakeQuantize1 defines precision
|
||||
|
Loading…
Reference in New Issue
Block a user