[ONNX] Quantize linear using FakeQuantize (#1169)

This commit is contained in:
Adam Osewski 2020-07-14 10:55:07 +02:00 committed by GitHub
parent b16c8faceb
commit ed4bbb3a0a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 323 additions and 45 deletions

View File

@ -173,8 +173,8 @@ add_library(onnx_importer SHARED
op/qlinear_matmul.hpp
# op/quant_conv.cpp
# op/quant_conv.hpp
# op/quantize_linear.cpp
# op/quantize_linear.hpp
op/quantize_linear.cpp
op/quantize_linear.hpp
op/range.cpp
op/range.hpp
op/reciprocal.cpp

View File

@ -16,11 +16,18 @@
#include <cstdint>
#include <memory>
#include <numeric>
#include <tuple>
#include "default_opset.hpp"
#include "exceptions.hpp"
#include "ngraph/axis_set.hpp"
#include "ngraph/opsets/opset0.hpp"
#include "ngraph/builder/reshape.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/type/element_type.hpp"
#include "ngraph/validation_util.hpp"
#include "quantize_linear.hpp"
#include "utils/reshape.hpp"
namespace ngraph
{
@ -28,49 +35,218 @@ namespace ngraph
{
namespace op
{
namespace detail
{
namespace
{
std::shared_ptr<ngraph::Node> get_zero_point(const NodeVector& inputs)
{
if (inputs.size() > 2)
{
return inputs.at(2);
}
else
{
return std::make_shared<default_opset::Constant>(
element::u8, Shape{1}, std::uint8_t(0));
}
}
void validate_zero_point_type(const Node& onnx_node,
const std::shared_ptr<ngraph::Node>& y_zero_point)
{
const auto& y_zero_point_et = y_zero_point->get_element_type();
CHECK_VALID_NODE(
onnx_node,
y_zero_point_et.is_static() &&
(y_zero_point_et == element::u8 || y_zero_point_et == element::i8),
"\"y_zero_point\" input data type must be static and of 8-bit "
"integer type.");
}
std::shared_ptr<ngraph::Node>
validate_scale(const Node& onnx_node,
const std::shared_ptr<ngraph::Node>& y_scale)
{
const auto& y_scale_et = y_scale->get_element_type();
CHECK_VALID_NODE(onnx_node,
y_scale_et.is_static(),
"\"y_scale\" input data type must be static.");
if (y_scale_et != element::f32)
{
return std::make_shared<default_opset::Convert>(y_scale, element::f32);
}
return y_scale;
}
std::shared_ptr<ngraph::Node> validate_data(const Node& onnx_node,
std::shared_ptr<ngraph::Node>& data)
{
const auto& data_et = data->get_element_type();
CHECK_VALID_NODE(onnx_node,
data_et.is_static(),
"\"x\" input data type must be static.");
if (data_et != element::f32)
{
return std::make_shared<default_opset::Convert>(data, element::f32);
}
return data;
}
std::tuple<std::shared_ptr<ngraph::Node>, std::shared_ptr<ngraph::Node>>
get_output_bands(const element::Type& destination_type,
const element::Type& data_type)
{
std::shared_ptr<ngraph::Node> output_low;
std::shared_ptr<ngraph::Node> output_high;
if (destination_type == element::i8)
{
output_low = std::make_shared<default_opset::Constant>(
data_type, Shape{1}, -128);
output_high =
std::make_shared<default_opset::Constant>(data_type, Shape{1}, 127);
}
else
{
output_low =
std::make_shared<default_opset::Constant>(data_type, Shape{1}, 0);
output_high =
std::make_shared<default_opset::Constant>(data_type, Shape{1}, 255);
}
return std::make_tuple(output_low, output_high);
}
std::tuple<std::shared_ptr<ngraph::Node>, std::shared_ptr<ngraph::Node>>
get_input_bands(const std::shared_ptr<ngraph::Node>& y_scale,
const std::shared_ptr<ngraph::Node>& y_zero_point,
const std::shared_ptr<ngraph::Node>& output_low,
const std::shared_ptr<ngraph::Node>& output_high,
const element::Type& data_type)
{
std::shared_ptr<ngraph::Node> input_low;
std::shared_ptr<ngraph::Node> input_high;
const auto& zero_point =
std::make_shared<default_opset::Convert>(y_zero_point, data_type);
input_low = std::make_shared<default_opset::Multiply>(
y_scale,
std::make_shared<default_opset::Subtract>(output_low, zero_point));
input_high = std::make_shared<default_opset::Multiply>(
y_scale,
std::make_shared<default_opset::Subtract>(output_high, zero_point));
return std::make_tuple(input_low, input_high);
}
std::shared_ptr<ngraph::Node>
make_fake_quantize(const std::shared_ptr<ngraph::Node>& y_scale,
const std::shared_ptr<ngraph::Node>& y_zero_point,
const std::shared_ptr<ngraph::Node>& data)
{
const element::Type& destination_type = y_zero_point->get_element_type();
const element::Type& data_type = data->get_element_type();
std::shared_ptr<ngraph::Node> output_low;
std::shared_ptr<ngraph::Node> output_high;
std::tie(output_low, output_high) =
detail::get_output_bands(destination_type, data_type);
std::shared_ptr<ngraph::Node> input_low;
std::shared_ptr<ngraph::Node> input_high;
std::tie(input_low, input_high) = detail::get_input_bands(
y_scale, y_zero_point, output_low, output_high, data_type);
const std::size_t levels = 1 << destination_type.bitwidth();
return std::make_shared<default_opset::Convert>(
std::make_shared<default_opset::FakeQuantize>(
data, input_low, input_high, output_low, output_high, levels),
destination_type);
}
}
}
namespace set_1
{
NodeVector quantize_linear(const Node& node)
{
NodeVector inputs{node.get_ng_inputs()};
std::shared_ptr<ngraph::Node> x = inputs.at(0);
std::shared_ptr<ngraph::Node> y_scale = inputs.at(1);
std::shared_ptr<ngraph::Node> y_zero_point = inputs.at(2);
auto x = inputs.at(0);
auto y_scale = inputs.at(1);
auto y_zero_point = detail::get_zero_point(inputs);
// get axis twice with two default values to see if it is set
int64_t axis_0{node.get_attribute_value<int64_t>("axis", 0)};
int64_t axis_1{node.get_attribute_value<int64_t>("axis", 1)};
x = detail::validate_data(node, x);
detail::validate_zero_point_type(node, y_zero_point);
y_scale = detail::validate_scale(node, y_scale);
AxisSet axes;
return {detail::make_fake_quantize(y_scale, y_zero_point, x)};
}
} // namespace set_1
// if axis attribute is set
if (axis_0 == axis_1)
namespace set_13
{
NodeVector quantize_linear(const Node& node)
{
NodeVector inputs{node.get_ng_inputs()};
auto x = inputs.at(0);
auto y_scale = inputs.at(1);
auto y_zero_point = detail::get_zero_point(inputs);
x = detail::validate_data(node, x);
detail::validate_zero_point_type(node, y_zero_point);
y_scale = detail::validate_scale(node, y_scale);
const auto& x_shape = x->get_output_partial_shape(0);
int64_t axis{node.get_attribute_value<int64_t>("axis", 1)};
axis = normalize_axis(node.get_description(), axis, x_shape.rank());
const auto& y_scale_shape = y_scale->get_output_partial_shape(0);
const auto& y_zero_point_shape = y_zero_point->get_output_partial_shape(0);
if (y_scale_shape.rank().is_static() &&
y_scale_shape.rank().get_length() == 1 && x_shape.rank().is_static() &&
x_shape[axis].is_static())
{
// positive axis
if (axis_0 >= 0)
{
axes.insert(axis_0);
}
// negative axis
else if (axis_0 < 0)
{
axes.insert(x->get_shape().size() + axis_0);
}
CHECK_VALID_NODE(
node,
y_scale_shape[0].same_scheme(x_shape[axis]),
"The number of quantization scale elements ",
y_scale_shape[0],
" must match the number of respective input data axis size: ",
x_shape[axis]);
Shape target_shape(x_shape.rank().get_length(), 1);
target_shape[axis] = static_cast<size_t>(x_shape[axis].get_length());
y_scale = builder::opset1::reshape(y_scale, target_shape);
}
Shape y_scale_shape = y_scale->get_shape();
Shape y_zero_point_shape = y_zero_point->get_shape();
if (y_zero_point_shape.rank().is_static() &&
y_zero_point_shape.rank().get_length() == 1 && x_shape.rank().is_static() &&
x_shape[axis].is_static())
{
CHECK_VALID_NODE(
node,
y_zero_point_shape[0].same_scheme(x_shape[axis]),
"The number of quantization zero point elements ",
y_zero_point_shape[0],
" must match the number of respective input data axis size: ",
x_shape[axis]);
return {std::make_shared<ngraph::opset0::Quantize>(
x,
y_scale,
y_zero_point,
y_zero_point->get_element_type(),
axes,
ngraph::opset0::Quantize::RoundMode::ROUND_NEAREST_TOWARD_EVEN)};
Shape target_shape(x_shape.rank().get_length(), 1);
target_shape[axis] = static_cast<size_t>(x_shape[axis].get_length());
y_zero_point = builder::opset1::reshape(y_zero_point, target_shape);
}
return {detail::make_fake_quantize(y_scale, y_zero_point, x)};
}
} // namespace set_1
} // namespace set_13
} // namespace op

View File

@ -31,6 +31,12 @@ namespace ngraph
} // namespace set_1
namespace set_13
{
NodeVector quantize_linear(const Node& node);
} // namespace set_13
} // namespace op
} // namespace onnx_import

