Files
openvino/docs/template_plugin/tests/functional/op_reference/irdft.cpp

834 lines
54 KiB
C++

// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include <iostream>
#include "base_reference_test.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/irdft.hpp"
using namespace reference_tests;
using namespace ov;
namespace {
struct IRDFTParams {
template <class T>
IRDFTParams(const Shape& input_shape,
const Shape& expected_shape,
const element::Type_t& input_type,
const element::Type_t& expected_type,
const std::vector<T>& input_value,
const std::vector<T>& expected_value,
const std::shared_ptr<op::v0::Constant>& axes,
const std::shared_ptr<op::v0::Constant>& signal) {
m_input_shape = input_shape;
m_expected_shape = expected_shape;
m_input_type = input_type;
m_expected_type = expected_type;
m_input_value = CreateTensor(input_type, input_value);
m_expected_value = CreateTensor(expected_type, expected_value);
m_axes = axes;
m_signal = signal;
}
Shape m_input_shape;
Shape m_expected_shape;
element::Type_t m_input_type;
element::Type_t m_expected_type;
ov::Tensor m_input_value;
ov::Tensor m_expected_value;
std::shared_ptr<op::v0::Constant> m_axes;
std::shared_ptr<op::v0::Constant> m_signal;
};
class ReferenceIRDFTLayerTest : public testing::TestWithParam<IRDFTParams>, public CommonReferenceTest {
public:
void SetUp() override {
auto params = GetParam();
if (params.m_signal != NULL) {
function = CreateFunctionWithSignal(params);
} else {
function = CreateFunction(params);
}
inputData = {params.m_input_value};
refOutData = {params.m_expected_value};
}
static std::string getTestCaseName(const testing::TestParamInfo<IRDFTParams>& obj) {
const auto param = obj.param;
std::ostringstream result;
result << "input_shape1=" << param.m_input_shape << "; ";
result << "output_shape=" << param.m_expected_shape << "; ";
result << "input_type1=" << param.m_input_type << "; ";
result << "output_type=" << param.m_expected_type << "; ";
result << "transpose1=" << param.m_axes;
return result.str();
}
private:
static std::shared_ptr<Model> CreateFunction(IRDFTParams& p) {
auto in = std::make_shared<op::v0::Parameter>(p.m_input_type, p.m_input_shape);
auto irdft = std::make_shared<op::v9::IRDFT>(in, p.m_axes);
return std::make_shared<ov::Model>(irdft, ParameterVector{in});
}
static std::shared_ptr<Model> CreateFunctionWithSignal(IRDFTParams& p) {
auto in = std::make_shared<op::v0::Parameter>(p.m_input_type, p.m_input_shape);
auto irdft = std::make_shared<op::v9::IRDFT>(in, p.m_axes, p.m_signal);
return std::make_shared<ov::Model>(irdft, ParameterVector{in});
}
};
TEST_P(ReferenceIRDFTLayerTest, CompareWithHardcodedRefs) {
Exec();
}
static const std::vector<float> input_data_1 = {
4.6657147, -1.1622906e-06, 0.21456887, -0.14946258, -0.20476034, -0.37063062,
-0.31414136, 0.5099413, -1.1779613, 0.07057127, -0.64047664, -1.0058284e-07,
4.982774, -1.1771917e-06, 0.6607505, 0.18829148, -0.9772357, 1.4243596,
0.8640026, 0.34923682, 0.33401352, 0.25859502, -0.7548928, 8.940697e-08,
5.9711604, -1.4901161e-06, 0.5638976, 1.5429841, -0.52065414, 0.24638398,
-0.27140495, 0.5040715, 0.5360231, 0.3234269, -0.36054826, 1.7508864e-07,
4.7464237, -1.2218952e-06, -0.29650804, 0.80609477, -0.161426, 1.0022418,
-0.50812817, 0.7967348, 0.4394225, -0.1588624, -1.3835809, -7.4505806e-08,
5.53836, -1.7136335e-06, -0.38635445, 0.8284859, -0.23278837, -0.63777345,
-0.93614054, 0.3215857, -0.14075133, -0.67071164, -1.4772836, 2.0861626e-07,
5.0798974, -1.5944242e-06, 0.056767445, 0.03468219, -0.1497254, -0.9672509,
0.2603209, 0.69644475, -0.9208536, 0.006730467, -1.7552528, 2.682209e-07,
4.893558, -1.6242266e-06, 0.6719861, -0.13982919, 0.064845346, -0.39896214,
0.21785057, -0.5099982, -0.65526295, 1.4383471, -0.52023906, 2.5331974e-07,
6.687699, -1.5497208e-06, -0.7423769, 0.09968524, 1.052381, -0.21306956,
0.5875206, -0.3038844, 0.3991575, -1.1895186, 0.17579001, 3.874302e-07,
5.2818384, -1.1026859e-06, 0.5087582, 0.106959194, 1.1816688, -0.87592727,
0.03740315, 0.5197907, -1.3198637, 0.6398836, 0.22712436, 2.2351742e-08,
5.0190897, -1.5646219e-06, -0.087282926, 0.50819266, -0.28002462, 0.29240948,
-0.32303664, 0.38377762, -0.0051696897, -0.99301195, -2.189299, 2.0861626e-07,
5.0545654, -1.5795231e-06, 0.9146397, 0.83839166, 0.870533, 0.17405808,
-0.56308234, -0.7806684, 0.26397777, 0.6880482, -1.4183462, 2.682209e-07,
5.479953, -1.2665987e-06, 0.49444157, 0.7534672, -0.76784146, -0.4507342,
0.88815784, 0.6985409, -0.2727425, -0.25027415, -0.7328796, 2.682209e-07,
4.1296124, -5.662441e-07, -0.46133032, 0.30635798, -0.18225375, 0.42515472,
-0.5484285, 0.9704039, -0.35255045, 0.17549685, 0.8870368, -3.1292439e-07,
4.8632016, -1.8924475e-06, -0.6926452, 0.025076404, -0.039108217, -1.7492937,
-0.8120377, -0.85315156, -0.0022608787, 0.45002514, -1.1024668, 3.501773e-07,
5.4715447, -1.4901161e-06, 1.1176248, -0.2109062, -0.27492502, 0.08983741,
1.1903813, -1.007312, -0.20150042, -0.83919466, -0.23939973, 4.917383e-07,
5.1267176, -9.983778e-07, -0.44803134, -0.8066604, -0.3435102, -0.41692197,
-0.22457689, -0.1076939, -0.29129186, -1.1880502, 0.9255183, -1.6391277e-07,
3.8495903, -5.5134296e-07, 0.09505272, -0.12751618, -1.1264827, 0.5068884,
-1.055237, -0.19516481, -0.34035242, -0.15379356, 1.2655814, -2.6077032e-07,
4.4372616, -9.23872e-07, -0.72962606, -0.23475963, -0.04278487, 1.1032158,
-0.558924, -0.5300043, 1.0578637, -0.2466627, 0.44617313, -7.8231096e-08,
5.5374002, -1.4156103e-06, 0.016273111, -0.5989829, -0.19913958, 0.013256833,
1.8512837, 0.14526272, -0.39700353, -0.07573915, 0.23181, 2.9429793e-07,
4.989425, -1.4901161e-06, 1.0391837, 0.16554561, -0.22647032, -1.0689808,
-0.84556, -0.82779336, 0.9430445, 0.37618563, 0.4684292, -9.685755e-08};
static const std::vector<float> expected_irdft1d_results_1 = {
0.10606491, 0.7454715, 0.57231355, 0.4582412, 0.3847059, 0.27398932, 0.66796243, 0.395475,
0.2815729, 0.7799197, 0.59909415, 0.12294636, 0.38957402, 0.97498834, 0.46759892, 0.14017141,
0.04206858, 0.7279963, 0.61560553, 0.9027321, 0.6226334, 0.2601217, 0.5555177, 0.40498647,
0.14175586, 0.57774633, 0.52652127, 0.9385691, 0.9588788, 0.9844318, 0.23095612, 0.09707925,
0.24574867, 0.6907577, 0.1974319, 0.8295272, 0.34612727, 0.51401484, 0.66115797, 0.9336245,
0.06690067, 0.7468897, 0.39028263, 0.53575844, 0.060429193, 0.8913558, 0.77787375, 0.6701197,
0.7350527, 0.6636995, 0.18176624, 0.8629976, 0.45142895, 0.6497297, 0.159372, 0.40598175,
0.7988516, 0.7291543, 0.07090418, 0.7697132, 0.4972157, 0.7669217, 0.67975855, 0.13026066,
0.6587437, 0.24532892, 0.24545169, 0.83795583, 0.105490535, 0.7264323, 0.94568557, 0.7216649,
0.14389831, 0.7930531, 0.70895344, 0.9724701, 0.9775157, 0.49999878, 0.65569246, 0.26876843,
0.63248956, 0.85201293, 0.5689624, 0.023386303, 0.5546464, 0.36860028, 0.9603114, 0.39123482,
0.0380728, 0.89212376, 0.14387614, 0.63858676, 0.10003748, 0.8906635, 0.06681054, 0.7458642,
0.45452347, 0.54724604, 0.6496482, 0.7818356, 0.6608355, 0.77711326, 0.24588613, 0.013456763,
0.355845, 0.80388206, 0.027993264, 0.73677206, 0.52755004, 0.9052324, 0.54311025, 0.5367805,
0.4131242, 0.7752338, 0.109669454, 0.13664648, 0.7828739, 0.9083969, 0.5247593, 0.7493595,
0.19275227, 0.007190853, 0.6087981, 0.344136, 0.46909887, 0.41924855, 0.7072913, 0.19932869,
0.5303847, 0.651384, 0.06686331, 0.9717932, 0.65702224, 0.11786682, 0.3154073, 0.88923013,
0.5564087, 0.91047823, 0.28466642, 0.0934668, 0.88953066, 0.9919338, 0.18322521, 0.8185455,
0.566391, 0.014207997, 0.29673064, 0.6347744, 0.6801958, 0.39601147, 0.34374171, 0.7216888,
0.6152569, 0.76679546, 0.5860851, 0.4276813, 0.79339284, 0.13130653, 0.68764234, 0.053128112,
0.02611321, 0.2982243, 0.7618372, 0.3331729, 0.5468192, 0.15707079, 0.28592056, 0.15286565,
0.9368963, 0.350671, 0.4336494, 0.08934934, 0.41172776, 0.5850259, 0.70730376, 0.8598349,
0.088788144, 0.26711187, 0.8002491, 0.19422275, 0.8312039, 0.5198718, 0.40111357, 0.98375803,
0.77703434, 0.037818834, 0.704231, 0.689808, 0.17102319, 0.42153922, 0.7278252, 0.8030207,
0.9101717, 0.0199644, 0.13768466, 0.55669, 0.17991355, 0.6720098, 0.7733328, 0.20881335};
static const std::vector<float> input_data_2 = {
2.266797, -8.195639e-08, -0.37842733, -0.41015846, -0.48980892, -0.10356337,
2.5542018, -2.2351742e-08, -0.3223713, 0.671882, 0.54300576, -0.35418037,
1.985015, -2.2351742e-08, -0.030243821, -0.20105253, 0.59431964, 0.07358998,
1.4619737, -7.450581e-09, -0.4356845, 0.35701087, 0.28208786, -0.36424285,
1.8002605, -1.1920929e-07, -0.43280697, -0.56735414, -0.30007166, -0.541847,
2.3052943, -1.2293458e-07, -0.39316025, -0.5526293, -0.30507135, -0.6021758,
2.7329001, -6.7055225e-08, 0.28245124, -0.42586988, -0.40586215, 0.4590181,
3.3132548, -5.9604645e-08, 0.6297612, 0.3694744, 0.077824846, -0.6248544,
2.6314974, -2.9802322e-08, 0.58795106, -0.60349375, -0.3224758, 0.34408605,
1.8399743, -9.685755e-08, -0.43963802, -0.079073176, -0.120658875, -1.0880115,
2.0531366, -4.4703484e-08, 0.80112594, -0.53726834, -0.17560546, -0.026561722,
2.3779182, -9.685755e-08, -0.21852754, -0.19336401, 0.38734403, -0.5954362,
1.6219761, 7.450581e-09, -0.43100592, 0.28373614, 0.101898566, 0.52321124,
2.128953, -1.4901161e-07, -0.1622684, -0.94116735, -0.7350497, 0.12695336,
3.449626, -8.940697e-08, 0.56062996, -0.031283244, -0.06161648, -0.8543532,
3.033568, -8.195639e-08, -0.37023768, -0.03989461, -0.28719214, -0.22382751,
1.9661667, -1.4901161e-08, -0.59863573, -0.015534669, -0.31916466, 0.55380434,
2.227056, -5.2154064e-08, -0.12656188, 0.6895717, 0.097157195, 0.19840825,
3.5129817, -2.1234155e-07, 0.11158541, 0.5870459, 0.20993343, -0.40297145,
2.5986667, 0.0, 0.26602313, -1.1560227, 0.2542065, 0.45556274};
static const std::vector<float> expected_irdft1d_results_2 = {
0.10606494, 0.7454715, 0.5723136, 0.45824113, 0.38470596, 0.59909415, 0.12294642,
0.38957405, 0.9749881, 0.46759906, 0.62263334, 0.26012173, 0.5555176, 0.40498644,
0.14175594, 0.23095612, 0.097079255, 0.24574867, 0.6907576, 0.197432, 0.066900685,
0.7468896, 0.39028254, 0.5357583, 0.0604293, 0.18176621, 0.8629975, 0.45142898,
0.64972955, 0.15937212, 0.49721566, 0.7669216, 0.6797584, 0.13026062, 0.6587438,
0.9456854, 0.7216646, 0.14389832, 0.7930529, 0.7089534, 0.6324895, 0.85201263,
0.5689623, 0.023386315, 0.55464643, 0.1438762, 0.63858664, 0.10003753, 0.8906633,
0.06681056, 0.66083544, 0.7771131, 0.24588616, 0.013456774, 0.35584506, 0.54311025,
0.53678054, 0.41312417, 0.7752337, 0.10966951, 0.19275223, 0.007190934, 0.608798,
0.344136, 0.46909887, 0.06686333, 0.971793, 0.65702224, 0.117866985, 0.31540743,
0.8895306, 0.99193364, 0.18322523, 0.81854534, 0.5663911, 0.34374166, 0.72168875,
0.6152569, 0.7667953, 0.58608514, 0.026113158, 0.2982243, 0.76183707, 0.3331729,
0.5468192, 0.43364936, 0.089349344, 0.41172776, 0.5850257, 0.7073037, 0.8312039,
0.5198719, 0.4011136, 0.9837578, 0.7770344, 0.72782516, 0.8030205, 0.9101716,
0.019964492, 0.13768478};
static const std::vector<float> input_data_3 = {
4.665715, -1.6093254e-06, -0.5430559, -0.5752678, -0.37596112, -1.1571281,
-0.46793216, -0.94566363, 0.6854232, -0.3444838, -0.674704, 0.5946392,
-0.64047587, 1.3560057e-06, 4.9827743, -1.7434359e-06, -0.43517, -0.049020194,
-1.4773891, -1.0811031, 1.2506557, 0.5371344, 1.2869358, -0.14998645,
0.8555907, 0.3693859, -0.7548918, 1.5944242e-06, 5.971161, -1.5199184e-06,
-1.2643411, 0.85635287, -0.1801207, -1.7264944, 0.6412285, -0.4787441,
0.82227707, 0.65098876, 0.9114491, 0.40323836, -0.36054718, 1.2852252e-06,
4.7464237, -1.66893e-06, -1.5010594, 0.2253451, -0.87915635, -0.4252541,
0.4976693, -0.6554581, 0.928985, 0.8035921, 0.6578763, -0.15220329,
-1.3835799, 1.0430813e-06, 5.5383606, -1.4901161e-06, -1.619024, -0.10987502,
0.20661727, -1.3774645, -0.3057741, -1.0960662, 0.2971667, 0.46700704,
-0.20812088, -0.602368, -1.4772825, 9.3877316e-07, 5.0798974, -1.758337e-06,
-0.7421876, -0.61749315, 0.21938956, -1.3415859, -0.838238, -0.6598083,
1.0601404, -0.7129184, -0.27083004, 0.31763482, -1.7552516, 1.4677644e-06,
4.893558, -1.4975667e-06, -0.06445231, -0.55879503, 0.08908144, -1.2869594,
0.33623943, -0.7704663, -0.047739983, -1.0678453, 0.48350462, 1.5768427,
-0.52023804, 1.1697412e-06, 6.687699, -1.3113022e-06, -1.292419, -1.2920969,
1.2041754, -0.2943018, 1.1889167, -0.66985166, 1.1336832, -0.13731277,
0.008011267, -0.9506076, 0.1757915, 1.1026859e-06, 5.2818394, -1.4305115e-06,
-0.25987166, -0.48605326, 0.90237427, -0.8028362, -0.3040653, -1.6981151,
1.1215456, -0.7120959, -0.4195284, 1.3941492, 0.22712523, 8.046627e-07,
5.01909, -1.7881393e-06, -1.1856917, -0.10931289, -0.5164983, -0.9724103,
0.30577338, -0.72837675, 0.89680094, 0.21036407, -0.052024096, -0.9455472,
-2.1892984, 1.4305115e-06, 5.054565, -1.5050173e-06, -0.3471575, 0.40542153,
0.36438322, -0.9765247, 1.2703501, -1.7359983, -0.1160066, -0.25323528,
0.9753329, 0.5339062, -1.418345, 9.834766e-07, 5.4799523, -1.7285347e-06,
-0.7905842, 0.093313254, 0.068526804, -1.8504739, -0.01845923, 0.26084417,
1.5358877, -0.4159652, 0.089752786, 0.089908056, -0.7328786, 1.4007092e-06,
4.129612, -9.536743e-07, -1.2393575, -0.28046644, -0.58673245, -0.39608067,
-0.12385368, -0.53435826, 0.77853805, 0.7645384, -0.18040559, 0.6678516,
0.88703763, 8.046627e-07, 4.8632016, -1.0430813e-06, -1.1780663, -1.0952923,
1.1691413, -1.4023741, -0.546494, -0.92614484, -1.1796933, -0.31762218,
0.25592417, 0.0959474, -1.1024656, 1.013279e-06, 5.471545, -1.6987324e-06,
0.35812324, -0.66833705, 0.07725692, -1.6537004, 1.6561611, 0.051166296,
0.865453, -1.1392289, -0.23588535, -0.5480979, -0.2393986, 1.3411045e-06,
5.126718, -9.23872e-07, -0.6379836, -1.6675751, 0.013057679, -0.9891113,
0.20881936, -0.30439606, 0.37222707, 0.25244698, -0.9197892, -0.77782196,
0.9255192, 1.1101365e-06, 3.8495903, -7.4505806e-07, -0.63088936, -0.4556699,
-1.1905057, -1.2522144, 0.46207082, -0.31992733, -0.4309795, 0.74295896,
-0.6106033, 0.18823686, 1.2655822, 7.748604e-07, 4.4372616, -7.0780516e-07,
-1.1016369, -1.0079124, -0.6083025, -0.0011255145, 1.4406854, -0.2912693,
-0.26610214, 0.87299407, 0.69553405, -0.45576566, 0.44617438, 7.4505806e-07,
5.5374007, -1.5944242e-06, -0.32642078, -1.3683549, 0.079301864, -0.83741367,
0.67391664, 0.69433576, 1.6423957, -1.1923066, 0.0334223, 0.37603495,
0.23181117, 1.4156103e-06, 4.9894247, -7.748604e-07, 0.1788401, -0.39274544,
0.78422666, -2.1340246, 0.5487572, -0.8765497, -0.7899384, 0.5434137,
0.91613716, 0.08274247, 0.46843058, 8.34465e-07};
static const std::vector<float> expected_irdft1d_results_3 = {
0.80091053, 1.548053, 1.3439665, 0.97278523, 0.65876126, 1.6395509, 1.0939313,
1.5905306, 0.81558955, 1.1096439, 2.0799308, 1.9659967, 0.21628714, 1.2937224,
1.7173465, 1.5190675, 0.62673247, 1.3878733, 2.2457566, 1.2779983, 0.9537279,
1.5238736, 1.6959155, 0.9063804, 1.2134336, 1.4805167, 1.277886, 0.9217217,
1.3267591, 2.0169291, 2.619178, 0.7248324, 1.4161175, 1.3378929, 1.6759893,
0.85183966, 0.53280216, 1.4385536, 1.7184939, 1.3292406, 1.1811583, 0.9698347,
1.5283158, 1.3752562, 0.99182767, 1.3061998, 1.7824118, 1.399513, 0.26604116,
1.3193192, 1.5053986, 1.0388529, 0.9190526, 1.4711612, 2.0971189, 0.37586892,
1.5662622, 1.6827406, 1.208139, 1.0144035, 0.96595216, 2.1122026, 1.6039357,
0.44462752, 0.34932646, 1.487859, 0.9802158, 1.0321891, 0.4064213, 1.7653472,
1.5080582, 0.75743484, 1.2409652, 2.0487022, 1.567386, 0.68034726, 1.5328329,
1.2476723, 1.3539927, 0.8549268};
static const std::vector<float> input_data_4 = {
52.8665, -2.9623508e-05, 1.1642078, 3.826082, -0.22771922, -0.49822173, -0.3857528,
3.2676966, -2.5112464, -0.27454787, -8.678656, 3.7550926e-06, -0.818072, 0.8330209,
3.4618711, -0.2419473, 1.7408192, 5.744002, 1.8477443, 2.039329, 0.3268112,
-2.7421296, 0.6809025, 1.7613728, -2.294264, -0.8984407, -0.2868184, -3.2426705,
-0.801461, -0.58971727, -1.463435, -2.5413132, 0.116907075, -0.5013529, -2.8377397,
-2.8455539, -0.13475686, -1.3145845, -2.2820292, -0.199, -0.056986623, 0.12560216,
-0.589707, -1.7577857, -0.5274223, -1.0395792, 0.53813136, -1.7159984, 0.22503978,
2.902198, -1.8643543, -1.8789856, 2.1722724, -2.068454, 0.59446484, 0.6067899,
1.5525781, 1.7612485, 1.1877432, -0.48152098, -0.16525066, 1.5497208e-06, 1.9815066,
0.55218977, 0.80434155, -3.575598, -2.1471107, -0.57691807, -3.004384, 3.8775828,
3.1358109, -6.2584877e-07, 0.22504184, -2.9021916, 1.0378464, 0.9877456, 0.38395065,
-1.6089694, -0.5107449, 1.8621777, -4.960479, -1.8983803, 1.187743, 0.48151842,
-0.1347583, 1.3145843, -0.9968031, -1.3782079, 0.9922035, 1.6614089, -0.83039653,
-0.043888614, 1.9431384, -1.6448143, 0.5381324, 1.7159982, -2.2942696, 0.8984335,
1.3057998, -0.26607463, -3.2994738, -1.9240448, 1.4963659, 2.8365738, -4.691832,
1.2995429, -2.8377357, 2.8455553, -0.8180722, -0.8330165, -1.3755352, 0.34623986,
-3.7555497, -0.9723124, -1.1528367, -0.593254, -0.023679793, 1.8681414, 0.6809023,
-1.7613728, 48.939255, -2.4735928e-05, 1.3455832, 0.11001387, -2.3319814, -1.3735183,
-0.6780232, -2.4875786, 0.40718403, -1.0639579, 0.7314569, -1.2665987e-07, 0.97006464,
-0.30789328, 3.3290033, 2.7749023, -0.7520597, -0.98800826, 1.3100916, 1.1514524,
1.1085359, 4.348257, -2.839456, 2.4404035, 0.9518837, 2.1538901, 3.8438358,
2.410589, 3.0649068, 0.95690995, 2.2213395, 0.66509914, -0.4409917, -0.37408838,
-0.6316552, -1.5842111, -0.72352415, -2.5862057, 0.2678757, 0.610149, 2.9564474,
0.08470708, -2.0889034, -8.370071, -0.16373271, 2.0413866, -3.3811545, 2.0487003,
0.0316903, -1.078939, -2.5515578, -0.16135174, -0.17406325, 1.2709827, -0.67006403,
-1.6342779, 0.42163712, 2.1418998, -0.96614444, 1.9175051, -0.8538456, 2.8014183e-06,
2.0189362, 0.30467552, 0.5074463, 3.7919073, 2.427857, 0.7526233, -2.4620402,
0.65359443, 0.7219074, -2.3841858e-07, 0.03169757, 1.0789458, -2.1129081, -1.0250417,
4.8181386, -0.39162922, -1.2349386, 1.8470186, -0.49495277, -1.5516026, -0.96614635,
-1.9175065, -0.7235237, 2.5862021, 0.677946, 2.0370173, -0.29536027, 0.6505451,
-2.8572361, 2.3176546, 3.4459226, 1.1869265, -3.3811545, -2.048697, 0.95187366,
-2.1538982, 1.808088, -1.1755496, -2.7418838, -1.6770658, -3.5766084, -2.8320727,
-0.02944839, -1.6522555, -0.63165283, 1.5842092, 0.9700667, 0.30789307, 0.5195943,
2.4985125, 3.6537378, -0.5842519, -0.4843334, 0.78346854, 0.84766304, 1.1503224,
-2.839459, -2.440402};
static const std::vector<float> input_data_5 = {
25.904434, -8.46386e-06, -5.3626504, 0.3475349, -2.7060094, -5.767444,
1.615847, -2.6387978, 4.020789, 1.4271183, 1.5420923, 0.6126925,
-4.6167765, 5.5730343e-06, -0.753784, -0.19148755, 1.4881928, -2.7645326,
-0.39467168, 1.014636, 0.5598, -1.7654291, -0.91835654, -2.3019042,
-0.49356225, -0.8411435, 0.080773115, -1.2883577, -0.5341466, 1.4913602,
-0.30008763, -0.5831754, 1.7365295, 1.821624, -0.08851206, -1.622279,
-0.27249795, -0.834725, -0.6706438, 0.4766277, 0.62642634, 0.5483514,
-0.5341469, -1.4913592, 0.8286207, 0.35826343, -1.0869694, -1.4876881,
-1.6723244, -0.06565219, 0.16255295, 0.5317876, -0.75649667, 1.2447717,
0.6264261, -0.5483517, -0.7537827, 0.19148779, 0.6306459, -0.23442982,
0.57131517, -1.366768, -2.7544713, 1.3638397, 0.43463084, -0.5446956,
-2.9949086, 1.4802479, 0.080771565, 1.2883584, 24.998875, -7.390976e-06,
-3.1970425, -1.5453612, 1.0925753, -6.279154, 2.237704, -2.8844912,
1.8841789, -1.3615136, 0.90471864, 0.8395144, -2.6060505, 4.976988e-06,
1.1634235, 0.42319643, 2.678257, 2.4692535, 0.34259582, 0.43598562,
2.748452, 0.88622695, 2.2745323, -2.8840196, 1.8120161, -0.27884078,
-1.5445104, -0.7000726, -1.0264511, -0.7026249, -1.071573, 1.062395,
-0.64628685, -0.36214483, -0.5110928, -1.0534683, -2.786768, 2.6113648,
0.94799054, 0.53423727, -0.69832724, 2.1821892, -1.0264513, 0.70262754,
-0.41705567, -0.17140968, 1.4991179, 2.9674625, -0.012362838, -3.8260121,
-1.5786235, -0.32526863, 1.2857957, 1.7469958, -0.6983267, -2.1821907,
1.1634252, -0.42319855, 0.2716269, 0.21222934, -0.46608746, -1.6447732,
1.8890494, -1.8022469, -0.37335354, 0.69326025, -0.07385725, -0.1723765,
-1.5445105, 0.7000739};
static const std::vector<float> input_data_6 = {
101.805756, -5.2273273e-05, 2.5097876, 3.936094, -2.5597036, -1.8717405,
-1.0637736, 0.7801182, -2.1040666, -1.3385094, -7.9471993, 2.026558e-06,
0.15199316, 0.52512753, 6.7908745, 2.5329556, 0.98875976, 4.755993,
3.157838, 3.190782, 1.4353466, 1.6061276, -2.158554, 4.201776,
-1.3423799, 1.2554499, 3.5570183, -0.8320818, 2.263445, 0.36719292,
0.7579028, -1.8762131, -0.32408538, -0.87544185, -3.4693956, -4.429764,
-0.85828185, -3.9007902, -2.0141544, 0.4111499, 2.8994608, 0.21030927,
-2.6786098, -10.127857, -0.6911557, 1.0018079, -2.8430226, 0.33270124,
0.25672907, 1.8232578, -4.4159126, -2.040338, 1.9982092, -0.7974717,
-0.07559925, -1.0274884, 1.9742157, 3.9031482, 0.22159882, 1.4359848,
-1.0190966, 3.2186508e-06, 4.0004425, 0.8568655, 1.3117876, 0.2163087,
0.28074512, 0.17570588, -5.466423, 4.531178, 3.857718, -1.2516975e-06,
0.2567385, -1.823246, -1.0750613, -0.037295938, 5.20209, -2.0005994,
-1.7456844, 3.7091968, -5.45543, -3.4499822, 0.22159535, -1.4359887,
-0.8582816, 3.9007854, -0.31885874, 0.65880924, 0.6968423, 2.3119528,
-3.6876333, 2.273767, 5.38906, -0.45788872, -2.8430223, -0.33269957,
-1.3423961, -1.2554631, 3.1138885, -1.4416232, -6.0413575, -3.6011095,
-2.080242, 0.0045015216, -4.7212796, -0.3527125, -3.4693892, 4.429763,
0.15199506, -0.52512354, -0.85594195, 2.8447511, -0.10181111, -1.5565643,
-1.6371696, 0.19021615, 0.8239815, 3.018465, -2.158556, -4.2017746,
3.9272437, -3.9339066e-06, -0.18137527, 3.7160687, 2.1042633, 0.8752967,
0.29226887, 5.755277, -2.9184306, 0.78941, -9.410112, 3.0100346e-06,
-1.7881365, 1.140914, 0.13286811, -3.01685, 2.4928799, 6.7320104,
0.5376528, 0.88787735, -0.78172505, -7.0903873, 3.5203578, -0.6790314,
-3.246148, -3.0523329, -4.1306543, -5.653259, -3.866367, -1.5466263,
-3.6847744, -3.2064118, 0.5578996, -0.12726665, -2.2060838, -1.2613428,
0.588767, 1.2716217, -2.5499039, -0.8091496, -3.0134337, 0.0408957,
1.4991964, 6.6122847, -0.36368948, -3.0809648, 3.9192853, -3.764699,
0.19334978, 3.9811373, 0.68720365, -1.717634, 2.346336, -3.3394372,
1.2645291, 2.241068, 1.1309403, -0.3806507, 2.1538877, -2.3990266,
0.6885946, -1.4901161e-06, -0.037429705, 0.24751475, 0.2968948, -7.367506,
-4.574969, -1.329541, -0.5423446, 3.2239883, 2.4139037, 2.9802322e-07,
0.19334424, -3.9811373, 3.1507545, 2.0127864, -4.4341884, -1.2173393,
0.72419256, 0.015158802, -4.4655256, -0.34677732, 2.1538897, 2.3990245,
0.5887663, -1.2716188, -1.6747494, -3.415226, 1.2875631, 1.0108626,
2.0268395, -2.3615427, -1.502785, -2.8317401, 3.919288, 3.764695,
-3.2461433, 3.0523314, -0.5022881, 0.9094755, -0.55759126, -0.24697942,
5.0729737, 5.668646, -4.662384, 2.9517999, -2.2060819, 1.2613468,
-1.7881389, -1.1409098, -1.8951292, -2.1522717, -7.4092865, -0.38806117,
-0.6685039, -1.3767233, -0.8713439, 0.71781945, 3.5203605, 0.6790297};
static const std::vector<float> input_data_7 = {
0.73348462, 0.74833735, 0.40982435, 0.51988197, 0.99384421, 0.12469386,
0.47686314, 0.25882564, 0.67028317, 0.58466398, 0.74927361, 0.19614283,
0.82593526, 0.41205770, 0.74020169, 0.62222693, 0.33264240, 0.84108156,
0.86392366, 0.79030966, 0.79792986, 0.47647899, 0.65967837, 0.92732906,
0.90477190, 0.87232389, 0.55734667, 0.75560744, 0.70658521, 0.28530827,
0.02554864, 0.14915414, 0.29936996, 0.74239557, 0.38158196, 0.26483291,
0.15843351, 0.38703221, 0.79967600, 0.63790851, 0.66191234, 0.19395184,
0.34992850, 0.89077723, 0.40746049, 0.01455611, 0.84174579, 0.91950995,
0.43402124, 0.76620100, 0.96476467, 0.78331896, 0.48567269, 0.33793230,
0.20362115, 0.51710568, 0.55455124, 0.10148728, 0.48229121, 0.58612092,
0.91786709, 0.94405867, 0.54302465, 0.24146348, 0.34853454, 0.75880201,
0.67781768, 0.29531289, 0.35969526, 0.01040005, 0.63142510, 0.67264276,
0.57920180, 0.99608063, 0.91108299, 0.82647166, 0.54134147, 0.79556370,
0.18579404, 0.95271365, 0.61918245, 0.17552980, 0.56332554, 0.58036855,
0.33756331, 0.69359258, 0.03914420, 0.14962257, 0.26647894, 0.45042564,
0.60093050, 0.67657016, 0.12601171, 0.95279680, 0.02868298, 0.82188820,
0.17558198, 0.40678849, 0.90804391, 0.21813571, 0.69710526, 0.91450289,
0.44277349, 0.70432336, 0.88161566, 0.23739783, 0.02746046, 0.05775890,
0.63494471, 0.10963744, 0.68260565, 0.87579980, 0.34451002, 0.01422449,
0.44081511, 0.78790226, 0.42010180, 0.62148773, 0.73164358, 0.85657540,
0.21649672, 0.93347654, 0.65511518, 0.45192463, 0.57671214, 0.09925586,
0.76042901, 0.84041443, 0.91933065, 0.00541233, 0.56194300, 0.71416635,
0.15882159, 0.57976451, 0.37377713, 0.48352544, 0.96645849, 0.50040596,
0.06060478, 0.21032667, 0.33303769, 0.80884551, 0.97500277, 0.28607026,
0.12235457, 0.47764468, 0.09834820, 0.08864630, 0.21728048, 0.92446905,
0.53802798, 0.22378462, 0.66087828, 0.64754384, 0.09980577, 0.50331927,
0.90966904, 0.67624758, 0.22728569, 0.61184030, 0.66753081, 0.00405466,
0.93407600, 0.89524725, 0.34496848, 0.01595642, 0.54338693, 0.65760153,
0.69930304, 0.54202591, 0.66030817, 0.74371140, 0.95000083, 0.86475930,
0.99826786, 0.85464029, 0.89926621, 0.90551912, 0.89889036, 0.38316505,
0.06428984, 0.39342267, 0.40689672, 0.37076883, 0.72720439, 0.05071236,
0.01355718, 0.95169120, 0.03623840, 0.05569115, 0.47255274, 0.44040655};
static const std::vector<float> expected_irdft2d_results_1 = {
0.106065355, 0.7454709, 0.5723129, 0.45824066, 0.384706, 0.27398905, 0.6679619, 0.39547434,
0.2815724, 0.779919, 0.59909385, 0.122946456, 0.38957337, 0.97498655, 0.46759892, 0.14017127,
0.04206834, 0.72799486, 0.61560476, 0.9027304, 0.6226336, 0.2601218, 0.5555171, 0.4049862,
0.14175594, 0.57774574, 0.52652067, 0.9385676, 0.958878, 0.9844308, 0.2309568, 0.0970796,
0.24574815, 0.6907565, 0.19743192, 0.8295261, 0.3461272, 0.5140136, 0.66115695, 0.93362343,
0.06690116, 0.74688905, 0.39028272, 0.53575796, 0.060429227, 0.89135474, 0.77787286, 0.67011875,
0.73505205, 0.6636992, 0.18176568, 0.8629964, 0.4514285, 0.6497283, 0.15937214, 0.40598106,
0.7988508, 0.72915316, 0.07090413, 0.76971227, 0.49721542, 0.7669206, 0.6797579, 0.13026048,
0.6587432, 0.24532847, 0.24545121, 0.83795464, 0.10549038, 0.72643167, 0.94568396, 0.72166353,
0.14389817, 0.79305094, 0.7089523, 0.9724684, 0.9775141, 0.49999753, 0.6556916, 0.2687679,
0.6324893, 0.85201234, 0.5689621, 0.023386242, 0.5546462, 0.36860004, 0.9603104, 0.3912346,
0.038073156, 0.8921232, 0.14387667, 0.63858616, 0.10003737, 0.8906622, 0.06681097, 0.74586314,
0.4545233, 0.54724485, 0.6496472, 0.7818348, 0.6608358, 0.77711284, 0.24588637, 0.0134570245,
0.35584468, 0.8038809, 0.027993381, 0.7367708, 0.52754945, 0.90523165, 0.54310995, 0.5367796,
0.41312343, 0.7752323, 0.10966998, 0.13664615, 0.7828726, 0.9083951, 0.524759, 0.7493586,
0.19275239, 0.0071907635, 0.60879755, 0.34413564, 0.4690983, 0.4192482, 0.70729065, 0.1993285,
0.5303842, 0.65138334, 0.06686333, 0.97179186, 0.657022, 0.11786719, 0.3154068, 0.8892283,
0.55640805, 0.9104763, 0.28466636, 0.093467236, 0.88953, 0.9919328, 0.18322526, 0.8185441,
0.56639117, 0.014208457, 0.29673028, 0.6347738, 0.68019533, 0.39601144, 0.34374115, 0.72168803,
0.61525595, 0.76679367, 0.5860848, 0.42768106, 0.7933919, 0.13130645, 0.68764144, 0.05312841,
0.026113434, 0.2982238, 0.7618365, 0.3331724, 0.5468184, 0.15707079, 0.28592035, 0.15286529,
0.9368952, 0.35067078, 0.43364897, 0.089348935, 0.41172677, 0.58502454, 0.7073026, 0.85983366,
0.08878795, 0.2671109, 0.8002475, 0.19422255, 0.83120316, 0.5198712, 0.40111288, 0.98375624,
0.77703446, 0.03781964, 0.7042304, 0.68980736, 0.17102323, 0.42153904, 0.7278248, 0.80301994,
0.91017085, 0.019965423, 0.13768451, 0.556689, 0.17991383, 0.6720085, 0.7733324, 0.20881362};
static const std::vector<float> expected_irdft2d_results_2 = {
0.10606504, 0.74547091, 0.57231341, 0.45824085, 0.38470576, 0.27398939, 0.66796227,
0.39547472, 0.28157284, 0.77991920, 0.00000012, 0.00000025, 0.59909402, 0.12294612,
0.38957398, 0.97498753, 0.46759871, 0.14017182, 0.04206866, 0.72799575, 0.61560553,
0.90273150, 0.00000029, 0.00000019, 0.62263335, 0.26012139, 0.55551768, 0.40498611,
0.14175560, 0.57774629, 0.52652119, 0.93856842, 0.95887877, 0.98443111, 0.00000026,
0.00000029, 0.23095626, 0.09707905, 0.24574875, 0.69075717, 0.19743158, 0.82952691,
0.34612741, 0.51401454, 0.66115784, 0.93362381, 0.00000013, 0.00000019, 0.06690087,
0.74688917, 0.39028283, 0.53575807, 0.06042910, 0.89135566, 0.77787371, 0.67011938,
0.73505260, 0.66369919, 0.00000020, 0.00000025, 0.66083517, 0.77711292, 0.24588620,
0.01345654, 0.35584463, 0.80388178, 0.02799342, 0.73677143, 0.52754998, 0.90523178,
0.00000020, 0.00000022, 0.54311002, 0.53678006, 0.41312413, 0.77523314, 0.10966939,
0.13664682, 0.78287364, 0.90839633, 0.52475940, 0.74935884, 0.00000017, 0.00000024,
0.19275220, 0.00719083, 0.60879792, 0.34413568, 0.46909855, 0.41924857, 0.70729118,
0.19932858, 0.53038468, 0.65138356, 0.00000024, 0.00000004, 0.06686326, 0.97179258,
0.65702215, 0.11786667, 0.31540699, 0.88922984, 0.55640881, 0.91047768, 0.28466661,
0.09346649, 0.00000006, 0.00000008, 0.88953045, 0.99193334, 0.18322520, 0.81854497,
0.56639084, 0.01420842, 0.29673067, 0.63477397, 0.68019596, 0.39601113, 0.00000014,
0.00000022};
static const std::vector<float> expected_irdft3d_results_2 = {
0.29655575, 0.59799123, 0.22431113, 0.46143103, 0.53208175, 0.32705094, 0.59367000,
0.29963828, 0.41763943, 0.24033307, 0.42796425, 0.56577777, 0.37677909, 0.32099129,
0.28778578, 0.50527716, 0.39592624, -0.01477019, 0.46390174, 0.48881302, 0.69299017,
0.69097986, 0.60120016, 0.82729206, -0.09137908, 0.49852066, 0.41157645, 0.50370176,
0.50602146, 0.12422646, 0.66381460, 0.40124601, 0.71138931, 0.66414101, 0.50896081,
0.51854765, 0.21342740, 0.75042767, 0.40385838, 0.28173387, 0.29258505, 0.34233110,
0.44617152, 0.32590713, 0.69813927, 0.27029157, 0.49500125, 0.57849153, 0.52079012,
0.46437605, 0.44842544, 0.21380078, 0.57897044, 0.32123390, 0.46531573, 0.55946432,
0.36995799, 0.19326348, 0.26279333, 0.89411452, 0.45806675, 0.58413552, 0.47982321,
0.40877153, 0.23978246, 0.33369794, 0.56433968, 0.09308288, 0.20574836, 0.51936717,
0.46905154, 0.47775696, 0.17856948, 0.04195880, 0.24284739, 0.63731160, 0.16159543,
0.08925854, 0.50157161, 0.67721677, 0.75653236, 0.50840554, 0.73467008, 0.62163510,
0.00566245, 0.92257200, 0.42133956, 0.45249607, 0.36451271, 0.46674756, 0.65809363,
0.29478180, 0.79919561, 0.37987672, 0.46803394, 0.20036376, 0.30268271, 0.62990812,
0.29745090, 0.46503467, 0.30444576, 0.43581755, 0.38956261, 0.58891618, 0.43936615,
0.12833645, 0.82411153, 0.30960669, 0.24676315, 0.39269569, 0.26772071, 0.46022705,
0.77598541, 0.46882716, 0.40922151, 0.28451272, 0.27156988, 0.32720683, 0.48740341,
0.52519462, 0.47371313, 0.61046947, 0.46505542, 0.04019986, 0.27622309, 0.42926452,
0.49897225, 0.04617115, 0.50902017, 0.74826910, 0.28548445, 0.63409441, 0.13183664,
0.02507987, 0.51695660, 0.50593892, 0.17335312, 0.24157819, 0.45513622, 0.69800550,
0.40604969, 0.47128647, 0.59389774, 0.33534107, 0.50887902, 0.82998967, 0.22642939,
0.32967160, 0.50515564, 0.54070049, 0.28947697, 0.35626388, 0.58235507, 0.30633221,
0.50041779, 0.24975602, 0.38320678, 0.40595842, 0.50651077, 0.42963483, 0.25977121,
0.32014694, 0.37577291, 0.46638206, 0.05511259, 0.45463482, 0.62685054, 0.13046773,
0.49768469, 0.47645129, 0.56182954, 0.74548830, 0.73150766, 0.37579758, 0.14279248,
0.28705593, 0.45403320, 0.50334282, 0.24132925, 0.24104091, 0.31220213, 0.62432518,
0.20954334, 0.09285936, 0.56852238, 0.42261752, 0.52830257, 0.25272655, 0.72091123,
0.46923499, 0.24439716, 0.72211522, 0.33004626, 0.30411236, 0.56189500, 0.37390448,
0.40768394, 0.13754946, 0.41746636, 0.50960175, 0.34250750, 0.65386079, 0.46042782,
0.54099804, 0.41183749, 0.40593833, 0.21666628, 0.38087729, 0.64666439, 0.19817938,
0.29519793, 0.46272810, 0.49454878, 0.59059650, 0.54134465, 0.56793991, 0.29395146,
0.52647797, 0.61291826, 0.24633402, 0.24791051, 0.22666050, 0.43238182, 0.20337301,
0.31388571, 0.59658993, 0.29774026, 0.39935257, 0.77171166, 0.54813165, 0.74253426,
0.49906203, 0.53449270, 0.22820431, 0.19888670, 0.56200754, 0.55242130, 0.36939947,
0.01671917, 0.60996081};
static const std::vector<float> expected_irdft3d_results_3 ={
0.51795123, 0.01846075, 0.03363710, -0.02286412, -0.00527071, -0.05116411,
-0.01142488, -0.01784910, -0.01088149, 0.01049122, -0.00829387, 0.00942086,
-0.02915924, 0.05941228, 0.05868882, -0.02329090, 0.06043447, 0.01260666,
0.04213929, -0.03578551, -0.00354573, -0.02047438, -0.03469945, -0.02365786,
0.00807303, 0.02364844, -0.00346402, -0.00134415, 0.04106979, 0.04961361,
-0.01212564, -0.04288128, -0.26157875, -0.01917418, -0.04232584, 0.02477720,
0.02514449, 0.04955597, -0.00301304, 0.00663580, 0.01947190, -0.01163269,
-0.07920224, -0.01201069, 0.00564843, 0.00283007, -0.05916596, 0.03569793,
-0.02454099, -0.01977048, -0.00360401, 0.00924050, -0.01237082, -0.04213287,
-0.03306797, -0.01442351, -0.02601594, 0.07406829, -0.02896844, 0.00503278,
0.00700455, 0.02915976, 0.01761130, -0.04474307, 0.03632101, 0.00957998,
-0.02003984, -0.04022581, 0.03104216, 0.00388626, 0.05861915, 0.01034101,
-0.00741989, 0.01010181, 0.01496502, -0.00544559, 0.04015258, -0.00600315,
-0.06137903, 0.07850411, -0.00074931, 0.02540785, -0.00166176, 0.02205904,
-0.02429718, 0.04010517, 0.02375359, 0.02229406, 0.01806382, -0.06089136,
0.00447113, -0.03169147, 0.02836490, -0.05821620, 0.03905417, 0.03987032,
0.29899586, -0.02616866, -0.00927641, -0.02134532, -0.02480746, -0.02636082,
-0.05009444, -0.02208490, 0.02632000, 0.00493334, -0.00402312, -0.00935831,
0.04154630, 0.00849218, 0.00232782, -0.01192997, -0.03309486, 0.01678531,
0.03526979, 0.09272132, 0.01420703, -0.01919909, 0.01321082, -0.01661140,
0.07861365, -0.02784724, 0.03900426, -0.00096805, -0.02880604, 0.02753764,
-0.02092520, -0.01412453};
static const std::vector<float> expected_irdft3d_results_4 = {
0.24882269, -0.00554157, -0.00759689, -0.00413212, 0.01099624, 0.02191469,
0.02829072, -0.01410181, 0.04826954, 0.03587530, -0.01151859, 0.03459743,
0.03157633, -0.03446264, 0.03595825, -0.01176664, 0.00625817, 0.00981066,
-0.11900401, -0.02756717, 0.01933546, 0.03042892, -0.04917013, 0.00048474,
-0.01849990, -0.01050222, -0.02433642, -0.08657554, -0.03473007, -0.01486101,
0.00137630, -0.01972852, -0.06159696, 0.02284726, -0.03851998, -0.00885092,
0.02397606, -0.02071742, -0.00586151, -0.01287085, 0.01713095, -0.07724825,
0.05983482, -0.02824272, 0.02959802, 0.04051825, 0.00219584, 0.04053028,
0.00415529, 0.02379833, -0.01936524, 0.04350142, 0.02095385, 0.03121966,
-0.02675550, 0.01142533, 0.05606331, 0.02115209, 0.00866956, 0.05367358,
-0.00479556, 0.05423974, -0.01172735, -0.01203834, 0.00181946, 0.00594081,
0.00527473, 0.00781714, 0.07042868, -0.02243115, 0.03207793, -0.04213578,
0.14912935, -0.01012542, -0.05799989, -0.02889979, 0.02934662, 0.03385938,
0.00951527, -0.01760542, -0.01611288, 0.29838892, -0.01029289, -0.06226702,
-0.03670440, 0.03954893, 0.00725941, 0.04219448, -0.03698240, 0.03564729};
template <element::Type_t ET>
std::vector<IRDFTParams> generateParamsForIRDFT() {
std::vector<IRDFTParams> params{
// irdft1d_eval
IRDFTParams(Shape{2, 10, 6, 2},
Shape{2, 10, 10},
ET,
ET,
input_data_1,
expected_irdft1d_results_1,
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{1}, {2}),
NULL),
// irdft1d_eval_1
IRDFTParams(Shape{2, 10, 6, 2},
Shape{2, 10, 10},
ET,
ET,
input_data_1,
expected_irdft1d_results_1,
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{1}, {-1}),
NULL),
// irdft1d_eval_signal_size_0
IRDFTParams(Shape{2, 10, 6, 2},
Shape{2, 10, 10},
ET,
ET,
input_data_1,
expected_irdft1d_results_1,
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{1}, {2}),
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{1}, {10})),
// irdft1d_eval_signal_size_0_1
IRDFTParams(Shape{2, 10, 6, 2},
Shape{2, 10, 10},
ET,
ET,
input_data_1,
expected_irdft1d_results_1,
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{1}, {2}),
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{1}, {-1})),
// irdft1d_eval_signal_size_1
IRDFTParams(Shape{2, 10, 3, 2},
Shape{2, 10, 5},
ET,
ET,
input_data_2,
expected_irdft1d_results_2,
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{1}, {2}),
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{1}, {5})),
// irdft1d_eval_signal_size_1_1
IRDFTParams(Shape{2, 10, 3, 2},
Shape{2, 10, 5},
ET,
ET,
input_data_2,
expected_irdft1d_results_2,
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{1}, {-1}),
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{1}, {5})),
// irdft1d_eval_signal_size_2
IRDFTParams(Shape{2, 10, 7, 2},
Shape{2, 10, 4},
ET,
ET,
input_data_3,
expected_irdft1d_results_3,
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{1}, {2}),
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{1}, {4})),
// irdft1d_eval_signal_size_2_1
IRDFTParams(Shape{2, 10, 7, 2},
Shape{2, 10, 4},
ET,
ET,
input_data_3,
expected_irdft1d_results_3,
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{1}, {-1}),
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{1}, {4})),
// irdft2d_eval_1
IRDFTParams(Shape{2, 10, 6, 2},
Shape{2, 10, 10},
ET,
ET,
input_data_4,
expected_irdft2d_results_1,
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{2}, {1, 2}),
NULL),
// irdft2d_eval_1_positive_negative_axes
IRDFTParams(Shape{2, 10, 6, 2},
Shape{2, 10, 10},
ET,
ET,
input_data_4,
expected_irdft1d_results_1,
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{2}, {1, -1}),
NULL),
// irdft2d_eval_1_negative_positive_axes
IRDFTParams(Shape{2, 10, 6, 2},
Shape{2, 10, 10},
ET,
ET,
input_data_4,
expected_irdft1d_results_1,
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{2}, {-2, 2}),
NULL),
// irdft2d_eval_1_negative_negative_axes
IRDFTParams(Shape{2, 10, 6, 2},
Shape{2, 10, 10},
ET,
ET,
input_data_4,
expected_irdft1d_results_1,
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{2}, {-2, -1}),
NULL),
// irdft2d_eval_1_signal_size_0_s10_10
IRDFTParams(Shape{2, 10, 6, 2},
Shape{2, 10, 10},
ET,
ET,
input_data_4,
expected_irdft1d_results_1,
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{2}, {1, 2}),
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{2}, {10, 10})),
// irdft2d_eval_1_signal_size_0_s10_10_positive_negative_axes
IRDFTParams(Shape{2, 10, 6, 2},
Shape{2, 10, 10},
ET,
ET,
input_data_4,
expected_irdft1d_results_1,
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{2}, {1, -1}),
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{2}, {10, 10})),
// irdft2d_eval_1_signal_size_0_s10_10_negative_positive_axes
IRDFTParams(Shape{2, 10, 6, 2},
Shape{2, 10, 10},
ET,
ET,
input_data_4,
expected_irdft1d_results_1,
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{2}, {-2, 2}),
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{2}, {10, 10})),
// irdft2d_eval_1_signal_size_0_s10_10_negative_negative_axes
IRDFTParams(Shape{2, 10, 6, 2},
Shape{2, 10, 10},
ET,
ET,
input_data_4,
expected_irdft1d_results_1,
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{2}, {-2, -1}),
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{2}, {10, 10})),
// irdft2d_eval_1_signal_size_0_s10_m1
IRDFTParams(Shape{2, 10, 6, 2},
Shape{2, 10, 10},
ET,
ET,
input_data_4,
expected_irdft1d_results_1,
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{2}, {1, 2}),
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{2}, {10, -1})),
// irdft2d_eval_1_signal_size_0_sm1_10
IRDFTParams(Shape{2, 10, 6, 2},
Shape{2, 10, 10},
ET,
ET,
input_data_4,
expected_irdft1d_results_1,
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{2}, {1, 2}),
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{2}, {-1, 10})),
// irdft2d_eval_1_signal_size_0_sm1_m1
IRDFTParams(Shape{2, 10, 6, 2},
Shape{2, 10, 10},
ET,
ET,
input_data_4,
expected_irdft1d_results_1,
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{2}, {1, 2}),
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{2}, {-1, -1})),
// irdft2d_eval_2_signal_size
IRDFTParams(Shape{2, 5, 7, 2},
Shape{2, 5, 12},
ET,
ET,
input_data_5,
expected_irdft2d_results_2,
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{2}, {1, 2}),
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{2}, {5, 12})),
// irdft2d_eval_2_signal_size_positive_negative_axes
IRDFTParams(Shape{2, 5, 7, 2},
Shape{2, 5, 12},
ET,
ET,
input_data_5,
expected_irdft2d_results_2,
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{2}, {1, -1}),
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{2}, {5, 12})),
// irdft2d_eval_2_signal_size_negative_positive_axes
IRDFTParams(Shape{2, 5, 7, 2},
Shape{2, 5, 12},
ET,
ET,
input_data_5,
expected_irdft2d_results_2,
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{2}, {-2, 2}),
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{2}, {5, 12})),
// irdft2d_eval_2_signal_size_negative_negative_axes
IRDFTParams(Shape{2, 5, 7, 2},
Shape{2, 5, 12},
ET,
ET,
input_data_5,
expected_irdft2d_results_2,
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{2}, {-2, -1}),
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{2}, {5, 12})),
// irdft3d_eval_1
IRDFTParams(Shape{2, 10, 6, 2},
Shape{2, 10, 10},
ET,
ET,
input_data_6,
expected_irdft1d_results_1,
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{3}, {0, 1, 2}),
NULL),
// irdft3d_eval_1_negative_axes_and_signal_size
IRDFTParams(Shape{2, 10, 6, 2},
Shape{2, 10, 10},
ET,
ET,
input_data_6,
expected_irdft1d_results_1,
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{3}, {-3, 1, 2}),
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{3}, {-1, 10, -1})),
// irdft3d_eval_2
IRDFTParams(Shape{2, 10, 6, 2},
Shape{4, 5, 12},
ET,
ET,
input_data_6,
expected_irdft3d_results_2,
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{3}, {0, 1, 2}),
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{3}, {4, 5, 12})),
// irdft3d_eval_2_negative_axes
IRDFTParams(Shape{2, 10, 6, 2},
Shape{4, 5, 12},
ET,
ET,
input_data_6,
expected_irdft3d_results_2,
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{3}, {-3, -2, 2}),
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{3}, {4, 5, 12})),
// irdft3d_reversed_axes
IRDFTParams(Shape{3, 4, 8, 2},
Shape{4, 4, 8},
ET,
ET,
input_data_7,
expected_irdft3d_results_3,
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{3}, {2, 1, 0}),
NULL),
// irdft3d_reversed_negative_axes
IRDFTParams(Shape{3, 4, 8, 2},
Shape{4, 4, 8},
ET,
ET,
input_data_7,
expected_irdft3d_results_3,
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{3}, {-1, -2, -3}),
NULL),
// irdft3d_reversed_axes_with_signals
IRDFTParams(Shape{3, 4, 8, 2},
Shape{10, 3, 3},
ET,
ET,
input_data_7,
expected_irdft3d_results_4,
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{3}, {2, 1, 0}),
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{3}, {3, 3, 10})),
// irdft3d_reversed_negative_axes_with_signals
IRDFTParams(Shape{3, 4, 8, 2},
Shape{10, 3, 3},
ET,
ET,
input_data_7,
expected_irdft3d_results_4,
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{3}, {-1, -2, -3}),
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{3}, {3, 3, 10})),
};
return params;
}
std::vector<IRDFTParams> generateCombinedParamsForIRDFT() {
const std::vector<std::vector<IRDFTParams>> allTypeParams{
generateParamsForIRDFT<element::Type_t::f32>()
};
std::vector<IRDFTParams> combinedParams;
for (const auto& params : allTypeParams) {
combinedParams.insert(combinedParams.end(), params.begin(), params.end());
}
return combinedParams;
}
INSTANTIATE_TEST_SUITE_P(
smoke_IRDFT_With_Hardcoded_Refs,
ReferenceIRDFTLayerTest,
::testing::ValuesIn(generateCombinedParamsForIRDFT()),
ReferenceIRDFTLayerTest::getTestCaseName);
} // namespace