[ONNX FE] Fix Squeeze v1 (#16865)

This commit is contained in:
Mateusz Bencer
2023-04-12 14:33:49 +02:00
committed by GitHub
parent 997f60f1c3
commit e737e18b02
3 changed files with 77 additions and 11 deletions

View File

@@ -5,9 +5,7 @@
#include "ngraph/op/squeeze.hpp"
#include "default_opset.hpp"
#include "exceptions.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/validation_util.hpp"
#include "op/squeeze.hpp"
namespace ngraph {
@@ -16,16 +14,14 @@ namespace op {
namespace set_1 {
OutputVector squeeze(const Node& node) {
auto data = node.get_ng_inputs().at(0);
std::vector<std::int64_t> axes = node.get_attribute_value<std::vector<std::int64_t>>("axes", {});
const auto data_rank = data.get_partial_shape().rank();
const auto axes = node.get_attribute_value<std::vector<std::int64_t>>("axes", {});
OPENVINO_SUPPRESS_DEPRECATED_START
std::vector<std::size_t> normalized_axes = ngraph::normalize_axes(node.get_description(), axes, data_rank);
OPENVINO_SUPPRESS_DEPRECATED_END
auto axes_node =
std::make_shared<default_opset::Constant>(element::u64, Shape{normalized_axes.size()}, normalized_axes);
return {std::make_shared<default_opset::Squeeze>(data, axes_node)};
if (axes.empty()) {
return {std::make_shared<default_opset::Squeeze>(data)};
} else {
const auto axes_const = std::make_shared<default_opset::Constant>(element::i64, Shape{axes.size()}, axes);
return {std::make_shared<default_opset::Squeeze>(data, axes_const)};
}
}
} // namespace set_1

View File

@@ -0,0 +1,58 @@
ir_version: 6
producer_name: "OV ONNX FE"
graph {
node {
input: "A"
output: "B"
op_type: "Squeeze"
attribute {
name: "axes"
type: INTS
}
}
name: "compute_graph"
input {
name: "A"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 4
}
dim {
dim_value: 1
}
dim {
dim_value: 1
}
dim {
dim_value: 2
}
}
}
}
}
output {
name: "B"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 4
}
dim {
dim_value: 2
}
}
}
}
}
}
opset_import {
version: 10
}

View File

@@ -367,6 +367,18 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_squeeze) {
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_squeeze_empty_axes_attribute) {
auto function = onnx_import::import_onnx_model(file_util::path_join(CommonTestUtils::getExecutableDirectory(),
SERIALIZED_ZOO,
"onnx/squeeze_empty_axes_attribute.onnx"));
auto test_case = test::TestCase(function, s_device);
const std::vector<float> data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f};
test_case.add_input<float>(Shape{1, 4, 1, 1, 2}, data);
test_case.add_expected_output<float>(Shape{4, 2}, data);
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_squeeze_opset13_no_axes) {
const auto function = onnx_import::import_onnx_model(file_util::path_join(CommonTestUtils::getExecutableDirectory(),
SERIALIZED_ZOO,