Introduce Quantize-Dequantize to FakeQuantize transformation (#1849)

* Introduce Quantize-Dequantize to FakeQuantize transformation

* Revert changes in DequantizeLinear

* apply code format

* Changes after review:

- description for transformation
- remove NGRAPH_CHECK and move some checks from callback to predicates in pattern
- check if out_low/high are broadcastable for FQ's first input
- fix params to copy_runtime_info

* Add type_matches and type_matches_any predicates

* Use get_single_value

* Changes after review:

- add brief description of transformation
- use get_pattern_value_map instead of get_pattern_map
- change opset1 to opset4
- fix params to copy_runtime_info

* Check result of dynamic_pointer_cast
This commit is contained in:
Mateusz Tabaka 2020-08-26 10:51:51 +02:00 committed by GitHub
parent 4673dc9b9c
commit a6076a1fd6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 678 additions and 0 deletions

View File

@ -0,0 +1,36 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <vector>
#include <memory>
#include <transformations_visibility.hpp>
#include <ngraph/pass/graph_rewrite.hpp>
namespace ngraph {
namespace pass {
class TRANSFORMATIONS_API ConvertQuantizeDequantize;
} // namespace pass
} // namespace ngraph
/**
* @ingroup ie_transformation_common_api
* @brief ConvertQuantizeDequantize transformation replaces following graph:
* FakeQuantize->Convert->Convert->Subtract->Multiply with a single FakeQuantize.
* Restrictions:
* - quantized data type must be i8 or u8
* - 'levels' attribute to FakeQuantize must be equal to 256
* - (output_low, output_high) must be (-128, 127) or (0, 256) (depends on sign of quantized data type)
* - 'zero_point' and 'scale' must be broadcastable to FakeQuantize's output
*/
class ngraph::pass::ConvertQuantizeDequantize: public ngraph::pass::MatcherPass {
public:
ConvertQuantizeDequantize();
};

View File

@ -19,6 +19,7 @@
#include "transformations/softplus_fusion.hpp"
#include "transformations/swish_fusion.hpp"
#include "transformations/hswish_fusion.hpp"
#include "transformations/convert_quantize_dequantize.hpp"
#include <ngraph/pass/manager.hpp>
#include <ngraph/pass/constant_folding.hpp>
@ -33,6 +34,7 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::
manager.register_pass<ngraph::pass::ConvertPriorBox>(); // WA: ConvertPriorBox must be executed before CF
manager.register_pass<ngraph::pass::ConstantFolding>();
manager.register_pass<ngraph::pass::RemoveFilteringBoxesBySize>(); // Resolves dynamism (replaces NonZero), CF needed
manager.register_pass<ngraph::pass::ConvertQuantizeDequantize>();
manager.register_pass<ngraph::pass::ConstantFolding>();
manager.register_pass<ngraph::pass::StridedSliceOptimization>(); // depends on CF
manager.register_pass<ngraph::pass::NopElimination>(); // may introduce fake dynamism

View File

@ -0,0 +1,153 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/convert_quantize_dequantize.hpp"
#include "transformations/utils/utils.hpp"
#include <memory>
#include <vector>
#include <ngraph/opsets/opset4.hpp>
#include <ngraph/rt_info.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
// ConvertQuantizeDequantize converts Quantize/Dequantize pair to a single FakeQuantize.
// Since Quantize is decomposed to FakeQuantize and Dequantize is decomposed to Subtract->Multiply,
// the full pattern to match is presented on the left hand side of the graph below.
// On the right hand side is the graph after transformation.
// Currently transformation supports only i8 and u8 quantized data type.
// That implies 'levels' attribute to be 256, as well as (output_low, output_high) be (-128, 127) or (0, 255) (depends on sign of the quantized data type).
// Another limitation is that 'zero_point' and 'scale' have to be broadcastable to the output of FakeQuantize.
//
//
// | | | | |
// | | | | |
// v v v v v
// +------------+
// |FakeQuantize|
// +------------+
// |
// v
// +---------------------+
// | Convert |
// |(e.g. from f32 to u8)|
// +---------+-----------+ | | | | |
// | | | | | |
// v v v v v v
// +---------------------+ +------------+
// | Convert | ====> |FakeQuantize|
// | (from u8 to f32) | +------------+
// +---------+-----------+ |
// | v
// v
// +----------+ +------------+
// |zero point|--->| Subtract |
// +----------+ +-----+------+
// |
// v
// +---------+ +------------+
// | scale |--->| Multiply |
// +---------+ +-----+------+
// |
// v
//
ngraph::pass::ConvertQuantizeDequantize::ConvertQuantizeDequantize() {
auto data_pattern = ngraph::pattern::any_input();
auto input_low_pattern = ngraph::pattern::any_input();
auto input_high_pattern = ngraph::pattern::any_input();
auto output_low_pattern = ngraph::pattern::wrap_type<opset4::Constant>();
auto output_high_pattern = ngraph::pattern::wrap_type<opset4::Constant>();
auto fq_pattern = ngraph::pattern::wrap_type<opset4::FakeQuantize>({data_pattern, input_low_pattern,
input_high_pattern, output_low_pattern,
output_high_pattern});
auto convert1_pattern = ngraph::pattern::wrap_type<opset4::Convert>({fq_pattern}, pattern::type_matches_any({element::i8, element::u8}));
auto convert2_pattern = ngraph::pattern::wrap_type<opset4::Convert>({convert1_pattern}, pattern::type_matches(element::f32));
auto zero_point_pattern = ngraph::pattern::any_input();
auto sub_pattern = ngraph::pattern::wrap_type<opset4::Subtract>({convert2_pattern, zero_point_pattern}, pattern::consumers_count(1));
auto scale_pattern = ngraph::pattern::any_input();
auto mul_pattern = ngraph::pattern::wrap_type<opset4::Multiply>({sub_pattern, scale_pattern});
ngraph::matcher_pass_callback callback = [=](pattern::Matcher& m) {
auto pattern_map = m.get_pattern_value_map();
auto data = pattern_map[data_pattern];
auto input_low = pattern_map[input_low_pattern];
auto input_high = pattern_map[input_high_pattern];
auto output_low = std::dynamic_pointer_cast<opset4::Constant>(pattern_map[output_low_pattern].get_node_shared_ptr());
if (!output_low)
return false;
auto output_high = std::dynamic_pointer_cast<opset4::Constant>(pattern_map[output_high_pattern].get_node_shared_ptr());
if (!output_high)
return false;
auto fq = std::dynamic_pointer_cast<opset4::FakeQuantize>(pattern_map[fq_pattern].get_node_shared_ptr());
if (!fq)
return false;
auto zero_point = pattern_map[zero_point_pattern];
auto scale = pattern_map[scale_pattern];
auto convert1 = pattern_map[convert1_pattern];
auto convert2 = pattern_map[convert2_pattern];
auto mul = pattern_map[mul_pattern].get_node_shared_ptr();
// convert1 and convert2 should have only one input
if (convert1.get_target_inputs().size() != 1)
return false;
if (convert2.get_target_inputs().size() != 1)
return false;
// we support only i8 or u8 so 'levels' attribute must be 256
size_t levels = fq->get_levels();
if (levels != 256)
return false;
// check if (out_low_val, out_high_val) is (-128, 127) or (0, 255)
float out_low_val;
if (!op::util::get_single_value(output_low, out_low_val))
return false;
float out_high_val;
if (!op::util::get_single_value(output_high, out_high_val))
return false;
const auto& type = convert1.get_element_type();
switch (type) {
case element::Type_t::i8:
if (out_low_val != -128 || out_high_val != 127)
return false;
break;
case element::Type_t::u8:
if (out_low_val != 0 || out_high_val != 255)
return false;
break;
default:
return false;
}
auto new_out_low = std::make_shared<ngraph::opset4::Multiply>(
std::make_shared<ngraph::opset4::Subtract>(output_low, zero_point), scale);
auto new_out_high = std::make_shared<ngraph::opset4::Multiply>(
std::make_shared<ngraph::opset4::Subtract>(output_high, zero_point), scale);
// check if new_out_low/high shapes are broadcastable to FQ's input
auto data_shape = data.get_partial_shape();
if (data_shape.rank().is_dynamic())
return false;
auto out_low_shape = new_out_low->get_output_partial_shape(0);
if (out_low_shape.rank().is_dynamic() || out_low_shape.rank().get_length() > data_shape.rank().get_length())
return false;
auto out_high_shape = new_out_high->get_output_partial_shape(0);
if (out_high_shape.rank().is_dynamic() || out_high_shape.rank().get_length() > data_shape.rank().get_length())
return false;
auto new_fq = std::make_shared<ngraph::opset4::FakeQuantize>(data, input_low, input_high, new_out_low, new_out_high, levels);
new_fq->set_friendly_name(mul->get_friendly_name());
copy_runtime_info({fq, convert1.get_node_shared_ptr(), convert2.get_node_shared_ptr()}, new_fq);
replace_node(mul, new_fq);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(mul_pattern, "ConvertQuantizeDequantize");
this->register_matcher(m, callback);
}

View File

@ -0,0 +1,205 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include <string>
#include <memory>
#include <queue>
#include <ngraph/function.hpp>
#include <ngraph/opsets/opset1.hpp>
#include <transformations/convert_quantize_dequantize.hpp>
#include <transformations/init_node_info.hpp>
#include <transformations/utils/utils.hpp>
#include <ngraph/pass/manager.hpp>
#include <ngraph/pass/constant_folding.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
using namespace testing;
using namespace ngraph;
template <typename T>
std::shared_ptr<Function> create_q_dq_function(const Shape& data_shape, float in_low, float in_high, float out_low, float out_high,
const Shape& zero_point_shape, std::vector<T> zero_point_values,
const Shape& scale_shape, std::vector<float> scale_values, size_t levels) {
auto data = std::make_shared<opset1::Parameter>(element::f32, data_shape);
auto input_low = opset1::Constant::create(element::f32, Shape{}, {in_low});
auto input_high = opset1::Constant::create(element::f32, Shape{}, {in_high});
auto output_low = opset1::Constant::create(element::f32, Shape{}, {out_low});
auto output_high = opset1::Constant::create(element::f32, Shape{}, {out_high});
auto fq = std::make_shared<opset1::FakeQuantize>(data, input_low,
input_high, output_low,
output_high, levels);
auto convert1 = std::make_shared<opset1::Convert>(fq, element::from<T>());
auto convert2 = std::make_shared<opset1::Convert>(convert1, element::f32);
auto zero_point = std::make_shared<opset1::Convert>(opset1::Constant::create(element::from<T>(), zero_point_shape, zero_point_values), element::f32);
auto sub = std::make_shared<opset1::Subtract>(convert2, zero_point);
auto scale = opset1::Constant::create(element::f32, scale_shape, scale_values);
auto mul = std::make_shared<opset1::Multiply>(sub, scale);
return std::make_shared<Function>(NodeVector{mul}, ParameterVector{data});
}
template <typename T>
void positive_test(const Shape& data_shape, float in_low, float in_high, float out_low, float out_high,
const Shape& zero_point_shape, std::vector<T> zero_point_values,
const Shape& scale_shape, std::vector<float> scale_values, size_t levels) {
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
{
f = create_q_dq_function(data_shape, in_low, in_high, out_low, out_high,
zero_point_shape, zero_point_values, scale_shape, scale_values, levels);
pass::Manager m;
m.register_pass<pass::InitNodeInfo>();
m.register_pass<pass::ConvertQuantizeDequantize>();
m.register_pass<pass::ConstantFolding>();
m.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto data = std::make_shared<opset1::Parameter>(element::f32, data_shape);
auto input_low = opset1::Constant::create(element::f32, Shape{}, {in_low});
auto input_high = opset1::Constant::create(element::f32, Shape{}, {in_high});
auto output_low = opset1::Constant::create(element::f32, Shape{}, {(out_low - zero_point_values[0]) * scale_values[0]});
auto output_high = opset1::Constant::create(element::f32, Shape{}, {(out_high - zero_point_values[0]) * scale_values[0]});
auto fq = std::make_shared<opset1::FakeQuantize>(data, input_low,
input_high, output_low,
output_high, levels);
f_ref = std::make_shared<Function>(NodeVector{fq}, ParameterVector{data});
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, ConvertQuantizeDequantizeINT8) {
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
Shape data_shape{3, 1, 2};
float in_low = 0;
float in_high = 5;
float out_low = -128;
float out_high = 127;
Shape zero_point_shape{};
std::vector<int8_t> zero_point_values{2};
Shape scale_shape{};
std::vector<float> scale_values{3};
size_t levels = 256;
positive_test(data_shape, in_low, in_high, out_low, out_high,
zero_point_shape, zero_point_values, scale_shape, scale_values, levels);
}
TEST(TransformationTests, ConvertQuantizeDequantizeUINT8) {
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
Shape data_shape{3, 1, 2};
float in_low = 0;
float in_high = 5;
float out_low = 0;
float out_high = 255;
Shape zero_point_shape{};
std::vector<uint8_t> zero_point_values{2};
Shape scale_shape{};
std::vector<float> scale_values{3};
size_t levels = 256;
positive_test(data_shape, in_low, in_high, out_low, out_high,
zero_point_shape, zero_point_values, scale_shape, scale_values, levels);
}
template <typename T>
void negative_test(const Shape& data_shape, float in_low, float in_high, float out_low, float out_high,
const Shape& zero_point_shape, std::vector<T> zero_point_values,
const Shape& scale_shape, std::vector<float> scale_values, size_t levels) {
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
{
f = create_q_dq_function(data_shape, in_low, in_high, out_low, out_high,
zero_point_shape, zero_point_values, scale_shape, scale_values, levels);
pass::Manager m;
m.register_pass<pass::InitNodeInfo>();
m.register_pass<pass::ConvertQuantizeDequantize>();
m.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
// negative test so the transformation does not fire and reference is the same graph as original
f_ref = create_q_dq_function(data_shape, in_low, in_high, out_low, out_high,
zero_point_shape, zero_point_values, scale_shape, scale_values, levels);
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, ConvertQuantizeDequantizeZeroPointNotBroadcastable) {
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
Shape data_shape{3, 1, 2};
float in_low = 0;
float in_high = 5;
float out_low = -128;
float out_high = 127;
Shape zero_point_shape{1, 1, 1, 1};
std::vector<int8_t> zero_point_values{2};
Shape scale_shape{1};
std::vector<float> scale_values{3};
size_t levels = 256;
negative_test(data_shape, in_low, in_high, out_low, out_high,
zero_point_shape, zero_point_values, scale_shape, scale_values, levels);
}
TEST(TransformationTests, ConvertQuantizeDequantizeScaleNotBroadcastable) {
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
Shape data_shape{3, 1, 2};
float in_low = 0;
float in_high = 5;
float out_low = -128;
float out_high = 127;
Shape zero_point_shape{};
std::vector<int8_t> zero_point_values{2};
Shape scale_shape{1, 1, 1, 1};
std::vector<float> scale_values{3};
size_t levels = 256;
negative_test(data_shape, in_low, in_high, out_low, out_high,
zero_point_shape, zero_point_values, scale_shape, scale_values, levels);
}
TEST(TransformationTests, ConvertQuantizeDequantizeInvalidLevels) {
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
Shape data_shape{3, 1, 2};
float in_low = 0;
float in_high = 5;
float out_low = -128;
float out_high = 127;
Shape zero_point_shape{};
std::vector<int8_t> zero_point_values{2};
Shape scale_shape{};
std::vector<float> scale_values{3};
size_t levels = 127;
negative_test(data_shape, in_low, in_high, out_low, out_high,
zero_point_shape, zero_point_values, scale_shape, scale_values, levels);
}
TEST(TransformationTests, ConvertQuantizeDequantizeInvalidOutLowOutHigh) {
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
Shape data_shape{3, 1, 2};
float in_low = 0;
float in_high = 5;
// (-128, 127) are invalid for uin8_t data type
float out_low = -128;
float out_high = 127;
Shape zero_point_shape{};
std::vector<uint8_t> zero_point_values{2};
Shape scale_shape{};
std::vector<float> scale_values{3};
size_t levels = 256;
negative_test(data_shape, in_low, in_high, out_low, out_high,
zero_point_shape, zero_point_values, scale_shape, scale_values, levels);
}

View File

@ -61,6 +61,12 @@ namespace ngraph
NGRAPH_API
std::function<bool(Output<Node>)> has_static_shape();
NGRAPH_API
std::function<bool(Output<Node>)> type_matches(const element::Type& type);
NGRAPH_API
std::function<bool(Output<Node>)> type_matches_any(const std::vector<element::Type>& types);
namespace op
{
using NodePredicate = std::function<bool(std::shared_ptr<Node>)>;

View File

@ -94,5 +94,21 @@ namespace ngraph
return
[=](Output<Node> output) -> bool { return output.get_partial_shape().is_static(); };
}
std::function<bool(Output<Node>)> type_matches(const element::Type& type)
{
return [=](Output<Node> output) -> bool { return output.get_element_type() == type; };
}
std::function<bool(Output<Node>)>
type_matches_any(const std::vector<element::Type>& expected_types)
{
return [=](Output<Node> output) -> bool {
const auto& output_type = output.get_element_type();
return std::any_of(expected_types.begin(),
expected_types.end(),
[=](element::Type type) { return type == output_type; });
};
}
}
}

View File

@ -0,0 +1,108 @@
ir_version: 3
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "data"
input: "scale"
input: "zero_point"
output: "quantization_out"
name: "quantization"
op_type: "QuantizeLinear"
}
node {
input: "quantization_out"
input: "scale"
input: "zero_point"
output: "dequantization_out"
name: "dequantization"
op_type: "DequantizeLinear"
}
node {
input: "dequantization_out"
input: "x"
output: "mul_out"
name: "mul"
op_type: "Mul"
}
name: "test_graph"
input {
name: "data"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 5
}
}
}
}
}
input {
name: "scale"
type {
tensor_type {
elem_type: 1
shape {
dim {
}
}
}
}
}
input {
name: "zero_point"
type {
tensor_type {
elem_type: 3
shape {
dim {
}
}
}
}
}
input {
name: "x"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
}
}
}
}
output {
name: "mul_out"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 5
}
}
}
}
}
initializer {
dims: 0
data_type: 1
name: "scale"
float_data: 3
}
initializer {
dims: 0
data_type: 2
name: "zero_point"
int32_data: 10
}
}
opset_import {
version: 11
}

