[PT FE] Add aten::fft_rfftn and aten::fft_irfft complex replacers (#17999)

* Add transformations for rfftn and irfftn

* Formatting & reduce tests

* Improvements

* rm unused namespace

* Apply suggestions from review

* Fix format

* Add tests for dynamic type

* Add error messages
This commit is contained in:
Mateusz Mikolajczyk 2023-06-16 11:35:30 +02:00 committed by GitHub
parent 945157cc7b
commit 42c93a70d4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 502 additions and 5 deletions

View File

@ -30,19 +30,21 @@ void ov::op::util::FFTBase::validate_types() {
element::Type input_et = get_input_element_type(0);
NODE_VALIDATION_CHECK(this,
input_et == element::f32 || input_et == element::f16 || input_et == element::bf16,
input_et == element::f32 || input_et == element::f16 || input_et == element::bf16 ||
input_et == element::dynamic,
"FFT op input element type must be f32, f16, or bf16");
element::Type axes_et = get_input_element_type(1);
NODE_VALIDATION_CHECK(this,
axes_et == element::i64 || axes_et == element::i32,
axes_et == element::i64 || axes_et == element::i32 || axes_et == element::dynamic,
"FFT op axes element type must be i32 or i64");
if (num_of_inputs == 3) {
element::Type signal_size_et = get_input_element_type(2);
NODE_VALIDATION_CHECK(this,
signal_size_et == element::i64 || signal_size_et == element::i32,
"FFT op signal_size element type must be i32 or i64");
NODE_VALIDATION_CHECK(
this,
signal_size_et == element::i64 || signal_size_et == element::i32 || signal_size_et == element::dynamic,
"FFT op signal_size element type must be i32 or i64");
}
}

View File

@ -401,3 +401,18 @@ TEST(type_prop, dft_invalid_signal_size) {
EXPECT_HAS_SUBSTRING(error.what(), "Sizes of inputs 'axes' and 'signal_size' must be equal.");
}
}
TEST(type_prop, dft_dynamic_types) {
const auto input_shape = PartialShape{2, 180, 180, 2};
const auto axes_shape = PartialShape::dynamic();
const auto signal_size_shape = PartialShape::dynamic();
const auto ref_output_shape = PartialShape{Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), 2};
auto data = std::make_shared<op::Parameter>(element::dynamic, input_shape);
auto axes_input = std::make_shared<op::Parameter>(element::dynamic, axes_shape);
auto signal_size_input = std::make_shared<op::Parameter>(element::dynamic, signal_size_shape);
auto dft = std::make_shared<op::v7::DFT>(data, axes_input, signal_size_input);
EXPECT_EQ(dft->get_element_type(), element::dynamic);
ASSERT_TRUE(dft->get_output_partial_shape(0).same_scheme(ref_output_shape));
}

View File

@ -389,3 +389,18 @@ TEST(type_prop, idft_invalid_signal_size) {
EXPECT_HAS_SUBSTRING(error.what(), "Sizes of inputs 'axes' and 'signal_size' must be equal.");
}
}
TEST(type_prop, idft_dynamic_types) {
const auto input_shape = PartialShape{2, 180, 180, 2};
const auto axes_shape = PartialShape::dynamic();
const auto signal_size_shape = PartialShape::dynamic();
const auto ref_output_shape = PartialShape{Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), 2};
auto data = std::make_shared<op::Parameter>(element::dynamic, input_shape);
auto axes_input = std::make_shared<op::Parameter>(element::dynamic, axes_shape);
auto signal_size_input = std::make_shared<op::Parameter>(element::dynamic, signal_size_shape);
auto idft = std::make_shared<op::v7::IDFT>(data, axes_input, signal_size_input);
EXPECT_EQ(idft->get_element_type(), element::dynamic);
ASSERT_TRUE(idft->get_output_partial_shape(0).same_scheme(ref_output_shape));
}

View File

