[ONNX] Quantize linear using FakeQuantize (#1169)
This commit is contained in:
parent
b16c8faceb
commit
ed4bbb3a0a
@ -173,8 +173,8 @@ add_library(onnx_importer SHARED
|
||||
op/qlinear_matmul.hpp
|
||||
# op/quant_conv.cpp
|
||||
# op/quant_conv.hpp
|
||||
# op/quantize_linear.cpp
|
||||
# op/quantize_linear.hpp
|
||||
op/quantize_linear.cpp
|
||||
op/quantize_linear.hpp
|
||||
op/range.cpp
|
||||
op/range.hpp
|
||||
op/reciprocal.cpp
|
||||
|
@ -16,11 +16,18 @@
|
||||
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <numeric>
|
||||
#include <tuple>
|
||||
|
||||
#include "default_opset.hpp"
|
||||
#include "exceptions.hpp"
|
||||
#include "ngraph/axis_set.hpp"
|
||||
#include "ngraph/opsets/opset0.hpp"
|
||||
#include "ngraph/builder/reshape.hpp"
|
||||
#include "ngraph/shape.hpp"
|
||||
#include "ngraph/type/element_type.hpp"
|
||||
#include "ngraph/validation_util.hpp"
|
||||
#include "quantize_linear.hpp"
|
||||
#include "utils/reshape.hpp"
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
@ -28,49 +35,218 @@ namespace ngraph
|
||||
{
|
||||
namespace op
|
||||
{
|
||||
namespace detail
|
||||
{
|
||||
namespace
|
||||
{
|
||||
std::shared_ptr<ngraph::Node> get_zero_point(const NodeVector& inputs)
|
||||
{
|
||||
if (inputs.size() > 2)
|
||||
{
|
||||
return inputs.at(2);
|
||||
}
|
||||
else
|
||||
{
|
||||
return std::make_shared<default_opset::Constant>(
|
||||
element::u8, Shape{1}, std::uint8_t(0));
|
||||
}
|
||||
}
|
||||
|
||||
void validate_zero_point_type(const Node& onnx_node,
|
||||
const std::shared_ptr<ngraph::Node>& y_zero_point)
|
||||
{
|
||||
const auto& y_zero_point_et = y_zero_point->get_element_type();
|
||||
CHECK_VALID_NODE(
|
||||
onnx_node,
|
||||
y_zero_point_et.is_static() &&
|
||||
(y_zero_point_et == element::u8 || y_zero_point_et == element::i8),
|
||||
"\"y_zero_point\" input data type must be static and of 8-bit "
|
||||
"integer type.");
|
||||
}
|
||||
|
||||
std::shared_ptr<ngraph::Node>
|
||||
validate_scale(const Node& onnx_node,
|
||||
const std::shared_ptr<ngraph::Node>& y_scale)
|
||||
{
|
||||
const auto& y_scale_et = y_scale->get_element_type();
|
||||
CHECK_VALID_NODE(onnx_node,
|
||||
y_scale_et.is_static(),
|
||||
"\"y_scale\" input data type must be static.");
|
||||
if (y_scale_et != element::f32)
|
||||
{
|
||||
return std::make_shared<default_opset::Convert>(y_scale, element::f32);
|
||||
}
|
||||
return y_scale;
|
||||
}
|
||||
|
||||
std::shared_ptr<ngraph::Node> validate_data(const Node& onnx_node,
|
||||
std::shared_ptr<ngraph::Node>& data)
|
||||
{
|
||||
const auto& data_et = data->get_element_type();
|
||||
CHECK_VALID_NODE(onnx_node,
|
||||
data_et.is_static(),
|
||||
"\"x\" input data type must be static.");
|
||||
|
||||
if (data_et != element::f32)
|
||||
{
|
||||
return std::make_shared<default_opset::Convert>(data, element::f32);
|
||||
}
|
||||
return data;
|
||||
}
|
||||
|
||||
std::tuple<std::shared_ptr<ngraph::Node>, std::shared_ptr<ngraph::Node>>
|
||||
get_output_bands(const element::Type& destination_type,
|
||||
const element::Type& data_type)
|
||||
{
|
||||
std::shared_ptr<ngraph::Node> output_low;
|
||||
std::shared_ptr<ngraph::Node> output_high;
|
||||
|
||||
if (destination_type == element::i8)
|
||||
{
|
||||
output_low = std::make_shared<default_opset::Constant>(
|
||||
data_type, Shape{1}, -128);
|
||||
output_high =
|
||||
std::make_shared<default_opset::Constant>(data_type, Shape{1}, 127);
|
||||
}
|
||||
else
|
||||
{
|
||||
output_low =
|
||||
std::make_shared<default_opset::Constant>(data_type, Shape{1}, 0);
|
||||
output_high =
|
||||
std::make_shared<default_opset::Constant>(data_type, Shape{1}, 255);
|
||||
}
|
||||
|
||||
return std::make_tuple(output_low, output_high);
|
||||
}
|
||||
|
||||
std::tuple<std::shared_ptr<ngraph::Node>, std::shared_ptr<ngraph::Node>>
|
||||
get_input_bands(const std::shared_ptr<ngraph::Node>& y_scale,
|
||||
const std::shared_ptr<ngraph::Node>& y_zero_point,
|
||||
const std::shared_ptr<ngraph::Node>& output_low,
|
||||
const std::shared_ptr<ngraph::Node>& output_high,
|
||||
const element::Type& data_type)
|
||||
{
|
||||
std::shared_ptr<ngraph::Node> input_low;
|
||||
std::shared_ptr<ngraph::Node> input_high;
|
||||
const auto& zero_point =
|
||||
std::make_shared<default_opset::Convert>(y_zero_point, data_type);
|
||||
|
||||
input_low = std::make_shared<default_opset::Multiply>(
|
||||
y_scale,
|
||||
std::make_shared<default_opset::Subtract>(output_low, zero_point));
|
||||
input_high = std::make_shared<default_opset::Multiply>(
|
||||
y_scale,
|
||||
std::make_shared<default_opset::Subtract>(output_high, zero_point));
|
||||
|
||||
return std::make_tuple(input_low, input_high);
|
||||
}
|
||||
|
||||
std::shared_ptr<ngraph::Node>
|
||||
make_fake_quantize(const std::shared_ptr<ngraph::Node>& y_scale,
|
||||
const std::shared_ptr<ngraph::Node>& y_zero_point,
|
||||
const std::shared_ptr<ngraph::Node>& data)
|
||||
{
|
||||
const element::Type& destination_type = y_zero_point->get_element_type();
|
||||
const element::Type& data_type = data->get_element_type();
|
||||
|
||||
std::shared_ptr<ngraph::Node> output_low;
|
||||
std::shared_ptr<ngraph::Node> output_high;
|
||||
std::tie(output_low, output_high) =
|
||||
detail::get_output_bands(destination_type, data_type);
|
||||
|
||||
std::shared_ptr<ngraph::Node> input_low;
|
||||
std::shared_ptr<ngraph::Node> input_high;
|
||||
std::tie(input_low, input_high) = detail::get_input_bands(
|
||||
y_scale, y_zero_point, output_low, output_high, data_type);
|
||||
|
||||
const std::size_t levels = 1 << destination_type.bitwidth();
|
||||
|
||||
return std::make_shared<default_opset::Convert>(
|
||||
std::make_shared<default_opset::FakeQuantize>(
|
||||
data, input_low, input_high, output_low, output_high, levels),
|
||||
destination_type);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
namespace set_1
|
||||
{
|
||||
NodeVector quantize_linear(const Node& node)
|
||||
{
|
||||
NodeVector inputs{node.get_ng_inputs()};
|
||||
std::shared_ptr<ngraph::Node> x = inputs.at(0);
|
||||
std::shared_ptr<ngraph::Node> y_scale = inputs.at(1);
|
||||
std::shared_ptr<ngraph::Node> y_zero_point = inputs.at(2);
|
||||
auto x = inputs.at(0);
|
||||
auto y_scale = inputs.at(1);
|
||||
auto y_zero_point = detail::get_zero_point(inputs);
|
||||
|
||||
// get axis twice with two default values to see if it is set
|
||||
int64_t axis_0{node.get_attribute_value<int64_t>("axis", 0)};
|
||||
int64_t axis_1{node.get_attribute_value<int64_t>("axis", 1)};
|
||||
x = detail::validate_data(node, x);
|
||||
detail::validate_zero_point_type(node, y_zero_point);
|
||||
y_scale = detail::validate_scale(node, y_scale);
|
||||
|
||||
AxisSet axes;
|
||||
return {detail::make_fake_quantize(y_scale, y_zero_point, x)};
|
||||
}
|
||||
} // namespace set_1
|
||||
|
||||
// if axis attribute is set
|
||||
if (axis_0 == axis_1)
|
||||
namespace set_13
|
||||
{
|
||||
NodeVector quantize_linear(const Node& node)
|
||||
{
|
||||
NodeVector inputs{node.get_ng_inputs()};
|
||||
auto x = inputs.at(0);
|
||||
auto y_scale = inputs.at(1);
|
||||
auto y_zero_point = detail::get_zero_point(inputs);
|
||||
|
||||
x = detail::validate_data(node, x);
|
||||
detail::validate_zero_point_type(node, y_zero_point);
|
||||
y_scale = detail::validate_scale(node, y_scale);
|
||||
|
||||
const auto& x_shape = x->get_output_partial_shape(0);
|
||||
|
||||
int64_t axis{node.get_attribute_value<int64_t>("axis", 1)};
|
||||
axis = normalize_axis(node.get_description(), axis, x_shape.rank());
|
||||
|
||||
const auto& y_scale_shape = y_scale->get_output_partial_shape(0);
|
||||
const auto& y_zero_point_shape = y_zero_point->get_output_partial_shape(0);
|
||||
|
||||
if (y_scale_shape.rank().is_static() &&
|
||||
y_scale_shape.rank().get_length() == 1 && x_shape.rank().is_static() &&
|
||||
x_shape[axis].is_static())
|
||||
{
|
||||
// positive axis
|
||||
if (axis_0 >= 0)
|
||||
{
|
||||
axes.insert(axis_0);
|
||||
}
|
||||
// negative axis
|
||||
else if (axis_0 < 0)
|
||||
{
|
||||
axes.insert(x->get_shape().size() + axis_0);
|
||||
}
|
||||
CHECK_VALID_NODE(
|
||||
node,
|
||||
y_scale_shape[0].same_scheme(x_shape[axis]),
|
||||
"The number of quantization scale elements ",
|
||||
y_scale_shape[0],
|
||||
" must match the number of respective input data axis size: ",
|
||||
x_shape[axis]);
|
||||
|
||||
Shape target_shape(x_shape.rank().get_length(), 1);
|
||||
target_shape[axis] = static_cast<size_t>(x_shape[axis].get_length());
|
||||
|
||||
y_scale = builder::opset1::reshape(y_scale, target_shape);
|
||||
}
|
||||
|
||||
Shape y_scale_shape = y_scale->get_shape();
|
||||
Shape y_zero_point_shape = y_zero_point->get_shape();
|
||||
if (y_zero_point_shape.rank().is_static() &&
|
||||
y_zero_point_shape.rank().get_length() == 1 && x_shape.rank().is_static() &&
|
||||
x_shape[axis].is_static())
|
||||
{
|
||||
CHECK_VALID_NODE(
|
||||
node,
|
||||
y_zero_point_shape[0].same_scheme(x_shape[axis]),
|
||||
"The number of quantization zero point elements ",
|
||||
y_zero_point_shape[0],
|
||||
" must match the number of respective input data axis size: ",
|
||||
x_shape[axis]);
|
||||
|
||||
return {std::make_shared<ngraph::opset0::Quantize>(
|
||||
x,
|
||||
y_scale,
|
||||
y_zero_point,
|
||||
y_zero_point->get_element_type(),
|
||||
axes,
|
||||
ngraph::opset0::Quantize::RoundMode::ROUND_NEAREST_TOWARD_EVEN)};
|
||||
Shape target_shape(x_shape.rank().get_length(), 1);
|
||||
target_shape[axis] = static_cast<size_t>(x_shape[axis].get_length());
|
||||
|
||||
y_zero_point = builder::opset1::reshape(y_zero_point, target_shape);
|
||||
}
|
||||
|
||||
return {detail::make_fake_quantize(y_scale, y_zero_point, x)};
|
||||
}
|
||||
|
||||
} // namespace set_1
|
||||
} // namespace set_13
|
||||
|
||||
} // namespace op
|
||||
|
||||
|
@ -31,6 +31,12 @@ namespace ngraph
|
||||
|
||||
} // namespace set_1
|
||||
|
||||
namespace set_13
|
||||
{
|
||||
NodeVector quantize_linear(const Node& node);
|
||||
|
||||
} // namespace set_13
|
||||
|
||||
} // namespace op
|
||||
|
||||
} // namespace onnx_import
|
||||
|
@ -101,7 +101,7 @@
|
||||
#include "op/prelu.hpp"
|
||||
#include "op/qlinear_matmul.hpp"
|
||||
// #include "op/quant_conv.hpp"
|
||||
// #include "op/quantize_linear.hpp"
|
||||
#include "op/quantize_linear.hpp"
|
||||
#include "op/range.hpp"
|
||||
#include "op/reciprocal.hpp"
|
||||
#include "op/reduce.hpp"
|
||||
@ -339,7 +339,8 @@ namespace ngraph
|
||||
REGISTER_OPERATOR("PRelu", 1, prelu);
|
||||
// REGISTER_OPERATOR("QLinearConv", 1, quant_conv);
|
||||
REGISTER_OPERATOR("QLinearMatMul", 1, qlinear_matmul);
|
||||
// REGISTER_OPERATOR("QuantizeLinear", 1, quantize_linear);
|
||||
REGISTER_OPERATOR("QuantizeLinear", 1, quantize_linear);
|
||||
REGISTER_OPERATOR("QuantizeLinear", 13, quantize_linear);
|
||||
REGISTER_OPERATOR("Range", 1, range);
|
||||
REGISTER_OPERATOR("Reciprocal", 1, reciprocal);
|
||||
REGISTER_OPERATOR("ReduceLogSum", 1, reduce_log_sum);
|
||||
|
@ -149,7 +149,7 @@ NodeVector op::FakeQuantize::decompose_op() const
|
||||
zero_point,
|
||||
element::i32,
|
||||
axes,
|
||||
op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_INFINITY);
|
||||
op::Quantize::RoundMode::ROUND_NEAREST_TOWARD_EVEN);
|
||||
|
||||
quantized_data = make_shared<op::Convert>(quantized_data, input_data_type);
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
ir_version: 3
|
||||
ir_version: 6
|
||||
producer_name: "ngraph ONNXImporter"
|
||||
graph {
|
||||
node {
|
||||
@ -75,5 +75,5 @@ graph {
|
||||
}
|
||||
}
|
||||
opset_import {
|
||||
version: 10
|
||||
version: 13
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
ir_version: 3
|
||||
ir_version: 6
|
||||
producer_name: "ngraph ONNXImporter"
|
||||
graph {
|
||||
node {
|
||||
@ -75,5 +75,5 @@ graph {
|
||||
}
|
||||
}
|
||||
opset_import {
|
||||
version: 10
|
||||
version: 13
|
||||
}
|
||||
|
78
ngraph/test/models/onnx/quantize_linear_const.prototxt
Normal file
78
ngraph/test/models/onnx/quantize_linear_const.prototxt
Normal file
@ -0,0 +1,78 @@
|
||||
ir_version: 3
|
||||
producer_name: "ngraph ONNXImporter"
|
||||
graph {
|
||||
node {
|
||||
input: "X"
|
||||
input: "y_scale"
|
||||
input: "y_zero_point"
|
||||
output: "Y"
|
||||
name: "QuantizeLinear"
|
||||
op_type: "QuantizeLinear"
|
||||
}
|
||||
name: "test_graph"
|
||||
initializer {
|
||||
data_type: 2
|
||||
name: "y_zero_point"
|
||||
raw_data: "\000"
|
||||
}
|
||||
initializer {
|
||||
name: "y_scale"
|
||||
data_type: 1
|
||||
float_data: 0.5
|
||||
}
|
||||
input {
|
||||
name: "X"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "y_scale"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "y_zero_point"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 2
|
||||
shape {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "Y"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 2
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
opset_import {
|
||||
version: 10
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
ir_version: 3
|
||||
ir_version: 6
|
||||
producer_name: "ngraph ONNXImporter"
|
||||
graph {
|
||||
node {
|
||||
|
@ -44,6 +44,18 @@ using TestEngine = test::ENGINE_CLASS_NAME(${BACKEND_NAME});
|
||||
using Inputs = std::vector<std::vector<float>>;
|
||||
using Outputs = std::vector<std::vector<float>>;
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_quantize_linear_const_scale_const_zero_p)
|
||||
{
|
||||
auto function = onnx_import::import_onnx_model(
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/quantize_linear_const.prototxt"));
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(function);
|
||||
test_case.add_input(std::vector<float>{32.25f, 48.34f, 50.f, 83.f});
|
||||
|
||||
test_case.add_expected_output(std::vector<std::uint8_t>{64, 97, 100, 166});
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_quantize_linear)
|
||||
{
|
||||
auto function = onnx_import::import_onnx_model(
|
||||
|
@ -10,8 +10,10 @@
|
||||
#
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
# Segmentation fault
|
||||
onnx_model_quantize_linear_const_scale_const_zero_p
|
||||
|
||||
# Not supported ONNX op: QuantizeLinear
|
||||
# Quantize layer input 'Multiply_7' doesn't have blobs
|
||||
onnx_model_quantize_linear
|
||||
onnx_model_quantize_linear_zero_point
|
||||
onnx_model_quantize_linear_axis_zero
|
||||
|
@ -87,10 +87,13 @@ INTERPRETER.convolution_2d_1item_5o3i_data_dilated
|
||||
INTERPRETER.convolution_2d_2item_5o3i_data_dilated
|
||||
|
||||
# Removed opset0 operations
|
||||
INTERPRETER.onnx_model_quantize_linear
|
||||
INTERPRETER.onnx_model_quantize_linear_zero_point
|
||||
INTERPRETER.onnx_model_quantize_linear_axis_zero
|
||||
INTERPRETER.onnx_model_quantize_linear_axis_negative
|
||||
INTERPRETER.onnx_model_dequantize_linear
|
||||
INTERPRETER.onnx_model_dequantize_linear_scalar_zero_scale_uint8
|
||||
INTERPRETER.onnx_model_dequantize_linear_scalar_zero_scale_int8
|
||||
INTERPRETER.onnx_model_dequantize_linear_1d_zero_scale_uint8
|
||||
INTERPRETER.onnx_model_dequantize_linear_1d_zero_scale_int8
|
||||
INTERPRETER.onnx_model_dequantize_linear_1d_zero_scale_int8_4d
|
||||
INTERPRETER.onnx_model_dequantize_linear_1d_zero_scale_uint8_negative_axis
|
||||
INTERPRETER.onnx_model_quant_conv_linear_2d
|
||||
INTERPRETER.onnx_model_quant_conv_linear_3d
|
||||
INTERPRETER.onnx_model_conv_integer
|
||||
|
Loading…
Reference in New Issue
Block a user