[CPU][WA] Supported 3D layout for FullyConnected primitive

Extended jit uni depthwise primitive to support 3D inputs
This commit is contained in:
dmitry-gorokhov 2020-05-29 16:53:59 +03:00 committed by Alexander Peskov
parent b6f2c06b26
commit bcd38100db
5 changed files with 132 additions and 34 deletions

View File

@ -780,27 +780,67 @@ void MKLDNNGraphOptimizer::FuseFullyConnectedAndSimpleOperation(MKLDNNGraph &gra
node->getChildEdges().size() == 1; node->getChildEdges().size() == 1;
}; };
auto isSutableChildNode = [&](MKLDNNNodePtr node) { auto isSutableChildNode = [&](MKLDNNNodePtr parentNode, MKLDNNNodePtr childNode) {
if (!node->getCnnLayer()) if (!childNode->getCnnLayer())
return false; return false;
if (node->getType() == Quantize) { if (childNode->getType() == Quantize) {
auto* quantizeNode = dynamic_cast<MKLDNNQuantizeNode*>(node.get()); auto* quantizeNode = dynamic_cast<MKLDNNQuantizeNode*>(childNode.get());
if (quantizeNode == nullptr) if (quantizeNode == nullptr)
THROW_IE_EXCEPTION << "Cannot get quantize layer " << node->getName(); THROW_IE_EXCEPTION << "Cannot get quantize layer " << childNode->getName();
if (parentNode->getParentEdgesAtPort(0)[0]->getDims().ndims() != 3) {
return !quantizeNode->isBinarization(); return !quantizeNode->isBinarization();
} else if (node->getType() == Depthwise) { } else {
auto* depthwiseNode = dynamic_cast<MKLDNNDepthwiseNode*>(node.get()); return (quantizeNode->isInputLowBroadcast() && quantizeNode->isInputHighBroadcast() &&
quantizeNode->isOutputLowBroadcast() && quantizeNode->isOutputHighBroadcast() &&
!quantizeNode->isBinarization());
}
} else if (childNode->getType() == Depthwise) {
auto* depthwiseNode = dynamic_cast<MKLDNNDepthwiseNode*>(childNode.get());
if (depthwiseNode == nullptr) if (depthwiseNode == nullptr)
THROW_IE_EXCEPTION << "Cannot get depthwise layer " << node->getName(); THROW_IE_EXCEPTION << "Cannot get depthwise layer " << childNode->getName();
return ((depthwiseNode->getAlgorithm() == mkldnn::algorithm::depthwise_scale_shift && depthwiseNode->isWithBiases()) || if (parentNode->getParentEdgesAtPort(0)[0]->getDims().ndims() != 3) {
return ((depthwiseNode->getAlgorithm() == mkldnn::algorithm::depthwise_scale_shift &&
depthwiseNode->isWithBiases()) ||
(depthwiseNode->getAlgorithm() == mkldnn::algorithm::depthwise_prelu)); (depthwiseNode->getAlgorithm() == mkldnn::algorithm::depthwise_prelu));
} else if (node->getType() == Activation) { } else {
auto* activationNode = dynamic_cast<MKLDNNActivationNode*>(node.get()); const auto &depthwiseLayer = depthwiseNode->getCnnLayer();
if (depthwiseLayer == nullptr)
THROW_IE_EXCEPTION << "Cannot get scale shift layer " << depthwiseNode->getName();
if (depthwiseNode->getAlgorithm() != mkldnn::algorithm::depthwise_scale_shift)
return false;
Blob::Ptr scalesBlob = depthwiseLayer->blobs["weights"];
if (scalesBlob == nullptr)
return false;
Blob::Ptr shiftsBlob = depthwiseLayer->blobs["biases"];
if (shiftsBlob == nullptr)
return false;
const float* scalesBufferPtr = scalesBlob->buffer().as<float*>();
const float* shiftsBufferPtr = shiftsBlob->buffer().as<float*>();
if (scalesBlob->size() != shiftsBlob->size())
return false;
for (int i = 1; i < scalesBlob->size(); i++)
if (scalesBufferPtr[0] != scalesBufferPtr[i])
return false;
for (int i = 1; i < shiftsBlob->size(); i++)
if (shiftsBufferPtr[0] != shiftsBufferPtr[i])
return false;
return true;
}
} else if (childNode->getType() == Activation) {
auto* activationNode = dynamic_cast<MKLDNNActivationNode*>(childNode.get());
if (activationNode == nullptr) if (activationNode == nullptr)
THROW_IE_EXCEPTION << "Cannot get activation layer " << node->getName(); THROW_IE_EXCEPTION << "Cannot get activation layer " << childNode->getName();
return isOneOf(activationNode->getAlgorithm(), {eltwise_relu, eltwise_gelu, eltwise_elu, eltwise_logistic, eltwise_bounded_relu, eltwise_clamp}); return isOneOf(activationNode->getAlgorithm(), {eltwise_relu, eltwise_gelu, eltwise_elu, eltwise_logistic, eltwise_bounded_relu, eltwise_clamp});
} }
@ -817,7 +857,7 @@ void MKLDNNGraphOptimizer::FuseFullyConnectedAndSimpleOperation(MKLDNNGraph &gra
} }
auto childNode = parentNode->getChildEdgeAt(0)->getChild(); auto childNode = parentNode->getChildEdgeAt(0)->getChild();
if (!isSutableChildNode(childNode)) { if (!isSutableChildNode(parentNode, childNode)) {
parent++; parent++;
continue; continue;
} }

View File

@ -103,7 +103,9 @@ void MKLDNNFullyConnectedNode::getSupportedDescriptors() {
MKLDNNDims inDims(fcLayer->input()->getDims()); MKLDNNDims inDims(fcLayer->input()->getDims());
if (inDims.ndims() == 2) { if (inDims.ndims() == 2) {
weightsDims = {fcLayer->_out_num, static_cast<size_t>(inDims.size(1))}; weightsDims = {fcLayer->_out_num, static_cast<size_t>(inDims[1])};
} else if (inDims.ndims() == 3) {
weightsDims = {fcLayer->_out_num, static_cast<size_t>(inDims[2])};
} else if (inDims.ndims() == 4) { } else if (inDims.ndims() == 4) {
weightsDims = {fcLayer->_out_num, static_cast<size_t>(inDims[1]), static_cast<size_t>(inDims[2]), weightsDims = {fcLayer->_out_num, static_cast<size_t>(inDims[1]), static_cast<size_t>(inDims[2]),
static_cast<size_t>(inDims[3])}; static_cast<size_t>(inDims[3])};
@ -196,7 +198,8 @@ void MKLDNNFullyConnectedNode::setPostOps(mkldnn::primitive_attr &attr, bool ini
if (depthwiseNode) { if (depthwiseNode) {
if (initWeights) { if (initWeights) {
auto* depthwiseLayer = reinterpret_cast<WeightableLayer*>(depthwiseNode->getCnnLayer().get()); auto* depthwiseLayer = reinterpret_cast<WeightableLayer*>(depthwiseNode->getCnnLayer().get());
MKLDNNDims depthwiseDims({static_cast<ptrdiff_t>(rnd_up(getChildEdgeAt(0)->getDims()[1], 16))}); int ndims = getParentEdgeAt(0)->getDims().ndims();
MKLDNNDims depthwiseDims({static_cast<ptrdiff_t>(rnd_up(ndims == 3 ? getChildEdgeAt(0)->getDims()[2] : getChildEdgeAt(0)->getDims()[1], 16))});
PostOpsIntBlobMemory.push_back(MKLDNNMemoryPtr(new MKLDNNMemory(getEngine()))); PostOpsIntBlobMemory.push_back(MKLDNNMemoryPtr(new MKLDNNMemory(getEngine())));
PostOpsIntBlobMemory[blob_idx]->Create(depthwiseDims, memory::data_type::f32, memory::format::x); PostOpsIntBlobMemory[blob_idx]->Create(depthwiseDims, memory::data_type::f32, memory::format::x);
@ -206,7 +209,7 @@ void MKLDNNFullyConnectedNode::setPostOps(mkldnn::primitive_attr &attr, bool ini
depthwiseLayer->_weights->size() * depthwiseLayer->_weights->size() *
MKLDNNExtensionUtils::sizeOfDataType(memory::data_type::f32)); MKLDNNExtensionUtils::sizeOfDataType(memory::data_type::f32));
if (depthwiseNode->isBroadcast()) { if (depthwiseNode->isBroadcast() || ndims == 3) {
float broadcastValue = static_cast<float *>(PostOpsIntBlobMemory[blob_idx]->GetData())[0]; float broadcastValue = static_cast<float *>(PostOpsIntBlobMemory[blob_idx]->GetData())[0];
for (int i = 1; i < PostOpsIntBlobMemory[blob_idx]->GetPrimitiveDescriptor().desc().data.dims[0]; i++) { for (int i = 1; i < PostOpsIntBlobMemory[blob_idx]->GetPrimitiveDescriptor().desc().data.dims[0]; i++) {
static_cast<float *>(PostOpsIntBlobMemory[blob_idx]->GetData())[i] = broadcastValue; static_cast<float *>(PostOpsIntBlobMemory[blob_idx]->GetData())[i] = broadcastValue;
@ -222,7 +225,7 @@ void MKLDNNFullyConnectedNode::setPostOps(mkldnn::primitive_attr &attr, bool ini
depthwiseLayer->_biases->size() * depthwiseLayer->_biases->size() *
MKLDNNExtensionUtils::sizeOfDataType(memory::data_type::f32)); MKLDNNExtensionUtils::sizeOfDataType(memory::data_type::f32));
if (depthwiseNode->isBroadcast()) { if (depthwiseNode->isBroadcast() || ndims == 3) {
float broadcastValue = static_cast<float *>(PostOpsIntBlobMemory[blob_idx + 1]->GetData())[0]; float broadcastValue = static_cast<float *>(PostOpsIntBlobMemory[blob_idx + 1]->GetData())[0];
for (int i = 1; i < PostOpsIntBlobMemory[blob_idx + 1]->GetPrimitiveDescriptor().desc().data.dims[0]; i++) { for (int i = 1; i < PostOpsIntBlobMemory[blob_idx + 1]->GetPrimitiveDescriptor().desc().data.dims[0]; i++) {
static_cast<float *>(PostOpsIntBlobMemory[blob_idx + 1]->GetData())[i] = broadcastValue; static_cast<float *>(PostOpsIntBlobMemory[blob_idx + 1]->GetData())[i] = broadcastValue;
@ -270,6 +273,8 @@ memory::format MKLDNNFullyConnectedNode::weightsFormatForSrcFormat(memory::forma
case memory::format::x: case memory::format::x:
return memory::format::x; return memory::format::x;
case memory::format::nc: case memory::format::nc:
case memory::format::tnc:
case memory::format::ntc:
return memory::format::oi; return memory::format::oi;
case memory::format::nchw: case memory::format::nchw:
return memory::format::oihw; return memory::format::oihw;

View File

@ -163,6 +163,39 @@ void MKLDNNQuantizeNode::init() {
binarizationOutputMask[i] = outputHighData[isOutputHighBroadcasted ? 0 : i] == 1.f ? 0xffffffff : 0x00000000; binarizationOutputMask[i] = outputHighData[isOutputHighBroadcasted ? 0 : i] == 1.f ? 0xffffffff : 0x00000000;
} }
} else { } else {
auto allElementsAreEqual = [&](const float* data, size_t size) {
if (size == 0)
return true;
auto first = data[0];
for (int i = 1; i < size; i++) {
if (data[i] != first)
return false;
}
return true;
};
if (allElementsAreEqual(inputLowData, inputLowAxisSize)) {
inputLowAxisSize = 1;
isInputLowBroadcasted = true;
}
if (allElementsAreEqual(inputHighData, inputHighAxisSize)) {
inputHighAxisSize = 1;
isInputHighBroadcasted = true;
}
if (allElementsAreEqual(outputLowData, outputLowAxisSize)) {
outputLowAxisSize = 1;
isOutputLowBroadcasted = true;
}
if (allElementsAreEqual(outputHighData, outputHighAxisSize)) {
outputHighAxisSize = 1;
isOutputHighBroadcasted = true;
}
cropLow.resize(inputLowAxisSize); cropLow.resize(inputLowAxisSize);
cropHigh.resize(inputHighAxisSize); cropHigh.resize(inputHighAxisSize);
inputScale.resize(std::max(inputLowAxisSize, inputHighAxisSize)); inputScale.resize(std::max(inputLowAxisSize, inputHighAxisSize));

View File

@ -19,7 +19,7 @@ using std::function;
struct depthwise_test_params { struct depthwise_test_params {
algorithm alg; algorithm alg;
// Formats: NC, NCHW, NCDHW // Formats: NC, CHW (actually NCH), NCHW, NCDHW
vector<size_t> dims; vector<size_t> dims;
bool isBroadcast; bool isBroadcast;
@ -40,8 +40,9 @@ void ref_depthwise(const InferenceEngine::TBlob<data_t> &src, const data_t *weig
size_t MB = src.getTensorDesc().getDims()[0]; size_t MB = src.getTensorDesc().getDims()[0];
size_t IC = src.getTensorDesc().getDims()[1]; size_t IC = src.getTensorDesc().getDims()[1];
size_t ID = dims_size == 5 ? src.getTensorDesc().getDims()[2] : 1u; size_t ID = dims_size == 5 ? src.getTensorDesc().getDims()[2] : 1u;
size_t IH = dims_size == 2 ? 1 : src.getTensorDesc().getDims()[dims_size - 2]; size_t IH = dims_size < 3 ? 1 : dims_size == 3 ? src.getTensorDesc().getDims()[dims_size - 1]
size_t IW = dims_size == 2 ? 1 : src.getTensorDesc().getDims()[dims_size - 1]; : src.getTensorDesc().getDims()[dims_size - 2];
size_t IW = dims_size < 4 ? 1 : src.getTensorDesc().getDims()[dims_size - 1];
const data_t *src_data = src.readOnly(); const data_t *src_data = src.readOnly();
const data_t *weights_data = weights; const data_t *weights_data = weights;
@ -129,25 +130,22 @@ protected:
std::string model = model_t; std::string model = model_t;
auto dims_size = p.dims.size(); auto dims_size = p.dims.size();
if (dims_size == 4) { if (dims_size < 5)
REMOVE_LINE(model, "<dim>_ID_</dim>"); REMOVE_LINE(model, "<dim>_ID_</dim>");
} else if (dims_size == 2) { if (dims_size < 4)
REMOVE_LINE(model, "<dim>_ID_</dim>");
REMOVE_LINE(model, "<dim>_IH_</dim>");
REMOVE_LINE(model, "<dim>_IW_</dim>"); REMOVE_LINE(model, "<dim>_IW_</dim>");
} if (dims_size < 3)
REMOVE_LINE(model, "<dim>_IH_</dim>");
REPLACE_WITH_NUM(model, "_IN_", p.dims[0]); REPLACE_WITH_NUM(model, "_IN_", p.dims[0]);
REPLACE_WITH_NUM(model, "_IC_", p.dims[1]); REPLACE_WITH_NUM(model, "_IC_", p.dims[1]);
if (dims_size > 2) { if (dims_size > 2)
REPLACE_WITH_NUM(model, "_IH_", dims_size == 3 ? p.dims[dims_size - 1] : p.dims[dims_size - 2]);
if (dims_size > 3)
REPLACE_WITH_NUM(model, "_IW_", p.dims[dims_size - 1]); REPLACE_WITH_NUM(model, "_IW_", p.dims[dims_size - 1]);
REPLACE_WITH_NUM(model, "_IH_", p.dims[dims_size - 2]); if (dims_size > 4)
}
if (dims_size > 4) {
REPLACE_WITH_NUM(model, "_ID_", p.dims[dims_size - 3]); REPLACE_WITH_NUM(model, "_ID_", p.dims[dims_size - 3]);
}
if (p.alg == depthwise_scale_shift) { if (p.alg == depthwise_scale_shift) {
REPLACE_WITH_STR(model, "_LT_", "ScaleShift"); REPLACE_WITH_STR(model, "_LT_", "ScaleShift");
@ -214,6 +212,8 @@ protected:
InferenceEngine::Layout layout = InferenceEngine::ANY; InferenceEngine::Layout layout = InferenceEngine::ANY;
switch (p.dims.size()) { switch (p.dims.size()) {
case 2: layout = InferenceEngine::NC; break; case 2: layout = InferenceEngine::NC; break;
// InferenceEngine::Layout doesn't have alias for 3D NCH layout so we use CHW instead
case 3: layout = InferenceEngine::CHW; break;
case 4: layout = InferenceEngine::NCHW; break; case 4: layout = InferenceEngine::NCHW; break;
case 5: layout = InferenceEngine::NCDHW; break; case 5: layout = InferenceEngine::NCDHW; break;
} }
@ -317,6 +317,26 @@ INSTANTIATE_TEST_CASE_P(
depthwise_test_params{depthwise_prelu, {4, 4, 4, 10, 10}, true, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}} depthwise_test_params{depthwise_prelu, {4, 4, 4, 10, 10}, true, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref_any}}
)); ));
INSTANTIATE_TEST_CASE_P(
TestsDepthwise3D, MKLDNNGraphDepthwiseTests,
::testing::Values(
depthwise_test_params{depthwise_scale_shift, {1, 32, 16}, false, num_2d_impl, jit},
depthwise_test_params{depthwise_scale_shift, {8, 32, 16}, false, num_2d_impl, jit},
depthwise_test_params{depthwise_scale_shift, {4, 3, 2}, true, num_2d_impl, jit},
depthwise_test_params{depthwise_scale_shift, {1, 1, 1}, false, num_2d_impl, jit},
depthwise_test_params{depthwise_scale_shift, {37, 35, 17}, false, num_2d_impl, jit},
depthwise_test_params{depthwise_prelu, {128, 32, 19}, false, num_2d_impl, jit},
depthwise_test_params{depthwise_prelu, {4, 3, 2}, true, num_2d_impl, jit},
depthwise_test_params{depthwise_prelu, {1, 1, 1}, false, num_2d_impl, jit},
depthwise_test_params{depthwise_prelu, {37, 35, 17}, false, num_2d_impl, jit},
depthwise_test_params{depthwise_scale_shift, {128, 32, 19}, false, num_2d_impl, ref, {ref_any}},
depthwise_test_params{depthwise_scale_shift, {4, 3, 2}, true, num_2d_impl, ref, {ref_any}},
depthwise_test_params{depthwise_scale_shift, {1, 1, 1}, false, num_2d_impl, ref, {ref_any}},
depthwise_test_params{depthwise_prelu, {128, 32, 17}, false, num_2d_impl, ref, {ref_any}},
depthwise_test_params{depthwise_prelu, {4, 3, 19}, true, num_2d_impl, ref, {ref_any}},
depthwise_test_params{depthwise_prelu, {1, 1, 1}, false, num_2d_impl, ref, {ref_any}}
));
class MKLDNNGraphDynBatchDepthwiseTests: public MKLDNNGraphDepthwiseTests { class MKLDNNGraphDynBatchDepthwiseTests: public MKLDNNGraphDepthwiseTests {
protected: protected:

@ -1 +1 @@
Subproject commit 32bafef642e246f3f3c333bbe4b021a744bee2d9 Subproject commit afabfc9d4975a13d2229a64ee782260d43d148ef