[CPU] RDFT kernel optimizations (#14060)

This commit is contained in:
Mateusz Tabaka 2022-12-16 09:18:27 +01:00 committed by GitHub
parent 7facf8b90b
commit 43164a6b25
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 209 additions and 228 deletions

View File

@ -15,8 +15,6 @@ void jit_dft_kernel_f32<isa>::generate() {
using namespace Xbyak::util;
using Xbyak::Label;
using Xbyak::Xmm;
using Vmm = typename conditional3<isa == cpu::x64::sse41, Xbyak::Xmm,
isa == cpu::x64::avx2, Xbyak::Ymm, Xbyak::Zmm>::type;
this->preamble();
@ -37,8 +35,9 @@ void jit_dft_kernel_f32<isa>::generate() {
output_type_size = type_size;
break;
}
int vlen = cpu_isa_traits<isa>::vlen;
const int simd_size = vlen / output_type_size;
int simd_size = vlen / output_type_size;
if (kernel_type_ == complex_to_complex)
simd_size = vlen / type_size;
mov(input_ptr, ptr[param1 + GET_OFF(input)]);
mov(input_size, ptr[param1 + GET_OFF(input_size)]);
@ -63,23 +62,40 @@ void jit_dft_kernel_f32<isa>::generate() {
Vmm vmm_signal_size = Vmm(reg_idx);
if (is_inverse_) {
reg_idx++;
uni_vbroadcastss(Vmm(reg_idx), ptr[param1 + GET_OFF(signal_size)]);
uni_vcvtdq2ps(vmm_signal_size, Vmm(reg_idx));
uni_vbroadcastss(vmm_signal_size, ptr[param1 + GET_OFF(signal_size)]);
uni_vcvtdq2ps(vmm_signal_size, vmm_signal_size);
}
Vmm vmm_neg_mask = Vmm(reg_idx);
Xmm xmm_neg_mask = Xmm(reg_idx);
Xmm neg_mask = Xmm(reg_idx);
if (kernel_type_ == complex_to_complex) {
reg_idx++;
if (!is_inverse_) {
mov(rax, 1ULL << 31);
} else {
mov(rax, 1ULL << 63);
}
uni_vmovq(xmm_neg_mask, rax);
uni_vbroadcastsd(vmm_neg_mask, xmm_neg_mask);
uni_vpxor(neg_mask, neg_mask, neg_mask);
mov(rax, 1ULL << 63);
uni_vmovq(neg_mask, rax);
}
size_t vmm_reg_idx = reg_idx;
Vmm inp_real = Vmm(vmm_reg_idx++);
Vmm inp_imag = Vmm(vmm_reg_idx++);
Vmm cos = Vmm(vmm_reg_idx++);
Vmm sin = Vmm(vmm_reg_idx++);
const Vmm& twiddles = cos;
Vmm tmp = Vmm(vmm_reg_idx++);
Vmm output_real = Vmm(vmm_reg_idx++);
Vmm output_imag = Vmm(vmm_reg_idx++);
const Vmm& output = output_real;
perm_low = Vmm(vmm_reg_idx++);
perm_high = Vmm(vmm_reg_idx++);
mov(rax, reinterpret_cast<uint64_t>(perm_low_values.data()));
uni_vmovups(perm_low, ptr[rax]);
mov(rax, reinterpret_cast<uint64_t>(perm_high_values.data()));
uni_vmovups(perm_high, ptr[rax]);
Xmm xmm_input = Xbyak::Xmm(reg_idx++);
Xmm xmm_twiddles = Xbyak::Xmm(reg_idx++);
Xmm xmm_output = Xbyak::Xmm(reg_idx++);
mov(rax, signal_size);
and_(rax, 1);
setz(is_signal_size_even);
@ -89,60 +105,68 @@ void jit_dft_kernel_f32<isa>::generate() {
Label loop_simd;
Label loop_nonsimd;
auto simd_loop = [this, vlen, simd_size,
input_type_size, reg_idx,
&vmm_signal_size,
&xmm_neg_mask,
&vmm_neg_mask] {
size_t idx = reg_idx;
Vmm result = Vmm(idx++);
Vmm inp_real = Vmm(idx++);
Vmm inp_imag = Vmm(idx++);
const Vmm& input = inp_real;
const Vmm& input_perm = inp_imag;
Vmm twiddles = Vmm(idx++);
const Vmm& cos = twiddles;
Vmm sin = Vmm(idx++);
Xmm tmp = Xmm(idx++);
uni_vpxor(result, result, result);
if (kernel_type_ == complex_to_complex && is_inverse_) {
mov(rdx, 1ULL << 63);
uni_vmovq(xmm_neg_mask, rdx);
uni_vbroadcastsd(vmm_neg_mask, xmm_neg_mask);
auto simd_loop = [&] {
if (kernel_type_ == complex_to_complex) {
uni_vpxor(output_real, output_real, output_real);
uni_vpxor(output_imag, output_imag, output_imag);
} else {
uni_vpxor(output, output, output);
}
auto c2r_kernel = [&] (bool backwards) {
// if backwards == false:
// output_real += input_real * cos(..) - input_imag * sin(..)
// else:
// output_real += input_real * cos(..) + input_imag * sin(..)
uni_vbroadcastss(inp_real, ptr[input_ptr]);
uni_vbroadcastss(inp_imag, ptr[input_ptr + type_size]);
uni_vmovups(cos, ptr[twiddles_ptr]);
uni_vmovups(sin, ptr[twiddles_ptr + vlen]);
uni_vfmadd231ps(output, inp_real, cos);
if (!backwards) {
uni_vfnmadd231ps(output, inp_imag, sin);
} else {
uni_vfmadd231ps(output, inp_imag, sin);
}
add(twiddles_ptr, 2 * vlen);
};
auto c2c_kernel = [&] (bool backwards) {
// if backwards == false:
// output_real += input_real * cos(..) - input_imag * sin(..)
// output_imag += input_imag * cos(..) + input_real * sin(..)
// else:
// output_real += input_real * cos(..) + input_imag * sin(..)
// output_imag += input_imag * cos(..) - input_real * sin(..)
uni_vbroadcastss(inp_real, ptr[input_ptr]);
uni_vbroadcastss(inp_imag, ptr[input_ptr + type_size]);
uni_vmovups(cos, ptr[twiddles_ptr]);
uni_vmovups(sin, ptr[twiddles_ptr + vlen]);
uni_vfmadd231ps(output_real, inp_real, cos);
uni_vfmadd231ps(output_imag, inp_imag, cos);
if (!backwards) {
uni_vfnmadd231ps(output_real, inp_imag, sin);
uni_vfmadd231ps(output_imag, inp_real, sin);
} else {
uni_vfmadd231ps(output_real, inp_imag, sin);
uni_vfnmadd231ps(output_imag, inp_real, sin);
}
add(twiddles_ptr, 2 * vlen);
};
Label loop;
L(loop);
{
if (kernel_type_ == real_to_complex) {
uni_vbroadcastss(inp_real, ptr[input_ptr]);
uni_vmovups(twiddles, ptr[twiddles_ptr]);
uni_vfmadd231ps(result, inp_real, twiddles);
uni_vfmadd231ps(output, inp_real, twiddles);
add(twiddles_ptr, vlen);
} else if (kernel_type_ == complex_to_real) {
uni_vbroadcastss(inp_real, ptr[input_ptr]);
uni_vbroadcastss(inp_imag, ptr[input_ptr + type_size]);
uni_vmovups(cos, ptr[twiddles_ptr]);
uni_vmovups(sin, ptr[twiddles_ptr + vlen]);
uni_vfmadd231ps(result, inp_real, cos);
uni_vfmadd231ps(result, inp_imag, sin);
add(twiddles_ptr, 2 * vlen);
c2r_kernel(false);
} else if (kernel_type_ == complex_to_complex) {
// output_real += input_real * cos(..) - input_imag * sin(..)
// output_imag += input_imag * cos(..) + input_real * sin(..)
uni_vbroadcastsd(input, ptr[input_ptr]);
uni_vpermilps(input_perm, input, 0b10110001); // swap real with imag
uni_vpxor(input_perm, input_perm, vmm_neg_mask); // negate imag part (or real part if is_inverse == true)
load_and_broadcast_every_other_elem(cos, twiddles_ptr, tmp);
load_and_broadcast_every_other_elem(sin, twiddles_ptr + vlen / 2, tmp);
uni_vfmadd231ps(result, input, cos);
uni_vfmadd231ps(result, input_perm, sin);
add(twiddles_ptr, vlen);
c2c_kernel(false);
}
add(input_ptr, input_type_size);
@ -159,12 +183,6 @@ void jit_dft_kernel_f32<isa>::generate() {
mov(input_size, signal_size);
sub(input_size, ptr[param1 + GET_OFF(input_size)]);
if (kernel_type_ == complex_to_complex) {
mov(rdx, 1ULL << 31);
uni_vmovq(xmm_neg_mask, rdx);
uni_vbroadcastsd(vmm_neg_mask, xmm_neg_mask);
}
test(is_signal_size_even, 1);
jz(loop_backwards);
@ -176,26 +194,11 @@ void jit_dft_kernel_f32<isa>::generate() {
je(loop_backwards_exit, T_NEAR);
sub(input_ptr, input_type_size);
if (kernel_type_ == complex_to_real) {
uni_vbroadcastss(inp_real, ptr[input_ptr]);
uni_vbroadcastss(inp_imag, ptr[input_ptr + type_size]);
uni_vmovups(cos, ptr[twiddles_ptr]);
uni_vmovups(sin, ptr[twiddles_ptr + vlen]);
uni_vfmadd231ps(result, inp_real, cos);
uni_vfnmadd231ps(result, inp_imag, sin);
add(twiddles_ptr, 2 * vlen);
if (kernel_type_ == complex_to_real) {
c2r_kernel(true);
} else if (kernel_type_ == complex_to_complex) {
// output_real += input_real * cos(..) - input_imag * sin(..)
// output_imag += input_imag * cos(..) + input_real * sin(..)
uni_vbroadcastsd(input, ptr[input_ptr]);
uni_vpermilps(input_perm, input, 0b10110001); // swap real with imag
uni_vpxor(input_perm, input_perm, vmm_neg_mask); // negate imag part
load_and_broadcast_every_other_elem(cos, twiddles_ptr, tmp);
load_and_broadcast_every_other_elem(sin, twiddles_ptr + vlen / 2, tmp);
uni_vfmadd231ps(result, input, cos);
uni_vfmadd231ps(result, input_perm, sin);
add(twiddles_ptr, vlen);
c2c_kernel(true);
}
dec(input_size);
@ -205,83 +208,77 @@ void jit_dft_kernel_f32<isa>::generate() {
}
if (is_inverse_) {
uni_vdivps(result, result, vmm_signal_size);
uni_vdivps(output_real, output_real, vmm_signal_size);
uni_vdivps(output_imag, output_imag, vmm_signal_size);
}
// store the results
if (kernel_type_ == complex_to_complex) {
interleave_and_store(output_real, output_imag, output_ptr, tmp);
add(output_ptr, 2 * vlen);
} else {
uni_vmovups(ptr[output_ptr], output);
add(output_ptr, vlen);
}
// store the results
uni_vmovups(ptr[output_ptr], result);
add(output_ptr, vlen);
sub(output_end, simd_size);
};
auto nonsimd_loop = [this,
input_type_size,
output_type_size,
&xmm_signal_size,
reg_idx] {
size_t idx = reg_idx;
Xmm xmm_inp_real = Xbyak::Xmm(idx++);
Xmm xmm_inp_imag = Xbyak::Xmm(idx++);
Xmm xmm_real = Xbyak::Xmm(idx++);
Xmm xmm_imag = Xbyak::Xmm(idx++);
Xmm xmm_cos = Xbyak::Xmm(idx++);
Xmm xmm_sin = Xbyak::Xmm(idx++);
auto nonsimd_loop = [&] {
uni_vxorps(xmm_output, xmm_output, xmm_output);
if (kernel_type_ != complex_to_real) {
xorps(xmm_real, xmm_real);
xorps(xmm_imag, xmm_imag);
} else {
xorps(xmm_real, xmm_real);
}
auto c2r_kernel = [&] (bool backwards) {
// if backwards == false:
// output_real += input_real * cos(..) - input_imag * sin(..)
// else:
// output_real += input_real * cos(..) + input_imag * sin(..)
uni_vmovq(xmm_input, ptr[input_ptr]);
uni_vmovq(xmm_twiddles, ptr[twiddles_ptr]);
uni_vmulps(xmm_input, xmm_input, xmm_twiddles);
if (!backwards) {
uni_vhsubps(xmm_input, xmm_input, xmm_input);
} else {
uni_vhaddps(xmm_input, xmm_input, xmm_input);
}
uni_vaddss(xmm_output, xmm_output, xmm_input);
};
auto c2c_kernel = [&] (bool backwards) {
// if backwards == false:
// output_real += input_real * cos(..) - input_imag * sin(..)
// output_imag += input_imag * cos(..) + input_real * sin(..)
// else:
// output_real += input_real * cos(..) + input_imag * sin(..)
// output_imag += input_imag * cos(..) - input_real * sin(..)
uni_vmovq(xmm_input, ptr[input_ptr]);
uni_vshufps(xmm_input, xmm_input, xmm_input, 0b00010100);
uni_vmovq(xmm_twiddles, ptr[twiddles_ptr]);
uni_vshufps(xmm_twiddles, xmm_twiddles, xmm_twiddles, 0b01000100);
uni_vxorps(xmm_twiddles, xmm_twiddles, neg_mask);
uni_vmulps(xmm_input, xmm_input, xmm_twiddles);
if (!backwards) {
uni_vhaddps(xmm_input, xmm_input, xmm_input);
} else {
uni_vhsubps(xmm_input, xmm_input, xmm_input);
}
uni_vaddps(xmm_output, xmm_output, xmm_input);
};
Label loop;
L(loop);
{
movss(xmm_cos, ptr[twiddles_ptr]);
movss(xmm_sin, ptr[twiddles_ptr + type_size]);
if (kernel_type_ == real_to_complex) {
movss(xmm_inp_real, ptr[input_ptr]);
// output_real += input_real * cos(..)
mulss(xmm_cos, xmm_inp_real);
addss(xmm_real, xmm_cos);
// output_imag += input_real * sin(..)
mulss(xmm_sin, xmm_inp_real);
addss(xmm_imag, xmm_sin);
uni_vmovq(xmm_twiddles, ptr[twiddles_ptr]);
uni_vmovd(xmm_input, ptr[input_ptr]);
uni_vshufps(xmm_input, xmm_input, xmm_input, 0);
uni_vmulps(xmm_input, xmm_input, xmm_twiddles);
uni_vaddps(xmm_output, xmm_output, xmm_input);
} else if (kernel_type_ == complex_to_real) {
movss(xmm_inp_real, ptr[input_ptr]);
movss(xmm_inp_imag, ptr[input_ptr + type_size]);
// output += real * cos(..) + imag * sin(..)
mulss(xmm_cos, xmm_inp_real);
mulss(xmm_sin, xmm_inp_imag);
addss(xmm_cos, xmm_sin);
addss(xmm_real, xmm_cos);
c2r_kernel(false);
} else if (kernel_type_ == complex_to_complex) {
// output_real += input_real * cos(..) - input_imag * sin(..)
movss(xmm_inp_real, ptr[input_ptr]);
movss(xmm_inp_imag, ptr[input_ptr + type_size]);
mulss(xmm_inp_real, xmm_cos);
mulss(xmm_inp_imag, xmm_sin);
if (!is_inverse_) {
subss(xmm_inp_real, xmm_inp_imag);
} else {
addss(xmm_inp_real, xmm_inp_imag);
}
addss(xmm_real, xmm_inp_real);
// output_imag += input_imag * cos(..) + input_real * sin(..)
movss(xmm_inp_real, ptr[input_ptr]);
movss(xmm_inp_imag, ptr[input_ptr + type_size]);
mulss(xmm_inp_imag, xmm_cos);
mulss(xmm_inp_real, xmm_sin);
if (!is_inverse_) {
addss(xmm_inp_imag, xmm_inp_real);
} else {
subss(xmm_inp_imag, xmm_inp_real);
}
addss(xmm_imag, xmm_inp_imag);
c2c_kernel(false);
}
// increment indexes for next iteration
@ -312,33 +309,10 @@ void jit_dft_kernel_f32<isa>::generate() {
sub(input_ptr, input_type_size);
movss(xmm_cos, ptr[twiddles_ptr]);
movss(xmm_sin, ptr[twiddles_ptr + type_size]);
movss(xmm_inp_real, ptr[input_ptr]);
movss(xmm_inp_imag, ptr[input_ptr + type_size]);
if (kernel_type_ == complex_to_real) {
// output += real * cos(..) - imag * sin(..)
mulss(xmm_cos, xmm_inp_real);
mulss(xmm_sin, xmm_inp_imag);
subss(xmm_cos, xmm_sin);
addss(xmm_real, xmm_cos);
c2r_kernel(true);
} else if (kernel_type_ == complex_to_complex) {
// output_real += input_real * cos(..) - input_imag * sin(..)
movss(xmm_inp_real, ptr[input_ptr]);
movss(xmm_inp_imag, ptr[input_ptr + type_size]);
mulss(xmm_inp_real, xmm_cos);
mulss(xmm_inp_imag, xmm_sin);
subss(xmm_inp_real, xmm_inp_imag);
addss(xmm_real, xmm_inp_real);
// output_imag += input_imag * cos(..) + input_real * sin(..)
movss(xmm_inp_real, ptr[input_ptr]);
movss(xmm_inp_imag, ptr[input_ptr + type_size]);
mulss(xmm_inp_imag, xmm_cos);
mulss(xmm_inp_real, xmm_sin);
addss(xmm_inp_imag, xmm_inp_real);
addss(xmm_imag, xmm_inp_imag);
c2c_kernel(true);
}
add(twiddles_ptr, complex_type_size<float>());
@ -350,18 +324,16 @@ void jit_dft_kernel_f32<isa>::generate() {
if (kernel_type_ == complex_to_real) {
if (is_inverse_) {
divss(xmm_real, xmm_signal_size);
uni_vdivss(xmm_output, xmm_output, xmm_signal_size);
}
// store the result
movss(ptr[output_ptr], xmm_real);
uni_vmovss(ptr[output_ptr], xmm_output);
} else {
if (is_inverse_) {
divss(xmm_real, xmm_signal_size);
divss(xmm_imag, xmm_signal_size);
uni_vdivps(xmm_output, xmm_output, xmm_signal_size);
}
// store the results
movss(ptr[output_ptr], xmm_real);
movss(ptr[output_ptr + type_size], xmm_imag);
uni_vmovq(ptr[output_ptr], xmm_output);
}
add(output_ptr, output_type_size);
@ -393,50 +365,43 @@ void jit_dft_kernel_f32<isa>::generate() {
this->postamble();
}
template <cpu_isa_t isa>
void jit_dft_kernel_f32<isa>::uni_vbroadcastsd(const Xbyak::Xmm& x, const Xbyak::Operand& op) {
movsd(x, op);
shufpd(x, x, 0x0);
// Interleave real and imag registers and store in memory.
// For example (for AVX):
// real = [1, 2, 3, 4, 5, 6, 7, 8]
// imag = [11, 12, 13, 14, 15, 16, 17, 18]
// interleaved = [1, 11, 2, 12, 3, 13, 4, 14, 5, 15, 6, 16, 7, 17, 8, 18]
template <>
void jit_dft_kernel_f32<avx512_core>::interleave_and_store(const Vmm& real, const Vmm& imag, const Xbyak::RegExp& reg_exp, const Vmm& tmp) {
const Vmm& low = tmp;
const Vmm& high = real;
uni_vmovups(low, real);
vpermt2ps(low, perm_low, imag);
vpermt2ps(high, perm_high, imag);
uni_vmovups(ptr[reg_exp], low);
uni_vmovups(ptr[reg_exp + vlen], high);
}
template <cpu_isa_t isa>
void jit_dft_kernel_f32<isa>::uni_vbroadcastsd(const Xbyak::Ymm& x, const Xbyak::Operand& op) {
vbroadcastsd(x, op);
template <>
void jit_dft_kernel_f32<avx2>::interleave_and_store(const Vmm& real, const Vmm& imag, const Xbyak::RegExp& reg_exp, const Vmm& tmp) {
const Vmm& low = real;
const Vmm& high = imag;
vunpcklps(tmp, real, imag);
vunpckhps(high, real, imag);
vinsertf128(low, tmp, Xbyak::Xmm(high.getIdx()), 1);
vperm2f128(high, tmp, high, 0b00110001);
uni_vmovups(ptr[reg_exp], low);
uni_vmovups(ptr[reg_exp + vlen], high);
}
template <cpu_isa_t isa>
void jit_dft_kernel_f32<isa>::uni_vpermilps(const Xbyak::Xmm& x, const Xbyak::Operand& op, int8_t control) {
movups(x, op);
shufps(x, x, control);
}
template <cpu_isa_t isa>
void jit_dft_kernel_f32<isa>::uni_vpermilps(const Xbyak::Ymm& x, const Xbyak::Operand& op, int8_t control) {
vpermilps(x, op, control);
}
template <cpu_isa_t isa>
void jit_dft_kernel_f32<isa>::load_and_broadcast_every_other_elem(const Xbyak::Zmm& x, const Xbyak::RegExp& reg_exp, const Xbyak::Xmm& tmp) {
for (int i = 0; i < 4; i++) {
movq(tmp, ptr[reg_exp + type_size * i * 2]);
shufps(tmp, tmp, 0b01010000);
vinsertf32x4(x, x, tmp, i);
}
}
template <cpu_isa_t isa>
void jit_dft_kernel_f32<isa>::load_and_broadcast_every_other_elem(const Xbyak::Ymm& x, const Xbyak::RegExp& reg_exp, const Xbyak::Xmm& tmp) {
for (int i = 0; i < 2; i++) {
movq(tmp, ptr[reg_exp + type_size * i * 2]);
shufps(tmp, tmp, 0b01010000);
vinsertf128(x, x, tmp, i);
}
}
template <cpu_isa_t isa>
void jit_dft_kernel_f32<isa>::load_and_broadcast_every_other_elem(const Xbyak::Xmm& x, const Xbyak::RegExp& reg_exp, const Xbyak::Xmm& tmp) {
movq(x, ptr[reg_exp]);
shufps(x, x, 0b01010000);
template <>
void jit_dft_kernel_f32<sse41>::interleave_and_store(const Vmm& real, const Vmm& imag, const Xbyak::RegExp& reg_exp, const Vmm& tmp) {
const Vmm& low = tmp;
const Vmm& high = real;
uni_vmovups(low, real);
unpcklps(low, imag);
unpckhps(high, imag);
uni_vmovups(ptr[reg_exp], low);
uni_vmovups(ptr[reg_exp + vlen], high);
}
template struct jit_dft_kernel_f32<cpu::x64::sse41>;

View File

@ -60,7 +60,17 @@ struct jit_dft_kernel_f32 : public jit_dft_kernel, public jit_generator {
public:
DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_dft_kernel_f32)
jit_dft_kernel_f32(bool is_inverse, enum dft_type type) : jit_dft_kernel(is_inverse, type), jit_generator(jit_name()) {}
jit_dft_kernel_f32(bool is_inverse, enum dft_type type) : jit_dft_kernel(is_inverse, type), jit_generator(jit_name()) {
constexpr int simd_size = vlen / type_size;
perm_low_values.reserve(simd_size);
perm_high_values.reserve(simd_size);
for (int i = 0; i < simd_size / 2; i++) {
perm_low_values.push_back(i);
perm_low_values.push_back(i + simd_size);
perm_high_values.push_back(i + simd_size / 2);
perm_high_values.push_back(i + simd_size / 2 + simd_size);
}
}
void create_ker() override {
jit_generator::create_kernel();
@ -70,17 +80,13 @@ struct jit_dft_kernel_f32 : public jit_dft_kernel, public jit_generator {
void generate() override;
private:
void uni_vbroadcastsd(const Xbyak::Xmm& x, const Xbyak::Operand& op);
void uni_vbroadcastsd(const Xbyak::Ymm& x, const Xbyak::Operand& op);
using Vmm = typename conditional3<isa == cpu::x64::sse41, Xbyak::Xmm,
isa == cpu::x64::avx2, Xbyak::Ymm, Xbyak::Zmm>::type;
void uni_vpermilps(const Xbyak::Xmm& x, const Xbyak::Operand& op, int8_t control);
void uni_vpermilps(const Xbyak::Ymm& x, const Xbyak::Operand& op, int8_t control);
void interleave_and_store(const Vmm& real, const Vmm& imag, const Xbyak::RegExp& reg_exp, const Vmm& tmp);
void load_and_broadcast_every_other_elem(const Xbyak::Zmm& x, const Xbyak::RegExp& reg_exp, const Xbyak::Xmm& tmp);
void load_and_broadcast_every_other_elem(const Xbyak::Ymm& x, const Xbyak::RegExp& reg_exp, const Xbyak::Xmm& tmp);
void load_and_broadcast_every_other_elem(const Xbyak::Xmm& x, const Xbyak::RegExp& reg_exp, const Xbyak::Xmm& tmp);
int type_size = sizeof(float);
static constexpr int type_size = sizeof(float);
static constexpr int vlen = cpu_isa_traits<isa>::vlen;
Xbyak::Reg8 is_signal_size_even = al;
Xbyak::Reg64 input_ptr = rbx;
@ -90,6 +96,12 @@ struct jit_dft_kernel_f32 : public jit_dft_kernel, public jit_generator {
Xbyak::Reg64 signal_size = r11;
Xbyak::Reg64 output_start = r12;
Xbyak::Reg64 output_end = r13;
std::vector<int> perm_low_values;
std::vector<int> perm_high_values;
Vmm perm_low;
Vmm perm_high;
};
} // namespace intel_cpu

View File

@ -684,7 +684,7 @@ struct RDFTJitExecutor : public RDFTExecutor {
std::vector<float> generateTwiddlesDFT(size_t inputSize, size_t outputSize, enum dft_type type) override {
std::vector<float> twiddles(inputSize * outputSize * 2);
int simdSize = vlen / sizeof(float);
if (type == real_to_complex || type == complex_to_complex) {
if (type == real_to_complex) {
simdSize /= 2; // there are two floats per one complex element in the output
}
@ -702,7 +702,7 @@ struct RDFTJitExecutor : public RDFTExecutor {
}
for (size_t k = 0; k < simdSize; k++) {
double angle = 2 * PI * (K * simdSize + k) * n / inputSize;
twiddles[((K * inputSize + n) * 2 + 1) * simdSize + k] = -std::sin(angle);
twiddles[((K * inputSize + n) * 2 + 1) * simdSize + k] = isInverse ? std::sin(angle) : -std::sin(angle);
}
}
});
@ -712,7 +712,7 @@ struct RDFTJitExecutor : public RDFTExecutor {
k += start;
double angle = 2 * PI * k * n / inputSize;
twiddles[2 * (k * inputSize + n)] = std::cos(angle);
twiddles[2 * (k * inputSize + n) + 1] = -std::sin(angle);
twiddles[2 * (k * inputSize + n) + 1] = isInverse ? std::sin(angle) : -std::sin(angle);
});
}
return twiddles;

View File

@ -436,6 +436,8 @@ std::vector<RDFTTestCPUParams> getParams4D() {
params.push_back({{1, 192, 36, 64}, {3, 2}, {}, false, cpuParams});
params.push_back({{1, 192, 36, 64}, {-2, -1}, {36, 64}, false, cpuParams});
params.push_back({{1, 192, 36, 64}, {0, 1, 2, 3}, {}, false, cpuParams});
params.push_back({{1, 120, 64, 64}, {-2, -1}, {64, 33}, false, cpuParams});
params.push_back({{1, 120, 96, 96}, {-2, -1}, {96, 49}, false, cpuParams});
params.push_back({{2, 192, 36, 33, 2}, {0}, {}, true, cpuParams});
params.push_back({{1, 192, 36, 33, 2}, {1}, {}, true, cpuParams});
params.push_back({{1, 192, 36, 33, 2}, {2}, {}, true, cpuParams});
@ -444,6 +446,8 @@ std::vector<RDFTTestCPUParams> getParams4D() {
params.push_back({{1, 192, 36, 33, 2}, {3, 2}, {}, true, cpuParams});
params.push_back({{1, 192, 36, 33, 2}, {-2, -1}, {36, 64}, true, cpuParams});
params.push_back({{1, 192, 36, 33, 2}, {0, 1, 2, 3}, {}, true, cpuParams});
params.push_back({{1, 120, 64, 33, 2}, {-2, -1}, {64, 64}, true, cpuParams});
params.push_back({{1, 120, 96, 49, 2}, {-2, -1}, {96, 96}, true, cpuParams});
return params;
}