New extension serialization (#8252)
* Fixed custom op serialization * Deprecate old serialize constructor
This commit is contained in:
parent
97a4b944b1
commit
512db063a8
@ -118,7 +118,9 @@ void CNNNetworkSerializer::operator << (const CNNNetwork & network) {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Serialize to old representation in case of old API
|
// Serialize to old representation in case of old API
|
||||||
|
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||||
ov::pass::StreamSerialize serializer(_ostream, getCustomOpSets(), serializeInputsAndOutputs);
|
ov::pass::StreamSerialize serializer(_ostream, getCustomOpSets(), serializeInputsAndOutputs);
|
||||||
|
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||||
serializer.run_on_function(std::const_pointer_cast<ngraph::Function>(network.getFunction()));
|
serializer.run_on_function(std::const_pointer_cast<ngraph::Function>(network.getFunction()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -11,6 +11,7 @@
|
|||||||
#include "common_test_utils/file_utils.hpp"
|
#include "common_test_utils/file_utils.hpp"
|
||||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||||
#include "ie_core.hpp"
|
#include "ie_core.hpp"
|
||||||
|
#include "openvino/runtime/core.hpp"
|
||||||
#include "ngraph/ngraph.hpp"
|
#include "ngraph/ngraph.hpp"
|
||||||
#include "transformations/serialize.hpp"
|
#include "transformations/serialize.hpp"
|
||||||
|
|
||||||
@ -27,6 +28,11 @@ static std::string get_extension_path() {
|
|||||||
{}, std::string("template_extension") + IE_BUILD_POSTFIX);
|
{}, std::string("template_extension") + IE_BUILD_POSTFIX);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static std::string get_ov_extension_path() {
|
||||||
|
return FileUtils::makePluginLibraryName<char>(
|
||||||
|
{}, std::string("template_ov_extension") + IE_BUILD_POSTFIX);
|
||||||
|
}
|
||||||
|
|
||||||
class CustomOpsSerializationTest : public ::testing::Test {
|
class CustomOpsSerializationTest : public ::testing::Test {
|
||||||
protected:
|
protected:
|
||||||
std::string test_name =
|
std::string test_name =
|
||||||
@ -158,3 +164,25 @@ TEST_F(CustomOpsSerializationTest, CustomOpNoExtensions) {
|
|||||||
|
|
||||||
ASSERT_TRUE(success) << message;
|
ASSERT_TRUE(success) << message;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(CustomOpsSerializationTest, CustomOpOVExtensions) {
|
||||||
|
const std::string model = CommonTestUtils::getModelFromTestModelZoo(
|
||||||
|
IR_SERIALIZATION_MODELS_PATH "custom_identity.xml");
|
||||||
|
|
||||||
|
ov::runtime::Core core;
|
||||||
|
core.add_extension(get_ov_extension_path());
|
||||||
|
auto expected = core.read_model(model);
|
||||||
|
ngraph::pass::Manager manager;
|
||||||
|
manager.register_pass<ngraph::pass::Serialize>(
|
||||||
|
m_out_xml_path, m_out_bin_path,
|
||||||
|
ngraph::pass::Serialize::Version::IR_V10);
|
||||||
|
manager.run_passes(expected);
|
||||||
|
auto result = core.read_model(m_out_xml_path, m_out_bin_path);
|
||||||
|
|
||||||
|
bool success;
|
||||||
|
std::string message;
|
||||||
|
std::tie(success, message) =
|
||||||
|
compare_functions(result, expected, true, false, false, true, true);
|
||||||
|
|
||||||
|
ASSERT_TRUE(success) << message;
|
||||||
|
}
|
||||||
|
@ -0,0 +1,51 @@
|
|||||||
|
<?xml version="1.0" ?>
|
||||||
|
<!--This is syntetic model created by hand desined only for white-box unit testing-->
|
||||||
|
<net name="Network" version="10">
|
||||||
|
<layers>
|
||||||
|
<layer name="in1" type="Parameter" id="0" version="opset1">
|
||||||
|
<data element_type="f32" shape="2,2,2,1"/>
|
||||||
|
<output>
|
||||||
|
<port id="0" precision="FP32">
|
||||||
|
<dim>2</dim>
|
||||||
|
<dim>2</dim>
|
||||||
|
<dim>2</dim>
|
||||||
|
<dim>1</dim>
|
||||||
|
</port>
|
||||||
|
</output>
|
||||||
|
</layer>
|
||||||
|
<layer name="operation" id="1" type="Identity" version="extension">
|
||||||
|
<data add="11"/>
|
||||||
|
<input>
|
||||||
|
<port id="1" precision="FP32">
|
||||||
|
<dim>2</dim>
|
||||||
|
<dim>2</dim>
|
||||||
|
<dim>2</dim>
|
||||||
|
<dim>1</dim>
|
||||||
|
</port>
|
||||||
|
</input>
|
||||||
|
<output>
|
||||||
|
<port id="2" precision="FP32">
|
||||||
|
<dim>2</dim>
|
||||||
|
<dim>2</dim>
|
||||||
|
<dim>2</dim>
|
||||||
|
<dim>1</dim>
|
||||||
|
</port>
|
||||||
|
</output>
|
||||||
|
</layer>
|
||||||
|
<layer name="output" type="Result" id="2" version="opset1">
|
||||||
|
<input>
|
||||||
|
<port id="0" precision="FP32">
|
||||||
|
<dim>2</dim>
|
||||||
|
<dim>2</dim>
|
||||||
|
<dim>2</dim>
|
||||||
|
<dim>1</dim>
|
||||||
|
</port>
|
||||||
|
</input>
|
||||||
|
</layer>
|
||||||
|
</layers>
|
||||||
|
<edges>
|
||||||
|
<edge from-layer="0" from-port="0" to-layer="1" to-port="1"/>
|
||||||
|
<edge from-layer="1" from-port="2" to-layer="2" to-port="0"/>
|
||||||
|
</edges>
|
||||||
|
</net>
|
||||||
|
|
@ -33,12 +33,14 @@ public:
|
|||||||
};
|
};
|
||||||
bool run_on_function(std::shared_ptr<ov::Function> f) override;
|
bool run_on_function(std::shared_ptr<ov::Function> f) override;
|
||||||
|
|
||||||
|
OPENVINO_DEPRECATED("This constructor is deprecated. Please use new extension API")
|
||||||
Serialize(std::ostream& xmlFile,
|
Serialize(std::ostream& xmlFile,
|
||||||
std::ostream& binFile,
|
std::ostream& binFile,
|
||||||
std::map<std::string, ngraph::OpSet> custom_opsets,
|
std::map<std::string, ngraph::OpSet> custom_opsets,
|
||||||
Version version = Version::UNSPECIFIED);
|
Version version = Version::UNSPECIFIED);
|
||||||
Serialize(std::ostream& xmlFile, std::ostream& binFile, Version version = Version::UNSPECIFIED);
|
Serialize(std::ostream& xmlFile, std::ostream& binFile, Version version = Version::UNSPECIFIED);
|
||||||
|
|
||||||
|
OPENVINO_DEPRECATED("This constructor is deprecated. Please use new extension API")
|
||||||
Serialize(const std::string& xmlPath,
|
Serialize(const std::string& xmlPath,
|
||||||
const std::string& binPath,
|
const std::string& binPath,
|
||||||
std::map<std::string, ngraph::OpSet> custom_opsets,
|
std::map<std::string, ngraph::OpSet> custom_opsets,
|
||||||
@ -74,10 +76,14 @@ public:
|
|||||||
|
|
||||||
bool run_on_function(std::shared_ptr<ov::Function> f) override;
|
bool run_on_function(std::shared_ptr<ov::Function> f) override;
|
||||||
|
|
||||||
|
OPENVINO_DEPRECATED("This constructor is deprecated. Please use new extension API")
|
||||||
StreamSerialize(std::ostream& stream,
|
StreamSerialize(std::ostream& stream,
|
||||||
std::map<std::string, ngraph::OpSet>&& custom_opsets = {},
|
std::map<std::string, ngraph::OpSet>&& custom_opsets = {},
|
||||||
const std::function<void(std::ostream&)>& custom_data_serializer = {},
|
const std::function<void(std::ostream&)>& custom_data_serializer = {},
|
||||||
Serialize::Version version = Serialize::Version::UNSPECIFIED);
|
Serialize::Version version = Serialize::Version::UNSPECIFIED);
|
||||||
|
StreamSerialize(std::ostream& stream,
|
||||||
|
const std::function<void(std::ostream&)>& custom_data_serializer = {},
|
||||||
|
Serialize::Version version = Serialize::Version::UNSPECIFIED);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::ostream& m_stream;
|
std::ostream& m_stream;
|
||||||
|
@ -125,21 +125,6 @@ void ngfunction_2_ir(pugi::xml_node& node,
|
|||||||
int64_t version,
|
int64_t version,
|
||||||
bool deterministic);
|
bool deterministic);
|
||||||
|
|
||||||
// Some of the operators were added to wrong opsets. This is a mapping
|
|
||||||
// that allows such operators to be serialized with proper opsets.
|
|
||||||
// If new operators are discovered that have the same problem, the mapping
|
|
||||||
// needs to be updated here. The keys contain op name and version in NodeTypeInfo.
|
|
||||||
const std::unordered_map<ngraph::Node::type_info_t, std::string> special_operator_to_opset_assignments = {
|
|
||||||
{ngraph::Node::type_info_t("ShuffleChannels", 0), "opset3"}};
|
|
||||||
|
|
||||||
std::string get_special_opset_for_op(const ngraph::Node::type_info_t& type_info) {
|
|
||||||
auto found = special_operator_to_opset_assignments.find(type_info);
|
|
||||||
if (found != end(special_operator_to_opset_assignments)) {
|
|
||||||
return found->second;
|
|
||||||
}
|
|
||||||
return "";
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace rt_info {
|
namespace rt_info {
|
||||||
const std::vector<std::string> list_of_names{
|
const std::vector<std::string> list_of_names{
|
||||||
"PrimitivesPriority",
|
"PrimitivesPriority",
|
||||||
@ -574,6 +559,10 @@ const std::vector<Edge> create_edge_mapping(const std::unordered_map<ngraph::Nod
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::string get_opset_name(const ngraph::Node* n, const std::map<std::string, ngraph::OpSet>& custom_opsets) {
|
std::string get_opset_name(const ngraph::Node* n, const std::map<std::string, ngraph::OpSet>& custom_opsets) {
|
||||||
|
OPENVINO_ASSERT(n != nullptr);
|
||||||
|
if (n->get_type_info().version_id != nullptr) {
|
||||||
|
return n->get_type_info().version_id;
|
||||||
|
}
|
||||||
// Try to find opset name from RT info
|
// Try to find opset name from RT info
|
||||||
auto opset_it = n->get_rt_info().find("opset");
|
auto opset_it = n->get_rt_info().find("opset");
|
||||||
if (opset_it != n->get_rt_info().end()) {
|
if (opset_it != n->get_rt_info().end()) {
|
||||||
@ -585,26 +574,6 @@ std::string get_opset_name(const ngraph::Node* n, const std::map<std::string, ng
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
auto opsets = std::array<std::reference_wrapper<const ngraph::OpSet>, 8>{ngraph::get_opset1(),
|
|
||||||
ngraph::get_opset2(),
|
|
||||||
ngraph::get_opset3(),
|
|
||||||
ngraph::get_opset4(),
|
|
||||||
ngraph::get_opset5(),
|
|
||||||
ngraph::get_opset6(),
|
|
||||||
ngraph::get_opset7(),
|
|
||||||
ngraph::get_opset8()};
|
|
||||||
|
|
||||||
auto special_opset = get_special_opset_for_op(n->get_type_info());
|
|
||||||
if (!special_opset.empty()) {
|
|
||||||
return special_opset;
|
|
||||||
}
|
|
||||||
// return the oldest opset name where node type is present
|
|
||||||
for (size_t idx = 0; idx < opsets.size(); idx++) {
|
|
||||||
if (opsets[idx].get().contains_op_type(n)) {
|
|
||||||
return "opset" + std::to_string(idx + 1);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (const auto& custom_opset : custom_opsets) {
|
for (const auto& custom_opset : custom_opsets) {
|
||||||
std::string name = custom_opset.first;
|
std::string name = custom_opset.first;
|
||||||
ngraph::OpSet opset = custom_opset.second;
|
ngraph::OpSet opset = custom_opset.second;
|
||||||
@ -1137,6 +1106,7 @@ bool pass::Serialize::run_on_function(std::shared_ptr<ngraph::Function> f) {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||||
pass::Serialize::Serialize(std::ostream& xmlFile,
|
pass::Serialize::Serialize(std::ostream& xmlFile,
|
||||||
std::ostream& binFile,
|
std::ostream& binFile,
|
||||||
std::map<std::string, ngraph::OpSet> custom_opsets,
|
std::map<std::string, ngraph::OpSet> custom_opsets,
|
||||||
@ -1147,6 +1117,7 @@ pass::Serialize::Serialize(std::ostream& xmlFile,
|
|||||||
m_binPath{},
|
m_binPath{},
|
||||||
m_version{version},
|
m_version{version},
|
||||||
m_custom_opsets{custom_opsets} {}
|
m_custom_opsets{custom_opsets} {}
|
||||||
|
|
||||||
pass::Serialize::Serialize(std::ostream& xmlFile, std::ostream& binFile, pass::Serialize::Version version)
|
pass::Serialize::Serialize(std::ostream& xmlFile, std::ostream& binFile, pass::Serialize::Version version)
|
||||||
: pass::Serialize::Serialize(xmlFile, binFile, std::map<std::string, ngraph::OpSet>{}, version) {}
|
: pass::Serialize::Serialize(xmlFile, binFile, std::map<std::string, ngraph::OpSet>{}, version) {}
|
||||||
|
|
||||||
@ -1162,7 +1133,9 @@ pass::Serialize::Serialize(const std::string& xmlPath,
|
|||||||
|
|
||||||
pass::Serialize::Serialize(const std::string& xmlPath, const std::string& binPath, pass::Serialize::Version version)
|
pass::Serialize::Serialize(const std::string& xmlPath, const std::string& binPath, pass::Serialize::Version version)
|
||||||
: pass::Serialize::Serialize(xmlPath, binPath, std::map<std::string, ngraph::OpSet>{}, version) {}
|
: pass::Serialize::Serialize(xmlPath, binPath, std::map<std::string, ngraph::OpSet>{}, version) {}
|
||||||
|
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||||
|
|
||||||
|
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||||
pass::StreamSerialize::StreamSerialize(std::ostream& stream,
|
pass::StreamSerialize::StreamSerialize(std::ostream& stream,
|
||||||
std::map<std::string, ngraph::OpSet>&& custom_opsets,
|
std::map<std::string, ngraph::OpSet>&& custom_opsets,
|
||||||
const std::function<void(std::ostream&)>& custom_data_serializer,
|
const std::function<void(std::ostream&)>& custom_data_serializer,
|
||||||
@ -1177,6 +1150,12 @@ pass::StreamSerialize::StreamSerialize(std::ostream& stream,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pass::StreamSerialize::StreamSerialize(std::ostream& stream,
|
||||||
|
const std::function<void(std::ostream&)>& custom_data_serializer,
|
||||||
|
Serialize::Version version)
|
||||||
|
: StreamSerialize(stream, {}, custom_data_serializer, version) {}
|
||||||
|
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||||
|
|
||||||
bool pass::StreamSerialize::run_on_function(std::shared_ptr<ngraph::Function> f) {
|
bool pass::StreamSerialize::run_on_function(std::shared_ptr<ngraph::Function> f) {
|
||||||
/*
|
/*
|
||||||
Format:
|
Format:
|
||||||
|
Loading…
Reference in New Issue
Block a user