[CPU] Fix avx2 gather of bfloat16 (#20683)
This commit is contained in:
parent
63b07b9357
commit
a785bd7b99
@ -744,8 +744,6 @@ void jitUniGatherKernel<isa>::process16b(bool isShortIdx, bool blocked) {
|
|||||||
|
|
||||||
mov(regAux1, reinterpret_cast<uintptr_t>(shufMask16bitUni));
|
mov(regAux1, reinterpret_cast<uintptr_t>(shufMask16bitUni));
|
||||||
uni_vmovups(vShufMask, ptr[regAux1]);
|
uni_vmovups(vShufMask, ptr[regAux1]);
|
||||||
mov(regAux1, reinterpret_cast<uintptr_t>(permMask16bitUni));
|
|
||||||
uni_vmovups(vPermMask, ptr[regAux1]);
|
|
||||||
|
|
||||||
// First iteration
|
// First iteration
|
||||||
shiftIdxAndGather(vmmAuxContainer, isShortIdx, false, blocked);
|
shiftIdxAndGather(vmmAuxContainer, isShortIdx, false, blocked);
|
||||||
@ -755,6 +753,9 @@ void jitUniGatherKernel<isa>::process16b(bool isShortIdx, bool blocked) {
|
|||||||
vpshufb(vmmAuxContainer[0], vmmAuxContainer[2], vShufMask);
|
vpshufb(vmmAuxContainer[0], vmmAuxContainer[2], vShufMask);
|
||||||
|
|
||||||
vshufps(vmmAuxContainer[0], vBuff0, vmmAuxContainer[0], 0x44);
|
vshufps(vmmAuxContainer[0], vBuff0, vmmAuxContainer[0], 0x44);
|
||||||
|
// vPermMask(vmm1) is override in shiftIdxAndGather, load the mask here for correctness
|
||||||
|
mov(regAux1, reinterpret_cast<uintptr_t>(permMask16bitUni));
|
||||||
|
uni_vmovups(vPermMask, ptr[regAux1]);
|
||||||
vpermd(vmmAuxContainer[0], vPermMask, vmmAuxContainer[0]);
|
vpermd(vmmAuxContainer[0], vPermMask, vmmAuxContainer[0]);
|
||||||
|
|
||||||
uni_vmovups(ptr[regDst], vmmAuxContainer[0]);
|
uni_vmovups(ptr[regDst], vmmAuxContainer[0]);
|
||||||
@ -774,6 +775,11 @@ void jitUniGatherKernel<isa>::process16b(bool isShortIdx, bool blocked) {
|
|||||||
vpshufb(vmmAuxContainer[0], vmmAuxContainer[2], vShufMask);
|
vpshufb(vmmAuxContainer[0], vmmAuxContainer[2], vShufMask);
|
||||||
|
|
||||||
vshufps(vmmAuxContainer[0], vBuff0, vmmAuxContainer[0], 0x44);
|
vshufps(vmmAuxContainer[0], vBuff0, vmmAuxContainer[0], 0x44);
|
||||||
|
if (isa == x64::avx2) {
|
||||||
|
// Register vPermMask is invalidated by shiftIdxAndGather and must be initialized again.
|
||||||
|
mov(regAux1, reinterpret_cast<uintptr_t>(permMask16bitUni));
|
||||||
|
uni_vmovups(vPermMask, ptr[regAux1]);
|
||||||
|
}
|
||||||
vpermd(vmmAuxContainer[0], vPermMask, vmmAuxContainer[0]);
|
vpermd(vmmAuxContainer[0], vPermMask, vmmAuxContainer[0]);
|
||||||
|
|
||||||
uni_vmovups(ptr[regDst], vmmAuxContainer[0]);
|
uni_vmovups(ptr[regDst], vmmAuxContainer[0]);
|
||||||
|
Loading…
Reference in New Issue
Block a user