Added axes node validation to DFTs operations (#11814)

* Fix DFTs axes node validation.

* Add DFTs type prop tests for invalid nodes.

* Adjusted DFTs axes node validation.
This commit is contained in:
Mykhailo Hnap 2022-07-01 09:19:04 +03:00 committed by GitHub
parent 8138e240a0
commit e23a568b7a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 387 additions and 14 deletions

View File

@ -53,22 +53,26 @@ void shape_infer(const ov::op::util::FFTBase* op,
// [n_0, ..., n_{r - 1}, 2] is interpreted as a complex tensor with the shape
// [n_0, ..., n_{r - 1}].
if (axes_shape.rank().is_static() && axes_are_known) {
const auto axis_min_value = -static_cast<int64_t>(input_rank);
const auto axis_max_value = static_cast<int64_t>(input_rank) - 1;
ov::AxisSet axes_set;
for (int64_t& axis : axes) {
NODE_VALIDATION_CHECK(op,
axis_min_value < axis && axis < axis_max_value,
"FFT op axis ",
axis,
" must be in the input rank range (",
axis_min_value,
", ",
axis_max_value,
").");
if (axis < 0) {
axis += input_rank - 1;
}
}
ov::AxisSet axes_set;
for (const auto& axis : axes) {
axes_set.insert(static_cast<size_t>(axis));
}
NODE_VALIDATION_CHECK(op, axes.size() == axes_set.size(), "FFT op axes must be unique.");
NODE_VALIDATION_CHECK(op,
std::find(axes.begin(), axes.end(), input_rank - 1) == axes.end(),
"FFT op axes cannot contain the last axis.");
}
}

View File

@ -78,8 +78,26 @@ void validate_axes(const ov::op::util::FFTBase* op,
// according to the RDFT operation specification, axes should be integers from -r to (r - 1)
// inclusively, where r = rank(data). A negative axis 'a' is interpreted as an axis 'r + a'.
const int64_t axis_correction = (rfft_kind == RFFTKind::Forward) ? input_rank : (input_rank - 1);
auto axis_min_value = -static_cast<int64_t>(input_rank);
auto axis_max_value = static_cast<int64_t>(input_rank) - 1;
// RDFT op axes can contain the last axis
if (rfft_kind == RFFTKind::Forward) {
--axis_min_value;
++axis_max_value;
}
ov::AxisSet axes_set;
for (int64_t& axis : axes) {
NODE_VALIDATION_CHECK(op,
axis_min_value < axis && axis < axis_max_value,
"(I)RDFT op axis ",
axis,
" must be in the input rank range (",
axis_min_value,
", ",
axis_max_value,
").");
if (axis < 0) {
axis += axis_correction;
}
@ -87,12 +105,6 @@ void validate_axes(const ov::op::util::FFTBase* op,
}
NODE_VALIDATION_CHECK(op, axes.size() == axes_set.size(), "(I)RDFT op axes must be unique.");
if (rfft_kind == RFFTKind::Inverse) {
NODE_VALIDATION_CHECK(op,
std::find(axes.begin(), axes.end(), input_rank - 1) == axes.end(),
"IRDFT op axes cannot contain the last axis.");
}
}
template <class T>

View File

@ -308,3 +308,96 @@ INSTANTIATE_TEST_SUITE_P(
{Dimension::dynamic(), Dimension::dynamic(), 130, Dimension::dynamic(), 2},
{3, 0, 1}}),
PrintToDummyParamName());
TEST(type_prop, dft_invalid_input) {
auto axes = op::Constant::create(element::i64, Shape{2}, {0, 1});
try {
auto data = std::make_shared<op::Parameter>(element::f32, Shape{2});
auto dft = std::make_shared<op::v7::DFT>(data, axes);
FAIL() << "DFT node was created with invalid input.";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), "The input rank must be greater or equal to 2.");
}
try {
auto data = std::make_shared<op::Parameter>(element::f32, Shape{4, 3});
auto dft = std::make_shared<op::v7::DFT>(data, axes);
FAIL() << "DFT node was created with invalid input.";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), "The last dimension of input data must be 2.");
}
try {
auto data = std::make_shared<op::Parameter>(element::f32, Shape{4, 2});
auto dft = std::make_shared<op::v7::DFT>(data, axes);
FAIL() << "DFT node was created with invalid input.";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), "The input rank must be greater than number of FFT op axes.");
}
}
TEST(type_prop, dft_invalid_axes) {
auto data = std::make_shared<op::Parameter>(element::f32, Shape{4, 3, 2});
try {
auto axes = op::Constant::create(element::i64, Shape{1}, {3});
auto dft = std::make_shared<op::v7::DFT>(data, axes);
FAIL() << "DFT node was created with invalid axes.";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), "FFT op axis 3 must be in the input rank range");
}
try {
auto axes = op::Constant::create(element::i64, Shape{1}, {-3});
auto dft = std::make_shared<op::v7::DFT>(data, axes);
FAIL() << "DFT node was created with invalid axes.";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), "FFT op axis -3 must be in the input rank range");
}
try {
auto axes = op::Constant::create(element::i64, Shape{2}, {0, -2});
auto dft = std::make_shared<op::v7::DFT>(data, axes);
FAIL() << "DFT node was created with invalid axes.";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), "FFT op axes must be unique.");
}
try {
auto axes = op::Constant::create(element::i64, Shape{1}, {2});
auto dft = std::make_shared<op::v7::DFT>(data, axes);
FAIL() << "DFT node was created with invalid axes.";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), "FFT op axis 2 must be in the input rank range");
}
try {
auto axes = op::Constant::create(element::i64, Shape{1, 2}, {0, 1});
auto dft = std::make_shared<op::v7::DFT>(data, axes);
FAIL() << "DFT node was created with invalid axes.";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), "FFT op axes input must be 1D tensor.");
}
}
TEST(type_prop, dft_invalid_signal_size) {
auto data = std::make_shared<op::Parameter>(element::f32, Shape{4, 3, 2});
auto axes = op::Constant::create(element::i64, Shape{1}, {0});
try {
auto signal_size = op::Constant::create(element::i64, Shape{1, 2}, {0, 1});
auto dft = std::make_shared<op::v7::DFT>(data, axes, signal_size);
FAIL() << "DFT node was created with invalid signal size.";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), "FFT op signal size input must be 1D tensor.");
}
try {
auto signal_size = op::Constant::create(element::i64, Shape{2}, {0, 1});
auto dft = std::make_shared<op::v7::DFT>(data, axes, signal_size);
FAIL() << "DFT node was created with invalid signal size.";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), "Sizes of inputs 'axes' and 'signal_size' must be equal.");
}
}

