Add ctc gready decoder sec len op to ngraph (#3499)
* Add ctc gready decoder sec len op to ngraph * Remove some comments * Add second constructor * Fix code style * Fix code style * Add unit tests * Add tests to cmake * Fix according to review * Fix code style * fix * Change input layoyt * Fix code style * Add unit tests * Add 3 input tensor check * Update shell impl * Fix code style * Fix code style * Add doxy gen * Fix code style * Update doxigen * Update constructor description * Fix code style * Refactoring code * fix code style * Optimize op constructor * Add macros. Optimize code for validate_and_infer_types * Refactoring code * Fix code style * Fix code style * Fix check blanck_index shape * Fix code style * Fix unit test for dynemic case * Fix code style * Fix copyryting * reverse changes * Update copyrite
This commit is contained in:
parent
20b9d390e1
commit
5408f611e9
116
ngraph/core/include/ngraph/op/ctc_greedy_decoder_seq_len.hpp
Normal file
116
ngraph/core/include/ngraph/op/ctc_greedy_decoder_seq_len.hpp
Normal file
@ -0,0 +1,116 @@
|
|||||||
|
//*****************************************************************************
|
||||||
|
// Copyright 2017-2021 Intel Corporation
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
//*****************************************************************************
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "ngraph/op/op.hpp"
|
||||||
|
|
||||||
|
namespace ngraph
|
||||||
|
{
|
||||||
|
namespace op
|
||||||
|
{
|
||||||
|
namespace v6
|
||||||
|
{
|
||||||
|
/// \brief Operator performing CTCGreedyDecoder
|
||||||
|
///
|
||||||
|
class NGRAPH_API CTCGreedyDecoderSeqLen : public Op
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
NGRAPH_RTTI_DECLARATION;
|
||||||
|
CTCGreedyDecoderSeqLen() = default;
|
||||||
|
/// \brief Constructs a CTCGreedyDecoderSeqLen operation
|
||||||
|
///
|
||||||
|
/// \param input 3-D tensor of logits on which greedy decoding is
|
||||||
|
/// performed
|
||||||
|
/// \param seq_len 1-D tensor of sequence lengths
|
||||||
|
/// \param merge_repeated Whether to merge repeated labels
|
||||||
|
/// \param classes_index_type Specifies the output classes_index tensor type
|
||||||
|
/// \param sequence_length_type Specifies the output sequence_length tensor type
|
||||||
|
CTCGreedyDecoderSeqLen(const Output<Node>& input,
|
||||||
|
const Output<Node>& seq_len,
|
||||||
|
const bool merge_repeated = true,
|
||||||
|
const element::Type& classes_index_type = element::i32,
|
||||||
|
const element::Type& sequence_length_type = element::i32);
|
||||||
|
/// \brief Constructs a CTCGreedyDecoderSeqLen operation
|
||||||
|
///
|
||||||
|
/// \param input 3-D tensor of logits on which greedy decoding is
|
||||||
|
/// performed
|
||||||
|
/// \param seq_len 1-D tensor of sequence lengths
|
||||||
|
/// \param blank_index Scalar or 1-D tensor with 1 element used to mark a
|
||||||
|
/// blank index
|
||||||
|
/// \param merge_repeated Whether to merge repeated labels
|
||||||
|
/// \param classes_index_type Specifies the output classes_index tensor type
|
||||||
|
/// \param sequence_length_type Specifies the output sequence_length tensor type
|
||||||
|
CTCGreedyDecoderSeqLen(const Output<Node>& input,
|
||||||
|
const Output<Node>& seq_len,
|
||||||
|
const Output<Node>& blank_index,
|
||||||
|
const bool merge_repeated = true,
|
||||||
|
const element::Type& classes_index_type = element::i32,
|
||||||
|
const element::Type& sequence_length_type = element::i32);
|
||||||
|
|
||||||
|
void validate_and_infer_types() override;
|
||||||
|
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||||
|
|
||||||
|
std::shared_ptr<Node>
|
||||||
|
clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||||
|
|
||||||
|
/// \brief Get merge_repeated attribute
|
||||||
|
///
|
||||||
|
/// \return Current value of merge_repeated attribute
|
||||||
|
///
|
||||||
|
bool get_merge_repeated() const { return m_merge_repeated; }
|
||||||
|
/// \brief Get classes_index_type attribute
|
||||||
|
///
|
||||||
|
/// \return Current value of classes_index_type attribute
|
||||||
|
///
|
||||||
|
const element::Type& get_classes_index_type() const { return m_classes_index_type; }
|
||||||
|
/// \brief Set classes_index_type attribute
|
||||||
|
///
|
||||||
|
/// \param classes_index_type Type of classes_index
|
||||||
|
///
|
||||||
|
void set_classes_index_type(const element::Type& classes_index_type)
|
||||||
|
{
|
||||||
|
m_classes_index_type = classes_index_type;
|
||||||
|
validate_and_infer_types();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// \brief Get sequence_length_type attribute
|
||||||
|
///
|
||||||
|
/// \return Current value of sequence_length_type attribute
|
||||||
|
///
|
||||||
|
const element::Type& get_sequence_length_type() const
|
||||||
|
{
|
||||||
|
return m_sequence_length_type;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// \brief Set sequence_length_type attribute
|
||||||
|
///
|
||||||
|
/// \param sequence_length_type Type of sequence length
|
||||||
|
///
|
||||||
|
void set_sequence_length_type(const element::Type& sequence_length_type)
|
||||||
|
{
|
||||||
|
m_sequence_length_type = sequence_length_type;
|
||||||
|
validate_and_infer_types();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool m_merge_repeated;
|
||||||
|
element::Type m_classes_index_type{element::i32};
|
||||||
|
element::Type m_sequence_length_type{element::i32};
|
||||||
|
};
|
||||||
|
} // namespace v6
|
||||||
|
} // namespace op
|
||||||
|
} // namespace ngraph
|
@ -44,6 +44,7 @@
|
|||||||
#include "ngraph/op/cos.hpp"
|
#include "ngraph/op/cos.hpp"
|
||||||
#include "ngraph/op/cosh.hpp"
|
#include "ngraph/op/cosh.hpp"
|
||||||
#include "ngraph/op/ctc_greedy_decoder.hpp"
|
#include "ngraph/op/ctc_greedy_decoder.hpp"
|
||||||
|
#include "ngraph/op/ctc_greedy_decoder_seq_len.hpp"
|
||||||
#include "ngraph/op/ctc_loss.hpp"
|
#include "ngraph/op/ctc_loss.hpp"
|
||||||
#include "ngraph/op/cum_sum.hpp"
|
#include "ngraph/op/cum_sum.hpp"
|
||||||
#include "ngraph/op/deformable_convolution.hpp"
|
#include "ngraph/op/deformable_convolution.hpp"
|
||||||
|
@ -173,5 +173,6 @@ NGRAPH_OP(RNNSequence, ngraph::op::v5)
|
|||||||
NGRAPH_OP(Round, ngraph::op::v5)
|
NGRAPH_OP(Round, ngraph::op::v5)
|
||||||
|
|
||||||
// New operations added in opset6
|
// New operations added in opset6
|
||||||
|
NGRAPH_OP(CTCGreedyDecoderSeqLen, ngraph::op::v6)
|
||||||
NGRAPH_OP(MVN, ngraph::op::v6)
|
NGRAPH_OP(MVN, ngraph::op::v6)
|
||||||
NGRAPH_OP(GatherElements, ngraph::op::v6)
|
NGRAPH_OP(GatherElements, ngraph::op::v6)
|
||||||
|
170
ngraph/core/src/op/ctc_greedy_decoder_seq_len.cpp
Normal file
170
ngraph/core/src/op/ctc_greedy_decoder_seq_len.cpp
Normal file
@ -0,0 +1,170 @@
|
|||||||
|
//*****************************************************************************
|
||||||
|
// Copyright 2017-2021 Intel Corporation
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
//*****************************************************************************
|
||||||
|
|
||||||
|
#include "itt.hpp"
|
||||||
|
|
||||||
|
#include "ngraph/op/ctc_greedy_decoder_seq_len.hpp"
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
using namespace ngraph;
|
||||||
|
|
||||||
|
NGRAPH_RTTI_DEFINITION(op::v6::CTCGreedyDecoderSeqLen, "CTCGreedyDecoderSeqLen", 6);
|
||||||
|
|
||||||
|
op::v6::CTCGreedyDecoderSeqLen::CTCGreedyDecoderSeqLen(const Output<Node>& input,
|
||||||
|
const Output<Node>& seq_len,
|
||||||
|
const bool merge_repeated,
|
||||||
|
const element::Type& classes_index_type,
|
||||||
|
const element::Type& sequence_length_type)
|
||||||
|
: Op({input, seq_len})
|
||||||
|
, m_merge_repeated(merge_repeated)
|
||||||
|
, m_classes_index_type(classes_index_type)
|
||||||
|
, m_sequence_length_type(sequence_length_type)
|
||||||
|
{
|
||||||
|
constructor_validate_and_infer_types();
|
||||||
|
}
|
||||||
|
|
||||||
|
op::v6::CTCGreedyDecoderSeqLen::CTCGreedyDecoderSeqLen(const Output<Node>& input,
|
||||||
|
const Output<Node>& seq_len,
|
||||||
|
const Output<Node>& blank_index,
|
||||||
|
const bool merge_repeated,
|
||||||
|
const element::Type& classes_index_type,
|
||||||
|
const element::Type& sequence_length_type)
|
||||||
|
: Op({input, seq_len, blank_index})
|
||||||
|
, m_merge_repeated(merge_repeated)
|
||||||
|
, m_classes_index_type(classes_index_type)
|
||||||
|
, m_sequence_length_type(sequence_length_type)
|
||||||
|
{
|
||||||
|
constructor_validate_and_infer_types();
|
||||||
|
}
|
||||||
|
|
||||||
|
void op::v6::CTCGreedyDecoderSeqLen::validate_and_infer_types()
|
||||||
|
{
|
||||||
|
NGRAPH_OP_SCOPE(v6_CTCGreedyDecoderSeqLen_validate_and_infer_types);
|
||||||
|
const auto& logits_pshape = get_input_partial_shape(0);
|
||||||
|
const auto& seq_len_pshape = get_input_partial_shape(1);
|
||||||
|
auto input_et = get_input_element_type(0);
|
||||||
|
const bool logits_is_static_rank = logits_pshape.rank().is_static();
|
||||||
|
const bool seq_len_is_static_rank = seq_len_pshape.rank().is_static();
|
||||||
|
|
||||||
|
// check ranks of input tensors
|
||||||
|
if (logits_is_static_rank)
|
||||||
|
{
|
||||||
|
NODE_VALIDATION_CHECK(this,
|
||||||
|
logits_pshape.rank().get_length() == 3,
|
||||||
|
"The rank of logits tensor must be equal to 3.");
|
||||||
|
}
|
||||||
|
if (seq_len_is_static_rank)
|
||||||
|
{
|
||||||
|
NODE_VALIDATION_CHECK(this,
|
||||||
|
seq_len_pshape.rank().get_length() == 1,
|
||||||
|
"The rank of sequence len tensor must be equal to 1.");
|
||||||
|
}
|
||||||
|
|
||||||
|
// check optional input type: blank index
|
||||||
|
if (get_input_size() == 3)
|
||||||
|
{
|
||||||
|
const auto& blank_index_type = get_input_element_type(2);
|
||||||
|
NODE_VALIDATION_CHECK(this,
|
||||||
|
blank_index_type.is_integral_number(),
|
||||||
|
"The blank index type is expected to be an integer type. Got: ",
|
||||||
|
blank_index_type);
|
||||||
|
|
||||||
|
const auto& blank_index_partial_shape = get_input_partial_shape(2);
|
||||||
|
if (blank_index_partial_shape.is_static())
|
||||||
|
{
|
||||||
|
Shape blank_index_shape = blank_index_partial_shape.to_shape();
|
||||||
|
NODE_VALIDATION_CHECK(this,
|
||||||
|
ngraph::is_scalar(blank_index_shape) ||
|
||||||
|
(is_vector(blank_index_shape) && (blank_index_shape[0] == 1)),
|
||||||
|
"Expected 0D or 1D tensor for the 'blank_index' input. Got: ",
|
||||||
|
blank_index_shape);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// validate input shapes and compute output shape
|
||||||
|
ngraph::Dimension batch_size = Dimension::dynamic();
|
||||||
|
ngraph::Dimension time_size = Dimension::dynamic();
|
||||||
|
|
||||||
|
if (logits_is_static_rank)
|
||||||
|
{
|
||||||
|
if (logits_pshape[0].is_static())
|
||||||
|
{
|
||||||
|
batch_size = logits_pshape[0];
|
||||||
|
}
|
||||||
|
if (logits_pshape[1].is_static())
|
||||||
|
{
|
||||||
|
time_size = logits_pshape[1];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (seq_len_is_static_rank && seq_len_pshape[0].is_static())
|
||||||
|
{
|
||||||
|
if (batch_size != Dimension::dynamic())
|
||||||
|
{
|
||||||
|
NODE_VALIDATION_CHECK(this,
|
||||||
|
seq_len_pshape[0] == batch_size,
|
||||||
|
"The first dimensions of input tensors must match.");
|
||||||
|
}
|
||||||
|
batch_size = seq_len_pshape[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
if (logits_is_static_rank && seq_len_is_static_rank)
|
||||||
|
{
|
||||||
|
batch_size = seq_len_pshape[0] & logits_pshape[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
set_output_type(0, m_classes_index_type, PartialShape{batch_size, time_size});
|
||||||
|
set_output_type(1, m_sequence_length_type, PartialShape{batch_size});
|
||||||
|
}
|
||||||
|
|
||||||
|
bool op::v6::CTCGreedyDecoderSeqLen::visit_attributes(AttributeVisitor& visitor)
|
||||||
|
{
|
||||||
|
NGRAPH_OP_SCOPE(v6_CTCGreedyDecoderSeqLen_visit_attributes);
|
||||||
|
visitor.on_attribute("merge_repeated", m_merge_repeated);
|
||||||
|
visitor.on_attribute("classes_index_type", m_classes_index_type);
|
||||||
|
visitor.on_attribute("sequence_length_type", m_sequence_length_type);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
shared_ptr<Node>
|
||||||
|
op::v6::CTCGreedyDecoderSeqLen::clone_with_new_inputs(const OutputVector& new_args) const
|
||||||
|
{
|
||||||
|
NGRAPH_OP_SCOPE(v6_CTCGreedyDecoderSeqLen_clone_with_new_inputs);
|
||||||
|
check_new_args_count(this, new_args);
|
||||||
|
|
||||||
|
size_t args_size = new_args.size();
|
||||||
|
if (args_size == 2)
|
||||||
|
{
|
||||||
|
return make_shared<CTCGreedyDecoderSeqLen>(new_args.at(0),
|
||||||
|
new_args.at(1),
|
||||||
|
m_merge_repeated,
|
||||||
|
m_classes_index_type,
|
||||||
|
m_sequence_length_type);
|
||||||
|
}
|
||||||
|
else if (args_size == 3)
|
||||||
|
{
|
||||||
|
return make_shared<CTCGreedyDecoderSeqLen>(new_args.at(0),
|
||||||
|
new_args.at(1),
|
||||||
|
new_args.at(2),
|
||||||
|
m_merge_repeated,
|
||||||
|
m_classes_index_type,
|
||||||
|
m_sequence_length_type);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
throw ngraph_error("Incorrect number of arguments");
|
||||||
|
}
|
||||||
|
}
|
@ -118,6 +118,7 @@ set(SRC
|
|||||||
type_prop/convert.cpp
|
type_prop/convert.cpp
|
||||||
type_prop/convolution.cpp
|
type_prop/convolution.cpp
|
||||||
type_prop/ctc_greedy_decoder.cpp
|
type_prop/ctc_greedy_decoder.cpp
|
||||||
|
type_prop/ctc_greedy_decoder_seq_len.cpp
|
||||||
type_prop/ctc_loss.cpp
|
type_prop/ctc_loss.cpp
|
||||||
type_prop/deformable_convolution.cpp
|
type_prop/deformable_convolution.cpp
|
||||||
type_prop/deformable_psroi_pooling.cpp
|
type_prop/deformable_psroi_pooling.cpp
|
||||||
|
203
ngraph/test/type_prop/ctc_greedy_decoder_seq_len.cpp
Normal file
203
ngraph/test/type_prop/ctc_greedy_decoder_seq_len.cpp
Normal file
@ -0,0 +1,203 @@
|
|||||||
|
//*****************************************************************************
|
||||||
|
// Copyright 2017-2021 Intel Corporation
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
//*****************************************************************************
|
||||||
|
|
||||||
|
#include "gtest/gtest.h"
|
||||||
|
#include "ngraph/ngraph.hpp"
|
||||||
|
#include "util/type_prop.hpp"
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
using namespace ngraph;
|
||||||
|
|
||||||
|
TEST(type_prop, ctc_greedy_decoder_seq_len_static_shapes)
|
||||||
|
{
|
||||||
|
PartialShape logits_shape{3, 100, 1200};
|
||||||
|
PartialShape seq_len_shape{3};
|
||||||
|
Shape out_shape1{3, 100};
|
||||||
|
Shape out_shape2{3};
|
||||||
|
auto P = make_shared<op::Parameter>(element::f32, logits_shape);
|
||||||
|
auto I = make_shared<op::Parameter>(element::i32, seq_len_shape);
|
||||||
|
auto G = make_shared<op::v6::CTCGreedyDecoderSeqLen>(P, I);
|
||||||
|
ASSERT_EQ(G->get_output_element_type(0), element::i32);
|
||||||
|
ASSERT_EQ(G->get_output_element_type(1), element::i32);
|
||||||
|
ASSERT_EQ(G->get_output_shape(0), out_shape1);
|
||||||
|
ASSERT_EQ(G->get_output_shape(1), out_shape2);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(type_prop, ctc_greedy_decoder_seq_len_static_shapes_with_bi)
|
||||||
|
{
|
||||||
|
PartialShape logits_shape{3, 100, 1200};
|
||||||
|
PartialShape seq_len_shape{3};
|
||||||
|
Shape out_shape1{3, 100};
|
||||||
|
Shape out_shape2{3};
|
||||||
|
auto P = make_shared<op::Parameter>(element::f32, logits_shape);
|
||||||
|
auto I = make_shared<op::Parameter>(element::i32, seq_len_shape);
|
||||||
|
auto BI = op::Constant::create(element::i32, Shape{}, {1});
|
||||||
|
auto G =
|
||||||
|
make_shared<op::v6::CTCGreedyDecoderSeqLen>(P, I, BI, false, element::i64, element::i64);
|
||||||
|
ASSERT_EQ(G->get_output_element_type(0), element::i64);
|
||||||
|
ASSERT_EQ(G->get_output_element_type(1), element::i64);
|
||||||
|
ASSERT_EQ(G->get_output_shape(0), out_shape1);
|
||||||
|
ASSERT_EQ(G->get_output_shape(1), out_shape2);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(type_prop, ctc_greedy_decoder_seq_len_static_shapes_with_dinemic_bi)
|
||||||
|
{
|
||||||
|
PartialShape logits_shape{3, 100, 1200};
|
||||||
|
PartialShape seq_len_shape{3};
|
||||||
|
Shape out_shape1{3, 100};
|
||||||
|
Shape out_shape2{3};
|
||||||
|
auto P = make_shared<op::Parameter>(element::f32, logits_shape);
|
||||||
|
auto I = make_shared<op::Parameter>(element::i32, seq_len_shape);
|
||||||
|
auto BI = make_shared<op::Parameter>(element::i32, PartialShape{Dimension::dynamic()});
|
||||||
|
auto G =
|
||||||
|
make_shared<op::v6::CTCGreedyDecoderSeqLen>(P, I, BI, false, element::i64, element::i64);
|
||||||
|
ASSERT_EQ(G->get_output_element_type(0), element::i64);
|
||||||
|
ASSERT_EQ(G->get_output_element_type(1), element::i64);
|
||||||
|
ASSERT_EQ(G->get_output_shape(0), out_shape1);
|
||||||
|
ASSERT_EQ(G->get_output_shape(1), out_shape2);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(type_prop, ctc_greedy_decoder_seq_len_output_static_shape1)
|
||||||
|
{
|
||||||
|
PartialShape logits_shape{Dimension::dynamic(), 100, 1200};
|
||||||
|
PartialShape seq_len_shape{3};
|
||||||
|
Shape out_shape1{3, 100};
|
||||||
|
Shape out_shape2{3};
|
||||||
|
auto P = make_shared<op::Parameter>(element::f32, logits_shape);
|
||||||
|
auto I = make_shared<op::Parameter>(element::i32, seq_len_shape);
|
||||||
|
auto G = make_shared<op::v6::CTCGreedyDecoderSeqLen>(P, I, false);
|
||||||
|
ASSERT_EQ(G->get_output_element_type(0), element::i32);
|
||||||
|
ASSERT_EQ(G->get_output_element_type(1), element::i32);
|
||||||
|
ASSERT_EQ(G->get_output_shape(0), out_shape1);
|
||||||
|
ASSERT_EQ(G->get_output_shape(1), out_shape2);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(type_prop, ctc_greedy_decoder_seq_len_dynamic_shapes)
|
||||||
|
{
|
||||||
|
PartialShape logits_shape{Dimension::dynamic(), Dimension::dynamic(), 1200};
|
||||||
|
PartialShape seq_len_shape{Dimension::dynamic()};
|
||||||
|
PartialShape out_shape1{Dimension::dynamic(), Dimension::dynamic()};
|
||||||
|
PartialShape out_shape2{Dimension::dynamic()};
|
||||||
|
auto P = make_shared<op::Parameter>(element::f32, logits_shape);
|
||||||
|
auto I = make_shared<op::Parameter>(element::i32, seq_len_shape);
|
||||||
|
auto G = make_shared<op::v6::CTCGreedyDecoderSeqLen>(P, I, false);
|
||||||
|
ASSERT_EQ(G->get_output_element_type(0), element::i32);
|
||||||
|
ASSERT_EQ(G->get_output_element_type(1), element::i32);
|
||||||
|
ASSERT_TRUE(G->get_output_partial_shape(0).same_scheme(out_shape1));
|
||||||
|
ASSERT_TRUE(G->get_output_partial_shape(1).same_scheme(out_shape2));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(type_prop, ctc_greedy_decoder_seq_len_dynamic_ranks1)
|
||||||
|
{
|
||||||
|
PartialShape logits_shape = PartialShape::dynamic();
|
||||||
|
PartialShape seq_len_shape{Dimension::dynamic()};
|
||||||
|
PartialShape out_shape1{Dimension::dynamic(), Dimension::dynamic()};
|
||||||
|
PartialShape out_shape2{Dimension::dynamic()};
|
||||||
|
auto P = make_shared<op::Parameter>(element::f32, logits_shape);
|
||||||
|
auto I = make_shared<op::Parameter>(element::i32, seq_len_shape);
|
||||||
|
auto G = make_shared<op::v6::CTCGreedyDecoderSeqLen>(P, I);
|
||||||
|
ASSERT_EQ(G->get_output_element_type(0), element::i32);
|
||||||
|
ASSERT_EQ(G->get_output_element_type(1), element::i32);
|
||||||
|
ASSERT_TRUE(G->get_output_partial_shape(0).same_scheme(out_shape1));
|
||||||
|
ASSERT_TRUE(G->get_output_partial_shape(1).same_scheme(out_shape2));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(type_prop, ctc_greedy_decoder_seq_len_dynamic_ranks2)
|
||||||
|
{
|
||||||
|
PartialShape logits_shape = PartialShape::dynamic();
|
||||||
|
PartialShape seq_mask_shape = PartialShape::dynamic();
|
||||||
|
PartialShape out_shape1{Dimension::dynamic(), Dimension::dynamic()};
|
||||||
|
PartialShape out_shape2{Dimension::dynamic()};
|
||||||
|
auto P = make_shared<op::Parameter>(element::f32, logits_shape);
|
||||||
|
auto I = make_shared<op::Parameter>(element::i32, seq_mask_shape);
|
||||||
|
auto G = make_shared<op::v6::CTCGreedyDecoderSeqLen>(P, I, false);
|
||||||
|
ASSERT_EQ(G->get_output_element_type(0), element::i32);
|
||||||
|
ASSERT_EQ(G->get_output_element_type(1), element::i32);
|
||||||
|
ASSERT_TRUE(G->get_output_partial_shape(0).same_scheme(out_shape1));
|
||||||
|
ASSERT_TRUE(G->get_output_partial_shape(1).same_scheme(out_shape2));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(type_prop, ctc_greedy_decoder_seq_len_incorrect_rank)
|
||||||
|
{
|
||||||
|
PartialShape logits_shape{Dimension::dynamic(), 100, 1200, 5};
|
||||||
|
PartialShape seq_len_shape{3};
|
||||||
|
auto P = make_shared<op::Parameter>(element::f32, logits_shape);
|
||||||
|
auto I = make_shared<op::Parameter>(element::i32, seq_len_shape);
|
||||||
|
|
||||||
|
try
|
||||||
|
{
|
||||||
|
auto G = make_shared<op::v6::CTCGreedyDecoderSeqLen>(P, I, false);
|
||||||
|
// Should have thrown, so fail if it didn't
|
||||||
|
FAIL() << "Incorrect indices rank";
|
||||||
|
}
|
||||||
|
catch (const NodeValidationFailure& error)
|
||||||
|
{
|
||||||
|
EXPECT_HAS_SUBSTRING(error.what(),
|
||||||
|
std::string("The rank of logits tensor must be equal to 3."));
|
||||||
|
}
|
||||||
|
catch (...)
|
||||||
|
{
|
||||||
|
FAIL() << "Rank check failed for unexpected reason";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(type_prop, ctc_greedy_decoder_seq_len_incorrect_rank2)
|
||||||
|
{
|
||||||
|
PartialShape logits_shape{3, 100, 1200};
|
||||||
|
PartialShape seq_len_shape{3, 100};
|
||||||
|
auto P = make_shared<op::Parameter>(element::f32, logits_shape);
|
||||||
|
auto I = make_shared<op::Parameter>(element::i32, seq_len_shape);
|
||||||
|
|
||||||
|
try
|
||||||
|
{
|
||||||
|
auto G = make_shared<op::v6::CTCGreedyDecoderSeqLen>(P, I, false);
|
||||||
|
// Should have thrown, so fail if it didn't
|
||||||
|
FAIL() << "Incorrect indices rank";
|
||||||
|
}
|
||||||
|
catch (const NodeValidationFailure& error)
|
||||||
|
{
|
||||||
|
EXPECT_HAS_SUBSTRING(error.what(),
|
||||||
|
std::string("The rank of sequence len tensor must be equal to 1."));
|
||||||
|
}
|
||||||
|
catch (...)
|
||||||
|
{
|
||||||
|
FAIL() << "Rank check failed for unexpected reason";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(type_prop, ctc_greedy_decoder_seq_len_mismatched_dim1)
|
||||||
|
{
|
||||||
|
PartialShape logits_shape{4, 100, 1200};
|
||||||
|
PartialShape seq_mask_shape{3};
|
||||||
|
auto P = make_shared<op::Parameter>(element::f32, logits_shape);
|
||||||
|
auto I = make_shared<op::Parameter>(element::i32, seq_mask_shape);
|
||||||
|
|
||||||
|
try
|
||||||
|
{
|
||||||
|
auto G = make_shared<op::v6::CTCGreedyDecoderSeqLen>(P, I, false);
|
||||||
|
// Should have thrown, so fail if it didn't
|
||||||
|
FAIL() << "Incorrect indices rank";
|
||||||
|
}
|
||||||
|
catch (const NodeValidationFailure& error)
|
||||||
|
{
|
||||||
|
EXPECT_HAS_SUBSTRING(error.what(),
|
||||||
|
std::string("The first dimensions of input tensors must match."));
|
||||||
|
}
|
||||||
|
catch (...)
|
||||||
|
{
|
||||||
|
FAIL() << "Rank check failed for unexpected reason";
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user