Expose enum version (#8893)

* add version param to offline_transofmations serialize

* fix style

* fix style

* remove redundant commented code
This commit is contained in:
Bartek Szmelczynski 2021-12-01 11:49:28 +01:00 committed by GitHub
parent 8a1b63ec51
commit f211749e15
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 121 additions and 3 deletions

View File

@ -8,6 +8,7 @@
#include <generate_mapping_file.hpp>
#include <openvino/pass/make_stateful.hpp>
#include <openvino/pass/serialize.hpp>
#include <pot_transformations.hpp>
#include <pruning.hpp>
#include <transformations/common_optimizations/compress_float_constants.hpp>
@ -18,6 +19,19 @@
#include "openvino/pass/low_latency.hpp"
#include "openvino/pass/manager.hpp"
using Version = ov::pass::Serialize::Version;
inline Version convert_to_version(const std::string& version) {
if (version == "UNSPECIFIED")
return Version::UNSPECIFIED;
if (version == "IR_V10")
return Version::IR_V10;
if (version == "IR_V11")
return Version::IR_V11;
throw ov::Exception("Invoked with wrong version argument: '" + version +
"'! The supported versions are: 'UNSPECIFIED'(default), 'IR_V10', 'IR_V11'.");
}
namespace py = pybind11;
void regmodule_offline_transformations(py::module m) {
@ -98,12 +112,56 @@ void regmodule_offline_transformations(py::module m) {
// todo: remove as serialize as part of passManager api will be merged
m_offline_transformations.def(
"serialize",
[](std::shared_ptr<ov::Function> function, const std::string& path_to_xml, const std::string& path_to_bin) {
[](std::shared_ptr<ov::Function> function,
const std::string& path_to_xml,
const std::string& path_to_bin,
const std::string& version) {
ov::pass::Manager manager;
manager.register_pass<ov::pass::Serialize>(path_to_xml, path_to_bin);
manager.register_pass<ov::pass::Serialize>(path_to_xml, path_to_bin, convert_to_version(version));
manager.run_passes(function);
},
py::arg("function"),
py::arg("model_path"),
py::arg("weights_path"));
py::arg("weights_path"),
py::arg("version") = "UNSPECIFIED",
R"(
Serialize given function into IR. The generated .xml and .bin files will be save
into provided paths.
Parameters
----------
function : ov.Function
function which will be converted to IR representation
xml_path : str
path where .xml file will be saved
bin_path : str
path where .bin file will be saved
version : str
sets the version of the IR which will be generated.
Supported versions are:
- "UNSPECIFIED" (default) : Use the latest or function version
- "IR_V10" : v10 IR
- "IR_V11" : v11 IR
Examples:
----------
1. Default IR version:
shape = [2, 2]
parameter_a = ov.parameter(shape, dtype=np.float32, name="A")
parameter_b = ov.parameter(shape, dtype=np.float32, name="B")
parameter_c = ov.parameter(shape, dtype=np.float32, name="C")
model = (parameter_a + parameter_b) * parameter_c
func = Function(model, [parameter_a, parameter_b, parameter_c], "Function")
# IR generated with default version
serialize(func, model_path="./serialized.xml", weights_path="./serialized.bin")
2. IR version 11:
shape = [2, 2]
parameter_a = ov.parameter(shape, dtype=np.float32, name="A")
parameter_b = ov.parameter(shape, dtype=np.float32, name="B")
parameter_c = ov.parameter(shape, dtype=np.float32, name="C")
model = (parameter_a + parameter_b) * parameter_c
func = Function(model, [parameter_a, parameter_b, parameter_c], "Function")
# IR generated with default version
serialize(func, model_path="./serialized.xml", "./serialized.bin", version="IR_V11")
// )");
}

View File

@ -117,3 +117,63 @@ def test_compress_model_transformation():
assert func is not None
assert func.get_ordered_ops()[0].get_element_type().get_type_name() == "f16"
def test_Version_default():
core = Core()
xml_path = "./serialized_function.xml"
bin_path = "./serialized_function.bin"
shape = [100, 100, 2]
parameter_a = ov.opset8.parameter(shape, dtype=np.float32, name="A")
parameter_b = ov.opset8.parameter(shape, dtype=np.float32, name="B")
model = ov.opset8.floor(ov.opset8.minimum(ov.opset8.abs(parameter_a), parameter_b))
func = Function(model, [parameter_a, parameter_b], "Function")
serialize(func, xml_path, bin_path)
res_func = core.read_model(model=xml_path, weights=bin_path)
assert func.get_parameters() == res_func.get_parameters()
assert func.get_ordered_ops() == res_func.get_ordered_ops()
os.remove(xml_path)
os.remove(bin_path)
def test_Version_ir_v10():
core = Core()
xml_path = "./serialized_function.xml"
bin_path = "./serialized_function.bin"
shape = [100, 100, 2]
parameter_a = ov.opset8.parameter(shape, dtype=np.float32, name="A")
parameter_b = ov.opset8.parameter(shape, dtype=np.float32, name="B")
model = ov.opset8.floor(ov.opset8.minimum(ov.opset8.abs(parameter_a), parameter_b))
func = Function(model, [parameter_a, parameter_b], "Function")
serialize(func, xml_path, bin_path, "IR_V10")
res_func = core.read_model(model=xml_path, weights=bin_path)
assert func.get_parameters() == res_func.get_parameters()
assert func.get_ordered_ops() == res_func.get_ordered_ops()
os.remove(xml_path)
os.remove(bin_path)
def test_Version_ir_v11():
core = Core()
xml_path = "./serialized_function.xml"
bin_path = "./serialized_function.bin"
shape = [100, 100, 2]
parameter_a = ov.opset8.parameter(shape, dtype=np.float32, name="A")
parameter_b = ov.opset8.parameter(shape, dtype=np.float32, name="B")
model = ov.opset8.floor(ov.opset8.minimum(ov.opset8.abs(parameter_a), parameter_b))
func = Function(model, [parameter_a, parameter_b], "Function")
serialize(func, xml_path, bin_path, "IR_V11")
res_func = core.read_model(model=xml_path, weights=bin_path)
assert func.get_parameters() == res_func.get_parameters()
assert func.get_ordered_ops() == res_func.get_ordered_ops()
os.remove(xml_path)
os.remove(bin_path)