Dequantize (Sub, Mul) to FakeQuantize (#4189)
* Dequantize (Sub, Mul) to FakeQuantize * disable for CPU/GPU
This commit is contained in:
parent
a327b72481
commit
a313c0c3ee
@ -30,6 +30,7 @@
|
||||
#include <transformations/control_flow/unroll_tensor_iterator.hpp>
|
||||
|
||||
#include <transformations/common_optimizations/common_optimizations.hpp>
|
||||
#include <transformations/common_optimizations/weights_dequantize_to_fake_quantize.hpp>
|
||||
#include "transformations/common_optimizations/convert_quantize_dequantize.hpp"
|
||||
#include <transformations/op_conversions/convert_depth_to_space.hpp>
|
||||
#include <transformations/op_conversions/convert_space_to_depth.hpp>
|
||||
@ -279,6 +280,7 @@ InferenceEngine::CNNNetwork clDNNEngine::CloneAndTransformNetwork(const Inferenc
|
||||
pass_config->disable<ngraph::pass::SoftPlusDecomposition>();
|
||||
pass_config->disable<ngraph::pass::LogSoftmaxDecomposition>();
|
||||
pass_config->disable<ngraph::pass::ConvertBroadcast3>();
|
||||
pass_config->disable<ngraph::pass::WeightsDequantizeToFakeQuantize>();
|
||||
|
||||
pass_config->enable<ngraph::pass::ConvertInterpolate1ToInterpolate4>();
|
||||
|
||||
|
@ -33,6 +33,7 @@
|
||||
#include <transformations/opset_conversions/convert_opset2_to_opset1.hpp>
|
||||
|
||||
#include <transformations/common_optimizations/common_optimizations.hpp>
|
||||
#include <transformations/common_optimizations/weights_dequantize_to_fake_quantize.hpp>
|
||||
#include "transformations/common_optimizations/convert_quantize_dequantize.hpp"
|
||||
#include <transformations/common_optimizations/depth_to_space_fusion.hpp>
|
||||
#include <transformations/op_conversions/convert_depth_to_space.hpp>
|
||||
@ -228,6 +229,7 @@ static void Transformation(CNNNetwork& clonedNetwork, const Config& conf) {
|
||||
pass_config->disable<ngraph::pass::ConvertMod>();
|
||||
pass_config->disable<ngraph::pass::LogSoftmaxDecomposition>();
|
||||
pass_config->disable<ngraph::pass::ConvertInterpolateToInterpOrResampleMatcher>();
|
||||
pass_config->disable<ngraph::pass::WeightsDequantizeToFakeQuantize>();
|
||||
|
||||
pass_config->enable<ngraph::pass::ConvertInterpolate1ToInterpolate4>();
|
||||
|
||||
|
@ -0,0 +1,30 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include <transformations_visibility.hpp>
|
||||
#include <ngraph/pass/graph_rewrite.hpp>
|
||||
|
||||
namespace ngraph {
|
||||
namespace pass {
|
||||
|
||||
class TRANSFORMATIONS_API WeightsDequantizeToFakeQuantize;
|
||||
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief SoftPlusFusion transformation replaces group of
|
||||
* operations: log(exp(x) + 1) to SoftPlus op.
|
||||
*/
|
||||
class ngraph::pass::WeightsDequantizeToFakeQuantize: public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
WeightsDequantizeToFakeQuantize();
|
||||
};
|
@ -55,6 +55,7 @@
|
||||
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
#include <ngraph/pass/constant_folding.hpp>
|
||||
#include <transformations/common_optimizations/weights_dequantize_to_fake_quantize.hpp>
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::CommonOptimizations, "CommonOptimizations", 0);
|
||||
|
||||
@ -69,6 +70,7 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::
|
||||
|
||||
// TODO: move to KMB
|
||||
manager.register_pass<ngraph::pass::ConvertQuantizeDequantize>();
|
||||
manager.register_pass<ngraph::pass::WeightsDequantizeToFakeQuantize>();
|
||||
|
||||
manager.register_pass<ngraph::pass::ConstantFolding>();
|
||||
manager.register_pass<ngraph::pass::StridedSliceOptimization>(); // depends on CF
|
||||
|
@ -0,0 +1,74 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transformations/common_optimizations/weights_dequantize_to_fake_quantize.hpp"
|
||||
|
||||
#include <ngraph/opsets/opset6.hpp>
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
#include <ngraph/pattern/op/or.hpp>
|
||||
#include <ngraph/rt_info.hpp>
|
||||
#include "itt.hpp"
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::WeightsDequantizeToFakeQuantize, "WeightsDequantizeToFakeQuantize", 0);
|
||||
|
||||
ngraph::pass::WeightsDequantizeToFakeQuantize::WeightsDequantizeToFakeQuantize() {
|
||||
MATCHER_SCOPE(WeightsDequantizeToFakeQuantize);
|
||||
|
||||
const auto weights = ngraph::pattern::wrap_type<ngraph::opset6::Constant>();
|
||||
const auto convert = ngraph::pattern::wrap_type<ngraph::opset6::Convert>({weights});
|
||||
const auto sub_c = ngraph::pattern::wrap_type<ngraph::opset6::Constant>();
|
||||
const auto sub = ngraph::pattern::wrap_type<ngraph::opset6::Subtract>({convert, sub_c});
|
||||
|
||||
const auto sub_or_convert = std::make_shared<pattern::op::Or>(OutputVector{convert, sub});
|
||||
|
||||
const auto mul_c = ngraph::pattern::wrap_type<ngraph::opset6::Constant>();
|
||||
const auto mul = ngraph::pattern::wrap_type<ngraph::opset6::Multiply>({sub_or_convert, mul_c});
|
||||
|
||||
ngraph::matcher_pass_callback callback;
|
||||
callback = [=](ngraph::pattern::Matcher &m) {
|
||||
auto pattern_map = m.get_pattern_map();
|
||||
|
||||
const auto &weights_node = as_type_ptr<opset6::Constant>(pattern_map.at(weights));
|
||||
const auto &convert_node = pattern_map.at(convert);
|
||||
const auto &multiply_node = pattern_map.at(mul);
|
||||
const auto &scale_node = pattern_map.at(mul_c);
|
||||
if (!weights_node || !convert_node || !multiply_node || !scale_node) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto *data = weights_node->get_data_ptr<int8_t>();
|
||||
const int8_t weights_minimum = *std::min_element(data, data + shape_size(weights_node->get_shape()));
|
||||
int64_t levels = (weights_minimum == static_cast<int8_t>(-128)) ? 256 : 255;
|
||||
int64_t in_low = -(levels / 2), in_high = levels + in_low - 1;
|
||||
|
||||
const auto &input_low = opset6::Constant::create(convert_node->get_element_type(), {}, {in_low});
|
||||
const auto &input_high = opset6::Constant::create(convert_node->get_element_type(), {}, {in_high});
|
||||
|
||||
auto &zero_point = pattern_map.at(sub_c);
|
||||
if (!zero_point)
|
||||
zero_point = opset6::Constant::create(convert_node->get_element_type(), {}, {0});
|
||||
|
||||
const auto &output_low = std::make_shared<opset6::Multiply>(
|
||||
std::make_shared<opset6::Subtract>(input_low, zero_point), scale_node);
|
||||
const auto &output_high = std::make_shared<opset6::Multiply>(
|
||||
std::make_shared<opset6::Subtract>(input_high, zero_point), scale_node);
|
||||
|
||||
auto fq = std::make_shared<opset6::FakeQuantize>(
|
||||
convert_node, input_low, input_high, output_low, output_high, levels);
|
||||
|
||||
NodeVector nodes_to_copy_RT_info_from{multiply_node, scale_node, zero_point};
|
||||
if (pattern_map.at(sub))
|
||||
nodes_to_copy_RT_info_from.push_back(sub);
|
||||
|
||||
ngraph::copy_runtime_info(fq, nodes_to_copy_RT_info_from);
|
||||
multiply_node->output(0).replace(fq->output(0));
|
||||
|
||||
if (convert_node->get_rt_info().count("DISABLED_CONSTANT_FOLDING"))
|
||||
convert_node->get_rt_info().erase("DISABLED_CONSTANT_FOLDING");
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(mul, "WeightsDequantizeToFakeQuantize");
|
||||
register_matcher(m, callback);
|
||||
}
|
Loading…
Reference in New Issue
Block a user