MakeStateful transformation (#7417)

* ReplaceInOutWithMemory transformation in ngraph

* add unit tests

* add ReplaceInputsOutputsWithMemory transformation in python

* update codestyle

* rename the transformation

* fix codestyle

* add Dynamic shapes check in the transformation and unit test

* fix codestyle

* rename files

* fix python API

* fix codestyle

* fix codestyle

* update python API

* fix codestyle

* fix build

* codestyle

* fix unit test

* fix RTTI declaration

* Apply suggestions from code review

Co-authored-by: Gleb Kazantaev <gleb.nnstu@gmail.com>

* review comments

* openvino codestyle

* change the name of Variable in the transformation

* fix build

* delete MakeStateful transformation from ie_transformations

* Resolve review comments

* codestyle

* fix missprint, codestyle

* delete unused variable

Co-authored-by: Gleb Kazantaev <gleb.nnstu@gmail.com>
This commit is contained in:
Ivan Tikhonov 2021-10-13 19:02:42 +03:00 committed by GitHub
parent fd97a62263
commit 7b1a418bf4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 326 additions and 7 deletions

View File

@ -6,6 +6,7 @@ from ..inference_engine.ie_api cimport IENetwork
from libcpp cimport bool from libcpp cimport bool
from libcpp.string cimport string from libcpp.string cimport string
from libcpp.map cimport map
from libc.stdint cimport int64_t from libc.stdint cimport int64_t
@ -17,6 +18,15 @@ def ApplyPOTTransformations(IENetwork network, string device):
C.ApplyPOTTransformations(network.impl, device) C.ApplyPOTTransformations(network.impl, device)
def ApplyMakeStatefulTransformation(IENetwork network, param_res_names : dict):
cdef map[string, string] c_param_res_names
for param_name, res_name in param_res_names.items():
if type(param_name) != str or type(res_name) != str:
raise TypeError("Only string keys and values are allowed!")
c_param_res_names[param_name.encode()] = res_name.encode()
C.ApplyMakeStatefulTransformation(network.impl, c_param_res_names)
def ApplyLowLatencyTransformation(IENetwork network, bool use_const_initializer = True): def ApplyLowLatencyTransformation(IENetwork network, bool use_const_initializer = True):
C.ApplyLowLatencyTransformation(network.impl, use_const_initializer) C.ApplyLowLatencyTransformation(network.impl, use_const_initializer)

View File

@ -9,6 +9,7 @@
#include <ngraph/pass/constant_folding.hpp> #include <ngraph/pass/constant_folding.hpp>
#include <ngraph/pass/low_latency.hpp> #include <ngraph/pass/low_latency.hpp>
#include <ngraph/pass/manager.hpp> #include <ngraph/pass/manager.hpp>
#include <openvino/pass/make_stateful.hpp>
#include <pot_transformations.hpp> #include <pot_transformations.hpp>
#include <pruning.hpp> #include <pruning.hpp>
#include <transformations/common_optimizations/moc_transformations.hpp> #include <transformations/common_optimizations/moc_transformations.hpp>
@ -33,6 +34,13 @@ void InferenceEnginePython::ApplyLowLatencyTransformation(InferenceEnginePython:
manager.run_passes(network.actual->getFunction()); manager.run_passes(network.actual->getFunction());
} }
void InferenceEnginePython::ApplyMakeStatefulTransformation(InferenceEnginePython::IENetwork network,
std::map<std::string, std::string>& param_res_names) {
ngraph::pass::Manager manager;
manager.register_pass<ov::pass::MakeStateful>(param_res_names);
manager.run_passes(network.actual->getFunction());
}
void InferenceEnginePython::ApplyPruningTransformation(InferenceEnginePython::IENetwork network) { void InferenceEnginePython::ApplyPruningTransformation(InferenceEnginePython::IENetwork network) {
ngraph::pass::Manager manager; ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::Pruning>(); manager.register_pass<ngraph::pass::Pruning>();

View File

@ -4,6 +4,7 @@
#pragma once #pragma once
#include <map>
#include <string> #include <string>
#include "Python.h" #include "Python.h"
@ -17,6 +18,9 @@ void ApplyPOTTransformations(InferenceEnginePython::IENetwork network, std::stri
void ApplyLowLatencyTransformation(InferenceEnginePython::IENetwork network, bool use_const_initializer = true); void ApplyLowLatencyTransformation(InferenceEnginePython::IENetwork network, bool use_const_initializer = true);
void ApplyMakeStatefulTransformation(InferenceEnginePython::IENetwork network,
std::map<std::string, std::string>& param_res_names);
void ApplyPruningTransformation(InferenceEnginePython::IENetwork network); void ApplyPruningTransformation(InferenceEnginePython::IENetwork network);
void GenerateMappingFile(InferenceEnginePython::IENetwork network, std::string path, bool extract_names); void GenerateMappingFile(InferenceEnginePython::IENetwork network, std::string path, bool extract_names);

View File

@ -3,6 +3,7 @@
from libcpp cimport bool from libcpp cimport bool
from libcpp.string cimport string from libcpp.string cimport string
from libcpp.map cimport map
from ..inference_engine.ie_api_impl_defs cimport IENetwork from ..inference_engine.ie_api_impl_defs cimport IENetwork
@ -13,6 +14,8 @@ cdef extern from "offline_transformations_api_impl.hpp" namespace "InferenceEngi
cdef void ApplyLowLatencyTransformation(IENetwork network, bool use_const_initializer) cdef void ApplyLowLatencyTransformation(IENetwork network, bool use_const_initializer)
cdef void ApplyMakeStatefulTransformation(IENetwork network, map[string, string]& in_out_names)
cdef void ApplyPruningTransformation(IENetwork network) cdef void ApplyPruningTransformation(IENetwork network)
cdef void GenerateMappingFile(IENetwork network, string path, bool extract_names) cdef void GenerateMappingFile(IENetwork network, string path, bool extract_names)

View File

@ -2,7 +2,8 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from openvino.inference_engine import IECore, IENetwork from openvino.inference_engine import IECore, IENetwork
from openvino.offline_transformations import ApplyMOCTransformations, ApplyLowLatencyTransformation, ApplyPruningTransformation from openvino.offline_transformations import ApplyMOCTransformations, ApplyLowLatencyTransformation, \
ApplyPruningTransformation, ApplyMakeStatefulTransformation
import ngraph as ng import ngraph as ng
from ngraph.impl.op import Parameter from ngraph.impl.op import Parameter
@ -14,10 +15,10 @@ from conftest import model_path
test_net_xml, test_net_bin = model_path() test_net_xml, test_net_bin = model_path()
def get_test_cnnnetwork(): def get_test_cnnnetwork():
element_type = Type.f32 param = ng.parameter(Shape([1, 3, 22, 22]), name="parameter")
param = Parameter(element_type, Shape([1, 3, 22, 22]))
relu = ng.relu(param) relu = ng.relu(param)
func = Function([relu], [param], 'test') res = ng.result(relu, name='result')
func = Function([res], [param], 'test')
caps = Function.to_capsule(func) caps = Function.to_capsule(func)
cnnNetwork = IENetwork(caps) cnnNetwork = IENetwork(caps)
@ -43,6 +44,16 @@ def test_low_latency_transformations():
assert len(f.get_ops()) == 3 assert len(f.get_ops()) == 3
def test_make_stateful_transformations():
net = get_test_cnnnetwork()
ApplyMakeStatefulTransformation(net, {"parameter": "result"})
f = ng.function_from_cnn(net)
assert f != None
assert len(f.get_parameters()) == 0
assert len(f.get_results()) == 0
def test_pruning_transformations(): def test_pruning_transformations():
net = get_test_cnnnetwork() net = get_test_cnnnetwork()
ApplyPruningTransformation(net) ApplyPruningTransformation(net)

View File

@ -84,5 +84,4 @@ INFERENCE_ENGINE_API_CPP(void) LowLatency(InferenceEngine::CNNNetwork& network);
* Loop operation by a given number. Does not affect TensorIterators. * Loop operation by a given number. Does not affect TensorIterators.
*/ */
INFERENCE_ENGINE_API_CPP(void) lowLatency2(InferenceEngine::CNNNetwork& network, bool use_const_initializer = true); INFERENCE_ENGINE_API_CPP(void) lowLatency2(InferenceEngine::CNNNetwork& network, bool use_const_initializer = true);
} // namespace InferenceEngine } // namespace InferenceEngine

View File

@ -0,0 +1,145 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include <string>
#include <memory>
#include <queue>
#include <ngraph/function.hpp>
#include <ngraph/opsets/opset8.hpp>
#include <ngraph/pass/manager.hpp>
#include <transformations/init_node_info.hpp>
#include <openvino/pass/make_stateful.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
using namespace testing;
using namespace ngraph;
using namespace opset8;
using namespace std;
TEST(TransformationTests, make_stateful_by_name) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto X = make_shared<Parameter>(element::f32, Shape{32, 1, 10});
auto Y = make_shared<Parameter>(element::f32, Shape{32, 1, 10});
X->set_friendly_name("x");
Y->set_friendly_name("y");
auto add = make_shared<Add>(X, Y);
auto result0 = make_shared<Result>(add);
auto result1 = make_shared<Result>(add);
result0->set_friendly_name("res0");
result1->set_friendly_name("res1");
f = make_shared<Function>(ResultVector{result0, result1}, ParameterVector{X, Y});
std::map<std::string, std::string> pair_names = {{"x", "res0"}, {"y", "res1"}};
f->validate_nodes_and_infer_types();
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ov::pass::MakeStateful>(pair_names);
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
// create ReadValue for X
auto variable_x = std::make_shared<Variable>(VariableInfo{PartialShape::dynamic(), element::dynamic, "xres0"});
auto const_zero_x = make_shared<Constant>(element::f32, Shape{32, 1, 10}, 0);
auto read_val_x = make_shared<ReadValue>(const_zero_x, variable_x);
// create ReadValue for Y
auto variable_y = std::make_shared<Variable>(VariableInfo{PartialShape::dynamic(), element::dynamic, "yres1"});
auto const_zero_y = make_shared<Constant>(element::f32, Shape{32, 1, 10}, 0);
auto read_val_y = make_shared<ReadValue>(const_zero_y, variable_y);
auto add = make_shared<Add>(read_val_x, read_val_y);
auto assign_x = make_shared<Assign>(add, variable_x);
auto assign_y = make_shared<Assign>(add, variable_y);
f_ref = make_shared<Function>(ResultVector{}, SinkVector{assign_x, assign_y}, ParameterVector{});
f_ref->validate_nodes_and_infer_types();
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, make_stateful_by_param_res) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto X = make_shared<Parameter>(element::f32, Shape{32, 1, 10});
auto Y = make_shared<Parameter>(element::f32, Shape{32, 1, 10});
X->set_friendly_name("x");
Y->set_friendly_name("y");
auto add = make_shared<Add>(X, Y);
auto result0 = make_shared<Result>(add);
auto result1 = make_shared<Result>(add);
result0->set_friendly_name("res0");
result1->set_friendly_name("res1");
f = make_shared<Function>(ResultVector{result0, result1}, ParameterVector{X, Y});
std::vector<std::pair<std::string, std::string>> pair_names = {{"x", "res0"}, {"y", "res1"}};
f->validate_nodes_and_infer_types();
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ov::pass::MakeStateful>(ov::pass::MakeStateful::ParamResPairs{{X, result0}, {Y, result1}});
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
// create ReadValue for X
auto variable_x = std::make_shared<Variable>(VariableInfo{PartialShape::dynamic(), element::dynamic, "xres0"});
auto const_zero_x = make_shared<Constant>(element::f32, Shape{32, 1, 10}, 0);
auto read_val_x = make_shared<ReadValue>(const_zero_x, variable_x);
// create ReadValue for Y
auto variable_y = std::make_shared<Variable>(VariableInfo{PartialShape::dynamic(), element::dynamic, "yres1"});
auto const_zero_y = make_shared<Constant>(element::f32, Shape{32, 1, 10}, 0);
auto read_val_y = make_shared<ReadValue>(const_zero_y, variable_y);
auto add = make_shared<Add>(read_val_x, read_val_y);
auto assign_x = make_shared<Assign>(add, variable_x);
auto assign_y = make_shared<Assign>(add, variable_y);
f_ref = make_shared<Function>(ResultVector{}, SinkVector{assign_x, assign_y}, ParameterVector{});
f_ref->validate_nodes_and_infer_types();
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, make_stateful_dynamic_shapes) {
std::shared_ptr<ngraph::Function> f(nullptr);
{
auto X = make_shared<Parameter>(element::f32, PartialShape::dynamic());
auto Y = make_shared<Parameter>(element::f32, PartialShape::dynamic());
X->set_friendly_name("x");
Y->set_friendly_name("y");
auto add = make_shared<Add>(X, Y);
auto result0 = make_shared<Result>(add);
auto result1 = make_shared<Result>(add);
result0->set_friendly_name("res0");
result1->set_friendly_name("res1");
f = make_shared<Function>(ResultVector{result0, result1}, ParameterVector{X, Y});
map<std::string, std::string> pair_names = {{"x", "res0"}, {"y", "res1"}};
f->validate_nodes_and_infer_types();
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ov::pass::MakeStateful>(pair_names);
EXPECT_THROW(manager.run_passes(f), ::ov::AssertFailure);
ASSERT_NO_THROW(check_rt_info(f));
}
}

