Removed template from base (#4045)
Co-authored-by: Ilya Lavrenov <ilya.lavrenov@intel.com>
This commit is contained in:
parent
e80e5e7ae5
commit
945da5ff4f
@ -143,7 +143,7 @@ InferenceEngine::IInferRequest::Ptr TemplatePlugin::ExecutableNetwork::CreateInf
|
||||
auto internalRequest = CreateInferRequestImpl(_networkInputs, _networkOutputs);
|
||||
auto asyncThreadSafeImpl = std::make_shared<TemplateAsyncInferRequest>(std::static_pointer_cast<TemplateInferRequest>(internalRequest),
|
||||
_taskExecutor, _plugin->_waitExecutor, _callbackExecutor);
|
||||
asyncRequest.reset(new InferenceEngine::InferRequestBase<TemplateAsyncInferRequest>(asyncThreadSafeImpl),
|
||||
asyncRequest.reset(new InferenceEngine::InferRequestBase(asyncThreadSafeImpl),
|
||||
[](InferenceEngine::IInferRequest *p) { p->Release(); });
|
||||
asyncThreadSafeImpl->SetPointerToPublicInterface(asyncRequest);
|
||||
return asyncRequest;
|
||||
|
@ -201,7 +201,7 @@ IInferRequest::Ptr MultiDeviceExecutableNetwork::CreateInferRequest() {
|
||||
_needPerfCounters,
|
||||
std::static_pointer_cast<MultiDeviceExecutableNetwork>(shared_from_this()),
|
||||
_callbackExecutor);
|
||||
asyncRequest.reset(new InferRequestBase<MultiDeviceAsyncInferRequest>(asyncTreadSafeImpl), [](IInferRequest *p) { p->Release(); });
|
||||
asyncRequest.reset(new InferRequestBase(asyncTreadSafeImpl), [](IInferRequest *p) { p->Release(); });
|
||||
asyncTreadSafeImpl->SetPointerToPublicInterface(asyncRequest);
|
||||
return asyncRequest;
|
||||
}
|
||||
|
@ -12,6 +12,7 @@
|
||||
#include <cpp/ie_executable_network.hpp>
|
||||
#include <cpp_interfaces/base/ie_variable_state_base.hpp>
|
||||
#include <cpp_interfaces/interface/ie_ivariable_state_internal.hpp>
|
||||
#include <cpp_interfaces/interface/ie_iexecutable_network_internal.hpp>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
@ -24,18 +25,17 @@ namespace InferenceEngine {
|
||||
/**
|
||||
* @brief Executable network `noexcept` wrapper which accepts IExecutableNetworkInternal derived instance which can throw exceptions
|
||||
* @ingroup ie_dev_api_exec_network_api
|
||||
* @tparam T Minimal CPP implementation of IExecutableNetworkInternal (e.g. ExecutableNetworkInternal)
|
||||
*/
|
||||
template <class T>
|
||||
class ExecutableNetworkBase : public IExecutableNetwork {
|
||||
std::shared_ptr<T> _impl;
|
||||
protected:
|
||||
std::shared_ptr<IExecutableNetworkInternal> _impl;
|
||||
|
||||
public:
|
||||
/**
|
||||
* @brief Constructor with actual underlying implementation.
|
||||
* @param impl Underlying implementation of type IExecutableNetworkInternal
|
||||
*/
|
||||
explicit ExecutableNetworkBase(std::shared_ptr<T> impl) {
|
||||
explicit ExecutableNetworkBase(std::shared_ptr<IExecutableNetworkInternal> impl) {
|
||||
if (impl.get() == nullptr) {
|
||||
THROW_IE_EXCEPTION << "implementation not defined";
|
||||
}
|
||||
@ -77,7 +77,7 @@ public:
|
||||
if (idx >= v.size()) {
|
||||
return OUT_OF_BOUNDS;
|
||||
}
|
||||
pState = std::make_shared<VariableStateBase<IVariableStateInternal>>(v[idx]);
|
||||
pState = std::make_shared<VariableStateBase>(v[idx]);
|
||||
return OK;
|
||||
} catch (const std::exception& ex) {
|
||||
return InferenceEngine::DescriptionBuffer(GENERAL_ERROR, resp) << ex.what();
|
||||
@ -91,11 +91,6 @@ public:
|
||||
delete this;
|
||||
}
|
||||
|
||||
/// @private Need for unit tests only - TODO: unit tests should test using public API, non having details
|
||||
const std::shared_ptr<T> getImpl() const {
|
||||
return _impl;
|
||||
}
|
||||
|
||||
StatusCode SetConfig(const std::map<std::string, Parameter>& config, ResponseDesc* resp) noexcept override {
|
||||
TO_STATUS(_impl->SetConfig(config));
|
||||
}
|
||||
@ -112,8 +107,8 @@ public:
|
||||
TO_STATUS(pContext = _impl->GetContext());
|
||||
}
|
||||
|
||||
private:
|
||||
~ExecutableNetworkBase() = default;
|
||||
protected:
|
||||
~ExecutableNetworkBase() override = default;
|
||||
};
|
||||
|
||||
/**
|
||||
@ -127,7 +122,7 @@ template <class T>
|
||||
inline typename InferenceEngine::ExecutableNetwork make_executable_network(std::shared_ptr<T> impl) {
|
||||
// to suppress warning about deprecated QueryState
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
typename ExecutableNetworkBase<T>::Ptr net(new ExecutableNetworkBase<T>(impl), [](IExecutableNetwork* p) {
|
||||
typename ExecutableNetworkBase::Ptr net(new ExecutableNetworkBase(impl), [](IExecutableNetwork* p) {
|
||||
p->Release();
|
||||
});
|
||||
IE_SUPPRESS_DEPRECATED_END
|
||||
|
@ -11,6 +11,7 @@
|
||||
#include "cpp_interfaces/exception2status.hpp"
|
||||
#include "cpp_interfaces/plugin_itt.hpp"
|
||||
#include <cpp_interfaces/base/ie_variable_state_base.hpp>
|
||||
#include <cpp_interfaces/interface/ie_iinfer_async_request_internal.hpp>
|
||||
#include "ie_iinfer_request.hpp"
|
||||
#include "ie_preprocess.hpp"
|
||||
|
||||
@ -19,18 +20,16 @@ namespace InferenceEngine {
|
||||
/**
|
||||
* @brief Inference request `noexcept` wrapper which accepts IAsyncInferRequestInternal derived instance which can throw exceptions
|
||||
* @ingroup ie_dev_api_async_infer_request_api
|
||||
* @tparam T Minimal CPP implementation of IAsyncInferRequestInternal (e.g. AsyncInferRequestThreadSafeDefault)
|
||||
*/
|
||||
template <class T>
|
||||
class InferRequestBase : public IInferRequest {
|
||||
std::shared_ptr<T> _impl;
|
||||
std::shared_ptr<IAsyncInferRequestInternal> _impl;
|
||||
|
||||
public:
|
||||
/**
|
||||
* @brief Constructor with actual underlying implementation.
|
||||
* @param impl Underlying implementation of type IAsyncInferRequestInternal
|
||||
*/
|
||||
explicit InferRequestBase(std::shared_ptr<T> impl): _impl(impl) {}
|
||||
explicit InferRequestBase(std::shared_ptr<IAsyncInferRequestInternal> impl): _impl(impl) {}
|
||||
|
||||
StatusCode Infer(ResponseDesc* resp) noexcept override {
|
||||
OV_ITT_SCOPED_TASK(itt::domains::Plugin, "Infer");
|
||||
@ -100,7 +99,7 @@ public:
|
||||
if (idx >= v.size()) {
|
||||
return OUT_OF_BOUNDS;
|
||||
}
|
||||
pState = std::make_shared<VariableStateBase<IVariableStateInternal>>(v[idx]);
|
||||
pState = std::make_shared<VariableStateBase>(v[idx]);
|
||||
return OK;
|
||||
} catch (const std::exception& ex) {
|
||||
return InferenceEngine::DescriptionBuffer(GENERAL_ERROR, resp) << ex.what();
|
||||
|
@ -16,19 +16,17 @@ IE_SUPPRESS_DEPRECATED_START
|
||||
|
||||
/**
|
||||
* @brief Default implementation for IVariableState
|
||||
* @tparam T Minimal CPP implementation of IVariableStateInternal (e.g. VariableStateInternal)
|
||||
* @ingroup ie_dev_api_variable_state_api
|
||||
*/
|
||||
template <class T>
|
||||
class VariableStateBase : public IVariableState {
|
||||
std::shared_ptr<T> impl;
|
||||
std::shared_ptr<IVariableStateInternal> impl;
|
||||
|
||||
public:
|
||||
/**
|
||||
* @brief Constructor with actual underlying implementation.
|
||||
* @param impl Underlying implementation of type IVariableStateInternal
|
||||
*/
|
||||
explicit VariableStateBase(std::shared_ptr<T> impl): impl(impl) {
|
||||
explicit VariableStateBase(std::shared_ptr<IVariableStateInternal> impl): impl(impl) {
|
||||
if (impl == nullptr) {
|
||||
THROW_IE_EXCEPTION << "VariableStateBase implementation is not defined";
|
||||
}
|
||||
|
@ -39,7 +39,7 @@ public:
|
||||
auto asyncRequestImpl = this->CreateAsyncInferRequestImpl(_networkInputs, _networkOutputs);
|
||||
asyncRequestImpl->setPointerToExecutableNetworkInternal(shared_from_this());
|
||||
|
||||
asyncRequest.reset(new InferRequestBase<AsyncInferRequestInternal>(asyncRequestImpl), [](IInferRequest* p) {
|
||||
asyncRequest.reset(new InferRequestBase(asyncRequestImpl), [](IInferRequest* p) {
|
||||
p->Release();
|
||||
});
|
||||
asyncRequestImpl->SetPointerToPublicInterface(asyncRequest);
|
||||
|
@ -69,7 +69,7 @@ protected:
|
||||
|
||||
auto asyncThreadSafeImpl = std::make_shared<AsyncInferRequestType>(
|
||||
syncRequestImpl, _taskExecutor, _callbackExecutor);
|
||||
asyncRequest.reset(new InferRequestBase<AsyncInferRequestType>(asyncThreadSafeImpl),
|
||||
asyncRequest.reset(new InferRequestBase(asyncThreadSafeImpl),
|
||||
[](IInferRequest *p) { p->Release(); });
|
||||
asyncThreadSafeImpl->SetPointerToPublicInterface(asyncRequest);
|
||||
|
||||
|
@ -88,8 +88,7 @@ public:
|
||||
auto taskExecutorGetResult = getNextTaskExecutor();
|
||||
auto asyncThreadSafeImpl = std::make_shared<MyriadAsyncInferRequest>(
|
||||
syncRequestImpl, _taskExecutor, _callbackExecutor, taskExecutorGetResult);
|
||||
asyncRequest.reset(new ie::InferRequestBase<ie::AsyncInferRequestThreadSafeDefault>(
|
||||
asyncThreadSafeImpl),
|
||||
asyncRequest.reset(new ie::InferRequestBase(asyncThreadSafeImpl),
|
||||
[](ie::IInferRequest *p) { p->Release(); });
|
||||
asyncThreadSafeImpl->SetPointerToPublicInterface(asyncRequest);
|
||||
return asyncRequest;
|
||||
|
@ -34,7 +34,7 @@ protected:
|
||||
virtual void SetUp() {
|
||||
mockExeNetwork = make_shared<MockExecutableNetworkThreadSafeAsyncOnly>();
|
||||
exeNetwork = details::shared_from_irelease(
|
||||
new ExecutableNetworkBase<MockExecutableNetworkThreadSafeAsyncOnly>(mockExeNetwork));
|
||||
new ExecutableNetworkBase(mockExeNetwork));
|
||||
InputsDataMap networkInputs;
|
||||
OutputsDataMap networkOutputs;
|
||||
mockAsyncInferRequestInternal = make_shared<MockAsyncInferRequestInternal>(networkInputs, networkOutputs);
|
||||
@ -46,7 +46,7 @@ TEST_F(ExecutableNetworkThreadSafeAsyncOnlyTests, createAsyncInferRequestCallsTh
|
||||
EXPECT_CALL(*mockExeNetwork.get(), CreateAsyncInferRequestImpl(_, _)).WillOnce(
|
||||
Return(mockAsyncInferRequestInternal));
|
||||
EXPECT_NO_THROW(exeNetwork->CreateInferRequest(req, &dsc));
|
||||
auto threadSafeReq = dynamic_pointer_cast<InferRequestBase<AsyncInferRequestInternal>>(req);
|
||||
auto threadSafeReq = dynamic_pointer_cast<InferRequestBase>(req);
|
||||
ASSERT_NE(threadSafeReq, nullptr);
|
||||
}
|
||||
|
||||
@ -109,7 +109,7 @@ protected:
|
||||
virtual void SetUp() {
|
||||
mockExeNetwork = make_shared<MockExecutableNetworkThreadSafe>();
|
||||
exeNetwork = details::shared_from_irelease(
|
||||
new ExecutableNetworkBase<MockExecutableNetworkThreadSafe>(mockExeNetwork));
|
||||
new ExecutableNetworkBase(mockExeNetwork));
|
||||
InputsDataMap networkInputs;
|
||||
OutputsDataMap networkOutputs;
|
||||
mockInferRequestInternal = make_shared<MockInferRequestInternal>(networkInputs, networkOutputs);
|
||||
@ -120,7 +120,7 @@ TEST_F(ExecutableNetworkThreadSafeTests, createInferRequestCallsThreadSafeImplAn
|
||||
IInferRequest::Ptr req;
|
||||
EXPECT_CALL(*mockExeNetwork.get(), CreateInferRequestImpl(_, _)).WillOnce(Return(mockInferRequestInternal));
|
||||
EXPECT_NO_THROW(exeNetwork->CreateInferRequest(req, &dsc));
|
||||
auto threadSafeReq = dynamic_pointer_cast<InferRequestBase<AsyncInferRequestThreadSafeDefault>>(req);
|
||||
auto threadSafeReq = dynamic_pointer_cast<InferRequestBase>(req);
|
||||
ASSERT_NE(threadSafeReq, nullptr);
|
||||
}
|
||||
|
||||
|
@ -34,7 +34,7 @@ protected:
|
||||
|
||||
virtual void SetUp() {
|
||||
mock_impl.reset(new MockIAsyncInferRequestInternal());
|
||||
request = details::shared_from_irelease(new InferRequestBase<MockIAsyncInferRequestInternal>(mock_impl));
|
||||
request = details::shared_from_irelease(new InferRequestBase(mock_impl));
|
||||
}
|
||||
};
|
||||
|
||||
@ -243,7 +243,7 @@ protected:
|
||||
mockNotEmptyNet.getOutputsInfo(outputsInfo);
|
||||
mockInferRequestInternal = make_shared<MockAsyncInferRequestInternal>(inputsInfo, outputsInfo);
|
||||
inferRequest = shared_from_irelease(
|
||||
new InferRequestBase<MockAsyncInferRequestInternal>(mockInferRequestInternal));
|
||||
new InferRequestBase(mockInferRequestInternal));
|
||||
return make_shared<InferRequest>(inferRequest);
|
||||
}
|
||||
|
||||
|
@ -198,8 +198,7 @@ TEST_F(InferRequestThreadSafeDefaultTests, callbackTakesOKIfAsyncRequestWasOK) {
|
||||
testRequest = make_shared<AsyncInferRequestThreadSafeDefault>(mockInferRequestInternal, taskExecutor, taskExecutor);
|
||||
|
||||
IInferRequest::Ptr asyncRequest;
|
||||
asyncRequest.reset(new InferRequestBase<AsyncInferRequestThreadSafeDefault>(testRequest),
|
||||
[](IInferRequest *p) { p->Release(); });
|
||||
asyncRequest.reset(new InferRequestBase(testRequest), [](IInferRequest *p) { p->Release(); });
|
||||
testRequest->SetPointerToPublicInterface(asyncRequest);
|
||||
|
||||
testRequest->SetCompletionCallback([](InferenceEngine::IInferRequest::Ptr request, StatusCode status) {
|
||||
@ -215,8 +214,7 @@ TEST_F(InferRequestThreadSafeDefaultTests, callbackIsCalledIfAsyncRequestFailed)
|
||||
auto taskExecutor = std::make_shared<CPUStreamsExecutor>();
|
||||
testRequest = make_shared<AsyncInferRequestThreadSafeDefault>(mockInferRequestInternal, taskExecutor, taskExecutor);
|
||||
IInferRequest::Ptr asyncRequest;
|
||||
asyncRequest.reset(new InferRequestBase<AsyncInferRequestThreadSafeDefault>(testRequest),
|
||||
[](IInferRequest *p) { p->Release(); });
|
||||
asyncRequest.reset(new InferRequestBase(testRequest), [](IInferRequest *p) { p->Release(); });
|
||||
testRequest->SetPointerToPublicInterface(asyncRequest);
|
||||
|
||||
bool wasCalled = false;
|
||||
@ -238,8 +236,7 @@ TEST_F(InferRequestThreadSafeDefaultTests, canCatchExceptionIfAsyncRequestFailed
|
||||
auto taskExecutor = std::make_shared<CPUStreamsExecutor>();
|
||||
testRequest = make_shared<AsyncInferRequestThreadSafeDefault>(mockInferRequestInternal, taskExecutor, taskExecutor);
|
||||
IInferRequest::Ptr asyncRequest;
|
||||
asyncRequest.reset(new InferRequestBase<AsyncInferRequestThreadSafeDefault>(testRequest),
|
||||
[](IInferRequest *p) { p->Release(); });
|
||||
asyncRequest.reset(new InferRequestBase(testRequest), [](IInferRequest *p) { p->Release(); });
|
||||
testRequest->SetPointerToPublicInterface(asyncRequest);
|
||||
|
||||
EXPECT_CALL(*mockInferRequestInternal.get(), InferImpl()).WillOnce(Throw(std::exception()));
|
||||
|
@ -20,7 +20,7 @@ using namespace InferenceEngine::details;
|
||||
|
||||
template <class T>
|
||||
inline typename InferenceEngine::InferRequest make_infer_request(std::shared_ptr<T> impl) {
|
||||
typename InferRequestBase<T>::Ptr req(new InferRequestBase<T>(impl), [](IInferRequest* p) {
|
||||
typename InferRequestBase::Ptr req(new InferRequestBase(impl), [](IInferRequest* p) {
|
||||
p->Release();
|
||||
});
|
||||
return InferenceEngine::InferRequest(req);
|
||||
|
@ -223,7 +223,7 @@ protected:
|
||||
|
||||
virtual void SetUp() {
|
||||
mock_impl.reset(new MockIExecutableNetworkInternal());
|
||||
exeNetwork = shared_from_irelease(new ExecutableNetworkBase<MockIExecutableNetworkInternal>(mock_impl));
|
||||
exeNetwork = shared_from_irelease(new ExecutableNetworkBase(mock_impl));
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -20,21 +20,17 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
class MKLDNNTestEngine: public MKLDNNPlugin::Engine {
|
||||
public:
|
||||
MKLDNNPlugin::MKLDNNGraph& getGraph(InferenceEngine::IExecutableNetwork::Ptr execNetwork) {
|
||||
auto * execNetworkInt =
|
||||
dynamic_cast<InferenceEngine::ExecutableNetworkBase<InferenceEngine::ExecutableNetworkInternal> *>(execNetwork.get());
|
||||
if (!execNetworkInt)
|
||||
THROW_IE_EXCEPTION << "Cannot find loaded network!";
|
||||
|
||||
auto * network = reinterpret_cast<MKLDNNTestExecNetwork *>(execNetworkInt->getImpl().get());
|
||||
if (!network)
|
||||
THROW_IE_EXCEPTION << "Cannot get mkldnn graph!";
|
||||
return network->getGraph();
|
||||
}
|
||||
struct TestExecutableNetworkBase : public InferenceEngine::ExecutableNetworkBase {
|
||||
using InferenceEngine::ExecutableNetworkBase::_impl;
|
||||
~TestExecutableNetworkBase() override = default;
|
||||
};
|
||||
|
||||
static MKLDNNPlugin::MKLDNNGraph& getGraph(InferenceEngine::IExecutableNetwork::Ptr execNetwork) {
|
||||
return reinterpret_cast<MKLDNNTestExecNetwork*>(
|
||||
reinterpret_cast<TestExecutableNetworkBase*>(
|
||||
execNetwork.get())->_impl.get())->getGraph();
|
||||
}
|
||||
|
||||
class MKLDNNGraphLeaksTests: public ::testing::Test {
|
||||
protected:
|
||||
void addOutputToEachNode(InferenceEngine::CNNNetwork& network, std::vector<std::string>& new_outputs,
|
||||
@ -257,11 +253,11 @@ TEST_F(MKLDNNGraphLeaksTests, MKLDNN_not_release_outputs_fp32) {
|
||||
|
||||
ASSERT_NE(1, network.getOutputsInfo().size());
|
||||
|
||||
std::shared_ptr<MKLDNNTestEngine> score_engine(new MKLDNNTestEngine());
|
||||
std::shared_ptr<MKLDNNPlugin::Engine> score_engine(new MKLDNNPlugin::Engine());
|
||||
InferenceEngine::ExecutableNetwork exeNetwork1;
|
||||
ASSERT_NO_THROW(exeNetwork1 = score_engine->LoadNetwork(network, {}));
|
||||
|
||||
size_t modified_outputs_size = score_engine->getGraph(exeNetwork1).GetOutputNodes().size();
|
||||
size_t modified_outputs_size = getGraph(exeNetwork1).GetOutputNodes().size();
|
||||
|
||||
InferenceEngine::CNNNetwork network2;
|
||||
ASSERT_NO_THROW(network2 = core.ReadNetwork(model, weights_ptr));
|
||||
@ -270,10 +266,12 @@ TEST_F(MKLDNNGraphLeaksTests, MKLDNN_not_release_outputs_fp32) {
|
||||
InferenceEngine::ExecutableNetwork exeNetwork2;
|
||||
ASSERT_NO_THROW(exeNetwork2 = score_engine->LoadNetwork(network2, {}));
|
||||
|
||||
size_t original_outputs_size = score_engine->getGraph(exeNetwork2).GetOutputNodes().size();
|
||||
size_t original_outputs_size = getGraph(exeNetwork2).GetOutputNodes().size();
|
||||
|
||||
ASSERT_NE(modified_outputs_size, original_outputs_size);
|
||||
ASSERT_EQ(1, original_outputs_size);
|
||||
} catch (std::exception& e) {
|
||||
FAIL() << e.what();
|
||||
} catch (...) {
|
||||
FAIL();
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user