nGraph: fix TopK output shape inference (#2967)

* nGraph: Fix TopK output shape inference

* nGraph: Correct TopK output shape inference

TopK lower bound of output shape at the axis was mistakenly calculated
basing on max_lenght instead of min_lenght.

* nGraph: Correct TopK output shape inference

* nGraph: Correct TopK type prop test

The topk_negative_axis_support type properties test was comparing
incompatible variables carrying the same value. So it was passing ok.

* nGraph: Add TopK type prop test

* nGraph: Fix code style

* nGraph: Follow review guidelines

Improve variables meaning.
Enforce rigid test pass condition.

* nGraph: Remove magic numbers
This commit is contained in:
Tomasz Jankowski 2020-11-13 15:10:42 +01:00 committed by GitHub
parent dc2ac0fb9e
commit 2966910dac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 36 additions and 3 deletions

View File

@ -242,7 +242,13 @@ void op::v1::TopK::validate_and_infer_types()
auto max_k = maximum_value(input_value(1));
if (max_k.first)
{
output_shape[m_normalized_axis] &= Dimension(0, max_k.second);
const auto in_min = output_shape[m_normalized_axis].get_min_length();
const auto in_max = output_shape[m_normalized_axis].get_max_length();
const auto lower = std::min<Dimension::value_type>(in_min, max_k.second);
const auto upper = in_max < 0
? Dimension::dynamic().get_max_length()
: std::max<Dimension::value_type>(in_max, max_k.second);
output_shape[m_normalized_axis] = Dimension(lower, upper);
}
else
{

View File

@ -38,7 +38,9 @@ TYPED_TEST_P(topk_type_prop, topk_negative_axis_support)
const auto topk = make_shared<TypeParam>(data, k, axis, "max", "value");
ASSERT_EQ(topk->get_provided_axis(), axis);
ASSERT_EQ(topk->get_axis(), data_shape.at(1));
const auto expect_shape = Shape{1, 2, 2, 4};
ASSERT_EQ(topk->get_output_shape(0), expect_shape);
ASSERT_EQ(topk->get_output_shape(1), expect_shape);
}
TYPED_TEST_P(topk_type_prop, topk_negative_axis_dynamic_rank)
@ -75,14 +77,39 @@ TYPED_TEST_P(topk_type_prop, topk_v1_partial_ouptut)
{
auto k = make_shared<op::Constant>(element::i32, Shape{}, 3);
auto topk = make_shared<TypeParam>(data, k, 1, "max", "value");
EXPECT_EQ(topk->get_output_shape(0), Shape({2, 3}));
EXPECT_EQ(topk->get_output_partial_shape(0), PartialShape({2, 3}));
}
}
TYPED_TEST_P(topk_type_prop, topk_rank_static_k_unknown)
{
const int64_t axis = 1;
const auto data_shape = Shape{1, 10, 100};
const auto data = make_shared<op::Parameter>(element::f32, data_shape);
{
const auto k = make_shared<op::Parameter>(element::i32, PartialShape({}));
const auto topk = make_shared<TypeParam>(data, k, axis, "max", "value");
const PartialShape fully_dynamic_axis_shape{1, Dimension::dynamic(), 100};
EXPECT_EQ(topk->get_output_partial_shape(0), fully_dynamic_axis_shape);
}
{
const auto k = make_shared<op::v0::Constant>(element::i64, Shape{}, 5);
const auto convert_k = make_shared<op::v0::Convert>(k, element::i32);
const auto topk = make_shared<TypeParam>(data, convert_k, axis, "max", "value");
const PartialShape ranged_dynamic_axis_shape{1, Dimension{5, 10}, 100};
EXPECT_EQ(topk->get_output_partial_shape(0), ranged_dynamic_axis_shape);
}
}
REGISTER_TYPED_TEST_CASE_P(topk_type_prop,
topk_negative_axis_support,
topk_negative_axis_dynamic_rank,
topk_v1_partial_ouptut);
topk_v1_partial_ouptut,
topk_rank_static_k_unknown);
typedef ::testing::Types<op::v1::TopK, op::v3::TopK> TopKTypes;
INSTANTIATE_TYPED_TEST_CASE_P(type_prop, topk_type_prop, TopKTypes, );