[CPU] Fix sporadic SIGFAULT in GridSample. (#15009)
This commit is contained in:
parent
a48b4fc2b5
commit
188dda668f
@ -202,12 +202,16 @@ void GridSample::prepareParams() {
|
|||||||
|
|
||||||
auto& p = execParamsPerThread[ithr];
|
auto& p = execParamsPerThread[ithr];
|
||||||
|
|
||||||
|
p.workAmount = dstEnd - dstStart;
|
||||||
|
if (p.workAmount == 0lu) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
p.batchNum = srcDataShape[0];
|
p.batchNum = srcDataShape[0];
|
||||||
p.channelsNum = srcDataShape[1];
|
p.channelsNum = srcDataShape[1];
|
||||||
p.srcHeightF[0] = srcDataShape[2];
|
p.srcHeightF[0] = srcDataShape[2];
|
||||||
p.srcWidthF[0] = srcDataShape[3];
|
p.srcWidthF[0] = srcDataShape[3];
|
||||||
|
|
||||||
p.workAmount = dstEnd - dstStart;
|
|
||||||
p.gridStartB = dstStart * 2 * gridTypeSize;
|
p.gridStartB = dstStart * 2 * gridTypeSize;
|
||||||
p.dstStartB = dstStart * dataTypeSize;
|
p.dstStartB = dstStart * dataTypeSize;
|
||||||
|
|
||||||
|
@ -76,10 +76,8 @@ void GridSampleKernel<x64::avx512_core>::initVectors() {
|
|||||||
mov(rAux, ptr[regParams + GET_OFF(srcHeightF)]);
|
mov(rAux, ptr[regParams + GET_OFF(srcHeightF)]);
|
||||||
uni_vpbroadcastd(vSrcHeightF, ptr[rAux]);
|
uni_vpbroadcastd(vSrcHeightF, ptr[rAux]);
|
||||||
|
|
||||||
if (one_of(jcp.paddingMode, GridSamplePaddingMode::ZEROS, GridSamplePaddingMode::BORDER)) {
|
vZeros = getVmm();
|
||||||
vZeros = getVmm();
|
uni_vpxor(vZeros, vZeros, vZeros);
|
||||||
uni_vpxor(vZeros, vZeros, vZeros);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (one_of(jcp.interpolationMode, GridSampleInterpolationMode::BICUBIC, GridSampleInterpolationMode::BILINEAR)) {
|
if (one_of(jcp.interpolationMode, GridSampleInterpolationMode::BICUBIC, GridSampleInterpolationMode::BILINEAR)) {
|
||||||
vOnesF = getVmm();
|
vOnesF = getVmm();
|
||||||
@ -430,7 +428,7 @@ void GridSampleKernel<x64::avx512_core>::getTailCoordinates(const Vmm& vHCoord,
|
|||||||
cmp(rAux, 0);
|
cmp(rAux, 0);
|
||||||
jle(lEnd, T_NEAR);
|
jle(lEnd, T_NEAR);
|
||||||
|
|
||||||
fillRestWorkMask(kTailMask, vAux, rAux);
|
fillRestWorkMask(kTailMask, rAux);
|
||||||
uni_vmovups((Vmm)vAux | kTailMask, ptr[regGrid]);
|
uni_vmovups((Vmm)vAux | kTailMask, ptr[regGrid]);
|
||||||
vpermd(vAux, vGridPermMask, vAux);
|
vpermd(vAux, vGridPermMask, vAux);
|
||||||
Xbyak::Ymm ymmAux(vAux.getIdx());
|
Xbyak::Ymm ymmAux(vAux.getIdx());
|
||||||
@ -441,7 +439,7 @@ void GridSampleKernel<x64::avx512_core>::getTailCoordinates(const Vmm& vHCoord,
|
|||||||
}
|
}
|
||||||
L(lRest);
|
L(lRest);
|
||||||
{
|
{
|
||||||
fillRestWorkMask(kTailMask, vAux, rAux);
|
fillRestWorkMask(kTailMask, rAux);
|
||||||
uni_vmovups(vWCoord | kTailMask, ptr[regGrid]);
|
uni_vmovups(vWCoord | kTailMask, ptr[regGrid]);
|
||||||
vpermd(vWCoord, vGridPermMask, vWCoord);
|
vpermd(vWCoord, vGridPermMask, vWCoord);
|
||||||
vshuff64x2(vHCoord, vWCoord, vHCoord, 0B11101110); // Extract Y component
|
vshuff64x2(vHCoord, vWCoord, vHCoord, 0B11101110); // Extract Y component
|
||||||
@ -454,7 +452,7 @@ void GridSampleKernel<x64::avx512_core>::getTailCoordinates(const Vmm& vHCoord,
|
|||||||
|
|
||||||
L(lEnd);
|
L(lEnd);
|
||||||
|
|
||||||
fillRestWorkMask(kTailMask, vAux, regWorkAmount);
|
fillRestWorkMask(kTailMask, regWorkAmount);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
@ -672,14 +670,14 @@ void GridSampleKernel<isa>::denormalizeRawCoordinates(const Vmm& vWCoord, const
|
|||||||
|
|
||||||
template <>
|
template <>
|
||||||
void GridSampleKernel<x64::avx512_core>::zerosPaddingW(const Vmask& kDst, const Vmm& vCoord) {
|
void GridSampleKernel<x64::avx512_core>::zerosPaddingW(const Vmask& kDst, const Vmm& vCoord) {
|
||||||
vcmpps(kDst, vCoord, vSrcWidthF, 0x1); // vCoord < vUpperBound
|
vcmpps(kDst, vCoord, vSrcWidthF, CMP_LT_PS); // vCoord < vUpperBound
|
||||||
vcmpps(kDst | kDst, vZeros, vCoord, 0x2); // vCoord >= vZeros
|
vcmpps(kDst | kDst, vZeros, vCoord, CMP_LE_PS); // vCoord >= vZeros
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
void GridSampleKernel<x64::avx512_core>::zerosPaddingH(const Vmask& kDst, const Vmm& vCoord, const Vmask& kMaskW) {
|
void GridSampleKernel<x64::avx512_core>::zerosPaddingH(const Vmask& kDst, const Vmm& vCoord, const Vmask& kMaskW) {
|
||||||
vcmpps(kDst | kMaskW, vCoord, vSrcHeightF, 0x1); // vCoord < vUpperBound
|
vcmpps(kDst | kMaskW, vCoord, vSrcHeightF, CMP_LT_PS); // vCoord < vUpperBound
|
||||||
vcmpps(kDst | kDst, vZeros, vCoord, 0x2); // vCoord >= vZeros
|
vcmpps(kDst | kDst, vZeros, vCoord, CMP_LE_PS); // vCoord >= vZeros
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
@ -693,15 +691,15 @@ void GridSampleKernel<x64::sse41>::zerosPaddingW(const Vmask& kDst, const Vmm& v
|
|||||||
auto vAux = getVmm();
|
auto vAux = getVmm();
|
||||||
|
|
||||||
if (vSrcWidthF.isInitialized()) {
|
if (vSrcWidthF.isInitialized()) {
|
||||||
uni_vcmpps(vAux, vWCoord, vSrcWidthF, 0x1); // vWCoord < vSrcWidthF
|
uni_vcmpps(vAux, vWCoord, vSrcWidthF, CMP_LT_PS); // vWCoord < vSrcWidthF
|
||||||
} else {
|
} else {
|
||||||
auto rAux = getReg64();
|
auto rAux = getReg64();
|
||||||
mov(rAux, ptr[regParams + GET_OFF(srcWidthF)]);
|
mov(rAux, ptr[regParams + GET_OFF(srcWidthF)]);
|
||||||
uni_vcmpps(vAux, vWCoord, ptr[rAux], 0x1); // vWCoord < vSrcWidthF
|
uni_vcmpps(vAux, vWCoord, ptr[rAux], CMP_LT_PS); // vWCoord < vSrcWidthF
|
||||||
}
|
}
|
||||||
|
|
||||||
uni_vpxor(kDst, kDst, kDst);
|
uni_vpxor(kDst, kDst, kDst);
|
||||||
uni_vcmpps(kDst, kDst, vWCoord, 0x2); // vWCoord >= vZeros
|
uni_vcmpps(kDst, kDst, vWCoord, CMP_LE_PS); // vWCoord >= vZeros
|
||||||
uni_vpand(kDst, kDst, vAux); // vZeros <= vWCoord < vSrcWidthF
|
uni_vpand(kDst, kDst, vAux); // vZeros <= vWCoord < vSrcWidthF
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -710,17 +708,17 @@ void GridSampleKernel<x64::sse41>::zerosPaddingH(const Vmask& kDst, const Vmm& v
|
|||||||
auto vAux = getVmm();
|
auto vAux = getVmm();
|
||||||
|
|
||||||
if (vSrcHeightF.isInitialized()) {
|
if (vSrcHeightF.isInitialized()) {
|
||||||
uni_vcmpps(vAux, vHCoord, vSrcHeightF, 0x1); // vHCoord < vSrcHeightF
|
uni_vcmpps(vAux, vHCoord, vSrcHeightF, CMP_LT_PS); // vHCoord < vSrcHeightF
|
||||||
} else {
|
} else {
|
||||||
auto rAux = getReg64();
|
auto rAux = getReg64();
|
||||||
mov(rAux, ptr[regParams + GET_OFF(srcHeightF)]);
|
mov(rAux, ptr[regParams + GET_OFF(srcHeightF)]);
|
||||||
uni_vcmpps(vAux, vHCoord, ptr[rAux], 0x1); // vHCoord < vSrcHeightF
|
uni_vcmpps(vAux, vHCoord, ptr[rAux], CMP_LT_PS); // vHCoord < vSrcHeightF
|
||||||
}
|
}
|
||||||
|
|
||||||
uni_vmovups(kDst, kMaskW);
|
uni_vmovups(kDst, kMaskW);
|
||||||
uni_vpand(kDst, kDst, vAux); // vHCoord < vSrcHeightF && vZeros <= vWCoord < vSrcWidthF
|
uni_vpand(kDst, kDst, vAux); // vHCoord < vSrcHeightF && vZeros <= vWCoord < vSrcWidthF
|
||||||
uni_vpxor(vAux, vAux, vAux);
|
uni_vpxor(vAux, vAux, vAux);
|
||||||
uni_vcmpps(vAux, vAux, vHCoord, 0x2); // vHCoord >= vZeros
|
uni_vcmpps(vAux, vAux, vHCoord, CMP_LE_PS); // vHCoord >= vZeros
|
||||||
uni_vpand(kDst, kDst, vAux); // vZeros <= vHCoord < vSrcHeightF && vZeros <= vWCoord < vSrcWidthF
|
uni_vpand(kDst, kDst, vAux); // vZeros <= vHCoord < vSrcHeightF && vZeros <= vWCoord < vSrcWidthF
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -744,14 +742,14 @@ void GridSampleKernel<isa>::zerosPaddingW(const Vmask& kDst, const Vmm& vCoord)
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (vSrcWidthF.isInitialized()) {
|
if (vSrcWidthF.isInitialized()) {
|
||||||
uni_vcmpps(vAux, vCoord, vSrcWidthF, 0x1); // vWCoord < vSrcWidthF
|
uni_vcmpps(vAux, vCoord, vSrcWidthF, CMP_LT_PS); // vWCoord < vSrcWidthF
|
||||||
} else {
|
} else {
|
||||||
auto rAux = getReg64();
|
auto rAux = getReg64();
|
||||||
mov(rAux, ptr[regParams + GET_OFF(srcWidthF)]);
|
mov(rAux, ptr[regParams + GET_OFF(srcWidthF)]);
|
||||||
uni_vcmpps(vAux, vCoord, ptr[rAux], 0x1); // vWCoord < vSrcWidthF
|
uni_vcmpps(vAux, vCoord, ptr[rAux], CMP_LT_PS); // vWCoord < vSrcWidthF
|
||||||
}
|
}
|
||||||
|
|
||||||
uni_vcmpps(kDst, vZerosTmp, vCoord, 0x2); // vWCoord >= vZeros
|
uni_vcmpps(kDst, vZerosTmp, vCoord, CMP_LE_PS); // vWCoord >= vZeros
|
||||||
uni_vandps(kDst, kDst, vAux); // vZeros <= vWCoord < vSrcWidthF
|
uni_vandps(kDst, kDst, vAux); // vZeros <= vWCoord < vSrcWidthF
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -769,15 +767,15 @@ void GridSampleKernel<isa>::zerosPaddingH(const Vmask& kDst, const Vmm& vCoord,
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (vSrcHeightF.isInitialized()) {
|
if (vSrcHeightF.isInitialized()) {
|
||||||
uni_vcmpps(vAux, vCoord, vSrcHeightF, 0x1); // vHCoord < vSrcHeightF
|
uni_vcmpps(vAux, vCoord, vSrcHeightF, CMP_LT_PS); // vHCoord < vSrcHeightF
|
||||||
} else {
|
} else {
|
||||||
auto rAux = getReg64();
|
auto rAux = getReg64();
|
||||||
mov(rAux, ptr[regParams + GET_OFF(srcHeightF)]);
|
mov(rAux, ptr[regParams + GET_OFF(srcHeightF)]);
|
||||||
uni_vcmpps(vAux, vCoord, ptr[rAux], 0x1); // vHCoord < vSrcHeightF
|
uni_vcmpps(vAux, vCoord, ptr[rAux], CMP_LT_PS); // vHCoord < vSrcHeightF
|
||||||
}
|
}
|
||||||
|
|
||||||
uni_vandps(kDst, kMaskW, vAux);
|
uni_vandps(kDst, kMaskW, vAux);
|
||||||
uni_vcmpps(vAux, vZerosTmp, vCoord, 0x2); // vHCoord >= vZeros
|
uni_vcmpps(vAux, vZerosTmp, vCoord, CMP_LE_PS); // vHCoord >= vZeros
|
||||||
uni_vandps(kDst, kDst, vAux);
|
uni_vandps(kDst, kDst, vAux);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -831,7 +829,7 @@ void GridSampleKernel<isa>::borderPadding(const Vmm& vCoordDst, const Vmm& vCoor
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
uni_vcmpps(vAux, vCoordOrigin, vSub1F, 0x2); // vCoord <= vUpperBound
|
uni_vcmpps(vAux, vCoordOrigin, vSub1F, CMP_LE_PS); // vCoord <= vUpperBound
|
||||||
uni_vandps(vCoordDst, vCoordOrigin, vAux);
|
uni_vandps(vCoordDst, vCoordOrigin, vAux);
|
||||||
uni_vandnps(vAux, vAux, vSub1F);
|
uni_vandnps(vAux, vAux, vSub1F);
|
||||||
uni_vaddps(vCoordDst, vCoordDst, vAux);
|
uni_vaddps(vCoordDst, vCoordDst, vAux);
|
||||||
@ -857,14 +855,20 @@ void GridSampleKernel<isa>::borderPadding(const Vmm& vCoordDst, const Vmm& vCoor
|
|||||||
template <>
|
template <>
|
||||||
void GridSampleKernel<x64::avx512_core>::reflectionPadding(const Vmm& vCoordDst, const Vmm& vCoordOrigin, const coord dim) {
|
void GridSampleKernel<x64::avx512_core>::reflectionPadding(const Vmm& vCoordDst, const Vmm& vCoordOrigin, const coord dim) {
|
||||||
auto vAux = getVmm();
|
auto vAux = getVmm();
|
||||||
|
auto kAux = getMask();
|
||||||
const auto& vSrcDimMul2Sub1F = dim == coord::w ? vSrcWidthMul2Sub1F : vSrcHeightMul2Sub1F;
|
const auto& vSrcDimMul2Sub1F = dim == coord::w ? vSrcWidthMul2Sub1F : vSrcHeightMul2Sub1F;
|
||||||
|
|
||||||
if (jcp.alignCorners) {
|
if (jcp.alignCorners) {
|
||||||
// abs(x) % D21
|
// abs(x) % D21
|
||||||
uni_vandps(vCoordDst, vCoordOrigin, vAbsMask); // abs(x)
|
uni_vandps(vCoordDst, vCoordOrigin, vAbsMask); // abs(x)
|
||||||
uni_vdivps(vAux, vCoordDst, vSrcDimMul2Sub1F);
|
uni_vdivps(vAux, vCoordDst, vSrcDimMul2Sub1F);
|
||||||
uni_vroundps(vAux, vAux, 0x3); // Truncation
|
uni_vroundps(vAux, vAux, 0x3); // Truncation
|
||||||
uni_vfnmadd231ps(vCoordDst, vAux, vSrcDimMul2Sub1F); // abs(x) % D21
|
uni_vfnmadd231ps(vCoordDst, vAux, vSrcDimMul2Sub1F); // abs(x) % D21
|
||||||
|
|
||||||
|
// Check that the result does not exceed the divisor.
|
||||||
|
vcmpps(kAux, vSrcDimMul2Sub1F, vCoordDst, CMP_LE_PS);
|
||||||
|
uni_vmovups(vCoordDst | kAux, vZeros);
|
||||||
|
vrangeps(vCoordDst, vCoordDst, vZeros, 0x1);
|
||||||
} else {
|
} else {
|
||||||
const auto& vSrcDimMul2F = dim == coord::w ? vSrcWidthMul2F : vSrcHeightMul2F;
|
const auto& vSrcDimMul2F = dim == coord::w ? vSrcWidthMul2F : vSrcHeightMul2F;
|
||||||
// (x % D2 + D2) % D2
|
// (x % D2 + D2) % D2
|
||||||
@ -877,12 +881,16 @@ void GridSampleKernel<x64::avx512_core>::reflectionPadding(const Vmm& vCoordDst,
|
|||||||
uni_vdivps(vAux, vCoordDst, vSrcDimMul2F);
|
uni_vdivps(vAux, vCoordDst, vSrcDimMul2F);
|
||||||
uni_vroundps(vAux, vAux, 0x3); // Truncation
|
uni_vroundps(vAux, vAux, 0x3); // Truncation
|
||||||
uni_vfnmadd231ps(vCoordDst, vAux, vSrcDimMul2F); // (x % D2 + D2) % D2
|
uni_vfnmadd231ps(vCoordDst, vAux, vSrcDimMul2F); // (x % D2 + D2) % D2
|
||||||
|
|
||||||
|
// Check that the result does not exceed the divisor.
|
||||||
|
vcmpps(kAux, vSrcDimMul2F, vCoordDst, CMP_LE_PS);
|
||||||
|
uni_vmovups(vCoordDst | kAux, vZeros);
|
||||||
|
vrangeps(vCoordDst, vCoordDst, vZeros, 0x1);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto kAux = getMask();
|
|
||||||
uni_vsubps(vAux, vSrcDimMul2Sub1F, vCoordDst);
|
uni_vsubps(vAux, vSrcDimMul2Sub1F, vCoordDst);
|
||||||
vcmpps(kAux, dim == coord::w ? vSrcWidthF : vSrcHeightF, vCoordDst, 0x2); // vCoordDst >= vSrcDimF
|
vcmpps(kAux, dim == coord::w ? vSrcWidthF : vSrcHeightF, vCoordDst, CMP_LE_PS); // vCoordDst >= vSrcDimF
|
||||||
vmovups(vCoordDst | kAux, vAux);
|
uni_vmovups(vCoordDst | kAux, vAux);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <x64::cpu_isa_t isa> // Works for AVX2, AVX, SSE41
|
template <x64::cpu_isa_t isa> // Works for AVX2, AVX, SSE41
|
||||||
@ -925,6 +933,14 @@ void GridSampleKernel<isa>::reflectionPadding(const Vmm& vCoordDst, const Vmm& v
|
|||||||
uni_vdivps(vAux0, vCoordDst, vMul2Sub1);
|
uni_vdivps(vAux0, vCoordDst, vMul2Sub1);
|
||||||
uni_vroundps(vAux0, vAux0, 0x3); // Truncation
|
uni_vroundps(vAux0, vAux0, 0x3); // Truncation
|
||||||
uni_vfnmadd231ps(vCoordDst, vAux0, vMul2Sub1); // abs(x) % D21
|
uni_vfnmadd231ps(vCoordDst, vAux0, vMul2Sub1); // abs(x) % D21
|
||||||
|
|
||||||
|
// Check that the result does not exceed the divisor.
|
||||||
|
uni_vcmpps(vAux0, vCoordDst, vMul2Sub1, CMP_LT_PS);
|
||||||
|
uni_vandps(vCoordDst, vCoordDst, vAux0);
|
||||||
|
uni_vxorps(vAux0, vAux0, vAux0);
|
||||||
|
uni_vcmpps(vAux0, vAux0, vCoordDst, CMP_LE_PS);
|
||||||
|
uni_vandps(vCoordDst, vCoordDst, vAux0);
|
||||||
|
|
||||||
uni_vsubps(vAux0, vCoordDst, vMul2Sub1); // abs(x) % D21 - D21
|
uni_vsubps(vAux0, vCoordDst, vMul2Sub1); // abs(x) % D21 - D21
|
||||||
} else {
|
} else {
|
||||||
// x' = (x % D2 + D2) % D2 - D21
|
// x' = (x % D2 + D2) % D2 - D21
|
||||||
@ -956,6 +972,13 @@ void GridSampleKernel<isa>::reflectionPadding(const Vmm& vCoordDst, const Vmm& v
|
|||||||
uni_vroundps(vAux0, vAux0, 0x3); // Truncation
|
uni_vroundps(vAux0, vAux0, 0x3); // Truncation
|
||||||
uni_vfnmadd231ps(vCoordDst, vAux0, vMul2); // (x % D2 + D2) % D2
|
uni_vfnmadd231ps(vCoordDst, vAux0, vMul2); // (x % D2 + D2) % D2
|
||||||
|
|
||||||
|
// Check that the result does not exceed the divisor.
|
||||||
|
uni_vcmpps(vAux0, vCoordDst, vMul2, CMP_LT_PS);
|
||||||
|
uni_vandps(vCoordDst, vCoordDst, vAux0);
|
||||||
|
uni_vxorps(vAux0, vAux0, vAux0);
|
||||||
|
uni_vcmpps(vAux0, vAux0, vCoordDst, CMP_LE_PS);
|
||||||
|
uni_vandps(vCoordDst, vCoordDst, vAux0);
|
||||||
|
|
||||||
if (dim == coord::w) {
|
if (dim == coord::w) {
|
||||||
if (vSrcWidthMul2Sub1F.isInitialized()) {
|
if (vSrcWidthMul2Sub1F.isInitialized()) {
|
||||||
uni_vsubps(vAux0, vCoordDst, vSrcWidthMul2Sub1F);
|
uni_vsubps(vAux0, vCoordDst, vSrcWidthMul2Sub1F);
|
||||||
@ -975,17 +998,17 @@ void GridSampleKernel<isa>::reflectionPadding(const Vmm& vCoordDst, const Vmm& v
|
|||||||
|
|
||||||
if (dim == coord::w) {
|
if (dim == coord::w) {
|
||||||
if (vSrcWidthF.isInitialized()) {
|
if (vSrcWidthF.isInitialized()) {
|
||||||
uni_vcmpps(vAux1, vCoordDst, vSrcWidthF, 0x1); // vCoordDst < vUpperBound
|
uni_vcmpps(vAux1, vCoordDst, vSrcWidthF, CMP_LT_PS); // vCoordDst < vUpperBound
|
||||||
} else {
|
} else {
|
||||||
mov(rAux, ptr[regParams + GET_OFF(srcWidthF)]);
|
mov(rAux, ptr[regParams + GET_OFF(srcWidthF)]);
|
||||||
uni_vcmpps(vAux1, vCoordDst, ptr[rAux], 0x1); // vCoordDst < vUpperBound
|
uni_vcmpps(vAux1, vCoordDst, ptr[rAux], CMP_LT_PS); // vCoordDst < vUpperBound
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if (vSrcHeightF.isInitialized()) {
|
if (vSrcHeightF.isInitialized()) {
|
||||||
uni_vcmpps(vAux1, vCoordDst, vSrcHeightF, 0x1); // vCoordDst < vUpperBound
|
uni_vcmpps(vAux1, vCoordDst, vSrcHeightF, CMP_LT_PS); // vCoordDst < vUpperBound
|
||||||
} else {
|
} else {
|
||||||
mov(rAux, ptr[regParams + GET_OFF(srcHeightF)]);
|
mov(rAux, ptr[regParams + GET_OFF(srcHeightF)]);
|
||||||
uni_vcmpps(vAux1, vCoordDst, ptr[rAux], 0x1); // vCoordDst < vUpperBound
|
uni_vcmpps(vAux1, vCoordDst, ptr[rAux], CMP_LT_PS); // vCoordDst < vUpperBound
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1246,23 +1269,21 @@ void GridSampleKernel<isa>::nearestInterpolation(const Vmm& vWCoord, const Vmm&
|
|||||||
|
|
||||||
template <>
|
template <>
|
||||||
void GridSampleKernel<x64::avx512_core>::bilinearInterpolation(const Vmm& vWCoord, const Vmm& vHCoord, bool tail) {
|
void GridSampleKernel<x64::avx512_core>::bilinearInterpolation(const Vmm& vWCoord, const Vmm& vHCoord, bool tail) {
|
||||||
auto vDX = getVmm();
|
const auto& vDX = vWCoord;
|
||||||
auto vDY = getVmm();
|
const auto& vDY = vHCoord;
|
||||||
const auto& shift00 = vWCoord;
|
auto shift00 = getVmm();
|
||||||
const auto& shift01 = vHCoord;
|
auto shift01 = getVmm();
|
||||||
auto shift10 = getVmm();
|
auto shift10 = getVmm();
|
||||||
auto shift11 = getVmm();
|
auto shift11 = getVmm();
|
||||||
auto vAux = getVmm();
|
auto vAux = getVmm();
|
||||||
RegistersPool::Reg<Vmask> kMask00, kMask01, kMask10, kMask11;
|
RegistersPool::Reg<Vmask> kMask00, kMask01, kMask10, kMask11;
|
||||||
|
|
||||||
uni_vmovups(vDX, vWCoord);
|
uni_vroundps(shift00, vWCoord, 0x1); // Round floor
|
||||||
uni_vmovups(vDY, vHCoord);
|
uni_vroundps(shift01, vHCoord, 0x1); // Round floor
|
||||||
uni_vroundps(vWCoord, vWCoord, 0x1); // Round floor
|
uni_vsubps(vDX, vWCoord, shift00);
|
||||||
uni_vroundps(vHCoord, vHCoord, 0x1); // Round floor
|
uni_vsubps(vDY, vHCoord, shift01);
|
||||||
uni_vsubps(vDX, vDX, vWCoord);
|
uni_vaddps(shift10, shift00, vOnesF);
|
||||||
uni_vsubps(vDY, vDY, vHCoord);
|
uni_vaddps(shift11, shift01, vOnesF);
|
||||||
uni_vaddps(shift10, vWCoord, vOnesF);
|
|
||||||
uni_vaddps(shift11, vHCoord, vOnesF);
|
|
||||||
|
|
||||||
bool useMask = false, zeroFill = false;
|
bool useMask = false, zeroFill = false;
|
||||||
if (jcp.paddingMode == GridSamplePaddingMode::ZEROS) {
|
if (jcp.paddingMode == GridSamplePaddingMode::ZEROS) {
|
||||||
@ -1272,31 +1293,31 @@ void GridSampleKernel<x64::avx512_core>::bilinearInterpolation(const Vmm& vWCoor
|
|||||||
kMask10 = getMask();
|
kMask10 = getMask();
|
||||||
kMask11 = getMask();
|
kMask11 = getMask();
|
||||||
|
|
||||||
zerosPadding(kMask00, vHCoord, vWCoord); // (y; x)
|
zerosPadding(kMask00, shift01, shift00); // (y; x)
|
||||||
zerosPadding(kMask01, vHCoord, shift10); // (y; x + 1)
|
zerosPadding(kMask01, shift01, shift10); // (y; x + 1)
|
||||||
zerosPadding(kMask11, shift11, shift10); // (y + 1; x + 1)
|
zerosPadding(kMask11, shift11, shift10); // (y + 1; x + 1)
|
||||||
zerosPadding(kMask10, shift11, vWCoord); // (y + 1; x)
|
zerosPadding(kMask10, shift11, shift00); // (y + 1; x)
|
||||||
|
|
||||||
hwShiftPs2dq(shift00, vHCoord, vWCoord, vSrcWidthF);
|
hwShiftPs2dq(shift00, shift01, shift00, vSrcWidthF);
|
||||||
uni_vpaddd(shift01, shift00, vDataTypeSizeB);
|
uni_vpaddd(shift01, shift00, vDataTypeSizeB);
|
||||||
uni_vpaddd(shift10, shift00, vSrcWidthB); // shift11??
|
uni_vpaddd(shift10, shift00, vSrcWidthB);
|
||||||
uni_vpaddd(shift11, shift10, vDataTypeSizeB); // sub??
|
uni_vpaddd(shift11, shift10, vDataTypeSizeB);
|
||||||
} else if (jcp.paddingMode == GridSamplePaddingMode::BORDER) {
|
} else if (jcp.paddingMode == GridSamplePaddingMode::BORDER) {
|
||||||
borderPadding(vWCoord, vWCoord, coord::w);
|
borderPadding(shift00, shift00, coord::w);
|
||||||
borderPadding(vHCoord, vHCoord, coord::h);
|
borderPadding(shift01, shift01, coord::h);
|
||||||
borderPadding(shift10, shift10, coord::w);
|
borderPadding(shift10, shift10, coord::w);
|
||||||
borderPadding(shift11, shift11, coord::h);
|
borderPadding(shift11, shift11, coord::h);
|
||||||
} else if (jcp.paddingMode == GridSamplePaddingMode::REFLECTION) {
|
} else if (jcp.paddingMode == GridSamplePaddingMode::REFLECTION) {
|
||||||
reflectionPadding(vWCoord, vWCoord, coord::w);
|
reflectionPadding(shift00, shift00, coord::w);
|
||||||
reflectionPadding(vHCoord, vHCoord, coord::h);
|
reflectionPadding(shift01, shift01, coord::h);
|
||||||
reflectionPadding(shift10, shift10, coord::w);
|
reflectionPadding(shift10, shift10, coord::w);
|
||||||
reflectionPadding(shift11, shift11, coord::h);
|
reflectionPadding(shift11, shift11, coord::h);
|
||||||
}
|
}
|
||||||
if (jcp.paddingMode == GridSamplePaddingMode::BORDER || jcp.paddingMode == GridSamplePaddingMode::REFLECTION) {
|
if (jcp.paddingMode == GridSamplePaddingMode::BORDER || jcp.paddingMode == GridSamplePaddingMode::REFLECTION) {
|
||||||
// W * y + x
|
// W * y + x
|
||||||
hwShiftPs2dq(vAux, shift11, vWCoord, vSrcWidthF);
|
hwShiftPs2dq(vAux, shift11, shift00, vSrcWidthF);
|
||||||
hwShiftPs2dq(vWCoord, vHCoord, vWCoord, vSrcWidthF);
|
hwShiftPs2dq(shift00, shift01, shift00, vSrcWidthF);
|
||||||
hwShiftPs2dq(vHCoord, vHCoord, shift10, vSrcWidthF);
|
hwShiftPs2dq(shift01, shift01, shift10, vSrcWidthF);
|
||||||
hwShiftPs2dq(shift11, shift11, shift10, vSrcWidthF);
|
hwShiftPs2dq(shift11, shift11, shift10, vSrcWidthF);
|
||||||
uni_vmovups(shift10, vAux);
|
uni_vmovups(shift10, vAux);
|
||||||
}
|
}
|
||||||
@ -1658,8 +1679,8 @@ void GridSampleKernel<x64::avx512_core>::bicubicInterpolation(const Vmm& vWCoord
|
|||||||
// (y - 1 + h; x - 1)
|
// (y - 1 + h; x - 1)
|
||||||
if (jcp.paddingMode == GridSamplePaddingMode::ZEROS) {
|
if (jcp.paddingMode == GridSamplePaddingMode::ZEROS) {
|
||||||
Xbyak::Opmask maskH = kMaskH;
|
Xbyak::Opmask maskH = kMaskH;
|
||||||
vcmpps(kMaskH, vHCoord, vSrcHeightF, 0x1);
|
vcmpps(kMaskH, vHCoord, vSrcHeightF, CMP_LT_PS);
|
||||||
vcmpps(maskH | maskH, vZeros, vHCoord, 0x2);
|
vcmpps(maskH | maskH, vZeros, vHCoord, CMP_LE_PS);
|
||||||
kandw(kAuxMask, kMaskH, wMasks[0]);
|
kandw(kAuxMask, kMaskH, wMasks[0]);
|
||||||
uni_vmulps(vSrcShift0, vHCoord, vSrcWidthF);
|
uni_vmulps(vSrcShift0, vHCoord, vSrcWidthF);
|
||||||
uni_vmovups(vWCoord, vWLeft);
|
uni_vmovups(vWCoord, vWLeft);
|
||||||
|
@ -286,26 +286,80 @@ void JitKernelBase::uni_vpbroadcastd(const Xbyak::Ymm &x, const Xbyak::Operand &
|
|||||||
}
|
}
|
||||||
|
|
||||||
void JitKernelBase::fillRestWorkMask(const Xbyak::Opmask& dstMask,
|
void JitKernelBase::fillRestWorkMask(const Xbyak::Opmask& dstMask,
|
||||||
const Xbyak::Zmm& zAux,
|
const Xbyak::Reg64& rWorkRest) {
|
||||||
const Xbyak::Reg64& rWorkRest) {
|
auto rOnes = getReg64();
|
||||||
auto rAux0 = getReg64();
|
|
||||||
auto rAux1 = getReg64();
|
|
||||||
Xbyak::Label lKmov;
|
|
||||||
Xbyak::Reg32 rOnes(rAux1.getIdx());
|
|
||||||
const uint64_t typeSize = 4;
|
|
||||||
const uint64_t elPerVec = x64::cpu_isa_traits<x64::avx512_core>::vlen / typeSize;
|
|
||||||
|
|
||||||
mov(rOnes, 0x0000FFFF);
|
mov(rOnes, 0xFFFFFFFFFFFFFFFF);
|
||||||
cmp(rWorkRest, elPerVec);
|
shlx(rOnes, rOnes, rWorkRest);
|
||||||
jge(lKmov);
|
not_(rOnes);
|
||||||
{
|
kmovq(dstMask, rOnes);
|
||||||
Xbyak::Reg32 rShift(rAux0.getIdx());
|
}
|
||||||
mov(rShift, elPerVec);
|
|
||||||
sub(rShift, rWorkRest);
|
void JitKernelBase::fillRestWorkMask(const Xbyak::Xmm& xmmDstMask,
|
||||||
shrx(rOnes, rOnes, rShift);
|
const Xbyak::Reg64& rWorkRest,
|
||||||
|
const uint64_t typeSize) {
|
||||||
|
if (!one_of(typeSize, 1, 2, 4, 8)) {
|
||||||
|
IE_THROW() << "Could not fill data with type size " << typeSize;
|
||||||
}
|
}
|
||||||
L(lKmov);
|
Xbyak::Label lEnd;
|
||||||
kmovw(dstMask, rOnes);
|
auto r32Ones = getReg32();
|
||||||
|
Xbyak::Reg64 r64Ones(r32Ones.getIdx());
|
||||||
|
auto elPerVec = x64::cpu_isa_traits<x64::sse41>::vlen / typeSize;
|
||||||
|
|
||||||
|
mov(r64Ones, 0xFFFFFFFFFFFFFFFF);
|
||||||
|
for (uint8_t i = 0; i < elPerVec; i++) {
|
||||||
|
cmp(rWorkRest, i);
|
||||||
|
jle(lEnd, T_NEAR);
|
||||||
|
|
||||||
|
if (typeSize == 1) {
|
||||||
|
pinsrb(xmmDstMask, r32Ones, i);
|
||||||
|
} else if (typeSize == 2) {
|
||||||
|
pinsrw(xmmDstMask, r32Ones, i);
|
||||||
|
} else if (typeSize == 4) {
|
||||||
|
pinsrd(xmmDstMask, r32Ones, i);
|
||||||
|
} else if (typeSize == 8) {
|
||||||
|
pinsrq(xmmDstMask, r64Ones, i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
L(lEnd);
|
||||||
|
}
|
||||||
|
|
||||||
|
void JitKernelBase::fillRestWorkMask(const Xbyak::Ymm& ymmDstMask,
|
||||||
|
const Xbyak::Reg64& rWorkRest,
|
||||||
|
const uint64_t typeSize) {
|
||||||
|
if (!one_of(typeSize, 1, 2, 4, 8)) {
|
||||||
|
IE_THROW() << "Could not fill data with type size " << typeSize;
|
||||||
|
}
|
||||||
|
Xbyak::Label lEnd;
|
||||||
|
auto elPerVec = x64::cpu_isa_traits<x64::sse41>::vlen / typeSize;
|
||||||
|
auto r32Ones = getReg32();
|
||||||
|
Xbyak::Reg64 r64Ones(r32Ones.getIdx());
|
||||||
|
Xbyak::Xmm xmmDstMask(ymmDstMask.getIdx());
|
||||||
|
|
||||||
|
mov(r64Ones, 0xFFFFFFFFFFFFFFFF);
|
||||||
|
uni_vpxor(ymmDstMask, ymmDstMask, ymmDstMask);
|
||||||
|
for (uint8_t i = 0; i < 2; i++) {
|
||||||
|
Xbyak::Label lPerm;
|
||||||
|
for (uint8_t j = 0; j < elPerVec; j++) {
|
||||||
|
cmp(rWorkRest, i * elPerVec + j);
|
||||||
|
jle(i == 0 ? lEnd : lPerm, T_NEAR);
|
||||||
|
|
||||||
|
if (typeSize == 1) {
|
||||||
|
pinsrb(xmmDstMask, r32Ones, j);
|
||||||
|
} else if (typeSize == 2) {
|
||||||
|
pinsrw(xmmDstMask, r32Ones, j);
|
||||||
|
} else if (typeSize == 4) {
|
||||||
|
pinsrd(xmmDstMask, r32Ones, j);
|
||||||
|
} else if (typeSize == 8) {
|
||||||
|
pinsrq(xmmDstMask, r64Ones, j);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cmp(rWorkRest, elPerVec);
|
||||||
|
je(lEnd, T_NEAR);
|
||||||
|
L(lPerm);
|
||||||
|
vperm2f128(ymmDstMask, ymmDstMask, ymmDstMask, 0x1);
|
||||||
|
}
|
||||||
|
L(lEnd);
|
||||||
}
|
}
|
||||||
|
|
||||||
void JitKernelBase::load(const Xbyak::Xmm& vDst,
|
void JitKernelBase::load(const Xbyak::Xmm& vDst,
|
||||||
|
@ -11,6 +11,7 @@ namespace ov {
|
|||||||
namespace intel_cpu {
|
namespace intel_cpu {
|
||||||
|
|
||||||
#define getReg64() RegistersPool::Reg<Xbyak::Reg64>(registersPool)
|
#define getReg64() RegistersPool::Reg<Xbyak::Reg64>(registersPool)
|
||||||
|
#define getReg32() RegistersPool::Reg<Xbyak::Reg32>(registersPool)
|
||||||
#define getVmm() RegistersPool::Reg<Vmm>(registersPool)
|
#define getVmm() RegistersPool::Reg<Vmm>(registersPool)
|
||||||
#define getMask() RegistersPool::Reg<Vmask>(registersPool)
|
#define getMask() RegistersPool::Reg<Vmask>(registersPool)
|
||||||
|
|
||||||
@ -84,9 +85,16 @@ public:
|
|||||||
const bool zeroFill = false);
|
const bool zeroFill = false);
|
||||||
|
|
||||||
void fillRestWorkMask(const Xbyak::Opmask& kDstMask,
|
void fillRestWorkMask(const Xbyak::Opmask& kDstMask,
|
||||||
const Xbyak::Zmm& zAux,
|
|
||||||
const Xbyak::Reg64& rWorkRest);
|
const Xbyak::Reg64& rWorkRest);
|
||||||
|
|
||||||
|
void fillRestWorkMask(const Xbyak::Xmm& ymmDstMask,
|
||||||
|
const Xbyak::Reg64& rWorkRest,
|
||||||
|
const uint64_t typeSize = 4);
|
||||||
|
|
||||||
|
void fillRestWorkMask(const Xbyak::Ymm& ymmDstMask,
|
||||||
|
const Xbyak::Reg64& rWorkRest,
|
||||||
|
const uint64_t typeSize = 4);
|
||||||
|
|
||||||
void load(const Xbyak::Xmm& vDst,
|
void load(const Xbyak::Xmm& vDst,
|
||||||
const Xbyak::Address& srcAddr,
|
const Xbyak::Address& srcAddr,
|
||||||
const Xbyak::Reg64& rLoadNum,
|
const Xbyak::Reg64& rLoadNum,
|
||||||
@ -133,6 +141,18 @@ protected:
|
|||||||
}
|
}
|
||||||
|
|
||||||
RegistersPool::Ptr registersPool;
|
RegistersPool::Ptr registersPool;
|
||||||
|
|
||||||
|
enum {
|
||||||
|
// Comparison predicate operand (immediate byte) for single-precision floating-point values.
|
||||||
|
CMP_EQ_PS = 0, // Equal (ordered, non-signaling)
|
||||||
|
CMP_LT_PS, // Less-than (ordered, signaling)
|
||||||
|
CMP_LE_PS, // Less-than-or-equal (ordered, signaling)
|
||||||
|
CMP_UNORD_PS, // Unordered (non-signaling)
|
||||||
|
CMP_NEQ_PS, // Not-equal (unordered, non-signaling)
|
||||||
|
CMP_NLT_PS, // Not-less-than (unordered, signaling)
|
||||||
|
CMP_NLE_PS, // Not-less-than-or-equal (unordered, signaling)
|
||||||
|
CMP_ORD_PS // Ordered (non-signaling)
|
||||||
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace intel_cpu
|
} // namespace intel_cpu
|
||||||
|
@ -184,9 +184,6 @@ std::vector<std::string> disabledTestPatterns() {
|
|||||||
// The kernel does not have such garbage. The diff 0.000000745 is taken into account in calculations and affects further type conversion.
|
// The kernel does not have such garbage. The diff 0.000000745 is taken into account in calculations and affects further type conversion.
|
||||||
// Reorder->GridSample->Reorder also does not work here. Potential fix is to use nearest conversion instead of truncation.
|
// Reorder->GridSample->Reorder also does not work here. Potential fix is to use nearest conversion instead of truncation.
|
||||||
R"(.*GridSampleLayerTestCPU.*(BILINEAR|BICUBIC).*(i32|i8).*)",
|
R"(.*GridSampleLayerTestCPU.*(BILINEAR|BICUBIC).*(i32|i8).*)",
|
||||||
// 94989. BF16 Reference produces different results.
|
|
||||||
// GridSample regression on bf16 data.
|
|
||||||
R"(.*GridSampleLayerTestCPU.*(BILINEAR|BICUBIC).*bf16.*)",
|
|
||||||
// // Issue: 95915
|
// // Issue: 95915
|
||||||
R"(smoke_dynamic/AUGRUCellCPUTest.CompareWithRefs/IS=\(\[\?\.1\]_\[\?\.1\]_\[\?\.1\]_\)_TS=\{\(1\.1\)_\(1\.1\)_\(1\.1\)\}_\{\(3\.1\)_\(3\.1\)_\(3\.1\)\}_\{\(5\.1\)_\(5\.1\)_\(5\.1\)\}_decompose=0_activations=\(sigmoid\.tanh\)_clip=0_linear=0_netPrec=f32__inFmts=nc\.nc_outFmts=nc_primitive=ref_any_PluginConf_ENFORCE_BF16=YES)", // NOLINT
|
R"(smoke_dynamic/AUGRUCellCPUTest.CompareWithRefs/IS=\(\[\?\.1\]_\[\?\.1\]_\[\?\.1\]_\)_TS=\{\(1\.1\)_\(1\.1\)_\(1\.1\)\}_\{\(3\.1\)_\(3\.1\)_\(3\.1\)\}_\{\(5\.1\)_\(5\.1\)_\(5\.1\)\}_decompose=0_activations=\(sigmoid\.tanh\)_clip=0_linear=0_netPrec=f32__inFmts=nc\.nc_outFmts=nc_primitive=ref_any_PluginConf_ENFORCE_BF16=YES)", // NOLINT
|
||||||
R"(smoke_dynamic/GRUCellCPUTest.CompareWithRefs/IS=\(\[\?.1\]_\[\?\.1\]_\)_TS=\{\(1\.1\)_\(1\.1\)\}_\{\(3\.1\)_\(3\.1\)\}_\{\(5\.1\)_\(5\.1\)\}_decompose=0_activations=\(sigmoid\.tanh\)_clip=0_linear=0_netPrec=f32__inFmts=nc\.nc_outFmts=nc_primitive=ref_any_PluginConf_ENFORCE_BF16=YES)", // NOLINT
|
R"(smoke_dynamic/GRUCellCPUTest.CompareWithRefs/IS=\(\[\?.1\]_\[\?\.1\]_\)_TS=\{\(1\.1\)_\(1\.1\)\}_\{\(3\.1\)_\(3\.1\)\}_\{\(5\.1\)_\(5\.1\)\}_decompose=0_activations=\(sigmoid\.tanh\)_clip=0_linear=0_netPrec=f32__inFmts=nc\.nc_outFmts=nc_primitive=ref_any_PluginConf_ENFORCE_BF16=YES)", // NOLINT
|
||||||
|
@ -94,6 +94,9 @@ protected:
|
|||||||
auto execType = dataPrecision == ov::element::i32 ? ov::element::i32 : ov::element::f32;
|
auto execType = dataPrecision == ov::element::i32 ? ov::element::i32 : ov::element::f32;
|
||||||
selectedType = makeSelectedTypeStr(selectedType, execType);
|
selectedType = makeSelectedTypeStr(selectedType, execType);
|
||||||
}
|
}
|
||||||
|
if (gridPrecision == ov::element::bf16) {
|
||||||
|
rel_threshold = 0.01f;
|
||||||
|
}
|
||||||
|
|
||||||
auto params = ngraph::builder::makeDynamicParams({dataPrecision, gridPrecision}, inputDynamicShapes);
|
auto params = ngraph::builder::makeDynamicParams({dataPrecision, gridPrecision}, inputDynamicShapes);
|
||||||
params[0]->set_friendly_name("data");
|
params[0]->set_friendly_name("data");
|
||||||
@ -272,12 +275,35 @@ INSTANTIATE_TEST_SUITE_P(smoke_static, GridSampleLayerTestCPU,
|
|||||||
::testing::ValuesIn(interpolateMode),
|
::testing::ValuesIn(interpolateMode),
|
||||||
::testing::ValuesIn(paddingMode),
|
::testing::ValuesIn(paddingMode),
|
||||||
::testing::ValuesIn(alignCorners),
|
::testing::ValuesIn(alignCorners),
|
||||||
::testing::ValuesIn({ElementType::f32, ElementType::bf16, ElementType::i32, ElementType::i8}),
|
::testing::ValuesIn({ElementType::f32, ElementType::i32}),
|
||||||
|
::testing::ValuesIn({ElementType::f32}),
|
||||||
|
::testing::ValuesIn(getCPUInfo()),
|
||||||
|
::testing::Values(additionalConfig[0])),
|
||||||
|
GridSampleLayerTestCPU::getTestCaseName);
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(nightly_static_1, GridSampleLayerTestCPU,
|
||||||
|
::testing::Combine(
|
||||||
|
::testing::ValuesIn(getStaticShapes()),
|
||||||
|
::testing::ValuesIn(interpolateMode),
|
||||||
|
::testing::ValuesIn(paddingMode),
|
||||||
|
::testing::ValuesIn(alignCorners),
|
||||||
|
::testing::ValuesIn({ElementType::bf16, ElementType::i8}),
|
||||||
::testing::ValuesIn({ElementType::f32, ElementType::bf16}),
|
::testing::ValuesIn({ElementType::f32, ElementType::bf16}),
|
||||||
::testing::ValuesIn(getCPUInfo()),
|
::testing::ValuesIn(getCPUInfo()),
|
||||||
::testing::Values(additionalConfig[0])),
|
::testing::Values(additionalConfig[0])),
|
||||||
GridSampleLayerTestCPU::getTestCaseName);
|
GridSampleLayerTestCPU::getTestCaseName);
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(nightly_static_2, GridSampleLayerTestCPU,
|
||||||
|
::testing::Combine(
|
||||||
|
::testing::ValuesIn(getStaticShapes()),
|
||||||
|
::testing::ValuesIn(interpolateMode),
|
||||||
|
::testing::ValuesIn(paddingMode),
|
||||||
|
::testing::ValuesIn(alignCorners),
|
||||||
|
::testing::ValuesIn({ElementType::f32}),
|
||||||
|
::testing::ValuesIn({ElementType::bf16}),
|
||||||
|
::testing::ValuesIn(getCPUInfo()),
|
||||||
|
::testing::Values(additionalConfig[0])),
|
||||||
|
GridSampleLayerTestCPU::getTestCaseName);
|
||||||
|
|
||||||
const std::vector<std::vector<InputShape>> dynamicInSapes = {
|
const std::vector<std::vector<InputShape>> dynamicInSapes = {
|
||||||
{ { { ov::Dimension(1, 15), -1, -1, -1 }, // Dynamic shape 0
|
{ { { ov::Dimension(1, 15), -1, -1, -1 }, // Dynamic shape 0
|
||||||
|
Loading…
Reference in New Issue
Block a user