View File

@ -0,0 +1,125 @@
ir_version: 3
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "data"
input: "scale"
input: "zero_point"
output: "quantization_out"
name: "quantization"
op_type: "QuantizeLinear"
attribute {
name: "axis"
i: 1
type: INT
}
}
node {
input: "quantization_out"
input: "scale"
input: "zero_point"
output: "dequantization_out"
name: "dequantization"
op_type: "DequantizeLinear"
}
node {
input: "dequantization_out"
input: "x"
output: "mul_out"
name: "mul"
op_type: "Mul"
}
name: "test_graph"
input {
name: "data"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 3
}
dim {
dim_value: 3
}
}
}
}
}
input {
name: "scale"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 3
}
}
}
}
}
input {
name: "zero_point"
type {
tensor_type {
elem_type: 3
shape {
dim {
dim_value: 3
}
}
}
}
}
input {
name: "x"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
}
}
}
}
output {
name: "mul_out"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 3
}
dim {
dim_value: 3
}
}
}
}
}
initializer {
dims: 3
data_type: 1
name: "scale"
float_data: 2
float_data: 3
float_data: 4
}
initializer {
dims: 3
data_type: 2
name: "zero_point"
int32_data: 10
int32_data: 20
int32_data: 30
}
}
opset_import {
version: 13
}

View File

@ -2447,3 +2447,30 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_empty_initializers_handling)
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, quant_dequant_pattern)
{
const auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/quant_dequant_pattern.prototxt"));
auto test_case = test::TestCase<TestEngine>(function);
// scale == 3.0
// zero point == 10
test_case.add_input<float>({9.0, 10.0, 15.0, 20.0, 30.0});
test_case.add_input<float>({1});
test_case.add_expected_output<float>(Shape{5}, {9.0, 9.0, 15.0, 21.0, 30.0});
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, quant_dequant_pattern_axis)
{
const auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/quant_dequant_pattern_axis.prototxt"));
auto test_case = test::TestCase<TestEngine>(function);
// axis = 1
// scale == {2.0, 3.0, 4.0}
// zero point == {10, 20, 30}
test_case.add_input<float>({1.0, 2.0, 3.0, 10.0, 20.0, 30.0, 40.0, 50.0, 100.0});
test_case.add_expected_output<float>(Shape{3, 3}, {0, 3, 4, 10, 21, 32, 40, 51, 100});
test_case.add_input<float>({1});
test_case.run();
}