From 6ccc025a43b9131c4c3a04703110cc517fdef7bf Mon Sep 17 00:00:00 2001 From: Roman Kazantsev Date: Wed, 22 Jul 2020 13:45:42 +0300 Subject: [PATCH] Extend nGraph for operation CTCLoss (#1236) * Extend nGraph for operation CTCLoss Signed-off-by: Roman Kazantsev * Fixes as per comments Co-authored-by: Nikolay Shchegolev --- ngraph/src/ngraph/CMakeLists.txt | 2 + ngraph/src/ngraph/op/ctc_loss.cpp | 226 +++++++++++++++++ ngraph/src/ngraph/op/ctc_loss.hpp | 77 ++++++ ngraph/src/ngraph/ops.hpp | 1 + ngraph/src/ngraph/opsets/opset4_tbl.hpp | 1 + ngraph/test/CMakeLists.txt | 1 + ngraph/test/type_prop/ctc_loss.cpp | 320 ++++++++++++++++++++++++ 7 files changed, 628 insertions(+) create mode 100644 ngraph/src/ngraph/op/ctc_loss.cpp create mode 100644 ngraph/src/ngraph/op/ctc_loss.hpp create mode 100644 ngraph/test/type_prop/ctc_loss.cpp diff --git a/ngraph/src/ngraph/CMakeLists.txt b/ngraph/src/ngraph/CMakeLists.txt index b9e57b6a993..e5717fd0c76 100644 --- a/ngraph/src/ngraph/CMakeLists.txt +++ b/ngraph/src/ngraph/CMakeLists.txt @@ -157,6 +157,8 @@ set (SRC op/cosh.hpp op/ctc_greedy_decoder.cpp op/ctc_greedy_decoder.hpp + op/ctc_loss.cpp + op/ctc_loss.hpp op/cum_sum.cpp op/cum_sum.hpp op/crop_and_resize.cpp diff --git a/ngraph/src/ngraph/op/ctc_loss.cpp b/ngraph/src/ngraph/op/ctc_loss.cpp new file mode 100644 index 00000000000..05d6b90ec8f --- /dev/null +++ b/ngraph/src/ngraph/op/ctc_loss.cpp @@ -0,0 +1,226 @@ +//***************************************************************************** +// Copyright 2017-2020 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** + +#include "ngraph/op/ctc_loss.hpp" + +using namespace std; +using namespace ngraph; + +constexpr NodeTypeInfo op::CTCLoss::type_info; + +op::CTCLoss::CTCLoss(const Output& logits, + const Output& logit_length, + const Output& labels, + const Output& label_length, + const Output& blank_index, + const bool preprocess_collapse_repeated, + const bool ctc_merge_repeated, + const bool unique) + : Op({logits, logit_length, labels, label_length, blank_index}) + , preprocess_collapse_repeated_(preprocess_collapse_repeated) + , ctc_merge_repeated_(ctc_merge_repeated) + , unique_(unique) +{ + constructor_validate_and_infer_types(); +} + +void op::CTCLoss::validate_and_infer_types() +{ + // check types of input tensors + const auto& logits_type = get_input_element_type(0); + const auto& logit_length_type = get_input_element_type(1); + const auto& labels_type = get_input_element_type(2); + const auto& label_length_type = get_input_element_type(3); + const auto& blank_index_type = get_input_element_type(4); + + NODE_VALIDATION_CHECK(this, + logits_type.is_real(), + "The data type for logits is expected to be a floating point type. Got: ", + logits_type); + + NODE_VALIDATION_CHECK(this, + logit_length_type.is_integral_number(), + "The logit length type is expected to be an integer type. Got: ", + logit_length_type); + + NODE_VALIDATION_CHECK(this, + labels_type.is_integral_number(), + "The labels type is expected to be an integer type. Got: ", + labels_type); + + NODE_VALIDATION_CHECK(this, + label_length_type.is_integral_number(), + "The label length type is expected to be an integer type. Got: ", + label_length_type); + + NODE_VALIDATION_CHECK(this, + blank_index_type.is_integral_number(), + "The blank index type is expected to be an integer type. Got: ", + blank_index_type); + + // check ranks of input tensors + const auto& logits_pshape = get_input_partial_shape(0); + const auto& logit_length_pshape = get_input_partial_shape(1); + const auto& labels_pshape = get_input_partial_shape(2); + const auto& label_length_pshape = get_input_partial_shape(3); + const auto& blank_index_pshape = get_input_partial_shape(4); + + NODE_VALIDATION_CHECK(this, + logits_pshape.rank().compatible(3), + "Expected a 3D tensor for logits. Got: ", + logits_pshape); + + NODE_VALIDATION_CHECK(this, + logit_length_pshape.rank().compatible(1), + "Expected a 1D tensor for logit length. Got: ", + logit_length_pshape); + + NODE_VALIDATION_CHECK(this, + labels_pshape.rank().compatible(2), + "Expected a 2D tensor for labels. Got: ", + labels_pshape); + + NODE_VALIDATION_CHECK(this, + label_length_pshape.rank().compatible(1), + "Expected a 1D tensor for label length. Got: ", + label_length_pshape); + + NODE_VALIDATION_CHECK(this, + blank_index_pshape.rank().compatible(0), + "Expected a scalar for blank index. Got: ", + blank_index_pshape); + + // check shapes of input tensors + size_t batch_size = 1; + bool is_batch_size_set = false; + size_t time_steps = 1; + bool is_time_steps_set = false; + + if (logits_pshape.rank().is_static()) + { + if (logits_pshape[0].is_static()) + { + batch_size = logits_pshape[0].get_length(); + is_batch_size_set = true; + } + if (logits_pshape[1].is_static()) + { + time_steps = logits_pshape[1].get_length(); + is_time_steps_set = true; + } + } + + if (logit_length_pshape.is_static()) + { + if (is_batch_size_set) + { + NODE_VALIDATION_CHECK( + this, + logit_length_pshape[0].compatible(batch_size), + "The first dimension of logit length must be equal to the first dimension ", + "of the logits. Got: ", + logit_length_pshape[0], + " and: ", + batch_size); + } + else if (logit_length_pshape[0].is_static()) + { + batch_size = logit_length_pshape[0].get_length(); + is_batch_size_set = true; + } + } + + if (labels_pshape.is_static()) + { + if (is_batch_size_set) + { + NODE_VALIDATION_CHECK( + this, + labels_pshape[0].compatible(batch_size), + "The first dimension of labels must be equal to the first dimension ", + "of the logits and the logit length. Got: ", + labels_pshape[0], + " and: ", + batch_size); + } + else if (labels_pshape[0].is_static()) + { + batch_size = labels_pshape[0].get_length(); + is_batch_size_set = true; + } + + if (is_time_steps_set) + { + NODE_VALIDATION_CHECK( + this, + labels_pshape[1].compatible(time_steps), + "The second dimension of labels must be equal to the second dimension ", + "of logits. Got: ", + labels_pshape[1], + " and: ", + time_steps); + } + } + + if (label_length_pshape.is_static()) + { + if (!is_batch_size_set && label_length_pshape[0].is_static()) + { + batch_size = label_length_pshape[0].get_length(); + is_batch_size_set = true; + } + NODE_VALIDATION_CHECK( + this, + label_length_pshape[0].compatible(batch_size), + "The first dimension of label length must be equal to the first dimension ", + "of the logits, the logit length and labels. Got: ", + label_length_pshape[0], + " and: ", + batch_size); + } + + // set output shape + set_output_size(1); + if (is_batch_size_set) + { + set_output_type(0, logits_type, Shape{batch_size}); + } + else + { + set_output_type(0, logits_type, PartialShape{Dimension::dynamic()}); + } +} + +bool op::CTCLoss::visit_attributes(AttributeVisitor& visitor) +{ + visitor.on_attribute("preprocess_collapse_repeated", preprocess_collapse_repeated_); + visitor.on_attribute("ctc_merge_repeated", ctc_merge_repeated_); + visitor.on_attribute("unique", unique_); + return true; +} + +shared_ptr op::CTCLoss::clone_with_new_inputs(const OutputVector& new_args) const +{ + check_new_args_count(this, new_args); + return make_shared(new_args.at(0), + new_args.at(1), + new_args.at(2), + new_args.at(3), + new_args.at(4), + preprocess_collapse_repeated_, + ctc_merge_repeated_, + unique_); +} diff --git a/ngraph/src/ngraph/op/ctc_loss.hpp b/ngraph/src/ngraph/op/ctc_loss.hpp new file mode 100644 index 00000000000..82518e78322 --- /dev/null +++ b/ngraph/src/ngraph/op/ctc_loss.hpp @@ -0,0 +1,77 @@ +//***************************************************************************** +// Copyright 2017-2020 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** + +#pragma once + +#include "ngraph/op/op.hpp" + +namespace ngraph +{ + namespace op + { + namespace v4 + { + class NGRAPH_API CTCLoss : public Op + { + public: + static constexpr NodeTypeInfo type_info{"CTCLoss", 0}; + const NodeTypeInfo& get_type_info() const override { return type_info; } + CTCLoss() = default; + /// \brief Constructs a CTCLoss operation + /// + /// \param logits 3-D tensor of logits + /// \param logit_length 1-D tensor of lenght for each object from + /// a batch + /// \param labels 2-D tensor of labels for which likelyhood + /// is estimated using logist + /// \param label_length 1-D tensor of length for each label + /// sequence + /// \param blank_index Scalar used to mark a blank index + /// \param preprocess_collapse_repeated Flag for preprocessing labels before loss + /// calculation + /// \param ctc_merge_repeated Flag for merging repeated characters in a + /// potential alignment + /// \param unique Flag to find unique elements in a target + /// before matching with alignment + CTCLoss(const Output& logits, + const Output& logit_length, + const Output& labels, + const Output& label_length, + const Output& blank_index, + const bool preprocess_collapse_repeated = false, + const bool ctc_merge_repeated = true, + const bool unique = false); + + void validate_and_infer_types() override; + bool visit_attributes(AttributeVisitor& visitor) override; + virtual std::shared_ptr + clone_with_new_inputs(const OutputVector& new_args) const override; + + bool get_preprocess_collapse_repeated() const + { + return preprocess_collapse_repeated_; + } + bool get_ctc_merge_repeated() const { return ctc_merge_repeated_; } + bool get_unique() const { return unique_; } + private: + bool preprocess_collapse_repeated_; + bool ctc_merge_repeated_; + bool unique_; + }; + } + using v4::CTCLoss; + } +} diff --git a/ngraph/src/ngraph/ops.hpp b/ngraph/src/ngraph/ops.hpp index e47ad77295f..2b09fdea803 100644 --- a/ngraph/src/ngraph/ops.hpp +++ b/ngraph/src/ngraph/ops.hpp @@ -45,6 +45,7 @@ #include "ngraph/op/cosh.hpp" #include "ngraph/op/crop_and_resize.hpp" #include "ngraph/op/ctc_greedy_decoder.hpp" +#include "ngraph/op/ctc_loss.hpp" #include "ngraph/op/cum_sum.hpp" #include "ngraph/op/deformable_convolution.hpp" #include "ngraph/op/deformable_psroi_pooling.hpp" diff --git a/ngraph/src/ngraph/opsets/opset4_tbl.hpp b/ngraph/src/ngraph/opsets/opset4_tbl.hpp index c7d804aab6a..573759f862a 100644 --- a/ngraph/src/ngraph/opsets/opset4_tbl.hpp +++ b/ngraph/src/ngraph/opsets/opset4_tbl.hpp @@ -155,3 +155,4 @@ NGRAPH_OP(TopK, ngraph::op::v3) // New operations added in opset4 NGRAPH_OP(NonMaxSuppression, ngraph::op::v4) NGRAPH_OP(Mish, ngraph::op::v4) +NGRAPH_OP(CTCLoss, ngraph::op::v4) diff --git a/ngraph/test/CMakeLists.txt b/ngraph/test/CMakeLists.txt index c74c739a55e..59521922a34 100644 --- a/ngraph/test/CMakeLists.txt +++ b/ngraph/test/CMakeLists.txt @@ -122,6 +122,7 @@ set(SRC type_prop/convert.cpp type_prop/convolution.cpp type_prop/crop_and_resize.cpp + type_prop/ctc_loss.cpp type_prop/deformable_psroi_pooling.cpp type_prop/depth_to_space.cpp type_prop/dequantize.cpp diff --git a/ngraph/test/type_prop/ctc_loss.cpp b/ngraph/test/type_prop/ctc_loss.cpp new file mode 100644 index 00000000000..07333a5260e --- /dev/null +++ b/ngraph/test/type_prop/ctc_loss.cpp @@ -0,0 +1,320 @@ +//***************************************************************************** +// Copyright 2017-2020 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** + +#include "gtest/gtest.h" +#include "ngraph/ngraph.hpp" +#include "util/type_prop.hpp" + +using namespace std; +using namespace ngraph; + +TEST(type_prop, ctc_loss) +{ + // create inputs + auto logits = make_shared(element::f32, Shape{10, 120, 28}); + auto logit_length = make_shared(element::i32, Shape{10}); + auto labels = make_shared(element::i32, Shape{10, 120}); + auto label_length = make_shared(element::i32, Shape{10}); + auto blank_index = make_shared(element::i32, Shape{}); + + // create CTCLoss node + auto ctc_loss = + make_shared(logits, logit_length, labels, label_length, blank_index); + + // check type and shape infer + EXPECT_EQ(ctc_loss->get_element_type(), element::f32); + EXPECT_TRUE(ctc_loss->get_output_partial_shape(0).same_scheme(PartialShape{10})); +} + +TEST(type_prop, ctc_loss_output_type) +{ + // create inputs + auto logits = make_shared(element::f64, Shape{10, 120, 28}); + auto logit_length = make_shared(element::i32, Shape{10}); + auto labels = make_shared(element::i32, Shape{10, 120}); + auto label_length = make_shared(element::i32, Shape{10}); + auto blank_index = make_shared(element::i32, Shape{}); + + // create CTCLoss node + auto ctc_loss = + make_shared(logits, logit_length, labels, label_length, blank_index); + + // check type and shape infer + EXPECT_EQ(ctc_loss->get_element_type(), element::f64); + EXPECT_TRUE(ctc_loss->get_output_partial_shape(0).same_scheme(PartialShape{10})); +} + +TEST(type_prop, ctc_loss_non_default_parameters) +{ + // create inputs + auto logits = make_shared(element::f64, Shape{10, 120, 28}); + auto logit_length = make_shared(element::i32, Shape{10}); + auto labels = make_shared(element::i32, Shape{10, 120}); + auto label_length = make_shared(element::i32, Shape{10}); + auto blank_index = make_shared(element::i32, Shape{}); + + // create CTCLoss node + auto ctc_loss = make_shared( + logits, logit_length, labels, label_length, blank_index, true, false, false); + + // check type and shape infer + EXPECT_EQ(ctc_loss->get_element_type(), element::f64); + EXPECT_TRUE(ctc_loss->get_output_partial_shape(0).same_scheme(PartialShape{10})); +} + +TEST(type_prop, ctc_loss_dynamic_input) +{ + // create inputs + auto logits = + make_shared(element::f32, PartialShape{Dimension::dynamic(), 120, 28}); + auto logit_length = + make_shared(element::i32, PartialShape{Dimension::dynamic()}); + auto labels = make_shared(element::i32, PartialShape{Dimension::dynamic(), 120}); + auto label_length = + make_shared(element::i32, PartialShape{Dimension::dynamic()}); + auto blank_index = make_shared(element::i32, Shape{}); + + // create CTCLoss node + auto ctc_loss = + make_shared(logits, logit_length, labels, label_length, blank_index); + + // check type and shape infer + EXPECT_EQ(ctc_loss->get_element_type(), element::f32); + EXPECT_TRUE( + ctc_loss->get_output_partial_shape(0).same_scheme(PartialShape{Dimension::dynamic()})); +} + +TEST(type_prop, ctc_loss_partly_dynamic_input) +{ + // create inputs + auto logits = + make_shared(element::f32, PartialShape{Dimension::dynamic(), 120, 28}); + auto logit_length = make_shared(element::i32, PartialShape{10}); + auto labels = make_shared(element::i32, PartialShape{Dimension::dynamic(), 120}); + auto label_length = + make_shared(element::i32, PartialShape{Dimension::dynamic()}); + auto blank_index = make_shared(element::i32, Shape{}); + + // create CTCLoss node + auto ctc_loss = + make_shared(logits, logit_length, labels, label_length, blank_index); + + // check type and shape infer + EXPECT_EQ(ctc_loss->get_element_type(), element::f32); + EXPECT_TRUE(ctc_loss->get_output_partial_shape(0).same_scheme(PartialShape{10})); +} + +TEST(type_prop, ctc_loss_fail_inputs_dim) +{ + // create inputs + auto logits = make_shared(element::f32, Shape{10, 120, 40, 28}); + auto logit_length = make_shared(element::i32, Shape{10}); + auto labels = make_shared(element::i32, Shape{10, 120}); + auto label_length = make_shared(element::i32, Shape{10}); + auto blank_index = make_shared(element::i32, Shape{}); + + try + { + // create CTCLoss node + auto ctc_loss = + make_shared(logits, logit_length, labels, label_length, blank_index); + + // Should have thrown, so fail if it didn't + FAIL() << "Invalid inputs not detected"; + } + catch (const NodeValidationFailure& error) + { + EXPECT_HAS_SUBSTRING(error.what(), std::string("Expected a 3D tensor for logits.")); + } + catch (...) + { + FAIL() << "Inputs shape check failed for unexpected reason"; + } +} + +TEST(type_prop, ctc_loss_fail_logit_length_dim) +{ + // create inputs + auto logits = make_shared(element::f32, Shape{10, 120, 28}); + auto logit_length = make_shared(element::i32, Shape{10, 20}); + auto labels = make_shared(element::i32, Shape{10, 120}); + auto label_length = make_shared(element::i32, Shape{10}); + auto blank_index = make_shared(element::i32, Shape{}); + + try + { + // create CTCLoss node + auto ctc_loss = + make_shared(logits, logit_length, labels, label_length, blank_index); + + // Should have thrown, so fail if it didn't + FAIL() << "Invalid logit length not detected"; + } + catch (const NodeValidationFailure& error) + { + EXPECT_HAS_SUBSTRING(error.what(), std::string("Expected a 1D tensor for logit length.")); + } + catch (...) + { + FAIL() << "Logit length shape check failed for unexpected reason"; + } +} + +TEST(type_prop, ctc_loss_fail_labels_dim) +{ + // create inputs + auto logits = make_shared(element::f32, Shape{10, 120, 28}); + auto logit_length = make_shared(element::i32, Shape{10}); + auto labels = make_shared(element::i32, Shape{10}); + auto label_length = make_shared(element::i32, Shape{10}); + auto blank_index = make_shared(element::i32, Shape{}); + + try + { + // create CTCLoss node + auto ctc_loss = + make_shared(logits, logit_length, labels, label_length, blank_index); + + // Should have thrown, so fail if it didn't + FAIL() << "Invalid labels not detected"; + } + catch (const NodeValidationFailure& error) + { + EXPECT_HAS_SUBSTRING(error.what(), std::string("Expected a 2D tensor for labels.")); + } + catch (...) + { + FAIL() << "Labels shape check failed for unexpected reason"; + } +} + +TEST(type_prop, ctc_loss_fail_label_length_dim) +{ + // create inputs + auto logits = make_shared(element::f32, Shape{10, 120, 28}); + auto logit_length = make_shared(element::i32, Shape{10}); + auto labels = make_shared(element::i32, Shape{10, 120}); + auto label_length = make_shared(element::i32, Shape{10, 40}); + auto blank_index = make_shared(element::i32, Shape{}); + + try + { + // create CTCLoss node + auto ctc_loss = + make_shared(logits, logit_length, labels, label_length, blank_index); + + // Should have thrown, so fail if it didn't + FAIL() << "Invalid labels not detected"; + } + catch (const NodeValidationFailure& error) + { + EXPECT_HAS_SUBSTRING(error.what(), std::string("Expected a 1D tensor for label length.")); + } + catch (...) + { + FAIL() << "Label length shape check failed for unexpected reason"; + } +} + +TEST(type_prop, ctc_loss_fail_blank_index_dim) +{ + // create inputs + auto logits = make_shared(element::f32, Shape{10, 120, 28}); + auto logit_length = make_shared(element::i32, Shape{10}); + auto labels = make_shared(element::i32, Shape{10, 120}); + auto label_length = make_shared(element::i32, Shape{10}); + auto blank_index = make_shared(element::i32, Shape{4}); + + try + { + // create CTCLoss node + auto ctc_loss = + make_shared(logits, logit_length, labels, label_length, blank_index); + + // Should have thrown, so fail if it didn't + FAIL() << "Invalid labels not detected"; + } + catch (const NodeValidationFailure& error) + { + EXPECT_HAS_SUBSTRING(error.what(), std::string("Expected a scalar for blank index.")); + } + catch (...) + { + FAIL() << "Blank index shape check failed for unexpected reason"; + } +} + +TEST(type_prop, ctc_loss_fail_batch_dim_mismatch) +{ + // create inputs + auto logits = make_shared(element::f32, Shape{10, 120, 28}); + auto logit_length = make_shared(element::i32, Shape{10}); + auto labels = make_shared(element::i32, Shape{10, 120}); + auto label_length = make_shared(element::i32, Shape{40}); + auto blank_index = make_shared(element::i32, Shape{}); + + try + { + // create CTCLoss node + auto ctc_loss = + make_shared(logits, logit_length, labels, label_length, blank_index); + + // Should have thrown, so fail if it didn't + FAIL() << "Mismatch of batch dimension not detected"; + } + catch (const NodeValidationFailure& error) + { + EXPECT_HAS_SUBSTRING( + error.what(), + std::string("The first dimension of label length must be equal to the first dimension " + "of the logits, the logit length and labels.")); + } + catch (...) + { + FAIL() << "Batch dimension matching check failed for unexpected reason"; + } +} + +TEST(type_prop, ctc_loss_fail_time_dim_mismatch) +{ + // create inputs + auto logits = make_shared(element::f32, Shape{10, 120, 28}); + auto logit_length = make_shared(element::i32, Shape{10}); + auto labels = make_shared(element::i32, Shape{10, 130}); + auto label_length = make_shared(element::i32, Shape{40}); + auto blank_index = make_shared(element::i32, Shape{}); + + try + { + // create CTCLoss node + auto ctc_loss = + make_shared(logits, logit_length, labels, label_length, blank_index); + + // Should have thrown, so fail if it didn't + FAIL() << "Mismatch of time dimension not detected"; + } + catch (const NodeValidationFailure& error) + { + EXPECT_HAS_SUBSTRING( + error.what(), + std::string("The second dimension of labels must be equal to the second dimension " + "of logits.")); + } + catch (...) + { + FAIL() << "Time dimension matching check failed for unexpected reason"; + } +}