Axis normalizing in ArgMin/ArgMax operations (#8589)

This commit is contained in:
Artur Kulikowski 2021-11-15 16:10:24 +01:00 committed by GitHub
parent e279ec5962
commit 2e9f83d705
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 9 additions and 21 deletions

View File

@ -53,11 +53,17 @@ std::shared_ptr<ngraph::Node> ArgMinMaxFactory::make_topk_subgraph(default_opset
// res_index = dims_on_axis - topk->output(1) = 6 - 3 = 3
// result = res_index - 1 = 3 - 1 = 2
const auto axis_node = default_opset::Constant::create(ngraph::element::i64, Shape{1}, {m_axis});
const int64_t normalized_axis =
normalize_axis(m_input_node.get_node(), m_axis, m_input_node.get_partial_shape().rank());
const auto axis_node = default_opset::Constant::create(ngraph::element::i64, Shape{1}, {normalized_axis});
const auto reverse = std::make_shared<opset1::Reverse>(m_input_node, axis_node, opset1::Reverse::Mode::INDEX);
const auto topk =
std::make_shared<default_opset::TopK>(reverse, k_node, m_axis, mode, default_opset::TopK::SortType::NONE);
const auto topk = std::make_shared<default_opset::TopK>(reverse,
k_node,
normalized_axis,
mode,
default_opset::TopK::SortType::NONE);
const auto data_shape = std::make_shared<default_opset::ShapeOf>(m_input_node);
const auto dims_on_axis = std::make_shared<default_opset::Gather>(

View File

@ -93,7 +93,6 @@ xfail_issue_44965 = xfail_test(reason="Expected: RuntimeError: value info has no
xfail_issue_44968 = xfail_test(reason="Expected: Unsupported dynamic op: Squeeze")
xfail_issue_47323 = xfail_test(reason="RuntimeError: The plugin does not support FP64")
xfail_issue_47337 = xfail_test(reason="RuntimeError: Unsupported dynamic ops: v1::OneHot")
xfail_issue_55760 = xfail_test(reason="RuntimeError: Reversed axis have axes above the source space shape")
# Model MSFT issues:
xfail_issue_37957 = xfail_test(reason="RuntimeError: nGraph does not support the following ONNX operations: "

View File

@ -40,7 +40,6 @@ from tests import (
xfail_issue_49207,
xfail_issue_49750,
xfail_issue_52463,
xfail_issue_55760,
xfail_issue_58033,
xfail_issue_63033,
xfail_issue_63036,
@ -144,13 +143,6 @@ tests_expected_to_fail = [
"OnnxBackendNodeModelTest.test_scatter_elements_with_negative_indices_cpu",
"OnnxBackendNodeModelTest.test_gather_negative_indices_cpu",
),
(
xfail_issue_55760,
"OnnxBackendNodeModelTest.test_argmax_negative_axis_keepdims_example_select_last_index_cpu",
"OnnxBackendNodeModelTest.test_argmin_negative_axis_keepdims_example_select_last_index_cpu",
"OnnxBackendNodeModelTest.test_argmax_negative_axis_keepdims_random_select_last_index_cpu",
"OnnxBackendNodeModelTest.test_argmin_negative_axis_keepdims_random_select_last_index_cpu",
),
(
xfail_issue_38091,
"OnnxBackendNodeModelTest.test_gather_negative_indices_cpu",

View File

@ -98,7 +98,6 @@ xfail_issue_44965 = xfail_test(reason="Expected: RuntimeError: value info has no
xfail_issue_44968 = xfail_test(reason="Expected: Unsupported dynamic op: Squeeze")
xfail_issue_47323 = xfail_test(reason="RuntimeError: The plugin does not support FP64")
xfail_issue_47337 = xfail_test(reason="RuntimeError: Unsupported dynamic ops: v1::OneHot")
xfail_issue_55760 = xfail_test(reason="RuntimeError: Reversed axis have axes above the source space shape")
# Model MSFT issues:
xfail_issue_37957 = xfail_test(reason="RuntimeError: nGraph does not support the following ONNX operations: "

View File

@ -39,7 +39,6 @@ from tests_compatibility import (
xfail_issue_49207,
xfail_issue_49750,
xfail_issue_52463,
xfail_issue_55760,
xfail_issue_58033,
xfail_issue_63033,
xfail_issue_63036,
@ -129,13 +128,6 @@ tests_expected_to_fail = [
"OnnxBackendNodeModelTest.test_scatter_elements_with_negative_indices_cpu",
"OnnxBackendNodeModelTest.test_gather_negative_indices_cpu",
),
(
xfail_issue_55760,
"OnnxBackendNodeModelTest.test_argmax_negative_axis_keepdims_example_select_last_index_cpu",
"OnnxBackendNodeModelTest.test_argmin_negative_axis_keepdims_example_select_last_index_cpu",
"OnnxBackendNodeModelTest.test_argmax_negative_axis_keepdims_random_select_last_index_cpu",
"OnnxBackendNodeModelTest.test_argmin_negative_axis_keepdims_random_select_last_index_cpu",
),
(
xfail_issue_38091,
"OnnxBackendNodeModelTest.test_gather_negative_indices_cpu",