Dequantize (Sub, Mul) to FakeQuantize (#4189)

* Dequantize (Sub, Mul) to FakeQuantize

* disable for CPU/GPU
This commit is contained in:
Evgenya Stepyreva 2021-02-10 17:08:11 +03:00 committed by GitHub
parent a327b72481
commit a313c0c3ee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 110 additions and 0 deletions

View File

@ -30,6 +30,7 @@
#include <transformations/control_flow/unroll_tensor_iterator.hpp>
#include <transformations/common_optimizations/common_optimizations.hpp>
#include <transformations/common_optimizations/weights_dequantize_to_fake_quantize.hpp>
#include "transformations/common_optimizations/convert_quantize_dequantize.hpp"
#include <transformations/op_conversions/convert_depth_to_space.hpp>
#include <transformations/op_conversions/convert_space_to_depth.hpp>
@ -279,6 +280,7 @@ InferenceEngine::CNNNetwork clDNNEngine::CloneAndTransformNetwork(const Inferenc
pass_config->disable<ngraph::pass::SoftPlusDecomposition>();
pass_config->disable<ngraph::pass::LogSoftmaxDecomposition>();
pass_config->disable<ngraph::pass::ConvertBroadcast3>();
pass_config->disable<ngraph::pass::WeightsDequantizeToFakeQuantize>();
pass_config->enable<ngraph::pass::ConvertInterpolate1ToInterpolate4>();

View File

@ -33,6 +33,7 @@
#include <transformations/opset_conversions/convert_opset2_to_opset1.hpp>
#include <transformations/common_optimizations/common_optimizations.hpp>
#include <transformations/common_optimizations/weights_dequantize_to_fake_quantize.hpp>
#include "transformations/common_optimizations/convert_quantize_dequantize.hpp"
#include <transformations/common_optimizations/depth_to_space_fusion.hpp>
#include <transformations/op_conversions/convert_depth_to_space.hpp>
@ -228,6 +229,7 @@ static void Transformation(CNNNetwork& clonedNetwork, const Config& conf) {
pass_config->disable<ngraph::pass::ConvertMod>();
pass_config->disable<ngraph::pass::LogSoftmaxDecomposition>();
pass_config->disable<ngraph::pass::ConvertInterpolateToInterpOrResampleMatcher>();
pass_config->disable<ngraph::pass::WeightsDequantizeToFakeQuantize>();
pass_config->enable<ngraph::pass::ConvertInterpolate1ToInterpolate4>();

View File

@ -0,0 +1,30 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <vector>
#include <memory>
#include <transformations_visibility.hpp>
#include <ngraph/pass/graph_rewrite.hpp>
namespace ngraph {
namespace pass {
class TRANSFORMATIONS_API WeightsDequantizeToFakeQuantize;
} // namespace pass
} // namespace ngraph
/**
* @ingroup ie_transformation_common_api
* @brief SoftPlusFusion transformation replaces group of
* operations: log(exp(x) + 1) to SoftPlus op.
*/
class ngraph::pass::WeightsDequantizeToFakeQuantize: public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
WeightsDequantizeToFakeQuantize();
};

View File

@ -55,6 +55,7 @@
#include <ngraph/pass/manager.hpp>
#include <ngraph/pass/constant_folding.hpp>
#include <transformations/common_optimizations/weights_dequantize_to_fake_quantize.hpp>
NGRAPH_RTTI_DEFINITION(ngraph::pass::CommonOptimizations, "CommonOptimizations", 0);
@ -69,6 +70,7 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::
// TODO: move to KMB
manager.register_pass<ngraph::pass::ConvertQuantizeDequantize>();
manager.register_pass<ngraph::pass::WeightsDequantizeToFakeQuantize>();
manager.register_pass<ngraph::pass::ConstantFolding>();
manager.register_pass<ngraph::pass::StridedSliceOptimization>(); // depends on CF

View File

@ -0,0 +1,74 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/common_optimizations/weights_dequantize_to_fake_quantize.hpp"
#include <ngraph/opsets/opset6.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <ngraph/pattern/op/or.hpp>
#include <ngraph/rt_info.hpp>
#include "itt.hpp"
NGRAPH_RTTI_DEFINITION(ngraph::pass::WeightsDequantizeToFakeQuantize, "WeightsDequantizeToFakeQuantize", 0);
ngraph::pass::WeightsDequantizeToFakeQuantize::WeightsDequantizeToFakeQuantize() {
MATCHER_SCOPE(WeightsDequantizeToFakeQuantize);
const auto weights = ngraph::pattern::wrap_type<ngraph::opset6::Constant>();
const auto convert = ngraph::pattern::wrap_type<ngraph::opset6::Convert>({weights});
const auto sub_c = ngraph::pattern::wrap_type<ngraph::opset6::Constant>();
const auto sub = ngraph::pattern::wrap_type<ngraph::opset6::Subtract>({convert, sub_c});
const auto sub_or_convert = std::make_shared<pattern::op::Or>(OutputVector{convert, sub});
const auto mul_c = ngraph::pattern::wrap_type<ngraph::opset6::Constant>();
const auto mul = ngraph::pattern::wrap_type<ngraph::opset6::Multiply>({sub_or_convert, mul_c});
ngraph::matcher_pass_callback callback;
callback = [=](ngraph::pattern::Matcher &m) {
auto pattern_map = m.get_pattern_map();
const auto &weights_node = as_type_ptr<opset6::Constant>(pattern_map.at(weights));
const auto &convert_node = pattern_map.at(convert);
const auto &multiply_node = pattern_map.at(mul);
const auto &scale_node = pattern_map.at(mul_c);
if (!weights_node || !convert_node || !multiply_node || !scale_node) {
return false;
}
const auto *data = weights_node->get_data_ptr<int8_t>();
const int8_t weights_minimum = *std::min_element(data, data + shape_size(weights_node->get_shape()));
int64_t levels = (weights_minimum == static_cast<int8_t>(-128)) ? 256 : 255;
int64_t in_low = -(levels / 2), in_high = levels + in_low - 1;
const auto &input_low = opset6::Constant::create(convert_node->get_element_type(), {}, {in_low});
const auto &input_high = opset6::Constant::create(convert_node->get_element_type(), {}, {in_high});
auto &zero_point = pattern_map.at(sub_c);
if (!zero_point)
zero_point = opset6::Constant::create(convert_node->get_element_type(), {}, {0});
const auto &output_low = std::make_shared<opset6::Multiply>(
std::make_shared<opset6::Subtract>(input_low, zero_point), scale_node);
const auto &output_high = std::make_shared<opset6::Multiply>(
std::make_shared<opset6::Subtract>(input_high, zero_point), scale_node);
auto fq = std::make_shared<opset6::FakeQuantize>(
convert_node, input_low, input_high, output_low, output_high, levels);
NodeVector nodes_to_copy_RT_info_from{multiply_node, scale_node, zero_point};
if (pattern_map.at(sub))
nodes_to_copy_RT_info_from.push_back(sub);
ngraph::copy_runtime_info(fq, nodes_to_copy_RT_info_from);
multiply_node->output(0).replace(fq->output(0));
if (convert_node->get_rt_info().count("DISABLED_CONSTANT_FOLDING"))
convert_node->get_rt_info().erase("DISABLED_CONSTANT_FOLDING");
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(mul, "WeightsDequantizeToFakeQuantize");
register_matcher(m, callback);
}