enable 1-D axis in cumsum (#17650)
* enable 1-D axis in cumsum * fix according to comments; add testcases
This commit is contained in:
parent
7bdaedf4f8
commit
5630d3a8b2
@ -7,6 +7,7 @@
|
||||
#include <memory>
|
||||
|
||||
#include "default_opset.hpp"
|
||||
#include "utils/reshape.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace onnx_import {
|
||||
@ -20,7 +21,9 @@ OutputVector cum_sum(const Node& node) {
|
||||
Output<ngraph::Node> axis;
|
||||
|
||||
if (inputs.size() > 1) {
|
||||
axis = inputs.at(1); // optional input, 0-D tensor
|
||||
// optional input, 0-D or 1-D tensor
|
||||
const auto& axis_shape = inputs.at(1).get_partial_shape();
|
||||
axis = axis_shape.is_dynamic() ? inputs.at(1) : ngraph::onnx_import::reshape::interpret_as_scalar(inputs.at(1));
|
||||
} else {
|
||||
axis = default_opset::Constant::create(element::i64, Shape{}, {0}); // default
|
||||
}
|
||||
|
@ -0,0 +1,69 @@
|
||||
ir_version: 5
|
||||
producer_name: "nGraph ONNX Importer"
|
||||
graph {
|
||||
node {
|
||||
output: "axis"
|
||||
op_type: "Constant"
|
||||
attribute {
|
||||
name: "value"
|
||||
t {
|
||||
dims: 1
|
||||
data_type: 7
|
||||
int64_data: -1
|
||||
}
|
||||
type: TENSOR
|
||||
}
|
||||
}
|
||||
node {
|
||||
input: "x"
|
||||
input: "axis"
|
||||
output: "y"
|
||||
op_type: "CumSum"
|
||||
attribute {
|
||||
name: "exclusive"
|
||||
i: 0
|
||||
type: INT
|
||||
}
|
||||
attribute {
|
||||
name: "reverse"
|
||||
i: 0
|
||||
type: INT
|
||||
}
|
||||
}
|
||||
name: "test_cum_sum"
|
||||
input {
|
||||
name: "x"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
dim {
|
||||
dim_value: 3
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "y"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
dim {
|
||||
dim_value: 3
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
opset_import {
|
||||
version: 11
|
||||
}
|
@ -0,0 +1,69 @@
|
||||
ir_version: 5
|
||||
producer_name: "nGraph ONNX Importer"
|
||||
graph {
|
||||
node {
|
||||
input: "x"
|
||||
input: "axis"
|
||||
output: "y"
|
||||
op_type: "CumSum"
|
||||
attribute {
|
||||
name: "exclusive"
|
||||
i: 0
|
||||
type: INT
|
||||
}
|
||||
attribute {
|
||||
name: "reverse"
|
||||
i: 0
|
||||
type: INT
|
||||
}
|
||||
}
|
||||
name: "test_cum_sum"
|
||||
input {
|
||||
name: "x"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
dim {
|
||||
dim_value: 3
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "axis"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 7
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "y"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
dim {
|
||||
dim_value: 3
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
opset_import {
|
||||
version: 11
|
||||
}
|
@ -641,6 +641,29 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_cum_sum_2d_dynamic_axis_input) {
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_cum_sum_2d_axis_input_1d) {
|
||||
auto function = onnx_import::import_onnx_model(file_util::path_join(CommonTestUtils::getExecutableDirectory(),
|
||||
SERIALIZED_ZOO,
|
||||
"onnx/cum_sum_2d_axis_input_1d.onnx"));
|
||||
|
||||
auto test_case = test::TestCase(function, s_device);
|
||||
test_case.add_input<float>({1.f, 2.f, 3.f, 4.f, 5.f, 6.f});
|
||||
test_case.add_expected_output<float>(Shape{2, 3}, {1.f, 3.f, 6.f, 4.f, 9.f, 15.f});
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_cum_sum_2d_dynamic_axis_input_1d) {
|
||||
auto function = onnx_import::import_onnx_model(file_util::path_join(CommonTestUtils::getExecutableDirectory(),
|
||||
SERIALIZED_ZOO,
|
||||
"onnx/cum_sum_2d_dynamic_axis_input_1d.onnx"));
|
||||
|
||||
auto test_case = test::TestCase(function, s_device);
|
||||
test_case.add_input<float>({1.f, 2.f, 3.f, 4.f, 5.f, 6.f});
|
||||
test_case.add_input<std::int64_t>({0});
|
||||
test_case.add_expected_output<float>(Shape{2, 3}, {1.f, 2.f, 3.f, 5.f, 7.f, 9.f});
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_cum_sum_3d_exclusive_reverse) {
|
||||
auto function = onnx_import::import_onnx_model(file_util::path_join(CommonTestUtils::getExecutableDirectory(),
|
||||
SERIALIZED_ZOO,
|
||||
|
Loading…
Reference in New Issue
Block a user