[IE CORE] enable plugins & dependent libs loading using absolute path (#3639)

* [IE CORE] enable plugins & dependent libs loading using absolute path

urrently this allowed to use plugins.xml file to specify full path to specific plugin with it's all dependency, not to be persisted in CWD or in PATH

* Code review fixes
This commit is contained in:
Mikhail Ryzhov 2020-12-22 21:02:05 +03:00 committed by GitHub
parent 00181d5179
commit f224c52f38
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,7 +1,7 @@
// Copyright (C) 2018-2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "details/ie_exception.hpp"
#include "details/ie_so_loader.h"
#include "file_utils.h"
@ -67,16 +67,19 @@
namespace InferenceEngine {
namespace details {
typedef DWORD(*GetDllDirectoryA_Fnc)(DWORD, LPSTR);
typedef DWORD(*GetDllDirectoryW_Fnc)(DWORD, LPWSTR);
static GetDllDirectoryA_Fnc IEGetDllDirectoryA;
static GetDllDirectoryW_Fnc IEGetDllDirectoryW;
/**
* @brief WINAPI based implementation for loading a shared object
*/
class SharedObjectLoader::Impl {
private:
private:
HMODULE shared_object;
typedef DWORD(* GetDllDirectoryA_Fnc)(DWORD, LPSTR);
typedef DWORD(* GetDllDirectoryW_Fnc)(DWORD, LPWSTR);
static GetDllDirectoryA_Fnc IEGetDllDirectoryA;
static GetDllDirectoryW_Fnc IEGetDllDirectoryW;
void LoadSymbols() {
static std::once_flag loadFlag;
std::call_once(loadFlag, [&] () {
@ -94,7 +97,7 @@ private:
// path was set to "" or NULL so reset it to "" to keep
// application safe.
void ExcludeCurrentDirectoryA() {
#ifndef WINAPI_FAMILY
#if !WINAPI_PARTITION_SYSTEM
LoadSymbols();
if (IEGetDllDirectoryA && IEGetDllDirectoryA(0, NULL) <= 1) {
SetDllDirectoryA("");
@ -104,7 +107,7 @@ private:
#ifdef ENABLE_UNICODE_PATH_SUPPORT
void ExcludeCurrentDirectoryW() {
#ifndef WINAPI_FAMILY
#if !WINAPI_PARTITION_SYSTEM
LoadSymbols();
if (IEGetDllDirectoryW && IEGetDllDirectoryW(0, NULL) <= 1) {
SetDllDirectoryW(L"");
@ -113,12 +116,93 @@ private:
}
#endif
public:
static const char kPathSeparator = '\\';
static const char* FindLastPathSeparator(LPCSTR path) {
const char* const last_sep = strchr(path, kPathSeparator);
return last_sep;
}
static std::string GetDirname(LPCSTR path) {
auto pos = FindLastPathSeparator(path);
if (pos == nullptr) {
return path;
}
std::string original(path);
original[pos - path] = 0;
return original;
}
#ifdef ENABLE_UNICODE_PATH_SUPPORT
static const wchar_t* FindLastPathSeparator(LPCWSTR path) {
const wchar_t* const last_sep = wcsrchr(path, kPathSeparator);
return last_sep;
}
static std::wstring GetDirname(LPCWSTR path) {
auto pos = FindLastPathSeparator(path);
if (pos == nullptr) {
return path;
}
std::wstring original(path);
original[pos - path] = 0;
return original;
}
void LoadPluginFromDirectoryW(LPCWSTR path) {
#if !WINAPI_PARTITION_SYSTEM
LoadSymbols();
if (IEGetDllDirectoryW) {
DWORD nBufferLength = IEGetDllDirectoryW(0, NULL);
std::vector<WCHAR> lpBuffer(nBufferLength);
IEGetDllDirectoryW(nBufferLength, &lpBuffer.front());
auto dirname = GetDirname(path);
SetDllDirectoryW(dirname.c_str());
shared_object = LoadLibraryW(path);
SetDllDirectoryW(&lpBuffer.front());
}
#endif
}
#endif
void LoadPluginFromDirectoryA(LPCSTR path) {
#if !WINAPI_PARTITION_SYSTEM
LoadSymbols();
if (IEGetDllDirectoryA) {
DWORD nBufferLength = IEGetDllDirectoryA(0, NULL);
std::vector<CHAR> lpBuffer(nBufferLength);
IEGetDllDirectoryA(nBufferLength, &lpBuffer.front());
auto dirname = GetDirname(path);
SetDllDirectoryA(dirname.c_str());
shared_object = LoadLibraryA(path);
SetDllDirectoryA(&lpBuffer.front());
}
#endif
}
public:
/**
* @brief A shared pointer to SharedObjectLoader
*/
using Ptr = std::shared_ptr<SharedObjectLoader>;
#ifdef ENABLE_UNICODE_PATH_SUPPORT
/**
* @brief Loads a library with the name specified. The library is loaded according to the
* WinAPI LoadLibrary rules
* @param pluginName Full or relative path to the plugin library
*/
explicit Impl(const wchar_t* pluginName) {
ExcludeCurrentDirectoryW();
LoadPluginFromDirectoryW(pluginName);
if(!shared_object) {
shared_object = LoadLibraryW(pluginName);
}
shared_object = LoadLibraryW(pluginName);
if (!shared_object) {
char cwd[1024];
THROW_IE_EXCEPTION << "Cannot load library '" << FileUtils::wStringtoMBCSstringChar(std::wstring(pluginName)) << "': " << GetLastError()
@ -129,8 +213,12 @@ public:
explicit Impl(const char* pluginName) {
ExcludeCurrentDirectoryA();
LoadPluginFromDirectoryA(pluginName);
if (!shared_object) {
shared_object = LoadLibraryA(pluginName);
}
shared_object = LoadLibraryA(pluginName);
if (!shared_object) {
char cwd[1024];
THROW_IE_EXCEPTION << "Cannot load library '" << pluginName << "': " << GetLastError()
@ -142,6 +230,12 @@ public:
FreeLibrary(shared_object);
}
/**
* @brief Searches for a function symbol in the loaded module
* @param symbolName Name of function to find
* @return A pointer to the function if found
* @throws InferenceEngineException if the function is not found
*/
void* get_symbol(const char* symbolName) const {
if (!shared_object) {
THROW_IE_EXCEPTION << "Cannot get '" << symbolName << "' content from unknown library!";
@ -154,18 +248,17 @@ public:
}
};
#ifdef ENABLE_UNICODE_PATH_SUPPORT
SharedObjectLoader::SharedObjectLoader(const wchar_t* pluginName) {
_impl = std::make_shared<Impl>(pluginName);
}
#endif
SharedObjectLoader::~SharedObjectLoader() noexcept(false) {
}
SharedObjectLoader::SharedObjectLoader(const char * pluginName) {
_impl = std::make_shared<Impl>(pluginName);
}
#ifdef ENABLE_UNICODE_PATH_SUPPORT
SharedObjectLoader::SharedObjectLoader(const wchar_t* pluginName) {
_impl = std::make_shared<Impl>(pluginName);
}
#endif
void* SharedObjectLoader::get_symbol(const char* symbolName) const {
return _impl->get_symbol(symbolName);