[Transformations] FuseU4WeightsAndZeroPoint transformation (#20503)

This commit is contained in:
Vladislav Golubev 2023-10-24 07:44:26 +02:00 committed by GitHub
parent d490ab68d1
commit afda7ad70f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 322 additions and 1 deletions

View File

@ -0,0 +1,26 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/pass/graph_rewrite.hpp"
#include "transformations_visibility.hpp"
namespace ov {
namespace pass {
class TRANSFORMATIONS_API ConvertU4WeightsZeroPointToScalar;
} // namespace pass
} // namespace ov
/**
* @ingroup ie_transformation_common_api
* @brief Converts U4 weights zero point to scalar if all values are equal
*/
class ov::pass::ConvertU4WeightsZeroPointToScalar : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ConvertU4WeightsZeroPointToScalar", "0");
ConvertU4WeightsZeroPointToScalar();
};

View File

@ -0,0 +1,80 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/common_optimizations/convert_u4_weights_zero_point_to_scalar.hpp"
#include "itt.hpp"
#include "openvino/core/rt_info.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/subtract.hpp"
#include "openvino/pass/pattern/op/or.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "openvino/reference/autobroadcast_binop.hpp"
#include "transformations/utils/utils.hpp"
ov::pass::ConvertU4WeightsZeroPointToScalar::ConvertU4WeightsZeroPointToScalar() {
MATCHER_SCOPE(ConvertU4WeightsZeroPointToScalar);
auto weights_m = pattern::wrap_type<ov::op::v0::Constant>(pattern::type_matches(ov::element::u4));
auto convert_m = pattern::wrap_type<ov::op::v0::Convert>({weights_m}, pattern::consumers_count(1));
auto float_zp_predicate = [](ov::Output<ov::Node> output) -> bool {
return pattern::type_matches_any({ov::element::f32, ov::element::f16})(output) &&
pattern::consumers_count(1)(output);
};
auto float_zero_point_m = pattern::wrap_type<ov::op::v0::Constant>(float_zp_predicate);
auto u4_zp_predicate = [](ov::Output<ov::Node> output) -> bool {
return pattern::type_matches(ov::element::u4)(output) && pattern::consumers_count(1)(output);
};
auto u4_zero_point_m = pattern::wrap_type<ov::op::v0::Constant>(u4_zp_predicate);
auto zero_point_convert_m = pattern::wrap_type<ov::op::v0::Convert>({u4_zero_point_m}, float_zp_predicate);
auto zero_point_m = std::make_shared<pattern::op::Or>(OutputVector{float_zero_point_m, zero_point_convert_m});
auto subtract_m = pattern::wrap_type<ov::op::v1::Subtract>({convert_m, zero_point_m});
ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) {
auto& pattern_map = m.get_pattern_value_map();
auto weights = ov::as_type_ptr<ov::op::v0::Constant>(pattern_map.at(weights_m).get_node_shared_ptr());
std::shared_ptr<ov::op::v0::Constant> zero_point;
if (pattern_map.count(float_zero_point_m)) {
const auto& float_zp = pattern_map.at(float_zero_point_m);
zero_point = ov::as_type_ptr<ov::op::v0::Constant>(float_zp.get_node_shared_ptr());
} else {
const auto& u4_zp = pattern_map.at(u4_zero_point_m);
zero_point = ov::as_type_ptr<ov::op::v0::Constant>(u4_zp.get_node_shared_ptr());
}
if (!weights || !zero_point)
return false;
// Due to the matcher specific and Subtract branches similarity,
// weights and zero_point might be mixed up with each other
if (ov::shape_size(weights->get_shape()) < ov::shape_size(zero_point->get_shape()))
std::swap(zero_point, weights);
auto zero_point_shape = zero_point->get_shape();
if (ov::shape_size(zero_point_shape) == 1)
return false;
const auto& weights_shape = weights->get_shape();
const size_t weights_rank = weights_shape.size();
const size_t zero_point_rank = zero_point_shape.size();
// Zero point constant can be converted into scalar only if this does not affect Subtract output shape
if (weights_rank < zero_point_rank)
return false;
zero_point_shape.insert(zero_point_shape.begin(), weights_rank - zero_point_rank, 1);
for (size_t i = 0; i < weights_rank; ++i) {
if (zero_point_shape[i] > weights_shape[i])
return false;
}
float zp_value;
if (!ov::op::util::get_single_value(zero_point, zp_value))
return false;
const auto new_zp = ov::op::v0::Constant::create(zero_point->get_element_type(), {}, {zp_value});
return ov::replace_node_update_name(zero_point, new_zp);
};
auto m = std::make_shared<ov::pass::pattern::Matcher>(subtract_m, matcher_name);
register_matcher(m, callback);
}

