[core][python] ov::serialize (#10945)

* add ov::serialize

* create python binding

* update python tools

* use ov::serialize in benchmark app

* remove serialize from python offline_transformations

* fix import

* revert pot

* update docs

* apply review comments

* add const

* make bin path optional

* Add docs

* add compare test
This commit is contained in:
Alexey Lebedev 2022-03-23 11:44:00 +03:00 committed by GitHub
parent af874e7754
commit 3de9189d50
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 145 additions and 106 deletions

View File

@ -199,7 +199,7 @@ void save_example() {
// ======== Step 3: Save the model ================ // ======== Step 3: Save the model ================
std::string xml = "/path/to/some_model_saved.xml"; std::string xml = "/path/to/some_model_saved.xml";
std::string bin = "/path/to/some_model_saved.bin"; std::string bin = "/path/to/some_model_saved.bin";
ov::pass::Serialize(xml, bin).run_on_model(model); ov::serialize(model, xml, bin);
//! [ov:preprocess:save] //! [ov:preprocess:save]
} }

View File

@ -3,7 +3,7 @@
# #
from openvino.preprocess import ResizeAlgorithm, ColorFormat from openvino.preprocess import ResizeAlgorithm, ColorFormat
from openvino.runtime import Layout, Type from openvino.runtime import Layout, Type, serialize
xml_path = '' xml_path = ''
@ -210,11 +210,7 @@ model = ppp.build()
set_batch(model, 2) set_batch(model, 2)
# ======== Step 3: Save the model ================ # ======== Step 3: Save the model ================
pass_manager = Manager() serialize(model, '/path/to/some_model_saved.xml', '/path/to/some_model_saved.bin')
pass_manager.register_pass(pass_name="Serialize",
xml_path='/path/to/some_model_saved.xml',
bin_path='/path/to/some_model_saved.bin')
pass_manager.run_passes(model)
# ! [ov:preprocess:save] # ! [ov:preprocess:save]
# ! [ov:preprocess:save_load] # ! [ov:preprocess:save_load]

View File

