New extension serialization (#8252)

* Fixed custom op serialization

* Deprecate old serialize constructor
This commit is contained in:
Ilya Churaev 2021-11-10 07:15:30 +03:00 committed by GitHub
parent 97a4b944b1
commit 512db063a8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 101 additions and 35 deletions

View File

@ -118,7 +118,9 @@ void CNNNetworkSerializer::operator << (const CNNNetwork & network) {
};
// Serialize to old representation in case of old API
OPENVINO_SUPPRESS_DEPRECATED_START
ov::pass::StreamSerialize serializer(_ostream, getCustomOpSets(), serializeInputsAndOutputs);
OPENVINO_SUPPRESS_DEPRECATED_END
serializer.run_on_function(std::const_pointer_cast<ngraph::Function>(network.getFunction()));
}

View File

@ -11,6 +11,7 @@
#include "common_test_utils/file_utils.hpp"
#include "common_test_utils/ngraph_test_utils.hpp"
#include "ie_core.hpp"
#include "openvino/runtime/core.hpp"
#include "ngraph/ngraph.hpp"
#include "transformations/serialize.hpp"
@ -27,6 +28,11 @@ static std::string get_extension_path() {
{}, std::string("template_extension") + IE_BUILD_POSTFIX);
}
static std::string get_ov_extension_path() {
return FileUtils::makePluginLibraryName<char>(
{}, std::string("template_ov_extension") + IE_BUILD_POSTFIX);
}
class CustomOpsSerializationTest : public ::testing::Test {
protected:
std::string test_name =
@ -158,3 +164,25 @@ TEST_F(CustomOpsSerializationTest, CustomOpNoExtensions) {
ASSERT_TRUE(success) << message;
}
TEST_F(CustomOpsSerializationTest, CustomOpOVExtensions) {
const std::string model = CommonTestUtils::getModelFromTestModelZoo(
IR_SERIALIZATION_MODELS_PATH "custom_identity.xml");
ov::runtime::Core core;
core.add_extension(get_ov_extension_path());
auto expected = core.read_model(model);
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::Serialize>(
m_out_xml_path, m_out_bin_path,
ngraph::pass::Serialize::Version::IR_V10);
manager.run_passes(expected);
auto result = core.read_model(m_out_xml_path, m_out_bin_path);
bool success;
std::string message;
std::tie(success, message) =
compare_functions(result, expected, true, false, false, true, true);
ASSERT_TRUE(success) << message;
}

View File

@ -0,0 +1,51 @@
<?xml version="1.0" ?>
<!--This is syntetic model created by hand desined only for white-box unit testing-->
<net name="Network" version="10">
<layers>
<layer name="in1" type="Parameter" id="0" version="opset1">
<data element_type="f32" shape="2,2,2,1"/>
<output>
<port id="0" precision="FP32">
<dim>2</dim>
<dim>2</dim>
<dim>2</dim>
<dim>1</dim>
</port>
</output>
</layer>
<layer name="operation" id="1" type="Identity" version="extension">
<data add="11"/>
<input>
<port id="1" precision="FP32">
<dim>2</dim>
<dim>2</dim>
<dim>2</dim>
<dim>1</dim>
</port>
</input>
<output>
<port id="2" precision="FP32">
<dim>2</dim>
<dim>2</dim>
<dim>2</dim>
<dim>1</dim>
</port>
</output>
</layer>
<layer name="output" type="Result" id="2" version="opset1">
<input>
<port id="0" precision="FP32">
<dim>2</dim>
<dim>2</dim>
<dim>2</dim>
<dim>1</dim>
</port>
</input>
</layer>
</layers>
<edges>
<edge from-layer="0" from-port="0" to-layer="1" to-port="1"/>
<edge from-layer="1" from-port="2" to-layer="2" to-port="0"/>
</edges>
</net>

View File

@ -33,12 +33,14 @@ public:
};
bool run_on_function(std::shared_ptr<ov::Function> f) override;
OPENVINO_DEPRECATED("This constructor is deprecated. Please use new extension API")
Serialize(std::ostream& xmlFile,
std::ostream& binFile,
std::map<std::string, ngraph::OpSet> custom_opsets,
Version version = Version::UNSPECIFIED);
Serialize(std::ostream& xmlFile, std::ostream& binFile, Version version = Version::UNSPECIFIED);
OPENVINO_DEPRECATED("This constructor is deprecated. Please use new extension API")
Serialize(const std::string& xmlPath,
const std::string& binPath,
std::map<std::string, ngraph::OpSet> custom_opsets,
@ -74,10 +76,14 @@ public:
bool run_on_function(std::shared_ptr<ov::Function> f) override;
OPENVINO_DEPRECATED("This constructor is deprecated. Please use new extension API")
StreamSerialize(std::ostream& stream,
std::map<std::string, ngraph::OpSet>&& custom_opsets = {},
const std::function<void(std::ostream&)>& custom_data_serializer = {},
Serialize::Version version = Serialize::Version::UNSPECIFIED);
StreamSerialize(std::ostream& stream,
const std::function<void(std::ostream&)>& custom_data_serializer = {},
Serialize::Version version = Serialize::Version::UNSPECIFIED);
private:
std::ostream& m_stream;

View File

@ -125,21 +125,6 @@ void ngfunction_2_ir(pugi::xml_node& node,
int64_t version,
bool deterministic);
// Some of the operators were added to wrong opsets. This is a mapping
// that allows such operators to be serialized with proper opsets.
// If new operators are discovered that have the same problem, the mapping
// needs to be updated here. The keys contain op name and version in NodeTypeInfo.
const std::unordered_map<ngraph::Node::type_info_t, std::string> special_operator_to_opset_assignments = {
{ngraph::Node::type_info_t("ShuffleChannels", 0), "opset3"}};
std::string get_special_opset_for_op(const ngraph::Node::type_info_t& type_info) {
auto found = special_operator_to_opset_assignments.find(type_info);
if (found != end(special_operator_to_opset_assignments)) {
return found->second;
}
return "";
}
namespace rt_info {
const std::vector<std::string> list_of_names{
"PrimitivesPriority",
@ -574,6 +559,10 @@ const std::vector<Edge> create_edge_mapping(const std::unordered_map<ngraph::Nod
}
std::string get_opset_name(const ngraph::Node* n, const std::map<std::string, ngraph::OpSet>& custom_opsets) {
OPENVINO_ASSERT(n != nullptr);
if (n->get_type_info().version_id != nullptr) {
return n->get_type_info().version_id;
}
// Try to find opset name from RT info
auto opset_it = n->get_rt_info().find("opset");
if (opset_it != n->get_rt_info().end()) {
@ -585,26 +574,6 @@ std::string get_opset_name(const ngraph::Node* n, const std::map<std::string, ng
}
}
auto opsets = std::array<std::reference_wrapper<const ngraph::OpSet>, 8>{ngraph::get_opset1(),
ngraph::get_opset2(),
ngraph::get_opset3(),
ngraph::get_opset4(),
ngraph::get_opset5(),
ngraph::get_opset6(),
ngraph::get_opset7(),
ngraph::get_opset8()};
auto special_opset = get_special_opset_for_op(n->get_type_info());
if (!special_opset.empty()) {
return special_opset;
}
// return the oldest opset name where node type is present
for (size_t idx = 0; idx < opsets.size(); idx++) {
if (opsets[idx].get().contains_op_type(n)) {
return "opset" + std::to_string(idx + 1);
}
}
for (const auto& custom_opset : custom_opsets) {
std::string name = custom_opset.first;
ngraph::OpSet opset = custom_opset.second;
@ -1137,6 +1106,7 @@ bool pass::Serialize::run_on_function(std::shared_ptr<ngraph::Function> f) {
return false;
}
OPENVINO_SUPPRESS_DEPRECATED_START
pass::Serialize::Serialize(std::ostream& xmlFile,
std::ostream& binFile,
std::map<std::string, ngraph::OpSet> custom_opsets,
@ -1147,6 +1117,7 @@ pass::Serialize::Serialize(std::ostream& xmlFile,
m_binPath{},
m_version{version},
m_custom_opsets{custom_opsets} {}
pass::Serialize::Serialize(std::ostream& xmlFile, std::ostream& binFile, pass::Serialize::Version version)
: pass::Serialize::Serialize(xmlFile, binFile, std::map<std::string, ngraph::OpSet>{}, version) {}
@ -1162,7 +1133,9 @@ pass::Serialize::Serialize(const std::string& xmlPath,
pass::Serialize::Serialize(const std::string& xmlPath, const std::string& binPath, pass::Serialize::Version version)
: pass::Serialize::Serialize(xmlPath, binPath, std::map<std::string, ngraph::OpSet>{}, version) {}
OPENVINO_SUPPRESS_DEPRECATED_END
OPENVINO_SUPPRESS_DEPRECATED_START
pass::StreamSerialize::StreamSerialize(std::ostream& stream,
std::map<std::string, ngraph::OpSet>&& custom_opsets,
const std::function<void(std::ostream&)>& custom_data_serializer,
@ -1177,6 +1150,12 @@ pass::StreamSerialize::StreamSerialize(std::ostream& stream,
}
}
pass::StreamSerialize::StreamSerialize(std::ostream& stream,
const std::function<void(std::ostream&)>& custom_data_serializer,
Serialize::Version version)
: StreamSerialize(stream, {}, custom_data_serializer, version) {}
OPENVINO_SUPPRESS_DEPRECATED_END
bool pass::StreamSerialize::run_on_function(std::shared_ptr<ngraph::Function> f) {
/*
Format: