nGraph reference for the operation RDFT. (#11175)

* Written nGraph reference for the operation RDFT.

* Used std::reverse() algorithm to simplify the function reverse_shape() from fft_common.cpp.

* Added assert into the function offset_from_coords_and_strides().

* Deleted redundant variable.

* Deleted redundant functions from the reference implementation of (I)DFT.

* Renamed the method reverse_shape() in fft_common.hpp.

* Code style fix.
This commit is contained in:
Vladimir Gavrilov 2022-03-30 09:38:05 +03:00 committed by GitHub
parent 1386f52dd6
commit e7b35c3b00
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 1025 additions and 71 deletions

View File

@ -64,6 +64,7 @@
#include <ngraph/runtime/reference/prior_box.hpp>
#include <ngraph/runtime/reference/proposal.hpp>
#include <ngraph/runtime/reference/psroi_pooling.hpp>
#include <ngraph/runtime/reference/rdft.hpp>
#include <ngraph/runtime/reference/region_yolo.hpp>
#include <ngraph/runtime/reference/reorg_yolo.hpp>
#include <ngraph/runtime/reference/reverse_sequence.hpp>
@ -2105,6 +2106,73 @@ bool evaluate(const shared_ptr<op::v7::IDFT>& op, const HostTensorVector& output
return true;
}
namespace rfft_v9 {
struct InfoForRFFT9 {
std::vector<float> input_data;
std::vector<int64_t> axes_data;
Shape input_data_shape;
Shape axes_data_shape;
Shape fft_output_shape;
Shape output_shape;
};
InfoForRFFT9 get_info_for_rfft9_eval(const std::vector<std::shared_ptr<HostTensor>>& inputs) {
InfoForRFFT9 result;
result.input_data_shape = inputs[0]->get_shape();
result.axes_data_shape = inputs[1]->get_shape();
result.input_data = get_floats(inputs[0], result.input_data_shape);
result.axes_data = get_integers(inputs[1], result.axes_data_shape);
auto fft_output_shape = result.input_data_shape;
auto output_shape = result.input_data_shape;
int64_t input_rank = static_cast<int64_t>(result.input_data_shape.size());
auto canonicalized_axes =
runtime::reference::canonicalize_axes(result.axes_data.data(), result.axes_data_shape, input_rank);
size_t num_of_axes = result.axes_data.size();
auto signal_size = fft_v7::get_signal_size(inputs, num_of_axes);
const auto last_axis = canonicalized_axes.back();
for (size_t i = 0; i < num_of_axes; ++i) {
int64_t current_axis = canonicalized_axes[i];
int64_t current_signal_size = signal_size[i];
if (current_signal_size != -1) {
fft_output_shape[current_axis] = current_signal_size;
output_shape[current_axis] = current_signal_size;
}
}
output_shape[last_axis] = fft_output_shape[last_axis] / 2 + 1;
output_shape.push_back(2);
fft_output_shape.push_back(2);
result.fft_output_shape = fft_output_shape;
result.output_shape = output_shape;
result.axes_data = canonicalized_axes;
return result;
}
} // namespace rfft_v9
template <element::Type_t ET>
bool evaluate(const shared_ptr<op::v9::RDFT>& op, const HostTensorVector& outputs, const HostTensorVector& inputs) {
auto info = rfft_v9::get_info_for_rfft9_eval(inputs);
outputs[0]->set_shape(info.output_shape);
std::vector<float> rfft_result(shape_size(info.output_shape), 0.0f);
runtime::reference::rdft(info.input_data,
info.input_data_shape,
info.axes_data,
info.fft_output_shape,
rfft_result.data());
const auto output_type = op->get_input_element_type(0);
runtime::reference::fft_postprocessing(outputs, output_type, rfft_result);
return true;
}
template <element::Type_t ET>
bool evaluate(const shared_ptr<op::v0::LRN>& op, const HostTensorVector& outputs, const HostTensorVector& inputs) {
using T = typename element_type_traits<ET>::value_type;

View File

@ -130,3 +130,5 @@ NGRAPH_OP(Exp, op::v0)
NGRAPH_OP(Log, op::v0)
NGRAPH_OP(PriorBox, ngraph::op::v8)
NGRAPH_OP(PRelu, op::v0)
NGRAPH_OP(RDFT, op::v9)

View File

@ -0,0 +1,704 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include "base_reference_test.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/rdft.hpp"
using namespace reference_tests;
using namespace ov;
namespace {
struct RDFTParams {
template <class T>
RDFTParams(const Shape& input_shape,
const Shape& expected_shape,
const element::Type_t& input_type,
const element::Type_t& expected_type,
const std::vector<T>& input_value,
const std::vector<T>& expected_value,
const std::shared_ptr<op::v0::Constant>& axes,
const std::shared_ptr<op::v0::Constant>& signal) {
m_input_shape = input_shape;
m_expected_shape = expected_shape;
m_input_type = input_type;
m_expected_type = expected_type;
m_input_value = CreateTensor(input_type, input_value);
m_expected_value = CreateTensor(expected_type, expected_value);
m_axes = axes;
m_signal = signal;
}
Shape m_input_shape;
Shape m_expected_shape;
element::Type_t m_input_type;
element::Type_t m_expected_type;
ov::Tensor m_input_value;
ov::Tensor m_expected_value;
std::shared_ptr<op::v0::Constant> m_axes;
std::shared_ptr<op::v0::Constant> m_signal;
};
class ReferenceRDFTLayerTest : public testing::TestWithParam<RDFTParams>, public CommonReferenceTest {
public:
void SetUp() override {
auto params = GetParam();
if (params.m_signal != NULL) {
function = CreateFunctionWithSignal(params);
} else {
function = CreateFunction(params);
}
inputData = {params.m_input_value};
refOutData = {params.m_expected_value};
}
static std::string getTestCaseName(const testing::TestParamInfo<RDFTParams>& obj) {
const auto param = obj.param;
std::ostringstream result;
result << "input_shape1=" << param.m_input_shape << "; ";
result << "output_shape=" << param.m_expected_shape << "; ";
result << "input_type1=" << param.m_input_type << "; ";
result << "output_type=" << param.m_expected_type << "; ";
result << "transpose1=" << param.m_axes;
return result.str();
}
private:
static std::shared_ptr<Model> CreateFunction(RDFTParams& p) {
auto in = std::make_shared<op::v0::Parameter>(p.m_input_type, p.m_input_shape);
auto rdft = std::make_shared<op::v9::RDFT>(in, p.m_axes);
return std::make_shared<ov::Model>(rdft, ParameterVector{in});
}
static std::shared_ptr<Model> CreateFunctionWithSignal(RDFTParams& p) {
auto in = std::make_shared<op::v0::Parameter>(p.m_input_type, p.m_input_shape);
auto rdft = std::make_shared<op::v9::RDFT>(in, p.m_axes, p.m_signal);
return std::make_shared<ov::Model>(rdft, ParameterVector{in});
}
};
TEST_P(ReferenceRDFTLayerTest, CompareWithHardcodedRefs) {
Exec();
}
static const std::vector<float> input_data = {
0.10606491, 0.7454715, 0.57231355, 0.4582412, 0.3847059, 0.27398932, 0.66796243, 0.395475,
0.2815729, 0.7799197, 0.59909415, 0.12294636, 0.38957402, 0.97498834, 0.46759892, 0.14017141,
0.04206858, 0.7279963, 0.61560553, 0.9027321, 0.6226334, 0.2601217, 0.5555177, 0.40498647,
0.14175586, 0.57774633, 0.52652127, 0.9385691, 0.9588788, 0.9844318, 0.23095612, 0.09707925,
0.24574867, 0.6907577, 0.1974319, 0.8295272, 0.34612727, 0.51401484, 0.66115797, 0.9336245,
0.06690067, 0.7468897, 0.39028263, 0.53575844, 0.060429193, 0.8913558, 0.77787375, 0.6701197,
0.7350527, 0.6636995, 0.18176624, 0.8629976, 0.45142895, 0.6497297, 0.159372, 0.40598175,
0.7988516, 0.7291543, 0.07090418, 0.7697132, 0.4972157, 0.7669217, 0.67975855, 0.13026066,
0.6587437, 0.24532892, 0.24545169, 0.83795583, 0.105490535, 0.7264323, 0.94568557, 0.7216649,
0.14389831, 0.7930531, 0.70895344, 0.9724701, 0.9775157, 0.49999878, 0.65569246, 0.26876843,
0.63248956, 0.85201293, 0.5689624, 0.023386303, 0.5546464, 0.36860028, 0.9603114, 0.39123482,
0.0380728, 0.89212376, 0.14387614, 0.63858676, 0.10003748, 0.8906635, 0.06681054, 0.7458642,
0.45452347, 0.54724604, 0.6496482, 0.7818356, 0.6608355, 0.77711326, 0.24588613, 0.013456763,
0.355845, 0.80388206, 0.027993264, 0.73677206, 0.52755004, 0.9052324, 0.54311025, 0.5367805,
0.4131242, 0.7752338, 0.109669454, 0.13664648, 0.7828739, 0.9083969, 0.5247593, 0.7493595,
0.19275227, 0.007190853, 0.6087981, 0.344136, 0.46909887, 0.41924855, 0.7072913, 0.19932869,
0.5303847, 0.651384, 0.06686331, 0.9717932, 0.65702224, 0.11786682, 0.3154073, 0.88923013,
0.5564087, 0.91047823, 0.28466642, 0.0934668, 0.88953066, 0.9919338, 0.18322521, 0.8185455,
0.566391, 0.014207997, 0.29673064, 0.6347744, 0.6801958, 0.39601147, 0.34374171, 0.7216888,
0.6152569, 0.76679546, 0.5860851, 0.4276813, 0.79339284, 0.13130653, 0.68764234, 0.053128112,
0.02611321, 0.2982243, 0.7618372, 0.3331729, 0.5468192, 0.15707079, 0.28592056, 0.15286565,
0.9368963, 0.350671, 0.4336494, 0.08934934, 0.41172776, 0.5850259, 0.70730376, 0.8598349,
0.088788144, 0.26711187, 0.8002491, 0.19422275, 0.8312039, 0.5198718, 0.40111357, 0.98375803,
0.77703434, 0.037818834, 0.704231, 0.689808, 0.17102319, 0.42153922, 0.7278252, 0.8030207,
0.9101717, 0.0199644, 0.13768466, 0.55669, 0.17991355, 0.6720098, 0.7733328, 0.20881335};
static const std::vector<float> expected_rdft1d_results_1 = {
4.6657147, -1.1622906e-06, 0.21456887, -0.14946258, -0.20476034, -0.37063062,
-0.31414136, 0.5099413, -1.1779613, 0.07057127, -0.64047664, -1.0058284e-07,
4.982774, -1.1771917e-06, 0.6607505, 0.18829148, -0.9772357, 1.4243596,
0.8640026, 0.34923682, 0.33401352, 0.25859502, -0.7548928, 8.940697e-08,
5.9711604, -1.4901161e-06, 0.5638976, 1.5429841, -0.52065414, 0.24638398,
-0.27140495, 0.5040715, 0.5360231, 0.3234269, -0.36054826, 1.7508864e-07,
4.7464237, -1.2218952e-06, -0.29650804, 0.80609477, -0.161426, 1.0022418,
-0.50812817, 0.7967348, 0.4394225, -0.1588624, -1.3835809, -7.4505806e-08,
5.53836, -1.7136335e-06, -0.38635445, 0.8284859, -0.23278837, -0.63777345,
-0.93614054, 0.3215857, -0.14075133, -0.67071164, -1.4772836, 2.0861626e-07,
5.0798974, -1.5944242e-06, 0.056767445, 0.03468219, -0.1497254, -0.9672509,
0.2603209, 0.69644475, -0.9208536, 0.006730467, -1.7552528, 2.682209e-07,
4.893558, -1.6242266e-06, 0.6719861, -0.13982919, 0.064845346, -0.39896214,
0.21785057, -0.5099982, -0.65526295, 1.4383471, -0.52023906, 2.5331974e-07,
6.687699, -1.5497208e-06, -0.7423769, 0.09968524, 1.052381, -0.21306956,
0.5875206, -0.3038844, 0.3991575, -1.1895186, 0.17579001, 3.874302e-07,
5.2818384, -1.1026859e-06, 0.5087582, 0.106959194, 1.1816688, -0.87592727,
0.03740315, 0.5197907, -1.3198637, 0.6398836, 0.22712436, 2.2351742e-08,
5.0190897, -1.5646219e-06, -0.087282926, 0.50819266, -0.28002462, 0.29240948,
-0.32303664, 0.38377762, -0.0051696897, -0.99301195, -2.189299, 2.0861626e-07,
5.0545654, -1.5795231e-06, 0.9146397, 0.83839166, 0.870533, 0.17405808,
-0.56308234, -0.7806684, 0.26397777, 0.6880482, -1.4183462, 2.682209e-07,
5.479953, -1.2665987e-06, 0.49444157, 0.7534672, -0.76784146, -0.4507342,
0.88815784, 0.6985409, -0.2727425, -0.25027415, -0.7328796, 2.682209e-07,
4.1296124, -5.662441e-07, -0.46133032, 0.30635798, -0.18225375, 0.42515472,
-0.5484285, 0.9704039, -0.35255045, 0.17549685, 0.8870368, -3.1292439e-07,
4.8632016, -1.8924475e-06, -0.6926452, 0.025076404, -0.039108217, -1.7492937,
-0.8120377, -0.85315156, -0.0022608787, 0.45002514, -1.1024668, 3.501773e-07,
5.4715447, -1.4901161e-06, 1.1176248, -0.2109062, -0.27492502, 0.08983741,
1.1903813, -1.007312, -0.20150042, -0.83919466, -0.23939973, 4.917383e-07,
5.1267176, -9.983778e-07, -0.44803134, -0.8066604, -0.3435102, -0.41692197,
-0.22457689, -0.1076939, -0.29129186, -1.1880502, 0.9255183, -1.6391277e-07,
3.8495903, -5.5134296e-07, 0.09505272, -0.12751618, -1.1264827, 0.5068884,
-1.055237, -0.19516481, -0.34035242, -0.15379356, 1.2655814, -2.6077032e-07,
4.4372616, -9.23872e-07, -0.72962606, -0.23475963, -0.04278487, 1.1032158,
-0.558924, -0.5300043, 1.0578637, -0.2466627, 0.44617313, -7.8231096e-08,
5.5374002, -1.4156103e-06, 0.016273111, -0.5989829, -0.19913958, 0.013256833,
1.8512837, 0.14526272, -0.39700353, -0.07573915, 0.23181, 2.9429793e-07,
4.989425, -1.4901161e-06, 1.0391837, 0.16554561, -0.22647032, -1.0689808,
-0.84556, -0.82779336, 0.9430445, 0.37618563, 0.4684292, -9.685755e-08};
static const std::vector<float> expected_rdft1d_results_2 = {
2.266797, -8.195639e-08, -0.37842733, -0.41015846, -0.48980892, -0.10356337,
2.5542018, -2.2351742e-08, -0.3223713, 0.671882, 0.54300576, -0.35418037,
1.985015, -2.2351742e-08, -0.030243821, -0.20105253, 0.59431964, 0.07358998,
1.4619737, -7.450581e-09, -0.4356845, 0.35701087, 0.28208786, -0.36424285,
1.8002605, -1.1920929e-07, -0.43280697, -0.56735414, -0.30007166, -0.541847,
2.3052943, -1.2293458e-07, -0.39316025, -0.5526293, -0.30507135, -0.6021758,
2.7329001, -6.7055225e-08, 0.28245124, -0.42586988, -0.40586215, 0.4590181,
3.3132548, -5.9604645e-08, 0.6297612, 0.3694744, 0.077824846, -0.6248544,
2.6314974, -2.9802322e-08, 0.58795106, -0.60349375, -0.3224758, 0.34408605,
1.8399743, -9.685755e-08, -0.43963802, -0.079073176, -0.120658875, -1.0880115,
2.0531366, -4.4703484e-08, 0.80112594, -0.53726834, -0.17560546, -0.026561722,
2.3779182, -9.685755e-08, -0.21852754, -0.19336401, 0.38734403, -0.5954362,
1.6219761, 7.450581e-09, -0.43100592, 0.28373614, 0.101898566, 0.52321124,
2.128953, -1.4901161e-07, -0.1622684, -0.94116735, -0.7350497, 0.12695336,
3.449626, -8.940697e-08, 0.56062996, -0.031283244, -0.06161648, -0.8543532,
3.033568, -8.195639e-08, -0.37023768, -0.03989461, -0.28719214, -0.22382751,
1.9661667, -1.4901161e-08, -0.59863573, -0.015534669, -0.31916466, 0.55380434,
2.227056, -5.2154064e-08, -0.12656188, 0.6895717, 0.097157195, 0.19840825,
3.5129817, -2.1234155e-07, 0.11158541, 0.5870459, 0.20993343, -0.40297145,
2.5986667, 0.0, 0.26602313, -1.1560227, 0.2542065, 0.45556274};
static const std::vector<float> expected_rdft1d_results_3 = {
4.665715, -1.6093254e-06, -0.5430559, -0.5752678, -0.37596112, -1.1571281,
-0.46793216, -0.94566363, 0.6854232, -0.3444838, -0.674704, 0.5946392,
-0.64047587, 1.3560057e-06, 4.9827743, -1.7434359e-06, -0.43517, -0.049020194,
-1.4773891, -1.0811031, 1.2506557, 0.5371344, 1.2869358, -0.14998645,
0.8555907, 0.3693859, -0.7548918, 1.5944242e-06, 5.971161, -1.5199184e-06,
-1.2643411, 0.85635287, -0.1801207, -1.7264944, 0.6412285, -0.4787441,
0.82227707, 0.65098876, 0.9114491, 0.40323836, -0.36054718, 1.2852252e-06,
4.7464237, -1.66893e-06, -1.5010594, 0.2253451, -0.87915635, -0.4252541,
0.4976693, -0.6554581, 0.928985, 0.8035921, 0.6578763, -0.15220329,
-1.3835799, 1.0430813e-06, 5.5383606, -1.4901161e-06, -1.619024, -0.10987502,
0.20661727, -1.3774645, -0.3057741, -1.0960662, 0.2971667, 0.46700704,
-0.20812088, -0.602368, -1.4772825, 9.3877316e-07, 5.0798974, -1.758337e-06,
-0.7421876, -0.61749315, 0.21938956, -1.3415859, -0.838238, -0.6598083,
1.0601404, -0.7129184, -0.27083004, 0.31763482, -1.7552516, 1.4677644e-06,
4.893558, -1.4975667e-06, -0.06445231, -0.55879503, 0.08908144, -1.2869594,
0.33623943, -0.7704663, -0.047739983, -1.0678453, 0.48350462, 1.5768427,
-0.52023804, 1.1697412e-06, 6.687699, -1.3113022e-06, -1.292419, -1.2920969,
1.2041754, -0.2943018, 1.1889167, -0.66985166, 1.1336832, -0.13731277,
0.008011267, -0.9506076, 0.1757915, 1.1026859e-06, 5.2818394, -1.4305115e-06,
-0.25987166, -0.48605326, 0.90237427, -0.8028362, -0.3040653, -1.6981151,
1.1215456, -0.7120959, -0.4195284, 1.3941492, 0.22712523, 8.046627e-07,
5.01909, -1.7881393e-06, -1.1856917, -0.10931289, -0.5164983, -0.9724103,
0.30577338, -0.72837675, 0.89680094, 0.21036407, -0.052024096, -0.9455472,
-2.1892984, 1.4305115e-06, 5.054565, -1.5050173e-06, -0.3471575, 0.40542153,
0.36438322, -0.9765247, 1.2703501, -1.7359983, -0.1160066, -0.25323528,
0.9753329, 0.5339062, -1.418345, 9.834766e-07, 5.4799523, -1.7285347e-06,
-0.7905842, 0.093313254, 0.068526804, -1.8504739, -0.01845923, 0.26084417,
1.5358877, -0.4159652, 0.089752786, 0.089908056, -0.7328786, 1.4007092e-06,
4.129612, -9.536743e-07, -1.2393575, -0.28046644, -0.58673245, -0.39608067,
-0.12385368, -0.53435826, 0.77853805, 0.7645384, -0.18040559, 0.6678516,
0.88703763, 8.046627e-07, 4.8632016, -1.0430813e-06, -1.1780663, -1.0952923,
1.1691413, -1.4023741, -0.546494, -0.92614484, -1.1796933, -0.31762218,
0.25592417, 0.0959474, -1.1024656, 1.013279e-06, 5.471545, -1.6987324e-06,
0.35812324, -0.66833705, 0.07725692, -1.6537004, 1.6561611, 0.051166296,
0.865453, -1.1392289, -0.23588535, -0.5480979, -0.2393986, 1.3411045e-06,
5.126718, -9.23872e-07, -0.6379836, -1.6675751, 0.013057679, -0.9891113,
0.20881936, -0.30439606, 0.37222707, 0.25244698, -0.9197892, -0.77782196,
0.9255192, 1.1101365e-06, 3.8495903, -7.4505806e-07, -0.63088936, -0.4556699,
-1.1905057, -1.2522144, 0.46207082, -0.31992733, -0.4309795, 0.74295896,
-0.6106033, 0.18823686, 1.2655822, 7.748604e-07, 4.4372616, -7.0780516e-07,
-1.1016369, -1.0079124, -0.6083025, -0.0011255145, 1.4406854, -0.2912693,
-0.26610214, 0.87299407, 0.69553405, -0.45576566, 0.44617438, 7.4505806e-07,
5.5374007, -1.5944242e-06, -0.32642078, -1.3683549, 0.079301864, -0.83741367,
0.67391664, 0.69433576, 1.6423957, -1.1923066, 0.0334223, 0.37603495,
0.23181117, 1.4156103e-06, 4.9894247, -7.748604e-07, 0.1788401, -0.39274544,
0.78422666, -2.1340246, 0.5487572, -0.8765497, -0.7899384, 0.5434137,
0.91613716, 0.08274247, 0.46843058, 8.34465e-07
};
const std::vector<float> expected_rdft2d_results = {
52.8665, -2.9623508e-05, 1.1642078, 3.826082, -0.22771922, -0.49822173,
-0.3857528, 3.2676966, -2.5112464, -0.27454787, -8.678656, 3.7550926e-06,
-0.818072, 0.8330209, 3.4618711, -0.2419473, 1.7408192, 5.744002,
1.8477443, 2.039329, 0.3268112, -2.7421296, 0.6809025, 1.7613728,
-2.294264, -0.8984407, -0.2868184, -3.2426705, -0.801461, -0.58971727,
-1.463435, -2.5413132, 0.116907075, -0.5013529, -2.8377397, -2.8455539,
-0.13475686, -1.3145845, -2.2820292, -0.199, -0.056986623, 0.12560216,
-0.589707, -1.7577857, -0.5274223, -1.0395792, 0.53813136, -1.7159984,
0.22503978, 2.902198, -1.8643543, -1.8789856, 2.1722724, -2.068454,
0.59446484, 0.6067899, 1.5525781, 1.7612485, 1.1877432, -0.48152098,
-0.16525066, 1.5497208e-06, 1.9815066, 0.55218977, 0.80434155, -3.575598,
-2.1471107, -0.57691807, -3.004384, 3.8775828, 3.1358109, -6.2584877e-07,
0.22504184, -2.9021916, 1.0378464, 0.9877456, 0.38395065, -1.6089694,
-0.5107449, 1.8621777, -4.960479, -1.8983803, 1.187743, 0.48151842,
-0.1347583, 1.3145843, -0.9968031, -1.3782079, 0.9922035, 1.6614089,
-0.83039653, -0.043888614, 1.9431384, -1.6448143, 0.5381324, 1.7159982,
-2.2942696, 0.8984335, 1.3057998, -0.26607463, -3.2994738, -1.9240448,
1.4963659, 2.8365738, -4.691832, 1.2995429, -2.8377357, 2.8455553,
-0.8180722, -0.8330165, -1.3755352, 0.34623986, -3.7555497, -0.9723124,
-1.1528367, -0.593254, -0.023679793, 1.8681414, 0.6809023, -1.7613728,
48.939255, -2.4735928e-05, 1.3455832, 0.11001387, -2.3319814, -1.3735183,
-0.6780232, -2.4875786, 0.40718403, -1.0639579, 0.7314569, -1.2665987e-07,
0.97006464, -0.30789328, 3.3290033, 2.7749023, -0.7520597, -0.98800826,
1.3100916, 1.1514524, 1.1085359, 4.348257, -2.839456, 2.4404035,
0.9518837, 2.1538901, 3.8438358, 2.410589, 3.0649068, 0.95690995,
2.2213395, 0.66509914, -0.4409917, -0.37408838, -0.6316552, -1.5842111,
-0.72352415, -2.5862057, 0.2678757, 0.610149, 2.9564474, 0.08470708,
-2.0889034, -8.370071, -0.16373271, 2.0413866, -3.3811545, 2.0487003,
0.0316903, -1.078939, -2.5515578, -0.16135174, -0.17406325, 1.2709827,
-0.67006403, -1.6342779, 0.42163712, 2.1418998, -0.96614444, 1.9175051,
-0.8538456, 2.8014183e-06, 2.0189362, 0.30467552, 0.5074463, 3.7919073,
2.427857, 0.7526233, -2.4620402, 0.65359443, 0.7219074, -2.3841858e-07,
0.03169757, 1.0789458, -2.1129081, -1.0250417, 4.8181386, -0.39162922,
-1.2349386, 1.8470186, -0.49495277, -1.5516026, -0.96614635, -1.9175065,
-0.7235237, 2.5862021, 0.677946, 2.0370173, -0.29536027, 0.6505451,
-2.8572361, 2.3176546, 3.4459226, 1.1869265, -3.3811545, -2.048697,
0.95187366, -2.1538982, 1.808088, -1.1755496, -2.7418838, -1.6770658,
-3.5766084, -2.8320727, -0.02944839, -1.6522555, -0.63165283, 1.5842092,
0.9700667, 0.30789307, 0.5195943, 2.4985125, 3.6537378, -0.5842519,
-0.4843334, 0.78346854, 0.84766304, 1.1503224, -2.839459, -2.440402};
const std::vector<float> expected_rdft2d_results_2 = {
25.904434, -8.46386e-06, -5.3626504, 0.3475349, -2.7060094, -5.767444,
1.615847, -2.6387978, 4.020789, 1.4271183, 1.5420923, 0.6126925,
-4.6167765, 5.5730343e-06, -0.753784, -0.19148755, 1.4881928, -2.7645326,
-0.39467168, 1.014636, 0.5598, -1.7654291, -0.91835654, -2.3019042,
-0.49356225, -0.8411435, 0.080773115, -1.2883577, -0.5341466, 1.4913602,
-0.30008763, -0.5831754, 1.7365295, 1.821624, -0.08851206, -1.622279,
-0.27249795, -0.834725, -0.6706438, 0.4766277, 0.62642634, 0.5483514,
-0.5341469, -1.4913592, 0.8286207, 0.35826343, -1.0869694, -1.4876881,
-1.6723244, -0.06565219, 0.16255295, 0.5317876, -0.75649667, 1.2447717,
0.6264261, -0.5483517, -0.7537827, 0.19148779, 0.6306459, -0.23442982,
0.57131517, -1.366768, -2.7544713, 1.3638397, 0.43463084, -0.5446956,
-2.9949086, 1.4802479, 0.080771565, 1.2883584, 24.998875, -7.390976e-06,
-3.1970425, -1.5453612, 1.0925753, -6.279154, 2.237704, -2.8844912,
1.8841789, -1.3615136, 0.90471864, 0.8395144, -2.6060505, 4.976988e-06,
1.1634235, 0.42319643, 2.678257, 2.4692535, 0.34259582, 0.43598562,
2.748452, 0.88622695, 2.2745323, -2.8840196, 1.8120161, -0.27884078,
-1.5445104, -0.7000726, -1.0264511, -0.7026249, -1.071573, 1.062395,
-0.64628685, -0.36214483, -0.5110928, -1.0534683, -2.786768, 2.6113648,
0.94799054, 0.53423727, -0.69832724, 2.1821892, -1.0264513, 0.70262754,
-0.41705567, -0.17140968, 1.4991179, 2.9674625, -0.012362838, -3.8260121,
-1.5786235, -0.32526863, 1.2857957, 1.7469958, -0.6983267, -2.1821907,
1.1634252, -0.42319855, 0.2716269, 0.21222934, -0.46608746, -1.6447732,
1.8890494, -1.8022469, -0.37335354, 0.69326025, -0.07385725, -0.1723765,
-1.5445105, 0.7000739};
const std::vector<float> expected_rdft3d_results = {
101.805756, -5.2273273e-05, 2.5097876, 3.936094, -2.5597036, -1.8717405,
-1.0637736, 0.7801182, -2.1040666, -1.3385094, -7.9471993, 2.026558e-06,
0.15199316, 0.52512753, 6.7908745, 2.5329556, 0.98875976, 4.755993,
3.157838, 3.190782, 1.4353466, 1.6061276, -2.158554, 4.201776,
-1.3423799, 1.2554499, 3.5570183, -0.8320818, 2.263445, 0.36719292,
0.7579028, -1.8762131, -0.32408538, -0.87544185, -3.4693956, -4.429764,
-0.85828185, -3.9007902, -2.0141544, 0.4111499, 2.8994608, 0.21030927,
-2.6786098, -10.127857, -0.6911557, 1.0018079, -2.8430226, 0.33270124,
0.25672907, 1.8232578, -4.4159126, -2.040338, 1.9982092, -0.7974717,
-0.07559925, -1.0274884, 1.9742157, 3.9031482, 0.22159882, 1.4359848,
-1.0190966, 3.2186508e-06, 4.0004425, 0.8568655, 1.3117876, 0.2163087,
0.28074512, 0.17570588, -5.466423, 4.531178, 3.857718, -1.2516975e-06,
0.2567385, -1.823246, -1.0750613, -0.037295938, 5.20209, -2.0005994,
-1.7456844, 3.7091968, -5.45543, -3.4499822, 0.22159535, -1.4359887,
-0.8582816, 3.9007854, -0.31885874, 0.65880924, 0.6968423, 2.3119528,
-3.6876333, 2.273767, 5.38906, -0.45788872, -2.8430223, -0.33269957,
-1.3423961, -1.2554631, 3.1138885, -1.4416232, -6.0413575, -3.6011095,
-2.080242, 0.0045015216, -4.7212796, -0.3527125, -3.4693892, 4.429763,
0.15199506, -0.52512354, -0.85594195, 2.8447511, -0.10181111, -1.5565643,
-1.6371696, 0.19021615, 0.8239815, 3.018465, -2.158556, -4.2017746,
3.9272437, -3.9339066e-06, -0.18137527, 3.7160687, 2.1042633, 0.8752967,
0.29226887, 5.755277, -2.9184306, 0.78941, -9.410112, 3.0100346e-06,
-1.7881365, 1.140914, 0.13286811, -3.01685, 2.4928799, 6.7320104,
0.5376528, 0.88787735, -0.78172505, -7.0903873, 3.5203578, -0.6790314,
-3.246148, -3.0523329, -4.1306543, -5.653259, -3.866367, -1.5466263,
-3.6847744, -3.2064118, 0.5578996, -0.12726665, -2.2060838, -1.2613428,
0.588767, 1.2716217, -2.5499039, -0.8091496, -3.0134337, 0.0408957,
1.4991964, 6.6122847, -0.36368948, -3.0809648, 3.9192853, -3.764699,
0.19334978, 3.9811373, 0.68720365, -1.717634, 2.346336, -3.3394372,
1.2645291, 2.241068, 1.1309403, -0.3806507, 2.1538877, -2.3990266,
0.6885946, -1.4901161e-06, -0.037429705, 0.24751475, 0.2968948, -7.367506,
-4.574969, -1.329541, -0.5423446, 3.2239883, 2.4139037, 2.9802322e-07,
0.19334424, -3.9811373, 3.1507545, 2.0127864, -4.4341884, -1.2173393,
0.72419256, 0.015158802, -4.4655256, -0.34677732, 2.1538897, 2.3990245,
0.5887663, -1.2716188, -1.6747494, -3.415226, 1.2875631, 1.0108626,
2.0268395, -2.3615427, -1.502785, -2.8317401, 3.919288, 3.764695,
-3.2461433, 3.0523314, -0.5022881, 0.9094755, -0.55759126, -0.24697942,
5.0729737, 5.668646, -4.662384, 2.9517999, -2.2060819, 1.2613468,
-1.7881389, -1.1409098, -1.8951292, -2.1522717, -7.4092865, -0.38806117,
-0.6685039, -1.3767233, -0.8713439, 0.71781945, 3.5203605, 0.6790297};
const std::vector<float> expected_rdft3d_results_2 = {
50.90331, -1.4543533e-05, -8.559692, -1.1978266, -1.6134334, -12.046599,
3.8535514, -5.5232873, 5.9049683, 0.065603495, 2.4468107, 1.4522064,
-7.222825, 1.2278557e-05, 0.40963984, 0.231709, 4.16645, -0.29528028,
-0.052075505, 1.450621, 3.3082519, -0.8792013, 1.356175, -5.1859245,
1.3184534, -1.1199851, -1.4637363, -1.9884299, -1.5605974, 0.7887349,
-1.3716602, 0.47921878, 1.0902424, 1.4594792, -0.59960556, -2.6757474,
-3.0592656, 1.7766399, 0.27734682, 1.0108652, -0.07190053, 2.7305403,
-1.5605986, -0.78873086, 0.41156515, 0.18685403, 0.4121489, 1.4797752,
-1.6846865, -3.8916636, -1.4160703, 0.20651829, 0.52929974, 2.9917672,
-0.07190076, -2.7305427, 0.4096415, -0.23171037, 0.9022726, -0.022200808,
0.10522783, -3.0115416, -0.8654218, -0.4384073, 0.061277367, 0.14856634,
-3.0687659, 1.3078697, -1.4637384, 1.9884316, 25.904425, -24.998884,
-6.9080105, 3.5445771, -8.985163, -6.860018, -1.2686447, -4.8765025,
2.6592734, -0.45706248, 2.3816066, -0.29202732, -4.6167727, 2.6060565,
-0.33058774, -1.3549114, 3.9574459, -5.44279, 0.041313916, 0.67204094,
1.446027, -4.5138807, -3.8023772, -4.576436, -0.7724026, -2.6531591,
-0.6192993, 0.25615194, -1.2367722, 2.5178113, 0.7623075, 0.48839718,
1.3743844, 2.4679115, -1.1419809, -1.1111865, 2.3388672, 1.9520425,
-0.13640736, -0.47136223, 2.8086162, 1.2466785, 0.16848034, -0.46490768,
0.6572111, 0.7753189, 1.8804929, -2.9868064, -5.498336, -0.053289652,
-0.16271627, 2.1104114, 0.9904991, -0.041024223, -1.5557647, 0.14997506,
-1.1769819, -0.9719368, 0.8428756, -0.5060569, -1.0734584, -0.9006812,
-4.556718, -0.5252099, 1.1278908, -0.17134166, -3.1672862, 1.5541049,
0.78084624, 2.8328683, 0.90555733, -1.3709068e-06, -2.1656086, 1.8928962,
-3.7985847, 0.511709, -0.62185717, 0.24569236, 2.1366088, 2.7886305,
0.6373716, -0.2268233, -2.0107267, 5.662441e-07, -1.9172084, -0.6146841,
-1.1900643, -5.233785, -0.73726743, 0.5786506, -2.188651, -2.6516552,
-3.1928902, 0.58211625, -2.305578, -0.5623034, 1.6252834, -0.58828497,
0.49230486, 2.1939852, 0.7714851, -1.6455705, 2.382816, 2.1837692,
0.4225806, -0.56881106, 2.514269, -3.4460905, -1.618634, -0.057608932,
1.3247533, -1.6338379, 0.49230492, -2.1939862, 1.2456759, 0.5296728,
-2.5860875, -4.45515, -1.659962, 3.7603593, 1.7411764, 0.8570565,
-2.0422916, -0.50222373, 1.3247528, 1.633839, -1.9172082, 0.6146865,
0.35901868, -0.44665974, 1.0374024, 0.27800465, -4.6435204, 3.1660864,
0.8079842, -1.2379556, -2.921052, 1.6526239, 1.6252828, 0.588284,
25.90444, 24.998867, -3.817289, -2.8495073, 3.573144, -4.6748676,
4.500339, -0.40109348, 5.382302, 3.3112957, 0.7025763, 1.5174108,
-4.616783, -2.6060438, -1.1769816, 0.97193646, -0.9810596, -0.086276084,
-0.83065766, 1.3572321, -0.3264265, 0.9830234, 1.9656628, -0.027371943,
-0.2147214, 0.9708719, 0.7808455, -2.8328671, 0.16847888, 0.46490908,
-1.3624828, -1.6547482, 2.0986745, 1.1753378, 0.9649557, -2.1333718,
-2.8838634, -3.6214924, -1.2048804, 1.4246187, -1.5557631, -0.14997569,
-1.2367743, -2.5178103, 1.0000296, -0.05879204, -4.0544314, 0.01142931,
2.153687, -0.078014135, 0.4878212, -1.0468364, -2.503492, 2.5305676,
2.808617, -1.2466786, -0.33058444, 1.3549128, 0.41841656, 0.03719666,
2.216088, -1.8328552, -0.95222485, 3.2528882, -0.25863037, -0.91804826,
-2.822532, 1.4063904, -0.6193025, -0.25615215};
template<class T>
static std::vector<T> convert(const std::vector<float>& v) {
if (v.empty()) {
return std::vector<T>();
}
size_t num_of_elems = v.size();
std::vector<T> converted(num_of_elems);
for (size_t i = 0; i < num_of_elems; ++i) {
converted[i] = static_cast<T>(v[i]);
}
return converted;
}
template <class T>
static std::vector<T> convert(const std::vector<float16>& v) {
if (v.empty()) {
return std::vector<T>();
}
size_t num_of_elems = v.size();
std::vector<T> converted(num_of_elems);
for (size_t i = 0; i < num_of_elems; ++i) {
converted[i] = static_cast<T>(v[i]);
}
return converted;
}
template <class T>
static std::vector<T> convert(const std::vector<bfloat16>& v) {
if (v.empty()) {
return std::vector<T>();
}
size_t num_of_elems = v.size();
std::vector<T> converted(num_of_elems);
for (size_t i = 0; i < num_of_elems; ++i) {
converted[i] = static_cast<T>(v[i]);
}
return converted;
}
template <element::Type_t ET>
std::vector<RDFTParams> generateParamsForRDFT() {
std::vector<RDFTParams> params{
// rdft1d_eval
RDFTParams(Shape{2, 10, 10},
Shape{2, 10, 6, 2},
ET,
ET,
input_data,
expected_rdft1d_results_1,
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{1}, {2}),
NULL),
// rdft1d_eval_signal_size_0
RDFTParams(Shape{2, 10, 10},
Shape{2, 10, 6, 2},
ET,
ET,
input_data,
expected_rdft1d_results_1,
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{1}, {2}),
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{1}, {10})),
// rdft1d_eval_signal_size_0_1
RDFTParams(Shape{2, 10, 10},
Shape{2, 10, 6, 2},
ET,
ET,
input_data,
expected_rdft1d_results_1,
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{1}, {2}),
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{1}, {-1})),
// rdft1d_eval_1
RDFTParams(Shape{2, 10, 10},
Shape{2, 10, 6, 2},
ET,
ET,
input_data,
expected_rdft1d_results_1,
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{1}, {-1}),
NULL),
// rdft1d_eval_signal_size_1
RDFTParams(Shape{2, 10, 10},
Shape{2, 10, 3, 2},
ET,
ET,
input_data,
expected_rdft1d_results_2,
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{1}, {2}),
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{1}, {5})),
// rdft1d_eval_signal_size_1_1
RDFTParams(Shape{2, 10, 10},
Shape{2, 10, 3, 2},
ET,
ET,
input_data,
expected_rdft1d_results_2,
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{1}, {-1}),
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{1}, {5})),
// rdft1d_eval_signal_size_2
RDFTParams(Shape{2, 10, 10},
Shape{2, 10, 7, 2},
ET,
ET,
input_data,
expected_rdft1d_results_3,
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{1}, {2}),
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{1}, {12})),
// rdft1d_eval_signal_size_2_1
RDFTParams(Shape{2, 10, 10},
Shape{2, 10, 7, 2},
ET,
ET,
input_data,
expected_rdft1d_results_3,
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{1}, {-1}),
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{1}, {12})),
// rdft2d_eval_1
RDFTParams(Shape{2, 10, 10},
Shape{2, 10, 6, 2},
ET,
ET,
input_data,
expected_rdft2d_results,
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{2}, {1, 2}),
NULL),
// rdft2d_eval_1_positive_negative_axes
RDFTParams(Shape{2, 10, 10},
Shape{2, 10, 6, 2},
ET,
ET,
input_data,
expected_rdft2d_results,
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{2}, {1, -1}),
NULL),
// rdft2d_eval_1_negative_positive_axes
RDFTParams(Shape{2, 10, 10},
Shape{2, 10, 6, 2},
ET,
ET,
input_data,
expected_rdft2d_results,
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{2}, {-2, 2}),
NULL),
// rdft2d_eval_1_negative_negative_axes
RDFTParams(Shape{2, 10, 10},
Shape{2, 10, 6, 2},
ET,
ET,
input_data,
expected_rdft2d_results,
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{2}, {-2, -1}),
NULL),
// rdft2d_eval_1_signal_size_0_s10_10
RDFTParams(Shape{2, 10, 10},
Shape{2, 10, 6, 2},
ET,
ET,
input_data,
expected_rdft2d_results,
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{2}, {1, 2}),
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{2}, {10, 10})),
// rdft2d_eval_1_signal_size_0_s10_10_positive_negative_axes
RDFTParams(Shape{2, 10, 10},
Shape{2, 10, 6, 2},
ET,
ET,
input_data,
expected_rdft2d_results,
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{2}, {1, -1}),
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{2}, {10, 10})),
// rdft2d_eval_1_signal_size_0_s10_10_negative_positive_axes
RDFTParams(Shape{2, 10, 10},
Shape{2, 10, 6, 2},
ET,
ET,
input_data,
expected_rdft2d_results,
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{2}, {-2, 2}),
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{2}, {10, 10})),
// rdft2d_eval_1_signal_size_0_s10_10_negative_negative_axes
RDFTParams(Shape{2, 10, 10},
Shape{2, 10, 6, 2},
ET,
ET,
input_data,
expected_rdft2d_results,
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{2}, {-2, -1}),
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{2}, {10, 10})),
// rdft2d_eval_1_signal_size_0_s10_m1
RDFTParams(Shape{2, 10, 10},
Shape{2, 10, 6, 2},
ET,
ET,
input_data,
expected_rdft2d_results,
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{2}, {1, 2}),
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{2}, {10, -1})),
// rdft2d_eval_1_signal_size_0_sm1_10
RDFTParams(Shape{2, 10, 10},
Shape{2, 10, 6, 2},
ET,
ET,
input_data,
expected_rdft2d_results,
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{2}, {1, 2}),
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{2}, {-1, 10})),
// rdft2d_eval_1_signal_size_0_sm1_m1
RDFTParams(Shape{2, 10, 10},
Shape{2, 10, 6, 2},
ET,
ET,
input_data,
expected_rdft2d_results,
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{2}, {1, 2}),
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{2}, {-1, -1})),
// rdft2d_eval_2_signal_size
RDFTParams(Shape{2, 10, 10},
Shape{2, 5, 7, 2},
ET,
ET,
input_data,
expected_rdft2d_results_2,
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{2}, {1, 2}),
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{2}, {5, 12})),
// rdft2d_eval_2_signal_size_positive_negative_axes
RDFTParams(Shape{2, 10, 10},
Shape{2, 5, 7, 2},
ET,
ET,
input_data,
expected_rdft2d_results_2,
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{2}, {1, -1}),
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{2}, {5, 12})),
// rdft2d_eval_2_signal_size_negative_positive_axes
RDFTParams(Shape{2, 10, 10},
Shape{2, 5, 7, 2},
ET,
ET,
input_data,
expected_rdft2d_results_2,
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{2}, {-2, 2}),
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{2}, {5, 12})),
// rdft2d_eval_2_signal_size_negative_negative_axes
RDFTParams(Shape{2, 10, 10},
Shape{2, 5, 7, 2},
ET,
ET,
input_data,
expected_rdft2d_results_2,
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{2}, {-2, -1}),
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{2}, {5, 12})),
// rdft3d_eval_1
RDFTParams(Shape{2, 10, 10},
Shape{2, 10, 6, 2},
ET,
ET,
input_data,
expected_rdft3d_results,
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{3}, {0, 1, 2}),
NULL),
// rdft3d_eval_1_negative_axes_and_signal_size
RDFTParams(Shape{2, 10, 10},
Shape{2, 10, 6, 2},
ET,
ET,
input_data,
expected_rdft3d_results,
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{3}, {-3, 1, 2}),
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{3}, {-1, 10, -1})),
// rdft3d_eval_2
RDFTParams(Shape{2, 10, 10},
Shape{4, 5, 7, 2},
ET,
ET,
input_data,
expected_rdft3d_results_2,
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{3}, {0, 1, 2}),
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{3}, {4, 5, 12})),
// rdft3d_eval_2_negative_axes
RDFTParams(Shape{2, 10, 10},
Shape{4, 5, 7, 2},
ET,
ET,
input_data,
expected_rdft3d_results_2,
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{3}, {-3, -2, 2}),
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{3}, {4, 5, 12})),
};
return params;
}
std::vector<RDFTParams> generateCombinedParamsForRDFT() {
const std::vector<std::vector<RDFTParams>> allTypeParams{
generateParamsForRDFT<element::Type_t::f32>()
};
std::vector<RDFTParams> combinedParams;
for (const auto& params : allTypeParams) {
combinedParams.insert(combinedParams.end(), params.begin(), params.end());
}
return combinedParams;
}
INSTANTIATE_TEST_SUITE_P(
smoke_RDFT_With_Hardcoded_Refs,
ReferenceRDFTLayerTest,
::testing::ValuesIn(generateCombinedParamsForRDFT()),
ReferenceRDFTLayerTest::getTestCaseName);
} // namespace

