Fix for CTCLoss in NGraph (#1563)
Blank index is optional input and must be handled appropriately Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>
This commit is contained in:
parent
43652498c7
commit
c56f024630
@ -369,7 +369,7 @@ std::shared_ptr<ngraph::Node> V10Parser::createNode(const std::vector<ngraph::Ou
|
|||||||
|
|
||||||
// Check that operation in default opsets
|
// Check that operation in default opsets
|
||||||
auto isDefaultOpSet = [](const std::string& version) -> bool {
|
auto isDefaultOpSet = [](const std::string& version) -> bool {
|
||||||
for (size_t i = 1; i <= 3; i++) {
|
for (size_t i = 1; i <= 4; i++) {
|
||||||
std::string opset_name = "opset" + std::to_string(i);
|
std::string opset_name = "opset" + std::to_string(i);
|
||||||
if (version == opset_name)
|
if (version == opset_name)
|
||||||
return true;
|
return true;
|
||||||
|
@ -19,9 +19,24 @@
|
|||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace ngraph;
|
using namespace ngraph;
|
||||||
|
|
||||||
constexpr NodeTypeInfo op::CTCLoss::type_info;
|
constexpr NodeTypeInfo op::v4::CTCLoss::type_info;
|
||||||
|
|
||||||
op::CTCLoss::CTCLoss(const Output<Node>& logits,
|
op::v4::CTCLoss::CTCLoss(const Output<Node>& logits,
|
||||||
|
const Output<Node>& logit_length,
|
||||||
|
const Output<Node>& labels,
|
||||||
|
const Output<Node>& label_length,
|
||||||
|
const bool preprocess_collapse_repeated,
|
||||||
|
const bool ctc_merge_repeated,
|
||||||
|
const bool unique)
|
||||||
|
: Op({logits, logit_length, labels, label_length})
|
||||||
|
, preprocess_collapse_repeated_(preprocess_collapse_repeated)
|
||||||
|
, ctc_merge_repeated_(ctc_merge_repeated)
|
||||||
|
, unique_(unique)
|
||||||
|
{
|
||||||
|
constructor_validate_and_infer_types();
|
||||||
|
}
|
||||||
|
|
||||||
|
op::v4::CTCLoss::CTCLoss(const Output<Node>& logits,
|
||||||
const Output<Node>& logit_length,
|
const Output<Node>& logit_length,
|
||||||
const Output<Node>& labels,
|
const Output<Node>& labels,
|
||||||
const Output<Node>& label_length,
|
const Output<Node>& label_length,
|
||||||
@ -37,14 +52,13 @@ op::CTCLoss::CTCLoss(const Output<Node>& logits,
|
|||||||
constructor_validate_and_infer_types();
|
constructor_validate_and_infer_types();
|
||||||
}
|
}
|
||||||
|
|
||||||
void op::CTCLoss::validate_and_infer_types()
|
void op::v4::CTCLoss::validate_and_infer_types()
|
||||||
{
|
{
|
||||||
// check types of input tensors
|
// check types of input tensors
|
||||||
const auto& logits_type = get_input_element_type(0);
|
const auto& logits_type = get_input_element_type(0);
|
||||||
const auto& logit_length_type = get_input_element_type(1);
|
const auto& logit_length_type = get_input_element_type(1);
|
||||||
const auto& labels_type = get_input_element_type(2);
|
const auto& labels_type = get_input_element_type(2);
|
||||||
const auto& label_length_type = get_input_element_type(3);
|
const auto& label_length_type = get_input_element_type(3);
|
||||||
const auto& blank_index_type = get_input_element_type(4);
|
|
||||||
|
|
||||||
NODE_VALIDATION_CHECK(this,
|
NODE_VALIDATION_CHECK(this,
|
||||||
logits_type.is_real(),
|
logits_type.is_real(),
|
||||||
@ -66,17 +80,21 @@ void op::CTCLoss::validate_and_infer_types()
|
|||||||
"The label length type is expected to be an integer type. Got: ",
|
"The label length type is expected to be an integer type. Got: ",
|
||||||
label_length_type);
|
label_length_type);
|
||||||
|
|
||||||
|
// check optional input type: blank index
|
||||||
|
if (get_input_size() == 5)
|
||||||
|
{
|
||||||
|
const auto& blank_index_type = get_input_element_type(4);
|
||||||
NODE_VALIDATION_CHECK(this,
|
NODE_VALIDATION_CHECK(this,
|
||||||
blank_index_type.is_integral_number(),
|
blank_index_type.is_integral_number(),
|
||||||
"The blank index type is expected to be an integer type. Got: ",
|
"The blank index type is expected to be an integer type. Got: ",
|
||||||
blank_index_type);
|
blank_index_type);
|
||||||
|
}
|
||||||
|
|
||||||
// check ranks of input tensors
|
// check ranks of input tensors
|
||||||
const auto& logits_pshape = get_input_partial_shape(0);
|
const auto& logits_pshape = get_input_partial_shape(0);
|
||||||
const auto& logit_length_pshape = get_input_partial_shape(1);
|
const auto& logit_length_pshape = get_input_partial_shape(1);
|
||||||
const auto& labels_pshape = get_input_partial_shape(2);
|
const auto& labels_pshape = get_input_partial_shape(2);
|
||||||
const auto& label_length_pshape = get_input_partial_shape(3);
|
const auto& label_length_pshape = get_input_partial_shape(3);
|
||||||
const auto& blank_index_pshape = get_input_partial_shape(4);
|
|
||||||
|
|
||||||
NODE_VALIDATION_CHECK(this,
|
NODE_VALIDATION_CHECK(this,
|
||||||
logits_pshape.rank().compatible(3),
|
logits_pshape.rank().compatible(3),
|
||||||
@ -98,10 +116,15 @@ void op::CTCLoss::validate_and_infer_types()
|
|||||||
"Expected a 1D tensor for label length. Got: ",
|
"Expected a 1D tensor for label length. Got: ",
|
||||||
label_length_pshape);
|
label_length_pshape);
|
||||||
|
|
||||||
|
// check optional input shape: blank index
|
||||||
|
if (get_input_size() == 5)
|
||||||
|
{
|
||||||
|
const auto& blank_index_pshape = get_input_partial_shape(4);
|
||||||
NODE_VALIDATION_CHECK(this,
|
NODE_VALIDATION_CHECK(this,
|
||||||
blank_index_pshape.rank().compatible(0),
|
blank_index_pshape.rank().compatible(0),
|
||||||
"Expected a scalar for blank index. Got: ",
|
"Expected a scalar for blank index. Got: ",
|
||||||
blank_index_pshape);
|
blank_index_pshape);
|
||||||
|
}
|
||||||
|
|
||||||
// check shapes of input tensors
|
// check shapes of input tensors
|
||||||
size_t batch_size = 1;
|
size_t batch_size = 1;
|
||||||
@ -204,7 +227,7 @@ void op::CTCLoss::validate_and_infer_types()
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool op::CTCLoss::visit_attributes(AttributeVisitor& visitor)
|
bool op::v4::CTCLoss::visit_attributes(AttributeVisitor& visitor)
|
||||||
{
|
{
|
||||||
visitor.on_attribute("preprocess_collapse_repeated", preprocess_collapse_repeated_);
|
visitor.on_attribute("preprocess_collapse_repeated", preprocess_collapse_repeated_);
|
||||||
visitor.on_attribute("ctc_merge_repeated", ctc_merge_repeated_);
|
visitor.on_attribute("ctc_merge_repeated", ctc_merge_repeated_);
|
||||||
@ -212,9 +235,21 @@ bool op::CTCLoss::visit_attributes(AttributeVisitor& visitor)
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
shared_ptr<Node> op::CTCLoss::clone_with_new_inputs(const OutputVector& new_args) const
|
shared_ptr<Node> op::v4::CTCLoss::clone_with_new_inputs(const OutputVector& new_args) const
|
||||||
{
|
{
|
||||||
check_new_args_count(this, new_args);
|
check_new_args_count(this, new_args);
|
||||||
|
if (new_args.size() == 4)
|
||||||
|
{
|
||||||
|
return make_shared<CTCLoss>(new_args.at(0),
|
||||||
|
new_args.at(1),
|
||||||
|
new_args.at(2),
|
||||||
|
new_args.at(3),
|
||||||
|
preprocess_collapse_repeated_,
|
||||||
|
ctc_merge_repeated_,
|
||||||
|
unique_);
|
||||||
|
}
|
||||||
|
else if (new_args.size() == 5)
|
||||||
|
{
|
||||||
return make_shared<CTCLoss>(new_args.at(0),
|
return make_shared<CTCLoss>(new_args.at(0),
|
||||||
new_args.at(1),
|
new_args.at(1),
|
||||||
new_args.at(2),
|
new_args.at(2),
|
||||||
@ -224,3 +259,8 @@ shared_ptr<Node> op::CTCLoss::clone_with_new_inputs(const OutputVector& new_args
|
|||||||
ctc_merge_repeated_,
|
ctc_merge_repeated_,
|
||||||
unique_);
|
unique_);
|
||||||
}
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
throw ngraph_error("Incorrect number of arguments");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -46,6 +46,14 @@ namespace ngraph
|
|||||||
/// potential alignment
|
/// potential alignment
|
||||||
/// \param unique Flag to find unique elements in a target
|
/// \param unique Flag to find unique elements in a target
|
||||||
/// before matching with alignment
|
/// before matching with alignment
|
||||||
|
CTCLoss(const Output<Node>& logits,
|
||||||
|
const Output<Node>& logit_length,
|
||||||
|
const Output<Node>& labels,
|
||||||
|
const Output<Node>& label_length,
|
||||||
|
const bool preprocess_collapse_repeated = false,
|
||||||
|
const bool ctc_merge_repeated = true,
|
||||||
|
const bool unique = false);
|
||||||
|
|
||||||
CTCLoss(const Output<Node>& logits,
|
CTCLoss(const Output<Node>& logits,
|
||||||
const Output<Node>& logit_length,
|
const Output<Node>& logit_length,
|
||||||
const Output<Node>& labels,
|
const Output<Node>& labels,
|
||||||
@ -72,6 +80,5 @@ namespace ngraph
|
|||||||
bool unique_;
|
bool unique_;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
using v4::CTCLoss;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -39,6 +39,22 @@ TEST(type_prop, ctc_loss)
|
|||||||
EXPECT_TRUE(ctc_loss->get_output_partial_shape(0).same_scheme(PartialShape{10}));
|
EXPECT_TRUE(ctc_loss->get_output_partial_shape(0).same_scheme(PartialShape{10}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(type_prop, ctc_loss_no_blank_index)
|
||||||
|
{
|
||||||
|
// create inputs
|
||||||
|
auto logits = make_shared<op::Parameter>(element::f32, Shape{10, 120, 28});
|
||||||
|
auto logit_length = make_shared<op::Parameter>(element::i32, Shape{10});
|
||||||
|
auto labels = make_shared<op::Parameter>(element::i32, Shape{10, 120});
|
||||||
|
auto label_length = make_shared<op::Parameter>(element::i32, Shape{10});
|
||||||
|
|
||||||
|
// create CTCLoss node
|
||||||
|
auto ctc_loss = make_shared<op::v4::CTCLoss>(logits, logit_length, labels, label_length);
|
||||||
|
|
||||||
|
// check type and shape infer
|
||||||
|
EXPECT_EQ(ctc_loss->get_element_type(), element::f32);
|
||||||
|
EXPECT_TRUE(ctc_loss->get_output_partial_shape(0).same_scheme(PartialShape{10}));
|
||||||
|
}
|
||||||
|
|
||||||
TEST(type_prop, ctc_loss_output_type)
|
TEST(type_prop, ctc_loss_output_type)
|
||||||
{
|
{
|
||||||
// create inputs
|
// create inputs
|
||||||
|
Loading…
Reference in New Issue
Block a user