serialize to memory (#6485)

This commit is contained in:
Anna Likholat 2021-07-09 09:52:40 +03:00 committed by GitHub
parent de8c57e034
commit 98148539b3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 171 additions and 0 deletions

View File

@ -199,6 +199,22 @@ public:
*/ */
void serialize(const std::string& xmlPath, const std::string& binPath = {}) const; void serialize(const std::string& xmlPath, const std::string& binPath = {}) const;
/**
* @brief Serialize network to IR and weights streams.
*
* @param xmlBuf output IR stream.
* @param binBuf output weights stream.
*/
void serialize(std::ostream& xmlBuf, std::ostream& binBuf) const;
/**
* @brief Serialize network to IR stream and weights Blob::Ptr.
*
* @param xmlBuf output IR stream.
* @param binBlob output weights Blob::Ptr.
*/
void serialize(std::ostream& xmlBuf, Blob::Ptr& binBlob) const;
/** /**
* @brief Method maps framework tensor name to OpenVINO name * @brief Method maps framework tensor name to OpenVINO name
* @param orig_name Framework tensor name * @param orig_name Framework tensor name

View File

@ -200,6 +200,14 @@ public:
virtual StatusCode serialize(const std::string& xmlPath, const std::string& binPath, ResponseDesc* resp) const virtual StatusCode serialize(const std::string& xmlPath, const std::string& binPath, ResponseDesc* resp) const
noexcept = 0; noexcept = 0;
INFERENCE_ENGINE_DEPRECATED("Use InferenceEngine::CNNNetwork wrapper instead")
virtual StatusCode serialize(std::ostream& xmlFile, std::ostream& binFile, ResponseDesc* resp) const
noexcept = 0;
INFERENCE_ENGINE_DEPRECATED("Use InferenceEngine::CNNNetwork wrapper instead")
virtual StatusCode serialize(std::ostream& xmlPath, Blob::Ptr& binPath, ResponseDesc* resp) const
noexcept = 0;
/** /**
* @deprecated Use InferenceEngine::CNNNetwork wrapper instead * @deprecated Use InferenceEngine::CNNNetwork wrapper instead
* @brief Methods maps framework tensor name to OpenVINO name * @brief Methods maps framework tensor name to OpenVINO name

View File

@ -8,6 +8,10 @@
#include <ie_common.h> #include <ie_common.h>
#include <math.h> #include <math.h>
#include <ie_memcpy.h>
#include <blob_factory.hpp>
#include <cassert> #include <cassert>
#include <map> #include <map>
#include <memory> #include <memory>
@ -476,6 +480,64 @@ StatusCode CNNNetworkNGraphImpl::serialize(const std::string& xmlPath,
return OK; return OK;
} }
StatusCode CNNNetworkNGraphImpl::serialize(std::ostream& xmlBuf,
std::ostream& binBuf,
ResponseDesc* resp) const noexcept {
try {
std::map<std::string, ngraph::OpSet> custom_opsets;
for (const auto& extension : _ie_extensions) {
auto opset = extension->getOpSets();
custom_opsets.insert(begin(opset), end(opset));
}
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::Serialize>(
xmlBuf, binBuf, ngraph::pass::Serialize::Version::IR_V10,
custom_opsets);
manager.run_passes(_ngraph_function);
} catch (const Exception& e) {
return DescriptionBuffer(GENERAL_ERROR, resp) << e.what();
} catch (const std::exception& e) {
return DescriptionBuffer(UNEXPECTED, resp) << e.what();
} catch (...) {
return DescriptionBuffer(UNEXPECTED, resp);
}
return OK;
}
StatusCode CNNNetworkNGraphImpl::serialize(std::ostream& xmlBuf,
Blob::Ptr& binBlob,
ResponseDesc* resp) const noexcept {
try {
std::map<std::string, ngraph::OpSet> custom_opsets;
for (const auto& extension : _ie_extensions) {
auto opset = extension->getOpSets();
custom_opsets.insert(begin(opset), end(opset));
}
std::stringstream binBuf;
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::Serialize>(
xmlBuf, binBuf, ngraph::pass::Serialize::Version::IR_V10,
custom_opsets);
manager.run_passes(_ngraph_function);
std::streambuf* pbuf = binBuf.rdbuf();
unsigned long bufSize = binBuf.tellp();
TensorDesc tensorDesc(Precision::U8, { bufSize }, Layout::C);
binBlob = make_shared_blob<uint8_t>(tensorDesc);
binBlob->allocate();
pbuf->sgetn(binBlob->buffer(), bufSize);
} catch (const Exception& e) {
return DescriptionBuffer(GENERAL_ERROR, resp) << e.what();
} catch (const std::exception& e) {
return DescriptionBuffer(UNEXPECTED, resp) << e.what();
} catch (...) {
return DescriptionBuffer(UNEXPECTED, resp);
}
return OK;
}
StatusCode CNNNetworkNGraphImpl::getOVNameForTensor(std::string& ov_name, const std::string& orig_name, ResponseDesc* resp) const noexcept { StatusCode CNNNetworkNGraphImpl::getOVNameForTensor(std::string& ov_name, const std::string& orig_name, ResponseDesc* resp) const noexcept {
if (_tensorNames.find(orig_name) == _tensorNames.end()) if (_tensorNames.find(orig_name) == _tensorNames.end())
return DescriptionBuffer(NOT_FOUND, resp) << "Framework tensor with name \"" << orig_name << "\" was not mapped to OpenVINO data!"; return DescriptionBuffer(NOT_FOUND, resp) << "Framework tensor with name \"" << orig_name << "\" was not mapped to OpenVINO data!";

View File

@ -79,6 +79,12 @@ public:
StatusCode serialize(const std::string& xmlPath, const std::string& binPath, ResponseDesc* resp) const StatusCode serialize(const std::string& xmlPath, const std::string& binPath, ResponseDesc* resp) const
noexcept override; noexcept override;
StatusCode serialize(std::ostream& xmlBuf, std::ostream& binBuf, ResponseDesc* resp) const
noexcept override;
StatusCode serialize(std::ostream& xmlBuf, Blob::Ptr& binBlob, ResponseDesc* resp) const
noexcept override;
StatusCode getOVNameForTensor(std::string& ov_name, const std::string& orig_name, ResponseDesc* resp) const noexcept override; StatusCode getOVNameForTensor(std::string& ov_name, const std::string& orig_name, ResponseDesc* resp) const noexcept override;
// used by convertFunctionToICNNNetwork from legacy library // used by convertFunctionToICNNNetwork from legacy library

View File

@ -124,6 +124,14 @@ void CNNNetwork::serialize(const std::string& xmlPath, const std::string& binPat
CALL_STATUS_FNC(serialize, xmlPath, binPath); CALL_STATUS_FNC(serialize, xmlPath, binPath);
} }
void CNNNetwork::serialize(std::ostream& xmlBuf, std::ostream& binBuf) const {
CALL_STATUS_FNC(serialize, xmlBuf, binBuf);
}
void CNNNetwork::serialize(std::ostream& xmlBuf, Blob::Ptr& binBlob) const {
CALL_STATUS_FNC(serialize, xmlBuf, binBlob);
}
std::string CNNNetwork::getOVNameForTensor(const std::string& orig_name) const { std::string CNNNetwork::getOVNameForTensor(const std::string& orig_name) const {
std::string ov_name; std::string ov_name;
CALL_STATUS_FNC(getOVNameForTensor, ov_name, orig_name); CALL_STATUS_FNC(getOVNameForTensor, ov_name, orig_name);

View File

@ -122,6 +122,12 @@ public:
StatusCode serialize(const std::string& xmlPath, const std::string& binPath, ResponseDesc* resp) const StatusCode serialize(const std::string& xmlPath, const std::string& binPath, ResponseDesc* resp) const
noexcept override; noexcept override;
StatusCode serialize(std::ostream& xmlBuf, std::ostream& binBuf, ResponseDesc* resp) const
noexcept override;
StatusCode serialize(std::ostream& xmlBuf, Blob::Ptr& binBlob, ResponseDesc* resp) const
noexcept override;
protected: protected:
std::map<std::string, DataPtr> _data; std::map<std::string, DataPtr> _data;
std::map<std::string, CNNLayerPtr> _layers; std::map<std::string, CNNLayerPtr> _layers;

View File

@ -408,6 +408,17 @@ StatusCode CNNNetworkImpl::serialize(const std::string& xmlPath, const std::stri
return DescriptionBuffer(NOT_IMPLEMENTED, resp) << "The CNNNetworkImpl::serialize is not implemented"; return DescriptionBuffer(NOT_IMPLEMENTED, resp) << "The CNNNetworkImpl::serialize is not implemented";
} }
StatusCode CNNNetworkImpl::serialize(std::ostream& xmlBuf, std::ostream& binBuf, ResponseDesc* resp) const
noexcept {
return DescriptionBuffer(NOT_IMPLEMENTED, resp) << "The CNNNetworkImpl::serialize is not implemented";
}
StatusCode CNNNetworkImpl::serialize(std::ostream& xmlBuf, Blob::Ptr& binBlob, ResponseDesc* resp) const
noexcept {
return DescriptionBuffer(NOT_IMPLEMENTED, resp) << "The CNNNetworkImpl::serialize is not implemented";
}
StatusCode CNNNetworkImpl::setBatchSize(size_t size, ResponseDesc* responseDesc) noexcept { StatusCode CNNNetworkImpl::setBatchSize(size_t size, ResponseDesc* responseDesc) noexcept {
try { try {
auto originalBatchSize = getBatchSize(); auto originalBatchSize = getBatchSize();

View File

@ -124,3 +124,49 @@ TEST_F(SerializationDeterministicityTest, ModelWithConstants) {
ASSERT_TRUE(files_equal(xml_1, xml_2)); ASSERT_TRUE(files_equal(xml_1, xml_2));
ASSERT_TRUE(files_equal(bin_1, bin_2)); ASSERT_TRUE(files_equal(bin_1, bin_2));
} }
TEST_F(SerializationDeterministicityTest, SerializeToStream) {
const std::string model =
IR_SERIALIZATION_MODELS_PATH "add_abc_initializers.xml";
const std::string weights =
IR_SERIALIZATION_MODELS_PATH "add_abc_initializers.bin";
std::stringstream m_out_xml_buf, m_out_bin_buf;
InferenceEngine::Blob::Ptr binBlob;
InferenceEngine::Core ie;
auto expected = ie.ReadNetwork(model, weights);
expected.serialize(m_out_xml_buf, m_out_bin_buf);
std::streambuf* pbuf = m_out_bin_buf.rdbuf();
unsigned long bufSize = m_out_bin_buf.tellp();
InferenceEngine::TensorDesc tensorDesc(InferenceEngine::Precision::U8,
{ bufSize }, InferenceEngine::Layout::C);
binBlob = InferenceEngine::make_shared_blob<uint8_t>(tensorDesc);
binBlob->allocate();
pbuf->sgetn(binBlob->buffer(), bufSize);
auto result = ie.ReadNetwork(m_out_xml_buf.str(), binBlob);
ASSERT_TRUE(expected.layerCount() == result.layerCount());
ASSERT_TRUE(expected.getInputShapes() == result.getInputShapes());
}
TEST_F(SerializationDeterministicityTest, SerializeToBlob) {
const std::string model =
IR_SERIALIZATION_MODELS_PATH "add_abc_initializers.xml";
const std::string weights =
IR_SERIALIZATION_MODELS_PATH "add_abc_initializers.bin";
std::stringstream m_out_xml_buf;
InferenceEngine::Blob::Ptr m_out_bin_buf;
InferenceEngine::Core ie;
auto expected = ie.ReadNetwork(model, weights);
expected.serialize(m_out_xml_buf, m_out_bin_buf);
auto result = ie.ReadNetwork(m_out_xml_buf.str(), m_out_bin_buf);
ASSERT_TRUE(expected.layerCount() == result.layerCount());
ASSERT_TRUE(expected.getInputShapes() == result.getInputShapes());
}

View File

@ -48,4 +48,8 @@ class MockICNNNetwork final : public InferenceEngine::ICNNNetwork {
(const, noexcept)); (const, noexcept));
MOCK_METHOD(InferenceEngine::StatusCode, serialize, MOCK_METHOD(InferenceEngine::StatusCode, serialize,
(const std::string &, const std::string &, InferenceEngine::ResponseDesc*), (const, noexcept)); (const std::string &, const std::string &, InferenceEngine::ResponseDesc*), (const, noexcept));
MOCK_METHOD(InferenceEngine::StatusCode, serialize,
(std::ostream &, std::ostream &, InferenceEngine::ResponseDesc*), (const, noexcept));
MOCK_METHOD(InferenceEngine::StatusCode, serialize,
(std::ostream &, InferenceEngine::Blob::Ptr &, InferenceEngine::ResponseDesc*), (const, noexcept));
}; };

View File

@ -43,6 +43,10 @@ public:
MOCK_METHOD(StatusCode, reshape, (const ICNNNetwork::InputShapes &, ResponseDesc *), (noexcept)); MOCK_METHOD(StatusCode, reshape, (const ICNNNetwork::InputShapes &, ResponseDesc *), (noexcept));
MOCK_METHOD(StatusCode, serialize, MOCK_METHOD(StatusCode, serialize,
(const std::string &, const std::string &, InferenceEngine::ResponseDesc*), (const, noexcept)); (const std::string &, const std::string &, InferenceEngine::ResponseDesc*), (const, noexcept));
MOCK_METHOD(StatusCode, serialize,
(std::ostream &, std::ostream &, InferenceEngine::ResponseDesc*), (const, noexcept));
MOCK_METHOD(StatusCode, serialize,
(std::ostream &, Blob::Ptr &, InferenceEngine::ResponseDesc*), (const, noexcept));
}; };
IE_SUPPRESS_DEPRECATED_END IE_SUPPRESS_DEPRECATED_END