Simplified cancel behavior (#4224)

This commit is contained in:
Anton Pankratv
2021-02-09 12:55:24 +03:00
committed by GitHub
parent 751ac1aef4
commit 91dcb515a3
14 changed files with 42 additions and 83 deletions

View File

@@ -131,13 +131,6 @@ void TemplateInferRequest::InferImpl() {
} }
// ! [infer_request:infer_impl] // ! [infer_request:infer_impl]
// ! [infer_request:cancel]
InferenceEngine::StatusCode TemplateInferRequest::Cancel() {
// TODO: add code to handle cancellation request
return InferenceEngine::OK;
}
// ! [infer_request:cancel]
template<typename SrcT, typename DstT> template<typename SrcT, typename DstT>
static void blobCopy(const Blob::Ptr& src, const Blob::Ptr& dst) { static void blobCopy(const Blob::Ptr& src, const Blob::Ptr& dst) {
std::copy_n(InferenceEngine::as<InferenceEngine::MemoryBlob>(src)->rmap().as<const SrcT*>(), std::copy_n(InferenceEngine::as<InferenceEngine::MemoryBlob>(src)->rmap().as<const SrcT*>(),

View File

@@ -40,8 +40,6 @@ public:
void InferImpl() override; void InferImpl() override;
std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> GetPerformanceCounts() const override; std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> GetPerformanceCounts() const override;
InferenceEngine::StatusCode Cancel() override;
// pipeline methods-stages which are used in async infer request implementation and assigned to particular executor // pipeline methods-stages which are used in async infer request implementation and assigned to particular executor
void inferPreprocess(); void inferPreprocess();
void startPipeline(); void startPipeline();

View File

@@ -162,10 +162,13 @@ public:
CALL_STATUS_FNC_NO_ARGS(Infer); CALL_STATUS_FNC_NO_ARGS(Infer);
} }
StatusCode Cancel() { /**
ResponseDesc resp; * @copybrief IInferRequest::Cancel
if (actual == nullptr) THROW_IE_EXCEPTION << "InferRequest was not initialized."; *
return actual->Cancel(&resp); * Wraps IInferRequest::Cancel
*/
void Cancel() {
CALL_STATUS_FNC_NO_ARGS(Cancel);
} }
/** /**

View File

@@ -60,6 +60,8 @@ inline void extract_exception(StatusCode status, const char* msg) {
throw InferNotStarted(msg); throw InferNotStarted(msg);
case NETWORK_NOT_READ: case NETWORK_NOT_READ:
throw NetworkNotRead(msg); throw NetworkNotRead(msg);
case INFER_CANCELLED:
throw InferCancelled(msg);
default: default:
THROW_IE_EXCEPTION << msg << InferenceEngine::details::as_status << status; THROW_IE_EXCEPTION << msg << InferenceEngine::details::as_status << status;
} }

View File

@@ -334,6 +334,11 @@ class NetworkNotRead : public std::logic_error {
using std::logic_error::logic_error; using std::logic_error::logic_error;
}; };
/** @brief This class represents StatusCode::INFER_CANCELLED exception */
class InferCancelled : public std::logic_error {
using std::logic_error::logic_error;
};
} // namespace InferenceEngine } // namespace InferenceEngine
#if defined(_WIN32) #if defined(_WIN32)

View File

@@ -117,9 +117,5 @@ class GNAInferRequest : public InferenceEngine::AsyncInferRequestInternal {
return plg->QueryState(); return plg->QueryState();
} }
IE_SUPPRESS_DEPRECATED_END IE_SUPPRESS_DEPRECATED_END
InferenceEngine::StatusCode Cancel() override {
return InferenceEngine::NOT_IMPLEMENTED;
}
}; };
} // namespace GNAPluginNS } // namespace GNAPluginNS

View File

@@ -38,7 +38,7 @@ public:
StatusCode Cancel(ResponseDesc* resp) noexcept override { StatusCode Cancel(ResponseDesc* resp) noexcept override {
OV_ITT_SCOPED_TASK(itt::domains::Plugin, "Cancel"); OV_ITT_SCOPED_TASK(itt::domains::Plugin, "Cancel");
NO_EXCEPT_CALL_RETURN_STATUS(_impl->Cancel()); TO_STATUS(_impl->Cancel());
} }
StatusCode GetPerformanceCounts(std::map<std::string, InferenceEngineProfileInfo>& perfMap, StatusCode GetPerformanceCounts(std::map<std::string, InferenceEngineProfileInfo>& perfMap,

View File

@@ -281,13 +281,10 @@ public:
} }
} }
StatusCode Cancel() override { void Cancel() override {
std::lock_guard<std::mutex> lock{_mutex}; std::lock_guard<std::mutex> lock{_mutex};
if (_state == InferState::Idle) { if (_state == InferState::Busy) {
return StatusCode::INFER_NOT_STARTED;
} else {
_state = InferState::Canceled; _state = InferState::Canceled;
return InferenceEngine::OK;
} }
} }