View File

@ -0,0 +1,38 @@
//*****************************************************************************
// Copyright 2017-2022 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <cstddef>
#include <ngraph/runtime/host_tensor.hpp>
#include <vector>
#include "ngraph/node.hpp"
#include "ngraph/op/util/op_types.hpp"
#include "ngraph/ops.hpp"
#include "ngraph/shape_util.hpp"
namespace ngraph {
namespace runtime {
namespace reference {
void rdft(const std::vector<float>& input_data,
const Shape& input_data_shape,
const std::vector<int64_t>& axes_data,
const Shape& output_fft_shape,
float* rdft_result);
} // namespace reference
} // namespace runtime
} // namespace ngraph

View File

@ -0,0 +1,44 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <cstdint>
#include <iterator>
#include <limits>
#include <stdexcept>
#include <type_traits>
#include <vector>
#include "ngraph/shape.hpp"
#include "ngraph/type/element_type.hpp"
namespace ngraph {
namespace runtime {
namespace reference {
namespace fft_common {
// To simplify calculation of strides for all axes of 'shape' of some complex
// tensor, we reverse numbers in 'shape'. Because we have no native support for
// complex numbers in tensors, we interpret float input tensors of the shape
// [N_0, ..., N_{r - 1}, 2] as a complex tensor with the shape
// [N_0, ..., N_{r - 1}]. Hence, we convert 'shape=[N_0, ..., N_{r - 1}, 2]'
// into [N_{r - 1}, ..., N_0].
// At this time, complex tensors are supported only for FFT-like operations, as
// DFT, IDFT, RDFT
std::vector<int64_t> reverse_shape_of_emulated_complex_tensor(const ngraph::Shape& shape);
// Calculates strides for all axes.
std::vector<int64_t> compute_strides(const std::vector<int64_t>& v);
// Calculating coordinates c_0, ..., c_{k - 1} from the index of the form
// c_0 * strides[0] + ... c_{k - 1} * strides[k - 1]
// where k is the number of strides.
std::vector<int64_t> coords_from_index(int64_t index, const std::vector<int64_t>& strides);
// Calculates offset of value using corresponding coordinates and strides.
int64_t offset_from_coords_and_strides(const std::vector<int64_t>& coords, const std::vector<int64_t>& strides);
} // namespace fft_common
} // namespace reference
} // namespace runtime
} // namespace ngraph

