ONNX Slice operator support types dynamically (#4507)

This commit is contained in:
Bartosz Sledz 2021-02-26 16:28:51 +01:00 committed by GitHub
parent f88f81c6ba
commit b0043bb599
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 174 additions and 111 deletions

View File

@ -139,15 +139,16 @@ namespace ngraph
// expected_output_shape: {3, 3, 1, 1}
OutputVector adjusted_indices(slice_indices_length);
std::vector<uint64_t> target_axes(axes);
const auto gather_axis = default_opset::Constant::create(element::i64, {}, {0});
const auto gather_axis =
default_opset::Constant::create(indices.get_element_type(), {}, {0});
int added_indices_number = 0;
for (int i = 0; i < slice_indices_length; ++i)
{
if (std::find(std::begin(axes), std::end(axes), i) == axes.end())
{
adjusted_indices[i] =
default_opset::Constant::create(element::i64, {1}, {fill_in_value});
adjusted_indices[i] = default_opset::Constant::create(
indices.get_element_type(), {1}, {fill_in_value});
target_axes.insert(std::next(target_axes.begin(), i), i);
++added_indices_number;
}
@ -156,7 +157,7 @@ namespace ngraph
adjusted_indices[i] = std::make_shared<default_opset::Gather>(
indices,
default_opset::Constant::create(
element::i64, {1}, {i - added_indices_number}),
indices.get_element_type(), {1}, {i - added_indices_number}),
gather_axis);
}
}

View File