View File

@ -296,3 +296,96 @@ INSTANTIATE_TEST_SUITE_P(
{Dimension::dynamic(), Dimension::dynamic(), 130, Dimension::dynamic(), 2},
{3, 0, 1}}),
PrintToDummyParamName());
TEST(type_prop, idft_invalid_input) {
auto axes = op::Constant::create(element::i64, Shape{2}, {0, 1});
try {
auto data = std::make_shared<op::Parameter>(element::f32, Shape{2});
auto idft = std::make_shared<op::v7::IDFT>(data, axes);
FAIL() << "IDFT node was created with invalid input.";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), "The input rank must be greater or equal to 2.");
}
try {
auto data = std::make_shared<op::Parameter>(element::f32, Shape{4, 3});
auto idft = std::make_shared<op::v7::IDFT>(data, axes);
FAIL() << "IDFT node was created with invalid input.";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), "The last dimension of input data must be 2.");
}
try {
auto data = std::make_shared<op::Parameter>(element::f32, Shape{4, 2});
auto idft = std::make_shared<op::v7::IDFT>(data, axes);
FAIL() << "IDFT node was created with invalid input.";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), "The input rank must be greater than number of FFT op axes.");
}
}
TEST(type_prop, idft_invalid_axes) {
auto data = std::make_shared<op::Parameter>(element::f32, Shape{4, 3, 2});
try {
auto axes = op::Constant::create(element::i64, Shape{1}, {3});
auto idft = std::make_shared<op::v7::IDFT>(data, axes);
FAIL() << "IDFT node was created with invalid axes.";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), "FFT op axis 3 must be in the input rank range");
}
try {
auto axes = op::Constant::create(element::i64, Shape{1}, {-3});
auto idft = std::make_shared<op::v7::IDFT>(data, axes);
FAIL() << "IDFT node was created with invalid axes.";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), "FFT op axis -3 must be in the input rank range");
}
try {
auto axes = op::Constant::create(element::i64, Shape{2}, {0, -2});
auto idft = std::make_shared<op::v7::IDFT>(data, axes);
FAIL() << "IDFT node was created with invalid axes.";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), "FFT op axes must be unique.");
}
try {
auto axes = op::Constant::create(element::i64, Shape{1}, {2});
auto idft = std::make_shared<op::v7::IDFT>(data, axes);
FAIL() << "IDFT node was created with invalid axes.";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), "FFT op axis 2 must be in the input rank range");
}
try {
auto axes = op::Constant::create(element::i64, Shape{1, 2}, {0, 1});
auto idft = std::make_shared<op::v7::IDFT>(data, axes);
FAIL() << "IDFT node was created with invalid axes.";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), "FFT op axes input must be 1D tensor.");
}
}
TEST(type_prop, idft_invalid_signal_size) {
auto data = std::make_shared<op::Parameter>(element::f32, Shape{4, 3, 2});
auto axes = op::Constant::create(element::i64, Shape{1}, {0});
try {
auto signal_size = op::Constant::create(element::i64, Shape{1, 2}, {0, 1});
auto idft = std::make_shared<op::v7::IDFT>(data, axes, signal_size);
FAIL() << "IDFT node was created with invalid signal size.";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), "FFT op signal size input must be 1D tensor.");
}
try {
auto signal_size = op::Constant::create(element::i64, Shape{2}, {0, 1});
auto idft = std::make_shared<op::v7::IDFT>(data, axes, signal_size);
FAIL() << "IDFT node was created with invalid signal size.";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), "Sizes of inputs 'axes' and 'signal_size' must be equal.");
}
}

