From 1aa58b4c7dec13dfd6570ca288bd1d0232843187 Mon Sep 17 00:00:00 2001 From: Alexandra Sidorova Date: Wed, 28 Jul 2021 16:11:41 +0300 Subject: [PATCH] [CPU] Enabled MatMul+Transpose transformations and reduced MatMul inference overheads (#6570) --- .../src/mkldnn_plugin/mkldnn_plugin.cpp | 2 + .../nodes/mkldnn_matmul_node.cpp | 94 +++++++++++-------- .../mkldnn_plugin/nodes/mkldnn_matmul_node.h | 33 ++++++- .../single_layer_tests/mat_mul.cpp | 4 +- 4 files changed, 91 insertions(+), 42 deletions(-) diff --git a/inference-engine/src/mkldnn_plugin/mkldnn_plugin.cpp b/inference-engine/src/mkldnn_plugin/mkldnn_plugin.cpp index 59a29ebf40a..2d7299aed92 100644 --- a/inference-engine/src/mkldnn_plugin/mkldnn_plugin.cpp +++ b/inference-engine/src/mkldnn_plugin/mkldnn_plugin.cpp @@ -58,6 +58,7 @@ #include #include #include +#include #include #include #include @@ -167,6 +168,7 @@ static void Transformation(CNNNetwork& clonedNetwork, const Config& conf) { manager.register_pass(); manager.register_pass(); manager.register_pass(); + manager.register_pass(); manager.register_pass(); if (useLpt) { diff --git a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_matmul_node.cpp b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_matmul_node.cpp index b7f2c0a4277..3ad3e9aef55 100644 --- a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_matmul_node.cpp +++ b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_matmul_node.cpp @@ -55,8 +55,8 @@ MKLDNNMatMulNode::MKLDNNMatMulNode(const std::shared_ptr& op, cons errorPrefix = "Gemm node with name '" + getName() + "'"; const auto matMul = std::dynamic_pointer_cast(op); - alpha = 1; - beta = 1; + alpha = 1.f; + beta = 0.f; transposeA = matMul->get_transpose_a(); transposeB = matMul->get_transpose_b(); } else { @@ -179,6 +179,34 @@ void MKLDNNMatMulNode::createPrimitive() { IE_THROW() << errorPrefix << " did not allocate input memory"; if (getSelectedPrimitiveDescriptor() == nullptr) IE_THROW() << errorPrefix << " did not set preferable primitive descriptor"; + + auto inDims0 = src0MemPtr->GetDims(); + auto outDims = dstMemPtr->GetDims(); + + params.src0_mem_ptr = src0MemPtr; + params.src1_mem_ptr = src1MemPtr; + params.dst_mem_ptr = dstMemPtr; + + params.ndims = outDims.size(); + + params.MB1 = 1; + params.MB2 = outDims.size() > 3 ? outDims[params.ndims - 3] : 1; + + params.M = outDims[yAxis]; + params.N = outDims[xAxis]; + params.K = transposeA ? inDims0[yAxis] : inDims0[xAxis]; + + params.transa = transposeA ? 'T' : 'N'; + params.transb = transposeB ? 'T' : 'N'; + + params.lda = transposeA ? params.M : params.K; + params.ldb = transposeB ? params.K : params.N; + params.ldc = params.N; + + params.shift1 = params.M * params.N * params.MB2; + params.shift2 = params.M * params.N; + + runtimePrecision = getParentEdgeAt(0)->getDesc().getPrecision(); } inline void process_gemm(char transa, char transb, int M, int N, int K, float alpha, const float *A, int lda, @@ -212,67 +240,57 @@ inline void process_gemm(char transa, char transb, int M, int N, int K, float al } template -void MKLDNNMatMulNode::process_data() { - auto inDims0 = getParentEdgeAt(0)->getDims(); - auto inDims1 = getParentEdgeAt(1)->getDims(); - auto outDims = getChildEdgeAt(0)->getDims(); +inline void MKLDNNMatMulNode::process_data() { + const T0* src0_ptr = reinterpret_cast(params.src0_mem_ptr->GetPtr()); + const T1* src1_ptr = reinterpret_cast(params.src1_mem_ptr->GetPtr()); + float* dst_ptr = reinterpret_cast(params.dst_mem_ptr->GetPtr()); - auto& srcMemory0 = getParentEdgeAt(0)->getMemory(); - auto& srcMemory1 = getParentEdgeAt(1)->getMemory(); - auto& dstMemory0 = getChildEdgeAt(0)->getMemory(); + const int MB = batchToProcess(); + if (params.ndims == 4) { + params.MB1 = MB; + } else if (params.ndims == 3) { + params.shift1 = params.shift1 * MB / params.MB2; + params.MB2 = MB; + } - const T0 *src0_ptr = reinterpret_cast(srcMemory0.GetPtr()); - const T1 *src1_ptr = reinterpret_cast(srcMemory1.GetData()); - float *dst_ptr = reinterpret_cast(dstMemory0.GetData()); - - int MB1 = outDims.ndims() == 4 ? batchToProcess() : 1; - int MB2 = outDims.ndims() == 3 ? batchToProcess() : outDims.ndims() > 3 ? outDims[outDims.ndims() - 3] : 1; - int M = outDims[yAxis]; - int N = outDims[xAxis]; - int K = transposeA ? inDims0[yAxis] : inDims0[xAxis]; - - const char transa = transposeA ? 'T' : 'N'; - const char transb = transposeB ? 'T' : 'N'; - - int lda = transposeA ? M : K; - int ldb = transposeB ? K : N; - int ldc = N; - - beta = 0.f; - - for (int b1 = 0; b1 < MB1; b1++) { + for (int b1 = 0; b1 < params.MB1; ++b1) { const T0 *a_ptr = src0_ptr; const T1 *b_ptr = src1_ptr; float *d_ptr = dst_ptr; - for (int b2 = 0; b2 < MB2; b2++) { - process_gemm(transa, transb, M, N, K, alpha, a_ptr, lda, b_ptr, ldb, beta, d_ptr, ldc); + for (int b2 = 0; b2 < params.MB2; ++b2) { + process_gemm(params.transa, params.transb, params.M, params.N, params.K, + alpha, a_ptr, params.lda, b_ptr, params.ldb, beta, d_ptr, params.ldc); a_ptr += aOffsets[0]; b_ptr += bOffsets[0]; - d_ptr += M * N; + d_ptr += params.shift2; } src0_ptr += aOffsets[1]; src1_ptr += bOffsets[1]; - dst_ptr += MB2 * M * N; + dst_ptr += params.shift1; } } void MKLDNNMatMulNode::execute(mkldnn::stream strm) { - switch (getParentEdgeAt(0)->getDesc().getPrecision()) { - case Precision::FP32: + switch (runtimePrecision) { + case Precision::FP32: { process_data(); break; - case Precision::BF16: + } + case Precision::BF16: { process_data(); break; - case Precision::I8: + } + case Precision::I8: { process_data(); break; - case Precision::U8: + } + case Precision::U8: { process_data(); break; + } default: IE_THROW() << errorPrefix << " has incorrect precision on first input"; } diff --git a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_matmul_node.h b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_matmul_node.h index 6196665aabc..3f056cc9953 100644 --- a/inference-engine/src/mkldnn_plugin/nodes/mkldnn_matmul_node.h +++ b/inference-engine/src/mkldnn_plugin/nodes/mkldnn_matmul_node.h @@ -28,8 +28,8 @@ public: static bool isSupportedOperation(const std::shared_ptr& op, std::string& errorMessage) noexcept; private: - float alpha = 1.0f; - float beta = 1.0f; + float alpha = 1.f; + float beta = 0.f; bool transposeA = false; bool transposeB = false; @@ -40,9 +40,36 @@ private: std::vector bOffsets; std::vector cOffsets; - template void process_data(); + InferenceEngine::Precision runtimePrecision; + + template inline void process_data(); std::string errorPrefix; + + struct { + MKLDNNMemoryPtr src0_mem_ptr = nullptr; + MKLDNNMemoryPtr src1_mem_ptr = nullptr; + MKLDNNMemoryPtr dst_mem_ptr = nullptr; + + char transa = 'N'; + char transb = 'N'; + + int MB1 = 1; + int MB2 = 1; + + int M = 0; + int N = 0; + int K = 0; + + int lda = 0; + int ldb = 0; + int ldc = 0; + + int shift1 = 0; + int shift2 = 0; + + size_t ndims = 0; + } params; }; } // namespace MKLDNNPlugin diff --git a/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/mat_mul.cpp b/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/mat_mul.cpp index 680276b2f72..3241ebef007 100644 --- a/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/mat_mul.cpp +++ b/inference-engine/tests/functional/plugin/cpu/shared_tests_instances/single_layer_tests/mat_mul.cpp @@ -18,6 +18,8 @@ const std::vector shapeRelatedParams = { { { {1, 4, 5, 6}, false }, { {1, 4, 6, 4}, false } }, { { {4, 5, 6}, false }, { {6, 3}, false } }, { { {9, 9, 9}, false }, { {9, 9}, false } }, + { { {1, 2, 3}, false }, { {1, 10, 3}, true } }, + { { {1, 2, 3}, false }, { {1, 3, 10}, false } }, { { {1, 2, 3}, false }, { {1, 1, 3, 2}, false } }, { { {1, 3, 2, 4}, false }, { {2, 1, 4, 2}, false } }, { { {2, 1, 2, 4}, false }, { {1, 3, 4, 2}, false } }, @@ -30,7 +32,7 @@ const std::vector shapeRelatedParams = { { { {2, 2, 1, 3}, false }, { {3}, false } }, { { {1, 5}, false }, { {5, 1}, false } }, { { {5, 1}, true }, { {5, 1}, false } }, - { { {1, 5}, false }, { {1, 5}, true } }, + { { {1, 5}, false }, { {10, 5}, true } }, { { {1, 5}, false }, { {5}, false } }, { { {5}, false }, { {5, 1}, false } }, { { {5}, false }, { {5}, false } },