Extend nGraph for operation CTCLoss (#1236)

* Extend nGraph for operation CTCLoss

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>

* Fixes as per comments

Co-authored-by: Nikolay Shchegolev <nikolay.shchegolev@intel.com>
This commit is contained in:
Roman Kazantsev
2020-07-22 13:45:42 +03:00
committed by GitHub
parent 141b24cf44
commit 6ccc025a43
7 changed files with 628 additions and 0 deletions

View File

@@ -157,6 +157,8 @@ set (SRC
op/cosh.hpp
op/ctc_greedy_decoder.cpp
op/ctc_greedy_decoder.hpp
op/ctc_loss.cpp
op/ctc_loss.hpp
op/cum_sum.cpp
op/cum_sum.hpp
op/crop_and_resize.cpp

View File

@@ -0,0 +1,226 @@
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/op/ctc_loss.hpp"
using namespace std;
using namespace ngraph;
constexpr NodeTypeInfo op::CTCLoss::type_info;
op::CTCLoss::CTCLoss(const Output<Node>& logits,
const Output<Node>& logit_length,
const Output<Node>& labels,
const Output<Node>& label_length,
const Output<Node>& blank_index,
const bool preprocess_collapse_repeated,
const bool ctc_merge_repeated,
const bool unique)
: Op({logits, logit_length, labels, label_length, blank_index})
, preprocess_collapse_repeated_(preprocess_collapse_repeated)
, ctc_merge_repeated_(ctc_merge_repeated)
, unique_(unique)
{
constructor_validate_and_infer_types();
}
void op::CTCLoss::validate_and_infer_types()
{
// check types of input tensors
const auto& logits_type = get_input_element_type(0);
const auto& logit_length_type = get_input_element_type(1);
const auto& labels_type = get_input_element_type(2);
const auto& label_length_type = get_input_element_type(3);
const auto& blank_index_type = get_input_element_type(4);
NODE_VALIDATION_CHECK(this,
logits_type.is_real(),
"The data type for logits is expected to be a floating point type. Got: ",
logits_type);
NODE_VALIDATION_CHECK(this,
logit_length_type.is_integral_number(),
"The logit length type is expected to be an integer type. Got: ",
logit_length_type);
NODE_VALIDATION_CHECK(this,
labels_type.is_integral_number(),
"The labels type is expected to be an integer type. Got: ",
labels_type);
NODE_VALIDATION_CHECK(this,
label_length_type.is_integral_number(),
"The label length type is expected to be an integer type. Got: ",
label_length_type);
NODE_VALIDATION_CHECK(this,
blank_index_type.is_integral_number(),
"The blank index type is expected to be an integer type. Got: ",
blank_index_type);
// check ranks of input tensors
const auto& logits_pshape = get_input_partial_shape(0);
const auto& logit_length_pshape = get_input_partial_shape(1);
const auto& labels_pshape = get_input_partial_shape(2);
const auto& label_length_pshape = get_input_partial_shape(3);
const auto& blank_index_pshape = get_input_partial_shape(4);
NODE_VALIDATION_CHECK(this,
logits_pshape.rank().compatible(3),
"Expected a 3D tensor for logits. Got: ",
logits_pshape);
NODE_VALIDATION_CHECK(this,
logit_length_pshape.rank().compatible(1),
"Expected a 1D tensor for logit length. Got: ",
logit_length_pshape);
NODE_VALIDATION_CHECK(this,
labels_pshape.rank().compatible(2),
"Expected a 2D tensor for labels. Got: ",
labels_pshape);
NODE_VALIDATION_CHECK(this,
label_length_pshape.rank().compatible(1),
"Expected a 1D tensor for label length. Got: ",
label_length_pshape);
NODE_VALIDATION_CHECK(this,
blank_index_pshape.rank().compatible(0),
"Expected a scalar for blank index. Got: ",
blank_index_pshape);
// check shapes of input tensors
size_t batch_size = 1;
bool is_batch_size_set = false;
size_t time_steps = 1;
bool is_time_steps_set = false;
if (logits_pshape.rank().is_static())
{
if (logits_pshape[0].is_static())
{
batch_size = logits_pshape[0].get_length();
is_batch_size_set = true;
}
if (logits_pshape[1].is_static())
{
time_steps = logits_pshape[1].get_length();
is_time_steps_set = true;
}
}
if (logit_length_pshape.is_static())
{
if (is_batch_size_set)
{
NODE_VALIDATION_CHECK(
this,
logit_length_pshape[0].compatible(batch_size),
"The first dimension of logit length must be equal to the first dimension ",
"of the logits. Got: ",
logit_length_pshape[0],
" and: ",
batch_size);
}
else if (logit_length_pshape[0].is_static())
{
batch_size = logit_length_pshape[0].get_length();
is_batch_size_set = true;
}
}
if (labels_pshape.is_static())
{
if (is_batch_size_set)
{
NODE_VALIDATION_CHECK(
this,
labels_pshape[0].compatible(batch_size),
"The first dimension of labels must be equal to the first dimension ",
"of the logits and the logit length. Got: ",
labels_pshape[0],
" and: ",
batch_size);
}
else if (labels_pshape[0].is_static())
{
batch_size = labels_pshape[0].get_length();
is_batch_size_set = true;
}
if (is_time_steps_set)
{
NODE_VALIDATION_CHECK(
this,
labels_pshape[1].compatible(time_steps),
"The second dimension of labels must be equal to the second dimension ",
"of logits. Got: ",
labels_pshape[1],
" and: ",
time_steps);
}
}
if (label_length_pshape.is_static())
{
if (!is_batch_size_set && label_length_pshape[0].is_static())
{
batch_size = label_length_pshape[0].get_length();
is_batch_size_set = true;
}
NODE_VALIDATION_CHECK(
this,
label_length_pshape[0].compatible(batch_size),
"The first dimension of label length must be equal to the first dimension ",
"of the logits, the logit length and labels. Got: ",
label_length_pshape[0],
" and: ",
batch_size);
}
// set output shape
set_output_size(1);
if (is_batch_size_set)
{
set_output_type(0, logits_type, Shape{batch_size});
}
else
{
set_output_type(0, logits_type, PartialShape{Dimension::dynamic()});
}
}
bool op::CTCLoss::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("preprocess_collapse_repeated", preprocess_collapse_repeated_);
visitor.on_attribute("ctc_merge_repeated", ctc_merge_repeated_);
visitor.on_attribute("unique", unique_);
return true;
}
shared_ptr<Node> op::CTCLoss::clone_with_new_inputs(const OutputVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<CTCLoss>(new_args.at(0),
new_args.at(1),
new_args.at(2),
new_args.at(3),
new_args.at(4),
preprocess_collapse_repeated_,
ctc_merge_repeated_,
unique_);
}

View File

@@ -0,0 +1,77 @@
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/op/op.hpp"
namespace ngraph
{
namespace op
{
namespace v4
{
class NGRAPH_API CTCLoss : public Op
{
public:
static constexpr NodeTypeInfo type_info{"CTCLoss", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
CTCLoss() = default;
/// \brief Constructs a CTCLoss operation
///
/// \param logits 3-D tensor of logits
/// \param logit_length 1-D tensor of lenght for each object from
/// a batch
/// \param labels 2-D tensor of labels for which likelyhood
/// is estimated using logist
/// \param label_length 1-D tensor of length for each label
/// sequence
/// \param blank_index Scalar used to mark a blank index
/// \param preprocess_collapse_repeated Flag for preprocessing labels before loss
/// calculation
/// \param ctc_merge_repeated Flag for merging repeated characters in a
/// potential alignment
/// \param unique Flag to find unique elements in a target
/// before matching with alignment
CTCLoss(const Output<Node>& logits,
const Output<Node>& logit_length,
const Output<Node>& labels,
const Output<Node>& label_length,
const Output<Node>& blank_index,
const bool preprocess_collapse_repeated = false,
const bool ctc_merge_repeated = true,
const bool unique = false);
void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
bool get_preprocess_collapse_repeated() const
{
return preprocess_collapse_repeated_;
}
bool get_ctc_merge_repeated() const { return ctc_merge_repeated_; }
bool get_unique() const { return unique_; }
private:
bool preprocess_collapse_repeated_;
bool ctc_merge_repeated_;
bool unique_;
};
}
using v4::CTCLoss;
}
}

View File

@@ -45,6 +45,7 @@
#include "ngraph/op/cosh.hpp"
#include "ngraph/op/crop_and_resize.hpp"
#include "ngraph/op/ctc_greedy_decoder.hpp"
#include "ngraph/op/ctc_loss.hpp"
#include "ngraph/op/cum_sum.hpp"
#include "ngraph/op/deformable_convolution.hpp"
#include "ngraph/op/deformable_psroi_pooling.hpp"

View File

@@ -155,3 +155,4 @@ NGRAPH_OP(TopK, ngraph::op::v3)
// New operations added in opset4
NGRAPH_OP(NonMaxSuppression, ngraph::op::v4)
NGRAPH_OP(Mish, ngraph::op::v4)
NGRAPH_OP(CTCLoss, ngraph::op::v4)

View File

@@ -122,6 +122,7 @@ set(SRC
type_prop/convert.cpp
type_prop/convolution.cpp
type_prop/crop_and_resize.cpp
type_prop/ctc_loss.cpp
type_prop/deformable_psroi_pooling.cpp
type_prop/depth_to_space.cpp
type_prop/dequantize.cpp

View File

@@ -0,0 +1,320 @@
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "util/type_prop.hpp"
using namespace std;
using namespace ngraph;
TEST(type_prop, ctc_loss)
{
// create inputs
auto logits = make_shared<op::Parameter>(element::f32, Shape{10, 120, 28});
auto logit_length = make_shared<op::Parameter>(element::i32, Shape{10});
auto labels = make_shared<op::Parameter>(element::i32, Shape{10, 120});
auto label_length = make_shared<op::Parameter>(element::i32, Shape{10});
auto blank_index = make_shared<op::Parameter>(element::i32, Shape{});
// create CTCLoss node
auto ctc_loss =
make_shared<op::v4::CTCLoss>(logits, logit_length, labels, label_length, blank_index);
// check type and shape infer
EXPECT_EQ(ctc_loss->get_element_type(), element::f32);
EXPECT_TRUE(ctc_loss->get_output_partial_shape(0).same_scheme(PartialShape{10}));
}
TEST(type_prop, ctc_loss_output_type)
{
// create inputs
auto logits = make_shared<op::Parameter>(element::f64, Shape{10, 120, 28});
auto logit_length = make_shared<op::Parameter>(element::i32, Shape{10});
auto labels = make_shared<op::Parameter>(element::i32, Shape{10, 120});
auto label_length = make_shared<op::Parameter>(element::i32, Shape{10});
auto blank_index = make_shared<op::Parameter>(element::i32, Shape{});
// create CTCLoss node
auto ctc_loss =
make_shared<op::v4::CTCLoss>(logits, logit_length, labels, label_length, blank_index);
// check type and shape infer
EXPECT_EQ(ctc_loss->get_element_type(), element::f64);
EXPECT_TRUE(ctc_loss->get_output_partial_shape(0).same_scheme(PartialShape{10}));
}
TEST(type_prop, ctc_loss_non_default_parameters)
{
// create inputs
auto logits = make_shared<op::Parameter>(element::f64, Shape{10, 120, 28});
auto logit_length = make_shared<op::Parameter>(element::i32, Shape{10});
auto labels = make_shared<op::Parameter>(element::i32, Shape{10, 120});
auto label_length = make_shared<op::Parameter>(element::i32, Shape{10});
auto blank_index = make_shared<op::Parameter>(element::i32, Shape{});
// create CTCLoss node
auto ctc_loss = make_shared<op::v4::CTCLoss>(
logits, logit_length, labels, label_length, blank_index, true, false, false);
// check type and shape infer
EXPECT_EQ(ctc_loss->get_element_type(), element::f64);
EXPECT_TRUE(ctc_loss->get_output_partial_shape(0).same_scheme(PartialShape{10}));
}
TEST(type_prop, ctc_loss_dynamic_input)
{
// create inputs
auto logits =
make_shared<op::Parameter>(element::f32, PartialShape{Dimension::dynamic(), 120, 28});
auto logit_length =
make_shared<op::Parameter>(element::i32, PartialShape{Dimension::dynamic()});
auto labels = make_shared<op::Parameter>(element::i32, PartialShape{Dimension::dynamic(), 120});
auto label_length =
make_shared<op::Parameter>(element::i32, PartialShape{Dimension::dynamic()});
auto blank_index = make_shared<op::Parameter>(element::i32, Shape{});
// create CTCLoss node
auto ctc_loss =
make_shared<op::v4::CTCLoss>(logits, logit_length, labels, label_length, blank_index);
// check type and shape infer
EXPECT_EQ(ctc_loss->get_element_type(), element::f32);
EXPECT_TRUE(
ctc_loss->get_output_partial_shape(0).same_scheme(PartialShape{Dimension::dynamic()}));
}
TEST(type_prop, ctc_loss_partly_dynamic_input)
{
// create inputs
auto logits =
make_shared<op::Parameter>(element::f32, PartialShape{Dimension::dynamic(), 120, 28});
auto logit_length = make_shared<op::Parameter>(element::i32, PartialShape{10});
auto labels = make_shared<op::Parameter>(element::i32, PartialShape{Dimension::dynamic(), 120});
auto label_length =
make_shared<op::Parameter>(element::i32, PartialShape{Dimension::dynamic()});
auto blank_index = make_shared<op::Parameter>(element::i32, Shape{});
// create CTCLoss node
auto ctc_loss =
make_shared<op::v4::CTCLoss>(logits, logit_length, labels, label_length, blank_index);
// check type and shape infer
EXPECT_EQ(ctc_loss->get_element_type(), element::f32);
EXPECT_TRUE(ctc_loss->get_output_partial_shape(0).same_scheme(PartialShape{10}));
}
TEST(type_prop, ctc_loss_fail_inputs_dim)
{
// create inputs
auto logits = make_shared<op::Parameter>(element::f32, Shape{10, 120, 40, 28});
auto logit_length = make_shared<op::Parameter>(element::i32, Shape{10});
auto labels = make_shared<op::Parameter>(element::i32, Shape{10, 120});
auto label_length = make_shared<op::Parameter>(element::i32, Shape{10});
auto blank_index = make_shared<op::Parameter>(element::i32, Shape{});
try
{
// create CTCLoss node
auto ctc_loss =
make_shared<op::v4::CTCLoss>(logits, logit_length, labels, label_length, blank_index);
// Should have thrown, so fail if it didn't
FAIL() << "Invalid inputs not detected";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Expected a 3D tensor for logits."));
}
catch (...)
{
FAIL() << "Inputs shape check failed for unexpected reason";
}
}
TEST(type_prop, ctc_loss_fail_logit_length_dim)
{
// create inputs
auto logits = make_shared<op::Parameter>(element::f32, Shape{10, 120, 28});
auto logit_length = make_shared<op::Parameter>(element::i32, Shape{10, 20});
auto labels = make_shared<op::Parameter>(element::i32, Shape{10, 120});
auto label_length = make_shared<op::Parameter>(element::i32, Shape{10});
auto blank_index = make_shared<op::Parameter>(element::i32, Shape{});
try
{
// create CTCLoss node
auto ctc_loss =
make_shared<op::v4::CTCLoss>(logits, logit_length, labels, label_length, blank_index);
// Should have thrown, so fail if it didn't
FAIL() << "Invalid logit length not detected";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Expected a 1D tensor for logit length."));
}
catch (...)
{
FAIL() << "Logit length shape check failed for unexpected reason";
}
}
TEST(type_prop, ctc_loss_fail_labels_dim)
{
// create inputs
auto logits = make_shared<op::Parameter>(element::f32, Shape{10, 120, 28});
auto logit_length = make_shared<op::Parameter>(element::i32, Shape{10});
auto labels = make_shared<op::Parameter>(element::i32, Shape{10});
auto label_length = make_shared<op::Parameter>(element::i32, Shape{10});
auto blank_index = make_shared<op::Parameter>(element::i32, Shape{});
try
{
// create CTCLoss node
auto ctc_loss =
make_shared<op::v4::CTCLoss>(logits, logit_length, labels, label_length, blank_index);
// Should have thrown, so fail if it didn't
FAIL() << "Invalid labels not detected";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Expected a 2D tensor for labels."));
}
catch (...)
{
FAIL() << "Labels shape check failed for unexpected reason";
}
}
TEST(type_prop, ctc_loss_fail_label_length_dim)
{
// create inputs
auto logits = make_shared<op::Parameter>(element::f32, Shape{10, 120, 28});
auto logit_length = make_shared<op::Parameter>(element::i32, Shape{10});
auto labels = make_shared<op::Parameter>(element::i32, Shape{10, 120});
auto label_length = make_shared<op::Parameter>(element::i32, Shape{10, 40});
auto blank_index = make_shared<op::Parameter>(element::i32, Shape{});
try
{
// create CTCLoss node
auto ctc_loss =
make_shared<op::v4::CTCLoss>(logits, logit_length, labels, label_length, blank_index);
// Should have thrown, so fail if it didn't
FAIL() << "Invalid labels not detected";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Expected a 1D tensor for label length."));
}
catch (...)
{
FAIL() << "Label length shape check failed for unexpected reason";
}
}
TEST(type_prop, ctc_loss_fail_blank_index_dim)
{
// create inputs
auto logits = make_shared<op::Parameter>(element::f32, Shape{10, 120, 28});
auto logit_length = make_shared<op::Parameter>(element::i32, Shape{10});
auto labels = make_shared<op::Parameter>(element::i32, Shape{10, 120});
auto label_length = make_shared<op::Parameter>(element::i32, Shape{10});
auto blank_index = make_shared<op::Parameter>(element::i32, Shape{4});
try
{
// create CTCLoss node
auto ctc_loss =
make_shared<op::v4::CTCLoss>(logits, logit_length, labels, label_length, blank_index);
// Should have thrown, so fail if it didn't
FAIL() << "Invalid labels not detected";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), std::string("Expected a scalar for blank index."));
}
catch (...)
{
FAIL() << "Blank index shape check failed for unexpected reason";
}
}
TEST(type_prop, ctc_loss_fail_batch_dim_mismatch)
{
// create inputs
auto logits = make_shared<op::Parameter>(element::f32, Shape{10, 120, 28});
auto logit_length = make_shared<op::Parameter>(element::i32, Shape{10});
auto labels = make_shared<op::Parameter>(element::i32, Shape{10, 120});
auto label_length = make_shared<op::Parameter>(element::i32, Shape{40});
auto blank_index = make_shared<op::Parameter>(element::i32, Shape{});
try
{
// create CTCLoss node
auto ctc_loss =
make_shared<op::v4::CTCLoss>(logits, logit_length, labels, label_length, blank_index);
// Should have thrown, so fail if it didn't
FAIL() << "Mismatch of batch dimension not detected";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
std::string("The first dimension of label length must be equal to the first dimension "
"of the logits, the logit length and labels."));
}
catch (...)
{
FAIL() << "Batch dimension matching check failed for unexpected reason";
}
}
TEST(type_prop, ctc_loss_fail_time_dim_mismatch)
{
// create inputs
auto logits = make_shared<op::Parameter>(element::f32, Shape{10, 120, 28});
auto logit_length = make_shared<op::Parameter>(element::i32, Shape{10});
auto labels = make_shared<op::Parameter>(element::i32, Shape{10, 130});
auto label_length = make_shared<op::Parameter>(element::i32, Shape{40});
auto blank_index = make_shared<op::Parameter>(element::i32, Shape{});
try
{
// create CTCLoss node
auto ctc_loss =
make_shared<op::v4::CTCLoss>(logits, logit_length, labels, label_length, blank_index);
// Should have thrown, so fail if it didn't
FAIL() << "Mismatch of time dimension not detected";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
std::string("The second dimension of labels must be equal to the second dimension "
"of logits."));
}
catch (...)
{
FAIL() << "Time dimension matching check failed for unexpected reason";
}
}