View File

@ -101,7 +101,7 @@
#include "op/prelu.hpp"
#include "op/qlinear_matmul.hpp"
// #include "op/quant_conv.hpp"
// #include "op/quantize_linear.hpp"
#include "op/quantize_linear.hpp"
#include "op/range.hpp"
#include "op/reciprocal.hpp"
#include "op/reduce.hpp"
@ -339,7 +339,8 @@ namespace ngraph
REGISTER_OPERATOR("PRelu", 1, prelu);
// REGISTER_OPERATOR("QLinearConv", 1, quant_conv);
REGISTER_OPERATOR("QLinearMatMul", 1, qlinear_matmul);
// REGISTER_OPERATOR("QuantizeLinear", 1, quantize_linear);
REGISTER_OPERATOR("QuantizeLinear", 1, quantize_linear);
REGISTER_OPERATOR("QuantizeLinear", 13, quantize_linear);
REGISTER_OPERATOR("Range", 1, range);
REGISTER_OPERATOR("Reciprocal", 1, reciprocal);
REGISTER_OPERATOR("ReduceLogSum", 1, reduce_log_sum);

View File

@ -149,7 +149,7 @@ NodeVector op::FakeQuantize::decompose_op() const
zero_point,
element::i32,
axes,
op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_INFINITY);
op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_EVEN);
quantized_data = make_shared<op::Convert>(quantized_data, input_data_type);

