From 403339f8f470c90dee6f6d94ed58644b2787f66b Mon Sep 17 00:00:00 2001 From: Yuan Hu Date: Wed, 19 Jan 2022 22:13:29 +0800 Subject: [PATCH] [AUTOPLUGIN] not select if only one device (#9730) * if only one Device, not select Signed-off-by: Hu, Yuan2 * modify test case to match logic Signed-off-by: Hu, Yuan2 --- src/plugins/auto/plugin.cpp | 34 ++++++++++++---------- src/tests/unit/auto/select_device_test.cpp | 21 ++++++++----- 2 files changed, 33 insertions(+), 22 deletions(-) diff --git a/src/plugins/auto/plugin.cpp b/src/plugins/auto/plugin.cpp index be02361c1fc..cc82be25d4b 100644 --- a/src/plugins/auto/plugin.cpp +++ b/src/plugins/auto/plugin.cpp @@ -455,23 +455,27 @@ DeviceInformation MultiDeviceInferencePlugin::SelectDevice(const std::vector validDevices; - auto selectSupportDev = [this, &devices, &validDevices](const std::string& networkPrecision) { - for (auto iter = devices.begin(); iter != devices.end();) { - std::vector capability = GetCore()->GetMetric(iter->deviceName, METRIC_KEY(OPTIMIZATION_CAPABILITIES)); - auto supportNetwork = std::find(capability.begin(), capability.end(), (networkPrecision)); - if (supportNetwork != capability.end()) { - validDevices.push_back(std::move(*iter)); - devices.erase(iter++); - continue; + if (metaDevices.size() > 1) { + auto selectSupportDev = [this, &devices, &validDevices](const std::string& networkPrecision) { + for (auto iter = devices.begin(); iter != devices.end();) { + std::vector capability = GetCore()->GetMetric(iter->deviceName, METRIC_KEY(OPTIMIZATION_CAPABILITIES)); + auto supportNetwork = std::find(capability.begin(), capability.end(), (networkPrecision)); + if (supportNetwork != capability.end()) { + validDevices.push_back(std::move(*iter)); + devices.erase(iter++); + continue; + } + iter++; } - iter++; + }; + selectSupportDev(networkPrecision); + // If network is FP32, continue to collect the device support FP16 but not support FP32. + if (networkPrecision == "FP32") { + const std::string f16 = "FP16"; + selectSupportDev(f16); } - }; - selectSupportDev(networkPrecision); - // If network is FP32, continue to collect the device support FP16 but not support FP32. - if (networkPrecision == "FP32") { - const std::string f16 = "FP16"; - selectSupportDev(f16); + } else { + validDevices.push_back(metaDevices[0]); } if (validDevices.empty()) { diff --git a/src/tests/unit/auto/select_device_test.cpp b/src/tests/unit/auto/select_device_test.cpp index 89d1f6e0853..c8a94da0f4c 100644 --- a/src/tests/unit/auto/select_device_test.cpp +++ b/src/tests/unit/auto/select_device_test.cpp @@ -106,14 +106,21 @@ public: auto& devicesInfo = devicesMap[netPrecision]; bool find = false; DeviceInformation expect; - for (auto& item : devicesInfo) { - auto device = std::find_if(metaDevices.begin(), metaDevices.end(), - [&item](const DeviceInformation& d)->bool{return d.uniqueName == item.uniqueName;}); - if (device != metaDevices.end()) { - find = true; - expect = item; - break; + if (metaDevices.size() > 1) { + for (auto& item : devicesInfo) { + auto device = std::find_if(metaDevices.begin(), metaDevices.end(), + [&item](const DeviceInformation& d)->bool{return d.uniqueName == item.uniqueName;}); + if (device != metaDevices.end()) { + find = true; + expect = item; + break; + } } + } else if (metaDevices.size() == 1) { + expect = metaDevices[0]; + find = true; + } else { + find = false; } testConfigs.push_back(std::make_tuple(netPrecision, metaDevices, expect, !find)); } else {