Feature/drizshko/cancellable request (#2635)
Added Cancelability to an Infer Request class (actually implemented for the CPU only, with a stub for other devices)
This commit is contained in:
parent
2495eaf56f
commit
77ecd7e17c
@ -131,6 +131,13 @@ void TemplateInferRequest::InferImpl() {
|
||||
}
|
||||
// ! [infer_request:infer_impl]
|
||||
|
||||
// ! [infer_request:cancel]
|
||||
InferenceEngine::StatusCode TemplateInferRequest::Cancel() {
|
||||
// TODO: add code to handle cancellation request
|
||||
return InferenceEngine::OK;
|
||||
}
|
||||
// ! [infer_request:cancel]
|
||||
|
||||
template<typename SrcT, typename DstT>
|
||||
static void blobCopy(const Blob::Ptr& src, const Blob::Ptr& dst) {
|
||||
std::copy_n(InferenceEngine::as<InferenceEngine::MemoryBlob>(src)->rmap().as<const SrcT*>(),
|
||||
|
@ -40,6 +40,8 @@ public:
|
||||
void InferImpl() override;
|
||||
void GetPerformanceCounts(std::map<std::string, InferenceEngine::InferenceEngineProfileInfo>& perfMap) const override;
|
||||
|
||||
InferenceEngine::StatusCode Cancel() override;
|
||||
|
||||
// pipeline methods-stages which are used in async infer request implementation and assigned to particular executor
|
||||
void inferPreprocess();
|
||||
void startPipeline();
|
||||
|
@ -162,6 +162,12 @@ public:
|
||||
CALL_STATUS_FNC_NO_ARGS(Infer);
|
||||
}
|
||||
|
||||
StatusCode Cancel() {
|
||||
ResponseDesc resp;
|
||||
if (actual == nullptr) THROW_IE_EXCEPTION << "InferRequest was not initialized.";
|
||||
return actual->Cancel(&resp);
|
||||
}
|
||||
|
||||
/**
|
||||
* @copybrief IInferRequest::GetPerformanceCounts
|
||||
*
|
||||
@ -233,7 +239,8 @@ public:
|
||||
ResponseDesc resp;
|
||||
if (actual == nullptr) THROW_IE_EXCEPTION << "InferRequest was not initialized.";
|
||||
auto res = actual->Wait(millis_timeout, &resp);
|
||||
if (res != OK && res != RESULT_NOT_READY && res != INFER_NOT_STARTED) {
|
||||
if (res != OK && res != RESULT_NOT_READY &&
|
||||
res != INFER_NOT_STARTED && res != INFER_CANCELLED) {
|
||||
InferenceEngine::details::extract_exception(res, resp.msg);
|
||||
}
|
||||
return res;
|
||||
|
@ -235,7 +235,8 @@ enum StatusCode : int {
|
||||
RESULT_NOT_READY = -9,
|
||||
NOT_ALLOCATED = -10,
|
||||
INFER_NOT_STARTED = -11,
|
||||
NETWORK_NOT_READ = -12
|
||||
NETWORK_NOT_READ = -12,
|
||||
INFER_CANCELLED = -13
|
||||
};
|
||||
|
||||
/**
|
||||
|
@ -96,6 +96,12 @@ public:
|
||||
* @return Status code of the operation: InferenceEngine::OK (0) for success
|
||||
*/
|
||||
virtual StatusCode Infer(ResponseDesc* resp) noexcept = 0;
|
||||
/**
|
||||
* @brief Cancels current async inference request
|
||||
* @param resp Optional: pointer to an already allocated object to contain information in case of failure
|
||||
* @return Status code of the operation: InferenceEngine::OK (0) for success
|
||||
*/
|
||||
virtual StatusCode Cancel(ResponseDesc* resp) noexcept = 0;
|
||||
|
||||
/**
|
||||
* @brief Queries performance measures per layer to get feedback of what is the most time consuming layer
|
||||
|
@ -121,5 +121,9 @@ class GNAInferRequest : public InferenceEngine::AsyncInferRequestInternal {
|
||||
return plg->QueryState();
|
||||
}
|
||||
IE_SUPPRESS_DEPRECATED_END
|
||||
|
||||
InferenceEngine::StatusCode Cancel() override {
|
||||
return InferenceEngine::NOT_IMPLEMENTED;
|
||||
}
|
||||
};
|
||||
} // namespace GNAPluginNS
|
||||
|
@ -767,6 +767,11 @@ void MKLDNNGraph::Infer(int batch) {
|
||||
|
||||
mkldnn::stream stream = mkldnn::stream(stream::kind::eager);
|
||||
for (int i = 0; i < graphNodes.size(); i++) {
|
||||
if (IsCancellationRequested()) {
|
||||
ResetCancellationRequest();
|
||||
THROW_IE_EXCEPTION << InferenceEngine::details::as_status << InferenceEngine::INFER_CANCELLED;
|
||||
}
|
||||
|
||||
PERF(graphNodes[i]);
|
||||
|
||||
if (batch > 0)
|
||||
|
@ -16,6 +16,7 @@
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <atomic>
|
||||
|
||||
namespace MKLDNNPlugin {
|
||||
|
||||
@ -29,7 +30,7 @@ public:
|
||||
Ready = 1,
|
||||
};
|
||||
|
||||
MKLDNNGraph(): status(NotReady), eng(mkldnn::engine(mkldnn::engine::kind::cpu, 0)) {}
|
||||
MKLDNNGraph(): status(NotReady), eng(mkldnn::engine(mkldnn::engine::kind::cpu, 0)), cancelation_requested(false) {}
|
||||
|
||||
Status GetStatus() {
|
||||
return status;
|
||||
@ -39,6 +40,10 @@ public:
|
||||
return (GetStatus() == Ready);
|
||||
}
|
||||
|
||||
void Cancel() {
|
||||
cancelation_requested.store(true);
|
||||
}
|
||||
|
||||
void setConfig(const Config &cfg);
|
||||
void setProperty(const std::map<std::string, std::string> &properties);
|
||||
Config getProperty();
|
||||
@ -124,6 +129,14 @@ public:
|
||||
void SortTopologically();
|
||||
|
||||
protected:
|
||||
bool IsCancellationRequested() const {
|
||||
return cancelation_requested.load();
|
||||
}
|
||||
|
||||
void ResetCancellationRequest() {
|
||||
cancelation_requested.store(false);
|
||||
}
|
||||
|
||||
void VisitNode(MKLDNNNodePtr node, std::vector<MKLDNNNodePtr>& sortedNodes);
|
||||
|
||||
void ForgetGraphData() {
|
||||
@ -185,6 +198,8 @@ private:
|
||||
InferenceEngine::CNNLayerPtr cnnLayer;
|
||||
size_t outIdx;
|
||||
};
|
||||
|
||||
std::atomic<bool> cancelation_requested;
|
||||
};
|
||||
|
||||
} // namespace MKLDNNPlugin
|
||||
|
@ -149,6 +149,11 @@ void MKLDNNPlugin::MKLDNNInferRequest::InferImpl() {
|
||||
graph->PullOutputData(_outputs);
|
||||
}
|
||||
|
||||
InferenceEngine::StatusCode MKLDNNPlugin::MKLDNNInferRequest::Cancel() {
|
||||
graph->Cancel();
|
||||
return InferenceEngine::OK;
|
||||
}
|
||||
|
||||
void MKLDNNPlugin::MKLDNNInferRequest::GetPerformanceCounts(
|
||||
std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> &perfMap) const {
|
||||
if (!graph || !graph->IsReady())
|
||||
|
@ -25,6 +25,8 @@ public:
|
||||
|
||||
void InferImpl() override;
|
||||
|
||||
InferenceEngine::StatusCode Cancel() override;
|
||||
|
||||
void GetPerformanceCounts(std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> &perfMap) const override;
|
||||
|
||||
/**
|
||||
|
@ -38,6 +38,11 @@ public:
|
||||
TO_STATUS(_impl->Infer());
|
||||
}
|
||||
|
||||
StatusCode Cancel(ResponseDesc* resp) noexcept override {
|
||||
OV_ITT_SCOPED_TASK(itt::domains::Plugin, "Cancel");
|
||||
NO_EXCEPT_CALL_RETURN_STATUS(_impl->Cancel());
|
||||
}
|
||||
|
||||
StatusCode GetPerformanceCounts(std::map<std::string, InferenceEngineProfileInfo>& perfMap,
|
||||
ResponseDesc* resp) const noexcept override {
|
||||
TO_STATUS(_impl->GetPerformanceCounts(perfMap));
|
||||
|
@ -156,6 +156,14 @@ public:
|
||||
return _syncRequest->QueryState();
|
||||
}
|
||||
|
||||
StatusCode Cancel() override {
|
||||
StatusCode status = Wait(IInferRequest::WaitMode::STATUS_ONLY);
|
||||
if (status == INFER_NOT_STARTED) {
|
||||
return status;
|
||||
}
|
||||
return _syncRequest->Cancel();
|
||||
}
|
||||
|
||||
protected:
|
||||
/**
|
||||
* @brief Each pipeline stage is a @ref Task that is executed by specified ITaskExecutor implementation
|
||||
|
@ -63,6 +63,13 @@ public:
|
||||
InferImpl();
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Default common implementation for all plugins
|
||||
*/
|
||||
StatusCode Cancel() override {
|
||||
return InferenceEngine::NOT_IMPLEMENTED;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Given optional implementation of setting blob to avoid need for it to be implemented by plugin
|
||||
* @param name - a name of input or output blob.
|
||||
|
@ -39,6 +39,11 @@ public:
|
||||
*/
|
||||
virtual void Infer() = 0;
|
||||
|
||||
/**
|
||||
* @brief Cancel current inference request execution
|
||||
*/
|
||||
virtual StatusCode Cancel() = 0;
|
||||
|
||||
/**
|
||||
* @brief Queries performance measures per layer to get feedback of what is the most time consuming layer.
|
||||
* Note: not all plugins may provide meaningful data
|
||||
|
@ -0,0 +1,24 @@
|
||||
// Copyright (C) 2018-2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "behavior/infer_request_cancellation.hpp"
|
||||
|
||||
using namespace BehaviorTestsDefinitions;
|
||||
namespace {
|
||||
const std::vector<InferenceEngine::Precision> netPrecisions = {
|
||||
InferenceEngine::Precision::FP32,
|
||||
InferenceEngine::Precision::FP16
|
||||
};
|
||||
|
||||
const std::vector<std::map<std::string, std::string>> configs = {
|
||||
{},
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(smoke_BehaviorTests, CancellationTests,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(netPrecisions),
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU),
|
||||
::testing::ValuesIn(configs)),
|
||||
CancellationTests::getTestCaseName);
|
||||
} // namespace
|
@ -0,0 +1,144 @@
|
||||
// Copyright (C) 2018-2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <future>
|
||||
#include "ie_extension.h"
|
||||
#include <condition_variable>
|
||||
#include "functional_test_utils/layer_test_utils.hpp"
|
||||
#include "ngraph_functions/utils/ngraph_helpers.hpp"
|
||||
#include "ngraph_functions/builders.hpp"
|
||||
#include <ie_core.hpp>
|
||||
#include <functional_test_utils/behavior_test_utils.hpp>
|
||||
#include "common_test_utils/common_utils.hpp"
|
||||
#include "functional_test_utils/plugin_cache.hpp"
|
||||
#include "functional_test_utils/blob_utils.hpp"
|
||||
#include "ngraph_functions/pass/convert_prc.hpp"
|
||||
#include "ngraph_functions/subgraph_builders.hpp"
|
||||
#include "behavior/infer_request_cancellation.hpp"
|
||||
|
||||
namespace BehaviorTestsDefinitions {
|
||||
using CancellationTests = BehaviorTestsUtils::BehaviorTestsBasic;
|
||||
|
||||
TEST_P(CancellationTests, canCancelAsyncRequest) {
|
||||
// Skip test according to plugin specific disabledTestPatterns() (if any)
|
||||
SKIP_IF_CURRENT_TEST_IS_DISABLED()
|
||||
std::shared_ptr<ngraph::Function> largeNetwork = ngraph::builder::subgraph::makeConvPoolRelu({1, 3, 640, 640});
|
||||
// Create CNNNetwork from ngrpah::Function
|
||||
InferenceEngine::CNNNetwork cnnNet(largeNetwork);
|
||||
// Load CNNNetwork to target plugins
|
||||
auto execNet = ie->LoadNetwork(cnnNet, targetDevice, configuration);
|
||||
// Create InferRequest
|
||||
InferenceEngine::InferRequest req = execNet.CreateInferRequest();
|
||||
bool cancelled = false;
|
||||
req.SetCompletionCallback<std::function<void(InferenceEngine::InferRequest, InferenceEngine::StatusCode)>>(
|
||||
[&](InferenceEngine::InferRequest request, InferenceEngine::StatusCode status) {
|
||||
if (targetDevice == CommonTestUtils::DEVICE_CPU) {
|
||||
cancelled = (status == InferenceEngine::StatusCode::INFER_CANCELLED);
|
||||
}
|
||||
});
|
||||
|
||||
req.StartAsync();
|
||||
InferenceEngine::StatusCode cancelStatus = req.Cancel();
|
||||
InferenceEngine::StatusCode waitStatus = req.Wait(InferenceEngine::IInferRequest::WaitMode::RESULT_READY);
|
||||
|
||||
if (targetDevice == CommonTestUtils::DEVICE_CPU) {
|
||||
ASSERT_EQ(true, cancelStatus == InferenceEngine::StatusCode::OK ||
|
||||
cancelStatus == InferenceEngine::StatusCode::INFER_NOT_STARTED);
|
||||
if (cancelStatus == InferenceEngine::StatusCode::OK) {
|
||||
ASSERT_EQ(true, cancelled);
|
||||
ASSERT_EQ(static_cast<int>(InferenceEngine::StatusCode::INFER_CANCELLED), waitStatus);
|
||||
} else {
|
||||
ASSERT_EQ(false, cancelled);
|
||||
ASSERT_EQ(static_cast<int>(InferenceEngine::StatusCode::OK), waitStatus);
|
||||
}
|
||||
} else {
|
||||
ASSERT_EQ(static_cast<int>(InferenceEngine::StatusCode::NOT_IMPLEMENTED), cancelStatus);
|
||||
ASSERT_EQ(static_cast<int>(InferenceEngine::StatusCode::OK), waitStatus);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_P(CancellationTests, canResetAfterCancelAsyncRequest) {
|
||||
// Skip test according to plugin specific disabledTestPatterns() (if any)
|
||||
SKIP_IF_CURRENT_TEST_IS_DISABLED()
|
||||
// Create CNNNetwork from ngrpah::Function
|
||||
InferenceEngine::CNNNetwork cnnNet(function);
|
||||
// Load CNNNetwork to target plugins
|
||||
auto execNet = ie->LoadNetwork(cnnNet, targetDevice, configuration);
|
||||
// Create InferRequest
|
||||
InferenceEngine::InferRequest req = execNet.CreateInferRequest();
|
||||
|
||||
req.StartAsync();
|
||||
req.Cancel();
|
||||
req.Wait(InferenceEngine::IInferRequest::WaitMode::RESULT_READY);
|
||||
|
||||
req.StartAsync();
|
||||
InferenceEngine::StatusCode waitStatus = req.Wait(InferenceEngine::IInferRequest::WaitMode::RESULT_READY);
|
||||
|
||||
ASSERT_EQ(static_cast<int>(InferenceEngine::StatusCode::OK), waitStatus);
|
||||
}
|
||||
|
||||
TEST_P(CancellationTests, canCancelBeforeAsyncRequest) {
|
||||
// Skip test according to plugin specific disabledTestPatterns() (if any)
|
||||
SKIP_IF_CURRENT_TEST_IS_DISABLED()
|
||||
// Create CNNNetwork from ngrpah::Function
|
||||
InferenceEngine::CNNNetwork cnnNet(function);
|
||||
// Load CNNNetwork to target plugins
|
||||
auto execNet = ie->LoadNetwork(cnnNet, targetDevice, configuration);
|
||||
// Create InferRequest
|
||||
InferenceEngine::InferRequest req = execNet.CreateInferRequest();
|
||||
|
||||
InferenceEngine::StatusCode cancelStatus = req.Cancel();
|
||||
|
||||
if (targetDevice == CommonTestUtils::DEVICE_CPU) {
|
||||
ASSERT_EQ(static_cast<int>(InferenceEngine::StatusCode::INFER_NOT_STARTED), cancelStatus);
|
||||
} else {
|
||||
ASSERT_EQ(static_cast<int>(InferenceEngine::StatusCode::NOT_IMPLEMENTED), cancelStatus);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_P(CancellationTests, canCancelInferRequest) {
|
||||
// Skip test according to plugin specific disabledTestPatterns() (if any)
|
||||
SKIP_IF_CURRENT_TEST_IS_DISABLED()
|
||||
// Create function with large input, to have a time to Cancel request
|
||||
std::shared_ptr<ngraph::Function> largeNetwork = ngraph::builder::subgraph::makeConvPoolRelu({1, 3, 640, 640});
|
||||
// Create CNNNetwork from ngrpah::Function
|
||||
InferenceEngine::CNNNetwork cnnNet(largeNetwork);
|
||||
// Load CNNNetwork to target plugins
|
||||
auto execNet = ie->LoadNetwork(cnnNet, targetDevice, configuration);
|
||||
// Create InferRequest
|
||||
InferenceEngine::InferRequest req = execNet.CreateInferRequest();
|
||||
|
||||
auto infer = std::async(std::launch::async, [&req]{ req.Infer(); });
|
||||
|
||||
const auto statusOnly = InferenceEngine::IInferRequest::WaitMode::STATUS_ONLY;
|
||||
while (req.Wait(statusOnly) == InferenceEngine::StatusCode::INFER_NOT_STARTED) {
|
||||
}
|
||||
|
||||
InferenceEngine::StatusCode cancelStatus = req.Cancel();
|
||||
InferenceEngine::StatusCode inferStatus = InferenceEngine::StatusCode::OK;
|
||||
|
||||
try {
|
||||
infer.get();
|
||||
} catch (InferenceEngine::details::InferenceEngineException& ex) {
|
||||
inferStatus = ex.getStatus();
|
||||
}
|
||||
|
||||
if (targetDevice == CommonTestUtils::DEVICE_CPU) {
|
||||
if (cancelStatus == InferenceEngine::StatusCode::OK) {
|
||||
ASSERT_EQ(static_cast<int>(InferenceEngine::StatusCode::INFER_CANCELLED), inferStatus);
|
||||
} else {
|
||||
ASSERT_EQ(static_cast<int>(InferenceEngine::StatusCode::OK), inferStatus);
|
||||
}
|
||||
} else {
|
||||
ASSERT_EQ(static_cast<int>(InferenceEngine::StatusCode::NOT_IMPLEMENTED), cancelStatus);
|
||||
ASSERT_EQ(static_cast<int>(InferenceEngine::StatusCode::OK), inferStatus);
|
||||
}
|
||||
}
|
||||
} // namespace BehaviorTestsDefinitions
|
@ -32,4 +32,6 @@ public:
|
||||
MOCK_METHOD1(setNetworkOutputs, void(OutputsDataMap));
|
||||
MOCK_METHOD2(GetBlob, void(const char *name, Blob::Ptr &));
|
||||
MOCK_METHOD1(SetCompletionCallback, void(IInferRequest::CompletionCallback));
|
||||
MOCK_METHOD0(Cancel, InferenceEngine::StatusCode());
|
||||
MOCK_METHOD0(Cancel_ThreadUnsafe, InferenceEngine::StatusCode());
|
||||
};
|
||||
|
@ -62,4 +62,6 @@ public:
|
||||
MOCK_METHOD1(SetBatch, void(int));
|
||||
MOCK_METHOD1(SetBatch_ThreadUnsafe, void(int));
|
||||
MOCK_METHOD0(QueryState, std::vector<std::shared_ptr<InferenceEngine::IVariableStateInternal>>(void));
|
||||
|
||||
MOCK_METHOD0(Cancel, InferenceEngine::StatusCode());
|
||||
};
|
||||
|
@ -22,4 +22,5 @@ public:
|
||||
using InferRequestInternal::GetBlob;
|
||||
MOCK_METHOD0(InferImpl, void());
|
||||
MOCK_CONST_METHOD1(GetPerformanceCounts, void(std::map<std::string, InferenceEngineProfileInfo> &));
|
||||
MOCK_METHOD0(Cancel, InferenceEngine::StatusCode());
|
||||
};
|
||||
|
@ -28,4 +28,5 @@ public:
|
||||
MOCK_METHOD1(SetCompletionCallback, void(InferenceEngine::IInferRequest::CompletionCallback));
|
||||
MOCK_METHOD1(SetBatch, void(int));
|
||||
MOCK_METHOD0(QueryState, std::vector<IVariableStateInternal::Ptr>());
|
||||
MOCK_METHOD0(Cancel, InferenceEngine::StatusCode());
|
||||
};
|
||||
|
@ -35,4 +35,5 @@ public:
|
||||
MOCK_QUALIFIED_METHOD4(SetBlob, noexcept, StatusCode(const char*, const Blob::Ptr&, const PreProcessInfo&, ResponseDesc*));
|
||||
MOCK_QUALIFIED_METHOD2(SetBatch, noexcept, StatusCode(int batch, ResponseDesc*));
|
||||
MOCK_QUALIFIED_METHOD3(QueryState, noexcept, StatusCode(IVariableState::Ptr &, size_t, ResponseDesc *));
|
||||
MOCK_QUALIFIED_METHOD1(Cancel, noexcept, InferenceEngine::StatusCode(ResponseDesc*));
|
||||
};
|
||||
|
@ -13,7 +13,8 @@ static std::shared_ptr<ngraph::Function> makeConvPoolRelu(std::vector<size_t> in
|
||||
ngraph::element::Type_t ngPrc = ngraph::element::Type_t::f32) {
|
||||
auto params = ngraph::builder::makeParams(ngPrc, {inputShape});
|
||||
params.front()->set_friendly_name("Param_1");
|
||||
auto const1 = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{4}, ngraph::Shape{1, 32, 1, 32});
|
||||
std::vector<size_t> constShape = {inputShape[0], inputShape[2], inputShape[1], inputShape[3]};
|
||||
auto const1 = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{4}, constShape);
|
||||
const1->set_friendly_name("Const_1");
|
||||
auto reshape1 = std::make_shared<ngraph::opset1::Reshape>(params.front(), const1, false);
|
||||
reshape1->set_friendly_name("Reshape_1");
|
||||
@ -27,7 +28,9 @@ static std::shared_ptr<ngraph::Function> makeConvPoolRelu(std::vector<size_t> in
|
||||
pool1->set_friendly_name("Pool_1");
|
||||
auto relu1 = std::make_shared<ngraph::opset1::Relu>(pool1);
|
||||
relu1->set_friendly_name("Relu_1");
|
||||
auto const2 = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{2}, ngraph::Shape{1, 116});
|
||||
ngraph::Shape reluShape = relu1->outputs()[0].get_tensor().get_shape();
|
||||
std::vector<size_t> constShape2 = {1, ngraph::shape_size(reluShape)};
|
||||
auto const2 = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{2}, constShape2);
|
||||
const2->set_friendly_name("Const_2");
|
||||
auto reshape2 = std::make_shared<ngraph::opset1::Reshape>(relu1, const2, false);
|
||||
reshape2->set_friendly_name("Reshape_2");
|
||||
|
Loading…
Reference in New Issue
Block a user