@ -400,3 +400,18 @@ TEST(type_prop, irdft_invalid_signal_size) {
EXPECT_HAS_SUBSTRING(error.what(), "Sizes of inputs 'axes' and 'signal_size' of (I)RDFT op must be equal.");
}
}
TEST(type_prop, irdft_dynamic_types) {
const auto input_shape = PartialShape{2, 180, 180, 2};
const auto axes_shape = PartialShape::dynamic();
const auto signal_size_shape = PartialShape::dynamic();
const auto ref_output_shape = PartialShape{Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic()};
auto data = std::make_shared<op::Parameter>(element::dynamic, input_shape);
auto axes_input = std::make_shared<op::Parameter>(element::dynamic, axes_shape);
auto signal_size_input = std::make_shared<op::Parameter>(element::dynamic, signal_size_shape);
auto irdft = std::make_shared<op::v9::IRDFT>(data, axes_input, signal_size_input);
EXPECT_EQ(irdft->get_element_type(), element::dynamic);
ASSERT_TRUE(irdft->get_output_partial_shape(0).same_scheme(ref_output_shape));
}

View File

@ -321,3 +321,18 @@ TEST(type_prop, rdft_invalid_signal_size) {
EXPECT_HAS_SUBSTRING(error.what(), "Sizes of inputs 'axes' and 'signal_size' of (I)RDFT op must be equal.");
}
}
TEST(type_prop, rdft_dynamic_types) {
const auto input_shape = PartialShape{2, 180, 180};
const auto axes_shape = PartialShape::dynamic();
const auto signal_size_shape = PartialShape::dynamic();
const auto ref_output_shape = PartialShape{Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), 2};
auto data = std::make_shared<op::Parameter>(element::dynamic, input_shape);
auto axes_input = std::make_shared<op::Parameter>(element::dynamic, axes_shape);
auto signal_size_input = std::make_shared<op::Parameter>(element::dynamic, signal_size_shape);
auto rdft = std::make_shared<op::v9::RDFT>(data, axes_input, signal_size_input);
EXPECT_EQ(rdft->get_element_type(), element::dynamic);
ASSERT_TRUE(rdft->get_output_partial_shape(0).same_scheme(ref_output_shape));
}

View File

