[ONNX FE] Fix Squeeze v1 (#16865)
This commit is contained in:
@@ -5,9 +5,7 @@
|
||||
#include "ngraph/op/squeeze.hpp"
|
||||
|
||||
#include "default_opset.hpp"
|
||||
#include "exceptions.hpp"
|
||||
#include "ngraph/op/constant.hpp"
|
||||
#include "ngraph/validation_util.hpp"
|
||||
#include "op/squeeze.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
@@ -16,16 +14,14 @@ namespace op {
|
||||
namespace set_1 {
|
||||
OutputVector squeeze(const Node& node) {
|
||||
auto data = node.get_ng_inputs().at(0);
|
||||
std::vector<std::int64_t> axes = node.get_attribute_value<std::vector<std::int64_t>>("axes", {});
|
||||
const auto data_rank = data.get_partial_shape().rank();
|
||||
const auto axes = node.get_attribute_value<std::vector<std::int64_t>>("axes", {});
|
||||
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
std::vector<std::size_t> normalized_axes = ngraph::normalize_axes(node.get_description(), axes, data_rank);
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
auto axes_node =
|
||||
std::make_shared<default_opset::Constant>(element::u64, Shape{normalized_axes.size()}, normalized_axes);
|
||||
|
||||
return {std::make_shared<default_opset::Squeeze>(data, axes_node)};
|
||||
if (axes.empty()) {
|
||||
return {std::make_shared<default_opset::Squeeze>(data)};
|
||||
} else {
|
||||
const auto axes_const = std::make_shared<default_opset::Constant>(element::i64, Shape{axes.size()}, axes);
|
||||
return {std::make_shared<default_opset::Squeeze>(data, axes_const)};
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace set_1
|
||||
|
||||
@@ -0,0 +1,58 @@
|
||||
ir_version: 6
|
||||
producer_name: "OV ONNX FE"
|
||||
graph {
|
||||
node {
|
||||
input: "A"
|
||||
output: "B"
|
||||
op_type: "Squeeze"
|
||||
attribute {
|
||||
name: "axes"
|
||||
type: INTS
|
||||
}
|
||||
}
|
||||
name: "compute_graph"
|
||||
input {
|
||||
name: "A"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 1
|
||||
}
|
||||
dim {
|
||||
dim_value: 4
|
||||
}
|
||||
dim {
|
||||
dim_value: 1
|
||||
}
|
||||
dim {
|
||||
dim_value: 1
|
||||
}
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "B"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 4
|
||||
}
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
opset_import {
|
||||
version: 10
|
||||
}
|
||||
@@ -367,6 +367,18 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_squeeze) {
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_squeeze_empty_axes_attribute) {
|
||||
auto function = onnx_import::import_onnx_model(file_util::path_join(CommonTestUtils::getExecutableDirectory(),
|
||||
SERIALIZED_ZOO,
|
||||
"onnx/squeeze_empty_axes_attribute.onnx"));
|
||||
|
||||
auto test_case = test::TestCase(function, s_device);
|
||||
const std::vector<float> data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f};
|
||||
test_case.add_input<float>(Shape{1, 4, 1, 1, 2}, data);
|
||||
test_case.add_expected_output<float>(Shape{4, 2}, data);
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_squeeze_opset13_no_axes) {
|
||||
const auto function = onnx_import::import_onnx_model(file_util::path_join(CommonTestUtils::getExecutableDirectory(),
|
||||
SERIALIZED_ZOO,
|
||||
|
||||
Reference in New Issue
Block a user