[ONNX] Handle const scalar during unsqueeze (#6323)
* handle scalar during unsqueeze * info about reverting * use get_constant_from_source * evaluate unsqueeze in onnx slice
This commit is contained in:
parent
5323200110
commit
a32513f78d
@ -12,6 +12,7 @@
|
||||
#include "ngraph/node.hpp"
|
||||
#include "ngraph/op/constant.hpp"
|
||||
#include "ngraph/op/util/op_types.hpp"
|
||||
#include "ngraph/validation_util.hpp"
|
||||
#include "op/gather.hpp"
|
||||
#include "utils/common.hpp"
|
||||
|
||||
@ -171,13 +172,12 @@ namespace ngraph
|
||||
auto ends = inputs.at(2);
|
||||
|
||||
// Slice is calculated over all axes as default
|
||||
Output<ngraph::Node> axes;
|
||||
std::shared_ptr<default_opset::Constant> axes_const;
|
||||
if (inputs.size() >= 4 && !is_null(inputs.at(3))) // axes input provided
|
||||
{
|
||||
axes = inputs.at(3);
|
||||
CHECK_VALID_NODE(node,
|
||||
ngraph::op::is_constant(axes.get_node()),
|
||||
"Axes input must be constant");
|
||||
axes_const = ngraph::get_constant_from_source(inputs.at(3));
|
||||
CHECK_VALID_NODE(
|
||||
node, axes_const != nullptr, "Axes input must be constant");
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -186,14 +186,11 @@ namespace ngraph
|
||||
data_rank.is_static(),
|
||||
"Data rank must be static when axes input is not provided");
|
||||
const size_t data_rank_value = data_rank.get_length();
|
||||
axes = default_opset::Constant::create(
|
||||
axes_const = default_opset::Constant::create(
|
||||
element::i64,
|
||||
{data_rank_value},
|
||||
common::get_monotonic_range<int64_t>(data_rank_value));
|
||||
}
|
||||
|
||||
const auto axes_const =
|
||||
as_type_ptr<default_opset::Constant>(axes.get_node_shared_ptr());
|
||||
auto raw_axes_vec = axes_const->cast_vector<int64_t>();
|
||||
std::vector<uint64_t> axes_vec =
|
||||
get_normalized_axes_vector(node, data_rank, raw_axes_vec);
|
||||
|
114
ngraph/test/models/onnx/slice_const_axes_source.prototxt
Normal file
114
ngraph/test/models/onnx/slice_const_axes_source.prototxt
Normal file
@ -0,0 +1,114 @@
|
||||
ir_version: 7
|
||||
producer_name: "backend-test"
|
||||
graph {
|
||||
name: "test_slice_with_unsqueeze_axes"
|
||||
initializer {
|
||||
data_type: 7
|
||||
int64_data: 1
|
||||
name: "x"
|
||||
}
|
||||
initializer {
|
||||
dims: 1
|
||||
data_type: 7
|
||||
int64_data: 1
|
||||
name: "starts"
|
||||
}
|
||||
initializer {
|
||||
dims: 1
|
||||
data_type: 7
|
||||
int64_data: 3
|
||||
name: "ends"
|
||||
}
|
||||
initializer {
|
||||
dims: 1
|
||||
data_type: 7
|
||||
int64_data: 1
|
||||
name: "steps"
|
||||
}
|
||||
node {
|
||||
input: "x"
|
||||
output: "slice_axes"
|
||||
op_type: "Unsqueeze"
|
||||
attribute {
|
||||
name: "axes"
|
||||
ints: 0
|
||||
type: INTS
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "data"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value:2
|
||||
}
|
||||
dim {
|
||||
dim_value:4
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "starts"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 7
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "ends"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 7
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "steps"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 7
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
input: "data"
|
||||
input: "starts"
|
||||
input: "ends"
|
||||
input: "slice_axes"
|
||||
input: "steps"
|
||||
output: "sliced"
|
||||
name: "Slice"
|
||||
op_type: "Slice"
|
||||
}
|
||||
output {
|
||||
name: "sliced"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
opset_import {
|
||||
version: 10
|
||||
}
|
@ -4334,6 +4334,17 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_multiple_slices_last_layer)
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_slice_const_axes_source)
|
||||
{
|
||||
auto function = onnx_import::import_onnx_model(
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/slice_const_axes_source.prototxt"));
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(function);
|
||||
test_case.add_input<float>(std::vector<float>{1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f});
|
||||
test_case.add_expected_output<float>(Shape{2, 2}, {2.f, 3.f, 6.f, 7.f});
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_softmax_crossentropy_loss_mean)
|
||||
{
|
||||
auto function = onnx_import::import_onnx_model(
|
||||
|
Loading…
Reference in New Issue
Block a user