Add CTCGreedyDecoder mo support (#4009)
* Add CTCGreedyDecoder mo support * Update copiright * Update bom file * Add transformation * Fix code style * Fix according to review * Add CTCGreedyDecoder v6 to ConvertPrecision * Hot fix * Add replasment for ctc_greedy_decoder * Fix test * Fix * Update ie transform * Draft ctc lost replaser * Add ctcloss replaser * Update * Refactoring code * Update transformation * Update decoder * Remove comments * Convert seq mask from int to float * Fix unit test * Add dynamic tests * Refactoring code * Fix py code style * update style * Disable ctcgreedydecoder transform for mkldnn plugin * Add some comments * Add transfor code comments * Enable transform from differend plagins * Fix mo * fix tests * Fix comment * Fix convert precition * Update comment * Fix prcition * Refactoring according to reviw * Add ir reder extender * Rename transformation * Update bom file * Fix mo replacer * Fix tests * Move transform to decomp * Add check blank_index * Rafactoring ctcloss * Change dinemic rank check * Fix ctclos extractor * Remove comment * Fix code style * Refactoring pattern matcher for transformation CTCGreedyDecoder * Disavle transform for vpu * Refactoring according to review * Refactoring code * Disable transformation for cldnn * Remove unused code * Reverse transfomation * Fix code style * Hot fix transform * Fix unit tests * Update transform * Enable transform in common pipline * Fix names replasments for mo transformations * Hot fix * Fix
This commit is contained in:
parent
edbb802e55
commit
f670b7cb3a
@ -58,6 +58,7 @@
|
||||
#include <transformations/op_conversions/convert_nms_to_nms_ie_internal.hpp>
|
||||
#include <transformations/op_conversions/convert_interpolate1_to_interpolate4.hpp>
|
||||
#include <transformations/op_conversions/convert_gather_0d.hpp>
|
||||
#include <transformations/op_conversions/simplify_ctc_greedy_decoder_seq_len.hpp>
|
||||
#include <transformations/convert_precision.hpp>
|
||||
#include <transformations/init_node_info.hpp>
|
||||
#include <transformations/rt_info/fused_names_attribute.hpp>
|
||||
@ -278,6 +279,7 @@ InferenceEngine::CNNNetwork clDNNEngine::CloneAndTransformNetwork(const Inferenc
|
||||
pass_config->disable<ngraph::pass::LogSoftmaxDecomposition>();
|
||||
pass_config->disable<ngraph::pass::ConvertBroadcast3>();
|
||||
pass_config->disable<ngraph::pass::WeightsDequantizeToFakeQuantize>();
|
||||
pass_config->disable<ngraph::pass::SimplifyCTCGreedyDecoderSeqLen>();
|
||||
|
||||
pass_config->enable<ngraph::pass::ConvertInterpolate1ToInterpolate4>();
|
||||
|
||||
|
@ -57,6 +57,7 @@
|
||||
#include <transformations/op_conversions/gru_cell_decomposition.hpp>
|
||||
#include <transformations/op_conversions/log_softmax_decomposition.hpp>
|
||||
#include <transformations/op_conversions/convert_interpolate1_to_interpolate4.hpp>
|
||||
#include <transformations/op_conversions/simplify_ctc_greedy_decoder_seq_len.hpp>
|
||||
#include <transformations/convert_precision.hpp>
|
||||
#include <transformations/init_node_info.hpp>
|
||||
#include <transformations/rt_info/fused_names_attribute.hpp>
|
||||
@ -227,6 +228,7 @@ static void Transformation(CNNNetwork& clonedNetwork, const Config& conf) {
|
||||
pass_config->disable<ngraph::pass::LogSoftmaxDecomposition>();
|
||||
pass_config->disable<ngraph::pass::ConvertInterpolateToInterpOrResampleMatcher>();
|
||||
pass_config->disable<ngraph::pass::WeightsDequantizeToFakeQuantize>();
|
||||
pass_config->disable<ngraph::pass::SimplifyCTCGreedyDecoderSeqLen>();
|
||||
|
||||
pass_config->enable<ngraph::pass::ConvertInterpolate1ToInterpolate4>();
|
||||
|
||||
|
@ -0,0 +1,43 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include <transformations_visibility.hpp>
|
||||
|
||||
#include <ngraph/pass/graph_rewrite.hpp>
|
||||
|
||||
namespace ngraph {
|
||||
namespace pass {
|
||||
|
||||
class TRANSFORMATIONS_API SimplifyCTCGreedyDecoderSeqLen;
|
||||
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief SimplifyCTCGreedyDecoder converts v6:CTCGreedyDecoderSeqLen into v0::CTCGreedyDecoder.
|
||||
*
|
||||
* data[N, T, C] seq_len[N]
|
||||
* \ /
|
||||
* CTCGreedyDecoderSeqLen
|
||||
*
|
||||
* will be converted to
|
||||
*
|
||||
* data[T, N, C] seq_mask[T, N]
|
||||
* \ /
|
||||
* CTCGreedyDecoder
|
||||
* / \
|
||||
* class_index[N, T] seq_len[N]
|
||||
*
|
||||
* The transformation works only for case when the blank_index input == C-1, where C is the number of classes.
|
||||
*/
|
||||
class ngraph::pass::SimplifyCTCGreedyDecoderSeqLen: public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
SimplifyCTCGreedyDecoderSeqLen();
|
||||
};
|
@ -52,6 +52,7 @@
|
||||
#include "transformations/op_conversions/hsigmoid_decomposition.hpp"
|
||||
#include "transformations/op_conversions/log_softmax_decomposition.hpp"
|
||||
#include "transformations/op_conversions/mvn6_decomposition.hpp"
|
||||
#include "transformations/op_conversions/simplify_ctc_greedy_decoder_seq_len.hpp"
|
||||
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
#include <ngraph/pass/constant_folding.hpp>
|
||||
@ -119,6 +120,7 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::
|
||||
decomp->add_matcher<ngraph::pass::ConvertSpaceToDepth>();
|
||||
decomp->add_matcher<ngraph::pass::BatchNormDecomposition>();
|
||||
decomp->add_matcher<ngraph::pass::MVN6Decomposition>();
|
||||
decomp->add_matcher<ngraph::pass::SimplifyCTCGreedyDecoderSeqLen>();
|
||||
decomp->set_name("ngraph::pass::CommonDecompositions");
|
||||
|
||||
// CF is required after all decompositions
|
||||
|
@ -8,6 +8,7 @@
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include <ngraph/opsets/opset6.hpp>
|
||||
#include <ngraph/opsets/opset5.hpp>
|
||||
#include <ngraph/opsets/opset4.hpp>
|
||||
#include <ngraph/opsets/opset3.hpp>
|
||||
@ -29,6 +30,7 @@ bool fuse_type_to_nms5(std::shared_ptr<ngraph::Node> & node, ngraph::element::Ty
|
||||
bool fuse_type_to_topk(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx);
|
||||
bool fuse_type_to_nonzero(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx);
|
||||
bool fuse_type_to_bucketize(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx);
|
||||
bool fuse_type_to_ctc_greedy_decoder_seq_len(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx);
|
||||
|
||||
bool extend_select_type(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx);
|
||||
|
||||
@ -87,6 +89,7 @@ bool ngraph::pass::ConvertPrecision::run_on_function(std::shared_ptr<ngraph::Fun
|
||||
{opset3::NonMaxSuppression::type_info, fuse_type_to_nms3},
|
||||
{opset4::NonMaxSuppression::type_info, fuse_type_to_nms4},
|
||||
{opset5::NonMaxSuppression::type_info, fuse_type_to_nms5},
|
||||
{opset6::CTCGreedyDecoderSeqLen::type_info, fuse_type_to_ctc_greedy_decoder_seq_len},
|
||||
{opset4::TopK::type_info, fuse_type_to_topk},
|
||||
{opset4::NonZero::type_info, fuse_type_to_nonzero},
|
||||
{opset4::Bucketize::type_info, fuse_type_to_bucketize},
|
||||
@ -260,6 +263,20 @@ bool fuse_type_to_topk(std::shared_ptr<ngraph::Node> & node, ngraph::element::Ty
|
||||
return false;
|
||||
}
|
||||
|
||||
bool fuse_type_to_ctc_greedy_decoder_seq_len(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx) {
|
||||
if (auto ctc_decoder = as_type_ptr<opset6::CTCGreedyDecoderSeqLen>(node)) {
|
||||
if (idx == 0 && (to == element::i32 || to == element::i64)) {
|
||||
ctc_decoder->set_classes_index_type(to);
|
||||
return true;
|
||||
}
|
||||
if (idx == 1 && (to == element::i32 || to == element::i64)) {
|
||||
ctc_decoder->set_sequence_length_type(to);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool fuse_type_to_nonzero(std::shared_ptr<ngraph::Node> & node, ngraph::element::Type to, size_t idx) {
|
||||
if (auto nonzero = as_type_ptr<opset4::NonZero>(node)) {
|
||||
if (to == element::i32 || to == element::i64) {
|
||||
|
@ -0,0 +1,128 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "itt.hpp"
|
||||
|
||||
#include "transformations/op_conversions/simplify_ctc_greedy_decoder_seq_len.hpp"
|
||||
|
||||
#include <ngraph/opsets/opset6.hpp>
|
||||
#include <ngraph/rt_info.hpp>
|
||||
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::SimplifyCTCGreedyDecoderSeqLen, "SimplifyCTCGreedyDecoder", 0);
|
||||
|
||||
ngraph::pass::SimplifyCTCGreedyDecoderSeqLen::SimplifyCTCGreedyDecoderSeqLen() {
|
||||
MATCHER_SCOPE(SimplifyCTCGreedyDecoderSeqLen);
|
||||
auto decoder = pattern::wrap_type<opset6::CTCGreedyDecoderSeqLen>();
|
||||
|
||||
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
|
||||
auto decoder_seq_len = std::dynamic_pointer_cast<opset6::CTCGreedyDecoderSeqLen> (m.get_match_root());
|
||||
if (!decoder_seq_len) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (decoder_seq_len->get_input_size() > 2) {
|
||||
const auto data_pshape = decoder_seq_len->get_input_partial_shape(0);
|
||||
auto blank_index = std::dynamic_pointer_cast<ngraph::opset6::Constant>(decoder_seq_len->input_value(2).get_node_shared_ptr());
|
||||
if (!blank_index || data_pshape.rank().is_dynamic() || data_pshape[2].is_dynamic()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const std::vector<int64_t> &blank_index_values = blank_index->cast_vector<int64_t>();
|
||||
const auto num_classes = decoder_seq_len->get_input_partial_shape(0)[2].get_length();
|
||||
if (blank_index_values[0] != (num_classes - 1)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
element::Type data_type = decoder_seq_len->input_value(0).get_element_type();
|
||||
element::Type seq_len_type = decoder_seq_len->input_value(1).get_element_type();
|
||||
// Transposing input data channels from [N, T, C] to [T, N, C]. Need for compatible with CTCGreedyDecoder v1
|
||||
auto transpose = std::make_shared<ngraph::opset6::Transpose>(decoder_seq_len->input_value(0),
|
||||
ngraph::opset6::Constant::create(element::i32,
|
||||
Shape({3}), {1, 0, 2}));
|
||||
// Receive time and batch dimensions and concatenate to [T, N] tensor shapes
|
||||
auto data_shape = std::make_shared<ngraph::opset6::ShapeOf>(decoder_seq_len->input_value(0));
|
||||
auto axisT = ngraph::opset6::Constant::create(seq_len_type, Shape{}, {0});
|
||||
auto indexT = ngraph::opset6::Constant::create(seq_len_type, Shape{1}, {1});
|
||||
auto T = std::make_shared<ngraph::opset6::Gather>(data_shape, indexT, axisT);
|
||||
|
||||
auto axisN = ngraph::opset6::Constant::create(seq_len_type, Shape{}, {0});
|
||||
auto indexN = ngraph::opset6::Constant::create(seq_len_type, Shape{1}, {0});
|
||||
auto N = std::make_shared<ngraph::opset6::Gather>(data_shape, indexN, axisN);
|
||||
|
||||
auto start = opset6::Constant::create(seq_len_type, Shape{}, {1});
|
||||
auto step = opset6::Constant::create(seq_len_type, Shape{}, {1});
|
||||
auto plus1 = opset6::Constant::create(element::i64, Shape{1}, {1});
|
||||
auto plusT = std::make_shared<ngraph::opset6::Add>(T, plus1);
|
||||
auto const_plusT = ngraph::opset6::Constant::create(seq_len_type, Shape{1}, {0});
|
||||
auto plusT_scalar = std::make_shared<ngraph::opset6::Squeeze>(plusT, const_plusT);
|
||||
auto range1T = std::make_shared<ngraph::opset6::Range>(start, plusT_scalar, step, seq_len_type);
|
||||
|
||||
auto mask_shape = std::make_shared<ngraph::opset6::Concat>(
|
||||
OutputVector{T->output(0), N->output(0)}, 0);
|
||||
|
||||
// Generate 2D tensor [T, N] for seq mask
|
||||
auto upper_bounds = std::make_shared<ngraph::opset6::Broadcast>(
|
||||
decoder_seq_len->input_value(1), mask_shape->output(0));
|
||||
auto transpose_upper_bounds = std::make_shared<ngraph::opset6::Transpose>(upper_bounds->output(0),
|
||||
ngraph::opset6::Constant::create(seq_len_type,
|
||||
Shape({2}), {1, 0}));
|
||||
// Compute boolean sequence mask
|
||||
auto bool_seq_mask = std::make_shared<ngraph::opset6::GreaterEqual>(transpose_upper_bounds->output(0),
|
||||
range1T->output(0));
|
||||
|
||||
// Generate resulted seq mask
|
||||
auto mask_val_true = ngraph::opset6::Constant::create(seq_len_type, Shape{}, {1});
|
||||
auto mask_val_false = ngraph::opset6::Constant::create(seq_len_type, Shape{}, {0});
|
||||
auto seq_mask = std::make_shared<ngraph::opset6::Select>(bool_seq_mask, mask_val_true, mask_val_false);
|
||||
auto transpose_seq_mask = std::make_shared<ngraph::opset6::Transpose>(seq_mask->output(0),
|
||||
ngraph::opset6::Constant::create(seq_len_type,
|
||||
Shape({2}), {1, 0}));
|
||||
auto transpose_seq_mask_f = std::make_shared<ngraph::opset6::Convert>(transpose_seq_mask->output(0), data_type);
|
||||
// Create CTCGreedyDecoder with original merge_repeated attribute and connect data and resulted seq_mask
|
||||
auto decoder = std::make_shared<ngraph::opset6::CTCGreedyDecoder>(transpose,
|
||||
transpose_seq_mask_f->output(0),
|
||||
decoder_seq_len->get_merge_repeated());
|
||||
decoder->set_friendly_name(decoder_seq_len->get_friendly_name());
|
||||
|
||||
// Normalize output from CTCGreedyDecoder = output_f and create second output with output_seq_len
|
||||
auto squeeze2_axis = ngraph::opset6::Constant::create(seq_len_type, Shape({1}), {3});
|
||||
auto squeeze2_output_f = std::make_shared<ngraph::opset6::Squeeze>(decoder->output(0), squeeze2_axis);
|
||||
auto squeeze1_axis = ngraph::opset6::Constant::create(seq_len_type, Shape({1}), {2});
|
||||
auto squeeze1_output_f = std::make_shared<ngraph::opset6::Squeeze>(squeeze2_output_f->output(0), squeeze1_axis);
|
||||
|
||||
element::Type ci_type = decoder_seq_len->get_classes_index_type();
|
||||
element::Type sl_type = decoder_seq_len->get_sequence_length_type();
|
||||
// CTCGreedyDecoder return floating point output. For Normalize output we need to convert output to classes_index_type
|
||||
// Receive the first output with correct classes_index_type
|
||||
auto output_i = std::make_shared<ngraph::opset6::Convert>(squeeze1_output_f->output(0), ci_type);
|
||||
auto minus1 = opset6::Constant::create(ci_type, Shape{}, {-1});
|
||||
// Get to know where equal -1
|
||||
auto where_equal_minus1 = std::make_shared<ngraph::opset6::Equal>(output_i, minus1);
|
||||
|
||||
// Compute output seq mask
|
||||
auto seq_mask_const0 = opset6::Constant::create(ci_type, Shape{1}, {0});
|
||||
auto seq_mask_const1 = opset6::Constant::create(ci_type, Shape{1}, {1});
|
||||
auto output_seq_mask = std::make_shared<ngraph::opset6::Select>(where_equal_minus1, seq_mask_const0, seq_mask_const1);
|
||||
auto seq_mask_axis = opset6::Constant::create(ci_type, Shape{1}, {1});
|
||||
// Receive the second output
|
||||
auto output_seq_len = std::make_shared<ngraph::opset6::ReduceSum>(output_seq_mask, seq_mask_axis);
|
||||
// Receive the second output with correct seq_len_type
|
||||
auto output_seq_len_i = std::make_shared<ngraph::opset6::Convert>(output_seq_len->output(0), sl_type);
|
||||
ngraph::copy_runtime_info(decoder_seq_len, {transpose, decoder, data_shape, T, N, plusT, plusT_scalar, range1T, mask_shape, upper_bounds,
|
||||
squeeze2_output_f, squeeze1_output_f, transpose_upper_bounds, bool_seq_mask, seq_mask, transpose_seq_mask,
|
||||
transpose_seq_mask_f, output_i, where_equal_minus1, output_seq_mask, output_seq_len, output_seq_len_i});
|
||||
|
||||
output_i->set_friendly_name(decoder_seq_len->get_friendly_name()+".0");
|
||||
output_seq_len_i->set_friendly_name(decoder_seq_len->get_friendly_name()+".1");
|
||||
ngraph::replace_node(decoder_seq_len, {output_i->output(0), output_seq_len_i->output(0)});
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(decoder, matcher_name);
|
||||
register_matcher(m, callback);
|
||||
}
|
@ -31,6 +31,7 @@
|
||||
#include <transformations/op_conversions/softplus_decomposition.hpp>
|
||||
#include <transformations/op_conversions/convert_minimum_to_power_and_max.hpp>
|
||||
#include <transformations/op_conversions/hswish_decomposition.hpp>
|
||||
#include <transformations/op_conversions/simplify_ctc_greedy_decoder_seq_len.hpp>
|
||||
#include <legacy/transformations/convert_opset1_to_legacy/convert_opset1_to_legacy.hpp>
|
||||
#include <legacy/transformations/convert_opset1_to_legacy/convert_prior_to_ie_prior.hpp>
|
||||
#include <transformations/common_optimizations/common_optimizations.hpp>
|
||||
@ -196,6 +197,7 @@ ie::CNNNetwork FrontEnd::convertNetwork(ie::CNNNetwork& network) {
|
||||
pass_config->disable<ngraph::pass::ConvertMinimum>();
|
||||
pass_config->disable<ngraph::pass::HSwishDecomposition>();
|
||||
pass_config->disable<ngraph::pass::MVN6Decomposition>();
|
||||
pass_config->disable<ngraph::pass::SimplifyCTCGreedyDecoderSeqLen>();
|
||||
|
||||
auto transformationPredicate = [](const std::shared_ptr<const ngraph::Node>& node) -> bool {
|
||||
return !!std::dynamic_pointer_cast<const ngraph::vpu::op::DynamicShapeResolver>(node->input_value(0).get_node_shared_ptr());
|
||||
|
@ -0,0 +1,560 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
|
||||
|
||||
#include <ngraph/function.hpp>
|
||||
#include <ngraph/opsets/opset6.hpp>
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
#include <transformations/op_conversions/simplify_ctc_greedy_decoder_seq_len.hpp>
|
||||
#include <transformations/init_node_info.hpp>
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
|
||||
using namespace testing;
|
||||
using namespace ngraph;
|
||||
|
||||
TEST(TransformationTests, SimplifyCTCGreedyDecoderSeqLenTest) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto data = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 3, 7 });
|
||||
auto seq_len = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::i64, ngraph::Shape{ 1 });
|
||||
|
||||
auto decoder_v6 = std::make_shared<ngraph::op::v6::CTCGreedyDecoderSeqLen>(data, seq_len, true);
|
||||
auto res_1 = std::make_shared<opset6::Result>(decoder_v6->output(0));
|
||||
auto res_2 = std::make_shared<opset6::Result>(decoder_v6->output(1));
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ res_1, res_2 }, ngraph::ParameterVector{ data, seq_len });
|
||||
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::SimplifyCTCGreedyDecoderSeqLen>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
{
|
||||
auto data1 = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 3, 7 });
|
||||
auto seq_len1 = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::i64, ngraph::Shape{ 1 });
|
||||
|
||||
element::Type data_type = data1->get_element_type();
|
||||
element::Type seq_len_type = seq_len1->get_element_type();
|
||||
element::Type ci_type = element::i32;
|
||||
element::Type sl_type = element::i32;
|
||||
auto transpose = std::make_shared<ngraph::opset6::Transpose>(data1,
|
||||
ngraph::opset6::Constant::create(element::i32,
|
||||
Shape({3}), {1, 0, 2}));
|
||||
auto data_shape = std::make_shared<ngraph::opset6::ShapeOf>(data1);
|
||||
auto axisT = ngraph::opset6::Constant::create(seq_len_type, Shape{}, {0});
|
||||
auto indexT = ngraph::opset6::Constant::create(seq_len_type, Shape{1}, {1});
|
||||
auto T = std::make_shared<ngraph::opset6::Gather>(data_shape, indexT, axisT);
|
||||
|
||||
auto axisN = ngraph::opset6::Constant::create(seq_len_type, Shape{}, {0});
|
||||
auto indexN = ngraph::opset6::Constant::create(seq_len_type, Shape{1}, {0});
|
||||
auto N = std::make_shared<ngraph::opset6::Gather>(data_shape, indexN, axisN);
|
||||
|
||||
auto start = opset6::Constant::create(seq_len_type, Shape{}, {1});
|
||||
auto step = opset6::Constant::create(seq_len_type, Shape{}, {1});
|
||||
auto plus1 = opset6::Constant::create(element::i64, Shape{1}, {1});
|
||||
auto plusT = std::make_shared<ngraph::opset6::Add>(T, plus1);
|
||||
auto const_plusT = ngraph::opset6::Constant::create(seq_len_type, Shape{1}, {0});
|
||||
auto plusT_scalar = std::make_shared<ngraph::opset6::Squeeze>(plusT, const_plusT);
|
||||
auto range1T = std::make_shared<ngraph::opset6::Range>(start, plusT_scalar, step, seq_len_type);
|
||||
|
||||
auto mask_shape = std::make_shared<ngraph::opset6::Concat>(
|
||||
OutputVector{T->output(0), N->output(0)}, 0);
|
||||
|
||||
auto upper_bounds = std::make_shared<ngraph::opset6::Broadcast>(
|
||||
seq_len1, mask_shape->output(0));
|
||||
auto transpose_upper_bounds = std::make_shared<ngraph::opset6::Transpose>(upper_bounds->output(0),
|
||||
ngraph::opset6::Constant::create(seq_len_type,
|
||||
Shape({2}), {1, 0}));
|
||||
auto bool_seq_mask = std::make_shared<ngraph::opset6::GreaterEqual>(transpose_upper_bounds->output(0),
|
||||
range1T->output(0));
|
||||
|
||||
auto mask_val_true = ngraph::opset6::Constant::create(seq_len_type, Shape{}, {1});
|
||||
auto mask_val_false = ngraph::opset6::Constant::create(seq_len_type, Shape{}, {0});
|
||||
auto seq_mask = std::make_shared<ngraph::opset6::Select>(bool_seq_mask, mask_val_true, mask_val_false);
|
||||
auto transpose_seq_mask = std::make_shared<ngraph::opset6::Transpose>(seq_mask->output(0),
|
||||
ngraph::opset6::Constant::create(seq_len_type,
|
||||
Shape({2}), {1, 0}));
|
||||
auto transpose_seq_mask_f32 = std::make_shared<ngraph::opset6::Convert>(transpose_seq_mask->output(0), data_type);
|
||||
auto simplified_decoder = std::make_shared<ngraph::opset6::CTCGreedyDecoder>(transpose,
|
||||
transpose_seq_mask_f32->output(0),
|
||||
true);
|
||||
|
||||
auto squeeze2_axis = ngraph::opset6::Constant::create(seq_len_type, Shape({1}), {3});
|
||||
auto squeeze2_output_f = std::make_shared<ngraph::opset6::Squeeze>(simplified_decoder->output(0), squeeze2_axis);
|
||||
auto squeeze1_axis = ngraph::opset6::Constant::create(seq_len_type, Shape({1}), {2});
|
||||
auto squeeze1_output_f = std::make_shared<ngraph::opset6::Squeeze>(squeeze2_output_f->output(0), squeeze1_axis);
|
||||
|
||||
auto output_i = std::make_shared<ngraph::opset6::Convert>(squeeze1_output_f->output(0), ci_type);
|
||||
auto minus1 = opset6::Constant::create(ci_type, Shape{}, {-1});
|
||||
auto where_equal_minus1 = std::make_shared<ngraph::opset6::Equal>(output_i, minus1);
|
||||
|
||||
auto seq_mask_const0 = opset6::Constant::create(ci_type, Shape{1}, {0});
|
||||
auto seq_mask_const1 = opset6::Constant::create(ci_type, Shape{1}, {1});
|
||||
auto output_seq_mask = std::make_shared<ngraph::opset6::Select>(where_equal_minus1, seq_mask_const0, seq_mask_const1);
|
||||
auto seq_mask_axis = opset6::Constant::create(ci_type, Shape{1}, {1});
|
||||
auto output_seq_len = std::make_shared<ngraph::opset6::ReduceSum>(output_seq_mask, seq_mask_axis);
|
||||
auto output_seq_len_i = std::make_shared<ngraph::opset6::Convert>(output_seq_len->output(0), sl_type);
|
||||
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ output_i, output_seq_len_i }, ngraph::ParameterVector{ data1, seq_len1 });
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref, true, false, false, true, true);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, SimplifyCTCGreedyDecoderSeqLenDynamicInputShapeTest) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto data = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic());
|
||||
auto seq_len = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::i32, ngraph::Shape{ 1 });
|
||||
|
||||
auto decoder_v6 = std::make_shared<ngraph::op::v6::CTCGreedyDecoderSeqLen>(data, seq_len, true, element::i64, element::i32);
|
||||
auto res_1 = std::make_shared<opset6::Result>(decoder_v6->output(0));
|
||||
auto res_2 = std::make_shared<opset6::Result>(decoder_v6->output(1));
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ res_1, res_2 }, ngraph::ParameterVector{ data, seq_len });
|
||||
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::SimplifyCTCGreedyDecoderSeqLen>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
{
|
||||
auto data1 = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f16, ngraph::PartialShape::dynamic());
|
||||
auto seq_len1 = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::i32, ngraph::Shape{ 1 });
|
||||
|
||||
element::Type data_type = data1->get_element_type();
|
||||
element::Type seq_len_type = seq_len1->get_element_type();
|
||||
element::Type ci_type = element::i64;
|
||||
element::Type sl_type = element::i32;
|
||||
auto transpose = std::make_shared<ngraph::opset6::Transpose>(data1,
|
||||
ngraph::opset6::Constant::create(element::i32,
|
||||
Shape({3}), {1, 0, 2}));
|
||||
auto data_shape = std::make_shared<ngraph::opset6::ShapeOf>(data1);
|
||||
auto axisT = ngraph::opset6::Constant::create(seq_len_type, Shape{}, {0});
|
||||
auto indexT = ngraph::opset6::Constant::create(seq_len_type, Shape{1}, {1});
|
||||
auto T = std::make_shared<ngraph::opset6::Gather>(data_shape, indexT, axisT);
|
||||
|
||||
auto axisN = ngraph::opset6::Constant::create(seq_len_type, Shape{}, {0});
|
||||
auto indexN = ngraph::opset6::Constant::create(seq_len_type, Shape{1}, {0});
|
||||
auto N = std::make_shared<ngraph::opset6::Gather>(data_shape, indexN, axisN);
|
||||
|
||||
auto start = opset6::Constant::create(seq_len_type, Shape{}, {1});
|
||||
auto step = opset6::Constant::create(seq_len_type, Shape{}, {1});
|
||||
auto plus1 = opset6::Constant::create(element::i64, Shape{1}, {1});
|
||||
auto plusT = std::make_shared<ngraph::opset6::Add>(T, plus1);
|
||||
auto const_plusT = ngraph::opset6::Constant::create(seq_len_type, Shape{1}, {0});
|
||||
auto plusT_scalar = std::make_shared<ngraph::opset6::Squeeze>(plusT, const_plusT);
|
||||
auto range1T = std::make_shared<ngraph::opset6::Range>(start, plusT_scalar, step, seq_len_type);
|
||||
|
||||
auto mask_shape = std::make_shared<ngraph::opset6::Concat>(
|
||||
OutputVector{T->output(0), N->output(0)}, 0);
|
||||
|
||||
auto upper_bounds = std::make_shared<ngraph::opset6::Broadcast>(
|
||||
seq_len1, mask_shape->output(0));
|
||||
auto transpose_upper_bounds = std::make_shared<ngraph::opset6::Transpose>(upper_bounds->output(0),
|
||||
ngraph::opset6::Constant::create(seq_len_type,
|
||||
Shape({2}), {1, 0}));
|
||||
auto bool_seq_mask = std::make_shared<ngraph::opset6::GreaterEqual>(transpose_upper_bounds->output(0),
|
||||
range1T->output(0));
|
||||
|
||||
auto mask_val_true = ngraph::opset6::Constant::create(seq_len_type, Shape{}, {1});
|
||||
auto mask_val_false = ngraph::opset6::Constant::create(seq_len_type, Shape{}, {0});
|
||||
auto seq_mask = std::make_shared<ngraph::opset6::Select>(bool_seq_mask, mask_val_true, mask_val_false);
|
||||
auto transpose_seq_mask = std::make_shared<ngraph::opset6::Transpose>(seq_mask->output(0),
|
||||
ngraph::opset6::Constant::create(seq_len_type,
|
||||
Shape({2}), {1, 0}));
|
||||
auto transpose_seq_mask_f = std::make_shared<ngraph::opset6::Convert>(transpose_seq_mask->output(0), data_type);
|
||||
auto simplified_decoder = std::make_shared<ngraph::opset6::CTCGreedyDecoder>(transpose,
|
||||
transpose_seq_mask_f->output(0),
|
||||
true);
|
||||
|
||||
auto squeeze2_axis = ngraph::opset6::Constant::create(seq_len_type, Shape({1}), {3});
|
||||
auto squeeze2_output_f = std::make_shared<ngraph::opset6::Squeeze>(simplified_decoder->output(0), squeeze2_axis);
|
||||
auto squeeze1_axis = ngraph::opset6::Constant::create(seq_len_type, Shape({1}), {2});
|
||||
auto squeeze1_output_f = std::make_shared<ngraph::opset6::Squeeze>(squeeze2_output_f->output(0), squeeze1_axis);
|
||||
|
||||
auto output_i = std::make_shared<ngraph::opset6::Convert>(squeeze1_output_f->output(0), ci_type);
|
||||
auto minus1 = opset6::Constant::create(ci_type, Shape{}, {-1});
|
||||
auto where_equal_minus1 = std::make_shared<ngraph::opset6::Equal>(output_i, minus1);
|
||||
|
||||
auto seq_mask_const0 = opset6::Constant::create(ci_type, Shape{1}, {0});
|
||||
auto seq_mask_const1 = opset6::Constant::create(ci_type, Shape{1}, {1});
|
||||
auto output_seq_mask = std::make_shared<ngraph::opset6::Select>(where_equal_minus1, seq_mask_const0, seq_mask_const1);
|
||||
auto seq_mask_axis = opset6::Constant::create(ci_type, Shape{1}, {1});
|
||||
auto output_seq_len = std::make_shared<ngraph::opset6::ReduceSum>(output_seq_mask, seq_mask_axis);
|
||||
auto output_seq_len_i = std::make_shared<ngraph::opset6::Convert>(output_seq_len->output(0), sl_type);
|
||||
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ output_i, output_seq_len_i }, ngraph::ParameterVector{ data1, seq_len1 });
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref, true, false, false, true, true);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, SimplifyCTCGreedyDecoderSeqLenDynamicBatchTest) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto data = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::PartialShape{Dimension::dynamic(), 3, 7});
|
||||
auto seq_len = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::i32, ngraph::PartialShape{Dimension::dynamic()});
|
||||
|
||||
auto decoder_v6 = std::make_shared<ngraph::op::v6::CTCGreedyDecoderSeqLen>(data, seq_len, true, element::i32, element::i64);
|
||||
auto res_1 = std::make_shared<opset6::Result>(decoder_v6->output(0));
|
||||
auto res_2 = std::make_shared<opset6::Result>(decoder_v6->output(1));
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ res_1, res_2 }, ngraph::ParameterVector{ data, seq_len });
|
||||
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::SimplifyCTCGreedyDecoderSeqLen>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
{
|
||||
auto data1 = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::PartialShape{Dimension::dynamic(), 3, 7});
|
||||
auto seq_len1 = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::i32, ngraph::PartialShape{Dimension::dynamic()});
|
||||
|
||||
element::Type data_type = data1->get_element_type();
|
||||
element::Type seq_len_type = seq_len1->get_element_type();
|
||||
element::Type ci_type = element::i32;
|
||||
element::Type sl_type = element::i64;
|
||||
auto transpose = std::make_shared<ngraph::opset6::Transpose>(data1,
|
||||
ngraph::opset6::Constant::create(element::i32,
|
||||
Shape({3}), {1, 0, 2}));
|
||||
auto data_shape = std::make_shared<ngraph::opset6::ShapeOf>(data1);
|
||||
auto axisT = ngraph::opset6::Constant::create(seq_len_type, Shape{}, {0});
|
||||
auto indexT = ngraph::opset6::Constant::create(seq_len_type, Shape{1}, {1});
|
||||
auto T = std::make_shared<ngraph::opset6::Gather>(data_shape, indexT, axisT);
|
||||
|
||||
auto axisN = ngraph::opset6::Constant::create(seq_len_type, Shape{}, {0});
|
||||
auto indexN = ngraph::opset6::Constant::create(seq_len_type, Shape{1}, {0});
|
||||
auto N = std::make_shared<ngraph::opset6::Gather>(data_shape, indexN, axisN);
|
||||
|
||||
auto start = opset6::Constant::create(seq_len_type, Shape{}, {1});
|
||||
auto step = opset6::Constant::create(seq_len_type, Shape{}, {1});
|
||||
auto plus1 = opset6::Constant::create(element::i64, Shape{1}, {1});
|
||||
auto plusT = std::make_shared<ngraph::opset6::Add>(T, plus1);
|
||||
auto const_plusT = ngraph::opset6::Constant::create(seq_len_type, Shape{1}, {0});
|
||||
auto plusT_scalar = std::make_shared<ngraph::opset6::Squeeze>(plusT, const_plusT);
|
||||
auto range1T = std::make_shared<ngraph::opset6::Range>(start, plusT_scalar, step, seq_len_type);
|
||||
|
||||
auto mask_shape = std::make_shared<ngraph::opset6::Concat>(
|
||||
OutputVector{T->output(0), N->output(0)}, 0);
|
||||
|
||||
auto upper_bounds = std::make_shared<ngraph::opset6::Broadcast>(
|
||||
seq_len1, mask_shape->output(0));
|
||||
auto transpose_upper_bounds = std::make_shared<ngraph::opset6::Transpose>(upper_bounds->output(0),
|
||||
ngraph::opset6::Constant::create(seq_len_type,
|
||||
Shape({2}), {1, 0}));
|
||||
auto bool_seq_mask = std::make_shared<ngraph::opset6::GreaterEqual>(transpose_upper_bounds->output(0),
|
||||
range1T->output(0));
|
||||
|
||||
auto mask_val_true = ngraph::opset6::Constant::create(seq_len_type, Shape{}, {1});
|
||||
auto mask_val_false = ngraph::opset6::Constant::create(seq_len_type, Shape{}, {0});
|
||||
auto seq_mask = std::make_shared<ngraph::opset6::Select>(bool_seq_mask, mask_val_true, mask_val_false);
|
||||
auto transpose_seq_mask = std::make_shared<ngraph::opset6::Transpose>(seq_mask->output(0),
|
||||
ngraph::opset6::Constant::create(seq_len_type,
|
||||
Shape({2}), {1, 0}));
|
||||
auto transpose_seq_mask_f = std::make_shared<ngraph::opset6::Convert>(transpose_seq_mask->output(0), data_type);
|
||||
auto simplified_decoder = std::make_shared<ngraph::opset6::CTCGreedyDecoder>(transpose,
|
||||
transpose_seq_mask_f->output(0),
|
||||
true);
|
||||
|
||||
auto squeeze2_axis = ngraph::opset6::Constant::create(seq_len_type, Shape({1}), {3});
|
||||
auto squeeze2_output_f = std::make_shared<ngraph::opset6::Squeeze>(simplified_decoder->output(0), squeeze2_axis);
|
||||
auto squeeze1_axis = ngraph::opset6::Constant::create(seq_len_type, Shape({1}), {2});
|
||||
auto squeeze1_output_f = std::make_shared<ngraph::opset6::Squeeze>(squeeze2_output_f->output(0), squeeze1_axis);
|
||||
|
||||
auto output_i = std::make_shared<ngraph::opset6::Convert>(squeeze1_output_f->output(0), ci_type);
|
||||
auto minus1 = opset6::Constant::create(ci_type, Shape{}, {-1});
|
||||
auto where_equal_minus1 = std::make_shared<ngraph::opset6::Equal>(output_i, minus1);
|
||||
|
||||
auto seq_mask_const0 = opset6::Constant::create(ci_type, Shape{1}, {0});
|
||||
auto seq_mask_const1 = opset6::Constant::create(ci_type, Shape{1}, {1});
|
||||
auto output_seq_mask = std::make_shared<ngraph::opset6::Select>(where_equal_minus1, seq_mask_const0, seq_mask_const1);
|
||||
auto seq_mask_axis = opset6::Constant::create(ci_type, Shape{1}, {1});
|
||||
auto output_seq_len = std::make_shared<ngraph::opset6::ReduceSum>(output_seq_mask, seq_mask_axis);
|
||||
auto output_seq_len_i = std::make_shared<ngraph::opset6::Convert>(output_seq_len->output(0), sl_type);
|
||||
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ output_i, output_seq_len_i }, ngraph::ParameterVector{ data1, seq_len1 });
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref, true, false, false, true, true);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, SimplifyCTCGreedyDecoderSeqLenDynamicSeqLenTest) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto data = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::PartialShape{2, Dimension::dynamic(), 7});
|
||||
auto seq_len = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::i32, ngraph::PartialShape{2});
|
||||
|
||||
auto decoder_v6 = std::make_shared<ngraph::op::v6::CTCGreedyDecoderSeqLen>(data, seq_len, true, ngraph::element::i64, ngraph::element::i64);
|
||||
auto res_1 = std::make_shared<opset6::Result>(decoder_v6->output(0));
|
||||
auto res_2 = std::make_shared<opset6::Result>(decoder_v6->output(1));
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ res_1, res_2 }, ngraph::ParameterVector{ data, seq_len });
|
||||
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::SimplifyCTCGreedyDecoderSeqLen>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
{
|
||||
auto data1 = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::PartialShape{2, Dimension::dynamic(), 7});
|
||||
auto seq_len1 = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::i32, ngraph::PartialShape{2});
|
||||
|
||||
element::Type data_type = data1->get_element_type();
|
||||
element::Type seq_len_type = seq_len1->get_element_type();
|
||||
element::Type ci_type = element::i64;
|
||||
element::Type sl_type = element::i64;
|
||||
auto transpose = std::make_shared<ngraph::opset6::Transpose>(data1,
|
||||
ngraph::opset6::Constant::create(element::i32,
|
||||
Shape({3}), {1, 0, 2}));
|
||||
auto data_shape = std::make_shared<ngraph::opset6::ShapeOf>(data1);
|
||||
auto axisT = ngraph::opset6::Constant::create(seq_len_type, Shape{}, {0});
|
||||
auto indexT = ngraph::opset6::Constant::create(seq_len_type, Shape{1}, {1});
|
||||
auto T = std::make_shared<ngraph::opset6::Gather>(data_shape, indexT, axisT);
|
||||
|
||||
auto axisN = ngraph::opset6::Constant::create(seq_len_type, Shape{}, {0});
|
||||
auto indexN = ngraph::opset6::Constant::create(seq_len_type, Shape{1}, {0});
|
||||
auto N = std::make_shared<ngraph::opset6::Gather>(data_shape, indexN, axisN);
|
||||
|
||||
auto start = opset6::Constant::create(seq_len_type, Shape{}, {1});
|
||||
auto step = opset6::Constant::create(seq_len_type, Shape{}, {1});
|
||||
auto plus1 = opset6::Constant::create(element::i64, Shape{1}, {1});
|
||||
auto plusT = std::make_shared<ngraph::opset6::Add>(T, plus1);
|
||||
auto const_plusT = ngraph::opset6::Constant::create(seq_len_type, Shape{1}, {0});
|
||||
auto plusT_scalar = std::make_shared<ngraph::opset6::Squeeze>(plusT, const_plusT);
|
||||
auto range1T = std::make_shared<ngraph::opset6::Range>(start, plusT_scalar, step, seq_len_type);
|
||||
|
||||
auto mask_shape = std::make_shared<ngraph::opset6::Concat>(
|
||||
OutputVector{T->output(0), N->output(0)}, 0);
|
||||
|
||||
auto upper_bounds = std::make_shared<ngraph::opset6::Broadcast>(
|
||||
seq_len1, mask_shape->output(0));
|
||||
auto transpose_upper_bounds = std::make_shared<ngraph::opset6::Transpose>(upper_bounds->output(0),
|
||||
ngraph::opset6::Constant::create(seq_len_type,
|
||||
Shape({2}), {1, 0}));
|
||||
auto bool_seq_mask = std::make_shared<ngraph::opset6::GreaterEqual>(transpose_upper_bounds->output(0),
|
||||
range1T->output(0));
|
||||
|
||||
auto mask_val_true = ngraph::opset6::Constant::create(seq_len_type, Shape{}, {1});
|
||||
auto mask_val_false = ngraph::opset6::Constant::create(seq_len_type, Shape{}, {0});
|
||||
auto seq_mask = std::make_shared<ngraph::opset6::Select>(bool_seq_mask, mask_val_true, mask_val_false);
|
||||
auto transpose_seq_mask = std::make_shared<ngraph::opset6::Transpose>(seq_mask->output(0),
|
||||
ngraph::opset6::Constant::create(seq_len_type,
|
||||
Shape({2}), {1, 0}));
|
||||
auto transpose_seq_mask_f = std::make_shared<ngraph::opset6::Convert>(transpose_seq_mask->output(0), data_type);
|
||||
auto simplified_decoder = std::make_shared<ngraph::opset6::CTCGreedyDecoder>(transpose,
|
||||
transpose_seq_mask_f->output(0),
|
||||
true);
|
||||
|
||||
auto squeeze2_axis = ngraph::opset6::Constant::create(seq_len_type, Shape({1}), {3});
|
||||
auto squeeze2_output_f = std::make_shared<ngraph::opset6::Squeeze>(simplified_decoder->output(0), squeeze2_axis);
|
||||
auto squeeze1_axis = ngraph::opset6::Constant::create(seq_len_type, Shape({1}), {2});
|
||||
auto squeeze1_output_f = std::make_shared<ngraph::opset6::Squeeze>(squeeze2_output_f->output(0), squeeze1_axis);
|
||||
|
||||
auto output_i = std::make_shared<ngraph::opset6::Convert>(squeeze1_output_f->output(0), ci_type);
|
||||
auto minus1 = opset6::Constant::create(ci_type, Shape{}, {-1});
|
||||
auto where_equal_minus1 = std::make_shared<ngraph::opset6::Equal>(output_i, minus1);
|
||||
|
||||
auto seq_mask_const0 = opset6::Constant::create(ci_type, Shape{1}, {0});
|
||||
auto seq_mask_const1 = opset6::Constant::create(ci_type, Shape{1}, {1});
|
||||
auto output_seq_mask = std::make_shared<ngraph::opset6::Select>(where_equal_minus1, seq_mask_const0, seq_mask_const1);
|
||||
auto seq_mask_axis = opset6::Constant::create(ci_type, Shape{1}, {1});
|
||||
auto output_seq_len = std::make_shared<ngraph::opset6::ReduceSum>(output_seq_mask, seq_mask_axis);
|
||||
auto output_seq_len_i = std::make_shared<ngraph::opset6::Convert>(output_seq_len->output(0), sl_type);
|
||||
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ output_i, output_seq_len_i },
|
||||
ngraph::ParameterVector{ data1, seq_len1 });
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref, true, false, false, true, true);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, SimplifyCTCGreedyDecoderSeqLenWrongBlankIndexTest) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto data = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::PartialShape{2, Dimension::dynamic(), 7});
|
||||
auto seq_len = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::i32, ngraph::PartialShape{2});
|
||||
auto blank_index = op::Constant::create(element::i32, Shape{}, {5});
|
||||
|
||||
auto decoder_v6 = std::make_shared<ngraph::op::v6::CTCGreedyDecoderSeqLen>(data, seq_len, blank_index,
|
||||
true, ngraph::element::i64, ngraph::element::i64);
|
||||
auto res_1 = std::make_shared<opset6::Result>(decoder_v6->output(0));
|
||||
auto res_2 = std::make_shared<opset6::Result>(decoder_v6->output(1));
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ res_1, res_2 }, ngraph::ParameterVector{ data, seq_len });
|
||||
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::SimplifyCTCGreedyDecoderSeqLen>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
{
|
||||
auto data1 = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::PartialShape{2, Dimension::dynamic(), 7});
|
||||
auto seq_len1 = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::i32, ngraph::PartialShape{2});
|
||||
auto blank_index1 = op::Constant::create(element::i32, Shape{}, {5});
|
||||
|
||||
auto decoder_v6 = std::make_shared<ngraph::op::v6::CTCGreedyDecoderSeqLen>(data1, seq_len1, blank_index1,
|
||||
true, ngraph::element::i64, ngraph::element::i64);
|
||||
auto res_1 = std::make_shared<opset6::Result>(decoder_v6->output(0));
|
||||
auto res_2 = std::make_shared<opset6::Result>(decoder_v6->output(1));
|
||||
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ res_1, res_2 }, ngraph::ParameterVector{ data1, seq_len1 });
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref, true, false, false, true, true);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, SimplifyCTCGreedyDecoderSeqLenDynamicSeqLenWithBlankIndexTest) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto data = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::PartialShape{2, Dimension::dynamic(), 7});
|
||||
auto seq_len = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::i32, ngraph::PartialShape{2});
|
||||
auto blank_index = op::Constant::create(element::i32, Shape{}, {6});
|
||||
|
||||
auto decoder_v6 = std::make_shared<ngraph::op::v6::CTCGreedyDecoderSeqLen>(data, seq_len, blank_index,
|
||||
true, ngraph::element::i64, ngraph::element::i64);
|
||||
auto res_1 = std::make_shared<opset6::Result>(decoder_v6->output(0));
|
||||
auto res_2 = std::make_shared<opset6::Result>(decoder_v6->output(1));
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ res_1, res_2 }, ngraph::ParameterVector{ data, seq_len });
|
||||
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::SimplifyCTCGreedyDecoderSeqLen>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
{
|
||||
auto data1 = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::PartialShape{2, Dimension::dynamic(), 7});
|
||||
auto seq_len1 = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::i32, ngraph::PartialShape{2});
|
||||
|
||||
element::Type data_type = data1->get_element_type();
|
||||
element::Type seq_len_type = seq_len1->get_element_type();
|
||||
element::Type ci_type = element::i64;
|
||||
element::Type sl_type = element::i64;
|
||||
auto transpose = std::make_shared<ngraph::opset6::Transpose>(data1,
|
||||
ngraph::opset6::Constant::create(element::i32,
|
||||
Shape({3}), {1, 0, 2}));
|
||||
auto data_shape = std::make_shared<ngraph::opset6::ShapeOf>(data1);
|
||||
auto axisT = ngraph::opset6::Constant::create(seq_len_type, Shape{}, {0});
|
||||
auto indexT = ngraph::opset6::Constant::create(seq_len_type, Shape{1}, {1});
|
||||
auto T = std::make_shared<ngraph::opset6::Gather>(data_shape, indexT, axisT);
|
||||
|
||||
auto axisN = ngraph::opset6::Constant::create(seq_len_type, Shape{}, {0});
|
||||
auto indexN = ngraph::opset6::Constant::create(seq_len_type, Shape{1}, {0});
|
||||
auto N = std::make_shared<ngraph::opset6::Gather>(data_shape, indexN, axisN);
|
||||
|
||||
auto start = opset6::Constant::create(seq_len_type, Shape{}, {1});
|
||||
auto step = opset6::Constant::create(seq_len_type, Shape{}, {1});
|
||||
auto plus1 = opset6::Constant::create(element::i64, Shape{1}, {1});
|
||||
auto plusT = std::make_shared<ngraph::opset6::Add>(T, plus1);
|
||||
auto const_plusT = ngraph::opset6::Constant::create(seq_len_type, Shape{1}, {0});
|
||||
auto plusT_scalar = std::make_shared<ngraph::opset6::Squeeze>(plusT, const_plusT);
|
||||
auto range1T = std::make_shared<ngraph::opset6::Range>(start, plusT_scalar, step, seq_len_type);
|
||||
|
||||
auto mask_shape = std::make_shared<ngraph::opset6::Concat>(
|
||||
OutputVector{T->output(0), N->output(0)}, 0);
|
||||
|
||||
auto upper_bounds = std::make_shared<ngraph::opset6::Broadcast>(
|
||||
seq_len1, mask_shape->output(0));
|
||||
auto transpose_upper_bounds = std::make_shared<ngraph::opset6::Transpose>(upper_bounds->output(0),
|
||||
ngraph::opset6::Constant::create(seq_len_type,
|
||||
Shape({2}), {1, 0}));
|
||||
auto bool_seq_mask = std::make_shared<ngraph::opset6::GreaterEqual>(transpose_upper_bounds->output(0),
|
||||
range1T->output(0));
|
||||
|
||||
auto mask_val_true = ngraph::opset6::Constant::create(seq_len_type, Shape{}, {1});
|
||||
auto mask_val_false = ngraph::opset6::Constant::create(seq_len_type, Shape{}, {0});
|
||||
auto seq_mask = std::make_shared<ngraph::opset6::Select>(bool_seq_mask, mask_val_true, mask_val_false);
|
||||
auto transpose_seq_mask = std::make_shared<ngraph::opset6::Transpose>(seq_mask->output(0),
|
||||
ngraph::opset6::Constant::create(seq_len_type,
|
||||
Shape({2}), {1, 0}));
|
||||
auto transpose_seq_mask_f = std::make_shared<ngraph::opset6::Convert>(transpose_seq_mask->output(0), data_type);
|
||||
auto simplified_decoder = std::make_shared<ngraph::opset6::CTCGreedyDecoder>(transpose,
|
||||
transpose_seq_mask_f->output(0),
|
||||
true);
|
||||
|
||||
auto squeeze2_axis = ngraph::opset6::Constant::create(seq_len_type, Shape({1}), {3});
|
||||
auto squeeze2_output_f = std::make_shared<ngraph::opset6::Squeeze>(simplified_decoder->output(0), squeeze2_axis);
|
||||
auto squeeze1_axis = ngraph::opset6::Constant::create(seq_len_type, Shape({1}), {2});
|
||||
auto squeeze1_output_f = std::make_shared<ngraph::opset6::Squeeze>(squeeze2_output_f->output(0), squeeze1_axis);
|
||||
|
||||
auto output_i = std::make_shared<ngraph::opset6::Convert>(squeeze1_output_f->output(0), ci_type);
|
||||
auto minus1 = opset6::Constant::create(ci_type, Shape{}, {-1});
|
||||
auto where_equal_minus1 = std::make_shared<ngraph::opset6::Equal>(output_i, minus1);
|
||||
|
||||
auto seq_mask_const0 = opset6::Constant::create(ci_type, Shape{1}, {0});
|
||||
auto seq_mask_const1 = opset6::Constant::create(ci_type, Shape{1}, {1});
|
||||
auto output_seq_mask = std::make_shared<ngraph::opset6::Select>(where_equal_minus1, seq_mask_const0, seq_mask_const1);
|
||||
auto seq_mask_axis = opset6::Constant::create(ci_type, Shape{1}, {1});
|
||||
auto output_seq_len = std::make_shared<ngraph::opset6::ReduceSum>(output_seq_mask, seq_mask_axis);
|
||||
auto output_seq_len_i = std::make_shared<ngraph::opset6::Convert>(output_seq_len->output(0), sl_type);
|
||||
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ output_i, output_seq_len_i },
|
||||
ngraph::ParameterVector{ data1, seq_len1 });
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref, true, false, false, true, true);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, SimplifyCTCGreedyDecoderSeqLenDynamicSeqLenParamWithBlankIndexTest) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto data = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::PartialShape{2, Dimension::dynamic(), 7});
|
||||
auto seq_len = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::i32, ngraph::PartialShape{2});
|
||||
auto blank_index = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::i32, ngraph::PartialShape{1});
|
||||
|
||||
auto decoder_v6 = std::make_shared<ngraph::op::v6::CTCGreedyDecoderSeqLen>(data, seq_len, blank_index,
|
||||
true, ngraph::element::i64, ngraph::element::i64);
|
||||
auto res_1 = std::make_shared<opset6::Result>(decoder_v6->output(0));
|
||||
auto res_2 = std::make_shared<opset6::Result>(decoder_v6->output(1));
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{ res_1, res_2 }, ngraph::ParameterVector{ data, seq_len, blank_index });
|
||||
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::SimplifyCTCGreedyDecoderSeqLen>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
{
|
||||
auto data1 = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::PartialShape{2, Dimension::dynamic(), 7});
|
||||
auto seq_len1 = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::i32, ngraph::PartialShape{2});
|
||||
auto blank_index1 = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::i32, ngraph::PartialShape{1});
|
||||
|
||||
auto decoder_v6 = std::make_shared<ngraph::op::v6::CTCGreedyDecoderSeqLen>(data1, seq_len1, blank_index1,
|
||||
true, ngraph::element::i64, ngraph::element::i64);
|
||||
auto res_1 = std::make_shared<opset6::Result>(decoder_v6->output(0));
|
||||
auto res_2 = std::make_shared<opset6::Result>(decoder_v6->output(1));
|
||||
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ res_1, res_2 }, ngraph::ParameterVector{ data1, seq_len1, blank_index1 });
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref, true, false, false, true, true);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
@ -640,6 +640,7 @@ extensions/ops/constant_fill.py
|
||||
extensions/ops/copyop.py
|
||||
extensions/ops/correlation.py
|
||||
extensions/ops/ctc_greedy_decoder.py
|
||||
extensions/ops/ctc_greedy_decoder_seq_len.py
|
||||
extensions/ops/ctc_loss.py
|
||||
extensions/ops/cumsum.py
|
||||
extensions/ops/data_augmentation.py
|
||||
@ -981,6 +982,7 @@ mo/utils/ir_reader/extenders/binary_convolution_extender.py
|
||||
mo/utils/ir_reader/extenders/bucketize_extender.py
|
||||
mo/utils/ir_reader/extenders/conv_extender.py
|
||||
mo/utils/ir_reader/extenders/convert_extender.py
|
||||
mo/utils/ir_reader/extenders/ctc_greedy_decoder_seq_len_extender.py
|
||||
mo/utils/ir_reader/extenders/deconvolution_extender.py
|
||||
mo/utils/ir_reader/extenders/deformable_convolution_extender.py
|
||||
mo/utils/ir_reader/extenders/experimental_extender.py
|
||||
|
@ -1,5 +1,5 @@
|
||||
"""
|
||||
Copyright (C) 2018-2020 Intel Corporation
|
||||
Copyright (C) 2018-2021 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -14,219 +14,61 @@
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import logging as log
|
||||
|
||||
import numpy as np
|
||||
|
||||
from extensions.ops.Cast import Cast
|
||||
from extensions.front.Pack import Pack
|
||||
from extensions.front.FillToBroadcast import FillToBroadcast
|
||||
from extensions.ops.ctc_greedy_decoder_seq_len import CTCGreedyDecoderSeqLenOp
|
||||
from extensions.ops.transpose import Transpose
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.front.common.replacement import FrontReplacementSubgraph
|
||||
from mo.front.tf.graph_utils import create_op_with_const_inputs
|
||||
from mo.graph.graph import Graph, rename_nodes
|
||||
from mo.ops.broadcast import Broadcast
|
||||
from mo.ops.concat import Concat
|
||||
from mo.ops.squeeze import Squeeze
|
||||
from mo.ops.unsqueeze import Unsqueeze
|
||||
|
||||
|
||||
class CTCGreedyDecoderReplacement(FrontReplacementSubgraph):
|
||||
"""
|
||||
TensorFlow CTCGreedyDecoder produces output in a sparse tensor that is not supported by Inference Engine and
|
||||
Inference Engine's CTCGreedyDecoder has different output that is in a dense format. So this transformation
|
||||
intents to replace TF CTCGreedyDecoder+SparseToDense with IE one.
|
||||
Also Inference Engine's CTCGreedyDecoder has a specific format for the second input tensor, a sequence length,
|
||||
different from TF's one so this transformation cares about transformation of its format.
|
||||
The second input to the CTCGreedyDecoder in the TensorFlow is a 1D tensor with sequence lengths. In the Inference
|
||||
Engine the second input to the CTCGreedyDecoder is a 2D tensor, a sequence mask, where the first element
|
||||
in each row is equal to 1 and all others in the tail are equal to 0. The number of ones represents
|
||||
a sequence length.
|
||||
Inference Engine's CTCGreedyDecoderSeqLen has different output that is in a dense format. So this transformation
|
||||
intents to replace TF CTCGreedyDecoder+SparseToDense to CTCGreedyDecoderSeqLen which compatible with IE.
|
||||
"""
|
||||
enabled = True
|
||||
|
||||
def run_after(self):
|
||||
# CTCGreedyDecoderReplacement is not reshape-able transformation
|
||||
# so reshape-able CTCGreedyDecoderReplacement2 transformation is applied first
|
||||
return [CTCGreedyDecoderReplacement2]
|
||||
|
||||
@staticmethod
|
||||
def pattern(**kwargs):
|
||||
return dict(
|
||||
nodes=[('decoder', dict(op='CTCGreedyDecoder')),
|
||||
nodes=[('decoder', dict(op='CTCGreedyDecoderSeqLen')),
|
||||
('cast', dict(op='Cast')),
|
||||
('sparse_to_dense', dict(op='SparseToDense'))
|
||||
],
|
||||
edges=[('decoder', 'sparse_to_dense', {'out': 0}),
|
||||
('decoder', 'sparse_to_dense', {'out': 2}),
|
||||
('decoder', 'cast', {'out': 1}),
|
||||
('cast', 'sparse_to_dense', {'out': 0})
|
||||
]
|
||||
)
|
||||
|
||||
def replace_sub_graph(self, graph: Graph, match: dict):
|
||||
# TODO: Once Inference Engine's CTCGreedyDecoder starts to support sequence length format like in TensorFlow,
|
||||
# CTCGreedyDecoderReplacement2 needs to be removed and CTCGreedyDecoderReplacement, a more generic
|
||||
# transformation, needs to be adopted for all cases
|
||||
ctc_greedy_decoder = match['decoder']
|
||||
ctc_greedy_decoder_tf = match['decoder']
|
||||
cast = match['cast']
|
||||
sparse_to_dense = match['sparse_to_dense']
|
||||
sparse_to_dense_name = sparse_to_dense.soft_get('name', sparse_to_dense.id)
|
||||
ctc_greedy_decoder_tf_name = ctc_greedy_decoder_tf.soft_get('name', ctc_greedy_decoder_tf.id)
|
||||
|
||||
# disconnect SparseToDense and Cast nodes
|
||||
sparse_to_dense.in_port(0).disconnect()
|
||||
cast.in_port(0).disconnect()
|
||||
# for normalizing input chanel need to transpose input data from [T, N, C] to [N, T, C]
|
||||
# which supported CTCGreedyDecoderSeqLen op.
|
||||
ctc_data_permute = create_op_with_const_inputs(graph, Transpose, {1: int64_array([1, 0, 2])},
|
||||
{'name': ctc_greedy_decoder_tf_name + '/ctc_data_permute'})
|
||||
|
||||
# transform CTCGreedyDecoder output to TensorFlow's one:
|
||||
# 1. squeeze the output to [N, T] shape
|
||||
# 2. cast it to integer
|
||||
squeeze_dec_seq = create_op_with_const_inputs(graph, Squeeze, {1: int64_array([2, 3])},
|
||||
{'name': sparse_to_dense_name})
|
||||
squeeze_dec_seq.in_port(0).connect(ctc_greedy_decoder.out_port(0))
|
||||
cast_to_int = Cast(graph, {'name': sparse_to_dense_name + '/CastToInt',
|
||||
'dst_type': np.int32}).create_node()
|
||||
cast_to_int.in_port(0).connect(squeeze_dec_seq.out_port(0))
|
||||
assert ctc_greedy_decoder_tf.has_valid('merge_repeated'), \
|
||||
'The CTCGreedyDecoderSeqLen node "{}" misses "merge_repeated" attribute'.format(ctc_greedy_decoder_tf_name)
|
||||
|
||||
# preserve output name from original graph
|
||||
rename_nodes([(sparse_to_dense, sparse_to_dense_name + '/AbandonedName'),
|
||||
(cast_to_int, sparse_to_dense_name)])
|
||||
ctc_greedy_decoder_tf.in_port(0).get_source().connect(ctc_data_permute.in_port(0))
|
||||
merge_repeated_tf = ctc_greedy_decoder_tf.merge_repeated
|
||||
ctc_greedy_decoder = CTCGreedyDecoderSeqLenOp(graph, {'name': sparse_to_dense_name,
|
||||
'merge_repeated': merge_repeated_tf}).create_node()
|
||||
rename_nodes([(sparse_to_dense, sparse_to_dense_name + '/AbandonedName'), (ctc_greedy_decoder, sparse_to_dense_name)])
|
||||
ctc_greedy_decoder.in_port(0).connect(ctc_data_permute.out_port(0))
|
||||
ctc_greedy_decoder_tf.in_port(1).get_source().connect(ctc_greedy_decoder.in_port(1))
|
||||
|
||||
# set output of the new sub-graph as a source for SparseToDense consumer
|
||||
sparse_to_dense.out_port(0).get_connection().set_source(cast_to_int.out_port(0))
|
||||
sparse_to_dense.out_port(0).get_connection().set_source(ctc_greedy_decoder.out_port(0))
|
||||
|
||||
# remove no longer needed nodes
|
||||
graph.remove_nodes_from([sparse_to_dense.id, cast.id])
|
||||
|
||||
# mark CTCGreedyDecoder node as a node that requires transformation of sequence length to a mask format
|
||||
# in the middle phase
|
||||
ctc_greedy_decoder['use_mask_format'] = True
|
||||
|
||||
# unless the second input of CTCGreedyDecoder is a parameter, it enforces MO to use --static-shape
|
||||
# to try getting the second input with a value
|
||||
sequence_length_node = ctc_greedy_decoder.in_node(1)
|
||||
if sequence_length_node.soft_get('op') != 'Parameter' and not graph.graph['cmd_params'].static_shape:
|
||||
log.error(
|
||||
"Model can not be translated in a reshape-able way.\n"
|
||||
"Model Optimizer key static_shape was turned on to prevent related errors.\n"
|
||||
"There will be no success changing input shapes of the model with the help of "
|
||||
"InferenceEngine reshape method", extra={'is_warning': True})
|
||||
graph.graph['cmd_params'].static_shape = True
|
||||
|
||||
|
||||
class CTCGreedyDecoderReplacement2(FrontReplacementSubgraph):
|
||||
"""
|
||||
The TF implementation of the CTCGreedyDecoder produces a tuple with two tensors. The first element in the tuple is
|
||||
the SparseTensor which is converted to a regular tensor with the SparseToDense operation. This replacer matches
|
||||
CTCGreedyDecoder and SparseToDense operations and removes the SparseToDense and Cast operation which is also used
|
||||
in the SparseToDense operation, because Inference Engine implementation of the CTCGreedyDecoder produces regular
|
||||
tensor as output.
|
||||
Also, Inference Engine CTCGreedyDecoder requires a mask format for sequence lengths that is a different from
|
||||
original one. Hence, this transformation changes a format of sequence length to a mask by replacing Fill and Pack
|
||||
nodes with a special graph that produces a tensor of ones with shape [T, N] accepted by opset CTCGreedyDecoder.
|
||||
"""
|
||||
enabled = True
|
||||
|
||||
def run_before(self):
|
||||
return [Pack, FillToBroadcast]
|
||||
|
||||
@staticmethod
|
||||
def pattern(**kwargs):
|
||||
return dict(
|
||||
nodes=[
|
||||
('transpose', dict(op='Transpose')),
|
||||
('shape', dict(op='ShapeOf')),
|
||||
('shape_1', dict(op='ShapeOf')),
|
||||
('strided_slice', dict(op='StridedSlice')),
|
||||
('stack', dict(op='Const', value=lambda v: v is not None and np.array_equal(v, [1]))),
|
||||
('stack1', dict(op='Const', value=lambda v: v is not None and np.array_equal(v, [2]))),
|
||||
('stack2', dict(op='Const', value=lambda v: v is not None and np.array_equal(v, [1]))),
|
||||
('strided_slice_1', dict(op='StridedSlice')),
|
||||
('stack_1', dict(op='Const', value=lambda v: v is not None and np.array_equal(v, [0]))),
|
||||
('stack1_1', dict(op='Const', value=lambda v: v is not None and np.array_equal(v, [1]))),
|
||||
('stack2_1', dict(op='Const', value=lambda v: v is not None and np.array_equal(v, [1]))),
|
||||
('dims', dict(op='Pack')),
|
||||
('fill', dict(op='Fill')),
|
||||
('decoder', dict(op='CTCGreedyDecoder')),
|
||||
('cast', dict(op='Cast')),
|
||||
('sparse_to_dense', dict(op='SparseToDense')),
|
||||
],
|
||||
edges=[
|
||||
('transpose', 'shape', {'out': 0}),
|
||||
('transpose', 'shape_1', {'out': 0}),
|
||||
('transpose', 'decoder', {'out': 0, 'in': 0}),
|
||||
('shape', 'strided_slice', {'out': 0, 'in': 0}),
|
||||
('stack', 'strided_slice', {'out': 0, 'in': 1}),
|
||||
('stack1', 'strided_slice', {'out': 0, 'in': 2}),
|
||||
('stack2', 'strided_slice', {'out': 0, 'in': 3}),
|
||||
('shape_1', 'strided_slice_1', {'out': 0, 'in': 0}),
|
||||
('stack_1', 'strided_slice_1', {'out': 0, 'in': 1}),
|
||||
('stack1_1', 'strided_slice_1', {'out': 0, 'in': 2}),
|
||||
('stack2_1', 'strided_slice_1', {'out': 0, 'in': 3}),
|
||||
('strided_slice', 'dims', {'out': 0, 'in': 0}),
|
||||
('dims', 'fill', {'out': 0, 'in': 0}),
|
||||
('strided_slice_1', 'fill', {'out': 0, 'in': 1}),
|
||||
('fill', 'decoder', {'out': 0, 'in': 1}),
|
||||
('decoder', 'sparse_to_dense', {'out': 0, 'in': 0}),
|
||||
('decoder', 'cast', {'out': 1, 'in': 0}),
|
||||
('cast', 'sparse_to_dense', {'out': 0}),
|
||||
]
|
||||
)
|
||||
|
||||
def replace_sub_graph(self, graph: Graph, match: dict):
|
||||
# obtain references to necessary nodes and their names
|
||||
fill = match['fill']
|
||||
dims = match['dims']
|
||||
strided_slice = match['strided_slice']
|
||||
strided_slice_1 = match['strided_slice_1']
|
||||
ctc_greedy_decoder = match['decoder']
|
||||
cast = match['cast']
|
||||
sparse_to_dense = match['sparse_to_dense']
|
||||
strided_slice_name = strided_slice.soft_get('name', strided_slice.id)
|
||||
strided_slice_1_name = strided_slice_1.soft_get('name', strided_slice_1.id)
|
||||
ctc_greedy_decoder_name = ctc_greedy_decoder.soft_get('name', ctc_greedy_decoder.id)
|
||||
sparse_to_dense_name = sparse_to_dense.soft_get('name', sparse_to_dense.id)
|
||||
|
||||
# unsqueeze scalar values with batch size and time dimension
|
||||
unsqueeze_batch_size = create_op_with_const_inputs(graph, Unsqueeze, {1: int64_array(0)},
|
||||
{'name': strided_slice_name + '/Unsqueeze'})
|
||||
dims.in_port(0).get_connection().set_destination(unsqueeze_batch_size.in_port(0))
|
||||
unsqueeze_time_size = create_op_with_const_inputs(graph, Unsqueeze, {1: int64_array(0)},
|
||||
{'name': strided_slice_1_name + '/Unsqueeze'})
|
||||
fill.in_port(1).get_connection().set_destination(unsqueeze_time_size.in_port(0))
|
||||
|
||||
# compute a sequence mask shape [T, N] required for CTCGreedyDecoder
|
||||
seq_mask_shape = Concat(graph, {'axis': 0, 'in_ports_count': 2,
|
||||
'name': ctc_greedy_decoder_name + '/SequenceMaskShape'}).create_node()
|
||||
seq_mask_shape.in_port(0).connect(unsqueeze_time_size.out_port(0))
|
||||
seq_mask_shape.in_port(1).connect(unsqueeze_batch_size.out_port(0))
|
||||
|
||||
# compute a sequence mask
|
||||
sequence_mask = create_op_with_const_inputs(graph, Broadcast, {0: np.array([1.0], dtype=np.float)},
|
||||
{'mode': 'numpy',
|
||||
'name': ctc_greedy_decoder_name + '/SequenceMask'})
|
||||
sequence_mask.in_port(1).connect(seq_mask_shape.out_port(0))
|
||||
|
||||
# create CTCGreedyDecoder with the sequence mask instead of sequence length
|
||||
ctc_greedy_decoder.in_port(1).disconnect()
|
||||
ctc_greedy_decoder.in_port(1).connect(sequence_mask.out_port(0))
|
||||
|
||||
# remove fill and pack nodes since they are now in unconnected component
|
||||
graph.remove_nodes_from([fill.id, dims.id])
|
||||
|
||||
# transform opset CTCGreedyDecoder output to TensorFlow's one that has a shape [N, T]
|
||||
# opset CTCGreedyDecoder has an output with a shape [N, T, 1, 1]
|
||||
squeeze_dec_seq = create_op_with_const_inputs(graph, Squeeze, {1: int64_array([2, 3])},
|
||||
{'name': sparse_to_dense_name})
|
||||
squeeze_dec_seq.in_port(0).connect(ctc_greedy_decoder.out_port(0))
|
||||
cast_to_int = Cast(graph, {'name': sparse_to_dense_name + '/CastToInt',
|
||||
'dst_type': np.int32}).create_node()
|
||||
cast_to_int.in_port(0).connect(squeeze_dec_seq.out_port(0))
|
||||
|
||||
# preserve output name from original graph
|
||||
rename_nodes([(sparse_to_dense, sparse_to_dense_name + '/AbandonedName'),
|
||||
(cast_to_int, sparse_to_dense_name)])
|
||||
|
||||
# set output of the new sub-graph as a source for SparseToDense consumer
|
||||
sparse_to_dense.out_port(0).get_connection().set_source(cast_to_int.out_port(0))
|
||||
|
||||
# cleanup a graph
|
||||
graph.remove_nodes_from([cast.id, sparse_to_dense.id])
|
||||
graph.remove_nodes_from([sparse_to_dense.id, cast.id, ctc_greedy_decoder_tf.id])
|
||||
|
@ -1,5 +1,5 @@
|
||||
"""
|
||||
Copyright (C) 2020 Intel Corporation
|
||||
Copyright (C) 2021 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -14,10 +14,11 @@
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import unittest
|
||||
|
||||
from extensions.front.tf.CTCGreedyDecoderReplacement import CTCGreedyDecoderReplacement, CTCGreedyDecoderReplacement2
|
||||
import numpy as np
|
||||
|
||||
from extensions.front.tf.CTCGreedyDecoderReplacement import CTCGreedyDecoderReplacement
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.utils.ir_engine.compare_graphs import compare_graphs
|
||||
from mo.utils.unittest.graph import build_graph, const
|
||||
@ -29,13 +30,15 @@ class CTCGreedyDecoderReplacementTests(unittest.TestCase):
|
||||
# nodes from original graph
|
||||
'logits': {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
|
||||
'seq_len': {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
|
||||
'decoder': {'kind': 'op', 'op': 'CTCGreedyDecoder'},
|
||||
'order_arr': {'kind': 'op', 'op': 'Const'},
|
||||
'transpose': {'type': 'Transpose', 'kind': 'op', 'op': 'Transpose'},
|
||||
'decoder': {'kind': 'op', 'op': 'CTCGreedyDecoderSeqLen', 'merge_repeated': True},
|
||||
'cast': {'kind': 'op', 'op': 'Cast'},
|
||||
'sparse_to_dense': {'kind': 'op', 'op': 'SparseToDense'},
|
||||
'last': {'type': None, 'value': None, 'kind': 'op', 'op': 'Result'},
|
||||
|
||||
# new nodes
|
||||
'new_decoder': {'kind': 'op', 'op': 'CTCGreedyDecoder', 'use_mask_format': True},
|
||||
'new_decoder': {'kind': 'op', 'op': 'CTCGreedyDecoderSeqLen', 'use_mask_format': True},
|
||||
**const('squeeze_axes', int64_array([2, 3])),
|
||||
'squeeze_dec_seq': {'kind': 'op', 'op': 'Squeeze'},
|
||||
'cast_to_int': {'kind': 'op', 'op': 'Cast'},
|
||||
@ -45,6 +48,7 @@ class CTCGreedyDecoderReplacementTests(unittest.TestCase):
|
||||
[('logits', 'decoder', {'out': 0, 'in': 0}),
|
||||
('seq_len', 'decoder', {'out': 0, 'in': 1}),
|
||||
('decoder', 'sparse_to_dense', {'out': 0, 'in': 0}),
|
||||
('decoder', 'sparse_to_dense', {'out': 2, 'in': 1}),
|
||||
('decoder', 'cast', {'out': 1, 'in': 0}),
|
||||
('cast', 'sparse_to_dense', {'out': 0}),
|
||||
('sparse_to_dense', 'last', {'out': 0, 'in': 0}),
|
||||
@ -53,113 +57,13 @@ class CTCGreedyDecoderReplacementTests(unittest.TestCase):
|
||||
CTCGreedyDecoderReplacement().find_and_replace_pattern(graph)
|
||||
|
||||
graph_ref = build_graph(nodes_attributes,
|
||||
[('logits', 'decoder', {'out': 0, 'in': 0}),
|
||||
('seq_len', 'decoder', {'out': 0, 'in': 1}),
|
||||
('decoder', 'squeeze_dec_seq', {'out': 0, 'in': 0}),
|
||||
('squeeze_axes', 'squeeze_dec_seq', {'out': 0, 'in': 1}),
|
||||
('squeeze_dec_seq', 'cast_to_int', {'out': 0, 'in': 0}),
|
||||
('cast_to_int', 'last', {'out': 0, 'in': 0}),
|
||||
],
|
||||
nodes_with_edges_only=True)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'last', check_op_attrs=True)
|
||||
self.assertEqual(len(graph.get_op_nodes(op='Cast')) == 1 and
|
||||
graph.get_op_nodes(op='Cast')[0]['name'] == 'sparse_to_dense', True,
|
||||
'Name is not inherited from original node for CTCGreedyDecoderReplacement')
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test2(self):
|
||||
nodes_attributes = {
|
||||
# nodes from original graph
|
||||
'logits': {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
|
||||
'transpose': {'kind': 'op', 'op': 'Transpose'},
|
||||
'shape': {'kind': 'op', 'op': 'ShapeOf'},
|
||||
'shape_1': {'kind': 'op', 'op': 'ShapeOf'},
|
||||
'strided_slice': {'kind': 'op', 'op': 'StridedSlice'},
|
||||
**const('stack', int64_array([1])),
|
||||
**const('stack1', int64_array([2])),
|
||||
**const('stack2', int64_array([1])),
|
||||
'strided_slice_1': {'kind': 'op', 'op': 'StridedSlice'},
|
||||
**const('stack_1', int64_array([0])),
|
||||
**const('stack1_1', int64_array([1])),
|
||||
**const('stack2_1', int64_array([1])),
|
||||
'dims': {'kind': 'op', 'op': 'Pack'},
|
||||
'fill': {'kind': 'op', 'op': 'Fill'},
|
||||
'decoder': {'kind': 'op', 'op': 'CTCGreedyDecoder'},
|
||||
'cast': {'kind': 'op', 'op': 'Cast'},
|
||||
'sparse_to_dense': {'kind': 'op', 'op': 'SparseToDense'},
|
||||
|
||||
# new nodes
|
||||
**const('unsqueeze_batch_size_axis', int64_array(0)),
|
||||
'unsqueeze_batch_size': {'kind': 'op', 'op': 'Unsqueeze'},
|
||||
**const('unsqueeze_time_size_axis', int64_array(0)),
|
||||
'unsqueeze_time_size': {'kind': 'op', 'op': 'Unsqueeze'},
|
||||
'seq_mask_shape': {'kind': 'op', 'op': 'Concat'},
|
||||
'sequence_mask': {'kind': 'op', 'op': 'Broadcast'},
|
||||
**const('one', np.array([1.0], dtype=np.float)),
|
||||
**const('squeeze_axes', int64_array([2, 3])),
|
||||
'squeeze_dec_seq': {'kind': 'op', 'op': 'Squeeze'},
|
||||
'cast_to_int': {'kind': 'op', 'op': 'Cast'},
|
||||
|
||||
'last': {'type': None, 'value': None, 'kind': 'op', 'op': 'Result'},
|
||||
}
|
||||
|
||||
graph = build_graph(nodes_attributes,
|
||||
[('logits', 'transpose', {'out': 0}),
|
||||
('transpose', 'shape', {'out': 0}),
|
||||
('transpose', 'shape_1', {'out': 0}),
|
||||
('transpose', 'decoder', {'out': 0, 'in': 0}),
|
||||
('shape', 'strided_slice', {'out': 0, 'in': 0}),
|
||||
('stack', 'strided_slice', {'out': 0, 'in': 1}),
|
||||
('stack1', 'strided_slice', {'out': 0, 'in': 2}),
|
||||
('stack2', 'strided_slice', {'out': 0, 'in': 3}),
|
||||
('shape_1', 'strided_slice_1', {'out': 0, 'in': 0}),
|
||||
('stack_1', 'strided_slice_1', {'out': 0, 'in': 1}),
|
||||
('stack1_1', 'strided_slice_1', {'out': 0, 'in': 2}),
|
||||
('stack2_1', 'strided_slice_1', {'out': 0, 'in': 3}),
|
||||
('strided_slice', 'dims', {'out': 0, 'in': 0}),
|
||||
('dims', 'fill', {'out': 0, 'in': 0}),
|
||||
('strided_slice_1', 'fill', {'out': 0, 'in': 1}),
|
||||
('fill', 'decoder', {'out': 0, 'in': 1}),
|
||||
('decoder', 'sparse_to_dense', {'out': 0, 'in': 0}),
|
||||
('decoder', 'cast', {'out': 1, 'in': 0}),
|
||||
('cast', 'sparse_to_dense', {'out': 0}),
|
||||
('sparse_to_dense', 'last', {'out': 0, 'in': 0}),
|
||||
], nodes_with_edges_only=True)
|
||||
graph.stage = 'front'
|
||||
CTCGreedyDecoderReplacement2().find_and_replace_pattern(graph)
|
||||
|
||||
graph_ref = build_graph(nodes_attributes,
|
||||
[('logits', 'transpose', {'out': 0}),
|
||||
('transpose', 'shape', {'out': 0}),
|
||||
('transpose', 'shape_1', {'out': 0}),
|
||||
[('logits', 'transpose', {'out': 0, 'in': 0}),
|
||||
('order_arr', 'transpose', {'out': 0, 'in': 1}),
|
||||
('transpose', 'decoder', {'out': 0, 'in': 0}),
|
||||
('shape', 'strided_slice', {'out': 0, 'in': 0}),
|
||||
('stack', 'strided_slice', {'out': 0, 'in': 1}),
|
||||
('stack1', 'strided_slice', {'out': 0, 'in': 2}),
|
||||
('stack2', 'strided_slice', {'out': 0, 'in': 3}),
|
||||
('shape_1', 'strided_slice_1', {'out': 0, 'in': 0}),
|
||||
('stack_1', 'strided_slice_1', {'out': 0, 'in': 1}),
|
||||
('stack1_1', 'strided_slice_1', {'out': 0, 'in': 2}),
|
||||
('stack2_1', 'strided_slice_1', {'out': 0, 'in': 3}),
|
||||
('strided_slice', 'unsqueeze_batch_size', {'out': 0, 'in': 0}),
|
||||
('unsqueeze_batch_size_axis', 'unsqueeze_batch_size', {'out': 0, 'in': 1}),
|
||||
('strided_slice_1', 'unsqueeze_time_size', {'out': 0, 'in': 0}),
|
||||
('unsqueeze_time_size_axis', 'unsqueeze_time_size', {'out': 0, 'in': 1}),
|
||||
('unsqueeze_batch_size', 'seq_mask_shape', {'out': 0, 'in': 1}),
|
||||
('unsqueeze_time_size', 'seq_mask_shape', {'out': 0, 'in': 0}),
|
||||
('one', 'sequence_mask', {'out': 0, 'in': 0}),
|
||||
('seq_mask_shape', 'sequence_mask', {'out': 0, 'in': 1}),
|
||||
('sequence_mask', 'decoder', {'out': 0, 'in': 1}),
|
||||
('decoder', 'squeeze_dec_seq', {'out': 0, 'in': 0}),
|
||||
('squeeze_axes', 'squeeze_dec_seq', {'out': 0, 'in': 1}),
|
||||
('squeeze_dec_seq', 'cast_to_int', {'out': 0, 'in': 0}),
|
||||
('cast_to_int', 'last', {'out': 0, 'in': 0}),
|
||||
('seq_len', 'decoder', {'out': 0, 'in': 1}),
|
||||
('decoder', 'last', {'out': 0, 'in': 0}),
|
||||
],
|
||||
nodes_with_edges_only=True)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'last', check_op_attrs=True)
|
||||
self.assertEqual(len(graph.get_op_nodes(op='Cast')) == 1 and
|
||||
graph.get_op_nodes(op='Cast')[0]['name'] == 'sparse_to_dense', True,
|
||||
'Name is not inherited from original node for CTCGreedyDecoderReplacement2')
|
||||
self.assertTrue(flag, resp)
|
||||
|
@ -1,5 +1,5 @@
|
||||
"""
|
||||
Copyright (C) 2018-2020 Intel Corporation
|
||||
Copyright (C) 2018-2021 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -13,7 +13,7 @@
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
from extensions.ops.ctc_greedy_decoder import CTCGreedyDecoderOp
|
||||
from extensions.ops.ctc_greedy_decoder_seq_len import CTCGreedyDecoderSeqLenOp
|
||||
from mo.front.extractor import FrontExtractorOp
|
||||
|
||||
|
||||
@ -24,7 +24,7 @@ class CTCCGreedyDecoderFrontExtractor(FrontExtractorOp):
|
||||
@classmethod
|
||||
def extract(cls, node):
|
||||
attrs = {
|
||||
'ctc_merge_repeated': int(node.pb.attr['merge_repeated'].b),
|
||||
'merge_repeated': bool(node.pb.attr['merge_repeated'].b),
|
||||
}
|
||||
CTCGreedyDecoderOp.update_node_stat(node, attrs)
|
||||
CTCGreedyDecoderSeqLenOp.update_node_stat(node, attrs)
|
||||
return cls.enabled
|
||||
|
@ -14,37 +14,21 @@
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import logging as log
|
||||
|
||||
from extensions.ops.Cast import Cast
|
||||
from extensions.ops.ctc_greedy_decoder import CTCGreedyDecoderOp
|
||||
from extensions.ops.ctc_greedy_decoder_seq_len import CTCGreedyDecoderSeqLenOp
|
||||
from extensions.ops.ctc_loss import CTCLoss
|
||||
from extensions.ops.elementwise import Equal
|
||||
from extensions.ops.parameter import Parameter
|
||||
from extensions.ops.ReduceOps import ReduceSum
|
||||
from extensions.ops.select import Select
|
||||
from extensions.ops.transpose import Transpose
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.front.common.replacement import FrontReplacementSubgraph
|
||||
from mo.front.tf.graph_utils import create_op_with_const_inputs
|
||||
from mo.graph.graph import Graph, rename_nodes
|
||||
from mo.middle.passes.convert_data_type import data_type_str_to_np
|
||||
from mo.ops.broadcast import Broadcast
|
||||
from mo.ops.shape import Shape
|
||||
from mo.ops.squeeze import Squeeze
|
||||
from mo.utils.error import Error
|
||||
|
||||
|
||||
class CTCLossReplacement(FrontReplacementSubgraph):
|
||||
"""
|
||||
The CTCLoss appears along with CTCGreedyDecoder operation in particular. Since the TensorFlow* CTCGreedyDecoder
|
||||
outputs sparse tensor format, the OpenVINO CTCGreedyDecoder has a different format and the CTCLoss is also affected
|
||||
outputs sparse tensor format, the OpenVINO CTCGreedyDecoderSeqLen has a different format and the CTCLoss is also affected
|
||||
in terms of different format for its inputs. So the corresponding sub-graph with CTCGreedyDecoding and CTCLoss
|
||||
must be transformed properly.
|
||||
Also, the transformation changes the input sequence length format into a mask format. For example, 1D tensor of
|
||||
sequence lengths equal to [4 2] is coded as 2D tensor [[1 1 1 1 0], [1 1 0 0 0]] with a time dimension is
|
||||
equal to 5.
|
||||
"""
|
||||
enabled = True
|
||||
|
||||
@ -55,17 +39,14 @@ class CTCLossReplacement(FrontReplacementSubgraph):
|
||||
def pattern(self):
|
||||
return dict(
|
||||
nodes=[
|
||||
('seq_len', dict(op='Parameter')),
|
||||
('transpose', dict(op='Transpose')),
|
||||
('ctc_greedy_decoder', dict(op='CTCGreedyDecoder')),
|
||||
('ctc_greedy_decoder', dict(op='CTCGreedyDecoderSeqLen')),
|
||||
('cast', dict(op='Cast')),
|
||||
('sparse_to_dense', dict(op='SparseToDense')),
|
||||
('const', dict(op='Const')),
|
||||
('ctc_loss', dict(op='CTCLoss')),
|
||||
],
|
||||
edges=[
|
||||
('seq_len', 'ctc_greedy_decoder', {'out': 0, 'in': 1}),
|
||||
('seq_len', 'ctc_loss', {'out': 0, 'in': 3}),
|
||||
('transpose', 'ctc_greedy_decoder', {'out': 0, 'in': 0}),
|
||||
('transpose', 'ctc_loss', {'out': 0, 'in': 0}),
|
||||
('ctc_greedy_decoder', 'sparse_to_dense', {'out': 0, 'in': 0}),
|
||||
@ -78,53 +59,29 @@ class CTCLossReplacement(FrontReplacementSubgraph):
|
||||
])
|
||||
|
||||
def replace_sub_graph(self, graph: Graph, match: dict):
|
||||
seq_len_tf = match['seq_len']
|
||||
transpose_tf = match['transpose']
|
||||
ctc_greedy_decoder_tf = match['ctc_greedy_decoder']
|
||||
cast_tf = match['cast']
|
||||
ctc_loss_tf = match['ctc_loss']
|
||||
sparse_to_dense_tf = match['sparse_to_dense']
|
||||
|
||||
output_sparse_to_dense_name = sparse_to_dense_tf.soft_get('name', sparse_to_dense_tf.id)
|
||||
output_ctc_loss_name = ctc_loss_tf.soft_get('name', ctc_loss_tf.id)
|
||||
ctc_data_permute = create_op_with_const_inputs(graph, Transpose, {1: int64_array([1, 0, 2])},
|
||||
{'name': ctc_greedy_decoder_tf.name + '/ctc_data_permute'})
|
||||
ctc_data_permute.in_port(0).connect(transpose_tf.out_port(0))
|
||||
|
||||
ctc_greedy_decoder_tf_name = ctc_greedy_decoder_tf.soft_get('name', ctc_greedy_decoder_tf.id)
|
||||
|
||||
log.debug('Found CTCLossFrontReplacer pattern after {} with name {}'.format(ctc_greedy_decoder_tf.op,
|
||||
ctc_greedy_decoder_tf.name))
|
||||
|
||||
# create sequence mask node, sub-graph for transforming into sequence length and connect with consumers
|
||||
seq_len_tf_shape = seq_len_tf.soft_get('shape', None)
|
||||
if seq_len_tf_shape is None or len(seq_len_tf_shape) != 2:
|
||||
raise Error('The sequence length that is the second input to the CTCGreedyDecoder node "{}"'
|
||||
' must be specified in a mask format.'.format(ctc_greedy_decoder_tf_name))
|
||||
log.error('The format of input sequence length has been changed to a mask format', extra={'is_warning': True})
|
||||
seq_len_tf_type = seq_len_tf.soft_get('data_type', None)
|
||||
seq_len_tf_name = seq_len_tf.soft_get('name', seq_len_tf.id)
|
||||
seq_mask_placeholder = Parameter(graph, {'name': seq_len_tf_name, 'shape': seq_len_tf_shape,
|
||||
'data_type': seq_len_tf_type}).create_node()
|
||||
reduce_to_seq_len_node = create_op_with_const_inputs(graph, ReduceSum, {1: np.array(1, dtype=np.int32)},
|
||||
{'name': seq_len_tf_name + '/ReduceToSeqLen',
|
||||
'keep_dims': False})
|
||||
reduce_to_seq_len_node.in_port(0).connect(seq_mask_placeholder.out_port(0))
|
||||
seq_len_tf.out_port(0).get_connection().set_source(reduce_to_seq_len_node.out_port(0))
|
||||
|
||||
cast_fp_type = data_type_str_to_np(graph.graph['cmd_params'].data_type)
|
||||
casted_seq_mask_node = Cast(graph, {'name': seq_len_tf_name + '/CastToFP32', 'dst_type': cast_fp_type}).create_node()
|
||||
casted_seq_mask_node.in_port(0).connect(seq_mask_placeholder.out_port(0))
|
||||
permuted_casted_seq_mask = create_op_with_const_inputs(graph, Transpose, {1: int64_array([1, 0])},
|
||||
{'name': seq_len_tf_name + '/Permute'})
|
||||
permuted_casted_seq_mask.in_port(0).connect(casted_seq_mask_node.out_port(0))
|
||||
rename_nodes([(seq_len_tf, seq_len_tf_name + '/AbandonedName'), (seq_mask_placeholder, seq_len_tf_name)])
|
||||
|
||||
# create CTCGreedyDecoder node and set mask node
|
||||
ctc_merge_repeated_i = ctc_greedy_decoder_tf.soft_get('ctc_merge_repeated', ctc_greedy_decoder_tf.id)
|
||||
ctc_greedy_decoder = CTCGreedyDecoderOp(graph, {'name': output_sparse_to_dense_name,
|
||||
'ctc_merge_repeated': ctc_merge_repeated_i}).create_node()
|
||||
ctc_greedy_decoder.in_port(1).connect(permuted_casted_seq_mask.out_port(0))
|
||||
assert ctc_greedy_decoder_tf.has_valid('merge_repeated'), \
|
||||
'The CTCGreedyDecoderSeqLen node "{}" misses "merge_repeated" attribute'.format(ctc_greedy_decoder_tf_name)
|
||||
merge_repeated_tf = ctc_greedy_decoder_tf.merge_repeated
|
||||
ctc_greedy_decoder = CTCGreedyDecoderSeqLenOp(graph, {'name': output_sparse_to_dense_name,
|
||||
'merge_repeated': merge_repeated_tf}).create_node()
|
||||
rename_nodes([(sparse_to_dense_tf, output_sparse_to_dense_name + '/AbandonedName'),
|
||||
(ctc_greedy_decoder, output_sparse_to_dense_name)])
|
||||
ctc_greedy_decoder.in_port(0).connect(ctc_data_permute.out_port(0))
|
||||
ctc_greedy_decoder.in_port(1).connect(ctc_greedy_decoder_tf.in_port(1).get_connection().get_source())
|
||||
|
||||
# create CTCLoss node and set attributes
|
||||
# set output of the new sub-graph as a source for SparseToDense consumer
|
||||
output_ctc_loss_name = ctc_loss_tf.soft_get('name', ctc_loss_tf.id)
|
||||
assert ctc_loss_tf.has_valid('preprocess_collapse_repeated'), \
|
||||
'The CTCLoss node "{}" misses "preprocess_collapse_repeated" attribute'.format(output_ctc_loss_name)
|
||||
assert ctc_loss_tf.has_valid('ctc_merge_repeated'), \
|
||||
@ -139,48 +96,14 @@ class CTCLossReplacement(FrontReplacementSubgraph):
|
||||
'ctc_merge_repeated': ctc_merge_repeated,
|
||||
'unique': unique}).create_node()
|
||||
rename_nodes([(ctc_loss_tf, output_ctc_loss_name + '/AbandonedName'), (ctc_loss, output_ctc_loss_name)])
|
||||
|
||||
# connect logits
|
||||
ctc_greedy_decoder_tf.in_port(0).get_connection().set_destination(ctc_greedy_decoder.in_port(0))
|
||||
ctc_loss.in_port(0).disconnect()
|
||||
transpose_tf.in_port(0).get_connection().add_destination(ctc_loss.in_port(0))
|
||||
|
||||
# connect logit lengths
|
||||
ctc_greedy_decoder_tf.in_port(1).disconnect()
|
||||
ctc_loss.in_port(1).connect(reduce_to_seq_len_node.out_port(0))
|
||||
|
||||
# connect labels to ctc_loss
|
||||
squeeze_op = create_op_with_const_inputs(graph, Squeeze, {1: int64_array([2, 3])})
|
||||
cast_labels_op = Cast(graph, {'name': output_sparse_to_dense_name + '/CastLabels', 'dst_type': np.int32}).create_node()
|
||||
squeeze_op.in_port(0).connect(ctc_greedy_decoder.out_port(0))
|
||||
cast_labels_op.in_port(0).connect(squeeze_op.out_port(0))
|
||||
ctc_loss.in_port(2).connect(cast_labels_op.out_port(0))
|
||||
|
||||
# connect label lengths
|
||||
equal_op = create_op_with_const_inputs(graph, Equal, {1: np.array([-1], dtype=np.int32)},
|
||||
{'name': output_sparse_to_dense_name + '/Equal'})
|
||||
equal_op.in_port(0).connect(cast_labels_op.out_port(0))
|
||||
labels_shape_op = Shape(graph, {'name': output_sparse_to_dense_name + '/ShapeOf'}).create_node()
|
||||
labels_shape_op.in_port(0).connect(equal_op.out_port(0))
|
||||
broadcast_one = create_op_with_const_inputs(graph, Broadcast, {0: np.array([1], dtype=np.int32)},
|
||||
{'mode': 'numpy',
|
||||
'name': output_sparse_to_dense_name + '/One'})
|
||||
broadcast_one.in_port(1).connect(labels_shape_op.out_port(0))
|
||||
broadcast_zero = create_op_with_const_inputs(graph, Broadcast, {0: np.array([0], dtype=np.int32)},
|
||||
{'mode': 'numpy',
|
||||
'name': output_sparse_to_dense_name + '/Zero'})
|
||||
broadcast_zero.in_port(1).connect(labels_shape_op.out_port(0))
|
||||
|
||||
select_node = Select(graph, {'name': output_sparse_to_dense_name + '/Select'}).create_node()
|
||||
select_node.in_port(0).connect(equal_op.out_port(0))
|
||||
select_node.in_port(1).connect(broadcast_zero.out_port(0))
|
||||
select_node.in_port(2).connect(broadcast_one.out_port(0))
|
||||
label_length_node = create_op_with_const_inputs(graph, ReduceSum, {1: int64_array([1])},
|
||||
op_attrs={'name': output_sparse_to_dense_name + '/LabelLength',
|
||||
'keep_dims': False})
|
||||
label_length_node.in_port(0).connect(select_node.out_port(0))
|
||||
ctc_loss.in_port(3).connect(label_length_node.out_port(0))
|
||||
|
||||
# set source for output of new sub-graph and remove old nodes
|
||||
ctc_loss_tf.out_port(0).get_connection().set_source(ctc_loss.out_port(0))
|
||||
graph.remove_nodes_from([ctc_greedy_decoder_tf.id, ctc_loss_tf.id, cast_tf.id, sparse_to_dense_tf.id])
|
||||
if ctc_loss_tf.logits_time_major:
|
||||
ctc_loss.in_port(0).connect(ctc_data_permute.out_port(0))
|
||||
else:
|
||||
ctc_loss.in_port(0).connect(transpose_tf.out_port(0))
|
||||
ctc_loss.in_port(1).connect(ctc_greedy_decoder_tf.in_port(1).get_connection().get_source())
|
||||
ctc_loss.in_port(2).connect(ctc_greedy_decoder.out_port(0))
|
||||
ctc_loss.in_port(3).connect(ctc_greedy_decoder.out_port(1))
|
||||
|
||||
# remove no longer needed nodes
|
||||
graph.remove_nodes_from([sparse_to_dense_tf.id, cast_tf.id, ctc_loss_tf.id, ctc_greedy_decoder_tf.id])
|
||||
|
@ -28,107 +28,101 @@ class CTCLossFrontReplacementTest(unittest.TestCase):
|
||||
def test1(self):
|
||||
nodes_attributes = {
|
||||
'logits': {'shape': int64_array([2, 6, 100]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
|
||||
'seq_mask': {'shape': int64_array([2, 100]), 'data_type': np.int32, 'kind': 'op', 'op': 'Parameter'},
|
||||
|
||||
'reduce_seq_mask': {'kind': 'op', 'op': 'ReduceSum'},
|
||||
's_cast_seq_mask': {'kind': 'op', 'op': 'Cast'},
|
||||
'transpose_cast_seq_mask': {'kind': 'op', 'op': 'Transpose'},
|
||||
|
||||
'seq_mask': {'shape': int64_array([2]), 'data_type': np.int32, 'kind': 'op', 'op': 'Parameter'},
|
||||
'transpose': {'kind': 'op', 'op': 'Transpose'},
|
||||
'ctc_greedy_decoder': {'kind': 'op', 'op': 'CTCGreedyDecoder'},
|
||||
'ctc_greedy_decoder': {'kind': 'op', 'op': 'CTCGreedyDecoderSeqLen', 'merge_repeated': True},
|
||||
'cast': {'kind': 'op', 'op': 'Cast'},
|
||||
'sparse_to_dense': {'kind': 'op', 'op': 'SparseToDense'},
|
||||
'const': {'kind': 'op', 'op': 'Const'},
|
||||
'tf_ctc_loss': {'kind': 'op', 'op': 'CTCLoss', 'preprocess_collapse_repeated': False,
|
||||
'ctc_merge_repeated': True, 'unique': False, 'logits_time_major': True},
|
||||
'ctc_loss': {'kind': 'op', 'op': 'CTCLoss', 'preprocess_collapse_repeated': False,
|
||||
'ctc_merge_repeated': True, 'unique': False},
|
||||
|
||||
'equal_op': {'kind': 'op', 'op': 'Equal'},
|
||||
|
||||
'ctc_greedy_decoder_op': {'kind': 'op', 'op': 'CTCGreedyDecoder'},
|
||||
'ctc_loss_op': {'kind': 'op', 'op': 'CTCLoss'},
|
||||
'squeeze_op': {'kind': 'op', 'op': 'Squeeze'},
|
||||
|
||||
'cast_labels_op': {'kind': 'op', 'op': 'Cast', 'type': 'Convert'},
|
||||
'labels_shape_op': {'kind': 'op', 'op': 'ShapeOf'},
|
||||
'broadcast_one_op': {'kind': 'op', 'op': 'Broadcast'},
|
||||
'broadcast_zero_op': {'kind': 'op', 'op': 'Broadcast'},
|
||||
'select_op': {'kind': 'op', 'op': 'Select'},
|
||||
'label_length_op': {'kind': 'op', 'op': 'ReduceSum'},
|
||||
|
||||
**const('reduce_indices', int64_array(1)),
|
||||
**const('permute_order', int64_array([1, 0])),
|
||||
'ctc_merge_repeated': True, 'unique': False},
|
||||
**const('default_value', int64_array(-1)),
|
||||
**const('squeeze_axis', int64_array([2, 3])),
|
||||
**const('minus_one', np.array([-1], dtype=np.int32)),
|
||||
**const('one', np.array([1], dtype=np.int32)),
|
||||
**const('zero', np.array([0], dtype=np.int32)),
|
||||
**const('reduce_sum_axis', int64_array([1])),
|
||||
|
||||
'last': {'type': None, 'value': None, 'kind': 'op', 'op': 'Result'},
|
||||
'transpose2': {'kind': 'op', 'op': 'Transpose'},
|
||||
**const('transpose2_axis', int64_array([1, 0, 2])),
|
||||
}
|
||||
|
||||
graph = build_graph(nodes_attributes,
|
||||
[('logits', 'transpose', {'out': 0, 'in': 0}),
|
||||
('transpose', 'ctc_greedy_decoder', {'out': 0, 'in': 0}),
|
||||
('seq_mask', 'ctc_greedy_decoder', {'out': 0, 'in': 1}),
|
||||
|
||||
('transpose', 'ctc_loss', {'out': 0, 'in': 0}),
|
||||
('seq_mask', 'ctc_loss', {'out': 0, 'in': 3}),
|
||||
|
||||
('ctc_greedy_decoder', 'sparse_to_dense', {'out': 0, 'in': 0}),
|
||||
('ctc_greedy_decoder', 'sparse_to_dense', {'out': 2, 'in': 1}),
|
||||
('ctc_greedy_decoder', 'sparse_to_dense', {'out': 1, 'in': 2}),
|
||||
('default_value', 'sparse_to_dense', {'out': 0, 'in': 3}),
|
||||
('ctc_greedy_decoder', 'cast', {'out': 1, 'in': 0}),
|
||||
('ctc_greedy_decoder', 'ctc_loss', {'out': 0, 'in': 1}),
|
||||
('cast', 'ctc_loss', {'out': 0, 'in': 2}),
|
||||
|
||||
('ctc_loss', 'last', {'out': 0, 'in': 0}),
|
||||
], nodes_with_edges_only=True)
|
||||
graph = build_graph(nodes_attributes, [('logits', 'transpose', {'out': 0, 'in': 0}),
|
||||
('transpose', 'ctc_greedy_decoder', {'out': 0, 'in': 0}),
|
||||
('seq_mask', 'ctc_greedy_decoder', {'out': 0, 'in': 1}),
|
||||
('transpose', 'tf_ctc_loss', {'out': 0, 'in': 0}),
|
||||
('seq_mask', 'tf_ctc_loss', {'out': 0, 'in': 3}),
|
||||
('ctc_greedy_decoder', 'sparse_to_dense', {'out': 0, 'in': 0}),
|
||||
('ctc_greedy_decoder', 'sparse_to_dense', {'out': 2, 'in': 1}),
|
||||
('ctc_greedy_decoder', 'sparse_to_dense', {'out': 1, 'in': 2}),
|
||||
('default_value', 'sparse_to_dense', {'out': 0, 'in': 3}),
|
||||
('ctc_greedy_decoder', 'cast', {'out': 1, 'in': 0}),
|
||||
('ctc_greedy_decoder', 'tf_ctc_loss', {'out': 0, 'in': 1}),
|
||||
('cast', 'tf_ctc_loss', {'out': 0, 'in': 2}),
|
||||
('tf_ctc_loss', 'last', {'out': 0, 'in': 0})],
|
||||
nodes_with_edges_only=True)
|
||||
graph.graph['cmd_params'] = Namespace(data_type='FP32')
|
||||
graph.stage = 'front'
|
||||
CTCLossReplacement().find_and_replace_pattern(graph)
|
||||
|
||||
graph_ref = build_graph(nodes_attributes,
|
||||
[('seq_mask', 'reduce_seq_mask', {'out': 0, 'in': 0}),
|
||||
('reduce_indices', 'reduce_seq_mask', {'out': 0, 'in': 1}),
|
||||
|
||||
('seq_mask', 's_cast_seq_mask', {'out': 0, 'in': 0}),
|
||||
('s_cast_seq_mask', 'transpose_cast_seq_mask', {'out': 0, 'in': 0}),
|
||||
('permute_order', 'transpose_cast_seq_mask', {'out': 0, 'in': 1}),
|
||||
|
||||
('logits', 'transpose', {'out': 0, 'in': 0}),
|
||||
('transpose', 'ctc_greedy_decoder_op', {'out': 0, 'in': 0}),
|
||||
('transpose_cast_seq_mask', 'ctc_greedy_decoder_op', {'out': 0, 'in': 1}),
|
||||
|
||||
('ctc_greedy_decoder_op', 'squeeze_op', {'out': 0, 'in': 0}),
|
||||
('squeeze_axis', 'squeeze_op', {'out': 0, 'in': 1}),
|
||||
('squeeze_op', 'cast_labels_op', {'in': 0}),
|
||||
|
||||
('minus_one', 'equal_op', {'out': 0, 'in': 1}),
|
||||
|
||||
('equal_op', 'labels_shape_op', {'out': 0, 'in': 0}),
|
||||
('one', 'broadcast_one_op', {'out': 0, 'in': 0}),
|
||||
('labels_shape_op', 'broadcast_one_op', {'out': 0, 'in': 1}),
|
||||
('zero', 'broadcast_zero_op', {'out': 0, 'in': 0}),
|
||||
('labels_shape_op', 'broadcast_zero_op', {'out': 0, 'in': 1}),
|
||||
|
||||
('equal_op', 'select_op', {'out': 0, 'in': 0}),
|
||||
('broadcast_zero_op', 'select_op', {'out': 0, 'in': 1}),
|
||||
('broadcast_one_op', 'select_op', {'out': 0, 'in': 2}),
|
||||
|
||||
('select_op', 'label_length_op', {'out': 0, 'in': 0}),
|
||||
('reduce_sum_axis', 'label_length_op', {'out': 0, 'in': 1}),
|
||||
|
||||
('logits', 'ctc_loss_op', {'out': 0, 'in': 0}),
|
||||
('reduce_seq_mask', 'ctc_loss_op', {'out': 0, 'in': 1}),
|
||||
('cast_labels_op', 'ctc_loss_op', {'out': 0, 'in': 2}),
|
||||
('label_length_op', 'ctc_loss_op', {'out': 0, 'in': 3}),
|
||||
|
||||
('cast_labels_op', 'equal_op', {'out': 0, 'in': 0}),
|
||||
|
||||
('ctc_loss_op', 'last', {'out': 0, 'in': 0})],
|
||||
nodes_with_edges_only=True)
|
||||
[('logits', 'transpose', {'out': 0, 'in': 0}),
|
||||
('transpose', 'transpose2', {'out': 0, 'in': 0}),
|
||||
('transpose2_axis', 'transpose2', {'out': 0, 'in': 1}),
|
||||
('transpose2', 'ctc_greedy_decoder', {'out': 0, 'in': 0}),
|
||||
('seq_mask', 'ctc_greedy_decoder', {'out': 0, 'in': 1}),
|
||||
('transpose2', 'ctc_loss', {'out': 0, 'in': 0}),
|
||||
('ctc_greedy_decoder', 'ctc_loss', {'out': 0, 'in': 2}),
|
||||
('ctc_greedy_decoder', 'ctc_loss', {'out': 1, 'in': 3}),
|
||||
('seq_mask', 'ctc_loss', {'out': 0, 'in': 1}),
|
||||
('ctc_loss', 'last', {'out': 0, 'in': 0})],
|
||||
nodes_with_edges_only=True)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'last', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test2(self):
|
||||
nodes_attributes = {
|
||||
'logits': {'shape': int64_array([2, 6, 100]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
|
||||
'seq_mask': {'shape': int64_array([2]), 'data_type': np.int32, 'kind': 'op', 'op': 'Parameter'},
|
||||
'transpose': {'kind': 'op', 'op': 'Transpose'},
|
||||
'ctc_greedy_decoder': {'kind': 'op', 'op': 'CTCGreedyDecoderSeqLen', 'merge_repeated': True},
|
||||
'cast': {'kind': 'op', 'op': 'Cast'},
|
||||
'sparse_to_dense': {'kind': 'op', 'op': 'SparseToDense'},
|
||||
'tf_ctc_loss': {'kind': 'op', 'op': 'CTCLoss', 'preprocess_collapse_repeated': False,
|
||||
'ctc_merge_repeated': True, 'unique': False, 'logits_time_major': False},
|
||||
'ctc_loss': {'kind': 'op', 'op': 'CTCLoss', 'preprocess_collapse_repeated': False,
|
||||
'ctc_merge_repeated': True, 'unique': False},
|
||||
**const('default_value', int64_array(-1)),
|
||||
'last': {'type': None, 'value': None, 'kind': 'op', 'op': 'Result'},
|
||||
'transpose2': {'kind': 'op', 'op': 'Transpose'},
|
||||
**const('transpose2_axis', int64_array([1, 0, 2])),
|
||||
}
|
||||
graph = build_graph(nodes_attributes, [('logits', 'transpose', {'out': 0, 'in': 0}),
|
||||
('transpose', 'ctc_greedy_decoder', {'out': 0, 'in': 0}),
|
||||
('seq_mask', 'ctc_greedy_decoder', {'out': 0, 'in': 1}),
|
||||
('transpose', 'tf_ctc_loss', {'out': 0, 'in': 0}),
|
||||
('seq_mask', 'tf_ctc_loss', {'out': 0, 'in': 3}),
|
||||
('ctc_greedy_decoder', 'sparse_to_dense', {'out': 0, 'in': 0}),
|
||||
('ctc_greedy_decoder', 'sparse_to_dense', {'out': 2, 'in': 1}),
|
||||
('ctc_greedy_decoder', 'sparse_to_dense', {'out': 1, 'in': 2}),
|
||||
('default_value', 'sparse_to_dense', {'out': 0, 'in': 3}),
|
||||
('ctc_greedy_decoder', 'cast', {'out': 1, 'in': 0}),
|
||||
('ctc_greedy_decoder', 'tf_ctc_loss', {'out': 0, 'in': 1}),
|
||||
('cast', 'tf_ctc_loss', {'out': 0, 'in': 2}),
|
||||
('tf_ctc_loss', 'last', {'out': 0, 'in': 0})],
|
||||
nodes_with_edges_only=True)
|
||||
graph.graph['cmd_params'] = Namespace(data_type='FP32')
|
||||
graph.stage = 'front'
|
||||
CTCLossReplacement().find_and_replace_pattern(graph)
|
||||
|
||||
graph_ref = build_graph(nodes_attributes,
|
||||
[('logits', 'transpose', {'out': 0, 'in': 0}),
|
||||
('transpose', 'transpose2', {'out': 0, 'in': 0}),
|
||||
('transpose2_axis', 'transpose2', {'out': 0, 'in': 1}),
|
||||
('transpose2', 'ctc_greedy_decoder', {'out': 0, 'in': 0}),
|
||||
('seq_mask', 'ctc_greedy_decoder', {'out': 0, 'in': 1}),
|
||||
('transpose', 'ctc_loss', {'out': 0, 'in': 0}),
|
||||
('ctc_greedy_decoder', 'ctc_loss', {'out': 0, 'in': 2}),
|
||||
('ctc_greedy_decoder', 'ctc_loss', {'out': 1, 'in': 3}),
|
||||
('seq_mask', 'ctc_loss', {'out': 0, 'in': 1}),
|
||||
('ctc_loss', 'last', {'out': 0, 'in': 0})],
|
||||
nodes_with_edges_only=True)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'last', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
@ -23,11 +23,18 @@ class CTCLossFrontExtractor(FrontExtractorOp):
|
||||
|
||||
@classmethod
|
||||
def extract(cls, node):
|
||||
# For CTCLoss default value is [N, T]
|
||||
logits_time_major = True
|
||||
if 'logits_time_major' in node.pb.attr:
|
||||
logits_time_major = node.pb.attr['logits_time_major'].b
|
||||
|
||||
attrs = {
|
||||
'ctc_merge_repeated': node.pb.attr['ctc_merge_repeated'].b,
|
||||
'preprocess_collapse_repeated': node.pb.attr['preprocess_collapse_repeated'].b,
|
||||
'logits_time_major': logits_time_major,
|
||||
# unique is always false for CTCLoss V1
|
||||
'unique': False
|
||||
}
|
||||
|
||||
CTCLoss.update_node_stat(node, attrs)
|
||||
return cls.enabled
|
||||
|
92
model-optimizer/extensions/ops/ctc_greedy_decoder_seq_len.py
Normal file
92
model-optimizer/extensions/ops/ctc_greedy_decoder_seq_len.py
Normal file
@ -0,0 +1,92 @@
|
||||
"""
|
||||
Copyright (C) 2017-2021 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.front.extractor import bool_to_str
|
||||
from mo.graph.graph import Node, Graph
|
||||
from mo.middle.passes.convert_data_type import np_data_type_to_destination_type
|
||||
from mo.ops.op import Op
|
||||
from mo.utils.error import Error
|
||||
|
||||
|
||||
class CTCGreedyDecoderSeqLenOp(Op):
|
||||
op = 'CTCGreedyDecoderSeqLen'
|
||||
|
||||
def __init__(self, graph: Graph, attrs: dict):
|
||||
mandatory_props = {
|
||||
'type': self.op,
|
||||
'op': self.op,
|
||||
'version': 'opset6',
|
||||
|
||||
'infer': self.infer,
|
||||
'type_infer': self.type_infer,
|
||||
|
||||
'in_ports_count': 3,
|
||||
'out_ports_count': 2,
|
||||
|
||||
'merge_repeated': True,
|
||||
'classes_index_type': np.int32,
|
||||
'sequence_length_type': np.int32
|
||||
}
|
||||
super().__init__(graph, mandatory_props, attrs)
|
||||
|
||||
def backend_attrs(self):
|
||||
version = self.get_opset()
|
||||
if version == 'opset6':
|
||||
return [('classes_index_type', lambda node: np_data_type_to_destination_type(node.classes_index_type)),
|
||||
('sequence_length_type', lambda node: np_data_type_to_destination_type(node.sequence_length_type)),
|
||||
('merge_repeated', lambda node: bool_to_str(node, 'merge_repeated'))]
|
||||
else:
|
||||
raise Error('Unknown opset version "{}"'.format(version))
|
||||
|
||||
@staticmethod
|
||||
def type_infer(node):
|
||||
opset = node.get_opset()
|
||||
if opset == 'opset6':
|
||||
node.out_port(0).set_data_type(node.classes_index_type)
|
||||
node.out_port(1).set_data_type(node.sequence_length_type)
|
||||
|
||||
@staticmethod
|
||||
def infer(node: Node):
|
||||
node_name = node.soft_get('name', node.id)
|
||||
connected_in_ports = [port for port in node.in_ports().values() if not port.disconnected()]
|
||||
assert len(connected_in_ports) in [2, 3], \
|
||||
"Incorrect number of inputs for {} node".format(node_name)
|
||||
|
||||
logits_shape = node.in_port(0).data.get_shape()
|
||||
sequence_len_shape = node.in_port(1).data.get_shape()
|
||||
if len(node.in_nodes()) == 3:
|
||||
blank_index_shape = node.in_port(2).data.get_shape()
|
||||
assert len(blank_index_shape) == 1, \
|
||||
'Incorrect rank of blank_index for {} node'.format(node_name)
|
||||
|
||||
# check shapes of input tensors
|
||||
assert len(logits_shape) == 3, \
|
||||
'Incorrect rank of logits for {} node'.format(node_name)
|
||||
|
||||
assert len(sequence_len_shape) == 1, \
|
||||
'Incorrect rank of sequence length tensor for {} node'.format(node_name)
|
||||
assert logits_shape[0] == sequence_len_shape[0], \
|
||||
'Batch dimensions of input tensors must be the same for {} node'.format(node_name)
|
||||
|
||||
batch_size = logits_shape[0]
|
||||
time_size = logits_shape[1]
|
||||
if node.is_out_port_connected(0):
|
||||
node.out_port(0).data.set_shape(int64_array([batch_size, time_size]))
|
||||
if node.is_out_port_connected(1):
|
||||
node.out_port(1).data.set_shape(int64_array([batch_size]))
|
@ -1,5 +1,5 @@
|
||||
"""
|
||||
Copyright (C) 2018-2020 Intel Corporation
|
||||
Copyright (C) 2018-2021 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -18,7 +18,7 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from extensions.ops.ctc_greedy_decoder import CTCGreedyDecoderOp
|
||||
from extensions.ops.ctc_greedy_decoder_seq_len import CTCGreedyDecoderSeqLenOp
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.graph.graph import Node
|
||||
from mo.utils.unittest.graph import build_graph
|
||||
@ -28,68 +28,73 @@ nodes_attributes = {'logits': {'kind': 'op'},
|
||||
'logits_data': {'shape': None, 'value': None, 'kind': 'data'},
|
||||
'seq_mask': {'kind': 'op'},
|
||||
'seq_mask_data': {'shape': None, 'value': None, 'kind': 'data'},
|
||||
'ctcgreedydecoder_node': {'op': 'CTCGreedyDecoder', 'kind': 'op',
|
||||
'ctcgreedydecoder_node': {'op': 'CTCGreedyDecoderSeqLen', 'kind': 'op',
|
||||
'ctc_merge_repeated': True},
|
||||
'output': {'shape': None, 'value': None, 'kind': 'data'}}
|
||||
'output1': {'shape': None, 'value': None, 'kind': 'data'},
|
||||
'last_output1': {'shape': None, 'value': None, 'kind': 'op'},
|
||||
'output2': {'shape': None, 'value': None, 'kind': 'data'}
|
||||
}
|
||||
|
||||
# graph 1
|
||||
edges1 = [('logits', 'logits_data'),
|
||||
('seq_mask', 'seq_mask_data'),
|
||||
('logits_data', 'ctcgreedydecoder_node', {'in': 0}),
|
||||
('seq_mask_data', 'ctcgreedydecoder_node', {'in': 1}),
|
||||
('ctcgreedydecoder_node', 'output', {'out': 0})]
|
||||
('ctcgreedydecoder_node', 'output1', {'out': 0}),
|
||||
('ctcgreedydecoder_node', 'output2', {'out': 1}),
|
||||
('output1', 'last_output1', {'out': 0}),]
|
||||
|
||||
# valid test case
|
||||
inputs1 = {'logits_data': {'shape': int64_array([100, 4, 5])},
|
||||
'seq_mask_data': {'shape': int64_array([100, 4])}}
|
||||
inputs1 = {'logits_data': {'shape': int64_array([4, 100, 5])},
|
||||
'seq_mask_data': {'shape': int64_array([4])}}
|
||||
|
||||
# invalid test case with incorrect rank for the first input tensor
|
||||
inputs1_inv = {'logits_data': {'shape': int64_array([100, 4, 5, 6])},
|
||||
'seq_mask_data': {'shape': int64_array([100, 4])}}
|
||||
inputs1_inv = {'logits_data': {'shape': int64_array([4, 100, 5, 6])},
|
||||
'seq_mask_data': {'shape': int64_array([4])}}
|
||||
|
||||
# invalid test case with incorrect rank for the second input tensor
|
||||
inputs2_inv = {'logits_data': {'shape': int64_array([100, 4, 5])},
|
||||
'seq_mask_data': {'shape': int64_array([100])}}
|
||||
inputs2_inv = {'logits_data': {'shape': int64_array([4, 100, 5])},
|
||||
'seq_mask_data': {'shape': int64_array([4, 100])}}
|
||||
|
||||
# invalid test case with incorrect time dimension
|
||||
inputs3_inv = {'logits_data': {'shape': int64_array([100, 4, 5])},
|
||||
'seq_mask_data': {'shape': int64_array([101, 4])}}
|
||||
inputs3_inv = {'logits_data': {'shape': int64_array([4, 100, 5])},
|
||||
'seq_mask_data': {'shape': int64_array([4, 101])}}
|
||||
|
||||
# invalid test case with incorrect batch dimension
|
||||
inputs4_inv = {'logits_data': {'shape': int64_array([100, 4, 5])},
|
||||
'seq_mask_data': {'shape': int64_array([100, 14])}}
|
||||
inputs4_inv = {'logits_data': {'shape': int64_array([4, 100, 5])},
|
||||
'seq_mask_data': {'shape': int64_array([14, 100])}}
|
||||
|
||||
class TestCTCGreedyDecoder(unittest.TestCase):
|
||||
def test_infer1(self):
|
||||
graph = build_graph(nodes_attributes, edges1, inputs1)
|
||||
ctcgreedydecoder_node = Node(graph, 'ctcgreedydecoder_node')
|
||||
CTCGreedyDecoderOp.infer(ctcgreedydecoder_node)
|
||||
CTCGreedyDecoderSeqLenOp.infer(ctcgreedydecoder_node)
|
||||
|
||||
# prepare reference results
|
||||
ref_output_shape = int64_array([4, 100, 1, 1])
|
||||
ref_output1_shape = int64_array([4, 100])
|
||||
|
||||
# get the result
|
||||
res_output_shape = graph.node['output']['shape']
|
||||
res_output1_shape = graph.node['output1']['shape']
|
||||
|
||||
self.assertTrue(np.array_equal(ref_output_shape, res_output_shape),
|
||||
'shapes do not match expected: {} and given: {}'.format(ref_output_shape, res_output_shape))
|
||||
self.assertTrue(np.array_equal(ref_output1_shape, res_output1_shape),
|
||||
'shapes do not match expected: {} and given: {}'.format(ref_output1_shape, res_output1_shape))
|
||||
|
||||
def test_infer_invalid1(self):
|
||||
graph = build_graph(nodes_attributes, edges1, inputs1_inv)
|
||||
ctcgreedydecoder_node = Node(graph, 'ctcgreedydecoder_node')
|
||||
self.assertRaises(AssertionError, CTCGreedyDecoderOp.infer, ctcgreedydecoder_node)
|
||||
self.assertRaises(AssertionError, CTCGreedyDecoderSeqLenOp.infer, ctcgreedydecoder_node)
|
||||
|
||||
def test_infer_invalid2(self):
|
||||
graph = build_graph(nodes_attributes, edges1, inputs2_inv)
|
||||
ctcgreedydecoder_node = Node(graph, 'ctcgreedydecoder_node')
|
||||
self.assertRaises(AssertionError, CTCGreedyDecoderOp.infer, ctcgreedydecoder_node)
|
||||
self.assertRaises(AssertionError, CTCGreedyDecoderSeqLenOp.infer, ctcgreedydecoder_node)
|
||||
|
||||
def test_infer_invalid3(self):
|
||||
graph = build_graph(nodes_attributes, edges1, inputs3_inv)
|
||||
ctcgreedydecoder_node = Node(graph, 'ctcgreedydecoder_node')
|
||||
self.assertRaises(AssertionError, CTCGreedyDecoderOp.infer, ctcgreedydecoder_node)
|
||||
self.assertRaises(AssertionError, CTCGreedyDecoderSeqLenOp.infer, ctcgreedydecoder_node)
|
||||
|
||||
def test_infer_invalid4(self):
|
||||
graph = build_graph(nodes_attributes, edges1, inputs4_inv)
|
||||
ctcgreedydecoder_node = Node(graph, 'ctcgreedydecoder_node')
|
||||
self.assertRaises(AssertionError, CTCGreedyDecoderOp.infer, ctcgreedydecoder_node)
|
||||
self.assertRaises(AssertionError, CTCGreedyDecoderSeqLenOp.infer, ctcgreedydecoder_node)
|
||||
|
@ -0,0 +1,30 @@
|
||||
"""
|
||||
Copyright (c) 2021 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
from mo.middle.passes.convert_data_type import destination_type_to_np_data_type
|
||||
from mo.utils.graph import Node
|
||||
from mo.utils.ir_reader.extender import Extender
|
||||
|
||||
|
||||
class CTCGreedyDecoderSeqLenExtender(Extender):
|
||||
op = 'CTCGreedyDecoderSeqLen'
|
||||
|
||||
@staticmethod
|
||||
def extend(op: Node):
|
||||
if op.has_valid('classes_index_type'):
|
||||
op['classes_index_type'] = destination_type_to_np_data_type(op.classes_index_type)
|
||||
if op.has_valid('sequence_length_type'):
|
||||
op['sequence_length_type'] = destination_type_to_np_data_type(op.sequence_length_type)
|
Loading…
Reference in New Issue
Block a user