[CPU] Enabled MatMul+Transpose transformations and reduced MatMul inference overheads (#6570)

This commit is contained in:
Alexandra Sidorova 2021-07-28 16:11:41 +03:00 committed by GitHub
parent 1471095bdb
commit 1aa58b4c7d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 91 additions and 42 deletions

View File

@ -58,6 +58,7 @@
#include <transformations/op_conversions/convert_previous_nms_to_nms_5.hpp> #include <transformations/op_conversions/convert_previous_nms_to_nms_5.hpp>
#include <transformations/op_conversions/convert_nms_to_nms_ie_internal.hpp> #include <transformations/op_conversions/convert_nms_to_nms_ie_internal.hpp>
#include <transformations/op_conversions/convert_deformable_conv_v8_to_v1.hpp> #include <transformations/op_conversions/convert_deformable_conv_v8_to_v1.hpp>
#include <transformations/smart_reshape/matmul_sr.hpp>
#include <transformations/convert_precision.hpp> #include <transformations/convert_precision.hpp>
#include <transformations/init_node_info.hpp> #include <transformations/init_node_info.hpp>
#include <transformations/rt_info/fused_names_attribute.hpp> #include <transformations/rt_info/fused_names_attribute.hpp>
@ -167,6 +168,7 @@ static void Transformation(CNNNetwork& clonedNetwork, const Config& conf) {
manager.register_pass<ngraph::pass::ConvertNMS3ToNMS5>(); manager.register_pass<ngraph::pass::ConvertNMS3ToNMS5>();
manager.register_pass<ngraph::pass::ConvertNMS4ToNMS5>(); manager.register_pass<ngraph::pass::ConvertNMS4ToNMS5>();
manager.register_pass<ngraph::pass::ConvertNMSToNMSIEInternal>(); manager.register_pass<ngraph::pass::ConvertNMSToNMSIEInternal>();
manager.register_pass<ngraph::pass::TransposeMatMul>();
manager.register_pass<ngraph::pass::ConstantFolding>(); manager.register_pass<ngraph::pass::ConstantFolding>();
if (useLpt) { if (useLpt) {

View File

@ -55,8 +55,8 @@ MKLDNNMatMulNode::MKLDNNMatMulNode(const std::shared_ptr<ngraph::Node>& op, cons
errorPrefix = "Gemm node with name '" + getName() + "'"; errorPrefix = "Gemm node with name '" + getName() + "'";
const auto matMul = std::dynamic_pointer_cast<const ngraph::opset1::MatMul>(op); const auto matMul = std::dynamic_pointer_cast<const ngraph::opset1::MatMul>(op);
alpha = 1; alpha = 1.f;
beta = 1; beta = 0.f;
transposeA = matMul->get_transpose_a(); transposeA = matMul->get_transpose_a();
transposeB = matMul->get_transpose_b(); transposeB = matMul->get_transpose_b();
} else { } else {
@ -179,6 +179,34 @@ void MKLDNNMatMulNode::createPrimitive() {
IE_THROW() << errorPrefix << " did not allocate input memory"; IE_THROW() << errorPrefix << " did not allocate input memory";
if (getSelectedPrimitiveDescriptor() == nullptr) if (getSelectedPrimitiveDescriptor() == nullptr)
IE_THROW() << errorPrefix << " did not set preferable primitive descriptor"; IE_THROW() << errorPrefix << " did not set preferable primitive descriptor";
auto inDims0 = src0MemPtr->GetDims();
auto outDims = dstMemPtr->GetDims();
params.src0_mem_ptr = src0MemPtr;
params.src1_mem_ptr = src1MemPtr;
params.dst_mem_ptr = dstMemPtr;
params.ndims = outDims.size();
params.MB1 = 1;
params.MB2 = outDims.size() > 3 ? outDims[params.ndims - 3] : 1;
params.M = outDims[yAxis];
params.N = outDims[xAxis];
params.K = transposeA ? inDims0[yAxis] : inDims0[xAxis];
params.transa = transposeA ? 'T' : 'N';
params.transb = transposeB ? 'T' : 'N';
params.lda = transposeA ? params.M : params.K;
params.ldb = transposeB ? params.K : params.N;
params.ldc = params.N;
params.shift1 = params.M * params.N * params.MB2;
params.shift2 = params.M * params.N;
runtimePrecision = getParentEdgeAt(0)->getDesc().getPrecision();
} }
inline void process_gemm(char transa, char transb, int M, int N, int K, float alpha, const float *A, int lda, inline void process_gemm(char transa, char transb, int M, int N, int K, float alpha, const float *A, int lda,
@ -212,67 +240,57 @@ inline void process_gemm(char transa, char transb, int M, int N, int K, float al
} }
template<typename T0, typename T1> template<typename T0, typename T1>
void MKLDNNMatMulNode::process_data() { inline void MKLDNNMatMulNode::process_data() {
auto inDims0 = getParentEdgeAt(0)->getDims(); const T0* src0_ptr = reinterpret_cast<const T0*>(params.src0_mem_ptr->GetPtr());
auto inDims1 = getParentEdgeAt(1)->getDims(); const T1* src1_ptr = reinterpret_cast<const T1*>(params.src1_mem_ptr->GetPtr());
auto outDims = getChildEdgeAt(0)->getDims(); float* dst_ptr = reinterpret_cast<float*>(params.dst_mem_ptr->GetPtr());
auto& srcMemory0 = getParentEdgeAt(0)->getMemory(); const int MB = batchToProcess();
auto& srcMemory1 = getParentEdgeAt(1)->getMemory(); if (params.ndims == 4) {
auto& dstMemory0 = getChildEdgeAt(0)->getMemory(); params.MB1 = MB;
} else if (params.ndims == 3) {
params.shift1 = params.shift1 * MB / params.MB2;
params.MB2 = MB;
}
const T0 *src0_ptr = reinterpret_cast<const T0*>(srcMemory0.GetPtr()); for (int b1 = 0; b1 < params.MB1; ++b1) {
const T1 *src1_ptr = reinterpret_cast<const T1*>(srcMemory1.GetData());
float *dst_ptr = reinterpret_cast<float*>(dstMemory0.GetData());
int MB1 = outDims.ndims() == 4 ? batchToProcess() : 1;
int MB2 = outDims.ndims() == 3 ? batchToProcess() : outDims.ndims() > 3 ? outDims[outDims.ndims() - 3] : 1;
int M = outDims[yAxis];
int N = outDims[xAxis];
int K = transposeA ? inDims0[yAxis] : inDims0[xAxis];
const char transa = transposeA ? 'T' : 'N';
const char transb = transposeB ? 'T' : 'N';
int lda = transposeA ? M : K;
int ldb = transposeB ? K : N;
int ldc = N;
beta = 0.f;
for (int b1 = 0; b1 < MB1; b1++) {
const T0 *a_ptr = src0_ptr; const T0 *a_ptr = src0_ptr;
const T1 *b_ptr = src1_ptr; const T1 *b_ptr = src1_ptr;
float *d_ptr = dst_ptr; float *d_ptr = dst_ptr;
for (int b2 = 0; b2 < MB2; b2++) { for (int b2 = 0; b2 < params.MB2; ++b2) {
process_gemm(transa, transb, M, N, K, alpha, a_ptr, lda, b_ptr, ldb, beta, d_ptr, ldc); process_gemm(params.transa, params.transb, params.M, params.N, params.K,
alpha, a_ptr, params.lda, b_ptr, params.ldb, beta, d_ptr, params.ldc);
a_ptr += aOffsets[0]; a_ptr += aOffsets[0];
b_ptr += bOffsets[0]; b_ptr += bOffsets[0];
d_ptr += M * N; d_ptr += params.shift2;
} }
src0_ptr += aOffsets[1]; src0_ptr += aOffsets[1];
src1_ptr += bOffsets[1]; src1_ptr += bOffsets[1];
dst_ptr += MB2 * M * N; dst_ptr += params.shift1;
} }
} }
void MKLDNNMatMulNode::execute(mkldnn::stream strm) { void MKLDNNMatMulNode::execute(mkldnn::stream strm) {
switch (getParentEdgeAt(0)->getDesc().getPrecision()) { switch (runtimePrecision) {
case Precision::FP32: case Precision::FP32: {
process_data<float, float>(); process_data<float, float>();
break; break;
case Precision::BF16: }
case Precision::BF16: {
process_data<uint16_t, uint16_t>(); process_data<uint16_t, uint16_t>();
break; break;
case Precision::I8: }
case Precision::I8: {
process_data<int8_t, int8_t>(); process_data<int8_t, int8_t>();
break; break;
case Precision::U8: }
case Precision::U8: {
process_data<uint8_t, int8_t>(); process_data<uint8_t, int8_t>();
break; break;
}
default: default:
IE_THROW() << errorPrefix << " has incorrect precision on first input"; IE_THROW() << errorPrefix << " has incorrect precision on first input";
} }

View File

@ -28,8 +28,8 @@ public:
static bool isSupportedOperation(const std::shared_ptr<ngraph::Node>& op, std::string& errorMessage) noexcept; static bool isSupportedOperation(const std::shared_ptr<ngraph::Node>& op, std::string& errorMessage) noexcept;
private: private:
float alpha = 1.0f; float alpha = 1.f;
float beta = 1.0f; float beta = 0.f;
bool transposeA = false; bool transposeA = false;
bool transposeB = false; bool transposeB = false;
@ -40,9 +40,36 @@ private:
std::vector<int> bOffsets; std::vector<int> bOffsets;
std::vector<int> cOffsets; std::vector<int> cOffsets;
template<typename T0, typename T1> void process_data(); InferenceEngine::Precision runtimePrecision;
template<typename T0, typename T1> inline void process_data();
std::string errorPrefix; std::string errorPrefix;
struct {
MKLDNNMemoryPtr src0_mem_ptr = nullptr;
MKLDNNMemoryPtr src1_mem_ptr = nullptr;
MKLDNNMemoryPtr dst_mem_ptr = nullptr;
char transa = 'N';
char transb = 'N';
int MB1 = 1;
int MB2 = 1;
int M = 0;
int N = 0;
int K = 0;
int lda = 0;
int ldb = 0;
int ldc = 0;
int shift1 = 0;
int shift2 = 0;
size_t ndims = 0;
} params;
}; };
} // namespace MKLDNNPlugin } // namespace MKLDNNPlugin

View File

@ -18,6 +18,8 @@ const std::vector<ShapeRelatedParams> shapeRelatedParams = {
{ { {1, 4, 5, 6}, false }, { {1, 4, 6, 4}, false } }, { { {1, 4, 5, 6}, false }, { {1, 4, 6, 4}, false } },
{ { {4, 5, 6}, false }, { {6, 3}, false } }, { { {4, 5, 6}, false }, { {6, 3}, false } },
{ { {9, 9, 9}, false }, { {9, 9}, false } }, { { {9, 9, 9}, false }, { {9, 9}, false } },
{ { {1, 2, 3}, false }, { {1, 10, 3}, true } },
{ { {1, 2, 3}, false }, { {1, 3, 10}, false } },
{ { {1, 2, 3}, false }, { {1, 1, 3, 2}, false } }, { { {1, 2, 3}, false }, { {1, 1, 3, 2}, false } },
{ { {1, 3, 2, 4}, false }, { {2, 1, 4, 2}, false } }, { { {1, 3, 2, 4}, false }, { {2, 1, 4, 2}, false } },
{ { {2, 1, 2, 4}, false }, { {1, 3, 4, 2}, false } }, { { {2, 1, 2, 4}, false }, { {1, 3, 4, 2}, false } },
@ -30,7 +32,7 @@ const std::vector<ShapeRelatedParams> shapeRelatedParams = {
{ { {2, 2, 1, 3}, false }, { {3}, false } }, { { {2, 2, 1, 3}, false }, { {3}, false } },
{ { {1, 5}, false }, { {5, 1}, false } }, { { {1, 5}, false }, { {5, 1}, false } },
{ { {5, 1}, true }, { {5, 1}, false } }, { { {5, 1}, true }, { {5, 1}, false } },
{ { {1, 5}, false }, { {1, 5}, true } }, { { {1, 5}, false }, { {10, 5}, true } },
{ { {1, 5}, false }, { {5}, false } }, { { {1, 5}, false }, { {5}, false } },
{ { {5}, false }, { {5, 1}, false } }, { { {5}, false }, { {5, 1}, false } },
{ { {5}, false }, { {5}, false } }, { { {5}, false }, { {5}, false } },