enable 1-D axis in cumsum (#17650)

* enable 1-D axis in cumsum

* fix according to comments; add testcases
This commit is contained in:
Xiuchuan Zhai 2023-07-12 19:54:00 +08:00 committed by GitHub
parent 7bdaedf4f8
commit 5630d3a8b2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 165 additions and 1 deletions

View File

@ -7,6 +7,7 @@
#include <memory>
#include "default_opset.hpp"
#include "utils/reshape.hpp"
namespace ngraph {
namespace onnx_import {
@ -20,7 +21,9 @@ OutputVector cum_sum(const Node& node) {
Output<ngraph::Node> axis;
if (inputs.size() > 1) {
axis = inputs.at(1); // optional input, 0-D tensor
// optional input, 0-D or 1-D tensor
const auto& axis_shape = inputs.at(1).get_partial_shape();
axis = axis_shape.is_dynamic() ? inputs.at(1) : ngraph::onnx_import::reshape::interpret_as_scalar(inputs.at(1));
} else {
axis = default_opset::Constant::create(element::i64, Shape{}, {0}); // default
}

View File

@ -0,0 +1,69 @@
ir_version: 5
producer_name: "nGraph ONNX Importer"
graph {
node {
output: "axis"
op_type: "Constant"
attribute {
name: "value"
t {
dims: 1
data_type: 7
int64_data: -1
}
type: TENSOR
}
}
node {
input: "x"
input: "axis"
output: "y"
op_type: "CumSum"
attribute {
name: "exclusive"
i: 0
type: INT
}
attribute {
name: "reverse"
i: 0
type: INT
}
}
name: "test_cum_sum"
input {
name: "x"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
dim {
dim_value: 3
}
}
}
}
}
output {
name: "y"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
dim {
dim_value: 3
}
}
}
}
}
}
opset_import {
version: 11
}

View File

@ -0,0 +1,69 @@
ir_version: 5
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "x"
input: "axis"
output: "y"
op_type: "CumSum"
attribute {
name: "exclusive"
i: 0
type: INT
}
attribute {
name: "reverse"
i: 0
type: INT
}
}
name: "test_cum_sum"
input {
name: "x"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
dim {
dim_value: 3
}
}
}
}
}
input {
name: "axis"
type {
tensor_type {
elem_type: 7
shape {
dim {
dim_value: 1
}
}
}
}
}
output {
name: "y"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
dim {
dim_value: 3
}
}
}
}
}
}
opset_import {
version: 11
}

View File

@ -641,6 +641,29 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_cum_sum_2d_dynamic_axis_input) {
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_cum_sum_2d_axis_input_1d) {
auto function = onnx_import::import_onnx_model(file_util::path_join(CommonTestUtils::getExecutableDirectory(),
SERIALIZED_ZOO,
"onnx/cum_sum_2d_axis_input_1d.onnx"));
auto test_case = test::TestCase(function, s_device);
test_case.add_input<float>({1.f, 2.f, 3.f, 4.f, 5.f, 6.f});
test_case.add_expected_output<float>(Shape{2, 3}, {1.f, 3.f, 6.f, 4.f, 9.f, 15.f});
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_cum_sum_2d_dynamic_axis_input_1d) {
auto function = onnx_import::import_onnx_model(file_util::path_join(CommonTestUtils::getExecutableDirectory(),
SERIALIZED_ZOO,
"onnx/cum_sum_2d_dynamic_axis_input_1d.onnx"));
auto test_case = test::TestCase(function, s_device);
test_case.add_input<float>({1.f, 2.f, 3.f, 4.f, 5.f, 6.f});
test_case.add_input<std::int64_t>({0});
test_case.add_expected_output<float>(Shape{2, 3}, {1.f, 2.f, 3.f, 5.f, 7.f, 9.f});
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_cum_sum_3d_exclusive_reverse) {
auto function = onnx_import::import_onnx_model(file_util::path_join(CommonTestUtils::getExecutableDirectory(),
SERIALIZED_ZOO,