Blob processing in HETERO forwarded to sub requests (#8012)
* Blob processing in HETERO forwarded to subrequests * Added template dynamic hetero tests
This commit is contained in:
parent
fa38103e5b
commit
4cc53c97a1
@ -10,6 +10,7 @@ addIeTargetTest(
|
|||||||
ROOT ${CMAKE_CURRENT_SOURCE_DIR}
|
ROOT ${CMAKE_CURRENT_SOURCE_DIR}
|
||||||
DEPENDENCIES
|
DEPENDENCIES
|
||||||
templatePlugin
|
templatePlugin
|
||||||
|
HeteroPlugin
|
||||||
LINK_LIBRARIES
|
LINK_LIBRARIES
|
||||||
IE::funcSharedTests
|
IE::funcSharedTests
|
||||||
INCLUDES
|
INCLUDES
|
||||||
|
@ -14,6 +14,9 @@ const std::vector<std::map<std::string, std::string>> configs = {
|
|||||||
{}
|
{}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const std::vector<std::map<std::string, std::string>> HeteroConfigs = {
|
||||||
|
{{"TARGET_FALLBACK", CommonTestUtils::DEVICE_TEMPLATE}}};
|
||||||
|
|
||||||
INSTANTIATE_TEST_SUITE_P(smoke_BehaviorTests, InferRequestDynamicTests,
|
INSTANTIATE_TEST_SUITE_P(smoke_BehaviorTests, InferRequestDynamicTests,
|
||||||
::testing::Combine(
|
::testing::Combine(
|
||||||
::testing::Values(ngraph::builder::subgraph::makeSplitConvConcat()),
|
::testing::Values(ngraph::builder::subgraph::makeSplitConvConcat()),
|
||||||
@ -23,4 +26,13 @@ INSTANTIATE_TEST_SUITE_P(smoke_BehaviorTests, InferRequestDynamicTests,
|
|||||||
::testing::ValuesIn(configs)),
|
::testing::ValuesIn(configs)),
|
||||||
InferRequestDynamicTests::getTestCaseName);
|
InferRequestDynamicTests::getTestCaseName);
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(smoke_Hetero_BehaviorTests, InferRequestDynamicTests,
|
||||||
|
::testing::Combine(
|
||||||
|
::testing::Values(ngraph::builder::subgraph::makeSplitConvConcat()),
|
||||||
|
::testing::Values(std::vector<std::pair<std::vector<size_t>, std::vector<size_t>>>{{{1, 4, 20, 20}, {1, 10, 18, 18}},
|
||||||
|
{{2, 4, 20, 20}, {2, 10, 18, 18}}}),
|
||||||
|
::testing::Values(CommonTestUtils::DEVICE_HETERO),
|
||||||
|
::testing::ValuesIn(HeteroConfigs)),
|
||||||
|
InferRequestDynamicTests::getTestCaseName);
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
@ -13,6 +13,9 @@ const std::vector<std::map<std::string, std::string>> configs = {
|
|||||||
{}
|
{}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const std::vector<std::map<std::string, std::string>> HeteroConfigs = {
|
||||||
|
{{"TARGET_FALLBACK", CommonTestUtils::DEVICE_TEMPLATE}}};
|
||||||
|
|
||||||
INSTANTIATE_TEST_SUITE_P(smoke_BehaviorTests, OVInferenceChaining,
|
INSTANTIATE_TEST_SUITE_P(smoke_BehaviorTests, OVInferenceChaining,
|
||||||
::testing::Combine(
|
::testing::Combine(
|
||||||
::testing::Values(ov::element::f32),
|
::testing::Values(ov::element::f32),
|
||||||
@ -20,4 +23,11 @@ INSTANTIATE_TEST_SUITE_P(smoke_BehaviorTests, OVInferenceChaining,
|
|||||||
::testing::ValuesIn(configs)),
|
::testing::ValuesIn(configs)),
|
||||||
OVInferenceChaining::getTestCaseName);
|
OVInferenceChaining::getTestCaseName);
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(smoke_Hetero_BehaviorTests, OVInferenceChaining,
|
||||||
|
::testing::Combine(
|
||||||
|
::testing::Values(ov::element::f32),
|
||||||
|
::testing::Values(CommonTestUtils::DEVICE_HETERO),
|
||||||
|
::testing::ValuesIn(HeteroConfigs)),
|
||||||
|
OVInferenceChaining::getTestCaseName);
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
@ -43,11 +43,6 @@ HeteroAsyncInferRequest::HeteroAsyncInferRequest(const IInferRequestInternal::Pt
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void HeteroAsyncInferRequest::StartAsync_ThreadUnsafe() {
|
|
||||||
_heteroInferRequest->updateInOutIfNeeded();
|
|
||||||
RunFirstStage(_pipeline.begin(), _pipeline.end());
|
|
||||||
}
|
|
||||||
|
|
||||||
StatusCode HeteroAsyncInferRequest::Wait(int64_t millis_timeout) {
|
StatusCode HeteroAsyncInferRequest::Wait(int64_t millis_timeout) {
|
||||||
auto waitStatus = StatusCode::OK;
|
auto waitStatus = StatusCode::OK;
|
||||||
try {
|
try {
|
||||||
|
@ -18,7 +18,6 @@ public:
|
|||||||
const InferenceEngine::ITaskExecutor::Ptr& taskExecutor,
|
const InferenceEngine::ITaskExecutor::Ptr& taskExecutor,
|
||||||
const InferenceEngine::ITaskExecutor::Ptr& callbackExecutor);
|
const InferenceEngine::ITaskExecutor::Ptr& callbackExecutor);
|
||||||
~HeteroAsyncInferRequest();
|
~HeteroAsyncInferRequest();
|
||||||
void StartAsync_ThreadUnsafe() override;
|
|
||||||
InferenceEngine::StatusCode Wait(int64_t millis_timeout) override;
|
InferenceEngine::StatusCode Wait(int64_t millis_timeout) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -36,11 +36,14 @@ HeteroInferRequest::HeteroInferRequest(InferenceEngine::InputsDataMap networkInp
|
|||||||
bool emplaced = false;
|
bool emplaced = false;
|
||||||
std::tie(itBlob, emplaced) = _blobs.emplace(intermediateBlobName, Blob::Ptr{});
|
std::tie(itBlob, emplaced) = _blobs.emplace(intermediateBlobName, Blob::Ptr{});
|
||||||
if (emplaced) {
|
if (emplaced) {
|
||||||
|
if (InferenceEngine::details::contains(_networkInputs, blobName)) {
|
||||||
|
_subRequestFromBlobName.emplace(blobName, r._ptr.get());
|
||||||
|
_blobs.erase(intermediateBlobName);
|
||||||
|
} else if (InferenceEngine::details::contains(_networkOutputs, blobName)) {
|
||||||
|
_subRequestFromBlobName.emplace(blobName, r._ptr.get());
|
||||||
|
_blobs.erase(intermediateBlobName);
|
||||||
|
} else {
|
||||||
itBlob->second = r->GetBlob(blobName);
|
itBlob->second = r->GetBlob(blobName);
|
||||||
if (InferenceEngine::details::contains(networkInputs, blobName)) {
|
|
||||||
_inputs[blobName] = itBlob->second;
|
|
||||||
} else if (InferenceEngine::details::contains(networkOutputs, blobName)) {
|
|
||||||
_outputs[blobName] = itBlob->second;
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
r->SetBlob(blobName, itBlob->second);
|
r->SetBlob(blobName, itBlob->second);
|
||||||
@ -64,25 +67,39 @@ HeteroInferRequest::HeteroInferRequest(InferenceEngine::InputsDataMap networkInp
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void HeteroInferRequest::SetBlob(const std::string& name, const InferenceEngine::Blob::Ptr& data) {
|
void HeteroInferRequest::SetBlob(const std::string& name, const InferenceEngine::Blob::Ptr& blob) {
|
||||||
InferenceEngine::IInferRequestInternal::SetBlob(name, data);
|
auto itRequest = _subRequestFromBlobName.find(name);
|
||||||
assert(!_inferRequests.empty());
|
if (itRequest == _subRequestFromBlobName.end()) {
|
||||||
for (auto &&desc : _inferRequests) {
|
IE_THROW() << "There is no infer requests binded to blob with name: " << name;
|
||||||
auto &r = desc._request;
|
|
||||||
assert(r);
|
|
||||||
InputInfo::Ptr foundInput;
|
|
||||||
DataPtr foundOutput;
|
|
||||||
try {
|
|
||||||
// if `name` is input blob
|
|
||||||
if (findInputAndOutputBlobByName(name, foundInput, foundOutput)) {
|
|
||||||
r->SetBlob(name, data, foundInput->getPreProcess());
|
|
||||||
}
|
}
|
||||||
} catch (const InferenceEngine::NotFound&) {}
|
itRequest->second->SetBlob(name, blob);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
InferenceEngine::Blob::Ptr HeteroInferRequest::GetBlob(const std::string& name) {
|
||||||
|
auto itRequest = _subRequestFromBlobName.find(name);
|
||||||
|
if (itRequest == _subRequestFromBlobName.end()) {
|
||||||
|
IE_THROW() << "There is no infer requests binded to blob with name: " << name;
|
||||||
|
}
|
||||||
|
return itRequest->second->GetBlob(name);
|
||||||
|
}
|
||||||
|
|
||||||
|
void HeteroInferRequest::SetBlob(const std::string& name, const Blob::Ptr& blob, const PreProcessInfo& info) {
|
||||||
|
auto itRequest = _subRequestFromBlobName.find(name);
|
||||||
|
if (itRequest == _subRequestFromBlobName.end()) {
|
||||||
|
IE_THROW() << "There is no infer requests binded to blob with name: " << name;
|
||||||
|
}
|
||||||
|
itRequest->second->SetBlob(name, blob, info);
|
||||||
|
}
|
||||||
|
|
||||||
|
const InferenceEngine::PreProcessInfo& HeteroInferRequest::GetPreProcess(const std::string& name) const {
|
||||||
|
auto itRequest = _subRequestFromBlobName.find(name);
|
||||||
|
if (itRequest == _subRequestFromBlobName.end()) {
|
||||||
|
IE_THROW() << "There is no infer requests binded to blob with name: " << name;
|
||||||
|
}
|
||||||
|
return itRequest->second->GetPreProcess(name);
|
||||||
}
|
}
|
||||||
|
|
||||||
void HeteroInferRequest::InferImpl() {
|
void HeteroInferRequest::InferImpl() {
|
||||||
updateInOutIfNeeded();
|
|
||||||
for (auto &&desc : _inferRequests) {
|
for (auto &&desc : _inferRequests) {
|
||||||
OV_ITT_SCOPED_TASK(itt::domains::HeteroPlugin, desc._profilingTask);
|
OV_ITT_SCOPED_TASK(itt::domains::HeteroPlugin, desc._profilingTask);
|
||||||
auto &r = desc._request;
|
auto &r = desc._request;
|
||||||
@ -101,40 +118,3 @@ std::map<std::string, InferenceEngineProfileInfo> HeteroInferRequest::GetPerform
|
|||||||
}
|
}
|
||||||
return perfMap;
|
return perfMap;
|
||||||
}
|
}
|
||||||
|
|
||||||
void HeteroInferRequest::updateInOutIfNeeded() {
|
|
||||||
OV_ITT_SCOPED_TASK(itt::domains::HeteroPlugin, "updateInOutIfNeeded");
|
|
||||||
assert(!_inferRequests.empty());
|
|
||||||
for (auto &&desc : _inferRequests) {
|
|
||||||
auto &r = desc._request;
|
|
||||||
assert(r);
|
|
||||||
for (auto&& inputInfo : desc._network->GetInputsInfo()) {
|
|
||||||
auto& ioname = inputInfo.first;
|
|
||||||
auto iti = _inputs.find(ioname);
|
|
||||||
if (iti != _inputs.end()) {
|
|
||||||
auto it = _preProcData.find(ioname);
|
|
||||||
if (it != _preProcData.end()) {
|
|
||||||
if (it->second->getRoiBlob() != _blobs[ioname]) {
|
|
||||||
r->SetBlob(ioname.c_str(), it->second->getRoiBlob());
|
|
||||||
_blobs[ioname] = iti->second;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if (iti->second != _blobs[ioname]) {
|
|
||||||
r->SetBlob(ioname.c_str(), iti->second);
|
|
||||||
_blobs[ioname] = iti->second;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for (auto&& outputInfo : desc._network->GetOutputsInfo()) {
|
|
||||||
auto& ioname = outputInfo.first;
|
|
||||||
auto ito = _outputs.find(ioname);
|
|
||||||
if (ito != _outputs.end()) {
|
|
||||||
if (ito->second != _blobs[ioname]) {
|
|
||||||
r->SetBlob(ioname.c_str(), ito->second);
|
|
||||||
_blobs[ioname] = ito->second;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
@ -34,14 +34,21 @@ public:
|
|||||||
|
|
||||||
void InferImpl() override;
|
void InferImpl() override;
|
||||||
|
|
||||||
void SetBlob(const std::string& name, const InferenceEngine::Blob::Ptr& data) override;
|
void SetBlob(const std::string& name, const InferenceEngine::Blob::Ptr& blob) override;
|
||||||
|
|
||||||
|
InferenceEngine::Blob::Ptr GetBlob(const std::string& name) override;
|
||||||
|
|
||||||
|
void SetBlob(const std::string& name,
|
||||||
|
const InferenceEngine::Blob::Ptr& blob,
|
||||||
|
const InferenceEngine::PreProcessInfo& info) override;
|
||||||
|
|
||||||
|
const InferenceEngine::PreProcessInfo& GetPreProcess(const std::string& name) const override;
|
||||||
|
|
||||||
std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> GetPerformanceCounts() const override;
|
std::map<std::string, InferenceEngine::InferenceEngineProfileInfo> GetPerformanceCounts() const override;
|
||||||
|
|
||||||
void updateInOutIfNeeded();
|
|
||||||
|
|
||||||
SubRequestsList _inferRequests;
|
SubRequestsList _inferRequests;
|
||||||
std::map<std::string, InferenceEngine::Blob::Ptr> _blobs;
|
std::map<std::string, InferenceEngine::Blob::Ptr> _blobs;
|
||||||
|
std::map<std::string, InferenceEngine::IInferRequestInternal*> _subRequestFromBlobName;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace HeteroPlugin
|
} // namespace HeteroPlugin
|
||||||
|
Loading…
Reference in New Issue
Block a user