Revert "[CPU] Fix mixing VEX and non-VEX instructions (#7238)" (#8381)

Because few issues were detected on AVX (unknown instruction)
and AVX2 (accuracy) platforms.
Reverting for now.
This commit is contained in:
Egor Duplensky
2021-11-03 13:51:26 +03:00
committed by GitHub
parent 6fdc6b4c16
commit 3cdebfcb7c
14 changed files with 146 additions and 148 deletions

View File

@@ -85,8 +85,8 @@ void jit_mul_add_emitter::emit_isa(const std::vector<size_t> &in_vec_idxs, const
if (isa == cpu::x64::sse41) {
h->uni_vmovups(vmm_dst, vmm_src0);
h->uni_vmulps(vmm_dst, vmm_src1);
h->uni_vaddps(vmm_dst, vmm_dst, vmm_src2);
h->mulps(vmm_dst, vmm_src1);
h->addps(vmm_dst, vmm_src2);
} else {
Vmm vmm_mul0;
if (vmm_dst.getIdx() == vmm_src0.getIdx()) {
@@ -656,7 +656,7 @@ void jit_equal_emitter::emit_isa(const std::vector<size_t> &in_vec_idxs, const s
} else if (isa == cpu::x64::avx2) {
h->vcmpeqps(vmm_aux0, vmm_src0, vmm_src1);
h->uni_vmovups(vmm_dst, table_val("zero"));
h->uni_vblendvps(vmm_dst, vmm_dst, table_val("one"), vmm_aux0);
h->vblendvps(vmm_dst, vmm_dst, table_val("one"), vmm_aux0);
} else {
h->vcmpps(k_mask, vmm_src0, vmm_src1, _cmp_eq_oq);
h->uni_vmovups(vmm_dst, table_val("zero"));

View File

@@ -190,7 +190,7 @@ void jit_load_emitter::load_bytes(const Vmm &vmm, const Xbyak::Reg64 &reg, int o
}
if (bytes_to_load >= 8 && bytes_to_load < 16)
h->uni_vpinsrq(xmm, xmm, addr(start_bytes), 0);
h->pinsrq(xmm, addr(start_bytes), 0);
else if (bytes_to_load == 16)
h->uni_vmovdqu(xmm, addr(start_bytes));
@@ -202,17 +202,17 @@ void jit_load_emitter::load_bytes(const Vmm &vmm, const Xbyak::Reg64 &reg, int o
h->uni_vpinsrw(xmm, xmm, addr(start_bytes), 0);
h->uni_vpinsrb(xmm, xmm, addr(start_bytes + 2), 2);
break;
case 4: h->uni_vpinsrd(xmm, xmm, addr(start_bytes), 0); break;
case 4: h->pinsrd(xmm, addr(start_bytes), 0); break;
case 5:
h->uni_vpinsrd(xmm, xmm, addr(start_bytes), 0);
h->pinsrd(xmm, addr(start_bytes), 0);
h->uni_vpinsrb(xmm, xmm, addr(start_bytes + 4), 4);
break;
case 6:
h->uni_vpinsrd(xmm, xmm, addr(start_bytes), 0);
h->pinsrd(xmm, addr(start_bytes), 0);
h->uni_vpinsrw(xmm, xmm, addr(start_bytes + 4), 2);
break;
case 7:
h->uni_vpinsrd(xmm, xmm, addr(start_bytes), 0);
h->pinsrd(xmm, addr(start_bytes), 0);
h->uni_vpinsrw(xmm, xmm, addr(start_bytes + 4), 2);
h->uni_vpinsrb(xmm, xmm, addr(start_bytes + 6), 6);
break;
@@ -223,17 +223,17 @@ void jit_load_emitter::load_bytes(const Vmm &vmm, const Xbyak::Reg64 &reg, int o
h->uni_vpinsrw(xmm, xmm, addr(start_bytes + 8), 4);
h->uni_vpinsrb(xmm, xmm, addr(start_bytes + 10), 10);
break;
case 12: h->uni_vpinsrd(xmm, xmm, addr(start_bytes + 8), 2); break;
case 12: h->pinsrd(xmm, addr(start_bytes + 8), 2); break;
case 13:
h->uni_vpinsrd(xmm, xmm, addr(start_bytes + 8), 2);
h->pinsrd(xmm, addr(start_bytes + 8), 2);
h->uni_vpinsrb(xmm, xmm, addr(start_bytes + 12), 12);
break;
case 14:
h->uni_vpinsrd(xmm, xmm, addr(start_bytes + 8), 2);
h->pinsrd(xmm, addr(start_bytes + 8), 2);
h->uni_vpinsrw(xmm, xmm, addr(start_bytes + 12), 6);
break;
case 15:
h->uni_vpinsrd(xmm, xmm, addr(start_bytes + 8), 2);
h->pinsrd(xmm, addr(start_bytes + 8), 2);
h->uni_vpinsrw(xmm, xmm, addr(start_bytes + 12), 6);
h->uni_vpinsrb(xmm, xmm, addr(start_bytes + 14), 14);
break;
@@ -465,7 +465,10 @@ template <typename Vmm>
if (is_xmm || is_ymm) {
uint8 imm = 1;
imm = ~((imm << load_num) - imm); // shift load_num bit
h->uni_vblendps(vmm, vmm, table_val(fill_value), imm);
if (is_xmm)
h->blendps(vmm, table_val(fill_value), imm);
else
h->vblendps(vmm, vmm, table_val(fill_value), imm);
} else if (is_zmm) {
uint64_t tail_mask = 1;
tail_mask = ~((tail_mask << load_num) - tail_mask);
@@ -665,7 +668,7 @@ template <typename Vmm>
}
if (bytes_to_store >= 8 && bytes_to_store < 16)
h->uni_vpextrq(addr(start_bytes), xmm, 0);
h->pextrq(addr(start_bytes), xmm, 0);
else if (bytes_to_store == 16)
h->uni_vmovdqu(addr(start_bytes), xmm);
@@ -679,17 +682,17 @@ template <typename Vmm>
h->uni_vpextrw(addr(start_bytes), xmm, 0);
h->uni_vpextrb(addr(start_bytes + 2), xmm, 2);
break;
case 4: h->uni_vpextrd(addr(start_bytes), xmm, 0); break;
case 4: h->pextrd(addr(start_bytes), xmm, 0); break;
case 5:
h->uni_vpextrd(addr(start_bytes), xmm, 0);
h->pextrd(addr(start_bytes), xmm, 0);
h->uni_vpextrb(addr(start_bytes + 4), xmm, 4);
break;
case 6:
h->uni_vpextrd(addr(start_bytes), xmm, 0);
h->pextrd(addr(start_bytes), xmm, 0);
h->uni_vpextrw(addr(start_bytes + 4), xmm, 2);
break;
case 7:
h->uni_vpextrd(addr(start_bytes), xmm, 0);
h->pextrd(addr(start_bytes), xmm, 0);
h->uni_vpextrw(addr(start_bytes + 4), xmm, 2);
h->uni_vpextrb(addr(start_bytes + 6), xmm, 6);
break;
@@ -700,17 +703,17 @@ template <typename Vmm>
h->uni_vpextrw(addr(start_bytes + 8), xmm, 4);
h->uni_vpextrb(addr(start_bytes + 10), xmm, 10);
break;
case 12: h->uni_vpextrd(addr(start_bytes + 8), xmm, 2); break;
case 12: h->pextrd(addr(start_bytes + 8), xmm, 2); break;
case 13:
h->uni_vpextrd(addr(start_bytes + 8), xmm, 2);
h->pextrd(addr(start_bytes + 8), xmm, 2);
h->uni_vpextrb(addr(start_bytes + 12), xmm, 12);
break;
case 14:
h->uni_vpextrd(addr(start_bytes + 8), xmm, 2);
h->pextrd(addr(start_bytes + 8), xmm, 2);
h->uni_vpextrw(addr(start_bytes + 12), xmm, 6);
break;
case 15:
h->uni_vpextrd(addr(start_bytes + 8), xmm, 2);
h->pextrd(addr(start_bytes + 8), xmm, 2);
h->uni_vpextrw(addr(start_bytes + 12), xmm, 6);
h->uni_vpextrb(addr(start_bytes + 14), xmm, 14);
break;

View File

@@ -47,21 +47,21 @@ struct jit_uni_permute_kernel_f32 : public jit_uni_permute_kernel, public jit_ge
void load(const Xbyak::Xmm &xmm, const Xbyak::Address &addr) {
switch (jcp.data_size) {
case 16: uni_vmovups(xmm, addr); break;
case 8: uni_vmovsd(xmm, addr); break;
case 4: uni_vmovss(xmm, addr); break;
case 2: uni_vpinsrw(xmm, xmm, addr, 0x0); break;
case 1: uni_vpinsrb(xmm, xmm, addr, 0x0); break;
case 16: movups(xmm, addr); break;
case 8: movsd(xmm, addr); break;
case 4: movss(xmm, addr); break;
case 2: pinsrw(xmm, addr, 0x0); break;
case 1: pinsrb(xmm, addr, 0x0); break;
}
}
void store(const Xbyak::Address &addr, const Xbyak::Xmm &xmm) {
switch (jcp.data_size) {
case 16: uni_vmovups(addr, xmm); break;
case 8: uni_vmovsd(addr, xmm); break;
case 4: uni_vmovss(addr, xmm); break;
case 2: uni_vpextrw(addr, xmm, 0x0); break;
case 1: uni_vpextrb(addr, xmm, 0x0); break;
case 16: movups(addr, xmm); break;
case 8: movsd(addr, xmm); break;
case 4: movss(addr, xmm); break;
case 2: pextrw(addr, xmm, 0x0); break;
case 1: pextrb(addr, xmm, 0x0); break;
}
}

View File

@@ -215,7 +215,7 @@ private:
case memory::data_type::s32:
if (scalar_load) {
mov(reg_tmp_32, op);
uni_vmovq(xmm_in, reg_tmp_64);
movq(xmm_in, reg_tmp_64);
} else {
uni_vmovups(vmm_in, op);
}
@@ -223,7 +223,7 @@ private:
case memory::data_type::s8:
if (scalar_load) {
movsx(reg_tmp_32, op);
uni_vmovq(xmm_in, reg_tmp_64);
movq(xmm_in, reg_tmp_64);
} else {
uni_vpmovsxbd(vmm_in, op);
}
@@ -231,7 +231,7 @@ private:
case memory::data_type::u8:
if (scalar_load) {
movzx(reg_tmp_32, op);
uni_vmovq(xmm_in, reg_tmp_64);
movq(xmm_in, reg_tmp_64);
} else {
uni_vpmovzxbd(vmm_in, op);
}
@@ -541,7 +541,7 @@ private:
if (jcp_.exclude_pad) {
mov(reg_shift, kw_padding[jj]);
imul(reg_shift, reg_tmp_32);
uni_vmovq(Xmm(vmm_shift.getIdx()), reg_shift);
movq(Xmm(vmm_shift.getIdx()), reg_shift);
uni_vbroadcastss(vmm_shift, Xmm(vmm_shift.getIdx()));
uni_vcvtdq2ps(vmm_shift, vmm_shift);
}
@@ -612,7 +612,7 @@ private:
} else {
Ymm ymm_prev_dst = Ymm(vmm_sum.getIdx());
vperm2i128(ymm_prev_dst, ymm_prev_dst, ymm_prev_dst, 0x01);
uni_vpslldq(vmm_sum, vmm_sum, (oc - jcp_.oc_block / 2) * sizeof(float));
vpslldq(vmm_sum, vmm_sum, (oc - jcp_.oc_block / 2) * sizeof(float));
}
uni_vaddps(vmm_dst, vmm_dst, vmm_sum);

View File

@@ -583,7 +583,7 @@ private:
switch (src_prc) {
case Precision::FP32:
case Precision::I32:
uni_vmovss(xmm_src, op);
movss(xmm_src, op);
break;
case Precision::BF16:
uni_vpinsrw(xmm_src, xmm_src, op, 0);
@@ -599,11 +599,11 @@ private:
break;
case Precision::I8:
movsx(reg_tmp_32, op);
uni_vmovq(xmm_src, reg_tmp_64);
movq(xmm_src, reg_tmp_64);
break;
case Precision::U8:
movzx(reg_tmp_32, op);
uni_vmovq(xmm_src, reg_tmp_64);
movq(xmm_src, reg_tmp_64);
break;
default:
assert(!"unknown src_prc");
@@ -730,7 +730,7 @@ private:
switch (dst_prc) {
case Precision::FP32:
case Precision::I32:
uni_vmovss(op, xmm_dst);
movss(op, xmm_dst);
break;
case Precision::BF16:
uni_vpsrld(xmm_dst, xmm_dst, 16);

View File

@@ -150,9 +150,9 @@ struct jit_uni_binarization_kernel : public jit_uni_quantize_kernel, public jit_
uni_vpxor(xmm_wei(0), xmm_wei(0), xmm_wei(0));
uni_vpxor(xmm_mask(0), xmm_mask(0), xmm_mask(0));
uni_vmovss(xmm_src(0), ptr[reg_from + c * sizeof(float)]);
uni_vmovss(xmm_wei(0), ptr[reg_thresholds + c * sizeof(float)]);
uni_vmovss(xmm_mask(0), ptr[reg_output_mask + c * sizeof(float)]);
movss(xmm_src(0), ptr[reg_from + c * sizeof(float)]);
movss(xmm_wei(0), ptr[reg_thresholds + c * sizeof(float)]);
movss(xmm_mask(0), ptr[reg_output_mask + c * sizeof(float)]);
uni_vcmpgtps(xmm_src(0), xmm_src(0), xmm_wei(0));
uni_vpcmpeqd(xmm_src(0), xmm_src(0), xmm_mask(0));
uni_vmovmskps(reg_src_32, xmm_src(0));
@@ -591,13 +591,13 @@ private:
jle(exit_label, T_NEAR);
for (int i = 0; i < jqp_.c % tail4_simd_w; i++) {
uni_vmovss(xmm_crop_low(0), ptr[reg_crop_low + i * wei_type_size]);
uni_vmovss(xmm_crop_high(0), ptr[reg_crop_high + i * wei_type_size]);
uni_vmovss(xmm_input_scale(0), ptr[reg_input_scale + i * wei_type_size]);
uni_vmovss(xmm_input_shift(0), ptr[reg_input_shift + i * wei_type_size]);
movss(xmm_crop_low(0), ptr[reg_crop_low + i * wei_type_size]);
movss(xmm_crop_high(0), ptr[reg_crop_high + i * wei_type_size]);
movss(xmm_input_scale(0), ptr[reg_input_scale + i * wei_type_size]);
movss(xmm_input_shift(0), ptr[reg_input_shift + i * wei_type_size]);
if (do_dequantization) {
uni_vmovss(xmm_output_scale(0), ptr[reg_output_scale + i * wei_type_size]);
uni_vmovss(xmm_output_shift(0), ptr[reg_output_shift + i * wei_type_size]);
movss(xmm_output_scale(0), ptr[reg_output_scale + i * wei_type_size]);
movss(xmm_output_shift(0), ptr[reg_output_shift + i * wei_type_size]);
}
load_scalar(xmm_val(0), ptr[aux_reg_from + i * src_type_size], jqp_.src_prc);
@@ -688,15 +688,15 @@ private:
switch (src_prc) {
case Precision::FP32:
case Precision::I32:
uni_vmovss(xmm_src, op);
movss(xmm_src, op);
break;
case Precision::I8:
movsx(reg_tmp_32, op);
uni_vmovq(xmm_src, reg_tmp_64);
movq(xmm_src, reg_tmp_64);
break;
case Precision::U8:
movzx(reg_tmp_32, op);
uni_vmovq(xmm_src, reg_tmp_64);
movq(xmm_src, reg_tmp_64);
break;
default:
assert(!"unknown src_prc");
@@ -797,7 +797,7 @@ private:
switch (dst_prc) {
case Precision::FP32:
case Precision::I32:
uni_vmovss(op, xmm_dst);
movss(op, xmm_dst);
break;
case Precision::I8:
uni_vpackssdw(xmm_dst, xmm_dst, xmm_dst);

View File

@@ -104,13 +104,13 @@ protected:
auto b = xmm2;
auto c = xmm3;
uni_vmovdqu(a, xword[src]); // load 4 floats
uni_vmovdqu(b, a); // b = a
uni_vmovdqu(c, a); // c = a
uni_vpcmpeqd(b, b, zero); // if (a == 0) b = 1 else b = 0
uni_vpand(c, c, mask); // c = a & 01111111100000000000000000000000
uni_vpcmpeqd(c, c, zero); // if (c == 0) c = 1 else c = 0
uni_vtestps(b, c); // if ((!b & c) == 0) CF = 1 else CF = 0
movdqu(a, xword[src]); // load 4 floats
movdqu(b, a); // b = a
movdqu(c, a); // c = a
pcmpeqd(b, zero); // if (a == 0) b = 1 else b = 0
pand(c, mask); // c = a & 01111111100000000000000000000000
pcmpeqd(c, zero); // if (c == 0) c = 1 else c = 0
ptest(b, c); // if ((!b & c) == 0) CF = 1 else CF = 0
}
template<cpu_isa_t isa>

View File

@@ -1199,10 +1199,10 @@ private:
jl(tail_loop_end_label, T_NEAR);
// get idx for input
uni_vmovss(Xmm(vmm_tbl_y.getIdx()), ptr[reg_tbl_y]);
movss(Xmm(vmm_tbl_y.getIdx()), ptr[reg_tbl_y]);
gather_i32_indices(vmm_index_in_y, reg_index_y, 0, vmm_tbl_y, 1, memory::data_type::s32, true);
uni_vmovss(Xmm(vmm_val.getIdx()), ptr[reg_tbl_x]);
movss(Xmm(vmm_val.getIdx()), ptr[reg_tbl_x]);
gather_i32_indices(vmm_index_in_x, reg_index, 0, vmm_val, 1, memory::data_type::s32, true);
// gather weightX by input idx, used in y0-y3
gather_i32_indices(vmm_weightX0, reg_weight_x, 0, vmm_val, grid_len, memory::data_type::f32, true);
@@ -1430,18 +1430,18 @@ private:
switch (src_dt) {
case memory::data_type::f32:
case memory::data_type::s32:
uni_vmovss(xmm_src, op);
movss(xmm_src, op);
break;
case memory::data_type::s8:
movsx(reg_tmp_32, op);
uni_vmovq(xmm_src, reg_tmp_64);
movq(xmm_src, reg_tmp_64);
break;
case memory::data_type::u8:
movzx(reg_tmp_32, op);
uni_vmovq(xmm_src, reg_tmp_64);
movq(xmm_src, reg_tmp_64);
break;
case memory::data_type::bf16:
uni_vpinsrw(xmm_src, xmm_src, op, 0x0);
pinsrw(xmm_src, op, 0x0);
uni_vpslld(xmm_src, xmm_src, 16);
break;
default:
@@ -1536,7 +1536,7 @@ private:
switch (dst_dt) {
case memory::data_type::f32:
case memory::data_type::s32:
uni_vmovss(op, xmm_dst);
movss(op, xmm_dst);
break;
case memory::data_type::s8:
uni_vpackssdw(xmm_dst, xmm_dst, xmm_dst);
@@ -1552,7 +1552,7 @@ private:
break;
case memory::data_type::bf16:
uni_vpsrld(xmm_dst, xmm_dst, 16);
uni_vpextrw(op, xmm_dst, 0x0);
pextrw(op, xmm_dst, 0x0);
break;
default:
assert(!"unknown dst_dt");

View File

@@ -102,17 +102,17 @@ struct jit_uni_mvn_mean_variance_kernel_f32 : public jit_uni_mvn_mean_variance_k
Xbyak::Ymm ymm_sum = Xbyak::Ymm(vmm_dst.getIdx());
vextractf128(xmm_aux1, ymm_sum, 0);
vextractf128(xmm_aux2, ymm_sum, 1);
uni_vaddps(xmm_aux1, xmm_aux1, xmm_aux2);
addps(xmm_aux1, xmm_aux2);
hsum_store(xmm_aux1);
} else {
Xbyak::Zmm zmm_sum = Xbyak::Zmm(vmm_dst.getIdx());
vextractf32x4(xmm_aux1, zmm_sum, 0);
vextractf32x4(xmm_aux2, zmm_sum, 1);
uni_vaddps(xmm_aux1, xmm_aux1, xmm_aux2);
addps(xmm_aux1, xmm_aux2);
vextractf32x4(xmm_aux2, zmm_sum, 2);
vextractf32x4(xmm_aux3, zmm_sum, 3);
uni_vaddps(xmm_aux2, xmm_aux2, xmm_aux3);
uni_vaddps(xmm_aux1, xmm_aux1, xmm_aux2);
addps(xmm_aux2, xmm_aux3);
addps(xmm_aux1, xmm_aux2);
hsum_store(xmm_aux1);
}
} else {
@@ -342,14 +342,14 @@ private:
}
inline void hsum_store(Xbyak::Xmm xmm_sum) {
uni_vmovshdup(xmm_aux3, xmm_sum); // sum:1,2,3,4; aux3:2,2,4,4
uni_vaddps(xmm_sum, xmm_sum, xmm_aux3); // sum:1+2,2+2,3+4,4+4
uni_vmovhlps(xmm_aux3, xmm_sum); // aux3:3+4,4+4,4,4
uni_vaddps(xmm_sum, xmm_sum, xmm_aux3); // sum:1+2+3+4,...
movshdup(xmm_aux3, xmm_sum); // sum:1,2,3,4; aux3:2,2,4,4
addps(xmm_sum, xmm_aux3); // sum:1+2,2+2,3+4,4+4
movhlps(xmm_aux3, xmm_sum); // aux3:3+4,4+4,4,4
addps(xmm_sum, xmm_aux3); // sum:1+2+3+4,...
if (jcp_.normalize_variance) {
uni_vmovss(ptr[reg_variance], xmm_sum);
movss(ptr[reg_variance], xmm_sum);
} else {
uni_vmovss(ptr[reg_sum], xmm_sum);
movss(ptr[reg_sum], xmm_sum);
}
}
};

View File

@@ -88,21 +88,17 @@ struct jit_uni_normalize_modulo_kernel_f32 : public jit_uni_normalize_modulo_ker
Xbyak::Ymm ymm_sqr_sum = Xbyak::Ymm(vmm_sqr_sum.getIdx());
vextractf128(xmm_aux1, ymm_sqr_sum, 0);
vextractf128(xmm_aux2, ymm_sqr_sum, 1);
// vaddps(xmm_aux1, xmm_aux2);
uni_vaddps(xmm_aux1, xmm_aux1, xmm_aux2);
addps(xmm_aux1, xmm_aux2);
hsum_store(xmm_aux1);
} else {
Xbyak::Zmm zmm_sqr_sum = Xbyak::Zmm(vmm_sqr_sum.getIdx());
vextractf32x4(xmm_aux1, zmm_sqr_sum, 0);
vextractf32x4(xmm_aux2, zmm_sqr_sum, 1);
// vaddps(xmm_aux1, xmm_aux2);
uni_vaddps(xmm_aux1, xmm_aux1, xmm_aux2);
addps(xmm_aux1, xmm_aux2);
vextractf32x4(xmm_aux2, zmm_sqr_sum, 2);
vextractf32x4(xmm_aux3, zmm_sqr_sum, 3);
// vaddps(xmm_aux2, xmm_aux3);
// vaddps(xmm_aux1, xmm_aux2);
uni_vaddps(xmm_aux2, xmm_aux2, xmm_aux3);
uni_vaddps(xmm_aux1, xmm_aux1, xmm_aux2);
addps(xmm_aux2, xmm_aux3);
addps(xmm_aux1, xmm_aux2);
hsum_store(xmm_aux1);
}
}
@@ -128,11 +124,11 @@ private:
Xbyak::Xmm xmm_aux3 = Xbyak::Xmm(4);
inline void hsum_store(Xbyak::Xmm xmm_sqr_sum) {
uni_vmovshdup(xmm_aux3, xmm_sqr_sum); // sqrt_sum:1,2,3,4; aux3:2,2,4,4
uni_vaddps(xmm_sqr_sum, xmm_sqr_sum, xmm_aux3); // sqrt_sum:1+2,2+2,3+4,4+4
uni_vmovhlps(xmm_aux3, xmm_sqr_sum); // aux3:3+4,4+4,4,4
uni_vaddps(xmm_sqr_sum, xmm_sqr_sum, xmm_aux3); // sqrt_sum:1+2+3+4,...
uni_vmovss(ptr[reg_modulo], xmm_sqr_sum);
movshdup(xmm_aux3, xmm_sqr_sum); // sqrt_sum:1,2,3,4; aux3:2,2,4,4
addps(xmm_sqr_sum, xmm_aux3); // sqrt_sum:1+2,2+2,3+4,4+4
movhlps(xmm_aux3, xmm_sqr_sum); // aux3:3+4,4+4,4,4
addps(xmm_sqr_sum, xmm_aux3); // sqrt_sum:1+2+3+4,...
movss(ptr[reg_modulo], xmm_sqr_sum);
}
inline void load_vector(Vmm vmm_src, const Xbyak::Address &op, memory::data_type src_dt) {
@@ -363,7 +359,6 @@ private:
load_scalar(xmm_val, ptr[reg_src], jcp_.src_dt);
uni_vmulps(xmm_val, xmm_val, xmm_fused_factor);
if (attr_.post_ops_.len() != 0) {
apply_post_ops(jcp_.dst_dt, 0);
add(reg_oc_off, step * sizeof(float));
@@ -498,19 +493,19 @@ private:
switch (src_dt) {
case memory::data_type::f32:
case memory::data_type::s32:
uni_vmovss(xmm_src, op);
movss(xmm_src, op);
break;
case memory::data_type::bf16:
uni_vpinsrw(xmm_src, xmm_src, op, 0x0);
pinsrw(xmm_src, op, 0x0);
uni_vpslld(xmm_src, xmm_src, 16);
break;
case memory::data_type::s8:
movsx(reg_tmp_32, op);
uni_vmovq(xmm_src, reg_tmp_64);
movq(xmm_src, reg_tmp_64);
break;
case memory::data_type::u8:
movzx(reg_tmp_32, op);
uni_vmovq(xmm_src, reg_tmp_64);
movq(xmm_src, reg_tmp_64);
break;
default:
assert(!"unknown dst_dt");
@@ -573,11 +568,11 @@ private:
switch (dst_dt) {
case memory::data_type::f32:
case memory::data_type::s32:
uni_vmovss(op, xmm_dst);
movss(op, xmm_dst);
break;
case memory::data_type::bf16:
uni_vpsrld(xmm_dst, xmm_dst, 16);
uni_vpextrw(op, xmm_dst, 0x0);
pextrw(op, xmm_dst, 0x0);
break;
case memory::data_type::s8:
uni_vpackssdw(xmm_dst, xmm_dst, xmm_dst);

View File

@@ -607,19 +607,19 @@ private:
switch (src_dt) {
case memory::data_type::f32:
case memory::data_type::s32:
uni_vmovss(xmm_src, op);
movss(xmm_src, op);
break;
case memory::data_type::bf16:
uni_vpinsrw(xmm_src, xmm_src, op, 0x0);
pinsrw(xmm_src, op, 0x0);
uni_vpslld(xmm_src, xmm_src, 16);
break;
case memory::data_type::s8:
movsx(reg_tmp_32, op);
uni_vmovq(xmm_src, reg_tmp_64);
movq(xmm_src, reg_tmp_64);
break;
case memory::data_type::u8:
movzx(reg_tmp_32, op);
uni_vmovq(xmm_src, reg_tmp_64);
movq(xmm_src, reg_tmp_64);
break;
default:
assert(!"unknown src_dt");
@@ -692,11 +692,11 @@ private:
switch (dst_dt) {
case memory::data_type::f32:
case memory::data_type::s32:
uni_vmovss(op, xmm_dst);
movss(op, xmm_dst);
break;
case memory::data_type::bf16:
uni_vpsrld(xmm_dst, xmm_dst, 16);
uni_vpextrw(op, xmm_dst, 0x0);
pextrw(op, xmm_dst, 0x0);
break;
case memory::data_type::s8:
uni_vpackssdw(xmm_dst, xmm_dst, xmm_dst);
@@ -707,7 +707,7 @@ private:
case memory::data_type::u8:
uni_vpackusdw(xmm_dst, xmm_dst, xmm_dst);
uni_vpackuswb(xmm_dst, xmm_dst, xmm_dst);
vmovq(reg_tmp_64, xmm_dst);
movq(reg_tmp_64, xmm_dst);
mov(op, reg_tmp_8);
break;
default:
@@ -738,9 +738,9 @@ private:
}
inline void load_embedded_horiz_store(Xbyak::Xmm xmm_dst, memory::data_type dst_dt) {
uni_vmovshdup(xmm_aux3, xmm_dst); // dst:1,2,3,4; aux3:2,2,4,4
movshdup(xmm_aux3, xmm_dst); // dst:1,2,3,4; aux3:2,2,4,4
horiz_ps(xmm_dst, xmm_aux3); // dst:f(1,2),f(2,2),f(3,4),f(4,4)
uni_vmovhlps(xmm_aux3, xmm_dst); // aux3:f(3,4),f(4,4),4,4
movhlps(xmm_aux3, xmm_dst); // aux3:f(3,4),f(4,4),4,4
horiz_ps(xmm_dst, xmm_aux3); // dst:f(1,2,3,4),...
load_scalar(xmm_aux3, ptr[reg_dst], dst_dt);
@@ -753,21 +753,21 @@ private:
case memory::data_type::s32:
horiz_ps(xmm_dst, xmm_aux3);
uni_vcvtps2dq(xmm_dst, xmm_dst);
uni_vmovss(ptr[reg_dst], xmm_dst);
movss(ptr[reg_dst], xmm_dst);
break;
case memory::data_type::u8:
horiz_ps(xmm_dst, xmm_aux3);
uni_vcvtps2dq(xmm_dst, xmm_dst);
uni_vpackusdw(xmm_dst, xmm_dst, xmm_dst);
uni_vpackuswb(xmm_dst, xmm_dst, xmm_dst);
uni_vpextrb(ptr[reg_dst], xmm_dst, 0);
pextrb(ptr[reg_dst], xmm_dst, 0);
break;
case memory::data_type::s8:
horiz_ps(xmm_dst, xmm_aux3);
uni_vcvtps2dq(xmm_dst, xmm_dst);
uni_vpackssdw(xmm_dst, xmm_dst, xmm_dst);
uni_vpacksswb(xmm_dst, xmm_dst, xmm_dst);
uni_vpextrb(ptr[reg_dst], xmm_dst, 0);
pextrb(ptr[reg_dst], xmm_dst, 0);
break;
default:
assert(!"unknown dst_dt");
@@ -777,7 +777,7 @@ private:
inline void horiz_ps(const Xmm& xmm, const Operand& op) {
switch (jcp_.reduce_mode) {
case ReduceAnd:
uni_vandps(xmm, xmm, op);
andps(xmm, op);
break;
case ReduceL1:
case ReduceL2:
@@ -786,19 +786,19 @@ private:
case ReduceSum:
case ReduceSumSquare:
case ReduceLogSumExp:
uni_vaddps(xmm, xmm, op);
addps(xmm, op);
break;
case ReduceMax:
uni_vmaxps(xmm, op);
maxps(xmm, op);
break;
case ReduceMin:
uni_vminps(xmm, op);
minps(xmm, op);
break;
case ReduceOr:
uni_vorps(xmm, xmm, op);
orps(xmm, op);
break;
case ReduceProd:
uni_vmulps(xmm, op);
mulps(xmm, op);
break;
default:
assert(!"unsupported reduce mode");
@@ -1074,19 +1074,19 @@ private:
switch (src_dt) {
case memory::data_type::f32:
case memory::data_type::s32:
uni_vmovss(xmm_src, op);
movss(xmm_src, op);
break;
case memory::data_type::bf16:
uni_vpinsrw(xmm_src, xmm_src, op, 0x0);
pinsrw(xmm_src, op, 0x0);
uni_vpslld(xmm_src, xmm_src, 16);
break;
case memory::data_type::s8:
movsx(reg_tmp_32, op);
uni_vmovq(xmm_src, reg_tmp_64);
movq(xmm_src, reg_tmp_64);
break;
case memory::data_type::u8:
movzx(reg_tmp_32, op);
uni_vmovq(xmm_src, reg_tmp_64);
movq(xmm_src, reg_tmp_64);
break;
default:
assert(!"unknown src_dt");
@@ -1159,11 +1159,11 @@ private:
switch (dst_dt) {
case memory::data_type::f32:
case memory::data_type::s32:
uni_vmovss(op, xmm_dst);
movss(op, xmm_dst);
break;
case memory::data_type::bf16:
uni_vpsrld(xmm_dst, xmm_dst, 16);
uni_vpextrw(op, xmm_dst, 0x0);
pextrw(op, xmm_dst, 0x0);
break;
case memory::data_type::s8:
uni_vpackssdw(xmm_dst, xmm_dst, xmm_dst);
@@ -1205,33 +1205,33 @@ private:
}
inline void horize_store(Xbyak::Xmm xmm_dst, memory::data_type dst_dt) {
uni_vmovshdup(xmm_aux3, xmm_dst); // dst:1,2,3,4; aux3:2,2,4,4
movshdup(xmm_aux3, xmm_dst); // dst:1,2,3,4; aux3:2,2,4,4
horiz_ps(xmm_dst, xmm_aux3); // dst:f(1,2),f(2,2),f(3,4),f(4,4)
uni_vmovhlps(xmm_aux3, xmm_dst); // aux3:f(3,4),f(4,4),4,4
movhlps(xmm_aux3, xmm_dst); // aux3:f(3,4),f(4,4),4,4
horiz_ps(xmm_dst, xmm_aux3); // dst:f(1,2,3,4),...
switch (dst_dt) {
case memory::data_type::f32:
uni_vmovss(ptr[reg_dst], xmm_dst);
movss(ptr[reg_dst], xmm_dst);
break;
case memory::data_type::bf16:
uni_vpsrld(xmm_dst, xmm_dst, 16);
uni_vpextrw(ptr[reg_dst], xmm_dst, 0x0);
pextrw(ptr[reg_dst], xmm_dst, 0x0);
break;
case memory::data_type::s32:
uni_vcvtps2dq(xmm_dst, xmm_dst);
uni_vmovss(ptr[reg_dst], xmm_dst);
movss(ptr[reg_dst], xmm_dst);
break;
case memory::data_type::u8:
uni_vcvtps2dq(xmm_dst, xmm_dst);
uni_vpackusdw(xmm_dst, xmm_dst, xmm_dst);
uni_vpackuswb(xmm_dst, xmm_dst, xmm_dst);
uni_vpextrb(ptr[reg_dst], xmm_dst, 0);
pextrb(ptr[reg_dst], xmm_dst, 0);
break;
case memory::data_type::s8:
uni_vcvtps2dq(xmm_dst, xmm_dst);
uni_vpackssdw(xmm_dst, xmm_dst, xmm_dst);
uni_vpacksswb(xmm_dst, xmm_dst, xmm_dst);
uni_vpextrb(ptr[reg_dst], xmm_dst, 0);
pextrb(ptr[reg_dst], xmm_dst, 0);
break;
default:
assert(!"unknown dst_dt");
@@ -1261,9 +1261,9 @@ private:
}
inline void load_embedded_horiz_store(Xbyak::Xmm xmm_dst, memory::data_type dst_dt) {
uni_vmovshdup(xmm_aux3, xmm_dst); // dst:1,2,3,4; aux3:2,2,4,4
movshdup(xmm_aux3, xmm_dst); // dst:1,2,3,4; aux3:2,2,4,4
horiz_ps(xmm_dst, xmm_aux3); // dst:f(1,2),f(2,2),f(3,4),f(4,4)
uni_vmovhlps(xmm_aux3, xmm_dst); // aux3:f(3,4),f(4,4),4,4
movhlps(xmm_aux3, xmm_dst); // aux3:f(3,4),f(4,4),4,4
horiz_ps(xmm_dst, xmm_aux3); // dst:f(1,2,3,4),...
load_scalar(xmm_aux3, ptr[reg_dst], dst_dt);
@@ -1276,21 +1276,21 @@ private:
case memory::data_type::s32:
horiz_ps(xmm_dst, xmm_aux3);
uni_vcvtps2dq(xmm_dst, xmm_dst);
uni_vmovss(ptr[reg_dst], xmm_dst);
movss(ptr[reg_dst], xmm_dst);
break;
case memory::data_type::u8:
horiz_ps(xmm_dst, xmm_aux3);
uni_vcvtps2dq(xmm_dst, xmm_dst);
uni_vpackusdw(xmm_dst, xmm_dst, xmm_dst);
uni_vpackuswb(xmm_dst, xmm_dst, xmm_dst);
uni_vpextrb(ptr[reg_dst], xmm_dst, 0);
pextrb(ptr[reg_dst], xmm_dst, 0);
break;
case memory::data_type::s8:
horiz_ps(xmm_dst, xmm_aux3);
uni_vcvtps2dq(xmm_dst, xmm_dst);
uni_vpackssdw(xmm_dst, xmm_dst, xmm_dst);
uni_vpacksswb(xmm_dst, xmm_dst, xmm_dst);
uni_vpextrb(ptr[reg_dst], xmm_dst, 0);
pextrb(ptr[reg_dst], xmm_dst, 0);
break;
default:
assert(!"unknown dst_dt");
@@ -1300,7 +1300,7 @@ private:
inline void horiz_ps(const Xmm& xmm, const Operand& op) {
switch (jcp_.reduce_mode) {
case ReduceAnd:
uni_vandps(xmm, xmm, op);
andps(xmm, op);
break;
case ReduceL1:
case ReduceL2:
@@ -1309,19 +1309,19 @@ private:
case ReduceSum:
case ReduceSumSquare:
case ReduceLogSumExp:
uni_vaddps(xmm, xmm, op);
addps(xmm, op);
break;
case ReduceMax:
uni_vmaxps(xmm, op);
maxps(xmm, op);
break;
case ReduceMin:
uni_vminps(xmm, op);
minps(xmm, op);
break;
case ReduceOr:
uni_vorps(xmm, xmm, op);
orps(xmm, op);
break;
case ReduceProd:
uni_vmulps(xmm, op);
mulps(xmm, op);
break;
default:
assert(!"unsupported reduce mode");

View File

@@ -202,10 +202,10 @@ private:
inline void load_scalar(Xbyak::Xmm xmm_src, const Xbyak::Address &op, InferenceEngine::Precision src_dt) {
switch (src_dt) {
case InferenceEngine::Precision::FP32:
uni_vmovss(xmm_src, op);
movss(xmm_src, op);
break;
case InferenceEngine::Precision::BF16:
uni_vpinsrw(xmm_src, xmm_src, op, 0x0);
pinsrw(xmm_src, op, 0x0);
uni_vpslld(xmm_src, xmm_src, 16);
break;
default:
@@ -215,11 +215,11 @@ private:
inline void store_scalar(const Xbyak::Address &op, Xbyak::Xmm xmm_dst, InferenceEngine::Precision dst_dt) {
switch (dst_dt) {
case InferenceEngine::Precision::FP32:
uni_vmovss(op, xmm_dst);
movss(op, xmm_dst);
break;
case InferenceEngine::Precision::BF16:
uni_vpsrld(xmm_dst, xmm_dst, 16);
uni_vpextrw(op, xmm_dst, 0x0);
pextrw(op, xmm_dst, 0x0);
break;
default:
assert(!"unknown dst_dt");

View File

@@ -210,9 +210,9 @@ private:
}
void roi_pool_bilinear(int c_blocks) {
uni_vmovq(xmm_yf, reg_yf);
movq(xmm_yf, reg_yf);
uni_vbroadcastss(vmm_yf, xmm_yf);
uni_vmovq(xmm_xf, reg_xf);
movq(xmm_xf, reg_xf);
uni_vbroadcastss(vmm_xf, xmm_xf);
Vmm vmm_src00 = get_src_reg(0);