Blob processing in HETERO forwarded to sub requests (#8012)
* Blob processing in HETERO forwarded to subrequests * Added template dynamic hetero tests
This commit is contained in:
parent
fa38103e5b
commit
4cc53c97a1
@ -10,6 +10,7 @@ addIeTargetTest(
|
||||
ROOT ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
DEPENDENCIES
|
||||
templatePlugin
|
||||
HeteroPlugin
|
||||
LINK_LIBRARIES
|
||||
IE::funcSharedTests
|
||||
INCLUDES
|
||||
|
@ -14,6 +14,9 @@ const std::vector<std::map<std::string, std::string>> configs = {
|
||||
{}
|
||||
};
|
||||
|
||||
const std::vector<std::map<std::string, std::string>> HeteroConfigs = {
|
||||
{{"TARGET_FALLBACK", CommonTestUtils::DEVICE_TEMPLATE}}};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_BehaviorTests, InferRequestDynamicTests,
|
||||
::testing::Combine(
|
||||
::testing::Values(ngraph::builder::subgraph::makeSplitConvConcat()),
|
||||
@ -23,4 +26,13 @@ INSTANTIATE_TEST_SUITE_P(smoke_BehaviorTests, InferRequestDynamicTests,
|
||||
::testing::ValuesIn(configs)),
|
||||
InferRequestDynamicTests::getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_Hetero_BehaviorTests, InferRequestDynamicTests,
|
||||
::testing::Combine(
|
||||
::testing::Values(ngraph::builder::subgraph::makeSplitConvConcat()),
|
||||
::testing::Values(std::vector<std::pair<std::vector<size_t>, std::vector<size_t>>>{{{1, 4, 20, 20}, {1, 10, 18, 18}},
|
||||
{{2, 4, 20, 20}, {2, 10, 18, 18}}}),
|
||||
::testing::Values(CommonTestUtils::DEVICE_HETERO),
|
||||
::testing::ValuesIn(HeteroConfigs)),
|
||||
InferRequestDynamicTests::getTestCaseName);
|
||||
|
||||
} // namespace
|
||||
|
@ -13,6 +13,9 @@ const std::vector<std::map<std::string, std::string>> configs = {
|
||||
{}
|
||||
};
|
||||
|
||||
const std::vector<std::map<std::string, std::string>> HeteroConfigs = {
|
||||
{{"TARGET_FALLBACK", CommonTestUtils::DEVICE_TEMPLATE}}};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_BehaviorTests, OVInferenceChaining,
|
||||
::testing::Combine(
|
||||
::testing::Values(ov::element::f32),
|
||||
@ -20,4 +23,11 @@ INSTANTIATE_TEST_SUITE_P(smoke_BehaviorTests, OVInferenceChaining,
|
||||
::testing::ValuesIn(configs)),
|
||||
OVInferenceChaining::getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_Hetero_BehaviorTests, OVInferenceChaining,
|
||||
::testing::Combine(
|
||||
::testing::Values(ov::element::f32),
|
||||
::testing::Values(CommonTestUtils::DEVICE_HETERO),
|
||||
::testing::ValuesIn(HeteroConfigs)),
|
||||
OVInferenceChaining::getTestCaseName);
|
||||
|
||||
} // namespace
|
||||
|
@ -43,11 +43,6 @@ HeteroAsyncInferRequest::HeteroAsyncInferRequest(const IInferRequestInternal::Pt
|
||||
}
|
||||
}
|
||||
|
||||
void HeteroAsyncInferRequest::StartAsync_ThreadUnsafe() {
|
||||
_heteroInferRequest->updateInOutIfNeeded();
|
||||
RunFirstStage(_pipeline.begin(), _pipeline.end());
|
||||
}
|
||||
|
||||
StatusCode HeteroAsyncInferRequest::Wait(int64_t millis_timeout) {
|
||||
auto waitStatus = StatusCode::OK;
|
||||
try {
|
||||
|
@ -18,7 +18,6 @@ public:
|
||||
const InferenceEngine::ITaskExecutor::Ptr& taskExecutor,
|
||||
const InferenceEngine::ITaskExecutor::Ptr& callbackExecutor);
|
||||
~HeteroAsyncInferRequest();
|
||||
void StartAsync_ThreadUnsafe() override;
|
||||
InferenceEngine::StatusCode Wait(int64_t millis_timeout) override;
|
||||
|
||||
private:
|
||||
|
@ -36,11 +36,14 @@ HeteroInferRequest::HeteroInferRequest(InferenceEngine::InputsDataMap networkInp
|
||||
bool emplaced = false;
|
||||
std::tie(itBlob, emplaced) = _blobs.emplace(intermediateBlobName, Blob::Ptr{});
|
||||
if (emplaced) {
|
||||
itBlob->second = r->GetBlob(blobName);
|
||||
if (InferenceEngine::details::contains(networkInputs, blobName)) {
|
||||
_inputs[blobName] = itBlob->second;
|
||||
} else if (InferenceEngine::details::contains(networkOutputs, blobName)) {
|
||||
_outputs[blobName] = itBlob->second;
|
||||
if (InferenceEngine::details::contains(_networkInputs, blobName)) {
|
||||
_subRequestFromBlobName.emplace(blobName, r._ptr.get());
|
||||
_blobs.erase(intermediateBlobName);
|
||||
} else if (InferenceEngine::details::contains(_networkOutputs, blobName)) {
|
||||
_subRequestFromBlobName.emplace(blobName, r._ptr.get());
|
||||
_blobs.erase(intermediateBlobName);
|
||||
} else {
|
||||
itBlob->second = r->GetBlob(blobName);
|
||||
}
|
||||
} else {
|
||||
r->SetBlob(blobName, itBlob->second);
|
||||
@ -64,25 +67,39 @@ HeteroInferRequest::HeteroInferRequest(InferenceEngine::InputsDataMap networkInp
|
||||
}
|
||||
}
|
||||
|
||||
void HeteroInferRequest::SetBlob(const std::string& name, const InferenceEngine::Blob::Ptr& data) {
|
||||
InferenceEngine::IInferRequestInternal::SetBlob(name, data);
|
||||
assert(!_inferRequests.empty());
|
||||
for (auto &&desc : _inferRequests) {
|
||||
auto &r = desc._request;
|
||||
assert(r);
|
||||
InputInfo::Ptr foundInput;
|
||||
DataPtr foundOutput;
|
||||
try {
|
||||
// if `name` is input blob
|
||||
if (findInputAndOutputBlobByName(name, foundInput, foundOutput)) {
|
||||
r->SetBlob(name, data, foundInput->getPreProcess());
|
||||
}
|
||||
} catch (const InferenceEngine::NotFound&) {}
|
||||
void HeteroInferRequest::SetBlob(const std::string& name, const InferenceEngine::Blob::Ptr& blob) {
|
||||
auto itRequest = _subRequestFromBlobName.find(name);
|
||||
if (itRequest == _subRequestFromBlobName.end()) {
|
||||
IE_THROW() << "There is no infer requests binded to blob with name: " << name;
|
||||
}
|
||||
itRequest->second->SetBlob(name, blob);
|
||||
}
|
||||
|
||||
InferenceEngine::Blob::Ptr HeteroInferRequest::GetBlob(const std::string& name) {
|
||||
auto itRequest = _subRequestFromBlobName.find(name);
|
||||
if (itRequest == _subRequestFromBlobName.end()) {
|
||||
IE_THROW() << "There is no infer requests binded to blob with name: " << name;
|
||||
}
|
||||
return itRequest->second->GetBlob(name);
|
||||
}
|
||||
|
||||
void HeteroInferRequest::SetBlob(const std::string& name, const Blob::Ptr& blob, const PreProcessInfo& info) {
|
||||
auto itRequest = _subRequestFromBlobName.find(name);
|
||||
if (itRequest == _subRequestFromBlobName.end()) {
|
||||
IE_THROW() << "There is no infer requests binded to blob with name: " << name;
|
||||
}
|
||||
itRequest->second->SetBlob(name, blob, info);
|
||||
}
|
||||
|
||||
const InferenceEngine::PreProcessInfo& HeteroInferRequest::GetPreProcess(const std::string& name) const {
|
||||
auto itRequest = _subRequestFromBlobName.find(name);
|
||||
if (itRequest == _subRequestFromBlobName.end()) {
|
||||
IE_THROW() << "There is no infer requests binded to blob with name: " << name;
|
||||
}
|
||||
return itRequest->second->GetPreProcess(name);
|
||||
}
|
||||
|
||||
void HeteroInferRequest::InferImpl() {
|
||||
updateInOutIfNeeded();
|
||||
for (auto &&desc : _inferRequests) {
|
||||
OV_ITT_SCOPED_TASK(itt::domains::HeteroPlugin, desc._profilingTask);
|
||||
auto &r = desc._request;
|
||||
@ -101,40 +118,3 @@ std::map<std::string, InferenceEngineProfileInfo> HeteroInferRequest::GetPerform
|
||||
}
|
||||
return perfMap;
|
||||
}
|
||||
|
||||
void HeteroInferRequest::updateInOutIfNeeded() {
|
||||
OV_ITT_SCOPED_TASK(itt::domains::HeteroPlugin, "updateInOutIfNeeded");
|
||||
assert(!_inferRequests.empty());
|
||||
for (auto &&desc : _inferRequests) {
|
||||
auto &r = desc._request;
|
||||
assert(r);
|
||||
for (auto&& inputInfo : desc._network->GetInputsInfo()) {
|
||||
auto& ioname = inputInfo.first;
|
||||
auto iti = _inputs.find(ioname);
|
||||
if (iti != _inputs.end()) {
|
||||
auto it = _preProcData.find(ioname);
|
||||
if (it != _preProcData.end()) {
|
||||
if (it->second->getRoiBlob() != _blobs[ioname]) {
|
||||
r->SetBlob(ioname.c_str(), it->second->getRoiBlob());
|
||||
_blobs[ioname] = iti->second;
|
||||
}
|
||||
} else {
|
||||
if (iti->second != _blobs[ioname]) {
|
||||
r->SetBlob(ioname.c_str(), iti->second);
|
||||
_blobs[ioname] = iti->second;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for (auto&& outputInfo : desc._network->GetOutputsInfo()) {
|
||||
auto& ioname = outputInfo.first;
|
||||
auto ito = _outputs.find(ioname);
|
||||
if (ito != _outputs.end()) {
|
||||
if (ito->second != _blobs[ioname]) {
|
||||
r->SetBlob(ioname.c_str(), ito->second);
|
||||
_blobs[ioname] = ito->second;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -34,14 +34,21 @@ public:
|
||||
|
||||
void InferImpl() override;
|
||||
|
||||
void SetBlob(const std::string& name, const InferenceEngine::Blob::Ptr& data) override;
|
||||
void SetBlob(const std::string& name, const InferenceEngine::Blob::Ptr& blob) override;
|
||||
|
||||
InferenceEngine::Blob::Ptr GetBlob(const std::string& name) override;
|
||||
|
||||
void SetBlob(const std::string& name,
|
||||
const InferenceEngine::Blob::Ptr& blob,
|
||||
const InferenceEngine::PreProcessInfo& info) override;
|
||||
|
||||
const InferenceEngine::PreProcessInfo& GetPreProcess(const std::string& name) const override;
|
||||
|
||||
std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> GetPerformanceCounts() const override;
|
||||
|
||||
void updateInOutIfNeeded();
|
||||
|
||||
SubRequestsList _inferRequests;
|
||||
std::map<std::string, InferenceEngine::Blob::Ptr> _blobs;
|
||||
std::map<std::string, InferenceEngine::Blob::Ptr> _blobs;
|
||||
std::map<std::string, InferenceEngine::IInferRequestInternal*> _subRequestFromBlobName;
|
||||
};
|
||||
|
||||
} // namespace HeteroPlugin
|
||||
|
Loading…
Reference in New Issue
Block a user