Mang/shape inference (#8412)
* Implement DFT and IDFT shape inference * Implement CTCLoss shape inference * Fix error message. * Refactor test case. * Apply review comments * Apply review comments * Fix clang format error * Fix merge error * Remove axes_vector.
This commit is contained in:
@@ -29,6 +29,8 @@
|
||||
#include "reduce_shape_inference.hpp"
|
||||
#include "scatter_elements_update_shape_inference.hpp"
|
||||
#include "scatter_nd_base_shape_inference.hpp"
|
||||
#include "ctc_loss_shape_inference.hpp"
|
||||
#include "fft_base_shape_inference.hpp"
|
||||
#include "shape_inference.hpp"
|
||||
#include "shape_nodes.hpp"
|
||||
#include "fake_quantize.hpp"
|
||||
@@ -184,6 +186,12 @@ void shape_inference(ov::Node* op,
|
||||
shape_infer(node, input_shapes, output_shapes);
|
||||
} else if (auto node = ov::as_type<ov::opset1::OneHot>(op)) {
|
||||
shape_infer(node, input_shapes, output_shapes, constant_data);
|
||||
} else if (auto node = ov::as_type<ov::opset4::CTCLoss>(op)) {
|
||||
shape_infer(node, input_shapes, output_shapes);
|
||||
} else if (auto node = ov::as_type<ov::opset7::DFT>(op)) {
|
||||
shape_infer(node, input_shapes, output_shapes, constant_data);
|
||||
} else if (auto node = ov::as_type<ov::opset7::IDFT>(op)) {
|
||||
shape_infer(node, input_shapes, output_shapes, constant_data);
|
||||
} else {
|
||||
ngraph::OutputVector new_inputs;
|
||||
for (size_t i = 0; i < op->get_input_size(); ++i) {
|
||||
|
||||
@@ -33,8 +33,6 @@ protected:
|
||||
/// \param axes Axes to perform FFT
|
||||
/// \param signal_size Signal sizes for 'axes'
|
||||
FFTBase(const Output<Node>& data, const Output<Node>& axes, const Output<Node>& signal_size);
|
||||
|
||||
void validate();
|
||||
};
|
||||
} // namespace util
|
||||
} // namespace op
|
||||
|
||||
133
src/core/shape_inference/include/ctc_loss_shape_inference.hpp
Normal file
133
src/core/shape_inference/include/ctc_loss_shape_inference.hpp
Normal file
@@ -0,0 +1,133 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <openvino/op/ctc_loss.hpp>
|
||||
|
||||
namespace ov {
|
||||
namespace op {
|
||||
namespace v4 {
|
||||
|
||||
template <class T>
|
||||
void shape_infer(const CTCLoss* op, const std::vector<T>& input_shapes, std::vector<T>& output_shapes) {
|
||||
using DimType = typename std::iterator_traits<typename T::iterator>::value_type;
|
||||
NODE_VALIDATION_CHECK(op, (input_shapes.size() == 4 || input_shapes.size() == 5) && output_shapes.size() == 1);
|
||||
|
||||
// check ranks of input tensors
|
||||
const auto& logits_pshape = input_shapes[0];
|
||||
const auto& logit_length_pshape = input_shapes[1];
|
||||
const auto& labels_pshape = input_shapes[2];
|
||||
const auto& label_length_pshape = input_shapes[3];
|
||||
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
logits_pshape.rank().compatible(3),
|
||||
"Expected a 3D tensor for logits. Got: ",
|
||||
logits_pshape);
|
||||
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
logit_length_pshape.rank().compatible(1),
|
||||
"Expected a 1D tensor for logit length. Got: ",
|
||||
logit_length_pshape);
|
||||
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
labels_pshape.rank().compatible(2),
|
||||
"Expected a 2D tensor for labels. Got: ",
|
||||
labels_pshape);
|
||||
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
label_length_pshape.rank().compatible(1),
|
||||
"Expected a 1D tensor for label length. Got: ",
|
||||
label_length_pshape);
|
||||
|
||||
// check optional input shape: blank index
|
||||
if (input_shapes.size() == 5) {
|
||||
const auto& blank_index_pshape = input_shapes[4];
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
blank_index_pshape.rank().compatible(0),
|
||||
"Expected a scalar for blank index. Got: ",
|
||||
blank_index_pshape);
|
||||
}
|
||||
|
||||
// check shapes of input tensors
|
||||
DimType batch_size = 1;
|
||||
bool is_batch_size_set = false;
|
||||
DimType time_steps = 1;
|
||||
bool is_time_steps_set = false;
|
||||
|
||||
if (logits_pshape.rank().is_static()) {
|
||||
batch_size = logits_pshape[0];
|
||||
is_batch_size_set = true;
|
||||
time_steps = logits_pshape[1];
|
||||
is_time_steps_set = true;
|
||||
}
|
||||
|
||||
if (logit_length_pshape.rank().is_static()) {
|
||||
if (is_batch_size_set) {
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
DimType::merge(batch_size, batch_size, logit_length_pshape[0]),
|
||||
"The first dimension of logit length must be equal to the first dimension ",
|
||||
"of the logits. Got: ",
|
||||
logit_length_pshape[0],
|
||||
" and: ",
|
||||
batch_size);
|
||||
} else {
|
||||
batch_size = logit_length_pshape[0];
|
||||
is_batch_size_set = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (labels_pshape.rank().is_static()) {
|
||||
if (is_batch_size_set) {
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
DimType::merge(batch_size, batch_size, labels_pshape[0]),
|
||||
"The first dimension of labels must be equal to the first dimension ",
|
||||
"of the logits and the logit length. Got: ",
|
||||
labels_pshape[0],
|
||||
" and: ",
|
||||
batch_size);
|
||||
} else {
|
||||
batch_size = labels_pshape[0];
|
||||
is_batch_size_set = true;
|
||||
}
|
||||
|
||||
if (is_time_steps_set) {
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
DimType::merge(time_steps, time_steps, labels_pshape[1]),
|
||||
"The second dimension of labels must be equal to the second dimension ",
|
||||
"of logits. Got: ",
|
||||
labels_pshape[1],
|
||||
" and: ",
|
||||
time_steps);
|
||||
}
|
||||
}
|
||||
|
||||
if (label_length_pshape.rank().is_static()) {
|
||||
if (is_batch_size_set) {
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
DimType::merge(batch_size, batch_size, label_length_pshape[0]),
|
||||
"The first dimension of label length must be equal to the first dimension ",
|
||||
"of the logits, the logit length and labels. Got: ",
|
||||
label_length_pshape[0],
|
||||
" and: ",
|
||||
batch_size);
|
||||
} else {
|
||||
batch_size = label_length_pshape[0];
|
||||
is_batch_size_set = true;
|
||||
}
|
||||
}
|
||||
|
||||
auto& output_shape = output_shapes[0];
|
||||
output_shape.resize(1);
|
||||
|
||||
if (is_batch_size_set) {
|
||||
output_shape[0] = batch_size;
|
||||
} else {
|
||||
output_shape[0] = Dimension::dynamic();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace v4
|
||||
} // namespace op
|
||||
} // namespace ov
|
||||
117
src/core/shape_inference/include/fft_base_shape_inference.hpp
Normal file
117
src/core/shape_inference/include/fft_base_shape_inference.hpp
Normal file
@@ -0,0 +1,117 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
#pragma once
|
||||
#include <openvino/op/util/fft_base.hpp>
|
||||
|
||||
#include "openvino/core/axis_vector.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
template <class T>
|
||||
void shape_infer(const ov::op::util::FFTBase* op,
|
||||
const std::vector<T>& input_shapes,
|
||||
std::vector<T>& output_shapes,
|
||||
const std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>>& constant_data = {}) {
|
||||
using DimType = typename std::iterator_traits<typename T::iterator>::value_type;
|
||||
NODE_VALIDATION_CHECK(op, (input_shapes.size() == 2 || input_shapes.size() == 3) && output_shapes.size() == 1);
|
||||
|
||||
const auto& input_shape = input_shapes[0];
|
||||
const auto& axes_shape = input_shapes[1];
|
||||
auto& output_shape = output_shapes[0];
|
||||
std::vector<int64_t> axes;
|
||||
bool axes_are_known = get_data_as_int64<T>(1, op, axes, constant_data);
|
||||
|
||||
if (input_shape.rank().is_static()) {
|
||||
const auto input_rank = input_shape.size();
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
input_rank >= 2,
|
||||
"The input rank must be greater or equal to 2. Got input rank: ",
|
||||
input_rank);
|
||||
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
input_shape[input_rank - 1].compatible(2),
|
||||
"The last dimension of input data must be 2. Got: ",
|
||||
input_shape[input_rank - 1]);
|
||||
|
||||
if (axes_shape.is_static()) {
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
input_rank >= static_cast<int64_t>(axes_shape[0].get_length() + 1),
|
||||
"The input rank must be greater than number of FFT op axes. Got "
|
||||
"input rank: ",
|
||||
input_rank,
|
||||
", number of axes: ",
|
||||
axes_shape[0].get_length());
|
||||
}
|
||||
|
||||
// FFT operation supports for negative axes to transform. More precisely, according to
|
||||
// the FFT operation specification, axes should be integers from -(r - 1) to (r - 2)
|
||||
// inclusively, where r = rank(data). A negative axis 'a' is interpreted as an axis
|
||||
// 'r - 1 + a'. The reason is the following: real input tensor of the shape
|
||||
// [n_0, ..., n_{r - 1}, 2] is interpreted as a complex tensor with the shape
|
||||
// [n_0, ..., n_{r - 1}].
|
||||
if (axes_shape.rank().is_static() && axes_are_known) {
|
||||
for (int64_t& axis : axes) {
|
||||
if (axis < 0) {
|
||||
axis += input_rank - 1;
|
||||
}
|
||||
}
|
||||
|
||||
ov::AxisSet axes_set;
|
||||
for (const auto& axis : axes) {
|
||||
axes_set.insert(static_cast<size_t>(axis));
|
||||
}
|
||||
|
||||
NODE_VALIDATION_CHECK(op, axes.size() == axes_set.size(), "FFT op axes must be unique.");
|
||||
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
std::find(axes.begin(), axes.end(), input_rank - 1) == axes.end(),
|
||||
"FFT op axes cannot contain the last axis.");
|
||||
}
|
||||
}
|
||||
|
||||
NODE_VALIDATION_CHECK(op, axes_shape.rank().compatible(1), "FFT op axes input must be 1D tensor.");
|
||||
|
||||
if (input_shapes.size() == 3) {
|
||||
const auto& signal_size_shape = input_shapes[2];
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
signal_size_shape.rank().compatible(1),
|
||||
"FFT op signal size input must be 1D tensor. Got signal: ",
|
||||
signal_size_shape);
|
||||
|
||||
if (axes_shape.is_static() && signal_size_shape.is_static()) {
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
axes_shape[0].compatible(signal_size_shape[0]),
|
||||
"Sizes of inputs 'axes' and 'signal_size' must be equal. Got "
|
||||
"size of 'axes': ",
|
||||
axes_shape[0],
|
||||
"size of 'signal_size': ",
|
||||
signal_size_shape[0]);
|
||||
}
|
||||
}
|
||||
|
||||
output_shape = input_shape;
|
||||
if (input_shape.rank().is_static() && axes_shape.rank().is_static() && input_shapes.size() == 3 && axes_are_known) {
|
||||
const auto& signal_size_shape = input_shapes[2];
|
||||
std::vector<int64_t> signal_size;
|
||||
bool status_signal_size = get_data_as_int64<T>(2, op, signal_size, constant_data);
|
||||
|
||||
if (signal_size_shape.rank().is_static() && status_signal_size) {
|
||||
size_t num_of_axes = axes.size();
|
||||
for (size_t i = 0; i < num_of_axes; ++i) {
|
||||
if (signal_size[i] == -1) {
|
||||
continue;
|
||||
}
|
||||
output_shape[axes[i]] = DimType(signal_size[i]);
|
||||
}
|
||||
} else if (signal_size_shape.rank().is_static()) {
|
||||
for (int64_t& axis : axes) {
|
||||
output_shape[axis] = ov::Dimension::dynamic();
|
||||
}
|
||||
}
|
||||
} else if (input_shape.rank().is_static() && (axes_shape.rank().is_dynamic() || !axes_are_known)) {
|
||||
const auto input_rank = input_shape.size();
|
||||
for (int64_t i = 0; i < input_rank - 1; ++i) {
|
||||
output_shape[i] = ov::Dimension::dynamic();
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -4,6 +4,8 @@
|
||||
|
||||
#include "ngraph/op/ctc_loss.hpp"
|
||||
|
||||
#include <ctc_loss_shape_inference.hpp>
|
||||
|
||||
#include "itt.hpp"
|
||||
|
||||
using namespace std;
|
||||
@@ -77,119 +79,22 @@ void op::v4::CTCLoss::validate_and_infer_types() {
|
||||
blank_index_type);
|
||||
}
|
||||
|
||||
// check ranks of input tensors
|
||||
const auto& logits_pshape = get_input_partial_shape(0);
|
||||
const auto& logit_length_pshape = get_input_partial_shape(1);
|
||||
const auto& labels_pshape = get_input_partial_shape(2);
|
||||
const auto& label_length_pshape = get_input_partial_shape(3);
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
logits_pshape.rank().compatible(3),
|
||||
"Expected a 3D tensor for logits. Got: ",
|
||||
logits_pshape);
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
logit_length_pshape.rank().compatible(1),
|
||||
"Expected a 1D tensor for logit length. Got: ",
|
||||
logit_length_pshape);
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
labels_pshape.rank().compatible(2),
|
||||
"Expected a 2D tensor for labels. Got: ",
|
||||
labels_pshape);
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
label_length_pshape.rank().compatible(1),
|
||||
"Expected a 1D tensor for label length. Got: ",
|
||||
label_length_pshape);
|
||||
|
||||
// check optional input shape: blank index
|
||||
std::vector<ov::PartialShape> input_shapes;
|
||||
if (get_input_size() == 5) {
|
||||
const auto& blank_index_pshape = get_input_partial_shape(4);
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
blank_index_pshape.rank().compatible(0),
|
||||
"Expected a scalar for blank index. Got: ",
|
||||
blank_index_pshape);
|
||||
}
|
||||
|
||||
// check shapes of input tensors
|
||||
size_t batch_size = 1;
|
||||
bool is_batch_size_set = false;
|
||||
size_t time_steps = 1;
|
||||
bool is_time_steps_set = false;
|
||||
|
||||
if (logits_pshape.rank().is_static()) {
|
||||
if (logits_pshape[0].is_static()) {
|
||||
batch_size = logits_pshape[0].get_length();
|
||||
is_batch_size_set = true;
|
||||
}
|
||||
if (logits_pshape[1].is_static()) {
|
||||
time_steps = logits_pshape[1].get_length();
|
||||
is_time_steps_set = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (logit_length_pshape.is_static()) {
|
||||
if (is_batch_size_set) {
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
logit_length_pshape[0].compatible(batch_size),
|
||||
"The first dimension of logit length must be equal to the first dimension ",
|
||||
"of the logits. Got: ",
|
||||
logit_length_pshape[0],
|
||||
" and: ",
|
||||
batch_size);
|
||||
} else if (logit_length_pshape[0].is_static()) {
|
||||
batch_size = logit_length_pshape[0].get_length();
|
||||
is_batch_size_set = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (labels_pshape.is_static()) {
|
||||
if (is_batch_size_set) {
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
labels_pshape[0].compatible(batch_size),
|
||||
"The first dimension of labels must be equal to the first dimension ",
|
||||
"of the logits and the logit length. Got: ",
|
||||
labels_pshape[0],
|
||||
" and: ",
|
||||
batch_size);
|
||||
} else if (labels_pshape[0].is_static()) {
|
||||
batch_size = labels_pshape[0].get_length();
|
||||
is_batch_size_set = true;
|
||||
}
|
||||
|
||||
if (is_time_steps_set) {
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
labels_pshape[1].compatible(time_steps),
|
||||
"The second dimension of labels must be equal to the second dimension ",
|
||||
"of logits. Got: ",
|
||||
labels_pshape[1],
|
||||
" and: ",
|
||||
time_steps);
|
||||
}
|
||||
}
|
||||
|
||||
if (label_length_pshape.is_static()) {
|
||||
if (!is_batch_size_set && label_length_pshape[0].is_static()) {
|
||||
batch_size = label_length_pshape[0].get_length();
|
||||
is_batch_size_set = true;
|
||||
}
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
label_length_pshape[0].compatible(batch_size),
|
||||
"The first dimension of label length must be equal to the first dimension ",
|
||||
"of the logits, the logit length and labels. Got: ",
|
||||
label_length_pshape[0],
|
||||
" and: ",
|
||||
batch_size);
|
||||
}
|
||||
|
||||
// set output shape
|
||||
set_output_size(1);
|
||||
if (is_batch_size_set) {
|
||||
set_output_type(0, logits_type, ov::Shape{batch_size});
|
||||
input_shapes = {logits_pshape, logit_length_pshape, labels_pshape, label_length_pshape, blank_index_pshape};
|
||||
} else {
|
||||
set_output_type(0, logits_type, ov::PartialShape{Dimension::dynamic()});
|
||||
input_shapes = {logits_pshape, logit_length_pshape, labels_pshape, label_length_pshape};
|
||||
}
|
||||
std::vector<ov::PartialShape> output_shapes = {ov::PartialShape{}};
|
||||
|
||||
shape_infer(this, input_shapes, output_shapes);
|
||||
set_output_type(0, logits_type, output_shapes[0]);
|
||||
}
|
||||
|
||||
bool op::v4::CTCLoss::visit_attributes(AttributeVisitor& visitor) {
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
#include "ngraph/op/util/fft_base.hpp"
|
||||
|
||||
#include <fft_base_shape_inference.hpp>
|
||||
#include <ngraph/validation_util.hpp>
|
||||
|
||||
#include "itt.hpp"
|
||||
@@ -23,9 +24,10 @@ bool ov::op::util::FFTBase::visit_attributes(AttributeVisitor& visitor) {
|
||||
return true;
|
||||
}
|
||||
|
||||
void ov::op::util::FFTBase::validate() {
|
||||
size_t num_of_inputs = get_input_size();
|
||||
void ov::op::util::FFTBase::validate_and_infer_types() {
|
||||
NGRAPH_OP_SCOPE(util_FFTBase_validate_and_infer_types);
|
||||
|
||||
size_t num_of_inputs = get_input_size();
|
||||
NODE_VALIDATION_CHECK(this, num_of_inputs == 2 || num_of_inputs == 3, "FFT op must have 2 or 3 inputs.");
|
||||
|
||||
element::Type input_et = get_input_element_type(0);
|
||||
@@ -38,164 +40,25 @@ void ov::op::util::FFTBase::validate() {
|
||||
axes_et == element::i64 || axes_et == element::i32,
|
||||
"FFT op axes element type must be i32 or i64");
|
||||
|
||||
const auto& input_shape = PartialShape(get_input_partial_shape(0));
|
||||
if (input_shape.rank().is_static()) {
|
||||
const auto input_rank = input_shape.rank().get_length();
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
input_rank >= 2,
|
||||
"The input rank must be greater or equal to 2. Got input rank: ",
|
||||
input_rank);
|
||||
|
||||
auto last_dim_with_two = input_shape[input_rank - 1] & Dimension(2);
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
!last_dim_with_two.get_interval().empty(),
|
||||
"The last dimension of input data must be 2. Got: ",
|
||||
input_shape[input_rank - 1]);
|
||||
}
|
||||
|
||||
const auto& axes_shape = PartialShape(get_input_partial_shape(1));
|
||||
if (axes_shape.rank().is_static()) {
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
axes_shape.rank().get_length() == 1,
|
||||
"FFT op axes input must be 1D tensor. Got axes input rank: ",
|
||||
axes_shape.rank().get_length());
|
||||
}
|
||||
|
||||
if (input_shape.rank().is_static() && axes_shape.is_static()) {
|
||||
const auto input_rank = input_shape.rank().get_length();
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
input_rank >= static_cast<int64_t>(axes_shape.to_shape()[0] + 1),
|
||||
"The input rank must be greater than number of FFT op axes. Got "
|
||||
"input rank: ",
|
||||
input_rank,
|
||||
", number of axes: ",
|
||||
axes_shape.to_shape()[0]);
|
||||
}
|
||||
|
||||
if (input_shape.rank().is_static() && ov::is_type<ngraph::op::Constant>(input_value(1).get_node())) {
|
||||
const auto input_rank = input_shape.rank().get_length();
|
||||
const auto& const_axes = get_constant_from_source(input_value(1));
|
||||
auto axes = const_axes->cast_vector<int64_t>();
|
||||
|
||||
// FFT operation supports for negative axes to transform. More precisely, according to
|
||||
// the FFT operation specification, axes should be integers from -(r - 1) to (r - 2)
|
||||
// inclusively, where r = rank(data). A negative axis 'a' is interpreted as an axis
|
||||
// 'r - 1 + a'. The reason is the following: real input tensor of the shape
|
||||
// [n_0, ..., n_{r - 1}, 2] is interpreted as a complex tensor with the shape
|
||||
// [n_0, ..., n_{r - 1}].
|
||||
for (int64_t& axis : axes) {
|
||||
if (axis < 0) {
|
||||
axis += input_rank - 1;
|
||||
}
|
||||
}
|
||||
|
||||
AxisVector axes_vector;
|
||||
AxisSet axes_set;
|
||||
for (const int64_t axis : axes) {
|
||||
axes_vector.push_back(static_cast<size_t>(axis));
|
||||
axes_set.insert(static_cast<size_t>(axis));
|
||||
}
|
||||
|
||||
NODE_VALIDATION_CHECK(this, axes.size() == axes_set.size(), "FFT op axes must be unique. Got: ", axes_vector);
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
std::find(axes.begin(), axes.end(), input_rank - 1) == axes.end(),
|
||||
"FFT op axes cannot contain the last axis. Got axes: ",
|
||||
axes_vector);
|
||||
}
|
||||
|
||||
if (num_of_inputs == 3) {
|
||||
element::Type signal_size_et = get_input_element_type(2);
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
signal_size_et == element::i64 || signal_size_et == element::i32,
|
||||
"FFT op signal_size element type must be i32 or i64");
|
||||
|
||||
const auto& signal_size_shape = PartialShape(get_input_partial_shape(2));
|
||||
if (signal_size_shape.rank().is_static()) {
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
signal_size_shape.rank().get_length() == 1,
|
||||
"FFT op signal size input must be 1D tensor. Got signal size "
|
||||
"input rank: ",
|
||||
signal_size_shape.rank().get_length());
|
||||
}
|
||||
|
||||
if (axes_shape.is_static() && signal_size_shape.is_static()) {
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
axes_shape.to_shape()[0] == signal_size_shape.to_shape()[0],
|
||||
"Sizes of inputs 'axes' and 'signal_size' must be equal. Got "
|
||||
"size of 'axes': ",
|
||||
axes_shape.to_shape()[0],
|
||||
"size of 'signal_size': ",
|
||||
signal_size_shape.to_shape()[0]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ov::op::util::FFTBase::validate_and_infer_types() {
|
||||
NGRAPH_OP_SCOPE(util_FFTBase_validate_and_infer_types);
|
||||
validate();
|
||||
|
||||
const auto& input_shape = PartialShape(get_input_partial_shape(0));
|
||||
const auto& axes_shape = PartialShape(get_input_partial_shape(1));
|
||||
PartialShape output_shape = input_shape;
|
||||
if (input_shape.rank().is_dynamic()) {
|
||||
set_output_type(0, get_input_element_type(0), output_shape);
|
||||
return;
|
||||
}
|
||||
|
||||
const auto input_rank = input_shape.rank().get_length();
|
||||
|
||||
if (axes_shape.rank().is_dynamic() || !ov::is_type<ngraph::op::Constant>(input_value(1).get_node())) {
|
||||
for (int64_t i = 0; i < input_rank - 1; ++i) {
|
||||
output_shape[i] = Dimension::dynamic();
|
||||
}
|
||||
set_output_type(0, get_input_element_type(0), output_shape);
|
||||
return;
|
||||
}
|
||||
std::vector<ov::PartialShape> output_shapes = {ov::PartialShape()};
|
||||
std::vector<ov::PartialShape> input_shapes;
|
||||
|
||||
const auto& data = get_input_partial_shape(0);
|
||||
const auto& axes = get_input_partial_shape(1);
|
||||
if (input_values().size() == 2) {
|
||||
set_output_type(0, get_input_element_type(0), output_shape);
|
||||
return;
|
||||
input_shapes = {data, axes};
|
||||
} else {
|
||||
const auto& signal_size = get_input_partial_shape(2);
|
||||
input_shapes = {data, axes, signal_size};
|
||||
}
|
||||
|
||||
const auto& signal_size_shape = PartialShape(get_input_partial_shape(2));
|
||||
if (signal_size_shape.rank().is_dynamic()) {
|
||||
set_output_type(0, get_input_element_type(0), output_shape);
|
||||
return;
|
||||
}
|
||||
|
||||
const auto& const_axes = get_constant_from_source(input_value(1));
|
||||
auto axes = const_axes->cast_vector<int64_t>();
|
||||
// FFT operation supports for negative axes to transform. More precisely, according to
|
||||
// the FFT operation specification, axes should be integers from -(r - 1) to (r - 2)
|
||||
// inclusively, where r = rank(data). A negative axis 'a' is interpreted as an axis
|
||||
// 'r - 1 + a'. The reason is the following: real input tensor of the shape
|
||||
// [n_0, ..., n_{r - 1}, 2] is interpreted as a complex tensor with the shape
|
||||
// [n_0, ..., n_{r - 1}].
|
||||
for (int64_t& axis : axes) {
|
||||
if (axis < 0) {
|
||||
axis += input_rank - 1;
|
||||
}
|
||||
}
|
||||
|
||||
if (!ov::is_type<ngraph::op::Constant>(input_value(2).get_node())) {
|
||||
for (int64_t axis : axes) {
|
||||
output_shape[axis] = Dimension::dynamic();
|
||||
}
|
||||
set_output_type(0, get_input_element_type(0), output_shape);
|
||||
return;
|
||||
}
|
||||
|
||||
const auto& const_signal_size = get_constant_from_source(input_value(2));
|
||||
const auto signal_size = const_signal_size->cast_vector<int64_t>();
|
||||
|
||||
size_t num_of_axes = axes.size();
|
||||
for (size_t i = 0; i < num_of_axes; ++i) {
|
||||
if (signal_size[i] == -1) {
|
||||
continue;
|
||||
}
|
||||
output_shape[axes[i]] = Dimension(signal_size[i]);
|
||||
}
|
||||
|
||||
set_output_type(0, get_input_element_type(0), output_shape);
|
||||
shape_infer(this, input_shapes, output_shapes);
|
||||
set_output_type(0, get_input_element_type(0), output_shapes[0]);
|
||||
}
|
||||
|
||||
@@ -0,0 +1,34 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <ctc_loss_shape_inference.hpp>
|
||||
#include <openvino/op/ctc_loss.hpp>
|
||||
#include <openvino/op/parameter.hpp>
|
||||
#include <utils/shape_inference/shape_inference.hpp>
|
||||
#include <utils/shape_inference/static_shape.hpp>
|
||||
|
||||
using namespace ov;
|
||||
|
||||
TEST(StaticShapeInferenceTest, CTCLossTest) {
|
||||
const auto& logits = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1});
|
||||
const auto& logit_length = std::make_shared<ov::op::v0::Parameter>(element::i32, PartialShape{-1});
|
||||
const auto& labels = std::make_shared<ov::op::v0::Parameter>(element::i32, PartialShape{-1, -1});
|
||||
const auto& label_length = std::make_shared<ov::op::v0::Parameter>(element::i32, PartialShape{-1});
|
||||
const auto& blank_index = std::make_shared<ov::op::v0::Parameter>(element::i32, ov::Shape{});
|
||||
|
||||
// create CTCLoss node
|
||||
auto ctc_loss = std::make_shared<op::v4::CTCLoss>(logits, logit_length, labels, label_length, blank_index);
|
||||
|
||||
std::vector<StaticShape> static_input_shapes = {StaticShape{10, 120, 28},
|
||||
StaticShape{10},
|
||||
StaticShape{10, 120},
|
||||
StaticShape{10},
|
||||
ov::Shape{}},
|
||||
static_output_shapes = {StaticShape{}};
|
||||
shape_inference(ctc_loss.get(), static_input_shapes, static_output_shapes);
|
||||
|
||||
ASSERT_EQ(static_output_shapes[0], StaticShape({10}));
|
||||
}
|
||||
@@ -0,0 +1,145 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <fft_base_shape_inference.hpp>
|
||||
#include <openvino/op/dft.hpp>
|
||||
#include <openvino/op/idft.hpp>
|
||||
#include <openvino/op/ops.hpp>
|
||||
#include <openvino/op/parameter.hpp>
|
||||
#include <utils/shape_inference/shape_inference.hpp>
|
||||
#include <utils/shape_inference/static_shape.hpp>
|
||||
|
||||
using namespace ov;
|
||||
|
||||
static std::shared_ptr<op::v7::DFT> build_dft() {
|
||||
auto input_shape = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
|
||||
auto axes = std::make_shared<ov::op::v0::Parameter>(element::i32, PartialShape::dynamic());
|
||||
auto DFT = std::make_shared<ov::op::v7::DFT>(input_shape, axes);
|
||||
return DFT;
|
||||
}
|
||||
|
||||
static std::shared_ptr<op::v7::DFT> build_dft_signal() {
|
||||
auto input_shape = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
|
||||
auto axes = std::make_shared<ov::op::v0::Parameter>(element::i32, PartialShape::dynamic());
|
||||
auto signal = std::make_shared<ov::op::v0::Parameter>(element::i32, PartialShape::dynamic());
|
||||
auto DFT_signal = std::make_shared<ov::op::v7::DFT>(input_shape, axes, signal);
|
||||
return DFT_signal;
|
||||
}
|
||||
|
||||
static std::shared_ptr<op::v7::DFT> build_dft_constant() {
|
||||
auto input_shape = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
|
||||
auto axes = std::make_shared<ov::op::v0::Constant>(element::i32, Shape{2}, std::vector<int32_t>{1, 2});
|
||||
auto signal = std::make_shared<ov::op::v0::Constant>(element::i32, Shape{2}, std::vector<int32_t>{512, 100});
|
||||
auto DFT_signal = std::make_shared<ov::op::v7::DFT>(input_shape, axes, signal);
|
||||
return DFT_signal;
|
||||
}
|
||||
|
||||
static std::shared_ptr<op::v7::IDFT> build_idft() {
|
||||
auto input_shape = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
|
||||
auto axes = std::make_shared<ov::op::v0::Parameter>(element::i32, PartialShape::dynamic());
|
||||
auto IDFT = std::make_shared<ov::op::v7::IDFT>(input_shape, axes);
|
||||
return IDFT;
|
||||
}
|
||||
|
||||
static std::shared_ptr<op::v7::IDFT> build_idft_signal() {
|
||||
auto input_shape = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
|
||||
auto axes = std::make_shared<ov::op::v0::Parameter>(element::i32, PartialShape::dynamic());
|
||||
auto signal = std::make_shared<ov::op::v0::Parameter>(element::i32, PartialShape::dynamic());
|
||||
auto IDFT_signal = std::make_shared<ov::op::v7::IDFT>(input_shape, axes, signal);
|
||||
return IDFT_signal;
|
||||
}
|
||||
|
||||
TEST(StaticShapeInferenceTest, DFTTest) {
|
||||
auto DFT = build_dft();
|
||||
std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>> constant_data;
|
||||
int32_t axes_val[] = {1, 2};
|
||||
constant_data[1] = std::make_shared<ngraph::runtime::HostTensor>(ngraph::element::Type_t::i32, Shape{2}, axes_val);
|
||||
|
||||
std::vector<StaticShape> static_input_shapes = {StaticShape{1, 320, 320, 2}, StaticShape{2}},
|
||||
static_output_shapes = {StaticShape{}};
|
||||
|
||||
shape_inference(DFT.get(), static_input_shapes, static_output_shapes, constant_data);
|
||||
ASSERT_EQ(static_output_shapes[0], StaticShape({1, 320, 320, 2}));
|
||||
}
|
||||
|
||||
TEST(StaticShapeInferenceTest, DFTSignalTest) {
|
||||
auto DFT = build_dft_signal();
|
||||
std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>> constant_data;
|
||||
int32_t axes_val[] = {1, 2};
|
||||
int32_t signal_val[] = {512, 100};
|
||||
constant_data[1] = std::make_shared<ngraph::runtime::HostTensor>(ngraph::element::Type_t::i32, Shape{2}, axes_val);
|
||||
constant_data[2] =
|
||||
std::make_shared<ngraph::runtime::HostTensor>(ngraph::element::Type_t::i32, Shape{2}, signal_val);
|
||||
|
||||
std::vector<StaticShape> static_input_shapes = {StaticShape{1, 320, 320, 2}, StaticShape{2}, StaticShape{2}},
|
||||
static_output_shapes = {StaticShape{}};
|
||||
|
||||
shape_inference(DFT.get(), static_input_shapes, static_output_shapes, constant_data);
|
||||
ASSERT_EQ(static_output_shapes[0], StaticShape({1, 512, 100, 2}));
|
||||
}
|
||||
|
||||
TEST(StaticShapeInferenceTest, DFTConstantTest) {
|
||||
auto DFT = build_dft_constant();
|
||||
|
||||
std::vector<StaticShape> static_input_shapes = {StaticShape{1, 320, 320, 2}, StaticShape{2}, StaticShape{2}},
|
||||
static_output_shapes = {StaticShape{}};
|
||||
|
||||
shape_inference(DFT.get(), static_input_shapes, static_output_shapes);
|
||||
ASSERT_EQ(static_output_shapes[0], StaticShape({1, 512, 100, 2}));
|
||||
}
|
||||
|
||||
TEST(StaticShapeInferenceTest, DFTSignalMissingConstDataTest) {
|
||||
auto DFT = build_dft_signal();
|
||||
std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>> constant_data;
|
||||
int32_t axes_val[] = {1, 2};
|
||||
constant_data[1] = std::make_shared<ngraph::runtime::HostTensor>(ngraph::element::Type_t::i32, Shape{2}, axes_val);
|
||||
|
||||
std::vector<StaticShape> static_input_shapes = {StaticShape{1, 320, 320, 2}, StaticShape{2}, StaticShape{2}},
|
||||
static_output_shapes = {StaticShape{}};
|
||||
EXPECT_THROW(shape_inference(DFT.get(), static_input_shapes, static_output_shapes, constant_data),
|
||||
NodeValidationFailure);
|
||||
}
|
||||
|
||||
TEST(StaticShapeInferenceTest, IDFTTest) {
|
||||
auto IDFT = build_idft();
|
||||
std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>> constant_data;
|
||||
int32_t axes_val[] = {1, 2};
|
||||
constant_data[1] = std::make_shared<ngraph::runtime::HostTensor>(ngraph::element::Type_t::i32, Shape{2}, axes_val);
|
||||
|
||||
std::vector<StaticShape> static_input_shapes = {StaticShape{1, 320, 320, 2}, StaticShape{2}},
|
||||
static_output_shapes = {StaticShape{}};
|
||||
|
||||
shape_inference(IDFT.get(), static_input_shapes, static_output_shapes, constant_data);
|
||||
ASSERT_EQ(static_output_shapes[0], StaticShape({1, 320, 320, 2}));
|
||||
}
|
||||
|
||||
TEST(StaticShapeInferenceTest, IDFTSignalTest) {
|
||||
auto IDFT = build_idft_signal();
|
||||
std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>> constant_data;
|
||||
int32_t axes_val[] = {1, 2};
|
||||
int32_t signal_val[] = {512, 100};
|
||||
constant_data[1] = std::make_shared<ngraph::runtime::HostTensor>(ngraph::element::Type_t::i32, Shape{2}, axes_val);
|
||||
constant_data[2] =
|
||||
std::make_shared<ngraph::runtime::HostTensor>(ngraph::element::Type_t::i32, Shape{2}, signal_val);
|
||||
|
||||
std::vector<StaticShape> static_input_shapes = {StaticShape{1, 320, 320, 2}, StaticShape{2}, StaticShape{2}},
|
||||
static_output_shapes = {StaticShape{}};
|
||||
|
||||
shape_inference(IDFT.get(), static_input_shapes, static_output_shapes, constant_data);
|
||||
ASSERT_EQ(static_output_shapes[0], StaticShape({1, 512, 100, 2}));
|
||||
}
|
||||
|
||||
TEST(StaticShapeInferenceTest, IDFTSignalMissingConstDataTest) {
|
||||
auto IDFT = build_idft_signal();
|
||||
std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>> constant_data;
|
||||
int32_t axes_val[] = {1, 2};
|
||||
constant_data[1] = std::make_shared<ngraph::runtime::HostTensor>(ngraph::element::Type_t::i32, Shape{2}, axes_val);
|
||||
|
||||
std::vector<StaticShape> static_input_shapes = {StaticShape{1, 320, 320, 2}, StaticShape{2}, StaticShape{2}},
|
||||
static_output_shapes = {StaticShape{}};
|
||||
EXPECT_THROW(shape_inference(IDFT.get(), static_input_shapes, static_output_shapes, constant_data),
|
||||
NodeValidationFailure);
|
||||
}
|
||||
Reference in New Issue
Block a user