diff --git a/src/inference/src/ie_core.cpp b/src/inference/src/ie_core.cpp index 76fd36349ea..e470cbca059 100644 --- a/src/inference/src/ie_core.cpp +++ b/src/inference/src/ie_core.cpp @@ -701,7 +701,7 @@ public: std::map& config_with_batch = parsed._config; // if auto-batching is applicable, the below function will patch the device name and config accordingly: ApplyAutoBatching(network, deviceName, config_with_batch); - CleanUpProperties(deviceName, config_with_batch); + CleanUpProperties(deviceName, config_with_batch, ov::auto_batch_timeout); parsed = parseDeviceNameIntoConfig(deviceName, config_with_batch); auto plugin = GetCPPPluginByName(parsed._deviceName); @@ -796,10 +796,10 @@ public: } } - void CleanUpProperties(std::string& deviceName, std::map& config) { + void CleanUpProperties(std::string& deviceName, std::map& config, ov::Any property) { // auto-batching is not applicable, if there is auto_batch_timeout, delete it if (deviceName.find("BATCH") == std::string::npos) { - const auto& batch_timeout_mode = config.find(ov::auto_batch_timeout.name()); + const auto& batch_timeout_mode = config.find(property.as()); if (batch_timeout_mode != config.end()) { if (deviceName.find("AUTO") == std::string::npos && deviceName.find("MULTI") == std::string::npos) config.erase(batch_timeout_mode); @@ -815,7 +815,7 @@ public: std::map config_with_batch = config; // if auto-batching is applicable, the below function will patch the device name and config accordingly: ApplyAutoBatching(network, deviceName, config_with_batch); - CleanUpProperties(deviceName, config_with_batch); + CleanUpProperties(deviceName, config_with_batch, ov::auto_batch_timeout); bool forceDisableCache = config_with_batch.count(CONFIG_KEY_INTERNAL(FORCE_DISABLE_CACHE)) > 0; auto parsed = parseDeviceNameIntoConfig(deviceName, config_with_batch); diff --git a/src/plugins/auto/plugin.cpp b/src/plugins/auto/plugin.cpp index 221a3c34d48..c14ffea6e59 100644 --- a/src/plugins/auto/plugin.cpp +++ b/src/plugins/auto/plugin.cpp @@ -412,8 +412,11 @@ IExecutableNetworkInternal::Ptr MultiDeviceInferencePlugin::LoadNetworkImpl(cons config.first.c_str(), config.second.c_str()); } - insertPropToConfig(CONFIG_KEY(ALLOW_AUTO_BATCHING), iter->deviceName, configs); - insertPropToConfig(CONFIG_KEY(AUTO_BATCH_TIMEOUT), iter->deviceName, configs); + // carry on batch configs only if user explicitly sets + if (config.find(CONFIG_KEY(ALLOW_AUTO_BATCHING)) != config.end()) + insertPropToConfig(CONFIG_KEY(ALLOW_AUTO_BATCHING), iter->deviceName, configs); + if (config.find(CONFIG_KEY(AUTO_BATCH_TIMEOUT)) != config.end()) + insertPropToConfig(CONFIG_KEY(AUTO_BATCH_TIMEOUT), iter->deviceName, configs); insertPropToConfig(CONFIG_KEY(CACHE_DIR), iter->deviceName, configs); strDevices += iter->deviceName; strDevices += ((iter + 1) == supportDevices.end()) ? "" : ","; @@ -463,9 +466,11 @@ IExecutableNetworkInternal::Ptr MultiDeviceInferencePlugin::LoadNetworkImpl(cons LOG_INFO_TAG("set %s=%s", tmpiter->first.c_str(), tmpiter->second.c_str()); multiSContext->_batchingDisabled = true; } - p.config.insert({tmpiter->first, tmpiter->second}); + if (config.find(CONFIG_KEY(ALLOW_AUTO_BATCHING)) != config.end()) + p.config.insert({tmpiter->first, tmpiter->second}); } - insertPropToConfig(CONFIG_KEY(AUTO_BATCH_TIMEOUT), p.deviceName, p.config); + if (config.find(CONFIG_KEY(AUTO_BATCH_TIMEOUT)) != config.end()) + insertPropToConfig(CONFIG_KEY(AUTO_BATCH_TIMEOUT), p.deviceName, p.config); insertPropToConfig(CONFIG_KEY(CACHE_DIR), p.deviceName, p.config); const auto& deviceName = p.deviceName; const auto& deviceConfig = p.config; diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/behavior/ov_plugin/caching_tests.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/behavior/ov_plugin/caching_tests.cpp index 4419020c17e..77242e6c9fe 100644 --- a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/behavior/ov_plugin/caching_tests.cpp +++ b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/behavior/ov_plugin/caching_tests.cpp @@ -128,4 +128,17 @@ namespace { ::testing::ValuesIn(autoConfigs)), CompileModelCacheTestBase::getTestCaseName); + const std::vector LoadFromFileConfigs = { + {ov::device::priorities(CommonTestUtils::DEVICE_CPU)}, + }; + const std::vector TestTargets = + {CommonTestUtils::DEVICE_AUTO, + CommonTestUtils::DEVICE_MULTI, + }; + + INSTANTIATE_TEST_SUITE_P(smoke_Auto_CachingSupportCase_CPU, CompileModelLoadFromFileTestBase, + ::testing::Combine( + ::testing::ValuesIn(TestTargets), + ::testing::ValuesIn(LoadFromFileConfigs)), + CompileModelLoadFromFileTestBase::getTestCaseName); } // namespace diff --git a/src/tests/functional/plugin/shared/include/behavior/ov_plugin/caching_tests.hpp b/src/tests/functional/plugin/shared/include/behavior/ov_plugin/caching_tests.hpp index 9dbcfde4917..2d681023ae3 100644 --- a/src/tests/functional/plugin/shared/include/behavior/ov_plugin/caching_tests.hpp +++ b/src/tests/functional/plugin/shared/include/behavior/ov_plugin/caching_tests.hpp @@ -53,6 +53,24 @@ public: static std::vector getStandardFunctions(); }; +using compileModelLoadFromFileParams = std::tuple< + std::string, // device name + ov::AnyMap // device configuration +>; +class CompileModelLoadFromFileTestBase : public testing::WithParamInterface, + virtual public SubgraphBaseTest, + virtual public OVPluginTestBase { + std::string m_cacheFolderName; + std::string m_modelName; + std::string m_weightsName; + +public: + static std::string getTestCaseName(testing::TestParamInfo obj); + + void SetUp() override; + void TearDown() override; + void run() override; +}; using compileKernelsCacheParams = std::tuple< std::string, // device name std::pair // device and cache configuration diff --git a/src/tests/functional/plugin/shared/src/behavior/ov_plugin/caching_tests.cpp b/src/tests/functional/plugin/shared/src/behavior/ov_plugin/caching_tests.cpp index 7c66fceb874..e802c32fc7e 100644 --- a/src/tests/functional/plugin/shared/src/behavior/ov_plugin/caching_tests.cpp +++ b/src/tests/functional/plugin/shared/src/behavior/ov_plugin/caching_tests.cpp @@ -225,6 +225,67 @@ TEST_P(CompileModelCacheTestBase, CompareWithRefImpl) { run(); } +std::string CompileModelLoadFromFileTestBase::getTestCaseName(testing::TestParamInfo obj) { + auto param = obj.param; + auto deviceName = std::get<0>(param); + auto configuration = std::get<1>(param); + std::ostringstream result; + std::replace(deviceName.begin(), deviceName.end(), ':', '.'); + result << "device_name=" << deviceName << "_"; + for (auto& iter : configuration) { + result << "_" << iter.first << "_" << iter.second.as() << "_"; + } + return result.str(); +} + +void CompileModelLoadFromFileTestBase::SetUp() { + ovModelWithName funcPair; + std::tie(targetDevice, configuration) = GetParam(); + target_device = targetDevice; + APIBaseTest::SetUp(); + std::stringstream ss; + auto hash = std::hash()(SubgraphBaseTest::GetTestName()); + ss << "testCache_" << std::to_string(hash) << "_" << std::this_thread::get_id() << "_" << GetTimestamp(); + m_modelName = ss.str() + ".xml"; + m_weightsName = ss.str() + ".bin"; + for (auto& iter : configuration) { + ss << "_" << iter.first << "_" << iter.second.as() << "_"; + } + m_cacheFolderName = ss.str(); + core->set_property(ov::cache_dir()); + ngraph::pass::Manager manager; + manager.register_pass(m_modelName, m_weightsName); + manager.run_passes(ngraph::builder::subgraph::makeConvPoolRelu( + {1, 3, 227, 227}, InferenceEngine::details::convertPrecision(InferenceEngine::Precision::FP32))); +} + +void CompileModelLoadFromFileTestBase::TearDown() { + CommonTestUtils::removeFilesWithExt(m_cacheFolderName, "blob"); + CommonTestUtils::removeIRFiles(m_modelName, m_weightsName); + std::remove(m_cacheFolderName.c_str()); + core->set_property(ov::cache_dir()); + APIBaseTest::TearDown(); +} + +void CompileModelLoadFromFileTestBase::run() { + SKIP_IF_CURRENT_TEST_IS_DISABLED(); + core->set_property(ov::cache_dir(m_cacheFolderName)); + try { + compiledModel = core->compile_model(m_modelName, targetDevice, configuration); + inferRequest = compiledModel.create_infer_request(); + inferRequest.infer(); + } catch (const Exception &ex) { + GTEST_FAIL() << "Can't loadNetwork with model path " << m_modelName << + "\nException [" << ex.what() << "]" << std::endl; + } catch (...) { + GTEST_FAIL() << "Can't compile network with model path " << m_modelName << std::endl; + } +} + +TEST_P(CompileModelLoadFromFileTestBase, CanLoadFromFileWithoutExecption) { + run(); +} + std::string CompiledKernelsCacheTest::getTestCaseName(testing::TestParamInfo obj) { auto param = obj.param; std::string deviceName;