[PyOV] add param-result to offline transformation (#21140)
This commit is contained in:
parent
0c041d7ebc
commit
56301d8878
@ -94,6 +94,16 @@ void regmodule_offline_transformations(py::module m) {
|
|||||||
py::arg("model"),
|
py::arg("model"),
|
||||||
py::arg("param_res_names"));
|
py::arg("param_res_names"));
|
||||||
|
|
||||||
|
m_offline_transformations.def(
|
||||||
|
"apply_make_stateful_transformation",
|
||||||
|
[](std::shared_ptr<ov::Model> model, const ov::pass::MakeStateful::ParamResPairs& pairs_to_replace) {
|
||||||
|
ov::pass::Manager manager;
|
||||||
|
manager.register_pass<ov::pass::MakeStateful>(pairs_to_replace);
|
||||||
|
manager.run_passes(model);
|
||||||
|
},
|
||||||
|
py::arg("model"),
|
||||||
|
py::arg("pairs_to_replace"));
|
||||||
|
|
||||||
m_offline_transformations.def(
|
m_offline_transformations.def(
|
||||||
"compress_model_transformation",
|
"compress_model_transformation",
|
||||||
[](std::shared_ptr<ov::Model> model) {
|
[](std::shared_ptr<ov::Model> model) {
|
||||||
|
@ -141,6 +141,21 @@ def test_pruning_transformation():
|
|||||||
|
|
||||||
|
|
||||||
def test_make_stateful_transformations():
|
def test_make_stateful_transformations():
|
||||||
|
param = ov.opset13.parameter(PartialShape([1, 3, 22, 22]), name="parameter")
|
||||||
|
param.get_output_tensor(0).set_names({"parameter"})
|
||||||
|
relu = ov.opset13.relu(param)
|
||||||
|
res = ov.opset13.result(relu, name="result")
|
||||||
|
res.get_output_tensor(0).set_names({"result"})
|
||||||
|
model = Model([res], [param], "test")
|
||||||
|
|
||||||
|
apply_make_stateful_transformation(model, [(param, res)])
|
||||||
|
|
||||||
|
assert model is not None
|
||||||
|
assert len(model.get_parameters()) == 0
|
||||||
|
assert len(model.get_results()) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_make_stateful_transformations_with_dics():
|
||||||
model = get_relu_model()
|
model = get_relu_model()
|
||||||
|
|
||||||
apply_make_stateful_transformation(model, {"parameter": "result"})
|
apply_make_stateful_transformation(model, {"parameter": "result"})
|
||||||
|
Loading…
Reference in New Issue
Block a user