[CPU] NMS optimization (#8312)
This commit is contained in:
@@ -15,8 +15,536 @@
|
||||
#include <ngraph_ops/nms_ie_internal.hpp>
|
||||
#include "utils/general_utils.h"
|
||||
|
||||
#include "cpu/x64/jit_generator.hpp"
|
||||
#include "emitters/jit_load_store_emitters.hpp"
|
||||
#include <cpu/x64/injectors/jit_uni_eltwise_injector.hpp>
|
||||
|
||||
using namespace MKLDNNPlugin;
|
||||
using namespace InferenceEngine;
|
||||
using namespace mkldnn;
|
||||
using namespace mkldnn::impl;
|
||||
using namespace mkldnn::impl::cpu::x64;
|
||||
using namespace mkldnn::impl::utils;
|
||||
using namespace Xbyak;
|
||||
|
||||
#define GET_OFF(field) offsetof(jit_nms_args, field)
|
||||
|
||||
template <cpu_isa_t isa>
|
||||
struct jit_uni_nms_kernel_f32 : public jit_uni_nms_kernel, public jit_generator {
|
||||
DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_nms_kernel_f32)
|
||||
|
||||
explicit jit_uni_nms_kernel_f32(jit_nms_config_params jcp_) : jit_uni_nms_kernel(jcp_), jit_generator() {}
|
||||
|
||||
void create_ker() override {
|
||||
jit_generator::create_kernel();
|
||||
ker_ = (decltype(ker_))jit_ker();
|
||||
}
|
||||
|
||||
void generate() override {
|
||||
load_emitter.reset(new jit_load_emitter(this, isa, nullptr));
|
||||
store_emitter.reset(new jit_store_emitter(this, isa, nullptr));
|
||||
exp_injector.reset(new jit_uni_eltwise_injector_f32<isa>(this, mkldnn::impl::alg_kind::eltwise_exp, 0.f, 0.f, 1.0f));
|
||||
|
||||
this->preamble();
|
||||
|
||||
uni_vpxor(vmm_zero, vmm_zero, vmm_zero);
|
||||
|
||||
load_pool_gpr_idxs = {static_cast<size_t>(reg_load_store_mask.getIdx()), static_cast<size_t>(reg_load_table.getIdx())};
|
||||
store_pool_gpr_idxs = {static_cast<size_t>(reg_load_store_mask.getIdx())};
|
||||
store_pool_vec_idxs = {static_cast<size_t>(vmm_zero.getIdx())};
|
||||
|
||||
mov(reg_boxes_coord0, ptr[reg_params + GET_OFF(selected_boxes_coord[0])]);
|
||||
mov(reg_boxes_coord1, ptr[reg_params + GET_OFF(selected_boxes_coord[0]) + 1 * sizeof(size_t)]);
|
||||
mov(reg_boxes_coord2, ptr[reg_params + GET_OFF(selected_boxes_coord[0]) + 2 * sizeof(size_t)]);
|
||||
mov(reg_boxes_coord3, ptr[reg_params + GET_OFF(selected_boxes_coord[0]) + 3 * sizeof(size_t)]);
|
||||
mov(reg_candidate_box, ptr[reg_params + GET_OFF(candidate_box)]);
|
||||
mov(reg_candidate_status, ptr[reg_params + GET_OFF(candidate_status)]);
|
||||
mov(reg_boxes_num, ptr[reg_params + GET_OFF(selected_boxes_num)]);
|
||||
mov(reg_iou_threshold, ptr[reg_params + GET_OFF(iou_threshold)]);
|
||||
// soft
|
||||
mov(reg_score_threshold, ptr[reg_params + GET_OFF(score_threshold)]);
|
||||
mov(reg_score, ptr[reg_params + GET_OFF(score)]);
|
||||
mov(reg_scale, ptr[reg_params + GET_OFF(scale)]);
|
||||
|
||||
// could use rcx(reg_table) and rdi(reg_temp) now as abi parse finished
|
||||
mov(reg_table, l_table_constant);
|
||||
if (mayiuse(cpu::x64::avx512_common)) {
|
||||
kmovw(k_mask_one, word[reg_table + vlen]);
|
||||
}
|
||||
uni_vbroadcastss(vmm_iou_threshold, ptr[reg_iou_threshold]);
|
||||
uni_vbroadcastss(vmm_score_threshold, ptr[reg_score_threshold]);
|
||||
|
||||
uni_vbroadcastss(vmm_candidate_coord0, ptr[reg_candidate_box]);
|
||||
uni_vbroadcastss(vmm_candidate_coord1, ptr[reg_candidate_box + 1 * sizeof(float)]);
|
||||
uni_vbroadcastss(vmm_candidate_coord2, ptr[reg_candidate_box + 2 * sizeof(float)]);
|
||||
uni_vbroadcastss(vmm_candidate_coord3, ptr[reg_candidate_box + 3 * sizeof(float)]);
|
||||
|
||||
if (jcp.box_encode_type == NMSBoxEncodeType::CORNER) {
|
||||
// box format: y1, x1, y2, x2
|
||||
uni_vminps(vmm_temp1, vmm_candidate_coord0, vmm_candidate_coord2);
|
||||
uni_vmaxps(vmm_temp2, vmm_candidate_coord0, vmm_candidate_coord2);
|
||||
uni_vmovups(vmm_candidate_coord0, vmm_temp1);
|
||||
uni_vmovups(vmm_candidate_coord2, vmm_temp2);
|
||||
|
||||
uni_vminps(vmm_temp1, vmm_candidate_coord1, vmm_candidate_coord3);
|
||||
uni_vmaxps(vmm_temp2, vmm_candidate_coord1, vmm_candidate_coord3);
|
||||
uni_vmovups(vmm_candidate_coord1, vmm_temp1);
|
||||
uni_vmovups(vmm_candidate_coord3, vmm_temp2);
|
||||
} else {
|
||||
// box format: x_center, y_center, width, height --> y1, x1, y2, x2
|
||||
uni_vmulps(vmm_temp1, vmm_candidate_coord2, ptr[reg_table]); // width/2
|
||||
uni_vmulps(vmm_temp2, vmm_candidate_coord3, ptr[reg_table]); // height/2
|
||||
|
||||
uni_vaddps(vmm_temp3, vmm_candidate_coord0, vmm_temp1); // x_center + width/2
|
||||
uni_vmovups(vmm_candidate_coord3, vmm_temp3);
|
||||
|
||||
uni_vaddps(vmm_temp3, vmm_candidate_coord1, vmm_temp2); // y_center + height/2
|
||||
uni_vmovups(vmm_candidate_coord2, vmm_temp3);
|
||||
|
||||
uni_vsubps(vmm_temp3, vmm_candidate_coord0, vmm_temp1); // x_center - width/2
|
||||
uni_vsubps(vmm_temp4, vmm_candidate_coord1, vmm_temp2); // y_center - height/2
|
||||
|
||||
uni_vmovups(vmm_candidate_coord1, vmm_temp3);
|
||||
uni_vmovups(vmm_candidate_coord0, vmm_temp4);
|
||||
}
|
||||
|
||||
// check from last to first
|
||||
imul(reg_temp_64, reg_boxes_num, sizeof(float));
|
||||
add(reg_boxes_coord0, reg_temp_64); // y1
|
||||
add(reg_boxes_coord1, reg_temp_64); // x1
|
||||
add(reg_boxes_coord2, reg_temp_64); // y2
|
||||
add(reg_boxes_coord3, reg_temp_64); // x2
|
||||
|
||||
Xbyak::Label hard_nms_label;
|
||||
Xbyak::Label nms_end_label;
|
||||
|
||||
mov(reg_temp_32, ptr[reg_scale]);
|
||||
test(reg_temp_32, reg_temp_32);
|
||||
jz(hard_nms_label, T_NEAR);
|
||||
|
||||
soft_nms();
|
||||
|
||||
jmp(nms_end_label, T_NEAR);
|
||||
|
||||
L(hard_nms_label);
|
||||
|
||||
hard_nms();
|
||||
|
||||
L(nms_end_label);
|
||||
|
||||
this->postamble();
|
||||
|
||||
load_emitter->emit_data();
|
||||
store_emitter->emit_data();
|
||||
|
||||
prepare_table();
|
||||
exp_injector->prepare_table();
|
||||
}
|
||||
|
||||
private:
|
||||
using Vmm = typename conditional3<isa == cpu::x64::sse41, Xbyak::Xmm, isa == cpu::x64::avx2, Xbyak::Ymm, Xbyak::Zmm>::type;
|
||||
uint32_t vlen = cpu_isa_traits<isa>::vlen;
|
||||
|
||||
Xbyak::Reg64 reg_boxes_coord0 = r8;
|
||||
Xbyak::Reg64 reg_boxes_coord1 = r9;
|
||||
Xbyak::Reg64 reg_boxes_coord2 = r10;
|
||||
Xbyak::Reg64 reg_boxes_coord3 = r11;
|
||||
Xbyak::Reg64 reg_candidate_box = r12;
|
||||
Xbyak::Reg64 reg_candidate_status = r13;
|
||||
Xbyak::Reg64 reg_boxes_num = r14;
|
||||
Xbyak::Reg64 reg_iou_threshold = r15;
|
||||
// more for soft
|
||||
Xbyak::Reg64 reg_score_threshold = rdx;
|
||||
Xbyak::Reg64 reg_score = rbp;
|
||||
Xbyak::Reg64 reg_scale = rsi;
|
||||
|
||||
Xbyak::Reg64 reg_load_table = rax;
|
||||
Xbyak::Reg64 reg_load_store_mask = rbx;
|
||||
|
||||
// reuse
|
||||
Xbyak::Label l_table_constant;
|
||||
Xbyak::Reg64 reg_table = rcx;
|
||||
Xbyak::Reg64 reg_temp_64 = rdi;
|
||||
Xbyak::Reg32 reg_temp_32 = edi;
|
||||
|
||||
Xbyak::Reg64 reg_params = abi_param1;
|
||||
|
||||
std::unique_ptr<jit_load_emitter> load_emitter = nullptr;
|
||||
std::unique_ptr<jit_store_emitter> store_emitter = nullptr;
|
||||
|
||||
std::vector<size_t> store_pool_gpr_idxs;
|
||||
std::vector<size_t> store_pool_vec_idxs;
|
||||
std::vector<size_t> load_pool_gpr_idxs;
|
||||
|
||||
Vmm vmm_boxes_coord0 = Vmm(1);
|
||||
Vmm vmm_boxes_coord1 = Vmm(2);
|
||||
Vmm vmm_boxes_coord2 = Vmm(3);
|
||||
Vmm vmm_boxes_coord3 = Vmm(4);
|
||||
Vmm vmm_candidate_coord0 = Vmm(5);
|
||||
Vmm vmm_candidate_coord1 = Vmm(6);
|
||||
Vmm vmm_candidate_coord2 = Vmm(7);
|
||||
Vmm vmm_candidate_coord3 = Vmm(8);
|
||||
Vmm vmm_temp1 = Vmm(9);
|
||||
Vmm vmm_temp2 = Vmm(10);
|
||||
Vmm vmm_temp3 = Vmm(11);
|
||||
Vmm vmm_temp4 = Vmm(12);
|
||||
|
||||
Vmm vmm_iou_threshold = Vmm(13);
|
||||
Vmm vmm_zero = Vmm(15);
|
||||
|
||||
// soft
|
||||
Vmm vmm_score_threshold = Vmm(14);
|
||||
Vmm vmm_scale = Vmm(0);
|
||||
|
||||
Xbyak::Opmask k_mask = Xbyak::Opmask(7);
|
||||
Xbyak::Opmask k_mask_one = Xbyak::Opmask(6);
|
||||
|
||||
std::shared_ptr<jit_uni_eltwise_injector_f32<isa>> exp_injector;
|
||||
|
||||
inline void hard_nms() {
|
||||
int step = vlen / sizeof(float);
|
||||
Xbyak::Label main_loop_label_hard;
|
||||
Xbyak::Label main_loop_end_label_hard;
|
||||
Xbyak::Label tail_loop_label_hard;
|
||||
Xbyak::Label terminate_label_hard;
|
||||
L(main_loop_label_hard);
|
||||
{
|
||||
cmp(reg_boxes_num, step);
|
||||
jl(main_loop_end_label_hard, T_NEAR);
|
||||
|
||||
sub(reg_boxes_coord0, step * sizeof(float));
|
||||
sub(reg_boxes_coord1, step * sizeof(float));
|
||||
sub(reg_boxes_coord2, step * sizeof(float));
|
||||
sub(reg_boxes_coord3, step * sizeof(float));
|
||||
|
||||
// iou result is in vmm_temp3
|
||||
iou(step);
|
||||
|
||||
sub(reg_boxes_num, step);
|
||||
|
||||
suppressed_by_iou(false);
|
||||
|
||||
// if zero continue, else set result to suppressed and terminate
|
||||
jz(main_loop_label_hard, T_NEAR);
|
||||
|
||||
uni_vpextrd(ptr[reg_candidate_status], Xmm(vmm_zero.getIdx()), 0);
|
||||
|
||||
jmp(terminate_label_hard, T_NEAR);
|
||||
}
|
||||
L(main_loop_end_label_hard);
|
||||
|
||||
step = 1;
|
||||
L(tail_loop_label_hard);
|
||||
{
|
||||
cmp(reg_boxes_num, 1);
|
||||
jl(terminate_label_hard, T_NEAR);
|
||||
|
||||
sub(reg_boxes_coord0, step * sizeof(float));
|
||||
sub(reg_boxes_coord1, step * sizeof(float));
|
||||
sub(reg_boxes_coord2, step * sizeof(float));
|
||||
sub(reg_boxes_coord3, step * sizeof(float));
|
||||
|
||||
// iou result is in vmm_temp3
|
||||
iou(step);
|
||||
|
||||
sub(reg_boxes_num, step);
|
||||
|
||||
suppressed_by_iou(true);
|
||||
|
||||
jz(tail_loop_label_hard, T_NEAR);
|
||||
|
||||
uni_vpextrd(ptr[reg_candidate_status], Xmm(vmm_zero.getIdx()), 0);
|
||||
|
||||
jmp(terminate_label_hard, T_NEAR);
|
||||
}
|
||||
|
||||
L(terminate_label_hard);
|
||||
}
|
||||
|
||||
inline void soft_nms() {
|
||||
uni_vbroadcastss(vmm_scale, ptr[reg_scale]);
|
||||
|
||||
int step = vlen / sizeof(float);
|
||||
Xbyak::Label main_loop_label;
|
||||
Xbyak::Label main_loop_end_label;
|
||||
Xbyak::Label tail_loop_label;
|
||||
Xbyak::Label terminate_label;
|
||||
|
||||
Xbyak::Label main_loop_label_soft;
|
||||
Xbyak::Label tail_loop_label_soft;
|
||||
L(main_loop_label);
|
||||
{
|
||||
cmp(reg_boxes_num, step);
|
||||
jl(main_loop_end_label, T_NEAR);
|
||||
|
||||
sub(reg_boxes_coord0, step * sizeof(float));
|
||||
sub(reg_boxes_coord1, step * sizeof(float));
|
||||
sub(reg_boxes_coord2, step * sizeof(float));
|
||||
sub(reg_boxes_coord3, step * sizeof(float));
|
||||
|
||||
// result(iou and weight) is in vmm_temp3
|
||||
iou(step);
|
||||
sub(reg_boxes_num, step);
|
||||
|
||||
// soft suppressed by iou_threshold
|
||||
if (jcp.is_soft_suppressed_by_iou) {
|
||||
suppressed_by_iou(false);
|
||||
|
||||
// if zero continue soft suppression, else set result to suppressed and terminate
|
||||
jz(main_loop_label_soft, T_NEAR);
|
||||
|
||||
uni_vpextrd(ptr[reg_candidate_status], Xmm(vmm_zero.getIdx()), 0);
|
||||
|
||||
jmp(terminate_label, T_NEAR);
|
||||
|
||||
L(main_loop_label_soft);
|
||||
}
|
||||
|
||||
// weight: std::exp(scale * iou * iou)
|
||||
soft_coeff();
|
||||
|
||||
// vector weights multiply
|
||||
horizontal_mul();
|
||||
|
||||
uni_vbroadcastss(vmm_temp1, ptr[reg_score]);
|
||||
|
||||
// new score in vmm3[0]
|
||||
uni_vmulps(vmm_temp3, vmm_temp3, vmm_temp1);
|
||||
// store new score
|
||||
uni_vmovss(ptr[reg_score], vmm_temp3);
|
||||
|
||||
// cmpps(_CMP_LE_OS) if new score is less or equal than score_threshold
|
||||
suppressed_by_score();
|
||||
|
||||
jz(main_loop_label, T_NEAR);
|
||||
|
||||
uni_vpextrd(ptr[reg_candidate_status], Xmm(vmm_zero.getIdx()), 0);
|
||||
|
||||
jmp(terminate_label, T_NEAR);
|
||||
}
|
||||
L(main_loop_end_label);
|
||||
|
||||
step = 1;
|
||||
L(tail_loop_label);
|
||||
{
|
||||
cmp(reg_boxes_num, 1);
|
||||
jl(terminate_label, T_NEAR);
|
||||
|
||||
sub(reg_boxes_coord0, step * sizeof(float));
|
||||
sub(reg_boxes_coord1, step * sizeof(float));
|
||||
sub(reg_boxes_coord2, step * sizeof(float));
|
||||
sub(reg_boxes_coord3, step * sizeof(float));
|
||||
|
||||
iou(step);
|
||||
sub(reg_boxes_num, step);
|
||||
|
||||
// soft suppressed by iou_threshold
|
||||
if (jcp.is_soft_suppressed_by_iou) {
|
||||
suppressed_by_iou(true);
|
||||
|
||||
jz(tail_loop_label_soft, T_NEAR);
|
||||
|
||||
uni_vpextrd(ptr[reg_candidate_status], Xmm(vmm_zero.getIdx()), 0);
|
||||
|
||||
jmp(terminate_label, T_NEAR);
|
||||
|
||||
L(tail_loop_label_soft);
|
||||
}
|
||||
|
||||
soft_coeff();
|
||||
|
||||
uni_vbroadcastss(vmm_temp1, ptr[reg_score]);
|
||||
|
||||
// vmm3[0] is valide, no need horizontal mul.
|
||||
uni_vmulps(vmm_temp3, vmm_temp3, vmm_temp1);
|
||||
|
||||
uni_vmovss(ptr[reg_score], vmm_temp3);
|
||||
|
||||
// cmpps(_CMP_LE_OS) if new score is less or equal than score_threshold
|
||||
suppressed_by_score();
|
||||
|
||||
jz(tail_loop_label, T_NEAR);
|
||||
|
||||
uni_vpextrd(ptr[reg_candidate_status], Xmm(vmm_zero.getIdx()), 0);
|
||||
|
||||
jmp(terminate_label, T_NEAR);
|
||||
}
|
||||
|
||||
L(terminate_label);
|
||||
}
|
||||
|
||||
inline void suppressed_by_iou(bool is_scalar) {
|
||||
if (mayiuse(cpu::x64::avx512_common)) {
|
||||
vcmpps(k_mask, vmm_temp3, vmm_iou_threshold, 0x0D); // _CMP_GE_OS. vcmpps w/ kmask only on V5
|
||||
if (is_scalar)
|
||||
kandw(k_mask, k_mask, k_mask_one);
|
||||
kortestw(k_mask, k_mask); // bitwise check if all zero
|
||||
} else if (mayiuse(cpu::x64::avx)) {
|
||||
// vex instructions with xmm on avx and ymm on avx2
|
||||
vcmpps(vmm_temp4, vmm_temp3, vmm_iou_threshold, 0x0D); // xmm and ymm only on V1.
|
||||
if (is_scalar) {
|
||||
uni_vpextrd(reg_temp_32, Xmm(vmm_temp4.getIdx()), 0);
|
||||
test(reg_temp_32, reg_temp_32);
|
||||
} else {
|
||||
uni_vtestps(vmm_temp4, vmm_temp4); // vtestps: sign bit check if all zeros, ymm and xmm only on V1, N/A on V5
|
||||
}
|
||||
} else {
|
||||
// pure sse path, make sure don't spoil vmm_temp3, which may used in after soft-suppression
|
||||
uni_vmovups(vmm_temp4, vmm_temp3);
|
||||
cmpps(vmm_temp4, vmm_iou_threshold, 0x07); // order compare, 0 for unorders
|
||||
|
||||
uni_vmovups(vmm_temp2, vmm_temp3);
|
||||
cmpps(vmm_temp2, vmm_iou_threshold, 0x05); // _CMP_GE_US on sse, no direct _CMP_GE_OS supported.
|
||||
|
||||
uni_vandps(vmm_temp4, vmm_temp4, vmm_temp2);
|
||||
if (is_scalar) {
|
||||
uni_vpextrd(reg_temp_32, Xmm(vmm_temp4.getIdx()), 0);
|
||||
test(reg_temp_32, reg_temp_32);
|
||||
} else {
|
||||
uni_vtestps(vmm_temp4, vmm_temp4); // ptest: bitwise check if all zeros, on sse41
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline void suppressed_by_score() {
|
||||
if (mayiuse(cpu::x64::avx512_common)) {
|
||||
vcmpps(k_mask, vmm_temp3, vmm_score_threshold, 0x02); // vcmpps w/ kmask only on V5, w/o kmask version N/A on V5
|
||||
kandw(k_mask, k_mask, k_mask_one);
|
||||
kortestw(k_mask, k_mask); // bitwise check if all zero
|
||||
} else if (mayiuse(cpu::x64::avx)) {
|
||||
vcmpps(vmm_temp4, vmm_temp3, vmm_score_threshold, 0x02);
|
||||
uni_vpextrd(reg_temp_32, Xmm(vmm_temp4.getIdx()), 0);
|
||||
test(reg_temp_32, reg_temp_32);
|
||||
} else {
|
||||
cmpps(vmm_temp3, vmm_score_threshold, 0x02); // _CMP_LE_OS on sse
|
||||
uni_vpextrd(reg_temp_32, Xmm(vmm_temp3.getIdx()), 0);
|
||||
test(reg_temp_32, reg_temp_32);
|
||||
}
|
||||
}
|
||||
|
||||
inline void iou(int ele_num) {
|
||||
auto load = [&](Xbyak::Reg64 reg_src, Vmm vmm_dst) {
|
||||
load_emitter->emit_code({static_cast<size_t>(reg_src.getIdx())}, {static_cast<size_t>(vmm_dst.getIdx())},
|
||||
std::make_shared<load_emitter_context>(Precision::FP32, Precision::FP32, ele_num),
|
||||
{}, {load_pool_gpr_idxs});
|
||||
};
|
||||
load(reg_boxes_coord0, vmm_boxes_coord0);
|
||||
load(reg_boxes_coord1, vmm_boxes_coord1);
|
||||
load(reg_boxes_coord2, vmm_boxes_coord2);
|
||||
load(reg_boxes_coord3, vmm_boxes_coord3);
|
||||
|
||||
if (jcp.box_encode_type == NMSBoxEncodeType::CORNER) {
|
||||
// box format: y1, x1, y2, x2
|
||||
uni_vminps(vmm_temp1, vmm_boxes_coord0, vmm_boxes_coord2);
|
||||
uni_vmaxps(vmm_temp2, vmm_boxes_coord0, vmm_boxes_coord2);
|
||||
uni_vmovups(vmm_boxes_coord0, vmm_temp1);
|
||||
uni_vmovups(vmm_boxes_coord2, vmm_temp2);
|
||||
|
||||
uni_vminps(vmm_temp1, vmm_boxes_coord1, vmm_boxes_coord3);
|
||||
uni_vmaxps(vmm_temp2, vmm_boxes_coord1, vmm_boxes_coord3);
|
||||
uni_vmovups(vmm_boxes_coord1, vmm_temp1);
|
||||
uni_vmovups(vmm_boxes_coord3, vmm_temp2);
|
||||
} else {
|
||||
// box format: x_center, y_center, width, height --> y1, x1, y2, x2
|
||||
uni_vmulps(vmm_temp1, vmm_boxes_coord2, ptr[reg_table]); // width/2
|
||||
uni_vmulps(vmm_temp2, vmm_boxes_coord3, ptr[reg_table]); // height/2
|
||||
|
||||
uni_vaddps(vmm_temp3, vmm_boxes_coord0, vmm_temp1); // x_center + width/2
|
||||
uni_vmovups(vmm_boxes_coord3, vmm_temp3);
|
||||
|
||||
uni_vaddps(vmm_temp3, vmm_boxes_coord1, vmm_temp2); // y_center + height/2
|
||||
uni_vmovups(vmm_boxes_coord2, vmm_temp3);
|
||||
|
||||
uni_vsubps(vmm_temp3, vmm_boxes_coord0, vmm_temp1); // x_center - width/2
|
||||
uni_vsubps(vmm_temp4, vmm_boxes_coord1, vmm_temp2); // y_center - height/2
|
||||
|
||||
uni_vmovups(vmm_boxes_coord1, vmm_temp3);
|
||||
uni_vmovups(vmm_boxes_coord0, vmm_temp4);
|
||||
}
|
||||
|
||||
uni_vsubps(vmm_temp1, vmm_boxes_coord2, vmm_boxes_coord0);
|
||||
uni_vsubps(vmm_temp2, vmm_boxes_coord3, vmm_boxes_coord1);
|
||||
uni_vmulps(vmm_temp1, vmm_temp1, vmm_temp2); // boxes area
|
||||
|
||||
uni_vsubps(vmm_temp2, vmm_candidate_coord2, vmm_candidate_coord0);
|
||||
uni_vsubps(vmm_temp3, vmm_candidate_coord3, vmm_candidate_coord1);
|
||||
uni_vmulps(vmm_temp2, vmm_temp2, vmm_temp3); // candidate(bc) area // candidate area calculate once and check if 0
|
||||
|
||||
uni_vaddps(vmm_temp1, vmm_temp1, vmm_temp2); // areaI + areaJ to free vmm_temp2
|
||||
|
||||
// y of intersection
|
||||
uni_vminps(vmm_temp3, vmm_boxes_coord2, vmm_candidate_coord2); // min(Ymax)
|
||||
uni_vmaxps(vmm_temp4, vmm_boxes_coord0, vmm_candidate_coord0); // max(Ymin)
|
||||
uni_vsubps(vmm_temp3, vmm_temp3, vmm_temp4); // min(Ymax) - max(Ymin)
|
||||
uni_vmaxps(vmm_temp3, vmm_temp3, vmm_zero);
|
||||
|
||||
// x of intersection
|
||||
uni_vminps(vmm_temp4, vmm_boxes_coord3, vmm_candidate_coord3); // min(Xmax)
|
||||
uni_vmaxps(vmm_temp2, vmm_boxes_coord1, vmm_candidate_coord1); // max(Xmin)
|
||||
uni_vsubps(vmm_temp4, vmm_temp4, vmm_temp2); // min(Xmax) - max(Xmin)
|
||||
uni_vmaxps(vmm_temp4, vmm_temp4, vmm_zero);
|
||||
|
||||
// intersection_area
|
||||
uni_vmulps(vmm_temp3, vmm_temp3, vmm_temp4);
|
||||
|
||||
// iou: intersection_area / (areaI + areaJ - intersection_area);
|
||||
uni_vsubps(vmm_temp1, vmm_temp1, vmm_temp3);
|
||||
uni_vdivps(vmm_temp3, vmm_temp3, vmm_temp1);
|
||||
}
|
||||
|
||||
// std::exp(scale * iou * iou)
|
||||
inline void soft_coeff() {
|
||||
uni_vmulps(vmm_temp3, vmm_temp3, vmm_temp3);
|
||||
uni_vmulps(vmm_temp3, vmm_temp3, vmm_scale);
|
||||
exp_injector->compute_vector_range(vmm_temp3.getIdx(), vmm_temp3.getIdx() + 1);
|
||||
}
|
||||
|
||||
inline void horizontal_mul_xmm(const Xbyak::Xmm &xmm_weight, const Xbyak::Xmm &xmm_aux) {
|
||||
uni_vmovshdup(xmm_aux, xmm_weight); // weight:1,2,3,4; aux:2,2,4,4
|
||||
uni_vmulps(xmm_weight, xmm_weight, xmm_aux); // weight:1*2,2*2,3*4,4*4
|
||||
uni_vmovhlps(xmm_aux, xmm_aux, xmm_weight); // aux:3*4,4*4,4,4
|
||||
uni_vmulps(xmm_weight, xmm_weight, xmm_aux); // weight:1*2*3*4,...
|
||||
}
|
||||
|
||||
// horizontal mul for vmm_weight(Vmm(3)), temp1 and temp2 as aux
|
||||
inline void horizontal_mul() {
|
||||
Xbyak::Xmm xmm_weight = Xbyak::Xmm(vmm_temp3.getIdx());
|
||||
Xbyak::Xmm xmm_temp1 = Xbyak::Xmm(vmm_temp1.getIdx());
|
||||
Xbyak::Xmm xmm_temp2 = Xbyak::Xmm(vmm_temp2.getIdx());
|
||||
if (isa == cpu::x64::sse41) {
|
||||
horizontal_mul_xmm(xmm_weight, xmm_temp1);
|
||||
} else if (isa == cpu::x64::avx2) {
|
||||
Xbyak::Ymm ymm_weight = Xbyak::Ymm(vmm_temp3.getIdx());
|
||||
vextractf128(xmm_temp1, ymm_weight, 0);
|
||||
vextractf128(xmm_temp2, ymm_weight, 1);
|
||||
uni_vmulps(xmm_weight, xmm_temp1, xmm_temp2);
|
||||
horizontal_mul_xmm(xmm_weight, xmm_temp1);
|
||||
} else {
|
||||
Xbyak::Zmm zmm_weight = Xbyak::Zmm(vmm_temp3.getIdx());
|
||||
vextractf32x4(xmm_temp1, zmm_weight, 0);
|
||||
vextractf32x4(xmm_temp2, zmm_weight, 1);
|
||||
uni_vmulps(xmm_temp1, xmm_temp1, xmm_temp2);
|
||||
vextractf32x4(xmm_temp2, zmm_weight, 2);
|
||||
vextractf32x4(xmm_weight, zmm_weight, 3);
|
||||
uni_vmulps(xmm_weight, xmm_weight, xmm_temp2);
|
||||
uni_vmulps(xmm_weight, xmm_weight, xmm_temp1);
|
||||
horizontal_mul_xmm(xmm_weight, xmm_temp1);
|
||||
}
|
||||
}
|
||||
|
||||
inline void prepare_table() {
|
||||
auto broadcast_d = [&](int val) {
|
||||
for (size_t d = 0; d < vlen / sizeof(int); ++d) {
|
||||
dd(val);
|
||||
}
|
||||
};
|
||||
|
||||
align(64);
|
||||
L(l_table_constant);
|
||||
broadcast_d(0x3f000000); // 0.5f
|
||||
dw(0x0001);
|
||||
}
|
||||
};
|
||||
|
||||
bool MKLDNNNonMaxSuppressionNode::isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, std::string& errorMessage) noexcept {
|
||||
try {
|
||||
@@ -42,7 +570,7 @@ bool MKLDNNNonMaxSuppressionNode::isSupportedOperation(const std::shared_ptr<con
|
||||
}
|
||||
|
||||
MKLDNNNonMaxSuppressionNode::MKLDNNNonMaxSuppressionNode(const std::shared_ptr<ngraph::Node>& op, const mkldnn::engine& eng,
|
||||
MKLDNNWeightsSharing::Ptr &cache) : MKLDNNNode(op, eng, cache) {
|
||||
MKLDNNWeightsSharing::Ptr &cache) : MKLDNNNode(op, eng, cache), isSoftSuppressedByIOU(true) {
|
||||
std::string errorMessage;
|
||||
if (!isSupportedOperation(op, errorMessage)) {
|
||||
IE_THROW(NotImplemented) << errorMessage;
|
||||
@@ -57,12 +585,12 @@ MKLDNNNonMaxSuppressionNode::MKLDNNNonMaxSuppressionNode(const std::shared_ptr<n
|
||||
IE_THROW() << errorPrefix << "has incorrect number of output edges: " << getOriginalOutputsNumber();
|
||||
|
||||
if (const auto nms5 = std::dynamic_pointer_cast<const ngraph::op::v5::NonMaxSuppression>(op)) {
|
||||
boxEncodingType = static_cast<boxEncoding>(nms5->get_box_encoding());
|
||||
sort_result_descending = nms5->get_sort_result_descending();
|
||||
boxEncodingType = static_cast<NMSBoxEncodeType>(nms5->get_box_encoding());
|
||||
sortResultDescending = nms5->get_sort_result_descending();
|
||||
// TODO [DS NMS]: remove when nodes from models where nms is not last node in model supports DS
|
||||
} else if (const auto nmsIe = std::dynamic_pointer_cast<const ngraph::op::internal::NonMaxSuppressionIEInternal>(op)) {
|
||||
boxEncodingType = nmsIe->m_center_point_box ? boxEncoding::CENTER : boxEncoding::CORNER;
|
||||
sort_result_descending = nmsIe->m_sort_result_descending;
|
||||
boxEncodingType = nmsIe->m_center_point_box ? NMSBoxEncodeType::CENTER : NMSBoxEncodeType::CORNER;
|
||||
sortResultDescending = nmsIe->m_sort_result_descending;
|
||||
} else {
|
||||
const auto &typeInfo = op->get_type_info();
|
||||
IE_THROW() << errorPrefix << " doesn't support NMS: " << typeInfo.name << " v" << typeInfo.version;
|
||||
@@ -125,32 +653,63 @@ void MKLDNNNonMaxSuppressionNode::initSupportedPrimitiveDescriptors() {
|
||||
outDataConf.emplace_back(LayoutType::ncsp, outPrecision);
|
||||
}
|
||||
|
||||
addSupportedPrimDesc(inDataConf, outDataConf, impl_desc_type::ref_any);
|
||||
impl_desc_type impl_type;
|
||||
if (mayiuse(cpu::x64::avx512_common)) {
|
||||
impl_type = impl_desc_type::jit_avx512;
|
||||
} else if (mayiuse(cpu::x64::avx2)) {
|
||||
impl_type = impl_desc_type::jit_avx2;
|
||||
} else if (mayiuse(cpu::x64::sse41)) {
|
||||
impl_type = impl_desc_type::jit_sse42;
|
||||
} else {
|
||||
impl_type = impl_desc_type::ref;
|
||||
}
|
||||
|
||||
addSupportedPrimDesc(inDataConf, outDataConf, impl_type);
|
||||
|
||||
// as only FP32 and ncsp is supported, and kernel is shape agnostic, we can create here. There is no need to recompilation.
|
||||
createJitKernel();
|
||||
}
|
||||
|
||||
void MKLDNNNonMaxSuppressionNode::prepareParams() {
|
||||
const auto& boxes_dims = isDynamicNode() ? getParentEdgesAtPort(NMS_BOXES)[0]->getMemory().getStaticDims() :
|
||||
const auto& boxesDims = isDynamicNode() ? getParentEdgesAtPort(NMS_BOXES)[0]->getMemory().getStaticDims() :
|
||||
getInputShapeAtPort(NMS_BOXES).getStaticDims();
|
||||
const auto& scores_dims = isDynamicNode() ? getParentEdgesAtPort(NMS_SCORES)[0]->getMemory().getStaticDims() :
|
||||
const auto& scoresDims = isDynamicNode() ? getParentEdgesAtPort(NMS_SCORES)[0]->getMemory().getStaticDims() :
|
||||
getInputShapeAtPort(NMS_SCORES).getStaticDims();
|
||||
|
||||
num_batches = boxes_dims[0];
|
||||
num_boxes = boxes_dims[1];
|
||||
num_classes = scores_dims[1];
|
||||
if (num_batches != scores_dims[0])
|
||||
IE_THROW() << errorPrefix << " num_batches is different in 'boxes' and 'scores' inputs";
|
||||
if (num_boxes != scores_dims[2])
|
||||
IE_THROW() << errorPrefix << " num_boxes is different in 'boxes' and 'scores' inputs";
|
||||
numBatches = boxesDims[0];
|
||||
numBoxes = boxesDims[1];
|
||||
numClasses = scoresDims[1];
|
||||
if (numBatches != scoresDims[0])
|
||||
IE_THROW() << errorPrefix << " numBatches is different in 'boxes' and 'scores' inputs";
|
||||
if (numBoxes != scoresDims[2])
|
||||
IE_THROW() << errorPrefix << " numBoxes is different in 'boxes' and 'scores' inputs";
|
||||
|
||||
numFiltBox.resize(num_batches);
|
||||
numFiltBox.resize(numBatches);
|
||||
for (auto & i : numFiltBox)
|
||||
i.resize(num_classes);
|
||||
i.resize(numClasses);
|
||||
}
|
||||
|
||||
bool MKLDNNNonMaxSuppressionNode::isExecutable() const {
|
||||
return isDynamicNode() || MKLDNNNode::isExecutable();
|
||||
}
|
||||
|
||||
void MKLDNNNonMaxSuppressionNode::createJitKernel() {
|
||||
auto jcp = jit_nms_config_params();
|
||||
jcp.box_encode_type = boxEncodingType;
|
||||
jcp.is_soft_suppressed_by_iou = isSoftSuppressedByIOU;
|
||||
|
||||
if (mayiuse(cpu::x64::avx512_common)) {
|
||||
nms_kernel.reset(new jit_uni_nms_kernel_f32<cpu::x64::avx512_common>(jcp));
|
||||
} else if (mayiuse(cpu::x64::avx2)) {
|
||||
nms_kernel.reset(new jit_uni_nms_kernel_f32<cpu::x64::avx2>(jcp));
|
||||
} else if (mayiuse(cpu::x64::sse41)) {
|
||||
nms_kernel.reset(new jit_uni_nms_kernel_f32<cpu::x64::sse41>(jcp));
|
||||
}
|
||||
|
||||
if (nms_kernel)
|
||||
nms_kernel->create_ker();
|
||||
}
|
||||
|
||||
void MKLDNNNonMaxSuppressionNode::executeDynamicImpl(mkldnn::stream strm) {
|
||||
if (hasEmptyInputTensors() || (inputShapes.size() > NMS_MAXOUTPUTBOXESPERCLASS &&
|
||||
reinterpret_cast<int *>(getParentEdgeAt(NMS_MAXOUTPUTBOXESPERCLASS)->getMemoryPtr()->GetPtr())[0] == 0)) {
|
||||
@@ -169,35 +728,35 @@ void MKLDNNNonMaxSuppressionNode::execute(mkldnn::stream strm) {
|
||||
const float *scores = reinterpret_cast<const float *>(getParentEdgeAt(NMS_SCORES)->getMemoryPtr()->GetPtr());
|
||||
|
||||
if (inputShapes.size() > NMS_MAXOUTPUTBOXESPERCLASS) {
|
||||
max_output_boxes_per_class = reinterpret_cast<int *>(getParentEdgeAt(NMS_MAXOUTPUTBOXESPERCLASS)->getMemoryPtr()->GetPtr())[0];
|
||||
maxOutputBoxesPerClass = reinterpret_cast<int *>(getParentEdgeAt(NMS_MAXOUTPUTBOXESPERCLASS)->getMemoryPtr()->GetPtr())[0];
|
||||
}
|
||||
|
||||
max_output_boxes_per_class = std::min(max_output_boxes_per_class, num_boxes);
|
||||
maxOutputBoxesPerClass = std::min(maxOutputBoxesPerClass, numBoxes);
|
||||
|
||||
if (max_output_boxes_per_class == 0) {
|
||||
if (maxOutputBoxesPerClass == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (inputShapes.size() > NMS_IOUTHRESHOLD)
|
||||
iou_threshold = reinterpret_cast<float *>(getParentEdgeAt(NMS_IOUTHRESHOLD)->getMemoryPtr()->GetPtr())[0];
|
||||
iouThreshold = reinterpret_cast<float *>(getParentEdgeAt(NMS_IOUTHRESHOLD)->getMemoryPtr()->GetPtr())[0];
|
||||
|
||||
if (inputShapes.size() > NMS_SCORETHRESHOLD)
|
||||
score_threshold = reinterpret_cast<float *>(getParentEdgeAt(NMS_SCORETHRESHOLD)->getMemoryPtr()->GetPtr())[0];
|
||||
scoreThreshold = reinterpret_cast<float *>(getParentEdgeAt(NMS_SCORETHRESHOLD)->getMemoryPtr()->GetPtr())[0];
|
||||
|
||||
if (inputShapes.size() > NMS_SOFTNMSSIGMA)
|
||||
soft_nms_sigma = reinterpret_cast<float *>(getParentEdgeAt(NMS_SOFTNMSSIGMA)->getMemoryPtr()->GetPtr())[0];
|
||||
softNMSSigma = reinterpret_cast<float *>(getParentEdgeAt(NMS_SOFTNMSSIGMA)->getMemoryPtr()->GetPtr())[0];
|
||||
scale = 0.0f;
|
||||
if (soft_nms_sigma > 0.0) {
|
||||
scale = -0.5f / soft_nms_sigma;
|
||||
if (softNMSSigma > 0.0) {
|
||||
scale = -0.5f / softNMSSigma;
|
||||
}
|
||||
|
||||
auto boxesStrides = getParentEdgeAt(NMS_BOXES)->getMemory().GetDescWithType<BlockedMemoryDesc>()->getStrides();
|
||||
auto scoresStrides = getParentEdgeAt(NMS_SCORES)->getMemory().GetDescWithType<BlockedMemoryDesc>()->getStrides();
|
||||
|
||||
const auto maxNumberOfBoxes = max_output_boxes_per_class * num_batches * num_classes;
|
||||
const auto maxNumberOfBoxes = maxOutputBoxesPerClass * numBatches * numClasses;
|
||||
std::vector<filteredBoxes> filtBoxes(maxNumberOfBoxes);
|
||||
|
||||
if (soft_nms_sigma == 0.0f) {
|
||||
if (softNMSSigma == 0.0f) {
|
||||
nmsWithoutSoftSigma(boxes, scores, boxesStrides, scoresStrides, filtBoxes);
|
||||
} else {
|
||||
nmsWithSoftSigma(boxes, scores, boxesStrides, scoresStrides, filtBoxes);
|
||||
@@ -205,9 +764,9 @@ void MKLDNNNonMaxSuppressionNode::execute(mkldnn::stream strm) {
|
||||
|
||||
size_t startOffset = numFiltBox[0][0];
|
||||
for (size_t b = 0; b < numFiltBox.size(); b++) {
|
||||
size_t batchOffset = b*num_classes*max_output_boxes_per_class;
|
||||
size_t batchOffset = b*numClasses*maxOutputBoxesPerClass;
|
||||
for (size_t c = (b == 0 ? 1 : 0); c < numFiltBox[b].size(); c++) {
|
||||
size_t offset = batchOffset + c*max_output_boxes_per_class;
|
||||
size_t offset = batchOffset + c*maxOutputBoxesPerClass;
|
||||
for (size_t i = 0; i < numFiltBox[b][c]; i++) {
|
||||
filtBoxes[startOffset + i] = filtBoxes[offset + i];
|
||||
}
|
||||
@@ -218,7 +777,7 @@ void MKLDNNNonMaxSuppressionNode::execute(mkldnn::stream strm) {
|
||||
|
||||
// need more particular comparator to get deterministic behaviour
|
||||
// escape situation when filtred boxes with same score have different position from launch to launch
|
||||
if (sort_result_descending) {
|
||||
if (sortResultDescending) {
|
||||
parallel_sort(filtBoxes.begin(), filtBoxes.end(),
|
||||
[](const filteredBoxes& l, const filteredBoxes& r) {
|
||||
return (l.score > r.score) ||
|
||||
@@ -273,7 +832,7 @@ bool MKLDNNNonMaxSuppressionNode::created() const {
|
||||
|
||||
float MKLDNNNonMaxSuppressionNode::intersectionOverUnion(const float *boxesI, const float *boxesJ) {
|
||||
float yminI, xminI, ymaxI, xmaxI, yminJ, xminJ, ymaxJ, xmaxJ;
|
||||
if (boxEncodingType == boxEncoding::CENTER) {
|
||||
if (boxEncodingType == NMSBoxEncodeType::CENTER) {
|
||||
// box format: x_center, y_center, width, height
|
||||
yminI = boxesI[1] - boxesI[3] / 2.f;
|
||||
xminI = boxesI[0] - boxesI[2] / 2.f;
|
||||
@@ -312,99 +871,205 @@ void MKLDNNNonMaxSuppressionNode::nmsWithSoftSigma(const float *boxes, const flo
|
||||
return l.score < r.score || ((l.score == r.score) && (l.idx > r.idx));
|
||||
};
|
||||
|
||||
// update score, if iou is 0, weight is 1, score does not change
|
||||
// if is_soft_suppressed_by_iou is false, apply for all iou, including iou>iou_threshold, soft suppressed when score < score_threshold
|
||||
// if is_soft_suppressed_by_iou is true, hard suppressed by iou_threshold, then soft suppress
|
||||
auto coeff = [&](float iou) {
|
||||
const float weight = std::exp(scale * iou * iou);
|
||||
return iou <= iou_threshold ? weight : 0.0f;
|
||||
if (isSoftSuppressedByIOU && iou > iouThreshold)
|
||||
return 0.0f;
|
||||
return std::exp(scale * iou * iou);
|
||||
};
|
||||
|
||||
parallel_for2d(num_batches, num_classes, [&](int batch_idx, int class_idx) {
|
||||
std::vector<filteredBoxes> fb;
|
||||
parallel_for2d(numBatches, numClasses, [&](int batch_idx, int class_idx) {
|
||||
std::vector<filteredBoxes> selectedBoxes;
|
||||
const float *boxesPtr = boxes + batch_idx * boxesStrides[0];
|
||||
const float *scoresPtr = scores + batch_idx * scoresStrides[0] + class_idx * scoresStrides[1];
|
||||
|
||||
std::priority_queue<boxInfo, std::vector<boxInfo>, decltype(less)> sorted_boxes(less);
|
||||
for (int box_idx = 0; box_idx < num_boxes; box_idx++) {
|
||||
if (scoresPtr[box_idx] > score_threshold)
|
||||
std::priority_queue<boxInfo, std::vector<boxInfo>, decltype(less)> sorted_boxes(less); // score, box_id, suppress_begin_index
|
||||
for (int box_idx = 0; box_idx < numBoxes; box_idx++) {
|
||||
if (scoresPtr[box_idx] > scoreThreshold)
|
||||
sorted_boxes.emplace(boxInfo({scoresPtr[box_idx], box_idx, 0}));
|
||||
}
|
||||
size_t sortedBoxSize = sorted_boxes.size();
|
||||
size_t maxSeletedBoxNum = std::min(sortedBoxSize, maxOutputBoxesPerClass);
|
||||
selectedBoxes.reserve(maxSeletedBoxNum);
|
||||
if (maxSeletedBoxNum > 0) {
|
||||
// include first directly
|
||||
boxInfo candidateBox = sorted_boxes.top();
|
||||
sorted_boxes.pop();
|
||||
selectedBoxes.push_back({ candidateBox.score, batch_idx, class_idx, candidateBox.idx });
|
||||
if (maxSeletedBoxNum > 1) {
|
||||
if (nms_kernel) {
|
||||
std::vector<float> boxCoord0(maxSeletedBoxNum, 0.0f);
|
||||
std::vector<float> boxCoord1(maxSeletedBoxNum, 0.0f);
|
||||
std::vector<float> boxCoord2(maxSeletedBoxNum, 0.0f);
|
||||
std::vector<float> boxCoord3(maxSeletedBoxNum, 0.0f);
|
||||
|
||||
fb.reserve(sorted_boxes.size());
|
||||
if (sorted_boxes.size() > 0) {
|
||||
while (fb.size() < max_output_boxes_per_class && !sorted_boxes.empty()) {
|
||||
boxInfo currBox = sorted_boxes.top();
|
||||
float origScore = currBox.score;
|
||||
sorted_boxes.pop();
|
||||
boxCoord0[0] = boxesPtr[candidateBox.idx * 4];
|
||||
boxCoord1[0] = boxesPtr[candidateBox.idx * 4 + 1];
|
||||
boxCoord2[0] = boxesPtr[candidateBox.idx * 4 + 2];
|
||||
boxCoord3[0] = boxesPtr[candidateBox.idx * 4 + 3];
|
||||
|
||||
bool box_is_selected = true;
|
||||
for (int idx = static_cast<int>(fb.size()) - 1; idx >= currBox.suppress_begin_index; idx--) {
|
||||
float iou = intersectionOverUnion(&boxesPtr[currBox.idx * 4], &boxesPtr[fb[idx].box_index * 4]);
|
||||
currBox.score *= coeff(iou);
|
||||
if (iou >= iou_threshold) {
|
||||
box_is_selected = false;
|
||||
break;
|
||||
auto arg = jit_nms_args();
|
||||
arg.iou_threshold = static_cast<float*>(&iouThreshold);
|
||||
arg.score_threshold = static_cast<float*>(&scoreThreshold);
|
||||
arg.scale = static_cast<float*>(&scale);
|
||||
while (selectedBoxes.size() < maxOutputBoxesPerClass && !sorted_boxes.empty()) {
|
||||
boxInfo candidateBox = sorted_boxes.top();
|
||||
float origScore = candidateBox.score;
|
||||
sorted_boxes.pop();
|
||||
|
||||
int candidateStatus = NMSCandidateStatus::SELECTED; // 0 for suppressed, 1 for selected, 2 for updated
|
||||
arg.score = static_cast<float*>(&candidateBox.score);
|
||||
arg.selected_boxes_num = selectedBoxes.size() - candidateBox.suppress_begin_index;
|
||||
arg.selected_boxes_coord[0] = static_cast<float*>(&boxCoord0[candidateBox.suppress_begin_index]);
|
||||
arg.selected_boxes_coord[1] = static_cast<float*>(&boxCoord1[candidateBox.suppress_begin_index]);
|
||||
arg.selected_boxes_coord[2] = static_cast<float*>(&boxCoord2[candidateBox.suppress_begin_index]);
|
||||
arg.selected_boxes_coord[3] = static_cast<float*>(&boxCoord3[candidateBox.suppress_begin_index]);
|
||||
arg.candidate_box = static_cast<const float*>(&boxesPtr[candidateBox.idx * 4]);
|
||||
arg.candidate_status = static_cast<int*>(&candidateStatus);
|
||||
(*nms_kernel)(&arg);
|
||||
|
||||
if (candidateStatus == NMSCandidateStatus::SUPPRESSED) {
|
||||
continue;
|
||||
} else {
|
||||
if (candidateBox.score == origScore) {
|
||||
selectedBoxes.push_back({ candidateBox.score, batch_idx, class_idx, candidateBox.idx });
|
||||
int selectedSize = selectedBoxes.size();
|
||||
boxCoord0[selectedSize - 1] = boxesPtr[candidateBox.idx * 4];
|
||||
boxCoord1[selectedSize - 1] = boxesPtr[candidateBox.idx * 4 + 1];
|
||||
boxCoord2[selectedSize - 1] = boxesPtr[candidateBox.idx * 4 + 2];
|
||||
boxCoord3[selectedSize - 1] = boxesPtr[candidateBox.idx * 4 + 3];
|
||||
} else {
|
||||
candidateBox.suppress_begin_index = selectedBoxes.size();
|
||||
sorted_boxes.push(candidateBox);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (currBox.score <= score_threshold)
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
while (selectedBoxes.size() < maxOutputBoxesPerClass && !sorted_boxes.empty()) {
|
||||
boxInfo candidateBox = sorted_boxes.top();
|
||||
float origScore = candidateBox.score;
|
||||
sorted_boxes.pop();
|
||||
|
||||
currBox.suppress_begin_index = fb.size();
|
||||
if (box_is_selected) {
|
||||
if (currBox.score == origScore) {
|
||||
fb.push_back({ currBox.score, batch_idx, class_idx, currBox.idx });
|
||||
continue;
|
||||
}
|
||||
if (currBox.score > score_threshold) {
|
||||
sorted_boxes.push(currBox);
|
||||
int candidateStatus = NMSCandidateStatus::SELECTED; // 0 for suppressed, 1 for selected, 2 for updated
|
||||
for (int selected_idx = static_cast<int>(selectedBoxes.size()) - 1; selected_idx >= candidateBox.suppress_begin_index; selected_idx--) {
|
||||
float iou = intersectionOverUnion(&boxesPtr[candidateBox.idx * 4], &boxesPtr[selectedBoxes[selected_idx].box_index * 4]);
|
||||
|
||||
// when is_soft_suppressed_by_iou is true, score is decayed to zero and implicitely suppressed if iou > iou_threshold.
|
||||
candidateBox.score *= coeff(iou);
|
||||
// soft suppressed
|
||||
if (candidateBox.score <= scoreThreshold) {
|
||||
candidateStatus = NMSCandidateStatus::SUPPRESSED;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (candidateStatus == NMSCandidateStatus::SUPPRESSED) {
|
||||
continue;
|
||||
} else {
|
||||
if (candidateBox.score == origScore) {
|
||||
selectedBoxes.push_back({ candidateBox.score, batch_idx, class_idx, candidateBox.idx });
|
||||
} else {
|
||||
candidateBox.suppress_begin_index = selectedBoxes.size();
|
||||
sorted_boxes.push(candidateBox);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
numFiltBox[batch_idx][class_idx] = fb.size();
|
||||
size_t offset = batch_idx*num_classes*max_output_boxes_per_class + class_idx*max_output_boxes_per_class;
|
||||
for (size_t i = 0; i < fb.size(); i++) {
|
||||
filtBoxes[offset + i] = fb[i];
|
||||
numFiltBox[batch_idx][class_idx] = selectedBoxes.size();
|
||||
size_t offset = batch_idx*numClasses*maxOutputBoxesPerClass + class_idx*maxOutputBoxesPerClass;
|
||||
for (size_t i = 0; i < selectedBoxes.size(); i++) {
|
||||
filtBoxes[offset + i] = selectedBoxes[i];
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void MKLDNNNonMaxSuppressionNode::nmsWithoutSoftSigma(const float *boxes, const float *scores, const VectorDims &boxesStrides,
|
||||
const VectorDims &scoresStrides, std::vector<filteredBoxes> &filtBoxes) {
|
||||
int max_out_box = static_cast<int>(max_output_boxes_per_class);
|
||||
parallel_for2d(num_batches, num_classes, [&](int batch_idx, int class_idx) {
|
||||
int max_out_box = static_cast<int>(maxOutputBoxesPerClass);
|
||||
parallel_for2d(numBatches, numClasses, [&](int batch_idx, int class_idx) {
|
||||
const float *boxesPtr = boxes + batch_idx * boxesStrides[0];
|
||||
const float *scoresPtr = scores + batch_idx * scoresStrides[0] + class_idx * scoresStrides[1];
|
||||
|
||||
std::vector<std::pair<float, int>> sorted_boxes;
|
||||
for (int box_idx = 0; box_idx < num_boxes; box_idx++) {
|
||||
if (scoresPtr[box_idx] > score_threshold)
|
||||
std::vector<std::pair<float, int>> sorted_boxes; // score, box_idx
|
||||
for (int box_idx = 0; box_idx < numBoxes; box_idx++) {
|
||||
if (scoresPtr[box_idx] > scoreThreshold)
|
||||
sorted_boxes.emplace_back(std::make_pair(scoresPtr[box_idx], box_idx));
|
||||
}
|
||||
|
||||
int io_selection_size = 0;
|
||||
if (sorted_boxes.size() > 0) {
|
||||
size_t sortedBoxSize = sorted_boxes.size();
|
||||
if (sortedBoxSize > 0) {
|
||||
parallel_sort(sorted_boxes.begin(), sorted_boxes.end(),
|
||||
[](const std::pair<float, int>& l, const std::pair<float, int>& r) {
|
||||
return (l.first > r.first || ((l.first == r.first) && (l.second < r.second)));
|
||||
});
|
||||
int offset = batch_idx*num_classes*max_output_boxes_per_class + class_idx*max_output_boxes_per_class;
|
||||
int offset = batch_idx*numClasses*maxOutputBoxesPerClass + class_idx*maxOutputBoxesPerClass;
|
||||
filtBoxes[offset + 0] = filteredBoxes(sorted_boxes[0].first, batch_idx, class_idx, sorted_boxes[0].second);
|
||||
io_selection_size++;
|
||||
for (size_t box_idx = 1; (box_idx < sorted_boxes.size()) && (io_selection_size < max_out_box); box_idx++) {
|
||||
bool box_is_selected = true;
|
||||
for (int idx = io_selection_size - 1; idx >= 0; idx--) {
|
||||
float iou = intersectionOverUnion(&boxesPtr[sorted_boxes[box_idx].second * 4], &boxesPtr[filtBoxes[offset + idx].box_index * 4]);
|
||||
if (iou >= iou_threshold) {
|
||||
box_is_selected = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (sortedBoxSize > 1) {
|
||||
if (nms_kernel) {
|
||||
std::vector<float> boxCoord0(sortedBoxSize, 0.0f);
|
||||
std::vector<float> boxCoord1(sortedBoxSize, 0.0f);
|
||||
std::vector<float> boxCoord2(sortedBoxSize, 0.0f);
|
||||
std::vector<float> boxCoord3(sortedBoxSize, 0.0f);
|
||||
|
||||
if (box_is_selected) {
|
||||
filtBoxes[offset + io_selection_size] = filteredBoxes(sorted_boxes[box_idx].first, batch_idx, class_idx, sorted_boxes[box_idx].second);
|
||||
io_selection_size++;
|
||||
boxCoord0[0] = boxesPtr[sorted_boxes[0].second * 4];
|
||||
boxCoord1[0] = boxesPtr[sorted_boxes[0].second * 4 + 1];
|
||||
boxCoord2[0] = boxesPtr[sorted_boxes[0].second * 4 + 2];
|
||||
boxCoord3[0] = boxesPtr[sorted_boxes[0].second * 4 + 3];
|
||||
|
||||
auto arg = jit_nms_args();
|
||||
arg.iou_threshold = static_cast<float*>(&iouThreshold);
|
||||
arg.score_threshold = static_cast<float*>(&scoreThreshold);
|
||||
arg.scale = static_cast<float*>(&scale);
|
||||
// box start index do not change for hard supresion
|
||||
arg.selected_boxes_coord[0] = static_cast<float*>(&boxCoord0[0]);
|
||||
arg.selected_boxes_coord[1] = static_cast<float*>(&boxCoord1[0]);
|
||||
arg.selected_boxes_coord[2] = static_cast<float*>(&boxCoord2[0]);
|
||||
arg.selected_boxes_coord[3] = static_cast<float*>(&boxCoord3[0]);
|
||||
|
||||
for (size_t candidate_idx = 1; (candidate_idx < sortedBoxSize) && (io_selection_size < max_out_box); candidate_idx++) {
|
||||
int candidateStatus = NMSCandidateStatus::SELECTED; // 0 for suppressed, 1 for selected
|
||||
arg.selected_boxes_num = io_selection_size;
|
||||
arg.candidate_box = static_cast<const float*>(&boxesPtr[sorted_boxes[candidate_idx].second * 4]);
|
||||
arg.candidate_status = static_cast<int*>(&candidateStatus);
|
||||
(*nms_kernel)(&arg);
|
||||
if (candidateStatus == NMSCandidateStatus::SELECTED) {
|
||||
boxCoord0[io_selection_size] = boxesPtr[sorted_boxes[candidate_idx].second * 4];
|
||||
boxCoord1[io_selection_size] = boxesPtr[sorted_boxes[candidate_idx].second * 4 + 1];
|
||||
boxCoord2[io_selection_size] = boxesPtr[sorted_boxes[candidate_idx].second * 4 + 2];
|
||||
boxCoord3[io_selection_size] = boxesPtr[sorted_boxes[candidate_idx].second * 4 + 3];
|
||||
filtBoxes[offset + io_selection_size] =
|
||||
filteredBoxes(sorted_boxes[candidate_idx].first, batch_idx, class_idx, sorted_boxes[candidate_idx].second);
|
||||
io_selection_size++;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (size_t candidate_idx = 1; (candidate_idx < sortedBoxSize) && (io_selection_size < max_out_box); candidate_idx++) {
|
||||
int candidateStatus = NMSCandidateStatus::SELECTED; // 0 for suppressed, 1 for selected
|
||||
for (int selected_idx = io_selection_size - 1; selected_idx >= 0; selected_idx--) {
|
||||
float iou = intersectionOverUnion(&boxesPtr[sorted_boxes[candidate_idx].second * 4],
|
||||
&boxesPtr[filtBoxes[offset + selected_idx].box_index * 4]);
|
||||
if (iou >= iouThreshold) {
|
||||
candidateStatus = NMSCandidateStatus::SUPPRESSED;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (candidateStatus == NMSCandidateStatus::SELECTED) {
|
||||
filtBoxes[offset + io_selection_size] =
|
||||
filteredBoxes(sorted_boxes[candidate_idx].first, batch_idx, class_idx, sorted_boxes[candidate_idx].second);
|
||||
io_selection_size++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
numFiltBox[batch_idx][class_idx] = io_selection_size;
|
||||
});
|
||||
}
|
||||
|
||||
@@ -10,10 +10,56 @@
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#define BOX_COORD_NUM 4
|
||||
|
||||
using namespace InferenceEngine;
|
||||
|
||||
namespace MKLDNNPlugin {
|
||||
|
||||
enum class NMSBoxEncodeType {
|
||||
CORNER,
|
||||
CENTER
|
||||
};
|
||||
|
||||
enum NMSCandidateStatus {
|
||||
SUPPRESSED = 0,
|
||||
SELECTED = 1,
|
||||
UPDATED = 2
|
||||
};
|
||||
|
||||
struct jit_nms_config_params {
|
||||
NMSBoxEncodeType box_encode_type;
|
||||
bool is_soft_suppressed_by_iou;
|
||||
};
|
||||
|
||||
struct jit_nms_args {
|
||||
const void* selected_boxes_coord[BOX_COORD_NUM];
|
||||
size_t selected_boxes_num;
|
||||
const void* candidate_box;
|
||||
const void* iou_threshold;
|
||||
void* candidate_status;
|
||||
// for soft suppression, score *= scale * iou * iou;
|
||||
const void* score_threshold;
|
||||
const void* scale;
|
||||
void* score;
|
||||
};
|
||||
|
||||
struct jit_uni_nms_kernel {
|
||||
void (*ker_)(const jit_nms_args *);
|
||||
|
||||
void operator()(const jit_nms_args *args) {
|
||||
assert(ker_);
|
||||
ker_(args);
|
||||
}
|
||||
|
||||
explicit jit_uni_nms_kernel(jit_nms_config_params jcp_) : ker_(nullptr), jcp(jcp_) {}
|
||||
virtual ~jit_uni_nms_kernel() {}
|
||||
|
||||
virtual void create_ker() = 0;
|
||||
|
||||
jit_nms_config_params jcp;
|
||||
};
|
||||
|
||||
class MKLDNNNonMaxSuppressionNode : public MKLDNNNode {
|
||||
public:
|
||||
MKLDNNNonMaxSuppressionNode(const std::shared_ptr<ngraph::Node>& op, const mkldnn::engine& eng, MKLDNNWeightsSharing::Ptr &cache);
|
||||
@@ -73,23 +119,20 @@ private:
|
||||
NMS_VALIDOUTPUTS
|
||||
};
|
||||
|
||||
NMSBoxEncodeType boxEncodingType = NMSBoxEncodeType::CORNER;
|
||||
bool sortResultDescending = true;
|
||||
|
||||
enum class boxEncoding {
|
||||
CORNER,
|
||||
CENTER
|
||||
};
|
||||
boxEncoding boxEncodingType = boxEncoding::CORNER;
|
||||
bool sort_result_descending = true;
|
||||
size_t numBatches = 0;
|
||||
size_t numBoxes = 0;
|
||||
size_t numClasses = 0;
|
||||
|
||||
size_t num_batches = 0;
|
||||
size_t num_boxes = 0;
|
||||
size_t num_classes = 0;
|
||||
|
||||
size_t max_output_boxes_per_class = 0lu;
|
||||
float iou_threshold = 0.0f;
|
||||
float score_threshold = 0.0f;
|
||||
float soft_nms_sigma = 0.0f;
|
||||
size_t maxOutputBoxesPerClass = 0lu;
|
||||
float iouThreshold = 0.0f;
|
||||
float scoreThreshold = 0.0f;
|
||||
float softNMSSigma = 0.0f;
|
||||
float scale = 1.f;
|
||||
// control placeholder for NMS in new opset.
|
||||
bool isSoftSuppressedByIOU = true;
|
||||
|
||||
std::string errorPrefix;
|
||||
|
||||
@@ -99,6 +142,9 @@ private:
|
||||
void checkPrecision(const Precision& prec, const std::vector<Precision>& precList, const std::string& name, const std::string& type);
|
||||
void check1DInput(const Shape& shape, const std::vector<Precision>& precList, const std::string& name, const size_t port);
|
||||
void checkOutput(const Shape& shape, const std::vector<Precision>& precList, const std::string& name, const size_t port);
|
||||
|
||||
void createJitKernel();
|
||||
std::shared_ptr<jit_uni_nms_kernel> nms_kernel;
|
||||
};
|
||||
|
||||
} // namespace MKLDNNPlugin
|
||||
|
||||
Reference in New Issue
Block a user