Removed template from base (#4045)
Co-authored-by: Ilya Lavrenov <ilya.lavrenov@intel.com>
This commit is contained in:
parent
e80e5e7ae5
commit
945da5ff4f
@ -143,7 +143,7 @@ InferenceEngine::IInferRequest::Ptr TemplatePlugin::ExecutableNetwork::CreateInf
|
|||||||
auto internalRequest = CreateInferRequestImpl(_networkInputs, _networkOutputs);
|
auto internalRequest = CreateInferRequestImpl(_networkInputs, _networkOutputs);
|
||||||
auto asyncThreadSafeImpl = std::make_shared<TemplateAsyncInferRequest>(std::static_pointer_cast<TemplateInferRequest>(internalRequest),
|
auto asyncThreadSafeImpl = std::make_shared<TemplateAsyncInferRequest>(std::static_pointer_cast<TemplateInferRequest>(internalRequest),
|
||||||
_taskExecutor, _plugin->_waitExecutor, _callbackExecutor);
|
_taskExecutor, _plugin->_waitExecutor, _callbackExecutor);
|
||||||
asyncRequest.reset(new InferenceEngine::InferRequestBase<TemplateAsyncInferRequest>(asyncThreadSafeImpl),
|
asyncRequest.reset(new InferenceEngine::InferRequestBase(asyncThreadSafeImpl),
|
||||||
[](InferenceEngine::IInferRequest *p) { p->Release(); });
|
[](InferenceEngine::IInferRequest *p) { p->Release(); });
|
||||||
asyncThreadSafeImpl->SetPointerToPublicInterface(asyncRequest);
|
asyncThreadSafeImpl->SetPointerToPublicInterface(asyncRequest);
|
||||||
return asyncRequest;
|
return asyncRequest;
|
||||||
|
@ -201,7 +201,7 @@ IInferRequest::Ptr MultiDeviceExecutableNetwork::CreateInferRequest() {
|
|||||||
_needPerfCounters,
|
_needPerfCounters,
|
||||||
std::static_pointer_cast<MultiDeviceExecutableNetwork>(shared_from_this()),
|
std::static_pointer_cast<MultiDeviceExecutableNetwork>(shared_from_this()),
|
||||||
_callbackExecutor);
|
_callbackExecutor);
|
||||||
asyncRequest.reset(new InferRequestBase<MultiDeviceAsyncInferRequest>(asyncTreadSafeImpl), [](IInferRequest *p) { p->Release(); });
|
asyncRequest.reset(new InferRequestBase(asyncTreadSafeImpl), [](IInferRequest *p) { p->Release(); });
|
||||||
asyncTreadSafeImpl->SetPointerToPublicInterface(asyncRequest);
|
asyncTreadSafeImpl->SetPointerToPublicInterface(asyncRequest);
|
||||||
return asyncRequest;
|
return asyncRequest;
|
||||||
}
|
}
|
||||||
|
@ -12,6 +12,7 @@
|
|||||||
#include <cpp/ie_executable_network.hpp>
|
#include <cpp/ie_executable_network.hpp>
|
||||||
#include <cpp_interfaces/base/ie_variable_state_base.hpp>
|
#include <cpp_interfaces/base/ie_variable_state_base.hpp>
|
||||||
#include <cpp_interfaces/interface/ie_ivariable_state_internal.hpp>
|
#include <cpp_interfaces/interface/ie_ivariable_state_internal.hpp>
|
||||||
|
#include <cpp_interfaces/interface/ie_iexecutable_network_internal.hpp>
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
@ -24,18 +25,17 @@ namespace InferenceEngine {
|
|||||||
/**
|
/**
|
||||||
* @brief Executable network `noexcept` wrapper which accepts IExecutableNetworkInternal derived instance which can throw exceptions
|
* @brief Executable network `noexcept` wrapper which accepts IExecutableNetworkInternal derived instance which can throw exceptions
|
||||||
* @ingroup ie_dev_api_exec_network_api
|
* @ingroup ie_dev_api_exec_network_api
|
||||||
* @tparam T Minimal CPP implementation of IExecutableNetworkInternal (e.g. ExecutableNetworkInternal)
|
*/
|
||||||
*/
|
|
||||||
template <class T>
|
|
||||||
class ExecutableNetworkBase : public IExecutableNetwork {
|
class ExecutableNetworkBase : public IExecutableNetwork {
|
||||||
std::shared_ptr<T> _impl;
|
protected:
|
||||||
|
std::shared_ptr<IExecutableNetworkInternal> _impl;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
/**
|
/**
|
||||||
* @brief Constructor with actual underlying implementation.
|
* @brief Constructor with actual underlying implementation.
|
||||||
* @param impl Underlying implementation of type IExecutableNetworkInternal
|
* @param impl Underlying implementation of type IExecutableNetworkInternal
|
||||||
*/
|
*/
|
||||||
explicit ExecutableNetworkBase(std::shared_ptr<T> impl) {
|
explicit ExecutableNetworkBase(std::shared_ptr<IExecutableNetworkInternal> impl) {
|
||||||
if (impl.get() == nullptr) {
|
if (impl.get() == nullptr) {
|
||||||
THROW_IE_EXCEPTION << "implementation not defined";
|
THROW_IE_EXCEPTION << "implementation not defined";
|
||||||
}
|
}
|
||||||
@ -77,7 +77,7 @@ public:
|
|||||||
if (idx >= v.size()) {
|
if (idx >= v.size()) {
|
||||||
return OUT_OF_BOUNDS;
|
return OUT_OF_BOUNDS;
|
||||||
}
|
}
|
||||||
pState = std::make_shared<VariableStateBase<IVariableStateInternal>>(v[idx]);
|
pState = std::make_shared<VariableStateBase>(v[idx]);
|
||||||
return OK;
|
return OK;
|
||||||
} catch (const std::exception& ex) {
|
} catch (const std::exception& ex) {
|
||||||
return InferenceEngine::DescriptionBuffer(GENERAL_ERROR, resp) << ex.what();
|
return InferenceEngine::DescriptionBuffer(GENERAL_ERROR, resp) << ex.what();
|
||||||
@ -91,11 +91,6 @@ public:
|
|||||||
delete this;
|
delete this;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// @private Need for unit tests only - TODO: unit tests should test using public API, non having details
|
|
||||||
const std::shared_ptr<T> getImpl() const {
|
|
||||||
return _impl;
|
|
||||||
}
|
|
||||||
|
|
||||||
StatusCode SetConfig(const std::map<std::string, Parameter>& config, ResponseDesc* resp) noexcept override {
|
StatusCode SetConfig(const std::map<std::string, Parameter>& config, ResponseDesc* resp) noexcept override {
|
||||||
TO_STATUS(_impl->SetConfig(config));
|
TO_STATUS(_impl->SetConfig(config));
|
||||||
}
|
}
|
||||||
@ -112,8 +107,8 @@ public:
|
|||||||
TO_STATUS(pContext = _impl->GetContext());
|
TO_STATUS(pContext = _impl->GetContext());
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
protected:
|
||||||
~ExecutableNetworkBase() = default;
|
~ExecutableNetworkBase() override = default;
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -127,7 +122,7 @@ template <class T>
|
|||||||
inline typename InferenceEngine::ExecutableNetwork make_executable_network(std::shared_ptr<T> impl) {
|
inline typename InferenceEngine::ExecutableNetwork make_executable_network(std::shared_ptr<T> impl) {
|
||||||
// to suppress warning about deprecated QueryState
|
// to suppress warning about deprecated QueryState
|
||||||
IE_SUPPRESS_DEPRECATED_START
|
IE_SUPPRESS_DEPRECATED_START
|
||||||
typename ExecutableNetworkBase<T>::Ptr net(new ExecutableNetworkBase<T>(impl), [](IExecutableNetwork* p) {
|
typename ExecutableNetworkBase::Ptr net(new ExecutableNetworkBase(impl), [](IExecutableNetwork* p) {
|
||||||
p->Release();
|
p->Release();
|
||||||
});
|
});
|
||||||
IE_SUPPRESS_DEPRECATED_END
|
IE_SUPPRESS_DEPRECATED_END
|
||||||
|
@ -11,6 +11,7 @@
|
|||||||
#include "cpp_interfaces/exception2status.hpp"
|
#include "cpp_interfaces/exception2status.hpp"
|
||||||
#include "cpp_interfaces/plugin_itt.hpp"
|
#include "cpp_interfaces/plugin_itt.hpp"
|
||||||
#include <cpp_interfaces/base/ie_variable_state_base.hpp>
|
#include <cpp_interfaces/base/ie_variable_state_base.hpp>
|
||||||
|
#include <cpp_interfaces/interface/ie_iinfer_async_request_internal.hpp>
|
||||||
#include "ie_iinfer_request.hpp"
|
#include "ie_iinfer_request.hpp"
|
||||||
#include "ie_preprocess.hpp"
|
#include "ie_preprocess.hpp"
|
||||||
|
|
||||||
@ -19,18 +20,16 @@ namespace InferenceEngine {
|
|||||||
/**
|
/**
|
||||||
* @brief Inference request `noexcept` wrapper which accepts IAsyncInferRequestInternal derived instance which can throw exceptions
|
* @brief Inference request `noexcept` wrapper which accepts IAsyncInferRequestInternal derived instance which can throw exceptions
|
||||||
* @ingroup ie_dev_api_async_infer_request_api
|
* @ingroup ie_dev_api_async_infer_request_api
|
||||||
* @tparam T Minimal CPP implementation of IAsyncInferRequestInternal (e.g. AsyncInferRequestThreadSafeDefault)
|
|
||||||
*/
|
*/
|
||||||
template <class T>
|
|
||||||
class InferRequestBase : public IInferRequest {
|
class InferRequestBase : public IInferRequest {
|
||||||
std::shared_ptr<T> _impl;
|
std::shared_ptr<IAsyncInferRequestInternal> _impl;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
/**
|
/**
|
||||||
* @brief Constructor with actual underlying implementation.
|
* @brief Constructor with actual underlying implementation.
|
||||||
* @param impl Underlying implementation of type IAsyncInferRequestInternal
|
* @param impl Underlying implementation of type IAsyncInferRequestInternal
|
||||||
*/
|
*/
|
||||||
explicit InferRequestBase(std::shared_ptr<T> impl): _impl(impl) {}
|
explicit InferRequestBase(std::shared_ptr<IAsyncInferRequestInternal> impl): _impl(impl) {}
|
||||||
|
|
||||||
StatusCode Infer(ResponseDesc* resp) noexcept override {
|
StatusCode Infer(ResponseDesc* resp) noexcept override {
|
||||||
OV_ITT_SCOPED_TASK(itt::domains::Plugin, "Infer");
|
OV_ITT_SCOPED_TASK(itt::domains::Plugin, "Infer");
|
||||||
@ -100,7 +99,7 @@ public:
|
|||||||
if (idx >= v.size()) {
|
if (idx >= v.size()) {
|
||||||
return OUT_OF_BOUNDS;
|
return OUT_OF_BOUNDS;
|
||||||
}
|
}
|
||||||
pState = std::make_shared<VariableStateBase<IVariableStateInternal>>(v[idx]);
|
pState = std::make_shared<VariableStateBase>(v[idx]);
|
||||||
return OK;
|
return OK;
|
||||||
} catch (const std::exception& ex) {
|
} catch (const std::exception& ex) {
|
||||||
return InferenceEngine::DescriptionBuffer(GENERAL_ERROR, resp) << ex.what();
|
return InferenceEngine::DescriptionBuffer(GENERAL_ERROR, resp) << ex.what();
|
||||||
|
@ -16,19 +16,17 @@ IE_SUPPRESS_DEPRECATED_START
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Default implementation for IVariableState
|
* @brief Default implementation for IVariableState
|
||||||
* @tparam T Minimal CPP implementation of IVariableStateInternal (e.g. VariableStateInternal)
|
* @ingroup ie_dev_api_variable_state_api
|
||||||
* @ingroup ie_dev_api_variable_state_api
|
|
||||||
*/
|
*/
|
||||||
template <class T>
|
|
||||||
class VariableStateBase : public IVariableState {
|
class VariableStateBase : public IVariableState {
|
||||||
std::shared_ptr<T> impl;
|
std::shared_ptr<IVariableStateInternal> impl;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
/**
|
/**
|
||||||
* @brief Constructor with actual underlying implementation.
|
* @brief Constructor with actual underlying implementation.
|
||||||
* @param impl Underlying implementation of type IVariableStateInternal
|
* @param impl Underlying implementation of type IVariableStateInternal
|
||||||
*/
|
*/
|
||||||
explicit VariableStateBase(std::shared_ptr<T> impl): impl(impl) {
|
explicit VariableStateBase(std::shared_ptr<IVariableStateInternal> impl): impl(impl) {
|
||||||
if (impl == nullptr) {
|
if (impl == nullptr) {
|
||||||
THROW_IE_EXCEPTION << "VariableStateBase implementation is not defined";
|
THROW_IE_EXCEPTION << "VariableStateBase implementation is not defined";
|
||||||
}
|
}
|
||||||
|
@ -39,7 +39,7 @@ public:
|
|||||||
auto asyncRequestImpl = this->CreateAsyncInferRequestImpl(_networkInputs, _networkOutputs);
|
auto asyncRequestImpl = this->CreateAsyncInferRequestImpl(_networkInputs, _networkOutputs);
|
||||||
asyncRequestImpl->setPointerToExecutableNetworkInternal(shared_from_this());
|
asyncRequestImpl->setPointerToExecutableNetworkInternal(shared_from_this());
|
||||||
|
|
||||||
asyncRequest.reset(new InferRequestBase<AsyncInferRequestInternal>(asyncRequestImpl), [](IInferRequest* p) {
|
asyncRequest.reset(new InferRequestBase(asyncRequestImpl), [](IInferRequest* p) {
|
||||||
p->Release();
|
p->Release();
|
||||||
});
|
});
|
||||||
asyncRequestImpl->SetPointerToPublicInterface(asyncRequest);
|
asyncRequestImpl->SetPointerToPublicInterface(asyncRequest);
|
||||||
|
@ -69,7 +69,7 @@ protected:
|
|||||||
|
|
||||||
auto asyncThreadSafeImpl = std::make_shared<AsyncInferRequestType>(
|
auto asyncThreadSafeImpl = std::make_shared<AsyncInferRequestType>(
|
||||||
syncRequestImpl, _taskExecutor, _callbackExecutor);
|
syncRequestImpl, _taskExecutor, _callbackExecutor);
|
||||||
asyncRequest.reset(new InferRequestBase<AsyncInferRequestType>(asyncThreadSafeImpl),
|
asyncRequest.reset(new InferRequestBase(asyncThreadSafeImpl),
|
||||||
[](IInferRequest *p) { p->Release(); });
|
[](IInferRequest *p) { p->Release(); });
|
||||||
asyncThreadSafeImpl->SetPointerToPublicInterface(asyncRequest);
|
asyncThreadSafeImpl->SetPointerToPublicInterface(asyncRequest);
|
||||||
|
|
||||||
|
@ -88,8 +88,7 @@ public:
|
|||||||
auto taskExecutorGetResult = getNextTaskExecutor();
|
auto taskExecutorGetResult = getNextTaskExecutor();
|
||||||
auto asyncThreadSafeImpl = std::make_shared<MyriadAsyncInferRequest>(
|
auto asyncThreadSafeImpl = std::make_shared<MyriadAsyncInferRequest>(
|
||||||
syncRequestImpl, _taskExecutor, _callbackExecutor, taskExecutorGetResult);
|
syncRequestImpl, _taskExecutor, _callbackExecutor, taskExecutorGetResult);
|
||||||
asyncRequest.reset(new ie::InferRequestBase<ie::AsyncInferRequestThreadSafeDefault>(
|
asyncRequest.reset(new ie::InferRequestBase(asyncThreadSafeImpl),
|
||||||
asyncThreadSafeImpl),
|
|
||||||
[](ie::IInferRequest *p) { p->Release(); });
|
[](ie::IInferRequest *p) { p->Release(); });
|
||||||
asyncThreadSafeImpl->SetPointerToPublicInterface(asyncRequest);
|
asyncThreadSafeImpl->SetPointerToPublicInterface(asyncRequest);
|
||||||
return asyncRequest;
|
return asyncRequest;
|
||||||
|
@ -34,7 +34,7 @@ protected:
|
|||||||
virtual void SetUp() {
|
virtual void SetUp() {
|
||||||
mockExeNetwork = make_shared<MockExecutableNetworkThreadSafeAsyncOnly>();
|
mockExeNetwork = make_shared<MockExecutableNetworkThreadSafeAsyncOnly>();
|
||||||
exeNetwork = details::shared_from_irelease(
|
exeNetwork = details::shared_from_irelease(
|
||||||
new ExecutableNetworkBase<MockExecutableNetworkThreadSafeAsyncOnly>(mockExeNetwork));
|
new ExecutableNetworkBase(mockExeNetwork));
|
||||||
InputsDataMap networkInputs;
|
InputsDataMap networkInputs;
|
||||||
OutputsDataMap networkOutputs;
|
OutputsDataMap networkOutputs;
|
||||||
mockAsyncInferRequestInternal = make_shared<MockAsyncInferRequestInternal>(networkInputs, networkOutputs);
|
mockAsyncInferRequestInternal = make_shared<MockAsyncInferRequestInternal>(networkInputs, networkOutputs);
|
||||||
@ -46,7 +46,7 @@ TEST_F(ExecutableNetworkThreadSafeAsyncOnlyTests, createAsyncInferRequestCallsTh
|
|||||||
EXPECT_CALL(*mockExeNetwork.get(), CreateAsyncInferRequestImpl(_, _)).WillOnce(
|
EXPECT_CALL(*mockExeNetwork.get(), CreateAsyncInferRequestImpl(_, _)).WillOnce(
|
||||||
Return(mockAsyncInferRequestInternal));
|
Return(mockAsyncInferRequestInternal));
|
||||||
EXPECT_NO_THROW(exeNetwork->CreateInferRequest(req, &dsc));
|
EXPECT_NO_THROW(exeNetwork->CreateInferRequest(req, &dsc));
|
||||||
auto threadSafeReq = dynamic_pointer_cast<InferRequestBase<AsyncInferRequestInternal>>(req);
|
auto threadSafeReq = dynamic_pointer_cast<InferRequestBase>(req);
|
||||||
ASSERT_NE(threadSafeReq, nullptr);
|
ASSERT_NE(threadSafeReq, nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -109,7 +109,7 @@ protected:
|
|||||||
virtual void SetUp() {
|
virtual void SetUp() {
|
||||||
mockExeNetwork = make_shared<MockExecutableNetworkThreadSafe>();
|
mockExeNetwork = make_shared<MockExecutableNetworkThreadSafe>();
|
||||||
exeNetwork = details::shared_from_irelease(
|
exeNetwork = details::shared_from_irelease(
|
||||||
new ExecutableNetworkBase<MockExecutableNetworkThreadSafe>(mockExeNetwork));
|
new ExecutableNetworkBase(mockExeNetwork));
|
||||||
InputsDataMap networkInputs;
|
InputsDataMap networkInputs;
|
||||||
OutputsDataMap networkOutputs;
|
OutputsDataMap networkOutputs;
|
||||||
mockInferRequestInternal = make_shared<MockInferRequestInternal>(networkInputs, networkOutputs);
|
mockInferRequestInternal = make_shared<MockInferRequestInternal>(networkInputs, networkOutputs);
|
||||||
@ -120,7 +120,7 @@ TEST_F(ExecutableNetworkThreadSafeTests, createInferRequestCallsThreadSafeImplAn
|
|||||||
IInferRequest::Ptr req;
|
IInferRequest::Ptr req;
|
||||||
EXPECT_CALL(*mockExeNetwork.get(), CreateInferRequestImpl(_, _)).WillOnce(Return(mockInferRequestInternal));
|
EXPECT_CALL(*mockExeNetwork.get(), CreateInferRequestImpl(_, _)).WillOnce(Return(mockInferRequestInternal));
|
||||||
EXPECT_NO_THROW(exeNetwork->CreateInferRequest(req, &dsc));
|
EXPECT_NO_THROW(exeNetwork->CreateInferRequest(req, &dsc));
|
||||||
auto threadSafeReq = dynamic_pointer_cast<InferRequestBase<AsyncInferRequestThreadSafeDefault>>(req);
|
auto threadSafeReq = dynamic_pointer_cast<InferRequestBase>(req);
|
||||||
ASSERT_NE(threadSafeReq, nullptr);
|
ASSERT_NE(threadSafeReq, nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -34,7 +34,7 @@ protected:
|
|||||||
|
|
||||||
virtual void SetUp() {
|
virtual void SetUp() {
|
||||||
mock_impl.reset(new MockIAsyncInferRequestInternal());
|
mock_impl.reset(new MockIAsyncInferRequestInternal());
|
||||||
request = details::shared_from_irelease(new InferRequestBase<MockIAsyncInferRequestInternal>(mock_impl));
|
request = details::shared_from_irelease(new InferRequestBase(mock_impl));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -243,7 +243,7 @@ protected:
|
|||||||
mockNotEmptyNet.getOutputsInfo(outputsInfo);
|
mockNotEmptyNet.getOutputsInfo(outputsInfo);
|
||||||
mockInferRequestInternal = make_shared<MockAsyncInferRequestInternal>(inputsInfo, outputsInfo);
|
mockInferRequestInternal = make_shared<MockAsyncInferRequestInternal>(inputsInfo, outputsInfo);
|
||||||
inferRequest = shared_from_irelease(
|
inferRequest = shared_from_irelease(
|
||||||
new InferRequestBase<MockAsyncInferRequestInternal>(mockInferRequestInternal));
|
new InferRequestBase(mockInferRequestInternal));
|
||||||
return make_shared<InferRequest>(inferRequest);
|
return make_shared<InferRequest>(inferRequest);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -198,8 +198,7 @@ TEST_F(InferRequestThreadSafeDefaultTests, callbackTakesOKIfAsyncRequestWasOK) {
|
|||||||
testRequest = make_shared<AsyncInferRequestThreadSafeDefault>(mockInferRequestInternal, taskExecutor, taskExecutor);
|
testRequest = make_shared<AsyncInferRequestThreadSafeDefault>(mockInferRequestInternal, taskExecutor, taskExecutor);
|
||||||
|
|
||||||
IInferRequest::Ptr asyncRequest;
|
IInferRequest::Ptr asyncRequest;
|
||||||
asyncRequest.reset(new InferRequestBase<AsyncInferRequestThreadSafeDefault>(testRequest),
|
asyncRequest.reset(new InferRequestBase(testRequest), [](IInferRequest *p) { p->Release(); });
|
||||||
[](IInferRequest *p) { p->Release(); });
|
|
||||||
testRequest->SetPointerToPublicInterface(asyncRequest);
|
testRequest->SetPointerToPublicInterface(asyncRequest);
|
||||||
|
|
||||||
testRequest->SetCompletionCallback([](InferenceEngine::IInferRequest::Ptr request, StatusCode status) {
|
testRequest->SetCompletionCallback([](InferenceEngine::IInferRequest::Ptr request, StatusCode status) {
|
||||||
@ -215,8 +214,7 @@ TEST_F(InferRequestThreadSafeDefaultTests, callbackIsCalledIfAsyncRequestFailed)
|
|||||||
auto taskExecutor = std::make_shared<CPUStreamsExecutor>();
|
auto taskExecutor = std::make_shared<CPUStreamsExecutor>();
|
||||||
testRequest = make_shared<AsyncInferRequestThreadSafeDefault>(mockInferRequestInternal, taskExecutor, taskExecutor);
|
testRequest = make_shared<AsyncInferRequestThreadSafeDefault>(mockInferRequestInternal, taskExecutor, taskExecutor);
|
||||||
IInferRequest::Ptr asyncRequest;
|
IInferRequest::Ptr asyncRequest;
|
||||||
asyncRequest.reset(new InferRequestBase<AsyncInferRequestThreadSafeDefault>(testRequest),
|
asyncRequest.reset(new InferRequestBase(testRequest), [](IInferRequest *p) { p->Release(); });
|
||||||
[](IInferRequest *p) { p->Release(); });
|
|
||||||
testRequest->SetPointerToPublicInterface(asyncRequest);
|
testRequest->SetPointerToPublicInterface(asyncRequest);
|
||||||
|
|
||||||
bool wasCalled = false;
|
bool wasCalled = false;
|
||||||
@ -238,8 +236,7 @@ TEST_F(InferRequestThreadSafeDefaultTests, canCatchExceptionIfAsyncRequestFailed
|
|||||||
auto taskExecutor = std::make_shared<CPUStreamsExecutor>();
|
auto taskExecutor = std::make_shared<CPUStreamsExecutor>();
|
||||||
testRequest = make_shared<AsyncInferRequestThreadSafeDefault>(mockInferRequestInternal, taskExecutor, taskExecutor);
|
testRequest = make_shared<AsyncInferRequestThreadSafeDefault>(mockInferRequestInternal, taskExecutor, taskExecutor);
|
||||||
IInferRequest::Ptr asyncRequest;
|
IInferRequest::Ptr asyncRequest;
|
||||||
asyncRequest.reset(new InferRequestBase<AsyncInferRequestThreadSafeDefault>(testRequest),
|
asyncRequest.reset(new InferRequestBase(testRequest), [](IInferRequest *p) { p->Release(); });
|
||||||
[](IInferRequest *p) { p->Release(); });
|
|
||||||
testRequest->SetPointerToPublicInterface(asyncRequest);
|
testRequest->SetPointerToPublicInterface(asyncRequest);
|
||||||
|
|
||||||
EXPECT_CALL(*mockInferRequestInternal.get(), InferImpl()).WillOnce(Throw(std::exception()));
|
EXPECT_CALL(*mockInferRequestInternal.get(), InferImpl()).WillOnce(Throw(std::exception()));
|
||||||
|
@ -20,7 +20,7 @@ using namespace InferenceEngine::details;
|
|||||||
|
|
||||||
template <class T>
|
template <class T>
|
||||||
inline typename InferenceEngine::InferRequest make_infer_request(std::shared_ptr<T> impl) {
|
inline typename InferenceEngine::InferRequest make_infer_request(std::shared_ptr<T> impl) {
|
||||||
typename InferRequestBase<T>::Ptr req(new InferRequestBase<T>(impl), [](IInferRequest* p) {
|
typename InferRequestBase::Ptr req(new InferRequestBase(impl), [](IInferRequest* p) {
|
||||||
p->Release();
|
p->Release();
|
||||||
});
|
});
|
||||||
return InferenceEngine::InferRequest(req);
|
return InferenceEngine::InferRequest(req);
|
||||||
|
@ -223,7 +223,7 @@ protected:
|
|||||||
|
|
||||||
virtual void SetUp() {
|
virtual void SetUp() {
|
||||||
mock_impl.reset(new MockIExecutableNetworkInternal());
|
mock_impl.reset(new MockIExecutableNetworkInternal());
|
||||||
exeNetwork = shared_from_irelease(new ExecutableNetworkBase<MockIExecutableNetworkInternal>(mock_impl));
|
exeNetwork = shared_from_irelease(new ExecutableNetworkBase(mock_impl));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -20,21 +20,17 @@ public:
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
class MKLDNNTestEngine: public MKLDNNPlugin::Engine {
|
struct TestExecutableNetworkBase : public InferenceEngine::ExecutableNetworkBase {
|
||||||
public:
|
using InferenceEngine::ExecutableNetworkBase::_impl;
|
||||||
MKLDNNPlugin::MKLDNNGraph& getGraph(InferenceEngine::IExecutableNetwork::Ptr execNetwork) {
|
~TestExecutableNetworkBase() override = default;
|
||||||
auto * execNetworkInt =
|
|
||||||
dynamic_cast<InferenceEngine::ExecutableNetworkBase<InferenceEngine::ExecutableNetworkInternal> *>(execNetwork.get());
|
|
||||||
if (!execNetworkInt)
|
|
||||||
THROW_IE_EXCEPTION << "Cannot find loaded network!";
|
|
||||||
|
|
||||||
auto * network = reinterpret_cast<MKLDNNTestExecNetwork *>(execNetworkInt->getImpl().get());
|
|
||||||
if (!network)
|
|
||||||
THROW_IE_EXCEPTION << "Cannot get mkldnn graph!";
|
|
||||||
return network->getGraph();
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
static MKLDNNPlugin::MKLDNNGraph& getGraph(InferenceEngine::IExecutableNetwork::Ptr execNetwork) {
|
||||||
|
return reinterpret_cast<MKLDNNTestExecNetwork*>(
|
||||||
|
reinterpret_cast<TestExecutableNetworkBase*>(
|
||||||
|
execNetwork.get())->_impl.get())->getGraph();
|
||||||
|
}
|
||||||
|
|
||||||
class MKLDNNGraphLeaksTests: public ::testing::Test {
|
class MKLDNNGraphLeaksTests: public ::testing::Test {
|
||||||
protected:
|
protected:
|
||||||
void addOutputToEachNode(InferenceEngine::CNNNetwork& network, std::vector<std::string>& new_outputs,
|
void addOutputToEachNode(InferenceEngine::CNNNetwork& network, std::vector<std::string>& new_outputs,
|
||||||
@ -257,11 +253,11 @@ TEST_F(MKLDNNGraphLeaksTests, MKLDNN_not_release_outputs_fp32) {
|
|||||||
|
|
||||||
ASSERT_NE(1, network.getOutputsInfo().size());
|
ASSERT_NE(1, network.getOutputsInfo().size());
|
||||||
|
|
||||||
std::shared_ptr<MKLDNNTestEngine> score_engine(new MKLDNNTestEngine());
|
std::shared_ptr<MKLDNNPlugin::Engine> score_engine(new MKLDNNPlugin::Engine());
|
||||||
InferenceEngine::ExecutableNetwork exeNetwork1;
|
InferenceEngine::ExecutableNetwork exeNetwork1;
|
||||||
ASSERT_NO_THROW(exeNetwork1 = score_engine->LoadNetwork(network, {}));
|
ASSERT_NO_THROW(exeNetwork1 = score_engine->LoadNetwork(network, {}));
|
||||||
|
|
||||||
size_t modified_outputs_size = score_engine->getGraph(exeNetwork1).GetOutputNodes().size();
|
size_t modified_outputs_size = getGraph(exeNetwork1).GetOutputNodes().size();
|
||||||
|
|
||||||
InferenceEngine::CNNNetwork network2;
|
InferenceEngine::CNNNetwork network2;
|
||||||
ASSERT_NO_THROW(network2 = core.ReadNetwork(model, weights_ptr));
|
ASSERT_NO_THROW(network2 = core.ReadNetwork(model, weights_ptr));
|
||||||
@ -270,10 +266,12 @@ TEST_F(MKLDNNGraphLeaksTests, MKLDNN_not_release_outputs_fp32) {
|
|||||||
InferenceEngine::ExecutableNetwork exeNetwork2;
|
InferenceEngine::ExecutableNetwork exeNetwork2;
|
||||||
ASSERT_NO_THROW(exeNetwork2 = score_engine->LoadNetwork(network2, {}));
|
ASSERT_NO_THROW(exeNetwork2 = score_engine->LoadNetwork(network2, {}));
|
||||||
|
|
||||||
size_t original_outputs_size = score_engine->getGraph(exeNetwork2).GetOutputNodes().size();
|
size_t original_outputs_size = getGraph(exeNetwork2).GetOutputNodes().size();
|
||||||
|
|
||||||
ASSERT_NE(modified_outputs_size, original_outputs_size);
|
ASSERT_NE(modified_outputs_size, original_outputs_size);
|
||||||
ASSERT_EQ(1, original_outputs_size);
|
ASSERT_EQ(1, original_outputs_size);
|
||||||
|
} catch (std::exception& e) {
|
||||||
|
FAIL() << e.what();
|
||||||
} catch (...) {
|
} catch (...) {
|
||||||
FAIL();
|
FAIL();
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user