Fixed proxy plugin initialization with internal names (#18783)

This commit is contained in:
Ilya Churaev 2023-07-25 22:14:42 +04:00 committed by GitHub
parent 5a4cf4c8b6
commit 7767af3529
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 128 additions and 11 deletions

View File

@ -50,12 +50,23 @@ ov::ICore::~ICore() = default;
namespace {
#ifdef PROXY_PLUGIN_ENABLED
std::string get_internal_plugin_name(const std::string& device_name, const ov::AnyMap& properties) {
std::string get_internal_plugin_name(std::unordered_map<std::string, std::string>& remapped_devices,
const std::string& device_name,
const ov::AnyMap& properties) {
static constexpr const char* internal_plugin_suffix = "_ov_internal";
auto it = properties.find(ov::proxy::configuration::internal_name.name());
if (it != properties.end())
return it->second.as<std::string>();
return device_name + internal_plugin_suffix;
{
auto it = remapped_devices.find(device_name);
if (it != remapped_devices.end())
return it->second;
}
auto plugin_name = device_name + internal_plugin_suffix;
{
auto it = properties.find(ov::proxy::configuration::internal_name.name());
if (it != properties.end())
plugin_name = it->second.as<std::string>();
}
remapped_devices[device_name] = plugin_name;
return plugin_name;
}
#endif
@ -367,8 +378,6 @@ void ov::CoreImpl::register_plugin_in_registry_unsafe(const std::string& device_
if (it != config.end()) {
auto fallback = it->second.as<std::string>();
// Change fallback name if fallback is configured to the HW plugin under the proxy with the same name
if (alias == fallback)
fallback = get_internal_plugin_name(fallback, config);
if (defaultConfig.find(ov::device::priorities.name()) == defaultConfig.end()) {
defaultConfig[ov::device::priorities.name()] = std::vector<std::string>{dev_name, fallback};
} else {
@ -402,7 +411,7 @@ void ov::CoreImpl::register_plugin_in_registry_unsafe(const std::string& device_
// Create proxy plugin for alias
auto alias = config.at(ov::proxy::configuration::alias.name()).as<std::string>();
if (alias == device_name)
dev_name = get_internal_plugin_name(dev_name, config);
dev_name = get_internal_plugin_name(remapped_devices, dev_name, config);
// Alias can be registered by several plugins
if (pluginRegistry.find(alias) == pluginRegistry.end()) {
// Register new plugin
@ -422,7 +431,7 @@ void ov::CoreImpl::register_plugin_in_registry_unsafe(const std::string& device_
}
} else if (config.find(ov::proxy::configuration::fallback.name()) != config.end()) {
// Fallback without alias means that we need to replace original plugin to proxy
dev_name = get_internal_plugin_name(dev_name, config);
dev_name = get_internal_plugin_name(remapped_devices, dev_name, config);
PluginDescriptor desc = PluginDescriptor(ov::proxy::create_plugin);
fill_config(desc.defaultConfig, config, dev_name);
pluginRegistry[device_name] = desc;
@ -618,7 +627,18 @@ ov::Plugin ov::CoreImpl::get_plugin(const std::string& pluginName) const {
}
it = desc.defaultConfig.find(ov::device::priorities.name());
if (it != desc.defaultConfig.end()) {
initial_config[ov::device::priorities.name()] = it->second;
// Fix fallback names in case if proxy plugin got a conflict in the process of plugins registration
auto priorities = it->second.as<std::vector<std::string>>();
for (auto&& priority : priorities) {
if (priority == deviceName) {
OPENVINO_ASSERT(remapped_devices.find(deviceName) != remapped_devices.end(),
"Cannot create proxy device ",
deviceName,
". Device has incorrect configuration.");
priority = remapped_devices.at(deviceName);
}
}
initial_config[ov::device::priorities.name()] = priorities;
}
plugin.set_property(initial_config);
try {
@ -1186,7 +1206,8 @@ void ov::CoreImpl::register_plugin(const std::string& plugin,
std::lock_guard<std::mutex> lock(get_mutex());
auto it = pluginRegistry.find(device_name);
if (it != pluginRegistry.end()) {
// Proxy plugins can be configured in the runtime
if (it != pluginRegistry.end() && !is_proxy_device(device_name)) {
IE_THROW() << "Device with \"" << device_name << "\" is already registered in the OpenVINO Runtime";
}

View File

@ -152,6 +152,8 @@ private:
mutable std::vector<ov::Extension::Ptr> ov_extensions;
std::map<std::string, PluginDescriptor> pluginRegistry;
// Map of remapped devices which have conflict with proxy device
std::unordered_map<std::string, std::string> remapped_devices;
const bool m_new_api;

View File

@ -35,6 +35,34 @@ TEST_F(ProxyTests, alias_for_the_same_name) {
EXPECT_TRUE(mock_reference_dev.empty());
}
TEST_F(ProxyTests, alias_for_the_same_name_with_custom_internal_name_inversed_order) {
register_plugin_support_subtract(core, "DEK", {{ov::proxy::configuration::alias.name(), "CBD"}});
register_plugin_support_reshape(core,
"CBD",
{{ov::proxy::configuration::alias.name(), "CBD"},
{ov::proxy::configuration::fallback.name(), "DEK"},
{ov::proxy::configuration::internal_name.name(), "CBD_INTERNAL"},
{ov::proxy::configuration::priority.name(), 0}});
auto available_devices = core.get_available_devices();
// 0, 1, 2 is ABC plugin
// 1, 3, 4 is BDE plugin
// ABC doesn't support subtract operation
std::unordered_map<std::string, std::string> mock_reference_dev = {{"CBD.0", "CBD_INTERNAL"},
{"CBD.1", "CBD_INTERNAL DEK"},
{"CBD.2", "CBD_INTERNAL"}};
for (const auto& it : mock_reference_dev) {
EXPECT_EQ(core.get_property(it.first, ov::device::priorities), it.second);
}
for (const auto& dev : available_devices) {
auto it = mock_reference_dev.find(dev);
if (it != mock_reference_dev.end()) {
mock_reference_dev.erase(it);
}
}
// All devices should be found
EXPECT_TRUE(mock_reference_dev.empty());
}
TEST_F(ProxyTests, alias_for_the_same_name_with_custom_internal_name) {
register_plugin_support_reshape(core,
"CBD",
@ -95,6 +123,72 @@ TEST_F(ProxyTests, fallback_to_alias_name) {
EXPECT_TRUE(mock_reference_dev.empty());
}
TEST_F(ProxyTests, fallback_to_alias_name_with_custom_internal_name) {
register_plugin_support_reshape(core,
"CBD",
{{ov::proxy::configuration::alias.name(), "CBD"},
{ov::proxy::configuration::internal_name.name(), "CBD_INTERNAL"},
{ov::proxy::configuration::priority.name(), 0}});
register_plugin_support_subtract(core,
"DEK",
{{ov::proxy::configuration::alias.name(), "CBD"},
{ov::proxy::configuration::fallback.name(), "CBD"},
{ov::proxy::configuration::priority.name(), 1}});
auto available_devices = core.get_available_devices();
// 0, 1, 2 is ABC plugin
// 1, 3, 4 is BDE plugin
// ABC doesn't support subtract operation
std::unordered_map<std::string, std::string> mock_reference_dev = {{"CBD.0", "CBD_INTERNAL"},
{"CBD.1", "DEK CBD_INTERNAL"},
{"CBD.2", "CBD_INTERNAL"},
{"CBD.3", "DEK"},
{"CBD.4", "DEK"}};
for (const auto& it : mock_reference_dev) {
EXPECT_EQ(core.get_property(it.first, ov::device::priorities), it.second);
}
for (const auto& dev : available_devices) {
auto it = mock_reference_dev.find(dev);
if (it != mock_reference_dev.end()) {
mock_reference_dev.erase(it);
}
}
// All devices should be found
EXPECT_TRUE(mock_reference_dev.empty());
}
TEST_F(ProxyTests, fallback_to_alias_name_with_custom_internal_name_inverted_order) {
register_plugin_support_subtract(core,
"DEK",
{{ov::proxy::configuration::alias.name(), "CBD"},
{ov::proxy::configuration::fallback.name(), "CBD"},
{ov::proxy::configuration::priority.name(), 1}});
register_plugin_support_reshape(core,
"CBD",
{{ov::proxy::configuration::alias.name(), "CBD"},
{ov::proxy::configuration::internal_name.name(), "CBD_INTERNAL"},
{ov::proxy::configuration::priority.name(), 0}});
auto available_devices = core.get_available_devices();
// 0, 1, 2 is ABC plugin
// 1, 3, 4 is BDE plugin
// ABC doesn't support subtract operation
std::unordered_map<std::string, std::string> mock_reference_dev = {{"CBD.0", "CBD_INTERNAL"},
{"CBD.1", "DEK CBD_INTERNAL"},
{"CBD.2", "CBD_INTERNAL"},
{"CBD.3", "DEK"},
{"CBD.4", "DEK"}};
for (const auto& it : mock_reference_dev) {
EXPECT_EQ(core.get_property(it.first, ov::device::priorities), it.second);
}
for (const auto& dev : available_devices) {
auto it = mock_reference_dev.find(dev);
if (it != mock_reference_dev.end()) {
mock_reference_dev.erase(it);
}
}
// All devices should be found
EXPECT_TRUE(mock_reference_dev.empty());
}
TEST_F(ProxyTests, load_proxy_on_plugin_without_devices_with_the_same_name) {
auto available_devices = core.get_available_devices();
register_plugin_without_devices(