ONNX Slice operator support types dynamically (#4507)
This commit is contained in:
parent
f88f81c6ba
commit
b0043bb599
@ -139,15 +139,16 @@ namespace ngraph
|
||||
// expected_output_shape: {3, 3, 1, 1}
|
||||
OutputVector adjusted_indices(slice_indices_length);
|
||||
std::vector<uint64_t> target_axes(axes);
|
||||
const auto gather_axis = default_opset::Constant::create(element::i64, {}, {0});
|
||||
const auto gather_axis =
|
||||
default_opset::Constant::create(indices.get_element_type(), {}, {0});
|
||||
|
||||
int added_indices_number = 0;
|
||||
for (int i = 0; i < slice_indices_length; ++i)
|
||||
{
|
||||
if (std::find(std::begin(axes), std::end(axes), i) == axes.end())
|
||||
{
|
||||
adjusted_indices[i] =
|
||||
default_opset::Constant::create(element::i64, {1}, {fill_in_value});
|
||||
adjusted_indices[i] = default_opset::Constant::create(
|
||||
indices.get_element_type(), {1}, {fill_in_value});
|
||||
target_axes.insert(std::next(target_axes.begin(), i), i);
|
||||
++added_indices_number;
|
||||
}
|
||||
@ -156,7 +157,7 @@ namespace ngraph
|
||||
adjusted_indices[i] = std::make_shared<default_opset::Gather>(
|
||||
indices,
|
||||
default_opset::Constant::create(
|
||||
element::i64, {1}, {i - added_indices_number}),
|
||||
indices.get_element_type(), {1}, {i - added_indices_number}),
|
||||
gather_axis);
|
||||
}
|
||||
}
|
||||
|
@ -159,7 +159,6 @@ xfail_issue_47330 = xfail_test(reason="RuntimeError: Eltwise node with name `[na
|
||||
"FP64 precision.")
|
||||
xfail_issue_47337 = xfail_test(reason="RuntimeError: Unsupported dynamic ops: v1::OneHot")
|
||||
xfail_issue_33593 = xfail_test(reason="Current implementation of MaxPool doesn't support indices output")
|
||||
xfail_issue_49113 = xfail_test(reason="NLL Loss error: While validating ONNX node '<Node(Slice):")
|
||||
xfail_issue_48098 = xfail_test(reason="ngraph.exceptions.UserInputError: ('Expected %s parameters, "
|
||||
"received %s.', <value1>, <value2>)")
|
||||
xfail_issue_48100 = xfail_test(reason="RuntimeError: cpu_convert can't convert from: "
|
||||
|
@ -75,7 +75,6 @@ from tests import (BACKEND_NAME,
|
||||
xfail_issue_47330,
|
||||
xfail_issue_47337,
|
||||
xfail_issue_48052,
|
||||
xfail_issue_49113,
|
||||
xfail_issue_48098,
|
||||
xfail_issue_48100,
|
||||
xfail_issue_49207,
|
||||
@ -295,111 +294,6 @@ tests_expected_to_fail = [
|
||||
"OnnxBackendNodeModelTest.test_resize_downsample_sizes_nearest_cpu"),
|
||||
(xfail_issue_33581,
|
||||
"OnnxBackendNodeModelTest.test_gather_elements_negative_indices_cpu"),
|
||||
(xfail_issue_49113,
|
||||
"OnnxBackendNodeModelTest.test_nllloss_NC_cpu",
|
||||
"OnnxBackendNodeModelTest.test_nllloss_NC_expanded_cpu",
|
||||
"OnnxBackendNodeModelTest.test_nllloss_NCd1_cpu",
|
||||
"OnnxBackendNodeModelTest.test_nllloss_NCd1_expanded_cpu",
|
||||
"OnnxBackendNodeModelTest.test_nllloss_NCd1_ii_cpu",
|
||||
"OnnxBackendNodeModelTest.test_nllloss_NCd1_ii_expanded_cpu",
|
||||
"OnnxBackendNodeModelTest.test_nllloss_NCd1_mean_weight_negative_ii_cpu",
|
||||
"OnnxBackendNodeModelTest.test_nllloss_NCd1_mean_weight_negative_ii_expanded_cpu",
|
||||
"OnnxBackendNodeModelTest.test_nllloss_NCd1_weight_cpu",
|
||||
"OnnxBackendNodeModelTest.test_nllloss_NCd1_weight_expanded_cpu",
|
||||
"OnnxBackendNodeModelTest.test_nllloss_NCd1_weight_ii_cpu",
|
||||
"OnnxBackendNodeModelTest.test_nllloss_NCd1_weight_ii_expanded_cpu",
|
||||
"OnnxBackendNodeModelTest.test_nllloss_NCd1d2_cpu",
|
||||
"OnnxBackendNodeModelTest.test_nllloss_NCd1d2_expanded_cpu",
|
||||
"OnnxBackendNodeModelTest.test_nllloss_NCd1d2_no_weight_reduction_mean_ii_cpu",
|
||||
"OnnxBackendNodeModelTest.test_nllloss_NCd1d2_no_weight_reduction_mean_ii_expanded_cpu",
|
||||
"OnnxBackendNodeModelTest.test_nllloss_NCd1d2_reduction_mean_cpu",
|
||||
"OnnxBackendNodeModelTest.test_nllloss_NCd1d2_reduction_mean_expanded_cpu",
|
||||
"OnnxBackendNodeModelTest.test_nllloss_NCd1d2_reduction_sum_cpu",
|
||||
"OnnxBackendNodeModelTest.test_nllloss_NCd1d2_reduction_sum_expanded_cpu",
|
||||
"OnnxBackendNodeModelTest.test_nllloss_NCd1d2_with_weight_cpu",
|
||||
"OnnxBackendNodeModelTest.test_nllloss_NCd1d2_with_weight_expanded_cpu",
|
||||
"OnnxBackendNodeModelTest.test_nllloss_NCd1d2_with_weight_reduction_mean_cpu",
|
||||
"OnnxBackendNodeModelTest.test_nllloss_NCd1d2_with_weight_reduction_mean_expanded_cpu",
|
||||
"OnnxBackendNodeModelTest.test_nllloss_NCd1d2_with_weight_reduction_sum_cpu",
|
||||
"OnnxBackendNodeModelTest.test_nllloss_NCd1d2_with_weight_reduction_sum_expanded_cpu",
|
||||
"OnnxBackendNodeModelTest.test_nllloss_NCd1d2_with_weight_reduction_sum_ii_cpu",
|
||||
"OnnxBackendNodeModelTest.test_nllloss_NCd1d2_with_weight_reduction_sum_ii_expanded_cpu",
|
||||
"OnnxBackendNodeModelTest.test_nllloss_NCd1d2d3_none_no_weight_negative_ii_cpu",
|
||||
"OnnxBackendNodeModelTest.test_nllloss_NCd1d2d3_none_no_weight_negative_ii_expanded_cpu",
|
||||
"OnnxBackendNodeModelTest.test_nllloss_NCd1d2d3_sum_weight_high_ii_cpu",
|
||||
"OnnxBackendNodeModelTest.test_nllloss_NCd1d2d3_sum_weight_high_ii_expanded_cpu",
|
||||
"OnnxBackendNodeModelTest.test_nllloss_NCd1d2d3d4d5_mean_weight_cpu",
|
||||
"OnnxBackendNodeModelTest.test_nllloss_NCd1d2d3d4d5_mean_weight_expanded_cpu",
|
||||
"OnnxBackendNodeModelTest.test_nllloss_NCd1d2d3d4d5_none_no_weight_cpu",
|
||||
"OnnxBackendNodeModelTest.test_nllloss_NCd1d2d3d4d5_none_no_weight_expanded_cpu",
|
||||
"OnnxBackendNodeModelTest.test_sce_NCd1_mean_weight_negative_ii_cpu",
|
||||
"OnnxBackendNodeModelTest.test_sce_NCd1_mean_weight_negative_ii_expanded_cpu",
|
||||
"OnnxBackendNodeModelTest.test_sce_NCd1_mean_weight_negative_ii_log_prob_cpu",
|
||||
"OnnxBackendNodeModelTest.test_sce_NCd1_mean_weight_negative_ii_log_prob_expanded_cpu",
|
||||
"OnnxBackendNodeModelTest.test_sce_NCd1d2d3_none_no_weight_negative_ii_cpu",
|
||||
"OnnxBackendNodeModelTest.test_sce_NCd1d2d3_none_no_weight_negative_ii_expanded_cpu",
|
||||
"OnnxBackendNodeModelTest.test_sce_NCd1d2d3_none_no_weight_negative_ii_log_prob_cpu",
|
||||
"OnnxBackendNodeModelTest.test_sce_NCd1d2d3_none_no_weight_negative_ii_log_prob_expanded_cpu",
|
||||
"OnnxBackendNodeModelTest.test_sce_NCd1d2d3_sum_weight_high_ii_cpu",
|
||||
"OnnxBackendNodeModelTest.test_sce_NCd1d2d3_sum_weight_high_ii_expanded_cpu",
|
||||
"OnnxBackendNodeModelTest.test_sce_NCd1d2d3_sum_weight_high_ii_log_prob_cpu",
|
||||
"OnnxBackendNodeModelTest.test_sce_NCd1d2d3_sum_weight_high_ii_log_prob_expanded_cpu",
|
||||
"OnnxBackendNodeModelTest.test_sce_NCd1d2d3d4d5_mean_weight_cpu",
|
||||
"OnnxBackendNodeModelTest.test_sce_NCd1d2d3d4d5_mean_weight_expanded_cpu",
|
||||
"OnnxBackendNodeModelTest.test_sce_NCd1d2d3d4d5_mean_weight_log_prob_cpu",
|
||||
"OnnxBackendNodeModelTest.test_sce_NCd1d2d3d4d5_mean_weight_log_prob_expanded_cpu",
|
||||
"OnnxBackendNodeModelTest.test_sce_NCd1d2d3d4d5_none_no_weight_cpu",
|
||||
"OnnxBackendNodeModelTest.test_sce_NCd1d2d3d4d5_none_no_weight_expanded_cpu",
|
||||
"OnnxBackendNodeModelTest.test_sce_NCd1d2d3d4d5_none_no_weight_log_prob_cpu",
|
||||
"OnnxBackendNodeModelTest.test_sce_NCd1d2d3d4d5_none_no_weight_log_prob_expanded_cpu",
|
||||
"OnnxBackendNodeModelTest.test_sce_mean_3d_cpu",
|
||||
"OnnxBackendNodeModelTest.test_sce_mean_3d_expanded_cpu",
|
||||
"OnnxBackendNodeModelTest.test_sce_mean_3d_log_prob_cpu",
|
||||
"OnnxBackendNodeModelTest.test_sce_mean_3d_log_prob_expanded_cpu",
|
||||
"OnnxBackendNodeModelTest.test_sce_mean_cpu",
|
||||
"OnnxBackendNodeModelTest.test_sce_mean_expanded_cpu",
|
||||
"OnnxBackendNodeModelTest.test_sce_mean_log_prob_cpu",
|
||||
"OnnxBackendNodeModelTest.test_sce_mean_log_prob_expanded_cpu",
|
||||
"OnnxBackendNodeModelTest.test_sce_mean_no_weight_ii_3d_cpu",
|
||||
"OnnxBackendNodeModelTest.test_sce_mean_no_weight_ii_3d_expanded_cpu",
|
||||
"OnnxBackendNodeModelTest.test_sce_mean_no_weight_ii_3d_log_prob_cpu",
|
||||
"OnnxBackendNodeModelTest.test_sce_mean_no_weight_ii_3d_log_prob_expanded_cpu",
|
||||
"OnnxBackendNodeModelTest.test_sce_mean_no_weight_ii_4d_cpu",
|
||||
"OnnxBackendNodeModelTest.test_sce_mean_no_weight_ii_4d_expanded_cpu",
|
||||
"OnnxBackendNodeModelTest.test_sce_mean_no_weight_ii_4d_log_prob_cpu",
|
||||
"OnnxBackendNodeModelTest.test_sce_mean_no_weight_ii_4d_log_prob_expanded_cpu",
|
||||
"OnnxBackendNodeModelTest.test_sce_mean_no_weight_ii_cpu",
|
||||
"OnnxBackendNodeModelTest.test_sce_mean_no_weight_ii_expanded_cpu",
|
||||
"OnnxBackendNodeModelTest.test_sce_mean_no_weight_ii_log_prob_cpu",
|
||||
"OnnxBackendNodeModelTest.test_sce_mean_no_weight_ii_log_prob_expanded_cpu",
|
||||
"OnnxBackendNodeModelTest.test_sce_mean_weight_cpu",
|
||||
"OnnxBackendNodeModelTest.test_sce_mean_weight_expanded_cpu",
|
||||
"OnnxBackendNodeModelTest.test_sce_mean_weight_ii_3d_cpu",
|
||||
"OnnxBackendNodeModelTest.test_sce_mean_weight_ii_3d_expanded_cpu",
|
||||
"OnnxBackendNodeModelTest.test_sce_mean_weight_ii_3d_log_prob_cpu",
|
||||
"OnnxBackendNodeModelTest.test_sce_mean_weight_ii_3d_log_prob_expanded_cpu",
|
||||
"OnnxBackendNodeModelTest.test_sce_mean_weight_ii_4d_cpu",
|
||||
"OnnxBackendNodeModelTest.test_sce_mean_weight_ii_4d_expanded_cpu",
|
||||
"OnnxBackendNodeModelTest.test_sce_mean_weight_ii_4d_log_prob_cpu",
|
||||
"OnnxBackendNodeModelTest.test_sce_mean_weight_ii_4d_log_prob_expanded_cpu",
|
||||
"OnnxBackendNodeModelTest.test_sce_mean_weight_ii_cpu",
|
||||
"OnnxBackendNodeModelTest.test_sce_mean_weight_ii_expanded_cpu",
|
||||
"OnnxBackendNodeModelTest.test_sce_mean_weight_ii_log_prob_cpu",
|
||||
"OnnxBackendNodeModelTest.test_sce_mean_weight_ii_log_prob_expanded_cpu",
|
||||
"OnnxBackendNodeModelTest.test_sce_mean_weight_log_prob_cpu",
|
||||
"OnnxBackendNodeModelTest.test_sce_mean_weight_log_prob_expanded_cpu",
|
||||
"OnnxBackendNodeModelTest.test_sce_none_cpu",
|
||||
"OnnxBackendNodeModelTest.test_sce_none_expanded_cpu",
|
||||
"OnnxBackendNodeModelTest.test_sce_none_log_prob_cpu",
|
||||
"OnnxBackendNodeModelTest.test_sce_none_log_prob_expanded_cpu",
|
||||
"OnnxBackendNodeModelTest.test_sce_none_weights_cpu",
|
||||
"OnnxBackendNodeModelTest.test_sce_none_weights_expanded_cpu",
|
||||
"OnnxBackendNodeModelTest.test_sce_none_weights_log_prob_cpu",
|
||||
"OnnxBackendNodeModelTest.test_sce_none_weights_log_prob_expanded_cpu",
|
||||
"OnnxBackendNodeModelTest.test_sce_sum_cpu",
|
||||
"OnnxBackendNodeModelTest.test_sce_sum_expanded_cpu",
|
||||
"OnnxBackendNodeModelTest.test_sce_sum_log_prob_cpu",
|
||||
"OnnxBackendNodeModelTest.test_sce_sum_log_prob_expanded_cpu"),
|
||||
(xfail_issue_38712,
|
||||
"OnnxBackendNodeModelTest.test_mod_mixed_sign_int16_cpu",
|
||||
"OnnxBackendNodeModelTest.test_mod_uint8_cpu",
|
||||
|
64
ngraph/test/models/onnx/negativelog_likelihood_loss.prototxt
Normal file
64
ngraph/test/models/onnx/negativelog_likelihood_loss.prototxt
Normal file
@ -0,0 +1,64 @@
|
||||
ir_version: 7
|
||||
producer_name: "backend-test"
|
||||
graph {
|
||||
node {
|
||||
input: "input"
|
||||
input: "target"
|
||||
output: "loss"
|
||||
op_type: "NegativeLogLikelihoodLoss"
|
||||
attribute {
|
||||
name: "reduction"
|
||||
s: "mean"
|
||||
type: STRING
|
||||
}
|
||||
}
|
||||
name: "test_nllloss_NCd1"
|
||||
input {
|
||||
name: "input"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 3
|
||||
}
|
||||
dim {
|
||||
dim_value: 5
|
||||
}
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "target"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 7
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 3
|
||||
}
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "loss"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
opset_import {
|
||||
version: 13
|
||||
}
|
@ -0,0 +1,58 @@
|
||||
ir_version: 7
|
||||
producer_name: "backend-test"
|
||||
graph {
|
||||
node {
|
||||
input: "x"
|
||||
input: "y"
|
||||
output: "z"
|
||||
op_type: "SoftmaxCrossEntropyLoss"
|
||||
attribute {
|
||||
name: "reduction"
|
||||
s: "mean"
|
||||
type: STRING
|
||||
}
|
||||
}
|
||||
name: "test_sce_mean"
|
||||
input {
|
||||
name: "x"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 3
|
||||
}
|
||||
dim {
|
||||
dim_value: 5
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
input {
|
||||
name: "y"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 7
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 3
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "z"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
opset_import {
|
||||
version: 13
|
||||
}
|
@ -3982,3 +3982,50 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_multiple_slices_last_layer)
|
||||
test_case.add_expected_output<float>(Shape{1, 320, 320, 9}, o2);
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_softmax_crossentropy_loss_mean)
|
||||
{
|
||||
auto function = onnx_import::import_onnx_model(
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/softmax_crossentropy_loss_mean.prototxt"));
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(function);
|
||||
test_case.add_input<float>({0.54881352186203,
|
||||
0.7151893377304077,
|
||||
0.6027633547782898,
|
||||
0.5448831915855408,
|
||||
0.42365479469299316,
|
||||
0.6458941102027893,
|
||||
0.4375872015953064,
|
||||
0.891772985458374,
|
||||
0.9636627435684204,
|
||||
0.3834415078163147,
|
||||
0.7917250394821167,
|
||||
0.5288949012756348,
|
||||
0.5680445432662964,
|
||||
0.9255966544151306,
|
||||
0.07103605568408966});
|
||||
test_case.add_input<int64_t>({1, 4, 3});
|
||||
test_case.add_expected_output<float>(Shape{}, {1.561384797096252441});
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_negativelog_likelihood_loss)
|
||||
{
|
||||
auto function = onnx_import::import_onnx_model(
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/negativelog_likelihood_loss.prototxt"));
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(function);
|
||||
test_case.add_input<float>({
|
||||
0.54881352186203, 0.7151893377304077, 0.6027633547782898, 0.5448831915855408,
|
||||
0.42365479469299316, 0.6458941102027893, 0.4375872015953064, 0.891772985458374,
|
||||
0.9636627435684204, 0.3834415078163147, 0.7917250394821167, 0.5288949012756348,
|
||||
0.5680445432662964, 0.9255966544151306, 0.07103605568408966, 0.08712930232286453,
|
||||
0.020218396559357643, 0.832619845867157, 0.7781567573547363, 0.8700121641159058,
|
||||
0.978618323802948, 0.7991585731506348, 0.4614793658256531, 0.7805292010307312,
|
||||
0.11827442795038223, 0.6399210095405579, 0.14335328340530396, 0.9446688890457153,
|
||||
0.5218483209609985, 0.4146619439125061,
|
||||
});
|
||||
test_case.add_input<int64_t>({3, 3, 2, 4, 2, 0});
|
||||
test_case.add_expected_output<float>(Shape{}, {-0.531306922435760498});
|
||||
test_case.run();
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user