Mang/shape inference (#8412)

* Implement DFT and IDFT shape inference

* Implement CTCLoss shape inference

* Fix error message.

* Refactor test case.

* Apply review comments

* Apply review comments

* Fix clang format error

* Fix merge error

* Remove axes_vector.
This commit is contained in:
Mang Guo
2021-12-22 12:47:44 +08:00
committed by GitHub
parent b8e6b6368c
commit 5fada94504
8 changed files with 460 additions and 257 deletions

View File

@@ -29,6 +29,8 @@
#include "reduce_shape_inference.hpp"
#include "scatter_elements_update_shape_inference.hpp"
#include "scatter_nd_base_shape_inference.hpp"
#include "ctc_loss_shape_inference.hpp"
#include "fft_base_shape_inference.hpp"
#include "shape_inference.hpp"
#include "shape_nodes.hpp"
#include "fake_quantize.hpp"
@@ -184,6 +186,12 @@ void shape_inference(ov::Node* op,
shape_infer(node, input_shapes, output_shapes);
} else if (auto node = ov::as_type<ov::opset1::OneHot>(op)) {
shape_infer(node, input_shapes, output_shapes, constant_data);
} else if (auto node = ov::as_type<ov::opset4::CTCLoss>(op)) {
shape_infer(node, input_shapes, output_shapes);
} else if (auto node = ov::as_type<ov::opset7::DFT>(op)) {
shape_infer(node, input_shapes, output_shapes, constant_data);
} else if (auto node = ov::as_type<ov::opset7::IDFT>(op)) {
shape_infer(node, input_shapes, output_shapes, constant_data);
} else {
ngraph::OutputVector new_inputs;
for (size_t i = 0; i < op->get_input_size(); ++i) {

View File

@@ -33,8 +33,6 @@ protected:
/// \param axes Axes to perform FFT
/// \param signal_size Signal sizes for 'axes'
FFTBase(const Output<Node>& data, const Output<Node>& axes, const Output<Node>& signal_size);
void validate();
};
} // namespace util
} // namespace op

View File

