[CPU] Reduce node fix illegal instruction on sse41 (#5782)

This commit is contained in:
Chen Xu
2021-06-01 15:01:20 +08:00
committed by GitHub
parent 42afbfb4a3
commit de8742ad8e

View File

@@ -316,8 +316,12 @@ private:
}
// reduce
reduce_main_loop();
if (jcp_.reduce_mode == ReduceOr && isa != avx512_common) {
vcmpneqps(vmm_dst, vmm_dst, vmm_zero);
if (jcp_.reduce_mode == ReduceOr && isa != cpu::x64::avx512_common) {
if (isa == cpu::x64::avx2) {
vcmpneqps(vmm_dst, vmm_dst, vmm_zero);
} else if (isa == cpu::x64::sse41) {
cmpneqps(vmm_dst, vmm_zero);
}
uni_vandps(vmm_dst, vmm_dst, vmm_aux);
}
// store
@@ -361,7 +365,11 @@ private:
// reduce
reduce_kernel_scalar(xmm_src, xmm_dst);
if (jcp_.reduce_mode == ReduceOr) {
vcmpneqps(xmm_dst, xmm_dst, xmm_zero);
if (isa == cpu::x64::sse41) {
cmpneqps(xmm_dst, xmm_zero);
} else {
vcmpneqps(xmm_dst, xmm_dst, xmm_zero);
}
uni_vandps(xmm_dst, xmm_dst, xmm_aux);
}
@@ -400,7 +408,11 @@ private:
reduce_kernel_scalar(xmm_src, xmm_dst);
if (jcp_.reduce_mode == ReduceOr) {
vcmpneqps(xmm_dst, xmm_dst, xmm_zero);
if (isa == cpu::x64::sse41) {
cmpneqps(xmm_dst, xmm_zero);
} else {
vcmpneqps(xmm_dst, xmm_dst, xmm_zero);
}
uni_vandps(xmm_dst, xmm_dst, xmm_aux);
}
@@ -448,11 +460,13 @@ private:
inline void reduce_kernel(Vmm vmm_src, Vmm vmm_dst) {
switch (jcp_.reduce_mode) {
case ReduceAnd:
if (isa == avx512_common) {
if (isa == cpu::x64::avx512_common) {
vcmpps(k_mask, vmm_src, vmm_zero, _cmp_neq_uq);
vblendmps(vmm_src | k_mask, vmm_zero, vmm_aux);
} else {
} else if (isa == cpu::x64::avx2) {
vcmpneqps(vmm_src, vmm_src, vmm_zero);
} else {
cmpneqps(vmm_src, vmm_zero);
}
uni_vandps(vmm_dst, vmm_dst, vmm_src);
break;
@@ -481,7 +495,7 @@ private:
uni_vaddps(vmm_dst, vmm_dst, vmm_src);
break;
case ReduceOr:
if (isa == avx512_common) {
if (isa == cpu::x64::avx512_common) {
vcmpps(k_mask, vmm_src, vmm_zero, _cmp_neq_uq);
vblendmps(vmm_src | k_mask, vmm_zero, vmm_aux);
}
@@ -498,7 +512,11 @@ private:
inline void reduce_kernel_scalar(Xmm xmm_src, Xmm xmm_dst) {
switch (jcp_.reduce_mode) {
case ReduceAnd:
vcmpneqps(xmm_src, xmm_src, xmm_zero);
if (isa == cpu::x64::sse41) {
cmpneqps(xmm_src, xmm_zero);
} else {
vcmpneqps(xmm_src, xmm_src, xmm_zero);
}
uni_vandps(xmm_dst, xmm_dst, xmm_src);
break;
case ReduceL1:
@@ -543,11 +561,16 @@ private:
}
inline void store_dst_vector() {
if (jcp_.reduce_mode == ReduceOr && isa != avx512_common) {
vcmpneqps(vmm_dst, vmm_dst, vmm_zero);
if (jcp_.reduce_mode == ReduceOr && isa != cpu::x64::avx512_common) {
if (isa == cpu::x64::avx2) {
vcmpneqps(vmm_dst, vmm_dst, vmm_zero);
} else if (isa == cpu::x64::sse41) {
cmpneqps(vmm_dst, vmm_zero);
}
uni_vandps(vmm_dst, vmm_dst, vmm_aux);
if (isa == cpu::x64::sse41) {
vcmpneqps(vmm_dst_aux, vmm_dst_aux, vmm_zero);
cmpneqps(vmm_dst_aux, vmm_zero);
uni_vandps(vmm_dst_aux, vmm_dst_aux, vmm_aux);
}
}
@@ -628,7 +651,7 @@ private:
vmovdqu16(op, ymm_dst);
break;
case memory::data_type::s8:
if (isa == avx512_common) {
if (isa == cpu::x64::avx512_common) {
vmaxps(vmm_dst, vmm_zero, vmm_dst);
vpmovsdb(op, vmm_dst);
} else {
@@ -643,7 +666,7 @@ private:
}
break;
case memory::data_type::u8:
if (isa == avx512_common) {
if (isa == cpu::x64::avx512_common) {
vpmovusdb(op, vmm_dst);
} else {
uni_vpackusdw(vmm_dst, vmm_dst, vmm_dst);
@@ -719,24 +742,20 @@ private:
horiz_ps(xmm_dst, xmm_aux3); // dst:f(1,2),f(2,2),f(3,4),f(4,4)
movhlps(xmm_aux3, xmm_dst); // aux3:f(3,4),f(4,4),4,4
horiz_ps(xmm_dst, xmm_aux3); // dst:f(1,2,3,4),...
load_scalar(xmm_aux3, ptr[reg_dst], dst_dt);
switch (dst_dt) {
case memory::data_type::f32:
case memory::data_type::bf16:
load_scalar(xmm_aux3, ptr[reg_dst], dst_dt);
horiz_ps(xmm_dst, xmm_aux3);
store_scalar(ptr[reg_dst], xmm_dst, dst_dt);
break;
case memory::data_type::s32:
movss(xmm_aux3, ptr[reg_dst]);
uni_vcvtdq2ps(xmm_aux3, xmm_aux3);
horiz_ps(xmm_dst, xmm_aux3);
uni_vcvtps2dq(xmm_dst, xmm_dst);
movss(ptr[reg_dst], xmm_dst);
break;
case memory::data_type::u8:
vpbroadcastb(xmm_aux3, ptr[reg_dst]);
uni_vpmovzxbd(xmm_aux3, xmm_aux3);
uni_vcvtdq2ps(xmm_aux3, xmm_aux3);
horiz_ps(xmm_dst, xmm_aux3);
uni_vcvtps2dq(xmm_dst, xmm_dst);
uni_vpackusdw(xmm_dst, xmm_dst, xmm_dst);
@@ -744,9 +763,6 @@ private:
pextrb(ptr[reg_dst], xmm_dst, 0);
break;
case memory::data_type::s8:
vpbroadcastb(xmm_aux3, ptr[reg_dst]);
uni_vpmovsxbd(xmm_aux3, xmm_aux3);
uni_vcvtdq2ps(xmm_aux3, xmm_aux3);
horiz_ps(xmm_dst, xmm_aux3);
uni_vcvtps2dq(xmm_dst, xmm_dst);
uni_vpackssdw(xmm_dst, xmm_dst, xmm_dst);
@@ -1102,7 +1118,7 @@ private:
vmovdqu16(op, ymm_dst);
break;
case memory::data_type::s8:
if (isa == avx512_common) {
if (isa == cpu::x64::avx512_common) {
vmaxps(vmm_dst, vmm_zero, vmm_dst);
vpmovsdb(op, vmm_dst);
} else {
@@ -1117,7 +1133,7 @@ private:
}
break;
case memory::data_type::u8:
if (isa == avx512_common) {
if (isa == cpu::x64::avx512_common) {
vpmovusdb(op, vmm_dst);
} else {
uni_vpackusdw(vmm_dst, vmm_dst, vmm_dst);
@@ -1249,24 +1265,20 @@ private:
horiz_ps(xmm_dst, xmm_aux3); // dst:f(1,2),f(2,2),f(3,4),f(4,4)
movhlps(xmm_aux3, xmm_dst); // aux3:f(3,4),f(4,4),4,4
horiz_ps(xmm_dst, xmm_aux3); // dst:f(1,2,3,4),...
load_scalar(xmm_aux3, ptr[reg_dst], dst_dt);
switch (dst_dt) {
case memory::data_type::f32:
case memory::data_type::bf16:
load_scalar(xmm_aux3, ptr[reg_dst], dst_dt);
horiz_ps(xmm_dst, xmm_aux3);
store_scalar(ptr[reg_dst], xmm_dst, dst_dt);
break;
case memory::data_type::s32:
movss(xmm_aux3, ptr[reg_dst]);
uni_vcvtdq2ps(xmm_aux3, xmm_aux3);
horiz_ps(xmm_dst, xmm_aux3);
uni_vcvtps2dq(xmm_dst, xmm_dst);
movss(ptr[reg_dst], xmm_dst);
break;
case memory::data_type::u8:
vpbroadcastb(xmm_aux3, ptr[reg_dst]);
uni_vpmovzxbd(xmm_aux3, xmm_aux3);
uni_vcvtdq2ps(xmm_aux3, xmm_aux3);
horiz_ps(xmm_dst, xmm_aux3);
uni_vcvtps2dq(xmm_dst, xmm_dst);
uni_vpackusdw(xmm_dst, xmm_dst, xmm_dst);
@@ -1274,9 +1286,6 @@ private:
pextrb(ptr[reg_dst], xmm_dst, 0);
break;
case memory::data_type::s8:
vpbroadcastb(xmm_aux3, ptr[reg_dst]);
uni_vpmovsxbd(xmm_aux3, xmm_aux3);
uni_vcvtdq2ps(xmm_aux3, xmm_aux3);
horiz_ps(xmm_dst, xmm_aux3);
uni_vcvtps2dq(xmm_dst, xmm_dst);
uni_vpackssdw(xmm_dst, xmm_dst, xmm_dst);