diff --git a/src/core/reference/src/runtime/reference/multiclass_nms.cpp b/src/core/reference/src/runtime/reference/multiclass_nms.cpp index c85a37b2fed..e08cffbb00f 100644 --- a/src/core/reference/src/runtime/reference/multiclass_nms.cpp +++ b/src/core/reference/src/runtime/reference/multiclass_nms.cpp @@ -446,6 +446,10 @@ void multiclass_nms(const float* boxes_data, for (int64_t i = 0; i < num_images; i++) { std::vector selected_boxes; if (shared) { + OPENVINO_ASSERT(boxes_data_shape[0] == scores_data_shape[0], + "Expect batch size of boxes and scores are the same."); + OPENVINO_ASSERT(boxes_data_shape[1] == scores_data_shape[2], + "Expect box numbers of boxes and scores are the same."); const auto num_boxes = boxes_data_shape[1]; const auto num_classes = scores_data_shape[1]; @@ -462,6 +466,10 @@ void multiclass_nms(const float* boxes_data, continue; } + OPENVINO_ASSERT(boxes_data_shape[0] == scores_data_shape[0], + "Expect class numbers of boxes and scores are the same."); + OPENVINO_ASSERT(boxes_data_shape[1] == scores_data_shape[1], + "Expect box numbers of boxes and scores are the same."); const auto num_classes = boxes_data_shape[0]; const auto boxes = slice_image(boxes_data, boxes_data_shape, head, roisnum_data[i]); diff --git a/src/core/shape_inference/include/multiclass_nms_shape_inference.hpp b/src/core/shape_inference/include/multiclass_nms_shape_inference.hpp index ec7c3921ea4..269107e1b0c 100644 --- a/src/core/shape_inference/include/multiclass_nms_shape_inference.hpp +++ b/src/core/shape_inference/include/multiclass_nms_shape_inference.hpp @@ -81,7 +81,7 @@ void shape_infer(const ov::op::util::MulticlassNmsBase* op, const auto num_batches_scores = scores_ps[0]; NODE_VALIDATION_CHECK(op, - num_batches_boxes.same_scheme(num_batches_scores), + num_batches_boxes.compatible(num_batches_scores), "The first dimension of both 'boxes' and 'scores' must match. Boxes: ", num_batches_boxes, "; Scores: ", @@ -90,7 +90,7 @@ void shape_infer(const ov::op::util::MulticlassNmsBase* op, const auto num_boxes_boxes = boxes_ps[1]; const auto num_boxes_scores = scores_ps[2]; NODE_VALIDATION_CHECK(op, - num_boxes_boxes.same_scheme(num_boxes_scores), + num_boxes_boxes.compatible(num_boxes_scores), "'boxes' and 'scores' input shapes must match at the second and third " "dimension respectively. Boxes: ", num_boxes_boxes, @@ -106,7 +106,7 @@ void shape_infer(const ov::op::util::MulticlassNmsBase* op, const auto num_classes_boxes = boxes_ps[0]; const auto num_classes_scores = scores_ps[0]; NODE_VALIDATION_CHECK(op, - num_classes_boxes.same_scheme(num_classes_scores), + num_classes_boxes.compatible(num_classes_scores), "'boxes' and 'scores' input shapes must match. Boxes: ", num_classes_boxes, "; Scores: ", @@ -115,7 +115,7 @@ void shape_infer(const ov::op::util::MulticlassNmsBase* op, const auto num_boxes_boxes = boxes_ps[1]; const auto num_boxes_scores = scores_ps[1]; NODE_VALIDATION_CHECK(op, - num_boxes_boxes.same_scheme(num_boxes_scores), + num_boxes_boxes.compatible(num_boxes_scores), "'boxes' and 'scores' input shapes must match. Boxes: ", num_boxes_boxes, "; Scores: ", diff --git a/src/plugins/intel_cpu/src/nodes/multiclass_nms.cpp b/src/plugins/intel_cpu/src/nodes/multiclass_nms.cpp index 556a1bd3d84..3c46b2c5992 100644 --- a/src/plugins/intel_cpu/src/nodes/multiclass_nms.cpp +++ b/src/plugins/intel_cpu/src/nodes/multiclass_nms.cpp @@ -78,18 +78,20 @@ MultiClassNms::MultiClassNms(const std::shared_ptr& op, const dnnl::en // boxes [C, M, 4], scores [C, M], roisnum [N] opset9 const auto& boxes_dims = getInputShapeAtPort(NMS_BOXES).getDims(); const auto& scores_dims = getInputShapeAtPort(NMS_SCORES).getDims(); + auto boxes_ps = PartialShape(boxes_dims); + auto scores_ps = PartialShape(scores_dims); if (boxes_dims.size() != 3) IE_THROW() << m_errorPrefix << "has unsupported 'boxes' input rank: " << boxes_dims.size(); if (boxes_dims[2] != 4) IE_THROW() << m_errorPrefix << "has unsupported 'boxes' input 3rd dimension size: " << boxes_dims[2]; if (scores_dims.size() == 3) { - if (boxes_dims[0] != scores_dims[0] || boxes_dims[1] != scores_dims[2]) - IE_THROW() << m_errorPrefix << "has incompatible 'boxes' and 'scores' shape " << PartialShape(boxes_dims) << " v.s. " << PartialShape(scores_dims); + if (!boxes_ps[0].compatible(scores_ps[0]) || !boxes_ps[1].compatible(scores_ps[2])) + IE_THROW() << m_errorPrefix << "has incompatible 'boxes' and 'scores' shape " << boxes_ps << " v.s. " << scores_ps; } else if (scores_dims.size() == 2) { if (op->get_type_info() == ov::op::v8::MulticlassNms::get_type_info_static()) IE_THROW() << m_errorPrefix << "has unsupported 'scores' input rank: " << scores_dims.size(); - if (boxes_dims[0] != scores_dims[0] || boxes_dims[1] != scores_dims[1]) - IE_THROW() << m_errorPrefix << "has incompatible 'boxes' and 'scores' shape " << PartialShape(boxes_dims) << " v.s. " << PartialShape(scores_dims); + if (!boxes_ps[0].compatible(scores_ps[0]) || !boxes_ps[1].compatible(scores_ps[1])) + IE_THROW() << m_errorPrefix << "has incompatible 'boxes' and 'scores' shape " << boxes_ps << " v.s. " << scores_ps; if (getOriginalInputsNumber() != 3) IE_THROW() << m_errorPrefix << "has incorrect number of input edges: " << getOriginalInputsNumber() << " when input 'scores' is 2D."; } else {