Blob processing in HETERO forwarded to sub requests (#8012)

* Blob processing in HETERO forwarded to subrequests

* Added template dynamic hetero tests
This commit is contained in:
Anton Pankratv 2021-10-15 10:44:11 +03:00 committed by GitHub
parent fa38103e5b
commit 4cc53c97a1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 71 additions and 67 deletions

View File

@ -10,6 +10,7 @@ addIeTargetTest(
ROOT ${CMAKE_CURRENT_SOURCE_DIR} ROOT ${CMAKE_CURRENT_SOURCE_DIR}
DEPENDENCIES DEPENDENCIES
templatePlugin templatePlugin
HeteroPlugin
LINK_LIBRARIES LINK_LIBRARIES
IE::funcSharedTests IE::funcSharedTests
INCLUDES INCLUDES

View File

@ -14,6 +14,9 @@ const std::vector<std::map<std::string, std::string>> configs = {
{} {}
}; };
const std::vector<std::map<std::string, std::string>> HeteroConfigs = {
{{"TARGET_FALLBACK", CommonTestUtils::DEVICE_TEMPLATE}}};
INSTANTIATE_TEST_SUITE_P(smoke_BehaviorTests, InferRequestDynamicTests, INSTANTIATE_TEST_SUITE_P(smoke_BehaviorTests, InferRequestDynamicTests,
::testing::Combine( ::testing::Combine(
::testing::Values(ngraph::builder::subgraph::makeSplitConvConcat()), ::testing::Values(ngraph::builder::subgraph::makeSplitConvConcat()),
@ -23,4 +26,13 @@ INSTANTIATE_TEST_SUITE_P(smoke_BehaviorTests, InferRequestDynamicTests,
::testing::ValuesIn(configs)), ::testing::ValuesIn(configs)),
InferRequestDynamicTests::getTestCaseName); InferRequestDynamicTests::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_Hetero_BehaviorTests, InferRequestDynamicTests,
::testing::Combine(
::testing::Values(ngraph::builder::subgraph::makeSplitConvConcat()),
::testing::Values(std::vector<std::pair<std::vector<size_t>, std::vector<size_t>>>{{{1, 4, 20, 20}, {1, 10, 18, 18}},
{{2, 4, 20, 20}, {2, 10, 18, 18}}}),
::testing::Values(CommonTestUtils::DEVICE_HETERO),
::testing::ValuesIn(HeteroConfigs)),
InferRequestDynamicTests::getTestCaseName);
} // namespace } // namespace

View File

@ -13,6 +13,9 @@ const std::vector<std::map<std::string, std::string>> configs = {
{} {}
}; };
const std::vector<std::map<std::string, std::string>> HeteroConfigs = {
{{"TARGET_FALLBACK", CommonTestUtils::DEVICE_TEMPLATE}}};
INSTANTIATE_TEST_SUITE_P(smoke_BehaviorTests, OVInferenceChaining, INSTANTIATE_TEST_SUITE_P(smoke_BehaviorTests, OVInferenceChaining,
::testing::Combine( ::testing::Combine(
::testing::Values(ov::element::f32), ::testing::Values(ov::element::f32),
@ -20,4 +23,11 @@ INSTANTIATE_TEST_SUITE_P(smoke_BehaviorTests, OVInferenceChaining,
::testing::ValuesIn(configs)), ::testing::ValuesIn(configs)),
OVInferenceChaining::getTestCaseName); OVInferenceChaining::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_Hetero_BehaviorTests, OVInferenceChaining,
::testing::Combine(
::testing::Values(ov::element::f32),
::testing::Values(CommonTestUtils::DEVICE_HETERO),
::testing::ValuesIn(HeteroConfigs)),
OVInferenceChaining::getTestCaseName);
} // namespace } // namespace

View File