View File

@ -307,3 +307,96 @@ INSTANTIATE_TEST_SUITE_P(
{Dimension(8, 129), Dimension::dynamic(), 130, Dimension(0, 500)},
{3, 0, 1}}),
PrintToDummyParamName());
TEST(type_prop, irdft_invalid_input) {
auto axes = op::Constant::create(element::i64, Shape{2}, {0, 1});
try {
auto data = std::make_shared<op::Parameter>(element::f32, Shape{2});
auto irdft = std::make_shared<op::v9::IRDFT>(data, axes);
FAIL() << "IRDFT node was created with invalid input.";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), "The input rank must be greater or equal to 2.");
}
try {
auto data = std::make_shared<op::Parameter>(element::f32, Shape{4, 3});
auto irdft = std::make_shared<op::v9::IRDFT>(data, axes);
FAIL() << "IRDFT node was created with invalid input.";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), "The last dimension of input data must be 2.");
}
try {
auto data = std::make_shared<op::Parameter>(element::f32, Shape{4, 2});
auto irdft = std::make_shared<op::v9::IRDFT>(data, axes);
FAIL() << "IRDFT node was created with invalid input.";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), "The input rank must be greater than number of IRDFT op axes.");
}
}
TEST(type_prop, irdft_invalid_axes) {
auto data = std::make_shared<op::Parameter>(element::f32, Shape{4, 3, 2});
try {
auto axes = op::Constant::create(element::i64, Shape{1}, {3});
auto irdft = std::make_shared<op::v9::IRDFT>(data, axes);
FAIL() << "IRDFT node was created with invalid axes.";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), "(I)RDFT op axis 3 must be in the input rank range");
}
try {
auto axes = op::Constant::create(element::i64, Shape{1}, {-3});
auto irdft = std::make_shared<op::v9::IRDFT>(data, axes);
FAIL() << "IRDFT node was created with invalid axes.";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), "(I)RDFT op axis -3 must be in the input rank range");
}
try {
auto axes = op::Constant::create(element::i64, Shape{2}, {0, -2});
auto irdft = std::make_shared<op::v9::IRDFT>(data, axes);
FAIL() << "IRDFT node was created with invalid axes.";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), "(I)RDFT op axes must be unique.");
}
try {
auto axes = op::Constant::create(element::i64, Shape{1}, {2});
auto irdft = std::make_shared<op::v9::IRDFT>(data, axes);
FAIL() << "IRDFT node was created with invalid axes.";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), "(I)RDFT op axis 2 must be in the input rank range");
}
try {
auto axes = op::Constant::create(element::i64, Shape{1, 2}, {0, 1});
auto irdft = std::make_shared<op::v9::IRDFT>(data, axes);
FAIL() << "IRDFT node was created with invalid axes.";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), "(I)RDFT op axes input must be 1D tensor.");
}
}
TEST(type_prop, irdft_invalid_signal_size) {
auto data = std::make_shared<op::Parameter>(element::f32, Shape{4, 3, 2});
auto axes = op::Constant::create(element::i64, Shape{1}, {0});
try {
auto signal_size = op::Constant::create(element::i64, Shape{1, 2}, {0, 1});
auto irdft = std::make_shared<op::v9::IRDFT>(data, axes, signal_size);
FAIL() << "IRDFT node was created with invalid signal size.";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), "(I)RDFT op signal size input must be 1D tensor.");
}
try {
auto signal_size = op::Constant::create(element::i64, Shape{2}, {0, 1});
auto irdft = std::make_shared<op::v9::IRDFT>(data, axes, signal_size);
FAIL() << "IRDFT node was created with invalid signal size.";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), "Sizes of inputs 'axes' and 'signal_size' of (I)RDFT op must be equal.");
}
}

