[CPU] Fix mixing VEX and non-VEX instructions (#7238)

Quote: The Skylake microarchitecture implements a different state
machine than prior generations to manage the YMM state transition
associated with mixing SSE and AVX instructions.
It no longer saves the entire upper YMM state when executing
an SSE instruction when in “Modified and Unsaved” state,
but saves the upper bits of individual register.
As a result, mixing SSE and AVX instructions will experience
a penalty associated with partial register dependency of
the destination registers being used and additional blend
operation on the upper bits of the destination registers.

Such type of penalties have a huge impact on openvino's and oneDNN's kernels.
Basically the mixing of VEX and non-VEX instructions should be avoided.
This commit is contained in:
Egor Duplensky
2021-09-22 10:46:37 +03:00
committed by GitHub
parent 05c641ff7c
commit 6d634d09a4
14 changed files with 148 additions and 146 deletions

View File

@@ -85,8 +85,8 @@ void jit_mul_add_emitter::emit_isa(const std::vector<size_t> &in_vec_idxs, const
if (isa == cpu::x64::sse41) {
h->uni_vmovups(vmm_dst, vmm_src0);
h->mulps(vmm_dst, vmm_src1);
h->addps(vmm_dst, vmm_src2);
h->uni_vmulps(vmm_dst, vmm_src1);
h->uni_vaddps(vmm_dst, vmm_dst, vmm_src2);
} else {
Vmm vmm_mul0;
if (vmm_dst.getIdx() == vmm_src0.getIdx()) {
@@ -656,7 +656,7 @@ void jit_equal_emitter::emit_isa(const std::vector<size_t> &in_vec_idxs, const s
} else if (isa == cpu::x64::avx2) {
h->vcmpeqps(vmm_aux0, vmm_src0, vmm_src1);
h->uni_vmovups(vmm_dst, table_val("zero"));
h->vblendvps(vmm_dst, vmm_dst, table_val("one"), vmm_aux0);
h->uni_vblendvps(vmm_dst, vmm_dst, table_val("one"), vmm_aux0);
} else {
h->vcmpps(k_mask, vmm_src0, vmm_src1, _cmp_eq_oq);
h->uni_vmovups(vmm_dst, table_val("zero"));

View File

@@ -190,7 +190,7 @@ void jit_load_emitter::load_bytes(const Vmm &vmm, const Xbyak::Reg64 &reg, int o
}
if (bytes_to_load >= 8 && bytes_to_load < 16)
h->pinsrq(xmm, addr(start_bytes), 0);
h->uni_vpinsrq(xmm, xmm, addr(start_bytes), 0);
else if (bytes_to_load == 16)
h->uni_vmovdqu(xmm, addr(start_bytes));
@@ -202,17 +202,17 @@ void jit_load_emitter::load_bytes(const Vmm &vmm, const Xbyak::Reg64 &reg, int o
h->uni_vpinsrw(xmm, xmm, addr(start_bytes), 0);
h->uni_vpinsrb(xmm, xmm, addr(start_bytes + 2), 2);
break;
case 4: h->pinsrd(xmm, addr(start_bytes), 0); break;
case 4: h->uni_vpinsrd(xmm, xmm, addr(start_bytes), 0); break;
case 5:
h->pinsrd(xmm, addr(start_bytes), 0);
h->uni_vpinsrd(xmm, xmm, addr(start_bytes), 0);
h->uni_vpinsrb(xmm, xmm, addr(start_bytes + 4), 4);
break;
case 6:
h->pinsrd(xmm, addr(start_bytes), 0);
h->uni_vpinsrd(xmm, xmm, addr(start_bytes), 0);
h->uni_vpinsrw(xmm, xmm, addr(start_bytes + 4), 2);
break;
case 7:
h->pinsrd(xmm, addr(start_bytes), 0);
h->uni_vpinsrd(xmm, xmm, addr(start_bytes), 0);
h->uni_vpinsrw(xmm, xmm, addr(start_bytes + 4), 2);
h->uni_vpinsrb(xmm, xmm, addr(start_bytes + 6), 6);
break;
@@ -223,17 +223,17 @@ void jit_load_emitter::load_bytes(const Vmm &vmm, const Xbyak::Reg64 &reg, int o
h->uni_vpinsrw(xmm, xmm, addr(start_bytes + 8), 4);
h->uni_vpinsrb(xmm, xmm, addr(start_bytes + 10), 10);
break;
case 12: h->pinsrd(xmm, addr(start_bytes + 8), 2); break;
case 12: h->uni_vpinsrd(xmm, xmm, addr(start_bytes + 8), 2); break;
case 13:
h->pinsrd(xmm, addr(start_bytes + 8), 2);
h->uni_vpinsrd(xmm, xmm, addr(start_bytes + 8), 2);
h->uni_vpinsrb(xmm, xmm, addr(start_bytes + 12), 12);
break;
case 14:
h->pinsrd(xmm, addr(start_bytes + 8), 2);
h->uni_vpinsrd(xmm, xmm, addr(start_bytes + 8), 2);
h->uni_vpinsrw(xmm, xmm, addr(start_bytes + 12), 6);
break;
case 15:
h->pinsrd(xmm, addr(start_bytes + 8), 2);
h->uni_vpinsrd(xmm, xmm, addr(start_bytes + 8), 2);
h->uni_vpinsrw(xmm, xmm, addr(start_bytes + 12), 6);
h->uni_vpinsrb(xmm, xmm, addr(start_bytes + 14), 14);
break;
@@ -465,10 +465,7 @@ template <typename Vmm>
if (is_xmm || is_ymm) {
uint8 imm = 1;
imm = ~((imm << load_num) - imm); // shift load_num bit
if (is_xmm)
h->blendps(vmm, table_val(fill_value), imm);
else
h->vblendps(vmm, vmm, table_val(fill_value), imm);
h->uni_vblendps(vmm, vmm, table_val(fill_value), imm);
} else if (is_zmm) {
uint64_t tail_mask = 1;
tail_mask = ~((tail_mask << load_num) - tail_mask);
@@ -668,7 +665,7 @@ template <typename Vmm>
}
if (bytes_to_store >= 8 && bytes_to_store < 16)
h->pextrq(addr(start_bytes), xmm, 0);
h->uni_vpextrq(addr(start_bytes), xmm, 0);
else if (bytes_to_store == 16)
h->uni_vmovdqu(addr(start_bytes), xmm);
@@ -682,17 +679,17 @@ template <typename Vmm>
h->uni_vpextrw(addr(start_bytes), xmm, 0);
h->uni_vpextrb(addr(start_bytes + 2), xmm, 2);
break;
case 4: h->pextrd(addr(start_bytes), xmm, 0); break;
case 4: h->uni_vpextrd(addr(start_bytes), xmm, 0); break;
case 5:
h->pextrd(addr(start_bytes), xmm, 0);
h->uni_vpextrd(addr(start_bytes), xmm, 0);
h->uni_vpextrb(addr(start_bytes + 4), xmm, 4);
break;
case 6:
h->pextrd(addr(start_bytes), xmm, 0);
h->uni_vpextrd(addr(start_bytes), xmm, 0);
h->uni_vpextrw(addr(start_bytes + 4), xmm, 2);
break;
case 7:
h->pextrd(addr(start_bytes), xmm, 0);
h->uni_vpextrd(addr(start_bytes), xmm, 0);
h->uni_vpextrw(addr(start_bytes + 4), xmm, 2);
h->uni_vpextrb(addr(start_bytes + 6), xmm, 6);
break;
@@ -703,17 +700,17 @@ template <typename Vmm>
h->uni_vpextrw(addr(start_bytes + 8), xmm, 4);
h->uni_vpextrb(addr(start_bytes + 10), xmm, 10);
break;
case 12: h->pextrd(addr(start_bytes + 8), xmm, 2); break;
case 12: h->uni_vpextrd(addr(start_bytes + 8), xmm, 2); break;
case 13:
h->pextrd(addr(start_bytes + 8), xmm, 2);
h->uni_vpextrd(addr(start_bytes + 8), xmm, 2);
h->uni_vpextrb(addr(start_bytes + 12), xmm, 12);
break;
case 14:
h->pextrd(addr(start_bytes + 8), xmm, 2);
h->uni_vpextrd(addr(start_bytes + 8), xmm, 2);
h->uni_vpextrw(addr(start_bytes + 12), xmm, 6);
break;
case 15:
h->pextrd(addr(start_bytes + 8), xmm, 2);
h->uni_vpextrd(addr(start_bytes + 8), xmm, 2);
h->uni_vpextrw(addr(start_bytes + 12), xmm, 6);
h->uni_vpextrb(addr(start_bytes + 14), xmm, 14);
break;

View File

@@ -47,21 +47,21 @@ struct jit_uni_permute_kernel_f32 : public jit_uni_permute_kernel, public jit_ge
void load(const Xbyak::Xmm &xmm, const Xbyak::Address &addr) {
switch (jcp.data_size) {
case 16: movups(xmm, addr); break;
case 8: movsd(xmm, addr); break;
case 4: movss(xmm, addr); break;
case 2: pinsrw(xmm, addr, 0x0); break;
case 1: pinsrb(xmm, addr, 0x0); break;
case 16: uni_vmovups(xmm, addr); break;
case 8: uni_vmovsd(xmm, addr); break;
case 4: uni_vmovss(xmm, addr); break;
case 2: uni_vpinsrw(xmm, xmm, addr, 0x0); break;
case 1: uni_vpinsrb(xmm, xmm, addr, 0x0); break;
}
}
void store(const Xbyak::Address &addr, const Xbyak::Xmm &xmm) {
switch (jcp.data_size) {
case 16: movups(addr, xmm); break;
case 8: movsd(addr, xmm); break;
case 4: movss(addr, xmm); break;
case 2: pextrw(addr, xmm, 0x0); break;
case 1: pextrb(addr, xmm, 0x0); break;
case 16: uni_vmovups(addr, xmm); break;
case 8: uni_vmovsd(addr, xmm); break;
case 4: uni_vmovss(addr, xmm); break;
case 2: uni_vpextrw(addr, xmm, 0x0); break;
case 1: uni_vpextrb(addr, xmm, 0x0); break;
}
}

View File

@@ -215,7 +215,7 @@ private:
case memory::data_type::s32:
if (scalar_load) {
mov(reg_tmp_32, op);
movq(xmm_in, reg_tmp_64);
uni_vmovq(xmm_in, reg_tmp_64);
} else {
uni_vmovups(vmm_in, op);
}
@@ -223,7 +223,7 @@ private:
case memory::data_type::s8:
if (scalar_load) {
movsx(reg_tmp_32, op);
movq(xmm_in, reg_tmp_64);
uni_vmovq(xmm_in, reg_tmp_64);
} else {
uni_vpmovsxbd(vmm_in, op);
}
@@ -231,7 +231,7 @@ private:
case memory::data_type::u8:
if (scalar_load) {
movzx(reg_tmp_32, op);
movq(xmm_in, reg_tmp_64);
uni_vmovq(xmm_in, reg_tmp_64);
} else {
uni_vpmovzxbd(vmm_in, op);
}
@@ -541,7 +541,7 @@ private:
if (jcp_.exclude_pad) {
mov(reg_shift, kw_padding[jj]);
imul(reg_shift, reg_tmp_32);
movq(Xmm(vmm_shift.getIdx()), reg_shift);
uni_vmovq(Xmm(vmm_shift.getIdx()), reg_shift);
uni_vbroadcastss(vmm_shift, Xmm(vmm_shift.getIdx()));
uni_vcvtdq2ps(vmm_shift, vmm_shift);
}
@@ -612,7 +612,7 @@ private:
} else {
Ymm ymm_prev_dst = Ymm(vmm_sum.getIdx());
vperm2i128(ymm_prev_dst, ymm_prev_dst, ymm_prev_dst, 0x01);
vpslldq(vmm_sum, vmm_sum, (oc - jcp_.oc_block / 2) * sizeof(float));
uni_vpslldq(vmm_sum, vmm_sum, (oc - jcp_.oc_block / 2) * sizeof(float));
}
uni_vaddps(vmm_dst, vmm_dst, vmm_sum);

View File

@@ -583,7 +583,7 @@ private:
switch (src_prc) {
case Precision::FP32:
case Precision::I32:
movss(xmm_src, op);
uni_vmovss(xmm_src, op);
break;
case Precision::BF16:
uni_vpinsrw(xmm_src, xmm_src, op, 0);
@@ -599,11 +599,11 @@ private:
break;
case Precision::I8:
movsx(reg_tmp_32, op);
movq(xmm_src, reg_tmp_64);
uni_vmovq(xmm_src, reg_tmp_64);
break;
case Precision::U8:
movzx(reg_tmp_32, op);
movq(xmm_src, reg_tmp_64);
uni_vmovq(xmm_src, reg_tmp_64);
break;
default:
assert(!"unknown src_prc");
@@ -730,7 +730,7 @@ private:
switch (dst_prc) {
case Precision::FP32:
case Precision::I32:
movss(op, xmm_dst);
uni_vmovss(op, xmm_dst);
break;
case Precision::BF16:
uni_vpsrld(xmm_dst, xmm_dst, 16);

View File

@@ -150,9 +150,9 @@ struct jit_uni_binarization_kernel : public jit_uni_quantize_kernel, public jit_
uni_vpxor(xmm_wei(0), xmm_wei(0), xmm_wei(0));
uni_vpxor(xmm_mask(0), xmm_mask(0), xmm_mask(0));
movss(xmm_src(0), ptr[reg_from + c * sizeof(float)]);
movss(xmm_wei(0), ptr[reg_thresholds + c * sizeof(float)]);
movss(xmm_mask(0), ptr[reg_output_mask + c * sizeof(float)]);
uni_vmovss(xmm_src(0), ptr[reg_from + c * sizeof(float)]);
uni_vmovss(xmm_wei(0), ptr[reg_thresholds + c * sizeof(float)]);
uni_vmovss(xmm_mask(0), ptr[reg_output_mask + c * sizeof(float)]);
uni_vcmpgtps(xmm_src(0), xmm_src(0), xmm_wei(0));
uni_vpcmpeqd(xmm_src(0), xmm_src(0), xmm_mask(0));
uni_vmovmskps(reg_src_32, xmm_src(0));
@@ -591,13 +591,13 @@ private:
jle(exit_label, T_NEAR);
for (int i = 0; i < jqp_.c % tail4_simd_w; i++) {
movss(xmm_crop_low(0), ptr[reg_crop_low + i * wei_type_size]);
movss(xmm_crop_high(0), ptr[reg_crop_high + i * wei_type_size]);
movss(xmm_input_scale(0), ptr[reg_input_scale + i * wei_type_size]);
movss(xmm_input_shift(0), ptr[reg_input_shift + i * wei_type_size]);
uni_vmovss(xmm_crop_low(0), ptr[reg_crop_low + i * wei_type_size]);
uni_vmovss(xmm_crop_high(0), ptr[reg_crop_high + i * wei_type_size]);
uni_vmovss(xmm_input_scale(0), ptr[reg_input_scale + i * wei_type_size]);
uni_vmovss(xmm_input_shift(0), ptr[reg_input_shift + i * wei_type_size]);
if (do_dequantization) {
movss(xmm_output_scale(0), ptr[reg_output_scale + i * wei_type_size]);
movss(xmm_output_shift(0), ptr[reg_output_shift + i * wei_type_size]);
uni_vmovss(xmm_output_scale(0), ptr[reg_output_scale + i * wei_type_size]);
uni_vmovss(xmm_output_shift(0), ptr[reg_output_shift + i * wei_type_size]);
}
load_scalar(xmm_val(0), ptr[aux_reg_from + i * src_type_size], jqp_.src_prc);
@@ -688,15 +688,15 @@ private:
switch (src_prc) {
case Precision::FP32:
case Precision::I32:
movss(xmm_src, op);
uni_vmovss(xmm_src, op);
break;
case Precision::I8:
movsx(reg_tmp_32, op);
movq(xmm_src, reg_tmp_64);
uni_vmovq(xmm_src, reg_tmp_64);
break;
case Precision::U8:
movzx(reg_tmp_32, op);
movq(xmm_src, reg_tmp_64);
uni_vmovq(xmm_src, reg_tmp_64);
break;
default:
assert(!"unknown src_prc");
@@ -797,7 +797,7 @@ private:
switch (dst_prc) {
case Precision::FP32:
case Precision::I32:
movss(op, xmm_dst);
uni_vmovss(op, xmm_dst);
break;
case Precision::I8:
uni_vpackssdw(xmm_dst, xmm_dst, xmm_dst);

View File

@@ -104,13 +104,13 @@ protected:
auto b = xmm2;
auto c = xmm3;
movdqu(a, xword[src]); // load 4 floats
movdqu(b, a); // b = a
movdqu(c, a); // c = a
pcmpeqd(b, zero); // if (a == 0) b = 1 else b = 0
pand(c, mask); // c = a & 01111111100000000000000000000000
pcmpeqd(c, zero); // if (c == 0) c = 1 else c = 0
ptest(b, c); // if ((!b & c) == 0) CF = 1 else CF = 0
uni_vmovdqu(a, xword[src]); // load 4 floats
uni_vmovdqu(b, a); // b = a
uni_vmovdqu(c, a); // c = a
uni_vpcmpeqd(b, b, zero); // if (a == 0) b = 1 else b = 0
uni_vpand(c, c, mask); // c = a & 01111111100000000000000000000000
uni_vpcmpeqd(c, c, zero); // if (c == 0) c = 1 else c = 0
uni_vtestps(b, c); // if ((!b & c) == 0) CF = 1 else CF = 0
}
template<cpu_isa_t isa>

View File

@@ -1199,10 +1199,10 @@ private:
jl(tail_loop_end_label, T_NEAR);
// get idx for input
movss(Xmm(vmm_tbl_y.getIdx()), ptr[reg_tbl_y]);
uni_vmovss(Xmm(vmm_tbl_y.getIdx()), ptr[reg_tbl_y]);
gather_i32_indices(vmm_index_in_y, reg_index_y, 0, vmm_tbl_y, 1, memory::data_type::s32, true);
movss(Xmm(vmm_val.getIdx()), ptr[reg_tbl_x]);
uni_vmovss(Xmm(vmm_val.getIdx()), ptr[reg_tbl_x]);
gather_i32_indices(vmm_index_in_x, reg_index, 0, vmm_val, 1, memory::data_type::s32, true);
// gather weightX by input idx, used in y0-y3
gather_i32_indices(vmm_weightX0, reg_weight_x, 0, vmm_val, grid_len, memory::data_type::f32, true);
@@ -1430,18 +1430,18 @@ private:
switch (src_dt) {
case memory::data_type::f32:
case memory::data_type::s32:
movss(xmm_src, op);
uni_vmovss(xmm_src, op);
break;
case memory::data_type::s8:
movsx(reg_tmp_32, op);
movq(xmm_src, reg_tmp_64);
uni_vmovq(xmm_src, reg_tmp_64);
break;
case memory::data_type::u8:
movzx(reg_tmp_32, op);
movq(xmm_src, reg_tmp_64);
uni_vmovq(xmm_src, reg_tmp_64);
break;
case memory::data_type::bf16:
pinsrw(xmm_src, op, 0x0);
uni_vpinsrw(xmm_src, xmm_src, op, 0x0);
uni_vpslld(xmm_src, xmm_src, 16);
break;
default:
@@ -1536,7 +1536,7 @@ private:
switch (dst_dt) {
case memory::data_type::f32:
case memory::data_type::s32:
movss(op, xmm_dst);
uni_vmovss(op, xmm_dst);
break;
case memory::data_type::s8:
uni_vpackssdw(xmm_dst, xmm_dst, xmm_dst);
@@ -1552,7 +1552,7 @@ private:
break;
case memory::data_type::bf16:
uni_vpsrld(xmm_dst, xmm_dst, 16);
pextrw(op, xmm_dst, 0x0);
uni_vpextrw(op, xmm_dst, 0x0);
break;
default:
assert(!"unknown dst_dt");

View File

@@ -102,17 +102,17 @@ struct jit_uni_mvn_mean_variance_kernel_f32 : public jit_uni_mvn_mean_variance_k
Xbyak::Ymm ymm_sum = Xbyak::Ymm(vmm_dst.getIdx());
vextractf128(xmm_aux1, ymm_sum, 0);
vextractf128(xmm_aux2, ymm_sum, 1);
addps(xmm_aux1, xmm_aux2);
uni_vaddps(xmm_aux1, xmm_aux1, xmm_aux2);
hsum_store(xmm_aux1);
} else {
Xbyak::Zmm zmm_sum = Xbyak::Zmm(vmm_dst.getIdx());
vextractf32x4(xmm_aux1, zmm_sum, 0);
vextractf32x4(xmm_aux2, zmm_sum, 1);
addps(xmm_aux1, xmm_aux2);
uni_vaddps(xmm_aux1, xmm_aux1, xmm_aux2);
vextractf32x4(xmm_aux2, zmm_sum, 2);
vextractf32x4(xmm_aux3, zmm_sum, 3);
addps(xmm_aux2, xmm_aux3);
addps(xmm_aux1, xmm_aux2);
uni_vaddps(xmm_aux2, xmm_aux2, xmm_aux3);
uni_vaddps(xmm_aux1, xmm_aux1, xmm_aux2);
hsum_store(xmm_aux1);
}
} else {
@@ -342,14 +342,14 @@ private:
}
inline void hsum_store(Xbyak::Xmm xmm_sum) {
movshdup(xmm_aux3, xmm_sum); // sum:1,2,3,4; aux3:2,2,4,4
addps(xmm_sum, xmm_aux3); // sum:1+2,2+2,3+4,4+4
movhlps(xmm_aux3, xmm_sum); // aux3:3+4,4+4,4,4
addps(xmm_sum, xmm_aux3); // sum:1+2+3+4,...
uni_vmovshdup(xmm_aux3, xmm_sum); // sum:1,2,3,4; aux3:2,2,4,4
uni_vaddps(xmm_sum, xmm_sum, xmm_aux3); // sum:1+2,2+2,3+4,4+4
uni_vmovhlps(xmm_aux3, xmm_sum); // aux3:3+4,4+4,4,4
uni_vaddps(xmm_sum, xmm_sum, xmm_aux3); // sum:1+2+3+4,...
if (jcp_.normalize_variance) {
movss(ptr[reg_variance], xmm_sum);
uni_vmovss(ptr[reg_variance], xmm_sum);
} else {
movss(ptr[reg_sum], xmm_sum);
uni_vmovss(ptr[reg_sum], xmm_sum);
}
}
};

View File

@@ -88,17 +88,21 @@ struct jit_uni_normalize_modulo_kernel_f32 : public jit_uni_normalize_modulo_ker
Xbyak::Ymm ymm_sqr_sum = Xbyak::Ymm(vmm_sqr_sum.getIdx());
vextractf128(xmm_aux1, ymm_sqr_sum, 0);
vextractf128(xmm_aux2, ymm_sqr_sum, 1);
addps(xmm_aux1, xmm_aux2);
// vaddps(xmm_aux1, xmm_aux2);
uni_vaddps(xmm_aux1, xmm_aux1, xmm_aux2);
hsum_store(xmm_aux1);
} else {
Xbyak::Zmm zmm_sqr_sum = Xbyak::Zmm(vmm_sqr_sum.getIdx());
vextractf32x4(xmm_aux1, zmm_sqr_sum, 0);
vextractf32x4(xmm_aux2, zmm_sqr_sum, 1);
addps(xmm_aux1, xmm_aux2);
// vaddps(xmm_aux1, xmm_aux2);
uni_vaddps(xmm_aux1, xmm_aux1, xmm_aux2);
vextractf32x4(xmm_aux2, zmm_sqr_sum, 2);
vextractf32x4(xmm_aux3, zmm_sqr_sum, 3);
addps(xmm_aux2, xmm_aux3);
addps(xmm_aux1, xmm_aux2);
// vaddps(xmm_aux2, xmm_aux3);
// vaddps(xmm_aux1, xmm_aux2);
uni_vaddps(xmm_aux2, xmm_aux2, xmm_aux3);
uni_vaddps(xmm_aux1, xmm_aux1, xmm_aux2);
hsum_store(xmm_aux1);
}
}
@@ -124,11 +128,11 @@ private:
Xbyak::Xmm xmm_aux3 = Xbyak::Xmm(4);
inline void hsum_store(Xbyak::Xmm xmm_sqr_sum) {
movshdup(xmm_aux3, xmm_sqr_sum); // sqrt_sum:1,2,3,4; aux3:2,2,4,4
addps(xmm_sqr_sum, xmm_aux3); // sqrt_sum:1+2,2+2,3+4,4+4
movhlps(xmm_aux3, xmm_sqr_sum); // aux3:3+4,4+4,4,4
addps(xmm_sqr_sum, xmm_aux3); // sqrt_sum:1+2+3+4,...
movss(ptr[reg_modulo], xmm_sqr_sum);
uni_vmovshdup(xmm_aux3, xmm_sqr_sum); // sqrt_sum:1,2,3,4; aux3:2,2,4,4
uni_vaddps(xmm_sqr_sum, xmm_sqr_sum, xmm_aux3); // sqrt_sum:1+2,2+2,3+4,4+4
uni_vmovhlps(xmm_aux3, xmm_sqr_sum); // aux3:3+4,4+4,4,4
uni_vaddps(xmm_sqr_sum, xmm_sqr_sum, xmm_aux3); // sqrt_sum:1+2+3+4,...
uni_vmovss(ptr[reg_modulo], xmm_sqr_sum);
}
inline void load_vector(Vmm vmm_src, const Xbyak::Address &op, memory::data_type src_dt) {
@@ -359,6 +363,7 @@ private:
load_scalar(xmm_val, ptr[reg_src], jcp_.src_dt);
uni_vmulps(xmm_val, xmm_val, xmm_fused_factor);
if (attr_.post_ops_.len() != 0) {
apply_post_ops(jcp_.dst_dt, 0);
add(reg_oc_off, step * sizeof(float));
@@ -493,19 +498,19 @@ private:
switch (src_dt) {
case memory::data_type::f32:
case memory::data_type::s32:
movss(xmm_src, op);
uni_vmovss(xmm_src, op);
break;
case memory::data_type::bf16:
pinsrw(xmm_src, op, 0x0);
uni_vpinsrw(xmm_src, xmm_src, op, 0x0);
uni_vpslld(xmm_src, xmm_src, 16);
break;
case memory::data_type::s8:
movsx(reg_tmp_32, op);
movq(xmm_src, reg_tmp_64);
uni_vmovq(xmm_src, reg_tmp_64);
break;
case memory::data_type::u8:
movzx(reg_tmp_32, op);
movq(xmm_src, reg_tmp_64);
uni_vmovq(xmm_src, reg_tmp_64);
break;
default:
assert(!"unknown dst_dt");
@@ -568,11 +573,11 @@ private:
switch (dst_dt) {
case memory::data_type::f32:
case memory::data_type::s32:
movss(op, xmm_dst);
uni_vmovss(op, xmm_dst);
break;
case memory::data_type::bf16:
uni_vpsrld(xmm_dst, xmm_dst, 16);
pextrw(op, xmm_dst, 0x0);
uni_vpextrw(op, xmm_dst, 0x0);
break;
case memory::data_type::s8:
uni_vpackssdw(xmm_dst, xmm_dst, xmm_dst);

View File

@@ -607,19 +607,19 @@ private:
switch (src_dt) {
case memory::data_type::f32:
case memory::data_type::s32:
movss(xmm_src, op);
uni_vmovss(xmm_src, op);
break;
case memory::data_type::bf16:
pinsrw(xmm_src, op, 0x0);
uni_vpinsrw(xmm_src, xmm_src, op, 0x0);
uni_vpslld(xmm_src, xmm_src, 16);
break;
case memory::data_type::s8:
movsx(reg_tmp_32, op);
movq(xmm_src, reg_tmp_64);
uni_vmovq(xmm_src, reg_tmp_64);
break;
case memory::data_type::u8:
movzx(reg_tmp_32, op);
movq(xmm_src, reg_tmp_64);
uni_vmovq(xmm_src, reg_tmp_64);
break;
default:
assert(!"unknown src_dt");
@@ -692,11 +692,11 @@ private:
switch (dst_dt) {
case memory::data_type::f32:
case memory::data_type::s32:
movss(op, xmm_dst);
uni_vmovss(op, xmm_dst);
break;
case memory::data_type::bf16:
uni_vpsrld(xmm_dst, xmm_dst, 16);
pextrw(op, xmm_dst, 0x0);
uni_vpextrw(op, xmm_dst, 0x0);
break;
case memory::data_type::s8:
uni_vpackssdw(xmm_dst, xmm_dst, xmm_dst);
@@ -707,7 +707,7 @@ private:
case memory::data_type::u8:
uni_vpackusdw(xmm_dst, xmm_dst, xmm_dst);
uni_vpackuswb(xmm_dst, xmm_dst, xmm_dst);
movq(reg_tmp_64, xmm_dst);
vmovq(reg_tmp_64, xmm_dst);
mov(op, reg_tmp_8);
break;
default:
@@ -738,9 +738,9 @@ private:
}
inline void load_embedded_horiz_store(Xbyak::Xmm xmm_dst, memory::data_type dst_dt) {
movshdup(xmm_aux3, xmm_dst); // dst:1,2,3,4; aux3:2,2,4,4
uni_vmovshdup(xmm_aux3, xmm_dst); // dst:1,2,3,4; aux3:2,2,4,4
horiz_ps(xmm_dst, xmm_aux3); // dst:f(1,2),f(2,2),f(3,4),f(4,4)
movhlps(xmm_aux3, xmm_dst); // aux3:f(3,4),f(4,4),4,4
uni_vmovhlps(xmm_aux3, xmm_dst); // aux3:f(3,4),f(4,4),4,4
horiz_ps(xmm_dst, xmm_aux3); // dst:f(1,2,3,4),...
load_scalar(xmm_aux3, ptr[reg_dst], dst_dt);
@@ -753,21 +753,21 @@ private:
case memory::data_type::s32:
horiz_ps(xmm_dst, xmm_aux3);
uni_vcvtps2dq(xmm_dst, xmm_dst);
movss(ptr[reg_dst], xmm_dst);
uni_vmovss(ptr[reg_dst], xmm_dst);
break;
case memory::data_type::u8:
horiz_ps(xmm_dst, xmm_aux3);
uni_vcvtps2dq(xmm_dst, xmm_dst);
uni_vpackusdw(xmm_dst, xmm_dst, xmm_dst);
uni_vpackuswb(xmm_dst, xmm_dst, xmm_dst);
pextrb(ptr[reg_dst], xmm_dst, 0);
uni_vpextrb(ptr[reg_dst], xmm_dst, 0);
break;
case memory::data_type::s8:
horiz_ps(xmm_dst, xmm_aux3);
uni_vcvtps2dq(xmm_dst, xmm_dst);
uni_vpackssdw(xmm_dst, xmm_dst, xmm_dst);
uni_vpacksswb(xmm_dst, xmm_dst, xmm_dst);
pextrb(ptr[reg_dst], xmm_dst, 0);
uni_vpextrb(ptr[reg_dst], xmm_dst, 0);
break;
default:
assert(!"unknown dst_dt");
@@ -777,7 +777,7 @@ private:
inline void horiz_ps(const Xmm& xmm, const Operand& op) {
switch (jcp_.reduce_mode) {
case ReduceAnd:
andps(xmm, op);
uni_vandps(xmm, xmm, op);
break;
case ReduceL1:
case ReduceL2:
@@ -786,19 +786,19 @@ private:
case ReduceSum:
case ReduceSumSquare:
case ReduceLogSumExp:
addps(xmm, op);
uni_vaddps(xmm, xmm, op);
break;
case ReduceMax:
maxps(xmm, op);
uni_vmaxps(xmm, op);
break;
case ReduceMin:
minps(xmm, op);
uni_vminps(xmm, op);
break;
case ReduceOr:
orps(xmm, op);
uni_vorps(xmm, xmm, op);
break;
case ReduceProd:
mulps(xmm, op);
uni_vmulps(xmm, op);
break;
default:
assert(!"unsupported reduce mode");
@@ -1074,19 +1074,19 @@ private:
switch (src_dt) {
case memory::data_type::f32:
case memory::data_type::s32:
movss(xmm_src, op);
uni_vmovss(xmm_src, op);
break;
case memory::data_type::bf16:
pinsrw(xmm_src, op, 0x0);
uni_vpinsrw(xmm_src, xmm_src, op, 0x0);
uni_vpslld(xmm_src, xmm_src, 16);
break;
case memory::data_type::s8:
movsx(reg_tmp_32, op);
movq(xmm_src, reg_tmp_64);
uni_vmovq(xmm_src, reg_tmp_64);
break;
case memory::data_type::u8:
movzx(reg_tmp_32, op);
movq(xmm_src, reg_tmp_64);
uni_vmovq(xmm_src, reg_tmp_64);
break;
default:
assert(!"unknown src_dt");
@@ -1159,11 +1159,11 @@ private:
switch (dst_dt) {
case memory::data_type::f32:
case memory::data_type::s32:
movss(op, xmm_dst);
uni_vmovss(op, xmm_dst);
break;
case memory::data_type::bf16:
uni_vpsrld(xmm_dst, xmm_dst, 16);
pextrw(op, xmm_dst, 0x0);
uni_vpextrw(op, xmm_dst, 0x0);
break;
case memory::data_type::s8:
uni_vpackssdw(xmm_dst, xmm_dst, xmm_dst);
@@ -1205,33 +1205,33 @@ private:
}
inline void horize_store(Xbyak::Xmm xmm_dst, memory::data_type dst_dt) {
movshdup(xmm_aux3, xmm_dst); // dst:1,2,3,4; aux3:2,2,4,4
uni_vmovshdup(xmm_aux3, xmm_dst); // dst:1,2,3,4; aux3:2,2,4,4
horiz_ps(xmm_dst, xmm_aux3); // dst:f(1,2),f(2,2),f(3,4),f(4,4)
movhlps(xmm_aux3, xmm_dst); // aux3:f(3,4),f(4,4),4,4
uni_vmovhlps(xmm_aux3, xmm_dst); // aux3:f(3,4),f(4,4),4,4
horiz_ps(xmm_dst, xmm_aux3); // dst:f(1,2,3,4),...
switch (dst_dt) {
case memory::data_type::f32:
movss(ptr[reg_dst], xmm_dst);
uni_vmovss(ptr[reg_dst], xmm_dst);
break;
case memory::data_type::bf16:
uni_vpsrld(xmm_dst, xmm_dst, 16);
pextrw(ptr[reg_dst], xmm_dst, 0x0);
uni_vpextrw(ptr[reg_dst], xmm_dst, 0x0);
break;
case memory::data_type::s32:
uni_vcvtps2dq(xmm_dst, xmm_dst);
movss(ptr[reg_dst], xmm_dst);
uni_vmovss(ptr[reg_dst], xmm_dst);
break;
case memory::data_type::u8:
uni_vcvtps2dq(xmm_dst, xmm_dst);
uni_vpackusdw(xmm_dst, xmm_dst, xmm_dst);
uni_vpackuswb(xmm_dst, xmm_dst, xmm_dst);
pextrb(ptr[reg_dst], xmm_dst, 0);
uni_vpextrb(ptr[reg_dst], xmm_dst, 0);
break;
case memory::data_type::s8:
uni_vcvtps2dq(xmm_dst, xmm_dst);
uni_vpackssdw(xmm_dst, xmm_dst, xmm_dst);
uni_vpacksswb(xmm_dst, xmm_dst, xmm_dst);
pextrb(ptr[reg_dst], xmm_dst, 0);
uni_vpextrb(ptr[reg_dst], xmm_dst, 0);
break;
default:
assert(!"unknown dst_dt");
@@ -1261,9 +1261,9 @@ private:
}
inline void load_embedded_horiz_store(Xbyak::Xmm xmm_dst, memory::data_type dst_dt) {
movshdup(xmm_aux3, xmm_dst); // dst:1,2,3,4; aux3:2,2,4,4
uni_vmovshdup(xmm_aux3, xmm_dst); // dst:1,2,3,4; aux3:2,2,4,4
horiz_ps(xmm_dst, xmm_aux3); // dst:f(1,2),f(2,2),f(3,4),f(4,4)
movhlps(xmm_aux3, xmm_dst); // aux3:f(3,4),f(4,4),4,4
uni_vmovhlps(xmm_aux3, xmm_dst); // aux3:f(3,4),f(4,4),4,4
horiz_ps(xmm_dst, xmm_aux3); // dst:f(1,2,3,4),...
load_scalar(xmm_aux3, ptr[reg_dst], dst_dt);
@@ -1276,21 +1276,21 @@ private:
case memory::data_type::s32:
horiz_ps(xmm_dst, xmm_aux3);
uni_vcvtps2dq(xmm_dst, xmm_dst);
movss(ptr[reg_dst], xmm_dst);
uni_vmovss(ptr[reg_dst], xmm_dst);
break;
case memory::data_type::u8:
horiz_ps(xmm_dst, xmm_aux3);
uni_vcvtps2dq(xmm_dst, xmm_dst);
uni_vpackusdw(xmm_dst, xmm_dst, xmm_dst);
uni_vpackuswb(xmm_dst, xmm_dst, xmm_dst);
pextrb(ptr[reg_dst], xmm_dst, 0);
uni_vpextrb(ptr[reg_dst], xmm_dst, 0);
break;
case memory::data_type::s8:
horiz_ps(xmm_dst, xmm_aux3);
uni_vcvtps2dq(xmm_dst, xmm_dst);
uni_vpackssdw(xmm_dst, xmm_dst, xmm_dst);
uni_vpacksswb(xmm_dst, xmm_dst, xmm_dst);
pextrb(ptr[reg_dst], xmm_dst, 0);
uni_vpextrb(ptr[reg_dst], xmm_dst, 0);
break;
default:
assert(!"unknown dst_dt");
@@ -1300,7 +1300,7 @@ private:
inline void horiz_ps(const Xmm& xmm, const Operand& op) {
switch (jcp_.reduce_mode) {
case ReduceAnd:
andps(xmm, op);
uni_vandps(xmm, xmm, op);
break;
case ReduceL1:
case ReduceL2:
@@ -1309,19 +1309,19 @@ private:
case ReduceSum:
case ReduceSumSquare:
case ReduceLogSumExp:
addps(xmm, op);
uni_vaddps(xmm, xmm, op);
break;
case ReduceMax:
maxps(xmm, op);
uni_vmaxps(xmm, op);
break;
case ReduceMin:
minps(xmm, op);
uni_vminps(xmm, op);
break;
case ReduceOr:
orps(xmm, op);
uni_vorps(xmm, xmm, op);
break;
case ReduceProd:
mulps(xmm, op);
uni_vmulps(xmm, op);
break;
default:
assert(!"unsupported reduce mode");

View File

@@ -202,10 +202,10 @@ private:
inline void load_scalar(Xbyak::Xmm xmm_src, const Xbyak::Address &op, InferenceEngine::Precision src_dt) {
switch (src_dt) {
case InferenceEngine::Precision::FP32:
movss(xmm_src, op);
uni_vmovss(xmm_src, op);
break;
case InferenceEngine::Precision::BF16:
pinsrw(xmm_src, op, 0x0);
uni_vpinsrw(xmm_src, xmm_src, op, 0x0);
uni_vpslld(xmm_src, xmm_src, 16);
break;
default:
@@ -215,11 +215,11 @@ private:
inline void store_scalar(const Xbyak::Address &op, Xbyak::Xmm xmm_dst, InferenceEngine::Precision dst_dt) {
switch (dst_dt) {
case InferenceEngine::Precision::FP32:
movss(op, xmm_dst);
uni_vmovss(op, xmm_dst);
break;
case InferenceEngine::Precision::BF16:
uni_vpsrld(xmm_dst, xmm_dst, 16);
pextrw(op, xmm_dst, 0x0);
uni_vpextrw(op, xmm_dst, 0x0);
break;
default:
assert(!"unknown dst_dt");

View File

@@ -210,9 +210,9 @@ private:
}
void roi_pool_bilinear(int c_blocks) {
movq(xmm_yf, reg_yf);
uni_vmovq(xmm_yf, reg_yf);
uni_vbroadcastss(vmm_yf, xmm_yf);
movq(xmm_xf, reg_xf);
uni_vmovq(xmm_xf, reg_xf);
uni_vbroadcastss(vmm_xf, xmm_xf);
Vmm vmm_src00 = get_src_reg(0);