[CPU] Fix sporadic SIGFAULT in GridSample. (#15009)

This commit is contained in:
Nikolay Shchegolev 2023-02-07 17:57:34 +04:00 committed by GitHub
parent a48b4fc2b5
commit 188dda668f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 209 additions and 87 deletions

View File

@ -202,12 +202,16 @@ void GridSample::prepareParams() {
auto& p = execParamsPerThread[ithr]; auto& p = execParamsPerThread[ithr];
p.workAmount = dstEnd - dstStart;
if (p.workAmount == 0lu) {
return;
}
p.batchNum = srcDataShape[0]; p.batchNum = srcDataShape[0];
p.channelsNum = srcDataShape[1]; p.channelsNum = srcDataShape[1];
p.srcHeightF[0] = srcDataShape[2]; p.srcHeightF[0] = srcDataShape[2];
p.srcWidthF[0] = srcDataShape[3]; p.srcWidthF[0] = srcDataShape[3];
p.workAmount = dstEnd - dstStart;
p.gridStartB = dstStart * 2 * gridTypeSize; p.gridStartB = dstStart * 2 * gridTypeSize;
p.dstStartB = dstStart * dataTypeSize; p.dstStartB = dstStart * dataTypeSize;

View File

@ -76,10 +76,8 @@ void GridSampleKernel<x64::avx512_core>::initVectors() {
mov(rAux, ptr[regParams + GET_OFF(srcHeightF)]); mov(rAux, ptr[regParams + GET_OFF(srcHeightF)]);
uni_vpbroadcastd(vSrcHeightF, ptr[rAux]); uni_vpbroadcastd(vSrcHeightF, ptr[rAux]);
if (one_of(jcp.paddingMode, GridSamplePaddingMode::ZEROS, GridSamplePaddingMode::BORDER)) { vZeros = getVmm();
vZeros = getVmm(); uni_vpxor(vZeros, vZeros, vZeros);
uni_vpxor(vZeros, vZeros, vZeros);
}
if (one_of(jcp.interpolationMode, GridSampleInterpolationMode::BICUBIC, GridSampleInterpolationMode::BILINEAR)) { if (one_of(jcp.interpolationMode, GridSampleInterpolationMode::BICUBIC, GridSampleInterpolationMode::BILINEAR)) {
vOnesF = getVmm(); vOnesF = getVmm();
@ -430,7 +428,7 @@ void GridSampleKernel<x64::avx512_core>::getTailCoordinates(const Vmm& vHCoord,
cmp(rAux, 0); cmp(rAux, 0);
jle(lEnd, T_NEAR); jle(lEnd, T_NEAR);
fillRestWorkMask(kTailMask, vAux, rAux); fillRestWorkMask(kTailMask, rAux);
uni_vmovups((Vmm)vAux | kTailMask, ptr[regGrid]); uni_vmovups((Vmm)vAux | kTailMask, ptr[regGrid]);
vpermd(vAux, vGridPermMask, vAux); vpermd(vAux, vGridPermMask, vAux);
Xbyak::Ymm ymmAux(vAux.getIdx()); Xbyak::Ymm ymmAux(vAux.getIdx());
@ -441,7 +439,7 @@ void GridSampleKernel<x64::avx512_core>::getTailCoordinates(const Vmm& vHCoord,
} }
L(lRest); L(lRest);
{ {
fillRestWorkMask(kTailMask, vAux, rAux); fillRestWorkMask(kTailMask, rAux);
uni_vmovups(vWCoord | kTailMask, ptr[regGrid]); uni_vmovups(vWCoord | kTailMask, ptr[regGrid]);
vpermd(vWCoord, vGridPermMask, vWCoord); vpermd(vWCoord, vGridPermMask, vWCoord);
vshuff64x2(vHCoord, vWCoord, vHCoord, 0B11101110); // Extract Y component vshuff64x2(vHCoord, vWCoord, vHCoord, 0B11101110); // Extract Y component
@ -454,7 +452,7 @@ void GridSampleKernel<x64::avx512_core>::getTailCoordinates(const Vmm& vHCoord,
L(lEnd); L(lEnd);
fillRestWorkMask(kTailMask, vAux, regWorkAmount); fillRestWorkMask(kTailMask, regWorkAmount);
} }
template <> template <>
@ -672,14 +670,14 @@ void GridSampleKernel<isa>::denormalizeRawCoordinates(const Vmm& vWCoord, const
template <> template <>
void GridSampleKernel<x64::avx512_core>::zerosPaddingW(const Vmask& kDst, const Vmm& vCoord) { void GridSampleKernel<x64::avx512_core>::zerosPaddingW(const Vmask& kDst, const Vmm& vCoord) {
vcmpps(kDst, vCoord, vSrcWidthF, 0x1); // vCoord < vUpperBound vcmpps(kDst, vCoord, vSrcWidthF, CMP_LT_PS); // vCoord < vUpperBound
vcmpps(kDst | kDst, vZeros, vCoord, 0x2); // vCoord >= vZeros vcmpps(kDst | kDst, vZeros, vCoord, CMP_LE_PS); // vCoord >= vZeros
} }
template <> template <>
void GridSampleKernel<x64::avx512_core>::zerosPaddingH(const Vmask& kDst, const Vmm& vCoord, const Vmask& kMaskW) { void GridSampleKernel<x64::avx512_core>::zerosPaddingH(const Vmask& kDst, const Vmm& vCoord, const Vmask& kMaskW) {
vcmpps(kDst | kMaskW, vCoord, vSrcHeightF, 0x1); // vCoord < vUpperBound vcmpps(kDst | kMaskW, vCoord, vSrcHeightF, CMP_LT_PS); // vCoord < vUpperBound
vcmpps(kDst | kDst, vZeros, vCoord, 0x2); // vCoord >= vZeros vcmpps(kDst | kDst, vZeros, vCoord, CMP_LE_PS); // vCoord >= vZeros
} }
template <> template <>
@ -693,15 +691,15 @@ void GridSampleKernel<x64::sse41>::zerosPaddingW(const Vmask& kDst, const Vmm& v
auto vAux = getVmm(); auto vAux = getVmm();
if (vSrcWidthF.isInitialized()) { if (vSrcWidthF.isInitialized()) {
uni_vcmpps(vAux, vWCoord, vSrcWidthF, 0x1); // vWCoord < vSrcWidthF uni_vcmpps(vAux, vWCoord, vSrcWidthF, CMP_LT_PS); // vWCoord < vSrcWidthF
} else { } else {
auto rAux = getReg64(); auto rAux = getReg64();
mov(rAux, ptr[regParams + GET_OFF(srcWidthF)]); mov(rAux, ptr[regParams + GET_OFF(srcWidthF)]);
uni_vcmpps(vAux, vWCoord, ptr[rAux], 0x1); // vWCoord < vSrcWidthF uni_vcmpps(vAux, vWCoord, ptr[rAux], CMP_LT_PS); // vWCoord < vSrcWidthF
} }
uni_vpxor(kDst, kDst, kDst); uni_vpxor(kDst, kDst, kDst);
uni_vcmpps(kDst, kDst, vWCoord, 0x2); // vWCoord >= vZeros uni_vcmpps(kDst, kDst, vWCoord, CMP_LE_PS); // vWCoord >= vZeros
uni_vpand(kDst, kDst, vAux); // vZeros <= vWCoord < vSrcWidthF uni_vpand(kDst, kDst, vAux); // vZeros <= vWCoord < vSrcWidthF
} }
@ -710,17 +708,17 @@ void GridSampleKernel<x64::sse41>::zerosPaddingH(const Vmask& kDst, const Vmm& v
auto vAux = getVmm(); auto vAux = getVmm();
if (vSrcHeightF.isInitialized()) { if (vSrcHeightF.isInitialized()) {
uni_vcmpps(vAux, vHCoord, vSrcHeightF, 0x1); // vHCoord < vSrcHeightF uni_vcmpps(vAux, vHCoord, vSrcHeightF, CMP_LT_PS); // vHCoord < vSrcHeightF
} else { } else {
auto rAux = getReg64(); auto rAux = getReg64();
mov(rAux, ptr[regParams + GET_OFF(srcHeightF)]); mov(rAux, ptr[regParams + GET_OFF(srcHeightF)]);
uni_vcmpps(vAux, vHCoord, ptr[rAux], 0x1); // vHCoord < vSrcHeightF uni_vcmpps(vAux, vHCoord, ptr[rAux], CMP_LT_PS); // vHCoord < vSrcHeightF
} }
uni_vmovups(kDst, kMaskW); uni_vmovups(kDst, kMaskW);
uni_vpand(kDst, kDst, vAux); // vHCoord < vSrcHeightF && vZeros <= vWCoord < vSrcWidthF uni_vpand(kDst, kDst, vAux); // vHCoord < vSrcHeightF && vZeros <= vWCoord < vSrcWidthF
uni_vpxor(vAux, vAux, vAux); uni_vpxor(vAux, vAux, vAux);
uni_vcmpps(vAux, vAux, vHCoord, 0x2); // vHCoord >= vZeros uni_vcmpps(vAux, vAux, vHCoord, CMP_LE_PS); // vHCoord >= vZeros
uni_vpand(kDst, kDst, vAux); // vZeros <= vHCoord < vSrcHeightF && vZeros <= vWCoord < vSrcWidthF uni_vpand(kDst, kDst, vAux); // vZeros <= vHCoord < vSrcHeightF && vZeros <= vWCoord < vSrcWidthF
} }
@ -744,14 +742,14 @@ void GridSampleKernel<isa>::zerosPaddingW(const Vmask& kDst, const Vmm& vCoord)
} }
if (vSrcWidthF.isInitialized()) { if (vSrcWidthF.isInitialized()) {
uni_vcmpps(vAux, vCoord, vSrcWidthF, 0x1); // vWCoord < vSrcWidthF uni_vcmpps(vAux, vCoord, vSrcWidthF, CMP_LT_PS); // vWCoord < vSrcWidthF
} else { } else {
auto rAux = getReg64(); auto rAux = getReg64();
mov(rAux, ptr[regParams + GET_OFF(srcWidthF)]); mov(rAux, ptr[regParams + GET_OFF(srcWidthF)]);
uni_vcmpps(vAux, vCoord, ptr[rAux], 0x1); // vWCoord < vSrcWidthF uni_vcmpps(vAux, vCoord, ptr[rAux], CMP_LT_PS); // vWCoord < vSrcWidthF
} }
uni_vcmpps(kDst, vZerosTmp, vCoord, 0x2); // vWCoord >= vZeros uni_vcmpps(kDst, vZerosTmp, vCoord, CMP_LE_PS); // vWCoord >= vZeros
uni_vandps(kDst, kDst, vAux); // vZeros <= vWCoord < vSrcWidthF uni_vandps(kDst, kDst, vAux); // vZeros <= vWCoord < vSrcWidthF
} }
@ -769,15 +767,15 @@ void GridSampleKernel<isa>::zerosPaddingH(const Vmask& kDst, const Vmm& vCoord,
} }
if (vSrcHeightF.isInitialized()) { if (vSrcHeightF.isInitialized()) {
uni_vcmpps(vAux, vCoord, vSrcHeightF, 0x1); // vHCoord < vSrcHeightF uni_vcmpps(vAux, vCoord, vSrcHeightF, CMP_LT_PS); // vHCoord < vSrcHeightF
} else { } else {
auto rAux = getReg64(); auto rAux = getReg64();
mov(rAux, ptr[regParams + GET_OFF(srcHeightF)]); mov(rAux, ptr[regParams + GET_OFF(srcHeightF)]);
uni_vcmpps(vAux, vCoord, ptr[rAux], 0x1); // vHCoord < vSrcHeightF uni_vcmpps(vAux, vCoord, ptr[rAux], CMP_LT_PS); // vHCoord < vSrcHeightF
} }
uni_vandps(kDst, kMaskW, vAux); uni_vandps(kDst, kMaskW, vAux);
uni_vcmpps(vAux, vZerosTmp, vCoord, 0x2); // vHCoord >= vZeros uni_vcmpps(vAux, vZerosTmp, vCoord, CMP_LE_PS); // vHCoord >= vZeros
uni_vandps(kDst, kDst, vAux); uni_vandps(kDst, kDst, vAux);
} }
@ -831,7 +829,7 @@ void GridSampleKernel<isa>::borderPadding(const Vmm& vCoordDst, const Vmm& vCoor
} }
} }
uni_vcmpps(vAux, vCoordOrigin, vSub1F, 0x2); // vCoord <= vUpperBound uni_vcmpps(vAux, vCoordOrigin, vSub1F, CMP_LE_PS); // vCoord <= vUpperBound
uni_vandps(vCoordDst, vCoordOrigin, vAux); uni_vandps(vCoordDst, vCoordOrigin, vAux);
uni_vandnps(vAux, vAux, vSub1F); uni_vandnps(vAux, vAux, vSub1F);
uni_vaddps(vCoordDst, vCoordDst, vAux); uni_vaddps(vCoordDst, vCoordDst, vAux);
@ -857,14 +855,20 @@ void GridSampleKernel<isa>::borderPadding(const Vmm& vCoordDst, const Vmm& vCoor
template <> template <>
void GridSampleKernel<x64::avx512_core>::reflectionPadding(const Vmm& vCoordDst, const Vmm& vCoordOrigin, const coord dim) { void GridSampleKernel<x64::avx512_core>::reflectionPadding(const Vmm& vCoordDst, const Vmm& vCoordOrigin, const coord dim) {
auto vAux = getVmm(); auto vAux = getVmm();
auto kAux = getMask();
const auto& vSrcDimMul2Sub1F = dim == coord::w ? vSrcWidthMul2Sub1F : vSrcHeightMul2Sub1F; const auto& vSrcDimMul2Sub1F = dim == coord::w ? vSrcWidthMul2Sub1F : vSrcHeightMul2Sub1F;
if (jcp.alignCorners) { if (jcp.alignCorners) {
// abs(x) % D21 // abs(x) % D21
uni_vandps(vCoordDst, vCoordOrigin, vAbsMask); // abs(x) uni_vandps(vCoordDst, vCoordOrigin, vAbsMask); // abs(x)
uni_vdivps(vAux, vCoordDst, vSrcDimMul2Sub1F); uni_vdivps(vAux, vCoordDst, vSrcDimMul2Sub1F);
uni_vroundps(vAux, vAux, 0x3); // Truncation uni_vroundps(vAux, vAux, 0x3); // Truncation
uni_vfnmadd231ps(vCoordDst, vAux, vSrcDimMul2Sub1F); // abs(x) % D21 uni_vfnmadd231ps(vCoordDst, vAux, vSrcDimMul2Sub1F); // abs(x) % D21
// Check that the result does not exceed the divisor.
vcmpps(kAux, vSrcDimMul2Sub1F, vCoordDst, CMP_LE_PS);
uni_vmovups(vCoordDst | kAux, vZeros);
vrangeps(vCoordDst, vCoordDst, vZeros, 0x1);
} else { } else {
const auto& vSrcDimMul2F = dim == coord::w ? vSrcWidthMul2F : vSrcHeightMul2F; const auto& vSrcDimMul2F = dim == coord::w ? vSrcWidthMul2F : vSrcHeightMul2F;
// (x % D2 + D2) % D2 // (x % D2 + D2) % D2
@ -877,12 +881,16 @@ void GridSampleKernel<x64::avx512_core>::reflectionPadding(const Vmm& vCoordDst,
uni_vdivps(vAux, vCoordDst, vSrcDimMul2F); uni_vdivps(vAux, vCoordDst, vSrcDimMul2F);
uni_vroundps(vAux, vAux, 0x3); // Truncation uni_vroundps(vAux, vAux, 0x3); // Truncation
uni_vfnmadd231ps(vCoordDst, vAux, vSrcDimMul2F); // (x % D2 + D2) % D2 uni_vfnmadd231ps(vCoordDst, vAux, vSrcDimMul2F); // (x % D2 + D2) % D2
// Check that the result does not exceed the divisor.
vcmpps(kAux, vSrcDimMul2F, vCoordDst, CMP_LE_PS);
uni_vmovups(vCoordDst | kAux, vZeros);
vrangeps(vCoordDst, vCoordDst, vZeros, 0x1);
} }
auto kAux = getMask();
uni_vsubps(vAux, vSrcDimMul2Sub1F, vCoordDst); uni_vsubps(vAux, vSrcDimMul2Sub1F, vCoordDst);
vcmpps(kAux, dim == coord::w ? vSrcWidthF : vSrcHeightF, vCoordDst, 0x2); // vCoordDst >= vSrcDimF vcmpps(kAux, dim == coord::w ? vSrcWidthF : vSrcHeightF, vCoordDst, CMP_LE_PS); // vCoordDst >= vSrcDimF
vmovups(vCoordDst | kAux, vAux); uni_vmovups(vCoordDst | kAux, vAux);
} }
template <x64::cpu_isa_t isa> // Works for AVX2, AVX, SSE41 template <x64::cpu_isa_t isa> // Works for AVX2, AVX, SSE41
@ -925,6 +933,14 @@ void GridSampleKernel<isa>::reflectionPadding(const Vmm& vCoordDst, const Vmm& v
uni_vdivps(vAux0, vCoordDst, vMul2Sub1); uni_vdivps(vAux0, vCoordDst, vMul2Sub1);
uni_vroundps(vAux0, vAux0, 0x3); // Truncation uni_vroundps(vAux0, vAux0, 0x3); // Truncation
uni_vfnmadd231ps(vCoordDst, vAux0, vMul2Sub1); // abs(x) % D21 uni_vfnmadd231ps(vCoordDst, vAux0, vMul2Sub1); // abs(x) % D21
// Check that the result does not exceed the divisor.
uni_vcmpps(vAux0, vCoordDst, vMul2Sub1, CMP_LT_PS);
uni_vandps(vCoordDst, vCoordDst, vAux0);
uni_vxorps(vAux0, vAux0, vAux0);
uni_vcmpps(vAux0, vAux0, vCoordDst, CMP_LE_PS);
uni_vandps(vCoordDst, vCoordDst, vAux0);
uni_vsubps(vAux0, vCoordDst, vMul2Sub1); // abs(x) % D21 - D21 uni_vsubps(vAux0, vCoordDst, vMul2Sub1); // abs(x) % D21 - D21
} else { } else {
// x' = (x % D2 + D2) % D2 - D21 // x' = (x % D2 + D2) % D2 - D21
@ -956,6 +972,13 @@ void GridSampleKernel<isa>::reflectionPadding(const Vmm& vCoordDst, const Vmm& v
uni_vroundps(vAux0, vAux0, 0x3); // Truncation uni_vroundps(vAux0, vAux0, 0x3); // Truncation
uni_vfnmadd231ps(vCoordDst, vAux0, vMul2); // (x % D2 + D2) % D2 uni_vfnmadd231ps(vCoordDst, vAux0, vMul2); // (x % D2 + D2) % D2
// Check that the result does not exceed the divisor.
uni_vcmpps(vAux0, vCoordDst, vMul2, CMP_LT_PS);
uni_vandps(vCoordDst, vCoordDst, vAux0);
uni_vxorps(vAux0, vAux0, vAux0);
uni_vcmpps(vAux0, vAux0, vCoordDst, CMP_LE_PS);
uni_vandps(vCoordDst, vCoordDst, vAux0);
if (dim == coord::w) { if (dim == coord::w) {
if (vSrcWidthMul2Sub1F.isInitialized()) { if (vSrcWidthMul2Sub1F.isInitialized()) {
uni_vsubps(vAux0, vCoordDst, vSrcWidthMul2Sub1F); uni_vsubps(vAux0, vCoordDst, vSrcWidthMul2Sub1F);
@ -975,17 +998,17 @@ void GridSampleKernel<isa>::reflectionPadding(const Vmm& vCoordDst, const Vmm& v
if (dim == coord::w) { if (dim == coord::w) {
if (vSrcWidthF.isInitialized()) { if (vSrcWidthF.isInitialized()) {
uni_vcmpps(vAux1, vCoordDst, vSrcWidthF, 0x1); // vCoordDst < vUpperBound uni_vcmpps(vAux1, vCoordDst, vSrcWidthF, CMP_LT_PS); // vCoordDst < vUpperBound
} else { } else {
mov(rAux, ptr[regParams + GET_OFF(srcWidthF)]); mov(rAux, ptr[regParams + GET_OFF(srcWidthF)]);
uni_vcmpps(vAux1, vCoordDst, ptr[rAux], 0x1); // vCoordDst < vUpperBound uni_vcmpps(vAux1, vCoordDst, ptr[rAux], CMP_LT_PS); // vCoordDst < vUpperBound
} }
} else { } else {
if (vSrcHeightF.isInitialized()) { if (vSrcHeightF.isInitialized()) {
uni_vcmpps(vAux1, vCoordDst, vSrcHeightF, 0x1); // vCoordDst < vUpperBound uni_vcmpps(vAux1, vCoordDst, vSrcHeightF, CMP_LT_PS); // vCoordDst < vUpperBound
} else { } else {
mov(rAux, ptr[regParams + GET_OFF(srcHeightF)]); mov(rAux, ptr[regParams + GET_OFF(srcHeightF)]);
uni_vcmpps(vAux1, vCoordDst, ptr[rAux], 0x1); // vCoordDst < vUpperBound uni_vcmpps(vAux1, vCoordDst, ptr[rAux], CMP_LT_PS); // vCoordDst < vUpperBound
} }
} }
@ -1246,23 +1269,21 @@ void GridSampleKernel<isa>::nearestInterpolation(const Vmm& vWCoord, const Vmm&
template <> template <>
void GridSampleKernel<x64::avx512_core>::bilinearInterpolation(const Vmm& vWCoord, const Vmm& vHCoord, bool tail) { void GridSampleKernel<x64::avx512_core>::bilinearInterpolation(const Vmm& vWCoord, const Vmm& vHCoord, bool tail) {
auto vDX = getVmm(); const auto& vDX = vWCoord;
auto vDY = getVmm(); const auto& vDY = vHCoord;
const auto& shift00 = vWCoord; auto shift00 = getVmm();
const auto& shift01 = vHCoord; auto shift01 = getVmm();
auto shift10 = getVmm(); auto shift10 = getVmm();
auto shift11 = getVmm(); auto shift11 = getVmm();
auto vAux = getVmm(); auto vAux = getVmm();
RegistersPool::Reg<Vmask> kMask00, kMask01, kMask10, kMask11; RegistersPool::Reg<Vmask> kMask00, kMask01, kMask10, kMask11;
uni_vmovups(vDX, vWCoord); uni_vroundps(shift00, vWCoord, 0x1); // Round floor
uni_vmovups(vDY, vHCoord); uni_vroundps(shift01, vHCoord, 0x1); // Round floor
uni_vroundps(vWCoord, vWCoord, 0x1); // Round floor uni_vsubps(vDX, vWCoord, shift00);
uni_vroundps(vHCoord, vHCoord, 0x1); // Round floor uni_vsubps(vDY, vHCoord, shift01);
uni_vsubps(vDX, vDX, vWCoord); uni_vaddps(shift10, shift00, vOnesF);
uni_vsubps(vDY, vDY, vHCoord); uni_vaddps(shift11, shift01, vOnesF);
uni_vaddps(shift10, vWCoord, vOnesF);
uni_vaddps(shift11, vHCoord, vOnesF);
bool useMask = false, zeroFill = false; bool useMask = false, zeroFill = false;
if (jcp.paddingMode == GridSamplePaddingMode::ZEROS) { if (jcp.paddingMode == GridSamplePaddingMode::ZEROS) {
@ -1272,31 +1293,31 @@ void GridSampleKernel<x64::avx512_core>::bilinearInterpolation(const Vmm& vWCoor
kMask10 = getMask(); kMask10 = getMask();
kMask11 = getMask(); kMask11 = getMask();
zerosPadding(kMask00, vHCoord, vWCoord); // (y; x) zerosPadding(kMask00, shift01, shift00); // (y; x)
zerosPadding(kMask01, vHCoord, shift10); // (y; x + 1) zerosPadding(kMask01, shift01, shift10); // (y; x + 1)
zerosPadding(kMask11, shift11, shift10); // (y + 1; x + 1) zerosPadding(kMask11, shift11, shift10); // (y + 1; x + 1)
zerosPadding(kMask10, shift11, vWCoord); // (y + 1; x) zerosPadding(kMask10, shift11, shift00); // (y + 1; x)
hwShiftPs2dq(shift00, vHCoord, vWCoord, vSrcWidthF); hwShiftPs2dq(shift00, shift01, shift00, vSrcWidthF);
uni_vpaddd(shift01, shift00, vDataTypeSizeB); uni_vpaddd(shift01, shift00, vDataTypeSizeB);
uni_vpaddd(shift10, shift00, vSrcWidthB); // shift11?? uni_vpaddd(shift10, shift00, vSrcWidthB);
uni_vpaddd(shift11, shift10, vDataTypeSizeB); // sub?? uni_vpaddd(shift11, shift10, vDataTypeSizeB);
} else if (jcp.paddingMode == GridSamplePaddingMode::BORDER) { } else if (jcp.paddingMode == GridSamplePaddingMode::BORDER) {
borderPadding(vWCoord, vWCoord, coord::w); borderPadding(shift00, shift00, coord::w);
borderPadding(vHCoord, vHCoord, coord::h); borderPadding(shift01, shift01, coord::h);
borderPadding(shift10, shift10, coord::w); borderPadding(shift10, shift10, coord::w);
borderPadding(shift11, shift11, coord::h); borderPadding(shift11, shift11, coord::h);
} else if (jcp.paddingMode == GridSamplePaddingMode::REFLECTION) { } else if (jcp.paddingMode == GridSamplePaddingMode::REFLECTION) {
reflectionPadding(vWCoord, vWCoord, coord::w); reflectionPadding(shift00, shift00, coord::w);
reflectionPadding(vHCoord, vHCoord, coord::h); reflectionPadding(shift01, shift01, coord::h);
reflectionPadding(shift10, shift10, coord::w); reflectionPadding(shift10, shift10, coord::w);
reflectionPadding(shift11, shift11, coord::h); reflectionPadding(shift11, shift11, coord::h);
} }
if (jcp.paddingMode == GridSamplePaddingMode::BORDER || jcp.paddingMode == GridSamplePaddingMode::REFLECTION) { if (jcp.paddingMode == GridSamplePaddingMode::BORDER || jcp.paddingMode == GridSamplePaddingMode::REFLECTION) {
// W * y + x // W * y + x
hwShiftPs2dq(vAux, shift11, vWCoord, vSrcWidthF); hwShiftPs2dq(vAux, shift11, shift00, vSrcWidthF);
hwShiftPs2dq(vWCoord, vHCoord, vWCoord, vSrcWidthF); hwShiftPs2dq(shift00, shift01, shift00, vSrcWidthF);
hwShiftPs2dq(vHCoord, vHCoord, shift10, vSrcWidthF); hwShiftPs2dq(shift01, shift01, shift10, vSrcWidthF);
hwShiftPs2dq(shift11, shift11, shift10, vSrcWidthF); hwShiftPs2dq(shift11, shift11, shift10, vSrcWidthF);
uni_vmovups(shift10, vAux); uni_vmovups(shift10, vAux);
} }
@ -1658,8 +1679,8 @@ void GridSampleKernel<x64::avx512_core>::bicubicInterpolation(const Vmm& vWCoord
// (y - 1 + h; x - 1) // (y - 1 + h; x - 1)
if (jcp.paddingMode == GridSamplePaddingMode::ZEROS) { if (jcp.paddingMode == GridSamplePaddingMode::ZEROS) {
Xbyak::Opmask maskH = kMaskH; Xbyak::Opmask maskH = kMaskH;
vcmpps(kMaskH, vHCoord, vSrcHeightF, 0x1); vcmpps(kMaskH, vHCoord, vSrcHeightF, CMP_LT_PS);
vcmpps(maskH | maskH, vZeros, vHCoord, 0x2); vcmpps(maskH | maskH, vZeros, vHCoord, CMP_LE_PS);
kandw(kAuxMask, kMaskH, wMasks[0]); kandw(kAuxMask, kMaskH, wMasks[0]);
uni_vmulps(vSrcShift0, vHCoord, vSrcWidthF); uni_vmulps(vSrcShift0, vHCoord, vSrcWidthF);
uni_vmovups(vWCoord, vWLeft); uni_vmovups(vWCoord, vWLeft);

View File

@ -286,26 +286,80 @@ void JitKernelBase::uni_vpbroadcastd(const Xbyak::Ymm &x, const Xbyak::Operand &
} }
void JitKernelBase::fillRestWorkMask(const Xbyak::Opmask& dstMask, void JitKernelBase::fillRestWorkMask(const Xbyak::Opmask& dstMask,
const Xbyak::Zmm& zAux, const Xbyak::Reg64& rWorkRest) {
const Xbyak::Reg64& rWorkRest) { auto rOnes = getReg64();
auto rAux0 = getReg64();
auto rAux1 = getReg64();
Xbyak::Label lKmov;
Xbyak::Reg32 rOnes(rAux1.getIdx());
const uint64_t typeSize = 4;
const uint64_t elPerVec = x64::cpu_isa_traits<x64::avx512_core>::vlen / typeSize;
mov(rOnes, 0x0000FFFF); mov(rOnes, 0xFFFFFFFFFFFFFFFF);
cmp(rWorkRest, elPerVec); shlx(rOnes, rOnes, rWorkRest);
jge(lKmov); not_(rOnes);
{ kmovq(dstMask, rOnes);
Xbyak::Reg32 rShift(rAux0.getIdx()); }
mov(rShift, elPerVec);
sub(rShift, rWorkRest); void JitKernelBase::fillRestWorkMask(const Xbyak::Xmm& xmmDstMask,
shrx(rOnes, rOnes, rShift); const Xbyak::Reg64& rWorkRest,
const uint64_t typeSize) {
if (!one_of(typeSize, 1, 2, 4, 8)) {
IE_THROW() << "Could not fill data with type size " << typeSize;
} }
L(lKmov); Xbyak::Label lEnd;
kmovw(dstMask, rOnes); auto r32Ones = getReg32();
Xbyak::Reg64 r64Ones(r32Ones.getIdx());
auto elPerVec = x64::cpu_isa_traits<x64::sse41>::vlen / typeSize;
mov(r64Ones, 0xFFFFFFFFFFFFFFFF);
for (uint8_t i = 0; i < elPerVec; i++) {
cmp(rWorkRest, i);
jle(lEnd, T_NEAR);
if (typeSize == 1) {
pinsrb(xmmDstMask, r32Ones, i);
} else if (typeSize == 2) {
pinsrw(xmmDstMask, r32Ones, i);
} else if (typeSize == 4) {
pinsrd(xmmDstMask, r32Ones, i);
} else if (typeSize == 8) {
pinsrq(xmmDstMask, r64Ones, i);
}
}
L(lEnd);
}
void JitKernelBase::fillRestWorkMask(const Xbyak::Ymm& ymmDstMask,
const Xbyak::Reg64& rWorkRest,
const uint64_t typeSize) {
if (!one_of(typeSize, 1, 2, 4, 8)) {
IE_THROW() << "Could not fill data with type size " << typeSize;
}
Xbyak::Label lEnd;
auto elPerVec = x64::cpu_isa_traits<x64::sse41>::vlen / typeSize;
auto r32Ones = getReg32();
Xbyak::Reg64 r64Ones(r32Ones.getIdx());
Xbyak::Xmm xmmDstMask(ymmDstMask.getIdx());
mov(r64Ones, 0xFFFFFFFFFFFFFFFF);
uni_vpxor(ymmDstMask, ymmDstMask, ymmDstMask);
for (uint8_t i = 0; i < 2; i++) {
Xbyak::Label lPerm;
for (uint8_t j = 0; j < elPerVec; j++) {
cmp(rWorkRest, i * elPerVec + j);
jle(i == 0 ? lEnd : lPerm, T_NEAR);
if (typeSize == 1) {
pinsrb(xmmDstMask, r32Ones, j);
} else if (typeSize == 2) {
pinsrw(xmmDstMask, r32Ones, j);
} else if (typeSize == 4) {
pinsrd(xmmDstMask, r32Ones, j);
} else if (typeSize == 8) {
pinsrq(xmmDstMask, r64Ones, j);
}
}
cmp(rWorkRest, elPerVec);
je(lEnd, T_NEAR);
L(lPerm);
vperm2f128(ymmDstMask, ymmDstMask, ymmDstMask, 0x1);
}
L(lEnd);
} }
void JitKernelBase::load(const Xbyak::Xmm& vDst, void JitKernelBase::load(const Xbyak::Xmm& vDst,

View File

@ -11,6 +11,7 @@ namespace ov {
namespace intel_cpu { namespace intel_cpu {
#define getReg64() RegistersPool::Reg<Xbyak::Reg64>(registersPool) #define getReg64() RegistersPool::Reg<Xbyak::Reg64>(registersPool)
#define getReg32() RegistersPool::Reg<Xbyak::Reg32>(registersPool)
#define getVmm() RegistersPool::Reg<Vmm>(registersPool) #define getVmm() RegistersPool::Reg<Vmm>(registersPool)
#define getMask() RegistersPool::Reg<Vmask>(registersPool) #define getMask() RegistersPool::Reg<Vmask>(registersPool)
@ -84,9 +85,16 @@ public:
const bool zeroFill = false); const bool zeroFill = false);
void fillRestWorkMask(const Xbyak::Opmask& kDstMask, void fillRestWorkMask(const Xbyak::Opmask& kDstMask,
const Xbyak::Zmm& zAux,
const Xbyak::Reg64& rWorkRest); const Xbyak::Reg64& rWorkRest);
void fillRestWorkMask(const Xbyak::Xmm& ymmDstMask,
const Xbyak::Reg64& rWorkRest,
const uint64_t typeSize = 4);
void fillRestWorkMask(const Xbyak::Ymm& ymmDstMask,
const Xbyak::Reg64& rWorkRest,
const uint64_t typeSize = 4);
void load(const Xbyak::Xmm& vDst, void load(const Xbyak::Xmm& vDst,
const Xbyak::Address& srcAddr, const Xbyak::Address& srcAddr,
const Xbyak::Reg64& rLoadNum, const Xbyak::Reg64& rLoadNum,
@ -133,6 +141,18 @@ protected:
} }
RegistersPool::Ptr registersPool; RegistersPool::Ptr registersPool;
enum {
// Comparison predicate operand (immediate byte) for single-precision floating-point values.
CMP_EQ_PS = 0, // Equal (ordered, non-signaling)
CMP_LT_PS, // Less-than (ordered, signaling)
CMP_LE_PS, // Less-than-or-equal (ordered, signaling)
CMP_UNORD_PS, // Unordered (non-signaling)
CMP_NEQ_PS, // Not-equal (unordered, non-signaling)
CMP_NLT_PS, // Not-less-than (unordered, signaling)
CMP_NLE_PS, // Not-less-than-or-equal (unordered, signaling)
CMP_ORD_PS // Ordered (non-signaling)
};
}; };
} // namespace intel_cpu } // namespace intel_cpu

View File

@ -184,9 +184,6 @@ std::vector<std::string> disabledTestPatterns() {
// The kernel does not have such garbage. The diff 0.000000745 is taken into account in calculations and affects further type conversion. // The kernel does not have such garbage. The diff 0.000000745 is taken into account in calculations and affects further type conversion.
// Reorder->GridSample->Reorder also does not work here. Potential fix is to use nearest conversion instead of truncation. // Reorder->GridSample->Reorder also does not work here. Potential fix is to use nearest conversion instead of truncation.
R"(.*GridSampleLayerTestCPU.*(BILINEAR|BICUBIC).*(i32|i8).*)", R"(.*GridSampleLayerTestCPU.*(BILINEAR|BICUBIC).*(i32|i8).*)",
// 94989. BF16 Reference produces different results.
// GridSample regression on bf16 data.
R"(.*GridSampleLayerTestCPU.*(BILINEAR|BICUBIC).*bf16.*)",
// // Issue: 95915 // // Issue: 95915
R"(smoke_dynamic/AUGRUCellCPUTest.CompareWithRefs/IS=\(\[\?\.1\]_\[\?\.1\]_\[\?\.1\]_\)_TS=\{\(1\.1\)_\(1\.1\)_\(1\.1\)\}_\{\(3\.1\)_\(3\.1\)_\(3\.1\)\}_\{\(5\.1\)_\(5\.1\)_\(5\.1\)\}_decompose=0_activations=\(sigmoid\.tanh\)_clip=0_linear=0_netPrec=f32__inFmts=nc\.nc_outFmts=nc_primitive=ref_any_PluginConf_ENFORCE_BF16=YES)", // NOLINT R"(smoke_dynamic/AUGRUCellCPUTest.CompareWithRefs/IS=\(\[\?\.1\]_\[\?\.1\]_\[\?\.1\]_\)_TS=\{\(1\.1\)_\(1\.1\)_\(1\.1\)\}_\{\(3\.1\)_\(3\.1\)_\(3\.1\)\}_\{\(5\.1\)_\(5\.1\)_\(5\.1\)\}_decompose=0_activations=\(sigmoid\.tanh\)_clip=0_linear=0_netPrec=f32__inFmts=nc\.nc_outFmts=nc_primitive=ref_any_PluginConf_ENFORCE_BF16=YES)", // NOLINT
R"(smoke_dynamic/GRUCellCPUTest.CompareWithRefs/IS=\(\[\?.1\]_\[\?\.1\]_\)_TS=\{\(1\.1\)_\(1\.1\)\}_\{\(3\.1\)_\(3\.1\)\}_\{\(5\.1\)_\(5\.1\)\}_decompose=0_activations=\(sigmoid\.tanh\)_clip=0_linear=0_netPrec=f32__inFmts=nc\.nc_outFmts=nc_primitive=ref_any_PluginConf_ENFORCE_BF16=YES)", // NOLINT R"(smoke_dynamic/GRUCellCPUTest.CompareWithRefs/IS=\(\[\?.1\]_\[\?\.1\]_\)_TS=\{\(1\.1\)_\(1\.1\)\}_\{\(3\.1\)_\(3\.1\)\}_\{\(5\.1\)_\(5\.1\)\}_decompose=0_activations=\(sigmoid\.tanh\)_clip=0_linear=0_netPrec=f32__inFmts=nc\.nc_outFmts=nc_primitive=ref_any_PluginConf_ENFORCE_BF16=YES)", // NOLINT

View File

@ -94,6 +94,9 @@ protected:
auto execType = dataPrecision == ov::element::i32 ? ov::element::i32 : ov::element::f32; auto execType = dataPrecision == ov::element::i32 ? ov::element::i32 : ov::element::f32;
selectedType = makeSelectedTypeStr(selectedType, execType); selectedType = makeSelectedTypeStr(selectedType, execType);
} }
if (gridPrecision == ov::element::bf16) {
rel_threshold = 0.01f;
}
auto params = ngraph::builder::makeDynamicParams({dataPrecision, gridPrecision}, inputDynamicShapes); auto params = ngraph::builder::makeDynamicParams({dataPrecision, gridPrecision}, inputDynamicShapes);
params[0]->set_friendly_name("data"); params[0]->set_friendly_name("data");
@ -272,12 +275,35 @@ INSTANTIATE_TEST_SUITE_P(smoke_static, GridSampleLayerTestCPU,
::testing::ValuesIn(interpolateMode), ::testing::ValuesIn(interpolateMode),
::testing::ValuesIn(paddingMode), ::testing::ValuesIn(paddingMode),
::testing::ValuesIn(alignCorners), ::testing::ValuesIn(alignCorners),
::testing::ValuesIn({ElementType::f32, ElementType::bf16, ElementType::i32, ElementType::i8}), ::testing::ValuesIn({ElementType::f32, ElementType::i32}),
::testing::ValuesIn({ElementType::f32}),
::testing::ValuesIn(getCPUInfo()),
::testing::Values(additionalConfig[0])),
GridSampleLayerTestCPU::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(nightly_static_1, GridSampleLayerTestCPU,
::testing::Combine(
::testing::ValuesIn(getStaticShapes()),
::testing::ValuesIn(interpolateMode),
::testing::ValuesIn(paddingMode),
::testing::ValuesIn(alignCorners),
::testing::ValuesIn({ElementType::bf16, ElementType::i8}),
::testing::ValuesIn({ElementType::f32, ElementType::bf16}), ::testing::ValuesIn({ElementType::f32, ElementType::bf16}),
::testing::ValuesIn(getCPUInfo()), ::testing::ValuesIn(getCPUInfo()),
::testing::Values(additionalConfig[0])), ::testing::Values(additionalConfig[0])),
GridSampleLayerTestCPU::getTestCaseName); GridSampleLayerTestCPU::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(nightly_static_2, GridSampleLayerTestCPU,
::testing::Combine(
::testing::ValuesIn(getStaticShapes()),
::testing::ValuesIn(interpolateMode),
::testing::ValuesIn(paddingMode),
::testing::ValuesIn(alignCorners),
::testing::ValuesIn({ElementType::f32}),
::testing::ValuesIn({ElementType::bf16}),
::testing::ValuesIn(getCPUInfo()),
::testing::Values(additionalConfig[0])),
GridSampleLayerTestCPU::getTestCaseName);
const std::vector<std::vector<InputShape>> dynamicInSapes = { const std::vector<std::vector<InputShape>> dynamicInSapes = {
{ { { ov::Dimension(1, 15), -1, -1, -1 }, // Dynamic shape 0 { { { ov::Dimension(1, 15), -1, -1, -1 }, // Dynamic shape 0