[CPU] Don't fuse Add or Mul to FakeQuantize if the result is not supported b… (#8585)
This commit is contained in:
parent
abbf0384ae
commit
585c6bcb24
@ -26,9 +26,12 @@
|
||||
#include <transformations/opset_conversions/convert_opset3_to_opset2.hpp>
|
||||
#include <transformations/opset_conversions/convert_opset2_to_opset1.hpp>
|
||||
|
||||
#include <transformations/common_optimizations/add_fake_quantize_fusion.hpp>
|
||||
#include <transformations/common_optimizations/common_optimizations.hpp>
|
||||
#include <transformations/common_optimizations/fq_mul_fusion.hpp>
|
||||
#include <transformations/common_optimizations/mul_fake_quantize_fusion.hpp>
|
||||
#include <transformations/common_optimizations/weights_dequantize_to_fake_quantize.hpp>
|
||||
#include "transformations/common_optimizations/convert_quantize_dequantize.hpp"
|
||||
#include <transformations/common_optimizations/convert_quantize_dequantize.hpp>
|
||||
#include <transformations/common_optimizations/nop_elimination.hpp>
|
||||
#include <transformations/common_optimizations/wrap_interpolate_into_transposes.hpp>
|
||||
#include <transformations/common_optimizations/transpose_sinking.hpp>
|
||||
@ -341,6 +344,13 @@ static void TransformationUpToCPUSpecificOpSet(std::shared_ptr<ngraph::Function>
|
||||
pass_config->enable<ngraph::pass::ConvertGather8ToGather7>();
|
||||
|
||||
if (useLpt) {
|
||||
pass_config->set_callback<ngraph::pass::AddFakeQuantizeFusion,
|
||||
ngraph::pass::MulFakeQuantizeFusion,
|
||||
ngraph::pass::FakeQuantizeMulFusion>([](const_node_ptr &node) -> bool {
|
||||
std::string errMsg;
|
||||
return !MKLDNNFakeQuantizeNode::isSupportedOperation(node, errMsg);
|
||||
});
|
||||
|
||||
pass_config->set_callback<ngraph::pass::ConvertQuantizeDequantize>([](const_node_ptr &node) -> bool {
|
||||
return ngraph::pass::low_precision::NetworkHelper::areQuantizeAndDequantizeSupportedForMultiply(node);
|
||||
});
|
||||
|
@ -101,12 +101,14 @@ ngraph::pass::AddFakeQuantizeFusion::AddFakeQuantizeFusion() {
|
||||
std::shared_ptr<Node> new_input_high = get_constant_from_source(input_high_sub);
|
||||
if (!new_input_high)
|
||||
new_input_high = input_high_sub;
|
||||
auto new_fq = register_new_node<opset5::FakeQuantize>(input,
|
||||
new_input_low,
|
||||
new_input_high,
|
||||
fq->input_value(3),
|
||||
fq->input_value(4),
|
||||
fq->get_levels());
|
||||
auto new_fq = fq->clone_with_new_inputs({input,
|
||||
new_input_low,
|
||||
new_input_high,
|
||||
fq->input_value(3),
|
||||
fq->input_value(4)});
|
||||
if (transformation_callback(new_fq))
|
||||
return false;
|
||||
register_new_node(new_fq);
|
||||
new_fq->set_friendly_name(fq->get_friendly_name());
|
||||
copy_runtime_info({add_node, fq}, {new_input_low, new_input_high, new_fq});
|
||||
replace_node(fq, new_fq);
|
||||
|
@ -116,6 +116,8 @@ ngraph::pass::FakeQuantizeMulFusion::FakeQuantizeMulFusion() {
|
||||
fq_node->input_value(2),
|
||||
get_adjusted_output_range(original_output_low),
|
||||
get_adjusted_output_range(original_output_high)});
|
||||
if (transformation_callback(new_fq_node))
|
||||
return false;
|
||||
|
||||
const auto mul_node = pattern_map.at(mul_node_p).get_node_shared_ptr();
|
||||
|
||||
|
@ -93,8 +93,11 @@ ngraph::pass::MulFakeQuantizeFusion::MulFakeQuantizeFusion() {
|
||||
if (!new_input_high)
|
||||
new_input_high = input_high_div;
|
||||
|
||||
auto new_fq = register_new_node<opset5::FakeQuantize>(input, new_input_low, new_input_high,
|
||||
fq->input_value(3), fq->input_value(4), fq->get_levels());
|
||||
auto new_fq = fq->clone_with_new_inputs({input, new_input_low, new_input_high,
|
||||
fq->input_value(3), fq->input_value(4)});
|
||||
if (transformation_callback(new_fq))
|
||||
return false;
|
||||
register_new_node(new_fq);
|
||||
copy_runtime_info({pattern_value_map.at(mul_pattern).get_node_shared_ptr(), fq},
|
||||
{new_const, new_input_low, new_input_high, new_fq});
|
||||
new_fq->set_friendly_name(fq->get_friendly_name());
|
||||
|
Loading…
Reference in New Issue
Block a user