relax the class number check in paddle multiclass_nms op (#11857)
* relax the class number check in paddle multiclass_nms op * relax checks in paddle multiclass_nms op
This commit is contained in:
parent
54fe2d1a3f
commit
39981bf2b8
@ -446,6 +446,10 @@ void multiclass_nms(const float* boxes_data,
|
||||
for (int64_t i = 0; i < num_images; i++) {
|
||||
std::vector<BoxInfo> selected_boxes;
|
||||
if (shared) {
|
||||
OPENVINO_ASSERT(boxes_data_shape[0] == scores_data_shape[0],
|
||||
"Expect batch size of boxes and scores are the same.");
|
||||
OPENVINO_ASSERT(boxes_data_shape[1] == scores_data_shape[2],
|
||||
"Expect box numbers of boxes and scores are the same.");
|
||||
const auto num_boxes = boxes_data_shape[1];
|
||||
const auto num_classes = scores_data_shape[1];
|
||||
|
||||
@ -462,6 +466,10 @@ void multiclass_nms(const float* boxes_data,
|
||||
continue;
|
||||
}
|
||||
|
||||
OPENVINO_ASSERT(boxes_data_shape[0] == scores_data_shape[0],
|
||||
"Expect class numbers of boxes and scores are the same.");
|
||||
OPENVINO_ASSERT(boxes_data_shape[1] == scores_data_shape[1],
|
||||
"Expect box numbers of boxes and scores are the same.");
|
||||
const auto num_classes = boxes_data_shape[0];
|
||||
|
||||
const auto boxes = slice_image(boxes_data, boxes_data_shape, head, roisnum_data[i]);
|
||||
|
@ -81,7 +81,7 @@ void shape_infer(const ov::op::util::MulticlassNmsBase* op,
|
||||
const auto num_batches_scores = scores_ps[0];
|
||||
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
num_batches_boxes.same_scheme(num_batches_scores),
|
||||
num_batches_boxes.compatible(num_batches_scores),
|
||||
"The first dimension of both 'boxes' and 'scores' must match. Boxes: ",
|
||||
num_batches_boxes,
|
||||
"; Scores: ",
|
||||
@ -90,7 +90,7 @@ void shape_infer(const ov::op::util::MulticlassNmsBase* op,
|
||||
const auto num_boxes_boxes = boxes_ps[1];
|
||||
const auto num_boxes_scores = scores_ps[2];
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
num_boxes_boxes.same_scheme(num_boxes_scores),
|
||||
num_boxes_boxes.compatible(num_boxes_scores),
|
||||
"'boxes' and 'scores' input shapes must match at the second and third "
|
||||
"dimension respectively. Boxes: ",
|
||||
num_boxes_boxes,
|
||||
@ -106,7 +106,7 @@ void shape_infer(const ov::op::util::MulticlassNmsBase* op,
|
||||
const auto num_classes_boxes = boxes_ps[0];
|
||||
const auto num_classes_scores = scores_ps[0];
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
num_classes_boxes.same_scheme(num_classes_scores),
|
||||
num_classes_boxes.compatible(num_classes_scores),
|
||||
"'boxes' and 'scores' input shapes must match. Boxes: ",
|
||||
num_classes_boxes,
|
||||
"; Scores: ",
|
||||
@ -115,7 +115,7 @@ void shape_infer(const ov::op::util::MulticlassNmsBase* op,
|
||||
const auto num_boxes_boxes = boxes_ps[1];
|
||||
const auto num_boxes_scores = scores_ps[1];
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
num_boxes_boxes.same_scheme(num_boxes_scores),
|
||||
num_boxes_boxes.compatible(num_boxes_scores),
|
||||
"'boxes' and 'scores' input shapes must match. Boxes: ",
|
||||
num_boxes_boxes,
|
||||
"; Scores: ",
|
||||
|
@ -78,18 +78,20 @@ MultiClassNms::MultiClassNms(const std::shared_ptr<ov::Node>& op, const dnnl::en
|
||||
// boxes [C, M, 4], scores [C, M], roisnum [N] opset9
|
||||
const auto& boxes_dims = getInputShapeAtPort(NMS_BOXES).getDims();
|
||||
const auto& scores_dims = getInputShapeAtPort(NMS_SCORES).getDims();
|
||||
auto boxes_ps = PartialShape(boxes_dims);
|
||||
auto scores_ps = PartialShape(scores_dims);
|
||||
if (boxes_dims.size() != 3)
|
||||
IE_THROW() << m_errorPrefix << "has unsupported 'boxes' input rank: " << boxes_dims.size();
|
||||
if (boxes_dims[2] != 4)
|
||||
IE_THROW() << m_errorPrefix << "has unsupported 'boxes' input 3rd dimension size: " << boxes_dims[2];
|
||||
if (scores_dims.size() == 3) {
|
||||
if (boxes_dims[0] != scores_dims[0] || boxes_dims[1] != scores_dims[2])
|
||||
IE_THROW() << m_errorPrefix << "has incompatible 'boxes' and 'scores' shape " << PartialShape(boxes_dims) << " v.s. " << PartialShape(scores_dims);
|
||||
if (!boxes_ps[0].compatible(scores_ps[0]) || !boxes_ps[1].compatible(scores_ps[2]))
|
||||
IE_THROW() << m_errorPrefix << "has incompatible 'boxes' and 'scores' shape " << boxes_ps << " v.s. " << scores_ps;
|
||||
} else if (scores_dims.size() == 2) {
|
||||
if (op->get_type_info() == ov::op::v8::MulticlassNms::get_type_info_static())
|
||||
IE_THROW() << m_errorPrefix << "has unsupported 'scores' input rank: " << scores_dims.size();
|
||||
if (boxes_dims[0] != scores_dims[0] || boxes_dims[1] != scores_dims[1])
|
||||
IE_THROW() << m_errorPrefix << "has incompatible 'boxes' and 'scores' shape " << PartialShape(boxes_dims) << " v.s. " << PartialShape(scores_dims);
|
||||
if (!boxes_ps[0].compatible(scores_ps[0]) || !boxes_ps[1].compatible(scores_ps[1]))
|
||||
IE_THROW() << m_errorPrefix << "has incompatible 'boxes' and 'scores' shape " << boxes_ps << " v.s. " << scores_ps;
|
||||
if (getOriginalInputsNumber() != 3)
|
||||
IE_THROW() << m_errorPrefix << "has incorrect number of input edges: " << getOriginalInputsNumber() << " when input 'scores' is 2D.";
|
||||
} else {
|
||||
|
Loading…
Reference in New Issue
Block a user