[MKLDNN_PLUGIN] Convolution node: skip initializing of primitive descriptors for planar layout if there is already jit primitive (#672)
This commit is contained in:
parent
158d32139f
commit
e53b1b7fbc
@ -272,10 +272,11 @@ void MKLDNNConvolutionNode::getSupportedDescriptors() {
|
||||
}
|
||||
}
|
||||
|
||||
MKLDNNMemoryDesc in_candidate, out_candidate;
|
||||
if (canBeExecutedInInt8()) {
|
||||
MKLDNNMemoryDesc in_candidate = MKLDNNMemoryDesc(getParentEdgeAt(0)->getDims(), inputDataType,
|
||||
in_candidate = MKLDNNMemoryDesc(getParentEdgeAt(0)->getDims(), inputDataType,
|
||||
getParentEdgeAt(0)->getDims().ndims() == 5 ? memory::ndhwc : memory::nhwc);
|
||||
MKLDNNMemoryDesc out_candidate = MKLDNNMemoryDesc(getChildEdgeAt(0)->getDims(), outputDataType,
|
||||
out_candidate = MKLDNNMemoryDesc(getChildEdgeAt(0)->getDims(), outputDataType,
|
||||
getParentEdgeAt(0)->getDims().ndims() == 5 ? memory::ndhwc : memory::nhwc);
|
||||
createDescriptor({in_candidate}, {out_candidate});
|
||||
} else {
|
||||
@ -308,13 +309,9 @@ void MKLDNNConvolutionNode::getSupportedDescriptors() {
|
||||
Layout layout = convLayer->input()->getLayout();
|
||||
|
||||
if (layout == NCHW || layout == NHWC) {
|
||||
MKLDNNMemoryDesc in_candidate(getParentEdgeAt(0)->getDims(), inputDataType,
|
||||
layout == NCHW ? memory::nchw : memory::nhwc);
|
||||
MKLDNNMemoryDesc out_candidate(getChildEdgeAt(0)->getDims(), outputDataType,
|
||||
layout == NCHW ? memory::nchw : memory::nhwc);
|
||||
createDescriptor({in_candidate}, {out_candidate});
|
||||
|
||||
if (IC == 3 || IC == 1) {
|
||||
in_candidate = MKLDNNMemoryDesc(getParentEdgeAt(0)->getDims(), inputDataType,
|
||||
layout == NCHW ? memory::nchw : memory::nhwc);
|
||||
out_candidate = MKLDNNMemoryDesc(getChildEdgeAt(0)->getDims(), outputDataType, memory::nChw16c);
|
||||
createDescriptor({in_candidate}, {out_candidate});
|
||||
out_candidate = MKLDNNMemoryDesc(getChildEdgeAt(0)->getDims(), outputDataType, memory::nChw8c);
|
||||
@ -327,13 +324,15 @@ void MKLDNNConvolutionNode::getSupportedDescriptors() {
|
||||
out_candidate = MKLDNNMemoryDesc(getChildEdgeAt(0)->getDims(), outputDataType, memory::nChw8c);
|
||||
createDescriptor({in_candidate}, {out_candidate});
|
||||
}
|
||||
} else if (layout == NCDHW || layout == NDHWC) {
|
||||
MKLDNNMemoryDesc in_candidate(getParentEdgeAt(0)->getDims(), inputDataType,
|
||||
layout == NCDHW ? memory::ncdhw : memory::ndhwc);
|
||||
MKLDNNMemoryDesc out_candidate(getChildEdgeAt(0)->getDims(), outputDataType,
|
||||
layout == NCDHW ? memory::ncdhw : memory::ndhwc);
|
||||
createDescriptor({in_candidate}, {out_candidate});
|
||||
|
||||
in_candidate = MKLDNNMemoryDesc(getParentEdgeAt(0)->getDims(), inputDataType,
|
||||
layout == NCHW ? memory::nchw : memory::nhwc);
|
||||
out_candidate = MKLDNNMemoryDesc(getChildEdgeAt(0)->getDims(), outputDataType,
|
||||
layout == NCHW ? memory::nchw : memory::nhwc);
|
||||
createDescriptor({in_candidate}, {out_candidate});
|
||||
} else if (layout == NCDHW || layout == NDHWC) {
|
||||
in_candidate = MKLDNNMemoryDesc(getParentEdgeAt(0)->getDims(), inputDataType,
|
||||
layout == NCDHW ? memory::ncdhw : memory::ndhwc);
|
||||
if (IC == 3 || IC == 1) {
|
||||
out_candidate = MKLDNNMemoryDesc(getChildEdgeAt(0)->getDims(), outputDataType, memory::nCdhw16c);
|
||||
createDescriptor({in_candidate}, {out_candidate});
|
||||
@ -347,6 +346,12 @@ void MKLDNNConvolutionNode::getSupportedDescriptors() {
|
||||
out_candidate = MKLDNNMemoryDesc(getChildEdgeAt(0)->getDims(), outputDataType, memory::nCdhw8c);
|
||||
createDescriptor({in_candidate}, {out_candidate});
|
||||
}
|
||||
|
||||
in_candidate = MKLDNNMemoryDesc(getParentEdgeAt(0)->getDims(), inputDataType,
|
||||
layout == NCDHW ? memory::ncdhw : memory::ndhwc);
|
||||
out_candidate = MKLDNNMemoryDesc(getChildEdgeAt(0)->getDims(), outputDataType,
|
||||
layout == NCDHW ? memory::ncdhw : memory::ndhwc);
|
||||
createDescriptor({in_candidate}, {out_candidate});
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -556,7 +561,11 @@ void MKLDNNConvolutionNode::initSupportedPrimitiveDescriptors() {
|
||||
addZeroPoints(attr);
|
||||
setPostOps(attr);
|
||||
|
||||
bool containJitImpl = false;
|
||||
|
||||
for (auto& desc : descs) {
|
||||
if (containJitImpl && isPossibleToSkipInitConfig(desc))
|
||||
continue;
|
||||
auto itpd = desc.createPrimitiveDescriptorIterator(getEngine(), attr);
|
||||
while (itpd.is_not_end()) {
|
||||
InferenceEngine::LayerConfig config;
|
||||
@ -610,6 +619,8 @@ void MKLDNNConvolutionNode::initSupportedPrimitiveDescriptors() {
|
||||
outFormats.emplace_back(static_cast<memory::format>(itpd.dst_primitive_desc().desc().data.format));
|
||||
}
|
||||
impl_desc_type impl_type = parse_impl_name(itpd.get_impl_info_str());
|
||||
if (impl_type & jit)
|
||||
containJitImpl = true;
|
||||
|
||||
supportedPrimitiveDescriptors.emplace_back(config, impl_type, outFormats);
|
||||
itpd++;
|
||||
@ -790,8 +801,13 @@ void MKLDNNConvolutionNode::initDescriptor(const InferenceEngine::LayerConfig& c
|
||||
|
||||
InferenceEngine::LayerConfig rightConfig = selectedPD->getConfig();
|
||||
size_t selected_count = 0;
|
||||
|
||||
bool containJitImpl = false;
|
||||
|
||||
for (size_t i = 0; i < descs.size(); i++) {
|
||||
const auto& desc = descs[i];
|
||||
auto& desc = descs[i];
|
||||
if (containJitImpl && isPossibleToSkipInitConfig(desc))
|
||||
continue;
|
||||
auto itpd = desc.createPrimitiveDescriptorIterator(getEngine(), attr);
|
||||
while (itpd.is_not_end()) {
|
||||
InferenceEngine::LayerConfig cfg;
|
||||
@ -836,6 +852,8 @@ void MKLDNNConvolutionNode::initDescriptor(const InferenceEngine::LayerConfig& c
|
||||
cfg.outConfs.push_back(dataConfig);
|
||||
}
|
||||
impl_desc_type impl_type = parse_impl_name(itpd.get_impl_info_str());
|
||||
if (impl_type & jit)
|
||||
containJitImpl = true;
|
||||
|
||||
if (selected_count == selectedPrimitiveDescriptorIndex) {
|
||||
if (impl_type != selectedPD->getImplementationType()) {
|
||||
@ -888,6 +906,41 @@ void MKLDNNConvolutionNode::filterSupportedDescriptors() {
|
||||
}
|
||||
}
|
||||
|
||||
bool MKLDNNConvolutionNode::isPossibleToSkipInitConfig(MKLDNNDescriptor &desc) {
|
||||
// WA: In some cases, we can predict in advance the type of primitive that will be called in the future.
|
||||
// In particular, isPossibleToSkipInitConfig() checks whether we can skip the creation of primitives with
|
||||
// gemm implementation, which significantly increase the network load time.
|
||||
if (!inputMemoryFormatsFilter.empty() || !outputMemoryFormatsFilter.empty())
|
||||
return false;
|
||||
|
||||
if (getCnnLayer()->params.find("PrimitivesPriority") != getCnnLayer()->params.end())
|
||||
return false;
|
||||
|
||||
// Here we check that we will not delete jit_planar_conv primitive by mistake.
|
||||
// It requires:
|
||||
// 1) strides equal 1;
|
||||
// 2) not grouped;
|
||||
// 3) first dim of weights is not 1.
|
||||
bool isPossibleJitPlanar = true;
|
||||
if (isGrouped || weightDims[0] != 1)
|
||||
isPossibleJitPlanar = false;
|
||||
for (int i = 0; i < stride.size(); i++)
|
||||
if (stride[i] != 1)
|
||||
isPossibleJitPlanar = false;
|
||||
|
||||
std::shared_ptr<mkldnn::convolution_forward::desc> convDesc(desc);
|
||||
auto srcMemFmt = convDesc->data.src_desc.format;
|
||||
auto dstMemFmt = convDesc->data.dst_desc.format;
|
||||
auto srcDataType = convDesc->data.src_desc.data_type;
|
||||
auto dstDataType = convDesc->data.dst_desc.data_type;
|
||||
bool isPlanarFloatConv = (srcMemFmt == memory::nchw || srcMemFmt == memory::ncdhw)
|
||||
&& (dstMemFmt == memory::nchw || dstMemFmt == memory::ncdhw)
|
||||
&& srcDataType == memory::f32
|
||||
&& dstDataType == memory::f32;
|
||||
|
||||
return !isPossibleJitPlanar && isPlanarFloatConv;
|
||||
}
|
||||
|
||||
MKLDNNMemoryDesc MKLDNNConvolutionNode::getSrcMemDesc(mkldnn::primitive_desc_iterator &primitive_desc_it, size_t idx) {
|
||||
InferenceEngine::TensorDesc desc = idx > 0 ? MKLDNNMemoryDesc(primitive_desc_it.weights_primitive_desc(idx - 1).desc())
|
||||
: MKLDNNMemoryDesc(primitive_desc_it.src_primitive_desc(idx).desc());
|
||||
|
@ -27,6 +27,7 @@ public:
|
||||
void initSupportedPrimitiveDescriptors() override;
|
||||
void filterSupportedPrimitiveDescriptors() override;
|
||||
void filterSupportedDescriptors();
|
||||
bool isPossibleToSkipInitConfig(MKLDNNDescriptor &desc);
|
||||
bool created() const override;
|
||||
bool canBeInPlace() const override {
|
||||
return false;
|
||||
|
@ -31,6 +31,7 @@ struct conv_test_params {
|
||||
size_t num_prim_desc;
|
||||
|
||||
int selectedType;
|
||||
bool defaultPrimitivesPriority;
|
||||
vector<MKLDNNPlugin::impl_desc_type> preferTypes;
|
||||
|
||||
vector<std::function<void(MKLDNNPlugin::PrimitiveDescInfo)>> comp;
|
||||
@ -149,7 +150,7 @@ class MKLDNNGraphConvolutionTests: public TestsCommon,
|
||||
<convolution _AP_ kernel="_K_"
|
||||
pads_begin="_PB_" pads_end="_PE_"
|
||||
strides="_KS_"
|
||||
output="_OC_" group="_GC_" PrimitivesPriority="_IMPLS_"/>
|
||||
output="_OC_" group="_GC_" _PRIM_PRIORITY_/>
|
||||
|
||||
<weights offset="0" size="_S1_" />
|
||||
<biases offset="_S1_" size="_S2_" />
|
||||
@ -216,13 +217,17 @@ protected:
|
||||
REPLACE_WITH_NUM(model, "_S1_", w_data_size);
|
||||
REPLACE_WITH_NUM(model, "_S2_", b_data_size);
|
||||
|
||||
std::string primitivesPriorityStr;
|
||||
if (!p.defaultPrimitivesPriority) {
|
||||
std::string impls;
|
||||
for (const auto& preferType : p.preferTypes) {
|
||||
if (!impls.empty())
|
||||
impls += ",";
|
||||
impls += "cpu:" + MKLDNNGraphTestClass::getStrPrimitiveDescriptorType(preferType);
|
||||
}
|
||||
REPLACE_WITH_STR(model, "_IMPLS_", impls);
|
||||
primitivesPriorityStr = "PrimitivesPriority=\"" + impls + "\"";
|
||||
}
|
||||
REPLACE_WITH_STR(model, "_PRIM_PRIORITY_", primitivesPriorityStr);
|
||||
|
||||
return model;
|
||||
}
|
||||
@ -263,6 +268,10 @@ protected:
|
||||
if (node->getType() == MKLDNNPlugin::Convolution) {
|
||||
ASSERT_LE(p.num_prim_desc, node->getSupportedPrimitiveDescriptors().size());
|
||||
for (const auto prim : node->getSupportedPrimitiveDescriptors()) {
|
||||
if (p.defaultPrimitivesPriority) {
|
||||
if (prim.getImplementationType() & MKLDNNPlugin::impl_desc_type::gemm)
|
||||
FAIL() << "There should be no gemm implementation in supportedPrimitiveDescriptors";
|
||||
}
|
||||
std::cout << MKLDNNGraphTestClass::getStrPrimitiveDescriptorType(prim.getImplementationType()) << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
@ -335,44 +344,29 @@ TEST_P(MKLDNNGraphConvolutionTests, TestsConvolution) {}
|
||||
INSTANTIATE_TEST_CASE_P(
|
||||
TestConvolution, MKLDNNGraphConvolutionTests,
|
||||
::testing::Values(
|
||||
/*0*/ conv_test_params{{1, 9, 16, 32},
|
||||
{1, 1}, {1, 1}, {0, 0}, {0, 0}, 17, 1, "same_upper", 6, MKLDNNPlugin::impl_desc_type::jit | MKLDNNPlugin::impl_desc_type::_1x1 },
|
||||
conv_test_params{{1, 9, 32, 16},
|
||||
{2, 4}, {1, 1}, {1, 1}, {0, 2}, 17, 1, "", 5, MKLDNNPlugin::impl_desc_type::jit },
|
||||
conv_test_params{{1, 9, 32, 16},
|
||||
{2, 4}, {2, 1}, {0, 0}, {0, 0}, 17, 1, "", 5, MKLDNNPlugin::impl_desc_type::jit },
|
||||
conv_test_params{{1, 3, 40, 40},
|
||||
{3, 3}, {1, 2}, {0, 0}, {0, 0}, 20, 1, "", 5, MKLDNNPlugin::impl_desc_type::jit },
|
||||
conv_test_params{{1, 1, 40, 40},
|
||||
{3, 3}, {1, 2}, {0, 0}, {0, 0}, 20, 1, "", 5, MKLDNNPlugin::impl_desc_type::jit },
|
||||
conv_test_params{{1, 1, 32, 16},
|
||||
{2, 4}, {2, 1}, {0, 0}, {0, 0}, 17, 1, "", 5, MKLDNNPlugin::impl_desc_type::jit },
|
||||
conv_test_params{{1, 9, 32, 16},
|
||||
{2, 4}, {1, 1}, {0, 0}, {0, 0}, 17, 1, "", 5, MKLDNNPlugin::impl_desc_type::ref_any,
|
||||
/*0*/ conv_test_params{{1, 9, 16, 32}, {1, 1}, {1, 1}, {0, 0}, {0, 0}, 17, 1, "same_upper", 6,
|
||||
MKLDNNPlugin::impl_desc_type::jit | MKLDNNPlugin::impl_desc_type::_1x1, false },
|
||||
conv_test_params{{1, 9, 32, 16}, {2, 4}, {1, 1}, {1, 1}, {0, 2}, 17, 1, "", 5, MKLDNNPlugin::impl_desc_type::jit, false },
|
||||
conv_test_params{{1, 9, 32, 16}, {2, 4}, {2, 1}, {0, 0}, {0, 0}, 17, 1, "", 5, MKLDNNPlugin::impl_desc_type::jit, false },
|
||||
conv_test_params{{1, 3, 40, 40}, {3, 3}, {1, 2}, {0, 0}, {0, 0}, 20, 1, "", 5, MKLDNNPlugin::impl_desc_type::jit, false },
|
||||
conv_test_params{{1, 1, 40, 40}, {3, 3}, {1, 2}, {0, 0}, {0, 0}, 20, 1, "", 5, MKLDNNPlugin::impl_desc_type::jit, false },
|
||||
conv_test_params{{1, 1, 32, 16}, {2, 4}, {2, 1}, {0, 0}, {0, 0}, 17, 1, "", 5, MKLDNNPlugin::impl_desc_type::jit, false },
|
||||
conv_test_params{{1, 9, 32, 16}, {2, 4}, {1, 1}, {0, 0}, {0, 0}, 17, 1, "", 5, MKLDNNPlugin::impl_desc_type::ref_any, false,
|
||||
{MKLDNNPlugin::impl_desc_type::ref_any} },
|
||||
conv_test_params{{1, 4, 54, 96},
|
||||
{3, 3}, {1, 1}, {1, 1}, {0, 0}, 64, 1, "", 3, MKLDNNPlugin::impl_desc_type::ref_any,
|
||||
conv_test_params{{1, 4, 54, 96}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, 64, 1, "", 3, MKLDNNPlugin::impl_desc_type::ref_any, false,
|
||||
{MKLDNNPlugin::impl_desc_type::jit_avx512_winograd, MKLDNNPlugin::impl_desc_type::ref_any}},
|
||||
// 5D
|
||||
/*8*/ conv_test_params{{1, 3, 15, 20, 20},
|
||||
{3, 3, 3}, {2, 2, 2}, {0, 0, 0}, {0, 0, 0}, 64, 1, "", 2, MKLDNNPlugin::impl_desc_type::ref_any,
|
||||
/*8*/ conv_test_params{{1, 3, 15, 20, 20}, {3, 3, 3}, {2, 2, 2}, {0, 0, 0}, {0, 0, 0}, 64, 1, "", 2, MKLDNNPlugin::impl_desc_type::ref_any, false,
|
||||
{MKLDNNPlugin::impl_desc_type::ref_any} },
|
||||
conv_test_params{{1, 24, 15, 20, 20},
|
||||
{3, 3, 3}, {2, 2, 2}, {0, 0, 0}, {0, 0, 0}, 64, 1, "", 2, MKLDNNPlugin::impl_desc_type::ref_any,
|
||||
conv_test_params{{1, 24, 15, 20, 20}, {3, 3, 3}, {2, 2, 2}, {0, 0, 0}, {0, 0, 0}, 64, 1, "", 2, MKLDNNPlugin::impl_desc_type::ref_any, false,
|
||||
{MKLDNNPlugin::impl_desc_type::ref_any} },
|
||||
conv_test_params{{1, 32, 15, 20, 20},
|
||||
{3, 3, 3}, {2, 2, 2}, {0, 0, 0}, {0, 0, 0}, 64, 1, "", 2, MKLDNNPlugin::impl_desc_type::ref_any,
|
||||
conv_test_params{{1, 32, 15, 20, 20}, {3, 3, 3}, {2, 2, 2}, {0, 0, 0}, {0, 0, 0}, 64, 1, "", 2, MKLDNNPlugin::impl_desc_type::ref_any, false,
|
||||
{MKLDNNPlugin::impl_desc_type::ref_any} },
|
||||
conv_test_params{{1, 3, 15, 25, 20},
|
||||
{3, 3, 3}, {2, 2, 2}, {0, 0, 0}, {0, 0, 0}, 64, 1, "", 2, MKLDNNPlugin::impl_desc_type::jit },
|
||||
conv_test_params{{1, 24, 15, 25, 20},
|
||||
{3, 3, 3}, {2, 2, 2}, {0, 0, 0}, {0, 0, 0}, 64, 1, "", 2, MKLDNNPlugin::impl_desc_type::jit },
|
||||
/*13*/ conv_test_params{{1, 32, 15, 25, 20},
|
||||
{3, 3, 3}, {2, 2, 2}, {0, 0, 0}, {0, 0, 0}, 64, 1, "", 2, MKLDNNPlugin::impl_desc_type::jit },
|
||||
/*20*/ conv_test_params{{1, 16, 30, 30, 10},
|
||||
{5, 5, 5}, {1, 1, 1}, {2, 2, 2}, {2, 2, 2}, 16, 1, "", 2, MKLDNNPlugin::impl_desc_type::jit },
|
||||
conv_test_params{{1, 16, 30, 30, 10},
|
||||
{5, 5, 5}, {1, 1, 1}, {2, 2, 2}, {2, 2, 2}, 16, 1, "", 2, MKLDNNPlugin::impl_desc_type::ref_any,
|
||||
conv_test_params{{1, 3, 15, 25, 20}, {3, 3, 3}, {2, 2, 2}, {0, 0, 0}, {0, 0, 0}, 64, 1, "", 2, MKLDNNPlugin::impl_desc_type::jit, false },
|
||||
conv_test_params{{1, 24, 15, 25, 20}, {3, 3, 3}, {2, 2, 2}, {0, 0, 0}, {0, 0, 0}, 64, 1, "", 2, MKLDNNPlugin::impl_desc_type::jit, false },
|
||||
/*13*/ conv_test_params{{1, 32, 15, 25, 20}, {3, 3, 3}, {2, 2, 2}, {0, 0, 0}, {0, 0, 0}, 64, 1, "", 2, MKLDNNPlugin::impl_desc_type::jit, false },
|
||||
conv_test_params{{1, 16, 30, 30, 10}, {5, 5, 5}, {1, 1, 1}, {2, 2, 2}, {2, 2, 2}, 16, 1, "", 2, MKLDNNPlugin::impl_desc_type::jit, false },
|
||||
conv_test_params{{1, 16, 30, 30, 10}, {5, 5, 5}, {1, 1, 1}, {2, 2, 2}, {2, 2, 2}, 16, 1, "", 2, MKLDNNPlugin::impl_desc_type::ref_any, false,
|
||||
{MKLDNNPlugin::impl_desc_type::ref_any} } ));
|
||||
|
||||
#ifdef USE_MKL
|
||||
@ -380,29 +374,45 @@ INSTANTIATE_TEST_CASE_P(
|
||||
MKLTestConvolution, MKLDNNGraphConvolutionTests,
|
||||
::testing::Values(
|
||||
conv_test_params{{1, 9, 16, 32},
|
||||
{1, 1}, {1, 1}, {0, 0}, {0, 0}, 17, 1, "", 6, MKLDNNPlugin::impl_desc_type::gemm,
|
||||
{1, 1}, {1, 1}, {0, 0}, {0, 0}, 17, 1, "", 6, MKLDNNPlugin::impl_desc_type::gemm, false,
|
||||
{MKLDNNPlugin::impl_desc_type::gemm_any,
|
||||
MKLDNNPlugin::impl_desc_type::gemm_blas,
|
||||
MKLDNNPlugin::impl_desc_type::gemm_avx512,
|
||||
MKLDNNPlugin::impl_desc_type::gemm_avx2,
|
||||
MKLDNNPlugin::impl_desc_type::gemm_sse42} },
|
||||
conv_test_params{{1, 5, 15, 20, 20},
|
||||
{3, 3, 3}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}, 64, 1, "", 2, MKLDNNPlugin::impl_desc_type::gemm_blas,
|
||||
{3, 3, 3}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}, 64, 1, "", 2, MKLDNNPlugin::impl_desc_type::gemm_blas, false,
|
||||
{MKLDNNPlugin::impl_desc_type::gemm_blas} },
|
||||
conv_test_params{{1, 5, 15, 20, 20},
|
||||
{3, 3, 3}, {3, 2, 1}, {0, 0, 0}, {0, 0, 0}, 64, 1, "", 2, MKLDNNPlugin::impl_desc_type::gemm_blas,
|
||||
{3, 3, 3}, {3, 2, 1}, {0, 0, 0}, {0, 0, 0}, 64, 1, "", 2, MKLDNNPlugin::impl_desc_type::gemm_blas, false,
|
||||
{MKLDNNPlugin::impl_desc_type::gemm_blas} },
|
||||
// conv_test_params{{1, 5, 15, 20, 20},
|
||||
// {3, 3, 3}, {1, 1, 1}, {2, 2, 2}, {1, 1, 1}, 64, 1, "", 2, MKLDNNPlugin::impl_desc_type::gemm_blas,
|
||||
// {3, 3, 3}, {1, 1, 1}, {2, 2, 2}, {1, 1, 1}, 64, 1, "", 2, MKLDNNPlugin::impl_desc_type::gemm_blas, false,
|
||||
// {MKLDNNPlugin::impl_desc_type::gemm_blas} },
|
||||
conv_test_params{{1, 16, 30, 30, 10},
|
||||
{5, 5, 5}, {1, 1, 1}, {2, 2, 2}, {2, 2, 2}, 16, 1, "", 2, MKLDNNPlugin::impl_desc_type::gemm_blas,
|
||||
{5, 5, 5}, {1, 1, 1}, {2, 2, 2}, {2, 2, 2}, 16, 1, "", 2, MKLDNNPlugin::impl_desc_type::gemm_blas, false,
|
||||
{MKLDNNPlugin::impl_desc_type::gemm_blas} },
|
||||
conv_test_params{{1, 4, 16, 16, 16},
|
||||
{3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, 8, 1, "", 2, MKLDNNPlugin::impl_desc_type::gemm_blas,
|
||||
{3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, 8, 1, "", 2, MKLDNNPlugin::impl_desc_type::gemm_blas, false,
|
||||
{MKLDNNPlugin::impl_desc_type::gemm_blas} } ));
|
||||
#endif
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(
|
||||
TestConvolutionDefaultPrimitivesPriority, MKLDNNGraphConvolutionTests,
|
||||
::testing::Values(
|
||||
/*0*/ conv_test_params{{1, 9, 16, 32}, {1, 1}, {1, 1}, {0, 0}, {0, 0}, 17, 1, "same_upper", 6,
|
||||
MKLDNNPlugin::impl_desc_type::jit | MKLDNNPlugin::impl_desc_type::_1x1, true },
|
||||
conv_test_params{{1, 9, 32, 16}, {2, 4}, {1, 1}, {1, 1}, {0, 2}, 17, 1, "", 4, MKLDNNPlugin::impl_desc_type::jit, true },
|
||||
conv_test_params{{1, 9, 32, 16}, {2, 4}, {2, 1}, {0, 0}, {0, 0}, 17, 1, "", 4, MKLDNNPlugin::impl_desc_type::jit, true },
|
||||
conv_test_params{{1, 3, 40, 40}, {3, 3}, {1, 2}, {0, 0}, {0, 0}, 20, 1, "", 4, MKLDNNPlugin::impl_desc_type::jit, true },
|
||||
conv_test_params{{1, 1, 40, 40}, {3, 3}, {1, 2}, {0, 0}, {0, 0}, 20, 1, "", 4, MKLDNNPlugin::impl_desc_type::jit, true },
|
||||
conv_test_params{{1, 1, 32, 16}, {2, 4}, {2, 1}, {0, 0}, {0, 0}, 17, 1, "", 4, MKLDNNPlugin::impl_desc_type::jit, true },
|
||||
// 5D
|
||||
/*6*/ conv_test_params{{1, 3, 15, 25, 20}, {3, 3, 3}, {2, 2, 2}, {0, 0, 0}, {0, 0, 0}, 64, 1, "", 2, MKLDNNPlugin::impl_desc_type::jit, true },
|
||||
conv_test_params{{1, 24, 15, 25, 20}, {3, 3, 3}, {2, 2, 2}, {0, 0, 0}, {0, 0, 0}, 64, 1, "", 2, MKLDNNPlugin::impl_desc_type::jit, true },
|
||||
conv_test_params{{1, 32, 15, 25, 20}, {3, 3, 3}, {2, 2, 2}, {0, 0, 0}, {0, 0, 0}, 64, 1, "", 2, MKLDNNPlugin::impl_desc_type::jit, true },
|
||||
conv_test_params{{1, 16, 30, 30, 10}, {5, 5, 5}, {1, 1, 1}, {2, 2, 2}, {2, 2, 2}, 16, 1, "", 2, MKLDNNPlugin::impl_desc_type::jit, true } ));
|
||||
|
||||
|
||||
class MKLDNNGraphDynBatchConvolutionTests: public MKLDNNGraphConvolutionTests {
|
||||
protected:
|
||||
@ -490,31 +500,31 @@ INSTANTIATE_TEST_CASE_P(
|
||||
::testing::Values(
|
||||
conv_test_params{{1, 8, 16, 32},
|
||||
{1, 1}, {1, 1}, {0, 0}, {0, 0}, 17, 1, "same_upper", 7, MKLDNNPlugin::impl_desc_type::jit | MKLDNNPlugin::impl_desc_type::_1x1,
|
||||
{MKLDNNPlugin::impl_desc_type::jit_avx512_winograd}},
|
||||
false, {MKLDNNPlugin::impl_desc_type::jit_avx512_winograd}},
|
||||
conv_test_params{{1, 9, 32, 16},
|
||||
{2, 4}, {1, 1}, {0, 0}, {0, 0}, 17, 1, "", 5, MKLDNNPlugin::impl_desc_type::jit,
|
||||
{MKLDNNPlugin::impl_desc_type::jit_avx512_winograd} },
|
||||
false, {MKLDNNPlugin::impl_desc_type::jit_avx512_winograd} },
|
||||
conv_test_params{{1, 9, 32, 16},
|
||||
{2, 4}, {2, 1}, {0, 0}, {0, 0}, 17, 1, "", 5, MKLDNNPlugin::impl_desc_type::jit,
|
||||
{MKLDNNPlugin::impl_desc_type::jit_avx512_winograd} },
|
||||
false, {MKLDNNPlugin::impl_desc_type::jit_avx512_winograd} },
|
||||
conv_test_params{{1, 3, 40, 40},
|
||||
{3, 3}, {1, 2}, {0, 0}, {0, 0}, 20, 1, "", 5, MKLDNNPlugin::impl_desc_type::jit,
|
||||
{MKLDNNPlugin::impl_desc_type::jit_avx512_winograd} },
|
||||
false, {MKLDNNPlugin::impl_desc_type::jit_avx512_winograd} },
|
||||
conv_test_params{{1, 1, 40, 40},
|
||||
{3, 3}, {1, 2}, {0, 0}, {0, 0}, 20, 1, "", 5, MKLDNNPlugin::impl_desc_type::jit,
|
||||
{MKLDNNPlugin::impl_desc_type::jit_avx512_winograd} },
|
||||
false, {MKLDNNPlugin::impl_desc_type::jit_avx512_winograd} },
|
||||
conv_test_params{{1, 1, 32, 16},
|
||||
{2, 4}, {2, 1}, {0, 0}, {0, 0}, 17, 1, "", 5, MKLDNNPlugin::impl_desc_type::jit,
|
||||
{MKLDNNPlugin::impl_desc_type::jit_avx512_winograd} },
|
||||
false, {MKLDNNPlugin::impl_desc_type::jit_avx512_winograd} },
|
||||
conv_test_params{{1, 9, 32, 16},
|
||||
{2, 4}, {1, 1}, {0, 0}, {0, 0}, 17, 1, "", 5, MKLDNNPlugin::impl_desc_type::ref_any,
|
||||
{MKLDNNPlugin::impl_desc_type::ref_any} } ));
|
||||
false, {MKLDNNPlugin::impl_desc_type::ref_any} } ));
|
||||
#ifdef USE_MKL
|
||||
INSTANTIATE_TEST_CASE_P(
|
||||
MKLTestDynBatchConvolution, MKLDNNGraphDynBatchConvolutionTests,
|
||||
::testing::Values(
|
||||
conv_test_params{{1, 9, 16, 32},
|
||||
{1, 1}, {1, 1}, {0, 0}, {0, 0}, 17, 1, "", 7, MKLDNNPlugin::impl_desc_type::gemm,
|
||||
{1, 1}, {1, 1}, {0, 0}, {0, 0}, 17, 1, "", 7, MKLDNNPlugin::impl_desc_type::gemm, false,
|
||||
{MKLDNNPlugin::impl_desc_type::gemm_any,
|
||||
MKLDNNPlugin::impl_desc_type::gemm_blas,
|
||||
MKLDNNPlugin::impl_desc_type::gemm_avx512,
|
||||
|
Loading…
Reference in New Issue
Block a user