[CPU] DefConv fixed for small channels (#9718)
This commit is contained in:
parent
0282f11165
commit
365bd7c46e
@ -674,9 +674,8 @@ void MKLDNNDeformableConvolutionNode::initSupportedPrimitiveDescriptors() {
|
||||
|
||||
auto &weiDims = getInputShapeAtPort(WEI_ID).getDims();
|
||||
if (weiDims[1] == Shape::UNDEFINED_DIM || weiDims[0] == Shape::UNDEFINED_DIM ||
|
||||
defConvAttr.group != 1 || // temporary workaround until jit impl. will correctly handle multigroup cases
|
||||
(weiDims[1] % simd_w != 0) // in_channels_per_gr !% simd_w
|
||||
|| ((weiDims[0] / defConvAttr.group) % simd_w != 0)) { // out_channels_per_gr !% simd_w
|
||||
(defConvAttr.group != 1 && ((weiDims[1] % simd_w != 0) // in_channels_per_gr !% simd_w
|
||||
|| ((weiDims[0] / defConvAttr.group) % simd_w != 0)))) { // out_channels_per_gr !% simd_w
|
||||
enforceRef = true;
|
||||
} else {
|
||||
enforceRef = false;
|
||||
@ -822,8 +821,8 @@ void MKLDNNDeformableConvolutionNode::DefConvExecutor::prepareSamplingWeights(
|
||||
lh = (h_high < cur_h_end ? lh : 0);
|
||||
lw = (w_high < cur_w_end ? lw : 0);
|
||||
|
||||
const int h_off_low = h_ind_low * srcStrides[2] / srcStrides[3];
|
||||
const int h_off_high = h_ind_high * srcStrides[2] / srcStrides[3];
|
||||
const int h_off_low = h_ind_low * (srcStrides[2] / srcStrides[3]);
|
||||
const int h_off_high = h_ind_high * (srcStrides[2] / srcStrides[3]);
|
||||
const int w_off_low = w_ind_low;
|
||||
const int w_off_high = w_ind_high;
|
||||
pSampledCoordsVector[sampledCoordIndex] = h_off_high + w_off_high;
|
||||
|
@ -336,6 +336,12 @@ const std::vector<std::vector<size_t>> channelParamsSingleGr = {
|
||||
{16, 32}, // in. ch. per gr.
|
||||
{16, 32} // out. ch. per gr.
|
||||
};
|
||||
const std::vector<std::vector<size_t>> channelParamsSingleGr2 = {
|
||||
{1}, // gr. 2,4
|
||||
{1}, // def. gr. 1,2
|
||||
{3}, // in. ch. per gr.
|
||||
{3} // out. ch. per gr.
|
||||
};
|
||||
const std::vector<std::vector<size_t>> channelParamsMulGr = {
|
||||
{2, 4}, // gr. 2,4
|
||||
{1, 2}, // def. gr. 1,2
|
||||
@ -503,7 +509,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_DefConvLayoutTest9, DefConvLayerCPUTest, params9_
|
||||
const auto params1 = ::testing::Combine(
|
||||
::testing::Combine(
|
||||
addSpParams,
|
||||
::testing::ValuesIn(static_shapes_to_test_representation(buildStaticParams(spatParams1, channelParamsSingleGr))),
|
||||
::testing::ValuesIn(static_shapes_to_test_representation(buildStaticParams(spatParams1, channelParamsSingleGr2))),
|
||||
defConvSpecificParams,
|
||||
::testing::ValuesIn(netPrecisions),
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU)),
|
||||
|
Loading…
Reference in New Issue
Block a user