Make model reshape and track batch (#12736)
* CVS-89672 Make model reshape and track batch * Minor refactoring * Changed mechanism of constant replacement to more mature * Update src/common/transformations/include/transformations/smart_reshape/lstm_states_broadcast.hpp * Update src/common/transformations/src/transformations/smart_reshape/lstm_states_broadcast.cpp
This commit is contained in:
committed by
GitHub
parent
8c1a3bab25
commit
3731913049
@@ -0,0 +1,29 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <ngraph/pass/graph_rewrite.hpp>
|
||||
#include <vector>
|
||||
|
||||
namespace ngraph {
|
||||
namespace pass {
|
||||
|
||||
class NGRAPH_API LSTMStatesBroadcast;
|
||||
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief In case LSTMCell has constant initial hidden and cell state with single batch size
|
||||
* we make them broadcast-able by batch
|
||||
*/
|
||||
|
||||
class ngraph::pass::LSTMStatesBroadcast : public ngraph::pass::FunctionPass {
|
||||
public:
|
||||
OPENVINO_RTTI("LSTMStatesBroadcast", "0");
|
||||
bool run_on_model(const std::shared_ptr<ngraph::Function>& m) override;
|
||||
};
|
||||
@@ -0,0 +1,31 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <ngraph/pass/graph_rewrite.hpp>
|
||||
#include <vector>
|
||||
|
||||
namespace ngraph {
|
||||
namespace pass {
|
||||
|
||||
class NGRAPH_API ReshapeSinkingMatMul;
|
||||
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief ReshapeSinkingMatMul transformation looks for MatMul followed by optional Add
|
||||
* surrounded with Reshape operations which are only needed to merge and unmerge dimensions
|
||||
* into MatMuls batch. In case of success upscales MatMul to work with multidimensional batch and updates
|
||||
* Reshape operators to make batch propagate through freely
|
||||
*/
|
||||
|
||||
class ngraph::pass::ReshapeSinkingMatMul : public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ReshapeSinkingMatMul", "0");
|
||||
ReshapeSinkingMatMul();
|
||||
};
|
||||
@@ -67,6 +67,8 @@
|
||||
#include <transformations/op_conversions/convert_divide.hpp>
|
||||
#include <transformations/op_conversions/convert_negative.hpp>
|
||||
#include <transformations/op_conversions/convert_scatter_elements_to_scatter.hpp>
|
||||
#include <transformations/smart_reshape/lstm_states_broadcast.hpp>
|
||||
#include <transformations/smart_reshape/reshape_sinking.hpp>
|
||||
|
||||
#include "itt.hpp"
|
||||
|
||||
@@ -115,6 +117,13 @@ bool ngraph::pass::MOCTransformations::run_on_model(const std::shared_ptr<ngraph
|
||||
manager.register_pass<ngraph::pass::FuseFilteringBoxesBySize>();
|
||||
manager.register_pass<ngraph::pass::Validate>();
|
||||
|
||||
if (!m_use_shapes) { // Approved Smart Reshape
|
||||
manager.register_pass<ngraph::pass::LSTMStatesBroadcast>();
|
||||
manager.register_pass<ngraph::pass::Validate>();
|
||||
manager.register_pass<ngraph::pass::ReshapeSinkingMatMul>();
|
||||
manager.register_pass<ngraph::pass::Validate>();
|
||||
}
|
||||
|
||||
manager.register_pass<ngraph::pass::ConvertQuantizeDequantize>();
|
||||
manager.register_pass<ngraph::pass::SimplifyShapeOfSubGraph>();
|
||||
if (!m_use_shapes) {
|
||||
|
||||
@@ -0,0 +1,181 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <memory>
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
#include <openvino/op/util/sub_graph_base.hpp>
|
||||
#include <openvino/opsets/opset9.hpp>
|
||||
#include <transformations/smart_reshape/lstm_states_broadcast.hpp>
|
||||
#include <transformations/utils/utils.hpp>
|
||||
|
||||
#include "dimension_tracker.hpp"
|
||||
#include "itt.hpp"
|
||||
|
||||
ov::Input<ov::Node> get_outer_input_of_ti_by_parameter(const std::shared_ptr<ov::opset9::Parameter>& parameter,
|
||||
const std::shared_ptr<ov::opset9::TensorIterator>& ti) {
|
||||
const auto& body = ti->get_body();
|
||||
OPENVINO_ASSERT(body != nullptr, "TI returns invalid body graph ", ti);
|
||||
int64_t parameter_index = ti->get_body()->get_parameter_index(parameter);
|
||||
OPENVINO_ASSERT(parameter_index >= 0,
|
||||
"LSTMStatesBroadcast encountered unregistered parameter ",
|
||||
parameter,
|
||||
" related to TI body ",
|
||||
ti);
|
||||
for (const auto& input_descriptor : ti->get_input_descriptions())
|
||||
if (input_descriptor->m_body_parameter_index == parameter_index)
|
||||
return ti->input(input_descriptor->m_input_index);
|
||||
OPENVINO_UNREACHABLE("LSTMStatesBroadcast failed to get outer input of TI by its inner Parameter. TI ",
|
||||
ti,
|
||||
" Parameter ",
|
||||
parameter);
|
||||
}
|
||||
|
||||
std::shared_ptr<ov::Node> deduce_outer_source_of_batch_for_inner_lstm_cell(
|
||||
const std::shared_ptr<ov::opset9::TensorIterator>& ti,
|
||||
const std::shared_ptr<ov::opset9::LSTMCell>& lstm_cell) {
|
||||
const auto& body = ti->get_body();
|
||||
OPENVINO_ASSERT(body != nullptr, "TI returns invalid body graph ", ti);
|
||||
|
||||
std::map<ov::opset9::Parameter*, ov::PartialShape> original_shapes;
|
||||
size_t label = 1;
|
||||
|
||||
// mark all input dimensions with labels and making them dynamic, keeping original shapes
|
||||
for (auto& parameter : body->get_parameters()) {
|
||||
auto pshape = parameter->get_partial_shape();
|
||||
original_shapes[parameter.get()] = pshape;
|
||||
if (pshape.rank().is_dynamic())
|
||||
continue;
|
||||
for (ngraph::Dimension& n : pshape) {
|
||||
OPENVINO_ASSERT(ov::DimensionTracker::get_label(n) == 0,
|
||||
"LSTMStatesBroadcast encountered TI with previously tracked dimensions");
|
||||
n = ov::Dimension::dynamic();
|
||||
ov::DimensionTracker::set_label(n, label++);
|
||||
}
|
||||
parameter->set_partial_shape(pshape);
|
||||
}
|
||||
|
||||
// propagate labels through TI body
|
||||
body->validate_nodes_and_infer_types();
|
||||
// if lstm first input has undefined rank or if tracked label is zero -- we failed to track batch dimension
|
||||
// returning body to initial state
|
||||
if (lstm_cell->get_input_partial_shape(0).rank().is_dynamic() ||
|
||||
ov::DimensionTracker::get_label(lstm_cell->get_input_partial_shape(0)[0]) == 0) {
|
||||
for (auto& item : original_shapes)
|
||||
item.first->set_partial_shape(item.second);
|
||||
body->validate_nodes_and_infer_types();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// batch label was tracked -- finding parameter that delivered it
|
||||
std::shared_ptr<ov::opset9::Parameter> batch_delivering_parameter;
|
||||
size_t index_of_batch_dim = 0;
|
||||
|
||||
size_t batch_label = ov::DimensionTracker::get_label(lstm_cell->get_input_partial_shape(0)[0]);
|
||||
for (auto& parameter : body->get_parameters()) {
|
||||
auto pshape = parameter->get_partial_shape();
|
||||
if (pshape.rank().is_dynamic())
|
||||
continue;
|
||||
for (size_t i = 0; i < pshape.size(); ++i) {
|
||||
if (ov::DimensionTracker::get_label(pshape[i]) == batch_label) {
|
||||
batch_delivering_parameter = parameter;
|
||||
index_of_batch_dim = i;
|
||||
}
|
||||
}
|
||||
}
|
||||
for (auto& item : original_shapes)
|
||||
item.first->set_partial_shape(item.second);
|
||||
body->validate_nodes_and_infer_types();
|
||||
|
||||
if (batch_delivering_parameter == nullptr)
|
||||
return nullptr;
|
||||
|
||||
const auto& batched_source = get_outer_input_of_ti_by_parameter(batch_delivering_parameter, ti);
|
||||
const auto& batched_shape = std::make_shared<ov::opset9::ShapeOf>(batched_source.get_source_output());
|
||||
const auto& batch = std::make_shared<ov::opset9::Gather>(
|
||||
batched_shape,
|
||||
ov::opset9::Constant::create(ov::element::i64, ov::Shape{1}, {index_of_batch_dim}),
|
||||
ov::opset9::Constant::create(ov::element::i64, ov::Shape{}, {0}));
|
||||
return batch;
|
||||
}
|
||||
|
||||
bool broadcast_state_by_batch(ov::Input<ov::Node> input, const std::shared_ptr<ov::Node>& batch_delivering_node) {
|
||||
auto constant_state =
|
||||
std::dynamic_pointer_cast<ov::opset9::Constant>(input.get_source_output().get_node_shared_ptr());
|
||||
if (constant_state == nullptr)
|
||||
return false;
|
||||
const auto& constant_shape = constant_state->get_shape();
|
||||
OPENVINO_ASSERT(constant_shape.size() == 2, "State has unexpected shape ", constant_shape);
|
||||
if (constant_shape[0] != 1)
|
||||
// we only expect to broadcast LSTM states prepared for batch 1 -- no tiling of batch > 1 will be done
|
||||
return false;
|
||||
|
||||
const auto& constant_copy = constant_state->copy_with_new_inputs({});
|
||||
const auto& broadcast_by_batch = std::make_shared<ov::opset9::Broadcast>(
|
||||
constant_copy,
|
||||
std::make_shared<ov::opset9::Concat>(
|
||||
ngraph::NodeVector{batch_delivering_node,
|
||||
ngraph::op::util::make_try_fold<ov::opset9::Gather>(
|
||||
ngraph::op::util::make_try_fold<ov::opset9::ShapeOf>(constant_copy),
|
||||
ov::opset9::Constant::create(ov::element::i64, ov::Shape{1}, {1}),
|
||||
ov::opset9::Constant::create(ov::element::i64, ov::Shape{}, {0}))},
|
||||
0));
|
||||
input.replace_source_output(broadcast_by_batch->output(0));
|
||||
return true;
|
||||
}
|
||||
|
||||
bool relax_batch_for_initial_states_of_lstm_in_ti(const std::shared_ptr<ov::opset9::TensorIterator>& ti,
|
||||
const std::shared_ptr<ov::opset9::LSTMCell>& lstm_cell) {
|
||||
bool rewritten = false;
|
||||
auto batch_delivering_node = deduce_outer_source_of_batch_for_inner_lstm_cell(ti, lstm_cell);
|
||||
if (batch_delivering_node == nullptr)
|
||||
return rewritten;
|
||||
if (auto init_hidden_state =
|
||||
std::dynamic_pointer_cast<ov::opset9::Parameter>(lstm_cell->get_input_node_shared_ptr(1))) {
|
||||
auto outer_init_hidden_state_input = get_outer_input_of_ti_by_parameter(init_hidden_state, ti);
|
||||
rewritten |= broadcast_state_by_batch(outer_init_hidden_state_input, batch_delivering_node);
|
||||
}
|
||||
if (auto init_cell_state =
|
||||
std::dynamic_pointer_cast<ov::opset9::Parameter>(lstm_cell->get_input_node_shared_ptr(2))) {
|
||||
auto outer_init_cell_state_input = get_outer_input_of_ti_by_parameter(init_cell_state, ti);
|
||||
rewritten |= broadcast_state_by_batch(outer_init_cell_state_input, batch_delivering_node);
|
||||
}
|
||||
return rewritten;
|
||||
}
|
||||
|
||||
bool relax_batch_for_initial_states_of_lstm(const std::shared_ptr<ov::opset9::LSTMCell>& lstm_cell) {
|
||||
bool rewritten = false;
|
||||
const auto& batched_shape = std::make_shared<ov::opset9::ShapeOf>(lstm_cell->get_input_source_output(0));
|
||||
const auto& batch_delivering_node =
|
||||
std::make_shared<ov::opset9::Gather>(batched_shape,
|
||||
ov::opset9::Constant::create(ov::element::i64, ov::Shape{1}, {0}),
|
||||
ov::opset9::Constant::create(ov::element::i64, ov::Shape{}, {0}));
|
||||
rewritten |= broadcast_state_by_batch(lstm_cell->input(1), batch_delivering_node);
|
||||
rewritten |= broadcast_state_by_batch(lstm_cell->input(2), batch_delivering_node);
|
||||
return rewritten;
|
||||
}
|
||||
|
||||
bool ngraph::pass::LSTMStatesBroadcast::run_on_model(const std::shared_ptr<ngraph::Function>& f) {
|
||||
RUN_ON_FUNCTION_SCOPE(LSTMStatesBroadcast);
|
||||
bool rewritten = false;
|
||||
for (auto& node : f->get_ordered_ops()) {
|
||||
// Recursively apply transformation for sub-graph based operations
|
||||
if (const auto& sub_graph_node = std::dynamic_pointer_cast<ov::op::util::SubGraphOp>(node))
|
||||
if (const auto& sub_graph = sub_graph_node->get_function())
|
||||
rewritten |= run_on_model(sub_graph);
|
||||
|
||||
// Case without TI (LSTMCell and Constant are in the same ov::Model)
|
||||
if (const auto& lstm_cell = std::dynamic_pointer_cast<ov::opset9::LSTMCell>(node))
|
||||
rewritten |= relax_batch_for_initial_states_of_lstm(lstm_cell);
|
||||
|
||||
// Case with TI (LSTMCell and Constant are in different ov::Model objects)
|
||||
if (auto ti = std::dynamic_pointer_cast<ov::opset9::TensorIterator>(node)) {
|
||||
auto body = ti->get_body();
|
||||
OPENVINO_ASSERT(body, "TensorIterator must have body network");
|
||||
for (const auto& body_node : body->get_ordered_ops())
|
||||
if (const auto& lstm_cell = std::dynamic_pointer_cast<ov::opset9::LSTMCell>(body_node))
|
||||
rewritten |= relax_batch_for_initial_states_of_lstm_in_ti(ti, lstm_cell);
|
||||
}
|
||||
}
|
||||
return rewritten;
|
||||
}
|
||||
@@ -0,0 +1,156 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <ngraph/opsets/opset9.hpp>
|
||||
#include <ngraph/pattern/matcher.hpp>
|
||||
#include <ngraph/pattern/op/or.hpp>
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
#include <ngraph/rt_info.hpp>
|
||||
#include <transformations/smart_reshape/reshape_sinking.hpp>
|
||||
|
||||
#include "itt.hpp"
|
||||
|
||||
ngraph::pass::ReshapeSinkingMatMul::ReshapeSinkingMatMul() {
|
||||
MATCHER_SCOPE(ReshapeSinkingMatMul);
|
||||
/* Original graph: Transformed graph:
|
||||
*
|
||||
* any_input any_input
|
||||
* | shape=[B, S, K] | shape=[B, S, K]
|
||||
* Reshape output_pattern=(-1, K) Reshape output_pattern=(0, 0, K)
|
||||
* | shape=[B * S, K] | shape=[B, S, K]
|
||||
* MatMul constant_shape=[K, O] MatMul constant_shape=[K, O]
|
||||
* | shape=[B * S, O] | shape=[B, S, O]
|
||||
* Reshape output_pattern=(B=1, S, O) Reshape output_pattern=(0, S, O)
|
||||
* | shape=[1, S, O] | shape=[B, S, O]
|
||||
*/
|
||||
auto any_input = pattern::any_input(pattern::has_static_rank());
|
||||
auto reshape_label = ngraph::pattern::wrap_type<opset9::Reshape>(
|
||||
{pattern::any_input(), ngraph::pattern::wrap_type<opset9::Constant>()},
|
||||
pattern::rank_equals(2));
|
||||
|
||||
auto matmul_label =
|
||||
ngraph::pattern::wrap_type<opset9::MatMul>({reshape_label, ngraph::pattern::wrap_type<opset9::Constant>()},
|
||||
pattern::rank_equals(2));
|
||||
auto add_label =
|
||||
ngraph::pattern::wrap_type<opset9::Add>({matmul_label, ngraph::pattern::wrap_type<opset9::Constant>()},
|
||||
pattern::rank_equals(2));
|
||||
|
||||
auto matmul_or_matmul_add_label = std::make_shared<pattern::op::Or>(OutputVector{add_label, matmul_label});
|
||||
|
||||
auto reshape_1_label = ngraph::pattern::wrap_type<opset9::Reshape>(
|
||||
{matmul_or_matmul_add_label, ngraph::pattern::wrap_type<opset9::Constant>()},
|
||||
pattern::has_static_rank());
|
||||
|
||||
matcher_pass_callback callback = [=](pattern::Matcher& m) -> bool {
|
||||
auto pattern_to_node = m.get_pattern_map();
|
||||
|
||||
// check first Reshape eligibility: has a constant output pattern in a form of [-1, K]
|
||||
auto reshape = pattern_to_node.at(reshape_label);
|
||||
int64_t K = -1;
|
||||
if (const auto& constant = std::dynamic_pointer_cast<opset9::Constant>(reshape->get_input_node_shared_ptr(1))) {
|
||||
auto output_pattern_vector = constant->cast_vector<int64_t>();
|
||||
if (output_pattern_vector.size() != 2 || output_pattern_vector[0] != -1)
|
||||
return false;
|
||||
K = output_pattern_vector[1];
|
||||
}
|
||||
if (K == -1)
|
||||
return false;
|
||||
|
||||
// check input shape eligibility: has a form of [x1, x2, ..., xn, K]
|
||||
auto input_pshape = reshape->get_input_partial_shape(0);
|
||||
if (input_pshape.rank().is_dynamic() || input_pshape.rank().get_length() <= 2)
|
||||
return false;
|
||||
auto input_rank = input_pshape.size();
|
||||
if (input_pshape[input_rank - 1] != K)
|
||||
return false;
|
||||
|
||||
// check matmul eligibility: has constant second input in a form of [O, K]
|
||||
auto matmul = std::dynamic_pointer_cast<opset9::MatMul>(pattern_to_node.at(matmul_label));
|
||||
if (!matmul || matmul->get_transpose_a())
|
||||
return false;
|
||||
int64_t O = -1;
|
||||
if (const auto& constant = std::dynamic_pointer_cast<opset9::Constant>(matmul->get_input_node_shared_ptr(1))) {
|
||||
const auto& constant_shape = constant->get_shape();
|
||||
if (constant_shape.size() != 2)
|
||||
return false;
|
||||
const auto& desired_K_index = matmul->get_transpose_b() ? 1 : 0;
|
||||
const auto& O_index = matmul->get_transpose_b() ? 0 : 1;
|
||||
if (constant_shape[desired_K_index] != K)
|
||||
return false;
|
||||
O = static_cast<int64_t>(constant_shape[O_index]);
|
||||
}
|
||||
if (O == -1)
|
||||
return false;
|
||||
|
||||
// check add eligibility if present: has constant second input that has a form of [1, 1, ..., O] (doesn't
|
||||
// broadcast first input)
|
||||
if (pattern_to_node.count(add_label)) {
|
||||
auto add = std::dynamic_pointer_cast<opset9::Add>(pattern_to_node.at(add_label));
|
||||
if (!add || add->get_autob() != ov::op::AutoBroadcastType::NUMPY)
|
||||
return false;
|
||||
const auto& constant = std::dynamic_pointer_cast<opset9::Constant>(add->get_input_node_shared_ptr(1));
|
||||
if (!constant)
|
||||
return false;
|
||||
const auto& constant_shape = constant->get_shape();
|
||||
auto desired_ones_shape = ov::Shape(constant_shape.size(), 1);
|
||||
auto desired_shape = ov::Shape(constant_shape.size() - 1, 1);
|
||||
desired_shape.push_back(O);
|
||||
OPENVINO_ASSERT(constant_shape.size() == desired_ones_shape.size() &&
|
||||
constant_shape.size() == desired_shape.size());
|
||||
if (constant_shape != desired_shape && constant_shape != desired_ones_shape)
|
||||
return false;
|
||||
}
|
||||
|
||||
// check second Reshape eligibility: has hard-coded output pattern constant which is almost the same as
|
||||
// input_shape of the pattern except for the batch and last dimension
|
||||
auto reshape_1 = m.get_match_root();
|
||||
|
||||
const auto& constant = std::dynamic_pointer_cast<opset9::Constant>(reshape_1->get_input_node_shared_ptr(1));
|
||||
if (constant == nullptr)
|
||||
return false;
|
||||
auto output_pattern = constant->cast_vector<int64_t>();
|
||||
if (!std::all_of(output_pattern.begin(), output_pattern.end(), [](const int64_t& i) {
|
||||
return i > 0;
|
||||
}))
|
||||
return false;
|
||||
if (output_pattern.size() != input_rank)
|
||||
return false;
|
||||
for (size_t i = 0; i < input_rank; ++i) {
|
||||
if (i == 0)
|
||||
continue;
|
||||
if (i + 1 == input_rank) {
|
||||
if (output_pattern[i] != O)
|
||||
return false;
|
||||
else
|
||||
continue;
|
||||
}
|
||||
if (input_pshape[i] != output_pattern[i])
|
||||
return false;
|
||||
}
|
||||
|
||||
// this is the pattern we are looking for! performing the transformation
|
||||
auto first_reshape = std::dynamic_pointer_cast<opset9::Reshape>(reshape);
|
||||
if (!first_reshape)
|
||||
return false;
|
||||
first_reshape->set_special_zero(true);
|
||||
auto second_reshape = std::dynamic_pointer_cast<opset9::Reshape>(reshape_1);
|
||||
if (!second_reshape)
|
||||
return false;
|
||||
second_reshape->set_special_zero(true);
|
||||
|
||||
std::vector<int64_t> output_pattern_vector(input_rank - 1, 0);
|
||||
output_pattern_vector.push_back(K);
|
||||
auto new_reshape_constant =
|
||||
opset9::Constant::create(ov::element::i64, Shape{input_rank}, output_pattern_vector);
|
||||
reshape->input(1).replace_source_output(new_reshape_constant->output(0));
|
||||
|
||||
output_pattern[0] = 0;
|
||||
auto new_reshape_1_constant = opset9::Constant::create(ov::element::i64, Shape{input_rank}, output_pattern);
|
||||
reshape_1->input(1).replace_source_output(new_reshape_1_constant->output(0));
|
||||
|
||||
return true;
|
||||
};
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(reshape_1_label, matcher_name);
|
||||
register_matcher(m, callback);
|
||||
}
|
||||
@@ -6,9 +6,10 @@
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
#include <transformations/init_node_info.hpp>
|
||||
#include <transformations/smart_reshape/broadcast_const_range_replacement.hpp>
|
||||
#include <transformations/smart_reshape/lstm_states_broadcast.hpp>
|
||||
#include <transformations/smart_reshape/matmul_sr.hpp>
|
||||
#include <transformations/smart_reshape/mimic_set_batch_size.hpp>
|
||||
#include <transformations/smart_reshape/proposal_scales_stridedslice.hpp>
|
||||
#include <transformations/smart_reshape/reshape_sinking.hpp>
|
||||
#include <transformations/smart_reshape/reshape_to_1D.hpp>
|
||||
#include <transformations/smart_reshape/smart_reshape.hpp>
|
||||
#include <transformations/smart_reshape/strided_slice_squeeze.hpp>
|
||||
@@ -29,6 +30,8 @@ bool ngraph::pass::SmartReshape::run_on_model(const std::shared_ptr<ngraph::Func
|
||||
static_manager.register_pass<ngraph::pass::ReshapeTo1D>();
|
||||
static_manager.register_pass<ngraph::pass::TransposeMatMul>();
|
||||
static_manager.register_pass<ngraph::pass::BroadcastConstRangeReplacement>();
|
||||
static_manager.register_pass<ngraph::pass::LSTMStatesBroadcast>();
|
||||
static_manager.register_pass<ngraph::pass::ReshapeSinkingMatMul>();
|
||||
static_manager.run_passes(f);
|
||||
|
||||
ngraph::pass::Manager dynamic_manager;
|
||||
|
||||
@@ -0,0 +1,106 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <openvino/core/model.hpp>
|
||||
#include <openvino/opsets/opset9.hpp>
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
|
||||
template<class T>
|
||||
std::shared_ptr<ov::opset9::Constant> create_constant(const std::vector<T>& data, const ov::element::Type_t et = ov::element::i64, bool scalar = false) {
|
||||
ov::Shape shape = scalar ? ov::Shape{} : ov::Shape{data.size()};
|
||||
return ov::opset9::Constant::create(et, shape, data);
|
||||
}
|
||||
|
||||
std::shared_ptr<ov::opset9::Constant> create_zero_constant(const ov::element::Type_t et, ov::PartialShape shape) {
|
||||
return ov::opset9::Constant::create(et, shape.to_shape(), {0});
|
||||
}
|
||||
|
||||
struct LSTMStatesAttributes {
|
||||
ov::element::Type_t data_et;
|
||||
ov::Dimension data_batch_size, new_batch_size;
|
||||
ov::Dimension input_size, hidden_size;
|
||||
};
|
||||
|
||||
class LSTMStatesBroadcastTest
|
||||
: public testing::WithParamInterface<LSTMStatesAttributes>, public CommonTestUtils::TestsCommon {
|
||||
};
|
||||
|
||||
TEST_P(LSTMStatesBroadcastTest, BareLSTM) {
|
||||
auto p = GetParam();
|
||||
|
||||
std::shared_ptr<ov::Model> model(nullptr);
|
||||
{
|
||||
auto parameter = std::make_shared<ov::opset9::Parameter>(p.data_et, ov::PartialShape{p.data_batch_size, p.input_size});
|
||||
auto initial_hidden_state = create_zero_constant(p.data_et, {1, p.hidden_size});
|
||||
auto initial_cell_state = create_zero_constant(p.data_et, {1, p.hidden_size});
|
||||
auto W = create_zero_constant(p.data_et, {p.hidden_size * 4, p.input_size});
|
||||
auto R = create_zero_constant(p.data_et, {p.hidden_size * 4, p.hidden_size});
|
||||
|
||||
auto cell = std::make_shared<ov::opset9::LSTMCell>(
|
||||
parameter, initial_hidden_state, initial_cell_state, W, R, static_cast<size_t>(p.hidden_size.get_length()));
|
||||
|
||||
model = std::make_shared<ov::Model>(ov::NodeVector{cell}, ov::ParameterVector{parameter});
|
||||
}
|
||||
ASSERT_NO_THROW(model->reshape(ov::PartialShape{p.new_batch_size, p.input_size}));
|
||||
}
|
||||
|
||||
|
||||
class LSTMStatesBroadcastTestWithTI
|
||||
: public testing::WithParamInterface<LSTMStatesAttributes>, public CommonTestUtils::TestsCommon {
|
||||
};
|
||||
|
||||
TEST_P(LSTMStatesBroadcastTestWithTI, TI_With_LSTM) {
|
||||
auto p = GetParam();
|
||||
|
||||
std::shared_ptr<ov::Model> model(nullptr);
|
||||
{
|
||||
auto X = std::make_shared<ov::opset9::Parameter>(p.data_et, ov::PartialShape{p.data_batch_size, 1, p.input_size});
|
||||
auto H_init = create_zero_constant(ov::element::i64, {1, p.hidden_size});
|
||||
auto C_init = create_zero_constant(ov::element::i64, {1, p.hidden_size});
|
||||
|
||||
auto Xi = std::make_shared<ov::opset9::Parameter>(p.data_et, ov::PartialShape{1, 1, p.input_size});
|
||||
auto H_t = std::make_shared<ov::opset9::Parameter>(p.data_et, ov::PartialShape{1, p.hidden_size});
|
||||
auto C_t = std::make_shared<ov::opset9::Parameter>(p.data_et, ov::PartialShape{1, p.hidden_size});
|
||||
|
||||
// Body
|
||||
auto squeeze = std::make_shared<ov::opset9::Squeeze>(Xi, create_constant<int64_t>({1}));
|
||||
auto W = create_zero_constant(p.data_et, {p.hidden_size * 4, p.input_size});
|
||||
auto R = create_zero_constant(p.data_et, {p.hidden_size * 4, p.hidden_size});
|
||||
|
||||
auto lstm_cell = std::make_shared<ov::opset9::LSTMCell>(squeeze, H_t, C_t, W, R, static_cast<size_t>(p.hidden_size.get_length()));
|
||||
auto res_1 = std::make_shared<ov::opset9::Result>(lstm_cell->output(0));
|
||||
auto unsqueeze = std::make_shared<ov::opset9::Unsqueeze>(lstm_cell->output(0), create_constant<int64_t>({1}));
|
||||
auto res_2 = std::make_shared<ov::opset9::Result>(unsqueeze);
|
||||
auto res_3 = std::make_shared<ov::opset9::Result>(lstm_cell->output(1));
|
||||
auto body = std::make_shared<ov::Model>(ov::OutputVector{res_1, res_2, res_3}, ov::ParameterVector{Xi, H_t, C_t});
|
||||
|
||||
auto tensor_iterator = std::make_shared<ov::opset9::TensorIterator>();
|
||||
tensor_iterator->set_body(body);
|
||||
|
||||
tensor_iterator->set_merged_input(C_t, C_init, res_3);
|
||||
tensor_iterator->set_sliced_input(Xi, X, 0, 1, 1, -1, 1);
|
||||
tensor_iterator->set_merged_input(H_t, H_init, res_1);
|
||||
|
||||
auto out0 = tensor_iterator->get_iter_value(res_1, -1);
|
||||
auto out1 = tensor_iterator->get_concatenated_slices(res_2, 0, 1, 1, -1, 0);
|
||||
|
||||
auto res_ti_1 = std::make_shared<ov::opset9::Result>(tensor_iterator->output(1));
|
||||
auto res_ti_2 = std::make_shared<ov::opset9::Result>(tensor_iterator->output(0));
|
||||
model = std::make_shared<ov::Model>(ov::NodeVector{res_ti_1, res_ti_2},
|
||||
ov::ParameterVector{X});
|
||||
}
|
||||
model->reshape(ov::PartialShape{p.new_batch_size, 1, p.input_size});
|
||||
ASSERT_NO_THROW(model->reshape(ov::PartialShape{p.new_batch_size, 1, p.input_size}));
|
||||
}
|
||||
|
||||
static std::vector<LSTMStatesAttributes> params = {
|
||||
LSTMStatesAttributes{ov::element::f32, {1}, {2}, {512}, {256}},
|
||||
LSTMStatesAttributes{ov::element::f32, {-1}, {2}, {512}, {256}},
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(SmartReshapeTests, LSTMStatesBroadcastTest, ::testing::ValuesIn(params));
|
||||
INSTANTIATE_TEST_SUITE_P(SmartReshapeTests, LSTMStatesBroadcastTestWithTI, ::testing::ValuesIn(params));
|
||||
@@ -0,0 +1,79 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <openvino/core/model.hpp>
|
||||
#include <openvino/opsets/opset9.hpp>
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
|
||||
template<class T>
|
||||
std::shared_ptr<ov::opset9::Constant> create_constant(const std::vector<T>& data, const ov::element::Type_t et = ov::element::i64, bool scalar = false) {
|
||||
ov::Shape shape = scalar ? ov::Shape{} : ov::Shape{data.size()};
|
||||
return ov::opset9::Constant::create(et, shape, data);
|
||||
}
|
||||
|
||||
std::shared_ptr<ov::opset9::Constant> create_zero_constant(const ov::element::Type_t et, ov::Shape shape) {
|
||||
return ov::opset9::Constant::create(et, shape, {0});
|
||||
}
|
||||
|
||||
struct ReshapeSinkingAttributes {
|
||||
ov::element::Type_t data_et;
|
||||
ov::PartialShape input_shape;
|
||||
ov::PartialShape new_shape;
|
||||
std::vector<int64_t> output_pattern;
|
||||
std::vector<int64_t> output_pattern_back;
|
||||
ov::Shape mm_second_input_shape;
|
||||
bool transpose_a, transpose_b;
|
||||
};
|
||||
|
||||
class ReshapeSinkingTest
|
||||
: public testing::WithParamInterface<ReshapeSinkingAttributes>, public CommonTestUtils::TestsCommon {
|
||||
};
|
||||
|
||||
TEST_P(ReshapeSinkingTest, ReshapeSinkingOnlyMatMul) {
|
||||
auto p = GetParam();
|
||||
|
||||
std::shared_ptr<ov::Model> model(nullptr);
|
||||
{
|
||||
auto parameter = std::make_shared<ov::opset9::Parameter>(p.data_et, p.input_shape);
|
||||
auto reshape = std::make_shared<ov::opset9::Reshape>(parameter, create_constant(p.output_pattern), false);
|
||||
auto matmul = std::make_shared<ov::opset9::MatMul>(reshape, create_zero_constant(p.data_et, p.mm_second_input_shape),
|
||||
p.transpose_a, p.transpose_b);
|
||||
auto reshape_back = std::make_shared<ov::opset9::Reshape>(matmul, create_constant(p.output_pattern_back), false);
|
||||
model = std::make_shared<ov::Model>(ov::NodeVector{reshape_back}, ov::ParameterVector{parameter});
|
||||
}
|
||||
ASSERT_NO_THROW(model->reshape(p.new_shape));
|
||||
}
|
||||
|
||||
class ReshapeSinkingTestWithAdd
|
||||
: public testing::WithParamInterface<ReshapeSinkingAttributes>, public CommonTestUtils::TestsCommon {
|
||||
};
|
||||
|
||||
TEST_P(ReshapeSinkingTestWithAdd, ReshapeSinkingMatMulAdd) {
|
||||
auto p = GetParam();
|
||||
|
||||
std::shared_ptr<ov::Model> model(nullptr);
|
||||
{
|
||||
auto parameter = std::make_shared<ov::opset9::Parameter>(p.data_et, p.input_shape);
|
||||
auto reshape = std::make_shared<ov::opset9::Reshape>(parameter, create_constant(p.output_pattern), false);
|
||||
auto matmul = std::make_shared<ov::opset9::MatMul>(reshape, create_zero_constant(p.data_et, p.mm_second_input_shape),
|
||||
p.transpose_a, p.transpose_b);
|
||||
auto add = std::make_shared<ov::opset9::Add>(matmul, create_zero_constant(p.data_et, {1, 37}));
|
||||
auto reshape_back = std::make_shared<ov::opset9::Reshape>(add, create_constant(p.output_pattern_back), false);
|
||||
model = std::make_shared<ov::Model>(ov::NodeVector{reshape_back}, ov::ParameterVector{parameter});
|
||||
}
|
||||
ASSERT_NO_THROW(model->reshape(p.new_shape));
|
||||
}
|
||||
|
||||
static std::vector<ReshapeSinkingAttributes> params = {
|
||||
ReshapeSinkingAttributes{ov::element::f32, {10, 30, 512}, {20, 30, 512}, {-1, 512}, {10, 30, 37}, {37, 512}, false, true},
|
||||
ReshapeSinkingAttributes{ov::element::f32, {-1, 30, 512}, {20, 30, 512}, {-1, 512}, {10, 30, 37}, {37, 512}, false, true},
|
||||
ReshapeSinkingAttributes{ov::element::f32, {1, 3, 4, 512}, {2, 3, 4, 512}, {-1, 512}, {1, 3, 4, 37}, {37, 512}, false, true},
|
||||
ReshapeSinkingAttributes{ov::element::f32, {1, 3, 4, 512}, {2, 3, 4, 512}, {-1, 512}, {1, 3, 4, 37}, {512, 37}, false, false},
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(SmartReshapeTests, ReshapeSinkingTest, ::testing::ValuesIn(params));
|
||||
INSTANTIATE_TEST_SUITE_P(SmartReshapeTests, ReshapeSinkingTestWithAdd, ::testing::ValuesIn(params));
|
||||
Reference in New Issue
Block a user