From ae3b19d034075eae43ec28302d2d4827a1a5608b Mon Sep 17 00:00:00 2001 From: Edward Shogulin Date: Tue, 3 Oct 2023 15:31:33 +0100 Subject: [PATCH] [LPT] [NPU] Multiply support (#19859) * [LPT] [NPU] Multiply support * [LPT] [NPU] Multiply support documentation * 1) FakeQuantize support 2) refactoring * [LPT] DisableCleanup attribute + cleanup transformations extension * [LPT] DisableCleanup usage * [LPT] Tests infrastructure support * [LPT] infrastructure quick fix * [LPT] Recurrent Cell Transformation fix * refactoring & comment fixes --- .../low_precision_transformations.md | 2 +- .../step3_main.md | 6 +- .../step3_main/arithmetic/multiply_partial.md | 3 + .../low_precision/cleanup_transformation.hpp | 30 + .../common/precisions_restriction.hpp | 2 +- .../low_precision/eliminate_fake_quantize.hpp | 4 +- .../include/low_precision/fold_convert.hpp | 4 +- .../include/low_precision/fuse_convert.hpp | 6 +- .../fuse_elementwise_to_fake_quantize.hpp | 29 + .../fuse_multiply_to_fake_quantize.hpp | 5 +- .../fuse_subtract_to_fake_quantize.hpp | 5 +- .../low_precision/layer_transformation.hpp | 2 +- .../include/low_precision/low_precision.hpp | 15 +- .../include/low_precision/multiply.hpp | 8 +- .../low_precision/multiply_partial.hpp | 32 + .../multiply_to_group_convolution.hpp | 4 +- .../rt_info/disable_cleanup_attribute.hpp | 27 + .../rt_info/skip_cleanup_attribute.hpp | 17 - .../weightable_layer_transformation.hpp | 27 +- .../src/cleanup_transformation.cpp | 26 + .../src/convolution.cpp | 6 + .../src/convolution_backprop_data.cpp | 6 + .../src/eliminate_fake_quantize.cpp | 6 +- .../src/fake_quantize.cpp | 5 + .../src/fold_convert.cpp | 7 +- .../src/fuse_convert.cpp | 7 +- .../src/fuse_elementwise_to_fake_quantize.cpp | 52 + .../src/fuse_multiply_to_fake_quantize.cpp | 37 +- .../src/fuse_subtract_to_fake_quantize.cpp | 48 +- .../src/layer_transformation.cpp | 6 +- .../src/low_precision.cpp | 8 +- .../src/multiply.cpp | 192 +-- .../src/multiply_partial.cpp | 174 ++ .../src/multiply_to_group_convolution.cpp | 6 +- .../src/recurrent_cell.cpp | 12 +- .../src/rt_info/skip_cleanup_attribute.cpp | 20 - .../src/weightable_layer_transformation.cpp | 56 +- .../lpt_avoid_shapeof_propagation_test.cpp | 4 +- .../tests/multiply_partial_transformation.cpp | 1007 +++++++++++ .../tests/multiply_transformation.cpp | 1493 ++++++++--------- .../simple_low_precision_transformer.cpp | 27 +- .../simple_low_precision_transformer.hpp | 3 +- .../multiply_transformation.cpp | 4 +- .../multiply_function.hpp | 43 +- .../multiply_partial_function.hpp | 60 + .../src/multiply_function.cpp | 123 +- .../src/multiply_partial_function.cpp | 154 ++ 47 files changed, 2570 insertions(+), 1250 deletions(-) create mode 100644 docs/articles_en/documentation/openvino_extensibility/openvino_plugin_library/detailed_guides/low_precision_transformations/step3_main/arithmetic/multiply_partial.md create mode 100644 src/common/low_precision_transformations/include/low_precision/cleanup_transformation.hpp create mode 100644 src/common/low_precision_transformations/include/low_precision/fuse_elementwise_to_fake_quantize.hpp create mode 100644 src/common/low_precision_transformations/include/low_precision/multiply_partial.hpp create mode 100644 src/common/low_precision_transformations/include/low_precision/rt_info/disable_cleanup_attribute.hpp delete mode 100644 src/common/low_precision_transformations/include/low_precision/rt_info/skip_cleanup_attribute.hpp create mode 100644 src/common/low_precision_transformations/src/cleanup_transformation.cpp create mode 100644 src/common/low_precision_transformations/src/fuse_elementwise_to_fake_quantize.cpp create mode 100644 src/common/low_precision_transformations/src/multiply_partial.cpp delete mode 100644 src/common/low_precision_transformations/src/rt_info/skip_cleanup_attribute.cpp create mode 100644 src/common/low_precision_transformations/tests/multiply_partial_transformation.cpp create mode 100644 src/tests/ngraph_helpers/lpt_ngraph_functions/include/lpt_ngraph_functions/multiply_partial_function.hpp create mode 100644 src/tests/ngraph_helpers/lpt_ngraph_functions/src/multiply_partial_function.cpp diff --git a/docs/articles_en/documentation/openvino_extensibility/openvino_plugin_library/detailed_guides/low_precision_transformations.md b/docs/articles_en/documentation/openvino_extensibility/openvino_plugin_library/detailed_guides/low_precision_transformations.md index 4b8dfd6fd5e..af9ffffed4c 100644 --- a/docs/articles_en/documentation/openvino_extensibility/openvino_plugin_library/detailed_guides/low_precision_transformations.md +++ b/docs/articles_en/documentation/openvino_extensibility/openvino_plugin_library/detailed_guides/low_precision_transformations.md @@ -200,7 +200,7 @@ Transformations: * :doc:`GatherTransformation ` * :doc:`MatMulTransformation ` * :doc:`MaxPoolTransformation ` -* :doc:`MultiplyTransformation ` +* :doc:`MultiplyPartialTransformation ` * :doc:`MVNTransformation ` * :doc:`NormalizeL2Transformation ` * :doc:`PReluTransformation ` diff --git a/docs/articles_en/documentation/openvino_extensibility/openvino_plugin_library/detailed_guides/low_precision_transformations/step3_main.md b/docs/articles_en/documentation/openvino_extensibility/openvino_plugin_library/detailed_guides/low_precision_transformations/step3_main.md index 162ba3ebfce..8bc0c5a0a50 100644 --- a/docs/articles_en/documentation/openvino_extensibility/openvino_plugin_library/detailed_guides/low_precision_transformations/step3_main.md +++ b/docs/articles_en/documentation/openvino_extensibility/openvino_plugin_library/detailed_guides/low_precision_transformations/step3_main.md @@ -26,7 +26,7 @@ GatherTransformation MatMulTransformation MaxPoolTransformation - MultiplyTransformation + MultiplyPartialTransformation MVNTransformation NormalizeL2Transformation PadTransformation @@ -45,7 +45,7 @@ TransposeTransformation UnsqueezeTransformation VariadicSplitTransformation - + Main transformations are the majority of low precision transformations. Transformations operate with dequantization operations. Main transformations include: @@ -64,7 +64,7 @@ Main transformations are the majority of low precision transformations. Transfor * :doc:`GatherTransformation ` * :doc:`MatMulTransformation ` * :doc:`MaxPoolTransformation ` -* :doc:`MultiplyTransformation ` +* :doc:`MultiplyPartialTransformation ` * :doc:`MVNTransformation ` * :doc:`NormalizeL2Transformation ` * :doc:`PadTransformation` diff --git a/docs/articles_en/documentation/openvino_extensibility/openvino_plugin_library/detailed_guides/low_precision_transformations/step3_main/arithmetic/multiply_partial.md b/docs/articles_en/documentation/openvino_extensibility/openvino_plugin_library/detailed_guides/low_precision_transformations/step3_main/arithmetic/multiply_partial.md new file mode 100644 index 00000000000..1d4b348100f --- /dev/null +++ b/docs/articles_en/documentation/openvino_extensibility/openvino_plugin_library/detailed_guides/low_precision_transformations/step3_main/arithmetic/multiply_partial.md @@ -0,0 +1,3 @@ +# MultiplyTransformation transformation {#openvino_docs_OV_UG_lpt_MultiplyPartialTransformation} + +ov::pass::low_precision::MultiplyPartialTransformation class represents the `MultiplyPartial` operation transformation. \ No newline at end of file diff --git a/src/common/low_precision_transformations/include/low_precision/cleanup_transformation.hpp b/src/common/low_precision_transformations/include/low_precision/cleanup_transformation.hpp new file mode 100644 index 00000000000..80e045e386b --- /dev/null +++ b/src/common/low_precision_transformations/include/low_precision/cleanup_transformation.hpp @@ -0,0 +1,30 @@ +// Copyright (C) 2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "layer_transformation.hpp" + +namespace ov { +namespace pass { +namespace low_precision { + +/** + * @ingroup ie_transformation_common_api + * @brief Base class for cleanup low precision transformation. + */ +class LP_TRANSFORMATIONS_API CleanupTransformation : public LayerTransformation { +public: + CleanupTransformation(const Params& params); + virtual ~CleanupTransformation() = default; + + bool canBeTransformed(const TransformationContext& context, std::shared_ptr layer) const override; + static bool canBeTransformedStatic( + const std::shared_ptr& layer, + const std::vector& defaultPrecisions = precision_set::get_int8_support()); +}; + +} // namespace low_precision +} // namespace pass +} // namespace ov diff --git a/src/common/low_precision_transformations/include/low_precision/common/precisions_restriction.hpp b/src/common/low_precision_transformations/include/low_precision/common/precisions_restriction.hpp index 31d820d1eb4..7301d13b27b 100644 --- a/src/common/low_precision_transformations/include/low_precision/common/precisions_restriction.hpp +++ b/src/common/low_precision_transformations/include/low_precision/common/precisions_restriction.hpp @@ -74,7 +74,7 @@ public: } template - static PrecisionsByPorts getPrecisionsByOperationType(std::vector& restrictions) { + static PrecisionsByPorts getPrecisionsByOperationType(const std::vector& restrictions) { for (const auto& restriction : restrictions) { if (restriction.operationType == T::get_type_info_static()) { return restriction.precisionsByPorts; diff --git a/src/common/low_precision_transformations/include/low_precision/eliminate_fake_quantize.hpp b/src/common/low_precision_transformations/include/low_precision/eliminate_fake_quantize.hpp index 2741d6b15cc..9b3d1f9e0fe 100644 --- a/src/common/low_precision_transformations/include/low_precision/eliminate_fake_quantize.hpp +++ b/src/common/low_precision_transformations/include/low_precision/eliminate_fake_quantize.hpp @@ -6,7 +6,7 @@ #include -#include "low_precision/layer_transformation.hpp" +#include "low_precision/cleanup_transformation.hpp" namespace ov { namespace pass { @@ -20,7 +20,7 @@ namespace low_precision { * [EliminateFakeQuantizeTransformation](@ref openvino_docs_OV_UG_lpt_EliminateFakeQuantizeTransformation) page * in the Inference Engine Developer Guide. */ -class LP_TRANSFORMATIONS_API EliminateFakeQuantizeTransformation : public LayerTransformation { +class LP_TRANSFORMATIONS_API EliminateFakeQuantizeTransformation : public CleanupTransformation { public: OPENVINO_RTTI("EliminateFakeQuantizeTransformation", "0"); EliminateFakeQuantizeTransformation(const Params& params = Params()); diff --git a/src/common/low_precision_transformations/include/low_precision/fold_convert.hpp b/src/common/low_precision_transformations/include/low_precision/fold_convert.hpp index e5fcfd639f7..640cdda59e6 100644 --- a/src/common/low_precision_transformations/include/low_precision/fold_convert.hpp +++ b/src/common/low_precision_transformations/include/low_precision/fold_convert.hpp @@ -6,7 +6,7 @@ #include -#include "low_precision/layer_transformation.hpp" +#include "low_precision/cleanup_transformation.hpp" namespace ov { namespace pass { @@ -20,7 +20,7 @@ namespace low_precision { * [FoldConvertTransformation](@ref openvino_docs_OV_UG_lpt_FoldConvertTransformation) page * in the Inference Engine Developer Guide. */ -class LP_TRANSFORMATIONS_API FoldConvertTransformation : public LayerTransformation { +class LP_TRANSFORMATIONS_API FoldConvertTransformation : public CleanupTransformation { public: OPENVINO_RTTI("FoldConvertTransformation", "0"); FoldConvertTransformation(const Params& params = Params()); diff --git a/src/common/low_precision_transformations/include/low_precision/fuse_convert.hpp b/src/common/low_precision_transformations/include/low_precision/fuse_convert.hpp index 76e5a8e4195..09c4692198d 100644 --- a/src/common/low_precision_transformations/include/low_precision/fuse_convert.hpp +++ b/src/common/low_precision_transformations/include/low_precision/fuse_convert.hpp @@ -4,9 +4,7 @@ #pragma once - -#include "low_precision/layer_transformation.hpp" -#include "low_precision/eltwise_base_transformation.hpp" +#include "low_precision/cleanup_transformation.hpp" namespace ov { namespace pass { @@ -20,7 +18,7 @@ namespace low_precision { * [FuseConvertTransformation](@ref openvino_docs_OV_UG_lpt_FuseConvertTransformation) page * in the Inference Engine Developer Guide. */ -class LP_TRANSFORMATIONS_API FuseConvertTransformation : public LayerTransformation { +class LP_TRANSFORMATIONS_API FuseConvertTransformation : public CleanupTransformation { public: OPENVINO_RTTI("FuseConvertTransformation", "0"); FuseConvertTransformation(const Params& params = Params()); diff --git a/src/common/low_precision_transformations/include/low_precision/fuse_elementwise_to_fake_quantize.hpp b/src/common/low_precision_transformations/include/low_precision/fuse_elementwise_to_fake_quantize.hpp new file mode 100644 index 00000000000..d615d0f13bb --- /dev/null +++ b/src/common/low_precision_transformations/include/low_precision/fuse_elementwise_to_fake_quantize.hpp @@ -0,0 +1,29 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#include "low_precision/cleanup_transformation.hpp" + +namespace ov { +namespace pass { +namespace low_precision { + +/** + * @ingroup ie_transformation_common_api + * @brief Base class for fuse elementwise to FakeQuantize low precision transformation. + */ +class LP_TRANSFORMATIONS_API FuseElementwiseToFakeQuantizeTransformation : public CleanupTransformation { +public: + FuseElementwiseToFakeQuantizeTransformation(const Params& params); + virtual ~FuseElementwiseToFakeQuantizeTransformation() = default; + + bool canBeTransformed(const TransformationContext& context, std::shared_ptr layer) const override; +}; + +} // namespace low_precision +} // namespace pass +} // namespace ov diff --git a/src/common/low_precision_transformations/include/low_precision/fuse_multiply_to_fake_quantize.hpp b/src/common/low_precision_transformations/include/low_precision/fuse_multiply_to_fake_quantize.hpp index 34259bb87c9..af0e152db0b 100644 --- a/src/common/low_precision_transformations/include/low_precision/fuse_multiply_to_fake_quantize.hpp +++ b/src/common/low_precision_transformations/include/low_precision/fuse_multiply_to_fake_quantize.hpp @@ -6,7 +6,7 @@ #include -#include "low_precision/layer_transformation.hpp" +#include "low_precision/fuse_elementwise_to_fake_quantize.hpp" namespace ov { namespace pass { @@ -20,12 +20,11 @@ namespace low_precision { * [FuseMultiplyToFakeQuantizeTransformation](@ref openvino_docs_OV_UG_lpt_FuseMultiplyToFakeQuantizeTransformation) page * in the Inference Engine Developer Guide. */ -class LP_TRANSFORMATIONS_API FuseMultiplyToFakeQuantizeTransformation : public LayerTransformation { +class LP_TRANSFORMATIONS_API FuseMultiplyToFakeQuantizeTransformation : public FuseElementwiseToFakeQuantizeTransformation { public: OPENVINO_RTTI("FuseMultiplyToFakeQuantizeTransformation", "0"); FuseMultiplyToFakeQuantizeTransformation(const Params& params = Params()); bool transform(TransformationContext& context, ov::pass::pattern::Matcher &m) override; - bool canBeTransformed(const TransformationContext& context, std::shared_ptr layer) const override; bool isPrecisionPreserved(std::shared_ptr layer) const noexcept override; }; diff --git a/src/common/low_precision_transformations/include/low_precision/fuse_subtract_to_fake_quantize.hpp b/src/common/low_precision_transformations/include/low_precision/fuse_subtract_to_fake_quantize.hpp index 98527defe13..4b06f6cc7de 100644 --- a/src/common/low_precision_transformations/include/low_precision/fuse_subtract_to_fake_quantize.hpp +++ b/src/common/low_precision_transformations/include/low_precision/fuse_subtract_to_fake_quantize.hpp @@ -6,7 +6,7 @@ #include -#include "low_precision/layer_transformation.hpp" +#include "low_precision/fuse_elementwise_to_fake_quantize.hpp" namespace ov { namespace pass { @@ -20,12 +20,11 @@ namespace low_precision { * [FuseSubtractToFakeQuantizeTransformation](@ref openvino_docs_OV_UG_lpt_FuseSubtractToFakeQuantizeTransformation) page * in the Inference Engine Developer Guide. */ -class LP_TRANSFORMATIONS_API FuseSubtractToFakeQuantizeTransformation : public LayerTransformation { +class LP_TRANSFORMATIONS_API FuseSubtractToFakeQuantizeTransformation : public FuseElementwiseToFakeQuantizeTransformation { public: OPENVINO_RTTI("FuseSubtractToFakeQuantizeTransformation", "0"); FuseSubtractToFakeQuantizeTransformation(const Params& params = Params()); bool transform(TransformationContext& context, ov::pass::pattern::Matcher &m) override; - bool canBeTransformed(const TransformationContext& context, std::shared_ptr layer) const override; bool isPrecisionPreserved(std::shared_ptr layer) const noexcept override; }; diff --git a/src/common/low_precision_transformations/include/low_precision/layer_transformation.hpp b/src/common/low_precision_transformations/include/low_precision/layer_transformation.hpp index 80096cdf7c1..e68f395049c 100644 --- a/src/common/low_precision_transformations/include/low_precision/layer_transformation.hpp +++ b/src/common/low_precision_transformations/include/low_precision/layer_transformation.hpp @@ -371,7 +371,7 @@ protected: const bool updatePrecision, const bool moveSubtract = true) const; - void updateOutput( + bool updateOutput( TransformationContext &context, std::shared_ptr lastNode, std::shared_ptr originalNode) const; diff --git a/src/common/low_precision_transformations/include/low_precision/low_precision.hpp b/src/common/low_precision_transformations/include/low_precision/low_precision.hpp index 9236113c731..f40c92605d6 100644 --- a/src/common/low_precision_transformations/include/low_precision/low_precision.hpp +++ b/src/common/low_precision_transformations/include/low_precision/low_precision.hpp @@ -48,9 +48,9 @@ public: const AttributeParameters& params); bool run_on_model(const std::shared_ptr& m) override; private: - const std::vector& precisionRestrictions; - const std::vector& quantizationRestrictions; - const AttributeParameters& params; + const std::vector precisionRestrictions; + const std::vector quantizationRestrictions; + const AttributeParameters params; }; class ov::pass::low_precision::TypeRelaxedReplacer : public ov::pass::GraphRewrite { @@ -71,9 +71,18 @@ public: static bool isFunctionQuantized(const std::shared_ptr& model); static bool isFQLevelsPresent(const std::shared_ptr& model, const std::set& levels); + template + std::shared_ptr add_main(Args&&... args) { + const auto tr = std::make_shared(std::forward(args)...); + additional_main_passes.push_back(tr); + return tr; + } + protected: std::vector precisionRestrictions; std::vector quantizationRestrictions; // remove LayerTransformation::Params params; + + std::vector> additional_main_passes; }; diff --git a/src/common/low_precision_transformations/include/low_precision/multiply.hpp b/src/common/low_precision_transformations/include/low_precision/multiply.hpp index 3dc4a26d056..55484b041d6 100644 --- a/src/common/low_precision_transformations/include/low_precision/multiply.hpp +++ b/src/common/low_precision_transformations/include/low_precision/multiply.hpp @@ -5,7 +5,7 @@ #pragma once -#include "low_precision/eltwise_base_transformation.hpp" +#include "low_precision/weightable_layer_transformation.hpp" namespace ov { namespace pass { @@ -19,12 +19,14 @@ namespace low_precision { * [MultiplyTransformation](@ref openvino_docs_OV_UG_lpt_MultiplyTransformation) page * in the Inference Engine Developer Guide. */ -class LP_TRANSFORMATIONS_API MultiplyTransformation : public EltwiseBaseTransformation { +class LP_TRANSFORMATIONS_API MultiplyTransformation : public WeightableLayerTransformation { public: OPENVINO_RTTI("MultiplyTransformation", "0"); MultiplyTransformation(const Params& params = Params()); bool transform(TransformationContext& context, ov::pass::pattern::Matcher &m) override; - bool canBeTransformed(const TransformationContext& context, std::shared_ptr layer) const override; + +protected: + size_t getInputChannels(const std::shared_ptr op) const override; }; } // namespace low_precision diff --git a/src/common/low_precision_transformations/include/low_precision/multiply_partial.hpp b/src/common/low_precision_transformations/include/low_precision/multiply_partial.hpp new file mode 100644 index 00000000000..c3db52ce5d9 --- /dev/null +++ b/src/common/low_precision_transformations/include/low_precision/multiply_partial.hpp @@ -0,0 +1,32 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include "low_precision/eltwise_base_transformation.hpp" + +namespace ov { +namespace pass { +namespace low_precision { + +/** + * @ingroup ie_transformation_common_api + * @brief MultiplyPartialTransformation propagates dequantization operations through Multiply operation. + * + * For more details about the transformation, refer to + * [MultiplyPartialTransformation](@ref openvino_docs_OV_UG_lpt_MultiplyPartialTransformation) page + * in the Inference Engine Developer Guide. + */ +class LP_TRANSFORMATIONS_API MultiplyPartialTransformation : public EltwiseBaseTransformation { +public: + OPENVINO_RTTI("MultiplyPartialTransformation", "0"); + MultiplyPartialTransformation(const Params& params = Params()); + bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) override; + bool canBeTransformed(const TransformationContext& context, std::shared_ptr layer) const override; +}; + +} // namespace low_precision +} // namespace pass +} // namespace ov diff --git a/src/common/low_precision_transformations/include/low_precision/multiply_to_group_convolution.hpp b/src/common/low_precision_transformations/include/low_precision/multiply_to_group_convolution.hpp index b107b7de041..d76f2ef1088 100644 --- a/src/common/low_precision_transformations/include/low_precision/multiply_to_group_convolution.hpp +++ b/src/common/low_precision_transformations/include/low_precision/multiply_to_group_convolution.hpp @@ -5,7 +5,7 @@ #pragma once #include -#include "low_precision/layer_transformation.hpp" +#include "low_precision/cleanup_transformation.hpp" #include "common/precisions_restriction.hpp" namespace ov { @@ -20,7 +20,7 @@ namespace low_precision { * [MultiplyToGroupConvolutionTransformation](@ref openvino_docs_OV_UG_lpt_MultiplyToGroupConvolutionTransformation) page * in the Inference Engine Developer Guide. */ -class LP_TRANSFORMATIONS_API MultiplyToGroupConvolutionTransformation : public LayerTransformation { +class LP_TRANSFORMATIONS_API MultiplyToGroupConvolutionTransformation : public CleanupTransformation { public: OPENVINO_RTTI("MultiplyToGroupConvolutionTransformation", "0"); MultiplyToGroupConvolutionTransformation( diff --git a/src/common/low_precision_transformations/include/low_precision/rt_info/disable_cleanup_attribute.hpp b/src/common/low_precision_transformations/include/low_precision/rt_info/disable_cleanup_attribute.hpp new file mode 100644 index 00000000000..71df996fe15 --- /dev/null +++ b/src/common/low_precision_transformations/include/low_precision/rt_info/disable_cleanup_attribute.hpp @@ -0,0 +1,27 @@ +// Copyright (C) 2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/core/node.hpp" +#include "openvino/core/runtime_attribute.hpp" +#include "low_precision/lpt_visibility.hpp" + +namespace ov { + +class LP_TRANSFORMATIONS_API DisableCleanupAttribute : public ov::RuntimeAttribute { +public: + OPENVINO_RTTI("LowPrecision::DisableCleanup", "", ov::RuntimeAttribute); + DisableCleanupAttribute() = default; + + static ov::Any create(const std::shared_ptr& node) { + auto& rt = node->get_rt_info(); + return (rt[DisableCleanupAttribute::get_type_info_static()] = DisableCleanupAttribute()); + } + + bool is_copyable() const override { + return false; + } +}; +} // namespace ov diff --git a/src/common/low_precision_transformations/include/low_precision/rt_info/skip_cleanup_attribute.hpp b/src/common/low_precision_transformations/include/low_precision/rt_info/skip_cleanup_attribute.hpp deleted file mode 100644 index 39e0bf46e3a..00000000000 --- a/src/common/low_precision_transformations/include/low_precision/rt_info/skip_cleanup_attribute.hpp +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright (C) 2018-2023 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#pragma once - -#include "openvino/core/node.hpp" - -#include "low_precision/rt_info/attribute_parameters.hpp" - -namespace ov { -class LP_TRANSFORMATIONS_API SkipCleanupAttribute : public ov::RuntimeAttribute { -public: - OPENVINO_RTTI("LowPrecision::SkipCleanup", "", ov::RuntimeAttribute); - static ov::Any create(const std::shared_ptr& node); -}; -} // namespace ov diff --git a/src/common/low_precision_transformations/include/low_precision/weightable_layer_transformation.hpp b/src/common/low_precision_transformations/include/low_precision/weightable_layer_transformation.hpp index 8abad779628..4655a940120 100644 --- a/src/common/low_precision_transformations/include/low_precision/weightable_layer_transformation.hpp +++ b/src/common/low_precision_transformations/include/low_precision/weightable_layer_transformation.hpp @@ -19,7 +19,29 @@ namespace low_precision { */ class LP_TRANSFORMATIONS_API WeightableLayerTransformation : public LayerTransformation { public: - WeightableLayerTransformation(const Params& params); + struct LP_TRANSFORMATIONS_API CanBeTransformedParams { + CanBeTransformedParams( + const bool constantWeight = true, + const bool perTensorQuantizationOnData = true, + const bool limitWeightsDataPrecision = true, + const bool dynamicWeights = false) : + constantWeight(constantWeight), + perTensorQuantizationOnData(perTensorQuantizationOnData), + limitWeightsDataPrecision(limitWeightsDataPrecision), + dynamicWeights(dynamicWeights) { + } + + // weights on constant path only + const bool constantWeight; + // data with per-tensor quantization only + const bool perTensorQuantizationOnData; + // limit weights by expected precisions + const bool limitWeightsDataPrecision; + const bool dynamicWeights; + }; + + WeightableLayerTransformation(const Params& params, const CanBeTransformedParams& canBeTransformedParams = {}); + bool canBeTransformed(const TransformationContext& context, std::shared_ptr layer) const override; bool canConvolutionBeTransformed(const TransformationContext& context, std::shared_ptr layer, const std::vector& defaultPrecisions) const; @@ -48,6 +70,9 @@ public: static DataPrecision getDataPrecisionOnWeights(const std::shared_ptr& node, const std::vector& defaultPrecisions); static bool isAsymmetricOnWeights(const std::shared_ptr& node, const std::vector& defaultPrecisions = precision_set::get_int8_support()); + +private: + const CanBeTransformedParams canBeTransformedParams; }; } // namespace low_precision diff --git a/src/common/low_precision_transformations/src/cleanup_transformation.cpp b/src/common/low_precision_transformations/src/cleanup_transformation.cpp new file mode 100644 index 00000000000..3a7cb0da5d5 --- /dev/null +++ b/src/common/low_precision_transformations/src/cleanup_transformation.cpp @@ -0,0 +1,26 @@ +// Copyright (C) 2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "low_precision/cleanup_transformation.hpp" +#include "low_precision/network_helper.hpp" +#include "low_precision/rt_info/disable_cleanup_attribute.hpp" + +namespace ov { +namespace pass { +namespace low_precision { + +CleanupTransformation::CleanupTransformation(const Params& params) : LayerTransformation(params) { +} + +bool CleanupTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr layer) const { + return canBeTransformedStatic(layer); +} + +bool CleanupTransformation::canBeTransformedStatic(const std::shared_ptr& layer, const std::vector& defaultPrecisions) { + return getAttribute(layer).empty(); +} + +} // namespace low_precision +} // namespace pass +} // namespace ov diff --git a/src/common/low_precision_transformations/src/convolution.cpp b/src/common/low_precision_transformations/src/convolution.cpp index e6044b6d9ed..2c80e75e156 100644 --- a/src/common/low_precision_transformations/src/convolution.cpp +++ b/src/common/low_precision_transformations/src/convolution.cpp @@ -13,6 +13,7 @@ #include "openvino/pass/pattern/op/wrap_type.hpp" #include "openvino/pass/pattern/op/or.hpp" #include "low_precision/network_helper.hpp" +#include "low_precision/rt_info/disable_cleanup_attribute.hpp" #include "transformations/rt_info/disable_constant_folding.hpp" #include "itt.hpp" @@ -333,6 +334,11 @@ bool ConvolutionTransformation::transform(TransformationContext &context, ov::pa ov::copy_runtime_info({ convolution, finalDequantization }, finalDequantization); updateOutput(context, finalDequantization, convolution); + const auto onActiviation = convolution->get_input_node_shared_ptr(0); + if (ov::is_type(onActiviation)) { + DisableCleanupAttribute::create(onActiviation); + } + auto onWeights = convolution->get_input_node_shared_ptr(1); if (ov::is_type(onWeights)) { onWeights = onWeights->get_input_node_shared_ptr(0); diff --git a/src/common/low_precision_transformations/src/convolution_backprop_data.cpp b/src/common/low_precision_transformations/src/convolution_backprop_data.cpp index 890bff9d231..3e232b5c840 100644 --- a/src/common/low_precision_transformations/src/convolution_backprop_data.cpp +++ b/src/common/low_precision_transformations/src/convolution_backprop_data.cpp @@ -13,6 +13,7 @@ #include "openvino/pass/pattern/op/wrap_type.hpp" #include "openvino/pass/pattern/op/or.hpp" #include "low_precision/network_helper.hpp" +#include "low_precision/rt_info/disable_cleanup_attribute.hpp" #include "transformations/rt_info/disable_constant_folding.hpp" #include "itt.hpp" @@ -220,6 +221,11 @@ bool ConvolutionBackpropDataTransformation::transform(TransformationContext &con ov::copy_runtime_info({ convolutionBackpropData, finalDequantization }, finalDequantization); updateOutput(context, finalDequantization, convolutionBackpropData); + const auto onActiviation = convolutionBackpropData->get_input_node_shared_ptr(0); + if (ov::is_type(onActiviation)) { + DisableCleanupAttribute::create(onActiviation); + } + auto onWeights = convolutionBackpropData->get_input_node_shared_ptr(1); if (ov::is_type(onWeights)) { onWeights = onWeights->get_input_node_shared_ptr(0); diff --git a/src/common/low_precision_transformations/src/eliminate_fake_quantize.cpp b/src/common/low_precision_transformations/src/eliminate_fake_quantize.cpp index bfa83bb0f44..3010ea213d0 100644 --- a/src/common/low_precision_transformations/src/eliminate_fake_quantize.cpp +++ b/src/common/low_precision_transformations/src/eliminate_fake_quantize.cpp @@ -15,7 +15,7 @@ namespace ov { namespace pass { namespace low_precision { -EliminateFakeQuantizeTransformation::EliminateFakeQuantizeTransformation(const Params& params) : LayerTransformation(params) { +EliminateFakeQuantizeTransformation::EliminateFakeQuantizeTransformation(const Params& params) : CleanupTransformation(params) { MATCHER_SCOPE(FuseMultiplyToFakeQuantizeTransformation); const auto matcher = pattern::wrap_type({ pattern::any_input(), @@ -112,6 +112,10 @@ bool check_intervals(const std::shared_ptr& fakeQuanti } // namespace bool EliminateFakeQuantizeTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr operation) const { + if (!CleanupTransformation::canBeTransformed(context, operation)) { + return false; + } + const auto fakeQuantize = ov::as_type_ptr(operation); OPENVINO_ASSERT(fakeQuantize != nullptr, "unexpected operation type"); diff --git a/src/common/low_precision_transformations/src/fake_quantize.cpp b/src/common/low_precision_transformations/src/fake_quantize.cpp index 28e32076181..a60c3bfcd93 100644 --- a/src/common/low_precision_transformations/src/fake_quantize.cpp +++ b/src/common/low_precision_transformations/src/fake_quantize.cpp @@ -11,6 +11,7 @@ #include "low_precision/network_helper.hpp" #include "low_precision/rt_info/bias_attribute.hpp" +#include "low_precision/rt_info/disable_cleanup_attribute.hpp" #include "itt.hpp" namespace ov { @@ -167,6 +168,10 @@ std::shared_ptr FakeQuantizeTransformation::fuseElementwis return nullptr; } + if (!getAttribute(eltwise).empty()) { + return nullptr; + } + std::shared_ptr inputLowConst_f32 = foldConvert(fakeQuantize->input_value(1), element::f32); std::shared_ptr inputHighConst_f32 = foldConvert(fakeQuantize->input_value(2), element::f32); diff --git a/src/common/low_precision_transformations/src/fold_convert.cpp b/src/common/low_precision_transformations/src/fold_convert.cpp index 35b3385e2eb..4054b0fad4e 100644 --- a/src/common/low_precision_transformations/src/fold_convert.cpp +++ b/src/common/low_precision_transformations/src/fold_convert.cpp @@ -14,7 +14,7 @@ namespace ov { namespace pass { namespace low_precision { -FoldConvertTransformation::FoldConvertTransformation(const Params& params) : LayerTransformation(params) { +FoldConvertTransformation::FoldConvertTransformation(const Params& params) : CleanupTransformation(params) { MATCHER_SCOPE(FoldConvertTransformation); auto subtract = pattern::wrap_type(); auto matcher = std::make_shared(subtract, matcher_name); @@ -57,10 +57,11 @@ bool FoldConvertTransformation::transform(TransformationContext& context, ov::pa bool FoldConvertTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr operation) const { return - (ov::is_type(operation->get_input_node_ptr(1)) && + CleanupTransformation::canBeTransformed(context, operation) && + ((ov::is_type(operation->get_input_node_ptr(1)) && ov::is_type(operation->get_input_node_ptr(1)->get_input_node_ptr(0))) || (ov::is_type(operation->get_input_node_ptr(0)) && - ov::is_type(operation->get_input_node_ptr(0)->get_input_node_ptr(0))); + ov::is_type(operation->get_input_node_ptr(0)->get_input_node_ptr(0)))); } bool FoldConvertTransformation::isPrecisionPreserved(std::shared_ptr layer) const noexcept { diff --git a/src/common/low_precision_transformations/src/fuse_convert.cpp b/src/common/low_precision_transformations/src/fuse_convert.cpp index 9c17f38074e..372476aeabe 100644 --- a/src/common/low_precision_transformations/src/fuse_convert.cpp +++ b/src/common/low_precision_transformations/src/fuse_convert.cpp @@ -12,14 +12,15 @@ #include "low_precision/common/ie_lpt_exception.hpp" #include "low_precision/network_helper.hpp" +#include "low_precision/rt_info/disable_cleanup_attribute.hpp" + #include "itt.hpp" -#include "low_precision/rt_info/skip_cleanup_attribute.hpp" namespace ov { namespace pass { namespace low_precision { -FuseConvertTransformation::FuseConvertTransformation(const Params& params) : LayerTransformation(params) { +FuseConvertTransformation::FuseConvertTransformation(const Params& params) : CleanupTransformation(params) { MATCHER_SCOPE(FuseConvertTransformation); auto multiply = pattern::wrap_type({ pattern::wrap_type(), pattern::wrap_type() }); auto subtract = pattern::wrap_type({ pattern::wrap_type(), pattern::wrap_type() }); @@ -114,7 +115,7 @@ bool FuseConvertTransformation::transform(TransformationContext& context, ov::pa } bool FuseConvertTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr op) const { - if (!getAttribute(op).empty()) { + if (!CleanupTransformation::canBeTransformed(context, op)) { return false; } diff --git a/src/common/low_precision_transformations/src/fuse_elementwise_to_fake_quantize.cpp b/src/common/low_precision_transformations/src/fuse_elementwise_to_fake_quantize.cpp new file mode 100644 index 00000000000..c641824bf53 --- /dev/null +++ b/src/common/low_precision_transformations/src/fuse_elementwise_to_fake_quantize.cpp @@ -0,0 +1,52 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "low_precision/fuse_elementwise_to_fake_quantize.hpp" + +#include +#include "low_precision/fake_quantize.hpp" +#include "low_precision/network_helper.hpp" + +namespace ov { +namespace pass { +namespace low_precision { + +FuseElementwiseToFakeQuantizeTransformation::FuseElementwiseToFakeQuantizeTransformation(const Params& params) : CleanupTransformation(params) { +} + +bool FuseElementwiseToFakeQuantizeTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr operation) const { + if (!CleanupTransformation::canBeTransformed(context, operation)) { + return false; + } + + if (!ov::is_type(operation->get_input_node_shared_ptr(1))) { + return false; + } + + if (!FakeQuantizeTransformation::checkElementwise(operation)) { + return false; + } + + const auto parent = operation->get_input_node_shared_ptr(0); + auto fq = ov::as_type_ptr(parent); + const auto convert = ov::as_type_ptr(parent); + + if (convert) { + fq = ov::as_type_ptr(convert->get_input_node_shared_ptr(0)); + } + + if (!fq) { + return false; + } + + if (fq->get_output_target_inputs(0).size() != 1) { + return false; + } + + return true; +} + +} // namespace low_precision +} // namespace pass +} // namespace ov diff --git a/src/common/low_precision_transformations/src/fuse_multiply_to_fake_quantize.cpp b/src/common/low_precision_transformations/src/fuse_multiply_to_fake_quantize.cpp index fd316f0068d..ccc21649009 100644 --- a/src/common/low_precision_transformations/src/fuse_multiply_to_fake_quantize.cpp +++ b/src/common/low_precision_transformations/src/fuse_multiply_to_fake_quantize.cpp @@ -9,13 +9,14 @@ #include "low_precision/fake_quantize.hpp" #include "low_precision/network_helper.hpp" #include "itt.hpp" -#include "low_precision/rt_info/skip_cleanup_attribute.hpp" +#include "low_precision/rt_info/disable_cleanup_attribute.hpp" namespace ov { namespace pass { namespace low_precision { -FuseMultiplyToFakeQuantizeTransformation::FuseMultiplyToFakeQuantizeTransformation(const Params& params) : LayerTransformation(params) { +FuseMultiplyToFakeQuantizeTransformation::FuseMultiplyToFakeQuantizeTransformation(const Params& params) + : FuseElementwiseToFakeQuantizeTransformation(params) { MATCHER_SCOPE(FuseMultiplyToFakeQuantizeTransformation); auto matcher = pattern::wrap_type(); @@ -89,38 +90,6 @@ bool FuseMultiplyToFakeQuantizeTransformation::transform(TransformationContext& return true; } -bool FuseMultiplyToFakeQuantizeTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr operation) const { - if (!ov::is_type(operation->get_input_node_shared_ptr(1))) { - return false; - } - - if (!FakeQuantizeTransformation::checkElementwise(operation)) { - return false; - } - - if (!getAttribute(operation).empty()) { - return false; - } - - const auto parent = operation->get_input_node_shared_ptr(0); - auto fq = ov::as_type_ptr(parent); - const auto convert = ov::as_type_ptr(parent); - - if (convert) { - fq = ov::as_type_ptr(convert->get_input_node_shared_ptr(0)); - } - - if (!fq) { - return false; - } - - if (fq->get_output_target_inputs(0).size() != 1) { - return false; - } - - return true; -} - bool FuseMultiplyToFakeQuantizeTransformation::isPrecisionPreserved(std::shared_ptr layer) const noexcept { return false; } diff --git a/src/common/low_precision_transformations/src/fuse_subtract_to_fake_quantize.cpp b/src/common/low_precision_transformations/src/fuse_subtract_to_fake_quantize.cpp index d75fde32ee1..56ed774ba36 100644 --- a/src/common/low_precision_transformations/src/fuse_subtract_to_fake_quantize.cpp +++ b/src/common/low_precision_transformations/src/fuse_subtract_to_fake_quantize.cpp @@ -9,13 +9,14 @@ #include "low_precision/fake_quantize.hpp" #include "low_precision/network_helper.hpp" #include "itt.hpp" -#include "low_precision/rt_info/skip_cleanup_attribute.hpp" +#include "low_precision/rt_info/disable_cleanup_attribute.hpp" namespace ov { namespace pass { namespace low_precision { -FuseSubtractToFakeQuantizeTransformation::FuseSubtractToFakeQuantizeTransformation(const Params& params) : LayerTransformation(params) { +FuseSubtractToFakeQuantizeTransformation::FuseSubtractToFakeQuantizeTransformation(const Params& params) + : FuseElementwiseToFakeQuantizeTransformation(params) { MATCHER_SCOPE(FuseSubtractToFakeQuantizeTransformation); auto matcher = pattern::wrap_type(); @@ -84,49 +85,6 @@ bool FuseSubtractToFakeQuantizeTransformation::transform(TransformationContext& return true; } -bool FuseSubtractToFakeQuantizeTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr operation) const { - if (!ov::is_type(operation->get_input_node_shared_ptr(1))) { - return false; - } - - if (!FakeQuantizeTransformation::checkElementwise(operation)) { - return false; - } - - if (!getAttribute(operation).empty()) { - return false; - } - - const auto children = operation->get_output_target_inputs(0); - - for (const auto& target : children) { - const auto convolution = ov::is_type(target.get_node()); - const auto groupConvolution = ov::is_type(target.get_node()); - const auto convolutionBackpropData = ov::is_type(target.get_node()); - if (convolution || groupConvolution || convolutionBackpropData) { - return false; - } - } - - const auto parent = operation->get_input_node_shared_ptr(0); - auto fq = ov::as_type_ptr(parent); - const auto convert = ov::as_type_ptr(parent); - - if (convert) { - fq = ov::as_type_ptr(convert->get_input_node_shared_ptr(0)); - } - - if (!fq) { - return false; - } - - if (fq->get_output_target_inputs(0).size() != 1) { - return false; - } - - return true; -} - bool FuseSubtractToFakeQuantizeTransformation::isPrecisionPreserved(std::shared_ptr layer) const noexcept { return false; } diff --git a/src/common/low_precision_transformations/src/layer_transformation.cpp b/src/common/low_precision_transformations/src/layer_transformation.cpp index 60ee21c1b34..86c2ba9e7df 100644 --- a/src/common/low_precision_transformations/src/layer_transformation.cpp +++ b/src/common/low_precision_transformations/src/layer_transformation.cpp @@ -422,21 +422,23 @@ std::shared_ptr LayerTransformation::moveDequantizationBefore( return result.newOperation; } -void LayerTransformation::updateOutput( +bool LayerTransformation::updateOutput( TransformationContext &context, std::shared_ptr lastNode, std::shared_ptr originalNode) const { - // TODO: not tested!!! + bool was_updated = false; for (auto output : lastNode->outputs()) { for (auto input : output.get_target_inputs()) { if (ov::is_type(input.get_node())) { const std::string originalName = originalNode->get_friendly_name(); originalNode->set_friendly_name(originalName + LayerTransformation::originalLayerPostfix); lastNode->set_friendly_name(originalName); + was_updated = true; break; } } } + return was_updated; } void LayerTransformation::updateOutput( diff --git a/src/common/low_precision_transformations/src/low_precision.cpp b/src/common/low_precision_transformations/src/low_precision.cpp index 9b84bb15dae..0f46c41c817 100644 --- a/src/common/low_precision_transformations/src/low_precision.cpp +++ b/src/common/low_precision_transformations/src/low_precision.cpp @@ -53,7 +53,7 @@ #include "low_precision/interpolate.hpp" #include "low_precision/mat_mul.hpp" #include "low_precision/max_pool.hpp" -#include "low_precision/multiply.hpp" +#include "low_precision/multiply_partial.hpp" #include "low_precision/mvn.hpp" #include "low_precision/normalize_l2.hpp" #include "low_precision/pad.hpp" @@ -251,7 +251,7 @@ bool ov::pass::low_precision::LowPrecision::run_on_model(const std::shared_ptradd_matcher(tr); + } + std::shared_ptr cleanup = manager.register_pass(); ADD_MATCHER(cleanup, EliminateFakeQuantizeTransformation, params) ADD_MATCHER(cleanup, FoldConvertTransformation, params) diff --git a/src/common/low_precision_transformations/src/multiply.cpp b/src/common/low_precision_transformations/src/multiply.cpp index 6d336e659e7..cc654d3deff 100644 --- a/src/common/low_precision_transformations/src/multiply.cpp +++ b/src/common/low_precision_transformations/src/multiply.cpp @@ -1,4 +1,4 @@ -// Copyright (C) 2018-2023 Intel Corporation +// Copyright (C) 2023 Intel Corporation // SPDX-License-Identifier: Apache-2.0 // @@ -15,6 +15,7 @@ #include "openvino/pass/pattern/op/wrap_type.hpp" #include "low_precision/common/ie_lpt_exception.hpp" +#include "low_precision/rt_info/disable_cleanup_attribute.hpp" #include "low_precision/network_helper.hpp" #include "itt.hpp" @@ -22,7 +23,8 @@ namespace ov { namespace pass { namespace low_precision { -MultiplyTransformation::MultiplyTransformation(const Params& params) : EltwiseBaseTransformation(params) { +MultiplyTransformation::MultiplyTransformation(const Params& params) : + WeightableLayerTransformation(params, CanBeTransformedParams(false, false, false, true)) { MATCHER_SCOPE(MultiplyTransformation); auto matcher = pattern::wrap_type(); @@ -38,135 +40,107 @@ MultiplyTransformation::MultiplyTransformation(const Params& params) : EltwiseBa this->register_matcher(m, callback); } -bool MultiplyTransformation::transform(TransformationContext& context, ov::pass::pattern::Matcher &m) { +bool MultiplyTransformation::transform(TransformationContext& context, ov::pass::pattern::Matcher& m) { auto multiply = m.get_match_root(); if (!canBeTransformed(context, multiply)) { return false; } + multiply = NetworkHelper::separateInStandaloneBranch(multiply, defaultPrecisions); + decomposeFakeQuantizeForWeightsPath(multiply); + NetworkHelper::normalizeDequantization(NetworkHelper::getDequantization(multiply, defaultPrecisions, 0)); NetworkHelper::normalizeDequantization(NetworkHelper::getDequantization(multiply, defaultPrecisions, 1)); - multiply = NetworkHelper::separateInStandaloneBranch(multiply, defaultPrecisions); - auto newMultiply = multiply; + const auto dequantization1 = NetworkHelper::getDequantization(multiply, defaultPrecisions, 0); + const auto dequantization2 = NetworkHelper::getDequantization(multiply, defaultPrecisions, 1); - auto fold_fake_quantizes = [](std::shared_ptr& multiply, const size_t index) { - auto fakeQuantizeOnWeights = ov::as_type_ptr(multiply->get_input_node_shared_ptr(index)); - if (fakeQuantizeOnWeights != nullptr) { - auto result = NetworkHelper::fold_fake_quantize(fakeQuantizeOnWeights); - if (ov::is_type(result)) { - replace_node(fakeQuantizeOnWeights, result); - } + if ((dequantization1.multiplyConstant == nullptr) && (dequantization2.multiplyConstant == nullptr)) { + return false; + } + + // before: y = (deq_scales1 * (x1 - zero_point1)) * (deq_scales2 * (x2 - zero_point2)) + // after : y = deq_scales1 * deq_scales2 * (x1 - zero_point1) * (x2 - zero_point2) + + auto new_scales_values = fold( + dequantization1.empty() ? dequantization1.data : dequantization1.multiplyConstant, + dequantization2.empty() ? dequantization2.data : dequantization2.multiplyConstant); + + if (!ov::is_type(new_scales_values)) { + return false; + } + + const auto init_input = [&new_scales_values](const FakeQuantizeDequantization& dequantization) -> Output { + if (dequantization.empty()) { + return new_scales_values; } + + if (dequantization.subtract == nullptr) { + return dequantization.data; + } + + const auto subtract = NetworkHelper::optimizeSubtract(dequantization.subtract); + if (subtract != nullptr) { + DisableCleanupAttribute::create(subtract); + } + + return subtract == nullptr ? dequantization.data : subtract; }; - fold_fake_quantizes(multiply, 0ul); - fold_fake_quantizes(multiply, 1ul); + if ((dequantization1.empty() && (ov::is_type(dequantization1.data.get_node()))) || + (dequantization2.empty() && (ov::is_type(dequantization2.data.get_node())))) { + // one input is constant + const Output in1 = init_input(dequantization1); + const Output in2 = init_input(dequantization2); - const int fullPathIndex = getNotEmpty(multiply); - if (fullPathIndex == -1) { - const auto multiplyBranch = getMultiplyConstBranch(multiply); - if (multiplyBranch.first != -1) { - NetworkHelper::foldDequantization(multiply, multiplyBranch.first == 0 ? 1 : 0, defaultPrecisions); - } + const auto new_multiply = (in1.get_element_type() == multiply->get_output_element_type(0)) && + (in2.get_element_type() == multiply->get_output_element_type(0)) ? + std::make_shared(in1, in2) : + std::make_shared>( + std::vector{ deqPrecision, deqPrecision }, + std::vector{ multiply->get_output_element_type(0) }, + ov::op::TemporaryReplaceOutputType(in1, deqPrecision).get(), + ov::op::TemporaryReplaceOutputType(in2, deqPrecision).get()); - if (multiplyBranch.first == -1 || multiplyBranch.second == -1) { - // constant folding on dequantization ops (for example: Convert on Subtract) - NetworkHelper::foldDequantization(multiply, 0, defaultPrecisions); - NetworkHelper::foldDequantization(multiply, 1, defaultPrecisions); - return false; - } + replace_node(multiply, new_multiply); + updateOutput(context, new_multiply, multiply); - auto multiplyParent = multiply->input_value(multiplyBranch.first); - auto constParent = multiply->input_value(multiplyBranch.first == 0 ? 1 : 0); - auto multiplyParentParent = multiplyParent.get_node_shared_ptr()->input_value(multiplyBranch.second); - auto multiplyParentConst = multiplyParent.get_node_shared_ptr()->input_value(multiplyBranch.second == 0 ? 1 : 0); - - newMultiply = std::make_shared>( - std::vector{ element::f32, element::f32 }, - std::vector{ multiply->get_output_element_type(0) }, - ov::op::TemporaryReplaceOutputType(multiplyParentParent, element::f32).get(), - ov::op::TemporaryReplaceOutputType( - fold( - foldConvert(multiplyParentConst, element::f32), - foldConvert(constParent, element::f32)), - element::f32).get()); - - NetworkHelper::copyInfo(multiplyParent.get_node_shared_ptr(), newMultiply); - NetworkHelper::copyInfo(multiply, newMultiply); - } else { - const int emptyPathIndex = fullPathIndex == 0 ? 1 : 0; - - if (updatePrecisions) { - const FakeQuantizeDequantization dequantizationEmptyPath = NetworkHelper::getDequantization(multiply, defaultPrecisions, emptyPathIndex); - if (!dequantizationEmptyPath.empty() && !dequantizationEmptyPath.isLowPrecision()) { - return false; - } - } - - FakeQuantizeDequantization dequantizationEmptyPath = NetworkHelper::foldDequantization(multiply, emptyPathIndex, defaultPrecisions); - std::shared_ptr subtractValuesEmptyPath; - std::shared_ptr multiplyValuesEmptyPath; - std::tie(subtractValuesEmptyPath, multiplyValuesEmptyPath) = NetworkHelper::createEmptyValues(dequantizationEmptyPath, deqPrecision); - - // check if empty path shifts are not zero - if (!NetworkHelper::isZeroConst(subtractValuesEmptyPath)) { - return false; - } - - FakeQuantizeDequantization dequantizationFullPath = NetworkHelper::foldDequantization(multiply, fullPathIndex, defaultPrecisions); - std::shared_ptr subtractValuesFullPath; - std::shared_ptr multiplyValuesFullPath; - std::tie(subtractValuesFullPath, multiplyValuesFullPath) = NetworkHelper::createEmptyValues(dequantizationFullPath, deqPrecision); - - - // before: Y = (SC1 * (X1 - SH1)) * (SC2 * X2) - // after : Y = (SC1' * (X1 - SH1)) * (X2) , where : - // SC1' = SC1 * SC2 - auto newMultiplyValuesFullPath = fold(multiplyValuesEmptyPath, multiplyValuesFullPath); - OutputVector inputs{ {}, {} }; - inputs[emptyPathIndex] = dequantizationEmptyPath.data; - inputs[fullPathIndex] = std::make_shared( - dequantizationFullPath.subtract == nullptr ? - (dequantizationFullPath.convert == nullptr ? - dequantizationFullPath.data : dequantizationFullPath.convert) : - dequantizationFullPath.subtract, - newMultiplyValuesFullPath); - - newMultiply = std::make_shared>( - std::vector{element::f32, element::f32}, - std::vector{ multiply->get_output_element_type(0) }, - ov::op::TemporaryReplaceOutputType(inputs[0], element::f32).get(), - ov::op::TemporaryReplaceOutputType(inputs[1], element::f32).get()); - NetworkHelper::copyInfo(multiply, newMultiply); + return true; } - replace_node(multiply, newMultiply); - updateOutput(context, newMultiply, multiply); + Output in1 = init_input(dequantization1); + Output in2 = init_input(dequantization2); - if (fullPathIndex != -1) { - NetworkHelper::foldDequantization(newMultiply, fullPathIndex, defaultPrecisions); - } + // in1 & in2 can have different input types + const auto new_multiply = (in1.get_element_type() == deqPrecision) && + (in2.get_element_type() == deqPrecision) ? + std::make_shared(in1, in2) : + std::make_shared>( + std::vector{ deqPrecision, deqPrecision }, + std::vector{ deqPrecision }, + ov::op::TemporaryReplaceOutputType(in1, deqPrecision).get(), + ov::op::TemporaryReplaceOutputType(in2, deqPrecision).get()); + + DisableCleanupAttribute::create(new_multiply); + + auto new_scales = (new_multiply->get_output_element_type(0) == multiply->get_output_element_type(0)) && + (new_scales_values->get_output_element_type(0) == multiply->get_output_element_type(0)) ? + std::make_shared(new_multiply, new_scales_values) : + std::make_shared>( + ov::opset1::Multiply(new_multiply, new_scales_values), + multiply->get_output_element_type(0)); + + replace_node(multiply, new_scales); + const auto was_updated = updateOutput(context, new_scales, multiply); + NetworkHelper::copyInfo(multiply, new_multiply, !was_updated); return true; } -bool MultiplyTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr layer) const { - FakeQuantizeDequantization dequantization1 = pass::low_precision::NetworkHelper::getDequantization(layer, defaultPrecisions, 0ul); - FakeQuantizeDequantization dequantization2 = pass::low_precision::NetworkHelper::getDequantization(layer, defaultPrecisions, 1ul); - - if (dequantization1.data.get_node() == nullptr || dequantization2.data.get_node() == nullptr) { - return false; - } - - const bool nonConstantData = !ov::is_type(dequantization1.data.get_node_shared_ptr()) && - !ov::is_type(dequantization2.data.get_node_shared_ptr()); - - if (((dequantization1.empty() || dequantization2.empty()) && nonConstantData)) { - return false; - } - - return EltwiseBaseTransformation::canBeTransformed(context, layer); +size_t MultiplyTransformation::getInputChannels(const std::shared_ptr op) const { + const auto channels = op->get_input_partial_shape(1)[1]; + assert(channels.is_static()); + return channels.get_length(); } } // namespace low_precision diff --git a/src/common/low_precision_transformations/src/multiply_partial.cpp b/src/common/low_precision_transformations/src/multiply_partial.cpp new file mode 100644 index 00000000000..cbe66273929 --- /dev/null +++ b/src/common/low_precision_transformations/src/multiply_partial.cpp @@ -0,0 +1,174 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "low_precision/multiply_partial.hpp" + +#include +#include +#include +#include +#include +#include +#include + +#include "openvino/pass/pattern/op/wrap_type.hpp" + +#include "low_precision/common/ie_lpt_exception.hpp" +#include "low_precision/network_helper.hpp" +#include "itt.hpp" + +namespace ov { +namespace pass { +namespace low_precision { + +MultiplyPartialTransformation::MultiplyPartialTransformation(const Params& params) : EltwiseBaseTransformation(params) { + MATCHER_SCOPE(MultiplyPartialTransformation); + auto matcher = pattern::wrap_type(); + + ov::graph_rewrite_callback callback = [this](pattern::Matcher& m) { + auto op = m.get_match_root(); + if (transformation_callback(op)) { + return false; + } + return transform(*context, m); + }; + + auto m = std::make_shared(matcher, matcher_name); + this->register_matcher(m, callback); +} + +bool MultiplyPartialTransformation::transform(TransformationContext& context, ov::pass::pattern::Matcher& m) { + auto multiply = m.get_match_root(); + if (!canBeTransformed(context, multiply)) { + return false; + } + + NetworkHelper::normalizeDequantization(NetworkHelper::getDequantization(multiply, defaultPrecisions, 0)); + NetworkHelper::normalizeDequantization(NetworkHelper::getDequantization(multiply, defaultPrecisions, 1)); + + multiply = NetworkHelper::separateInStandaloneBranch(multiply, defaultPrecisions); + auto newMultiply = multiply; + + auto fold_fake_quantizes = [](std::shared_ptr& multiply, const size_t index) { + auto fakeQuantizeOnWeights = ov::as_type_ptr(multiply->get_input_node_shared_ptr(index)); + if (fakeQuantizeOnWeights != nullptr) { + auto result = NetworkHelper::fold_fake_quantize(fakeQuantizeOnWeights); + if (ov::is_type(result)) { + replace_node(fakeQuantizeOnWeights, result); + } + } + }; + + fold_fake_quantizes(multiply, 0ul); + fold_fake_quantizes(multiply, 1ul); + + const int fullPathIndex = getNotEmpty(multiply); + if (fullPathIndex == -1) { + const auto multiplyBranch = getMultiplyConstBranch(multiply); + if (multiplyBranch.first != -1) { + NetworkHelper::foldDequantization(multiply, multiplyBranch.first == 0 ? 1 : 0, defaultPrecisions); + } + + if (multiplyBranch.first == -1 || multiplyBranch.second == -1) { + // constant folding on dequantization ops (for example: Convert on Subtract) + NetworkHelper::foldDequantization(multiply, 0, defaultPrecisions); + NetworkHelper::foldDequantization(multiply, 1, defaultPrecisions); + return false; + } + + auto multiplyParent = multiply->input_value(multiplyBranch.first); + auto constParent = multiply->input_value(multiplyBranch.first == 0 ? 1 : 0); + auto multiplyParentParent = multiplyParent.get_node_shared_ptr()->input_value(multiplyBranch.second); + auto multiplyParentConst = multiplyParent.get_node_shared_ptr()->input_value(multiplyBranch.second == 0 ? 1 : 0); + + newMultiply = std::make_shared>( + std::vector{ element::f32, element::f32 }, + std::vector{ multiply->get_output_element_type(0) }, + ov::op::TemporaryReplaceOutputType(multiplyParentParent, element::f32).get(), + ov::op::TemporaryReplaceOutputType( + fold( + foldConvert(multiplyParentConst, element::f32), + foldConvert(constParent, element::f32)), + element::f32).get()); + + NetworkHelper::copyInfo(multiplyParent.get_node_shared_ptr(), newMultiply); + NetworkHelper::copyInfo(multiply, newMultiply); + } else { + const int emptyPathIndex = fullPathIndex == 0 ? 1 : 0; + + if (updatePrecisions) { + const FakeQuantizeDequantization dequantizationEmptyPath = NetworkHelper::getDequantization(multiply, defaultPrecisions, emptyPathIndex); + if (!dequantizationEmptyPath.empty() && !dequantizationEmptyPath.isLowPrecision()) { + return false; + } + } + + FakeQuantizeDequantization dequantizationEmptyPath = NetworkHelper::foldDequantization(multiply, emptyPathIndex, defaultPrecisions); + std::shared_ptr subtractValuesEmptyPath; + std::shared_ptr multiplyValuesEmptyPath; + std::tie(subtractValuesEmptyPath, multiplyValuesEmptyPath) = NetworkHelper::createEmptyValues(dequantizationEmptyPath, deqPrecision); + + // check if empty path shifts are not zero + if (!NetworkHelper::isZeroConst(subtractValuesEmptyPath)) { + return false; + } + + FakeQuantizeDequantization dequantizationFullPath = NetworkHelper::foldDequantization(multiply, fullPathIndex, defaultPrecisions); + std::shared_ptr subtractValuesFullPath; + std::shared_ptr multiplyValuesFullPath; + std::tie(subtractValuesFullPath, multiplyValuesFullPath) = NetworkHelper::createEmptyValues(dequantizationFullPath, deqPrecision); + + + // before: Y = (SC1 * (X1 - SH1)) * (SC2 * X2) + // after : Y = (SC1' * (X1 - SH1)) * (X2) , where : + // SC1' = SC1 * SC2 + auto newMultiplyValuesFullPath = fold(multiplyValuesEmptyPath, multiplyValuesFullPath); + OutputVector inputs{ {}, {} }; + inputs[emptyPathIndex] = dequantizationEmptyPath.data; + inputs[fullPathIndex] = std::make_shared( + dequantizationFullPath.subtract == nullptr ? + (dequantizationFullPath.convert == nullptr ? + dequantizationFullPath.data : dequantizationFullPath.convert) : + dequantizationFullPath.subtract, + newMultiplyValuesFullPath); + + newMultiply = std::make_shared>( + std::vector{element::f32, element::f32}, + std::vector{ multiply->get_output_element_type(0) }, + ov::op::TemporaryReplaceOutputType(inputs[0], element::f32).get(), + ov::op::TemporaryReplaceOutputType(inputs[1], element::f32).get()); + NetworkHelper::copyInfo(multiply, newMultiply); + } + + replace_node(multiply, newMultiply); + updateOutput(context, newMultiply, multiply); + + if (fullPathIndex != -1) { + NetworkHelper::foldDequantization(newMultiply, fullPathIndex, defaultPrecisions); + } + + return true; +} + +bool MultiplyPartialTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr layer) const { + FakeQuantizeDequantization dequantization1 = pass::low_precision::NetworkHelper::getDequantization(layer, defaultPrecisions, 0ul); + FakeQuantizeDequantization dequantization2 = pass::low_precision::NetworkHelper::getDequantization(layer, defaultPrecisions, 1ul); + + if (dequantization1.data.get_node() == nullptr || dequantization2.data.get_node() == nullptr) { + return false; + } + + const bool nonConstantData = !ov::is_type(dequantization1.data.get_node_shared_ptr()) && + !ov::is_type(dequantization2.data.get_node_shared_ptr()); + + if (((dequantization1.empty() || dequantization2.empty()) && nonConstantData)) { + return false; + } + + return EltwiseBaseTransformation::canBeTransformed(context, layer); +} + +} // namespace low_precision +} // namespace pass +} // namespace ov diff --git a/src/common/low_precision_transformations/src/multiply_to_group_convolution.cpp b/src/common/low_precision_transformations/src/multiply_to_group_convolution.cpp index 62a51ef6193..a8999aeff8e 100644 --- a/src/common/low_precision_transformations/src/multiply_to_group_convolution.cpp +++ b/src/common/low_precision_transformations/src/multiply_to_group_convolution.cpp @@ -15,7 +15,7 @@ namespace low_precision { MultiplyToGroupConvolutionTransformation::MultiplyToGroupConvolutionTransformation( const Params& params, - const PrecisionsRestriction::PrecisionsByPorts& restrictions) : LayerTransformation(params), restrictions(restrictions), groupSize(1ul) { + const PrecisionsRestriction::PrecisionsByPorts& restrictions) : CleanupTransformation(params), restrictions(restrictions), groupSize(1ul) { MATCHER_SCOPE(MultiplyToGroupConvolutionTransformation); auto matcher = pattern::wrap_type(); @@ -143,6 +143,10 @@ bool MultiplyToGroupConvolutionTransformation::transform(TransformationContext& } bool MultiplyToGroupConvolutionTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr operation) const { + if (!CleanupTransformation::canBeTransformed(context, operation)) { + return false; + } + const PartialShape outPShape = operation->get_output_partial_shape(0); const auto rank = outPShape.rank(); if (rank.is_dynamic()) { diff --git a/src/common/low_precision_transformations/src/recurrent_cell.cpp b/src/common/low_precision_transformations/src/recurrent_cell.cpp index cb961a3de40..7fd40cf2071 100644 --- a/src/common/low_precision_transformations/src/recurrent_cell.cpp +++ b/src/common/low_precision_transformations/src/recurrent_cell.cpp @@ -14,7 +14,7 @@ #include "openvino/pass/pattern/op/or.hpp" #include "low_precision/network_helper.hpp" -#include "low_precision/rt_info/skip_cleanup_attribute.hpp" +#include "low_precision/rt_info/disable_cleanup_attribute.hpp" namespace ov { namespace pass { @@ -96,6 +96,7 @@ bool RecurrentCellTransformation::transform(TransformationContext& context, ov:: if (!canBeTransformed(context, lstm)) { return false; } + for (size_t parentIndex = 0ul; parentIndex < lstm->get_input_size(); parentIndex++) { auto lstm_parent = lstm->get_input_node_shared_ptr(parentIndex); if (is_type(lstm_parent)) { @@ -108,7 +109,7 @@ bool RecurrentCellTransformation::transform(TransformationContext& context, ov:: ? defaultPrecisions : precisionsAttribute.as().value(); const DataPrecision dataPrecision = getDataPrecision(lstm_parent, quantizationDetails, precisions); - if (dataPrecision.empty()) { + if (dataPrecision.empty() || dataPrecision.hasZeroPoint) { return false; } @@ -148,6 +149,7 @@ bool RecurrentCellTransformation::transform(TransformationContext& context, ov:: continue; } } + return true; } @@ -172,12 +174,12 @@ bool RecurrentCellTransformation::isPrecisionPreserved(std::shared_ptr) co } void RecurrentCellTransformation::propagateSkipCleanupAttribute(std::shared_ptr multiply) { - SkipCleanupAttribute::create(multiply); + DisableCleanupAttribute::create(multiply); auto multiply_parent = multiply->get_input_node_shared_ptr(0); - SkipCleanupAttribute::create(multiply_parent); + DisableCleanupAttribute::create(multiply_parent); if (is_type(multiply_parent)) { auto subtract_parent = multiply_parent->get_input_node_shared_ptr(0); - SkipCleanupAttribute::create(subtract_parent); + DisableCleanupAttribute::create(subtract_parent); } } diff --git a/src/common/low_precision_transformations/src/rt_info/skip_cleanup_attribute.cpp b/src/common/low_precision_transformations/src/rt_info/skip_cleanup_attribute.cpp deleted file mode 100644 index 1d7d4a1549a..00000000000 --- a/src/common/low_precision_transformations/src/rt_info/skip_cleanup_attribute.cpp +++ /dev/null @@ -1,20 +0,0 @@ -// Copyright (C) 2018-2023 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#include "low_precision/rt_info/skip_cleanup_attribute.hpp" - -#include -#include -#include -#include -#include - -using namespace ov; -using namespace ov; - -ov::Any SkipCleanupAttribute::create( - const std::shared_ptr& node) { - auto& rt = node->get_rt_info(); - return (rt[SkipCleanupAttribute::get_type_info_static()] = SkipCleanupAttribute()); -} diff --git a/src/common/low_precision_transformations/src/weightable_layer_transformation.cpp b/src/common/low_precision_transformations/src/weightable_layer_transformation.cpp index ac945339755..d3dd47d2107 100644 --- a/src/common/low_precision_transformations/src/weightable_layer_transformation.cpp +++ b/src/common/low_precision_transformations/src/weightable_layer_transformation.cpp @@ -15,6 +15,7 @@ namespace pass { namespace low_precision { namespace { +// used in isQuantizedStatic static method, can not be virtual method std::vector getWeightsDequantizationIdces(const std::shared_ptr weightableLayer) { if (ov::is_type(weightableLayer)) { return std::vector{0}; @@ -22,7 +23,9 @@ std::vector getWeightsDequantizationIdces(const std::shared_ptr{1}; } else if (ov::is_type(weightableLayer)) { return ov::is_type(weightableLayer->get_input_node_shared_ptr(1)) ? std::vector{0} - : std::vector{0, 1}; + : std::vector{0, 1}; + } else if (ov::is_type(weightableLayer)) { + return std::vector{}; } else { THROW_IE_LPT_EXCEPTION(*weightableLayer) << "getWeightsDequantizationIdces is called for unexpected layer"; } @@ -41,7 +44,10 @@ bool checkConstShape(const std::vector& idcesToCheck, const std::shared_ } } // namespace -WeightableLayerTransformation::WeightableLayerTransformation(const Params& params) : LayerTransformation(params) {} +WeightableLayerTransformation::WeightableLayerTransformation(const Params& params, const CanBeTransformedParams& canBeTransformedParams) : + LayerTransformation(params), + canBeTransformedParams(canBeTransformedParams) { +} bool WeightableLayerTransformation::canConvolutionBeTransformed(const TransformationContext& context, std::shared_ptr layer, const std::vector& defaultPrecisions) const { @@ -88,7 +94,7 @@ bool WeightableLayerTransformation::canBeTransformed(const TransformationContext } // dynamic activations rank and dynamic weights aren't supported - if (layer->get_input_partial_shape(0).rank().is_dynamic() || layer->get_input_partial_shape(1).is_dynamic()) { + if (!canBeTransformedParams.dynamicWeights && (layer->get_input_partial_shape(0).rank().is_dynamic() || layer->get_input_partial_shape(1).is_dynamic())) { return false; } @@ -138,14 +144,16 @@ bool WeightableLayerTransformation::canBeTransformed(const TransformationContext return false; } - // exactly cast vector as original code has a conversion; - // optimize cast: - // two branches depending on real type of the constant? - const auto scalesBuffer = dequantization.multiplyConstant->cast_vector(); - size_t scalesBufferSize = shape_size(dequantization.multiplyConstant->get_shape()); - for (size_t i = 1ul; i < scalesBufferSize; ++i) { - if (scalesBuffer[i - 1] != scalesBuffer[i]) { - return false; + if (canBeTransformedParams.perTensorQuantizationOnData) { + // exactly cast vector as original code has a conversion; + // optimize cast: + // two branches depending on real type of the constant? + const auto scalesBuffer = dequantization.multiplyConstant->cast_vector(); + size_t scalesBufferSize = shape_size(dequantization.multiplyConstant->get_shape()); + for (size_t i = 1ul; i < scalesBufferSize; ++i) { + if (scalesBuffer[i - 1] != scalesBuffer[i]) { + return false; + } } } } @@ -213,8 +221,11 @@ bool WeightableLayerTransformation::canBeTransformed(const TransformationContext dqVolume *= constChannels; } } - if (shape_size(constShape) != 1 && shape_size(constShape) != dqVolume) { - return false; + + if (!dqIdces.empty()) { + if (shape_size(constShape) != 1 && shape_size(constShape) != dqVolume) { + return false; + } } } else { // TODO: LPT: is it possible to share with isQuantized? @@ -225,13 +236,16 @@ bool WeightableLayerTransformation::canBeTransformed(const TransformationContext return false; } - const auto weightsData = ov::as_type_ptr(dequantizationOnWeights.data.get_node_shared_ptr()); - if (weightsData == nullptr) { - return false; + const auto weightsData = dequantizationOnWeights.data.get_node_shared_ptr(); + if (canBeTransformedParams.constantWeight) { + const auto constantWeightsData = ov::as_type_ptr(weightsData); + if (constantWeightsData == nullptr) { + return false; + } } const auto weightsDataPrecision = weightsData->get_element_type(); - if (!DataPrecision::isSupported(weightsDataPrecision)) { + if (canBeTransformedParams.limitWeightsDataPrecision && !DataPrecision::isSupported(weightsDataPrecision)) { return false; } @@ -243,9 +257,11 @@ bool WeightableLayerTransformation::canBeTransformed(const TransformationContext } const auto dqIdces = getWeightsDequantizationIdces(layer); - if ((dequantizationOnWeights.subtract && !checkConstShape(dqIdces, dequantizationOnWeights.subtractConstant)) || - (dequantizationOnWeights.multiply && !checkConstShape(dqIdces, dequantizationOnWeights.multiplyConstant))) { - return false; + if (!dqIdces.empty()) { + if ((dequantizationOnWeights.subtract && !checkConstShape(dqIdces, dequantizationOnWeights.subtractConstant)) || + (dequantizationOnWeights.multiply && !checkConstShape(dqIdces, dequantizationOnWeights.multiplyConstant))) { + return false; + } } } diff --git a/src/common/low_precision_transformations/tests/lpt_avoid_shapeof_propagation_test.cpp b/src/common/low_precision_transformations/tests/lpt_avoid_shapeof_propagation_test.cpp index 431c4459a4c..f2459620019 100644 --- a/src/common/low_precision_transformations/tests/lpt_avoid_shapeof_propagation_test.cpp +++ b/src/common/low_precision_transformations/tests/lpt_avoid_shapeof_propagation_test.cpp @@ -22,7 +22,7 @@ #include "low_precision/interpolate.hpp" #include "low_precision/mat_mul.hpp" #include "low_precision/max_pool.hpp" -#include "low_precision/multiply.hpp" +#include "low_precision/multiply_partial.hpp" #include "low_precision/mvn.hpp" #include "low_precision/network_helper.hpp" #include "low_precision/normalize_l2.hpp" @@ -361,7 +361,7 @@ TEST(LPT, AvoidDequantizationToShapeOfPropagationMultiplyTransformation) { auto f = std::make_shared(ResultVector{result1, result2}, ParameterVector{input1, input2}); pass::Manager m; - m.register_pass(); + m.register_pass(); m.run_passes(f); auto dqBeforeShapeOf = ov::pass::low_precision::NetworkHelper::getDequantization(result2->get_input_node_shared_ptr(0)); diff --git a/src/common/low_precision_transformations/tests/multiply_partial_transformation.cpp b/src/common/low_precision_transformations/tests/multiply_partial_transformation.cpp new file mode 100644 index 00000000000..1e556df70bc --- /dev/null +++ b/src/common/low_precision_transformations/tests/multiply_partial_transformation.cpp @@ -0,0 +1,1007 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "layer_transformation.hpp" + +#include +#include +#include + +#include + +#include +#include +#include +#include "low_precision/multiply_partial.hpp" +#include "lpt_ngraph_functions/common/dequantization_operations.hpp" + +#include "common_test_utils/ov_test_utils.hpp" +#include "simple_low_precision_transformer.hpp" +#include "lpt_ngraph_functions/multiply_partial_function.hpp" + +namespace { +using namespace testing; +using namespace ov; +using namespace ov::pass; +using namespace ngraph::builder::subgraph; + +class MultiplyPartialTransformationTestValues { +public: + TestTransformationParams transformationParams; + MultiplyPartialValues actual; + MultiplyPartialValues expected; + + MultiplyPartialTransformationTestValues() = default; + + MultiplyPartialTransformationTestValues( + TestTransformationParams transformationParams, + MultiplyPartialValues actual, + MultiplyPartialValues expected): + transformationParams(std::move(transformationParams)), + actual(std::move(actual)), + expected(std::move(expected)) {} +}; + +typedef std::tuple< + ov::element::Type, + MultiplyPartialTransformationTestValues> MultiplyPartialTransformationParams; + +class MultiplyPartialTransformation : public LayerTransformation, public testing::WithParamInterface { +public: + void SetUp() override { + const ov::element::Type precision = std::get<0>(GetParam()); + const MultiplyPartialTransformationTestValues testParams = std::get<1>(GetParam()); + + actualFunction = MultiplyPartialFunction::get(precision, testParams.actual); + + SimpleLowPrecisionTransformer transform; + transform.add(testParams.transformationParams); + transform.transform(actualFunction); + + referenceFunction = MultiplyPartialFunction::get(precision, testParams.expected); + } + + static std::string getTestCaseName(testing::TestParamInfo obj) { + const ov::element::Type precision = std::get<0>(obj.param); + const MultiplyPartialTransformationTestValues testParams = std::get<1>(obj.param); + + std::ostringstream result; + result << + LayerTransformation::getTestCaseNameByParams(precision, testParams.expected.branch1.inputShape, testParams.transformationParams) << + testParams.actual << + testParams.expected; + return result.str(); + } +}; + +TEST_P(MultiplyPartialTransformation, CompareFunctions) { + actualFunction->validate_nodes_and_infer_types(); + auto res = compare_functions(actualFunction, referenceFunction, true, true, false); + ASSERT_TRUE(res.first) << res.second; + + ASSERT_TRUE(LayerTransformation::allNamesAreUnique(actualFunction)) << "Not all names are unique"; +} + +const std::vector precisions = { + ov::element::f32, + ov::element::f16 +}; + +const std::vector multiplyTransformationTestValues = { + // U8 + { + LayerTransformation::createParamsU8I8(), + { + { + { 1, 3, 8, 16 }, + {}, + ov::element::u8, + {ov::element::f32, { 2.f }, { 10.f }} + }, + { + { 1, 3, 8, 16 }, + {}, + ov::element::u8, + {ov::element::f32, { 3.f }, { 7.f }} + }, + false + }, + { + { + { 1, 3, 8, 16 }, + {}, + ov::element::u8, + {ov::element::f32, { 2.f }, { 10.f }} + }, + { + { 1, 3, 8, 16 }, + {}, + ov::element::u8, + {ov::element::f32, { 3.f }, { 7.f }} + }, + false + } + }, + + { + LayerTransformation::createParamsU8I8(), + { + { + { Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic() }, + {}, + ov::element::u8, + {ov::element::f32, { 2.f }, { 10.f }} + }, + { + { Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic() }, + {}, + ov::element::u8, + {ov::element::f32, { 3.f }, { 7.f }} + }, + false + }, + { + { + { Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic() }, + {}, + ov::element::u8, + {ov::element::f32, { 2.f }, { 10.f }} + }, + { + { Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic() }, + {}, + ov::element::u8, + {ov::element::f32, { 3.f }, { 7.f }} + }, + false + } + }, + + { + LayerTransformation::createParamsU8I8(), + { + { + { 1, 3, 8, 16 }, + {}, + ov::element::u8, + {ov::element::f32, { 2.f }, { 10.f }} + }, + { + { 1, 3, 8, 16 }, + {}, + ov::element::u8, + {ov::element::f32, { }, { 7.f }} + }, + false + }, + { + { + { 1, 3, 8, 16 }, + {}, + ov::element::u8, + {ov::element::f32, { 2.f }, { 70.f }} + }, + { + { 1, 3, 8, 16 }, + {}, + ov::element::u8, + {} + }, + false + } + }, + + { + LayerTransformation::createParamsU8I8(), + { + { + { Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic() }, + {}, + ov::element::u8, + {ov::element::f32, { 2.f }, { 10.f }} + }, + { + { Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic() }, + {}, + ov::element::u8, + {ov::element::f32, { }, { 7.f }} + }, + false + }, + { + { + { Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic() }, + {}, + ov::element::u8, + {ov::element::f32, { 2.f }, { 70.f }} + }, + { + { Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic() }, + {}, + ov::element::u8, + {} + }, + false + } + }, + + { + LayerTransformation::createParamsU8I8(), + { + { + { 1, 3, 8, 16 }, + {}, + ov::element::u8, + { ov::element::f32, { }, { 10.f }} + }, + { + { 1, 3, 8, 16 }, + {}, + ov::element::u8, + { ov::element::f32, { }, { 7.f } } + }, + false + }, + { + { + { 1, 3, 8, 16 }, + {}, + ov::element::u8, + {ov::element::f32, { }, { 70.f }} + }, + { + { 1, 3, 8, 16 }, + {}, + ov::element::u8, + {} + }, + false + } + }, + + { + LayerTransformation::createParamsU8I8(), + { + { + { 1, 3, 8, 16 }, + {}, + ov::element::u8, + {ov::element::f32, { 2.f }, { }} + }, + { + { 1, 3, 8, 16 }, + {}, + ov::element::u8, + {ov::element::f32, { }, { 7.f } } + }, + false + }, + { + { + { 1, 3, 8, 16 }, + {}, + ov::element::u8, + {ov::element::f32, { 2.f }, { 7.f }} + }, + { + { 1, 3, 8, 16 }, + {}, + ov::element::u8, + {} + }, + false + } + }, + { + LayerTransformation::createParamsU8I8(), + { + { + { Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic() }, + {}, + ov::element::u8, + {ov::element::f32, { 2.f }, { }} + }, + { + { Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic() }, + {}, + ov::element::u8, + {ov::element::f32, { }, { 7.f } } + }, + false + }, + { + { + { Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic() }, + {}, + ov::element::u8, + {ov::element::f32, { 2.f }, { 7.f }} + }, + { + { Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic() }, + {}, + ov::element::u8, + {} + }, + false + } + }, + { + LayerTransformation::createParamsU8I8(), + { + { + PartialShape::dynamic(), + {}, + ov::element::u8, + {ov::element::f32, { 2.f }, { }} + }, + { + PartialShape::dynamic(), + {}, + ov::element::u8, + {ov::element::f32, { }, { 7.f } } + }, + false + }, + { + { + PartialShape::dynamic(), + {}, + ov::element::u8, + {ov::element::f32, { 2.f }, { }} + }, + { + PartialShape::dynamic(), + {}, + ov::element::u8, + {ov::element::f32, { }, { 7.f } } + }, + false + } + }, + + // I8 + { + LayerTransformation::createParamsI8I8(), + { + { + { 1, 3, 8, 16 }, + {}, + ov::element::i8, + {ov::element::f32, { 2.f }, { 10.f }} + }, + { + { 1, 3, 8, 16 }, + {}, + ov::element::i8, + {ov::element::f32, { 3.f }, { 7.f }} + }, + false + }, + { + { + { 1, 3, 8, 16 }, + {}, + ov::element::i8, + {ov::element::f32, { 2.f }, { 10.f }} + }, + { + { 1, 3, 8, 16 }, + {}, + ov::element::i8, + {ov::element::f32, { 3.f }, { 7.f } } + }, + false + } + }, + + // Actual: + // + // Parameter + // |I8 + // | + // Convert Constant Parameter + // \FP32 /FP32 |I8 + // \ / | + // Subtract Constant Convert Constant + // \FP32 /FP32 \FP32 /FP32 + // \ / \ / + // Multiply Multiply + // \FP32 /FP32 + // \ / + // Multiply + // Transformed: + // + // Parameter + // |I8 + // | + // Convert Constant + // \FP32 /FP32 + // \ / + // Subtract Constant + // \FP32 /FP32 + // \ / + // Multiply Parameter + // \FP32 /I8 + // \ / + // Multiply + { + LayerTransformation::createParamsI8I8(), + { + { + { 1, 3, 8, 16 }, + {}, + ov::element::i8, + {ov::element::f32, { 2.f }, { 10.f }} + }, + { + { 1, 3, 8, 16 }, + {}, + ov::element::i8, + {ov::element::f32, { }, { 7.f }} + }, + false + }, + { + { + { 1, 3, 8, 16 }, + {}, + ov::element::i8, + {ov::element::f32, { 2.f }, { 70.f }}, + }, + { + { 1, 3, 8, 16 }, + {}, + ov::element::i8, + {} + }, + false + } + }, + + // Actual: + // + // Parameter Constant + // |I8 |I8 + // | | + // Convert Convert Parameter + // \FP32 /FP32 |I8 + // \ / | + // Subtract Constant Convert Constant + // \FP32 /FP32 \FP32 /FP32 + // \ / \ / + // Multiply Multiply + // \FP32 /FP32 + // \ / + // Multiply + // Transformed: + // + // Parameter + // |I8 + // | + // Convert Constant + // \FP32 /FP32 + // \ / + // Subtract Constant + // \FP32 /FP32 + // \ / + // Multiply Parameter + // \FP32 /I8 + // \ / + // Multiply + { + LayerTransformation::createParamsI8I8(), + { + { + { 1, 3, 8, 16 }, + {}, + ov::element::i8, + { + ov::element::f32, + { {2.f}, ov::element::f32, {}, true, 1ul, ov::element::i8, true }, + { 10.f } + } + }, + { + { 1, 3, 8, 16 }, + {}, + ov::element::i8, + {ov::element::f32, { }, { 7.f }} + }, + false + }, + { + { + { 1, 3, 8, 16 }, + {}, + ov::element::i8, + {ov::element::f32, { 2.f }, { 70.f }}, + }, + { + { 1, 3, 8, 16 }, + {}, + ov::element::i8, + {} + }, + false + } + }, + + { + LayerTransformation::createParamsI8I8(), + { + { + { 1, 3, 8, 16 }, + {}, + ov::element::i8, + {ov::element::f32, { }, { 10.f }} + }, + { + { 1, 3, 8, 16 }, + {}, + ov::element::i8, + {ov::element::f32, { }, { 7.f } } + }, + false + }, + { + { + { 1, 3, 8, 16 }, + {}, + ov::element::i8, + { ov::element::f32, { }, { 70.f }} + }, + { + { 1, 3, 8, 16 }, + {}, + ov::element::i8, + { } + }, + false + } + }, + + { + LayerTransformation::createParamsI8I8(), + { + { + { 1, 3, 8, 16 }, + {}, + ov::element::i8, + {ov::element::f32, { 2.f }, { }}, + }, + { + { 1, 3, 8, 16 }, + {}, + ov::element::i8, + {ov::element::f32, { }, { 7.f } }, + }, + false + }, + { + { + { 1, 3, 8, 16 }, + {}, + ov::element::i8, + {ov::element::f32, { 2.f }, { 7.f }}, + }, + { + { 1, 3, 8, 16 }, + {}, + ov::element::i8, + {} + }, + false + } + }, + + // Constant as input + { + LayerTransformation::createParamsU8I8(), + { + { + { 1, 3, 8, 16 }, + {}, + ov::element::i8, + {ov::element::f32, { }, { 10.f }}, + }, + { + {}, + {{ 7.f }, ov::element::f32}, // Constant as input + ov::element::f32, + {} + }, + false + }, + { + { + { 1, 3, 8, 16 }, + {}, + ov::element::i8, + {ov::element::f32, {}, {}}, + }, + { + {}, + {{ 70.f }, ov::element::f32}, + ov::element::f32, + {} + }, + true + } + }, + + { + LayerTransformation::createParamsU8I8(), + { + { + { Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic() }, + {}, + ov::element::i8, + {ov::element::f32, { }, { 10.f }}, + }, + { + {}, + {{ 7.f }, ov::element::f32}, // Constant as input + ov::element::f32, + {} + }, + false + }, + { + { + { Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic() }, + {}, + ov::element::i8, + {ov::element::f32, {}, {}}, + }, + { + {}, + {{ 70.f }, ov::element::f32}, + ov::element::f32, + {} + }, + true + } + }, + + { + LayerTransformation::createParamsU8I8(), + { + { + { 1, 3, 8, 16 }, + {}, + ov::element::i8, + {ov::element::f32, { 18.f }, { 10.f }}, + }, + { + {}, + {{ 7.f }, ov::element::f32}, + ov::element::f32, + {} + }, + false + }, + { + { + { 1, 3, 8, 16 }, + {}, + ov::element::i8, + {ov::element::f32, { 18.f }, { }}, + }, + { + {}, + {{ 70.f }, ov::element::f32}, + ov::element::f32, + {} + }, + true + } + }, + + // Constant as input with empty shape + { + LayerTransformation::createParamsU8I8(), + { + { + { 1, 3, 8, 16 }, + {}, + ov::element::i8, + {ov::element::f32, { }, { 0.2f }}, + }, + { + {}, + {{ 7.f }, ov::element::i8}, // Constant as input + ov::element::i8, + {ov::element::f32, { }, { 0.5f }}, + }, + false + }, + { + { + { 1, 3, 8, 16 }, + {}, + ov::element::i8, + {ov::element::f32, {}, {}}, + }, + { + {}, + {{ 0.7f }, ov::element::f32}, + ov::element::f32, + {} + }, + true + } + }, + + // Constant as input with 1 dimension shape + { + LayerTransformation::createParamsU8I8(), + { + { + {}, + {{ 7.f, 8.f, 9.f }, ov::element::i8, ov::Shape{3}}, // Constant as input + ov::element::i8, + {ov::element::f32, { }, { {0.1f, 0.2f, 0.3f}, element::f32, ov::Shape{3} }}, + }, + { + { 1, 2, 3 }, + {}, + ov::element::f32, + {{}, {}, {{0.2f, 0.3f, 0.4f}, element::f32, ov::Shape{3}}}, + }, + false + }, + { + { + { 1, 2, 3 }, + {}, + ov::element::f32, + {}, + }, + { + {}, + { {0.14f, 0.48f, 1.08f}, ov::element::f32, ov::Shape{3}}, // Constant as input + {}, + {}, + }, + true + } + }, + + // Parameter as input with, Constant with 1 dimension shape + { + LayerTransformation::createParamsU8I8(), + { + { + { 1, 2, 3 }, + {}, + ov::element::f32, + {{}, {}, {{0.2f, 0.3f, 0.4f}, element::f32, ov::Shape{3}}}, + }, + { + {}, + {{ 7.f, 8.f, 9.f }, ov::element::i8, ov::Shape{3}}, // Constant as input + ov::element::i8, + {ov::element::f32, { }, { {0.1f, 0.2f, 0.3f}, element::f32, ov::Shape{3} }}, + }, + false + }, + { + { + { 1, 2, 3 }, + {}, + ov::element::f32, + {}, + }, + { + {}, + { {0.14f, 0.48f, 1.08f}, ov::element::f32, ov::Shape{3}}, // Constant as input + {}, + {}, + }, + true + } + }, + + // Actual: + // + // Parameter Constant Constant Constant + // |I8 |I8 |I8 |I8 + // | | | | + // Convert Convert Convert Convert + // \FP32 /FP32 |I8 /FP32 + // \ / | / + // Subtract Constant Subtract Constant + // \FP32 /FP32 \FP32 /FP32 + // \ / \ / + // Multiply Multiply + // \FP32 /FP32 + // \ / + // Multiply + // Transformed: + // + // Parameter Constant + // |I8 |I8 + // | | + // Convert Convert + // \FP32 /FP32 + // \ / + // Subtract Constant + // \FP32 /FP32 + // \ / + // Multiply + // + { + LayerTransformation::createParamsU8I8(), + { + { + { 1, 3, 8, 16 }, + {}, + ov::element::i8, + { + ov::element::f32, + { {127.f}, ov::element::f32, {}, false, 1, ov::element::i8, true }, + { 0.2f } + }, + }, + { + {}, + {{ 7.f }, ov::element::i8}, // Constant as input + ov::element::i8, + { + ov::element::f32, + { {127.f}, ov::element::f32, {}, false, 1, ov::element::i8, true }, + { 0.5f } + }, + }, + false + }, + { + { + { 1, 3, 8, 16 }, + {}, + ov::element::i8, + { + ov::element::f32, + { {127.f}, ov::element::f32, {}, false, 1, ov::element::i8, true }, + {} + }, + }, + { + {}, + {{ -12.f }, ov::element::f32}, + ov::element::f32, + {} + }, + true + } + }, + + // Actual: + // + // Constant Constant Parameter Constant + // |I8 |I8 |I8 |I8 + // | | | | + // Convert Convert Convert Convert + // \FP32 /FP32 |I8 /FP32 + // \ / | / + // Subtract Constant Subtract Constant + // \FP32 /FP32 \FP32 /FP32 + // \ / \ / + // Multiply Multiply + // \FP32 /FP32 + // \ / + // Multiply + // Transformed: + // + // Parameter Constant + // |I8 |I8 + // | | + // Convert Convert + // \FP32 /FP32 + // \ / + // Subtract Constant + // \FP32 /FP32 + // \ / + // Multiply + // + { + LayerTransformation::createParamsU8I8(), + { + { + {}, + {{ 7.f }, ov::element::i8}, // Constant as input + ov::element::i8, + { + ov::element::f32, + { {127.f}, ov::element::f32, {}, false, 1, ov::element::i8, true }, + { 0.5f } + }, + }, + { + { 1, 3, 8, 16 }, + {}, + ov::element::i8, + { + ov::element::f32, + { {127.f}, ov::element::f32, {}, false, 1, ov::element::i8, true }, + { 0.2f } + }, + }, + false + }, + { + { + { 1, 3, 8, 16 }, + {}, + ov::element::i8, + { + ov::element::f32, + { {127.f}, ov::element::f32, {}, false, 1, ov::element::i8, true }, + {} + }, + }, + { + {}, + {{ -12.f }, ov::element::f32}, + ov::element::f32, + {} + }, + true + } + }, + { + LayerTransformation::createParamsU8I8(), + { + { + {}, + {{ 7.f }, ov::element::i8}, // Constant as input + ov::element::i8, + { + ov::element::f32, + { {127.f}, ov::element::f32, {}, false, 1, ov::element::i8, true }, + { 0.5f } + }, + }, + { + { Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic() }, + {}, + ov::element::i8, + { + ov::element::f32, + { {127.f}, ov::element::f32, {}, false, 1, ov::element::i8, true }, + { 0.2f } + }, + }, + false + }, + { + { + { Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic() }, + {}, + ov::element::i8, + { + ov::element::f32, + { {127.f}, ov::element::f32, {}, false, 1, ov::element::i8, true }, + {} + }, + }, + { + {}, + {{ -12.f }, ov::element::f32}, + ov::element::f32, + {} + }, + true + } + }, +}; + +INSTANTIATE_TEST_SUITE_P( + smoke_LPT, + MultiplyPartialTransformation, + ::testing::Combine( + ::testing::ValuesIn(precisions), + ::testing::ValuesIn(multiplyTransformationTestValues)), + MultiplyPartialTransformation::getTestCaseName); +} // namespace diff --git a/src/common/low_precision_transformations/tests/multiply_transformation.cpp b/src/common/low_precision_transformations/tests/multiply_transformation.cpp index f8925c75c9e..3ea4563f62b 100644 --- a/src/common/low_precision_transformations/tests/multiply_transformation.cpp +++ b/src/common/low_precision_transformations/tests/multiply_transformation.cpp @@ -14,6 +14,7 @@ #include "transformations/utils/utils.hpp" #include "transformations/init_node_info.hpp" #include "low_precision/multiply.hpp" +#include "low_precision/multiply_to_group_convolution.hpp" #include "lpt_ngraph_functions/common/dequantization_operations.hpp" #include "common_test_utils/ov_test_utils.hpp" @@ -26,8 +27,49 @@ using namespace ov; using namespace ov::pass; using namespace ngraph::builder::subgraph; +class MultiplyBranch { +public: + ngraph::builder::subgraph::Constant constant; + ngraph::element::Type input_precision; + ngraph::builder::subgraph::DequantizationOperations dequantization; + ngraph::builder::subgraph::FakeQuantizeOnData fake_quantize; +}; + +inline std::ostream& operator<<(std::ostream& out, const MultiplyBranch& branch) { + if (branch.input_precision != element::undefined) { + out << "_input=" << branch.input_precision; + } + if (!branch.constant.empty()) { + out << "_constant=" << branch.constant; + } + if (!branch.dequantization.empty()) { + out << "_dequantization=" << branch.dequantization; + } + if (!branch.fake_quantize.empty()) { + out << "_fake_quantize=" << branch.constant; + } + return out; +} + +class MultiplyValues { +public: + MultiplyBranch branch1; + MultiplyBranch branch2; + ngraph::builder::subgraph::DequantizationOperations after_dequantization; +}; + +inline std::ostream& operator<<(std::ostream& out, const MultiplyValues& values) { + return out << "_branch1=" << values.branch1 << "_branch2=" << values.branch2 << "_after=" << values.after_dequantization; +} + class MultiplyTransformationTestValues { public: + // use this value in test case declaration to set precision as input precision + static const ov::element::Type input_precision; + + // use this value in test case declaration to set precision as model precision + static const ov::element::Type model_precision; + TestTransformationParams transformationParams; MultiplyValues actual; MultiplyValues expected; @@ -43,35 +85,107 @@ public: expected(std::move(expected)) {} }; +const ov::element::Type MultiplyTransformationTestValues::input_precision = ov::element::undefined; +const ov::element::Type MultiplyTransformationTestValues::model_precision = ov::element::undefined; + typedef std::tuple< - ov::element::Type, + ov::element::Type, // model precision + std::pair, // input_shapes + std::pair, // input precisions MultiplyTransformationTestValues> MultiplyTransformationParams; class MultiplyTransformation : public LayerTransformation, public testing::WithParamInterface { public: void SetUp() override { - const ov::element::Type precision = std::get<0>(GetParam()); - const MultiplyTransformationTestValues testParams = std::get<1>(GetParam()); + const auto model_precision = std::get<0>(GetParam()); + const auto input_shapes = std::get<1>(GetParam()); + const auto input_precisions = std::get<2>(GetParam()); + MultiplyTransformationTestValues testParams = std::get<3>(GetParam()); - actualFunction = MultiplyFunction::get(precision, testParams.actual); - SimpleLowPrecisionTransformer transform; + update_input_precisions(input_precisions, testParams); + update_dequantization_precision(model_precision, testParams); + + // output precision has to be defined by model precision + if (testParams.expected.after_dequantization.multiply.outPrecision == MultiplyTransformationTestValues::model_precision) { + testParams.expected.after_dequantization.multiply.outPrecision = model_precision; + } + + const auto to_multiply_values = [&input_shapes, &input_precisions](const MultiplyValues& values) { + return ngraph::builder::subgraph::MultiplyValues( + ngraph::builder::subgraph::MultiplyBranch( + input_shapes.first, values.branch1.constant, input_precisions.first, values.branch1.dequantization, values.branch1.fake_quantize), + ngraph::builder::subgraph::MultiplyBranch( + input_shapes.second, values.branch2.constant, input_precisions.second, values.branch2.dequantization, values.branch2.fake_quantize), + ngraph::builder::subgraph::DequantizationOperations(values.after_dequantization)); + }; + + actualFunction = MultiplyFunction::get(model_precision, to_multiply_values(testParams.actual)); + + SimpleLowPrecisionTransformer transform({}, {}, AttributeParameters(), true); transform.add(testParams.transformationParams); + transform.cleanup->get_pass_config()->disable(); transform.transform(actualFunction); - referenceFunction = MultiplyFunction::get(precision, testParams.expected); + referenceFunction = MultiplyFunction::get(model_precision, to_multiply_values(testParams.expected)); } static std::string getTestCaseName(testing::TestParamInfo obj) { - const ov::element::Type precision = std::get<0>(obj.param); - const MultiplyTransformationTestValues testParams = std::get<1>(obj.param); + const auto model_precision = std::get<0>(obj.param); + const auto input_shapes = std::get<1>(obj.param); + const auto input_precisions = std::get<2>(obj.param); + MultiplyTransformationTestValues testParams = std::get<3>(obj.param); std::ostringstream result; - result << - LayerTransformation::getTestCaseNameByParams(precision, testParams.expected.branch1.inputShape, testParams.transformationParams) << - testParams.actual << - testParams.expected; + result << LayerTransformation::getTestCaseNameByParams(model_precision, input_shapes.first, testParams.transformationParams) << + "_SH1=" << input_shapes.first << + "_TY1=" << input_precisions.first << + "_SH2=" << input_shapes.second << + "_TY2=" << input_precisions.second; + + update_input_precisions(input_precisions, testParams); + update_dequantization_precision(model_precision, testParams); + + result << testParams.actual << testParams.expected; return result.str(); } + +private: + // dequantization output precision has to be defined by input precision + static void update_dequantization_precision(const ov::element::Type& dequantization_precision, + MultiplyTransformationTestValues& test_values) { + if (!test_values.actual.after_dequantization.multiply.empty() && + test_values.actual.after_dequantization.multiply.outPrecision == MultiplyTransformationTestValues::input_precision) { + test_values.actual.after_dequantization.multiply.outPrecision = dequantization_precision; + } + + if (!test_values.expected.after_dequantization.multiply.empty() && + test_values.expected.after_dequantization.multiply.outPrecision == MultiplyTransformationTestValues::input_precision) { + test_values.expected.after_dequantization.multiply.outPrecision = dequantization_precision; + } + } + + // low precision has to be defined by tests parameters + static void update_input_precisions(const std::pair& input_precisions, + MultiplyTransformationTestValues& test_values) { + const auto update_values = [](const std::pair& input_precisions, MultiplyValues& values) { + const auto update_branch = [](const ov::element::Type& input_precision, MultiplyBranch& branch) { + if (branch.input_precision == MultiplyTransformationTestValues::input_precision) { + branch.input_precision = input_precision; + } + + if (!branch.constant.empty() && + (branch.constant.outPrecision == MultiplyTransformationTestValues::input_precision)) { + branch.constant.outPrecision = input_precision; + } + }; + + update_branch(input_precisions.first, values.branch1); + update_branch(input_precisions.second, values.branch2); + }; + + update_values(input_precisions, test_values.actual); + update_values(input_precisions, test_values.expected); + } }; TEST_P(MultiplyTransformation, CompareFunctions) { @@ -82,44 +196,121 @@ TEST_P(MultiplyTransformation, CompareFunctions) { ASSERT_TRUE(LayerTransformation::allNamesAreUnique(actualFunction)) << "Not all names are unique"; } -const std::vector precisions = { +const std::vector model_precisions = { ov::element::f32, ov::element::f16 }; +const std::vector> input_shapes = { + {{ 1, 3, 8, 16 }, { 1, 3, 8, 16 }}, + {{ 1, 3, 8, 16 }, { 1, 3, 1, 1 }}, + {{ 1, 3, 1, 1 }, { 1, 3, 8, 16 }}, + { + { Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic() }, + { Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic() } + }, + { + { Dimension::dynamic(), 3, Dimension::dynamic(), Dimension::dynamic() }, + { Dimension::dynamic(), 3, Dimension::dynamic(), Dimension::dynamic() } + } +}; + +namespace multiply_channel_fq { + const std::vector> input_precisions = { + { ov::element::u8, ov::element::f32 }, + { ov::element::u8, ov::element::f16 }, + { ov::element::i8, ov::element::f32 }, + { ov::element::i8, ov::element::f16 } + }; + + const std::vector multiplyTransformationTestValues = { + { + LayerTransformation::createParamsU8I8(), + { + { + {}, + MultiplyTransformationTestValues::input_precision, + {ov::element::f32, { 2.f }, { 10.f }} + }, + { + {{ 0.f, 1.27f, 2.55f }, MultiplyTransformationTestValues::input_precision, ov::Shape{1, 3, 1, 1}}, // Constant as input, + {}, + {}, + { + 256, + ov::Shape{1, 3, 1, 1}, + {0.f, 0.f, 0.f}, + {2.55f, 2.55f, 2.55f}, + {0.f, 0.f, 0.f}, + {2.55f, 2.55f, 2.55f}, + MultiplyTransformationTestValues::input_precision + } // FakeQuantize + }, + }, + { + { + {}, + MultiplyTransformationTestValues::input_precision, + {{}, {{2.f}, ov::element::f32}, {}} + }, + { + {{ 0, 127, 255 }, ov::element::u8, ov::Shape{1, 3, 1, 1}}, // Constant as input, + {}, + {} + }, + {{}, {}, {{0.1f, 0.1f, 0.1f}}} + }, + }, + }; + + INSTANTIATE_TEST_SUITE_P( + smoke_LPT, + MultiplyTransformation, + ::testing::Combine( + ::testing::ValuesIn(model_precisions), + ::testing::ValuesIn(input_shapes), + ::testing::ValuesIn(input_precisions), + ::testing::ValuesIn(multiplyTransformationTestValues)), + MultiplyTransformation::getTestCaseName); +} // namespace multiply_channel_fq + +const std::vector> input_precisions = { + { ov::element::u8, ov::element::u8 }, + { ov::element::i8, ov::element::i8 }, + { ov::element::u8, ov::element::i8 }, + { ov::element::i8, ov::element::u8 }, + { ov::element::f32, ov::element::f32 }, + { ov::element::f16, ov::element::f16 }, +}; + +namespace multiply_channel { const std::vector multiplyTransformationTestValues = { - // U8 { LayerTransformation::createParamsU8I8(), { { - { 1, 3, 8, 16 }, {}, - ov::element::u8, + MultiplyTransformationTestValues::input_precision, {ov::element::f32, { 2.f }, { 10.f }} }, { - { 1, 3, 8, 16 }, {}, - ov::element::u8, + MultiplyTransformationTestValues::input_precision, {ov::element::f32, { 3.f }, { 7.f }} }, - false }, { { - { 1, 3, 8, 16 }, {}, - ov::element::u8, - {ov::element::f32, { 2.f }, { 10.f }} + MultiplyTransformationTestValues::input_precision, + {{}, {{2.f}, ov::element::f32}, {}} }, { - { 1, 3, 8, 16 }, {}, - ov::element::u8, - {ov::element::f32, { 3.f }, { 7.f }} + MultiplyTransformationTestValues::input_precision, + {{}, {{3.f}, ov::element::f32}, {}} }, - false + {{}, {}, {{70.f}, MultiplyTransformationTestValues::model_precision}} } }, @@ -127,33 +318,28 @@ const std::vector multiplyTransformationTestVa LayerTransformation::createParamsU8I8(), { { - { Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic() }, {}, - ov::element::u8, + MultiplyTransformationTestValues::input_precision, {ov::element::f32, { 2.f }, { 10.f }} }, { - { Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic() }, + {{ 7.f, 8.f, 9.f }, MultiplyTransformationTestValues::input_precision, ov::Shape{1, 3, 1, 1}}, // Constant as input, {}, - ov::element::u8, {ov::element::f32, { 3.f }, { 7.f }} }, - false }, { { - { Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic() }, {}, - ov::element::u8, - {ov::element::f32, { 2.f }, { 10.f }} + MultiplyTransformationTestValues::input_precision, + {{}, {{2.f}, ov::element::f32}, {}} }, { - { Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic() }, + {{ 7.f, 8.f, 9.f }, MultiplyTransformationTestValues::input_precision, ov::Shape{1, 3, 1, 1}}, // Constant as input, {}, - ov::element::u8, - {ov::element::f32, { 3.f }, { 7.f }} + {{}, {{3.f}, ov::element::f32}, {}} }, - false + {{}, {}, {{70.f}, MultiplyTransformationTestValues::model_precision}} } }, @@ -161,33 +347,115 @@ const std::vector multiplyTransformationTestVa LayerTransformation::createParamsU8I8(), { { - { 1, 3, 8, 16 }, + {{ 7.f, 8.f, 9.f }, MultiplyTransformationTestValues::input_precision, ov::Shape{1, 3, 1, 1}}, // Constant as input, {}, - ov::element::u8, + {ov::element::f32, { 3.f }, { 7.f }} + }, + { + {}, + MultiplyTransformationTestValues::input_precision, + {ov::element::f32, { 2.f }, { 10.f }} + } + }, + { + { + {{ 7.f, 8.f, 9.f }, MultiplyTransformationTestValues::input_precision, ov::Shape{1, 3, 1, 1}}, // Constant as input, + {}, + {{}, {{3.f}, ov::element::f32}, {}} + }, + { + {}, + MultiplyTransformationTestValues::input_precision, + {{}, {{2.f}, ov::element::f32}, {}} + }, + {{}, {}, {{70.f}, MultiplyTransformationTestValues::model_precision}} + } + }, + + { + LayerTransformation::createParamsU8I8(), + { + { + {}, + MultiplyTransformationTestValues::input_precision, + {ov::element::f32, {}, { 10.f }} + }, + { + {}, + MultiplyTransformationTestValues::input_precision, + {ov::element::f32, {}, { 7.f }} + } + }, + { + { + {}, + MultiplyTransformationTestValues::input_precision, + {} + }, + { + {}, + MultiplyTransformationTestValues::input_precision, + {} + }, + {{}, {}, {{70.f}, MultiplyTransformationTestValues::model_precision}} + } + }, + + { + LayerTransformation::createParamsU8I8(), + { + { + {}, + MultiplyTransformationTestValues::input_precision, + {ov::element::f32, {{ 1.f, 2.f, 3.f }}, {{ 10.f, 11.f, 12.f }}} + }, + { + {}, + MultiplyTransformationTestValues::input_precision, + {ov::element::f32, {{ 3.f, 4.f, 5.f }}, {{ 7.f, 8.f, 9.f }}} + } + }, + { + { + {}, + MultiplyTransformationTestValues::input_precision, + {{}, {{1.f, 2.f, 3.f}, ov::element::f32}, {}} + }, + { + {}, + MultiplyTransformationTestValues::input_precision, + {{}, {{3.f, 4.f, 5.f }, ov::element::f32}, {}} + }, + {{}, {}, {{70.f, 88.f, 108.f}, MultiplyTransformationTestValues::model_precision}} + } + }, + + { + LayerTransformation::createParamsU8I8(), + { + { + {}, + MultiplyTransformationTestValues::input_precision, {ov::element::f32, { 2.f }, { 10.f }} }, { - { 1, 3, 8, 16 }, {}, - ov::element::u8, + MultiplyTransformationTestValues::input_precision, {ov::element::f32, { }, { 7.f }} - }, - false + } }, { { - { 1, 3, 8, 16 }, {}, - ov::element::u8, - {ov::element::f32, { 2.f }, { 70.f }} + MultiplyTransformationTestValues::input_precision, + {{}, {{2.f}, ov::element::f32}, {}} }, { - { 1, 3, 8, 16 }, {}, - ov::element::u8, + MultiplyTransformationTestValues::input_precision, {} }, - false + {{}, {}, {{70.f}, MultiplyTransformationTestValues::model_precision}} } }, @@ -195,803 +463,28 @@ const std::vector multiplyTransformationTestVa LayerTransformation::createParamsU8I8(), { { - { Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic() }, {}, - ov::element::u8, - {ov::element::f32, { 2.f }, { 10.f }} + MultiplyTransformationTestValues::input_precision, + {ov::element::f32, {}, { 10.f }} }, { - { Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic() }, {}, - ov::element::u8, - {ov::element::f32, { }, { 7.f }} - }, - false - }, - { - { - { Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic() }, - {}, - ov::element::u8, - {ov::element::f32, { 2.f }, { 70.f }} - }, - { - { Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic() }, - {}, - ov::element::u8, - {} - }, - false - } - }, - - { - LayerTransformation::createParamsU8I8(), - { - { - { 1, 3, 8, 16 }, - {}, - ov::element::u8, - { ov::element::f32, { }, { 10.f }} - }, - { - { 1, 3, 8, 16 }, - {}, - ov::element::u8, - { ov::element::f32, { }, { 7.f } } - }, - false - }, - { - { - { 1, 3, 8, 16 }, - {}, - ov::element::u8, - {ov::element::f32, { }, { 70.f }} - }, - { - { 1, 3, 8, 16 }, - {}, - ov::element::u8, - {} - }, - false - } - }, - - { - LayerTransformation::createParamsU8I8(), - { - { - { 1, 3, 8, 16 }, - {}, - ov::element::u8, - {ov::element::f32, { 2.f }, { }} - }, - { - { 1, 3, 8, 16 }, - {}, - ov::element::u8, - {ov::element::f32, { }, { 7.f } } - }, - false - }, - { - { - { 1, 3, 8, 16 }, - {}, - ov::element::u8, - {ov::element::f32, { 2.f }, { 7.f }} - }, - { - { 1, 3, 8, 16 }, - {}, - ov::element::u8, - {} - }, - false - } - }, - { - LayerTransformation::createParamsU8I8(), - { - { - { Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic() }, - {}, - ov::element::u8, - {ov::element::f32, { 2.f }, { }} - }, - { - { Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic() }, - {}, - ov::element::u8, - {ov::element::f32, { }, { 7.f } } - }, - false - }, - { - { - { Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic() }, - {}, - ov::element::u8, - {ov::element::f32, { 2.f }, { 7.f }} - }, - { - { Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic() }, - {}, - ov::element::u8, - {} - }, - false - } - }, - { - LayerTransformation::createParamsU8I8(), - { - { - PartialShape::dynamic(), - {}, - ov::element::u8, - {ov::element::f32, { 2.f }, { }} - }, - { - PartialShape::dynamic(), - {}, - ov::element::u8, - {ov::element::f32, { }, { 7.f } } - }, - false - }, - { - { - PartialShape::dynamic(), - {}, - ov::element::u8, - {ov::element::f32, { 2.f }, { }} - }, - { - PartialShape::dynamic(), - {}, - ov::element::u8, - {ov::element::f32, { }, { 7.f } } - }, - false - } - }, - - // I8 - { - LayerTransformation::createParamsI8I8(), - { - { - { 1, 3, 8, 16 }, - {}, - ov::element::i8, - {ov::element::f32, { 2.f }, { 10.f }} - }, - { - { 1, 3, 8, 16 }, - {}, - ov::element::i8, + MultiplyTransformationTestValues::input_precision, {ov::element::f32, { 3.f }, { 7.f }} - }, - false + } }, { { - { 1, 3, 8, 16 }, {}, - ov::element::i8, - {ov::element::f32, { 2.f }, { 10.f }} - }, - { - { 1, 3, 8, 16 }, - {}, - ov::element::i8, - {ov::element::f32, { 3.f }, { 7.f } } - }, - false - } - }, - - // Actual: - // - // Parameter - // |I8 - // | - // Convert Constant Parameter - // \FP32 /FP32 |I8 - // \ / | - // Subtract Constant Convert Constant - // \FP32 /FP32 \FP32 /FP32 - // \ / \ / - // Multiply Multiply - // \FP32 /FP32 - // \ / - // Multiply - // Transformed: - // - // Parameter - // |I8 - // | - // Convert Constant - // \FP32 /FP32 - // \ / - // Subtract Constant - // \FP32 /FP32 - // \ / - // Multiply Parameter - // \FP32 /I8 - // \ / - // Multiply - { - LayerTransformation::createParamsI8I8(), - { - { - { 1, 3, 8, 16 }, - {}, - ov::element::i8, - {ov::element::f32, { 2.f }, { 10.f }} - }, - { - { 1, 3, 8, 16 }, - {}, - ov::element::i8, - {ov::element::f32, { }, { 7.f }} - }, - false - }, - { - { - { 1, 3, 8, 16 }, - {}, - ov::element::i8, - {ov::element::f32, { 2.f }, { 70.f }}, - }, - { - { 1, 3, 8, 16 }, - {}, - ov::element::i8, + MultiplyTransformationTestValues::input_precision, {} }, - false - } - }, - - // Actual: - // - // Parameter Constant - // |I8 |I8 - // | | - // Convert Convert Parameter - // \FP32 /FP32 |I8 - // \ / | - // Subtract Constant Convert Constant - // \FP32 /FP32 \FP32 /FP32 - // \ / \ / - // Multiply Multiply - // \FP32 /FP32 - // \ / - // Multiply - // Transformed: - // - // Parameter - // |I8 - // | - // Convert Constant - // \FP32 /FP32 - // \ / - // Subtract Constant - // \FP32 /FP32 - // \ / - // Multiply Parameter - // \FP32 /I8 - // \ / - // Multiply - { - LayerTransformation::createParamsI8I8(), - { - { - { 1, 3, 8, 16 }, - {}, - ov::element::i8, - { - ov::element::f32, - { {2.f}, ov::element::f32, {}, true, 1ul, ov::element::i8, true }, - { 10.f } - } - }, - { - { 1, 3, 8, 16 }, - {}, - ov::element::i8, - {ov::element::f32, { }, { 7.f }} - }, - false - }, - { - { - { 1, 3, 8, 16 }, - {}, - ov::element::i8, - {ov::element::f32, { 2.f }, { 70.f }}, - }, - { - { 1, 3, 8, 16 }, - {}, - ov::element::i8, - {} - }, - false - } - }, - - { - LayerTransformation::createParamsI8I8(), - { - { - { 1, 3, 8, 16 }, - {}, - ov::element::i8, - {ov::element::f32, { }, { 10.f }} - }, - { - { 1, 3, 8, 16 }, - {}, - ov::element::i8, - {ov::element::f32, { }, { 7.f } } - }, - false - }, - { - { - { 1, 3, 8, 16 }, - {}, - ov::element::i8, - { ov::element::f32, { }, { 70.f }} - }, - { - { 1, 3, 8, 16 }, - {}, - ov::element::i8, - { } - }, - false - } - }, - - { - LayerTransformation::createParamsI8I8(), - { - { - { 1, 3, 8, 16 }, - {}, - ov::element::i8, - {ov::element::f32, { 2.f }, { }}, - }, - { - { 1, 3, 8, 16 }, - {}, - ov::element::i8, - {ov::element::f32, { }, { 7.f } }, - }, - false - }, - { - { - { 1, 3, 8, 16 }, - {}, - ov::element::i8, - {ov::element::f32, { 2.f }, { 7.f }}, - }, - { - { 1, 3, 8, 16 }, - {}, - ov::element::i8, - {} - }, - false - } - }, - - // Constant as input - { - LayerTransformation::createParamsU8I8(), - { - { - { 1, 3, 8, 16 }, - {}, - ov::element::i8, - {ov::element::f32, { }, { 10.f }}, - }, { {}, - {{ 7.f }, ov::element::f32}, // Constant as input - ov::element::f32, - {} + MultiplyTransformationTestValues::input_precision, + {{}, {{3.f}, ov::element::f32}, {}} }, - false - }, - { - { - { 1, 3, 8, 16 }, - {}, - ov::element::i8, - {ov::element::f32, {}, {}}, - }, - { - {}, - {{ 70.f }, ov::element::f32}, - ov::element::f32, - {} - }, - true - } - }, - - { - LayerTransformation::createParamsU8I8(), - { - { - { Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic() }, - {}, - ov::element::i8, - {ov::element::f32, { }, { 10.f }}, - }, - { - {}, - {{ 7.f }, ov::element::f32}, // Constant as input - ov::element::f32, - {} - }, - false - }, - { - { - { Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic() }, - {}, - ov::element::i8, - {ov::element::f32, {}, {}}, - }, - { - {}, - {{ 70.f }, ov::element::f32}, - ov::element::f32, - {} - }, - true - } - }, - - { - LayerTransformation::createParamsU8I8(), - { - { - { 1, 3, 8, 16 }, - {}, - ov::element::i8, - {ov::element::f32, { 18.f }, { 10.f }}, - }, - { - {}, - {{ 7.f }, ov::element::f32}, - ov::element::f32, - {} - }, - false - }, - { - { - { 1, 3, 8, 16 }, - {}, - ov::element::i8, - {ov::element::f32, { 18.f }, { }}, - }, - { - {}, - {{ 70.f }, ov::element::f32}, - ov::element::f32, - {} - }, - true - } - }, - - // Constant as input with empty shape - { - LayerTransformation::createParamsU8I8(), - { - { - { 1, 3, 8, 16 }, - {}, - ov::element::i8, - {ov::element::f32, { }, { 0.2f }}, - }, - { - {}, - {{ 7.f }, ov::element::i8}, // Constant as input - ov::element::i8, - {ov::element::f32, { }, { 0.5f }}, - }, - false - }, - { - { - { 1, 3, 8, 16 }, - {}, - ov::element::i8, - {ov::element::f32, {}, {}}, - }, - { - {}, - {{ 0.7f }, ov::element::f32}, - ov::element::f32, - {} - }, - true - } - }, - - // Constant as input with 1 dimension shape - { - LayerTransformation::createParamsU8I8(), - { - { - {}, - {{ 7.f, 8.f, 9.f }, ov::element::i8, ov::Shape{3}}, // Constant as input - ov::element::i8, - {ov::element::f32, { }, { {0.1f, 0.2f, 0.3f}, element::f32, ov::Shape{3} }}, - }, - { - { 1, 2, 3 }, - {}, - ov::element::f32, - {{}, {}, {{0.2f, 0.3f, 0.4f}, element::f32, ov::Shape{3}}}, - }, - false - }, - { - { - { 1, 2, 3 }, - {}, - ov::element::f32, - {}, - }, - { - {}, - { {0.14f, 0.48f, 1.08f}, ov::element::f32, ov::Shape{3}}, // Constant as input - {}, - {}, - }, - true - } - }, - - // Parameter as input with, Constant with 1 dimension shape - { - LayerTransformation::createParamsU8I8(), - { - { - { 1, 2, 3 }, - {}, - ov::element::f32, - {{}, {}, {{0.2f, 0.3f, 0.4f}, element::f32, ov::Shape{3}}}, - }, - { - {}, - {{ 7.f, 8.f, 9.f }, ov::element::i8, ov::Shape{3}}, // Constant as input - ov::element::i8, - {ov::element::f32, { }, { {0.1f, 0.2f, 0.3f}, element::f32, ov::Shape{3} }}, - }, - false - }, - { - { - { 1, 2, 3 }, - {}, - ov::element::f32, - {}, - }, - { - {}, - { {0.14f, 0.48f, 1.08f}, ov::element::f32, ov::Shape{3}}, // Constant as input - {}, - {}, - }, - true - } - }, - - // Actual: - // - // Parameter Constant Constant Constant - // |I8 |I8 |I8 |I8 - // | | | | - // Convert Convert Convert Convert - // \FP32 /FP32 |I8 /FP32 - // \ / | / - // Subtract Constant Subtract Constant - // \FP32 /FP32 \FP32 /FP32 - // \ / \ / - // Multiply Multiply - // \FP32 /FP32 - // \ / - // Multiply - // Transformed: - // - // Parameter Constant - // |I8 |I8 - // | | - // Convert Convert - // \FP32 /FP32 - // \ / - // Subtract Constant - // \FP32 /FP32 - // \ / - // Multiply - // - { - LayerTransformation::createParamsU8I8(), - { - { - { 1, 3, 8, 16 }, - {}, - ov::element::i8, - { - ov::element::f32, - { {127.f}, ov::element::f32, {}, false, 1, ov::element::i8, true }, - { 0.2f } - }, - }, - { - {}, - {{ 7.f }, ov::element::i8}, // Constant as input - ov::element::i8, - { - ov::element::f32, - { {127.f}, ov::element::f32, {}, false, 1, ov::element::i8, true }, - { 0.5f } - }, - }, - false - }, - { - { - { 1, 3, 8, 16 }, - {}, - ov::element::i8, - { - ov::element::f32, - { {127.f}, ov::element::f32, {}, false, 1, ov::element::i8, true }, - {} - }, - }, - { - {}, - {{ -12.f }, ov::element::f32}, - ov::element::f32, - {} - }, - true - } - }, - - // Actual: - // - // Constant Constant Parameter Constant - // |I8 |I8 |I8 |I8 - // | | | | - // Convert Convert Convert Convert - // \FP32 /FP32 |I8 /FP32 - // \ / | / - // Subtract Constant Subtract Constant - // \FP32 /FP32 \FP32 /FP32 - // \ / \ / - // Multiply Multiply - // \FP32 /FP32 - // \ / - // Multiply - // Transformed: - // - // Parameter Constant - // |I8 |I8 - // | | - // Convert Convert - // \FP32 /FP32 - // \ / - // Subtract Constant - // \FP32 /FP32 - // \ / - // Multiply - // - { - LayerTransformation::createParamsU8I8(), - { - { - {}, - {{ 7.f }, ov::element::i8}, // Constant as input - ov::element::i8, - { - ov::element::f32, - { {127.f}, ov::element::f32, {}, false, 1, ov::element::i8, true }, - { 0.5f } - }, - }, - { - { 1, 3, 8, 16 }, - {}, - ov::element::i8, - { - ov::element::f32, - { {127.f}, ov::element::f32, {}, false, 1, ov::element::i8, true }, - { 0.2f } - }, - }, - false - }, - { - { - { 1, 3, 8, 16 }, - {}, - ov::element::i8, - { - ov::element::f32, - { {127.f}, ov::element::f32, {}, false, 1, ov::element::i8, true }, - {} - }, - }, - { - {}, - {{ -12.f }, ov::element::f32}, - ov::element::f32, - {} - }, - true - } - }, - { - LayerTransformation::createParamsU8I8(), - { - { - {}, - {{ 7.f }, ov::element::i8}, // Constant as input - ov::element::i8, - { - ov::element::f32, - { {127.f}, ov::element::f32, {}, false, 1, ov::element::i8, true }, - { 0.5f } - }, - }, - { - { Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic() }, - {}, - ov::element::i8, - { - ov::element::f32, - { {127.f}, ov::element::f32, {}, false, 1, ov::element::i8, true }, - { 0.2f } - }, - }, - false - }, - { - { - { Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic() }, - {}, - ov::element::i8, - { - ov::element::f32, - { {127.f}, ov::element::f32, {}, false, 1, ov::element::i8, true }, - {} - }, - }, - { - {}, - {{ -12.f }, ov::element::f32}, - ov::element::f32, - {} - }, - true + {{}, {}, {{70.f}, MultiplyTransformationTestValues::model_precision}} } }, }; @@ -1000,7 +493,337 @@ INSTANTIATE_TEST_SUITE_P( smoke_LPT, MultiplyTransformation, ::testing::Combine( - ::testing::ValuesIn(precisions), + ::testing::ValuesIn(model_precisions), + ::testing::ValuesIn(input_shapes), + ::testing::ValuesIn(input_precisions), ::testing::ValuesIn(multiplyTransformationTestValues)), MultiplyTransformation::getTestCaseName); -} // namespace +} // namespace multiply_channel + +namespace broadcast_right { +const std::vector> input_shapes = { + {{ 1, 3, 8, 16 }, { 1, 1, 1, 1 }} +}; + +const std::vector multiplyTransformationTestValues = { + { + LayerTransformation::createParamsU8I8(), + { + { + {}, + MultiplyTransformationTestValues::input_precision, + {ov::element::f32, { 2.f }, { 10.f }} + }, + { + {}, + MultiplyTransformationTestValues::input_precision, + {ov::element::f32, { 3.f }, { 7.f }} + }, + }, + { + { + {}, + MultiplyTransformationTestValues::input_precision, + {{}, {{ 2.f }, ov::element::f32}, {}} + }, + { + {}, + MultiplyTransformationTestValues::input_precision, + {{}, {{ 3.f }, ov::element::f32}, {}} + }, + {{}, {}, {{ 70.f }, MultiplyTransformationTestValues::model_precision}} + } + }, + + { + LayerTransformation::createParamsU8I8(), + { + { + {}, + MultiplyTransformationTestValues::input_precision, + {ov::element::f32, {}, { 10.f }} + }, + { + {}, + MultiplyTransformationTestValues::input_precision, + {ov::element::f32, {}, { 7.f }} + } + }, + { + { + {}, + MultiplyTransformationTestValues::input_precision, + {} + }, + { + {}, + MultiplyTransformationTestValues::input_precision, + {} + }, + {{}, {}, {{ 70.f }, MultiplyTransformationTestValues::model_precision}} + } + }, + + { + LayerTransformation::createParamsU8I8(), + { + { + {}, + MultiplyTransformationTestValues::input_precision, + {ov::element::f32, {{ 1.f, 2.f, 3.f }}, {{ 10.f, 11.f, 12.f }}} + }, + { + {}, + MultiplyTransformationTestValues::input_precision, + {ov::element::f32, { 3.f }, { 7.f }} + } + }, + { + { + {}, + MultiplyTransformationTestValues::input_precision, + {{}, {{ 1.f, 2.f, 3.f }, ov::element::f32}, {}} + }, + { + {}, + MultiplyTransformationTestValues::input_precision, + {{}, {{ 3.f }, ov::element::f32}, {}} + }, + {{}, {}, {{70.f, 77.f, 84.f}, MultiplyTransformationTestValues::model_precision}} + } + }, + + { + LayerTransformation::createParamsU8I8(), + { + { + {}, + MultiplyTransformationTestValues::input_precision, + {ov::element::f32, { 2.f }, { 10.f }} + }, + { + {}, + MultiplyTransformationTestValues::input_precision, + {ov::element::f32, {}, { 7.f }} + } + }, + { + { + {}, + MultiplyTransformationTestValues::input_precision, + {{}, {{2.f}, ov::element::f32}, {}} + }, + { + {}, + MultiplyTransformationTestValues::input_precision, + {} + }, + {{}, {}, {{70.f}, MultiplyTransformationTestValues::model_precision}} + } + }, + + { + LayerTransformation::createParamsU8I8(), + { + { + {}, + MultiplyTransformationTestValues::input_precision, + {ov::element::f32, {}, { 10.f }} + }, + { + {}, + MultiplyTransformationTestValues::input_precision, + {ov::element::f32, { 3.f }, { 7.f }} + } + }, + { + { + {}, + MultiplyTransformationTestValues::input_precision, + {} + }, + { + {}, + MultiplyTransformationTestValues::input_precision, + {{}, {{3.f}, ov::element::f32}, {}} + }, + {{}, {}, {{70.f}, MultiplyTransformationTestValues::model_precision}} + } + }, +}; + +INSTANTIATE_TEST_SUITE_P( + smoke_LPT, + MultiplyTransformation, + ::testing::Combine( + ::testing::ValuesIn(model_precisions), + ::testing::ValuesIn(input_shapes), + ::testing::ValuesIn(input_precisions), + ::testing::ValuesIn(multiplyTransformationTestValues)), + MultiplyTransformation::getTestCaseName); +} // namespace broadcast_right + +namespace broadcast_left { +const std::vector> input_shapes = { + {{ 1, 1, 1, 1 }, { 1, 3, 8, 16 }} +}; + +const std::vector multiplyTransformationTestValues = { + { + LayerTransformation::createParamsU8I8(), + { + { + {}, + MultiplyTransformationTestValues::input_precision, + {ov::element::f32, { 2.f }, { 10.f }} + }, + { + {}, + MultiplyTransformationTestValues::input_precision, + {ov::element::f32, { 3.f }, { 7.f }} + }, + }, + { + { + {}, + MultiplyTransformationTestValues::input_precision, + {{}, {{ 2.f }, ov::element::f32}, {}} + }, + { + {}, + MultiplyTransformationTestValues::input_precision, + {{}, {{ 3.f }, ov::element::f32}, {}} + }, + {{}, {}, {{ 70.f }, MultiplyTransformationTestValues::model_precision}} + } + }, + + { + LayerTransformation::createParamsU8I8(), + { + { + {}, + MultiplyTransformationTestValues::input_precision, + {ov::element::f32, {}, { 10.f }} + }, + { + {}, + MultiplyTransformationTestValues::input_precision, + {ov::element::f32, {}, { 7.f }} + } + }, + { + { + {}, + MultiplyTransformationTestValues::input_precision, + {} + }, + { + {}, + MultiplyTransformationTestValues::input_precision, + {} + }, + {{}, {}, {{ 70.f }, MultiplyTransformationTestValues::model_precision}} + } + }, + + { + LayerTransformation::createParamsU8I8(), + { + { + {}, + MultiplyTransformationTestValues::input_precision, + {ov::element::f32, { 2.f }, { 10.f }} + }, + { + {}, + MultiplyTransformationTestValues::input_precision, + {ov::element::f32, {{ 3.f, 4.f, 5.f }}, {{ 7.f, 8.f, 9.f }}} + } + }, + { + { + {}, + MultiplyTransformationTestValues::input_precision, + {{}, {{ 2.f }, ov::element::f32}, {}} + }, + { + {}, + MultiplyTransformationTestValues::input_precision, + {{}, {{ 3.f, 4.f, 5.f }, ov::element::f32}, {}} + }, + {{}, {}, {{70.f, 80.f, 90.f}, MultiplyTransformationTestValues::model_precision}} + } + }, + + { + LayerTransformation::createParamsU8I8(), + { + { + {}, + MultiplyTransformationTestValues::input_precision, + {ov::element::f32, { 2.f }, { 10.f }} + }, + { + {}, + MultiplyTransformationTestValues::input_precision, + {ov::element::f32, {}, { 7.f }} + } + }, + { + { + {}, + MultiplyTransformationTestValues::input_precision, + {{}, {{2.f}, ov::element::f32}, {}} + }, + { + {}, + MultiplyTransformationTestValues::input_precision, + {} + }, + {{}, {}, {{70.f}, MultiplyTransformationTestValues::model_precision}} + } + }, + + { + LayerTransformation::createParamsU8I8(), + { + { + {}, + MultiplyTransformationTestValues::input_precision, + {ov::element::f32, {}, { 10.f }} + }, + { + {}, + MultiplyTransformationTestValues::input_precision, + {ov::element::f32, { 3.f }, { 7.f }} + } + }, + { + { + {}, + MultiplyTransformationTestValues::input_precision, + {} + }, + { + {}, + MultiplyTransformationTestValues::input_precision, + {{}, {{3.f}, ov::element::f32}, {}} + }, + {{}, {}, {{70.f}, MultiplyTransformationTestValues::model_precision}} + } + }, +}; + +INSTANTIATE_TEST_SUITE_P( + smoke_LPT, + MultiplyTransformation, + ::testing::Combine( + ::testing::ValuesIn(model_precisions), + ::testing::ValuesIn(input_shapes), + ::testing::ValuesIn(input_precisions), + ::testing::ValuesIn(multiplyTransformationTestValues)), + MultiplyTransformation::getTestCaseName); +} // namespace broadcast_left + +} // namespace \ No newline at end of file diff --git a/src/common/low_precision_transformations/tests/simple_low_precision_transformer.cpp b/src/common/low_precision_transformations/tests/simple_low_precision_transformer.cpp index a805aadb8f4..9f39cc64de5 100644 --- a/src/common/low_precision_transformations/tests/simple_low_precision_transformer.cpp +++ b/src/common/low_precision_transformations/tests/simple_low_precision_transformer.cpp @@ -12,17 +12,29 @@ #include "low_precision/markup_quantization_granularity.hpp" #include "low_precision/transformation_context.hpp" +// cleanup transformations +#include "low_precision/convert.hpp" +#include "low_precision/eliminate_fake_quantize.hpp" +#include "low_precision/fold_convert.hpp" +#include "low_precision/fold_fake_quantize.hpp" +#include "low_precision/fuse_convert.hpp" +#include "low_precision/fuse_multiply_to_fake_quantize.hpp" +#include "low_precision/fuse_subtract_to_fake_quantize.hpp" +#include "low_precision/multiply_to_group_convolution.hpp" + #include using namespace testing; using namespace ov::pass; +using namespace ov::pass::low_precision; OPENVINO_SUPPRESS_DEPRECATED_START SimpleLowPrecisionTransformer::SimpleLowPrecisionTransformer( const std::vector& precisionRestrictions, const std::vector& quantizationRestrictions, - const AttributeParameters& params) { + const AttributeParameters& params, + const bool addCleanup) { auto passConfig = get_pass_config(); // TODO: use one pass manager @@ -39,7 +51,20 @@ SimpleLowPrecisionTransformer::SimpleLowPrecisionTransformer( common = std::make_shared(passConfig); commonGraphRewrite = common->register_pass(); + cleanup = common->register_pass(); + if (addCleanup) { + ov::pass::low_precision::LayerTransformation::Params params; + cleanup->add_matcher(params); + cleanup->add_matcher(params); + cleanup->add_matcher(params); + cleanup->add_matcher(params); + cleanup->add_matcher(params); + + cleanup->add_matcher( + params, + PrecisionsRestriction::getPrecisionsByOperationType(precisionRestrictions)); + } } void SimpleLowPrecisionTransformer::transform(std::shared_ptr& model) { diff --git a/src/common/low_precision_transformations/tests/simple_low_precision_transformer.hpp b/src/common/low_precision_transformations/tests/simple_low_precision_transformer.hpp index d7f49649b01..2c65f0b316b 100644 --- a/src/common/low_precision_transformations/tests/simple_low_precision_transformer.hpp +++ b/src/common/low_precision_transformations/tests/simple_low_precision_transformer.hpp @@ -19,7 +19,8 @@ public: SimpleLowPrecisionTransformer( const std::vector& precisionRestrictions = {}, const std::vector& quantizationRestrictions = {}, - const AttributeParameters& params = AttributeParameters()); + const AttributeParameters& params = AttributeParameters(), + const bool addCleanup = false); template void add(const TestTransformationParams& params) { diff --git a/src/tests/functional/plugin/shared/src/low_precision_transformations/multiply_transformation.cpp b/src/tests/functional/plugin/shared/src/low_precision_transformations/multiply_transformation.cpp index 26846b5f97c..2088d4db876 100644 --- a/src/tests/functional/plugin/shared/src/low_precision_transformations/multiply_transformation.cpp +++ b/src/tests/functional/plugin/shared/src/low_precision_transformations/multiply_transformation.cpp @@ -11,7 +11,7 @@ #include #include -#include "lpt_ngraph_functions/multiply_function.hpp" +#include "lpt_ngraph_functions/multiply_partial_function.hpp" #include "ngraph_functions/subgraph_builders.hpp" @@ -56,7 +56,7 @@ void MultiplyTransformation::SetUp() { MultiplyTestValues param; std::tie(precision, inputShape, targetDevice, param) = this->GetParam(); - function = ngraph::builder::subgraph::MultiplyFunction::getOriginal( + function = ngraph::builder::subgraph::MultiplyPartialFunction::get( precision, inputShape, param.broadcast1, diff --git a/src/tests/ngraph_helpers/lpt_ngraph_functions/include/lpt_ngraph_functions/multiply_function.hpp b/src/tests/ngraph_helpers/lpt_ngraph_functions/include/lpt_ngraph_functions/multiply_function.hpp index b5b4c22e5fc..553a34b02d1 100644 --- a/src/tests/ngraph_helpers/lpt_ngraph_functions/include/lpt_ngraph_functions/multiply_function.hpp +++ b/src/tests/ngraph_helpers/lpt_ngraph_functions/include/lpt_ngraph_functions/multiply_function.hpp @@ -17,42 +17,39 @@ namespace subgraph { class MultiplyBranch { public: + MultiplyBranch(const PartialShape& inputShape, + const ngraph::builder::subgraph::Constant& constant, + const ngraph::element::Type& input_precision, + const ngraph::builder::subgraph::DequantizationOperations& dequantization, + const ngraph::builder::subgraph::FakeQuantizeOnData& fake_quantize) + : inputShape(inputShape), + constant(constant), + input_precision(input_precision), + dequantization(dequantization), + fake_quantize(fake_quantize) {} + PartialShape inputShape; ngraph::builder::subgraph::Constant constant; - ngraph::element::Type precisionBeforeDequantization; + ngraph::element::Type input_precision; ngraph::builder::subgraph::DequantizationOperations dequantization; + ngraph::builder::subgraph::FakeQuantizeOnData fake_quantize; }; -inline std::ostream& operator<<(std::ostream& out, const MultiplyBranch& branch) { - return out << "_" << branch.constant << "_" << branch.precisionBeforeDequantization << "_" << branch.dequantization; -} - class MultiplyValues { public: + MultiplyValues(const MultiplyBranch& branch1, + const MultiplyBranch& branch2, + const ngraph::builder::subgraph::DequantizationOperations& after_dequantization) + : branch1(branch1), branch2(branch2), after_dequantization(after_dequantization) {} + MultiplyBranch branch1; MultiplyBranch branch2; - bool isDequantization; + ngraph::builder::subgraph::DequantizationOperations after_dequantization; }; -inline std::ostream& operator<<(std::ostream& out, const MultiplyValues& values) { - return out << "_" << values.branch1 << "_" << values.branch2 << (values.isDequantization ? "_isDequantization" : ""); -} - class MultiplyFunction : public ElementwiseFunction { public: - static std::shared_ptr get( - const element::Type precision, - const MultiplyValues& actualValues); - - static std::shared_ptr getOriginal( - const ngraph::element::Type precision, - const ngraph::PartialShape& inputShape, - const bool broadcast1, - const ngraph::builder::subgraph::FakeQuantizeOnData& fq1, - const bool broadcast2, - const ngraph::builder::subgraph::FakeQuantizeOnData& fq2, - const ngraph::builder::subgraph::FakeQuantizeOnData& fqAfter, - const bool secondInputIsConstant = false); + static std::shared_ptr get(const element::Type model_precision, const MultiplyValues& actualValues); }; } // namespace subgraph diff --git a/src/tests/ngraph_helpers/lpt_ngraph_functions/include/lpt_ngraph_functions/multiply_partial_function.hpp b/src/tests/ngraph_helpers/lpt_ngraph_functions/include/lpt_ngraph_functions/multiply_partial_function.hpp new file mode 100644 index 00000000000..878554dd1df --- /dev/null +++ b/src/tests/ngraph_helpers/lpt_ngraph_functions/include/lpt_ngraph_functions/multiply_partial_function.hpp @@ -0,0 +1,60 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include + +#include "elementwise_function.hpp" +#include "lpt_ngraph_functions/common/constant.hpp" +#include "lpt_ngraph_functions/common/dequantization_operations.hpp" + +namespace ngraph { +namespace builder { +namespace subgraph { + +class MultiplyPartialBranch { +public: + PartialShape inputShape; + ngraph::builder::subgraph::Constant constant; + ngraph::element::Type precisionBeforeDequantization; + ngraph::builder::subgraph::DequantizationOperations dequantization; +}; + +inline std::ostream& operator<<(std::ostream& out, const MultiplyPartialBranch& branch) { + return out << "_" << branch.constant << "_" << branch.precisionBeforeDequantization << "_" << branch.dequantization; +} + +class MultiplyPartialValues { +public: + MultiplyPartialBranch branch1; + MultiplyPartialBranch branch2; + bool isDequantization; +}; + +inline std::ostream& operator<<(std::ostream& out, const MultiplyPartialValues& values) { + return out << "_" << values.branch1 << "_" << values.branch2 << (values.isDequantization ? "_isDequantization" : ""); +} + +class MultiplyPartialFunction : public ElementwiseFunction { +public: + static std::shared_ptr get( + const element::Type precision, + const MultiplyPartialValues& actualValues); + + static std::shared_ptr get( + const ngraph::element::Type precision, + const ngraph::PartialShape& inputShape, + const bool broadcast1, + const ngraph::builder::subgraph::FakeQuantizeOnData& fq1, + const bool broadcast2, + const ngraph::builder::subgraph::FakeQuantizeOnData& fq2, + const ngraph::builder::subgraph::FakeQuantizeOnData& fqAfter, + const bool secondInputIsConstant = false); +}; + +} // namespace subgraph +} // namespace builder +} // namespace ngraph diff --git a/src/tests/ngraph_helpers/lpt_ngraph_functions/src/multiply_function.cpp b/src/tests/ngraph_helpers/lpt_ngraph_functions/src/multiply_function.cpp index 4628acb8f27..e4ff86359f8 100644 --- a/src/tests/ngraph_helpers/lpt_ngraph_functions/src/multiply_function.cpp +++ b/src/tests/ngraph_helpers/lpt_ngraph_functions/src/multiply_function.cpp @@ -4,6 +4,8 @@ #include "lpt_ngraph_functions/multiply_function.hpp" +#include + #include #include #include "ngraph_functions/subgraph_builders.hpp" @@ -18,49 +20,52 @@ namespace ngraph { namespace builder { namespace subgraph { +namespace multiply_function { struct BranchNodes { std::shared_ptr input; std::shared_ptr dequantization; }; -BranchNodes getBranch(const MultiplyBranch& branch) { - const std::shared_ptr parent = branch.constant.empty() ? - std::make_shared(branch.precisionBeforeDequantization, branch.inputShape) : +BranchNodes makeBranch(const MultiplyBranch& branch) { + std::shared_ptr parent = branch.constant.empty() ? + std::make_shared(branch.input_precision, branch.inputShape) : std::dynamic_pointer_cast(std::make_shared( branch.constant.outPrecision, branch.constant.shape, branch.constant.values)); + if (!branch.fake_quantize.empty()) { + if ((parent->get_output_element_type(0) != element::f32) && + (parent->get_output_element_type(0) != element::f16)) { + throw std::runtime_error("unexpected precision before FakeQuantize"); + } + parent = makeFakeQuantize(parent, parent->get_output_element_type(0), branch.fake_quantize); + } + const auto dequantization = makeDequantization(parent, branch.dequantization); + return {parent, dequantization}; } +} // namespace multiply_function -std::shared_ptr MultiplyFunction::get( - const element::Type precision, - const MultiplyValues& actualValues) { - auto branch1Structure = actualValues.branch1; - branch1Structure.precisionBeforeDequantization = precision; - branch1Structure.dequantization.multiply.outPrecision = precision; - auto branch2Structure = actualValues.branch2; - branch2Structure.precisionBeforeDequantization = precision; - branch2Structure.dequantization.multiply.outPrecision = precision; +std::shared_ptr MultiplyFunction::get(const element::Type model_precision, const MultiplyValues& actualValues) { + const auto branchNodes1 = multiply_function::makeBranch(actualValues.branch1); + const auto branchNodes2 = multiply_function::makeBranch(actualValues.branch2); - const BranchNodes branchNodes1 = getBranch(actualValues.branch1); - const BranchNodes branchNodes2 = getBranch(actualValues.branch2); - - auto multiplyOriginal = opset1::Multiply( + // branchNodes1.dequantization & branchNodes2.dequantization can have different input types + std::shared_ptr parent = std::make_shared>( + std::vector{ element::f32, element::f32 }, + std::vector{ actualValues.after_dequantization.empty() ? model_precision : element::f32 }, ov::op::TemporaryReplaceOutputType(branchNodes1.dequantization, element::f32).get(), ov::op::TemporaryReplaceOutputType(branchNodes2.dequantization, element::f32).get()); - const std::shared_ptr multiply = std::make_shared>( - multiplyOriginal, - std::vector{element::f32, element::f32}, - std::vector{precision}); - auto& rtInfo = multiply->get_rt_info(); + auto& rtInfo = parent->get_rt_info(); rtInfo["Variant::std::string"] = "multiply"; - multiply->set_friendly_name("output"); - ngraph::ResultVector results{ std::make_shared(multiply) }; + parent = makeDequantization(parent, actualValues.after_dequantization); + parent->set_friendly_name("output"); + + ngraph::ResultVector results{ std::make_shared(parent) }; ngraph::ParameterVector inputs; if (is_type(branchNodes1.input)) { @@ -73,78 +78,6 @@ std::shared_ptr MultiplyFunction::get( return std::make_shared(results, inputs, "MultiplyTransformation"); } -std::shared_ptr MultiplyFunction::getOriginal( - const ngraph::element::Type precision, - const ngraph::PartialShape& inputShape, - const bool broadcast1, - const ngraph::builder::subgraph::FakeQuantizeOnData& fq1, - const bool broadcast2, - const ngraph::builder::subgraph::FakeQuantizeOnData& fq2, - const ngraph::builder::subgraph::FakeQuantizeOnData& fqAfter, - const bool secondInputIsConstant) { - auto inputShape1 = inputShape; - if (broadcast1) { - inputShape1[2] = 1; - inputShape1[3] = 1; - } - - ngraph::PartialShape inputShape2; - if (secondInputIsConstant) { - inputShape2 = {}; - } else { - inputShape2 = inputShape; - if (broadcast2) { - inputShape2[2] = 1; - inputShape2[3] = 1; - } - } - - const auto input1 = std::make_shared(precision, inputShape1); - const auto fakeQuantize1 = fq1.empty() ? - nullptr : - ngraph::builder::makeFakeQuantize( - input1, precision, fq1.quantizationLevel, fq1.constantShape, - fq1.inputLowValues, fq1.inputHighValues, fq1.outputLowValues, fq1.outputHighValues); - if (fakeQuantize1 != nullptr) { - fakeQuantize1->set_friendly_name("fakeQuantize1"); - } - - const std::shared_ptr input2 = secondInputIsConstant ? - makeConstant(element::f32, Shape{}, std::vector{0.5f}, false) : - std::make_shared(precision, inputShape2); - const auto fakeQuantize2 = fq2.empty() ? - nullptr : - ngraph::builder::makeFakeQuantize( - input2, precision, fq2.quantizationLevel, fq2.constantShape, - fq2.inputLowValues, fq2.inputHighValues, fq2.outputLowValues, fq2.outputHighValues); - if (fakeQuantize2 != nullptr) { - fakeQuantize2->set_friendly_name("fakeQuantize2"); - } - - const auto multiply = std::make_shared( - fq1.empty() ? input1 : fakeQuantize1, - fq2.empty() ? input2 : fakeQuantize2); - multiply->set_friendly_name("multiply"); - - auto const fakeQuantizeAfter = fqAfter.empty() ? - nullptr : - makeFakeQuantize(multiply, precision, fqAfter); - if (fakeQuantizeAfter != nullptr) { - fakeQuantizeAfter->set_friendly_name("fakeQuantizeAfter"); - } - - const std::shared_ptr result = fakeQuantizeAfter == nullptr ? std::dynamic_pointer_cast(multiply) : fakeQuantizeAfter; - ngraph::ResultVector results{ std::make_shared(result) }; - std::shared_ptr function = std::make_shared( - results, - secondInputIsConstant ? - ngraph::ParameterVector{ input1 } : - ngraph::ParameterVector{ input1, ngraph::as_type_ptr(input2) }, - "MultiplyTransformation"); - - return function; -} - } // namespace subgraph } // namespace builder } // namespace ngraph diff --git a/src/tests/ngraph_helpers/lpt_ngraph_functions/src/multiply_partial_function.cpp b/src/tests/ngraph_helpers/lpt_ngraph_functions/src/multiply_partial_function.cpp new file mode 100644 index 00000000000..e41d340a634 --- /dev/null +++ b/src/tests/ngraph_helpers/lpt_ngraph_functions/src/multiply_partial_function.cpp @@ -0,0 +1,154 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "lpt_ngraph_functions/multiply_partial_function.hpp" + +#include + +#include +#include +#include "ngraph_functions/subgraph_builders.hpp" +#include "low_precision/network_helper.hpp" + +#include "lpt_ngraph_functions/common/builders.hpp" +#include "lpt_ngraph_functions/common/dequantization_operations.hpp" + +using namespace ov::pass::low_precision; + +namespace ngraph { +namespace builder { +namespace subgraph { + +namespace multiply_partial_function { +struct BranchNodes { + std::shared_ptr input; + std::shared_ptr dequantization; +}; + +BranchNodes getBranch(const MultiplyPartialBranch& branch) { + const std::shared_ptr parent = branch.constant.empty() ? + std::make_shared(branch.precisionBeforeDequantization, branch.inputShape) : + std::dynamic_pointer_cast(std::make_shared( + branch.constant.outPrecision, + branch.constant.shape, + branch.constant.values)); + + const auto dequantization = makeDequantization(parent, branch.dequantization); + return {parent, dequantization}; +} +} // namespace multiply_partial_function + +std::shared_ptr MultiplyPartialFunction::get( + const element::Type precision, + const MultiplyPartialValues& actualValues) { + auto branch1Structure = actualValues.branch1; + branch1Structure.precisionBeforeDequantization = precision; + branch1Structure.dequantization.multiply.outPrecision = precision; + auto branch2Structure = actualValues.branch2; + branch2Structure.precisionBeforeDequantization = precision; + branch2Structure.dequantization.multiply.outPrecision = precision; + + const auto branchNodes1 = multiply_partial_function::getBranch(actualValues.branch1); + const auto branchNodes2 = multiply_partial_function::getBranch(actualValues.branch2); + + auto multiplyOriginal = opset1::Multiply( + ov::op::TemporaryReplaceOutputType(branchNodes1.dequantization, element::f32).get(), + ov::op::TemporaryReplaceOutputType(branchNodes2.dequantization, element::f32).get()); + + const std::shared_ptr multiply = std::make_shared>( + multiplyOriginal, + std::vector{element::f32, element::f32}, + std::vector{precision}); + auto& rtInfo = multiply->get_rt_info(); + rtInfo["Variant::std::string"] = "multiply"; + multiply->set_friendly_name("output"); + + ngraph::ResultVector results{ std::make_shared(multiply) }; + + ngraph::ParameterVector inputs; + if (is_type(branchNodes1.input)) { + inputs.push_back(std::dynamic_pointer_cast(branchNodes1.input)); + } + if (is_type(branchNodes2.input)) { + inputs.push_back(std::dynamic_pointer_cast(branchNodes2.input)); + } + + return std::make_shared(results, inputs, "MultiplyTransformation"); +} + +std::shared_ptr MultiplyPartialFunction::get( + const ngraph::element::Type precision, + const ngraph::PartialShape& inputShape, + const bool broadcast1, + const ngraph::builder::subgraph::FakeQuantizeOnData& fq1, + const bool broadcast2, + const ngraph::builder::subgraph::FakeQuantizeOnData& fq2, + const ngraph::builder::subgraph::FakeQuantizeOnData& fqAfter, + const bool secondInputIsConstant) { + auto inputShape1 = inputShape; + if (broadcast1) { + inputShape1[2] = 1; + inputShape1[3] = 1; + } + + ngraph::PartialShape inputShape2; + if (secondInputIsConstant) { + inputShape2 = {}; + } else { + inputShape2 = inputShape; + if (broadcast2) { + inputShape2[2] = 1; + inputShape2[3] = 1; + } + } + + const auto input1 = std::make_shared(precision, inputShape1); + const auto fakeQuantize1 = fq1.empty() ? + nullptr : + ngraph::builder::makeFakeQuantize( + input1, precision, fq1.quantizationLevel, fq1.constantShape, + fq1.inputLowValues, fq1.inputHighValues, fq1.outputLowValues, fq1.outputHighValues); + if (fakeQuantize1 != nullptr) { + fakeQuantize1->set_friendly_name("fakeQuantize1"); + } + + const std::shared_ptr input2 = secondInputIsConstant ? + makeConstant(element::f32, Shape{}, std::vector{0.5f}, false) : + std::make_shared(precision, inputShape2); + const auto fakeQuantize2 = fq2.empty() ? + nullptr : + ngraph::builder::makeFakeQuantize( + input2, precision, fq2.quantizationLevel, fq2.constantShape, + fq2.inputLowValues, fq2.inputHighValues, fq2.outputLowValues, fq2.outputHighValues); + if (fakeQuantize2 != nullptr) { + fakeQuantize2->set_friendly_name("fakeQuantize2"); + } + + const auto multiply = std::make_shared( + fq1.empty() ? input1 : fakeQuantize1, + fq2.empty() ? input2 : fakeQuantize2); + multiply->set_friendly_name("multiply"); + + auto const fakeQuantizeAfter = fqAfter.empty() ? + nullptr : + makeFakeQuantize(multiply, precision, fqAfter); + if (fakeQuantizeAfter != nullptr) { + fakeQuantizeAfter->set_friendly_name("fakeQuantizeAfter"); + } + + const std::shared_ptr result = fakeQuantizeAfter == nullptr ? std::dynamic_pointer_cast(multiply) : fakeQuantizeAfter; + ngraph::ResultVector results{ std::make_shared(result) }; + std::shared_ptr function = std::make_shared( + results, + secondInputIsConstant ? + ngraph::ParameterVector{ input1 } : + ngraph::ParameterVector{ input1, ngraph::as_type_ptr(input2) }, + "MultiplyTransformation"); + + return function; +} + +} // namespace subgraph +} // namespace builder +} // namespace ngraph