View File

@ -22,6 +22,7 @@
#include <complex>
#include <cstring>
#include <functional>
#include <ngraph/runtime/reference/utils/fft_common.hpp>
#include <utility>
#include <vector>
@ -57,34 +58,6 @@ std::vector<int64_t> canonicalize_axes(const int64_t* axes_data,
namespace {
using complex_type = std::complex<float>;
// Calculates strides for all axes.
std::vector<int64_t> compute_strides(const std::vector<int64_t>& v) {
std::vector<int64_t> strides(v.size() + 1);
int64_t stride = 1;
for (size_t i = 0; i < v.size(); ++i) {
strides[i] = stride;
stride *= v[i];
}
strides.back() = stride;
return strides;
}
// To simplify calculation of strides for all axes of 'shape' of some complex
// tensor, we reverse numbers in 'shape'. Because we have no native support for
// complex numbers in tensors, we interpret FFT input tensors of the shape
// [N_0, ..., N_{r - 1}, 2] as a complex tensor with the shape
// [N_0, ..., N_{r - 1}]. Hence, we convert 'shape=[N_0, ..., N_{r - 1}, 2]'
// into [N_{r - 1}, ..., N_0].
std::vector<int64_t> reverse_shape(const Shape& shape) {
size_t complex_data_rank = shape.size() - 1;
std::vector<int64_t> reversed_shape(complex_data_rank);
for (size_t i = 0; i < complex_data_rank; ++i) {
reversed_shape[i] = static_cast<int64_t>(shape[complex_data_rank - i - 1]);
}
return reversed_shape;
}
// This function gets FFT axes from axes_data
std::vector<int64_t> get_axes(const int64_t* axes_data, const Shape& axes_data_shape, int64_t complex_data_rank) {
auto axes = canonicalize_axes(axes_data, axes_data_shape, complex_data_rank);
@ -148,24 +121,6 @@ int64_t compute_buffer_size(const std::vector<int64_t>& fft_lengths) {
return buffer_size;
}
// Calculating coordinates c_0, ..., c_{k - 1} from the index of the form
// c_0 * strides[0] + ... c_{k - 1} * strides[k - 1]
// where k is the number of strides.
std::vector<int64_t> coords_from_index(int64_t index, const std::vector<int64_t>& strides) {
int64_t num_of_axes = static_cast<int64_t>(strides.size()) - 1;
if (num_of_axes == 0) {
return std::vector<int64_t>{};
}
std::vector<int64_t> coords(num_of_axes);
int64_t curr = index;
for (int64_t j = num_of_axes - 1; j >= 1; --j) {
coords[j] = curr / strides[j];
curr %= strides[j];
}
coords[0] = curr;
return coords;
}
// This function gets a complex value from given coords of this value
complex_type get_value_from_input(const complex_type* input_data,
int64_t src_index,
@ -194,7 +149,7 @@ void copy_data_from_input(complex_type* result,
const std::vector<int64_t>& input_fft_lengths,
const std::vector<int64_t>& input_fft_strides) {
for (int64_t idx = 0; idx < fft_size; ++idx) {
auto coords = coords_from_index(idx, fft_strides);
auto coords = fft_common::coords_from_index(idx, fft_strides);
complex_type value = get_value_from_input(input_data, src_index, coords, input_fft_lengths, input_fft_strides);
result[idx] = value;
}
@ -210,16 +165,6 @@ bool blob_is_zero(const complex_type* data, int64_t blob_size) {
return true;
}
// Calculates offset of value using corresponding coordinates and strides.
int64_t offset_from_coords_and_strides(const std::vector<int64_t>& coords, const std::vector<int64_t>& strides) {
int64_t offset = 0;
int64_t num_of_axes = coords.size();
for (int64_t i = 0; i < num_of_axes; ++i) {
offset += coords[i] * strides[i];
}
return offset;
}
// Copying calculated data to the given memory domain.
void copy_data_to_output(complex_type* output,
const complex_type* data,
@ -228,9 +173,9 @@ void copy_data_to_output(complex_type* output,
const std::vector<int64_t>& fft_strides,
const std::vector<int64_t>& output_fft_strides) {
for (int64_t idx = 0; idx < fft_size; ++idx) {
auto coords = coords_from_index(idx, fft_strides);
auto coords = fft_common::coords_from_index(idx, fft_strides);
complex_type value = data[idx];
int64_t offset = offset_from_coords_and_strides(coords, output_fft_strides);
int64_t offset = fft_common::offset_from_coords_and_strides(coords, output_fft_strides);
output[dst_index + offset] = value;
}
@ -372,29 +317,29 @@ InfoForFFTCalculation get_info_for_calculation(const Shape& input_data_shape,
const int64_t complex_data_rank = static_cast<int64_t>(input_data_shape.size() - 1);
const auto reversed_output_shape = reverse_shape(output_shape);
const auto reversed_output_shape = fft_common::reverse_shape_of_emulated_complex_tensor(output_shape);
auto fft_axes = get_axes(axes_data, axes_data_shape, complex_data_rank);
reverse_fft_axes(fft_axes, complex_data_rank);
const int64_t fft_rank = fft_axes.size();
const auto fft_lengths = get_lengths(reversed_output_shape, fft_axes);
const auto fft_strides = compute_strides(fft_lengths);
const auto fft_strides = fft_common::compute_strides(fft_lengths);
const int64_t fft_size = fft_strides[fft_rank];
const auto outer_axes = get_outer_axes(fft_axes, complex_data_rank);
const int64_t outer_rank = outer_axes.size();
const auto outer_lengths = get_lengths(reversed_output_shape, outer_axes);
const auto outer_strides = compute_strides(outer_lengths);
const auto outer_strides = fft_common::compute_strides(outer_lengths);
const int64_t outer_size = outer_strides[outer_rank];
const int64_t buffer_size = compute_buffer_size(fft_lengths);
const auto output_strides = compute_strides(reversed_output_shape);
const auto output_strides = fft_common::compute_strides(reversed_output_shape);
const auto output_fft_strides = get_lengths(output_strides, fft_axes);
const auto output_outer_strides = get_lengths(output_strides, outer_axes);
const auto reversed_input_shape = reverse_shape(input_data_shape);
const auto reversed_input_shape = fft_common::reverse_shape_of_emulated_complex_tensor(input_data_shape);
const auto input_fft_lengths = get_lengths(reversed_input_shape, fft_axes);
const auto input_strides = compute_strides(reversed_input_shape);
const auto input_strides = fft_common::compute_strides(reversed_input_shape);
const auto input_fft_strides = get_lengths(input_strides, fft_axes);
const auto input_outer_strides = get_lengths(input_strides, outer_axes);
@ -461,8 +406,8 @@ void fft(const float* input_data,
// Loop along with 'outer' dimensions, that is along with
// not transformed dimensions.
for (int64_t outer_idx = 0; outer_idx < outer_size; ++outer_idx) {
const auto outer_coords = coords_from_index(outer_idx, outer_strides);
int64_t outer_input_offset = offset_from_coords_and_strides(outer_coords, input_outer_strides);
const auto outer_coords = fft_common::coords_from_index(outer_idx, outer_strides);
int64_t outer_input_offset = fft_common::offset_from_coords_and_strides(outer_coords, input_outer_strides);
// Copying current data to transform
copy_data_from_input(data.data(),
@ -482,14 +427,14 @@ void fft(const float* input_data,
auto outer_fft_axes = lengths_except_given_axis(fft_axes, axis_idx);
int64_t outer_fft_size = fft_size / current_fft_length;
auto outer_fft_strides = compute_strides(outer_fft_lengths);
auto outer_fft_strides = fft_common::compute_strides(outer_fft_lengths);
auto fft_strides_for_outer_fft_axes = lengths_except_given_axis(fft_strides, axis_idx);
// Loop along with all FFT axes, except the current one.
for (int64_t outer_fft_idx = 0; outer_fft_idx < outer_fft_size; ++outer_fft_idx) {
const auto outer_fft_coords = coords_from_index(outer_fft_idx, outer_fft_strides);
const auto outer_fft_coords = fft_common::coords_from_index(outer_fft_idx, outer_fft_strides);
int64_t outer_fft_offset =
offset_from_coords_and_strides(outer_fft_coords, fft_strides_for_outer_fft_axes);
fft_common::offset_from_coords_and_strides(outer_fft_coords, fft_strides_for_outer_fft_axes);
// Calculation of 1D FFT
fft1d(current_fft_length,
outer_fft_offset,
@ -502,7 +447,7 @@ void fft(const float* input_data,
}
// Copying current calculated data to the output blob.
int64_t outer_output_offset = offset_from_coords_and_strides(outer_coords, output_outer_strides);
int64_t outer_output_offset = fft_common::offset_from_coords_and_strides(outer_coords, output_outer_strides);
copy_data_to_output(complex_output_ptr,
data.data(),
outer_output_offset,

View File

@ -0,0 +1,89 @@
//*****************************************************************************
// Copyright 2017-2022 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/runtime/reference/rdft.hpp"
#include <complex>
#include <ngraph/runtime/reference/utils/fft_common.hpp>
#include <vector>
#include "ngraph/runtime/reference/fft.hpp"
#include "ngraph/shape.hpp"
using namespace ngraph;
using namespace ngraph::runtime::reference;
namespace ngraph {
namespace runtime {
namespace reference {
namespace {
using complex_type = std::complex<float>;
// This function clips transformed axes and writes the result into output
void clip_and_write_result(const std::vector<int64_t>& axes_data,
const std::vector<float>& fft_result,
const Shape& output_fft_shape,
float* rdft_result) {
auto rdft_result_shape = output_fft_shape;
const auto last_axis = axes_data.back();
rdft_result_shape[last_axis] = rdft_result_shape[last_axis] / 2 + 1;
const auto reversed_rdft_result_shape = fft_common::reverse_shape_of_emulated_complex_tensor(rdft_result_shape);
const auto rdft_output_strides = fft_common::compute_strides(reversed_rdft_result_shape);
const auto reversed_output_fft_shape = fft_common::reverse_shape_of_emulated_complex_tensor(output_fft_shape);
const auto output_fft_strides = fft_common::compute_strides(reversed_output_fft_shape);
const auto rdft_output_size = rdft_output_strides.back();
complex_type* complex_output_ptr = reinterpret_cast<complex_type*>(rdft_result);
const complex_type* complex_input_ptr = reinterpret_cast<const complex_type*>(fft_result.data());
for (int64_t i = 0; i < rdft_output_size; ++i) {
const auto coords = fft_common::coords_from_index(i, rdft_output_strides);
const int64_t input_offset = fft_common::offset_from_coords_and_strides(coords, output_fft_strides);
complex_output_ptr[i] = complex_input_ptr[input_offset];
}
}
} // namespace
void rdft(const std::vector<float>& input_data,
const Shape& input_data_shape,
const std::vector<int64_t>& axes_data,
const Shape& output_fft_shape,
float* rdft_result) {
// Converting input data to complex type and calculation of DFT with such data.
size_t input_data_size = input_data.size();
std::vector<complex_type> complex_data(input_data_size);
for (size_t i = 0; i < input_data_size; ++i) {
complex_data[i] = complex_type{input_data[i], 0.0f};
}
auto input_shape_for_fft = input_data_shape;
input_shape_for_fft.push_back(2);
std::vector<float> fft_result(shape_size(output_fft_shape), 0.0f);
fft(reinterpret_cast<const float*>(complex_data.data()),
input_shape_for_fft,
axes_data.data(),
Shape{axes_data.size()},
fft_result.data(),
output_fft_shape,
FFTKind::Forward);
clip_and_write_result(axes_data, fft_result, output_fft_shape, rdft_result);
}
} // namespace reference
} // namespace runtime
} // namespace ngraph

View File

@ -0,0 +1,64 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "ngraph/runtime/reference/utils/fft_common.hpp"
#include <algorithm>
#include <cassert>
#include <cstddef>
#include <cstring>
#include <numeric>
#include "ngraph/check.hpp"
namespace ngraph {
namespace runtime {
namespace reference {
namespace fft_common {
std::vector<int64_t> reverse_shape_of_emulated_complex_tensor(const ngraph::Shape& shape) {
assert(shape.size() >= 2);
std::vector<int64_t> reversed_shape(shape.begin(), shape.end() - 1);
std::reverse(reversed_shape.begin(), reversed_shape.end());
return reversed_shape;
}
std::vector<int64_t> compute_strides(const std::vector<int64_t>& v) {
std::vector<int64_t> strides(v.size() + 1);
int64_t stride = 1;
for (size_t i = 0; i < v.size(); ++i) {
strides[i] = stride;
stride *= v[i];
}
strides.back() = stride;
return strides;
}
std::vector<int64_t> coords_from_index(int64_t index, const std::vector<int64_t>& strides) {
int64_t num_of_axes = static_cast<int64_t>(strides.size()) - 1;
if (num_of_axes == 0) {
return std::vector<int64_t>{};
}
std::vector<int64_t> coords(num_of_axes);
int64_t curr = index;
for (int64_t j = num_of_axes - 1; j >= 1; --j) {
coords[j] = curr / strides[j];
curr %= strides[j];
}
coords[0] = curr;
return coords;
}
int64_t offset_from_coords_and_strides(const std::vector<int64_t>& coords, const std::vector<int64_t>& strides) {
assert(coords.size() < strides.size());
int64_t offset = 0;
int64_t num_of_axes = coords.size();
for (int64_t i = 0; i < num_of_axes; ++i) {
offset += coords[i] * strides[i];
}
return offset;
}
} // namespace fft_common
} // namespace reference
} // namespace runtime
} // namespace ngraph