diff --git a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_conv_node.cpp b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_conv_node.cpp index 606e59914de..7cc72943001 100644 --- a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_conv_node.cpp +++ b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_conv_node.cpp @@ -272,10 +272,11 @@ void MKLDNNConvolutionNode::getSupportedDescriptors() { } } + MKLDNNMemoryDesc in_candidate, out_candidate; if (canBeExecutedInInt8()) { - MKLDNNMemoryDesc in_candidate = MKLDNNMemoryDesc(getParentEdgeAt(0)->getDims(), inputDataType, + in_candidate = MKLDNNMemoryDesc(getParentEdgeAt(0)->getDims(), inputDataType, getParentEdgeAt(0)->getDims().ndims() == 5 ? memory::ndhwc : memory::nhwc); - MKLDNNMemoryDesc out_candidate = MKLDNNMemoryDesc(getChildEdgeAt(0)->getDims(), outputDataType, + out_candidate = MKLDNNMemoryDesc(getChildEdgeAt(0)->getDims(), outputDataType, getParentEdgeAt(0)->getDims().ndims() == 5 ? memory::ndhwc : memory::nhwc); createDescriptor({in_candidate}, {out_candidate}); } else { @@ -308,13 +309,9 @@ void MKLDNNConvolutionNode::getSupportedDescriptors() { Layout layout = convLayer->input()->getLayout(); if (layout == NCHW || layout == NHWC) { - MKLDNNMemoryDesc in_candidate(getParentEdgeAt(0)->getDims(), inputDataType, - layout == NCHW ? memory::nchw : memory::nhwc); - MKLDNNMemoryDesc out_candidate(getChildEdgeAt(0)->getDims(), outputDataType, - layout == NCHW ? memory::nchw : memory::nhwc); - createDescriptor({in_candidate}, {out_candidate}); - if (IC == 3 || IC == 1) { + in_candidate = MKLDNNMemoryDesc(getParentEdgeAt(0)->getDims(), inputDataType, + layout == NCHW ? memory::nchw : memory::nhwc); out_candidate = MKLDNNMemoryDesc(getChildEdgeAt(0)->getDims(), outputDataType, memory::nChw16c); createDescriptor({in_candidate}, {out_candidate}); out_candidate = MKLDNNMemoryDesc(getChildEdgeAt(0)->getDims(), outputDataType, memory::nChw8c); @@ -327,13 +324,15 @@ void MKLDNNConvolutionNode::getSupportedDescriptors() { out_candidate = MKLDNNMemoryDesc(getChildEdgeAt(0)->getDims(), outputDataType, memory::nChw8c); createDescriptor({in_candidate}, {out_candidate}); } - } else if (layout == NCDHW || layout == NDHWC) { - MKLDNNMemoryDesc in_candidate(getParentEdgeAt(0)->getDims(), inputDataType, - layout == NCDHW ? memory::ncdhw : memory::ndhwc); - MKLDNNMemoryDesc out_candidate(getChildEdgeAt(0)->getDims(), outputDataType, - layout == NCDHW ? memory::ncdhw : memory::ndhwc); - createDescriptor({in_candidate}, {out_candidate}); + in_candidate = MKLDNNMemoryDesc(getParentEdgeAt(0)->getDims(), inputDataType, + layout == NCHW ? memory::nchw : memory::nhwc); + out_candidate = MKLDNNMemoryDesc(getChildEdgeAt(0)->getDims(), outputDataType, + layout == NCHW ? memory::nchw : memory::nhwc); + createDescriptor({in_candidate}, {out_candidate}); + } else if (layout == NCDHW || layout == NDHWC) { + in_candidate = MKLDNNMemoryDesc(getParentEdgeAt(0)->getDims(), inputDataType, + layout == NCDHW ? memory::ncdhw : memory::ndhwc); if (IC == 3 || IC == 1) { out_candidate = MKLDNNMemoryDesc(getChildEdgeAt(0)->getDims(), outputDataType, memory::nCdhw16c); createDescriptor({in_candidate}, {out_candidate}); @@ -347,6 +346,12 @@ void MKLDNNConvolutionNode::getSupportedDescriptors() { out_candidate = MKLDNNMemoryDesc(getChildEdgeAt(0)->getDims(), outputDataType, memory::nCdhw8c); createDescriptor({in_candidate}, {out_candidate}); } + + in_candidate = MKLDNNMemoryDesc(getParentEdgeAt(0)->getDims(), inputDataType, + layout == NCDHW ? memory::ncdhw : memory::ndhwc); + out_candidate = MKLDNNMemoryDesc(getChildEdgeAt(0)->getDims(), outputDataType, + layout == NCDHW ? memory::ncdhw : memory::ndhwc); + createDescriptor({in_candidate}, {out_candidate}); } } } @@ -556,7 +561,11 @@ void MKLDNNConvolutionNode::initSupportedPrimitiveDescriptors() { addZeroPoints(attr); setPostOps(attr); + bool containJitImpl = false; + for (auto& desc : descs) { + if (containJitImpl && isPossibleToSkipInitConfig(desc)) + continue; auto itpd = desc.createPrimitiveDescriptorIterator(getEngine(), attr); while (itpd.is_not_end()) { InferenceEngine::LayerConfig config; @@ -610,6 +619,8 @@ void MKLDNNConvolutionNode::initSupportedPrimitiveDescriptors() { outFormats.emplace_back(static_cast(itpd.dst_primitive_desc().desc().data.format)); } impl_desc_type impl_type = parse_impl_name(itpd.get_impl_info_str()); + if (impl_type & jit) + containJitImpl = true; supportedPrimitiveDescriptors.emplace_back(config, impl_type, outFormats); itpd++; @@ -790,8 +801,13 @@ void MKLDNNConvolutionNode::initDescriptor(const InferenceEngine::LayerConfig& c InferenceEngine::LayerConfig rightConfig = selectedPD->getConfig(); size_t selected_count = 0; + + bool containJitImpl = false; + for (size_t i = 0; i < descs.size(); i++) { - const auto& desc = descs[i]; + auto& desc = descs[i]; + if (containJitImpl && isPossibleToSkipInitConfig(desc)) + continue; auto itpd = desc.createPrimitiveDescriptorIterator(getEngine(), attr); while (itpd.is_not_end()) { InferenceEngine::LayerConfig cfg; @@ -836,6 +852,8 @@ void MKLDNNConvolutionNode::initDescriptor(const InferenceEngine::LayerConfig& c cfg.outConfs.push_back(dataConfig); } impl_desc_type impl_type = parse_impl_name(itpd.get_impl_info_str()); + if (impl_type & jit) + containJitImpl = true; if (selected_count == selectedPrimitiveDescriptorIndex) { if (impl_type != selectedPD->getImplementationType()) { @@ -888,6 +906,41 @@ void MKLDNNConvolutionNode::filterSupportedDescriptors() { } } +bool MKLDNNConvolutionNode::isPossibleToSkipInitConfig(MKLDNNDescriptor &desc) { + // WA: In some cases, we can predict in advance the type of primitive that will be called in the future. + // In particular, isPossibleToSkipInitConfig() checks whether we can skip the creation of primitives with + // gemm implementation, which significantly increase the network load time. + if (!inputMemoryFormatsFilter.empty() || !outputMemoryFormatsFilter.empty()) + return false; + + if (getCnnLayer()->params.find("PrimitivesPriority") != getCnnLayer()->params.end()) + return false; + + // Here we check that we will not delete jit_planar_conv primitive by mistake. + // It requires: + // 1) strides equal 1; + // 2) not grouped; + // 3) first dim of weights is not 1. + bool isPossibleJitPlanar = true; + if (isGrouped || weightDims[0] != 1) + isPossibleJitPlanar = false; + for (int i = 0; i < stride.size(); i++) + if (stride[i] != 1) + isPossibleJitPlanar = false; + + std::shared_ptr convDesc(desc); + auto srcMemFmt = convDesc->data.src_desc.format; + auto dstMemFmt = convDesc->data.dst_desc.format; + auto srcDataType = convDesc->data.src_desc.data_type; + auto dstDataType = convDesc->data.dst_desc.data_type; + bool isPlanarFloatConv = (srcMemFmt == memory::nchw || srcMemFmt == memory::ncdhw) + && (dstMemFmt == memory::nchw || dstMemFmt == memory::ncdhw) + && srcDataType == memory::f32 + && dstDataType == memory::f32; + + return !isPossibleJitPlanar && isPlanarFloatConv; +} + MKLDNNMemoryDesc MKLDNNConvolutionNode::getSrcMemDesc(mkldnn::primitive_desc_iterator &primitive_desc_it, size_t idx) { InferenceEngine::TensorDesc desc = idx > 0 ? MKLDNNMemoryDesc(primitive_desc_it.weights_primitive_desc(idx - 1).desc()) : MKLDNNMemoryDesc(primitive_desc_it.src_primitive_desc(idx).desc()); diff --git a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_conv_node.h b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_conv_node.h index ecd2193c17e..d031724d9cb 100644 --- a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_conv_node.h +++ b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_conv_node.h @@ -27,6 +27,7 @@ public: void initSupportedPrimitiveDescriptors() override; void filterSupportedPrimitiveDescriptors() override; void filterSupportedDescriptors(); + bool isPossibleToSkipInitConfig(MKLDNNDescriptor &desc); bool created() const override; bool canBeInPlace() const override { return false; diff --git a/inference-engine/tests_deprecated/unit/engines/mkldnn/graph/layers/internal/graph_conv_test.cpp b/inference-engine/tests_deprecated/unit/engines/mkldnn/graph/layers/internal/graph_conv_test.cpp index 3102143798c..92f8bfbaef7 100644 --- a/inference-engine/tests_deprecated/unit/engines/mkldnn/graph/layers/internal/graph_conv_test.cpp +++ b/inference-engine/tests_deprecated/unit/engines/mkldnn/graph/layers/internal/graph_conv_test.cpp @@ -31,6 +31,7 @@ struct conv_test_params { size_t num_prim_desc; int selectedType; + bool defaultPrimitivesPriority; vector preferTypes; vector> comp; @@ -149,7 +150,7 @@ class MKLDNNGraphConvolutionTests: public TestsCommon, + output="_OC_" group="_GC_" _PRIM_PRIORITY_/> @@ -216,13 +217,17 @@ protected: REPLACE_WITH_NUM(model, "_S1_", w_data_size); REPLACE_WITH_NUM(model, "_S2_", b_data_size); - std::string impls; - for (const auto& preferType : p.preferTypes) { - if (!impls.empty()) - impls += ","; - impls += "cpu:" + MKLDNNGraphTestClass::getStrPrimitiveDescriptorType(preferType); + std::string primitivesPriorityStr; + if (!p.defaultPrimitivesPriority) { + std::string impls; + for (const auto& preferType : p.preferTypes) { + if (!impls.empty()) + impls += ","; + impls += "cpu:" + MKLDNNGraphTestClass::getStrPrimitiveDescriptorType(preferType); + } + primitivesPriorityStr = "PrimitivesPriority=\"" + impls + "\""; } - REPLACE_WITH_STR(model, "_IMPLS_", impls); + REPLACE_WITH_STR(model, "_PRIM_PRIORITY_", primitivesPriorityStr); return model; } @@ -263,6 +268,10 @@ protected: if (node->getType() == MKLDNNPlugin::Convolution) { ASSERT_LE(p.num_prim_desc, node->getSupportedPrimitiveDescriptors().size()); for (const auto prim : node->getSupportedPrimitiveDescriptors()) { + if (p.defaultPrimitivesPriority) { + if (prim.getImplementationType() & MKLDNNPlugin::impl_desc_type::gemm) + FAIL() << "There should be no gemm implementation in supportedPrimitiveDescriptors"; + } std::cout << MKLDNNGraphTestClass::getStrPrimitiveDescriptorType(prim.getImplementationType()) << " "; } std::cout << std::endl; @@ -335,44 +344,29 @@ TEST_P(MKLDNNGraphConvolutionTests, TestsConvolution) {} INSTANTIATE_TEST_CASE_P( TestConvolution, MKLDNNGraphConvolutionTests, ::testing::Values( - /*0*/ conv_test_params{{1, 9, 16, 32}, - {1, 1}, {1, 1}, {0, 0}, {0, 0}, 17, 1, "same_upper", 6, MKLDNNPlugin::impl_desc_type::jit | MKLDNNPlugin::impl_desc_type::_1x1 }, - conv_test_params{{1, 9, 32, 16}, - {2, 4}, {1, 1}, {1, 1}, {0, 2}, 17, 1, "", 5, MKLDNNPlugin::impl_desc_type::jit }, - conv_test_params{{1, 9, 32, 16}, - {2, 4}, {2, 1}, {0, 0}, {0, 0}, 17, 1, "", 5, MKLDNNPlugin::impl_desc_type::jit }, - conv_test_params{{1, 3, 40, 40}, - {3, 3}, {1, 2}, {0, 0}, {0, 0}, 20, 1, "", 5, MKLDNNPlugin::impl_desc_type::jit }, - conv_test_params{{1, 1, 40, 40}, - {3, 3}, {1, 2}, {0, 0}, {0, 0}, 20, 1, "", 5, MKLDNNPlugin::impl_desc_type::jit }, - conv_test_params{{1, 1, 32, 16}, - {2, 4}, {2, 1}, {0, 0}, {0, 0}, 17, 1, "", 5, MKLDNNPlugin::impl_desc_type::jit }, - conv_test_params{{1, 9, 32, 16}, - {2, 4}, {1, 1}, {0, 0}, {0, 0}, 17, 1, "", 5, MKLDNNPlugin::impl_desc_type::ref_any, + /*0*/ conv_test_params{{1, 9, 16, 32}, {1, 1}, {1, 1}, {0, 0}, {0, 0}, 17, 1, "same_upper", 6, + MKLDNNPlugin::impl_desc_type::jit | MKLDNNPlugin::impl_desc_type::_1x1, false }, + conv_test_params{{1, 9, 32, 16}, {2, 4}, {1, 1}, {1, 1}, {0, 2}, 17, 1, "", 5, MKLDNNPlugin::impl_desc_type::jit, false }, + conv_test_params{{1, 9, 32, 16}, {2, 4}, {2, 1}, {0, 0}, {0, 0}, 17, 1, "", 5, MKLDNNPlugin::impl_desc_type::jit, false }, + conv_test_params{{1, 3, 40, 40}, {3, 3}, {1, 2}, {0, 0}, {0, 0}, 20, 1, "", 5, MKLDNNPlugin::impl_desc_type::jit, false }, + conv_test_params{{1, 1, 40, 40}, {3, 3}, {1, 2}, {0, 0}, {0, 0}, 20, 1, "", 5, MKLDNNPlugin::impl_desc_type::jit, false }, + conv_test_params{{1, 1, 32, 16}, {2, 4}, {2, 1}, {0, 0}, {0, 0}, 17, 1, "", 5, MKLDNNPlugin::impl_desc_type::jit, false }, + conv_test_params{{1, 9, 32, 16}, {2, 4}, {1, 1}, {0, 0}, {0, 0}, 17, 1, "", 5, MKLDNNPlugin::impl_desc_type::ref_any, false, {MKLDNNPlugin::impl_desc_type::ref_any} }, - conv_test_params{{1, 4, 54, 96}, - {3, 3}, {1, 1}, {1, 1}, {0, 0}, 64, 1, "", 3, MKLDNNPlugin::impl_desc_type::ref_any, + conv_test_params{{1, 4, 54, 96}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, 64, 1, "", 3, MKLDNNPlugin::impl_desc_type::ref_any, false, {MKLDNNPlugin::impl_desc_type::jit_avx512_winograd, MKLDNNPlugin::impl_desc_type::ref_any}}, // 5D - /*8*/ conv_test_params{{1, 3, 15, 20, 20}, - {3, 3, 3}, {2, 2, 2}, {0, 0, 0}, {0, 0, 0}, 64, 1, "", 2, MKLDNNPlugin::impl_desc_type::ref_any, + /*8*/ conv_test_params{{1, 3, 15, 20, 20}, {3, 3, 3}, {2, 2, 2}, {0, 0, 0}, {0, 0, 0}, 64, 1, "", 2, MKLDNNPlugin::impl_desc_type::ref_any, false, {MKLDNNPlugin::impl_desc_type::ref_any} }, - conv_test_params{{1, 24, 15, 20, 20}, - {3, 3, 3}, {2, 2, 2}, {0, 0, 0}, {0, 0, 0}, 64, 1, "", 2, MKLDNNPlugin::impl_desc_type::ref_any, + conv_test_params{{1, 24, 15, 20, 20}, {3, 3, 3}, {2, 2, 2}, {0, 0, 0}, {0, 0, 0}, 64, 1, "", 2, MKLDNNPlugin::impl_desc_type::ref_any, false, {MKLDNNPlugin::impl_desc_type::ref_any} }, - conv_test_params{{1, 32, 15, 20, 20}, - {3, 3, 3}, {2, 2, 2}, {0, 0, 0}, {0, 0, 0}, 64, 1, "", 2, MKLDNNPlugin::impl_desc_type::ref_any, + conv_test_params{{1, 32, 15, 20, 20}, {3, 3, 3}, {2, 2, 2}, {0, 0, 0}, {0, 0, 0}, 64, 1, "", 2, MKLDNNPlugin::impl_desc_type::ref_any, false, {MKLDNNPlugin::impl_desc_type::ref_any} }, - conv_test_params{{1, 3, 15, 25, 20}, - {3, 3, 3}, {2, 2, 2}, {0, 0, 0}, {0, 0, 0}, 64, 1, "", 2, MKLDNNPlugin::impl_desc_type::jit }, - conv_test_params{{1, 24, 15, 25, 20}, - {3, 3, 3}, {2, 2, 2}, {0, 0, 0}, {0, 0, 0}, 64, 1, "", 2, MKLDNNPlugin::impl_desc_type::jit }, - /*13*/ conv_test_params{{1, 32, 15, 25, 20}, - {3, 3, 3}, {2, 2, 2}, {0, 0, 0}, {0, 0, 0}, 64, 1, "", 2, MKLDNNPlugin::impl_desc_type::jit }, - /*20*/ conv_test_params{{1, 16, 30, 30, 10}, - {5, 5, 5}, {1, 1, 1}, {2, 2, 2}, {2, 2, 2}, 16, 1, "", 2, MKLDNNPlugin::impl_desc_type::jit }, - conv_test_params{{1, 16, 30, 30, 10}, - {5, 5, 5}, {1, 1, 1}, {2, 2, 2}, {2, 2, 2}, 16, 1, "", 2, MKLDNNPlugin::impl_desc_type::ref_any, + conv_test_params{{1, 3, 15, 25, 20}, {3, 3, 3}, {2, 2, 2}, {0, 0, 0}, {0, 0, 0}, 64, 1, "", 2, MKLDNNPlugin::impl_desc_type::jit, false }, + conv_test_params{{1, 24, 15, 25, 20}, {3, 3, 3}, {2, 2, 2}, {0, 0, 0}, {0, 0, 0}, 64, 1, "", 2, MKLDNNPlugin::impl_desc_type::jit, false }, + /*13*/ conv_test_params{{1, 32, 15, 25, 20}, {3, 3, 3}, {2, 2, 2}, {0, 0, 0}, {0, 0, 0}, 64, 1, "", 2, MKLDNNPlugin::impl_desc_type::jit, false }, + conv_test_params{{1, 16, 30, 30, 10}, {5, 5, 5}, {1, 1, 1}, {2, 2, 2}, {2, 2, 2}, 16, 1, "", 2, MKLDNNPlugin::impl_desc_type::jit, false }, + conv_test_params{{1, 16, 30, 30, 10}, {5, 5, 5}, {1, 1, 1}, {2, 2, 2}, {2, 2, 2}, 16, 1, "", 2, MKLDNNPlugin::impl_desc_type::ref_any, false, {MKLDNNPlugin::impl_desc_type::ref_any} } )); #ifdef USE_MKL @@ -380,29 +374,45 @@ INSTANTIATE_TEST_CASE_P( MKLTestConvolution, MKLDNNGraphConvolutionTests, ::testing::Values( conv_test_params{{1, 9, 16, 32}, - {1, 1}, {1, 1}, {0, 0}, {0, 0}, 17, 1, "", 6, MKLDNNPlugin::impl_desc_type::gemm, + {1, 1}, {1, 1}, {0, 0}, {0, 0}, 17, 1, "", 6, MKLDNNPlugin::impl_desc_type::gemm, false, {MKLDNNPlugin::impl_desc_type::gemm_any, MKLDNNPlugin::impl_desc_type::gemm_blas, MKLDNNPlugin::impl_desc_type::gemm_avx512, MKLDNNPlugin::impl_desc_type::gemm_avx2, MKLDNNPlugin::impl_desc_type::gemm_sse42} }, conv_test_params{{1, 5, 15, 20, 20}, - {3, 3, 3}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}, 64, 1, "", 2, MKLDNNPlugin::impl_desc_type::gemm_blas, + {3, 3, 3}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}, 64, 1, "", 2, MKLDNNPlugin::impl_desc_type::gemm_blas, false, {MKLDNNPlugin::impl_desc_type::gemm_blas} }, conv_test_params{{1, 5, 15, 20, 20}, - {3, 3, 3}, {3, 2, 1}, {0, 0, 0}, {0, 0, 0}, 64, 1, "", 2, MKLDNNPlugin::impl_desc_type::gemm_blas, + {3, 3, 3}, {3, 2, 1}, {0, 0, 0}, {0, 0, 0}, 64, 1, "", 2, MKLDNNPlugin::impl_desc_type::gemm_blas, false, {MKLDNNPlugin::impl_desc_type::gemm_blas} }, // conv_test_params{{1, 5, 15, 20, 20}, - // {3, 3, 3}, {1, 1, 1}, {2, 2, 2}, {1, 1, 1}, 64, 1, "", 2, MKLDNNPlugin::impl_desc_type::gemm_blas, + // {3, 3, 3}, {1, 1, 1}, {2, 2, 2}, {1, 1, 1}, 64, 1, "", 2, MKLDNNPlugin::impl_desc_type::gemm_blas, false, // {MKLDNNPlugin::impl_desc_type::gemm_blas} }, conv_test_params{{1, 16, 30, 30, 10}, - {5, 5, 5}, {1, 1, 1}, {2, 2, 2}, {2, 2, 2}, 16, 1, "", 2, MKLDNNPlugin::impl_desc_type::gemm_blas, + {5, 5, 5}, {1, 1, 1}, {2, 2, 2}, {2, 2, 2}, 16, 1, "", 2, MKLDNNPlugin::impl_desc_type::gemm_blas, false, {MKLDNNPlugin::impl_desc_type::gemm_blas} }, conv_test_params{{1, 4, 16, 16, 16}, - {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, 8, 1, "", 2, MKLDNNPlugin::impl_desc_type::gemm_blas, + {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, 8, 1, "", 2, MKLDNNPlugin::impl_desc_type::gemm_blas, false, {MKLDNNPlugin::impl_desc_type::gemm_blas} } )); #endif +INSTANTIATE_TEST_CASE_P( + TestConvolutionDefaultPrimitivesPriority, MKLDNNGraphConvolutionTests, + ::testing::Values( + /*0*/ conv_test_params{{1, 9, 16, 32}, {1, 1}, {1, 1}, {0, 0}, {0, 0}, 17, 1, "same_upper", 6, + MKLDNNPlugin::impl_desc_type::jit | MKLDNNPlugin::impl_desc_type::_1x1, true }, + conv_test_params{{1, 9, 32, 16}, {2, 4}, {1, 1}, {1, 1}, {0, 2}, 17, 1, "", 4, MKLDNNPlugin::impl_desc_type::jit, true }, + conv_test_params{{1, 9, 32, 16}, {2, 4}, {2, 1}, {0, 0}, {0, 0}, 17, 1, "", 4, MKLDNNPlugin::impl_desc_type::jit, true }, + conv_test_params{{1, 3, 40, 40}, {3, 3}, {1, 2}, {0, 0}, {0, 0}, 20, 1, "", 4, MKLDNNPlugin::impl_desc_type::jit, true }, + conv_test_params{{1, 1, 40, 40}, {3, 3}, {1, 2}, {0, 0}, {0, 0}, 20, 1, "", 4, MKLDNNPlugin::impl_desc_type::jit, true }, + conv_test_params{{1, 1, 32, 16}, {2, 4}, {2, 1}, {0, 0}, {0, 0}, 17, 1, "", 4, MKLDNNPlugin::impl_desc_type::jit, true }, + // 5D + /*6*/ conv_test_params{{1, 3, 15, 25, 20}, {3, 3, 3}, {2, 2, 2}, {0, 0, 0}, {0, 0, 0}, 64, 1, "", 2, MKLDNNPlugin::impl_desc_type::jit, true }, + conv_test_params{{1, 24, 15, 25, 20}, {3, 3, 3}, {2, 2, 2}, {0, 0, 0}, {0, 0, 0}, 64, 1, "", 2, MKLDNNPlugin::impl_desc_type::jit, true }, + conv_test_params{{1, 32, 15, 25, 20}, {3, 3, 3}, {2, 2, 2}, {0, 0, 0}, {0, 0, 0}, 64, 1, "", 2, MKLDNNPlugin::impl_desc_type::jit, true }, + conv_test_params{{1, 16, 30, 30, 10}, {5, 5, 5}, {1, 1, 1}, {2, 2, 2}, {2, 2, 2}, 16, 1, "", 2, MKLDNNPlugin::impl_desc_type::jit, true } )); + class MKLDNNGraphDynBatchConvolutionTests: public MKLDNNGraphConvolutionTests { protected: @@ -490,31 +500,31 @@ INSTANTIATE_TEST_CASE_P( ::testing::Values( conv_test_params{{1, 8, 16, 32}, {1, 1}, {1, 1}, {0, 0}, {0, 0}, 17, 1, "same_upper", 7, MKLDNNPlugin::impl_desc_type::jit | MKLDNNPlugin::impl_desc_type::_1x1, - {MKLDNNPlugin::impl_desc_type::jit_avx512_winograd}}, + false, {MKLDNNPlugin::impl_desc_type::jit_avx512_winograd}}, conv_test_params{{1, 9, 32, 16}, {2, 4}, {1, 1}, {0, 0}, {0, 0}, 17, 1, "", 5, MKLDNNPlugin::impl_desc_type::jit, - {MKLDNNPlugin::impl_desc_type::jit_avx512_winograd} }, + false, {MKLDNNPlugin::impl_desc_type::jit_avx512_winograd} }, conv_test_params{{1, 9, 32, 16}, {2, 4}, {2, 1}, {0, 0}, {0, 0}, 17, 1, "", 5, MKLDNNPlugin::impl_desc_type::jit, - {MKLDNNPlugin::impl_desc_type::jit_avx512_winograd} }, + false, {MKLDNNPlugin::impl_desc_type::jit_avx512_winograd} }, conv_test_params{{1, 3, 40, 40}, {3, 3}, {1, 2}, {0, 0}, {0, 0}, 20, 1, "", 5, MKLDNNPlugin::impl_desc_type::jit, - {MKLDNNPlugin::impl_desc_type::jit_avx512_winograd} }, + false, {MKLDNNPlugin::impl_desc_type::jit_avx512_winograd} }, conv_test_params{{1, 1, 40, 40}, {3, 3}, {1, 2}, {0, 0}, {0, 0}, 20, 1, "", 5, MKLDNNPlugin::impl_desc_type::jit, - {MKLDNNPlugin::impl_desc_type::jit_avx512_winograd} }, + false, {MKLDNNPlugin::impl_desc_type::jit_avx512_winograd} }, conv_test_params{{1, 1, 32, 16}, {2, 4}, {2, 1}, {0, 0}, {0, 0}, 17, 1, "", 5, MKLDNNPlugin::impl_desc_type::jit, - {MKLDNNPlugin::impl_desc_type::jit_avx512_winograd} }, + false, {MKLDNNPlugin::impl_desc_type::jit_avx512_winograd} }, conv_test_params{{1, 9, 32, 16}, {2, 4}, {1, 1}, {0, 0}, {0, 0}, 17, 1, "", 5, MKLDNNPlugin::impl_desc_type::ref_any, - {MKLDNNPlugin::impl_desc_type::ref_any} } )); + false, {MKLDNNPlugin::impl_desc_type::ref_any} } )); #ifdef USE_MKL INSTANTIATE_TEST_CASE_P( MKLTestDynBatchConvolution, MKLDNNGraphDynBatchConvolutionTests, ::testing::Values( conv_test_params{{1, 9, 16, 32}, - {1, 1}, {1, 1}, {0, 0}, {0, 0}, 17, 1, "", 7, MKLDNNPlugin::impl_desc_type::gemm, + {1, 1}, {1, 1}, {0, 0}, {0, 0}, 17, 1, "", 7, MKLDNNPlugin::impl_desc_type::gemm, false, {MKLDNNPlugin::impl_desc_type::gemm_any, MKLDNNPlugin::impl_desc_type::gemm_blas, MKLDNNPlugin::impl_desc_type::gemm_avx512,