From 4a3ce48f7af1f0d5eb355be6cd9f1ff0aa54be00 Mon Sep 17 00:00:00 2001 From: Mateusz Tabaka Date: Tue, 10 Oct 2023 11:10:05 +0200 Subject: [PATCH] CompressQuantizeWeights optimizations (#20025) * Optimize CompressQuantizeWeights transformation - remove CoordinateTransform usage from FakeQuantize reference implementation - move ZeroPointOptimizer functionality inside CompressQuantizeWeights - compute scale and zero point in the same loop Ticket: CVS-119273 * review comments * clang format * fix comments --- .../core/offline_transformations.cpp | 1 - .../include/compress_quantize_weights.hpp | 29 +- .../src/compress_quantize_weigths.cpp | 977 ++++++++++++++---- .../tests/utils/compress_quantize_weights.cpp | 36 +- .../openvino/reference/fake_quantize.hpp | 638 +++++++++--- .../functional/op_reference/fake_quantize.cpp | 16 + 6 files changed, 1307 insertions(+), 390 deletions(-) diff --git a/src/bindings/python/src/pyopenvino/core/offline_transformations.cpp b/src/bindings/python/src/pyopenvino/core/offline_transformations.cpp index 17a879c72b5..215a65da316 100644 --- a/src/bindings/python/src/pyopenvino/core/offline_transformations.cpp +++ b/src/bindings/python/src/pyopenvino/core/offline_transformations.cpp @@ -109,7 +109,6 @@ void regmodule_offline_transformations(py::module m) { [](std::shared_ptr model) { ov::pass::Manager manager; manager.register_pass(); - manager.register_pass(); manager.run_passes(model); }, py::arg("model")); diff --git a/src/common/offline_transformations/include/compress_quantize_weights.hpp b/src/common/offline_transformations/include/compress_quantize_weights.hpp index 62119cd907c..356ff01195a 100644 --- a/src/common/offline_transformations/include/compress_quantize_weights.hpp +++ b/src/common/offline_transformations/include/compress_quantize_weights.hpp @@ -10,7 +10,6 @@ namespace ov { namespace pass { class CompressQuantizeWeights; -class ZeroPointOptimizer; } // namespace pass } // namespace ov @@ -57,36 +56,10 @@ class ZeroPointOptimizer; Transformation prepares quantized constant data for Low Precision pipeline. Such constant data packing reduces IR size (.bin file size) in offline transformations. With that we can skip same calculations in the runtime and make loading of such sub-graphs to the plugin faster. + Additionally zero point can be fused to weights if it doesn't affect accuracy. */ class ov::pass::CompressQuantizeWeights : public ov::pass::MatcherPass { public: OPENVINO_RTTI("CompressQuantizeWeights", "0"); CompressQuantizeWeights(); }; - -/* - if zero_point == 0 we can eliminate Subtract from following dequantization subgraph: - - +-----------------+ - | Constant | - | (low precision) | - +-----------------+ - | - v - +------------------+ - | Convert | - | (to high prec) | - +------------------+ - | - v - +----------+ +------------+ - |zero point|--->| Subtract | - +----------+ +-----+------+ - | - v -*/ -class ov::pass::ZeroPointOptimizer : public ov::pass::MatcherPass { -public: - OPENVINO_RTTI("ZeroPointOptimizer"); - ZeroPointOptimizer(); -}; diff --git a/src/common/offline_transformations/src/compress_quantize_weigths.cpp b/src/common/offline_transformations/src/compress_quantize_weigths.cpp index f8ff0d9d9f8..6c9e4554782 100644 --- a/src/common/offline_transformations/src/compress_quantize_weigths.cpp +++ b/src/common/offline_transformations/src/compress_quantize_weigths.cpp @@ -5,95 +5,125 @@ #include "compress_quantize_weights.hpp" #include "openvino/core/rt_info.hpp" #include "openvino/core/validation_util.hpp" -#include "openvino/opsets/opset8.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/convert.hpp" +#include "openvino/op/divide.hpp" +#include "openvino/op/fake_quantize.hpp" +#include "openvino/op/multiply.hpp" +#include "openvino/op/subtract.hpp" #include "openvino/pass/constant_folding.hpp" #include "openvino/pass/pattern/op/or.hpp" #include "openvino/pass/pattern/op/wrap_type.hpp" -#include "transformations/rt_info/decompression.hpp" +#include "openvino/reference/autobroadcast_binop.hpp" +#include "openvino/reference/convert.hpp" +#include "openvino/reference/fake_quantize.hpp" +#include "validation_util.hpp" -static bool has_dequantization_subgraph(const std::shared_ptr& first_convert) { - auto first_convert_users = first_convert->get_users(); - const auto second_convert = std::find_if(first_convert_users.begin(), - first_convert_users.end(), - [](const std::shared_ptr& n) -> bool { - return ov::is_type(n); - }); - if (second_convert == first_convert_users.end()) - return false; - auto convert_or_subtract_users = (*second_convert)->get_users(); - const auto subtract = std::find_if(convert_or_subtract_users.begin(), - convert_or_subtract_users.end(), - [](const std::shared_ptr& n) -> bool { - return ov::is_type(n); - }); - if (subtract != convert_or_subtract_users.end()) { - convert_or_subtract_users = (*subtract)->get_users(); - } - const auto multiply = std::find_if(convert_or_subtract_users.begin(), - convert_or_subtract_users.end(), - [](const std::shared_ptr& n) -> bool { - return ov::is_type(n); - }); - return multiply != convert_or_subtract_users.end(); -} +static bool has_dequantization_subgraph(const std::shared_ptr& fq, + std::shared_ptr& convert_to_low_precision, + std::shared_ptr& convert_to_high_precision, + std::shared_ptr& zero_point); + +static bool compute_scale_and_zero_point(const std::shared_ptr& output_low, + const std::shared_ptr& output_high, + size_t levels, + ov::Tensor& scale_tensor, + ov::Tensor& zero_point_tensor, + bool& zero_point_is_zero); + +static std::shared_ptr compress_quantized_weights( + const std::shared_ptr& weights, + const std::shared_ptr& fq, + const std::shared_ptr& input_low, + const std::shared_ptr& input_high, + const std::shared_ptr& output_low, + const std::shared_ptr& output_high, + const std::shared_ptr& convert, + const std::shared_ptr& zero_point, + bool& can_fuse_zero_point); + +static std::shared_ptr compress_quantized_weights( + const std::shared_ptr& weights, + const std::shared_ptr& input_low, + const std::shared_ptr& input_high, + const ov::element::Type& low_precision_type, + size_t levels, + bool zero_point_is_zero, + const ov::Tensor& zero_point_tensor, + bool& can_fuse_zero_point); + +static void replace_with_dequantize_subgraph(const std::shared_ptr& fq, + const std::shared_ptr& new_weights, + const ov::element::Type& high_precision_type, + const ov::Shape& scale_or_zero_point_shape, + const ov::Tensor& scale_tensor, + bool zero_point_is_zero, + const ov::Tensor& zero_point_tensor = {}); ov::pass::CompressQuantizeWeights::CompressQuantizeWeights() { - auto weights_const_pattern = pattern::wrap_type(); - auto weigths_convert_pattern = pattern::wrap_type({weights_const_pattern}); - OutputVector weights_options{weights_const_pattern, weigths_convert_pattern}; + auto weights_const_pattern = pattern::wrap_type(); + auto weights_convert_pattern = pattern::wrap_type({weights_const_pattern}); + OutputVector weights_options{weights_const_pattern, weights_convert_pattern}; auto weights_pattern = std::make_shared(weights_options); - auto input_low_pattern = pattern::wrap_type(); - auto input_high_pattern = pattern::wrap_type(); - auto output_low_pattern = pattern::wrap_type(); - auto output_high_pattern = pattern::wrap_type(); - auto fq_pattern = pattern::wrap_type( + auto input_low_pattern = pattern::wrap_type(); + auto input_high_pattern = pattern::wrap_type(); + auto output_low_pattern = pattern::wrap_type(); + auto output_high_pattern = pattern::wrap_type(); + auto fq_pattern = pattern::wrap_type( {weights_pattern, input_low_pattern, input_high_pattern, output_low_pattern, output_high_pattern}); ov::matcher_pass_callback callback = [=](pattern::Matcher& m) { - auto fq = std::dynamic_pointer_cast(m.get_match_root()); + auto fq = std::dynamic_pointer_cast(m.get_match_root()); if (!fq) return false; - auto levels = fq->get_levels(); - if (levels <= 2 || levels > 256) - return false; - auto quantized_type = element::undefined; - // Currently we support two weights quantize types: i4 and i8 - if (levels <= 16) { - quantized_type = element::i4; - } else if (levels <= 256) { - quantized_type = element::i8; - } + const auto& high_precision_type = fq->get_element_type(); - const auto& pattern_value_map = m.get_pattern_value_map(); - const auto& input_type = fq->get_element_type(); - const auto& fq_data_input = fq->get_input_node_shared_ptr(0); - bool are_weights_decompressed = is_decompression(fq_data_input); - if (are_weights_decompressed) { - unmark_as_decompression(fq_data_input); - } + auto weights = ov::util::constantfold_subgraph(fq->get_input_node_shared_ptr(0)); + if (!weights) + return false; + auto input_low = ov::as_type_ptr(fq->get_input_node_shared_ptr(1)); + if (!input_low) + return false; + auto input_high = ov::as_type_ptr(fq->get_input_node_shared_ptr(2)); + if (!input_high) + return false; + auto output_low = ov::as_type_ptr(fq->get_input_node_shared_ptr(3)); + if (!output_low) + return false; + auto output_high = ov::as_type_ptr(fq->get_input_node_shared_ptr(4)); + if (!output_high) + return false; // skip dequantize part if there is already dequantization subgraph after FakeQuantize - auto fq_users = fq->get_users(); - if (fq_users.size() == 1 && has_dequantization_subgraph(fq_users[0])) { - auto& first_convert = fq_users[0]; - OPENVINO_SUPPRESS_DEPRECATED_START - if (auto new_weights = ov::get_constant_from_source(first_convert)) { - OPENVINO_SUPPRESS_DEPRECATED_END - new_weights->set_friendly_name(first_convert->get_friendly_name()); - replace_node(first_convert, new_weights); - copy_runtime_info(first_convert, new_weights); - // preserve dequantization subgraph for LP transformations - auto weights_users = new_weights->get_users(); - if (weights_users.size() == 1 && ov::is_type(weights_users[0])) { - ov::pass::disable_constant_folding(weights_users[0]); - } - return true; - } else { - if (are_weights_decompressed) { - mark_as_decompression(fq_data_input); - } + std::shared_ptr convert_to_low_precision; + std::shared_ptr convert_to_high_precision; + std::shared_ptr zero_point; + if (has_dequantization_subgraph(fq, convert_to_low_precision, convert_to_high_precision, zero_point)) { + bool can_fuse_zero_point = false; + auto new_weights = compress_quantized_weights(weights, + fq, + input_low, + input_high, + output_low, + output_high, + convert_to_low_precision, + zero_point, + can_fuse_zero_point); + if (!new_weights) return false; + + new_weights->set_friendly_name(convert_to_low_precision->get_friendly_name()); + replace_node(convert_to_low_precision, new_weights); + copy_runtime_info({fq, convert_to_low_precision}, new_weights); + // preserve dequantization subgraph for LP transformations + ov::pass::disable_constant_folding(convert_to_high_precision); + if (can_fuse_zero_point) { + auto subtract = convert_to_high_precision->get_users()[0]; + auto subtract_consumers = subtract->output(0).get_target_inputs(); + auto multiply = *(subtract_consumers.begin()); + multiply.replace_source_output(convert_to_high_precision); } + return true; } else { /* Quantize part @@ -103,33 +133,7 @@ ov::pass::CompressQuantizeWeights::CompressQuantizeWeights() { output_low = -levels / 2 output_high = levels - 1 + output_low The FakeQuantize result is converted to low precision type and then constant folded - */ - std::shared_ptr new_output_low = - op::v0::Constant::create(input_type, Shape{}, {-static_cast(levels / 2)}); - std::shared_ptr new_output_high = - std::make_shared(new_output_low, - op::v0::Constant::create(input_type, Shape{}, {levels - 1})); - const auto& weights_const = pattern_value_map.at(weights_const_pattern); - Output input_low = pattern_value_map.at(input_low_pattern); - Output input_high = pattern_value_map.at(input_high_pattern); - auto quantize = - fq->clone_with_new_inputs({fq_data_input, input_low, input_high, new_output_low, new_output_high}); - // Convert quantized weights to low precision type - std::shared_ptr new_weights = std::make_shared(quantize, quantized_type); - // Constant fold quantized weights - OPENVINO_SUPPRESS_DEPRECATED_START - if (auto constant = ov::get_constant_from_source(new_weights)) { - OPENVINO_SUPPRESS_DEPRECATED_END - new_weights = constant; - } else { - if (are_weights_decompressed) { - mark_as_decompression(fq_data_input); - } - return false; - } - new_weights->set_friendly_name(weights_const.get_node()->get_friendly_name()); - /* Dequantize part is performed by Convert(from low to high precision)->Subtract->Multiply subgraph. +-------------------------+ @@ -153,56 +157,65 @@ ov::pass::CompressQuantizeWeights::CompressQuantizeWeights() { scale = (output_high - output_low) / (new_output_high - new_output_low) zero_point = new_output_low - output_low / scale */ - Output output_low = pattern_value_map.at(output_low_pattern); - Output output_high = pattern_value_map.at(output_high_pattern); - const auto& fq_type = fq->get_output_element_type(0); - const bool should_convert = fq_type.is_real() && fq_type.size() < element::f32.size(); - if (should_convert) { - input_low = std::make_shared(input_low, element::f32); - input_high = std::make_shared(input_high, element::f32); - output_low = std::make_shared(output_low, element::f32); - output_high = std::make_shared(output_high, element::f32); - new_output_low = std::make_shared(new_output_low, element::f32); - new_output_high = std::make_shared(new_output_high, element::f32); - } - auto output_range = std::make_shared(output_high, output_low); - auto input_range = std::make_shared(new_output_high, new_output_low); - std::shared_ptr scale = std::make_shared(output_range, input_range); - auto descaled_output_low = std::make_shared(output_low, scale); - std::shared_ptr shift = std::make_shared(new_output_low, descaled_output_low); - OPENVINO_SUPPRESS_DEPRECATED_START - if (auto constant = ov::get_constant_from_source(scale)) { - OPENVINO_SUPPRESS_DEPRECATED_END - scale = constant; - } - auto zero = op::v0::Constant::create(scale->get_output_element_type(0), Shape{}, {0}); - auto scale_eq_zero = std::make_shared(scale, zero); - // shift equals to input_low - output_low / scale - // for positions where scale == 0, we put zero as shift - std::shared_ptr zero_point = std::make_shared(scale_eq_zero, zero, shift); - if (should_convert) { - scale = std::make_shared(scale, fq_type); - zero_point = std::make_shared(zero_point, fq_type); + auto levels = fq->get_levels(); + if (levels <= 2 || levels > 256) + return false; + auto low_precision_type = element::undefined; + // Currently we support two weights quantize types: i4 and i8 + if (levels <= 16) { + low_precision_type = element::i4; + } else if (levels <= 256) { + low_precision_type = element::i8; } - OPENVINO_SUPPRESS_DEPRECATED_START - if (auto constant = ov::get_constant_from_source(zero_point)) { - OPENVINO_SUPPRESS_DEPRECATED_END - zero_point = constant; + bool zero_point_is_zero = true; + PartialShape merged_shape{output_low->get_shape()}; + PartialShape::broadcast_merge_into(merged_shape, output_high->get_shape(), op::AutoBroadcastType::NUMPY); + Shape scale_or_zero_point_shape = merged_shape.to_shape(); + Tensor scale_tensor(high_precision_type, scale_or_zero_point_shape); + Tensor zero_point_tensor(high_precision_type, scale_or_zero_point_shape); + + if (!compute_scale_and_zero_point(output_low, + output_high, + levels, + scale_tensor, + zero_point_tensor, + zero_point_is_zero)) { + return false; } - OPENVINO_SUPPRESS_DEPRECATED_START - if (auto constant = ov::get_constant_from_source(scale)) { - OPENVINO_SUPPRESS_DEPRECATED_END - scale = constant; + + bool can_fuse_zero_point = false; + auto new_weights = compress_quantized_weights(weights, + input_low, + input_high, + low_precision_type, + levels, + zero_point_is_zero, + zero_point_tensor, + can_fuse_zero_point); + if (!new_weights) { + return false; } - auto convert_to_high_prec = std::make_shared(new_weights, input_type); - auto sub = register_new_node(convert_to_high_prec, zero_point); - auto mul = register_new_node(sub, scale); - mul->set_friendly_name(fq->get_friendly_name()); - copy_runtime_info(fq, {convert_to_high_prec, sub, mul}); - ov::pass::disable_constant_folding(convert_to_high_prec); - replace_node(fq, mul); + + if (zero_point_is_zero || can_fuse_zero_point) { + replace_with_dequantize_subgraph(fq, + new_weights, + high_precision_type, + scale_or_zero_point_shape, + scale_tensor, + true); + } else { + replace_with_dequantize_subgraph(fq, + new_weights, + high_precision_type, + scale_or_zero_point_shape, + scale_tensor, + zero_point_is_zero, + zero_point_tensor); + } + + return true; } return true; }; @@ -211,86 +224,622 @@ ov::pass::CompressQuantizeWeights::CompressQuantizeWeights() { this->register_matcher(m, callback); } -ov::pass::ZeroPointOptimizer::ZeroPointOptimizer() { - auto weights_pattern = pattern::wrap_type(); - auto zero_point_pattern = pattern::wrap_type(); - auto convert_pattern = pattern::wrap_type({weights_pattern}); - auto sub_pattern = pattern::wrap_type({convert_pattern, zero_point_pattern}); +static ov::Tensor tensor_from_constant(const std::shared_ptr& constant) { + return ov::Tensor(constant->get_element_type(), constant->get_shape(), const_cast(constant->get_data_ptr())); +} - ov::matcher_pass_callback callback = [=](pattern::Matcher& m) { - const auto& pattern_value_map = m.get_pattern_value_map(); - auto convert = pattern_value_map.at(convert_pattern).get_node_shared_ptr(); - auto sub = pattern_value_map.at(sub_pattern).get_node_shared_ptr(); - auto weights = - std::dynamic_pointer_cast(pattern_value_map.at(weights_pattern).get_node_shared_ptr()); - if (!weights || weights->get_element_type() != element::i8) - return false; - auto zero_point = - std::dynamic_pointer_cast(pattern_value_map.at(zero_point_pattern).get_node_shared_ptr()); - if (!zero_point) - return false; +static bool evaluate_node(const std::shared_ptr& node, + const ov::TensorVector& input_tensors, + ov::Tensor& output_tensor) { + if (node->get_output_size() != 1) + return false; - auto zp_value = zero_point->cast_vector(); - if (std::all_of(zp_value.begin(), zp_value.end(), [](float f) -> bool { - return std::fabs(f) <= std::numeric_limits::epsilon(); - })) { - copy_runtime_info(sub, convert); - replace_node(sub, convert); - } + ov::TensorVector output_tensors{ov::Tensor(node->get_output_element_type(0), node->get_output_shape(0))}; + if (!node->evaluate(output_tensors, input_tensors)) + return false; - auto int8_zero_point = std::make_shared( - std::make_shared(zero_point, opset8::Round::RoundMode::HALF_TO_EVEN), - weights->get_element_type()); - auto adj_zero_point = std::make_shared( - zero_point, - std::make_shared(int8_zero_point, convert->get_element_type())); + output_tensor = output_tensors[0]; - OPENVINO_SUPPRESS_DEPRECATED_START - auto adj_zero_point_const = ov::get_constant_from_source(adj_zero_point); - OPENVINO_SUPPRESS_DEPRECATED_END - if (!adj_zero_point_const) - return false; - auto adj_zero_point_val = adj_zero_point_const->cast_vector(); - bool is_adj_zero_point_close_to_zero = - std::all_of(adj_zero_point_val.begin(), adj_zero_point_val.end(), [](float f) -> bool { - return std::fabs(f) < 1e-4; - }); - if (!is_adj_zero_point_close_to_zero) - return false; + return true; +} - auto transformed = std::make_shared( - std::make_shared(std::make_shared(weights, int8_zero_point), - convert->get_element_type()), - adj_zero_point); - auto diff = std::make_shared(sub, transformed); - OPENVINO_SUPPRESS_DEPRECATED_START - auto diff_const = ov::get_constant_from_source(diff); - OPENVINO_SUPPRESS_DEPRECATED_END - if (!diff_const) - return false; - auto diff_val = diff_const->cast_vector(); - bool is_transformed_and_original_equal = std::all_of(diff_val.begin(), diff_val.end(), [](float f) -> bool { - return std::fabs(f) < std::numeric_limits::epsilon(); +static ov::TensorVector get_fake_quantize_input_tensors(const std::shared_ptr& fq) { + ov::Tensor weights_tensor; + + auto fq_input = fq->get_input_node_shared_ptr(0); + auto fq_input_constant = ov::as_type_ptr(fq_input); + + if (!fq_input_constant) { + auto weights = ov::as_type_ptr(fq_input->get_input_node_shared_ptr(0)); + if (!evaluate_node(fq_input, ov::TensorVector{tensor_from_constant(weights)}, weights_tensor)) + return {}; + } else { + weights_tensor = tensor_from_constant(fq_input_constant); + } + + auto in_low = ov::as_type_ptr(fq->get_input_node_shared_ptr(1)); + auto in_high = ov::as_type_ptr(fq->get_input_node_shared_ptr(2)); + auto out_low = ov::as_type_ptr(fq->get_input_node_shared_ptr(3)); + auto out_high = ov::as_type_ptr(fq->get_input_node_shared_ptr(4)); + + return ov::TensorVector{weights_tensor, + tensor_from_constant(in_low), + tensor_from_constant(in_high), + tensor_from_constant(out_low), + tensor_from_constant(out_high)}; +} + +template +static std::shared_ptr get_single_consumer_of_type(const std::shared_ptr& node) { + auto target_inputs = node->output(0).get_target_inputs(); + if (target_inputs.size() != 1) + return nullptr; + auto consumer = ov::as_type(target_inputs.begin()->get_node()); + if (!consumer) + return nullptr; + return consumer->shared_from_this(); +} + +bool has_dequantization_subgraph(const std::shared_ptr& fq, + std::shared_ptr& convert_to_low_precision, + std::shared_ptr& convert_to_high_precision, + std::shared_ptr& zero_point) { + convert_to_low_precision = get_single_consumer_of_type(fq); + if (!convert_to_low_precision) + return false; + convert_to_high_precision = get_single_consumer_of_type(convert_to_low_precision); + if (!convert_to_high_precision) + return false; + auto subtract = get_single_consumer_of_type(convert_to_high_precision); + if (subtract) { + zero_point = subtract->get_input_node_shared_ptr(1); + return get_single_consumer_of_type(subtract) != nullptr; + } else { + return get_single_consumer_of_type(convert_to_high_precision) != nullptr; + } +} + +static std::shared_ptr evaluate_fake_quantize(const std::shared_ptr& quantize, + const std::shared_ptr& convert) { + ov::Tensor quantize_output_tensor; + if (!evaluate_node(quantize, get_fake_quantize_input_tensors(quantize), quantize_output_tensor)) + return nullptr; + ov::Tensor new_weights_tensor; + if (!evaluate_node(convert, {quantize_output_tensor}, new_weights_tensor)) + return nullptr; + return std::make_shared(new_weights_tensor); +} + +void replace_with_dequantize_subgraph(const std::shared_ptr& fq, + const std::shared_ptr& new_weights, + const ov::element::Type& high_precision_type, + const ov::Shape& scale_or_zero_point_shape, + const ov::Tensor& scale_tensor, + bool zero_point_is_zero, + const ov::Tensor& zero_point_tensor) { + ov::pass::NodeRegistry node_registry; + auto convert = node_registry.make(new_weights, high_precision_type); + ov::pass::disable_constant_folding(convert); + std::shared_ptr mul; + auto scale = node_registry.make(scale_tensor); + if (!zero_point_is_zero) { + auto zero_point = node_registry.make(zero_point_tensor); + auto sub = node_registry.make(convert, zero_point); + mul = node_registry.make(sub, scale); + } else { + mul = node_registry.make(convert, scale); + } + mul->set_friendly_name(fq->get_friendly_name()); + copy_runtime_info(fq, node_registry.get()); + replace_node(fq, mul); +} + +template +static void compute_scale_and_zero_point_internal(const std::shared_ptr& output_low, + const std::shared_ptr& output_high, + size_t levels, + ov::Tensor& scale_tensor, + ov::Tensor& zero_point_tensor, + bool& zero_point_is_zero) { + zero_point_is_zero = true; + float input_range = static_cast(levels - 1); + float new_output_low = -static_cast(levels / 2); + T* zero_point = zero_point_tensor.data(); + T* scale = scale_tensor.data(); + ov::reference::autobroadcast_binop( + output_low->get_data_ptr(), + output_high->get_data_ptr(), + scale, + output_low->get_shape(), + output_high->get_shape(), + ov::op::AutoBroadcastType::NUMPY, + [input_range, new_output_low, zero_point, &zero_point_is_zero](float output_low_value, + float output_high_value) mutable { + float output_range = output_high_value - output_low_value; + float scale = output_range / input_range; + float zero_point_value = (new_output_low - output_low_value / scale) * (scale != 0); + zero_point_is_zero = + zero_point_is_zero && std::fabs(zero_point_value) < std::numeric_limits::epsilon(); + *zero_point++ = zero_point_value; + return scale; }); - if (!is_transformed_and_original_equal) - return false; +} - std::shared_ptr new_weights = std::make_shared(weights, int8_zero_point); - OPENVINO_SUPPRESS_DEPRECATED_START - if (auto constant = ov::get_constant_from_source(new_weights)) { - OPENVINO_SUPPRESS_DEPRECATED_END - new_weights = constant; - } else { - return false; +bool compute_scale_and_zero_point(const std::shared_ptr& output_low, + const std::shared_ptr& output_high, + size_t levels, + ov::Tensor& scale_tensor, + ov::Tensor& zero_point_tensor, + bool& zero_point_is_zero) { + const auto type = output_low->get_element_type(); + switch (type) { + case ov::element::Type_t::f32: { + compute_scale_and_zero_point_internal(output_low, + output_high, + levels, + scale_tensor, + zero_point_tensor, + zero_point_is_zero); + break; + } + case ov::element::f16: { + compute_scale_and_zero_point_internal(output_low, + output_high, + levels, + scale_tensor, + zero_point_tensor, + zero_point_is_zero); + break; + } + default: + return false; + } + + return true; +} + +template +static void +transform(const T* first1, const T* const last1, const T* first2, const T* first3, const T* first4, U* out, F& f) { + while (first1 < last1) { + *out++ = f(*first1++, *first2++, *first3++, *first4++); + } +} + +template +static void transform(const T* first1, + const T* const last1, + const T* first2, + const T* first3, + const T* first4, + const T* first5, + const T* first6, + U* out, + F& f) { + while (first1 < last1) { + *out++ = f(*first1++, *first2++, *first3++, *first4++, *first5++, *first6++); + } +} + +template +static void numpy_broadcast_4inputs(const T* weights, + const ov::Shape& weights_shape, + const T* in_low, + const ov::Shape& in_low_shape, + const T* in_high, + const ov::Shape& in_high_shape, + const T* zero_point, + const ov::Shape& zero_point_shape, + U* new_weights, + F& f) { + using namespace ov::reference::fake_quantize_details; + + std::vector output_strides = compute_strides(weights_shape, weights_shape); + std::vector in_low_strides = compute_strides(weights_shape, in_low_shape); + std::vector in_high_strides = compute_strides(weights_shape, in_high_shape); + std::vector zero_point_strides = compute_strides(weights_shape, zero_point_shape); + + size_t num_elements = shape_size(weights_shape); + + size_t weights_inner_stride = num_elements; + size_t in_low_inner_stride = 0; + size_t in_high_inner_stride = 0; + size_t zero_point_inner_stride = 0; + + std::tie(in_low_inner_stride, weights_inner_stride) = + get_inner_stride(num_elements, weights_shape, in_low_shape, weights_inner_stride); + std::tie(in_high_inner_stride, weights_inner_stride) = + get_inner_stride(num_elements, weights_shape, in_high_shape, weights_inner_stride); + std::tie(zero_point_inner_stride, weights_inner_stride) = + get_inner_stride(num_elements, weights_shape, zero_point_shape, weights_inner_stride); + + auto get_outer_strides = + [&output_strides, &in_low_strides, &in_high_strides, &zero_point_strides](size_t flat_index) { + size_t in_low_stride = 0; + size_t in_high_stride = 0; + size_t zero_point_stride = 0; + + for (size_t i = 0; i < output_strides.size(); i++) { + size_t div = flat_index / output_strides[i]; + flat_index = flat_index % output_strides[i]; + in_low_stride += div * in_low_strides[i]; + in_high_stride += div * in_high_strides[i]; + zero_point_stride += div * zero_point_strides[i]; + } + + return std::tuple{in_low_stride, in_high_stride, zero_point_stride}; + }; + + size_t in_low_stride = 0; + size_t in_high_stride = 0; + size_t zero_point_stride = 0; + + if (in_low_inner_stride * in_high_inner_stride * zero_point_inner_stride == 1) { + for (size_t i = 0; i < shape_size(weights_shape); i += weights_inner_stride) { + std::tie(in_low_stride, in_high_stride, zero_point_stride) = get_outer_strides(i); + T in_low_scalar = *(in_low + in_low_stride); + T in_high_scalar = *(in_high + in_high_stride); + T zero_point_scalar = *(zero_point + zero_point_stride); + std::transform(weights, + weights + weights_inner_stride, + new_weights, + [in_low_scalar, in_high_scalar, zero_point_scalar, &f](T w) { + return f(w, in_low_scalar, in_high_scalar, zero_point_scalar); + }); + weights += weights_inner_stride; + new_weights += weights_inner_stride; } - new_weights->set_friendly_name(weights->get_friendly_name()); - replace_node(weights, new_weights); + } else if (in_low_inner_stride > 1 && in_high_inner_stride > 1 && zero_point_inner_stride > 1) { + for (size_t i = 0; i < shape_size(weights_shape); i += weights_inner_stride) { + std::tie(in_low_stride, in_high_stride, zero_point_stride) = get_outer_strides(i); + transform(weights, + weights + weights_inner_stride, + in_low + in_low_stride, + in_high + in_high_stride, + zero_point + zero_point_stride, + new_weights, + f); + weights += weights_inner_stride; + new_weights += weights_inner_stride; + } + } else { + for (size_t i = 0; i < shape_size(weights_shape); i++) { + std::tie(in_low_stride, in_high_stride, zero_point_stride) = get_outer_strides(i); + *new_weights++ = f(*weights++, + *(in_low + in_low_stride), + *(in_high + in_high_stride), + *(zero_point + zero_point_stride)); + } + } +} - copy_runtime_info(sub, convert); - replace_node(sub, convert); - return true; +template +static void numpy_broadcast_6inputs(const T* weights, + const ov::Shape& weights_shape, + const T* in_low, + const ov::Shape& in_low_shape, + const T* in_high, + const ov::Shape& in_high_shape, + const T* out_low, + const ov::Shape& out_low_shape, + const T* out_high, + const ov::Shape& out_high_shape, + const T* zero_point, + const ov::Shape& zero_point_shape, + U* new_weights, + F& f) { + using namespace ov::reference::fake_quantize_details; + + std::vector output_strides = compute_strides(weights_shape, weights_shape); + std::vector in_low_strides = compute_strides(weights_shape, in_low_shape); + std::vector in_high_strides = compute_strides(weights_shape, in_high_shape); + std::vector out_low_strides = compute_strides(weights_shape, out_low_shape); + std::vector out_high_strides = compute_strides(weights_shape, out_high_shape); + std::vector zero_point_strides = compute_strides(weights_shape, zero_point_shape); + + auto get_outer_strides = + [&output_strides, &in_low_strides, &in_high_strides, &out_low_strides, &out_high_strides, &zero_point_strides]( + size_t flat_index) { + size_t in_low_stride = 0; + size_t in_high_stride = 0; + size_t out_low_stride = 0; + size_t out_high_stride = 0; + size_t zero_point_stride = 0; + + for (size_t i = 0; i < output_strides.size(); i++) { + size_t div = flat_index / output_strides[i]; + flat_index = flat_index % output_strides[i]; + in_low_stride += div * in_low_strides[i]; + in_high_stride += div * in_high_strides[i]; + out_low_stride += div * out_low_strides[i]; + out_high_stride += div * out_high_strides[i]; + zero_point_stride += div * zero_point_strides[i]; + } + + return std::tuple{in_low_stride, + in_high_stride, + out_low_stride, + out_high_stride, + zero_point_stride}; + }; + + size_t in_low_stride = 0; + size_t in_high_stride = 0; + size_t out_low_stride = 0; + size_t out_high_stride = 0; + size_t zero_point_stride = 0; + + for (size_t i = 0; i < shape_size(weights_shape); i++) { + std::tie(in_low_stride, in_high_stride, out_low_stride, out_high_stride, zero_point_stride) = + get_outer_strides(i); + *new_weights++ = f(*weights++, + *(in_low + in_low_stride), + *(in_high + in_high_stride), + *(out_low + out_low_stride), + *(out_high + out_high_stride), + *(zero_point + zero_point_stride)); + } +} + +static inline int8_t convert_to_int8(float val) { + return static_cast(std::nearbyint(val)); +} + +static inline int8_t convert_to_int4(float val) { + return static_cast(std::nearbyint(val)) & 0x0f; +} + +static std::shared_ptr create_weights_constant(const ov::Tensor& weights_tensor, + const ov::element::Type& type) { + auto weights = std::make_shared(weights_tensor); + if (weights->get_element_type() != type) { + return ov::util::constantfold_subgraph(std::make_shared(weights, type)); + } + return weights; +} + +template +static std::shared_ptr compress_quantized_weights_internal( + const ov::element::Type& low_precision_type, + const T* weights, + const ov::Shape& weights_shape, + const T* input_low, + const ov::Shape& input_low_shape, + const T* input_high, + const ov::Shape& input_high_shape, + const T* output_low, + const ov::Shape& output_low_shape, + const T* output_high, + const ov::Shape& output_high_shape, + const T* zero_point, + const ov::Shape& zero_point_shape, + size_t levels, + bool& can_fuse_zero_point) { + ov::Tensor compressed_weights_tensor(ov::element::i8, weights_shape); + int8_t* compressed_weights = compressed_weights_tensor.data(); + ov::Tensor compressed_weights_with_fused_zero_point_tensor(ov::element::i8, weights_shape); + int8_t* compressed_weights_with_fused_zero_point = compressed_weights_with_fused_zero_point_tensor.data(); + T levels_minus_one = static_cast(levels - 1); + can_fuse_zero_point = true; + const auto convert_to_low_precision = low_precision_type == ov::element::i4 ? convert_to_int4 : convert_to_int8; + + auto f = + [compressed_weights_with_fused_zero_point, levels_minus_one, convert_to_low_precision, &can_fuse_zero_point]( + T weights_value, + T input_low, + T input_high, + T output_low, + T output_high, + T zero_point) mutable { + int8_t compressed_weights_value = + convert_to_low_precision(ov::reference::fake_quantize_details::quantize(weights_value, + input_low, + input_high, + output_low, + output_high, + levels_minus_one)); + T weights_minus_zero_point = static_cast(compressed_weights_value) - zero_point; + int8_t compressed_weights_with_fused_zero_point_value = convert_to_low_precision(weights_minus_zero_point); + can_fuse_zero_point &= + std::fabs(compressed_weights_with_fused_zero_point_value - weights_minus_zero_point) < 1e-4; + *compressed_weights_with_fused_zero_point++ = compressed_weights_with_fused_zero_point_value; + return compressed_weights_value; + }; + + numpy_broadcast_6inputs(weights, + weights_shape, + input_low, + input_low_shape, + input_high, + input_high_shape, + output_low, + output_low_shape, + output_high, + output_high_shape, + zero_point, + zero_point_shape, + compressed_weights, + f); + + return create_weights_constant( + can_fuse_zero_point ? compressed_weights_with_fused_zero_point_tensor : compressed_weights_tensor, + low_precision_type); +} + +std::shared_ptr compress_quantized_weights( + const std::shared_ptr& weights, + const std::shared_ptr& fq, + const std::shared_ptr& input_low, + const std::shared_ptr& input_high, + const std::shared_ptr& output_low, + const std::shared_ptr& output_high, + const std::shared_ptr& convert, + const std::shared_ptr& zero_point, + bool& can_fuse_zero_point) { + std::shared_ptr new_weights; + const auto& weights_shape = weights->get_shape(); + const auto& type = weights->get_element_type(); + const auto& low_precision_type = convert->get_output_element_type(0); + + if (zero_point == nullptr) + return evaluate_fake_quantize(fq, convert); + + auto zero_point_constant = ov::util::constantfold_subgraph(zero_point); + if (!zero_point_constant) + return nullptr; + + switch (type) { + case ov::element::f32: { + new_weights = compress_quantized_weights_internal(low_precision_type, + weights->get_data_ptr(), + weights_shape, + input_low->get_data_ptr(), + input_low->get_shape(), + input_high->get_data_ptr(), + input_low->get_shape(), + output_low->get_data_ptr(), + output_low->get_shape(), + output_high->get_data_ptr(), + output_low->get_shape(), + zero_point_constant->get_data_ptr(), + zero_point_constant->get_shape(), + fq->get_levels(), + can_fuse_zero_point); + break; + } + case ov::element::f16: { + new_weights = compress_quantized_weights_internal(low_precision_type, + weights->get_data_ptr(), + weights_shape, + input_low->get_data_ptr(), + input_low->get_shape(), + input_high->get_data_ptr(), + input_low->get_shape(), + output_low->get_data_ptr(), + output_low->get_shape(), + output_high->get_data_ptr(), + output_low->get_shape(), + zero_point_constant->get_data_ptr(), + zero_point_constant->get_shape(), + fq->get_levels(), + can_fuse_zero_point); + break; + } + default: + return nullptr; + } + return new_weights; +} + +template +static std::shared_ptr compress_quantized_weights_internal( + const ov::element::Type& low_precision_type, + const T* weights, + const ov::Shape& weights_shape, + const T* input_low, + const ov::Shape& input_low_shape, + const T* input_high, + const ov::Shape& input_high_shape, + const T* zero_point, + const ov::Shape& zero_point_shape, + size_t levels, + bool zero_point_is_zero, + bool& can_fuse_zero_point) { + using namespace ov::reference::fake_quantize_details; + ov::Tensor compressed_weights_tensor(ov::element::i8, weights_shape); + int8_t* compressed_weights = compressed_weights_tensor.data(); + int8_t* compressed_weights_with_fused_zero_point = nullptr; + ov::Tensor compressed_weights_with_fused_zero_point_tensor; + if (!zero_point_is_zero) { + compressed_weights_with_fused_zero_point_tensor = ov::Tensor(ov::element::i8, weights_shape); + compressed_weights_with_fused_zero_point = compressed_weights_with_fused_zero_point_tensor.data(); + } + T levels_minus_one = static_cast(levels - 1); + T output_low = -static_cast(levels / 2); + T output_high = levels_minus_one + output_low; + can_fuse_zero_point = !zero_point_is_zero; + const auto convert_to_low_precision = low_precision_type == ov::element::i4 ? convert_to_int4 : convert_to_int8; + + auto f = [compressed_weights_with_fused_zero_point, + levels_minus_one, + output_low, + output_high, + zero_point_is_zero, + convert_to_low_precision, + &can_fuse_zero_point](T weights_value, T input_low, T input_high, T zero_point) mutable { + int8_t compressed_weights_value = convert_to_low_precision( + quantize(weights_value, input_low, input_high, output_low, output_high, levels_minus_one)); + if (!zero_point_is_zero && can_fuse_zero_point) { + T weights_minus_zero_point = static_cast(compressed_weights_value) - zero_point; + int8_t compressed_weights_with_fused_zero_point_value = convert_to_low_precision(weights_minus_zero_point); + can_fuse_zero_point &= + std::fabs(compressed_weights_with_fused_zero_point_value - weights_minus_zero_point) < 1e-4; + *compressed_weights_with_fused_zero_point++ = compressed_weights_with_fused_zero_point_value; + } + return compressed_weights_value; }; - auto m = std::make_shared(sub_pattern, "ZeroPointOptimizer"); - this->register_matcher(m, callback); + numpy_broadcast_4inputs(weights, + weights_shape, + input_low, + input_low_shape, + input_high, + input_high_shape, + zero_point, + zero_point_shape, + compressed_weights, + f); + + return create_weights_constant( + can_fuse_zero_point ? compressed_weights_with_fused_zero_point_tensor : compressed_weights_tensor, + low_precision_type); +} + +std::shared_ptr compress_quantized_weights( + const std::shared_ptr& weights, + const std::shared_ptr& input_low, + const std::shared_ptr& input_high, + const ov::element::Type& low_precision_type, + size_t levels, + bool zero_point_is_zero, + const ov::Tensor& zero_point_tensor, + bool& can_fuse_zero_point) { + std::shared_ptr new_weights; + const auto& weights_shape = weights->get_shape(); + const auto& type = weights->get_element_type(); + switch (type) { + case ov::element::f32: { + new_weights = compress_quantized_weights_internal(low_precision_type, + weights->get_data_ptr(), + weights_shape, + input_low->get_data_ptr(), + input_low->get_shape(), + input_high->get_data_ptr(), + input_low->get_shape(), + zero_point_tensor.data(), + zero_point_tensor.get_shape(), + levels, + zero_point_is_zero, + can_fuse_zero_point); + break; + } + case ov::element::f16: { + new_weights = compress_quantized_weights_internal(low_precision_type, + weights->get_data_ptr(), + weights_shape, + input_low->get_data_ptr(), + input_low->get_shape(), + input_high->get_data_ptr(), + input_low->get_shape(), + zero_point_tensor.data(), + zero_point_tensor.get_shape(), + levels, + zero_point_is_zero, + can_fuse_zero_point); + break; + } + default: + return nullptr; + } + return new_weights; } diff --git a/src/common/transformations/tests/utils/compress_quantize_weights.cpp b/src/common/transformations/tests/utils/compress_quantize_weights.cpp index 5a62b79bfaa..cc310173688 100644 --- a/src/common/transformations/tests/utils/compress_quantize_weights.cpp +++ b/src/common/transformations/tests/utils/compress_quantize_weights.cpp @@ -31,6 +31,7 @@ struct CompressQuantizeWeightsParams { std::vector expected_weights; float scale_val; float zero_point_val; + bool fuse_zero_point; }; class CompressQuantizeWeightsTests @@ -66,9 +67,14 @@ class CompressQuantizeWeightsTests auto data = opset8::Constant::create(param.expected_type, param.shape, param.expected_weights); auto convert = std::make_shared(data, element::f32); auto scale = opset8::Constant::create(element::f32, Shape{}, {param.scale_val}); - auto zero_point = opset8::Constant::create(element::f32, Shape{}, {param.zero_point_val}); - auto sub = std::make_shared(convert, zero_point); - auto mul = std::make_shared(sub, scale); + std::shared_ptr mul; + if (!param.fuse_zero_point) { + auto zero_point = opset8::Constant::create(element::f32, Shape{}, {param.zero_point_val}); + auto sub = std::make_shared(convert, zero_point); + mul = std::make_shared(sub, scale); + } else { + mul = std::make_shared(convert, scale); + } model_ref = std::make_shared(mul, ParameterVector{}); } comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); @@ -89,7 +95,8 @@ static std::vector params = { element::i4, {-1.0f, -1.0f, 0.0f, 0.0f, 0.0f, 1.0f}, 3.0f, - -0.666667f}, + -0.666667f, + false}, {Shape{2, 3, 1, 1}, {-1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 11.0f}, 0.0f, @@ -100,7 +107,8 @@ static std::vector params = { element::i4, {-8.0f, -5.0f, -4.0f, -2.0f, 0.0f, 7.0f}, 0.333333f, - -5.0f}, + -5.0f, + false}, {Shape{2, 4, 1, 1}, {-1.0f, 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 11.0f}, 1.0f, @@ -109,9 +117,10 @@ static std::vector params = { 6.0f, 17, element::i8, - {-8.0f, -8.0f, -8.0f, -6.0f, -4.0f, -2.0f, 0.0f, 8.0f}, + {-4.0f, -4.0f, -4.0f, -2.0f, 0.0f, 2.0f, 4.0f, 12.0f}, 0.5f, - -4.0f}, + -4.0f, + true}, {Shape{2, 4, 1, 1}, {-1.0f, 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 11.0f}, 1.0f, @@ -122,7 +131,8 @@ static std::vector params = { element::i8, {-128.0f, -128.0f, -128.0f, -96.0f, -64.0f, -32.0f, 0.0f, 127.0f}, 0.0313725f, - -64.25f}, + -64.25f, + false}, }; static element::TypeVector data_precisions = {element::f32, element::f16}; @@ -198,7 +208,7 @@ TEST_F(TransformationTestsF, CompressQuantizeWeightsWithDequantizationSubgraphFP comparator.enable(FunctionsComparator::CmpValues::ACCURACY); } -TEST_F(TransformationTestsF, CompressQuantizeWeightsWithZeroPointOptimizer) { +TEST_F(TransformationTestsF, CompressQuantizeWeightsWithZeroPointEliminated) { { auto data = opset8::Constant::create(element::f32, Shape{3, 1, 1, 1}, {-0.144816, 0.0858578, 0.110928}); auto input_low = opset8::Constant::create(element::f32, Shape{3, 1, 1, 1}, {-0.402659, -0.383148, -0.34054}); @@ -209,7 +219,6 @@ TEST_F(TransformationTestsF, CompressQuantizeWeightsWithZeroPointOptimizer) { model = std::make_shared(NodeVector{fq}, ParameterVector{}); manager.register_pass(); - manager.register_pass(); } { @@ -223,7 +232,7 @@ TEST_F(TransformationTestsF, CompressQuantizeWeightsWithZeroPointOptimizer) { comparator.enable(FunctionsComparator::CmpValues::ACCURACY); } -TEST_F(TransformationTestsF, CompressQuantizeWeightsWithZeroPointOptimizerFP16) { +TEST_F(TransformationTestsF, CompressQuantizeWeightsWithZeroPointEliminatedFP16) { { auto data = opset8::Constant::create(element::f16, Shape{3, 1, 1, 1}, {0.2, 1.2, 1.2}); auto input_low = @@ -239,7 +248,6 @@ TEST_F(TransformationTestsF, CompressQuantizeWeightsWithZeroPointOptimizerFP16) model = std::make_shared(NodeVector{fq}, ParameterVector{}); manager.register_pass(); - manager.register_pass(); } { @@ -253,7 +261,7 @@ TEST_F(TransformationTestsF, CompressQuantizeWeightsWithZeroPointOptimizerFP16) comparator.enable(FunctionsComparator::CmpValues::ACCURACY); } -TEST_F(TransformationTestsF, NegativeCompressQuantizeWeightsWithZeroPointOptimizer) { +TEST_F(TransformationTestsF, NegativeCompressQuantizeWeights) { { auto data = opset8::Constant::create(element::f32, Shape{2, 4, 1, 1}, {-1, 0, 1, 2, 3, 4, 5, 11}); auto input_low = opset8::Constant::create(element::f32, Shape{}, {1}); @@ -264,7 +272,6 @@ TEST_F(TransformationTestsF, NegativeCompressQuantizeWeightsWithZeroPointOptimiz model = std::make_shared(NodeVector{fq}, ParameterVector{}); manager.register_pass(); - manager.register_pass(); } { auto data = opset8::Constant::create(element::i8, Shape{2, 4, 1, 1}, {-128, -128, -128, -96, -64, -32, 0, 127}); @@ -289,7 +296,6 @@ TEST_F(TransformationTestsF, NegativeCompressQuantizeWeightsNonConstantInput) { model = std::make_shared(NodeVector{fq}, ParameterVector{data}); manager.register_pass(); - manager.register_pass(); comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); comparator.enable(FunctionsComparator::CmpValues::ACCURACY); diff --git a/src/core/reference/include/openvino/reference/fake_quantize.hpp b/src/core/reference/include/openvino/reference/fake_quantize.hpp index d0828cd2308..2fb30a4a5c4 100644 --- a/src/core/reference/include/openvino/reference/fake_quantize.hpp +++ b/src/core/reference/include/openvino/reference/fake_quantize.hpp @@ -21,31 +21,86 @@ namespace ov { namespace reference { namespace fake_quantize_details { template -inline T quantize(const T& arg, - const T& in_low, - const T& in_high, - const T& out_low, - const T& out_high, - const size_t& levels) { +static inline T quantize(const T arg, + const T in_low, + const T in_high, + const T out_low, + const T out_high, + const T levels_minus_one) { if (arg <= std::min(in_low, in_high)) { return out_low; } else if (arg > std::max(in_low, in_high)) { return out_high; } - return static_cast(std::nearbyint((arg - in_low) / (in_high - in_low) * (levels - 1)) / (levels - 1) * + return static_cast(std::nearbyint((arg - in_low) / (in_high - in_low) * levels_minus_one) / levels_minus_one * (out_high - out_low) + out_low); } +static std::vector compute_strides(const ov::Shape& out_shape, const ov::Shape& shape); + +static std::tuple get_inner_stride(size_t num_output_elements, + const ov::Shape& output_shape, + const ov::Shape& shape, + size_t current_output_inner_stride); + +template +static void fake_quantize_non_unit_inner_stride(const T* arg, + const T* in_low, + const T* in_high, + const T* out_low, + const T* out_high, + T* out, + const Shape& arg_shape, + T levels_minus_one, + size_t input_inner_stride, + const F& get_outer_strides); + +template +static void fake_quantize_unit_inner_stride(const T* arg, + const T* in_low, + const T* in_high, + const T* out_low, + const T* out_high, + T* out, + const Shape& arg_shape, + T levels_minus_one, + size_t input_inner_stride, + const F& get_outer_strides); + +template +static void fake_quantize_unit_output_intervals_inner_stride(const T* arg, + const T* in_low, + const T* in_high, + const T* out_low, + const T* out_high, + T* out, + const Shape& arg_shape, + T levels_minus_one, + size_t input_inner_stride, + const F& get_outer_strides); + +template +static void fake_quantize_unit_input_intervals_inner_stride(const T* arg, + const T* in_low, + const T* in_high, + const T* out_low, + const T* out_high, + T* out, + const Shape& arg_shape, + T levels_minus_one, + size_t input_inner_stride, + const F& get_outer_strides); + } // namespace fake_quantize_details template -void fake_quantize(const T* const arg, - const T* const in_low, - const T* const in_high, - const T* const out_low, - const T* const out_high, - T* const out, +void fake_quantize(const T* arg, + const T* in_low, + const T* in_high, + const T* out_low, + const T* out_high, + T* out, const Shape& arg_shape, const Shape& in_low_shape, const Shape& in_high_shape, @@ -55,133 +110,452 @@ void fake_quantize(const T* const arg, const op::AutoBroadcastSpec& broadcast) { using namespace fake_quantize_details; + T levels_minus_one = static_cast(levels - 1); + const size_t arg_size = shape_size(arg_shape); + if (shape_size(in_low_shape) == 1 && shape_size(in_high_shape) == 1 && shape_size(out_low_shape) == 1 && shape_size(out_high_shape) == 1) { - const size_t arg_size = shape_size(arg_shape); - const auto q = [=](const T& a) { - return quantize(a, *in_low, *in_high, *out_low, *out_high, levels); - }; for (size_t i = 0; i < arg_size; ++i) { - out[i] = q(arg[i]); + out[i] = quantize(arg[i], *in_low, *in_high, *out_low, *out_high, levels_minus_one); } + return; + } + + // clang-format off + /* + * --------------------------------------------------- + * Overview: + * Numpy broadcasted input tensors can be partitioned into two: outer and inner part (which also defines inner + * stride as a product of inner part), so N-dimensional tensors can be processed using two loops. + * + * For example with two inputs [2, 2, 3, 4] and [1, 1, 3, 4] we can have: + * input 1 with shape [2, 2, 3, 4] can be divided into outer part [2, 2] and inner part [3, 4] + * with inner stride = 12 (3 * 4). + * input 2 with shape [1, 1, 3, 4] can be divided into outer part [1, 1] + * and inner part [3, 4] with inner stride = 12 (3 * 4) + * + * Having that, those inputs can be processed by the following: + * + * output_shape = {2, 2, 3, 4}; + * output_inner_stride = 12; + * for (i = 0; i < shape_size(shape); i += output_inner_stride) { + * first_input_stride = i; + * second_input_stride = 0; + * for (j = 0; j < 12; j++) { + * *out++ = f(first_input[first_input_stride + j], second_input[second_input_stride + j]); + * } + * } + * + * --------------------------------------------------- + * How the partitioning is done: + * Partitioning process starts with the last dimension of input tensor shape and it stops when either one of below + * occurs: + * - if the last dimension is equal to 1, partitioning stops at the dimension that is greater than 1 (this + * dimension is not included in the inner part), + * - if the last dimension is greater than 1, partitioning stops at the dimension that is equal to 1 (this + * dimension is not included in the inner part). + * + * Examples: + * tensor_shape=[2, 3, 4, 5], inner_part = [2, 3, 4, 5], inner_stride = 120 + * tensor_shape=[1, 1, 4, 5], inner_part = [4, 5], inner_stride = 20 + * tensor_shape=[2, 3, 1, 1], inner_part = [1, 1], inner_stride = 1 + * + * + * --------------------------------------------------- + * How the output inner stride is calculated: + * Inner part (and inner stride) for every input tensor is determined. Then the size of output inner part is the + * size of inner part with the fewest number of dimensions. + * + * Example with 5 inputs: + * input 1 shape [2, 3, 4, 5], inner_part = [2, 3, 4, 5], inner_stride = 120 + * input 2 shape [1, 3, 4, 5], inner_part = [3, 4, 5], inner_stride = 60 + * input 3 shape [2, 3, 1, 1], inner_part = [1, 1], inner_stride = 1 + * input 4 shape [2, 1, 1, 1], inner_part = [1, 1, 1], inner_stride = 1 + * input 5 shape [1, 1, 1, 1], inner_part = [1, 1, 1, 1], inner_stride = 1 + * + * output shape [2, 3, 4, 5], inner_part = [4, 5], inner_stride = 20 + * + * Inner part with fewest number of elements is [1, 1] for input 3. So the inner part for output shape is [4, 5] + * and output inner stride is 20. + */ + // clang-format on + + std::vector output_strides = compute_strides(arg_shape, arg_shape); + std::vector in_low_strides = compute_strides(arg_shape, in_low_shape); + std::vector in_high_strides = compute_strides(arg_shape, in_high_shape); + std::vector out_low_strides = compute_strides(arg_shape, out_low_shape); + std::vector out_high_strides = compute_strides(arg_shape, out_high_shape); + + size_t input_inner_stride = arg_size; + size_t in_low_inner_stride = 0; + size_t in_high_inner_stride = 0; + size_t out_low_inner_stride = 0; + size_t out_high_inner_stride = 0; + + std::tie(in_low_inner_stride, input_inner_stride) = + get_inner_stride(arg_size, arg_shape, in_low_shape, input_inner_stride); + std::tie(in_high_inner_stride, input_inner_stride) = + get_inner_stride(arg_size, arg_shape, in_high_shape, input_inner_stride); + std::tie(out_low_inner_stride, input_inner_stride) = + get_inner_stride(arg_size, arg_shape, out_low_shape, input_inner_stride); + std::tie(out_high_inner_stride, input_inner_stride) = + get_inner_stride(arg_size, arg_shape, out_high_shape, input_inner_stride); + + auto get_outer_strides = + [&output_strides, &in_low_strides, &in_high_strides, &out_low_strides, &out_high_strides](size_t flat_index) { + size_t in_low_stride = 0; + size_t in_high_stride = 0; + size_t out_low_stride = 0; + size_t out_high_stride = 0; + + for (size_t i = 0; i < output_strides.size(); i++) { + size_t div = flat_index / output_strides[i]; + flat_index = flat_index % output_strides[i]; + in_low_stride += div * in_low_strides[i]; + in_high_stride += div * in_high_strides[i]; + out_low_stride += div * out_low_strides[i]; + out_high_stride += div * out_high_strides[i]; + } + + return std::tuple{in_low_stride, + in_high_stride, + out_low_stride, + out_high_stride}; + }; + + if (in_low_inner_stride > 1 && in_high_inner_stride > 1 && out_low_inner_stride > 1 && out_high_inner_stride > 1) { + fake_quantize_non_unit_inner_stride(arg, + in_low, + in_high, + out_low, + out_high, + out, + arg_shape, + levels_minus_one, + input_inner_stride, + get_outer_strides); + } else if (in_low_inner_stride == 1 && in_high_inner_stride == 1 && out_low_inner_stride == 1 && + out_high_inner_stride == 1) { + fake_quantize_unit_inner_stride(arg, + in_low, + in_high, + out_low, + out_high, + out, + arg_shape, + levels_minus_one, + input_inner_stride, + get_outer_strides); + + } else if (in_low_inner_stride > 1 && in_high_inner_stride > 1 && out_low_inner_stride == 1 && + out_high_inner_stride == 1) { + fake_quantize_unit_output_intervals_inner_stride(arg, + in_low, + in_high, + out_low, + out_high, + out, + arg_shape, + levels_minus_one, + input_inner_stride, + get_outer_strides); + + } else if (in_low_inner_stride == 1 && in_high_inner_stride == 1 && out_low_inner_stride > 1 && + out_high_inner_stride > 1) { + fake_quantize_unit_input_intervals_inner_stride(arg, + in_low, + in_high, + out_low, + out_high, + out, + arg_shape, + levels_minus_one, + input_inner_stride, + get_outer_strides); } else { - OPENVINO_ASSERT(in_low_shape.size() <= arg_shape.size() && in_high_shape.size() <= arg_shape.size() && - out_low_shape.size() <= arg_shape.size() && out_high_shape.size() <= arg_shape.size(), - "Tensors with input\\output ranges should have rank less or " - "equal to data tensor rank equal to ", - arg_shape.size()); + size_t in_low_stride = 0; + size_t in_high_stride = 0; + size_t out_low_stride = 0; + size_t out_high_stride = 0; - Shape arg0_padded_shape = arg_shape; - Shape arg1_padded_shape = in_low_shape; - Shape arg2_padded_shape = in_high_shape; - Shape arg3_padded_shape = out_low_shape; - Shape arg4_padded_shape = out_high_shape; - - size_t max_shape_size = arg_shape.size(); - - while (arg0_padded_shape.size() < max_shape_size) { - arg0_padded_shape.insert(arg0_padded_shape.begin(), 1); - } - - while (arg1_padded_shape.size() < max_shape_size) { - arg1_padded_shape.insert(arg1_padded_shape.begin(), 1); - } - - while (arg2_padded_shape.size() < max_shape_size) { - arg2_padded_shape.insert(arg2_padded_shape.begin(), 1); - } - - while (arg3_padded_shape.size() < max_shape_size) { - arg3_padded_shape.insert(arg3_padded_shape.begin(), 1); - } - - while (arg4_padded_shape.size() < max_shape_size) { - arg4_padded_shape.insert(arg4_padded_shape.begin(), 1); - } - - Shape arg0_squeezed_shape, arg1_squeezed_shape, arg2_squeezed_shape, arg3_squeezed_shape, arg4_squeezed_shape; - AxisSet arg0_squeezed_axes, arg1_squeezed_axes, arg2_squeezed_axes, arg3_squeezed_axes, arg4_squeezed_axes; - Shape output_shape; - - for (size_t i = 0; i < max_shape_size; i++) { - if (arg1_padded_shape[i] == 1) { - arg1_squeezed_axes.insert(i); - } else { - arg1_squeezed_shape.push_back(arg1_padded_shape[i]); - } - - if (arg2_padded_shape[i] == 1) { - arg2_squeezed_axes.insert(i); - } else { - arg2_squeezed_shape.push_back(arg2_padded_shape[i]); - } - - if (arg0_padded_shape[i] == 1) { - arg0_squeezed_axes.insert(i); - } else { - arg0_squeezed_shape.push_back(arg0_padded_shape[i]); - } - - if (arg3_padded_shape[i] == 1) { - arg3_squeezed_axes.insert(i); - } else { - arg3_squeezed_shape.push_back(arg3_padded_shape[i]); - } - - if (arg4_padded_shape[i] == 1) { - arg4_squeezed_axes.insert(i); - } else { - arg4_squeezed_shape.push_back(arg4_padded_shape[i]); - } - - output_shape.push_back(std::max({arg0_padded_shape[i], - arg2_padded_shape[i], - arg1_padded_shape[i], - arg3_padded_shape[i], - arg4_padded_shape[i]})); - } - - CoordinateTransformBasic arg0_transform(arg0_squeezed_shape); - CoordinateTransformBasic arg1_transform(arg1_squeezed_shape); - CoordinateTransformBasic arg2_transform(arg2_squeezed_shape); - CoordinateTransformBasic arg3_transform(arg3_squeezed_shape); - CoordinateTransformBasic arg4_transform(arg4_squeezed_shape); - CoordinateTransformBasic output_transform(output_shape); - - const auto arg0_strides = row_major_strides(arg0_squeezed_shape); - const auto arg1_strides = row_major_strides(arg1_squeezed_shape); - const auto arg2_strides = row_major_strides(arg2_squeezed_shape); - const auto arg3_strides = row_major_strides(arg3_squeezed_shape); - const auto arg4_strides = row_major_strides(arg4_squeezed_shape); - const auto output_strides = row_major_strides(output_shape); - - for (const Coordinate& output_coord : output_transform) { - const auto arg0_coord = util::reduce(output_coord, arg0_squeezed_axes); - const auto arg1_coord = util::reduce(output_coord, arg1_squeezed_axes); - const auto arg2_coord = util::reduce(output_coord, arg2_squeezed_axes); - const auto arg3_coord = util::reduce(output_coord, arg3_squeezed_axes); - const auto arg4_coord = util::reduce(output_coord, arg4_squeezed_axes); - - const size_t arg0_idx = - std::inner_product(arg0_coord.begin(), arg0_coord.end(), arg0_strides.begin(), uint64_t(0)); - const size_t arg1_idx = - std::inner_product(arg1_coord.begin(), arg1_coord.end(), arg1_strides.begin(), uint64_t(0)); - const size_t arg2_idx = - std::inner_product(arg2_coord.begin(), arg2_coord.end(), arg2_strides.begin(), uint64_t(0)); - const size_t arg3_idx = - std::inner_product(arg3_coord.begin(), arg3_coord.end(), arg3_strides.begin(), uint64_t(0)); - const size_t arg4_idx = - std::inner_product(arg4_coord.begin(), arg4_coord.end(), arg4_strides.begin(), uint64_t(0)); - const size_t output_idx = - std::inner_product(output_coord.begin(), output_coord.end(), output_strides.begin(), uint64_t(0)); - out[output_idx] = quantize(arg[arg0_idx], - in_low[arg1_idx], - in_high[arg2_idx], - out_low[arg3_idx], - out_high[arg4_idx], - levels); + for (size_t i = 0; i < arg_size; i++) { + std::tie(in_low_stride, in_high_stride, out_low_stride, out_high_stride) = get_outer_strides(i); + *out++ = quantize(*arg++, + *(in_low + in_low_stride), + *(in_high + in_high_stride), + *(out_low + out_low_stride), + *(out_high + out_low_stride), + levels_minus_one); } } } + +namespace fake_quantize_details { +std::vector compute_strides(const ov::Shape& out_shape, const ov::Shape& shape) { + size_t stride = 1; + size_t out_rank = out_shape.size(); + size_t shape_rank = shape.size(); + std::vector strides(out_rank); + for (size_t i = 0; i < out_rank; i++) { + if (i < shape_rank && shape[shape_rank - i - 1] == out_shape[out_rank - i - 1]) { + strides[out_rank - i - 1] = stride; + stride *= shape[shape_rank - i - 1]; + } else { + strides[out_rank - i - 1] = 0; + } + } + return strides; +} + +std::tuple get_inner_stride(size_t num_output_elements, + const ov::Shape& output_shape, + const ov::Shape& shape, + size_t current_output_inner_stride) { + if (shape.size() == 0) + return std::tuple{1, std::min(current_output_inner_stride, num_output_elements)}; + const size_t last = shape.back(); + auto it = std::find_if(shape.rbegin(), shape.rend(), [last](size_t dim) { + return (last == 1 && dim > 1) || (last > 1 && dim == 1); + }); + if (it == shape.rend()) { + const size_t num_elements = shape_size(shape); + return std::tuple{ + num_elements, + last == 1 ? current_output_inner_stride : std::min(current_output_inner_stride, num_elements)}; + } + const size_t idx = std::distance(it, shape.rbegin()) + static_cast(shape.size()); + const size_t inner_stride = + std::accumulate(shape.begin() + idx, shape.end(), static_cast(1), std::multiplies()); + const size_t output_inner_stride = std::accumulate(output_shape.begin() + output_shape.size() - shape.size() + idx, + output_shape.end(), + static_cast(1), + std::multiplies()); + return std::tuple{inner_stride, std::min(current_output_inner_stride, output_inner_stride)}; +} + +template +static void transform(const T* first1, const T* const last1, const T* first2, const T* first3, T* out, const F& f) { + while (first1 < last1) { + *out++ = f(*first1++, *first2++, *first3++); + } +} + +template +static void transform(const T* first1, + const T* const last1, + const T* first2, + const T* first3, + const T* first4, + const T* first5, + T* out, + const F& f) { + while (first1 < last1) { + *out++ = f(*first1++, *first2++, *first3++, *first4++, *first5++); + } +} + +template +static void fake_quantize_loop(const Shape& arg_shape, + const T* arg, + const T* in_low, + const T* in_high, + const T* out_low, + const T* out_high, + T* out, + size_t input_inner_stride, + const F1& get_outer_strides, + const F2& quantize_loop) { + size_t in_low_stride = 0; + size_t in_high_stride = 0; + size_t out_low_stride = 0; + size_t out_high_stride = 0; + + for (size_t i = 0; i < shape_size(arg_shape); i += input_inner_stride) { + std::tie(in_low_stride, in_high_stride, out_low_stride, out_high_stride) = get_outer_strides(i); + quantize_loop(arg, + arg + input_inner_stride, + in_low + in_low_stride, + in_high + in_high_stride, + out_low + out_low_stride, + out_high + out_high_stride, + out); + arg += input_inner_stride; + out += input_inner_stride; + } +} + +template +void fake_quantize_non_unit_inner_stride(const T* arg, + const T* in_low, + const T* in_high, + const T* out_low, + const T* out_high, + T* out, + const Shape& arg_shape, + T levels_minus_one, + size_t input_inner_stride, + const F& get_outer_strides) { + fake_quantize_loop(arg_shape, + arg, + in_low, + in_high, + out_low, + out_high, + out, + input_inner_stride, + get_outer_strides, + [levels_minus_one](const T* input, + const T* const input_end, + const T* in_low, + const T* in_high, + const T* out_low, + const T* out_high, + T* out) { + transform(input, + input_end, + in_low, + in_high, + out_low, + out_high, + out, + [levels_minus_one](T input, T in_low, T in_high, T out_low, T out_high) { + return quantize(input, in_low, in_high, out_low, out_high, levels_minus_one); + }); + }); +} + +template +void fake_quantize_unit_inner_stride(const T* arg, + const T* in_low, + const T* in_high, + const T* out_low, + const T* out_high, + T* out, + const Shape& arg_shape, + T levels_minus_one, + size_t input_inner_stride, + const F& get_outer_strides) { + auto quantize_with_scalar_intervals = [levels_minus_one](const T* input, + const T* const input_end, + const T* in_low, + const T* in_high, + const T* out_low, + const T* out_high, + T* out) { + const auto in_low_scalar = *in_low; + const auto in_high_scalar = *in_high; + const auto out_low_scalar = *out_low; + const auto out_high_scalar = *out_high; + std::transform(input, + input_end, + out, + [levels_minus_one, in_low_scalar, in_high_scalar, out_low_scalar, out_high_scalar](T input) { + return quantize(input, + in_low_scalar, + in_high_scalar, + out_low_scalar, + out_high_scalar, + levels_minus_one); + }); + }; + + fake_quantize_loop(arg_shape, + arg, + in_low, + in_high, + out_low, + out_high, + out, + input_inner_stride, + get_outer_strides, + quantize_with_scalar_intervals); +} + +template +void fake_quantize_unit_output_intervals_inner_stride(const T* arg, + const T* in_low, + const T* in_high, + const T* out_low, + const T* out_high, + T* out, + const Shape& arg_shape, + T levels_minus_one, + size_t input_inner_stride, + const F& get_outer_strides) { + auto quantize_with_scalar_output_intervals = [levels_minus_one](const T* input, + const T* const input_end, + const T* in_low, + const T* in_high, + const T* out_low, + const T* out_high, + T* out) { + const auto out_low_scalar = *out_low; + const auto out_high_scalar = *out_high; + transform(input, + input_end, + in_low, + in_high, + out, + [levels_minus_one, out_low_scalar, out_high_scalar](T input, T in_low, T in_high) { + return quantize(input, in_low, in_high, out_low_scalar, out_high_scalar, levels_minus_one); + }); + }; + + fake_quantize_loop(arg_shape, + arg, + in_low, + in_high, + out_low, + out_high, + out, + input_inner_stride, + get_outer_strides, + quantize_with_scalar_output_intervals); +} + +template +void fake_quantize_unit_input_intervals_inner_stride(const T* arg, + const T* in_low, + const T* in_high, + const T* out_low, + const T* out_high, + T* out, + const Shape& arg_shape, + T levels_minus_one, + size_t input_inner_stride, + const F& get_outer_strides) { + auto quantize_with_scalar_input_intervals = [levels_minus_one](const T* input, + const T* const input_end, + const T* in_low, + const T* in_high, + const T* out_low, + const T* out_high, + T* out) { + const auto in_low_scalar = *in_low; + const auto in_high_scalar = *in_high; + transform(input, + input_end, + out_low, + out_high, + out, + [levels_minus_one, in_low_scalar, in_high_scalar](T input, T out_low, T out_high) { + return quantize(input, in_low_scalar, in_high_scalar, out_low, out_high, levels_minus_one); + }); + }; + + fake_quantize_loop(arg_shape, + arg, + in_low, + in_high, + out_low, + out_high, + out, + input_inner_stride, + get_outer_strides, + quantize_with_scalar_input_intervals); +} + +} // namespace fake_quantize_details + } // namespace reference } // namespace ov diff --git a/src/plugins/template/tests/functional/op_reference/fake_quantize.cpp b/src/plugins/template/tests/functional/op_reference/fake_quantize.cpp index 418d02c23c5..d2ea203ff8e 100644 --- a/src/plugins/template/tests/functional/op_reference/fake_quantize.cpp +++ b/src/plugins/template/tests/functional/op_reference/fake_quantize.cpp @@ -253,6 +253,22 @@ std::vector generateParamsForFakeQuantize() { }), 16, op::AutoBroadcastSpec(op::AutoBroadcastType::NUMPY)), + FakeQuantizeParams( + ov::Shape{1, 2, 4, 4}, + ov::Shape{1, 2, 4, 4}, + IN_ET, + IN_ET, + iota_vector(shape_size(Shape{1, 2, 4, 4})), + std::vector{ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 8.75, 8.75, 8.75, 8.75, 8.75, 8.75, 17.5, + 23.75, 23.75, 27.5, 27.5, 27.5, 27.5, 27.5, 31.25, 31.25, 31.25, 31.25, 31.25, 35, 35, 35, 35, + }, + op::v0::Constant::create(IN_ET, Shape{1, 2, 1, 1}, {5.f, 10.f}), + op::v0::Constant::create(IN_ET, Shape{1, 1}, {30.f}), + op::v0::Constant::create(IN_ET, Shape{2, 1, 1}, {0.f, 20.f}), + op::v0::Constant::create(IN_ET, Shape{1}, {35.f}), + 5), + }; return params; }