[CPU] Optimize DFT operation (#11946)

This commit is contained in:
avoskoboinyk-lohika 2022-11-04 08:46:46 +02:00 committed by GitHub
parent e9e3044d99
commit 6fc23b4768
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 686 additions and 112 deletions

View File

@ -2,12 +2,14 @@
// SPDX-License-Identifier: Apache-2.0
//
#include "dft.h"
#include <string>
#include <thread>
#include <vector>
#include <cmath>
#include <dnnl_extension_utils.h>
#include "dft.h"
#include "ie_parallel.hpp"
#include "ie_precision.hpp"
#include <onednn/dnnl.h>
@ -15,7 +17,8 @@
#include "common/cpu_memcpy.h"
#include <ngraph/opsets/opset7.hpp>
using namespace dnnl;
using namespace dnnl::impl;
using namespace dnnl::impl::cpu::x64;
using namespace InferenceEngine;
namespace ov {
@ -28,8 +31,8 @@ bool DFT::isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, st
errorMessage = "Doesn't support op with dynamic shapes";
return false;
}
const auto interpDFT = std::dynamic_pointer_cast<const ngraph::opset7::DFT>(op);
const auto interpIDFT = std::dynamic_pointer_cast<const ngraph::opset7::IDFT>(op);
const auto interpDFT = ov::is_type<const op::v7::DFT>(op);
const auto interpIDFT = ov::is_type<const op::v7::IDFT>(op);
if (!interpDFT && !interpIDFT) {
errorMessage = "Only opset7 DFT/IDFT operation is supported";
@ -74,7 +77,8 @@ DFT::DFT(const std::shared_ptr<ngraph::Node>& op, const dnnl::engine& eng, Weigh
}
}
inverse = std::dynamic_pointer_cast<ngraph::opset7::DFT>(op) == nullptr;
inverse = !ov::is_type<op::v7::DFT>(op);
lastInverse = !inverse;
}
void DFT::getSupportedDescriptors() {}
@ -230,53 +234,72 @@ void copyDataToOutputWithSignalSize(const float* input, const std::vector<size_t
} // namespace
void DFT::execute(dnnl::stream strm) {
auto axesEdge = getParentEdgeAt(AXES_INDEX);
const auto* axesStartPtr = reinterpret_cast<const int32_t*>(axesEdge->getMemoryPtr()->GetPtr());
axes = std::vector<int32_t>(axesStartPtr, axesStartPtr + axesEdge->getMemory().getStaticDims()[0]);
for (auto& axis : axes) {
if (axis < 0) {
axis += inputShape.size() - 1;
}
}
std::sort(axes.begin(), axes.end());
const auto& outputShape = getChildEdgesAtPort(0)[0]->getMemory().getStaticDims();
outputShape = getChildEdgesAtPort(0)[0]->getMemory().getStaticDims();
const auto inputDataEdge = getParentEdgeAt(DATA_INDEX);
const auto outputDataEdge = getChildEdgeAt(0);
const auto src = reinterpret_cast<const float*>(inputDataEdge->getMemoryPtr()->GetPtr());
auto dst = reinterpret_cast<float*>(outputDataEdge->getMemoryPtr()->GetPtr());
const auto inputRank = inputDataEdge->getMemory().GetShape().getRank();
const auto& inputStrides = inputDataEdge->getMemory().GetDescWithType<BlockedMemoryDesc>()->getStrides();
const auto& outputStrides = outputDataEdge->getMemory().GetDescWithType<BlockedMemoryDesc>()->getStrides();
size_t nComplexMaxFFT = 0;
for (size_t axis : axes) {
size_t nComplex = outputShape[axis];
// FFT uses different twiddle factors
if (twiddlesMap.find(nComplex) == twiddlesMap.end() && !IsPowerOfTwo(nComplex)) {
twiddlesMap[nComplex] = generateTwiddles(nComplex);
if (!IsPowerOfTwo(nComplex)) {
if (twiddlesMapDFT.find(nComplex) == twiddlesMapDFT.end() || lastInverse != inverse) {
twiddlesMapDFT[nComplex] = generateTwiddlesDFT(nComplex, inverse);
}
} else {
if (nComplexMaxFFT < nComplex) {
nComplexMaxFFT = nComplex;
}
}
}
auto inputDataEdge = getParentEdgeAt(DATA_INDEX);
auto outputDataEdge = getChildEdgeAt(0);
const auto *input = reinterpret_cast<const float*>(inputDataEdge->getMemoryPtr()->GetPtr());
auto *output = reinterpret_cast<float*>(outputDataEdge->getMemoryPtr()->GetPtr());
if (nComplexMaxFFT > 0 && ((nComplexMaxFFT - 1) * 2 > twiddlesFFT.size() || lastInverse != inverse)) {
updateTwiddlesFFT(nComplexMaxFFT, inverse);
}
auto inputStrides = inputDataEdge->getMemory().GetDescWithType<BlockedMemoryDesc>()->getStrides();
auto outputStrides = outputDataEdge->getMemory().GetDescWithType<BlockedMemoryDesc>()->getStrides();
if (inputShape != outputShape) {
copyDataToOutputWithSignalSize(input, inputShape, inputStrides, output, outputShape, outputStrides);
copyDataToOutputWithSignalSize(src, inputShape, inputStrides, dst, outputShape, outputStrides);
} else {
auto totalElements = std::accumulate(inputShape.begin(), inputShape.end(), size_t(1), std::multiplies<size_t>());
cpu_memcpy(output, input, totalElements * sizeof(float));
cpu_memcpy(dst, src, totalElements * sizeof(float));
}
// 1d case
if (inputDataEdge->getMemory().GetShape().getRank() == 2) {
if (inputRank == 2) {
size_t nComplex = outputShape[0];
if (IsPowerOfTwo(nComplex)) {
fft(output, nComplex * 2, true);
std::vector<float> outputData(nComplex * 2);
const float* resultBufPtr;
fft(dst, outputData.data(), nComplex * 2, inverse, true, &resultBufPtr);
if (resultBufPtr != dst) {
cpu_memcpy(dst, resultBufPtr, nComplex * 2 * sizeof(float));
}
} else {
naiveDFT(output, nComplex * 2);
naiveDFT(dst, nComplex * 2, inverse);
}
} else {
dftNd(output, outputStrides);
dftNd(dst, outputShape, outputStrides, axes, inverse);
}
lastInverse = inverse;
}
void DFT::dftNd(float* output, const std::vector<size_t>& outputStrides) const {
void DFT::dftNd(float* output,
const VectorDims& outputShape,
const VectorDims& outputStrides,
const std::vector<int32_t>& axes,
bool inverse) const {
const std::vector<size_t> iterationRange(outputShape.begin(), outputShape.end() - 1);
const size_t lastDimIndex = iterationRange.size() - 1;
for (size_t axisIndex = 0; axisIndex < axes.size(); ++axisIndex) {
@ -289,12 +312,13 @@ void DFT::dftNd(float* output, const std::vector<size_t>& outputStrides) const {
size_t parallelDimIndex = lastDimIndex == currentAxis ? lastDimIndex - 1 : lastDimIndex;
do {
parallel_for(iterationRange[parallelDimIndex], [&](size_t dim) {
std::vector<float> gatheredData(outputLen);
std::vector<float> gatheredData(outputLen * 2);
auto parallelIterationCounter = iterationCounter;
parallelIterationCounter[parallelDimIndex] = dim;
gatherToBufferND(gatheredData.data(), output, currentAxis, parallelIterationCounter, outputShape, outputStrides);
fft(gatheredData.data(), outputLen);
applyBufferND(gatheredData.data(), output, currentAxis, parallelIterationCounter, outputShape, outputStrides);
const float* resultBufPtr;
fft(gatheredData.data(), gatheredData.data() + outputLen, outputLen, inverse, false, &resultBufPtr);
applyBufferND(resultBufPtr, output, currentAxis, parallelIterationCounter, outputShape, outputStrides);
});
iterationCounter[parallelDimIndex] = iterationRange[parallelDimIndex] - 1;
} while (nextIterationStep(iterationCounter, iterationRange, currentAxis));
@ -302,7 +326,7 @@ void DFT::dftNd(float* output, const std::vector<size_t>& outputStrides) const {
std::vector<float> gatheredData(outputLen);
do {
gatherToBufferND(gatheredData.data(), output, currentAxis, iterationCounter, outputShape, outputStrides);
naiveDFT(gatheredData.data(), outputLen);
naiveDFT(gatheredData.data(), outputLen, inverse);
applyBufferND(gatheredData.data(), output, currentAxis, iterationCounter, outputShape, outputStrides);
} while (nextIterationStep(iterationCounter, iterationRange, currentAxis));
}
@ -310,118 +334,257 @@ void DFT::dftNd(float* output, const std::vector<size_t>& outputStrides) const {
}
/* Cooley Tukey implementation of FFT */
void DFT::fft(float* data, int64_t dataLength, bool parallelize) const {
static int cacheSizeL3 = utils::get_cache_size(3, false);
void DFT::fft(float* inBuffer,
float* outBuffer,
int64_t dataLength,
bool inverse,
bool parallelize,
const float** resultBuf) const {
static int cacheSizeL3 = dnnl::utils::get_cache_size(3, false);
static int elementsPerCacheLine = cacheSizeL3 / sizeof(float);
std::vector<float> bufferVector(dataLength * 2, 0);
float* buffer = bufferVector.data();
cpu_memcpy(buffer, data, dataLength * sizeof(float));
size_t nComplex = dataLength / 2;
float* inBufferStart = buffer + dataLength;
float* outBufferStart = buffer;
auto blockIteration = [&] (const size_t block, const size_t blockSize, const size_t nextIterationBlockSize, const float anglePart) {
float* curInpBufferPtr = inBufferStart + block * blockSize;
float* curOutBufferPtr = outBufferStart + block * nextIterationBlockSize;
std::function<void(const size_t, const size_t, const size_t)> blockIteration;
if (fftKernel != nullptr) {
blockIteration = [&](const size_t block, const size_t numBlocks, const size_t nextIterationBlockSize) {
auto arg = jit_args_fft();
const float angle = anglePart * block;
const float twiddleReal = std::cos(angle);
const float twiddleImag = -std::sin(angle);
for (int64_t pair = 0; pair < blockSize / 2; pair += 2) {
const float evenReal = curInpBufferPtr[pair];
const float evenImag = curInpBufferPtr[pair + 1];
arg.src = inBuffer + block * nextIterationBlockSize * 2;
arg.dst = outBuffer + block * nextIterationBlockSize;
arg.twiddles = &twiddlesFFT[(numBlocks + block - 1) * 2];
arg.num_blocks = numBlocks;
arg.work_amount = nextIterationBlockSize;
arg.n_complex = nComplex;
const float oddReal = curInpBufferPtr[(blockSize / 2 + pair)];
const float oddImag = curInpBufferPtr[(blockSize / 2 + pair) + 1];
(*fftKernel)(&arg);
};
} else {
blockIteration = [&](const size_t block, const size_t numBlocks, const size_t nextIterationBlockSize) {
float* curInpBufferPtr = inBuffer + block * nextIterationBlockSize * 2;
float* curOutBufferPtr = outBuffer + block * nextIterationBlockSize;
const float twiddledOddReal = getRealFromComplexProd(twiddleReal, twiddleImag, oddReal, oddImag);
const float twiddledOddImag = getImaginaryFromComplexProd(twiddleReal, twiddleImag, oddReal, oddImag);
for (size_t block = 0; block < numBlocks; ++block) {
float twiddleReal = twiddlesFFT[(numBlocks + block - 1) * 2];
float twiddleImag = twiddlesFFT[(numBlocks + block) * 2 - 1];
curOutBufferPtr[pair] = evenReal + twiddledOddReal;
curOutBufferPtr[pair + 1] = evenImag + twiddledOddImag;
for (size_t pair = 0; pair < nextIterationBlockSize; pair += 2) {
const float evenReal = curInpBufferPtr[pair];
const float evenImag = curInpBufferPtr[pair + 1];
curOutBufferPtr[nComplex + pair] = evenReal - twiddledOddReal;
curOutBufferPtr[nComplex + pair + 1] = evenImag - twiddledOddImag;
}
};
const float oddReal = curInpBufferPtr[(nextIterationBlockSize + pair)];
const float oddImag = curInpBufferPtr[(nextIterationBlockSize + pair) + 1];
for (int64_t numBlocks = 1; numBlocks < nComplex; numBlocks *= 2) {
const float anglePart = PI / numBlocks * (inverse ? -1 : 1);
const float twiddledOddReal = getRealFromComplexProd(twiddleReal, twiddleImag, oddReal, oddImag);
const float twiddledOddImag =
getImaginaryFromComplexProd(twiddleReal, twiddleImag, oddReal, oddImag);
std::swap(inBufferStart, outBufferStart);
const int64_t blockSize = dataLength / numBlocks;
const int64_t nextIterationBlockSize = blockSize / 2;
curOutBufferPtr[pair] = evenReal + twiddledOddReal;
curOutBufferPtr[pair + 1] = evenImag + twiddledOddImag;
curOutBufferPtr[nComplex + pair] = evenReal - twiddledOddReal;
curOutBufferPtr[nComplex + pair + 1] = evenImag - twiddledOddImag;
}
}
};
}
size_t blockSize;
size_t nextIterationBlockSize = dataLength;
for (size_t numBlocks = 1; numBlocks < nComplex; numBlocks *= 2) {
blockSize = nextIterationBlockSize;
nextIterationBlockSize /= 2;
if (parallelize && blockSize >= 4 * elementsPerCacheLine) {
parallel_for(numBlocks, [&] (const size_t block) {
blockIteration(block, blockSize, nextIterationBlockSize, anglePart);
parallel_for(numBlocks, [&](const size_t block) {
blockIteration(block, 1, nextIterationBlockSize);
});
} else {
for (int64_t block = 0; block < numBlocks; ++block) {
blockIteration(block, blockSize, nextIterationBlockSize, anglePart);
}
blockIteration(0, numBlocks, nextIterationBlockSize);
}
std::swap(inBuffer, outBuffer);
}
if (inverse) {
for (int64_t k = 0; k < dataLength; k++) {
inBuffer[k] /= nComplex;
}
}
for (int64_t k = 0; k < dataLength; k++) {
if (inverse) {
outBufferStart[k] /= nComplex;
}
data[k] = outBufferStart[k];
}
*resultBuf = inBuffer;
}
void DFT::naiveDFT(float* data, size_t dataLength) const {
void DFT::naiveDFT(float* data, size_t dataLength, bool inverse) const {
std::vector<float> outputBuffer(dataLength);
const size_t nComplex = dataLength / 2;
const auto& twiddles = twiddlesMap.find(nComplex)->second;
const float reciprocalNComplex = 1.0f / nComplex;
const auto& twiddles = twiddlesMapDFT.find(nComplex)->second;
parallel_for(nComplex, [&](size_t k) {
float sumReal = 0.0f;
float sumImag = 0.0f;
for (size_t n = 0; n < nComplex; ++n) {
auto it = twiddles[k * nComplex + n];
float complexReal = it.first;
float complexImag = it.second;
std::function<void(size_t)> blockIteration;
if (dftKernel != nullptr) {
blockIteration = [&](size_t k) {
auto arg = jit_args_dft();
arg.src = data;
arg.dst = outputBuffer.data() + 2 * k;
arg.twiddles = twiddles.data() + 2 * k * nComplex;
arg.work_amount = nComplex;
arg.index = k;
(*dftKernel)(&arg);
if (inverse) {
complexImag *= -1; // conjugate
outputBuffer[k * 2] *= reciprocalNComplex;
outputBuffer[k * 2 + 1] *= reciprocalNComplex;
}
float complexProdReal = getRealFromComplexProd(data[2 * n], data[2 * n + 1], complexReal, complexImag);
float complexProdImag = getImaginaryFromComplexProd(data[2 * n], data[2 * n + 1], complexReal, complexImag);
};
} else {
blockIteration = [&](size_t k) {
float sumReal = 0.0f;
float sumImag = 0.0f;
for (size_t n = 0; n < nComplex; ++n) {
auto complexRef = &twiddles[2 * (k * nComplex + n)];
float complexReal = *complexRef;
float complexImag = *(complexRef + 1);
sumReal += complexProdReal;
sumImag += complexProdImag;
}
float complexProdReal = getRealFromComplexProd(data[2 * n], data[2 * n + 1], complexReal, complexImag);
float complexProdImag =
getImaginaryFromComplexProd(data[2 * n], data[2 * n + 1], complexReal, complexImag);
if (inverse) {
sumReal /= nComplex;
sumImag /= nComplex;
}
outputBuffer[k * 2] = sumReal;
outputBuffer[k * 2 + 1] = sumImag;
});
sumReal += complexProdReal;
sumImag += complexProdImag;
}
if (inverse) {
sumReal *= reciprocalNComplex;
sumImag *= reciprocalNComplex;
}
outputBuffer[k * 2] = sumReal;
outputBuffer[k * 2 + 1] = sumImag;
};
}
parallel_for(nComplex, blockIteration);
cpu_memcpy(data, outputBuffer.data(), dataLength * sizeof(float));
}
std::vector<std::pair<float, float>> DFT::generateTwiddles(size_t n_complex) const {
std::vector<std::pair<float, float>> twiddles(n_complex * n_complex);
std::vector<float> DFT::generateTwiddlesDFT(size_t n_complex, bool inverse) const {
std::vector<float> twiddles(n_complex * n_complex * 2);
const float inverseMultiplier = inverse ? 1 : -1;
parallel_for(n_complex, [&](const size_t k) {
for (size_t n = 0; n < n_complex; ++n) {
float phase = 2.0f * PI * static_cast<float>(n * k) / static_cast<float>(n_complex);
float phase = 2.0f * PI * static_cast<float>(n * k) / static_cast<float>(n_complex);
auto complexReal = std::cos(phase);
auto complexImag = -std::sin(phase);
twiddles[k * n_complex + n] = std::make_pair(complexReal, complexImag);
auto complexImag = std::sin(phase) * inverseMultiplier;
twiddles[2 * (k * n_complex + n)] = complexReal;
twiddles[2 * (k * n_complex + n) + 1] = complexImag;
}
});
return twiddles;
}
void DFT::updateTwiddlesFFT(size_t n_complex, bool inverse) {
const float inverseMultiplier = inverse ? 1 : -1;
size_t numBlocks = 1;
twiddlesFFT.reserve((n_complex - 1) * 2);
if (twiddlesFFT.size() == 0) {
twiddlesFFT.emplace_back(1.0f); // cos(0)
twiddlesFFT.emplace_back(-0.0f); // -sin(0)
} else {
for (size_t i = numBlocks; i < twiddlesFFT.size() / 2; i += numBlocks) {
numBlocks *= 2;
}
}
for (size_t i = twiddlesFFT.size() / 2; i < n_complex - 1; i += numBlocks) {
numBlocks *= 2;
for (size_t blockNum = 0; blockNum < numBlocks; blockNum++) {
size_t copyIndex = twiddlesFFT.size() - blockNum - numBlocks;
twiddlesFFT.push_back(twiddlesFFT[copyIndex]);
twiddlesFFT.push_back(twiddlesFFT[copyIndex + 1]);
blockNum++;
float angle = PI * blockNum / numBlocks;
auto complexReal = std::cos(angle);
auto complexImag = std::sin(angle) * inverseMultiplier;
twiddlesFFT.emplace_back(complexReal);
twiddlesFFT.emplace_back(complexImag);
}
}
}
bool DFT::created() const {
return getType() == Type::DFT;
}
void DFT::createPrimitive() {}
void DFT::prepareParams() {
bool hasDFT = false;
bool hasFFT = false;
axes = getAxes();
const auto outputShape = getChildEdgesAtPort(0)[0]->getMemory().getStaticDims();
for (size_t axis : axes) {
size_t nComplex = outputShape[axis];
if (!IsPowerOfTwo(nComplex)) {
hasDFT = true;
} else {
hasFFT = true;
}
}
if (mayiuse(cpu::x64::sse41)) {
createJITKernels(hasDFT, hasFFT);
}
}
std::vector<int32_t> DFT::getAxes() const {
auto axesEdge = getParentEdgeAt(AXES_INDEX);
const auto* axesStartPtr = reinterpret_cast<const int32_t*>(axesEdge->getMemoryPtr()->GetPtr());
auto axes = std::vector<int32_t>(axesStartPtr, axesStartPtr + axesEdge->getMemory().getStaticDims()[0]);
for (auto& axis : axes) {
if (axis < 0) {
axis += inputShape.size() - 1;
}
}
std::sort(axes.begin(), axes.end());
return axes;
}
void DFT::createJITKernels(bool hasDFT, bool hasFFT) {
if (hasDFT && dftKernel == nullptr) {
if (mayiuse(cpu::x64::avx512_core)) {
dftKernel.reset(new jit_uni_dft_kernel_f32<cpu::x64::avx512_core>());
} else if (mayiuse(cpu::x64::avx2)) {
dftKernel.reset(new jit_uni_dft_kernel_f32<cpu::x64::avx2>());
} else if (mayiuse(cpu::x64::sse41)) {
dftKernel.reset(new jit_uni_dft_kernel_f32<cpu::x64::sse41>());
} else {
IE_THROW() << "Can't create jit DFT kernel";
}
if (dftKernel)
dftKernel->create_ker();
}
if (hasFFT && fftKernel == nullptr) {
if (mayiuse(cpu::x64::avx512_core)) {
fftKernel.reset(new jit_uni_fft_kernel_f32<cpu::x64::avx512_core>());
} else if (mayiuse(cpu::x64::avx2)) {
fftKernel.reset(new jit_uni_fft_kernel_f32<cpu::x64::avx2>());
} else if (mayiuse(cpu::x64::sse41)) {
fftKernel.reset(new jit_uni_fft_kernel_f32<cpu::x64::sse41>());
} else {
IE_THROW() << "Can't create jit FFT kernel";
}
if (fftKernel)
fftKernel->create_ker();
}
}
} // namespace node
} // namespace intel_cpu

View File

@ -8,6 +8,8 @@
#include <node.h>
#include <string>
#include "kernels/dft_uni_kernel.hpp"
namespace ov {
namespace intel_cpu {
namespace node {
@ -19,29 +21,50 @@ public:
void getSupportedDescriptors() override;
void initSupportedPrimitiveDescriptors() override;
void createPrimitive() override;
void execute(dnnl::stream strm) override;
bool created() const override;
void prepareParams() override;
static bool isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, std::string& errorMessage) noexcept;
private:
void dftNd(float* output, const std::vector<size_t>& outputStrides) const;
void fft(float* data, int64_t dataLength, bool parallelize = false) const;
void naiveDFT(float* data, size_t dataLength) const;
std::vector<int32_t> getAxes() const;
void createJITKernels(bool hasDFT, bool hasFFT);
std::vector<std::pair<float, float>> generateTwiddles(size_t n_complex) const;
void dftNd(float* output,
const VectorDims& outputShape,
const VectorDims& outputStrides,
const std::vector<int32_t>& axes,
bool inverse) const;
void fft(float* inBuffer,
float* outBuffer,
int64_t dataLength,
bool inverse,
bool parallelize,
const float** resultBuf) const;
void naiveDFT(float* data, size_t dataLength, bool inverse) const;
std::vector<float> generateTwiddlesDFT(size_t n_complex, bool inverse) const;
void updateTwiddlesFFT(size_t n_complex, bool inverse);
std::unique_ptr<jit_uni_dft_kernel> dftKernel = nullptr;
std::unique_ptr<jit_uni_fft_kernel> fftKernel = nullptr;
std::vector<float> twiddlesFFT;
std::unordered_map<size_t, std::vector<float>> twiddlesMapDFT;
std::unordered_map<size_t, std::vector<std::pair<float, float>>> twiddlesMap;
std::vector<int32_t> axes;
std::vector<size_t> outputShape;
std::vector<size_t> inputShape;
std::string layerErrorPrefix;
const size_t DATA_INDEX = 0;
const size_t AXES_INDEX = 1;
const size_t SIGNAL_SIZE_INDEX = 2;
static constexpr float PI = 3.141592653589793238462643f;
bool inverse;
bool lastInverse;
};
} // namespace node

View File

@ -0,0 +1,246 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "dft_uni_kernel.hpp"
using namespace dnnl::impl;
using namespace dnnl::impl::utils;
using namespace dnnl::impl::cpu::x64;
#define GET_OFF_DFT(field) offsetof(jit_args_dft, field)
#define GET_OFF_FFT(field) offsetof(jit_args_fft, field)
namespace ov {
namespace intel_cpu {
template <cpu::x64::cpu_isa_t isa>
jit_uni_dft_kernel_f32<isa>::jit_uni_dft_kernel_f32() : jit_uni_dft_kernel(), jit_generator(jit_name()) {}
template <cpu::x64::cpu_isa_t isa>
void jit_uni_dft_kernel_f32<isa>::create_ker() {
jit_generator::create_kernel();
ker_ = (decltype(ker_))jit_ker();
}
template <cpu::x64::cpu_isa_t isa>
void jit_uni_dft_kernel_f32<isa>::generate() {
this->preamble();
mov(reg_src, ptr[reg_params + GET_OFF_DFT(src)]);
mov(reg_dst, ptr[reg_params + GET_OFF_DFT(dst)]);
mov(reg_twiddles, ptr[reg_params + GET_OFF_DFT(twiddles)]);
mov(reg_work_amount, ptr[reg_params + GET_OFF_DFT(work_amount)]);
mov(reg_index, ptr[reg_params + GET_OFF_DFT(index)]);
Xbyak::Label main_loop_label;
Xbyak::Label main_loop_end_label;
Xbyak::Label tail_loop_label;
Xbyak::Label tail_loop_end_label;
uni_vpxor(vmm_sum, vmm_sum, vmm_sum);
int step = vlen / 8;
L(main_loop_label);
{
cmp(reg_work_amount, step);
jl(main_loop_end_label, T_NEAR);
uni_vmovups(vmm_data_cache, ptr[reg_src]);
uni_vmovups(vmm_twiddles_cache, ptr[reg_twiddles]);
uni_vshufps(vmm_data, vmm_data_cache, vmm_data_cache, 0b01000001);
uni_vshufps(vmm_twiddles, vmm_twiddles_cache, vmm_twiddles_cache, 0b01000100);
uni_vfmadd231ps(vmm_sum, vmm_data, vmm_twiddles);
uni_vshufps(vmm_data, vmm_data_cache, vmm_data_cache, 0b11101011);
uni_vshufps(vmm_twiddles, vmm_twiddles_cache, vmm_twiddles_cache, 0b11101110);
uni_vfmadd231ps(vmm_sum, vmm_data, vmm_twiddles);
add(reg_twiddles, 2 * step * sizeof(float));
add(reg_src, 2 * step * sizeof(float));
sub(reg_work_amount, step);
jmp(main_loop_label, T_NEAR);
}
L(main_loop_end_label);
if (mayiuse(cpu::x64::avx512_core)) {
Xbyak::Zmm zmm_sum = Xbyak::Zmm(vmm_sum.getIdx());
Xbyak::Ymm ymm_sum = Xbyak::Ymm(vmm_sum.getIdx());
Xbyak::Ymm ymm_sum_2 = Xbyak::Ymm(vmm_sum_2.getIdx());
vextractf64x4(ymm_sum_2, zmm_sum, 1);
vaddps(ymm_sum, ymm_sum, ymm_sum_2);
}
if (mayiuse(cpu::x64::avx2)) {
Xbyak::Ymm ymm_sum = Xbyak::Ymm(vmm_sum.getIdx());
vextractf128(xmm_sum_2, ymm_sum, 1);
vaddps(xmm_sum, xmm_sum, xmm_sum_2);
}
L(tail_loop_label);
{
cmp(reg_work_amount, 1);
jl(tail_loop_end_label, T_NEAR);
uni_vmovups(xmm_data, ptr[reg_src]);
uni_vmovups(xmm_twiddles, ptr[reg_twiddles]);
uni_vshufps(xmm_data, xmm_data, xmm_data, 0b01000001);
uni_vshufps(xmm_twiddles, xmm_twiddles, xmm_twiddles, 0b01000100);
uni_vfmadd231ps(xmm_sum, xmm_data, xmm_twiddles);
add(reg_twiddles, 2 * sizeof(float));
add(reg_src, 2 * sizeof(float));
sub(reg_work_amount, 1);
jmp(tail_loop_label, T_NEAR);
}
L(tail_loop_end_label);
uni_vmovhlps(xmm_sum_2, xmm_sum_2, xmm_sum);
uni_vhsubps(xmm_sum_2, xmm_sum_2, xmm_sum_2);
uni_vhaddps(xmm_sum, xmm_sum, xmm_sum);
uni_vmovss(ptr[reg_dst], xmm_sum_2);
uni_vmovss(ptr[reg_dst + sizeof(float)], xmm_sum);
this->postamble();
}
template struct jit_uni_dft_kernel_f32<cpu::x64::sse41>;
template struct jit_uni_dft_kernel_f32<cpu::x64::avx2>;
template struct jit_uni_dft_kernel_f32<cpu::x64::avx512_core>;
template <cpu::x64::cpu_isa_t isa>
jit_uni_fft_kernel_f32<isa>::jit_uni_fft_kernel_f32()
: jit_uni_fft_kernel(),
jit_generator(jit_name()) {}
template <cpu::x64::cpu_isa_t isa>
void jit_uni_fft_kernel_f32<isa>::create_ker() {
jit_generator::create_kernel();
ker_ = (decltype(ker_))jit_ker();
}
template <cpu::x64::cpu_isa_t isa>
void jit_uni_fft_kernel_f32<isa>::generate() {
this->preamble();
mov(reg_src, ptr[reg_params + GET_OFF_FFT(src)]);
mov(reg_dst, ptr[reg_params + GET_OFF_FFT(dst)]);
mov(reg_twiddles_addr, ptr[reg_params + GET_OFF_FFT(twiddles)]);
mov(reg_num_blocks, ptr[reg_params + GET_OFF_FFT(num_blocks)]);
mov(reg_work_amount, ptr[reg_params + GET_OFF_FFT(work_amount)]);
mov(reg_even_in_diff, sizeof(float));
mul(ptr[reg_params + GET_OFF_FFT(n_complex)]);
mov(reg_even_out_diff, reg_even_in_diff);
mov(reg_even_in_diff, sizeof(float));
mul(reg_work_amount);
Xbyak::Label block_loop_label;
Xbyak::Label block_loop_end_label;
L(block_loop_label);
{
cmp(reg_num_blocks, 1);
jl(block_loop_end_label, T_NEAR);
mov(aux_reg_work_amount, reg_work_amount);
uni_vbroadcastss(vmm_twiddle_real, ptr[reg_twiddles_addr]);
uni_vbroadcastss(vmm_twiddle_imag, ptr[reg_twiddles_addr + sizeof(float)]);
if (mayiuse(cpu::x64::avx2)) {
loop_process<Vmm>(vlen / 4);
}
loop_process<Xbyak::Xmm>(4);
loop_process<Xbyak::Xmm>(2);
add(reg_twiddles_addr, 2 * sizeof(float));
add(reg_src, reg_even_in_diff);
sub(reg_num_blocks, 1);
jmp(block_loop_label, T_NEAR);
}
L(block_loop_end_label);
this->postamble();
}
template <cpu::x64::cpu_isa_t isa>
template <typename T>
void jit_uni_fft_kernel_f32<isa>::loop_process(int step) {
T reg_data_odd_1 = T(vmm_data_odd_1.getIdx());
T reg_data_odd_2 = T(vmm_data_odd_2.getIdx());
T reg_twiddle_imag = T(vmm_twiddle_imag.getIdx());
T reg_twiddle_real = T(vmm_twiddle_real.getIdx());
T reg_data_even = T(vmm_data_even.getIdx());
T reg_data_result = T(vmm_data_result.getIdx());
Xbyak::Label loop_label;
Xbyak::Label loop_end_label;
L(loop_label);
{
cmp(aux_reg_work_amount, step);
jl(loop_end_label, T_NEAR);
move_data(reg_data_odd_1, ptr[reg_src + reg_even_in_diff], step);
uni_vshufps(reg_data_odd_2, reg_data_odd_1, reg_data_odd_1, 0b10110001);
uni_vmulps(reg_data_odd_2, reg_data_odd_2, reg_twiddle_imag);
if (mayiuse(cpu::x64::avx512_core)) {
vfmaddsub213ps(reg_data_odd_1, reg_twiddle_real, reg_data_odd_2);
} else {
uni_vmulps(reg_data_odd_1, reg_data_odd_1, reg_twiddle_real);
uni_vaddsubps(reg_data_odd_1, reg_data_odd_1, reg_data_odd_2);
}
move_data(reg_data_even, ptr[reg_src], step);
uni_vaddps(reg_data_result, reg_data_even, reg_data_odd_1);
move_data(ptr[reg_dst], reg_data_result, step);
uni_vsubps(reg_data_result, reg_data_even, reg_data_odd_1);
move_data(ptr[reg_dst + reg_even_out_diff], reg_data_result, step);
add(reg_src, step * sizeof(float));
add(reg_dst, step * sizeof(float));
sub(aux_reg_work_amount, step);
jmp(loop_label, T_NEAR);
}
L(loop_end_label);
}
template <cpu::x64::cpu_isa_t isa>
void jit_uni_fft_kernel_f32<isa>::move_data(const Xbyak::Address& addr, const Xbyak::Xmm& x, int count) {
if (count == 2) {
uni_vmovq(addr, x);
} else {
uni_vmovups(addr, x);
}
}
template <cpu::x64::cpu_isa_t isa>
void jit_uni_fft_kernel_f32<isa>::move_data(const Xbyak::Xmm& x, const Xbyak::Address& addr, int count) {
if (count == 2) {
uni_vmovq(x, addr);
} else {
uni_vmovups(x, addr);
}
}
template struct jit_uni_fft_kernel_f32<cpu::x64::sse41>;
template struct jit_uni_fft_kernel_f32<cpu::x64::avx2>;
template struct jit_uni_fft_kernel_f32<cpu::x64::avx512_core>;
} // namespace intel_cpu
} // namespace ov

View File

@ -0,0 +1,142 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <cpu/x64/cpu_isa_traits.hpp>
#include <cpu/x64/jit_generator.hpp>
namespace ov {
namespace intel_cpu {
struct jit_args_dft {
const float* src;
float* dst;
const float* twiddles;
size_t work_amount;
size_t index;
};
struct jit_args_fft {
const float* src;
float* dst;
const float* twiddles;
size_t num_blocks;
size_t work_amount;
size_t n_complex;
};
struct jit_uni_dft_kernel {
void (*ker_)(const jit_args_dft*);
void operator()(const jit_args_dft* args) {
assert(ker_);
ker_(args);
}
jit_uni_dft_kernel() : ker_(nullptr) {}
virtual ~jit_uni_dft_kernel() {}
virtual void create_ker() = 0;
};
struct jit_uni_fft_kernel {
void (*ker_)(const jit_args_fft*);
void operator()(const jit_args_fft* args) {
assert(ker_);
ker_(args);
}
jit_uni_fft_kernel() : ker_(nullptr) {}
virtual ~jit_uni_fft_kernel() {}
virtual void create_ker() = 0;
};
template <dnnl::impl::cpu::x64::cpu_isa_t isa>
struct jit_uni_dft_kernel_f32 : public jit_uni_dft_kernel, public dnnl::impl::cpu::x64::jit_generator {
DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_dft_kernel_f32)
jit_uni_dft_kernel_f32();
void create_ker() override;
void generate() override;
private:
using Vmm = typename dnnl::impl::utils::conditional3<isa == dnnl::impl::cpu::x64::sse41,
Xbyak::Xmm,
isa == dnnl::impl::cpu::x64::avx2,
Xbyak::Ymm,
Xbyak::Zmm>::type;
size_t vlen = dnnl::impl::cpu::x64::cpu_isa_traits<isa>::vlen;
Xbyak::Reg64 reg_src = r8;
Xbyak::Reg64 reg_dst = r9;
Xbyak::Reg64 reg_twiddles = r10;
Xbyak::Reg64 reg_work_amount = r11;
Xbyak::Reg64 reg_index = r12;
Xbyak::Reg64 reg_params = Xbyak::Reg64(dnnl::impl::cpu::x64::abi_param_regs[0]);
Vmm vmm_data = Vmm(0);
Vmm vmm_twiddles = Vmm(1);
Vmm vmm_sum = Vmm(2);
Vmm vmm_sum_2 = vmm_data;
Vmm vmm_data_cache = Vmm(3);
Vmm vmm_twiddles_cache = Vmm(4);
Xbyak::Xmm xmm_data = Xbyak::Xmm(0);
Xbyak::Xmm xmm_twiddles = Xbyak::Xmm(1);
Xbyak::Xmm xmm_sum = Xbyak::Xmm(2);
Xbyak::Xmm xmm_sum_2 = xmm_data;
};
template <dnnl::impl::cpu::x64::cpu_isa_t isa>
struct jit_uni_fft_kernel_f32 : public jit_uni_fft_kernel, public dnnl::impl::cpu::x64::jit_generator {
DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_fft_kernel_f32)
jit_uni_fft_kernel_f32();
void create_ker() override;
void generate() override;
private:
using Vmm = typename dnnl::impl::utils::conditional3<isa == dnnl::impl::cpu::x64::sse41,
Xbyak::Xmm,
isa == dnnl::impl::cpu::x64::avx2,
Xbyak::Ymm,
Xbyak::Zmm>::type;
const size_t vlen = dnnl::impl::cpu::x64::cpu_isa_traits<isa>::vlen;
Xbyak::Reg64 reg_even_in_diff = rax;
Xbyak::Reg64 reg_even_out_diff = rbx;
Xbyak::Reg64 reg_src = r9;
Xbyak::Reg64 reg_dst = r10;
Xbyak::Reg64 reg_num_blocks = r11;
Xbyak::Reg64 reg_work_amount = r12;
Xbyak::Reg64 aux_reg_work_amount = r13;
Xbyak::Reg64 reg_twiddles_addr = r14;
Xbyak::Reg64 reg_params = Xbyak::Reg64(dnnl::impl::cpu::x64::abi_param_regs[0]);
Vmm vmm_data_odd_1 = Vmm(0);
Vmm vmm_data_odd_2 = Vmm(1);
Vmm vmm_twiddle_real = Vmm(2);
Vmm vmm_twiddle_imag = Vmm(3);
Vmm vmm_data_even = Vmm(4);
Vmm vmm_data_result = vmm_data_odd_2;
template <typename T>
void loop_process(int step);
void move_data(const Xbyak::Address& addr, const Xbyak::Xmm& x, int count);
void move_data(const Xbyak::Xmm& x, const Xbyak::Address& addr, int count);
};
} // namespace intel_cpu
} // namespace ov