Reference implementation of DFT and IDFT operations (#4938)
* Started to write the header file for (I)DFT reference implementation. * Continued to write the header file for (I)DFT reference implementation. * Renamed the header file for the reference implementation of (I)DFT. * Started to write an implementation file for the reference implementation of (I)DFT. * Continued to write an implementation file for the reference implementation of (I)DFT. * Continued to write an implementation file for the reference implementation of (I)DFT. * Small fix. * Written copying data from input and copying data to output. * Code style fixes. * Small fix. * Some fixes. * Some fixes. * Small fix. * Written naive version of (I)DFT calculation. * Some fixes. * Some fixes. * Some fixes. * Some fixes. * Some fixes. * Small fixes. * Written the draft of the reference implementation of (I)DFT. * Small fix. * Small fix. * Code style fixes. * Added evaluation of (I)DFT to evaluation_map.cpp. * Small fixes. * Some fixes. * Written test for evaluation of 1D DFT. * Fixed ngraph/test/CMakeLists.txt. * Disabled DFT evaluation test in CPU, because (I)DFT has not implemented yet in CPU. * Added debug prints to evaluation test of DFT. * Added debug prints into evaluate_map.cpp for DFT evaluation. * Added some debug prints into FFT calculation. * Added more debug prints. * Added more debug prints. * Added more debug prints. * Added more debug prints. * Added more debug prints. * Small fix. * Added more debug prints. * Added more debug prints. * Small change. * Some fixes. * Small fix. * Some changes. * Added more tests. * Added test for IDFT 1D calculation. * Some fixes. * Added more debug prints. * Small fix. * Small fix. * Some fixes. * Some fix. * Small fix. * Added tests for 2D case of IDFT. * Some fixes. * Written tests for 3D case of DFT. * Some fixes. * Small fix. * Added test for 3D case of IDFT. * Some fixes. * Deleted debug prints from tests for IDFT. * Deleted debug prints from tests for DFT. * Deleted debug prints from the reference implementation of (I)DFT. * Code style fixes. * Deleted debug prints from evaluates_map.cpp. * Written the header file for the base class of DFT and IDFT operations. * Written an implementation of the base class of DFT and IDFT. * Now nGraph IDFT operation class is a derived class of FFTBase. * Now the nGraph operation DFT is a derived class of FFTBase. * Added assert for axes in (I)DFT reference. * Small refactoring. * Deleted commented code. * Small refactoring. * Small fix. * Initialization of calculations of the reference implementation of (I)DFT was moved in the separate function. * Small fix. * Code style fix. * Small fix. * Now evaluate() of (I)DFT uses canonicalize_axes() from the reference implementation. * Code style fix. * Deleted commented code. * Added tests for i32 axes of DFT. * Added test for i32 axes of 2D DFT. * Added i32 axes case to test for 3D DFT. * Added test for i32 axes in tests for IDFT. * Written signal_size case test for 1D DFT. * Small fix. * Written test for signal_size case for 2D DFT. * Added test for bfloat16 input data of 1D DFT. * Small fix. * Small fix. * Small fix. * Some fixes. * Some fix. * Added bfloat16 input tests for 2D DFT. * Some fixes. * Written tests for bfloat16 input of 3D DFT. * Some fixes. * Some fixes. * Added tests for bfloat16 input of 1D IDFT. * Some fixes. * Added tests for bfloat16 input of 2D IDFT. * Added test for bfloat16 input of 3D IDFT. * Small fix. * Some fixes. * Added tests for float16 input of 1D DFT. * Small fix. * Written tests for float16 input of 2D and 3D DFT. * Small fix. * Some fixes. * Some fixes. * Written tests for float16 inputs of 1D, 2D, 3D IDFT. * Some fixes. * Some fixes. * Some fixes. * Some fixes. * Deleted redundant include. * Some fixes. * Added tests of 1D and 2D DFT for the case when some axes lengths are powers of 2. * Added tests for 3D DFT and 1D, 2D, 3D IDFT in the case when lengths of some axes are powers of 2. * Small fix. * Added some comments. * Added some comments.
This commit is contained in:
parent
1c4428e945
commit
f2366f7072
@ -21,6 +21,7 @@
|
||||
#include "ngraph/attribute_adapter.hpp"
|
||||
#include "ngraph/op/op.hpp"
|
||||
#include "ngraph/op/util/attr_types.hpp"
|
||||
#include "ngraph/op/util/fft_base.hpp"
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
@ -29,7 +30,7 @@ namespace ngraph
|
||||
namespace v7
|
||||
{
|
||||
/// \brief An operation DFT that computes the discrete Fourier transformation.
|
||||
class NGRAPH_API DFT : public Op
|
||||
class NGRAPH_API DFT : public util::FFTBase
|
||||
{
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
@ -52,13 +53,8 @@ namespace ngraph
|
||||
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
|
||||
void validate_and_infer_types() override;
|
||||
|
||||
std::shared_ptr<Node>
|
||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
|
||||
protected:
|
||||
void validate();
|
||||
};
|
||||
}
|
||||
}
|
||||
|
@ -9,6 +9,7 @@
|
||||
#include "ngraph/attribute_adapter.hpp"
|
||||
#include "ngraph/op/op.hpp"
|
||||
#include "ngraph/op/util/attr_types.hpp"
|
||||
#include "ngraph/op/util/fft_base.hpp"
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
@ -17,7 +18,7 @@ namespace ngraph
|
||||
namespace v7
|
||||
{
|
||||
/// \brief An operation IDFT that computes the inverse discrete Fourier transformation.
|
||||
class NGRAPH_API IDFT : public Op
|
||||
class NGRAPH_API IDFT : public util::FFTBase
|
||||
{
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
@ -40,13 +41,8 @@ namespace ngraph
|
||||
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
|
||||
void validate_and_infer_types() override;
|
||||
|
||||
std::shared_ptr<Node>
|
||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
|
||||
protected:
|
||||
void validate();
|
||||
};
|
||||
}
|
||||
}
|
||||
|
46
ngraph/core/include/ngraph/op/util/fft_base.hpp
Normal file
46
ngraph/core/include/ngraph/op/util/fft_base.hpp
Normal file
@ -0,0 +1,46 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ngraph/op/op.hpp"
|
||||
#include "ngraph/op/util/attr_types.hpp"
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
namespace op
|
||||
{
|
||||
namespace util
|
||||
{
|
||||
/// \brief Base class for operations DFT and DFT.
|
||||
class NGRAPH_API FFTBase : public Op
|
||||
{
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
FFTBase() = default;
|
||||
|
||||
void validate_and_infer_types() override;
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
|
||||
protected:
|
||||
/// \brief Constructs an FFT operation. FFT is performed for full size axes.
|
||||
///
|
||||
/// \param data Input data
|
||||
/// \param axes Axes to perform FFT
|
||||
FFTBase(const Output<Node>& data, const Output<Node>& axes);
|
||||
|
||||
/// \brief Constructs a FFT operation.
|
||||
///
|
||||
/// \param data Input data
|
||||
/// \param axes Axes to perform FFT
|
||||
/// \param signal_size Signal sizes for 'axes'
|
||||
FFTBase(const Output<Node>& data,
|
||||
const Output<Node>& axes,
|
||||
const Output<Node>& signal_size);
|
||||
|
||||
void validate();
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,62 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2021 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
#include <cstddef>
|
||||
#include <functional>
|
||||
#include <map>
|
||||
#include <ngraph/runtime/host_tensor.hpp>
|
||||
#include <vector>
|
||||
#include "ngraph/node.hpp"
|
||||
#include "ngraph/op/util/op_types.hpp"
|
||||
#include "ngraph/ops.hpp"
|
||||
#include "ngraph/shape_util.hpp"
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
namespace runtime
|
||||
{
|
||||
namespace reference
|
||||
{
|
||||
enum class FFTKind
|
||||
{
|
||||
Forward,
|
||||
Inverse
|
||||
};
|
||||
|
||||
void fft(const float* input_data,
|
||||
const Shape& input_data_shape,
|
||||
const int64_t* axes_data,
|
||||
const Shape& axes_data_shape,
|
||||
float* fft_result,
|
||||
const Shape& output_shape,
|
||||
FFTKind fft_kind);
|
||||
|
||||
void fft_postprocessing(const HostTensorVector& outputs,
|
||||
const ngraph::element::Type output_type,
|
||||
const std::vector<float>& fft_result);
|
||||
|
||||
std::vector<int64_t> canonicalize_axes(const int64_t* axes_data,
|
||||
const Shape& axes_data_shape,
|
||||
int64_t complex_data_rank);
|
||||
}
|
||||
}
|
||||
}
|
635
ngraph/core/reference/src/runtime/reference/fft.cpp
Normal file
635
ngraph/core/reference/src/runtime/reference/fft.cpp
Normal file
@ -0,0 +1,635 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2021 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#include "ngraph/runtime/reference/fft.hpp"
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
#include <complex>
|
||||
#include <cstring>
|
||||
#include <functional>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include "ngraph/shape.hpp"
|
||||
|
||||
using namespace ngraph;
|
||||
using namespace ngraph::runtime::reference;
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
namespace runtime
|
||||
{
|
||||
namespace reference
|
||||
{
|
||||
// FFT operation supports for negative axes to transform. More precisely, according to
|
||||
// the FFT operation specification, axes should be integers from -(r - 1) to (r - 2)
|
||||
// inclusively, where r = rank(data). A negative axis 'a' is interpreted as an axis
|
||||
// 'r - 1 + a'. The reason is the following: real input tensor of the shape
|
||||
// [n_0, ..., n_{r - 1}, 2] is interpreted as a complex tensor with the shape
|
||||
// [n_0, ..., n_{r - 1}]. To simplify calculations, we need to convert negative axes to
|
||||
// positive axes using the formula 'r - 1 + a'.
|
||||
std::vector<int64_t> canonicalize_axes(const int64_t* axes_data,
|
||||
const Shape& axes_data_shape,
|
||||
int64_t complex_data_rank)
|
||||
{
|
||||
size_t num_of_fft_axes = axes_data_shape[0];
|
||||
|
||||
std::vector<int64_t> result(axes_data, axes_data + num_of_fft_axes);
|
||||
for (int64_t& axis : result)
|
||||
{
|
||||
if (axis < 0)
|
||||
{
|
||||
axis += complex_data_rank;
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
namespace
|
||||
{
|
||||
using complex_type = std::complex<float>;
|
||||
|
||||
// Calculates strides for all axes.
|
||||
std::vector<int64_t> compute_strides(const std::vector<int64_t>& v)
|
||||
{
|
||||
std::vector<int64_t> strides(v.size() + 1);
|
||||
int64_t stride = 1;
|
||||
for (size_t i = 0; i < v.size(); ++i)
|
||||
{
|
||||
strides[i] = stride;
|
||||
stride *= v[i];
|
||||
}
|
||||
strides.back() = stride;
|
||||
return strides;
|
||||
}
|
||||
|
||||
// To simplify calculation of strides for all axes of 'shape' of some complex
|
||||
// tensor, we reverse numbers in 'shape'. Because we have no native support for
|
||||
// complex numbers in tensors, we interpret FFT input tensors of the shape
|
||||
// [N_0, ..., N_{r - 1}, 2] as a complex tensor with the shape
|
||||
// [N_0, ..., N_{r - 1}]. Hence, we convert 'shape=[N_0, ..., N_{r - 1}, 2]'
|
||||
// into [N_{r - 1}, ..., N_0].
|
||||
std::vector<int64_t> reverse_shape(const Shape& shape)
|
||||
{
|
||||
size_t complex_data_rank = shape.size() - 1;
|
||||
|
||||
std::vector<int64_t> reversed_shape(complex_data_rank);
|
||||
for (size_t i = 0; i < complex_data_rank; ++i)
|
||||
{
|
||||
reversed_shape[i] = static_cast<int64_t>(shape[complex_data_rank - i - 1]);
|
||||
}
|
||||
return reversed_shape;
|
||||
}
|
||||
|
||||
// This function gets FFT axes from axes_data
|
||||
std::vector<int64_t> get_axes(const int64_t* axes_data,
|
||||
const Shape& axes_data_shape,
|
||||
int64_t complex_data_rank)
|
||||
{
|
||||
auto axes = canonicalize_axes(axes_data, axes_data_shape, complex_data_rank);
|
||||
std::sort(axes.begin(), axes.end(), std::greater<int64_t>{});
|
||||
return axes;
|
||||
}
|
||||
|
||||
// When we reverted shape, we need to revert FFT axes.
|
||||
void reverse_fft_axes(std::vector<int64_t>& axes, int64_t complex_data_rank)
|
||||
{
|
||||
for (int64_t& axis : axes)
|
||||
{
|
||||
axis = complex_data_rank - 1 - axis;
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to get only length with respect to given axes.
|
||||
std::vector<int64_t> get_lengths(const std::vector<int64_t>& shape,
|
||||
const std::vector<int64_t>& axes)
|
||||
{
|
||||
std::vector<int64_t> lengths;
|
||||
for (int64_t axis : axes)
|
||||
{
|
||||
lengths.push_back(shape[axis]);
|
||||
}
|
||||
return lengths;
|
||||
}
|
||||
|
||||
// This function calculates 'outer axes', that is axes that are
|
||||
// not transformed by FFT.
|
||||
std::vector<int64_t> get_outer_axes(const std::vector<int64_t>& inner_axes,
|
||||
int64_t complex_data_rank)
|
||||
{
|
||||
int64_t num_of_inner_axes = static_cast<int64_t>(inner_axes.size());
|
||||
int64_t num_of_outer_axes = complex_data_rank - num_of_inner_axes;
|
||||
|
||||
std::vector<int64_t> outer_axes(num_of_outer_axes);
|
||||
|
||||
int64_t fft_axes_as_bitset = 0;
|
||||
for (int64_t axis : inner_axes)
|
||||
{
|
||||
assert(axis < 64);
|
||||
fft_axes_as_bitset |= static_cast<int64_t>(1) << axis;
|
||||
}
|
||||
|
||||
for (int64_t j = 0, i = 0; i < complex_data_rank; ++i)
|
||||
{
|
||||
if ((fft_axes_as_bitset & (static_cast<int64_t>(1) << i)) == 0)
|
||||
{
|
||||
outer_axes[j] = i;
|
||||
++j;
|
||||
}
|
||||
}
|
||||
|
||||
return outer_axes;
|
||||
}
|
||||
|
||||
inline bool is_power_of_two(int64_t x) { return (x != 0) && ((x & (x - 1)) == 0); }
|
||||
|
||||
// This function calculates internal FFT buffer size using lengths of FFT axes.
|
||||
int64_t compute_buffer_size(const std::vector<int64_t>& fft_lengths)
|
||||
{
|
||||
int64_t buffer_size = 0;
|
||||
|
||||
for (int64_t length : fft_lengths)
|
||||
{
|
||||
int64_t current_size = is_power_of_two(length) ? (2 * length) : length;
|
||||
buffer_size = std::max(buffer_size, current_size);
|
||||
}
|
||||
|
||||
return buffer_size;
|
||||
}
|
||||
|
||||
// Calculating coordinates c_0, ..., c_{k - 1} from the index of the form
|
||||
// c_0 * strides[0] + ... c_{k - 1} * strides[k - 1]
|
||||
// where k is the number of strides.
|
||||
std::vector<int64_t> coords_from_index(int64_t index,
|
||||
const std::vector<int64_t>& strides)
|
||||
{
|
||||
int64_t num_of_axes = static_cast<int64_t>(strides.size()) - 1;
|
||||
if (num_of_axes == 0)
|
||||
{
|
||||
return std::vector<int64_t>{};
|
||||
}
|
||||
std::vector<int64_t> coords(num_of_axes);
|
||||
int64_t curr = index;
|
||||
for (int64_t j = num_of_axes - 1; j >= 1; --j)
|
||||
{
|
||||
coords[j] = curr / strides[j];
|
||||
curr %= strides[j];
|
||||
}
|
||||
coords[0] = curr;
|
||||
return coords;
|
||||
}
|
||||
|
||||
// This function gets a complex value from given coords of this value
|
||||
complex_type get_value_from_input(const complex_type* input_data,
|
||||
int64_t src_index,
|
||||
const std::vector<int64_t>& coords,
|
||||
const std::vector<int64_t>& input_fft_lengths,
|
||||
const std::vector<int64_t>& input_fft_strides)
|
||||
{
|
||||
int64_t offset = 0;
|
||||
int64_t num_of_fft_axes = static_cast<int64_t>(coords.size());
|
||||
for (int64_t i = 0; i < num_of_fft_axes; ++i)
|
||||
{
|
||||
int64_t coord = coords[i];
|
||||
if (coord >= input_fft_lengths[i])
|
||||
{
|
||||
return complex_type{0.0f, 0.0f};
|
||||
}
|
||||
offset += coord * input_fft_strides[i];
|
||||
}
|
||||
|
||||
return input_data[src_index + offset];
|
||||
}
|
||||
|
||||
// Copying input data to the given memory domain.
|
||||
void copy_data_from_input(complex_type* result,
|
||||
const complex_type* input_data,
|
||||
int64_t src_index,
|
||||
int64_t fft_size,
|
||||
const std::vector<int64_t>& fft_strides,
|
||||
const std::vector<int64_t>& input_fft_lengths,
|
||||
const std::vector<int64_t>& input_fft_strides)
|
||||
{
|
||||
for (int64_t idx = 0; idx < fft_size; ++idx)
|
||||
{
|
||||
auto coords = coords_from_index(idx, fft_strides);
|
||||
complex_type value = get_value_from_input(
|
||||
input_data, src_index, coords, input_fft_lengths, input_fft_strides);
|
||||
result[idx] = value;
|
||||
}
|
||||
}
|
||||
|
||||
// This function checks whether data of given complex blob are only zeros.
|
||||
bool blob_is_zero(const complex_type* data, int64_t blob_size)
|
||||
{
|
||||
for (int64_t i = 0; i < blob_size; ++i)
|
||||
{
|
||||
if (data[i] != complex_type{0.0f, 0.0f})
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Calculates offset of value using corresponding coordinates and strides.
|
||||
int64_t offset_from_coords_and_strides(const std::vector<int64_t>& coords,
|
||||
const std::vector<int64_t>& strides)
|
||||
{
|
||||
int64_t offset = 0;
|
||||
int64_t num_of_axes = coords.size();
|
||||
for (int64_t i = 0; i < num_of_axes; ++i)
|
||||
{
|
||||
offset += coords[i] * strides[i];
|
||||
}
|
||||
return offset;
|
||||
}
|
||||
|
||||
// Copying calculated data to the given memory domain.
|
||||
void copy_data_to_output(complex_type* output,
|
||||
const complex_type* data,
|
||||
int64_t dst_index,
|
||||
int64_t fft_size,
|
||||
const std::vector<int64_t>& fft_strides,
|
||||
const std::vector<int64_t>& output_fft_strides)
|
||||
{
|
||||
for (int64_t idx = 0; idx < fft_size; ++idx)
|
||||
{
|
||||
auto coords = coords_from_index(idx, fft_strides);
|
||||
complex_type value = data[idx];
|
||||
int64_t offset = offset_from_coords_and_strides(coords, output_fft_strides);
|
||||
|
||||
output[dst_index + offset] = value;
|
||||
}
|
||||
}
|
||||
|
||||
static constexpr float pi = 3.141592653589793238462643f;
|
||||
|
||||
// This function calculates e^{-2i\pi k / length} for the forward FFT, and
|
||||
// e^{2i\pi k / length} otherwise. Here 'i' is an imaginary unit.
|
||||
complex_type twiddle(int64_t k, int64_t length, FFTKind fft_kind)
|
||||
{
|
||||
float angle = -2.0f * pi * static_cast<float>(k) / static_cast<float>(length);
|
||||
complex_type result = std::exp(complex_type(0.0f, angle));
|
||||
return (fft_kind == FFTKind::Inverse) ? std::conj(result) : result;
|
||||
}
|
||||
|
||||
// This function gathers data from the input of 1D FFT to the contiguous buffer
|
||||
void gather_to_buffer(const complex_type* data,
|
||||
int64_t length,
|
||||
int64_t start,
|
||||
int64_t stride,
|
||||
complex_type* buffer)
|
||||
{
|
||||
for (int64_t k = 0; k < length; ++k)
|
||||
{
|
||||
complex_type value = data[start + k * stride];
|
||||
buffer[k] = value;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<complex_type> generate_twiddles(int64_t length, FFTKind fft_kind)
|
||||
{
|
||||
std::vector<complex_type> twiddles(length / 2);
|
||||
for (int64_t k = 0; k < length / 2; ++k)
|
||||
{
|
||||
twiddles[k] = twiddle(k, length, fft_kind);
|
||||
}
|
||||
return twiddles;
|
||||
}
|
||||
|
||||
// Non-recursive implementation of the Cooley-Tukey radix-2 decimation in
|
||||
// time. Performs 1D FFT transform for the lengths, which are powers of 2.
|
||||
// Runs in O(length * log(length)) time. Uses the same parameters as the naive
|
||||
// implementation above, except that the preallocated buffer must be at least
|
||||
// twice as big as the length of the transform, because the buffer is used to
|
||||
// hold both input and output values for each stage of the transform.
|
||||
void optimized_fft1d(int64_t length,
|
||||
int64_t fft_offset,
|
||||
int64_t stride,
|
||||
complex_type* data,
|
||||
complex_type* buffer,
|
||||
FFTKind fft_kind)
|
||||
{
|
||||
gather_to_buffer(data, length, fft_offset, stride, buffer);
|
||||
if (blob_is_zero(buffer, length))
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
int64_t in_base = length;
|
||||
int64_t out_base = 0;
|
||||
for (int64_t num_blocks = 1; num_blocks < length; num_blocks *= 2)
|
||||
{
|
||||
std::swap(in_base, out_base);
|
||||
|
||||
auto twiddles = generate_twiddles(num_blocks * 2, fft_kind);
|
||||
const int64_t block_size = length / num_blocks;
|
||||
const int64_t next_iteration_block_size = block_size / 2;
|
||||
for (int64_t block = 0; block < num_blocks; block++)
|
||||
{
|
||||
const int64_t in_offset = in_base + block * block_size;
|
||||
const int64_t out_offset = out_base + block * next_iteration_block_size;
|
||||
|
||||
for (int64_t pair = 0; pair < block_size / 2; pair++)
|
||||
{
|
||||
const complex_type even = buffer[in_offset + pair];
|
||||
const complex_type odd = buffer[in_offset + block_size / 2 + pair];
|
||||
const complex_type twiddled_odd = twiddles[block] * odd;
|
||||
buffer[out_offset + pair] = even + twiddled_odd;
|
||||
buffer[out_offset + length / 2 + pair] = even - twiddled_odd;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int64_t k = 0; k < length; k++)
|
||||
{
|
||||
complex_type value = buffer[out_base + k];
|
||||
if (fft_kind == FFTKind::Inverse)
|
||||
{
|
||||
value /= complex_type(length, 0.0f);
|
||||
}
|
||||
data[fft_offset + k * stride] = value;
|
||||
}
|
||||
}
|
||||
|
||||
// Naive implementation of 1D FFT
|
||||
void naive_fft1d(int64_t length,
|
||||
int64_t fft_offset,
|
||||
int64_t stride,
|
||||
complex_type* data,
|
||||
complex_type* buffer,
|
||||
FFTKind fft_kind)
|
||||
{
|
||||
gather_to_buffer(data, length, fft_offset, stride, buffer);
|
||||
if (blob_is_zero(buffer, length))
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
for (int64_t k = 0; k < length; ++k)
|
||||
{
|
||||
complex_type value = complex_type(0.0f, 0.0f);
|
||||
for (int64_t n = 0; n < length; ++n)
|
||||
{
|
||||
value += buffer[n] * twiddle(n * k, length, fft_kind);
|
||||
}
|
||||
if (fft_kind == FFTKind::Inverse)
|
||||
{
|
||||
value /= complex_type(length, 0.0f);
|
||||
}
|
||||
data[fft_offset + k * stride] = value;
|
||||
}
|
||||
}
|
||||
|
||||
void fft1d(int64_t length,
|
||||
int64_t fft_offset,
|
||||
int64_t stride,
|
||||
complex_type* data,
|
||||
complex_type* buffer,
|
||||
FFTKind fft_kind)
|
||||
{
|
||||
if (is_power_of_two(length))
|
||||
{
|
||||
optimized_fft1d(length, fft_offset, stride, data, buffer, fft_kind);
|
||||
}
|
||||
else
|
||||
{
|
||||
naive_fft1d(length, fft_offset, stride, data, buffer, fft_kind);
|
||||
}
|
||||
}
|
||||
|
||||
struct InfoForFFTCalculation
|
||||
{
|
||||
std::vector<int64_t> fft_axes;
|
||||
std::vector<int64_t> fft_lengths;
|
||||
std::vector<int64_t> fft_strides;
|
||||
std::vector<int64_t> outer_strides;
|
||||
std::vector<int64_t> output_fft_strides;
|
||||
std::vector<int64_t> output_outer_strides;
|
||||
std::vector<int64_t> input_fft_lengths;
|
||||
std::vector<int64_t> input_fft_strides;
|
||||
std::vector<int64_t> input_outer_strides;
|
||||
int64_t fft_rank;
|
||||
int64_t fft_size;
|
||||
int64_t outer_size;
|
||||
int64_t buffer_size;
|
||||
};
|
||||
|
||||
// This function builds information needed to calculate FFT.
|
||||
InfoForFFTCalculation get_info_for_calculation(const Shape& input_data_shape,
|
||||
const int64_t* axes_data,
|
||||
const Shape& axes_data_shape,
|
||||
const Shape& output_shape)
|
||||
{
|
||||
InfoForFFTCalculation result;
|
||||
|
||||
const int64_t complex_data_rank =
|
||||
static_cast<int64_t>(input_data_shape.size() - 1);
|
||||
|
||||
const auto reversed_output_shape = reverse_shape(output_shape);
|
||||
auto fft_axes = get_axes(axes_data, axes_data_shape, complex_data_rank);
|
||||
reverse_fft_axes(fft_axes, complex_data_rank);
|
||||
|
||||
const int64_t fft_rank = fft_axes.size();
|
||||
const auto fft_lengths = get_lengths(reversed_output_shape, fft_axes);
|
||||
const auto fft_strides = compute_strides(fft_lengths);
|
||||
const int64_t fft_size = fft_strides[fft_rank];
|
||||
|
||||
const auto outer_axes = get_outer_axes(fft_axes, complex_data_rank);
|
||||
const int64_t outer_rank = outer_axes.size();
|
||||
const auto outer_lengths = get_lengths(reversed_output_shape, outer_axes);
|
||||
const auto outer_strides = compute_strides(outer_lengths);
|
||||
const int64_t outer_size = outer_strides[outer_rank];
|
||||
|
||||
const int64_t buffer_size = compute_buffer_size(fft_lengths);
|
||||
|
||||
const auto output_strides = compute_strides(reversed_output_shape);
|
||||
const auto output_fft_strides = get_lengths(output_strides, fft_axes);
|
||||
const auto output_outer_strides = get_lengths(output_strides, outer_axes);
|
||||
const auto reversed_input_shape = reverse_shape(input_data_shape);
|
||||
const auto input_fft_lengths = get_lengths(reversed_input_shape, fft_axes);
|
||||
const auto input_strides = compute_strides(reversed_input_shape);
|
||||
const auto input_fft_strides = get_lengths(input_strides, fft_axes);
|
||||
const auto input_outer_strides = get_lengths(input_strides, outer_axes);
|
||||
|
||||
result.fft_axes = fft_axes;
|
||||
result.fft_lengths = fft_lengths;
|
||||
result.fft_strides = fft_strides;
|
||||
result.outer_strides = outer_strides;
|
||||
result.output_fft_strides = output_fft_strides;
|
||||
result.output_outer_strides = output_outer_strides;
|
||||
result.input_fft_lengths = input_fft_lengths;
|
||||
result.input_fft_strides = input_fft_strides;
|
||||
result.input_outer_strides = input_outer_strides;
|
||||
result.fft_rank = fft_rank;
|
||||
result.fft_size = fft_size;
|
||||
result.outer_size = outer_size;
|
||||
result.buffer_size = buffer_size;
|
||||
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
// Calculation of FFT
|
||||
void fft(const float* input_data,
|
||||
const Shape& input_data_shape,
|
||||
const int64_t* axes_data,
|
||||
const Shape& axes_data_shape,
|
||||
float* fft_result,
|
||||
const Shape& output_shape,
|
||||
FFTKind fft_kind)
|
||||
{
|
||||
const complex_type* complex_input_data_ptr =
|
||||
reinterpret_cast<const complex_type*>(input_data);
|
||||
complex_type* complex_output_ptr = reinterpret_cast<complex_type*>(fft_result);
|
||||
|
||||
const auto info = get_info_for_calculation(
|
||||
input_data_shape, axes_data, axes_data_shape, output_shape);
|
||||
const auto& fft_axes = info.fft_axes;
|
||||
const int64_t fft_rank = info.fft_rank;
|
||||
const auto& fft_lengths = info.fft_lengths;
|
||||
const auto& fft_strides = info.fft_strides;
|
||||
const int64_t fft_size = info.fft_size;
|
||||
|
||||
if (fft_size <= 0)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<complex_type> data(fft_size);
|
||||
std::vector<complex_type> buffer(info.buffer_size);
|
||||
|
||||
const auto& output_fft_strides = info.output_fft_strides;
|
||||
const auto& outer_strides = info.outer_strides;
|
||||
const int64_t outer_size = info.outer_size;
|
||||
|
||||
const auto& output_outer_strides = info.output_outer_strides;
|
||||
const auto& input_fft_lengths = info.input_fft_lengths;
|
||||
const auto& input_fft_strides = info.input_fft_strides;
|
||||
const auto& input_outer_strides = info.input_outer_strides;
|
||||
|
||||
// Loop along with 'outer' dimensions, that is along with
|
||||
// not transformed dimensions.
|
||||
for (int64_t outer_idx = 0; outer_idx < outer_size; ++outer_idx)
|
||||
{
|
||||
const auto outer_coords = coords_from_index(outer_idx, outer_strides);
|
||||
int64_t outer_input_offset =
|
||||
offset_from_coords_and_strides(outer_coords, input_outer_strides);
|
||||
|
||||
// Copying current data to transform
|
||||
copy_data_from_input(data.data(),
|
||||
complex_input_data_ptr,
|
||||
outer_input_offset,
|
||||
fft_size,
|
||||
fft_strides,
|
||||
input_fft_lengths,
|
||||
input_fft_strides);
|
||||
|
||||
if (!blob_is_zero(data.data(), fft_size))
|
||||
{
|
||||
// The loop along with all transformed axes.
|
||||
for (int64_t axis_idx = 0; axis_idx < fft_rank; ++axis_idx)
|
||||
{
|
||||
int64_t current_fft_stride = fft_strides[axis_idx];
|
||||
int64_t current_fft_length = fft_lengths[axis_idx];
|
||||
|
||||
int64_t outer_fft_size = 1;
|
||||
std::vector<int64_t> outer_fft_lengths;
|
||||
std::vector<int64_t> outer_fft_axes;
|
||||
for (int64_t i = 0; i < fft_rank; ++i)
|
||||
{
|
||||
if (i == axis_idx)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
outer_fft_size *= fft_lengths[i];
|
||||
outer_fft_lengths.push_back(fft_lengths[i]);
|
||||
outer_fft_axes.push_back(fft_axes[i]);
|
||||
}
|
||||
auto outer_fft_strides = compute_strides(outer_fft_lengths);
|
||||
auto fft_strides_for_outer_fft_axes =
|
||||
get_lengths(fft_strides, outer_fft_axes);
|
||||
|
||||
// Loop along with all FFT axes, except the current one.
|
||||
for (int64_t outer_fft_idx = 0; outer_fft_idx < outer_fft_size;
|
||||
++outer_fft_idx)
|
||||
{
|
||||
const auto outer_fft_coords =
|
||||
coords_from_index(outer_fft_idx, outer_fft_strides);
|
||||
int64_t outer_fft_offset = offset_from_coords_and_strides(
|
||||
outer_fft_coords, fft_strides_for_outer_fft_axes);
|
||||
// Calculation of 1D FFT
|
||||
fft1d(current_fft_length,
|
||||
outer_fft_offset,
|
||||
current_fft_stride,
|
||||
data.data(),
|
||||
buffer.data(),
|
||||
fft_kind);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Copying current calculated data to the output blob.
|
||||
int64_t outer_output_offset =
|
||||
offset_from_coords_and_strides(outer_coords, output_outer_strides);
|
||||
copy_data_to_output(complex_output_ptr,
|
||||
data.data(),
|
||||
outer_output_offset,
|
||||
fft_size,
|
||||
fft_strides,
|
||||
output_fft_strides);
|
||||
}
|
||||
}
|
||||
|
||||
void fft_postprocessing(const HostTensorVector& outputs,
|
||||
const ngraph::element::Type output_type,
|
||||
const std::vector<float>& fft_result)
|
||||
{
|
||||
size_t fft_result_size = fft_result.size();
|
||||
|
||||
switch (output_type)
|
||||
{
|
||||
case element::Type_t::bf16:
|
||||
{
|
||||
bfloat16* result_ptr = outputs[0]->get_data_ptr<bfloat16>();
|
||||
for (size_t i = 0; i < fft_result_size; ++i)
|
||||
{
|
||||
result_ptr[i] = bfloat16(fft_result[i]);
|
||||
}
|
||||
}
|
||||
break;
|
||||
case element::Type_t::f16:
|
||||
{
|
||||
float16* result_ptr = outputs[0]->get_data_ptr<float16>();
|
||||
for (size_t i = 0; i < fft_result_size; ++i)
|
||||
{
|
||||
result_ptr[i] = float16(fft_result[i]);
|
||||
}
|
||||
}
|
||||
break;
|
||||
case element::Type_t::f32:
|
||||
{
|
||||
float* result_ptr = outputs[0]->get_data_ptr<float>();
|
||||
memcpy(result_ptr, fft_result.data(), fft_result_size * sizeof(float));
|
||||
}
|
||||
break;
|
||||
default:;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -30,10 +30,10 @@
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(op::v7::DFT, "DFT", 7);
|
||||
NGRAPH_RTTI_DEFINITION(op::v7::DFT, "DFT", 7, util::FFTBase);
|
||||
|
||||
op::v7::DFT::DFT(const Output<Node>& data, const Output<Node>& axes)
|
||||
: Op({data, axes})
|
||||
: FFTBase(data, axes)
|
||||
{
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
@ -41,7 +41,7 @@ op::v7::DFT::DFT(const Output<Node>& data, const Output<Node>& axes)
|
||||
op::v7::DFT::DFT(const Output<Node>& data,
|
||||
const Output<Node>& axes,
|
||||
const Output<Node>& signal_size)
|
||||
: Op({data, axes, signal_size})
|
||||
: FFTBase(data, axes, signal_size)
|
||||
{
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
@ -66,202 +66,3 @@ std::shared_ptr<Node> op::v7::DFT::clone_with_new_inputs(const OutputVector& new
|
||||
|
||||
return std::make_shared<op::v7::DFT>(new_args.at(0), new_args.at(1), new_args.at(2));
|
||||
}
|
||||
|
||||
void op::v7::DFT::validate()
|
||||
{
|
||||
size_t num_of_inputs = get_input_size();
|
||||
|
||||
NODE_VALIDATION_CHECK(
|
||||
this, num_of_inputs == 2 || num_of_inputs == 3, "DFT must have 2 or 3 inputs.");
|
||||
|
||||
element::Type input_et = get_input_element_type(0);
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
input_et == element::f32 || input_et == element::f16 ||
|
||||
input_et == element::bf16,
|
||||
"DFT input element type must be f32, f16, or bf16");
|
||||
|
||||
element::Type axes_et = get_input_element_type(1);
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
axes_et == element::i64 || axes_et == element::i32,
|
||||
"DFT axes element type must be i32 or i64");
|
||||
|
||||
const auto& input_shape = PartialShape(get_input_partial_shape(0));
|
||||
if (input_shape.rank().is_static())
|
||||
{
|
||||
const auto input_rank = input_shape.rank().get_length();
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
input_rank >= 2,
|
||||
"The input rank must be greater or equal to 2. Got input rank: ",
|
||||
input_rank);
|
||||
|
||||
auto last_dim_with_two = input_shape[input_rank - 1] & Dimension(2);
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
!last_dim_with_two.get_interval().empty(),
|
||||
"The last dimension of input data must be 2. Got: ",
|
||||
input_shape[input_rank - 1]);
|
||||
}
|
||||
|
||||
const auto& axes_shape = PartialShape(get_input_partial_shape(1));
|
||||
if (axes_shape.rank().is_static())
|
||||
{
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
axes_shape.rank().get_length() == 1,
|
||||
"DFT axes input must be 1D tensor. Got axes input rank: ",
|
||||
axes_shape.rank().get_length());
|
||||
}
|
||||
|
||||
if (input_shape.rank().is_static() && axes_shape.is_static())
|
||||
{
|
||||
const auto input_rank = input_shape.rank().get_length();
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
input_rank >= axes_shape.to_shape()[0] + 1,
|
||||
"The input rank must be greater than number of DFT axes. Got "
|
||||
"input rank: ",
|
||||
input_rank,
|
||||
", number of axes: ",
|
||||
axes_shape.to_shape()[0]);
|
||||
}
|
||||
|
||||
if (input_shape.rank().is_static() && is_type<op::Constant>(input_value(1).get_node()))
|
||||
{
|
||||
const auto input_rank = input_shape.rank().get_length();
|
||||
const auto& const_axes = get_constant_from_source(input_value(1));
|
||||
auto axes = const_axes->cast_vector<int64_t>();
|
||||
|
||||
// DFT operation supports for negative axes to transform. More precisely, according to
|
||||
// the DFT operation specification, axes should be integers from -(r - 1) to (r - 2)
|
||||
// inclusively, where r = rank(data). A negative axis 'a' is interpreted as an axis
|
||||
//'r - 1 + a'. The reason is the following.
|
||||
for (int64_t& axis : axes)
|
||||
{
|
||||
if (axis < 0)
|
||||
{
|
||||
axis += input_rank - 1;
|
||||
}
|
||||
}
|
||||
|
||||
AxisVector axes_vector;
|
||||
AxisSet axes_set;
|
||||
for (const int64_t axis : axes)
|
||||
{
|
||||
axes_vector.push_back(static_cast<size_t>(axis));
|
||||
axes_set.insert(static_cast<size_t>(axis));
|
||||
}
|
||||
|
||||
NODE_VALIDATION_CHECK(
|
||||
this, axes.size() == axes_set.size(), "DFT axes must be unique. Got: ", axes_vector);
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
std::find(axes.begin(), axes.end(), input_rank - 1) == axes.end(),
|
||||
"DFT axes cannot contain the last axis. Got axes: ",
|
||||
axes_vector);
|
||||
}
|
||||
|
||||
if (num_of_inputs == 3)
|
||||
{
|
||||
element::Type signal_size_et = get_input_element_type(2);
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
signal_size_et == element::i64 || signal_size_et == element::i32,
|
||||
"DFT signal_size element type must be i32 or i64");
|
||||
|
||||
const auto& signal_size_shape = PartialShape(get_input_partial_shape(2));
|
||||
if (signal_size_shape.rank().is_static())
|
||||
{
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
signal_size_shape.rank().get_length() == 1,
|
||||
"DFT Signal size input must be 1D tensor. Got signal size "
|
||||
"input rank: ",
|
||||
signal_size_shape.rank().get_length());
|
||||
}
|
||||
|
||||
if (axes_shape.is_static() && signal_size_shape.is_static())
|
||||
{
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
axes_shape.to_shape()[0] == signal_size_shape.to_shape()[0],
|
||||
"Sizes of inputs 'axes' and 'signal_size' must be equal. Got "
|
||||
"size of 'axes': ",
|
||||
axes_shape.to_shape()[0],
|
||||
"size of 'signal_size': ",
|
||||
signal_size_shape.to_shape()[0]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void op::v7::DFT::validate_and_infer_types()
|
||||
{
|
||||
NGRAPH_OP_SCOPE(v7_DFT_validate_and_infer_types);
|
||||
validate();
|
||||
|
||||
const auto& input_shape = PartialShape(get_input_partial_shape(0));
|
||||
const auto& axes_shape = PartialShape(get_input_partial_shape(1));
|
||||
PartialShape output_shape = input_shape;
|
||||
if (input_shape.rank().is_dynamic())
|
||||
{
|
||||
set_output_type(0, get_input_element_type(0), output_shape);
|
||||
return;
|
||||
}
|
||||
|
||||
const auto input_rank = input_shape.rank().get_length();
|
||||
|
||||
if (axes_shape.rank().is_dynamic() || !is_type<op::Constant>(input_value(1).get_node()))
|
||||
{
|
||||
for (size_t i = 0; i < input_rank - 1; ++i)
|
||||
{
|
||||
output_shape[i] = Dimension::dynamic();
|
||||
}
|
||||
set_output_type(0, get_input_element_type(0), output_shape);
|
||||
return;
|
||||
}
|
||||
|
||||
if (input_values().size() == 2)
|
||||
{
|
||||
set_output_type(0, get_input_element_type(0), output_shape);
|
||||
return;
|
||||
}
|
||||
|
||||
const auto& signal_size_shape = PartialShape(get_input_partial_shape(2));
|
||||
if (signal_size_shape.rank().is_dynamic())
|
||||
{
|
||||
set_output_type(0, get_input_element_type(0), output_shape);
|
||||
return;
|
||||
}
|
||||
|
||||
const auto& const_axes = get_constant_from_source(input_value(1));
|
||||
auto axes = const_axes->cast_vector<int64_t>();
|
||||
// DFT operation supports for negative axes to transform. More precisely, according to
|
||||
// the DFT operation specification, axes should be integers from -(r - 1) to (r - 2)
|
||||
// inclusively, where r = rank(data). A negative axis 'a' is interpreted as an axis
|
||||
//'r - 1 + a'. The reason is the following.
|
||||
for (int64_t& axis : axes)
|
||||
{
|
||||
if (axis < 0)
|
||||
{
|
||||
axis += input_rank - 1;
|
||||
}
|
||||
}
|
||||
|
||||
if (!is_type<op::Constant>(input_value(2).get_node()))
|
||||
{
|
||||
for (int64_t axis : axes)
|
||||
{
|
||||
output_shape[axis] = Dimension::dynamic();
|
||||
}
|
||||
set_output_type(0, get_input_element_type(0), output_shape);
|
||||
return;
|
||||
}
|
||||
|
||||
const auto& const_signal_size = get_constant_from_source(input_value(2));
|
||||
const auto signal_size = const_signal_size->cast_vector<int64_t>();
|
||||
|
||||
size_t num_of_axes = axes.size();
|
||||
for (size_t i = 0; i < num_of_axes; ++i)
|
||||
{
|
||||
if (signal_size[i] == -1)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
output_shape[axes[i]] = Dimension(signal_size[i]);
|
||||
}
|
||||
|
||||
set_output_type(0, get_input_element_type(0), output_shape);
|
||||
}
|
||||
|
@ -17,10 +17,10 @@
|
||||
|
||||
using namespace ngraph;
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(op::v7::IDFT, "IDFT", 7);
|
||||
NGRAPH_RTTI_DEFINITION(op::v7::IDFT, "IDFT", 7, util::FFTBase);
|
||||
|
||||
op::v7::IDFT::IDFT(const Output<Node>& data, const Output<Node>& axes)
|
||||
: Op({data, axes})
|
||||
: FFTBase(data, axes)
|
||||
{
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
@ -28,7 +28,7 @@ op::v7::IDFT::IDFT(const Output<Node>& data, const Output<Node>& axes)
|
||||
op::v7::IDFT::IDFT(const Output<Node>& data,
|
||||
const Output<Node>& axes,
|
||||
const Output<Node>& signal_size)
|
||||
: Op({data, axes, signal_size})
|
||||
: FFTBase(data, axes, signal_size)
|
||||
{
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
@ -53,202 +53,3 @@ std::shared_ptr<Node> op::v7::IDFT::clone_with_new_inputs(const OutputVector& ne
|
||||
|
||||
return std::make_shared<op::v7::IDFT>(new_args.at(0), new_args.at(1), new_args.at(2));
|
||||
}
|
||||
|
||||
void op::v7::IDFT::validate()
|
||||
{
|
||||
size_t num_of_inputs = get_input_size();
|
||||
|
||||
NODE_VALIDATION_CHECK(
|
||||
this, num_of_inputs == 2 || num_of_inputs == 3, "IDFT must have 2 or 3 inputs.");
|
||||
|
||||
element::Type input_et = get_input_element_type(0);
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
input_et == element::f32 || input_et == element::f16 ||
|
||||
input_et == element::bf16,
|
||||
"IDFT input element type must be f32, f16, or bf16");
|
||||
|
||||
element::Type axes_et = get_input_element_type(1);
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
axes_et == element::i64 || axes_et == element::i32,
|
||||
"IDFT axes element type must be i32 or i64");
|
||||
|
||||
const auto& input_shape = PartialShape(get_input_partial_shape(0));
|
||||
if (input_shape.rank().is_static())
|
||||
{
|
||||
const auto input_rank = input_shape.rank().get_length();
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
input_rank >= 2,
|
||||
"The input rank must be greater or equal to 2. Got input rank: ",
|
||||
input_rank);
|
||||
|
||||
auto last_dim_with_two = input_shape[input_rank - 1] & Dimension(2);
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
!last_dim_with_two.get_interval().empty(),
|
||||
"The last dimension of input data must be 2. Got: ",
|
||||
input_shape[input_rank - 1]);
|
||||
}
|
||||
|
||||
const auto& axes_shape = PartialShape(get_input_partial_shape(1));
|
||||
if (axes_shape.rank().is_static())
|
||||
{
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
axes_shape.rank().get_length() == 1,
|
||||
"IDFT axes input must be 1D tensor. Got axes input rank: ",
|
||||
axes_shape.rank().get_length());
|
||||
}
|
||||
|
||||
if (input_shape.rank().is_static() && axes_shape.is_static())
|
||||
{
|
||||
const auto input_rank = input_shape.rank().get_length();
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
input_rank >= axes_shape.to_shape()[0] + 1,
|
||||
"The input rank must be greater than number of IDFT axes. Got "
|
||||
"input rank: ",
|
||||
input_rank,
|
||||
", number of axes: ",
|
||||
axes_shape.to_shape()[0]);
|
||||
}
|
||||
|
||||
if (input_shape.rank().is_static() && is_type<op::Constant>(input_value(1).get_node()))
|
||||
{
|
||||
const auto input_rank = input_shape.rank().get_length();
|
||||
const auto& const_axes = get_constant_from_source(input_value(1));
|
||||
auto axes = const_axes->cast_vector<int64_t>();
|
||||
|
||||
// IDFT operation supports for negative axes to transform. More precisely, according to
|
||||
// the IDFT operation specification, axes should be integers from -(r - 1) to (r - 2)
|
||||
// inclusively, where r = rank(data). A negative axis 'a' is interpreted as an axis
|
||||
//'r - 1 + a'. The reason is the following.
|
||||
for (int64_t& axis : axes)
|
||||
{
|
||||
if (axis < 0)
|
||||
{
|
||||
axis += input_rank - 1;
|
||||
}
|
||||
}
|
||||
|
||||
AxisVector axes_vector;
|
||||
AxisSet axes_set;
|
||||
for (const int64_t axis : axes)
|
||||
{
|
||||
axes_vector.push_back(static_cast<size_t>(axis));
|
||||
axes_set.insert(static_cast<size_t>(axis));
|
||||
}
|
||||
|
||||
NODE_VALIDATION_CHECK(
|
||||
this, axes.size() == axes_set.size(), "IDFT axes must be unique. Got: ", axes_vector);
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
std::find(axes.begin(), axes.end(), input_rank - 1) == axes.end(),
|
||||
"IDFT axes cannot contain the last axis. Got axes: ",
|
||||
axes_vector);
|
||||
}
|
||||
|
||||
if (num_of_inputs == 3)
|
||||
{
|
||||
element::Type signal_size_et = get_input_element_type(2);
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
signal_size_et == element::i64 || signal_size_et == element::i32,
|
||||
"IDFT signal_size element type must be i32 or i64");
|
||||
|
||||
const auto& signal_size_shape = PartialShape(get_input_partial_shape(2));
|
||||
if (signal_size_shape.rank().is_static())
|
||||
{
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
signal_size_shape.rank().get_length() == 1,
|
||||
"IDFT Signal size input must be 1D tensor. Got signal size "
|
||||
"input rank: ",
|
||||
signal_size_shape.rank().get_length());
|
||||
}
|
||||
|
||||
if (axes_shape.is_static() && signal_size_shape.is_static())
|
||||
{
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
axes_shape.to_shape()[0] == signal_size_shape.to_shape()[0],
|
||||
"Sizes of inputs 'axes' and 'signal_size' must be equal. Got "
|
||||
"size of 'axes': ",
|
||||
axes_shape.to_shape()[0],
|
||||
"size of 'signal_size': ",
|
||||
signal_size_shape.to_shape()[0]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void op::v7::IDFT::validate_and_infer_types()
|
||||
{
|
||||
NGRAPH_OP_SCOPE(v7_IDFT_validate_and_infer_types);
|
||||
validate();
|
||||
|
||||
const auto& input_shape = PartialShape(get_input_partial_shape(0));
|
||||
const auto& axes_shape = PartialShape(get_input_partial_shape(1));
|
||||
PartialShape output_shape = input_shape;
|
||||
if (input_shape.rank().is_dynamic())
|
||||
{
|
||||
set_output_type(0, get_input_element_type(0), output_shape);
|
||||
return;
|
||||
}
|
||||
|
||||
const auto input_rank = input_shape.rank().get_length();
|
||||
|
||||
if (axes_shape.rank().is_dynamic() || !is_type<op::Constant>(input_value(1).get_node()))
|
||||
{
|
||||
for (size_t i = 0; i < input_rank - 1; ++i)
|
||||
{
|
||||
output_shape[i] = Dimension::dynamic();
|
||||
}
|
||||
set_output_type(0, get_input_element_type(0), output_shape);
|
||||
return;
|
||||
}
|
||||
|
||||
if (input_values().size() == 2)
|
||||
{
|
||||
set_output_type(0, get_input_element_type(0), output_shape);
|
||||
return;
|
||||
}
|
||||
|
||||
const auto& signal_size_shape = PartialShape(get_input_partial_shape(2));
|
||||
if (signal_size_shape.rank().is_dynamic())
|
||||
{
|
||||
set_output_type(0, get_input_element_type(0), output_shape);
|
||||
return;
|
||||
}
|
||||
|
||||
const auto& const_axes = get_constant_from_source(input_value(1));
|
||||
auto axes = const_axes->cast_vector<int64_t>();
|
||||
// IDFT operation supports for negative axes to transform. More precisely, according to
|
||||
// the IDFT operation specification, axes should be integers from -(r - 1) to (r - 2)
|
||||
// inclusively, where r = rank(data). A negative axis 'a' is interpreted as an axis
|
||||
//'r - 1 + a'. The reason is the following.
|
||||
for (int64_t& axis : axes)
|
||||
{
|
||||
if (axis < 0)
|
||||
{
|
||||
axis += input_rank - 1;
|
||||
}
|
||||
}
|
||||
|
||||
if (!is_type<op::Constant>(input_value(2).get_node()))
|
||||
{
|
||||
for (int64_t axis : axes)
|
||||
{
|
||||
output_shape[axis] = Dimension::dynamic();
|
||||
}
|
||||
set_output_type(0, get_input_element_type(0), output_shape);
|
||||
return;
|
||||
}
|
||||
|
||||
const auto& const_signal_size = get_constant_from_source(input_value(2));
|
||||
const auto signal_size = const_signal_size->cast_vector<int64_t>();
|
||||
|
||||
size_t num_of_axes = axes.size();
|
||||
for (size_t i = 0; i < num_of_axes; ++i)
|
||||
{
|
||||
if (signal_size[i] == -1)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
output_shape[axes[i]] = Dimension(signal_size[i]);
|
||||
}
|
||||
|
||||
set_output_type(0, get_input_element_type(0), output_shape);
|
||||
}
|
||||
|
234
ngraph/core/src/op/util/fft_base.cpp
Normal file
234
ngraph/core/src/op/util/fft_base.cpp
Normal file
@ -0,0 +1,234 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "ngraph/op/util/fft_base.hpp"
|
||||
#include <ngraph/validation_util.hpp>
|
||||
#include "itt.hpp"
|
||||
#include "ngraph/attribute_visitor.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(op::util::FFTBase, "FFTBase", 0);
|
||||
|
||||
op::util::FFTBase::FFTBase(const Output<Node>& data, const Output<Node>& axes)
|
||||
: Op({data, axes})
|
||||
{
|
||||
}
|
||||
|
||||
op::util::FFTBase::FFTBase(const Output<Node>& data,
|
||||
const Output<Node>& axes,
|
||||
const Output<Node>& signal_size)
|
||||
: Op({data, axes, signal_size})
|
||||
{
|
||||
}
|
||||
|
||||
bool op::util::FFTBase::visit_attributes(AttributeVisitor& visitor)
|
||||
{
|
||||
NGRAPH_OP_SCOPE(util_FFTBase_visit_attributes);
|
||||
return true;
|
||||
}
|
||||
|
||||
void op::util::FFTBase::validate()
|
||||
{
|
||||
size_t num_of_inputs = get_input_size();
|
||||
|
||||
NODE_VALIDATION_CHECK(
|
||||
this, num_of_inputs == 2 || num_of_inputs == 3, "FFT op must have 2 or 3 inputs.");
|
||||
|
||||
element::Type input_et = get_input_element_type(0);
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
input_et == element::f32 || input_et == element::f16 ||
|
||||
input_et == element::bf16,
|
||||
"FFT op input element type must be f32, f16, or bf16");
|
||||
|
||||
element::Type axes_et = get_input_element_type(1);
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
axes_et == element::i64 || axes_et == element::i32,
|
||||
"FFT op axes element type must be i32 or i64");
|
||||
|
||||
const auto& input_shape = PartialShape(get_input_partial_shape(0));
|
||||
if (input_shape.rank().is_static())
|
||||
{
|
||||
const auto input_rank = input_shape.rank().get_length();
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
input_rank >= 2,
|
||||
"The input rank must be greater or equal to 2. Got input rank: ",
|
||||
input_rank);
|
||||
|
||||
auto last_dim_with_two = input_shape[input_rank - 1] & Dimension(2);
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
!last_dim_with_two.get_interval().empty(),
|
||||
"The last dimension of input data must be 2. Got: ",
|
||||
input_shape[input_rank - 1]);
|
||||
}
|
||||
|
||||
const auto& axes_shape = PartialShape(get_input_partial_shape(1));
|
||||
if (axes_shape.rank().is_static())
|
||||
{
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
axes_shape.rank().get_length() == 1,
|
||||
"FFT op axes input must be 1D tensor. Got axes input rank: ",
|
||||
axes_shape.rank().get_length());
|
||||
}
|
||||
|
||||
if (input_shape.rank().is_static() && axes_shape.is_static())
|
||||
{
|
||||
const auto input_rank = input_shape.rank().get_length();
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
input_rank >= axes_shape.to_shape()[0] + 1,
|
||||
"The input rank must be greater than number of FFT op axes. Got "
|
||||
"input rank: ",
|
||||
input_rank,
|
||||
", number of axes: ",
|
||||
axes_shape.to_shape()[0]);
|
||||
}
|
||||
|
||||
if (input_shape.rank().is_static() && is_type<op::Constant>(input_value(1).get_node()))
|
||||
{
|
||||
const auto input_rank = input_shape.rank().get_length();
|
||||
const auto& const_axes = get_constant_from_source(input_value(1));
|
||||
auto axes = const_axes->cast_vector<int64_t>();
|
||||
|
||||
// FFT operation supports for negative axes to transform. More precisely, according to
|
||||
// the FFT operation specification, axes should be integers from -(r - 1) to (r - 2)
|
||||
// inclusively, where r = rank(data). A negative axis 'a' is interpreted as an axis
|
||||
// 'r - 1 + a'. The reason is the following: real input tensor of the shape
|
||||
// [n_0, ..., n_{r - 1}, 2] is interpreted as a complex tensor with the shape
|
||||
// [n_0, ..., n_{r - 1}].
|
||||
for (int64_t& axis : axes)
|
||||
{
|
||||
if (axis < 0)
|
||||
{
|
||||
axis += input_rank - 1;
|
||||
}
|
||||
}
|
||||
|
||||
AxisVector axes_vector;
|
||||
AxisSet axes_set;
|
||||
for (const int64_t axis : axes)
|
||||
{
|
||||
axes_vector.push_back(static_cast<size_t>(axis));
|
||||
axes_set.insert(static_cast<size_t>(axis));
|
||||
}
|
||||
|
||||
NODE_VALIDATION_CHECK(
|
||||
this, axes.size() == axes_set.size(), "FFT op axes must be unique. Got: ", axes_vector);
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
std::find(axes.begin(), axes.end(), input_rank - 1) == axes.end(),
|
||||
"FFT op axes cannot contain the last axis. Got axes: ",
|
||||
axes_vector);
|
||||
}
|
||||
|
||||
if (num_of_inputs == 3)
|
||||
{
|
||||
element::Type signal_size_et = get_input_element_type(2);
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
signal_size_et == element::i64 || signal_size_et == element::i32,
|
||||
"FFT op signal_size element type must be i32 or i64");
|
||||
|
||||
const auto& signal_size_shape = PartialShape(get_input_partial_shape(2));
|
||||
if (signal_size_shape.rank().is_static())
|
||||
{
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
signal_size_shape.rank().get_length() == 1,
|
||||
"FFT op signal size input must be 1D tensor. Got signal size "
|
||||
"input rank: ",
|
||||
signal_size_shape.rank().get_length());
|
||||
}
|
||||
|
||||
if (axes_shape.is_static() && signal_size_shape.is_static())
|
||||
{
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
axes_shape.to_shape()[0] == signal_size_shape.to_shape()[0],
|
||||
"Sizes of inputs 'axes' and 'signal_size' must be equal. Got "
|
||||
"size of 'axes': ",
|
||||
axes_shape.to_shape()[0],
|
||||
"size of 'signal_size': ",
|
||||
signal_size_shape.to_shape()[0]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void op::util::FFTBase::validate_and_infer_types()
|
||||
{
|
||||
NGRAPH_OP_SCOPE(util_FFTBase_validate_and_infer_types);
|
||||
validate();
|
||||
|
||||
const auto& input_shape = PartialShape(get_input_partial_shape(0));
|
||||
const auto& axes_shape = PartialShape(get_input_partial_shape(1));
|
||||
PartialShape output_shape = input_shape;
|
||||
if (input_shape.rank().is_dynamic())
|
||||
{
|
||||
set_output_type(0, get_input_element_type(0), output_shape);
|
||||
return;
|
||||
}
|
||||
|
||||
const auto input_rank = input_shape.rank().get_length();
|
||||
|
||||
if (axes_shape.rank().is_dynamic() || !is_type<op::Constant>(input_value(1).get_node()))
|
||||
{
|
||||
for (size_t i = 0; i < input_rank - 1; ++i)
|
||||
{
|
||||
output_shape[i] = Dimension::dynamic();
|
||||
}
|
||||
set_output_type(0, get_input_element_type(0), output_shape);
|
||||
return;
|
||||
}
|
||||
|
||||
if (input_values().size() == 2)
|
||||
{
|
||||
set_output_type(0, get_input_element_type(0), output_shape);
|
||||
return;
|
||||
}
|
||||
|
||||
const auto& signal_size_shape = PartialShape(get_input_partial_shape(2));
|
||||
if (signal_size_shape.rank().is_dynamic())
|
||||
{
|
||||
set_output_type(0, get_input_element_type(0), output_shape);
|
||||
return;
|
||||
}
|
||||
|
||||
const auto& const_axes = get_constant_from_source(input_value(1));
|
||||
auto axes = const_axes->cast_vector<int64_t>();
|
||||
// FFT operation supports for negative axes to transform. More precisely, according to
|
||||
// the FFT operation specification, axes should be integers from -(r - 1) to (r - 2)
|
||||
// inclusively, where r = rank(data). A negative axis 'a' is interpreted as an axis
|
||||
// 'r - 1 + a'. The reason is the following: real input tensor of the shape
|
||||
// [n_0, ..., n_{r - 1}, 2] is interpreted as a complex tensor with the shape
|
||||
// [n_0, ..., n_{r - 1}].
|
||||
for (int64_t& axis : axes)
|
||||
{
|
||||
if (axis < 0)
|
||||
{
|
||||
axis += input_rank - 1;
|
||||
}
|
||||
}
|
||||
|
||||
if (!is_type<op::Constant>(input_value(2).get_node()))
|
||||
{
|
||||
for (int64_t axis : axes)
|
||||
{
|
||||
output_shape[axis] = Dimension::dynamic();
|
||||
}
|
||||
set_output_type(0, get_input_element_type(0), output_shape);
|
||||
return;
|
||||
}
|
||||
|
||||
const auto& const_signal_size = get_constant_from_source(input_value(2));
|
||||
const auto signal_size = const_signal_size->cast_vector<int64_t>();
|
||||
|
||||
size_t num_of_axes = axes.size();
|
||||
for (size_t i = 0; i < num_of_axes; ++i)
|
||||
{
|
||||
if (signal_size[i] == -1)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
output_shape[axes[i]] = Dimension(signal_size[i]);
|
||||
}
|
||||
|
||||
set_output_type(0, get_input_element_type(0), output_shape);
|
||||
}
|
@ -323,6 +323,7 @@ set(MULTI_TEST_SRC
|
||||
backend/ctc_greedy_decoder_seq_len.in.cpp
|
||||
backend/cum_sum.in.cpp
|
||||
backend/detection_output.in.cpp
|
||||
backend/dft.in.cpp
|
||||
backend/divide.in.cpp
|
||||
backend/dyn_reshape.in.cpp
|
||||
backend/strided_slice.in.cpp
|
||||
@ -339,6 +340,7 @@ set(MULTI_TEST_SRC
|
||||
backend/group_convolution.in.cpp
|
||||
backend/group_convolution_backprop_data.in.cpp
|
||||
backend/hard_sigmoid.in.cpp
|
||||
backend/idft.in.cpp
|
||||
backend/interpolate.in.cpp
|
||||
backend/log.in.cpp
|
||||
backend/log_softmax.in.cpp
|
||||
|
1416
ngraph/test/backend/dft.in.cpp
Normal file
1416
ngraph/test/backend/dft.in.cpp
Normal file
File diff suppressed because it is too large
Load Diff
1272
ngraph/test/backend/idft.in.cpp
Normal file
1272
ngraph/test/backend/idft.in.cpp
Normal file
File diff suppressed because it is too large
Load Diff
@ -1061,7 +1061,7 @@ rnn_cell_zero_bias_default_attrs
|
||||
# Activation function hardsigmoid is not supported
|
||||
gru_cell_hardsigmoid_activation_function
|
||||
|
||||
# Roll is not implemented yet for CPU, GPU
|
||||
# Roll is not implemented yet for CPU, GPU
|
||||
roll_2d_input
|
||||
roll_2d_input_negative_shift
|
||||
roll_repeated_axes
|
||||
@ -1164,6 +1164,41 @@ IE_CPU.onnx_model_reduce_sum_13_axes_as_0_dim_input
|
||||
# output mismatch
|
||||
IE_CPU.gather_nd_batch_1d_from_3d_negative
|
||||
|
||||
# DFT is not implemented yet in plugins
|
||||
IE_CPU.dft1d_eval
|
||||
IE_CPU.dft2d_eval
|
||||
IE_CPU.dft3d_eval
|
||||
IE_CPU.dft1d_eval_i32
|
||||
IE_CPU.dft2d_eval_i32
|
||||
IE_CPU.dft3d_eval_i32
|
||||
IE_CPU.dft1d_signal_size_eval
|
||||
IE_CPU.dft1d_eval_bfloat16
|
||||
IE_CPU.dft2d_eval_bfloat16
|
||||
IE_CPU.dft3d_eval_bfloat16
|
||||
IE_CPU.dft1d_eval_float16
|
||||
IE_CPU.dft2d_eval_float16
|
||||
IE_CPU.dft3d_eval_float16
|
||||
IE_CPU.dft1d_eval_1
|
||||
IE_CPU.dft2d_eval_1
|
||||
IE_CPU.dft3d_eval_1
|
||||
|
||||
# IDFT is not implemented yet in plugins
|
||||
IE_CPU.idft1d_eval
|
||||
IE_CPU.idft2d_eval
|
||||
IE_CPU.idft3d_eval
|
||||
IE_CPU.idft1d_eval_i32
|
||||
IE_CPU.idft2d_eval_i32
|
||||
IE_CPU.idft3d_eval_i32
|
||||
IE_CPU.idft1d_eval_bfloat16
|
||||
IE_CPU.idft2d_eval_bfloat16
|
||||
IE_CPU.idft3d_eval_bfloat16
|
||||
IE_CPU.idft1d_eval_float16
|
||||
IE_CPU.idft2d_eval_float16
|
||||
IE_CPU.idft3d_eval_float16
|
||||
IE_CPU.idft1d_eval_1
|
||||
IE_CPU.idft2d_eval_1
|
||||
IE_CPU.idft3d_eval_1
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
#
|
||||
# Inference Engine GPU plugin excludes
|
||||
|
@ -27,6 +27,7 @@
|
||||
#include <ngraph/runtime/reference/embedding_segments_sum.hpp>
|
||||
#include <ngraph/runtime/reference/extract_image_patches.hpp>
|
||||
#include <ngraph/runtime/reference/fake_quantize.hpp>
|
||||
#include <ngraph/runtime/reference/fft.hpp>
|
||||
#include <ngraph/runtime/reference/gather_elements.hpp>
|
||||
#include <ngraph/runtime/reference/gather_nd.hpp>
|
||||
#include <ngraph/runtime/reference/gather_tree.hpp>
|
||||
@ -878,6 +879,106 @@ namespace
|
||||
return true;
|
||||
}
|
||||
|
||||
namespace fft_v7
|
||||
{
|
||||
struct InfoForFFT7
|
||||
{
|
||||
std::vector<float> input_data;
|
||||
std::vector<int64_t> axes_data;
|
||||
Shape input_data_shape;
|
||||
Shape axes_data_shape;
|
||||
Shape output_shape;
|
||||
};
|
||||
|
||||
std::vector<int64_t> get_signal_size(
|
||||
const std::vector<std::shared_ptr<HostTensor>>& inputs, size_t num_of_axes)
|
||||
{
|
||||
if (inputs.size() == 3)
|
||||
{
|
||||
return nms_v5::get_integers(inputs[2], inputs[2]->get_shape());
|
||||
}
|
||||
|
||||
return std::vector<int64_t>(num_of_axes, static_cast<int64_t>(-1));
|
||||
}
|
||||
|
||||
InfoForFFT7 get_info_for_fft7_eval(const std::vector<std::shared_ptr<HostTensor>>& inputs)
|
||||
{
|
||||
InfoForFFT7 result;
|
||||
|
||||
result.input_data_shape = inputs[0]->get_shape();
|
||||
result.axes_data_shape = inputs[1]->get_shape();
|
||||
result.input_data = nms_v5::get_floats(inputs[0], result.input_data_shape);
|
||||
result.axes_data = nms_v5::get_integers(inputs[1], result.axes_data_shape);
|
||||
|
||||
auto output_shape = result.input_data_shape;
|
||||
|
||||
int64_t input_rank = static_cast<int64_t>(result.input_data_shape.size());
|
||||
int64_t complex_data_rank = input_rank - 1;
|
||||
auto canonicalized_axes = runtime::reference::canonicalize_axes(result.axes_data.data(),
|
||||
result.axes_data_shape,
|
||||
complex_data_rank);
|
||||
|
||||
size_t num_of_axes = result.axes_data.size();
|
||||
auto signal_size = get_signal_size(inputs, num_of_axes);
|
||||
|
||||
for (size_t i = 0; i < num_of_axes; ++i)
|
||||
{
|
||||
int64_t current_axis = canonicalized_axes[i];
|
||||
int64_t current_signal_size = signal_size[i];
|
||||
if (current_signal_size != -1)
|
||||
{
|
||||
output_shape[current_axis] = current_signal_size;
|
||||
}
|
||||
}
|
||||
|
||||
result.output_shape = output_shape;
|
||||
|
||||
return result;
|
||||
}
|
||||
} // namespace fft_v7
|
||||
|
||||
template <element::Type_t ET>
|
||||
bool evaluate(const shared_ptr<op::v7::DFT>& op,
|
||||
const HostTensorVector& outputs,
|
||||
const HostTensorVector& inputs)
|
||||
{
|
||||
auto info = fft_v7::get_info_for_fft7_eval(inputs);
|
||||
|
||||
std::vector<float> fft_result(shape_size(info.output_shape), 0.0f);
|
||||
runtime::reference::fft(info.input_data.data(),
|
||||
info.input_data_shape,
|
||||
info.axes_data.data(),
|
||||
info.axes_data_shape,
|
||||
fft_result.data(),
|
||||
info.output_shape,
|
||||
runtime::reference::FFTKind::Forward);
|
||||
|
||||
const auto output_type = op->get_input_element_type(0);
|
||||
runtime::reference::fft_postprocessing(outputs, output_type, fft_result);
|
||||
return true;
|
||||
}
|
||||
|
||||
template <element::Type_t ET>
|
||||
bool evaluate(const shared_ptr<op::v7::IDFT>& op,
|
||||
const HostTensorVector& outputs,
|
||||
const HostTensorVector& inputs)
|
||||
{
|
||||
auto info = fft_v7::get_info_for_fft7_eval(inputs);
|
||||
|
||||
std::vector<float> fft_result(shape_size(info.output_shape), 0.0f);
|
||||
runtime::reference::fft(info.input_data.data(),
|
||||
info.input_data_shape,
|
||||
info.axes_data.data(),
|
||||
info.axes_data_shape,
|
||||
fft_result.data(),
|
||||
info.output_shape,
|
||||
runtime::reference::FFTKind::Inverse);
|
||||
|
||||
const auto output_type = op->get_input_element_type(0);
|
||||
runtime::reference::fft_postprocessing(outputs, output_type, fft_result);
|
||||
return true;
|
||||
}
|
||||
|
||||
template <element::Type_t ET>
|
||||
bool evaluate(const shared_ptr<op::v0::LRN>& op,
|
||||
const HostTensorVector& outputs,
|
||||
|
@ -86,4 +86,6 @@ NGRAPH_OP(CTCGreedyDecoderSeqLen, op::v6)
|
||||
NGRAPH_OP(GatherElements, op::v6)
|
||||
NGRAPH_OP(MVN, ngraph::op::v6)
|
||||
|
||||
NGRAPH_OP(DFT, op::v7)
|
||||
NGRAPH_OP(IDFT, op::v7)
|
||||
NGRAPH_OP(Roll, ngraph::op::v7)
|
||||
|
Loading…
Reference in New Issue
Block a user