ONNX DequantizeLinear op (#1123)

* DequantizeLinear 10 as a subgraph

* Enable DequantizeLinear from opset 13

* Exclude the failing tests

* Re-enable dequantize linear UTs

* Validation helper
This commit is contained in:
Tomasz Dołbniak 2020-07-07 13:08:08 +02:00 committed by GitHub
parent 59579eb437
commit ae6cfe12bb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 213 additions and 63 deletions

View File

@ -14,7 +14,7 @@
# limitations under the License.
# ******************************************************************************
set(ONNX_OPSET_VERSION 11 CACHE INTERNAL "Supported version of ONNX operator set")
set(ONNX_OPSET_VERSION 13 CACHE INTERNAL "Supported version of ONNX operator set")
add_library(onnx_importer SHARED
core/node.cpp
@ -86,8 +86,8 @@ add_library(onnx_importer SHARED
op/cum_sum.hpp
op/depth_to_space.cpp
op/depth_to_space.hpp
# op/dequantize_linear.cpp
# op/dequantize_linear.hpp
op/dequantize_linear.cpp
op/dequantize_linear.hpp
op/div.hpp
op/dropout.hpp
op/elu.cpp

View File

@ -23,9 +23,9 @@
#include "ngraph/builder/make_constant.hpp"
#include "ngraph/op/convert.hpp"
#include "ngraph/op/dequantize.hpp"
#include "ngraph/opsets/opset0.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/validation_util.hpp"
#include "utils/common.hpp"
namespace ngraph
{
@ -33,54 +33,177 @@ namespace ngraph
{
namespace op
{
namespace
{
std::shared_ptr<ngraph::Node> get_zero_point(const NodeVector& inputs)
{
if (inputs.size() == 3 && !inputs[2]->is_null())
{
auto zero_point = inputs[2];
if (zero_point->get_element_type() != element::f32)
{
zero_point =
std::make_shared<default_opset::Convert>(zero_point, element::f32);
}
return zero_point;
}
else
{
return default_opset::Constant::create(element::f32, Shape{}, {0});
}
}
}
namespace set_1
{
NodeVector dequantize_linear(const Node& node)
{
NodeVector inputs{node.get_ng_inputs()};
std::shared_ptr<ngraph::Node> x = inputs.at(0);
std::shared_ptr<ngraph::Node> x_scale = inputs.at(1);
std::shared_ptr<ngraph::Node> zero_point;
if (inputs.size() == 3 && !inputs.at(2)->is_null())
const NodeVector inputs{node.get_ng_inputs()};
NGRAPH_CHECK(
2 <= inputs.size() && inputs.size() <= 3,
"The DequantizeLinear op expects 2 required and one optional input. Got: ",
inputs.size());
const auto x = inputs[0];
const auto scale = inputs[1];
const auto zero_point = get_zero_point(inputs);
common::validate_scalar_input("Dequantization scale", scale, {element::f32});
common::validate_scalar_input("Zero point", zero_point);
const auto converted_x =
std::make_shared<default_opset::Convert>(x, element::f32);
return {std::make_shared<default_opset::Multiply>(
std::make_shared<default_opset::Subtract>(converted_x, zero_point), scale)};
}
}
namespace set_13
{
zero_point = inputs.at(2);
namespace
{
void validate_scale(const std::shared_ptr<ngraph::Node> scale,
const std::shared_ptr<ngraph::Node> x,
const int64_t axis)
{
const auto& scale_shape = scale->get_output_partial_shape(0);
NGRAPH_CHECK(scale_shape.rank().get_length() == 0 ||
scale_shape.rank().get_length() == 1,
"Dequantization scale needs to be a scalar or a vector.");
if (scale_shape.rank().get_length() == 1)
{
const auto& scale_dim = scale_shape[0];
const auto& x_shape = x->get_output_partial_shape(0);
const auto& x_dim_at_axis = x_shape[axis];
NGRAPH_CHECK(scale_dim.same_scheme(x_dim_at_axis),
"The number of dequantization scale elements '",
scale_dim,
"' must match the input shape dimension '",
x_dim_at_axis,
" pointed to by the axis attribute: ",
axis);
}
}
void validate_zero_point(const std::shared_ptr<ngraph::Node> zero_point,
const std::shared_ptr<ngraph::Node> x,
const int64_t axis)
{
const auto& zero_point_shape = zero_point->get_output_partial_shape(0);
NGRAPH_CHECK(zero_point_shape.rank().get_length() == 0 ||
zero_point_shape.rank().get_length() == 1,
"Zero point needs to be a scalar or a vector.");
if (zero_point_shape.rank().get_length() == 1)
{
const auto& zero_point_dim = zero_point_shape[0];
const auto& x_shape = x->get_output_partial_shape(0);
const auto& x_dim_at_axis = x_shape[axis];
NGRAPH_CHECK(zero_point_dim.same_scheme(x_dim_at_axis),
"The number of zero point elements '",
zero_point_dim,
"' must match the input shape dimension '",
x_dim_at_axis,
" pointed to by the axis attribute: ",
axis);
}
}
std::shared_ptr<ngraph::Node>
reshape_input(const std::shared_ptr<ngraph::Node> input,
const int64_t axis,
const PartialShape& x_shape)
{
std::vector<int64_t> target_dims;
for (size_t i = 0; i < axis; ++i)
{
target_dims.push_back(1);
}
// copy dimension at axis from input X
if (x_shape[axis].is_static())
{
target_dims.push_back(x_shape[axis].get_length());
}
else
{
zero_point =
ngraph::builder::make_constant(x->get_element_type(), Shape{}, 0);
target_dims.push_back(0);
}
Shape y_scale_shape = x_scale->get_shape();
Shape y_zero_point_shape = zero_point->get_shape();
// get axis twice with two default values to see if it is set
int64_t axis_0{node.get_attribute_value<int64_t>("axis", 0)};
int64_t axis_1{node.get_attribute_value<int64_t>("axis", 1)};
const auto data_rank = x->get_output_partial_shape(0).rank();
AxisSet axes;
// if axis attribute is set
if (axis_0 == axis_1)
for (size_t i = axis + 1; i < x_shape.rank().get_length(); ++i)
{
axes.insert(
ngraph::normalize_axis(node.get_description(), axis_0, data_rank));
target_dims.push_back(1);
}
if (x->get_element_type() != zero_point->get_element_type())
const auto target_shape = default_opset::Constant::create(
element::i64, Shape{target_dims.size()}, target_dims);
return std::make_shared<default_opset::Reshape>(input, target_shape, true);
}
}
NodeVector dequantize_linear(const Node& node)
{
zero_point = std::make_shared<default_opset::Convert>(
zero_point, x->get_element_type());
const NodeVector inputs{node.get_ng_inputs()};
NGRAPH_CHECK(2 <= inputs.size() && inputs.size() <= 3,
"The DequantizeLinear op expects 2 required and one optional "
"input. Got: ",
inputs.size());
const auto x = inputs[0];
auto scale = inputs[1];
auto zero_point = get_zero_point(inputs);
const auto x_shape = x->get_output_partial_shape(0);
NGRAPH_CHECK(x_shape.rank().is_static(),
"Rank of the input data tensor has to be known (static).");
int64_t axis{node.get_attribute_value<int64_t>("axis", 1)};
axis = ngraph::normalize_axis(node.get_description(), axis, x_shape.rank());
validate_scale(scale, x, axis);
validate_zero_point(zero_point, x, axis);
// these reshapes make sure that dequantization happens over the specified axis
scale = reshape_input(scale, axis, x_shape);
zero_point = reshape_input(zero_point, axis, x_shape);
const auto converted_x =
std::make_shared<default_opset::Convert>(x, element::f32);
return {std::make_shared<default_opset::Multiply>(
std::make_shared<default_opset::Subtract>(converted_x, zero_point), scale)};
}
}
}
}
return {std::make_shared<ngraph::opset0::Dequantize>(
x, x_scale, zero_point, x_scale->get_element_type(), axes)};
}
} // namespace set_1
} // namespace op
} // namespace onnx_import
} // namespace ngraph

