[MKLDNN_PLUGIN] Convolution node: skip initializing of primitive descriptors for planar layout if there is already jit primitive (#672)
This commit is contained in:
parent
158d32139f
commit
e53b1b7fbc
@ -272,10 +272,11 @@ void MKLDNNConvolutionNode::getSupportedDescriptors() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MKLDNNMemoryDesc in_candidate, out_candidate;
|
||||||
if (canBeExecutedInInt8()) {
|
if (canBeExecutedInInt8()) {
|
||||||
MKLDNNMemoryDesc in_candidate = MKLDNNMemoryDesc(getParentEdgeAt(0)->getDims(), inputDataType,
|
in_candidate = MKLDNNMemoryDesc(getParentEdgeAt(0)->getDims(), inputDataType,
|
||||||
getParentEdgeAt(0)->getDims().ndims() == 5 ? memory::ndhwc : memory::nhwc);
|
getParentEdgeAt(0)->getDims().ndims() == 5 ? memory::ndhwc : memory::nhwc);
|
||||||
MKLDNNMemoryDesc out_candidate = MKLDNNMemoryDesc(getChildEdgeAt(0)->getDims(), outputDataType,
|
out_candidate = MKLDNNMemoryDesc(getChildEdgeAt(0)->getDims(), outputDataType,
|
||||||
getParentEdgeAt(0)->getDims().ndims() == 5 ? memory::ndhwc : memory::nhwc);
|
getParentEdgeAt(0)->getDims().ndims() == 5 ? memory::ndhwc : memory::nhwc);
|
||||||
createDescriptor({in_candidate}, {out_candidate});
|
createDescriptor({in_candidate}, {out_candidate});
|
||||||
} else {
|
} else {
|
||||||
@ -308,13 +309,9 @@ void MKLDNNConvolutionNode::getSupportedDescriptors() {
|
|||||||
Layout layout = convLayer->input()->getLayout();
|
Layout layout = convLayer->input()->getLayout();
|
||||||
|
|
||||||
if (layout == NCHW || layout == NHWC) {
|
if (layout == NCHW || layout == NHWC) {
|
||||||
MKLDNNMemoryDesc in_candidate(getParentEdgeAt(0)->getDims(), inputDataType,
|
|
||||||
layout == NCHW ? memory::nchw : memory::nhwc);
|
|
||||||
MKLDNNMemoryDesc out_candidate(getChildEdgeAt(0)->getDims(), outputDataType,
|
|
||||||
layout == NCHW ? memory::nchw : memory::nhwc);
|
|
||||||
createDescriptor({in_candidate}, {out_candidate});
|
|
||||||
|
|
||||||
if (IC == 3 || IC == 1) {
|
if (IC == 3 || IC == 1) {
|
||||||
|
in_candidate = MKLDNNMemoryDesc(getParentEdgeAt(0)->getDims(), inputDataType,
|
||||||
|
layout == NCHW ? memory::nchw : memory::nhwc);
|
||||||
out_candidate = MKLDNNMemoryDesc(getChildEdgeAt(0)->getDims(), outputDataType, memory::nChw16c);
|
out_candidate = MKLDNNMemoryDesc(getChildEdgeAt(0)->getDims(), outputDataType, memory::nChw16c);
|
||||||
createDescriptor({in_candidate}, {out_candidate});
|
createDescriptor({in_candidate}, {out_candidate});
|
||||||
out_candidate = MKLDNNMemoryDesc(getChildEdgeAt(0)->getDims(), outputDataType, memory::nChw8c);
|
out_candidate = MKLDNNMemoryDesc(getChildEdgeAt(0)->getDims(), outputDataType, memory::nChw8c);
|
||||||
@ -327,13 +324,15 @@ void MKLDNNConvolutionNode::getSupportedDescriptors() {
|
|||||||
out_candidate = MKLDNNMemoryDesc(getChildEdgeAt(0)->getDims(), outputDataType, memory::nChw8c);
|
out_candidate = MKLDNNMemoryDesc(getChildEdgeAt(0)->getDims(), outputDataType, memory::nChw8c);
|
||||||
createDescriptor({in_candidate}, {out_candidate});
|
createDescriptor({in_candidate}, {out_candidate});
|
||||||
}
|
}
|
||||||
} else if (layout == NCDHW || layout == NDHWC) {
|
|
||||||
MKLDNNMemoryDesc in_candidate(getParentEdgeAt(0)->getDims(), inputDataType,
|
|
||||||
layout == NCDHW ? memory::ncdhw : memory::ndhwc);
|
|
||||||
MKLDNNMemoryDesc out_candidate(getChildEdgeAt(0)->getDims(), outputDataType,
|
|
||||||
layout == NCDHW ? memory::ncdhw : memory::ndhwc);
|
|
||||||
createDescriptor({in_candidate}, {out_candidate});
|
|
||||||
|
|
||||||
|
in_candidate = MKLDNNMemoryDesc(getParentEdgeAt(0)->getDims(), inputDataType,
|
||||||
|
layout == NCHW ? memory::nchw : memory::nhwc);
|
||||||
|
out_candidate = MKLDNNMemoryDesc(getChildEdgeAt(0)->getDims(), outputDataType,
|
||||||
|
layout == NCHW ? memory::nchw : memory::nhwc);
|
||||||
|
createDescriptor({in_candidate}, {out_candidate});
|
||||||
|
} else if (layout == NCDHW || layout == NDHWC) {
|
||||||
|
in_candidate = MKLDNNMemoryDesc(getParentEdgeAt(0)->getDims(), inputDataType,
|
||||||
|
layout == NCDHW ? memory::ncdhw : memory::ndhwc);
|
||||||
if (IC == 3 || IC == 1) {
|
if (IC == 3 || IC == 1) {
|
||||||
out_candidate = MKLDNNMemoryDesc(getChildEdgeAt(0)->getDims(), outputDataType, memory::nCdhw16c);
|
out_candidate = MKLDNNMemoryDesc(getChildEdgeAt(0)->getDims(), outputDataType, memory::nCdhw16c);
|
||||||
createDescriptor({in_candidate}, {out_candidate});
|
createDescriptor({in_candidate}, {out_candidate});
|
||||||
@ -347,6 +346,12 @@ void MKLDNNConvolutionNode::getSupportedDescriptors() {
|
|||||||
out_candidate = MKLDNNMemoryDesc(getChildEdgeAt(0)->getDims(), outputDataType, memory::nCdhw8c);
|
out_candidate = MKLDNNMemoryDesc(getChildEdgeAt(0)->getDims(), outputDataType, memory::nCdhw8c);
|
||||||
createDescriptor({in_candidate}, {out_candidate});
|
createDescriptor({in_candidate}, {out_candidate});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
in_candidate = MKLDNNMemoryDesc(getParentEdgeAt(0)->getDims(), inputDataType,
|
||||||
|
layout == NCDHW ? memory::ncdhw : memory::ndhwc);
|
||||||
|
out_candidate = MKLDNNMemoryDesc(getChildEdgeAt(0)->getDims(), outputDataType,
|
||||||
|
layout == NCDHW ? memory::ncdhw : memory::ndhwc);
|
||||||
|
createDescriptor({in_candidate}, {out_candidate});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -556,7 +561,11 @@ void MKLDNNConvolutionNode::initSupportedPrimitiveDescriptors() {
|
|||||||
addZeroPoints(attr);
|
addZeroPoints(attr);
|
||||||
setPostOps(attr);
|
setPostOps(attr);
|
||||||
|
|
||||||
|
bool containJitImpl = false;
|
||||||
|
|
||||||
for (auto& desc : descs) {
|
for (auto& desc : descs) {
|
||||||
|
if (containJitImpl && isPossibleToSkipInitConfig(desc))
|
||||||
|
continue;
|
||||||
auto itpd = desc.createPrimitiveDescriptorIterator(getEngine(), attr);
|
auto itpd = desc.createPrimitiveDescriptorIterator(getEngine(), attr);
|
||||||
while (itpd.is_not_end()) {
|
while (itpd.is_not_end()) {
|
||||||
InferenceEngine::LayerConfig config;
|
InferenceEngine::LayerConfig config;
|
||||||
@ -610,6 +619,8 @@ void MKLDNNConvolutionNode::initSupportedPrimitiveDescriptors() {
|
|||||||
outFormats.emplace_back(static_cast<memory::format>(itpd.dst_primitive_desc().desc().data.format));
|
outFormats.emplace_back(static_cast<memory::format>(itpd.dst_primitive_desc().desc().data.format));
|
||||||
}
|
}
|
||||||
impl_desc_type impl_type = parse_impl_name(itpd.get_impl_info_str());
|
impl_desc_type impl_type = parse_impl_name(itpd.get_impl_info_str());
|
||||||
|
if (impl_type & jit)
|
||||||
|
containJitImpl = true;
|
||||||
|
|
||||||
supportedPrimitiveDescriptors.emplace_back(config, impl_type, outFormats);
|
supportedPrimitiveDescriptors.emplace_back(config, impl_type, outFormats);
|
||||||
itpd++;
|
itpd++;
|
||||||
@ -790,8 +801,13 @@ void MKLDNNConvolutionNode::initDescriptor(const InferenceEngine::LayerConfig& c
|
|||||||
|
|
||||||
InferenceEngine::LayerConfig rightConfig = selectedPD->getConfig();
|
InferenceEngine::LayerConfig rightConfig = selectedPD->getConfig();
|
||||||
size_t selected_count = 0;
|
size_t selected_count = 0;
|
||||||
|
|
||||||
|
bool containJitImpl = false;
|
||||||
|
|
||||||
for (size_t i = 0; i < descs.size(); i++) {
|
for (size_t i = 0; i < descs.size(); i++) {
|
||||||
const auto& desc = descs[i];
|
auto& desc = descs[i];
|
||||||
|
if (containJitImpl && isPossibleToSkipInitConfig(desc))
|
||||||
|
continue;
|
||||||
auto itpd = desc.createPrimitiveDescriptorIterator(getEngine(), attr);
|
auto itpd = desc.createPrimitiveDescriptorIterator(getEngine(), attr);
|
||||||
while (itpd.is_not_end()) {
|
while (itpd.is_not_end()) {
|
||||||
InferenceEngine::LayerConfig cfg;
|
InferenceEngine::LayerConfig cfg;
|
||||||
@ -836,6 +852,8 @@ void MKLDNNConvolutionNode::initDescriptor(const InferenceEngine::LayerConfig& c
|
|||||||
cfg.outConfs.push_back(dataConfig);
|
cfg.outConfs.push_back(dataConfig);
|
||||||
}
|
}
|
||||||
impl_desc_type impl_type = parse_impl_name(itpd.get_impl_info_str());
|
impl_desc_type impl_type = parse_impl_name(itpd.get_impl_info_str());
|
||||||
|
if (impl_type & jit)
|
||||||
|
containJitImpl = true;
|
||||||
|
|
||||||
if (selected_count == selectedPrimitiveDescriptorIndex) {
|
if (selected_count == selectedPrimitiveDescriptorIndex) {
|
||||||
if (impl_type != selectedPD->getImplementationType()) {
|
if (impl_type != selectedPD->getImplementationType()) {
|
||||||
@ -888,6 +906,41 @@ void MKLDNNConvolutionNode::filterSupportedDescriptors() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool MKLDNNConvolutionNode::isPossibleToSkipInitConfig(MKLDNNDescriptor &desc) {
|
||||||
|
// WA: In some cases, we can predict in advance the type of primitive that will be called in the future.
|
||||||
|
// In particular, isPossibleToSkipInitConfig() checks whether we can skip the creation of primitives with
|
||||||
|
// gemm implementation, which significantly increase the network load time.
|
||||||
|
if (!inputMemoryFormatsFilter.empty() || !outputMemoryFormatsFilter.empty())
|
||||||
|
return false;
|
||||||
|
|
||||||
|
if (getCnnLayer()->params.find("PrimitivesPriority") != getCnnLayer()->params.end())
|
||||||
|
return false;
|
||||||
|
|
||||||
|
// Here we check that we will not delete jit_planar_conv primitive by mistake.
|
||||||
|
// It requires:
|
||||||
|
// 1) strides equal 1;
|
||||||
|
// 2) not grouped;
|
||||||
|
// 3) first dim of weights is not 1.
|
||||||
|
bool isPossibleJitPlanar = true;
|
||||||
|
if (isGrouped || weightDims[0] != 1)
|
||||||
|
isPossibleJitPlanar = false;
|
||||||
|
for (int i = 0; i < stride.size(); i++)
|
||||||
|
if (stride[i] != 1)
|
||||||
|
isPossibleJitPlanar = false;
|
||||||
|
|
||||||
|
std::shared_ptr<mkldnn::convolution_forward::desc> convDesc(desc);
|
||||||
|
auto srcMemFmt = convDesc->data.src_desc.format;
|
||||||
|
auto dstMemFmt = convDesc->data.dst_desc.format;
|
||||||
|
auto srcDataType = convDesc->data.src_desc.data_type;
|
||||||
|
auto dstDataType = convDesc->data.dst_desc.data_type;
|
||||||
|
bool isPlanarFloatConv = (srcMemFmt == memory::nchw || srcMemFmt == memory::ncdhw)
|
||||||
|
&& (dstMemFmt == memory::nchw || dstMemFmt == memory::ncdhw)
|
||||||
|
&& srcDataType == memory::f32
|
||||||
|
&& dstDataType == memory::f32;
|
||||||
|
|
||||||
|
return !isPossibleJitPlanar && isPlanarFloatConv;
|
||||||
|
}
|
||||||
|
|
||||||
MKLDNNMemoryDesc MKLDNNConvolutionNode::getSrcMemDesc(mkldnn::primitive_desc_iterator &primitive_desc_it, size_t idx) {
|
MKLDNNMemoryDesc MKLDNNConvolutionNode::getSrcMemDesc(mkldnn::primitive_desc_iterator &primitive_desc_it, size_t idx) {
|
||||||
InferenceEngine::TensorDesc desc = idx > 0 ? MKLDNNMemoryDesc(primitive_desc_it.weights_primitive_desc(idx - 1).desc())
|
InferenceEngine::TensorDesc desc = idx > 0 ? MKLDNNMemoryDesc(primitive_desc_it.weights_primitive_desc(idx - 1).desc())
|
||||||
: MKLDNNMemoryDesc(primitive_desc_it.src_primitive_desc(idx).desc());
|
: MKLDNNMemoryDesc(primitive_desc_it.src_primitive_desc(idx).desc());
|
||||||
|
@ -27,6 +27,7 @@ public:
|
|||||||
void initSupportedPrimitiveDescriptors() override;
|
void initSupportedPrimitiveDescriptors() override;
|
||||||
void filterSupportedPrimitiveDescriptors() override;
|
void filterSupportedPrimitiveDescriptors() override;
|
||||||
void filterSupportedDescriptors();
|
void filterSupportedDescriptors();
|
||||||
|
bool isPossibleToSkipInitConfig(MKLDNNDescriptor &desc);
|
||||||
bool created() const override;
|
bool created() const override;
|
||||||
bool canBeInPlace() const override {
|
bool canBeInPlace() const override {
|
||||||
return false;
|
return false;
|
||||||
|
@ -31,6 +31,7 @@ struct conv_test_params {
|
|||||||
size_t num_prim_desc;
|
size_t num_prim_desc;
|
||||||
|
|
||||||
int selectedType;
|
int selectedType;
|
||||||
|
bool defaultPrimitivesPriority;
|
||||||
vector<MKLDNNPlugin::impl_desc_type> preferTypes;
|
vector<MKLDNNPlugin::impl_desc_type> preferTypes;
|
||||||
|
|
||||||
vector<std::function<void(MKLDNNPlugin::PrimitiveDescInfo)>> comp;
|
vector<std::function<void(MKLDNNPlugin::PrimitiveDescInfo)>> comp;
|
||||||
@ -149,7 +150,7 @@ class MKLDNNGraphConvolutionTests: public TestsCommon,
|
|||||||
<convolution _AP_ kernel="_K_"
|
<convolution _AP_ kernel="_K_"
|
||||||
pads_begin="_PB_" pads_end="_PE_"
|
pads_begin="_PB_" pads_end="_PE_"
|
||||||
strides="_KS_"
|
strides="_KS_"
|
||||||
output="_OC_" group="_GC_" PrimitivesPriority="_IMPLS_"/>
|
output="_OC_" group="_GC_" _PRIM_PRIORITY_/>
|
||||||
|
|
||||||
<weights offset="0" size="_S1_" />
|
<weights offset="0" size="_S1_" />
|
||||||
<biases offset="_S1_" size="_S2_" />
|
<biases offset="_S1_" size="_S2_" />
|
||||||
@ -216,13 +217,17 @@ protected:
|
|||||||
REPLACE_WITH_NUM(model, "_S1_", w_data_size);
|
REPLACE_WITH_NUM(model, "_S1_", w_data_size);
|
||||||
REPLACE_WITH_NUM(model, "_S2_", b_data_size);
|
REPLACE_WITH_NUM(model, "_S2_", b_data_size);
|
||||||
|
|
||||||
std::string impls;
|
std::string primitivesPriorityStr;
|
||||||
for (const auto& preferType : p.preferTypes) {
|
if (!p.defaultPrimitivesPriority) {
|
||||||
if (!impls.empty())
|
std::string impls;
|
||||||
impls += ",";
|
for (const auto& preferType : p.preferTypes) {
|
||||||
impls += "cpu:" + MKLDNNGraphTestClass::getStrPrimitiveDescriptorType(preferType);
|
if (!impls.empty())
|
||||||
|
impls += ",";
|
||||||
|
impls += "cpu:" + MKLDNNGraphTestClass::getStrPrimitiveDescriptorType(preferType);
|
||||||
|
}
|
||||||
|
primitivesPriorityStr = "PrimitivesPriority=\"" + impls + "\"";
|
||||||
}
|
}
|
||||||
REPLACE_WITH_STR(model, "_IMPLS_", impls);
|
REPLACE_WITH_STR(model, "_PRIM_PRIORITY_", primitivesPriorityStr);
|
||||||
|
|
||||||
return model;
|
return model;
|
||||||
}
|
}
|
||||||
@ -263,6 +268,10 @@ protected:
|
|||||||
if (node->getType() == MKLDNNPlugin::Convolution) {
|
if (node->getType() == MKLDNNPlugin::Convolution) {
|
||||||
ASSERT_LE(p.num_prim_desc, node->getSupportedPrimitiveDescriptors().size());
|
ASSERT_LE(p.num_prim_desc, node->getSupportedPrimitiveDescriptors().size());
|
||||||
for (const auto prim : node->getSupportedPrimitiveDescriptors()) {
|
for (const auto prim : node->getSupportedPrimitiveDescriptors()) {
|
||||||
|
if (p.defaultPrimitivesPriority) {
|
||||||
|
if (prim.getImplementationType() & MKLDNNPlugin::impl_desc_type::gemm)
|
||||||
|
FAIL() << "There should be no gemm implementation in supportedPrimitiveDescriptors";
|
||||||
|
}
|
||||||
std::cout << MKLDNNGraphTestClass::getStrPrimitiveDescriptorType(prim.getImplementationType()) << " ";
|
std::cout << MKLDNNGraphTestClass::getStrPrimitiveDescriptorType(prim.getImplementationType()) << " ";
|
||||||
}
|
}
|
||||||
std::cout << std::endl;
|
std::cout << std::endl;
|
||||||
@ -335,44 +344,29 @@ TEST_P(MKLDNNGraphConvolutionTests, TestsConvolution) {}
|
|||||||
INSTANTIATE_TEST_CASE_P(
|
INSTANTIATE_TEST_CASE_P(
|
||||||
TestConvolution, MKLDNNGraphConvolutionTests,
|
TestConvolution, MKLDNNGraphConvolutionTests,
|
||||||
::testing::Values(
|
::testing::Values(
|
||||||
/*0*/ conv_test_params{{1, 9, 16, 32},
|
/*0*/ conv_test_params{{1, 9, 16, 32}, {1, 1}, {1, 1}, {0, 0}, {0, 0}, 17, 1, "same_upper", 6,
|
||||||
{1, 1}, {1, 1}, {0, 0}, {0, 0}, 17, 1, "same_upper", 6, MKLDNNPlugin::impl_desc_type::jit | MKLDNNPlugin::impl_desc_type::_1x1 },
|
MKLDNNPlugin::impl_desc_type::jit | MKLDNNPlugin::impl_desc_type::_1x1, false },
|
||||||
conv_test_params{{1, 9, 32, 16},
|
conv_test_params{{1, 9, 32, 16}, {2, 4}, {1, 1}, {1, 1}, {0, 2}, 17, 1, "", 5, MKLDNNPlugin::impl_desc_type::jit, false },
|
||||||
{2, 4}, {1, 1}, {1, 1}, {0, 2}, 17, 1, "", 5, MKLDNNPlugin::impl_desc_type::jit },
|
conv_test_params{{1, 9, 32, 16}, {2, 4}, {2, 1}, {0, 0}, {0, 0}, 17, 1, "", 5, MKLDNNPlugin::impl_desc_type::jit, false },
|
||||||
conv_test_params{{1, 9, 32, 16},
|
conv_test_params{{1, 3, 40, 40}, {3, 3}, {1, 2}, {0, 0}, {0, 0}, 20, 1, "", 5, MKLDNNPlugin::impl_desc_type::jit, false },
|
||||||
{2, 4}, {2, 1}, {0, 0}, {0, 0}, 17, 1, "", 5, MKLDNNPlugin::impl_desc_type::jit },
|
conv_test_params{{1, 1, 40, 40}, {3, 3}, {1, 2}, {0, 0}, {0, 0}, 20, 1, "", 5, MKLDNNPlugin::impl_desc_type::jit, false },
|
||||||
conv_test_params{{1, 3, 40, 40},
|
conv_test_params{{1, 1, 32, 16}, {2, 4}, {2, 1}, {0, 0}, {0, 0}, 17, 1, "", 5, MKLDNNPlugin::impl_desc_type::jit, false },
|
||||||
{3, 3}, {1, 2}, {0, 0}, {0, 0}, 20, 1, "", 5, MKLDNNPlugin::impl_desc_type::jit },
|
conv_test_params{{1, 9, 32, 16}, {2, 4}, {1, 1}, {0, 0}, {0, 0}, 17, 1, "", 5, MKLDNNPlugin::impl_desc_type::ref_any, false,
|
||||||
conv_test_params{{1, 1, 40, 40},
|
|
||||||
{3, 3}, {1, 2}, {0, 0}, {0, 0}, 20, 1, "", 5, MKLDNNPlugin::impl_desc_type::jit },
|
|
||||||
conv_test_params{{1, 1, 32, 16},
|
|
||||||
{2, 4}, {2, 1}, {0, 0}, {0, 0}, 17, 1, "", 5, MKLDNNPlugin::impl_desc_type::jit },
|
|
||||||
conv_test_params{{1, 9, 32, 16},
|
|
||||||
{2, 4}, {1, 1}, {0, 0}, {0, 0}, 17, 1, "", 5, MKLDNNPlugin::impl_desc_type::ref_any,
|
|
||||||
{MKLDNNPlugin::impl_desc_type::ref_any} },
|
{MKLDNNPlugin::impl_desc_type::ref_any} },
|
||||||
conv_test_params{{1, 4, 54, 96},
|
conv_test_params{{1, 4, 54, 96}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, 64, 1, "", 3, MKLDNNPlugin::impl_desc_type::ref_any, false,
|
||||||
{3, 3}, {1, 1}, {1, 1}, {0, 0}, 64, 1, "", 3, MKLDNNPlugin::impl_desc_type::ref_any,
|
|
||||||
{MKLDNNPlugin::impl_desc_type::jit_avx512_winograd, MKLDNNPlugin::impl_desc_type::ref_any}},
|
{MKLDNNPlugin::impl_desc_type::jit_avx512_winograd, MKLDNNPlugin::impl_desc_type::ref_any}},
|
||||||
// 5D
|
// 5D
|
||||||
/*8*/ conv_test_params{{1, 3, 15, 20, 20},
|
/*8*/ conv_test_params{{1, 3, 15, 20, 20}, {3, 3, 3}, {2, 2, 2}, {0, 0, 0}, {0, 0, 0}, 64, 1, "", 2, MKLDNNPlugin::impl_desc_type::ref_any, false,
|
||||||
{3, 3, 3}, {2, 2, 2}, {0, 0, 0}, {0, 0, 0}, 64, 1, "", 2, MKLDNNPlugin::impl_desc_type::ref_any,
|
|
||||||
{MKLDNNPlugin::impl_desc_type::ref_any} },
|
{MKLDNNPlugin::impl_desc_type::ref_any} },
|
||||||
conv_test_params{{1, 24, 15, 20, 20},
|
conv_test_params{{1, 24, 15, 20, 20}, {3, 3, 3}, {2, 2, 2}, {0, 0, 0}, {0, 0, 0}, 64, 1, "", 2, MKLDNNPlugin::impl_desc_type::ref_any, false,
|
||||||
{3, 3, 3}, {2, 2, 2}, {0, 0, 0}, {0, 0, 0}, 64, 1, "", 2, MKLDNNPlugin::impl_desc_type::ref_any,
|
|
||||||
{MKLDNNPlugin::impl_desc_type::ref_any} },
|
{MKLDNNPlugin::impl_desc_type::ref_any} },
|
||||||
conv_test_params{{1, 32, 15, 20, 20},
|
conv_test_params{{1, 32, 15, 20, 20}, {3, 3, 3}, {2, 2, 2}, {0, 0, 0}, {0, 0, 0}, 64, 1, "", 2, MKLDNNPlugin::impl_desc_type::ref_any, false,
|
||||||
{3, 3, 3}, {2, 2, 2}, {0, 0, 0}, {0, 0, 0}, 64, 1, "", 2, MKLDNNPlugin::impl_desc_type::ref_any,
|
|
||||||
{MKLDNNPlugin::impl_desc_type::ref_any} },
|
{MKLDNNPlugin::impl_desc_type::ref_any} },
|
||||||
conv_test_params{{1, 3, 15, 25, 20},
|
conv_test_params{{1, 3, 15, 25, 20}, {3, 3, 3}, {2, 2, 2}, {0, 0, 0}, {0, 0, 0}, 64, 1, "", 2, MKLDNNPlugin::impl_desc_type::jit, false },
|
||||||
{3, 3, 3}, {2, 2, 2}, {0, 0, 0}, {0, 0, 0}, 64, 1, "", 2, MKLDNNPlugin::impl_desc_type::jit },
|
conv_test_params{{1, 24, 15, 25, 20}, {3, 3, 3}, {2, 2, 2}, {0, 0, 0}, {0, 0, 0}, 64, 1, "", 2, MKLDNNPlugin::impl_desc_type::jit, false },
|
||||||
conv_test_params{{1, 24, 15, 25, 20},
|
/*13*/ conv_test_params{{1, 32, 15, 25, 20}, {3, 3, 3}, {2, 2, 2}, {0, 0, 0}, {0, 0, 0}, 64, 1, "", 2, MKLDNNPlugin::impl_desc_type::jit, false },
|
||||||
{3, 3, 3}, {2, 2, 2}, {0, 0, 0}, {0, 0, 0}, 64, 1, "", 2, MKLDNNPlugin::impl_desc_type::jit },
|
conv_test_params{{1, 16, 30, 30, 10}, {5, 5, 5}, {1, 1, 1}, {2, 2, 2}, {2, 2, 2}, 16, 1, "", 2, MKLDNNPlugin::impl_desc_type::jit, false },
|
||||||
/*13*/ conv_test_params{{1, 32, 15, 25, 20},
|
conv_test_params{{1, 16, 30, 30, 10}, {5, 5, 5}, {1, 1, 1}, {2, 2, 2}, {2, 2, 2}, 16, 1, "", 2, MKLDNNPlugin::impl_desc_type::ref_any, false,
|
||||||
{3, 3, 3}, {2, 2, 2}, {0, 0, 0}, {0, 0, 0}, 64, 1, "", 2, MKLDNNPlugin::impl_desc_type::jit },
|
|
||||||
/*20*/ conv_test_params{{1, 16, 30, 30, 10},
|
|
||||||
{5, 5, 5}, {1, 1, 1}, {2, 2, 2}, {2, 2, 2}, 16, 1, "", 2, MKLDNNPlugin::impl_desc_type::jit },
|
|
||||||
conv_test_params{{1, 16, 30, 30, 10},
|
|
||||||
{5, 5, 5}, {1, 1, 1}, {2, 2, 2}, {2, 2, 2}, 16, 1, "", 2, MKLDNNPlugin::impl_desc_type::ref_any,
|
|
||||||
{MKLDNNPlugin::impl_desc_type::ref_any} } ));
|
{MKLDNNPlugin::impl_desc_type::ref_any} } ));
|
||||||
|
|
||||||
#ifdef USE_MKL
|
#ifdef USE_MKL
|
||||||
@ -380,29 +374,45 @@ INSTANTIATE_TEST_CASE_P(
|
|||||||
MKLTestConvolution, MKLDNNGraphConvolutionTests,
|
MKLTestConvolution, MKLDNNGraphConvolutionTests,
|
||||||
::testing::Values(
|
::testing::Values(
|
||||||
conv_test_params{{1, 9, 16, 32},
|
conv_test_params{{1, 9, 16, 32},
|
||||||
{1, 1}, {1, 1}, {0, 0}, {0, 0}, 17, 1, "", 6, MKLDNNPlugin::impl_desc_type::gemm,
|
{1, 1}, {1, 1}, {0, 0}, {0, 0}, 17, 1, "", 6, MKLDNNPlugin::impl_desc_type::gemm, false,
|
||||||
{MKLDNNPlugin::impl_desc_type::gemm_any,
|
{MKLDNNPlugin::impl_desc_type::gemm_any,
|
||||||
MKLDNNPlugin::impl_desc_type::gemm_blas,
|
MKLDNNPlugin::impl_desc_type::gemm_blas,
|
||||||
MKLDNNPlugin::impl_desc_type::gemm_avx512,
|
MKLDNNPlugin::impl_desc_type::gemm_avx512,
|
||||||
MKLDNNPlugin::impl_desc_type::gemm_avx2,
|
MKLDNNPlugin::impl_desc_type::gemm_avx2,
|
||||||
MKLDNNPlugin::impl_desc_type::gemm_sse42} },
|
MKLDNNPlugin::impl_desc_type::gemm_sse42} },
|
||||||
conv_test_params{{1, 5, 15, 20, 20},
|
conv_test_params{{1, 5, 15, 20, 20},
|
||||||
{3, 3, 3}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}, 64, 1, "", 2, MKLDNNPlugin::impl_desc_type::gemm_blas,
|
{3, 3, 3}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}, 64, 1, "", 2, MKLDNNPlugin::impl_desc_type::gemm_blas, false,
|
||||||
{MKLDNNPlugin::impl_desc_type::gemm_blas} },
|
{MKLDNNPlugin::impl_desc_type::gemm_blas} },
|
||||||
conv_test_params{{1, 5, 15, 20, 20},
|
conv_test_params{{1, 5, 15, 20, 20},
|
||||||
{3, 3, 3}, {3, 2, 1}, {0, 0, 0}, {0, 0, 0}, 64, 1, "", 2, MKLDNNPlugin::impl_desc_type::gemm_blas,
|
{3, 3, 3}, {3, 2, 1}, {0, 0, 0}, {0, 0, 0}, 64, 1, "", 2, MKLDNNPlugin::impl_desc_type::gemm_blas, false,
|
||||||
{MKLDNNPlugin::impl_desc_type::gemm_blas} },
|
{MKLDNNPlugin::impl_desc_type::gemm_blas} },
|
||||||
// conv_test_params{{1, 5, 15, 20, 20},
|
// conv_test_params{{1, 5, 15, 20, 20},
|
||||||
// {3, 3, 3}, {1, 1, 1}, {2, 2, 2}, {1, 1, 1}, 64, 1, "", 2, MKLDNNPlugin::impl_desc_type::gemm_blas,
|
// {3, 3, 3}, {1, 1, 1}, {2, 2, 2}, {1, 1, 1}, 64, 1, "", 2, MKLDNNPlugin::impl_desc_type::gemm_blas, false,
|
||||||
// {MKLDNNPlugin::impl_desc_type::gemm_blas} },
|
// {MKLDNNPlugin::impl_desc_type::gemm_blas} },
|
||||||
conv_test_params{{1, 16, 30, 30, 10},
|
conv_test_params{{1, 16, 30, 30, 10},
|
||||||
{5, 5, 5}, {1, 1, 1}, {2, 2, 2}, {2, 2, 2}, 16, 1, "", 2, MKLDNNPlugin::impl_desc_type::gemm_blas,
|
{5, 5, 5}, {1, 1, 1}, {2, 2, 2}, {2, 2, 2}, 16, 1, "", 2, MKLDNNPlugin::impl_desc_type::gemm_blas, false,
|
||||||
{MKLDNNPlugin::impl_desc_type::gemm_blas} },
|
{MKLDNNPlugin::impl_desc_type::gemm_blas} },
|
||||||
conv_test_params{{1, 4, 16, 16, 16},
|
conv_test_params{{1, 4, 16, 16, 16},
|
||||||
{3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, 8, 1, "", 2, MKLDNNPlugin::impl_desc_type::gemm_blas,
|
{3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, 8, 1, "", 2, MKLDNNPlugin::impl_desc_type::gemm_blas, false,
|
||||||
{MKLDNNPlugin::impl_desc_type::gemm_blas} } ));
|
{MKLDNNPlugin::impl_desc_type::gemm_blas} } ));
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_CASE_P(
|
||||||
|
TestConvolutionDefaultPrimitivesPriority, MKLDNNGraphConvolutionTests,
|
||||||
|
::testing::Values(
|
||||||
|
/*0*/ conv_test_params{{1, 9, 16, 32}, {1, 1}, {1, 1}, {0, 0}, {0, 0}, 17, 1, "same_upper", 6,
|
||||||
|
MKLDNNPlugin::impl_desc_type::jit | MKLDNNPlugin::impl_desc_type::_1x1, true },
|
||||||
|
conv_test_params{{1, 9, 32, 16}, {2, 4}, {1, 1}, {1, 1}, {0, 2}, 17, 1, "", 4, MKLDNNPlugin::impl_desc_type::jit, true },
|
||||||
|
conv_test_params{{1, 9, 32, 16}, {2, 4}, {2, 1}, {0, 0}, {0, 0}, 17, 1, "", 4, MKLDNNPlugin::impl_desc_type::jit, true },
|
||||||
|
conv_test_params{{1, 3, 40, 40}, {3, 3}, {1, 2}, {0, 0}, {0, 0}, 20, 1, "", 4, MKLDNNPlugin::impl_desc_type::jit, true },
|
||||||
|
conv_test_params{{1, 1, 40, 40}, {3, 3}, {1, 2}, {0, 0}, {0, 0}, 20, 1, "", 4, MKLDNNPlugin::impl_desc_type::jit, true },
|
||||||
|
conv_test_params{{1, 1, 32, 16}, {2, 4}, {2, 1}, {0, 0}, {0, 0}, 17, 1, "", 4, MKLDNNPlugin::impl_desc_type::jit, true },
|
||||||
|
// 5D
|
||||||
|
/*6*/ conv_test_params{{1, 3, 15, 25, 20}, {3, 3, 3}, {2, 2, 2}, {0, 0, 0}, {0, 0, 0}, 64, 1, "", 2, MKLDNNPlugin::impl_desc_type::jit, true },
|
||||||
|
conv_test_params{{1, 24, 15, 25, 20}, {3, 3, 3}, {2, 2, 2}, {0, 0, 0}, {0, 0, 0}, 64, 1, "", 2, MKLDNNPlugin::impl_desc_type::jit, true },
|
||||||
|
conv_test_params{{1, 32, 15, 25, 20}, {3, 3, 3}, {2, 2, 2}, {0, 0, 0}, {0, 0, 0}, 64, 1, "", 2, MKLDNNPlugin::impl_desc_type::jit, true },
|
||||||
|
conv_test_params{{1, 16, 30, 30, 10}, {5, 5, 5}, {1, 1, 1}, {2, 2, 2}, {2, 2, 2}, 16, 1, "", 2, MKLDNNPlugin::impl_desc_type::jit, true } ));
|
||||||
|
|
||||||
|
|
||||||
class MKLDNNGraphDynBatchConvolutionTests: public MKLDNNGraphConvolutionTests {
|
class MKLDNNGraphDynBatchConvolutionTests: public MKLDNNGraphConvolutionTests {
|
||||||
protected:
|
protected:
|
||||||
@ -490,31 +500,31 @@ INSTANTIATE_TEST_CASE_P(
|
|||||||
::testing::Values(
|
::testing::Values(
|
||||||
conv_test_params{{1, 8, 16, 32},
|
conv_test_params{{1, 8, 16, 32},
|
||||||
{1, 1}, {1, 1}, {0, 0}, {0, 0}, 17, 1, "same_upper", 7, MKLDNNPlugin::impl_desc_type::jit | MKLDNNPlugin::impl_desc_type::_1x1,
|
{1, 1}, {1, 1}, {0, 0}, {0, 0}, 17, 1, "same_upper", 7, MKLDNNPlugin::impl_desc_type::jit | MKLDNNPlugin::impl_desc_type::_1x1,
|
||||||
{MKLDNNPlugin::impl_desc_type::jit_avx512_winograd}},
|
false, {MKLDNNPlugin::impl_desc_type::jit_avx512_winograd}},
|
||||||
conv_test_params{{1, 9, 32, 16},
|
conv_test_params{{1, 9, 32, 16},
|
||||||
{2, 4}, {1, 1}, {0, 0}, {0, 0}, 17, 1, "", 5, MKLDNNPlugin::impl_desc_type::jit,
|
{2, 4}, {1, 1}, {0, 0}, {0, 0}, 17, 1, "", 5, MKLDNNPlugin::impl_desc_type::jit,
|
||||||
{MKLDNNPlugin::impl_desc_type::jit_avx512_winograd} },
|
false, {MKLDNNPlugin::impl_desc_type::jit_avx512_winograd} },
|
||||||
conv_test_params{{1, 9, 32, 16},
|
conv_test_params{{1, 9, 32, 16},
|
||||||
{2, 4}, {2, 1}, {0, 0}, {0, 0}, 17, 1, "", 5, MKLDNNPlugin::impl_desc_type::jit,
|
{2, 4}, {2, 1}, {0, 0}, {0, 0}, 17, 1, "", 5, MKLDNNPlugin::impl_desc_type::jit,
|
||||||
{MKLDNNPlugin::impl_desc_type::jit_avx512_winograd} },
|
false, {MKLDNNPlugin::impl_desc_type::jit_avx512_winograd} },
|
||||||
conv_test_params{{1, 3, 40, 40},
|
conv_test_params{{1, 3, 40, 40},
|
||||||
{3, 3}, {1, 2}, {0, 0}, {0, 0}, 20, 1, "", 5, MKLDNNPlugin::impl_desc_type::jit,
|
{3, 3}, {1, 2}, {0, 0}, {0, 0}, 20, 1, "", 5, MKLDNNPlugin::impl_desc_type::jit,
|
||||||
{MKLDNNPlugin::impl_desc_type::jit_avx512_winograd} },
|
false, {MKLDNNPlugin::impl_desc_type::jit_avx512_winograd} },
|
||||||
conv_test_params{{1, 1, 40, 40},
|
conv_test_params{{1, 1, 40, 40},
|
||||||
{3, 3}, {1, 2}, {0, 0}, {0, 0}, 20, 1, "", 5, MKLDNNPlugin::impl_desc_type::jit,
|
{3, 3}, {1, 2}, {0, 0}, {0, 0}, 20, 1, "", 5, MKLDNNPlugin::impl_desc_type::jit,
|
||||||
{MKLDNNPlugin::impl_desc_type::jit_avx512_winograd} },
|
false, {MKLDNNPlugin::impl_desc_type::jit_avx512_winograd} },
|
||||||
conv_test_params{{1, 1, 32, 16},
|
conv_test_params{{1, 1, 32, 16},
|
||||||
{2, 4}, {2, 1}, {0, 0}, {0, 0}, 17, 1, "", 5, MKLDNNPlugin::impl_desc_type::jit,
|
{2, 4}, {2, 1}, {0, 0}, {0, 0}, 17, 1, "", 5, MKLDNNPlugin::impl_desc_type::jit,
|
||||||
{MKLDNNPlugin::impl_desc_type::jit_avx512_winograd} },
|
false, {MKLDNNPlugin::impl_desc_type::jit_avx512_winograd} },
|
||||||
conv_test_params{{1, 9, 32, 16},
|
conv_test_params{{1, 9, 32, 16},
|
||||||
{2, 4}, {1, 1}, {0, 0}, {0, 0}, 17, 1, "", 5, MKLDNNPlugin::impl_desc_type::ref_any,
|
{2, 4}, {1, 1}, {0, 0}, {0, 0}, 17, 1, "", 5, MKLDNNPlugin::impl_desc_type::ref_any,
|
||||||
{MKLDNNPlugin::impl_desc_type::ref_any} } ));
|
false, {MKLDNNPlugin::impl_desc_type::ref_any} } ));
|
||||||
#ifdef USE_MKL
|
#ifdef USE_MKL
|
||||||
INSTANTIATE_TEST_CASE_P(
|
INSTANTIATE_TEST_CASE_P(
|
||||||
MKLTestDynBatchConvolution, MKLDNNGraphDynBatchConvolutionTests,
|
MKLTestDynBatchConvolution, MKLDNNGraphDynBatchConvolutionTests,
|
||||||
::testing::Values(
|
::testing::Values(
|
||||||
conv_test_params{{1, 9, 16, 32},
|
conv_test_params{{1, 9, 16, 32},
|
||||||
{1, 1}, {1, 1}, {0, 0}, {0, 0}, 17, 1, "", 7, MKLDNNPlugin::impl_desc_type::gemm,
|
{1, 1}, {1, 1}, {0, 0}, {0, 0}, 17, 1, "", 7, MKLDNNPlugin::impl_desc_type::gemm, false,
|
||||||
{MKLDNNPlugin::impl_desc_type::gemm_any,
|
{MKLDNNPlugin::impl_desc_type::gemm_any,
|
||||||
MKLDNNPlugin::impl_desc_type::gemm_blas,
|
MKLDNNPlugin::impl_desc_type::gemm_blas,
|
||||||
MKLDNNPlugin::impl_desc_type::gemm_avx512,
|
MKLDNNPlugin::impl_desc_type::gemm_avx512,
|
||||||
|
Loading…
Reference in New Issue
Block a user