[ONNX] check if 3rd input to Pad op is not null (#12797)
This commit is contained in:
parent
1f66077fc3
commit
2d2ac7fafa
61
src/core/tests/models/onnx/pad_optional_constant.prototxt
Normal file
61
src/core/tests/models/onnx/pad_optional_constant.prototxt
Normal file
@ -0,0 +1,61 @@
|
||||
ir_version: 6
|
||||
producer_name: "backend-test"
|
||||
graph {
|
||||
node {
|
||||
input: "x"
|
||||
input: "pads"
|
||||
input: ""
|
||||
output: "y"
|
||||
op_type: "Pad"
|
||||
attribute {
|
||||
name: "mode"
|
||||
s: "constant"
|
||||
type: STRING
|
||||
}
|
||||
}
|
||||
name: "test_constant_pad"
|
||||
initializer {
|
||||
dims: 4
|
||||
data_type: 7
|
||||
int64_data: 0
|
||||
int64_data: 2
|
||||
int64_data: 0
|
||||
int64_data: 0
|
||||
name: "pads"
|
||||
}
|
||||
input {
|
||||
name: "x"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 3
|
||||
}
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "y"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 3
|
||||
}
|
||||
dim {
|
||||
dim_value: 4
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
opset_import {
|
||||
version: 11
|
||||
}
|
@ -14,6 +14,7 @@
|
||||
#include "ngraph/op/convert.hpp"
|
||||
#include "ngraph/op/util/op_types.hpp"
|
||||
#include "ngraph/shape.hpp"
|
||||
#include "onnx_import/core/null_node.hpp"
|
||||
#include "op/pad.hpp"
|
||||
#include "utils/convpool.hpp"
|
||||
|
||||
@ -64,14 +65,15 @@ OutputVector pad(const Node& node) {
|
||||
} // namespace set_1
|
||||
namespace set_11 {
|
||||
OutputVector pad(const Node& node) {
|
||||
auto data = node.get_ng_inputs().at(0);
|
||||
auto pads = node.get_ng_inputs().at(1);
|
||||
const auto inputs = node.get_ng_inputs();
|
||||
const auto& data = inputs[0];
|
||||
const auto& pads = inputs[1];
|
||||
Output<ngraph::Node> values;
|
||||
Output<ngraph::Node> padding_begin;
|
||||
Output<ngraph::Node> padding_end;
|
||||
|
||||
if (node.get_ng_inputs().size() == 3) {
|
||||
values = node.get_ng_inputs().at(2);
|
||||
if (inputs.size() == 3 && !ngraph::op::is_null(inputs[2])) {
|
||||
values = inputs[2];
|
||||
} else {
|
||||
values = default_opset::Constant::create(data.get_element_type(), ngraph::Shape{}, {0});
|
||||
}
|
||||
|
@ -3316,6 +3316,19 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_pad_constant) {
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_pad_optional_constant) {
|
||||
const auto function = onnx_import::import_onnx_model(file_util::path_join(CommonTestUtils::getExecutableDirectory(),
|
||||
SERIALIZED_ZOO,
|
||||
"onnx/pad_optional_constant.onnx"));
|
||||
auto test_case = test::TestCase(function, s_device);
|
||||
|
||||
test_case.add_input<float>({1.f, 1.2f, 2.3f, 3.4f, 4.5f, 5.7f});
|
||||
test_case.add_expected_output<float>(Shape{3, 4},
|
||||
{0.f, 0.f, 1.f, 1.2f, 0.f, 0.f, 2.3f, 3.4f, 0.f, 0.f, 4.5f, 5.7f});
|
||||
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_pow_float32_float32) {
|
||||
const auto function = onnx_import::import_onnx_model(file_util::path_join(CommonTestUtils::getExecutableDirectory(),
|
||||
SERIALIZED_ZOO,
|
||||
|
Loading…
Reference in New Issue
Block a user