[LPT] Fixes for the cases with convert before subtraction constant (#12835)

This commit is contained in:
Vladislav Golubev 2022-09-06 19:41:29 +02:00 committed by GitHub
parent 9d55355daf
commit 88e4ac5e53
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 224 additions and 5 deletions

View File

@ -128,6 +128,10 @@ public:
const element::Type deqPrecision = element::f32, const element::Type deqPrecision = element::f32,
std::shared_ptr<ngraph::Node> input = nullptr); std::shared_ptr<ngraph::Node> input = nullptr);
static std::shared_ptr<ngraph::Node> makeDequantizationSubtract(
const ngraph::Output<ngraph::Node>& parent,
const ngraph::Output<ngraph::Node>& subtract_constant);
static FakeQuantizeDequantization createDequantizationFromFakeQuantize( static FakeQuantizeDequantization createDequantizationFromFakeQuantize(
std::shared_ptr<opset1::FakeQuantize> fq, std::shared_ptr<opset1::FakeQuantize> fq,
element::Type precision, element::Type precision,
@ -156,7 +160,7 @@ public:
static std::shared_ptr<opset1::Constant> normalizeDequantizationShape( static std::shared_ptr<opset1::Constant> normalizeDequantizationShape(
const std::shared_ptr<Node>& eltwise, const std::shared_ptr<Node>& eltwise,
const bool convertIsExpected = false); const bool convertIsExpected = true);
// 1. remove Convert if possible // 1. remove Convert if possible
// 2. optimize Constant if possible // 2. optimize Constant if possible

View File

@ -1187,6 +1187,16 @@ FakeQuantizeDequantization NetworkHelper::makeDequantization(
return FakeQuantizeDequantization(input, convert, subtract, nullptr, subtractConstant, multiply, multiplyConstant); return FakeQuantizeDequantization(input, convert, subtract, nullptr, subtractConstant, multiply, multiplyConstant);
} }
std::shared_ptr<ov::Node> NetworkHelper::makeDequantizationSubtract(
const ov::Output<ov::Node>& parent,
const ov::Output<ov::Node>& subtract_constant) {
return subtract_constant.get_element_type() != parent.get_element_type()
? std::dynamic_pointer_cast<ov::Node>(std::make_shared<opset1::Subtract>(
parent,
std::make_shared<opset1::Convert>(subtract_constant, parent.get_element_type())))
: std::make_shared<opset1::Subtract>(parent, subtract_constant);
}
FakeQuantizeDequantization NetworkHelper::createDequantizationFromFakeQuantize( FakeQuantizeDequantization NetworkHelper::createDequantizationFromFakeQuantize(
std::shared_ptr<opset1::FakeQuantize> fq, std::shared_ptr<opset1::FakeQuantize> fq,
element::Type precision, element::Type precision,
@ -1644,6 +1654,9 @@ NetworkHelper::InsertDequantizationResult NetworkHelper::moveDequantizationAfter
op::TemporaryReplaceOutputType(foldConvert(dequantization.subtractConstant, parentPrecision), element::f32).get()); op::TemporaryReplaceOutputType(foldConvert(dequantization.subtractConstant, parentPrecision), element::f32).get());
ngraph::copy_runtime_info({ newOperation, parent }, parent); ngraph::copy_runtime_info({ newOperation, parent }, parent);
} else { } else {
// Subtract constant could be changed (including a shape) before propagation in some cases
// so it's necessary to compute the shape for a subtractConvert before creating a new subtract
dequantization.subtractConvert->validate_and_infer_types();
parent = std::make_shared<opset1::Subtract>(parent, dequantization.subtractConvert); parent = std::make_shared<opset1::Subtract>(parent, dequantization.subtractConvert);
ngraph::copy_runtime_info({ newOperation, parent }, parent); ngraph::copy_runtime_info({ newOperation, parent }, parent);
} }
@ -1736,6 +1749,9 @@ NetworkHelper::InsertDequantizationResult NetworkHelper::moveDequantizationBefor
foldConvert(subtractConstant, parentPrecision), element::f32).get()); foldConvert(subtractConstant, parentPrecision), element::f32).get());
parent->set_friendly_name(dequantization.subtract->get_friendly_name() + "_" + std::to_string(i + 1)); parent->set_friendly_name(dequantization.subtract->get_friendly_name() + "_" + std::to_string(i + 1));
} else { } else {
// Subtract constant could be changed (including a shape) before propagation in some cases
// so it's necessary to compute the shape for a subtractConvert before creating a new subtract
dequantization.subtractConvert->validate_and_infer_types();
parent = std::make_shared<opset1::Subtract>(parent, dequantization.subtractConvert); parent = std::make_shared<opset1::Subtract>(parent, dequantization.subtractConvert);
} }
ngraph::copy_runtime_info(dequantization.subtract, parent); ngraph::copy_runtime_info(dequantization.subtract, parent);