View File

@ -9,8 +9,9 @@ from mo.utils.cli_parser import parse_transform
def get_available_transformations(): def get_available_transformations():
try: try:
from openvino.offline_transformations import ApplyLowLatencyTransformation # pylint: disable=import-error,no-name-in-module from openvino.offline_transformations import ApplyLowLatencyTransformation, ApplyMakeStatefulTransformation # pylint: disable=import-error,no-name-in-module
return { return {
'MakeStateful': ApplyMakeStatefulTransformation,
'LowLatency2': ApplyLowLatencyTransformation, 'LowLatency2': ApplyLowLatencyTransformation,
} }
except Exception as e: except Exception as e:

View File

@ -49,7 +49,7 @@ def import_core_modules(silent: bool, path_to_module: str):
try: try:
from openvino.inference_engine import get_version, read_network # pylint: disable=import-error,no-name-in-module from openvino.inference_engine import get_version, read_network # pylint: disable=import-error,no-name-in-module
from openvino.offline_transformations import ApplyMOCTransformations, ApplyLowLatencyTransformation, \ from openvino.offline_transformations import ApplyMOCTransformations, ApplyLowLatencyTransformation, \
GenerateMappingFile # pylint: disable=import-error,no-name-in-module ApplyMakeStatefulTransformation, GenerateMappingFile # pylint: disable=import-error,no-name-in-module
# TODO: it is temporary import to check that nGraph python API is available. But in future # TODO: it is temporary import to check that nGraph python API is available. But in future
# we need to replace it with Frontend imports # we need to replace it with Frontend imports

