From a32513f78d7c3c6d03540f2a8755dff462329939 Mon Sep 17 00:00:00 2001 From: Mateusz Bencer Date: Thu, 1 Jul 2021 11:26:31 +0200 Subject: [PATCH] [ONNX] Handle const scalar during unsqueeze (#6323) * handle scalar during unsqueeze * info about reverting * use get_constant_from_source * evaluate unsqueeze in onnx slice --- ngraph/frontend/onnx_import/src/op/slice.cpp | 15 +-- .../onnx/slice_const_axes_source.prototxt | 114 ++++++++++++++++++ ngraph/test/onnx/onnx_import.in.cpp | 11 ++ 3 files changed, 131 insertions(+), 9 deletions(-) create mode 100644 ngraph/test/models/onnx/slice_const_axes_source.prototxt diff --git a/ngraph/frontend/onnx_import/src/op/slice.cpp b/ngraph/frontend/onnx_import/src/op/slice.cpp index 26010a5bda0..e344a849390 100644 --- a/ngraph/frontend/onnx_import/src/op/slice.cpp +++ b/ngraph/frontend/onnx_import/src/op/slice.cpp @@ -12,6 +12,7 @@ #include "ngraph/node.hpp" #include "ngraph/op/constant.hpp" #include "ngraph/op/util/op_types.hpp" +#include "ngraph/validation_util.hpp" #include "op/gather.hpp" #include "utils/common.hpp" @@ -171,13 +172,12 @@ namespace ngraph auto ends = inputs.at(2); // Slice is calculated over all axes as default - Output axes; + std::shared_ptr axes_const; if (inputs.size() >= 4 && !is_null(inputs.at(3))) // axes input provided { - axes = inputs.at(3); - CHECK_VALID_NODE(node, - ngraph::op::is_constant(axes.get_node()), - "Axes input must be constant"); + axes_const = ngraph::get_constant_from_source(inputs.at(3)); + CHECK_VALID_NODE( + node, axes_const != nullptr, "Axes input must be constant"); } else { @@ -186,14 +186,11 @@ namespace ngraph data_rank.is_static(), "Data rank must be static when axes input is not provided"); const size_t data_rank_value = data_rank.get_length(); - axes = default_opset::Constant::create( + axes_const = default_opset::Constant::create( element::i64, {data_rank_value}, common::get_monotonic_range(data_rank_value)); } - - const auto axes_const = - as_type_ptr(axes.get_node_shared_ptr()); auto raw_axes_vec = axes_const->cast_vector(); std::vector axes_vec = get_normalized_axes_vector(node, data_rank, raw_axes_vec); diff --git a/ngraph/test/models/onnx/slice_const_axes_source.prototxt b/ngraph/test/models/onnx/slice_const_axes_source.prototxt new file mode 100644 index 00000000000..e6da15faddb --- /dev/null +++ b/ngraph/test/models/onnx/slice_const_axes_source.prototxt @@ -0,0 +1,114 @@ +ir_version: 7 +producer_name: "backend-test" +graph { + name: "test_slice_with_unsqueeze_axes" + initializer { + data_type: 7 + int64_data: 1 + name: "x" + } + initializer { + dims: 1 + data_type: 7 + int64_data: 1 + name: "starts" + } + initializer { + dims: 1 + data_type: 7 + int64_data: 3 + name: "ends" + } + initializer { + dims: 1 + data_type: 7 + int64_data: 1 + name: "steps" + } + node { + input: "x" + output: "slice_axes" + op_type: "Unsqueeze" + attribute { + name: "axes" + ints: 0 + type: INTS + } + } + input { + name: "data" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value:2 + } + dim { + dim_value:4 + } + } + } + } + } + input { + name: "starts" + type { + tensor_type { + elem_type: 7 + shape { + dim { + dim_value: 1 + } + } + } + } + } + input { + name: "ends" + type { + tensor_type { + elem_type: 7 + shape { + dim { + dim_value: 1 + } + } + } + } + } + input { + name: "steps" + type { + tensor_type { + elem_type: 7 + shape { + dim { + dim_value: 1 + } + } + } + } + } + node { + input: "data" + input: "starts" + input: "ends" + input: "slice_axes" + input: "steps" + output: "sliced" + name: "Slice" + op_type: "Slice" + } + output { + name: "sliced" + type { + tensor_type { + elem_type: 1 + } + } + } +} +opset_import { + version: 10 +} diff --git a/ngraph/test/onnx/onnx_import.in.cpp b/ngraph/test/onnx/onnx_import.in.cpp index 118b646b3c6..f16e0f2dc45 100644 --- a/ngraph/test/onnx/onnx_import.in.cpp +++ b/ngraph/test/onnx/onnx_import.in.cpp @@ -4334,6 +4334,17 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_multiple_slices_last_layer) test_case.run(); } +NGRAPH_TEST(${BACKEND_NAME}, onnx_slice_const_axes_source) +{ + auto function = onnx_import::import_onnx_model( + file_util::path_join(SERIALIZED_ZOO, "onnx/slice_const_axes_source.prototxt")); + + auto test_case = test::TestCase(function); + test_case.add_input(std::vector{1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}); + test_case.add_expected_output(Shape{2, 2}, {2.f, 3.f, 6.f, 7.f}); + test_case.run(); +} + NGRAPH_TEST(${BACKEND_NAME}, onnx_softmax_crossentropy_loss_mean) { auto function = onnx_import::import_onnx_model(