[PyOV] add param-result to offline transformation (#21140)
This commit is contained in:
parent
0c041d7ebc
commit
56301d8878
@ -94,6 +94,16 @@ void regmodule_offline_transformations(py::module m) {
|
||||
py::arg("model"),
|
||||
py::arg("param_res_names"));
|
||||
|
||||
m_offline_transformations.def(
|
||||
"apply_make_stateful_transformation",
|
||||
[](std::shared_ptr<ov::Model> model, const ov::pass::MakeStateful::ParamResPairs& pairs_to_replace) {
|
||||
ov::pass::Manager manager;
|
||||
manager.register_pass<ov::pass::MakeStateful>(pairs_to_replace);
|
||||
manager.run_passes(model);
|
||||
},
|
||||
py::arg("model"),
|
||||
py::arg("pairs_to_replace"));
|
||||
|
||||
m_offline_transformations.def(
|
||||
"compress_model_transformation",
|
||||
[](std::shared_ptr<ov::Model> model) {
|
||||
|
@ -141,6 +141,21 @@ def test_pruning_transformation():
|
||||
|
||||
|
||||
def test_make_stateful_transformations():
|
||||
param = ov.opset13.parameter(PartialShape([1, 3, 22, 22]), name="parameter")
|
||||
param.get_output_tensor(0).set_names({"parameter"})
|
||||
relu = ov.opset13.relu(param)
|
||||
res = ov.opset13.result(relu, name="result")
|
||||
res.get_output_tensor(0).set_names({"result"})
|
||||
model = Model([res], [param], "test")
|
||||
|
||||
apply_make_stateful_transformation(model, [(param, res)])
|
||||
|
||||
assert model is not None
|
||||
assert len(model.get_parameters()) == 0
|
||||
assert len(model.get_results()) == 0
|
||||
|
||||
|
||||
def test_make_stateful_transformations_with_dics():
|
||||
model = get_relu_model()
|
||||
|
||||
apply_make_stateful_transformation(model, {"parameter": "result"})
|
||||
|
Loading…
Reference in New Issue
Block a user