MakeStateful transformation (#7417)
* ReplaceInOutWithMemory transformation in ngraph * add unit tests * add ReplaceInputsOutputsWithMemory transformation in python * update codestyle * rename the transformation * fix codestyle * add Dynamic shapes check in the transformation and unit test * fix codestyle * rename files * fix python API * fix codestyle * fix codestyle * update python API * fix codestyle * fix build * codestyle * fix unit test * fix RTTI declaration * Apply suggestions from code review Co-authored-by: Gleb Kazantaev <gleb.nnstu@gmail.com> * review comments * openvino codestyle * change the name of Variable in the transformation * fix build * delete MakeStateful transformation from ie_transformations * Resolve review comments * codestyle * fix missprint, codestyle * delete unused variable Co-authored-by: Gleb Kazantaev <gleb.nnstu@gmail.com>
This commit is contained in:
parent
fd97a62263
commit
7b1a418bf4
@ -6,6 +6,7 @@ from ..inference_engine.ie_api cimport IENetwork
|
|||||||
|
|
||||||
from libcpp cimport bool
|
from libcpp cimport bool
|
||||||
from libcpp.string cimport string
|
from libcpp.string cimport string
|
||||||
|
from libcpp.map cimport map
|
||||||
from libc.stdint cimport int64_t
|
from libc.stdint cimport int64_t
|
||||||
|
|
||||||
|
|
||||||
@ -17,6 +18,15 @@ def ApplyPOTTransformations(IENetwork network, string device):
|
|||||||
C.ApplyPOTTransformations(network.impl, device)
|
C.ApplyPOTTransformations(network.impl, device)
|
||||||
|
|
||||||
|
|
||||||
|
def ApplyMakeStatefulTransformation(IENetwork network, param_res_names : dict):
|
||||||
|
cdef map[string, string] c_param_res_names
|
||||||
|
for param_name, res_name in param_res_names.items():
|
||||||
|
if type(param_name) != str or type(res_name) != str:
|
||||||
|
raise TypeError("Only string keys and values are allowed!")
|
||||||
|
c_param_res_names[param_name.encode()] = res_name.encode()
|
||||||
|
C.ApplyMakeStatefulTransformation(network.impl, c_param_res_names)
|
||||||
|
|
||||||
|
|
||||||
def ApplyLowLatencyTransformation(IENetwork network, bool use_const_initializer = True):
|
def ApplyLowLatencyTransformation(IENetwork network, bool use_const_initializer = True):
|
||||||
C.ApplyLowLatencyTransformation(network.impl, use_const_initializer)
|
C.ApplyLowLatencyTransformation(network.impl, use_const_initializer)
|
||||||
|
|
||||||
|
@ -9,6 +9,7 @@
|
|||||||
#include <ngraph/pass/constant_folding.hpp>
|
#include <ngraph/pass/constant_folding.hpp>
|
||||||
#include <ngraph/pass/low_latency.hpp>
|
#include <ngraph/pass/low_latency.hpp>
|
||||||
#include <ngraph/pass/manager.hpp>
|
#include <ngraph/pass/manager.hpp>
|
||||||
|
#include <openvino/pass/make_stateful.hpp>
|
||||||
#include <pot_transformations.hpp>
|
#include <pot_transformations.hpp>
|
||||||
#include <pruning.hpp>
|
#include <pruning.hpp>
|
||||||
#include <transformations/common_optimizations/moc_transformations.hpp>
|
#include <transformations/common_optimizations/moc_transformations.hpp>
|
||||||
@ -33,6 +34,13 @@ void InferenceEnginePython::ApplyLowLatencyTransformation(InferenceEnginePython:
|
|||||||
manager.run_passes(network.actual->getFunction());
|
manager.run_passes(network.actual->getFunction());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void InferenceEnginePython::ApplyMakeStatefulTransformation(InferenceEnginePython::IENetwork network,
|
||||||
|
std::map<std::string, std::string>& param_res_names) {
|
||||||
|
ngraph::pass::Manager manager;
|
||||||
|
manager.register_pass<ov::pass::MakeStateful>(param_res_names);
|
||||||
|
manager.run_passes(network.actual->getFunction());
|
||||||
|
}
|
||||||
|
|
||||||
void InferenceEnginePython::ApplyPruningTransformation(InferenceEnginePython::IENetwork network) {
|
void InferenceEnginePython::ApplyPruningTransformation(InferenceEnginePython::IENetwork network) {
|
||||||
ngraph::pass::Manager manager;
|
ngraph::pass::Manager manager;
|
||||||
manager.register_pass<ngraph::pass::Pruning>();
|
manager.register_pass<ngraph::pass::Pruning>();
|
||||||
|
@ -4,6 +4,7 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <map>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
#include "Python.h"
|
#include "Python.h"
|
||||||
@ -17,6 +18,9 @@ void ApplyPOTTransformations(InferenceEnginePython::IENetwork network, std::stri
|
|||||||
|
|
||||||
void ApplyLowLatencyTransformation(InferenceEnginePython::IENetwork network, bool use_const_initializer = true);
|
void ApplyLowLatencyTransformation(InferenceEnginePython::IENetwork network, bool use_const_initializer = true);
|
||||||
|
|
||||||
|
void ApplyMakeStatefulTransformation(InferenceEnginePython::IENetwork network,
|
||||||
|
std::map<std::string, std::string>& param_res_names);
|
||||||
|
|
||||||
void ApplyPruningTransformation(InferenceEnginePython::IENetwork network);
|
void ApplyPruningTransformation(InferenceEnginePython::IENetwork network);
|
||||||
|
|
||||||
void GenerateMappingFile(InferenceEnginePython::IENetwork network, std::string path, bool extract_names);
|
void GenerateMappingFile(InferenceEnginePython::IENetwork network, std::string path, bool extract_names);
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
|
|
||||||
from libcpp cimport bool
|
from libcpp cimport bool
|
||||||
from libcpp.string cimport string
|
from libcpp.string cimport string
|
||||||
|
from libcpp.map cimport map
|
||||||
|
|
||||||
from ..inference_engine.ie_api_impl_defs cimport IENetwork
|
from ..inference_engine.ie_api_impl_defs cimport IENetwork
|
||||||
|
|
||||||
@ -13,6 +14,8 @@ cdef extern from "offline_transformations_api_impl.hpp" namespace "InferenceEngi
|
|||||||
|
|
||||||
cdef void ApplyLowLatencyTransformation(IENetwork network, bool use_const_initializer)
|
cdef void ApplyLowLatencyTransformation(IENetwork network, bool use_const_initializer)
|
||||||
|
|
||||||
|
cdef void ApplyMakeStatefulTransformation(IENetwork network, map[string, string]& in_out_names)
|
||||||
|
|
||||||
cdef void ApplyPruningTransformation(IENetwork network)
|
cdef void ApplyPruningTransformation(IENetwork network)
|
||||||
|
|
||||||
cdef void GenerateMappingFile(IENetwork network, string path, bool extract_names)
|
cdef void GenerateMappingFile(IENetwork network, string path, bool extract_names)
|
||||||
|
@ -2,7 +2,8 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from openvino.inference_engine import IECore, IENetwork
|
from openvino.inference_engine import IECore, IENetwork
|
||||||
from openvino.offline_transformations import ApplyMOCTransformations, ApplyLowLatencyTransformation, ApplyPruningTransformation
|
from openvino.offline_transformations import ApplyMOCTransformations, ApplyLowLatencyTransformation, \
|
||||||
|
ApplyPruningTransformation, ApplyMakeStatefulTransformation
|
||||||
|
|
||||||
import ngraph as ng
|
import ngraph as ng
|
||||||
from ngraph.impl.op import Parameter
|
from ngraph.impl.op import Parameter
|
||||||
@ -14,10 +15,10 @@ from conftest import model_path
|
|||||||
test_net_xml, test_net_bin = model_path()
|
test_net_xml, test_net_bin = model_path()
|
||||||
|
|
||||||
def get_test_cnnnetwork():
|
def get_test_cnnnetwork():
|
||||||
element_type = Type.f32
|
param = ng.parameter(Shape([1, 3, 22, 22]), name="parameter")
|
||||||
param = Parameter(element_type, Shape([1, 3, 22, 22]))
|
|
||||||
relu = ng.relu(param)
|
relu = ng.relu(param)
|
||||||
func = Function([relu], [param], 'test')
|
res = ng.result(relu, name='result')
|
||||||
|
func = Function([res], [param], 'test')
|
||||||
caps = Function.to_capsule(func)
|
caps = Function.to_capsule(func)
|
||||||
|
|
||||||
cnnNetwork = IENetwork(caps)
|
cnnNetwork = IENetwork(caps)
|
||||||
@ -43,6 +44,16 @@ def test_low_latency_transformations():
|
|||||||
assert len(f.get_ops()) == 3
|
assert len(f.get_ops()) == 3
|
||||||
|
|
||||||
|
|
||||||
|
def test_make_stateful_transformations():
|
||||||
|
net = get_test_cnnnetwork()
|
||||||
|
ApplyMakeStatefulTransformation(net, {"parameter": "result"})
|
||||||
|
|
||||||
|
f = ng.function_from_cnn(net)
|
||||||
|
assert f != None
|
||||||
|
assert len(f.get_parameters()) == 0
|
||||||
|
assert len(f.get_results()) == 0
|
||||||
|
|
||||||
|
|
||||||
def test_pruning_transformations():
|
def test_pruning_transformations():
|
||||||
net = get_test_cnnnetwork()
|
net = get_test_cnnnetwork()
|
||||||
ApplyPruningTransformation(net)
|
ApplyPruningTransformation(net)
|
||||||
|
@ -84,5 +84,4 @@ INFERENCE_ENGINE_API_CPP(void) LowLatency(InferenceEngine::CNNNetwork& network);
|
|||||||
* Loop operation by a given number. Does not affect TensorIterators.
|
* Loop operation by a given number. Does not affect TensorIterators.
|
||||||
*/
|
*/
|
||||||
INFERENCE_ENGINE_API_CPP(void) lowLatency2(InferenceEngine::CNNNetwork& network, bool use_const_initializer = true);
|
INFERENCE_ENGINE_API_CPP(void) lowLatency2(InferenceEngine::CNNNetwork& network, bool use_const_initializer = true);
|
||||||
|
|
||||||
} // namespace InferenceEngine
|
} // namespace InferenceEngine
|
||||||
|
@ -0,0 +1,145 @@
|
|||||||
|
// Copyright (C) 2018-2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <memory>
|
||||||
|
#include <queue>
|
||||||
|
|
||||||
|
#include <ngraph/function.hpp>
|
||||||
|
#include <ngraph/opsets/opset8.hpp>
|
||||||
|
#include <ngraph/pass/manager.hpp>
|
||||||
|
|
||||||
|
#include <transformations/init_node_info.hpp>
|
||||||
|
#include <openvino/pass/make_stateful.hpp>
|
||||||
|
|
||||||
|
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||||
|
|
||||||
|
using namespace testing;
|
||||||
|
using namespace ngraph;
|
||||||
|
using namespace opset8;
|
||||||
|
using namespace std;
|
||||||
|
|
||||||
|
TEST(TransformationTests, make_stateful_by_name) {
|
||||||
|
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||||
|
{
|
||||||
|
auto X = make_shared<Parameter>(element::f32, Shape{32, 1, 10});
|
||||||
|
auto Y = make_shared<Parameter>(element::f32, Shape{32, 1, 10});
|
||||||
|
X->set_friendly_name("x");
|
||||||
|
Y->set_friendly_name("y");
|
||||||
|
|
||||||
|
auto add = make_shared<Add>(X, Y);
|
||||||
|
auto result0 = make_shared<Result>(add);
|
||||||
|
auto result1 = make_shared<Result>(add);
|
||||||
|
result0->set_friendly_name("res0");
|
||||||
|
result1->set_friendly_name("res1");
|
||||||
|
|
||||||
|
f = make_shared<Function>(ResultVector{result0, result1}, ParameterVector{X, Y});
|
||||||
|
std::map<std::string, std::string> pair_names = {{"x", "res0"}, {"y", "res1"}};
|
||||||
|
f->validate_nodes_and_infer_types();
|
||||||
|
|
||||||
|
ngraph::pass::Manager manager;
|
||||||
|
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||||
|
manager.register_pass<ov::pass::MakeStateful>(pair_names);
|
||||||
|
|
||||||
|
manager.run_passes(f);
|
||||||
|
ASSERT_NO_THROW(check_rt_info(f));
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
// create ReadValue for X
|
||||||
|
auto variable_x = std::make_shared<Variable>(VariableInfo{PartialShape::dynamic(), element::dynamic, "xres0"});
|
||||||
|
auto const_zero_x = make_shared<Constant>(element::f32, Shape{32, 1, 10}, 0);
|
||||||
|
auto read_val_x = make_shared<ReadValue>(const_zero_x, variable_x);
|
||||||
|
|
||||||
|
// create ReadValue for Y
|
||||||
|
auto variable_y = std::make_shared<Variable>(VariableInfo{PartialShape::dynamic(), element::dynamic, "yres1"});
|
||||||
|
auto const_zero_y = make_shared<Constant>(element::f32, Shape{32, 1, 10}, 0);
|
||||||
|
auto read_val_y = make_shared<ReadValue>(const_zero_y, variable_y);
|
||||||
|
|
||||||
|
auto add = make_shared<Add>(read_val_x, read_val_y);
|
||||||
|
auto assign_x = make_shared<Assign>(add, variable_x);
|
||||||
|
auto assign_y = make_shared<Assign>(add, variable_y);
|
||||||
|
|
||||||
|
f_ref = make_shared<Function>(ResultVector{}, SinkVector{assign_x, assign_y}, ParameterVector{});
|
||||||
|
f_ref->validate_nodes_and_infer_types();
|
||||||
|
}
|
||||||
|
auto res = compare_functions(f, f_ref);
|
||||||
|
ASSERT_TRUE(res.first) << res.second;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(TransformationTests, make_stateful_by_param_res) {
|
||||||
|
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||||
|
{
|
||||||
|
auto X = make_shared<Parameter>(element::f32, Shape{32, 1, 10});
|
||||||
|
auto Y = make_shared<Parameter>(element::f32, Shape{32, 1, 10});
|
||||||
|
X->set_friendly_name("x");
|
||||||
|
Y->set_friendly_name("y");
|
||||||
|
|
||||||
|
auto add = make_shared<Add>(X, Y);
|
||||||
|
auto result0 = make_shared<Result>(add);
|
||||||
|
auto result1 = make_shared<Result>(add);
|
||||||
|
result0->set_friendly_name("res0");
|
||||||
|
result1->set_friendly_name("res1");
|
||||||
|
|
||||||
|
f = make_shared<Function>(ResultVector{result0, result1}, ParameterVector{X, Y});
|
||||||
|
std::vector<std::pair<std::string, std::string>> pair_names = {{"x", "res0"}, {"y", "res1"}};
|
||||||
|
f->validate_nodes_and_infer_types();
|
||||||
|
|
||||||
|
ngraph::pass::Manager manager;
|
||||||
|
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||||
|
manager.register_pass<ov::pass::MakeStateful>(ov::pass::MakeStateful::ParamResPairs{{X, result0}, {Y, result1}});
|
||||||
|
manager.run_passes(f);
|
||||||
|
ASSERT_NO_THROW(check_rt_info(f));
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
// create ReadValue for X
|
||||||
|
auto variable_x = std::make_shared<Variable>(VariableInfo{PartialShape::dynamic(), element::dynamic, "xres0"});
|
||||||
|
auto const_zero_x = make_shared<Constant>(element::f32, Shape{32, 1, 10}, 0);
|
||||||
|
auto read_val_x = make_shared<ReadValue>(const_zero_x, variable_x);
|
||||||
|
|
||||||
|
// create ReadValue for Y
|
||||||
|
auto variable_y = std::make_shared<Variable>(VariableInfo{PartialShape::dynamic(), element::dynamic, "yres1"});
|
||||||
|
auto const_zero_y = make_shared<Constant>(element::f32, Shape{32, 1, 10}, 0);
|
||||||
|
auto read_val_y = make_shared<ReadValue>(const_zero_y, variable_y);
|
||||||
|
|
||||||
|
auto add = make_shared<Add>(read_val_x, read_val_y);
|
||||||
|
auto assign_x = make_shared<Assign>(add, variable_x);
|
||||||
|
auto assign_y = make_shared<Assign>(add, variable_y);
|
||||||
|
|
||||||
|
f_ref = make_shared<Function>(ResultVector{}, SinkVector{assign_x, assign_y}, ParameterVector{});
|
||||||
|
f_ref->validate_nodes_and_infer_types();
|
||||||
|
}
|
||||||
|
auto res = compare_functions(f, f_ref);
|
||||||
|
ASSERT_TRUE(res.first) << res.second;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(TransformationTests, make_stateful_dynamic_shapes) {
|
||||||
|
std::shared_ptr<ngraph::Function> f(nullptr);
|
||||||
|
{
|
||||||
|
auto X = make_shared<Parameter>(element::f32, PartialShape::dynamic());
|
||||||
|
auto Y = make_shared<Parameter>(element::f32, PartialShape::dynamic());
|
||||||
|
X->set_friendly_name("x");
|
||||||
|
Y->set_friendly_name("y");
|
||||||
|
|
||||||
|
auto add = make_shared<Add>(X, Y);
|
||||||
|
auto result0 = make_shared<Result>(add);
|
||||||
|
auto result1 = make_shared<Result>(add);
|
||||||
|
result0->set_friendly_name("res0");
|
||||||
|
result1->set_friendly_name("res1");
|
||||||
|
|
||||||
|
f = make_shared<Function>(ResultVector{result0, result1}, ParameterVector{X, Y});
|
||||||
|
map<std::string, std::string> pair_names = {{"x", "res0"}, {"y", "res1"}};
|
||||||
|
f->validate_nodes_and_infer_types();
|
||||||
|
|
||||||
|
ngraph::pass::Manager manager;
|
||||||
|
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||||
|
manager.register_pass<ov::pass::MakeStateful>(pair_names);
|
||||||
|
|
||||||
|
EXPECT_THROW(manager.run_passes(f), ::ov::AssertFailure);
|
||||||
|
ASSERT_NO_THROW(check_rt_info(f));
|
||||||
|
}
|
||||||
|
}
|
@ -9,8 +9,9 @@ from mo.utils.cli_parser import parse_transform
|
|||||||
|
|
||||||
def get_available_transformations():
|
def get_available_transformations():
|
||||||
try:
|
try:
|
||||||
from openvino.offline_transformations import ApplyLowLatencyTransformation # pylint: disable=import-error,no-name-in-module
|
from openvino.offline_transformations import ApplyLowLatencyTransformation, ApplyMakeStatefulTransformation # pylint: disable=import-error,no-name-in-module
|
||||||
return {
|
return {
|
||||||
|
'MakeStateful': ApplyMakeStatefulTransformation,
|
||||||
'LowLatency2': ApplyLowLatencyTransformation,
|
'LowLatency2': ApplyLowLatencyTransformation,
|
||||||
}
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -49,7 +49,7 @@ def import_core_modules(silent: bool, path_to_module: str):
|
|||||||
try:
|
try:
|
||||||
from openvino.inference_engine import get_version, read_network # pylint: disable=import-error,no-name-in-module
|
from openvino.inference_engine import get_version, read_network # pylint: disable=import-error,no-name-in-module
|
||||||
from openvino.offline_transformations import ApplyMOCTransformations, ApplyLowLatencyTransformation, \
|
from openvino.offline_transformations import ApplyMOCTransformations, ApplyLowLatencyTransformation, \
|
||||||
GenerateMappingFile # pylint: disable=import-error,no-name-in-module
|
ApplyMakeStatefulTransformation, GenerateMappingFile # pylint: disable=import-error,no-name-in-module
|
||||||
|
|
||||||
# TODO: it is temporary import to check that nGraph python API is available. But in future
|
# TODO: it is temporary import to check that nGraph python API is available. But in future
|
||||||
# we need to replace it with Frontend imports
|
# we need to replace it with Frontend imports
|
||||||
|
@ -1190,7 +1190,18 @@ def isbool(value):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def isdict(value):
|
||||||
|
try:
|
||||||
|
evaluated = ast.literal_eval(value)
|
||||||
|
return isinstance(evaluated, dict)
|
||||||
|
except ValueError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def convert_string_to_real_type(value: str):
|
def convert_string_to_real_type(value: str):
|
||||||
|
if isdict(value):
|
||||||
|
return ast.literal_eval(value)
|
||||||
|
|
||||||
values = value.split(',')
|
values = value.split(',')
|
||||||
for i in range(len(values)):
|
for i in range(len(values)):
|
||||||
value = values[i]
|
value = values[i]
|
||||||
|
36
ngraph/core/include/openvino/pass/make_stateful.hpp
Normal file
36
ngraph/core/include/openvino/pass/make_stateful.hpp
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
// Copyright (C) 2018-2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "ngraph/opsets/opset8.hpp"
|
||||||
|
#include "openvino/pass/pass.hpp"
|
||||||
|
|
||||||
|
namespace ov {
|
||||||
|
namespace pass {
|
||||||
|
/**
|
||||||
|
* @brief The transformation replaces the provided pairs Parameter and Result with ngraph Memory layers
|
||||||
|
* ReadValue and Assign
|
||||||
|
*/
|
||||||
|
class OPENVINO_API MakeStateful : public FunctionPass {
|
||||||
|
public:
|
||||||
|
OPENVINO_RTTI("MakeStateful");
|
||||||
|
|
||||||
|
using ParamResPairs =
|
||||||
|
std::vector<std::pair<std::shared_ptr<ngraph::opset8::Parameter>, std::shared_ptr<ngraph::opset8::Result>>>;
|
||||||
|
|
||||||
|
explicit MakeStateful(const ParamResPairs& pairs_to_replace) : m_param_res_pairs(pairs_to_replace) {}
|
||||||
|
explicit MakeStateful(const std::map<std::string, std::string>& param_res_names)
|
||||||
|
: m_param_res_names(param_res_names) {}
|
||||||
|
bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
ParamResPairs m_param_res_pairs;
|
||||||
|
std::map<std::string, std::string> m_param_res_names;
|
||||||
|
};
|
||||||
|
} // namespace pass
|
||||||
|
} // namespace ov
|
91
ngraph/core/src/pass/make_stateful.cpp
Normal file
91
ngraph/core/src/pass/make_stateful.cpp
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
// Copyright (C) 2018-2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "openvino/pass/make_stateful.hpp"
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <ngraph/rt_info.hpp>
|
||||||
|
#include <openvino/op/util/variable.hpp>
|
||||||
|
#include <openvino/opsets/opset8.hpp>
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
using namespace ngraph;
|
||||||
|
using namespace opset8;
|
||||||
|
using namespace op::util;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
string generate_variable_name(const shared_ptr<Parameter>& param, const shared_ptr<Result>& res) {
|
||||||
|
return param->get_friendly_name() + res->get_friendly_name();
|
||||||
|
}
|
||||||
|
|
||||||
|
ov::pass::MakeStateful::ParamResPairs find_param_results_by_names(
|
||||||
|
const shared_ptr<ngraph::Function>& func,
|
||||||
|
const std::map<std::string, std::string>& param_res_names) {
|
||||||
|
ov::pass::MakeStateful::ParamResPairs pairs_to_replace;
|
||||||
|
const auto& params = func->get_parameters();
|
||||||
|
const auto& results = func->get_results();
|
||||||
|
|
||||||
|
// find corresponding param and result by name and add to the list
|
||||||
|
for (const auto& param_res : param_res_names) {
|
||||||
|
const auto& param_name = param_res.first;
|
||||||
|
const auto& res_name = param_res.second;
|
||||||
|
|
||||||
|
auto param = std::find_if(params.begin(), params.end(), [&](const std::shared_ptr<ngraph::Node>& node) {
|
||||||
|
return node->get_friendly_name() == param_name;
|
||||||
|
});
|
||||||
|
NGRAPH_CHECK(param != params.end(), "Parameter node with name = ", param_name, "doesn't exist in the function");
|
||||||
|
|
||||||
|
auto res = std::find_if(results.begin(), results.end(), [&](const std::shared_ptr<ngraph::Node>& node) {
|
||||||
|
return node->get_friendly_name() == res_name;
|
||||||
|
});
|
||||||
|
NGRAPH_CHECK(res != results.end(), "Result node with name = ", res_name, " doesn't exist in the function");
|
||||||
|
|
||||||
|
pairs_to_replace.emplace_back(*param, *res);
|
||||||
|
}
|
||||||
|
return pairs_to_replace;
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
bool ov::pass::MakeStateful::run_on_function(std::shared_ptr<ngraph::Function> f) {
|
||||||
|
if (m_param_res_pairs.empty()) {
|
||||||
|
m_param_res_pairs = find_param_results_by_names(f, m_param_res_names);
|
||||||
|
}
|
||||||
|
|
||||||
|
VariableVector variables;
|
||||||
|
SinkVector sinks;
|
||||||
|
for (const auto& pair : m_param_res_pairs) {
|
||||||
|
const auto& param = pair.first;
|
||||||
|
const auto& res = pair.second;
|
||||||
|
|
||||||
|
NGRAPH_CHECK(param->get_partial_shape().is_static(),
|
||||||
|
"Shape of Parameter ",
|
||||||
|
param->get_friendly_name(),
|
||||||
|
" must be static. MakeStateful transformation doesn't support dynamic shapes.");
|
||||||
|
|
||||||
|
// Create Variable
|
||||||
|
std::string var_name = generate_variable_name(param, res);
|
||||||
|
auto variable =
|
||||||
|
std::make_shared<Variable>(VariableInfo{param->get_shape(), param->get_element_type(), var_name});
|
||||||
|
variables.push_back(variable);
|
||||||
|
|
||||||
|
// Create ReadValue
|
||||||
|
auto const_zero = make_shared<Constant>(param->get_element_type(), param->get_shape(), 0);
|
||||||
|
auto read_val = make_shared<ReadValue>(const_zero, variable);
|
||||||
|
replace_node(param, read_val);
|
||||||
|
copy_runtime_info(param, {read_val, const_zero});
|
||||||
|
|
||||||
|
// Create Assign
|
||||||
|
auto assign = make_shared<Assign>(res->input_value(0), variable);
|
||||||
|
copy_runtime_info(res, assign);
|
||||||
|
|
||||||
|
// Update Function
|
||||||
|
sinks.push_back(assign);
|
||||||
|
f->remove_result(res);
|
||||||
|
f->remove_parameter(param);
|
||||||
|
assign->add_control_dependency(read_val);
|
||||||
|
}
|
||||||
|
f->add_variables(variables);
|
||||||
|
f->add_sinks(sinks);
|
||||||
|
return true;
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user