ONNX DequantizeLinear op (#1123)
* DequantizeLinear 10 as a subgraph * Enable DequantizeLinear from opset 13 * Exclude the failing tests * Re-enable dequantize linear UTs * Validation helper
This commit is contained in:
parent
59579eb437
commit
ae6cfe12bb
@ -14,7 +14,7 @@
|
||||
# limitations under the License.
|
||||
# ******************************************************************************
|
||||
|
||||
set(ONNX_OPSET_VERSION 11 CACHE INTERNAL "Supported version of ONNX operator set")
|
||||
set(ONNX_OPSET_VERSION 13 CACHE INTERNAL "Supported version of ONNX operator set")
|
||||
|
||||
add_library(onnx_importer SHARED
|
||||
core/node.cpp
|
||||
@ -86,8 +86,8 @@ add_library(onnx_importer SHARED
|
||||
op/cum_sum.hpp
|
||||
op/depth_to_space.cpp
|
||||
op/depth_to_space.hpp
|
||||
# op/dequantize_linear.cpp
|
||||
# op/dequantize_linear.hpp
|
||||
op/dequantize_linear.cpp
|
||||
op/dequantize_linear.hpp
|
||||
op/div.hpp
|
||||
op/dropout.hpp
|
||||
op/elu.cpp
|
||||
|
@ -23,9 +23,9 @@
|
||||
#include "ngraph/builder/make_constant.hpp"
|
||||
#include "ngraph/op/convert.hpp"
|
||||
#include "ngraph/op/dequantize.hpp"
|
||||
#include "ngraph/opsets/opset0.hpp"
|
||||
#include "ngraph/shape.hpp"
|
||||
#include "ngraph/validation_util.hpp"
|
||||
#include "utils/common.hpp"
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
@ -33,54 +33,177 @@ namespace ngraph
|
||||
{
|
||||
namespace op
|
||||
{
|
||||
namespace
|
||||
{
|
||||
std::shared_ptr<ngraph::Node> get_zero_point(const NodeVector& inputs)
|
||||
{
|
||||
if (inputs.size() == 3 && !inputs[2]->is_null())
|
||||
{
|
||||
auto zero_point = inputs[2];
|
||||
|
||||
if (zero_point->get_element_type() != element::f32)
|
||||
{
|
||||
zero_point =
|
||||
std::make_shared<default_opset::Convert>(zero_point, element::f32);
|
||||
}
|
||||
|
||||
return zero_point;
|
||||
}
|
||||
else
|
||||
{
|
||||
return default_opset::Constant::create(element::f32, Shape{}, {0});
|
||||
}
|
||||
}
|
||||
}
|
||||
namespace set_1
|
||||
{
|
||||
NodeVector dequantize_linear(const Node& node)
|
||||
{
|
||||
NodeVector inputs{node.get_ng_inputs()};
|
||||
std::shared_ptr<ngraph::Node> x = inputs.at(0);
|
||||
std::shared_ptr<ngraph::Node> x_scale = inputs.at(1);
|
||||
std::shared_ptr<ngraph::Node> zero_point;
|
||||
if (inputs.size() == 3 && !inputs.at(2)->is_null())
|
||||
const NodeVector inputs{node.get_ng_inputs()};
|
||||
|
||||
NGRAPH_CHECK(
|
||||
2 <= inputs.size() && inputs.size() <= 3,
|
||||
"The DequantizeLinear op expects 2 required and one optional input. Got: ",
|
||||
inputs.size());
|
||||
|
||||
const auto x = inputs[0];
|
||||
const auto scale = inputs[1];
|
||||
const auto zero_point = get_zero_point(inputs);
|
||||
|
||||
common::validate_scalar_input("Dequantization scale", scale, {element::f32});
|
||||
common::validate_scalar_input("Zero point", zero_point);
|
||||
|
||||
const auto converted_x =
|
||||
std::make_shared<default_opset::Convert>(x, element::f32);
|
||||
|
||||
return {std::make_shared<default_opset::Multiply>(
|
||||
std::make_shared<default_opset::Subtract>(converted_x, zero_point), scale)};
|
||||
}
|
||||
}
|
||||
|
||||
namespace set_13
|
||||
{
|
||||
namespace
|
||||
{
|
||||
void validate_scale(const std::shared_ptr<ngraph::Node> scale,
|
||||
const std::shared_ptr<ngraph::Node> x,
|
||||
const int64_t axis)
|
||||
{
|
||||
zero_point = inputs.at(2);
|
||||
}
|
||||
else
|
||||
{
|
||||
zero_point =
|
||||
ngraph::builder::make_constant(x->get_element_type(), Shape{}, 0);
|
||||
const auto& scale_shape = scale->get_output_partial_shape(0);
|
||||
NGRAPH_CHECK(scale_shape.rank().get_length() == 0 ||
|
||||
scale_shape.rank().get_length() == 1,
|
||||
"Dequantization scale needs to be a scalar or a vector.");
|
||||
|
||||
if (scale_shape.rank().get_length() == 1)
|
||||
{
|
||||
const auto& scale_dim = scale_shape[0];
|
||||
const auto& x_shape = x->get_output_partial_shape(0);
|
||||
const auto& x_dim_at_axis = x_shape[axis];
|
||||
|
||||
NGRAPH_CHECK(scale_dim.same_scheme(x_dim_at_axis),
|
||||
"The number of dequantization scale elements '",
|
||||
scale_dim,
|
||||
"' must match the input shape dimension '",
|
||||
x_dim_at_axis,
|
||||
" pointed to by the axis attribute: ",
|
||||
axis);
|
||||
}
|
||||
}
|
||||
|
||||
Shape y_scale_shape = x_scale->get_shape();
|
||||
Shape y_zero_point_shape = zero_point->get_shape();
|
||||
|
||||
// get axis twice with two default values to see if it is set
|
||||
int64_t axis_0{node.get_attribute_value<int64_t>("axis", 0)};
|
||||
int64_t axis_1{node.get_attribute_value<int64_t>("axis", 1)};
|
||||
|
||||
const auto data_rank = x->get_output_partial_shape(0).rank();
|
||||
AxisSet axes;
|
||||
// if axis attribute is set
|
||||
if (axis_0 == axis_1)
|
||||
void validate_zero_point(const std::shared_ptr<ngraph::Node> zero_point,
|
||||
const std::shared_ptr<ngraph::Node> x,
|
||||
const int64_t axis)
|
||||
{
|
||||
axes.insert(
|
||||
ngraph::normalize_axis(node.get_description(), axis_0, data_rank));
|
||||
const auto& zero_point_shape = zero_point->get_output_partial_shape(0);
|
||||
NGRAPH_CHECK(zero_point_shape.rank().get_length() == 0 ||
|
||||
zero_point_shape.rank().get_length() == 1,
|
||||
"Zero point needs to be a scalar or a vector.");
|
||||
|
||||
if (zero_point_shape.rank().get_length() == 1)
|
||||
{
|
||||
const auto& zero_point_dim = zero_point_shape[0];
|
||||
const auto& x_shape = x->get_output_partial_shape(0);
|
||||
const auto& x_dim_at_axis = x_shape[axis];
|
||||
|
||||
NGRAPH_CHECK(zero_point_dim.same_scheme(x_dim_at_axis),
|
||||
"The number of zero point elements '",
|
||||
zero_point_dim,
|
||||
"' must match the input shape dimension '",
|
||||
x_dim_at_axis,
|
||||
" pointed to by the axis attribute: ",
|
||||
axis);
|
||||
}
|
||||
}
|
||||
|
||||
if (x->get_element_type() != zero_point->get_element_type())
|
||||
std::shared_ptr<ngraph::Node>
|
||||
reshape_input(const std::shared_ptr<ngraph::Node> input,
|
||||
const int64_t axis,
|
||||
const PartialShape& x_shape)
|
||||
{
|
||||
zero_point = std::make_shared<default_opset::Convert>(
|
||||
zero_point, x->get_element_type());
|
||||
}
|
||||
std::vector<int64_t> target_dims;
|
||||
|
||||
return {std::make_shared<ngraph::opset0::Dequantize>(
|
||||
x, x_scale, zero_point, x_scale->get_element_type(), axes)};
|
||||
for (size_t i = 0; i < axis; ++i)
|
||||
{
|
||||
target_dims.push_back(1);
|
||||
}
|
||||
|
||||
// copy dimension at axis from input X
|
||||
if (x_shape[axis].is_static())
|
||||
{
|
||||
target_dims.push_back(x_shape[axis].get_length());
|
||||
}
|
||||
else
|
||||
{
|
||||
target_dims.push_back(0);
|
||||
}
|
||||
|
||||
for (size_t i = axis + 1; i < x_shape.rank().get_length(); ++i)
|
||||
{
|
||||
target_dims.push_back(1);
|
||||
}
|
||||
|
||||
const auto target_shape = default_opset::Constant::create(
|
||||
element::i64, Shape{target_dims.size()}, target_dims);
|
||||
|
||||
return std::make_shared<default_opset::Reshape>(input, target_shape, true);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace set_1
|
||||
NodeVector dequantize_linear(const Node& node)
|
||||
{
|
||||
const NodeVector inputs{node.get_ng_inputs()};
|
||||
|
||||
} // namespace op
|
||||
NGRAPH_CHECK(2 <= inputs.size() && inputs.size() <= 3,
|
||||
"The DequantizeLinear op expects 2 required and one optional "
|
||||
"input. Got: ",
|
||||
inputs.size());
|
||||
|
||||
} // namespace onnx_import
|
||||
const auto x = inputs[0];
|
||||
auto scale = inputs[1];
|
||||
auto zero_point = get_zero_point(inputs);
|
||||
|
||||
} // namespace ngraph
|
||||
const auto x_shape = x->get_output_partial_shape(0);
|
||||
|
||||
NGRAPH_CHECK(x_shape.rank().is_static(),
|
||||
"Rank of the input data tensor has to be known (static).");
|
||||
|
||||
int64_t axis{node.get_attribute_value<int64_t>("axis", 1)};
|
||||
axis = ngraph::normalize_axis(node.get_description(), axis, x_shape.rank());
|
||||
|
||||
validate_scale(scale, x, axis);
|
||||
validate_zero_point(zero_point, x, axis);
|
||||
|
||||
// these reshapes make sure that dequantization happens over the specified axis
|
||||
scale = reshape_input(scale, axis, x_shape);
|
||||
zero_point = reshape_input(zero_point, axis, x_shape);
|
||||
|
||||
const auto converted_x =
|
||||
std::make_shared<default_opset::Convert>(x, element::f32);
|
||||
|
||||
return {std::make_shared<default_opset::Multiply>(
|
||||
std::make_shared<default_opset::Subtract>(converted_x, zero_point), scale)};
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -31,6 +31,11 @@ namespace ngraph
|
||||
|
||||
} // namespace set_1
|
||||
|
||||
namespace set_13
|
||||
{
|
||||
NodeVector dequantize_linear(const Node& node);
|
||||
}
|
||||
|
||||
} // namespace op
|
||||
|
||||
} // namespace onnx_import
|
||||
|
@ -48,7 +48,7 @@
|
||||
#include "op/cosh.hpp"
|
||||
#include "op/cum_sum.hpp"
|
||||
#include "op/depth_to_space.hpp"
|
||||
// #include "op/dequantize_linear.hpp"
|
||||
#include "op/dequantize_linear.hpp"
|
||||
#include "op/div.hpp"
|
||||
#include "op/dropout.hpp"
|
||||
#include "op/elu.hpp"
|
||||
@ -278,7 +278,8 @@ namespace ngraph
|
||||
REGISTER_OPERATOR("Cosh", 1, cosh);
|
||||
REGISTER_OPERATOR("CumSum", 1, cum_sum);
|
||||
REGISTER_OPERATOR("DepthToSpace", 1, depth_to_space);
|
||||
// REGISTER_OPERATOR("DequantizeLinear", 1, dequantize_linear);
|
||||
REGISTER_OPERATOR("DequantizeLinear", 1, dequantize_linear);
|
||||
REGISTER_OPERATOR("DequantizeLinear", 13, dequantize_linear);
|
||||
REGISTER_OPERATOR("Div", 1, div);
|
||||
REGISTER_OPERATOR("Div", 7, div);
|
||||
REGISTER_OPERATOR("Dropout", 1, dropout);
|
||||
|
@ -67,6 +67,26 @@ namespace ngraph
|
||||
default_opset::Constant::create(element::i64, {}, {step}));
|
||||
}
|
||||
|
||||
void validate_scalar_input(const char* input_name,
|
||||
const std::shared_ptr<ngraph::Node> input,
|
||||
const std::set<element::Type> allowed_types)
|
||||
{
|
||||
const auto validated_input_rank = input->get_output_partial_shape(0).rank();
|
||||
|
||||
NGRAPH_CHECK(
|
||||
validated_input_rank.same_scheme({0}), input_name, " needs to be a scalar.");
|
||||
|
||||
if (!allowed_types.empty())
|
||||
{
|
||||
const bool data_type_ok = allowed_types.count(input->get_element_type());
|
||||
NGRAPH_CHECK(data_type_ok,
|
||||
"Incorrect data type of the ",
|
||||
input_name,
|
||||
" input: ",
|
||||
input->get_element_type());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace common
|
||||
} // namespace onnx_import
|
||||
} // namespace ngraph
|
||||
|
@ -127,6 +127,16 @@ namespace ngraph
|
||||
return shifted_square_identity(Shape{n, n}, type, 0);
|
||||
}
|
||||
|
||||
/// \brief Performs validation of an input that is expected to be a scalar.
|
||||
/// \note This function throws an exception if any of the validation steps fails.
|
||||
///
|
||||
/// \param[in] input_name A human-readable name of an input (used for logging)
|
||||
/// \param[in] input An input node to be validated
|
||||
/// \param[in] allowed_types An optional set of allowed element types for this input
|
||||
void validate_scalar_input(const char* input_name,
|
||||
const std::shared_ptr<ngraph::Node> input,
|
||||
const std::set<element::Type> allowed_types = {});
|
||||
|
||||
} // namespace common
|
||||
} // namespace onnx_import
|
||||
} // namespace ngraph
|
||||
|
@ -75,5 +75,5 @@ graph {
|
||||
}
|
||||
}
|
||||
opset_import {
|
||||
version: 10
|
||||
version: 13
|
||||
}
|
||||
|
@ -75,5 +75,5 @@ graph {
|
||||
}
|
||||
}
|
||||
opset_import {
|
||||
version: 10
|
||||
version: 13
|
||||
}
|
||||
|
@ -87,5 +87,5 @@ graph {
|
||||
}
|
||||
}
|
||||
opset_import {
|
||||
version: 10
|
||||
version: 13
|
||||
}
|
||||
|
@ -75,5 +75,5 @@ graph {
|
||||
}
|
||||
}
|
||||
opset_import {
|
||||
version: 10
|
||||
version: 13
|
||||
}
|
||||
|
@ -156,7 +156,7 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_dequantize_linear_1d_zero_scale_uint8)
|
||||
auto function = onnx_import::import_onnx_model(
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/dequantize_linear_2.prototxt"));
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(function);
|
||||
auto test_case = ngraph::test::TestCase<TestEngine>(function);
|
||||
|
||||
test_case.add_input(std::vector<uint8_t>{0, 1, 2, 3, 0, 1, 2, 3, 0, 10, 20, 30}); // x
|
||||
test_case.add_input(std::vector<float>{1.0f, 2.0f, 4.0f}); // scale
|
||||
@ -174,7 +174,7 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_dequantize_linear_1d_zero_scale_int8)
|
||||
auto function = onnx_import::import_onnx_model(
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/dequantize_linear_3.prototxt"));
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(function);
|
||||
auto test_case = ngraph::test::TestCase<TestEngine>(function);
|
||||
|
||||
test_case.add_input(std::vector<int8_t>{0, 1, 2, 3, 0, 2, 4, 6, 0, 10, 20, 30}); // x
|
||||
test_case.add_input(std::vector<float>{1.0f, 2.0f, 4.0f, 8.0f}); // scale
|
||||
@ -192,7 +192,7 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_dequantize_linear_1d_zero_scale_int8_4d)
|
||||
auto function = onnx_import::import_onnx_model(
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/dequantize_linear_4.prototxt"));
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(function);
|
||||
auto test_case = ngraph::test::TestCase<TestEngine>(function);
|
||||
|
||||
test_case.add_input(std::vector<uint8_t>{7, 9, 10, 10, 5, 8, 9, 1, 8, 6, 7, 9, 10, 0, 7, 10, 8,
|
||||
2, 6, 0, 5, 9, 8, 1, 2, 7, 5, 3, 2, 4, 1, 3, 8, 7,
|
||||
@ -216,7 +216,7 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_dequantize_linear_1d_zero_scale_uint8_ne
|
||||
auto function = onnx_import::import_onnx_model(
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/dequantize_linear_5.prototxt"));
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(function);
|
||||
auto test_case = ngraph::test::TestCase<TestEngine>(function);
|
||||
|
||||
test_case.add_input(std::vector<uint8_t>{0, 1, 2, 3, 0, 1, 2, 3, 0, 10, 20, 30}); // x
|
||||
test_case.add_input(std::vector<float>{1.0f, 2.0f, 4.0f}); // scale
|
||||
|
@ -17,14 +17,12 @@ onnx_model_quantize_linear_zero_point
|
||||
onnx_model_quantize_linear_axis_zero
|
||||
onnx_model_quantize_linear_axis_negative
|
||||
|
||||
# Not supported ONNX op: DequantizeLinear
|
||||
onnx_model_dequantize_linear
|
||||
onnx_model_dequantize_linear_scalar_zero_scale_uint8
|
||||
onnx_model_dequantize_linear_scalar_zero_scale_int8
|
||||
onnx_model_dequantize_linear_1d_zero_scale_uint8
|
||||
onnx_model_dequantize_linear_1d_zero_scale_int8
|
||||
onnx_model_dequantize_linear_1d_zero_scale_int8_4d
|
||||
onnx_model_dequantize_linear_1d_zero_scale_uint8_negative_axis
|
||||
# DequantizeLinear:
|
||||
# C++ exception with description "Unsupported precisions!
|
||||
IE_CPU.onnx_model_dequantize_linear_scalar_zero_scale_int8
|
||||
IE_CPU.onnx_model_dequantize_linear_1d_zero_scale_int8
|
||||
# C++ exception with description "Input data precision not supported. Expected float.
|
||||
IE_CPU.onnx_model_dequantize_linear_1d_zero_scale_int8_4d
|
||||
|
||||
# Not supported ONNX op: QLinearConv
|
||||
onnx_model_quant_conv_linear
|
||||
|
@ -91,13 +91,6 @@ INTERPRETER.onnx_model_quantize_linear
|
||||
INTERPRETER.onnx_model_quantize_linear_zero_point
|
||||
INTERPRETER.onnx_model_quantize_linear_axis_zero
|
||||
INTERPRETER.onnx_model_quantize_linear_axis_negative
|
||||
INTERPRETER.onnx_model_dequantize_linear
|
||||
INTERPRETER.onnx_model_dequantize_linear_scalar_zero_scale_uint8
|
||||
INTERPRETER.onnx_model_dequantize_linear_scalar_zero_scale_int8
|
||||
INTERPRETER.onnx_model_dequantize_linear_1d_zero_scale_uint8
|
||||
INTERPRETER.onnx_model_dequantize_linear_1d_zero_scale_int8
|
||||
INTERPRETER.onnx_model_dequantize_linear_1d_zero_scale_int8_4d
|
||||
INTERPRETER.onnx_model_dequantize_linear_1d_zero_scale_uint8_negative_axis
|
||||
INTERPRETER.onnx_model_quant_conv_linear_2d
|
||||
INTERPRETER.onnx_model_quant_conv_linear_3d
|
||||
INTERPRETER.onnx_model_conv_integer
|
||||
|
Loading…
Reference in New Issue
Block a user