Extend nGraph for operation CTCLoss (#1236)
* Extend nGraph for operation CTCLoss Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com> * Fixes as per comments Co-authored-by: Nikolay Shchegolev <nikolay.shchegolev@intel.com>
This commit is contained in:
@@ -157,6 +157,8 @@ set (SRC
|
||||
op/cosh.hpp
|
||||
op/ctc_greedy_decoder.cpp
|
||||
op/ctc_greedy_decoder.hpp
|
||||
op/ctc_loss.cpp
|
||||
op/ctc_loss.hpp
|
||||
op/cum_sum.cpp
|
||||
op/cum_sum.hpp
|
||||
op/crop_and_resize.cpp
|
||||
|
||||
226
ngraph/src/ngraph/op/ctc_loss.cpp
Normal file
226
ngraph/src/ngraph/op/ctc_loss.cpp
Normal file
@@ -0,0 +1,226 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#include "ngraph/op/ctc_loss.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
constexpr NodeTypeInfo op::CTCLoss::type_info;
|
||||
|
||||
op::CTCLoss::CTCLoss(const Output<Node>& logits,
|
||||
const Output<Node>& logit_length,
|
||||
const Output<Node>& labels,
|
||||
const Output<Node>& label_length,
|
||||
const Output<Node>& blank_index,
|
||||
const bool preprocess_collapse_repeated,
|
||||
const bool ctc_merge_repeated,
|
||||
const bool unique)
|
||||
: Op({logits, logit_length, labels, label_length, blank_index})
|
||||
, preprocess_collapse_repeated_(preprocess_collapse_repeated)
|
||||
, ctc_merge_repeated_(ctc_merge_repeated)
|
||||
, unique_(unique)
|
||||
{
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
void op::CTCLoss::validate_and_infer_types()
|
||||
{
|
||||
// check types of input tensors
|
||||
const auto& logits_type = get_input_element_type(0);
|
||||
const auto& logit_length_type = get_input_element_type(1);
|
||||
const auto& labels_type = get_input_element_type(2);
|
||||
const auto& label_length_type = get_input_element_type(3);
|
||||
const auto& blank_index_type = get_input_element_type(4);
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
logits_type.is_real(),
|
||||
"The data type for logits is expected to be a floating point type. Got: ",
|
||||
logits_type);
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
logit_length_type.is_integral_number(),
|
||||
"The logit length type is expected to be an integer type. Got: ",
|
||||
logit_length_type);
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
labels_type.is_integral_number(),
|
||||
"The labels type is expected to be an integer type. Got: ",
|
||||
labels_type);
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
label_length_type.is_integral_number(),
|
||||
"The label length type is expected to be an integer type. Got: ",
|
||||
label_length_type);
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
blank_index_type.is_integral_number(),
|
||||
"The blank index type is expected to be an integer type. Got: ",
|
||||
blank_index_type);
|
||||
|
||||
// check ranks of input tensors
|
||||
const auto& logits_pshape = get_input_partial_shape(0);
|
||||
const auto& logit_length_pshape = get_input_partial_shape(1);
|
||||
const auto& labels_pshape = get_input_partial_shape(2);
|
||||
const auto& label_length_pshape = get_input_partial_shape(3);
|
||||
const auto& blank_index_pshape = get_input_partial_shape(4);
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
logits_pshape.rank().compatible(3),
|
||||
"Expected a 3D tensor for logits. Got: ",
|
||||
logits_pshape);
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
logit_length_pshape.rank().compatible(1),
|
||||
"Expected a 1D tensor for logit length. Got: ",
|
||||
logit_length_pshape);
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
labels_pshape.rank().compatible(2),
|
||||
"Expected a 2D tensor for labels. Got: ",
|
||||
labels_pshape);
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
label_length_pshape.rank().compatible(1),
|
||||
"Expected a 1D tensor for label length. Got: ",
|
||||
label_length_pshape);
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
blank_index_pshape.rank().compatible(0),
|
||||
"Expected a scalar for blank index. Got: ",
|
||||
blank_index_pshape);
|
||||
|
||||
// check shapes of input tensors
|
||||
size_t batch_size = 1;
|
||||
bool is_batch_size_set = false;
|
||||
size_t time_steps = 1;
|
||||
bool is_time_steps_set = false;
|
||||
|
||||
if (logits_pshape.rank().is_static())
|
||||
{
|
||||
if (logits_pshape[0].is_static())
|
||||
{
|
||||
batch_size = logits_pshape[0].get_length();
|
||||
is_batch_size_set = true;
|
||||
}
|
||||
if (logits_pshape[1].is_static())
|
||||
{
|
||||
time_steps = logits_pshape[1].get_length();
|
||||
is_time_steps_set = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (logit_length_pshape.is_static())
|
||||
{
|
||||
if (is_batch_size_set)
|
||||
{
|
||||
NODE_VALIDATION_CHECK(
|
||||
this,
|
||||
logit_length_pshape[0].compatible(batch_size),
|
||||
"The first dimension of logit length must be equal to the first dimension ",
|
||||
"of the logits. Got: ",
|
||||
logit_length_pshape[0],
|
||||
" and: ",
|
||||
batch_size);
|
||||
}
|
||||
else if (logit_length_pshape[0].is_static())
|
||||
{
|
||||
batch_size = logit_length_pshape[0].get_length();
|
||||
is_batch_size_set = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (labels_pshape.is_static())
|
||||
{
|
||||
if (is_batch_size_set)
|
||||
{
|
||||
NODE_VALIDATION_CHECK(
|
||||
this,
|
||||
labels_pshape[0].compatible(batch_size),
|
||||
"The first dimension of labels must be equal to the first dimension ",
|
||||
"of the logits and the logit length. Got: ",
|
||||
labels_pshape[0],
|
||||
" and: ",
|
||||
batch_size);
|
||||
}
|
||||
else if (labels_pshape[0].is_static())
|
||||
{
|
||||
batch_size = labels_pshape[0].get_length();
|
||||
is_batch_size_set = true;
|
||||
}
|
||||
|
||||
if (is_time_steps_set)
|
||||
{
|
||||
NODE_VALIDATION_CHECK(
|
||||
this,
|
||||
labels_pshape[1].compatible(time_steps),
|
||||
"The second dimension of labels must be equal to the second dimension ",
|
||||
"of logits. Got: ",
|
||||
labels_pshape[1],
|
||||
" and: ",
|
||||
time_steps);
|
||||
}
|
||||
}
|
||||
|
||||
if (label_length_pshape.is_static())
|
||||
{
|
||||
if (!is_batch_size_set && label_length_pshape[0].is_static())
|
||||
{
|
||||
batch_size = label_length_pshape[0].get_length();
|
||||
is_batch_size_set = true;
|
||||
}
|
||||
NODE_VALIDATION_CHECK(
|
||||
this,
|
||||
label_length_pshape[0].compatible(batch_size),
|
||||
"The first dimension of label length must be equal to the first dimension ",
|
||||
"of the logits, the logit length and labels. Got: ",
|
||||
label_length_pshape[0],
|
||||
" and: ",
|
||||
batch_size);
|
||||
}
|
||||
|
||||
// set output shape
|
||||
set_output_size(1);
|
||||
if (is_batch_size_set)
|
||||
{
|
||||
set_output_type(0, logits_type, Shape{batch_size});
|
||||
}
|
||||
else
|
||||
{
|
||||
set_output_type(0, logits_type, PartialShape{Dimension::dynamic()});
|
||||
}
|
||||
}
|
||||
|
||||
bool op::CTCLoss::visit_attributes(AttributeVisitor& visitor)
|
||||
{
|
||||
visitor.on_attribute("preprocess_collapse_repeated", preprocess_collapse_repeated_);
|
||||
visitor.on_attribute("ctc_merge_repeated", ctc_merge_repeated_);
|
||||
visitor.on_attribute("unique", unique_);
|
||||
return true;
|
||||
}
|
||||
|
||||
shared_ptr<Node> op::CTCLoss::clone_with_new_inputs(const OutputVector& new_args) const
|
||||
{
|
||||
check_new_args_count(this, new_args);
|
||||
return make_shared<CTCLoss>(new_args.at(0),
|
||||
new_args.at(1),
|
||||
new_args.at(2),
|
||||
new_args.at(3),
|
||||
new_args.at(4),
|
||||
preprocess_collapse_repeated_,
|
||||
ctc_merge_repeated_,
|
||||
unique_);
|
||||
}
|
||||
77
ngraph/src/ngraph/op/ctc_loss.hpp
Normal file
77
ngraph/src/ngraph/op/ctc_loss.hpp
Normal file
@@ -0,0 +1,77 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ngraph/op/op.hpp"
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
namespace op
|
||||
{
|
||||
namespace v4
|
||||
{
|
||||
class NGRAPH_API CTCLoss : public Op
|
||||
{
|
||||
public:
|
||||
static constexpr NodeTypeInfo type_info{"CTCLoss", 0};
|
||||
const NodeTypeInfo& get_type_info() const override { return type_info; }
|
||||
CTCLoss() = default;
|
||||
/// \brief Constructs a CTCLoss operation
|
||||
///
|
||||
/// \param logits 3-D tensor of logits
|
||||
/// \param logit_length 1-D tensor of lenght for each object from
|
||||
/// a batch
|
||||
/// \param labels 2-D tensor of labels for which likelyhood
|
||||
/// is estimated using logist
|
||||
/// \param label_length 1-D tensor of length for each label
|
||||
/// sequence
|
||||
/// \param blank_index Scalar used to mark a blank index
|
||||
/// \param preprocess_collapse_repeated Flag for preprocessing labels before loss
|
||||
/// calculation
|
||||
/// \param ctc_merge_repeated Flag for merging repeated characters in a
|
||||
/// potential alignment
|
||||
/// \param unique Flag to find unique elements in a target
|
||||
/// before matching with alignment
|
||||
CTCLoss(const Output<Node>& logits,
|
||||
const Output<Node>& logit_length,
|
||||
const Output<Node>& labels,
|
||||
const Output<Node>& label_length,
|
||||
const Output<Node>& blank_index,
|
||||
const bool preprocess_collapse_repeated = false,
|
||||
const bool ctc_merge_repeated = true,
|
||||
const bool unique = false);
|
||||
|
||||
void validate_and_infer_types() override;
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
virtual std::shared_ptr<Node>
|
||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
|
||||
bool get_preprocess_collapse_repeated() const
|
||||
{
|
||||
return preprocess_collapse_repeated_;
|
||||
}
|
||||
bool get_ctc_merge_repeated() const { return ctc_merge_repeated_; }
|
||||
bool get_unique() const { return unique_; }
|
||||
private:
|
||||
bool preprocess_collapse_repeated_;
|
||||
bool ctc_merge_repeated_;
|
||||
bool unique_;
|
||||
};
|
||||
}
|
||||
using v4::CTCLoss;
|
||||
}
|
||||
}
|
||||
@@ -45,6 +45,7 @@
|
||||
#include "ngraph/op/cosh.hpp"
|
||||
#include "ngraph/op/crop_and_resize.hpp"
|
||||
#include "ngraph/op/ctc_greedy_decoder.hpp"
|
||||
#include "ngraph/op/ctc_loss.hpp"
|
||||
#include "ngraph/op/cum_sum.hpp"
|
||||
#include "ngraph/op/deformable_convolution.hpp"
|
||||
#include "ngraph/op/deformable_psroi_pooling.hpp"
|
||||
|
||||
@@ -155,3 +155,4 @@ NGRAPH_OP(TopK, ngraph::op::v3)
|
||||
// New operations added in opset4
|
||||
NGRAPH_OP(NonMaxSuppression, ngraph::op::v4)
|
||||
NGRAPH_OP(Mish, ngraph::op::v4)
|
||||
NGRAPH_OP(CTCLoss, ngraph::op::v4)
|
||||
|
||||
@@ -122,6 +122,7 @@ set(SRC
|
||||
type_prop/convert.cpp
|
||||
type_prop/convolution.cpp
|
||||
type_prop/crop_and_resize.cpp
|
||||
type_prop/ctc_loss.cpp
|
||||
type_prop/deformable_psroi_pooling.cpp
|
||||
type_prop/depth_to_space.cpp
|
||||
type_prop/dequantize.cpp
|
||||
|
||||
320
ngraph/test/type_prop/ctc_loss.cpp
Normal file
320
ngraph/test/type_prop/ctc_loss.cpp
Normal file
@@ -0,0 +1,320 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ngraph/ngraph.hpp"
|
||||
#include "util/type_prop.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
TEST(type_prop, ctc_loss)
|
||||
{
|
||||
// create inputs
|
||||
auto logits = make_shared<op::Parameter>(element::f32, Shape{10, 120, 28});
|
||||
auto logit_length = make_shared<op::Parameter>(element::i32, Shape{10});
|
||||
auto labels = make_shared<op::Parameter>(element::i32, Shape{10, 120});
|
||||
auto label_length = make_shared<op::Parameter>(element::i32, Shape{10});
|
||||
auto blank_index = make_shared<op::Parameter>(element::i32, Shape{});
|
||||
|
||||
// create CTCLoss node
|
||||
auto ctc_loss =
|
||||
make_shared<op::v4::CTCLoss>(logits, logit_length, labels, label_length, blank_index);
|
||||
|
||||
// check type and shape infer
|
||||
EXPECT_EQ(ctc_loss->get_element_type(), element::f32);
|
||||
EXPECT_TRUE(ctc_loss->get_output_partial_shape(0).same_scheme(PartialShape{10}));
|
||||
}
|
||||
|
||||
TEST(type_prop, ctc_loss_output_type)
|
||||
{
|
||||
// create inputs
|
||||
auto logits = make_shared<op::Parameter>(element::f64, Shape{10, 120, 28});
|
||||
auto logit_length = make_shared<op::Parameter>(element::i32, Shape{10});
|
||||
auto labels = make_shared<op::Parameter>(element::i32, Shape{10, 120});
|
||||
auto label_length = make_shared<op::Parameter>(element::i32, Shape{10});
|
||||
auto blank_index = make_shared<op::Parameter>(element::i32, Shape{});
|
||||
|
||||
// create CTCLoss node
|
||||
auto ctc_loss =
|
||||
make_shared<op::v4::CTCLoss>(logits, logit_length, labels, label_length, blank_index);
|
||||
|
||||
// check type and shape infer
|
||||
EXPECT_EQ(ctc_loss->get_element_type(), element::f64);
|
||||
EXPECT_TRUE(ctc_loss->get_output_partial_shape(0).same_scheme(PartialShape{10}));
|
||||
}
|
||||
|
||||
TEST(type_prop, ctc_loss_non_default_parameters)
|
||||
{
|
||||
// create inputs
|
||||
auto logits = make_shared<op::Parameter>(element::f64, Shape{10, 120, 28});
|
||||
auto logit_length = make_shared<op::Parameter>(element::i32, Shape{10});
|
||||
auto labels = make_shared<op::Parameter>(element::i32, Shape{10, 120});
|
||||
auto label_length = make_shared<op::Parameter>(element::i32, Shape{10});
|
||||
auto blank_index = make_shared<op::Parameter>(element::i32, Shape{});
|
||||
|
||||
// create CTCLoss node
|
||||
auto ctc_loss = make_shared<op::v4::CTCLoss>(
|
||||
logits, logit_length, labels, label_length, blank_index, true, false, false);
|
||||
|
||||
// check type and shape infer
|
||||
EXPECT_EQ(ctc_loss->get_element_type(), element::f64);
|
||||
EXPECT_TRUE(ctc_loss->get_output_partial_shape(0).same_scheme(PartialShape{10}));
|
||||
}
|
||||
|
||||
TEST(type_prop, ctc_loss_dynamic_input)
|
||||
{
|
||||
// create inputs
|
||||
auto logits =
|
||||
make_shared<op::Parameter>(element::f32, PartialShape{Dimension::dynamic(), 120, 28});
|
||||
auto logit_length =
|
||||
make_shared<op::Parameter>(element::i32, PartialShape{Dimension::dynamic()});
|
||||
auto labels = make_shared<op::Parameter>(element::i32, PartialShape{Dimension::dynamic(), 120});
|
||||
auto label_length =
|
||||
make_shared<op::Parameter>(element::i32, PartialShape{Dimension::dynamic()});
|
||||
auto blank_index = make_shared<op::Parameter>(element::i32, Shape{});
|
||||
|
||||
// create CTCLoss node
|
||||
auto ctc_loss =
|
||||
make_shared<op::v4::CTCLoss>(logits, logit_length, labels, label_length, blank_index);
|
||||
|
||||
// check type and shape infer
|
||||
EXPECT_EQ(ctc_loss->get_element_type(), element::f32);
|
||||
EXPECT_TRUE(
|
||||
ctc_loss->get_output_partial_shape(0).same_scheme(PartialShape{Dimension::dynamic()}));
|
||||
}
|
||||
|
||||
TEST(type_prop, ctc_loss_partly_dynamic_input)
|
||||
{
|
||||
// create inputs
|
||||
auto logits =
|
||||
make_shared<op::Parameter>(element::f32, PartialShape{Dimension::dynamic(), 120, 28});
|
||||
auto logit_length = make_shared<op::Parameter>(element::i32, PartialShape{10});
|
||||
auto labels = make_shared<op::Parameter>(element::i32, PartialShape{Dimension::dynamic(), 120});
|
||||
auto label_length =
|
||||
make_shared<op::Parameter>(element::i32, PartialShape{Dimension::dynamic()});
|
||||
auto blank_index = make_shared<op::Parameter>(element::i32, Shape{});
|
||||
|
||||
// create CTCLoss node
|
||||
auto ctc_loss =
|
||||
make_shared<op::v4::CTCLoss>(logits, logit_length, labels, label_length, blank_index);
|
||||
|
||||
// check type and shape infer
|
||||
EXPECT_EQ(ctc_loss->get_element_type(), element::f32);
|
||||
EXPECT_TRUE(ctc_loss->get_output_partial_shape(0).same_scheme(PartialShape{10}));
|
||||
}
|
||||
|
||||
TEST(type_prop, ctc_loss_fail_inputs_dim)
|
||||
{
|
||||
// create inputs
|
||||
auto logits = make_shared<op::Parameter>(element::f32, Shape{10, 120, 40, 28});
|
||||
auto logit_length = make_shared<op::Parameter>(element::i32, Shape{10});
|
||||
auto labels = make_shared<op::Parameter>(element::i32, Shape{10, 120});
|
||||
auto label_length = make_shared<op::Parameter>(element::i32, Shape{10});
|
||||
auto blank_index = make_shared<op::Parameter>(element::i32, Shape{});
|
||||
|
||||
try
|
||||
{
|
||||
// create CTCLoss node
|
||||
auto ctc_loss =
|
||||
make_shared<op::v4::CTCLoss>(logits, logit_length, labels, label_length, blank_index);
|
||||
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "Invalid inputs not detected";
|
||||
}
|
||||
catch (const NodeValidationFailure& error)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(error.what(), std::string("Expected a 3D tensor for logits."));
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
FAIL() << "Inputs shape check failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, ctc_loss_fail_logit_length_dim)
|
||||
{
|
||||
// create inputs
|
||||
auto logits = make_shared<op::Parameter>(element::f32, Shape{10, 120, 28});
|
||||
auto logit_length = make_shared<op::Parameter>(element::i32, Shape{10, 20});
|
||||
auto labels = make_shared<op::Parameter>(element::i32, Shape{10, 120});
|
||||
auto label_length = make_shared<op::Parameter>(element::i32, Shape{10});
|
||||
auto blank_index = make_shared<op::Parameter>(element::i32, Shape{});
|
||||
|
||||
try
|
||||
{
|
||||
// create CTCLoss node
|
||||
auto ctc_loss =
|
||||
make_shared<op::v4::CTCLoss>(logits, logit_length, labels, label_length, blank_index);
|
||||
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "Invalid logit length not detected";
|
||||
}
|
||||
catch (const NodeValidationFailure& error)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(error.what(), std::string("Expected a 1D tensor for logit length."));
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
FAIL() << "Logit length shape check failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, ctc_loss_fail_labels_dim)
|
||||
{
|
||||
// create inputs
|
||||
auto logits = make_shared<op::Parameter>(element::f32, Shape{10, 120, 28});
|
||||
auto logit_length = make_shared<op::Parameter>(element::i32, Shape{10});
|
||||
auto labels = make_shared<op::Parameter>(element::i32, Shape{10});
|
||||
auto label_length = make_shared<op::Parameter>(element::i32, Shape{10});
|
||||
auto blank_index = make_shared<op::Parameter>(element::i32, Shape{});
|
||||
|
||||
try
|
||||
{
|
||||
// create CTCLoss node
|
||||
auto ctc_loss =
|
||||
make_shared<op::v4::CTCLoss>(logits, logit_length, labels, label_length, blank_index);
|
||||
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "Invalid labels not detected";
|
||||
}
|
||||
catch (const NodeValidationFailure& error)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(error.what(), std::string("Expected a 2D tensor for labels."));
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
FAIL() << "Labels shape check failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, ctc_loss_fail_label_length_dim)
|
||||
{
|
||||
// create inputs
|
||||
auto logits = make_shared<op::Parameter>(element::f32, Shape{10, 120, 28});
|
||||
auto logit_length = make_shared<op::Parameter>(element::i32, Shape{10});
|
||||
auto labels = make_shared<op::Parameter>(element::i32, Shape{10, 120});
|
||||
auto label_length = make_shared<op::Parameter>(element::i32, Shape{10, 40});
|
||||
auto blank_index = make_shared<op::Parameter>(element::i32, Shape{});
|
||||
|
||||
try
|
||||
{
|
||||
// create CTCLoss node
|
||||
auto ctc_loss =
|
||||
make_shared<op::v4::CTCLoss>(logits, logit_length, labels, label_length, blank_index);
|
||||
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "Invalid labels not detected";
|
||||
}
|
||||
catch (const NodeValidationFailure& error)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(error.what(), std::string("Expected a 1D tensor for label length."));
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
FAIL() << "Label length shape check failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, ctc_loss_fail_blank_index_dim)
|
||||
{
|
||||
// create inputs
|
||||
auto logits = make_shared<op::Parameter>(element::f32, Shape{10, 120, 28});
|
||||
auto logit_length = make_shared<op::Parameter>(element::i32, Shape{10});
|
||||
auto labels = make_shared<op::Parameter>(element::i32, Shape{10, 120});
|
||||
auto label_length = make_shared<op::Parameter>(element::i32, Shape{10});
|
||||
auto blank_index = make_shared<op::Parameter>(element::i32, Shape{4});
|
||||
|
||||
try
|
||||
{
|
||||
// create CTCLoss node
|
||||
auto ctc_loss =
|
||||
make_shared<op::v4::CTCLoss>(logits, logit_length, labels, label_length, blank_index);
|
||||
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "Invalid labels not detected";
|
||||
}
|
||||
catch (const NodeValidationFailure& error)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(error.what(), std::string("Expected a scalar for blank index."));
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
FAIL() << "Blank index shape check failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, ctc_loss_fail_batch_dim_mismatch)
|
||||
{
|
||||
// create inputs
|
||||
auto logits = make_shared<op::Parameter>(element::f32, Shape{10, 120, 28});
|
||||
auto logit_length = make_shared<op::Parameter>(element::i32, Shape{10});
|
||||
auto labels = make_shared<op::Parameter>(element::i32, Shape{10, 120});
|
||||
auto label_length = make_shared<op::Parameter>(element::i32, Shape{40});
|
||||
auto blank_index = make_shared<op::Parameter>(element::i32, Shape{});
|
||||
|
||||
try
|
||||
{
|
||||
// create CTCLoss node
|
||||
auto ctc_loss =
|
||||
make_shared<op::v4::CTCLoss>(logits, logit_length, labels, label_length, blank_index);
|
||||
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "Mismatch of batch dimension not detected";
|
||||
}
|
||||
catch (const NodeValidationFailure& error)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(
|
||||
error.what(),
|
||||
std::string("The first dimension of label length must be equal to the first dimension "
|
||||
"of the logits, the logit length and labels."));
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
FAIL() << "Batch dimension matching check failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, ctc_loss_fail_time_dim_mismatch)
|
||||
{
|
||||
// create inputs
|
||||
auto logits = make_shared<op::Parameter>(element::f32, Shape{10, 120, 28});
|
||||
auto logit_length = make_shared<op::Parameter>(element::i32, Shape{10});
|
||||
auto labels = make_shared<op::Parameter>(element::i32, Shape{10, 130});
|
||||
auto label_length = make_shared<op::Parameter>(element::i32, Shape{40});
|
||||
auto blank_index = make_shared<op::Parameter>(element::i32, Shape{});
|
||||
|
||||
try
|
||||
{
|
||||
// create CTCLoss node
|
||||
auto ctc_loss =
|
||||
make_shared<op::v4::CTCLoss>(logits, logit_length, labels, label_length, blank_index);
|
||||
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "Mismatch of time dimension not detected";
|
||||
}
|
||||
catch (const NodeValidationFailure& error)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(
|
||||
error.what(),
|
||||
std::string("The second dimension of labels must be equal to the second dimension "
|
||||
"of logits."));
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
FAIL() << "Time dimension matching check failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user