Feature/drizshko/cancellable request (#2635)

Added Cancelability to an Infer Request class (actually implemented for the CPU only, with a stub for other devices)
This commit is contained in:
Dmitrii Ryzhkov 2020-12-13 22:38:29 -08:00 committed by GitHub
parent 2495eaf56f
commit 77ecd7e17c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 262 additions and 5 deletions

View File

@ -131,6 +131,13 @@ void TemplateInferRequest::InferImpl() {
}
// ! [infer_request:infer_impl]
// ! [infer_request:cancel]
InferenceEngine::StatusCode TemplateInferRequest::Cancel() {
// TODO: add code to handle cancellation request
return InferenceEngine::OK;
}
// ! [infer_request:cancel]
template<typename SrcT, typename DstT>
static void blobCopy(const Blob::Ptr& src, const Blob::Ptr& dst) {
std::copy_n(InferenceEngine::as<InferenceEngine::MemoryBlob>(src)->rmap().as<const SrcT*>(),

View File

@ -40,6 +40,8 @@ public:
void InferImpl() override;
void GetPerformanceCounts(std::map<std::string, InferenceEngine::InferenceEngineProfileInfo>& perfMap) const override;
InferenceEngine::StatusCode Cancel() override;
// pipeline methods-stages which are used in async infer request implementation and assigned to particular executor
void inferPreprocess();
void startPipeline();

View File

@ -162,6 +162,12 @@ public:
CALL_STATUS_FNC_NO_ARGS(Infer);
}
StatusCode Cancel() {
ResponseDesc resp;
if (actual == nullptr) THROW_IE_EXCEPTION << "InferRequest was not initialized.";
return actual->Cancel(&resp);
}
/**
* @copybrief IInferRequest::GetPerformanceCounts
*
@ -233,7 +239,8 @@ public:
ResponseDesc resp;
if (actual == nullptr) THROW_IE_EXCEPTION << "InferRequest was not initialized.";
auto res = actual->Wait(millis_timeout, &resp);
if (res != OK && res != RESULT_NOT_READY && res != INFER_NOT_STARTED) {
if (res != OK && res != RESULT_NOT_READY &&
res != INFER_NOT_STARTED && res != INFER_CANCELLED) {
InferenceEngine::details::extract_exception(res, resp.msg);
}
return res;

View File

@ -235,7 +235,8 @@ enum StatusCode : int {
RESULT_NOT_READY = -9,
NOT_ALLOCATED = -10,
INFER_NOT_STARTED = -11,
NETWORK_NOT_READ = -12
NETWORK_NOT_READ = -12,
INFER_CANCELLED = -13
};
/**

View File

@ -96,6 +96,12 @@ public:
* @return Status code of the operation: InferenceEngine::OK (0) for success
*/
virtual StatusCode Infer(ResponseDesc* resp) noexcept = 0;
/**
* @brief Cancels current async inference request
* @param resp Optional: pointer to an already allocated object to contain information in case of failure
* @return Status code of the operation: InferenceEngine::OK (0) for success
*/
virtual StatusCode Cancel(ResponseDesc* resp) noexcept = 0;
/**
* @brief Queries performance measures per layer to get feedback of what is the most time consuming layer

View File

@ -121,5 +121,9 @@ class GNAInferRequest : public InferenceEngine::AsyncInferRequestInternal {
return plg->QueryState();
}
IE_SUPPRESS_DEPRECATED_END
InferenceEngine::StatusCode Cancel() override {
return InferenceEngine::NOT_IMPLEMENTED;
}
};
} // namespace GNAPluginNS

View File

@ -767,6 +767,11 @@ void MKLDNNGraph::Infer(int batch) {
mkldnn::stream stream = mkldnn::stream(stream::kind::eager);
for (int i = 0; i < graphNodes.size(); i++) {
if (IsCancellationRequested()) {
ResetCancellationRequest();
THROW_IE_EXCEPTION << InferenceEngine::details::as_status << InferenceEngine::INFER_CANCELLED;
}
PERF(graphNodes[i]);
if (batch > 0)

View File

@ -16,6 +16,7 @@
#include <string>
#include <vector>
#include <memory>
#include <atomic>
namespace MKLDNNPlugin {
@ -29,7 +30,7 @@ public:
Ready = 1,
};
MKLDNNGraph(): status(NotReady), eng(mkldnn::engine(mkldnn::engine::kind::cpu, 0)) {}
MKLDNNGraph(): status(NotReady), eng(mkldnn::engine(mkldnn::engine::kind::cpu, 0)), cancelation_requested(false) {}
Status GetStatus() {
return status;
@ -39,6 +40,10 @@ public:
return (GetStatus() == Ready);
}
void Cancel() {
cancelation_requested.store(true);
}
void setConfig(const Config &cfg);
void setProperty(const std::map<std::string, std::string> &properties);
Config getProperty();
@ -124,6 +129,14 @@ public:
void SortTopologically();
protected:
bool IsCancellationRequested() const {
return cancelation_requested.load();
}
void ResetCancellationRequest() {
cancelation_requested.store(false);
}
void VisitNode(MKLDNNNodePtr node, std::vector<MKLDNNNodePtr>& sortedNodes);
void ForgetGraphData() {
@ -185,6 +198,8 @@ private:
InferenceEngine::CNNLayerPtr cnnLayer;
size_t outIdx;
};
std::atomic<bool> cancelation_requested;
};
} // namespace MKLDNNPlugin

View File

@ -149,6 +149,11 @@ void MKLDNNPlugin::MKLDNNInferRequest::InferImpl() {
graph->PullOutputData(_outputs);
}
InferenceEngine::StatusCode MKLDNNPlugin::MKLDNNInferRequest::Cancel() {
graph->Cancel();
return InferenceEngine::OK;
}
void MKLDNNPlugin::MKLDNNInferRequest::GetPerformanceCounts(
std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> &perfMap) const {
if (!graph || !graph->IsReady())

View File

@ -25,6 +25,8 @@ public:
void InferImpl() override;
InferenceEngine::StatusCode Cancel() override;
void GetPerformanceCounts(std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> &perfMap) const override;
/**

View File

@ -38,6 +38,11 @@ public:
TO_STATUS(_impl->Infer());
}
StatusCode Cancel(ResponseDesc* resp) noexcept override {
OV_ITT_SCOPED_TASK(itt::domains::Plugin, "Cancel");
NO_EXCEPT_CALL_RETURN_STATUS(_impl->Cancel());
}
StatusCode GetPerformanceCounts(std::map<std::string, InferenceEngineProfileInfo>& perfMap,
ResponseDesc* resp) const noexcept override {
TO_STATUS(_impl->GetPerformanceCounts(perfMap));

View File

@ -156,6 +156,14 @@ public:
return _syncRequest->QueryState();
}
StatusCode Cancel() override {
StatusCode status = Wait(IInferRequest::WaitMode::STATUS_ONLY);
if (status == INFER_NOT_STARTED) {
return status;
}
return _syncRequest->Cancel();
}
protected:
/**
* @brief Each pipeline stage is a @ref Task that is executed by specified ITaskExecutor implementation

View File

@ -63,6 +63,13 @@ public:
InferImpl();
}
/**
* @brief Default common implementation for all plugins
*/
StatusCode Cancel() override {
return InferenceEngine::NOT_IMPLEMENTED;
}
/**
* @brief Given optional implementation of setting blob to avoid need for it to be implemented by plugin
* @param name - a name of input or output blob.

View File

@ -39,6 +39,11 @@ public:
*/
virtual void Infer() = 0;
/**
* @brief Cancel current inference request execution
*/
virtual StatusCode Cancel() = 0;
/**
* @brief Queries performance measures per layer to get feedback of what is the most time consuming layer.
* Note: not all plugins may provide meaningful data

View File

@ -0,0 +1,24 @@
// Copyright (C) 2018-2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "behavior/infer_request_cancellation.hpp"
using namespace BehaviorTestsDefinitions;
namespace {
const std::vector<InferenceEngine::Precision> netPrecisions = {
InferenceEngine::Precision::FP32,
InferenceEngine::Precision::FP16
};
const std::vector<std::map<std::string, std::string>> configs = {
{},
};
INSTANTIATE_TEST_CASE_P(smoke_BehaviorTests, CancellationTests,
::testing::Combine(
::testing::ValuesIn(netPrecisions),
::testing::Values(CommonTestUtils::DEVICE_CPU),
::testing::ValuesIn(configs)),
CancellationTests::getTestCaseName);
} // namespace

View File

@ -0,0 +1,144 @@
// Copyright (C) 2018-2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <tuple>
#include <vector>
#include <string>
#include <memory>
#include <future>
#include "ie_extension.h"
#include <condition_variable>
#include "functional_test_utils/layer_test_utils.hpp"
#include "ngraph_functions/utils/ngraph_helpers.hpp"
#include "ngraph_functions/builders.hpp"
#include <ie_core.hpp>
#include <functional_test_utils/behavior_test_utils.hpp>
#include "common_test_utils/common_utils.hpp"
#include "functional_test_utils/plugin_cache.hpp"
#include "functional_test_utils/blob_utils.hpp"
#include "ngraph_functions/pass/convert_prc.hpp"
#include "ngraph_functions/subgraph_builders.hpp"
#include "behavior/infer_request_cancellation.hpp"
namespace BehaviorTestsDefinitions {
using CancellationTests = BehaviorTestsUtils::BehaviorTestsBasic;
TEST_P(CancellationTests, canCancelAsyncRequest) {
// Skip test according to plugin specific disabledTestPatterns() (if any)
SKIP_IF_CURRENT_TEST_IS_DISABLED()
std::shared_ptr<ngraph::Function> largeNetwork = ngraph::builder::subgraph::makeConvPoolRelu({1, 3, 640, 640});
// Create CNNNetwork from ngrpah::Function
InferenceEngine::CNNNetwork cnnNet(largeNetwork);
// Load CNNNetwork to target plugins
auto execNet = ie->LoadNetwork(cnnNet, targetDevice, configuration);
// Create InferRequest
InferenceEngine::InferRequest req = execNet.CreateInferRequest();
bool cancelled = false;
req.SetCompletionCallback<std::function<void(InferenceEngine::InferRequest, InferenceEngine::StatusCode)>>(
[&](InferenceEngine::InferRequest request, InferenceEngine::StatusCode status) {
if (targetDevice == CommonTestUtils::DEVICE_CPU) {
cancelled = (status == InferenceEngine::StatusCode::INFER_CANCELLED);
}
});
req.StartAsync();
InferenceEngine::StatusCode cancelStatus = req.Cancel();
InferenceEngine::StatusCode waitStatus = req.Wait(InferenceEngine::IInferRequest::WaitMode::RESULT_READY);
if (targetDevice == CommonTestUtils::DEVICE_CPU) {
ASSERT_EQ(true, cancelStatus == InferenceEngine::StatusCode::OK ||
cancelStatus == InferenceEngine::StatusCode::INFER_NOT_STARTED);
if (cancelStatus == InferenceEngine::StatusCode::OK) {
ASSERT_EQ(true, cancelled);
ASSERT_EQ(static_cast<int>(InferenceEngine::StatusCode::INFER_CANCELLED), waitStatus);
} else {
ASSERT_EQ(false, cancelled);
ASSERT_EQ(static_cast<int>(InferenceEngine::StatusCode::OK), waitStatus);
}
} else {
ASSERT_EQ(static_cast<int>(InferenceEngine::StatusCode::NOT_IMPLEMENTED), cancelStatus);
ASSERT_EQ(static_cast<int>(InferenceEngine::StatusCode::OK), waitStatus);
}
}
TEST_P(CancellationTests, canResetAfterCancelAsyncRequest) {
// Skip test according to plugin specific disabledTestPatterns() (if any)
SKIP_IF_CURRENT_TEST_IS_DISABLED()
// Create CNNNetwork from ngrpah::Function
InferenceEngine::CNNNetwork cnnNet(function);
// Load CNNNetwork to target plugins
auto execNet = ie->LoadNetwork(cnnNet, targetDevice, configuration);
// Create InferRequest
InferenceEngine::InferRequest req = execNet.CreateInferRequest();
req.StartAsync();
req.Cancel();
req.Wait(InferenceEngine::IInferRequest::WaitMode::RESULT_READY);
req.StartAsync();
InferenceEngine::StatusCode waitStatus = req.Wait(InferenceEngine::IInferRequest::WaitMode::RESULT_READY);
ASSERT_EQ(static_cast<int>(InferenceEngine::StatusCode::OK), waitStatus);
}
TEST_P(CancellationTests, canCancelBeforeAsyncRequest) {
// Skip test according to plugin specific disabledTestPatterns() (if any)
SKIP_IF_CURRENT_TEST_IS_DISABLED()
// Create CNNNetwork from ngrpah::Function
InferenceEngine::CNNNetwork cnnNet(function);
// Load CNNNetwork to target plugins
auto execNet = ie->LoadNetwork(cnnNet, targetDevice, configuration);
// Create InferRequest
InferenceEngine::InferRequest req = execNet.CreateInferRequest();
InferenceEngine::StatusCode cancelStatus = req.Cancel();
if (targetDevice == CommonTestUtils::DEVICE_CPU) {
ASSERT_EQ(static_cast<int>(InferenceEngine::StatusCode::INFER_NOT_STARTED), cancelStatus);
} else {
ASSERT_EQ(static_cast<int>(InferenceEngine::StatusCode::NOT_IMPLEMENTED), cancelStatus);
}
}
TEST_P(CancellationTests, canCancelInferRequest) {
// Skip test according to plugin specific disabledTestPatterns() (if any)
SKIP_IF_CURRENT_TEST_IS_DISABLED()
// Create function with large input, to have a time to Cancel request
std::shared_ptr<ngraph::Function> largeNetwork = ngraph::builder::subgraph::makeConvPoolRelu({1, 3, 640, 640});
// Create CNNNetwork from ngrpah::Function
InferenceEngine::CNNNetwork cnnNet(largeNetwork);
// Load CNNNetwork to target plugins
auto execNet = ie->LoadNetwork(cnnNet, targetDevice, configuration);
// Create InferRequest
InferenceEngine::InferRequest req = execNet.CreateInferRequest();
auto infer = std::async(std::launch::async, [&req]{ req.Infer(); });
const auto statusOnly = InferenceEngine::IInferRequest::WaitMode::STATUS_ONLY;
while (req.Wait(statusOnly) == InferenceEngine::StatusCode::INFER_NOT_STARTED) {
}
InferenceEngine::StatusCode cancelStatus = req.Cancel();
InferenceEngine::StatusCode inferStatus = InferenceEngine::StatusCode::OK;
try {
infer.get();
} catch (InferenceEngine::details::InferenceEngineException& ex) {
inferStatus = ex.getStatus();
}
if (targetDevice == CommonTestUtils::DEVICE_CPU) {
if (cancelStatus == InferenceEngine::StatusCode::OK) {
ASSERT_EQ(static_cast<int>(InferenceEngine::StatusCode::INFER_CANCELLED), inferStatus);
} else {
ASSERT_EQ(static_cast<int>(InferenceEngine::StatusCode::OK), inferStatus);
}
} else {
ASSERT_EQ(static_cast<int>(InferenceEngine::StatusCode::NOT_IMPLEMENTED), cancelStatus);
ASSERT_EQ(static_cast<int>(InferenceEngine::StatusCode::OK), inferStatus);
}
}
} // namespace BehaviorTestsDefinitions

View File

@ -32,4 +32,6 @@ public:
MOCK_METHOD1(setNetworkOutputs, void(OutputsDataMap));
MOCK_METHOD2(GetBlob, void(const char *name, Blob::Ptr &));
MOCK_METHOD1(SetCompletionCallback, void(IInferRequest::CompletionCallback));
MOCK_METHOD0(Cancel, InferenceEngine::StatusCode());
MOCK_METHOD0(Cancel_ThreadUnsafe, InferenceEngine::StatusCode());
};

View File

@ -62,4 +62,6 @@ public:
MOCK_METHOD1(SetBatch, void(int));
MOCK_METHOD1(SetBatch_ThreadUnsafe, void(int));
MOCK_METHOD0(QueryState, std::vector<std::shared_ptr<InferenceEngine::IVariableStateInternal>>(void));
MOCK_METHOD0(Cancel, InferenceEngine::StatusCode());
};

View File

@ -22,4 +22,5 @@ public:
using InferRequestInternal::GetBlob;
MOCK_METHOD0(InferImpl, void());
MOCK_CONST_METHOD1(GetPerformanceCounts, void(std::map<std::string, InferenceEngineProfileInfo> &));
MOCK_METHOD0(Cancel, InferenceEngine::StatusCode());
};

View File

@ -28,4 +28,5 @@ public:
MOCK_METHOD1(SetCompletionCallback, void(InferenceEngine::IInferRequest::CompletionCallback));
MOCK_METHOD1(SetBatch, void(int));
MOCK_METHOD0(QueryState, std::vector<IVariableStateInternal::Ptr>());
MOCK_METHOD0(Cancel, InferenceEngine::StatusCode());
};

View File

@ -35,4 +35,5 @@ public:
MOCK_QUALIFIED_METHOD4(SetBlob, noexcept, StatusCode(const char*, const Blob::Ptr&, const PreProcessInfo&, ResponseDesc*));
MOCK_QUALIFIED_METHOD2(SetBatch, noexcept, StatusCode(int batch, ResponseDesc*));
MOCK_QUALIFIED_METHOD3(QueryState, noexcept, StatusCode(IVariableState::Ptr &, size_t, ResponseDesc *));
MOCK_QUALIFIED_METHOD1(Cancel, noexcept, InferenceEngine::StatusCode(ResponseDesc*));
};

View File

@ -13,7 +13,8 @@ static std::shared_ptr<ngraph::Function> makeConvPoolRelu(std::vector<size_t> in
ngraph::element::Type_t ngPrc = ngraph::element::Type_t::f32) {
auto params = ngraph::builder::makeParams(ngPrc, {inputShape});
params.front()->set_friendly_name("Param_1");
auto const1 = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{4}, ngraph::Shape{1, 32, 1, 32});
std::vector<size_t> constShape = {inputShape[0], inputShape[2], inputShape[1], inputShape[3]};
auto const1 = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{4}, constShape);
const1->set_friendly_name("Const_1");
auto reshape1 = std::make_shared<ngraph::opset1::Reshape>(params.front(), const1, false);
reshape1->set_friendly_name("Reshape_1");
@ -27,7 +28,9 @@ static std::shared_ptr<ngraph::Function> makeConvPoolRelu(std::vector<size_t> in
pool1->set_friendly_name("Pool_1");
auto relu1 = std::make_shared<ngraph::opset1::Relu>(pool1);
relu1->set_friendly_name("Relu_1");
auto const2 = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{2}, ngraph::Shape{1, 116});
ngraph::Shape reluShape = relu1->outputs()[0].get_tensor().get_shape();
std::vector<size_t> constShape2 = {1, ngraph::shape_size(reluShape)};
auto const2 = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{2}, constShape2);
const2->set_friendly_name("Const_2");
auto reshape2 = std::make_shared<ngraph::opset1::Reshape>(relu1, const2, false);
reshape2->set_friendly_name("Reshape_2");