diff --git a/inference-engine/src/auto_batch/auto_batch.cpp b/inference-engine/src/auto_batch/auto_batch.cpp index 4b5d6bf502e..915d561d776 100644 --- a/inference-engine/src/auto_batch/auto_batch.cpp +++ b/inference-engine/src/auto_batch/auto_batch.cpp @@ -535,7 +535,6 @@ IExecutableNetworkInternal::Ptr AutoBatchInferencePlugin::LoadExeNetworkImpl(con // TODO: remove this experimental code that does loop rather than use the batch1 footprint only InferenceEngine::SoExecutableNetworkInternal executableNetworkForDevice; - auto needLoop = deviceName.find("GPU") != std::string::npos; do { try { CNNNetwork clonedNetwork(InferenceEngine::cloneNetwork(network)); @@ -555,27 +554,27 @@ IExecutableNetworkInternal::Ptr AutoBatchInferencePlugin::LoadExeNetworkImpl(con } } std::cout << "Reshaped network by batch to " << metaDevice.batchForDevice << std::endl; - clonedNetwork.reshape(shapes); - executableNetworkForDevice = GetCore()->LoadNetwork(CNNNetwork{clonedNetwork}, deviceName, deviceConfig); - if (executableNetworkForDevice == nullptr) - IE_THROW(NotFound) << "Failed to load Executable network the device " - << "that the BATCH device is initialized to work with"; + executableNetworkForDevice = GetCore()->LoadNetwork(CNNNetwork{clonedNetwork}, deviceName, deviceConfig); if (deviceName.find("GPU") != std::string::npos) { const uint64_t total_mem = GetCore()->GetMetric(deviceName, GPU_METRIC_KEY(DEVICE_TOTAL_MEM_SIZE)); const uint64_t footprint = executableNetworkForDevice->GetMetric(GPU_METRIC_KEY(NETWORK_MEM_FOOTPRINT)); std::cout << "!!!!!!!!!!!!!! (BATCHED):" << footprint << std::endl; - if (footprint < total_mem) - break; - else // WA for inaccurate footprint estimations + if (footprint > total_mem) // WA for inaccurate footprint estimations throw NETWORK_NOT_LOADED; } } catch (...) { + // reload the network with smaller batch + executableNetworkForDevice = {nullptr, nullptr}; std::cout << "WA for network failure!!!" << std::endl; metaDevice.batchForDevice /= 2; } - } while (needLoop && (metaDevice.batchForDevice)); + } while (!executableNetworkForDevice && (metaDevice.batchForDevice)); + if (executableNetworkForDevice == nullptr) + IE_THROW(NetworkNotLoaded) << "Failed to load Executable network to the device " << deviceName + << "that the BATCH device is initialized to work with"; + return std::make_shared(executableNetworkForDevice, networkWithoutBatch, metaDevice,