Proper handling of ONNX Squeeze without the axes input (#5425)
This commit is contained in:
parent
7cad047f53
commit
36f407dd18
@ -38,14 +38,10 @@ namespace ngraph
|
||||
{
|
||||
OutputVector squeeze(const Node& node)
|
||||
{
|
||||
auto inputs = node.get_ng_inputs();
|
||||
const auto inputs = node.get_ng_inputs();
|
||||
if (inputs.size() < 2)
|
||||
{
|
||||
std::vector<int64_t> axes{};
|
||||
auto axes_node = std::make_shared<default_opset::Constant>(
|
||||
element::Type_t::u64, Shape{}, axes);
|
||||
|
||||
return {std::make_shared<default_opset::Squeeze>(inputs.at(0), axes_node)};
|
||||
return {std::make_shared<default_opset::Squeeze>(inputs.at(0))};
|
||||
}
|
||||
else
|
||||
{
|
||||
|
54
ngraph/test/models/onnx/squeeze_opset13_no_axes.prototxt
Normal file
54
ngraph/test/models/onnx/squeeze_opset13_no_axes.prototxt
Normal file
@ -0,0 +1,54 @@
|
||||
ir_version: 7
|
||||
producer_name: "nGraph ONNX Importer"
|
||||
graph {
|
||||
node {
|
||||
input: "A"
|
||||
output: "B"
|
||||
op_type: "Squeeze"
|
||||
}
|
||||
name: "compute_graph"
|
||||
input {
|
||||
name: "A"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 1
|
||||
}
|
||||
dim {
|
||||
dim_value: 4
|
||||
}
|
||||
dim {
|
||||
dim_value: 1
|
||||
}
|
||||
dim {
|
||||
dim_value: 1
|
||||
}
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "B"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 4
|
||||
}
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
opset_import {
|
||||
version: 13
|
||||
}
|
@ -372,7 +372,7 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_space_to_depth_no_blocksize)
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_squeeze)
|
||||
{
|
||||
auto function = onnx_import::import_onnx_model(
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/squeeze_duplicate_axes.prototxt"));
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/squeeze.prototxt"));
|
||||
|
||||
// {1, 4, 1, 1, 2}
|
||||
auto input = test::NDArray<float, 5>(
|
||||
@ -390,6 +390,18 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_squeeze)
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_squeeze_opset13_no_axes)
|
||||
{
|
||||
const auto function = onnx_import::import_onnx_model(
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/squeeze_opset13_no_axes.prototxt"));
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(function);
|
||||
const std::vector<float> data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f};
|
||||
test_case.add_input<float>(Shape{1, 4, 1, 1, 2}, data);
|
||||
test_case.add_expected_output<float>(Shape{4, 2}, data);
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_unsqueeze)
|
||||
{
|
||||
auto function = onnx_import::import_onnx_model(
|
||||
|
Loading…
Reference in New Issue
Block a user