[ShapeInference] CTCGreedyDecoderSeqLen shape infer improvements (#15501)
* Shape infer improvements * Add setter for merge repeated attribute * Use new shape_infer in validate and infer types * Add more type prop tests * Add shape infer tests * Align variable names in tests * shape_infer call refactor --------- Co-authored-by: Evgenya Stepyreva <evgenya.stepyreva@intel.com>
This commit is contained in:
parent
c3083589bd
commit
76817f56c2
@ -58,6 +58,14 @@ public:
|
|||||||
bool get_merge_repeated() const {
|
bool get_merge_repeated() const {
|
||||||
return m_merge_repeated;
|
return m_merge_repeated;
|
||||||
}
|
}
|
||||||
|
/// \brief Set merge_repeated attribute
|
||||||
|
///
|
||||||
|
/// \param merge_repeated A new value for the attribute
|
||||||
|
///
|
||||||
|
void set_merge_repeated(bool merge_repeated) {
|
||||||
|
m_merge_repeated = merge_repeated;
|
||||||
|
}
|
||||||
|
|
||||||
/// \brief Get classes_index_type attribute
|
/// \brief Get classes_index_type attribute
|
||||||
///
|
///
|
||||||
/// \return Current value of classes_index_type attribute
|
/// \return Current value of classes_index_type attribute
|
||||||
|
@ -2,55 +2,53 @@
|
|||||||
// SPDX-License-Identifier: Apache-2.0
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
//
|
//
|
||||||
#pragma once
|
#pragma once
|
||||||
#include <openvino/op/ctc_greedy_decoder_seq_len.hpp>
|
#include "openvino/op/ctc_greedy_decoder_seq_len.hpp"
|
||||||
|
|
||||||
namespace ov {
|
namespace ov {
|
||||||
namespace op {
|
namespace op {
|
||||||
namespace v6 {
|
namespace v6 {
|
||||||
|
|
||||||
template <class T>
|
template <class TShape>
|
||||||
void shape_infer(const CTCGreedyDecoderSeqLen* op, const std::vector<T>& input_shapes, std::vector<T>& output_shapes) {
|
std::vector<TShape> shape_infer(const CTCGreedyDecoderSeqLen* op, const std::vector<TShape>& input_shapes) {
|
||||||
NODE_VALIDATION_CHECK(op, (input_shapes.size() == 2 || input_shapes.size() == 3) && output_shapes.size() == 2);
|
NODE_VALIDATION_CHECK(op, input_shapes.size() == 2 || input_shapes.size() == 3);
|
||||||
using DimType = typename std::iterator_traits<typename T::iterator>::value_type;
|
using DimType = typename TShape::value_type;
|
||||||
|
|
||||||
const auto& logits_shape = input_shapes[0];
|
const auto& logits_shape = input_shapes[0];
|
||||||
const auto& seq_len_shape = input_shapes[1];
|
const auto& seq_len_shape = input_shapes[1];
|
||||||
const bool logits_is_static_rank = logits_shape.rank().is_static();
|
|
||||||
const bool seq_len_is_static_rank = seq_len_shape.rank().is_static();
|
|
||||||
auto& decoded_shape = output_shapes[0];
|
|
||||||
auto& seq_shape = output_shapes[1];
|
|
||||||
decoded_shape.resize(2);
|
|
||||||
seq_shape.resize(1);
|
|
||||||
if (input_shapes.size() == 3) {
|
|
||||||
const auto& blank_shape = input_shapes[2];
|
|
||||||
const auto& blank_rank = blank_shape.rank();
|
|
||||||
if (blank_shape.is_static()) {
|
|
||||||
const auto blank_is_scalar = blank_rank.get_length() == 0;
|
|
||||||
const auto blank_has_one_elem = blank_rank.get_length() == 1 && blank_shape[0].get_length() == 1;
|
|
||||||
NODE_VALIDATION_CHECK(op,
|
|
||||||
blank_is_scalar || blank_has_one_elem,
|
|
||||||
"Expected 0D or 1D tensor for the 'blank_index' input. Got: ",
|
|
||||||
blank_shape);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
auto& batch_size = decoded_shape[0];
|
|
||||||
auto& time_size = decoded_shape[1];
|
|
||||||
|
|
||||||
// check ranks of input tensors
|
if (input_shapes.size() == 3 && input_shapes[2].is_static()) {
|
||||||
if (logits_is_static_rank) {
|
const auto& blank_shape = input_shapes[2];
|
||||||
NODE_VALIDATION_CHECK(op, logits_shape.rank().compatible(3), "The rank of logits tensor must be equal to 3.");
|
const auto blank_is_scalar = blank_shape.size() == 0;
|
||||||
|
const auto blank_has_one_elem = blank_shape.size() == 1 && blank_shape[0].get_length() == 1;
|
||||||
|
NODE_VALIDATION_CHECK(op,
|
||||||
|
blank_is_scalar || blank_has_one_elem,
|
||||||
|
"Expected 0D or 1D tensor for the 'blank_index' input. Got: ",
|
||||||
|
blank_shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
DimType batch_size{};
|
||||||
|
DimType time_size{};
|
||||||
|
|
||||||
|
if (logits_shape.rank().is_static()) {
|
||||||
|
NODE_VALIDATION_CHECK(op, logits_shape.size() == 3, "The rank of logits tensor must be equal to 3.");
|
||||||
batch_size = logits_shape[0];
|
batch_size = logits_shape[0];
|
||||||
time_size = logits_shape[1];
|
time_size = logits_shape[1];
|
||||||
}
|
}
|
||||||
if (seq_len_is_static_rank) {
|
if (seq_len_shape.rank().is_static()) {
|
||||||
NODE_VALIDATION_CHECK(op,
|
NODE_VALIDATION_CHECK(op, seq_len_shape.size() == 1, "The rank of sequence len tensor must be equal to 1.");
|
||||||
seq_len_shape.rank().compatible(1),
|
|
||||||
"The rank of sequence len tensor must be equal to 1.");
|
|
||||||
NODE_VALIDATION_CHECK(op,
|
NODE_VALIDATION_CHECK(op,
|
||||||
DimType::merge(batch_size, batch_size, seq_len_shape[0]),
|
DimType::merge(batch_size, batch_size, seq_len_shape[0]),
|
||||||
"The first dimensions of input tensors must match.");
|
"The first dimensions of input tensors must match.");
|
||||||
}
|
}
|
||||||
|
|
||||||
seq_shape[0] = batch_size;
|
return {TShape{batch_size, time_size}, TShape{batch_size}};
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class TShape>
|
||||||
|
void shape_infer(const CTCGreedyDecoderSeqLen* op,
|
||||||
|
const std::vector<TShape>& input_shapes,
|
||||||
|
std::vector<TShape>& output_shapes) {
|
||||||
|
output_shapes = shape_infer(op, input_shapes);
|
||||||
}
|
}
|
||||||
} // namespace v6
|
} // namespace v6
|
||||||
} // namespace op
|
} // namespace op
|
||||||
|
@ -51,8 +51,8 @@ void op::v6::CTCGreedyDecoderSeqLen::validate_and_infer_types() {
|
|||||||
input_shapes.push_back(get_input_partial_shape(2));
|
input_shapes.push_back(get_input_partial_shape(2));
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<ov::PartialShape> output_shapes = {ov::PartialShape{}, ov::PartialShape{}};
|
const auto output_shapes = shape_infer(this, input_shapes);
|
||||||
shape_infer(this, input_shapes, output_shapes);
|
|
||||||
set_output_type(0, m_classes_index_type, output_shapes[0]);
|
set_output_type(0, m_classes_index_type, output_shapes[0]);
|
||||||
set_output_type(1, m_sequence_length_type, output_shapes[1]);
|
set_output_type(1, m_sequence_length_type, output_shapes[1]);
|
||||||
}
|
}
|
||||||
|
@ -2,25 +2,57 @@
|
|||||||
// SPDX-License-Identifier: Apache-2.0
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
//
|
//
|
||||||
|
|
||||||
|
#include "dimension_tracker.hpp"
|
||||||
#include "gtest/gtest.h"
|
#include "gtest/gtest.h"
|
||||||
#include "ngraph/ngraph.hpp"
|
#include "openvino/op/ops.hpp"
|
||||||
#include "util/type_prop.hpp"
|
#include "util/type_prop.hpp"
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace ngraph;
|
using namespace ov;
|
||||||
|
using namespace testing;
|
||||||
|
|
||||||
|
TEST(type_prop, ctc_greedy_decoder_seq_len_default_ctor) {
|
||||||
|
PartialShape logits_shape{2, 10, 1200};
|
||||||
|
PartialShape seq_len_shape{2};
|
||||||
|
|
||||||
|
auto op = make_shared<op::v6::CTCGreedyDecoderSeqLen>();
|
||||||
|
|
||||||
|
auto logits_param = make_shared<op::v0::Parameter>(element::f32, logits_shape);
|
||||||
|
auto seq_len_param = make_shared<op::v0::Parameter>(element::f32, seq_len_shape);
|
||||||
|
op->set_arguments(OutputVector{logits_param, seq_len_param});
|
||||||
|
|
||||||
|
op->set_merge_repeated(false);
|
||||||
|
EXPECT_EQ(op->get_merge_repeated(), false);
|
||||||
|
|
||||||
|
op->set_merge_repeated(true);
|
||||||
|
EXPECT_EQ(op->get_merge_repeated(), true);
|
||||||
|
|
||||||
|
op->validate_and_infer_types();
|
||||||
|
EXPECT_EQ(op->get_output_element_type(0), element::i32);
|
||||||
|
EXPECT_EQ(op->get_output_element_type(1), element::i32);
|
||||||
|
|
||||||
|
op->set_classes_index_type(element::i64);
|
||||||
|
EXPECT_EQ(op->get_output_element_type(0), element::i64);
|
||||||
|
|
||||||
|
op->set_sequence_length_type(element::i64);
|
||||||
|
EXPECT_EQ(op->get_output_element_type(1), element::i64);
|
||||||
|
|
||||||
|
EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{2, 10}));
|
||||||
|
EXPECT_EQ(op->get_output_partial_shape(1), (PartialShape{2}));
|
||||||
|
}
|
||||||
|
|
||||||
TEST(type_prop, ctc_greedy_decoder_seq_len_static_shapes) {
|
TEST(type_prop, ctc_greedy_decoder_seq_len_static_shapes) {
|
||||||
PartialShape logits_shape{3, 100, 1200};
|
PartialShape logits_shape{3, 100, 1200};
|
||||||
PartialShape seq_len_shape{3};
|
PartialShape seq_len_shape{3};
|
||||||
Shape out_shape1{3, 100};
|
Shape out_shape1{3, 100};
|
||||||
Shape out_shape2{3};
|
Shape out_shape2{3};
|
||||||
auto P = make_shared<op::Parameter>(element::f32, logits_shape);
|
auto logits_param = make_shared<op::v0::Parameter>(element::f32, logits_shape);
|
||||||
auto I = make_shared<op::Parameter>(element::i32, seq_len_shape);
|
auto seq_len_param = make_shared<op::v0::Parameter>(element::i32, seq_len_shape);
|
||||||
auto G = make_shared<op::v6::CTCGreedyDecoderSeqLen>(P, I);
|
auto op = make_shared<op::v6::CTCGreedyDecoderSeqLen>(logits_param, seq_len_param);
|
||||||
ASSERT_EQ(G->get_output_element_type(0), element::i32);
|
EXPECT_EQ(op->get_output_element_type(0), element::i32);
|
||||||
ASSERT_EQ(G->get_output_element_type(1), element::i32);
|
EXPECT_EQ(op->get_output_element_type(1), element::i32);
|
||||||
ASSERT_EQ(G->get_output_shape(0), out_shape1);
|
EXPECT_EQ(op->get_output_shape(0), out_shape1);
|
||||||
ASSERT_EQ(G->get_output_shape(1), out_shape2);
|
EXPECT_EQ(op->get_output_shape(1), out_shape2);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(type_prop, ctc_greedy_decoder_seq_len_static_shapes_with_bi) {
|
TEST(type_prop, ctc_greedy_decoder_seq_len_static_shapes_with_bi) {
|
||||||
@ -28,14 +60,15 @@ TEST(type_prop, ctc_greedy_decoder_seq_len_static_shapes_with_bi) {
|
|||||||
PartialShape seq_len_shape{3};
|
PartialShape seq_len_shape{3};
|
||||||
Shape out_shape1{3, 100};
|
Shape out_shape1{3, 100};
|
||||||
Shape out_shape2{3};
|
Shape out_shape2{3};
|
||||||
auto P = make_shared<op::Parameter>(element::f32, logits_shape);
|
auto logits_param = make_shared<op::v0::Parameter>(element::f32, logits_shape);
|
||||||
auto I = make_shared<op::Parameter>(element::i32, seq_len_shape);
|
auto seq_len_param = make_shared<op::v0::Parameter>(element::i32, seq_len_shape);
|
||||||
auto BI = op::Constant::create(element::i32, Shape{}, {1});
|
auto bi = op::v0::Constant::create(element::i32, Shape{}, {1});
|
||||||
auto G = make_shared<op::v6::CTCGreedyDecoderSeqLen>(P, I, BI, false, element::i64, element::i64);
|
auto op =
|
||||||
ASSERT_EQ(G->get_output_element_type(0), element::i64);
|
make_shared<op::v6::CTCGreedyDecoderSeqLen>(logits_param, seq_len_param, bi, false, element::i64, element::i64);
|
||||||
ASSERT_EQ(G->get_output_element_type(1), element::i64);
|
EXPECT_EQ(op->get_output_element_type(0), element::i64);
|
||||||
ASSERT_EQ(G->get_output_shape(0), out_shape1);
|
EXPECT_EQ(op->get_output_element_type(1), element::i64);
|
||||||
ASSERT_EQ(G->get_output_shape(1), out_shape2);
|
EXPECT_EQ(op->get_output_shape(0), out_shape1);
|
||||||
|
EXPECT_EQ(op->get_output_shape(1), out_shape2);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(type_prop, ctc_greedy_decoder_seq_len_static_shapes_with_dinemic_bi) {
|
TEST(type_prop, ctc_greedy_decoder_seq_len_static_shapes_with_dinemic_bi) {
|
||||||
@ -43,14 +76,15 @@ TEST(type_prop, ctc_greedy_decoder_seq_len_static_shapes_with_dinemic_bi) {
|
|||||||
PartialShape seq_len_shape{3};
|
PartialShape seq_len_shape{3};
|
||||||
Shape out_shape1{3, 100};
|
Shape out_shape1{3, 100};
|
||||||
Shape out_shape2{3};
|
Shape out_shape2{3};
|
||||||
auto P = make_shared<op::Parameter>(element::f32, logits_shape);
|
auto logits_param = make_shared<op::v0::Parameter>(element::f32, logits_shape);
|
||||||
auto I = make_shared<op::Parameter>(element::i32, seq_len_shape);
|
auto seq_len_param = make_shared<op::v0::Parameter>(element::i32, seq_len_shape);
|
||||||
auto BI = make_shared<op::Parameter>(element::i32, PartialShape{Dimension::dynamic()});
|
auto bi = make_shared<op::v0::Parameter>(element::i32, PartialShape{Dimension::dynamic()});
|
||||||
auto G = make_shared<op::v6::CTCGreedyDecoderSeqLen>(P, I, BI, false, element::i64, element::i64);
|
auto op =
|
||||||
ASSERT_EQ(G->get_output_element_type(0), element::i64);
|
make_shared<op::v6::CTCGreedyDecoderSeqLen>(logits_param, seq_len_param, bi, false, element::i64, element::i64);
|
||||||
ASSERT_EQ(G->get_output_element_type(1), element::i64);
|
EXPECT_EQ(op->get_output_element_type(0), element::i64);
|
||||||
ASSERT_EQ(G->get_output_shape(0), out_shape1);
|
EXPECT_EQ(op->get_output_element_type(1), element::i64);
|
||||||
ASSERT_EQ(G->get_output_shape(1), out_shape2);
|
EXPECT_EQ(op->get_output_shape(0), out_shape1);
|
||||||
|
EXPECT_EQ(op->get_output_shape(1), out_shape2);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(type_prop, ctc_greedy_decoder_seq_len_output_static_shape1) {
|
TEST(type_prop, ctc_greedy_decoder_seq_len_output_static_shape1) {
|
||||||
@ -58,13 +92,13 @@ TEST(type_prop, ctc_greedy_decoder_seq_len_output_static_shape1) {
|
|||||||
PartialShape seq_len_shape{3};
|
PartialShape seq_len_shape{3};
|
||||||
Shape out_shape1{3, 100};
|
Shape out_shape1{3, 100};
|
||||||
Shape out_shape2{3};
|
Shape out_shape2{3};
|
||||||
auto P = make_shared<op::Parameter>(element::f32, logits_shape);
|
auto logits_param = make_shared<op::v0::Parameter>(element::f32, logits_shape);
|
||||||
auto I = make_shared<op::Parameter>(element::i32, seq_len_shape);
|
auto seq_len_param = make_shared<op::v0::Parameter>(element::i32, seq_len_shape);
|
||||||
auto G = make_shared<op::v6::CTCGreedyDecoderSeqLen>(P, I, false);
|
auto op = make_shared<op::v6::CTCGreedyDecoderSeqLen>(logits_param, seq_len_param, false);
|
||||||
ASSERT_EQ(G->get_output_element_type(0), element::i32);
|
EXPECT_EQ(op->get_output_element_type(0), element::i32);
|
||||||
ASSERT_EQ(G->get_output_element_type(1), element::i32);
|
EXPECT_EQ(op->get_output_element_type(1), element::i32);
|
||||||
ASSERT_EQ(G->get_output_shape(0), out_shape1);
|
EXPECT_EQ(op->get_output_shape(0), out_shape1);
|
||||||
ASSERT_EQ(G->get_output_shape(1), out_shape2);
|
EXPECT_EQ(op->get_output_shape(1), out_shape2);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(type_prop, ctc_greedy_decoder_seq_len_dynamic_shapes) {
|
TEST(type_prop, ctc_greedy_decoder_seq_len_dynamic_shapes) {
|
||||||
@ -72,13 +106,13 @@ TEST(type_prop, ctc_greedy_decoder_seq_len_dynamic_shapes) {
|
|||||||
PartialShape seq_len_shape{Dimension::dynamic()};
|
PartialShape seq_len_shape{Dimension::dynamic()};
|
||||||
PartialShape out_shape1{Dimension::dynamic(), Dimension::dynamic()};
|
PartialShape out_shape1{Dimension::dynamic(), Dimension::dynamic()};
|
||||||
PartialShape out_shape2{Dimension::dynamic()};
|
PartialShape out_shape2{Dimension::dynamic()};
|
||||||
auto P = make_shared<op::Parameter>(element::f32, logits_shape);
|
auto logits_param = make_shared<op::v0::Parameter>(element::f32, logits_shape);
|
||||||
auto I = make_shared<op::Parameter>(element::i32, seq_len_shape);
|
auto seq_len_param = make_shared<op::v0::Parameter>(element::i32, seq_len_shape);
|
||||||
auto G = make_shared<op::v6::CTCGreedyDecoderSeqLen>(P, I, false);
|
auto op = make_shared<op::v6::CTCGreedyDecoderSeqLen>(logits_param, seq_len_param, false);
|
||||||
ASSERT_EQ(G->get_output_element_type(0), element::i32);
|
EXPECT_EQ(op->get_output_element_type(0), element::i32);
|
||||||
ASSERT_EQ(G->get_output_element_type(1), element::i32);
|
EXPECT_EQ(op->get_output_element_type(1), element::i32);
|
||||||
ASSERT_TRUE(G->get_output_partial_shape(0).same_scheme(out_shape1));
|
EXPECT_TRUE(op->get_output_partial_shape(0).same_scheme(out_shape1));
|
||||||
ASSERT_TRUE(G->get_output_partial_shape(1).same_scheme(out_shape2));
|
EXPECT_TRUE(op->get_output_partial_shape(1).same_scheme(out_shape2));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(type_prop, ctc_greedy_decoder_seq_len_dynamic_ranks1) {
|
TEST(type_prop, ctc_greedy_decoder_seq_len_dynamic_ranks1) {
|
||||||
@ -86,37 +120,101 @@ TEST(type_prop, ctc_greedy_decoder_seq_len_dynamic_ranks1) {
|
|||||||
PartialShape seq_len_shape{Dimension::dynamic()};
|
PartialShape seq_len_shape{Dimension::dynamic()};
|
||||||
PartialShape out_shape1{Dimension::dynamic(), Dimension::dynamic()};
|
PartialShape out_shape1{Dimension::dynamic(), Dimension::dynamic()};
|
||||||
PartialShape out_shape2{Dimension::dynamic()};
|
PartialShape out_shape2{Dimension::dynamic()};
|
||||||
auto P = make_shared<op::Parameter>(element::f32, logits_shape);
|
auto logits_param = make_shared<op::v0::Parameter>(element::f32, logits_shape);
|
||||||
auto I = make_shared<op::Parameter>(element::i32, seq_len_shape);
|
auto seq_len_param = make_shared<op::v0::Parameter>(element::i32, seq_len_shape);
|
||||||
auto G = make_shared<op::v6::CTCGreedyDecoderSeqLen>(P, I);
|
auto op = make_shared<op::v6::CTCGreedyDecoderSeqLen>(logits_param, seq_len_param);
|
||||||
ASSERT_EQ(G->get_output_element_type(0), element::i32);
|
EXPECT_EQ(op->get_output_element_type(0), element::i32);
|
||||||
ASSERT_EQ(G->get_output_element_type(1), element::i32);
|
EXPECT_EQ(op->get_output_element_type(1), element::i32);
|
||||||
ASSERT_TRUE(G->get_output_partial_shape(0).same_scheme(out_shape1));
|
EXPECT_TRUE(op->get_output_partial_shape(0).same_scheme(out_shape1));
|
||||||
ASSERT_TRUE(G->get_output_partial_shape(1).same_scheme(out_shape2));
|
EXPECT_TRUE(op->get_output_partial_shape(1).same_scheme(out_shape2));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(type_prop, ctc_greedy_decoder_seq_len_dynamic_ranks2) {
|
TEST(type_prop, ctc_greedy_decoder_seq_len_dynamic_ranks2) {
|
||||||
PartialShape logits_shape = PartialShape::dynamic();
|
PartialShape logits_shape = PartialShape::dynamic();
|
||||||
PartialShape seq_mask_shape = PartialShape::dynamic();
|
PartialShape seq_len_shape = PartialShape::dynamic();
|
||||||
PartialShape out_shape1{Dimension::dynamic(), Dimension::dynamic()};
|
PartialShape out_shape1{Dimension::dynamic(), Dimension::dynamic()};
|
||||||
PartialShape out_shape2{Dimension::dynamic()};
|
PartialShape out_shape2{Dimension::dynamic()};
|
||||||
auto P = make_shared<op::Parameter>(element::f32, logits_shape);
|
auto logits_param = make_shared<op::v0::Parameter>(element::f32, logits_shape);
|
||||||
auto I = make_shared<op::Parameter>(element::i32, seq_mask_shape);
|
auto seq_len_param = make_shared<op::v0::Parameter>(element::i32, seq_len_shape);
|
||||||
auto G = make_shared<op::v6::CTCGreedyDecoderSeqLen>(P, I, false);
|
auto op = make_shared<op::v6::CTCGreedyDecoderSeqLen>(logits_param, seq_len_param, false);
|
||||||
ASSERT_EQ(G->get_output_element_type(0), element::i32);
|
EXPECT_EQ(op->get_output_element_type(0), element::i32);
|
||||||
ASSERT_EQ(G->get_output_element_type(1), element::i32);
|
EXPECT_EQ(op->get_output_element_type(1), element::i32);
|
||||||
ASSERT_TRUE(G->get_output_partial_shape(0).same_scheme(out_shape1));
|
EXPECT_TRUE(op->get_output_partial_shape(0).same_scheme(out_shape1));
|
||||||
ASSERT_TRUE(G->get_output_partial_shape(1).same_scheme(out_shape2));
|
EXPECT_TRUE(op->get_output_partial_shape(1).same_scheme(out_shape2));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(type_prop, ctc_greedy_decoder_seq_len_interval_labeled_dims_all) {
|
||||||
|
PartialShape logits_shape{{2, 6}, {10, 100}, {600, 1200}};
|
||||||
|
PartialShape seq_len_shape{{4, 8}};
|
||||||
|
|
||||||
|
set_shape_labels(logits_shape, 10);
|
||||||
|
set_shape_labels(seq_len_shape, 20);
|
||||||
|
|
||||||
|
auto logits_param = make_shared<op::v0::Parameter>(element::f32, logits_shape);
|
||||||
|
auto seq_len_param = make_shared<op::v0::Parameter>(element::f32, seq_len_shape);
|
||||||
|
auto op = make_shared<op::v6::CTCGreedyDecoderSeqLen>(logits_param, seq_len_param);
|
||||||
|
|
||||||
|
// Output 0
|
||||||
|
EXPECT_EQ(op->get_output_element_type(0), element::i32);
|
||||||
|
EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{{4, 6}, {10, 100}}));
|
||||||
|
EXPECT_THAT(get_shape_labels(op->get_output_partial_shape(0)), ElementsAre(20, 11));
|
||||||
|
|
||||||
|
// Output 1
|
||||||
|
EXPECT_EQ(op->get_output_element_type(1), element::i32);
|
||||||
|
EXPECT_EQ(op->get_output_partial_shape(1), (PartialShape{{4, 6}}));
|
||||||
|
EXPECT_THAT(get_shape_labels(op->get_output_partial_shape(1)), ElementsAre(20));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(type_prop, ctc_greedy_decoder_seq_len_interval_labeled_dims_in0) {
|
||||||
|
PartialShape logits_shape{{2, 6}, {10, 100}, {600, 1200}};
|
||||||
|
PartialShape seq_len_shape{{4, 8}};
|
||||||
|
|
||||||
|
set_shape_labels(logits_shape, 10);
|
||||||
|
|
||||||
|
auto logits_param = make_shared<op::v0::Parameter>(element::f32, logits_shape);
|
||||||
|
auto seq_len_param = make_shared<op::v0::Parameter>(element::f32, seq_len_shape);
|
||||||
|
auto op = make_shared<op::v6::CTCGreedyDecoderSeqLen>(logits_param, seq_len_param);
|
||||||
|
|
||||||
|
// Output 0
|
||||||
|
EXPECT_EQ(op->get_output_element_type(0), element::i32);
|
||||||
|
EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{{4, 6}, {10, 100}}));
|
||||||
|
EXPECT_THAT(get_shape_labels(op->get_output_partial_shape(0)), ElementsAre(10, 11));
|
||||||
|
|
||||||
|
// Output 1
|
||||||
|
EXPECT_EQ(op->get_output_element_type(1), element::i32);
|
||||||
|
EXPECT_EQ(op->get_output_partial_shape(1), (PartialShape{{4, 6}}));
|
||||||
|
EXPECT_THAT(get_shape_labels(op->get_output_partial_shape(1)), ElementsAre(10));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(type_prop, ctc_greedy_decoder_seq_len_interval_labeled_dims_in1) {
|
||||||
|
PartialShape logits_shape{{2, 6}, {10, 100}, {600, 1200}};
|
||||||
|
PartialShape seq_len_shape{{4, 8}};
|
||||||
|
|
||||||
|
set_shape_labels(seq_len_shape, 20);
|
||||||
|
|
||||||
|
auto logits_param = make_shared<op::v0::Parameter>(element::f32, logits_shape);
|
||||||
|
auto seq_len_param = make_shared<op::v0::Parameter>(element::f32, seq_len_shape);
|
||||||
|
auto op = make_shared<op::v6::CTCGreedyDecoderSeqLen>(logits_param, seq_len_param);
|
||||||
|
|
||||||
|
// Output 0
|
||||||
|
EXPECT_EQ(op->get_output_element_type(0), element::i32);
|
||||||
|
EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{{4, 6}, {10, 100}}));
|
||||||
|
EXPECT_THAT(get_shape_labels(op->get_output_partial_shape(0)), ElementsAre(20, ov::no_label));
|
||||||
|
|
||||||
|
// Output 1
|
||||||
|
EXPECT_EQ(op->get_output_element_type(1), element::i32);
|
||||||
|
EXPECT_EQ(op->get_output_partial_shape(1), (PartialShape{{4, 6}}));
|
||||||
|
EXPECT_THAT(get_shape_labels(op->get_output_partial_shape(1)), ElementsAre(20));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(type_prop, ctc_greedy_decoder_seq_len_incorrect_rank) {
|
TEST(type_prop, ctc_greedy_decoder_seq_len_incorrect_rank) {
|
||||||
PartialShape logits_shape{Dimension::dynamic(), 100, 1200, 5};
|
PartialShape logits_shape{Dimension::dynamic(), 100, 1200, 5};
|
||||||
PartialShape seq_len_shape{3};
|
PartialShape seq_len_shape{3};
|
||||||
auto P = make_shared<op::Parameter>(element::f32, logits_shape);
|
auto logits_param = make_shared<op::v0::Parameter>(element::f32, logits_shape);
|
||||||
auto I = make_shared<op::Parameter>(element::i32, seq_len_shape);
|
auto seq_len_param = make_shared<op::v0::Parameter>(element::i32, seq_len_shape);
|
||||||
|
|
||||||
try {
|
try {
|
||||||
auto G = make_shared<op::v6::CTCGreedyDecoderSeqLen>(P, I, false);
|
auto op = make_shared<op::v6::CTCGreedyDecoderSeqLen>(logits_param, seq_len_param, false);
|
||||||
// Should have thrown, so fail if it didn't
|
// Should have thrown, so fail if it didn't
|
||||||
FAIL() << "Incorrect indices rank";
|
FAIL() << "Incorrect indices rank";
|
||||||
} catch (const NodeValidationFailure& error) {
|
} catch (const NodeValidationFailure& error) {
|
||||||
@ -129,11 +227,11 @@ TEST(type_prop, ctc_greedy_decoder_seq_len_incorrect_rank) {
|
|||||||
TEST(type_prop, ctc_greedy_decoder_seq_len_incorrect_rank2) {
|
TEST(type_prop, ctc_greedy_decoder_seq_len_incorrect_rank2) {
|
||||||
PartialShape logits_shape{3, 100, 1200};
|
PartialShape logits_shape{3, 100, 1200};
|
||||||
PartialShape seq_len_shape{3, 100};
|
PartialShape seq_len_shape{3, 100};
|
||||||
auto P = make_shared<op::Parameter>(element::f32, logits_shape);
|
auto logits_param = make_shared<op::v0::Parameter>(element::f32, logits_shape);
|
||||||
auto I = make_shared<op::Parameter>(element::i32, seq_len_shape);
|
auto seq_len_param = make_shared<op::v0::Parameter>(element::i32, seq_len_shape);
|
||||||
|
|
||||||
try {
|
try {
|
||||||
auto G = make_shared<op::v6::CTCGreedyDecoderSeqLen>(P, I, false);
|
auto op = make_shared<op::v6::CTCGreedyDecoderSeqLen>(logits_param, seq_len_param, false);
|
||||||
// Should have thrown, so fail if it didn't
|
// Should have thrown, so fail if it didn't
|
||||||
FAIL() << "Incorrect indices rank";
|
FAIL() << "Incorrect indices rank";
|
||||||
} catch (const NodeValidationFailure& error) {
|
} catch (const NodeValidationFailure& error) {
|
||||||
@ -145,12 +243,12 @@ TEST(type_prop, ctc_greedy_decoder_seq_len_incorrect_rank2) {
|
|||||||
|
|
||||||
TEST(type_prop, ctc_greedy_decoder_seq_len_mismatched_dim1) {
|
TEST(type_prop, ctc_greedy_decoder_seq_len_mismatched_dim1) {
|
||||||
PartialShape logits_shape{4, 100, 1200};
|
PartialShape logits_shape{4, 100, 1200};
|
||||||
PartialShape seq_mask_shape{3};
|
PartialShape seq_len_shape{3};
|
||||||
auto P = make_shared<op::Parameter>(element::f32, logits_shape);
|
auto logits_param = make_shared<op::v0::Parameter>(element::f32, logits_shape);
|
||||||
auto I = make_shared<op::Parameter>(element::i32, seq_mask_shape);
|
auto seq_len_param = make_shared<op::v0::Parameter>(element::i32, seq_len_shape);
|
||||||
|
|
||||||
try {
|
try {
|
||||||
auto G = make_shared<op::v6::CTCGreedyDecoderSeqLen>(P, I, false);
|
auto op = make_shared<op::v6::CTCGreedyDecoderSeqLen>(logits_param, seq_len_param, false);
|
||||||
// Should have thrown, so fail if it didn't
|
// Should have thrown, so fail if it didn't
|
||||||
FAIL() << "Incorrect indices rank";
|
FAIL() << "Incorrect indices rank";
|
||||||
} catch (const NodeValidationFailure& error) {
|
} catch (const NodeValidationFailure& error) {
|
||||||
|
@ -1,26 +0,0 @@
|
|||||||
// Copyright (C) 2018-2023 Intel Corporation
|
|
||||||
// SPDX-License-Identifier: Apache-2.0
|
|
||||||
//
|
|
||||||
|
|
||||||
#include <gtest/gtest.h>
|
|
||||||
|
|
||||||
#include <ctc_greedy_decoder_seq_len_shape_inference.hpp>
|
|
||||||
#include <openvino/op/ops.hpp>
|
|
||||||
#include <openvino/op/parameter.hpp>
|
|
||||||
#include <utils/shape_inference/shape_inference.hpp>
|
|
||||||
#include <utils/shape_inference/static_shape.hpp>
|
|
||||||
|
|
||||||
using namespace ov;
|
|
||||||
using namespace ov::intel_cpu;
|
|
||||||
|
|
||||||
TEST(StaticShapeInferenceTest, CtcGreedyDecoderSeqLenTest) {
|
|
||||||
auto P = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1});
|
|
||||||
auto I = std::make_shared<ov::op::v0::Parameter>(element::i32, PartialShape{-1});
|
|
||||||
auto G = std::make_shared<op::v6::CTCGreedyDecoderSeqLen>(P, I);
|
|
||||||
// Test StaticShape
|
|
||||||
std::vector<StaticShape> static_input_shapes = {StaticShape{3, 100, 1200}, StaticShape{3}},
|
|
||||||
static_output_shapes = {StaticShape{}, StaticShape{}};
|
|
||||||
shape_inference(G.get(), static_input_shapes, static_output_shapes);
|
|
||||||
ASSERT_EQ(static_output_shapes[0], StaticShape({3, 100}));
|
|
||||||
ASSERT_EQ(static_output_shapes[1], StaticShape({3}));
|
|
||||||
}
|
|
@ -0,0 +1,72 @@
|
|||||||
|
// Copyright (C) 2018-2024 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
#include "ctc_greedy_decoder_seq_len_shape_inference.hpp"
|
||||||
|
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
#include "common_test_utils/test_assertions.hpp"
|
||||||
|
#include "openvino/op/ops.hpp"
|
||||||
|
#include "utils.hpp"
|
||||||
|
|
||||||
|
using namespace ov;
|
||||||
|
using namespace ov::intel_cpu;
|
||||||
|
using namespace testing;
|
||||||
|
|
||||||
|
class CTCGreedyDecoderSeqLenV6StaticShapeInferenceTest : public OpStaticShapeInferenceTest<op::v6::CTCGreedyDecoderSeqLen> {
|
||||||
|
void SetUp() override {
|
||||||
|
output_shapes.resize(1);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(CTCGreedyDecoderSeqLenV6StaticShapeInferenceTest, basic) {
|
||||||
|
auto data = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1});
|
||||||
|
auto seq_mask = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{-1});
|
||||||
|
op = make_op(data, seq_mask, false);
|
||||||
|
|
||||||
|
input_shapes = {StaticShape{4, 100, 1200}, StaticShape{4}};
|
||||||
|
|
||||||
|
shape_inference(op.get(), input_shapes, output_shapes);
|
||||||
|
EXPECT_EQ(output_shapes[0], StaticShape({4, 100}));
|
||||||
|
EXPECT_EQ(output_shapes[1], StaticShape({4}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(CTCGreedyDecoderSeqLenV6StaticShapeInferenceTest, default_ctor) {
|
||||||
|
op = make_op();
|
||||||
|
|
||||||
|
// Two inputs
|
||||||
|
input_shapes = {StaticShape{4, 100, 1200}, StaticShape{4}};
|
||||||
|
shape_inference(op.get(), input_shapes, output_shapes);
|
||||||
|
EXPECT_EQ(output_shapes[0], StaticShape({4, 100}));
|
||||||
|
EXPECT_EQ(output_shapes[1], StaticShape({4}));
|
||||||
|
|
||||||
|
// Three inputs (the last one is optional)
|
||||||
|
input_shapes = {StaticShape{4, 100, 1200}, StaticShape{4}, {}};
|
||||||
|
shape_inference(op.get(), input_shapes, output_shapes);
|
||||||
|
EXPECT_EQ(output_shapes[0], StaticShape({4, 100}));
|
||||||
|
EXPECT_EQ(output_shapes[1], StaticShape({4}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(CTCGreedyDecoderSeqLenV6StaticShapeInferenceTest, incompatible_batch) {
|
||||||
|
auto data = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic());
|
||||||
|
auto seq_mask = std::make_shared<op::v0::Parameter>(element::i32, PartialShape::dynamic());
|
||||||
|
op = make_op(data, seq_mask, false);
|
||||||
|
|
||||||
|
input_shapes = {StaticShape{4, 100, 1200}, StaticShape{6}};
|
||||||
|
|
||||||
|
OV_EXPECT_THROW(shape_inference(op.get(), input_shapes, output_shapes),
|
||||||
|
NodeValidationFailure,
|
||||||
|
HasSubstr("The first dimensions of input tensors must match"))
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(CTCGreedyDecoderSeqLenV6StaticShapeInferenceTest, incompatible_seq_len_rank) {
|
||||||
|
auto data = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic());
|
||||||
|
auto seq_mask = std::make_shared<op::v0::Parameter>(element::i32, PartialShape::dynamic());
|
||||||
|
op = make_op(data, seq_mask, false);
|
||||||
|
|
||||||
|
input_shapes = {StaticShape{4, 100, 1200}, StaticShape{4, 1}};
|
||||||
|
|
||||||
|
OV_EXPECT_THROW(shape_inference(op.get(), input_shapes, output_shapes),
|
||||||
|
NodeValidationFailure,
|
||||||
|
HasSubstr("The rank of sequence len tensor must be equal to 1"))
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user