View File

@@ -66,8 +66,8 @@ public:
/** /**
* @brief Default common implementation for all plugins * @brief Default common implementation for all plugins
*/ */
StatusCode Cancel() override { void Cancel() override {
return InferenceEngine::NOT_IMPLEMENTED; THROW_IE_EXCEPTION_WITH_STATUS(NOT_IMPLEMENTED);
} }
/** /**

View File

@@ -42,7 +42,7 @@ public:
/** /**
* @brief Cancel current inference request execution * @brief Cancel current inference request execution
*/ */
virtual StatusCode Cancel() = 0; virtual void Cancel() = 0;
/** /**
* @brief Queries performance measures per layer to get feedback of what is the most time consuming layer. * @brief Queries performance measures per layer to get feedback of what is the most time consuming layer.

View File

@@ -41,31 +41,13 @@ TEST_P(CancellationTests, canCancelAsyncRequest) {
auto execNet = ie->LoadNetwork(cnnNet, targetDevice, configuration); auto execNet = ie->LoadNetwork(cnnNet, targetDevice, configuration);
// Create InferRequest // Create InferRequest
InferenceEngine::InferRequest req = execNet.CreateInferRequest(); InferenceEngine::InferRequest req = execNet.CreateInferRequest();
bool cancelled = false;
req.SetCompletionCallback<std::function<void(InferenceEngine::InferRequest, InferenceEngine::StatusCode)>>(
[&](InferenceEngine::InferRequest request, InferenceEngine::StatusCode status) {
if (targetDevice == CommonTestUtils::DEVICE_CPU) {
cancelled = (status == InferenceEngine::StatusCode::INFER_CANCELLED);
}
});
req.StartAsync(); req.StartAsync();
InferenceEngine::StatusCode cancelStatus = req.Cancel();
InferenceEngine::StatusCode waitStatus = req.Wait(InferenceEngine::IInferRequest::WaitMode::RESULT_READY);
if (targetDevice == CommonTestUtils::DEVICE_CPU) { ASSERT_NO_THROW(req.Cancel());
ASSERT_EQ(true, cancelStatus == InferenceEngine::StatusCode::OK || try {
cancelStatus == InferenceEngine::StatusCode::INFER_NOT_STARTED); req.Wait(InferenceEngine::IInferRequest::WaitMode::RESULT_READY);
if (cancelStatus == InferenceEngine::StatusCode::OK) { } catch (const InferenceEngine::InferCancelled& ex) {
ASSERT_EQ(true, cancelled); SUCCEED();
ASSERT_EQ(static_cast<int>(InferenceEngine::StatusCode::INFER_CANCELLED), waitStatus);
} else {
ASSERT_EQ(false, cancelled);
ASSERT_EQ(static_cast<int>(InferenceEngine::StatusCode::OK), waitStatus);
}
} else {
ASSERT_EQ(static_cast<int>(InferenceEngine::StatusCode::NOT_IMPLEMENTED), cancelStatus);
ASSERT_EQ(static_cast<int>(InferenceEngine::StatusCode::OK), waitStatus);
} }
} }
@@ -79,14 +61,16 @@ TEST_P(CancellationTests, canResetAfterCancelAsyncRequest) {
// Create InferRequest // Create InferRequest
InferenceEngine::InferRequest req = execNet.CreateInferRequest(); InferenceEngine::InferRequest req = execNet.CreateInferRequest();
req.StartAsync(); ASSERT_NO_THROW(req.StartAsync());
req.Cancel(); ASSERT_NO_THROW(req.Cancel());
req.Wait(InferenceEngine::IInferRequest::WaitMode::RESULT_READY); try {
req.Wait(InferenceEngine::IInferRequest::WaitMode::RESULT_READY);
} catch (const InferenceEngine::InferCancelled& ex) {
SUCCEED();
}
req.StartAsync(); ASSERT_NO_THROW(req.StartAsync());
InferenceEngine::StatusCode waitStatus = req.Wait(InferenceEngine::IInferRequest::WaitMode::RESULT_READY); ASSERT_NO_THROW(req.Wait(InferenceEngine::IInferRequest::WaitMode::RESULT_READY));
ASSERT_EQ(static_cast<int>(InferenceEngine::StatusCode::OK), waitStatus);
} }
TEST_P(CancellationTests, canCancelBeforeAsyncRequest) { TEST_P(CancellationTests, canCancelBeforeAsyncRequest) {
@@ -99,13 +83,7 @@ TEST_P(CancellationTests, canCancelBeforeAsyncRequest) {
// Create InferRequest // Create InferRequest
InferenceEngine::InferRequest req = execNet.CreateInferRequest(); InferenceEngine::InferRequest req = execNet.CreateInferRequest();
InferenceEngine::StatusCode cancelStatus = req.Cancel(); ASSERT_NO_THROW(req.Cancel());
if (targetDevice == CommonTestUtils::DEVICE_CPU) {
ASSERT_EQ(static_cast<int>(InferenceEngine::StatusCode::INFER_NOT_STARTED), cancelStatus);
} else {
ASSERT_EQ(static_cast<int>(InferenceEngine::StatusCode::NOT_IMPLEMENTED), cancelStatus);
}
} }
TEST_P(CancellationTests, canCancelInferRequest) { TEST_P(CancellationTests, canCancelInferRequest) {
@@ -126,24 +104,11 @@ TEST_P(CancellationTests, canCancelInferRequest) {
while (req.Wait(statusOnly) == InferenceEngine::StatusCode::INFER_NOT_STARTED) { while (req.Wait(statusOnly) == InferenceEngine::StatusCode::INFER_NOT_STARTED) {
} }
InferenceEngine::StatusCode cancelStatus = req.Cancel(); ASSERT_NO_THROW(req.Cancel());
InferenceEngine::StatusCode inferStatus = InferenceEngine::StatusCode::OK;
try { try {
infer.get(); infer.get();
} catch (InferenceEngine::details::InferenceEngineException& ex) { } catch (const InferenceEngine::InferCancelled& ex) {
inferStatus = ex.getStatus(); SUCCEED();
}
if (targetDevice == CommonTestUtils::DEVICE_CPU) {
if (cancelStatus == InferenceEngine::StatusCode::OK) {
ASSERT_EQ(static_cast<int>(InferenceEngine::StatusCode::INFER_CANCELLED), inferStatus);
} else {
ASSERT_EQ(static_cast<int>(InferenceEngine::StatusCode::OK), inferStatus);
}
} else {
ASSERT_EQ(static_cast<int>(InferenceEngine::StatusCode::NOT_IMPLEMENTED), cancelStatus);
ASSERT_EQ(static_cast<int>(InferenceEngine::StatusCode::OK), inferStatus);
} }
} }
} // namespace BehaviorTestsDefinitions } // namespace BehaviorTestsDefinitions

View File

@@ -32,6 +32,6 @@ public:
MOCK_METHOD1(setNetworkOutputs, void(OutputsDataMap)); MOCK_METHOD1(setNetworkOutputs, void(OutputsDataMap));
MOCK_METHOD1(GetBlob, Blob::Ptr(const std::string&)); MOCK_METHOD1(GetBlob, Blob::Ptr(const std::string&));
MOCK_METHOD1(SetCompletionCallback, void(IInferRequest::CompletionCallback)); MOCK_METHOD1(SetCompletionCallback, void(IInferRequest::CompletionCallback));
MOCK_METHOD0(Cancel, InferenceEngine::StatusCode()); MOCK_METHOD0(Cancel, void());
MOCK_METHOD0(Cancel_ThreadUnsafe, InferenceEngine::StatusCode()); MOCK_METHOD0(Cancel_ThreadUnsafe, void());
}; };

View File

@@ -23,5 +23,5 @@ public:
MOCK_METHOD0(InferImpl, void()); MOCK_METHOD0(InferImpl, void());
MOCK_CONST_METHOD0(GetPerformanceCounts, std::map<std::string, InferenceEngineProfileInfo>()); MOCK_CONST_METHOD0(GetPerformanceCounts, std::map<std::string, InferenceEngineProfileInfo>());
MOCK_METHOD0(checkBlobs, void()); MOCK_METHOD0(checkBlobs, void());
MOCK_METHOD0(Cancel, InferenceEngine::StatusCode()); MOCK_METHOD0(Cancel, void());
}; };

View File

@@ -28,5 +28,5 @@ public:
MOCK_METHOD1(SetCompletionCallback, void(InferenceEngine::IInferRequest::CompletionCallback)); MOCK_METHOD1(SetCompletionCallback, void(InferenceEngine::IInferRequest::CompletionCallback));
MOCK_METHOD1(SetBatch, void(int)); MOCK_METHOD1(SetBatch, void(int));
MOCK_METHOD0(QueryState, std::vector<IVariableStateInternal::Ptr>()); MOCK_METHOD0(QueryState, std::vector<IVariableStateInternal::Ptr>());
MOCK_METHOD0(Cancel, InferenceEngine::StatusCode()); MOCK_METHOD0(Cancel, void());
}; };