[LPT] Fixes for the cases with convert before subtraction constant (#12835)
This commit is contained in:
parent
9d55355daf
commit
88e4ac5e53
@ -128,6 +128,10 @@ public:
|
||||
const element::Type deqPrecision = element::f32,
|
||||
std::shared_ptr<ngraph::Node> input = nullptr);
|
||||
|
||||
static std::shared_ptr<ngraph::Node> makeDequantizationSubtract(
|
||||
const ngraph::Output<ngraph::Node>& parent,
|
||||
const ngraph::Output<ngraph::Node>& subtract_constant);
|
||||
|
||||
static FakeQuantizeDequantization createDequantizationFromFakeQuantize(
|
||||
std::shared_ptr<opset1::FakeQuantize> fq,
|
||||
element::Type precision,
|
||||
@ -156,7 +160,7 @@ public:
|
||||
|
||||
static std::shared_ptr<opset1::Constant> normalizeDequantizationShape(
|
||||
const std::shared_ptr<Node>& eltwise,
|
||||
const bool convertIsExpected = false);
|
||||
const bool convertIsExpected = true);
|
||||
|
||||
// 1. remove Convert if possible
|
||||
// 2. optimize Constant if possible
|
||||
|
@ -1187,6 +1187,16 @@ FakeQuantizeDequantization NetworkHelper::makeDequantization(
|
||||
return FakeQuantizeDequantization(input, convert, subtract, nullptr, subtractConstant, multiply, multiplyConstant);
|
||||
}
|
||||
|
||||
std::shared_ptr<ov::Node> NetworkHelper::makeDequantizationSubtract(
|
||||
const ov::Output<ov::Node>& parent,
|
||||
const ov::Output<ov::Node>& subtract_constant) {
|
||||
return subtract_constant.get_element_type() != parent.get_element_type()
|
||||
? std::dynamic_pointer_cast<ov::Node>(std::make_shared<opset1::Subtract>(
|
||||
parent,
|
||||
std::make_shared<opset1::Convert>(subtract_constant, parent.get_element_type())))
|
||||
: std::make_shared<opset1::Subtract>(parent, subtract_constant);
|
||||
}
|
||||
|
||||
FakeQuantizeDequantization NetworkHelper::createDequantizationFromFakeQuantize(
|
||||
std::shared_ptr<opset1::FakeQuantize> fq,
|
||||
element::Type precision,
|
||||
@ -1644,6 +1654,9 @@ NetworkHelper::InsertDequantizationResult NetworkHelper::moveDequantizationAfter
|
||||
op::TemporaryReplaceOutputType(foldConvert(dequantization.subtractConstant, parentPrecision), element::f32).get());
|
||||
ngraph::copy_runtime_info({ newOperation, parent }, parent);
|
||||
} else {
|
||||
// Subtract constant could be changed (including a shape) before propagation in some cases
|
||||
// so it's necessary to compute the shape for a subtractConvert before creating a new subtract
|
||||
dequantization.subtractConvert->validate_and_infer_types();
|
||||
parent = std::make_shared<opset1::Subtract>(parent, dequantization.subtractConvert);
|
||||
ngraph::copy_runtime_info({ newOperation, parent }, parent);
|
||||
}
|
||||
@ -1736,6 +1749,9 @@ NetworkHelper::InsertDequantizationResult NetworkHelper::moveDequantizationBefor
|
||||
foldConvert(subtractConstant, parentPrecision), element::f32).get());
|
||||
parent->set_friendly_name(dequantization.subtract->get_friendly_name() + "_" + std::to_string(i + 1));
|
||||
} else {
|
||||
// Subtract constant could be changed (including a shape) before propagation in some cases
|
||||
// so it's necessary to compute the shape for a subtractConvert before creating a new subtract
|
||||
dequantization.subtractConvert->validate_and_infer_types();
|
||||
parent = std::make_shared<opset1::Subtract>(parent, dequantization.subtractConvert);
|
||||
}
|
||||
ngraph::copy_runtime_info(dequantization.subtract, parent);
|
||||
|
@ -42,7 +42,7 @@ bool ShuffleChannelsTransformation::transform(TransformationContext& context, ng
|
||||
auto dequantization = NetworkHelper::getDequantization(shuffleChannels, defaultPrecisions);
|
||||
|
||||
const auto shuffleDequantizationConstant = [&](const std::shared_ptr<Node>& eltwise) {
|
||||
const auto normalizedConst = NetworkHelper::normalizeDequantizationShape(eltwise);
|
||||
const auto normalizedConst = NetworkHelper::normalizeDequantizationShape(eltwise, true);
|
||||
const auto constShape = normalizedConst->get_shape();
|
||||
|
||||
if (shape_size(constShape) == 1ul) {
|
||||
|
@ -90,7 +90,7 @@ bool SplitTransformation::transform(TransformationContext& context, ngraph::patt
|
||||
}
|
||||
|
||||
if (dequantization.subtract) {
|
||||
const auto subtract = std::make_shared<opset1::Subtract>(parent, splitedSub[i]);
|
||||
const auto subtract = NetworkHelper::makeDequantizationSubtract(parent, splitedSub[i]);
|
||||
copy_runtime_info({ newSplit, subtract }, subtract);
|
||||
parent = subtract;
|
||||
}
|
||||
|
@ -125,8 +125,9 @@ bool TransposeTransformation::canBeTransformed(const TransformationContext& cont
|
||||
return true;
|
||||
}();
|
||||
|
||||
const auto values = constant->cast_vector<float>();
|
||||
// TODO: remove legacy limitation
|
||||
if (!isPerTensor) {
|
||||
const auto values = constant->cast_vector<float>();
|
||||
if ((values.size() < 2ul) || (values[0] != 0) || (values[1] != 1)) {
|
||||
return false;
|
||||
}
|
||||
|
@ -315,6 +315,28 @@ const std::vector<PadTransformationTestValues> deqWithSub = {
|
||||
ngraph::element::u8,
|
||||
{{ngraph::element::f32}, {{128.f, 64.f, 32.f}}, {{3.f, 1.f, 2.f}}}
|
||||
}
|
||||
},
|
||||
// int8 subtraction with Convert from u8 to fp32
|
||||
{
|
||||
LayerTransformation::createParamsU8I8(),
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{
|
||||
{ngraph::element::f32},
|
||||
{{128.f}, element::undefined, {1, 3, 1, 1}, false, 1ul, element::u8, true},
|
||||
{3.f}
|
||||
}
|
||||
},
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{{}, {}, {}},
|
||||
ngraph::element::u8,
|
||||
{
|
||||
{ngraph::element::f32},
|
||||
{{128.f}, element::undefined, {1, 3, 1, 1}, false, 1ul, element::u8, true},
|
||||
{3.f}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -250,6 +250,31 @@ const std::vector<ReshapeTransformationTestValues> testValues = {
|
||||
}
|
||||
}
|
||||
},
|
||||
// U8: no subtract 3D -> 4D: channels are not affected:
|
||||
// per-channel subtraction with Convert from u8 to fp32 and identical values
|
||||
{
|
||||
{ 1, 4, 10, 10 },
|
||||
{ 1, 2, 2, 10, 10},
|
||||
LayerTransformation::createParamsU8I8(),
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{
|
||||
{ngraph::element::f32},
|
||||
{{128.f}, element::undefined, {1, 4, 1, 1}, false, 1ul, element::u8, true},
|
||||
{3.f}
|
||||
}
|
||||
},
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{{}, {}, {}},
|
||||
ngraph::element::u8,
|
||||
{
|
||||
{ngraph::element::f32},
|
||||
{{128.f}, element::undefined, {}, false, 1ul, element::u8, true},
|
||||
{3.f}
|
||||
}
|
||||
}
|
||||
},
|
||||
// U8: with subtract 3D -> 4D: channels are not affected, dynamic batch
|
||||
{
|
||||
{ -1, 3, 20 },
|
||||
|
@ -139,6 +139,30 @@ const std::vector<ShuffleChannelsTransformationTestValues> testValues = {
|
||||
{{ngraph::element::f32}, {{128.f, 64.f, 32.f}}, {{0.01f, 0.02f, 0.03f}}}
|
||||
}
|
||||
},
|
||||
// subtraction with Convert from u8 to fp32
|
||||
{
|
||||
LayerTransformation::createParamsU8I8(),
|
||||
1,
|
||||
1,
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{
|
||||
{ngraph::element::f32},
|
||||
{{128.f}, element::undefined, {1, 3, 1, 1}, false, 1ul, element::u8, true},
|
||||
{3.f}
|
||||
}
|
||||
},
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{},
|
||||
ngraph::element::u8,
|
||||
{
|
||||
{ngraph::element::f32},
|
||||
{{128.f}, element::undefined, {1, 3, 1, 1}, false, 1ul, element::u8, true},
|
||||
{3.f}
|
||||
}
|
||||
}
|
||||
},
|
||||
// U8 quantization by spatial dimension, shuffling by the same dimension
|
||||
{
|
||||
LayerTransformation::createParamsU8I8(),
|
||||
|
@ -136,6 +136,46 @@ const std::vector<SplitTransformationTestValues> testValues = {
|
||||
}
|
||||
}
|
||||
},
|
||||
// U8 per tensor quantization / int8 subtraction with Convert from u8 to fp32
|
||||
{
|
||||
{ 1, 3, 16, 16 }, std::int64_t{2}, size_t{2},
|
||||
LayerTransformation::createParamsU8I8(),
|
||||
// ActualValues
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{{ngraph::element::f32}, {{128.f}, element::undefined, {}, false, 1ul, element::u8, true}, {3.f}}
|
||||
},
|
||||
// ExpectedValues
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{},
|
||||
ngraph::element::u8,
|
||||
{
|
||||
{{ngraph::element::f32}, {{128.f}, element::undefined, {}, false, 1ul, element::u8, true}, {3.f}},
|
||||
{{ngraph::element::f32}, {{128.f}, element::undefined, {}, false, 1ul, element::u8, true}, {3.f}},
|
||||
}
|
||||
}
|
||||
},
|
||||
// U8 per tensor quantization / int8 subtraction with Convert from fp16 -> fp32
|
||||
{
|
||||
{ 1, 3, 16, 16 }, std::int64_t{2}, size_t{2},
|
||||
LayerTransformation::createParamsU8I8(),
|
||||
// ActualValues
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{{ngraph::element::f32}, {{128.f}, element::undefined, {}, false, 1ul, element::f16, true}, {3.f}}
|
||||
},
|
||||
// ExpectedValues
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{},
|
||||
ngraph::element::u8,
|
||||
{
|
||||
{{ngraph::element::f32}, {{128.f}, element::undefined, {}, false, 1ul, element::f16, true}, {3.f}},
|
||||
{{ngraph::element::f32}, {{128.f}, element::undefined, {}, false, 1ul, element::f16, true}, {3.f}},
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
{ -1, -1, -1, -1 }, std::int64_t{2}, size_t{2},
|
||||
LayerTransformation::createParamsU8I8(),
|
||||
|
@ -228,6 +228,29 @@ const std::vector<StridedSliceTransformationTestValues> stridedSliceTransformati
|
||||
{{ngraph::element::f32}, { 128.f }, { 0.1f }}
|
||||
}
|
||||
},
|
||||
// U8: channel slice, per-channel quantization with the same values, subtraction with Convert from u8 to fp32
|
||||
{
|
||||
LayerTransformation::createParamsU8I8(),
|
||||
channelSlice,
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{
|
||||
{ngraph::element::f32},
|
||||
{{128.f}, element::undefined, {1, 3, 1, 1}, false, 1ul, element::u8, true},
|
||||
{3.f}
|
||||
}
|
||||
},
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{},
|
||||
ngraph::element::u8,
|
||||
{
|
||||
{ngraph::element::f32},
|
||||
{{128.f}, element::undefined, {}, false, 1ul, element::u8, true},
|
||||
{3.f}
|
||||
}
|
||||
}
|
||||
},
|
||||
// U8: channel slice, per-channel quantization with different values
|
||||
{
|
||||
LayerTransformation::createParamsU8I8(),
|
||||
|
@ -168,6 +168,68 @@ const std::vector<TransposeTransformationTestValues> testValues = {
|
||||
}
|
||||
}
|
||||
},
|
||||
// U8: per-channel quantization with the same values,
|
||||
// subtraction with Convert from u8 to fp32, transpose channel dimension
|
||||
{
|
||||
{ 0, 3, 1, 2 },
|
||||
LayerTransformation::createParamsU8I8(),
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{
|
||||
{ ngraph::element::f32 },
|
||||
{{128.f}, element::undefined, {1, 3, 1, 1}, false, 1ul, element::u8, true},
|
||||
{{0.1}, ngraph::element::f32, { 1, 3, 1, 1 }}
|
||||
}
|
||||
},
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{{}, {}, {}},
|
||||
ngraph::element::u8,
|
||||
{
|
||||
{ ngraph::element::f32 },
|
||||
{{128.f}, element::undefined, {1, 1, 3, 1}, false, 1ul, element::u8, true},
|
||||
{{0.1}, ngraph::element::f32, {1, 1, 3, 1}}
|
||||
}
|
||||
}
|
||||
},
|
||||
// U8: per-tensor quantization, transpose channel dimension
|
||||
{
|
||||
{ 0, 3, 1, 2 },
|
||||
LayerTransformation::createParamsU8I8(),
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{{ngraph::element::f32}, {128}, {0.1f}}
|
||||
},
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{{}, {}, {}},
|
||||
ngraph::element::u8,
|
||||
{{ngraph::element::f32}, {128}, {0.1f}}
|
||||
}
|
||||
},
|
||||
// U8: per-channel quantization, transpose channel dimension
|
||||
{
|
||||
{ 0, 2, 1, 3 },
|
||||
LayerTransformation::createParamsU8I8(),
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{
|
||||
{ ngraph::element::f32 },
|
||||
{{ 128, 64, 32 }, ngraph::element::f32, { 1, 3, 1, 1 }},
|
||||
{{ 0.3f, 0.2f, 0.1f }, ngraph::element::f32, { 1, 3, 1, 1 }}
|
||||
}
|
||||
},
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{
|
||||
{ ngraph::element::f32 },
|
||||
{{ 128, 64, 32 }, ngraph::element::f32, { 1, 3, 1, 1 }},
|
||||
{{ 0.3f, 0.2f, 0.1f }, ngraph::element::f32, { 1, 3, 1, 1 }}
|
||||
},
|
||||
ngraph::element::f32,
|
||||
{{}, {}, {}},
|
||||
}
|
||||
},
|
||||
// empty
|
||||
{
|
||||
{ 0, 1, 3, 2 },
|
||||
|
@ -60,7 +60,9 @@ std::shared_ptr<Node> makeDequantization(
|
||||
if (dequantizationOperations.subtract.addConvert) {
|
||||
std::shared_ptr<Node> subtractConstConvert = std::make_shared<ngraph::opset1::Convert>(
|
||||
subtractConst,
|
||||
dequantizationOperations.subtract.outPrecision);
|
||||
dequantizationOperations.subtract.outPrecision == element::undefined ?
|
||||
parent.get_element_type() :
|
||||
dequantizationOperations.subtract.outPrecision);
|
||||
|
||||
auto& rt = subtractConstConvert->get_rt_info();
|
||||
for (const auto& attribute : dequantizationOperations.subtract.convertAttributes) {
|
||||
|
Loading…
Reference in New Issue
Block a user