Add CTCGreedyDecoder mo support (#4009)

* Add CTCGreedyDecoder mo support

* Update copiright

* Update bom file

* Add transformation

* Fix code style

* Fix according to review

* Add CTCGreedyDecoder v6 to ConvertPrecision

* Hot fix

* Add replasment for ctc_greedy_decoder

* Fix test

* Fix

* Update ie transform

* Draft ctc lost replaser

* Add ctcloss replaser

* Update

* Refactoring code

* Update transformation

* Update decoder

* Remove comments

* Convert seq mask from int to float

* Fix unit test

* Add dynamic tests

* Refactoring code

* Fix py code style

* update style

* Disable ctcgreedydecoder transform for mkldnn plugin

* Add some comments

* Add transfor code comments

* Enable transform from differend plagins

* Fix mo

* fix tests

* Fix comment

* Fix convert precition

* Update comment

* Fix prcition

* Refactoring according to reviw

* Add ir reder extender

* Rename transformation

* Update bom file

* Fix mo replacer

* Fix tests

* Move transform to decomp

* Add check blank_index

* Rafactoring ctcloss

* Change dinemic rank check

* Fix ctclos extractor

* Remove comment

* Fix code style

* Refactoring pattern matcher for transformation CTCGreedyDecoder

* Disavle transform for vpu

* Refactoring according to review

* Refactoring code

* Disable transformation for cldnn

* Remove unused code

* Reverse transfomation

* Fix code style

* Hot fix transform

* Fix unit tests

* Update transform

* Enable transform in common pipline

* Fix names replasments for mo transformations

* Hot fix

* Fix
This commit is contained in:
iliya mironov 2021-02-17 14:38:51 +03:00 committed by GitHub
parent edbb802e55
commit f670b7cb3a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 1066 additions and 511 deletions

View File

@ -58,6 +58,7 @@
#include <transformations/op_conversions/convert_nms_to_nms_ie_internal.hpp>
#include <transformations/op_conversions/convert_interpolate1_to_interpolate4.hpp>
#include <transformations/op_conversions/convert_gather_0d.hpp>
#include <transformations/op_conversions/simplify_ctc_greedy_decoder_seq_len.hpp>
#include <transformations/convert_precision.hpp>
#include <transformations/init_node_info.hpp>
#include <transformations/rt_info/fused_names_attribute.hpp>
@ -278,6 +279,7 @@ InferenceEngine::CNNNetwork clDNNEngine::CloneAndTransformNetwork(const Inferenc
pass_config->disable<ngraph::pass::LogSoftmaxDecomposition>();
pass_config->disable<ngraph::pass::ConvertBroadcast3>();
pass_config->disable<ngraph::pass::WeightsDequantizeToFakeQuantize>();
pass_config->disable<ngraph::pass::SimplifyCTCGreedyDecoderSeqLen>();
pass_config->enable<ngraph::pass::ConvertInterpolate1ToInterpolate4>();

View File

@ -57,6 +57,7 @@
#include <transformations/op_conversions/gru_cell_decomposition.hpp>
#include <transformations/op_conversions/log_softmax_decomposition.hpp>
#include <transformations/op_conversions/convert_interpolate1_to_interpolate4.hpp>
#include <transformations/op_conversions/simplify_ctc_greedy_decoder_seq_len.hpp>
#include <transformations/convert_precision.hpp>
#include <transformations/init_node_info.hpp>
#include <transformations/rt_info/fused_names_attribute.hpp>
@ -227,6 +228,7 @@ static void Transformation(CNNNetwork& clonedNetwork, const Config& conf) {
pass_config->disable<ngraph::pass::LogSoftmaxDecomposition>();
pass_config->disable<ngraph::pass::ConvertInterpolateToInterpOrResampleMatcher>();
pass_config->disable<ngraph::pass::WeightsDequantizeToFakeQuantize>();
pass_config->disable<ngraph::pass::SimplifyCTCGreedyDecoderSeqLen>();
pass_config->enable<ngraph::pass::ConvertInterpolate1ToInterpolate4>();

View File

@ -0,0 +1,43 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <memory>
#include <transformations_visibility.hpp>
#include <ngraph/pass/graph_rewrite.hpp>
namespace ngraph {
namespace pass {
class TRANSFORMATIONS_API SimplifyCTCGreedyDecoderSeqLen;
} // namespace pass
} // namespace ngraph
/**
* @ingroup ie_transformation_common_api
* @brief SimplifyCTCGreedyDecoder converts v6:CTCGreedyDecoderSeqLen into v0::CTCGreedyDecoder.
*
* data[N, T, C] seq_len[N]
* \ /
* CTCGreedyDecoderSeqLen
*
* will be converted to
*
* data[T, N, C] seq_mask[T, N]
* \ /
* CTCGreedyDecoder
* / \
* class_index[N, T] seq_len[N]
*
* The transformation works only for case when the blank_index input == C-1, where C is the number of classes.
*/
class ngraph::pass::SimplifyCTCGreedyDecoderSeqLen: public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
SimplifyCTCGreedyDecoderSeqLen();
};

View File

@ -52,6 +52,7 @@
#include "transformations/op_conversions/hsigmoid_decomposition.hpp"
#include "transformations/op_conversions/log_softmax_decomposition.hpp"
#include "transformations/op_conversions/mvn6_decomposition.hpp"
#include "transformations/op_conversions/simplify_ctc_greedy_decoder_seq_len.hpp"
#include <ngraph/pass/manager.hpp>
#include <ngraph/pass/constant_folding.hpp>
@ -119,6 +120,7 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::
decomp->add_matcher<ngraph::pass::ConvertSpaceToDepth>();
decomp->add_matcher<ngraph::pass::BatchNormDecomposition>();
decomp->add_matcher<ngraph::pass::MVN6Decomposition>();
decomp->add_matcher<ngraph::pass::SimplifyCTCGreedyDecoderSeqLen>();
decomp->set_name("ngraph::pass::CommonDecompositions");
// CF is required after all decompositions

View File

@ -8,6 +8,7 @@
#include <memory>
#include <vector>
#include <ngraph/opsets/opset6.hpp>
#include <ngraph/opsets/opset5.hpp>
#include <ngraph/opsets/opset4.hpp>
#include <ngraph/opsets/opset3.hpp>
@ -29,6 +30,7 @@ bool fuse_type_to_nms5(std::shared_ptr<ngraph::Node> & node, ngraph::element::Ty
bool fuse_type_to_topk(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx);
bool fuse_type_to_nonzero(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx);
bool fuse_type_to_bucketize(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx);
bool fuse_type_to_ctc_greedy_decoder_seq_len(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx);
bool extend_select_type(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx);
@ -87,6 +89,7 @@ bool ngraph::pass::ConvertPrecision::run_on_function(std::shared_ptr<ngraph::Fun
{opset3::NonMaxSuppression::type_info, fuse_type_to_nms3},
{opset4::NonMaxSuppression::type_info, fuse_type_to_nms4},
{opset5::NonMaxSuppression::type_info, fuse_type_to_nms5},
{opset6::CTCGreedyDecoderSeqLen::type_info, fuse_type_to_ctc_greedy_decoder_seq_len},
{opset4::TopK::type_info, fuse_type_to_topk},
{opset4::NonZero::type_info, fuse_type_to_nonzero},
{opset4::Bucketize::type_info, fuse_type_to_bucketize},
@ -260,6 +263,20 @@ bool fuse_type_to_topk(std::shared_ptr<ngraph::Node> & node, ngraph::element::Ty
return false;
}
bool fuse_type_to_ctc_greedy_decoder_seq_len(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx) {
if (auto ctc_decoder = as_type_ptr<opset6::CTCGreedyDecoderSeqLen>(node)) {
if (idx == 0 && (to == element::i32 || to == element::i64)) {
ctc_decoder->set_classes_index_type(to);
return true;
}
if (idx == 1 && (to == element::i32 || to == element::i64)) {
ctc_decoder->set_sequence_length_type(to);
return true;
}
}
return false;
}
bool fuse_type_to_nonzero(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx) {
if (auto nonzero = as_type_ptr<opset4::NonZero>(node)) {
if (to == element::i32 || to == element::i64) {

View File

@ -0,0 +1,128 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "itt.hpp"
#include "transformations/op_conversions/simplify_ctc_greedy_decoder_seq_len.hpp"
#include <ngraph/opsets/opset6.hpp>
#include <ngraph/rt_info.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
NGRAPH_RTTI_DEFINITION(ngraph::pass::SimplifyCTCGreedyDecoderSeqLen, "SimplifyCTCGreedyDecoder", 0);
ngraph::pass::SimplifyCTCGreedyDecoderSeqLen::SimplifyCTCGreedyDecoderSeqLen() {
MATCHER_SCOPE(SimplifyCTCGreedyDecoderSeqLen);
auto decoder = pattern::wrap_type<opset6::CTCGreedyDecoderSeqLen>();
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
auto decoder_seq_len = std::dynamic_pointer_cast<opset6::CTCGreedyDecoderSeqLen> (m.get_match_root());
if (!decoder_seq_len) {
return false;
}
if (decoder_seq_len->get_input_size() > 2) {
const auto data_pshape = decoder_seq_len->get_input_partial_shape(0);
auto blank_index = std::dynamic_pointer_cast<ngraph::opset6::Constant>(decoder_seq_len->input_value(2).get_node_shared_ptr());
if (!blank_index || data_pshape.rank().is_dynamic() || data_pshape[2].is_dynamic()) {
return false;
}
const std::vector<int64_t> &blank_index_values = blank_index->cast_vector<int64_t>();
const auto num_classes = decoder_seq_len->get_input_partial_shape(0)[2].get_length();
if (blank_index_values[0] != (num_classes - 1)) {
return false;
}
}
element::Type data_type = decoder_seq_len->input_value(0).get_element_type();
element::Type seq_len_type = decoder_seq_len->input_value(1).get_element_type();
// Transposing input data channels from [N, T, C] to [T, N, C]. Need for compatible with CTCGreedyDecoder v1
auto transpose = std::make_shared<ngraph::opset6::Transpose>(decoder_seq_len->input_value(0),
ngraph::opset6::Constant::create(element::i32,
Shape({3}), {1, 0, 2}));
// Receive time and batch dimensions and concatenate to [T, N] tensor shapes
auto data_shape = std::make_shared<ngraph::opset6::ShapeOf>(decoder_seq_len->input_value(0));
auto axisT = ngraph::opset6::Constant::create(seq_len_type, Shape{}, {0});
auto indexT = ngraph::opset6::Constant::create(seq_len_type, Shape{1}, {1});
auto T = std::make_shared<ngraph::opset6::Gather>(data_shape, indexT, axisT);
auto axisN = ngraph::opset6::Constant::create(seq_len_type, Shape{}, {0});
auto indexN = ngraph::opset6::Constant::create(seq_len_type, Shape{1}, {0});
auto N = std::make_shared<ngraph::opset6::Gather>(data_shape, indexN, axisN);
auto start = opset6::Constant::create(seq_len_type, Shape{}, {1});
auto step = opset6::Constant::create(seq_len_type, Shape{}, {1});
auto plus1 = opset6::Constant::create(element::i64, Shape{1}, {1});
auto plusT = std::make_shared<ngraph::opset6::Add>(T, plus1);
auto const_plusT = ngraph::opset6::Constant::create(seq_len_type, Shape{1}, {0});
auto plusT_scalar = std::make_shared<ngraph::opset6::Squeeze>(plusT, const_plusT);
auto range1T = std::make_shared<ngraph::opset6::Range>(start, plusT_scalar, step, seq_len_type);
auto mask_shape = std::make_shared<ngraph::opset6::Concat>(
OutputVector{T->output(0), N->output(0)}, 0);
// Generate 2D tensor [T, N] for seq mask
auto upper_bounds = std::make_shared<ngraph::opset6::Broadcast>(
decoder_seq_len->input_value(1), mask_shape->output(0));
auto transpose_upper_bounds = std::make_shared<ngraph::opset6::Transpose>(upper_bounds->output(0),
ngraph::opset6::Constant::create(seq_len_type,
Shape({2}), {1, 0}));
// Compute boolean sequence mask
auto bool_seq_mask = std::make_shared<ngraph::opset6::GreaterEqual>(transpose_upper_bounds->output(0),
range1T->output(0));
// Generate resulted seq mask
auto mask_val_true = ngraph::opset6::Constant::create(seq_len_type, Shape{}, {1});
auto mask_val_false = ngraph::opset6::Constant::create(seq_len_type, Shape{}, {0});
auto seq_mask = std::make_shared<ngraph::opset6::Select>(bool_seq_mask, mask_val_true, mask_val_false);
auto transpose_seq_mask = std::make_shared<ngraph::opset6::Transpose>(seq_mask->output(0),
ngraph::opset6::Constant::create(seq_len_type,
Shape({2}), {1, 0}));
auto transpose_seq_mask_f = std::make_shared<ngraph::opset6::Convert>(transpose_seq_mask->output(0), data_type);
// Create CTCGreedyDecoder with original merge_repeated attribute and connect data and resulted seq_mask
auto decoder = std::make_shared<ngraph::opset6::CTCGreedyDecoder>(transpose,
transpose_seq_mask_f->output(0),
decoder_seq_len->get_merge_repeated());
decoder->set_friendly_name(decoder_seq_len->get_friendly_name());
// Normalize output from CTCGreedyDecoder = output_f and create second output with output_seq_len
auto squeeze2_axis = ngraph::opset6::Constant::create(seq_len_type, Shape({1}), {3});
auto squeeze2_output_f = std::make_shared<ngraph::opset6::Squeeze>(decoder->output(0), squeeze2_axis);
auto squeeze1_axis = ngraph::opset6::Constant::create(seq_len_type, Shape({1}), {2});
auto squeeze1_output_f = std::make_shared<ngraph::opset6::Squeeze>(squeeze2_output_f->output(0), squeeze1_axis);
element::Type ci_type = decoder_seq_len->get_classes_index_type();
element::Type sl_type = decoder_seq_len->get_sequence_length_type();
// CTCGreedyDecoder return floating point output. For Normalize output we need to convert output to classes_index_type
// Receive the first output with correct classes_index_type
auto output_i = std::make_shared<ngraph::opset6::Convert>(squeeze1_output_f->output(0), ci_type);
auto minus1 = opset6::Constant::create(ci_type, Shape{}, {-1});
// Get to know where equal -1
auto where_equal_minus1 = std::make_shared<ngraph::opset6::Equal>(output_i, minus1);
// Compute output seq mask
auto seq_mask_const0 = opset6::Constant::create(ci_type, Shape{1}, {0});
auto seq_mask_const1 = opset6::Constant::create(ci_type, Shape{1}, {1});
auto output_seq_mask = std::make_shared<ngraph::opset6::Select>(where_equal_minus1, seq_mask_const0, seq_mask_const1);
auto seq_mask_axis = opset6::Constant::create(ci_type, Shape{1}, {1});
// Receive the second output
auto output_seq_len = std::make_shared<ngraph::opset6::ReduceSum>(output_seq_mask, seq_mask_axis);
// Receive the second output with correct seq_len_type
auto output_seq_len_i = std::make_shared<ngraph::opset6::Convert>(output_seq_len->output(0), sl_type);
ngraph::copy_runtime_info(decoder_seq_len, {transpose, decoder, data_shape, T, N, plusT, plusT_scalar, range1T, mask_shape, upper_bounds,
squeeze2_output_f, squeeze1_output_f, transpose_upper_bounds, bool_seq_mask, seq_mask, transpose_seq_mask,
transpose_seq_mask_f, output_i, where_equal_minus1, output_seq_mask, output_seq_len, output_seq_len_i});
output_i->set_friendly_name(decoder_seq_len->get_friendly_name()+".0");
output_seq_len_i->set_friendly_name(decoder_seq_len->get_friendly_name()+".1");
ngraph::replace_node(decoder_seq_len, {output_i->output(0), output_seq_len_i->output(0)});
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(decoder, matcher_name);
register_matcher(m, callback);
}

View File

@ -31,6 +31,7 @@
#include <transformations/op_conversions/softplus_decomposition.hpp>
#include <transformations/op_conversions/convert_minimum_to_power_and_max.hpp>
#include <transformations/op_conversions/hswish_decomposition.hpp>
#include <transformations/op_conversions/simplify_ctc_greedy_decoder_seq_len.hpp>
#include <legacy/transformations/convert_opset1_to_legacy/convert_opset1_to_legacy.hpp>
#include <legacy/transformations/convert_opset1_to_legacy/convert_prior_to_ie_prior.hpp>
#include <transformations/common_optimizations/common_optimizations.hpp>
@ -196,6 +197,7 @@ ie::CNNNetwork FrontEnd::convertNetwork(ie::CNNNetwork& network) {
pass_config->disable<ngraph::pass::ConvertMinimum>();
pass_config->disable<ngraph::pass::HSwishDecomposition>();
pass_config->disable<ngraph::pass::MVN6Decomposition>();
pass_config->disable<ngraph::pass::SimplifyCTCGreedyDecoderSeqLen>();
auto transformationPredicate = [](const std::shared_ptr<const ngraph::Node>& node) -> bool {
return !!std::dynamic_pointer_cast<const ngraph::vpu::op::DynamicShapeResolver>(node->input_value(0).get_node_shared_ptr());

View File

@ -0,0 +1,560 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include <string>
#include <memory>
#include <ngraph/function.hpp>
#include <ngraph/opsets/opset6.hpp>
#include <ngraph/pass/manager.hpp>
#include <transformations/op_conversions/simplify_ctc_greedy_decoder_seq_len.hpp>
#include <transformations/init_node_info.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
using namespace testing;
using namespace ngraph;
TEST(TransformationTests, SimplifyCTCGreedyDecoderSeqLenTest) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto data = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 3, 7 });
auto seq_len = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::i64, ngraph::Shape{ 1 });
auto decoder_v6 = std::make_shared<ngraph::op::v6::CTCGreedyDecoderSeqLen>(data, seq_len, true);
auto res_1 = std::make_shared<opset6::Result>(decoder_v6->output(0));
auto res_2 = std::make_shared<opset6::Result>(decoder_v6->output(1));
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ res_1, res_2 }, ngraph::ParameterVector{ data, seq_len });
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::SimplifyCTCGreedyDecoderSeqLen>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto data1 = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 3, 7 });
auto seq_len1 = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::i64, ngraph::Shape{ 1 });
element::Type data_type = data1->get_element_type();
element::Type seq_len_type = seq_len1->get_element_type();
element::Type ci_type = element::i32;
element::Type sl_type = element::i32;
auto transpose = std::make_shared<ngraph::opset6::Transpose>(data1,
ngraph::opset6::Constant::create(element::i32,
Shape({3}), {1, 0, 2}));
auto data_shape = std::make_shared<ngraph::opset6::ShapeOf>(data1);
auto axisT = ngraph::opset6::Constant::create(seq_len_type, Shape{}, {0});
auto indexT = ngraph::opset6::Constant::create(seq_len_type, Shape{1}, {1});
auto T = std::make_shared<ngraph::opset6::Gather>(data_shape, indexT, axisT);
auto axisN = ngraph::opset6::Constant::create(seq_len_type, Shape{}, {0});
auto indexN = ngraph::opset6::Constant::create(seq_len_type, Shape{1}, {0});
auto N = std::make_shared<ngraph::opset6::Gather>(data_shape, indexN, axisN);
auto start = opset6::Constant::create(seq_len_type, Shape{}, {1});
auto step = opset6::Constant::create(seq_len_type, Shape{}, {1});
auto plus1 = opset6::Constant::create(element::i64, Shape{1}, {1});
auto plusT = std::make_shared<ngraph::opset6::Add>(T, plus1);
auto const_plusT = ngraph::opset6::Constant::create(seq_len_type, Shape{1}, {0});
auto plusT_scalar = std::make_shared<ngraph::opset6::Squeeze>(plusT, const_plusT);
auto range1T = std::make_shared<ngraph::opset6::Range>(start, plusT_scalar, step, seq_len_type);
auto mask_shape = std::make_shared<ngraph::opset6::Concat>(
OutputVector{T->output(0), N->output(0)}, 0);
auto upper_bounds = std::make_shared<ngraph::opset6::Broadcast>(
seq_len1, mask_shape->output(0));
auto transpose_upper_bounds = std::make_shared<ngraph::opset6::Transpose>(upper_bounds->output(0),
ngraph::opset6::Constant::create(seq_len_type,
Shape({2}), {1, 0}));
auto bool_seq_mask = std::make_shared<ngraph::opset6::GreaterEqual>(transpose_upper_bounds->output(0),
range1T->output(0));
auto mask_val_true = ngraph::opset6::Constant::create(seq_len_type, Shape{}, {1});
auto mask_val_false = ngraph::opset6::Constant::create(seq_len_type, Shape{}, {0});
auto seq_mask = std::make_shared<ngraph::opset6::Select>(bool_seq_mask, mask_val_true, mask_val_false);
auto transpose_seq_mask = std::make_shared<ngraph::opset6::Transpose>(seq_mask->output(0),
ngraph::opset6::Constant::create(seq_len_type,
Shape({2}), {1, 0}));
auto transpose_seq_mask_f32 = std::make_shared<ngraph::opset6::Convert>(transpose_seq_mask->output(0), data_type);
auto simplified_decoder = std::make_shared<ngraph::opset6::CTCGreedyDecoder>(transpose,
transpose_seq_mask_f32->output(0),
true);
auto squeeze2_axis = ngraph::opset6::Constant::create(seq_len_type, Shape({1}), {3});
auto squeeze2_output_f = std::make_shared<ngraph::opset6::Squeeze>(simplified_decoder->output(0), squeeze2_axis);
auto squeeze1_axis = ngraph::opset6::Constant::create(seq_len_type, Shape({1}), {2});
auto squeeze1_output_f = std::make_shared<ngraph::opset6::Squeeze>(squeeze2_output_f->output(0), squeeze1_axis);
auto output_i = std::make_shared<ngraph::opset6::Convert>(squeeze1_output_f->output(0), ci_type);
auto minus1 = opset6::Constant::create(ci_type, Shape{}, {-1});
auto where_equal_minus1 = std::make_shared<ngraph::opset6::Equal>(output_i, minus1);
auto seq_mask_const0 = opset6::Constant::create(ci_type, Shape{1}, {0});
auto seq_mask_const1 = opset6::Constant::create(ci_type, Shape{1}, {1});
auto output_seq_mask = std::make_shared<ngraph::opset6::Select>(where_equal_minus1, seq_mask_const0, seq_mask_const1);
auto seq_mask_axis = opset6::Constant::create(ci_type, Shape{1}, {1});
auto output_seq_len = std::make_shared<ngraph::opset6::ReduceSum>(output_seq_mask, seq_mask_axis);
auto output_seq_len_i = std::make_shared<ngraph::opset6::Convert>(output_seq_len->output(0), sl_type);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ output_i, output_seq_len_i }, ngraph::ParameterVector{ data1, seq_len1 });
}
auto res = compare_functions(f, f_ref, true, false, false, true, true);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, SimplifyCTCGreedyDecoderSeqLenDynamicInputShapeTest) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto data = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic());
auto seq_len = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::i32, ngraph::Shape{ 1 });
auto decoder_v6 = std::make_shared<ngraph::op::v6::CTCGreedyDecoderSeqLen>(data, seq_len, true, element::i64, element::i32);
auto res_1 = std::make_shared<opset6::Result>(decoder_v6->output(0));
auto res_2 = std::make_shared<opset6::Result>(decoder_v6->output(1));
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ res_1, res_2 }, ngraph::ParameterVector{ data, seq_len });
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::SimplifyCTCGreedyDecoderSeqLen>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto data1 = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic());
auto seq_len1 = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::i32, ngraph::Shape{ 1 });
element::Type data_type = data1->get_element_type();
element::Type seq_len_type = seq_len1->get_element_type();
element::Type ci_type = element::i64;
element::Type sl_type = element::i32;
auto transpose = std::make_shared<ngraph::opset6::Transpose>(data1,
ngraph::opset6::Constant::create(element::i32,
Shape({3}), {1, 0, 2}));
auto data_shape = std::make_shared<ngraph::opset6::ShapeOf>(data1);
auto axisT = ngraph::opset6::Constant::create(seq_len_type, Shape{}, {0});
auto indexT = ngraph::opset6::Constant::create(seq_len_type, Shape{1}, {1});
auto T = std::make_shared<ngraph::opset6::Gather>(data_shape, indexT, axisT);
auto axisN = ngraph::opset6::Constant::create(seq_len_type, Shape{}, {0});
auto indexN = ngraph::opset6::Constant::create(seq_len_type, Shape{1}, {0});
auto N = std::make_shared<ngraph::opset6::Gather>(data_shape, indexN, axisN);
auto start = opset6::Constant::create(seq_len_type, Shape{}, {1});
auto step = opset6::Constant::create(seq_len_type, Shape{}, {1});
auto plus1 = opset6::Constant::create(element::i64, Shape{1}, {1});
auto plusT = std::make_shared<ngraph::opset6::Add>(T, plus1);
auto const_plusT = ngraph::opset6::Constant::create(seq_len_type, Shape{1}, {0});
auto plusT_scalar = std::make_shared<ngraph::opset6::Squeeze>(plusT, const_plusT);
auto range1T = std::make_shared<ngraph::opset6::Range>(start, plusT_scalar, step, seq_len_type);
auto mask_shape = std::make_shared<ngraph::opset6::Concat>(
OutputVector{T->output(0), N->output(0)}, 0);
auto upper_bounds = std::make_shared<ngraph::opset6::Broadcast>(
seq_len1, mask_shape->output(0));
auto transpose_upper_bounds = std::make_shared<ngraph::opset6::Transpose>(upper_bounds->output(0),
ngraph::opset6::Constant::create(seq_len_type,
Shape({2}), {1, 0}));
auto bool_seq_mask = std::make_shared<ngraph::opset6::GreaterEqual>(transpose_upper_bounds->output(0),
range1T->output(0));
auto mask_val_true = ngraph::opset6::Constant::create(seq_len_type, Shape{}, {1});
auto mask_val_false = ngraph::opset6::Constant::create(seq_len_type, Shape{}, {0});
auto seq_mask = std::make_shared<ngraph::opset6::Select>(bool_seq_mask, mask_val_true, mask_val_false);
auto transpose_seq_mask = std::make_shared<ngraph::opset6::Transpose>(seq_mask->output(0),
ngraph::opset6::Constant::create(seq_len_type,
Shape({2}), {1, 0}));
auto transpose_seq_mask_f = std::make_shared<ngraph::opset6::Convert>(transpose_seq_mask->output(0), data_type);
auto simplified_decoder = std::make_shared<ngraph::opset6::CTCGreedyDecoder>(transpose,
transpose_seq_mask_f->output(0),
true);
auto squeeze2_axis = ngraph::opset6::Constant::create(seq_len_type, Shape({1}), {3});
auto squeeze2_output_f = std::make_shared<ngraph::opset6::Squeeze>(simplified_decoder->output(0), squeeze2_axis);
auto squeeze1_axis = ngraph::opset6::Constant::create(seq_len_type, Shape({1}), {2});
auto squeeze1_output_f = std::make_shared<ngraph::opset6::Squeeze>(squeeze2_output_f->output(0), squeeze1_axis);
auto output_i = std::make_shared<ngraph::opset6::Convert>(squeeze1_output_f->output(0), ci_type);
auto minus1 = opset6::Constant::create(ci_type, Shape{}, {-1});
auto where_equal_minus1 = std::make_shared<ngraph::opset6::Equal>(output_i, minus1);
auto seq_mask_const0 = opset6::Constant::create(ci_type, Shape{1}, {0});
auto seq_mask_const1 = opset6::Constant::create(ci_type, Shape{1}, {1});
auto output_seq_mask = std::make_shared<ngraph::opset6::Select>(where_equal_minus1, seq_mask_const0, seq_mask_const1);
auto seq_mask_axis = opset6::Constant::create(ci_type, Shape{1}, {1});
auto output_seq_len = std::make_shared<ngraph::opset6::ReduceSum>(output_seq_mask, seq_mask_axis);
auto output_seq_len_i = std::make_shared<ngraph::opset6::Convert>(output_seq_len->output(0), sl_type);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ output_i, output_seq_len_i }, ngraph::ParameterVector{ data1, seq_len1 });
}
auto res = compare_functions(f, f_ref, true, false, false, true, true);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, SimplifyCTCGreedyDecoderSeqLenDynamicBatchTest) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto data = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::PartialShape{Dimension::dynamic(), 3, 7});
auto seq_len = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::i32, ngraph::PartialShape{Dimension::dynamic()});
auto decoder_v6 = std::make_shared<ngraph::op::v6::CTCGreedyDecoderSeqLen>(data, seq_len, true, element::i32, element::i64);
auto res_1 = std::make_shared<opset6::Result>(decoder_v6->output(0));
auto res_2 = std::make_shared<opset6::Result>(decoder_v6->output(1));
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ res_1, res_2 }, ngraph::ParameterVector{ data, seq_len });
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::SimplifyCTCGreedyDecoderSeqLen>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto data1 = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::PartialShape{Dimension::dynamic(), 3, 7});
auto seq_len1 = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::i32, ngraph::PartialShape{Dimension::dynamic()});
element::Type data_type = data1->get_element_type();
element::Type seq_len_type = seq_len1->get_element_type();
element::Type ci_type = element::i32;
element::Type sl_type = element::i64;
auto transpose = std::make_shared<ngraph::opset6::Transpose>(data1,
ngraph::opset6::Constant::create(element::i32,
Shape({3}), {1, 0, 2}));
auto data_shape = std::make_shared<ngraph::opset6::ShapeOf>(data1);
auto axisT = ngraph::opset6::Constant::create(seq_len_type, Shape{}, {0});
auto indexT = ngraph::opset6::Constant::create(seq_len_type, Shape{1}, {1});
auto T = std::make_shared<ngraph::opset6::Gather>(data_shape, indexT, axisT);
auto axisN = ngraph::opset6::Constant::create(seq_len_type, Shape{}, {0});
auto indexN = ngraph::opset6::Constant::create(seq_len_type, Shape{1}, {0});
auto N = std::make_shared<ngraph::opset6::Gather>(data_shape, indexN, axisN);
auto start = opset6::Constant::create(seq_len_type, Shape{}, {1});
auto step = opset6::Constant::create(seq_len_type, Shape{}, {1});
auto plus1 = opset6::Constant::create(element::i64, Shape{1}, {1});
auto plusT = std::make_shared<ngraph::opset6::Add>(T, plus1);
auto const_plusT = ngraph::opset6::Constant::create(seq_len_type, Shape{1}, {0});
auto plusT_scalar = std::make_shared<ngraph::opset6::Squeeze>(plusT, const_plusT);
auto range1T = std::make_shared<ngraph::opset6::Range>(start, plusT_scalar, step, seq_len_type);
auto mask_shape = std::make_shared<ngraph::opset6::Concat>(
OutputVector{T->output(0), N->output(0)}, 0);
auto upper_bounds = std::make_shared<ngraph::opset6::Broadcast>(
seq_len1, mask_shape->output(0));
auto transpose_upper_bounds = std::make_shared<ngraph::opset6::Transpose>(upper_bounds->output(0),
ngraph::opset6::Constant::create(seq_len_type,
Shape({2}), {1, 0}));
auto bool_seq_mask = std::make_shared<ngraph::opset6::GreaterEqual>(transpose_upper_bounds->output(0),
range1T->output(0));
auto mask_val_true = ngraph::opset6::Constant::create(seq_len_type, Shape{}, {1});
auto mask_val_false = ngraph::opset6::Constant::create(seq_len_type, Shape{}, {0});
auto seq_mask = std::make_shared<ngraph::opset6::Select>(bool_seq_mask, mask_val_true, mask_val_false);
auto transpose_seq_mask = std::make_shared<ngraph::opset6::Transpose>(seq_mask->output(0),
ngraph::opset6::Constant::create(seq_len_type,
Shape({2}), {1, 0}));
auto transpose_seq_mask_f = std::make_shared<ngraph::opset6::Convert>(transpose_seq_mask->output(0), data_type);
auto simplified_decoder = std::make_shared<ngraph::opset6::CTCGreedyDecoder>(transpose,
transpose_seq_mask_f->output(0),
true);
auto squeeze2_axis = ngraph::opset6::Constant::create(seq_len_type, Shape({1}), {3});
auto squeeze2_output_f = std::make_shared<ngraph::opset6::Squeeze>(simplified_decoder->output(0), squeeze2_axis);
auto squeeze1_axis = ngraph::opset6::Constant::create(seq_len_type, Shape({1}), {2});
auto squeeze1_output_f = std::make_shared<ngraph::opset6::Squeeze>(squeeze2_output_f->output(0), squeeze1_axis);
auto output_i = std::make_shared<ngraph::opset6::Convert>(squeeze1_output_f->output(0), ci_type);
auto minus1 = opset6::Constant::create(ci_type, Shape{}, {-1});
auto where_equal_minus1 = std::make_shared<ngraph::opset6::Equal>(output_i, minus1);
auto seq_mask_const0 = opset6::Constant::create(ci_type, Shape{1}, {0});
auto seq_mask_const1 = opset6::Constant::create(ci_type, Shape{1}, {1});
auto output_seq_mask = std::make_shared<ngraph::opset6::Select>(where_equal_minus1, seq_mask_const0, seq_mask_const1);
auto seq_mask_axis = opset6::Constant::create(ci_type, Shape{1}, {1});
auto output_seq_len = std::make_shared<ngraph::opset6::ReduceSum>(output_seq_mask, seq_mask_axis);
auto output_seq_len_i = std::make_shared<ngraph::opset6::Convert>(output_seq_len->output(0), sl_type);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ output_i, output_seq_len_i }, ngraph::ParameterVector{ data1, seq_len1 });
}
auto res = compare_functions(f, f_ref, true, false, false, true, true);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, SimplifyCTCGreedyDecoderSeqLenDynamicSeqLenTest) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto data = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::PartialShape{2, Dimension::dynamic(), 7});
auto seq_len = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::i32, ngraph::PartialShape{2});
auto decoder_v6 = std::make_shared<ngraph::op::v6::CTCGreedyDecoderSeqLen>(data, seq_len, true, ngraph::element::i64, ngraph::element::i64);
auto res_1 = std::make_shared<opset6::Result>(decoder_v6->output(0));
auto res_2 = std::make_shared<opset6::Result>(decoder_v6->output(1));
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ res_1, res_2 }, ngraph::ParameterVector{ data, seq_len });
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::SimplifyCTCGreedyDecoderSeqLen>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto data1 = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::PartialShape{2, Dimension::dynamic(), 7});
auto seq_len1 = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::i32, ngraph::PartialShape{2});
element::Type data_type = data1->get_element_type();
element::Type seq_len_type = seq_len1->get_element_type();
element::Type ci_type = element::i64;
element::Type sl_type = element::i64;
auto transpose = std::make_shared<ngraph::opset6::Transpose>(data1,
ngraph::opset6::Constant::create(element::i32,
Shape({3}), {1, 0, 2}));
auto data_shape = std::make_shared<ngraph::opset6::ShapeOf>(data1);
auto axisT = ngraph::opset6::Constant::create(seq_len_type, Shape{}, {0});
auto indexT = ngraph::opset6::Constant::create(seq_len_type, Shape{1}, {1});
auto T = std::make_shared<ngraph::opset6::Gather>(data_shape, indexT, axisT);
auto axisN = ngraph::opset6::Constant::create(seq_len_type, Shape{}, {0});
auto indexN = ngraph::opset6::Constant::create(seq_len_type, Shape{1}, {0});
auto N = std::make_shared<ngraph::opset6::Gather>(data_shape, indexN, axisN);
auto start = opset6::Constant::create(seq_len_type, Shape{}, {1});
auto step = opset6::Constant::create(seq_len_type, Shape{}, {1});
auto plus1 = opset6::Constant::create(element::i64, Shape{1}, {1});
auto plusT = std::make_shared<ngraph::opset6::Add>(T, plus1);
auto const_plusT = ngraph::opset6::Constant::create(seq_len_type, Shape{1}, {0});
auto plusT_scalar = std::make_shared<ngraph::opset6::Squeeze>(plusT, const_plusT);
auto range1T = std::make_shared<ngraph::opset6::Range>(start, plusT_scalar, step, seq_len_type);
auto mask_shape = std::make_shared<ngraph::opset6::Concat>(
OutputVector{T->output(0), N->output(0)}, 0);
auto upper_bounds = std::make_shared<ngraph::opset6::Broadcast>(
seq_len1, mask_shape->output(0));
auto transpose_upper_bounds = std::make_shared<ngraph::opset6::Transpose>(upper_bounds->output(0),
ngraph::opset6::Constant::create(seq_len_type,
Shape({2}), {1, 0}));
auto bool_seq_mask = std::make_shared<ngraph::opset6::GreaterEqual>(transpose_upper_bounds->output(0),
range1T->output(0));
auto mask_val_true = ngraph::opset6::Constant::create(seq_len_type, Shape{}, {1});
auto mask_val_false = ngraph::opset6::Constant::create(seq_len_type, Shape{}, {0});
auto seq_mask = std::make_shared<ngraph::opset6::Select>(bool_seq_mask, mask_val_true, mask_val_false);
auto transpose_seq_mask = std::make_shared<ngraph::opset6::Transpose>(seq_mask->output(0),
ngraph::opset6::Constant::create(seq_len_type,
Shape({2}), {1, 0}));
auto transpose_seq_mask_f = std::make_shared<ngraph::opset6::Convert>(transpose_seq_mask->output(0), data_type);
auto simplified_decoder = std::make_shared<ngraph::opset6::CTCGreedyDecoder>(transpose,
transpose_seq_mask_f->output(0),
true);
auto squeeze2_axis = ngraph::opset6::Constant::create(seq_len_type, Shape({1}), {3});
auto squeeze2_output_f = std::make_shared<ngraph::opset6::Squeeze>(simplified_decoder->output(0), squeeze2_axis);
auto squeeze1_axis = ngraph::opset6::Constant::create(seq_len_type, Shape({1}), {2});
auto squeeze1_output_f = std::make_shared<ngraph::opset6::Squeeze>(squeeze2_output_f->output(0), squeeze1_axis);
auto output_i = std::make_shared<ngraph::opset6::Convert>(squeeze1_output_f->output(0), ci_type);
auto minus1 = opset6::Constant::create(ci_type, Shape{}, {-1});
auto where_equal_minus1 = std::make_shared<ngraph::opset6::Equal>(output_i, minus1);
auto seq_mask_const0 = opset6::Constant::create(ci_type, Shape{1}, {0});
auto seq_mask_const1 = opset6::Constant::create(ci_type, Shape{1}, {1});
auto output_seq_mask = std::make_shared<ngraph::opset6::Select>(where_equal_minus1, seq_mask_const0, seq_mask_const1);
auto seq_mask_axis = opset6::Constant::create(ci_type, Shape{1}, {1});
auto output_seq_len = std::make_shared<ngraph::opset6::ReduceSum>(output_seq_mask, seq_mask_axis);
auto output_seq_len_i = std::make_shared<ngraph::opset6::Convert>(output_seq_len->output(0), sl_type);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ output_i, output_seq_len_i },
ngraph::ParameterVector{ data1, seq_len1 });
}
auto res = compare_functions(f, f_ref, true, false, false, true, true);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, SimplifyCTCGreedyDecoderSeqLenWrongBlankIndexTest) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto data = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::PartialShape{2, Dimension::dynamic(), 7});
auto seq_len = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::i32, ngraph::PartialShape{2});
auto blank_index = op::Constant::create(element::i32, Shape{}, {5});
auto decoder_v6 = std::make_shared<ngraph::op::v6::CTCGreedyDecoderSeqLen>(data, seq_len, blank_index,
true, ngraph::element::i64, ngraph::element::i64);
auto res_1 = std::make_shared<opset6::Result>(decoder_v6->output(0));
auto res_2 = std::make_shared<opset6::Result>(decoder_v6->output(1));
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ res_1, res_2 }, ngraph::ParameterVector{ data, seq_len });
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::SimplifyCTCGreedyDecoderSeqLen>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto data1 = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::PartialShape{2, Dimension::dynamic(), 7});
auto seq_len1 = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::i32, ngraph::PartialShape{2});
auto blank_index1 = op::Constant::create(element::i32, Shape{}, {5});
auto decoder_v6 = std::make_shared<ngraph::op::v6::CTCGreedyDecoderSeqLen>(data1, seq_len1, blank_index1,
true, ngraph::element::i64, ngraph::element::i64);
auto res_1 = std::make_shared<opset6::Result>(decoder_v6->output(0));
auto res_2 = std::make_shared<opset6::Result>(decoder_v6->output(1));
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ res_1, res_2 }, ngraph::ParameterVector{ data1, seq_len1 });
}
auto res = compare_functions(f, f_ref, true, false, false, true, true);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, SimplifyCTCGreedyDecoderSeqLenDynamicSeqLenWithBlankIndexTest) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto data = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::PartialShape{2, Dimension::dynamic(), 7});
auto seq_len = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::i32, ngraph::PartialShape{2});
auto blank_index = op::Constant::create(element::i32, Shape{}, {6});
auto decoder_v6 = std::make_shared<ngraph::op::v6::CTCGreedyDecoderSeqLen>(data, seq_len, blank_index,
true, ngraph::element::i64, ngraph::element::i64);
auto res_1 = std::make_shared<opset6::Result>(decoder_v6->output(0));
auto res_2 = std::make_shared<opset6::Result>(decoder_v6->output(1));
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ res_1, res_2 }, ngraph::ParameterVector{ data, seq_len });
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::SimplifyCTCGreedyDecoderSeqLen>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto data1 = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::PartialShape{2, Dimension::dynamic(), 7});
auto seq_len1 = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::i32, ngraph::PartialShape{2});
element::Type data_type = data1->get_element_type();
element::Type seq_len_type = seq_len1->get_element_type();
element::Type ci_type = element::i64;
element::Type sl_type = element::i64;
auto transpose = std::make_shared<ngraph::opset6::Transpose>(data1,
ngraph::opset6::Constant::create(element::i32,
Shape({3}), {1, 0, 2}));
auto data_shape = std::make_shared<ngraph::opset6::ShapeOf>(data1);
auto axisT = ngraph::opset6::Constant::create(seq_len_type, Shape{}, {0});
auto indexT = ngraph::opset6::Constant::create(seq_len_type, Shape{1}, {1});
auto T = std::make_shared<ngraph::opset6::Gather>(data_shape, indexT, axisT);
auto axisN = ngraph::opset6::Constant::create(seq_len_type, Shape{}, {0});
auto indexN = ngraph::opset6::Constant::create(seq_len_type, Shape{1}, {0});
auto N = std::make_shared<ngraph::opset6::Gather>(data_shape, indexN, axisN);
auto start = opset6::Constant::create(seq_len_type, Shape{}, {1});
auto step = opset6::Constant::create(seq_len_type, Shape{}, {1});
auto plus1 = opset6::Constant::create(element::i64, Shape{1}, {1});
auto plusT = std::make_shared<ngraph::opset6::Add>(T, plus1);
auto const_plusT = ngraph::opset6::Constant::create(seq_len_type, Shape{1}, {0});
auto plusT_scalar = std::make_shared<ngraph::opset6::Squeeze>(plusT, const_plusT);
auto range1T = std::make_shared<ngraph::opset6::Range>(start, plusT_scalar, step, seq_len_type);
auto mask_shape = std::make_shared<ngraph::opset6::Concat>(
OutputVector{T->output(0), N->output(0)}, 0);
auto upper_bounds = std::make_shared<ngraph::opset6::Broadcast>(
seq_len1, mask_shape->output(0));
auto transpose_upper_bounds = std::make_shared<ngraph::opset6::Transpose>(upper_bounds->output(0),
ngraph::opset6::Constant::create(seq_len_type,
Shape({2}), {1, 0}));
auto bool_seq_mask = std::make_shared<ngraph::opset6::GreaterEqual>(transpose_upper_bounds->output(0),
range1T->output(0));
auto mask_val_true = ngraph::opset6::Constant::create(seq_len_type, Shape{}, {1});
auto mask_val_false = ngraph::opset6::Constant::create(seq_len_type, Shape{}, {0});
auto seq_mask = std::make_shared<ngraph::opset6::Select>(bool_seq_mask, mask_val_true, mask_val_false);
auto transpose_seq_mask = std::make_shared<ngraph::opset6::Transpose>(seq_mask->output(0),
ngraph::opset6::Constant::create(seq_len_type,
Shape({2}), {1, 0}));
auto transpose_seq_mask_f = std::make_shared<ngraph::opset6::Convert>(transpose_seq_mask->output(0), data_type);
auto simplified_decoder = std::make_shared<ngraph::opset6::CTCGreedyDecoder>(transpose,
transpose_seq_mask_f->output(0),
true);
auto squeeze2_axis = ngraph::opset6::Constant::create(seq_len_type, Shape({1}), {3});
auto squeeze2_output_f = std::make_shared<ngraph::opset6::Squeeze>(simplified_decoder->output(0), squeeze2_axis);
auto squeeze1_axis = ngraph::opset6::Constant::create(seq_len_type, Shape({1}), {2});
auto squeeze1_output_f = std::make_shared<ngraph::opset6::Squeeze>(squeeze2_output_f->output(0), squeeze1_axis);
auto output_i = std::make_shared<ngraph::opset6::Convert>(squeeze1_output_f->output(0), ci_type);
auto minus1 = opset6::Constant::create(ci_type, Shape{}, {-1});
auto where_equal_minus1 = std::make_shared<ngraph::opset6::Equal>(output_i, minus1);
auto seq_mask_const0 = opset6::Constant::create(ci_type, Shape{1}, {0});
auto seq_mask_const1 = opset6::Constant::create(ci_type, Shape{1}, {1});
auto output_seq_mask = std::make_shared<ngraph::opset6::Select>(where_equal_minus1, seq_mask_const0, seq_mask_const1);
auto seq_mask_axis = opset6::Constant::create(ci_type, Shape{1}, {1});
auto output_seq_len = std::make_shared<ngraph::opset6::ReduceSum>(output_seq_mask, seq_mask_axis);
auto output_seq_len_i = std::make_shared<ngraph::opset6::Convert>(output_seq_len->output(0), sl_type);
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ output_i, output_seq_len_i },
ngraph::ParameterVector{ data1, seq_len1 });
}
auto res = compare_functions(f, f_ref, true, false, false, true, true);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, SimplifyCTCGreedyDecoderSeqLenDynamicSeqLenParamWithBlankIndexTest) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto data = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::PartialShape{2, Dimension::dynamic(), 7});
auto seq_len = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::i32, ngraph::PartialShape{2});
auto blank_index = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::i32, ngraph::PartialShape{1});
auto decoder_v6 = std::make_shared<ngraph::op::v6::CTCGreedyDecoderSeqLen>(data, seq_len, blank_index,
true, ngraph::element::i64, ngraph::element::i64);
auto res_1 = std::make_shared<opset6::Result>(decoder_v6->output(0));
auto res_2 = std::make_shared<opset6::Result>(decoder_v6->output(1));
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ res_1, res_2 }, ngraph::ParameterVector{ data, seq_len, blank_index });
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::SimplifyCTCGreedyDecoderSeqLen>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto data1 = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::PartialShape{2, Dimension::dynamic(), 7});
auto seq_len1 = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::i32, ngraph::PartialShape{2});
auto blank_index1 = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::i32, ngraph::PartialShape{1});
auto decoder_v6 = std::make_shared<ngraph::op::v6::CTCGreedyDecoderSeqLen>(data1, seq_len1, blank_index1,
true, ngraph::element::i64, ngraph::element::i64);
auto res_1 = std::make_shared<opset6::Result>(decoder_v6->output(0));
auto res_2 = std::make_shared<opset6::Result>(decoder_v6->output(1));
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ res_1, res_2 }, ngraph::ParameterVector{ data1, seq_len1, blank_index1 });
}
auto res = compare_functions(f, f_ref, true, false, false, true, true);
ASSERT_TRUE(res.first) << res.second;
}