@ -1103,9 +1103,7 @@ int main(int argc, char* argv[]) {
if (!FLAGS_exec_graph_path.empty()) { if (!FLAGS_exec_graph_path.empty()) {
try { try {
std::string fileName = fileNameNoExt(FLAGS_exec_graph_path); ov::serialize(compiledModel.get_runtime_model(), FLAGS_exec_graph_path);
ov::pass::Serialize serializer(fileName + ".xml", fileName + ".bin");
serializer.run_on_model(std::const_pointer_cast<ov::Model>(compiledModel.get_runtime_model()));
slog::info << "executable graph is stored to " << FLAGS_exec_graph_path << slog::endl; slog::info << "executable graph is stored to " << FLAGS_exec_graph_path << slog::endl;
} catch (const std::exception& ex) { } catch (const std::exception& ex) {
slog::err << "Can't get executable graph: " << ex.what() << slog::endl; slog::err << "Can't get executable graph: " << ex.what() << slog::endl;

View File

@ -15,6 +15,5 @@ from openvino.pyopenvino.offline_transformations import apply_low_latency_transf
from openvino.pyopenvino.offline_transformations import apply_pruning_transformation from openvino.pyopenvino.offline_transformations import apply_pruning_transformation
from openvino.pyopenvino.offline_transformations import generate_mapping_file from openvino.pyopenvino.offline_transformations import generate_mapping_file
from openvino.pyopenvino.offline_transformations import apply_make_stateful_transformation from openvino.pyopenvino.offline_transformations import apply_make_stateful_transformation
from openvino.pyopenvino.offline_transformations import serialize
from openvino.pyopenvino.offline_transformations import compress_model_transformation from openvino.pyopenvino.offline_transformations import compress_model_transformation
from openvino.pyopenvino.offline_transformations import compress_quantize_weights_transformation from openvino.pyopenvino.offline_transformations import compress_quantize_weights_transformation

View File

@ -47,6 +47,7 @@ from openvino.pyopenvino import ProfilingInfo
from openvino.pyopenvino import get_version from openvino.pyopenvino import get_version
from openvino.pyopenvino import get_batch from openvino.pyopenvino import get_batch
from openvino.pyopenvino import set_batch from openvino.pyopenvino import set_batch
from openvino.pyopenvino import serialize
# Import opsets # Import opsets
from openvino.runtime import opset1 from openvino.runtime import opset1

View File

@ -505,4 +505,20 @@ py::dict outputs_to_dict(const std::vector<ov::Output<const ov::Node>>& outputs,
return res; return res;
} }
ov::pass::Serialize::Version convert_to_version(const std::string& version) {
using Version = ov::pass::Serialize::Version;
if (version == "UNSPECIFIED") {
return Version::UNSPECIFIED;
}
if (version == "IR_V10") {
return Version::IR_V10;
}
if (version == "IR_V11") {
return Version::IR_V11;
}
throw ov::Exception("Invoked with wrong version argument: '" + version +
"'! The supported versions are: 'UNSPECIFIED'(default), 'IR_V10', 'IR_V11'.");
}
}; // namespace Common }; // namespace Common

View File

@ -19,6 +19,7 @@
#include "openvino/runtime/infer_request.hpp" #include "openvino/runtime/infer_request.hpp"
#include "openvino/runtime/tensor.hpp" #include "openvino/runtime/tensor.hpp"
#include "openvino/runtime/properties.hpp" #include "openvino/runtime/properties.hpp"
#include "openvino/pass/serialize.hpp"
#include "pyopenvino/core/containers.hpp" #include "pyopenvino/core/containers.hpp"
#include "pyopenvino/graph/any.hpp" #include "pyopenvino/graph/any.hpp"
@ -55,6 +56,8 @@ uint32_t get_optimal_number_of_requests(const ov::CompiledModel& actual);
py::dict outputs_to_dict(const std::vector<ov::Output<const ov::Node>>& outputs, ov::InferRequest& request); py::dict outputs_to_dict(const std::vector<ov::Output<const ov::Node>>& outputs, ov::InferRequest& request);
ov::pass::Serialize::Version convert_to_version(const std::string& version);
// Use only with classes that are not creatable by users on Python's side, because // Use only with classes that are not creatable by users on Python's side, because
// Objects created in Python that are wrapped with such wrapper will cause memory leaks. // Objects created in Python that are wrapped with such wrapper will cause memory leaks.
template <typename T> template <typename T>

View File

@ -21,19 +21,6 @@
#include "openvino/pass/low_latency.hpp" #include "openvino/pass/low_latency.hpp"
#include "openvino/pass/manager.hpp" #include "openvino/pass/manager.hpp"
using Version = ov::pass::Serialize::Version;
inline Version convert_to_version(const std::string& version) {
if (version == "UNSPECIFIED")
return Version::UNSPECIFIED;
if (version == "IR_V10")
return Version::IR_V10;
if (version == "IR_V11")
return Version::IR_V11;
throw ov::Exception("Invoked with wrong version argument: '" + version +
"'! The supported versions are: 'UNSPECIFIED'(default), 'IR_V10', 'IR_V11'.");
}
namespace py = pybind11; namespace py = pybind11;
void regmodule_offline_transformations(py::module m) { void regmodule_offline_transformations(py::module m) {
@ -129,63 +116,4 @@ void regmodule_offline_transformations(py::module m) {
manager.run_passes(model); manager.run_passes(model);
}, },
py::arg("model")); py::arg("model"));
// todo: remove as serialize as part of passManager api will be merged
m_offline_transformations.def(
"serialize",
[](std::shared_ptr<ov::Model> model,
const std::string& path_to_xml,
const std::string& path_to_bin,
const std::string& version) {
ov::pass::Manager manager;
manager.register_pass<ov::pass::Serialize>(path_to_xml, path_to_bin, convert_to_version(version));
manager.run_passes(model);
},
py::arg("model"),
py::arg("model_path"),
py::arg("weights_path"),
py::arg("version") = "UNSPECIFIED",
R"(
Serialize given model into IR. The generated .xml and .bin files will be saved
into provided paths.
:param model: model which will be converted to IR representation
:type model: openvino.runtime.Model
:param xml_path: path where .xml file will be saved
:type xml_path: str
:param bin_path: path where .bin file will be saved
:type bin_path: str
:param version: sets the version of the IR which will be generated.
Supported versions are:
- "UNSPECIFIED" (default) : Use the latest or model version
- "IR_V10" : v10 IR
- "IR_V11" : v11 IR
:Examples:
1. Default IR version:
.. code-block:: python
shape = [2, 2]
parameter_a = ov.parameter(shape, dtype=np.float32, name="A")
parameter_b = ov.parameter(shape, dtype=np.float32, name="B")
parameter_c = ov.parameter(shape, dtype=np.float32, name="C")
model = (parameter_a + parameter_b) * parameter_c
func = Model(model, [parameter_a, parameter_b, parameter_c], "Model")
# IR generated with default version
serialize(func, model_path="./serialized.xml", weights_path="./serialized.bin")
2. IR version 11:
.. code-block:: python
parameter_a = ov.parameter(shape, dtype=np.float32, name="A")
parameter_b = ov.parameter(shape, dtype=np.float32, name="B")
parameter_c = ov.parameter(shape, dtype=np.float32, name="C")
model = (parameter_a + parameter_b) * parameter_c
func = Model(model, [parameter_a, parameter_b, parameter_c], "Model")
# IR generated with default version
serialize(func, model_path="./serialized.xml", "./serialized.bin", version="IR_V11")
)");
} }

