[CPU] Fix incorrect output for float to bf16 in avx2 isa (#19358)

This commit is contained in:
River Li 2023-09-07 13:09:16 +08:00 committed by GitHub
parent 14e0b1fd2c
commit 252afa3b6c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 18 additions and 18 deletions

View File

@ -75,6 +75,7 @@ private:
h->uni_vpackusdw(aux, aux, aux);
if (host_isa_ == dnnl::impl::cpu::x64::cpu_isa_t::avx2) {
h->vpermq(Ymm(aux.getIdx()), Ymm(aux.getIdx()), 0xD8); //11 01 10 00
h->vextracti128(out, Ymm(aux.getIdx()), 0);
} else {
h->uni_vmovups(out, aux);

View File

@ -272,20 +272,19 @@ struct jit_variable_load_store_test_kernel {
size_t size;
};
template<size_t N, bool is_src>
template<size_t N, size_t M, bool is_src>
void test() {
kernel_impl<N, is_src> kernel;
kernel.init();
const size_t size = 3;
ASSERT_GE(N, M);
std::array<SrcT, N> src {};
std::array<DstT, N> result {};
Params args = { src.data(), result.data(), size };
Params args = { src.data(), result.data(), M };
src.fill(static_cast<SrcT>(42));
for (size_t i = 0; i < size; ++i) {
for (size_t i = 0; i < M; ++i) {
src[i] = static_cast<SrcT>(i);
}
@ -293,7 +292,7 @@ struct jit_variable_load_store_test_kernel {
std::array<DstT, N> expected_result {};
for (size_t i = 0; i < size; ++i) {
for (size_t i = 0; i < M; ++i) {
expected_result[i] = static_cast<DstT>(i);
}
@ -325,52 +324,52 @@ TEST(JitKernel, variable_load_and_store) {
{
jit_variable_load_store_test_kernel<uint8_t, float> kernel;
if (mayiuse(cpu_isa_t::avx512_core)) {
kernel.test<16, false>();
kernel.test<16, 11, false>();
}
if (mayiuse(cpu_isa_t::avx2)) {
kernel.test<8, false>();
kernel.test<8, 5, false>();
}
if (mayiuse(cpu_isa_t::sse41)) {
kernel.test<4, false>();
kernel.test<4, 3, false>();
}
}
{
jit_variable_load_store_test_kernel<int8_t, int8_t> kernel;
if (mayiuse(cpu_isa_t::avx512_core)) {
kernel.test<16, false>();
kernel.test<16, 11, false>();
}
if (mayiuse(cpu_isa_t::avx2)) {
kernel.test<8, false>();
kernel.test<8, 5, false>();
}
if (mayiuse(cpu_isa_t::sse41)) {
kernel.test<4, false>();
kernel.test<4, 3, false>();
}
}
{
jit_variable_load_store_test_kernel<float, bfloat16_t> kernel;
if (mayiuse(cpu_isa_t::avx512_core)) {
kernel.test<16, true>();
kernel.test<16, 11, true>();
}
if (mayiuse(cpu_isa_t::avx2)) {
kernel.test<8, true>();
kernel.test<8, 5, true>();
}
if (mayiuse(cpu_isa_t::sse41)) {
kernel.test<4, true>();
kernel.test<4, 3, true>();
}
}
{
jit_variable_load_store_test_kernel<int32_t, bfloat16_t> kernel;
if (mayiuse(cpu_isa_t::avx512_core)) {
kernel.test<16, true>();
kernel.test<16, 11, true>();
}
if (mayiuse(cpu_isa_t::avx2)) {
kernel.test<8, true>();
kernel.test<8, 5, true>();
}
if (mayiuse(cpu_isa_t::sse41)) {
kernel.test<4, true>();
kernel.test<4, 3, true>();
}
}
}