[GPU] Fix GPU remote context name initialization (#17850)

This commit is contained in:
Ilya Churaev 2023-06-05 12:00:04 +04:00 committed by GitHub
parent db8d23231a
commit 36625404eb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 22 additions and 9 deletions

View File

@ -24,7 +24,10 @@ class Plugin : public InferenceEngine::IInferencePlugin {
std::map<std::string, cldnn::device::ptr> device_map;
std::map<std::string, ExecutionConfig> m_configs_map;
std::map<std::string, RemoteCLContext::Ptr> m_default_contexts;
mutable std::map<std::string, RemoteCLContext::Ptr> m_default_contexts;
mutable std::once_flag m_default_contexts_once;
std::map<std::string, RemoteCLContext::Ptr> get_default_contexts() const;
InferenceEngine::CNNNetwork clone_and_transform_model(const InferenceEngine::CNNNetwork& network,
const ExecutionConfig& config) const;

View File

@ -4,6 +4,7 @@
#include <limits>
#include <algorithm>
#include <mutex>
#include <string>
#include <map>
#include <vector>
@ -137,7 +138,18 @@ InferenceEngine::CNNNetwork Plugin::clone_and_transform_model(const InferenceEng
return clonedNetwork;
}
Plugin::Plugin() : m_default_contexts({}) {
std::map<std::string, RemoteCLContext::Ptr> Plugin::get_default_contexts() const {
std::call_once(m_default_contexts_once, [this]() {
// Create default context
for (auto& device : device_map) {
auto ctx = std::make_shared<RemoteCLContext>(GetName() + "." + device.first, std::vector<cldnn::device::ptr>{ device.second });
m_default_contexts.insert({device.first, ctx});
}
});
return m_default_contexts;
}
Plugin::Plugin() {
_pluginName = "GPU";
register_primitives();
// try loading gpu engine and get info from it
@ -149,8 +161,6 @@ Plugin::Plugin() : m_default_contexts({}) {
// Set default configs for each device
for (auto& device : device_map) {
m_configs_map.insert({device.first, ExecutionConfig(ov::device::id(device.first))});
auto ctx = std::make_shared<RemoteCLContext>(GetName() + "." + device.first, std::vector<cldnn::device::ptr>{ device.second });
m_default_contexts.insert({device.first, ctx});
}
}
}
@ -226,7 +236,7 @@ InferenceEngine::RemoteContext::Ptr Plugin::CreateContext(const AnyMap& params)
}
std::vector<RemoteContextImpl::Ptr> known_contexts;
for (auto& c : m_default_contexts) {
for (auto& c : get_default_contexts()) {
known_contexts.push_back(c.second->get_impl());
}
std::string context_type = extract_object<std::string>(params, GPU_PARAM_KEY(CONTEXT_TYPE));
@ -245,9 +255,9 @@ InferenceEngine::RemoteContext::Ptr Plugin::CreateContext(const AnyMap& params)
}
RemoteCLContext::Ptr Plugin::get_default_context(const std::string& device_id) const {
OPENVINO_ASSERT(m_default_contexts.find(device_id) != m_default_contexts.end(), "[GPU] Context was not initialized for ", device_id, " device");
OPENVINO_ASSERT(get_default_contexts().find(device_id) != get_default_contexts().end(), "[GPU] Context was not initialized for ", device_id, " device");
return m_default_contexts.at(device_id);;
return get_default_contexts().at(device_id);;
}
InferenceEngine::RemoteContext::Ptr Plugin::GetDefaultContext(const AnyMap& params) {
@ -773,7 +783,7 @@ std::vector<std::string> Plugin::get_device_capabilities(const cldnn::device_inf
uint32_t Plugin::get_max_batch_size(const std::map<std::string, Parameter>& options) const {
GPU_DEBUG_GET_INSTANCE(debug_config);
auto device_id = GetConfig(ov::device::id.name(), options).as<std::string>();
auto context = m_default_contexts.at(device_id)->get_impl();
auto context = get_default_contexts().at(device_id)->get_impl();
const auto& device_info = context->get_engine().get_device_info();
const auto& config = m_configs_map.at(device_id);
uint32_t n_streams = static_cast<uint32_t>(config.get_property(ov::num_streams));
@ -923,7 +933,7 @@ uint32_t Plugin::get_max_batch_size(const std::map<std::string, Parameter>& opti
uint32_t Plugin::get_optimal_batch_size(const std::map<std::string, Parameter>& options) const {
auto device_id = GetConfig(ov::device::id.name(), options).as<std::string>();
auto context = m_default_contexts.at(device_id)->get_impl();
auto context = get_default_contexts().at(device_id)->get_impl();
const auto& device_info = context->get_engine().get_device_info();
auto next_pow_of_2 = [] (float x) {
return pow(2, ceil(std::log(x)/std::log(2)));