nGraph: fix TopK output shape inference (#2967)
* nGraph: Fix TopK output shape inference * nGraph: Correct TopK output shape inference TopK lower bound of output shape at the axis was mistakenly calculated basing on max_lenght instead of min_lenght. * nGraph: Correct TopK output shape inference * nGraph: Correct TopK type prop test The topk_negative_axis_support type properties test was comparing incompatible variables carrying the same value. So it was passing ok. * nGraph: Add TopK type prop test * nGraph: Fix code style * nGraph: Follow review guidelines Improve variables meaning. Enforce rigid test pass condition. * nGraph: Remove magic numbers
This commit is contained in:
parent
dc2ac0fb9e
commit
2966910dac
@ -242,7 +242,13 @@ void op::v1::TopK::validate_and_infer_types()
|
||||
auto max_k = maximum_value(input_value(1));
|
||||
if (max_k.first)
|
||||
{
|
||||
output_shape[m_normalized_axis] &= Dimension(0, max_k.second);
|
||||
const auto in_min = output_shape[m_normalized_axis].get_min_length();
|
||||
const auto in_max = output_shape[m_normalized_axis].get_max_length();
|
||||
const auto lower = std::min<Dimension::value_type>(in_min, max_k.second);
|
||||
const auto upper = in_max < 0
|
||||
? Dimension::dynamic().get_max_length()
|
||||
: std::max<Dimension::value_type>(in_max, max_k.second);
|
||||
output_shape[m_normalized_axis] = Dimension(lower, upper);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
@ -38,7 +38,9 @@ TYPED_TEST_P(topk_type_prop, topk_negative_axis_support)
|
||||
const auto topk = make_shared<TypeParam>(data, k, axis, "max", "value");
|
||||
|
||||
ASSERT_EQ(topk->get_provided_axis(), axis);
|
||||
ASSERT_EQ(topk->get_axis(), data_shape.at(1));
|
||||
const auto expect_shape = Shape{1, 2, 2, 4};
|
||||
ASSERT_EQ(topk->get_output_shape(0), expect_shape);
|
||||
ASSERT_EQ(topk->get_output_shape(1), expect_shape);
|
||||
}
|
||||
|
||||
TYPED_TEST_P(topk_type_prop, topk_negative_axis_dynamic_rank)
|
||||
@ -75,14 +77,39 @@ TYPED_TEST_P(topk_type_prop, topk_v1_partial_ouptut)
|
||||
{
|
||||
auto k = make_shared<op::Constant>(element::i32, Shape{}, 3);
|
||||
auto topk = make_shared<TypeParam>(data, k, 1, "max", "value");
|
||||
EXPECT_EQ(topk->get_output_shape(0), Shape({2, 3}));
|
||||
EXPECT_EQ(topk->get_output_partial_shape(0), PartialShape({2, 3}));
|
||||
}
|
||||
}
|
||||
|
||||
TYPED_TEST_P(topk_type_prop, topk_rank_static_k_unknown)
|
||||
{
|
||||
const int64_t axis = 1;
|
||||
const auto data_shape = Shape{1, 10, 100};
|
||||
const auto data = make_shared<op::Parameter>(element::f32, data_shape);
|
||||
|
||||
{
|
||||
const auto k = make_shared<op::Parameter>(element::i32, PartialShape({}));
|
||||
const auto topk = make_shared<TypeParam>(data, k, axis, "max", "value");
|
||||
|
||||
const PartialShape fully_dynamic_axis_shape{1, Dimension::dynamic(), 100};
|
||||
EXPECT_EQ(topk->get_output_partial_shape(0), fully_dynamic_axis_shape);
|
||||
}
|
||||
{
|
||||
const auto k = make_shared<op::v0::Constant>(element::i64, Shape{}, 5);
|
||||
const auto convert_k = make_shared<op::v0::Convert>(k, element::i32);
|
||||
const auto topk = make_shared<TypeParam>(data, convert_k, axis, "max", "value");
|
||||
|
||||
const PartialShape ranged_dynamic_axis_shape{1, Dimension{5, 10}, 100};
|
||||
EXPECT_EQ(topk->get_output_partial_shape(0), ranged_dynamic_axis_shape);
|
||||
}
|
||||
}
|
||||
|
||||
REGISTER_TYPED_TEST_CASE_P(topk_type_prop,
|
||||
topk_negative_axis_support,
|
||||
topk_negative_axis_dynamic_rank,
|
||||
topk_v1_partial_ouptut);
|
||||
topk_v1_partial_ouptut,
|
||||
topk_rank_static_k_unknown);
|
||||
|
||||
typedef ::testing::Types<op::v1::TopK, op::v3::TopK> TopKTypes;
|
||||
INSTANTIATE_TYPED_TEST_CASE_P(type_prop, topk_type_prop, TopKTypes, );
|
||||
|
Loading…
Reference in New Issue
Block a user