Added get compiled model API (#9186)

This commit is contained in:
Anton Pankratov 2021-12-14 00:14:01 +03:00 committed by GitHub
parent c39dba62b0
commit ef58ec6c8c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 48 additions and 0 deletions

View File

@ -8,6 +8,7 @@
#include <openvino/core/except.hpp>
#include <openvino/runtime/infer_request.hpp>
#include <openvino/runtime/remote_tensor.hpp>
#include <openvino/runtime/compiled_model.hpp>
using namespace ::testing;
using namespace std;
@ -83,3 +84,9 @@ TEST(InferRequestOVTests, throwsOnUninitializedSetRemoteTensor) {
ov::runtime::RemoteTensor remote_tensor;
ASSERT_THROW(req.set_tensor(ov::Output<const ov::Node>(), remote_tensor), ov::Exception);
}
TEST(InferRequestOVTests, throwsOnGetCompiledModel) {
ov::runtime::InferRequest req;
ASSERT_THROW(req.get_compiled_model(), ov::Exception);
}

View File

@ -633,6 +633,23 @@ TEST_P(OVExecutableNetworkBaseTest, precisionsAsInOriginalIR) {
EXPECT_EQ(ref_result->get_shape(), actual_result->get_shape());
EXPECT_EQ(ref_result->get_friendly_name(), actual_result->get_friendly_name());
}
TEST_P(OVExecutableNetworkBaseTest, getCompiledModelFromInferRequest) {
ov::runtime::InferRequest req;
{
ov::runtime::CompiledModel compiled_model;
ASSERT_NO_THROW(compiled_model = core->compile_model(function, targetDevice, configuration));
ASSERT_NO_THROW(req = compiled_model.create_infer_request());
ASSERT_NO_THROW(req.infer());
}
{
ov::runtime::CompiledModel restored_compiled_model;
ov::runtime::InferRequest another_req;
ASSERT_NO_THROW(restored_compiled_model = req.get_compiled_model());
ASSERT_NO_THROW(another_req = restored_compiled_model.create_infer_request());
ASSERT_NO_THROW(another_req.infer());
}
}
} // namespace behavior
} // namespace test
} // namespace ov

View File

@ -185,6 +185,12 @@ public:
*/
void setPointerToExecutableNetworkInternal(const std::shared_ptr<IExecutableNetworkInternal>& exeNetwork);
/**
* @brief Returns the pointer to executable network internal.
* @returns The executable network
*/
std::shared_ptr<IExecutableNetworkInternal> getPointerToExecutableNetworkInternal() const;
/**
* @brief Gets the pointer to userData.
* @return Pointer to user data

View File

@ -28,6 +28,7 @@ namespace ov {
namespace runtime {
class Core;
class InferRequest;
/**
* @brief This is an interface of an executable network
@ -45,6 +46,7 @@ class OPENVINO_RUNTIME_API CompiledModel {
CompiledModel(const std::shared_ptr<void>& so,
const std::shared_ptr<InferenceEngine::IExecutableNetworkInternal>& impl);
friend class ov::runtime::Core;
friend class ov::runtime::InferRequest;
public:
/**

View File

@ -222,6 +222,13 @@ public:
*/
std::vector<VariableState> query_state();
/**
* @brief Returns compiled model that creates this inference request
*
* @return Compiled model object
*/
CompiledModel get_compiled_model();
/**
* @brief Checks if current InferRequest object is not initialized
* @return true if current InferRequest object is not initialized, false - otherwise

View File

@ -13,6 +13,7 @@
#include "ie_infer_async_request_base.hpp"
#include "ie_ngraph_utils.hpp"
#include "ie_remote_context.hpp"
#include "openvino/runtime/compiled_model.hpp"
#include "openvino/runtime/exception.hpp"
#include "openvino/runtime/infer_request.hpp"
#include "transformations/utils/utils.hpp"
@ -446,6 +447,10 @@ std::vector<VariableState> InferRequest::query_state() {
return variable_states;
}
CompiledModel InferRequest::get_compiled_model() {
OV_INFER_REQ_CALL_STATEMENT(return {_so, _impl->getPointerToExecutableNetworkInternal()});
}
bool InferRequest::operator!() const noexcept {
return !_impl;
}

View File

@ -355,6 +355,10 @@ void IInferRequestInternal::setPointerToExecutableNetworkInternal(
_exeNetwork = exeNetwork;
}
std::shared_ptr<IExecutableNetworkInternal> IInferRequestInternal::getPointerToExecutableNetworkInternal() const {
return _exeNetwork;
}
bool IInferRequestInternal::preProcessingRequired(const InputInfo::Ptr& info,
const Blob::Ptr& userBlob,
const Blob::Ptr& deviceBlob) {