View File

@ -42,7 +42,7 @@ bool ShuffleChannelsTransformation::transform(TransformationContext& context, ng
auto dequantization = NetworkHelper::getDequantization(shuffleChannels, defaultPrecisions); auto dequantization = NetworkHelper::getDequantization(shuffleChannels, defaultPrecisions);
const auto shuffleDequantizationConstant = [&](const std::shared_ptr<Node>& eltwise) { const auto shuffleDequantizationConstant = [&](const std::shared_ptr<Node>& eltwise) {
const auto normalizedConst = NetworkHelper::normalizeDequantizationShape(eltwise); const auto normalizedConst = NetworkHelper::normalizeDequantizationShape(eltwise, true);
const auto constShape = normalizedConst->get_shape(); const auto constShape = normalizedConst->get_shape();
if (shape_size(constShape) == 1ul) { if (shape_size(constShape) == 1ul) {

View File

@ -90,7 +90,7 @@ bool SplitTransformation::transform(TransformationContext& context, ngraph::patt
} }
if (dequantization.subtract) { if (dequantization.subtract) {
const auto subtract = std::make_shared<opset1::Subtract>(parent, splitedSub[i]); const auto subtract = NetworkHelper::makeDequantizationSubtract(parent, splitedSub[i]);
copy_runtime_info({ newSplit, subtract }, subtract); copy_runtime_info({ newSplit, subtract }, subtract);
parent = subtract; parent = subtract;
} }

View File

@ -125,8 +125,9 @@ bool TransposeTransformation::canBeTransformed(const TransformationContext& cont
return true; return true;
}(); }();
const auto values = constant->cast_vector<float>(); // TODO: remove legacy limitation
if (!isPerTensor) { if (!isPerTensor) {
const auto values = constant->cast_vector<float>();
if ((values.size() < 2ul) || (values[0] != 0) || (values[1] != 1)) { if ((values.size() < 2ul) || (values[0] != 0) || (values[1] != 1)) {
return false; return false;
} }

View File

@ -315,6 +315,28 @@ const std::vector<PadTransformationTestValues> deqWithSub = {
ngraph::element::u8, ngraph::element::u8,
{{ngraph::element::f32}, {{128.f, 64.f, 32.f}}, {{3.f, 1.f, 2.f}}} {{ngraph::element::f32}, {{128.f, 64.f, 32.f}}, {{3.f, 1.f, 2.f}}}
} }
},
// int8 subtraction with Convert from u8 to fp32
{
LayerTransformation::createParamsU8I8(),
{
ngraph::element::u8,
{
{ngraph::element::f32},
{{128.f}, element::undefined, {1, 3, 1, 1}, false, 1ul, element::u8, true},
{3.f}
}
},
{
ngraph::element::u8,
{{}, {}, {}},
ngraph::element::u8,
{
{ngraph::element::f32},
{{128.f}, element::undefined, {1, 3, 1, 1}, false, 1ul, element::u8, true},
{3.f}
}
}
} }
}; };

View File

