Review CTCLoss class for shape inference aspects (#15375)

* Review ctc loss operator for
- partial shape and label propagation
- template implementation of shape_infer
- update/extend tests

* Use namespace ov in ctc loss operator
This commit is contained in:
Pawel Raasz 2023-01-31 11:10:30 +01:00 committed by GitHub
parent 3a8646215f
commit 4ce3e9a88d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 293 additions and 350 deletions

View File

@ -3,43 +3,35 @@
//
#pragma once
#include <array>
#include <openvino/op/ctc_loss.hpp>
namespace ov {
namespace op {
namespace v4 {
template <class T>
void shape_infer(const CTCLoss* op, const std::vector<T>& input_shapes, std::vector<T>& output_shapes) {
using DimType = typename std::iterator_traits<typename T::iterator>::value_type;
NODE_VALIDATION_CHECK(op, (input_shapes.size() == 4 || input_shapes.size() == 5) && output_shapes.size() == 1);
namespace ctc_loss {
constexpr auto shape_names =
std::array<const char*, 5>{"logits", "logit length", "labels", "label length", "blank index"};
constexpr auto shape_ranks = std::array<int64_t, 4>{3, 1, 2, 1};
} // namespace ctc_loss
template <class TShape>
std::vector<TShape> shape_infer(const CTCLoss* op, const std::vector<TShape>& input_shapes) {
using DimType = typename TShape::value_type;
NODE_VALIDATION_CHECK(op, input_shapes.size() == 4 || input_shapes.size() == 5);
// check ranks of input tensors
const auto& logits_pshape = input_shapes[0];
const auto& logit_length_pshape = input_shapes[1];
const auto& labels_pshape = input_shapes[2];
const auto& label_length_pshape = input_shapes[3];
NODE_VALIDATION_CHECK(op,
logits_pshape.rank().compatible(3),
"Expected a 3D tensor for logits. Got: ",
logits_pshape);
NODE_VALIDATION_CHECK(op,
logit_length_pshape.rank().compatible(1),
"Expected a 1D tensor for logit length. Got: ",
logit_length_pshape);
NODE_VALIDATION_CHECK(op,
labels_pshape.rank().compatible(2),
"Expected a 2D tensor for labels. Got: ",
labels_pshape);
NODE_VALIDATION_CHECK(op,
label_length_pshape.rank().compatible(1),
"Expected a 1D tensor for label length. Got: ",
label_length_pshape);
for (size_t i = 0; i < ctc_loss::shape_ranks.size(); ++i) {
NODE_VALIDATION_CHECK(op,
input_shapes[i].rank().compatible(ctc_loss::shape_ranks[i]),
"Expected a ",
ctc_loss::shape_ranks[i],
"D tensor for ",
ctc_loss::shape_names[i],
". Got: ",
input_shapes[i]);
}
// check optional input shape: blank index
if (input_shapes.size() == 5) {
@ -50,82 +42,59 @@ void shape_infer(const CTCLoss* op, const std::vector<T>& input_shapes, std::vec
blank_index_pshape);
}
const auto& logits_pshape = input_shapes[0];
const auto& logits_rank = logits_pshape.rank();
const auto& logit_length_pshape = input_shapes[1];
const auto& labels_pshape = input_shapes[2];
const auto& label_length_pshape = input_shapes[3];
// check shapes of input tensors
DimType batch_size = 1;
bool is_batch_size_set = false;
DimType time_steps = 1;
bool is_time_steps_set = false;
DimType batch_size = logits_rank.is_static() ? logits_pshape[0] : -1;
DimType time_steps = logits_rank.is_static() ? logits_pshape[1] : -1;
if (logits_pshape.rank().is_static()) {
batch_size = logits_pshape[0];
is_batch_size_set = true;
time_steps = logits_pshape[1];
is_time_steps_set = true;
}
if (logit_length_pshape.rank().is_static()) {
if (is_batch_size_set) {
NODE_VALIDATION_CHECK(op,
DimType::merge(batch_size, batch_size, logit_length_pshape[0]),
"The first dimension of logit length must be equal to the first dimension ",
"of the logits. Got: ",
logit_length_pshape[0],
" and: ",
batch_size);
} else {
batch_size = logit_length_pshape[0];
is_batch_size_set = true;
}
}
NODE_VALIDATION_CHECK(
op,
logit_length_pshape.rank().is_dynamic() || DimType::merge(batch_size, batch_size, logit_length_pshape[0]),
"The first dimension of logit length must be equal to the first dimension ",
"of the logits. Got: ",
logit_length_pshape[0],
" and: ",
batch_size);
if (labels_pshape.rank().is_static()) {
if (is_batch_size_set) {
NODE_VALIDATION_CHECK(op,
DimType::merge(batch_size, batch_size, labels_pshape[0]),
"The first dimension of labels must be equal to the first dimension ",
"of the logits and the logit length. Got: ",
labels_pshape[0],
" and: ",
batch_size);
} else {
batch_size = labels_pshape[0];
is_batch_size_set = true;
}
NODE_VALIDATION_CHECK(op,
DimType::merge(batch_size, batch_size, labels_pshape[0]),
"The first dimension of labels must be equal to the first dimension ",
"of the logits and the logit length. Got: ",
labels_pshape[0],
" and: ",
batch_size);
if (is_time_steps_set) {
NODE_VALIDATION_CHECK(op,
DimType::merge(time_steps, time_steps, labels_pshape[1]),
"The second dimension of labels must be equal to the second dimension ",
"of logits. Got: ",
labels_pshape[1],
" and: ",
time_steps);
}
NODE_VALIDATION_CHECK(op,
labels_pshape[1].compatible(time_steps),
"The second dimension of labels must be equal to the second dimension ",
"of logits. Got: ",
labels_pshape[1],
" and: ",
time_steps);
}
if (label_length_pshape.rank().is_static()) {
if (is_batch_size_set) {
NODE_VALIDATION_CHECK(op,
DimType::merge(batch_size, batch_size, label_length_pshape[0]),
"The first dimension of label length must be equal to the first dimension ",
"of the logits, the logit length and labels. Got: ",
label_length_pshape[0],
" and: ",
batch_size);
} else {
batch_size = label_length_pshape[0];
is_batch_size_set = true;
}
}
NODE_VALIDATION_CHECK(
op,
label_length_pshape.rank().is_dynamic() || DimType::merge(batch_size, batch_size, label_length_pshape[0]),
"The first dimension of label length must be equal to the first dimension ",
"of the logits, the logit length and labels. Got: ",
label_length_pshape[0],
" and: ",
batch_size);
auto& output_shape = output_shapes[0];
output_shape.resize(1);
return {TShape{batch_size}};
}
if (is_batch_size_set) {
output_shape[0] = batch_size;
} else {
output_shape[0] = Dimension::dynamic();
}
template <class TShape>
void shape_infer(const CTCLoss* op, const std::vector<TShape>& input_shapes, std::vector<TShape>& output_shapes) {
output_shapes = shape_infer(op, input_shapes);
}
} // namespace v4

View File

@ -2,15 +2,15 @@
// SPDX-License-Identifier: Apache-2.0
//
#include "ngraph/op/ctc_loss.hpp"
#include <ctc_loss_shape_inference.hpp>
#include "openvino/op/ctc_loss.hpp"
#include "ctc_loss_shape_inference.hpp"
#include "itt.hpp"
#include "openvino/core/validation_util.hpp"
using namespace std;
using namespace ngraph;
namespace ov {
op::v4::CTCLoss::CTCLoss(const Output<Node>& logits,
const Output<Node>& logit_length,
const Output<Node>& labels,
@ -44,55 +44,26 @@ void op::v4::CTCLoss::validate_and_infer_types() {
OV_OP_SCOPE(v4_CTCLoss_validate_and_infer_types);
// check types of input tensors
const auto& logits_type = get_input_element_type(0);
const auto& logit_length_type = get_input_element_type(1);
const auto& labels_type = get_input_element_type(2);
const auto& label_length_type = get_input_element_type(3);
NODE_VALIDATION_CHECK(this,
logits_type.is_real(),
"The data type for logits is expected to be a floating point type. Got: ",
"The data type for ",
ctc_loss::shape_names[0],
" is expected to be a floating point type. Got: ",
logits_type);
NODE_VALIDATION_CHECK(this,
logit_length_type.is_integral_number(),
"The logit length type is expected to be an integer type. Got: ",
logit_length_type);
NODE_VALIDATION_CHECK(this,
labels_type.is_integral_number(),
"The labels type is expected to be an integer type. Got: ",
labels_type);
NODE_VALIDATION_CHECK(this,
label_length_type.is_integral_number(),
"The label length type is expected to be an integer type. Got: ",
label_length_type);
// check optional input type: blank index
if (get_input_size() == 5) {
const auto& blank_index_type = get_input_element_type(4);
for (size_t i = 1; i < get_input_size(); ++i) {
const auto& input_et = get_input_element_type(i);
NODE_VALIDATION_CHECK(this,
blank_index_type.is_integral_number(),
"The blank index type is expected to be an integer type. Got: ",
blank_index_type);
input_et.is_integral_number(),
"The ",
ctc_loss::shape_names[i],
" type is expected to be an integer type. Got: ",
input_et);
}
const auto& logits_pshape = get_input_partial_shape(0);
const auto& logit_length_pshape = get_input_partial_shape(1);
const auto& labels_pshape = get_input_partial_shape(2);
const auto& label_length_pshape = get_input_partial_shape(3);
std::vector<ov::PartialShape> input_shapes;
if (get_input_size() == 5) {
const auto& blank_index_pshape = get_input_partial_shape(4);
input_shapes = {logits_pshape, logit_length_pshape, labels_pshape, label_length_pshape, blank_index_pshape};
} else {
input_shapes = {logits_pshape, logit_length_pshape, labels_pshape, label_length_pshape};
}
std::vector<ov::PartialShape> output_shapes = {ov::PartialShape{}};
shape_infer(this, input_shapes, output_shapes);
set_output_type(0, logits_type, output_shapes[0]);
const auto output_shape = shape_infer(this, ov::get_node_input_partial_shapes(*this)).front();
set_output_type(0, logits_type, output_shape);
}
bool op::v4::CTCLoss::visit_attributes(AttributeVisitor& visitor) {
@ -124,6 +95,7 @@ shared_ptr<Node> op::v4::CTCLoss::clone_with_new_inputs(const OutputVector& new_
ctc_merge_repeated_,
unique_);
} else {
throw ngraph_error("Incorrect number of arguments");
throw ov::Exception("Incorrect number of arguments");
}
}
} // namespace ov

View File

@ -2,256 +2,247 @@
// SPDX-License-Identifier: Apache-2.0
//
#include "common_test_utils/test_assertions.hpp"
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "openvino/opsets/opset10.hpp"
#include "util/type_prop.hpp"
using namespace std;
using namespace ngraph;
using namespace ov;
using namespace ov::opset10;
using namespace testing;
TEST(type_prop, ctc_loss) {
class TypePropCTCLossV4Test : public TypePropOpTest<op::v4::CTCLoss> {};
TEST_F(TypePropCTCLossV4Test, with_blank_index) {
// create inputs
auto logits = make_shared<op::Parameter>(element::f32, Shape{10, 120, 28});
auto logit_length = make_shared<op::Parameter>(element::i32, Shape{10});
auto labels = make_shared<op::Parameter>(element::i32, Shape{10, 120});
auto label_length = make_shared<op::Parameter>(element::i32, Shape{10});
auto blank_index = make_shared<op::Parameter>(element::i32, Shape{});
auto logits_shape = PartialShape{10, 120, 28};
set_shape_labels(logits_shape, 10);
auto logits = make_shared<Parameter>(element::f32, logits_shape);
auto logit_length = make_shared<Parameter>(element::i32, Shape{10});
auto labels = make_shared<Parameter>(element::i32, Shape{10, 120});
auto label_length = make_shared<Parameter>(element::i32, Shape{10});
auto blank_index = make_shared<Parameter>(element::i32, Shape{});
// create CTCLoss node
auto ctc_loss = make_shared<op::v4::CTCLoss>(logits, logit_length, labels, label_length, blank_index);
auto ctc_loss = make_op(logits, logit_length, labels, label_length, blank_index);
// check type and shape infer
EXPECT_EQ(ctc_loss->get_element_type(), element::f32);
EXPECT_TRUE(ctc_loss->get_output_partial_shape(0).same_scheme(PartialShape{10}));
EXPECT_EQ(ctc_loss->get_output_partial_shape(0), PartialShape({10}));
EXPECT_THAT(get_shape_labels(ctc_loss->get_output_partial_shape(0)), ElementsAre(10));
}
TEST(type_prop, ctc_loss_no_blank_index) {
TEST_F(TypePropCTCLossV4Test, no_blank_index) {
// create inputs
auto logits = make_shared<op::Parameter>(element::f32, Shape{10, 120, 28});
auto logit_length = make_shared<op::Parameter>(element::i32, Shape{10});
auto labels = make_shared<op::Parameter>(element::i32, Shape{10, 120});
auto label_length = make_shared<op::Parameter>(element::i32, Shape{10});
auto logits = make_shared<Parameter>(element::f32, Shape{10, 120, 28});
auto logit_length = make_shared<Parameter>(element::i32, Shape{10});
auto labels_shape = PartialShape{10, 120};
set_shape_labels(labels_shape, 20);
auto labels = make_shared<Parameter>(element::i32, labels_shape);
auto label_length = make_shared<Parameter>(element::i32, Shape{10});
// create CTCLoss node
auto ctc_loss = make_shared<op::v4::CTCLoss>(logits, logit_length, labels, label_length);
auto ctc_loss = make_op(logits, logit_length, labels, label_length);
// check type and shape infer
EXPECT_EQ(ctc_loss->get_element_type(), element::f32);
EXPECT_TRUE(ctc_loss->get_output_partial_shape(0).same_scheme(PartialShape{10}));
EXPECT_EQ(ctc_loss->get_output_partial_shape(0), PartialShape({10}));
EXPECT_THAT(get_shape_labels(ctc_loss->get_output_partial_shape(0)), ElementsAre(20));
}
TEST(type_prop, ctc_loss_output_type) {
TEST_F(TypePropCTCLossV4Test, output_type_f64) {
// create inputs
auto logits = make_shared<op::Parameter>(element::f64, Shape{10, 120, 28});
auto logit_length = make_shared<op::Parameter>(element::i32, Shape{10});
auto labels = make_shared<op::Parameter>(element::i32, Shape{10, 120});
auto label_length = make_shared<op::Parameter>(element::i32, Shape{10});
auto blank_index = make_shared<op::Parameter>(element::i32, Shape{});
auto logits = make_shared<Parameter>(element::f64, Shape{10, 120, 28});
auto logit_length = make_shared<Parameter>(element::i32, Shape{10});
auto labels = make_shared<Parameter>(element::i32, Shape{10, 120});
auto label_len_shape = PartialShape{10};
set_shape_labels(label_len_shape, 30);
auto label_length = make_shared<Parameter>(element::i32, label_len_shape);
auto blank_index = make_shared<Parameter>(element::i32, Shape{});
// create CTCLoss node
auto ctc_loss = make_shared<op::v4::CTCLoss>(logits, logit_length, labels, label_length, blank_index);
auto ctc_loss = make_op(logits, logit_length, labels, label_length, blank_index);
// check type and shape infer
EXPECT_EQ(ctc_loss->get_element_type(), element::f64);
EXPECT_TRUE(ctc_loss->get_output_partial_shape(0).same_scheme(PartialShape{10}));
EXPECT_EQ(ctc_loss->get_output_partial_shape(0), PartialShape({10}));
EXPECT_THAT(get_shape_labels(ctc_loss->get_output_partial_shape(0)), ElementsAre(30));
}
TEST(type_prop, ctc_loss_non_default_parameters) {
TEST_F(TypePropCTCLossV4Test, non_default_parameters) {
// create inputs
auto logits = make_shared<op::Parameter>(element::f64, Shape{10, 120, 28});
auto logit_length = make_shared<op::Parameter>(element::i32, Shape{10});
auto labels = make_shared<op::Parameter>(element::i32, Shape{10, 120});
auto label_length = make_shared<op::Parameter>(element::i32, Shape{10});
auto blank_index = make_shared<op::Parameter>(element::i32, Shape{});
auto logits = make_shared<Parameter>(element::f64, Shape{10, 120, 28});
auto logit_length = make_shared<Parameter>(element::i32, Shape{10});
auto labels = make_shared<Parameter>(element::i32, Shape{10, 120});
auto label_length = make_shared<Parameter>(element::i32, Shape{10});
auto blank_index = make_shared<Parameter>(element::i32, Shape{});
// create CTCLoss node
auto ctc_loss =
make_shared<op::v4::CTCLoss>(logits, logit_length, labels, label_length, blank_index, true, false, false);
auto ctc_loss = make_op(logits, logit_length, labels, label_length, blank_index, true, false, false);
// check type and shape infer
EXPECT_EQ(ctc_loss->get_element_type(), element::f64);
EXPECT_TRUE(ctc_loss->get_output_partial_shape(0).same_scheme(PartialShape{10}));
EXPECT_EQ(ctc_loss->get_output_partial_shape(0), PartialShape({10}));
}
TEST(type_prop, ctc_loss_dynamic_input) {
TEST_F(TypePropCTCLossV4Test, dynamic_input) {
// create inputs
auto logits = make_shared<op::Parameter>(element::f32, PartialShape{Dimension::dynamic(), 120, 28});
auto logit_length = make_shared<op::Parameter>(element::i32, PartialShape{Dimension::dynamic()});
auto labels = make_shared<op::Parameter>(element::i32, PartialShape{Dimension::dynamic(), 120});
auto label_length = make_shared<op::Parameter>(element::i32, PartialShape{Dimension::dynamic()});
auto blank_index = make_shared<op::Parameter>(element::i32, Shape{});
auto logits = make_shared<Parameter>(element::f32, PartialShape{Dimension::dynamic(), 120, 28});
auto logit_length = make_shared<Parameter>(element::i32, PartialShape{Dimension::dynamic()});
auto labels = make_shared<Parameter>(element::i32, PartialShape{Dimension::dynamic(), 120});
auto label_length = make_shared<Parameter>(element::i32, PartialShape{Dimension::dynamic()});
auto blank_index = make_shared<Parameter>(element::i32, Shape{});
// create CTCLoss node
auto ctc_loss = make_shared<op::v4::CTCLoss>(logits, logit_length, labels, label_length, blank_index);
auto ctc_loss = make_op(logits, logit_length, labels, label_length, blank_index);
// check type and shape infer
EXPECT_EQ(ctc_loss->get_element_type(), element::f32);
EXPECT_TRUE(ctc_loss->get_output_partial_shape(0).same_scheme(PartialShape{Dimension::dynamic()}));
EXPECT_EQ(ctc_loss->get_output_partial_shape(0), PartialShape::dynamic(1));
}
TEST(type_prop, ctc_loss_partly_dynamic_input) {
TEST_F(TypePropCTCLossV4Test, partly_dynamic_input) {
// create inputs
auto logits = make_shared<op::Parameter>(element::f32, PartialShape{Dimension::dynamic(), 120, 28});
auto logit_length = make_shared<op::Parameter>(element::i32, PartialShape{10});
auto labels = make_shared<op::Parameter>(element::i32, PartialShape{Dimension::dynamic(), 120});
auto label_length = make_shared<op::Parameter>(element::i32, PartialShape{Dimension::dynamic()});
auto blank_index = make_shared<op::Parameter>(element::i32, Shape{});
auto logits_shape = PartialShape{{2, 20}, {100, 130}, 28};
auto logits_len_shape = PartialShape{{5, 10}};
auto labels_shape = PartialShape{-1, 120};
set_shape_labels(logits_shape, 10);
set_shape_labels(logits_len_shape, 20);
set_shape_labels(labels_shape, 30);
auto logits = make_shared<Parameter>(element::f32, logits_shape);
auto logit_length = make_shared<Parameter>(element::i32, logits_len_shape);
auto labels = make_shared<Parameter>(element::i32, labels_shape);
auto label_length = make_shared<Parameter>(element::i32, PartialShape{Dimension::dynamic()});
auto blank_index = make_shared<Parameter>(element::i32, Shape{});
// create CTCLoss node
auto ctc_loss = make_shared<op::v4::CTCLoss>(logits, logit_length, labels, label_length, blank_index);
auto ctc_loss = make_op(logits, logit_length, labels, label_length, blank_index);
// check type and shape infer
EXPECT_EQ(ctc_loss->get_element_type(), element::f32);
EXPECT_TRUE(ctc_loss->get_output_partial_shape(0).same_scheme(PartialShape{10}));
EXPECT_EQ(ctc_loss->get_output_partial_shape(0), PartialShape({{5, 10}}));
EXPECT_THAT(get_shape_labels(ctc_loss->get_output_partial_shape(0)), ElementsAre(30));
}
TEST(type_prop, ctc_loss_fail_inputs_dim) {
TEST_F(TypePropCTCLossV4Test, fail_inputs_dim) {
// create inputs
auto logits = make_shared<op::Parameter>(element::f32, Shape{10, 120, 40, 28});
auto logit_length = make_shared<op::Parameter>(element::i32, Shape{10});
auto labels = make_shared<op::Parameter>(element::i32, Shape{10, 120});
auto label_length = make_shared<op::Parameter>(element::i32, Shape{10});
auto blank_index = make_shared<op::Parameter>(element::i32, Shape{});
auto logits = make_shared<Parameter>(element::f32, Shape{10, 120, 40, 28});
auto logit_length = make_shared<Parameter>(element::i32, Shape{10});
auto labels = make_shared<Parameter>(element::i32, Shape{10, 120});
auto label_length = make_shared<Parameter>(element::i32, Shape{10});
auto blank_index = make_shared<Parameter>(element::i32, Shape{});
try {
// create CTCLoss node
auto ctc_loss = make_shared<op::v4::CTCLoss>(logits, logit_length, labels, label_length, blank_index);
// Should have thrown, so fail if it didn't
FAIL() << "Invalid inputs not detected";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), std::string("Expected a 3D tensor for logits."));
} catch (...) {
FAIL() << "Inputs shape check failed for unexpected reason";
}
OV_EXPECT_THROW(auto op = make_op(logits, logit_length, labels, label_length, blank_index),
NodeValidationFailure,
HasSubstr("Expected a 3D tensor for logits."));
}
TEST(type_prop, ctc_loss_fail_logit_length_dim) {
TEST_F(TypePropCTCLossV4Test, fail_logit_length_dim) {
// create inputs
auto logits = make_shared<op::Parameter>(element::f32, Shape{10, 120, 28});
auto logit_length = make_shared<op::Parameter>(element::i32, Shape{10, 20});
auto labels = make_shared<op::Parameter>(element::i32, Shape{10, 120});
auto label_length = make_shared<op::Parameter>(element::i32, Shape{10});
auto blank_index = make_shared<op::Parameter>(element::i32, Shape{});
auto logits = make_shared<Parameter>(element::f32, Shape{10, 120, 28});
auto logit_length = make_shared<Parameter>(element::i32, Shape{10, 20});
auto labels = make_shared<Parameter>(element::i32, Shape{10, 120});
auto label_length = make_shared<Parameter>(element::i32, Shape{10});
auto blank_index = make_shared<Parameter>(element::i32, Shape{});
try {
// create CTCLoss node
auto ctc_loss = make_shared<op::v4::CTCLoss>(logits, logit_length, labels, label_length, blank_index);
// Should have thrown, so fail if it didn't
FAIL() << "Invalid logit length not detected";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), std::string("Expected a 1D tensor for logit length."));
} catch (...) {
FAIL() << "Logit length shape check failed for unexpected reason";
}
OV_EXPECT_THROW(auto op = make_op(logits, logit_length, labels, label_length, blank_index),
NodeValidationFailure,
HasSubstr("Expected a 1D tensor for logit length."));
}
TEST(type_prop, ctc_loss_fail_labels_dim) {
TEST_F(TypePropCTCLossV4Test, fail_labels_dim) {
// create inputs
auto logits = make_shared<op::Parameter>(element::f32, Shape{10, 120, 28});
auto logit_length = make_shared<op::Parameter>(element::i32, Shape{10});
auto labels = make_shared<op::Parameter>(element::i32, Shape{10});
auto label_length = make_shared<op::Parameter>(element::i32, Shape{10});
auto blank_index = make_shared<op::Parameter>(element::i32, Shape{});
auto logits = make_shared<Parameter>(element::f32, Shape{10, 120, 28});
auto logit_length = make_shared<Parameter>(element::i32, Shape{10});
auto labels = make_shared<Parameter>(element::i32, Shape{10});
auto label_length = make_shared<Parameter>(element::i32, Shape{10});
auto blank_index = make_shared<Parameter>(element::i32, Shape{});
try {
// create CTCLoss node
auto ctc_loss = make_shared<op::v4::CTCLoss>(logits, logit_length, labels, label_length, blank_index);
// Should have thrown, so fail if it didn't
FAIL() << "Invalid labels not detected";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), std::string("Expected a 2D tensor for labels."));
} catch (...) {
FAIL() << "Labels shape check failed for unexpected reason";
}
OV_EXPECT_THROW(auto op = make_op(logits, logit_length, labels, label_length, blank_index),
NodeValidationFailure,
HasSubstr("Expected a 2D tensor for labels."));
}
TEST(type_prop, ctc_loss_fail_label_length_dim) {
TEST_F(TypePropCTCLossV4Test, fail_label_length_dim) {
// create inputs
auto logits = make_shared<op::Parameter>(element::f32, Shape{10, 120, 28});
auto logit_length = make_shared<op::Parameter>(element::i32, Shape{10});
auto labels = make_shared<op::Parameter>(element::i32, Shape{10, 120});
auto label_length = make_shared<op::Parameter>(element::i32, Shape{10, 40});
auto blank_index = make_shared<op::Parameter>(element::i32, Shape{});
auto logits = make_shared<Parameter>(element::f32, Shape{10, 120, 28});
auto logit_length = make_shared<Parameter>(element::i32, Shape{10});
auto labels = make_shared<Parameter>(element::i32, Shape{10, 120});
auto label_length = make_shared<Parameter>(element::i32, Shape{10, 40});
auto blank_index = make_shared<Parameter>(element::i32, Shape{});
try {
// create CTCLoss node
auto ctc_loss = make_shared<op::v4::CTCLoss>(logits, logit_length, labels, label_length, blank_index);
// Should have thrown, so fail if it didn't
FAIL() << "Invalid labels not detected";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), std::string("Expected a 1D tensor for label length."));
} catch (...) {
FAIL() << "Label length shape check failed for unexpected reason";
}
OV_EXPECT_THROW(auto op = make_op(logits, logit_length, labels, label_length, blank_index),
NodeValidationFailure,
HasSubstr("Expected a 1D tensor for label length."));
}
TEST(type_prop, ctc_loss_fail_blank_index_dim) {
TEST_F(TypePropCTCLossV4Test, fail_blank_index_dim) {
// create inputs
auto logits = make_shared<op::Parameter>(element::f32, Shape{10, 120, 28});
auto logit_length = make_shared<op::Parameter>(element::i32, Shape{10});
auto labels = make_shared<op::Parameter>(element::i32, Shape{10, 120});
auto label_length = make_shared<op::Parameter>(element::i32, Shape{10});
auto blank_index = make_shared<op::Parameter>(element::i32, Shape{4});
auto logits = make_shared<Parameter>(element::f32, Shape{10, 120, 28});
auto logit_length = make_shared<Parameter>(element::i32, Shape{10});
auto labels = make_shared<Parameter>(element::i32, Shape{10, 120});
auto label_length = make_shared<Parameter>(element::i32, Shape{10});
auto blank_index = make_shared<Parameter>(element::i32, Shape{4});
try {
// create CTCLoss node
auto ctc_loss = make_shared<op::v4::CTCLoss>(logits, logit_length, labels, label_length, blank_index);
// Should have thrown, so fail if it didn't
FAIL() << "Invalid labels not detected";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), std::string("Expected a scalar for blank index."));
} catch (...) {
FAIL() << "Blank index shape check failed for unexpected reason";
}
OV_EXPECT_THROW(auto op = make_op(logits, logit_length, labels, label_length, blank_index),
NodeValidationFailure,
HasSubstr("Expected a scalar for blank index."));
}
TEST(type_prop, ctc_loss_fail_batch_dim_mismatch) {
TEST_F(TypePropCTCLossV4Test, fail_batch_dim_mismatch) {
// create inputs
auto logits = make_shared<op::Parameter>(element::f32, Shape{10, 120, 28});
auto logit_length = make_shared<op::Parameter>(element::i32, Shape{10});
auto labels = make_shared<op::Parameter>(element::i32, Shape{10, 120});
auto label_length = make_shared<op::Parameter>(element::i32, Shape{40});
auto blank_index = make_shared<op::Parameter>(element::i32, Shape{});
auto logits = make_shared<Parameter>(element::f32, Shape{10, 120, 28});
auto logit_length = make_shared<Parameter>(element::i32, Shape{10});
auto labels = make_shared<Parameter>(element::i32, Shape{10, 120});
auto label_length = make_shared<Parameter>(element::i32, Shape{40});
auto blank_index = make_shared<Parameter>(element::i32, Shape{});
try {
// create CTCLoss node
auto ctc_loss = make_shared<op::v4::CTCLoss>(logits, logit_length, labels, label_length, blank_index);
// Should have thrown, so fail if it didn't
FAIL() << "Mismatch of batch dimension not detected";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(),
std::string("The first dimension of label length must be equal to the first dimension "
"of the logits, the logit length and labels."));
} catch (...) {
FAIL() << "Batch dimension matching check failed for unexpected reason";
}
OV_EXPECT_THROW(auto op = make_op(logits, logit_length, labels, label_length, blank_index),
NodeValidationFailure,
HasSubstr("The first dimension of label length must be equal to the first dimension of the logits, "
"the logit length and labels."));
}
TEST(type_prop, ctc_loss_fail_time_dim_mismatch) {
TEST_F(TypePropCTCLossV4Test, fail_time_dim_mismatch) {
// create inputs
auto logits = make_shared<op::Parameter>(element::f32, Shape{10, 120, 28});
auto logit_length = make_shared<op::Parameter>(element::i32, Shape{10});
auto labels = make_shared<op::Parameter>(element::i32, Shape{10, 130});
auto label_length = make_shared<op::Parameter>(element::i32, Shape{40});
auto blank_index = make_shared<op::Parameter>(element::i32, Shape{});
auto logits = make_shared<Parameter>(element::f32, Shape{10, 120, 28});
auto logit_length = make_shared<Parameter>(element::i32, Shape{10});
auto labels = make_shared<Parameter>(element::i32, Shape{10, 130});
auto label_length = make_shared<Parameter>(element::i32, Shape{40});
auto blank_index = make_shared<Parameter>(element::i32, Shape{});
try {
// create CTCLoss node
auto ctc_loss = make_shared<op::v4::CTCLoss>(logits, logit_length, labels, label_length, blank_index);
// Should have thrown, so fail if it didn't
FAIL() << "Mismatch of time dimension not detected";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(),
std::string("The second dimension of labels must be equal to the second dimension "
"of logits."));
} catch (...) {
FAIL() << "Time dimension matching check failed for unexpected reason";
}
OV_EXPECT_THROW(auto op = make_op(logits, logit_length, labels, label_length, blank_index),
NodeValidationFailure,
HasSubstr("The second dimension of labels must be equal to the second dimension of logits."));
}
TEST_F(TypePropCTCLossV4Test, default_ctor) {
// create inputs
auto logits_shape = PartialShape{{2, 20}, {100, 130}, 28};
auto logits_len_shape = PartialShape{{5, 10}};
auto labels_shape = PartialShape{-1, 120};
set_shape_labels(logits_shape, 10);
set_shape_labels(logits_len_shape, 20);
set_shape_labels(labels_shape, 30);
auto logits = make_shared<Parameter>(element::f32, logits_shape);
auto logit_length = make_shared<Parameter>(element::i32, logits_len_shape);
auto labels = make_shared<Parameter>(element::i32, labels_shape);
auto label_length = make_shared<Parameter>(element::i32, PartialShape{Dimension::dynamic()});
auto blank_index = make_shared<Parameter>(element::i32, Shape{});
// create CTCLoss node
auto ctc_loss = make_op();
ctc_loss->set_arguments(OutputVector{logits, logit_length, labels, label_length, blank_index});
ctc_loss->validate_and_infer_types();
// check type and shape infer
EXPECT_EQ(ctc_loss->get_element_type(), element::f32);
EXPECT_EQ(ctc_loss->get_output_partial_shape(0), PartialShape({{5, 10}}));
EXPECT_THAT(get_shape_labels(ctc_loss->get_output_partial_shape(0)), ElementsAre(30));
}

View File

@ -1,35 +0,0 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include <ctc_loss_shape_inference.hpp>
#include <openvino/op/ctc_loss.hpp>
#include <openvino/op/parameter.hpp>
#include <utils/shape_inference/shape_inference.hpp>
#include <utils/shape_inference/static_shape.hpp>
using namespace ov;
using namespace ov::intel_cpu;
TEST(StaticShapeInferenceTest, CTCLossTest) {
const auto& logits = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1});
const auto& logit_length = std::make_shared<ov::op::v0::Parameter>(element::i32, PartialShape{-1});
const auto& labels = std::make_shared<ov::op::v0::Parameter>(element::i32, PartialShape{-1, -1});
const auto& label_length = std::make_shared<ov::op::v0::Parameter>(element::i32, PartialShape{-1});
const auto& blank_index = std::make_shared<ov::op::v0::Parameter>(element::i32, ov::Shape{});
// create CTCLoss node
auto ctc_loss = std::make_shared<op::v4::CTCLoss>(logits, logit_length, labels, label_length, blank_index);
std::vector<StaticShape> static_input_shapes = {StaticShape{10, 120, 28},
StaticShape{10},
StaticShape{10, 120},
StaticShape{10},
ov::Shape{}},
static_output_shapes = {StaticShape{}};
shape_inference(ctc_loss.get(), static_input_shapes, static_output_shapes);
ASSERT_EQ(static_output_shapes[0], StaticShape({10}));
}

View File

@ -0,0 +1,46 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include "ctc_loss_shape_inference.hpp"
#include "openvino/opsets/opset10.hpp"
#include "utils.hpp"
using namespace ov;
using namespace ov::opset10;
using namespace ov::intel_cpu;
class CTCLossV4StaticShapeInferenceTest : public OpStaticShapeInferenceTest<op::v4::CTCLoss> {
protected:
void SetUp() override {
output_shapes.resize(1);
}
};
TEST_F(CTCLossV4StaticShapeInferenceTest, correct_input_shapes) {
const auto& logits = std::make_shared<Parameter>(element::f32, PartialShape{-1, -1, -1});
const auto& logit_length = std::make_shared<Parameter>(element::i32, PartialShape{-1});
const auto& labels = std::make_shared<Parameter>(element::i32, PartialShape{-1, -1});
const auto& label_length = std::make_shared<Parameter>(element::i32, PartialShape{-1});
const auto& blank_index = std::make_shared<Parameter>(element::i32, ov::Shape{});
auto op = make_op(logits, logit_length, labels, label_length, blank_index);
input_shapes = ShapeVector{{10, 120, 28}, {10}, {10, 120}, {10}, {}};
shape_inference(op.get(), input_shapes, output_shapes);
EXPECT_EQ(output_shapes.size(), 1);
EXPECT_EQ(output_shapes.front(), StaticShape({10}));
}
TEST_F(CTCLossV4StaticShapeInferenceTest, default_ctor) {
auto op = make_op();
input_shapes = ShapeVector{{12, 120, 28}, {12}, {12, 120}, {12}, {}};
shape_inference(op.get(), input_shapes, output_shapes);
EXPECT_EQ(output_shapes.size(), 1);
EXPECT_EQ(output_shapes.front(), StaticShape({12}));
}