[ShapeInfer] FFTBase shape infer improvement - preserve input sizes bounds and labels (#19463)
* Reduce number of rank checks * Preserve data shape if signal_size input is not provided * Add bounds propagation on fft input * Improved preserving bounds on fft input * Remove size_t rank cast and have_axes variable * Check refactor * Use ge helper for rank comparison * Make bounds constexpr * Pass raw pointer instead of unique_ptr ref * Use normalize_axes helper * Ensure to call set label if it's not zero
This commit is contained in:
parent
8723b5dd6d
commit
bd66971dd6
@ -2,57 +2,92 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
#pragma once
|
||||
#include <openvino/op/util/fft_base.hpp>
|
||||
|
||||
#include "dimension_util.hpp"
|
||||
#include "fft_common_validation.hpp"
|
||||
#include "openvino/core/axis_vector.hpp"
|
||||
#include "openvino/core/dimension_tracker.hpp"
|
||||
#include "openvino/core/validation_util.hpp"
|
||||
#include "openvino/op/util/fft_base.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace op {
|
||||
|
||||
namespace fft {
|
||||
template <class TRShape, typename std::enable_if<std::is_same<TRShape, PartialShape>::value>::type* = nullptr>
|
||||
void apply_dims_from_sizes(const util::FFTBase* op,
|
||||
TRShape& output_shape,
|
||||
const std::vector<int64_t>& axes,
|
||||
const ITensorAccessor& ta) {
|
||||
using namespace ov::util;
|
||||
using DimType = typename TRShape::value_type;
|
||||
|
||||
if (const auto output_bounds = get_input_bounds<TRShape, int64_t>(op, 2, ta)) {
|
||||
const auto minus_one_bound = std::make_pair(dim::inf_bound, dim::inf_bound);
|
||||
const auto num_of_axes = axes.size();
|
||||
const auto labels =
|
||||
op->get_input_size() > 2 ? op->get_input_source_output(2).get_tensor().get_value_label() : TensorLabel();
|
||||
const bool propagate_labels = num_of_axes <= labels.size();
|
||||
for (size_t i = 0; i < num_of_axes; ++i) {
|
||||
if ((*output_bounds)[i] != minus_one_bound) {
|
||||
auto& out_dim = output_shape[(axes)[i]];
|
||||
out_dim = DimType((*output_bounds)[i].first, (*output_bounds)[i].second);
|
||||
if (propagate_labels && labels[i] != ov::no_label) {
|
||||
DimensionTracker::set_label(out_dim, labels[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (auto axis : axes) {
|
||||
output_shape[axis] = ov::Dimension::dynamic();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class TRShape, typename std::enable_if<!std::is_same<TRShape, PartialShape>::value>::type* = nullptr>
|
||||
void apply_dims_from_sizes(const util::FFTBase* op,
|
||||
TRShape& output_shape,
|
||||
const std::vector<int64_t>& axes,
|
||||
const ITensorAccessor& ta) {
|
||||
using namespace ov::util;
|
||||
using DimType = typename TRShape::value_type;
|
||||
|
||||
if (const auto output_dim_vals = get_input_const_data_as<TRShape, int64_t>(op, 2, ta)) {
|
||||
const auto num_of_axes = axes.size();
|
||||
for (size_t i = 0; i < num_of_axes; ++i) {
|
||||
if ((*output_dim_vals)[i] != dim::inf_bound) {
|
||||
output_shape[(axes)[i]] = DimType((*output_dim_vals)[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace fft
|
||||
|
||||
template <class T, class TRShape = result_shape_t<T>>
|
||||
std::vector<TRShape> shape_infer(const util::FFTBase* op,
|
||||
const std::vector<T>& input_shapes,
|
||||
const ITensorAccessor& ta = make_tensor_accessor()) {
|
||||
using DimType = typename T::value_type;
|
||||
|
||||
NODE_VALIDATION_CHECK(op, (input_shapes.size() == 2 || input_shapes.size() == 3));
|
||||
|
||||
const auto& input_shape = input_shapes[0];
|
||||
const auto& axes_shape = input_shapes[1];
|
||||
auto output_shapes = std::vector<TRShape>(1);
|
||||
auto& output_shape = output_shapes[0];
|
||||
auto axes = get_input_const_data_as<TRShape, int64_t>(op, 1, ta);
|
||||
|
||||
util::fft_common_validation::shape_validation(op,
|
||||
input_shapes,
|
||||
*axes,
|
||||
static_cast<bool>(axes),
|
||||
axes.get(),
|
||||
util::fft_common_validation::FFTKind::ComplexInput);
|
||||
|
||||
output_shape = input_shape;
|
||||
|
||||
if (input_shape.rank().is_static() && axes_shape.rank().is_static() && input_shapes.size() == 3 && axes) {
|
||||
const auto& signal_size_shape = input_shapes[2];
|
||||
auto signal_size = get_input_const_data_as<TRShape, int64_t>(op, 2, ta);
|
||||
|
||||
if (signal_size_shape.rank().is_static() && signal_size) {
|
||||
size_t num_of_axes = axes->size();
|
||||
for (size_t i = 0; i < num_of_axes; ++i) {
|
||||
if ((*signal_size)[i] == -1) {
|
||||
continue;
|
||||
}
|
||||
output_shape[(*axes)[i]] = DimType((*signal_size)[i]);
|
||||
if (input_shapes.size() == 3 && input_shape.rank().is_static()) {
|
||||
if (axes) {
|
||||
ov::op::fft::apply_dims_from_sizes(op, output_shape, *axes, ta);
|
||||
} else {
|
||||
for (size_t i = 0; i < input_shape.size() - 1; ++i) {
|
||||
output_shape[i] = ov::Dimension::dynamic();
|
||||
}
|
||||
} else if (signal_size_shape.rank().is_static()) {
|
||||
for (int64_t& axis : *axes) {
|
||||
output_shape[axis] = ov::Dimension::dynamic();
|
||||
}
|
||||
}
|
||||
} else if (input_shape.rank().is_static() && (axes_shape.rank().is_dynamic() || !axes)) {
|
||||
const auto input_rank = input_shape.size();
|
||||
for (size_t i = 0; i < input_rank - 1; ++i) {
|
||||
output_shape[i] = ov::Dimension::dynamic();
|
||||
}
|
||||
}
|
||||
return output_shapes;
|
||||
|
@ -11,15 +11,15 @@ namespace ov {
|
||||
namespace op {
|
||||
namespace util {
|
||||
namespace fft_common_validation {
|
||||
enum class FFTKind { RealInput, ComplexInput };
|
||||
enum FFTKind { RealInput = 1, ComplexInput = 2 };
|
||||
template <class T>
|
||||
void validate_input_rank(const ov::op::util::FFTBase* op,
|
||||
const std::vector<T>& input_shapes,
|
||||
const T& input_shape,
|
||||
const T& axes_shape,
|
||||
size_t input_rank,
|
||||
int64_t input_rank,
|
||||
FFTKind fft_kind) {
|
||||
const size_t min_rank = (fft_kind == FFTKind::RealInput) ? 1 : 2;
|
||||
const int64_t min_rank = fft_kind;
|
||||
NODE_SHAPE_INFER_CHECK(op,
|
||||
input_shapes,
|
||||
input_rank >= min_rank,
|
||||
@ -40,12 +40,12 @@ void validate_input_rank(const ov::op::util::FFTBase* op,
|
||||
if (fft_kind == FFTKind::RealInput) {
|
||||
NODE_SHAPE_INFER_CHECK(op,
|
||||
input_shapes,
|
||||
input_rank >= static_cast<size_t>(axes_shape[0].get_length()),
|
||||
ov::cmp::ge(input_rank, axes_shape[0].get_length()),
|
||||
"The input rank must be greater than or equal to the number of axes. ");
|
||||
} else {
|
||||
NODE_SHAPE_INFER_CHECK(op,
|
||||
input_shapes,
|
||||
input_rank >= static_cast<size_t>(axes_shape[0].get_length() + 1),
|
||||
ov::cmp::ge(input_rank, axes_shape[0].get_length() + 1),
|
||||
"The input rank must be greater than number of axes.");
|
||||
}
|
||||
}
|
||||
@ -55,10 +55,9 @@ void validate_axes(const ov::op::util::FFTBase* op,
|
||||
const std::vector<T>& input_shapes,
|
||||
const T& axes_shape,
|
||||
std::vector<int64_t>& axes,
|
||||
size_t input_rank,
|
||||
bool axes_are_known,
|
||||
int64_t input_rank,
|
||||
FFTKind fft_kind) {
|
||||
if (axes_shape.rank().is_dynamic() || !axes_are_known) {
|
||||
if (axes_shape.rank().is_dynamic()) {
|
||||
return;
|
||||
}
|
||||
|
||||
@ -73,33 +72,11 @@ void validate_axes(const ov::op::util::FFTBase* op,
|
||||
// according to the RDFT operation specification, axes should be integers from -r to (r - 1)
|
||||
// inclusively, where r = rank(data). A negative axis 'a' is interpreted as an axis 'r + a'.
|
||||
const int64_t axis_correction = (fft_kind == FFTKind::RealInput) ? input_rank : (input_rank - 1);
|
||||
auto axis_min_value = -static_cast<int64_t>(input_rank);
|
||||
auto axis_max_value = static_cast<int64_t>(input_rank) - 1;
|
||||
|
||||
// RDFT op axes can contain the last axis
|
||||
if (fft_kind == FFTKind::RealInput) {
|
||||
--axis_min_value;
|
||||
++axis_max_value;
|
||||
}
|
||||
|
||||
ov::AxisSet axes_set;
|
||||
for (int64_t& axis : axes) {
|
||||
NODE_SHAPE_INFER_CHECK(op,
|
||||
input_shapes,
|
||||
axis_min_value < axis && axis < axis_max_value,
|
||||
"Axis value: ",
|
||||
axis,
|
||||
", must be in range (",
|
||||
axis_min_value,
|
||||
", ",
|
||||
axis_max_value,
|
||||
").");
|
||||
if (axis < 0) {
|
||||
axis += axis_correction;
|
||||
}
|
||||
axes_set.insert(static_cast<size_t>(axis));
|
||||
}
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
ov::normalize_axes(op, axis_correction, axes);
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
|
||||
auto axes_set = AxisSet(std::vector<size_t>(axes.begin(), axes.end()));
|
||||
NODE_VALIDATION_CHECK(op, axes.size() == axes_set.size(), "Each axis must be unique.");
|
||||
}
|
||||
|
||||
@ -124,16 +101,18 @@ void validate_signal_size(const ov::op::util::FFTBase* op,
|
||||
template <class T>
|
||||
void shape_validation(const ov::op::util::FFTBase* op,
|
||||
const std::vector<T>& input_shapes,
|
||||
std::vector<int64_t>& axes,
|
||||
bool axes_are_known,
|
||||
std::vector<int64_t>* axes,
|
||||
FFTKind fft_kind) {
|
||||
const auto& input_shape = input_shapes[0];
|
||||
const auto& axes_shape = input_shapes[1];
|
||||
|
||||
if (input_shape.rank().is_static()) {
|
||||
const auto input_rank = input_shape.size();
|
||||
validate_input_rank(op, input_shapes, input_shape, axes_shape, input_rank, fft_kind);
|
||||
validate_axes(op, input_shapes, axes_shape, axes, input_rank, axes_are_known, fft_kind);
|
||||
const auto input_shape_rank = input_shape.rank();
|
||||
if (input_shape_rank.is_static()) {
|
||||
const auto input_rank_length = input_shape_rank.get_length();
|
||||
validate_input_rank(op, input_shapes, input_shape, axes_shape, input_rank_length, fft_kind);
|
||||
if (axes) {
|
||||
validate_axes(op, input_shapes, axes_shape, *axes, input_rank_length, fft_kind);
|
||||
}
|
||||
}
|
||||
|
||||
NODE_SHAPE_INFER_CHECK(op, input_shapes, axes_shape.rank().compatible(1), "Axes input must be 1D tensor.");
|
||||
|
@ -28,8 +28,7 @@ std::vector<TRShape> shape_infer(const IRDFT* op,
|
||||
|
||||
util::fft_common_validation::shape_validation(op,
|
||||
input_shapes,
|
||||
*axes,
|
||||
axes_are_known,
|
||||
axes.get(),
|
||||
util::fft_common_validation::FFTKind::ComplexInput);
|
||||
|
||||
if (input_shape.rank().is_dynamic()) {
|
||||
|
@ -39,8 +39,7 @@ std::vector<TRShape> shape_infer(const RDFT* op,
|
||||
|
||||
util::fft_common_validation::shape_validation(op,
|
||||
input_shapes,
|
||||
*axes,
|
||||
static_cast<bool>(axes),
|
||||
axes.get(),
|
||||
util::fft_common_validation::FFTKind::RealInput);
|
||||
|
||||
if (input_shape.rank().is_dynamic()) {
|
||||
|
@ -283,11 +283,8 @@ TYPED_TEST_P(FFTNonConstantAxesTest, non_constant_axes_no_signal_size) {
|
||||
auto dft = this->make_op(data, axes_input);
|
||||
|
||||
EXPECT_EQ(dft->get_element_type(), element::f32);
|
||||
EXPECT_EQ(dft->get_output_partial_shape(0), params.ref_output_shape);
|
||||
|
||||
std::vector<label_t> expected_labels(params.input_shape.size() - 1, no_label);
|
||||
expected_labels.push_back(get_shape_labels(params.input_shape).back());
|
||||
EXPECT_EQ(get_shape_labels(dft->get_output_partial_shape(0)), expected_labels);
|
||||
EXPECT_EQ(dft->get_output_partial_shape(0), params.input_shape);
|
||||
EXPECT_EQ(get_shape_labels(dft->get_output_partial_shape(0)), get_shape_labels(params.input_shape));
|
||||
}
|
||||
}
|
||||
|
||||
@ -318,7 +315,7 @@ TYPED_TEST_P(FFTNonConstantAxesTest, non_constant_axes_const_signal_size) {
|
||||
auto axes_input = std::make_shared<op::v0::Parameter>(element::i64, params.axes_shape);
|
||||
auto signal_size_input = op::v0::Constant::create<int64_t>(element::i64, Shape{2}, {100, 200});
|
||||
|
||||
auto dft = this->make_op(data, axes_input);
|
||||
auto dft = this->make_op(data, axes_input, signal_size_input);
|
||||
|
||||
EXPECT_EQ(dft->get_element_type(), element::f32);
|
||||
EXPECT_EQ(dft->get_output_partial_shape(0), params.ref_output_shape);
|
||||
@ -417,12 +414,12 @@ TYPED_TEST_P(FFTInvalidInput, invalid_axes) {
|
||||
auto axes = op::v0::Constant::create(element::i64, Shape{1}, {3});
|
||||
OV_EXPECT_THROW(std::ignore = this->make_op(data, axes),
|
||||
ov::Exception,
|
||||
HasSubstr("Axis value: 3, must be in range (-3, 2)"));
|
||||
HasSubstr("Parameter axis 3 out of the tensor rank range [-2, 1]"));
|
||||
|
||||
axes = op::v0::Constant::create(element::i64, Shape{1}, {-3});
|
||||
OV_EXPECT_THROW(std::ignore = this->make_op(data, axes),
|
||||
ov::Exception,
|
||||
HasSubstr("Axis value: -3, must be in range (-3, 2)"));
|
||||
HasSubstr("Parameter axis -3 out of the tensor rank range [-2, 1]"));
|
||||
|
||||
axes = op::v0::Constant::create(element::i64, Shape{2}, {0, -2});
|
||||
OV_EXPECT_THROW(std::ignore = this->make_op(data, axes), ov::Exception, HasSubstr("Each axis must be unique"));
|
||||
@ -430,7 +427,7 @@ TYPED_TEST_P(FFTInvalidInput, invalid_axes) {
|
||||
axes = op::v0::Constant::create(element::i64, Shape{1}, {2});
|
||||
OV_EXPECT_THROW(std::ignore = this->make_op(data, axes),
|
||||
ov::Exception,
|
||||
HasSubstr("Axis value: 2, must be in range (-3, 2)"));
|
||||
HasSubstr("Parameter axis 2 out of the tensor rank range [-2, 1]"));
|
||||
|
||||
axes = op::v0::Constant::create(element::i64, Shape{1, 2}, {0, 1});
|
||||
OV_EXPECT_THROW(std::ignore = this->make_op(data, axes), ov::Exception, HasSubstr("Axes input must be 1D tensor."));
|
||||
@ -477,4 +474,169 @@ TYPED_TEST_P(FFTDynamicTypes, dynamic_types) {
|
||||
REGISTER_TYPED_TEST_SUITE_P(FFTDynamicTypes, dynamic_types);
|
||||
INSTANTIATE_TYPED_TEST_SUITE_P(type_prop, FFTDynamicTypes, FFTBaseTypes);
|
||||
|
||||
struct FFTConstantAxesAndShapeOfSignalSizeTestParams {
|
||||
PartialShape input_shape;
|
||||
PartialShape ref_output_shape;
|
||||
std::vector<int64_t> axes;
|
||||
PartialShape signal_size;
|
||||
std::vector<ov::label_t> expected_labels;
|
||||
};
|
||||
|
||||
template <class TOp>
|
||||
class FFTConstantAxesAndShapeOfSignalSizeTest : public TypePropOpTest<TOp> {
|
||||
public:
|
||||
const std::vector<FFTConstantAxesAndShapeOfSignalSizeTestParams> test_params{
|
||||
FFTConstantAxesAndShapeOfSignalSizeTestParams{{3, {10, 16}, 18, 20, 2},
|
||||
{-1, {10, 16}, {6, 8}, 4, 2},
|
||||
{2, 0, -1},
|
||||
{{6, 8}, -1, 4},
|
||||
{21, 11, 20, 22, 14}},
|
||||
FFTConstantAxesAndShapeOfSignalSizeTestParams{{{8, 129}, 50, 130, {0, 500}, 2},
|
||||
{{8, 129}, {0, 10}, -1, 40, 2},
|
||||
{1, 2, 3},
|
||||
{{0, 10}, -1, 40},
|
||||
{10, 20, 21, 22, 14}}};
|
||||
};
|
||||
TYPED_TEST_SUITE_P(FFTConstantAxesAndShapeOfSignalSizeTest);
|
||||
|
||||
TYPED_TEST_P(FFTConstantAxesAndShapeOfSignalSizeTest, constant_axes_and_shape_of_signal_size) {
|
||||
for (auto params : this->test_params) {
|
||||
set_shape_labels(params.input_shape, 10);
|
||||
auto data = std::make_shared<op::v0::Parameter>(element::f32, params.input_shape);
|
||||
auto axes_input = op::v0::Constant::create<int64_t>(element::i64, Shape{params.axes.size()}, params.axes);
|
||||
|
||||
set_shape_labels(params.signal_size, 20);
|
||||
auto param_of_shape = std::make_shared<op::v0::Parameter>(element::f32, params.signal_size);
|
||||
auto signal_size_input = std::make_shared<op::v3::ShapeOf>(param_of_shape, element::i64);
|
||||
|
||||
auto dft = this->make_op(data, axes_input, signal_size_input);
|
||||
|
||||
EXPECT_EQ(dft->get_element_type(), element::f32);
|
||||
EXPECT_EQ(dft->get_output_partial_shape(0), params.ref_output_shape);
|
||||
|
||||
EXPECT_EQ(get_shape_labels(dft->get_output_partial_shape(0)), params.expected_labels);
|
||||
}
|
||||
}
|
||||
|
||||
TYPED_TEST_P(FFTConstantAxesAndShapeOfSignalSizeTest, constant_axes_and_shape_of_signal_size_first_input_labels) {
|
||||
auto params = FFTConstantAxesAndShapeOfSignalSizeTestParams{{3, {10, 16}, 18, 20, 2},
|
||||
{-1, {10, 16}, {6, 8}, 4, 2},
|
||||
{2, 0, -1},
|
||||
{{6, 8}, -1, 4},
|
||||
{ov::no_label, 11, ov::no_label, ov::no_label, 14}};
|
||||
set_shape_labels(params.input_shape, 10);
|
||||
auto data = std::make_shared<op::v0::Parameter>(element::f32, params.input_shape);
|
||||
auto axes_input = op::v0::Constant::create<int64_t>(element::i64, Shape{params.axes.size()}, params.axes);
|
||||
|
||||
auto param_of_shape = std::make_shared<op::v0::Parameter>(element::f32, params.signal_size);
|
||||
auto signal_size_input = std::make_shared<op::v3::ShapeOf>(param_of_shape, element::i64);
|
||||
|
||||
auto dft = this->make_op(data, axes_input, signal_size_input);
|
||||
|
||||
EXPECT_EQ(dft->get_element_type(), element::f32);
|
||||
EXPECT_EQ(dft->get_output_partial_shape(0), params.ref_output_shape);
|
||||
|
||||
EXPECT_EQ(get_shape_labels(dft->get_output_partial_shape(0)), params.expected_labels);
|
||||
}
|
||||
|
||||
TYPED_TEST_P(FFTConstantAxesAndShapeOfSignalSizeTest, constant_axes_and_shape_of_signal_size_second_input_labels) {
|
||||
auto params = FFTConstantAxesAndShapeOfSignalSizeTestParams{{3, {10, 16}, 18, 20, 2},
|
||||
{-1, {10, 16}, {6, 8}, 4, 2},
|
||||
{2, 0, -1},
|
||||
{{6, 8}, -1, 4},
|
||||
{21, ov::no_label, 20, 22, ov::no_label}};
|
||||
|
||||
auto data = std::make_shared<op::v0::Parameter>(element::f32, params.input_shape);
|
||||
auto axes_input = op::v0::Constant::create<int64_t>(element::i64, Shape{params.axes.size()}, params.axes);
|
||||
|
||||
set_shape_labels(params.signal_size, 20);
|
||||
auto param_of_shape = std::make_shared<op::v0::Parameter>(element::f32, params.signal_size);
|
||||
auto signal_size_input = std::make_shared<op::v3::ShapeOf>(param_of_shape, element::i64);
|
||||
|
||||
auto dft = this->make_op(data, axes_input, signal_size_input);
|
||||
|
||||
EXPECT_EQ(dft->get_element_type(), element::f32);
|
||||
EXPECT_EQ(dft->get_output_partial_shape(0), params.ref_output_shape);
|
||||
|
||||
EXPECT_EQ(get_shape_labels(dft->get_output_partial_shape(0)), params.expected_labels);
|
||||
}
|
||||
|
||||
TYPED_TEST_P(FFTConstantAxesAndShapeOfSignalSizeTest, constant_axes_and_shape_of_signal_size_single_label) {
|
||||
auto params = FFTConstantAxesAndShapeOfSignalSizeTestParams{{3, {10, 16}, 18, 20, 2},
|
||||
{-1, {10, 16}, {6, 8}, 4, 2},
|
||||
{2, 0, -1},
|
||||
{{6, 8}, -1, 4},
|
||||
{ov::no_label, 11, 22, ov::no_label, ov::no_label}};
|
||||
ov::DimensionTracker::set_label(params.input_shape[1], 11);
|
||||
|
||||
auto data = std::make_shared<op::v0::Parameter>(element::f32, params.input_shape);
|
||||
auto axes_input = op::v0::Constant::create<int64_t>(element::i64, Shape{params.axes.size()}, params.axes);
|
||||
|
||||
ov::DimensionTracker::set_label(params.signal_size[0], 22);
|
||||
|
||||
auto param_of_shape = std::make_shared<op::v0::Parameter>(element::f32, params.signal_size);
|
||||
auto signal_size_input = std::make_shared<op::v3::ShapeOf>(param_of_shape, element::i64);
|
||||
|
||||
auto dft = this->make_op(data, axes_input, signal_size_input);
|
||||
|
||||
EXPECT_EQ(dft->get_element_type(), element::f32);
|
||||
EXPECT_EQ(dft->get_output_partial_shape(0), params.ref_output_shape);
|
||||
|
||||
EXPECT_EQ(get_shape_labels(dft->get_output_partial_shape(0)), params.expected_labels);
|
||||
}
|
||||
|
||||
TYPED_TEST_P(FFTConstantAxesAndShapeOfSignalSizeTest, constant_axes_and_shape_of_signal_size_no_labels) {
|
||||
auto params = FFTConstantAxesAndShapeOfSignalSizeTestParams{
|
||||
{3, {10, 16}, 18, 20, 2},
|
||||
{-1, {10, 16}, {6, 8}, 4, 2},
|
||||
{2, 0, -1},
|
||||
{{6, 8}, -1, 4},
|
||||
{ov::no_label, ov::no_label, ov::no_label, ov::no_label, ov::no_label}};
|
||||
auto data = std::make_shared<op::v0::Parameter>(element::f32, params.input_shape);
|
||||
auto axes_input = op::v0::Constant::create<int64_t>(element::i64, Shape{params.axes.size()}, params.axes);
|
||||
|
||||
auto param_of_shape = std::make_shared<op::v0::Parameter>(element::f32, params.signal_size);
|
||||
auto signal_size_input = std::make_shared<op::v3::ShapeOf>(param_of_shape, element::i64);
|
||||
|
||||
auto dft = this->make_op(data, axes_input, signal_size_input);
|
||||
|
||||
EXPECT_EQ(dft->get_element_type(), element::f32);
|
||||
EXPECT_EQ(dft->get_output_partial_shape(0), params.ref_output_shape);
|
||||
|
||||
EXPECT_EQ(get_shape_labels(dft->get_output_partial_shape(0)), params.expected_labels);
|
||||
}
|
||||
|
||||
TYPED_TEST_P(FFTConstantAxesAndShapeOfSignalSizeTest, constant_axes_and_shape_of_concat_signal_size) {
|
||||
auto params = FFTConstantAxesAndShapeOfSignalSizeTestParams{{4, {8, 16}, 24, -1, 2},
|
||||
{{5, 10}, {8, 16}, -1, 40, 2},
|
||||
{1, 0, -2, 3},
|
||||
{{5, 10}, -1, 40},
|
||||
{20, 11, 21, 22, 14}};
|
||||
set_shape_labels(params.input_shape, 10);
|
||||
auto data = std::make_shared<op::v0::Parameter>(element::f32, params.input_shape);
|
||||
auto axes_input = op::v0::Constant::create<int64_t>(element::i64, Shape{params.axes.size()}, params.axes);
|
||||
|
||||
set_shape_labels(params.signal_size, 20);
|
||||
auto param_of_shape = std::make_shared<op::v0::Parameter>(element::f32, params.signal_size);
|
||||
auto shape_of = std::make_shared<op::v3::ShapeOf>(param_of_shape, element::i64);
|
||||
auto minus_one = op::v0::Constant::create<int64_t>(element::i64, Shape{1}, {-1});
|
||||
|
||||
auto signal_size_input = std::make_shared<op::v0::Concat>(OutputVector{minus_one, shape_of}, 0);
|
||||
|
||||
auto dft = this->make_op(data, axes_input, signal_size_input);
|
||||
|
||||
EXPECT_EQ(dft->get_element_type(), element::f32);
|
||||
EXPECT_EQ(dft->get_output_partial_shape(0), params.ref_output_shape);
|
||||
EXPECT_EQ(get_shape_labels(dft->get_output_partial_shape(0)), params.expected_labels);
|
||||
}
|
||||
|
||||
REGISTER_TYPED_TEST_SUITE_P(FFTConstantAxesAndShapeOfSignalSizeTest,
|
||||
constant_axes_and_shape_of_signal_size,
|
||||
constant_axes_and_shape_of_signal_size_first_input_labels,
|
||||
constant_axes_and_shape_of_signal_size_second_input_labels,
|
||||
constant_axes_and_shape_of_signal_size_single_label,
|
||||
constant_axes_and_shape_of_signal_size_no_labels,
|
||||
constant_axes_and_shape_of_concat_signal_size);
|
||||
INSTANTIATE_TYPED_TEST_SUITE_P(type_prop, FFTConstantAxesAndShapeOfSignalSizeTest, FFTBaseTypes);
|
||||
|
||||
} // namespace fft_base_test
|
||||
|
@ -395,12 +395,12 @@ TEST(type_prop, irdft_invalid_axes) {
|
||||
auto axes = op::v0::Constant::create(element::i64, Shape{1}, {3});
|
||||
OV_EXPECT_THROW(std::ignore = std::make_shared<op::v9::IRDFT>(data, axes),
|
||||
Exception,
|
||||
HasSubstr("Axis value: 3, must be in range (-3, 2)"));
|
||||
HasSubstr("Parameter axis 3 out of the tensor rank range [-2, 1]"));
|
||||
|
||||
axes = op::v0::Constant::create(element::i64, Shape{1}, {-3});
|
||||
OV_EXPECT_THROW(std::ignore = std::make_shared<op::v9::IRDFT>(data, axes),
|
||||
Exception,
|
||||
HasSubstr("Axis value: -3, must be in range (-3, 2)"));
|
||||
HasSubstr("Parameter axis -3 out of the tensor rank range [-2, 1]"));
|
||||
|
||||
axes = op::v0::Constant::create(element::i64, Shape{2}, {0, -2});
|
||||
OV_EXPECT_THROW(std::ignore = std::make_shared<op::v9::IRDFT>(data, axes),
|
||||
@ -410,7 +410,7 @@ TEST(type_prop, irdft_invalid_axes) {
|
||||
axes = op::v0::Constant::create(element::i64, Shape{1}, {2});
|
||||
OV_EXPECT_THROW(std::ignore = std::make_shared<op::v9::IRDFT>(data, axes),
|
||||
Exception,
|
||||
HasSubstr("Axis value: 2, must be in range (-3, 2)"));
|
||||
HasSubstr("Parameter axis 2 out of the tensor rank range [-2, 1]"));
|
||||
|
||||
axes = op::v0::Constant::create(element::i64, Shape{1, 2}, {0, 1});
|
||||
OV_EXPECT_THROW(std::ignore = std::make_shared<op::v9::IRDFT>(data, axes),
|
||||
|
@ -305,12 +305,12 @@ TEST(type_prop, rdft_invalid_axes) {
|
||||
auto axes = op::v0::Constant::create(element::i64, Shape{1}, {3});
|
||||
OV_EXPECT_THROW(std::ignore = std::make_shared<op::v9::RDFT>(data, axes),
|
||||
ov::Exception,
|
||||
HasSubstr("Axis value: 3, must be in range (-4, 3)"));
|
||||
HasSubstr("Parameter axis 3 out of the tensor rank range [-3, 2]"));
|
||||
|
||||
axes = op::v0::Constant::create(element::i64, Shape{1}, {-4});
|
||||
OV_EXPECT_THROW(std::ignore = std::make_shared<op::v9::RDFT>(data, axes),
|
||||
ov::Exception,
|
||||
HasSubstr("Axis value: -4, must be in range (-4, 3)"));
|
||||
HasSubstr("Parameter axis -4 out of the tensor rank range [-3, 2]"));
|
||||
|
||||
axes = op::v0::Constant::create(element::i64, Shape{2}, {0, -3});
|
||||
OV_EXPECT_THROW(std::ignore = std::make_shared<op::v9::RDFT>(data, axes),
|
||||
|
Loading…
Reference in New Issue
Block a user