diff --git a/src/common/low_precision_transformations/include/low_precision/network_helper.hpp b/src/common/low_precision_transformations/include/low_precision/network_helper.hpp index 08cfc518d69..a29e404a3e3 100644 --- a/src/common/low_precision_transformations/include/low_precision/network_helper.hpp +++ b/src/common/low_precision_transformations/include/low_precision/network_helper.hpp @@ -128,6 +128,10 @@ public: const element::Type deqPrecision = element::f32, std::shared_ptr input = nullptr); + static std::shared_ptr makeDequantizationSubtract( + const ngraph::Output& parent, + const ngraph::Output& subtract_constant); + static FakeQuantizeDequantization createDequantizationFromFakeQuantize( std::shared_ptr fq, element::Type precision, @@ -156,7 +160,7 @@ public: static std::shared_ptr normalizeDequantizationShape( const std::shared_ptr& eltwise, - const bool convertIsExpected = false); + const bool convertIsExpected = true); // 1. remove Convert if possible // 2. optimize Constant if possible diff --git a/src/common/low_precision_transformations/src/network_helper.cpp b/src/common/low_precision_transformations/src/network_helper.cpp index 2ede193b4f9..13a36eeb67b 100644 --- a/src/common/low_precision_transformations/src/network_helper.cpp +++ b/src/common/low_precision_transformations/src/network_helper.cpp @@ -1187,6 +1187,16 @@ FakeQuantizeDequantization NetworkHelper::makeDequantization( return FakeQuantizeDequantization(input, convert, subtract, nullptr, subtractConstant, multiply, multiplyConstant); } +std::shared_ptr NetworkHelper::makeDequantizationSubtract( + const ov::Output& parent, + const ov::Output& subtract_constant) { + return subtract_constant.get_element_type() != parent.get_element_type() + ? std::dynamic_pointer_cast(std::make_shared( + parent, + std::make_shared(subtract_constant, parent.get_element_type()))) + : std::make_shared(parent, subtract_constant); +} + FakeQuantizeDequantization NetworkHelper::createDequantizationFromFakeQuantize( std::shared_ptr fq, element::Type precision, @@ -1644,6 +1654,9 @@ NetworkHelper::InsertDequantizationResult NetworkHelper::moveDequantizationAfter op::TemporaryReplaceOutputType(foldConvert(dequantization.subtractConstant, parentPrecision), element::f32).get()); ngraph::copy_runtime_info({ newOperation, parent }, parent); } else { + // Subtract constant could be changed (including a shape) before propagation in some cases + // so it's necessary to compute the shape for a subtractConvert before creating a new subtract + dequantization.subtractConvert->validate_and_infer_types(); parent = std::make_shared(parent, dequantization.subtractConvert); ngraph::copy_runtime_info({ newOperation, parent }, parent); } @@ -1736,6 +1749,9 @@ NetworkHelper::InsertDequantizationResult NetworkHelper::moveDequantizationBefor foldConvert(subtractConstant, parentPrecision), element::f32).get()); parent->set_friendly_name(dequantization.subtract->get_friendly_name() + "_" + std::to_string(i + 1)); } else { + // Subtract constant could be changed (including a shape) before propagation in some cases + // so it's necessary to compute the shape for a subtractConvert before creating a new subtract + dequantization.subtractConvert->validate_and_infer_types(); parent = std::make_shared(parent, dequantization.subtractConvert); } ngraph::copy_runtime_info(dequantization.subtract, parent); diff --git a/src/common/low_precision_transformations/src/shuffle_channels.cpp b/src/common/low_precision_transformations/src/shuffle_channels.cpp index 7383717c298..4d35975ba55 100644 --- a/src/common/low_precision_transformations/src/shuffle_channels.cpp +++ b/src/common/low_precision_transformations/src/shuffle_channels.cpp @@ -42,7 +42,7 @@ bool ShuffleChannelsTransformation::transform(TransformationContext& context, ng auto dequantization = NetworkHelper::getDequantization(shuffleChannels, defaultPrecisions); const auto shuffleDequantizationConstant = [&](const std::shared_ptr& eltwise) { - const auto normalizedConst = NetworkHelper::normalizeDequantizationShape(eltwise); + const auto normalizedConst = NetworkHelper::normalizeDequantizationShape(eltwise, true); const auto constShape = normalizedConst->get_shape(); if (shape_size(constShape) == 1ul) { diff --git a/src/common/low_precision_transformations/src/split.cpp b/src/common/low_precision_transformations/src/split.cpp index 8587cb96f0c..086afc23218 100644 --- a/src/common/low_precision_transformations/src/split.cpp +++ b/src/common/low_precision_transformations/src/split.cpp @@ -90,7 +90,7 @@ bool SplitTransformation::transform(TransformationContext& context, ngraph::patt } if (dequantization.subtract) { - const auto subtract = std::make_shared(parent, splitedSub[i]); + const auto subtract = NetworkHelper::makeDequantizationSubtract(parent, splitedSub[i]); copy_runtime_info({ newSplit, subtract }, subtract); parent = subtract; } diff --git a/src/common/low_precision_transformations/src/transpose.cpp b/src/common/low_precision_transformations/src/transpose.cpp index 71a1609da00..ec95fdd73ae 100644 --- a/src/common/low_precision_transformations/src/transpose.cpp +++ b/src/common/low_precision_transformations/src/transpose.cpp @@ -125,8 +125,9 @@ bool TransposeTransformation::canBeTransformed(const TransformationContext& cont return true; }(); - const auto values = constant->cast_vector(); + // TODO: remove legacy limitation if (!isPerTensor) { + const auto values = constant->cast_vector(); if ((values.size() < 2ul) || (values[0] != 0) || (values[1] != 1)) { return false; } diff --git a/src/tests/functional/inference_engine/lp_transformations/pad_transformation.cpp b/src/tests/functional/inference_engine/lp_transformations/pad_transformation.cpp index a465218dff8..0e3f031b4e4 100644 --- a/src/tests/functional/inference_engine/lp_transformations/pad_transformation.cpp +++ b/src/tests/functional/inference_engine/lp_transformations/pad_transformation.cpp @@ -315,6 +315,28 @@ const std::vector deqWithSub = { ngraph::element::u8, {{ngraph::element::f32}, {{128.f, 64.f, 32.f}}, {{3.f, 1.f, 2.f}}} } + }, + // int8 subtraction with Convert from u8 to fp32 + { + LayerTransformation::createParamsU8I8(), + { + ngraph::element::u8, + { + {ngraph::element::f32}, + {{128.f}, element::undefined, {1, 3, 1, 1}, false, 1ul, element::u8, true}, + {3.f} + } + }, + { + ngraph::element::u8, + {{}, {}, {}}, + ngraph::element::u8, + { + {ngraph::element::f32}, + {{128.f}, element::undefined, {1, 3, 1, 1}, false, 1ul, element::u8, true}, + {3.f} + } + } } }; diff --git a/src/tests/functional/inference_engine/lp_transformations/reshape_transformation.cpp b/src/tests/functional/inference_engine/lp_transformations/reshape_transformation.cpp index a16620cd319..57b134b7d8d 100644 --- a/src/tests/functional/inference_engine/lp_transformations/reshape_transformation.cpp +++ b/src/tests/functional/inference_engine/lp_transformations/reshape_transformation.cpp @@ -250,6 +250,31 @@ const std::vector testValues = { } } }, + // U8: no subtract 3D -> 4D: channels are not affected: + // per-channel subtraction with Convert from u8 to fp32 and identical values + { + { 1, 4, 10, 10 }, + { 1, 2, 2, 10, 10}, + LayerTransformation::createParamsU8I8(), + { + ngraph::element::u8, + { + {ngraph::element::f32}, + {{128.f}, element::undefined, {1, 4, 1, 1}, false, 1ul, element::u8, true}, + {3.f} + } + }, + { + ngraph::element::u8, + {{}, {}, {}}, + ngraph::element::u8, + { + {ngraph::element::f32}, + {{128.f}, element::undefined, {}, false, 1ul, element::u8, true}, + {3.f} + } + } + }, // U8: with subtract 3D -> 4D: channels are not affected, dynamic batch { { -1, 3, 20 }, diff --git a/src/tests/functional/inference_engine/lp_transformations/shuffle_channels_transformation.cpp b/src/tests/functional/inference_engine/lp_transformations/shuffle_channels_transformation.cpp index 849253aa10f..f51d4e86974 100644 --- a/src/tests/functional/inference_engine/lp_transformations/shuffle_channels_transformation.cpp +++ b/src/tests/functional/inference_engine/lp_transformations/shuffle_channels_transformation.cpp @@ -139,6 +139,30 @@ const std::vector testValues = { {{ngraph::element::f32}, {{128.f, 64.f, 32.f}}, {{0.01f, 0.02f, 0.03f}}} } }, + // subtraction with Convert from u8 to fp32 + { + LayerTransformation::createParamsU8I8(), + 1, + 1, + { + ngraph::element::u8, + { + {ngraph::element::f32}, + {{128.f}, element::undefined, {1, 3, 1, 1}, false, 1ul, element::u8, true}, + {3.f} + } + }, + { + ngraph::element::u8, + {}, + ngraph::element::u8, + { + {ngraph::element::f32}, + {{128.f}, element::undefined, {1, 3, 1, 1}, false, 1ul, element::u8, true}, + {3.f} + } + } + }, // U8 quantization by spatial dimension, shuffling by the same dimension { LayerTransformation::createParamsU8I8(), diff --git a/src/tests/functional/inference_engine/lp_transformations/split_transformation.cpp b/src/tests/functional/inference_engine/lp_transformations/split_transformation.cpp index b70c6973eed..45c61aa1596 100644 --- a/src/tests/functional/inference_engine/lp_transformations/split_transformation.cpp +++ b/src/tests/functional/inference_engine/lp_transformations/split_transformation.cpp @@ -136,6 +136,46 @@ const std::vector testValues = { } } }, + // U8 per tensor quantization / int8 subtraction with Convert from u8 to fp32 + { + { 1, 3, 16, 16 }, std::int64_t{2}, size_t{2}, + LayerTransformation::createParamsU8I8(), + // ActualValues + { + ngraph::element::u8, + {{ngraph::element::f32}, {{128.f}, element::undefined, {}, false, 1ul, element::u8, true}, {3.f}} + }, + // ExpectedValues + { + ngraph::element::u8, + {}, + ngraph::element::u8, + { + {{ngraph::element::f32}, {{128.f}, element::undefined, {}, false, 1ul, element::u8, true}, {3.f}}, + {{ngraph::element::f32}, {{128.f}, element::undefined, {}, false, 1ul, element::u8, true}, {3.f}}, + } + } + }, + // U8 per tensor quantization / int8 subtraction with Convert from fp16 -> fp32 + { + { 1, 3, 16, 16 }, std::int64_t{2}, size_t{2}, + LayerTransformation::createParamsU8I8(), + // ActualValues + { + ngraph::element::u8, + {{ngraph::element::f32}, {{128.f}, element::undefined, {}, false, 1ul, element::f16, true}, {3.f}} + }, + // ExpectedValues + { + ngraph::element::u8, + {}, + ngraph::element::u8, + { + {{ngraph::element::f32}, {{128.f}, element::undefined, {}, false, 1ul, element::f16, true}, {3.f}}, + {{ngraph::element::f32}, {{128.f}, element::undefined, {}, false, 1ul, element::f16, true}, {3.f}}, + } + } + }, { { -1, -1, -1, -1 }, std::int64_t{2}, size_t{2}, LayerTransformation::createParamsU8I8(), diff --git a/src/tests/functional/inference_engine/lp_transformations/strided_slice_transformation.cpp b/src/tests/functional/inference_engine/lp_transformations/strided_slice_transformation.cpp index 4174969d93d..86725acba96 100644 --- a/src/tests/functional/inference_engine/lp_transformations/strided_slice_transformation.cpp +++ b/src/tests/functional/inference_engine/lp_transformations/strided_slice_transformation.cpp @@ -228,6 +228,29 @@ const std::vector stridedSliceTransformati {{ngraph::element::f32}, { 128.f }, { 0.1f }} } }, + // U8: channel slice, per-channel quantization with the same values, subtraction with Convert from u8 to fp32 + { + LayerTransformation::createParamsU8I8(), + channelSlice, + { + ngraph::element::u8, + { + {ngraph::element::f32}, + {{128.f}, element::undefined, {1, 3, 1, 1}, false, 1ul, element::u8, true}, + {3.f} + } + }, + { + ngraph::element::u8, + {}, + ngraph::element::u8, + { + {ngraph::element::f32}, + {{128.f}, element::undefined, {}, false, 1ul, element::u8, true}, + {3.f} + } + } + }, // U8: channel slice, per-channel quantization with different values { LayerTransformation::createParamsU8I8(), diff --git a/src/tests/functional/inference_engine/lp_transformations/transpose_transformation.cpp b/src/tests/functional/inference_engine/lp_transformations/transpose_transformation.cpp index 86bc20b4efe..0e960f49a21 100644 --- a/src/tests/functional/inference_engine/lp_transformations/transpose_transformation.cpp +++ b/src/tests/functional/inference_engine/lp_transformations/transpose_transformation.cpp @@ -168,6 +168,68 @@ const std::vector testValues = { } } }, + // U8: per-channel quantization with the same values, + // subtraction with Convert from u8 to fp32, transpose channel dimension + { + { 0, 3, 1, 2 }, + LayerTransformation::createParamsU8I8(), + { + ngraph::element::u8, + { + { ngraph::element::f32 }, + {{128.f}, element::undefined, {1, 3, 1, 1}, false, 1ul, element::u8, true}, + {{0.1}, ngraph::element::f32, { 1, 3, 1, 1 }} + } + }, + { + ngraph::element::u8, + {{}, {}, {}}, + ngraph::element::u8, + { + { ngraph::element::f32 }, + {{128.f}, element::undefined, {1, 1, 3, 1}, false, 1ul, element::u8, true}, + {{0.1}, ngraph::element::f32, {1, 1, 3, 1}} + } + } + }, + // U8: per-tensor quantization, transpose channel dimension + { + { 0, 3, 1, 2 }, + LayerTransformation::createParamsU8I8(), + { + ngraph::element::u8, + {{ngraph::element::f32}, {128}, {0.1f}} + }, + { + ngraph::element::u8, + {{}, {}, {}}, + ngraph::element::u8, + {{ngraph::element::f32}, {128}, {0.1f}} + } + }, + // U8: per-channel quantization, transpose channel dimension + { + { 0, 2, 1, 3 }, + LayerTransformation::createParamsU8I8(), + { + ngraph::element::u8, + { + { ngraph::element::f32 }, + {{ 128, 64, 32 }, ngraph::element::f32, { 1, 3, 1, 1 }}, + {{ 0.3f, 0.2f, 0.1f }, ngraph::element::f32, { 1, 3, 1, 1 }} + } + }, + { + ngraph::element::u8, + { + { ngraph::element::f32 }, + {{ 128, 64, 32 }, ngraph::element::f32, { 1, 3, 1, 1 }}, + {{ 0.3f, 0.2f, 0.1f }, ngraph::element::f32, { 1, 3, 1, 1 }} + }, + ngraph::element::f32, + {{}, {}, {}}, + } + }, // empty { { 0, 1, 3, 2 }, diff --git a/src/tests/ngraph_helpers/lpt_ngraph_functions/src/common/builders.cpp b/src/tests/ngraph_helpers/lpt_ngraph_functions/src/common/builders.cpp index b022bc73ba4..fb22618ff8a 100644 --- a/src/tests/ngraph_helpers/lpt_ngraph_functions/src/common/builders.cpp +++ b/src/tests/ngraph_helpers/lpt_ngraph_functions/src/common/builders.cpp @@ -60,7 +60,9 @@ std::shared_ptr makeDequantization( if (dequantizationOperations.subtract.addConvert) { std::shared_ptr subtractConstConvert = std::make_shared( subtractConst, - dequantizationOperations.subtract.outPrecision); + dequantizationOperations.subtract.outPrecision == element::undefined ? + parent.get_element_type() : + dequantizationOperations.subtract.outPrecision); auto& rt = subtractConstConvert->get_rt_info(); for (const auto& attribute : dequantizationOperations.subtract.convertAttributes) {