View File

@ -243,3 +243,81 @@ INSTANTIATE_TEST_SUITE_P(
{Dimension(8, 129), Dimension::dynamic(), 130, Dimension(0, 500), 2},
{3, 0, 1}}),
PrintToDummyParamName());
TEST(type_prop, rdft_invalid_input) {
auto axes = op::Constant::create(element::i64, Shape{2}, {0, 1});
try {
auto data = std::make_shared<op::Parameter>(element::f32, Shape{});
auto rdft = std::make_shared<op::v9::RDFT>(data, axes);
FAIL() << "RDFT node was created with invalid input.";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), "The input rank must be greater or equal to 1.");
}
try {
auto data = std::make_shared<op::Parameter>(element::f32, Shape{4});
auto rdft = std::make_shared<op::v9::RDFT>(data, axes);
FAIL() << "RDFT node was created with invalid input.";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(),
"The input rank must be greater than or equal to the number of RDFT op axes.");
}
}
TEST(type_prop, rdft_invalid_axes) {
auto data = std::make_shared<op::Parameter>(element::f32, Shape{4, 3, 2});
try {
auto axes = op::Constant::create(element::i64, Shape{1}, {3});
auto rdft = std::make_shared<op::v9::RDFT>(data, axes);
FAIL() << "RDFT node was created with invalid axes.";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), "(I)RDFT op axis 3 must be in the input rank range");
}
try {
auto axes = op::Constant::create(element::i64, Shape{1}, {-4});
auto rdft = std::make_shared<op::v9::RDFT>(data, axes);
FAIL() << "RDFT node was created with invalid axes.";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), "(I)RDFT op axis -4 must be in the input rank range");
}
try {
auto axes = op::Constant::create(element::i64, Shape{2}, {0, -3});
auto rdft = std::make_shared<op::v9::RDFT>(data, axes);
FAIL() << "RDFT node was created with invalid axes.";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), "(I)RDFT op axes must be unique.");
}
try {
auto axes = op::Constant::create(element::i64, Shape{1, 2}, {0, 1});
auto rdft = std::make_shared<op::v9::RDFT>(data, axes);
FAIL() << "RDFT node was created with invalid axes.";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), "(I)RDFT op axes input must be 1D tensor.");
}
}
TEST(type_prop, rdft_invalid_signal_size) {
auto data = std::make_shared<op::Parameter>(element::f32, Shape{4, 3, 2});
auto axes = op::Constant::create(element::i64, Shape{1}, {0});
try {
auto signal_size = op::Constant::create(element::i64, Shape{1, 2}, {0, 1});
auto rdft = std::make_shared<op::v9::RDFT>(data, axes, signal_size);
FAIL() << "RDFT node was created with invalid signal size.";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), "(I)RDFT op signal size input must be 1D tensor.");
}
try {
auto signal_size = op::Constant::create(element::i64, Shape{2}, {0, 1});
auto rdft = std::make_shared<op::v9::RDFT>(data, axes, signal_size);
FAIL() << "RDFT node was created with invalid signal size.";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), "Sizes of inputs 'axes' and 'signal_size' of (I)RDFT op must be equal.");
}
}