View File

@ -1190,7 +1190,18 @@ def isbool(value):
return False return False
def isdict(value):
try:
evaluated = ast.literal_eval(value)
return isinstance(evaluated, dict)
except ValueError:
return False
def convert_string_to_real_type(value: str): def convert_string_to_real_type(value: str):
if isdict(value):
return ast.literal_eval(value)
values = value.split(',') values = value.split(',')
for i in range(len(values)): for i in range(len(values)):
value = values[i] value = values[i]

View File

@ -0,0 +1,36 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <memory>
#include <vector>
#include "ngraph/opsets/opset8.hpp"
#include "openvino/pass/pass.hpp"
namespace ov {
namespace pass {
/**
* @brief The transformation replaces the provided pairs Parameter and Result with ngraph Memory layers
* ReadValue and Assign
*/
class OPENVINO_API MakeStateful : public FunctionPass {
public:
OPENVINO_RTTI("MakeStateful");
using ParamResPairs =
std::vector<std::pair<std::shared_ptr<ngraph::opset8::Parameter>, std::shared_ptr<ngraph::opset8::Result>>>;
explicit MakeStateful(const ParamResPairs& pairs_to_replace) : m_param_res_pairs(pairs_to_replace) {}
explicit MakeStateful(const std::map<std::string, std::string>& param_res_names)
: m_param_res_names(param_res_names) {}
bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
private:
ParamResPairs m_param_res_pairs;
std::map<std::string, std::string> m_param_res_names;
};
} // namespace pass
} // namespace ov

