Removed template from base (#4045)

Co-authored-by: Ilya Lavrenov <ilya.lavrenov@intel.com>
This commit is contained in:
Anton Pankratv 2021-02-04 17:43:20 +03:00 committed by GitHub
parent e80e5e7ae5
commit 945da5ff4f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 46 additions and 60 deletions

View File

@ -143,7 +143,7 @@ InferenceEngine::IInferRequest::Ptr TemplatePlugin::ExecutableNetwork::CreateInf
auto internalRequest = CreateInferRequestImpl(_networkInputs, _networkOutputs);
auto asyncThreadSafeImpl = std::make_shared<TemplateAsyncInferRequest>(std::static_pointer_cast<TemplateInferRequest>(internalRequest),
_taskExecutor, _plugin->_waitExecutor, _callbackExecutor);
asyncRequest.reset(new InferenceEngine::InferRequestBase<TemplateAsyncInferRequest>(asyncThreadSafeImpl),
asyncRequest.reset(new InferenceEngine::InferRequestBase(asyncThreadSafeImpl),
[](InferenceEngine::IInferRequest *p) { p->Release(); });
asyncThreadSafeImpl->SetPointerToPublicInterface(asyncRequest);
return asyncRequest;

View File

@ -201,7 +201,7 @@ IInferRequest::Ptr MultiDeviceExecutableNetwork::CreateInferRequest() {
_needPerfCounters,
std::static_pointer_cast<MultiDeviceExecutableNetwork>(shared_from_this()),
_callbackExecutor);
asyncRequest.reset(new InferRequestBase<MultiDeviceAsyncInferRequest>(asyncTreadSafeImpl), [](IInferRequest *p) { p->Release(); });
asyncRequest.reset(new InferRequestBase(asyncTreadSafeImpl), [](IInferRequest *p) { p->Release(); });
asyncTreadSafeImpl->SetPointerToPublicInterface(asyncRequest);
return asyncRequest;
}

View File

@ -12,6 +12,7 @@
#include <cpp/ie_executable_network.hpp>
#include <cpp_interfaces/base/ie_variable_state_base.hpp>
#include <cpp_interfaces/interface/ie_ivariable_state_internal.hpp>
#include <cpp_interfaces/interface/ie_iexecutable_network_internal.hpp>
#include <map>
#include <memory>
#include <string>
@ -24,18 +25,17 @@ namespace InferenceEngine {
/**
* @brief Executable network `noexcept` wrapper which accepts IExecutableNetworkInternal derived instance which can throw exceptions
* @ingroup ie_dev_api_exec_network_api
* @tparam T Minimal CPP implementation of IExecutableNetworkInternal (e.g. ExecutableNetworkInternal)
*/
template <class T>
*/
class ExecutableNetworkBase : public IExecutableNetwork {
std::shared_ptr<T> _impl;
protected:
std::shared_ptr<IExecutableNetworkInternal> _impl;
public:
/**
* @brief Constructor with actual underlying implementation.
* @param impl Underlying implementation of type IExecutableNetworkInternal
*/
explicit ExecutableNetworkBase(std::shared_ptr<T> impl) {
explicit ExecutableNetworkBase(std::shared_ptr<IExecutableNetworkInternal> impl) {
if (impl.get() == nullptr) {
THROW_IE_EXCEPTION << "implementation not defined";
}
@ -77,7 +77,7 @@ public:
if (idx >= v.size()) {
return OUT_OF_BOUNDS;
}
pState = std::make_shared<VariableStateBase<IVariableStateInternal>>(v[idx]);
pState = std::make_shared<VariableStateBase>(v[idx]);
return OK;
} catch (const std::exception& ex) {
return InferenceEngine::DescriptionBuffer(GENERAL_ERROR, resp) << ex.what();
@ -91,11 +91,6 @@ public:
delete this;
}
/// @private Need for unit tests only - TODO: unit tests should test using public API, non having details
const std::shared_ptr<T> getImpl() const {
return _impl;
}
StatusCode SetConfig(const std::map<std::string, Parameter>& config, ResponseDesc* resp) noexcept override {
TO_STATUS(_impl->SetConfig(config));
}
@ -112,8 +107,8 @@ public:
TO_STATUS(pContext = _impl->GetContext());
}
private:
~ExecutableNetworkBase() = default;
protected:
~ExecutableNetworkBase() override = default;
};
/**
@ -127,7 +122,7 @@ template <class T>
inline typename InferenceEngine::ExecutableNetwork make_executable_network(std::shared_ptr<T> impl) {
// to suppress warning about deprecated QueryState
IE_SUPPRESS_DEPRECATED_START
typename ExecutableNetworkBase<T>::Ptr net(new ExecutableNetworkBase<T>(impl), [](IExecutableNetwork* p) {
typename ExecutableNetworkBase::Ptr net(new ExecutableNetworkBase(impl), [](IExecutableNetwork* p) {
p->Release();
});
IE_SUPPRESS_DEPRECATED_END

View File

@ -11,6 +11,7 @@
#include "cpp_interfaces/exception2status.hpp"
#include "cpp_interfaces/plugin_itt.hpp"
#include <cpp_interfaces/base/ie_variable_state_base.hpp>
#include <cpp_interfaces/interface/ie_iinfer_async_request_internal.hpp>
#include "ie_iinfer_request.hpp"
#include "ie_preprocess.hpp"
@ -19,18 +20,16 @@ namespace InferenceEngine {
/**
* @brief Inference request `noexcept` wrapper which accepts IAsyncInferRequestInternal derived instance which can throw exceptions
* @ingroup ie_dev_api_async_infer_request_api
* @tparam T Minimal CPP implementation of IAsyncInferRequestInternal (e.g. AsyncInferRequestThreadSafeDefault)
*/
template <class T>
class InferRequestBase : public IInferRequest {
std::shared_ptr<T> _impl;
std::shared_ptr<IAsyncInferRequestInternal> _impl;
public:
/**
* @brief Constructor with actual underlying implementation.
* @param impl Underlying implementation of type IAsyncInferRequestInternal
*/
explicit InferRequestBase(std::shared_ptr<T> impl): _impl(impl) {}
explicit InferRequestBase(std::shared_ptr<IAsyncInferRequestInternal> impl): _impl(impl) {}
StatusCode Infer(ResponseDesc* resp) noexcept override {
OV_ITT_SCOPED_TASK(itt::domains::Plugin, "Infer");
@ -100,7 +99,7 @@ public:
if (idx >= v.size()) {
return OUT_OF_BOUNDS;
}
pState = std::make_shared<VariableStateBase<IVariableStateInternal>>(v[idx]);
pState = std::make_shared<VariableStateBase>(v[idx]);
return OK;
} catch (const std::exception& ex) {
return InferenceEngine::DescriptionBuffer(GENERAL_ERROR, resp) << ex.what();

View File

@ -16,19 +16,17 @@ IE_SUPPRESS_DEPRECATED_START
/**
* @brief Default implementation for IVariableState
* @tparam T Minimal CPP implementation of IVariableStateInternal (e.g. VariableStateInternal)
* @ingroup ie_dev_api_variable_state_api
* @ingroup ie_dev_api_variable_state_api
*/
template <class T>
class VariableStateBase : public IVariableState {
std::shared_ptr<T> impl;
std::shared_ptr<IVariableStateInternal> impl;
public:
/**
* @brief Constructor with actual underlying implementation.
* @param impl Underlying implementation of type IVariableStateInternal
*/
explicit VariableStateBase(std::shared_ptr<T> impl): impl(impl) {
explicit VariableStateBase(std::shared_ptr<IVariableStateInternal> impl): impl(impl) {
if (impl == nullptr) {
THROW_IE_EXCEPTION << "VariableStateBase implementation is not defined";
}

View File

@ -39,7 +39,7 @@ public:
auto asyncRequestImpl = this->CreateAsyncInferRequestImpl(_networkInputs, _networkOutputs);
asyncRequestImpl->setPointerToExecutableNetworkInternal(shared_from_this());
asyncRequest.reset(new InferRequestBase<AsyncInferRequestInternal>(asyncRequestImpl), [](IInferRequest* p) {
asyncRequest.reset(new InferRequestBase(asyncRequestImpl), [](IInferRequest* p) {
p->Release();
});
asyncRequestImpl->SetPointerToPublicInterface(asyncRequest);

View File

@ -69,7 +69,7 @@ protected:
auto asyncThreadSafeImpl = std::make_shared<AsyncInferRequestType>(
syncRequestImpl, _taskExecutor, _callbackExecutor);
asyncRequest.reset(new InferRequestBase<AsyncInferRequestType>(asyncThreadSafeImpl),
asyncRequest.reset(new InferRequestBase(asyncThreadSafeImpl),
[](IInferRequest *p) { p->Release(); });
asyncThreadSafeImpl->SetPointerToPublicInterface(asyncRequest);

View File

@ -88,8 +88,7 @@ public:
auto taskExecutorGetResult = getNextTaskExecutor();
auto asyncThreadSafeImpl = std::make_shared<MyriadAsyncInferRequest>(
syncRequestImpl, _taskExecutor, _callbackExecutor, taskExecutorGetResult);
asyncRequest.reset(new ie::InferRequestBase<ie::AsyncInferRequestThreadSafeDefault>(
asyncThreadSafeImpl),
asyncRequest.reset(new ie::InferRequestBase(asyncThreadSafeImpl),
[](ie::IInferRequest *p) { p->Release(); });
asyncThreadSafeImpl->SetPointerToPublicInterface(asyncRequest);
return asyncRequest;

View File

@ -34,7 +34,7 @@ protected:
virtual void SetUp() {
mockExeNetwork = make_shared<MockExecutableNetworkThreadSafeAsyncOnly>();
exeNetwork = details::shared_from_irelease(
new ExecutableNetworkBase<MockExecutableNetworkThreadSafeAsyncOnly>(mockExeNetwork));
new ExecutableNetworkBase(mockExeNetwork));
InputsDataMap networkInputs;
OutputsDataMap networkOutputs;
mockAsyncInferRequestInternal = make_shared<MockAsyncInferRequestInternal>(networkInputs, networkOutputs);
@ -46,7 +46,7 @@ TEST_F(ExecutableNetworkThreadSafeAsyncOnlyTests, createAsyncInferRequestCallsTh
EXPECT_CALL(*mockExeNetwork.get(), CreateAsyncInferRequestImpl(_, _)).WillOnce(
Return(mockAsyncInferRequestInternal));
EXPECT_NO_THROW(exeNetwork->CreateInferRequest(req, &dsc));
auto threadSafeReq = dynamic_pointer_cast<InferRequestBase<AsyncInferRequestInternal>>(req);
auto threadSafeReq = dynamic_pointer_cast<InferRequestBase>(req);
ASSERT_NE(threadSafeReq, nullptr);
}
@ -109,7 +109,7 @@ protected:
virtual void SetUp() {
mockExeNetwork = make_shared<MockExecutableNetworkThreadSafe>();
exeNetwork = details::shared_from_irelease(
new ExecutableNetworkBase<MockExecutableNetworkThreadSafe>(mockExeNetwork));
new ExecutableNetworkBase(mockExeNetwork));
InputsDataMap networkInputs;
OutputsDataMap networkOutputs;
mockInferRequestInternal = make_shared<MockInferRequestInternal>(networkInputs, networkOutputs);
@ -120,7 +120,7 @@ TEST_F(ExecutableNetworkThreadSafeTests, createInferRequestCallsThreadSafeImplAn
IInferRequest::Ptr req;
EXPECT_CALL(*mockExeNetwork.get(), CreateInferRequestImpl(_, _)).WillOnce(Return(mockInferRequestInternal));
EXPECT_NO_THROW(exeNetwork->CreateInferRequest(req, &dsc));
auto threadSafeReq = dynamic_pointer_cast<InferRequestBase<AsyncInferRequestThreadSafeDefault>>(req);
auto threadSafeReq = dynamic_pointer_cast<InferRequestBase>(req);
ASSERT_NE(threadSafeReq, nullptr);
}

View File

@ -34,7 +34,7 @@ protected:
virtual void SetUp() {
mock_impl.reset(new MockIAsyncInferRequestInternal());
request = details::shared_from_irelease(new InferRequestBase<MockIAsyncInferRequestInternal>(mock_impl));
request = details::shared_from_irelease(new InferRequestBase(mock_impl));
}
};
@ -243,7 +243,7 @@ protected:
mockNotEmptyNet.getOutputsInfo(outputsInfo);
mockInferRequestInternal = make_shared<MockAsyncInferRequestInternal>(inputsInfo, outputsInfo);
inferRequest = shared_from_irelease(
new InferRequestBase<MockAsyncInferRequestInternal>(mockInferRequestInternal));
new InferRequestBase(mockInferRequestInternal));
return make_shared<InferRequest>(inferRequest);
}

View File

@ -198,8 +198,7 @@ TEST_F(InferRequestThreadSafeDefaultTests, callbackTakesOKIfAsyncRequestWasOK) {
testRequest = make_shared<AsyncInferRequestThreadSafeDefault>(mockInferRequestInternal, taskExecutor, taskExecutor);
IInferRequest::Ptr asyncRequest;
asyncRequest.reset(new InferRequestBase<AsyncInferRequestThreadSafeDefault>(testRequest),
[](IInferRequest *p) { p->Release(); });
asyncRequest.reset(new InferRequestBase(testRequest), [](IInferRequest *p) { p->Release(); });
testRequest->SetPointerToPublicInterface(asyncRequest);
testRequest->SetCompletionCallback([](InferenceEngine::IInferRequest::Ptr request, StatusCode status) {
@ -215,8 +214,7 @@ TEST_F(InferRequestThreadSafeDefaultTests, callbackIsCalledIfAsyncRequestFailed)
auto taskExecutor = std::make_shared<CPUStreamsExecutor>();
testRequest = make_shared<AsyncInferRequestThreadSafeDefault>(mockInferRequestInternal, taskExecutor, taskExecutor);
IInferRequest::Ptr asyncRequest;
asyncRequest.reset(new InferRequestBase<AsyncInferRequestThreadSafeDefault>(testRequest),
[](IInferRequest *p) { p->Release(); });
asyncRequest.reset(new InferRequestBase(testRequest), [](IInferRequest *p) { p->Release(); });
testRequest->SetPointerToPublicInterface(asyncRequest);
bool wasCalled = false;
@ -238,8 +236,7 @@ TEST_F(InferRequestThreadSafeDefaultTests, canCatchExceptionIfAsyncRequestFailed
auto taskExecutor = std::make_shared<CPUStreamsExecutor>();
testRequest = make_shared<AsyncInferRequestThreadSafeDefault>(mockInferRequestInternal, taskExecutor, taskExecutor);
IInferRequest::Ptr asyncRequest;
asyncRequest.reset(new InferRequestBase<AsyncInferRequestThreadSafeDefault>(testRequest),
[](IInferRequest *p) { p->Release(); });
asyncRequest.reset(new InferRequestBase(testRequest), [](IInferRequest *p) { p->Release(); });
testRequest->SetPointerToPublicInterface(asyncRequest);
EXPECT_CALL(*mockInferRequestInternal.get(), InferImpl()).WillOnce(Throw(std::exception()));

View File

@ -20,7 +20,7 @@ using namespace InferenceEngine::details;
template <class T>
inline typename InferenceEngine::InferRequest make_infer_request(std::shared_ptr<T> impl) {
typename InferRequestBase<T>::Ptr req(new InferRequestBase<T>(impl), [](IInferRequest* p) {
typename InferRequestBase::Ptr req(new InferRequestBase(impl), [](IInferRequest* p) {
p->Release();
});
return InferenceEngine::InferRequest(req);

View File

@ -223,7 +223,7 @@ protected:
virtual void SetUp() {
mock_impl.reset(new MockIExecutableNetworkInternal());
exeNetwork = shared_from_irelease(new ExecutableNetworkBase<MockIExecutableNetworkInternal>(mock_impl));
exeNetwork = shared_from_irelease(new ExecutableNetworkBase(mock_impl));
}
};

View File

@ -20,21 +20,17 @@ public:
}
};
class MKLDNNTestEngine: public MKLDNNPlugin::Engine {
public:
MKLDNNPlugin::MKLDNNGraph& getGraph(InferenceEngine::IExecutableNetwork::Ptr execNetwork) {
auto * execNetworkInt =
dynamic_cast<InferenceEngine::ExecutableNetworkBase<InferenceEngine::ExecutableNetworkInternal> *>(execNetwork.get());
if (!execNetworkInt)
THROW_IE_EXCEPTION << "Cannot find loaded network!";
auto * network = reinterpret_cast<MKLDNNTestExecNetwork *>(execNetworkInt->getImpl().get());
if (!network)
THROW_IE_EXCEPTION << "Cannot get mkldnn graph!";
return network->getGraph();
}
struct TestExecutableNetworkBase : public InferenceEngine::ExecutableNetworkBase {
using InferenceEngine::ExecutableNetworkBase::_impl;
~TestExecutableNetworkBase() override = default;
};
static MKLDNNPlugin::MKLDNNGraph& getGraph(InferenceEngine::IExecutableNetwork::Ptr execNetwork) {
return reinterpret_cast<MKLDNNTestExecNetwork*>(
reinterpret_cast<TestExecutableNetworkBase*>(
execNetwork.get())->_impl.get())->getGraph();
}
class MKLDNNGraphLeaksTests: public ::testing::Test {
protected:
void addOutputToEachNode(InferenceEngine::CNNNetwork& network, std::vector<std::string>& new_outputs,
@ -257,11 +253,11 @@ TEST_F(MKLDNNGraphLeaksTests, MKLDNN_not_release_outputs_fp32) {
ASSERT_NE(1, network.getOutputsInfo().size());
std::shared_ptr<MKLDNNTestEngine> score_engine(new MKLDNNTestEngine());
std::shared_ptr<MKLDNNPlugin::Engine> score_engine(new MKLDNNPlugin::Engine());
InferenceEngine::ExecutableNetwork exeNetwork1;
ASSERT_NO_THROW(exeNetwork1 = score_engine->LoadNetwork(network, {}));
size_t modified_outputs_size = score_engine->getGraph(exeNetwork1).GetOutputNodes().size();
size_t modified_outputs_size = getGraph(exeNetwork1).GetOutputNodes().size();
InferenceEngine::CNNNetwork network2;
ASSERT_NO_THROW(network2 = core.ReadNetwork(model, weights_ptr));
@ -270,10 +266,12 @@ TEST_F(MKLDNNGraphLeaksTests, MKLDNN_not_release_outputs_fp32) {
InferenceEngine::ExecutableNetwork exeNetwork2;
ASSERT_NO_THROW(exeNetwork2 = score_engine->LoadNetwork(network2, {}));
size_t original_outputs_size = score_engine->getGraph(exeNetwork2).GetOutputNodes().size();
size_t original_outputs_size = getGraph(exeNetwork2).GetOutputNodes().size();
ASSERT_NE(modified_outputs_size, original_outputs_size);
ASSERT_EQ(1, original_outputs_size);
} catch (std::exception& e) {
FAIL() << e.what();
} catch (...) {
FAIL();
}