Proper handling of ONNX Squeeze without the axes input (#5425)

This commit is contained in:
Tomasz Dołbniak 2021-04-28 20:59:32 +02:00 committed by GitHub
parent 7cad047f53
commit 36f407dd18
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 69 additions and 7 deletions

View File

@ -38,14 +38,10 @@ namespace ngraph
{
OutputVector squeeze(const Node& node)
{
auto inputs = node.get_ng_inputs();
const auto inputs = node.get_ng_inputs();
if (inputs.size() < 2)
{
std::vector<int64_t> axes{};
auto axes_node = std::make_shared<default_opset::Constant>(
element::Type_t::u64, Shape{}, axes);
return {std::make_shared<default_opset::Squeeze>(inputs.at(0), axes_node)};
return {std::make_shared<default_opset::Squeeze>(inputs.at(0))};
}
else
{

View File

@ -0,0 +1,54 @@
ir_version: 7
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "A"
output: "B"
op_type: "Squeeze"
}
name: "compute_graph"
input {
name: "A"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 4
}
dim {
dim_value: 1
}
dim {
dim_value: 1
}
dim {
dim_value: 2
}
}
}
}
}
output {
name: "B"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 4
}
dim {
dim_value: 2
}
}
}
}
}
}
opset_import {
version: 13
}

View File

@ -372,7 +372,7 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_space_to_depth_no_blocksize)
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_squeeze)
{
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/squeeze_duplicate_axes.prototxt"));
file_util::path_join(SERIALIZED_ZOO, "onnx/squeeze.prototxt"));
// {1, 4, 1, 1, 2}
auto input = test::NDArray<float, 5>(
@ -390,6 +390,18 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_squeeze)
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_squeeze_opset13_no_axes)
{
const auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/squeeze_opset13_no_axes.prototxt"));
auto test_case = test::TestCase<TestEngine>(function);
const std::vector<float> data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f};
test_case.add_input<float>(Shape{1, 4, 1, 1, 2}, data);
test_case.add_expected_output<float>(Shape{4, 2}, data);
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_unsqueeze)
{
auto function = onnx_import::import_onnx_model(