CompressQuantizeWeights - handle decompressed inputs to FakeQuantize (#19061)
Decompression attribute (that is present in models with FP16 precision) prevents the weights to be constantfolded. Weights constantfolding is required by CompressQuantizeWeights to compress the weights to low precision format. Ticket: CVS-117310
This commit is contained in:
@@ -9,6 +9,7 @@
|
||||
#include <ngraph/rt_info.hpp>
|
||||
#include <ngraph/validation_util.hpp>
|
||||
#include <openvino/pass/constant_folding.hpp>
|
||||
#include <transformations/rt_info/decompression.hpp>
|
||||
|
||||
static bool has_dequantization_subgraph(const std::shared_ptr<ngraph::Node>& first_convert) {
|
||||
auto first_convert_users = first_convert->get_users();
|
||||
@@ -65,6 +66,11 @@ ngraph::pass::CompressQuantizeWeights::CompressQuantizeWeights() {
|
||||
|
||||
const auto& pattern_value_map = m.get_pattern_value_map();
|
||||
const auto& input_type = fq->get_element_type();
|
||||
const auto& fq_data_input = fq->get_input_node_shared_ptr(0);
|
||||
bool are_weights_decompressed = is_decompression(fq_data_input);
|
||||
if (are_weights_decompressed) {
|
||||
unmark_as_decompression(fq_data_input);
|
||||
}
|
||||
|
||||
// skip dequantize part if there is already dequantization subgraph after FakeQuantize
|
||||
auto fq_users = fq->get_users();
|
||||
@@ -83,6 +89,9 @@ ngraph::pass::CompressQuantizeWeights::CompressQuantizeWeights() {
|
||||
}
|
||||
return true;
|
||||
} else {
|
||||
if (are_weights_decompressed) {
|
||||
mark_as_decompression(fq_data_input);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
@@ -102,9 +111,6 @@ ngraph::pass::CompressQuantizeWeights::CompressQuantizeWeights() {
|
||||
const auto& weights_const = pattern_value_map.at(weights_const_pattern);
|
||||
Output<Node> input_low = pattern_value_map.at(input_low_pattern);
|
||||
Output<Node> input_high = pattern_value_map.at(input_high_pattern);
|
||||
const auto& fq_data_input = pattern_value_map.count(weigths_convert_pattern)
|
||||
? pattern_value_map.at(weigths_convert_pattern)
|
||||
: weights_const;
|
||||
auto quantize =
|
||||
fq->clone_with_new_inputs({fq_data_input, input_low, input_high, new_output_low, new_output_high});
|
||||
// Convert quantized weights to low precision type
|
||||
@@ -115,6 +121,9 @@ ngraph::pass::CompressQuantizeWeights::CompressQuantizeWeights() {
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
new_weights = constant;
|
||||
} else {
|
||||
if (are_weights_decompressed) {
|
||||
mark_as_decompression(fq_data_input);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
new_weights->set_friendly_name(weights_const.get_node()->get_friendly_name());
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
#include <ngraph/opsets/opset8.hpp>
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
#include <transformations/init_node_info.hpp>
|
||||
#include <transformations/rt_info/decompression.hpp>
|
||||
#include <transformations/utils/utils.hpp>
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
@@ -41,8 +42,10 @@ class CompressQuantizeWeightsTests
|
||||
std::tie(param, data_prc) = GetParam();
|
||||
{
|
||||
std::shared_ptr<Node> data = opset8::Constant::create(data_prc, param.shape, param.weights);
|
||||
if (data_prc == element::f16)
|
||||
if (data_prc == element::f16) {
|
||||
data = std::make_shared<opset8::Convert>(data, element::f32);
|
||||
ov::mark_as_decompression(data);
|
||||
}
|
||||
auto input_low = opset8::Constant::create(element::f32, Shape{}, {param.in_low});
|
||||
auto input_high = opset8::Constant::create(element::f32, Shape{}, {param.in_high});
|
||||
auto output_low = opset8::Constant::create(element::f32, Shape{}, {param.out_low});
|
||||
@@ -159,6 +162,41 @@ TEST_F(TransformationTestsF, CompressQuantizeWeightsWithDequantizationSubgraph)
|
||||
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, CompressQuantizeWeightsWithDequantizationSubgraphFP16) {
|
||||
{
|
||||
auto data = opset8::Constant::create(element::f16, Shape{2, 4, 1, 1}, {-1, 0, 1, 2, 3, 4, 5, 11});
|
||||
auto convert_to_f32 = std::make_shared<opset8::Convert>(data, element::f32);
|
||||
ov::mark_as_decompression(convert_to_f32);
|
||||
auto input_low = opset8::Constant::create(element::f32, Shape{}, {1});
|
||||
auto input_high = opset8::Constant::create(element::f32, Shape{}, {9});
|
||||
auto output_low = opset8::Constant::create(element::f32, Shape{}, {-128});
|
||||
auto output_high = opset8::Constant::create(element::f32, Shape{}, {127});
|
||||
auto fq =
|
||||
std::make_shared<opset8::FakeQuantize>(convert_to_f32, input_low, input_high, output_low, output_high, 256);
|
||||
auto convert = std::make_shared<opset8::Convert>(fq, element::i8);
|
||||
auto second_convert = std::make_shared<opset8::Convert>(convert, element::f32);
|
||||
auto scale = opset8::Constant::create(element::f32, Shape{}, {10.0 / 255});
|
||||
auto zero_point = opset8::Constant::create(element::f32, Shape{}, {2 - 255.0 / 10});
|
||||
auto sub = std::make_shared<opset8::Subtract>(second_convert, zero_point);
|
||||
auto mul = std::make_shared<opset8::Multiply>(sub, scale);
|
||||
|
||||
function = std::make_shared<Function>(NodeVector{mul}, ParameterVector{});
|
||||
|
||||
manager.register_pass<pass::CompressQuantizeWeights>();
|
||||
}
|
||||
{
|
||||
auto data = opset8::Constant::create(element::i8, Shape{2, 4, 1, 1}, {-128, -128, -128, -96, -64, -32, 0, 127});
|
||||
auto convert = std::make_shared<opset8::Convert>(data, element::f32);
|
||||
auto scale = opset8::Constant::create(element::f32, Shape{}, {10.0 / 255});
|
||||
auto zero_point = opset8::Constant::create(element::f32, Shape{}, {2 - 255.0 / 10});
|
||||
auto sub = std::make_shared<opset8::Subtract>(convert, zero_point);
|
||||
auto mul = std::make_shared<opset8::Multiply>(sub, scale);
|
||||
function_ref = std::make_shared<Function>(NodeVector{mul}, ParameterVector{});
|
||||
}
|
||||
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, CompressQuantizeWeightsWithZeroPointOptimizer) {
|
||||
{
|
||||
auto data = opset8::Constant::create(element::f32, Shape{3, 1, 1, 1}, {-0.144816, 0.0858578, 0.110928});
|
||||
|
||||
Reference in New Issue
Block a user