View File

@ -0,0 +1,91 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "openvino/pass/make_stateful.hpp"
#include <memory>
#include <ngraph/rt_info.hpp>
#include <openvino/op/util/variable.hpp>
#include <openvino/opsets/opset8.hpp>
using namespace std;
using namespace ngraph;
using namespace opset8;
using namespace op::util;
namespace {
string generate_variable_name(const shared_ptr<Parameter>& param, const shared_ptr<Result>& res) {
return param->get_friendly_name() + res->get_friendly_name();
}
ov::pass::MakeStateful::ParamResPairs find_param_results_by_names(
const shared_ptr<ngraph::Function>& func,
const std::map<std::string, std::string>& param_res_names) {
ov::pass::MakeStateful::ParamResPairs pairs_to_replace;
const auto& params = func->get_parameters();
const auto& results = func->get_results();
// find corresponding param and result by name and add to the list
for (const auto& param_res : param_res_names) {
const auto& param_name = param_res.first;
const auto& res_name = param_res.second;
auto param = std::find_if(params.begin(), params.end(), [&](const std::shared_ptr<ngraph::Node>& node) {
return node->get_friendly_name() == param_name;
});
NGRAPH_CHECK(param != params.end(), "Parameter node with name = ", param_name, "doesn't exist in the function");
auto res = std::find_if(results.begin(), results.end(), [&](const std::shared_ptr<ngraph::Node>& node) {
return node->get_friendly_name() == res_name;
});
NGRAPH_CHECK(res != results.end(), "Result node with name = ", res_name, " doesn't exist in the function");
pairs_to_replace.emplace_back(*param, *res);
}
return pairs_to_replace;
}
} // namespace
bool ov::pass::MakeStateful::run_on_function(std::shared_ptr<ngraph::Function> f) {
if (m_param_res_pairs.empty()) {
m_param_res_pairs = find_param_results_by_names(f, m_param_res_names);
}
VariableVector variables;
SinkVector sinks;
for (const auto& pair : m_param_res_pairs) {
const auto& param = pair.first;
const auto& res = pair.second;
NGRAPH_CHECK(param->get_partial_shape().is_static(),
"Shape of Parameter ",
param->get_friendly_name(),
" must be static. MakeStateful transformation doesn't support dynamic shapes.");
// Create Variable
std::string var_name = generate_variable_name(param, res);
auto variable =
std::make_shared<Variable>(VariableInfo{param->get_shape(), param->get_element_type(), var_name});
variables.push_back(variable);
// Create ReadValue
auto const_zero = make_shared<Constant>(param->get_element_type(), param->get_shape(), 0);
auto read_val = make_shared<ReadValue>(const_zero, variable);
replace_node(param, read_val);
copy_runtime_info(param, {read_val, const_zero});
// Create Assign
auto assign = make_shared<Assign>(res->input_value(0), variable);
copy_runtime_info(res, assign);
// Update Function
sinks.push_back(assign);
f->remove_result(res);
f->remove_parameter(param);
assign->add_control_dependency(read_val);
}
f->add_variables(variables);
f->add_sinks(sinks);
return true;
}