[PyOV] add param-result to offline transformation (#21140)

This commit is contained in:
Anastasia Kuporosova 2023-11-20 18:59:02 +04:00 committed by GitHub
parent 0c041d7ebc
commit 56301d8878
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 25 additions and 0 deletions

View File

@ -94,6 +94,16 @@ void regmodule_offline_transformations(py::module m) {
py::arg("model"), py::arg("model"),
py::arg("param_res_names")); py::arg("param_res_names"));
m_offline_transformations.def(
"apply_make_stateful_transformation",
[](std::shared_ptr<ov::Model> model, const ov::pass::MakeStateful::ParamResPairs& pairs_to_replace) {
ov::pass::Manager manager;
manager.register_pass<ov::pass::MakeStateful>(pairs_to_replace);
manager.run_passes(model);
},
py::arg("model"),
py::arg("pairs_to_replace"));
m_offline_transformations.def( m_offline_transformations.def(
"compress_model_transformation", "compress_model_transformation",
[](std::shared_ptr<ov::Model> model) { [](std::shared_ptr<ov::Model> model) {

View File

@ -141,6 +141,21 @@ def test_pruning_transformation():
def test_make_stateful_transformations(): def test_make_stateful_transformations():
param = ov.opset13.parameter(PartialShape([1, 3, 22, 22]), name="parameter")
param.get_output_tensor(0).set_names({"parameter"})
relu = ov.opset13.relu(param)
res = ov.opset13.result(relu, name="result")
res.get_output_tensor(0).set_names({"result"})
model = Model([res], [param], "test")
apply_make_stateful_transformation(model, [(param, res)])
assert model is not None
assert len(model.get_parameters()) == 0
assert len(model.get_results()) == 0
def test_make_stateful_transformations_with_dics():
model = get_relu_model() model = get_relu_model()
apply_make_stateful_transformation(model, {"parameter": "result"}) apply_make_stateful_transformation(model, {"parameter": "result"})