Resolve comments 12736 (#12778)

* Comments resolving

* Style and getting rid of asserts

* style

* Update src/common/transformations/src/transformations/smart_reshape/lstm_states_broadcast.cpp
This commit is contained in:
Evgenya Stepyreva 2022-08-29 10:16:49 +04:00 committed by GitHub
parent 7601400d99
commit 79f1e720e7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 156 additions and 172 deletions

View File

@ -5,16 +5,18 @@
#pragma once
#include <memory>
#include <ngraph/pass/graph_rewrite.hpp>
#include <vector>
namespace ngraph {
#include "openvino/pass/graph_rewrite.hpp"
#include "transformations_visibility.hpp"
namespace ov {
namespace pass {
class NGRAPH_API LSTMStatesBroadcast;
class TRANSFORMATIONS_API LSTMStatesBroadcast;
} // namespace pass
} // namespace ngraph
} // namespace ov
/**
* @ingroup ie_transformation_common_api
@ -22,8 +24,8 @@ class NGRAPH_API LSTMStatesBroadcast;
* we make them broadcast-able by batch
*/
class ngraph::pass::LSTMStatesBroadcast : public ngraph::pass::FunctionPass {
class ov::pass::LSTMStatesBroadcast : public ov::pass::ModelPass {
public:
OPENVINO_RTTI("LSTMStatesBroadcast", "0");
bool run_on_model(const std::shared_ptr<ngraph::Function>& m) override;
bool run_on_model(const std::shared_ptr<ov::Model>& m) override;
};

View File

@ -5,16 +5,18 @@
#pragma once
#include <memory>
#include <ngraph/pass/graph_rewrite.hpp>
#include <vector>
namespace ngraph {
#include "openvino/pass/graph_rewrite.hpp"
#include "transformations_visibility.hpp"
namespace ov {
namespace pass {
class NGRAPH_API ReshapeSinkingMatMul;
class TRANSFORMATIONS_API ReshapeSinkingMatMul;
} // namespace pass
} // namespace ngraph
} // namespace ov
/**
* @ingroup ie_transformation_common_api
@ -24,7 +26,7 @@ class NGRAPH_API ReshapeSinkingMatMul;
* Reshape operators to make batch propagate through freely
*/
class ngraph::pass::ReshapeSinkingMatMul : public ngraph::pass::MatcherPass {
class ov::pass::ReshapeSinkingMatMul : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ReshapeSinkingMatMul", "0");
ReshapeSinkingMatMul();

View File

@ -118,10 +118,10 @@ bool ngraph::pass::MOCTransformations::run_on_model(const std::shared_ptr<ngraph
manager.register_pass<ngraph::pass::Validate>();
if (!m_use_shapes) { // Approved Smart Reshape
manager.register_pass<ngraph::pass::LSTMStatesBroadcast>();
manager.register_pass<ngraph::pass::Validate>();
manager.register_pass<ngraph::pass::ReshapeSinkingMatMul>();
manager.register_pass<ngraph::pass::Validate>();
manager.register_pass<ov::pass::LSTMStatesBroadcast>();
manager.register_pass<ov::pass::Validate>();
manager.register_pass<ov::pass::ReshapeSinkingMatMul>();
manager.register_pass<ov::pass::Validate>();
}
manager.register_pass<ngraph::pass::ConvertQuantizeDequantize>();

View File

@ -2,26 +2,23 @@
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/smart_reshape/lstm_states_broadcast.hpp"
#include <memory>
#include <ngraph/pass/manager.hpp>
#include <openvino/op/util/sub_graph_base.hpp>
#include <openvino/opsets/opset9.hpp>
#include <transformations/smart_reshape/lstm_states_broadcast.hpp>
#include <transformations/utils/utils.hpp>
#include "dimension_tracker.hpp"
#include "itt.hpp"
#include "openvino/op/util/sub_graph_base.hpp"
#include "openvino/opsets/opset9.hpp"
#include "openvino/pass/manager.hpp"
#include "transformations/utils/utils.hpp"
ov::Input<ov::Node> get_outer_input_of_ti_by_parameter(const std::shared_ptr<ov::opset9::Parameter>& parameter,
const std::shared_ptr<ov::opset9::TensorIterator>& ti) {
const auto& body = ti->get_body();
OPENVINO_ASSERT(body != nullptr, "TI returns invalid body graph ", ti);
using namespace std;
using namespace ov::opset9;
ov::Input<ov::Node> get_outer_input_of_ti_by_parameter(const shared_ptr<Parameter>& parameter,
const shared_ptr<TensorIterator>& ti) {
int64_t parameter_index = ti->get_body()->get_parameter_index(parameter);
OPENVINO_ASSERT(parameter_index >= 0,
"LSTMStatesBroadcast encountered unregistered parameter ",
parameter,
" related to TI body ",
ti);
for (const auto& input_descriptor : ti->get_input_descriptions())
if (input_descriptor->m_body_parameter_index == parameter_index)
return ti->input(input_descriptor->m_input_index);
@ -31,13 +28,11 @@ ov::Input<ov::Node> get_outer_input_of_ti_by_parameter(const std::shared_ptr<ov:
parameter);
}
std::shared_ptr<ov::Node> deduce_outer_source_of_batch_for_inner_lstm_cell(
const std::shared_ptr<ov::opset9::TensorIterator>& ti,
const std::shared_ptr<ov::opset9::LSTMCell>& lstm_cell) {
const auto& body = ti->get_body();
OPENVINO_ASSERT(body != nullptr, "TI returns invalid body graph ", ti);
shared_ptr<ov::Node> deduce_outer_source_of_batch_for_inner_lstm_cell(const shared_ptr<TensorIterator>& ti,
const shared_ptr<LSTMCell>& lstm_cell) {
const auto& body = ti->get_body(); // body is not nullptr -- we checked earlier
std::map<ov::opset9::Parameter*, ov::PartialShape> original_shapes;
map<Parameter*, ov::PartialShape> original_shapes;
size_t label = 1;
// mark all input dimensions with labels and making them dynamic, keeping original shapes
@ -46,9 +41,7 @@ std::shared_ptr<ov::Node> deduce_outer_source_of_batch_for_inner_lstm_cell(
original_shapes[parameter.get()] = pshape;
if (pshape.rank().is_dynamic())
continue;
for (ngraph::Dimension& n : pshape) {
OPENVINO_ASSERT(ov::DimensionTracker::get_label(n) == 0,
"LSTMStatesBroadcast encountered TI with previously tracked dimensions");
for (ov::Dimension& n : pshape) {
n = ov::Dimension::dynamic();
ov::DimensionTracker::set_label(n, label++);
}
@ -68,7 +61,7 @@ std::shared_ptr<ov::Node> deduce_outer_source_of_batch_for_inner_lstm_cell(
}
// batch label was tracked -- finding parameter that delivered it
std::shared_ptr<ov::opset9::Parameter> batch_delivering_parameter;
shared_ptr<Parameter> batch_delivering_parameter;
size_t index_of_batch_dim = 0;
size_t batch_label = ov::DimensionTracker::get_label(lstm_cell->get_input_partial_shape(0)[0]);
@ -80,8 +73,11 @@ std::shared_ptr<ov::Node> deduce_outer_source_of_batch_for_inner_lstm_cell(
if (ov::DimensionTracker::get_label(pshape[i]) == batch_label) {
batch_delivering_parameter = parameter;
index_of_batch_dim = i;
break;
}
}
if (index_of_batch_dim != 0 && batch_delivering_parameter != nullptr)
break;
}
for (auto& item : original_shapes)
item.first->set_partial_shape(item.second);
@ -91,89 +87,83 @@ std::shared_ptr<ov::Node> deduce_outer_source_of_batch_for_inner_lstm_cell(
return nullptr;
const auto& batched_source = get_outer_input_of_ti_by_parameter(batch_delivering_parameter, ti);
const auto& batched_shape = std::make_shared<ov::opset9::ShapeOf>(batched_source.get_source_output());
const auto& batch = std::make_shared<ov::opset9::Gather>(
batched_shape,
ov::opset9::Constant::create(ov::element::i64, ov::Shape{1}, {index_of_batch_dim}),
ov::opset9::Constant::create(ov::element::i64, ov::Shape{}, {0}));
const auto& batched_shape = make_shared<ShapeOf>(batched_source.get_source_output());
const auto& batch = make_shared<Gather>(batched_shape,
Constant::create(ov::element::i64, ov::Shape{1}, {index_of_batch_dim}),
Constant::create(ov::element::i64, ov::Shape{}, {0}));
return batch;
}
bool broadcast_state_by_batch(ov::Input<ov::Node> input, const std::shared_ptr<ov::Node>& batch_delivering_node) {
auto constant_state =
std::dynamic_pointer_cast<ov::opset9::Constant>(input.get_source_output().get_node_shared_ptr());
bool broadcast_state_by_batch(ov::Input<ov::Node> input, const shared_ptr<ov::Node>& batch_delivering_node) {
auto constant_state = dynamic_pointer_cast<Constant>(input.get_source_output().get_node_shared_ptr());
if (constant_state == nullptr)
return false;
const auto& constant_shape = constant_state->get_shape();
OPENVINO_ASSERT(constant_shape.size() == 2, "State has unexpected shape ", constant_shape);
if (constant_shape[0] != 1)
// we only expect to broadcast LSTM states prepared for batch 1 -- no tiling of batch > 1 will be done
return false;
const auto& constant_copy = constant_state->copy_with_new_inputs({});
const auto& broadcast_by_batch = std::make_shared<ov::opset9::Broadcast>(
const auto& broadcast_by_batch = make_shared<Broadcast>(
constant_copy,
std::make_shared<ov::opset9::Concat>(
ngraph::NodeVector{batch_delivering_node,
ngraph::op::util::make_try_fold<ov::opset9::Gather>(
ngraph::op::util::make_try_fold<ov::opset9::ShapeOf>(constant_copy),
ov::opset9::Constant::create(ov::element::i64, ov::Shape{1}, {1}),
ov::opset9::Constant::create(ov::element::i64, ov::Shape{}, {0}))},
0));
make_shared<Concat>(ngraph::NodeVector{batch_delivering_node,
ngraph::op::util::make_try_fold<Gather>(
ngraph::op::util::make_try_fold<ShapeOf>(constant_copy),
Constant::create(ov::element::i64, ov::Shape{1}, {1}),
Constant::create(ov::element::i64, ov::Shape{}, {0}))},
0));
input.replace_source_output(broadcast_by_batch->output(0));
return true;
}
bool relax_batch_for_initial_states_of_lstm_in_ti(const std::shared_ptr<ov::opset9::TensorIterator>& ti,
const std::shared_ptr<ov::opset9::LSTMCell>& lstm_cell) {
bool relax_batch_for_initial_states_of_lstm_in_ti(const shared_ptr<TensorIterator>& ti,
const shared_ptr<LSTMCell>& lstm_cell) {
bool rewritten = false;
auto batch_delivering_node = deduce_outer_source_of_batch_for_inner_lstm_cell(ti, lstm_cell);
if (batch_delivering_node == nullptr)
return rewritten;
if (auto init_hidden_state =
std::dynamic_pointer_cast<ov::opset9::Parameter>(lstm_cell->get_input_node_shared_ptr(1))) {
if (auto init_hidden_state = dynamic_pointer_cast<Parameter>(lstm_cell->get_input_node_shared_ptr(1))) {
auto outer_init_hidden_state_input = get_outer_input_of_ti_by_parameter(init_hidden_state, ti);
rewritten |= broadcast_state_by_batch(outer_init_hidden_state_input, batch_delivering_node);
}
if (auto init_cell_state =
std::dynamic_pointer_cast<ov::opset9::Parameter>(lstm_cell->get_input_node_shared_ptr(2))) {
if (auto init_cell_state = dynamic_pointer_cast<Parameter>(lstm_cell->get_input_node_shared_ptr(2))) {
auto outer_init_cell_state_input = get_outer_input_of_ti_by_parameter(init_cell_state, ti);
rewritten |= broadcast_state_by_batch(outer_init_cell_state_input, batch_delivering_node);
}
return rewritten;
}
bool relax_batch_for_initial_states_of_lstm(const std::shared_ptr<ov::opset9::LSTMCell>& lstm_cell) {
bool relax_batch_for_initial_states_of_lstm(const shared_ptr<LSTMCell>& lstm_cell) {
bool rewritten = false;
const auto& batched_shape = std::make_shared<ov::opset9::ShapeOf>(lstm_cell->get_input_source_output(0));
const auto& batch_delivering_node =
std::make_shared<ov::opset9::Gather>(batched_shape,
ov::opset9::Constant::create(ov::element::i64, ov::Shape{1}, {0}),
ov::opset9::Constant::create(ov::element::i64, ov::Shape{}, {0}));
const auto& batched_shape = make_shared<ShapeOf>(lstm_cell->get_input_source_output(0));
const auto& batch_delivering_node = make_shared<Gather>(batched_shape,
Constant::create(ov::element::i64, ov::Shape{1}, {0}),
Constant::create(ov::element::i64, ov::Shape{}, {0}));
rewritten |= broadcast_state_by_batch(lstm_cell->input(1), batch_delivering_node);
rewritten |= broadcast_state_by_batch(lstm_cell->input(2), batch_delivering_node);
return rewritten;
}
bool ngraph::pass::LSTMStatesBroadcast::run_on_model(const std::shared_ptr<ngraph::Function>& f) {
bool ov::pass::LSTMStatesBroadcast::run_on_model(const shared_ptr<ov::Model>& f) {
RUN_ON_FUNCTION_SCOPE(LSTMStatesBroadcast);
bool rewritten = false;
for (auto& node : f->get_ordered_ops()) {
// Recursively apply transformation for sub-graph based operations
if (const auto& sub_graph_node = std::dynamic_pointer_cast<ov::op::util::SubGraphOp>(node))
if (const auto& sub_graph_node = dynamic_pointer_cast<ov::op::util::SubGraphOp>(node))
if (const auto& sub_graph = sub_graph_node->get_function())
rewritten |= run_on_model(sub_graph);
// Case without TI (LSTMCell and Constant are in the same ov::Model)
if (const auto& lstm_cell = std::dynamic_pointer_cast<ov::opset9::LSTMCell>(node))
if (const auto& lstm_cell = dynamic_pointer_cast<LSTMCell>(node))
rewritten |= relax_batch_for_initial_states_of_lstm(lstm_cell);
// Case with TI (LSTMCell and Constant are in different ov::Model objects)
if (auto ti = std::dynamic_pointer_cast<ov::opset9::TensorIterator>(node)) {
if (auto ti = dynamic_pointer_cast<TensorIterator>(node)) {
auto body = ti->get_body();
OPENVINO_ASSERT(body, "TensorIterator must have body network");
if (body == nullptr)
continue;
for (const auto& body_node : body->get_ordered_ops())
if (const auto& lstm_cell = std::dynamic_pointer_cast<ov::opset9::LSTMCell>(body_node))
if (const auto& lstm_cell = dynamic_pointer_cast<LSTMCell>(body_node))
rewritten |= relax_batch_for_initial_states_of_lstm_in_ti(ti, lstm_cell);
}
}

View File

@ -2,16 +2,18 @@
// SPDX-License-Identifier: Apache-2.0
//
#include <ngraph/opsets/opset9.hpp>
#include <ngraph/pattern/matcher.hpp>
#include <ngraph/pattern/op/or.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <ngraph/rt_info.hpp>
#include <transformations/smart_reshape/reshape_sinking.hpp>
#include "transformations/smart_reshape/reshape_sinking.hpp"
#include "itt.hpp"
#include "openvino/opsets/opset9.hpp"
#include "openvino/pass/pattern/matcher.hpp"
#include "openvino/pass/pattern/op/or.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
ngraph::pass::ReshapeSinkingMatMul::ReshapeSinkingMatMul() {
using namespace std;
using namespace ov::opset9;
ov::pass::ReshapeSinkingMatMul::ReshapeSinkingMatMul() {
MATCHER_SCOPE(ReshapeSinkingMatMul);
/* Original graph: Transformed graph:
*
@ -25,22 +27,20 @@ ngraph::pass::ReshapeSinkingMatMul::ReshapeSinkingMatMul() {
* | shape=[1, S, O] | shape=[B, S, O]
*/
auto any_input = pattern::any_input(pattern::has_static_rank());
auto reshape_label = ngraph::pattern::wrap_type<opset9::Reshape>(
{pattern::any_input(), ngraph::pattern::wrap_type<opset9::Constant>()},
pattern::rank_equals(2));
auto reshape_label =
ov::pass::pattern::wrap_type<Reshape>({pattern::any_input(), ov::pass::pattern::wrap_type<Constant>()},
pattern::rank_equals(2));
auto matmul_label =
ngraph::pattern::wrap_type<opset9::MatMul>({reshape_label, ngraph::pattern::wrap_type<opset9::Constant>()},
pattern::rank_equals(2));
auto add_label =
ngraph::pattern::wrap_type<opset9::Add>({matmul_label, ngraph::pattern::wrap_type<opset9::Constant>()},
pattern::rank_equals(2));
auto matmul_label = ov::pass::pattern::wrap_type<MatMul>({reshape_label, ov::pass::pattern::wrap_type<Constant>()},
pattern::rank_equals(2));
auto add_label = ov::pass::pattern::wrap_type<Add>({matmul_label, ov::pass::pattern::wrap_type<Constant>()},
pattern::rank_equals(2));
auto matmul_or_matmul_add_label = std::make_shared<pattern::op::Or>(OutputVector{add_label, matmul_label});
auto matmul_or_matmul_add_label = make_shared<pattern::op::Or>(OutputVector{add_label, matmul_label});
auto reshape_1_label = ngraph::pattern::wrap_type<opset9::Reshape>(
{matmul_or_matmul_add_label, ngraph::pattern::wrap_type<opset9::Constant>()},
pattern::has_static_rank());
auto reshape_1_label =
ov::pass::pattern::wrap_type<Reshape>({matmul_or_matmul_add_label, ov::pass::pattern::wrap_type<Constant>()},
pattern::has_static_rank());
matcher_pass_callback callback = [=](pattern::Matcher& m) -> bool {
auto pattern_to_node = m.get_pattern_map();
@ -48,7 +48,7 @@ ngraph::pass::ReshapeSinkingMatMul::ReshapeSinkingMatMul() {
// check first Reshape eligibility: has a constant output pattern in a form of [-1, K]
auto reshape = pattern_to_node.at(reshape_label);
int64_t K = -1;
if (const auto& constant = std::dynamic_pointer_cast<opset9::Constant>(reshape->get_input_node_shared_ptr(1))) {
if (const auto& constant = dynamic_pointer_cast<Constant>(reshape->get_input_node_shared_ptr(1))) {
auto output_pattern_vector = constant->cast_vector<int64_t>();
if (output_pattern_vector.size() != 2 || output_pattern_vector[0] != -1)
return false;
@ -66,11 +66,11 @@ ngraph::pass::ReshapeSinkingMatMul::ReshapeSinkingMatMul() {
return false;
// check matmul eligibility: has constant second input in a form of [O, K]
auto matmul = std::dynamic_pointer_cast<opset9::MatMul>(pattern_to_node.at(matmul_label));
auto matmul = dynamic_pointer_cast<MatMul>(pattern_to_node.at(matmul_label));
if (!matmul || matmul->get_transpose_a())
return false;
int64_t O = -1;
if (const auto& constant = std::dynamic_pointer_cast<opset9::Constant>(matmul->get_input_node_shared_ptr(1))) {
if (const auto& constant = dynamic_pointer_cast<Constant>(matmul->get_input_node_shared_ptr(1))) {
const auto& constant_shape = constant->get_shape();
if (constant_shape.size() != 2)
return false;
@ -86,10 +86,10 @@ ngraph::pass::ReshapeSinkingMatMul::ReshapeSinkingMatMul() {
// check add eligibility if present: has constant second input that has a form of [1, 1, ..., O] (doesn't
// broadcast first input)
if (pattern_to_node.count(add_label)) {
auto add = std::dynamic_pointer_cast<opset9::Add>(pattern_to_node.at(add_label));
auto add = dynamic_pointer_cast<Add>(pattern_to_node.at(add_label));
if (!add || add->get_autob() != ov::op::AutoBroadcastType::NUMPY)
return false;
const auto& constant = std::dynamic_pointer_cast<opset9::Constant>(add->get_input_node_shared_ptr(1));
const auto& constant = dynamic_pointer_cast<Constant>(add->get_input_node_shared_ptr(1));
if (!constant)
return false;
const auto& constant_shape = constant->get_shape();
@ -106,19 +106,17 @@ ngraph::pass::ReshapeSinkingMatMul::ReshapeSinkingMatMul() {
// input_shape of the pattern except for the batch and last dimension
auto reshape_1 = m.get_match_root();
const auto& constant = std::dynamic_pointer_cast<opset9::Constant>(reshape_1->get_input_node_shared_ptr(1));
const auto& constant = dynamic_pointer_cast<Constant>(reshape_1->get_input_node_shared_ptr(1));
if (constant == nullptr)
return false;
auto output_pattern = constant->cast_vector<int64_t>();
if (!std::all_of(output_pattern.begin(), output_pattern.end(), [](const int64_t& i) {
if (output_pattern.size() != input_rank)
return false;
if (!all_of(output_pattern.begin(), output_pattern.end(), [](const int64_t& i) {
return i > 0;
}))
return false;
if (output_pattern.size() != input_rank)
return false;
for (size_t i = 0; i < input_rank; ++i) {
if (i == 0)
continue;
for (size_t i = 1; i < input_rank; ++i) {
if (i + 1 == input_rank) {
if (output_pattern[i] != O)
return false;
@ -129,28 +127,26 @@ ngraph::pass::ReshapeSinkingMatMul::ReshapeSinkingMatMul() {
return false;
}
auto first_reshape = dynamic_pointer_cast<Reshape>(reshape);
auto second_reshape = dynamic_pointer_cast<Reshape>(reshape_1);
if (!first_reshape || !second_reshape)
return false;
// this is the pattern we are looking for! performing the transformation
auto first_reshape = std::dynamic_pointer_cast<opset9::Reshape>(reshape);
if (!first_reshape)
return false;
first_reshape->set_special_zero(true);
auto second_reshape = std::dynamic_pointer_cast<opset9::Reshape>(reshape_1);
if (!second_reshape)
return false;
second_reshape->set_special_zero(true);
std::vector<int64_t> output_pattern_vector(input_rank - 1, 0);
vector<int64_t> output_pattern_vector(input_rank - 1, 0);
output_pattern_vector.push_back(K);
auto new_reshape_constant =
opset9::Constant::create(ov::element::i64, Shape{input_rank}, output_pattern_vector);
auto new_reshape_constant = Constant::create(ov::element::i64, Shape{input_rank}, output_pattern_vector);
reshape->input(1).replace_source_output(new_reshape_constant->output(0));
output_pattern[0] = 0;
auto new_reshape_1_constant = opset9::Constant::create(ov::element::i64, Shape{input_rank}, output_pattern);
auto new_reshape_1_constant = Constant::create(ov::element::i64, Shape{input_rank}, output_pattern);
reshape_1->input(1).replace_source_output(new_reshape_1_constant->output(0));
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(reshape_1_label, matcher_name);
auto m = make_shared<ov::pass::pattern::Matcher>(reshape_1_label, matcher_name);
register_matcher(m, callback);
}

View File

@ -30,8 +30,8 @@ bool ngraph::pass::SmartReshape::run_on_model(const std::shared_ptr<ngraph::Func
static_manager.register_pass<ngraph::pass::ReshapeTo1D>();
static_manager.register_pass<ngraph::pass::TransposeMatMul>();
static_manager.register_pass<ngraph::pass::BroadcastConstRangeReplacement>();
static_manager.register_pass<ngraph::pass::LSTMStatesBroadcast>();
static_manager.register_pass<ngraph::pass::ReshapeSinkingMatMul>();
static_manager.register_pass<ov::pass::LSTMStatesBroadcast>();
static_manager.register_pass<ov::pass::ReshapeSinkingMatMul>();
static_manager.run_passes(f);
ngraph::pass::Manager dynamic_manager;

View File

@ -4,20 +4,13 @@
#include <gtest/gtest.h>
#include <openvino/core/model.hpp>
#include <openvino/opsets/opset9.hpp>
#include "openvino/core/model.hpp"
#include "openvino/opsets/opset9.hpp"
#include "common_test_utils/ngraph_test_utils.hpp"
template<class T>
std::shared_ptr<ov::opset9::Constant> create_constant(const std::vector<T>& data, const ov::element::Type_t et = ov::element::i64, bool scalar = false) {
ov::Shape shape = scalar ? ov::Shape{} : ov::Shape{data.size()};
return ov::opset9::Constant::create(et, shape, data);
}
std::shared_ptr<ov::opset9::Constant> create_zero_constant(const ov::element::Type_t et, ov::PartialShape shape) {
return ov::opset9::Constant::create(et, shape.to_shape(), {0});
}
using namespace std;
using namespace ov::opset9;
struct LSTMStatesAttributes {
ov::element::Type_t data_et;
@ -32,18 +25,18 @@ class LSTMStatesBroadcastTest
TEST_P(LSTMStatesBroadcastTest, BareLSTM) {
auto p = GetParam();
std::shared_ptr<ov::Model> model(nullptr);
shared_ptr<ov::Model> model(nullptr);
{
auto parameter = std::make_shared<ov::opset9::Parameter>(p.data_et, ov::PartialShape{p.data_batch_size, p.input_size});
auto initial_hidden_state = create_zero_constant(p.data_et, {1, p.hidden_size});
auto initial_cell_state = create_zero_constant(p.data_et, {1, p.hidden_size});
auto W = create_zero_constant(p.data_et, {p.hidden_size * 4, p.input_size});
auto R = create_zero_constant(p.data_et, {p.hidden_size * 4, p.hidden_size});
auto parameter = make_shared<Parameter>(p.data_et, ov::PartialShape{p.data_batch_size, p.input_size});
auto initial_hidden_state = create_zero_constant(p.data_et, ov::PartialShape{1, p.hidden_size}.to_shape());
auto initial_cell_state = create_zero_constant(p.data_et, ov::PartialShape{1, p.hidden_size}.to_shape());
auto W = create_zero_constant(p.data_et, ov::PartialShape{p.hidden_size * 4, p.input_size}.to_shape());
auto R = create_zero_constant(p.data_et, ov::PartialShape{p.hidden_size * 4, p.hidden_size}.to_shape());
auto cell = std::make_shared<ov::opset9::LSTMCell>(
auto cell = make_shared<LSTMCell>(
parameter, initial_hidden_state, initial_cell_state, W, R, static_cast<size_t>(p.hidden_size.get_length()));
model = std::make_shared<ov::Model>(ov::NodeVector{cell}, ov::ParameterVector{parameter});
model = make_shared<ov::Model>(ov::NodeVector{cell}, ov::ParameterVector{parameter});
}
ASSERT_NO_THROW(model->reshape(ov::PartialShape{p.new_batch_size, p.input_size}));
}
@ -56,29 +49,29 @@ class LSTMStatesBroadcastTestWithTI
TEST_P(LSTMStatesBroadcastTestWithTI, TI_With_LSTM) {
auto p = GetParam();
std::shared_ptr<ov::Model> model(nullptr);
shared_ptr<ov::Model> model(nullptr);
{
auto X = std::make_shared<ov::opset9::Parameter>(p.data_et, ov::PartialShape{p.data_batch_size, 1, p.input_size});
auto H_init = create_zero_constant(ov::element::i64, {1, p.hidden_size});
auto C_init = create_zero_constant(ov::element::i64, {1, p.hidden_size});
auto X = make_shared<Parameter>(p.data_et, ov::PartialShape{p.data_batch_size, 1, p.input_size});
auto H_init = create_zero_constant(ov::element::i64, ov::PartialShape{1, p.hidden_size}.to_shape());
auto C_init = create_zero_constant(ov::element::i64, ov::PartialShape{1, p.hidden_size}.to_shape());
auto Xi = std::make_shared<ov::opset9::Parameter>(p.data_et, ov::PartialShape{1, 1, p.input_size});
auto H_t = std::make_shared<ov::opset9::Parameter>(p.data_et, ov::PartialShape{1, p.hidden_size});
auto C_t = std::make_shared<ov::opset9::Parameter>(p.data_et, ov::PartialShape{1, p.hidden_size});
auto Xi = make_shared<Parameter>(p.data_et, ov::PartialShape{1, 1, p.input_size});
auto H_t = make_shared<Parameter>(p.data_et, ov::PartialShape{1, p.hidden_size});
auto C_t = make_shared<Parameter>(p.data_et, ov::PartialShape{1, p.hidden_size});
// Body
auto squeeze = std::make_shared<ov::opset9::Squeeze>(Xi, create_constant<int64_t>({1}));
auto W = create_zero_constant(p.data_et, {p.hidden_size * 4, p.input_size});
auto R = create_zero_constant(p.data_et, {p.hidden_size * 4, p.hidden_size});
auto squeeze = make_shared<Squeeze>(Xi, create_constant<int64_t>({1}));
auto W = create_zero_constant(p.data_et, ov::PartialShape{p.hidden_size * 4, p.input_size}.to_shape());
auto R = create_zero_constant(p.data_et, ov::PartialShape{p.hidden_size * 4, p.hidden_size}.to_shape());
auto lstm_cell = std::make_shared<ov::opset9::LSTMCell>(squeeze, H_t, C_t, W, R, static_cast<size_t>(p.hidden_size.get_length()));
auto res_1 = std::make_shared<ov::opset9::Result>(lstm_cell->output(0));
auto unsqueeze = std::make_shared<ov::opset9::Unsqueeze>(lstm_cell->output(0), create_constant<int64_t>({1}));
auto res_2 = std::make_shared<ov::opset9::Result>(unsqueeze);
auto res_3 = std::make_shared<ov::opset9::Result>(lstm_cell->output(1));
auto body = std::make_shared<ov::Model>(ov::OutputVector{res_1, res_2, res_3}, ov::ParameterVector{Xi, H_t, C_t});
auto lstm_cell = make_shared<LSTMCell>(squeeze, H_t, C_t, W, R, static_cast<size_t>(p.hidden_size.get_length()));
auto res_1 = make_shared<Result>(lstm_cell->output(0));
auto unsqueeze = make_shared<Unsqueeze>(lstm_cell->output(0), create_constant<int64_t>({1}));
auto res_2 = make_shared<Result>(unsqueeze);
auto res_3 = make_shared<Result>(lstm_cell->output(1));
auto body = make_shared<ov::Model>(ov::OutputVector{res_1, res_2, res_3}, ov::ParameterVector{Xi, H_t, C_t});
auto tensor_iterator = std::make_shared<ov::opset9::TensorIterator>();
auto tensor_iterator = make_shared<TensorIterator>();
tensor_iterator->set_body(body);
tensor_iterator->set_merged_input(C_t, C_init, res_3);
@ -88,16 +81,15 @@ TEST_P(LSTMStatesBroadcastTestWithTI, TI_With_LSTM) {
auto out0 = tensor_iterator->get_iter_value(res_1, -1);
auto out1 = tensor_iterator->get_concatenated_slices(res_2, 0, 1, 1, -1, 0);
auto res_ti_1 = std::make_shared<ov::opset9::Result>(tensor_iterator->output(1));
auto res_ti_2 = std::make_shared<ov::opset9::Result>(tensor_iterator->output(0));
model = std::make_shared<ov::Model>(ov::NodeVector{res_ti_1, res_ti_2},
auto res_ti_1 = make_shared<Result>(tensor_iterator->output(1));
auto res_ti_2 = make_shared<Result>(tensor_iterator->output(0));
model = make_shared<ov::Model>(ov::NodeVector{res_ti_1, res_ti_2},
ov::ParameterVector{X});
}
model->reshape(ov::PartialShape{p.new_batch_size, 1, p.input_size});
ASSERT_NO_THROW(model->reshape(ov::PartialShape{p.new_batch_size, 1, p.input_size}));
}
static std::vector<LSTMStatesAttributes> params = {
static vector<LSTMStatesAttributes> params = {
LSTMStatesAttributes{ov::element::f32, {1}, {2}, {512}, {256}},
LSTMStatesAttributes{ov::element::f32, {-1}, {2}, {512}, {256}},
};

View File

@ -4,21 +4,11 @@
#include <gtest/gtest.h>
#include <openvino/core/model.hpp>
#include <openvino/opsets/opset9.hpp>
#include "openvino/core/model.hpp"
#include "openvino/opsets/opset9.hpp"
#include "common_test_utils/ngraph_test_utils.hpp"
template<class T>
std::shared_ptr<ov::opset9::Constant> create_constant(const std::vector<T>& data, const ov::element::Type_t et = ov::element::i64, bool scalar = false) {
ov::Shape shape = scalar ? ov::Shape{} : ov::Shape{data.size()};
return ov::opset9::Constant::create(et, shape, data);
}
std::shared_ptr<ov::opset9::Constant> create_zero_constant(const ov::element::Type_t et, ov::Shape shape) {
return ov::opset9::Constant::create(et, shape, {0});
}
struct ReshapeSinkingAttributes {
ov::element::Type_t data_et;
ov::PartialShape input_shape;

View File

@ -64,3 +64,7 @@ void check_unique_names(std::shared_ptr<ngraph::Function> f, const std::shared_p
manager.register_pass<ngraph::pass::CheckUniqueNames>(unh, true);
manager.run_passes(f);
}
std::shared_ptr<ov::opset8::Constant> create_zero_constant(const ov::element::Type_t& et, const ov::Shape& shape) {
return ov::opset8::Constant::create(et, shape, {0});
}

View File

@ -69,3 +69,11 @@ size_t count_ops_of_type(const std::shared_ptr<ngraph::Function>& f) {
return count;
}
template<class T>
std::shared_ptr<ov::opset8::Constant> create_constant(const std::vector<T>& data, const ov::element::Type_t et = ov::element::i64, bool scalar = false) {
ov::Shape shape = scalar ? ov::Shape{} : ov::Shape{data.size()};
return ov::opset8::Constant::create(et, shape, data);
}
std::shared_ptr<ov::opset8::Constant> create_zero_constant(const ov::element::Type_t& et, const ov::Shape& shape);