[LPT] [NPU] Multiply support (#19859)

* [LPT] [NPU] Multiply support

* [LPT] [NPU] Multiply support documentation

* 1) FakeQuantize support 2) refactoring

* [LPT] DisableCleanup attribute + cleanup transformations extension

* [LPT] DisableCleanup usage

* [LPT] Tests infrastructure support

* [LPT] infrastructure quick fix

* [LPT] Recurrent Cell Transformation fix

* refactoring & comment fixes
This commit is contained in:
Edward Shogulin 2023-10-03 15:31:33 +01:00 committed by GitHub
parent 2b07576e2b
commit ae3b19d034
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
47 changed files with 2570 additions and 1250 deletions

View File

@ -200,7 +200,7 @@ Transformations:
* :doc:`GatherTransformation <openvino_docs_OV_UG_lpt_GatherTransformation>`
* :doc:`MatMulTransformation <openvino_docs_OV_UG_lpt_MatMulTransformation>`
* :doc:`MaxPoolTransformation <openvino_docs_OV_UG_lpt_MaxPoolTransformation>`
* :doc:`MultiplyTransformation <openvino_docs_OV_UG_lpt_MultiplyTransformation>`
* :doc:`MultiplyPartialTransformation <openvino_docs_OV_UG_lpt_MultiplyPartialTransformation>`
* :doc:`MVNTransformation <openvino_docs_OV_UG_lpt_MVNTransformation>`
* :doc:`NormalizeL2Transformation <openvino_docs_OV_UG_lpt_NormalizeL2Transformation>`
* :doc:`PReluTransformation <openvino_docs_OV_UG_lpt_PReluTransformation>`

View File

@ -26,7 +26,7 @@
GatherTransformation <openvino_docs_OV_UG_lpt_GatherTransformation>
MatMulTransformation <openvino_docs_OV_UG_lpt_MatMulTransformation>
MaxPoolTransformation <openvino_docs_OV_UG_lpt_MaxPoolTransformation>
MultiplyTransformation <openvino_docs_OV_UG_lpt_MultiplyTransformation>
MultiplyPartialTransformation <openvino_docs_OV_UG_lpt_MultiplyPartialTransformation>
MVNTransformation <openvino_docs_OV_UG_lpt_MVNTransformation>
NormalizeL2Transformation <openvino_docs_OV_UG_lpt_NormalizeL2Transformation>
PadTransformation<openvino_docs_OV_UG_lpt_PadTransformation>
@ -45,7 +45,7 @@
TransposeTransformation <openvino_docs_OV_UG_lpt_TransposeTransformation>
UnsqueezeTransformation <openvino_docs_OV_UG_lpt_UnsqueezeTransformation>
VariadicSplitTransformation <openvino_docs_OV_UG_lpt_VariadicSplitTransformation>
Main transformations are the majority of low precision transformations. Transformations operate with dequantization operations. Main transformations include:
@ -64,7 +64,7 @@ Main transformations are the majority of low precision transformations. Transfor
* :doc:`GatherTransformation <openvino_docs_OV_UG_lpt_GatherTransformation>`
* :doc:`MatMulTransformation <openvino_docs_OV_UG_lpt_MatMulTransformation>`
* :doc:`MaxPoolTransformation <openvino_docs_OV_UG_lpt_MaxPoolTransformation>`
* :doc:`MultiplyTransformation <openvino_docs_OV_UG_lpt_MultiplyTransformation>`
* :doc:`MultiplyPartialTransformation <openvino_docs_OV_UG_lpt_MultiplyPartialTransformation>`
* :doc:`MVNTransformation <openvino_docs_OV_UG_lpt_MVNTransformation>`
* :doc:`NormalizeL2Transformation <openvino_docs_OV_UG_lpt_NormalizeL2Transformation>`
* :doc:`PadTransformation<openvino_docs_OV_UG_lpt_PadTransformation>`

View File

@ -0,0 +1,3 @@
# MultiplyTransformation transformation {#openvino_docs_OV_UG_lpt_MultiplyPartialTransformation}
ov::pass::low_precision::MultiplyPartialTransformation class represents the `MultiplyPartial` operation transformation.

View File

@ -0,0 +1,30 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "layer_transformation.hpp"
namespace ov {
namespace pass {
namespace low_precision {
/**
* @ingroup ie_transformation_common_api
* @brief Base class for cleanup low precision transformation.
*/
class LP_TRANSFORMATIONS_API CleanupTransformation : public LayerTransformation {
public:
CleanupTransformation(const Params& params);
virtual ~CleanupTransformation() = default;
bool canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const override;
static bool canBeTransformedStatic(
const std::shared_ptr<Node>& layer,
const std::vector<ov::element::Type>& defaultPrecisions = precision_set::get_int8_support());
};
} // namespace low_precision
} // namespace pass
} // namespace ov

View File

