Add ctc gready decoder sec len op to ngraph (#3499)
* Add ctc gready decoder sec len op to ngraph * Remove some comments * Add second constructor * Fix code style * Fix code style * Add unit tests * Add tests to cmake * Fix according to review * Fix code style * fix * Change input layoyt * Fix code style * Add unit tests * Add 3 input tensor check * Update shell impl * Fix code style * Fix code style * Add doxy gen * Fix code style * Update doxigen * Update constructor description * Fix code style * Refactoring code * fix code style * Optimize op constructor * Add macros. Optimize code for validate_and_infer_types * Refactoring code * Fix code style * Fix code style * Fix check blanck_index shape * Fix code style * Fix unit test for dynemic case * Fix code style * Fix copyryting * reverse changes * Update copyrite
This commit is contained in:
parent
20b9d390e1
commit
5408f611e9
116
ngraph/core/include/ngraph/op/ctc_greedy_decoder_seq_len.hpp
Normal file
116
ngraph/core/include/ngraph/op/ctc_greedy_decoder_seq_len.hpp
Normal file
@ -0,0 +1,116 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2021 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ngraph/op/op.hpp"
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
namespace op
|
||||
{
|
||||
namespace v6
|
||||
{
|
||||
/// \brief Operator performing CTCGreedyDecoder
|
||||
///
|
||||
class NGRAPH_API CTCGreedyDecoderSeqLen : public Op
|
||||
{
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
CTCGreedyDecoderSeqLen() = default;
|
||||
/// \brief Constructs a CTCGreedyDecoderSeqLen operation
|
||||
///
|
||||
/// \param input 3-D tensor of logits on which greedy decoding is
|
||||
/// performed
|
||||
/// \param seq_len 1-D tensor of sequence lengths
|
||||
/// \param merge_repeated Whether to merge repeated labels
|
||||
/// \param classes_index_type Specifies the output classes_index tensor type
|
||||
/// \param sequence_length_type Specifies the output sequence_length tensor type
|
||||
CTCGreedyDecoderSeqLen(const Output<Node>& input,
|
||||
const Output<Node>& seq_len,
|
||||
const bool merge_repeated = true,
|
||||
const element::Type& classes_index_type = element::i32,
|
||||
const element::Type& sequence_length_type = element::i32);
|
||||
/// \brief Constructs a CTCGreedyDecoderSeqLen operation
|
||||
///
|
||||
/// \param input 3-D tensor of logits on which greedy decoding is
|
||||
/// performed
|
||||
/// \param seq_len 1-D tensor of sequence lengths
|
||||
/// \param blank_index Scalar or 1-D tensor with 1 element used to mark a
|
||||
/// blank index
|
||||
/// \param merge_repeated Whether to merge repeated labels
|
||||
/// \param classes_index_type Specifies the output classes_index tensor type
|
||||
/// \param sequence_length_type Specifies the output sequence_length tensor type
|
||||
CTCGreedyDecoderSeqLen(const Output<Node>& input,
|
||||
const Output<Node>& seq_len,
|
||||
const Output<Node>& blank_index,
|
||||
const bool merge_repeated = true,
|
||||
const element::Type& classes_index_type = element::i32,
|
||||
const element::Type& sequence_length_type = element::i32);
|
||||
|
||||
void validate_and_infer_types() override;
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
|
||||
std::shared_ptr<Node>
|
||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
|
||||
/// \brief Get merge_repeated attribute
|
||||
///
|
||||
/// \return Current value of merge_repeated attribute
|
||||
///
|
||||
bool get_merge_repeated() const { return m_merge_repeated; }
|
||||
/// \brief Get classes_index_type attribute
|
||||
///
|
||||
/// \return Current value of classes_index_type attribute
|
||||
///
|
||||
const element::Type& get_classes_index_type() const { return m_classes_index_type; }
|
||||
/// \brief Set classes_index_type attribute
|
||||
///
|
||||
/// \param classes_index_type Type of classes_index
|
||||
///
|
||||
void set_classes_index_type(const element::Type& classes_index_type)
|
||||
{
|
||||
m_classes_index_type = classes_index_type;
|
||||
validate_and_infer_types();
|
||||
}
|
||||
|
||||
/// \brief Get sequence_length_type attribute
|
||||
///
|
||||
/// \return Current value of sequence_length_type attribute
|
||||
///
|
||||
const element::Type& get_sequence_length_type() const
|
||||
{
|
||||
return m_sequence_length_type;
|
||||
}
|
||||
|
||||
/// \brief Set sequence_length_type attribute
|
||||
///
|
||||
/// \param sequence_length_type Type of sequence length
|
||||
///
|
||||
void set_sequence_length_type(const element::Type& sequence_length_type)
|
||||
{
|
||||
m_sequence_length_type = sequence_length_type;
|
||||
validate_and_infer_types();
|
||||
}
|
||||
|
||||
private:
|
||||
bool m_merge_repeated;
|
||||
element::Type m_classes_index_type{element::i32};
|
||||
element::Type m_sequence_length_type{element::i32};
|
||||
};
|
||||
} // namespace v6
|
||||
} // namespace op
|
||||
} // namespace ngraph
|
@ -44,6 +44,7 @@
|
||||
#include "ngraph/op/cos.hpp"
|
||||
#include "ngraph/op/cosh.hpp"
|
||||
#include "ngraph/op/ctc_greedy_decoder.hpp"
|
||||
#include "ngraph/op/ctc_greedy_decoder_seq_len.hpp"
|
||||
#include "ngraph/op/ctc_loss.hpp"
|
||||
#include "ngraph/op/cum_sum.hpp"
|
||||
#include "ngraph/op/deformable_convolution.hpp"
|
||||
|
@ -173,5 +173,6 @@ NGRAPH_OP(RNNSequence, ngraph::op::v5)
|
||||
NGRAPH_OP(Round, ngraph::op::v5)
|
||||
|
||||
// New operations added in opset6
|
||||
NGRAPH_OP(CTCGreedyDecoderSeqLen, ngraph::op::v6)
|
||||
NGRAPH_OP(MVN, ngraph::op::v6)
|
||||
NGRAPH_OP(GatherElements, ngraph::op::v6)
|
||||
|
170
ngraph/core/src/op/ctc_greedy_decoder_seq_len.cpp
Normal file
170
ngraph/core/src/op/ctc_greedy_decoder_seq_len.cpp
Normal file
@ -0,0 +1,170 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2021 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#include "itt.hpp"
|
||||
|
||||
#include "ngraph/op/ctc_greedy_decoder_seq_len.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(op::v6::CTCGreedyDecoderSeqLen, "CTCGreedyDecoderSeqLen", 6);
|
||||
|
||||
op::v6::CTCGreedyDecoderSeqLen::CTCGreedyDecoderSeqLen(const Output<Node>& input,
|
||||
const Output<Node>& seq_len,
|
||||
const bool merge_repeated,
|
||||
const element::Type& classes_index_type,
|
||||
const element::Type& sequence_length_type)
|
||||
: Op({input, seq_len})
|
||||
, m_merge_repeated(merge_repeated)
|
||||
, m_classes_index_type(classes_index_type)
|
||||
, m_sequence_length_type(sequence_length_type)
|
||||
{
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
op::v6::CTCGreedyDecoderSeqLen::CTCGreedyDecoderSeqLen(const Output<Node>& input,
|
||||
const Output<Node>& seq_len,
|
||||
const Output<Node>& blank_index,
|
||||
const bool merge_repeated,
|
||||
const element::Type& classes_index_type,
|
||||
const element::Type& sequence_length_type)
|
||||
: Op({input, seq_len, blank_index})
|
||||
, m_merge_repeated(merge_repeated)
|
||||
, m_classes_index_type(classes_index_type)
|
||||
, m_sequence_length_type(sequence_length_type)
|
||||
{
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
void op::v6::CTCGreedyDecoderSeqLen::validate_and_infer_types()
|
||||
{
|
||||
NGRAPH_OP_SCOPE(v6_CTCGreedyDecoderSeqLen_validate_and_infer_types);
|
||||
const auto& logits_pshape = get_input_partial_shape(0);
|
||||
const auto& seq_len_pshape = get_input_partial_shape(1);
|
||||
auto input_et = get_input_element_type(0);
|
||||
const bool logits_is_static_rank = logits_pshape.rank().is_static();
|
||||
const bool seq_len_is_static_rank = seq_len_pshape.rank().is_static();
|
||||
|
||||
// check ranks of input tensors
|
||||
if (logits_is_static_rank)
|
||||
{
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
logits_pshape.rank().get_length() == 3,
|
||||
"The rank of logits tensor must be equal to 3.");
|
||||
}
|
||||
if (seq_len_is_static_rank)
|
||||
{
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
seq_len_pshape.rank().get_length() == 1,
|
||||
"The rank of sequence len tensor must be equal to 1.");
|
||||
}
|
||||
|
||||
// check optional input type: blank index
|
||||
if (get_input_size() == 3)
|
||||
{
|
||||
const auto& blank_index_type = get_input_element_type(2);
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
blank_index_type.is_integral_number(),
|
||||
"The blank index type is expected to be an integer type. Got: ",
|
||||
blank_index_type);
|
||||
|
||||
const auto& blank_index_partial_shape = get_input_partial_shape(2);
|
||||
if (blank_index_partial_shape.is_static())
|
||||
{
|
||||
Shape blank_index_shape = blank_index_partial_shape.to_shape();
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
ngraph::is_scalar(blank_index_shape) ||
|
||||
(is_vector(blank_index_shape) && (blank_index_shape[0] == 1)),
|
||||
"Expected 0D or 1D tensor for the 'blank_index' input. Got: ",
|
||||
blank_index_shape);
|
||||
}
|
||||
}
|
||||
|
||||
// validate input shapes and compute output shape
|
||||
ngraph::Dimension batch_size = Dimension::dynamic();
|
||||
ngraph::Dimension time_size = Dimension::dynamic();
|
||||
|
||||
if (logits_is_static_rank)
|
||||
{
|
||||
if (logits_pshape[0].is_static())
|
||||
{
|
||||
batch_size = logits_pshape[0];
|
||||
}
|
||||
if (logits_pshape[1].is_static())
|
||||
{
|
||||
time_size = logits_pshape[1];
|
||||
}
|
||||
}
|
||||
|
||||
if (seq_len_is_static_rank && seq_len_pshape[0].is_static())
|
||||
{
|
||||
if (batch_size != Dimension::dynamic())
|
||||
{
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
seq_len_pshape[0] == batch_size,
|
||||
"The first dimensions of input tensors must match.");
|
||||
}
|
||||
batch_size = seq_len_pshape[0];
|
||||
}
|
||||
|
||||
if (logits_is_static_rank && seq_len_is_static_rank)
|
||||
{
|
||||
batch_size = seq_len_pshape[0] & logits_pshape[0];
|
||||
}
|
||||
|
||||
set_output_type(0, m_classes_index_type, PartialShape{batch_size, time_size});
|
||||
set_output_type(1, m_sequence_length_type, PartialShape{batch_size});
|
||||
}
|
||||
|
||||
bool op::v6::CTCGreedyDecoderSeqLen::visit_attributes(AttributeVisitor& visitor)
|
||||
{
|
||||
NGRAPH_OP_SCOPE(v6_CTCGreedyDecoderSeqLen_visit_attributes);
|
||||
visitor.on_attribute("merge_repeated", m_merge_repeated);
|
||||
visitor.on_attribute("classes_index_type", m_classes_index_type);
|
||||
visitor.on_attribute("sequence_length_type", m_sequence_length_type);
|
||||
return true;
|
||||
}
|
||||
|
||||
shared_ptr<Node>
|
||||
op::v6::CTCGreedyDecoderSeqLen::clone_with_new_inputs(const OutputVector& new_args) const
|
||||
{
|
||||
NGRAPH_OP_SCOPE(v6_CTCGreedyDecoderSeqLen_clone_with_new_inputs);
|
||||
check_new_args_count(this, new_args);
|
||||
|
||||
size_t args_size = new_args.size();
|
||||
if (args_size == 2)
|
||||
{
|
||||
return make_shared<CTCGreedyDecoderSeqLen>(new_args.at(0),
|
||||
new_args.at(1),
|
||||
m_merge_repeated,
|
||||
m_classes_index_type,
|
||||
m_sequence_length_type);
|
||||
}
|
||||
else if (args_size == 3)
|
||||
{
|
||||
return make_shared<CTCGreedyDecoderSeqLen>(new_args.at(0),
|
||||
new_args.at(1),
|
||||
new_args.at(2),
|
||||
m_merge_repeated,
|
||||
m_classes_index_type,
|
||||
m_sequence_length_type);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw ngraph_error("Incorrect number of arguments");
|
||||
}
|
||||
}
|
@ -118,6 +118,7 @@ set(SRC
|
||||
type_prop/convert.cpp
|
||||
type_prop/convolution.cpp
|
||||
type_prop/ctc_greedy_decoder.cpp
|
||||
type_prop/ctc_greedy_decoder_seq_len.cpp
|
||||
type_prop/ctc_loss.cpp
|
||||
type_prop/deformable_convolution.cpp
|
||||
type_prop/deformable_psroi_pooling.cpp
|
||||
|
203
ngraph/test/type_prop/ctc_greedy_decoder_seq_len.cpp
Normal file
203
ngraph/test/type_prop/ctc_greedy_decoder_seq_len.cpp
Normal file
@ -0,0 +1,203 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2021 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ngraph/ngraph.hpp"
|
||||
#include "util/type_prop.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
TEST(type_prop, ctc_greedy_decoder_seq_len_static_shapes)
|
||||
{
|
||||
PartialShape logits_shape{3, 100, 1200};
|
||||
PartialShape seq_len_shape{3};
|
||||
Shape out_shape1{3, 100};
|
||||
Shape out_shape2{3};
|
||||
auto P = make_shared<op::Parameter>(element::f32, logits_shape);
|
||||
auto I = make_shared<op::Parameter>(element::i32, seq_len_shape);
|
||||
auto G = make_shared<op::v6::CTCGreedyDecoderSeqLen>(P, I);
|
||||
ASSERT_EQ(G->get_output_element_type(0), element::i32);
|
||||
ASSERT_EQ(G->get_output_element_type(1), element::i32);
|
||||
ASSERT_EQ(G->get_output_shape(0), out_shape1);
|
||||
ASSERT_EQ(G->get_output_shape(1), out_shape2);
|
||||
}
|
||||
|
||||
TEST(type_prop, ctc_greedy_decoder_seq_len_static_shapes_with_bi)
|
||||
{
|
||||
PartialShape logits_shape{3, 100, 1200};
|
||||
PartialShape seq_len_shape{3};
|
||||
Shape out_shape1{3, 100};
|
||||
Shape out_shape2{3};
|
||||
auto P = make_shared<op::Parameter>(element::f32, logits_shape);
|
||||
auto I = make_shared<op::Parameter>(element::i32, seq_len_shape);
|
||||
auto BI = op::Constant::create(element::i32, Shape{}, {1});
|
||||
auto G =
|
||||
make_shared<op::v6::CTCGreedyDecoderSeqLen>(P, I, BI, false, element::i64, element::i64);
|
||||
ASSERT_EQ(G->get_output_element_type(0), element::i64);
|
||||
ASSERT_EQ(G->get_output_element_type(1), element::i64);
|
||||
ASSERT_EQ(G->get_output_shape(0), out_shape1);
|
||||
ASSERT_EQ(G->get_output_shape(1), out_shape2);
|
||||
}
|
||||
|
||||
TEST(type_prop, ctc_greedy_decoder_seq_len_static_shapes_with_dinemic_bi)
|
||||
{
|
||||
PartialShape logits_shape{3, 100, 1200};
|
||||
PartialShape seq_len_shape{3};
|
||||
Shape out_shape1{3, 100};
|
||||
Shape out_shape2{3};
|
||||
auto P = make_shared<op::Parameter>(element::f32, logits_shape);
|
||||
auto I = make_shared<op::Parameter>(element::i32, seq_len_shape);
|
||||
auto BI = make_shared<op::Parameter>(element::i32, PartialShape{Dimension::dynamic()});
|
||||
auto G =
|
||||
make_shared<op::v6::CTCGreedyDecoderSeqLen>(P, I, BI, false, element::i64, element::i64);
|
||||
ASSERT_EQ(G->get_output_element_type(0), element::i64);
|
||||
ASSERT_EQ(G->get_output_element_type(1), element::i64);
|
||||
ASSERT_EQ(G->get_output_shape(0), out_shape1);
|
||||
ASSERT_EQ(G->get_output_shape(1), out_shape2);
|
||||
}
|
||||
|
||||
TEST(type_prop, ctc_greedy_decoder_seq_len_output_static_shape1)
|
||||
{
|
||||
PartialShape logits_shape{Dimension::dynamic(), 100, 1200};
|
||||
PartialShape seq_len_shape{3};
|
||||
Shape out_shape1{3, 100};
|
||||
Shape out_shape2{3};
|
||||
auto P = make_shared<op::Parameter>(element::f32, logits_shape);
|
||||
auto I = make_shared<op::Parameter>(element::i32, seq_len_shape);
|
||||
auto G = make_shared<op::v6::CTCGreedyDecoderSeqLen>(P, I, false);
|
||||
ASSERT_EQ(G->get_output_element_type(0), element::i32);
|
||||
ASSERT_EQ(G->get_output_element_type(1), element::i32);
|
||||
ASSERT_EQ(G->get_output_shape(0), out_shape1);
|
||||
ASSERT_EQ(G->get_output_shape(1), out_shape2);
|
||||
}
|
||||
|
||||
TEST(type_prop, ctc_greedy_decoder_seq_len_dynamic_shapes)
|
||||
{
|
||||
PartialShape logits_shape{Dimension::dynamic(), Dimension::dynamic(), 1200};
|
||||
PartialShape seq_len_shape{Dimension::dynamic()};
|
||||
PartialShape out_shape1{Dimension::dynamic(), Dimension::dynamic()};
|
||||
PartialShape out_shape2{Dimension::dynamic()};
|
||||
auto P = make_shared<op::Parameter>(element::f32, logits_shape);
|
||||
auto I = make_shared<op::Parameter>(element::i32, seq_len_shape);
|
||||
auto G = make_shared<op::v6::CTCGreedyDecoderSeqLen>(P, I, false);
|
||||
ASSERT_EQ(G->get_output_element_type(0), element::i32);
|
||||
ASSERT_EQ(G->get_output_element_type(1), element::i32);
|
||||
ASSERT_TRUE(G->get_output_partial_shape(0).same_scheme(out_shape1));
|
||||
ASSERT_TRUE(G->get_output_partial_shape(1).same_scheme(out_shape2));
|
||||
}
|
||||
|
||||
TEST(type_prop, ctc_greedy_decoder_seq_len_dynamic_ranks1)
|
||||
{
|
||||
PartialShape logits_shape = PartialShape::dynamic();
|
||||
PartialShape seq_len_shape{Dimension::dynamic()};
|
||||
PartialShape out_shape1{Dimension::dynamic(), Dimension::dynamic()};
|
||||
PartialShape out_shape2{Dimension::dynamic()};
|
||||
auto P = make_shared<op::Parameter>(element::f32, logits_shape);
|
||||
auto I = make_shared<op::Parameter>(element::i32, seq_len_shape);
|
||||
auto G = make_shared<op::v6::CTCGreedyDecoderSeqLen>(P, I);
|
||||
ASSERT_EQ(G->get_output_element_type(0), element::i32);
|
||||
ASSERT_EQ(G->get_output_element_type(1), element::i32);
|
||||
ASSERT_TRUE(G->get_output_partial_shape(0).same_scheme(out_shape1));
|
||||
ASSERT_TRUE(G->get_output_partial_shape(1).same_scheme(out_shape2));
|
||||
}
|
||||
|
||||
TEST(type_prop, ctc_greedy_decoder_seq_len_dynamic_ranks2)
|
||||
{
|
||||
PartialShape logits_shape = PartialShape::dynamic();
|
||||
PartialShape seq_mask_shape = PartialShape::dynamic();
|
||||
PartialShape out_shape1{Dimension::dynamic(), Dimension::dynamic()};
|
||||
PartialShape out_shape2{Dimension::dynamic()};
|
||||
auto P = make_shared<op::Parameter>(element::f32, logits_shape);
|
||||
auto I = make_shared<op::Parameter>(element::i32, seq_mask_shape);
|
||||
auto G = make_shared<op::v6::CTCGreedyDecoderSeqLen>(P, I, false);
|
||||
ASSERT_EQ(G->get_output_element_type(0), element::i32);
|
||||
ASSERT_EQ(G->get_output_element_type(1), element::i32);
|
||||
ASSERT_TRUE(G->get_output_partial_shape(0).same_scheme(out_shape1));
|
||||
ASSERT_TRUE(G->get_output_partial_shape(1).same_scheme(out_shape2));
|
||||
}
|
||||
|
||||
TEST(type_prop, ctc_greedy_decoder_seq_len_incorrect_rank)
|
||||
{
|
||||
PartialShape logits_shape{Dimension::dynamic(), 100, 1200, 5};
|
||||
PartialShape seq_len_shape{3};
|
||||
auto P = make_shared<op::Parameter>(element::f32, logits_shape);
|
||||
auto I = make_shared<op::Parameter>(element::i32, seq_len_shape);
|
||||
|
||||
try
|
||||
{
|
||||
auto G = make_shared<op::v6::CTCGreedyDecoderSeqLen>(P, I, false);
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "Incorrect indices rank";
|
||||
}
|
||||
catch (const NodeValidationFailure& error)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(error.what(),
|
||||
std::string("The rank of logits tensor must be equal to 3."));
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
FAIL() << "Rank check failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, ctc_greedy_decoder_seq_len_incorrect_rank2)
|
||||
{
|
||||
PartialShape logits_shape{3, 100, 1200};
|
||||
PartialShape seq_len_shape{3, 100};
|
||||
auto P = make_shared<op::Parameter>(element::f32, logits_shape);
|
||||
auto I = make_shared<op::Parameter>(element::i32, seq_len_shape);
|
||||
|
||||
try
|
||||
{
|
||||
auto G = make_shared<op::v6::CTCGreedyDecoderSeqLen>(P, I, false);
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "Incorrect indices rank";
|
||||
}
|
||||
catch (const NodeValidationFailure& error)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(error.what(),
|
||||
std::string("The rank of sequence len tensor must be equal to 1."));
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
FAIL() << "Rank check failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, ctc_greedy_decoder_seq_len_mismatched_dim1)
|
||||
{
|
||||
PartialShape logits_shape{4, 100, 1200};
|
||||
PartialShape seq_mask_shape{3};
|
||||
auto P = make_shared<op::Parameter>(element::f32, logits_shape);
|
||||
auto I = make_shared<op::Parameter>(element::i32, seq_mask_shape);
|
||||
|
||||
try
|
||||
{
|
||||
auto G = make_shared<op::v6::CTCGreedyDecoderSeqLen>(P, I, false);
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "Incorrect indices rank";
|
||||
}
|
||||
catch (const NodeValidationFailure& error)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(error.what(),
|
||||
std::string("The first dimensions of input tensors must match."));
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
FAIL() << "Rank check failed for unexpected reason";
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user