Fix for CTCLoss in NGraph (#1563)

Blank index is optional input and must be handled appropriately

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>
This commit is contained in:
Roman Kazantsev 2020-07-31 11:57:29 +03:00 committed by GitHub
parent 43652498c7
commit c56f024630
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 95 additions and 32 deletions

View File

@ -369,7 +369,7 @@ std::shared_ptr<ngraph::Node> V10Parser::createNode(const std::vector<ngraph::Ou
// Check that operation in default opsets
auto isDefaultOpSet = [](const std::string& version) -> bool {
for (size_t i = 1; i <= 3; i++) {
for (size_t i = 1; i <= 4; i++) {
std::string opset_name = "opset" + std::to_string(i);
if (version == opset_name)
return true;

View File

@ -19,16 +19,31 @@
using namespace std;
using namespace ngraph;
constexpr NodeTypeInfo op::CTCLoss::type_info;
constexpr NodeTypeInfo op::v4::CTCLoss::type_info;
op::CTCLoss::CTCLoss(const Output<Node>& logits,
const Output<Node>& logit_length,
const Output<Node>& labels,
const Output<Node>& label_length,
const Output<Node>& blank_index,
const bool preprocess_collapse_repeated,
const bool ctc_merge_repeated,
const bool unique)
op::v4::CTCLoss::CTCLoss(const Output<Node>& logits,
const Output<Node>& logit_length,
const Output<Node>& labels,
const Output<Node>& label_length,
const bool preprocess_collapse_repeated,
const bool ctc_merge_repeated,
const bool unique)
: Op({logits, logit_length, labels, label_length})
, preprocess_collapse_repeated_(preprocess_collapse_repeated)
, ctc_merge_repeated_(ctc_merge_repeated)
, unique_(unique)
{
constructor_validate_and_infer_types();
}
op::v4::CTCLoss::CTCLoss(const Output<Node>& logits,
const Output<Node>& logit_length,
const Output<Node>& labels,
const Output<Node>& label_length,
const Output<Node>& blank_index,
const bool preprocess_collapse_repeated,
const bool ctc_merge_repeated,
const bool unique)
: Op({logits, logit_length, labels, label_length, blank_index})
, preprocess_collapse_repeated_(preprocess_collapse_repeated)
, ctc_merge_repeated_(ctc_merge_repeated)
@ -37,14 +52,13 @@ op::CTCLoss::CTCLoss(const Output<Node>& logits,
constructor_validate_and_infer_types();
}
void op::CTCLoss::validate_and_infer_types()
void op::v4::CTCLoss::validate_and_infer_types()
{
// check types of input tensors
const auto& logits_type = get_input_element_type(0);
const auto& logit_length_type = get_input_element_type(1);
const auto& labels_type = get_input_element_type(2);
const auto& label_length_type = get_input_element_type(3);
const auto& blank_index_type = get_input_element_type(4);
NODE_VALIDATION_CHECK(this,
logits_type.is_real(),
@ -66,17 +80,21 @@ void op::CTCLoss::validate_and_infer_types()
"The label length type is expected to be an integer type. Got: ",
label_length_type);
NODE_VALIDATION_CHECK(this,
blank_index_type.is_integral_number(),
"The blank index type is expected to be an integer type. Got: ",
blank_index_type);
// check optional input type: blank index
if (get_input_size() == 5)
{
const auto& blank_index_type = get_input_element_type(4);
NODE_VALIDATION_CHECK(this,
blank_index_type.is_integral_number(),
"The blank index type is expected to be an integer type. Got: ",
blank_index_type);
}
// check ranks of input tensors
const auto& logits_pshape = get_input_partial_shape(0);
const auto& logit_length_pshape = get_input_partial_shape(1);
const auto& labels_pshape = get_input_partial_shape(2);
const auto& label_length_pshape = get_input_partial_shape(3);
const auto& blank_index_pshape = get_input_partial_shape(4);
NODE_VALIDATION_CHECK(this,
logits_pshape.rank().compatible(3),
@ -98,10 +116,15 @@ void op::CTCLoss::validate_and_infer_types()
"Expected a 1D tensor for label length. Got: ",
label_length_pshape);
NODE_VALIDATION_CHECK(this,
blank_index_pshape.rank().compatible(0),
"Expected a scalar for blank index. Got: ",
blank_index_pshape);
// check optional input shape: blank index
if (get_input_size() == 5)
{
const auto& blank_index_pshape = get_input_partial_shape(4);
NODE_VALIDATION_CHECK(this,
blank_index_pshape.rank().compatible(0),
"Expected a scalar for blank index. Got: ",
blank_index_pshape);
}
// check shapes of input tensors
size_t batch_size = 1;
@ -204,7 +227,7 @@ void op::CTCLoss::validate_and_infer_types()
}
}
bool op::CTCLoss::visit_attributes(AttributeVisitor& visitor)
bool op::v4::CTCLoss::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("preprocess_collapse_repeated", preprocess_collapse_repeated_);
visitor.on_attribute("ctc_merge_repeated", ctc_merge_repeated_);
@ -212,15 +235,32 @@ bool op::CTCLoss::visit_attributes(AttributeVisitor& visitor)
return true;
}
shared_ptr<Node> op::CTCLoss::clone_with_new_inputs(const OutputVector& new_args) const
shared_ptr<Node> op::v4::CTCLoss::clone_with_new_inputs(const OutputVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<CTCLoss>(new_args.at(0),
new_args.at(1),
new_args.at(2),
new_args.at(3),
new_args.at(4),
preprocess_collapse_repeated_,
ctc_merge_repeated_,
unique_);
if (new_args.size() == 4)
{
return make_shared<CTCLoss>(new_args.at(0),
new_args.at(1),
new_args.at(2),
new_args.at(3),
preprocess_collapse_repeated_,
ctc_merge_repeated_,
unique_);
}
else if (new_args.size() == 5)
{
return make_shared<CTCLoss>(new_args.at(0),
new_args.at(1),
new_args.at(2),
new_args.at(3),
new_args.at(4),
preprocess_collapse_repeated_,
ctc_merge_repeated_,
unique_);
}
else
{
throw ngraph_error("Incorrect number of arguments");
}
}

View File

@ -46,6 +46,14 @@ namespace ngraph
/// potential alignment
/// \param unique Flag to find unique elements in a target
/// before matching with alignment
CTCLoss(const Output<Node>& logits,
const Output<Node>& logit_length,
const Output<Node>& labels,
const Output<Node>& label_length,
const bool preprocess_collapse_repeated = false,
const bool ctc_merge_repeated = true,
const bool unique = false);
CTCLoss(const Output<Node>& logits,
const Output<Node>& logit_length,
const Output<Node>& labels,
@ -72,6 +80,5 @@ namespace ngraph
bool unique_;
};
}
using v4::CTCLoss;
}
}

View File

@ -39,6 +39,22 @@ TEST(type_prop, ctc_loss)
EXPECT_TRUE(ctc_loss->get_output_partial_shape(0).same_scheme(PartialShape{10}));
}
TEST(type_prop, ctc_loss_no_blank_index)
{
// create inputs
auto logits = make_shared<op::Parameter>(element::f32, Shape{10, 120, 28});
auto logit_length = make_shared<op::Parameter>(element::i32, Shape{10});
auto labels = make_shared<op::Parameter>(element::i32, Shape{10, 120});
auto label_length = make_shared<op::Parameter>(element::i32, Shape{10});
// create CTCLoss node
auto ctc_loss = make_shared<op::v4::CTCLoss>(logits, logit_length, labels, label_length);
// check type and shape infer
EXPECT_EQ(ctc_loss->get_element_type(), element::f32);
EXPECT_TRUE(ctc_loss->get_output_partial_shape(0).same_scheme(PartialShape{10}));
}
TEST(type_prop, ctc_loss_output_type)
{
// create inputs