@ -27,11 +27,13 @@
#include "transforms/dict_resolver.hpp"
#include "transforms/einsum_list_construct.hpp"
#include "transforms/index_loop_getitem_replacer.hpp"
#include "transforms/irfftn_complex_replacer.hpp"
#include "transforms/listconstruct_replacer.hpp"
#include "transforms/min_max_prim_list_construct_replacer.hpp"
#include "transforms/prim_list_construct_pad.hpp"
#include "transforms/prim_list_tuple_construct_replacer.hpp"
#include "transforms/prim_list_unpack_replacer.hpp"
#include "transforms/rfftn_complex_replacer.hpp"
#include "transforms/string_equality_replacer.hpp"
#include "translate_session.hpp"
@ -168,6 +170,8 @@ void FrontEnd::normalize(const std::shared_ptr<ov::Model>& model) const {
manager.register_pass<ov::frontend::pytorch::pass::AtenEinsumListConstructReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::MinMaxPrimListConstructReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::StringEqualityReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::RFFTNComplexReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::IRFFTNComplexReplacer>();
manager.register_pass<ov::frontend::pytorch::pass::DecomposeListTupleResults>();
manager.register_pass<ov::frontend::pytorch::pass::DictResolver>();
manager.register_pass<ov::frontend::pytorch::pass::IndexLoopGetitemReplacer>();

View File

@ -0,0 +1,163 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "irfftn_complex_replacer.hpp"
#include "openvino/core/rt_info.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/convert.hpp"
#include "openvino/op/convert_like.hpp"
#include "openvino/op/equal.hpp"
#include "openvino/op/gather.hpp"
#include "openvino/op/irdft.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/range.hpp"
#include "openvino/op/reduce_prod.hpp"
#include "openvino/op/scatter_update.hpp"
#include "openvino/op/select.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/op/sqrt.hpp"
#include "openvino/op/squeeze.hpp"
#include "openvino/op/subtract.hpp"
#include "openvino/op/unsqueeze.hpp"
#include "openvino/op/util/framework_node.hpp"
#include "openvino/pass/pattern/matcher.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "utils.hpp"
namespace ov {
namespace frontend {
namespace pytorch {
namespace pass {
using namespace ov::pass;
using namespace ov::op;
IRFFTNComplexReplacer::IRFFTNComplexReplacer() {
// Transformation used to replace combination of aten::complex -> aten::fft_irfftn torch operators.
// Pattern: aten::complex -> aten::fft_irfftn
auto fft_op = pattern::wrap_type<ov::op::util::FrameworkNode>();
ov::matcher_pass_callback irfftn_callback = [](pattern::Matcher& m) {
// "aten::fft_irfftn(Tensor self, int[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor"
auto irfftn_op = cast_fw_node(m.get_match_root(), "aten::fft_irfftn");
if (!irfftn_op) {
return false;
}
auto const_neg_1 = v0::Constant::create(element::i32, Shape{1}, {-1});
auto const_0 = v0::Constant::create(element::i32, Shape{1}, {0});
auto const_scalar_0 = v0::Constant::create(element::i32, Shape{}, {0});
auto const_1 = v0::Constant::create(element::i32, Shape{1}, {1});
auto const_scalar_1 = v0::Constant::create(element::i32, Shape{}, {1});
auto const_2 = v0::Constant::create(element::i32, Shape{1}, {2});
// Check whether input node being aten::complex.
auto fw_node_complex_input = cast_fw_node(irfftn_op->input_value(0).get_node_shared_ptr(), "aten::complex");
if (!fw_node_complex_input) {
return false;
}
// Concatenate real and imag parts over additional, last dimension.
auto real = std::make_shared<v0::Unsqueeze>(fw_node_complex_input->input_value(0), const_neg_1);
auto imag = std::make_shared<v0::Unsqueeze>(fw_node_complex_input->input_value(1), const_neg_1);
NodeVector complex = {real, imag};
auto input = std::make_shared<v0::Concat>(complex, -1);
// Input shape of complex number (excluding dimension created by concatenation of real and imag)
auto complex_input_shape = std::make_shared<v3::ShapeOf>(fw_node_complex_input->input_value(0), element::i32);
auto input_rank = std::make_shared<v3::ShapeOf>(complex_input_shape, element::i32);
auto input_rank_scalar = std::make_shared<v0::Squeeze>(input_rank);
// Inputs can be either none or ListConstruct. Check whether input values should be used or should be set to
// default values.
bool dim_use_default = is_none_node(irfftn_op->input_value(2));
bool s_use_default = is_none_node(irfftn_op->input_value(1));
// Can be None constant, when used check s_use_default.
auto raw_s_input_maybe = concat_list_construct(irfftn_op->input_value(1)).get_node_shared_ptr();
// Handle dim parameter containing vector of intigers indicating dimensions to be transformed.
std::shared_ptr<ov::Node> dim;
if (!dim_use_default) {
// Dim values is provided, load from input.
dim = std::make_shared<v0::Convert>(concat_list_construct(irfftn_op->input_value(2)), element::i32);
} else if (!s_use_default) {
// If dim is default and s is provided, use last s_len dimensions where s_len is length of s.
auto s_len = std::make_shared<v3::ShapeOf>(raw_s_input_maybe, element::i32);
auto range_start = std::make_shared<v1::Subtract>(input_rank, s_len);
auto range_start_scalar = std::make_shared<v0::Squeeze>(range_start);
dim = std::make_shared<v4::Range>(range_start_scalar, input_rank_scalar, const_scalar_1, element::i32);
} else {
// Dim and s are set to default, use all of dimensions.
dim = std::make_shared<v4::Range>(const_scalar_0, input_rank_scalar, const_scalar_1, element::i32);
}
// Calculate default s values. Use full available size except last element, which is set to even value in last
// dimension: s[-1] = 2 * (complex_input_shape[dim[-1]])
auto default_s_raw = std::make_shared<v8::Gather>(complex_input_shape, dim, const_0);
auto last_s = std::make_shared<v8::Gather>(default_s_raw, const_neg_1, const_0);
auto last_s_m_1 = std::make_shared<v1::Subtract>(last_s, const_1);
auto s_upd = std::make_shared<v1::Multiply>(last_s_m_1, const_2);
auto s_shape = std::make_shared<v3::ShapeOf>(default_s_raw, element::i32);
auto last_s_idx = std::make_shared<v1::Subtract>(s_shape, const_1);
auto default_s = std::make_shared<v3::ScatterUpdate>(default_s_raw, last_s_idx, s_upd, const_0);
// Handle s parameter containing vector of intigers indicating signal sizes for dimensions.
std::shared_ptr<ov::Node> s;
if (!s_use_default) {
// Values for s were provided. Replace -1 values with default full size in given dimension.
auto full_s_cond = std::make_shared<v1::Equal>(raw_s_input_maybe, const_neg_1);
s = std::make_shared<v1::Select>(full_s_cond, default_s, raw_s_input_maybe);
} else {
// Value for s was set to default.
s = default_s;
}
// Handle norm parameter indicating normalization mode to use. Defaults to "backward".
std::string norm;
if (const auto& fw_node_mode = std::dynamic_pointer_cast<ov::op::util::FrameworkNode>(
irfftn_op->input_value(3).get_node_shared_ptr())) {
const auto& attrs = fw_node_mode->get_attrs();
if (attrs.find("string_value") != attrs.end()) {
norm = attrs.at("string_value");
} else {
norm = "backward";
}
} else {
add_exception_to_fw_node(irfftn_op, "aten::fft_irfftn: could not retrive value for norm attribute.");
return false;
}
auto irdft = std::make_shared<v9::IRDFT>(input, dim, s);
// Apply normalizations.
auto n_int = std::make_shared<v1::ReduceProd>(s, const_0);
auto n = std::make_shared<v1::ConvertLike>(n_int, irdft);
std::shared_ptr<ov::Node> normalized_irfftn;
if (norm == "forward") {
normalized_irfftn = std::make_shared<v1::Multiply>(irdft, n);
} else if (norm == "backward") {
normalized_irfftn = irdft;
} else if (norm == "ortho") {
auto sqrt_n = std::make_shared<v0::Sqrt>(n);
normalized_irfftn = std::make_shared<v1::Multiply>(irdft, sqrt_n);
} else {
add_exception_to_fw_node(
irfftn_op,
"aten::fft_irfftn: unrecognized normalization mode. Only forward, backward and ortho are supported.");
return false;
}
copy_runtime_info({irfftn_op, fw_node_complex_input}, normalized_irfftn);
normalized_irfftn->set_friendly_name(irfftn_op->get_friendly_name());
replace_node(irfftn_op, normalized_irfftn);
return true;
};
auto m = std::make_shared<pattern::Matcher>(fft_op, "ov::frontend::pytorch::pass::IRFFTNComplexReplacer");
this->register_matcher(m, irfftn_callback);
};
} // namespace pass
} // namespace pytorch
} // namespace frontend
} // namespace ov

View File

@ -0,0 +1,24 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/pass/graph_rewrite.hpp"
#include "openvino/pass/pass.hpp"
namespace ov {
namespace frontend {
namespace pytorch {
namespace pass {
class IRFFTNComplexReplacer : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ov::frontend::pytorch::pass::IRFFTNComplexReplacer");
IRFFTNComplexReplacer();
};
} // namespace pass
} // namespace pytorch
} // namespace frontend
} // namespace ov

