Test nms fix (#19156)

This commit is contained in:
Katarzyna Mitrus 2023-08-16 10:32:04 +02:00 committed by GitHub
parent 89956b65e3
commit 014691de4c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 7 additions and 2 deletions

View File

@ -198,6 +198,11 @@ std::vector<TRShape> shape_infer(const Node* op,
const auto min_selected_boxes =
std::min(num_boxes.get_length(), static_cast<V>(max_out_boxes_per_class_val));
selected_boxes = static_output ? TDim{min_selected_boxes} : TDim{0, min_selected_boxes};
} else if (scores_rank.is_static() && num_boxes.get_max_length() != -1 &&
scores_shape[0].get_max_length() != -1 && scores_shape[1].get_max_length() != -1) {
const auto min_selected_boxes =
std::min(num_boxes.get_max_length(), static_cast<V>(max_out_boxes_per_class_val));
selected_boxes = static_output ? TDim{min_selected_boxes} : TDim{0, min_selected_boxes};
}
if (scores_rank.is_static()) {

View File

@ -361,8 +361,8 @@ TYPED_TEST_P(NMSDynamicOutputTest, interval_shapes_labels) {
Property("Scores type", &Output<Node>::get_element_type, element::f32),
Property("Outputs type", &Output<Node>::get_element_type, element::i64)));
EXPECT_THAT(op->outputs(),
ElementsAre(Property("Indicies shape", &Output<Node>::get_partial_shape, PartialShape({-1, 3})),
Property("Scores shape", &Output<Node>::get_partial_shape, PartialShape({-1, 3})),
ElementsAre(Property("Indicies shape", &Output<Node>::get_partial_shape, PartialShape({{0, 70}, 3})),
Property("Scores shape", &Output<Node>::get_partial_shape, PartialShape({{0, 70}, 3})),
Property("Outputs shape", &Output<Node>::get_partial_shape, PartialShape({1}))));
EXPECT_THAT(get_shape_labels(op->get_output_partial_shape(0)), Each(no_label));
EXPECT_THAT(get_shape_labels(op->get_output_partial_shape(1)), Each(no_label));