[CPU] Enabled MatMul+Transpose transformations and reduced MatMul inference overheads (#6570)
This commit is contained in:
parent
1471095bdb
commit
1aa58b4c7d
@ -58,6 +58,7 @@
|
||||
#include <transformations/op_conversions/convert_previous_nms_to_nms_5.hpp>
|
||||
#include <transformations/op_conversions/convert_nms_to_nms_ie_internal.hpp>
|
||||
#include <transformations/op_conversions/convert_deformable_conv_v8_to_v1.hpp>
|
||||
#include <transformations/smart_reshape/matmul_sr.hpp>
|
||||
#include <transformations/convert_precision.hpp>
|
||||
#include <transformations/init_node_info.hpp>
|
||||
#include <transformations/rt_info/fused_names_attribute.hpp>
|
||||
@ -167,6 +168,7 @@ static void Transformation(CNNNetwork& clonedNetwork, const Config& conf) {
|
||||
manager.register_pass<ngraph::pass::ConvertNMS3ToNMS5>();
|
||||
manager.register_pass<ngraph::pass::ConvertNMS4ToNMS5>();
|
||||
manager.register_pass<ngraph::pass::ConvertNMSToNMSIEInternal>();
|
||||
manager.register_pass<ngraph::pass::TransposeMatMul>();
|
||||
manager.register_pass<ngraph::pass::ConstantFolding>();
|
||||
|
||||
if (useLpt) {
|
||||
|
@ -55,8 +55,8 @@ MKLDNNMatMulNode::MKLDNNMatMulNode(const std::shared_ptr<ngraph::Node>& op, cons
|
||||
errorPrefix = "Gemm node with name '" + getName() + "'";
|
||||
|
||||
const auto matMul = std::dynamic_pointer_cast<const ngraph::opset1::MatMul>(op);
|
||||
alpha = 1;
|
||||
beta = 1;
|
||||
alpha = 1.f;
|
||||
beta = 0.f;
|
||||
transposeA = matMul->get_transpose_a();
|
||||
transposeB = matMul->get_transpose_b();
|
||||
} else {
|
||||
@ -179,6 +179,34 @@ void MKLDNNMatMulNode::createPrimitive() {
|
||||
IE_THROW() << errorPrefix << " did not allocate input memory";
|
||||
if (getSelectedPrimitiveDescriptor() == nullptr)
|
||||
IE_THROW() << errorPrefix << " did not set preferable primitive descriptor";
|
||||
|
||||
auto inDims0 = src0MemPtr->GetDims();
|
||||
auto outDims = dstMemPtr->GetDims();
|
||||
|
||||
params.src0_mem_ptr = src0MemPtr;
|
||||
params.src1_mem_ptr = src1MemPtr;
|
||||
params.dst_mem_ptr = dstMemPtr;
|
||||
|
||||
params.ndims = outDims.size();
|
||||
|
||||
params.MB1 = 1;
|
||||
params.MB2 = outDims.size() > 3 ? outDims[params.ndims - 3] : 1;
|
||||
|
||||
params.M = outDims[yAxis];
|
||||
params.N = outDims[xAxis];
|
||||
params.K = transposeA ? inDims0[yAxis] : inDims0[xAxis];
|
||||
|
||||
params.transa = transposeA ? 'T' : 'N';
|
||||
params.transb = transposeB ? 'T' : 'N';
|
||||
|
||||
params.lda = transposeA ? params.M : params.K;
|
||||
params.ldb = transposeB ? params.K : params.N;
|
||||
params.ldc = params.N;
|
||||
|
||||
params.shift1 = params.M * params.N * params.MB2;
|
||||
params.shift2 = params.M * params.N;
|
||||
|
||||
runtimePrecision = getParentEdgeAt(0)->getDesc().getPrecision();
|
||||
}
|
||||
|
||||
inline void process_gemm(char transa, char transb, int M, int N, int K, float alpha, const float *A, int lda,
|
||||
@ -212,67 +240,57 @@ inline void process_gemm(char transa, char transb, int M, int N, int K, float al
|
||||
}
|
||||
|
||||
template<typename T0, typename T1>
|
||||
void MKLDNNMatMulNode::process_data() {
|
||||
auto inDims0 = getParentEdgeAt(0)->getDims();
|
||||
auto inDims1 = getParentEdgeAt(1)->getDims();
|
||||
auto outDims = getChildEdgeAt(0)->getDims();
|
||||
inline void MKLDNNMatMulNode::process_data() {
|
||||
const T0* src0_ptr = reinterpret_cast<const T0*>(params.src0_mem_ptr->GetPtr());
|
||||
const T1* src1_ptr = reinterpret_cast<const T1*>(params.src1_mem_ptr->GetPtr());
|
||||
float* dst_ptr = reinterpret_cast<float*>(params.dst_mem_ptr->GetPtr());
|
||||
|
||||
auto& srcMemory0 = getParentEdgeAt(0)->getMemory();
|
||||
auto& srcMemory1 = getParentEdgeAt(1)->getMemory();
|
||||
auto& dstMemory0 = getChildEdgeAt(0)->getMemory();
|
||||
const int MB = batchToProcess();
|
||||
if (params.ndims == 4) {
|
||||
params.MB1 = MB;
|
||||
} else if (params.ndims == 3) {
|
||||
params.shift1 = params.shift1 * MB / params.MB2;
|
||||
params.MB2 = MB;
|
||||
}
|
||||
|
||||
const T0 *src0_ptr = reinterpret_cast<const T0*>(srcMemory0.GetPtr());
|
||||
const T1 *src1_ptr = reinterpret_cast<const T1*>(srcMemory1.GetData());
|
||||
float *dst_ptr = reinterpret_cast<float*>(dstMemory0.GetData());
|
||||
|
||||
int MB1 = outDims.ndims() == 4 ? batchToProcess() : 1;
|
||||
int MB2 = outDims.ndims() == 3 ? batchToProcess() : outDims.ndims() > 3 ? outDims[outDims.ndims() - 3] : 1;
|
||||
int M = outDims[yAxis];
|
||||
int N = outDims[xAxis];
|
||||
int K = transposeA ? inDims0[yAxis] : inDims0[xAxis];
|
||||
|
||||
const char transa = transposeA ? 'T' : 'N';
|
||||
const char transb = transposeB ? 'T' : 'N';
|
||||
|
||||
int lda = transposeA ? M : K;
|
||||
int ldb = transposeB ? K : N;
|
||||
int ldc = N;
|
||||
|
||||
beta = 0.f;
|
||||
|
||||
for (int b1 = 0; b1 < MB1; b1++) {
|
||||
for (int b1 = 0; b1 < params.MB1; ++b1) {
|
||||
const T0 *a_ptr = src0_ptr;
|
||||
const T1 *b_ptr = src1_ptr;
|
||||
float *d_ptr = dst_ptr;
|
||||
|
||||
for (int b2 = 0; b2 < MB2; b2++) {
|
||||
process_gemm(transa, transb, M, N, K, alpha, a_ptr, lda, b_ptr, ldb, beta, d_ptr, ldc);
|
||||
for (int b2 = 0; b2 < params.MB2; ++b2) {
|
||||
process_gemm(params.transa, params.transb, params.M, params.N, params.K,
|
||||
alpha, a_ptr, params.lda, b_ptr, params.ldb, beta, d_ptr, params.ldc);
|
||||
|
||||
a_ptr += aOffsets[0];
|
||||
b_ptr += bOffsets[0];
|
||||
d_ptr += M * N;
|
||||
d_ptr += params.shift2;
|
||||
}
|
||||
|
||||
src0_ptr += aOffsets[1];
|
||||
src1_ptr += bOffsets[1];
|
||||
dst_ptr += MB2 * M * N;
|
||||
dst_ptr += params.shift1;
|
||||
}
|
||||
}
|
||||
|
||||
void MKLDNNMatMulNode::execute(mkldnn::stream strm) {
|
||||
switch (getParentEdgeAt(0)->getDesc().getPrecision()) {
|
||||
case Precision::FP32:
|
||||
switch (runtimePrecision) {
|
||||
case Precision::FP32: {
|
||||
process_data<float, float>();
|
||||
break;
|
||||
case Precision::BF16:
|
||||
}
|
||||
case Precision::BF16: {
|
||||
process_data<uint16_t, uint16_t>();
|
||||
break;
|
||||
case Precision::I8:
|
||||
}
|
||||
case Precision::I8: {
|
||||
process_data<int8_t, int8_t>();
|
||||
break;
|
||||
case Precision::U8:
|
||||
}
|
||||
case Precision::U8: {
|
||||
process_data<uint8_t, int8_t>();
|
||||
break;
|
||||
}
|
||||
default:
|
||||
IE_THROW() << errorPrefix << " has incorrect precision on first input";
|
||||
}
|
||||
|
@ -28,8 +28,8 @@ public:
|
||||
static bool isSupportedOperation(const std::shared_ptr<ngraph::Node>& op, std::string& errorMessage) noexcept;
|
||||
|
||||
private:
|
||||
float alpha = 1.0f;
|
||||
float beta = 1.0f;
|
||||
float alpha = 1.f;
|
||||
float beta = 0.f;
|
||||
bool transposeA = false;
|
||||
bool transposeB = false;
|
||||
|
||||
@ -40,9 +40,36 @@ private:
|
||||
std::vector<int> bOffsets;
|
||||
std::vector<int> cOffsets;
|
||||
|
||||
template<typename T0, typename T1> void process_data();
|
||||
InferenceEngine::Precision runtimePrecision;
|
||||
|
||||
template<typename T0, typename T1> inline void process_data();
|
||||
|
||||
std::string errorPrefix;
|
||||
|
||||
struct {
|
||||
MKLDNNMemoryPtr src0_mem_ptr = nullptr;
|
||||
MKLDNNMemoryPtr src1_mem_ptr = nullptr;
|
||||
MKLDNNMemoryPtr dst_mem_ptr = nullptr;
|
||||
|
||||
char transa = 'N';
|
||||
char transb = 'N';
|
||||
|
||||
int MB1 = 1;
|
||||
int MB2 = 1;
|
||||
|
||||
int M = 0;
|
||||
int N = 0;
|
||||
int K = 0;
|
||||
|
||||
int lda = 0;
|
||||
int ldb = 0;
|
||||
int ldc = 0;
|
||||
|
||||
int shift1 = 0;
|
||||
int shift2 = 0;
|
||||
|
||||
size_t ndims = 0;
|
||||
} params;
|
||||
};
|
||||
|
||||
} // namespace MKLDNNPlugin
|
||||
|
@ -18,6 +18,8 @@ const std::vector<ShapeRelatedParams> shapeRelatedParams = {
|
||||
{ { {1, 4, 5, 6}, false }, { {1, 4, 6, 4}, false } },
|
||||
{ { {4, 5, 6}, false }, { {6, 3}, false } },
|
||||
{ { {9, 9, 9}, false }, { {9, 9}, false } },
|
||||
{ { {1, 2, 3}, false }, { {1, 10, 3}, true } },
|
||||
{ { {1, 2, 3}, false }, { {1, 3, 10}, false } },
|
||||
{ { {1, 2, 3}, false }, { {1, 1, 3, 2}, false } },
|
||||
{ { {1, 3, 2, 4}, false }, { {2, 1, 4, 2}, false } },
|
||||
{ { {2, 1, 2, 4}, false }, { {1, 3, 4, 2}, false } },
|
||||
@ -30,7 +32,7 @@ const std::vector<ShapeRelatedParams> shapeRelatedParams = {
|
||||
{ { {2, 2, 1, 3}, false }, { {3}, false } },
|
||||
{ { {1, 5}, false }, { {5, 1}, false } },
|
||||
{ { {5, 1}, true }, { {5, 1}, false } },
|
||||
{ { {1, 5}, false }, { {1, 5}, true } },
|
||||
{ { {1, 5}, false }, { {10, 5}, true } },
|
||||
{ { {1, 5}, false }, { {5}, false } },
|
||||
{ { {5}, false }, { {5, 1}, false } },
|
||||
{ { {5}, false }, { {5}, false } },
|
||||
|
Loading…
Reference in New Issue
Block a user