@@ -0,0 +1,133 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <openvino/op/ctc_loss.hpp>
namespace ov {
namespace op {
namespace v4 {
template <class T>
void shape_infer(const CTCLoss* op, const std::vector<T>& input_shapes, std::vector<T>& output_shapes) {
using DimType = typename std::iterator_traits<typename T::iterator>::value_type;
NODE_VALIDATION_CHECK(op, (input_shapes.size() == 4 || input_shapes.size() == 5) && output_shapes.size() == 1);
// check ranks of input tensors
const auto& logits_pshape = input_shapes[0];
const auto& logit_length_pshape = input_shapes[1];
const auto& labels_pshape = input_shapes[2];
const auto& label_length_pshape = input_shapes[3];
NODE_VALIDATION_CHECK(op,
logits_pshape.rank().compatible(3),
"Expected a 3D tensor for logits. Got: ",
logits_pshape);
NODE_VALIDATION_CHECK(op,
logit_length_pshape.rank().compatible(1),
"Expected a 1D tensor for logit length. Got: ",
logit_length_pshape);
NODE_VALIDATION_CHECK(op,
labels_pshape.rank().compatible(2),
"Expected a 2D tensor for labels. Got: ",
labels_pshape);
NODE_VALIDATION_CHECK(op,
label_length_pshape.rank().compatible(1),
"Expected a 1D tensor for label length. Got: ",
label_length_pshape);
// check optional input shape: blank index
if (input_shapes.size() == 5) {
const auto& blank_index_pshape = input_shapes[4];
NODE_VALIDATION_CHECK(op,
blank_index_pshape.rank().compatible(0),
"Expected a scalar for blank index. Got: ",
blank_index_pshape);
}
// check shapes of input tensors
DimType batch_size = 1;
bool is_batch_size_set = false;
DimType time_steps = 1;
bool is_time_steps_set = false;
if (logits_pshape.rank().is_static()) {
batch_size = logits_pshape[0];
is_batch_size_set = true;
time_steps = logits_pshape[1];
is_time_steps_set = true;
}
if (logit_length_pshape.rank().is_static()) {
if (is_batch_size_set) {
NODE_VALIDATION_CHECK(op,
DimType::merge(batch_size, batch_size, logit_length_pshape[0]),
"The first dimension of logit length must be equal to the first dimension ",
"of the logits. Got: ",
logit_length_pshape[0],
" and: ",
batch_size);
} else {
batch_size = logit_length_pshape[0];
is_batch_size_set = true;
}
}
if (labels_pshape.rank().is_static()) {
if (is_batch_size_set) {
NODE_VALIDATION_CHECK(op,
DimType::merge(batch_size, batch_size, labels_pshape[0]),
"The first dimension of labels must be equal to the first dimension ",
"of the logits and the logit length. Got: ",
labels_pshape[0],
" and: ",
batch_size);
} else {
batch_size = labels_pshape[0];
is_batch_size_set = true;
}
if (is_time_steps_set) {
NODE_VALIDATION_CHECK(op,
DimType::merge(time_steps, time_steps, labels_pshape[1]),
"The second dimension of labels must be equal to the second dimension ",
"of logits. Got: ",
labels_pshape[1],
" and: ",
time_steps);
}
}
if (label_length_pshape.rank().is_static()) {
if (is_batch_size_set) {
NODE_VALIDATION_CHECK(op,
DimType::merge(batch_size, batch_size, label_length_pshape[0]),
"The first dimension of label length must be equal to the first dimension ",
"of the logits, the logit length and labels. Got: ",
label_length_pshape[0],
" and: ",
batch_size);
} else {
batch_size = label_length_pshape[0];
is_batch_size_set = true;
}
}
auto& output_shape = output_shapes[0];
output_shape.resize(1);
if (is_batch_size_set) {
output_shape[0] = batch_size;
} else {
output_shape[0] = Dimension::dynamic();
}
}
} // namespace v4
} // namespace op
} // namespace ov

View File

