Revise topk (#3819)
* Add visit_attribute and node validation check * add type_prop test for default values * style-apply * Update node validation check for index_element_type * Update type_prop test for default index_element_type * Add index_element_type attribute to TopK_1 spec
This commit is contained in:
parent
d3f0242f58
commit
96b2ffa9ab
@ -32,6 +32,14 @@
|
||||
* **Default value**: None
|
||||
* **Required**: *yes*
|
||||
|
||||
* *index_element_type*
|
||||
|
||||
* **Description**: the type of output tensor with indices
|
||||
* **Range of values**: "i64" or "i32"
|
||||
* **Type**: string
|
||||
* **Default value**: "i32"
|
||||
* **Required**: *No*
|
||||
|
||||
**Inputs**:
|
||||
|
||||
* **1**: Arbitrary tensor. Required.
|
||||
|
@ -259,6 +259,7 @@ bool ngraph::op::v1::TopK::visit_attributes(AttributeVisitor& visitor)
|
||||
visitor.on_attribute("axis", m_axis);
|
||||
visitor.on_attribute("mode", m_mode);
|
||||
visitor.on_attribute("sort", m_sort);
|
||||
visitor.on_attribute("index_element_type", m_index_element_type);
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -276,6 +277,12 @@ void op::v1::TopK::validate_and_infer_types()
|
||||
NODE_VALIDATION_CHECK(
|
||||
this, k_partial_shape.rank().compatible(0), "The 'K' input must be a scalar.");
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
m_index_element_type == element::i32 ||
|
||||
m_index_element_type == element::i64,
|
||||
"Index element type attribute should be either \'i32\' or \'i64\'. Got: ",
|
||||
m_index_element_type);
|
||||
|
||||
size_t k = 0;
|
||||
if (op::is_constant(input_value(1).get_node()))
|
||||
{
|
||||
|
@ -43,6 +43,17 @@ TYPED_TEST_P(topk_type_prop, topk_negative_axis_support)
|
||||
ASSERT_EQ(topk->get_output_shape(1), expect_shape);
|
||||
}
|
||||
|
||||
TYPED_TEST_P(topk_type_prop, topk_default_index_element_type)
|
||||
{
|
||||
const auto data_shape = Shape{1, 2, 3, 4};
|
||||
const auto data = make_shared<op::Parameter>(element::f32, data_shape);
|
||||
const auto k = op::Constant::create(element::i64, Shape{}, {2});
|
||||
const int64_t axis = -2;
|
||||
|
||||
const auto op = make_shared<op::v1::TopK>(data, k, axis, "max", "value");
|
||||
ASSERT_EQ(op->get_index_element_type(), element::i32);
|
||||
}
|
||||
|
||||
TYPED_TEST_P(topk_type_prop, topk_negative_axis_dynamic_rank)
|
||||
{
|
||||
const auto data_shape = PartialShape::dynamic();
|
||||
@ -109,7 +120,8 @@ REGISTER_TYPED_TEST_CASE_P(topk_type_prop,
|
||||
topk_negative_axis_support,
|
||||
topk_negative_axis_dynamic_rank,
|
||||
topk_v1_partial_ouptut,
|
||||
topk_rank_static_k_unknown);
|
||||
topk_rank_static_k_unknown,
|
||||
topk_default_index_element_type);
|
||||
|
||||
typedef ::testing::Types<op::v1::TopK, op::v3::TopK> TopKTypes;
|
||||
INSTANTIATE_TYPED_TEST_CASE_P(type_prop, topk_type_prop, TopKTypes, );
|
||||
|
Loading…
Reference in New Issue
Block a user