From 5408f611e9130555bc617266b5e77aace9e5f618 Mon Sep 17 00:00:00 2001 From: iliya mironov Date: Thu, 14 Jan 2021 14:14:32 +0300 Subject: [PATCH] Add ctc gready decoder sec len op to ngraph (#3499) * Add ctc gready decoder sec len op to ngraph * Remove some comments * Add second constructor * Fix code style * Fix code style * Add unit tests * Add tests to cmake * Fix according to review * Fix code style * fix * Change input layoyt * Fix code style * Add unit tests * Add 3 input tensor check * Update shell impl * Fix code style * Fix code style * Add doxy gen * Fix code style * Update doxigen * Update constructor description * Fix code style * Refactoring code * fix code style * Optimize op constructor * Add macros. Optimize code for validate_and_infer_types * Refactoring code * Fix code style * Fix code style * Fix check blanck_index shape * Fix code style * Fix unit test for dynemic case * Fix code style * Fix copyryting * reverse changes * Update copyrite --- .../ngraph/op/ctc_greedy_decoder_seq_len.hpp | 116 ++++++++++ ngraph/core/include/ngraph/ops.hpp | 1 + .../core/include/ngraph/opsets/opset6_tbl.hpp | 1 + .../src/op/ctc_greedy_decoder_seq_len.cpp | 170 +++++++++++++++ ngraph/test/CMakeLists.txt | 1 + .../type_prop/ctc_greedy_decoder_seq_len.cpp | 203 ++++++++++++++++++ 6 files changed, 492 insertions(+) create mode 100644 ngraph/core/include/ngraph/op/ctc_greedy_decoder_seq_len.hpp create mode 100644 ngraph/core/src/op/ctc_greedy_decoder_seq_len.cpp create mode 100644 ngraph/test/type_prop/ctc_greedy_decoder_seq_len.cpp diff --git a/ngraph/core/include/ngraph/op/ctc_greedy_decoder_seq_len.hpp b/ngraph/core/include/ngraph/op/ctc_greedy_decoder_seq_len.hpp new file mode 100644 index 00000000000..444c055d07c --- /dev/null +++ b/ngraph/core/include/ngraph/op/ctc_greedy_decoder_seq_len.hpp @@ -0,0 +1,116 @@ +//***************************************************************************** +// Copyright 2017-2021 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** + +#pragma once + +#include "ngraph/op/op.hpp" + +namespace ngraph +{ + namespace op + { + namespace v6 + { + /// \brief Operator performing CTCGreedyDecoder + /// + class NGRAPH_API CTCGreedyDecoderSeqLen : public Op + { + public: + NGRAPH_RTTI_DECLARATION; + CTCGreedyDecoderSeqLen() = default; + /// \brief Constructs a CTCGreedyDecoderSeqLen operation + /// + /// \param input 3-D tensor of logits on which greedy decoding is + /// performed + /// \param seq_len 1-D tensor of sequence lengths + /// \param merge_repeated Whether to merge repeated labels + /// \param classes_index_type Specifies the output classes_index tensor type + /// \param sequence_length_type Specifies the output sequence_length tensor type + CTCGreedyDecoderSeqLen(const Output& input, + const Output& seq_len, + const bool merge_repeated = true, + const element::Type& classes_index_type = element::i32, + const element::Type& sequence_length_type = element::i32); + /// \brief Constructs a CTCGreedyDecoderSeqLen operation + /// + /// \param input 3-D tensor of logits on which greedy decoding is + /// performed + /// \param seq_len 1-D tensor of sequence lengths + /// \param blank_index Scalar or 1-D tensor with 1 element used to mark a + /// blank index + /// \param merge_repeated Whether to merge repeated labels + /// \param classes_index_type Specifies the output classes_index tensor type + /// \param sequence_length_type Specifies the output sequence_length tensor type + CTCGreedyDecoderSeqLen(const Output& input, + const Output& seq_len, + const Output& blank_index, + const bool merge_repeated = true, + const element::Type& classes_index_type = element::i32, + const element::Type& sequence_length_type = element::i32); + + void validate_and_infer_types() override; + bool visit_attributes(AttributeVisitor& visitor) override; + + std::shared_ptr + clone_with_new_inputs(const OutputVector& new_args) const override; + + /// \brief Get merge_repeated attribute + /// + /// \return Current value of merge_repeated attribute + /// + bool get_merge_repeated() const { return m_merge_repeated; } + /// \brief Get classes_index_type attribute + /// + /// \return Current value of classes_index_type attribute + /// + const element::Type& get_classes_index_type() const { return m_classes_index_type; } + /// \brief Set classes_index_type attribute + /// + /// \param classes_index_type Type of classes_index + /// + void set_classes_index_type(const element::Type& classes_index_type) + { + m_classes_index_type = classes_index_type; + validate_and_infer_types(); + } + + /// \brief Get sequence_length_type attribute + /// + /// \return Current value of sequence_length_type attribute + /// + const element::Type& get_sequence_length_type() const + { + return m_sequence_length_type; + } + + /// \brief Set sequence_length_type attribute + /// + /// \param sequence_length_type Type of sequence length + /// + void set_sequence_length_type(const element::Type& sequence_length_type) + { + m_sequence_length_type = sequence_length_type; + validate_and_infer_types(); + } + + private: + bool m_merge_repeated; + element::Type m_classes_index_type{element::i32}; + element::Type m_sequence_length_type{element::i32}; + }; + } // namespace v6 + } // namespace op +} // namespace ngraph diff --git a/ngraph/core/include/ngraph/ops.hpp b/ngraph/core/include/ngraph/ops.hpp index e9b5d1a0d25..491d40ddbd5 100644 --- a/ngraph/core/include/ngraph/ops.hpp +++ b/ngraph/core/include/ngraph/ops.hpp @@ -44,6 +44,7 @@ #include "ngraph/op/cos.hpp" #include "ngraph/op/cosh.hpp" #include "ngraph/op/ctc_greedy_decoder.hpp" +#include "ngraph/op/ctc_greedy_decoder_seq_len.hpp" #include "ngraph/op/ctc_loss.hpp" #include "ngraph/op/cum_sum.hpp" #include "ngraph/op/deformable_convolution.hpp" diff --git a/ngraph/core/include/ngraph/opsets/opset6_tbl.hpp b/ngraph/core/include/ngraph/opsets/opset6_tbl.hpp index 19068ee31b4..a6554b903f3 100644 --- a/ngraph/core/include/ngraph/opsets/opset6_tbl.hpp +++ b/ngraph/core/include/ngraph/opsets/opset6_tbl.hpp @@ -173,5 +173,6 @@ NGRAPH_OP(RNNSequence, ngraph::op::v5) NGRAPH_OP(Round, ngraph::op::v5) // New operations added in opset6 +NGRAPH_OP(CTCGreedyDecoderSeqLen, ngraph::op::v6) NGRAPH_OP(MVN, ngraph::op::v6) NGRAPH_OP(GatherElements, ngraph::op::v6) diff --git a/ngraph/core/src/op/ctc_greedy_decoder_seq_len.cpp b/ngraph/core/src/op/ctc_greedy_decoder_seq_len.cpp new file mode 100644 index 00000000000..322575d1730 --- /dev/null +++ b/ngraph/core/src/op/ctc_greedy_decoder_seq_len.cpp @@ -0,0 +1,170 @@ +//***************************************************************************** +// Copyright 2017-2021 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** + +#include "itt.hpp" + +#include "ngraph/op/ctc_greedy_decoder_seq_len.hpp" + +using namespace std; +using namespace ngraph; + +NGRAPH_RTTI_DEFINITION(op::v6::CTCGreedyDecoderSeqLen, "CTCGreedyDecoderSeqLen", 6); + +op::v6::CTCGreedyDecoderSeqLen::CTCGreedyDecoderSeqLen(const Output& input, + const Output& seq_len, + const bool merge_repeated, + const element::Type& classes_index_type, + const element::Type& sequence_length_type) + : Op({input, seq_len}) + , m_merge_repeated(merge_repeated) + , m_classes_index_type(classes_index_type) + , m_sequence_length_type(sequence_length_type) +{ + constructor_validate_and_infer_types(); +} + +op::v6::CTCGreedyDecoderSeqLen::CTCGreedyDecoderSeqLen(const Output& input, + const Output& seq_len, + const Output& blank_index, + const bool merge_repeated, + const element::Type& classes_index_type, + const element::Type& sequence_length_type) + : Op({input, seq_len, blank_index}) + , m_merge_repeated(merge_repeated) + , m_classes_index_type(classes_index_type) + , m_sequence_length_type(sequence_length_type) +{ + constructor_validate_and_infer_types(); +} + +void op::v6::CTCGreedyDecoderSeqLen::validate_and_infer_types() +{ + NGRAPH_OP_SCOPE(v6_CTCGreedyDecoderSeqLen_validate_and_infer_types); + const auto& logits_pshape = get_input_partial_shape(0); + const auto& seq_len_pshape = get_input_partial_shape(1); + auto input_et = get_input_element_type(0); + const bool logits_is_static_rank = logits_pshape.rank().is_static(); + const bool seq_len_is_static_rank = seq_len_pshape.rank().is_static(); + + // check ranks of input tensors + if (logits_is_static_rank) + { + NODE_VALIDATION_CHECK(this, + logits_pshape.rank().get_length() == 3, + "The rank of logits tensor must be equal to 3."); + } + if (seq_len_is_static_rank) + { + NODE_VALIDATION_CHECK(this, + seq_len_pshape.rank().get_length() == 1, + "The rank of sequence len tensor must be equal to 1."); + } + + // check optional input type: blank index + if (get_input_size() == 3) + { + const auto& blank_index_type = get_input_element_type(2); + NODE_VALIDATION_CHECK(this, + blank_index_type.is_integral_number(), + "The blank index type is expected to be an integer type. Got: ", + blank_index_type); + + const auto& blank_index_partial_shape = get_input_partial_shape(2); + if (blank_index_partial_shape.is_static()) + { + Shape blank_index_shape = blank_index_partial_shape.to_shape(); + NODE_VALIDATION_CHECK(this, + ngraph::is_scalar(blank_index_shape) || + (is_vector(blank_index_shape) && (blank_index_shape[0] == 1)), + "Expected 0D or 1D tensor for the 'blank_index' input. Got: ", + blank_index_shape); + } + } + + // validate input shapes and compute output shape + ngraph::Dimension batch_size = Dimension::dynamic(); + ngraph::Dimension time_size = Dimension::dynamic(); + + if (logits_is_static_rank) + { + if (logits_pshape[0].is_static()) + { + batch_size = logits_pshape[0]; + } + if (logits_pshape[1].is_static()) + { + time_size = logits_pshape[1]; + } + } + + if (seq_len_is_static_rank && seq_len_pshape[0].is_static()) + { + if (batch_size != Dimension::dynamic()) + { + NODE_VALIDATION_CHECK(this, + seq_len_pshape[0] == batch_size, + "The first dimensions of input tensors must match."); + } + batch_size = seq_len_pshape[0]; + } + + if (logits_is_static_rank && seq_len_is_static_rank) + { + batch_size = seq_len_pshape[0] & logits_pshape[0]; + } + + set_output_type(0, m_classes_index_type, PartialShape{batch_size, time_size}); + set_output_type(1, m_sequence_length_type, PartialShape{batch_size}); +} + +bool op::v6::CTCGreedyDecoderSeqLen::visit_attributes(AttributeVisitor& visitor) +{ + NGRAPH_OP_SCOPE(v6_CTCGreedyDecoderSeqLen_visit_attributes); + visitor.on_attribute("merge_repeated", m_merge_repeated); + visitor.on_attribute("classes_index_type", m_classes_index_type); + visitor.on_attribute("sequence_length_type", m_sequence_length_type); + return true; +} + +shared_ptr + op::v6::CTCGreedyDecoderSeqLen::clone_with_new_inputs(const OutputVector& new_args) const +{ + NGRAPH_OP_SCOPE(v6_CTCGreedyDecoderSeqLen_clone_with_new_inputs); + check_new_args_count(this, new_args); + + size_t args_size = new_args.size(); + if (args_size == 2) + { + return make_shared(new_args.at(0), + new_args.at(1), + m_merge_repeated, + m_classes_index_type, + m_sequence_length_type); + } + else if (args_size == 3) + { + return make_shared(new_args.at(0), + new_args.at(1), + new_args.at(2), + m_merge_repeated, + m_classes_index_type, + m_sequence_length_type); + } + else + { + throw ngraph_error("Incorrect number of arguments"); + } +} diff --git a/ngraph/test/CMakeLists.txt b/ngraph/test/CMakeLists.txt index 7f99ece7c30..6a1be5b9e4c 100644 --- a/ngraph/test/CMakeLists.txt +++ b/ngraph/test/CMakeLists.txt @@ -118,6 +118,7 @@ set(SRC type_prop/convert.cpp type_prop/convolution.cpp type_prop/ctc_greedy_decoder.cpp + type_prop/ctc_greedy_decoder_seq_len.cpp type_prop/ctc_loss.cpp type_prop/deformable_convolution.cpp type_prop/deformable_psroi_pooling.cpp diff --git a/ngraph/test/type_prop/ctc_greedy_decoder_seq_len.cpp b/ngraph/test/type_prop/ctc_greedy_decoder_seq_len.cpp new file mode 100644 index 00000000000..c1847653d24 --- /dev/null +++ b/ngraph/test/type_prop/ctc_greedy_decoder_seq_len.cpp @@ -0,0 +1,203 @@ +//***************************************************************************** +// Copyright 2017-2021 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** + +#include "gtest/gtest.h" +#include "ngraph/ngraph.hpp" +#include "util/type_prop.hpp" + +using namespace std; +using namespace ngraph; + +TEST(type_prop, ctc_greedy_decoder_seq_len_static_shapes) +{ + PartialShape logits_shape{3, 100, 1200}; + PartialShape seq_len_shape{3}; + Shape out_shape1{3, 100}; + Shape out_shape2{3}; + auto P = make_shared(element::f32, logits_shape); + auto I = make_shared(element::i32, seq_len_shape); + auto G = make_shared(P, I); + ASSERT_EQ(G->get_output_element_type(0), element::i32); + ASSERT_EQ(G->get_output_element_type(1), element::i32); + ASSERT_EQ(G->get_output_shape(0), out_shape1); + ASSERT_EQ(G->get_output_shape(1), out_shape2); +} + +TEST(type_prop, ctc_greedy_decoder_seq_len_static_shapes_with_bi) +{ + PartialShape logits_shape{3, 100, 1200}; + PartialShape seq_len_shape{3}; + Shape out_shape1{3, 100}; + Shape out_shape2{3}; + auto P = make_shared(element::f32, logits_shape); + auto I = make_shared(element::i32, seq_len_shape); + auto BI = op::Constant::create(element::i32, Shape{}, {1}); + auto G = + make_shared(P, I, BI, false, element::i64, element::i64); + ASSERT_EQ(G->get_output_element_type(0), element::i64); + ASSERT_EQ(G->get_output_element_type(1), element::i64); + ASSERT_EQ(G->get_output_shape(0), out_shape1); + ASSERT_EQ(G->get_output_shape(1), out_shape2); +} + +TEST(type_prop, ctc_greedy_decoder_seq_len_static_shapes_with_dinemic_bi) +{ + PartialShape logits_shape{3, 100, 1200}; + PartialShape seq_len_shape{3}; + Shape out_shape1{3, 100}; + Shape out_shape2{3}; + auto P = make_shared(element::f32, logits_shape); + auto I = make_shared(element::i32, seq_len_shape); + auto BI = make_shared(element::i32, PartialShape{Dimension::dynamic()}); + auto G = + make_shared(P, I, BI, false, element::i64, element::i64); + ASSERT_EQ(G->get_output_element_type(0), element::i64); + ASSERT_EQ(G->get_output_element_type(1), element::i64); + ASSERT_EQ(G->get_output_shape(0), out_shape1); + ASSERT_EQ(G->get_output_shape(1), out_shape2); +} + +TEST(type_prop, ctc_greedy_decoder_seq_len_output_static_shape1) +{ + PartialShape logits_shape{Dimension::dynamic(), 100, 1200}; + PartialShape seq_len_shape{3}; + Shape out_shape1{3, 100}; + Shape out_shape2{3}; + auto P = make_shared(element::f32, logits_shape); + auto I = make_shared(element::i32, seq_len_shape); + auto G = make_shared(P, I, false); + ASSERT_EQ(G->get_output_element_type(0), element::i32); + ASSERT_EQ(G->get_output_element_type(1), element::i32); + ASSERT_EQ(G->get_output_shape(0), out_shape1); + ASSERT_EQ(G->get_output_shape(1), out_shape2); +} + +TEST(type_prop, ctc_greedy_decoder_seq_len_dynamic_shapes) +{ + PartialShape logits_shape{Dimension::dynamic(), Dimension::dynamic(), 1200}; + PartialShape seq_len_shape{Dimension::dynamic()}; + PartialShape out_shape1{Dimension::dynamic(), Dimension::dynamic()}; + PartialShape out_shape2{Dimension::dynamic()}; + auto P = make_shared(element::f32, logits_shape); + auto I = make_shared(element::i32, seq_len_shape); + auto G = make_shared(P, I, false); + ASSERT_EQ(G->get_output_element_type(0), element::i32); + ASSERT_EQ(G->get_output_element_type(1), element::i32); + ASSERT_TRUE(G->get_output_partial_shape(0).same_scheme(out_shape1)); + ASSERT_TRUE(G->get_output_partial_shape(1).same_scheme(out_shape2)); +} + +TEST(type_prop, ctc_greedy_decoder_seq_len_dynamic_ranks1) +{ + PartialShape logits_shape = PartialShape::dynamic(); + PartialShape seq_len_shape{Dimension::dynamic()}; + PartialShape out_shape1{Dimension::dynamic(), Dimension::dynamic()}; + PartialShape out_shape2{Dimension::dynamic()}; + auto P = make_shared(element::f32, logits_shape); + auto I = make_shared(element::i32, seq_len_shape); + auto G = make_shared(P, I); + ASSERT_EQ(G->get_output_element_type(0), element::i32); + ASSERT_EQ(G->get_output_element_type(1), element::i32); + ASSERT_TRUE(G->get_output_partial_shape(0).same_scheme(out_shape1)); + ASSERT_TRUE(G->get_output_partial_shape(1).same_scheme(out_shape2)); +} + +TEST(type_prop, ctc_greedy_decoder_seq_len_dynamic_ranks2) +{ + PartialShape logits_shape = PartialShape::dynamic(); + PartialShape seq_mask_shape = PartialShape::dynamic(); + PartialShape out_shape1{Dimension::dynamic(), Dimension::dynamic()}; + PartialShape out_shape2{Dimension::dynamic()}; + auto P = make_shared(element::f32, logits_shape); + auto I = make_shared(element::i32, seq_mask_shape); + auto G = make_shared(P, I, false); + ASSERT_EQ(G->get_output_element_type(0), element::i32); + ASSERT_EQ(G->get_output_element_type(1), element::i32); + ASSERT_TRUE(G->get_output_partial_shape(0).same_scheme(out_shape1)); + ASSERT_TRUE(G->get_output_partial_shape(1).same_scheme(out_shape2)); +} + +TEST(type_prop, ctc_greedy_decoder_seq_len_incorrect_rank) +{ + PartialShape logits_shape{Dimension::dynamic(), 100, 1200, 5}; + PartialShape seq_len_shape{3}; + auto P = make_shared(element::f32, logits_shape); + auto I = make_shared(element::i32, seq_len_shape); + + try + { + auto G = make_shared(P, I, false); + // Should have thrown, so fail if it didn't + FAIL() << "Incorrect indices rank"; + } + catch (const NodeValidationFailure& error) + { + EXPECT_HAS_SUBSTRING(error.what(), + std::string("The rank of logits tensor must be equal to 3.")); + } + catch (...) + { + FAIL() << "Rank check failed for unexpected reason"; + } +} + +TEST(type_prop, ctc_greedy_decoder_seq_len_incorrect_rank2) +{ + PartialShape logits_shape{3, 100, 1200}; + PartialShape seq_len_shape{3, 100}; + auto P = make_shared(element::f32, logits_shape); + auto I = make_shared(element::i32, seq_len_shape); + + try + { + auto G = make_shared(P, I, false); + // Should have thrown, so fail if it didn't + FAIL() << "Incorrect indices rank"; + } + catch (const NodeValidationFailure& error) + { + EXPECT_HAS_SUBSTRING(error.what(), + std::string("The rank of sequence len tensor must be equal to 1.")); + } + catch (...) + { + FAIL() << "Rank check failed for unexpected reason"; + } +} + +TEST(type_prop, ctc_greedy_decoder_seq_len_mismatched_dim1) +{ + PartialShape logits_shape{4, 100, 1200}; + PartialShape seq_mask_shape{3}; + auto P = make_shared(element::f32, logits_shape); + auto I = make_shared(element::i32, seq_mask_shape); + + try + { + auto G = make_shared(P, I, false); + // Should have thrown, so fail if it didn't + FAIL() << "Incorrect indices rank"; + } + catch (const NodeValidationFailure& error) + { + EXPECT_HAS_SUBSTRING(error.what(), + std::string("The first dimensions of input tensors must match.")); + } + catch (...) + { + FAIL() << "Rank check failed for unexpected reason"; + } +}