do not carry batch configs in auto when user does not explicitly set it (#14003)

* temp resolution to support model path for CPU in auto

Signed-off-by: fishbell <bell.song@intel.com>

* disable batch when load through model path

Signed-off-by: fishbell <bell.song@intel.com>

* add mark for future release

Signed-off-by: fishbell <bell.song@intel.com>

* implement step1: donotparse batch config if user not set explictly

Signed-off-by: fishbell <bell.song@intel.com>

* correct typo in case

Signed-off-by: fishbell <bell.song@intel.com>

Signed-off-by: fishbell <bell.song@intel.com>
This commit is contained in:
yanlan song 2022-12-02 10:14:18 +08:00 committed by GitHub
parent 1f16015802
commit 3eac2cd613
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 105 additions and 8 deletions

View File

@ -701,7 +701,7 @@ public:
std::map<std::string, std::string>& config_with_batch = parsed._config;
// if auto-batching is applicable, the below function will patch the device name and config accordingly:
ApplyAutoBatching(network, deviceName, config_with_batch);
CleanUpProperties(deviceName, config_with_batch);
CleanUpProperties(deviceName, config_with_batch, ov::auto_batch_timeout);
parsed = parseDeviceNameIntoConfig(deviceName, config_with_batch);
auto plugin = GetCPPPluginByName(parsed._deviceName);
@ -796,10 +796,10 @@ public:
}
}
void CleanUpProperties(std::string& deviceName, std::map<std::string, std::string>& config) {
void CleanUpProperties(std::string& deviceName, std::map<std::string, std::string>& config, ov::Any property) {
// auto-batching is not applicable, if there is auto_batch_timeout, delete it
if (deviceName.find("BATCH") == std::string::npos) {
const auto& batch_timeout_mode = config.find(ov::auto_batch_timeout.name());
const auto& batch_timeout_mode = config.find(property.as<std::string>());
if (batch_timeout_mode != config.end()) {
if (deviceName.find("AUTO") == std::string::npos && deviceName.find("MULTI") == std::string::npos)
config.erase(batch_timeout_mode);
@ -815,7 +815,7 @@ public:
std::map<std::string, std::string> config_with_batch = config;
// if auto-batching is applicable, the below function will patch the device name and config accordingly:
ApplyAutoBatching(network, deviceName, config_with_batch);
CleanUpProperties(deviceName, config_with_batch);
CleanUpProperties(deviceName, config_with_batch, ov::auto_batch_timeout);
bool forceDisableCache = config_with_batch.count(CONFIG_KEY_INTERNAL(FORCE_DISABLE_CACHE)) > 0;
auto parsed = parseDeviceNameIntoConfig(deviceName, config_with_batch);

View File

@ -412,8 +412,11 @@ IExecutableNetworkInternal::Ptr MultiDeviceInferencePlugin::LoadNetworkImpl(cons
config.first.c_str(),
config.second.c_str());
}
insertPropToConfig(CONFIG_KEY(ALLOW_AUTO_BATCHING), iter->deviceName, configs);
insertPropToConfig(CONFIG_KEY(AUTO_BATCH_TIMEOUT), iter->deviceName, configs);
// carry on batch configs only if user explicitly sets
if (config.find(CONFIG_KEY(ALLOW_AUTO_BATCHING)) != config.end())
insertPropToConfig(CONFIG_KEY(ALLOW_AUTO_BATCHING), iter->deviceName, configs);
if (config.find(CONFIG_KEY(AUTO_BATCH_TIMEOUT)) != config.end())
insertPropToConfig(CONFIG_KEY(AUTO_BATCH_TIMEOUT), iter->deviceName, configs);
insertPropToConfig(CONFIG_KEY(CACHE_DIR), iter->deviceName, configs);
strDevices += iter->deviceName;
strDevices += ((iter + 1) == supportDevices.end()) ? "" : ",";
@ -463,9 +466,11 @@ IExecutableNetworkInternal::Ptr MultiDeviceInferencePlugin::LoadNetworkImpl(cons
LOG_INFO_TAG("set %s=%s", tmpiter->first.c_str(), tmpiter->second.c_str());
multiSContext->_batchingDisabled = true;
}
p.config.insert({tmpiter->first, tmpiter->second});
if (config.find(CONFIG_KEY(ALLOW_AUTO_BATCHING)) != config.end())
p.config.insert({tmpiter->first, tmpiter->second});
}
insertPropToConfig(CONFIG_KEY(AUTO_BATCH_TIMEOUT), p.deviceName, p.config);
if (config.find(CONFIG_KEY(AUTO_BATCH_TIMEOUT)) != config.end())
insertPropToConfig(CONFIG_KEY(AUTO_BATCH_TIMEOUT), p.deviceName, p.config);
insertPropToConfig(CONFIG_KEY(CACHE_DIR), p.deviceName, p.config);
const auto& deviceName = p.deviceName;
const auto& deviceConfig = p.config;

View File

@ -128,4 +128,17 @@ namespace {
::testing::ValuesIn(autoConfigs)),
CompileModelCacheTestBase::getTestCaseName);
const std::vector<ov::AnyMap> LoadFromFileConfigs = {
{ov::device::priorities(CommonTestUtils::DEVICE_CPU)},
};
const std::vector<std::string> TestTargets =
{CommonTestUtils::DEVICE_AUTO,
CommonTestUtils::DEVICE_MULTI,
};
INSTANTIATE_TEST_SUITE_P(smoke_Auto_CachingSupportCase_CPU, CompileModelLoadFromFileTestBase,
::testing::Combine(
::testing::ValuesIn(TestTargets),
::testing::ValuesIn(LoadFromFileConfigs)),
CompileModelLoadFromFileTestBase::getTestCaseName);
} // namespace

