diff --git a/inference-engine/include/cpp/ie_cnn_network.h b/inference-engine/include/cpp/ie_cnn_network.h index 9e21a470d45..1fe5d2173f2 100644 --- a/inference-engine/include/cpp/ie_cnn_network.h +++ b/inference-engine/include/cpp/ie_cnn_network.h @@ -199,6 +199,22 @@ public: */ void serialize(const std::string& xmlPath, const std::string& binPath = {}) const; + /** + * @brief Serialize network to IR and weights streams. + * + * @param xmlBuf output IR stream. + * @param binBuf output weights stream. + */ + void serialize(std::ostream& xmlBuf, std::ostream& binBuf) const; + + /** + * @brief Serialize network to IR stream and weights Blob::Ptr. + * + * @param xmlBuf output IR stream. + * @param binBlob output weights Blob::Ptr. + */ + void serialize(std::ostream& xmlBuf, Blob::Ptr& binBlob) const; + /** * @brief Method maps framework tensor name to OpenVINO name * @param orig_name Framework tensor name diff --git a/inference-engine/include/ie_icnn_network.hpp b/inference-engine/include/ie_icnn_network.hpp index 6e42b5ea402..55b98ba58c6 100644 --- a/inference-engine/include/ie_icnn_network.hpp +++ b/inference-engine/include/ie_icnn_network.hpp @@ -200,6 +200,14 @@ public: virtual StatusCode serialize(const std::string& xmlPath, const std::string& binPath, ResponseDesc* resp) const noexcept = 0; + INFERENCE_ENGINE_DEPRECATED("Use InferenceEngine::CNNNetwork wrapper instead") + virtual StatusCode serialize(std::ostream& xmlFile, std::ostream& binFile, ResponseDesc* resp) const + noexcept = 0; + + INFERENCE_ENGINE_DEPRECATED("Use InferenceEngine::CNNNetwork wrapper instead") + virtual StatusCode serialize(std::ostream& xmlPath, Blob::Ptr& binPath, ResponseDesc* resp) const + noexcept = 0; + /** * @deprecated Use InferenceEngine::CNNNetwork wrapper instead * @brief Methods maps framework tensor name to OpenVINO name diff --git a/inference-engine/src/inference_engine/cnn_network_ngraph_impl.cpp b/inference-engine/src/inference_engine/cnn_network_ngraph_impl.cpp index 1252c1b856a..1f05ca0098c 100644 --- a/inference-engine/src/inference_engine/cnn_network_ngraph_impl.cpp +++ b/inference-engine/src/inference_engine/cnn_network_ngraph_impl.cpp @@ -8,6 +8,10 @@ #include #include +#include +#include + + #include #include #include @@ -476,6 +480,64 @@ StatusCode CNNNetworkNGraphImpl::serialize(const std::string& xmlPath, return OK; } +StatusCode CNNNetworkNGraphImpl::serialize(std::ostream& xmlBuf, + std::ostream& binBuf, + ResponseDesc* resp) const noexcept { + try { + std::map custom_opsets; + for (const auto& extension : _ie_extensions) { + auto opset = extension->getOpSets(); + custom_opsets.insert(begin(opset), end(opset)); + } + ngraph::pass::Manager manager; + manager.register_pass( + xmlBuf, binBuf, ngraph::pass::Serialize::Version::IR_V10, + custom_opsets); + manager.run_passes(_ngraph_function); + } catch (const Exception& e) { + return DescriptionBuffer(GENERAL_ERROR, resp) << e.what(); + } catch (const std::exception& e) { + return DescriptionBuffer(UNEXPECTED, resp) << e.what(); + } catch (...) { + return DescriptionBuffer(UNEXPECTED, resp); + } + return OK; +} + +StatusCode CNNNetworkNGraphImpl::serialize(std::ostream& xmlBuf, + Blob::Ptr& binBlob, + ResponseDesc* resp) const noexcept { + try { + std::map custom_opsets; + for (const auto& extension : _ie_extensions) { + auto opset = extension->getOpSets(); + custom_opsets.insert(begin(opset), end(opset)); + } + + std::stringstream binBuf; + ngraph::pass::Manager manager; + manager.register_pass( + xmlBuf, binBuf, ngraph::pass::Serialize::Version::IR_V10, + custom_opsets); + manager.run_passes(_ngraph_function); + + std::streambuf* pbuf = binBuf.rdbuf(); + unsigned long bufSize = binBuf.tellp(); + + TensorDesc tensorDesc(Precision::U8, { bufSize }, Layout::C); + binBlob = make_shared_blob(tensorDesc); + binBlob->allocate(); + pbuf->sgetn(binBlob->buffer(), bufSize); + } catch (const Exception& e) { + return DescriptionBuffer(GENERAL_ERROR, resp) << e.what(); + } catch (const std::exception& e) { + return DescriptionBuffer(UNEXPECTED, resp) << e.what(); + } catch (...) { + return DescriptionBuffer(UNEXPECTED, resp); + } + return OK; +} + StatusCode CNNNetworkNGraphImpl::getOVNameForTensor(std::string& ov_name, const std::string& orig_name, ResponseDesc* resp) const noexcept { if (_tensorNames.find(orig_name) == _tensorNames.end()) return DescriptionBuffer(NOT_FOUND, resp) << "Framework tensor with name \"" << orig_name << "\" was not mapped to OpenVINO data!"; diff --git a/inference-engine/src/inference_engine/cnn_network_ngraph_impl.hpp b/inference-engine/src/inference_engine/cnn_network_ngraph_impl.hpp index db0d1d9ab49..6fe00b8ad81 100644 --- a/inference-engine/src/inference_engine/cnn_network_ngraph_impl.hpp +++ b/inference-engine/src/inference_engine/cnn_network_ngraph_impl.hpp @@ -79,6 +79,12 @@ public: StatusCode serialize(const std::string& xmlPath, const std::string& binPath, ResponseDesc* resp) const noexcept override; + StatusCode serialize(std::ostream& xmlBuf, std::ostream& binBuf, ResponseDesc* resp) const + noexcept override; + + StatusCode serialize(std::ostream& xmlBuf, Blob::Ptr& binBlob, ResponseDesc* resp) const + noexcept override; + StatusCode getOVNameForTensor(std::string& ov_name, const std::string& orig_name, ResponseDesc* resp) const noexcept override; // used by convertFunctionToICNNNetwork from legacy library diff --git a/inference-engine/src/inference_engine/cpp/ie_cnn_network.cpp b/inference-engine/src/inference_engine/cpp/ie_cnn_network.cpp index 52ae8998038..e2506d6cdca 100644 --- a/inference-engine/src/inference_engine/cpp/ie_cnn_network.cpp +++ b/inference-engine/src/inference_engine/cpp/ie_cnn_network.cpp @@ -124,6 +124,14 @@ void CNNNetwork::serialize(const std::string& xmlPath, const std::string& binPat CALL_STATUS_FNC(serialize, xmlPath, binPath); } +void CNNNetwork::serialize(std::ostream& xmlBuf, std::ostream& binBuf) const { + CALL_STATUS_FNC(serialize, xmlBuf, binBuf); +} + +void CNNNetwork::serialize(std::ostream& xmlBuf, Blob::Ptr& binBlob) const { + CALL_STATUS_FNC(serialize, xmlBuf, binBlob); +} + std::string CNNNetwork::getOVNameForTensor(const std::string& orig_name) const { std::string ov_name; CALL_STATUS_FNC(getOVNameForTensor, ov_name, orig_name); diff --git a/inference-engine/src/legacy_api/include/legacy/cnn_network_impl.hpp b/inference-engine/src/legacy_api/include/legacy/cnn_network_impl.hpp index 8151832d027..4849f4d0193 100644 --- a/inference-engine/src/legacy_api/include/legacy/cnn_network_impl.hpp +++ b/inference-engine/src/legacy_api/include/legacy/cnn_network_impl.hpp @@ -122,6 +122,12 @@ public: StatusCode serialize(const std::string& xmlPath, const std::string& binPath, ResponseDesc* resp) const noexcept override; + StatusCode serialize(std::ostream& xmlBuf, std::ostream& binBuf, ResponseDesc* resp) const + noexcept override; + + StatusCode serialize(std::ostream& xmlBuf, Blob::Ptr& binBlob, ResponseDesc* resp) const + noexcept override; + protected: std::map _data; std::map _layers; diff --git a/inference-engine/src/legacy_api/src/cnn_network_impl.cpp b/inference-engine/src/legacy_api/src/cnn_network_impl.cpp index c332d2e07ae..a7309caca69 100644 --- a/inference-engine/src/legacy_api/src/cnn_network_impl.cpp +++ b/inference-engine/src/legacy_api/src/cnn_network_impl.cpp @@ -408,6 +408,17 @@ StatusCode CNNNetworkImpl::serialize(const std::string& xmlPath, const std::stri return DescriptionBuffer(NOT_IMPLEMENTED, resp) << "The CNNNetworkImpl::serialize is not implemented"; } + +StatusCode CNNNetworkImpl::serialize(std::ostream& xmlBuf, std::ostream& binBuf, ResponseDesc* resp) const + noexcept { + return DescriptionBuffer(NOT_IMPLEMENTED, resp) << "The CNNNetworkImpl::serialize is not implemented"; +} + +StatusCode CNNNetworkImpl::serialize(std::ostream& xmlBuf, Blob::Ptr& binBlob, ResponseDesc* resp) const + noexcept { + return DescriptionBuffer(NOT_IMPLEMENTED, resp) << "The CNNNetworkImpl::serialize is not implemented"; +} + StatusCode CNNNetworkImpl::setBatchSize(size_t size, ResponseDesc* responseDesc) noexcept { try { auto originalBatchSize = getBatchSize(); diff --git a/inference-engine/tests/functional/inference_engine/ir_serialization/deterministicity.cpp b/inference-engine/tests/functional/inference_engine/ir_serialization/deterministicity.cpp index 8fb248ed7dc..cbcd1b3093d 100644 --- a/inference-engine/tests/functional/inference_engine/ir_serialization/deterministicity.cpp +++ b/inference-engine/tests/functional/inference_engine/ir_serialization/deterministicity.cpp @@ -124,3 +124,49 @@ TEST_F(SerializationDeterministicityTest, ModelWithConstants) { ASSERT_TRUE(files_equal(xml_1, xml_2)); ASSERT_TRUE(files_equal(bin_1, bin_2)); } + +TEST_F(SerializationDeterministicityTest, SerializeToStream) { + const std::string model = + IR_SERIALIZATION_MODELS_PATH "add_abc_initializers.xml"; + const std::string weights = + IR_SERIALIZATION_MODELS_PATH "add_abc_initializers.bin"; + + std::stringstream m_out_xml_buf, m_out_bin_buf; + InferenceEngine::Blob::Ptr binBlob; + + InferenceEngine::Core ie; + auto expected = ie.ReadNetwork(model, weights); + expected.serialize(m_out_xml_buf, m_out_bin_buf); + + std::streambuf* pbuf = m_out_bin_buf.rdbuf(); + unsigned long bufSize = m_out_bin_buf.tellp(); + + InferenceEngine::TensorDesc tensorDesc(InferenceEngine::Precision::U8, + { bufSize }, InferenceEngine::Layout::C); + binBlob = InferenceEngine::make_shared_blob(tensorDesc); + binBlob->allocate(); + pbuf->sgetn(binBlob->buffer(), bufSize); + + auto result = ie.ReadNetwork(m_out_xml_buf.str(), binBlob); + + ASSERT_TRUE(expected.layerCount() == result.layerCount()); + ASSERT_TRUE(expected.getInputShapes() == result.getInputShapes()); +} + +TEST_F(SerializationDeterministicityTest, SerializeToBlob) { + const std::string model = + IR_SERIALIZATION_MODELS_PATH "add_abc_initializers.xml"; + const std::string weights = + IR_SERIALIZATION_MODELS_PATH "add_abc_initializers.bin"; + + std::stringstream m_out_xml_buf; + InferenceEngine::Blob::Ptr m_out_bin_buf; + + InferenceEngine::Core ie; + auto expected = ie.ReadNetwork(model, weights); + expected.serialize(m_out_xml_buf, m_out_bin_buf); + auto result = ie.ReadNetwork(m_out_xml_buf.str(), m_out_bin_buf); + + ASSERT_TRUE(expected.layerCount() == result.layerCount()); + ASSERT_TRUE(expected.getInputShapes() == result.getInputShapes()); +} diff --git a/inference-engine/tests/ie_test_utils/unit_test_utils/mocks/mock_icnn_network.hpp b/inference-engine/tests/ie_test_utils/unit_test_utils/mocks/mock_icnn_network.hpp index a3f7d337dd0..3d157842b4d 100644 --- a/inference-engine/tests/ie_test_utils/unit_test_utils/mocks/mock_icnn_network.hpp +++ b/inference-engine/tests/ie_test_utils/unit_test_utils/mocks/mock_icnn_network.hpp @@ -48,4 +48,8 @@ class MockICNNNetwork final : public InferenceEngine::ICNNNetwork { (const, noexcept)); MOCK_METHOD(InferenceEngine::StatusCode, serialize, (const std::string &, const std::string &, InferenceEngine::ResponseDesc*), (const, noexcept)); + MOCK_METHOD(InferenceEngine::StatusCode, serialize, + (std::ostream &, std::ostream &, InferenceEngine::ResponseDesc*), (const, noexcept)); + MOCK_METHOD(InferenceEngine::StatusCode, serialize, + (std::ostream &, InferenceEngine::Blob::Ptr &, InferenceEngine::ResponseDesc*), (const, noexcept)); }; diff --git a/inference-engine/tests/ie_test_utils/unit_test_utils/mocks/mock_not_empty_icnn_network.hpp b/inference-engine/tests/ie_test_utils/unit_test_utils/mocks/mock_not_empty_icnn_network.hpp index f9ffeab2f2d..d861ded519a 100644 --- a/inference-engine/tests/ie_test_utils/unit_test_utils/mocks/mock_not_empty_icnn_network.hpp +++ b/inference-engine/tests/ie_test_utils/unit_test_utils/mocks/mock_not_empty_icnn_network.hpp @@ -43,6 +43,10 @@ public: MOCK_METHOD(StatusCode, reshape, (const ICNNNetwork::InputShapes &, ResponseDesc *), (noexcept)); MOCK_METHOD(StatusCode, serialize, (const std::string &, const std::string &, InferenceEngine::ResponseDesc*), (const, noexcept)); + MOCK_METHOD(StatusCode, serialize, + (std::ostream &, std::ostream &, InferenceEngine::ResponseDesc*), (const, noexcept)); + MOCK_METHOD(StatusCode, serialize, + (std::ostream &, Blob::Ptr &, InferenceEngine::ResponseDesc*), (const, noexcept)); }; IE_SUPPRESS_DEPRECATED_END