Simplified cancel behavior (#4224)
This commit is contained in:
@@ -131,13 +131,6 @@ void TemplateInferRequest::InferImpl() {
|
|||||||
}
|
}
|
||||||
// ! [infer_request:infer_impl]
|
// ! [infer_request:infer_impl]
|
||||||
|
|
||||||
// ! [infer_request:cancel]
|
|
||||||
InferenceEngine::StatusCode TemplateInferRequest::Cancel() {
|
|
||||||
// TODO: add code to handle cancellation request
|
|
||||||
return InferenceEngine::OK;
|
|
||||||
}
|
|
||||||
// ! [infer_request:cancel]
|
|
||||||
|
|
||||||
template<typename SrcT, typename DstT>
|
template<typename SrcT, typename DstT>
|
||||||
static void blobCopy(const Blob::Ptr& src, const Blob::Ptr& dst) {
|
static void blobCopy(const Blob::Ptr& src, const Blob::Ptr& dst) {
|
||||||
std::copy_n(InferenceEngine::as<InferenceEngine::MemoryBlob>(src)->rmap().as<const SrcT*>(),
|
std::copy_n(InferenceEngine::as<InferenceEngine::MemoryBlob>(src)->rmap().as<const SrcT*>(),
|
||||||
|
|||||||
@@ -40,8 +40,6 @@ public:
|
|||||||
void InferImpl() override;
|
void InferImpl() override;
|
||||||
std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> GetPerformanceCounts() const override;
|
std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> GetPerformanceCounts() const override;
|
||||||
|
|
||||||
InferenceEngine::StatusCode Cancel() override;
|
|
||||||
|
|
||||||
// pipeline methods-stages which are used in async infer request implementation and assigned to particular executor
|
// pipeline methods-stages which are used in async infer request implementation and assigned to particular executor
|
||||||
void inferPreprocess();
|
void inferPreprocess();
|
||||||
void startPipeline();
|
void startPipeline();
|
||||||
|
|||||||
@@ -162,10 +162,13 @@ public:
|
|||||||
CALL_STATUS_FNC_NO_ARGS(Infer);
|
CALL_STATUS_FNC_NO_ARGS(Infer);
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusCode Cancel() {
|
/**
|
||||||
ResponseDesc resp;
|
* @copybrief IInferRequest::Cancel
|
||||||
if (actual == nullptr) THROW_IE_EXCEPTION << "InferRequest was not initialized.";
|
*
|
||||||
return actual->Cancel(&resp);
|
* Wraps IInferRequest::Cancel
|
||||||
|
*/
|
||||||
|
void Cancel() {
|
||||||
|
CALL_STATUS_FNC_NO_ARGS(Cancel);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@@ -60,6 +60,8 @@ inline void extract_exception(StatusCode status, const char* msg) {
|
|||||||
throw InferNotStarted(msg);
|
throw InferNotStarted(msg);
|
||||||
case NETWORK_NOT_READ:
|
case NETWORK_NOT_READ:
|
||||||
throw NetworkNotRead(msg);
|
throw NetworkNotRead(msg);
|
||||||
|
case INFER_CANCELLED:
|
||||||
|
throw InferCancelled(msg);
|
||||||
default:
|
default:
|
||||||
THROW_IE_EXCEPTION << msg << InferenceEngine::details::as_status << status;
|
THROW_IE_EXCEPTION << msg << InferenceEngine::details::as_status << status;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -334,6 +334,11 @@ class NetworkNotRead : public std::logic_error {
|
|||||||
using std::logic_error::logic_error;
|
using std::logic_error::logic_error;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/** @brief This class represents StatusCode::INFER_CANCELLED exception */
|
||||||
|
class InferCancelled : public std::logic_error {
|
||||||
|
using std::logic_error::logic_error;
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace InferenceEngine
|
} // namespace InferenceEngine
|
||||||
|
|
||||||
#if defined(_WIN32)
|
#if defined(_WIN32)
|
||||||
|
|||||||
@@ -117,9 +117,5 @@ class GNAInferRequest : public InferenceEngine::AsyncInferRequestInternal {
|
|||||||
return plg->QueryState();
|
return plg->QueryState();
|
||||||
}
|
}
|
||||||
IE_SUPPRESS_DEPRECATED_END
|
IE_SUPPRESS_DEPRECATED_END
|
||||||
|
|
||||||
InferenceEngine::StatusCode Cancel() override {
|
|
||||||
return InferenceEngine::NOT_IMPLEMENTED;
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
} // namespace GNAPluginNS
|
} // namespace GNAPluginNS
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ public:
|
|||||||
|
|
||||||
StatusCode Cancel(ResponseDesc* resp) noexcept override {
|
StatusCode Cancel(ResponseDesc* resp) noexcept override {
|
||||||
OV_ITT_SCOPED_TASK(itt::domains::Plugin, "Cancel");
|
OV_ITT_SCOPED_TASK(itt::domains::Plugin, "Cancel");
|
||||||
NO_EXCEPT_CALL_RETURN_STATUS(_impl->Cancel());
|
TO_STATUS(_impl->Cancel());
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusCode GetPerformanceCounts(std::map<std::string, InferenceEngineProfileInfo>& perfMap,
|
StatusCode GetPerformanceCounts(std::map<std::string, InferenceEngineProfileInfo>& perfMap,
|
||||||
|
|||||||
@@ -281,13 +281,10 @@ public:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusCode Cancel() override {
|
void Cancel() override {
|
||||||
std::lock_guard<std::mutex> lock{_mutex};
|
std::lock_guard<std::mutex> lock{_mutex};
|
||||||
if (_state == InferState::Idle) {
|
if (_state == InferState::Busy) {
|
||||||
return StatusCode::INFER_NOT_STARTED;
|
|
||||||
} else {
|
|
||||||
_state = InferState::Canceled;
|
_state = InferState::Canceled;
|
||||||
return InferenceEngine::OK;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -66,8 +66,8 @@ public:
|
|||||||
/**
|
/**
|
||||||
* @brief Default common implementation for all plugins
|
* @brief Default common implementation for all plugins
|
||||||
*/
|
*/
|
||||||
StatusCode Cancel() override {
|
void Cancel() override {
|
||||||
return InferenceEngine::NOT_IMPLEMENTED;
|
THROW_IE_EXCEPTION_WITH_STATUS(NOT_IMPLEMENTED);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ public:
|
|||||||
/**
|
/**
|
||||||
* @brief Cancel current inference request execution
|
* @brief Cancel current inference request execution
|
||||||
*/
|
*/
|
||||||
virtual StatusCode Cancel() = 0;
|
virtual void Cancel() = 0;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Queries performance measures per layer to get feedback of what is the most time consuming layer.
|
* @brief Queries performance measures per layer to get feedback of what is the most time consuming layer.
|
||||||
|
|||||||
@@ -41,31 +41,13 @@ TEST_P(CancellationTests, canCancelAsyncRequest) {
|
|||||||
auto execNet = ie->LoadNetwork(cnnNet, targetDevice, configuration);
|
auto execNet = ie->LoadNetwork(cnnNet, targetDevice, configuration);
|
||||||
// Create InferRequest
|
// Create InferRequest
|
||||||
InferenceEngine::InferRequest req = execNet.CreateInferRequest();
|
InferenceEngine::InferRequest req = execNet.CreateInferRequest();
|
||||||
bool cancelled = false;
|
|
||||||
req.SetCompletionCallback<std::function<void(InferenceEngine::InferRequest, InferenceEngine::StatusCode)>>(
|
|
||||||
[&](InferenceEngine::InferRequest request, InferenceEngine::StatusCode status) {
|
|
||||||
if (targetDevice == CommonTestUtils::DEVICE_CPU) {
|
|
||||||
cancelled = (status == InferenceEngine::StatusCode::INFER_CANCELLED);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
req.StartAsync();
|
req.StartAsync();
|
||||||
InferenceEngine::StatusCode cancelStatus = req.Cancel();
|
|
||||||
InferenceEngine::StatusCode waitStatus = req.Wait(InferenceEngine::IInferRequest::WaitMode::RESULT_READY);
|
|
||||||
|
|
||||||
if (targetDevice == CommonTestUtils::DEVICE_CPU) {
|
ASSERT_NO_THROW(req.Cancel());
|
||||||
ASSERT_EQ(true, cancelStatus == InferenceEngine::StatusCode::OK ||
|
try {
|
||||||
cancelStatus == InferenceEngine::StatusCode::INFER_NOT_STARTED);
|
req.Wait(InferenceEngine::IInferRequest::WaitMode::RESULT_READY);
|
||||||
if (cancelStatus == InferenceEngine::StatusCode::OK) {
|
} catch (const InferenceEngine::InferCancelled& ex) {
|
||||||
ASSERT_EQ(true, cancelled);
|
SUCCEED();
|
||||||
ASSERT_EQ(static_cast<int>(InferenceEngine::StatusCode::INFER_CANCELLED), waitStatus);
|
|
||||||
} else {
|
|
||||||
ASSERT_EQ(false, cancelled);
|
|
||||||
ASSERT_EQ(static_cast<int>(InferenceEngine::StatusCode::OK), waitStatus);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
ASSERT_EQ(static_cast<int>(InferenceEngine::StatusCode::NOT_IMPLEMENTED), cancelStatus);
|
|
||||||
ASSERT_EQ(static_cast<int>(InferenceEngine::StatusCode::OK), waitStatus);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -79,14 +61,16 @@ TEST_P(CancellationTests, canResetAfterCancelAsyncRequest) {
|
|||||||
// Create InferRequest
|
// Create InferRequest
|
||||||
InferenceEngine::InferRequest req = execNet.CreateInferRequest();
|
InferenceEngine::InferRequest req = execNet.CreateInferRequest();
|
||||||
|
|
||||||
req.StartAsync();
|
ASSERT_NO_THROW(req.StartAsync());
|
||||||
req.Cancel();
|
ASSERT_NO_THROW(req.Cancel());
|
||||||
req.Wait(InferenceEngine::IInferRequest::WaitMode::RESULT_READY);
|
try {
|
||||||
|
req.Wait(InferenceEngine::IInferRequest::WaitMode::RESULT_READY);
|
||||||
|
} catch (const InferenceEngine::InferCancelled& ex) {
|
||||||
|
SUCCEED();
|
||||||
|
}
|
||||||
|
|
||||||
req.StartAsync();
|
ASSERT_NO_THROW(req.StartAsync());
|
||||||
InferenceEngine::StatusCode waitStatus = req.Wait(InferenceEngine::IInferRequest::WaitMode::RESULT_READY);
|
ASSERT_NO_THROW(req.Wait(InferenceEngine::IInferRequest::WaitMode::RESULT_READY));
|
||||||
|
|
||||||
ASSERT_EQ(static_cast<int>(InferenceEngine::StatusCode::OK), waitStatus);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(CancellationTests, canCancelBeforeAsyncRequest) {
|
TEST_P(CancellationTests, canCancelBeforeAsyncRequest) {
|
||||||
@@ -99,13 +83,7 @@ TEST_P(CancellationTests, canCancelBeforeAsyncRequest) {
|
|||||||
// Create InferRequest
|
// Create InferRequest
|
||||||
InferenceEngine::InferRequest req = execNet.CreateInferRequest();
|
InferenceEngine::InferRequest req = execNet.CreateInferRequest();
|
||||||
|
|
||||||
InferenceEngine::StatusCode cancelStatus = req.Cancel();
|
ASSERT_NO_THROW(req.Cancel());
|
||||||
|
|
||||||
if (targetDevice == CommonTestUtils::DEVICE_CPU) {
|
|
||||||
ASSERT_EQ(static_cast<int>(InferenceEngine::StatusCode::INFER_NOT_STARTED), cancelStatus);
|
|
||||||
} else {
|
|
||||||
ASSERT_EQ(static_cast<int>(InferenceEngine::StatusCode::NOT_IMPLEMENTED), cancelStatus);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(CancellationTests, canCancelInferRequest) {
|
TEST_P(CancellationTests, canCancelInferRequest) {
|
||||||
@@ -126,24 +104,11 @@ TEST_P(CancellationTests, canCancelInferRequest) {
|
|||||||
while (req.Wait(statusOnly) == InferenceEngine::StatusCode::INFER_NOT_STARTED) {
|
while (req.Wait(statusOnly) == InferenceEngine::StatusCode::INFER_NOT_STARTED) {
|
||||||
}
|
}
|
||||||
|
|
||||||
InferenceEngine::StatusCode cancelStatus = req.Cancel();
|
ASSERT_NO_THROW(req.Cancel());
|
||||||
InferenceEngine::StatusCode inferStatus = InferenceEngine::StatusCode::OK;
|
|
||||||
|
|
||||||
try {
|
try {
|
||||||
infer.get();
|
infer.get();
|
||||||
} catch (InferenceEngine::details::InferenceEngineException& ex) {
|
} catch (const InferenceEngine::InferCancelled& ex) {
|
||||||
inferStatus = ex.getStatus();
|
SUCCEED();
|
||||||
}
|
|
||||||
|
|
||||||
if (targetDevice == CommonTestUtils::DEVICE_CPU) {
|
|
||||||
if (cancelStatus == InferenceEngine::StatusCode::OK) {
|
|
||||||
ASSERT_EQ(static_cast<int>(InferenceEngine::StatusCode::INFER_CANCELLED), inferStatus);
|
|
||||||
} else {
|
|
||||||
ASSERT_EQ(static_cast<int>(InferenceEngine::StatusCode::OK), inferStatus);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
ASSERT_EQ(static_cast<int>(InferenceEngine::StatusCode::NOT_IMPLEMENTED), cancelStatus);
|
|
||||||
ASSERT_EQ(static_cast<int>(InferenceEngine::StatusCode::OK), inferStatus);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} // namespace BehaviorTestsDefinitions
|
} // namespace BehaviorTestsDefinitions
|
||||||
|
|||||||
@@ -32,6 +32,6 @@ public:
|
|||||||
MOCK_METHOD1(setNetworkOutputs, void(OutputsDataMap));
|
MOCK_METHOD1(setNetworkOutputs, void(OutputsDataMap));
|
||||||
MOCK_METHOD1(GetBlob, Blob::Ptr(const std::string&));
|
MOCK_METHOD1(GetBlob, Blob::Ptr(const std::string&));
|
||||||
MOCK_METHOD1(SetCompletionCallback, void(IInferRequest::CompletionCallback));
|
MOCK_METHOD1(SetCompletionCallback, void(IInferRequest::CompletionCallback));
|
||||||
MOCK_METHOD0(Cancel, InferenceEngine::StatusCode());
|
MOCK_METHOD0(Cancel, void());
|
||||||
MOCK_METHOD0(Cancel_ThreadUnsafe, InferenceEngine::StatusCode());
|
MOCK_METHOD0(Cancel_ThreadUnsafe, void());
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -23,5 +23,5 @@ public:
|
|||||||
MOCK_METHOD0(InferImpl, void());
|
MOCK_METHOD0(InferImpl, void());
|
||||||
MOCK_CONST_METHOD0(GetPerformanceCounts, std::map<std::string, InferenceEngineProfileInfo>());
|
MOCK_CONST_METHOD0(GetPerformanceCounts, std::map<std::string, InferenceEngineProfileInfo>());
|
||||||
MOCK_METHOD0(checkBlobs, void());
|
MOCK_METHOD0(checkBlobs, void());
|
||||||
MOCK_METHOD0(Cancel, InferenceEngine::StatusCode());
|
MOCK_METHOD0(Cancel, void());
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -28,5 +28,5 @@ public:
|
|||||||
MOCK_METHOD1(SetCompletionCallback, void(InferenceEngine::IInferRequest::CompletionCallback));
|
MOCK_METHOD1(SetCompletionCallback, void(InferenceEngine::IInferRequest::CompletionCallback));
|
||||||
MOCK_METHOD1(SetBatch, void(int));
|
MOCK_METHOD1(SetBatch, void(int));
|
||||||
MOCK_METHOD0(QueryState, std::vector<IVariableStateInternal::Ptr>());
|
MOCK_METHOD0(QueryState, std::vector<IVariableStateInternal::Ptr>());
|
||||||
MOCK_METHOD0(Cancel, InferenceEngine::StatusCode());
|
MOCK_METHOD0(Cancel, void());
|
||||||
};
|
};
|
||||||
|
|||||||
Reference in New Issue
Block a user