Review CTCLoss class for shape inference aspects (#15375)
* Review ctc loss operator for - partial shape and label propagation - template implementation of shape_infer - update/extend tests * Use namespace ov in ctc loss operator
This commit is contained in:
parent
3a8646215f
commit
4ce3e9a88d
@ -3,43 +3,35 @@
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <array>
|
||||
#include <openvino/op/ctc_loss.hpp>
|
||||
|
||||
namespace ov {
|
||||
namespace op {
|
||||
namespace v4 {
|
||||
|
||||
template <class T>
|
||||
void shape_infer(const CTCLoss* op, const std::vector<T>& input_shapes, std::vector<T>& output_shapes) {
|
||||
using DimType = typename std::iterator_traits<typename T::iterator>::value_type;
|
||||
NODE_VALIDATION_CHECK(op, (input_shapes.size() == 4 || input_shapes.size() == 5) && output_shapes.size() == 1);
|
||||
namespace ctc_loss {
|
||||
constexpr auto shape_names =
|
||||
std::array<const char*, 5>{"logits", "logit length", "labels", "label length", "blank index"};
|
||||
constexpr auto shape_ranks = std::array<int64_t, 4>{3, 1, 2, 1};
|
||||
} // namespace ctc_loss
|
||||
|
||||
template <class TShape>
|
||||
std::vector<TShape> shape_infer(const CTCLoss* op, const std::vector<TShape>& input_shapes) {
|
||||
using DimType = typename TShape::value_type;
|
||||
NODE_VALIDATION_CHECK(op, input_shapes.size() == 4 || input_shapes.size() == 5);
|
||||
|
||||
// check ranks of input tensors
|
||||
const auto& logits_pshape = input_shapes[0];
|
||||
const auto& logit_length_pshape = input_shapes[1];
|
||||
const auto& labels_pshape = input_shapes[2];
|
||||
const auto& label_length_pshape = input_shapes[3];
|
||||
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
logits_pshape.rank().compatible(3),
|
||||
"Expected a 3D tensor for logits. Got: ",
|
||||
logits_pshape);
|
||||
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
logit_length_pshape.rank().compatible(1),
|
||||
"Expected a 1D tensor for logit length. Got: ",
|
||||
logit_length_pshape);
|
||||
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
labels_pshape.rank().compatible(2),
|
||||
"Expected a 2D tensor for labels. Got: ",
|
||||
labels_pshape);
|
||||
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
label_length_pshape.rank().compatible(1),
|
||||
"Expected a 1D tensor for label length. Got: ",
|
||||
label_length_pshape);
|
||||
for (size_t i = 0; i < ctc_loss::shape_ranks.size(); ++i) {
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
input_shapes[i].rank().compatible(ctc_loss::shape_ranks[i]),
|
||||
"Expected a ",
|
||||
ctc_loss::shape_ranks[i],
|
||||
"D tensor for ",
|
||||
ctc_loss::shape_names[i],
|
||||
". Got: ",
|
||||
input_shapes[i]);
|
||||
}
|
||||
|
||||
// check optional input shape: blank index
|
||||
if (input_shapes.size() == 5) {
|
||||
@ -50,82 +42,59 @@ void shape_infer(const CTCLoss* op, const std::vector<T>& input_shapes, std::vec
|
||||
blank_index_pshape);
|
||||
}
|
||||
|
||||
const auto& logits_pshape = input_shapes[0];
|
||||
const auto& logits_rank = logits_pshape.rank();
|
||||
|
||||
const auto& logit_length_pshape = input_shapes[1];
|
||||
const auto& labels_pshape = input_shapes[2];
|
||||
const auto& label_length_pshape = input_shapes[3];
|
||||
|
||||
// check shapes of input tensors
|
||||
DimType batch_size = 1;
|
||||
bool is_batch_size_set = false;
|
||||
DimType time_steps = 1;
|
||||
bool is_time_steps_set = false;
|
||||
DimType batch_size = logits_rank.is_static() ? logits_pshape[0] : -1;
|
||||
DimType time_steps = logits_rank.is_static() ? logits_pshape[1] : -1;
|
||||
|
||||
if (logits_pshape.rank().is_static()) {
|
||||
batch_size = logits_pshape[0];
|
||||
is_batch_size_set = true;
|
||||
time_steps = logits_pshape[1];
|
||||
is_time_steps_set = true;
|
||||
}
|
||||
|
||||
if (logit_length_pshape.rank().is_static()) {
|
||||
if (is_batch_size_set) {
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
DimType::merge(batch_size, batch_size, logit_length_pshape[0]),
|
||||
"The first dimension of logit length must be equal to the first dimension ",
|
||||
"of the logits. Got: ",
|
||||
logit_length_pshape[0],
|
||||
" and: ",
|
||||
batch_size);
|
||||
} else {
|
||||
batch_size = logit_length_pshape[0];
|
||||
is_batch_size_set = true;
|
||||
}
|
||||
}
|
||||
NODE_VALIDATION_CHECK(
|
||||
op,
|
||||
logit_length_pshape.rank().is_dynamic() || DimType::merge(batch_size, batch_size, logit_length_pshape[0]),
|
||||
"The first dimension of logit length must be equal to the first dimension ",
|
||||
"of the logits. Got: ",
|
||||
logit_length_pshape[0],
|
||||
" and: ",
|
||||
batch_size);
|
||||
|
||||
if (labels_pshape.rank().is_static()) {
|
||||
if (is_batch_size_set) {
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
DimType::merge(batch_size, batch_size, labels_pshape[0]),
|
||||
"The first dimension of labels must be equal to the first dimension ",
|
||||
"of the logits and the logit length. Got: ",
|
||||
labels_pshape[0],
|
||||
" and: ",
|
||||
batch_size);
|
||||
} else {
|
||||
batch_size = labels_pshape[0];
|
||||
is_batch_size_set = true;
|
||||
}
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
DimType::merge(batch_size, batch_size, labels_pshape[0]),
|
||||
"The first dimension of labels must be equal to the first dimension ",
|
||||
"of the logits and the logit length. Got: ",
|
||||
labels_pshape[0],
|
||||
" and: ",
|
||||
batch_size);
|
||||
|
||||
if (is_time_steps_set) {
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
DimType::merge(time_steps, time_steps, labels_pshape[1]),
|
||||
"The second dimension of labels must be equal to the second dimension ",
|
||||
"of logits. Got: ",
|
||||
labels_pshape[1],
|
||||
" and: ",
|
||||
time_steps);
|
||||
}
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
labels_pshape[1].compatible(time_steps),
|
||||
"The second dimension of labels must be equal to the second dimension ",
|
||||
"of logits. Got: ",
|
||||
labels_pshape[1],
|
||||
" and: ",
|
||||
time_steps);
|
||||
}
|
||||
|
||||
if (label_length_pshape.rank().is_static()) {
|
||||
if (is_batch_size_set) {
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
DimType::merge(batch_size, batch_size, label_length_pshape[0]),
|
||||
"The first dimension of label length must be equal to the first dimension ",
|
||||
"of the logits, the logit length and labels. Got: ",
|
||||
label_length_pshape[0],
|
||||
" and: ",
|
||||
batch_size);
|
||||
} else {
|
||||
batch_size = label_length_pshape[0];
|
||||
is_batch_size_set = true;
|
||||
}
|
||||
}
|
||||
NODE_VALIDATION_CHECK(
|
||||
op,
|
||||
label_length_pshape.rank().is_dynamic() || DimType::merge(batch_size, batch_size, label_length_pshape[0]),
|
||||
"The first dimension of label length must be equal to the first dimension ",
|
||||
"of the logits, the logit length and labels. Got: ",
|
||||
label_length_pshape[0],
|
||||
" and: ",
|
||||
batch_size);
|
||||
|
||||
auto& output_shape = output_shapes[0];
|
||||
output_shape.resize(1);
|
||||
return {TShape{batch_size}};
|
||||
}
|
||||
|
||||
if (is_batch_size_set) {
|
||||
output_shape[0] = batch_size;
|
||||
} else {
|
||||
output_shape[0] = Dimension::dynamic();
|
||||
}
|
||||
template <class TShape>
|
||||
void shape_infer(const CTCLoss* op, const std::vector<TShape>& input_shapes, std::vector<TShape>& output_shapes) {
|
||||
output_shapes = shape_infer(op, input_shapes);
|
||||
}
|
||||
|
||||
} // namespace v4
|
||||
|
@ -2,15 +2,15 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "ngraph/op/ctc_loss.hpp"
|
||||
|
||||
#include <ctc_loss_shape_inference.hpp>
|
||||
#include "openvino/op/ctc_loss.hpp"
|
||||
|
||||
#include "ctc_loss_shape_inference.hpp"
|
||||
#include "itt.hpp"
|
||||
#include "openvino/core/validation_util.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
namespace ov {
|
||||
op::v4::CTCLoss::CTCLoss(const Output<Node>& logits,
|
||||
const Output<Node>& logit_length,
|
||||
const Output<Node>& labels,
|
||||
@ -44,55 +44,26 @@ void op::v4::CTCLoss::validate_and_infer_types() {
|
||||
OV_OP_SCOPE(v4_CTCLoss_validate_and_infer_types);
|
||||
// check types of input tensors
|
||||
const auto& logits_type = get_input_element_type(0);
|
||||
const auto& logit_length_type = get_input_element_type(1);
|
||||
const auto& labels_type = get_input_element_type(2);
|
||||
const auto& label_length_type = get_input_element_type(3);
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
logits_type.is_real(),
|
||||
"The data type for logits is expected to be a floating point type. Got: ",
|
||||
"The data type for ",
|
||||
ctc_loss::shape_names[0],
|
||||
" is expected to be a floating point type. Got: ",
|
||||
logits_type);
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
logit_length_type.is_integral_number(),
|
||||
"The logit length type is expected to be an integer type. Got: ",
|
||||
logit_length_type);
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
labels_type.is_integral_number(),
|
||||
"The labels type is expected to be an integer type. Got: ",
|
||||
labels_type);
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
label_length_type.is_integral_number(),
|
||||
"The label length type is expected to be an integer type. Got: ",
|
||||
label_length_type);
|
||||
|
||||
// check optional input type: blank index
|
||||
if (get_input_size() == 5) {
|
||||
const auto& blank_index_type = get_input_element_type(4);
|
||||
for (size_t i = 1; i < get_input_size(); ++i) {
|
||||
const auto& input_et = get_input_element_type(i);
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
blank_index_type.is_integral_number(),
|
||||
"The blank index type is expected to be an integer type. Got: ",
|
||||
blank_index_type);
|
||||
input_et.is_integral_number(),
|
||||
"The ",
|
||||
ctc_loss::shape_names[i],
|
||||
" type is expected to be an integer type. Got: ",
|
||||
input_et);
|
||||
}
|
||||
|
||||
const auto& logits_pshape = get_input_partial_shape(0);
|
||||
const auto& logit_length_pshape = get_input_partial_shape(1);
|
||||
const auto& labels_pshape = get_input_partial_shape(2);
|
||||
const auto& label_length_pshape = get_input_partial_shape(3);
|
||||
|
||||
std::vector<ov::PartialShape> input_shapes;
|
||||
if (get_input_size() == 5) {
|
||||
const auto& blank_index_pshape = get_input_partial_shape(4);
|
||||
input_shapes = {logits_pshape, logit_length_pshape, labels_pshape, label_length_pshape, blank_index_pshape};
|
||||
} else {
|
||||
input_shapes = {logits_pshape, logit_length_pshape, labels_pshape, label_length_pshape};
|
||||
}
|
||||
std::vector<ov::PartialShape> output_shapes = {ov::PartialShape{}};
|
||||
|
||||
shape_infer(this, input_shapes, output_shapes);
|
||||
set_output_type(0, logits_type, output_shapes[0]);
|
||||
const auto output_shape = shape_infer(this, ov::get_node_input_partial_shapes(*this)).front();
|
||||
set_output_type(0, logits_type, output_shape);
|
||||
}
|
||||
|
||||
bool op::v4::CTCLoss::visit_attributes(AttributeVisitor& visitor) {
|
||||
@ -124,6 +95,7 @@ shared_ptr<Node> op::v4::CTCLoss::clone_with_new_inputs(const OutputVector& new_
|
||||
ctc_merge_repeated_,
|
||||
unique_);
|
||||
} else {
|
||||
throw ngraph_error("Incorrect number of arguments");
|
||||
throw ov::Exception("Incorrect number of arguments");
|
||||
}
|
||||
}
|
||||
} // namespace ov
|
||||
|
@ -2,256 +2,247 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "common_test_utils/test_assertions.hpp"
|
||||
#include "gtest/gtest.h"
|
||||
#include "ngraph/ngraph.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "util/type_prop.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
using namespace ov;
|
||||
using namespace ov::opset10;
|
||||
using namespace testing;
|
||||
|
||||
TEST(type_prop, ctc_loss) {
|
||||
class TypePropCTCLossV4Test : public TypePropOpTest<op::v4::CTCLoss> {};
|
||||
|
||||
TEST_F(TypePropCTCLossV4Test, with_blank_index) {
|
||||
// create inputs
|
||||
auto logits = make_shared<op::Parameter>(element::f32, Shape{10, 120, 28});
|
||||
auto logit_length = make_shared<op::Parameter>(element::i32, Shape{10});
|
||||
auto labels = make_shared<op::Parameter>(element::i32, Shape{10, 120});
|
||||
auto label_length = make_shared<op::Parameter>(element::i32, Shape{10});
|
||||
auto blank_index = make_shared<op::Parameter>(element::i32, Shape{});
|
||||
auto logits_shape = PartialShape{10, 120, 28};
|
||||
set_shape_labels(logits_shape, 10);
|
||||
|
||||
auto logits = make_shared<Parameter>(element::f32, logits_shape);
|
||||
auto logit_length = make_shared<Parameter>(element::i32, Shape{10});
|
||||
auto labels = make_shared<Parameter>(element::i32, Shape{10, 120});
|
||||
auto label_length = make_shared<Parameter>(element::i32, Shape{10});
|
||||
auto blank_index = make_shared<Parameter>(element::i32, Shape{});
|
||||
|
||||
// create CTCLoss node
|
||||
auto ctc_loss = make_shared<op::v4::CTCLoss>(logits, logit_length, labels, label_length, blank_index);
|
||||
auto ctc_loss = make_op(logits, logit_length, labels, label_length, blank_index);
|
||||
|
||||
// check type and shape infer
|
||||
EXPECT_EQ(ctc_loss->get_element_type(), element::f32);
|
||||
EXPECT_TRUE(ctc_loss->get_output_partial_shape(0).same_scheme(PartialShape{10}));
|
||||
EXPECT_EQ(ctc_loss->get_output_partial_shape(0), PartialShape({10}));
|
||||
EXPECT_THAT(get_shape_labels(ctc_loss->get_output_partial_shape(0)), ElementsAre(10));
|
||||
}
|
||||
|
||||
TEST(type_prop, ctc_loss_no_blank_index) {
|
||||
TEST_F(TypePropCTCLossV4Test, no_blank_index) {
|
||||
// create inputs
|
||||
auto logits = make_shared<op::Parameter>(element::f32, Shape{10, 120, 28});
|
||||
auto logit_length = make_shared<op::Parameter>(element::i32, Shape{10});
|
||||
auto labels = make_shared<op::Parameter>(element::i32, Shape{10, 120});
|
||||
auto label_length = make_shared<op::Parameter>(element::i32, Shape{10});
|
||||
auto logits = make_shared<Parameter>(element::f32, Shape{10, 120, 28});
|
||||
auto logit_length = make_shared<Parameter>(element::i32, Shape{10});
|
||||
|
||||
auto labels_shape = PartialShape{10, 120};
|
||||
set_shape_labels(labels_shape, 20);
|
||||
auto labels = make_shared<Parameter>(element::i32, labels_shape);
|
||||
auto label_length = make_shared<Parameter>(element::i32, Shape{10});
|
||||
|
||||
// create CTCLoss node
|
||||
auto ctc_loss = make_shared<op::v4::CTCLoss>(logits, logit_length, labels, label_length);
|
||||
auto ctc_loss = make_op(logits, logit_length, labels, label_length);
|
||||
|
||||
// check type and shape infer
|
||||
EXPECT_EQ(ctc_loss->get_element_type(), element::f32);
|
||||
EXPECT_TRUE(ctc_loss->get_output_partial_shape(0).same_scheme(PartialShape{10}));
|
||||
EXPECT_EQ(ctc_loss->get_output_partial_shape(0), PartialShape({10}));
|
||||
EXPECT_THAT(get_shape_labels(ctc_loss->get_output_partial_shape(0)), ElementsAre(20));
|
||||
}
|
||||
|
||||
TEST(type_prop, ctc_loss_output_type) {
|
||||
TEST_F(TypePropCTCLossV4Test, output_type_f64) {
|
||||
// create inputs
|
||||
auto logits = make_shared<op::Parameter>(element::f64, Shape{10, 120, 28});
|
||||
auto logit_length = make_shared<op::Parameter>(element::i32, Shape{10});
|
||||
auto labels = make_shared<op::Parameter>(element::i32, Shape{10, 120});
|
||||
auto label_length = make_shared<op::Parameter>(element::i32, Shape{10});
|
||||
auto blank_index = make_shared<op::Parameter>(element::i32, Shape{});
|
||||
auto logits = make_shared<Parameter>(element::f64, Shape{10, 120, 28});
|
||||
auto logit_length = make_shared<Parameter>(element::i32, Shape{10});
|
||||
auto labels = make_shared<Parameter>(element::i32, Shape{10, 120});
|
||||
|
||||
auto label_len_shape = PartialShape{10};
|
||||
set_shape_labels(label_len_shape, 30);
|
||||
auto label_length = make_shared<Parameter>(element::i32, label_len_shape);
|
||||
auto blank_index = make_shared<Parameter>(element::i32, Shape{});
|
||||
|
||||
// create CTCLoss node
|
||||
auto ctc_loss = make_shared<op::v4::CTCLoss>(logits, logit_length, labels, label_length, blank_index);
|
||||
auto ctc_loss = make_op(logits, logit_length, labels, label_length, blank_index);
|
||||
|
||||
// check type and shape infer
|
||||
EXPECT_EQ(ctc_loss->get_element_type(), element::f64);
|
||||
EXPECT_TRUE(ctc_loss->get_output_partial_shape(0).same_scheme(PartialShape{10}));
|
||||
EXPECT_EQ(ctc_loss->get_output_partial_shape(0), PartialShape({10}));
|
||||
EXPECT_THAT(get_shape_labels(ctc_loss->get_output_partial_shape(0)), ElementsAre(30));
|
||||
}
|
||||
|
||||
TEST(type_prop, ctc_loss_non_default_parameters) {
|
||||
TEST_F(TypePropCTCLossV4Test, non_default_parameters) {
|
||||
// create inputs
|
||||
auto logits = make_shared<op::Parameter>(element::f64, Shape{10, 120, 28});
|
||||
auto logit_length = make_shared<op::Parameter>(element::i32, Shape{10});
|
||||
auto labels = make_shared<op::Parameter>(element::i32, Shape{10, 120});
|
||||
auto label_length = make_shared<op::Parameter>(element::i32, Shape{10});
|
||||
auto blank_index = make_shared<op::Parameter>(element::i32, Shape{});
|
||||
auto logits = make_shared<Parameter>(element::f64, Shape{10, 120, 28});
|
||||
auto logit_length = make_shared<Parameter>(element::i32, Shape{10});
|
||||
auto labels = make_shared<Parameter>(element::i32, Shape{10, 120});
|
||||
auto label_length = make_shared<Parameter>(element::i32, Shape{10});
|
||||
auto blank_index = make_shared<Parameter>(element::i32, Shape{});
|
||||
|
||||
// create CTCLoss node
|
||||
auto ctc_loss =
|
||||
make_shared<op::v4::CTCLoss>(logits, logit_length, labels, label_length, blank_index, true, false, false);
|
||||
auto ctc_loss = make_op(logits, logit_length, labels, label_length, blank_index, true, false, false);
|
||||
|
||||
// check type and shape infer
|
||||
EXPECT_EQ(ctc_loss->get_element_type(), element::f64);
|
||||
EXPECT_TRUE(ctc_loss->get_output_partial_shape(0).same_scheme(PartialShape{10}));
|
||||
EXPECT_EQ(ctc_loss->get_output_partial_shape(0), PartialShape({10}));
|
||||
}
|
||||
|
||||
TEST(type_prop, ctc_loss_dynamic_input) {
|
||||
TEST_F(TypePropCTCLossV4Test, dynamic_input) {
|
||||
// create inputs
|
||||
auto logits = make_shared<op::Parameter>(element::f32, PartialShape{Dimension::dynamic(), 120, 28});
|
||||
auto logit_length = make_shared<op::Parameter>(element::i32, PartialShape{Dimension::dynamic()});
|
||||
auto labels = make_shared<op::Parameter>(element::i32, PartialShape{Dimension::dynamic(), 120});
|
||||
auto label_length = make_shared<op::Parameter>(element::i32, PartialShape{Dimension::dynamic()});
|
||||
auto blank_index = make_shared<op::Parameter>(element::i32, Shape{});
|
||||
auto logits = make_shared<Parameter>(element::f32, PartialShape{Dimension::dynamic(), 120, 28});
|
||||
auto logit_length = make_shared<Parameter>(element::i32, PartialShape{Dimension::dynamic()});
|
||||
auto labels = make_shared<Parameter>(element::i32, PartialShape{Dimension::dynamic(), 120});
|
||||
auto label_length = make_shared<Parameter>(element::i32, PartialShape{Dimension::dynamic()});
|
||||
auto blank_index = make_shared<Parameter>(element::i32, Shape{});
|
||||
|
||||
// create CTCLoss node
|
||||
auto ctc_loss = make_shared<op::v4::CTCLoss>(logits, logit_length, labels, label_length, blank_index);
|
||||
auto ctc_loss = make_op(logits, logit_length, labels, label_length, blank_index);
|
||||
|
||||
// check type and shape infer
|
||||
EXPECT_EQ(ctc_loss->get_element_type(), element::f32);
|
||||
EXPECT_TRUE(ctc_loss->get_output_partial_shape(0).same_scheme(PartialShape{Dimension::dynamic()}));
|
||||
EXPECT_EQ(ctc_loss->get_output_partial_shape(0), PartialShape::dynamic(1));
|
||||
}
|
||||
|
||||
TEST(type_prop, ctc_loss_partly_dynamic_input) {
|
||||
TEST_F(TypePropCTCLossV4Test, partly_dynamic_input) {
|
||||
// create inputs
|
||||
auto logits = make_shared<op::Parameter>(element::f32, PartialShape{Dimension::dynamic(), 120, 28});
|
||||
auto logit_length = make_shared<op::Parameter>(element::i32, PartialShape{10});
|
||||
auto labels = make_shared<op::Parameter>(element::i32, PartialShape{Dimension::dynamic(), 120});
|
||||
auto label_length = make_shared<op::Parameter>(element::i32, PartialShape{Dimension::dynamic()});
|
||||
auto blank_index = make_shared<op::Parameter>(element::i32, Shape{});
|
||||
auto logits_shape = PartialShape{{2, 20}, {100, 130}, 28};
|
||||
auto logits_len_shape = PartialShape{{5, 10}};
|
||||
auto labels_shape = PartialShape{-1, 120};
|
||||
set_shape_labels(logits_shape, 10);
|
||||
set_shape_labels(logits_len_shape, 20);
|
||||
set_shape_labels(labels_shape, 30);
|
||||
|
||||
auto logits = make_shared<Parameter>(element::f32, logits_shape);
|
||||
auto logit_length = make_shared<Parameter>(element::i32, logits_len_shape);
|
||||
auto labels = make_shared<Parameter>(element::i32, labels_shape);
|
||||
auto label_length = make_shared<Parameter>(element::i32, PartialShape{Dimension::dynamic()});
|
||||
auto blank_index = make_shared<Parameter>(element::i32, Shape{});
|
||||
|
||||
// create CTCLoss node
|
||||
auto ctc_loss = make_shared<op::v4::CTCLoss>(logits, logit_length, labels, label_length, blank_index);
|
||||
auto ctc_loss = make_op(logits, logit_length, labels, label_length, blank_index);
|
||||
|
||||
// check type and shape infer
|
||||
EXPECT_EQ(ctc_loss->get_element_type(), element::f32);
|
||||
EXPECT_TRUE(ctc_loss->get_output_partial_shape(0).same_scheme(PartialShape{10}));
|
||||
EXPECT_EQ(ctc_loss->get_output_partial_shape(0), PartialShape({{5, 10}}));
|
||||
EXPECT_THAT(get_shape_labels(ctc_loss->get_output_partial_shape(0)), ElementsAre(30));
|
||||
}
|
||||
|
||||
TEST(type_prop, ctc_loss_fail_inputs_dim) {
|
||||
TEST_F(TypePropCTCLossV4Test, fail_inputs_dim) {
|
||||
// create inputs
|
||||
auto logits = make_shared<op::Parameter>(element::f32, Shape{10, 120, 40, 28});
|
||||
auto logit_length = make_shared<op::Parameter>(element::i32, Shape{10});
|
||||
auto labels = make_shared<op::Parameter>(element::i32, Shape{10, 120});
|
||||
auto label_length = make_shared<op::Parameter>(element::i32, Shape{10});
|
||||
auto blank_index = make_shared<op::Parameter>(element::i32, Shape{});
|
||||
auto logits = make_shared<Parameter>(element::f32, Shape{10, 120, 40, 28});
|
||||
auto logit_length = make_shared<Parameter>(element::i32, Shape{10});
|
||||
auto labels = make_shared<Parameter>(element::i32, Shape{10, 120});
|
||||
auto label_length = make_shared<Parameter>(element::i32, Shape{10});
|
||||
auto blank_index = make_shared<Parameter>(element::i32, Shape{});
|
||||
|
||||
try {
|
||||
// create CTCLoss node
|
||||
auto ctc_loss = make_shared<op::v4::CTCLoss>(logits, logit_length, labels, label_length, blank_index);
|
||||
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "Invalid inputs not detected";
|
||||
} catch (const NodeValidationFailure& error) {
|
||||
EXPECT_HAS_SUBSTRING(error.what(), std::string("Expected a 3D tensor for logits."));
|
||||
} catch (...) {
|
||||
FAIL() << "Inputs shape check failed for unexpected reason";
|
||||
}
|
||||
OV_EXPECT_THROW(auto op = make_op(logits, logit_length, labels, label_length, blank_index),
|
||||
NodeValidationFailure,
|
||||
HasSubstr("Expected a 3D tensor for logits."));
|
||||
}
|
||||
|
||||
TEST(type_prop, ctc_loss_fail_logit_length_dim) {
|
||||
TEST_F(TypePropCTCLossV4Test, fail_logit_length_dim) {
|
||||
// create inputs
|
||||
auto logits = make_shared<op::Parameter>(element::f32, Shape{10, 120, 28});
|
||||
auto logit_length = make_shared<op::Parameter>(element::i32, Shape{10, 20});
|
||||
auto labels = make_shared<op::Parameter>(element::i32, Shape{10, 120});
|
||||
auto label_length = make_shared<op::Parameter>(element::i32, Shape{10});
|
||||
auto blank_index = make_shared<op::Parameter>(element::i32, Shape{});
|
||||
auto logits = make_shared<Parameter>(element::f32, Shape{10, 120, 28});
|
||||
auto logit_length = make_shared<Parameter>(element::i32, Shape{10, 20});
|
||||
auto labels = make_shared<Parameter>(element::i32, Shape{10, 120});
|
||||
auto label_length = make_shared<Parameter>(element::i32, Shape{10});
|
||||
auto blank_index = make_shared<Parameter>(element::i32, Shape{});
|
||||
|
||||
try {
|
||||
// create CTCLoss node
|
||||
auto ctc_loss = make_shared<op::v4::CTCLoss>(logits, logit_length, labels, label_length, blank_index);
|
||||
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "Invalid logit length not detected";
|
||||
} catch (const NodeValidationFailure& error) {
|
||||
EXPECT_HAS_SUBSTRING(error.what(), std::string("Expected a 1D tensor for logit length."));
|
||||
} catch (...) {
|
||||
FAIL() << "Logit length shape check failed for unexpected reason";
|
||||
}
|
||||
OV_EXPECT_THROW(auto op = make_op(logits, logit_length, labels, label_length, blank_index),
|
||||
NodeValidationFailure,
|
||||
HasSubstr("Expected a 1D tensor for logit length."));
|
||||
}
|
||||
|
||||
TEST(type_prop, ctc_loss_fail_labels_dim) {
|
||||
TEST_F(TypePropCTCLossV4Test, fail_labels_dim) {
|
||||
// create inputs
|
||||
auto logits = make_shared<op::Parameter>(element::f32, Shape{10, 120, 28});
|
||||
auto logit_length = make_shared<op::Parameter>(element::i32, Shape{10});
|
||||
auto labels = make_shared<op::Parameter>(element::i32, Shape{10});
|
||||
auto label_length = make_shared<op::Parameter>(element::i32, Shape{10});
|
||||
auto blank_index = make_shared<op::Parameter>(element::i32, Shape{});
|
||||
auto logits = make_shared<Parameter>(element::f32, Shape{10, 120, 28});
|
||||
auto logit_length = make_shared<Parameter>(element::i32, Shape{10});
|
||||
auto labels = make_shared<Parameter>(element::i32, Shape{10});
|
||||
auto label_length = make_shared<Parameter>(element::i32, Shape{10});
|
||||
auto blank_index = make_shared<Parameter>(element::i32, Shape{});
|
||||
|
||||
try {
|
||||
// create CTCLoss node
|
||||
auto ctc_loss = make_shared<op::v4::CTCLoss>(logits, logit_length, labels, label_length, blank_index);
|
||||
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "Invalid labels not detected";
|
||||
} catch (const NodeValidationFailure& error) {
|
||||
EXPECT_HAS_SUBSTRING(error.what(), std::string("Expected a 2D tensor for labels."));
|
||||
} catch (...) {
|
||||
FAIL() << "Labels shape check failed for unexpected reason";
|
||||
}
|
||||
OV_EXPECT_THROW(auto op = make_op(logits, logit_length, labels, label_length, blank_index),
|
||||
NodeValidationFailure,
|
||||
HasSubstr("Expected a 2D tensor for labels."));
|
||||
}
|
||||
|
||||
TEST(type_prop, ctc_loss_fail_label_length_dim) {
|
||||
TEST_F(TypePropCTCLossV4Test, fail_label_length_dim) {
|
||||
// create inputs
|
||||
auto logits = make_shared<op::Parameter>(element::f32, Shape{10, 120, 28});
|
||||
auto logit_length = make_shared<op::Parameter>(element::i32, Shape{10});
|
||||
auto labels = make_shared<op::Parameter>(element::i32, Shape{10, 120});
|
||||
auto label_length = make_shared<op::Parameter>(element::i32, Shape{10, 40});
|
||||
auto blank_index = make_shared<op::Parameter>(element::i32, Shape{});
|
||||
auto logits = make_shared<Parameter>(element::f32, Shape{10, 120, 28});
|
||||
auto logit_length = make_shared<Parameter>(element::i32, Shape{10});
|
||||
auto labels = make_shared<Parameter>(element::i32, Shape{10, 120});
|
||||
auto label_length = make_shared<Parameter>(element::i32, Shape{10, 40});
|
||||
auto blank_index = make_shared<Parameter>(element::i32, Shape{});
|
||||
|
||||
try {
|
||||
// create CTCLoss node
|
||||
auto ctc_loss = make_shared<op::v4::CTCLoss>(logits, logit_length, labels, label_length, blank_index);
|
||||
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "Invalid labels not detected";
|
||||
} catch (const NodeValidationFailure& error) {
|
||||
EXPECT_HAS_SUBSTRING(error.what(), std::string("Expected a 1D tensor for label length."));
|
||||
} catch (...) {
|
||||
FAIL() << "Label length shape check failed for unexpected reason";
|
||||
}
|
||||
OV_EXPECT_THROW(auto op = make_op(logits, logit_length, labels, label_length, blank_index),
|
||||
NodeValidationFailure,
|
||||
HasSubstr("Expected a 1D tensor for label length."));
|
||||
}
|
||||
|
||||
TEST(type_prop, ctc_loss_fail_blank_index_dim) {
|
||||
TEST_F(TypePropCTCLossV4Test, fail_blank_index_dim) {
|
||||
// create inputs
|
||||
auto logits = make_shared<op::Parameter>(element::f32, Shape{10, 120, 28});
|
||||
auto logit_length = make_shared<op::Parameter>(element::i32, Shape{10});
|
||||
auto labels = make_shared<op::Parameter>(element::i32, Shape{10, 120});
|
||||
auto label_length = make_shared<op::Parameter>(element::i32, Shape{10});
|
||||
auto blank_index = make_shared<op::Parameter>(element::i32, Shape{4});
|
||||
auto logits = make_shared<Parameter>(element::f32, Shape{10, 120, 28});
|
||||
auto logit_length = make_shared<Parameter>(element::i32, Shape{10});
|
||||
auto labels = make_shared<Parameter>(element::i32, Shape{10, 120});
|
||||
auto label_length = make_shared<Parameter>(element::i32, Shape{10});
|
||||
auto blank_index = make_shared<Parameter>(element::i32, Shape{4});
|
||||
|
||||
try {
|
||||
// create CTCLoss node
|
||||
auto ctc_loss = make_shared<op::v4::CTCLoss>(logits, logit_length, labels, label_length, blank_index);
|
||||
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "Invalid labels not detected";
|
||||
} catch (const NodeValidationFailure& error) {
|
||||
EXPECT_HAS_SUBSTRING(error.what(), std::string("Expected a scalar for blank index."));
|
||||
} catch (...) {
|
||||
FAIL() << "Blank index shape check failed for unexpected reason";
|
||||
}
|
||||
OV_EXPECT_THROW(auto op = make_op(logits, logit_length, labels, label_length, blank_index),
|
||||
NodeValidationFailure,
|
||||
HasSubstr("Expected a scalar for blank index."));
|
||||
}
|
||||
|
||||
TEST(type_prop, ctc_loss_fail_batch_dim_mismatch) {
|
||||
TEST_F(TypePropCTCLossV4Test, fail_batch_dim_mismatch) {
|
||||
// create inputs
|
||||
auto logits = make_shared<op::Parameter>(element::f32, Shape{10, 120, 28});
|
||||
auto logit_length = make_shared<op::Parameter>(element::i32, Shape{10});
|
||||
auto labels = make_shared<op::Parameter>(element::i32, Shape{10, 120});
|
||||
auto label_length = make_shared<op::Parameter>(element::i32, Shape{40});
|
||||
auto blank_index = make_shared<op::Parameter>(element::i32, Shape{});
|
||||
auto logits = make_shared<Parameter>(element::f32, Shape{10, 120, 28});
|
||||
auto logit_length = make_shared<Parameter>(element::i32, Shape{10});
|
||||
auto labels = make_shared<Parameter>(element::i32, Shape{10, 120});
|
||||
auto label_length = make_shared<Parameter>(element::i32, Shape{40});
|
||||
auto blank_index = make_shared<Parameter>(element::i32, Shape{});
|
||||
|
||||
try {
|
||||
// create CTCLoss node
|
||||
auto ctc_loss = make_shared<op::v4::CTCLoss>(logits, logit_length, labels, label_length, blank_index);
|
||||
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "Mismatch of batch dimension not detected";
|
||||
} catch (const NodeValidationFailure& error) {
|
||||
EXPECT_HAS_SUBSTRING(error.what(),
|
||||
std::string("The first dimension of label length must be equal to the first dimension "
|
||||
"of the logits, the logit length and labels."));
|
||||
} catch (...) {
|
||||
FAIL() << "Batch dimension matching check failed for unexpected reason";
|
||||
}
|
||||
OV_EXPECT_THROW(auto op = make_op(logits, logit_length, labels, label_length, blank_index),
|
||||
NodeValidationFailure,
|
||||
HasSubstr("The first dimension of label length must be equal to the first dimension of the logits, "
|
||||
"the logit length and labels."));
|
||||
}
|
||||
|
||||
TEST(type_prop, ctc_loss_fail_time_dim_mismatch) {
|
||||
TEST_F(TypePropCTCLossV4Test, fail_time_dim_mismatch) {
|
||||
// create inputs
|
||||
auto logits = make_shared<op::Parameter>(element::f32, Shape{10, 120, 28});
|
||||
auto logit_length = make_shared<op::Parameter>(element::i32, Shape{10});
|
||||
auto labels = make_shared<op::Parameter>(element::i32, Shape{10, 130});
|
||||
auto label_length = make_shared<op::Parameter>(element::i32, Shape{40});
|
||||
auto blank_index = make_shared<op::Parameter>(element::i32, Shape{});
|
||||
auto logits = make_shared<Parameter>(element::f32, Shape{10, 120, 28});
|
||||
auto logit_length = make_shared<Parameter>(element::i32, Shape{10});
|
||||
auto labels = make_shared<Parameter>(element::i32, Shape{10, 130});
|
||||
auto label_length = make_shared<Parameter>(element::i32, Shape{40});
|
||||
auto blank_index = make_shared<Parameter>(element::i32, Shape{});
|
||||
|
||||
try {
|
||||
// create CTCLoss node
|
||||
auto ctc_loss = make_shared<op::v4::CTCLoss>(logits, logit_length, labels, label_length, blank_index);
|
||||
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "Mismatch of time dimension not detected";
|
||||
} catch (const NodeValidationFailure& error) {
|
||||
EXPECT_HAS_SUBSTRING(error.what(),
|
||||
std::string("The second dimension of labels must be equal to the second dimension "
|
||||
"of logits."));
|
||||
} catch (...) {
|
||||
FAIL() << "Time dimension matching check failed for unexpected reason";
|
||||
}
|
||||
OV_EXPECT_THROW(auto op = make_op(logits, logit_length, labels, label_length, blank_index),
|
||||
NodeValidationFailure,
|
||||
HasSubstr("The second dimension of labels must be equal to the second dimension of logits."));
|
||||
}
|
||||
|
||||
TEST_F(TypePropCTCLossV4Test, default_ctor) {
|
||||
// create inputs
|
||||
auto logits_shape = PartialShape{{2, 20}, {100, 130}, 28};
|
||||
auto logits_len_shape = PartialShape{{5, 10}};
|
||||
auto labels_shape = PartialShape{-1, 120};
|
||||
set_shape_labels(logits_shape, 10);
|
||||
set_shape_labels(logits_len_shape, 20);
|
||||
set_shape_labels(labels_shape, 30);
|
||||
|
||||
auto logits = make_shared<Parameter>(element::f32, logits_shape);
|
||||
auto logit_length = make_shared<Parameter>(element::i32, logits_len_shape);
|
||||
auto labels = make_shared<Parameter>(element::i32, labels_shape);
|
||||
auto label_length = make_shared<Parameter>(element::i32, PartialShape{Dimension::dynamic()});
|
||||
auto blank_index = make_shared<Parameter>(element::i32, Shape{});
|
||||
|
||||
// create CTCLoss node
|
||||
auto ctc_loss = make_op();
|
||||
ctc_loss->set_arguments(OutputVector{logits, logit_length, labels, label_length, blank_index});
|
||||
ctc_loss->validate_and_infer_types();
|
||||
|
||||
// check type and shape infer
|
||||
EXPECT_EQ(ctc_loss->get_element_type(), element::f32);
|
||||
EXPECT_EQ(ctc_loss->get_output_partial_shape(0), PartialShape({{5, 10}}));
|
||||
EXPECT_THAT(get_shape_labels(ctc_loss->get_output_partial_shape(0)), ElementsAre(30));
|
||||
}
|
||||
|
@ -1,35 +0,0 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <ctc_loss_shape_inference.hpp>
|
||||
#include <openvino/op/ctc_loss.hpp>
|
||||
#include <openvino/op/parameter.hpp>
|
||||
#include <utils/shape_inference/shape_inference.hpp>
|
||||
#include <utils/shape_inference/static_shape.hpp>
|
||||
|
||||
using namespace ov;
|
||||
using namespace ov::intel_cpu;
|
||||
|
||||
TEST(StaticShapeInferenceTest, CTCLossTest) {
|
||||
const auto& logits = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1});
|
||||
const auto& logit_length = std::make_shared<ov::op::v0::Parameter>(element::i32, PartialShape{-1});
|
||||
const auto& labels = std::make_shared<ov::op::v0::Parameter>(element::i32, PartialShape{-1, -1});
|
||||
const auto& label_length = std::make_shared<ov::op::v0::Parameter>(element::i32, PartialShape{-1});
|
||||
const auto& blank_index = std::make_shared<ov::op::v0::Parameter>(element::i32, ov::Shape{});
|
||||
|
||||
// create CTCLoss node
|
||||
auto ctc_loss = std::make_shared<op::v4::CTCLoss>(logits, logit_length, labels, label_length, blank_index);
|
||||
|
||||
std::vector<StaticShape> static_input_shapes = {StaticShape{10, 120, 28},
|
||||
StaticShape{10},
|
||||
StaticShape{10, 120},
|
||||
StaticShape{10},
|
||||
ov::Shape{}},
|
||||
static_output_shapes = {StaticShape{}};
|
||||
shape_inference(ctc_loss.get(), static_input_shapes, static_output_shapes);
|
||||
|
||||
ASSERT_EQ(static_output_shapes[0], StaticShape({10}));
|
||||
}
|
@ -0,0 +1,46 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "ctc_loss_shape_inference.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
using namespace ov;
|
||||
using namespace ov::opset10;
|
||||
using namespace ov::intel_cpu;
|
||||
|
||||
class CTCLossV4StaticShapeInferenceTest : public OpStaticShapeInferenceTest<op::v4::CTCLoss> {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
output_shapes.resize(1);
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(CTCLossV4StaticShapeInferenceTest, correct_input_shapes) {
|
||||
const auto& logits = std::make_shared<Parameter>(element::f32, PartialShape{-1, -1, -1});
|
||||
const auto& logit_length = std::make_shared<Parameter>(element::i32, PartialShape{-1});
|
||||
const auto& labels = std::make_shared<Parameter>(element::i32, PartialShape{-1, -1});
|
||||
const auto& label_length = std::make_shared<Parameter>(element::i32, PartialShape{-1});
|
||||
const auto& blank_index = std::make_shared<Parameter>(element::i32, ov::Shape{});
|
||||
|
||||
auto op = make_op(logits, logit_length, labels, label_length, blank_index);
|
||||
|
||||
input_shapes = ShapeVector{{10, 120, 28}, {10}, {10, 120}, {10}, {}};
|
||||
shape_inference(op.get(), input_shapes, output_shapes);
|
||||
|
||||
EXPECT_EQ(output_shapes.size(), 1);
|
||||
EXPECT_EQ(output_shapes.front(), StaticShape({10}));
|
||||
}
|
||||
|
||||
TEST_F(CTCLossV4StaticShapeInferenceTest, default_ctor) {
|
||||
auto op = make_op();
|
||||
|
||||
input_shapes = ShapeVector{{12, 120, 28}, {12}, {12, 120}, {12}, {}};
|
||||
shape_inference(op.get(), input_shapes, output_shapes);
|
||||
|
||||
EXPECT_EQ(output_shapes.size(), 1);
|
||||
EXPECT_EQ(output_shapes.front(), StaticShape({12}));
|
||||
}
|
Loading…
Reference in New Issue
Block a user