diff --git a/docs/template_plugin/backend/evaluates_map.cpp b/docs/template_plugin/backend/evaluates_map.cpp index cb3f1fc9e0d..0fd39155b05 100644 --- a/docs/template_plugin/backend/evaluates_map.cpp +++ b/docs/template_plugin/backend/evaluates_map.cpp @@ -64,6 +64,7 @@ #include #include #include +#include #include #include #include @@ -2105,6 +2106,73 @@ bool evaluate(const shared_ptr& op, const HostTensorVector& output return true; } +namespace rfft_v9 { +struct InfoForRFFT9 { + std::vector input_data; + std::vector axes_data; + Shape input_data_shape; + Shape axes_data_shape; + Shape fft_output_shape; + Shape output_shape; +}; + +InfoForRFFT9 get_info_for_rfft9_eval(const std::vector>& inputs) { + InfoForRFFT9 result; + + result.input_data_shape = inputs[0]->get_shape(); + result.axes_data_shape = inputs[1]->get_shape(); + result.input_data = get_floats(inputs[0], result.input_data_shape); + result.axes_data = get_integers(inputs[1], result.axes_data_shape); + + auto fft_output_shape = result.input_data_shape; + auto output_shape = result.input_data_shape; + + int64_t input_rank = static_cast(result.input_data_shape.size()); + auto canonicalized_axes = + runtime::reference::canonicalize_axes(result.axes_data.data(), result.axes_data_shape, input_rank); + + size_t num_of_axes = result.axes_data.size(); + auto signal_size = fft_v7::get_signal_size(inputs, num_of_axes); + + const auto last_axis = canonicalized_axes.back(); + for (size_t i = 0; i < num_of_axes; ++i) { + int64_t current_axis = canonicalized_axes[i]; + int64_t current_signal_size = signal_size[i]; + if (current_signal_size != -1) { + fft_output_shape[current_axis] = current_signal_size; + output_shape[current_axis] = current_signal_size; + } + } + output_shape[last_axis] = fft_output_shape[last_axis] / 2 + 1; + output_shape.push_back(2); + fft_output_shape.push_back(2); + + result.fft_output_shape = fft_output_shape; + result.output_shape = output_shape; + + result.axes_data = canonicalized_axes; + + return result; +} +} // namespace rfft_v9 + +template +bool evaluate(const shared_ptr& op, const HostTensorVector& outputs, const HostTensorVector& inputs) { + auto info = rfft_v9::get_info_for_rfft9_eval(inputs); + outputs[0]->set_shape(info.output_shape); + + std::vector rfft_result(shape_size(info.output_shape), 0.0f); + runtime::reference::rdft(info.input_data, + info.input_data_shape, + info.axes_data, + info.fft_output_shape, + rfft_result.data()); + + const auto output_type = op->get_input_element_type(0); + runtime::reference::fft_postprocessing(outputs, output_type, rfft_result); + return true; +} + template bool evaluate(const shared_ptr& op, const HostTensorVector& outputs, const HostTensorVector& inputs) { using T = typename element_type_traits::value_type; diff --git a/docs/template_plugin/backend/opset_int_tbl.hpp b/docs/template_plugin/backend/opset_int_tbl.hpp index fabfb5ee2aa..471abdfde06 100644 --- a/docs/template_plugin/backend/opset_int_tbl.hpp +++ b/docs/template_plugin/backend/opset_int_tbl.hpp @@ -130,3 +130,5 @@ NGRAPH_OP(Exp, op::v0) NGRAPH_OP(Log, op::v0) NGRAPH_OP(PriorBox, ngraph::op::v8) NGRAPH_OP(PRelu, op::v0) + +NGRAPH_OP(RDFT, op::v9) diff --git a/docs/template_plugin/tests/functional/op_reference/rdft.cpp b/docs/template_plugin/tests/functional/op_reference/rdft.cpp new file mode 100644 index 00000000000..82c6c2c3cf6 --- /dev/null +++ b/docs/template_plugin/tests/functional/op_reference/rdft.cpp @@ -0,0 +1,704 @@ +// Copyright (C) 2018-2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include "base_reference_test.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/rdft.hpp" + +using namespace reference_tests; +using namespace ov; + +namespace { +struct RDFTParams { + template + RDFTParams(const Shape& input_shape, + const Shape& expected_shape, + const element::Type_t& input_type, + const element::Type_t& expected_type, + const std::vector& input_value, + const std::vector& expected_value, + const std::shared_ptr& axes, + const std::shared_ptr& signal) { + m_input_shape = input_shape; + m_expected_shape = expected_shape; + m_input_type = input_type; + m_expected_type = expected_type; + m_input_value = CreateTensor(input_type, input_value); + m_expected_value = CreateTensor(expected_type, expected_value); + m_axes = axes; + m_signal = signal; + } + + Shape m_input_shape; + Shape m_expected_shape; + element::Type_t m_input_type; + element::Type_t m_expected_type; + ov::Tensor m_input_value; + ov::Tensor m_expected_value; + std::shared_ptr m_axes; + std::shared_ptr m_signal; +}; + +class ReferenceRDFTLayerTest : public testing::TestWithParam, public CommonReferenceTest { +public: + void SetUp() override { + auto params = GetParam(); + if (params.m_signal != NULL) { + function = CreateFunctionWithSignal(params); + } else { + function = CreateFunction(params); + } + + inputData = {params.m_input_value}; + refOutData = {params.m_expected_value}; + } + + static std::string getTestCaseName(const testing::TestParamInfo& obj) { + const auto param = obj.param; + std::ostringstream result; + + result << "input_shape1=" << param.m_input_shape << "; "; + result << "output_shape=" << param.m_expected_shape << "; "; + result << "input_type1=" << param.m_input_type << "; "; + result << "output_type=" << param.m_expected_type << "; "; + result << "transpose1=" << param.m_axes; + + return result.str(); + } + +private: + static std::shared_ptr CreateFunction(RDFTParams& p) { + auto in = std::make_shared(p.m_input_type, p.m_input_shape); + auto rdft = std::make_shared(in, p.m_axes); + + return std::make_shared(rdft, ParameterVector{in}); + } + + static std::shared_ptr CreateFunctionWithSignal(RDFTParams& p) { + auto in = std::make_shared(p.m_input_type, p.m_input_shape); + auto rdft = std::make_shared(in, p.m_axes, p.m_signal); + + return std::make_shared(rdft, ParameterVector{in}); + } +}; + +TEST_P(ReferenceRDFTLayerTest, CompareWithHardcodedRefs) { + Exec(); +} + +static const std::vector input_data = { + 0.10606491, 0.7454715, 0.57231355, 0.4582412, 0.3847059, 0.27398932, 0.66796243, 0.395475, + 0.2815729, 0.7799197, 0.59909415, 0.12294636, 0.38957402, 0.97498834, 0.46759892, 0.14017141, + 0.04206858, 0.7279963, 0.61560553, 0.9027321, 0.6226334, 0.2601217, 0.5555177, 0.40498647, + 0.14175586, 0.57774633, 0.52652127, 0.9385691, 0.9588788, 0.9844318, 0.23095612, 0.09707925, + 0.24574867, 0.6907577, 0.1974319, 0.8295272, 0.34612727, 0.51401484, 0.66115797, 0.9336245, + 0.06690067, 0.7468897, 0.39028263, 0.53575844, 0.060429193, 0.8913558, 0.77787375, 0.6701197, + 0.7350527, 0.6636995, 0.18176624, 0.8629976, 0.45142895, 0.6497297, 0.159372, 0.40598175, + 0.7988516, 0.7291543, 0.07090418, 0.7697132, 0.4972157, 0.7669217, 0.67975855, 0.13026066, + 0.6587437, 0.24532892, 0.24545169, 0.83795583, 0.105490535, 0.7264323, 0.94568557, 0.7216649, + 0.14389831, 0.7930531, 0.70895344, 0.9724701, 0.9775157, 0.49999878, 0.65569246, 0.26876843, + 0.63248956, 0.85201293, 0.5689624, 0.023386303, 0.5546464, 0.36860028, 0.9603114, 0.39123482, + 0.0380728, 0.89212376, 0.14387614, 0.63858676, 0.10003748, 0.8906635, 0.06681054, 0.7458642, + 0.45452347, 0.54724604, 0.6496482, 0.7818356, 0.6608355, 0.77711326, 0.24588613, 0.013456763, + 0.355845, 0.80388206, 0.027993264, 0.73677206, 0.52755004, 0.9052324, 0.54311025, 0.5367805, + 0.4131242, 0.7752338, 0.109669454, 0.13664648, 0.7828739, 0.9083969, 0.5247593, 0.7493595, + 0.19275227, 0.007190853, 0.6087981, 0.344136, 0.46909887, 0.41924855, 0.7072913, 0.19932869, + 0.5303847, 0.651384, 0.06686331, 0.9717932, 0.65702224, 0.11786682, 0.3154073, 0.88923013, + 0.5564087, 0.91047823, 0.28466642, 0.0934668, 0.88953066, 0.9919338, 0.18322521, 0.8185455, + 0.566391, 0.014207997, 0.29673064, 0.6347744, 0.6801958, 0.39601147, 0.34374171, 0.7216888, + 0.6152569, 0.76679546, 0.5860851, 0.4276813, 0.79339284, 0.13130653, 0.68764234, 0.053128112, + 0.02611321, 0.2982243, 0.7618372, 0.3331729, 0.5468192, 0.15707079, 0.28592056, 0.15286565, + 0.9368963, 0.350671, 0.4336494, 0.08934934, 0.41172776, 0.5850259, 0.70730376, 0.8598349, + 0.088788144, 0.26711187, 0.8002491, 0.19422275, 0.8312039, 0.5198718, 0.40111357, 0.98375803, + 0.77703434, 0.037818834, 0.704231, 0.689808, 0.17102319, 0.42153922, 0.7278252, 0.8030207, + 0.9101717, 0.0199644, 0.13768466, 0.55669, 0.17991355, 0.6720098, 0.7733328, 0.20881335}; + +static const std::vector expected_rdft1d_results_1 = { + 4.6657147, -1.1622906e-06, 0.21456887, -0.14946258, -0.20476034, -0.37063062, + -0.31414136, 0.5099413, -1.1779613, 0.07057127, -0.64047664, -1.0058284e-07, + 4.982774, -1.1771917e-06, 0.6607505, 0.18829148, -0.9772357, 1.4243596, + 0.8640026, 0.34923682, 0.33401352, 0.25859502, -0.7548928, 8.940697e-08, + 5.9711604, -1.4901161e-06, 0.5638976, 1.5429841, -0.52065414, 0.24638398, + -0.27140495, 0.5040715, 0.5360231, 0.3234269, -0.36054826, 1.7508864e-07, + 4.7464237, -1.2218952e-06, -0.29650804, 0.80609477, -0.161426, 1.0022418, + -0.50812817, 0.7967348, 0.4394225, -0.1588624, -1.3835809, -7.4505806e-08, + 5.53836, -1.7136335e-06, -0.38635445, 0.8284859, -0.23278837, -0.63777345, + -0.93614054, 0.3215857, -0.14075133, -0.67071164, -1.4772836, 2.0861626e-07, + 5.0798974, -1.5944242e-06, 0.056767445, 0.03468219, -0.1497254, -0.9672509, + 0.2603209, 0.69644475, -0.9208536, 0.006730467, -1.7552528, 2.682209e-07, + 4.893558, -1.6242266e-06, 0.6719861, -0.13982919, 0.064845346, -0.39896214, + 0.21785057, -0.5099982, -0.65526295, 1.4383471, -0.52023906, 2.5331974e-07, + 6.687699, -1.5497208e-06, -0.7423769, 0.09968524, 1.052381, -0.21306956, + 0.5875206, -0.3038844, 0.3991575, -1.1895186, 0.17579001, 3.874302e-07, + 5.2818384, -1.1026859e-06, 0.5087582, 0.106959194, 1.1816688, -0.87592727, + 0.03740315, 0.5197907, -1.3198637, 0.6398836, 0.22712436, 2.2351742e-08, + 5.0190897, -1.5646219e-06, -0.087282926, 0.50819266, -0.28002462, 0.29240948, + -0.32303664, 0.38377762, -0.0051696897, -0.99301195, -2.189299, 2.0861626e-07, + 5.0545654, -1.5795231e-06, 0.9146397, 0.83839166, 0.870533, 0.17405808, + -0.56308234, -0.7806684, 0.26397777, 0.6880482, -1.4183462, 2.682209e-07, + 5.479953, -1.2665987e-06, 0.49444157, 0.7534672, -0.76784146, -0.4507342, + 0.88815784, 0.6985409, -0.2727425, -0.25027415, -0.7328796, 2.682209e-07, + 4.1296124, -5.662441e-07, -0.46133032, 0.30635798, -0.18225375, 0.42515472, + -0.5484285, 0.9704039, -0.35255045, 0.17549685, 0.8870368, -3.1292439e-07, + 4.8632016, -1.8924475e-06, -0.6926452, 0.025076404, -0.039108217, -1.7492937, + -0.8120377, -0.85315156, -0.0022608787, 0.45002514, -1.1024668, 3.501773e-07, + 5.4715447, -1.4901161e-06, 1.1176248, -0.2109062, -0.27492502, 0.08983741, + 1.1903813, -1.007312, -0.20150042, -0.83919466, -0.23939973, 4.917383e-07, + 5.1267176, -9.983778e-07, -0.44803134, -0.8066604, -0.3435102, -0.41692197, + -0.22457689, -0.1076939, -0.29129186, -1.1880502, 0.9255183, -1.6391277e-07, + 3.8495903, -5.5134296e-07, 0.09505272, -0.12751618, -1.1264827, 0.5068884, + -1.055237, -0.19516481, -0.34035242, -0.15379356, 1.2655814, -2.6077032e-07, + 4.4372616, -9.23872e-07, -0.72962606, -0.23475963, -0.04278487, 1.1032158, + -0.558924, -0.5300043, 1.0578637, -0.2466627, 0.44617313, -7.8231096e-08, + 5.5374002, -1.4156103e-06, 0.016273111, -0.5989829, -0.19913958, 0.013256833, + 1.8512837, 0.14526272, -0.39700353, -0.07573915, 0.23181, 2.9429793e-07, + 4.989425, -1.4901161e-06, 1.0391837, 0.16554561, -0.22647032, -1.0689808, + -0.84556, -0.82779336, 0.9430445, 0.37618563, 0.4684292, -9.685755e-08}; + +static const std::vector expected_rdft1d_results_2 = { + 2.266797, -8.195639e-08, -0.37842733, -0.41015846, -0.48980892, -0.10356337, + 2.5542018, -2.2351742e-08, -0.3223713, 0.671882, 0.54300576, -0.35418037, + 1.985015, -2.2351742e-08, -0.030243821, -0.20105253, 0.59431964, 0.07358998, + 1.4619737, -7.450581e-09, -0.4356845, 0.35701087, 0.28208786, -0.36424285, + 1.8002605, -1.1920929e-07, -0.43280697, -0.56735414, -0.30007166, -0.541847, + 2.3052943, -1.2293458e-07, -0.39316025, -0.5526293, -0.30507135, -0.6021758, + 2.7329001, -6.7055225e-08, 0.28245124, -0.42586988, -0.40586215, 0.4590181, + 3.3132548, -5.9604645e-08, 0.6297612, 0.3694744, 0.077824846, -0.6248544, + 2.6314974, -2.9802322e-08, 0.58795106, -0.60349375, -0.3224758, 0.34408605, + 1.8399743, -9.685755e-08, -0.43963802, -0.079073176, -0.120658875, -1.0880115, + 2.0531366, -4.4703484e-08, 0.80112594, -0.53726834, -0.17560546, -0.026561722, + 2.3779182, -9.685755e-08, -0.21852754, -0.19336401, 0.38734403, -0.5954362, + 1.6219761, 7.450581e-09, -0.43100592, 0.28373614, 0.101898566, 0.52321124, + 2.128953, -1.4901161e-07, -0.1622684, -0.94116735, -0.7350497, 0.12695336, + 3.449626, -8.940697e-08, 0.56062996, -0.031283244, -0.06161648, -0.8543532, + 3.033568, -8.195639e-08, -0.37023768, -0.03989461, -0.28719214, -0.22382751, + 1.9661667, -1.4901161e-08, -0.59863573, -0.015534669, -0.31916466, 0.55380434, + 2.227056, -5.2154064e-08, -0.12656188, 0.6895717, 0.097157195, 0.19840825, + 3.5129817, -2.1234155e-07, 0.11158541, 0.5870459, 0.20993343, -0.40297145, + 2.5986667, 0.0, 0.26602313, -1.1560227, 0.2542065, 0.45556274}; + +static const std::vector expected_rdft1d_results_3 = { + 4.665715, -1.6093254e-06, -0.5430559, -0.5752678, -0.37596112, -1.1571281, + -0.46793216, -0.94566363, 0.6854232, -0.3444838, -0.674704, 0.5946392, + -0.64047587, 1.3560057e-06, 4.9827743, -1.7434359e-06, -0.43517, -0.049020194, + -1.4773891, -1.0811031, 1.2506557, 0.5371344, 1.2869358, -0.14998645, + 0.8555907, 0.3693859, -0.7548918, 1.5944242e-06, 5.971161, -1.5199184e-06, + -1.2643411, 0.85635287, -0.1801207, -1.7264944, 0.6412285, -0.4787441, + 0.82227707, 0.65098876, 0.9114491, 0.40323836, -0.36054718, 1.2852252e-06, + 4.7464237, -1.66893e-06, -1.5010594, 0.2253451, -0.87915635, -0.4252541, + 0.4976693, -0.6554581, 0.928985, 0.8035921, 0.6578763, -0.15220329, + -1.3835799, 1.0430813e-06, 5.5383606, -1.4901161e-06, -1.619024, -0.10987502, + 0.20661727, -1.3774645, -0.3057741, -1.0960662, 0.2971667, 0.46700704, + -0.20812088, -0.602368, -1.4772825, 9.3877316e-07, 5.0798974, -1.758337e-06, + -0.7421876, -0.61749315, 0.21938956, -1.3415859, -0.838238, -0.6598083, + 1.0601404, -0.7129184, -0.27083004, 0.31763482, -1.7552516, 1.4677644e-06, + 4.893558, -1.4975667e-06, -0.06445231, -0.55879503, 0.08908144, -1.2869594, + 0.33623943, -0.7704663, -0.047739983, -1.0678453, 0.48350462, 1.5768427, + -0.52023804, 1.1697412e-06, 6.687699, -1.3113022e-06, -1.292419, -1.2920969, + 1.2041754, -0.2943018, 1.1889167, -0.66985166, 1.1336832, -0.13731277, + 0.008011267, -0.9506076, 0.1757915, 1.1026859e-06, 5.2818394, -1.4305115e-06, + -0.25987166, -0.48605326, 0.90237427, -0.8028362, -0.3040653, -1.6981151, + 1.1215456, -0.7120959, -0.4195284, 1.3941492, 0.22712523, 8.046627e-07, + 5.01909, -1.7881393e-06, -1.1856917, -0.10931289, -0.5164983, -0.9724103, + 0.30577338, -0.72837675, 0.89680094, 0.21036407, -0.052024096, -0.9455472, + -2.1892984, 1.4305115e-06, 5.054565, -1.5050173e-06, -0.3471575, 0.40542153, + 0.36438322, -0.9765247, 1.2703501, -1.7359983, -0.1160066, -0.25323528, + 0.9753329, 0.5339062, -1.418345, 9.834766e-07, 5.4799523, -1.7285347e-06, + -0.7905842, 0.093313254, 0.068526804, -1.8504739, -0.01845923, 0.26084417, + 1.5358877, -0.4159652, 0.089752786, 0.089908056, -0.7328786, 1.4007092e-06, + 4.129612, -9.536743e-07, -1.2393575, -0.28046644, -0.58673245, -0.39608067, + -0.12385368, -0.53435826, 0.77853805, 0.7645384, -0.18040559, 0.6678516, + 0.88703763, 8.046627e-07, 4.8632016, -1.0430813e-06, -1.1780663, -1.0952923, + 1.1691413, -1.4023741, -0.546494, -0.92614484, -1.1796933, -0.31762218, + 0.25592417, 0.0959474, -1.1024656, 1.013279e-06, 5.471545, -1.6987324e-06, + 0.35812324, -0.66833705, 0.07725692, -1.6537004, 1.6561611, 0.051166296, + 0.865453, -1.1392289, -0.23588535, -0.5480979, -0.2393986, 1.3411045e-06, + 5.126718, -9.23872e-07, -0.6379836, -1.6675751, 0.013057679, -0.9891113, + 0.20881936, -0.30439606, 0.37222707, 0.25244698, -0.9197892, -0.77782196, + 0.9255192, 1.1101365e-06, 3.8495903, -7.4505806e-07, -0.63088936, -0.4556699, + -1.1905057, -1.2522144, 0.46207082, -0.31992733, -0.4309795, 0.74295896, + -0.6106033, 0.18823686, 1.2655822, 7.748604e-07, 4.4372616, -7.0780516e-07, + -1.1016369, -1.0079124, -0.6083025, -0.0011255145, 1.4406854, -0.2912693, + -0.26610214, 0.87299407, 0.69553405, -0.45576566, 0.44617438, 7.4505806e-07, + 5.5374007, -1.5944242e-06, -0.32642078, -1.3683549, 0.079301864, -0.83741367, + 0.67391664, 0.69433576, 1.6423957, -1.1923066, 0.0334223, 0.37603495, + 0.23181117, 1.4156103e-06, 4.9894247, -7.748604e-07, 0.1788401, -0.39274544, + 0.78422666, -2.1340246, 0.5487572, -0.8765497, -0.7899384, 0.5434137, + 0.91613716, 0.08274247, 0.46843058, 8.34465e-07 +}; + +const std::vector expected_rdft2d_results = { + 52.8665, -2.9623508e-05, 1.1642078, 3.826082, -0.22771922, -0.49822173, + -0.3857528, 3.2676966, -2.5112464, -0.27454787, -8.678656, 3.7550926e-06, + -0.818072, 0.8330209, 3.4618711, -0.2419473, 1.7408192, 5.744002, + 1.8477443, 2.039329, 0.3268112, -2.7421296, 0.6809025, 1.7613728, + -2.294264, -0.8984407, -0.2868184, -3.2426705, -0.801461, -0.58971727, + -1.463435, -2.5413132, 0.116907075, -0.5013529, -2.8377397, -2.8455539, + -0.13475686, -1.3145845, -2.2820292, -0.199, -0.056986623, 0.12560216, + -0.589707, -1.7577857, -0.5274223, -1.0395792, 0.53813136, -1.7159984, + 0.22503978, 2.902198, -1.8643543, -1.8789856, 2.1722724, -2.068454, + 0.59446484, 0.6067899, 1.5525781, 1.7612485, 1.1877432, -0.48152098, + -0.16525066, 1.5497208e-06, 1.9815066, 0.55218977, 0.80434155, -3.575598, + -2.1471107, -0.57691807, -3.004384, 3.8775828, 3.1358109, -6.2584877e-07, + 0.22504184, -2.9021916, 1.0378464, 0.9877456, 0.38395065, -1.6089694, + -0.5107449, 1.8621777, -4.960479, -1.8983803, 1.187743, 0.48151842, + -0.1347583, 1.3145843, -0.9968031, -1.3782079, 0.9922035, 1.6614089, + -0.83039653, -0.043888614, 1.9431384, -1.6448143, 0.5381324, 1.7159982, + -2.2942696, 0.8984335, 1.3057998, -0.26607463, -3.2994738, -1.9240448, + 1.4963659, 2.8365738, -4.691832, 1.2995429, -2.8377357, 2.8455553, + -0.8180722, -0.8330165, -1.3755352, 0.34623986, -3.7555497, -0.9723124, + -1.1528367, -0.593254, -0.023679793, 1.8681414, 0.6809023, -1.7613728, + 48.939255, -2.4735928e-05, 1.3455832, 0.11001387, -2.3319814, -1.3735183, + -0.6780232, -2.4875786, 0.40718403, -1.0639579, 0.7314569, -1.2665987e-07, + 0.97006464, -0.30789328, 3.3290033, 2.7749023, -0.7520597, -0.98800826, + 1.3100916, 1.1514524, 1.1085359, 4.348257, -2.839456, 2.4404035, + 0.9518837, 2.1538901, 3.8438358, 2.410589, 3.0649068, 0.95690995, + 2.2213395, 0.66509914, -0.4409917, -0.37408838, -0.6316552, -1.5842111, + -0.72352415, -2.5862057, 0.2678757, 0.610149, 2.9564474, 0.08470708, + -2.0889034, -8.370071, -0.16373271, 2.0413866, -3.3811545, 2.0487003, + 0.0316903, -1.078939, -2.5515578, -0.16135174, -0.17406325, 1.2709827, + -0.67006403, -1.6342779, 0.42163712, 2.1418998, -0.96614444, 1.9175051, + -0.8538456, 2.8014183e-06, 2.0189362, 0.30467552, 0.5074463, 3.7919073, + 2.427857, 0.7526233, -2.4620402, 0.65359443, 0.7219074, -2.3841858e-07, + 0.03169757, 1.0789458, -2.1129081, -1.0250417, 4.8181386, -0.39162922, + -1.2349386, 1.8470186, -0.49495277, -1.5516026, -0.96614635, -1.9175065, + -0.7235237, 2.5862021, 0.677946, 2.0370173, -0.29536027, 0.6505451, + -2.8572361, 2.3176546, 3.4459226, 1.1869265, -3.3811545, -2.048697, + 0.95187366, -2.1538982, 1.808088, -1.1755496, -2.7418838, -1.6770658, + -3.5766084, -2.8320727, -0.02944839, -1.6522555, -0.63165283, 1.5842092, + 0.9700667, 0.30789307, 0.5195943, 2.4985125, 3.6537378, -0.5842519, + -0.4843334, 0.78346854, 0.84766304, 1.1503224, -2.839459, -2.440402}; + +const std::vector expected_rdft2d_results_2 = { + 25.904434, -8.46386e-06, -5.3626504, 0.3475349, -2.7060094, -5.767444, + 1.615847, -2.6387978, 4.020789, 1.4271183, 1.5420923, 0.6126925, + -4.6167765, 5.5730343e-06, -0.753784, -0.19148755, 1.4881928, -2.7645326, + -0.39467168, 1.014636, 0.5598, -1.7654291, -0.91835654, -2.3019042, + -0.49356225, -0.8411435, 0.080773115, -1.2883577, -0.5341466, 1.4913602, + -0.30008763, -0.5831754, 1.7365295, 1.821624, -0.08851206, -1.622279, + -0.27249795, -0.834725, -0.6706438, 0.4766277, 0.62642634, 0.5483514, + -0.5341469, -1.4913592, 0.8286207, 0.35826343, -1.0869694, -1.4876881, + -1.6723244, -0.06565219, 0.16255295, 0.5317876, -0.75649667, 1.2447717, + 0.6264261, -0.5483517, -0.7537827, 0.19148779, 0.6306459, -0.23442982, + 0.57131517, -1.366768, -2.7544713, 1.3638397, 0.43463084, -0.5446956, + -2.9949086, 1.4802479, 0.080771565, 1.2883584, 24.998875, -7.390976e-06, + -3.1970425, -1.5453612, 1.0925753, -6.279154, 2.237704, -2.8844912, + 1.8841789, -1.3615136, 0.90471864, 0.8395144, -2.6060505, 4.976988e-06, + 1.1634235, 0.42319643, 2.678257, 2.4692535, 0.34259582, 0.43598562, + 2.748452, 0.88622695, 2.2745323, -2.8840196, 1.8120161, -0.27884078, + -1.5445104, -0.7000726, -1.0264511, -0.7026249, -1.071573, 1.062395, + -0.64628685, -0.36214483, -0.5110928, -1.0534683, -2.786768, 2.6113648, + 0.94799054, 0.53423727, -0.69832724, 2.1821892, -1.0264513, 0.70262754, + -0.41705567, -0.17140968, 1.4991179, 2.9674625, -0.012362838, -3.8260121, + -1.5786235, -0.32526863, 1.2857957, 1.7469958, -0.6983267, -2.1821907, + 1.1634252, -0.42319855, 0.2716269, 0.21222934, -0.46608746, -1.6447732, + 1.8890494, -1.8022469, -0.37335354, 0.69326025, -0.07385725, -0.1723765, + -1.5445105, 0.7000739}; + +const std::vector expected_rdft3d_results = { + 101.805756, -5.2273273e-05, 2.5097876, 3.936094, -2.5597036, -1.8717405, + -1.0637736, 0.7801182, -2.1040666, -1.3385094, -7.9471993, 2.026558e-06, + 0.15199316, 0.52512753, 6.7908745, 2.5329556, 0.98875976, 4.755993, + 3.157838, 3.190782, 1.4353466, 1.6061276, -2.158554, 4.201776, + -1.3423799, 1.2554499, 3.5570183, -0.8320818, 2.263445, 0.36719292, + 0.7579028, -1.8762131, -0.32408538, -0.87544185, -3.4693956, -4.429764, + -0.85828185, -3.9007902, -2.0141544, 0.4111499, 2.8994608, 0.21030927, + -2.6786098, -10.127857, -0.6911557, 1.0018079, -2.8430226, 0.33270124, + 0.25672907, 1.8232578, -4.4159126, -2.040338, 1.9982092, -0.7974717, + -0.07559925, -1.0274884, 1.9742157, 3.9031482, 0.22159882, 1.4359848, + -1.0190966, 3.2186508e-06, 4.0004425, 0.8568655, 1.3117876, 0.2163087, + 0.28074512, 0.17570588, -5.466423, 4.531178, 3.857718, -1.2516975e-06, + 0.2567385, -1.823246, -1.0750613, -0.037295938, 5.20209, -2.0005994, + -1.7456844, 3.7091968, -5.45543, -3.4499822, 0.22159535, -1.4359887, + -0.8582816, 3.9007854, -0.31885874, 0.65880924, 0.6968423, 2.3119528, + -3.6876333, 2.273767, 5.38906, -0.45788872, -2.8430223, -0.33269957, + -1.3423961, -1.2554631, 3.1138885, -1.4416232, -6.0413575, -3.6011095, + -2.080242, 0.0045015216, -4.7212796, -0.3527125, -3.4693892, 4.429763, + 0.15199506, -0.52512354, -0.85594195, 2.8447511, -0.10181111, -1.5565643, + -1.6371696, 0.19021615, 0.8239815, 3.018465, -2.158556, -4.2017746, + 3.9272437, -3.9339066e-06, -0.18137527, 3.7160687, 2.1042633, 0.8752967, + 0.29226887, 5.755277, -2.9184306, 0.78941, -9.410112, 3.0100346e-06, + -1.7881365, 1.140914, 0.13286811, -3.01685, 2.4928799, 6.7320104, + 0.5376528, 0.88787735, -0.78172505, -7.0903873, 3.5203578, -0.6790314, + -3.246148, -3.0523329, -4.1306543, -5.653259, -3.866367, -1.5466263, + -3.6847744, -3.2064118, 0.5578996, -0.12726665, -2.2060838, -1.2613428, + 0.588767, 1.2716217, -2.5499039, -0.8091496, -3.0134337, 0.0408957, + 1.4991964, 6.6122847, -0.36368948, -3.0809648, 3.9192853, -3.764699, + 0.19334978, 3.9811373, 0.68720365, -1.717634, 2.346336, -3.3394372, + 1.2645291, 2.241068, 1.1309403, -0.3806507, 2.1538877, -2.3990266, + 0.6885946, -1.4901161e-06, -0.037429705, 0.24751475, 0.2968948, -7.367506, + -4.574969, -1.329541, -0.5423446, 3.2239883, 2.4139037, 2.9802322e-07, + 0.19334424, -3.9811373, 3.1507545, 2.0127864, -4.4341884, -1.2173393, + 0.72419256, 0.015158802, -4.4655256, -0.34677732, 2.1538897, 2.3990245, + 0.5887663, -1.2716188, -1.6747494, -3.415226, 1.2875631, 1.0108626, + 2.0268395, -2.3615427, -1.502785, -2.8317401, 3.919288, 3.764695, + -3.2461433, 3.0523314, -0.5022881, 0.9094755, -0.55759126, -0.24697942, + 5.0729737, 5.668646, -4.662384, 2.9517999, -2.2060819, 1.2613468, + -1.7881389, -1.1409098, -1.8951292, -2.1522717, -7.4092865, -0.38806117, + -0.6685039, -1.3767233, -0.8713439, 0.71781945, 3.5203605, 0.6790297}; + +const std::vector expected_rdft3d_results_2 = { + 50.90331, -1.4543533e-05, -8.559692, -1.1978266, -1.6134334, -12.046599, + 3.8535514, -5.5232873, 5.9049683, 0.065603495, 2.4468107, 1.4522064, + -7.222825, 1.2278557e-05, 0.40963984, 0.231709, 4.16645, -0.29528028, + -0.052075505, 1.450621, 3.3082519, -0.8792013, 1.356175, -5.1859245, + 1.3184534, -1.1199851, -1.4637363, -1.9884299, -1.5605974, 0.7887349, + -1.3716602, 0.47921878, 1.0902424, 1.4594792, -0.59960556, -2.6757474, + -3.0592656, 1.7766399, 0.27734682, 1.0108652, -0.07190053, 2.7305403, + -1.5605986, -0.78873086, 0.41156515, 0.18685403, 0.4121489, 1.4797752, + -1.6846865, -3.8916636, -1.4160703, 0.20651829, 0.52929974, 2.9917672, + -0.07190076, -2.7305427, 0.4096415, -0.23171037, 0.9022726, -0.022200808, + 0.10522783, -3.0115416, -0.8654218, -0.4384073, 0.061277367, 0.14856634, + -3.0687659, 1.3078697, -1.4637384, 1.9884316, 25.904425, -24.998884, + -6.9080105, 3.5445771, -8.985163, -6.860018, -1.2686447, -4.8765025, + 2.6592734, -0.45706248, 2.3816066, -0.29202732, -4.6167727, 2.6060565, + -0.33058774, -1.3549114, 3.9574459, -5.44279, 0.041313916, 0.67204094, + 1.446027, -4.5138807, -3.8023772, -4.576436, -0.7724026, -2.6531591, + -0.6192993, 0.25615194, -1.2367722, 2.5178113, 0.7623075, 0.48839718, + 1.3743844, 2.4679115, -1.1419809, -1.1111865, 2.3388672, 1.9520425, + -0.13640736, -0.47136223, 2.8086162, 1.2466785, 0.16848034, -0.46490768, + 0.6572111, 0.7753189, 1.8804929, -2.9868064, -5.498336, -0.053289652, + -0.16271627, 2.1104114, 0.9904991, -0.041024223, -1.5557647, 0.14997506, + -1.1769819, -0.9719368, 0.8428756, -0.5060569, -1.0734584, -0.9006812, + -4.556718, -0.5252099, 1.1278908, -0.17134166, -3.1672862, 1.5541049, + 0.78084624, 2.8328683, 0.90555733, -1.3709068e-06, -2.1656086, 1.8928962, + -3.7985847, 0.511709, -0.62185717, 0.24569236, 2.1366088, 2.7886305, + 0.6373716, -0.2268233, -2.0107267, 5.662441e-07, -1.9172084, -0.6146841, + -1.1900643, -5.233785, -0.73726743, 0.5786506, -2.188651, -2.6516552, + -3.1928902, 0.58211625, -2.305578, -0.5623034, 1.6252834, -0.58828497, + 0.49230486, 2.1939852, 0.7714851, -1.6455705, 2.382816, 2.1837692, + 0.4225806, -0.56881106, 2.514269, -3.4460905, -1.618634, -0.057608932, + 1.3247533, -1.6338379, 0.49230492, -2.1939862, 1.2456759, 0.5296728, + -2.5860875, -4.45515, -1.659962, 3.7603593, 1.7411764, 0.8570565, + -2.0422916, -0.50222373, 1.3247528, 1.633839, -1.9172082, 0.6146865, + 0.35901868, -0.44665974, 1.0374024, 0.27800465, -4.6435204, 3.1660864, + 0.8079842, -1.2379556, -2.921052, 1.6526239, 1.6252828, 0.588284, + 25.90444, 24.998867, -3.817289, -2.8495073, 3.573144, -4.6748676, + 4.500339, -0.40109348, 5.382302, 3.3112957, 0.7025763, 1.5174108, + -4.616783, -2.6060438, -1.1769816, 0.97193646, -0.9810596, -0.086276084, + -0.83065766, 1.3572321, -0.3264265, 0.9830234, 1.9656628, -0.027371943, + -0.2147214, 0.9708719, 0.7808455, -2.8328671, 0.16847888, 0.46490908, + -1.3624828, -1.6547482, 2.0986745, 1.1753378, 0.9649557, -2.1333718, + -2.8838634, -3.6214924, -1.2048804, 1.4246187, -1.5557631, -0.14997569, + -1.2367743, -2.5178103, 1.0000296, -0.05879204, -4.0544314, 0.01142931, + 2.153687, -0.078014135, 0.4878212, -1.0468364, -2.503492, 2.5305676, + 2.808617, -1.2466786, -0.33058444, 1.3549128, 0.41841656, 0.03719666, + 2.216088, -1.8328552, -0.95222485, 3.2528882, -0.25863037, -0.91804826, + -2.822532, 1.4063904, -0.6193025, -0.25615215}; + +template +static std::vector convert(const std::vector& v) { + if (v.empty()) { + return std::vector(); + } + + size_t num_of_elems = v.size(); + std::vector converted(num_of_elems); + for (size_t i = 0; i < num_of_elems; ++i) { + converted[i] = static_cast(v[i]); + } + return converted; +} + +template +static std::vector convert(const std::vector& v) { + if (v.empty()) { + return std::vector(); + } + + size_t num_of_elems = v.size(); + std::vector converted(num_of_elems); + for (size_t i = 0; i < num_of_elems; ++i) { + converted[i] = static_cast(v[i]); + } + return converted; +} + +template +static std::vector convert(const std::vector& v) { + if (v.empty()) { + return std::vector(); + } + + size_t num_of_elems = v.size(); + std::vector converted(num_of_elems); + for (size_t i = 0; i < num_of_elems; ++i) { + converted[i] = static_cast(v[i]); + } + return converted; +} + +template +std::vector generateParamsForRDFT() { + std::vector params{ + // rdft1d_eval + RDFTParams(Shape{2, 10, 10}, + Shape{2, 10, 6, 2}, + ET, + ET, + input_data, + expected_rdft1d_results_1, + op::v0::Constant::create(element::Type_t::i64, Shape{1}, {2}), + NULL), + // rdft1d_eval_signal_size_0 + RDFTParams(Shape{2, 10, 10}, + Shape{2, 10, 6, 2}, + ET, + ET, + input_data, + expected_rdft1d_results_1, + op::v0::Constant::create(element::Type_t::i64, Shape{1}, {2}), + op::v0::Constant::create(element::Type_t::i64, Shape{1}, {10})), + // rdft1d_eval_signal_size_0_1 + RDFTParams(Shape{2, 10, 10}, + Shape{2, 10, 6, 2}, + ET, + ET, + input_data, + expected_rdft1d_results_1, + op::v0::Constant::create(element::Type_t::i64, Shape{1}, {2}), + op::v0::Constant::create(element::Type_t::i64, Shape{1}, {-1})), + // rdft1d_eval_1 + RDFTParams(Shape{2, 10, 10}, + Shape{2, 10, 6, 2}, + ET, + ET, + input_data, + expected_rdft1d_results_1, + op::v0::Constant::create(element::Type_t::i64, Shape{1}, {-1}), + NULL), + // rdft1d_eval_signal_size_1 + RDFTParams(Shape{2, 10, 10}, + Shape{2, 10, 3, 2}, + ET, + ET, + input_data, + expected_rdft1d_results_2, + op::v0::Constant::create(element::Type_t::i64, Shape{1}, {2}), + op::v0::Constant::create(element::Type_t::i64, Shape{1}, {5})), + // rdft1d_eval_signal_size_1_1 + RDFTParams(Shape{2, 10, 10}, + Shape{2, 10, 3, 2}, + ET, + ET, + input_data, + expected_rdft1d_results_2, + op::v0::Constant::create(element::Type_t::i64, Shape{1}, {-1}), + op::v0::Constant::create(element::Type_t::i64, Shape{1}, {5})), + // rdft1d_eval_signal_size_2 + RDFTParams(Shape{2, 10, 10}, + Shape{2, 10, 7, 2}, + ET, + ET, + input_data, + expected_rdft1d_results_3, + op::v0::Constant::create(element::Type_t::i64, Shape{1}, {2}), + op::v0::Constant::create(element::Type_t::i64, Shape{1}, {12})), + // rdft1d_eval_signal_size_2_1 + RDFTParams(Shape{2, 10, 10}, + Shape{2, 10, 7, 2}, + ET, + ET, + input_data, + expected_rdft1d_results_3, + op::v0::Constant::create(element::Type_t::i64, Shape{1}, {-1}), + op::v0::Constant::create(element::Type_t::i64, Shape{1}, {12})), + // rdft2d_eval_1 + RDFTParams(Shape{2, 10, 10}, + Shape{2, 10, 6, 2}, + ET, + ET, + input_data, + expected_rdft2d_results, + op::v0::Constant::create(element::Type_t::i64, Shape{2}, {1, 2}), + NULL), + // rdft2d_eval_1_positive_negative_axes + RDFTParams(Shape{2, 10, 10}, + Shape{2, 10, 6, 2}, + ET, + ET, + input_data, + expected_rdft2d_results, + op::v0::Constant::create(element::Type_t::i64, Shape{2}, {1, -1}), + NULL), + // rdft2d_eval_1_negative_positive_axes + RDFTParams(Shape{2, 10, 10}, + Shape{2, 10, 6, 2}, + ET, + ET, + input_data, + expected_rdft2d_results, + op::v0::Constant::create(element::Type_t::i64, Shape{2}, {-2, 2}), + NULL), + // rdft2d_eval_1_negative_negative_axes + RDFTParams(Shape{2, 10, 10}, + Shape{2, 10, 6, 2}, + ET, + ET, + input_data, + expected_rdft2d_results, + op::v0::Constant::create(element::Type_t::i64, Shape{2}, {-2, -1}), + NULL), + // rdft2d_eval_1_signal_size_0_s10_10 + RDFTParams(Shape{2, 10, 10}, + Shape{2, 10, 6, 2}, + ET, + ET, + input_data, + expected_rdft2d_results, + op::v0::Constant::create(element::Type_t::i64, Shape{2}, {1, 2}), + op::v0::Constant::create(element::Type_t::i64, Shape{2}, {10, 10})), + // rdft2d_eval_1_signal_size_0_s10_10_positive_negative_axes + RDFTParams(Shape{2, 10, 10}, + Shape{2, 10, 6, 2}, + ET, + ET, + input_data, + expected_rdft2d_results, + op::v0::Constant::create(element::Type_t::i64, Shape{2}, {1, -1}), + op::v0::Constant::create(element::Type_t::i64, Shape{2}, {10, 10})), + // rdft2d_eval_1_signal_size_0_s10_10_negative_positive_axes + RDFTParams(Shape{2, 10, 10}, + Shape{2, 10, 6, 2}, + ET, + ET, + input_data, + expected_rdft2d_results, + op::v0::Constant::create(element::Type_t::i64, Shape{2}, {-2, 2}), + op::v0::Constant::create(element::Type_t::i64, Shape{2}, {10, 10})), + // rdft2d_eval_1_signal_size_0_s10_10_negative_negative_axes + RDFTParams(Shape{2, 10, 10}, + Shape{2, 10, 6, 2}, + ET, + ET, + input_data, + expected_rdft2d_results, + op::v0::Constant::create(element::Type_t::i64, Shape{2}, {-2, -1}), + op::v0::Constant::create(element::Type_t::i64, Shape{2}, {10, 10})), + // rdft2d_eval_1_signal_size_0_s10_m1 + RDFTParams(Shape{2, 10, 10}, + Shape{2, 10, 6, 2}, + ET, + ET, + input_data, + expected_rdft2d_results, + op::v0::Constant::create(element::Type_t::i64, Shape{2}, {1, 2}), + op::v0::Constant::create(element::Type_t::i64, Shape{2}, {10, -1})), + // rdft2d_eval_1_signal_size_0_sm1_10 + RDFTParams(Shape{2, 10, 10}, + Shape{2, 10, 6, 2}, + ET, + ET, + input_data, + expected_rdft2d_results, + op::v0::Constant::create(element::Type_t::i64, Shape{2}, {1, 2}), + op::v0::Constant::create(element::Type_t::i64, Shape{2}, {-1, 10})), + // rdft2d_eval_1_signal_size_0_sm1_m1 + RDFTParams(Shape{2, 10, 10}, + Shape{2, 10, 6, 2}, + ET, + ET, + input_data, + expected_rdft2d_results, + op::v0::Constant::create(element::Type_t::i64, Shape{2}, {1, 2}), + op::v0::Constant::create(element::Type_t::i64, Shape{2}, {-1, -1})), + // rdft2d_eval_2_signal_size + RDFTParams(Shape{2, 10, 10}, + Shape{2, 5, 7, 2}, + ET, + ET, + input_data, + expected_rdft2d_results_2, + op::v0::Constant::create(element::Type_t::i64, Shape{2}, {1, 2}), + op::v0::Constant::create(element::Type_t::i64, Shape{2}, {5, 12})), + // rdft2d_eval_2_signal_size_positive_negative_axes + RDFTParams(Shape{2, 10, 10}, + Shape{2, 5, 7, 2}, + ET, + ET, + input_data, + expected_rdft2d_results_2, + op::v0::Constant::create(element::Type_t::i64, Shape{2}, {1, -1}), + op::v0::Constant::create(element::Type_t::i64, Shape{2}, {5, 12})), + // rdft2d_eval_2_signal_size_negative_positive_axes + RDFTParams(Shape{2, 10, 10}, + Shape{2, 5, 7, 2}, + ET, + ET, + input_data, + expected_rdft2d_results_2, + op::v0::Constant::create(element::Type_t::i64, Shape{2}, {-2, 2}), + op::v0::Constant::create(element::Type_t::i64, Shape{2}, {5, 12})), + // rdft2d_eval_2_signal_size_negative_negative_axes + RDFTParams(Shape{2, 10, 10}, + Shape{2, 5, 7, 2}, + ET, + ET, + input_data, + expected_rdft2d_results_2, + op::v0::Constant::create(element::Type_t::i64, Shape{2}, {-2, -1}), + op::v0::Constant::create(element::Type_t::i64, Shape{2}, {5, 12})), + // rdft3d_eval_1 + RDFTParams(Shape{2, 10, 10}, + Shape{2, 10, 6, 2}, + ET, + ET, + input_data, + expected_rdft3d_results, + op::v0::Constant::create(element::Type_t::i64, Shape{3}, {0, 1, 2}), + NULL), + // rdft3d_eval_1_negative_axes_and_signal_size + RDFTParams(Shape{2, 10, 10}, + Shape{2, 10, 6, 2}, + ET, + ET, + input_data, + expected_rdft3d_results, + op::v0::Constant::create(element::Type_t::i64, Shape{3}, {-3, 1, 2}), + op::v0::Constant::create(element::Type_t::i64, Shape{3}, {-1, 10, -1})), + // rdft3d_eval_2 + RDFTParams(Shape{2, 10, 10}, + Shape{4, 5, 7, 2}, + ET, + ET, + input_data, + expected_rdft3d_results_2, + op::v0::Constant::create(element::Type_t::i64, Shape{3}, {0, 1, 2}), + op::v0::Constant::create(element::Type_t::i64, Shape{3}, {4, 5, 12})), + // rdft3d_eval_2_negative_axes + RDFTParams(Shape{2, 10, 10}, + Shape{4, 5, 7, 2}, + ET, + ET, + input_data, + expected_rdft3d_results_2, + op::v0::Constant::create(element::Type_t::i64, Shape{3}, {-3, -2, 2}), + op::v0::Constant::create(element::Type_t::i64, Shape{3}, {4, 5, 12})), + }; + + return params; +} + +std::vector generateCombinedParamsForRDFT() { + const std::vector> allTypeParams{ + generateParamsForRDFT() + }; + + std::vector combinedParams; + + for (const auto& params : allTypeParams) { + combinedParams.insert(combinedParams.end(), params.begin(), params.end()); + } + + return combinedParams; +} + +INSTANTIATE_TEST_SUITE_P( + smoke_RDFT_With_Hardcoded_Refs, + ReferenceRDFTLayerTest, + ::testing::ValuesIn(generateCombinedParamsForRDFT()), + ReferenceRDFTLayerTest::getTestCaseName); +} // namespace diff --git a/src/core/reference/include/ngraph/runtime/reference/rdft.hpp b/src/core/reference/include/ngraph/runtime/reference/rdft.hpp new file mode 100644 index 00000000000..e163cd48fd3 --- /dev/null +++ b/src/core/reference/include/ngraph/runtime/reference/rdft.hpp @@ -0,0 +1,38 @@ +//***************************************************************************** +// Copyright 2017-2022 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** + +#pragma once + +#include +#include +#include + +#include "ngraph/node.hpp" +#include "ngraph/op/util/op_types.hpp" +#include "ngraph/ops.hpp" +#include "ngraph/shape_util.hpp" + +namespace ngraph { +namespace runtime { +namespace reference { +void rdft(const std::vector& input_data, + const Shape& input_data_shape, + const std::vector& axes_data, + const Shape& output_fft_shape, + float* rdft_result); +} // namespace reference +} // namespace runtime +} // namespace ngraph diff --git a/src/core/reference/include/ngraph/runtime/reference/utils/fft_common.hpp b/src/core/reference/include/ngraph/runtime/reference/utils/fft_common.hpp new file mode 100644 index 00000000000..d0d09a2f509 --- /dev/null +++ b/src/core/reference/include/ngraph/runtime/reference/utils/fft_common.hpp @@ -0,0 +1,44 @@ +// Copyright (C) 2018-2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "ngraph/shape.hpp" +#include "ngraph/type/element_type.hpp" + +namespace ngraph { +namespace runtime { +namespace reference { +namespace fft_common { +// To simplify calculation of strides for all axes of 'shape' of some complex +// tensor, we reverse numbers in 'shape'. Because we have no native support for +// complex numbers in tensors, we interpret float input tensors of the shape +// [N_0, ..., N_{r - 1}, 2] as a complex tensor with the shape +// [N_0, ..., N_{r - 1}]. Hence, we convert 'shape=[N_0, ..., N_{r - 1}, 2]' +// into [N_{r - 1}, ..., N_0]. +// At this time, complex tensors are supported only for FFT-like operations, as +// DFT, IDFT, RDFT +std::vector reverse_shape_of_emulated_complex_tensor(const ngraph::Shape& shape); + +// Calculates strides for all axes. +std::vector compute_strides(const std::vector& v); + +// Calculating coordinates c_0, ..., c_{k - 1} from the index of the form +// c_0 * strides[0] + ... c_{k - 1} * strides[k - 1] +// where k is the number of strides. +std::vector coords_from_index(int64_t index, const std::vector& strides); + +// Calculates offset of value using corresponding coordinates and strides. +int64_t offset_from_coords_and_strides(const std::vector& coords, const std::vector& strides); +} // namespace fft_common +} // namespace reference +} // namespace runtime +} // namespace ngraph diff --git a/src/core/reference/src/runtime/reference/fft.cpp b/src/core/reference/src/runtime/reference/fft.cpp index fea15056d56..5e706e4c026 100644 --- a/src/core/reference/src/runtime/reference/fft.cpp +++ b/src/core/reference/src/runtime/reference/fft.cpp @@ -22,6 +22,7 @@ #include #include #include +#include #include #include @@ -57,34 +58,6 @@ std::vector canonicalize_axes(const int64_t* axes_data, namespace { using complex_type = std::complex; -// Calculates strides for all axes. -std::vector compute_strides(const std::vector& v) { - std::vector strides(v.size() + 1); - int64_t stride = 1; - for (size_t i = 0; i < v.size(); ++i) { - strides[i] = stride; - stride *= v[i]; - } - strides.back() = stride; - return strides; -} - -// To simplify calculation of strides for all axes of 'shape' of some complex -// tensor, we reverse numbers in 'shape'. Because we have no native support for -// complex numbers in tensors, we interpret FFT input tensors of the shape -// [N_0, ..., N_{r - 1}, 2] as a complex tensor with the shape -// [N_0, ..., N_{r - 1}]. Hence, we convert 'shape=[N_0, ..., N_{r - 1}, 2]' -// into [N_{r - 1}, ..., N_0]. -std::vector reverse_shape(const Shape& shape) { - size_t complex_data_rank = shape.size() - 1; - - std::vector reversed_shape(complex_data_rank); - for (size_t i = 0; i < complex_data_rank; ++i) { - reversed_shape[i] = static_cast(shape[complex_data_rank - i - 1]); - } - return reversed_shape; -} - // This function gets FFT axes from axes_data std::vector get_axes(const int64_t* axes_data, const Shape& axes_data_shape, int64_t complex_data_rank) { auto axes = canonicalize_axes(axes_data, axes_data_shape, complex_data_rank); @@ -148,24 +121,6 @@ int64_t compute_buffer_size(const std::vector& fft_lengths) { return buffer_size; } -// Calculating coordinates c_0, ..., c_{k - 1} from the index of the form -// c_0 * strides[0] + ... c_{k - 1} * strides[k - 1] -// where k is the number of strides. -std::vector coords_from_index(int64_t index, const std::vector& strides) { - int64_t num_of_axes = static_cast(strides.size()) - 1; - if (num_of_axes == 0) { - return std::vector{}; - } - std::vector coords(num_of_axes); - int64_t curr = index; - for (int64_t j = num_of_axes - 1; j >= 1; --j) { - coords[j] = curr / strides[j]; - curr %= strides[j]; - } - coords[0] = curr; - return coords; -} - // This function gets a complex value from given coords of this value complex_type get_value_from_input(const complex_type* input_data, int64_t src_index, @@ -194,7 +149,7 @@ void copy_data_from_input(complex_type* result, const std::vector& input_fft_lengths, const std::vector& input_fft_strides) { for (int64_t idx = 0; idx < fft_size; ++idx) { - auto coords = coords_from_index(idx, fft_strides); + auto coords = fft_common::coords_from_index(idx, fft_strides); complex_type value = get_value_from_input(input_data, src_index, coords, input_fft_lengths, input_fft_strides); result[idx] = value; } @@ -210,16 +165,6 @@ bool blob_is_zero(const complex_type* data, int64_t blob_size) { return true; } -// Calculates offset of value using corresponding coordinates and strides. -int64_t offset_from_coords_and_strides(const std::vector& coords, const std::vector& strides) { - int64_t offset = 0; - int64_t num_of_axes = coords.size(); - for (int64_t i = 0; i < num_of_axes; ++i) { - offset += coords[i] * strides[i]; - } - return offset; -} - // Copying calculated data to the given memory domain. void copy_data_to_output(complex_type* output, const complex_type* data, @@ -228,9 +173,9 @@ void copy_data_to_output(complex_type* output, const std::vector& fft_strides, const std::vector& output_fft_strides) { for (int64_t idx = 0; idx < fft_size; ++idx) { - auto coords = coords_from_index(idx, fft_strides); + auto coords = fft_common::coords_from_index(idx, fft_strides); complex_type value = data[idx]; - int64_t offset = offset_from_coords_and_strides(coords, output_fft_strides); + int64_t offset = fft_common::offset_from_coords_and_strides(coords, output_fft_strides); output[dst_index + offset] = value; } @@ -372,29 +317,29 @@ InfoForFFTCalculation get_info_for_calculation(const Shape& input_data_shape, const int64_t complex_data_rank = static_cast(input_data_shape.size() - 1); - const auto reversed_output_shape = reverse_shape(output_shape); + const auto reversed_output_shape = fft_common::reverse_shape_of_emulated_complex_tensor(output_shape); auto fft_axes = get_axes(axes_data, axes_data_shape, complex_data_rank); reverse_fft_axes(fft_axes, complex_data_rank); const int64_t fft_rank = fft_axes.size(); const auto fft_lengths = get_lengths(reversed_output_shape, fft_axes); - const auto fft_strides = compute_strides(fft_lengths); + const auto fft_strides = fft_common::compute_strides(fft_lengths); const int64_t fft_size = fft_strides[fft_rank]; const auto outer_axes = get_outer_axes(fft_axes, complex_data_rank); const int64_t outer_rank = outer_axes.size(); const auto outer_lengths = get_lengths(reversed_output_shape, outer_axes); - const auto outer_strides = compute_strides(outer_lengths); + const auto outer_strides = fft_common::compute_strides(outer_lengths); const int64_t outer_size = outer_strides[outer_rank]; const int64_t buffer_size = compute_buffer_size(fft_lengths); - const auto output_strides = compute_strides(reversed_output_shape); + const auto output_strides = fft_common::compute_strides(reversed_output_shape); const auto output_fft_strides = get_lengths(output_strides, fft_axes); const auto output_outer_strides = get_lengths(output_strides, outer_axes); - const auto reversed_input_shape = reverse_shape(input_data_shape); + const auto reversed_input_shape = fft_common::reverse_shape_of_emulated_complex_tensor(input_data_shape); const auto input_fft_lengths = get_lengths(reversed_input_shape, fft_axes); - const auto input_strides = compute_strides(reversed_input_shape); + const auto input_strides = fft_common::compute_strides(reversed_input_shape); const auto input_fft_strides = get_lengths(input_strides, fft_axes); const auto input_outer_strides = get_lengths(input_strides, outer_axes); @@ -461,8 +406,8 @@ void fft(const float* input_data, // Loop along with 'outer' dimensions, that is along with // not transformed dimensions. for (int64_t outer_idx = 0; outer_idx < outer_size; ++outer_idx) { - const auto outer_coords = coords_from_index(outer_idx, outer_strides); - int64_t outer_input_offset = offset_from_coords_and_strides(outer_coords, input_outer_strides); + const auto outer_coords = fft_common::coords_from_index(outer_idx, outer_strides); + int64_t outer_input_offset = fft_common::offset_from_coords_and_strides(outer_coords, input_outer_strides); // Copying current data to transform copy_data_from_input(data.data(), @@ -482,14 +427,14 @@ void fft(const float* input_data, auto outer_fft_axes = lengths_except_given_axis(fft_axes, axis_idx); int64_t outer_fft_size = fft_size / current_fft_length; - auto outer_fft_strides = compute_strides(outer_fft_lengths); + auto outer_fft_strides = fft_common::compute_strides(outer_fft_lengths); auto fft_strides_for_outer_fft_axes = lengths_except_given_axis(fft_strides, axis_idx); // Loop along with all FFT axes, except the current one. for (int64_t outer_fft_idx = 0; outer_fft_idx < outer_fft_size; ++outer_fft_idx) { - const auto outer_fft_coords = coords_from_index(outer_fft_idx, outer_fft_strides); + const auto outer_fft_coords = fft_common::coords_from_index(outer_fft_idx, outer_fft_strides); int64_t outer_fft_offset = - offset_from_coords_and_strides(outer_fft_coords, fft_strides_for_outer_fft_axes); + fft_common::offset_from_coords_and_strides(outer_fft_coords, fft_strides_for_outer_fft_axes); // Calculation of 1D FFT fft1d(current_fft_length, outer_fft_offset, @@ -502,7 +447,7 @@ void fft(const float* input_data, } // Copying current calculated data to the output blob. - int64_t outer_output_offset = offset_from_coords_and_strides(outer_coords, output_outer_strides); + int64_t outer_output_offset = fft_common::offset_from_coords_and_strides(outer_coords, output_outer_strides); copy_data_to_output(complex_output_ptr, data.data(), outer_output_offset, diff --git a/src/core/reference/src/runtime/reference/rdft.cpp b/src/core/reference/src/runtime/reference/rdft.cpp new file mode 100644 index 00000000000..2805b83ccca --- /dev/null +++ b/src/core/reference/src/runtime/reference/rdft.cpp @@ -0,0 +1,89 @@ +//***************************************************************************** +// Copyright 2017-2022 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** + +#include "ngraph/runtime/reference/rdft.hpp" + +#include +#include +#include + +#include "ngraph/runtime/reference/fft.hpp" +#include "ngraph/shape.hpp" + +using namespace ngraph; +using namespace ngraph::runtime::reference; + +namespace ngraph { +namespace runtime { +namespace reference { +namespace { +using complex_type = std::complex; + +// This function clips transformed axes and writes the result into output +void clip_and_write_result(const std::vector& axes_data, + const std::vector& fft_result, + const Shape& output_fft_shape, + float* rdft_result) { + auto rdft_result_shape = output_fft_shape; + const auto last_axis = axes_data.back(); + rdft_result_shape[last_axis] = rdft_result_shape[last_axis] / 2 + 1; + + const auto reversed_rdft_result_shape = fft_common::reverse_shape_of_emulated_complex_tensor(rdft_result_shape); + const auto rdft_output_strides = fft_common::compute_strides(reversed_rdft_result_shape); + const auto reversed_output_fft_shape = fft_common::reverse_shape_of_emulated_complex_tensor(output_fft_shape); + const auto output_fft_strides = fft_common::compute_strides(reversed_output_fft_shape); + const auto rdft_output_size = rdft_output_strides.back(); + + complex_type* complex_output_ptr = reinterpret_cast(rdft_result); + const complex_type* complex_input_ptr = reinterpret_cast(fft_result.data()); + for (int64_t i = 0; i < rdft_output_size; ++i) { + const auto coords = fft_common::coords_from_index(i, rdft_output_strides); + const int64_t input_offset = fft_common::offset_from_coords_and_strides(coords, output_fft_strides); + complex_output_ptr[i] = complex_input_ptr[input_offset]; + } +} +} // namespace + +void rdft(const std::vector& input_data, + const Shape& input_data_shape, + const std::vector& axes_data, + const Shape& output_fft_shape, + float* rdft_result) { + // Converting input data to complex type and calculation of DFT with such data. + size_t input_data_size = input_data.size(); + std::vector complex_data(input_data_size); + for (size_t i = 0; i < input_data_size; ++i) { + complex_data[i] = complex_type{input_data[i], 0.0f}; + } + + auto input_shape_for_fft = input_data_shape; + input_shape_for_fft.push_back(2); + + std::vector fft_result(shape_size(output_fft_shape), 0.0f); + + fft(reinterpret_cast(complex_data.data()), + input_shape_for_fft, + axes_data.data(), + Shape{axes_data.size()}, + fft_result.data(), + output_fft_shape, + FFTKind::Forward); + + clip_and_write_result(axes_data, fft_result, output_fft_shape, rdft_result); +} +} // namespace reference +} // namespace runtime +} // namespace ngraph \ No newline at end of file diff --git a/src/core/reference/src/runtime/reference/utils/fft_common.cpp b/src/core/reference/src/runtime/reference/utils/fft_common.cpp new file mode 100644 index 00000000000..c6be84529fc --- /dev/null +++ b/src/core/reference/src/runtime/reference/utils/fft_common.cpp @@ -0,0 +1,64 @@ +// Copyright (C) 2018-2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "ngraph/runtime/reference/utils/fft_common.hpp" + +#include +#include +#include +#include +#include + +#include "ngraph/check.hpp" + +namespace ngraph { +namespace runtime { +namespace reference { +namespace fft_common { +std::vector reverse_shape_of_emulated_complex_tensor(const ngraph::Shape& shape) { + assert(shape.size() >= 2); + std::vector reversed_shape(shape.begin(), shape.end() - 1); + std::reverse(reversed_shape.begin(), reversed_shape.end()); + return reversed_shape; +} + +std::vector compute_strides(const std::vector& v) { + std::vector strides(v.size() + 1); + int64_t stride = 1; + for (size_t i = 0; i < v.size(); ++i) { + strides[i] = stride; + stride *= v[i]; + } + strides.back() = stride; + return strides; +} + +std::vector coords_from_index(int64_t index, const std::vector& strides) { + int64_t num_of_axes = static_cast(strides.size()) - 1; + if (num_of_axes == 0) { + return std::vector{}; + } + std::vector coords(num_of_axes); + int64_t curr = index; + for (int64_t j = num_of_axes - 1; j >= 1; --j) { + coords[j] = curr / strides[j]; + curr %= strides[j]; + } + coords[0] = curr; + return coords; +} + +int64_t offset_from_coords_and_strides(const std::vector& coords, const std::vector& strides) { + assert(coords.size() < strides.size()); + int64_t offset = 0; + int64_t num_of_axes = coords.size(); + for (int64_t i = 0; i < num_of_axes; ++i) { + offset += coords[i] * strides[i]; + } + return offset; +} +} // namespace fft_common +} // namespace reference +} // namespace runtime +} // namespace ngraph