View File

@ -31,6 +31,11 @@ namespace ngraph
} // namespace set_1
namespace set_13
{
NodeVector dequantize_linear(const Node& node);
}
} // namespace op
} // namespace onnx_import

View File

@ -48,7 +48,7 @@
#include "op/cosh.hpp"
#include "op/cum_sum.hpp"
#include "op/depth_to_space.hpp"
// #include "op/dequantize_linear.hpp"
#include "op/dequantize_linear.hpp"
#include "op/div.hpp"
#include "op/dropout.hpp"
#include "op/elu.hpp"
@ -278,7 +278,8 @@ namespace ngraph
REGISTER_OPERATOR("Cosh", 1, cosh);
REGISTER_OPERATOR("CumSum", 1, cum_sum);
REGISTER_OPERATOR("DepthToSpace", 1, depth_to_space);
// REGISTER_OPERATOR("DequantizeLinear", 1, dequantize_linear);
REGISTER_OPERATOR("DequantizeLinear", 1, dequantize_linear);
REGISTER_OPERATOR("DequantizeLinear", 13, dequantize_linear);
REGISTER_OPERATOR("Div", 1, div);
REGISTER_OPERATOR("Div", 7, div);
REGISTER_OPERATOR("Dropout", 1, dropout);

View File

@ -67,6 +67,26 @@ namespace ngraph
default_opset::Constant::create(element::i64, {}, {step}));
}
void validate_scalar_input(const char* input_name,
const std::shared_ptr<ngraph::Node> input,
const std::set<element::Type> allowed_types)
{
const auto validated_input_rank = input->get_output_partial_shape(0).rank();
NGRAPH_CHECK(
validated_input_rank.same_scheme({0}), input_name, " needs to be a scalar.");
if (!allowed_types.empty())
{
const bool data_type_ok = allowed_types.count(input->get_element_type());
NGRAPH_CHECK(data_type_ok,
"Incorrect data type of the ",
input_name,
" input: ",
input->get_element_type());
}
}
} // namespace common
} // namespace onnx_import
} // namespace ngraph

