Add ctc gready decoder sec len op to ngraph (#3499)

* Add ctc gready decoder sec len op to ngraph

* Remove some comments

* Add second constructor

* Fix code style

* Fix code style

* Add unit tests

* Add tests to cmake

* Fix according to review

* Fix code style

* fix

* Change input layoyt

* Fix code style

* Add unit tests

* Add 3 input tensor check

* Update shell impl

* Fix code style

* Fix code style

* Add doxy gen

* Fix code style

* Update doxigen

* Update constructor description

* Fix code style

* Refactoring code

* fix code style

* Optimize op constructor

* Add macros. Optimize code for validate_and_infer_types

* Refactoring code

* Fix code style

* Fix code style

* Fix check blanck_index shape

* Fix code style

* Fix unit test for dynemic case

* Fix code style

* Fix copyryting

* reverse changes

* Update copyrite
This commit is contained in:
iliya mironov 2021-01-14 14:14:32 +03:00 committed by GitHub
parent 20b9d390e1
commit 5408f611e9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 492 additions and 0 deletions

View File

@ -0,0 +1,116 @@
//*****************************************************************************
// Copyright 2017-2021 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/op/op.hpp"
namespace ngraph
{
namespace op
{
namespace v6
{
/// \brief Operator performing CTCGreedyDecoder
///
class NGRAPH_API CTCGreedyDecoderSeqLen : public Op
{
public:
NGRAPH_RTTI_DECLARATION;
CTCGreedyDecoderSeqLen() = default;
/// \brief Constructs a CTCGreedyDecoderSeqLen operation
///
/// \param input 3-D tensor of logits on which greedy decoding is
/// performed
/// \param seq_len 1-D tensor of sequence lengths
/// \param merge_repeated Whether to merge repeated labels
/// \param classes_index_type Specifies the output classes_index tensor type
/// \param sequence_length_type Specifies the output sequence_length tensor type
CTCGreedyDecoderSeqLen(const Output<Node>& input,
const Output<Node>& seq_len,
const bool merge_repeated = true,
const element::Type& classes_index_type = element::i32,
const element::Type& sequence_length_type = element::i32);
/// \brief Constructs a CTCGreedyDecoderSeqLen operation
///
/// \param input 3-D tensor of logits on which greedy decoding is
/// performed
/// \param seq_len 1-D tensor of sequence lengths
/// \param blank_index Scalar or 1-D tensor with 1 element used to mark a
/// blank index
/// \param merge_repeated Whether to merge repeated labels
/// \param classes_index_type Specifies the output classes_index tensor type
/// \param sequence_length_type Specifies the output sequence_length tensor type
CTCGreedyDecoderSeqLen(const Output<Node>& input,
const Output<Node>& seq_len,
const Output<Node>& blank_index,
const bool merge_repeated = true,
const element::Type& classes_index_type = element::i32,
const element::Type& sequence_length_type = element::i32);
void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;
std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
/// \brief Get merge_repeated attribute
///
/// \return Current value of merge_repeated attribute
///
bool get_merge_repeated() const { return m_merge_repeated; }
/// \brief Get classes_index_type attribute
///
/// \return Current value of classes_index_type attribute
///
const element::Type& get_classes_index_type() const { return m_classes_index_type; }
/// \brief Set classes_index_type attribute
///
/// \param classes_index_type Type of classes_index
///
void set_classes_index_type(const element::Type& classes_index_type)
{
m_classes_index_type = classes_index_type;
validate_and_infer_types();
}
/// \brief Get sequence_length_type attribute
///
/// \return Current value of sequence_length_type attribute
///
const element::Type& get_sequence_length_type() const
{
return m_sequence_length_type;
}
/// \brief Set sequence_length_type attribute
///
/// \param sequence_length_type Type of sequence length
///
void set_sequence_length_type(const element::Type& sequence_length_type)
{
m_sequence_length_type = sequence_length_type;
validate_and_infer_types();
}
private:
bool m_merge_repeated;
element::Type m_classes_index_type{element::i32};
element::Type m_sequence_length_type{element::i32};
};
} // namespace v6
} // namespace op
} // namespace ngraph

View File

@ -44,6 +44,7 @@
#include "ngraph/op/cos.hpp" #include "ngraph/op/cos.hpp"
#include "ngraph/op/cosh.hpp" #include "ngraph/op/cosh.hpp"
#include "ngraph/op/ctc_greedy_decoder.hpp" #include "ngraph/op/ctc_greedy_decoder.hpp"
#include "ngraph/op/ctc_greedy_decoder_seq_len.hpp"
#include "ngraph/op/ctc_loss.hpp" #include "ngraph/op/ctc_loss.hpp"
#include "ngraph/op/cum_sum.hpp" #include "ngraph/op/cum_sum.hpp"
#include "ngraph/op/deformable_convolution.hpp" #include "ngraph/op/deformable_convolution.hpp"

View File

@ -173,5 +173,6 @@ NGRAPH_OP(RNNSequence, ngraph::op::v5)
NGRAPH_OP(Round, ngraph::op::v5) NGRAPH_OP(Round, ngraph::op::v5)
// New operations added in opset6 // New operations added in opset6
NGRAPH_OP(CTCGreedyDecoderSeqLen, ngraph::op::v6)
NGRAPH_OP(MVN, ngraph::op::v6) NGRAPH_OP(MVN, ngraph::op::v6)
NGRAPH_OP(GatherElements, ngraph::op::v6) NGRAPH_OP(GatherElements, ngraph::op::v6)

View File

@ -0,0 +1,170 @@
//*****************************************************************************
// Copyright 2017-2021 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "itt.hpp"
#include "ngraph/op/ctc_greedy_decoder_seq_len.hpp"
using namespace std;
using namespace ngraph;
NGRAPH_RTTI_DEFINITION(op::v6::CTCGreedyDecoderSeqLen, "CTCGreedyDecoderSeqLen", 6);
op::v6::CTCGreedyDecoderSeqLen::CTCGreedyDecoderSeqLen(const Output<Node>& input,
const Output<Node>& seq_len,
const bool merge_repeated,
const element::Type& classes_index_type,
const element::Type& sequence_length_type)
: Op({input, seq_len})
, m_merge_repeated(merge_repeated)
, m_classes_index_type(classes_index_type)
, m_sequence_length_type(sequence_length_type)
{
constructor_validate_and_infer_types();
}
op::v6::CTCGreedyDecoderSeqLen::CTCGreedyDecoderSeqLen(const Output<Node>& input,
const Output<Node>& seq_len,
const Output<Node>& blank_index,
const bool merge_repeated,
const element::Type& classes_index_type,
const element::Type& sequence_length_type)
: Op({input, seq_len, blank_index})
, m_merge_repeated(merge_repeated)
, m_classes_index_type(classes_index_type)
, m_sequence_length_type(sequence_length_type)
{
constructor_validate_and_infer_types();
}
void op::v6::CTCGreedyDecoderSeqLen::validate_and_infer_types()
{
NGRAPH_OP_SCOPE(v6_CTCGreedyDecoderSeqLen_validate_and_infer_types);
const auto& logits_pshape = get_input_partial_shape(0);
const auto& seq_len_pshape = get_input_partial_shape(1);
auto input_et = get_input_element_type(0);
const bool logits_is_static_rank = logits_pshape.rank().is_static();
const bool seq_len_is_static_rank = seq_len_pshape.rank().is_static();
// check ranks of input tensors
if (logits_is_static_rank)
{
NODE_VALIDATION_CHECK(this,
logits_pshape.rank().get_length() == 3,
"The rank of logits tensor must be equal to 3.");
}
if (seq_len_is_static_rank)
{
NODE_VALIDATION_CHECK(this,
seq_len_pshape.rank().get_length() == 1,
"The rank of sequence len tensor must be equal to 1.");
}
// check optional input type: blank index
if (get_input_size() == 3)
{
const auto& blank_index_type = get_input_element_type(2);
NODE_VALIDATION_CHECK(this,
blank_index_type.is_integral_number(),
"The blank index type is expected to be an integer type. Got: ",
blank_index_type);
const auto& blank_index_partial_shape = get_input_partial_shape(2);
if (blank_index_partial_shape.is_static())
{
Shape blank_index_shape = blank_index_partial_shape.to_shape();
NODE_VALIDATION_CHECK(this,
ngraph::is_scalar(blank_index_shape) ||
(is_vector(blank_index_shape) && (blank_index_shape[0] == 1)),
"Expected 0D or 1D tensor for the 'blank_index' input. Got: ",
blank_index_shape);
}
}
// validate input shapes and compute output shape
ngraph::Dimension batch_size = Dimension::dynamic();
ngraph::Dimension time_size = Dimension::dynamic();
if (logits_is_static_rank)
{
if (logits_pshape[0].is_static())
{
batch_size = logits_pshape[0];
}
if (logits_pshape[1].is_static())
{
time_size = logits_pshape[1];
}
}
if (seq_len_is_static_rank && seq_len_pshape[0].is_static())
{
if (batch_size != Dimension::dynamic())
{
NODE_VALIDATION_CHECK(this,
seq_len_pshape[0] == batch_size,
"The first dimensions of input tensors must match.");
}
batch_size = seq_len_pshape[0];
}
if (logits_is_static_rank && seq_len_is_static_rank)
{
batch_size = seq_len_pshape[0] & logits_pshape[0];
}
set_output_type(0, m_classes_index_type, PartialShape{batch_size, time_size});
set_output_type(1, m_sequence_length_type, PartialShape{batch_size});
}
bool op::v6::CTCGreedyDecoderSeqLen::visit_attributes(AttributeVisitor& visitor)
{
NGRAPH_OP_SCOPE(v6_CTCGreedyDecoderSeqLen_visit_attributes);
visitor.on_attribute("merge_repeated", m_merge_repeated);
visitor.on_attribute("classes_index_type", m_classes_index_type);
visitor.on_attribute("sequence_length_type", m_sequence_length_type);
return true;
}
shared_ptr<Node>
op::v6::CTCGreedyDecoderSeqLen::clone_with_new_inputs(const OutputVector& new_args) const
{
NGRAPH_OP_SCOPE(v6_CTCGreedyDecoderSeqLen_clone_with_new_inputs);
check_new_args_count(this, new_args);
size_t args_size = new_args.size();
if (args_size == 2)
{
return make_shared<CTCGreedyDecoderSeqLen>(new_args.at(0),
new_args.at(1),
m_merge_repeated,
m_classes_index_type,
m_sequence_length_type);
}
else if (args_size == 3)
{
return make_shared<CTCGreedyDecoderSeqLen>(new_args.at(0),
new_args.at(1),
new_args.at(2),
m_merge_repeated,
m_classes_index_type,
m_sequence_length_type);
}
else
{
throw ngraph_error("Incorrect number of arguments");
}
}

View File

@ -118,6 +118,7 @@ set(SRC
type_prop/convert.cpp type_prop/convert.cpp
type_prop/convolution.cpp type_prop/convolution.cpp
type_prop/ctc_greedy_decoder.cpp type_prop/ctc_greedy_decoder.cpp
type_prop/ctc_greedy_decoder_seq_len.cpp
type_prop/ctc_loss.cpp type_prop/ctc_loss.cpp
type_prop/deformable_convolution.cpp type_prop/deformable_convolution.cpp
type_prop/deformable_psroi_pooling.cpp type_prop/deformable_psroi_pooling.cpp

View File

@ -0,0 +1,203 @@
//*****************************************************************************
// Copyright 2017-2021 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "util/type_prop.hpp"
using namespace std;
using namespace ngraph;
TEST(type_prop, ctc_greedy_decoder_seq_len_static_shapes)
{
PartialShape logits_shape{3, 100, 1200};
PartialShape seq_len_shape{3};
Shape out_shape1{3, 100};
Shape out_shape2{3};
auto P = make_shared<op::Parameter>(element::f32, logits_shape);
auto I = make_shared<op::Parameter>(element::i32, seq_len_shape);
auto G = make_shared<op::v6::CTCGreedyDecoderSeqLen>(P, I);
ASSERT_EQ(G->get_output_element_type(0), element::i32);
ASSERT_EQ(G->get_output_element_type(1), element::i32);
ASSERT_EQ(G->get_output_shape(0), out_shape1);
ASSERT_EQ(G->get_output_shape(1), out_shape2);
}
TEST(type_prop, ctc_greedy_decoder_seq_len_static_shapes_with_bi)
{
PartialShape logits_shape{3, 100, 1200};
PartialShape seq_len_shape{3};
Shape out_shape1{3, 100};
Shape out_shape2{3};
auto P = make_shared<op::Parameter>(element::f32, logits_shape);
auto I = make_shared<op::Parameter>(element::i32, seq_len_shape);
auto BI = op::Constant::create(element::i32, Shape{}, {1});
auto G =
make_shared<op::v6::CTCGreedyDecoderSeqLen>(P, I, BI, false, element::i64, element::i64);
ASSERT_EQ(G->get_output_element_type(0), element::i64);
ASSERT_EQ(G->get_output_element_type(1), element::i64);
ASSERT_EQ(G->get_output_shape(0), out_shape1);
ASSERT_EQ(G->get_output_shape(1), out_shape2);
}
TEST(type_prop, ctc_greedy_decoder_seq_len_static_shapes_with_dinemic_bi)
{
PartialShape logits_shape{3, 100, 1200};
PartialShape seq_len_shape{3};
Shape out_shape1{3, 100};
Shape out_shape2{3};
auto P = make_shared<op::Parameter>(element::f32, logits_shape);
auto I = make_shared<op::Parameter>(element::i32, seq_len_shape);
auto BI = make_shared<op::Parameter>(element::i32, PartialShape{Dimension::dynamic()});
auto G =
make_shared<op::v6::CTCGreedyDecoderSeqLen>(P, I, BI, false, element::i64, element::i64);
ASSERT_EQ(G->get_output_element_type(0), element::i64);
ASSERT_EQ(G->get_output_element_type(1), element::i64);
ASSERT_EQ(G->get_output_shape(0), out_shape1);
ASSERT_EQ(G->get_output_shape(1), out_shape2);
}
TEST(type_prop, ctc_greedy_decoder_seq_len_output_static_shape1)
{
PartialShape logits_shape{Dimension::dynamic(), 100, 1200};
PartialShape seq_len_shape{3};
Shape out_shape1{3, 100};
Shape out_shape2{3};
auto P = make_shared<op::Parameter>(element::f32, logits_shape);
auto I = make_shared<op::Parameter>(element::i32, seq_len_shape);
auto G = make_shared<op::v6::CTCGreedyDecoderSeqLen>(P, I, false);
ASSERT_EQ(G->get_output_element_type(0), element::i32);
ASSERT_EQ(G->get_output_element_type(1), element::i32);
ASSERT_EQ(G->get_output_shape(0), out_shape1);
ASSERT_EQ(G->get_output_shape(1), out_shape2);
}
TEST(type_prop, ctc_greedy_decoder_seq_len_dynamic_shapes)
{
PartialShape logits_shape{Dimension::dynamic(), Dimension::dynamic(), 1200};
PartialShape seq_len_shape{Dimension::dynamic()};
PartialShape out_shape1{Dimension::dynamic(), Dimension::dynamic()};
PartialShape out_shape2{Dimension::dynamic()};
auto P = make_shared<op::Parameter>(element::f32, logits_shape);
auto I = make_shared<op::Parameter>(element::i32, seq_len_shape);
auto G = make_shared<op::v6::CTCGreedyDecoderSeqLen>(P, I, false);
ASSERT_EQ(G->get_output_element_type(0), element::i32);
ASSERT_EQ(G->get_output_element_type(1), element::i32);
ASSERT_TRUE(G->get_output_partial_shape(0).same_scheme(out_shape1));
ASSERT_TRUE(G->get_output_partial_shape(1).same_scheme(out_shape2));
}
TEST(type_prop, ctc_greedy_decoder_seq_len_dynamic_ranks1)
{
PartialShape logits_shape = PartialShape::dynamic();
PartialShape seq_len_shape{Dimension::dynamic()};
PartialShape out_shape1{Dimension::dynamic(), Dimension::dynamic()};
PartialShape out_shape2{Dimension::dynamic()};
auto P = make_shared<op::Parameter>(element::f32, logits_shape);
auto I = make_shared<op::Parameter>(element::i32, seq_len_shape);
auto G = make_shared<op::v6::CTCGreedyDecoderSeqLen>(P, I);
ASSERT_EQ(G->get_output_element_type(0), element::i32);
ASSERT_EQ(G->get_output_element_type(1), element::i32);
ASSERT_TRUE(G->get_output_partial_shape(0).same_scheme(out_shape1));
ASSERT_TRUE(G->get_output_partial_shape(1).same_scheme(out_shape2));
}
TEST(type_prop, ctc_greedy_decoder_seq_len_dynamic_ranks2)
{
PartialShape logits_shape = PartialShape::dynamic();
PartialShape seq_mask_shape = PartialShape::dynamic();
PartialShape out_shape1{Dimension::dynamic(), Dimension::dynamic()};
PartialShape out_shape2{Dimension::dynamic()};
auto P = make_shared<op::Parameter>(element::f32, logits_shape);
auto I = make_shared<op::Parameter>(element::i32, seq_mask_shape);
auto G = make_shared<op::v6::CTCGreedyDecoderSeqLen>(P, I, false);
ASSERT_EQ(G->get_output_element_type(0), element::i32);
ASSERT_EQ(G->get_output_element_type(1), element::i32);
ASSERT_TRUE(G->get_output_partial_shape(0).same_scheme(out_shape1));
ASSERT_TRUE(G->get_output_partial_shape(1).same_scheme(out_shape2));
}
TEST(type_prop, ctc_greedy_decoder_seq_len_incorrect_rank)
{
PartialShape logits_shape{Dimension::dynamic(), 100, 1200, 5};
PartialShape seq_len_shape{3};
auto P = make_shared<op::Parameter>(element::f32, logits_shape);
auto I = make_shared<op::Parameter>(element::i32, seq_len_shape);
try
{
auto G = make_shared<op::v6::CTCGreedyDecoderSeqLen>(P, I, false);
// Should have thrown, so fail if it didn't
FAIL() << "Incorrect indices rank";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("The rank of logits tensor must be equal to 3."));
}
catch (...)
{
FAIL() << "Rank check failed for unexpected reason";
}
}
TEST(type_prop, ctc_greedy_decoder_seq_len_incorrect_rank2)
{
PartialShape logits_shape{3, 100, 1200};
PartialShape seq_len_shape{3, 100};
auto P = make_shared<op::Parameter>(element::f32, logits_shape);
auto I = make_shared<op::Parameter>(element::i32, seq_len_shape);
try
{
auto G = make_shared<op::v6::CTCGreedyDecoderSeqLen>(P, I, false);
// Should have thrown, so fail if it didn't
FAIL() << "Incorrect indices rank";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("The rank of sequence len tensor must be equal to 1."));
}
catch (...)
{
FAIL() << "Rank check failed for unexpected reason";
}
}
TEST(type_prop, ctc_greedy_decoder_seq_len_mismatched_dim1)
{
PartialShape logits_shape{4, 100, 1200};
PartialShape seq_mask_shape{3};
auto P = make_shared<op::Parameter>(element::f32, logits_shape);
auto I = make_shared<op::Parameter>(element::i32, seq_mask_shape);
try
{
auto G = make_shared<op::v6::CTCGreedyDecoderSeqLen>(P, I, false);
// Should have thrown, so fail if it didn't
FAIL() << "Incorrect indices rank";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("The first dimensions of input tensors must match."));
}
catch (...)
{
FAIL() << "Rank check failed for unexpected reason";
}
}