[CPU] Fix incorrect output for float to bf16 in avx2 isa (#19358)
This commit is contained in:
parent
14e0b1fd2c
commit
252afa3b6c
@ -75,6 +75,7 @@ private:
|
||||
h->uni_vpackusdw(aux, aux, aux);
|
||||
|
||||
if (host_isa_ == dnnl::impl::cpu::x64::cpu_isa_t::avx2) {
|
||||
h->vpermq(Ymm(aux.getIdx()), Ymm(aux.getIdx()), 0xD8); //11 01 10 00
|
||||
h->vextracti128(out, Ymm(aux.getIdx()), 0);
|
||||
} else {
|
||||
h->uni_vmovups(out, aux);
|
||||
|
@ -272,20 +272,19 @@ struct jit_variable_load_store_test_kernel {
|
||||
size_t size;
|
||||
};
|
||||
|
||||
template<size_t N, bool is_src>
|
||||
template<size_t N, size_t M, bool is_src>
|
||||
void test() {
|
||||
kernel_impl<N, is_src> kernel;
|
||||
kernel.init();
|
||||
|
||||
const size_t size = 3;
|
||||
ASSERT_GE(N, M);
|
||||
|
||||
std::array<SrcT, N> src {};
|
||||
std::array<DstT, N> result {};
|
||||
|
||||
Params args = { src.data(), result.data(), size };
|
||||
Params args = { src.data(), result.data(), M };
|
||||
|
||||
src.fill(static_cast<SrcT>(42));
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
for (size_t i = 0; i < M; ++i) {
|
||||
src[i] = static_cast<SrcT>(i);
|
||||
}
|
||||
|
||||
@ -293,7 +292,7 @@ struct jit_variable_load_store_test_kernel {
|
||||
|
||||
std::array<DstT, N> expected_result {};
|
||||
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
for (size_t i = 0; i < M; ++i) {
|
||||
expected_result[i] = static_cast<DstT>(i);
|
||||
}
|
||||
|
||||
@ -325,52 +324,52 @@ TEST(JitKernel, variable_load_and_store) {
|
||||
{
|
||||
jit_variable_load_store_test_kernel<uint8_t, float> kernel;
|
||||
if (mayiuse(cpu_isa_t::avx512_core)) {
|
||||
kernel.test<16, false>();
|
||||
kernel.test<16, 11, false>();
|
||||
}
|
||||
if (mayiuse(cpu_isa_t::avx2)) {
|
||||
kernel.test<8, false>();
|
||||
kernel.test<8, 5, false>();
|
||||
}
|
||||
if (mayiuse(cpu_isa_t::sse41)) {
|
||||
kernel.test<4, false>();
|
||||
kernel.test<4, 3, false>();
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
jit_variable_load_store_test_kernel<int8_t, int8_t> kernel;
|
||||
if (mayiuse(cpu_isa_t::avx512_core)) {
|
||||
kernel.test<16, false>();
|
||||
kernel.test<16, 11, false>();
|
||||
}
|
||||
if (mayiuse(cpu_isa_t::avx2)) {
|
||||
kernel.test<8, false>();
|
||||
kernel.test<8, 5, false>();
|
||||
}
|
||||
if (mayiuse(cpu_isa_t::sse41)) {
|
||||
kernel.test<4, false>();
|
||||
kernel.test<4, 3, false>();
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
jit_variable_load_store_test_kernel<float, bfloat16_t> kernel;
|
||||
if (mayiuse(cpu_isa_t::avx512_core)) {
|
||||
kernel.test<16, true>();
|
||||
kernel.test<16, 11, true>();
|
||||
}
|
||||
if (mayiuse(cpu_isa_t::avx2)) {
|
||||
kernel.test<8, true>();
|
||||
kernel.test<8, 5, true>();
|
||||
}
|
||||
if (mayiuse(cpu_isa_t::sse41)) {
|
||||
kernel.test<4, true>();
|
||||
kernel.test<4, 3, true>();
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
jit_variable_load_store_test_kernel<int32_t, bfloat16_t> kernel;
|
||||
if (mayiuse(cpu_isa_t::avx512_core)) {
|
||||
kernel.test<16, true>();
|
||||
kernel.test<16, 11, true>();
|
||||
}
|
||||
if (mayiuse(cpu_isa_t::avx2)) {
|
||||
kernel.test<8, true>();
|
||||
kernel.test<8, 5, true>();
|
||||
}
|
||||
if (mayiuse(cpu_isa_t::sse41)) {
|
||||
kernel.test<4, true>();
|
||||
kernel.test<4, 3, true>();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user