diff --git a/src/common/offline_transformations/src/compress_quantize_weigths.cpp b/src/common/offline_transformations/src/compress_quantize_weigths.cpp index 294178c979c..8c0faa85a84 100644 --- a/src/common/offline_transformations/src/compress_quantize_weigths.cpp +++ b/src/common/offline_transformations/src/compress_quantize_weigths.cpp @@ -9,6 +9,7 @@ #include #include #include +#include static bool has_dequantization_subgraph(const std::shared_ptr& first_convert) { auto first_convert_users = first_convert->get_users(); @@ -65,6 +66,11 @@ ngraph::pass::CompressQuantizeWeights::CompressQuantizeWeights() { const auto& pattern_value_map = m.get_pattern_value_map(); const auto& input_type = fq->get_element_type(); + const auto& fq_data_input = fq->get_input_node_shared_ptr(0); + bool are_weights_decompressed = is_decompression(fq_data_input); + if (are_weights_decompressed) { + unmark_as_decompression(fq_data_input); + } // skip dequantize part if there is already dequantization subgraph after FakeQuantize auto fq_users = fq->get_users(); @@ -83,6 +89,9 @@ ngraph::pass::CompressQuantizeWeights::CompressQuantizeWeights() { } return true; } else { + if (are_weights_decompressed) { + mark_as_decompression(fq_data_input); + } return false; } } else { @@ -102,9 +111,6 @@ ngraph::pass::CompressQuantizeWeights::CompressQuantizeWeights() { const auto& weights_const = pattern_value_map.at(weights_const_pattern); Output input_low = pattern_value_map.at(input_low_pattern); Output input_high = pattern_value_map.at(input_high_pattern); - const auto& fq_data_input = pattern_value_map.count(weigths_convert_pattern) - ? pattern_value_map.at(weigths_convert_pattern) - : weights_const; auto quantize = fq->clone_with_new_inputs({fq_data_input, input_low, input_high, new_output_low, new_output_high}); // Convert quantized weights to low precision type @@ -115,6 +121,9 @@ ngraph::pass::CompressQuantizeWeights::CompressQuantizeWeights() { OPENVINO_SUPPRESS_DEPRECATED_END new_weights = constant; } else { + if (are_weights_decompressed) { + mark_as_decompression(fq_data_input); + } return false; } new_weights->set_friendly_name(weights_const.get_node()->get_friendly_name()); diff --git a/src/common/transformations/tests/utils/compress_quantize_weights.cpp b/src/common/transformations/tests/utils/compress_quantize_weights.cpp index df5f60ece77..55c35b7205a 100644 --- a/src/common/transformations/tests/utils/compress_quantize_weights.cpp +++ b/src/common/transformations/tests/utils/compress_quantize_weights.cpp @@ -10,6 +10,7 @@ #include #include #include +#include #include #include "common_test_utils/ngraph_test_utils.hpp" @@ -41,8 +42,10 @@ class CompressQuantizeWeightsTests std::tie(param, data_prc) = GetParam(); { std::shared_ptr data = opset8::Constant::create(data_prc, param.shape, param.weights); - if (data_prc == element::f16) + if (data_prc == element::f16) { data = std::make_shared(data, element::f32); + ov::mark_as_decompression(data); + } auto input_low = opset8::Constant::create(element::f32, Shape{}, {param.in_low}); auto input_high = opset8::Constant::create(element::f32, Shape{}, {param.in_high}); auto output_low = opset8::Constant::create(element::f32, Shape{}, {param.out_low}); @@ -159,6 +162,41 @@ TEST_F(TransformationTestsF, CompressQuantizeWeightsWithDequantizationSubgraph) comparator.enable(FunctionsComparator::CmpValues::ACCURACY); } +TEST_F(TransformationTestsF, CompressQuantizeWeightsWithDequantizationSubgraphFP16) { + { + auto data = opset8::Constant::create(element::f16, Shape{2, 4, 1, 1}, {-1, 0, 1, 2, 3, 4, 5, 11}); + auto convert_to_f32 = std::make_shared(data, element::f32); + ov::mark_as_decompression(convert_to_f32); + auto input_low = opset8::Constant::create(element::f32, Shape{}, {1}); + auto input_high = opset8::Constant::create(element::f32, Shape{}, {9}); + auto output_low = opset8::Constant::create(element::f32, Shape{}, {-128}); + auto output_high = opset8::Constant::create(element::f32, Shape{}, {127}); + auto fq = + std::make_shared(convert_to_f32, input_low, input_high, output_low, output_high, 256); + auto convert = std::make_shared(fq, element::i8); + auto second_convert = std::make_shared(convert, element::f32); + auto scale = opset8::Constant::create(element::f32, Shape{}, {10.0 / 255}); + auto zero_point = opset8::Constant::create(element::f32, Shape{}, {2 - 255.0 / 10}); + auto sub = std::make_shared(second_convert, zero_point); + auto mul = std::make_shared(sub, scale); + + function = std::make_shared(NodeVector{mul}, ParameterVector{}); + + manager.register_pass(); + } + { + auto data = opset8::Constant::create(element::i8, Shape{2, 4, 1, 1}, {-128, -128, -128, -96, -64, -32, 0, 127}); + auto convert = std::make_shared(data, element::f32); + auto scale = opset8::Constant::create(element::f32, Shape{}, {10.0 / 255}); + auto zero_point = opset8::Constant::create(element::f32, Shape{}, {2 - 255.0 / 10}); + auto sub = std::make_shared(convert, zero_point); + auto mul = std::make_shared(sub, scale); + function_ref = std::make_shared(NodeVector{mul}, ParameterVector{}); + } + comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); + comparator.enable(FunctionsComparator::CmpValues::ACCURACY); +} + TEST_F(TransformationTestsF, CompressQuantizeWeightsWithZeroPointOptimizer) { { auto data = opset8::Constant::create(element::f32, Shape{3, 1, 1, 1}, {-0.144816, 0.0858578, 0.110928});