diff --git a/src/common/transformations/include/transformations/common_optimizations/convert_u4_weights_zero_point_to_scalar.hpp b/src/common/transformations/include/transformations/common_optimizations/convert_u4_weights_zero_point_to_scalar.hpp new file mode 100644 index 00000000000..0b8d31b4040 --- /dev/null +++ b/src/common/transformations/include/transformations/common_optimizations/convert_u4_weights_zero_point_to_scalar.hpp @@ -0,0 +1,26 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/pass/graph_rewrite.hpp" +#include "transformations_visibility.hpp" + +namespace ov { +namespace pass { + +class TRANSFORMATIONS_API ConvertU4WeightsZeroPointToScalar; + +} // namespace pass +} // namespace ov + +/** + * @ingroup ie_transformation_common_api + * @brief Converts U4 weights zero point to scalar if all values are equal + */ +class ov::pass::ConvertU4WeightsZeroPointToScalar : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("ConvertU4WeightsZeroPointToScalar", "0"); + ConvertU4WeightsZeroPointToScalar(); +}; diff --git a/src/common/transformations/src/transformations/common_optimizations/convert_u4_weights_zero_point_to_scalar.cpp b/src/common/transformations/src/transformations/common_optimizations/convert_u4_weights_zero_point_to_scalar.cpp new file mode 100644 index 00000000000..6313db127ac --- /dev/null +++ b/src/common/transformations/src/transformations/common_optimizations/convert_u4_weights_zero_point_to_scalar.cpp @@ -0,0 +1,80 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "transformations/common_optimizations/convert_u4_weights_zero_point_to_scalar.hpp" + +#include "itt.hpp" +#include "openvino/core/rt_info.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/subtract.hpp" +#include "openvino/pass/pattern/op/or.hpp" +#include "openvino/pass/pattern/op/wrap_type.hpp" +#include "openvino/reference/autobroadcast_binop.hpp" +#include "transformations/utils/utils.hpp" + +ov::pass::ConvertU4WeightsZeroPointToScalar::ConvertU4WeightsZeroPointToScalar() { + MATCHER_SCOPE(ConvertU4WeightsZeroPointToScalar); + auto weights_m = pattern::wrap_type(pattern::type_matches(ov::element::u4)); + auto convert_m = pattern::wrap_type({weights_m}, pattern::consumers_count(1)); + + auto float_zp_predicate = [](ov::Output output) -> bool { + return pattern::type_matches_any({ov::element::f32, ov::element::f16})(output) && + pattern::consumers_count(1)(output); + }; + auto float_zero_point_m = pattern::wrap_type(float_zp_predicate); + + auto u4_zp_predicate = [](ov::Output output) -> bool { + return pattern::type_matches(ov::element::u4)(output) && pattern::consumers_count(1)(output); + }; + auto u4_zero_point_m = pattern::wrap_type(u4_zp_predicate); + auto zero_point_convert_m = pattern::wrap_type({u4_zero_point_m}, float_zp_predicate); + + auto zero_point_m = std::make_shared(OutputVector{float_zero_point_m, zero_point_convert_m}); + auto subtract_m = pattern::wrap_type({convert_m, zero_point_m}); + + ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) { + auto& pattern_map = m.get_pattern_value_map(); + auto weights = ov::as_type_ptr(pattern_map.at(weights_m).get_node_shared_ptr()); + std::shared_ptr zero_point; + if (pattern_map.count(float_zero_point_m)) { + const auto& float_zp = pattern_map.at(float_zero_point_m); + zero_point = ov::as_type_ptr(float_zp.get_node_shared_ptr()); + } else { + const auto& u4_zp = pattern_map.at(u4_zero_point_m); + zero_point = ov::as_type_ptr(u4_zp.get_node_shared_ptr()); + } + if (!weights || !zero_point) + return false; + // Due to the matcher specific and Subtract branches similarity, + // weights and zero_point might be mixed up with each other + if (ov::shape_size(weights->get_shape()) < ov::shape_size(zero_point->get_shape())) + std::swap(zero_point, weights); + + auto zero_point_shape = zero_point->get_shape(); + if (ov::shape_size(zero_point_shape) == 1) + return false; + + const auto& weights_shape = weights->get_shape(); + const size_t weights_rank = weights_shape.size(); + const size_t zero_point_rank = zero_point_shape.size(); + // Zero point constant can be converted into scalar only if this does not affect Subtract output shape + if (weights_rank < zero_point_rank) + return false; + + zero_point_shape.insert(zero_point_shape.begin(), weights_rank - zero_point_rank, 1); + for (size_t i = 0; i < weights_rank; ++i) { + if (zero_point_shape[i] > weights_shape[i]) + return false; + } + + float zp_value; + if (!ov::op::util::get_single_value(zero_point, zp_value)) + return false; + const auto new_zp = ov::op::v0::Constant::create(zero_point->get_element_type(), {}, {zp_value}); + return ov::replace_node_update_name(zero_point, new_zp); + }; + + auto m = std::make_shared(subtract_m, matcher_name); + register_matcher(m, callback); +} diff --git a/src/common/transformations/src/transformations/common_optimizations/moc_transformations.cpp b/src/common/transformations/src/transformations/common_optimizations/moc_transformations.cpp index 068e1f27a29..86746f176ca 100644 --- a/src/common/transformations/src/transformations/common_optimizations/moc_transformations.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/moc_transformations.cpp @@ -21,6 +21,7 @@ #include "transformations/common_optimizations/conv_to_binary_conv.hpp" #include "transformations/common_optimizations/convert_nms_gather_path_to_unsigned.hpp" #include "transformations/common_optimizations/convert_quantize_dequantize.hpp" +#include "transformations/common_optimizations/convert_u4_weights_zero_point_to_scalar.hpp" #include "transformations/common_optimizations/convolution_to_group_convolution_fusion.hpp" #include "transformations/common_optimizations/depth_to_space_fusion.hpp" #include "transformations/common_optimizations/dilated_convolution_converter.hpp" @@ -212,6 +213,7 @@ bool ov::pass::MOCTransformations::run_on_model(const std::shared_ptr ADD_MATCHER(common_fusions, ShuffleChannelsFusion, !m_use_shapes) ADD_MATCHER(common_fusions, NonZeroHorizontalFusion) ADD_MATCHER(common_fusions, AdaptivePoolToReduce) + ADD_MATCHER(common_fusions, ConvertU4WeightsZeroPointToScalar) common_fusions->set_name("ov::pass::CommonFusions"); REGISTER_PASS(manager, BinarizeWeights) diff --git a/src/common/transformations/src/transformations/utils/utils.cpp b/src/common/transformations/src/transformations/utils/utils.cpp index 62b1765e7ba..b7cde395a66 100644 --- a/src/common/transformations/src/transformations/utils/utils.cpp +++ b/src/common/transformations/src/transformations/utils/utils.cpp @@ -31,6 +31,8 @@ bool get_single_value(const std::shared_ptr& const_node, float return util::normalize_single_value(const_node->get_vector(), value, check_value_range); case element::Type_t::f64: return util::normalize_single_value(const_node->get_vector(), value, check_value_range); + case element::Type_t::i4: + return util::normalize_single_value(const_node->cast_vector(), value, check_value_range); case element::Type_t::i8: return util::normalize_single_value(const_node->get_vector(), value, check_value_range); case element::Type_t::i16: @@ -39,6 +41,8 @@ bool get_single_value(const std::shared_ptr& const_node, float return util::normalize_single_value(const_node->get_vector(), value, check_value_range); case element::Type_t::i64: return util::normalize_single_value(const_node->get_vector(), value, check_value_range); + case element::Type_t::u4: + return util::normalize_single_value(const_node->cast_vector(), value, check_value_range); case element::Type_t::u8: return util::normalize_single_value(const_node->get_vector(), value, check_value_range); case element::Type_t::u16: diff --git a/src/common/transformations/tests/common_optimizations/convert_u4_weights_zero_point_to_scalar.cpp b/src/common/transformations/tests/common_optimizations/convert_u4_weights_zero_point_to_scalar.cpp new file mode 100644 index 00000000000..8fc896065e9 --- /dev/null +++ b/src/common/transformations/tests/common_optimizations/convert_u4_weights_zero_point_to_scalar.cpp @@ -0,0 +1,208 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "transformations/common_optimizations/convert_u4_weights_zero_point_to_scalar.hpp" + +#include + +#include "common_test_utils/ov_test_utils.hpp" +#include "openvino/core/model.hpp" +#include "openvino/op/convert.hpp" +#include "openvino/op/multiply.hpp" +#include "openvino/op/shape_of.hpp" +#include "openvino/op/subtract.hpp" +#include "openvino/pass/manager.hpp" + +using namespace testing; +using namespace ov; + +TEST_F(TransformationTestsF, ConvertU4WeightsFloatZeroPointToScalar) { + auto weights_precision = ov::element::u4; + auto decompression_precision = ov::element::f32; + ov::Shape weights_shape{32, 128, 64}; + ov::Shape decompression_shape{32, 1, 64}; + { + auto weights = ov::op::v0::Constant::create(weights_precision, weights_shape, {4}); + auto convert = std::make_shared(weights, decompression_precision); + auto zero_point = ov::op::v0::Constant::create(decompression_precision, decompression_shape, {8.1f}); + auto subtract = std::make_shared(convert, zero_point); + auto scale = ov::op::v0::Constant::create(decompression_precision, decompression_shape, {3.f}); + auto multiply = std::make_shared(subtract, scale); + model = std::make_shared(NodeVector{multiply}, ParameterVector{}); + manager.register_pass(); + } + { + ov::Shape scalar_shape{}; + auto weights = ov::op::v0::Constant::create(weights_precision, weights_shape, {4}); + auto convert = std::make_shared(weights, decompression_precision); + auto zero_point = ov::op::v0::Constant::create(decompression_precision, scalar_shape, {8.1f}); + auto subtract = std::make_shared(convert, zero_point); + auto scale = ov::op::v0::Constant::create(decompression_precision, decompression_shape, {3.f}); + auto multiply = std::make_shared(subtract, scale); + model_ref = std::make_shared(NodeVector{multiply}, ParameterVector{}); + } + comparator.enable(FunctionsComparator::ACCURACY); + comparator.enable(FunctionsComparator::CONST_VALUES); +} + +TEST_F(TransformationTestsF, ConvertU4WeightsU4ZeroPointToScalar) { + auto weights_precision = ov::element::u4; + auto decompression_precision = ov::element::f32; + ov::Shape weights_shape{32, 128, 64}; + ov::Shape decompression_shape{32, 1, 64}; + { + auto weights = ov::op::v0::Constant::create(weights_precision, weights_shape, {4}); + auto convert = std::make_shared(weights, decompression_precision); + auto zero_point = ov::op::v0::Constant::create(weights_precision, decompression_shape, {8}); + auto zero_point_convert = std::make_shared(zero_point, decompression_precision); + auto subtract = std::make_shared(convert, zero_point_convert); + auto scale = ov::op::v0::Constant::create(decompression_precision, decompression_shape, {3.f}); + auto multiply = std::make_shared(subtract, scale); + model = std::make_shared(NodeVector{multiply}, ParameterVector{}); + manager.register_pass(); + } + { + ov::Shape scalar_shape{}; + auto weights = ov::op::v0::Constant::create(weights_precision, weights_shape, {4}); + auto convert = std::make_shared(weights, decompression_precision); + auto zero_point = ov::op::v0::Constant::create(weights_precision, scalar_shape, {8}); + auto zero_point_convert = std::make_shared(zero_point, decompression_precision); + auto subtract = std::make_shared(convert, zero_point_convert); + auto scale = ov::op::v0::Constant::create(decompression_precision, decompression_shape, {3.f}); + auto multiply = std::make_shared(subtract, scale); + model_ref = std::make_shared(NodeVector{multiply}, ParameterVector{}); + } + comparator.enable(FunctionsComparator::ACCURACY); + comparator.enable(FunctionsComparator::CONST_VALUES); +} + +TEST_F(TransformationTestsF, ConvertU4WeightsFloatZeroPointToScalarWeightsWithBiggerRank) { + auto weights_precision = ov::element::u4; + auto decompression_precision = ov::element::f32; + ov::Shape weights_shape{32, 128, 64}; + ov::Shape decompression_shape{64}; + { + auto weights = ov::op::v0::Constant::create(weights_precision, weights_shape, {4}); + auto convert = std::make_shared(weights, decompression_precision); + auto zero_point = ov::op::v0::Constant::create(decompression_precision, decompression_shape, {8}); + auto subtract = std::make_shared(convert, zero_point); + auto scale = ov::op::v0::Constant::create(decompression_precision, decompression_shape, {3.f}); + auto multiply = std::make_shared(subtract, scale); + model = std::make_shared(NodeVector{multiply}, ParameterVector{}); + manager.register_pass(); + } + { + ov::Shape scalar_shape{}; + auto weights = ov::op::v0::Constant::create(weights_precision, weights_shape, {4}); + auto convert = std::make_shared(weights, decompression_precision); + auto zero_point = ov::op::v0::Constant::create(decompression_precision, scalar_shape, {8}); + auto subtract = std::make_shared(convert, zero_point); + auto scale = ov::op::v0::Constant::create(decompression_precision, decompression_shape, {3.f}); + auto multiply = std::make_shared(subtract, scale); + model_ref = std::make_shared(NodeVector{multiply}, ParameterVector{}); + } + comparator.enable(FunctionsComparator::ACCURACY); + comparator.enable(FunctionsComparator::CONST_VALUES); +} + +TEST_F(TransformationTestsF, FuseU4WeightsAndZeroPointNotScalarLikeZP) { + auto weights_precision = ov::element::u8; + auto decompression_precision = ov::element::f32; + ov::Shape weights_shape{32, 128, 64}; + ov::Shape decompression_shape{32, 1, 64}; + auto weights = ov::op::v0::Constant::create(weights_precision, weights_shape, {4}); + auto convert = std::make_shared(weights, decompression_precision); + std::vector zero_point_values(ov::shape_size(decompression_shape), 8); + zero_point_values.back() = 6; + auto zero_point = ov::op::v0::Constant::create(weights_precision, decompression_shape, zero_point_values); + auto zero_point_convert = std::make_shared(zero_point, decompression_precision); + auto subtract = std::make_shared(convert, zero_point_convert); + auto scale = ov::op::v0::Constant::create(decompression_precision, decompression_shape, {3.f}); + auto multiply = std::make_shared(subtract, scale); + model = std::make_shared(NodeVector{multiply}, ParameterVector{}); + manager.register_pass(); +} + +TEST_F(TransformationTestsF, FuseU4WeightsAndZeroPointNotU4Weights) { + auto weights_precision = ov::element::u8; + auto decompression_precision = ov::element::f32; + ov::Shape weights_shape{32, 128, 64}; + ov::Shape decompression_shape{32, 1, 64}; + auto weights = ov::op::v0::Constant::create(weights_precision, weights_shape, {4}); + auto convert = std::make_shared(weights, decompression_precision); + auto zero_point = ov::op::v0::Constant::create(weights_precision, decompression_shape, {8}); + auto zero_point_convert = std::make_shared(zero_point, decompression_precision); + auto subtract = std::make_shared(convert, zero_point_convert); + auto scale = ov::op::v0::Constant::create(decompression_precision, decompression_shape, {3.f}); + auto multiply = std::make_shared(subtract, scale); + model = std::make_shared(NodeVector{multiply}, ParameterVector{}); + manager.register_pass(); +} + +TEST_F(TransformationTestsF, ConvertU4WeightsFloatZeroPointToScalarAdditionalZPConsumer) { + auto weights_precision = ov::element::u4; + auto decompression_precision = ov::element::f32; + ov::Shape weights_shape{32, 128, 64}; + ov::Shape decompression_shape{32, 1, 64}; + auto weights = ov::op::v0::Constant::create(weights_precision, weights_shape, {4}); + auto convert = std::make_shared(weights, decompression_precision); + auto zero_point = ov::op::v0::Constant::create(decompression_precision, decompression_shape, {8}); + auto zero_point_consumer = std::make_shared(zero_point); + auto subtract = std::make_shared(convert, zero_point); + auto scale = ov::op::v0::Constant::create(decompression_precision, decompression_shape, {3.f}); + auto multiply = std::make_shared(subtract, scale); + model = std::make_shared(NodeVector{multiply, zero_point_consumer}, ParameterVector{}); + manager.register_pass(); +} + +TEST_F(TransformationTestsF, ConvertU4WeightsU4ZeroPointToScalarAdditionalZPConsumer) { + auto weights_precision = ov::element::u4; + auto decompression_precision = ov::element::f32; + ov::Shape weights_shape{32, 128, 64}; + ov::Shape decompression_shape{32, 1, 64}; + auto weights = ov::op::v0::Constant::create(weights_precision, weights_shape, {4}); + auto convert = std::make_shared(weights, decompression_precision); + auto zero_point = ov::op::v0::Constant::create(weights_precision, decompression_shape, {8}); + auto zero_point_consumer = std::make_shared(zero_point); + auto zero_point_convert = std::make_shared(zero_point, decompression_precision); + auto subtract = std::make_shared(convert, zero_point_convert); + auto scale = ov::op::v0::Constant::create(decompression_precision, decompression_shape, {3.f}); + auto multiply = std::make_shared(subtract, scale); + model = std::make_shared(NodeVector{multiply, zero_point_consumer}, ParameterVector{}); + manager.register_pass(); +} + +TEST_F(TransformationTestsF, ConvertU4WeightsU4ZeroPointToScalarAdditionalZPConvertConsumer) { + auto weights_precision = ov::element::u4; + auto decompression_precision = ov::element::f32; + ov::Shape weights_shape{32, 128, 64}; + ov::Shape decompression_shape{32, 1, 64}; + auto weights = ov::op::v0::Constant::create(weights_precision, weights_shape, {4}); + auto convert = std::make_shared(weights, decompression_precision); + auto zero_point = ov::op::v0::Constant::create(weights_precision, decompression_shape, {8}); + auto zero_point_convert = std::make_shared(zero_point, decompression_precision); + auto zero_point_convert_consumer = std::make_shared(zero_point_convert); + auto subtract = std::make_shared(convert, zero_point_convert); + auto scale = ov::op::v0::Constant::create(decompression_precision, decompression_shape, {3.f}); + auto multiply = std::make_shared(subtract, scale); + model = std::make_shared(NodeVector{multiply, zero_point_convert_consumer}, ParameterVector{}); + manager.register_pass(); +} + +TEST_F(TransformationTestsF, ConvertU4WeightsU4ZeroPointToScalarZPWithBiggerRank) { + auto weights_precision = ov::element::u4; + auto decompression_precision = ov::element::f32; + ov::Shape weights_shape{32, 128, 64}; + ov::Shape decompression_shape{1, 32, 1, 64}; + auto weights = ov::op::v0::Constant::create(weights_precision, weights_shape, {4}); + auto convert = std::make_shared(weights, decompression_precision); + auto zero_point = ov::op::v0::Constant::create(weights_precision, decompression_shape, {8}); + auto zero_point_convert = std::make_shared(zero_point, decompression_precision); + auto zero_point_convert_consumer = std::make_shared(zero_point_convert); + auto subtract = std::make_shared(convert, zero_point_convert); + auto scale = ov::op::v0::Constant::create(decompression_precision, decompression_shape, {3.f}); + auto multiply = std::make_shared(subtract, scale); + model = std::make_shared(NodeVector{multiply, zero_point_convert_consumer}, ParameterVector{}); + manager.register_pass(); +} diff --git a/src/core/src/graph_util.cpp b/src/core/src/graph_util.cpp index 8001678dab2..4c6a4d0f33e 100644 --- a/src/core/src/graph_util.cpp +++ b/src/core/src/graph_util.cpp @@ -319,7 +319,8 @@ bool replace_output_update_name(Output output, const Output& replace bool replace_node_update_name(const std::shared_ptr& target, const std::shared_ptr& replacement) { for (auto& output : target->output(0).get_target_inputs()) { - if (ov::as_type(replacement->input_value(0).get_node()) && + if (replacement->get_input_size() > 0 && + ov::as_type(replacement->input_value(0).get_node()) && ov::as_type(output.get_node())) { return false; }