MakeStateful transformation (#7417)
* ReplaceInOutWithMemory transformation in ngraph * add unit tests * add ReplaceInputsOutputsWithMemory transformation in python * update codestyle * rename the transformation * fix codestyle * add Dynamic shapes check in the transformation and unit test * fix codestyle * rename files * fix python API * fix codestyle * fix codestyle * update python API * fix codestyle * fix build * codestyle * fix unit test * fix RTTI declaration * Apply suggestions from code review Co-authored-by: Gleb Kazantaev <gleb.nnstu@gmail.com> * review comments * openvino codestyle * change the name of Variable in the transformation * fix build * delete MakeStateful transformation from ie_transformations * Resolve review comments * codestyle * fix missprint, codestyle * delete unused variable Co-authored-by: Gleb Kazantaev <gleb.nnstu@gmail.com>
This commit is contained in:
parent
fd97a62263
commit
7b1a418bf4
@ -6,6 +6,7 @@ from ..inference_engine.ie_api cimport IENetwork
|
||||
|
||||
from libcpp cimport bool
|
||||
from libcpp.string cimport string
|
||||
from libcpp.map cimport map
|
||||
from libc.stdint cimport int64_t
|
||||
|
||||
|
||||
@ -17,6 +18,15 @@ def ApplyPOTTransformations(IENetwork network, string device):
|
||||
C.ApplyPOTTransformations(network.impl, device)
|
||||
|
||||
|
||||
def ApplyMakeStatefulTransformation(IENetwork network, param_res_names : dict):
|
||||
cdef map[string, string] c_param_res_names
|
||||
for param_name, res_name in param_res_names.items():
|
||||
if type(param_name) != str or type(res_name) != str:
|
||||
raise TypeError("Only string keys and values are allowed!")
|
||||
c_param_res_names[param_name.encode()] = res_name.encode()
|
||||
C.ApplyMakeStatefulTransformation(network.impl, c_param_res_names)
|
||||
|
||||
|
||||
def ApplyLowLatencyTransformation(IENetwork network, bool use_const_initializer = True):
|
||||
C.ApplyLowLatencyTransformation(network.impl, use_const_initializer)
|
||||
|
||||
|
@ -9,6 +9,7 @@
|
||||
#include <ngraph/pass/constant_folding.hpp>
|
||||
#include <ngraph/pass/low_latency.hpp>
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
#include <openvino/pass/make_stateful.hpp>
|
||||
#include <pot_transformations.hpp>
|
||||
#include <pruning.hpp>
|
||||
#include <transformations/common_optimizations/moc_transformations.hpp>
|
||||
@ -33,6 +34,13 @@ void InferenceEnginePython::ApplyLowLatencyTransformation(InferenceEnginePython:
|
||||
manager.run_passes(network.actual->getFunction());
|
||||
}
|
||||
|
||||
void InferenceEnginePython::ApplyMakeStatefulTransformation(InferenceEnginePython::IENetwork network,
|
||||
std::map<std::string, std::string>& param_res_names) {
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ov::pass::MakeStateful>(param_res_names);
|
||||
manager.run_passes(network.actual->getFunction());
|
||||
}
|
||||
|
||||
void InferenceEnginePython::ApplyPruningTransformation(InferenceEnginePython::IENetwork network) {
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::Pruning>();
|
||||
|
@ -4,6 +4,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
|
||||
#include "Python.h"
|
||||
@ -17,6 +18,9 @@ void ApplyPOTTransformations(InferenceEnginePython::IENetwork network, std::stri
|
||||
|
||||
void ApplyLowLatencyTransformation(InferenceEnginePython::IENetwork network, bool use_const_initializer = true);
|
||||
|
||||
void ApplyMakeStatefulTransformation(InferenceEnginePython::IENetwork network,
|
||||
std::map<std::string, std::string>& param_res_names);
|
||||
|
||||
void ApplyPruningTransformation(InferenceEnginePython::IENetwork network);
|
||||
|
||||
void GenerateMappingFile(InferenceEnginePython::IENetwork network, std::string path, bool extract_names);
|
||||
|
@ -3,6 +3,7 @@
|
||||
|
||||
from libcpp cimport bool
|
||||
from libcpp.string cimport string
|
||||
from libcpp.map cimport map
|
||||
|
||||
from ..inference_engine.ie_api_impl_defs cimport IENetwork
|
||||
|
||||
@ -13,6 +14,8 @@ cdef extern from "offline_transformations_api_impl.hpp" namespace "InferenceEngi
|
||||
|
||||
cdef void ApplyLowLatencyTransformation(IENetwork network, bool use_const_initializer)
|
||||
|
||||
cdef void ApplyMakeStatefulTransformation(IENetwork network, map[string, string]& in_out_names)
|
||||
|
||||
cdef void ApplyPruningTransformation(IENetwork network)
|
||||
|
||||
cdef void GenerateMappingFile(IENetwork network, string path, bool extract_names)
|
||||
|
@ -2,7 +2,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from openvino.inference_engine import IECore, IENetwork
|
||||
from openvino.offline_transformations import ApplyMOCTransformations, ApplyLowLatencyTransformation, ApplyPruningTransformation
|
||||
from openvino.offline_transformations import ApplyMOCTransformations, ApplyLowLatencyTransformation, \
|
||||
ApplyPruningTransformation, ApplyMakeStatefulTransformation
|
||||
|
||||
import ngraph as ng
|
||||
from ngraph.impl.op import Parameter
|
||||
@ -14,10 +15,10 @@ from conftest import model_path
|
||||
test_net_xml, test_net_bin = model_path()
|
||||
|
||||
def get_test_cnnnetwork():
|
||||
element_type = Type.f32
|
||||
param = Parameter(element_type, Shape([1, 3, 22, 22]))
|
||||
param = ng.parameter(Shape([1, 3, 22, 22]), name="parameter")
|
||||
relu = ng.relu(param)
|
||||
func = Function([relu], [param], 'test')
|
||||
res = ng.result(relu, name='result')
|
||||
func = Function([res], [param], 'test')
|
||||
caps = Function.to_capsule(func)
|
||||
|
||||
cnnNetwork = IENetwork(caps)
|
||||
@ -43,6 +44,16 @@ def test_low_latency_transformations():
|
||||
assert len(f.get_ops()) == 3
|
||||
|
||||
|
||||
def test_make_stateful_transformations():
|
||||
net = get_test_cnnnetwork()
|
||||
ApplyMakeStatefulTransformation(net, {"parameter": "result"})
|
||||
|
||||
f = ng.function_from_cnn(net)
|
||||
assert f != None
|
||||
assert len(f.get_parameters()) == 0
|
||||
assert len(f.get_results()) == 0
|
||||
|
||||
|
||||
def test_pruning_transformations():
|
||||
net = get_test_cnnnetwork()
|
||||
ApplyPruningTransformation(net)
|
||||
|
@ -84,5 +84,4 @@ INFERENCE_ENGINE_API_CPP(void) LowLatency(InferenceEngine::CNNNetwork& network);
|
||||
* Loop operation by a given number. Does not affect TensorIterators.
|
||||
*/
|
||||
INFERENCE_ENGINE_API_CPP(void) lowLatency2(InferenceEngine::CNNNetwork& network, bool use_const_initializer = true);
|
||||
|
||||
} // namespace InferenceEngine
|
||||
|
@ -0,0 +1,145 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <queue>
|
||||
|
||||
#include <ngraph/function.hpp>
|
||||
#include <ngraph/opsets/opset8.hpp>
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
|
||||
#include <transformations/init_node_info.hpp>
|
||||
#include <openvino/pass/make_stateful.hpp>
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
|
||||
using namespace testing;
|
||||
using namespace ngraph;
|
||||
using namespace opset8;
|
||||
using namespace std;
|
||||
|
||||
TEST(TransformationTests, make_stateful_by_name) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto X = make_shared<Parameter>(element::f32, Shape{32, 1, 10});
|
||||
auto Y = make_shared<Parameter>(element::f32, Shape{32, 1, 10});
|
||||
X->set_friendly_name("x");
|
||||
Y->set_friendly_name("y");
|
||||
|
||||
auto add = make_shared<Add>(X, Y);
|
||||
auto result0 = make_shared<Result>(add);
|
||||
auto result1 = make_shared<Result>(add);
|
||||
result0->set_friendly_name("res0");
|
||||
result1->set_friendly_name("res1");
|
||||
|
||||
f = make_shared<Function>(ResultVector{result0, result1}, ParameterVector{X, Y});
|
||||
std::map<std::string, std::string> pair_names = {{"x", "res0"}, {"y", "res1"}};
|
||||
f->validate_nodes_and_infer_types();
|
||||
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ov::pass::MakeStateful>(pair_names);
|
||||
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
{
|
||||
// create ReadValue for X
|
||||
auto variable_x = std::make_shared<Variable>(VariableInfo{PartialShape::dynamic(), element::dynamic, "xres0"});
|
||||
auto const_zero_x = make_shared<Constant>(element::f32, Shape{32, 1, 10}, 0);
|
||||
auto read_val_x = make_shared<ReadValue>(const_zero_x, variable_x);
|
||||
|
||||
// create ReadValue for Y
|
||||
auto variable_y = std::make_shared<Variable>(VariableInfo{PartialShape::dynamic(), element::dynamic, "yres1"});
|
||||
auto const_zero_y = make_shared<Constant>(element::f32, Shape{32, 1, 10}, 0);
|
||||
auto read_val_y = make_shared<ReadValue>(const_zero_y, variable_y);
|
||||
|
||||
auto add = make_shared<Add>(read_val_x, read_val_y);
|
||||
auto assign_x = make_shared<Assign>(add, variable_x);
|
||||
auto assign_y = make_shared<Assign>(add, variable_y);
|
||||
|
||||
f_ref = make_shared<Function>(ResultVector{}, SinkVector{assign_x, assign_y}, ParameterVector{});
|
||||
f_ref->validate_nodes_and_infer_types();
|
||||
}
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, make_stateful_by_param_res) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto X = make_shared<Parameter>(element::f32, Shape{32, 1, 10});
|
||||
auto Y = make_shared<Parameter>(element::f32, Shape{32, 1, 10});
|
||||
X->set_friendly_name("x");
|
||||
Y->set_friendly_name("y");
|
||||
|
||||
auto add = make_shared<Add>(X, Y);
|
||||
auto result0 = make_shared<Result>(add);
|
||||
auto result1 = make_shared<Result>(add);
|
||||
result0->set_friendly_name("res0");
|
||||
result1->set_friendly_name("res1");
|
||||
|
||||
f = make_shared<Function>(ResultVector{result0, result1}, ParameterVector{X, Y});
|
||||
std::vector<std::pair<std::string, std::string>> pair_names = {{"x", "res0"}, {"y", "res1"}};
|
||||
f->validate_nodes_and_infer_types();
|
||||
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ov::pass::MakeStateful>(ov::pass::MakeStateful::ParamResPairs{{X, result0}, {Y, result1}});
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
{
|
||||
// create ReadValue for X
|
||||
auto variable_x = std::make_shared<Variable>(VariableInfo{PartialShape::dynamic(), element::dynamic, "xres0"});
|
||||
auto const_zero_x = make_shared<Constant>(element::f32, Shape{32, 1, 10}, 0);
|
||||
auto read_val_x = make_shared<ReadValue>(const_zero_x, variable_x);
|
||||
|
||||
// create ReadValue for Y
|
||||
auto variable_y = std::make_shared<Variable>(VariableInfo{PartialShape::dynamic(), element::dynamic, "yres1"});
|
||||
auto const_zero_y = make_shared<Constant>(element::f32, Shape{32, 1, 10}, 0);
|
||||
auto read_val_y = make_shared<ReadValue>(const_zero_y, variable_y);
|
||||
|
||||
auto add = make_shared<Add>(read_val_x, read_val_y);
|
||||
auto assign_x = make_shared<Assign>(add, variable_x);
|
||||
auto assign_y = make_shared<Assign>(add, variable_y);
|
||||
|
||||
f_ref = make_shared<Function>(ResultVector{}, SinkVector{assign_x, assign_y}, ParameterVector{});
|
||||
f_ref->validate_nodes_and_infer_types();
|
||||
}
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, make_stateful_dynamic_shapes) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr);
|
||||
{
|
||||
auto X = make_shared<Parameter>(element::f32, PartialShape::dynamic());
|
||||
auto Y = make_shared<Parameter>(element::f32, PartialShape::dynamic());
|
||||
X->set_friendly_name("x");
|
||||
Y->set_friendly_name("y");
|
||||
|
||||
auto add = make_shared<Add>(X, Y);
|
||||
auto result0 = make_shared<Result>(add);
|
||||
auto result1 = make_shared<Result>(add);
|
||||
result0->set_friendly_name("res0");
|
||||
result1->set_friendly_name("res1");
|
||||
|
||||
f = make_shared<Function>(ResultVector{result0, result1}, ParameterVector{X, Y});
|
||||
map<std::string, std::string> pair_names = {{"x", "res0"}, {"y", "res1"}};
|
||||
f->validate_nodes_and_infer_types();
|
||||
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ov::pass::MakeStateful>(pair_names);
|
||||
|
||||
EXPECT_THROW(manager.run_passes(f), ::ov::AssertFailure);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
}
|
@ -9,8 +9,9 @@ from mo.utils.cli_parser import parse_transform
|
||||
|
||||
def get_available_transformations():
|
||||
try:
|
||||
from openvino.offline_transformations import ApplyLowLatencyTransformation # pylint: disable=import-error,no-name-in-module
|
||||
from openvino.offline_transformations import ApplyLowLatencyTransformation, ApplyMakeStatefulTransformation # pylint: disable=import-error,no-name-in-module
|
||||
return {
|
||||
'MakeStateful': ApplyMakeStatefulTransformation,
|
||||
'LowLatency2': ApplyLowLatencyTransformation,
|
||||
}
|
||||
except Exception as e:
|
||||
|
@ -49,7 +49,7 @@ def import_core_modules(silent: bool, path_to_module: str):
|
||||
try:
|
||||
from openvino.inference_engine import get_version, read_network # pylint: disable=import-error,no-name-in-module
|
||||
from openvino.offline_transformations import ApplyMOCTransformations, ApplyLowLatencyTransformation, \
|
||||
GenerateMappingFile # pylint: disable=import-error,no-name-in-module
|
||||
ApplyMakeStatefulTransformation, GenerateMappingFile # pylint: disable=import-error,no-name-in-module
|
||||
|
||||
# TODO: it is temporary import to check that nGraph python API is available. But in future
|
||||
# we need to replace it with Frontend imports
|
||||
|
@ -1190,7 +1190,18 @@ def isbool(value):
|
||||
return False
|
||||
|
||||
|
||||
def isdict(value):
|
||||
try:
|
||||
evaluated = ast.literal_eval(value)
|
||||
return isinstance(evaluated, dict)
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def convert_string_to_real_type(value: str):
|
||||
if isdict(value):
|
||||
return ast.literal_eval(value)
|
||||
|
||||
values = value.split(',')
|
||||
for i in range(len(values)):
|
||||
value = values[i]
|
||||
|
36
ngraph/core/include/openvino/pass/make_stateful.hpp
Normal file
36
ngraph/core/include/openvino/pass/make_stateful.hpp
Normal file
@ -0,0 +1,36 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "ngraph/opsets/opset8.hpp"
|
||||
#include "openvino/pass/pass.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
/**
|
||||
* @brief The transformation replaces the provided pairs Parameter and Result with ngraph Memory layers
|
||||
* ReadValue and Assign
|
||||
*/
|
||||
class OPENVINO_API MakeStateful : public FunctionPass {
|
||||
public:
|
||||
OPENVINO_RTTI("MakeStateful");
|
||||
|
||||
using ParamResPairs =
|
||||
std::vector<std::pair<std::shared_ptr<ngraph::opset8::Parameter>, std::shared_ptr<ngraph::opset8::Result>>>;
|
||||
|
||||
explicit MakeStateful(const ParamResPairs& pairs_to_replace) : m_param_res_pairs(pairs_to_replace) {}
|
||||
explicit MakeStateful(const std::map<std::string, std::string>& param_res_names)
|
||||
: m_param_res_names(param_res_names) {}
|
||||
bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
|
||||
|
||||
private:
|
||||
ParamResPairs m_param_res_pairs;
|
||||
std::map<std::string, std::string> m_param_res_names;
|
||||
};
|
||||
} // namespace pass
|
||||
} // namespace ov
|
91
ngraph/core/src/pass/make_stateful.cpp
Normal file
91
ngraph/core/src/pass/make_stateful.cpp
Normal file
@ -0,0 +1,91 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "openvino/pass/make_stateful.hpp"
|
||||
|
||||
#include <memory>
|
||||
#include <ngraph/rt_info.hpp>
|
||||
#include <openvino/op/util/variable.hpp>
|
||||
#include <openvino/opsets/opset8.hpp>
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
using namespace opset8;
|
||||
using namespace op::util;
|
||||
|
||||
namespace {
|
||||
string generate_variable_name(const shared_ptr<Parameter>& param, const shared_ptr<Result>& res) {
|
||||
return param->get_friendly_name() + res->get_friendly_name();
|
||||
}
|
||||
|
||||
ov::pass::MakeStateful::ParamResPairs find_param_results_by_names(
|
||||
const shared_ptr<ngraph::Function>& func,
|
||||
const std::map<std::string, std::string>& param_res_names) {
|
||||
ov::pass::MakeStateful::ParamResPairs pairs_to_replace;
|
||||
const auto& params = func->get_parameters();
|
||||
const auto& results = func->get_results();
|
||||
|
||||
// find corresponding param and result by name and add to the list
|
||||
for (const auto& param_res : param_res_names) {
|
||||
const auto& param_name = param_res.first;
|
||||
const auto& res_name = param_res.second;
|
||||
|
||||
auto param = std::find_if(params.begin(), params.end(), [&](const std::shared_ptr<ngraph::Node>& node) {
|
||||
return node->get_friendly_name() == param_name;
|
||||
});
|
||||
NGRAPH_CHECK(param != params.end(), "Parameter node with name = ", param_name, "doesn't exist in the function");
|
||||
|
||||
auto res = std::find_if(results.begin(), results.end(), [&](const std::shared_ptr<ngraph::Node>& node) {
|
||||
return node->get_friendly_name() == res_name;
|
||||
});
|
||||
NGRAPH_CHECK(res != results.end(), "Result node with name = ", res_name, " doesn't exist in the function");
|
||||
|
||||
pairs_to_replace.emplace_back(*param, *res);
|
||||
}
|
||||
return pairs_to_replace;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool ov::pass::MakeStateful::run_on_function(std::shared_ptr<ngraph::Function> f) {
|
||||
if (m_param_res_pairs.empty()) {
|
||||
m_param_res_pairs = find_param_results_by_names(f, m_param_res_names);
|
||||
}
|
||||
|
||||
VariableVector variables;
|
||||
SinkVector sinks;
|
||||
for (const auto& pair : m_param_res_pairs) {
|
||||
const auto& param = pair.first;
|
||||
const auto& res = pair.second;
|
||||
|
||||
NGRAPH_CHECK(param->get_partial_shape().is_static(),
|
||||
"Shape of Parameter ",
|
||||
param->get_friendly_name(),
|
||||
" must be static. MakeStateful transformation doesn't support dynamic shapes.");
|
||||
|
||||
// Create Variable
|
||||
std::string var_name = generate_variable_name(param, res);
|
||||
auto variable =
|
||||
std::make_shared<Variable>(VariableInfo{param->get_shape(), param->get_element_type(), var_name});
|
||||
variables.push_back(variable);
|
||||
|
||||
// Create ReadValue
|
||||
auto const_zero = make_shared<Constant>(param->get_element_type(), param->get_shape(), 0);
|
||||
auto read_val = make_shared<ReadValue>(const_zero, variable);
|
||||
replace_node(param, read_val);
|
||||
copy_runtime_info(param, {read_val, const_zero});
|
||||
|
||||
// Create Assign
|
||||
auto assign = make_shared<Assign>(res->input_value(0), variable);
|
||||
copy_runtime_info(res, assign);
|
||||
|
||||
// Update Function
|
||||
sinks.push_back(assign);
|
||||
f->remove_result(res);
|
||||
f->remove_parameter(param);
|
||||
assign->add_control_dependency(read_val);
|
||||
}
|
||||
f->add_variables(variables);
|
||||
f->add_sinks(sinks);
|
||||
return true;
|
||||
}
|
Loading…
Reference in New Issue
Block a user