[ONNX] check if 3rd input to Pad op is not null (#12797)

This commit is contained in:
Mateusz Tabaka 2022-08-30 12:13:47 +02:00 committed by GitHub
parent 1f66077fc3
commit 2d2ac7fafa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 80 additions and 4 deletions

View File

@ -0,0 +1,61 @@
ir_version: 6
producer_name: "backend-test"
graph {
node {
input: "x"
input: "pads"
input: ""
output: "y"
op_type: "Pad"
attribute {
name: "mode"
s: "constant"
type: STRING
}
}
name: "test_constant_pad"
initializer {
dims: 4
data_type: 7
int64_data: 0
int64_data: 2
int64_data: 0
int64_data: 0
name: "pads"
}
input {
name: "x"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 3
}
dim {
dim_value: 2
}
}
}
}
}
output {
name: "y"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 3
}
dim {
dim_value: 4
}
}
}
}
}
}
opset_import {
version: 11
}

View File

@ -14,6 +14,7 @@
#include "ngraph/op/convert.hpp"
#include "ngraph/op/util/op_types.hpp"
#include "ngraph/shape.hpp"
#include "onnx_import/core/null_node.hpp"
#include "op/pad.hpp"
#include "utils/convpool.hpp"
@ -64,14 +65,15 @@ OutputVector pad(const Node& node) {
} // namespace set_1
namespace set_11 {
OutputVector pad(const Node& node) {
auto data = node.get_ng_inputs().at(0);
auto pads = node.get_ng_inputs().at(1);
const auto inputs = node.get_ng_inputs();
const auto& data = inputs[0];
const auto& pads = inputs[1];
Output<ngraph::Node> values;
Output<ngraph::Node> padding_begin;
Output<ngraph::Node> padding_end;
if (node.get_ng_inputs().size() == 3) {
values = node.get_ng_inputs().at(2);
if (inputs.size() == 3 && !ngraph::op::is_null(inputs[2])) {
values = inputs[2];
} else {
values = default_opset::Constant::create(data.get_element_type(), ngraph::Shape{}, {0});
}

View File

@ -3316,6 +3316,19 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_pad_constant) {
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_pad_optional_constant) {
const auto function = onnx_import::import_onnx_model(file_util::path_join(CommonTestUtils::getExecutableDirectory(),
SERIALIZED_ZOO,
"onnx/pad_optional_constant.onnx"));
auto test_case = test::TestCase(function, s_device);
test_case.add_input<float>({1.f, 1.2f, 2.3f, 3.4f, 4.5f, 5.7f});
test_case.add_expected_output<float>(Shape{3, 4},
{0.f, 0.f, 1.f, 1.2f, 0.f, 0.f, 2.3f, 3.4f, 0.f, 0.f, 4.5f, 5.7f});
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_pow_float32_float32) {
const auto function = onnx_import::import_onnx_model(file_util::path_join(CommonTestUtils::getExecutableDirectory(),
SERIALIZED_ZOO,