Fix for CTCLoss in NGraph (#1563)
Blank index is optional input and must be handled appropriately Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>
This commit is contained in:
parent
43652498c7
commit
c56f024630
@ -369,7 +369,7 @@ std::shared_ptr<ngraph::Node> V10Parser::createNode(const std::vector<ngraph::Ou
|
||||
|
||||
// Check that operation in default opsets
|
||||
auto isDefaultOpSet = [](const std::string& version) -> bool {
|
||||
for (size_t i = 1; i <= 3; i++) {
|
||||
for (size_t i = 1; i <= 4; i++) {
|
||||
std::string opset_name = "opset" + std::to_string(i);
|
||||
if (version == opset_name)
|
||||
return true;
|
||||
|
@ -19,16 +19,31 @@
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
constexpr NodeTypeInfo op::CTCLoss::type_info;
|
||||
constexpr NodeTypeInfo op::v4::CTCLoss::type_info;
|
||||
|
||||
op::CTCLoss::CTCLoss(const Output<Node>& logits,
|
||||
const Output<Node>& logit_length,
|
||||
const Output<Node>& labels,
|
||||
const Output<Node>& label_length,
|
||||
const Output<Node>& blank_index,
|
||||
const bool preprocess_collapse_repeated,
|
||||
const bool ctc_merge_repeated,
|
||||
const bool unique)
|
||||
op::v4::CTCLoss::CTCLoss(const Output<Node>& logits,
|
||||
const Output<Node>& logit_length,
|
||||
const Output<Node>& labels,
|
||||
const Output<Node>& label_length,
|
||||
const bool preprocess_collapse_repeated,
|
||||
const bool ctc_merge_repeated,
|
||||
const bool unique)
|
||||
: Op({logits, logit_length, labels, label_length})
|
||||
, preprocess_collapse_repeated_(preprocess_collapse_repeated)
|
||||
, ctc_merge_repeated_(ctc_merge_repeated)
|
||||
, unique_(unique)
|
||||
{
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
op::v4::CTCLoss::CTCLoss(const Output<Node>& logits,
|
||||
const Output<Node>& logit_length,
|
||||
const Output<Node>& labels,
|
||||
const Output<Node>& label_length,
|
||||
const Output<Node>& blank_index,
|
||||
const bool preprocess_collapse_repeated,
|
||||
const bool ctc_merge_repeated,
|
||||
const bool unique)
|
||||
: Op({logits, logit_length, labels, label_length, blank_index})
|
||||
, preprocess_collapse_repeated_(preprocess_collapse_repeated)
|
||||
, ctc_merge_repeated_(ctc_merge_repeated)
|
||||
@ -37,14 +52,13 @@ op::CTCLoss::CTCLoss(const Output<Node>& logits,
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
void op::CTCLoss::validate_and_infer_types()
|
||||
void op::v4::CTCLoss::validate_and_infer_types()
|
||||
{
|
||||
// check types of input tensors
|
||||
const auto& logits_type = get_input_element_type(0);
|
||||
const auto& logit_length_type = get_input_element_type(1);
|
||||
const auto& labels_type = get_input_element_type(2);
|
||||
const auto& label_length_type = get_input_element_type(3);
|
||||
const auto& blank_index_type = get_input_element_type(4);
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
logits_type.is_real(),
|
||||
@ -66,17 +80,21 @@ void op::CTCLoss::validate_and_infer_types()
|
||||
"The label length type is expected to be an integer type. Got: ",
|
||||
label_length_type);
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
blank_index_type.is_integral_number(),
|
||||
"The blank index type is expected to be an integer type. Got: ",
|
||||
blank_index_type);
|
||||
// check optional input type: blank index
|
||||
if (get_input_size() == 5)
|
||||
{
|
||||
const auto& blank_index_type = get_input_element_type(4);
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
blank_index_type.is_integral_number(),
|
||||
"The blank index type is expected to be an integer type. Got: ",
|
||||
blank_index_type);
|
||||
}
|
||||
|
||||
// check ranks of input tensors
|
||||
const auto& logits_pshape = get_input_partial_shape(0);
|
||||
const auto& logit_length_pshape = get_input_partial_shape(1);
|
||||
const auto& labels_pshape = get_input_partial_shape(2);
|
||||
const auto& label_length_pshape = get_input_partial_shape(3);
|
||||
const auto& blank_index_pshape = get_input_partial_shape(4);
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
logits_pshape.rank().compatible(3),
|
||||
@ -98,10 +116,15 @@ void op::CTCLoss::validate_and_infer_types()
|
||||
"Expected a 1D tensor for label length. Got: ",
|
||||
label_length_pshape);
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
blank_index_pshape.rank().compatible(0),
|
||||
"Expected a scalar for blank index. Got: ",
|
||||
blank_index_pshape);
|
||||
// check optional input shape: blank index
|
||||
if (get_input_size() == 5)
|
||||
{
|
||||
const auto& blank_index_pshape = get_input_partial_shape(4);
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
blank_index_pshape.rank().compatible(0),
|
||||
"Expected a scalar for blank index. Got: ",
|
||||
blank_index_pshape);
|
||||
}
|
||||
|
||||
// check shapes of input tensors
|
||||
size_t batch_size = 1;
|
||||
@ -204,7 +227,7 @@ void op::CTCLoss::validate_and_infer_types()
|
||||
}
|
||||
}
|
||||
|
||||
bool op::CTCLoss::visit_attributes(AttributeVisitor& visitor)
|
||||
bool op::v4::CTCLoss::visit_attributes(AttributeVisitor& visitor)
|
||||
{
|
||||
visitor.on_attribute("preprocess_collapse_repeated", preprocess_collapse_repeated_);
|
||||
visitor.on_attribute("ctc_merge_repeated", ctc_merge_repeated_);
|
||||
@ -212,15 +235,32 @@ bool op::CTCLoss::visit_attributes(AttributeVisitor& visitor)
|
||||
return true;
|
||||
}
|
||||
|
||||
shared_ptr<Node> op::CTCLoss::clone_with_new_inputs(const OutputVector& new_args) const
|
||||
shared_ptr<Node> op::v4::CTCLoss::clone_with_new_inputs(const OutputVector& new_args) const
|
||||
{
|
||||
check_new_args_count(this, new_args);
|
||||
return make_shared<CTCLoss>(new_args.at(0),
|
||||
new_args.at(1),
|
||||
new_args.at(2),
|
||||
new_args.at(3),
|
||||
new_args.at(4),
|
||||
preprocess_collapse_repeated_,
|
||||
ctc_merge_repeated_,
|
||||
unique_);
|
||||
if (new_args.size() == 4)
|
||||
{
|
||||
return make_shared<CTCLoss>(new_args.at(0),
|
||||
new_args.at(1),
|
||||
new_args.at(2),
|
||||
new_args.at(3),
|
||||
preprocess_collapse_repeated_,
|
||||
ctc_merge_repeated_,
|
||||
unique_);
|
||||
}
|
||||
else if (new_args.size() == 5)
|
||||
{
|
||||
return make_shared<CTCLoss>(new_args.at(0),
|
||||
new_args.at(1),
|
||||
new_args.at(2),
|
||||
new_args.at(3),
|
||||
new_args.at(4),
|
||||
preprocess_collapse_repeated_,
|
||||
ctc_merge_repeated_,
|
||||
unique_);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw ngraph_error("Incorrect number of arguments");
|
||||
}
|
||||
}
|
||||
|
@ -46,6 +46,14 @@ namespace ngraph
|
||||
/// potential alignment
|
||||
/// \param unique Flag to find unique elements in a target
|
||||
/// before matching with alignment
|
||||
CTCLoss(const Output<Node>& logits,
|
||||
const Output<Node>& logit_length,
|
||||
const Output<Node>& labels,
|
||||
const Output<Node>& label_length,
|
||||
const bool preprocess_collapse_repeated = false,
|
||||
const bool ctc_merge_repeated = true,
|
||||
const bool unique = false);
|
||||
|
||||
CTCLoss(const Output<Node>& logits,
|
||||
const Output<Node>& logit_length,
|
||||
const Output<Node>& labels,
|
||||
@ -72,6 +80,5 @@ namespace ngraph
|
||||
bool unique_;
|
||||
};
|
||||
}
|
||||
using v4::CTCLoss;
|
||||
}
|
||||
}
|
||||
|
@ -39,6 +39,22 @@ TEST(type_prop, ctc_loss)
|
||||
EXPECT_TRUE(ctc_loss->get_output_partial_shape(0).same_scheme(PartialShape{10}));
|
||||
}
|
||||
|
||||
TEST(type_prop, ctc_loss_no_blank_index)
|
||||
{
|
||||
// create inputs
|
||||
auto logits = make_shared<op::Parameter>(element::f32, Shape{10, 120, 28});
|
||||
auto logit_length = make_shared<op::Parameter>(element::i32, Shape{10});
|
||||
auto labels = make_shared<op::Parameter>(element::i32, Shape{10, 120});
|
||||
auto label_length = make_shared<op::Parameter>(element::i32, Shape{10});
|
||||
|
||||
// create CTCLoss node
|
||||
auto ctc_loss = make_shared<op::v4::CTCLoss>(logits, logit_length, labels, label_length);
|
||||
|
||||
// check type and shape infer
|
||||
EXPECT_EQ(ctc_loss->get_element_type(), element::f32);
|
||||
EXPECT_TRUE(ctc_loss->get_output_partial_shape(0).same_scheme(PartialShape{10}));
|
||||
}
|
||||
|
||||
TEST(type_prop, ctc_loss_output_type)
|
||||
{
|
||||
// create inputs
|
||||
|
Loading…
Reference in New Issue
Block a user