[ONNX] ArgMin/ArgMax support for select_last_index (#5661)
* Add Reverse Op to opset * Worksave with Reverse path * Add last_index support * refactor argminmax factory * Remove old xfail, add new one * Fix proto file for argmax * Rewrite test for select_last_index * Add CPU tests to Manifest * Update manifest * Remove Reverse from opset7 * Refactor arg_min_max factory * Added example comment in arg_min_max * Codestyle changes
This commit is contained in:
parent
39c08e40f6
commit
c3ca8d048e
@ -26,13 +26,6 @@ namespace ngraph
|
||||
{
|
||||
OutputVector argmax(const Node& node)
|
||||
{
|
||||
const auto select_last_index =
|
||||
node.get_attribute_value<std::int64_t>("select_last_index", 0);
|
||||
CHECK_VALID_NODE(node,
|
||||
select_last_index == 0,
|
||||
"Mode 'select_last_index=1' is not supported by current "
|
||||
"implementation of ArgMax");
|
||||
|
||||
const utils::ArgMinMaxFactory arg_factory(node);
|
||||
return {arg_factory.make_arg_max()};
|
||||
}
|
||||
|
@ -26,13 +26,6 @@ namespace ngraph
|
||||
{
|
||||
OutputVector argmin(const Node& node)
|
||||
{
|
||||
const auto select_last_index =
|
||||
node.get_attribute_value<std::int64_t>("select_last_index", 0);
|
||||
CHECK_VALID_NODE(node,
|
||||
select_last_index == 0,
|
||||
"Mode 'select_last_index=1' is not supported by current "
|
||||
"implementation of ArgMin");
|
||||
|
||||
const utils::ArgMinMaxFactory arg_factory(node);
|
||||
return {arg_factory.make_arg_min()};
|
||||
}
|
||||
|
@ -4,6 +4,7 @@
|
||||
|
||||
#include "utils/arg_min_max_factory.hpp"
|
||||
#include "default_opset.hpp"
|
||||
#include "ngraph/opsets/opset1.hpp"
|
||||
#include "ngraph/validation_util.hpp"
|
||||
|
||||
namespace ngraph
|
||||
@ -14,9 +15,11 @@ namespace ngraph
|
||||
{
|
||||
ArgMinMaxFactory::ArgMinMaxFactory(const Node& node)
|
||||
: m_keep_dims{node.get_attribute_value<std::int64_t>("keepdims", 1)}
|
||||
, m_input_node{node.get_ng_inputs().at(0)}
|
||||
, m_axis{node.get_attribute_value<std::int64_t>("axis", 0)}
|
||||
, m_select_last_index{
|
||||
node.get_attribute_value<std::int64_t>("select_last_index", 0)}
|
||||
{
|
||||
m_input_node = node.get_ng_inputs().at(0);
|
||||
}
|
||||
|
||||
std::shared_ptr<ngraph::Node> ArgMinMaxFactory::make_arg_max() const
|
||||
@ -34,19 +37,80 @@ namespace ngraph
|
||||
{
|
||||
const auto k_node =
|
||||
default_opset::Constant::create(ngraph::element::i64, Shape{}, {1});
|
||||
|
||||
if (m_select_last_index == 1)
|
||||
{
|
||||
// Example (ArgMin):
|
||||
// The goal is to get the index of the last occurence of the
|
||||
// minimum value present in given input tensor.
|
||||
//
|
||||
// Input: [1, 2, 1, 3, 4, 4]
|
||||
// Expected output: [2]
|
||||
//
|
||||
// Top-K is always returning the "most-left" result. The trick is to
|
||||
// reverse input to find the "most-right" occurence which is equal to
|
||||
// the last occurence in the original input.
|
||||
// reverse = [4, 4, 3, 1, 2, 1]
|
||||
//
|
||||
// Run TopK on reversed tensor, in the example output with index values
|
||||
// is equal to:
|
||||
// topk->output(1) = 3
|
||||
//
|
||||
// Using ShapeOf and Gather on input obtain length of the input tensor
|
||||
// along axis, in the example this is equal to:
|
||||
// dims_on_axis = 6
|
||||
//
|
||||
// Now using two Substract ops calculate resulting index:
|
||||
// res_index = dims_on_axis - topk->output(1) = 6 - 3 = 3
|
||||
// result = res_index - 1 = 3 - 1 = 2
|
||||
|
||||
const auto axis_node =
|
||||
default_opset::Constant::create(ngraph::element::i64, Shape{1}, {m_axis});
|
||||
const auto reverse = std::make_shared<opset1::Reverse>(
|
||||
m_input_node, axis_node, opset1::Reverse::Mode::INDEX);
|
||||
|
||||
const auto topk = std::make_shared<default_opset::TopK>(
|
||||
reverse, k_node, m_axis, mode, default_opset::TopK::SortType::NONE);
|
||||
|
||||
const auto data_shape = std::make_shared<default_opset::ShapeOf>(m_input_node);
|
||||
const auto dims_on_axis = std::make_shared<default_opset::Gather>(
|
||||
data_shape,
|
||||
axis_node,
|
||||
default_opset::Constant::create(ngraph::element::i64, Shape{}, {0}));
|
||||
|
||||
const auto res_index = std::make_shared<default_opset::Subtract>(
|
||||
dims_on_axis,
|
||||
std::make_shared<default_opset::Convert>(topk->output(1), element::i64));
|
||||
const auto result = std::make_shared<default_opset::Subtract>(
|
||||
res_index,
|
||||
default_opset::Constant::create(ngraph::element::i64, Shape{1}, {1}));
|
||||
|
||||
if (m_keep_dims == 0)
|
||||
{
|
||||
const auto axis_to_remove = default_opset::Constant::create(
|
||||
element::u64, Shape{}, {topk->get_axis()});
|
||||
|
||||
return std::make_shared<default_opset::Squeeze>(result, axis_to_remove);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
const auto topk = std::make_shared<default_opset::TopK>(
|
||||
m_input_node, k_node, m_axis, mode, default_opset::TopK::SortType::NONE);
|
||||
|
||||
const auto result =
|
||||
std::make_shared<default_opset::Convert>(topk->output(1), element::i64);
|
||||
|
||||
if (m_keep_dims == 0)
|
||||
{
|
||||
const auto axis_to_remove =
|
||||
default_opset::Constant::create(element::u64, Shape{}, {topk->get_axis()});
|
||||
const auto reshaped_indices =
|
||||
std::make_shared<default_opset::Squeeze>(topk->output(1), axis_to_remove);
|
||||
|
||||
return std::make_shared<default_opset::Convert>(reshaped_indices, element::i64);
|
||||
return std::make_shared<default_opset::Squeeze>(result, axis_to_remove);
|
||||
}
|
||||
return std::make_shared<default_opset::Convert>(topk->output(1), element::i64);
|
||||
|
||||
return result;
|
||||
}
|
||||
} // namespace utils
|
||||
} // namespace onnx_import
|
||||
|
@ -39,6 +39,7 @@ namespace ngraph
|
||||
const std::int64_t m_keep_dims;
|
||||
Output<ngraph::Node> m_input_node;
|
||||
std::int64_t m_axis;
|
||||
std::int64_t m_select_last_index;
|
||||
};
|
||||
|
||||
} // namespace utils
|
||||
|
@ -111,12 +111,12 @@ xfail_issue_44970 = xfail_test(reason="Assertion error")
|
||||
xfail_issue_44976 = xfail_test(reason="E RuntimeError: Quantize layer with name:"
|
||||
"FakeQuantize_xxx has non const input on 1 port")
|
||||
xfail_issue_46762 = xfail_test(reason="Incorrect result of Minimum op if uint data type is used")
|
||||
xfail_issue_46765 = xfail_test(reason="select_last_index attribute is not supported by ArgMin and ArgMax")
|
||||
xfail_issue_47323 = xfail_test(reason="RuntimeError: The plugin does not support FP64")
|
||||
xfail_issue_47337 = xfail_test(reason="RuntimeError: Unsupported dynamic ops: v1::OneHot")
|
||||
xfail_issue_33593 = xfail_test(reason="Current implementation of MaxPool doesn't support indices output")
|
||||
xfail_issue_51993 = xfail_test(reason="PRelu supports only 1D tensor for 'slope' input broadcasted"
|
||||
"by channel")
|
||||
xfail_issue_55760 = xfail_test(reason="RuntimeError: Reversed axis have axes above the source space shape")
|
||||
|
||||
# Model MSFT issues:
|
||||
xfail_issue_37957 = xfail_test(reason="RuntimeError: nGraph does not support the following ONNX operations:"
|
||||
|
@ -50,7 +50,6 @@ from tests import (BACKEND_NAME,
|
||||
xfail_issue_45180,
|
||||
xfail_issue_45344,
|
||||
xfail_issue_46762,
|
||||
xfail_issue_46765,
|
||||
xfail_issue_47323,
|
||||
xfail_issue_47337,
|
||||
xfail_issue_48052,
|
||||
@ -60,7 +59,8 @@ from tests import (BACKEND_NAME,
|
||||
xfail_issue_49753,
|
||||
xfail_issue_49754,
|
||||
xfail_issue_52463,
|
||||
xfail_issue_51993)
|
||||
xfail_issue_51993,
|
||||
xfail_issue_55760)
|
||||
|
||||
|
||||
def expect_fail(test_case_path, xfail): # type: (str) -> None
|
||||
@ -165,23 +165,11 @@ tests_expected_to_fail = [
|
||||
"OnnxBackendNodeModelTest.test_min_uint16_cpu",
|
||||
"OnnxBackendNodeModelTest.test_min_uint32_cpu",
|
||||
"OnnxBackendNodeModelTest.test_min_uint64_cpu"),
|
||||
(xfail_issue_46765,
|
||||
(xfail_issue_55760,
|
||||
"OnnxBackendNodeModelTest.test_argmax_negative_axis_keepdims_example_select_last_index_cpu",
|
||||
"OnnxBackendNodeModelTest.test_argmax_keepdims_example_select_last_index_cpu",
|
||||
"OnnxBackendNodeModelTest.test_argmax_no_keepdims_example_select_last_index_cpu",
|
||||
"OnnxBackendNodeModelTest.test_argmin_negative_axis_keepdims_example_select_last_index_cpu",
|
||||
"OnnxBackendNodeModelTest.test_argmin_keepdims_example_select_last_index_cpu",
|
||||
"OnnxBackendNodeModelTest.test_argmin_no_keepdims_example_select_last_index_cpu",
|
||||
"OnnxBackendNodeModelTest.test_argmax_default_axis_example_select_last_index_cpu",
|
||||
"OnnxBackendNodeModelTest.test_argmax_default_axis_random_select_last_index_cpu",
|
||||
"OnnxBackendNodeModelTest.test_argmax_keepdims_random_select_last_index_cpu",
|
||||
"OnnxBackendNodeModelTest.test_argmax_negative_axis_keepdims_random_select_last_index_cpu",
|
||||
"OnnxBackendNodeModelTest.test_argmax_no_keepdims_random_select_last_index_cpu",
|
||||
"OnnxBackendNodeModelTest.test_argmin_default_axis_example_select_last_index_cpu",
|
||||
"OnnxBackendNodeModelTest.test_argmin_default_axis_random_select_last_index_cpu",
|
||||
"OnnxBackendNodeModelTest.test_argmin_keepdims_random_select_last_index_cpu",
|
||||
"OnnxBackendNodeModelTest.test_argmin_negative_axis_keepdims_random_select_last_index_cpu",
|
||||
"OnnxBackendNodeModelTest.test_argmin_no_keepdims_random_select_last_index_cpu"),
|
||||
"OnnxBackendNodeModelTest.test_argmin_negative_axis_keepdims_random_select_last_index_cpu"),
|
||||
(xfail_issue_38091,
|
||||
"OnnxBackendNodeModelTest.test_gather_negative_indices_cpu"),
|
||||
(xfail_issue_52463,
|
||||
|
@ -50,6 +50,9 @@ graph {
|
||||
dim {
|
||||
dim_value: 3
|
||||
}
|
||||
dim {
|
||||
dim_value: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -2503,45 +2503,24 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_argmin_float)
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_argmax_select_last_index)
|
||||
{
|
||||
try
|
||||
{
|
||||
auto function = onnx_import::import_onnx_model(
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/argmax_select_last_index.prototxt"));
|
||||
FAIL() << "Expected exception was not thrown";
|
||||
}
|
||||
catch (const ngraph::ngraph_error& e)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(
|
||||
e.what(),
|
||||
std::string(
|
||||
"Mode 'select_last_index=1' is not supported by current implementation of ArgMax"));
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
FAIL() << "Expected OnnxNodeValidationFailure exception was not thrown";
|
||||
}
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(function);
|
||||
test_case.add_input<float>(Shape{4, 3}, {1, 1, 1, 0.5, 3, 4, 0.5, 1, 1.1, 0, 3, 0});
|
||||
test_case.add_expected_output<std::int64_t>(Shape{1, 3}, {0, 3, 1});
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_argmin_select_last_index)
|
||||
{
|
||||
try
|
||||
{
|
||||
auto function = onnx_import::import_onnx_model(
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/argmin_select_last_index.prototxt"));
|
||||
FAIL() << "Expected exception was not thrown";
|
||||
}
|
||||
catch (const ngraph::ngraph_error& e)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(
|
||||
e.what(),
|
||||
std::string(
|
||||
"Mode 'select_last_index=1' is not supported by current implementation of ArgMin"));
|
||||
std::string what{e.what()};
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
FAIL() << "Expected OnnxNodeValidationFailure exception was not thrown";
|
||||
}
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(function);
|
||||
test_case.add_input<float>(Shape{4, 3}, {1, 1, 1, 2, 3, 4, 2, 1, 1.1, 3, 3, 8});
|
||||
test_case.add_expected_output<std::int64_t>(Shape{4}, {2, 0, 1, 1});
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_top_k)
|
||||
|
@ -131,6 +131,8 @@ arg_max_dyn_shape
|
||||
# Result mismatch
|
||||
onnx_model_argmax_float
|
||||
onnx_model_argmin_float
|
||||
onnx_model_argmax_select_last_index
|
||||
onnx_model_argmin_select_last_index
|
||||
|
||||
# Constant has zero dimension that is not allowable
|
||||
onnx_dyn_shapes_transpose
|
||||
|
Loading…
Reference in New Issue
Block a user