From c56f024630a55d33806c43475e9527e86a133a59 Mon Sep 17 00:00:00 2001 From: Roman Kazantsev Date: Fri, 31 Jul 2020 11:57:29 +0300 Subject: [PATCH] Fix for CTCLoss in NGraph (#1563) Blank index is optional input and must be handled appropriately Signed-off-by: Roman Kazantsev --- .../src/readers/ir_reader/ie_ir_parser.cpp | 2 +- ngraph/src/ngraph/op/ctc_loss.cpp | 100 ++++++++++++------ ngraph/src/ngraph/op/ctc_loss.hpp | 9 +- ngraph/test/type_prop/ctc_loss.cpp | 16 +++ 4 files changed, 95 insertions(+), 32 deletions(-) diff --git a/inference-engine/src/readers/ir_reader/ie_ir_parser.cpp b/inference-engine/src/readers/ir_reader/ie_ir_parser.cpp index e3ddad0e1f5..6e5c856d59a 100644 --- a/inference-engine/src/readers/ir_reader/ie_ir_parser.cpp +++ b/inference-engine/src/readers/ir_reader/ie_ir_parser.cpp @@ -369,7 +369,7 @@ std::shared_ptr V10Parser::createNode(const std::vector bool { - for (size_t i = 1; i <= 3; i++) { + for (size_t i = 1; i <= 4; i++) { std::string opset_name = "opset" + std::to_string(i); if (version == opset_name) return true; diff --git a/ngraph/src/ngraph/op/ctc_loss.cpp b/ngraph/src/ngraph/op/ctc_loss.cpp index 05d6b90ec8f..d0036639304 100644 --- a/ngraph/src/ngraph/op/ctc_loss.cpp +++ b/ngraph/src/ngraph/op/ctc_loss.cpp @@ -19,16 +19,31 @@ using namespace std; using namespace ngraph; -constexpr NodeTypeInfo op::CTCLoss::type_info; +constexpr NodeTypeInfo op::v4::CTCLoss::type_info; -op::CTCLoss::CTCLoss(const Output& logits, - const Output& logit_length, - const Output& labels, - const Output& label_length, - const Output& blank_index, - const bool preprocess_collapse_repeated, - const bool ctc_merge_repeated, - const bool unique) +op::v4::CTCLoss::CTCLoss(const Output& logits, + const Output& logit_length, + const Output& labels, + const Output& label_length, + const bool preprocess_collapse_repeated, + const bool ctc_merge_repeated, + const bool unique) + : Op({logits, logit_length, labels, label_length}) + , preprocess_collapse_repeated_(preprocess_collapse_repeated) + , ctc_merge_repeated_(ctc_merge_repeated) + , unique_(unique) +{ + constructor_validate_and_infer_types(); +} + +op::v4::CTCLoss::CTCLoss(const Output& logits, + const Output& logit_length, + const Output& labels, + const Output& label_length, + const Output& blank_index, + const bool preprocess_collapse_repeated, + const bool ctc_merge_repeated, + const bool unique) : Op({logits, logit_length, labels, label_length, blank_index}) , preprocess_collapse_repeated_(preprocess_collapse_repeated) , ctc_merge_repeated_(ctc_merge_repeated) @@ -37,14 +52,13 @@ op::CTCLoss::CTCLoss(const Output& logits, constructor_validate_and_infer_types(); } -void op::CTCLoss::validate_and_infer_types() +void op::v4::CTCLoss::validate_and_infer_types() { // check types of input tensors const auto& logits_type = get_input_element_type(0); const auto& logit_length_type = get_input_element_type(1); const auto& labels_type = get_input_element_type(2); const auto& label_length_type = get_input_element_type(3); - const auto& blank_index_type = get_input_element_type(4); NODE_VALIDATION_CHECK(this, logits_type.is_real(), @@ -66,17 +80,21 @@ void op::CTCLoss::validate_and_infer_types() "The label length type is expected to be an integer type. Got: ", label_length_type); - NODE_VALIDATION_CHECK(this, - blank_index_type.is_integral_number(), - "The blank index type is expected to be an integer type. Got: ", - blank_index_type); + // check optional input type: blank index + if (get_input_size() == 5) + { + const auto& blank_index_type = get_input_element_type(4); + NODE_VALIDATION_CHECK(this, + blank_index_type.is_integral_number(), + "The blank index type is expected to be an integer type. Got: ", + blank_index_type); + } // check ranks of input tensors const auto& logits_pshape = get_input_partial_shape(0); const auto& logit_length_pshape = get_input_partial_shape(1); const auto& labels_pshape = get_input_partial_shape(2); const auto& label_length_pshape = get_input_partial_shape(3); - const auto& blank_index_pshape = get_input_partial_shape(4); NODE_VALIDATION_CHECK(this, logits_pshape.rank().compatible(3), @@ -98,10 +116,15 @@ void op::CTCLoss::validate_and_infer_types() "Expected a 1D tensor for label length. Got: ", label_length_pshape); - NODE_VALIDATION_CHECK(this, - blank_index_pshape.rank().compatible(0), - "Expected a scalar for blank index. Got: ", - blank_index_pshape); + // check optional input shape: blank index + if (get_input_size() == 5) + { + const auto& blank_index_pshape = get_input_partial_shape(4); + NODE_VALIDATION_CHECK(this, + blank_index_pshape.rank().compatible(0), + "Expected a scalar for blank index. Got: ", + blank_index_pshape); + } // check shapes of input tensors size_t batch_size = 1; @@ -204,7 +227,7 @@ void op::CTCLoss::validate_and_infer_types() } } -bool op::CTCLoss::visit_attributes(AttributeVisitor& visitor) +bool op::v4::CTCLoss::visit_attributes(AttributeVisitor& visitor) { visitor.on_attribute("preprocess_collapse_repeated", preprocess_collapse_repeated_); visitor.on_attribute("ctc_merge_repeated", ctc_merge_repeated_); @@ -212,15 +235,32 @@ bool op::CTCLoss::visit_attributes(AttributeVisitor& visitor) return true; } -shared_ptr op::CTCLoss::clone_with_new_inputs(const OutputVector& new_args) const +shared_ptr op::v4::CTCLoss::clone_with_new_inputs(const OutputVector& new_args) const { check_new_args_count(this, new_args); - return make_shared(new_args.at(0), - new_args.at(1), - new_args.at(2), - new_args.at(3), - new_args.at(4), - preprocess_collapse_repeated_, - ctc_merge_repeated_, - unique_); + if (new_args.size() == 4) + { + return make_shared(new_args.at(0), + new_args.at(1), + new_args.at(2), + new_args.at(3), + preprocess_collapse_repeated_, + ctc_merge_repeated_, + unique_); + } + else if (new_args.size() == 5) + { + return make_shared(new_args.at(0), + new_args.at(1), + new_args.at(2), + new_args.at(3), + new_args.at(4), + preprocess_collapse_repeated_, + ctc_merge_repeated_, + unique_); + } + else + { + throw ngraph_error("Incorrect number of arguments"); + } } diff --git a/ngraph/src/ngraph/op/ctc_loss.hpp b/ngraph/src/ngraph/op/ctc_loss.hpp index 82518e78322..0fa2f5c0f67 100644 --- a/ngraph/src/ngraph/op/ctc_loss.hpp +++ b/ngraph/src/ngraph/op/ctc_loss.hpp @@ -46,6 +46,14 @@ namespace ngraph /// potential alignment /// \param unique Flag to find unique elements in a target /// before matching with alignment + CTCLoss(const Output& logits, + const Output& logit_length, + const Output& labels, + const Output& label_length, + const bool preprocess_collapse_repeated = false, + const bool ctc_merge_repeated = true, + const bool unique = false); + CTCLoss(const Output& logits, const Output& logit_length, const Output& labels, @@ -72,6 +80,5 @@ namespace ngraph bool unique_; }; } - using v4::CTCLoss; } } diff --git a/ngraph/test/type_prop/ctc_loss.cpp b/ngraph/test/type_prop/ctc_loss.cpp index 07333a5260e..2b2cc6f1847 100644 --- a/ngraph/test/type_prop/ctc_loss.cpp +++ b/ngraph/test/type_prop/ctc_loss.cpp @@ -39,6 +39,22 @@ TEST(type_prop, ctc_loss) EXPECT_TRUE(ctc_loss->get_output_partial_shape(0).same_scheme(PartialShape{10})); } +TEST(type_prop, ctc_loss_no_blank_index) +{ + // create inputs + auto logits = make_shared(element::f32, Shape{10, 120, 28}); + auto logit_length = make_shared(element::i32, Shape{10}); + auto labels = make_shared(element::i32, Shape{10, 120}); + auto label_length = make_shared(element::i32, Shape{10}); + + // create CTCLoss node + auto ctc_loss = make_shared(logits, logit_length, labels, label_length); + + // check type and shape infer + EXPECT_EQ(ctc_loss->get_element_type(), element::f32); + EXPECT_TRUE(ctc_loss->get_output_partial_shape(0).same_scheme(PartialShape{10})); +} + TEST(type_prop, ctc_loss_output_type) { // create inputs