serialize to memory (#6485)
This commit is contained in:
parent
de8c57e034
commit
98148539b3
@ -199,6 +199,22 @@ public:
|
|||||||
*/
|
*/
|
||||||
void serialize(const std::string& xmlPath, const std::string& binPath = {}) const;
|
void serialize(const std::string& xmlPath, const std::string& binPath = {}) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Serialize network to IR and weights streams.
|
||||||
|
*
|
||||||
|
* @param xmlBuf output IR stream.
|
||||||
|
* @param binBuf output weights stream.
|
||||||
|
*/
|
||||||
|
void serialize(std::ostream& xmlBuf, std::ostream& binBuf) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Serialize network to IR stream and weights Blob::Ptr.
|
||||||
|
*
|
||||||
|
* @param xmlBuf output IR stream.
|
||||||
|
* @param binBlob output weights Blob::Ptr.
|
||||||
|
*/
|
||||||
|
void serialize(std::ostream& xmlBuf, Blob::Ptr& binBlob) const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Method maps framework tensor name to OpenVINO name
|
* @brief Method maps framework tensor name to OpenVINO name
|
||||||
* @param orig_name Framework tensor name
|
* @param orig_name Framework tensor name
|
||||||
|
@ -200,6 +200,14 @@ public:
|
|||||||
virtual StatusCode serialize(const std::string& xmlPath, const std::string& binPath, ResponseDesc* resp) const
|
virtual StatusCode serialize(const std::string& xmlPath, const std::string& binPath, ResponseDesc* resp) const
|
||||||
noexcept = 0;
|
noexcept = 0;
|
||||||
|
|
||||||
|
INFERENCE_ENGINE_DEPRECATED("Use InferenceEngine::CNNNetwork wrapper instead")
|
||||||
|
virtual StatusCode serialize(std::ostream& xmlFile, std::ostream& binFile, ResponseDesc* resp) const
|
||||||
|
noexcept = 0;
|
||||||
|
|
||||||
|
INFERENCE_ENGINE_DEPRECATED("Use InferenceEngine::CNNNetwork wrapper instead")
|
||||||
|
virtual StatusCode serialize(std::ostream& xmlPath, Blob::Ptr& binPath, ResponseDesc* resp) const
|
||||||
|
noexcept = 0;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @deprecated Use InferenceEngine::CNNNetwork wrapper instead
|
* @deprecated Use InferenceEngine::CNNNetwork wrapper instead
|
||||||
* @brief Methods maps framework tensor name to OpenVINO name
|
* @brief Methods maps framework tensor name to OpenVINO name
|
||||||
|
@ -8,6 +8,10 @@
|
|||||||
#include <ie_common.h>
|
#include <ie_common.h>
|
||||||
#include <math.h>
|
#include <math.h>
|
||||||
|
|
||||||
|
#include <ie_memcpy.h>
|
||||||
|
#include <blob_factory.hpp>
|
||||||
|
|
||||||
|
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
@ -476,6 +480,64 @@ StatusCode CNNNetworkNGraphImpl::serialize(const std::string& xmlPath,
|
|||||||
return OK;
|
return OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
StatusCode CNNNetworkNGraphImpl::serialize(std::ostream& xmlBuf,
|
||||||
|
std::ostream& binBuf,
|
||||||
|
ResponseDesc* resp) const noexcept {
|
||||||
|
try {
|
||||||
|
std::map<std::string, ngraph::OpSet> custom_opsets;
|
||||||
|
for (const auto& extension : _ie_extensions) {
|
||||||
|
auto opset = extension->getOpSets();
|
||||||
|
custom_opsets.insert(begin(opset), end(opset));
|
||||||
|
}
|
||||||
|
ngraph::pass::Manager manager;
|
||||||
|
manager.register_pass<ngraph::pass::Serialize>(
|
||||||
|
xmlBuf, binBuf, ngraph::pass::Serialize::Version::IR_V10,
|
||||||
|
custom_opsets);
|
||||||
|
manager.run_passes(_ngraph_function);
|
||||||
|
} catch (const Exception& e) {
|
||||||
|
return DescriptionBuffer(GENERAL_ERROR, resp) << e.what();
|
||||||
|
} catch (const std::exception& e) {
|
||||||
|
return DescriptionBuffer(UNEXPECTED, resp) << e.what();
|
||||||
|
} catch (...) {
|
||||||
|
return DescriptionBuffer(UNEXPECTED, resp);
|
||||||
|
}
|
||||||
|
return OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
StatusCode CNNNetworkNGraphImpl::serialize(std::ostream& xmlBuf,
|
||||||
|
Blob::Ptr& binBlob,
|
||||||
|
ResponseDesc* resp) const noexcept {
|
||||||
|
try {
|
||||||
|
std::map<std::string, ngraph::OpSet> custom_opsets;
|
||||||
|
for (const auto& extension : _ie_extensions) {
|
||||||
|
auto opset = extension->getOpSets();
|
||||||
|
custom_opsets.insert(begin(opset), end(opset));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::stringstream binBuf;
|
||||||
|
ngraph::pass::Manager manager;
|
||||||
|
manager.register_pass<ngraph::pass::Serialize>(
|
||||||
|
xmlBuf, binBuf, ngraph::pass::Serialize::Version::IR_V10,
|
||||||
|
custom_opsets);
|
||||||
|
manager.run_passes(_ngraph_function);
|
||||||
|
|
||||||
|
std::streambuf* pbuf = binBuf.rdbuf();
|
||||||
|
unsigned long bufSize = binBuf.tellp();
|
||||||
|
|
||||||
|
TensorDesc tensorDesc(Precision::U8, { bufSize }, Layout::C);
|
||||||
|
binBlob = make_shared_blob<uint8_t>(tensorDesc);
|
||||||
|
binBlob->allocate();
|
||||||
|
pbuf->sgetn(binBlob->buffer(), bufSize);
|
||||||
|
} catch (const Exception& e) {
|
||||||
|
return DescriptionBuffer(GENERAL_ERROR, resp) << e.what();
|
||||||
|
} catch (const std::exception& e) {
|
||||||
|
return DescriptionBuffer(UNEXPECTED, resp) << e.what();
|
||||||
|
} catch (...) {
|
||||||
|
return DescriptionBuffer(UNEXPECTED, resp);
|
||||||
|
}
|
||||||
|
return OK;
|
||||||
|
}
|
||||||
|
|
||||||
StatusCode CNNNetworkNGraphImpl::getOVNameForTensor(std::string& ov_name, const std::string& orig_name, ResponseDesc* resp) const noexcept {
|
StatusCode CNNNetworkNGraphImpl::getOVNameForTensor(std::string& ov_name, const std::string& orig_name, ResponseDesc* resp) const noexcept {
|
||||||
if (_tensorNames.find(orig_name) == _tensorNames.end())
|
if (_tensorNames.find(orig_name) == _tensorNames.end())
|
||||||
return DescriptionBuffer(NOT_FOUND, resp) << "Framework tensor with name \"" << orig_name << "\" was not mapped to OpenVINO data!";
|
return DescriptionBuffer(NOT_FOUND, resp) << "Framework tensor with name \"" << orig_name << "\" was not mapped to OpenVINO data!";
|
||||||
|
@ -79,6 +79,12 @@ public:
|
|||||||
StatusCode serialize(const std::string& xmlPath, const std::string& binPath, ResponseDesc* resp) const
|
StatusCode serialize(const std::string& xmlPath, const std::string& binPath, ResponseDesc* resp) const
|
||||||
noexcept override;
|
noexcept override;
|
||||||
|
|
||||||
|
StatusCode serialize(std::ostream& xmlBuf, std::ostream& binBuf, ResponseDesc* resp) const
|
||||||
|
noexcept override;
|
||||||
|
|
||||||
|
StatusCode serialize(std::ostream& xmlBuf, Blob::Ptr& binBlob, ResponseDesc* resp) const
|
||||||
|
noexcept override;
|
||||||
|
|
||||||
StatusCode getOVNameForTensor(std::string& ov_name, const std::string& orig_name, ResponseDesc* resp) const noexcept override;
|
StatusCode getOVNameForTensor(std::string& ov_name, const std::string& orig_name, ResponseDesc* resp) const noexcept override;
|
||||||
|
|
||||||
// used by convertFunctionToICNNNetwork from legacy library
|
// used by convertFunctionToICNNNetwork from legacy library
|
||||||
|
@ -124,6 +124,14 @@ void CNNNetwork::serialize(const std::string& xmlPath, const std::string& binPat
|
|||||||
CALL_STATUS_FNC(serialize, xmlPath, binPath);
|
CALL_STATUS_FNC(serialize, xmlPath, binPath);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void CNNNetwork::serialize(std::ostream& xmlBuf, std::ostream& binBuf) const {
|
||||||
|
CALL_STATUS_FNC(serialize, xmlBuf, binBuf);
|
||||||
|
}
|
||||||
|
|
||||||
|
void CNNNetwork::serialize(std::ostream& xmlBuf, Blob::Ptr& binBlob) const {
|
||||||
|
CALL_STATUS_FNC(serialize, xmlBuf, binBlob);
|
||||||
|
}
|
||||||
|
|
||||||
std::string CNNNetwork::getOVNameForTensor(const std::string& orig_name) const {
|
std::string CNNNetwork::getOVNameForTensor(const std::string& orig_name) const {
|
||||||
std::string ov_name;
|
std::string ov_name;
|
||||||
CALL_STATUS_FNC(getOVNameForTensor, ov_name, orig_name);
|
CALL_STATUS_FNC(getOVNameForTensor, ov_name, orig_name);
|
||||||
|
@ -122,6 +122,12 @@ public:
|
|||||||
StatusCode serialize(const std::string& xmlPath, const std::string& binPath, ResponseDesc* resp) const
|
StatusCode serialize(const std::string& xmlPath, const std::string& binPath, ResponseDesc* resp) const
|
||||||
noexcept override;
|
noexcept override;
|
||||||
|
|
||||||
|
StatusCode serialize(std::ostream& xmlBuf, std::ostream& binBuf, ResponseDesc* resp) const
|
||||||
|
noexcept override;
|
||||||
|
|
||||||
|
StatusCode serialize(std::ostream& xmlBuf, Blob::Ptr& binBlob, ResponseDesc* resp) const
|
||||||
|
noexcept override;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
std::map<std::string, DataPtr> _data;
|
std::map<std::string, DataPtr> _data;
|
||||||
std::map<std::string, CNNLayerPtr> _layers;
|
std::map<std::string, CNNLayerPtr> _layers;
|
||||||
|
@ -408,6 +408,17 @@ StatusCode CNNNetworkImpl::serialize(const std::string& xmlPath, const std::stri
|
|||||||
return DescriptionBuffer(NOT_IMPLEMENTED, resp) << "The CNNNetworkImpl::serialize is not implemented";
|
return DescriptionBuffer(NOT_IMPLEMENTED, resp) << "The CNNNetworkImpl::serialize is not implemented";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
StatusCode CNNNetworkImpl::serialize(std::ostream& xmlBuf, std::ostream& binBuf, ResponseDesc* resp) const
|
||||||
|
noexcept {
|
||||||
|
return DescriptionBuffer(NOT_IMPLEMENTED, resp) << "The CNNNetworkImpl::serialize is not implemented";
|
||||||
|
}
|
||||||
|
|
||||||
|
StatusCode CNNNetworkImpl::serialize(std::ostream& xmlBuf, Blob::Ptr& binBlob, ResponseDesc* resp) const
|
||||||
|
noexcept {
|
||||||
|
return DescriptionBuffer(NOT_IMPLEMENTED, resp) << "The CNNNetworkImpl::serialize is not implemented";
|
||||||
|
}
|
||||||
|
|
||||||
StatusCode CNNNetworkImpl::setBatchSize(size_t size, ResponseDesc* responseDesc) noexcept {
|
StatusCode CNNNetworkImpl::setBatchSize(size_t size, ResponseDesc* responseDesc) noexcept {
|
||||||
try {
|
try {
|
||||||
auto originalBatchSize = getBatchSize();
|
auto originalBatchSize = getBatchSize();
|
||||||
|
@ -124,3 +124,49 @@ TEST_F(SerializationDeterministicityTest, ModelWithConstants) {
|
|||||||
ASSERT_TRUE(files_equal(xml_1, xml_2));
|
ASSERT_TRUE(files_equal(xml_1, xml_2));
|
||||||
ASSERT_TRUE(files_equal(bin_1, bin_2));
|
ASSERT_TRUE(files_equal(bin_1, bin_2));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(SerializationDeterministicityTest, SerializeToStream) {
|
||||||
|
const std::string model =
|
||||||
|
IR_SERIALIZATION_MODELS_PATH "add_abc_initializers.xml";
|
||||||
|
const std::string weights =
|
||||||
|
IR_SERIALIZATION_MODELS_PATH "add_abc_initializers.bin";
|
||||||
|
|
||||||
|
std::stringstream m_out_xml_buf, m_out_bin_buf;
|
||||||
|
InferenceEngine::Blob::Ptr binBlob;
|
||||||
|
|
||||||
|
InferenceEngine::Core ie;
|
||||||
|
auto expected = ie.ReadNetwork(model, weights);
|
||||||
|
expected.serialize(m_out_xml_buf, m_out_bin_buf);
|
||||||
|
|
||||||
|
std::streambuf* pbuf = m_out_bin_buf.rdbuf();
|
||||||
|
unsigned long bufSize = m_out_bin_buf.tellp();
|
||||||
|
|
||||||
|
InferenceEngine::TensorDesc tensorDesc(InferenceEngine::Precision::U8,
|
||||||
|
{ bufSize }, InferenceEngine::Layout::C);
|
||||||
|
binBlob = InferenceEngine::make_shared_blob<uint8_t>(tensorDesc);
|
||||||
|
binBlob->allocate();
|
||||||
|
pbuf->sgetn(binBlob->buffer(), bufSize);
|
||||||
|
|
||||||
|
auto result = ie.ReadNetwork(m_out_xml_buf.str(), binBlob);
|
||||||
|
|
||||||
|
ASSERT_TRUE(expected.layerCount() == result.layerCount());
|
||||||
|
ASSERT_TRUE(expected.getInputShapes() == result.getInputShapes());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(SerializationDeterministicityTest, SerializeToBlob) {
|
||||||
|
const std::string model =
|
||||||
|
IR_SERIALIZATION_MODELS_PATH "add_abc_initializers.xml";
|
||||||
|
const std::string weights =
|
||||||
|
IR_SERIALIZATION_MODELS_PATH "add_abc_initializers.bin";
|
||||||
|
|
||||||
|
std::stringstream m_out_xml_buf;
|
||||||
|
InferenceEngine::Blob::Ptr m_out_bin_buf;
|
||||||
|
|
||||||
|
InferenceEngine::Core ie;
|
||||||
|
auto expected = ie.ReadNetwork(model, weights);
|
||||||
|
expected.serialize(m_out_xml_buf, m_out_bin_buf);
|
||||||
|
auto result = ie.ReadNetwork(m_out_xml_buf.str(), m_out_bin_buf);
|
||||||
|
|
||||||
|
ASSERT_TRUE(expected.layerCount() == result.layerCount());
|
||||||
|
ASSERT_TRUE(expected.getInputShapes() == result.getInputShapes());
|
||||||
|
}
|
||||||
|
@ -48,4 +48,8 @@ class MockICNNNetwork final : public InferenceEngine::ICNNNetwork {
|
|||||||
(const, noexcept));
|
(const, noexcept));
|
||||||
MOCK_METHOD(InferenceEngine::StatusCode, serialize,
|
MOCK_METHOD(InferenceEngine::StatusCode, serialize,
|
||||||
(const std::string &, const std::string &, InferenceEngine::ResponseDesc*), (const, noexcept));
|
(const std::string &, const std::string &, InferenceEngine::ResponseDesc*), (const, noexcept));
|
||||||
|
MOCK_METHOD(InferenceEngine::StatusCode, serialize,
|
||||||
|
(std::ostream &, std::ostream &, InferenceEngine::ResponseDesc*), (const, noexcept));
|
||||||
|
MOCK_METHOD(InferenceEngine::StatusCode, serialize,
|
||||||
|
(std::ostream &, InferenceEngine::Blob::Ptr &, InferenceEngine::ResponseDesc*), (const, noexcept));
|
||||||
};
|
};
|
||||||
|
@ -43,6 +43,10 @@ public:
|
|||||||
MOCK_METHOD(StatusCode, reshape, (const ICNNNetwork::InputShapes &, ResponseDesc *), (noexcept));
|
MOCK_METHOD(StatusCode, reshape, (const ICNNNetwork::InputShapes &, ResponseDesc *), (noexcept));
|
||||||
MOCK_METHOD(StatusCode, serialize,
|
MOCK_METHOD(StatusCode, serialize,
|
||||||
(const std::string &, const std::string &, InferenceEngine::ResponseDesc*), (const, noexcept));
|
(const std::string &, const std::string &, InferenceEngine::ResponseDesc*), (const, noexcept));
|
||||||
|
MOCK_METHOD(StatusCode, serialize,
|
||||||
|
(std::ostream &, std::ostream &, InferenceEngine::ResponseDesc*), (const, noexcept));
|
||||||
|
MOCK_METHOD(StatusCode, serialize,
|
||||||
|
(std::ostream &, Blob::Ptr &, InferenceEngine::ResponseDesc*), (const, noexcept));
|
||||||
};
|
};
|
||||||
|
|
||||||
IE_SUPPRESS_DEPRECATED_END
|
IE_SUPPRESS_DEPRECATED_END
|
||||||
|
Loading…
Reference in New Issue
Block a user