diff --git a/inference-engine/src/auto_batch/auto_batch.cpp b/inference-engine/src/auto_batch/auto_batch.cpp index 6a17016d71e..83d309c7989 100644 --- a/inference-engine/src/auto_batch/auto_batch.cpp +++ b/inference-engine/src/auto_batch/auto_batch.cpp @@ -235,7 +235,8 @@ AutoBatchAsyncInferRequest::~AutoBatchAsyncInferRequest() { } // ------------------------------AutoBatchExecutableNetwork---------------------------- -AutoBatchExecutableNetwork::AutoBatchExecutableNetwork(const InferenceEngine::SoExecutableNetworkInternal& networkForDevice, +AutoBatchExecutableNetwork::AutoBatchExecutableNetwork( + const InferenceEngine::SoExecutableNetworkInternal& networkWithBatch, const InferenceEngine::SoExecutableNetworkInternal& networkWithoutBatch, const DeviceInformation& networkDevice, const std::unordered_map& config, @@ -244,7 +245,7 @@ AutoBatchExecutableNetwork::AutoBatchExecutableNetwork(const InferenceEngine::So nullptr, std::make_shared()), _device{networkDevice}, - _network{networkForDevice}, + _network{networkWithBatch}, _networkWithoutBatch{networkWithoutBatch}, _config{config}, _needPerfCounters{needPerfCounters} { @@ -259,7 +260,7 @@ AutoBatchExecutableNetwork::~AutoBatchExecutableNetwork() { } std::shared_ptr AutoBatchExecutableNetwork::GetContext() const { - return _networkWithoutBatch->GetContext(); + return _network->GetContext(); } InferenceEngine::IInferRequestInternal::Ptr AutoBatchExecutableNetwork::CreateInferRequestImpl(InferenceEngine::InputsDataMap networkInputs, @@ -527,14 +528,14 @@ IExecutableNetworkInternal::Ptr AutoBatchInferencePlugin::LoadExeNetworkImpl(con const auto perfConfig = fullConfig.find(PluginConfigParams::KEY_PERF_COUNT); const bool enablePerfCounters = (fullConfig.end() != perfConfig) && (perfConfig->second == PluginConfigParams::YES); - auto networkWithoutBatch = GetCore()->LoadNetwork(network, deviceName, deviceConfig); + auto executableNetworkWithoutBatch = GetCore()->LoadNetwork(network, deviceName, deviceConfig); // device settings + auto-batch settings std::unordered_map networkConfig; networkConfig.insert(*device_batch); networkConfig.insert(deviceConfig.begin(), deviceConfig.end()); // TODO: remove this experimental code that does loop rather than use the batch1 footprint only - InferenceEngine::SoExecutableNetworkInternal executableNetworkForDevice; + InferenceEngine::SoExecutableNetworkInternal executableNetworkWithBatch; do { try { CNNNetwork clonedNetwork(InferenceEngine::cloneNetwork(network)); @@ -555,20 +556,21 @@ IExecutableNetworkInternal::Ptr AutoBatchInferencePlugin::LoadExeNetworkImpl(con } std::cout << "Reshaped network by batch to " << metaDevice.batchForDevice << std::endl; clonedNetwork.reshape(shapes); - executableNetworkForDevice = GetCore()->LoadNetwork(CNNNetwork{clonedNetwork}, deviceName, deviceConfig); + executableNetworkWithBatch = GetCore()->LoadNetwork(CNNNetwork{clonedNetwork}, deviceName, deviceConfig); + IE_ASSERT(executableNetworkWithoutBatch->GetContext() == executableNetworkWithBatch->GetContext()); } catch (...) { // reload the network with smaller batch - executableNetworkForDevice = {nullptr, nullptr}; + executableNetworkWithBatch = {nullptr, nullptr}; std::cout << "WA for network failure!!!" << std::endl; metaDevice.batchForDevice /= 2; } - } while (!executableNetworkForDevice && (metaDevice.batchForDevice)); - if (executableNetworkForDevice == nullptr) + } while (!executableNetworkWithBatch && (metaDevice.batchForDevice)); + if (executableNetworkWithBatch == nullptr) IE_THROW(NetworkNotLoaded) << "Failed to load Executable network to the device " << deviceName << "that the BATCH device is initialized to work with"; - return std::make_shared(executableNetworkForDevice, - networkWithoutBatch, + return std::make_shared(executableNetworkWithBatch, + executableNetworkWithoutBatch, metaDevice, networkConfig, enablePerfCounters);