@ -250,6 +250,31 @@ const std::vector<ReshapeTransformationTestValues> testValues = {
} }
} }
}, },
// U8: no subtract 3D -> 4D: channels are not affected:
// per-channel subtraction with Convert from u8 to fp32 and identical values
{
{ 1, 4, 10, 10 },
{ 1, 2, 2, 10, 10},
LayerTransformation::createParamsU8I8(),
{
ngraph::element::u8,
{
{ngraph::element::f32},
{{128.f}, element::undefined, {1, 4, 1, 1}, false, 1ul, element::u8, true},
{3.f}
}
},
{
ngraph::element::u8,
{{}, {}, {}},
ngraph::element::u8,
{
{ngraph::element::f32},
{{128.f}, element::undefined, {}, false, 1ul, element::u8, true},
{3.f}
}
}
},
// U8: with subtract 3D -> 4D: channels are not affected, dynamic batch // U8: with subtract 3D -> 4D: channels are not affected, dynamic batch
{ {
{ -1, 3, 20 }, { -1, 3, 20 },

View File

@ -139,6 +139,30 @@ const std::vector<ShuffleChannelsTransformationTestValues> testValues = {
{{ngraph::element::f32}, {{128.f, 64.f, 32.f}}, {{0.01f, 0.02f, 0.03f}}} {{ngraph::element::f32}, {{128.f, 64.f, 32.f}}, {{0.01f, 0.02f, 0.03f}}}
} }
}, },
// subtraction with Convert from u8 to fp32
{
LayerTransformation::createParamsU8I8(),
1,
1,
{
ngraph::element::u8,
{
{ngraph::element::f32},
{{128.f}, element::undefined, {1, 3, 1, 1}, false, 1ul, element::u8, true},
{3.f}
}
},
{
ngraph::element::u8,
{},
ngraph::element::u8,
{
{ngraph::element::f32},
{{128.f}, element::undefined, {1, 3, 1, 1}, false, 1ul, element::u8, true},
{3.f}
}
}
},
// U8 quantization by spatial dimension, shuffling by the same dimension // U8 quantization by spatial dimension, shuffling by the same dimension
{ {
LayerTransformation::createParamsU8I8(), LayerTransformation::createParamsU8I8(),

View File

@ -136,6 +136,46 @@ const std::vector<SplitTransformationTestValues> testValues = {
} }
} }
}, },
// U8 per tensor quantization / int8 subtraction with Convert from u8 to fp32
{
{ 1, 3, 16, 16 }, std::int64_t{2}, size_t{2},
LayerTransformation::createParamsU8I8(),
// ActualValues
{
ngraph::element::u8,
{{ngraph::element::f32}, {{128.f}, element::undefined, {}, false, 1ul, element::u8, true}, {3.f}}
},
// ExpectedValues
{
ngraph::element::u8,
{},
ngraph::element::u8,
{
{{ngraph::element::f32}, {{128.f}, element::undefined, {}, false, 1ul, element::u8, true}, {3.f}},
{{ngraph::element::f32}, {{128.f}, element::undefined, {}, false, 1ul, element::u8, true}, {3.f}},
}
}
},
// U8 per tensor quantization / int8 subtraction with Convert from fp16 -> fp32
{
{ 1, 3, 16, 16 }, std::int64_t{2}, size_t{2},
LayerTransformation::createParamsU8I8(),
// ActualValues
{
ngraph::element::u8,
{{ngraph::element::f32}, {{128.f}, element::undefined, {}, false, 1ul, element::f16, true}, {3.f}}
},
// ExpectedValues
{
ngraph::element::u8,
{},
ngraph::element::u8,
{
{{ngraph::element::f32}, {{128.f}, element::undefined, {}, false, 1ul, element::f16, true}, {3.f}},
{{ngraph::element::f32}, {{128.f}, element::undefined, {}, false, 1ul, element::f16, true}, {3.f}},
}
}
},
{ {
{ -1, -1, -1, -1 }, std::int64_t{2}, size_t{2}, { -1, -1, -1, -1 }, std::int64_t{2}, size_t{2},
LayerTransformation::createParamsU8I8(), LayerTransformation::createParamsU8I8(),

View File

@ -228,6 +228,29 @@ const std::vector<StridedSliceTransformationTestValues> stridedSliceTransformati
{{ngraph::element::f32}, { 128.f }, { 0.1f }} {{ngraph::element::f32}, { 128.f }, { 0.1f }}
} }
}, },
// U8: channel slice, per-channel quantization with the same values, subtraction with Convert from u8 to fp32
{
LayerTransformation::createParamsU8I8(),
channelSlice,
{
ngraph::element::u8,
{
{ngraph::element::f32},
{{128.f}, element::undefined, {1, 3, 1, 1}, false, 1ul, element::u8, true},
{3.f}
}
},
{
ngraph::element::u8,
{},
ngraph::element::u8,
{
{ngraph::element::f32},
{{128.f}, element::undefined, {}, false, 1ul, element::u8, true},
{3.f}
}
}
},
// U8: channel slice, per-channel quantization with different values // U8: channel slice, per-channel quantization with different values
{ {
LayerTransformation::createParamsU8I8(), LayerTransformation::createParamsU8I8(),

View File

@ -168,6 +168,68 @@ const std::vector<TransposeTransformationTestValues> testValues = {
} }
} }
}, },
// U8: per-channel quantization with the same values,
// subtraction with Convert from u8 to fp32, transpose channel dimension
{
{ 0, 3, 1, 2 },
LayerTransformation::createParamsU8I8(),
{
ngraph::element::u8,
{
{ ngraph::element::f32 },
{{128.f}, element::undefined, {1, 3, 1, 1}, false, 1ul, element::u8, true},
{{0.1}, ngraph::element::f32, { 1, 3, 1, 1 }}
}
},
{
ngraph::element::u8,
{{}, {}, {}},
ngraph::element::u8,
{
{ ngraph::element::f32 },
{{128.f}, element::undefined, {1, 1, 3, 1}, false, 1ul, element::u8, true},
{{0.1}, ngraph::element::f32, {1, 1, 3, 1}}
}
}
},
// U8: per-tensor quantization, transpose channel dimension
{
{ 0, 3, 1, 2 },
LayerTransformation::createParamsU8I8(),
{
ngraph::element::u8,
{{ngraph::element::f32}, {128}, {0.1f}}
},
{
ngraph::element::u8,
{{}, {}, {}},
ngraph::element::u8,
{{ngraph::element::f32}, {128}, {0.1f}}
}
},
// U8: per-channel quantization, transpose channel dimension
{
{ 0, 2, 1, 3 },
LayerTransformation::createParamsU8I8(),
{
ngraph::element::u8,
{
{ ngraph::element::f32 },
{{ 128, 64, 32 }, ngraph::element::f32, { 1, 3, 1, 1 }},
{{ 0.3f, 0.2f, 0.1f }, ngraph::element::f32, { 1, 3, 1, 1 }}
}
},
{
ngraph::element::u8,
{
{ ngraph::element::f32 },
{{ 128, 64, 32 }, ngraph::element::f32, { 1, 3, 1, 1 }},
{{ 0.3f, 0.2f, 0.1f }, ngraph::element::f32, { 1, 3, 1, 1 }}
},
ngraph::element::f32,
{{}, {}, {}},
}
},
// empty // empty
{ {
{ 0, 1, 3, 2 }, { 0, 1, 3, 2 },

View File

@ -60,7 +60,9 @@ std::shared_ptr<Node> makeDequantization(
if (dequantizationOperations.subtract.addConvert) { if (dequantizationOperations.subtract.addConvert) {
std::shared_ptr<Node> subtractConstConvert = std::make_shared<ngraph::opset1::Convert>( std::shared_ptr<Node> subtractConstConvert = std::make_shared<ngraph::opset1::Convert>(
subtractConst, subtractConst,
dequantizationOperations.subtract.outPrecision); dequantizationOperations.subtract.outPrecision == element::undefined ?
parent.get_element_type() :
dequantizationOperations.subtract.outPrecision);
auto& rt = subtractConstConvert->get_rt_info(); auto& rt = subtractConstConvert->get_rt_info();
for (const auto& attribute : dequantizationOperations.subtract.convertAttributes) { for (const auto& attribute : dequantizationOperations.subtract.convertAttributes) {