View File

@ -3,6 +3,7 @@
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#include <openvino/core/graph_util.hpp>
#include <openvino/core/model.hpp> #include <openvino/core/model.hpp>
#include <openvino/core/node.hpp> #include <openvino/core/node.hpp>
#include <openvino/core/version.hpp> #include <openvino/core/version.hpp>
@ -95,6 +96,60 @@ PYBIND11_MODULE(pyopenvino, m) {
py::arg("model"), py::arg("model"),
py::arg("batch_size") = -1); py::arg("batch_size") = -1);
m.def(
"serialize",
[](std::shared_ptr<ov::Model>& model,
const std::string& xml_path,
const std::string& bin_path,
const std::string& version) {
ov::serialize(model, xml_path, bin_path, Common::convert_to_version(version));
},
py::arg("model"),
py::arg("xml_path"),
py::arg("bin_path") = "",
py::arg("version") = "UNSPECIFIED",
R"(
Serialize given model into IR. The generated .xml and .bin files will be saved
into provided paths.
:param model: model which will be converted to IR representation
:type model: openvino.runtime.Model
:param xml_path: path where .xml file will be saved
:type xml_path: str
:param bin_path: path where .bin file will be saved (optional),
the same name as for xml_path will be used by default.
:type bin_path: str
:param version: version of the generated IR (optional).
Supported versions are:
- "UNSPECIFIED" (default) : Use the latest or model version
- "IR_V10" : v10 IR
- "IR_V11" : v11 IR
:Examples:
1. Default IR version:
.. code-block:: python
shape = [2, 2]
parameter_a = ov.parameter(shape, dtype=np.float32, name="A")
parameter_b = ov.parameter(shape, dtype=np.float32, name="B")
parameter_c = ov.parameter(shape, dtype=np.float32, name="C")
op = (parameter_a + parameter_b) * parameter_c
model = Model(op, [parameter_a, parameter_b, parameter_c], "Model")
# IR generated with default version
serialize(model, xml_path="./serialized.xml", bin_path="./serialized.bin")
2. IR version 11:
.. code-block:: python
parameter_a = ov.parameter(shape, dtype=np.float32, name="A")
parameter_b = ov.parameter(shape, dtype=np.float32, name="B")
parameter_c = ov.parameter(shape, dtype=np.float32, name="C")
op = (parameter_a + parameter_b) * parameter_c
model = Model(ops, [parameter_a, parameter_b, parameter_c], "Model")
# IR generated with default version
serialize(model, xml_path="./serialized.xml", bin_path="./serialized.bin", version="IR_V11")
)");
regclass_graph_PyRTMap(m); regclass_graph_PyRTMap(m);
regmodule_graph_types(m); regmodule_graph_types(m);
regclass_graph_Dimension(m); // Dimension must be registered before PartialShape regclass_graph_Dimension(m); // Dimension must be registered before PartialShape