@ -159,7 +159,6 @@ xfail_issue_47330 = xfail_test(reason="RuntimeError: Eltwise node with name `[na
"FP64 precision.")
xfail_issue_47337 = xfail_test(reason="RuntimeError: Unsupported dynamic ops: v1::OneHot")
xfail_issue_33593 = xfail_test(reason="Current implementation of MaxPool doesn't support indices output")
xfail_issue_49113 = xfail_test(reason="NLL Loss error: While validating ONNX node '<Node(Slice):")
xfail_issue_48098 = xfail_test(reason="ngraph.exceptions.UserInputError: ('Expected %s parameters, "
"received %s.', <value1>, <value2>)")
xfail_issue_48100 = xfail_test(reason="RuntimeError: cpu_convert can't convert from: "

View File

@ -75,7 +75,6 @@ from tests import (BACKEND_NAME,
xfail_issue_47330,
xfail_issue_47337,
xfail_issue_48052,
xfail_issue_49113,
xfail_issue_48098,
xfail_issue_48100,
xfail_issue_49207,
@ -295,111 +294,6 @@ tests_expected_to_fail = [
"OnnxBackendNodeModelTest.test_resize_downsample_sizes_nearest_cpu"),
(xfail_issue_33581,
"OnnxBackendNodeModelTest.test_gather_elements_negative_indices_cpu"),
(xfail_issue_49113,
"OnnxBackendNodeModelTest.test_nllloss_NC_cpu",
"OnnxBackendNodeModelTest.test_nllloss_NC_expanded_cpu",
"OnnxBackendNodeModelTest.test_nllloss_NCd1_cpu",
"OnnxBackendNodeModelTest.test_nllloss_NCd1_expanded_cpu",
"OnnxBackendNodeModelTest.test_nllloss_NCd1_ii_cpu",
"OnnxBackendNodeModelTest.test_nllloss_NCd1_ii_expanded_cpu",
"OnnxBackendNodeModelTest.test_nllloss_NCd1_mean_weight_negative_ii_cpu",
"OnnxBackendNodeModelTest.test_nllloss_NCd1_mean_weight_negative_ii_expanded_cpu",
"OnnxBackendNodeModelTest.test_nllloss_NCd1_weight_cpu",
"OnnxBackendNodeModelTest.test_nllloss_NCd1_weight_expanded_cpu",
"OnnxBackendNodeModelTest.test_nllloss_NCd1_weight_ii_cpu",
"OnnxBackendNodeModelTest.test_nllloss_NCd1_weight_ii_expanded_cpu",
"OnnxBackendNodeModelTest.test_nllloss_NCd1d2_cpu",
"OnnxBackendNodeModelTest.test_nllloss_NCd1d2_expanded_cpu",
"OnnxBackendNodeModelTest.test_nllloss_NCd1d2_no_weight_reduction_mean_ii_cpu",
"OnnxBackendNodeModelTest.test_nllloss_NCd1d2_no_weight_reduction_mean_ii_expanded_cpu",
"OnnxBackendNodeModelTest.test_nllloss_NCd1d2_reduction_mean_cpu",
"OnnxBackendNodeModelTest.test_nllloss_NCd1d2_reduction_mean_expanded_cpu",
"OnnxBackendNodeModelTest.test_nllloss_NCd1d2_reduction_sum_cpu",
"OnnxBackendNodeModelTest.test_nllloss_NCd1d2_reduction_sum_expanded_cpu",
"OnnxBackendNodeModelTest.test_nllloss_NCd1d2_with_weight_cpu",
"OnnxBackendNodeModelTest.test_nllloss_NCd1d2_with_weight_expanded_cpu",
"OnnxBackendNodeModelTest.test_nllloss_NCd1d2_with_weight_reduction_mean_cpu",
"OnnxBackendNodeModelTest.test_nllloss_NCd1d2_with_weight_reduction_mean_expanded_cpu",
"OnnxBackendNodeModelTest.test_nllloss_NCd1d2_with_weight_reduction_sum_cpu",
"OnnxBackendNodeModelTest.test_nllloss_NCd1d2_with_weight_reduction_sum_expanded_cpu",
"OnnxBackendNodeModelTest.test_nllloss_NCd1d2_with_weight_reduction_sum_ii_cpu",
"OnnxBackendNodeModelTest.test_nllloss_NCd1d2_with_weight_reduction_sum_ii_expanded_cpu",
"OnnxBackendNodeModelTest.test_nllloss_NCd1d2d3_none_no_weight_negative_ii_cpu",
"OnnxBackendNodeModelTest.test_nllloss_NCd1d2d3_none_no_weight_negative_ii_expanded_cpu",
"OnnxBackendNodeModelTest.test_nllloss_NCd1d2d3_sum_weight_high_ii_cpu",
"OnnxBackendNodeModelTest.test_nllloss_NCd1d2d3_sum_weight_high_ii_expanded_cpu",
"OnnxBackendNodeModelTest.test_nllloss_NCd1d2d3d4d5_mean_weight_cpu",
"OnnxBackendNodeModelTest.test_nllloss_NCd1d2d3d4d5_mean_weight_expanded_cpu",
"OnnxBackendNodeModelTest.test_nllloss_NCd1d2d3d4d5_none_no_weight_cpu",
"OnnxBackendNodeModelTest.test_nllloss_NCd1d2d3d4d5_none_no_weight_expanded_cpu",
"OnnxBackendNodeModelTest.test_sce_NCd1_mean_weight_negative_ii_cpu",
"OnnxBackendNodeModelTest.test_sce_NCd1_mean_weight_negative_ii_expanded_cpu",
"OnnxBackendNodeModelTest.test_sce_NCd1_mean_weight_negative_ii_log_prob_cpu",
"OnnxBackendNodeModelTest.test_sce_NCd1_mean_weight_negative_ii_log_prob_expanded_cpu",
"OnnxBackendNodeModelTest.test_sce_NCd1d2d3_none_no_weight_negative_ii_cpu",
"OnnxBackendNodeModelTest.test_sce_NCd1d2d3_none_no_weight_negative_ii_expanded_cpu",
"OnnxBackendNodeModelTest.test_sce_NCd1d2d3_none_no_weight_negative_ii_log_prob_cpu",
"OnnxBackendNodeModelTest.test_sce_NCd1d2d3_none_no_weight_negative_ii_log_prob_expanded_cpu",
"OnnxBackendNodeModelTest.test_sce_NCd1d2d3_sum_weight_high_ii_cpu",
"OnnxBackendNodeModelTest.test_sce_NCd1d2d3_sum_weight_high_ii_expanded_cpu",
"OnnxBackendNodeModelTest.test_sce_NCd1d2d3_sum_weight_high_ii_log_prob_cpu",
"OnnxBackendNodeModelTest.test_sce_NCd1d2d3_sum_weight_high_ii_log_prob_expanded_cpu",
"OnnxBackendNodeModelTest.test_sce_NCd1d2d3d4d5_mean_weight_cpu",
"OnnxBackendNodeModelTest.test_sce_NCd1d2d3d4d5_mean_weight_expanded_cpu",
"OnnxBackendNodeModelTest.test_sce_NCd1d2d3d4d5_mean_weight_log_prob_cpu",
"OnnxBackendNodeModelTest.test_sce_NCd1d2d3d4d5_mean_weight_log_prob_expanded_cpu",
"OnnxBackendNodeModelTest.test_sce_NCd1d2d3d4d5_none_no_weight_cpu",
"OnnxBackendNodeModelTest.test_sce_NCd1d2d3d4d5_none_no_weight_expanded_cpu",
"OnnxBackendNodeModelTest.test_sce_NCd1d2d3d4d5_none_no_weight_log_prob_cpu",
"OnnxBackendNodeModelTest.test_sce_NCd1d2d3d4d5_none_no_weight_log_prob_expanded_cpu",
"OnnxBackendNodeModelTest.test_sce_mean_3d_cpu",
"OnnxBackendNodeModelTest.test_sce_mean_3d_expanded_cpu",
"OnnxBackendNodeModelTest.test_sce_mean_3d_log_prob_cpu",
"OnnxBackendNodeModelTest.test_sce_mean_3d_log_prob_expanded_cpu",
"OnnxBackendNodeModelTest.test_sce_mean_cpu",
"OnnxBackendNodeModelTest.test_sce_mean_expanded_cpu",
"OnnxBackendNodeModelTest.test_sce_mean_log_prob_cpu",
"OnnxBackendNodeModelTest.test_sce_mean_log_prob_expanded_cpu",
"OnnxBackendNodeModelTest.test_sce_mean_no_weight_ii_3d_cpu",
"OnnxBackendNodeModelTest.test_sce_mean_no_weight_ii_3d_expanded_cpu",
"OnnxBackendNodeModelTest.test_sce_mean_no_weight_ii_3d_log_prob_cpu",
"OnnxBackendNodeModelTest.test_sce_mean_no_weight_ii_3d_log_prob_expanded_cpu",
"OnnxBackendNodeModelTest.test_sce_mean_no_weight_ii_4d_cpu",
"OnnxBackendNodeModelTest.test_sce_mean_no_weight_ii_4d_expanded_cpu",
"OnnxBackendNodeModelTest.test_sce_mean_no_weight_ii_4d_log_prob_cpu",
"OnnxBackendNodeModelTest.test_sce_mean_no_weight_ii_4d_log_prob_expanded_cpu",
"OnnxBackendNodeModelTest.test_sce_mean_no_weight_ii_cpu",
"OnnxBackendNodeModelTest.test_sce_mean_no_weight_ii_expanded_cpu",
"OnnxBackendNodeModelTest.test_sce_mean_no_weight_ii_log_prob_cpu",
"OnnxBackendNodeModelTest.test_sce_mean_no_weight_ii_log_prob_expanded_cpu",
"OnnxBackendNodeModelTest.test_sce_mean_weight_cpu",
"OnnxBackendNodeModelTest.test_sce_mean_weight_expanded_cpu",
"OnnxBackendNodeModelTest.test_sce_mean_weight_ii_3d_cpu",
"OnnxBackendNodeModelTest.test_sce_mean_weight_ii_3d_expanded_cpu",
"OnnxBackendNodeModelTest.test_sce_mean_weight_ii_3d_log_prob_cpu",
"OnnxBackendNodeModelTest.test_sce_mean_weight_ii_3d_log_prob_expanded_cpu",
"OnnxBackendNodeModelTest.test_sce_mean_weight_ii_4d_cpu",
"OnnxBackendNodeModelTest.test_sce_mean_weight_ii_4d_expanded_cpu",
"OnnxBackendNodeModelTest.test_sce_mean_weight_ii_4d_log_prob_cpu",
"OnnxBackendNodeModelTest.test_sce_mean_weight_ii_4d_log_prob_expanded_cpu",
"OnnxBackendNodeModelTest.test_sce_mean_weight_ii_cpu",
"OnnxBackendNodeModelTest.test_sce_mean_weight_ii_expanded_cpu",
"OnnxBackendNodeModelTest.test_sce_mean_weight_ii_log_prob_cpu",
"OnnxBackendNodeModelTest.test_sce_mean_weight_ii_log_prob_expanded_cpu",
"OnnxBackendNodeModelTest.test_sce_mean_weight_log_prob_cpu",
"OnnxBackendNodeModelTest.test_sce_mean_weight_log_prob_expanded_cpu",
"OnnxBackendNodeModelTest.test_sce_none_cpu",
"OnnxBackendNodeModelTest.test_sce_none_expanded_cpu",
"OnnxBackendNodeModelTest.test_sce_none_log_prob_cpu",
"OnnxBackendNodeModelTest.test_sce_none_log_prob_expanded_cpu",
"OnnxBackendNodeModelTest.test_sce_none_weights_cpu",
"OnnxBackendNodeModelTest.test_sce_none_weights_expanded_cpu",
"OnnxBackendNodeModelTest.test_sce_none_weights_log_prob_cpu",
"OnnxBackendNodeModelTest.test_sce_none_weights_log_prob_expanded_cpu",
"OnnxBackendNodeModelTest.test_sce_sum_cpu",
"OnnxBackendNodeModelTest.test_sce_sum_expanded_cpu",
"OnnxBackendNodeModelTest.test_sce_sum_log_prob_cpu",
"OnnxBackendNodeModelTest.test_sce_sum_log_prob_expanded_cpu"),
(xfail_issue_38712,
"OnnxBackendNodeModelTest.test_mod_mixed_sign_int16_cpu",
"OnnxBackendNodeModelTest.test_mod_uint8_cpu",

View File

@ -0,0 +1,64 @@
ir_version: 7
producer_name: "backend-test"
graph {
node {
input: "input"
input: "target"
output: "loss"
op_type: "NegativeLogLikelihoodLoss"
attribute {
name: "reduction"
s: "mean"
type: STRING
}
}
name: "test_nllloss_NCd1"
input {
name: "input"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 3
}
dim {
dim_value: 5
}
dim {
dim_value: 2
}
}
}
}
}
input {
name: "target"
type {
tensor_type {
elem_type: 7
shape {
dim {
dim_value: 3
}
dim {
dim_value: 2
}
}
}
}
}
output {
name: "loss"
type {
tensor_type {
elem_type: 1
shape {
}
}
}
}
}
opset_import {
version: 13
}