@@ -0,0 +1,117 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <openvino/op/util/fft_base.hpp>
#include "openvino/core/axis_vector.hpp"
#include "utils.hpp"
template <class T>
void shape_infer(const ov::op::util::FFTBase* op,
const std::vector<T>& input_shapes,
std::vector<T>& output_shapes,
const std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>>& constant_data = {}) {
using DimType = typename std::iterator_traits<typename T::iterator>::value_type;
NODE_VALIDATION_CHECK(op, (input_shapes.size() == 2 || input_shapes.size() == 3) && output_shapes.size() == 1);
const auto& input_shape = input_shapes[0];
const auto& axes_shape = input_shapes[1];
auto& output_shape = output_shapes[0];
std::vector<int64_t> axes;
bool axes_are_known = get_data_as_int64<T>(1, op, axes, constant_data);
if (input_shape.rank().is_static()) {
const auto input_rank = input_shape.size();
NODE_VALIDATION_CHECK(op,
input_rank >= 2,
"The input rank must be greater or equal to 2. Got input rank: ",
input_rank);
NODE_VALIDATION_CHECK(op,
input_shape[input_rank - 1].compatible(2),
"The last dimension of input data must be 2. Got: ",
input_shape[input_rank - 1]);
if (axes_shape.is_static()) {
NODE_VALIDATION_CHECK(op,
input_rank >= static_cast<int64_t>(axes_shape[0].get_length() + 1),
"The input rank must be greater than number of FFT op axes. Got "
"input rank: ",
input_rank,
", number of axes: ",
axes_shape[0].get_length());
}
// FFT operation supports for negative axes to transform. More precisely, according to
// the FFT operation specification, axes should be integers from -(r - 1) to (r - 2)
// inclusively, where r = rank(data). A negative axis 'a' is interpreted as an axis
// 'r - 1 + a'. The reason is the following: real input tensor of the shape
// [n_0, ..., n_{r - 1}, 2] is interpreted as a complex tensor with the shape
// [n_0, ..., n_{r - 1}].
if (axes_shape.rank().is_static() && axes_are_known) {
for (int64_t& axis : axes) {
if (axis < 0) {
axis += input_rank - 1;
}
}
ov::AxisSet axes_set;
for (const auto& axis : axes) {
axes_set.insert(static_cast<size_t>(axis));
}
NODE_VALIDATION_CHECK(op, axes.size() == axes_set.size(), "FFT op axes must be unique.");
NODE_VALIDATION_CHECK(op,
std::find(axes.begin(), axes.end(), input_rank - 1) == axes.end(),
"FFT op axes cannot contain the last axis.");
}
}
NODE_VALIDATION_CHECK(op, axes_shape.rank().compatible(1), "FFT op axes input must be 1D tensor.");
if (input_shapes.size() == 3) {
const auto& signal_size_shape = input_shapes[2];
NODE_VALIDATION_CHECK(op,
signal_size_shape.rank().compatible(1),
"FFT op signal size input must be 1D tensor. Got signal: ",
signal_size_shape);
if (axes_shape.is_static() && signal_size_shape.is_static()) {
NODE_VALIDATION_CHECK(op,
axes_shape[0].compatible(signal_size_shape[0]),
"Sizes of inputs 'axes' and 'signal_size' must be equal. Got "
"size of 'axes': ",
axes_shape[0],
"size of 'signal_size': ",
signal_size_shape[0]);
}
}
output_shape = input_shape;
if (input_shape.rank().is_static() && axes_shape.rank().is_static() && input_shapes.size() == 3 && axes_are_known) {
const auto& signal_size_shape = input_shapes[2];
std::vector<int64_t> signal_size;
bool status_signal_size = get_data_as_int64<T>(2, op, signal_size, constant_data);
if (signal_size_shape.rank().is_static() && status_signal_size) {
size_t num_of_axes = axes.size();
for (size_t i = 0; i < num_of_axes; ++i) {
if (signal_size[i] == -1) {
continue;
}
output_shape[axes[i]] = DimType(signal_size[i]);
}
} else if (signal_size_shape.rank().is_static()) {
for (int64_t& axis : axes) {
output_shape[axis] = ov::Dimension::dynamic();
}
}
} else if (input_shape.rank().is_static() && (axes_shape.rank().is_dynamic() || !axes_are_known)) {
const auto input_rank = input_shape.size();
for (int64_t i = 0; i < input_rank - 1; ++i) {
output_shape[i] = ov::Dimension::dynamic();
}
}
}

View File

