Make model reshape and track batch (#12736)

* CVS-89672 Make model reshape and track batch

* Minor refactoring

* Changed mechanism of constant replacement to more mature

* Update src/common/transformations/include/transformations/smart_reshape/lstm_states_broadcast.hpp

* Update src/common/transformations/src/transformations/smart_reshape/lstm_states_broadcast.cpp
This commit is contained in:
Evgenya Stepyreva
2022-08-26 12:05:34 +04:00
committed by GitHub
parent 8c1a3bab25
commit 3731913049
8 changed files with 595 additions and 1 deletions

View File

@@ -0,0 +1,29 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <memory>
#include <ngraph/pass/graph_rewrite.hpp>
#include <vector>
namespace ngraph {
namespace pass {
class NGRAPH_API LSTMStatesBroadcast;
} // namespace pass
} // namespace ngraph
/**
* @ingroup ie_transformation_common_api
* @brief In case LSTMCell has constant initial hidden and cell state with single batch size
* we make them broadcast-able by batch
*/
class ngraph::pass::LSTMStatesBroadcast : public ngraph::pass::FunctionPass {
public:
OPENVINO_RTTI("LSTMStatesBroadcast", "0");
bool run_on_model(const std::shared_ptr<ngraph::Function>& m) override;
};

View File

@@ -0,0 +1,31 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <memory>
#include <ngraph/pass/graph_rewrite.hpp>
#include <vector>
namespace ngraph {
namespace pass {
class NGRAPH_API ReshapeSinkingMatMul;
} // namespace pass
} // namespace ngraph
/**
* @ingroup ie_transformation_common_api
* @brief ReshapeSinkingMatMul transformation looks for MatMul followed by optional Add
* surrounded with Reshape operations which are only needed to merge and unmerge dimensions
* into MatMuls batch. In case of success upscales MatMul to work with multidimensional batch and updates
* Reshape operators to make batch propagate through freely
*/
class ngraph::pass::ReshapeSinkingMatMul : public ngraph::pass::MatcherPass {
public:
OPENVINO_RTTI("ReshapeSinkingMatMul", "0");
ReshapeSinkingMatMul();
};

View File

@@ -67,6 +67,8 @@
#include <transformations/op_conversions/convert_divide.hpp>
#include <transformations/op_conversions/convert_negative.hpp>
#include <transformations/op_conversions/convert_scatter_elements_to_scatter.hpp>
#include <transformations/smart_reshape/lstm_states_broadcast.hpp>
#include <transformations/smart_reshape/reshape_sinking.hpp>
#include "itt.hpp"
@@ -115,6 +117,13 @@ bool ngraph::pass::MOCTransformations::run_on_model(const std::shared_ptr<ngraph
manager.register_pass<ngraph::pass::FuseFilteringBoxesBySize>();
manager.register_pass<ngraph::pass::Validate>();
if (!m_use_shapes) { // Approved Smart Reshape
manager.register_pass<ngraph::pass::LSTMStatesBroadcast>();
manager.register_pass<ngraph::pass::Validate>();
manager.register_pass<ngraph::pass::ReshapeSinkingMatMul>();
manager.register_pass<ngraph::pass::Validate>();
}
manager.register_pass<ngraph::pass::ConvertQuantizeDequantize>();
manager.register_pass<ngraph::pass::SimplifyShapeOfSubGraph>();
if (!m_use_shapes) {

View File

@@ -0,0 +1,181 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <memory>
#include <ngraph/pass/manager.hpp>
#include <openvino/op/util/sub_graph_base.hpp>
#include <openvino/opsets/opset9.hpp>
#include <transformations/smart_reshape/lstm_states_broadcast.hpp>
#include <transformations/utils/utils.hpp>
#include "dimension_tracker.hpp"
#include "itt.hpp"
ov::Input<ov::Node> get_outer_input_of_ti_by_parameter(const std::shared_ptr<ov::opset9::Parameter>& parameter,
const std::shared_ptr<ov::opset9::TensorIterator>& ti) {
const auto& body = ti->get_body();
OPENVINO_ASSERT(body != nullptr, "TI returns invalid body graph ", ti);
int64_t parameter_index = ti->get_body()->get_parameter_index(parameter);
OPENVINO_ASSERT(parameter_index >= 0,
"LSTMStatesBroadcast encountered unregistered parameter ",
parameter,
" related to TI body ",
ti);
for (const auto& input_descriptor : ti->get_input_descriptions())
if (input_descriptor->m_body_parameter_index == parameter_index)
return ti->input(input_descriptor->m_input_index);
OPENVINO_UNREACHABLE("LSTMStatesBroadcast failed to get outer input of TI by its inner Parameter. TI ",
ti,
" Parameter ",
parameter);
}
std::shared_ptr<ov::Node> deduce_outer_source_of_batch_for_inner_lstm_cell(
const std::shared_ptr<ov::opset9::TensorIterator>& ti,
const std::shared_ptr<ov::opset9::LSTMCell>& lstm_cell) {
const auto& body = ti->get_body();
OPENVINO_ASSERT(body != nullptr, "TI returns invalid body graph ", ti);
std::map<ov::opset9::Parameter*, ov::PartialShape> original_shapes;
size_t label = 1;
// mark all input dimensions with labels and making them dynamic, keeping original shapes
for (auto& parameter : body->get_parameters()) {
auto pshape = parameter->get_partial_shape();
original_shapes[parameter.get()] = pshape;
if (pshape.rank().is_dynamic())
continue;
for (ngraph::Dimension& n : pshape) {
OPENVINO_ASSERT(ov::DimensionTracker::get_label(n) == 0,
"LSTMStatesBroadcast encountered TI with previously tracked dimensions");
n = ov::Dimension::dynamic();
ov::DimensionTracker::set_label(n, label++);
}
parameter->set_partial_shape(pshape);
}
// propagate labels through TI body
body->validate_nodes_and_infer_types();
// if lstm first input has undefined rank or if tracked label is zero -- we failed to track batch dimension
// returning body to initial state
if (lstm_cell->get_input_partial_shape(0).rank().is_dynamic() ||
ov::DimensionTracker::get_label(lstm_cell->get_input_partial_shape(0)[0]) == 0) {
for (auto& item : original_shapes)
item.first->set_partial_shape(item.second);
body->validate_nodes_and_infer_types();
return nullptr;
}
// batch label was tracked -- finding parameter that delivered it
std::shared_ptr<ov::opset9::Parameter> batch_delivering_parameter;
size_t index_of_batch_dim = 0;
size_t batch_label = ov::DimensionTracker::get_label(lstm_cell->get_input_partial_shape(0)[0]);
for (auto& parameter : body->get_parameters()) {
auto pshape = parameter->get_partial_shape();
if (pshape.rank().is_dynamic())
continue;
for (size_t i = 0; i < pshape.size(); ++i) {
if (ov::DimensionTracker::get_label(pshape[i]) == batch_label) {
batch_delivering_parameter = parameter;
index_of_batch_dim = i;
}
}
}
for (auto& item : original_shapes)
item.first->set_partial_shape(item.second);
body->validate_nodes_and_infer_types();
if (batch_delivering_parameter == nullptr)
return nullptr;
const auto& batched_source = get_outer_input_of_ti_by_parameter(batch_delivering_parameter, ti);
const auto& batched_shape = std::make_shared<ov::opset9::ShapeOf>(batched_source.get_source_output());
const auto& batch = std::make_shared<ov::opset9::Gather>(
batched_shape,
ov::opset9::Constant::create(ov::element::i64, ov::Shape{1}, {index_of_batch_dim}),
ov::opset9::Constant::create(ov::element::i64, ov::Shape{}, {0}));
return batch;
}
bool broadcast_state_by_batch(ov::Input<ov::Node> input, const std::shared_ptr<ov::Node>& batch_delivering_node) {
auto constant_state =
std::dynamic_pointer_cast<ov::opset9::Constant>(input.get_source_output().get_node_shared_ptr());
if (constant_state == nullptr)
return false;
const auto& constant_shape = constant_state->get_shape();
OPENVINO_ASSERT(constant_shape.size() == 2, "State has unexpected shape ", constant_shape);
if (constant_shape[0] != 1)
// we only expect to broadcast LSTM states prepared for batch 1 -- no tiling of batch > 1 will be done
return false;
const auto& constant_copy = constant_state->copy_with_new_inputs({});
const auto& broadcast_by_batch = std::make_shared<ov::opset9::Broadcast>(
constant_copy,
std::make_shared<ov::opset9::Concat>(
ngraph::NodeVector{batch_delivering_node,
ngraph::op::util::make_try_fold<ov::opset9::Gather>(
ngraph::op::util::make_try_fold<ov::opset9::ShapeOf>(constant_copy),
ov::opset9::Constant::create(ov::element::i64, ov::Shape{1}, {1}),
ov::opset9::Constant::create(ov::element::i64, ov::Shape{}, {0}))},
0));
input.replace_source_output(broadcast_by_batch->output(0));
return true;
}
bool relax_batch_for_initial_states_of_lstm_in_ti(const std::shared_ptr<ov::opset9::TensorIterator>& ti,
const std::shared_ptr<ov::opset9::LSTMCell>& lstm_cell) {
bool rewritten = false;
auto batch_delivering_node = deduce_outer_source_of_batch_for_inner_lstm_cell(ti, lstm_cell);
if (batch_delivering_node == nullptr)
return rewritten;
if (auto init_hidden_state =
std::dynamic_pointer_cast<ov::opset9::Parameter>(lstm_cell->get_input_node_shared_ptr(1))) {
auto outer_init_hidden_state_input = get_outer_input_of_ti_by_parameter(init_hidden_state, ti);
rewritten |= broadcast_state_by_batch(outer_init_hidden_state_input, batch_delivering_node);
}
if (auto init_cell_state =
std::dynamic_pointer_cast<ov::opset9::Parameter>(lstm_cell->get_input_node_shared_ptr(2))) {
auto outer_init_cell_state_input = get_outer_input_of_ti_by_parameter(init_cell_state, ti);
rewritten |= broadcast_state_by_batch(outer_init_cell_state_input, batch_delivering_node);
}
return rewritten;
}
bool relax_batch_for_initial_states_of_lstm(const std::shared_ptr<ov::opset9::LSTMCell>& lstm_cell) {
bool rewritten = false;
const auto& batched_shape = std::make_shared<ov::opset9::ShapeOf>(lstm_cell->get_input_source_output(0));
const auto& batch_delivering_node =
std::make_shared<ov::opset9::Gather>(batched_shape,
ov::opset9::Constant::create(ov::element::i64, ov::Shape{1}, {0}),
ov::opset9::Constant::create(ov::element::i64, ov::Shape{}, {0}));
rewritten |= broadcast_state_by_batch(lstm_cell->input(1), batch_delivering_node);
rewritten |= broadcast_state_by_batch(lstm_cell->input(2), batch_delivering_node);
return rewritten;
}
bool ngraph::pass::LSTMStatesBroadcast::run_on_model(const std::shared_ptr<ngraph::Function>& f) {
RUN_ON_FUNCTION_SCOPE(LSTMStatesBroadcast);
bool rewritten = false;
for (auto& node : f->get_ordered_ops()) {
// Recursively apply transformation for sub-graph based operations
if (const auto& sub_graph_node = std::dynamic_pointer_cast<ov::op::util::SubGraphOp>(node))
if (const auto& sub_graph = sub_graph_node->get_function())
rewritten |= run_on_model(sub_graph);
// Case without TI (LSTMCell and Constant are in the same ov::Model)
if (const auto& lstm_cell = std::dynamic_pointer_cast<ov::opset9::LSTMCell>(node))
rewritten |= relax_batch_for_initial_states_of_lstm(lstm_cell);
// Case with TI (LSTMCell and Constant are in different ov::Model objects)
if (auto ti = std::dynamic_pointer_cast<ov::opset9::TensorIterator>(node)) {
auto body = ti->get_body();
OPENVINO_ASSERT(body, "TensorIterator must have body network");
for (const auto& body_node : body->get_ordered_ops())
if (const auto& lstm_cell = std::dynamic_pointer_cast<ov::opset9::LSTMCell>(body_node))
rewritten |= relax_batch_for_initial_states_of_lstm_in_ti(ti, lstm_cell);
}
}
return rewritten;
}

View File

@@ -0,0 +1,156 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <ngraph/opsets/opset9.hpp>
#include <ngraph/pattern/matcher.hpp>
#include <ngraph/pattern/op/or.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <ngraph/rt_info.hpp>
#include <transformations/smart_reshape/reshape_sinking.hpp>
#include "itt.hpp"
ngraph::pass::ReshapeSinkingMatMul::ReshapeSinkingMatMul() {
MATCHER_SCOPE(ReshapeSinkingMatMul);
/* Original graph: Transformed graph:
*
* any_input any_input
* | shape=[B, S, K] | shape=[B, S, K]
* Reshape output_pattern=(-1, K) Reshape output_pattern=(0, 0, K)
* | shape=[B * S, K] | shape=[B, S, K]
* MatMul constant_shape=[K, O] MatMul constant_shape=[K, O]
* | shape=[B * S, O] | shape=[B, S, O]
* Reshape output_pattern=(B=1, S, O) Reshape output_pattern=(0, S, O)
* | shape=[1, S, O] | shape=[B, S, O]
*/
auto any_input = pattern::any_input(pattern::has_static_rank());
auto reshape_label = ngraph::pattern::wrap_type<opset9::Reshape>(
{pattern::any_input(), ngraph::pattern::wrap_type<opset9::Constant>()},
pattern::rank_equals(2));
auto matmul_label =
ngraph::pattern::wrap_type<opset9::MatMul>({reshape_label, ngraph::pattern::wrap_type<opset9::Constant>()},
pattern::rank_equals(2));
auto add_label =
ngraph::pattern::wrap_type<opset9::Add>({matmul_label, ngraph::pattern::wrap_type<opset9::Constant>()},
pattern::rank_equals(2));
auto matmul_or_matmul_add_label = std::make_shared<pattern::op::Or>(OutputVector{add_label, matmul_label});
auto reshape_1_label = ngraph::pattern::wrap_type<opset9::Reshape>(
{matmul_or_matmul_add_label, ngraph::pattern::wrap_type<opset9::Constant>()},
pattern::has_static_rank());
matcher_pass_callback callback = [=](pattern::Matcher& m) -> bool {
auto pattern_to_node = m.get_pattern_map();
// check first Reshape eligibility: has a constant output pattern in a form of [-1, K]
auto reshape = pattern_to_node.at(reshape_label);
int64_t K = -1;
if (const auto& constant = std::dynamic_pointer_cast<opset9::Constant>(reshape->get_input_node_shared_ptr(1))) {
auto output_pattern_vector = constant->cast_vector<int64_t>();
if (output_pattern_vector.size() != 2 || output_pattern_vector[0] != -1)
return false;
K = output_pattern_vector[1];
}
if (K == -1)
return false;
// check input shape eligibility: has a form of [x1, x2, ..., xn, K]
auto input_pshape = reshape->get_input_partial_shape(0);
if (input_pshape.rank().is_dynamic() || input_pshape.rank().get_length() <= 2)
return false;
auto input_rank = input_pshape.size();
if (input_pshape[input_rank - 1] != K)
return false;
// check matmul eligibility: has constant second input in a form of [O, K]
auto matmul = std::dynamic_pointer_cast<opset9::MatMul>(pattern_to_node.at(matmul_label));
if (!matmul || matmul->get_transpose_a())
return false;
int64_t O = -1;
if (const auto& constant = std::dynamic_pointer_cast<opset9::Constant>(matmul->get_input_node_shared_ptr(1))) {
const auto& constant_shape = constant->get_shape();
if (constant_shape.size() != 2)
return false;
const auto& desired_K_index = matmul->get_transpose_b() ? 1 : 0;
const auto& O_index = matmul->get_transpose_b() ? 0 : 1;
if (constant_shape[desired_K_index] != K)
return false;
O = static_cast<int64_t>(constant_shape[O_index]);
}
if (O == -1)
return false;
// check add eligibility if present: has constant second input that has a form of [1, 1, ..., O] (doesn't
// broadcast first input)
if (pattern_to_node.count(add_label)) {
auto add = std::dynamic_pointer_cast<opset9::Add>(pattern_to_node.at(add_label));
if (!add || add->get_autob() != ov::op::AutoBroadcastType::NUMPY)
return false;
const auto& constant = std::dynamic_pointer_cast<opset9::Constant>(add->get_input_node_shared_ptr(1));
if (!constant)
return false;
const auto& constant_shape = constant->get_shape();
auto desired_ones_shape = ov::Shape(constant_shape.size(), 1);
auto desired_shape = ov::Shape(constant_shape.size() - 1, 1);
desired_shape.push_back(O);
OPENVINO_ASSERT(constant_shape.size() == desired_ones_shape.size() &&
constant_shape.size() == desired_shape.size());
if (constant_shape != desired_shape && constant_shape != desired_ones_shape)
return false;
}
// check second Reshape eligibility: has hard-coded output pattern constant which is almost the same as
// input_shape of the pattern except for the batch and last dimension
auto reshape_1 = m.get_match_root();
const auto& constant = std::dynamic_pointer_cast<opset9::Constant>(reshape_1->get_input_node_shared_ptr(1));
if (constant == nullptr)
return false;
auto output_pattern = constant->cast_vector<int64_t>();
if (!std::all_of(output_pattern.begin(), output_pattern.end(), [](const int64_t& i) {
return i > 0;
}))
return false;
if (output_pattern.size() != input_rank)
return false;
for (size_t i = 0; i < input_rank; ++i) {
if (i == 0)
continue;
if (i + 1 == input_rank) {
if (output_pattern[i] != O)
return false;
else
continue;
}
if (input_pshape[i] != output_pattern[i])
return false;
}
// this is the pattern we are looking for! performing the transformation
auto first_reshape = std::dynamic_pointer_cast<opset9::Reshape>(reshape);
if (!first_reshape)
return false;
first_reshape->set_special_zero(true);
auto second_reshape = std::dynamic_pointer_cast<opset9::Reshape>(reshape_1);
if (!second_reshape)
return false;
second_reshape->set_special_zero(true);
std::vector<int64_t> output_pattern_vector(input_rank - 1, 0);
output_pattern_vector.push_back(K);
auto new_reshape_constant =
opset9::Constant::create(ov::element::i64, Shape{input_rank}, output_pattern_vector);
reshape->input(1).replace_source_output(new_reshape_constant->output(0));
output_pattern[0] = 0;
auto new_reshape_1_constant = opset9::Constant::create(ov::element::i64, Shape{input_rank}, output_pattern);
reshape_1->input(1).replace_source_output(new_reshape_1_constant->output(0));
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(reshape_1_label, matcher_name);
register_matcher(m, callback);
}

View File

@@ -6,9 +6,10 @@
#include <ngraph/pass/manager.hpp>
#include <transformations/init_node_info.hpp>
#include <transformations/smart_reshape/broadcast_const_range_replacement.hpp>
#include <transformations/smart_reshape/lstm_states_broadcast.hpp>
#include <transformations/smart_reshape/matmul_sr.hpp>
#include <transformations/smart_reshape/mimic_set_batch_size.hpp>
#include <transformations/smart_reshape/proposal_scales_stridedslice.hpp>
#include <transformations/smart_reshape/reshape_sinking.hpp>
#include <transformations/smart_reshape/reshape_to_1D.hpp>
#include <transformations/smart_reshape/smart_reshape.hpp>
#include <transformations/smart_reshape/strided_slice_squeeze.hpp>
@@ -29,6 +30,8 @@ bool ngraph::pass::SmartReshape::run_on_model(const std::shared_ptr<ngraph::Func
static_manager.register_pass<ngraph::pass::ReshapeTo1D>();
static_manager.register_pass<ngraph::pass::TransposeMatMul>();
static_manager.register_pass<ngraph::pass::BroadcastConstRangeReplacement>();
static_manager.register_pass<ngraph::pass::LSTMStatesBroadcast>();
static_manager.register_pass<ngraph::pass::ReshapeSinkingMatMul>();
static_manager.run_passes(f);
ngraph::pass::Manager dynamic_manager;

View File

@@ -0,0 +1,106 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include <openvino/core/model.hpp>
#include <openvino/opsets/opset9.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
template<class T>
std::shared_ptr<ov::opset9::Constant> create_constant(const std::vector<T>& data, const ov::element::Type_t et = ov::element::i64, bool scalar = false) {
ov::Shape shape = scalar ? ov::Shape{} : ov::Shape{data.size()};
return ov::opset9::Constant::create(et, shape, data);
}
std::shared_ptr<ov::opset9::Constant> create_zero_constant(const ov::element::Type_t et, ov::PartialShape shape) {
return ov::opset9::Constant::create(et, shape.to_shape(), {0});
}
struct LSTMStatesAttributes {
ov::element::Type_t data_et;
ov::Dimension data_batch_size, new_batch_size;
ov::Dimension input_size, hidden_size;
};
class LSTMStatesBroadcastTest
: public testing::WithParamInterface<LSTMStatesAttributes>, public CommonTestUtils::TestsCommon {
};
TEST_P(LSTMStatesBroadcastTest, BareLSTM) {
auto p = GetParam();
std::shared_ptr<ov::Model> model(nullptr);
{
auto parameter = std::make_shared<ov::opset9::Parameter>(p.data_et, ov::PartialShape{p.data_batch_size, p.input_size});
auto initial_hidden_state = create_zero_constant(p.data_et, {1, p.hidden_size});
auto initial_cell_state = create_zero_constant(p.data_et, {1, p.hidden_size});
auto W = create_zero_constant(p.data_et, {p.hidden_size * 4, p.input_size});
auto R = create_zero_constant(p.data_et, {p.hidden_size * 4, p.hidden_size});
auto cell = std::make_shared<ov::opset9::LSTMCell>(
parameter, initial_hidden_state, initial_cell_state, W, R, static_cast<size_t>(p.hidden_size.get_length()));
model = std::make_shared<ov::Model>(ov::NodeVector{cell}, ov::ParameterVector{parameter});
}
ASSERT_NO_THROW(model->reshape(ov::PartialShape{p.new_batch_size, p.input_size}));
}
class LSTMStatesBroadcastTestWithTI
: public testing::WithParamInterface<LSTMStatesAttributes>, public CommonTestUtils::TestsCommon {
};
TEST_P(LSTMStatesBroadcastTestWithTI, TI_With_LSTM) {
auto p = GetParam();
std::shared_ptr<ov::Model> model(nullptr);
{
auto X = std::make_shared<ov::opset9::Parameter>(p.data_et, ov::PartialShape{p.data_batch_size, 1, p.input_size});
auto H_init = create_zero_constant(ov::element::i64, {1, p.hidden_size});
auto C_init = create_zero_constant(ov::element::i64, {1, p.hidden_size});
auto Xi = std::make_shared<ov::opset9::Parameter>(p.data_et, ov::PartialShape{1, 1, p.input_size});
auto H_t = std::make_shared<ov::opset9::Parameter>(p.data_et, ov::PartialShape{1, p.hidden_size});
auto C_t = std::make_shared<ov::opset9::Parameter>(p.data_et, ov::PartialShape{1, p.hidden_size});
// Body
auto squeeze = std::make_shared<ov::opset9::Squeeze>(Xi, create_constant<int64_t>({1}));
auto W = create_zero_constant(p.data_et, {p.hidden_size * 4, p.input_size});
auto R = create_zero_constant(p.data_et, {p.hidden_size * 4, p.hidden_size});
auto lstm_cell = std::make_shared<ov::opset9::LSTMCell>(squeeze, H_t, C_t, W, R, static_cast<size_t>(p.hidden_size.get_length()));
auto res_1 = std::make_shared<ov::opset9::Result>(lstm_cell->output(0));
auto unsqueeze = std::make_shared<ov::opset9::Unsqueeze>(lstm_cell->output(0), create_constant<int64_t>({1}));
auto res_2 = std::make_shared<ov::opset9::Result>(unsqueeze);
auto res_3 = std::make_shared<ov::opset9::Result>(lstm_cell->output(1));
auto body = std::make_shared<ov::Model>(ov::OutputVector{res_1, res_2, res_3}, ov::ParameterVector{Xi, H_t, C_t});
auto tensor_iterator = std::make_shared<ov::opset9::TensorIterator>();
tensor_iterator->set_body(body);
tensor_iterator->set_merged_input(C_t, C_init, res_3);
tensor_iterator->set_sliced_input(Xi, X, 0, 1, 1, -1, 1);
tensor_iterator->set_merged_input(H_t, H_init, res_1);
auto out0 = tensor_iterator->get_iter_value(res_1, -1);
auto out1 = tensor_iterator->get_concatenated_slices(res_2, 0, 1, 1, -1, 0);
auto res_ti_1 = std::make_shared<ov::opset9::Result>(tensor_iterator->output(1));
auto res_ti_2 = std::make_shared<ov::opset9::Result>(tensor_iterator->output(0));
model = std::make_shared<ov::Model>(ov::NodeVector{res_ti_1, res_ti_2},
ov::ParameterVector{X});
}
model->reshape(ov::PartialShape{p.new_batch_size, 1, p.input_size});
ASSERT_NO_THROW(model->reshape(ov::PartialShape{p.new_batch_size, 1, p.input_size}));
}
static std::vector<LSTMStatesAttributes> params = {
LSTMStatesAttributes{ov::element::f32, {1}, {2}, {512}, {256}},
LSTMStatesAttributes{ov::element::f32, {-1}, {2}, {512}, {256}},
};
INSTANTIATE_TEST_SUITE_P(SmartReshapeTests, LSTMStatesBroadcastTest, ::testing::ValuesIn(params));
INSTANTIATE_TEST_SUITE_P(SmartReshapeTests, LSTMStatesBroadcastTestWithTI, ::testing::ValuesIn(params));

View File

@@ -0,0 +1,79 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include <openvino/core/model.hpp>
#include <openvino/opsets/opset9.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
template<class T>
std::shared_ptr<ov::opset9::Constant> create_constant(const std::vector<T>& data, const ov::element::Type_t et = ov::element::i64, bool scalar = false) {
ov::Shape shape = scalar ? ov::Shape{} : ov::Shape{data.size()};
return ov::opset9::Constant::create(et, shape, data);
}
std::shared_ptr<ov::opset9::Constant> create_zero_constant(const ov::element::Type_t et, ov::Shape shape) {
return ov::opset9::Constant::create(et, shape, {0});
}
struct ReshapeSinkingAttributes {
ov::element::Type_t data_et;
ov::PartialShape input_shape;
ov::PartialShape new_shape;
std::vector<int64_t> output_pattern;
std::vector<int64_t> output_pattern_back;
ov::Shape mm_second_input_shape;
bool transpose_a, transpose_b;
};
class ReshapeSinkingTest
: public testing::WithParamInterface<ReshapeSinkingAttributes>, public CommonTestUtils::TestsCommon {
};
TEST_P(ReshapeSinkingTest, ReshapeSinkingOnlyMatMul) {
auto p = GetParam();
std::shared_ptr<ov::Model> model(nullptr);
{
auto parameter = std::make_shared<ov::opset9::Parameter>(p.data_et, p.input_shape);
auto reshape = std::make_shared<ov::opset9::Reshape>(parameter, create_constant(p.output_pattern), false);
auto matmul = std::make_shared<ov::opset9::MatMul>(reshape, create_zero_constant(p.data_et, p.mm_second_input_shape),
p.transpose_a, p.transpose_b);
auto reshape_back = std::make_shared<ov::opset9::Reshape>(matmul, create_constant(p.output_pattern_back), false);
model = std::make_shared<ov::Model>(ov::NodeVector{reshape_back}, ov::ParameterVector{parameter});
}
ASSERT_NO_THROW(model->reshape(p.new_shape));
}
class ReshapeSinkingTestWithAdd
: public testing::WithParamInterface<ReshapeSinkingAttributes>, public CommonTestUtils::TestsCommon {
};
TEST_P(ReshapeSinkingTestWithAdd, ReshapeSinkingMatMulAdd) {
auto p = GetParam();
std::shared_ptr<ov::Model> model(nullptr);
{
auto parameter = std::make_shared<ov::opset9::Parameter>(p.data_et, p.input_shape);
auto reshape = std::make_shared<ov::opset9::Reshape>(parameter, create_constant(p.output_pattern), false);
auto matmul = std::make_shared<ov::opset9::MatMul>(reshape, create_zero_constant(p.data_et, p.mm_second_input_shape),
p.transpose_a, p.transpose_b);
auto add = std::make_shared<ov::opset9::Add>(matmul, create_zero_constant(p.data_et, {1, 37}));
auto reshape_back = std::make_shared<ov::opset9::Reshape>(add, create_constant(p.output_pattern_back), false);
model = std::make_shared<ov::Model>(ov::NodeVector{reshape_back}, ov::ParameterVector{parameter});
}
ASSERT_NO_THROW(model->reshape(p.new_shape));
}
static std::vector<ReshapeSinkingAttributes> params = {
ReshapeSinkingAttributes{ov::element::f32, {10, 30, 512}, {20, 30, 512}, {-1, 512}, {10, 30, 37}, {37, 512}, false, true},
ReshapeSinkingAttributes{ov::element::f32, {-1, 30, 512}, {20, 30, 512}, {-1, 512}, {10, 30, 37}, {37, 512}, false, true},
ReshapeSinkingAttributes{ov::element::f32, {1, 3, 4, 512}, {2, 3, 4, 512}, {-1, 512}, {1, 3, 4, 37}, {37, 512}, false, true},
ReshapeSinkingAttributes{ov::element::f32, {1, 3, 4, 512}, {2, 3, 4, 512}, {-1, 512}, {1, 3, 4, 37}, {512, 37}, false, false},
};
INSTANTIATE_TEST_SUITE_P(SmartReshapeTests, ReshapeSinkingTest, ::testing::ValuesIn(params));
INSTANTIATE_TEST_SUITE_P(SmartReshapeTests, ReshapeSinkingTestWithAdd, ::testing::ValuesIn(params));