@ -74,7 +74,7 @@ public:
}
template <typename T>
static PrecisionsByPorts getPrecisionsByOperationType(std::vector<PrecisionsRestriction>& restrictions) {
static PrecisionsByPorts getPrecisionsByOperationType(const std::vector<PrecisionsRestriction>& restrictions) {
for (const auto& restriction : restrictions) {
if (restriction.operationType == T::get_type_info_static()) {
return restriction.precisionsByPorts;

View File

@ -6,7 +6,7 @@
#include <memory>
#include "low_precision/layer_transformation.hpp"
#include "low_precision/cleanup_transformation.hpp"
namespace ov {
namespace pass {
@ -20,7 +20,7 @@ namespace low_precision {
* [EliminateFakeQuantizeTransformation](@ref openvino_docs_OV_UG_lpt_EliminateFakeQuantizeTransformation) page
* in the Inference Engine Developer Guide.
*/
class LP_TRANSFORMATIONS_API EliminateFakeQuantizeTransformation : public LayerTransformation {
class LP_TRANSFORMATIONS_API EliminateFakeQuantizeTransformation : public CleanupTransformation {
public:
OPENVINO_RTTI("EliminateFakeQuantizeTransformation", "0");
EliminateFakeQuantizeTransformation(const Params& params = Params());

View File

@ -6,7 +6,7 @@
#include <memory>
#include "low_precision/layer_transformation.hpp"
#include "low_precision/cleanup_transformation.hpp"
namespace ov {
namespace pass {
@ -20,7 +20,7 @@ namespace low_precision {
* [FoldConvertTransformation](@ref openvino_docs_OV_UG_lpt_FoldConvertTransformation) page
* in the Inference Engine Developer Guide.
*/
class LP_TRANSFORMATIONS_API FoldConvertTransformation : public LayerTransformation {
class LP_TRANSFORMATIONS_API FoldConvertTransformation : public CleanupTransformation {
public:
OPENVINO_RTTI("FoldConvertTransformation", "0");
FoldConvertTransformation(const Params& params = Params());

View File

@ -4,9 +4,7 @@
#pragma once
#include "low_precision/layer_transformation.hpp"
#include "low_precision/eltwise_base_transformation.hpp"
#include "low_precision/cleanup_transformation.hpp"
namespace ov {
namespace pass {
@ -20,7 +18,7 @@ namespace low_precision {
* [FuseConvertTransformation](@ref openvino_docs_OV_UG_lpt_FuseConvertTransformation) page
* in the Inference Engine Developer Guide.
*/
class LP_TRANSFORMATIONS_API FuseConvertTransformation : public LayerTransformation {
class LP_TRANSFORMATIONS_API FuseConvertTransformation : public CleanupTransformation {
public:
OPENVINO_RTTI("FuseConvertTransformation", "0");
FuseConvertTransformation(const Params& params = Params());

View File

@ -0,0 +1,29 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <memory>
#include "low_precision/cleanup_transformation.hpp"
namespace ov {
namespace pass {
namespace low_precision {
/**
* @ingroup ie_transformation_common_api
* @brief Base class for fuse elementwise to FakeQuantize low precision transformation.
*/
class LP_TRANSFORMATIONS_API FuseElementwiseToFakeQuantizeTransformation : public CleanupTransformation {
public:
FuseElementwiseToFakeQuantizeTransformation(const Params& params);
virtual ~FuseElementwiseToFakeQuantizeTransformation() = default;
bool canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const override;
};
} // namespace low_precision
} // namespace pass
} // namespace ov

View File

@ -6,7 +6,7 @@
#include <memory>
#include "low_precision/layer_transformation.hpp"
#include "low_precision/fuse_elementwise_to_fake_quantize.hpp"
namespace ov {
namespace pass {
@ -20,12 +20,11 @@ namespace low_precision {
* [FuseMultiplyToFakeQuantizeTransformation](@ref openvino_docs_OV_UG_lpt_FuseMultiplyToFakeQuantizeTransformation) page
* in the Inference Engine Developer Guide.
*/
class LP_TRANSFORMATIONS_API FuseMultiplyToFakeQuantizeTransformation : public LayerTransformation {
class LP_TRANSFORMATIONS_API FuseMultiplyToFakeQuantizeTransformation : public FuseElementwiseToFakeQuantizeTransformation {
public:
OPENVINO_RTTI("FuseMultiplyToFakeQuantizeTransformation", "0");
FuseMultiplyToFakeQuantizeTransformation(const Params& params = Params());
bool transform(TransformationContext& context, ov::pass::pattern::Matcher &m) override;
bool canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const override;
bool isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept override;
};

View File

@ -6,7 +6,7 @@
#include <memory>
#include "low_precision/layer_transformation.hpp"
#include "low_precision/fuse_elementwise_to_fake_quantize.hpp"
namespace ov {
namespace pass {
@ -20,12 +20,11 @@ namespace low_precision {
* [FuseSubtractToFakeQuantizeTransformation](@ref openvino_docs_OV_UG_lpt_FuseSubtractToFakeQuantizeTransformation) page
* in the Inference Engine Developer Guide.
*/
class LP_TRANSFORMATIONS_API FuseSubtractToFakeQuantizeTransformation : public LayerTransformation {
class LP_TRANSFORMATIONS_API FuseSubtractToFakeQuantizeTransformation : public FuseElementwiseToFakeQuantizeTransformation {
public:
OPENVINO_RTTI("FuseSubtractToFakeQuantizeTransformation", "0");
FuseSubtractToFakeQuantizeTransformation(const Params& params = Params());
bool transform(TransformationContext& context, ov::pass::pattern::Matcher &m) override;
bool canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const override;
bool isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept override;
};

View File

@ -371,7 +371,7 @@ protected:
const bool updatePrecision,
const bool moveSubtract = true) const;
void updateOutput(
bool updateOutput(
TransformationContext &context,
std::shared_ptr<ov::Node> lastNode,
std::shared_ptr<ov::Node> originalNode) const;

View File

@ -48,9 +48,9 @@ public:
const AttributeParameters& params);
bool run_on_model(const std::shared_ptr<ov::Model>& m) override;
private:
const std::vector<PrecisionsRestriction>& precisionRestrictions;
const std::vector<QuantizationGranularityRestriction>& quantizationRestrictions;
const AttributeParameters& params;
const std::vector<PrecisionsRestriction> precisionRestrictions;
const std::vector<QuantizationGranularityRestriction> quantizationRestrictions;
const AttributeParameters params;
};
class ov::pass::low_precision::TypeRelaxedReplacer : public ov::pass::GraphRewrite {
@ -71,9 +71,18 @@ public:
static bool isFunctionQuantized(const std::shared_ptr<const ov::Model>& model);
static bool isFQLevelsPresent(const std::shared_ptr<const ov::Model>& model, const std::set<size_t>& levels);
template <typename T, class... Args>
std::shared_ptr<T> add_main(Args&&... args) {
const auto tr = std::make_shared<T>(std::forward<Args>(args)...);
additional_main_passes.push_back(tr);
return tr;
}
protected:
std::vector<PrecisionsRestriction> precisionRestrictions;
std::vector<QuantizationGranularityRestriction> quantizationRestrictions;
// remove
LayerTransformation::Params params;
std::vector<std::shared_ptr<MatcherPass>> additional_main_passes;
};

View File

@ -5,7 +5,7 @@
#pragma once
#include "low_precision/eltwise_base_transformation.hpp"
#include "low_precision/weightable_layer_transformation.hpp"
namespace ov {
namespace pass {
@ -19,12 +19,14 @@ namespace low_precision {
* [MultiplyTransformation](@ref openvino_docs_OV_UG_lpt_MultiplyTransformation) page
* in the Inference Engine Developer Guide.
*/
class LP_TRANSFORMATIONS_API MultiplyTransformation : public EltwiseBaseTransformation {
class LP_TRANSFORMATIONS_API MultiplyTransformation : public WeightableLayerTransformation {
public:
OPENVINO_RTTI("MultiplyTransformation", "0");
MultiplyTransformation(const Params& params = Params());
bool transform(TransformationContext& context, ov::pass::pattern::Matcher &m) override;
bool canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const override;
protected:
size_t getInputChannels(const std::shared_ptr<ov::Node> op) const override;
};
} // namespace low_precision

View File

@ -0,0 +1,32 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <ngraph/ngraph.hpp>
#include "low_precision/eltwise_base_transformation.hpp"
namespace ov {
namespace pass {
namespace low_precision {
/**
* @ingroup ie_transformation_common_api
* @brief MultiplyPartialTransformation propagates dequantization operations through Multiply operation.
*
* For more details about the transformation, refer to
* [MultiplyPartialTransformation](@ref openvino_docs_OV_UG_lpt_MultiplyPartialTransformation) page
* in the Inference Engine Developer Guide.
*/
class LP_TRANSFORMATIONS_API MultiplyPartialTransformation : public EltwiseBaseTransformation {
public:
OPENVINO_RTTI("MultiplyPartialTransformation", "0");
MultiplyPartialTransformation(const Params& params = Params());
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) override;
bool canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const override;
};
} // namespace low_precision
} // namespace pass
} // namespace ov

View File

@ -5,7 +5,7 @@
#pragma once
#include <memory>
#include "low_precision/layer_transformation.hpp"
#include "low_precision/cleanup_transformation.hpp"
#include "common/precisions_restriction.hpp"
namespace ov {
@ -20,7 +20,7 @@ namespace low_precision {
* [MultiplyToGroupConvolutionTransformation](@ref openvino_docs_OV_UG_lpt_MultiplyToGroupConvolutionTransformation) page
* in the Inference Engine Developer Guide.
*/
class LP_TRANSFORMATIONS_API MultiplyToGroupConvolutionTransformation : public LayerTransformation {
class LP_TRANSFORMATIONS_API MultiplyToGroupConvolutionTransformation : public CleanupTransformation {
public:
OPENVINO_RTTI("MultiplyToGroupConvolutionTransformation", "0");
MultiplyToGroupConvolutionTransformation(

View File

@ -0,0 +1,27 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/core/node.hpp"
#include "openvino/core/runtime_attribute.hpp"
#include "low_precision/lpt_visibility.hpp"
namespace ov {
class LP_TRANSFORMATIONS_API DisableCleanupAttribute : public ov::RuntimeAttribute {
public:
OPENVINO_RTTI("LowPrecision::DisableCleanup", "", ov::RuntimeAttribute);
DisableCleanupAttribute() = default;
static ov::Any create(const std::shared_ptr<ov::Node>& node) {
auto& rt = node->get_rt_info();
return (rt[DisableCleanupAttribute::get_type_info_static()] = DisableCleanupAttribute());
}
bool is_copyable() const override {
return false;
}
};
} // namespace ov

View File

@ -1,17 +0,0 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/core/node.hpp"
#include "low_precision/rt_info/attribute_parameters.hpp"
namespace ov {
class LP_TRANSFORMATIONS_API SkipCleanupAttribute : public ov::RuntimeAttribute {
public:
OPENVINO_RTTI("LowPrecision::SkipCleanup", "", ov::RuntimeAttribute);
static ov::Any create(const std::shared_ptr<ov::Node>& node);
};
} // namespace ov

View File

@ -19,7 +19,29 @@ namespace low_precision {
*/
class LP_TRANSFORMATIONS_API WeightableLayerTransformation : public LayerTransformation {
public:
WeightableLayerTransformation(const Params& params);
struct LP_TRANSFORMATIONS_API CanBeTransformedParams {
CanBeTransformedParams(
const bool constantWeight = true,
const bool perTensorQuantizationOnData = true,
const bool limitWeightsDataPrecision = true,
const bool dynamicWeights = false) :
constantWeight(constantWeight),
perTensorQuantizationOnData(perTensorQuantizationOnData),
limitWeightsDataPrecision(limitWeightsDataPrecision),
dynamicWeights(dynamicWeights) {
}
// weights on constant path only
const bool constantWeight;
// data with per-tensor quantization only
const bool perTensorQuantizationOnData;
// limit weights by expected precisions
const bool limitWeightsDataPrecision;
const bool dynamicWeights;
};
WeightableLayerTransformation(const Params& params, const CanBeTransformedParams& canBeTransformedParams = {});
bool canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const override;
bool canConvolutionBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer,
const std::vector<ov::element::Type>& defaultPrecisions) const;
@ -48,6 +70,9 @@ public:
static DataPrecision getDataPrecisionOnWeights(const std::shared_ptr<Node>& node, const std::vector<ov::element::Type>& defaultPrecisions);
static bool isAsymmetricOnWeights(const std::shared_ptr<const Node>& node,
const std::vector<ov::element::Type>& defaultPrecisions = precision_set::get_int8_support());
private:
const CanBeTransformedParams canBeTransformedParams;
};
} // namespace low_precision

View File

@ -0,0 +1,26 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "low_precision/cleanup_transformation.hpp"
#include "low_precision/network_helper.hpp"
#include "low_precision/rt_info/disable_cleanup_attribute.hpp"
namespace ov {
namespace pass {
namespace low_precision {
CleanupTransformation::CleanupTransformation(const Params& params) : LayerTransformation(params) {
}
bool CleanupTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const {
return canBeTransformedStatic(layer);
}
bool CleanupTransformation::canBeTransformedStatic(const std::shared_ptr<Node>& layer, const std::vector<ov::element::Type>& defaultPrecisions) {
return getAttribute<DisableCleanupAttribute>(layer).empty();
}
} // namespace low_precision
} // namespace pass
} // namespace ov

View File

@ -13,6 +13,7 @@
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "openvino/pass/pattern/op/or.hpp"
#include "low_precision/network_helper.hpp"
#include "low_precision/rt_info/disable_cleanup_attribute.hpp"
#include "transformations/rt_info/disable_constant_folding.hpp"
#include "itt.hpp"
@ -333,6 +334,11 @@ bool ConvolutionTransformation::transform(TransformationContext &context, ov::pa
ov::copy_runtime_info({ convolution, finalDequantization }, finalDequantization);
updateOutput(context, finalDequantization, convolution);
const auto onActiviation = convolution->get_input_node_shared_ptr(0);
if (ov::is_type<ov::opset1::Subtract>(onActiviation)) {
DisableCleanupAttribute::create(onActiviation);
}
auto onWeights = convolution->get_input_node_shared_ptr(1);
if (ov::is_type<ov::opset1::Reshape>(onWeights)) {
onWeights = onWeights->get_input_node_shared_ptr(0);

View File

@ -13,6 +13,7 @@
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "openvino/pass/pattern/op/or.hpp"
#include "low_precision/network_helper.hpp"
#include "low_precision/rt_info/disable_cleanup_attribute.hpp"
#include "transformations/rt_info/disable_constant_folding.hpp"
#include "itt.hpp"
@ -220,6 +221,11 @@ bool ConvolutionBackpropDataTransformation::transform(TransformationContext &con
ov::copy_runtime_info({ convolutionBackpropData, finalDequantization }, finalDequantization);
updateOutput(context, finalDequantization, convolutionBackpropData);
const auto onActiviation = convolutionBackpropData->get_input_node_shared_ptr(0);
if (ov::is_type<ov::opset1::Subtract>(onActiviation)) {
DisableCleanupAttribute::create(onActiviation);
}
auto onWeights = convolutionBackpropData->get_input_node_shared_ptr(1);
if (ov::is_type<ov::opset1::Reshape>(onWeights)) {
onWeights = onWeights->get_input_node_shared_ptr(0);

View File

@ -15,7 +15,7 @@ namespace ov {
namespace pass {
namespace low_precision {
EliminateFakeQuantizeTransformation::EliminateFakeQuantizeTransformation(const Params& params) : LayerTransformation(params) {
EliminateFakeQuantizeTransformation::EliminateFakeQuantizeTransformation(const Params& params) : CleanupTransformation(params) {
MATCHER_SCOPE(FuseMultiplyToFakeQuantizeTransformation);
const auto matcher = pattern::wrap_type<ov::opset1::FakeQuantize>({
pattern::any_input(),
@ -112,6 +112,10 @@ bool check_intervals(const std::shared_ptr<ov::opset1::FakeQuantize>& fakeQuanti
} // namespace
bool EliminateFakeQuantizeTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> operation) const {
if (!CleanupTransformation::canBeTransformed(context, operation)) {
return false;
}
const auto fakeQuantize = ov::as_type_ptr<ov::opset1::FakeQuantize>(operation);
OPENVINO_ASSERT(fakeQuantize != nullptr, "unexpected operation type");

View File

@ -11,6 +11,7 @@
#include "low_precision/network_helper.hpp"
#include "low_precision/rt_info/bias_attribute.hpp"
#include "low_precision/rt_info/disable_cleanup_attribute.hpp"
#include "itt.hpp"
namespace ov {
@ -167,6 +168,10 @@ std::shared_ptr<opset1::FakeQuantize> FakeQuantizeTransformation::fuseElementwis
return nullptr;
}
if (!getAttribute<DisableCleanupAttribute>(eltwise).empty()) {
return nullptr;
}
std::shared_ptr<Node> inputLowConst_f32 = foldConvert(fakeQuantize->input_value(1), element::f32);
std::shared_ptr<Node> inputHighConst_f32 = foldConvert(fakeQuantize->input_value(2), element::f32);

View File

@ -14,7 +14,7 @@ namespace ov {
namespace pass {
namespace low_precision {
FoldConvertTransformation::FoldConvertTransformation(const Params& params) : LayerTransformation(params) {
FoldConvertTransformation::FoldConvertTransformation(const Params& params) : CleanupTransformation(params) {
MATCHER_SCOPE(FoldConvertTransformation);
auto subtract = pattern::wrap_type<ov::opset1::Subtract>();
auto matcher = std::make_shared<ov::pass::pattern::Matcher>(subtract, matcher_name);
@ -57,10 +57,11 @@ bool FoldConvertTransformation::transform(TransformationContext& context, ov::pa
bool FoldConvertTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> operation) const {
return
(ov::is_type<ov::opset1::Convert>(operation->get_input_node_ptr(1)) &&
CleanupTransformation::canBeTransformed(context, operation) &&
((ov::is_type<ov::opset1::Convert>(operation->get_input_node_ptr(1)) &&
ov::is_type<ov::opset1::Constant>(operation->get_input_node_ptr(1)->get_input_node_ptr(0))) ||
(ov::is_type<ov::opset1::Convert>(operation->get_input_node_ptr(0)) &&
ov::is_type<ov::opset1::Constant>(operation->get_input_node_ptr(0)->get_input_node_ptr(0)));
ov::is_type<ov::opset1::Constant>(operation->get_input_node_ptr(0)->get_input_node_ptr(0))));
}
bool FoldConvertTransformation::isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept {

View File

@ -12,14 +12,15 @@
#include "low_precision/common/ie_lpt_exception.hpp"
#include "low_precision/network_helper.hpp"
#include "low_precision/rt_info/disable_cleanup_attribute.hpp"
#include "itt.hpp"
#include "low_precision/rt_info/skip_cleanup_attribute.hpp"
namespace ov {
namespace pass {
namespace low_precision {
FuseConvertTransformation::FuseConvertTransformation(const Params& params) : LayerTransformation(params) {
FuseConvertTransformation::FuseConvertTransformation(const Params& params) : CleanupTransformation(params) {
MATCHER_SCOPE(FuseConvertTransformation);
auto multiply = pattern::wrap_type<ov::opset1::Multiply>({ pattern::wrap_type<ov::opset1::Convert>(), pattern::wrap_type<ov::opset1::Constant>() });
auto subtract = pattern::wrap_type<ov::opset1::Subtract>({ pattern::wrap_type<ov::opset1::Convert>(), pattern::wrap_type<ov::opset1::Constant>() });
@ -114,7 +115,7 @@ bool FuseConvertTransformation::transform(TransformationContext& context, ov::pa
}
bool FuseConvertTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> op) const {
if (!getAttribute<SkipCleanupAttribute>(op).empty()) {
if (!CleanupTransformation::canBeTransformed(context, op)) {
return false;
}

View File

@ -0,0 +1,52 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "low_precision/fuse_elementwise_to_fake_quantize.hpp"
#include <memory>
#include "low_precision/fake_quantize.hpp"
#include "low_precision/network_helper.hpp"
namespace ov {
namespace pass {
namespace low_precision {
FuseElementwiseToFakeQuantizeTransformation::FuseElementwiseToFakeQuantizeTransformation(const Params& params) : CleanupTransformation(params) {
}
bool FuseElementwiseToFakeQuantizeTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> operation) const {
if (!CleanupTransformation::canBeTransformed(context, operation)) {
return false;
}
if (!ov::is_type<ov::opset1::Constant>(operation->get_input_node_shared_ptr(1))) {
return false;
}
if (!FakeQuantizeTransformation::checkElementwise(operation)) {
return false;
}
const auto parent = operation->get_input_node_shared_ptr(0);
auto fq = ov::as_type_ptr<ov::opset1::FakeQuantize>(parent);
const auto convert = ov::as_type_ptr<ov::opset1::Convert>(parent);
if (convert) {
fq = ov::as_type_ptr<ov::opset1::FakeQuantize>(convert->get_input_node_shared_ptr(0));
}
if (!fq) {
return false;
}
if (fq->get_output_target_inputs(0).size() != 1) {
return false;
}
return true;
}
} // namespace low_precision
} // namespace pass
} // namespace ov

View File

@ -9,13 +9,14 @@
#include "low_precision/fake_quantize.hpp"
#include "low_precision/network_helper.hpp"
#include "itt.hpp"
#include "low_precision/rt_info/skip_cleanup_attribute.hpp"
#include "low_precision/rt_info/disable_cleanup_attribute.hpp"
namespace ov {
namespace pass {
namespace low_precision {
FuseMultiplyToFakeQuantizeTransformation::FuseMultiplyToFakeQuantizeTransformation(const Params& params) : LayerTransformation(params) {
FuseMultiplyToFakeQuantizeTransformation::FuseMultiplyToFakeQuantizeTransformation(const Params& params)
: FuseElementwiseToFakeQuantizeTransformation(params) {
MATCHER_SCOPE(FuseMultiplyToFakeQuantizeTransformation);
auto matcher = pattern::wrap_type<ov::opset1::Multiply>();
@ -89,38 +90,6 @@ bool FuseMultiplyToFakeQuantizeTransformation::transform(TransformationContext&
return true;
}
bool FuseMultiplyToFakeQuantizeTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> operation) const {
if (!ov::is_type<ov::opset1::Constant>(operation->get_input_node_shared_ptr(1))) {
return false;
}
if (!FakeQuantizeTransformation::checkElementwise(operation)) {
return false;
}
if (!getAttribute<SkipCleanupAttribute>(operation).empty()) {
return false;
}
const auto parent = operation->get_input_node_shared_ptr(0);
auto fq = ov::as_type_ptr<ov::opset1::FakeQuantize>(parent);
const auto convert = ov::as_type_ptr<ov::opset1::Convert>(parent);
if (convert) {
fq = ov::as_type_ptr<ov::opset1::FakeQuantize>(convert->get_input_node_shared_ptr(0));
}
if (!fq) {
return false;
}
if (fq->get_output_target_inputs(0).size() != 1) {
return false;
}
return true;
}
bool FuseMultiplyToFakeQuantizeTransformation::isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept {
return false;
}

View File

@ -9,13 +9,14 @@
#include "low_precision/fake_quantize.hpp"
#include "low_precision/network_helper.hpp"
#include "itt.hpp"
#include "low_precision/rt_info/skip_cleanup_attribute.hpp"
#include "low_precision/rt_info/disable_cleanup_attribute.hpp"
namespace ov {
namespace pass {
namespace low_precision {
FuseSubtractToFakeQuantizeTransformation::FuseSubtractToFakeQuantizeTransformation(const Params& params) : LayerTransformation(params) {
FuseSubtractToFakeQuantizeTransformation::FuseSubtractToFakeQuantizeTransformation(const Params& params)
: FuseElementwiseToFakeQuantizeTransformation(params) {
MATCHER_SCOPE(FuseSubtractToFakeQuantizeTransformation);
auto matcher = pattern::wrap_type<ov::opset1::Subtract>();
@ -84,49 +85,6 @@ bool FuseSubtractToFakeQuantizeTransformation::transform(TransformationContext&
return true;
}
bool FuseSubtractToFakeQuantizeTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> operation) const {
if (!ov::is_type<ov::opset1::Constant>(operation->get_input_node_shared_ptr(1))) {
return false;
}
if (!FakeQuantizeTransformation::checkElementwise(operation)) {
return false;
}
if (!getAttribute<SkipCleanupAttribute>(operation).empty()) {
return false;
}
const auto children = operation->get_output_target_inputs(0);
for (const auto& target : children) {
const auto convolution = ov::is_type<ov::opset1::Convolution>(target.get_node());
const auto groupConvolution = ov::is_type<ov::opset1::GroupConvolution>(target.get_node());
const auto convolutionBackpropData = ov::is_type<ov::opset1::ConvolutionBackpropData>(target.get_node());
if (convolution || groupConvolution || convolutionBackpropData) {
return false;
}
}
const auto parent = operation->get_input_node_shared_ptr(0);
auto fq = ov::as_type_ptr<ov::opset1::FakeQuantize>(parent);
const auto convert = ov::as_type_ptr<ov::opset1::Convert>(parent);
if (convert) {
fq = ov::as_type_ptr<ov::opset1::FakeQuantize>(convert->get_input_node_shared_ptr(0));
}
if (!fq) {
return false;
}
if (fq->get_output_target_inputs(0).size() != 1) {
return false;
}
return true;
}
bool FuseSubtractToFakeQuantizeTransformation::isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept {
return false;
}

View File

@ -422,21 +422,23 @@ std::shared_ptr<ov::Node> LayerTransformation::moveDequantizationBefore(
return result.newOperation;
}
void LayerTransformation::updateOutput(
bool LayerTransformation::updateOutput(
TransformationContext &context,
std::shared_ptr<ov::Node> lastNode,
std::shared_ptr<ov::Node> originalNode) const {
// TODO: not tested!!!
bool was_updated = false;
for (auto output : lastNode->outputs()) {
for (auto input : output.get_target_inputs()) {
if (ov::is_type<ov::opset1::Result>(input.get_node())) {
const std::string originalName = originalNode->get_friendly_name();
originalNode->set_friendly_name(originalName + LayerTransformation::originalLayerPostfix);
lastNode->set_friendly_name(originalName);
was_updated = true;
break;
}
}
}
return was_updated;
}
void LayerTransformation::updateOutput(

View File

@ -53,7 +53,7 @@
#include "low_precision/interpolate.hpp"
#include "low_precision/mat_mul.hpp"
#include "low_precision/max_pool.hpp"
#include "low_precision/multiply.hpp"
#include "low_precision/multiply_partial.hpp"
#include "low_precision/mvn.hpp"
#include "low_precision/normalize_l2.hpp"
#include "low_precision/pad.hpp"
@ -251,7 +251,7 @@ bool ov::pass::low_precision::LowPrecision::run_on_model(const std::shared_ptr<o
ADD_MATCHER(common, GroupConvolutionTransformation, params)
ADD_MATCHER(common, MatMulTransformation, params)
ADD_MATCHER(common, MaxPoolTransformation, params)
ADD_MATCHER(common, MultiplyTransformation, params)
ADD_MATCHER(common, MultiplyPartialTransformation, params)
ADD_MATCHER(common, MVNTransformation, params)
ADD_MATCHER(common, NormalizeL2Transformation, params)
ADD_MATCHER(common, PadTransformation, params)
@ -273,6 +273,10 @@ bool ov::pass::low_precision::LowPrecision::run_on_model(const std::shared_ptr<o
ADD_MATCHER(common, UnsqueezeTransformation, params)
ADD_MATCHER(common, VariadicSplitTransformation, params)
for (const auto& tr : additional_main_passes) {
common->add_matcher(tr);
}
std::shared_ptr<ov::pass::GraphRewrite> cleanup = manager.register_pass<ov::pass::GraphRewrite>();
ADD_MATCHER(cleanup, EliminateFakeQuantizeTransformation, params)
ADD_MATCHER(cleanup, FoldConvertTransformation, params)

View File

@ -1,4 +1,4 @@
// Copyright (C) 2018-2023 Intel Corporation
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
@ -15,6 +15,7 @@
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "low_precision/common/ie_lpt_exception.hpp"
#include "low_precision/rt_info/disable_cleanup_attribute.hpp"
#include "low_precision/network_helper.hpp"
#include "itt.hpp"
@ -22,7 +23,8 @@ namespace ov {
namespace pass {
namespace low_precision {
MultiplyTransformation::MultiplyTransformation(const Params& params) : EltwiseBaseTransformation(params) {
MultiplyTransformation::MultiplyTransformation(const Params& params) :
WeightableLayerTransformation(params, CanBeTransformedParams(false, false, false, true)) {
MATCHER_SCOPE(MultiplyTransformation);
auto matcher = pattern::wrap_type<ov::opset1::Multiply>();
@ -38,135 +40,107 @@ MultiplyTransformation::MultiplyTransformation(const Params& params) : EltwiseBa
this->register_matcher(m, callback);
}
bool MultiplyTransformation::transform(TransformationContext& context, ov::pass::pattern::Matcher &m) {
bool MultiplyTransformation::transform(TransformationContext& context, ov::pass::pattern::Matcher& m) {
auto multiply = m.get_match_root();
if (!canBeTransformed(context, multiply)) {
return false;
}
multiply = NetworkHelper::separateInStandaloneBranch(multiply, defaultPrecisions);
decomposeFakeQuantizeForWeightsPath(multiply);
NetworkHelper::normalizeDequantization(NetworkHelper::getDequantization(multiply, defaultPrecisions, 0));
NetworkHelper::normalizeDequantization(NetworkHelper::getDequantization(multiply, defaultPrecisions, 1));
multiply = NetworkHelper::separateInStandaloneBranch(multiply, defaultPrecisions);
auto newMultiply = multiply;
const auto dequantization1 = NetworkHelper::getDequantization(multiply, defaultPrecisions, 0);
const auto dequantization2 = NetworkHelper::getDequantization(multiply, defaultPrecisions, 1);
auto fold_fake_quantizes = [](std::shared_ptr<Node>& multiply, const size_t index) {
auto fakeQuantizeOnWeights = ov::as_type_ptr<ov::opset1::FakeQuantize>(multiply->get_input_node_shared_ptr(index));
if (fakeQuantizeOnWeights != nullptr) {
auto result = NetworkHelper::fold_fake_quantize(fakeQuantizeOnWeights);
if (ov::is_type<ov::opset1::Constant>(result)) {
replace_node(fakeQuantizeOnWeights, result);
}
if ((dequantization1.multiplyConstant == nullptr) && (dequantization2.multiplyConstant == nullptr)) {
return false;
}
// before: y = (deq_scales1 * (x1 - zero_point1)) * (deq_scales2 * (x2 - zero_point2))
// after : y = deq_scales1 * deq_scales2 * (x1 - zero_point1) * (x2 - zero_point2)
auto new_scales_values = fold<ov::opset1::Multiply>(
dequantization1.empty() ? dequantization1.data : dequantization1.multiplyConstant,
dequantization2.empty() ? dequantization2.data : dequantization2.multiplyConstant);
if (!ov::is_type<ov::opset1::Constant>(new_scales_values)) {
return false;
}
const auto init_input = [&new_scales_values](const FakeQuantizeDequantization& dequantization) -> Output<Node> {
if (dequantization.empty()) {
return new_scales_values;
}
if (dequantization.subtract == nullptr) {
return dequantization.data;
}
const auto subtract = NetworkHelper::optimizeSubtract(dequantization.subtract);
if (subtract != nullptr) {
DisableCleanupAttribute::create(subtract);
}
return subtract == nullptr ? dequantization.data : subtract;
};
fold_fake_quantizes(multiply, 0ul);
fold_fake_quantizes(multiply, 1ul);
if ((dequantization1.empty() && (ov::is_type<ov::opset1::Constant>(dequantization1.data.get_node()))) ||
(dequantization2.empty() && (ov::is_type<ov::opset1::Constant>(dequantization2.data.get_node())))) {
// one input is constant
const Output<Node> in1 = init_input(dequantization1);
const Output<Node> in2 = init_input(dequantization2);
const int fullPathIndex = getNotEmpty(multiply);
if (fullPathIndex == -1) {
const auto multiplyBranch = getMultiplyConstBranch(multiply);
if (multiplyBranch.first != -1) {
NetworkHelper::foldDequantization(multiply, multiplyBranch.first == 0 ? 1 : 0, defaultPrecisions);
}
const auto new_multiply = (in1.get_element_type() == multiply->get_output_element_type(0)) &&
(in2.get_element_type() == multiply->get_output_element_type(0)) ?
std::make_shared<ov::opset1::Multiply>(in1, in2) :
std::make_shared<ov::op::TypeRelaxed<ov::opset1::Multiply>>(
std::vector<ov::element::Type>{ deqPrecision, deqPrecision },
std::vector<ov::element::Type>{ multiply->get_output_element_type(0) },
ov::op::TemporaryReplaceOutputType(in1, deqPrecision).get(),
ov::op::TemporaryReplaceOutputType(in2, deqPrecision).get());
if (multiplyBranch.first == -1 || multiplyBranch.second == -1) {
// constant folding on dequantization ops (for example: Convert on Subtract)
NetworkHelper::foldDequantization(multiply, 0, defaultPrecisions);
NetworkHelper::foldDequantization(multiply, 1, defaultPrecisions);
return false;
}
replace_node(multiply, new_multiply);
updateOutput(context, new_multiply, multiply);
auto multiplyParent = multiply->input_value(multiplyBranch.first);
auto constParent = multiply->input_value(multiplyBranch.first == 0 ? 1 : 0);
auto multiplyParentParent = multiplyParent.get_node_shared_ptr()->input_value(multiplyBranch.second);
auto multiplyParentConst = multiplyParent.get_node_shared_ptr()->input_value(multiplyBranch.second == 0 ? 1 : 0);
newMultiply = std::make_shared<ov::op::TypeRelaxed<ov::opset1::Multiply>>(
std::vector<ov::element::Type>{ element::f32, element::f32 },
std::vector<ov::element::Type>{ multiply->get_output_element_type(0) },
ov::op::TemporaryReplaceOutputType(multiplyParentParent, element::f32).get(),
ov::op::TemporaryReplaceOutputType(
fold<ov::opset1::Multiply>(
foldConvert(multiplyParentConst, element::f32),
foldConvert(constParent, element::f32)),
element::f32).get());
NetworkHelper::copyInfo(multiplyParent.get_node_shared_ptr(), newMultiply);
NetworkHelper::copyInfo(multiply, newMultiply);
} else {
const int emptyPathIndex = fullPathIndex == 0 ? 1 : 0;
if (updatePrecisions) {
const FakeQuantizeDequantization dequantizationEmptyPath = NetworkHelper::getDequantization(multiply, defaultPrecisions, emptyPathIndex);
if (!dequantizationEmptyPath.empty() && !dequantizationEmptyPath.isLowPrecision()) {
return false;
}
}
FakeQuantizeDequantization dequantizationEmptyPath = NetworkHelper::foldDequantization(multiply, emptyPathIndex, defaultPrecisions);
std::shared_ptr<Node> subtractValuesEmptyPath;
std::shared_ptr<Node> multiplyValuesEmptyPath;
std::tie(subtractValuesEmptyPath, multiplyValuesEmptyPath) = NetworkHelper::createEmptyValues(dequantizationEmptyPath, deqPrecision);
// check if empty path shifts are not zero
if (!NetworkHelper::isZeroConst(subtractValuesEmptyPath)) {
return false;
}
FakeQuantizeDequantization dequantizationFullPath = NetworkHelper::foldDequantization(multiply, fullPathIndex, defaultPrecisions);
std::shared_ptr<Node> subtractValuesFullPath;
std::shared_ptr<Node> multiplyValuesFullPath;
std::tie(subtractValuesFullPath, multiplyValuesFullPath) = NetworkHelper::createEmptyValues(dequantizationFullPath, deqPrecision);
// before: Y = (SC1 * (X1 - SH1)) * (SC2 * X2)
// after : Y = (SC1' * (X1 - SH1)) * (X2) , where :
// SC1' = SC1 * SC2
auto newMultiplyValuesFullPath = fold<ov::opset1::Multiply>(multiplyValuesEmptyPath, multiplyValuesFullPath);
OutputVector inputs{ {}, {} };
inputs[emptyPathIndex] = dequantizationEmptyPath.data;
inputs[fullPathIndex] = std::make_shared<ov::opset1::Multiply>(
dequantizationFullPath.subtract == nullptr ?
(dequantizationFullPath.convert == nullptr ?
dequantizationFullPath.data : dequantizationFullPath.convert) :
dequantizationFullPath.subtract,
newMultiplyValuesFullPath);
newMultiply = std::make_shared<ov::op::TypeRelaxed<ov::opset1::Multiply>>(
std::vector<element::Type>{element::f32, element::f32},
std::vector<element::Type>{ multiply->get_output_element_type(0) },
ov::op::TemporaryReplaceOutputType(inputs[0], element::f32).get(),
ov::op::TemporaryReplaceOutputType(inputs[1], element::f32).get());
NetworkHelper::copyInfo(multiply, newMultiply);
return true;
}
replace_node(multiply, newMultiply);
updateOutput(context, newMultiply, multiply);
Output<Node> in1 = init_input(dequantization1);
Output<Node> in2 = init_input(dequantization2);
if (fullPathIndex != -1) {
NetworkHelper::foldDequantization(newMultiply, fullPathIndex, defaultPrecisions);
}
// in1 & in2 can have different input types
const auto new_multiply = (in1.get_element_type() == deqPrecision) &&
(in2.get_element_type() == deqPrecision) ?
std::make_shared<ov::opset1::Multiply>(in1, in2) :
std::make_shared<ov::op::TypeRelaxed<ov::opset1::Multiply>>(
std::vector<ov::element::Type>{ deqPrecision, deqPrecision },
std::vector<ov::element::Type>{ deqPrecision },
ov::op::TemporaryReplaceOutputType(in1, deqPrecision).get(),
ov::op::TemporaryReplaceOutputType(in2, deqPrecision).get());
DisableCleanupAttribute::create(new_multiply);
auto new_scales = (new_multiply->get_output_element_type(0) == multiply->get_output_element_type(0)) &&
(new_scales_values->get_output_element_type(0) == multiply->get_output_element_type(0)) ?
std::make_shared<ov::opset1::Multiply>(new_multiply, new_scales_values) :
std::make_shared<ov::op::TypeRelaxed<ov::opset1::Multiply>>(
ov::opset1::Multiply(new_multiply, new_scales_values),
multiply->get_output_element_type(0));
replace_node(multiply, new_scales);
const auto was_updated = updateOutput(context, new_scales, multiply);
NetworkHelper::copyInfo(multiply, new_multiply, !was_updated);
return true;
}
bool MultiplyTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const {
FakeQuantizeDequantization dequantization1 = pass::low_precision::NetworkHelper::getDequantization(layer, defaultPrecisions, 0ul);
FakeQuantizeDequantization dequantization2 = pass::low_precision::NetworkHelper::getDequantization(layer, defaultPrecisions, 1ul);
if (dequantization1.data.get_node() == nullptr || dequantization2.data.get_node() == nullptr) {
return false;
}
const bool nonConstantData = !ov::is_type<ov::opset1::Constant>(dequantization1.data.get_node_shared_ptr()) &&
!ov::is_type<ov::opset1::Constant>(dequantization2.data.get_node_shared_ptr());
if (((dequantization1.empty() || dequantization2.empty()) && nonConstantData)) {
return false;
}
return EltwiseBaseTransformation::canBeTransformed(context, layer);
size_t MultiplyTransformation::getInputChannels(const std::shared_ptr<ov::Node> op) const {
const auto channels = op->get_input_partial_shape(1)[1];
assert(channels.is_static());
return channels.get_length();
}
} // namespace low_precision

View File

@ -0,0 +1,174 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "low_precision/multiply_partial.hpp"
#include <algorithm>
#include <memory>
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>
#include <cassert>
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "low_precision/common/ie_lpt_exception.hpp"
#include "low_precision/network_helper.hpp"
#include "itt.hpp"
namespace ov {
namespace pass {
namespace low_precision {
MultiplyPartialTransformation::MultiplyPartialTransformation(const Params& params) : EltwiseBaseTransformation(params) {
MATCHER_SCOPE(MultiplyPartialTransformation);
auto matcher = pattern::wrap_type<ov::opset1::Multiply>();
ov::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
auto op = m.get_match_root();
if (transformation_callback(op)) {
return false;
}
return transform(*context, m);
};
auto m = std::make_shared<ov::pass::pattern::Matcher>(matcher, matcher_name);
this->register_matcher(m, callback);
}
bool MultiplyPartialTransformation::transform(TransformationContext& context, ov::pass::pattern::Matcher& m) {
auto multiply = m.get_match_root();
if (!canBeTransformed(context, multiply)) {
return false;
}
NetworkHelper::normalizeDequantization(NetworkHelper::getDequantization(multiply, defaultPrecisions, 0));
NetworkHelper::normalizeDequantization(NetworkHelper::getDequantization(multiply, defaultPrecisions, 1));
multiply = NetworkHelper::separateInStandaloneBranch(multiply, defaultPrecisions);
auto newMultiply = multiply;
auto fold_fake_quantizes = [](std::shared_ptr<Node>& multiply, const size_t index) {
auto fakeQuantizeOnWeights = ov::as_type_ptr<ov::opset1::FakeQuantize>(multiply->get_input_node_shared_ptr(index));
if (fakeQuantizeOnWeights != nullptr) {
auto result = NetworkHelper::fold_fake_quantize(fakeQuantizeOnWeights);
if (ov::is_type<ov::opset1::Constant>(result)) {
replace_node(fakeQuantizeOnWeights, result);
}
}
};
fold_fake_quantizes(multiply, 0ul);
fold_fake_quantizes(multiply, 1ul);
const int fullPathIndex = getNotEmpty(multiply);
if (fullPathIndex == -1) {
const auto multiplyBranch = getMultiplyConstBranch(multiply);
if (multiplyBranch.first != -1) {
NetworkHelper::foldDequantization(multiply, multiplyBranch.first == 0 ? 1 : 0, defaultPrecisions);
}
if (multiplyBranch.first == -1 || multiplyBranch.second == -1) {
// constant folding on dequantization ops (for example: Convert on Subtract)
NetworkHelper::foldDequantization(multiply, 0, defaultPrecisions);
NetworkHelper::foldDequantization(multiply, 1, defaultPrecisions);
return false;
}
auto multiplyParent = multiply->input_value(multiplyBranch.first);
auto constParent = multiply->input_value(multiplyBranch.first == 0 ? 1 : 0);
auto multiplyParentParent = multiplyParent.get_node_shared_ptr()->input_value(multiplyBranch.second);
auto multiplyParentConst = multiplyParent.get_node_shared_ptr()->input_value(multiplyBranch.second == 0 ? 1 : 0);
newMultiply = std::make_shared<ov::op::TypeRelaxed<ov::opset1::Multiply>>(
std::vector<ov::element::Type>{ element::f32, element::f32 },
std::vector<ov::element::Type>{ multiply->get_output_element_type(0) },
ov::op::TemporaryReplaceOutputType(multiplyParentParent, element::f32).get(),
ov::op::TemporaryReplaceOutputType(
fold<ov::opset1::Multiply>(
foldConvert(multiplyParentConst, element::f32),
foldConvert(constParent, element::f32)),
element::f32).get());
NetworkHelper::copyInfo(multiplyParent.get_node_shared_ptr(), newMultiply);
NetworkHelper::copyInfo(multiply, newMultiply);
} else {
const int emptyPathIndex = fullPathIndex == 0 ? 1 : 0;
if (updatePrecisions) {
const FakeQuantizeDequantization dequantizationEmptyPath = NetworkHelper::getDequantization(multiply, defaultPrecisions, emptyPathIndex);
if (!dequantizationEmptyPath.empty() && !dequantizationEmptyPath.isLowPrecision()) {
return false;
}
}
FakeQuantizeDequantization dequantizationEmptyPath = NetworkHelper::foldDequantization(multiply, emptyPathIndex, defaultPrecisions);
std::shared_ptr<Node> subtractValuesEmptyPath;
std::shared_ptr<Node> multiplyValuesEmptyPath;
std::tie(subtractValuesEmptyPath, multiplyValuesEmptyPath) = NetworkHelper::createEmptyValues(dequantizationEmptyPath, deqPrecision);
// check if empty path shifts are not zero
if (!NetworkHelper::isZeroConst(subtractValuesEmptyPath)) {
return false;
}
FakeQuantizeDequantization dequantizationFullPath = NetworkHelper::foldDequantization(multiply, fullPathIndex, defaultPrecisions);
std::shared_ptr<Node> subtractValuesFullPath;
std::shared_ptr<Node> multiplyValuesFullPath;
std::tie(subtractValuesFullPath, multiplyValuesFullPath) = NetworkHelper::createEmptyValues(dequantizationFullPath, deqPrecision);
// before: Y = (SC1 * (X1 - SH1)) * (SC2 * X2)
// after : Y = (SC1' * (X1 - SH1)) * (X2) , where :
// SC1' = SC1 * SC2
auto newMultiplyValuesFullPath = fold<ov::opset1::Multiply>(multiplyValuesEmptyPath, multiplyValuesFullPath);
OutputVector inputs{ {}, {} };
inputs[emptyPathIndex] = dequantizationEmptyPath.data;
inputs[fullPathIndex] = std::make_shared<ov::opset1::Multiply>(
dequantizationFullPath.subtract == nullptr ?
(dequantizationFullPath.convert == nullptr ?
dequantizationFullPath.data : dequantizationFullPath.convert) :
dequantizationFullPath.subtract,
newMultiplyValuesFullPath);
newMultiply = std::make_shared<ov::op::TypeRelaxed<ov::opset1::Multiply>>(
std::vector<element::Type>{element::f32, element::f32},
std::vector<element::Type>{ multiply->get_output_element_type(0) },
ov::op::TemporaryReplaceOutputType(inputs[0], element::f32).get(),
ov::op::TemporaryReplaceOutputType(inputs[1], element::f32).get());
NetworkHelper::copyInfo(multiply, newMultiply);
}
replace_node(multiply, newMultiply);
updateOutput(context, newMultiply, multiply);
if (fullPathIndex != -1) {
NetworkHelper::foldDequantization(newMultiply, fullPathIndex, defaultPrecisions);
}
return true;
}
bool MultiplyPartialTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const {
FakeQuantizeDequantization dequantization1 = pass::low_precision::NetworkHelper::getDequantization(layer, defaultPrecisions, 0ul);
FakeQuantizeDequantization dequantization2 = pass::low_precision::NetworkHelper::getDequantization(layer, defaultPrecisions, 1ul);
if (dequantization1.data.get_node() == nullptr || dequantization2.data.get_node() == nullptr) {
return false;
}
const bool nonConstantData = !ov::is_type<ov::opset1::Constant>(dequantization1.data.get_node_shared_ptr()) &&
!ov::is_type<ov::opset1::Constant>(dequantization2.data.get_node_shared_ptr());
if (((dequantization1.empty() || dequantization2.empty()) && nonConstantData)) {
return false;
}
return EltwiseBaseTransformation::canBeTransformed(context, layer);
}
} // namespace low_precision
} // namespace pass
} // namespace ov

View File

@ -15,7 +15,7 @@ namespace low_precision {
MultiplyToGroupConvolutionTransformation::MultiplyToGroupConvolutionTransformation(
const Params& params,
const PrecisionsRestriction::PrecisionsByPorts& restrictions) : LayerTransformation(params), restrictions(restrictions), groupSize(1ul) {
const PrecisionsRestriction::PrecisionsByPorts& restrictions) : CleanupTransformation(params), restrictions(restrictions), groupSize(1ul) {
MATCHER_SCOPE(MultiplyToGroupConvolutionTransformation);
auto matcher = pattern::wrap_type<ov::opset1::Multiply>();
@ -143,6 +143,10 @@ bool MultiplyToGroupConvolutionTransformation::transform(TransformationContext&
}
bool MultiplyToGroupConvolutionTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> operation) const {
if (!CleanupTransformation::canBeTransformed(context, operation)) {
return false;
}
const PartialShape outPShape = operation->get_output_partial_shape(0);
const auto rank = outPShape.rank();
if (rank.is_dynamic()) {

View File

@ -14,7 +14,7 @@
#include "openvino/pass/pattern/op/or.hpp"
#include "low_precision/network_helper.hpp"
#include "low_precision/rt_info/skip_cleanup_attribute.hpp"
#include "low_precision/rt_info/disable_cleanup_attribute.hpp"
namespace ov {
namespace pass {
@ -96,6 +96,7 @@ bool RecurrentCellTransformation::transform(TransformationContext& context, ov::
if (!canBeTransformed(context, lstm)) {
return false;
}
for (size_t parentIndex = 0ul; parentIndex < lstm->get_input_size(); parentIndex++) {
auto lstm_parent = lstm->get_input_node_shared_ptr(parentIndex);
if (is_type<ov::opset1::FakeQuantize>(lstm_parent)) {
@ -108,7 +109,7 @@ bool RecurrentCellTransformation::transform(TransformationContext& context, ov::
? defaultPrecisions
: precisionsAttribute.as<PrecisionsAttribute>().value();
const DataPrecision dataPrecision = getDataPrecision(lstm_parent, quantizationDetails, precisions);
if (dataPrecision.empty()) {
if (dataPrecision.empty() || dataPrecision.hasZeroPoint) {
return false;
}
@ -148,6 +149,7 @@ bool RecurrentCellTransformation::transform(TransformationContext& context, ov::
continue;
}
}
return true;
}
@ -172,12 +174,12 @@ bool RecurrentCellTransformation::isPrecisionPreserved(std::shared_ptr<Node>) co
}
void RecurrentCellTransformation::propagateSkipCleanupAttribute(std::shared_ptr<Node> multiply) {
SkipCleanupAttribute::create(multiply);
DisableCleanupAttribute::create(multiply);
auto multiply_parent = multiply->get_input_node_shared_ptr(0);
SkipCleanupAttribute::create(multiply_parent);
DisableCleanupAttribute::create(multiply_parent);
if (is_type<ov::opset1::Subtract>(multiply_parent)) {
auto subtract_parent = multiply_parent->get_input_node_shared_ptr(0);
SkipCleanupAttribute::create(subtract_parent);
DisableCleanupAttribute::create(subtract_parent);
}
}

View File

@ -1,20 +0,0 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "low_precision/rt_info/skip_cleanup_attribute.hpp"
#include <memory>
#include <string>
#include <unordered_map>
#include <iterator>
#include <vector>
using namespace ov;
using namespace ov;
ov::Any SkipCleanupAttribute::create(
const std::shared_ptr<ov::Node>& node) {
auto& rt = node->get_rt_info();
return (rt[SkipCleanupAttribute::get_type_info_static()] = SkipCleanupAttribute());
}

View File

@ -15,6 +15,7 @@ namespace pass {
namespace low_precision {
namespace {
// used in isQuantizedStatic static method, can not be virtual method
std::vector<size_t> getWeightsDequantizationIdces(const std::shared_ptr<const Node> weightableLayer) {
if (ov::is_type<ov::opset1::Convolution>(weightableLayer)) {
return std::vector<size_t>{0};
@ -22,7 +23,9 @@ std::vector<size_t> getWeightsDequantizationIdces(const std::shared_ptr<const No
return std::vector<size_t>{1};
} else if (ov::is_type<ov::opset1::GroupConvolution>(weightableLayer)) {
return ov::is_type<ov::opset1::Reshape>(weightableLayer->get_input_node_shared_ptr(1)) ? std::vector<size_t>{0}
: std::vector<size_t>{0, 1};
: std::vector<size_t>{0, 1};
} else if (ov::is_type<ov::opset1::Multiply>(weightableLayer)) {
return std::vector<size_t>{};
} else {
THROW_IE_LPT_EXCEPTION(*weightableLayer) << "getWeightsDequantizationIdces is called for unexpected layer";
}
@ -41,7 +44,10 @@ bool checkConstShape(const std::vector<size_t>& idcesToCheck, const std::shared_
}
} // namespace
WeightableLayerTransformation::WeightableLayerTransformation(const Params& params) : LayerTransformation(params) {}
WeightableLayerTransformation::WeightableLayerTransformation(const Params& params, const CanBeTransformedParams& canBeTransformedParams) :
LayerTransformation(params),
canBeTransformedParams(canBeTransformedParams) {
}
bool WeightableLayerTransformation::canConvolutionBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer,
const std::vector<ov::element::Type>& defaultPrecisions) const {
@ -88,7 +94,7 @@ bool WeightableLayerTransformation::canBeTransformed(const TransformationContext
}
// dynamic activations rank and dynamic weights aren't supported
if (layer->get_input_partial_shape(0).rank().is_dynamic() || layer->get_input_partial_shape(1).is_dynamic()) {
if (!canBeTransformedParams.dynamicWeights && (layer->get_input_partial_shape(0).rank().is_dynamic() || layer->get_input_partial_shape(1).is_dynamic())) {
return false;
}
@ -138,14 +144,16 @@ bool WeightableLayerTransformation::canBeTransformed(const TransformationContext
return false;
}
// exactly cast vector as original code has a conversion;
// optimize cast:
// two branches depending on real type of the constant?
const auto scalesBuffer = dequantization.multiplyConstant->cast_vector<float>();
size_t scalesBufferSize = shape_size(dequantization.multiplyConstant->get_shape());
for (size_t i = 1ul; i < scalesBufferSize; ++i) {
if (scalesBuffer[i - 1] != scalesBuffer[i]) {
return false;
if (canBeTransformedParams.perTensorQuantizationOnData) {
// exactly cast vector as original code has a conversion;
// optimize cast:
// two branches depending on real type of the constant?
const auto scalesBuffer = dequantization.multiplyConstant->cast_vector<float>();
size_t scalesBufferSize = shape_size(dequantization.multiplyConstant->get_shape());
for (size_t i = 1ul; i < scalesBufferSize; ++i) {
if (scalesBuffer[i - 1] != scalesBuffer[i]) {
return false;
}
}
}
}
@ -213,8 +221,11 @@ bool WeightableLayerTransformation::canBeTransformed(const TransformationContext
dqVolume *= constChannels;
}
}
if (shape_size(constShape) != 1 && shape_size(constShape) != dqVolume) {
return false;
if (!dqIdces.empty()) {
if (shape_size(constShape) != 1 && shape_size(constShape) != dqVolume) {
return false;
}
}
} else {
// TODO: LPT: is it possible to share with isQuantized?
@ -225,13 +236,16 @@ bool WeightableLayerTransformation::canBeTransformed(const TransformationContext
return false;
}
const auto weightsData = ov::as_type_ptr<ov::opset1::Constant>(dequantizationOnWeights.data.get_node_shared_ptr());
if (weightsData == nullptr) {
return false;
const auto weightsData = dequantizationOnWeights.data.get_node_shared_ptr();
if (canBeTransformedParams.constantWeight) {
const auto constantWeightsData = ov::as_type_ptr<ov::opset1::Constant>(weightsData);
if (constantWeightsData == nullptr) {
return false;
}
}
const auto weightsDataPrecision = weightsData->get_element_type();
if (!DataPrecision::isSupported(weightsDataPrecision)) {
if (canBeTransformedParams.limitWeightsDataPrecision && !DataPrecision::isSupported(weightsDataPrecision)) {
return false;
}
@ -243,9 +257,11 @@ bool WeightableLayerTransformation::canBeTransformed(const TransformationContext
}
const auto dqIdces = getWeightsDequantizationIdces(layer);
if ((dequantizationOnWeights.subtract && !checkConstShape(dqIdces, dequantizationOnWeights.subtractConstant)) ||
(dequantizationOnWeights.multiply && !checkConstShape(dqIdces, dequantizationOnWeights.multiplyConstant))) {
return false;
if (!dqIdces.empty()) {
if ((dequantizationOnWeights.subtract && !checkConstShape(dqIdces, dequantizationOnWeights.subtractConstant)) ||
(dequantizationOnWeights.multiply && !checkConstShape(dqIdces, dequantizationOnWeights.multiplyConstant))) {
return false;
}
}
}

View File

@ -22,7 +22,7 @@
#include "low_precision/interpolate.hpp"
#include "low_precision/mat_mul.hpp"
#include "low_precision/max_pool.hpp"
#include "low_precision/multiply.hpp"
#include "low_precision/multiply_partial.hpp"
#include "low_precision/mvn.hpp"
#include "low_precision/network_helper.hpp"
#include "low_precision/normalize_l2.hpp"
@ -361,7 +361,7 @@ TEST(LPT, AvoidDequantizationToShapeOfPropagationMultiplyTransformation) {
auto f = std::make_shared<Model>(ResultVector{result1, result2}, ParameterVector{input1, input2});
pass::Manager m;
m.register_pass<ov::pass::low_precision::MultiplyTransformation>();
m.register_pass<ov::pass::low_precision::MultiplyPartialTransformation>();
m.run_passes(f);
auto dqBeforeShapeOf = ov::pass::low_precision::NetworkHelper::getDequantization(result2->get_input_node_shared_ptr(0));

View File

@ -12,17 +12,29 @@
#include "low_precision/markup_quantization_granularity.hpp"
#include "low_precision/transformation_context.hpp"
// cleanup transformations
#include "low_precision/convert.hpp"
#include "low_precision/eliminate_fake_quantize.hpp"
#include "low_precision/fold_convert.hpp"
#include "low_precision/fold_fake_quantize.hpp"
#include "low_precision/fuse_convert.hpp"
#include "low_precision/fuse_multiply_to_fake_quantize.hpp"
#include "low_precision/fuse_subtract_to_fake_quantize.hpp"
#include "low_precision/multiply_to_group_convolution.hpp"
#include <string>
using namespace testing;
using namespace ov::pass;
using namespace ov::pass::low_precision;
OPENVINO_SUPPRESS_DEPRECATED_START
SimpleLowPrecisionTransformer::SimpleLowPrecisionTransformer(
const std::vector<ov::pass::low_precision::PrecisionsRestriction>& precisionRestrictions,
const std::vector<ov::pass::low_precision::QuantizationGranularityRestriction>& quantizationRestrictions,
const AttributeParameters& params) {
const AttributeParameters& params,
const bool addCleanup) {
auto passConfig = get_pass_config();
// TODO: use one pass manager
@ -39,7 +51,20 @@ SimpleLowPrecisionTransformer::SimpleLowPrecisionTransformer(
common = std::make_shared<ov::pass::Manager>(passConfig);
commonGraphRewrite = common->register_pass<ov::pass::GraphRewrite>();
cleanup = common->register_pass<ov::pass::GraphRewrite>();
if (addCleanup) {
ov::pass::low_precision::LayerTransformation::Params params;
cleanup->add_matcher<EliminateFakeQuantizeTransformation>(params);
cleanup->add_matcher<FoldConvertTransformation>(params);
cleanup->add_matcher<FuseConvertTransformation>(params);
cleanup->add_matcher<FuseSubtractToFakeQuantizeTransformation>(params);
cleanup->add_matcher<FuseMultiplyToFakeQuantizeTransformation>(params);
cleanup->add_matcher<MultiplyToGroupConvolutionTransformation>(
params,
PrecisionsRestriction::getPrecisionsByOperationType<opset1::GroupConvolution>(precisionRestrictions));
}
}
void SimpleLowPrecisionTransformer::transform(std::shared_ptr<ov::Model>& model) {

View File

@ -19,7 +19,8 @@ public:
SimpleLowPrecisionTransformer(
const std::vector<ov::pass::low_precision::PrecisionsRestriction>& precisionRestrictions = {},
const std::vector<ov::pass::low_precision::QuantizationGranularityRestriction>& quantizationRestrictions = {},
const AttributeParameters& params = AttributeParameters());
const AttributeParameters& params = AttributeParameters(),
const bool addCleanup = false);
template <class T, class Operation>
void add(const TestTransformationParams& params) {

View File

@ -11,7 +11,7 @@
#include <ie_core.hpp>
#include <transformations/init_node_info.hpp>
#include "lpt_ngraph_functions/multiply_function.hpp"
#include "lpt_ngraph_functions/multiply_partial_function.hpp"
#include "ngraph_functions/subgraph_builders.hpp"
@ -56,7 +56,7 @@ void MultiplyTransformation::SetUp() {
MultiplyTestValues param;
std::tie(precision, inputShape, targetDevice, param) = this->GetParam();
function = ngraph::builder::subgraph::MultiplyFunction::getOriginal(
function = ngraph::builder::subgraph::MultiplyPartialFunction::get(
precision,
inputShape,
param.broadcast1,

View File

@ -17,42 +17,39 @@ namespace subgraph {
class MultiplyBranch {
public:
MultiplyBranch(const PartialShape& inputShape,
const ngraph::builder::subgraph::Constant& constant,
const ngraph::element::Type& input_precision,
const ngraph::builder::subgraph::DequantizationOperations& dequantization,
const ngraph::builder::subgraph::FakeQuantizeOnData& fake_quantize)
: inputShape(inputShape),
constant(constant),
input_precision(input_precision),
dequantization(dequantization),
fake_quantize(fake_quantize) {}
PartialShape inputShape;
ngraph::builder::subgraph::Constant constant;
ngraph::element::Type precisionBeforeDequantization;
ngraph::element::Type input_precision;
ngraph::builder::subgraph::DequantizationOperations dequantization;
ngraph::builder::subgraph::FakeQuantizeOnData fake_quantize;
};
inline std::ostream& operator<<(std::ostream& out, const MultiplyBranch& branch) {
return out << "_" << branch.constant << "_" << branch.precisionBeforeDequantization << "_" << branch.dequantization;
}
class MultiplyValues {
public:
MultiplyValues(const MultiplyBranch& branch1,
const MultiplyBranch& branch2,
const ngraph::builder::subgraph::DequantizationOperations& after_dequantization)
: branch1(branch1), branch2(branch2), after_dequantization(after_dequantization) {}
MultiplyBranch branch1;
MultiplyBranch branch2;
bool isDequantization;
ngraph::builder::subgraph::DequantizationOperations after_dequantization;
};
inline std::ostream& operator<<(std::ostream& out, const MultiplyValues& values) {
return out << "_" << values.branch1 << "_" << values.branch2 << (values.isDequantization ? "_isDequantization" : "");
}
class MultiplyFunction : public ElementwiseFunction {
public:
static std::shared_ptr<ngraph::Function> get(
const element::Type precision,
const MultiplyValues& actualValues);
static std::shared_ptr<ngraph::Function> getOriginal(
const ngraph::element::Type precision,
const ngraph::PartialShape& inputShape,
const bool broadcast1,
const ngraph::builder::subgraph::FakeQuantizeOnData& fq1,
const bool broadcast2,
const ngraph::builder::subgraph::FakeQuantizeOnData& fq2,
const ngraph::builder::subgraph::FakeQuantizeOnData& fqAfter,
const bool secondInputIsConstant = false);
static std::shared_ptr<ngraph::Function> get(const element::Type model_precision, const MultiplyValues& actualValues);
};
} // namespace subgraph

View File

@ -0,0 +1,60 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <memory>
#include <ngraph/ngraph.hpp>
#include "elementwise_function.hpp"
#include "lpt_ngraph_functions/common/constant.hpp"
#include "lpt_ngraph_functions/common/dequantization_operations.hpp"
namespace ngraph {
namespace builder {
namespace subgraph {
class MultiplyPartialBranch {
public:
PartialShape inputShape;
ngraph::builder::subgraph::Constant constant;
ngraph::element::Type precisionBeforeDequantization;
ngraph::builder::subgraph::DequantizationOperations dequantization;
};
inline std::ostream& operator<<(std::ostream& out, const MultiplyPartialBranch& branch) {
return out << "_" << branch.constant << "_" << branch.precisionBeforeDequantization << "_" << branch.dequantization;
}
class MultiplyPartialValues {
public:
MultiplyPartialBranch branch1;
MultiplyPartialBranch branch2;
bool isDequantization;
};
inline std::ostream& operator<<(std::ostream& out, const MultiplyPartialValues& values) {
return out << "_" << values.branch1 << "_" << values.branch2 << (values.isDequantization ? "_isDequantization" : "");
}
class MultiplyPartialFunction : public ElementwiseFunction {
public:
static std::shared_ptr<ngraph::Function> get(
const element::Type precision,
const MultiplyPartialValues& actualValues);
static std::shared_ptr<ngraph::Function> get(
const ngraph::element::Type precision,
const ngraph::PartialShape& inputShape,
const bool broadcast1,
const ngraph::builder::subgraph::FakeQuantizeOnData& fq1,
const bool broadcast2,
const ngraph::builder::subgraph::FakeQuantizeOnData& fq2,
const ngraph::builder::subgraph::FakeQuantizeOnData& fqAfter,
const bool secondInputIsConstant = false);
};
} // namespace subgraph
} // namespace builder
} // namespace ngraph

View File

@ -4,6 +4,8 @@
#include "lpt_ngraph_functions/multiply_function.hpp"
#include <memory>
#include <ngraph/opsets/opset1.hpp>
#include <ov_ops/type_relaxed.hpp>
#include "ngraph_functions/subgraph_builders.hpp"
@ -18,49 +20,52 @@ namespace ngraph {
namespace builder {
namespace subgraph {
namespace multiply_function {
struct BranchNodes {
std::shared_ptr<Node> input;
std::shared_ptr<Node> dequantization;
};
BranchNodes getBranch(const MultiplyBranch& branch) {
const std::shared_ptr<Node> parent = branch.constant.empty() ?
std::make_shared<ngraph::opset1::Parameter>(branch.precisionBeforeDequantization, branch.inputShape) :
BranchNodes makeBranch(const MultiplyBranch& branch) {
std::shared_ptr<Node> parent = branch.constant.empty() ?
std::make_shared<ngraph::opset1::Parameter>(branch.input_precision, branch.inputShape) :
std::dynamic_pointer_cast<Node>(std::make_shared<ngraph::opset1::Constant>(
branch.constant.outPrecision,
branch.constant.shape,
branch.constant.values));
if (!branch.fake_quantize.empty()) {
if ((parent->get_output_element_type(0) != element::f32) &&
(parent->get_output_element_type(0) != element::f16)) {
throw std::runtime_error("unexpected precision before FakeQuantize");
}
parent = makeFakeQuantize(parent, parent->get_output_element_type(0), branch.fake_quantize);
}
const auto dequantization = makeDequantization(parent, branch.dequantization);
return {parent, dequantization};
}
} // namespace multiply_function
std::shared_ptr<ngraph::Function> MultiplyFunction::get(
const element::Type precision,
const MultiplyValues& actualValues) {
auto branch1Structure = actualValues.branch1;
branch1Structure.precisionBeforeDequantization = precision;
branch1Structure.dequantization.multiply.outPrecision = precision;
auto branch2Structure = actualValues.branch2;
branch2Structure.precisionBeforeDequantization = precision;
branch2Structure.dequantization.multiply.outPrecision = precision;
std::shared_ptr<ngraph::Function> MultiplyFunction::get(const element::Type model_precision, const MultiplyValues& actualValues) {
const auto branchNodes1 = multiply_function::makeBranch(actualValues.branch1);
const auto branchNodes2 = multiply_function::makeBranch(actualValues.branch2);
const BranchNodes branchNodes1 = getBranch(actualValues.branch1);
const BranchNodes branchNodes2 = getBranch(actualValues.branch2);
auto multiplyOriginal = opset1::Multiply(
// branchNodes1.dequantization & branchNodes2.dequantization can have different input types
std::shared_ptr<ngraph::Node> parent = std::make_shared<ov::op::TypeRelaxed<ov::opset1::Multiply>>(
std::vector<ngraph::element::Type>{ element::f32, element::f32 },
std::vector<ngraph::element::Type>{ actualValues.after_dequantization.empty() ? model_precision : element::f32 },
ov::op::TemporaryReplaceOutputType(branchNodes1.dequantization, element::f32).get(),
ov::op::TemporaryReplaceOutputType(branchNodes2.dequantization, element::f32).get());
const std::shared_ptr<ngraph::Node> multiply = std::make_shared<ov::op::TypeRelaxed<ngraph::opset1::Multiply>>(
multiplyOriginal,
std::vector<element::Type>{element::f32, element::f32},
std::vector<element::Type>{precision});
auto& rtInfo = multiply->get_rt_info();
auto& rtInfo = parent->get_rt_info();
rtInfo["Variant::std::string"] = "multiply";
multiply->set_friendly_name("output");
ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(multiply) };
parent = makeDequantization(parent, actualValues.after_dequantization);
parent->set_friendly_name("output");
ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(parent) };
ngraph::ParameterVector inputs;
if (is_type<opset1::Parameter>(branchNodes1.input)) {
@ -73,78 +78,6 @@ std::shared_ptr<ngraph::Function> MultiplyFunction::get(
return std::make_shared<ngraph::Function>(results, inputs, "MultiplyTransformation");
}
std::shared_ptr<ngraph::Function> MultiplyFunction::getOriginal(
const ngraph::element::Type precision,
const ngraph::PartialShape& inputShape,
const bool broadcast1,
const ngraph::builder::subgraph::FakeQuantizeOnData& fq1,
const bool broadcast2,
const ngraph::builder::subgraph::FakeQuantizeOnData& fq2,
const ngraph::builder::subgraph::FakeQuantizeOnData& fqAfter,
const bool secondInputIsConstant) {
auto inputShape1 = inputShape;
if (broadcast1) {
inputShape1[2] = 1;
inputShape1[3] = 1;
}
ngraph::PartialShape inputShape2;
if (secondInputIsConstant) {
inputShape2 = {};
} else {
inputShape2 = inputShape;
if (broadcast2) {
inputShape2[2] = 1;
inputShape2[3] = 1;
}
}
const auto input1 = std::make_shared<ngraph::opset1::Parameter>(precision, inputShape1);
const auto fakeQuantize1 = fq1.empty() ?
nullptr :
ngraph::builder::makeFakeQuantize(
input1, precision, fq1.quantizationLevel, fq1.constantShape,
fq1.inputLowValues, fq1.inputHighValues, fq1.outputLowValues, fq1.outputHighValues);
if (fakeQuantize1 != nullptr) {
fakeQuantize1->set_friendly_name("fakeQuantize1");
}
const std::shared_ptr<ngraph::Node> input2 = secondInputIsConstant ?
makeConstant(element::f32, Shape{}, std::vector<float>{0.5f}, false) :
std::make_shared<ngraph::opset1::Parameter>(precision, inputShape2);
const auto fakeQuantize2 = fq2.empty() ?
nullptr :
ngraph::builder::makeFakeQuantize(
input2, precision, fq2.quantizationLevel, fq2.constantShape,
fq2.inputLowValues, fq2.inputHighValues, fq2.outputLowValues, fq2.outputHighValues);
if (fakeQuantize2 != nullptr) {
fakeQuantize2->set_friendly_name("fakeQuantize2");
}
const auto multiply = std::make_shared<ngraph::opset1::Multiply>(
fq1.empty() ? input1 : fakeQuantize1,
fq2.empty() ? input2 : fakeQuantize2);
multiply->set_friendly_name("multiply");
auto const fakeQuantizeAfter = fqAfter.empty() ?
nullptr :
makeFakeQuantize(multiply, precision, fqAfter);
if (fakeQuantizeAfter != nullptr) {
fakeQuantizeAfter->set_friendly_name("fakeQuantizeAfter");
}
const std::shared_ptr<Node> result = fakeQuantizeAfter == nullptr ? std::dynamic_pointer_cast<Node>(multiply) : fakeQuantizeAfter;
ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(result) };
std::shared_ptr<ngraph::Function> function = std::make_shared<ngraph::Function>(
results,
secondInputIsConstant ?
ngraph::ParameterVector{ input1 } :
ngraph::ParameterVector{ input1, ngraph::as_type_ptr<ngraph::opset1::Parameter>(input2) },
"MultiplyTransformation");
return function;
}
} // namespace subgraph
} // namespace builder
} // namespace ngraph

View File

@ -0,0 +1,154 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "lpt_ngraph_functions/multiply_partial_function.hpp"
#include <memory>
#include <ngraph/opsets/opset1.hpp>
#include <ov_ops/type_relaxed.hpp>
#include "ngraph_functions/subgraph_builders.hpp"
#include "low_precision/network_helper.hpp"
#include "lpt_ngraph_functions/common/builders.hpp"
#include "lpt_ngraph_functions/common/dequantization_operations.hpp"
using namespace ov::pass::low_precision;
namespace ngraph {
namespace builder {
namespace subgraph {
namespace multiply_partial_function {
struct BranchNodes {
std::shared_ptr<Node> input;
std::shared_ptr<Node> dequantization;
};
BranchNodes getBranch(const MultiplyPartialBranch& branch) {
const std::shared_ptr<Node> parent = branch.constant.empty() ?
std::make_shared<ngraph::opset1::Parameter>(branch.precisionBeforeDequantization, branch.inputShape) :
std::dynamic_pointer_cast<Node>(std::make_shared<ngraph::opset1::Constant>(
branch.constant.outPrecision,
branch.constant.shape,
branch.constant.values));
const auto dequantization = makeDequantization(parent, branch.dequantization);
return {parent, dequantization};
}
} // namespace multiply_partial_function
std::shared_ptr<ngraph::Function> MultiplyPartialFunction::get(
const element::Type precision,
const MultiplyPartialValues& actualValues) {
auto branch1Structure = actualValues.branch1;
branch1Structure.precisionBeforeDequantization = precision;
branch1Structure.dequantization.multiply.outPrecision = precision;
auto branch2Structure = actualValues.branch2;
branch2Structure.precisionBeforeDequantization = precision;
branch2Structure.dequantization.multiply.outPrecision = precision;
const auto branchNodes1 = multiply_partial_function::getBranch(actualValues.branch1);
const auto branchNodes2 = multiply_partial_function::getBranch(actualValues.branch2);
auto multiplyOriginal = opset1::Multiply(
ov::op::TemporaryReplaceOutputType(branchNodes1.dequantization, element::f32).get(),
ov::op::TemporaryReplaceOutputType(branchNodes2.dequantization, element::f32).get());
const std::shared_ptr<ngraph::Node> multiply = std::make_shared<ov::op::TypeRelaxed<ngraph::opset1::Multiply>>(
multiplyOriginal,
std::vector<element::Type>{element::f32, element::f32},
std::vector<element::Type>{precision});
auto& rtInfo = multiply->get_rt_info();
rtInfo["Variant::std::string"] = "multiply";
multiply->set_friendly_name("output");
ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(multiply) };
ngraph::ParameterVector inputs;
if (is_type<opset1::Parameter>(branchNodes1.input)) {
inputs.push_back(std::dynamic_pointer_cast<opset1::Parameter>(branchNodes1.input));
}
if (is_type<opset1::Parameter>(branchNodes2.input)) {
inputs.push_back(std::dynamic_pointer_cast<opset1::Parameter>(branchNodes2.input));
}
return std::make_shared<ngraph::Function>(results, inputs, "MultiplyTransformation");
}
std::shared_ptr<ngraph::Function> MultiplyPartialFunction::get(
const ngraph::element::Type precision,
const ngraph::PartialShape& inputShape,
const bool broadcast1,
const ngraph::builder::subgraph::FakeQuantizeOnData& fq1,
const bool broadcast2,
const ngraph::builder::subgraph::FakeQuantizeOnData& fq2,
const ngraph::builder::subgraph::FakeQuantizeOnData& fqAfter,
const bool secondInputIsConstant) {
auto inputShape1 = inputShape;
if (broadcast1) {
inputShape1[2] = 1;
inputShape1[3] = 1;
}
ngraph::PartialShape inputShape2;
if (secondInputIsConstant) {
inputShape2 = {};
} else {
inputShape2 = inputShape;
if (broadcast2) {
inputShape2[2] = 1;
inputShape2[3] = 1;
}
}
const auto input1 = std::make_shared<ngraph::opset1::Parameter>(precision, inputShape1);
const auto fakeQuantize1 = fq1.empty() ?
nullptr :
ngraph::builder::makeFakeQuantize(
input1, precision, fq1.quantizationLevel, fq1.constantShape,
fq1.inputLowValues, fq1.inputHighValues, fq1.outputLowValues, fq1.outputHighValues);
if (fakeQuantize1 != nullptr) {
fakeQuantize1->set_friendly_name("fakeQuantize1");
}
const std::shared_ptr<ngraph::Node> input2 = secondInputIsConstant ?
makeConstant(element::f32, Shape{}, std::vector<float>{0.5f}, false) :
std::make_shared<ngraph::opset1::Parameter>(precision, inputShape2);
const auto fakeQuantize2 = fq2.empty() ?
nullptr :
ngraph::builder::makeFakeQuantize(
input2, precision, fq2.quantizationLevel, fq2.constantShape,
fq2.inputLowValues, fq2.inputHighValues, fq2.outputLowValues, fq2.outputHighValues);
if (fakeQuantize2 != nullptr) {
fakeQuantize2->set_friendly_name("fakeQuantize2");
}
const auto multiply = std::make_shared<ngraph::opset1::Multiply>(
fq1.empty() ? input1 : fakeQuantize1,
fq2.empty() ? input2 : fakeQuantize2);
multiply->set_friendly_name("multiply");
auto const fakeQuantizeAfter = fqAfter.empty() ?
nullptr :
makeFakeQuantize(multiply, precision, fqAfter);
if (fakeQuantizeAfter != nullptr) {
fakeQuantizeAfter->set_friendly_name("fakeQuantizeAfter");
}
const std::shared_ptr<Node> result = fakeQuantizeAfter == nullptr ? std::dynamic_pointer_cast<Node>(multiply) : fakeQuantizeAfter;
ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(result) };
std::shared_ptr<ngraph::Function> function = std::make_shared<ngraph::Function>(
results,
secondInputIsConstant ?
ngraph::ParameterVector{ input1 } :
ngraph::ParameterVector{ input1, ngraph::as_type_ptr<ngraph::opset1::Parameter>(input2) },
"MultiplyTransformation");
return function;
}
} // namespace subgraph
} // namespace builder
} // namespace ngraph