@ -43,11 +43,6 @@ HeteroAsyncInferRequest::HeteroAsyncInferRequest(const IInferRequestInternal::Pt
} }
} }
void HeteroAsyncInferRequest::StartAsync_ThreadUnsafe() {
_heteroInferRequest->updateInOutIfNeeded();
RunFirstStage(_pipeline.begin(), _pipeline.end());
}
StatusCode HeteroAsyncInferRequest::Wait(int64_t millis_timeout) { StatusCode HeteroAsyncInferRequest::Wait(int64_t millis_timeout) {
auto waitStatus = StatusCode::OK; auto waitStatus = StatusCode::OK;
try { try {

View File

@ -18,7 +18,6 @@ public:
const InferenceEngine::ITaskExecutor::Ptr& taskExecutor, const InferenceEngine::ITaskExecutor::Ptr& taskExecutor,
const InferenceEngine::ITaskExecutor::Ptr& callbackExecutor); const InferenceEngine::ITaskExecutor::Ptr& callbackExecutor);
~HeteroAsyncInferRequest(); ~HeteroAsyncInferRequest();
void StartAsync_ThreadUnsafe() override;
InferenceEngine::StatusCode Wait(int64_t millis_timeout) override; InferenceEngine::StatusCode Wait(int64_t millis_timeout) override;
private: private:

View File

@ -36,11 +36,14 @@ HeteroInferRequest::HeteroInferRequest(InferenceEngine::InputsDataMap networkInp
bool emplaced = false; bool emplaced = false;
std::tie(itBlob, emplaced) = _blobs.emplace(intermediateBlobName, Blob::Ptr{}); std::tie(itBlob, emplaced) = _blobs.emplace(intermediateBlobName, Blob::Ptr{});
if (emplaced) { if (emplaced) {
itBlob->second = r->GetBlob(blobName); if (InferenceEngine::details::contains(_networkInputs, blobName)) {
if (InferenceEngine::details::contains(networkInputs, blobName)) { _subRequestFromBlobName.emplace(blobName, r._ptr.get());
_inputs[blobName] = itBlob->second; _blobs.erase(intermediateBlobName);
} else if (InferenceEngine::details::contains(networkOutputs, blobName)) { } else if (InferenceEngine::details::contains(_networkOutputs, blobName)) {
_outputs[blobName] = itBlob->second; _subRequestFromBlobName.emplace(blobName, r._ptr.get());
_blobs.erase(intermediateBlobName);
} else {
itBlob->second = r->GetBlob(blobName);
} }
} else { } else {
r->SetBlob(blobName, itBlob->second); r->SetBlob(blobName, itBlob->second);
@ -64,25 +67,39 @@ HeteroInferRequest::HeteroInferRequest(InferenceEngine::InputsDataMap networkInp
} }
} }
void HeteroInferRequest::SetBlob(const std::string& name, const InferenceEngine::Blob::Ptr& data) { void HeteroInferRequest::SetBlob(const std::string& name, const InferenceEngine::Blob::Ptr& blob) {
InferenceEngine::IInferRequestInternal::SetBlob(name, data); auto itRequest = _subRequestFromBlobName.find(name);
assert(!_inferRequests.empty()); if (itRequest == _subRequestFromBlobName.end()) {
for (auto &&desc : _inferRequests) { IE_THROW() << "There is no infer requests binded to blob with name: " << name;
auto &r = desc._request;
assert(r);
InputInfo::Ptr foundInput;
DataPtr foundOutput;
try {
// if `name` is input blob
if (findInputAndOutputBlobByName(name, foundInput, foundOutput)) {
r->SetBlob(name, data, foundInput->getPreProcess());
}
} catch (const InferenceEngine::NotFound&) {}
} }
itRequest->second->SetBlob(name, blob);
}
InferenceEngine::Blob::Ptr HeteroInferRequest::GetBlob(const std::string& name) {
auto itRequest = _subRequestFromBlobName.find(name);
if (itRequest == _subRequestFromBlobName.end()) {
IE_THROW() << "There is no infer requests binded to blob with name: " << name;
}
return itRequest->second->GetBlob(name);
}
void HeteroInferRequest::SetBlob(const std::string& name, const Blob::Ptr& blob, const PreProcessInfo& info) {
auto itRequest = _subRequestFromBlobName.find(name);
if (itRequest == _subRequestFromBlobName.end()) {
IE_THROW() << "There is no infer requests binded to blob with name: " << name;
}
itRequest->second->SetBlob(name, blob, info);
}
const InferenceEngine::PreProcessInfo& HeteroInferRequest::GetPreProcess(const std::string& name) const {
auto itRequest = _subRequestFromBlobName.find(name);
if (itRequest == _subRequestFromBlobName.end()) {
IE_THROW() << "There is no infer requests binded to blob with name: " << name;
}
return itRequest->second->GetPreProcess(name);
} }
void HeteroInferRequest::InferImpl() { void HeteroInferRequest::InferImpl() {
updateInOutIfNeeded();
for (auto &&desc : _inferRequests) { for (auto &&desc : _inferRequests) {
OV_ITT_SCOPED_TASK(itt::domains::HeteroPlugin, desc._profilingTask); OV_ITT_SCOPED_TASK(itt::domains::HeteroPlugin, desc._profilingTask);
auto &r = desc._request; auto &r = desc._request;
@ -101,40 +118,3 @@ std::map<std::string, InferenceEngineProfileInfo> HeteroInferRequest::GetPerform
} }
return perfMap; return perfMap;
} }
void HeteroInferRequest::updateInOutIfNeeded() {
OV_ITT_SCOPED_TASK(itt::domains::HeteroPlugin, "updateInOutIfNeeded");
assert(!_inferRequests.empty());
for (auto &&desc : _inferRequests) {
auto &r = desc._request;
assert(r);
for (auto&& inputInfo : desc._network->GetInputsInfo()) {
auto& ioname = inputInfo.first;
auto iti = _inputs.find(ioname);
if (iti != _inputs.end()) {
auto it = _preProcData.find(ioname);
if (it != _preProcData.end()) {
if (it->second->getRoiBlob() != _blobs[ioname]) {
r->SetBlob(ioname.c_str(), it->second->getRoiBlob());
_blobs[ioname] = iti->second;
}
} else {
if (iti->second != _blobs[ioname]) {
r->SetBlob(ioname.c_str(), iti->second);
_blobs[ioname] = iti->second;
}
}
}
}
for (auto&& outputInfo : desc._network->GetOutputsInfo()) {
auto& ioname = outputInfo.first;
auto ito = _outputs.find(ioname);
if (ito != _outputs.end()) {
if (ito->second != _blobs[ioname]) {
r->SetBlob(ioname.c_str(), ito->second);
_blobs[ioname] = ito->second;
}
}
}
}
}

View File

@ -34,14 +34,21 @@ public:
void InferImpl() override; void InferImpl() override;
void SetBlob(const std::string& name, const InferenceEngine::Blob::Ptr& data) override; void SetBlob(const std::string& name, const InferenceEngine::Blob::Ptr& blob) override;
InferenceEngine::Blob::Ptr GetBlob(const std::string& name) override;
void SetBlob(const std::string& name,
const InferenceEngine::Blob::Ptr& blob,
const InferenceEngine::PreProcessInfo& info) override;
const InferenceEngine::PreProcessInfo& GetPreProcess(const std::string& name) const override;
std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> GetPerformanceCounts() const override; std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> GetPerformanceCounts() const override;
void updateInOutIfNeeded();
SubRequestsList _inferRequests; SubRequestsList _inferRequests;
std::map<std::string, InferenceEngine::Blob::Ptr> _blobs; std::map<std::string, InferenceEngine::Blob::Ptr> _blobs;
std::map<std::string, InferenceEngine::IInferRequestInternal*> _subRequestFromBlobName;
}; };
} // namespace HeteroPlugin } // namespace HeteroPlugin