improved remote-blobs tests

This commit is contained in:
myshevts 2021-11-30 19:06:48 +03:00
parent 515a0f7591
commit 43266c585e
3 changed files with 34 additions and 26 deletions

View File

@ -474,7 +474,7 @@ RemoteContext::Ptr AutoBatchInferencePlugin::CreateContext(const InferenceEngin
auto val = it->second;
auto metaDevice = ParseMetaDevice(val, {{}});
cfg.erase(it);
std::cout << "AutoBatchInferencePlugin::CreateContext" << std::endl;
return GetCore()->CreateContext(metaDevice.deviceName, cfg);
}
@ -604,6 +604,7 @@ InferenceEngine::IExecutableNetworkInternal::Ptr AutoBatchInferencePlugin::LoadE
const InferenceEngine::CNNNetwork& network,
const std::shared_ptr<InferenceEngine::RemoteContext>& context,
const std::map<std::string, std::string>& config) {
std::cout << "AutoBatchInferencePlugin::LoadExeNetworkImpl with context for " << context-> getDeviceName() << std::endl;
return LoadNetworkImpl(network, context, config);
}

View File

@ -25,13 +25,17 @@ class RemoteBlob_Test : public CommonTestUtils::TestsCommon, public testing::Wit
protected:
std::shared_ptr<ngraph::Function> fn_ptr;
std::string deviceName;
std::map<std::string, std::string> config;
public:
void SetUp() override {
fn_ptr = ngraph::builder::subgraph::makeSplitMultiConvConcat();
deviceName = CommonTestUtils::DEVICE_GPU;
if (this->GetParam()) // BATCH:GPU(1)
auto with_auto_batching = this->GetParam();
if (with_auto_batching) { // BATCH:GPU(1)
deviceName = std::string(CommonTestUtils::DEVICE_BATCH) + ":" + deviceName + "(1)";
config = {{CONFIG_KEY(ALLOW_AUTO_BATCHING), CONFIG_VALUE(YES)}};
}
}
static std::string getTestCaseName(const testing::TestParamInfo<bool>& obj) {
auto with_auto_batch = obj.param;
@ -172,7 +176,10 @@ TEST_P(RemoteBlob_Test, smoke_canInferOnUserContext) {
// inference using remote blob
auto ocl_instance = std::make_shared<OpenCL>();
auto remote_context = make_shared_context(*ie, deviceName, ocl_instance->_context.get());
auto exec_net_shared = ie->LoadNetwork(net, remote_context);
// since there is no way to enable the Auto-Batching thru the device name when loading with the RemoteContext
// (as the device name is deduced from the context, which is the "GPU")
// the only-way to test the auto-batching is explicit config with ALLOW_AUTO_BATCHING set to YES
auto exec_net_shared = ie->LoadNetwork(net, remote_context, config);
auto inf_req_shared = exec_net_shared.CreateInferRequest();
inf_req_shared.SetBlob(net.getInputsInfo().begin()->first, fakeImageData);

View File

@ -484,10 +484,10 @@ public:
return newAPI;
}
ov::runtime::SoPtr<ie::IExecutableNetworkInternal> LoadNetwork(const ie::CNNNetwork& network,
ov::runtime::SoPtr<ie::IExecutableNetworkInternal> LoadNetwork(
const ie::CNNNetwork& network,
const std::shared_ptr<ie::RemoteContext>& context,
const std::map<std::string, std::string>& config)
override {
const std::map<std::string, std::string>& config) override {
OV_ITT_SCOPE(FIRST_INFERENCE, ie::itt::domains::IE_LT, "Core::LoadNetwork::RemoteContext");
if (context == nullptr) {
IE_THROW() << "Remote context is null";
@ -497,6 +497,7 @@ public:
std::map<std::string, std::string>& config_with_batch = parsed._config;
// if auto-batching is applicable, the below function will patch the device name and config accordingly:
ApplyAutoBatching(network, deviceName, config_with_batch);
parsed = parseDeviceNameIntoConfig(deviceName, config_with_batch);
auto plugin = GetCPPPluginByName(parsed._deviceName);
ov::runtime::SoPtr<ie::IExecutableNetworkInternal> res;
@ -515,7 +516,7 @@ public:
return res;
}
void ApplyAutoBatching (const ie::CNNNetwork& network,
void ApplyAutoBatching(const ie::CNNNetwork& network,
std::string& deviceName,
std::map<std::string, std::string>& config_with_batch) {
std::string deviceNameWithBatchSize;
@ -529,9 +530,8 @@ public:
const auto& batch_mode = config_with_batch.find(CONFIG_KEY(ALLOW_AUTO_BATCHING));
if (batch_mode != config_with_batch.end() && batch_mode->second == CONFIG_VALUE(YES)) {
auto deviceNameWithoutBatch = !deviceNameWithBatchSize.empty()
? DeviceIDParser::getBatchDevice(deviceNameWithBatchSize)
: deviceName;
auto deviceNameWithoutBatch =
!deviceNameWithBatchSize.empty() ? DeviceIDParser::getBatchDevice(deviceNameWithBatchSize) : deviceName;
unsigned int requests = 0;
unsigned int optimalBatchSize = 0;
if (deviceNameWithBatchSize.empty()) {