Revise topk (#3819)

* Add visit_attribute and node validation check

* add type_prop test for default values

* style-apply

* Update node validation check for index_element_type

* Update type_prop test for default index_element_type

* Add index_element_type attribute to TopK_1 spec
This commit is contained in:
Piotr Szmelczynski 2021-01-25 12:19:59 +01:00 committed by GitHub
parent d3f0242f58
commit 96b2ffa9ab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 28 additions and 1 deletions

View File

@ -32,6 +32,14 @@
* **Default value**: None
* **Required**: *yes*
* *index_element_type*
* **Description**: the type of output tensor with indices
* **Range of values**: "i64" or "i32"
* **Type**: string
* **Default value**: "i32"
* **Required**: *No*
**Inputs**:
* **1**: Arbitrary tensor. Required.

View File

@ -259,6 +259,7 @@ bool ngraph::op::v1::TopK::visit_attributes(AttributeVisitor& visitor)
visitor.on_attribute("axis", m_axis);
visitor.on_attribute("mode", m_mode);
visitor.on_attribute("sort", m_sort);
visitor.on_attribute("index_element_type", m_index_element_type);
return true;
}
@ -276,6 +277,12 @@ void op::v1::TopK::validate_and_infer_types()
NODE_VALIDATION_CHECK(
this, k_partial_shape.rank().compatible(0), "The 'K' input must be a scalar.");
NODE_VALIDATION_CHECK(this,
m_index_element_type == element::i32 ||
m_index_element_type == element::i64,
"Index element type attribute should be either \'i32\' or \'i64\'. Got: ",
m_index_element_type);
size_t k = 0;
if (op::is_constant(input_value(1).get_node()))
{

View File

@ -43,6 +43,17 @@ TYPED_TEST_P(topk_type_prop, topk_negative_axis_support)
ASSERT_EQ(topk->get_output_shape(1), expect_shape);
}
TYPED_TEST_P(topk_type_prop, topk_default_index_element_type)
{
const auto data_shape = Shape{1, 2, 3, 4};
const auto data = make_shared<op::Parameter>(element::f32, data_shape);
const auto k = op::Constant::create(element::i64, Shape{}, {2});
const int64_t axis = -2;
const auto op = make_shared<op::v1::TopK>(data, k, axis, "max", "value");
ASSERT_EQ(op->get_index_element_type(), element::i32);
}
TYPED_TEST_P(topk_type_prop, topk_negative_axis_dynamic_rank)
{
const auto data_shape = PartialShape::dynamic();
@ -109,7 +120,8 @@ REGISTER_TYPED_TEST_CASE_P(topk_type_prop,
topk_negative_axis_support,
topk_negative_axis_dynamic_rank,
topk_v1_partial_ouptut,
topk_rank_static_k_unknown);
topk_rank_static_k_unknown,
topk_default_index_element_type);
typedef ::testing::Types<op::v1::TopK, op::v3::TopK> TopKTypes;
INSTANTIATE_TYPED_TEST_CASE_P(type_prop, topk_type_prop, TopKTypes, );