serialize to memory (#6485)
This commit is contained in:
parent
de8c57e034
commit
98148539b3
@ -199,6 +199,22 @@ public:
|
||||
*/
|
||||
void serialize(const std::string& xmlPath, const std::string& binPath = {}) const;
|
||||
|
||||
/**
|
||||
* @brief Serialize network to IR and weights streams.
|
||||
*
|
||||
* @param xmlBuf output IR stream.
|
||||
* @param binBuf output weights stream.
|
||||
*/
|
||||
void serialize(std::ostream& xmlBuf, std::ostream& binBuf) const;
|
||||
|
||||
/**
|
||||
* @brief Serialize network to IR stream and weights Blob::Ptr.
|
||||
*
|
||||
* @param xmlBuf output IR stream.
|
||||
* @param binBlob output weights Blob::Ptr.
|
||||
*/
|
||||
void serialize(std::ostream& xmlBuf, Blob::Ptr& binBlob) const;
|
||||
|
||||
/**
|
||||
* @brief Method maps framework tensor name to OpenVINO name
|
||||
* @param orig_name Framework tensor name
|
||||
|
@ -200,6 +200,14 @@ public:
|
||||
virtual StatusCode serialize(const std::string& xmlPath, const std::string& binPath, ResponseDesc* resp) const
|
||||
noexcept = 0;
|
||||
|
||||
INFERENCE_ENGINE_DEPRECATED("Use InferenceEngine::CNNNetwork wrapper instead")
|
||||
virtual StatusCode serialize(std::ostream& xmlFile, std::ostream& binFile, ResponseDesc* resp) const
|
||||
noexcept = 0;
|
||||
|
||||
INFERENCE_ENGINE_DEPRECATED("Use InferenceEngine::CNNNetwork wrapper instead")
|
||||
virtual StatusCode serialize(std::ostream& xmlPath, Blob::Ptr& binPath, ResponseDesc* resp) const
|
||||
noexcept = 0;
|
||||
|
||||
/**
|
||||
* @deprecated Use InferenceEngine::CNNNetwork wrapper instead
|
||||
* @brief Methods maps framework tensor name to OpenVINO name
|
||||
|
@ -8,6 +8,10 @@
|
||||
#include <ie_common.h>
|
||||
#include <math.h>
|
||||
|
||||
#include <ie_memcpy.h>
|
||||
#include <blob_factory.hpp>
|
||||
|
||||
|
||||
#include <cassert>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
@ -476,6 +480,64 @@ StatusCode CNNNetworkNGraphImpl::serialize(const std::string& xmlPath,
|
||||
return OK;
|
||||
}
|
||||
|
||||
StatusCode CNNNetworkNGraphImpl::serialize(std::ostream& xmlBuf,
|
||||
std::ostream& binBuf,
|
||||
ResponseDesc* resp) const noexcept {
|
||||
try {
|
||||
std::map<std::string, ngraph::OpSet> custom_opsets;
|
||||
for (const auto& extension : _ie_extensions) {
|
||||
auto opset = extension->getOpSets();
|
||||
custom_opsets.insert(begin(opset), end(opset));
|
||||
}
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::Serialize>(
|
||||
xmlBuf, binBuf, ngraph::pass::Serialize::Version::IR_V10,
|
||||
custom_opsets);
|
||||
manager.run_passes(_ngraph_function);
|
||||
} catch (const Exception& e) {
|
||||
return DescriptionBuffer(GENERAL_ERROR, resp) << e.what();
|
||||
} catch (const std::exception& e) {
|
||||
return DescriptionBuffer(UNEXPECTED, resp) << e.what();
|
||||
} catch (...) {
|
||||
return DescriptionBuffer(UNEXPECTED, resp);
|
||||
}
|
||||
return OK;
|
||||
}
|
||||
|
||||
StatusCode CNNNetworkNGraphImpl::serialize(std::ostream& xmlBuf,
|
||||
Blob::Ptr& binBlob,
|
||||
ResponseDesc* resp) const noexcept {
|
||||
try {
|
||||
std::map<std::string, ngraph::OpSet> custom_opsets;
|
||||
for (const auto& extension : _ie_extensions) {
|
||||
auto opset = extension->getOpSets();
|
||||
custom_opsets.insert(begin(opset), end(opset));
|
||||
}
|
||||
|
||||
std::stringstream binBuf;
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::Serialize>(
|
||||
xmlBuf, binBuf, ngraph::pass::Serialize::Version::IR_V10,
|
||||
custom_opsets);
|
||||
manager.run_passes(_ngraph_function);
|
||||
|
||||
std::streambuf* pbuf = binBuf.rdbuf();
|
||||
unsigned long bufSize = binBuf.tellp();
|
||||
|
||||
TensorDesc tensorDesc(Precision::U8, { bufSize }, Layout::C);
|
||||
binBlob = make_shared_blob<uint8_t>(tensorDesc);
|
||||
binBlob->allocate();
|
||||
pbuf->sgetn(binBlob->buffer(), bufSize);
|
||||
} catch (const Exception& e) {
|
||||
return DescriptionBuffer(GENERAL_ERROR, resp) << e.what();
|
||||
} catch (const std::exception& e) {
|
||||
return DescriptionBuffer(UNEXPECTED, resp) << e.what();
|
||||
} catch (...) {
|
||||
return DescriptionBuffer(UNEXPECTED, resp);
|
||||
}
|
||||
return OK;
|
||||
}
|
||||
|
||||
StatusCode CNNNetworkNGraphImpl::getOVNameForTensor(std::string& ov_name, const std::string& orig_name, ResponseDesc* resp) const noexcept {
|
||||
if (_tensorNames.find(orig_name) == _tensorNames.end())
|
||||
return DescriptionBuffer(NOT_FOUND, resp) << "Framework tensor with name \"" << orig_name << "\" was not mapped to OpenVINO data!";
|
||||
|
@ -79,6 +79,12 @@ public:
|
||||
StatusCode serialize(const std::string& xmlPath, const std::string& binPath, ResponseDesc* resp) const
|
||||
noexcept override;
|
||||
|
||||
StatusCode serialize(std::ostream& xmlBuf, std::ostream& binBuf, ResponseDesc* resp) const
|
||||
noexcept override;
|
||||
|
||||
StatusCode serialize(std::ostream& xmlBuf, Blob::Ptr& binBlob, ResponseDesc* resp) const
|
||||
noexcept override;
|
||||
|
||||
StatusCode getOVNameForTensor(std::string& ov_name, const std::string& orig_name, ResponseDesc* resp) const noexcept override;
|
||||
|
||||
// used by convertFunctionToICNNNetwork from legacy library
|
||||
|
@ -124,6 +124,14 @@ void CNNNetwork::serialize(const std::string& xmlPath, const std::string& binPat
|
||||
CALL_STATUS_FNC(serialize, xmlPath, binPath);
|
||||
}
|
||||
|
||||
void CNNNetwork::serialize(std::ostream& xmlBuf, std::ostream& binBuf) const {
|
||||
CALL_STATUS_FNC(serialize, xmlBuf, binBuf);
|
||||
}
|
||||
|
||||
void CNNNetwork::serialize(std::ostream& xmlBuf, Blob::Ptr& binBlob) const {
|
||||
CALL_STATUS_FNC(serialize, xmlBuf, binBlob);
|
||||
}
|
||||
|
||||
std::string CNNNetwork::getOVNameForTensor(const std::string& orig_name) const {
|
||||
std::string ov_name;
|
||||
CALL_STATUS_FNC(getOVNameForTensor, ov_name, orig_name);
|
||||
|
@ -122,6 +122,12 @@ public:
|
||||
StatusCode serialize(const std::string& xmlPath, const std::string& binPath, ResponseDesc* resp) const
|
||||
noexcept override;
|
||||
|
||||
StatusCode serialize(std::ostream& xmlBuf, std::ostream& binBuf, ResponseDesc* resp) const
|
||||
noexcept override;
|
||||
|
||||
StatusCode serialize(std::ostream& xmlBuf, Blob::Ptr& binBlob, ResponseDesc* resp) const
|
||||
noexcept override;
|
||||
|
||||
protected:
|
||||
std::map<std::string, DataPtr> _data;
|
||||
std::map<std::string, CNNLayerPtr> _layers;
|
||||
|
@ -408,6 +408,17 @@ StatusCode CNNNetworkImpl::serialize(const std::string& xmlPath, const std::stri
|
||||
return DescriptionBuffer(NOT_IMPLEMENTED, resp) << "The CNNNetworkImpl::serialize is not implemented";
|
||||
}
|
||||
|
||||
|
||||
StatusCode CNNNetworkImpl::serialize(std::ostream& xmlBuf, std::ostream& binBuf, ResponseDesc* resp) const
|
||||
noexcept {
|
||||
return DescriptionBuffer(NOT_IMPLEMENTED, resp) << "The CNNNetworkImpl::serialize is not implemented";
|
||||
}
|
||||
|
||||
StatusCode CNNNetworkImpl::serialize(std::ostream& xmlBuf, Blob::Ptr& binBlob, ResponseDesc* resp) const
|
||||
noexcept {
|
||||
return DescriptionBuffer(NOT_IMPLEMENTED, resp) << "The CNNNetworkImpl::serialize is not implemented";
|
||||
}
|
||||
|
||||
StatusCode CNNNetworkImpl::setBatchSize(size_t size, ResponseDesc* responseDesc) noexcept {
|
||||
try {
|
||||
auto originalBatchSize = getBatchSize();
|
||||
|
@ -124,3 +124,49 @@ TEST_F(SerializationDeterministicityTest, ModelWithConstants) {
|
||||
ASSERT_TRUE(files_equal(xml_1, xml_2));
|
||||
ASSERT_TRUE(files_equal(bin_1, bin_2));
|
||||
}
|
||||
|
||||
TEST_F(SerializationDeterministicityTest, SerializeToStream) {
|
||||
const std::string model =
|
||||
IR_SERIALIZATION_MODELS_PATH "add_abc_initializers.xml";
|
||||
const std::string weights =
|
||||
IR_SERIALIZATION_MODELS_PATH "add_abc_initializers.bin";
|
||||
|
||||
std::stringstream m_out_xml_buf, m_out_bin_buf;
|
||||
InferenceEngine::Blob::Ptr binBlob;
|
||||
|
||||
InferenceEngine::Core ie;
|
||||
auto expected = ie.ReadNetwork(model, weights);
|
||||
expected.serialize(m_out_xml_buf, m_out_bin_buf);
|
||||
|
||||
std::streambuf* pbuf = m_out_bin_buf.rdbuf();
|
||||
unsigned long bufSize = m_out_bin_buf.tellp();
|
||||
|
||||
InferenceEngine::TensorDesc tensorDesc(InferenceEngine::Precision::U8,
|
||||
{ bufSize }, InferenceEngine::Layout::C);
|
||||
binBlob = InferenceEngine::make_shared_blob<uint8_t>(tensorDesc);
|
||||
binBlob->allocate();
|
||||
pbuf->sgetn(binBlob->buffer(), bufSize);
|
||||
|
||||
auto result = ie.ReadNetwork(m_out_xml_buf.str(), binBlob);
|
||||
|
||||
ASSERT_TRUE(expected.layerCount() == result.layerCount());
|
||||
ASSERT_TRUE(expected.getInputShapes() == result.getInputShapes());
|
||||
}
|
||||
|
||||
TEST_F(SerializationDeterministicityTest, SerializeToBlob) {
|
||||
const std::string model =
|
||||
IR_SERIALIZATION_MODELS_PATH "add_abc_initializers.xml";
|
||||
const std::string weights =
|
||||
IR_SERIALIZATION_MODELS_PATH "add_abc_initializers.bin";
|
||||
|
||||
std::stringstream m_out_xml_buf;
|
||||
InferenceEngine::Blob::Ptr m_out_bin_buf;
|
||||
|
||||
InferenceEngine::Core ie;
|
||||
auto expected = ie.ReadNetwork(model, weights);
|
||||
expected.serialize(m_out_xml_buf, m_out_bin_buf);
|
||||
auto result = ie.ReadNetwork(m_out_xml_buf.str(), m_out_bin_buf);
|
||||
|
||||
ASSERT_TRUE(expected.layerCount() == result.layerCount());
|
||||
ASSERT_TRUE(expected.getInputShapes() == result.getInputShapes());
|
||||
}
|
||||
|
@ -48,4 +48,8 @@ class MockICNNNetwork final : public InferenceEngine::ICNNNetwork {
|
||||
(const, noexcept));
|
||||
MOCK_METHOD(InferenceEngine::StatusCode, serialize,
|
||||
(const std::string &, const std::string &, InferenceEngine::ResponseDesc*), (const, noexcept));
|
||||
MOCK_METHOD(InferenceEngine::StatusCode, serialize,
|
||||
(std::ostream &, std::ostream &, InferenceEngine::ResponseDesc*), (const, noexcept));
|
||||
MOCK_METHOD(InferenceEngine::StatusCode, serialize,
|
||||
(std::ostream &, InferenceEngine::Blob::Ptr &, InferenceEngine::ResponseDesc*), (const, noexcept));
|
||||
};
|
||||
|
@ -43,6 +43,10 @@ public:
|
||||
MOCK_METHOD(StatusCode, reshape, (const ICNNNetwork::InputShapes &, ResponseDesc *), (noexcept));
|
||||
MOCK_METHOD(StatusCode, serialize,
|
||||
(const std::string &, const std::string &, InferenceEngine::ResponseDesc*), (const, noexcept));
|
||||
MOCK_METHOD(StatusCode, serialize,
|
||||
(std::ostream &, std::ostream &, InferenceEngine::ResponseDesc*), (const, noexcept));
|
||||
MOCK_METHOD(StatusCode, serialize,
|
||||
(std::ostream &, Blob::Ptr &, InferenceEngine::ResponseDesc*), (const, noexcept));
|
||||
};
|
||||
|
||||
IE_SUPPRESS_DEPRECATED_END
|
||||
|
Loading…
Reference in New Issue
Block a user