View File

@ -0,0 +1,162 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "rfftn_complex_replacer.hpp"
#include "openvino/core/rt_info.hpp"
#include "openvino/op/convert.hpp"
#include "openvino/op/convert_like.hpp"
#include "openvino/op/divide.hpp"
#include "openvino/op/equal.hpp"
#include "openvino/op/gather.hpp"
#include "openvino/op/range.hpp"
#include "openvino/op/rdft.hpp"
#include "openvino/op/reduce_prod.hpp"
#include "openvino/op/select.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/op/slice.hpp"
#include "openvino/op/split.hpp"
#include "openvino/op/sqrt.hpp"
#include "openvino/op/squeeze.hpp"
#include "openvino/op/subtract.hpp"
#include "openvino/op/util/framework_node.hpp"
#include "openvino/pass/pattern/matcher.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "utils.hpp"
namespace ov {
namespace frontend {
namespace pytorch {
namespace pass {
using namespace ov::pass;
using namespace ov::op;
RFFTNComplexReplacer::RFFTNComplexReplacer() {
// Transformation used to replace combination of aten::fft_rfftn -> {aten::real, aten::imag} torch operators.
// Pattern: aten::fft_rfftn -> {aten::real, aten::imag}
auto fft_op = pattern::wrap_type<ov::op::util::FrameworkNode>();
ov::matcher_pass_callback rfftn_callback = [](pattern::Matcher& m) {
// Schema: "aten::fft_rfftn(Tensor self, int[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor"
auto rfftn_op = cast_fw_node(m.get_match_root(), "aten::fft_rfftn");
if (!rfftn_op) {
return false;
}
auto const_neg_1 = v0::Constant::create(element::i32, Shape{}, {-1});
auto const_0 = v0::Constant::create(element::i32, Shape{}, {0});
auto const_1 = v0::Constant::create(element::i32, Shape{}, {1});
auto input = rfftn_op->input_value(0);
auto input_shape = std::make_shared<v3::ShapeOf>(input, element::i32);
auto input_rank = std::make_shared<v3::ShapeOf>(input_shape, element::i32);
auto input_rank_scalar = std::make_shared<v0::Squeeze>(input_rank);
// Inputs can be either none or ListConstruct. Check whether input values should be used or should be set to
// default values.
bool dim_use_default = is_none_node(rfftn_op->input_value(2));
bool s_use_default = is_none_node(rfftn_op->input_value(1));
// Can be None constant, when used check s_use_default.
auto raw_s_input_maybe = concat_list_construct(rfftn_op->input_value(1)).get_node_shared_ptr();
// Handle dim parameter containing vector of intigers indicating dimensions to be transformed.
std::shared_ptr<ov::Node> dim;
if (!dim_use_default) {
// Dim values is provided, load from input.
dim = std::make_shared<v0::Convert>(concat_list_construct(rfftn_op->input_value(2)), element::i32);
} else if (!s_use_default) {
// If dim is default and s is provided, use last s_len dimensions where s_len is length of s.
auto s_len = std::make_shared<v3::ShapeOf>(raw_s_input_maybe, element::i32);
auto slice_start = std::make_shared<v1::Subtract>(input_rank, s_len);
auto slice_start_scalar = std::make_shared<v0::Squeeze>(slice_start);
dim = std::make_shared<v4::Range>(slice_start_scalar, input_rank_scalar, const_1, element::i32);
} else {
// Dim and s are set to default, use all of dimensions.
dim = std::make_shared<v4::Range>(const_0, input_rank_scalar, const_1, element::i32);
}
// Handle s parameter containing vector of intigers indicating signal sizes for dimensions.
std::shared_ptr<ov::Node> s;
if (!s_use_default) {
// Values for s were provided. Replace -1 values with default full size in given dimension.
auto full_s_cond = std::make_shared<v1::Equal>(raw_s_input_maybe, const_neg_1);
auto full_s_values = std::make_shared<v8::Gather>(input_shape, dim, const_0);
s = std::make_shared<v1::Select>(full_s_cond, full_s_values, raw_s_input_maybe);
} else {
// Value for s was set to default, use full size for all dimensions.
s = std::make_shared<v8::Gather>(input_shape, dim, const_0);
}
// Handle norm parameter indicating normalization mode to use. Defaults to "backward".
std::string norm;
if (const auto& fw_node_mode = std::dynamic_pointer_cast<ov::op::util::FrameworkNode>(
rfftn_op->input_value(3).get_node_shared_ptr())) {
const auto& attrs = fw_node_mode->get_attrs();
if (attrs.find("string_value") != attrs.end()) {
norm = attrs.at("string_value");
} else {
norm = "backward";
}
} else {
add_exception_to_fw_node(rfftn_op, "aten::fft_rfftn: could not retrive value for norm attribute.");
return false;
}
auto rdft = std::make_shared<v9::RDFT>(input, dim, s);
// Apply normalizations
auto n_int = std::make_shared<v1::ReduceProd>(s, const_0);
auto n = std::make_shared<v1::ConvertLike>(n_int, rdft);
std::shared_ptr<ov::Node> normalized_rfftn;
if (norm == "forward") {
// Normalize by 1/n
normalized_rfftn = std::make_shared<v1::Divide>(rdft, n);
} else if (norm == "backward") {
// No normalization
normalized_rfftn = rdft;
} else if (norm == "ortho") {
// Normalize by 1/sqrt(n)
auto sqrt_n = std::make_shared<v0::Sqrt>(n);
normalized_rfftn = std::make_shared<v1::Divide>(rdft, sqrt_n);
} else {
add_exception_to_fw_node(
rfftn_op,
"aten::fft_rfftn: unrecognized normalization mode. Only forward, backward and ortho are supported.");
return false;
}
// Replace outputs that are either torch operators aten::real or aten::imag. Apply squeeze to remove last
// dimension used to concatenate.
auto normalized_rfftn_splitted = std::make_shared<v1::Split>(normalized_rfftn, const_neg_1, 2);
auto rfftn_outs = rfftn_op->get_users();
bool rval = false;
for (auto out : rfftn_outs) {
if (auto real_op = cast_fw_node(out, "aten::real")) {
auto squeezed = std::make_shared<v0::Squeeze>(normalized_rfftn_splitted->output(0), const_neg_1);
copy_runtime_info({rfftn_op, real_op}, squeezed);
squeezed->set_friendly_name(real_op->get_friendly_name());
replace_node(real_op, squeezed);
rval = true;
}
if (auto imag_op = cast_fw_node(out, "aten::imag")) {
auto squeezed = std::make_shared<v0::Squeeze>(normalized_rfftn_splitted->output(1), const_neg_1);
copy_runtime_info({rfftn_op, imag_op}, squeezed);
squeezed->set_friendly_name(imag_op->get_friendly_name());
replace_node(imag_op, squeezed);
rval = true;
}
}
add_exception_to_fw_node(
rfftn_op,
"aten::fft_rfftn: Unsupported output node. Only aten::real and aten::imag are supported.");
return rval;
};
auto m = std::make_shared<pattern::Matcher>(fft_op, "ov::frontend::pytorch::pass::RFFTNComplexReplacer");
this->register_matcher(m, rfftn_callback);
};
} // namespace pass
} // namespace pytorch
} // namespace frontend
} // namespace ov

