Support dynamic seq lenghts in ConvertSequenceToTensorIterator transformation (#20671)
This commit is contained in:
parent
620a0fc289
commit
69e1258cc5
@ -182,7 +182,8 @@ TRANSFORMATIONS_API bool check_for_broadcast(const PartialShape& ref_shape, cons
|
||||
|
||||
TRANSFORMATIONS_API std::shared_ptr<Node> activation(const std::string& activation_name, const Output<Node>& apply_to);
|
||||
|
||||
TRANSFORMATIONS_API bool is_seq_len_provided(const std::shared_ptr<Node>& seq_len_input, int64_t max_seq_len);
|
||||
TRANSFORMATIONS_API bool is_seq_len_provided(const std::shared_ptr<Node>& X,
|
||||
const std::shared_ptr<Node>& seq_len_input);
|
||||
|
||||
TRANSFORMATIONS_API std::shared_ptr<Node> try_fold_unary_output(const std::shared_ptr<Node>& node);
|
||||
|
||||
|
@ -88,12 +88,11 @@ bool convert_sequence_to_ti(const std::shared_ptr<ov::Node>& sequence,
|
||||
const ov::Output<ov::Node>& B,
|
||||
const ov::op::RecurrentSequenceDirection& direction) {
|
||||
auto X_pshape = X.get_partial_shape();
|
||||
if (X_pshape.size() < 2 || X_pshape[1].is_dynamic()) {
|
||||
if (X_pshape.size() < 2) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto max_seq_len = X_pshape[1].get_length();
|
||||
bool enable_mask = ov::op::util::is_seq_len_provided(seq_lengths.get_node_shared_ptr(), max_seq_len);
|
||||
bool enable_mask = ov::op::util::is_seq_len_provided(X.get_node_shared_ptr(), seq_lengths.get_node_shared_ptr());
|
||||
|
||||
const bool is_reverse = direction == ov::op::RecurrentSequenceDirection::REVERSE;
|
||||
std::shared_ptr<ov::Node> reverse_seq_before;
|
||||
|
@ -132,11 +132,56 @@ std::shared_ptr<ov::Node> activation(const std::string& activation_name, const o
|
||||
}
|
||||
}
|
||||
|
||||
bool is_seq_len_provided(const std::shared_ptr<Node>& seq_len_input, int64_t max_seq_len) {
|
||||
bool is_seq_len_provided(const std::shared_ptr<Node>& X, const std::shared_ptr<Node>& seq_len_input) {
|
||||
auto max_seq_dim = X->get_output_partial_shape(0)[1];
|
||||
if (max_seq_dim.is_dynamic()) {
|
||||
// if values in seq_len input are equal to max_seq_len dim in X input
|
||||
// then we don't need to insert Select operations
|
||||
// supported seq_len_input:
|
||||
// X -> ShapeOf -> Gather (max_seq_dim) -> Optional (Broadcast)
|
||||
std::shared_ptr<Node> input = seq_len_input;
|
||||
auto broadcast = ov::as_type_ptr<ov::op::v3::Broadcast>(input);
|
||||
if (broadcast) {
|
||||
input = seq_len_input->input_value(0).get_node_shared_ptr();
|
||||
}
|
||||
|
||||
auto gather = ov::as_type_ptr<ov::op::util::GatherBase>(input);
|
||||
bool valid_gather = false;
|
||||
if (gather) {
|
||||
auto indices = gather->input_value(1).get_node_shared_ptr();
|
||||
auto axis = gather->input_value(2).get_node_shared_ptr();
|
||||
auto indices_const = ov::as_type_ptr<ov::op::v0::Constant>(indices);
|
||||
auto axis_const = ov::as_type_ptr<ov::op::v0::Constant>(axis);
|
||||
if (indices_const && axis_const) {
|
||||
auto ind_values = indices_const->cast_vector<int64_t>();
|
||||
auto axis_values = axis_const->cast_vector<int64_t>();
|
||||
if (ind_values.size() == 1 && ind_values[0] == 1 && axis_values.size() == 1 && axis_values[0] == 0) {
|
||||
valid_gather = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!valid_gather) {
|
||||
return true;
|
||||
}
|
||||
|
||||
auto shape_of = ov::as_type_ptr<ov::op::util::ShapeOfBase>(gather->input_value(0).get_node_shared_ptr());
|
||||
if (!shape_of) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (shape_of->input_value(0).get_node_shared_ptr() != X) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
auto max_seq_len_val = max_seq_dim.get_length();
|
||||
if (const auto& seq_len_const = std::dynamic_pointer_cast<op::v0::Constant>(seq_len_input)) {
|
||||
const auto& seq_len_values = seq_len_const->cast_vector<int64_t>();
|
||||
return std::any_of(seq_len_values.begin(), seq_len_values.end(), [max_seq_len](const int64_t val) {
|
||||
return val != max_seq_len;
|
||||
return std::any_of(seq_len_values.begin(), seq_len_values.end(), [max_seq_len_val](const int64_t val) {
|
||||
return val != max_seq_len_val;
|
||||
});
|
||||
}
|
||||
return true;
|
||||
|
@ -798,3 +798,115 @@ TEST(TransformationTests, ConvertQuantizedGRUSequenceToTensorIterator) {
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, ConvertLSTMSequenceWithDynSeqLenToTensorIterator) {
|
||||
std::shared_ptr<ov::Model> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto X = std::make_shared<opset5::Parameter>(element::f32, PartialShape{1, -1, 16});
|
||||
auto Y = std::make_shared<opset5::Parameter>(element::f32, Shape{1, 1, 128});
|
||||
auto Z = std::make_shared<opset5::Parameter>(element::f32, Shape{1, 1, 128});
|
||||
auto shape_of = std::make_shared<opset5::ShapeOf>(X);
|
||||
auto indices = opset5::Constant::create(element::i32, {1}, {1});
|
||||
auto axis = opset5::Constant::create(element::i32, {}, {0});
|
||||
auto seq_lengths = std::make_shared<opset5::Gather>(shape_of, indices, axis);
|
||||
|
||||
auto w_val = std::vector<float>(512 * 16, 0);
|
||||
auto r_val = std::vector<float>(512 * 128, 0);
|
||||
auto b_val = std::vector<float>(512, 0);
|
||||
auto W = opset5::Constant::create(element::f32, Shape{1, 512, 16}, w_val);
|
||||
auto R = opset5::Constant::create(element::f32, Shape{1, 512, 128}, r_val);
|
||||
auto B = opset5::Constant::create(element::f32, Shape{1, 512}, b_val);
|
||||
|
||||
auto rnn_sequence = std::make_shared<opset5::LSTMSequence>(X,
|
||||
Y,
|
||||
Z,
|
||||
seq_lengths,
|
||||
W,
|
||||
R,
|
||||
B,
|
||||
128,
|
||||
op::RecurrentSequenceDirection::FORWARD);
|
||||
auto Y_out = std::make_shared<opset5::Result>(rnn_sequence->output(0));
|
||||
auto Ho = std::make_shared<opset5::Result>(rnn_sequence->output(1));
|
||||
auto Co = std::make_shared<opset5::Result>(rnn_sequence->output(2));
|
||||
Y_out->set_friendly_name("Y_out");
|
||||
Ho->set_friendly_name("Ho");
|
||||
Co->set_friendly_name("Co");
|
||||
|
||||
f = std::make_shared<ov::Model>(NodeVector{Y_out, Ho, Co}, ParameterVector{X, Y, Z});
|
||||
|
||||
pass::Manager m;
|
||||
m.register_pass<ov::pass::InitNodeInfo>();
|
||||
m.register_pass<ov::pass::ConvertLSTMSequenceToTensorIterator>();
|
||||
m.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
{
|
||||
auto X = std::make_shared<opset5::Parameter>(element::f32, PartialShape{1, -1, 16});
|
||||
auto Y = std::make_shared<opset5::Parameter>(element::f32, Shape{1, 1, 128});
|
||||
auto Z = std::make_shared<opset5::Parameter>(element::f32, Shape{1, 1, 128});
|
||||
auto squeeze_pattern = opset5::Constant::create(element::i64, Shape{1}, {1});
|
||||
auto squeeze_y = std::make_shared<opset5::Squeeze>(Y, squeeze_pattern);
|
||||
auto squeeze_z = std::make_shared<opset5::Squeeze>(Z, squeeze_pattern);
|
||||
|
||||
auto Xi = std::make_shared<opset5::Parameter>(element::f32, Shape{1, 1, 16});
|
||||
auto Yi = std::make_shared<opset5::Parameter>(element::f32, Shape{1, 128});
|
||||
auto Zi = std::make_shared<opset5::Parameter>(element::f32, Shape{1, 128});
|
||||
auto seq_body_param = std::make_shared<opset5::Parameter>(element::i32, PartialShape{1});
|
||||
|
||||
// Body
|
||||
auto squeeze_x = std::make_shared<opset5::Squeeze>(Xi, squeeze_pattern);
|
||||
|
||||
auto w_val = std::vector<float>(512 * 16, 0);
|
||||
auto r_val = std::vector<float>(512 * 128, 0);
|
||||
auto b_val = std::vector<float>(512, 0);
|
||||
auto W = opset5::Constant::create(element::f32, Shape{512, 16}, w_val);
|
||||
auto R = opset5::Constant::create(element::f32, Shape{512, 128}, r_val);
|
||||
auto B = opset5::Constant::create(element::f32, Shape{512}, b_val);
|
||||
|
||||
auto rnn_cell = std::make_shared<opset5::LSTMCell>(squeeze_x, Yi, Zi, W, R, B, 128);
|
||||
|
||||
auto unsqueeze_pattern = opset5::Constant::create(element::i64, Shape{1}, {1});
|
||||
auto Ho = std::make_shared<opset5::Result>(rnn_cell->output(0));
|
||||
|
||||
auto Co = std::make_shared<opset5::Result>(rnn_cell->output(1));
|
||||
|
||||
auto unsqueeze_y = std::make_shared<opset5::Unsqueeze>(rnn_cell->output(0), unsqueeze_pattern);
|
||||
auto Y_out = std::make_shared<opset5::Result>(unsqueeze_y);
|
||||
|
||||
auto body = std::make_shared<Model>(OutputVector{Y_out, Ho, Co}, ParameterVector{Xi, Yi, Zi, seq_body_param});
|
||||
|
||||
auto tensor_iterator = std::make_shared<opset5::TensorIterator>();
|
||||
tensor_iterator->set_body(body);
|
||||
|
||||
tensor_iterator->set_sliced_input(Xi, X, 0, 1, 1, -1, 1);
|
||||
tensor_iterator->get_concatenated_slices(Y_out, 0, 1, 1, -1, 1);
|
||||
|
||||
tensor_iterator->set_merged_input(Yi, squeeze_y, Ho);
|
||||
tensor_iterator->set_merged_input(Zi, squeeze_z, Co);
|
||||
|
||||
auto shape_of = std::make_shared<opset5::ShapeOf>(X);
|
||||
auto indices = opset5::Constant::create(element::i32, {1}, {1});
|
||||
auto axis = opset5::Constant::create(element::i32, {}, {0});
|
||||
auto seq_lengths = std::make_shared<opset5::Gather>(shape_of, indices, axis);
|
||||
tensor_iterator->set_invariant_input(seq_body_param, seq_lengths);
|
||||
|
||||
tensor_iterator->get_iter_value(Ho);
|
||||
tensor_iterator->get_iter_value(Co);
|
||||
|
||||
auto res_ti_Y = std::make_shared<opset5::Result>(
|
||||
std::make_shared<opset5::Unsqueeze>(tensor_iterator->output(0), unsqueeze_pattern));
|
||||
auto res_ti_H = std::make_shared<opset5::Result>(
|
||||
std::make_shared<opset5::Unsqueeze>(tensor_iterator->output(1), unsqueeze_pattern));
|
||||
auto res_ti_C = std::make_shared<opset5::Result>(
|
||||
std::make_shared<opset5::Unsqueeze>(tensor_iterator->output(2), unsqueeze_pattern));
|
||||
res_ti_Y->set_friendly_name("Y_out");
|
||||
res_ti_H->set_friendly_name("Ho");
|
||||
res_ti_C->set_friendly_name("Co");
|
||||
f_ref = std::make_shared<ov::Model>(NodeVector{res_ti_Y, res_ti_H, res_ti_C}, ParameterVector{X, Y, Z});
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
@ -318,8 +318,9 @@ bool RNN::isSupportedOperation(const std::shared_ptr<const ov::Node>& op, std::s
|
||||
errorMessage = "Max sequence length dimension is dynamic";
|
||||
return false;
|
||||
}
|
||||
auto maxSeqLen = data_pshape[maxSeqLenDimIdx].get_length();
|
||||
if (ov::op::util::is_seq_len_provided(op->get_input_node_shared_ptr(seqLenIdx), maxSeqLen)) {
|
||||
|
||||
if (ov::op::util::is_seq_len_provided(op->get_input_node_shared_ptr(0),
|
||||
op->get_input_node_shared_ptr(seqLenIdx))) {
|
||||
errorMessage = "Unsupported sequence length.";
|
||||
return false;
|
||||
}
|
||||
|
@ -379,8 +379,8 @@ void TransformationsPipeline::apply(std::shared_ptr<ov::Model> func) {
|
||||
return lstm_seq->get_clip() == 0.0f &&
|
||||
lstm_seq->get_activations() == std::vector<std::string>{"sigmoid", "tanh", "tanh"} &&
|
||||
max_seq_len < 16 &&
|
||||
!ov::op::util::is_seq_len_provided(lstm_seq->get_input_node_shared_ptr(3),
|
||||
max_seq_len);
|
||||
!ov::op::util::is_seq_len_provided(lstm_seq->get_input_node_shared_ptr(0),
|
||||
lstm_seq->get_input_node_shared_ptr(3));
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
Loading…
Reference in New Issue
Block a user