[CPU] Don't fuse Add or Mul to FakeQuantize if the result is not supported b… (#8585)

This commit is contained in:
Mateusz Tabaka 2021-11-24 07:50:47 +01:00 committed by GitHub
parent abbf0384ae
commit 585c6bcb24
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 26 additions and 9 deletions

View File

@ -26,9 +26,12 @@
#include <transformations/opset_conversions/convert_opset3_to_opset2.hpp>
#include <transformations/opset_conversions/convert_opset2_to_opset1.hpp>
#include <transformations/common_optimizations/add_fake_quantize_fusion.hpp>
#include <transformations/common_optimizations/common_optimizations.hpp>
#include <transformations/common_optimizations/fq_mul_fusion.hpp>
#include <transformations/common_optimizations/mul_fake_quantize_fusion.hpp>
#include <transformations/common_optimizations/weights_dequantize_to_fake_quantize.hpp>
#include "transformations/common_optimizations/convert_quantize_dequantize.hpp"
#include <transformations/common_optimizations/convert_quantize_dequantize.hpp>
#include <transformations/common_optimizations/nop_elimination.hpp>
#include <transformations/common_optimizations/wrap_interpolate_into_transposes.hpp>
#include <transformations/common_optimizations/transpose_sinking.hpp>
@ -341,6 +344,13 @@ static void TransformationUpToCPUSpecificOpSet(std::shared_ptr<ngraph::Function>
pass_config->enable<ngraph::pass::ConvertGather8ToGather7>();
if (useLpt) {
pass_config->set_callback<ngraph::pass::AddFakeQuantizeFusion,
ngraph::pass::MulFakeQuantizeFusion,
ngraph::pass::FakeQuantizeMulFusion>([](const_node_ptr &node) -> bool {
std::string errMsg;
return !MKLDNNFakeQuantizeNode::isSupportedOperation(node, errMsg);
});
pass_config->set_callback<ngraph::pass::ConvertQuantizeDequantize>([](const_node_ptr &node) -> bool {
return ngraph::pass::low_precision::NetworkHelper::areQuantizeAndDequantizeSupportedForMultiply(node);
});

View File

@ -101,12 +101,14 @@ ngraph::pass::AddFakeQuantizeFusion::AddFakeQuantizeFusion() {
std::shared_ptr<Node> new_input_high = get_constant_from_source(input_high_sub);
if (!new_input_high)
new_input_high = input_high_sub;
auto new_fq = register_new_node<opset5::FakeQuantize>(input,
new_input_low,
new_input_high,
fq->input_value(3),
fq->input_value(4),
fq->get_levels());
auto new_fq = fq->clone_with_new_inputs({input,
new_input_low,
new_input_high,
fq->input_value(3),
fq->input_value(4)});
if (transformation_callback(new_fq))
return false;
register_new_node(new_fq);
new_fq->set_friendly_name(fq->get_friendly_name());
copy_runtime_info({add_node, fq}, {new_input_low, new_input_high, new_fq});
replace_node(fq, new_fq);

View File

@ -116,6 +116,8 @@ ngraph::pass::FakeQuantizeMulFusion::FakeQuantizeMulFusion() {
fq_node->input_value(2),
get_adjusted_output_range(original_output_low),
get_adjusted_output_range(original_output_high)});
if (transformation_callback(new_fq_node))
return false;
const auto mul_node = pattern_map.at(mul_node_p).get_node_shared_ptr();

View File

@ -93,8 +93,11 @@ ngraph::pass::MulFakeQuantizeFusion::MulFakeQuantizeFusion() {
if (!new_input_high)
new_input_high = input_high_div;
auto new_fq = register_new_node<opset5::FakeQuantize>(input, new_input_low, new_input_high,
fq->input_value(3), fq->input_value(4), fq->get_levels());
auto new_fq = fq->clone_with_new_inputs({input, new_input_low, new_input_high,
fq->input_value(3), fq->input_value(4)});
if (transformation_callback(new_fq))
return false;
register_new_node(new_fq);
copy_runtime_info({pattern_value_map.at(mul_pattern).get_node_shared_ptr(), fq},
{new_const, new_input_low, new_input_high, new_fq});
new_fq->set_friendly_name(fq->get_friendly_name());