New extension serialization (#8252)
* Fixed custom op serialization * Deprecate old serialize constructor
This commit is contained in:
parent
97a4b944b1
commit
512db063a8
@ -118,7 +118,9 @@ void CNNNetworkSerializer::operator << (const CNNNetwork & network) {
|
||||
};
|
||||
|
||||
// Serialize to old representation in case of old API
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
ov::pass::StreamSerialize serializer(_ostream, getCustomOpSets(), serializeInputsAndOutputs);
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
serializer.run_on_function(std::const_pointer_cast<ngraph::Function>(network.getFunction()));
|
||||
}
|
||||
|
||||
|
@ -11,6 +11,7 @@
|
||||
#include "common_test_utils/file_utils.hpp"
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
#include "ie_core.hpp"
|
||||
#include "openvino/runtime/core.hpp"
|
||||
#include "ngraph/ngraph.hpp"
|
||||
#include "transformations/serialize.hpp"
|
||||
|
||||
@ -27,6 +28,11 @@ static std::string get_extension_path() {
|
||||
{}, std::string("template_extension") + IE_BUILD_POSTFIX);
|
||||
}
|
||||
|
||||
static std::string get_ov_extension_path() {
|
||||
return FileUtils::makePluginLibraryName<char>(
|
||||
{}, std::string("template_ov_extension") + IE_BUILD_POSTFIX);
|
||||
}
|
||||
|
||||
class CustomOpsSerializationTest : public ::testing::Test {
|
||||
protected:
|
||||
std::string test_name =
|
||||
@ -158,3 +164,25 @@ TEST_F(CustomOpsSerializationTest, CustomOpNoExtensions) {
|
||||
|
||||
ASSERT_TRUE(success) << message;
|
||||
}
|
||||
|
||||
TEST_F(CustomOpsSerializationTest, CustomOpOVExtensions) {
|
||||
const std::string model = CommonTestUtils::getModelFromTestModelZoo(
|
||||
IR_SERIALIZATION_MODELS_PATH "custom_identity.xml");
|
||||
|
||||
ov::runtime::Core core;
|
||||
core.add_extension(get_ov_extension_path());
|
||||
auto expected = core.read_model(model);
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::Serialize>(
|
||||
m_out_xml_path, m_out_bin_path,
|
||||
ngraph::pass::Serialize::Version::IR_V10);
|
||||
manager.run_passes(expected);
|
||||
auto result = core.read_model(m_out_xml_path, m_out_bin_path);
|
||||
|
||||
bool success;
|
||||
std::string message;
|
||||
std::tie(success, message) =
|
||||
compare_functions(result, expected, true, false, false, true, true);
|
||||
|
||||
ASSERT_TRUE(success) << message;
|
||||
}
|
||||
|
@ -0,0 +1,51 @@
|
||||
<?xml version="1.0" ?>
|
||||
<!--This is syntetic model created by hand desined only for white-box unit testing-->
|
||||
<net name="Network" version="10">
|
||||
<layers>
|
||||
<layer name="in1" type="Parameter" id="0" version="opset1">
|
||||
<data element_type="f32" shape="2,2,2,1"/>
|
||||
<output>
|
||||
<port id="0" precision="FP32">
|
||||
<dim>2</dim>
|
||||
<dim>2</dim>
|
||||
<dim>2</dim>
|
||||
<dim>1</dim>
|
||||
</port>
|
||||
</output>
|
||||
</layer>
|
||||
<layer name="operation" id="1" type="Identity" version="extension">
|
||||
<data add="11"/>
|
||||
<input>
|
||||
<port id="1" precision="FP32">
|
||||
<dim>2</dim>
|
||||
<dim>2</dim>
|
||||
<dim>2</dim>
|
||||
<dim>1</dim>
|
||||
</port>
|
||||
</input>
|
||||
<output>
|
||||
<port id="2" precision="FP32">
|
||||
<dim>2</dim>
|
||||
<dim>2</dim>
|
||||
<dim>2</dim>
|
||||
<dim>1</dim>
|
||||
</port>
|
||||
</output>
|
||||
</layer>
|
||||
<layer name="output" type="Result" id="2" version="opset1">
|
||||
<input>
|
||||
<port id="0" precision="FP32">
|
||||
<dim>2</dim>
|
||||
<dim>2</dim>
|
||||
<dim>2</dim>
|
||||
<dim>1</dim>
|
||||
</port>
|
||||
</input>
|
||||
</layer>
|
||||
</layers>
|
||||
<edges>
|
||||
<edge from-layer="0" from-port="0" to-layer="1" to-port="1"/>
|
||||
<edge from-layer="1" from-port="2" to-layer="2" to-port="0"/>
|
||||
</edges>
|
||||
</net>
|
||||
|
@ -33,12 +33,14 @@ public:
|
||||
};
|
||||
bool run_on_function(std::shared_ptr<ov::Function> f) override;
|
||||
|
||||
OPENVINO_DEPRECATED("This constructor is deprecated. Please use new extension API")
|
||||
Serialize(std::ostream& xmlFile,
|
||||
std::ostream& binFile,
|
||||
std::map<std::string, ngraph::OpSet> custom_opsets,
|
||||
Version version = Version::UNSPECIFIED);
|
||||
Serialize(std::ostream& xmlFile, std::ostream& binFile, Version version = Version::UNSPECIFIED);
|
||||
|
||||
OPENVINO_DEPRECATED("This constructor is deprecated. Please use new extension API")
|
||||
Serialize(const std::string& xmlPath,
|
||||
const std::string& binPath,
|
||||
std::map<std::string, ngraph::OpSet> custom_opsets,
|
||||
@ -74,10 +76,14 @@ public:
|
||||
|
||||
bool run_on_function(std::shared_ptr<ov::Function> f) override;
|
||||
|
||||
OPENVINO_DEPRECATED("This constructor is deprecated. Please use new extension API")
|
||||
StreamSerialize(std::ostream& stream,
|
||||
std::map<std::string, ngraph::OpSet>&& custom_opsets = {},
|
||||
const std::function<void(std::ostream&)>& custom_data_serializer = {},
|
||||
Serialize::Version version = Serialize::Version::UNSPECIFIED);
|
||||
StreamSerialize(std::ostream& stream,
|
||||
const std::function<void(std::ostream&)>& custom_data_serializer = {},
|
||||
Serialize::Version version = Serialize::Version::UNSPECIFIED);
|
||||
|
||||
private:
|
||||
std::ostream& m_stream;
|
||||
|
@ -125,21 +125,6 @@ void ngfunction_2_ir(pugi::xml_node& node,
|
||||
int64_t version,
|
||||
bool deterministic);
|
||||
|
||||
// Some of the operators were added to wrong opsets. This is a mapping
|
||||
// that allows such operators to be serialized with proper opsets.
|
||||
// If new operators are discovered that have the same problem, the mapping
|
||||
// needs to be updated here. The keys contain op name and version in NodeTypeInfo.
|
||||
const std::unordered_map<ngraph::Node::type_info_t, std::string> special_operator_to_opset_assignments = {
|
||||
{ngraph::Node::type_info_t("ShuffleChannels", 0), "opset3"}};
|
||||
|
||||
std::string get_special_opset_for_op(const ngraph::Node::type_info_t& type_info) {
|
||||
auto found = special_operator_to_opset_assignments.find(type_info);
|
||||
if (found != end(special_operator_to_opset_assignments)) {
|
||||
return found->second;
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
namespace rt_info {
|
||||
const std::vector<std::string> list_of_names{
|
||||
"PrimitivesPriority",
|
||||
@ -574,6 +559,10 @@ const std::vector<Edge> create_edge_mapping(const std::unordered_map<ngraph::Nod
|
||||
}
|
||||
|
||||
std::string get_opset_name(const ngraph::Node* n, const std::map<std::string, ngraph::OpSet>& custom_opsets) {
|
||||
OPENVINO_ASSERT(n != nullptr);
|
||||
if (n->get_type_info().version_id != nullptr) {
|
||||
return n->get_type_info().version_id;
|
||||
}
|
||||
// Try to find opset name from RT info
|
||||
auto opset_it = n->get_rt_info().find("opset");
|
||||
if (opset_it != n->get_rt_info().end()) {
|
||||
@ -585,26 +574,6 @@ std::string get_opset_name(const ngraph::Node* n, const std::map<std::string, ng
|
||||
}
|
||||
}
|
||||
|
||||
auto opsets = std::array<std::reference_wrapper<const ngraph::OpSet>, 8>{ngraph::get_opset1(),
|
||||
ngraph::get_opset2(),
|
||||
ngraph::get_opset3(),
|
||||
ngraph::get_opset4(),
|
||||
ngraph::get_opset5(),
|
||||
ngraph::get_opset6(),
|
||||
ngraph::get_opset7(),
|
||||
ngraph::get_opset8()};
|
||||
|
||||
auto special_opset = get_special_opset_for_op(n->get_type_info());
|
||||
if (!special_opset.empty()) {
|
||||
return special_opset;
|
||||
}
|
||||
// return the oldest opset name where node type is present
|
||||
for (size_t idx = 0; idx < opsets.size(); idx++) {
|
||||
if (opsets[idx].get().contains_op_type(n)) {
|
||||
return "opset" + std::to_string(idx + 1);
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto& custom_opset : custom_opsets) {
|
||||
std::string name = custom_opset.first;
|
||||
ngraph::OpSet opset = custom_opset.second;
|
||||
@ -1137,6 +1106,7 @@ bool pass::Serialize::run_on_function(std::shared_ptr<ngraph::Function> f) {
|
||||
return false;
|
||||
}
|
||||
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
pass::Serialize::Serialize(std::ostream& xmlFile,
|
||||
std::ostream& binFile,
|
||||
std::map<std::string, ngraph::OpSet> custom_opsets,
|
||||
@ -1147,6 +1117,7 @@ pass::Serialize::Serialize(std::ostream& xmlFile,
|
||||
m_binPath{},
|
||||
m_version{version},
|
||||
m_custom_opsets{custom_opsets} {}
|
||||
|
||||
pass::Serialize::Serialize(std::ostream& xmlFile, std::ostream& binFile, pass::Serialize::Version version)
|
||||
: pass::Serialize::Serialize(xmlFile, binFile, std::map<std::string, ngraph::OpSet>{}, version) {}
|
||||
|
||||
@ -1162,7 +1133,9 @@ pass::Serialize::Serialize(const std::string& xmlPath,
|
||||
|
||||
pass::Serialize::Serialize(const std::string& xmlPath, const std::string& binPath, pass::Serialize::Version version)
|
||||
: pass::Serialize::Serialize(xmlPath, binPath, std::map<std::string, ngraph::OpSet>{}, version) {}
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
pass::StreamSerialize::StreamSerialize(std::ostream& stream,
|
||||
std::map<std::string, ngraph::OpSet>&& custom_opsets,
|
||||
const std::function<void(std::ostream&)>& custom_data_serializer,
|
||||
@ -1177,6 +1150,12 @@ pass::StreamSerialize::StreamSerialize(std::ostream& stream,
|
||||
}
|
||||
}
|
||||
|
||||
pass::StreamSerialize::StreamSerialize(std::ostream& stream,
|
||||
const std::function<void(std::ostream&)>& custom_data_serializer,
|
||||
Serialize::Version version)
|
||||
: StreamSerialize(stream, {}, custom_data_serializer, version) {}
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
|
||||
bool pass::StreamSerialize::run_on_function(std::shared_ptr<ngraph::Function> f) {
|
||||
/*
|
||||
Format:
|
||||
|
Loading…
Reference in New Issue
Block a user