[ShapeInference] CTCGreedyDecoderSeqLen shape infer improvements (#15501)

* Shape infer improvements

* Add setter for merge repeated attribute

* Use new shape_infer in validate and infer types

* Add more type prop tests

* Add shape infer tests

* Align variable names in tests

* shape_infer call refactor

---------

Co-authored-by: Evgenya Stepyreva <evgenya.stepyreva@intel.com>
This commit is contained in:
Katarzyna Mitrus 2023-02-08 13:24:48 +01:00 committed by GitHub
parent c3083589bd
commit 76817f56c2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 275 additions and 125 deletions

View File

@ -58,6 +58,14 @@ public:
bool get_merge_repeated() const { bool get_merge_repeated() const {
return m_merge_repeated; return m_merge_repeated;
} }
/// \brief Set merge_repeated attribute
///
/// \param merge_repeated A new value for the attribute
///
void set_merge_repeated(bool merge_repeated) {
m_merge_repeated = merge_repeated;
}
/// \brief Get classes_index_type attribute /// \brief Get classes_index_type attribute
/// ///
/// \return Current value of classes_index_type attribute /// \return Current value of classes_index_type attribute

View File

@ -2,55 +2,53 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
// //
#pragma once #pragma once
#include <openvino/op/ctc_greedy_decoder_seq_len.hpp> #include "openvino/op/ctc_greedy_decoder_seq_len.hpp"
namespace ov { namespace ov {
namespace op { namespace op {
namespace v6 { namespace v6 {
template <class T> template <class TShape>
void shape_infer(const CTCGreedyDecoderSeqLen* op, const std::vector<T>& input_shapes, std::vector<T>& output_shapes) { std::vector<TShape> shape_infer(const CTCGreedyDecoderSeqLen* op, const std::vector<TShape>& input_shapes) {
NODE_VALIDATION_CHECK(op, (input_shapes.size() == 2 || input_shapes.size() == 3) && output_shapes.size() == 2); NODE_VALIDATION_CHECK(op, input_shapes.size() == 2 || input_shapes.size() == 3);
using DimType = typename std::iterator_traits<typename T::iterator>::value_type; using DimType = typename TShape::value_type;
const auto& logits_shape = input_shapes[0]; const auto& logits_shape = input_shapes[0];
const auto& seq_len_shape = input_shapes[1]; const auto& seq_len_shape = input_shapes[1];
const bool logits_is_static_rank = logits_shape.rank().is_static();
const bool seq_len_is_static_rank = seq_len_shape.rank().is_static();
auto& decoded_shape = output_shapes[0];
auto& seq_shape = output_shapes[1];
decoded_shape.resize(2);
seq_shape.resize(1);
if (input_shapes.size() == 3) {
const auto& blank_shape = input_shapes[2];
const auto& blank_rank = blank_shape.rank();
if (blank_shape.is_static()) {
const auto blank_is_scalar = blank_rank.get_length() == 0;
const auto blank_has_one_elem = blank_rank.get_length() == 1 && blank_shape[0].get_length() == 1;
NODE_VALIDATION_CHECK(op,
blank_is_scalar || blank_has_one_elem,
"Expected 0D or 1D tensor for the 'blank_index' input. Got: ",
blank_shape);
}
}
auto& batch_size = decoded_shape[0];
auto& time_size = decoded_shape[1];
// check ranks of input tensors if (input_shapes.size() == 3 && input_shapes[2].is_static()) {
if (logits_is_static_rank) { const auto& blank_shape = input_shapes[2];
NODE_VALIDATION_CHECK(op, logits_shape.rank().compatible(3), "The rank of logits tensor must be equal to 3."); const auto blank_is_scalar = blank_shape.size() == 0;
const auto blank_has_one_elem = blank_shape.size() == 1 && blank_shape[0].get_length() == 1;
NODE_VALIDATION_CHECK(op,
blank_is_scalar || blank_has_one_elem,
"Expected 0D or 1D tensor for the 'blank_index' input. Got: ",
blank_shape);
}
DimType batch_size{};
DimType time_size{};
if (logits_shape.rank().is_static()) {
NODE_VALIDATION_CHECK(op, logits_shape.size() == 3, "The rank of logits tensor must be equal to 3.");
batch_size = logits_shape[0]; batch_size = logits_shape[0];
time_size = logits_shape[1]; time_size = logits_shape[1];
} }
if (seq_len_is_static_rank) { if (seq_len_shape.rank().is_static()) {
NODE_VALIDATION_CHECK(op, NODE_VALIDATION_CHECK(op, seq_len_shape.size() == 1, "The rank of sequence len tensor must be equal to 1.");
seq_len_shape.rank().compatible(1),
"The rank of sequence len tensor must be equal to 1.");
NODE_VALIDATION_CHECK(op, NODE_VALIDATION_CHECK(op,
DimType::merge(batch_size, batch_size, seq_len_shape[0]), DimType::merge(batch_size, batch_size, seq_len_shape[0]),
"The first dimensions of input tensors must match."); "The first dimensions of input tensors must match.");
} }
seq_shape[0] = batch_size; return {TShape{batch_size, time_size}, TShape{batch_size}};
}
template <class TShape>
void shape_infer(const CTCGreedyDecoderSeqLen* op,
const std::vector<TShape>& input_shapes,
std::vector<TShape>& output_shapes) {
output_shapes = shape_infer(op, input_shapes);
} }
} // namespace v6 } // namespace v6
} // namespace op } // namespace op

View File

@ -51,8 +51,8 @@ void op::v6::CTCGreedyDecoderSeqLen::validate_and_infer_types() {
input_shapes.push_back(get_input_partial_shape(2)); input_shapes.push_back(get_input_partial_shape(2));
} }
std::vector<ov::PartialShape> output_shapes = {ov::PartialShape{}, ov::PartialShape{}}; const auto output_shapes = shape_infer(this, input_shapes);
shape_infer(this, input_shapes, output_shapes);
set_output_type(0, m_classes_index_type, output_shapes[0]); set_output_type(0, m_classes_index_type, output_shapes[0]);
set_output_type(1, m_sequence_length_type, output_shapes[1]); set_output_type(1, m_sequence_length_type, output_shapes[1]);
} }

View File

@ -2,25 +2,57 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
// //
#include "dimension_tracker.hpp"
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "ngraph/ngraph.hpp" #include "openvino/op/ops.hpp"
#include "util/type_prop.hpp" #include "util/type_prop.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ov;
using namespace testing;
TEST(type_prop, ctc_greedy_decoder_seq_len_default_ctor) {
PartialShape logits_shape{2, 10, 1200};
PartialShape seq_len_shape{2};
auto op = make_shared<op::v6::CTCGreedyDecoderSeqLen>();
auto logits_param = make_shared<op::v0::Parameter>(element::f32, logits_shape);
auto seq_len_param = make_shared<op::v0::Parameter>(element::f32, seq_len_shape);
op->set_arguments(OutputVector{logits_param, seq_len_param});
op->set_merge_repeated(false);
EXPECT_EQ(op->get_merge_repeated(), false);
op->set_merge_repeated(true);
EXPECT_EQ(op->get_merge_repeated(), true);
op->validate_and_infer_types();
EXPECT_EQ(op->get_output_element_type(0), element::i32);
EXPECT_EQ(op->get_output_element_type(1), element::i32);
op->set_classes_index_type(element::i64);
EXPECT_EQ(op->get_output_element_type(0), element::i64);
op->set_sequence_length_type(element::i64);
EXPECT_EQ(op->get_output_element_type(1), element::i64);
EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{2, 10}));
EXPECT_EQ(op->get_output_partial_shape(1), (PartialShape{2}));
}
TEST(type_prop, ctc_greedy_decoder_seq_len_static_shapes) { TEST(type_prop, ctc_greedy_decoder_seq_len_static_shapes) {
PartialShape logits_shape{3, 100, 1200}; PartialShape logits_shape{3, 100, 1200};
PartialShape seq_len_shape{3}; PartialShape seq_len_shape{3};
Shape out_shape1{3, 100}; Shape out_shape1{3, 100};
Shape out_shape2{3}; Shape out_shape2{3};
auto P = make_shared<op::Parameter>(element::f32, logits_shape); auto logits_param = make_shared<op::v0::Parameter>(element::f32, logits_shape);
auto I = make_shared<op::Parameter>(element::i32, seq_len_shape); auto seq_len_param = make_shared<op::v0::Parameter>(element::i32, seq_len_shape);
auto G = make_shared<op::v6::CTCGreedyDecoderSeqLen>(P, I); auto op = make_shared<op::v6::CTCGreedyDecoderSeqLen>(logits_param, seq_len_param);
ASSERT_EQ(G->get_output_element_type(0), element::i32); EXPECT_EQ(op->get_output_element_type(0), element::i32);
ASSERT_EQ(G->get_output_element_type(1), element::i32); EXPECT_EQ(op->get_output_element_type(1), element::i32);
ASSERT_EQ(G->get_output_shape(0), out_shape1); EXPECT_EQ(op->get_output_shape(0), out_shape1);
ASSERT_EQ(G->get_output_shape(1), out_shape2); EXPECT_EQ(op->get_output_shape(1), out_shape2);
} }
TEST(type_prop, ctc_greedy_decoder_seq_len_static_shapes_with_bi) { TEST(type_prop, ctc_greedy_decoder_seq_len_static_shapes_with_bi) {
@ -28,14 +60,15 @@ TEST(type_prop, ctc_greedy_decoder_seq_len_static_shapes_with_bi) {
PartialShape seq_len_shape{3}; PartialShape seq_len_shape{3};
Shape out_shape1{3, 100}; Shape out_shape1{3, 100};
Shape out_shape2{3}; Shape out_shape2{3};
auto P = make_shared<op::Parameter>(element::f32, logits_shape); auto logits_param = make_shared<op::v0::Parameter>(element::f32, logits_shape);
auto I = make_shared<op::Parameter>(element::i32, seq_len_shape); auto seq_len_param = make_shared<op::v0::Parameter>(element::i32, seq_len_shape);
auto BI = op::Constant::create(element::i32, Shape{}, {1}); auto bi = op::v0::Constant::create(element::i32, Shape{}, {1});
auto G = make_shared<op::v6::CTCGreedyDecoderSeqLen>(P, I, BI, false, element::i64, element::i64); auto op =
ASSERT_EQ(G->get_output_element_type(0), element::i64); make_shared<op::v6::CTCGreedyDecoderSeqLen>(logits_param, seq_len_param, bi, false, element::i64, element::i64);
ASSERT_EQ(G->get_output_element_type(1), element::i64); EXPECT_EQ(op->get_output_element_type(0), element::i64);
ASSERT_EQ(G->get_output_shape(0), out_shape1); EXPECT_EQ(op->get_output_element_type(1), element::i64);
ASSERT_EQ(G->get_output_shape(1), out_shape2); EXPECT_EQ(op->get_output_shape(0), out_shape1);
EXPECT_EQ(op->get_output_shape(1), out_shape2);
} }
TEST(type_prop, ctc_greedy_decoder_seq_len_static_shapes_with_dinemic_bi) { TEST(type_prop, ctc_greedy_decoder_seq_len_static_shapes_with_dinemic_bi) {
@ -43,14 +76,15 @@ TEST(type_prop, ctc_greedy_decoder_seq_len_static_shapes_with_dinemic_bi) {
PartialShape seq_len_shape{3}; PartialShape seq_len_shape{3};
Shape out_shape1{3, 100}; Shape out_shape1{3, 100};
Shape out_shape2{3}; Shape out_shape2{3};
auto P = make_shared<op::Parameter>(element::f32, logits_shape); auto logits_param = make_shared<op::v0::Parameter>(element::f32, logits_shape);
auto I = make_shared<op::Parameter>(element::i32, seq_len_shape); auto seq_len_param = make_shared<op::v0::Parameter>(element::i32, seq_len_shape);
auto BI = make_shared<op::Parameter>(element::i32, PartialShape{Dimension::dynamic()}); auto bi = make_shared<op::v0::Parameter>(element::i32, PartialShape{Dimension::dynamic()});
auto G = make_shared<op::v6::CTCGreedyDecoderSeqLen>(P, I, BI, false, element::i64, element::i64); auto op =
ASSERT_EQ(G->get_output_element_type(0), element::i64); make_shared<op::v6::CTCGreedyDecoderSeqLen>(logits_param, seq_len_param, bi, false, element::i64, element::i64);
ASSERT_EQ(G->get_output_element_type(1), element::i64); EXPECT_EQ(op->get_output_element_type(0), element::i64);
ASSERT_EQ(G->get_output_shape(0), out_shape1); EXPECT_EQ(op->get_output_element_type(1), element::i64);
ASSERT_EQ(G->get_output_shape(1), out_shape2); EXPECT_EQ(op->get_output_shape(0), out_shape1);
EXPECT_EQ(op->get_output_shape(1), out_shape2);
} }
TEST(type_prop, ctc_greedy_decoder_seq_len_output_static_shape1) { TEST(type_prop, ctc_greedy_decoder_seq_len_output_static_shape1) {
@ -58,13 +92,13 @@ TEST(type_prop, ctc_greedy_decoder_seq_len_output_static_shape1) {
PartialShape seq_len_shape{3}; PartialShape seq_len_shape{3};
Shape out_shape1{3, 100}; Shape out_shape1{3, 100};
Shape out_shape2{3}; Shape out_shape2{3};
auto P = make_shared<op::Parameter>(element::f32, logits_shape); auto logits_param = make_shared<op::v0::Parameter>(element::f32, logits_shape);
auto I = make_shared<op::Parameter>(element::i32, seq_len_shape); auto seq_len_param = make_shared<op::v0::Parameter>(element::i32, seq_len_shape);
auto G = make_shared<op::v6::CTCGreedyDecoderSeqLen>(P, I, false); auto op = make_shared<op::v6::CTCGreedyDecoderSeqLen>(logits_param, seq_len_param, false);
ASSERT_EQ(G->get_output_element_type(0), element::i32); EXPECT_EQ(op->get_output_element_type(0), element::i32);
ASSERT_EQ(G->get_output_element_type(1), element::i32); EXPECT_EQ(op->get_output_element_type(1), element::i32);
ASSERT_EQ(G->get_output_shape(0), out_shape1); EXPECT_EQ(op->get_output_shape(0), out_shape1);
ASSERT_EQ(G->get_output_shape(1), out_shape2); EXPECT_EQ(op->get_output_shape(1), out_shape2);
} }
TEST(type_prop, ctc_greedy_decoder_seq_len_dynamic_shapes) { TEST(type_prop, ctc_greedy_decoder_seq_len_dynamic_shapes) {
@ -72,13 +106,13 @@ TEST(type_prop, ctc_greedy_decoder_seq_len_dynamic_shapes) {
PartialShape seq_len_shape{Dimension::dynamic()}; PartialShape seq_len_shape{Dimension::dynamic()};
PartialShape out_shape1{Dimension::dynamic(), Dimension::dynamic()}; PartialShape out_shape1{Dimension::dynamic(), Dimension::dynamic()};
PartialShape out_shape2{Dimension::dynamic()}; PartialShape out_shape2{Dimension::dynamic()};
auto P = make_shared<op::Parameter>(element::f32, logits_shape); auto logits_param = make_shared<op::v0::Parameter>(element::f32, logits_shape);
auto I = make_shared<op::Parameter>(element::i32, seq_len_shape); auto seq_len_param = make_shared<op::v0::Parameter>(element::i32, seq_len_shape);
auto G = make_shared<op::v6::CTCGreedyDecoderSeqLen>(P, I, false); auto op = make_shared<op::v6::CTCGreedyDecoderSeqLen>(logits_param, seq_len_param, false);
ASSERT_EQ(G->get_output_element_type(0), element::i32); EXPECT_EQ(op->get_output_element_type(0), element::i32);
ASSERT_EQ(G->get_output_element_type(1), element::i32); EXPECT_EQ(op->get_output_element_type(1), element::i32);
ASSERT_TRUE(G->get_output_partial_shape(0).same_scheme(out_shape1)); EXPECT_TRUE(op->get_output_partial_shape(0).same_scheme(out_shape1));
ASSERT_TRUE(G->get_output_partial_shape(1).same_scheme(out_shape2)); EXPECT_TRUE(op->get_output_partial_shape(1).same_scheme(out_shape2));
} }
TEST(type_prop, ctc_greedy_decoder_seq_len_dynamic_ranks1) { TEST(type_prop, ctc_greedy_decoder_seq_len_dynamic_ranks1) {
@ -86,37 +120,101 @@ TEST(type_prop, ctc_greedy_decoder_seq_len_dynamic_ranks1) {
PartialShape seq_len_shape{Dimension::dynamic()}; PartialShape seq_len_shape{Dimension::dynamic()};
PartialShape out_shape1{Dimension::dynamic(), Dimension::dynamic()}; PartialShape out_shape1{Dimension::dynamic(), Dimension::dynamic()};
PartialShape out_shape2{Dimension::dynamic()}; PartialShape out_shape2{Dimension::dynamic()};
auto P = make_shared<op::Parameter>(element::f32, logits_shape); auto logits_param = make_shared<op::v0::Parameter>(element::f32, logits_shape);
auto I = make_shared<op::Parameter>(element::i32, seq_len_shape); auto seq_len_param = make_shared<op::v0::Parameter>(element::i32, seq_len_shape);
auto G = make_shared<op::v6::CTCGreedyDecoderSeqLen>(P, I); auto op = make_shared<op::v6::CTCGreedyDecoderSeqLen>(logits_param, seq_len_param);
ASSERT_EQ(G->get_output_element_type(0), element::i32); EXPECT_EQ(op->get_output_element_type(0), element::i32);
ASSERT_EQ(G->get_output_element_type(1), element::i32); EXPECT_EQ(op->get_output_element_type(1), element::i32);
ASSERT_TRUE(G->get_output_partial_shape(0).same_scheme(out_shape1)); EXPECT_TRUE(op->get_output_partial_shape(0).same_scheme(out_shape1));
ASSERT_TRUE(G->get_output_partial_shape(1).same_scheme(out_shape2)); EXPECT_TRUE(op->get_output_partial_shape(1).same_scheme(out_shape2));
} }
TEST(type_prop, ctc_greedy_decoder_seq_len_dynamic_ranks2) { TEST(type_prop, ctc_greedy_decoder_seq_len_dynamic_ranks2) {
PartialShape logits_shape = PartialShape::dynamic(); PartialShape logits_shape = PartialShape::dynamic();
PartialShape seq_mask_shape = PartialShape::dynamic(); PartialShape seq_len_shape = PartialShape::dynamic();
PartialShape out_shape1{Dimension::dynamic(), Dimension::dynamic()}; PartialShape out_shape1{Dimension::dynamic(), Dimension::dynamic()};
PartialShape out_shape2{Dimension::dynamic()}; PartialShape out_shape2{Dimension::dynamic()};
auto P = make_shared<op::Parameter>(element::f32, logits_shape); auto logits_param = make_shared<op::v0::Parameter>(element::f32, logits_shape);
auto I = make_shared<op::Parameter>(element::i32, seq_mask_shape); auto seq_len_param = make_shared<op::v0::Parameter>(element::i32, seq_len_shape);
auto G = make_shared<op::v6::CTCGreedyDecoderSeqLen>(P, I, false); auto op = make_shared<op::v6::CTCGreedyDecoderSeqLen>(logits_param, seq_len_param, false);
ASSERT_EQ(G->get_output_element_type(0), element::i32); EXPECT_EQ(op->get_output_element_type(0), element::i32);
ASSERT_EQ(G->get_output_element_type(1), element::i32); EXPECT_EQ(op->get_output_element_type(1), element::i32);
ASSERT_TRUE(G->get_output_partial_shape(0).same_scheme(out_shape1)); EXPECT_TRUE(op->get_output_partial_shape(0).same_scheme(out_shape1));
ASSERT_TRUE(G->get_output_partial_shape(1).same_scheme(out_shape2)); EXPECT_TRUE(op->get_output_partial_shape(1).same_scheme(out_shape2));
}
TEST(type_prop, ctc_greedy_decoder_seq_len_interval_labeled_dims_all) {
PartialShape logits_shape{{2, 6}, {10, 100}, {600, 1200}};
PartialShape seq_len_shape{{4, 8}};
set_shape_labels(logits_shape, 10);
set_shape_labels(seq_len_shape, 20);
auto logits_param = make_shared<op::v0::Parameter>(element::f32, logits_shape);
auto seq_len_param = make_shared<op::v0::Parameter>(element::f32, seq_len_shape);
auto op = make_shared<op::v6::CTCGreedyDecoderSeqLen>(logits_param, seq_len_param);
// Output 0
EXPECT_EQ(op->get_output_element_type(0), element::i32);
EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{{4, 6}, {10, 100}}));
EXPECT_THAT(get_shape_labels(op->get_output_partial_shape(0)), ElementsAre(20, 11));
// Output 1
EXPECT_EQ(op->get_output_element_type(1), element::i32);
EXPECT_EQ(op->get_output_partial_shape(1), (PartialShape{{4, 6}}));
EXPECT_THAT(get_shape_labels(op->get_output_partial_shape(1)), ElementsAre(20));
}
TEST(type_prop, ctc_greedy_decoder_seq_len_interval_labeled_dims_in0) {
PartialShape logits_shape{{2, 6}, {10, 100}, {600, 1200}};
PartialShape seq_len_shape{{4, 8}};
set_shape_labels(logits_shape, 10);
auto logits_param = make_shared<op::v0::Parameter>(element::f32, logits_shape);
auto seq_len_param = make_shared<op::v0::Parameter>(element::f32, seq_len_shape);
auto op = make_shared<op::v6::CTCGreedyDecoderSeqLen>(logits_param, seq_len_param);
// Output 0
EXPECT_EQ(op->get_output_element_type(0), element::i32);
EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{{4, 6}, {10, 100}}));
EXPECT_THAT(get_shape_labels(op->get_output_partial_shape(0)), ElementsAre(10, 11));
// Output 1
EXPECT_EQ(op->get_output_element_type(1), element::i32);
EXPECT_EQ(op->get_output_partial_shape(1), (PartialShape{{4, 6}}));
EXPECT_THAT(get_shape_labels(op->get_output_partial_shape(1)), ElementsAre(10));
}
TEST(type_prop, ctc_greedy_decoder_seq_len_interval_labeled_dims_in1) {
PartialShape logits_shape{{2, 6}, {10, 100}, {600, 1200}};
PartialShape seq_len_shape{{4, 8}};
set_shape_labels(seq_len_shape, 20);
auto logits_param = make_shared<op::v0::Parameter>(element::f32, logits_shape);
auto seq_len_param = make_shared<op::v0::Parameter>(element::f32, seq_len_shape);
auto op = make_shared<op::v6::CTCGreedyDecoderSeqLen>(logits_param, seq_len_param);
// Output 0
EXPECT_EQ(op->get_output_element_type(0), element::i32);
EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{{4, 6}, {10, 100}}));
EXPECT_THAT(get_shape_labels(op->get_output_partial_shape(0)), ElementsAre(20, ov::no_label));
// Output 1
EXPECT_EQ(op->get_output_element_type(1), element::i32);
EXPECT_EQ(op->get_output_partial_shape(1), (PartialShape{{4, 6}}));
EXPECT_THAT(get_shape_labels(op->get_output_partial_shape(1)), ElementsAre(20));
} }
TEST(type_prop, ctc_greedy_decoder_seq_len_incorrect_rank) { TEST(type_prop, ctc_greedy_decoder_seq_len_incorrect_rank) {
PartialShape logits_shape{Dimension::dynamic(), 100, 1200, 5}; PartialShape logits_shape{Dimension::dynamic(), 100, 1200, 5};
PartialShape seq_len_shape{3}; PartialShape seq_len_shape{3};
auto P = make_shared<op::Parameter>(element::f32, logits_shape); auto logits_param = make_shared<op::v0::Parameter>(element::f32, logits_shape);
auto I = make_shared<op::Parameter>(element::i32, seq_len_shape); auto seq_len_param = make_shared<op::v0::Parameter>(element::i32, seq_len_shape);
try { try {
auto G = make_shared<op::v6::CTCGreedyDecoderSeqLen>(P, I, false); auto op = make_shared<op::v6::CTCGreedyDecoderSeqLen>(logits_param, seq_len_param, false);
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Incorrect indices rank"; FAIL() << "Incorrect indices rank";
} catch (const NodeValidationFailure& error) { } catch (const NodeValidationFailure& error) {
@ -129,11 +227,11 @@ TEST(type_prop, ctc_greedy_decoder_seq_len_incorrect_rank) {
TEST(type_prop, ctc_greedy_decoder_seq_len_incorrect_rank2) { TEST(type_prop, ctc_greedy_decoder_seq_len_incorrect_rank2) {
PartialShape logits_shape{3, 100, 1200}; PartialShape logits_shape{3, 100, 1200};
PartialShape seq_len_shape{3, 100}; PartialShape seq_len_shape{3, 100};
auto P = make_shared<op::Parameter>(element::f32, logits_shape); auto logits_param = make_shared<op::v0::Parameter>(element::f32, logits_shape);
auto I = make_shared<op::Parameter>(element::i32, seq_len_shape); auto seq_len_param = make_shared<op::v0::Parameter>(element::i32, seq_len_shape);
try { try {
auto G = make_shared<op::v6::CTCGreedyDecoderSeqLen>(P, I, false); auto op = make_shared<op::v6::CTCGreedyDecoderSeqLen>(logits_param, seq_len_param, false);
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Incorrect indices rank"; FAIL() << "Incorrect indices rank";
} catch (const NodeValidationFailure& error) { } catch (const NodeValidationFailure& error) {
@ -145,12 +243,12 @@ TEST(type_prop, ctc_greedy_decoder_seq_len_incorrect_rank2) {
TEST(type_prop, ctc_greedy_decoder_seq_len_mismatched_dim1) { TEST(type_prop, ctc_greedy_decoder_seq_len_mismatched_dim1) {
PartialShape logits_shape{4, 100, 1200}; PartialShape logits_shape{4, 100, 1200};
PartialShape seq_mask_shape{3}; PartialShape seq_len_shape{3};
auto P = make_shared<op::Parameter>(element::f32, logits_shape); auto logits_param = make_shared<op::v0::Parameter>(element::f32, logits_shape);
auto I = make_shared<op::Parameter>(element::i32, seq_mask_shape); auto seq_len_param = make_shared<op::v0::Parameter>(element::i32, seq_len_shape);
try { try {
auto G = make_shared<op::v6::CTCGreedyDecoderSeqLen>(P, I, false); auto op = make_shared<op::v6::CTCGreedyDecoderSeqLen>(logits_param, seq_len_param, false);
// Should have thrown, so fail if it didn't // Should have thrown, so fail if it didn't
FAIL() << "Incorrect indices rank"; FAIL() << "Incorrect indices rank";
} catch (const NodeValidationFailure& error) { } catch (const NodeValidationFailure& error) {

View File

@ -1,26 +0,0 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include <ctc_greedy_decoder_seq_len_shape_inference.hpp>
#include <openvino/op/ops.hpp>
#include <openvino/op/parameter.hpp>
#include <utils/shape_inference/shape_inference.hpp>
#include <utils/shape_inference/static_shape.hpp>
using namespace ov;
using namespace ov::intel_cpu;
TEST(StaticShapeInferenceTest, CtcGreedyDecoderSeqLenTest) {
auto P = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1});
auto I = std::make_shared<ov::op::v0::Parameter>(element::i32, PartialShape{-1});
auto G = std::make_shared<op::v6::CTCGreedyDecoderSeqLen>(P, I);
// Test StaticShape
std::vector<StaticShape> static_input_shapes = {StaticShape{3, 100, 1200}, StaticShape{3}},
static_output_shapes = {StaticShape{}, StaticShape{}};
shape_inference(G.get(), static_input_shapes, static_output_shapes);
ASSERT_EQ(static_output_shapes[0], StaticShape({3, 100}));
ASSERT_EQ(static_output_shapes[1], StaticShape({3}));
}

View File

@ -0,0 +1,72 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "ctc_greedy_decoder_seq_len_shape_inference.hpp"
#include <gtest/gtest.h>
#include "common_test_utils/test_assertions.hpp"
#include "openvino/op/ops.hpp"
#include "utils.hpp"
using namespace ov;
using namespace ov::intel_cpu;
using namespace testing;
class CTCGreedyDecoderSeqLenV6StaticShapeInferenceTest : public OpStaticShapeInferenceTest<op::v6::CTCGreedyDecoderSeqLen> {
void SetUp() override {
output_shapes.resize(1);
}
};
TEST_F(CTCGreedyDecoderSeqLenV6StaticShapeInferenceTest, basic) {
auto data = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1});
auto seq_mask = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{-1});
op = make_op(data, seq_mask, false);
input_shapes = {StaticShape{4, 100, 1200}, StaticShape{4}};
shape_inference(op.get(), input_shapes, output_shapes);
EXPECT_EQ(output_shapes[0], StaticShape({4, 100}));
EXPECT_EQ(output_shapes[1], StaticShape({4}));
}
TEST_F(CTCGreedyDecoderSeqLenV6StaticShapeInferenceTest, default_ctor) {
op = make_op();
// Two inputs
input_shapes = {StaticShape{4, 100, 1200}, StaticShape{4}};
shape_inference(op.get(), input_shapes, output_shapes);
EXPECT_EQ(output_shapes[0], StaticShape({4, 100}));
EXPECT_EQ(output_shapes[1], StaticShape({4}));
// Three inputs (the last one is optional)
input_shapes = {StaticShape{4, 100, 1200}, StaticShape{4}, {}};
shape_inference(op.get(), input_shapes, output_shapes);
EXPECT_EQ(output_shapes[0], StaticShape({4, 100}));
EXPECT_EQ(output_shapes[1], StaticShape({4}));
}
TEST_F(CTCGreedyDecoderSeqLenV6StaticShapeInferenceTest, incompatible_batch) {
auto data = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic());
auto seq_mask = std::make_shared<op::v0::Parameter>(element::i32, PartialShape::dynamic());
op = make_op(data, seq_mask, false);
input_shapes = {StaticShape{4, 100, 1200}, StaticShape{6}};
OV_EXPECT_THROW(shape_inference(op.get(), input_shapes, output_shapes),
NodeValidationFailure,
HasSubstr("The first dimensions of input tensors must match"))
}
TEST_F(CTCGreedyDecoderSeqLenV6StaticShapeInferenceTest, incompatible_seq_len_rank) {
auto data = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic());
auto seq_mask = std::make_shared<op::v0::Parameter>(element::i32, PartialShape::dynamic());
op = make_op(data, seq_mask, false);
input_shapes = {StaticShape{4, 100, 1200}, StaticShape{4, 1}};
OV_EXPECT_THROW(shape_inference(op.get(), input_shapes, output_shapes),
NodeValidationFailure,
HasSubstr("The rank of sequence len tensor must be equal to 1"))
}