Simplified cancel behavior (#4224)

This commit is contained in:
Anton Pankratv 2021-02-09 12:55:24 +03:00 committed by GitHub
parent 751ac1aef4
commit 91dcb515a3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 42 additions and 83 deletions

View File

@ -131,13 +131,6 @@ void TemplateInferRequest::InferImpl() {
}
// ! [infer_request:infer_impl]
// ! [infer_request:cancel]
InferenceEngine::StatusCode TemplateInferRequest::Cancel() {
// TODO: add code to handle cancellation request
return InferenceEngine::OK;
}
// ! [infer_request:cancel]
template<typename SrcT, typename DstT>
static void blobCopy(const Blob::Ptr& src, const Blob::Ptr& dst) {
std::copy_n(InferenceEngine::as<InferenceEngine::MemoryBlob>(src)->rmap().as<const SrcT*>(),

View File

@ -40,8 +40,6 @@ public:
void InferImpl() override;
std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> GetPerformanceCounts() const override;
InferenceEngine::StatusCode Cancel() override;
// pipeline methods-stages which are used in async infer request implementation and assigned to particular executor
void inferPreprocess();
void startPipeline();

View File

@ -162,10 +162,13 @@ public:
CALL_STATUS_FNC_NO_ARGS(Infer);
}
StatusCode Cancel() {
ResponseDesc resp;
if (actual == nullptr) THROW_IE_EXCEPTION << "InferRequest was not initialized.";
return actual->Cancel(&resp);
/**
* @copybrief IInferRequest::Cancel
*
* Wraps IInferRequest::Cancel
*/
void Cancel() {
CALL_STATUS_FNC_NO_ARGS(Cancel);
}
/**

View File

@ -60,6 +60,8 @@ inline void extract_exception(StatusCode status, const char* msg) {
throw InferNotStarted(msg);
case NETWORK_NOT_READ:
throw NetworkNotRead(msg);
case INFER_CANCELLED:
throw InferCancelled(msg);
default:
THROW_IE_EXCEPTION << msg << InferenceEngine::details::as_status << status;
}

View File

@ -334,6 +334,11 @@ class NetworkNotRead : public std::logic_error {
using std::logic_error::logic_error;
};
/** @brief This class represents StatusCode::INFER_CANCELLED exception */
class InferCancelled : public std::logic_error {
using std::logic_error::logic_error;
};
} // namespace InferenceEngine
#if defined(_WIN32)

View File

@ -117,9 +117,5 @@ class GNAInferRequest : public InferenceEngine::AsyncInferRequestInternal {
return plg->QueryState();
}
IE_SUPPRESS_DEPRECATED_END
InferenceEngine::StatusCode Cancel() override {
return InferenceEngine::NOT_IMPLEMENTED;
}
};
} // namespace GNAPluginNS

View File