View File

@ -0,0 +1,24 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/pass/graph_rewrite.hpp"
#include "openvino/pass/pass.hpp"
namespace ov {
namespace frontend {
namespace pytorch {
namespace pass {
class RFFTNComplexReplacer : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ov::frontend::pytorch::pass::RFFTNComplexReplacer");
RFFTNComplexReplacer();
};
} // namespace pass
} // namespace pytorch
} // namespace frontend
} // namespace ov

View File

@ -331,6 +331,16 @@ std::shared_ptr<ov::op::util::FrameworkNode> cast_fw_node(std::shared_ptr<Node>
return fw_node;
}
bool is_none_node(const Output<Node>& node) {
if (const auto& fw_node_inp = std::dynamic_pointer_cast<ov::op::util::FrameworkNode>(node.get_node_shared_ptr())) {
const auto& attrs = fw_node_inp->get_attrs();
if (attrs.find("none_value") != attrs.end()) {
return true;
}
}
return false;
}
Any simplified_type_interpret(Any type) {
// Interpret Tensor[type] as just type
// After applying of this interpretation we cannot distinguish true scalars (not tensors) and tensors with elements

View File

@ -58,6 +58,8 @@ OutputVector make_framework_node(const NodeContext& context, const std::string&
std::shared_ptr<op::util::FrameworkNode> cast_fw_node(std::shared_ptr<Node> node, const std::string& type);
bool is_none_node(const Output<Node>& node);
// TODO: Eliminate the need of this function by implementing more accurate custom data type handling
Any simplified_type_interpret(Any type);

View File

@ -0,0 +1,46 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import numpy as np
import pytest
import torch
from pytorch_layer_test_class import PytorchLayerTest
class TestRFFTN(PytorchLayerTest):
def _prepare_input(self):
return (np.random.randn(*self.input_shape).astype(np.float32),)
def create_model(self, dim, s, norm):
class aten_fft_rfftn(torch.nn.Module):
def __init__(self, dim, s, norm):
super(aten_fft_rfftn, self).__init__()
self.dim = dim
self.s = s
self.norm = norm
def forward(self, x):
rfftn = torch.fft.rfftn(x, s=self.s, dim=self.dim, norm=self.norm)
r = rfftn.real
i = rfftn.imag
irfftn = torch.fft.irfftn(torch.complex(r, i), s=self.s, dim=self.dim, norm=self.norm)
return irfftn, r, i
ref_net = None
return (
aten_fft_rfftn(dim, s, norm),
ref_net,
["aten::fft_irfftn", "aten::complex", "aten::fft_rfftn", "aten::real", "aten::imag"],
)
@pytest.mark.parametrize("input_shape", [[64, 49], [64, 50], [64, 64, 49]])
@pytest.mark.parametrize("dim", [[0, -1], [-2, -1], None, [0, 1]])
@pytest.mark.parametrize("s", [None, [-1, 49], [64, -1], [64, 49], [5, 1]])
@pytest.mark.parametrize("norm", ["forward", "backward", "ortho", None])
@pytest.mark.nightly
@pytest.mark.precommit
def test_rfftn(self, ie_device, precision, ir_version, input_shape, dim, s, norm):
self.input_shape = input_shape
# Unfrozen test would fail due to issues with prim::GetAttr containing lists, strings or none.
self._test(*self.create_model(dim, s, norm), ie_device, precision, ir_version, custom_eps=1e-3, freeze_model=True)