diff --git a/src/common/transformations/include/transformations/smart_reshape/lstm_states_broadcast.hpp b/src/common/transformations/include/transformations/smart_reshape/lstm_states_broadcast.hpp index 2e8e2345726..037a5f06c72 100644 --- a/src/common/transformations/include/transformations/smart_reshape/lstm_states_broadcast.hpp +++ b/src/common/transformations/include/transformations/smart_reshape/lstm_states_broadcast.hpp @@ -5,16 +5,18 @@ #pragma once #include -#include #include -namespace ngraph { +#include "openvino/pass/graph_rewrite.hpp" +#include "transformations_visibility.hpp" + +namespace ov { namespace pass { -class NGRAPH_API LSTMStatesBroadcast; +class TRANSFORMATIONS_API LSTMStatesBroadcast; } // namespace pass -} // namespace ngraph +} // namespace ov /** * @ingroup ie_transformation_common_api @@ -22,8 +24,8 @@ class NGRAPH_API LSTMStatesBroadcast; * we make them broadcast-able by batch */ -class ngraph::pass::LSTMStatesBroadcast : public ngraph::pass::FunctionPass { +class ov::pass::LSTMStatesBroadcast : public ov::pass::ModelPass { public: OPENVINO_RTTI("LSTMStatesBroadcast", "0"); - bool run_on_model(const std::shared_ptr& m) override; + bool run_on_model(const std::shared_ptr& m) override; }; diff --git a/src/common/transformations/include/transformations/smart_reshape/reshape_sinking.hpp b/src/common/transformations/include/transformations/smart_reshape/reshape_sinking.hpp index a755e337027..b8b7eabc1df 100644 --- a/src/common/transformations/include/transformations/smart_reshape/reshape_sinking.hpp +++ b/src/common/transformations/include/transformations/smart_reshape/reshape_sinking.hpp @@ -5,16 +5,18 @@ #pragma once #include -#include #include -namespace ngraph { +#include "openvino/pass/graph_rewrite.hpp" +#include "transformations_visibility.hpp" + +namespace ov { namespace pass { -class NGRAPH_API ReshapeSinkingMatMul; +class TRANSFORMATIONS_API ReshapeSinkingMatMul; } // namespace pass -} // namespace ngraph +} // namespace ov /** * @ingroup ie_transformation_common_api @@ -24,7 +26,7 @@ class NGRAPH_API ReshapeSinkingMatMul; * Reshape operators to make batch propagate through freely */ -class ngraph::pass::ReshapeSinkingMatMul : public ngraph::pass::MatcherPass { +class ov::pass::ReshapeSinkingMatMul : public ov::pass::MatcherPass { public: OPENVINO_RTTI("ReshapeSinkingMatMul", "0"); ReshapeSinkingMatMul(); diff --git a/src/common/transformations/src/transformations/common_optimizations/moc_transformations.cpp b/src/common/transformations/src/transformations/common_optimizations/moc_transformations.cpp index e812414e36c..a34d2aa5cdc 100644 --- a/src/common/transformations/src/transformations/common_optimizations/moc_transformations.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/moc_transformations.cpp @@ -118,10 +118,10 @@ bool ngraph::pass::MOCTransformations::run_on_model(const std::shared_ptr(); if (!m_use_shapes) { // Approved Smart Reshape - manager.register_pass(); - manager.register_pass(); - manager.register_pass(); - manager.register_pass(); + manager.register_pass(); + manager.register_pass(); + manager.register_pass(); + manager.register_pass(); } manager.register_pass(); diff --git a/src/common/transformations/src/transformations/smart_reshape/lstm_states_broadcast.cpp b/src/common/transformations/src/transformations/smart_reshape/lstm_states_broadcast.cpp index ab2e026995a..017aa22362b 100644 --- a/src/common/transformations/src/transformations/smart_reshape/lstm_states_broadcast.cpp +++ b/src/common/transformations/src/transformations/smart_reshape/lstm_states_broadcast.cpp @@ -2,26 +2,23 @@ // SPDX-License-Identifier: Apache-2.0 // +#include "transformations/smart_reshape/lstm_states_broadcast.hpp" + #include -#include -#include -#include -#include -#include #include "dimension_tracker.hpp" #include "itt.hpp" +#include "openvino/op/util/sub_graph_base.hpp" +#include "openvino/opsets/opset9.hpp" +#include "openvino/pass/manager.hpp" +#include "transformations/utils/utils.hpp" -ov::Input get_outer_input_of_ti_by_parameter(const std::shared_ptr& parameter, - const std::shared_ptr& ti) { - const auto& body = ti->get_body(); - OPENVINO_ASSERT(body != nullptr, "TI returns invalid body graph ", ti); +using namespace std; +using namespace ov::opset9; + +ov::Input get_outer_input_of_ti_by_parameter(const shared_ptr& parameter, + const shared_ptr& ti) { int64_t parameter_index = ti->get_body()->get_parameter_index(parameter); - OPENVINO_ASSERT(parameter_index >= 0, - "LSTMStatesBroadcast encountered unregistered parameter ", - parameter, - " related to TI body ", - ti); for (const auto& input_descriptor : ti->get_input_descriptions()) if (input_descriptor->m_body_parameter_index == parameter_index) return ti->input(input_descriptor->m_input_index); @@ -31,13 +28,11 @@ ov::Input get_outer_input_of_ti_by_parameter(const std::shared_ptr deduce_outer_source_of_batch_for_inner_lstm_cell( - const std::shared_ptr& ti, - const std::shared_ptr& lstm_cell) { - const auto& body = ti->get_body(); - OPENVINO_ASSERT(body != nullptr, "TI returns invalid body graph ", ti); +shared_ptr deduce_outer_source_of_batch_for_inner_lstm_cell(const shared_ptr& ti, + const shared_ptr& lstm_cell) { + const auto& body = ti->get_body(); // body is not nullptr -- we checked earlier - std::map original_shapes; + map original_shapes; size_t label = 1; // mark all input dimensions with labels and making them dynamic, keeping original shapes @@ -46,9 +41,7 @@ std::shared_ptr deduce_outer_source_of_batch_for_inner_lstm_cell( original_shapes[parameter.get()] = pshape; if (pshape.rank().is_dynamic()) continue; - for (ngraph::Dimension& n : pshape) { - OPENVINO_ASSERT(ov::DimensionTracker::get_label(n) == 0, - "LSTMStatesBroadcast encountered TI with previously tracked dimensions"); + for (ov::Dimension& n : pshape) { n = ov::Dimension::dynamic(); ov::DimensionTracker::set_label(n, label++); } @@ -68,7 +61,7 @@ std::shared_ptr deduce_outer_source_of_batch_for_inner_lstm_cell( } // batch label was tracked -- finding parameter that delivered it - std::shared_ptr batch_delivering_parameter; + shared_ptr batch_delivering_parameter; size_t index_of_batch_dim = 0; size_t batch_label = ov::DimensionTracker::get_label(lstm_cell->get_input_partial_shape(0)[0]); @@ -80,8 +73,11 @@ std::shared_ptr deduce_outer_source_of_batch_for_inner_lstm_cell( if (ov::DimensionTracker::get_label(pshape[i]) == batch_label) { batch_delivering_parameter = parameter; index_of_batch_dim = i; + break; } } + if (index_of_batch_dim != 0 && batch_delivering_parameter != nullptr) + break; } for (auto& item : original_shapes) item.first->set_partial_shape(item.second); @@ -91,89 +87,83 @@ std::shared_ptr deduce_outer_source_of_batch_for_inner_lstm_cell( return nullptr; const auto& batched_source = get_outer_input_of_ti_by_parameter(batch_delivering_parameter, ti); - const auto& batched_shape = std::make_shared(batched_source.get_source_output()); - const auto& batch = std::make_shared( - batched_shape, - ov::opset9::Constant::create(ov::element::i64, ov::Shape{1}, {index_of_batch_dim}), - ov::opset9::Constant::create(ov::element::i64, ov::Shape{}, {0})); + const auto& batched_shape = make_shared(batched_source.get_source_output()); + const auto& batch = make_shared(batched_shape, + Constant::create(ov::element::i64, ov::Shape{1}, {index_of_batch_dim}), + Constant::create(ov::element::i64, ov::Shape{}, {0})); return batch; } -bool broadcast_state_by_batch(ov::Input input, const std::shared_ptr& batch_delivering_node) { - auto constant_state = - std::dynamic_pointer_cast(input.get_source_output().get_node_shared_ptr()); +bool broadcast_state_by_batch(ov::Input input, const shared_ptr& batch_delivering_node) { + auto constant_state = dynamic_pointer_cast(input.get_source_output().get_node_shared_ptr()); if (constant_state == nullptr) return false; const auto& constant_shape = constant_state->get_shape(); - OPENVINO_ASSERT(constant_shape.size() == 2, "State has unexpected shape ", constant_shape); if (constant_shape[0] != 1) // we only expect to broadcast LSTM states prepared for batch 1 -- no tiling of batch > 1 will be done return false; const auto& constant_copy = constant_state->copy_with_new_inputs({}); - const auto& broadcast_by_batch = std::make_shared( + const auto& broadcast_by_batch = make_shared( constant_copy, - std::make_shared( - ngraph::NodeVector{batch_delivering_node, - ngraph::op::util::make_try_fold( - ngraph::op::util::make_try_fold(constant_copy), - ov::opset9::Constant::create(ov::element::i64, ov::Shape{1}, {1}), - ov::opset9::Constant::create(ov::element::i64, ov::Shape{}, {0}))}, - 0)); + make_shared(ngraph::NodeVector{batch_delivering_node, + ngraph::op::util::make_try_fold( + ngraph::op::util::make_try_fold(constant_copy), + Constant::create(ov::element::i64, ov::Shape{1}, {1}), + Constant::create(ov::element::i64, ov::Shape{}, {0}))}, + 0)); input.replace_source_output(broadcast_by_batch->output(0)); return true; } -bool relax_batch_for_initial_states_of_lstm_in_ti(const std::shared_ptr& ti, - const std::shared_ptr& lstm_cell) { +bool relax_batch_for_initial_states_of_lstm_in_ti(const shared_ptr& ti, + const shared_ptr& lstm_cell) { bool rewritten = false; auto batch_delivering_node = deduce_outer_source_of_batch_for_inner_lstm_cell(ti, lstm_cell); if (batch_delivering_node == nullptr) return rewritten; - if (auto init_hidden_state = - std::dynamic_pointer_cast(lstm_cell->get_input_node_shared_ptr(1))) { + if (auto init_hidden_state = dynamic_pointer_cast(lstm_cell->get_input_node_shared_ptr(1))) { auto outer_init_hidden_state_input = get_outer_input_of_ti_by_parameter(init_hidden_state, ti); rewritten |= broadcast_state_by_batch(outer_init_hidden_state_input, batch_delivering_node); } - if (auto init_cell_state = - std::dynamic_pointer_cast(lstm_cell->get_input_node_shared_ptr(2))) { + if (auto init_cell_state = dynamic_pointer_cast(lstm_cell->get_input_node_shared_ptr(2))) { auto outer_init_cell_state_input = get_outer_input_of_ti_by_parameter(init_cell_state, ti); rewritten |= broadcast_state_by_batch(outer_init_cell_state_input, batch_delivering_node); } return rewritten; } -bool relax_batch_for_initial_states_of_lstm(const std::shared_ptr& lstm_cell) { +bool relax_batch_for_initial_states_of_lstm(const shared_ptr& lstm_cell) { bool rewritten = false; - const auto& batched_shape = std::make_shared(lstm_cell->get_input_source_output(0)); - const auto& batch_delivering_node = - std::make_shared(batched_shape, - ov::opset9::Constant::create(ov::element::i64, ov::Shape{1}, {0}), - ov::opset9::Constant::create(ov::element::i64, ov::Shape{}, {0})); + const auto& batched_shape = make_shared(lstm_cell->get_input_source_output(0)); + const auto& batch_delivering_node = make_shared(batched_shape, + Constant::create(ov::element::i64, ov::Shape{1}, {0}), + Constant::create(ov::element::i64, ov::Shape{}, {0})); rewritten |= broadcast_state_by_batch(lstm_cell->input(1), batch_delivering_node); rewritten |= broadcast_state_by_batch(lstm_cell->input(2), batch_delivering_node); return rewritten; } -bool ngraph::pass::LSTMStatesBroadcast::run_on_model(const std::shared_ptr& f) { +bool ov::pass::LSTMStatesBroadcast::run_on_model(const shared_ptr& f) { RUN_ON_FUNCTION_SCOPE(LSTMStatesBroadcast); bool rewritten = false; for (auto& node : f->get_ordered_ops()) { // Recursively apply transformation for sub-graph based operations - if (const auto& sub_graph_node = std::dynamic_pointer_cast(node)) + if (const auto& sub_graph_node = dynamic_pointer_cast(node)) if (const auto& sub_graph = sub_graph_node->get_function()) rewritten |= run_on_model(sub_graph); // Case without TI (LSTMCell and Constant are in the same ov::Model) - if (const auto& lstm_cell = std::dynamic_pointer_cast(node)) + if (const auto& lstm_cell = dynamic_pointer_cast(node)) rewritten |= relax_batch_for_initial_states_of_lstm(lstm_cell); // Case with TI (LSTMCell and Constant are in different ov::Model objects) - if (auto ti = std::dynamic_pointer_cast(node)) { + if (auto ti = dynamic_pointer_cast(node)) { auto body = ti->get_body(); - OPENVINO_ASSERT(body, "TensorIterator must have body network"); + if (body == nullptr) + continue; for (const auto& body_node : body->get_ordered_ops()) - if (const auto& lstm_cell = std::dynamic_pointer_cast(body_node)) + if (const auto& lstm_cell = dynamic_pointer_cast(body_node)) rewritten |= relax_batch_for_initial_states_of_lstm_in_ti(ti, lstm_cell); } } diff --git a/src/common/transformations/src/transformations/smart_reshape/reshape_sinking.cpp b/src/common/transformations/src/transformations/smart_reshape/reshape_sinking.cpp index 8528a65e3bf..7a8994f00cc 100644 --- a/src/common/transformations/src/transformations/smart_reshape/reshape_sinking.cpp +++ b/src/common/transformations/src/transformations/smart_reshape/reshape_sinking.cpp @@ -2,16 +2,18 @@ // SPDX-License-Identifier: Apache-2.0 // -#include -#include -#include -#include -#include -#include +#include "transformations/smart_reshape/reshape_sinking.hpp" #include "itt.hpp" +#include "openvino/opsets/opset9.hpp" +#include "openvino/pass/pattern/matcher.hpp" +#include "openvino/pass/pattern/op/or.hpp" +#include "openvino/pass/pattern/op/wrap_type.hpp" -ngraph::pass::ReshapeSinkingMatMul::ReshapeSinkingMatMul() { +using namespace std; +using namespace ov::opset9; + +ov::pass::ReshapeSinkingMatMul::ReshapeSinkingMatMul() { MATCHER_SCOPE(ReshapeSinkingMatMul); /* Original graph: Transformed graph: * @@ -25,22 +27,20 @@ ngraph::pass::ReshapeSinkingMatMul::ReshapeSinkingMatMul() { * | shape=[1, S, O] | shape=[B, S, O] */ auto any_input = pattern::any_input(pattern::has_static_rank()); - auto reshape_label = ngraph::pattern::wrap_type( - {pattern::any_input(), ngraph::pattern::wrap_type()}, - pattern::rank_equals(2)); + auto reshape_label = + ov::pass::pattern::wrap_type({pattern::any_input(), ov::pass::pattern::wrap_type()}, + pattern::rank_equals(2)); - auto matmul_label = - ngraph::pattern::wrap_type({reshape_label, ngraph::pattern::wrap_type()}, - pattern::rank_equals(2)); - auto add_label = - ngraph::pattern::wrap_type({matmul_label, ngraph::pattern::wrap_type()}, - pattern::rank_equals(2)); + auto matmul_label = ov::pass::pattern::wrap_type({reshape_label, ov::pass::pattern::wrap_type()}, + pattern::rank_equals(2)); + auto add_label = ov::pass::pattern::wrap_type({matmul_label, ov::pass::pattern::wrap_type()}, + pattern::rank_equals(2)); - auto matmul_or_matmul_add_label = std::make_shared(OutputVector{add_label, matmul_label}); + auto matmul_or_matmul_add_label = make_shared(OutputVector{add_label, matmul_label}); - auto reshape_1_label = ngraph::pattern::wrap_type( - {matmul_or_matmul_add_label, ngraph::pattern::wrap_type()}, - pattern::has_static_rank()); + auto reshape_1_label = + ov::pass::pattern::wrap_type({matmul_or_matmul_add_label, ov::pass::pattern::wrap_type()}, + pattern::has_static_rank()); matcher_pass_callback callback = [=](pattern::Matcher& m) -> bool { auto pattern_to_node = m.get_pattern_map(); @@ -48,7 +48,7 @@ ngraph::pass::ReshapeSinkingMatMul::ReshapeSinkingMatMul() { // check first Reshape eligibility: has a constant output pattern in a form of [-1, K] auto reshape = pattern_to_node.at(reshape_label); int64_t K = -1; - if (const auto& constant = std::dynamic_pointer_cast(reshape->get_input_node_shared_ptr(1))) { + if (const auto& constant = dynamic_pointer_cast(reshape->get_input_node_shared_ptr(1))) { auto output_pattern_vector = constant->cast_vector(); if (output_pattern_vector.size() != 2 || output_pattern_vector[0] != -1) return false; @@ -66,11 +66,11 @@ ngraph::pass::ReshapeSinkingMatMul::ReshapeSinkingMatMul() { return false; // check matmul eligibility: has constant second input in a form of [O, K] - auto matmul = std::dynamic_pointer_cast(pattern_to_node.at(matmul_label)); + auto matmul = dynamic_pointer_cast(pattern_to_node.at(matmul_label)); if (!matmul || matmul->get_transpose_a()) return false; int64_t O = -1; - if (const auto& constant = std::dynamic_pointer_cast(matmul->get_input_node_shared_ptr(1))) { + if (const auto& constant = dynamic_pointer_cast(matmul->get_input_node_shared_ptr(1))) { const auto& constant_shape = constant->get_shape(); if (constant_shape.size() != 2) return false; @@ -86,10 +86,10 @@ ngraph::pass::ReshapeSinkingMatMul::ReshapeSinkingMatMul() { // check add eligibility if present: has constant second input that has a form of [1, 1, ..., O] (doesn't // broadcast first input) if (pattern_to_node.count(add_label)) { - auto add = std::dynamic_pointer_cast(pattern_to_node.at(add_label)); + auto add = dynamic_pointer_cast(pattern_to_node.at(add_label)); if (!add || add->get_autob() != ov::op::AutoBroadcastType::NUMPY) return false; - const auto& constant = std::dynamic_pointer_cast(add->get_input_node_shared_ptr(1)); + const auto& constant = dynamic_pointer_cast(add->get_input_node_shared_ptr(1)); if (!constant) return false; const auto& constant_shape = constant->get_shape(); @@ -106,19 +106,17 @@ ngraph::pass::ReshapeSinkingMatMul::ReshapeSinkingMatMul() { // input_shape of the pattern except for the batch and last dimension auto reshape_1 = m.get_match_root(); - const auto& constant = std::dynamic_pointer_cast(reshape_1->get_input_node_shared_ptr(1)); + const auto& constant = dynamic_pointer_cast(reshape_1->get_input_node_shared_ptr(1)); if (constant == nullptr) return false; auto output_pattern = constant->cast_vector(); - if (!std::all_of(output_pattern.begin(), output_pattern.end(), [](const int64_t& i) { + if (output_pattern.size() != input_rank) + return false; + if (!all_of(output_pattern.begin(), output_pattern.end(), [](const int64_t& i) { return i > 0; })) return false; - if (output_pattern.size() != input_rank) - return false; - for (size_t i = 0; i < input_rank; ++i) { - if (i == 0) - continue; + for (size_t i = 1; i < input_rank; ++i) { if (i + 1 == input_rank) { if (output_pattern[i] != O) return false; @@ -129,28 +127,26 @@ ngraph::pass::ReshapeSinkingMatMul::ReshapeSinkingMatMul() { return false; } + auto first_reshape = dynamic_pointer_cast(reshape); + auto second_reshape = dynamic_pointer_cast(reshape_1); + if (!first_reshape || !second_reshape) + return false; + // this is the pattern we are looking for! performing the transformation - auto first_reshape = std::dynamic_pointer_cast(reshape); - if (!first_reshape) - return false; first_reshape->set_special_zero(true); - auto second_reshape = std::dynamic_pointer_cast(reshape_1); - if (!second_reshape) - return false; second_reshape->set_special_zero(true); - std::vector output_pattern_vector(input_rank - 1, 0); + vector output_pattern_vector(input_rank - 1, 0); output_pattern_vector.push_back(K); - auto new_reshape_constant = - opset9::Constant::create(ov::element::i64, Shape{input_rank}, output_pattern_vector); + auto new_reshape_constant = Constant::create(ov::element::i64, Shape{input_rank}, output_pattern_vector); reshape->input(1).replace_source_output(new_reshape_constant->output(0)); output_pattern[0] = 0; - auto new_reshape_1_constant = opset9::Constant::create(ov::element::i64, Shape{input_rank}, output_pattern); + auto new_reshape_1_constant = Constant::create(ov::element::i64, Shape{input_rank}, output_pattern); reshape_1->input(1).replace_source_output(new_reshape_1_constant->output(0)); return true; }; - auto m = std::make_shared(reshape_1_label, matcher_name); + auto m = make_shared(reshape_1_label, matcher_name); register_matcher(m, callback); } diff --git a/src/common/transformations/src/transformations/smart_reshape/smart_reshape.cpp b/src/common/transformations/src/transformations/smart_reshape/smart_reshape.cpp index 0ce962c96b9..c27462b59df 100644 --- a/src/common/transformations/src/transformations/smart_reshape/smart_reshape.cpp +++ b/src/common/transformations/src/transformations/smart_reshape/smart_reshape.cpp @@ -30,8 +30,8 @@ bool ngraph::pass::SmartReshape::run_on_model(const std::shared_ptr(); static_manager.register_pass(); static_manager.register_pass(); - static_manager.register_pass(); - static_manager.register_pass(); + static_manager.register_pass(); + static_manager.register_pass(); static_manager.run_passes(f); ngraph::pass::Manager dynamic_manager; diff --git a/src/tests/functional/inference_engine/transformations/smart_reshape/lstm_states_broadcast.cpp b/src/tests/functional/inference_engine/transformations/smart_reshape/lstm_states_broadcast.cpp index 27772141560..9fa21850212 100644 --- a/src/tests/functional/inference_engine/transformations/smart_reshape/lstm_states_broadcast.cpp +++ b/src/tests/functional/inference_engine/transformations/smart_reshape/lstm_states_broadcast.cpp @@ -4,20 +4,13 @@ #include -#include -#include +#include "openvino/core/model.hpp" +#include "openvino/opsets/opset9.hpp" #include "common_test_utils/ngraph_test_utils.hpp" -template -std::shared_ptr create_constant(const std::vector& data, const ov::element::Type_t et = ov::element::i64, bool scalar = false) { - ov::Shape shape = scalar ? ov::Shape{} : ov::Shape{data.size()}; - return ov::opset9::Constant::create(et, shape, data); -} - -std::shared_ptr create_zero_constant(const ov::element::Type_t et, ov::PartialShape shape) { - return ov::opset9::Constant::create(et, shape.to_shape(), {0}); -} +using namespace std; +using namespace ov::opset9; struct LSTMStatesAttributes { ov::element::Type_t data_et; @@ -32,18 +25,18 @@ class LSTMStatesBroadcastTest TEST_P(LSTMStatesBroadcastTest, BareLSTM) { auto p = GetParam(); - std::shared_ptr model(nullptr); + shared_ptr model(nullptr); { - auto parameter = std::make_shared(p.data_et, ov::PartialShape{p.data_batch_size, p.input_size}); - auto initial_hidden_state = create_zero_constant(p.data_et, {1, p.hidden_size}); - auto initial_cell_state = create_zero_constant(p.data_et, {1, p.hidden_size}); - auto W = create_zero_constant(p.data_et, {p.hidden_size * 4, p.input_size}); - auto R = create_zero_constant(p.data_et, {p.hidden_size * 4, p.hidden_size}); + auto parameter = make_shared(p.data_et, ov::PartialShape{p.data_batch_size, p.input_size}); + auto initial_hidden_state = create_zero_constant(p.data_et, ov::PartialShape{1, p.hidden_size}.to_shape()); + auto initial_cell_state = create_zero_constant(p.data_et, ov::PartialShape{1, p.hidden_size}.to_shape()); + auto W = create_zero_constant(p.data_et, ov::PartialShape{p.hidden_size * 4, p.input_size}.to_shape()); + auto R = create_zero_constant(p.data_et, ov::PartialShape{p.hidden_size * 4, p.hidden_size}.to_shape()); - auto cell = std::make_shared( + auto cell = make_shared( parameter, initial_hidden_state, initial_cell_state, W, R, static_cast(p.hidden_size.get_length())); - model = std::make_shared(ov::NodeVector{cell}, ov::ParameterVector{parameter}); + model = make_shared(ov::NodeVector{cell}, ov::ParameterVector{parameter}); } ASSERT_NO_THROW(model->reshape(ov::PartialShape{p.new_batch_size, p.input_size})); } @@ -56,29 +49,29 @@ class LSTMStatesBroadcastTestWithTI TEST_P(LSTMStatesBroadcastTestWithTI, TI_With_LSTM) { auto p = GetParam(); - std::shared_ptr model(nullptr); + shared_ptr model(nullptr); { - auto X = std::make_shared(p.data_et, ov::PartialShape{p.data_batch_size, 1, p.input_size}); - auto H_init = create_zero_constant(ov::element::i64, {1, p.hidden_size}); - auto C_init = create_zero_constant(ov::element::i64, {1, p.hidden_size}); + auto X = make_shared(p.data_et, ov::PartialShape{p.data_batch_size, 1, p.input_size}); + auto H_init = create_zero_constant(ov::element::i64, ov::PartialShape{1, p.hidden_size}.to_shape()); + auto C_init = create_zero_constant(ov::element::i64, ov::PartialShape{1, p.hidden_size}.to_shape()); - auto Xi = std::make_shared(p.data_et, ov::PartialShape{1, 1, p.input_size}); - auto H_t = std::make_shared(p.data_et, ov::PartialShape{1, p.hidden_size}); - auto C_t = std::make_shared(p.data_et, ov::PartialShape{1, p.hidden_size}); + auto Xi = make_shared(p.data_et, ov::PartialShape{1, 1, p.input_size}); + auto H_t = make_shared(p.data_et, ov::PartialShape{1, p.hidden_size}); + auto C_t = make_shared(p.data_et, ov::PartialShape{1, p.hidden_size}); // Body - auto squeeze = std::make_shared(Xi, create_constant({1})); - auto W = create_zero_constant(p.data_et, {p.hidden_size * 4, p.input_size}); - auto R = create_zero_constant(p.data_et, {p.hidden_size * 4, p.hidden_size}); + auto squeeze = make_shared(Xi, create_constant({1})); + auto W = create_zero_constant(p.data_et, ov::PartialShape{p.hidden_size * 4, p.input_size}.to_shape()); + auto R = create_zero_constant(p.data_et, ov::PartialShape{p.hidden_size * 4, p.hidden_size}.to_shape()); - auto lstm_cell = std::make_shared(squeeze, H_t, C_t, W, R, static_cast(p.hidden_size.get_length())); - auto res_1 = std::make_shared(lstm_cell->output(0)); - auto unsqueeze = std::make_shared(lstm_cell->output(0), create_constant({1})); - auto res_2 = std::make_shared(unsqueeze); - auto res_3 = std::make_shared(lstm_cell->output(1)); - auto body = std::make_shared(ov::OutputVector{res_1, res_2, res_3}, ov::ParameterVector{Xi, H_t, C_t}); + auto lstm_cell = make_shared(squeeze, H_t, C_t, W, R, static_cast(p.hidden_size.get_length())); + auto res_1 = make_shared(lstm_cell->output(0)); + auto unsqueeze = make_shared(lstm_cell->output(0), create_constant({1})); + auto res_2 = make_shared(unsqueeze); + auto res_3 = make_shared(lstm_cell->output(1)); + auto body = make_shared(ov::OutputVector{res_1, res_2, res_3}, ov::ParameterVector{Xi, H_t, C_t}); - auto tensor_iterator = std::make_shared(); + auto tensor_iterator = make_shared(); tensor_iterator->set_body(body); tensor_iterator->set_merged_input(C_t, C_init, res_3); @@ -88,16 +81,15 @@ TEST_P(LSTMStatesBroadcastTestWithTI, TI_With_LSTM) { auto out0 = tensor_iterator->get_iter_value(res_1, -1); auto out1 = tensor_iterator->get_concatenated_slices(res_2, 0, 1, 1, -1, 0); - auto res_ti_1 = std::make_shared(tensor_iterator->output(1)); - auto res_ti_2 = std::make_shared(tensor_iterator->output(0)); - model = std::make_shared(ov::NodeVector{res_ti_1, res_ti_2}, + auto res_ti_1 = make_shared(tensor_iterator->output(1)); + auto res_ti_2 = make_shared(tensor_iterator->output(0)); + model = make_shared(ov::NodeVector{res_ti_1, res_ti_2}, ov::ParameterVector{X}); } - model->reshape(ov::PartialShape{p.new_batch_size, 1, p.input_size}); ASSERT_NO_THROW(model->reshape(ov::PartialShape{p.new_batch_size, 1, p.input_size})); } -static std::vector params = { +static vector params = { LSTMStatesAttributes{ov::element::f32, {1}, {2}, {512}, {256}}, LSTMStatesAttributes{ov::element::f32, {-1}, {2}, {512}, {256}}, }; diff --git a/src/tests/functional/inference_engine/transformations/smart_reshape/reshape_sinking.cpp b/src/tests/functional/inference_engine/transformations/smart_reshape/reshape_sinking.cpp index 37e6fbcea00..5f560c08979 100644 --- a/src/tests/functional/inference_engine/transformations/smart_reshape/reshape_sinking.cpp +++ b/src/tests/functional/inference_engine/transformations/smart_reshape/reshape_sinking.cpp @@ -4,21 +4,11 @@ #include -#include -#include +#include "openvino/core/model.hpp" +#include "openvino/opsets/opset9.hpp" #include "common_test_utils/ngraph_test_utils.hpp" -template -std::shared_ptr create_constant(const std::vector& data, const ov::element::Type_t et = ov::element::i64, bool scalar = false) { - ov::Shape shape = scalar ? ov::Shape{} : ov::Shape{data.size()}; - return ov::opset9::Constant::create(et, shape, data); -} - -std::shared_ptr create_zero_constant(const ov::element::Type_t et, ov::Shape shape) { - return ov::opset9::Constant::create(et, shape, {0}); -} - struct ReshapeSinkingAttributes { ov::element::Type_t data_et; ov::PartialShape input_shape; diff --git a/src/tests/ie_test_utils/common_test_utils/ngraph_test_utils.cpp b/src/tests/ie_test_utils/common_test_utils/ngraph_test_utils.cpp index dcecd9672a0..b5fcc9b102d 100644 --- a/src/tests/ie_test_utils/common_test_utils/ngraph_test_utils.cpp +++ b/src/tests/ie_test_utils/common_test_utils/ngraph_test_utils.cpp @@ -64,3 +64,7 @@ void check_unique_names(std::shared_ptr f, const std::shared_p manager.register_pass(unh, true); manager.run_passes(f); } + +std::shared_ptr create_zero_constant(const ov::element::Type_t& et, const ov::Shape& shape) { + return ov::opset8::Constant::create(et, shape, {0}); +} \ No newline at end of file diff --git a/src/tests/ie_test_utils/common_test_utils/ngraph_test_utils.hpp b/src/tests/ie_test_utils/common_test_utils/ngraph_test_utils.hpp index 5101bb93b0b..95d528a7078 100644 --- a/src/tests/ie_test_utils/common_test_utils/ngraph_test_utils.hpp +++ b/src/tests/ie_test_utils/common_test_utils/ngraph_test_utils.hpp @@ -69,3 +69,11 @@ size_t count_ops_of_type(const std::shared_ptr& f) { return count; } + +template +std::shared_ptr create_constant(const std::vector& data, const ov::element::Type_t et = ov::element::i64, bool scalar = false) { + ov::Shape shape = scalar ? ov::Shape{} : ov::Shape{data.size()}; + return ov::opset8::Constant::create(et, shape, data); +} + +std::shared_ptr create_zero_constant(const ov::element::Type_t& et, const ov::Shape& shape);