[CPU] Fix avx2 gather of bfloat16 (#20683)

This commit is contained in:
Zhang Yi 2023-11-02 13:46:03 +08:00 committed by Alexander Nesterov
parent 63b07b9357
commit a785bd7b99

View File

@ -744,8 +744,6 @@ void jitUniGatherKernel<isa>::process16b(bool isShortIdx, bool blocked) {
mov(regAux1, reinterpret_cast<uintptr_t>(shufMask16bitUni));
uni_vmovups(vShufMask, ptr[regAux1]);
mov(regAux1, reinterpret_cast<uintptr_t>(permMask16bitUni));
uni_vmovups(vPermMask, ptr[regAux1]);
// First iteration
shiftIdxAndGather(vmmAuxContainer, isShortIdx, false, blocked);
@ -755,6 +753,9 @@ void jitUniGatherKernel<isa>::process16b(bool isShortIdx, bool blocked) {
vpshufb(vmmAuxContainer[0], vmmAuxContainer[2], vShufMask);
vshufps(vmmAuxContainer[0], vBuff0, vmmAuxContainer[0], 0x44);
// vPermMask(vmm1) is override in shiftIdxAndGather, load the mask here for correctness
mov(regAux1, reinterpret_cast<uintptr_t>(permMask16bitUni));
uni_vmovups(vPermMask, ptr[regAux1]);
vpermd(vmmAuxContainer[0], vPermMask, vmmAuxContainer[0]);
uni_vmovups(ptr[regDst], vmmAuxContainer[0]);
@ -774,6 +775,11 @@ void jitUniGatherKernel<isa>::process16b(bool isShortIdx, bool blocked) {
vpshufb(vmmAuxContainer[0], vmmAuxContainer[2], vShufMask);
vshufps(vmmAuxContainer[0], vBuff0, vmmAuxContainer[0], 0x44);
if (isa == x64::avx2) {
// Register vPermMask is invalidated by shiftIdxAndGather and must be initialized again.
mov(regAux1, reinterpret_cast<uintptr_t>(permMask16bitUni));
uni_vmovups(vPermMask, ptr[regAux1]);
}
vpermd(vmmAuxContainer[0], vPermMask, vmmAuxContainer[0]);
uni_vmovups(ptr[regDst], vmmAuxContainer[0]);