ONNX ConvInteger - handling of scalar zero points (#10057)
This commit is contained in:
parent
53af687a0c
commit
0700ba781b
103
src/core/tests/models/onnx/conv_integer_scalar_zp.prototxt
Normal file
103
src/core/tests/models/onnx/conv_integer_scalar_zp.prototxt
Normal file
@ -0,0 +1,103 @@
|
||||
ir_version: 7
|
||||
producer_name: "nGraph ONNX Importer"
|
||||
graph {
|
||||
node {
|
||||
input: "x"
|
||||
input: "w"
|
||||
input: "x_zero_point"
|
||||
input: "w_zero_point"
|
||||
output: "y"
|
||||
op_type: "ConvInteger"
|
||||
}
|
||||
name: "ConvInt"
|
||||
input {
|
||||
name: "x"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 2
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 1
|
||||
}
|
||||
dim {
|
||||
dim_value: 1
|
||||
}
|
||||
dim {
|
||||
dim_value: 3
|
||||
}
|
||||
dim {
|
||||
dim_value: 3
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "w"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 2
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 1
|
||||
}
|
||||
dim {
|
||||
dim_value: 1
|
||||
}
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "x_zero_point"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 2
|
||||
shape {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "w_zero_point"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 2
|
||||
shape {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "y"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 6
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 1
|
||||
}
|
||||
dim {
|
||||
dim_value: 1
|
||||
}
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
opset_import {
|
||||
domain: ""
|
||||
version: 10
|
||||
}
|
@ -771,6 +771,27 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_conv_integer_simple_zero_point) {
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_conv_integer_scalar_zp) {
|
||||
auto function =
|
||||
onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/conv_integer_scalar_zp.onnx"));
|
||||
|
||||
auto test_case = test::TestCase(function, s_device);
|
||||
|
||||
// clang-format off
|
||||
test_case.add_input(std::vector<uint8_t>{11, 22, 33,
|
||||
44, 55, 66,
|
||||
77, 88, 99}); // x
|
||||
test_case.add_input(std::vector<uint8_t>{5, 6,
|
||||
7, 8}); // w
|
||||
test_case.add_input(std::vector<uint8_t>{10}); // x_zero_point
|
||||
test_case.add_input(std::vector<uint8_t>{20}); // w_zero_point
|
||||
|
||||
test_case.add_expected_output({1, 1, 2, 2}, std::vector<int32_t>{-1165, -1759,
|
||||
-2947, -3541}); // y
|
||||
// clang-format on
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_conv_integer_int8) {
|
||||
auto function = onnx_import::import_onnx_model(file_util::path_join(SERIALIZED_ZOO, "onnx/conv_integer_int8.onnx"));
|
||||
|
||||
|
@ -11,6 +11,30 @@
|
||||
|
||||
namespace ngraph {
|
||||
namespace onnx_import {
|
||||
namespace {
|
||||
std::shared_ptr<ngraph::Node> get_filter_zero_point(const OutputVector& inputs) {
|
||||
const auto& original_zero_point =
|
||||
(inputs.size() > 3) ? inputs.at(3) : ngraph::op::Constant::create(ngraph::element::i32, {}, {0});
|
||||
|
||||
const auto filter_zero_point_rank = original_zero_point.get_partial_shape().rank();
|
||||
if (filter_zero_point_rank.is_static() && filter_zero_point_rank.get_length() == 0) {
|
||||
return std::make_shared<default_opset::Convert>(original_zero_point, element::i32);
|
||||
} else {
|
||||
// in case of 1D zero point filter, it has to be unsqueezed to match the data input's rank
|
||||
const auto& converted_filter_zero_point =
|
||||
std::make_shared<default_opset::Convert>(original_zero_point, element::i32);
|
||||
const auto& input_shape = std::make_shared<default_opset::ShapeOf>(inputs.at(0), element::i32);
|
||||
const auto& input_rank = std::make_shared<default_opset::ShapeOf>(input_shape, element::i32);
|
||||
const auto& input_rank_scalar = reshape::interpret_as_scalar(input_rank);
|
||||
|
||||
const auto& one_node = ngraph::op::Constant::create(ngraph::element::i32, {}, {1});
|
||||
const auto& missing_dimensions =
|
||||
std::make_shared<default_opset::Range>(one_node, input_rank_scalar, one_node, element::i32);
|
||||
|
||||
return std::make_shared<default_opset::Unsqueeze>(converted_filter_zero_point, missing_dimensions);
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
namespace op {
|
||||
namespace set_1 {
|
||||
|
||||
@ -20,28 +44,16 @@ OutputVector conv_integer(const Node& node) {
|
||||
const auto& input = inputs.at(0);
|
||||
const auto& filter = inputs.at(1);
|
||||
const auto& input_zero_point =
|
||||
(inputs.size() > 2) ? inputs.at(2) : ngraph::op::Constant::create(ngraph::element::i32, {1}, {0});
|
||||
const auto& filter_zero_point =
|
||||
(inputs.size() > 3) ? inputs.at(3) : ngraph::op::Constant::create(ngraph::element::i32, {1}, {0});
|
||||
(inputs.size() > 2) ? inputs.at(2) : ngraph::op::Constant::create(ngraph::element::i32, {}, {0});
|
||||
|
||||
const auto& converted_input = std::make_shared<default_opset::Convert>(input, element::i32);
|
||||
const auto& converted_filter = std::make_shared<default_opset::Convert>(filter, element::i32);
|
||||
|
||||
const auto& converted_input_zero_point = std::make_shared<default_opset::Convert>(input_zero_point, element::i32);
|
||||
const auto& converted_filter_zero_point = std::make_shared<default_opset::Convert>(filter_zero_point, element::i32);
|
||||
|
||||
const auto& input_shape = std::make_shared<default_opset::ShapeOf>(input, element::i32);
|
||||
const auto& input_rank = std::make_shared<default_opset::ShapeOf>(input_shape, element::i32);
|
||||
const auto& input_rank_scalar = reshape::interpret_as_scalar(input_rank);
|
||||
|
||||
const auto& one_node = ngraph::op::Constant::create(ngraph::element::i32, {}, {1});
|
||||
const auto& missing_dimensions =
|
||||
std::make_shared<default_opset::Range>(one_node, input_rank_scalar, one_node, element::i32);
|
||||
const auto& resized_filter_zero_point =
|
||||
std::make_shared<default_opset::Unsqueeze>(converted_filter_zero_point, missing_dimensions);
|
||||
const auto& filter_zero_point = get_filter_zero_point(inputs);
|
||||
|
||||
const auto& shifted_input = std::make_shared<default_opset::Subtract>(converted_input, converted_input_zero_point);
|
||||
const auto& shifted_filter = std::make_shared<default_opset::Subtract>(converted_filter, resized_filter_zero_point);
|
||||
const auto& shifted_filter = std::make_shared<default_opset::Subtract>(converted_filter, filter_zero_point);
|
||||
|
||||
const auto& groups = node.get_attribute_value<int64_t>("group", 1);
|
||||
const auto& strides = convpool::get_strides(node);
|
||||
|
Loading…
Reference in New Issue
Block a user