[ONNX] Handle const scalar during unsqueeze (#6323)

* handle scalar during unsqueeze

* info about reverting

* use get_constant_from_source

* evaluate unsqueeze in onnx slice
This commit is contained in:
Mateusz Bencer 2021-07-01 11:26:31 +02:00 committed by GitHub
parent 5323200110
commit a32513f78d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 131 additions and 9 deletions

View File

@ -12,6 +12,7 @@
#include "ngraph/node.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/util/op_types.hpp"
#include "ngraph/validation_util.hpp"
#include "op/gather.hpp"
#include "utils/common.hpp"
@ -171,13 +172,12 @@ namespace ngraph
auto ends = inputs.at(2);
// Slice is calculated over all axes as default
Output<ngraph::Node> axes;
std::shared_ptr<default_opset::Constant> axes_const;
if (inputs.size() >= 4 && !is_null(inputs.at(3))) // axes input provided
{
axes = inputs.at(3);
CHECK_VALID_NODE(node,
ngraph::op::is_constant(axes.get_node()),
"Axes input must be constant");
axes_const = ngraph::get_constant_from_source(inputs.at(3));
CHECK_VALID_NODE(
node, axes_const != nullptr, "Axes input must be constant");
}
else
{
@ -186,14 +186,11 @@ namespace ngraph
data_rank.is_static(),
"Data rank must be static when axes input is not provided");
const size_t data_rank_value = data_rank.get_length();
axes = default_opset::Constant::create(
axes_const = default_opset::Constant::create(
element::i64,
{data_rank_value},
common::get_monotonic_range<int64_t>(data_rank_value));
}
const auto axes_const =
as_type_ptr<default_opset::Constant>(axes.get_node_shared_ptr());
auto raw_axes_vec = axes_const->cast_vector<int64_t>();
std::vector<uint64_t> axes_vec =
get_normalized_axes_vector(node, data_rank, raw_axes_vec);

View File

@ -0,0 +1,114 @@
ir_version: 7
producer_name: "backend-test"
graph {
name: "test_slice_with_unsqueeze_axes"
initializer {
data_type: 7
int64_data: 1
name: "x"
}
initializer {
dims: 1
data_type: 7
int64_data: 1
name: "starts"
}
initializer {
dims: 1
data_type: 7
int64_data: 3
name: "ends"
}
initializer {
dims: 1
data_type: 7
int64_data: 1
name: "steps"
}
node {
input: "x"
output: "slice_axes"
op_type: "Unsqueeze"
attribute {
name: "axes"
ints: 0
type: INTS
}
}
input {
name: "data"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value:2
}
dim {
dim_value:4
}
}
}
}
}
input {
name: "starts"
type {
tensor_type {
elem_type: 7
shape {
dim {
dim_value: 1
}
}
}
}
}
input {
name: "ends"
type {
tensor_type {
elem_type: 7
shape {
dim {
dim_value: 1
}
}
}
}
}
input {
name: "steps"
type {
tensor_type {
elem_type: 7
shape {
dim {
dim_value: 1
}
}
}
}
}
node {
input: "data"
input: "starts"
input: "ends"
input: "slice_axes"
input: "steps"
output: "sliced"
name: "Slice"
op_type: "Slice"
}
output {
name: "sliced"
type {
tensor_type {
elem_type: 1
}
}
}
}
opset_import {
version: 10
}

View File

@ -4334,6 +4334,17 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_multiple_slices_last_layer)
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, onnx_slice_const_axes_source)
{
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/slice_const_axes_source.prototxt"));
auto test_case = test::TestCase<TestEngine>(function);
test_case.add_input<float>(std::vector<float>{1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f});
test_case.add_expected_output<float>(Shape{2, 2}, {2.f, 3.f, 6.f, 7.f});
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, onnx_softmax_crossentropy_loss_mean)
{
auto function = onnx_import::import_onnx_model(