Proper handling of ONNX Squeeze without the axes input (#5425)
This commit is contained in:
parent
7cad047f53
commit
36f407dd18
@ -38,14 +38,10 @@ namespace ngraph
|
|||||||
{
|
{
|
||||||
OutputVector squeeze(const Node& node)
|
OutputVector squeeze(const Node& node)
|
||||||
{
|
{
|
||||||
auto inputs = node.get_ng_inputs();
|
const auto inputs = node.get_ng_inputs();
|
||||||
if (inputs.size() < 2)
|
if (inputs.size() < 2)
|
||||||
{
|
{
|
||||||
std::vector<int64_t> axes{};
|
return {std::make_shared<default_opset::Squeeze>(inputs.at(0))};
|
||||||
auto axes_node = std::make_shared<default_opset::Constant>(
|
|
||||||
element::Type_t::u64, Shape{}, axes);
|
|
||||||
|
|
||||||
return {std::make_shared<default_opset::Squeeze>(inputs.at(0), axes_node)};
|
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
|
54
ngraph/test/models/onnx/squeeze_opset13_no_axes.prototxt
Normal file
54
ngraph/test/models/onnx/squeeze_opset13_no_axes.prototxt
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
ir_version: 7
|
||||||
|
producer_name: "nGraph ONNX Importer"
|
||||||
|
graph {
|
||||||
|
node {
|
||||||
|
input: "A"
|
||||||
|
output: "B"
|
||||||
|
op_type: "Squeeze"
|
||||||
|
}
|
||||||
|
name: "compute_graph"
|
||||||
|
input {
|
||||||
|
name: "A"
|
||||||
|
type {
|
||||||
|
tensor_type {
|
||||||
|
elem_type: 1
|
||||||
|
shape {
|
||||||
|
dim {
|
||||||
|
dim_value: 1
|
||||||
|
}
|
||||||
|
dim {
|
||||||
|
dim_value: 4
|
||||||
|
}
|
||||||
|
dim {
|
||||||
|
dim_value: 1
|
||||||
|
}
|
||||||
|
dim {
|
||||||
|
dim_value: 1
|
||||||
|
}
|
||||||
|
dim {
|
||||||
|
dim_value: 2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
output {
|
||||||
|
name: "B"
|
||||||
|
type {
|
||||||
|
tensor_type {
|
||||||
|
elem_type: 1
|
||||||
|
shape {
|
||||||
|
dim {
|
||||||
|
dim_value: 4
|
||||||
|
}
|
||||||
|
dim {
|
||||||
|
dim_value: 2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
opset_import {
|
||||||
|
version: 13
|
||||||
|
}
|
@ -372,7 +372,7 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_space_to_depth_no_blocksize)
|
|||||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_squeeze)
|
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_squeeze)
|
||||||
{
|
{
|
||||||
auto function = onnx_import::import_onnx_model(
|
auto function = onnx_import::import_onnx_model(
|
||||||
file_util::path_join(SERIALIZED_ZOO, "onnx/squeeze_duplicate_axes.prototxt"));
|
file_util::path_join(SERIALIZED_ZOO, "onnx/squeeze.prototxt"));
|
||||||
|
|
||||||
// {1, 4, 1, 1, 2}
|
// {1, 4, 1, 1, 2}
|
||||||
auto input = test::NDArray<float, 5>(
|
auto input = test::NDArray<float, 5>(
|
||||||
@ -390,6 +390,18 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_squeeze)
|
|||||||
test_case.run();
|
test_case.run();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_squeeze_opset13_no_axes)
|
||||||
|
{
|
||||||
|
const auto function = onnx_import::import_onnx_model(
|
||||||
|
file_util::path_join(SERIALIZED_ZOO, "onnx/squeeze_opset13_no_axes.prototxt"));
|
||||||
|
|
||||||
|
auto test_case = test::TestCase<TestEngine>(function);
|
||||||
|
const std::vector<float> data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f};
|
||||||
|
test_case.add_input<float>(Shape{1, 4, 1, 1, 2}, data);
|
||||||
|
test_case.add_expected_output<float>(Shape{4, 2}, data);
|
||||||
|
test_case.run();
|
||||||
|
}
|
||||||
|
|
||||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_unsqueeze)
|
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_unsqueeze)
|
||||||
{
|
{
|
||||||
auto function = onnx_import::import_onnx_model(
|
auto function = onnx_import::import_onnx_model(
|
||||||
|
Loading…
Reference in New Issue
Block a user