Expose enum version (#8893)
* add version param to offline_transofmations serialize * fix style * fix style * remove redundant commented code
This commit is contained in:
parent
8a1b63ec51
commit
f211749e15
@ -8,6 +8,7 @@
|
||||
|
||||
#include <generate_mapping_file.hpp>
|
||||
#include <openvino/pass/make_stateful.hpp>
|
||||
#include <openvino/pass/serialize.hpp>
|
||||
#include <pot_transformations.hpp>
|
||||
#include <pruning.hpp>
|
||||
#include <transformations/common_optimizations/compress_float_constants.hpp>
|
||||
@ -18,6 +19,19 @@
|
||||
#include "openvino/pass/low_latency.hpp"
|
||||
#include "openvino/pass/manager.hpp"
|
||||
|
||||
using Version = ov::pass::Serialize::Version;
|
||||
|
||||
inline Version convert_to_version(const std::string& version) {
|
||||
if (version == "UNSPECIFIED")
|
||||
return Version::UNSPECIFIED;
|
||||
if (version == "IR_V10")
|
||||
return Version::IR_V10;
|
||||
if (version == "IR_V11")
|
||||
return Version::IR_V11;
|
||||
throw ov::Exception("Invoked with wrong version argument: '" + version +
|
||||
"'! The supported versions are: 'UNSPECIFIED'(default), 'IR_V10', 'IR_V11'.");
|
||||
}
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
void regmodule_offline_transformations(py::module m) {
|
||||
@ -98,12 +112,56 @@ void regmodule_offline_transformations(py::module m) {
|
||||
// todo: remove as serialize as part of passManager api will be merged
|
||||
m_offline_transformations.def(
|
||||
"serialize",
|
||||
[](std::shared_ptr<ov::Function> function, const std::string& path_to_xml, const std::string& path_to_bin) {
|
||||
[](std::shared_ptr<ov::Function> function,
|
||||
const std::string& path_to_xml,
|
||||
const std::string& path_to_bin,
|
||||
const std::string& version) {
|
||||
ov::pass::Manager manager;
|
||||
manager.register_pass<ov::pass::Serialize>(path_to_xml, path_to_bin);
|
||||
manager.register_pass<ov::pass::Serialize>(path_to_xml, path_to_bin, convert_to_version(version));
|
||||
manager.run_passes(function);
|
||||
},
|
||||
py::arg("function"),
|
||||
py::arg("model_path"),
|
||||
py::arg("weights_path"));
|
||||
py::arg("weights_path"),
|
||||
py::arg("version") = "UNSPECIFIED",
|
||||
R"(
|
||||
Serialize given function into IR. The generated .xml and .bin files will be save
|
||||
into provided paths.
|
||||
Parameters
|
||||
----------
|
||||
function : ov.Function
|
||||
function which will be converted to IR representation
|
||||
xml_path : str
|
||||
path where .xml file will be saved
|
||||
bin_path : str
|
||||
path where .bin file will be saved
|
||||
version : str
|
||||
sets the version of the IR which will be generated.
|
||||
Supported versions are:
|
||||
- "UNSPECIFIED" (default) : Use the latest or function version
|
||||
- "IR_V10" : v10 IR
|
||||
- "IR_V11" : v11 IR
|
||||
|
||||
Examples:
|
||||
----------
|
||||
1. Default IR version:
|
||||
shape = [2, 2]
|
||||
parameter_a = ov.parameter(shape, dtype=np.float32, name="A")
|
||||
parameter_b = ov.parameter(shape, dtype=np.float32, name="B")
|
||||
parameter_c = ov.parameter(shape, dtype=np.float32, name="C")
|
||||
model = (parameter_a + parameter_b) * parameter_c
|
||||
func = Function(model, [parameter_a, parameter_b, parameter_c], "Function")
|
||||
# IR generated with default version
|
||||
serialize(func, model_path="./serialized.xml", weights_path="./serialized.bin")
|
||||
|
||||
2. IR version 11:
|
||||
shape = [2, 2]
|
||||
parameter_a = ov.parameter(shape, dtype=np.float32, name="A")
|
||||
parameter_b = ov.parameter(shape, dtype=np.float32, name="B")
|
||||
parameter_c = ov.parameter(shape, dtype=np.float32, name="C")
|
||||
model = (parameter_a + parameter_b) * parameter_c
|
||||
func = Function(model, [parameter_a, parameter_b, parameter_c], "Function")
|
||||
# IR generated with default version
|
||||
serialize(func, model_path="./serialized.xml", "./serialized.bin", version="IR_V11")
|
||||
// )");
|
||||
}
|
||||
|
@ -117,3 +117,63 @@ def test_compress_model_transformation():
|
||||
|
||||
assert func is not None
|
||||
assert func.get_ordered_ops()[0].get_element_type().get_type_name() == "f16"
|
||||
|
||||
|
||||
def test_Version_default():
|
||||
core = Core()
|
||||
xml_path = "./serialized_function.xml"
|
||||
bin_path = "./serialized_function.bin"
|
||||
shape = [100, 100, 2]
|
||||
parameter_a = ov.opset8.parameter(shape, dtype=np.float32, name="A")
|
||||
parameter_b = ov.opset8.parameter(shape, dtype=np.float32, name="B")
|
||||
model = ov.opset8.floor(ov.opset8.minimum(ov.opset8.abs(parameter_a), parameter_b))
|
||||
func = Function(model, [parameter_a, parameter_b], "Function")
|
||||
|
||||
serialize(func, xml_path, bin_path)
|
||||
res_func = core.read_model(model=xml_path, weights=bin_path)
|
||||
|
||||
assert func.get_parameters() == res_func.get_parameters()
|
||||
assert func.get_ordered_ops() == res_func.get_ordered_ops()
|
||||
|
||||
os.remove(xml_path)
|
||||
os.remove(bin_path)
|
||||
|
||||
|
||||
def test_Version_ir_v10():
|
||||
core = Core()
|
||||
xml_path = "./serialized_function.xml"
|
||||
bin_path = "./serialized_function.bin"
|
||||
shape = [100, 100, 2]
|
||||
parameter_a = ov.opset8.parameter(shape, dtype=np.float32, name="A")
|
||||
parameter_b = ov.opset8.parameter(shape, dtype=np.float32, name="B")
|
||||
model = ov.opset8.floor(ov.opset8.minimum(ov.opset8.abs(parameter_a), parameter_b))
|
||||
func = Function(model, [parameter_a, parameter_b], "Function")
|
||||
|
||||
serialize(func, xml_path, bin_path, "IR_V10")
|
||||
res_func = core.read_model(model=xml_path, weights=bin_path)
|
||||
|
||||
assert func.get_parameters() == res_func.get_parameters()
|
||||
assert func.get_ordered_ops() == res_func.get_ordered_ops()
|
||||
|
||||
os.remove(xml_path)
|
||||
os.remove(bin_path)
|
||||
|
||||
|
||||
def test_Version_ir_v11():
|
||||
core = Core()
|
||||
xml_path = "./serialized_function.xml"
|
||||
bin_path = "./serialized_function.bin"
|
||||
shape = [100, 100, 2]
|
||||
parameter_a = ov.opset8.parameter(shape, dtype=np.float32, name="A")
|
||||
parameter_b = ov.opset8.parameter(shape, dtype=np.float32, name="B")
|
||||
model = ov.opset8.floor(ov.opset8.minimum(ov.opset8.abs(parameter_a), parameter_b))
|
||||
func = Function(model, [parameter_a, parameter_b], "Function")
|
||||
|
||||
serialize(func, xml_path, bin_path, "IR_V11")
|
||||
res_func = core.read_model(model=xml_path, weights=bin_path)
|
||||
|
||||
assert func.get_parameters() == res_func.get_parameters()
|
||||
assert func.get_ordered_ops() == res_func.get_ordered_ops()
|
||||
|
||||
os.remove(xml_path)
|
||||
os.remove(bin_path)
|
||||
|
Loading…
Reference in New Issue
Block a user