[CPU] Optimize DFT operation (#11946)
This commit is contained in:
parent
e9e3044d99
commit
6fc23b4768
@ -2,12 +2,14 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "dft.h"
|
||||
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
#include <cmath>
|
||||
#include <dnnl_extension_utils.h>
|
||||
|
||||
#include "dft.h"
|
||||
#include "ie_parallel.hpp"
|
||||
#include "ie_precision.hpp"
|
||||
#include <onednn/dnnl.h>
|
||||
@ -15,7 +17,8 @@
|
||||
#include "common/cpu_memcpy.h"
|
||||
#include <ngraph/opsets/opset7.hpp>
|
||||
|
||||
using namespace dnnl;
|
||||
using namespace dnnl::impl;
|
||||
using namespace dnnl::impl::cpu::x64;
|
||||
using namespace InferenceEngine;
|
||||
|
||||
namespace ov {
|
||||
@ -28,8 +31,8 @@ bool DFT::isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, st
|
||||
errorMessage = "Doesn't support op with dynamic shapes";
|
||||
return false;
|
||||
}
|
||||
const auto interpDFT = std::dynamic_pointer_cast<const ngraph::opset7::DFT>(op);
|
||||
const auto interpIDFT = std::dynamic_pointer_cast<const ngraph::opset7::IDFT>(op);
|
||||
const auto interpDFT = ov::is_type<const op::v7::DFT>(op);
|
||||
const auto interpIDFT = ov::is_type<const op::v7::IDFT>(op);
|
||||
|
||||
if (!interpDFT && !interpIDFT) {
|
||||
errorMessage = "Only opset7 DFT/IDFT operation is supported";
|
||||
@ -74,7 +77,8 @@ DFT::DFT(const std::shared_ptr<ngraph::Node>& op, const dnnl::engine& eng, Weigh
|
||||
}
|
||||
}
|
||||
|
||||
inverse = std::dynamic_pointer_cast<ngraph::opset7::DFT>(op) == nullptr;
|
||||
inverse = !ov::is_type<op::v7::DFT>(op);
|
||||
lastInverse = !inverse;
|
||||
}
|
||||
|
||||
void DFT::getSupportedDescriptors() {}
|
||||
@ -230,53 +234,72 @@ void copyDataToOutputWithSignalSize(const float* input, const std::vector<size_t
|
||||
} // namespace
|
||||
|
||||
void DFT::execute(dnnl::stream strm) {
|
||||
auto axesEdge = getParentEdgeAt(AXES_INDEX);
|
||||
const auto* axesStartPtr = reinterpret_cast<const int32_t*>(axesEdge->getMemoryPtr()->GetPtr());
|
||||
axes = std::vector<int32_t>(axesStartPtr, axesStartPtr + axesEdge->getMemory().getStaticDims()[0]);
|
||||
for (auto& axis : axes) {
|
||||
if (axis < 0) {
|
||||
axis += inputShape.size() - 1;
|
||||
}
|
||||
}
|
||||
std::sort(axes.begin(), axes.end());
|
||||
const auto& outputShape = getChildEdgesAtPort(0)[0]->getMemory().getStaticDims();
|
||||
|
||||
outputShape = getChildEdgesAtPort(0)[0]->getMemory().getStaticDims();
|
||||
const auto inputDataEdge = getParentEdgeAt(DATA_INDEX);
|
||||
const auto outputDataEdge = getChildEdgeAt(0);
|
||||
|
||||
const auto src = reinterpret_cast<const float*>(inputDataEdge->getMemoryPtr()->GetPtr());
|
||||
auto dst = reinterpret_cast<float*>(outputDataEdge->getMemoryPtr()->GetPtr());
|
||||
|
||||
const auto inputRank = inputDataEdge->getMemory().GetShape().getRank();
|
||||
|
||||
const auto& inputStrides = inputDataEdge->getMemory().GetDescWithType<BlockedMemoryDesc>()->getStrides();
|
||||
const auto& outputStrides = outputDataEdge->getMemory().GetDescWithType<BlockedMemoryDesc>()->getStrides();
|
||||
|
||||
size_t nComplexMaxFFT = 0;
|
||||
for (size_t axis : axes) {
|
||||
size_t nComplex = outputShape[axis];
|
||||
// FFT uses different twiddle factors
|
||||
if (twiddlesMap.find(nComplex) == twiddlesMap.end() && !IsPowerOfTwo(nComplex)) {
|
||||
twiddlesMap[nComplex] = generateTwiddles(nComplex);
|
||||
if (!IsPowerOfTwo(nComplex)) {
|
||||
if (twiddlesMapDFT.find(nComplex) == twiddlesMapDFT.end() || lastInverse != inverse) {
|
||||
twiddlesMapDFT[nComplex] = generateTwiddlesDFT(nComplex, inverse);
|
||||
}
|
||||
} else {
|
||||
if (nComplexMaxFFT < nComplex) {
|
||||
nComplexMaxFFT = nComplex;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto inputDataEdge = getParentEdgeAt(DATA_INDEX);
|
||||
auto outputDataEdge = getChildEdgeAt(0);
|
||||
const auto *input = reinterpret_cast<const float*>(inputDataEdge->getMemoryPtr()->GetPtr());
|
||||
auto *output = reinterpret_cast<float*>(outputDataEdge->getMemoryPtr()->GetPtr());
|
||||
if (nComplexMaxFFT > 0 && ((nComplexMaxFFT - 1) * 2 > twiddlesFFT.size() || lastInverse != inverse)) {
|
||||
updateTwiddlesFFT(nComplexMaxFFT, inverse);
|
||||
}
|
||||
|
||||
auto inputStrides = inputDataEdge->getMemory().GetDescWithType<BlockedMemoryDesc>()->getStrides();
|
||||
auto outputStrides = outputDataEdge->getMemory().GetDescWithType<BlockedMemoryDesc>()->getStrides();
|
||||
if (inputShape != outputShape) {
|
||||
copyDataToOutputWithSignalSize(input, inputShape, inputStrides, output, outputShape, outputStrides);
|
||||
copyDataToOutputWithSignalSize(src, inputShape, inputStrides, dst, outputShape, outputStrides);
|
||||
} else {
|
||||
auto totalElements = std::accumulate(inputShape.begin(), inputShape.end(), size_t(1), std::multiplies<size_t>());
|
||||
cpu_memcpy(output, input, totalElements * sizeof(float));
|
||||
cpu_memcpy(dst, src, totalElements * sizeof(float));
|
||||
}
|
||||
|
||||
// 1d case
|
||||
if (inputDataEdge->getMemory().GetShape().getRank() == 2) {
|
||||
if (inputRank == 2) {
|
||||
size_t nComplex = outputShape[0];
|
||||
if (IsPowerOfTwo(nComplex)) {
|
||||
fft(output, nComplex * 2, true);
|
||||
std::vector<float> outputData(nComplex * 2);
|
||||
const float* resultBufPtr;
|
||||
|
||||
fft(dst, outputData.data(), nComplex * 2, inverse, true, &resultBufPtr);
|
||||
|
||||
if (resultBufPtr != dst) {
|
||||
cpu_memcpy(dst, resultBufPtr, nComplex * 2 * sizeof(float));
|
||||
}
|
||||
} else {
|
||||
naiveDFT(output, nComplex * 2);
|
||||
naiveDFT(dst, nComplex * 2, inverse);
|
||||
}
|
||||
} else {
|
||||
dftNd(output, outputStrides);
|
||||
dftNd(dst, outputShape, outputStrides, axes, inverse);
|
||||
}
|
||||
|
||||
lastInverse = inverse;
|
||||
}
|
||||
|
||||
void DFT::dftNd(float* output, const std::vector<size_t>& outputStrides) const {
|
||||
void DFT::dftNd(float* output,
|
||||
const VectorDims& outputShape,
|
||||
const VectorDims& outputStrides,
|
||||
const std::vector<int32_t>& axes,
|
||||
bool inverse) const {
|
||||
const std::vector<size_t> iterationRange(outputShape.begin(), outputShape.end() - 1);
|
||||
const size_t lastDimIndex = iterationRange.size() - 1;
|
||||
for (size_t axisIndex = 0; axisIndex < axes.size(); ++axisIndex) {
|
||||
@ -289,12 +312,13 @@ void DFT::dftNd(float* output, const std::vector<size_t>& outputStrides) const {
|
||||
size_t parallelDimIndex = lastDimIndex == currentAxis ? lastDimIndex - 1 : lastDimIndex;
|
||||
do {
|
||||
parallel_for(iterationRange[parallelDimIndex], [&](size_t dim) {
|
||||
std::vector<float> gatheredData(outputLen);
|
||||
std::vector<float> gatheredData(outputLen * 2);
|
||||
auto parallelIterationCounter = iterationCounter;
|
||||
parallelIterationCounter[parallelDimIndex] = dim;
|
||||
gatherToBufferND(gatheredData.data(), output, currentAxis, parallelIterationCounter, outputShape, outputStrides);
|
||||
fft(gatheredData.data(), outputLen);
|
||||
applyBufferND(gatheredData.data(), output, currentAxis, parallelIterationCounter, outputShape, outputStrides);
|
||||
const float* resultBufPtr;
|
||||
fft(gatheredData.data(), gatheredData.data() + outputLen, outputLen, inverse, false, &resultBufPtr);
|
||||
applyBufferND(resultBufPtr, output, currentAxis, parallelIterationCounter, outputShape, outputStrides);
|
||||
});
|
||||
iterationCounter[parallelDimIndex] = iterationRange[parallelDimIndex] - 1;
|
||||
} while (nextIterationStep(iterationCounter, iterationRange, currentAxis));
|
||||
@ -302,7 +326,7 @@ void DFT::dftNd(float* output, const std::vector<size_t>& outputStrides) const {
|
||||
std::vector<float> gatheredData(outputLen);
|
||||
do {
|
||||
gatherToBufferND(gatheredData.data(), output, currentAxis, iterationCounter, outputShape, outputStrides);
|
||||
naiveDFT(gatheredData.data(), outputLen);
|
||||
naiveDFT(gatheredData.data(), outputLen, inverse);
|
||||
applyBufferND(gatheredData.data(), output, currentAxis, iterationCounter, outputShape, outputStrides);
|
||||
} while (nextIterationStep(iterationCounter, iterationRange, currentAxis));
|
||||
}
|
||||
@ -310,118 +334,257 @@ void DFT::dftNd(float* output, const std::vector<size_t>& outputStrides) const {
|
||||
}
|
||||
|
||||
/* Cooley Tukey implementation of FFT */
|
||||
void DFT::fft(float* data, int64_t dataLength, bool parallelize) const {
|
||||
static int cacheSizeL3 = utils::get_cache_size(3, false);
|
||||
void DFT::fft(float* inBuffer,
|
||||
float* outBuffer,
|
||||
int64_t dataLength,
|
||||
bool inverse,
|
||||
bool parallelize,
|
||||
const float** resultBuf) const {
|
||||
static int cacheSizeL3 = dnnl::utils::get_cache_size(3, false);
|
||||
static int elementsPerCacheLine = cacheSizeL3 / sizeof(float);
|
||||
std::vector<float> bufferVector(dataLength * 2, 0);
|
||||
float* buffer = bufferVector.data();
|
||||
cpu_memcpy(buffer, data, dataLength * sizeof(float));
|
||||
|
||||
size_t nComplex = dataLength / 2;
|
||||
float* inBufferStart = buffer + dataLength;
|
||||
float* outBufferStart = buffer;
|
||||
|
||||
auto blockIteration = [&] (const size_t block, const size_t blockSize, const size_t nextIterationBlockSize, const float anglePart) {
|
||||
float* curInpBufferPtr = inBufferStart + block * blockSize;
|
||||
float* curOutBufferPtr = outBufferStart + block * nextIterationBlockSize;
|
||||
std::function<void(const size_t, const size_t, const size_t)> blockIteration;
|
||||
if (fftKernel != nullptr) {
|
||||
blockIteration = [&](const size_t block, const size_t numBlocks, const size_t nextIterationBlockSize) {
|
||||
auto arg = jit_args_fft();
|
||||
|
||||
const float angle = anglePart * block;
|
||||
const float twiddleReal = std::cos(angle);
|
||||
const float twiddleImag = -std::sin(angle);
|
||||
for (int64_t pair = 0; pair < blockSize / 2; pair += 2) {
|
||||
const float evenReal = curInpBufferPtr[pair];
|
||||
const float evenImag = curInpBufferPtr[pair + 1];
|
||||
arg.src = inBuffer + block * nextIterationBlockSize * 2;
|
||||
arg.dst = outBuffer + block * nextIterationBlockSize;
|
||||
arg.twiddles = &twiddlesFFT[(numBlocks + block - 1) * 2];
|
||||
arg.num_blocks = numBlocks;
|
||||
arg.work_amount = nextIterationBlockSize;
|
||||
arg.n_complex = nComplex;
|
||||
|
||||
const float oddReal = curInpBufferPtr[(blockSize / 2 + pair)];
|
||||
const float oddImag = curInpBufferPtr[(blockSize / 2 + pair) + 1];
|
||||
(*fftKernel)(&arg);
|
||||
};
|
||||
} else {
|
||||
blockIteration = [&](const size_t block, const size_t numBlocks, const size_t nextIterationBlockSize) {
|
||||
float* curInpBufferPtr = inBuffer + block * nextIterationBlockSize * 2;
|
||||
float* curOutBufferPtr = outBuffer + block * nextIterationBlockSize;
|
||||
|
||||
const float twiddledOddReal = getRealFromComplexProd(twiddleReal, twiddleImag, oddReal, oddImag);
|
||||
const float twiddledOddImag = getImaginaryFromComplexProd(twiddleReal, twiddleImag, oddReal, oddImag);
|
||||
for (size_t block = 0; block < numBlocks; ++block) {
|
||||
float twiddleReal = twiddlesFFT[(numBlocks + block - 1) * 2];
|
||||
float twiddleImag = twiddlesFFT[(numBlocks + block) * 2 - 1];
|
||||
|
||||
curOutBufferPtr[pair] = evenReal + twiddledOddReal;
|
||||
curOutBufferPtr[pair + 1] = evenImag + twiddledOddImag;
|
||||
for (size_t pair = 0; pair < nextIterationBlockSize; pair += 2) {
|
||||
const float evenReal = curInpBufferPtr[pair];
|
||||
const float evenImag = curInpBufferPtr[pair + 1];
|
||||
|
||||
curOutBufferPtr[nComplex + pair] = evenReal - twiddledOddReal;
|
||||
curOutBufferPtr[nComplex + pair + 1] = evenImag - twiddledOddImag;
|
||||
}
|
||||
};
|
||||
const float oddReal = curInpBufferPtr[(nextIterationBlockSize + pair)];
|
||||
const float oddImag = curInpBufferPtr[(nextIterationBlockSize + pair) + 1];
|
||||
|
||||
for (int64_t numBlocks = 1; numBlocks < nComplex; numBlocks *= 2) {
|
||||
const float anglePart = PI / numBlocks * (inverse ? -1 : 1);
|
||||
const float twiddledOddReal = getRealFromComplexProd(twiddleReal, twiddleImag, oddReal, oddImag);
|
||||
const float twiddledOddImag =
|
||||
getImaginaryFromComplexProd(twiddleReal, twiddleImag, oddReal, oddImag);
|
||||
|
||||
std::swap(inBufferStart, outBufferStart);
|
||||
const int64_t blockSize = dataLength / numBlocks;
|
||||
const int64_t nextIterationBlockSize = blockSize / 2;
|
||||
curOutBufferPtr[pair] = evenReal + twiddledOddReal;
|
||||
curOutBufferPtr[pair + 1] = evenImag + twiddledOddImag;
|
||||
|
||||
curOutBufferPtr[nComplex + pair] = evenReal - twiddledOddReal;
|
||||
curOutBufferPtr[nComplex + pair + 1] = evenImag - twiddledOddImag;
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
size_t blockSize;
|
||||
size_t nextIterationBlockSize = dataLength;
|
||||
for (size_t numBlocks = 1; numBlocks < nComplex; numBlocks *= 2) {
|
||||
blockSize = nextIterationBlockSize;
|
||||
nextIterationBlockSize /= 2;
|
||||
if (parallelize && blockSize >= 4 * elementsPerCacheLine) {
|
||||
parallel_for(numBlocks, [&] (const size_t block) {
|
||||
blockIteration(block, blockSize, nextIterationBlockSize, anglePart);
|
||||
parallel_for(numBlocks, [&](const size_t block) {
|
||||
blockIteration(block, 1, nextIterationBlockSize);
|
||||
});
|
||||
} else {
|
||||
for (int64_t block = 0; block < numBlocks; ++block) {
|
||||
blockIteration(block, blockSize, nextIterationBlockSize, anglePart);
|
||||
}
|
||||
blockIteration(0, numBlocks, nextIterationBlockSize);
|
||||
}
|
||||
|
||||
std::swap(inBuffer, outBuffer);
|
||||
}
|
||||
if (inverse) {
|
||||
for (int64_t k = 0; k < dataLength; k++) {
|
||||
inBuffer[k] /= nComplex;
|
||||
}
|
||||
}
|
||||
|
||||
for (int64_t k = 0; k < dataLength; k++) {
|
||||
if (inverse) {
|
||||
outBufferStart[k] /= nComplex;
|
||||
}
|
||||
data[k] = outBufferStart[k];
|
||||
}
|
||||
*resultBuf = inBuffer;
|
||||
}
|
||||
|
||||
void DFT::naiveDFT(float* data, size_t dataLength) const {
|
||||
void DFT::naiveDFT(float* data, size_t dataLength, bool inverse) const {
|
||||
std::vector<float> outputBuffer(dataLength);
|
||||
const size_t nComplex = dataLength / 2;
|
||||
const auto& twiddles = twiddlesMap.find(nComplex)->second;
|
||||
const float reciprocalNComplex = 1.0f / nComplex;
|
||||
const auto& twiddles = twiddlesMapDFT.find(nComplex)->second;
|
||||
|
||||
parallel_for(nComplex, [&](size_t k) {
|
||||
float sumReal = 0.0f;
|
||||
float sumImag = 0.0f;
|
||||
for (size_t n = 0; n < nComplex; ++n) {
|
||||
auto it = twiddles[k * nComplex + n];
|
||||
float complexReal = it.first;
|
||||
float complexImag = it.second;
|
||||
std::function<void(size_t)> blockIteration;
|
||||
if (dftKernel != nullptr) {
|
||||
blockIteration = [&](size_t k) {
|
||||
auto arg = jit_args_dft();
|
||||
|
||||
arg.src = data;
|
||||
arg.dst = outputBuffer.data() + 2 * k;
|
||||
arg.twiddles = twiddles.data() + 2 * k * nComplex;
|
||||
arg.work_amount = nComplex;
|
||||
arg.index = k;
|
||||
|
||||
(*dftKernel)(&arg);
|
||||
|
||||
if (inverse) {
|
||||
complexImag *= -1; // conjugate
|
||||
outputBuffer[k * 2] *= reciprocalNComplex;
|
||||
outputBuffer[k * 2 + 1] *= reciprocalNComplex;
|
||||
}
|
||||
float complexProdReal = getRealFromComplexProd(data[2 * n], data[2 * n + 1], complexReal, complexImag);
|
||||
float complexProdImag = getImaginaryFromComplexProd(data[2 * n], data[2 * n + 1], complexReal, complexImag);
|
||||
};
|
||||
} else {
|
||||
blockIteration = [&](size_t k) {
|
||||
float sumReal = 0.0f;
|
||||
float sumImag = 0.0f;
|
||||
for (size_t n = 0; n < nComplex; ++n) {
|
||||
auto complexRef = &twiddles[2 * (k * nComplex + n)];
|
||||
float complexReal = *complexRef;
|
||||
float complexImag = *(complexRef + 1);
|
||||
|
||||
sumReal += complexProdReal;
|
||||
sumImag += complexProdImag;
|
||||
}
|
||||
float complexProdReal = getRealFromComplexProd(data[2 * n], data[2 * n + 1], complexReal, complexImag);
|
||||
float complexProdImag =
|
||||
getImaginaryFromComplexProd(data[2 * n], data[2 * n + 1], complexReal, complexImag);
|
||||
|
||||
if (inverse) {
|
||||
sumReal /= nComplex;
|
||||
sumImag /= nComplex;
|
||||
}
|
||||
outputBuffer[k * 2] = sumReal;
|
||||
outputBuffer[k * 2 + 1] = sumImag;
|
||||
});
|
||||
sumReal += complexProdReal;
|
||||
sumImag += complexProdImag;
|
||||
}
|
||||
|
||||
if (inverse) {
|
||||
sumReal *= reciprocalNComplex;
|
||||
sumImag *= reciprocalNComplex;
|
||||
}
|
||||
outputBuffer[k * 2] = sumReal;
|
||||
outputBuffer[k * 2 + 1] = sumImag;
|
||||
};
|
||||
}
|
||||
|
||||
parallel_for(nComplex, blockIteration);
|
||||
cpu_memcpy(data, outputBuffer.data(), dataLength * sizeof(float));
|
||||
}
|
||||
|
||||
std::vector<std::pair<float, float>> DFT::generateTwiddles(size_t n_complex) const {
|
||||
std::vector<std::pair<float, float>> twiddles(n_complex * n_complex);
|
||||
std::vector<float> DFT::generateTwiddlesDFT(size_t n_complex, bool inverse) const {
|
||||
std::vector<float> twiddles(n_complex * n_complex * 2);
|
||||
const float inverseMultiplier = inverse ? 1 : -1;
|
||||
parallel_for(n_complex, [&](const size_t k) {
|
||||
for (size_t n = 0; n < n_complex; ++n) {
|
||||
float phase = 2.0f * PI * static_cast<float>(n * k) / static_cast<float>(n_complex);
|
||||
float phase = 2.0f * PI * static_cast<float>(n * k) / static_cast<float>(n_complex);
|
||||
auto complexReal = std::cos(phase);
|
||||
auto complexImag = -std::sin(phase);
|
||||
twiddles[k * n_complex + n] = std::make_pair(complexReal, complexImag);
|
||||
auto complexImag = std::sin(phase) * inverseMultiplier;
|
||||
twiddles[2 * (k * n_complex + n)] = complexReal;
|
||||
twiddles[2 * (k * n_complex + n) + 1] = complexImag;
|
||||
}
|
||||
});
|
||||
return twiddles;
|
||||
}
|
||||
|
||||
void DFT::updateTwiddlesFFT(size_t n_complex, bool inverse) {
|
||||
const float inverseMultiplier = inverse ? 1 : -1;
|
||||
size_t numBlocks = 1;
|
||||
|
||||
twiddlesFFT.reserve((n_complex - 1) * 2);
|
||||
if (twiddlesFFT.size() == 0) {
|
||||
twiddlesFFT.emplace_back(1.0f); // cos(0)
|
||||
twiddlesFFT.emplace_back(-0.0f); // -sin(0)
|
||||
} else {
|
||||
for (size_t i = numBlocks; i < twiddlesFFT.size() / 2; i += numBlocks) {
|
||||
numBlocks *= 2;
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t i = twiddlesFFT.size() / 2; i < n_complex - 1; i += numBlocks) {
|
||||
numBlocks *= 2;
|
||||
|
||||
for (size_t blockNum = 0; blockNum < numBlocks; blockNum++) {
|
||||
size_t copyIndex = twiddlesFFT.size() - blockNum - numBlocks;
|
||||
|
||||
twiddlesFFT.push_back(twiddlesFFT[copyIndex]);
|
||||
twiddlesFFT.push_back(twiddlesFFT[copyIndex + 1]);
|
||||
|
||||
blockNum++;
|
||||
|
||||
float angle = PI * blockNum / numBlocks;
|
||||
auto complexReal = std::cos(angle);
|
||||
auto complexImag = std::sin(angle) * inverseMultiplier;
|
||||
|
||||
twiddlesFFT.emplace_back(complexReal);
|
||||
twiddlesFFT.emplace_back(complexImag);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool DFT::created() const {
|
||||
return getType() == Type::DFT;
|
||||
}
|
||||
|
||||
void DFT::createPrimitive() {}
|
||||
void DFT::prepareParams() {
|
||||
bool hasDFT = false;
|
||||
bool hasFFT = false;
|
||||
|
||||
axes = getAxes();
|
||||
const auto outputShape = getChildEdgesAtPort(0)[0]->getMemory().getStaticDims();
|
||||
|
||||
for (size_t axis : axes) {
|
||||
size_t nComplex = outputShape[axis];
|
||||
if (!IsPowerOfTwo(nComplex)) {
|
||||
hasDFT = true;
|
||||
} else {
|
||||
hasFFT = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (mayiuse(cpu::x64::sse41)) {
|
||||
createJITKernels(hasDFT, hasFFT);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<int32_t> DFT::getAxes() const {
|
||||
auto axesEdge = getParentEdgeAt(AXES_INDEX);
|
||||
const auto* axesStartPtr = reinterpret_cast<const int32_t*>(axesEdge->getMemoryPtr()->GetPtr());
|
||||
auto axes = std::vector<int32_t>(axesStartPtr, axesStartPtr + axesEdge->getMemory().getStaticDims()[0]);
|
||||
for (auto& axis : axes) {
|
||||
if (axis < 0) {
|
||||
axis += inputShape.size() - 1;
|
||||
}
|
||||
}
|
||||
std::sort(axes.begin(), axes.end());
|
||||
return axes;
|
||||
}
|
||||
|
||||
void DFT::createJITKernels(bool hasDFT, bool hasFFT) {
|
||||
if (hasDFT && dftKernel == nullptr) {
|
||||
if (mayiuse(cpu::x64::avx512_core)) {
|
||||
dftKernel.reset(new jit_uni_dft_kernel_f32<cpu::x64::avx512_core>());
|
||||
} else if (mayiuse(cpu::x64::avx2)) {
|
||||
dftKernel.reset(new jit_uni_dft_kernel_f32<cpu::x64::avx2>());
|
||||
} else if (mayiuse(cpu::x64::sse41)) {
|
||||
dftKernel.reset(new jit_uni_dft_kernel_f32<cpu::x64::sse41>());
|
||||
} else {
|
||||
IE_THROW() << "Can't create jit DFT kernel";
|
||||
}
|
||||
|
||||
if (dftKernel)
|
||||
dftKernel->create_ker();
|
||||
}
|
||||
|
||||
if (hasFFT && fftKernel == nullptr) {
|
||||
if (mayiuse(cpu::x64::avx512_core)) {
|
||||
fftKernel.reset(new jit_uni_fft_kernel_f32<cpu::x64::avx512_core>());
|
||||
} else if (mayiuse(cpu::x64::avx2)) {
|
||||
fftKernel.reset(new jit_uni_fft_kernel_f32<cpu::x64::avx2>());
|
||||
} else if (mayiuse(cpu::x64::sse41)) {
|
||||
fftKernel.reset(new jit_uni_fft_kernel_f32<cpu::x64::sse41>());
|
||||
} else {
|
||||
IE_THROW() << "Can't create jit FFT kernel";
|
||||
}
|
||||
|
||||
if (fftKernel)
|
||||
fftKernel->create_ker();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace node
|
||||
} // namespace intel_cpu
|
||||
|
@ -8,6 +8,8 @@
|
||||
#include <node.h>
|
||||
#include <string>
|
||||
|
||||
#include "kernels/dft_uni_kernel.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace intel_cpu {
|
||||
namespace node {
|
||||
@ -19,29 +21,50 @@ public:
|
||||
|
||||
void getSupportedDescriptors() override;
|
||||
void initSupportedPrimitiveDescriptors() override;
|
||||
void createPrimitive() override;
|
||||
void execute(dnnl::stream strm) override;
|
||||
bool created() const override;
|
||||
|
||||
void prepareParams() override;
|
||||
|
||||
static bool isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, std::string& errorMessage) noexcept;
|
||||
|
||||
private:
|
||||
void dftNd(float* output, const std::vector<size_t>& outputStrides) const;
|
||||
void fft(float* data, int64_t dataLength, bool parallelize = false) const;
|
||||
void naiveDFT(float* data, size_t dataLength) const;
|
||||
std::vector<int32_t> getAxes() const;
|
||||
void createJITKernels(bool hasDFT, bool hasFFT);
|
||||
|
||||
std::vector<std::pair<float, float>> generateTwiddles(size_t n_complex) const;
|
||||
void dftNd(float* output,
|
||||
const VectorDims& outputShape,
|
||||
const VectorDims& outputStrides,
|
||||
const std::vector<int32_t>& axes,
|
||||
bool inverse) const;
|
||||
|
||||
void fft(float* inBuffer,
|
||||
float* outBuffer,
|
||||
int64_t dataLength,
|
||||
bool inverse,
|
||||
bool parallelize,
|
||||
const float** resultBuf) const;
|
||||
void naiveDFT(float* data, size_t dataLength, bool inverse) const;
|
||||
|
||||
std::vector<float> generateTwiddlesDFT(size_t n_complex, bool inverse) const;
|
||||
void updateTwiddlesFFT(size_t n_complex, bool inverse);
|
||||
|
||||
std::unique_ptr<jit_uni_dft_kernel> dftKernel = nullptr;
|
||||
std::unique_ptr<jit_uni_fft_kernel> fftKernel = nullptr;
|
||||
|
||||
std::vector<float> twiddlesFFT;
|
||||
std::unordered_map<size_t, std::vector<float>> twiddlesMapDFT;
|
||||
|
||||
std::unordered_map<size_t, std::vector<std::pair<float, float>>> twiddlesMap;
|
||||
std::vector<int32_t> axes;
|
||||
std::vector<size_t> outputShape;
|
||||
std::vector<size_t> inputShape;
|
||||
std::string layerErrorPrefix;
|
||||
const size_t DATA_INDEX = 0;
|
||||
const size_t AXES_INDEX = 1;
|
||||
const size_t SIGNAL_SIZE_INDEX = 2;
|
||||
static constexpr float PI = 3.141592653589793238462643f;
|
||||
|
||||
bool inverse;
|
||||
bool lastInverse;
|
||||
};
|
||||
|
||||
} // namespace node
|
||||
|
246
src/plugins/intel_cpu/src/nodes/kernels/dft_uni_kernel.cpp
Normal file
246
src/plugins/intel_cpu/src/nodes/kernels/dft_uni_kernel.cpp
Normal file
@ -0,0 +1,246 @@
|
||||
// Copyright (C) 2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "dft_uni_kernel.hpp"
|
||||
|
||||
|
||||
using namespace dnnl::impl;
|
||||
using namespace dnnl::impl::utils;
|
||||
using namespace dnnl::impl::cpu::x64;
|
||||
|
||||
#define GET_OFF_DFT(field) offsetof(jit_args_dft, field)
|
||||
#define GET_OFF_FFT(field) offsetof(jit_args_fft, field)
|
||||
|
||||
namespace ov {
|
||||
namespace intel_cpu {
|
||||
|
||||
template <cpu::x64::cpu_isa_t isa>
|
||||
jit_uni_dft_kernel_f32<isa>::jit_uni_dft_kernel_f32() : jit_uni_dft_kernel(), jit_generator(jit_name()) {}
|
||||
|
||||
template <cpu::x64::cpu_isa_t isa>
|
||||
void jit_uni_dft_kernel_f32<isa>::create_ker() {
|
||||
jit_generator::create_kernel();
|
||||
ker_ = (decltype(ker_))jit_ker();
|
||||
}
|
||||
|
||||
template <cpu::x64::cpu_isa_t isa>
|
||||
void jit_uni_dft_kernel_f32<isa>::generate() {
|
||||
this->preamble();
|
||||
|
||||
mov(reg_src, ptr[reg_params + GET_OFF_DFT(src)]);
|
||||
mov(reg_dst, ptr[reg_params + GET_OFF_DFT(dst)]);
|
||||
mov(reg_twiddles, ptr[reg_params + GET_OFF_DFT(twiddles)]);
|
||||
mov(reg_work_amount, ptr[reg_params + GET_OFF_DFT(work_amount)]);
|
||||
mov(reg_index, ptr[reg_params + GET_OFF_DFT(index)]);
|
||||
|
||||
Xbyak::Label main_loop_label;
|
||||
Xbyak::Label main_loop_end_label;
|
||||
Xbyak::Label tail_loop_label;
|
||||
Xbyak::Label tail_loop_end_label;
|
||||
|
||||
uni_vpxor(vmm_sum, vmm_sum, vmm_sum);
|
||||
|
||||
int step = vlen / 8;
|
||||
|
||||
L(main_loop_label);
|
||||
{
|
||||
cmp(reg_work_amount, step);
|
||||
jl(main_loop_end_label, T_NEAR);
|
||||
|
||||
uni_vmovups(vmm_data_cache, ptr[reg_src]);
|
||||
uni_vmovups(vmm_twiddles_cache, ptr[reg_twiddles]);
|
||||
|
||||
uni_vshufps(vmm_data, vmm_data_cache, vmm_data_cache, 0b01000001);
|
||||
uni_vshufps(vmm_twiddles, vmm_twiddles_cache, vmm_twiddles_cache, 0b01000100);
|
||||
uni_vfmadd231ps(vmm_sum, vmm_data, vmm_twiddles);
|
||||
|
||||
uni_vshufps(vmm_data, vmm_data_cache, vmm_data_cache, 0b11101011);
|
||||
uni_vshufps(vmm_twiddles, vmm_twiddles_cache, vmm_twiddles_cache, 0b11101110);
|
||||
uni_vfmadd231ps(vmm_sum, vmm_data, vmm_twiddles);
|
||||
|
||||
add(reg_twiddles, 2 * step * sizeof(float));
|
||||
add(reg_src, 2 * step * sizeof(float));
|
||||
|
||||
sub(reg_work_amount, step);
|
||||
jmp(main_loop_label, T_NEAR);
|
||||
}
|
||||
L(main_loop_end_label);
|
||||
|
||||
if (mayiuse(cpu::x64::avx512_core)) {
|
||||
Xbyak::Zmm zmm_sum = Xbyak::Zmm(vmm_sum.getIdx());
|
||||
Xbyak::Ymm ymm_sum = Xbyak::Ymm(vmm_sum.getIdx());
|
||||
Xbyak::Ymm ymm_sum_2 = Xbyak::Ymm(vmm_sum_2.getIdx());
|
||||
|
||||
vextractf64x4(ymm_sum_2, zmm_sum, 1);
|
||||
vaddps(ymm_sum, ymm_sum, ymm_sum_2);
|
||||
}
|
||||
if (mayiuse(cpu::x64::avx2)) {
|
||||
Xbyak::Ymm ymm_sum = Xbyak::Ymm(vmm_sum.getIdx());
|
||||
|
||||
vextractf128(xmm_sum_2, ymm_sum, 1);
|
||||
vaddps(xmm_sum, xmm_sum, xmm_sum_2);
|
||||
}
|
||||
|
||||
L(tail_loop_label);
|
||||
{
|
||||
cmp(reg_work_amount, 1);
|
||||
jl(tail_loop_end_label, T_NEAR);
|
||||
|
||||
uni_vmovups(xmm_data, ptr[reg_src]);
|
||||
uni_vmovups(xmm_twiddles, ptr[reg_twiddles]);
|
||||
uni_vshufps(xmm_data, xmm_data, xmm_data, 0b01000001);
|
||||
uni_vshufps(xmm_twiddles, xmm_twiddles, xmm_twiddles, 0b01000100);
|
||||
uni_vfmadd231ps(xmm_sum, xmm_data, xmm_twiddles);
|
||||
|
||||
add(reg_twiddles, 2 * sizeof(float));
|
||||
add(reg_src, 2 * sizeof(float));
|
||||
|
||||
sub(reg_work_amount, 1);
|
||||
jmp(tail_loop_label, T_NEAR);
|
||||
}
|
||||
L(tail_loop_end_label);
|
||||
|
||||
uni_vmovhlps(xmm_sum_2, xmm_sum_2, xmm_sum);
|
||||
uni_vhsubps(xmm_sum_2, xmm_sum_2, xmm_sum_2);
|
||||
uni_vhaddps(xmm_sum, xmm_sum, xmm_sum);
|
||||
|
||||
uni_vmovss(ptr[reg_dst], xmm_sum_2);
|
||||
uni_vmovss(ptr[reg_dst + sizeof(float)], xmm_sum);
|
||||
|
||||
this->postamble();
|
||||
}
|
||||
|
||||
template struct jit_uni_dft_kernel_f32<cpu::x64::sse41>;
|
||||
template struct jit_uni_dft_kernel_f32<cpu::x64::avx2>;
|
||||
template struct jit_uni_dft_kernel_f32<cpu::x64::avx512_core>;
|
||||
|
||||
|
||||
template <cpu::x64::cpu_isa_t isa>
|
||||
jit_uni_fft_kernel_f32<isa>::jit_uni_fft_kernel_f32()
|
||||
: jit_uni_fft_kernel(),
|
||||
jit_generator(jit_name()) {}
|
||||
|
||||
template <cpu::x64::cpu_isa_t isa>
|
||||
void jit_uni_fft_kernel_f32<isa>::create_ker() {
|
||||
jit_generator::create_kernel();
|
||||
ker_ = (decltype(ker_))jit_ker();
|
||||
}
|
||||
|
||||
template <cpu::x64::cpu_isa_t isa>
|
||||
void jit_uni_fft_kernel_f32<isa>::generate() {
|
||||
this->preamble();
|
||||
|
||||
mov(reg_src, ptr[reg_params + GET_OFF_FFT(src)]);
|
||||
mov(reg_dst, ptr[reg_params + GET_OFF_FFT(dst)]);
|
||||
mov(reg_twiddles_addr, ptr[reg_params + GET_OFF_FFT(twiddles)]);
|
||||
|
||||
mov(reg_num_blocks, ptr[reg_params + GET_OFF_FFT(num_blocks)]);
|
||||
mov(reg_work_amount, ptr[reg_params + GET_OFF_FFT(work_amount)]);
|
||||
|
||||
mov(reg_even_in_diff, sizeof(float));
|
||||
mul(ptr[reg_params + GET_OFF_FFT(n_complex)]);
|
||||
mov(reg_even_out_diff, reg_even_in_diff);
|
||||
|
||||
mov(reg_even_in_diff, sizeof(float));
|
||||
mul(reg_work_amount);
|
||||
|
||||
Xbyak::Label block_loop_label;
|
||||
Xbyak::Label block_loop_end_label;
|
||||
|
||||
L(block_loop_label);
|
||||
{
|
||||
cmp(reg_num_blocks, 1);
|
||||
jl(block_loop_end_label, T_NEAR);
|
||||
|
||||
mov(aux_reg_work_amount, reg_work_amount);
|
||||
uni_vbroadcastss(vmm_twiddle_real, ptr[reg_twiddles_addr]);
|
||||
uni_vbroadcastss(vmm_twiddle_imag, ptr[reg_twiddles_addr + sizeof(float)]);
|
||||
|
||||
if (mayiuse(cpu::x64::avx2)) {
|
||||
loop_process<Vmm>(vlen / 4);
|
||||
}
|
||||
loop_process<Xbyak::Xmm>(4);
|
||||
loop_process<Xbyak::Xmm>(2);
|
||||
|
||||
add(reg_twiddles_addr, 2 * sizeof(float));
|
||||
add(reg_src, reg_even_in_diff);
|
||||
sub(reg_num_blocks, 1);
|
||||
|
||||
jmp(block_loop_label, T_NEAR);
|
||||
}
|
||||
L(block_loop_end_label);
|
||||
|
||||
this->postamble();
|
||||
}
|
||||
|
||||
template <cpu::x64::cpu_isa_t isa>
|
||||
template <typename T>
|
||||
void jit_uni_fft_kernel_f32<isa>::loop_process(int step) {
|
||||
T reg_data_odd_1 = T(vmm_data_odd_1.getIdx());
|
||||
T reg_data_odd_2 = T(vmm_data_odd_2.getIdx());
|
||||
T reg_twiddle_imag = T(vmm_twiddle_imag.getIdx());
|
||||
T reg_twiddle_real = T(vmm_twiddle_real.getIdx());
|
||||
T reg_data_even = T(vmm_data_even.getIdx());
|
||||
T reg_data_result = T(vmm_data_result.getIdx());
|
||||
|
||||
Xbyak::Label loop_label;
|
||||
Xbyak::Label loop_end_label;
|
||||
|
||||
L(loop_label);
|
||||
{
|
||||
cmp(aux_reg_work_amount, step);
|
||||
jl(loop_end_label, T_NEAR);
|
||||
|
||||
move_data(reg_data_odd_1, ptr[reg_src + reg_even_in_diff], step);
|
||||
uni_vshufps(reg_data_odd_2, reg_data_odd_1, reg_data_odd_1, 0b10110001);
|
||||
uni_vmulps(reg_data_odd_2, reg_data_odd_2, reg_twiddle_imag);
|
||||
|
||||
if (mayiuse(cpu::x64::avx512_core)) {
|
||||
vfmaddsub213ps(reg_data_odd_1, reg_twiddle_real, reg_data_odd_2);
|
||||
} else {
|
||||
uni_vmulps(reg_data_odd_1, reg_data_odd_1, reg_twiddle_real);
|
||||
uni_vaddsubps(reg_data_odd_1, reg_data_odd_1, reg_data_odd_2);
|
||||
}
|
||||
|
||||
move_data(reg_data_even, ptr[reg_src], step);
|
||||
|
||||
uni_vaddps(reg_data_result, reg_data_even, reg_data_odd_1);
|
||||
move_data(ptr[reg_dst], reg_data_result, step);
|
||||
|
||||
uni_vsubps(reg_data_result, reg_data_even, reg_data_odd_1);
|
||||
move_data(ptr[reg_dst + reg_even_out_diff], reg_data_result, step);
|
||||
|
||||
add(reg_src, step * sizeof(float));
|
||||
add(reg_dst, step * sizeof(float));
|
||||
|
||||
sub(aux_reg_work_amount, step);
|
||||
jmp(loop_label, T_NEAR);
|
||||
}
|
||||
L(loop_end_label);
|
||||
}
|
||||
|
||||
template <cpu::x64::cpu_isa_t isa>
|
||||
void jit_uni_fft_kernel_f32<isa>::move_data(const Xbyak::Address& addr, const Xbyak::Xmm& x, int count) {
|
||||
if (count == 2) {
|
||||
uni_vmovq(addr, x);
|
||||
} else {
|
||||
uni_vmovups(addr, x);
|
||||
}
|
||||
}
|
||||
|
||||
template <cpu::x64::cpu_isa_t isa>
|
||||
void jit_uni_fft_kernel_f32<isa>::move_data(const Xbyak::Xmm& x, const Xbyak::Address& addr, int count) {
|
||||
if (count == 2) {
|
||||
uni_vmovq(x, addr);
|
||||
} else {
|
||||
uni_vmovups(x, addr);
|
||||
}
|
||||
}
|
||||
|
||||
template struct jit_uni_fft_kernel_f32<cpu::x64::sse41>;
|
||||
template struct jit_uni_fft_kernel_f32<cpu::x64::avx2>;
|
||||
template struct jit_uni_fft_kernel_f32<cpu::x64::avx512_core>;
|
||||
|
||||
} // namespace intel_cpu
|
||||
} // namespace ov
|
142
src/plugins/intel_cpu/src/nodes/kernels/dft_uni_kernel.hpp
Normal file
142
src/plugins/intel_cpu/src/nodes/kernels/dft_uni_kernel.hpp
Normal file
@ -0,0 +1,142 @@
|
||||
// Copyright (C) 2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cpu/x64/cpu_isa_traits.hpp>
|
||||
#include <cpu/x64/jit_generator.hpp>
|
||||
|
||||
namespace ov {
|
||||
namespace intel_cpu {
|
||||
|
||||
struct jit_args_dft {
|
||||
const float* src;
|
||||
float* dst;
|
||||
const float* twiddles;
|
||||
|
||||
size_t work_amount;
|
||||
size_t index;
|
||||
};
|
||||
|
||||
struct jit_args_fft {
|
||||
const float* src;
|
||||
float* dst;
|
||||
const float* twiddles;
|
||||
|
||||
size_t num_blocks;
|
||||
size_t work_amount;
|
||||
size_t n_complex;
|
||||
};
|
||||
|
||||
struct jit_uni_dft_kernel {
|
||||
void (*ker_)(const jit_args_dft*);
|
||||
|
||||
void operator()(const jit_args_dft* args) {
|
||||
assert(ker_);
|
||||
ker_(args);
|
||||
}
|
||||
|
||||
jit_uni_dft_kernel() : ker_(nullptr) {}
|
||||
virtual ~jit_uni_dft_kernel() {}
|
||||
|
||||
virtual void create_ker() = 0;
|
||||
};
|
||||
|
||||
struct jit_uni_fft_kernel {
|
||||
void (*ker_)(const jit_args_fft*);
|
||||
|
||||
void operator()(const jit_args_fft* args) {
|
||||
assert(ker_);
|
||||
ker_(args);
|
||||
}
|
||||
|
||||
jit_uni_fft_kernel() : ker_(nullptr) {}
|
||||
virtual ~jit_uni_fft_kernel() {}
|
||||
|
||||
virtual void create_ker() = 0;
|
||||
};
|
||||
|
||||
template <dnnl::impl::cpu::x64::cpu_isa_t isa>
|
||||
struct jit_uni_dft_kernel_f32 : public jit_uni_dft_kernel, public dnnl::impl::cpu::x64::jit_generator {
|
||||
DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_dft_kernel_f32)
|
||||
|
||||
jit_uni_dft_kernel_f32();
|
||||
|
||||
void create_ker() override;
|
||||
void generate() override;
|
||||
|
||||
private:
|
||||
using Vmm = typename dnnl::impl::utils::conditional3<isa == dnnl::impl::cpu::x64::sse41,
|
||||
Xbyak::Xmm,
|
||||
isa == dnnl::impl::cpu::x64::avx2,
|
||||
Xbyak::Ymm,
|
||||
Xbyak::Zmm>::type;
|
||||
size_t vlen = dnnl::impl::cpu::x64::cpu_isa_traits<isa>::vlen;
|
||||
|
||||
Xbyak::Reg64 reg_src = r8;
|
||||
Xbyak::Reg64 reg_dst = r9;
|
||||
Xbyak::Reg64 reg_twiddles = r10;
|
||||
Xbyak::Reg64 reg_work_amount = r11;
|
||||
Xbyak::Reg64 reg_index = r12;
|
||||
Xbyak::Reg64 reg_params = Xbyak::Reg64(dnnl::impl::cpu::x64::abi_param_regs[0]);
|
||||
|
||||
Vmm vmm_data = Vmm(0);
|
||||
Vmm vmm_twiddles = Vmm(1);
|
||||
Vmm vmm_sum = Vmm(2);
|
||||
Vmm vmm_sum_2 = vmm_data;
|
||||
Vmm vmm_data_cache = Vmm(3);
|
||||
Vmm vmm_twiddles_cache = Vmm(4);
|
||||
|
||||
Xbyak::Xmm xmm_data = Xbyak::Xmm(0);
|
||||
Xbyak::Xmm xmm_twiddles = Xbyak::Xmm(1);
|
||||
Xbyak::Xmm xmm_sum = Xbyak::Xmm(2);
|
||||
Xbyak::Xmm xmm_sum_2 = xmm_data;
|
||||
};
|
||||
|
||||
template <dnnl::impl::cpu::x64::cpu_isa_t isa>
|
||||
struct jit_uni_fft_kernel_f32 : public jit_uni_fft_kernel, public dnnl::impl::cpu::x64::jit_generator {
|
||||
DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_fft_kernel_f32)
|
||||
|
||||
jit_uni_fft_kernel_f32();
|
||||
|
||||
void create_ker() override;
|
||||
void generate() override;
|
||||
|
||||
private:
|
||||
using Vmm = typename dnnl::impl::utils::conditional3<isa == dnnl::impl::cpu::x64::sse41,
|
||||
Xbyak::Xmm,
|
||||
isa == dnnl::impl::cpu::x64::avx2,
|
||||
Xbyak::Ymm,
|
||||
Xbyak::Zmm>::type;
|
||||
const size_t vlen = dnnl::impl::cpu::x64::cpu_isa_traits<isa>::vlen;
|
||||
|
||||
Xbyak::Reg64 reg_even_in_diff = rax;
|
||||
Xbyak::Reg64 reg_even_out_diff = rbx;
|
||||
|
||||
Xbyak::Reg64 reg_src = r9;
|
||||
Xbyak::Reg64 reg_dst = r10;
|
||||
Xbyak::Reg64 reg_num_blocks = r11;
|
||||
Xbyak::Reg64 reg_work_amount = r12;
|
||||
Xbyak::Reg64 aux_reg_work_amount = r13;
|
||||
Xbyak::Reg64 reg_twiddles_addr = r14;
|
||||
Xbyak::Reg64 reg_params = Xbyak::Reg64(dnnl::impl::cpu::x64::abi_param_regs[0]);
|
||||
|
||||
Vmm vmm_data_odd_1 = Vmm(0);
|
||||
Vmm vmm_data_odd_2 = Vmm(1);
|
||||
Vmm vmm_twiddle_real = Vmm(2);
|
||||
Vmm vmm_twiddle_imag = Vmm(3);
|
||||
Vmm vmm_data_even = Vmm(4);
|
||||
|
||||
Vmm vmm_data_result = vmm_data_odd_2;
|
||||
|
||||
|
||||
template <typename T>
|
||||
void loop_process(int step);
|
||||
|
||||
void move_data(const Xbyak::Address& addr, const Xbyak::Xmm& x, int count);
|
||||
void move_data(const Xbyak::Xmm& x, const Xbyak::Address& addr, int count);
|
||||
};
|
||||
|
||||
} // namespace intel_cpu
|
||||
} // namespace ov
|
Loading…
Reference in New Issue
Block a user