[CPU] RDFT kernel optimizations (#14060)
This commit is contained in:
parent
7facf8b90b
commit
43164a6b25
@ -15,8 +15,6 @@ void jit_dft_kernel_f32<isa>::generate() {
|
||||
using namespace Xbyak::util;
|
||||
using Xbyak::Label;
|
||||
using Xbyak::Xmm;
|
||||
using Vmm = typename conditional3<isa == cpu::x64::sse41, Xbyak::Xmm,
|
||||
isa == cpu::x64::avx2, Xbyak::Ymm, Xbyak::Zmm>::type;
|
||||
|
||||
this->preamble();
|
||||
|
||||
@ -37,8 +35,9 @@ void jit_dft_kernel_f32<isa>::generate() {
|
||||
output_type_size = type_size;
|
||||
break;
|
||||
}
|
||||
int vlen = cpu_isa_traits<isa>::vlen;
|
||||
const int simd_size = vlen / output_type_size;
|
||||
int simd_size = vlen / output_type_size;
|
||||
if (kernel_type_ == complex_to_complex)
|
||||
simd_size = vlen / type_size;
|
||||
|
||||
mov(input_ptr, ptr[param1 + GET_OFF(input)]);
|
||||
mov(input_size, ptr[param1 + GET_OFF(input_size)]);
|
||||
@ -63,23 +62,40 @@ void jit_dft_kernel_f32<isa>::generate() {
|
||||
Vmm vmm_signal_size = Vmm(reg_idx);
|
||||
if (is_inverse_) {
|
||||
reg_idx++;
|
||||
uni_vbroadcastss(Vmm(reg_idx), ptr[param1 + GET_OFF(signal_size)]);
|
||||
uni_vcvtdq2ps(vmm_signal_size, Vmm(reg_idx));
|
||||
uni_vbroadcastss(vmm_signal_size, ptr[param1 + GET_OFF(signal_size)]);
|
||||
uni_vcvtdq2ps(vmm_signal_size, vmm_signal_size);
|
||||
}
|
||||
|
||||
Vmm vmm_neg_mask = Vmm(reg_idx);
|
||||
Xmm xmm_neg_mask = Xmm(reg_idx);
|
||||
Xmm neg_mask = Xmm(reg_idx);
|
||||
if (kernel_type_ == complex_to_complex) {
|
||||
reg_idx++;
|
||||
if (!is_inverse_) {
|
||||
mov(rax, 1ULL << 31);
|
||||
} else {
|
||||
mov(rax, 1ULL << 63);
|
||||
}
|
||||
uni_vmovq(xmm_neg_mask, rax);
|
||||
uni_vbroadcastsd(vmm_neg_mask, xmm_neg_mask);
|
||||
uni_vpxor(neg_mask, neg_mask, neg_mask);
|
||||
mov(rax, 1ULL << 63);
|
||||
uni_vmovq(neg_mask, rax);
|
||||
}
|
||||
|
||||
size_t vmm_reg_idx = reg_idx;
|
||||
Vmm inp_real = Vmm(vmm_reg_idx++);
|
||||
Vmm inp_imag = Vmm(vmm_reg_idx++);
|
||||
Vmm cos = Vmm(vmm_reg_idx++);
|
||||
Vmm sin = Vmm(vmm_reg_idx++);
|
||||
const Vmm& twiddles = cos;
|
||||
Vmm tmp = Vmm(vmm_reg_idx++);
|
||||
Vmm output_real = Vmm(vmm_reg_idx++);
|
||||
Vmm output_imag = Vmm(vmm_reg_idx++);
|
||||
const Vmm& output = output_real;
|
||||
perm_low = Vmm(vmm_reg_idx++);
|
||||
perm_high = Vmm(vmm_reg_idx++);
|
||||
|
||||
mov(rax, reinterpret_cast<uint64_t>(perm_low_values.data()));
|
||||
uni_vmovups(perm_low, ptr[rax]);
|
||||
mov(rax, reinterpret_cast<uint64_t>(perm_high_values.data()));
|
||||
uni_vmovups(perm_high, ptr[rax]);
|
||||
|
||||
Xmm xmm_input = Xbyak::Xmm(reg_idx++);
|
||||
Xmm xmm_twiddles = Xbyak::Xmm(reg_idx++);
|
||||
Xmm xmm_output = Xbyak::Xmm(reg_idx++);
|
||||
|
||||
mov(rax, signal_size);
|
||||
and_(rax, 1);
|
||||
setz(is_signal_size_even);
|
||||
@ -89,60 +105,68 @@ void jit_dft_kernel_f32<isa>::generate() {
|
||||
Label loop_simd;
|
||||
Label loop_nonsimd;
|
||||
|
||||
auto simd_loop = [this, vlen, simd_size,
|
||||
input_type_size, reg_idx,
|
||||
&vmm_signal_size,
|
||||
&xmm_neg_mask,
|
||||
&vmm_neg_mask] {
|
||||
size_t idx = reg_idx;
|
||||
Vmm result = Vmm(idx++);
|
||||
Vmm inp_real = Vmm(idx++);
|
||||
Vmm inp_imag = Vmm(idx++);
|
||||
const Vmm& input = inp_real;
|
||||
const Vmm& input_perm = inp_imag;
|
||||
Vmm twiddles = Vmm(idx++);
|
||||
const Vmm& cos = twiddles;
|
||||
Vmm sin = Vmm(idx++);
|
||||
Xmm tmp = Xmm(idx++);
|
||||
|
||||
uni_vpxor(result, result, result);
|
||||
|
||||
if (kernel_type_ == complex_to_complex && is_inverse_) {
|
||||
mov(rdx, 1ULL << 63);
|
||||
uni_vmovq(xmm_neg_mask, rdx);
|
||||
uni_vbroadcastsd(vmm_neg_mask, xmm_neg_mask);
|
||||
auto simd_loop = [&] {
|
||||
if (kernel_type_ == complex_to_complex) {
|
||||
uni_vpxor(output_real, output_real, output_real);
|
||||
uni_vpxor(output_imag, output_imag, output_imag);
|
||||
} else {
|
||||
uni_vpxor(output, output, output);
|
||||
}
|
||||
|
||||
auto c2r_kernel = [&] (bool backwards) {
|
||||
// if backwards == false:
|
||||
// output_real += input_real * cos(..) - input_imag * sin(..)
|
||||
// else:
|
||||
// output_real += input_real * cos(..) + input_imag * sin(..)
|
||||
uni_vbroadcastss(inp_real, ptr[input_ptr]);
|
||||
uni_vbroadcastss(inp_imag, ptr[input_ptr + type_size]);
|
||||
uni_vmovups(cos, ptr[twiddles_ptr]);
|
||||
uni_vmovups(sin, ptr[twiddles_ptr + vlen]);
|
||||
uni_vfmadd231ps(output, inp_real, cos);
|
||||
if (!backwards) {
|
||||
uni_vfnmadd231ps(output, inp_imag, sin);
|
||||
} else {
|
||||
uni_vfmadd231ps(output, inp_imag, sin);
|
||||
}
|
||||
add(twiddles_ptr, 2 * vlen);
|
||||
};
|
||||
|
||||
auto c2c_kernel = [&] (bool backwards) {
|
||||
// if backwards == false:
|
||||
// output_real += input_real * cos(..) - input_imag * sin(..)
|
||||
// output_imag += input_imag * cos(..) + input_real * sin(..)
|
||||
// else:
|
||||
// output_real += input_real * cos(..) + input_imag * sin(..)
|
||||
// output_imag += input_imag * cos(..) - input_real * sin(..)
|
||||
uni_vbroadcastss(inp_real, ptr[input_ptr]);
|
||||
uni_vbroadcastss(inp_imag, ptr[input_ptr + type_size]);
|
||||
uni_vmovups(cos, ptr[twiddles_ptr]);
|
||||
uni_vmovups(sin, ptr[twiddles_ptr + vlen]);
|
||||
uni_vfmadd231ps(output_real, inp_real, cos);
|
||||
uni_vfmadd231ps(output_imag, inp_imag, cos);
|
||||
if (!backwards) {
|
||||
uni_vfnmadd231ps(output_real, inp_imag, sin);
|
||||
uni_vfmadd231ps(output_imag, inp_real, sin);
|
||||
} else {
|
||||
uni_vfmadd231ps(output_real, inp_imag, sin);
|
||||
uni_vfnmadd231ps(output_imag, inp_real, sin);
|
||||
}
|
||||
|
||||
add(twiddles_ptr, 2 * vlen);
|
||||
};
|
||||
|
||||
Label loop;
|
||||
L(loop);
|
||||
{
|
||||
if (kernel_type_ == real_to_complex) {
|
||||
uni_vbroadcastss(inp_real, ptr[input_ptr]);
|
||||
uni_vmovups(twiddles, ptr[twiddles_ptr]);
|
||||
uni_vfmadd231ps(result, inp_real, twiddles);
|
||||
|
||||
uni_vfmadd231ps(output, inp_real, twiddles);
|
||||
add(twiddles_ptr, vlen);
|
||||
} else if (kernel_type_ == complex_to_real) {
|
||||
uni_vbroadcastss(inp_real, ptr[input_ptr]);
|
||||
uni_vbroadcastss(inp_imag, ptr[input_ptr + type_size]);
|
||||
uni_vmovups(cos, ptr[twiddles_ptr]);
|
||||
uni_vmovups(sin, ptr[twiddles_ptr + vlen]);
|
||||
uni_vfmadd231ps(result, inp_real, cos);
|
||||
uni_vfmadd231ps(result, inp_imag, sin);
|
||||
|
||||
add(twiddles_ptr, 2 * vlen);
|
||||
c2r_kernel(false);
|
||||
} else if (kernel_type_ == complex_to_complex) {
|
||||
// output_real += input_real * cos(..) - input_imag * sin(..)
|
||||
// output_imag += input_imag * cos(..) + input_real * sin(..)
|
||||
uni_vbroadcastsd(input, ptr[input_ptr]);
|
||||
uni_vpermilps(input_perm, input, 0b10110001); // swap real with imag
|
||||
uni_vpxor(input_perm, input_perm, vmm_neg_mask); // negate imag part (or real part if is_inverse == true)
|
||||
load_and_broadcast_every_other_elem(cos, twiddles_ptr, tmp);
|
||||
load_and_broadcast_every_other_elem(sin, twiddles_ptr + vlen / 2, tmp);
|
||||
uni_vfmadd231ps(result, input, cos);
|
||||
uni_vfmadd231ps(result, input_perm, sin);
|
||||
|
||||
add(twiddles_ptr, vlen);
|
||||
c2c_kernel(false);
|
||||
}
|
||||
|
||||
add(input_ptr, input_type_size);
|
||||
@ -159,12 +183,6 @@ void jit_dft_kernel_f32<isa>::generate() {
|
||||
mov(input_size, signal_size);
|
||||
sub(input_size, ptr[param1 + GET_OFF(input_size)]);
|
||||
|
||||
if (kernel_type_ == complex_to_complex) {
|
||||
mov(rdx, 1ULL << 31);
|
||||
uni_vmovq(xmm_neg_mask, rdx);
|
||||
uni_vbroadcastsd(vmm_neg_mask, xmm_neg_mask);
|
||||
}
|
||||
|
||||
test(is_signal_size_even, 1);
|
||||
jz(loop_backwards);
|
||||
|
||||
@ -176,26 +194,11 @@ void jit_dft_kernel_f32<isa>::generate() {
|
||||
je(loop_backwards_exit, T_NEAR);
|
||||
|
||||
sub(input_ptr, input_type_size);
|
||||
if (kernel_type_ == complex_to_real) {
|
||||
uni_vbroadcastss(inp_real, ptr[input_ptr]);
|
||||
uni_vbroadcastss(inp_imag, ptr[input_ptr + type_size]);
|
||||
uni_vmovups(cos, ptr[twiddles_ptr]);
|
||||
uni_vmovups(sin, ptr[twiddles_ptr + vlen]);
|
||||
|
||||
uni_vfmadd231ps(result, inp_real, cos);
|
||||
uni_vfnmadd231ps(result, inp_imag, sin);
|
||||
add(twiddles_ptr, 2 * vlen);
|
||||
if (kernel_type_ == complex_to_real) {
|
||||
c2r_kernel(true);
|
||||
} else if (kernel_type_ == complex_to_complex) {
|
||||
// output_real += input_real * cos(..) - input_imag * sin(..)
|
||||
// output_imag += input_imag * cos(..) + input_real * sin(..)
|
||||
uni_vbroadcastsd(input, ptr[input_ptr]);
|
||||
uni_vpermilps(input_perm, input, 0b10110001); // swap real with imag
|
||||
uni_vpxor(input_perm, input_perm, vmm_neg_mask); // negate imag part
|
||||
load_and_broadcast_every_other_elem(cos, twiddles_ptr, tmp);
|
||||
load_and_broadcast_every_other_elem(sin, twiddles_ptr + vlen / 2, tmp);
|
||||
uni_vfmadd231ps(result, input, cos);
|
||||
uni_vfmadd231ps(result, input_perm, sin);
|
||||
add(twiddles_ptr, vlen);
|
||||
c2c_kernel(true);
|
||||
}
|
||||
|
||||
dec(input_size);
|
||||
@ -205,83 +208,77 @@ void jit_dft_kernel_f32<isa>::generate() {
|
||||
}
|
||||
|
||||
if (is_inverse_) {
|
||||
uni_vdivps(result, result, vmm_signal_size);
|
||||
uni_vdivps(output_real, output_real, vmm_signal_size);
|
||||
uni_vdivps(output_imag, output_imag, vmm_signal_size);
|
||||
}
|
||||
|
||||
// store the results
|
||||
if (kernel_type_ == complex_to_complex) {
|
||||
interleave_and_store(output_real, output_imag, output_ptr, tmp);
|
||||
add(output_ptr, 2 * vlen);
|
||||
} else {
|
||||
uni_vmovups(ptr[output_ptr], output);
|
||||
add(output_ptr, vlen);
|
||||
}
|
||||
// store the results
|
||||
uni_vmovups(ptr[output_ptr], result);
|
||||
|
||||
add(output_ptr, vlen);
|
||||
sub(output_end, simd_size);
|
||||
};
|
||||
|
||||
auto nonsimd_loop = [this,
|
||||
input_type_size,
|
||||
output_type_size,
|
||||
&xmm_signal_size,
|
||||
reg_idx] {
|
||||
size_t idx = reg_idx;
|
||||
Xmm xmm_inp_real = Xbyak::Xmm(idx++);
|
||||
Xmm xmm_inp_imag = Xbyak::Xmm(idx++);
|
||||
Xmm xmm_real = Xbyak::Xmm(idx++);
|
||||
Xmm xmm_imag = Xbyak::Xmm(idx++);
|
||||
Xmm xmm_cos = Xbyak::Xmm(idx++);
|
||||
Xmm xmm_sin = Xbyak::Xmm(idx++);
|
||||
auto nonsimd_loop = [&] {
|
||||
uni_vxorps(xmm_output, xmm_output, xmm_output);
|
||||
|
||||
if (kernel_type_ != complex_to_real) {
|
||||
xorps(xmm_real, xmm_real);
|
||||
xorps(xmm_imag, xmm_imag);
|
||||
} else {
|
||||
xorps(xmm_real, xmm_real);
|
||||
}
|
||||
auto c2r_kernel = [&] (bool backwards) {
|
||||
// if backwards == false:
|
||||
// output_real += input_real * cos(..) - input_imag * sin(..)
|
||||
// else:
|
||||
// output_real += input_real * cos(..) + input_imag * sin(..)
|
||||
uni_vmovq(xmm_input, ptr[input_ptr]);
|
||||
uni_vmovq(xmm_twiddles, ptr[twiddles_ptr]);
|
||||
uni_vmulps(xmm_input, xmm_input, xmm_twiddles);
|
||||
if (!backwards) {
|
||||
uni_vhsubps(xmm_input, xmm_input, xmm_input);
|
||||
} else {
|
||||
uni_vhaddps(xmm_input, xmm_input, xmm_input);
|
||||
}
|
||||
uni_vaddss(xmm_output, xmm_output, xmm_input);
|
||||
};
|
||||
|
||||
auto c2c_kernel = [&] (bool backwards) {
|
||||
// if backwards == false:
|
||||
// output_real += input_real * cos(..) - input_imag * sin(..)
|
||||
// output_imag += input_imag * cos(..) + input_real * sin(..)
|
||||
// else:
|
||||
// output_real += input_real * cos(..) + input_imag * sin(..)
|
||||
// output_imag += input_imag * cos(..) - input_real * sin(..)
|
||||
uni_vmovq(xmm_input, ptr[input_ptr]);
|
||||
uni_vshufps(xmm_input, xmm_input, xmm_input, 0b00010100);
|
||||
uni_vmovq(xmm_twiddles, ptr[twiddles_ptr]);
|
||||
uni_vshufps(xmm_twiddles, xmm_twiddles, xmm_twiddles, 0b01000100);
|
||||
uni_vxorps(xmm_twiddles, xmm_twiddles, neg_mask);
|
||||
uni_vmulps(xmm_input, xmm_input, xmm_twiddles);
|
||||
if (!backwards) {
|
||||
uni_vhaddps(xmm_input, xmm_input, xmm_input);
|
||||
} else {
|
||||
uni_vhsubps(xmm_input, xmm_input, xmm_input);
|
||||
}
|
||||
uni_vaddps(xmm_output, xmm_output, xmm_input);
|
||||
};
|
||||
|
||||
Label loop;
|
||||
L(loop);
|
||||
{
|
||||
movss(xmm_cos, ptr[twiddles_ptr]);
|
||||
movss(xmm_sin, ptr[twiddles_ptr + type_size]);
|
||||
if (kernel_type_ == real_to_complex) {
|
||||
movss(xmm_inp_real, ptr[input_ptr]);
|
||||
|
||||
// output_real += input_real * cos(..)
|
||||
mulss(xmm_cos, xmm_inp_real);
|
||||
addss(xmm_real, xmm_cos);
|
||||
|
||||
// output_imag += input_real * sin(..)
|
||||
mulss(xmm_sin, xmm_inp_real);
|
||||
addss(xmm_imag, xmm_sin);
|
||||
uni_vmovq(xmm_twiddles, ptr[twiddles_ptr]);
|
||||
uni_vmovd(xmm_input, ptr[input_ptr]);
|
||||
uni_vshufps(xmm_input, xmm_input, xmm_input, 0);
|
||||
uni_vmulps(xmm_input, xmm_input, xmm_twiddles);
|
||||
uni_vaddps(xmm_output, xmm_output, xmm_input);
|
||||
} else if (kernel_type_ == complex_to_real) {
|
||||
movss(xmm_inp_real, ptr[input_ptr]);
|
||||
movss(xmm_inp_imag, ptr[input_ptr + type_size]);
|
||||
|
||||
// output += real * cos(..) + imag * sin(..)
|
||||
mulss(xmm_cos, xmm_inp_real);
|
||||
mulss(xmm_sin, xmm_inp_imag);
|
||||
addss(xmm_cos, xmm_sin);
|
||||
addss(xmm_real, xmm_cos);
|
||||
c2r_kernel(false);
|
||||
} else if (kernel_type_ == complex_to_complex) {
|
||||
// output_real += input_real * cos(..) - input_imag * sin(..)
|
||||
movss(xmm_inp_real, ptr[input_ptr]);
|
||||
movss(xmm_inp_imag, ptr[input_ptr + type_size]);
|
||||
mulss(xmm_inp_real, xmm_cos);
|
||||
mulss(xmm_inp_imag, xmm_sin);
|
||||
if (!is_inverse_) {
|
||||
subss(xmm_inp_real, xmm_inp_imag);
|
||||
} else {
|
||||
addss(xmm_inp_real, xmm_inp_imag);
|
||||
}
|
||||
addss(xmm_real, xmm_inp_real);
|
||||
|
||||
// output_imag += input_imag * cos(..) + input_real * sin(..)
|
||||
movss(xmm_inp_real, ptr[input_ptr]);
|
||||
movss(xmm_inp_imag, ptr[input_ptr + type_size]);
|
||||
mulss(xmm_inp_imag, xmm_cos);
|
||||
mulss(xmm_inp_real, xmm_sin);
|
||||
if (!is_inverse_) {
|
||||
addss(xmm_inp_imag, xmm_inp_real);
|
||||
} else {
|
||||
subss(xmm_inp_imag, xmm_inp_real);
|
||||
}
|
||||
addss(xmm_imag, xmm_inp_imag);
|
||||
c2c_kernel(false);
|
||||
}
|
||||
|
||||
// increment indexes for next iteration
|
||||
@ -312,33 +309,10 @@ void jit_dft_kernel_f32<isa>::generate() {
|
||||
|
||||
sub(input_ptr, input_type_size);
|
||||
|
||||
movss(xmm_cos, ptr[twiddles_ptr]);
|
||||
movss(xmm_sin, ptr[twiddles_ptr + type_size]);
|
||||
movss(xmm_inp_real, ptr[input_ptr]);
|
||||
movss(xmm_inp_imag, ptr[input_ptr + type_size]);
|
||||
|
||||
if (kernel_type_ == complex_to_real) {
|
||||
// output += real * cos(..) - imag * sin(..)
|
||||
mulss(xmm_cos, xmm_inp_real);
|
||||
mulss(xmm_sin, xmm_inp_imag);
|
||||
subss(xmm_cos, xmm_sin);
|
||||
addss(xmm_real, xmm_cos);
|
||||
c2r_kernel(true);
|
||||
} else if (kernel_type_ == complex_to_complex) {
|
||||
// output_real += input_real * cos(..) - input_imag * sin(..)
|
||||
movss(xmm_inp_real, ptr[input_ptr]);
|
||||
movss(xmm_inp_imag, ptr[input_ptr + type_size]);
|
||||
mulss(xmm_inp_real, xmm_cos);
|
||||
mulss(xmm_inp_imag, xmm_sin);
|
||||
subss(xmm_inp_real, xmm_inp_imag);
|
||||
addss(xmm_real, xmm_inp_real);
|
||||
|
||||
// output_imag += input_imag * cos(..) + input_real * sin(..)
|
||||
movss(xmm_inp_real, ptr[input_ptr]);
|
||||
movss(xmm_inp_imag, ptr[input_ptr + type_size]);
|
||||
mulss(xmm_inp_imag, xmm_cos);
|
||||
mulss(xmm_inp_real, xmm_sin);
|
||||
addss(xmm_inp_imag, xmm_inp_real);
|
||||
addss(xmm_imag, xmm_inp_imag);
|
||||
c2c_kernel(true);
|
||||
}
|
||||
|
||||
add(twiddles_ptr, complex_type_size<float>());
|
||||
@ -350,18 +324,16 @@ void jit_dft_kernel_f32<isa>::generate() {
|
||||
|
||||
if (kernel_type_ == complex_to_real) {
|
||||
if (is_inverse_) {
|
||||
divss(xmm_real, xmm_signal_size);
|
||||
uni_vdivss(xmm_output, xmm_output, xmm_signal_size);
|
||||
}
|
||||
// store the result
|
||||
movss(ptr[output_ptr], xmm_real);
|
||||
uni_vmovss(ptr[output_ptr], xmm_output);
|
||||
} else {
|
||||
if (is_inverse_) {
|
||||
divss(xmm_real, xmm_signal_size);
|
||||
divss(xmm_imag, xmm_signal_size);
|
||||
uni_vdivps(xmm_output, xmm_output, xmm_signal_size);
|
||||
}
|
||||
// store the results
|
||||
movss(ptr[output_ptr], xmm_real);
|
||||
movss(ptr[output_ptr + type_size], xmm_imag);
|
||||
uni_vmovq(ptr[output_ptr], xmm_output);
|
||||
}
|
||||
|
||||
add(output_ptr, output_type_size);
|
||||
@ -393,50 +365,43 @@ void jit_dft_kernel_f32<isa>::generate() {
|
||||
this->postamble();
|
||||
}
|
||||
|
||||
template <cpu_isa_t isa>
|
||||
void jit_dft_kernel_f32<isa>::uni_vbroadcastsd(const Xbyak::Xmm& x, const Xbyak::Operand& op) {
|
||||
movsd(x, op);
|
||||
shufpd(x, x, 0x0);
|
||||
// Interleave real and imag registers and store in memory.
|
||||
// For example (for AVX):
|
||||
// real = [1, 2, 3, 4, 5, 6, 7, 8]
|
||||
// imag = [11, 12, 13, 14, 15, 16, 17, 18]
|
||||
// interleaved = [1, 11, 2, 12, 3, 13, 4, 14, 5, 15, 6, 16, 7, 17, 8, 18]
|
||||
template <>
|
||||
void jit_dft_kernel_f32<avx512_core>::interleave_and_store(const Vmm& real, const Vmm& imag, const Xbyak::RegExp& reg_exp, const Vmm& tmp) {
|
||||
const Vmm& low = tmp;
|
||||
const Vmm& high = real;
|
||||
uni_vmovups(low, real);
|
||||
vpermt2ps(low, perm_low, imag);
|
||||
vpermt2ps(high, perm_high, imag);
|
||||
uni_vmovups(ptr[reg_exp], low);
|
||||
uni_vmovups(ptr[reg_exp + vlen], high);
|
||||
}
|
||||
|
||||
template <cpu_isa_t isa>
|
||||
void jit_dft_kernel_f32<isa>::uni_vbroadcastsd(const Xbyak::Ymm& x, const Xbyak::Operand& op) {
|
||||
vbroadcastsd(x, op);
|
||||
template <>
|
||||
void jit_dft_kernel_f32<avx2>::interleave_and_store(const Vmm& real, const Vmm& imag, const Xbyak::RegExp& reg_exp, const Vmm& tmp) {
|
||||
const Vmm& low = real;
|
||||
const Vmm& high = imag;
|
||||
vunpcklps(tmp, real, imag);
|
||||
vunpckhps(high, real, imag);
|
||||
vinsertf128(low, tmp, Xbyak::Xmm(high.getIdx()), 1);
|
||||
vperm2f128(high, tmp, high, 0b00110001);
|
||||
uni_vmovups(ptr[reg_exp], low);
|
||||
uni_vmovups(ptr[reg_exp + vlen], high);
|
||||
}
|
||||
|
||||
template <cpu_isa_t isa>
|
||||
void jit_dft_kernel_f32<isa>::uni_vpermilps(const Xbyak::Xmm& x, const Xbyak::Operand& op, int8_t control) {
|
||||
movups(x, op);
|
||||
shufps(x, x, control);
|
||||
}
|
||||
|
||||
template <cpu_isa_t isa>
|
||||
void jit_dft_kernel_f32<isa>::uni_vpermilps(const Xbyak::Ymm& x, const Xbyak::Operand& op, int8_t control) {
|
||||
vpermilps(x, op, control);
|
||||
}
|
||||
|
||||
template <cpu_isa_t isa>
|
||||
void jit_dft_kernel_f32<isa>::load_and_broadcast_every_other_elem(const Xbyak::Zmm& x, const Xbyak::RegExp& reg_exp, const Xbyak::Xmm& tmp) {
|
||||
for (int i = 0; i < 4; i++) {
|
||||
movq(tmp, ptr[reg_exp + type_size * i * 2]);
|
||||
shufps(tmp, tmp, 0b01010000);
|
||||
vinsertf32x4(x, x, tmp, i);
|
||||
}
|
||||
}
|
||||
|
||||
template <cpu_isa_t isa>
|
||||
void jit_dft_kernel_f32<isa>::load_and_broadcast_every_other_elem(const Xbyak::Ymm& x, const Xbyak::RegExp& reg_exp, const Xbyak::Xmm& tmp) {
|
||||
for (int i = 0; i < 2; i++) {
|
||||
movq(tmp, ptr[reg_exp + type_size * i * 2]);
|
||||
shufps(tmp, tmp, 0b01010000);
|
||||
vinsertf128(x, x, tmp, i);
|
||||
}
|
||||
}
|
||||
|
||||
template <cpu_isa_t isa>
|
||||
void jit_dft_kernel_f32<isa>::load_and_broadcast_every_other_elem(const Xbyak::Xmm& x, const Xbyak::RegExp& reg_exp, const Xbyak::Xmm& tmp) {
|
||||
movq(x, ptr[reg_exp]);
|
||||
shufps(x, x, 0b01010000);
|
||||
template <>
|
||||
void jit_dft_kernel_f32<sse41>::interleave_and_store(const Vmm& real, const Vmm& imag, const Xbyak::RegExp& reg_exp, const Vmm& tmp) {
|
||||
const Vmm& low = tmp;
|
||||
const Vmm& high = real;
|
||||
uni_vmovups(low, real);
|
||||
unpcklps(low, imag);
|
||||
unpckhps(high, imag);
|
||||
uni_vmovups(ptr[reg_exp], low);
|
||||
uni_vmovups(ptr[reg_exp + vlen], high);
|
||||
}
|
||||
|
||||
template struct jit_dft_kernel_f32<cpu::x64::sse41>;
|
||||
|
@ -60,7 +60,17 @@ struct jit_dft_kernel_f32 : public jit_dft_kernel, public jit_generator {
|
||||
public:
|
||||
DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_dft_kernel_f32)
|
||||
|
||||
jit_dft_kernel_f32(bool is_inverse, enum dft_type type) : jit_dft_kernel(is_inverse, type), jit_generator(jit_name()) {}
|
||||
jit_dft_kernel_f32(bool is_inverse, enum dft_type type) : jit_dft_kernel(is_inverse, type), jit_generator(jit_name()) {
|
||||
constexpr int simd_size = vlen / type_size;
|
||||
perm_low_values.reserve(simd_size);
|
||||
perm_high_values.reserve(simd_size);
|
||||
for (int i = 0; i < simd_size / 2; i++) {
|
||||
perm_low_values.push_back(i);
|
||||
perm_low_values.push_back(i + simd_size);
|
||||
perm_high_values.push_back(i + simd_size / 2);
|
||||
perm_high_values.push_back(i + simd_size / 2 + simd_size);
|
||||
}
|
||||
}
|
||||
|
||||
void create_ker() override {
|
||||
jit_generator::create_kernel();
|
||||
@ -70,17 +80,13 @@ struct jit_dft_kernel_f32 : public jit_dft_kernel, public jit_generator {
|
||||
void generate() override;
|
||||
|
||||
private:
|
||||
void uni_vbroadcastsd(const Xbyak::Xmm& x, const Xbyak::Operand& op);
|
||||
void uni_vbroadcastsd(const Xbyak::Ymm& x, const Xbyak::Operand& op);
|
||||
using Vmm = typename conditional3<isa == cpu::x64::sse41, Xbyak::Xmm,
|
||||
isa == cpu::x64::avx2, Xbyak::Ymm, Xbyak::Zmm>::type;
|
||||
|
||||
void uni_vpermilps(const Xbyak::Xmm& x, const Xbyak::Operand& op, int8_t control);
|
||||
void uni_vpermilps(const Xbyak::Ymm& x, const Xbyak::Operand& op, int8_t control);
|
||||
void interleave_and_store(const Vmm& real, const Vmm& imag, const Xbyak::RegExp& reg_exp, const Vmm& tmp);
|
||||
|
||||
void load_and_broadcast_every_other_elem(const Xbyak::Zmm& x, const Xbyak::RegExp& reg_exp, const Xbyak::Xmm& tmp);
|
||||
void load_and_broadcast_every_other_elem(const Xbyak::Ymm& x, const Xbyak::RegExp& reg_exp, const Xbyak::Xmm& tmp);
|
||||
void load_and_broadcast_every_other_elem(const Xbyak::Xmm& x, const Xbyak::RegExp& reg_exp, const Xbyak::Xmm& tmp);
|
||||
|
||||
int type_size = sizeof(float);
|
||||
static constexpr int type_size = sizeof(float);
|
||||
static constexpr int vlen = cpu_isa_traits<isa>::vlen;
|
||||
|
||||
Xbyak::Reg8 is_signal_size_even = al;
|
||||
Xbyak::Reg64 input_ptr = rbx;
|
||||
@ -90,6 +96,12 @@ struct jit_dft_kernel_f32 : public jit_dft_kernel, public jit_generator {
|
||||
Xbyak::Reg64 signal_size = r11;
|
||||
Xbyak::Reg64 output_start = r12;
|
||||
Xbyak::Reg64 output_end = r13;
|
||||
|
||||
std::vector<int> perm_low_values;
|
||||
std::vector<int> perm_high_values;
|
||||
|
||||
Vmm perm_low;
|
||||
Vmm perm_high;
|
||||
};
|
||||
|
||||
} // namespace intel_cpu
|
||||
|
@ -684,7 +684,7 @@ struct RDFTJitExecutor : public RDFTExecutor {
|
||||
std::vector<float> generateTwiddlesDFT(size_t inputSize, size_t outputSize, enum dft_type type) override {
|
||||
std::vector<float> twiddles(inputSize * outputSize * 2);
|
||||
int simdSize = vlen / sizeof(float);
|
||||
if (type == real_to_complex || type == complex_to_complex) {
|
||||
if (type == real_to_complex) {
|
||||
simdSize /= 2; // there are two floats per one complex element in the output
|
||||
}
|
||||
|
||||
@ -702,7 +702,7 @@ struct RDFTJitExecutor : public RDFTExecutor {
|
||||
}
|
||||
for (size_t k = 0; k < simdSize; k++) {
|
||||
double angle = 2 * PI * (K * simdSize + k) * n / inputSize;
|
||||
twiddles[((K * inputSize + n) * 2 + 1) * simdSize + k] = -std::sin(angle);
|
||||
twiddles[((K * inputSize + n) * 2 + 1) * simdSize + k] = isInverse ? std::sin(angle) : -std::sin(angle);
|
||||
}
|
||||
}
|
||||
});
|
||||
@ -712,7 +712,7 @@ struct RDFTJitExecutor : public RDFTExecutor {
|
||||
k += start;
|
||||
double angle = 2 * PI * k * n / inputSize;
|
||||
twiddles[2 * (k * inputSize + n)] = std::cos(angle);
|
||||
twiddles[2 * (k * inputSize + n) + 1] = -std::sin(angle);
|
||||
twiddles[2 * (k * inputSize + n) + 1] = isInverse ? std::sin(angle) : -std::sin(angle);
|
||||
});
|
||||
}
|
||||
return twiddles;
|
||||
|
@ -436,6 +436,8 @@ std::vector<RDFTTestCPUParams> getParams4D() {
|
||||
params.push_back({{1, 192, 36, 64}, {3, 2}, {}, false, cpuParams});
|
||||
params.push_back({{1, 192, 36, 64}, {-2, -1}, {36, 64}, false, cpuParams});
|
||||
params.push_back({{1, 192, 36, 64}, {0, 1, 2, 3}, {}, false, cpuParams});
|
||||
params.push_back({{1, 120, 64, 64}, {-2, -1}, {64, 33}, false, cpuParams});
|
||||
params.push_back({{1, 120, 96, 96}, {-2, -1}, {96, 49}, false, cpuParams});
|
||||
params.push_back({{2, 192, 36, 33, 2}, {0}, {}, true, cpuParams});
|
||||
params.push_back({{1, 192, 36, 33, 2}, {1}, {}, true, cpuParams});
|
||||
params.push_back({{1, 192, 36, 33, 2}, {2}, {}, true, cpuParams});
|
||||
@ -444,6 +446,8 @@ std::vector<RDFTTestCPUParams> getParams4D() {
|
||||
params.push_back({{1, 192, 36, 33, 2}, {3, 2}, {}, true, cpuParams});
|
||||
params.push_back({{1, 192, 36, 33, 2}, {-2, -1}, {36, 64}, true, cpuParams});
|
||||
params.push_back({{1, 192, 36, 33, 2}, {0, 1, 2, 3}, {}, true, cpuParams});
|
||||
params.push_back({{1, 120, 64, 33, 2}, {-2, -1}, {64, 64}, true, cpuParams});
|
||||
params.push_back({{1, 120, 96, 49, 2}, {-2, -1}, {96, 96}, true, cpuParams});
|
||||
|
||||
return params;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user