@@ -4,6 +4,8 @@
#include "ngraph/op/ctc_loss.hpp"
#include <ctc_loss_shape_inference.hpp>
#include "itt.hpp"
using namespace std;
@@ -77,119 +79,22 @@ void op::v4::CTCLoss::validate_and_infer_types() {
blank_index_type);
}
// check ranks of input tensors
const auto& logits_pshape = get_input_partial_shape(0);
const auto& logit_length_pshape = get_input_partial_shape(1);
const auto& labels_pshape = get_input_partial_shape(2);
const auto& label_length_pshape = get_input_partial_shape(3);
NODE_VALIDATION_CHECK(this,
logits_pshape.rank().compatible(3),
"Expected a 3D tensor for logits. Got: ",
logits_pshape);
NODE_VALIDATION_CHECK(this,
logit_length_pshape.rank().compatible(1),
"Expected a 1D tensor for logit length. Got: ",
logit_length_pshape);
NODE_VALIDATION_CHECK(this,
labels_pshape.rank().compatible(2),
"Expected a 2D tensor for labels. Got: ",
labels_pshape);
NODE_VALIDATION_CHECK(this,
label_length_pshape.rank().compatible(1),
"Expected a 1D tensor for label length. Got: ",
label_length_pshape);
// check optional input shape: blank index
std::vector<ov::PartialShape> input_shapes;
if (get_input_size() == 5) {
const auto& blank_index_pshape = get_input_partial_shape(4);
NODE_VALIDATION_CHECK(this,
blank_index_pshape.rank().compatible(0),
"Expected a scalar for blank index. Got: ",
blank_index_pshape);
}
// check shapes of input tensors
size_t batch_size = 1;
bool is_batch_size_set = false;
size_t time_steps = 1;
bool is_time_steps_set = false;
if (logits_pshape.rank().is_static()) {
if (logits_pshape[0].is_static()) {
batch_size = logits_pshape[0].get_length();
is_batch_size_set = true;
}
if (logits_pshape[1].is_static()) {
time_steps = logits_pshape[1].get_length();
is_time_steps_set = true;
}
}
if (logit_length_pshape.is_static()) {
if (is_batch_size_set) {
NODE_VALIDATION_CHECK(this,
logit_length_pshape[0].compatible(batch_size),
"The first dimension of logit length must be equal to the first dimension ",
"of the logits. Got: ",
logit_length_pshape[0],
" and: ",
batch_size);
} else if (logit_length_pshape[0].is_static()) {
batch_size = logit_length_pshape[0].get_length();
is_batch_size_set = true;
}
}
if (labels_pshape.is_static()) {
if (is_batch_size_set) {
NODE_VALIDATION_CHECK(this,
labels_pshape[0].compatible(batch_size),
"The first dimension of labels must be equal to the first dimension ",
"of the logits and the logit length. Got: ",
labels_pshape[0],
" and: ",
batch_size);
} else if (labels_pshape[0].is_static()) {
batch_size = labels_pshape[0].get_length();
is_batch_size_set = true;
}
if (is_time_steps_set) {
NODE_VALIDATION_CHECK(this,
labels_pshape[1].compatible(time_steps),
"The second dimension of labels must be equal to the second dimension ",
"of logits. Got: ",
labels_pshape[1],
" and: ",
time_steps);
}
}
if (label_length_pshape.is_static()) {
if (!is_batch_size_set && label_length_pshape[0].is_static()) {
batch_size = label_length_pshape[0].get_length();
is_batch_size_set = true;
}
NODE_VALIDATION_CHECK(this,
label_length_pshape[0].compatible(batch_size),
"The first dimension of label length must be equal to the first dimension ",
"of the logits, the logit length and labels. Got: ",
label_length_pshape[0],
" and: ",
batch_size);
}
// set output shape
set_output_size(1);
if (is_batch_size_set) {
set_output_type(0, logits_type, ov::Shape{batch_size});
input_shapes = {logits_pshape, logit_length_pshape, labels_pshape, label_length_pshape, blank_index_pshape};
} else {
set_output_type(0, logits_type, ov::PartialShape{Dimension::dynamic()});
input_shapes = {logits_pshape, logit_length_pshape, labels_pshape, label_length_pshape};
}
std::vector<ov::PartialShape> output_shapes = {ov::PartialShape{}};
shape_infer(this, input_shapes, output_shapes);
set_output_type(0, logits_type, output_shapes[0]);
}
bool op::v4::CTCLoss::visit_attributes(AttributeVisitor& visitor) {

View File

@@ -4,6 +4,7 @@
#include "ngraph/op/util/fft_base.hpp"
#include <fft_base_shape_inference.hpp>
#include <ngraph/validation_util.hpp>
#include "itt.hpp"
@@ -23,9 +24,10 @@ bool ov::op::util::FFTBase::visit_attributes(AttributeVisitor& visitor) {
return true;
}
void ov::op::util::FFTBase::validate() {
size_t num_of_inputs = get_input_size();
void ov::op::util::FFTBase::validate_and_infer_types() {
NGRAPH_OP_SCOPE(util_FFTBase_validate_and_infer_types);
size_t num_of_inputs = get_input_size();
NODE_VALIDATION_CHECK(this, num_of_inputs == 2 || num_of_inputs == 3, "FFT op must have 2 or 3 inputs.");
element::Type input_et = get_input_element_type(0);
@@ -38,164 +40,25 @@ void ov::op::util::FFTBase::validate() {
axes_et == element::i64 || axes_et == element::i32,
"FFT op axes element type must be i32 or i64");
const auto& input_shape = PartialShape(get_input_partial_shape(0));
if (input_shape.rank().is_static()) {
const auto input_rank = input_shape.rank().get_length();
NODE_VALIDATION_CHECK(this,
input_rank >= 2,
"The input rank must be greater or equal to 2. Got input rank: ",
input_rank);
auto last_dim_with_two = input_shape[input_rank - 1] & Dimension(2);
NODE_VALIDATION_CHECK(this,
!last_dim_with_two.get_interval().empty(),
"The last dimension of input data must be 2. Got: ",
input_shape[input_rank - 1]);
}
const auto& axes_shape = PartialShape(get_input_partial_shape(1));
if (axes_shape.rank().is_static()) {
NODE_VALIDATION_CHECK(this,
axes_shape.rank().get_length() == 1,
"FFT op axes input must be 1D tensor. Got axes input rank: ",
axes_shape.rank().get_length());
}
if (input_shape.rank().is_static() && axes_shape.is_static()) {
const auto input_rank = input_shape.rank().get_length();
NODE_VALIDATION_CHECK(this,
input_rank >= static_cast<int64_t>(axes_shape.to_shape()[0] + 1),
"The input rank must be greater than number of FFT op axes. Got "
"input rank: ",
input_rank,
", number of axes: ",
axes_shape.to_shape()[0]);
}
if (input_shape.rank().is_static() && ov::is_type<ngraph::op::Constant>(input_value(1).get_node())) {
const auto input_rank = input_shape.rank().get_length();
const auto& const_axes = get_constant_from_source(input_value(1));
auto axes = const_axes->cast_vector<int64_t>();
// FFT operation supports for negative axes to transform. More precisely, according to
// the FFT operation specification, axes should be integers from -(r - 1) to (r - 2)
// inclusively, where r = rank(data). A negative axis 'a' is interpreted as an axis
// 'r - 1 + a'. The reason is the following: real input tensor of the shape
// [n_0, ..., n_{r - 1}, 2] is interpreted as a complex tensor with the shape
// [n_0, ..., n_{r - 1}].
for (int64_t& axis : axes) {
if (axis < 0) {
axis += input_rank - 1;
}
}
AxisVector axes_vector;
AxisSet axes_set;
for (const int64_t axis : axes) {
axes_vector.push_back(static_cast<size_t>(axis));
axes_set.insert(static_cast<size_t>(axis));
}
NODE_VALIDATION_CHECK(this, axes.size() == axes_set.size(), "FFT op axes must be unique. Got: ", axes_vector);
NODE_VALIDATION_CHECK(this,
std::find(axes.begin(), axes.end(), input_rank - 1) == axes.end(),
"FFT op axes cannot contain the last axis. Got axes: ",
axes_vector);
}
if (num_of_inputs == 3) {
element::Type signal_size_et = get_input_element_type(2);
NODE_VALIDATION_CHECK(this,
signal_size_et == element::i64 || signal_size_et == element::i32,
"FFT op signal_size element type must be i32 or i64");
const auto& signal_size_shape = PartialShape(get_input_partial_shape(2));
if (signal_size_shape.rank().is_static()) {
NODE_VALIDATION_CHECK(this,
signal_size_shape.rank().get_length() == 1,
"FFT op signal size input must be 1D tensor. Got signal size "
"input rank: ",
signal_size_shape.rank().get_length());
}
if (axes_shape.is_static() && signal_size_shape.is_static()) {
NODE_VALIDATION_CHECK(this,
axes_shape.to_shape()[0] == signal_size_shape.to_shape()[0],
"Sizes of inputs 'axes' and 'signal_size' must be equal. Got "
"size of 'axes': ",
axes_shape.to_shape()[0],
"size of 'signal_size': ",
signal_size_shape.to_shape()[0]);
}
}
}
void ov::op::util::FFTBase::validate_and_infer_types() {
NGRAPH_OP_SCOPE(util_FFTBase_validate_and_infer_types);
validate();
const auto& input_shape = PartialShape(get_input_partial_shape(0));
const auto& axes_shape = PartialShape(get_input_partial_shape(1));
PartialShape output_shape = input_shape;
if (input_shape.rank().is_dynamic()) {
set_output_type(0, get_input_element_type(0), output_shape);
return;
}
const auto input_rank = input_shape.rank().get_length();
if (axes_shape.rank().is_dynamic() || !ov::is_type<ngraph::op::Constant>(input_value(1).get_node())) {
for (int64_t i = 0; i < input_rank - 1; ++i) {
output_shape[i] = Dimension::dynamic();
}
set_output_type(0, get_input_element_type(0), output_shape);
return;
}
std::vector<ov::PartialShape> output_shapes = {ov::PartialShape()};
std::vector<ov::PartialShape> input_shapes;
const auto& data = get_input_partial_shape(0);
const auto& axes = get_input_partial_shape(1);
if (input_values().size() == 2) {
set_output_type(0, get_input_element_type(0), output_shape);
return;
input_shapes = {data, axes};
} else {
const auto& signal_size = get_input_partial_shape(2);
input_shapes = {data, axes, signal_size};
}
const auto& signal_size_shape = PartialShape(get_input_partial_shape(2));
if (signal_size_shape.rank().is_dynamic()) {
set_output_type(0, get_input_element_type(0), output_shape);
return;
}
const auto& const_axes = get_constant_from_source(input_value(1));
auto axes = const_axes->cast_vector<int64_t>();
// FFT operation supports for negative axes to transform. More precisely, according to
// the FFT operation specification, axes should be integers from -(r - 1) to (r - 2)
// inclusively, where r = rank(data). A negative axis 'a' is interpreted as an axis
// 'r - 1 + a'. The reason is the following: real input tensor of the shape
// [n_0, ..., n_{r - 1}, 2] is interpreted as a complex tensor with the shape
// [n_0, ..., n_{r - 1}].
for (int64_t& axis : axes) {
if (axis < 0) {
axis += input_rank - 1;
}
}
if (!ov::is_type<ngraph::op::Constant>(input_value(2).get_node())) {
for (int64_t axis : axes) {
output_shape[axis] = Dimension::dynamic();
}
set_output_type(0, get_input_element_type(0), output_shape);
return;
}
const auto& const_signal_size = get_constant_from_source(input_value(2));
const auto signal_size = const_signal_size->cast_vector<int64_t>();
size_t num_of_axes = axes.size();
for (size_t i = 0; i < num_of_axes; ++i) {
if (signal_size[i] == -1) {
continue;
}
output_shape[axes[i]] = Dimension(signal_size[i]);
}
set_output_type(0, get_input_element_type(0), output_shape);
shape_infer(this, input_shapes, output_shapes);
set_output_type(0, get_input_element_type(0), output_shapes[0]);
}

View File

@@ -0,0 +1,34 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include <ctc_loss_shape_inference.hpp>
#include <openvino/op/ctc_loss.hpp>
#include <openvino/op/parameter.hpp>
#include <utils/shape_inference/shape_inference.hpp>
#include <utils/shape_inference/static_shape.hpp>
using namespace ov;
TEST(StaticShapeInferenceTest, CTCLossTest) {
const auto& logits = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1});
const auto& logit_length = std::make_shared<ov::op::v0::Parameter>(element::i32, PartialShape{-1});
const auto& labels = std::make_shared<ov::op::v0::Parameter>(element::i32, PartialShape{-1, -1});
const auto& label_length = std::make_shared<ov::op::v0::Parameter>(element::i32, PartialShape{-1});
const auto& blank_index = std::make_shared<ov::op::v0::Parameter>(element::i32, ov::Shape{});
// create CTCLoss node
auto ctc_loss = std::make_shared<op::v4::CTCLoss>(logits, logit_length, labels, label_length, blank_index);
std::vector<StaticShape> static_input_shapes = {StaticShape{10, 120, 28},
StaticShape{10},
StaticShape{10, 120},
StaticShape{10},
ov::Shape{}},
static_output_shapes = {StaticShape{}};
shape_inference(ctc_loss.get(), static_input_shapes, static_output_shapes);
ASSERT_EQ(static_output_shapes[0], StaticShape({10}));
}

