ONNX ConvInteger - handling of scalar zero points (#10057)

This commit is contained in:
Tomasz Dołbniak 2022-02-02 12:16:08 +01:00 committed by GitHub
parent 53af687a0c
commit 0700ba781b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 151 additions and 15 deletions

View File

@ -0,0 +1,103 @@
ir_version: 7
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "x"
input: "w"
input: "x_zero_point"
input: "w_zero_point"
output: "y"
op_type: "ConvInteger"
}
name: "ConvInt"
input {
name: "x"
type {
tensor_type {
elem_type: 2
shape {
dim {
dim_value: 1
}
dim {
dim_value: 1
}
dim {
dim_value: 3
}
dim {
dim_value: 3
}
}
}
}
}
input {
name: "w"
type {
tensor_type {
elem_type: 2
shape {
dim {
dim_value: 1
}
dim {
dim_value: 1
}
dim {
dim_value: 2
}
dim {
dim_value: 2
}
}
}
}
}
input {
name: "x_zero_point"
type {
tensor_type {
elem_type: 2
shape {
}
}
}
}
input {
name: "w_zero_point"
type {
tensor_type {
elem_type: 2
shape {
}
}
}
}
output {
name: "y"
type {
tensor_type {
elem_type: 6
shape {
dim {
dim_value: 1
}
dim {
dim_value: 1
}
dim {
dim_value: 2
}
dim {
dim_value: 2
}
}
}
}
}
}
opset_import {
domain: ""
version: 10
}

View File

@ -771,6 +771,27 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_conv_integer_simple_zero_point) {
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_conv_integer_scalar_zp) {
auto function =
onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/conv_integer_scalar_zp.onnx"));
auto test_case = test::TestCase(function, s_device);
// clang-format off
test_case.add_input(std::vector<uint8_t>{11, 22, 33,
44, 55, 66,
77, 88, 99}); // x
test_case.add_input(std::vector<uint8_t>{5, 6,
7, 8}); // w
test_case.add_input(std::vector<uint8_t>{10}); // x_zero_point
test_case.add_input(std::vector<uint8_t>{20}); // w_zero_point
test_case.add_expected_output({1, 1, 2, 2}, std::vector<int32_t>{-1165, -1759,
-2947, -3541}); // y
// clang-format on
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_conv_integer_int8) {
auto function = onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/conv_integer_int8.onnx"));

View File

@ -11,6 +11,30 @@
namespace ngraph {
namespace onnx_import {
namespace {
std::shared_ptr<ngraph::Node> get_filter_zero_point(const OutputVector& inputs) {
const auto& original_zero_point =
(inputs.size() > 3) ? inputs.at(3) : ngraph::op::Constant::create(ngraph::element::i32, {}, {0});
const auto filter_zero_point_rank = original_zero_point.get_partial_shape().rank();
if (filter_zero_point_rank.is_static() && filter_zero_point_rank.get_length() == 0) {
return std::make_shared<default_opset::Convert>(original_zero_point, element::i32);
} else {
// in case of 1D zero point filter, it has to be unsqueezed to match the data input's rank
const auto& converted_filter_zero_point =
std::make_shared<default_opset::Convert>(original_zero_point, element::i32);
const auto& input_shape = std::make_shared<default_opset::ShapeOf>(inputs.at(0), element::i32);
const auto& input_rank = std::make_shared<default_opset::ShapeOf>(input_shape, element::i32);
const auto& input_rank_scalar = reshape::interpret_as_scalar(input_rank);
const auto& one_node = ngraph::op::Constant::create(ngraph::element::i32, {}, {1});
const auto& missing_dimensions =
std::make_shared<default_opset::Range>(one_node, input_rank_scalar, one_node, element::i32);
return std::make_shared<default_opset::Unsqueeze>(converted_filter_zero_point, missing_dimensions);
}
}
} // namespace
namespace op {
namespace set_1 {
@ -20,28 +44,16 @@ OutputVector conv_integer(const Node& node) {
const auto& input = inputs.at(0);
const auto& filter = inputs.at(1);
const auto& input_zero_point =
(inputs.size() > 2) ? inputs.at(2) : ngraph::op::Constant::create(ngraph::element::i32, {1}, {0});
const auto& filter_zero_point =
(inputs.size() > 3) ? inputs.at(3) : ngraph::op::Constant::create(ngraph::element::i32, {1}, {0});
(inputs.size() > 2) ? inputs.at(2) : ngraph::op::Constant::create(ngraph::element::i32, {}, {0});
const auto& converted_input = std::make_shared<default_opset::Convert>(input, element::i32);
const auto& converted_filter = std::make_shared<default_opset::Convert>(filter, element::i32);
const auto& converted_input_zero_point = std::make_shared<default_opset::Convert>(input_zero_point, element::i32);
const auto& converted_filter_zero_point = std::make_shared<default_opset::Convert>(filter_zero_point, element::i32);
const auto& input_shape = std::make_shared<default_opset::ShapeOf>(input, element::i32);
const auto& input_rank = std::make_shared<default_opset::ShapeOf>(input_shape, element::i32);
const auto& input_rank_scalar = reshape::interpret_as_scalar(input_rank);
const auto& one_node = ngraph::op::Constant::create(ngraph::element::i32, {}, {1});
const auto& missing_dimensions =
std::make_shared<default_opset::Range>(one_node, input_rank_scalar, one_node, element::i32);
const auto& resized_filter_zero_point =
std::make_shared<default_opset::Unsqueeze>(converted_filter_zero_point, missing_dimensions);
const auto& filter_zero_point = get_filter_zero_point(inputs);
const auto& shifted_input = std::make_shared<default_opset::Subtract>(converted_input, converted_input_zero_point);
const auto& shifted_filter = std::make_shared<default_opset::Subtract>(converted_filter, resized_filter_zero_point);
const auto& shifted_filter = std::make_shared<default_opset::Subtract>(converted_filter, filter_zero_point);
const auto& groups = node.get_attribute_value<int64_t>("group", 1);
const auto& strides = convpool::get_strides(node);