View File

@ -0,0 +1,58 @@
ir_version: 7
producer_name: "backend-test"
graph {
node {
input: "x"
input: "y"
output: "z"
op_type: "SoftmaxCrossEntropyLoss"
attribute {
name: "reduction"
s: "mean"
type: STRING
}
}
name: "test_sce_mean"
input {
name: "x"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 3
}
dim {
dim_value: 5
}
}
}
}
}
input {
name: "y"
type {
tensor_type {
elem_type: 7
shape {
dim {
dim_value: 3
}
}
}
}
}
output {
name: "z"
type {
tensor_type {
elem_type: 1
shape {
}
}
}
}
}
opset_import {
version: 13
}

View File

@ -3982,3 +3982,50 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_multiple_slices_last_layer)
test_case.add_expected_output<float>(Shape{1, 320, 320, 9}, o2);
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, onnx_softmax_crossentropy_loss_mean)
{
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/softmax_crossentropy_loss_mean.prototxt"));
auto test_case = test::TestCase<TestEngine>(function);
test_case.add_input<float>({0.54881352186203,
0.7151893377304077,
0.6027633547782898,
0.5448831915855408,
0.42365479469299316,
0.6458941102027893,
0.4375872015953064,
0.891772985458374,
0.9636627435684204,
0.3834415078163147,
0.7917250394821167,
0.5288949012756348,
0.5680445432662964,
0.9255966544151306,
0.07103605568408966});
test_case.add_input<int64_t>({1, 4, 3});
test_case.add_expected_output<float>(Shape{}, {1.561384797096252441});
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, onnx_negativelog_likelihood_loss)
{
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/negativelog_likelihood_loss.prototxt"));
auto test_case = test::TestCase<TestEngine>(function);
test_case.add_input<float>({
0.54881352186203, 0.7151893377304077, 0.6027633547782898, 0.5448831915855408,
0.42365479469299316, 0.6458941102027893, 0.4375872015953064, 0.891772985458374,
0.9636627435684204, 0.3834415078163147, 0.7917250394821167, 0.5288949012756348,
0.5680445432662964, 0.9255966544151306, 0.07103605568408966, 0.08712930232286453,
0.020218396559357643, 0.832619845867157, 0.7781567573547363, 0.8700121641159058,
0.978618323802948, 0.7991585731506348, 0.4614793658256531, 0.7805292010307312,
0.11827442795038223, 0.6399210095405579, 0.14335328340530396, 0.9446688890457153,
0.5218483209609985, 0.4146619439125061,
});
test_case.add_input<int64_t>({3, 3, 2, 4, 2, 0});
test_case.add_expected_output<float>(Shape{}, {-0.531306922435760498});
test_case.run();
}