Use runtime::Tensor instead of Blob in read_model (#7951)

* Use runtime::Tensor instead of Blob in read_model

* Fixed the memory ownership
This commit is contained in:
Ilya Churaev 2021-10-13 09:13:11 +03:00 committed by GitHub
parent d23ec24fd8
commit 972524f1cc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 16 additions and 15 deletions

View File

@ -20,10 +20,10 @@
#include "openvino/runtime/common.hpp"
#include "openvino/runtime/executable_network.hpp"
#include "openvino/runtime/remote_context.hpp"
#include "openvino/runtime/tensor.hpp"
namespace InferenceEngine {
class IExtension;
class Blob;
class RemoteContext;
} // namespace InferenceEngine
@ -91,19 +91,14 @@ public:
/**
* @brief Reads models from IR and ONNX formats
* @param model string with model in IR or ONNX format
* @param weights shared pointer to constant blob with weights
* Reading ONNX models doesn't support loading weights from data blobs.
* If you are using an ONNX model with external data files, please use the
* `ov::runtime::Core::read_model(const std::string& model, const Blob::CPtr& weights) const`
* function overload which takes a filesystem path to the model.
* For ONNX case the second parameter should contain empty blob.
* @param weights shared pointer to constant tensor with weights
* Reading ONNX models doesn't support loading weights from data tensors.
* @note Created Function object shares the weights with `weights` object.
* So, do not create `weights` on temporary data which can be later freed, since the network
* constant data becomes to point to invalid memory.
* @return Function
*/
std::shared_ptr<ov::Function> read_model(const std::string& model,
const std::shared_ptr<const ie::Blob>& weights) const;
std::shared_ptr<ov::Function> read_model(const std::string& model, const Tensor& weights) const;
/**
* @brief Creates an executable network from a network object.

View File

@ -1293,8 +1293,12 @@ std::shared_ptr<ov::Function> Core::read_model(const std::string& modelPath, con
OV_CORE_CALL_STATEMENT(return _impl->ReadNetwork(modelPath, binPath).getFunction(););
}
std::shared_ptr<ov::Function> Core::read_model(const std::string& model, const ie::Blob::CPtr& weights) const {
OV_CORE_CALL_STATEMENT(return _impl->ReadNetwork(model, weights).getFunction(););
std::shared_ptr<ov::Function> Core::read_model(const std::string& model, const ov::runtime::Tensor& weights) const {
InferenceEngine::Blob::Ptr blob;
if (weights) {
blob = weights._impl;
}
OV_CORE_CALL_STATEMENT(return _impl->ReadNetwork(model, blob).getFunction(););
}
namespace {

View File

@ -178,7 +178,7 @@ TEST_F(RTInfoDeserialization, NodeV10) {
f_10_ref->set_friendly_name("Network");
ov::runtime::Core core;
auto f_10_core = core.read_model(model, InferenceEngine::Blob::CPtr());
auto f_10_core = core.read_model(model, ov::runtime::Tensor());
ASSERT_NE(nullptr, f_10_core);
check_version(f_10_core, 10);
@ -330,7 +330,7 @@ TEST_F(RTInfoDeserialization, InputAndOutputV10) {
f_10_ref->set_friendly_name("Network");
ov::runtime::Core core;
auto f_10_core = core.read_model(model, InferenceEngine::Blob::CPtr());
auto f_10_core = core.read_model(model, ov::runtime::Tensor());
ASSERT_NE(nullptr, f_10_core);
check_version(f_10_core, 10);
@ -451,7 +451,7 @@ TEST_F(RTInfoDeserialization, NodeV11) {
// read IR v11 with new API
{
ov::runtime::Core core;
auto f_11 = core.read_model(model, InferenceEngine::Blob::CPtr());
auto f_11 = core.read_model(model, ov::runtime::Tensor());
ASSERT_NE(nullptr, f_11);
check_old_api_map(f_11->get_parameters()[0]->get_rt_info(),

View File

@ -320,7 +320,7 @@ TEST_P(OVExecNetwork, readFromV10IR) {
</edges>
</net>
)V0G0N";
function = ie->read_model(model, InferenceEngine::Blob::Ptr());
function = ie->read_model(model, ov::runtime::Tensor());
EXPECT_EQ(function->inputs().size(), 1);
EXPECT_EQ(function->outputs().size(), 1);
EXPECT_NO_THROW(function->input("in1")); // remove if read_model does not change function names

View File

@ -23,6 +23,7 @@ class Blob;
namespace ov {
namespace runtime {
class Core;
class InferRequest;
class RemoteContext;
class VariableState;
@ -45,6 +46,7 @@ protected:
*/
Tensor(const std::shared_ptr<void>& so, const std::shared_ptr<InferenceEngine::Blob>& impl);
friend class ov::runtime::Core;
friend class ov::runtime::InferRequest;
friend class ov::runtime::RemoteContext;
friend class ov::runtime::VariableState;