diff --git a/src/core/include/openvino/op/ctc_greedy_decoder.hpp b/src/core/include/openvino/op/ctc_greedy_decoder.hpp index b0b2fdd3e26..eff8ac82418 100644 --- a/src/core/include/openvino/op/ctc_greedy_decoder.hpp +++ b/src/core/include/openvino/op/ctc_greedy_decoder.hpp @@ -32,6 +32,10 @@ public: return m_ctc_merge_repeated; } + void set_ctc_merge_repeated(bool ctc_merge_repeated) { + m_ctc_merge_repeated = ctc_merge_repeated; + } + private: bool m_ctc_merge_repeated{true}; }; diff --git a/src/core/shape_inference/include/ctc_greedy_decoder_shape_inference.hpp b/src/core/shape_inference/include/ctc_greedy_decoder_shape_inference.hpp index 61eda198d3b..81e4b1cf678 100644 --- a/src/core/shape_inference/include/ctc_greedy_decoder_shape_inference.hpp +++ b/src/core/shape_inference/include/ctc_greedy_decoder_shape_inference.hpp @@ -8,26 +8,27 @@ namespace ov { namespace op { namespace v0 { -template -void shape_infer(const CTCGreedyDecoder* op, const std::vector& input_shapes, std::vector& output_shapes) { - NODE_VALIDATION_CHECK(op, input_shapes.size() == 2 && output_shapes.size() == 1); - using DimType = typename std::iterator_traits::value_type; - // output dynamic rank tensor if all inputs are of dynamic rank +template +std::vector shape_infer(const CTCGreedyDecoder* op, const std::vector& input_shapes) { + NODE_VALIDATION_CHECK(op, input_shapes.size() == 2); + using DimType = typename TShape::value_type; + + // Output shape rank is always static and equal to 4 + // The last two output shape dimensions are always static and equal to 1 + std::vector output_dims(4); + output_dims[2] = 1; + output_dims[3] = 1; + const auto& logits_pshape = input_shapes[0]; const auto& seq_mask_pshape = input_shapes[1]; - auto& output_shape = output_shapes[0]; - output_shape.resize(4); - output_shape[2] = 1; - output_shape[3] = 1; if (logits_pshape.rank().is_dynamic() && seq_mask_pshape.rank().is_dynamic()) { - return; + return {TShape(std::move(output_dims))}; } - // validate input shapes and compute output shape - auto& batch_size = output_shape[0]; - auto& time_size = output_shape[1]; - // check ranks of input tensors + auto& batch_size = output_dims[0]; + auto& time_size = output_dims[1]; + if (logits_pshape.rank().is_static()) { NODE_VALIDATION_CHECK(op, logits_pshape.rank().compatible(3), "The rank of logits tensor must be equal to 3."); time_size = logits_pshape[0]; @@ -45,6 +46,14 @@ void shape_infer(const CTCGreedyDecoder* op, const std::vector& input_shapes, DimType::merge(batch_size, batch_size, seq_mask_pshape[1]), "The second dimensions of input tensors must match."); } + return {TShape(std::move(output_dims))}; +} + +template +void shape_infer(const CTCGreedyDecoder* op, + const std::vector& input_shapes, + std::vector& output_shapes) { + output_shapes = shape_infer(op, input_shapes); } } // namespace v0 } // namespace op diff --git a/src/core/tests/type_prop/ctc_greedy_decoder.cpp b/src/core/tests/type_prop/ctc_greedy_decoder.cpp index 8cc8c981114..ad8054f10d6 100644 --- a/src/core/tests/type_prop/ctc_greedy_decoder.cpp +++ b/src/core/tests/type_prop/ctc_greedy_decoder.cpp @@ -2,87 +2,165 @@ // SPDX-License-Identifier: Apache-2.0 // +#include "dimension_tracker.hpp" #include "gtest/gtest.h" #include "ngraph/ngraph.hpp" #include "util/type_prop.hpp" using namespace std; using namespace ngraph; +using namespace testing; + +TEST(type_prop, ctc_greedy_decoder_default_constructor) { + PartialShape data_shape{100, 3, 1200}; + PartialShape seq_mask_shape{100, 3}; + PartialShape expected_shape{3, 100, 1, 1}; + + auto op = make_shared(); + + auto data = make_shared(element::f32, data_shape); + auto seq_mask = make_shared(element::f32, seq_mask_shape); + op->set_arguments(OutputVector{data, seq_mask}); + + op->set_ctc_merge_repeated(false); + EXPECT_EQ(op->get_ctc_merge_repeated(), false); + + op->set_ctc_merge_repeated(true); + EXPECT_EQ(op->get_ctc_merge_repeated(), true); + + op->validate_and_infer_types(); + + EXPECT_EQ(op->get_element_type(), element::f32); + EXPECT_EQ(op->get_output_partial_shape(0), expected_shape); +} TEST(type_prop, ctc_greedy_decoder_static_shapes) { PartialShape logits_shape{100, 3, 1200}; PartialShape seq_mask_shape{100, 3}; Shape out_shape{3, 100, 1, 1}; - auto P = make_shared(element::f32, logits_shape); - auto I = make_shared(element::f32, seq_mask_shape); - auto G = make_shared(P, I, false); - ASSERT_EQ(G->get_element_type(), element::f32); - ASSERT_EQ(G->get_shape(), out_shape); + auto data = make_shared(element::f32, logits_shape); + auto seq_mask = make_shared(element::f32, seq_mask_shape); + auto op = make_shared(data, seq_mask, false); + EXPECT_EQ(op->get_element_type(), element::f32); + EXPECT_EQ(op->get_shape(), out_shape); +} + +TEST(type_prop, ctc_greedy_decoder_interval_labeled_dims_all) { + PartialShape data_shape{{1, 100}, {2, 6}, {600, 1200}}; + PartialShape seq_mask_shape{{10, 1000}, {4, 8}}; + PartialShape expected_shape{{4, 6}, {10, 100}, 1, 1}; + + set_shape_labels(data_shape, 10); + set_shape_labels(seq_mask_shape, 20); + + auto data = make_shared(element::f32, data_shape); + auto seq_mask = make_shared(element::f32, seq_mask_shape); + auto op = make_shared(data, seq_mask, false); + + const auto& out_shape = op->get_output_partial_shape(0); + EXPECT_EQ(op->get_element_type(), element::f32); + EXPECT_EQ(out_shape, expected_shape); + EXPECT_THAT(get_shape_labels(out_shape), ElementsAre(21, 20, ov::no_label, ov::no_label)); +} + +TEST(type_prop, ctc_greedy_decoder_interval_labeled_dims_data) { + PartialShape data_shape{{1, 100}, {2, 6}, {600, 1200}}; + PartialShape seq_mask_shape{{10, 1000}, {4, 8}}; + PartialShape expected_shape{{4, 6}, {10, 100}, 1, 1}; + + set_shape_labels(data_shape, 10); + + auto data = make_shared(element::f32, data_shape); + auto seq_mask = make_shared(element::f32, seq_mask_shape); + auto op = make_shared(data, seq_mask, false); + + const auto& out_shape = op->get_output_partial_shape(0); + EXPECT_EQ(op->get_element_type(), element::f32); + EXPECT_EQ(out_shape, expected_shape); + EXPECT_THAT(get_shape_labels(out_shape), ElementsAre(11, 10, ov::no_label, ov::no_label)); +} + +TEST(type_prop, ctc_greedy_decoder_interval_labeled_dims_mask) { + PartialShape data_shape{{1, 100}, {2, 6}, {600, 1200}}; + PartialShape seq_mask_shape{{10, 1000}, {4, 8}}; + PartialShape expected_shape{{4, 6}, {10, 100}, 1, 1}; + + set_shape_labels(seq_mask_shape, 20); + + auto data = make_shared(element::f32, data_shape); + auto seq_mask = make_shared(element::f32, seq_mask_shape); + auto op = make_shared(data, seq_mask, false); + + const auto& out_shape = op->get_output_partial_shape(0); + EXPECT_EQ(op->get_output_element_type(0), element::f32); + EXPECT_EQ(out_shape, expected_shape); + EXPECT_THAT(get_shape_labels(out_shape), ElementsAre(21, 20, ov::no_label, ov::no_label)); } TEST(type_prop, ctc_greedy_decoder_output_static_shape1) { PartialShape logits_shape{Dimension::dynamic(), 3, 1200}; PartialShape seq_mask_shape{100, 3}; Shape out_shape{3, 100, 1, 1}; - auto P = make_shared(element::f32, logits_shape); - auto I = make_shared(element::f32, seq_mask_shape); - auto G = make_shared(P, I, false); - ASSERT_EQ(G->get_element_type(), element::f32); - ASSERT_EQ(G->get_shape(), out_shape); + auto data = make_shared(element::f32, logits_shape); + auto seq_mask = make_shared(element::f32, seq_mask_shape); + auto op = make_shared(data, seq_mask, false); + + EXPECT_EQ(op->get_output_element_type(0), element::f32); + EXPECT_EQ(op->get_shape(), out_shape); } TEST(type_prop, ctc_greedy_decoder_output_static_shape2) { PartialShape logits_shape{Dimension::dynamic(), 3, 1200}; PartialShape seq_mask_shape{100, Dimension::dynamic()}; Shape out_shape{3, 100, 1, 1}; - auto P = make_shared(element::f32, logits_shape); - auto I = make_shared(element::f32, seq_mask_shape); - auto G = make_shared(P, I, false); - ASSERT_EQ(G->get_element_type(), element::f32); - ASSERT_EQ(G->get_shape(), out_shape); + auto data = make_shared(element::f32, logits_shape); + auto seq_mask = make_shared(element::f32, seq_mask_shape); + auto op = make_shared(data, seq_mask, false); + EXPECT_EQ(op->get_element_type(), element::f32); + EXPECT_EQ(op->get_shape(), out_shape); } TEST(type_prop, ctc_greedy_decoder_dynamic_shapes) { PartialShape logits_shape{Dimension::dynamic(), Dimension::dynamic(), 1200}; PartialShape seq_mask_shape{Dimension::dynamic(), Dimension::dynamic()}; PartialShape out_shape{Dimension::dynamic(), Dimension::dynamic(), 1, 1}; - auto P = make_shared(element::f32, logits_shape); - auto I = make_shared(element::f32, seq_mask_shape); - auto G = make_shared(P, I, false); - ASSERT_EQ(G->get_element_type(), element::f32); - ASSERT_TRUE(G->get_output_partial_shape(0).same_scheme(out_shape)); + auto data = make_shared(element::f32, logits_shape); + auto seq_mask = make_shared(element::f32, seq_mask_shape); + auto op = make_shared(data, seq_mask, false); + EXPECT_EQ(op->get_element_type(), element::f32); + ASSERT_TRUE(op->get_output_partial_shape(0).same_scheme(out_shape)); } TEST(type_prop, ctc_greedy_decoder_dynamic_ranks1) { PartialShape logits_shape = PartialShape::dynamic(); PartialShape seq_mask_shape{100, Dimension::dynamic()}; PartialShape out_shape{Dimension::dynamic(), 100, 1, 1}; - auto P = make_shared(element::f32, logits_shape); - auto I = make_shared(element::f32, seq_mask_shape); - auto G = make_shared(P, I, false); - ASSERT_EQ(G->get_element_type(), element::f32); - ASSERT_TRUE(G->get_output_partial_shape(0).same_scheme(out_shape)); + auto data = make_shared(element::f32, logits_shape); + auto seq_mask = make_shared(element::f32, seq_mask_shape); + auto op = make_shared(data, seq_mask, false); + EXPECT_EQ(op->get_element_type(), element::f32); + ASSERT_TRUE(op->get_output_partial_shape(0).same_scheme(out_shape)); } TEST(type_prop, ctc_greedy_decoder_dynamic_ranks2) { PartialShape logits_shape = PartialShape::dynamic(); PartialShape seq_mask_shape = PartialShape::dynamic(); PartialShape out_shape{Dimension::dynamic(), Dimension::dynamic(), 1, 1}; - auto P = make_shared(element::f32, logits_shape); - auto I = make_shared(element::f32, seq_mask_shape); - auto G = make_shared(P, I, false); - ASSERT_EQ(G->get_element_type(), element::f32); - ASSERT_TRUE(G->get_output_partial_shape(0).same_scheme(out_shape)); + auto data = make_shared(element::f32, logits_shape); + auto seq_mask = make_shared(element::f32, seq_mask_shape); + auto op = make_shared(data, seq_mask, false); + EXPECT_EQ(op->get_element_type(), element::f32); + ASSERT_TRUE(op->get_output_partial_shape(0).same_scheme(out_shape)); } TEST(type_prop, ctc_greedy_decoder_incorrect_rank) { PartialShape logits_shape{Dimension::dynamic(), 3, 1200, 5}; PartialShape seq_mask_shape{100, 3}; - auto P = make_shared(element::f32, logits_shape); - auto I = make_shared(element::f32, seq_mask_shape); + auto data = make_shared(element::f32, logits_shape); + auto seq_mask = make_shared(element::f32, seq_mask_shape); try { - auto G = make_shared(P, I, false); + auto op = make_shared(data, seq_mask, false); // Should have thrown, so fail if it didn't FAIL() << "Incorrect indices rank"; } catch (const NodeValidationFailure& error) { @@ -95,11 +173,11 @@ TEST(type_prop, ctc_greedy_decoder_incorrect_rank) { TEST(type_prop, ctc_greedy_decoder_incorrect_rank2) { PartialShape logits_shape{Dimension::dynamic(), 3, 1200}; PartialShape seq_mask_shape{100, 3, 2}; - auto P = make_shared(element::f32, logits_shape); - auto I = make_shared(element::f32, seq_mask_shape); + auto data = make_shared(element::f32, logits_shape); + auto seq_mask = make_shared(element::f32, seq_mask_shape); try { - auto G = make_shared(P, I, false); + auto op = make_shared(data, seq_mask, false); // Should have thrown, so fail if it didn't FAIL() << "Incorrect indices rank"; } catch (const NodeValidationFailure& error) { @@ -112,11 +190,11 @@ TEST(type_prop, ctc_greedy_decoder_incorrect_rank2) { TEST(type_prop, ctc_greedy_decoder_mismatched_dim1) { PartialShape logits_shape{100, 4, 1200}; PartialShape seq_mask_shape{100, 3}; - auto P = make_shared(element::f32, logits_shape); - auto I = make_shared(element::f32, seq_mask_shape); + auto data = make_shared(element::f32, logits_shape); + auto seq_mask = make_shared(element::f32, seq_mask_shape); try { - auto G = make_shared(P, I, false); + auto op = make_shared(data, seq_mask, false); // Should have thrown, so fail if it didn't FAIL() << "Incorrect indices rank"; } catch (const NodeValidationFailure& error) { @@ -129,11 +207,11 @@ TEST(type_prop, ctc_greedy_decoder_mismatched_dim1) { TEST(type_prop, ctc_greedy_decoder_mismatched_dim2) { PartialShape logits_shape{101, 3, 1200}; PartialShape seq_mask_shape{100, 3}; - auto P = make_shared(element::f32, logits_shape); - auto I = make_shared(element::f32, seq_mask_shape); + auto data = make_shared(element::f32, logits_shape); + auto seq_mask = make_shared(element::f32, seq_mask_shape); try { - auto G = make_shared(P, I, false); + auto op = make_shared(data, seq_mask, false); // Should have thrown, so fail if it didn't FAIL() << "Incorrect indices rank"; } catch (const NodeValidationFailure& error) { diff --git a/src/plugins/intel_cpu/tests/unit/shape_inference_test/ctc_greedy_decoder_shape_inference.cpp b/src/plugins/intel_cpu/tests/unit/shape_inference_test/ctc_greedy_decoder_shape_inference.cpp deleted file mode 100644 index 3427de780cb..00000000000 --- a/src/plugins/intel_cpu/tests/unit/shape_inference_test/ctc_greedy_decoder_shape_inference.cpp +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright (C) 2018-2023 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#include - -#include -#include -#include -#include -#include - -using namespace ov; -using namespace ov::intel_cpu; - -TEST(StaticShapeInferenceTest, CtcGreedyDecoderTest) { - auto P = std::make_shared(element::f32, PartialShape{-1, -1, -1}); - auto I = std::make_shared(element::i32, PartialShape{-1, -1}); - auto G = std::make_shared(P, I, false); - // Test StaticShape - std::vector static_input_shapes = {StaticShape{100, 3, 1200}, StaticShape{100, 3}}, - static_output_shapes = {StaticShape{}}; - shape_inference(G.get(), static_input_shapes, static_output_shapes); - ASSERT_EQ(static_output_shapes[0], StaticShape({3, 100, 1, 1})); -} diff --git a/src/plugins/intel_cpu/tests/unit/shape_inference_test/ctc_greedy_decoder_shape_inference_test.cpp b/src/plugins/intel_cpu/tests/unit/shape_inference_test/ctc_greedy_decoder_shape_inference_test.cpp new file mode 100644 index 00000000000..9f63542d383 --- /dev/null +++ b/src/plugins/intel_cpu/tests/unit/shape_inference_test/ctc_greedy_decoder_shape_inference_test.cpp @@ -0,0 +1,64 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// +#include "ctc_greedy_decoder_shape_inference.hpp" + +#include + +#include "common_test_utils/test_assertions.hpp" +#include "openvino/op/ops.hpp" +#include "utils.hpp" + +using namespace ov; +using namespace ov::intel_cpu; +using namespace testing; + +class CTCGreedyDecoderV0StaticShapeInferenceTest : public OpStaticShapeInferenceTest { + void SetUp() override { + output_shapes.resize(1); + } +}; + +TEST_F(CTCGreedyDecoderV0StaticShapeInferenceTest, basic) { + auto data = std::make_shared(element::f32, PartialShape{-1, -1, -1}); + auto seq_mask = std::make_shared(element::i32, PartialShape{-1, -1}); + op = make_op(data, seq_mask, false); + + input_shapes = {StaticShape{100, 3, 1200}, StaticShape{100, 3}}; + + shape_inference(op.get(), input_shapes, output_shapes); + EXPECT_EQ(output_shapes[0], StaticShape({3, 100, 1, 1})); +} + +TEST_F(CTCGreedyDecoderV0StaticShapeInferenceTest, decoder_default_ctor) { + op = make_op(); + + input_shapes = {StaticShape{100, 3, 1200}, StaticShape{100, 3}}; + + shape_infer(op.get(), input_shapes, output_shapes); + EXPECT_EQ(output_shapes[0], StaticShape({3, 100, 1, 1})); +} + +TEST_F(CTCGreedyDecoderV0StaticShapeInferenceTest, incompatible_batch) { + auto data = std::make_shared(element::f32, PartialShape::dynamic()); + auto seq_mask = std::make_shared(element::i32, PartialShape::dynamic()); + op = make_op(data, seq_mask, false); + + input_shapes = {StaticShape{10, 3, 1200}, StaticShape{100, 3}}; + + OV_EXPECT_THROW(shape_inference(op.get(), input_shapes, output_shapes), + NodeValidationFailure, + HasSubstr("The first dimensions of input tensors must match")) +} + +TEST_F(CTCGreedyDecoderV0StaticShapeInferenceTest, incompatible_t_dim) { + auto data = std::make_shared(element::f32, PartialShape::dynamic()); + auto seq_mask = std::make_shared(element::i32, PartialShape::dynamic()); + op = make_op(data, seq_mask, false); + + input_shapes = {StaticShape{100, 3, 1200}, StaticShape{100, 5}}; + + OV_EXPECT_THROW(shape_inference(op.get(), input_shapes, output_shapes), + NodeValidationFailure, + HasSubstr("The second dimensions of input tensors must match")) +}