View File

@ -127,6 +127,16 @@ namespace ngraph
return shifted_square_identity(Shape{n, n}, type, 0);
}
/// \brief Performs validation of an input that is expected to be a scalar.
/// \note This function throws an exception if any of the validation steps fails.
///
/// \param[in] input_name A human-readable name of an input (used for logging)
/// \param[in] input An input node to be validated
/// \param[in] allowed_types An optional set of allowed element types for this input
void validate_scalar_input(const char* input_name,
const std::shared_ptr<ngraph::Node> input,
const std::set<element::Type> allowed_types = {});
} // namespace common
} // namespace onnx_import
} // namespace ngraph

View File

@ -75,5 +75,5 @@ graph {
}
}
opset_import {
version: 10
version: 13
}

View File

@ -75,5 +75,5 @@ graph {
}
}
opset_import {
version: 10
version: 13
}

View File

@ -87,5 +87,5 @@ graph {
}
}
opset_import {
version: 10
version: 13
}

View File

@ -75,5 +75,5 @@ graph {
}
}
opset_import {
version: 10
version: 13
}

View File

@ -156,7 +156,7 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_dequantize_linear_1d_zero_scale_uint8)
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/dequantize_linear_2.prototxt"));
auto test_case = test::TestCase<TestEngine>(function);
auto test_case = ngraph::test::TestCase<TestEngine>(function);
test_case.add_input(std::vector<uint8_t>{0, 1, 2, 3, 0, 1, 2, 3, 0, 10, 20, 30}); // x
test_case.add_input(std::vector<float>{1.0f, 2.0f, 4.0f}); // scale
@ -174,7 +174,7 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_dequantize_linear_1d_zero_scale_int8)
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/dequantize_linear_3.prototxt"));
auto test_case = test::TestCase<TestEngine>(function);
auto test_case = ngraph::test::TestCase<TestEngine>(function);
test_case.add_input(std::vector<int8_t>{0, 1, 2, 3, 0, 2, 4, 6, 0, 10, 20, 30}); // x
test_case.add_input(std::vector<float>{1.0f, 2.0f, 4.0f, 8.0f}); // scale
@ -192,7 +192,7 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_dequantize_linear_1d_zero_scale_int8_4d)
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/dequantize_linear_4.prototxt"));
auto test_case = test::TestCase<TestEngine>(function);
auto test_case = ngraph::test::TestCase<TestEngine>(function);
test_case.add_input(std::vector<uint8_t>{7, 9, 10, 10, 5, 8, 9, 1, 8, 6, 7, 9, 10, 0, 7, 10, 8,
2, 6, 0, 5, 9, 8, 1, 2, 7, 5, 3, 2, 4, 1, 3, 8, 7,
@ -216,7 +216,7 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_dequantize_linear_1d_zero_scale_uint8_ne
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/dequantize_linear_5.prototxt"));
auto test_case = test::TestCase<TestEngine>(function);
auto test_case = ngraph::test::TestCase<TestEngine>(function);
test_case.add_input(std::vector<uint8_t>{0, 1, 2, 3, 0, 1, 2, 3, 0, 10, 20, 30}); // x
test_case.add_input(std::vector<float>{1.0f, 2.0f, 4.0f}); // scale

View File

@ -17,14 +17,12 @@ onnx_model_quantize_linear_zero_point
onnx_model_quantize_linear_axis_zero
onnx_model_quantize_linear_axis_negative
# Not supported ONNX op: DequantizeLinear
onnx_model_dequantize_linear
onnx_model_dequantize_linear_scalar_zero_scale_uint8
onnx_model_dequantize_linear_scalar_zero_scale_int8
onnx_model_dequantize_linear_1d_zero_scale_uint8
onnx_model_dequantize_linear_1d_zero_scale_int8
onnx_model_dequantize_linear_1d_zero_scale_int8_4d
onnx_model_dequantize_linear_1d_zero_scale_uint8_negative_axis
# DequantizeLinear:
# C++ exception with description "Unsupported precisions!
IE_CPU.onnx_model_dequantize_linear_scalar_zero_scale_int8
IE_CPU.onnx_model_dequantize_linear_1d_zero_scale_int8
# C++ exception with description "Input data precision not supported. Expected float.
IE_CPU.onnx_model_dequantize_linear_1d_zero_scale_int8_4d
# Not supported ONNX op: QLinearConv
onnx_model_quant_conv_linear

View File

@ -91,13 +91,6 @@ INTERPRETER.onnx_model_quantize_linear
INTERPRETER.onnx_model_quantize_linear_zero_point
INTERPRETER.onnx_model_quantize_linear_axis_zero
INTERPRETER.onnx_model_quantize_linear_axis_negative
INTERPRETER.onnx_model_dequantize_linear
INTERPRETER.onnx_model_dequantize_linear_scalar_zero_scale_uint8
INTERPRETER.onnx_model_dequantize_linear_scalar_zero_scale_int8
INTERPRETER.onnx_model_dequantize_linear_1d_zero_scale_uint8
INTERPRETER.onnx_model_dequantize_linear_1d_zero_scale_int8
INTERPRETER.onnx_model_dequantize_linear_1d_zero_scale_int8_4d
INTERPRETER.onnx_model_dequantize_linear_1d_zero_scale_uint8_negative_axis
INTERPRETER.onnx_model_quant_conv_linear_2d
INTERPRETER.onnx_model_quant_conv_linear_3d
INTERPRETER.onnx_model_conv_integer