Test nms fix (#19156)
This commit is contained in:
parent
89956b65e3
commit
014691de4c
@ -198,6 +198,11 @@ std::vector<TRShape> shape_infer(const Node* op,
|
||||
const auto min_selected_boxes =
|
||||
std::min(num_boxes.get_length(), static_cast<V>(max_out_boxes_per_class_val));
|
||||
selected_boxes = static_output ? TDim{min_selected_boxes} : TDim{0, min_selected_boxes};
|
||||
} else if (scores_rank.is_static() && num_boxes.get_max_length() != -1 &&
|
||||
scores_shape[0].get_max_length() != -1 && scores_shape[1].get_max_length() != -1) {
|
||||
const auto min_selected_boxes =
|
||||
std::min(num_boxes.get_max_length(), static_cast<V>(max_out_boxes_per_class_val));
|
||||
selected_boxes = static_output ? TDim{min_selected_boxes} : TDim{0, min_selected_boxes};
|
||||
}
|
||||
|
||||
if (scores_rank.is_static()) {
|
||||
|
@ -361,8 +361,8 @@ TYPED_TEST_P(NMSDynamicOutputTest, interval_shapes_labels) {
|
||||
Property("Scores type", &Output<Node>::get_element_type, element::f32),
|
||||
Property("Outputs type", &Output<Node>::get_element_type, element::i64)));
|
||||
EXPECT_THAT(op->outputs(),
|
||||
ElementsAre(Property("Indicies shape", &Output<Node>::get_partial_shape, PartialShape({-1, 3})),
|
||||
Property("Scores shape", &Output<Node>::get_partial_shape, PartialShape({-1, 3})),
|
||||
ElementsAre(Property("Indicies shape", &Output<Node>::get_partial_shape, PartialShape({{0, 70}, 3})),
|
||||
Property("Scores shape", &Output<Node>::get_partial_shape, PartialShape({{0, 70}, 3})),
|
||||
Property("Outputs shape", &Output<Node>::get_partial_shape, PartialShape({1}))));
|
||||
EXPECT_THAT(get_shape_labels(op->get_output_partial_shape(0)), Each(no_label));
|
||||
EXPECT_THAT(get_shape_labels(op->get_output_partial_shape(1)), Each(no_label));
|
||||
|
Loading…
Reference in New Issue
Block a user