View File

@ -21,6 +21,7 @@
#include "transformations/common_optimizations/conv_to_binary_conv.hpp"
#include "transformations/common_optimizations/convert_nms_gather_path_to_unsigned.hpp"
#include "transformations/common_optimizations/convert_quantize_dequantize.hpp"
#include "transformations/common_optimizations/convert_u4_weights_zero_point_to_scalar.hpp"
#include "transformations/common_optimizations/convolution_to_group_convolution_fusion.hpp"
#include "transformations/common_optimizations/depth_to_space_fusion.hpp"
#include "transformations/common_optimizations/dilated_convolution_converter.hpp"
@ -212,6 +213,7 @@ bool ov::pass::MOCTransformations::run_on_model(const std::shared_ptr<ov::Model>
ADD_MATCHER(common_fusions, ShuffleChannelsFusion, !m_use_shapes)
ADD_MATCHER(common_fusions, NonZeroHorizontalFusion)
ADD_MATCHER(common_fusions, AdaptivePoolToReduce)
ADD_MATCHER(common_fusions, ConvertU4WeightsZeroPointToScalar)
common_fusions->set_name("ov::pass::CommonFusions");
REGISTER_PASS(manager, BinarizeWeights)

View File

@ -31,6 +31,8 @@ bool get_single_value(const std::shared_ptr<op::v0::Constant>& const_node, float
return util::normalize_single_value(const_node->get_vector<bfloat16>(), value, check_value_range);
case element::Type_t::f64:
return util::normalize_single_value(const_node->get_vector<double>(), value, check_value_range);
case element::Type_t::i4:
return util::normalize_single_value(const_node->cast_vector<int8_t>(), value, check_value_range);
case element::Type_t::i8:
return util::normalize_single_value(const_node->get_vector<int8_t>(), value, check_value_range);
case element::Type_t::i16:
@ -39,6 +41,8 @@ bool get_single_value(const std::shared_ptr<op::v0::Constant>& const_node, float
return util::normalize_single_value(const_node->get_vector<int32_t>(), value, check_value_range);
case element::Type_t::i64:
return util::normalize_single_value(const_node->get_vector<int64_t>(), value, check_value_range);
case element::Type_t::u4:
return util::normalize_single_value(const_node->cast_vector<int8_t>(), value, check_value_range);
case element::Type_t::u8:
return util::normalize_single_value(const_node->get_vector<uint8_t>(), value, check_value_range);
case element::Type_t::u16:

View File

@ -0,0 +1,208 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/common_optimizations/convert_u4_weights_zero_point_to_scalar.hpp"
#include <gtest/gtest.h>
#include "common_test_utils/ov_test_utils.hpp"
#include "openvino/core/model.hpp"
#include "openvino/op/convert.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/op/subtract.hpp"
#include "openvino/pass/manager.hpp"
using namespace testing;
using namespace ov;
TEST_F(TransformationTestsF, ConvertU4WeightsFloatZeroPointToScalar) {
auto weights_precision = ov::element::u4;
auto decompression_precision = ov::element::f32;
ov::Shape weights_shape{32, 128, 64};
ov::Shape decompression_shape{32, 1, 64};
{
auto weights = ov::op::v0::Constant::create(weights_precision, weights_shape, {4});
auto convert = std::make_shared<ov::op::v0::Convert>(weights, decompression_precision);
auto zero_point = ov::op::v0::Constant::create(decompression_precision, decompression_shape, {8.1f});
auto subtract = std::make_shared<ov::op::v1::Subtract>(convert, zero_point);
auto scale = ov::op::v0::Constant::create(decompression_precision, decompression_shape, {3.f});
auto multiply = std::make_shared<ov::op::v1::Multiply>(subtract, scale);
model = std::make_shared<Model>(NodeVector{multiply}, ParameterVector{});
manager.register_pass<ov::pass::ConvertU4WeightsZeroPointToScalar>();
}
{
ov::Shape scalar_shape{};
auto weights = ov::op::v0::Constant::create(weights_precision, weights_shape, {4});
auto convert = std::make_shared<ov::op::v0::Convert>(weights, decompression_precision);
auto zero_point = ov::op::v0::Constant::create(decompression_precision, scalar_shape, {8.1f});
auto subtract = std::make_shared<ov::op::v1::Subtract>(convert, zero_point);
auto scale = ov::op::v0::Constant::create(decompression_precision, decompression_shape, {3.f});
auto multiply = std::make_shared<ov::op::v1::Multiply>(subtract, scale);
model_ref = std::make_shared<Model>(NodeVector{multiply}, ParameterVector{});
}
comparator.enable(FunctionsComparator::ACCURACY);
comparator.enable(FunctionsComparator::CONST_VALUES);
}
TEST_F(TransformationTestsF, ConvertU4WeightsU4ZeroPointToScalar) {
auto weights_precision = ov::element::u4;
auto decompression_precision = ov::element::f32;
ov::Shape weights_shape{32, 128, 64};
ov::Shape decompression_shape{32, 1, 64};
{
auto weights = ov::op::v0::Constant::create(weights_precision, weights_shape, {4});
auto convert = std::make_shared<ov::op::v0::Convert>(weights, decompression_precision);
auto zero_point = ov::op::v0::Constant::create(weights_precision, decompression_shape, {8});
auto zero_point_convert = std::make_shared<ov::op::v0::Convert>(zero_point, decompression_precision);
auto subtract = std::make_shared<ov::op::v1::Subtract>(convert, zero_point_convert);
auto scale = ov::op::v0::Constant::create(decompression_precision, decompression_shape, {3.f});
auto multiply = std::make_shared<ov::op::v1::Multiply>(subtract, scale);
model = std::make_shared<Model>(NodeVector{multiply}, ParameterVector{});
manager.register_pass<ov::pass::ConvertU4WeightsZeroPointToScalar>();
}
{
ov::Shape scalar_shape{};
auto weights = ov::op::v0::Constant::create(weights_precision, weights_shape, {4});
auto convert = std::make_shared<ov::op::v0::Convert>(weights, decompression_precision);
auto zero_point = ov::op::v0::Constant::create(weights_precision, scalar_shape, {8});
auto zero_point_convert = std::make_shared<ov::op::v0::Convert>(zero_point, decompression_precision);
auto subtract = std::make_shared<ov::op::v1::Subtract>(convert, zero_point_convert);
auto scale = ov::op::v0::Constant::create(decompression_precision, decompression_shape, {3.f});
auto multiply = std::make_shared<ov::op::v1::Multiply>(subtract, scale);
model_ref = std::make_shared<Model>(NodeVector{multiply}, ParameterVector{});
}
comparator.enable(FunctionsComparator::ACCURACY);
comparator.enable(FunctionsComparator::CONST_VALUES);
}
TEST_F(TransformationTestsF, ConvertU4WeightsFloatZeroPointToScalarWeightsWithBiggerRank) {
auto weights_precision = ov::element::u4;
auto decompression_precision = ov::element::f32;
ov::Shape weights_shape{32, 128, 64};
ov::Shape decompression_shape{64};
{
auto weights = ov::op::v0::Constant::create(weights_precision, weights_shape, {4});
auto convert = std::make_shared<ov::op::v0::Convert>(weights, decompression_precision);
auto zero_point = ov::op::v0::Constant::create(decompression_precision, decompression_shape, {8});
auto subtract = std::make_shared<ov::op::v1::Subtract>(convert, zero_point);
auto scale = ov::op::v0::Constant::create(decompression_precision, decompression_shape, {3.f});
auto multiply = std::make_shared<ov::op::v1::Multiply>(subtract, scale);
model = std::make_shared<Model>(NodeVector{multiply}, ParameterVector{});
manager.register_pass<ov::pass::ConvertU4WeightsZeroPointToScalar>();
}
{
ov::Shape scalar_shape{};
auto weights = ov::op::v0::Constant::create(weights_precision, weights_shape, {4});
auto convert = std::make_shared<ov::op::v0::Convert>(weights, decompression_precision);
auto zero_point = ov::op::v0::Constant::create(decompression_precision, scalar_shape, {8});
auto subtract = std::make_shared<ov::op::v1::Subtract>(convert, zero_point);
auto scale = ov::op::v0::Constant::create(decompression_precision, decompression_shape, {3.f});
auto multiply = std::make_shared<ov::op::v1::Multiply>(subtract, scale);
model_ref = std::make_shared<Model>(NodeVector{multiply}, ParameterVector{});
}
comparator.enable(FunctionsComparator::ACCURACY);
comparator.enable(FunctionsComparator::CONST_VALUES);
}
TEST_F(TransformationTestsF, FuseU4WeightsAndZeroPointNotScalarLikeZP) {
auto weights_precision = ov::element::u8;
auto decompression_precision = ov::element::f32;
ov::Shape weights_shape{32, 128, 64};
ov::Shape decompression_shape{32, 1, 64};
auto weights = ov::op::v0::Constant::create(weights_precision, weights_shape, {4});
auto convert = std::make_shared<ov::op::v0::Convert>(weights, decompression_precision);
std::vector<std::int8_t> zero_point_values(ov::shape_size(decompression_shape), 8);
zero_point_values.back() = 6;
auto zero_point = ov::op::v0::Constant::create(weights_precision, decompression_shape, zero_point_values);
auto zero_point_convert = std::make_shared<ov::op::v0::Convert>(zero_point, decompression_precision);
auto subtract = std::make_shared<ov::op::v1::Subtract>(convert, zero_point_convert);
auto scale = ov::op::v0::Constant::create(decompression_precision, decompression_shape, {3.f});
auto multiply = std::make_shared<ov::op::v1::Multiply>(subtract, scale);
model = std::make_shared<Model>(NodeVector{multiply}, ParameterVector{});
manager.register_pass<ov::pass::ConvertU4WeightsZeroPointToScalar>();
}
TEST_F(TransformationTestsF, FuseU4WeightsAndZeroPointNotU4Weights) {
auto weights_precision = ov::element::u8;
auto decompression_precision = ov::element::f32;
ov::Shape weights_shape{32, 128, 64};
ov::Shape decompression_shape{32, 1, 64};
auto weights = ov::op::v0::Constant::create(weights_precision, weights_shape, {4});
auto convert = std::make_shared<ov::op::v0::Convert>(weights, decompression_precision);
auto zero_point = ov::op::v0::Constant::create(weights_precision, decompression_shape, {8});
auto zero_point_convert = std::make_shared<ov::op::v0::Convert>(zero_point, decompression_precision);
auto subtract = std::make_shared<ov::op::v1::Subtract>(convert, zero_point_convert);
auto scale = ov::op::v0::Constant::create(decompression_precision, decompression_shape, {3.f});
auto multiply = std::make_shared<ov::op::v1::Multiply>(subtract, scale);
model = std::make_shared<Model>(NodeVector{multiply}, ParameterVector{});
manager.register_pass<ov::pass::ConvertU4WeightsZeroPointToScalar>();
}
TEST_F(TransformationTestsF, ConvertU4WeightsFloatZeroPointToScalarAdditionalZPConsumer) {
auto weights_precision = ov::element::u4;
auto decompression_precision = ov::element::f32;
ov::Shape weights_shape{32, 128, 64};
ov::Shape decompression_shape{32, 1, 64};
auto weights = ov::op::v0::Constant::create(weights_precision, weights_shape, {4});
auto convert = std::make_shared<ov::op::v0::Convert>(weights, decompression_precision);
auto zero_point = ov::op::v0::Constant::create(decompression_precision, decompression_shape, {8});
auto zero_point_consumer = std::make_shared<ov::op::v3::ShapeOf>(zero_point);
auto subtract = std::make_shared<ov::op::v1::Subtract>(convert, zero_point);
auto scale = ov::op::v0::Constant::create(decompression_precision, decompression_shape, {3.f});
auto multiply = std::make_shared<ov::op::v1::Multiply>(subtract, scale);
model = std::make_shared<Model>(NodeVector{multiply, zero_point_consumer}, ParameterVector{});
manager.register_pass<ov::pass::ConvertU4WeightsZeroPointToScalar>();
}
TEST_F(TransformationTestsF, ConvertU4WeightsU4ZeroPointToScalarAdditionalZPConsumer) {
auto weights_precision = ov::element::u4;
auto decompression_precision = ov::element::f32;
ov::Shape weights_shape{32, 128, 64};
ov::Shape decompression_shape{32, 1, 64};
auto weights = ov::op::v0::Constant::create(weights_precision, weights_shape, {4});
auto convert = std::make_shared<ov::op::v0::Convert>(weights, decompression_precision);
auto zero_point = ov::op::v0::Constant::create(weights_precision, decompression_shape, {8});
auto zero_point_consumer = std::make_shared<ov::op::v3::ShapeOf>(zero_point);
auto zero_point_convert = std::make_shared<ov::op::v0::Convert>(zero_point, decompression_precision);
auto subtract = std::make_shared<ov::op::v1::Subtract>(convert, zero_point_convert);
auto scale = ov::op::v0::Constant::create(decompression_precision, decompression_shape, {3.f});
auto multiply = std::make_shared<ov::op::v1::Multiply>(subtract, scale);
model = std::make_shared<Model>(NodeVector{multiply, zero_point_consumer}, ParameterVector{});
manager.register_pass<ov::pass::ConvertU4WeightsZeroPointToScalar>();
}
TEST_F(TransformationTestsF, ConvertU4WeightsU4ZeroPointToScalarAdditionalZPConvertConsumer) {
auto weights_precision = ov::element::u4;
auto decompression_precision = ov::element::f32;
ov::Shape weights_shape{32, 128, 64};
ov::Shape decompression_shape{32, 1, 64};
auto weights = ov::op::v0::Constant::create(weights_precision, weights_shape, {4});
auto convert = std::make_shared<ov::op::v0::Convert>(weights, decompression_precision);
auto zero_point = ov::op::v0::Constant::create(weights_precision, decompression_shape, {8});
auto zero_point_convert = std::make_shared<ov::op::v0::Convert>(zero_point, decompression_precision);
auto zero_point_convert_consumer = std::make_shared<ov::op::v3::ShapeOf>(zero_point_convert);
auto subtract = std::make_shared<ov::op::v1::Subtract>(convert, zero_point_convert);
auto scale = ov::op::v0::Constant::create(decompression_precision, decompression_shape, {3.f});
auto multiply = std::make_shared<ov::op::v1::Multiply>(subtract, scale);
model = std::make_shared<Model>(NodeVector{multiply, zero_point_convert_consumer}, ParameterVector{});
manager.register_pass<ov::pass::ConvertU4WeightsZeroPointToScalar>();
}
TEST_F(TransformationTestsF, ConvertU4WeightsU4ZeroPointToScalarZPWithBiggerRank) {
auto weights_precision = ov::element::u4;
auto decompression_precision = ov::element::f32;
ov::Shape weights_shape{32, 128, 64};
ov::Shape decompression_shape{1, 32, 1, 64};
auto weights = ov::op::v0::Constant::create(weights_precision, weights_shape, {4});
auto convert = std::make_shared<ov::op::v0::Convert>(weights, decompression_precision);
auto zero_point = ov::op::v0::Constant::create(weights_precision, decompression_shape, {8});
auto zero_point_convert = std::make_shared<ov::op::v0::Convert>(zero_point, decompression_precision);
auto zero_point_convert_consumer = std::make_shared<ov::op::v3::ShapeOf>(zero_point_convert);
auto subtract = std::make_shared<ov::op::v1::Subtract>(convert, zero_point_convert);
auto scale = ov::op::v0::Constant::create(decompression_precision, decompression_shape, {3.f});
auto multiply = std::make_shared<ov::op::v1::Multiply>(subtract, scale);
model = std::make_shared<Model>(NodeVector{multiply, zero_point_convert_consumer}, ParameterVector{});
manager.register_pass<ov::pass::ConvertU4WeightsZeroPointToScalar>();
}

View File

@ -319,7 +319,8 @@ bool replace_output_update_name(Output<Node> output, const Output<Node>& replace
bool replace_node_update_name(const std::shared_ptr<Node>& target, const std::shared_ptr<Node>& replacement) {
for (auto& output : target->output(0).get_target_inputs()) {
if (ov::as_type<op::v0::Parameter>(replacement->input_value(0).get_node()) &&
if (replacement->get_input_size() > 0 &&
ov::as_type<op::v0::Parameter>(replacement->input_value(0).get_node()) &&
ov::as_type<op::v0::Result>(output.get_node())) {
return false;
}