[ShapeInference] CTCGreedyDecoder shape infer improvements (#15474)

* Add setter for ctc_merge_repeated

* shape infer improvements

* Add type prop tests

* Add cpu shape infer tests

* Tests refactor
This commit is contained in:
Katarzyna Mitrus 2023-02-08 08:41:14 +01:00 committed by GitHub
parent 18905ada20
commit d94dae79d8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 211 additions and 81 deletions

View File

@ -32,6 +32,10 @@ public:
return m_ctc_merge_repeated;
}
void set_ctc_merge_repeated(bool ctc_merge_repeated) {
m_ctc_merge_repeated = ctc_merge_repeated;
}
private:
bool m_ctc_merge_repeated{true};
};

View File

@ -8,26 +8,27 @@ namespace ov {
namespace op {
namespace v0 {
template <class T>
void shape_infer(const CTCGreedyDecoder* op, const std::vector<T>& input_shapes, std::vector<T>& output_shapes) {
NODE_VALIDATION_CHECK(op, input_shapes.size() == 2 && output_shapes.size() == 1);
using DimType = typename std::iterator_traits<typename T::iterator>::value_type;
// output dynamic rank tensor if all inputs are of dynamic rank
template <class TShape>
std::vector<TShape> shape_infer(const CTCGreedyDecoder* op, const std::vector<TShape>& input_shapes) {
NODE_VALIDATION_CHECK(op, input_shapes.size() == 2);
using DimType = typename TShape::value_type;
// Output shape rank is always static and equal to 4
// The last two output shape dimensions are always static and equal to 1
std::vector<DimType> output_dims(4);
output_dims[2] = 1;
output_dims[3] = 1;
const auto& logits_pshape = input_shapes[0];
const auto& seq_mask_pshape = input_shapes[1];
auto& output_shape = output_shapes[0];
output_shape.resize(4);
output_shape[2] = 1;
output_shape[3] = 1;
if (logits_pshape.rank().is_dynamic() && seq_mask_pshape.rank().is_dynamic()) {
return;
return {TShape(std::move(output_dims))};
}
// validate input shapes and compute output shape
auto& batch_size = output_shape[0];
auto& time_size = output_shape[1];
// check ranks of input tensors
auto& batch_size = output_dims[0];
auto& time_size = output_dims[1];
if (logits_pshape.rank().is_static()) {
NODE_VALIDATION_CHECK(op, logits_pshape.rank().compatible(3), "The rank of logits tensor must be equal to 3.");
time_size = logits_pshape[0];
@ -45,6 +46,14 @@ void shape_infer(const CTCGreedyDecoder* op, const std::vector<T>& input_shapes,
DimType::merge(batch_size, batch_size, seq_mask_pshape[1]),
"The second dimensions of input tensors must match.");
}
return {TShape(std::move(output_dims))};
}
template <class TShape>
void shape_infer(const CTCGreedyDecoder* op,
const std::vector<TShape>& input_shapes,
std::vector<TShape>& output_shapes) {
output_shapes = shape_infer(op, input_shapes);
}
} // namespace v0
} // namespace op

View File

@ -2,87 +2,165 @@
// SPDX-License-Identifier: Apache-2.0
//
#include "dimension_tracker.hpp"
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "util/type_prop.hpp"
using namespace std;
using namespace ngraph;
using namespace testing;
TEST(type_prop, ctc_greedy_decoder_default_constructor) {
PartialShape data_shape{100, 3, 1200};
PartialShape seq_mask_shape{100, 3};
PartialShape expected_shape{3, 100, 1, 1};
auto op = make_shared<op::CTCGreedyDecoder>();
auto data = make_shared<op::Parameter>(element::f32, data_shape);
auto seq_mask = make_shared<op::Parameter>(element::f32, seq_mask_shape);
op->set_arguments(OutputVector{data, seq_mask});
op->set_ctc_merge_repeated(false);
EXPECT_EQ(op->get_ctc_merge_repeated(), false);
op->set_ctc_merge_repeated(true);
EXPECT_EQ(op->get_ctc_merge_repeated(), true);
op->validate_and_infer_types();
EXPECT_EQ(op->get_element_type(), element::f32);
EXPECT_EQ(op->get_output_partial_shape(0), expected_shape);
}
TEST(type_prop, ctc_greedy_decoder_static_shapes) {
PartialShape logits_shape{100, 3, 1200};
PartialShape seq_mask_shape{100, 3};
Shape out_shape{3, 100, 1, 1};
auto P = make_shared<op::Parameter>(element::f32, logits_shape);
auto I = make_shared<op::Parameter>(element::f32, seq_mask_shape);
auto G = make_shared<op::CTCGreedyDecoder>(P, I, false);
ASSERT_EQ(G->get_element_type(), element::f32);
ASSERT_EQ(G->get_shape(), out_shape);
auto data = make_shared<op::Parameter>(element::f32, logits_shape);
auto seq_mask = make_shared<op::Parameter>(element::f32, seq_mask_shape);
auto op = make_shared<op::CTCGreedyDecoder>(data, seq_mask, false);
EXPECT_EQ(op->get_element_type(), element::f32);
EXPECT_EQ(op->get_shape(), out_shape);
}
TEST(type_prop, ctc_greedy_decoder_interval_labeled_dims_all) {
PartialShape data_shape{{1, 100}, {2, 6}, {600, 1200}};
PartialShape seq_mask_shape{{10, 1000}, {4, 8}};
PartialShape expected_shape{{4, 6}, {10, 100}, 1, 1};
set_shape_labels(data_shape, 10);
set_shape_labels(seq_mask_shape, 20);
auto data = make_shared<op::Parameter>(element::f32, data_shape);
auto seq_mask = make_shared<op::Parameter>(element::f32, seq_mask_shape);
auto op = make_shared<op::CTCGreedyDecoder>(data, seq_mask, false);
const auto& out_shape = op->get_output_partial_shape(0);
EXPECT_EQ(op->get_element_type(), element::f32);
EXPECT_EQ(out_shape, expected_shape);
EXPECT_THAT(get_shape_labels(out_shape), ElementsAre(21, 20, ov::no_label, ov::no_label));
}
TEST(type_prop, ctc_greedy_decoder_interval_labeled_dims_data) {
PartialShape data_shape{{1, 100}, {2, 6}, {600, 1200}};
PartialShape seq_mask_shape{{10, 1000}, {4, 8}};
PartialShape expected_shape{{4, 6}, {10, 100}, 1, 1};
set_shape_labels(data_shape, 10);
auto data = make_shared<op::Parameter>(element::f32, data_shape);
auto seq_mask = make_shared<op::Parameter>(element::f32, seq_mask_shape);
auto op = make_shared<op::CTCGreedyDecoder>(data, seq_mask, false);
const auto& out_shape = op->get_output_partial_shape(0);
EXPECT_EQ(op->get_element_type(), element::f32);
EXPECT_EQ(out_shape, expected_shape);
EXPECT_THAT(get_shape_labels(out_shape), ElementsAre(11, 10, ov::no_label, ov::no_label));
}
TEST(type_prop, ctc_greedy_decoder_interval_labeled_dims_mask) {
PartialShape data_shape{{1, 100}, {2, 6}, {600, 1200}};
PartialShape seq_mask_shape{{10, 1000}, {4, 8}};
PartialShape expected_shape{{4, 6}, {10, 100}, 1, 1};
set_shape_labels(seq_mask_shape, 20);
auto data = make_shared<op::Parameter>(element::f32, data_shape);
auto seq_mask = make_shared<op::Parameter>(element::f32, seq_mask_shape);
auto op = make_shared<op::CTCGreedyDecoder>(data, seq_mask, false);
const auto& out_shape = op->get_output_partial_shape(0);
EXPECT_EQ(op->get_output_element_type(0), element::f32);
EXPECT_EQ(out_shape, expected_shape);
EXPECT_THAT(get_shape_labels(out_shape), ElementsAre(21, 20, ov::no_label, ov::no_label));
}
TEST(type_prop, ctc_greedy_decoder_output_static_shape1) {
PartialShape logits_shape{Dimension::dynamic(), 3, 1200};
PartialShape seq_mask_shape{100, 3};
Shape out_shape{3, 100, 1, 1};
auto P = make_shared<op::Parameter>(element::f32, logits_shape);
auto I = make_shared<op::Parameter>(element::f32, seq_mask_shape);
auto G = make_shared<op::CTCGreedyDecoder>(P, I, false);
ASSERT_EQ(G->get_element_type(), element::f32);
ASSERT_EQ(G->get_shape(), out_shape);
auto data = make_shared<op::Parameter>(element::f32, logits_shape);
auto seq_mask = make_shared<op::Parameter>(element::f32, seq_mask_shape);
auto op = make_shared<op::CTCGreedyDecoder>(data, seq_mask, false);
EXPECT_EQ(op->get_output_element_type(0), element::f32);
EXPECT_EQ(op->get_shape(), out_shape);
}
TEST(type_prop, ctc_greedy_decoder_output_static_shape2) {
PartialShape logits_shape{Dimension::dynamic(), 3, 1200};
PartialShape seq_mask_shape{100, Dimension::dynamic()};
Shape out_shape{3, 100, 1, 1};
auto P = make_shared<op::Parameter>(element::f32, logits_shape);
auto I = make_shared<op::Parameter>(element::f32, seq_mask_shape);
auto G = make_shared<op::CTCGreedyDecoder>(P, I, false);
ASSERT_EQ(G->get_element_type(), element::f32);
ASSERT_EQ(G->get_shape(), out_shape);
auto data = make_shared<op::Parameter>(element::f32, logits_shape);
auto seq_mask = make_shared<op::Parameter>(element::f32, seq_mask_shape);
auto op = make_shared<op::CTCGreedyDecoder>(data, seq_mask, false);
EXPECT_EQ(op->get_element_type(), element::f32);
EXPECT_EQ(op->get_shape(), out_shape);
}
TEST(type_prop, ctc_greedy_decoder_dynamic_shapes) {
PartialShape logits_shape{Dimension::dynamic(), Dimension::dynamic(), 1200};
PartialShape seq_mask_shape{Dimension::dynamic(), Dimension::dynamic()};
PartialShape out_shape{Dimension::dynamic(), Dimension::dynamic(), 1, 1};
auto P = make_shared<op::Parameter>(element::f32, logits_shape);
auto I = make_shared<op::Parameter>(element::f32, seq_mask_shape);
auto G = make_shared<op::CTCGreedyDecoder>(P, I, false);
ASSERT_EQ(G->get_element_type(), element::f32);
ASSERT_TRUE(G->get_output_partial_shape(0).same_scheme(out_shape));
auto data = make_shared<op::Parameter>(element::f32, logits_shape);
auto seq_mask = make_shared<op::Parameter>(element::f32, seq_mask_shape);
auto op = make_shared<op::CTCGreedyDecoder>(data, seq_mask, false);
EXPECT_EQ(op->get_element_type(), element::f32);
ASSERT_TRUE(op->get_output_partial_shape(0).same_scheme(out_shape));
}
TEST(type_prop, ctc_greedy_decoder_dynamic_ranks1) {
PartialShape logits_shape = PartialShape::dynamic();
PartialShape seq_mask_shape{100, Dimension::dynamic()};
PartialShape out_shape{Dimension::dynamic(), 100, 1, 1};
auto P = make_shared<op::Parameter>(element::f32, logits_shape);
auto I = make_shared<op::Parameter>(element::f32, seq_mask_shape);
auto G = make_shared<op::CTCGreedyDecoder>(P, I, false);
ASSERT_EQ(G->get_element_type(), element::f32);
ASSERT_TRUE(G->get_output_partial_shape(0).same_scheme(out_shape));
auto data = make_shared<op::Parameter>(element::f32, logits_shape);
auto seq_mask = make_shared<op::Parameter>(element::f32, seq_mask_shape);
auto op = make_shared<op::CTCGreedyDecoder>(data, seq_mask, false);
EXPECT_EQ(op->get_element_type(), element::f32);
ASSERT_TRUE(op->get_output_partial_shape(0).same_scheme(out_shape));
}
TEST(type_prop, ctc_greedy_decoder_dynamic_ranks2) {
PartialShape logits_shape = PartialShape::dynamic();
PartialShape seq_mask_shape = PartialShape::dynamic();
PartialShape out_shape{Dimension::dynamic(), Dimension::dynamic(), 1, 1};
auto P = make_shared<op::Parameter>(element::f32, logits_shape);
auto I = make_shared<op::Parameter>(element::f32, seq_mask_shape);
auto G = make_shared<op::CTCGreedyDecoder>(P, I, false);
ASSERT_EQ(G->get_element_type(), element::f32);
ASSERT_TRUE(G->get_output_partial_shape(0).same_scheme(out_shape));
auto data = make_shared<op::Parameter>(element::f32, logits_shape);
auto seq_mask = make_shared<op::Parameter>(element::f32, seq_mask_shape);
auto op = make_shared<op::CTCGreedyDecoder>(data, seq_mask, false);
EXPECT_EQ(op->get_element_type(), element::f32);
ASSERT_TRUE(op->get_output_partial_shape(0).same_scheme(out_shape));
}
TEST(type_prop, ctc_greedy_decoder_incorrect_rank) {
PartialShape logits_shape{Dimension::dynamic(), 3, 1200, 5};
PartialShape seq_mask_shape{100, 3};
auto P = make_shared<op::Parameter>(element::f32, logits_shape);
auto I = make_shared<op::Parameter>(element::f32, seq_mask_shape);
auto data = make_shared<op::Parameter>(element::f32, logits_shape);
auto seq_mask = make_shared<op::Parameter>(element::f32, seq_mask_shape);
try {
auto G = make_shared<op::CTCGreedyDecoder>(P, I, false);
auto op = make_shared<op::CTCGreedyDecoder>(data, seq_mask, false);
// Should have thrown, so fail if it didn't
FAIL() << "Incorrect indices rank";
} catch (const NodeValidationFailure& error) {
@ -95,11 +173,11 @@ TEST(type_prop, ctc_greedy_decoder_incorrect_rank) {
TEST(type_prop, ctc_greedy_decoder_incorrect_rank2) {
PartialShape logits_shape{Dimension::dynamic(), 3, 1200};
PartialShape seq_mask_shape{100, 3, 2};
auto P = make_shared<op::Parameter>(element::f32, logits_shape);
auto I = make_shared<op::Parameter>(element::f32, seq_mask_shape);
auto data = make_shared<op::Parameter>(element::f32, logits_shape);
auto seq_mask = make_shared<op::Parameter>(element::f32, seq_mask_shape);
try {
auto G = make_shared<op::CTCGreedyDecoder>(P, I, false);
auto op = make_shared<op::CTCGreedyDecoder>(data, seq_mask, false);
// Should have thrown, so fail if it didn't
FAIL() << "Incorrect indices rank";
} catch (const NodeValidationFailure& error) {
@ -112,11 +190,11 @@ TEST(type_prop, ctc_greedy_decoder_incorrect_rank2) {
TEST(type_prop, ctc_greedy_decoder_mismatched_dim1) {
PartialShape logits_shape{100, 4, 1200};
PartialShape seq_mask_shape{100, 3};
auto P = make_shared<op::Parameter>(element::f32, logits_shape);
auto I = make_shared<op::Parameter>(element::f32, seq_mask_shape);
auto data = make_shared<op::Parameter>(element::f32, logits_shape);
auto seq_mask = make_shared<op::Parameter>(element::f32, seq_mask_shape);
try {
auto G = make_shared<op::CTCGreedyDecoder>(P, I, false);
auto op = make_shared<op::CTCGreedyDecoder>(data, seq_mask, false);
// Should have thrown, so fail if it didn't
FAIL() << "Incorrect indices rank";
} catch (const NodeValidationFailure& error) {
@ -129,11 +207,11 @@ TEST(type_prop, ctc_greedy_decoder_mismatched_dim1) {
TEST(type_prop, ctc_greedy_decoder_mismatched_dim2) {
PartialShape logits_shape{101, 3, 1200};
PartialShape seq_mask_shape{100, 3};
auto P = make_shared<op::Parameter>(element::f32, logits_shape);
auto I = make_shared<op::Parameter>(element::f32, seq_mask_shape);
auto data = make_shared<op::Parameter>(element::f32, logits_shape);
auto seq_mask = make_shared<op::Parameter>(element::f32, seq_mask_shape);
try {
auto G = make_shared<op::CTCGreedyDecoder>(P, I, false);
auto op = make_shared<op::CTCGreedyDecoder>(data, seq_mask, false);
// Should have thrown, so fail if it didn't
FAIL() << "Incorrect indices rank";
} catch (const NodeValidationFailure& error) {

View File

@ -1,25 +0,0 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include <ctc_greedy_decoder_shape_inference.hpp>
#include <openvino/op/ops.hpp>
#include <openvino/op/parameter.hpp>
#include <utils/shape_inference/shape_inference.hpp>
#include <utils/shape_inference/static_shape.hpp>
using namespace ov;
using namespace ov::intel_cpu;
TEST(StaticShapeInferenceTest, CtcGreedyDecoderTest) {
auto P = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1});
auto I = std::make_shared<ov::op::v0::Parameter>(element::i32, PartialShape{-1, -1});
auto G = std::make_shared<op::v0::CTCGreedyDecoder>(P, I, false);
// Test StaticShape
std::vector<StaticShape> static_input_shapes = {StaticShape{100, 3, 1200}, StaticShape{100, 3}},
static_output_shapes = {StaticShape{}};
shape_inference(G.get(), static_input_shapes, static_output_shapes);
ASSERT_EQ(static_output_shapes[0], StaticShape({3, 100, 1, 1}));
}

View File

@ -0,0 +1,64 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "ctc_greedy_decoder_shape_inference.hpp"
#include <gtest/gtest.h>
#include "common_test_utils/test_assertions.hpp"
#include "openvino/op/ops.hpp"
#include "utils.hpp"
using namespace ov;
using namespace ov::intel_cpu;
using namespace testing;
class CTCGreedyDecoderV0StaticShapeInferenceTest : public OpStaticShapeInferenceTest<op::v0::CTCGreedyDecoder> {
void SetUp() override {
output_shapes.resize(1);
}
};
TEST_F(CTCGreedyDecoderV0StaticShapeInferenceTest, basic) {
auto data = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1});
auto seq_mask = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{-1, -1});
op = make_op(data, seq_mask, false);
input_shapes = {StaticShape{100, 3, 1200}, StaticShape{100, 3}};
shape_inference(op.get(), input_shapes, output_shapes);
EXPECT_EQ(output_shapes[0], StaticShape({3, 100, 1, 1}));
}
TEST_F(CTCGreedyDecoderV0StaticShapeInferenceTest, decoder_default_ctor) {
op = make_op();
input_shapes = {StaticShape{100, 3, 1200}, StaticShape{100, 3}};
shape_infer(op.get(), input_shapes, output_shapes);
EXPECT_EQ(output_shapes[0], StaticShape({3, 100, 1, 1}));
}
TEST_F(CTCGreedyDecoderV0StaticShapeInferenceTest, incompatible_batch) {
auto data = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic());
auto seq_mask = std::make_shared<op::v0::Parameter>(element::i32, PartialShape::dynamic());
op = make_op(data, seq_mask, false);
input_shapes = {StaticShape{10, 3, 1200}, StaticShape{100, 3}};
OV_EXPECT_THROW(shape_inference(op.get(), input_shapes, output_shapes),
NodeValidationFailure,
HasSubstr("The first dimensions of input tensors must match"))
}
TEST_F(CTCGreedyDecoderV0StaticShapeInferenceTest, incompatible_t_dim) {
auto data = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic());
auto seq_mask = std::make_shared<op::v0::Parameter>(element::i32, PartialShape::dynamic());
op = make_op(data, seq_mask, false);
input_shapes = {StaticShape{100, 3, 1200}, StaticShape{100, 5}};
OV_EXPECT_THROW(shape_inference(op.get(), input_shapes, output_shapes),
NodeValidationFailure,
HasSubstr("The second dimensions of input tensors must match"))
}