View File

@ -1,4 +1,4 @@
ir_version: 3
ir_version: 6
producer_name: "ngraph ONNXImporter"
graph {
node {
@ -75,5 +75,5 @@ graph {
}
}
opset_import {
version: 10
version: 13
}

View File

@ -1,4 +1,4 @@
ir_version: 3
ir_version: 6
producer_name: "ngraph ONNXImporter"
graph {
node {
@ -75,5 +75,5 @@ graph {
}
}
opset_import {
version: 10
version: 13
}

View File

@ -0,0 +1,78 @@
ir_version: 3
producer_name: "ngraph ONNXImporter"
graph {
node {
input: "X"
input: "y_scale"
input: "y_zero_point"
output: "Y"
name: "QuantizeLinear"
op_type: "QuantizeLinear"
}
name: "test_graph"
initializer {
data_type: 2
name: "y_zero_point"
raw_data: "\000"
}
initializer {
name: "y_scale"
data_type: 1
float_data: 0.5
}
input {
name: "X"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
dim {
dim_value: 2
}
}
}
}
}
input {
name: "y_scale"
type {
tensor_type {
elem_type: 1
shape {
}
}
}
}
input {
name: "y_zero_point"
type {
tensor_type {
elem_type: 2
shape {
}
}
}
}
output {
name: "Y"
type {
tensor_type {
elem_type: 2
shape {
dim {
dim_value: 2
}
dim {
dim_value: 2
}
}
}
}
}
}
opset_import {
version: 10
}

View File

@ -1,4 +1,4 @@
ir_version: 3
ir_version: 6
producer_name: "ngraph ONNXImporter"
graph {
node {

View File

@ -44,6 +44,18 @@ using TestEngine = test::ENGINE_CLASS_NAME(${BACKEND_NAME});
using Inputs = std::vector<std::vector<float>>;
using Outputs = std::vector<std::vector<float>>;
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_quantize_linear_const_scale_const_zero_p)
{
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/quantize_linear_const.prototxt"));
auto test_case = test::TestCase<TestEngine>(function);
test_case.add_input(std::vector<float>{32.25f, 48.34f, 50.f, 83.f});
test_case.add_expected_output(std::vector<std::uint8_t>{64, 97, 100, 166});
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_quantize_linear)
{
auto function = onnx_import::import_onnx_model(

View File

@ -10,8 +10,10 @@
#
#-------------------------------------------------------------------------------
# Segmentation fault
onnx_model_quantize_linear_const_scale_const_zero_p
# Not supported ONNX op: QuantizeLinear
# Quantize layer input 'Multiply_7' doesn't have blobs
onnx_model_quantize_linear
onnx_model_quantize_linear_zero_point
onnx_model_quantize_linear_axis_zero

View File

@ -87,10 +87,13 @@ INTERPRETER.convolution_2d_1item_5o3i_data_dilated
INTERPRETER.convolution_2d_2item_5o3i_data_dilated
# Removed opset0 operations
INTERPRETER.onnx_model_quantize_linear
INTERPRETER.onnx_model_quantize_linear_zero_point
INTERPRETER.onnx_model_quantize_linear_axis_zero
INTERPRETER.onnx_model_quantize_linear_axis_negative
INTERPRETER.onnx_model_dequantize_linear
INTERPRETER.onnx_model_dequantize_linear_scalar_zero_scale_uint8
INTERPRETER.onnx_model_dequantize_linear_scalar_zero_scale_int8
INTERPRETER.onnx_model_dequantize_linear_1d_zero_scale_uint8
INTERPRETER.onnx_model_dequantize_linear_1d_zero_scale_int8
INTERPRETER.onnx_model_dequantize_linear_1d_zero_scale_int8_4d
INTERPRETER.onnx_model_dequantize_linear_1d_zero_scale_uint8_negative_axis
INTERPRETER.onnx_model_quant_conv_linear_2d
INTERPRETER.onnx_model_quant_conv_linear_3d
INTERPRETER.onnx_model_conv_integer