diff --git a/inference-engine/src/inference_engine/src/ie_core.cpp b/inference-engine/src/inference_engine/src/ie_core.cpp index 16f5acf1f6f..352f938c142 100644 --- a/inference-engine/src/inference_engine/src/ie_core.cpp +++ b/inference-engine/src/inference_engine/src/ie_core.cpp @@ -540,28 +540,43 @@ public: const std::string& deviceNameOrig, const std::map& config) override { OV_ITT_SCOPE(FIRST_INFERENCE, ie::itt::domains::IE_LT, "Core::LoadNetwork::CNN"); - std::string deviceName = deviceNameOrig; + std::string deviceName = deviceNameOrig, deviceNameWithBatchSize; std::map config_with_batch = config; + if (deviceName.find("BATCH") != std::string::npos) { + auto pos = deviceName.find_first_of(":"); + if (pos != std::string::npos) { + deviceNameWithBatchSize = deviceName.substr(pos + 1); + config_with_batch[CONFIG_KEY(ALLOW_AUTO_BATCHING)] = CONFIG_VALUE(YES); + } + } + const auto& batch_mode = config_with_batch.find(CONFIG_KEY(ALLOW_AUTO_BATCHING)); if (batch_mode != config_with_batch.end() && batch_mode->second == CONFIG_VALUE(YES)) { - std::map options; - options["MODEL_ADDRESS"] = &network; - auto optimalBatchSize = GetCPPPluginByName(DeviceIDParser(deviceNameOrig).getDeviceName()) - .get_metric(METRIC_KEY(OPTIMAL_BATCH), options) - .as(); + auto deviceNameWithoutBatch = !deviceNameWithBatchSize.empty() + ? DeviceIDParser::getBatchDevice(deviceNameWithBatchSize) + : deviceNameOrig; unsigned int requests = 0; - try { - auto res = GetConfig(deviceNameOrig, CONFIG_KEY(PERFORMANCE_HINT_NUM_REQUESTS)).as(); - requests = PerfHintsConfig::CheckPerformanceHintRequestValue(res); - } catch (...) { - } - const auto& reqs = config.find(CONFIG_KEY(PERFORMANCE_HINT_NUM_REQUESTS)); - if (reqs != config.end()) - requests = (unsigned int)PerfHintsConfig::CheckPerformanceHintRequestValue(reqs->second); - if (requests) { - std::cout << "!!!!!!!!!!!!!!!Detected reqs_limitation: " << requests << std::endl; - optimalBatchSize = std::min(requests, optimalBatchSize); + unsigned int optimalBatchSize = 0; + if (deviceNameWithBatchSize.empty()) { // batch size is not set explicitly via device name e.g. BATCH:GPU(4) + // query the optimal batch size + try { + std::map options; + options["MODEL_ADDRESS"] = &network; + optimalBatchSize = GetCPPPluginByName(DeviceIDParser(deviceNameWithoutBatch).getDeviceName()) + .get_metric(METRIC_KEY(OPTIMAL_BATCH), options) + .as(); + auto res = GetConfig(deviceNameWithoutBatch, CONFIG_KEY(PERFORMANCE_HINT_NUM_REQUESTS)).as(); + requests = PerfHintsConfig::CheckPerformanceHintRequestValue(res); + const auto &reqs = config.find(CONFIG_KEY(PERFORMANCE_HINT_NUM_REQUESTS)); + if (reqs != config.end()) + requests = (unsigned int) PerfHintsConfig::CheckPerformanceHintRequestValue(reqs->second); + if (requests) { + std::cout << "!!!!!!!!!!!!!!!Detected reqs_limitation: " << requests << std::endl; + optimalBatchSize = std::max(1u, std::min(requests, optimalBatchSize)); + } + } catch (...) { + } } auto function = network.getFunction(); bool bDetectionOutput = false; @@ -577,7 +592,7 @@ public: if (!std::strcmp("DetectionOutput", node->get_type_info().name) || (!std::strcmp("Result", node->get_type_info().name) && isDetectionOutputParent(node))) { node->get_rt_info()["affinity"] = - std::make_shared>(deviceNameOrig); + std::make_shared>(deviceNameWithoutBatch); std::cout << "!!! AFF !!! type: " << node->get_type_info().name << ", name: " << node->get_friendly_name() << std::endl; bDetectionOutput = true; @@ -585,16 +600,17 @@ public: node->get_rt_info()["affinity"] = std::make_shared>("BATCH"); } } - if (optimalBatchSize > 1) { + if (optimalBatchSize > 1 || !deviceNameWithBatchSize.empty()) { + auto batchConfig = deviceNameWithBatchSize.empty() + ? deviceNameWithoutBatch + "(" + std::to_string(optimalBatchSize) + ")" + : deviceNameWithBatchSize; if (bDetectionOutput) { - deviceName = "HETERO:BATCH," + deviceNameOrig; + deviceName = "HETERO:BATCH," + deviceNameWithoutBatch; std::cout << "HETERO code path!!!!" << std::endl; - // config["AUTO_BATCH"] = deviceNameOrig+"("+ std::to_string(optimalBatchSize)+ ")"; - SetConfigForPlugins({{"AUTO_BATCH", deviceNameOrig + "(" + std::to_string(optimalBatchSize) + ")"}}, - "BATCH"); + // config_with_batch[CONFIG_KEY(AUTO_BATCH)] = batchConfig; + SetConfigForPlugins({{CONFIG_KEY(AUTO_BATCH), batchConfig}},"BATCH"); } else { - std::string deviceBatch = "BATCH:" + deviceNameOrig + "(" + std::to_string(optimalBatchSize) + ")"; - deviceName = deviceBatch; + deviceName = "BATCH:" + batchConfig; } } config_with_batch.erase(batch_mode);