[CPU] DefConv fixed for small channels (#9718)

This commit is contained in:
Yury Gaydaychuk 2022-01-25 17:26:00 +03:00 committed by GitHub
parent 0282f11165
commit 365bd7c46e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 11 additions and 6 deletions

View File

@ -674,9 +674,8 @@ void MKLDNNDeformableConvolutionNode::initSupportedPrimitiveDescriptors() {
auto &weiDims = getInputShapeAtPort(WEI_ID).getDims();
if (weiDims[1] == Shape::UNDEFINED_DIM || weiDims[0] == Shape::UNDEFINED_DIM ||
defConvAttr.group != 1 || // temporary workaround until jit impl. will correctly handle multigroup cases
(weiDims[1] % simd_w != 0) // in_channels_per_gr !% simd_w
|| ((weiDims[0] / defConvAttr.group) % simd_w != 0)) { // out_channels_per_gr !% simd_w
(defConvAttr.group != 1 && ((weiDims[1] % simd_w != 0) // in_channels_per_gr !% simd_w
|| ((weiDims[0] / defConvAttr.group) % simd_w != 0)))) { // out_channels_per_gr !% simd_w
enforceRef = true;
} else {
enforceRef = false;
@ -822,8 +821,8 @@ void MKLDNNDeformableConvolutionNode::DefConvExecutor::prepareSamplingWeights(
lh = (h_high < cur_h_end ? lh : 0);
lw = (w_high < cur_w_end ? lw : 0);
const int h_off_low = h_ind_low * srcStrides[2] / srcStrides[3];
const int h_off_high = h_ind_high * srcStrides[2] / srcStrides[3];
const int h_off_low = h_ind_low * (srcStrides[2] / srcStrides[3]);
const int h_off_high = h_ind_high * (srcStrides[2] / srcStrides[3]);
const int w_off_low = w_ind_low;
const int w_off_high = w_ind_high;
pSampledCoordsVector[sampledCoordIndex] = h_off_high + w_off_high;

View File

@ -336,6 +336,12 @@ const std::vector<std::vector<size_t>> channelParamsSingleGr = {
{16, 32}, // in. ch. per gr.
{16, 32} // out. ch. per gr.
};
const std::vector<std::vector<size_t>> channelParamsSingleGr2 = {
{1}, // gr. 2,4
{1}, // def. gr. 1,2
{3}, // in. ch. per gr.
{3} // out. ch. per gr.
};
const std::vector<std::vector<size_t>> channelParamsMulGr = {
{2, 4}, // gr. 2,4
{1, 2}, // def. gr. 1,2
@ -503,7 +509,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_DefConvLayoutTest9, DefConvLayerCPUTest, params9_
const auto params1 = ::testing::Combine(
::testing::Combine(
addSpParams,
::testing::ValuesIn(static_shapes_to_test_representation(buildStaticParams(spatParams1, channelParamsSingleGr))),
::testing::ValuesIn(static_shapes_to_test_representation(buildStaticParams(spatParams1, channelParamsSingleGr2))),
defConvSpecificParams,
::testing::ValuesIn(netPrecisions),
::testing::Values(CommonTestUtils::DEVICE_CPU)),