View File

@@ -0,0 +1,145 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include <fft_base_shape_inference.hpp>
#include <openvino/op/dft.hpp>
#include <openvino/op/idft.hpp>
#include <openvino/op/ops.hpp>
#include <openvino/op/parameter.hpp>
#include <utils/shape_inference/shape_inference.hpp>
#include <utils/shape_inference/static_shape.hpp>
using namespace ov;
static std::shared_ptr<op::v7::DFT> build_dft() {
auto input_shape = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
auto axes = std::make_shared<ov::op::v0::Parameter>(element::i32, PartialShape::dynamic());
auto DFT = std::make_shared<ov::op::v7::DFT>(input_shape, axes);
return DFT;
}
static std::shared_ptr<op::v7::DFT> build_dft_signal() {
auto input_shape = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
auto axes = std::make_shared<ov::op::v0::Parameter>(element::i32, PartialShape::dynamic());
auto signal = std::make_shared<ov::op::v0::Parameter>(element::i32, PartialShape::dynamic());
auto DFT_signal = std::make_shared<ov::op::v7::DFT>(input_shape, axes, signal);
return DFT_signal;
}
static std::shared_ptr<op::v7::DFT> build_dft_constant() {
auto input_shape = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
auto axes = std::make_shared<ov::op::v0::Constant>(element::i32, Shape{2}, std::vector<int32_t>{1, 2});
auto signal = std::make_shared<ov::op::v0::Constant>(element::i32, Shape{2}, std::vector<int32_t>{512, 100});
auto DFT_signal = std::make_shared<ov::op::v7::DFT>(input_shape, axes, signal);
return DFT_signal;
}
static std::shared_ptr<op::v7::IDFT> build_idft() {
auto input_shape = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
auto axes = std::make_shared<ov::op::v0::Parameter>(element::i32, PartialShape::dynamic());
auto IDFT = std::make_shared<ov::op::v7::IDFT>(input_shape, axes);
return IDFT;
}
static std::shared_ptr<op::v7::IDFT> build_idft_signal() {
auto input_shape = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
auto axes = std::make_shared<ov::op::v0::Parameter>(element::i32, PartialShape::dynamic());
auto signal = std::make_shared<ov::op::v0::Parameter>(element::i32, PartialShape::dynamic());
auto IDFT_signal = std::make_shared<ov::op::v7::IDFT>(input_shape, axes, signal);
return IDFT_signal;
}
TEST(StaticShapeInferenceTest, DFTTest) {
auto DFT = build_dft();
std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>> constant_data;
int32_t axes_val[] = {1, 2};
constant_data[1] = std::make_shared<ngraph::runtime::HostTensor>(ngraph::element::Type_t::i32, Shape{2}, axes_val);
std::vector<StaticShape> static_input_shapes = {StaticShape{1, 320, 320, 2}, StaticShape{2}},
static_output_shapes = {StaticShape{}};
shape_inference(DFT.get(), static_input_shapes, static_output_shapes, constant_data);
ASSERT_EQ(static_output_shapes[0], StaticShape({1, 320, 320, 2}));
}
TEST(StaticShapeInferenceTest, DFTSignalTest) {
auto DFT = build_dft_signal();
std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>> constant_data;
int32_t axes_val[] = {1, 2};
int32_t signal_val[] = {512, 100};
constant_data[1] = std::make_shared<ngraph::runtime::HostTensor>(ngraph::element::Type_t::i32, Shape{2}, axes_val);
constant_data[2] =
std::make_shared<ngraph::runtime::HostTensor>(ngraph::element::Type_t::i32, Shape{2}, signal_val);
std::vector<StaticShape> static_input_shapes = {StaticShape{1, 320, 320, 2}, StaticShape{2}, StaticShape{2}},
static_output_shapes = {StaticShape{}};
shape_inference(DFT.get(), static_input_shapes, static_output_shapes, constant_data);
ASSERT_EQ(static_output_shapes[0], StaticShape({1, 512, 100, 2}));
}
TEST(StaticShapeInferenceTest, DFTConstantTest) {
auto DFT = build_dft_constant();
std::vector<StaticShape> static_input_shapes = {StaticShape{1, 320, 320, 2}, StaticShape{2}, StaticShape{2}},
static_output_shapes = {StaticShape{}};
shape_inference(DFT.get(), static_input_shapes, static_output_shapes);
ASSERT_EQ(static_output_shapes[0], StaticShape({1, 512, 100, 2}));
}
TEST(StaticShapeInferenceTest, DFTSignalMissingConstDataTest) {
auto DFT = build_dft_signal();
std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>> constant_data;
int32_t axes_val[] = {1, 2};
constant_data[1] = std::make_shared<ngraph::runtime::HostTensor>(ngraph::element::Type_t::i32, Shape{2}, axes_val);
std::vector<StaticShape> static_input_shapes = {StaticShape{1, 320, 320, 2}, StaticShape{2}, StaticShape{2}},
static_output_shapes = {StaticShape{}};
EXPECT_THROW(shape_inference(DFT.get(), static_input_shapes, static_output_shapes, constant_data),
NodeValidationFailure);
}
TEST(StaticShapeInferenceTest, IDFTTest) {
auto IDFT = build_idft();
std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>> constant_data;
int32_t axes_val[] = {1, 2};
constant_data[1] = std::make_shared<ngraph::runtime::HostTensor>(ngraph::element::Type_t::i32, Shape{2}, axes_val);
std::vector<StaticShape> static_input_shapes = {StaticShape{1, 320, 320, 2}, StaticShape{2}},
static_output_shapes = {StaticShape{}};
shape_inference(IDFT.get(), static_input_shapes, static_output_shapes, constant_data);
ASSERT_EQ(static_output_shapes[0], StaticShape({1, 320, 320, 2}));
}
TEST(StaticShapeInferenceTest, IDFTSignalTest) {
auto IDFT = build_idft_signal();
std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>> constant_data;
int32_t axes_val[] = {1, 2};
int32_t signal_val[] = {512, 100};
constant_data[1] = std::make_shared<ngraph::runtime::HostTensor>(ngraph::element::Type_t::i32, Shape{2}, axes_val);
constant_data[2] =
std::make_shared<ngraph::runtime::HostTensor>(ngraph::element::Type_t::i32, Shape{2}, signal_val);
std::vector<StaticShape> static_input_shapes = {StaticShape{1, 320, 320, 2}, StaticShape{2}, StaticShape{2}},
static_output_shapes = {StaticShape{}};
shape_inference(IDFT.get(), static_input_shapes, static_output_shapes, constant_data);
ASSERT_EQ(static_output_shapes[0], StaticShape({1, 512, 100, 2}));
}
TEST(StaticShapeInferenceTest, IDFTSignalMissingConstDataTest) {
auto IDFT = build_idft_signal();
std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>> constant_data;
int32_t axes_val[] = {1, 2};
constant_data[1] = std::make_shared<ngraph::runtime::HostTensor>(ngraph::element::Type_t::i32, Shape{2}, axes_val);
std::vector<StaticShape> static_input_shapes = {StaticShape{1, 320, 320, 2}, StaticShape{2}, StaticShape{2}},
static_output_shapes = {StaticShape{}};
EXPECT_THROW(shape_inference(IDFT.get(), static_input_shapes, static_output_shapes, constant_data),
NodeValidationFailure);
}