View File

@ -3,9 +3,10 @@
import os import os
import numpy as np import numpy as np
from openvino.runtime import serialize
from openvino.offline_transformations import apply_moc_transformations, apply_pot_transformations, \ from openvino.offline_transformations import apply_moc_transformations, apply_pot_transformations, \
apply_low_latency_transformation, apply_pruning_transformation, apply_make_stateful_transformation, \ apply_low_latency_transformation, apply_pruning_transformation, apply_make_stateful_transformation, \
compress_model_transformation, serialize compress_model_transformation
from openvino.runtime import Model, PartialShape, Core from openvino.runtime import Model, PartialShape, Core
import openvino.runtime as ov import openvino.runtime as ov
@ -140,6 +141,16 @@ def test_Version_default():
os.remove(bin_path) os.remove(bin_path)
def test_serialize_default_bin():
xml_path = "./serialized_function.xml"
bin_path = "./serialized_function.bin"
model = get_test_function()
serialize(model, xml_path)
assert os.path.exists(bin_path)
os.remove(xml_path)
os.remove(bin_path)
def test_Version_ir_v10(): def test_Version_ir_v10():
core = Core() core = Core()
xml_path = "./serialized_function.xml" xml_path = "./serialized_function.xml"

View File

@ -18,6 +18,7 @@
#include "openvino/core/model.hpp" #include "openvino/core/model.hpp"
#include "openvino/core/node.hpp" #include "openvino/core/node.hpp"
#include "openvino/op/parameter.hpp" #include "openvino/op/parameter.hpp"
#include "openvino/pass/serialize.hpp"
namespace ov { namespace ov {
@ -278,4 +279,16 @@ bool replace_output_update_name(Output<Node> node, const Output<Node>& node_inpu
OPENVINO_API OPENVINO_API
bool replace_node_update_name(const std::shared_ptr<Node>& target, const std::shared_ptr<Node>& replacement); bool replace_node_update_name(const std::shared_ptr<Node>& target, const std::shared_ptr<Node>& replacement);
/// \brief Serialize given model into IR. The generated .xml and .bin files will be saved into provided paths.
/// \param m Model which will be converted to IR representation.
/// \param xml_path Path where .xml file will be saved.
/// \param bin_path Path where .bin file will be saved (optional).
/// The same name as for xml_path will be used by default.
/// \param version Version of the generated IR (optional).
OPENVINO_API
void serialize(const std::shared_ptr<const ov::Model>& m,
const std::string& xml_path,
const std::string& bin_path = "",
ov::pass::Serialize::Version version = ov::pass::Serialize::Version::UNSPECIFIED);
} // namespace ov } // namespace ov

View File

@ -808,3 +808,12 @@ bool ov::replace_node_update_name(const std::shared_ptr<Node>& target, const std
copy_runtime_info(target, replacement); copy_runtime_info(target, replacement);
return true; return true;
} }
void ov::serialize(const std::shared_ptr<const ov::Model>& m,
const std::string& xml_path,
const std::string& bin_path,
ov::pass::Serialize::Version version) {
ov::pass::Manager manager;
manager.register_pass<ov::pass::Serialize>(xml_path, bin_path, version);
manager.run_passes(std::const_pointer_cast<ov::Model>(m));
}

View File

@ -24,6 +24,20 @@ public:
std::string m_out_xml_path; std::string m_out_xml_path;
std::string m_out_bin_path; std::string m_out_bin_path;
void CompareSerialized(std::function<void(const std::shared_ptr<ov::Model>&)> serializer) {
auto expected = ov::test::readModel(m_model_path, m_binary_path);
auto orig = ov::clone_model(*expected);
serializer(expected);
auto result = ov::test::readModel(m_out_xml_path, m_out_bin_path);
const auto fc = FunctionsComparator::with_default()
.enable(FunctionsComparator::ATTRIBUTES)
.enable(FunctionsComparator::CONST_VALUES);
const auto res = fc.compare(result, expected);
const auto res2 = fc.compare(expected, orig);
EXPECT_TRUE(res.valid) << res.message;
EXPECT_TRUE(res2.valid) << res2.message;
}
void SetUp() override { void SetUp() override {
m_model_path = ov::util::path_join({SERIALIZED_ZOO, "ir/", std::get<0>(GetParam())}); m_model_path = ov::util::path_join({SERIALIZED_ZOO, "ir/", std::get<0>(GetParam())});
if (!std::get<1>(GetParam()).empty()) { if (!std::get<1>(GetParam()).empty()) {
@ -42,17 +56,15 @@ public:
}; };
TEST_P(SerializationTest, CompareFunctions) { TEST_P(SerializationTest, CompareFunctions) {
auto expected = ov::test::readModel(m_model_path, m_binary_path); CompareSerialized([this](const std::shared_ptr<ov::Model>& m) {
auto orig = ov::clone_model(*expected); ov::pass::Serialize(m_out_xml_path, m_out_bin_path).run_on_model(m);
ov::pass::Serialize(m_out_xml_path, m_out_bin_path).run_on_model(expected); });
auto result = ov::test::readModel(m_out_xml_path, m_out_bin_path); }
const auto fc = FunctionsComparator::with_default()
.enable(FunctionsComparator::ATTRIBUTES) TEST_P(SerializationTest, SerializeHelper) {
.enable(FunctionsComparator::CONST_VALUES); CompareSerialized([this](const std::shared_ptr<ov::Model>& m) {
const auto res = fc.compare(result, expected); ov::serialize(m, m_out_xml_path, m_out_bin_path);
const auto res2 = fc.compare(expected, orig); });
EXPECT_TRUE(res.valid) << res.message;
EXPECT_TRUE(res2.valid) << res2.message;
} }
INSTANTIATE_TEST_SUITE_P( INSTANTIATE_TEST_SUITE_P(

View File

@ -3,9 +3,8 @@
from collections import defaultdict from collections import defaultdict
import datetime import datetime
from openvino.runtime import Core, Model, PartialShape, Dimension, Layout, Type from openvino.runtime import Core, Model, PartialShape, Dimension, Layout, Type, serialize
from openvino.preprocess import PrePostProcessor from openvino.preprocess import PrePostProcessor
from openvino.runtime.passes import Manager
from .constants import DEVICE_DURATION_IN_SECS, UNKNOWN_DEVICE_TYPE, \ from .constants import DEVICE_DURATION_IN_SECS, UNKNOWN_DEVICE_TYPE, \
CPU_DEVICE_NAME, GPU_DEVICE_NAME CPU_DEVICE_NAME, GPU_DEVICE_NAME
@ -308,11 +307,7 @@ def process_help_inference_string(benchmark_app, device_number_streams):
def dump_exec_graph(compiled_model, model_path): def dump_exec_graph(compiled_model, model_path):
weight_path = model_path[:model_path.find(".xml")] + ".bin" serialize(compiled_model.get_runtime_model(), model_path)
pass_manager = Manager()
pass_manager.register_pass("Serialize", model_path, weight_path)
pass_manager.run_passes(compiled_model.get_runtime_model())
def print_perf_counters(perf_counts_list): def print_perf_counters(perf_counts_list):

View File

@ -54,7 +54,8 @@ def apply_offline_transformations(input_model: str, argv: argparse.Namespace):
# to produce correct mapping # to produce correct mapping
extract_names = argv.framework in ['tf', 'mxnet', 'kaldi'] extract_names = argv.framework in ['tf', 'mxnet', 'kaldi']
from openvino.offline_transformations import generate_mapping_file, serialize # pylint: disable=import-error,no-name-in-module from openvino.runtime import serialize # pylint: disable=import-error,no-name-in-module
from openvino.offline_transformations import generate_mapping_file # pylint: disable=import-error,no-name-in-module
from openvino.frontend import FrontEndManager # pylint: disable=no-name-in-module,import-error from openvino.frontend import FrontEndManager # pylint: disable=no-name-in-module,import-error
from openvino.tools.mo.back.preprocessing import apply_preprocessing # pylint: disable=no-name-in-module,import-error from openvino.tools.mo.back.preprocessing import apply_preprocessing # pylint: disable=no-name-in-module,import-error

View File

@ -40,7 +40,8 @@ def moc_emit_ir(ngraph_function: Model, argv: argparse.Namespace):
orig_model_name = os.path.normpath(os.path.join(output_dir, argv.model_name)) orig_model_name = os.path.normpath(os.path.join(output_dir, argv.model_name))
from openvino.offline_transformations import serialize, generate_mapping_file # pylint: disable=import-error,no-name-in-module from openvino.runtime import serialize # pylint: disable=import-error,no-name-in-module
from openvino.offline_transformations import generate_mapping_file # pylint: disable=import-error,no-name-in-module
serialize(ngraph_function, (orig_model_name + ".xml").encode('utf-8'), (orig_model_name + ".bin").encode('utf-8')) serialize(ngraph_function, (orig_model_name + ".xml").encode('utf-8'), (orig_model_name + ".bin").encode('utf-8'))
del argv.feManager del argv.feManager

View File

@ -51,9 +51,9 @@ def import_core_modules(silent: bool, path_to_module: str):
from openvino.offline_transformations import apply_moc_transformations, apply_moc_legacy_transformations,\ from openvino.offline_transformations import apply_moc_transformations, apply_moc_legacy_transformations,\
apply_low_latency_transformation # pylint: disable=import-error,no-name-in-module apply_low_latency_transformation # pylint: disable=import-error,no-name-in-module
from openvino.offline_transformations import apply_make_stateful_transformation, generate_mapping_file # pylint: disable=import-error,no-name-in-module from openvino.offline_transformations import apply_make_stateful_transformation, generate_mapping_file # pylint: disable=import-error,no-name-in-module
from openvino.offline_transformations import generate_mapping_file, apply_make_stateful_transformation, serialize # pylint: disable=import-error,no-name-in-module from openvino.offline_transformations import generate_mapping_file, apply_make_stateful_transformation # pylint: disable=import-error,no-name-in-module
from openvino.runtime import Model, get_version # pylint: disable=import-error,no-name-in-module from openvino.runtime import Model, serialize, get_version # pylint: disable=import-error,no-name-in-module
from openvino.runtime.op import Parameter # pylint: disable=import-error,no-name-in-module from openvino.runtime.op import Parameter # pylint: disable=import-error,no-name-in-module
from openvino.runtime import PartialShape, Dimension # pylint: disable=import-error,no-name-in-module from openvino.runtime import PartialShape, Dimension # pylint: disable=import-error,no-name-in-module
from openvino.frontend import FrontEndManager, FrontEnd # pylint: disable=no-name-in-module,import-error from openvino.frontend import FrontEndManager, FrontEnd # pylint: disable=no-name-in-module,import-error

View File

@ -34,6 +34,7 @@ def load_graph(model_config, target_device='ANY'):
apply_pot_transformations(model, target_device.encode('utf-8')) apply_pot_transformations(model, target_device.encode('utf-8'))
bin_path = serialized_bin_path bin_path = serialized_bin_path
xml_path = serialized_xml_path xml_path = serialized_xml_path
# TODO: replace by openvino.runtime.serialize
pass_manager.register_pass(pass_name="Serialize", xml_path=xml_path, bin_path=bin_path) pass_manager.register_pass(pass_name="Serialize", xml_path=xml_path, bin_path=bin_path)
pass_manager.run_passes(model) pass_manager.run_passes(model)