[CPU] Fix avx2 gather of bfloat16 (#20683)
This commit is contained in:
parent
63b07b9357
commit
a785bd7b99
@ -744,8 +744,6 @@ void jitUniGatherKernel<isa>::process16b(bool isShortIdx, bool blocked) {
|
||||
|
||||
mov(regAux1, reinterpret_cast<uintptr_t>(shufMask16bitUni));
|
||||
uni_vmovups(vShufMask, ptr[regAux1]);
|
||||
mov(regAux1, reinterpret_cast<uintptr_t>(permMask16bitUni));
|
||||
uni_vmovups(vPermMask, ptr[regAux1]);
|
||||
|
||||
// First iteration
|
||||
shiftIdxAndGather(vmmAuxContainer, isShortIdx, false, blocked);
|
||||
@ -755,6 +753,9 @@ void jitUniGatherKernel<isa>::process16b(bool isShortIdx, bool blocked) {
|
||||
vpshufb(vmmAuxContainer[0], vmmAuxContainer[2], vShufMask);
|
||||
|
||||
vshufps(vmmAuxContainer[0], vBuff0, vmmAuxContainer[0], 0x44);
|
||||
// vPermMask(vmm1) is override in shiftIdxAndGather, load the mask here for correctness
|
||||
mov(regAux1, reinterpret_cast<uintptr_t>(permMask16bitUni));
|
||||
uni_vmovups(vPermMask, ptr[regAux1]);
|
||||
vpermd(vmmAuxContainer[0], vPermMask, vmmAuxContainer[0]);
|
||||
|
||||
uni_vmovups(ptr[regDst], vmmAuxContainer[0]);
|
||||
@ -774,6 +775,11 @@ void jitUniGatherKernel<isa>::process16b(bool isShortIdx, bool blocked) {
|
||||
vpshufb(vmmAuxContainer[0], vmmAuxContainer[2], vShufMask);
|
||||
|
||||
vshufps(vmmAuxContainer[0], vBuff0, vmmAuxContainer[0], 0x44);
|
||||
if (isa == x64::avx2) {
|
||||
// Register vPermMask is invalidated by shiftIdxAndGather and must be initialized again.
|
||||
mov(regAux1, reinterpret_cast<uintptr_t>(permMask16bitUni));
|
||||
uni_vmovups(vPermMask, ptr[regAux1]);
|
||||
}
|
||||
vpermd(vmmAuxContainer[0], vPermMask, vmmAuxContainer[0]);
|
||||
|
||||
uni_vmovups(ptr[regDst], vmmAuxContainer[0]);
|
||||
|
Loading…
Reference in New Issue
Block a user