[CPU] Fix incorrect output for float to bf16 in avx2 isa (#19358)

This commit is contained in:
River Li 2023-09-07 13:09:16 +08:00 committed by GitHub
parent 14e0b1fd2c
commit 252afa3b6c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 18 additions and 18 deletions

View File

@ -75,6 +75,7 @@ private:
h->uni_vpackusdw(aux, aux, aux); h->uni_vpackusdw(aux, aux, aux);
if (host_isa_ == dnnl::impl::cpu::x64::cpu_isa_t::avx2) { if (host_isa_ == dnnl::impl::cpu::x64::cpu_isa_t::avx2) {
h->vpermq(Ymm(aux.getIdx()), Ymm(aux.getIdx()), 0xD8); //11 01 10 00
h->vextracti128(out, Ymm(aux.getIdx()), 0); h->vextracti128(out, Ymm(aux.getIdx()), 0);
} else { } else {
h->uni_vmovups(out, aux); h->uni_vmovups(out, aux);

View File

@ -272,20 +272,19 @@ struct jit_variable_load_store_test_kernel {
size_t size; size_t size;
}; };
template<size_t N, bool is_src> template<size_t N, size_t M, bool is_src>
void test() { void test() {
kernel_impl<N, is_src> kernel; kernel_impl<N, is_src> kernel;
kernel.init(); kernel.init();
ASSERT_GE(N, M);
const size_t size = 3;
std::array<SrcT, N> src {}; std::array<SrcT, N> src {};
std::array<DstT, N> result {}; std::array<DstT, N> result {};
Params args = { src.data(), result.data(), size }; Params args = { src.data(), result.data(), M };
src.fill(static_cast<SrcT>(42)); src.fill(static_cast<SrcT>(42));
for (size_t i = 0; i < size; ++i) { for (size_t i = 0; i < M; ++i) {
src[i] = static_cast<SrcT>(i); src[i] = static_cast<SrcT>(i);
} }
@ -293,7 +292,7 @@ struct jit_variable_load_store_test_kernel {
std::array<DstT, N> expected_result {}; std::array<DstT, N> expected_result {};
for (size_t i = 0; i < size; ++i) { for (size_t i = 0; i < M; ++i) {
expected_result[i] = static_cast<DstT>(i); expected_result[i] = static_cast<DstT>(i);
} }
@ -325,52 +324,52 @@ TEST(JitKernel, variable_load_and_store) {
{ {
jit_variable_load_store_test_kernel<uint8_t, float> kernel; jit_variable_load_store_test_kernel<uint8_t, float> kernel;
if (mayiuse(cpu_isa_t::avx512_core)) { if (mayiuse(cpu_isa_t::avx512_core)) {
kernel.test<16, false>(); kernel.test<16, 11, false>();
} }
if (mayiuse(cpu_isa_t::avx2)) { if (mayiuse(cpu_isa_t::avx2)) {
kernel.test<8, false>(); kernel.test<8, 5, false>();
} }
if (mayiuse(cpu_isa_t::sse41)) { if (mayiuse(cpu_isa_t::sse41)) {
kernel.test<4, false>(); kernel.test<4, 3, false>();
} }
} }
{ {
jit_variable_load_store_test_kernel<int8_t, int8_t> kernel; jit_variable_load_store_test_kernel<int8_t, int8_t> kernel;
if (mayiuse(cpu_isa_t::avx512_core)) { if (mayiuse(cpu_isa_t::avx512_core)) {
kernel.test<16, false>(); kernel.test<16, 11, false>();
} }
if (mayiuse(cpu_isa_t::avx2)) { if (mayiuse(cpu_isa_t::avx2)) {
kernel.test<8, false>(); kernel.test<8, 5, false>();
} }
if (mayiuse(cpu_isa_t::sse41)) { if (mayiuse(cpu_isa_t::sse41)) {
kernel.test<4, false>(); kernel.test<4, 3, false>();
} }
} }
{ {
jit_variable_load_store_test_kernel<float, bfloat16_t> kernel; jit_variable_load_store_test_kernel<float, bfloat16_t> kernel;
if (mayiuse(cpu_isa_t::avx512_core)) { if (mayiuse(cpu_isa_t::avx512_core)) {
kernel.test<16, true>(); kernel.test<16, 11, true>();
} }
if (mayiuse(cpu_isa_t::avx2)) { if (mayiuse(cpu_isa_t::avx2)) {
kernel.test<8, true>(); kernel.test<8, 5, true>();
} }
if (mayiuse(cpu_isa_t::sse41)) { if (mayiuse(cpu_isa_t::sse41)) {
kernel.test<4, true>(); kernel.test<4, 3, true>();
} }
} }
{ {
jit_variable_load_store_test_kernel<int32_t, bfloat16_t> kernel; jit_variable_load_store_test_kernel<int32_t, bfloat16_t> kernel;
if (mayiuse(cpu_isa_t::avx512_core)) { if (mayiuse(cpu_isa_t::avx512_core)) {
kernel.test<16, true>(); kernel.test<16, 11, true>();
} }
if (mayiuse(cpu_isa_t::avx2)) { if (mayiuse(cpu_isa_t::avx2)) {
kernel.test<8, true>(); kernel.test<8, 5, true>();
} }
if (mayiuse(cpu_isa_t::sse41)) { if (mayiuse(cpu_isa_t::sse41)) {
kernel.test<4, true>(); kernel.test<4, 3, true>();
} }
} }
} }