allow hetero code-path for explicit device name like BATCH:GPU(4), used in the hetero code-path tests

This commit is contained in:
myshevts 2021-11-24 17:16:59 +03:00
parent 2db8ff5ccc
commit 1ebbcef63f

View File

@ -540,28 +540,43 @@ public:
const std::string& deviceNameOrig, const std::string& deviceNameOrig,
const std::map<std::string, std::string>& config) override { const std::map<std::string, std::string>& config) override {
OV_ITT_SCOPE(FIRST_INFERENCE, ie::itt::domains::IE_LT, "Core::LoadNetwork::CNN"); OV_ITT_SCOPE(FIRST_INFERENCE, ie::itt::domains::IE_LT, "Core::LoadNetwork::CNN");
std::string deviceName = deviceNameOrig; std::string deviceName = deviceNameOrig, deviceNameWithBatchSize;
std::map<std::string, std::string> config_with_batch = config; std::map<std::string, std::string> config_with_batch = config;
if (deviceName.find("BATCH") != std::string::npos) {
auto pos = deviceName.find_first_of(":");
if (pos != std::string::npos) {
deviceNameWithBatchSize = deviceName.substr(pos + 1);
config_with_batch[CONFIG_KEY(ALLOW_AUTO_BATCHING)] = CONFIG_VALUE(YES);
}
}
const auto& batch_mode = config_with_batch.find(CONFIG_KEY(ALLOW_AUTO_BATCHING)); const auto& batch_mode = config_with_batch.find(CONFIG_KEY(ALLOW_AUTO_BATCHING));
if (batch_mode != config_with_batch.end() && batch_mode->second == CONFIG_VALUE(YES)) { if (batch_mode != config_with_batch.end() && batch_mode->second == CONFIG_VALUE(YES)) {
auto deviceNameWithoutBatch = !deviceNameWithBatchSize.empty()
? DeviceIDParser::getBatchDevice(deviceNameWithBatchSize)
: deviceNameOrig;
unsigned int requests = 0;
unsigned int optimalBatchSize = 0;
if (deviceNameWithBatchSize.empty()) { // batch size is not set explicitly via device name e.g. BATCH:GPU(4)
// query the optimal batch size
try {
std::map<std::string, ie::Parameter> options; std::map<std::string, ie::Parameter> options;
options["MODEL_ADDRESS"] = &network; options["MODEL_ADDRESS"] = &network;
auto optimalBatchSize = GetCPPPluginByName(DeviceIDParser(deviceNameOrig).getDeviceName()) optimalBatchSize = GetCPPPluginByName(DeviceIDParser(deviceNameWithoutBatch).getDeviceName())
.get_metric(METRIC_KEY(OPTIMAL_BATCH), options) .get_metric(METRIC_KEY(OPTIMAL_BATCH), options)
.as<unsigned int>(); .as<unsigned int>();
unsigned int requests = 0; auto res = GetConfig(deviceNameWithoutBatch, CONFIG_KEY(PERFORMANCE_HINT_NUM_REQUESTS)).as<std::string>();
try {
auto res = GetConfig(deviceNameOrig, CONFIG_KEY(PERFORMANCE_HINT_NUM_REQUESTS)).as<std::string>();
requests = PerfHintsConfig::CheckPerformanceHintRequestValue(res); requests = PerfHintsConfig::CheckPerformanceHintRequestValue(res);
} catch (...) {
}
const auto &reqs = config.find(CONFIG_KEY(PERFORMANCE_HINT_NUM_REQUESTS)); const auto &reqs = config.find(CONFIG_KEY(PERFORMANCE_HINT_NUM_REQUESTS));
if (reqs != config.end()) if (reqs != config.end())
requests = (unsigned int) PerfHintsConfig::CheckPerformanceHintRequestValue(reqs->second); requests = (unsigned int) PerfHintsConfig::CheckPerformanceHintRequestValue(reqs->second);
if (requests) { if (requests) {
std::cout << "!!!!!!!!!!!!!!!Detected reqs_limitation: " << requests << std::endl; std::cout << "!!!!!!!!!!!!!!!Detected reqs_limitation: " << requests << std::endl;
optimalBatchSize = std::min(requests, optimalBatchSize); optimalBatchSize = std::max(1u, std::min(requests, optimalBatchSize));
}
} catch (...) {
}
} }
auto function = network.getFunction(); auto function = network.getFunction();
bool bDetectionOutput = false; bool bDetectionOutput = false;
@ -577,7 +592,7 @@ public:
if (!std::strcmp("DetectionOutput", node->get_type_info().name) || if (!std::strcmp("DetectionOutput", node->get_type_info().name) ||
(!std::strcmp("Result", node->get_type_info().name) && isDetectionOutputParent(node))) { (!std::strcmp("Result", node->get_type_info().name) && isDetectionOutputParent(node))) {
node->get_rt_info()["affinity"] = node->get_rt_info()["affinity"] =
std::make_shared<ngraph::VariantWrapper<std::string>>(deviceNameOrig); std::make_shared<ngraph::VariantWrapper<std::string>>(deviceNameWithoutBatch);
std::cout << "!!! AFF !!! type: " << node->get_type_info().name std::cout << "!!! AFF !!! type: " << node->get_type_info().name
<< ", name: " << node->get_friendly_name() << std::endl; << ", name: " << node->get_friendly_name() << std::endl;
bDetectionOutput = true; bDetectionOutput = true;
@ -585,16 +600,17 @@ public:
node->get_rt_info()["affinity"] = std::make_shared<ngraph::VariantWrapper<std::string>>("BATCH"); node->get_rt_info()["affinity"] = std::make_shared<ngraph::VariantWrapper<std::string>>("BATCH");
} }
} }
if (optimalBatchSize > 1) { if (optimalBatchSize > 1 || !deviceNameWithBatchSize.empty()) {
auto batchConfig = deviceNameWithBatchSize.empty()
? deviceNameWithoutBatch + "(" + std::to_string(optimalBatchSize) + ")"
: deviceNameWithBatchSize;
if (bDetectionOutput) { if (bDetectionOutput) {
deviceName = "HETERO:BATCH," + deviceNameOrig; deviceName = "HETERO:BATCH," + deviceNameWithoutBatch;
std::cout << "HETERO code path!!!!" << std::endl; std::cout << "HETERO code path!!!!" << std::endl;
// config["AUTO_BATCH"] = deviceNameOrig+"("+ std::to_string(optimalBatchSize)+ ")"; // config_with_batch[CONFIG_KEY(AUTO_BATCH)] = batchConfig;
SetConfigForPlugins({{"AUTO_BATCH", deviceNameOrig + "(" + std::to_string(optimalBatchSize) + ")"}}, SetConfigForPlugins({{CONFIG_KEY(AUTO_BATCH), batchConfig}},"BATCH");
"BATCH");
} else { } else {
std::string deviceBatch = "BATCH:" + deviceNameOrig + "(" + std::to_string(optimalBatchSize) + ")"; deviceName = "BATCH:" + batchConfig;
deviceName = deviceBatch;
} }
} }
config_with_batch.erase(batch_mode); config_with_batch.erase(batch_mode);