relax the class number check in paddle multiclass_nms op (#11857)

* relax the class number check in paddle multiclass_nms op

* relax checks in paddle multiclass_nms op
This commit is contained in:
mei, yang 2022-06-16 11:29:15 +08:00 committed by GitHub
parent 54fe2d1a3f
commit 39981bf2b8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 18 additions and 8 deletions

View File

@ -446,6 +446,10 @@ void multiclass_nms(const float* boxes_data,
for (int64_t i = 0; i < num_images; i++) { for (int64_t i = 0; i < num_images; i++) {
std::vector<BoxInfo> selected_boxes; std::vector<BoxInfo> selected_boxes;
if (shared) { if (shared) {
OPENVINO_ASSERT(boxes_data_shape[0] == scores_data_shape[0],
"Expect batch size of boxes and scores are the same.");
OPENVINO_ASSERT(boxes_data_shape[1] == scores_data_shape[2],
"Expect box numbers of boxes and scores are the same.");
const auto num_boxes = boxes_data_shape[1]; const auto num_boxes = boxes_data_shape[1];
const auto num_classes = scores_data_shape[1]; const auto num_classes = scores_data_shape[1];
@ -462,6 +466,10 @@ void multiclass_nms(const float* boxes_data,
continue; continue;
} }
OPENVINO_ASSERT(boxes_data_shape[0] == scores_data_shape[0],
"Expect class numbers of boxes and scores are the same.");
OPENVINO_ASSERT(boxes_data_shape[1] == scores_data_shape[1],
"Expect box numbers of boxes and scores are the same.");
const auto num_classes = boxes_data_shape[0]; const auto num_classes = boxes_data_shape[0];
const auto boxes = slice_image(boxes_data, boxes_data_shape, head, roisnum_data[i]); const auto boxes = slice_image(boxes_data, boxes_data_shape, head, roisnum_data[i]);

View File

@ -81,7 +81,7 @@ void shape_infer(const ov::op::util::MulticlassNmsBase* op,
const auto num_batches_scores = scores_ps[0]; const auto num_batches_scores = scores_ps[0];
NODE_VALIDATION_CHECK(op, NODE_VALIDATION_CHECK(op,
num_batches_boxes.same_scheme(num_batches_scores), num_batches_boxes.compatible(num_batches_scores),
"The first dimension of both 'boxes' and 'scores' must match. Boxes: ", "The first dimension of both 'boxes' and 'scores' must match. Boxes: ",
num_batches_boxes, num_batches_boxes,
"; Scores: ", "; Scores: ",
@ -90,7 +90,7 @@ void shape_infer(const ov::op::util::MulticlassNmsBase* op,
const auto num_boxes_boxes = boxes_ps[1]; const auto num_boxes_boxes = boxes_ps[1];
const auto num_boxes_scores = scores_ps[2]; const auto num_boxes_scores = scores_ps[2];
NODE_VALIDATION_CHECK(op, NODE_VALIDATION_CHECK(op,
num_boxes_boxes.same_scheme(num_boxes_scores), num_boxes_boxes.compatible(num_boxes_scores),
"'boxes' and 'scores' input shapes must match at the second and third " "'boxes' and 'scores' input shapes must match at the second and third "
"dimension respectively. Boxes: ", "dimension respectively. Boxes: ",
num_boxes_boxes, num_boxes_boxes,
@ -106,7 +106,7 @@ void shape_infer(const ov::op::util::MulticlassNmsBase* op,
const auto num_classes_boxes = boxes_ps[0]; const auto num_classes_boxes = boxes_ps[0];
const auto num_classes_scores = scores_ps[0]; const auto num_classes_scores = scores_ps[0];
NODE_VALIDATION_CHECK(op, NODE_VALIDATION_CHECK(op,
num_classes_boxes.same_scheme(num_classes_scores), num_classes_boxes.compatible(num_classes_scores),
"'boxes' and 'scores' input shapes must match. Boxes: ", "'boxes' and 'scores' input shapes must match. Boxes: ",
num_classes_boxes, num_classes_boxes,
"; Scores: ", "; Scores: ",
@ -115,7 +115,7 @@ void shape_infer(const ov::op::util::MulticlassNmsBase* op,
const auto num_boxes_boxes = boxes_ps[1]; const auto num_boxes_boxes = boxes_ps[1];
const auto num_boxes_scores = scores_ps[1]; const auto num_boxes_scores = scores_ps[1];
NODE_VALIDATION_CHECK(op, NODE_VALIDATION_CHECK(op,
num_boxes_boxes.same_scheme(num_boxes_scores), num_boxes_boxes.compatible(num_boxes_scores),
"'boxes' and 'scores' input shapes must match. Boxes: ", "'boxes' and 'scores' input shapes must match. Boxes: ",
num_boxes_boxes, num_boxes_boxes,
"; Scores: ", "; Scores: ",

View File

@ -78,18 +78,20 @@ MultiClassNms::MultiClassNms(const std::shared_ptr<ov::Node>& op, const dnnl::en
// boxes [C, M, 4], scores [C, M], roisnum [N] opset9 // boxes [C, M, 4], scores [C, M], roisnum [N] opset9
const auto& boxes_dims = getInputShapeAtPort(NMS_BOXES).getDims(); const auto& boxes_dims = getInputShapeAtPort(NMS_BOXES).getDims();
const auto& scores_dims = getInputShapeAtPort(NMS_SCORES).getDims(); const auto& scores_dims = getInputShapeAtPort(NMS_SCORES).getDims();
auto boxes_ps = PartialShape(boxes_dims);
auto scores_ps = PartialShape(scores_dims);
if (boxes_dims.size() != 3) if (boxes_dims.size() != 3)
IE_THROW() << m_errorPrefix << "has unsupported 'boxes' input rank: " << boxes_dims.size(); IE_THROW() << m_errorPrefix << "has unsupported 'boxes' input rank: " << boxes_dims.size();
if (boxes_dims[2] != 4) if (boxes_dims[2] != 4)
IE_THROW() << m_errorPrefix << "has unsupported 'boxes' input 3rd dimension size: " << boxes_dims[2]; IE_THROW() << m_errorPrefix << "has unsupported 'boxes' input 3rd dimension size: " << boxes_dims[2];
if (scores_dims.size() == 3) { if (scores_dims.size() == 3) {
if (boxes_dims[0] != scores_dims[0] || boxes_dims[1] != scores_dims[2]) if (!boxes_ps[0].compatible(scores_ps[0]) || !boxes_ps[1].compatible(scores_ps[2]))
IE_THROW() << m_errorPrefix << "has incompatible 'boxes' and 'scores' shape " << PartialShape(boxes_dims) << " v.s. " << PartialShape(scores_dims); IE_THROW() << m_errorPrefix << "has incompatible 'boxes' and 'scores' shape " << boxes_ps << " v.s. " << scores_ps;
} else if (scores_dims.size() == 2) { } else if (scores_dims.size() == 2) {
if (op->get_type_info() == ov::op::v8::MulticlassNms::get_type_info_static()) if (op->get_type_info() == ov::op::v8::MulticlassNms::get_type_info_static())
IE_THROW() << m_errorPrefix << "has unsupported 'scores' input rank: " << scores_dims.size(); IE_THROW() << m_errorPrefix << "has unsupported 'scores' input rank: " << scores_dims.size();
if (boxes_dims[0] != scores_dims[0] || boxes_dims[1] != scores_dims[1]) if (!boxes_ps[0].compatible(scores_ps[0]) || !boxes_ps[1].compatible(scores_ps[1]))
IE_THROW() << m_errorPrefix << "has incompatible 'boxes' and 'scores' shape " << PartialShape(boxes_dims) << " v.s. " << PartialShape(scores_dims); IE_THROW() << m_errorPrefix << "has incompatible 'boxes' and 'scores' shape " << boxes_ps << " v.s. " << scores_ps;
if (getOriginalInputsNumber() != 3) if (getOriginalInputsNumber() != 3)
IE_THROW() << m_errorPrefix << "has incorrect number of input edges: " << getOriginalInputsNumber() << " when input 'scores' is 2D."; IE_THROW() << m_errorPrefix << "has incorrect number of input edges: " << getOriginalInputsNumber() << " when input 'scores' is 2D.";
} else { } else {