View File

@ -53,6 +53,24 @@ public:
static std::vector<ovModelWithName> getStandardFunctions();
};
using compileModelLoadFromFileParams = std::tuple<
std::string, // device name
ov::AnyMap // device configuration
>;
class CompileModelLoadFromFileTestBase : public testing::WithParamInterface<compileModelLoadFromFileParams>,
virtual public SubgraphBaseTest,
virtual public OVPluginTestBase {
std::string m_cacheFolderName;
std::string m_modelName;
std::string m_weightsName;
public:
static std::string getTestCaseName(testing::TestParamInfo<compileModelLoadFromFileParams> obj);
void SetUp() override;
void TearDown() override;
void run() override;
};
using compileKernelsCacheParams = std::tuple<
std::string, // device name
std::pair<ov::AnyMap, std::string> // device and cache configuration

View File

@ -225,6 +225,67 @@ TEST_P(CompileModelCacheTestBase, CompareWithRefImpl) {
run();
}
std::string CompileModelLoadFromFileTestBase::getTestCaseName(testing::TestParamInfo<compileModelLoadFromFileParams> obj) {
auto param = obj.param;
auto deviceName = std::get<0>(param);
auto configuration = std::get<1>(param);
std::ostringstream result;
std::replace(deviceName.begin(), deviceName.end(), ':', '.');
result << "device_name=" << deviceName << "_";
for (auto& iter : configuration) {
result << "_" << iter.first << "_" << iter.second.as<std::string>() << "_";
}
return result.str();
}
void CompileModelLoadFromFileTestBase::SetUp() {
ovModelWithName funcPair;
std::tie(targetDevice, configuration) = GetParam();
target_device = targetDevice;
APIBaseTest::SetUp();
std::stringstream ss;
auto hash = std::hash<std::string>()(SubgraphBaseTest::GetTestName());
ss << "testCache_" << std::to_string(hash) << "_" << std::this_thread::get_id() << "_" << GetTimestamp();
m_modelName = ss.str() + ".xml";
m_weightsName = ss.str() + ".bin";
for (auto& iter : configuration) {
ss << "_" << iter.first << "_" << iter.second.as<std::string>() << "_";
}
m_cacheFolderName = ss.str();
core->set_property(ov::cache_dir());
ngraph::pass::Manager manager;
manager.register_pass<ov::pass::Serialize>(m_modelName, m_weightsName);
manager.run_passes(ngraph::builder::subgraph::makeConvPoolRelu(
{1, 3, 227, 227}, InferenceEngine::details::convertPrecision(InferenceEngine::Precision::FP32)));
}
void CompileModelLoadFromFileTestBase::TearDown() {
CommonTestUtils::removeFilesWithExt(m_cacheFolderName, "blob");
CommonTestUtils::removeIRFiles(m_modelName, m_weightsName);
std::remove(m_cacheFolderName.c_str());
core->set_property(ov::cache_dir());
APIBaseTest::TearDown();
}
void CompileModelLoadFromFileTestBase::run() {
SKIP_IF_CURRENT_TEST_IS_DISABLED();
core->set_property(ov::cache_dir(m_cacheFolderName));
try {
compiledModel = core->compile_model(m_modelName, targetDevice, configuration);
inferRequest = compiledModel.create_infer_request();
inferRequest.infer();
} catch (const Exception &ex) {
GTEST_FAIL() << "Can't loadNetwork with model path " << m_modelName <<
"\nException [" << ex.what() << "]" << std::endl;
} catch (...) {
GTEST_FAIL() << "Can't compile network with model path " << m_modelName << std::endl;
}
}
TEST_P(CompileModelLoadFromFileTestBase, CanLoadFromFileWithoutExecption) {
run();
}
std::string CompiledKernelsCacheTest::getTestCaseName(testing::TestParamInfo<compileKernelsCacheParams> obj) {
auto param = obj.param;
std::string deviceName;