From 43266c585e90efc40d5705490a03cd1b87d51b5e Mon Sep 17 00:00:00 2001 From: myshevts Date: Tue, 30 Nov 2021 19:06:48 +0300 Subject: [PATCH] improved remote-blobs tests --- .../src/auto_batch/auto_batch.cpp | 7 ++-- .../cldnn_remote_blob_tests.cpp | 11 ++++- src/inference/src/ie_core.cpp | 42 +++++++++---------- 3 files changed, 34 insertions(+), 26 deletions(-) diff --git a/inference-engine/src/auto_batch/auto_batch.cpp b/inference-engine/src/auto_batch/auto_batch.cpp index de4309ab3d6..62bb722d71a 100644 --- a/inference-engine/src/auto_batch/auto_batch.cpp +++ b/inference-engine/src/auto_batch/auto_batch.cpp @@ -465,7 +465,7 @@ DeviceInformation AutoBatchInferencePlugin::ParseMetaDevice(const std::string& d return metaDevice; } -RemoteContext::Ptr AutoBatchInferencePlugin::CreateContext(const InferenceEngine::ParamMap& config) { +RemoteContext::Ptr AutoBatchInferencePlugin::CreateContext(const InferenceEngine::ParamMap& config) { auto cfg = config; auto it = cfg.find(CONFIG_KEY(AUTO_BATCH)); if (it == cfg.end()) @@ -474,7 +474,7 @@ RemoteContext::Ptr AutoBatchInferencePlugin::CreateContext(const InferenceEngin auto val = it->second; auto metaDevice = ParseMetaDevice(val, {{}}); cfg.erase(it); - + std::cout << "AutoBatchInferencePlugin::CreateContext" << std::endl; return GetCore()->CreateContext(metaDevice.deviceName, cfg); } @@ -604,7 +604,8 @@ InferenceEngine::IExecutableNetworkInternal::Ptr AutoBatchInferencePlugin::LoadE const InferenceEngine::CNNNetwork& network, const std::shared_ptr& context, const std::map& config) { - return LoadNetworkImpl(network, context, config); + std::cout << "AutoBatchInferencePlugin::LoadExeNetworkImpl with context for " << context-> getDeviceName() << std::endl; + return LoadNetworkImpl(network, context, config); } InferenceEngine::QueryNetworkResult AutoBatchInferencePlugin::QueryNetwork(const InferenceEngine::CNNNetwork& network, diff --git a/inference-engine/tests/functional/plugin/gpu/remote_blob_tests/cldnn_remote_blob_tests.cpp b/inference-engine/tests/functional/plugin/gpu/remote_blob_tests/cldnn_remote_blob_tests.cpp index 2026d042ee1..35024ff1ee7 100644 --- a/inference-engine/tests/functional/plugin/gpu/remote_blob_tests/cldnn_remote_blob_tests.cpp +++ b/inference-engine/tests/functional/plugin/gpu/remote_blob_tests/cldnn_remote_blob_tests.cpp @@ -25,13 +25,17 @@ class RemoteBlob_Test : public CommonTestUtils::TestsCommon, public testing::Wit protected: std::shared_ptr fn_ptr; std::string deviceName; + std::map config; public: void SetUp() override { fn_ptr = ngraph::builder::subgraph::makeSplitMultiConvConcat(); deviceName = CommonTestUtils::DEVICE_GPU; - if (this->GetParam()) // BATCH:GPU(1) + auto with_auto_batching = this->GetParam(); + if (with_auto_batching) { // BATCH:GPU(1) deviceName = std::string(CommonTestUtils::DEVICE_BATCH) + ":" + deviceName + "(1)"; + config = {{CONFIG_KEY(ALLOW_AUTO_BATCHING), CONFIG_VALUE(YES)}}; + } } static std::string getTestCaseName(const testing::TestParamInfo& obj) { auto with_auto_batch = obj.param; @@ -172,7 +176,10 @@ TEST_P(RemoteBlob_Test, smoke_canInferOnUserContext) { // inference using remote blob auto ocl_instance = std::make_shared(); auto remote_context = make_shared_context(*ie, deviceName, ocl_instance->_context.get()); - auto exec_net_shared = ie->LoadNetwork(net, remote_context); + // since there is no way to enable the Auto-Batching thru the device name when loading with the RemoteContext + // (as the device name is deduced from the context, which is the "GPU") + // the only-way to test the auto-batching is explicit config with ALLOW_AUTO_BATCHING set to YES + auto exec_net_shared = ie->LoadNetwork(net, remote_context, config); auto inf_req_shared = exec_net_shared.CreateInferRequest(); inf_req_shared.SetBlob(net.getInputsInfo().begin()->first, fakeImageData); diff --git a/src/inference/src/ie_core.cpp b/src/inference/src/ie_core.cpp index f24827e111f..4b64eeb2312 100644 --- a/src/inference/src/ie_core.cpp +++ b/src/inference/src/ie_core.cpp @@ -484,10 +484,10 @@ public: return newAPI; } - ov::runtime::SoPtr LoadNetwork(const ie::CNNNetwork& network, - const std::shared_ptr& context, - const std::map& config) - override { + ov::runtime::SoPtr LoadNetwork( + const ie::CNNNetwork& network, + const std::shared_ptr& context, + const std::map& config) override { OV_ITT_SCOPE(FIRST_INFERENCE, ie::itt::domains::IE_LT, "Core::LoadNetwork::RemoteContext"); if (context == nullptr) { IE_THROW() << "Remote context is null"; @@ -497,6 +497,7 @@ public: std::map& config_with_batch = parsed._config; // if auto-batching is applicable, the below function will patch the device name and config accordingly: ApplyAutoBatching(network, deviceName, config_with_batch); + parsed = parseDeviceNameIntoConfig(deviceName, config_with_batch); auto plugin = GetCPPPluginByName(parsed._deviceName); ov::runtime::SoPtr res; @@ -515,9 +516,9 @@ public: return res; } - void ApplyAutoBatching (const ie::CNNNetwork& network, - std::string& deviceName, - std::map& config_with_batch) { + void ApplyAutoBatching(const ie::CNNNetwork& network, + std::string& deviceName, + std::map& config_with_batch) { std::string deviceNameWithBatchSize; if (deviceName.find("BATCH") != std::string::npos) { auto pos = deviceName.find_first_of(":"); @@ -529,9 +530,8 @@ public: const auto& batch_mode = config_with_batch.find(CONFIG_KEY(ALLOW_AUTO_BATCHING)); if (batch_mode != config_with_batch.end() && batch_mode->second == CONFIG_VALUE(YES)) { - auto deviceNameWithoutBatch = !deviceNameWithBatchSize.empty() - ? DeviceIDParser::getBatchDevice(deviceNameWithBatchSize) - : deviceName; + auto deviceNameWithoutBatch = + !deviceNameWithBatchSize.empty() ? DeviceIDParser::getBatchDevice(deviceNameWithBatchSize) : deviceName; unsigned int requests = 0; unsigned int optimalBatchSize = 0; if (deviceNameWithBatchSize.empty()) { @@ -541,10 +541,10 @@ public: std::map options; options["MODEL_PTR"] = &network; optimalBatchSize = GetCPPPluginByName(DeviceIDParser(deviceNameWithoutBatch).getDeviceName()) - .get_metric(METRIC_KEY(OPTIMAL_BATCH), options) - .as(); + .get_metric(METRIC_KEY(OPTIMAL_BATCH), options) + .as(); auto res = - GetConfig(deviceNameWithoutBatch, CONFIG_KEY(PERFORMANCE_HINT_NUM_REQUESTS)).as(); + GetConfig(deviceNameWithoutBatch, CONFIG_KEY(PERFORMANCE_HINT_NUM_REQUESTS)).as(); requests = PerfHintsConfig::CheckPerformanceHintRequestValue(res); const auto& reqs = config_with_batch.find(CONFIG_KEY(PERFORMANCE_HINT_NUM_REQUESTS)); if (reqs != config_with_batch.end()) @@ -570,7 +570,7 @@ public: if (!std::strcmp("DetectionOutput", node->get_type_info().name) || (!std::strcmp("Result", node->get_type_info().name) && isDetectionOutputParent(node))) { node->get_rt_info()["affinity"] = - std::make_shared>(deviceNameWithoutBatch); + std::make_shared>(deviceNameWithoutBatch); std::cout << "!!! AFF !!! type: " << node->get_type_info().name << ", name: " << node->get_friendly_name() << std::endl; bDetectionOutput = true; @@ -580,8 +580,8 @@ public: } if (optimalBatchSize > 1 || !deviceNameWithBatchSize.empty()) { auto batchConfig = deviceNameWithBatchSize.empty() - ? deviceNameWithoutBatch + "(" + std::to_string(optimalBatchSize) + ")" - : deviceNameWithBatchSize; + ? deviceNameWithoutBatch + "(" + std::to_string(optimalBatchSize) + ")" + : deviceNameWithBatchSize; if (bDetectionOutput) { deviceName = "HETERO:BATCH," + deviceNameWithoutBatch; std::cout << "HETERO code path!!!!" << std::endl; @@ -828,11 +828,11 @@ public: * @param params Map of device-specific shared context parameters. * @return A shared pointer to a created remote context. */ - InferenceEngine::RemoteContext::Ptr CreateContext(const std::string& deviceName, - const InferenceEngine::ParamMap& params) override { - auto parsed = ov::runtime::parseDeviceNameIntoConfig(deviceName, params); - return GetCPPPluginByName(parsed._deviceName).create_context(parsed._config)._ptr; - } + InferenceEngine::RemoteContext::Ptr CreateContext(const std::string& deviceName, + const InferenceEngine::ParamMap& params) override { + auto parsed = ov::runtime::parseDeviceNameIntoConfig(deviceName, params); + return GetCPPPluginByName(parsed._deviceName).create_context(parsed._config)._ptr; + } /** * @brief Returns reference to CPP plugin wrapper by a device name