@ -38,7 +38,7 @@ public:
StatusCode Cancel(ResponseDesc* resp) noexcept override {
OV_ITT_SCOPED_TASK(itt::domains::Plugin, "Cancel");
NO_EXCEPT_CALL_RETURN_STATUS(_impl->Cancel());
TO_STATUS(_impl->Cancel());
}
StatusCode GetPerformanceCounts(std::map<std::string, InferenceEngineProfileInfo>& perfMap,

View File

@ -281,13 +281,10 @@ public:
}
}
StatusCode Cancel() override {
void Cancel() override {
std::lock_guard<std::mutex> lock{_mutex};
if (_state == InferState::Idle) {
return StatusCode::INFER_NOT_STARTED;
} else {
if (_state == InferState::Busy) {
_state = InferState::Canceled;
return InferenceEngine::OK;
}
}

View File

@ -66,8 +66,8 @@ public:
/**
* @brief Default common implementation for all plugins
*/
StatusCode Cancel() override {
return InferenceEngine::NOT_IMPLEMENTED;
void Cancel() override {
THROW_IE_EXCEPTION_WITH_STATUS(NOT_IMPLEMENTED);
}
/**

View File

@ -42,7 +42,7 @@ public:
/**
* @brief Cancel current inference request execution
*/
virtual StatusCode Cancel() = 0;
virtual void Cancel() = 0;
/**
* @brief Queries performance measures per layer to get feedback of what is the most time consuming layer.

View File

@ -41,31 +41,13 @@ TEST_P(CancellationTests, canCancelAsyncRequest) {
auto execNet = ie->LoadNetwork(cnnNet, targetDevice, configuration);
// Create InferRequest
InferenceEngine::InferRequest req = execNet.CreateInferRequest();
bool cancelled = false;
req.SetCompletionCallback<std::function<void(InferenceEngine::InferRequest, InferenceEngine::StatusCode)>>(
[&](InferenceEngine::InferRequest request, InferenceEngine::StatusCode status) {
if (targetDevice == CommonTestUtils::DEVICE_CPU) {
cancelled = (status == InferenceEngine::StatusCode::INFER_CANCELLED);
}
});
req.StartAsync();
InferenceEngine::StatusCode cancelStatus = req.Cancel();
InferenceEngine::StatusCode waitStatus = req.Wait(InferenceEngine::IInferRequest::WaitMode::RESULT_READY);
if (targetDevice == CommonTestUtils::DEVICE_CPU) {
ASSERT_EQ(true, cancelStatus == InferenceEngine::StatusCode::OK ||
cancelStatus == InferenceEngine::StatusCode::INFER_NOT_STARTED);
if (cancelStatus == InferenceEngine::StatusCode::OK) {
ASSERT_EQ(true, cancelled);
ASSERT_EQ(static_cast<int>(InferenceEngine::StatusCode::INFER_CANCELLED), waitStatus);
} else {
ASSERT_EQ(false, cancelled);
ASSERT_EQ(static_cast<int>(InferenceEngine::StatusCode::OK), waitStatus);
}
} else {
ASSERT_EQ(static_cast<int>(InferenceEngine::StatusCode::NOT_IMPLEMENTED), cancelStatus);
ASSERT_EQ(static_cast<int>(InferenceEngine::StatusCode::OK), waitStatus);
ASSERT_NO_THROW(req.Cancel());
try {
req.Wait(InferenceEngine::IInferRequest::WaitMode::RESULT_READY);
} catch (const InferenceEngine::InferCancelled& ex) {
SUCCEED();
}
}
@ -79,14 +61,16 @@ TEST_P(CancellationTests, canResetAfterCancelAsyncRequest) {
// Create InferRequest
InferenceEngine::InferRequest req = execNet.CreateInferRequest();
req.StartAsync();
req.Cancel();
req.Wait(InferenceEngine::IInferRequest::WaitMode::RESULT_READY);
ASSERT_NO_THROW(req.StartAsync());
ASSERT_NO_THROW(req.Cancel());
try {
req.Wait(InferenceEngine::IInferRequest::WaitMode::RESULT_READY);
} catch (const InferenceEngine::InferCancelled& ex) {
SUCCEED();
}
req.StartAsync();
InferenceEngine::StatusCode waitStatus = req.Wait(InferenceEngine::IInferRequest::WaitMode::RESULT_READY);
ASSERT_EQ(static_cast<int>(InferenceEngine::StatusCode::OK), waitStatus);
ASSERT_NO_THROW(req.StartAsync());
ASSERT_NO_THROW(req.Wait(InferenceEngine::IInferRequest::WaitMode::RESULT_READY));
}
TEST_P(CancellationTests, canCancelBeforeAsyncRequest) {
@ -99,13 +83,7 @@ TEST_P(CancellationTests, canCancelBeforeAsyncRequest) {
// Create InferRequest
InferenceEngine::InferRequest req = execNet.CreateInferRequest();
InferenceEngine::StatusCode cancelStatus = req.Cancel();
if (targetDevice == CommonTestUtils::DEVICE_CPU) {
ASSERT_EQ(static_cast<int>(InferenceEngine::StatusCode::INFER_NOT_STARTED), cancelStatus);
} else {
ASSERT_EQ(static_cast<int>(InferenceEngine::StatusCode::NOT_IMPLEMENTED), cancelStatus);
}
ASSERT_NO_THROW(req.Cancel());
}
TEST_P(CancellationTests, canCancelInferRequest) {
@ -126,24 +104,11 @@ TEST_P(CancellationTests, canCancelInferRequest) {
while (req.Wait(statusOnly) == InferenceEngine::StatusCode::INFER_NOT_STARTED) {
}
InferenceEngine::StatusCode cancelStatus = req.Cancel();
InferenceEngine::StatusCode inferStatus = InferenceEngine::StatusCode::OK;
ASSERT_NO_THROW(req.Cancel());
try {
infer.get();
} catch (InferenceEngine::details::InferenceEngineException& ex) {
inferStatus = ex.getStatus();
}
if (targetDevice == CommonTestUtils::DEVICE_CPU) {
if (cancelStatus == InferenceEngine::StatusCode::OK) {
ASSERT_EQ(static_cast<int>(InferenceEngine::StatusCode::INFER_CANCELLED), inferStatus);
} else {
ASSERT_EQ(static_cast<int>(InferenceEngine::StatusCode::OK), inferStatus);
}
} else {
ASSERT_EQ(static_cast<int>(InferenceEngine::StatusCode::NOT_IMPLEMENTED), cancelStatus);
ASSERT_EQ(static_cast<int>(InferenceEngine::StatusCode::OK), inferStatus);
} catch (const InferenceEngine::InferCancelled& ex) {
SUCCEED();
}
}
} // namespace BehaviorTestsDefinitions

View File

@ -32,6 +32,6 @@ public:
MOCK_METHOD1(setNetworkOutputs, void(OutputsDataMap));
MOCK_METHOD1(GetBlob, Blob::Ptr(const std::string&));
MOCK_METHOD1(SetCompletionCallback, void(IInferRequest::CompletionCallback));
MOCK_METHOD0(Cancel, InferenceEngine::StatusCode());
MOCK_METHOD0(Cancel_ThreadUnsafe, InferenceEngine::StatusCode());
MOCK_METHOD0(Cancel, void());
MOCK_METHOD0(Cancel_ThreadUnsafe, void());
};

View File

@ -23,5 +23,5 @@ public:
MOCK_METHOD0(InferImpl, void());
MOCK_CONST_METHOD0(GetPerformanceCounts, std::map<std::string, InferenceEngineProfileInfo>());
MOCK_METHOD0(checkBlobs, void());
MOCK_METHOD0(Cancel, InferenceEngine::StatusCode());
MOCK_METHOD0(Cancel, void());
};

View File

@ -28,5 +28,5 @@ public:
MOCK_METHOD1(SetCompletionCallback, void(InferenceEngine::IInferRequest::CompletionCallback));
MOCK_METHOD1(SetBatch, void(int));
MOCK_METHOD0(QueryState, std::vector<IVariableStateInternal::Ptr>());
MOCK_METHOD0(Cancel, InferenceEngine::StatusCode());
MOCK_METHOD0(Cancel, void());
};