From 1c3208ffe0888ee720ba21bfe80d3ed3a84f2796 Mon Sep 17 00:00:00 2001 From: Ivan Tikhonov Date: Fri, 6 Nov 2020 14:11:11 +0300 Subject: [PATCH] Low Latency transformation (#2869) * initial draft of adding sinks to ngraph::Function * style fixes * code style fixes * code style fixes * code style fix * review fix+build fix * code style fix * fix build * API changed according to latest discussion * review fixes * review fixes + tests * initial draft of adding sinks to ngraph::Function * style fixes * code style fixes * code style fixes * code style fix * review fix+build fix * code style fix * fix build * API changed according to latest discussion * review fixes * review fixes + tests * added 1 more ctor * style fixes * used new api in ir parser * fixed build * update low latency transformation, fix unroll transformation, add unit tests, modify subgraph tests * fix low latency transformation * Update low latency transformation, unit and sub-graph tests * update LowLatency transformation and tests * ngraph codestyle * fix build, update description * resolve review remarks Co-authored-by: Svetlana Dolinina --- .../include/ie_transformations.hpp | 56 +++ inference-engine/include/inference_engine.hpp | 1 + .../src/gna_plugin/CMakeLists.txt | 5 +- .../src/gna_plugin/gna_graph_compiler.cpp | 2 +- .../src/gna_plugin/gna_plugin.cpp | 36 +- .../inference_engine/ie_transformations.cpp | 16 + .../src/mkldnn_plugin/mkldnn_plugin.cpp | 7 + .../common_optimizations/low_latency.hpp | 5 + .../control_flow/unroll_tensor_iterator.hpp | 4 +- .../control_flow/unroll_tensor_iterator.cpp | 23 +- .../transformations/low_latency_test.cpp | 353 ++++++++++++++++++ .../include/subgraph_tests/basic_lstm.hpp | 6 +- .../subgraph_tests/memory_LSTMCell.hpp | 2 + .../subgraph_tests/multiple_LSTMCell.hpp | 2 + .../shared/src/subgraph_tests/basic_lstm.cpp | 113 +++--- .../src/subgraph_tests/memory_LSTMCell.cpp | 124 ++++++ .../src/subgraph_tests/multiple_LSTMCell.cpp | 190 ++++++++++ .../common_test_utils/ngraph_test_utils.cpp | 17 +- ngraph/core/include/ngraph/function.hpp | 1 + ngraph/core/include/ngraph/op/sink.hpp | 1 + .../core/include/ngraph/pass/low_latency.hpp | 55 +++ ngraph/core/src/op/loop.cpp | 6 +- ngraph/core/src/op/tensor_iterator.cpp | 20 +- ngraph/core/src/pass/low_latency.cpp | 71 ++++ 24 files changed, 1023 insertions(+), 93 deletions(-) create mode 100644 inference-engine/include/ie_transformations.hpp create mode 100644 inference-engine/src/inference_engine/ie_transformations.cpp create mode 100644 inference-engine/src/transformations/include/transformations/common_optimizations/low_latency.hpp create mode 100644 inference-engine/tests/functional/inference_engine/transformations/low_latency_test.cpp create mode 100644 ngraph/core/include/ngraph/pass/low_latency.hpp create mode 100644 ngraph/core/src/pass/low_latency.cpp diff --git a/inference-engine/include/ie_transformations.hpp b/inference-engine/include/ie_transformations.hpp new file mode 100644 index 00000000000..673e563f4e8 --- /dev/null +++ b/inference-engine/include/ie_transformations.hpp @@ -0,0 +1,56 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +/** + * @brief This header file defines the list of public transformations. + * + * @file ie_transformations.hpp + */ + +#pragma once + +#include +#include + +namespace InferenceEngine { + +/** + * @brief The transformation finds all TensorIterator layers in the network, processes all back + * edges that describe a connection between Result and Parameter of the TensorIterator body, + * and inserts ReadValue layer between Parameter and the next layers after this Parameter, + * and Assign layer after the layers before the Result layer. + * Supported platforms: CPU, GNA. + * + * The example below describes the changes to the inner part (body, back edges) of the TensorIterator layer. + * [] - TensorIterator body + * () - new layer + * + * before applying the transformation: + * back_edge_1 -> [Parameter -> some layers ... -> Result ] -> back_edge_1 + * + * after applying the transformation: + * back_edge_1 -> [Parameter -> (ReadValue layer) -> some layers ... -> (Assign layer) ] + * \ + * -> Result ] -> back_edge_1 + * + * It is recommended to use this transformation in conjunction with the Reshape feature to set sequence + * dimension to 1 and with the UnrollTensorIterator transformation. + * For convenience, we have already enabled the unconditional execution of the UnrollTensorIterator + * transformation when using the LowLatency transformation for CPU, GNA plugins, no action is required here. + * After applying both of these transformations, the resulting network can be inferred step by + * step, the states will store between inferences. + * + * An illustrative example, not real API: + * + * network->reshape(...) // Set sequence dimension to 1, recalculating shapes. Optional, depends on the network. + * LowLatency(network) // Applying LowLatency and UnrollTensorIterator transformations. + * network->infer (...) // Calculating new values for states. + * // All states are stored between inferences via Assign, ReadValue layers. + * network->infer (...) // Using stored states, calculating new values for states. + * + * @param network A network to apply LowLatency transformation + * * + */ +INFERENCE_ENGINE_API_CPP(void) LowLatency(InferenceEngine::CNNNetwork& network); +} // namespace InferenceEngine diff --git a/inference-engine/include/inference_engine.hpp b/inference-engine/include/inference_engine.hpp index 4f97a3e866b..300566f79b9 100644 --- a/inference-engine/include/inference_engine.hpp +++ b/inference-engine/include/inference_engine.hpp @@ -8,6 +8,7 @@ */ #pragma once +#include "ie_transformations.hpp" #include "ie_plugin_config.hpp" #include "ie_compound_blob.h" #include "ie_core.hpp" diff --git a/inference-engine/src/gna_plugin/CMakeLists.txt b/inference-engine/src/gna_plugin/CMakeLists.txt index 645e7e941e7..dfd15e6bbf0 100644 --- a/inference-engine/src/gna_plugin/CMakeLists.txt +++ b/inference-engine/src/gna_plugin/CMakeLists.txt @@ -31,7 +31,8 @@ ie_add_plugin(NAME ${TARGET_NAME} # saving rpath to GNA shared library be used by CI log_rpath_from_dir(GNA ${libGNA_LIBRARIES_BASE_PATH}) -target_link_libraries(${TARGET_NAME} PRIVATE inference_engine inference_engine_legacy Threads::Threads libGNA) +target_link_libraries(${TARGET_NAME} PRIVATE inference_engine inference_engine_legacy inference_engine_transformations + Threads::Threads libGNA) target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) target_compile_definitions(${TARGET_NAME} @@ -57,7 +58,7 @@ target_compile_definitions(${TARGET_NAME}_test_static INTEGER_LOW_P USE_STATIC_IE) -target_link_libraries(${TARGET_NAME}_test_static PUBLIC inference_engine_preproc_s libGNA::API) +target_link_libraries(${TARGET_NAME}_test_static PUBLIC inference_engine_preproc_s inference_engine_transformations libGNA::API) target_include_directories(${TARGET_NAME}_test_static PUBLIC ${CMAKE_CURRENT_SOURCE_DIR} $) set_target_properties(${TARGET_NAME}_test_static PROPERTIES COMPILE_PDB_NAME ${TARGET_NAME}_test_static) diff --git a/inference-engine/src/gna_plugin/gna_graph_compiler.cpp b/inference-engine/src/gna_plugin/gna_graph_compiler.cpp index 68f0403be72..256489be376 100644 --- a/inference-engine/src/gna_plugin/gna_graph_compiler.cpp +++ b/inference-engine/src/gna_plugin/gna_graph_compiler.cpp @@ -2137,7 +2137,7 @@ GNAPluginNS::ConnectionDetails GNAGraphCompiler::connectInput(CNNLayerPtr layer, auto prevMemoryLayer = std::find_if(begin(memory_connection), end(memory_connection), [&](MemoryConnection::value_type &comp) { - return comp.second.getInput()->name == prevLayer->name; + return comp.second.getInput()->params.at("id") == prevLayer->params.at("id"); }); if (prevMemoryLayer != memory_connection.end()) { // dnnLayer that is input for memory output layer diff --git a/inference-engine/src/gna_plugin/gna_plugin.cpp b/inference-engine/src/gna_plugin/gna_plugin.cpp index 8a69d0480f8..69ae31cca09 100644 --- a/inference-engine/src/gna_plugin/gna_plugin.cpp +++ b/inference-engine/src/gna_plugin/gna_plugin.cpp @@ -38,6 +38,18 @@ #include "gna_model_serial.hpp" #include "runtime/gna_float_runtime.hpp" +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + #if GNA_LIB_VER == 2 #include @@ -342,7 +354,29 @@ void GNAPlugin::InitGNADevice() { void GNAPlugin::LoadNetwork(ICNNNetwork & _network) { std::shared_ptr convertedNetwork; if (_network.getFunction()) { - convertedNetwork = std::make_shared(_network); + std::shared_ptr clonedNetwork = cloneNetwork(_network); + const auto& graph = clonedNetwork->getFunction(); + // Disable shape inference (WA for generic operations) + ngraph::op::GenericIE::DisableReshape noReshape(graph); + ngraph::pass::Manager manager; + manager.register_pass(); + // WA: ConvertPriorBox must be executed before the 1st ConstantFolding pass + manager.register_pass(); + manager.register_pass(); + manager.register_pass(); + manager.register_pass(); + manager.register_pass(); + // UnrollTI should be the last transformation in the transformation pipeline + manager.register_pass(); + + const auto& pass_config = manager.get_pass_config(); + pass_config->set_callback( + [](const std::shared_ptr &node) -> bool { + // UnrollTI transformation is disabled by default, is turned on by LowLatency transformation + return node->get_rt_info().count("UNROLL_TI") == 0; + }); + manager.run_passes(graph); + convertedNetwork = InferenceEngine::details::convertFunctionToICNNNetwork(graph, *clonedNetwork); } InferenceEngine::ICNNNetwork &network = convertedNetwork ? *convertedNetwork : _network; diff --git a/inference-engine/src/inference_engine/ie_transformations.cpp b/inference-engine/src/inference_engine/ie_transformations.cpp new file mode 100644 index 00000000000..798510a161c --- /dev/null +++ b/inference-engine/src/inference_engine/ie_transformations.cpp @@ -0,0 +1,16 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "ie_transformations.hpp" +#include +#include + +using namespace InferenceEngine; + +void InferenceEngine::LowLatency(InferenceEngine::CNNNetwork &network) { + auto function = network.getFunction(); + ngraph::pass::Manager manager; + manager.register_pass(); + manager.run_passes(function); +} diff --git a/inference-engine/src/mkldnn_plugin/mkldnn_plugin.cpp b/inference-engine/src/mkldnn_plugin/mkldnn_plugin.cpp index 1f69e7bdbba..0cc8b2c9b7d 100644 --- a/inference-engine/src/mkldnn_plugin/mkldnn_plugin.cpp +++ b/inference-engine/src/mkldnn_plugin/mkldnn_plugin.cpp @@ -32,6 +32,7 @@ #include #include +#include #include #include #include @@ -177,6 +178,8 @@ static void Transformation(ICNNNetwork::Ptr& clonedNetwork, const Config& conf) ngraph::pass::Manager legacyManager; legacyManager.register_pass(); legacyManager.register_pass(ngraph::element::i64, ngraph::element::i32); + // not legacy actually, but it should be the last transformation in the transformation pipeline + legacyManager.register_pass(); auto legacyPassConfig = manager.get_pass_config(); legacyPassConfig->set_callback([](const_node_ptr &node) -> bool { @@ -193,6 +196,10 @@ static void Transformation(ICNNNetwork::Ptr& clonedNetwork, const Config& conf) return false; }); + legacyManager.get_pass_config()->set_callback([](const_node_ptr &node) -> bool { + // UnrollTI transformation is disabled by default, is turned on by LowLatency transformation + return node->get_rt_info().count("UNROLL_TI") == 0; + }); legacyManager.run_passes(nGraphFunc); clonedNetwork = InferenceEngine::details::convertFunctionToICNNNetwork(nGraphFunc, *clonedNetwork); diff --git a/inference-engine/src/transformations/include/transformations/common_optimizations/low_latency.hpp b/inference-engine/src/transformations/include/transformations/common_optimizations/low_latency.hpp new file mode 100644 index 00000000000..eb05df852bc --- /dev/null +++ b/inference-engine/src/transformations/include/transformations/common_optimizations/low_latency.hpp @@ -0,0 +1,5 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include diff --git a/inference-engine/src/transformations/include/transformations/control_flow/unroll_tensor_iterator.hpp b/inference-engine/src/transformations/include/transformations/control_flow/unroll_tensor_iterator.hpp index a0016c3c793..b1424325498 100644 --- a/inference-engine/src/transformations/include/transformations/control_flow/unroll_tensor_iterator.hpp +++ b/inference-engine/src/transformations/include/transformations/control_flow/unroll_tensor_iterator.hpp @@ -27,8 +27,8 @@ class TRANSFORMATIONS_API UnrollTensorIterator; * are added to the network. */ -class ngraph::pass::UnrollTensorIterator: public ngraph::pass::MatcherPass { +class ngraph::pass::UnrollTensorIterator: public ngraph::pass::FunctionPass { public: NGRAPH_RTTI_DECLARATION; - UnrollTensorIterator(); + bool run_on_function(std::shared_ptr) override; }; diff --git a/inference-engine/src/transformations/src/transformations/control_flow/unroll_tensor_iterator.cpp b/inference-engine/src/transformations/src/transformations/control_flow/unroll_tensor_iterator.cpp index 45f96fb6797..d374e281a5e 100644 --- a/inference-engine/src/transformations/src/transformations/control_flow/unroll_tensor_iterator.cpp +++ b/inference-engine/src/transformations/src/transformations/control_flow/unroll_tensor_iterator.cpp @@ -15,12 +15,11 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::UnrollTensorIterator, "UnrollTensorIterator", 0); -ngraph::pass::UnrollTensorIterator::UnrollTensorIterator() : MatcherPass() { - auto tensor_iterator = ngraph::pattern::wrap_type(); - ngraph::matcher_pass_callback callback = [this](pattern::Matcher& m) { - auto ti = std::dynamic_pointer_cast(m.get_match_root()); - if (!ti) { - return false; +bool ngraph::pass::UnrollTensorIterator::run_on_function(std::shared_ptr f) { + for (const auto& op : f->get_ops()) { + auto ti = std::dynamic_pointer_cast(op); + if (!ti || m_transformation_callback(ti)) { + continue; } const auto function = ti->get_body(); @@ -28,7 +27,7 @@ ngraph::pass::UnrollTensorIterator::UnrollTensorIterator() : MatcherPass() { // negative value means inconsistent TI if (num_iter <= -1) { - return false; + continue; } // Create copies of the TensorIterator body, the number of copies is equal to the number of iterations. @@ -183,9 +182,9 @@ ngraph::pass::UnrollTensorIterator::UnrollTensorIterator() : MatcherPass() { } } - return true; - }; - - auto m = std::make_shared(tensor_iterator, "UnrollTensorIterator"); - register_matcher(m, callback); + for (const auto& body_func : body_functions) { + f->add_sinks(body_func->get_sinks()); + } + } + return true; } diff --git a/inference-engine/tests/functional/inference_engine/transformations/low_latency_test.cpp b/inference-engine/tests/functional/inference_engine/transformations/low_latency_test.cpp new file mode 100644 index 00000000000..dc1db93da6f --- /dev/null +++ b/inference-engine/tests/functional/inference_engine/transformations/low_latency_test.cpp @@ -0,0 +1,353 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include +#include +#include + +#include +#include +#include + +#include +#include +#include + +#include "common_test_utils/ngraph_test_utils.hpp" + +using namespace testing; +using namespace ngraph; + +TEST(TransformationTests, LowLatencyLSTM) { + std::shared_ptr f(nullptr), f_ref(nullptr); + { + auto X = std::make_shared(element::f32, Shape{1, 1, 16}); + auto H_init = std::make_shared(element::f32, Shape{1, 128}); + auto C_init = std::make_shared(element::f32, Shape{1, 128}); + + auto Xi = std::make_shared(element::f32, Shape{1, 1, 16}); + auto H_t = std::make_shared(element::f32, Shape{1, 128}); + auto C_t = std::make_shared(element::f32, Shape{1, 128}); + + // Body + auto axis = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{}, {0}); + auto squeeze = std::make_shared(Xi, axis); + + auto w_val = std::vector(512 * 16, 0); + auto r_val = std::vector(512 * 128, 0); + auto b_val = std::vector(512, 0); + auto W = ngraph::opset5::Constant::create(ngraph::element::f32, ngraph::Shape{512, 16}, w_val); + auto R = ngraph::opset5::Constant::create(ngraph::element::f32, ngraph::Shape{512, 128}, r_val); + auto B = ngraph::opset5::Constant::create(ngraph::element::f32, ngraph::Shape{512}, b_val); + + auto lstm_cell = std::make_shared(squeeze, H_t, C_t, W, R, B, 128); + auto res_1 = std::make_shared(lstm_cell->output(0)); + auto unsqueeze = std::make_shared(lstm_cell->output(0), axis); + auto res_2 = std::make_shared(unsqueeze); + auto res_3 = std::make_shared(lstm_cell->output(1)); + auto body = std::make_shared(OutputVector{res_1, res_2, res_3}, ParameterVector{Xi, H_t, C_t}); + + auto tensor_iterator = std::make_shared(); + tensor_iterator->set_body(body); + tensor_iterator->set_friendly_name("LSTMTensorIterator"); + + tensor_iterator->set_merged_input(C_t, C_init, res_3); + tensor_iterator->set_sliced_input(Xi, X, 0, 1, 1, -1, 0); + tensor_iterator->set_merged_input(H_t, H_init, res_1); + + auto out0 = tensor_iterator->get_iter_value(res_1, -1); + auto out1 = tensor_iterator->get_concatenated_slices(res_2, 0, 1, 1, -1, 0); + + auto res_ti_1 = std::make_shared(tensor_iterator->output(1)); + auto res_ti_2 = std::make_shared(tensor_iterator->output(0)); + f = std::make_shared(ngraph::NodeVector{res_ti_1, res_ti_2}, + ngraph::ParameterVector{X, H_init, C_init}); + + ngraph::pass::Manager manager; + manager.register_pass(); + manager.register_pass(); + manager.register_pass(); + manager.run_passes(f); + } + { + auto Xi = std::make_shared(element::f32, Shape{1, 1, 16}); + auto H_t = std::make_shared(element::f32, Shape{1, 128}); + auto C_t = std::make_shared(element::f32, Shape{1, 128}); + + const std::string variable_name_H("LSTMTensorIterator/variable0"); + const std::string variable_name_C("LSTMTensorIterator/variable1"); + auto read_value_H = std::make_shared(H_t, variable_name_H); + auto read_value_C = std::make_shared(C_t, variable_name_C); + // Body + auto axis = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{}, {0}); + auto squeeze = std::make_shared(Xi, axis); + + auto w_val = std::vector(512 * 16, 0); + auto r_val = std::vector(512 * 128, 0); + auto b_val = std::vector(512, 0); + auto W = ngraph::opset5::Constant::create(ngraph::element::f32, ngraph::Shape{512, 16}, w_val); + auto R = ngraph::opset5::Constant::create(ngraph::element::f32, ngraph::Shape{512, 128}, r_val); + auto B = ngraph::opset5::Constant::create(ngraph::element::f32, ngraph::Shape{512}, b_val); + + auto lstm_cell = std::make_shared(squeeze, read_value_H, read_value_C, W, R, B, 128); + auto assign_H = std::make_shared(lstm_cell->output(0), variable_name_H); + auto assign_C = std::make_shared(lstm_cell->output(1), variable_name_C); + auto res_1 = std::make_shared(lstm_cell->output(0)); + auto unsqueeze = std::make_shared(lstm_cell->output(0), axis); + auto res_2 = std::make_shared(unsqueeze); + f_ref = std::make_shared(OutputVector{unsqueeze, res_1}, ParameterVector{Xi, H_t, C_t}); + f_ref->add_sinks({assign_C, assign_H}); + assign_H->add_control_dependency(read_value_H); + assign_C->add_control_dependency(read_value_C); + } + auto res = compare_functions(f, f_ref); + ASSERT_TRUE(res.first) << res.second; +} + +TEST(TransformationTests, LowLatencyGRU) { + std::shared_ptr f(nullptr), f_ref(nullptr); + { + auto X = std::make_shared(element::f32, Shape{1, 1, 16}); + auto Y = std::make_shared(element::f32, Shape{1, 128}); + + auto Xi = std::make_shared(element::f32, Shape{1, 1, 16}); + auto Yi = std::make_shared(element::f32, Shape{1, 128}); + + // Body + auto axis = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{}, {0}); + auto squeeze = std::make_shared(Xi, axis); + + auto w_val = std::vector(384 * 16, 0); + auto r_val = std::vector(384 * 128, 0); + auto b_val = std::vector(384, 0); + auto W = ngraph::opset5::Constant::create(ngraph::element::f32, ngraph::Shape{384, 16}, w_val); + auto R = ngraph::opset5::Constant::create(ngraph::element::f32, ngraph::Shape{384, 128}, r_val); + auto B = ngraph::opset5::Constant::create(ngraph::element::f32, ngraph::Shape{384}, b_val); + + auto gru_cell = std::make_shared(squeeze, Yi, W, R, B, 128); + auto res_1 = std::make_shared(gru_cell); + auto unsqueeze = std::make_shared(gru_cell, axis); + auto res_2 = std::make_shared(unsqueeze); + auto body = std::make_shared(OutputVector{res_1, res_2}, ParameterVector{Xi, Yi}); + + auto tensor_iterator = std::make_shared(); + tensor_iterator->set_body(body); + + tensor_iterator->set_sliced_input(Xi, X, 0, 1, 1, -1, 0); + tensor_iterator->set_merged_input(Yi, Y, res_1); + + auto out0 = tensor_iterator->get_iter_value(res_1, -1); + auto out1 = tensor_iterator->get_concatenated_slices(res_2, 0, 1, 1, -1, 0); + + auto res_ti_1 = std::make_shared(tensor_iterator->output(1)); + f = std::make_shared(ngraph::NodeVector{res_ti_1}, ngraph::ParameterVector{X, Y}); + + ngraph::pass::Manager manager; + manager.register_pass(); + manager.register_pass(); + manager.register_pass(); + manager.run_passes(f); + + ASSERT_NO_THROW(check_rt_info(f)); + } + { + auto Xi = std::make_shared(element::f32, Shape{1, 1, 16}); + auto H_t = std::make_shared(element::f32, Shape{1, 128}); + + const std::string variable_name_H("GRUTensorIterator/variable0"); + auto read_value_H = std::make_shared(H_t, variable_name_H); + // Body + auto axis = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{}, {0}); + auto squeeze = std::make_shared(Xi, axis); + + auto w_val = std::vector(384 * 16, 0); + auto r_val = std::vector(384 * 128, 0); + auto b_val = std::vector(384, 0); + auto W = ngraph::opset5::Constant::create(ngraph::element::f32, ngraph::Shape{384, 16}, w_val); + auto R = ngraph::opset5::Constant::create(ngraph::element::f32, ngraph::Shape{384, 128}, r_val); + auto B = ngraph::opset5::Constant::create(ngraph::element::f32, ngraph::Shape{384}, b_val); + + auto rnn_cell = std::make_shared(squeeze, read_value_H, W, R, B, 128); + auto assign_H = std::make_shared(rnn_cell->output(0), variable_name_H); + auto res_1 = std::make_shared(assign_H); + auto unsqueeze = std::make_shared(rnn_cell->output(0), axis); + auto res_2 = std::make_shared(unsqueeze); + f_ref = std::make_shared(OutputVector{unsqueeze}, ParameterVector{Xi, H_t}); + f_ref->add_sinks({assign_H}); + assign_H->add_control_dependency(read_value_H); + } + auto res = compare_functions(f, f_ref); + ASSERT_TRUE(res.first) << res.second; +} + +TEST(TransformationTests, LowLatencyRNN) { + std::shared_ptr f(nullptr), f_ref(nullptr); + { + auto X = std::make_shared(element::f32, Shape{1, 1, 16}); + auto Y = std::make_shared(element::f32, Shape{1, 128}); + + auto Xi = std::make_shared(element::f32, Shape{1, 1, 16}); + auto Yi = std::make_shared(element::f32, Shape{1, 128}); + + // Body + auto axis = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{}, {0}); + auto squeeze = std::make_shared(Xi, axis); + + auto w_val = std::vector(128 * 16, 0); + auto r_val = std::vector(128 * 128, 0); + auto b_val = std::vector(128, 0); + auto W = ngraph::opset5::Constant::create(ngraph::element::f32, ngraph::Shape{128, 16}, w_val); + auto R = ngraph::opset5::Constant::create(ngraph::element::f32, ngraph::Shape{128, 128}, r_val); + auto B = ngraph::opset5::Constant::create(ngraph::element::f32, ngraph::Shape{128}, b_val); + + auto rnn_cell = std::make_shared(squeeze, Yi, W, R, B, 128); + auto res_1 = std::make_shared(rnn_cell); + auto unsqueeze = std::make_shared(rnn_cell, axis); + auto res_2 = std::make_shared(unsqueeze); + auto body = std::make_shared(OutputVector{res_1, res_2}, ParameterVector{Xi, + Yi}); + + auto tensor_iterator = std::make_shared(); + tensor_iterator->set_body(body); + + tensor_iterator->set_sliced_input(Xi, X, 0, 1, 1, -1, 0); + tensor_iterator->set_merged_input(Yi, Y, res_1); + + auto out0 = tensor_iterator->get_iter_value(res_1, -1); + auto out1 = tensor_iterator->get_concatenated_slices(res_2, 0, 1, 1, -1, 0); + + auto res_ti_1 = std::make_shared(tensor_iterator->output(1)); + f = std::make_shared(ngraph::NodeVector{res_ti_1}, ngraph::ParameterVector{X, Y}); + + ngraph::pass::Manager manager; + manager.register_pass(); + manager.register_pass(); + manager.register_pass(); + manager.run_passes(f); + + ASSERT_NO_THROW(check_rt_info(f)); + } + { + auto Xi = std::make_shared(element::f32, Shape{1, 1, 16}); + auto H_t = std::make_shared(element::f32, Shape{1, 128}); + + const std::string variable_name_H("RNNTensorIterator/variable0"); + auto read_value_H = std::make_shared(H_t, variable_name_H); + // Body + auto axis = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{}, {0}); + auto squeeze = std::make_shared(Xi, axis); + + auto w_val = std::vector(128 * 16, 0); + auto r_val = std::vector(128 * 128, 0); + auto b_val = std::vector(128, 0); + auto W = ngraph::opset5::Constant::create(ngraph::element::f32, ngraph::Shape{128, 16}, w_val); + auto R = ngraph::opset5::Constant::create(ngraph::element::f32, ngraph::Shape{128, 128}, r_val); + auto B = ngraph::opset5::Constant::create(ngraph::element::f32, ngraph::Shape{128}, b_val); + + auto rnn_cell = std::make_shared(squeeze, read_value_H, W, R, B, 128); + auto assign_H = std::make_shared(rnn_cell->output(0), variable_name_H); + auto res_1 = std::make_shared(assign_H); + auto unsqueeze = std::make_shared(rnn_cell->output(0), axis); + auto res_2 = std::make_shared(unsqueeze); + f_ref = std::make_shared(OutputVector{unsqueeze}, ParameterVector{Xi, H_t}); + f_ref->add_sinks({assign_H}); + assign_H->add_control_dependency(read_value_H); + } + auto res = compare_functions(f, f_ref); + ASSERT_TRUE(res.first) << res.second; +} + +TEST(TransformationTests, LowLatencyLSTMReshape) { + std::shared_ptr f(nullptr), f_ref(nullptr); + { + auto X = std::make_shared(element::f32, Shape{2, 1, 16}); + auto H = std::make_shared(element::f32, Shape{1, 128}); + auto C = std::make_shared(element::f32, Shape{1, 128}); + + auto Xi = std::make_shared(element::f32, Shape{1, 1, 16}); + auto H_t = std::make_shared(element::f32, Shape{1, 128}); + auto C_t = std::make_shared(element::f32, Shape{1, 128}); + + // Body + auto axis = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{}, {0}); + auto squeeze = std::make_shared(Xi, axis); + + auto w_val = std::vector(512 * 16, 0); + auto r_val = std::vector(512 * 128, 0); + auto b_val = std::vector(512, 0); + auto W = ngraph::opset5::Constant::create(ngraph::element::f32, ngraph::Shape{512, 16}, w_val); + auto R = ngraph::opset5::Constant::create(ngraph::element::f32, ngraph::Shape{512, 128}, r_val); + auto B = ngraph::opset5::Constant::create(ngraph::element::f32, ngraph::Shape{512}, b_val); + + auto lstm_cell = std::make_shared(squeeze, H_t, C_t, W, R, B, 128); + auto res_1 = std::make_shared(lstm_cell->output(0)); + auto unsqueeze = std::make_shared(lstm_cell, axis); + auto res_2 = std::make_shared(unsqueeze); + auto res_3 = std::make_shared(lstm_cell->output(1)); + auto body = std::make_shared(OutputVector{res_1, res_2, res_3}, + ParameterVector{Xi, H_t, C_t}); + + auto tensor_iterator = std::make_shared(); + tensor_iterator->set_body(body); + + tensor_iterator->set_merged_input(C_t, C, res_3); + tensor_iterator->set_sliced_input(Xi, X, 0, 1, 1, -1, 0); + tensor_iterator->set_merged_input(H_t, H, res_1); + + auto out0 = tensor_iterator->get_iter_value(res_1, -1); + auto out1 = tensor_iterator->get_concatenated_slices(res_2, 0, 1, 1, -1, 0); + + auto res_ti_1 = std::make_shared(tensor_iterator->output(1)); + auto res_ti_2 = std::make_shared(tensor_iterator->output(0)); + f = std::make_shared(ngraph::NodeVector{res_ti_1, res_ti_2}, ngraph::ParameterVector{X, H, + C}); + + // Reshape + // change the number of iteration of TI. 2 -> 1 + auto new_X = std::make_shared(element::f32, Shape{1, 1, 16}); + f->replace_parameter(0, new_X); + f->validate_nodes_and_infer_types(); + + ngraph::pass::Manager manager; + manager.register_pass(); + manager.register_pass(); + manager.register_pass(); + manager.run_passes(f); + } + { + auto Xi = std::make_shared(element::f32, Shape{1, 1, 16}); + auto H_t = std::make_shared(element::f32, Shape{1, 128}); + auto C_t = std::make_shared(element::f32, Shape{1, 128}); + + const std::string variable_name_H("LSTMTensorIterator/variable0"); + const std::string variable_name_C("LSTMTensorIterator/variable1"); + auto read_value_H = std::make_shared(H_t, variable_name_H); + auto read_value_C = std::make_shared(C_t, variable_name_C); + // Body + auto axis = ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{}, {0}); + auto squeeze = std::make_shared(Xi, axis); + + auto w_val = std::vector(512 * 16, 0); + auto r_val = std::vector(512 * 128, 0); + auto b_val = std::vector(512, 0); + auto W = ngraph::opset5::Constant::create(ngraph::element::f32, ngraph::Shape{512, 16}, w_val); + auto R = ngraph::opset5::Constant::create(ngraph::element::f32, ngraph::Shape{512, 128}, r_val); + auto B = ngraph::opset5::Constant::create(ngraph::element::f32, ngraph::Shape{512}, b_val); + + auto lstm_cell = std::make_shared(squeeze, read_value_H, read_value_C, W, R, B, 128); + auto assign_H = std::make_shared(lstm_cell->output(0), variable_name_H); + auto assign_C = std::make_shared(lstm_cell->output(1), variable_name_C); + auto res_1 = std::make_shared(lstm_cell->output(0)); + auto unsqueeze = std::make_shared(lstm_cell->output(0), axis); + auto res_2 = std::make_shared(unsqueeze); + f_ref = std::make_shared(OutputVector{unsqueeze, res_1}, ParameterVector{Xi, H_t, C_t}); + f_ref->add_sinks({assign_C, assign_H}); + assign_H->add_control_dependency(read_value_H); + assign_C->add_control_dependency(read_value_C); + } + auto res = compare_functions(f, f_ref); + ASSERT_TRUE(res.first) << res.second; +} diff --git a/inference-engine/tests/functional/plugin/shared/include/subgraph_tests/basic_lstm.hpp b/inference-engine/tests/functional/plugin/shared/include/subgraph_tests/basic_lstm.hpp index b6b68e142ca..daa8ca358fe 100644 --- a/inference-engine/tests/functional/plugin/shared/include/subgraph_tests/basic_lstm.hpp +++ b/inference-engine/tests/functional/plugin/shared/include/subgraph_tests/basic_lstm.hpp @@ -29,11 +29,11 @@ public: void Run() override; protected: + size_t hidden_size; + std::vector hidden_memory_init; + std::vector cell_memory_init; void SetUp() override; std::vector> CalculateRefs() override; - -private: - std::shared_ptr CreateGraphWithUnrolledTI(); }; } // namespace LayerTestsDefinitions diff --git a/inference-engine/tests/functional/plugin/shared/include/subgraph_tests/memory_LSTMCell.hpp b/inference-engine/tests/functional/plugin/shared/include/subgraph_tests/memory_LSTMCell.hpp index ba1252f440c..e30f62a81c7 100644 --- a/inference-engine/tests/functional/plugin/shared/include/subgraph_tests/memory_LSTMCell.hpp +++ b/inference-engine/tests/functional/plugin/shared/include/subgraph_tests/memory_LSTMCell.hpp @@ -20,6 +20,7 @@ class MemoryLSTMCellTest : public LayerTestsUtils::LayerTestsCommon, private: // you have to Unroll TI manually and remove memory untill ngraph supports it void switchToNgraphFriendlyModel(); + void CreatePureTensorIteratorModel(); // since we switching models we need to generate and save weights biases and inputs in SetUp std::vector input_bias; std::vector input_weights; @@ -31,6 +32,7 @@ private: protected: void SetUp() override; void Run() override; + void RunLowLatency(bool regular_api = false); public: static std::string getTestCaseName(const testing::TestParamInfo &obj); }; diff --git a/inference-engine/tests/functional/plugin/shared/include/subgraph_tests/multiple_LSTMCell.hpp b/inference-engine/tests/functional/plugin/shared/include/subgraph_tests/multiple_LSTMCell.hpp index 9e03de02d1a..16b6d7a867e 100644 --- a/inference-engine/tests/functional/plugin/shared/include/subgraph_tests/multiple_LSTMCell.hpp +++ b/inference-engine/tests/functional/plugin/shared/include/subgraph_tests/multiple_LSTMCell.hpp @@ -20,6 +20,7 @@ class MultipleLSTMCellTest : public LayerTestsUtils::LayerTestsCommon, private: // you have to Unroll TI manually and remove memory untill ngraph supports it void switchToNgraphFriendlyModel(); + void CreatePureTensorIteratorModel(); // since we switching models we need to generate and save weights biases and inputs in SetUp size_t hiddenSize; std::vector input_bias; @@ -33,6 +34,7 @@ private: protected: void SetUp() override; void Run() override; + void RunLowLatency(bool regular_api = false); public: static std::string getTestCaseName(const testing::TestParamInfo &obj); }; diff --git a/inference-engine/tests/functional/plugin/shared/src/subgraph_tests/basic_lstm.cpp b/inference-engine/tests/functional/plugin/shared/src/subgraph_tests/basic_lstm.cpp index 0d30b7ec84c..b06cee1c4e0 100644 --- a/inference-engine/tests/functional/plugin/shared/src/subgraph_tests/basic_lstm.cpp +++ b/inference-engine/tests/functional/plugin/shared/src/subgraph_tests/basic_lstm.cpp @@ -8,12 +8,15 @@ #include #include #include +#include #include "common_test_utils/common_utils.hpp" #include "functional_test_utils/blob_utils.hpp" #include "functional_test_utils/layer_test_utils.hpp" #include "functional_test_utils/plugin_cache.hpp" #include "ngraph_functions/pass/convert_prc.hpp" +#include "transformations/control_flow/unroll_tensor_iterator.hpp" +#include "transformations/common_optimizations/low_latency.hpp" #include "subgraph_tests/basic_lstm.hpp" @@ -47,7 +50,7 @@ void Basic_LSTM_S::SetUp() { auto params = ngraph::builder::makeParams(ngPrc, { {1, 490} }); - const size_t hidden_size = 118; + hidden_size = 118; const size_t batch_size = 1; outPrc = InferenceEngine::Precision::FP32; @@ -60,10 +63,13 @@ void Basic_LSTM_S::SetUp() { auto reshape1_shape = reshape1->output(0).get_shape(); auto H_init = ngraph::builder::makeConstant(ngPrc, { batch_size, hidden_size }, {}, true); auto C_init = ngraph::builder::makeConstant(ngPrc, { batch_size, hidden_size }, {}, true); + hidden_memory_init = std::static_pointer_cast(H_init)->cast_vector(); + cell_memory_init = std::static_pointer_cast(C_init)->cast_vector(); auto H_t = std::make_shared(ngPrc, ngraph::Shape{ batch_size, hidden_size }); auto C_t = std::make_shared(ngPrc, ngraph::Shape{ batch_size, hidden_size }); - + H_t->set_friendly_name("hidden_state_1"); + C_t->set_friendly_name("cell_state_1"); //Body auto X = std::make_shared(ngPrc, ngraph::Shape{ batch_size, 1, reshape1_shape[2] }); auto weightsNode = ngraph::builder::makeConstant(ngPrc, { 4 * hidden_size, reshape1_shape[2] }, {}, true); @@ -112,60 +118,12 @@ void Basic_LSTM_S::Run() { Compare(referenceOutputs, actualOutputs); } -std::shared_ptr Basic_LSTM_S::CreateGraphWithUnrolledTI() { - InferenceEngine::Precision netPrecision; - netPrecision = std::get<0>(this->GetParam()); - auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision); - - auto params = ngraph::builder::makeParams(ngPrc, { {1, 490} }); - - const size_t hidden_size = 118; - const size_t batch_size = 1; - const size_t iterations = 10; - - outPrc = InferenceEngine::Precision::FP32; - - //Reshape_1 [1,490] -> [1, 10, 49] - std::vector outFormShapes1 = { batch_size, iterations, 49 }; - auto pattern1 = std::make_shared(ngraph::element::Type_t::i64, ngraph::Shape{ 3 }, outFormShapes1); - auto reshape1 = std::make_shared(params[0], pattern1, false); - - std::vector axis_shape = { 1 }; - auto axis = std::make_shared(ngraph::element::Type_t::i64, ngraph::Shape{ }, axis_shape); - auto split1 = std::make_shared(reshape1, axis, iterations); - - ngraph::Output H[iterations + 1]; - ngraph::Output C[iterations + 1]; - std::shared_ptr lstm[iterations]; - H[0] = ngraph::builder::makeConstant(ngPrc, { batch_size, hidden_size }, {}, true); - C[0] = ngraph::builder::makeConstant(ngPrc, { batch_size, hidden_size }, {}, true); - auto reshape1_shape = reshape1->output(0).get_shape(); - auto weightsNode = ngraph::builder::makeConstant(ngPrc, { 4 * hidden_size, reshape1_shape[2] }, {}, true); - auto reccurrenceWeightsNode = ngraph::builder::makeConstant(ngPrc, { 4 * hidden_size, hidden_size }, {}, true); - - outFormShapes1 = { batch_size, reshape1_shape[2] }; - auto constantX = std::make_shared(ngraph::element::i64, ngraph::Shape{ 2 }, outFormShapes1); - - for (size_t i = 0; i < iterations; ++i) { - auto X = split1->output(i); - lstm[i] = std::make_shared(std::make_shared(X, constantX, false), - H[i], C[i], - weightsNode, reccurrenceWeightsNode, hidden_size); - - H[i+1] = lstm[i]->output(0); - C[i+1] = lstm[i]->output(1); - } - - const size_t output_size = 12; - auto fc1 = ngraph::builder::makeFullyConnected(H[iterations], ngPrc, output_size, true, { hidden_size, output_size }, { 1 }, { 1 }); - - ngraph::ResultVector results{ std::make_shared(fc1) }; - return std::make_shared(results, params, "Basic_LSTM_S_Ref"); -} - std::vector> Basic_LSTM_S::CalculateRefs() { //For now TensorIterator is not implemented in ngraph interpreter so it is needed to validate with another reference - auto reference_model = CreateGraphWithUnrolledTI(); + auto reference_model = ngraph::clone_function(*function); + ngraph::pass::Manager manager; + manager.register_pass(); + manager.run_passes(reference_model); auto refCnnNetwork = InferenceEngine::CNNNetwork{ reference_model }; auto refExecutableNetwork = core->LoadNetwork(refCnnNetwork, targetDevice); @@ -215,4 +173,51 @@ TEST_P(Basic_LSTM_S, CompareWithRefImpl) { Run(); }; +TEST_P(Basic_LSTM_S, CompareWithRefImpl_LowLatencyTransformation) { + InferenceEngine::TensorDesc state_description(InferenceEngine::Precision::FP32, + InferenceEngine::SizeVector({1, hidden_size}), + InferenceEngine::Layout::NC); + // Reshape + auto params = ngraph::builder::makeParams(function->get_parameters().at(0)->get_element_type(), { {1, 49} }); + function->replace_parameter(0, params[0]); + + // todo: it is better to modify the model -> use ShapeOf() and Gather() + std::vector outFormShapes1 = { 1, 1, 49 }; + auto pattern1 = std::make_shared(ngraph::element::Type_t::i64, ngraph::Shape{3}, outFormShapes1); + auto param_target_inputs = function->get_parameters().at(0)->output(0).get_target_inputs(); + + // replace hardcoded shape + for (const auto& target : param_target_inputs.begin()->get_node()->input(1).get_source_output().get_target_inputs()) { + target.replace_source_output(pattern1); + } + function->validate_nodes_and_infer_types(); + + // Calculate References for the network before transformation passes + auto referenceOutputs = CalculateRefs(); + + // Apply LowLatency and UnrollTensorIterator transformations + ngraph::pass::Manager manager; + manager.register_pass(); // LowLatency enables UnrollTI + manager.run_passes(function); + LoadNetwork(); + auto states = executableNetwork.QueryState(); + for (auto& state : states) { + auto name = state.GetName(); + if (name.find("cell_state_1") != std::string::npos) { + auto blob = FuncTestUtils::createAndFillBlobWithFloatArray(state_description, + cell_memory_init.data(), cell_memory_init.size()); + state.SetState(blob); + } else if (name.find("hidden_state_1") != std::string::npos) { + auto blob = FuncTestUtils::createAndFillBlobWithFloatArray(state_description, + hidden_memory_init.data(), hidden_memory_init.size()); + state.SetState(blob); + } else { + GTEST_FAIL() << "unknown memory state"; + } + } + // Run and compare + Infer(); + const auto& actualOutputs = GetOutputs(); + Compare(referenceOutputs, actualOutputs); +}; } // namespace LayerTestsDefinitions diff --git a/inference-engine/tests/functional/plugin/shared/src/subgraph_tests/memory_LSTMCell.cpp b/inference-engine/tests/functional/plugin/shared/src/subgraph_tests/memory_LSTMCell.cpp index 051aa47e959..dcbeb7c68d3 100644 --- a/inference-engine/tests/functional/plugin/shared/src/subgraph_tests/memory_LSTMCell.cpp +++ b/inference-engine/tests/functional/plugin/shared/src/subgraph_tests/memory_LSTMCell.cpp @@ -10,6 +10,7 @@ #include "ie_core.hpp" +#include "ie_transformations.hpp" #include "common_test_utils/common_utils.hpp" #include "functional_test_utils/blob_utils.hpp" #include "functional_test_utils/precision_utils.hpp" @@ -19,6 +20,8 @@ #include "ngraph_functions/builders.hpp" #include +#include "transformations/control_flow/unroll_tensor_iterator.hpp" +#include "transformations/common_optimizations/low_latency.hpp" #include "subgraph_tests/memory_LSTMCell.hpp" namespace SubgraphTestsDefinitions { @@ -194,6 +197,79 @@ namespace SubgraphTestsDefinitions { function = std::make_shared(final_reshape, input_parameter, "TI_unrolled_without_memory"); } + void MemoryLSTMCellTest::CreatePureTensorIteratorModel() { + InferenceEngine::Precision netPrecision; + std::map config; + size_t inputSize; + std::tie(targetDevice, netPrecision, inputSize, hiddenSize, config) = this->GetParam(); + auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision); + + std::vector input_dims { 1, inputSize }; + std::vector squeeze_axes {0}; + std::vector hidden_memory_dims {1, hiddenSize}; + std::vector cell_memory_dims {1, hiddenSize}; + + auto input_parameter = ngraph::builder::makeParams(ngPrc, {input_dims}); + + auto input_add_const = ngraph::builder::makeConstant(ngPrc, input_dims, input_bias); + auto add = ngraph::builder::makeEltwise(input_parameter[0], input_add_const, ngraph::helpers::EltwiseTypes::ADD); + + auto input_mul_const = ngraph::builder::makeConstant(ngPrc, input_dims, input_weights); + auto mul = ngraph::builder::makeEltwise(add, input_mul_const, ngraph::helpers::EltwiseTypes::MULTIPLY); + + auto unsqueeze_input_const = std::make_shared(ngraph::element::i64, ngraph::Shape{1}, squeeze_axes); + auto unsqueeze_input = std::make_shared(mul, unsqueeze_input_const); + + auto permute_in_params = std::make_shared(ngraph::element::i64, ngraph::Shape{3}, ngraph::Shape{{1, 0, 2}}); + auto permute_in = std::make_shared(unsqueeze_input, permute_in_params); + + auto cell_memory_constant = ngraph::builder::makeConstant(ngPrc, cell_memory_dims, cell_memory_init); + + auto hidden_memory_constant = ngraph::builder::makeConstant(ngPrc, hidden_memory_dims, hidden_memory_init); + + // Body - inputs + auto X = std::make_shared(ngPrc, ngraph::Shape{1, 1, inputSize}); + auto H_t = std::make_shared(ngPrc, ngraph::Shape{1, hiddenSize}); + auto C_t = std::make_shared(ngPrc, ngraph::Shape{1, hiddenSize}); + H_t->set_friendly_name("hidden_state_1"); + C_t->set_friendly_name("cell_state_1"); + // Body - layers + auto squeeze_const = std::make_shared(ngraph::element::i64, ngraph::Shape{1}, squeeze_axes); + auto squeeze = std::make_shared(X, squeeze_const); + + auto weightsNode = ngraph::builder::makeConstant(ngPrc, { 4 * hiddenSize, inputSize }, weights_vals); + auto reccurrenceWeightsNode = ngraph::builder::makeConstant(ngPrc, { 4 * hiddenSize, hiddenSize }, reccurrenceWeights_vals); + auto biasNode = ngraph::builder::makeConstant(ngPrc, {4 * hiddenSize}, bias_vals); + auto lstm = std::make_shared(squeeze, H_t, C_t, weightsNode, reccurrenceWeightsNode, biasNode, hiddenSize); + + auto unsqueeze_const = std::make_shared(ngraph::element::i64, ngraph::Shape{1}, squeeze_axes); + auto unsqueeze = std::make_shared(lstm->output(0), unsqueeze_const); + // body - outputs + auto H_o = lstm->output(0); + auto C_o = lstm->output(1); + auto unsqueeze_o = unsqueeze->output(0); + + auto body = std::make_shared(ngraph::OutputVector{unsqueeze_o, H_o, C_o}, ngraph::ParameterVector {X, H_t, C_t}); + // TI construction + auto tensor_iterator = std::make_shared(); + tensor_iterator->set_body(body); + tensor_iterator->set_sliced_input(X, permute_in, 0, 1, 1, -1, 0); + tensor_iterator->set_merged_input(H_t, hidden_memory_constant, H_o); + tensor_iterator->set_merged_input(C_t, cell_memory_constant, C_o); + + auto out_unsqueeze = tensor_iterator->get_iter_value(unsqueeze_o, -1); + auto out_hidden = tensor_iterator->get_iter_value(H_o, -1); + auto out_cell = tensor_iterator->get_iter_value(C_o, -1); + + out_hidden.get_tensor().set_element_type(ngPrc); + out_cell.get_tensor().set_element_type(ngPrc); + + auto final_reshape_pattern = std::make_shared(ngraph::element::i64, ngraph::Shape{4}, std::vector({1, 1, 1, hiddenSize})); + auto final_reshape = std::make_shared(out_unsqueeze, final_reshape_pattern, false); + + function = std::make_shared(final_reshape, input_parameter, "PureTI"); + } + void MemoryLSTMCellTest::Run() { SKIP_IF_CURRENT_TEST_IS_DISABLED() @@ -218,7 +294,55 @@ namespace SubgraphTestsDefinitions { Validate(); } + void MemoryLSTMCellTest::RunLowLatency(bool regular_api) { + SKIP_IF_CURRENT_TEST_IS_DISABLED() + + CreatePureTensorIteratorModel(); + if (regular_api) { + cnnNetwork = InferenceEngine::CNNNetwork{function}; + InferenceEngine::LowLatency(cnnNetwork); + ConfigureNetwork(); + executableNetwork = core->LoadNetwork(cnnNetwork, targetDevice, configuration); + } else { + // Apply LowLatency (insert Assigns/ReadValues) and UnrollTensorIterator + ngraph::pass::Manager manager; + manager.register_pass(); // LowLatency enables UnrollTI + manager.run_passes(function); + LoadNetwork(); + } + auto states = executableNetwork.QueryState(); + for (auto& state : states) { + auto name = state.GetName(); + if (name.find("cell_state_1") != std::string::npos) { + auto blob = FuncTestUtils::createAndFillBlobWithFloatArray(state.GetLastState()->getTensorDesc(), + cell_memory_init.data(), cell_memory_init.size()); + state.SetState(blob); + } else if (name.find("hidden_state_1") != std::string::npos) { + auto blob = FuncTestUtils::createAndFillBlobWithFloatArray(state.GetLastState()->getTensorDesc(), + hidden_memory_init.data(), hidden_memory_init.size()); + state.SetState(blob); + } else { + GTEST_FAIL() << "unknown memory state"; + } + } + Infer(); + + CreatePureTensorIteratorModel(); + ngraph::pass::Manager manager_2; + manager_2.register_pass(); + manager_2.run_passes(function); + Validate(); + } + TEST_P(MemoryLSTMCellTest, CompareWithRefs) { Run(); }; + + TEST_P(MemoryLSTMCellTest, CompareWithRefs_LowLatencyTransformation) { + RunLowLatency(); + }; + + TEST_P(MemoryLSTMCellTest, CompareWithRefs_LowLatencyRegularAPITransformation) { + RunLowLatency(true); + }; } // namespace SubgraphTestsDefinitions diff --git a/inference-engine/tests/functional/plugin/shared/src/subgraph_tests/multiple_LSTMCell.cpp b/inference-engine/tests/functional/plugin/shared/src/subgraph_tests/multiple_LSTMCell.cpp index 22532965cf1..1df0e7baf26 100644 --- a/inference-engine/tests/functional/plugin/shared/src/subgraph_tests/multiple_LSTMCell.cpp +++ b/inference-engine/tests/functional/plugin/shared/src/subgraph_tests/multiple_LSTMCell.cpp @@ -19,6 +19,10 @@ #include "ngraph_functions/builders.hpp" #include +#include +#include +#include "transformations/control_flow/unroll_tensor_iterator.hpp" +#include "transformations/common_optimizations/low_latency.hpp" #include "subgraph_tests/multiple_LSTMCell.hpp" namespace SubgraphTestsDefinitions { @@ -280,6 +284,131 @@ void MultipleLSTMCellTest::switchToNgraphFriendlyModel() { function = std::make_shared(final_reshape, input_parameter, "TI_unrolled_without_memory"); } +void MultipleLSTMCellTest::CreatePureTensorIteratorModel() { + InferenceEngine::Precision netPrecision; + std::map config; + size_t inputSize; + std::tie(targetDevice, netPrecision, inputSize, hiddenSize, config) = this->GetParam(); + auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision); + + std::vector input_dims { 1, inputSize }; + std::vector squeeze_axes {0}; + std::vector hidden_memory_dims {1, hiddenSize}; + std::vector cell_memory_dims {1, hiddenSize}; + + auto input_parameter = ngraph::builder::makeParams(ngPrc, {input_dims}); + + auto input_add_const = ngraph::builder::makeConstant(ngPrc, input_dims, input_bias); + auto add = ngraph::builder::makeEltwise(input_parameter[0], input_add_const, ngraph::helpers::EltwiseTypes::ADD); + + auto input_mul_const = ngraph::builder::makeConstant(ngPrc, input_dims, input_weights); + auto mul = ngraph::builder::makeEltwise(add, input_mul_const, ngraph::helpers::EltwiseTypes::MULTIPLY); + + auto unsqueeze_input_const = std::make_shared(ngraph::element::i64, ngraph::Shape{1}, squeeze_axes); + auto unsqueeze_input = std::make_shared(mul, unsqueeze_input_const); + + auto permute_in_params = std::make_shared(ngraph::element::i64, ngraph::Shape{3}, ngraph::Shape{{1, 0, 2}}); + auto permute_in = std::make_shared(unsqueeze_input, permute_in_params); + + auto cell_memory_constant = ngraph::builder::makeConstant(ngPrc, cell_memory_dims, cell_memory_init); + + auto hidden_memory_constant = ngraph::builder::makeConstant(ngPrc, hidden_memory_dims, hidden_memory_init); + + // Body - inputs + auto X = std::make_shared(ngPrc, ngraph::Shape{1, 1, inputSize}); + auto H_t = std::make_shared(ngPrc, ngraph::Shape{1, hiddenSize}); + auto C_t = std::make_shared(ngPrc, ngraph::Shape{1, hiddenSize}); + H_t->set_friendly_name("hidden_state_1"); + C_t->set_friendly_name("cell_state_1"); + // Body - layers + auto squeeze_const = std::make_shared(ngraph::element::i64, ngraph::Shape{1}, squeeze_axes); + auto squeeze = std::make_shared(X, squeeze_const); + + auto weightsNode = ngraph::builder::makeConstant(ngPrc, { 4 * hiddenSize, inputSize }, weights_vals); + auto reccurrenceWeightsNode = ngraph::builder::makeConstant(ngPrc, { 4 * hiddenSize, hiddenSize }, reccurrenceWeights_vals); + auto biasNode = ngraph::builder::makeConstant(ngPrc, {4 * hiddenSize}, bias_vals); + auto lstm = std::make_shared(squeeze, H_t, C_t, weightsNode, reccurrenceWeightsNode, biasNode, hiddenSize); + + auto unsqueeze_const = std::make_shared(ngraph::element::i64, ngraph::Shape{1}, squeeze_axes); + auto unsqueeze = std::make_shared(lstm->output(0), unsqueeze_const); + // body - outputs + auto H_o = lstm->output(0); + auto C_o = lstm->output(1); + auto unsqueeze_o = unsqueeze->output(0); + + auto body = std::make_shared(ngraph::OutputVector{unsqueeze_o, H_o, C_o}, ngraph::ParameterVector {X, H_t, C_t}); + // TI construction + auto tensor_iterator = std::make_shared(); + tensor_iterator->set_body(body); + tensor_iterator->set_sliced_input(X, permute_in, 0, 1, 1, -1, 0); + tensor_iterator->set_merged_input(H_t, hidden_memory_constant, H_o); + tensor_iterator->set_merged_input(C_t, cell_memory_constant, C_o); + + auto out_unsqueeze = tensor_iterator->get_iter_value(unsqueeze_o, -1); + auto out_hidden = tensor_iterator->get_iter_value(H_o, -1); + auto out_cell = tensor_iterator->get_iter_value(C_o, -1); + + out_hidden.get_tensor().set_element_type(ngPrc); + out_cell.get_tensor().set_element_type(ngPrc); + tensor_iterator->validate_and_infer_types(); + + auto first_reshape_pattern = std::make_shared(ngraph::element::i64, + ngraph::Shape{4}, std::vector({1, 1, 1, hiddenSize})); + auto first_reshape = std::make_shared(out_unsqueeze, first_reshape_pattern, false); + // End of TI 1 + + auto inbetween_squeeze_const = std::make_shared(ngraph::element::i64, ngraph::Shape{1}, squeeze_axes); + auto inbetween_squeeze = std::make_shared(first_reshape, inbetween_squeeze_const); + + // Second TI + auto cell_memory_2_constant = ngraph::builder::makeConstant(ngPrc, cell_memory_dims, cell_memory_init); + + auto hidden_memory_2_constant = ngraph::builder::makeConstant(ngPrc, hidden_memory_dims, hidden_memory_init); + + // Body - inputs + auto X_2 = std::make_shared(ngPrc, ngraph::Shape{1, 1, hiddenSize}); + auto H_t_2 = std::make_shared(ngPrc, ngraph::Shape{1, hiddenSize}); + auto C_t_2 = std::make_shared(ngPrc, ngraph::Shape{1, hiddenSize}); + H_t_2->set_friendly_name("hidden_state_2"); + C_t_2->set_friendly_name("cell_state_2"); + // Body - layers + auto squeeze_2_const = std::make_shared(ngraph::element::i64, ngraph::Shape{1}, squeeze_axes); + auto squeeze_2 = std::make_shared(X_2, squeeze_2_const); + + auto weightsNode_2 = ngraph::builder::makeConstant(ngPrc, { 4 * hiddenSize, hiddenSize }, weights_2_vals); + auto reccurrenceWeightsNode_2 = ngraph::builder::makeConstant(ngPrc, { 4 * hiddenSize, hiddenSize }, reccurrenceWeights_vals); + auto biasNode_2 = ngraph::builder::makeConstant(ngPrc, {4 * hiddenSize}, bias_vals); + auto lstm_2 = std::make_shared(squeeze_2, H_t_2, C_t_2, weightsNode_2, reccurrenceWeightsNode_2, biasNode_2, hiddenSize); + + auto unsqueeze_2_const = std::make_shared(ngraph::element::i64, ngraph::Shape{1}, squeeze_axes); + auto unsqueeze_2 = std::make_shared(lstm_2->output(0), unsqueeze_2_const); + // body - outputs + auto H_o_2 = lstm_2->output(0); + auto C_o_2 = lstm_2->output(1); + auto unsqueeze_o_2 = unsqueeze_2->output(0); + + auto body_2 = std::make_shared(ngraph::OutputVector{unsqueeze_o_2, H_o_2, C_o_2}, ngraph::ParameterVector {X_2, H_t_2, C_t_2}); + // TI construction + auto tensor_iterator_2 = std::make_shared(); + tensor_iterator_2->set_body(body_2); + tensor_iterator_2->set_sliced_input(X_2, inbetween_squeeze, 0, 1, 1, -1, 0); + tensor_iterator_2->set_merged_input(H_t_2, hidden_memory_2_constant, H_o_2); + tensor_iterator_2->set_merged_input(C_t_2, cell_memory_2_constant, C_o_2); + + auto out_unsqueeze_2 = tensor_iterator_2->get_iter_value(unsqueeze_o_2, -1); + auto out_hidden_2 = tensor_iterator_2->get_iter_value(H_o_2, -1); + auto out_cell_2 = tensor_iterator_2->get_iter_value(C_o_2, -1); + + out_hidden_2.get_tensor().set_element_type(ngPrc); + out_cell_2.get_tensor().set_element_type(ngPrc); + tensor_iterator_2->validate_and_infer_types(); + auto final_reshape_pattern = std::make_shared(ngraph::element::i64, + ngraph::Shape{4}, std::vector({1, 1, 1, hiddenSize})); + auto final_reshape = std::make_shared(out_unsqueeze_2, final_reshape_pattern, false); + + function = std::make_shared(final_reshape, input_parameter, "PureTI"); +} + void MultipleLSTMCellTest::Run() { SKIP_IF_CURRENT_TEST_IS_DISABLED() InferenceEngine::TensorDesc state_description(InferenceEngine::Precision::FP32, @@ -314,7 +443,68 @@ void MultipleLSTMCellTest::Run() { Validate(); } +void MultipleLSTMCellTest::RunLowLatency(bool regular_api) { + SKIP_IF_CURRENT_TEST_IS_DISABLED() + InferenceEngine::TensorDesc state_description(InferenceEngine::Precision::FP32, + InferenceEngine::SizeVector({1, hiddenSize}), + InferenceEngine::Layout::NC); + // Calculate values after LowLatency transformation + CreatePureTensorIteratorModel(); + if (regular_api) { + cnnNetwork = InferenceEngine::CNNNetwork{function}; + InferenceEngine::LowLatency(cnnNetwork); + ConfigureNetwork(); + executableNetwork = core->LoadNetwork(cnnNetwork, targetDevice, configuration); + } else { + function->validate_nodes_and_infer_types(); + // Apply LowLatency (insert Assigns/ReadValues) and UnrollTensorIterator + ngraph::pass::Manager manager; + manager.register_pass(); // LowLatency enables UnrollTI + manager.run_passes(function); + LoadNetwork(); + } + auto states = executableNetwork.QueryState(); + for (auto& state : states) { + auto name = state.GetName(); + if (name.find("cell_state_1") != std::string::npos) { + auto blob = FuncTestUtils::createAndFillBlobWithFloatArray(state_description, + cell_memory_init.data(), cell_memory_init.size()); + state.SetState(blob); + } else if (name.find("hidden_state_1") != std::string::npos) { + auto blob = FuncTestUtils::createAndFillBlobWithFloatArray(state_description, + hidden_memory_init.data(), hidden_memory_init.size()); + state.SetState(blob); + } else if (name.find("cell_state_2") != std::string::npos) { + auto blob = FuncTestUtils::createAndFillBlobWithFloatArray(state_description, + cell_memory_init.data(), cell_memory_init.size()); + state.SetState(blob); + } else if (name.find("hidden_state_2") != std::string::npos) { + auto blob = FuncTestUtils::createAndFillBlobWithFloatArray(state_description, + hidden_memory_init.data(), hidden_memory_init.size()); + state.SetState(blob); + } else { + GTEST_FAIL() << "unknown memory state"; + } + } + Infer(); + + // Calculate ref values for Unrolled TI + CreatePureTensorIteratorModel(); + ngraph::pass::Manager manager_2; + manager_2.register_pass(); + manager_2.run_passes(function); + Validate(); +} + TEST_P(MultipleLSTMCellTest, CompareWithRefs) { Run(); }; + +TEST_P(MultipleLSTMCellTest, CompareWithRefs_LowLatencyTransformation) { + RunLowLatency(); +}; + +TEST_P(MultipleLSTMCellTest, CompareWithRefs_LowLatencyRegularAPITransformation) { + RunLowLatency(true); +}; } // namespace SubgraphTestsDefinitions diff --git a/inference-engine/tests/ie_test_utils/common_test_utils/ngraph_test_utils.cpp b/inference-engine/tests/ie_test_utils/common_test_utils/ngraph_test_utils.cpp index aba68380be7..d24f45b52c1 100644 --- a/inference-engine/tests/ie_test_utils/common_test_utils/ngraph_test_utils.cpp +++ b/inference-engine/tests/ie_test_utils/common_test_utils/ngraph_test_utils.cpp @@ -84,12 +84,18 @@ std::pair compare_functions( * - Do not check nodes attributes (requires visitor mechanism to be completed) */ - const auto f1_results = f1->get_results(); - const auto f2_results = f2->get_results(); + const auto& f1_results = f1->get_results(); + const auto& f2_results = f2->get_results(); if (f1_results.size() != f2_results.size()) { return { false, "Number of results is different: " + std::to_string(f1_results.size()) + " and " + std::to_string(f2_results.size()) }; } + const auto& f1_sinks = f1->get_sinks(); + const auto& f2_sinks = f2->get_sinks(); + if (f1_sinks.size() != f2_sinks.size()) { + return { false, "Number of sinks is different: " + std::to_string(f1_sinks.size()) + " and " + std::to_string(f2_sinks.size()) }; + } + auto typeInfoToStr = [](const ngraph::Node::type_info_t & typeInfo) { return std::string(typeInfo.name) + "/" + std::to_string(typeInfo.version); }; @@ -120,6 +126,13 @@ std::pair compare_functions( return {false, typeInfoToStr(type_info1) + " != " + typeInfoToStr(type_info2)}; } + const auto& dependencies_1 = node1->get_control_dependencies(); + const auto& dependencies_2 = node2->get_control_dependencies(); + if (dependencies_1.size() != dependencies_2.size()) { + return {false, "Number of dependencies is different: " + std::to_string(dependencies_1.size()) + " for " + node1->get_friendly_name() + + + " and " + std::to_string(dependencies_2.size()) + " for " + node2->get_friendly_name()}; + } + if (node1->inputs().size() != node2->inputs().size()) { return {false, "Number of inputs is different: " + std::to_string(node1->inputs().size()) + " for " + node1->get_friendly_name() + + " and " + std::to_string(node2->inputs().size()) + " for " + node2->get_friendly_name()}; diff --git a/ngraph/core/include/ngraph/function.hpp b/ngraph/core/include/ngraph/function.hpp index 5affa0b28f3..4f5cc1ce12d 100644 --- a/ngraph/core/include/ngraph/function.hpp +++ b/ngraph/core/include/ngraph/function.hpp @@ -182,6 +182,7 @@ namespace ngraph topological_sort_t m_topological_sorter; ResultVector m_results; + // List of the nodes with side effect in graph. // These nodes are not outputs of graph but should not be removed even if have no children. SinkVector m_sinks; diff --git a/ngraph/core/include/ngraph/op/sink.hpp b/ngraph/core/include/ngraph/op/sink.hpp index 266b99a2b59..c198ac0d783 100644 --- a/ngraph/core/include/ngraph/op/sink.hpp +++ b/ngraph/core/include/ngraph/op/sink.hpp @@ -36,6 +36,7 @@ namespace ngraph : Op() { } + Sink(const OutputVector& arguments) : Op(arguments) { diff --git a/ngraph/core/include/ngraph/pass/low_latency.hpp b/ngraph/core/include/ngraph/pass/low_latency.hpp new file mode 100644 index 00000000000..4650a3022e0 --- /dev/null +++ b/ngraph/core/include/ngraph/pass/low_latency.hpp @@ -0,0 +1,55 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include + +#include + +namespace ngraph +{ + namespace pass + { + class NGRAPH_API LowLatency; + } // namespace pass +} // namespace ngraph + +/** + * @brief The transformation finds all TensorIterator layers in the network, processes all back + * edges that describe a connection between Result and Parameter of the TensorIterator body, + * and inserts ReadValue layer between Parameter and the next layers after this Parameter, + * and Assign layer after the layers before the Result layer. + * Supported platforms: CPU, GNA. + * + * The example below describes the changes to the inner part (body, back edges) of the Tensor + * Iterator layer. + * [] - TensorIterator body + * () - new layer + * + * before applying the transformation: + * back_edge_1 -> [Parameter -> some layers ... -> Result ] -> back_edge_1 + * + * after applying the transformation: + * back_edge_1 -> [Parameter -> (ReadValue layer) -> some layers ... -> (Assign layer) ] + * \ + * -> Result ] -> back_edge_1 + * + * It is recommended to use this transformation in conjunction with the Reshape feature to set + * sequence dimension to 1 and with the UnrollTensorIterator transformation. + * For convenience, we have already enabled the unconditional execution of the UnrollTensorIterator + * transformation when using the LowLatency transformation for CPU, GNA plugins, no action is + * required here. + * After applying both of these transformations, the resulting network can be inferred step by + * step, the states will store between inferences. + * + */ + +class ngraph::pass::LowLatency : public ngraph::pass::MatcherPass +{ +public: + NGRAPH_RTTI_DECLARATION; + LowLatency(); +}; diff --git a/ngraph/core/src/op/loop.cpp b/ngraph/core/src/op/loop.cpp index 25245522da6..436c61d4bd5 100644 --- a/ngraph/core/src/op/loop.cpp +++ b/ngraph/core/src/op/loop.cpp @@ -349,10 +349,12 @@ std::shared_ptr op::v5::Loop::clone_with_new_inputs(const OutputVector& ne } op->m_num_iterations = m_num_iterations; op->m_special_body_ports = m_special_body_ports; - auto func = std::make_shared(m_body->get_results(), m_body->get_parameters()); + auto func = std::make_shared( + m_body->get_results(), m_body->get_sinks(), m_body->get_parameters()); auto spec_func = specialize_function( func, types, new_shapes, std::vector(body_params_args.size(), nullptr)); - op->m_body = std::make_shared(spec_func->get_results(), spec_func->get_parameters()); + op->m_body = std::make_shared( + spec_func->get_results(), spec_func->get_sinks(), spec_func->get_parameters()); for (auto& input_description : m_input_descriptions) { diff --git a/ngraph/core/src/op/tensor_iterator.cpp b/ngraph/core/src/op/tensor_iterator.cpp index cefe0362895..d40de4bef69 100644 --- a/ngraph/core/src/op/tensor_iterator.cpp +++ b/ngraph/core/src/op/tensor_iterator.cpp @@ -126,18 +126,8 @@ void op::v0::TensorIterator::validate_and_infer_types() auto start = make_positive(slice_input_description->m_start, dim_size); auto end = make_positive(slice_input_description->m_end, dim_size); - if (m_num_iterations == -1) - { - // +1 because the left and right borders are included [start, end] - m_num_iterations = (abs(end - start) + 1) / part_size; - } - else - { - NODE_VALIDATION_CHECK(this, - m_num_iterations == (abs(end - start) + 1) / part_size, - "Number of slices not the same"); - } - + // +1 because the left and right borders are included [start, end] + m_num_iterations = (abs(end - start) + 1) / part_size; if (body_param_partial_shape.is_static()) { // validate @@ -316,10 +306,12 @@ std::shared_ptr } op->m_num_iterations = m_num_iterations; - auto func = std::make_shared(m_body->get_results(), m_body->get_parameters()); + auto func = std::make_shared( + m_body->get_results(), m_body->get_sinks(), m_body->get_parameters()); auto spec_func = specialize_function(func, types, new_shapes, std::vector(new_args.size(), nullptr)); - op->m_body = std::make_shared(spec_func->get_results(), spec_func->get_parameters()); + op->m_body = std::make_shared( + spec_func->get_results(), spec_func->get_sinks(), spec_func->get_parameters()); for (auto& input_description : m_input_descriptions) { diff --git a/ngraph/core/src/pass/low_latency.cpp b/ngraph/core/src/pass/low_latency.cpp new file mode 100644 index 00000000000..89a7a73c499 --- /dev/null +++ b/ngraph/core/src/pass/low_latency.cpp @@ -0,0 +1,71 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "ngraph/pass/low_latency.hpp" + +#include + +#include +#include +#include + +NGRAPH_RTTI_DEFINITION(ngraph::pass::LowLatency, "LowLatency", 0); + +ngraph::pass::LowLatency::LowLatency() +{ + auto tensor_iterator = ngraph::pattern::wrap_type(); + ngraph::matcher_pass_callback callback = [](ngraph::pattern::Matcher& m) { + auto ti = std::dynamic_pointer_cast(m.get_match_root()); + if (!ti) + { + return false; + } + + // Mark the TI layer to be unrolled. Enable unconditional ti unrolling for all plugins. + auto& rt_info = ti->get_rt_info(); + rt_info["UNROLL_TI"] = std::make_shared>(1); + + int64_t variable_id = 0; + std::vector> assigns; + const auto& func = ti->get_function(); + for (const auto& in : ti->get_input_descriptions()) + { + // Process all back edges + if (const auto& merged_in = std::dynamic_pointer_cast< + ngraph::opset5::TensorIterator::MergedInputDescription>(in)) + { + // Insert ReadValue nodes: Parameter -> (new ReadValue) -> consumers + const auto& inputs_to = func->get_parameters() + .at(merged_in->m_body_parameter_index) + ->get_output_target_inputs(0); + const std::string variable_name(ti->get_friendly_name() + "/" + + func->get_parameters() + .at(merged_in->m_body_parameter_index) + ->get_friendly_name() + + "/variable_" + std::to_string(variable_id)); + auto read_value = std::make_shared( + func->get_parameters().at(merged_in->m_body_parameter_index), variable_name); + read_value->set_friendly_name(variable_name); + for (const auto& input_to : inputs_to) + { + input_to.replace_source_output(read_value->output(0)); + } + + // insert Assign nodes: provider -> (new Assign) -> Result + const auto res = func->get_results().at(merged_in->m_body_value_index); + auto assign = std::make_shared(res->input_value(0), variable_name); + // control dependency so that ReadValue is processed before Assign + assign->add_control_dependency(read_value); + assigns.emplace_back(assign); + } + variable_id++; + } + // save Assign in the func so that it gets into graph traversals and isn't deleted. + func->add_sinks(assigns); + return false; + }; + + auto m = std::make_shared(tensor_iterator, "LowLatency"); + register_matcher(m, callback); +}