From 42c93a70d4ba68ee48fe8b372508e195e4d0044c Mon Sep 17 00:00:00 2001 From: Mateusz Mikolajczyk Date: Fri, 16 Jun 2023 11:35:30 +0200 Subject: [PATCH] [PT FE] Add aten::fft_rfftn and aten::fft_irfft complex replacers (#17999) * Add transformations for rfftn and irfftn * Formatting & reduce tests * Improvements * rm unused namespace * Apply suggestions from review * Fix format * Add tests for dynamic type * Add error messages --- src/core/src/op/util/fft_base.cpp | 12 +- src/core/tests/type_prop/dft.cpp | 15 ++ src/core/tests/type_prop/idft.cpp | 15 ++ src/core/tests/type_prop/irdft.cpp | 15 ++ src/core/tests/type_prop/rdft.cpp | 15 ++ src/frontends/pytorch/src/frontend.cpp | 4 + .../transforms/irfftn_complex_replacer.cpp | 163 ++++++++++++++++++ .../transforms/irfftn_complex_replacer.hpp | 24 +++ .../src/transforms/rfftn_complex_replacer.cpp | 162 +++++++++++++++++ .../src/transforms/rfftn_complex_replacer.hpp | 24 +++ src/frontends/pytorch/src/utils.cpp | 10 ++ src/frontends/pytorch/src/utils.hpp | 2 + .../test_rfftn_complex_transforms.py | 46 +++++ 13 files changed, 502 insertions(+), 5 deletions(-) create mode 100644 src/frontends/pytorch/src/transforms/irfftn_complex_replacer.cpp create mode 100644 src/frontends/pytorch/src/transforms/irfftn_complex_replacer.hpp create mode 100644 src/frontends/pytorch/src/transforms/rfftn_complex_replacer.cpp create mode 100644 src/frontends/pytorch/src/transforms/rfftn_complex_replacer.hpp create mode 100644 tests/layer_tests/pytorch_tests/test_rfftn_complex_transforms.py diff --git a/src/core/src/op/util/fft_base.cpp b/src/core/src/op/util/fft_base.cpp index bfc58ed4145..7d615721adc 100644 --- a/src/core/src/op/util/fft_base.cpp +++ b/src/core/src/op/util/fft_base.cpp @@ -30,19 +30,21 @@ void ov::op::util::FFTBase::validate_types() { element::Type input_et = get_input_element_type(0); NODE_VALIDATION_CHECK(this, - input_et == element::f32 || input_et == element::f16 || input_et == element::bf16, + input_et == element::f32 || input_et == element::f16 || input_et == element::bf16 || + input_et == element::dynamic, "FFT op input element type must be f32, f16, or bf16"); element::Type axes_et = get_input_element_type(1); NODE_VALIDATION_CHECK(this, - axes_et == element::i64 || axes_et == element::i32, + axes_et == element::i64 || axes_et == element::i32 || axes_et == element::dynamic, "FFT op axes element type must be i32 or i64"); if (num_of_inputs == 3) { element::Type signal_size_et = get_input_element_type(2); - NODE_VALIDATION_CHECK(this, - signal_size_et == element::i64 || signal_size_et == element::i32, - "FFT op signal_size element type must be i32 or i64"); + NODE_VALIDATION_CHECK( + this, + signal_size_et == element::i64 || signal_size_et == element::i32 || signal_size_et == element::dynamic, + "FFT op signal_size element type must be i32 or i64"); } } diff --git a/src/core/tests/type_prop/dft.cpp b/src/core/tests/type_prop/dft.cpp index 1fa3733b38a..a8aa945dcb4 100644 --- a/src/core/tests/type_prop/dft.cpp +++ b/src/core/tests/type_prop/dft.cpp @@ -401,3 +401,18 @@ TEST(type_prop, dft_invalid_signal_size) { EXPECT_HAS_SUBSTRING(error.what(), "Sizes of inputs 'axes' and 'signal_size' must be equal."); } } + +TEST(type_prop, dft_dynamic_types) { + const auto input_shape = PartialShape{2, 180, 180, 2}; + const auto axes_shape = PartialShape::dynamic(); + const auto signal_size_shape = PartialShape::dynamic(); + const auto ref_output_shape = PartialShape{Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), 2}; + + auto data = std::make_shared(element::dynamic, input_shape); + auto axes_input = std::make_shared(element::dynamic, axes_shape); + auto signal_size_input = std::make_shared(element::dynamic, signal_size_shape); + auto dft = std::make_shared(data, axes_input, signal_size_input); + + EXPECT_EQ(dft->get_element_type(), element::dynamic); + ASSERT_TRUE(dft->get_output_partial_shape(0).same_scheme(ref_output_shape)); +} diff --git a/src/core/tests/type_prop/idft.cpp b/src/core/tests/type_prop/idft.cpp index fe3251f7299..0ff0f4a0957 100644 --- a/src/core/tests/type_prop/idft.cpp +++ b/src/core/tests/type_prop/idft.cpp @@ -389,3 +389,18 @@ TEST(type_prop, idft_invalid_signal_size) { EXPECT_HAS_SUBSTRING(error.what(), "Sizes of inputs 'axes' and 'signal_size' must be equal."); } } + +TEST(type_prop, idft_dynamic_types) { + const auto input_shape = PartialShape{2, 180, 180, 2}; + const auto axes_shape = PartialShape::dynamic(); + const auto signal_size_shape = PartialShape::dynamic(); + const auto ref_output_shape = PartialShape{Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), 2}; + + auto data = std::make_shared(element::dynamic, input_shape); + auto axes_input = std::make_shared(element::dynamic, axes_shape); + auto signal_size_input = std::make_shared(element::dynamic, signal_size_shape); + auto idft = std::make_shared(data, axes_input, signal_size_input); + + EXPECT_EQ(idft->get_element_type(), element::dynamic); + ASSERT_TRUE(idft->get_output_partial_shape(0).same_scheme(ref_output_shape)); +} diff --git a/src/core/tests/type_prop/irdft.cpp b/src/core/tests/type_prop/irdft.cpp index cabc7439bfa..79529755f3e 100644 --- a/src/core/tests/type_prop/irdft.cpp +++ b/src/core/tests/type_prop/irdft.cpp @@ -400,3 +400,18 @@ TEST(type_prop, irdft_invalid_signal_size) { EXPECT_HAS_SUBSTRING(error.what(), "Sizes of inputs 'axes' and 'signal_size' of (I)RDFT op must be equal."); } } + +TEST(type_prop, irdft_dynamic_types) { + const auto input_shape = PartialShape{2, 180, 180, 2}; + const auto axes_shape = PartialShape::dynamic(); + const auto signal_size_shape = PartialShape::dynamic(); + const auto ref_output_shape = PartialShape{Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()}; + + auto data = std::make_shared(element::dynamic, input_shape); + auto axes_input = std::make_shared(element::dynamic, axes_shape); + auto signal_size_input = std::make_shared(element::dynamic, signal_size_shape); + auto irdft = std::make_shared(data, axes_input, signal_size_input); + + EXPECT_EQ(irdft->get_element_type(), element::dynamic); + ASSERT_TRUE(irdft->get_output_partial_shape(0).same_scheme(ref_output_shape)); +} diff --git a/src/core/tests/type_prop/rdft.cpp b/src/core/tests/type_prop/rdft.cpp index 38a3d60afe4..c2a035c2240 100644 --- a/src/core/tests/type_prop/rdft.cpp +++ b/src/core/tests/type_prop/rdft.cpp @@ -321,3 +321,18 @@ TEST(type_prop, rdft_invalid_signal_size) { EXPECT_HAS_SUBSTRING(error.what(), "Sizes of inputs 'axes' and 'signal_size' of (I)RDFT op must be equal."); } } + +TEST(type_prop, rdft_dynamic_types) { + const auto input_shape = PartialShape{2, 180, 180}; + const auto axes_shape = PartialShape::dynamic(); + const auto signal_size_shape = PartialShape::dynamic(); + const auto ref_output_shape = PartialShape{Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), 2}; + + auto data = std::make_shared(element::dynamic, input_shape); + auto axes_input = std::make_shared(element::dynamic, axes_shape); + auto signal_size_input = std::make_shared(element::dynamic, signal_size_shape); + auto rdft = std::make_shared(data, axes_input, signal_size_input); + + EXPECT_EQ(rdft->get_element_type(), element::dynamic); + ASSERT_TRUE(rdft->get_output_partial_shape(0).same_scheme(ref_output_shape)); +} diff --git a/src/frontends/pytorch/src/frontend.cpp b/src/frontends/pytorch/src/frontend.cpp index ee9e59d0f6c..9ed4ddfb817 100644 --- a/src/frontends/pytorch/src/frontend.cpp +++ b/src/frontends/pytorch/src/frontend.cpp @@ -27,11 +27,13 @@ #include "transforms/dict_resolver.hpp" #include "transforms/einsum_list_construct.hpp" #include "transforms/index_loop_getitem_replacer.hpp" +#include "transforms/irfftn_complex_replacer.hpp" #include "transforms/listconstruct_replacer.hpp" #include "transforms/min_max_prim_list_construct_replacer.hpp" #include "transforms/prim_list_construct_pad.hpp" #include "transforms/prim_list_tuple_construct_replacer.hpp" #include "transforms/prim_list_unpack_replacer.hpp" +#include "transforms/rfftn_complex_replacer.hpp" #include "transforms/string_equality_replacer.hpp" #include "translate_session.hpp" @@ -168,6 +170,8 @@ void FrontEnd::normalize(const std::shared_ptr& model) const { manager.register_pass(); manager.register_pass(); manager.register_pass(); + manager.register_pass(); + manager.register_pass(); manager.register_pass(); manager.register_pass(); manager.register_pass(); diff --git a/src/frontends/pytorch/src/transforms/irfftn_complex_replacer.cpp b/src/frontends/pytorch/src/transforms/irfftn_complex_replacer.cpp new file mode 100644 index 00000000000..327b8756b68 --- /dev/null +++ b/src/frontends/pytorch/src/transforms/irfftn_complex_replacer.cpp @@ -0,0 +1,163 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "irfftn_complex_replacer.hpp" + +#include "openvino/core/rt_info.hpp" +#include "openvino/op/concat.hpp" +#include "openvino/op/convert.hpp" +#include "openvino/op/convert_like.hpp" +#include "openvino/op/equal.hpp" +#include "openvino/op/gather.hpp" +#include "openvino/op/irdft.hpp" +#include "openvino/op/multiply.hpp" +#include "openvino/op/range.hpp" +#include "openvino/op/reduce_prod.hpp" +#include "openvino/op/scatter_update.hpp" +#include "openvino/op/select.hpp" +#include "openvino/op/shape_of.hpp" +#include "openvino/op/sqrt.hpp" +#include "openvino/op/squeeze.hpp" +#include "openvino/op/subtract.hpp" +#include "openvino/op/unsqueeze.hpp" +#include "openvino/op/util/framework_node.hpp" +#include "openvino/pass/pattern/matcher.hpp" +#include "openvino/pass/pattern/op/wrap_type.hpp" +#include "utils.hpp" + +namespace ov { +namespace frontend { +namespace pytorch { +namespace pass { + +using namespace ov::pass; +using namespace ov::op; + +IRFFTNComplexReplacer::IRFFTNComplexReplacer() { + // Transformation used to replace combination of aten::complex -> aten::fft_irfftn torch operators. + // Pattern: aten::complex -> aten::fft_irfftn + auto fft_op = pattern::wrap_type(); + + ov::matcher_pass_callback irfftn_callback = [](pattern::Matcher& m) { + // "aten::fft_irfftn(Tensor self, int[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor" + auto irfftn_op = cast_fw_node(m.get_match_root(), "aten::fft_irfftn"); + if (!irfftn_op) { + return false; + } + auto const_neg_1 = v0::Constant::create(element::i32, Shape{1}, {-1}); + auto const_0 = v0::Constant::create(element::i32, Shape{1}, {0}); + auto const_scalar_0 = v0::Constant::create(element::i32, Shape{}, {0}); + auto const_1 = v0::Constant::create(element::i32, Shape{1}, {1}); + auto const_scalar_1 = v0::Constant::create(element::i32, Shape{}, {1}); + auto const_2 = v0::Constant::create(element::i32, Shape{1}, {2}); + + // Check whether input node being aten::complex. + auto fw_node_complex_input = cast_fw_node(irfftn_op->input_value(0).get_node_shared_ptr(), "aten::complex"); + if (!fw_node_complex_input) { + return false; + } + + // Concatenate real and imag parts over additional, last dimension. + auto real = std::make_shared(fw_node_complex_input->input_value(0), const_neg_1); + auto imag = std::make_shared(fw_node_complex_input->input_value(1), const_neg_1); + NodeVector complex = {real, imag}; + auto input = std::make_shared(complex, -1); + + // Input shape of complex number (excluding dimension created by concatenation of real and imag) + auto complex_input_shape = std::make_shared(fw_node_complex_input->input_value(0), element::i32); + auto input_rank = std::make_shared(complex_input_shape, element::i32); + auto input_rank_scalar = std::make_shared(input_rank); + + // Inputs can be either none or ListConstruct. Check whether input values should be used or should be set to + // default values. + bool dim_use_default = is_none_node(irfftn_op->input_value(2)); + bool s_use_default = is_none_node(irfftn_op->input_value(1)); + // Can be None constant, when used check s_use_default. + auto raw_s_input_maybe = concat_list_construct(irfftn_op->input_value(1)).get_node_shared_ptr(); + + // Handle dim parameter containing vector of intigers indicating dimensions to be transformed. + std::shared_ptr dim; + if (!dim_use_default) { + // Dim values is provided, load from input. + dim = std::make_shared(concat_list_construct(irfftn_op->input_value(2)), element::i32); + } else if (!s_use_default) { + // If dim is default and s is provided, use last s_len dimensions where s_len is length of s. + auto s_len = std::make_shared(raw_s_input_maybe, element::i32); + auto range_start = std::make_shared(input_rank, s_len); + auto range_start_scalar = std::make_shared(range_start); + dim = std::make_shared(range_start_scalar, input_rank_scalar, const_scalar_1, element::i32); + } else { + // Dim and s are set to default, use all of dimensions. + dim = std::make_shared(const_scalar_0, input_rank_scalar, const_scalar_1, element::i32); + } + + // Calculate default s values. Use full available size except last element, which is set to even value in last + // dimension: s[-1] = 2 * (complex_input_shape[dim[-1]]) + auto default_s_raw = std::make_shared(complex_input_shape, dim, const_0); + auto last_s = std::make_shared(default_s_raw, const_neg_1, const_0); + auto last_s_m_1 = std::make_shared(last_s, const_1); + auto s_upd = std::make_shared(last_s_m_1, const_2); + auto s_shape = std::make_shared(default_s_raw, element::i32); + auto last_s_idx = std::make_shared(s_shape, const_1); + auto default_s = std::make_shared(default_s_raw, last_s_idx, s_upd, const_0); + + // Handle s parameter containing vector of intigers indicating signal sizes for dimensions. + std::shared_ptr s; + if (!s_use_default) { + // Values for s were provided. Replace -1 values with default full size in given dimension. + auto full_s_cond = std::make_shared(raw_s_input_maybe, const_neg_1); + s = std::make_shared(full_s_cond, default_s, raw_s_input_maybe); + } else { + // Value for s was set to default. + s = default_s; + } + + // Handle norm parameter indicating normalization mode to use. Defaults to "backward". + std::string norm; + if (const auto& fw_node_mode = std::dynamic_pointer_cast( + irfftn_op->input_value(3).get_node_shared_ptr())) { + const auto& attrs = fw_node_mode->get_attrs(); + if (attrs.find("string_value") != attrs.end()) { + norm = attrs.at("string_value"); + } else { + norm = "backward"; + } + } else { + add_exception_to_fw_node(irfftn_op, "aten::fft_irfftn: could not retrive value for norm attribute."); + return false; + } + + auto irdft = std::make_shared(input, dim, s); + + // Apply normalizations. + auto n_int = std::make_shared(s, const_0); + auto n = std::make_shared(n_int, irdft); + std::shared_ptr normalized_irfftn; + if (norm == "forward") { + normalized_irfftn = std::make_shared(irdft, n); + } else if (norm == "backward") { + normalized_irfftn = irdft; + } else if (norm == "ortho") { + auto sqrt_n = std::make_shared(n); + normalized_irfftn = std::make_shared(irdft, sqrt_n); + } else { + add_exception_to_fw_node( + irfftn_op, + "aten::fft_irfftn: unrecognized normalization mode. Only forward, backward and ortho are supported."); + return false; + } + + copy_runtime_info({irfftn_op, fw_node_complex_input}, normalized_irfftn); + normalized_irfftn->set_friendly_name(irfftn_op->get_friendly_name()); + replace_node(irfftn_op, normalized_irfftn); + return true; + }; + auto m = std::make_shared(fft_op, "ov::frontend::pytorch::pass::IRFFTNComplexReplacer"); + this->register_matcher(m, irfftn_callback); +}; + +} // namespace pass +} // namespace pytorch +} // namespace frontend +} // namespace ov diff --git a/src/frontends/pytorch/src/transforms/irfftn_complex_replacer.hpp b/src/frontends/pytorch/src/transforms/irfftn_complex_replacer.hpp new file mode 100644 index 00000000000..63061defd1f --- /dev/null +++ b/src/frontends/pytorch/src/transforms/irfftn_complex_replacer.hpp @@ -0,0 +1,24 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/pass/graph_rewrite.hpp" +#include "openvino/pass/pass.hpp" + +namespace ov { +namespace frontend { +namespace pytorch { +namespace pass { + +class IRFFTNComplexReplacer : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("ov::frontend::pytorch::pass::IRFFTNComplexReplacer"); + IRFFTNComplexReplacer(); +}; + +} // namespace pass +} // namespace pytorch +} // namespace frontend +} // namespace ov diff --git a/src/frontends/pytorch/src/transforms/rfftn_complex_replacer.cpp b/src/frontends/pytorch/src/transforms/rfftn_complex_replacer.cpp new file mode 100644 index 00000000000..692bbf04fd1 --- /dev/null +++ b/src/frontends/pytorch/src/transforms/rfftn_complex_replacer.cpp @@ -0,0 +1,162 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "rfftn_complex_replacer.hpp" + +#include "openvino/core/rt_info.hpp" +#include "openvino/op/convert.hpp" +#include "openvino/op/convert_like.hpp" +#include "openvino/op/divide.hpp" +#include "openvino/op/equal.hpp" +#include "openvino/op/gather.hpp" +#include "openvino/op/range.hpp" +#include "openvino/op/rdft.hpp" +#include "openvino/op/reduce_prod.hpp" +#include "openvino/op/select.hpp" +#include "openvino/op/shape_of.hpp" +#include "openvino/op/slice.hpp" +#include "openvino/op/split.hpp" +#include "openvino/op/sqrt.hpp" +#include "openvino/op/squeeze.hpp" +#include "openvino/op/subtract.hpp" +#include "openvino/op/util/framework_node.hpp" +#include "openvino/pass/pattern/matcher.hpp" +#include "openvino/pass/pattern/op/wrap_type.hpp" +#include "utils.hpp" + +namespace ov { +namespace frontend { +namespace pytorch { +namespace pass { + +using namespace ov::pass; +using namespace ov::op; + +RFFTNComplexReplacer::RFFTNComplexReplacer() { + // Transformation used to replace combination of aten::fft_rfftn -> {aten::real, aten::imag} torch operators. + // Pattern: aten::fft_rfftn -> {aten::real, aten::imag} + auto fft_op = pattern::wrap_type(); + ov::matcher_pass_callback rfftn_callback = [](pattern::Matcher& m) { + // Schema: "aten::fft_rfftn(Tensor self, int[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor" + auto rfftn_op = cast_fw_node(m.get_match_root(), "aten::fft_rfftn"); + if (!rfftn_op) { + return false; + } + auto const_neg_1 = v0::Constant::create(element::i32, Shape{}, {-1}); + auto const_0 = v0::Constant::create(element::i32, Shape{}, {0}); + auto const_1 = v0::Constant::create(element::i32, Shape{}, {1}); + + auto input = rfftn_op->input_value(0); + auto input_shape = std::make_shared(input, element::i32); + auto input_rank = std::make_shared(input_shape, element::i32); + auto input_rank_scalar = std::make_shared(input_rank); + + // Inputs can be either none or ListConstruct. Check whether input values should be used or should be set to + // default values. + bool dim_use_default = is_none_node(rfftn_op->input_value(2)); + bool s_use_default = is_none_node(rfftn_op->input_value(1)); + // Can be None constant, when used check s_use_default. + auto raw_s_input_maybe = concat_list_construct(rfftn_op->input_value(1)).get_node_shared_ptr(); + + // Handle dim parameter containing vector of intigers indicating dimensions to be transformed. + std::shared_ptr dim; + if (!dim_use_default) { + // Dim values is provided, load from input. + dim = std::make_shared(concat_list_construct(rfftn_op->input_value(2)), element::i32); + } else if (!s_use_default) { + // If dim is default and s is provided, use last s_len dimensions where s_len is length of s. + auto s_len = std::make_shared(raw_s_input_maybe, element::i32); + auto slice_start = std::make_shared(input_rank, s_len); + auto slice_start_scalar = std::make_shared(slice_start); + dim = std::make_shared(slice_start_scalar, input_rank_scalar, const_1, element::i32); + } else { + // Dim and s are set to default, use all of dimensions. + dim = std::make_shared(const_0, input_rank_scalar, const_1, element::i32); + } + + // Handle s parameter containing vector of intigers indicating signal sizes for dimensions. + std::shared_ptr s; + if (!s_use_default) { + // Values for s were provided. Replace -1 values with default full size in given dimension. + auto full_s_cond = std::make_shared(raw_s_input_maybe, const_neg_1); + auto full_s_values = std::make_shared(input_shape, dim, const_0); + s = std::make_shared(full_s_cond, full_s_values, raw_s_input_maybe); + } else { + // Value for s was set to default, use full size for all dimensions. + s = std::make_shared(input_shape, dim, const_0); + } + + // Handle norm parameter indicating normalization mode to use. Defaults to "backward". + std::string norm; + if (const auto& fw_node_mode = std::dynamic_pointer_cast( + rfftn_op->input_value(3).get_node_shared_ptr())) { + const auto& attrs = fw_node_mode->get_attrs(); + if (attrs.find("string_value") != attrs.end()) { + norm = attrs.at("string_value"); + } else { + norm = "backward"; + } + } else { + add_exception_to_fw_node(rfftn_op, "aten::fft_rfftn: could not retrive value for norm attribute."); + return false; + } + + auto rdft = std::make_shared(input, dim, s); + + // Apply normalizations + auto n_int = std::make_shared(s, const_0); + auto n = std::make_shared(n_int, rdft); + std::shared_ptr normalized_rfftn; + if (norm == "forward") { + // Normalize by 1/n + normalized_rfftn = std::make_shared(rdft, n); + } else if (norm == "backward") { + // No normalization + normalized_rfftn = rdft; + } else if (norm == "ortho") { + // Normalize by 1/sqrt(n) + auto sqrt_n = std::make_shared(n); + normalized_rfftn = std::make_shared(rdft, sqrt_n); + } else { + add_exception_to_fw_node( + rfftn_op, + "aten::fft_rfftn: unrecognized normalization mode. Only forward, backward and ortho are supported."); + return false; + } + + // Replace outputs that are either torch operators aten::real or aten::imag. Apply squeeze to remove last + // dimension used to concatenate. + auto normalized_rfftn_splitted = std::make_shared(normalized_rfftn, const_neg_1, 2); + auto rfftn_outs = rfftn_op->get_users(); + bool rval = false; + for (auto out : rfftn_outs) { + if (auto real_op = cast_fw_node(out, "aten::real")) { + auto squeezed = std::make_shared(normalized_rfftn_splitted->output(0), const_neg_1); + copy_runtime_info({rfftn_op, real_op}, squeezed); + squeezed->set_friendly_name(real_op->get_friendly_name()); + replace_node(real_op, squeezed); + rval = true; + } + if (auto imag_op = cast_fw_node(out, "aten::imag")) { + auto squeezed = std::make_shared(normalized_rfftn_splitted->output(1), const_neg_1); + copy_runtime_info({rfftn_op, imag_op}, squeezed); + squeezed->set_friendly_name(imag_op->get_friendly_name()); + replace_node(imag_op, squeezed); + rval = true; + } + } + add_exception_to_fw_node( + rfftn_op, + "aten::fft_rfftn: Unsupported output node. Only aten::real and aten::imag are supported."); + return rval; + }; + + auto m = std::make_shared(fft_op, "ov::frontend::pytorch::pass::RFFTNComplexReplacer"); + this->register_matcher(m, rfftn_callback); +}; + +} // namespace pass +} // namespace pytorch +} // namespace frontend +} // namespace ov diff --git a/src/frontends/pytorch/src/transforms/rfftn_complex_replacer.hpp b/src/frontends/pytorch/src/transforms/rfftn_complex_replacer.hpp new file mode 100644 index 00000000000..3a6bceb0dfe --- /dev/null +++ b/src/frontends/pytorch/src/transforms/rfftn_complex_replacer.hpp @@ -0,0 +1,24 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/pass/graph_rewrite.hpp" +#include "openvino/pass/pass.hpp" + +namespace ov { +namespace frontend { +namespace pytorch { +namespace pass { + +class RFFTNComplexReplacer : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("ov::frontend::pytorch::pass::RFFTNComplexReplacer"); + RFFTNComplexReplacer(); +}; + +} // namespace pass +} // namespace pytorch +} // namespace frontend +} // namespace ov diff --git a/src/frontends/pytorch/src/utils.cpp b/src/frontends/pytorch/src/utils.cpp index b6e9d0972be..1f364895e44 100644 --- a/src/frontends/pytorch/src/utils.cpp +++ b/src/frontends/pytorch/src/utils.cpp @@ -331,6 +331,16 @@ std::shared_ptr cast_fw_node(std::shared_ptr return fw_node; } +bool is_none_node(const Output& node) { + if (const auto& fw_node_inp = std::dynamic_pointer_cast(node.get_node_shared_ptr())) { + const auto& attrs = fw_node_inp->get_attrs(); + if (attrs.find("none_value") != attrs.end()) { + return true; + } + } + return false; +} + Any simplified_type_interpret(Any type) { // Interpret Tensor[type] as just type // After applying of this interpretation we cannot distinguish true scalars (not tensors) and tensors with elements diff --git a/src/frontends/pytorch/src/utils.hpp b/src/frontends/pytorch/src/utils.hpp index a3732b681ae..6fb2a036cf1 100644 --- a/src/frontends/pytorch/src/utils.hpp +++ b/src/frontends/pytorch/src/utils.hpp @@ -58,6 +58,8 @@ OutputVector make_framework_node(const NodeContext& context, const std::string& std::shared_ptr cast_fw_node(std::shared_ptr node, const std::string& type); +bool is_none_node(const Output& node); + // TODO: Eliminate the need of this function by implementing more accurate custom data type handling Any simplified_type_interpret(Any type); diff --git a/tests/layer_tests/pytorch_tests/test_rfftn_complex_transforms.py b/tests/layer_tests/pytorch_tests/test_rfftn_complex_transforms.py new file mode 100644 index 00000000000..7e3dd169e5b --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_rfftn_complex_transforms.py @@ -0,0 +1,46 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import pytest +import torch +from pytorch_layer_test_class import PytorchLayerTest + + +class TestRFFTN(PytorchLayerTest): + def _prepare_input(self): + return (np.random.randn(*self.input_shape).astype(np.float32),) + + def create_model(self, dim, s, norm): + class aten_fft_rfftn(torch.nn.Module): + def __init__(self, dim, s, norm): + super(aten_fft_rfftn, self).__init__() + self.dim = dim + self.s = s + self.norm = norm + + def forward(self, x): + rfftn = torch.fft.rfftn(x, s=self.s, dim=self.dim, norm=self.norm) + r = rfftn.real + i = rfftn.imag + irfftn = torch.fft.irfftn(torch.complex(r, i), s=self.s, dim=self.dim, norm=self.norm) + return irfftn, r, i + + ref_net = None + + return ( + aten_fft_rfftn(dim, s, norm), + ref_net, + ["aten::fft_irfftn", "aten::complex", "aten::fft_rfftn", "aten::real", "aten::imag"], + ) + + @pytest.mark.parametrize("input_shape", [[64, 49], [64, 50], [64, 64, 49]]) + @pytest.mark.parametrize("dim", [[0, -1], [-2, -1], None, [0, 1]]) + @pytest.mark.parametrize("s", [None, [-1, 49], [64, -1], [64, 49], [5, 1]]) + @pytest.mark.parametrize("norm", ["forward", "backward", "ortho", None]) + @pytest.mark.nightly + @pytest.mark.precommit + def test_rfftn(self, ie_device, precision, ir_version, input_shape, dim, s, norm): + self.input_shape = input_shape + # Unfrozen test would fail due to issues with prim::GetAttr containing lists, strings or none. + self._test(*self.create_model(dim, s, norm), ie_device, precision, ir_version, custom_eps=1e-3, freeze_model=True)