Fix IRDFT for case when axes are in reversed order (#12155)

This commit is contained in:
Mateusz Bencer
2022-07-15 12:22:52 +02:00
committed by GitHub
parent dc7efafd7c
commit e8310f7e0b
4 changed files with 120 additions and 7 deletions

View File

@@ -2456,7 +2456,8 @@ bool evaluate(const shared_ptr<op::v9::IRDFT>& op, const HostTensorVector& outpu
info.axes_data,
irfft_result.data(),
info.fft_output_shape,
info.output_shape);
info.output_shape,
info.last_signal_size);
const auto output_type = op->get_input_element_type(0);
runtime::reference::fft_postprocessing(outputs, output_type, irfft_result);

View File

@@ -366,6 +366,40 @@ static const std::vector<float> input_data_6 = {
-1.7881389, -1.1409098, -1.8951292, -2.1522717, -7.4092865, -0.38806117,
-0.6685039, -1.3767233, -0.8713439, 0.71781945, 3.5203605, 0.6790297};
static const std::vector<float> input_data_7 = {
0.73348462, 0.74833735, 0.40982435, 0.51988197, 0.99384421, 0.12469386,
0.47686314, 0.25882564, 0.67028317, 0.58466398, 0.74927361, 0.19614283,
0.82593526, 0.41205770, 0.74020169, 0.62222693, 0.33264240, 0.84108156,
0.86392366, 0.79030966, 0.79792986, 0.47647899, 0.65967837, 0.92732906,
0.90477190, 0.87232389, 0.55734667, 0.75560744, 0.70658521, 0.28530827,
0.02554864, 0.14915414, 0.29936996, 0.74239557, 0.38158196, 0.26483291,
0.15843351, 0.38703221, 0.79967600, 0.63790851, 0.66191234, 0.19395184,
0.34992850, 0.89077723, 0.40746049, 0.01455611, 0.84174579, 0.91950995,
0.43402124, 0.76620100, 0.96476467, 0.78331896, 0.48567269, 0.33793230,
0.20362115, 0.51710568, 0.55455124, 0.10148728, 0.48229121, 0.58612092,
0.91786709, 0.94405867, 0.54302465, 0.24146348, 0.34853454, 0.75880201,
0.67781768, 0.29531289, 0.35969526, 0.01040005, 0.63142510, 0.67264276,
0.57920180, 0.99608063, 0.91108299, 0.82647166, 0.54134147, 0.79556370,
0.18579404, 0.95271365, 0.61918245, 0.17552980, 0.56332554, 0.58036855,
0.33756331, 0.69359258, 0.03914420, 0.14962257, 0.26647894, 0.45042564,
0.60093050, 0.67657016, 0.12601171, 0.95279680, 0.02868298, 0.82188820,
0.17558198, 0.40678849, 0.90804391, 0.21813571, 0.69710526, 0.91450289,
0.44277349, 0.70432336, 0.88161566, 0.23739783, 0.02746046, 0.05775890,
0.63494471, 0.10963744, 0.68260565, 0.87579980, 0.34451002, 0.01422449,
0.44081511, 0.78790226, 0.42010180, 0.62148773, 0.73164358, 0.85657540,
0.21649672, 0.93347654, 0.65511518, 0.45192463, 0.57671214, 0.09925586,
0.76042901, 0.84041443, 0.91933065, 0.00541233, 0.56194300, 0.71416635,
0.15882159, 0.57976451, 0.37377713, 0.48352544, 0.96645849, 0.50040596,
0.06060478, 0.21032667, 0.33303769, 0.80884551, 0.97500277, 0.28607026,
0.12235457, 0.47764468, 0.09834820, 0.08864630, 0.21728048, 0.92446905,
0.53802798, 0.22378462, 0.66087828, 0.64754384, 0.09980577, 0.50331927,
0.90966904, 0.67624758, 0.22728569, 0.61184030, 0.66753081, 0.00405466,
0.93407600, 0.89524725, 0.34496848, 0.01595642, 0.54338693, 0.65760153,
0.69930304, 0.54202591, 0.66030817, 0.74371140, 0.95000083, 0.86475930,
0.99826786, 0.85464029, 0.89926621, 0.90551912, 0.89889036, 0.38316505,
0.06428984, 0.39342267, 0.40689672, 0.37076883, 0.72720439, 0.05071236,
0.01355718, 0.95169120, 0.03623840, 0.05569115, 0.47255274, 0.44040655};
static const std::vector<float> expected_irdft2d_results_1 = {
0.106065355, 0.7454709, 0.5723129, 0.45824066, 0.384706, 0.27398905, 0.6679619, 0.39547434,
0.2815724, 0.779919, 0.59909385, 0.122946456, 0.38957337, 0.97498655, 0.46759892, 0.14017127,
@@ -413,7 +447,7 @@ static const std::vector<float> expected_irdft2d_results_2 = {
0.56639084, 0.01420842, 0.29673067, 0.63477397, 0.68019596, 0.39601113, 0.00000014,
0.00000022};
static const std::vector<float> expected_rdft3d_results_2 = {
static const std::vector<float> expected_irdft3d_results_2 = {
0.29655575, 0.59799123, 0.22431113, 0.46143103, 0.53208175, 0.32705094, 0.59367000,
0.29963828, 0.41763943, 0.24033307, 0.42796425, 0.56577777, 0.37677909, 0.32099129,
0.28778578, 0.50527716, 0.39592624, -0.01477019, 0.46390174, 0.48881302, 0.69299017,
@@ -450,6 +484,47 @@ static const std::vector<float> expected_rdft3d_results_2 = {
0.49906203, 0.53449270, 0.22820431, 0.19888670, 0.56200754, 0.55242130, 0.36939947,
0.01671917, 0.60996081};
static const std::vector<float> expected_irdft3d_results_3 ={
0.51795123, 0.01846075, 0.03363710, -0.02286412, -0.00527071, -0.05116411,
-0.01142488, -0.01784910, -0.01088149, 0.01049122, -0.00829387, 0.00942086,
-0.02915924, 0.05941228, 0.05868882, -0.02329090, 0.06043447, 0.01260666,
0.04213929, -0.03578551, -0.00354573, -0.02047438, -0.03469945, -0.02365786,
0.00807303, 0.02364844, -0.00346402, -0.00134415, 0.04106979, 0.04961361,
-0.01212564, -0.04288128, -0.26157875, -0.01917418, -0.04232584, 0.02477720,
0.02514449, 0.04955597, -0.00301304, 0.00663580, 0.01947190, -0.01163269,
-0.07920224, -0.01201069, 0.00564843, 0.00283007, -0.05916596, 0.03569793,
-0.02454099, -0.01977048, -0.00360401, 0.00924050, -0.01237082, -0.04213287,
-0.03306797, -0.01442351, -0.02601594, 0.07406829, -0.02896844, 0.00503278,
0.00700455, 0.02915976, 0.01761130, -0.04474307, 0.03632101, 0.00957998,
-0.02003984, -0.04022581, 0.03104216, 0.00388626, 0.05861915, 0.01034101,
-0.00741989, 0.01010181, 0.01496502, -0.00544559, 0.04015258, -0.00600315,
-0.06137903, 0.07850411, -0.00074931, 0.02540785, -0.00166176, 0.02205904,
-0.02429718, 0.04010517, 0.02375359, 0.02229406, 0.01806382, -0.06089136,
0.00447113, -0.03169147, 0.02836490, -0.05821620, 0.03905417, 0.03987032,
0.29899586, -0.02616866, -0.00927641, -0.02134532, -0.02480746, -0.02636082,
-0.05009444, -0.02208490, 0.02632000, 0.00493334, -0.00402312, -0.00935831,
0.04154630, 0.00849218, 0.00232782, -0.01192997, -0.03309486, 0.01678531,
0.03526979, 0.09272132, 0.01420703, -0.01919909, 0.01321082, -0.01661140,
0.07861365, -0.02784724, 0.03900426, -0.00096805, -0.02880604, 0.02753764,
-0.02092520, -0.01412453};
static const std::vector<float> expected_irdft3d_results_4 = {
0.24882269, -0.00554157, -0.00759689, -0.00413212, 0.01099624, 0.02191469,
0.02829072, -0.01410181, 0.04826954, 0.03587530, -0.01151859, 0.03459743,
0.03157633, -0.03446264, 0.03595825, -0.01176664, 0.00625817, 0.00981066,
-0.11900401, -0.02756717, 0.01933546, 0.03042892, -0.04917013, 0.00048474,
-0.01849990, -0.01050222, -0.02433642, -0.08657554, -0.03473007, -0.01486101,
0.00137630, -0.01972852, -0.06159696, 0.02284726, -0.03851998, -0.00885092,
0.02397606, -0.02071742, -0.00586151, -0.01287085, 0.01713095, -0.07724825,
0.05983482, -0.02824272, 0.02959802, 0.04051825, 0.00219584, 0.04053028,
0.00415529, 0.02379833, -0.01936524, 0.04350142, 0.02095385, 0.03121966,
-0.02675550, 0.01142533, 0.05606331, 0.02115209, 0.00866956, 0.05367358,
-0.00479556, 0.05423974, -0.01172735, -0.01203834, 0.00181946, 0.00594081,
0.00527473, 0.00781714, 0.07042868, -0.02243115, 0.03207793, -0.04213578,
0.14912935, -0.01012542, -0.05799989, -0.02889979, 0.02934662, 0.03385938,
0.00951527, -0.01760542, -0.01611288, 0.29838892, -0.01029289, -0.06226702,
-0.03670440, 0.03954893, 0.00725941, 0.04219448, -0.03698240, 0.03564729};
template <element::Type_t ET>
std::vector<IRDFTParams> generateParamsForIRDFT() {
std::vector<IRDFTParams> params{
@@ -684,7 +759,7 @@ std::vector<IRDFTParams> generateParamsForIRDFT() {
ET,
ET,
input_data_6,
expected_rdft3d_results_2,
expected_irdft3d_results_2,
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{3}, {0, 1, 2}),
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{3}, {4, 5, 12})),
// irdft3d_eval_2_negative_axes
@@ -693,9 +768,45 @@ std::vector<IRDFTParams> generateParamsForIRDFT() {
ET,
ET,
input_data_6,
expected_rdft3d_results_2,
expected_irdft3d_results_2,
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{3}, {-3, -2, 2}),
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{3}, {4, 5, 12})),
// irdft3d_reversed_axes
IRDFTParams(Shape{3, 4, 8, 2},
Shape{4, 4, 8},
ET,
ET,
input_data_7,
expected_irdft3d_results_3,
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{3}, {2, 1, 0}),
NULL),
// irdft3d_reversed_negative_axes
IRDFTParams(Shape{3, 4, 8, 2},
Shape{4, 4, 8},
ET,
ET,
input_data_7,
expected_irdft3d_results_3,
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{3}, {-1, -2, -3}),
NULL),
// irdft3d_reversed_axes_with_signals
IRDFTParams(Shape{3, 4, 8, 2},
Shape{10, 3, 3},
ET,
ET,
input_data_7,
expected_irdft3d_results_4,
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{3}, {2, 1, 0}),
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{3}, {3, 3, 10})),
// irdft3d_reversed_negative_axes_with_signals
IRDFTParams(Shape{3, 4, 8, 2},
Shape{10, 3, 3},
ET,
ET,
input_data_7,
expected_irdft3d_results_4,
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{3}, {-1, -2, -3}),
op::v0::Constant::create<int64_t>(element::Type_t::i64, Shape{3}, {3, 3, 10})),
};
return params;

View File

@@ -16,7 +16,8 @@ void irdft(const std::vector<float>& input_data,
const std::vector<int64_t>& axes_data,
float* irdft_result,
const Shape& fft_output_shape,
const Shape& irdft_output_shape);
const Shape& irdft_output_shape,
const int64_t last_signal_size);
} // namespace reference
} // namespace runtime
} // namespace ngraph

View File

@@ -104,9 +104,9 @@ void irdft(const std::vector<float>& input_data,
const std::vector<int64_t>& axes_data,
float* irdft_result,
const Shape& fft_output_shape,
const Shape& irdft_output_shape) {
const Shape& irdft_output_shape,
const int64_t last_signal_size) {
// calculate inverse FFT over the outer axes
const int64_t last_signal_size = irdft_output_shape.back();
const auto outer_ifft_axes = get_outer_fft_axes(axes_data);
auto outer_ifft_shape = input_data_shape;
for (const auto& a : outer_ifft_axes) {