View File

@ -640,6 +640,7 @@ extensions/ops/constant_fill.py
extensions/ops/copyop.py
extensions/ops/correlation.py
extensions/ops/ctc_greedy_decoder.py
extensions/ops/ctc_greedy_decoder_seq_len.py
extensions/ops/ctc_loss.py
extensions/ops/cumsum.py
extensions/ops/data_augmentation.py
@ -981,6 +982,7 @@ mo/utils/ir_reader/extenders/binary_convolution_extender.py
mo/utils/ir_reader/extenders/bucketize_extender.py
mo/utils/ir_reader/extenders/conv_extender.py
mo/utils/ir_reader/extenders/convert_extender.py
mo/utils/ir_reader/extenders/ctc_greedy_decoder_seq_len_extender.py
mo/utils/ir_reader/extenders/deconvolution_extender.py
mo/utils/ir_reader/extenders/deformable_convolution_extender.py
mo/utils/ir_reader/extenders/experimental_extender.py

View File

@ -1,5 +1,5 @@
"""
Copyright (C) 2018-2020 Intel Corporation
Copyright (C) 2018-2021 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -14,219 +14,61 @@
limitations under the License.
"""
import logging as log
import numpy as np
from extensions.ops.Cast import Cast
from extensions.front.Pack import Pack
from extensions.front.FillToBroadcast import FillToBroadcast
from extensions.ops.ctc_greedy_decoder_seq_len import CTCGreedyDecoderSeqLenOp
from extensions.ops.transpose import Transpose
from mo.front.common.partial_infer.utils import int64_array
from mo.front.common.replacement import FrontReplacementSubgraph
from mo.front.tf.graph_utils import create_op_with_const_inputs
from mo.graph.graph import Graph, rename_nodes
from mo.ops.broadcast import Broadcast
from mo.ops.concat import Concat
from mo.ops.squeeze import Squeeze
from mo.ops.unsqueeze import Unsqueeze
class CTCGreedyDecoderReplacement(FrontReplacementSubgraph):
"""
TensorFlow CTCGreedyDecoder produces output in a sparse tensor that is not supported by Inference Engine and
Inference Engine's CTCGreedyDecoder has different output that is in a dense format. So this transformation
intents to replace TF CTCGreedyDecoder+SparseToDense with IE one.
Also Inference Engine's CTCGreedyDecoder has a specific format for the second input tensor, a sequence length,
different from TF's one so this transformation cares about transformation of its format.
The second input to the CTCGreedyDecoder in the TensorFlow is a 1D tensor with sequence lengths. In the Inference
Engine the second input to the CTCGreedyDecoder is a 2D tensor, a sequence mask, where the first element
in each row is equal to 1 and all others in the tail are equal to 0. The number of ones represents
a sequence length.
Inference Engine's CTCGreedyDecoderSeqLen has different output that is in a dense format. So this transformation
intents to replace TF CTCGreedyDecoder+SparseToDense to CTCGreedyDecoderSeqLen which compatible with IE.
"""
enabled = True
def run_after(self):
# CTCGreedyDecoderReplacement is not reshape-able transformation
# so reshape-able CTCGreedyDecoderReplacement2 transformation is applied first
return [CTCGreedyDecoderReplacement2]
@staticmethod
def pattern(**kwargs):
return dict(
nodes=[('decoder', dict(op='CTCGreedyDecoder')),
nodes=[('decoder', dict(op='CTCGreedyDecoderSeqLen')),
('cast', dict(op='Cast')),
('sparse_to_dense', dict(op='SparseToDense'))
],
edges=[('decoder', 'sparse_to_dense', {'out': 0}),
('decoder', 'sparse_to_dense', {'out': 2}),
('decoder', 'cast', {'out': 1}),
('cast', 'sparse_to_dense', {'out': 0})
]
)
def replace_sub_graph(self, graph: Graph, match: dict):
# TODO: Once Inference Engine's CTCGreedyDecoder starts to support sequence length format like in TensorFlow,
# CTCGreedyDecoderReplacement2 needs to be removed and CTCGreedyDecoderReplacement, a more generic
# transformation, needs to be adopted for all cases
ctc_greedy_decoder = match['decoder']
ctc_greedy_decoder_tf = match['decoder']
cast = match['cast']
sparse_to_dense = match['sparse_to_dense']
sparse_to_dense_name = sparse_to_dense.soft_get('name', sparse_to_dense.id)
ctc_greedy_decoder_tf_name = ctc_greedy_decoder_tf.soft_get('name', ctc_greedy_decoder_tf.id)
# disconnect SparseToDense and Cast nodes
sparse_to_dense.in_port(0).disconnect()
cast.in_port(0).disconnect()
# for normalizing input chanel need to transpose input data from [T, N, C] to [N, T, C]
# which supported CTCGreedyDecoderSeqLen op.
ctc_data_permute = create_op_with_const_inputs(graph, Transpose, {1: int64_array([1, 0, 2])},
{'name': ctc_greedy_decoder_tf_name + '/ctc_data_permute'})
# transform CTCGreedyDecoder output to TensorFlow's one:
# 1. squeeze the output to [N, T] shape
# 2. cast it to integer
squeeze_dec_seq = create_op_with_const_inputs(graph, Squeeze, {1: int64_array([2, 3])},
{'name': sparse_to_dense_name})
squeeze_dec_seq.in_port(0).connect(ctc_greedy_decoder.out_port(0))
cast_to_int = Cast(graph, {'name': sparse_to_dense_name + '/CastToInt',
'dst_type': np.int32}).create_node()
cast_to_int.in_port(0).connect(squeeze_dec_seq.out_port(0))
assert ctc_greedy_decoder_tf.has_valid('merge_repeated'), \
'The CTCGreedyDecoderSeqLen node "{}" misses "merge_repeated" attribute'.format(ctc_greedy_decoder_tf_name)
# preserve output name from original graph
rename_nodes([(sparse_to_dense, sparse_to_dense_name + '/AbandonedName'),
(cast_to_int, sparse_to_dense_name)])
ctc_greedy_decoder_tf.in_port(0).get_source().connect(ctc_data_permute.in_port(0))
merge_repeated_tf = ctc_greedy_decoder_tf.merge_repeated
ctc_greedy_decoder = CTCGreedyDecoderSeqLenOp(graph, {'name': sparse_to_dense_name,
'merge_repeated': merge_repeated_tf}).create_node()
rename_nodes([(sparse_to_dense, sparse_to_dense_name + '/AbandonedName'), (ctc_greedy_decoder, sparse_to_dense_name)])
ctc_greedy_decoder.in_port(0).connect(ctc_data_permute.out_port(0))
ctc_greedy_decoder_tf.in_port(1).get_source().connect(ctc_greedy_decoder.in_port(1))
# set output of the new sub-graph as a source for SparseToDense consumer
sparse_to_dense.out_port(0).get_connection().set_source(cast_to_int.out_port(0))
sparse_to_dense.out_port(0).get_connection().set_source(ctc_greedy_decoder.out_port(0))
# remove no longer needed nodes
graph.remove_nodes_from([sparse_to_dense.id, cast.id])
# mark CTCGreedyDecoder node as a node that requires transformation of sequence length to a mask format
# in the middle phase
ctc_greedy_decoder['use_mask_format'] = True
# unless the second input of CTCGreedyDecoder is a parameter, it enforces MO to use --static-shape
# to try getting the second input with a value
sequence_length_node = ctc_greedy_decoder.in_node(1)
if sequence_length_node.soft_get('op') != 'Parameter' and not graph.graph['cmd_params'].static_shape:
log.error(
"Model can not be translated in a reshape-able way.\n"
"Model Optimizer key static_shape was turned on to prevent related errors.\n"
"There will be no success changing input shapes of the model with the help of "
"InferenceEngine reshape method", extra={'is_warning': True})
graph.graph['cmd_params'].static_shape = True
class CTCGreedyDecoderReplacement2(FrontReplacementSubgraph):
"""
The TF implementation of the CTCGreedyDecoder produces a tuple with two tensors. The first element in the tuple is
the SparseTensor which is converted to a regular tensor with the SparseToDense operation. This replacer matches
CTCGreedyDecoder and SparseToDense operations and removes the SparseToDense and Cast operation which is also used
in the SparseToDense operation, because Inference Engine implementation of the CTCGreedyDecoder produces regular
tensor as output.
Also, Inference Engine CTCGreedyDecoder requires a mask format for sequence lengths that is a different from
original one. Hence, this transformation changes a format of sequence length to a mask by replacing Fill and Pack
nodes with a special graph that produces a tensor of ones with shape [T, N] accepted by opset CTCGreedyDecoder.
"""
enabled = True
def run_before(self):
return [Pack, FillToBroadcast]
@staticmethod
def pattern(**kwargs):
return dict(
nodes=[
('transpose', dict(op='Transpose')),
('shape', dict(op='ShapeOf')),
('shape_1', dict(op='ShapeOf')),
('strided_slice', dict(op='StridedSlice')),
('stack', dict(op='Const', value=lambda v: v is not None and np.array_equal(v, [1]))),
('stack1', dict(op='Const', value=lambda v: v is not None and np.array_equal(v, [2]))),
('stack2', dict(op='Const', value=lambda v: v is not None and np.array_equal(v, [1]))),
('strided_slice_1', dict(op='StridedSlice')),
('stack_1', dict(op='Const', value=lambda v: v is not None and np.array_equal(v, [0]))),
('stack1_1', dict(op='Const', value=lambda v: v is not None and np.array_equal(v, [1]))),
('stack2_1', dict(op='Const', value=lambda v: v is not None and np.array_equal(v, [1]))),
('dims', dict(op='Pack')),
('fill', dict(op='Fill')),
('decoder', dict(op='CTCGreedyDecoder')),
('cast', dict(op='Cast')),
('sparse_to_dense', dict(op='SparseToDense')),
],
edges=[
('transpose', 'shape', {'out': 0}),
('transpose', 'shape_1', {'out': 0}),
('transpose', 'decoder', {'out': 0, 'in': 0}),
('shape', 'strided_slice', {'out': 0, 'in': 0}),
('stack', 'strided_slice', {'out': 0, 'in': 1}),
('stack1', 'strided_slice', {'out': 0, 'in': 2}),
('stack2', 'strided_slice', {'out': 0, 'in': 3}),
('shape_1', 'strided_slice_1', {'out': 0, 'in': 0}),
('stack_1', 'strided_slice_1', {'out': 0, 'in': 1}),
('stack1_1', 'strided_slice_1', {'out': 0, 'in': 2}),
('stack2_1', 'strided_slice_1', {'out': 0, 'in': 3}),
('strided_slice', 'dims', {'out': 0, 'in': 0}),
('dims', 'fill', {'out': 0, 'in': 0}),
('strided_slice_1', 'fill', {'out': 0, 'in': 1}),
('fill', 'decoder', {'out': 0, 'in': 1}),
('decoder', 'sparse_to_dense', {'out': 0, 'in': 0}),
('decoder', 'cast', {'out': 1, 'in': 0}),
('cast', 'sparse_to_dense', {'out': 0}),
]
)
def replace_sub_graph(self, graph: Graph, match: dict):
# obtain references to necessary nodes and their names
fill = match['fill']
dims = match['dims']
strided_slice = match['strided_slice']
strided_slice_1 = match['strided_slice_1']
ctc_greedy_decoder = match['decoder']
cast = match['cast']
sparse_to_dense = match['sparse_to_dense']
strided_slice_name = strided_slice.soft_get('name', strided_slice.id)
strided_slice_1_name = strided_slice_1.soft_get('name', strided_slice_1.id)
ctc_greedy_decoder_name = ctc_greedy_decoder.soft_get('name', ctc_greedy_decoder.id)
sparse_to_dense_name = sparse_to_dense.soft_get('name', sparse_to_dense.id)
# unsqueeze scalar values with batch size and time dimension
unsqueeze_batch_size = create_op_with_const_inputs(graph, Unsqueeze, {1: int64_array(0)},
{'name': strided_slice_name + '/Unsqueeze'})
dims.in_port(0).get_connection().set_destination(unsqueeze_batch_size.in_port(0))
unsqueeze_time_size = create_op_with_const_inputs(graph, Unsqueeze, {1: int64_array(0)},
{'name': strided_slice_1_name + '/Unsqueeze'})
fill.in_port(1).get_connection().set_destination(unsqueeze_time_size.in_port(0))
# compute a sequence mask shape [T, N] required for CTCGreedyDecoder
seq_mask_shape = Concat(graph, {'axis': 0, 'in_ports_count': 2,
'name': ctc_greedy_decoder_name + '/SequenceMaskShape'}).create_node()
seq_mask_shape.in_port(0).connect(unsqueeze_time_size.out_port(0))
seq_mask_shape.in_port(1).connect(unsqueeze_batch_size.out_port(0))
# compute a sequence mask
sequence_mask = create_op_with_const_inputs(graph, Broadcast, {0: np.array([1.0], dtype=np.float)},
{'mode': 'numpy',
'name': ctc_greedy_decoder_name + '/SequenceMask'})
sequence_mask.in_port(1).connect(seq_mask_shape.out_port(0))
# create CTCGreedyDecoder with the sequence mask instead of sequence length
ctc_greedy_decoder.in_port(1).disconnect()
ctc_greedy_decoder.in_port(1).connect(sequence_mask.out_port(0))
# remove fill and pack nodes since they are now in unconnected component
graph.remove_nodes_from([fill.id, dims.id])
# transform opset CTCGreedyDecoder output to TensorFlow's one that has a shape [N, T]
# opset CTCGreedyDecoder has an output with a shape [N, T, 1, 1]
squeeze_dec_seq = create_op_with_const_inputs(graph, Squeeze, {1: int64_array([2, 3])},
{'name': sparse_to_dense_name})
squeeze_dec_seq.in_port(0).connect(ctc_greedy_decoder.out_port(0))
cast_to_int = Cast(graph, {'name': sparse_to_dense_name + '/CastToInt',
'dst_type': np.int32}).create_node()
cast_to_int.in_port(0).connect(squeeze_dec_seq.out_port(0))
# preserve output name from original graph
rename_nodes([(sparse_to_dense, sparse_to_dense_name + '/AbandonedName'),
(cast_to_int, sparse_to_dense_name)])
# set output of the new sub-graph as a source for SparseToDense consumer
sparse_to_dense.out_port(0).get_connection().set_source(cast_to_int.out_port(0))
# cleanup a graph
graph.remove_nodes_from([cast.id, sparse_to_dense.id])
graph.remove_nodes_from([sparse_to_dense.id, cast.id, ctc_greedy_decoder_tf.id])

View File

@ -1,5 +1,5 @@
"""
Copyright (C) 2020 Intel Corporation
Copyright (C) 2021 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -14,10 +14,11 @@
limitations under the License.
"""
import numpy as np
import unittest
from extensions.front.tf.CTCGreedyDecoderReplacement import CTCGreedyDecoderReplacement, CTCGreedyDecoderReplacement2
import numpy as np
from extensions.front.tf.CTCGreedyDecoderReplacement import CTCGreedyDecoderReplacement
from mo.front.common.partial_infer.utils import int64_array
from mo.utils.ir_engine.compare_graphs import compare_graphs
from mo.utils.unittest.graph import build_graph, const
@ -29,13 +30,15 @@ class CTCGreedyDecoderReplacementTests(unittest.TestCase):
# nodes from original graph
'logits': {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
'seq_len': {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
'decoder': {'kind': 'op', 'op': 'CTCGreedyDecoder'},
'order_arr': {'kind': 'op', 'op': 'Const'},
'transpose': {'type': 'Transpose', 'kind': 'op', 'op': 'Transpose'},
'decoder': {'kind': 'op', 'op': 'CTCGreedyDecoderSeqLen', 'merge_repeated': True},
'cast': {'kind': 'op', 'op': 'Cast'},
'sparse_to_dense': {'kind': 'op', 'op': 'SparseToDense'},
'last': {'type': None, 'value': None, 'kind': 'op', 'op': 'Result'},
# new nodes
'new_decoder': {'kind': 'op', 'op': 'CTCGreedyDecoder', 'use_mask_format': True},
'new_decoder': {'kind': 'op', 'op': 'CTCGreedyDecoderSeqLen', 'use_mask_format': True},
**const('squeeze_axes', int64_array([2, 3])),
'squeeze_dec_seq': {'kind': 'op', 'op': 'Squeeze'},
'cast_to_int': {'kind': 'op', 'op': 'Cast'},
@ -45,6 +48,7 @@ class CTCGreedyDecoderReplacementTests(unittest.TestCase):
[('logits', 'decoder', {'out': 0, 'in': 0}),
('seq_len', 'decoder', {'out': 0, 'in': 1}),
('decoder', 'sparse_to_dense', {'out': 0, 'in': 0}),
('decoder', 'sparse_to_dense', {'out': 2, 'in': 1}),
('decoder', 'cast', {'out': 1, 'in': 0}),
('cast', 'sparse_to_dense', {'out': 0}),
('sparse_to_dense', 'last', {'out': 0, 'in': 0}),
@ -53,113 +57,13 @@ class CTCGreedyDecoderReplacementTests(unittest.TestCase):
CTCGreedyDecoderReplacement().find_and_replace_pattern(graph)
graph_ref = build_graph(nodes_attributes,
[('logits', 'decoder', {'out': 0, 'in': 0}),
('seq_len', 'decoder', {'out': 0, 'in': 1}),
('decoder', 'squeeze_dec_seq', {'out': 0, 'in': 0}),
('squeeze_axes', 'squeeze_dec_seq', {'out': 0, 'in': 1}),
('squeeze_dec_seq', 'cast_to_int', {'out': 0, 'in': 0}),
('cast_to_int', 'last', {'out': 0, 'in': 0}),
],
nodes_with_edges_only=True)
(flag, resp) = compare_graphs(graph, graph_ref, 'last', check_op_attrs=True)
self.assertEqual(len(graph.get_op_nodes(op='Cast')) == 1 and
graph.get_op_nodes(op='Cast')[0]['name'] == 'sparse_to_dense', True,
'Name is not inherited from original node for CTCGreedyDecoderReplacement')
self.assertTrue(flag, resp)
def test2(self):
nodes_attributes = {
# nodes from original graph
'logits': {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
'transpose': {'kind': 'op', 'op': 'Transpose'},
'shape': {'kind': 'op', 'op': 'ShapeOf'},
'shape_1': {'kind': 'op', 'op': 'ShapeOf'},
'strided_slice': {'kind': 'op', 'op': 'StridedSlice'},
**const('stack', int64_array([1])),
**const('stack1', int64_array([2])),
**const('stack2', int64_array([1])),
'strided_slice_1': {'kind': 'op', 'op': 'StridedSlice'},
**const('stack_1', int64_array([0])),
**const('stack1_1', int64_array([1])),
**const('stack2_1', int64_array([1])),
'dims': {'kind': 'op', 'op': 'Pack'},
'fill': {'kind': 'op', 'op': 'Fill'},
'decoder': {'kind': 'op', 'op': 'CTCGreedyDecoder'},
'cast': {'kind': 'op', 'op': 'Cast'},
'sparse_to_dense': {'kind': 'op', 'op': 'SparseToDense'},
# new nodes
**const('unsqueeze_batch_size_axis', int64_array(0)),
'unsqueeze_batch_size': {'kind': 'op', 'op': 'Unsqueeze'},
**const('unsqueeze_time_size_axis', int64_array(0)),
'unsqueeze_time_size': {'kind': 'op', 'op': 'Unsqueeze'},
'seq_mask_shape': {'kind': 'op', 'op': 'Concat'},
'sequence_mask': {'kind': 'op', 'op': 'Broadcast'},
**const('one', np.array([1.0], dtype=np.float)),
**const('squeeze_axes', int64_array([2, 3])),
'squeeze_dec_seq': {'kind': 'op', 'op': 'Squeeze'},
'cast_to_int': {'kind': 'op', 'op': 'Cast'},
'last': {'type': None, 'value': None, 'kind': 'op', 'op': 'Result'},
}
graph = build_graph(nodes_attributes,
[('logits', 'transpose', {'out': 0}),
('transpose', 'shape', {'out': 0}),
('transpose', 'shape_1', {'out': 0}),
('transpose', 'decoder', {'out': 0, 'in': 0}),
('shape', 'strided_slice', {'out': 0, 'in': 0}),
('stack', 'strided_slice', {'out': 0, 'in': 1}),
('stack1', 'strided_slice', {'out': 0, 'in': 2}),
('stack2', 'strided_slice', {'out': 0, 'in': 3}),
('shape_1', 'strided_slice_1', {'out': 0, 'in': 0}),
('stack_1', 'strided_slice_1', {'out': 0, 'in': 1}),
('stack1_1', 'strided_slice_1', {'out': 0, 'in': 2}),
('stack2_1', 'strided_slice_1', {'out': 0, 'in': 3}),
('strided_slice', 'dims', {'out': 0, 'in': 0}),
('dims', 'fill', {'out': 0, 'in': 0}),
('strided_slice_1', 'fill', {'out': 0, 'in': 1}),
('fill', 'decoder', {'out': 0, 'in': 1}),
('decoder', 'sparse_to_dense', {'out': 0, 'in': 0}),
('decoder', 'cast', {'out': 1, 'in': 0}),
('cast', 'sparse_to_dense', {'out': 0}),
('sparse_to_dense', 'last', {'out': 0, 'in': 0}),
], nodes_with_edges_only=True)
graph.stage = 'front'
CTCGreedyDecoderReplacement2().find_and_replace_pattern(graph)
graph_ref = build_graph(nodes_attributes,
[('logits', 'transpose', {'out': 0}),
('transpose', 'shape', {'out': 0}),
('transpose', 'shape_1', {'out': 0}),
[('logits', 'transpose', {'out': 0, 'in': 0}),
('order_arr', 'transpose', {'out': 0, 'in': 1}),
('transpose', 'decoder', {'out': 0, 'in': 0}),
('shape', 'strided_slice', {'out': 0, 'in': 0}),
('stack', 'strided_slice', {'out': 0, 'in': 1}),
('stack1', 'strided_slice', {'out': 0, 'in': 2}),
('stack2', 'strided_slice', {'out': 0, 'in': 3}),
('shape_1', 'strided_slice_1', {'out': 0, 'in': 0}),
('stack_1', 'strided_slice_1', {'out': 0, 'in': 1}),
('stack1_1', 'strided_slice_1', {'out': 0, 'in': 2}),
('stack2_1', 'strided_slice_1', {'out': 0, 'in': 3}),
('strided_slice', 'unsqueeze_batch_size', {'out': 0, 'in': 0}),
('unsqueeze_batch_size_axis', 'unsqueeze_batch_size', {'out': 0, 'in': 1}),
('strided_slice_1', 'unsqueeze_time_size', {'out': 0, 'in': 0}),
('unsqueeze_time_size_axis', 'unsqueeze_time_size', {'out': 0, 'in': 1}),
('unsqueeze_batch_size', 'seq_mask_shape', {'out': 0, 'in': 1}),
('unsqueeze_time_size', 'seq_mask_shape', {'out': 0, 'in': 0}),
('one', 'sequence_mask', {'out': 0, 'in': 0}),
('seq_mask_shape', 'sequence_mask', {'out': 0, 'in': 1}),
('sequence_mask', 'decoder', {'out': 0, 'in': 1}),
('decoder', 'squeeze_dec_seq', {'out': 0, 'in': 0}),
('squeeze_axes', 'squeeze_dec_seq', {'out': 0, 'in': 1}),
('squeeze_dec_seq', 'cast_to_int', {'out': 0, 'in': 0}),
('cast_to_int', 'last', {'out': 0, 'in': 0}),
('seq_len', 'decoder', {'out': 0, 'in': 1}),
('decoder', 'last', {'out': 0, 'in': 0}),
],
nodes_with_edges_only=True)
(flag, resp) = compare_graphs(graph, graph_ref, 'last', check_op_attrs=True)
self.assertEqual(len(graph.get_op_nodes(op='Cast')) == 1 and
graph.get_op_nodes(op='Cast')[0]['name'] == 'sparse_to_dense', True,
'Name is not inherited from original node for CTCGreedyDecoderReplacement2')
self.assertTrue(flag, resp)

View File

@ -1,5 +1,5 @@
"""
Copyright (C) 2018-2020 Intel Corporation
Copyright (C) 2018-2021 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -13,7 +13,7 @@
See the License for the specific language governing permissions and
limitations under the License.
"""
from extensions.ops.ctc_greedy_decoder import CTCGreedyDecoderOp
from extensions.ops.ctc_greedy_decoder_seq_len import CTCGreedyDecoderSeqLenOp
from mo.front.extractor import FrontExtractorOp
@ -24,7 +24,7 @@ class CTCCGreedyDecoderFrontExtractor(FrontExtractorOp):
@classmethod
def extract(cls, node):
attrs = {
'ctc_merge_repeated': int(node.pb.attr['merge_repeated'].b),
'merge_repeated': bool(node.pb.attr['merge_repeated'].b),
}
CTCGreedyDecoderOp.update_node_stat(node, attrs)
CTCGreedyDecoderSeqLenOp.update_node_stat(node, attrs)
return cls.enabled

View File

@ -14,37 +14,21 @@
limitations under the License.
"""
import numpy as np
import logging as log
from extensions.ops.Cast import Cast
from extensions.ops.ctc_greedy_decoder import CTCGreedyDecoderOp
from extensions.ops.ctc_greedy_decoder_seq_len import CTCGreedyDecoderSeqLenOp
from extensions.ops.ctc_loss import CTCLoss
from extensions.ops.elementwise import Equal
from extensions.ops.parameter import Parameter
from extensions.ops.ReduceOps import ReduceSum
from extensions.ops.select import Select
from extensions.ops.transpose import Transpose
from mo.front.common.partial_infer.utils import int64_array
from mo.front.common.replacement import FrontReplacementSubgraph
from mo.front.tf.graph_utils import create_op_with_const_inputs
from mo.graph.graph import Graph, rename_nodes
from mo.middle.passes.convert_data_type import data_type_str_to_np
from mo.ops.broadcast import Broadcast
from mo.ops.shape import Shape
from mo.ops.squeeze import Squeeze
from mo.utils.error import Error
class CTCLossReplacement(FrontReplacementSubgraph):
"""
The CTCLoss appears along with CTCGreedyDecoder operation in particular. Since the TensorFlow* CTCGreedyDecoder
outputs sparse tensor format, the OpenVINO CTCGreedyDecoder has a different format and the CTCLoss is also affected
outputs sparse tensor format, the OpenVINO CTCGreedyDecoderSeqLen has a different format and the CTCLoss is also affected
in terms of different format for its inputs. So the corresponding sub-graph with CTCGreedyDecoding and CTCLoss
must be transformed properly.
Also, the transformation changes the input sequence length format into a mask format. For example, 1D tensor of
sequence lengths equal to [4 2] is coded as 2D tensor [[1 1 1 1 0], [1 1 0 0 0]] with a time dimension is
equal to 5.
"""
enabled = True
@ -55,17 +39,14 @@ class CTCLossReplacement(FrontReplacementSubgraph):
def pattern(self):
return dict(
nodes=[
('seq_len', dict(op='Parameter')),
('transpose', dict(op='Transpose')),
('ctc_greedy_decoder', dict(op='CTCGreedyDecoder')),
('ctc_greedy_decoder', dict(op='CTCGreedyDecoderSeqLen')),
('cast', dict(op='Cast')),
('sparse_to_dense', dict(op='SparseToDense')),
('const', dict(op='Const')),
('ctc_loss', dict(op='CTCLoss')),
],
edges=[
('seq_len', 'ctc_greedy_decoder', {'out': 0, 'in': 1}),
('seq_len', 'ctc_loss', {'out': 0, 'in': 3}),
('transpose', 'ctc_greedy_decoder', {'out': 0, 'in': 0}),
('transpose', 'ctc_loss', {'out': 0, 'in': 0}),
('ctc_greedy_decoder', 'sparse_to_dense', {'out': 0, 'in': 0}),
@ -78,53 +59,29 @@ class CTCLossReplacement(FrontReplacementSubgraph):
])
def replace_sub_graph(self, graph: Graph, match: dict):
seq_len_tf = match['seq_len']
transpose_tf = match['transpose']
ctc_greedy_decoder_tf = match['ctc_greedy_decoder']
cast_tf = match['cast']
ctc_loss_tf = match['ctc_loss']
sparse_to_dense_tf = match['sparse_to_dense']
output_sparse_to_dense_name = sparse_to_dense_tf.soft_get('name', sparse_to_dense_tf.id)
output_ctc_loss_name = ctc_loss_tf.soft_get('name', ctc_loss_tf.id)
ctc_data_permute = create_op_with_const_inputs(graph, Transpose, {1: int64_array([1, 0, 2])},
{'name': ctc_greedy_decoder_tf.name + '/ctc_data_permute'})
ctc_data_permute.in_port(0).connect(transpose_tf.out_port(0))
ctc_greedy_decoder_tf_name = ctc_greedy_decoder_tf.soft_get('name', ctc_greedy_decoder_tf.id)
log.debug('Found CTCLossFrontReplacer pattern after {} with name {}'.format(ctc_greedy_decoder_tf.op,
ctc_greedy_decoder_tf.name))
# create sequence mask node, sub-graph for transforming into sequence length and connect with consumers
seq_len_tf_shape = seq_len_tf.soft_get('shape', None)
if seq_len_tf_shape is None or len(seq_len_tf_shape) != 2:
raise Error('The sequence length that is the second input to the CTCGreedyDecoder node "{}"'
' must be specified in a mask format.'.format(ctc_greedy_decoder_tf_name))
log.error('The format of input sequence length has been changed to a mask format', extra={'is_warning': True})
seq_len_tf_type = seq_len_tf.soft_get('data_type', None)
seq_len_tf_name = seq_len_tf.soft_get('name', seq_len_tf.id)
seq_mask_placeholder = Parameter(graph, {'name': seq_len_tf_name, 'shape': seq_len_tf_shape,
'data_type': seq_len_tf_type}).create_node()
reduce_to_seq_len_node = create_op_with_const_inputs(graph, ReduceSum, {1: np.array(1, dtype=np.int32)},
{'name': seq_len_tf_name + '/ReduceToSeqLen',
'keep_dims': False})
reduce_to_seq_len_node.in_port(0).connect(seq_mask_placeholder.out_port(0))
seq_len_tf.out_port(0).get_connection().set_source(reduce_to_seq_len_node.out_port(0))
cast_fp_type = data_type_str_to_np(graph.graph['cmd_params'].data_type)
casted_seq_mask_node = Cast(graph, {'name': seq_len_tf_name + '/CastToFP32', 'dst_type': cast_fp_type}).create_node()
casted_seq_mask_node.in_port(0).connect(seq_mask_placeholder.out_port(0))
permuted_casted_seq_mask = create_op_with_const_inputs(graph, Transpose, {1: int64_array([1, 0])},
{'name': seq_len_tf_name + '/Permute'})
permuted_casted_seq_mask.in_port(0).connect(casted_seq_mask_node.out_port(0))
rename_nodes([(seq_len_tf, seq_len_tf_name + '/AbandonedName'), (seq_mask_placeholder, seq_len_tf_name)])
# create CTCGreedyDecoder node and set mask node
ctc_merge_repeated_i = ctc_greedy_decoder_tf.soft_get('ctc_merge_repeated', ctc_greedy_decoder_tf.id)
ctc_greedy_decoder = CTCGreedyDecoderOp(graph, {'name': output_sparse_to_dense_name,
'ctc_merge_repeated': ctc_merge_repeated_i}).create_node()
ctc_greedy_decoder.in_port(1).connect(permuted_casted_seq_mask.out_port(0))
assert ctc_greedy_decoder_tf.has_valid('merge_repeated'), \
'The CTCGreedyDecoderSeqLen node "{}" misses "merge_repeated" attribute'.format(ctc_greedy_decoder_tf_name)
merge_repeated_tf = ctc_greedy_decoder_tf.merge_repeated
ctc_greedy_decoder = CTCGreedyDecoderSeqLenOp(graph, {'name': output_sparse_to_dense_name,
'merge_repeated': merge_repeated_tf}).create_node()
rename_nodes([(sparse_to_dense_tf, output_sparse_to_dense_name + '/AbandonedName'),
(ctc_greedy_decoder, output_sparse_to_dense_name)])
ctc_greedy_decoder.in_port(0).connect(ctc_data_permute.out_port(0))
ctc_greedy_decoder.in_port(1).connect(ctc_greedy_decoder_tf.in_port(1).get_connection().get_source())
# create CTCLoss node and set attributes
# set output of the new sub-graph as a source for SparseToDense consumer
output_ctc_loss_name = ctc_loss_tf.soft_get('name', ctc_loss_tf.id)
assert ctc_loss_tf.has_valid('preprocess_collapse_repeated'), \
'The CTCLoss node "{}" misses "preprocess_collapse_repeated" attribute'.format(output_ctc_loss_name)
assert ctc_loss_tf.has_valid('ctc_merge_repeated'), \
@ -139,48 +96,14 @@ class CTCLossReplacement(FrontReplacementSubgraph):
'ctc_merge_repeated': ctc_merge_repeated,
'unique': unique}).create_node()
rename_nodes([(ctc_loss_tf, output_ctc_loss_name + '/AbandonedName'), (ctc_loss, output_ctc_loss_name)])
# connect logits
ctc_greedy_decoder_tf.in_port(0).get_connection().set_destination(ctc_greedy_decoder.in_port(0))
ctc_loss.in_port(0).disconnect()
transpose_tf.in_port(0).get_connection().add_destination(ctc_loss.in_port(0))
# connect logit lengths
ctc_greedy_decoder_tf.in_port(1).disconnect()
ctc_loss.in_port(1).connect(reduce_to_seq_len_node.out_port(0))
# connect labels to ctc_loss
squeeze_op = create_op_with_const_inputs(graph, Squeeze, {1: int64_array([2, 3])})
cast_labels_op = Cast(graph, {'name': output_sparse_to_dense_name + '/CastLabels', 'dst_type': np.int32}).create_node()
squeeze_op.in_port(0).connect(ctc_greedy_decoder.out_port(0))
cast_labels_op.in_port(0).connect(squeeze_op.out_port(0))
ctc_loss.in_port(2).connect(cast_labels_op.out_port(0))
# connect label lengths
equal_op = create_op_with_const_inputs(graph, Equal, {1: np.array([-1], dtype=np.int32)},
{'name': output_sparse_to_dense_name + '/Equal'})
equal_op.in_port(0).connect(cast_labels_op.out_port(0))
labels_shape_op = Shape(graph, {'name': output_sparse_to_dense_name + '/ShapeOf'}).create_node()
labels_shape_op.in_port(0).connect(equal_op.out_port(0))
broadcast_one = create_op_with_const_inputs(graph, Broadcast, {0: np.array([1], dtype=np.int32)},
{'mode': 'numpy',
'name': output_sparse_to_dense_name + '/One'})
broadcast_one.in_port(1).connect(labels_shape_op.out_port(0))
broadcast_zero = create_op_with_const_inputs(graph, Broadcast, {0: np.array([0], dtype=np.int32)},
{'mode': 'numpy',
'name': output_sparse_to_dense_name + '/Zero'})
broadcast_zero.in_port(1).connect(labels_shape_op.out_port(0))
select_node = Select(graph, {'name': output_sparse_to_dense_name + '/Select'}).create_node()
select_node.in_port(0).connect(equal_op.out_port(0))
select_node.in_port(1).connect(broadcast_zero.out_port(0))
select_node.in_port(2).connect(broadcast_one.out_port(0))
label_length_node = create_op_with_const_inputs(graph, ReduceSum, {1: int64_array([1])},
op_attrs={'name': output_sparse_to_dense_name + '/LabelLength',
'keep_dims': False})
label_length_node.in_port(0).connect(select_node.out_port(0))
ctc_loss.in_port(3).connect(label_length_node.out_port(0))
# set source for output of new sub-graph and remove old nodes
ctc_loss_tf.out_port(0).get_connection().set_source(ctc_loss.out_port(0))
graph.remove_nodes_from([ctc_greedy_decoder_tf.id, ctc_loss_tf.id, cast_tf.id, sparse_to_dense_tf.id])
if ctc_loss_tf.logits_time_major:
ctc_loss.in_port(0).connect(ctc_data_permute.out_port(0))
else:
ctc_loss.in_port(0).connect(transpose_tf.out_port(0))
ctc_loss.in_port(1).connect(ctc_greedy_decoder_tf.in_port(1).get_connection().get_source())
ctc_loss.in_port(2).connect(ctc_greedy_decoder.out_port(0))
ctc_loss.in_port(3).connect(ctc_greedy_decoder.out_port(1))
# remove no longer needed nodes
graph.remove_nodes_from([sparse_to_dense_tf.id, cast_tf.id, ctc_loss_tf.id, ctc_greedy_decoder_tf.id])

View File

@ -28,107 +28,101 @@ class CTCLossFrontReplacementTest(unittest.TestCase):
def test1(self):
nodes_attributes = {
'logits': {'shape': int64_array([2, 6, 100]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
'seq_mask': {'shape': int64_array([2, 100]), 'data_type': np.int32, 'kind': 'op', 'op': 'Parameter'},
'reduce_seq_mask': {'kind': 'op', 'op': 'ReduceSum'},
's_cast_seq_mask': {'kind': 'op', 'op': 'Cast'},
'transpose_cast_seq_mask': {'kind': 'op', 'op': 'Transpose'},
'seq_mask': {'shape': int64_array([2]), 'data_type': np.int32, 'kind': 'op', 'op': 'Parameter'},
'transpose': {'kind': 'op', 'op': 'Transpose'},
'ctc_greedy_decoder': {'kind': 'op', 'op': 'CTCGreedyDecoder'},
'ctc_greedy_decoder': {'kind': 'op', 'op': 'CTCGreedyDecoderSeqLen', 'merge_repeated': True},
'cast': {'kind': 'op', 'op': 'Cast'},
'sparse_to_dense': {'kind': 'op', 'op': 'SparseToDense'},
'const': {'kind': 'op', 'op': 'Const'},
'tf_ctc_loss': {'kind': 'op', 'op': 'CTCLoss', 'preprocess_collapse_repeated': False,
'ctc_merge_repeated': True, 'unique': False, 'logits_time_major': True},
'ctc_loss': {'kind': 'op', 'op': 'CTCLoss', 'preprocess_collapse_repeated': False,
'ctc_merge_repeated': True, 'unique': False},
'equal_op': {'kind': 'op', 'op': 'Equal'},
'ctc_greedy_decoder_op': {'kind': 'op', 'op': 'CTCGreedyDecoder'},
'ctc_loss_op': {'kind': 'op', 'op': 'CTCLoss'},
'squeeze_op': {'kind': 'op', 'op': 'Squeeze'},
'cast_labels_op': {'kind': 'op', 'op': 'Cast', 'type': 'Convert'},
'labels_shape_op': {'kind': 'op', 'op': 'ShapeOf'},
'broadcast_one_op': {'kind': 'op', 'op': 'Broadcast'},
'broadcast_zero_op': {'kind': 'op', 'op': 'Broadcast'},
'select_op': {'kind': 'op', 'op': 'Select'},
'label_length_op': {'kind': 'op', 'op': 'ReduceSum'},
**const('reduce_indices', int64_array(1)),
**const('permute_order', int64_array([1, 0])),
'ctc_merge_repeated': True, 'unique': False},
**const('default_value', int64_array(-1)),
**const('squeeze_axis', int64_array([2, 3])),
**const('minus_one', np.array([-1], dtype=np.int32)),
**const('one', np.array([1], dtype=np.int32)),
**const('zero', np.array([0], dtype=np.int32)),
**const('reduce_sum_axis', int64_array([1])),
'last': {'type': None, 'value': None, 'kind': 'op', 'op': 'Result'},
'transpose2': {'kind': 'op', 'op': 'Transpose'},
**const('transpose2_axis', int64_array([1, 0, 2])),
}
graph = build_graph(nodes_attributes,
[('logits', 'transpose', {'out': 0, 'in': 0}),
('transpose', 'ctc_greedy_decoder', {'out': 0, 'in': 0}),
('seq_mask', 'ctc_greedy_decoder', {'out': 0, 'in': 1}),
('transpose', 'ctc_loss', {'out': 0, 'in': 0}),
('seq_mask', 'ctc_loss', {'out': 0, 'in': 3}),
('ctc_greedy_decoder', 'sparse_to_dense', {'out': 0, 'in': 0}),
('ctc_greedy_decoder', 'sparse_to_dense', {'out': 2, 'in': 1}),
('ctc_greedy_decoder', 'sparse_to_dense', {'out': 1, 'in': 2}),
('default_value', 'sparse_to_dense', {'out': 0, 'in': 3}),
('ctc_greedy_decoder', 'cast', {'out': 1, 'in': 0}),
('ctc_greedy_decoder', 'ctc_loss', {'out': 0, 'in': 1}),
('cast', 'ctc_loss', {'out': 0, 'in': 2}),
('ctc_loss', 'last', {'out': 0, 'in': 0}),
], nodes_with_edges_only=True)
graph = build_graph(nodes_attributes, [('logits', 'transpose', {'out': 0, 'in': 0}),
('transpose', 'ctc_greedy_decoder', {'out': 0, 'in': 0}),
('seq_mask', 'ctc_greedy_decoder', {'out': 0, 'in': 1}),
('transpose', 'tf_ctc_loss', {'out': 0, 'in': 0}),
('seq_mask', 'tf_ctc_loss', {'out': 0, 'in': 3}),
('ctc_greedy_decoder', 'sparse_to_dense', {'out': 0, 'in': 0}),
('ctc_greedy_decoder', 'sparse_to_dense', {'out': 2, 'in': 1}),
('ctc_greedy_decoder', 'sparse_to_dense', {'out': 1, 'in': 2}),
('default_value', 'sparse_to_dense', {'out': 0, 'in': 3}),
('ctc_greedy_decoder', 'cast', {'out': 1, 'in': 0}),
('ctc_greedy_decoder', 'tf_ctc_loss', {'out': 0, 'in': 1}),
('cast', 'tf_ctc_loss', {'out': 0, 'in': 2}),
('tf_ctc_loss', 'last', {'out': 0, 'in': 0})],
nodes_with_edges_only=True)
graph.graph['cmd_params'] = Namespace(data_type='FP32')
graph.stage = 'front'
CTCLossReplacement().find_and_replace_pattern(graph)
graph_ref = build_graph(nodes_attributes,
[('seq_mask', 'reduce_seq_mask', {'out': 0, 'in': 0}),
('reduce_indices', 'reduce_seq_mask', {'out': 0, 'in': 1}),
('seq_mask', 's_cast_seq_mask', {'out': 0, 'in': 0}),
('s_cast_seq_mask', 'transpose_cast_seq_mask', {'out': 0, 'in': 0}),
('permute_order', 'transpose_cast_seq_mask', {'out': 0, 'in': 1}),
('logits', 'transpose', {'out': 0, 'in': 0}),
('transpose', 'ctc_greedy_decoder_op', {'out': 0, 'in': 0}),
('transpose_cast_seq_mask', 'ctc_greedy_decoder_op', {'out': 0, 'in': 1}),
('ctc_greedy_decoder_op', 'squeeze_op', {'out': 0, 'in': 0}),
('squeeze_axis', 'squeeze_op', {'out': 0, 'in': 1}),
('squeeze_op', 'cast_labels_op', {'in': 0}),
('minus_one', 'equal_op', {'out': 0, 'in': 1}),
('equal_op', 'labels_shape_op', {'out': 0, 'in': 0}),
('one', 'broadcast_one_op', {'out': 0, 'in': 0}),
('labels_shape_op', 'broadcast_one_op', {'out': 0, 'in': 1}),
('zero', 'broadcast_zero_op', {'out': 0, 'in': 0}),
('labels_shape_op', 'broadcast_zero_op', {'out': 0, 'in': 1}),
('equal_op', 'select_op', {'out': 0, 'in': 0}),
('broadcast_zero_op', 'select_op', {'out': 0, 'in': 1}),
('broadcast_one_op', 'select_op', {'out': 0, 'in': 2}),
('select_op', 'label_length_op', {'out': 0, 'in': 0}),
('reduce_sum_axis', 'label_length_op', {'out': 0, 'in': 1}),
('logits', 'ctc_loss_op', {'out': 0, 'in': 0}),
('reduce_seq_mask', 'ctc_loss_op', {'out': 0, 'in': 1}),
('cast_labels_op', 'ctc_loss_op', {'out': 0, 'in': 2}),
('label_length_op', 'ctc_loss_op', {'out': 0, 'in': 3}),
('cast_labels_op', 'equal_op', {'out': 0, 'in': 0}),
('ctc_loss_op', 'last', {'out': 0, 'in': 0})],
nodes_with_edges_only=True)
[('logits', 'transpose', {'out': 0, 'in': 0}),
('transpose', 'transpose2', {'out': 0, 'in': 0}),
('transpose2_axis', 'transpose2', {'out': 0, 'in': 1}),
('transpose2', 'ctc_greedy_decoder', {'out': 0, 'in': 0}),
('seq_mask', 'ctc_greedy_decoder', {'out': 0, 'in': 1}),
('transpose2', 'ctc_loss', {'out': 0, 'in': 0}),
('ctc_greedy_decoder', 'ctc_loss', {'out': 0, 'in': 2}),
('ctc_greedy_decoder', 'ctc_loss', {'out': 1, 'in': 3}),
('seq_mask', 'ctc_loss', {'out': 0, 'in': 1}),
('ctc_loss', 'last', {'out': 0, 'in': 0})],
nodes_with_edges_only=True)
(flag, resp) = compare_graphs(graph, graph_ref, 'last', check_op_attrs=True)
self.assertTrue(flag, resp)
def test2(self):
nodes_attributes = {
'logits': {'shape': int64_array([2, 6, 100]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
'seq_mask': {'shape': int64_array([2]), 'data_type': np.int32, 'kind': 'op', 'op': 'Parameter'},
'transpose': {'kind': 'op', 'op': 'Transpose'},
'ctc_greedy_decoder': {'kind': 'op', 'op': 'CTCGreedyDecoderSeqLen', 'merge_repeated': True},
'cast': {'kind': 'op', 'op': 'Cast'},
'sparse_to_dense': {'kind': 'op', 'op': 'SparseToDense'},
'tf_ctc_loss': {'kind': 'op', 'op': 'CTCLoss', 'preprocess_collapse_repeated': False,
'ctc_merge_repeated': True, 'unique': False, 'logits_time_major': False},
'ctc_loss': {'kind': 'op', 'op': 'CTCLoss', 'preprocess_collapse_repeated': False,
'ctc_merge_repeated': True, 'unique': False},
**const('default_value', int64_array(-1)),
'last': {'type': None, 'value': None, 'kind': 'op', 'op': 'Result'},
'transpose2': {'kind': 'op', 'op': 'Transpose'},
**const('transpose2_axis', int64_array([1, 0, 2])),
}
graph = build_graph(nodes_attributes, [('logits', 'transpose', {'out': 0, 'in': 0}),
('transpose', 'ctc_greedy_decoder', {'out': 0, 'in': 0}),
('seq_mask', 'ctc_greedy_decoder', {'out': 0, 'in': 1}),
('transpose', 'tf_ctc_loss', {'out': 0, 'in': 0}),
('seq_mask', 'tf_ctc_loss', {'out': 0, 'in': 3}),
('ctc_greedy_decoder', 'sparse_to_dense', {'out': 0, 'in': 0}),
('ctc_greedy_decoder', 'sparse_to_dense', {'out': 2, 'in': 1}),
('ctc_greedy_decoder', 'sparse_to_dense', {'out': 1, 'in': 2}),
('default_value', 'sparse_to_dense', {'out': 0, 'in': 3}),
('ctc_greedy_decoder', 'cast', {'out': 1, 'in': 0}),
('ctc_greedy_decoder', 'tf_ctc_loss', {'out': 0, 'in': 1}),
('cast', 'tf_ctc_loss', {'out': 0, 'in': 2}),
('tf_ctc_loss', 'last', {'out': 0, 'in': 0})],
nodes_with_edges_only=True)
graph.graph['cmd_params'] = Namespace(data_type='FP32')
graph.stage = 'front'
CTCLossReplacement().find_and_replace_pattern(graph)
graph_ref = build_graph(nodes_attributes,
[('logits', 'transpose', {'out': 0, 'in': 0}),
('transpose', 'transpose2', {'out': 0, 'in': 0}),
('transpose2_axis', 'transpose2', {'out': 0, 'in': 1}),
('transpose2', 'ctc_greedy_decoder', {'out': 0, 'in': 0}),
('seq_mask', 'ctc_greedy_decoder', {'out': 0, 'in': 1}),
('transpose', 'ctc_loss', {'out': 0, 'in': 0}),
('ctc_greedy_decoder', 'ctc_loss', {'out': 0, 'in': 2}),
('ctc_greedy_decoder', 'ctc_loss', {'out': 1, 'in': 3}),
('seq_mask', 'ctc_loss', {'out': 0, 'in': 1}),
('ctc_loss', 'last', {'out': 0, 'in': 0})],
nodes_with_edges_only=True)
(flag, resp) = compare_graphs(graph, graph_ref, 'last', check_op_attrs=True)
self.assertTrue(flag, resp)

View File

@ -23,11 +23,18 @@ class CTCLossFrontExtractor(FrontExtractorOp):
@classmethod
def extract(cls, node):
# For CTCLoss default value is [N, T]
logits_time_major = True
if 'logits_time_major' in node.pb.attr:
logits_time_major = node.pb.attr['logits_time_major'].b
attrs = {
'ctc_merge_repeated': node.pb.attr['ctc_merge_repeated'].b,
'preprocess_collapse_repeated': node.pb.attr['preprocess_collapse_repeated'].b,
'logits_time_major': logits_time_major,
# unique is always false for CTCLoss V1
'unique': False
}
CTCLoss.update_node_stat(node, attrs)
return cls.enabled

View File

@ -0,0 +1,92 @@
"""
Copyright (C) 2017-2021 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import numpy as np
from mo.front.common.partial_infer.utils import int64_array
from mo.front.extractor import bool_to_str
from mo.graph.graph import Node, Graph
from mo.middle.passes.convert_data_type import np_data_type_to_destination_type
from mo.ops.op import Op
from mo.utils.error import Error
class CTCGreedyDecoderSeqLenOp(Op):
op = 'CTCGreedyDecoderSeqLen'
def __init__(self, graph: Graph, attrs: dict):
mandatory_props = {
'type': self.op,
'op': self.op,
'version': 'opset6',
'infer': self.infer,
'type_infer': self.type_infer,
'in_ports_count': 3,
'out_ports_count': 2,
'merge_repeated': True,
'classes_index_type': np.int32,
'sequence_length_type': np.int32
}
super().__init__(graph, mandatory_props, attrs)
def backend_attrs(self):
version = self.get_opset()
if version == 'opset6':
return [('classes_index_type', lambda node: np_data_type_to_destination_type(node.classes_index_type)),
('sequence_length_type', lambda node: np_data_type_to_destination_type(node.sequence_length_type)),
('merge_repeated', lambda node: bool_to_str(node, 'merge_repeated'))]
else:
raise Error('Unknown opset version "{}"'.format(version))
@staticmethod
def type_infer(node):
opset = node.get_opset()
if opset == 'opset6':
node.out_port(0).set_data_type(node.classes_index_type)
node.out_port(1).set_data_type(node.sequence_length_type)
@staticmethod
def infer(node: Node):
node_name = node.soft_get('name', node.id)
connected_in_ports = [port for port in node.in_ports().values() if not port.disconnected()]
assert len(connected_in_ports) in [2, 3], \
"Incorrect number of inputs for {} node".format(node_name)
logits_shape = node.in_port(0).data.get_shape()
sequence_len_shape = node.in_port(1).data.get_shape()
if len(node.in_nodes()) == 3:
blank_index_shape = node.in_port(2).data.get_shape()
assert len(blank_index_shape) == 1, \
'Incorrect rank of blank_index for {} node'.format(node_name)
# check shapes of input tensors
assert len(logits_shape) == 3, \
'Incorrect rank of logits for {} node'.format(node_name)
assert len(sequence_len_shape) == 1, \
'Incorrect rank of sequence length tensor for {} node'.format(node_name)
assert logits_shape[0] == sequence_len_shape[0], \
'Batch dimensions of input tensors must be the same for {} node'.format(node_name)
batch_size = logits_shape[0]
time_size = logits_shape[1]
if node.is_out_port_connected(0):
node.out_port(0).data.set_shape(int64_array([batch_size, time_size]))
if node.is_out_port_connected(1):
node.out_port(1).data.set_shape(int64_array([batch_size]))

View File

@ -1,5 +1,5 @@
"""
Copyright (C) 2018-2020 Intel Corporation
Copyright (C) 2018-2021 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -18,7 +18,7 @@ import unittest
import numpy as np
from extensions.ops.ctc_greedy_decoder import CTCGreedyDecoderOp
from extensions.ops.ctc_greedy_decoder_seq_len import CTCGreedyDecoderSeqLenOp
from mo.front.common.partial_infer.utils import int64_array
from mo.graph.graph import Node
from mo.utils.unittest.graph import build_graph
@ -28,68 +28,73 @@ nodes_attributes = {'logits': {'kind': 'op'},
'logits_data': {'shape': None, 'value': None, 'kind': 'data'},
'seq_mask': {'kind': 'op'},
'seq_mask_data': {'shape': None, 'value': None, 'kind': 'data'},
'ctcgreedydecoder_node': {'op': 'CTCGreedyDecoder', 'kind': 'op',
'ctcgreedydecoder_node': {'op': 'CTCGreedyDecoderSeqLen', 'kind': 'op',
'ctc_merge_repeated': True},
'output': {'shape': None, 'value': None, 'kind': 'data'}}
'output1': {'shape': None, 'value': None, 'kind': 'data'},
'last_output1': {'shape': None, 'value': None, 'kind': 'op'},
'output2': {'shape': None, 'value': None, 'kind': 'data'}
}
# graph 1
edges1 = [('logits', 'logits_data'),
('seq_mask', 'seq_mask_data'),
('logits_data', 'ctcgreedydecoder_node', {'in': 0}),
('seq_mask_data', 'ctcgreedydecoder_node', {'in': 1}),
('ctcgreedydecoder_node', 'output', {'out': 0})]
('ctcgreedydecoder_node', 'output1', {'out': 0}),
('ctcgreedydecoder_node', 'output2', {'out': 1}),
('output1', 'last_output1', {'out': 0}),]
# valid test case
inputs1 = {'logits_data': {'shape': int64_array([100, 4, 5])},
'seq_mask_data': {'shape': int64_array([100, 4])}}
inputs1 = {'logits_data': {'shape': int64_array([4, 100, 5])},
'seq_mask_data': {'shape': int64_array([4])}}
# invalid test case with incorrect rank for the first input tensor
inputs1_inv = {'logits_data': {'shape': int64_array([100, 4, 5, 6])},
'seq_mask_data': {'shape': int64_array([100, 4])}}
inputs1_inv = {'logits_data': {'shape': int64_array([4, 100, 5, 6])},
'seq_mask_data': {'shape': int64_array([4])}}
# invalid test case with incorrect rank for the second input tensor
inputs2_inv = {'logits_data': {'shape': int64_array([100, 4, 5])},
'seq_mask_data': {'shape': int64_array([100])}}
inputs2_inv = {'logits_data': {'shape': int64_array([4, 100, 5])},
'seq_mask_data': {'shape': int64_array([4, 100])}}
# invalid test case with incorrect time dimension
inputs3_inv = {'logits_data': {'shape': int64_array([100, 4, 5])},
'seq_mask_data': {'shape': int64_array([101, 4])}}
inputs3_inv = {'logits_data': {'shape': int64_array([4, 100, 5])},
'seq_mask_data': {'shape': int64_array([4, 101])}}
# invalid test case with incorrect batch dimension
inputs4_inv = {'logits_data': {'shape': int64_array([100, 4, 5])},
'seq_mask_data': {'shape': int64_array([100, 14])}}
inputs4_inv = {'logits_data': {'shape': int64_array([4, 100, 5])},
'seq_mask_data': {'shape': int64_array([14, 100])}}
class TestCTCGreedyDecoder(unittest.TestCase):
def test_infer1(self):
graph = build_graph(nodes_attributes, edges1, inputs1)
ctcgreedydecoder_node = Node(graph, 'ctcgreedydecoder_node')
CTCGreedyDecoderOp.infer(ctcgreedydecoder_node)
CTCGreedyDecoderSeqLenOp.infer(ctcgreedydecoder_node)
# prepare reference results
ref_output_shape = int64_array([4, 100, 1, 1])
ref_output1_shape = int64_array([4, 100])
# get the result
res_output_shape = graph.node['output']['shape']
res_output1_shape = graph.node['output1']['shape']
self.assertTrue(np.array_equal(ref_output_shape, res_output_shape),
'shapes do not match expected: {} and given: {}'.format(ref_output_shape, res_output_shape))
self.assertTrue(np.array_equal(ref_output1_shape, res_output1_shape),
'shapes do not match expected: {} and given: {}'.format(ref_output1_shape, res_output1_shape))
def test_infer_invalid1(self):
graph = build_graph(nodes_attributes, edges1, inputs1_inv)
ctcgreedydecoder_node = Node(graph, 'ctcgreedydecoder_node')
self.assertRaises(AssertionError, CTCGreedyDecoderOp.infer, ctcgreedydecoder_node)
self.assertRaises(AssertionError, CTCGreedyDecoderSeqLenOp.infer, ctcgreedydecoder_node)
def test_infer_invalid2(self):
graph = build_graph(nodes_attributes, edges1, inputs2_inv)
ctcgreedydecoder_node = Node(graph, 'ctcgreedydecoder_node')
self.assertRaises(AssertionError, CTCGreedyDecoderOp.infer, ctcgreedydecoder_node)
self.assertRaises(AssertionError, CTCGreedyDecoderSeqLenOp.infer, ctcgreedydecoder_node)
def test_infer_invalid3(self):
graph = build_graph(nodes_attributes, edges1, inputs3_inv)
ctcgreedydecoder_node = Node(graph, 'ctcgreedydecoder_node')
self.assertRaises(AssertionError, CTCGreedyDecoderOp.infer, ctcgreedydecoder_node)
self.assertRaises(AssertionError, CTCGreedyDecoderSeqLenOp.infer, ctcgreedydecoder_node)
def test_infer_invalid4(self):
graph = build_graph(nodes_attributes, edges1, inputs4_inv)
ctcgreedydecoder_node = Node(graph, 'ctcgreedydecoder_node')
self.assertRaises(AssertionError, CTCGreedyDecoderOp.infer, ctcgreedydecoder_node)
self.assertRaises(AssertionError, CTCGreedyDecoderSeqLenOp.infer, ctcgreedydecoder_node)

View File

@ -0,0 +1,30 @@
"""
Copyright (c) 2021 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from mo.middle.passes.convert_data_type import destination_type_to_np_data_type
from mo.utils.graph import Node
from mo.utils.ir_reader.extender import Extender
class CTCGreedyDecoderSeqLenExtender(Extender):
op = 'CTCGreedyDecoderSeqLen'
@staticmethod
def extend(op: Node):
if op.has_valid('classes_index_type'):
op['classes_index_type'] = destination_type_to_np_data_type(op.classes_index_type)
if op.has_valid('sequence_length_type'):
op['sequence_length_type'] = destination_type_to_np_data_type(op.sequence_length_type)