[MKLDNN_PLUGIN] Convolution node: skip initializing of primitive descriptors for planar layout if there is already jit primitive (#672)

This commit is contained in:
Anton Voronov 2020-06-04 08:06:14 +03:00 committed by GitHub
parent 158d32139f
commit e53b1b7fbc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 132 additions and 68 deletions

View File

@ -272,10 +272,11 @@ void MKLDNNConvolutionNode::getSupportedDescriptors() {
}
}
MKLDNNMemoryDesc in_candidate, out_candidate;
if (canBeExecutedInInt8()) {
MKLDNNMemoryDesc in_candidate = MKLDNNMemoryDesc(getParentEdgeAt(0)->getDims(), inputDataType,
in_candidate = MKLDNNMemoryDesc(getParentEdgeAt(0)->getDims(), inputDataType,
getParentEdgeAt(0)->getDims().ndims() == 5 ? memory::ndhwc : memory::nhwc);
MKLDNNMemoryDesc out_candidate = MKLDNNMemoryDesc(getChildEdgeAt(0)->getDims(), outputDataType,
out_candidate = MKLDNNMemoryDesc(getChildEdgeAt(0)->getDims(), outputDataType,
getParentEdgeAt(0)->getDims().ndims() == 5 ? memory::ndhwc : memory::nhwc);
createDescriptor({in_candidate}, {out_candidate});
} else {
@ -308,13 +309,9 @@ void MKLDNNConvolutionNode::getSupportedDescriptors() {
Layout layout = convLayer->input()->getLayout();
if (layout == NCHW || layout == NHWC) {
MKLDNNMemoryDesc in_candidate(getParentEdgeAt(0)->getDims(), inputDataType,
layout == NCHW ? memory::nchw : memory::nhwc);
MKLDNNMemoryDesc out_candidate(getChildEdgeAt(0)->getDims(), outputDataType,
layout == NCHW ? memory::nchw : memory::nhwc);
createDescriptor({in_candidate}, {out_candidate});
if (IC == 3 || IC == 1) {
in_candidate = MKLDNNMemoryDesc(getParentEdgeAt(0)->getDims(), inputDataType,
layout == NCHW ? memory::nchw : memory::nhwc);
out_candidate = MKLDNNMemoryDesc(getChildEdgeAt(0)->getDims(), outputDataType, memory::nChw16c);
createDescriptor({in_candidate}, {out_candidate});
out_candidate = MKLDNNMemoryDesc(getChildEdgeAt(0)->getDims(), outputDataType, memory::nChw8c);
@ -327,13 +324,15 @@ void MKLDNNConvolutionNode::getSupportedDescriptors() {
out_candidate = MKLDNNMemoryDesc(getChildEdgeAt(0)->getDims(), outputDataType, memory::nChw8c);
createDescriptor({in_candidate}, {out_candidate});
}
} else if (layout == NCDHW || layout == NDHWC) {
MKLDNNMemoryDesc in_candidate(getParentEdgeAt(0)->getDims(), inputDataType,
layout == NCDHW ? memory::ncdhw : memory::ndhwc);
MKLDNNMemoryDesc out_candidate(getChildEdgeAt(0)->getDims(), outputDataType,
layout == NCDHW ? memory::ncdhw : memory::ndhwc);
createDescriptor({in_candidate}, {out_candidate});
in_candidate = MKLDNNMemoryDesc(getParentEdgeAt(0)->getDims(), inputDataType,
layout == NCHW ? memory::nchw : memory::nhwc);
out_candidate = MKLDNNMemoryDesc(getChildEdgeAt(0)->getDims(), outputDataType,
layout == NCHW ? memory::nchw : memory::nhwc);
createDescriptor({in_candidate}, {out_candidate});
} else if (layout == NCDHW || layout == NDHWC) {
in_candidate = MKLDNNMemoryDesc(getParentEdgeAt(0)->getDims(), inputDataType,
layout == NCDHW ? memory::ncdhw : memory::ndhwc);
if (IC == 3 || IC == 1) {
out_candidate = MKLDNNMemoryDesc(getChildEdgeAt(0)->getDims(), outputDataType, memory::nCdhw16c);
createDescriptor({in_candidate}, {out_candidate});
@ -347,6 +346,12 @@ void MKLDNNConvolutionNode::getSupportedDescriptors() {
out_candidate = MKLDNNMemoryDesc(getChildEdgeAt(0)->getDims(), outputDataType, memory::nCdhw8c);
createDescriptor({in_candidate}, {out_candidate});
}
in_candidate = MKLDNNMemoryDesc(getParentEdgeAt(0)->getDims(), inputDataType,
layout == NCDHW ? memory::ncdhw : memory::ndhwc);
out_candidate = MKLDNNMemoryDesc(getChildEdgeAt(0)->getDims(), outputDataType,
layout == NCDHW ? memory::ncdhw : memory::ndhwc);
createDescriptor({in_candidate}, {out_candidate});
}
}
}
@ -556,7 +561,11 @@ void MKLDNNConvolutionNode::initSupportedPrimitiveDescriptors() {
addZeroPoints(attr);
setPostOps(attr);
bool containJitImpl = false;
for (auto& desc : descs) {
if (containJitImpl && isPossibleToSkipInitConfig(desc))
continue;
auto itpd = desc.createPrimitiveDescriptorIterator(getEngine(), attr);
while (itpd.is_not_end()) {
InferenceEngine::LayerConfig config;
@ -610,6 +619,8 @@ void MKLDNNConvolutionNode::initSupportedPrimitiveDescriptors() {
outFormats.emplace_back(static_cast<memory::format>(itpd.dst_primitive_desc().desc().data.format));
}
impl_desc_type impl_type = parse_impl_name(itpd.get_impl_info_str());
if (impl_type & jit)
containJitImpl = true;
supportedPrimitiveDescriptors.emplace_back(config, impl_type, outFormats);
itpd++;
@ -790,8 +801,13 @@ void MKLDNNConvolutionNode::initDescriptor(const InferenceEngine::LayerConfig& c
InferenceEngine::LayerConfig rightConfig = selectedPD->getConfig();
size_t selected_count = 0;
bool containJitImpl = false;
for (size_t i = 0; i < descs.size(); i++) {
const auto& desc = descs[i];
auto& desc = descs[i];
if (containJitImpl && isPossibleToSkipInitConfig(desc))
continue;
auto itpd = desc.createPrimitiveDescriptorIterator(getEngine(), attr);
while (itpd.is_not_end()) {
InferenceEngine::LayerConfig cfg;
@ -836,6 +852,8 @@ void MKLDNNConvolutionNode::initDescriptor(const InferenceEngine::LayerConfig& c
cfg.outConfs.push_back(dataConfig);
}
impl_desc_type impl_type = parse_impl_name(itpd.get_impl_info_str());
if (impl_type & jit)
containJitImpl = true;
if (selected_count == selectedPrimitiveDescriptorIndex) {
if (impl_type != selectedPD->getImplementationType()) {
@ -888,6 +906,41 @@ void MKLDNNConvolutionNode::filterSupportedDescriptors() {
}
}
bool MKLDNNConvolutionNode::isPossibleToSkipInitConfig(MKLDNNDescriptor &desc) {
// WA: In some cases, we can predict in advance the type of primitive that will be called in the future.
// In particular, isPossibleToSkipInitConfig() checks whether we can skip the creation of primitives with
// gemm implementation, which significantly increase the network load time.
if (!inputMemoryFormatsFilter.empty() || !outputMemoryFormatsFilter.empty())
return false;
if (getCnnLayer()->params.find("PrimitivesPriority") != getCnnLayer()->params.end())
return false;
// Here we check that we will not delete jit_planar_conv primitive by mistake.
// It requires:
// 1) strides equal 1;
// 2) not grouped;
// 3) first dim of weights is not 1.
bool isPossibleJitPlanar = true;
if (isGrouped || weightDims[0] != 1)
isPossibleJitPlanar = false;
for (int i = 0; i < stride.size(); i++)
if (stride[i] != 1)
isPossibleJitPlanar = false;
std::shared_ptr<mkldnn::convolution_forward::desc> convDesc(desc);
auto srcMemFmt = convDesc->data.src_desc.format;
auto dstMemFmt = convDesc->data.dst_desc.format;
auto srcDataType = convDesc->data.src_desc.data_type;
auto dstDataType = convDesc->data.dst_desc.data_type;
bool isPlanarFloatConv = (srcMemFmt == memory::nchw || srcMemFmt == memory::ncdhw)
&& (dstMemFmt == memory::nchw || dstMemFmt == memory::ncdhw)
&& srcDataType == memory::f32
&& dstDataType == memory::f32;
return !isPossibleJitPlanar && isPlanarFloatConv;
}
MKLDNNMemoryDesc MKLDNNConvolutionNode::getSrcMemDesc(mkldnn::primitive_desc_iterator &primitive_desc_it, size_t idx) {
InferenceEngine::TensorDesc desc = idx > 0 ? MKLDNNMemoryDesc(primitive_desc_it.weights_primitive_desc(idx - 1).desc())
: MKLDNNMemoryDesc(primitive_desc_it.src_primitive_desc(idx).desc());

View File

@ -27,6 +27,7 @@ public:
void initSupportedPrimitiveDescriptors() override;
void filterSupportedPrimitiveDescriptors() override;
void filterSupportedDescriptors();
bool isPossibleToSkipInitConfig(MKLDNNDescriptor &desc);
bool created() const override;
bool canBeInPlace() const override {
return false;

View File

@ -31,6 +31,7 @@ struct conv_test_params {
size_t num_prim_desc;
int selectedType;
bool defaultPrimitivesPriority;
vector<MKLDNNPlugin::impl_desc_type> preferTypes;
vector<std::function<void(MKLDNNPlugin::PrimitiveDescInfo)>> comp;
@ -149,7 +150,7 @@ class MKLDNNGraphConvolutionTests: public TestsCommon,
<convolution _AP_ kernel="_K_"
pads_begin="_PB_" pads_end="_PE_"
strides="_KS_"
output="_OC_" group="_GC_" PrimitivesPriority="_IMPLS_"/>
output="_OC_" group="_GC_" _PRIM_PRIORITY_/>
<weights offset="0" size="_S1_" />
<biases offset="_S1_" size="_S2_" />
@ -216,13 +217,17 @@ protected:
REPLACE_WITH_NUM(model, "_S1_", w_data_size);
REPLACE_WITH_NUM(model, "_S2_", b_data_size);
std::string primitivesPriorityStr;
if (!p.defaultPrimitivesPriority) {
std::string impls;
for (const auto& preferType : p.preferTypes) {
if (!impls.empty())
impls += ",";
impls += "cpu:" + MKLDNNGraphTestClass::getStrPrimitiveDescriptorType(preferType);
}
REPLACE_WITH_STR(model, "_IMPLS_", impls);
primitivesPriorityStr = "PrimitivesPriority=\"" + impls + "\"";
}
REPLACE_WITH_STR(model, "_PRIM_PRIORITY_", primitivesPriorityStr);
return model;
}
@ -263,6 +268,10 @@ protected:
if (node->getType() == MKLDNNPlugin::Convolution) {
ASSERT_LE(p.num_prim_desc, node->getSupportedPrimitiveDescriptors().size());
for (const auto prim : node->getSupportedPrimitiveDescriptors()) {
if (p.defaultPrimitivesPriority) {
if (prim.getImplementationType() & MKLDNNPlugin::impl_desc_type::gemm)
FAIL() << "There should be no gemm implementation in supportedPrimitiveDescriptors";
}
std::cout << MKLDNNGraphTestClass::getStrPrimitiveDescriptorType(prim.getImplementationType()) << " ";
}
std::cout << std::endl;
@ -335,44 +344,29 @@ TEST_P(MKLDNNGraphConvolutionTests, TestsConvolution) {}
INSTANTIATE_TEST_CASE_P(
TestConvolution, MKLDNNGraphConvolutionTests,
::testing::Values(
/*0*/ conv_test_params{{1, 9, 16, 32},
{1, 1}, {1, 1}, {0, 0}, {0, 0}, 17, 1, "same_upper", 6, MKLDNNPlugin::impl_desc_type::jit | MKLDNNPlugin::impl_desc_type::_1x1 },
conv_test_params{{1, 9, 32, 16},
{2, 4}, {1, 1}, {1, 1}, {0, 2}, 17, 1, "", 5, MKLDNNPlugin::impl_desc_type::jit },
conv_test_params{{1, 9, 32, 16},
{2, 4}, {2, 1}, {0, 0}, {0, 0}, 17, 1, "", 5, MKLDNNPlugin::impl_desc_type::jit },
conv_test_params{{1, 3, 40, 40},
{3, 3}, {1, 2}, {0, 0}, {0, 0}, 20, 1, "", 5, MKLDNNPlugin::impl_desc_type::jit },
conv_test_params{{1, 1, 40, 40},
{3, 3}, {1, 2}, {0, 0}, {0, 0}, 20, 1, "", 5, MKLDNNPlugin::impl_desc_type::jit },
conv_test_params{{1, 1, 32, 16},
{2, 4}, {2, 1}, {0, 0}, {0, 0}, 17, 1, "", 5, MKLDNNPlugin::impl_desc_type::jit },
conv_test_params{{1, 9, 32, 16},
{2, 4}, {1, 1}, {0, 0}, {0, 0}, 17, 1, "", 5, MKLDNNPlugin::impl_desc_type::ref_any,
/*0*/ conv_test_params{{1, 9, 16, 32}, {1, 1}, {1, 1}, {0, 0}, {0, 0}, 17, 1, "same_upper", 6,
MKLDNNPlugin::impl_desc_type::jit | MKLDNNPlugin::impl_desc_type::_1x1, false },
conv_test_params{{1, 9, 32, 16}, {2, 4}, {1, 1}, {1, 1}, {0, 2}, 17, 1, "", 5, MKLDNNPlugin::impl_desc_type::jit, false },
conv_test_params{{1, 9, 32, 16}, {2, 4}, {2, 1}, {0, 0}, {0, 0}, 17, 1, "", 5, MKLDNNPlugin::impl_desc_type::jit, false },
conv_test_params{{1, 3, 40, 40}, {3, 3}, {1, 2}, {0, 0}, {0, 0}, 20, 1, "", 5, MKLDNNPlugin::impl_desc_type::jit, false },
conv_test_params{{1, 1, 40, 40}, {3, 3}, {1, 2}, {0, 0}, {0, 0}, 20, 1, "", 5, MKLDNNPlugin::impl_desc_type::jit, false },
conv_test_params{{1, 1, 32, 16}, {2, 4}, {2, 1}, {0, 0}, {0, 0}, 17, 1, "", 5, MKLDNNPlugin::impl_desc_type::jit, false },
conv_test_params{{1, 9, 32, 16}, {2, 4}, {1, 1}, {0, 0}, {0, 0}, 17, 1, "", 5, MKLDNNPlugin::impl_desc_type::ref_any, false,
{MKLDNNPlugin::impl_desc_type::ref_any} },
conv_test_params{{1, 4, 54, 96},
{3, 3}, {1, 1}, {1, 1}, {0, 0}, 64, 1, "", 3, MKLDNNPlugin::impl_desc_type::ref_any,
conv_test_params{{1, 4, 54, 96}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, 64, 1, "", 3, MKLDNNPlugin::impl_desc_type::ref_any, false,
{MKLDNNPlugin::impl_desc_type::jit_avx512_winograd, MKLDNNPlugin::impl_desc_type::ref_any}},
// 5D
/*8*/ conv_test_params{{1, 3, 15, 20, 20},
{3, 3, 3}, {2, 2, 2}, {0, 0, 0}, {0, 0, 0}, 64, 1, "", 2, MKLDNNPlugin::impl_desc_type::ref_any,
/*8*/ conv_test_params{{1, 3, 15, 20, 20}, {3, 3, 3}, {2, 2, 2}, {0, 0, 0}, {0, 0, 0}, 64, 1, "", 2, MKLDNNPlugin::impl_desc_type::ref_any, false,
{MKLDNNPlugin::impl_desc_type::ref_any} },
conv_test_params{{1, 24, 15, 20, 20},
{3, 3, 3}, {2, 2, 2}, {0, 0, 0}, {0, 0, 0}, 64, 1, "", 2, MKLDNNPlugin::impl_desc_type::ref_any,
conv_test_params{{1, 24, 15, 20, 20}, {3, 3, 3}, {2, 2, 2}, {0, 0, 0}, {0, 0, 0}, 64, 1, "", 2, MKLDNNPlugin::impl_desc_type::ref_any, false,
{MKLDNNPlugin::impl_desc_type::ref_any} },
conv_test_params{{1, 32, 15, 20, 20},
{3, 3, 3}, {2, 2, 2}, {0, 0, 0}, {0, 0, 0}, 64, 1, "", 2, MKLDNNPlugin::impl_desc_type::ref_any,
conv_test_params{{1, 32, 15, 20, 20}, {3, 3, 3}, {2, 2, 2}, {0, 0, 0}, {0, 0, 0}, 64, 1, "", 2, MKLDNNPlugin::impl_desc_type::ref_any, false,
{MKLDNNPlugin::impl_desc_type::ref_any} },
conv_test_params{{1, 3, 15, 25, 20},
{3, 3, 3}, {2, 2, 2}, {0, 0, 0}, {0, 0, 0}, 64, 1, "", 2, MKLDNNPlugin::impl_desc_type::jit },
conv_test_params{{1, 24, 15, 25, 20},
{3, 3, 3}, {2, 2, 2}, {0, 0, 0}, {0, 0, 0}, 64, 1, "", 2, MKLDNNPlugin::impl_desc_type::jit },
/*13*/ conv_test_params{{1, 32, 15, 25, 20},
{3, 3, 3}, {2, 2, 2}, {0, 0, 0}, {0, 0, 0}, 64, 1, "", 2, MKLDNNPlugin::impl_desc_type::jit },
/*20*/ conv_test_params{{1, 16, 30, 30, 10},
{5, 5, 5}, {1, 1, 1}, {2, 2, 2}, {2, 2, 2}, 16, 1, "", 2, MKLDNNPlugin::impl_desc_type::jit },
conv_test_params{{1, 16, 30, 30, 10},
{5, 5, 5}, {1, 1, 1}, {2, 2, 2}, {2, 2, 2}, 16, 1, "", 2, MKLDNNPlugin::impl_desc_type::ref_any,
conv_test_params{{1, 3, 15, 25, 20}, {3, 3, 3}, {2, 2, 2}, {0, 0, 0}, {0, 0, 0}, 64, 1, "", 2, MKLDNNPlugin::impl_desc_type::jit, false },
conv_test_params{{1, 24, 15, 25, 20}, {3, 3, 3}, {2, 2, 2}, {0, 0, 0}, {0, 0, 0}, 64, 1, "", 2, MKLDNNPlugin::impl_desc_type::jit, false },
/*13*/ conv_test_params{{1, 32, 15, 25, 20}, {3, 3, 3}, {2, 2, 2}, {0, 0, 0}, {0, 0, 0}, 64, 1, "", 2, MKLDNNPlugin::impl_desc_type::jit, false },
conv_test_params{{1, 16, 30, 30, 10}, {5, 5, 5}, {1, 1, 1}, {2, 2, 2}, {2, 2, 2}, 16, 1, "", 2, MKLDNNPlugin::impl_desc_type::jit, false },
conv_test_params{{1, 16, 30, 30, 10}, {5, 5, 5}, {1, 1, 1}, {2, 2, 2}, {2, 2, 2}, 16, 1, "", 2, MKLDNNPlugin::impl_desc_type::ref_any, false,
{MKLDNNPlugin::impl_desc_type::ref_any} } ));
#ifdef USE_MKL
@ -380,29 +374,45 @@ INSTANTIATE_TEST_CASE_P(
MKLTestConvolution, MKLDNNGraphConvolutionTests,
::testing::Values(
conv_test_params{{1, 9, 16, 32},
{1, 1}, {1, 1}, {0, 0}, {0, 0}, 17, 1, "", 6, MKLDNNPlugin::impl_desc_type::gemm,
{1, 1}, {1, 1}, {0, 0}, {0, 0}, 17, 1, "", 6, MKLDNNPlugin::impl_desc_type::gemm, false,
{MKLDNNPlugin::impl_desc_type::gemm_any,
MKLDNNPlugin::impl_desc_type::gemm_blas,
MKLDNNPlugin::impl_desc_type::gemm_avx512,
MKLDNNPlugin::impl_desc_type::gemm_avx2,
MKLDNNPlugin::impl_desc_type::gemm_sse42} },
conv_test_params{{1, 5, 15, 20, 20},
{3, 3, 3}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}, 64, 1, "", 2, MKLDNNPlugin::impl_desc_type::gemm_blas,
{3, 3, 3}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}, 64, 1, "", 2, MKLDNNPlugin::impl_desc_type::gemm_blas, false,
{MKLDNNPlugin::impl_desc_type::gemm_blas} },
conv_test_params{{1, 5, 15, 20, 20},
{3, 3, 3}, {3, 2, 1}, {0, 0, 0}, {0, 0, 0}, 64, 1, "", 2, MKLDNNPlugin::impl_desc_type::gemm_blas,
{3, 3, 3}, {3, 2, 1}, {0, 0, 0}, {0, 0, 0}, 64, 1, "", 2, MKLDNNPlugin::impl_desc_type::gemm_blas, false,
{MKLDNNPlugin::impl_desc_type::gemm_blas} },
// conv_test_params{{1, 5, 15, 20, 20},
// {3, 3, 3}, {1, 1, 1}, {2, 2, 2}, {1, 1, 1}, 64, 1, "", 2, MKLDNNPlugin::impl_desc_type::gemm_blas,
// {3, 3, 3}, {1, 1, 1}, {2, 2, 2}, {1, 1, 1}, 64, 1, "", 2, MKLDNNPlugin::impl_desc_type::gemm_blas, false,
// {MKLDNNPlugin::impl_desc_type::gemm_blas} },
conv_test_params{{1, 16, 30, 30, 10},
{5, 5, 5}, {1, 1, 1}, {2, 2, 2}, {2, 2, 2}, 16, 1, "", 2, MKLDNNPlugin::impl_desc_type::gemm_blas,
{5, 5, 5}, {1, 1, 1}, {2, 2, 2}, {2, 2, 2}, 16, 1, "", 2, MKLDNNPlugin::impl_desc_type::gemm_blas, false,
{MKLDNNPlugin::impl_desc_type::gemm_blas} },
conv_test_params{{1, 4, 16, 16, 16},
{3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, 8, 1, "", 2, MKLDNNPlugin::impl_desc_type::gemm_blas,
{3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, 8, 1, "", 2, MKLDNNPlugin::impl_desc_type::gemm_blas, false,
{MKLDNNPlugin::impl_desc_type::gemm_blas} } ));
#endif
INSTANTIATE_TEST_CASE_P(
TestConvolutionDefaultPrimitivesPriority, MKLDNNGraphConvolutionTests,
::testing::Values(
/*0*/ conv_test_params{{1, 9, 16, 32}, {1, 1}, {1, 1}, {0, 0}, {0, 0}, 17, 1, "same_upper", 6,
MKLDNNPlugin::impl_desc_type::jit | MKLDNNPlugin::impl_desc_type::_1x1, true },
conv_test_params{{1, 9, 32, 16}, {2, 4}, {1, 1}, {1, 1}, {0, 2}, 17, 1, "", 4, MKLDNNPlugin::impl_desc_type::jit, true },
conv_test_params{{1, 9, 32, 16}, {2, 4}, {2, 1}, {0, 0}, {0, 0}, 17, 1, "", 4, MKLDNNPlugin::impl_desc_type::jit, true },
conv_test_params{{1, 3, 40, 40}, {3, 3}, {1, 2}, {0, 0}, {0, 0}, 20, 1, "", 4, MKLDNNPlugin::impl_desc_type::jit, true },
conv_test_params{{1, 1, 40, 40}, {3, 3}, {1, 2}, {0, 0}, {0, 0}, 20, 1, "", 4, MKLDNNPlugin::impl_desc_type::jit, true },
conv_test_params{{1, 1, 32, 16}, {2, 4}, {2, 1}, {0, 0}, {0, 0}, 17, 1, "", 4, MKLDNNPlugin::impl_desc_type::jit, true },
// 5D
/*6*/ conv_test_params{{1, 3, 15, 25, 20}, {3, 3, 3}, {2, 2, 2}, {0, 0, 0}, {0, 0, 0}, 64, 1, "", 2, MKLDNNPlugin::impl_desc_type::jit, true },
conv_test_params{{1, 24, 15, 25, 20}, {3, 3, 3}, {2, 2, 2}, {0, 0, 0}, {0, 0, 0}, 64, 1, "", 2, MKLDNNPlugin::impl_desc_type::jit, true },
conv_test_params{{1, 32, 15, 25, 20}, {3, 3, 3}, {2, 2, 2}, {0, 0, 0}, {0, 0, 0}, 64, 1, "", 2, MKLDNNPlugin::impl_desc_type::jit, true },
conv_test_params{{1, 16, 30, 30, 10}, {5, 5, 5}, {1, 1, 1}, {2, 2, 2}, {2, 2, 2}, 16, 1, "", 2, MKLDNNPlugin::impl_desc_type::jit, true } ));
class MKLDNNGraphDynBatchConvolutionTests: public MKLDNNGraphConvolutionTests {
protected:
@ -490,31 +500,31 @@ INSTANTIATE_TEST_CASE_P(
::testing::Values(
conv_test_params{{1, 8, 16, 32},
{1, 1}, {1, 1}, {0, 0}, {0, 0}, 17, 1, "same_upper", 7, MKLDNNPlugin::impl_desc_type::jit | MKLDNNPlugin::impl_desc_type::_1x1,
{MKLDNNPlugin::impl_desc_type::jit_avx512_winograd}},
false, {MKLDNNPlugin::impl_desc_type::jit_avx512_winograd}},
conv_test_params{{1, 9, 32, 16},
{2, 4}, {1, 1}, {0, 0}, {0, 0}, 17, 1, "", 5, MKLDNNPlugin::impl_desc_type::jit,
{MKLDNNPlugin::impl_desc_type::jit_avx512_winograd} },
false, {MKLDNNPlugin::impl_desc_type::jit_avx512_winograd} },
conv_test_params{{1, 9, 32, 16},
{2, 4}, {2, 1}, {0, 0}, {0, 0}, 17, 1, "", 5, MKLDNNPlugin::impl_desc_type::jit,
{MKLDNNPlugin::impl_desc_type::jit_avx512_winograd} },
false, {MKLDNNPlugin::impl_desc_type::jit_avx512_winograd} },
conv_test_params{{1, 3, 40, 40},
{3, 3}, {1, 2}, {0, 0}, {0, 0}, 20, 1, "", 5, MKLDNNPlugin::impl_desc_type::jit,
{MKLDNNPlugin::impl_desc_type::jit_avx512_winograd} },
false, {MKLDNNPlugin::impl_desc_type::jit_avx512_winograd} },
conv_test_params{{1, 1, 40, 40},
{3, 3}, {1, 2}, {0, 0}, {0, 0}, 20, 1, "", 5, MKLDNNPlugin::impl_desc_type::jit,
{MKLDNNPlugin::impl_desc_type::jit_avx512_winograd} },
false, {MKLDNNPlugin::impl_desc_type::jit_avx512_winograd} },
conv_test_params{{1, 1, 32, 16},
{2, 4}, {2, 1}, {0, 0}, {0, 0}, 17, 1, "", 5, MKLDNNPlugin::impl_desc_type::jit,
{MKLDNNPlugin::impl_desc_type::jit_avx512_winograd} },
false, {MKLDNNPlugin::impl_desc_type::jit_avx512_winograd} },
conv_test_params{{1, 9, 32, 16},
{2, 4}, {1, 1}, {0, 0}, {0, 0}, 17, 1, "", 5, MKLDNNPlugin::impl_desc_type::ref_any,
{MKLDNNPlugin::impl_desc_type::ref_any} } ));
false, {MKLDNNPlugin::impl_desc_type::ref_any} } ));
#ifdef USE_MKL
INSTANTIATE_TEST_CASE_P(
MKLTestDynBatchConvolution, MKLDNNGraphDynBatchConvolutionTests,
::testing::Values(
conv_test_params{{1, 9, 16, 32},
{1, 1}, {1, 1}, {0, 0}, {0, 0}, 17, 1, "", 7, MKLDNNPlugin::impl_desc_type::gemm,
{1, 1}, {1, 1}, {0, 0}, {0, 0}, 17, 1, "", 7, MKLDNNPlugin::impl_desc_type::gemm, false,
{MKLDNNPlugin::impl_desc_type::gemm_any,
MKLDNNPlugin::impl_desc_type::gemm_blas,
MKLDNNPlugin::impl_desc_type::gemm_avx512,