Simplified cancel behavior (#4224)
This commit is contained in:
parent
751ac1aef4
commit
91dcb515a3
@ -131,13 +131,6 @@ void TemplateInferRequest::InferImpl() {
|
||||
}
|
||||
// ! [infer_request:infer_impl]
|
||||
|
||||
// ! [infer_request:cancel]
|
||||
InferenceEngine::StatusCode TemplateInferRequest::Cancel() {
|
||||
// TODO: add code to handle cancellation request
|
||||
return InferenceEngine::OK;
|
||||
}
|
||||
// ! [infer_request:cancel]
|
||||
|
||||
template<typename SrcT, typename DstT>
|
||||
static void blobCopy(const Blob::Ptr& src, const Blob::Ptr& dst) {
|
||||
std::copy_n(InferenceEngine::as<InferenceEngine::MemoryBlob>(src)->rmap().as<const SrcT*>(),
|
||||
|
@ -40,8 +40,6 @@ public:
|
||||
void InferImpl() override;
|
||||
std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> GetPerformanceCounts() const override;
|
||||
|
||||
InferenceEngine::StatusCode Cancel() override;
|
||||
|
||||
// pipeline methods-stages which are used in async infer request implementation and assigned to particular executor
|
||||
void inferPreprocess();
|
||||
void startPipeline();
|
||||
|
@ -162,10 +162,13 @@ public:
|
||||
CALL_STATUS_FNC_NO_ARGS(Infer);
|
||||
}
|
||||
|
||||
StatusCode Cancel() {
|
||||
ResponseDesc resp;
|
||||
if (actual == nullptr) THROW_IE_EXCEPTION << "InferRequest was not initialized.";
|
||||
return actual->Cancel(&resp);
|
||||
/**
|
||||
* @copybrief IInferRequest::Cancel
|
||||
*
|
||||
* Wraps IInferRequest::Cancel
|
||||
*/
|
||||
void Cancel() {
|
||||
CALL_STATUS_FNC_NO_ARGS(Cancel);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -60,6 +60,8 @@ inline void extract_exception(StatusCode status, const char* msg) {
|
||||
throw InferNotStarted(msg);
|
||||
case NETWORK_NOT_READ:
|
||||
throw NetworkNotRead(msg);
|
||||
case INFER_CANCELLED:
|
||||
throw InferCancelled(msg);
|
||||
default:
|
||||
THROW_IE_EXCEPTION << msg << InferenceEngine::details::as_status << status;
|
||||
}
|
||||
|
@ -334,6 +334,11 @@ class NetworkNotRead : public std::logic_error {
|
||||
using std::logic_error::logic_error;
|
||||
};
|
||||
|
||||
/** @brief This class represents StatusCode::INFER_CANCELLED exception */
|
||||
class InferCancelled : public std::logic_error {
|
||||
using std::logic_error::logic_error;
|
||||
};
|
||||
|
||||
} // namespace InferenceEngine
|
||||
|
||||
#if defined(_WIN32)
|
||||
|
@ -117,9 +117,5 @@ class GNAInferRequest : public InferenceEngine::AsyncInferRequestInternal {
|
||||
return plg->QueryState();
|
||||
}
|
||||
IE_SUPPRESS_DEPRECATED_END
|
||||
|
||||
InferenceEngine::StatusCode Cancel() override {
|
||||
return InferenceEngine::NOT_IMPLEMENTED;
|
||||
}
|
||||
};
|
||||
} // namespace GNAPluginNS
|
||||
|
@ -38,7 +38,7 @@ public:
|
||||
|
||||
StatusCode Cancel(ResponseDesc* resp) noexcept override {
|
||||
OV_ITT_SCOPED_TASK(itt::domains::Plugin, "Cancel");
|
||||
NO_EXCEPT_CALL_RETURN_STATUS(_impl->Cancel());
|
||||
TO_STATUS(_impl->Cancel());
|
||||
}
|
||||
|
||||
StatusCode GetPerformanceCounts(std::map<std::string, InferenceEngineProfileInfo>& perfMap,
|
||||
|
@ -281,13 +281,10 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
StatusCode Cancel() override {
|
||||
void Cancel() override {
|
||||
std::lock_guard<std::mutex> lock{_mutex};
|
||||
if (_state == InferState::Idle) {
|
||||
return StatusCode::INFER_NOT_STARTED;
|
||||
} else {
|
||||
if (_state == InferState::Busy) {
|
||||
_state = InferState::Canceled;
|
||||
return InferenceEngine::OK;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -66,8 +66,8 @@ public:
|
||||
/**
|
||||
* @brief Default common implementation for all plugins
|
||||
*/
|
||||
StatusCode Cancel() override {
|
||||
return InferenceEngine::NOT_IMPLEMENTED;
|
||||
void Cancel() override {
|
||||
THROW_IE_EXCEPTION_WITH_STATUS(NOT_IMPLEMENTED);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -42,7 +42,7 @@ public:
|
||||
/**
|
||||
* @brief Cancel current inference request execution
|
||||
*/
|
||||
virtual StatusCode Cancel() = 0;
|
||||
virtual void Cancel() = 0;
|
||||
|
||||
/**
|
||||
* @brief Queries performance measures per layer to get feedback of what is the most time consuming layer.
|
||||
|
@ -41,31 +41,13 @@ TEST_P(CancellationTests, canCancelAsyncRequest) {
|
||||
auto execNet = ie->LoadNetwork(cnnNet, targetDevice, configuration);
|
||||
// Create InferRequest
|
||||
InferenceEngine::InferRequest req = execNet.CreateInferRequest();
|
||||
bool cancelled = false;
|
||||
req.SetCompletionCallback<std::function<void(InferenceEngine::InferRequest, InferenceEngine::StatusCode)>>(
|
||||
[&](InferenceEngine::InferRequest request, InferenceEngine::StatusCode status) {
|
||||
if (targetDevice == CommonTestUtils::DEVICE_CPU) {
|
||||
cancelled = (status == InferenceEngine::StatusCode::INFER_CANCELLED);
|
||||
}
|
||||
});
|
||||
|
||||
req.StartAsync();
|
||||
InferenceEngine::StatusCode cancelStatus = req.Cancel();
|
||||
InferenceEngine::StatusCode waitStatus = req.Wait(InferenceEngine::IInferRequest::WaitMode::RESULT_READY);
|
||||
|
||||
if (targetDevice == CommonTestUtils::DEVICE_CPU) {
|
||||
ASSERT_EQ(true, cancelStatus == InferenceEngine::StatusCode::OK ||
|
||||
cancelStatus == InferenceEngine::StatusCode::INFER_NOT_STARTED);
|
||||
if (cancelStatus == InferenceEngine::StatusCode::OK) {
|
||||
ASSERT_EQ(true, cancelled);
|
||||
ASSERT_EQ(static_cast<int>(InferenceEngine::StatusCode::INFER_CANCELLED), waitStatus);
|
||||
} else {
|
||||
ASSERT_EQ(false, cancelled);
|
||||
ASSERT_EQ(static_cast<int>(InferenceEngine::StatusCode::OK), waitStatus);
|
||||
}
|
||||
} else {
|
||||
ASSERT_EQ(static_cast<int>(InferenceEngine::StatusCode::NOT_IMPLEMENTED), cancelStatus);
|
||||
ASSERT_EQ(static_cast<int>(InferenceEngine::StatusCode::OK), waitStatus);
|
||||
ASSERT_NO_THROW(req.Cancel());
|
||||
try {
|
||||
req.Wait(InferenceEngine::IInferRequest::WaitMode::RESULT_READY);
|
||||
} catch (const InferenceEngine::InferCancelled& ex) {
|
||||
SUCCEED();
|
||||
}
|
||||
}
|
||||
|
||||
@ -79,14 +61,16 @@ TEST_P(CancellationTests, canResetAfterCancelAsyncRequest) {
|
||||
// Create InferRequest
|
||||
InferenceEngine::InferRequest req = execNet.CreateInferRequest();
|
||||
|
||||
req.StartAsync();
|
||||
req.Cancel();
|
||||
req.Wait(InferenceEngine::IInferRequest::WaitMode::RESULT_READY);
|
||||
ASSERT_NO_THROW(req.StartAsync());
|
||||
ASSERT_NO_THROW(req.Cancel());
|
||||
try {
|
||||
req.Wait(InferenceEngine::IInferRequest::WaitMode::RESULT_READY);
|
||||
} catch (const InferenceEngine::InferCancelled& ex) {
|
||||
SUCCEED();
|
||||
}
|
||||
|
||||
req.StartAsync();
|
||||
InferenceEngine::StatusCode waitStatus = req.Wait(InferenceEngine::IInferRequest::WaitMode::RESULT_READY);
|
||||
|
||||
ASSERT_EQ(static_cast<int>(InferenceEngine::StatusCode::OK), waitStatus);
|
||||
ASSERT_NO_THROW(req.StartAsync());
|
||||
ASSERT_NO_THROW(req.Wait(InferenceEngine::IInferRequest::WaitMode::RESULT_READY));
|
||||
}
|
||||
|
||||
TEST_P(CancellationTests, canCancelBeforeAsyncRequest) {
|
||||
@ -99,13 +83,7 @@ TEST_P(CancellationTests, canCancelBeforeAsyncRequest) {
|
||||
// Create InferRequest
|
||||
InferenceEngine::InferRequest req = execNet.CreateInferRequest();
|
||||
|
||||
InferenceEngine::StatusCode cancelStatus = req.Cancel();
|
||||
|
||||
if (targetDevice == CommonTestUtils::DEVICE_CPU) {
|
||||
ASSERT_EQ(static_cast<int>(InferenceEngine::StatusCode::INFER_NOT_STARTED), cancelStatus);
|
||||
} else {
|
||||
ASSERT_EQ(static_cast<int>(InferenceEngine::StatusCode::NOT_IMPLEMENTED), cancelStatus);
|
||||
}
|
||||
ASSERT_NO_THROW(req.Cancel());
|
||||
}
|
||||
|
||||
TEST_P(CancellationTests, canCancelInferRequest) {
|
||||
@ -126,24 +104,11 @@ TEST_P(CancellationTests, canCancelInferRequest) {
|
||||
while (req.Wait(statusOnly) == InferenceEngine::StatusCode::INFER_NOT_STARTED) {
|
||||
}
|
||||
|
||||
InferenceEngine::StatusCode cancelStatus = req.Cancel();
|
||||
InferenceEngine::StatusCode inferStatus = InferenceEngine::StatusCode::OK;
|
||||
|
||||
ASSERT_NO_THROW(req.Cancel());
|
||||
try {
|
||||
infer.get();
|
||||
} catch (InferenceEngine::details::InferenceEngineException& ex) {
|
||||
inferStatus = ex.getStatus();
|
||||
}
|
||||
|
||||
if (targetDevice == CommonTestUtils::DEVICE_CPU) {
|
||||
if (cancelStatus == InferenceEngine::StatusCode::OK) {
|
||||
ASSERT_EQ(static_cast<int>(InferenceEngine::StatusCode::INFER_CANCELLED), inferStatus);
|
||||
} else {
|
||||
ASSERT_EQ(static_cast<int>(InferenceEngine::StatusCode::OK), inferStatus);
|
||||
}
|
||||
} else {
|
||||
ASSERT_EQ(static_cast<int>(InferenceEngine::StatusCode::NOT_IMPLEMENTED), cancelStatus);
|
||||
ASSERT_EQ(static_cast<int>(InferenceEngine::StatusCode::OK), inferStatus);
|
||||
} catch (const InferenceEngine::InferCancelled& ex) {
|
||||
SUCCEED();
|
||||
}
|
||||
}
|
||||
} // namespace BehaviorTestsDefinitions
|
||||
|
@ -32,6 +32,6 @@ public:
|
||||
MOCK_METHOD1(setNetworkOutputs, void(OutputsDataMap));
|
||||
MOCK_METHOD1(GetBlob, Blob::Ptr(const std::string&));
|
||||
MOCK_METHOD1(SetCompletionCallback, void(IInferRequest::CompletionCallback));
|
||||
MOCK_METHOD0(Cancel, InferenceEngine::StatusCode());
|
||||
MOCK_METHOD0(Cancel_ThreadUnsafe, InferenceEngine::StatusCode());
|
||||
MOCK_METHOD0(Cancel, void());
|
||||
MOCK_METHOD0(Cancel_ThreadUnsafe, void());
|
||||
};
|
||||
|
@ -23,5 +23,5 @@ public:
|
||||
MOCK_METHOD0(InferImpl, void());
|
||||
MOCK_CONST_METHOD0(GetPerformanceCounts, std::map<std::string, InferenceEngineProfileInfo>());
|
||||
MOCK_METHOD0(checkBlobs, void());
|
||||
MOCK_METHOD0(Cancel, InferenceEngine::StatusCode());
|
||||
MOCK_METHOD0(Cancel, void());
|
||||
};
|
||||
|
@ -28,5 +28,5 @@ public:
|
||||
MOCK_METHOD1(SetCompletionCallback, void(InferenceEngine::IInferRequest::CompletionCallback));
|
||||
MOCK_METHOD1(SetBatch, void(int));
|
||||
MOCK_METHOD0(QueryState, std::vector<IVariableStateInternal::Ptr>());
|
||||
MOCK_METHOD0(Cancel, InferenceEngine::StatusCode());
|
||||
MOCK_METHOD0(Cancel, void());
|
||||
};
|
||||
|
Loading…
Reference in New Issue
Block a user