diff --git a/inference-engine/include/ie_core.hpp b/inference-engine/include/ie_core.hpp index c5197f0a41c..fddf2b29069 100644 --- a/inference-engine/include/ie_core.hpp +++ b/inference-engine/include/ie_core.hpp @@ -1,4 +1,4 @@ -// Copyright (C) 2018-2020 Intel Corporation +// Copyright (C) 2018-2021 Intel Corporation // SPDX-License-Identifier: Apache-2.0 // @@ -108,6 +108,23 @@ public: const CNNNetwork& network, const std::string& deviceName, const std::map& config = {}); + /** + * @brief Reads model and creates an executable network from IR or ONNX file + * + * This can be more efficient than using ReadNetwork + LoadNetwork(CNNNetwork) flow + * especially for cases when caching is enabled and cached model is available + * + * @param modelPath path to model + * @param deviceName Name of device to load network to + * @param config Optional map of pairs: (config parameter name, config parameter value) relevant only for this load + * operation/ + * + * @return An executable network reference + */ + ExecutableNetwork LoadNetwork( + const std::string& modelPath, const std::string& deviceName, + const std::map& config = {}); + /** * @brief Registers extension * @param extension Pointer to already loaded extension @@ -137,8 +154,8 @@ public: /** * @brief Creates an executable network from a previously exported network * - * @param deviceName Name of device load executable network on * @param modelFileName Path to the location of the exported file + * @param deviceName Name of device load executable network on * @param config Optional map of pairs: (config parameter name, config parameter value) relevant only for this load * operation* * @return An executable network reference @@ -149,8 +166,8 @@ public: /** * @brief Creates an executable network from a previously exported network - * @param deviceName Name of device load executable network on * @param networkModel network model stream + * @param deviceName Name of device load executable network on * @param config Optional map of pairs: (config parameter name, config parameter value) relevant only for this load * operation* * @return An executable network reference diff --git a/inference-engine/include/ie_plugin_config.hpp b/inference-engine/include/ie_plugin_config.hpp index e6175eb5356..a211c491be6 100644 --- a/inference-engine/include/ie_plugin_config.hpp +++ b/inference-engine/include/ie_plugin_config.hpp @@ -1,4 +1,4 @@ -// Copyright (C) 2018-2020 Intel Corporation +// Copyright (C) 2018-2021 Intel Corporation // SPDX-License-Identifier: Apache-2.0 // @@ -143,6 +143,16 @@ DECLARE_METRIC_KEY(NUMBER_OF_WAITING_INFER_REQUESTS, unsigned int); */ DECLARE_METRIC_KEY(NUMBER_OF_EXEC_INFER_REQUESTS, unsigned int); +/** + * @brief Metric which defines the device architecture. + */ +DECLARE_METRIC_KEY(DEVICE_ARCHITECTURE, std::string); + +/** + * @brief Metric which defines support of import/export functionality by plugin + */ +DECLARE_METRIC_KEY(IMPORT_EXPORT_SUPPORT, bool); + /** * @brief Metric to get a name of network. String value is "NETWORK_NAME". */ @@ -363,16 +373,23 @@ DECLARE_CONFIG_KEY(DUMP_EXEC_GRAPH_AS_DOT); DECLARE_CONFIG_KEY(ENFORCE_BF16); /** -* @brief This key defines the directory which will be used to store any data cached by plugins. -* -* This key supports unicode symbols in path -* The underlying cache structure is not defined and might differ between OpenVINO releases -* Cached data might be platform/device specific and might be invalid after OpenVINO version change -* If this key is not specified or value is empty string, then caching is disabled. -* The key might enable caching for all plugin or some specific ones, e.g.: -* ie.SetConfig({{CONFIG_KEY(CACHE_DIR), "cache/"}}) - enables cache for all plugins that might want to use it -* ie.SetConfig({{CONFIG_KEY(CACHE_DIR), "cache/"}}, {"GPU"}) - enables cache only for GPU plugin -*/ + * @brief This key defines the directory which will be used to store any data cached by plugins. + * + * The underlying cache structure is not defined and might differ between OpenVINO releases + * Cached data might be platform / device specific and might be invalid after OpenVINO version change + * If this key is not specified or value is empty string, then caching is disabled. + * The key might enable caching for the plugin using the following code: + * + * @code + * ie.SetConfig({{CONFIG_KEY(CACHE_DIR), "cache/"}}, "GPU"); // enables cache for GPU plugin + * @endcode + * + * The following code enables caching of compiled network blobs for devices where import/export is supported + * + * @code + * ie.SetConfig({{CONFIG_KEY(CACHE_DIR), "cache/"}}); // enables models cache + * @endcode + */ DECLARE_CONFIG_KEY(CACHE_DIR); } // namespace PluginConfigParams diff --git a/inference-engine/src/inference_engine/compilation_context.cpp b/inference-engine/src/inference_engine/compilation_context.cpp new file mode 100644 index 00000000000..05cd6879f0a --- /dev/null +++ b/inference-engine/src/inference_engine/compilation_context.cpp @@ -0,0 +1,219 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// +#include "compilation_context.hpp" + +#include +#include + +#ifndef WIN32 +#include +#endif +#include + +#include "ie_itt.hpp" +#include "cpp_interfaces/exception2status.hpp" +#include "transformations/serialize.hpp" +#include "cpp/ie_cnn_network.h" +#include "details/ie_exception.hpp" + +#include "ngraph/variant.hpp" +#include "ngraph/opsets/opset6.hpp" +#include "transformations/rt_info/dequantization_attribute.hpp" +#include "transformations/rt_info/fused_names_attribute.hpp" +#include "transformations/rt_info/primitives_priority_attribute.hpp" +#include "file_utils.h" + +#ifdef WIN32 +#define stat _stat +#endif + +namespace InferenceEngine { + +template +static std::size_t hash_combine(std::size_t seed, const T& a) { + // Hash combine formula from boost + return seed ^ (std::hash()(a) + 0x9e3779b9 + (seed << 6) + (seed >> 2)); +} + +template +static int32_t as_int32_t(T v) { + return static_cast(v); +} + +class OstreamHashWrapper final: public std::streambuf { + std::size_t m_res = {}; +public: + std::size_t getResult() const { return m_res; } + std::streamsize xsputn(const char* s, std::streamsize n) override { + const std::int64_t* intS = (const std::int64_t *)s; + std::streamsize n64 = n / sizeof(std::int64_t); + std::streamsize i = 0; + // Using 64-bit values executes much faster than char + while (i++ < n64) { + m_res += *(intS++); + } + + std::streamsize rest = n % sizeof(std::int64_t); + for (i = 0; i < rest; i++) { + m_res += s[n - rest + i]; + } + return n; + } +}; + +////////////////////////////////////////////////// + +std::string NetworkCompilationContext::calculateFileInfo(const std::string& filePath) { + size_t seed {}; + auto absPath = filePath; + try { + absPath = FileUtils::absoluteFilePath(filePath); + } catch (...) { + // can't get absolute path, will use filePath for hash + } + + seed = hash_combine(seed, absPath); + + std::string res; + struct stat result; + if (stat(absPath.c_str(), &result) == 0) { + seed = hash_combine(seed, result.st_mtime); + seed = hash_combine(seed, result.st_size); + } + return std::to_string(seed); +} + +std::string NetworkCompilationContext::computeHash(const CNNNetwork& network, + const std::map& compileOptions) { + OV_ITT_SCOPED_TASK(itt::domains::IE_LT, "NetworkCompilationContext::computeHash - CNN"); + OstreamHashWrapper xmlHash; + OstreamHashWrapper binHash; + std::ostream xml(&xmlHash); + std::ostream bin(&binHash); + + IE_ASSERT(network.getFunction()); + + // 1. Serialize + CNNNetwork net(network); + ngraph::pass::Serialize serializer(xml, bin, + ngraph::pass::Serialize::Version::IR_V10); + serializer.run_on_function(net.getFunction()); + + // 2. Compute hash on serialized data and options + size_t seed {}; + seed = hash_combine(seed, xmlHash.getResult()); + seed = hash_combine(seed, binHash.getResult()); + + for (const auto& kvp : compileOptions) { + seed = hash_combine(seed, kvp.first + kvp.second); + } + + // 3. Add runtime information which may not be serialized + for (const auto& op : network.getFunction()->get_ordered_ops()) { + const auto& rt = op->get_rt_info(); + for (const auto& rtMapData : rt) { + seed = hash_combine(seed, rtMapData.first); + + if (auto stringData = std::dynamic_pointer_cast>(rtMapData.second)) { + seed = hash_combine(seed, stringData->get()); + } else if (auto intData = std::dynamic_pointer_cast>(rtMapData.second)) { + seed = hash_combine(seed, intData->get()); + } else if (auto deq = std::dynamic_pointer_cast>(rtMapData.second)) { + seed = hash_combine(seed, deq->get().getDequantizationAttr()); + } else if (auto fNames = std::dynamic_pointer_cast>(rtMapData.second)) { + seed = hash_combine(seed, fNames->get().getNames()); + } else if (auto prim = std::dynamic_pointer_cast>(rtMapData.second)) { + seed = hash_combine(seed, prim->get().getPrimitivesPriority()); + } + } + } + + // 4. Add inputs info + for (const auto& input : network.getInputsInfo()) { + InputInfo::Ptr info = input.second; + seed = hash_combine(seed, as_int32_t(info->getPrecision())); + seed = hash_combine(seed, as_int32_t(info->getLayout())); + + const InferenceEngine::PreProcessInfo& preproc = info->getPreProcess(); + seed = hash_combine(seed, as_int32_t(preproc.getMeanVariant())); + + if (preproc.getMeanVariant() == MeanVariant::MEAN_VALUE) { + seed = hash_combine(seed, preproc.getNumberOfChannels()); + for (size_t c = 0; c < preproc.getNumberOfChannels(); ++c) { + const PreProcessChannel::Ptr & channelInfo = preproc[c]; + seed = hash_combine(seed, channelInfo->stdScale); + seed = hash_combine(seed, channelInfo->meanValue); + } + } else if (preproc.getMeanVariant() == MeanVariant::MEAN_IMAGE) { + // TODO: think if we need to compute hash for mean image if it exists + } + } + + // 5. Add outputs info + for (const auto& output : network.getOutputsInfo()) { + DataPtr info = output.second; + seed = hash_combine(seed, as_int32_t(info->getPrecision())); + seed = hash_combine(seed, as_int32_t(info->getLayout())); + } + + return std::to_string(seed); +} + +std::string NetworkCompilationContext::computeHash(const std::string& modelName, + const std::map& compileOptions) { + OV_ITT_SCOPED_TASK(itt::domains::IE_LT, "NetworkCompilationContext::computeHash - ModelName"); + size_t seed {}; + try { + seed = hash_combine(seed, FileUtils::absoluteFilePath(modelName)); + } catch (...) { + // can't get absolute path, use modelName for hash calculation + seed = hash_combine(seed, modelName); + } + for (const auto& kvp : compileOptions) { + seed = hash_combine(seed, kvp.first + kvp.second); + } + return std::to_string(seed); +} + +////////////////////////////////////////////////// + +CompiledBlobHeader::CompiledBlobHeader() {} + +CompiledBlobHeader::CompiledBlobHeader(const std::string& ieVersion, const std::string& fileInfo) : + m_ieVersion(ieVersion), + m_fileInfo(fileInfo) { +} + +std::istream& operator >> (std::istream& stream, CompiledBlobHeader& header) { + std::string xmlStr; + std::getline(stream, xmlStr); + + pugi::xml_document document; + pugi::xml_parse_result res = document.load_string(xmlStr.c_str()); + + if (res.status != pugi::status_ok) { + THROW_IE_EXCEPTION_WITH_STATUS(NETWORK_NOT_READ) << "Error reading compiled blob header"; + } + + pugi::xml_node compiledBlobNode = document.document_element(); + header.m_ieVersion = XMLParseUtils::GetStrAttr(compiledBlobNode, "ie_version"); + header.m_fileInfo = XMLParseUtils::GetStrAttr(compiledBlobNode, "file_info"); + + return stream; +} + +std::ostream& operator << (std::ostream& stream, const CompiledBlobHeader& header) { + pugi::xml_document document; + auto compiledBlobNode = document.append_child("compiled_blob"); + compiledBlobNode.append_attribute("ie_version").set_value(header.m_ieVersion.c_str()); + compiledBlobNode.append_attribute("file_info").set_value(header.m_fileInfo.c_str()); + + document.save(stream, nullptr, pugi::format_raw); + document.reset(); + stream << std::endl; + + return stream; +} + +} // namespace InferenceEngine diff --git a/inference-engine/src/inference_engine/compilation_context.hpp b/inference-engine/src/inference_engine/compilation_context.hpp new file mode 100644 index 00000000000..c91c613ed9d --- /dev/null +++ b/inference-engine/src/inference_engine/compilation_context.hpp @@ -0,0 +1,47 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include +#include + +namespace InferenceEngine { + +class CNNNetwork; + +struct NetworkCompilationContext final { + static std::string calculateFileInfo(const std::string& filePath); + + static std::string computeHash(const CNNNetwork& network, + const std::map& compileOptions); + + static std::string computeHash(const std::string& modelName, + const std::map& compileOptions); +}; + +class CompiledBlobHeader final { + std::string m_ieVersion; + std::string m_fileInfo; + +public: + CompiledBlobHeader(); + CompiledBlobHeader(const std::string& ieVersion, const std::string& fileInfo); + + const std::string& getIeVersion() const { + return m_ieVersion; + } + + const std::string& getFileInfo() const { + return m_fileInfo; + } + + friend std::istream & operator >> (std::istream& stream, CompiledBlobHeader& header); + + friend std::ostream & operator << (std::ostream& stream, const CompiledBlobHeader& header); +}; + +} // namespace InferenceEngine diff --git a/inference-engine/src/inference_engine/file_utils.cpp b/inference-engine/src/inference_engine/file_utils.cpp index 15898395770..a9a57fba744 100644 --- a/inference-engine/src/inference_engine/file_utils.cpp +++ b/inference-engine/src/inference_engine/file_utils.cpp @@ -13,6 +13,8 @@ #include #include
+#include +#include #ifndef _WIN32 # include @@ -32,6 +34,38 @@ # include #endif +#ifdef _WIN32 + +#include + +// Copied from linux libc sys/stat.h: +# define S_ISDIR(m) (((m) & S_IFMT) == S_IFDIR) + +/// @brief Windows-specific 'mkdir' wrapper +#define makedir(dir) _mkdir(dir) + +/// @brief Max length of absolute file path +#define MAX_ABS_PATH _MAX_PATH +/// @brief Get absolute file path, returns NULL in case of error +#define get_absolute_path(result, path) _fullpath(result, path.c_str(), MAX_ABS_PATH) + +/// @brief Windows-specific 'stat' wrapper +#define stat _stat + +#else + +#include + +/// @brief mkdir wrapper +#define makedir(dir) mkdir(dir, 0755) + +/// @brief Max length of absolute file path +#define MAX_ABS_PATH PATH_MAX +/// @brief Get absolute file path, returns NULL in case of error +#define get_absolute_path(result, path) realpath(path.c_str(), result) + +#endif + #ifdef ENABLE_UNICODE_PATH_SUPPORT std::string FileUtils::wStringtoMBCSstringChar(const std::wstring& wstr) { @@ -73,6 +107,44 @@ long long FileUtils::fileSize(const char* charfilepath) { return in.tellg(); } +std::string FileUtils::absoluteFilePath(const std::string& filePath) { + std::string absolutePath; + absolutePath.resize(MAX_ABS_PATH); + auto absPath = get_absolute_path(&absolutePath[0], filePath); + if (!absPath) { + THROW_IE_EXCEPTION << "Can't get absolute file path for [" << filePath << "], err = " << strerror(errno); + } + absolutePath.resize(strlen(absPath)); + return absolutePath; +} + +bool FileUtils::directoryExists(const std::string &path) { + struct stat sb; + + if (stat(path.c_str(), &sb) == 0 && S_ISDIR(sb.st_mode)) { + return true; + } + return false; +} + +void FileUtils::createDirectoryRecursive(const std::string& dirPath) { + if (dirPath.empty() || directoryExists(dirPath)) { + return; + } + + std::size_t pos = dirPath.rfind(FileUtils::FileSeparator); + if (pos != std::string::npos) { + createDirectoryRecursive(dirPath.substr(0, pos)); + } + + int err = makedir(dirPath.c_str()); + if (err != 0 && errno != EEXIST) { + // TODO: in case of exception it may be needed to remove all created sub-directories + THROW_IE_EXCEPTION << "Couldn't create directory [" + << dirPath << "], err=" << strerror(errno) << ")"; + } +} + namespace InferenceEngine { namespace { diff --git a/inference-engine/src/inference_engine/ie_cache_manager.hpp b/inference-engine/src/inference_engine/ie_cache_manager.hpp new file mode 100644 index 00000000000..a45e004b97f --- /dev/null +++ b/inference-engine/src/inference_engine/ie_cache_manager.hpp @@ -0,0 +1,121 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +/** + * @brief This is a header file for the Inference Engine Cache Manager class C++ API + * + * @file ie_cache_manager.hpp + */ +#pragma once + +#include +#include +#include +#include +#include "ie_api.h" +#include "file_utils.h" + +namespace InferenceEngine { + +/** + * @brief This class represents private interface for Cache Manager + * + */ +class ICacheManager { +public: + /** + * @brief Default destructor + */ + virtual ~ICacheManager() = default; + + /** + * @brief Function passing created output stream + * + */ + using StreamWriter = std::function; + /** + * @brief Callback when Inference Engine intends to write network to cache + * + * Client needs to call create std::ostream object and call writer(ostream) + * Otherwise, network will not be cached + * + * @param id Id of cache (hash of the network) + * @param writer Lambda function to be called when stream is created + */ + virtual void writeCacheEntry(const std::string& id, StreamWriter writer) = 0; + + /** + * @brief Function passing created input stream + * + */ + using StreamReader = std::function; + /** + * @brief Callback when Inference Engine intends to read network from cache + * + * Client needs to call create std::istream object and call reader(istream) + * Otherwise, network will not be read from cache and will be loaded as usual + * + * @param id Id of cache (hash of the network) + * @param reader Lambda function to be called when input stream is created + */ + virtual void readCacheEntry(const std::string& id, StreamReader reader) = 0; + + /** + * @brief Callback when Inference Engine intends to remove cache entry + * + * Client needs to perform appropriate cleanup (e.g. delete a cache file) + * + * @param id Id of cache (hash of the network) + */ + virtual void removeCacheEntry(const std::string& id) = 0; +}; + +/** + * @brief File storage-based Implementation of ICacheManager + * + * Uses simple file for read/write cached models. + * + */ +class FileStorageCacheManager final : public ICacheManager { + std::string m_cachePath; + + std::string getBlobFile(const std::string& blobHash) const { + return FileUtils::makePath(m_cachePath, blobHash + ".blob"); + } + +public: + /** + * @brief Constructor + * + */ + FileStorageCacheManager(std::string&& cachePath) : m_cachePath(std::move(cachePath)) {} + + /** + * @brief Destructor + * + */ + ~FileStorageCacheManager() override = default; + +private: + void writeCacheEntry(const std::string& id, StreamWriter writer) override { + std::ofstream stream(getBlobFile(id), std::ios_base::binary | std::ofstream::out); + writer(stream); + } + + void readCacheEntry(const std::string& id, StreamReader reader) override { + auto blobFileName = getBlobFile(id); + if (FileUtils::fileExist(blobFileName)) { + std::ifstream stream(blobFileName, std::ios_base::binary); + reader(stream); + } + } + + void removeCacheEntry(const std::string& id) override { + auto blobFileName = getBlobFile(id); + if (FileUtils::fileExist(blobFileName)) + std::remove(blobFileName.c_str()); + } +}; + +} // namespace InferenceEngine diff --git a/inference-engine/src/inference_engine/ie_core.cpp b/inference-engine/src/inference_engine/ie_core.cpp index 0d77b16a092..243df52b0e8 100644 --- a/inference-engine/src/inference_engine/ie_core.cpp +++ b/inference-engine/src/inference_engine/ie_core.cpp @@ -1,4 +1,4 @@ -// Copyright (C) 2018-2020 Intel Corporation +// Copyright (C) 2018-2021 Intel Corporation // SPDX-License-Identifier: Apache-2.0 // @@ -6,8 +6,8 @@ #include #include #include -#include #include +#include #include #include @@ -17,14 +17,17 @@ #include #include +#include "compilation_context.hpp" #include "ie_plugin_cpp.hpp" #include "ie_plugin_config.hpp" +#include "ie_cache_manager.hpp" #include "ie_itt.hpp" #include "file_utils.h" #include "ie_network_reader.hpp" #include "xml_parse_utils.h" using namespace InferenceEngine::PluginConfigParams; +using namespace std::placeholders; namespace InferenceEngine { @@ -158,6 +161,41 @@ class Core::Impl : public ICore { mutable std::map plugins; + class CoreConfig final { + public: + struct CacheConfig { + std::shared_ptr _cacheManager; + }; + + void setAndUpdate(std::map& config) { + auto it = config.find(CONFIG_KEY(CACHE_DIR)); + if (it != config.end()) { + std::lock_guard lock(_cacheConfigMutex); + if (!it->second.empty()) { + FileUtils::createDirectoryRecursive(it->second); + _cacheConfig._cacheManager = std::make_shared(std::move(it->second)); + } else { + _cacheConfig._cacheManager = nullptr; + } + + config.erase(it); + } + } + + // Creating thread-safe copy of config including shared_ptr to ICacheManager + CacheConfig getCacheConfig() const { + std::lock_guard lock(_cacheConfigMutex); + return _cacheConfig; + } + + private: + mutable std::mutex _cacheConfigMutex; + CacheConfig _cacheConfig; + }; + + // Core settings (cache config, etc) + CoreConfig coreConfig; + struct PluginDescriptor { FileUtils::FilePath libraryLocation; std::map defaultConfig; @@ -170,9 +208,141 @@ class Core::Impl : public ICore { std::map pluginRegistry; mutable std::mutex pluginsMutex; // to lock parallel access to pluginRegistry and plugins + bool DeviceSupportsImportExport(const InferencePlugin& plugin) const { + std::vector supportedMetricKeys = plugin.GetMetric(METRIC_KEY(SUPPORTED_METRICS), {}); + auto it = std::find(supportedMetricKeys.begin(), supportedMetricKeys.end(), + METRIC_KEY(IMPORT_EXPORT_SUPPORT)); + bool supported = (it != supportedMetricKeys.end()) && + plugin.GetMetric(METRIC_KEY(IMPORT_EXPORT_SUPPORT), {}); + return supported; + } + + ExecutableNetwork LoadNetworkImpl(const CNNNetwork& network, + InferencePlugin& plugin, + const std::map& parsedConfig, + const RemoteContext::Ptr& context, + const std::string& blobID, + const std::string& modelPath = std::string()) { + OV_ITT_SCOPED_TASK(itt::domains::IE_LT, "Core::Impl::LoadNetworkImpl"); + ExecutableNetwork execNetwork; + execNetwork = context ? plugin.LoadNetwork(network, context, parsedConfig) : + plugin.LoadNetwork(network, parsedConfig); + auto cacheManager = coreConfig.getCacheConfig()._cacheManager; + if (cacheManager && DeviceSupportsImportExport(plugin)) { + try { + // need to export network for further import from "cache" + OV_ITT_SCOPED_TASK(itt::domains::IE_LT, "Core::LoadNetwork::Export"); + cacheManager->writeCacheEntry(blobID, [&](std::ostream& networkStream) { + networkStream << CompiledBlobHeader(GetInferenceEngineVersion()->buildNumber, + NetworkCompilationContext::calculateFileInfo(modelPath)); + execNetwork.Export(networkStream); + }); + } catch (...) { + cacheManager->removeCacheEntry(blobID); + throw; + } + } + return execNetwork; + } + + ExecutableNetwork LoadNetworkFromCache(const std::shared_ptr& cacheManager, + const std::string& blobId, + InferencePlugin& plugin, + const std::map& config, + const RemoteContext::Ptr& context, + bool& networkIsImported, + const std::string& modelPath = std::string()) { + ExecutableNetwork execNetwork; + struct HeaderException {}; + + IE_ASSERT(cacheManager != nullptr); + try { + cacheManager->readCacheEntry(blobId, [&](std::istream &networkStream) { + OV_ITT_SCOPED_TASK(itt::domains::IE_LT, "Core::LoadNetworkFromCache::ReadStreamAndImport"); + try { + CompiledBlobHeader header; + networkStream >> header; + if (header.getIeVersion() != GetInferenceEngineVersion()->buildNumber) { + // Build number mismatch, don't use this cache + throw NetworkNotRead("Version does not match"); + } + if (header.getFileInfo() != NetworkCompilationContext::calculateFileInfo(modelPath)) { + // Original file is changed, don't use cache + throw NetworkNotRead("Original model file is changed"); + } + } catch (...) { + throw HeaderException(); + } + + execNetwork = context ? + plugin.ImportNetwork(networkStream, context, config) : + plugin.ImportNetwork(networkStream, config); + networkIsImported = true; + }); + } catch (const HeaderException& ex) { + // For these exceptions just remove old cache and set that import didn't work + cacheManager->removeCacheEntry(blobId); + networkIsImported = false; + } catch (...) { + cacheManager->removeCacheEntry(blobId); + throw; + } + return execNetwork; + } + + std::map CreateCompileConfig(const InferencePlugin& plugin, + const std::string& deviceFamily, + const std::map& origConfig) const { + std::map getMetricConfig; + auto compileConfig = origConfig; + + // 0. remove DEVICE_ID key + auto deviceIt = compileConfig.find(CONFIG_KEY(DEVICE_ID)); + if (deviceIt != compileConfig.end()) { + getMetricConfig[deviceIt->first] = deviceIt->second; + compileConfig.erase(deviceIt); + } + + // 1. replace it with DEVICE_ARCHITECTURE value + std::vector supportedMetricKeys = + plugin.GetMetric(METRIC_KEY(SUPPORTED_METRICS), getMetricConfig); + auto archIt = std::find(supportedMetricKeys.begin(), supportedMetricKeys.end(), + METRIC_KEY(DEVICE_ARCHITECTURE)); + if (archIt != supportedMetricKeys.end()) { + auto value = plugin.GetMetric(METRIC_KEY(DEVICE_ARCHITECTURE), getMetricConfig); + compileConfig[METRIC_KEY(DEVICE_ARCHITECTURE)] = value.as(); + } else { + // Take device name if device does not support DEVICE_ARCHITECTURE metric + compileConfig[METRIC_KEY(DEVICE_ARCHITECTURE)] = deviceFamily; + } + return compileConfig; + } + + std::string CalculateNetworkHash(const CNNNetwork& network, const std::string& deviceFamily, + const InferencePlugin& plugin, + const std::map& config) const { + auto compileConfig = CreateCompileConfig(plugin, deviceFamily, config); + return NetworkCompilationContext::computeHash(network, compileConfig); + } + + std::string CalculateFileHash(const std::string& modelName, const std::string& deviceFamily, + const InferencePlugin& plugin, + const std::map& config) const { + auto compileConfig = CreateCompileConfig(plugin, deviceFamily, config); + return NetworkCompilationContext::computeHash(modelName, compileConfig); + } + public: - Impl(); - ~Impl() override; + Impl() { + opsetNames.insert("opset1"); + opsetNames.insert("opset2"); + opsetNames.insert("opset3"); + opsetNames.insert("opset4"); + opsetNames.insert("opset5"); + opsetNames.insert("opset6"); + } + + ~Impl() override = default; /** * @brief Register plugins for devices which are located in .xml configuration file. The function supports UNICODE path @@ -250,20 +420,80 @@ public: } CNNNetwork ReadNetwork(const std::string& modelPath, const std::string& binPath) const override { - OV_ITT_SCOPED_TASK(itt::domains::IE); + OV_ITT_SCOPED_TASK(itt::domains::IE, "Core::Impl::ReadNetwork from file"); return details::ReadNetwork(modelPath, binPath, extensions); } CNNNetwork ReadNetwork(const std::string& model, const Blob::CPtr& weights) const override { - OV_ITT_SCOPED_TASK(itt::domains::IE, "Core::Impl::ReadNetwork"); + OV_ITT_SCOPED_TASK(itt::domains::IE, "Core::Impl::ReadNetwork from memory"); return details::ReadNetwork(model, weights, extensions); } + // TODO: In future this method can be added to ICore interface + ExecutableNetwork LoadNetwork(const CNNNetwork& network, const RemoteContext::Ptr& context, + const std::map& config) { + OV_ITT_SCOPED_TASK(itt::domains::IE_LT, "Core::LoadNetwork::RemoteContext"); + if (context == nullptr) { + THROW_IE_EXCEPTION << "Remote context is null"; + } + auto parsed = parseDeviceNameIntoConfig(context->getDeviceName(), config); + auto plugin = GetCPPPluginByName(parsed._deviceName); + bool loadedFromCache = false; + ExecutableNetwork res; + std::string hash; + auto cacheManager = coreConfig.getCacheConfig()._cacheManager; + if (cacheManager && DeviceSupportsImportExport(plugin)) { + hash = CalculateNetworkHash(network, parsed._deviceName, plugin, parsed._config); + res = LoadNetworkFromCache(cacheManager, hash, plugin, parsed._config, context, loadedFromCache); + } + + if (!loadedFromCache) { + res = LoadNetworkImpl(network, plugin, parsed._config, context, hash); + } + return res; + } + ExecutableNetwork LoadNetwork(const CNNNetwork& network, const std::string& deviceName, const std::map& config) override { - OV_ITT_SCOPED_TASK(itt::domains::IE, "Core::Impl::LoadNetwork"); + OV_ITT_SCOPED_TASK(itt::domains::IE_LT, "Core::LoadNetwork::CNN"); auto parsed = parseDeviceNameIntoConfig(deviceName, config); - return GetCPPPluginByName(parsed._deviceName).LoadNetwork(network, parsed._config); + auto plugin = GetCPPPluginByName(parsed._deviceName); + bool loadedFromCache = false; + ExecutableNetwork res; + std::string hash; + auto cacheManager = coreConfig.getCacheConfig()._cacheManager; + if (cacheManager && DeviceSupportsImportExport(plugin)) { + hash = CalculateNetworkHash(network, parsed._deviceName, plugin, parsed._config); + res = LoadNetworkFromCache(cacheManager, hash, plugin, parsed._config, nullptr, loadedFromCache); + } + + if (!loadedFromCache) { + res = LoadNetworkImpl(network, plugin, parsed._config, nullptr, hash); + } + return res; + } + + // TODO: In future this method can be added to ICore interface + ExecutableNetwork LoadNetwork(const std::string& modelPath, const std::string& deviceName, + const std::map& config) { + OV_ITT_SCOPED_TASK(itt::domains::IE_LT, "Core::LoadNetwork::Path"); + auto parsed = parseDeviceNameIntoConfig(deviceName, config); + auto plugin = GetCPPPluginByName(parsed._deviceName); + bool loadedFromCache = false; + ExecutableNetwork res; + std::string hash; + auto cacheManager = coreConfig.getCacheConfig()._cacheManager; + if (cacheManager && DeviceSupportsImportExport(plugin)) { + hash = CalculateFileHash(modelPath, parsed._deviceName, plugin, parsed._config); + res = LoadNetworkFromCache(cacheManager, hash, plugin, parsed._config, + nullptr, loadedFromCache, modelPath); + } + + if (!loadedFromCache) { + auto cnnNetwork = ReadNetwork(modelPath, std::string()); + res = LoadNetworkImpl(cnnNetwork, plugin, parsed._config, nullptr, hash, modelPath); + } + return res; } ExecutableNetwork ImportNetwork(std::istream& networkModel, const std::string& deviceName, @@ -286,6 +516,7 @@ public: QueryNetworkResult QueryNetwork(const CNNNetwork& network, const std::string& deviceName, const std::map& config) const override { + OV_ITT_SCOPED_TASK(itt::domains::IE, "Core::QueryNetwork"); auto parsed = parseDeviceNameIntoConfig(deviceName, config); auto res = GetCPPPluginByName(parsed._deviceName).QueryNetwork(network, parsed._config); if (!network.getFunction() || res.supportedLayersMap.empty()) @@ -338,7 +569,6 @@ public: } /** - * @deprecated * @brief Returns reference to CPP plugin wrapper by a device name * @param deviceName A name of device * @return Reference to a CPP plugin wrapper @@ -463,10 +693,18 @@ public: * @brief Sets config values for a plugin or set of plugins * @param deviceName A device name to set config to * If empty, config is set for all the plugins / plugin's meta-data + * @note `deviceName` is not allowed in form of MULTI:CPU, HETERO:FPGA,CPU + * just simple forms like CPU, GPU, MULTU, GPU.0, etc */ - void SetConfigForPlugins(const std::map& config, const std::string& deviceName) { + void SetConfigForPlugins(const std::map& configMap, const std::string& deviceName) { + auto config = configMap; + std::lock_guard lock(pluginsMutex); + if (deviceName.empty()) { + coreConfig.setAndUpdate(config); + } + // set config for plugins in registry bool configIsSet = false; for (auto& desc : pluginRegistry) { @@ -524,15 +762,6 @@ public: } }; -Core::Impl::Impl() { - opsetNames.insert("opset1"); - opsetNames.insert("opset2"); - opsetNames.insert("opset3"); - opsetNames.insert("opset4"); -} - -Core::Impl::~Impl() {} - Core::Core(const std::string& xmlConfigFile) { _impl = std::make_shared(); @@ -603,20 +832,14 @@ ExecutableNetwork Core::LoadNetwork(const CNNNetwork& network, const std::string return _impl->LoadNetwork(network, deviceName, config); } -void Core::AddExtension(const IExtensionPtr& extension) { - _impl->AddExtension(extension); -} - ExecutableNetwork Core::LoadNetwork(const CNNNetwork& network, RemoteContext::Ptr context, const std::map& config) { - OV_ITT_SCOPED_TASK(itt::domains::IE, "Core::LoadNetwork"); + return _impl->LoadNetwork(network, context, config); +} - if (context == nullptr) { - THROW_IE_EXCEPTION << "Remote context is null"; - } - - auto parsed = parseDeviceNameIntoConfig(context->getDeviceName(), config); - return _impl->GetCPPPluginByName(parsed._deviceName).LoadNetwork(network, parsed._config, context); +ExecutableNetwork Core::LoadNetwork(const std::string& modelPath, const std::string& deviceName, + const std::map& config) { + return _impl->LoadNetwork(modelPath, deviceName, config); } RemoteContext::Ptr Core::CreateContext(const std::string& deviceName, const ParamMap& params) { @@ -656,8 +879,15 @@ void Core::AddExtension(IExtensionPtr extension, const std::string& deviceName_) _impl->AddExtension(extension); } +void Core::AddExtension(const IExtensionPtr& extension) { + _impl->AddExtension(extension); +} + ExecutableNetwork Core::ImportNetwork(const std::string& modelFileName, const std::string& deviceName, const std::map& config) { + OV_ITT_SCOPED_TASK(itt::domains::IE, "Core::ImportNetwork"); + + // TODO: remove once NotImplemented exception is deprecated and not used if (deviceName.find("HETERO") == 0) { THROW_IE_EXCEPTION << "HETERO device does not support ImportNetwork"; } @@ -698,19 +928,21 @@ QueryNetworkResult Core::QueryNetwork(const CNNNetwork& network, const std::stri void Core::SetConfig(const std::map& config, const std::string& deviceName) { // HETERO case - { - if (deviceName.find("HETERO:") == 0) { - THROW_IE_EXCEPTION << "SetConfig is supported only for HETERO itself (without devices). " - "You can configure the devices with SetConfig before creating the HETERO on top."; - } + if (deviceName.find("HETERO:") == 0) { + THROW_IE_EXCEPTION << "SetConfig is supported only for HETERO itself (without devices). " + "You can configure the devices with SetConfig before creating the HETERO on top."; } // MULTI case - { - if (deviceName.find("MULTI:") == 0) { - THROW_IE_EXCEPTION << "SetConfig is supported only for MULTI itself (without devices). " - "You can configure the devices with SetConfig before creating the MULTI on top."; - } + if (deviceName.find("MULTI:") == 0) { + THROW_IE_EXCEPTION << "SetConfig is supported only for MULTI itself (without devices). " + "You can configure the devices with SetConfig before creating the MULTI on top."; + } + + // GPU.0, FPGA.1 cases + if (deviceName.find(".") != std::string::npos) { + THROW_IE_EXCEPTION << "SetConfig is supported only for device family itself (without particular device .#). " + "You can pass .# as a particular device instance to QueryNetwork, LoadNetwork, ImportNetwork only"; } if (deviceName.empty()) { diff --git a/inference-engine/src/inference_engine/ie_plugin_cpp.hpp b/inference-engine/src/inference_engine/ie_plugin_cpp.hpp index e57239e95ae..f0474e190c3 100644 --- a/inference-engine/src/inference_engine/ie_plugin_cpp.hpp +++ b/inference-engine/src/inference_engine/ie_plugin_cpp.hpp @@ -1,4 +1,4 @@ -// Copyright (C) 2018-2020 Intel Corporation +// Copyright (C) 2018-2021 Intel Corporation // SPDX-License-Identifier: Apache-2.0 // @@ -75,10 +75,6 @@ public: CALL_STATEMENT(return actual->GetVersion()); } - ExecutableNetwork LoadNetwork(CNNNetwork network, const std::map& config) { - CALL_STATEMENT(return ExecutableNetwork(actual->LoadNetwork(network, config), actual)); - } - void AddExtension(InferenceEngine::IExtensionPtr extension) { CALL_STATEMENT(actual->AddExtension(extension)); } @@ -87,9 +83,12 @@ public: CALL_STATEMENT(actual->SetConfig(config)); } - ExecutableNetwork ImportNetwork(const std::string& modelFileName, - const std::map& config) { - CALL_STATEMENT(return ExecutableNetwork(actual->ImportNetwork(modelFileName, config), actual)); + ExecutableNetwork LoadNetwork(const CNNNetwork& network, const std::map& config) { + CALL_STATEMENT(return ExecutableNetwork(actual->LoadNetwork(network, config), actual)); + } + + ExecutableNetwork LoadNetwork(const CNNNetwork& network, RemoteContext::Ptr context, const std::map& config) { + CALL_STATEMENT(return ExecutableNetwork(actual->LoadNetwork(network, config, context), actual)); } QueryNetworkResult QueryNetwork(const CNNNetwork& network, @@ -100,20 +99,26 @@ public: return res; } + ExecutableNetwork ImportNetwork(const std::string& modelFileName, + const std::map& config) { + CALL_STATEMENT(return ExecutableNetwork(actual->ImportNetwork(modelFileName, config), actual)); + } + ExecutableNetwork ImportNetwork(std::istream& networkModel, - const std::map &config) { + const std::map& config) { CALL_STATEMENT(return ExecutableNetwork(actual->ImportNetwork(networkModel, config), actual)); } + ExecutableNetwork ImportNetwork(std::istream& networkModel, + const RemoteContext::Ptr& context, + const std::map& config) { + CALL_STATEMENT(return ExecutableNetwork(actual->ImportNetwork(networkModel, context, config), actual)); + } + Parameter GetMetric(const std::string& name, const std::map& options) const { CALL_STATEMENT(return actual->GetMetric(name, options)); } - ExecutableNetwork LoadNetwork(const CNNNetwork& network, const std::map& config, - RemoteContext::Ptr context) { - CALL_STATEMENT(return ExecutableNetwork(actual->LoadNetwork(network, config, context), actual)); - } - RemoteContext::Ptr CreateContext(const ParamMap& params) { CALL_STATEMENT(return actual->CreateContext(params)); } @@ -122,12 +127,6 @@ public: CALL_STATEMENT(return actual->GetDefaultContext(params)); } - ExecutableNetwork ImportNetwork(std::istream& networkModel, - const RemoteContext::Ptr& context, - const std::map& config) { - CALL_STATEMENT(return ExecutableNetwork(actual->ImportNetwork(networkModel, context, config), actual)); - } - Parameter GetConfig(const std::string& name, const std::map& options) const { CALL_STATEMENT(return actual->GetConfig(name, options)); } diff --git a/inference-engine/src/plugin_api/file_utils.h b/inference-engine/src/plugin_api/file_utils.h index cb3dd6a1756..c9e9d0bd763 100644 --- a/inference-engine/src/plugin_api/file_utils.h +++ b/inference-engine/src/plugin_api/file_utils.h @@ -81,6 +81,31 @@ template<> struct FileTraits { }; #endif +/** + * @brief Interface function to get absolute path of file + * @ingroup ie_dev_api_file_utils + * @param filePath - path to file, can be relative to current working directory + * @return Absolute path of file + * @throw InferenceEngineException if any error occurred + */ +INFERENCE_ENGINE_API_CPP(std::string) absoluteFilePath(const std::string& filePath); + +/** + * @brief Interface function to create directorty recursively by given path + * @ingroup ie_dev_api_file_utils + * @param dirPath - path to file, can be relative to current working directory + * @throw InferenceEngineException if any error occurred + */ +INFERENCE_ENGINE_API_CPP(void) createDirectoryRecursive(const std::string& dirPath); + +/** + * @brief Interface function to check if directory exists for given path + * @ingroup ie_dev_api_file_utils + * @param path - path to directory + * @return true if directory exists, false otherwise + */ +INFERENCE_ENGINE_API_CPP(bool) directoryExists(const std::string& path); + /** * @brief Interface function to get the size of a file. The function supports UNICODE path * @ingroup ie_dev_api_file_utils diff --git a/inference-engine/src/plugin_api/ie_icore.hpp b/inference-engine/src/plugin_api/ie_icore.hpp index 42210ef59fe..7534c8c765e 100644 --- a/inference-engine/src/plugin_api/ie_icore.hpp +++ b/inference-engine/src/plugin_api/ie_icore.hpp @@ -1,4 +1,4 @@ -// Copyright (C) 2018-2020 Intel Corporation +// Copyright (C) 2018-2021 Intel Corporation // SPDX-License-Identifier: Apache-2.0 // @@ -68,8 +68,8 @@ public: /** * @brief Creates an executable network from a previously exported network - * @param deviceName Name of device load executable network on * @param networkModel network model stream + * @param deviceName Name of device load executable network on * @param config Optional map of pairs: (config parameter name, config parameter value) relevant only for this load * operation* * @return An executable network reference diff --git a/inference-engine/tests/functional/inference_engine/caching_test.cpp b/inference-engine/tests/functional/inference_engine/caching_test.cpp new file mode 100644 index 00000000000..7f25324aebf --- /dev/null +++ b/inference-engine/tests/functional/inference_engine/caching_test.cpp @@ -0,0 +1,956 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include +#include +#include +#include +#include +#include + +#include "ie_plugin_ptr.hpp" +#include "ngraph/function.hpp" +#include "details/ie_so_loader.h" +#include "ie_metric_helpers.hpp" +#include "ie_iexecutable_network.hpp" + +#include "cpp_interfaces/impl/ie_executable_network_internal.hpp" +#include "cpp_interfaces/impl/ie_plugin_internal.hpp" + +#include "common_test_utils/unicode_utils.hpp" +#include "common_test_utils/file_utils.hpp" +#include "common_test_utils/test_constants.hpp" + +#include "functional_test_utils/test_model/test_model.hpp" +#include "functional_test_utils/network_utils.hpp" + +#include "unit_test_utils/mocks/mock_iexecutable_network.hpp" + +using namespace InferenceEngine; +using namespace ::testing; +using namespace InferenceEngine::details; +using namespace std::placeholders; +using namespace std::chrono; + +enum class TestLoadType { + ECNN, + EContext, + EModelName +}; +using TestParam = std::tuple; + +// GCC4.8 limitation: have to specify type of each element in list +static const std::vector loadVariants = { + TestParam { TestLoadType::ECNN, std::string("ByCNNNetwork"), false }, + TestParam { TestLoadType::EContext, std::string("ByRemoteContext"), true }, + TestParam { TestLoadType::EModelName, std::string("ByModelName"), false }, +}; + +static const std::vector cacheFolders { + std::string("testCache"), +}; + +std::string getTestCaseName(const testing::TestParamInfo> &obj) { + return std::get<1>(std::get<0>(obj.param)) + "_" + std::get<1>(obj.param); +} + +class MockRemoteContext : public RemoteContext { + std::string m_name; +public: + MockRemoteContext(std::string name): m_name(std::move(name)) {} + std::string getDeviceName() const noexcept { return m_name; } + MOCK_METHOD2(CreateBlob, RemoteBlob::Ptr(const TensorDesc&, const ParamMap&)); + MOCK_QUALIFIED_METHOD0(getParams, const, ParamMap()); +}; + +class MockCachingInferencePlugin : public InferenceEngine::InferencePluginInternal { +public: + MockCachingInferencePlugin() = default; + ~MockCachingInferencePlugin() = default; + + MOCK_METHOD2(LoadExeNetworkImpl, ExecutableNetworkInternal::Ptr(const CNNNetwork& network, + const std::map& config)); + + MOCK_METHOD3(LoadExeNetworkImpl, ExecutableNetworkInternal::Ptr(const CNNNetwork& network, RemoteContext::Ptr context, + const std::map& config)); + + MOCK_METHOD2(ImportNetworkImpl, ExecutableNetwork(std::istream& networkModel, + const std::map& config)); + + MOCK_METHOD3(ImportNetworkImpl, ExecutableNetwork(std::istream& networkModel, + const RemoteContext::Ptr& context, + const std::map& config)); + + MOCK_CONST_METHOD2(QueryNetwork, QueryNetworkResult(const CNNNetwork& network, + const std::map& config)); + + MOCK_CONST_METHOD2(GetMetric, Parameter(const std::string& name, const std::map& options)); + MOCK_METHOD1(GetDefaultContext, RemoteContext::Ptr(const ParamMap& params)); +}; + +class MockExecutableNetwork : public ExecutableNetworkInternal { +public: + MockExecutableNetwork() {} + MOCK_METHOD1(ExportImpl, void(std::ostream& networkModel)); + MOCK_METHOD0(CreateInferRequest, IInferRequest::Ptr()); +}; + +//------------------------------------------------------ +class MkDirGuard { + std::string m_dir; +public: + MkDirGuard(const std::string &dir = std::string()): m_dir(dir) { + if (!m_dir.empty()) { + CommonTestUtils::createDirectory(m_dir); + } + } + + MkDirGuard(const MkDirGuard&) = delete; + MkDirGuard& operator=(const MkDirGuard&) = delete; + + ~MkDirGuard() { + if (!m_dir.empty()) { + CommonTestUtils::removeFilesWithExt(m_dir, "blob"); + CommonTestUtils::removeDir(m_dir); + } + } +}; + +class CachingTest : public ::testing::TestWithParam> { +public: + std::unique_ptr sharedObjectLoader; + std::function injectProxyEngine; + std::string modelName = "Caching_test.xml"; + std::string weightsName = "Caching_test.bin"; + std::string deviceName = "mock"; + std::string deviceToLoad = "mock"; + std::shared_ptr mockPlugin; + std::shared_ptr net; + std::unique_ptr m_dirCreator; + TestLoadType m_type; + std::string m_cacheDir; + using LoadFunction = std::function; + using LoadFunctionWithCfg = std::function &)>; + LoadFunction m_testFunction; + LoadFunctionWithCfg m_testFunctionWithCfg; + bool m_remoteContext = false; + using CNNCallback = std::function; + CNNCallback m_cnnCallback = nullptr; + + + std::string get_mock_engine_name() { + std::string mockEngineName("mock_engine"); + return CommonTestUtils::pre + mockEngineName + IE_BUILD_POSTFIX + CommonTestUtils::ext; + } + + static std::string generateTestFilePrefix() { + // Generate unique file names based on test name, thread id and timestamp + // This allows execution of tests in parallel (stress mode) + auto testInfo = UnitTest::GetInstance()->current_test_info(); + std::string testName = testInfo->test_case_name(); + testName += testInfo->name(); + testName = std::to_string(std::hash()(testName)); + std::stringstream ss; + auto ts = duration_cast(high_resolution_clock::now().time_since_epoch()); + ss << testName << "_" << std::this_thread::get_id() << "_" << ts.count(); + testName = ss.str(); + return testName; + } + + void initParamTest() { + m_type = std::get<0>(std::get<0>(GetParam())); + m_cacheDir = std::get<1>(GetParam()); + m_testFunction = getLoadFunction(m_type); + m_testFunctionWithCfg = getLoadFunctionWithCfg(m_type); + m_remoteContext = std::get<2>(std::get<0>(GetParam())); + auto testName = generateTestFilePrefix(); + modelName = testName + ".xml"; + weightsName = testName + ".bin"; + m_cacheDir = testName + m_cacheDir; + m_dirCreator = std::unique_ptr(new MkDirGuard(m_cacheDir)); + } + + void SetUp() override { + initParamTest(); + mockPlugin = std::make_shared(); + net = std::make_shared(); + setupMock(*mockPlugin); + std::string libraryName = get_mock_engine_name(); + sharedObjectLoader.reset(new SharedObjectLoader(libraryName.c_str())); + injectProxyEngine = make_std_function("InjectProxyEngine"); + + FuncTestUtils::TestModel::generateTestModel(modelName, weightsName); + } + + void TearDown() override { + CommonTestUtils::removeIRFiles(modelName, weightsName); + } + + void testLoad(std::function func) { + Core ie; + injectProxyEngine(mockPlugin.get()); + ie.RegisterPlugin(std::string("mock_engine") + IE_BUILD_POSTFIX, deviceName); + func(ie); + ie.UnregisterPlugin(deviceName); + } + + LoadFunction getLoadFunction(TestLoadType type) const { + switch (type) { + case TestLoadType::ECNN: + return [&](Core& ie) { performReadAndLoad(ie); }; + case TestLoadType::EContext: + return [&](Core& ie) { performReadAndLoadWithContext(ie); }; + case TestLoadType::EModelName: + return [&](Core& ie) { performLoadByName(ie); }; + } + return nullptr; + } + + LoadFunctionWithCfg getLoadFunctionWithCfg(TestLoadType type) const { + switch (type) { + case TestLoadType::ECNN: + return std::bind(&CachingTest::performReadAndLoad, this, _1, _2); + case TestLoadType::EContext: + return std::bind(&CachingTest::performReadAndLoadWithContext, this, _1, _2); + case TestLoadType::EModelName: + return std::bind(&CachingTest::performLoadByName, this, _1, _2); + } + return nullptr; + } + + void performLoadByName(Core& ie, const std::map& config = {}) const { + ie.LoadNetwork(modelName, deviceToLoad, config); + } + + void performReadAndLoad(Core& ie, const std::map& config = {}) const { + auto cnnNetwork = ie.ReadNetwork(modelName); + if (m_cnnCallback) m_cnnCallback(cnnNetwork); + ie.LoadNetwork(cnnNetwork, deviceToLoad, config); + } + + void performReadAndLoadWithContext(Core& ie, const std::map& config = {}) const { + auto cnnNetwork = ie.ReadNetwork(modelName); + EXPECT_CALL(*mockPlugin, GetDefaultContext(_)).Times(AnyNumber()); + auto context = ie.GetDefaultContext(deviceToLoad); + if (m_cnnCallback) m_cnnCallback(cnnNetwork); + ie.LoadNetwork(cnnNetwork, context, config); + } + +private: + template + std::function make_std_function(const std::string& functionName) { + std::function ptr(reinterpret_cast(sharedObjectLoader->get_symbol(functionName.c_str()))); + return ptr; + } + + void setupMock(MockCachingInferencePlugin& plugin) { + ON_CALL(plugin, GetMetric(METRIC_KEY(SUPPORTED_METRICS), _)). + WillByDefault(Invoke([&](const std::string &, const std::map &) { + std::vector res; + res.push_back(METRIC_KEY(IMPORT_EXPORT_SUPPORT)); + res.push_back(METRIC_KEY(DEVICE_ARCHITECTURE)); + return res; + })); + ON_CALL(plugin, GetMetric(METRIC_KEY(IMPORT_EXPORT_SUPPORT), _)). + WillByDefault(Return(true)); + + ON_CALL(plugin, GetMetric(METRIC_KEY(SUPPORTED_CONFIG_KEYS), _)). + WillByDefault(Invoke([&](const std::string &, const std::map &) { + std::vector res; + res.push_back("SomeConfig"); + return res; + })); + + ON_CALL(plugin, GetMetric(METRIC_KEY(DEVICE_ARCHITECTURE), _)). + WillByDefault(Return("mock")); + + ON_CALL(plugin, ImportNetworkImpl(_, _, _)). + WillByDefault(Invoke([&](std::istream &, RemoteContext::Ptr, + const std::map &) { + return ExecutableNetwork(std::make_shared()); + })); + + ON_CALL(plugin, ImportNetworkImpl(_, _)). + WillByDefault(Invoke([&](std::istream &, const std::map &) { + return ExecutableNetwork(std::make_shared()); + })); + + ON_CALL(plugin, LoadExeNetworkImpl(_, _, _)). + WillByDefault(Invoke([&](const CNNNetwork &, RemoteContext::Ptr, + const std::map &) { + return net; + })); + + ON_CALL(plugin, LoadExeNetworkImpl(_, _)). + WillByDefault(Invoke([&](const CNNNetwork &, + const std::map &) { + return net; + })); + + ON_CALL(plugin, GetDefaultContext(_)). + WillByDefault(Invoke([&](const ParamMap &) { + return std::make_shared(deviceToLoad); + })); + + ON_CALL(plugin, QueryNetwork(_, _)). + WillByDefault(Invoke([&](const CNNNetwork &network, const std::map&) { + QueryNetworkResult res; + auto function = network.getFunction(); + EXPECT_TRUE(function); + + for (auto &&node : function->get_ops()) { + res.supportedLayersMap.emplace(node->get_friendly_name(), deviceName); + } + return res; + })); + } +}; + +TEST_P(CachingTest, TestLoad) { + EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_METRICS), _)).Times(AnyNumber()); + EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(IMPORT_EXPORT_SUPPORT), _)).Times(AnyNumber()); + EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(DEVICE_ARCHITECTURE), _)).Times(AnyNumber()); + { + EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0); + EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0); + EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0); + EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(0); + EXPECT_CALL(*net, ExportImpl(_)).Times(1); + testLoad([&](Core &ie) { + ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}}); + m_testFunction(ie); + }); + } + + { + EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(0); + EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(0); + EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0); + EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0); + EXPECT_CALL(*net, ExportImpl(_)).Times(0); + testLoad([&](Core &ie) { + ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}}); + m_testFunction(ie); + }); + } +} + +TEST_P(CachingTest, TestLoadCustomImportExport) { + const int customNumber = 1234; + EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_METRICS), _)).Times(AnyNumber()); + EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(IMPORT_EXPORT_SUPPORT), _)).Times(AnyNumber()); + EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(DEVICE_ARCHITECTURE), _)).Times(AnyNumber()); + ON_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)). + WillByDefault(Invoke([&](std::istream& s, RemoteContext::Ptr, + const std::map &) { + int a; + s >> a; + EXPECT_EQ(customNumber, a); + return ExecutableNetwork(std::make_shared()); + })); + + ON_CALL(*mockPlugin, ImportNetworkImpl(_, _)). + WillByDefault(Invoke([&](std::istream &s, const std::map &) { + int a; + s >> a; + EXPECT_EQ(customNumber, a); + return ExecutableNetwork(std::make_shared()); + })); + + ON_CALL(*net, ExportImpl(_)).WillByDefault(Invoke([&] (std::ostream& s) { + s << customNumber; + })); + + { + EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0); + EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0); + EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0); + EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(0); + EXPECT_CALL(*net, ExportImpl(_)).Times(1); + testLoad([&](Core &ie) { + ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}}); + m_testFunction(ie); + }); + } + + { + EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(0); + EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(0); + EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0); + EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0); + EXPECT_CALL(*net, ExportImpl(_)).Times(0); + testLoad([&](Core &ie) { + ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}}); + m_testFunction(ie); + }); + } +} + +// Brief: when LoadNetwork is called from different config - old cache shall not be used +TEST_P(CachingTest, TestChangeLoadConfig) { + const std::string CUSTOM_KEY = "CUSTOM_KEY"; + EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_METRICS), _)).Times(AnyNumber()); + EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(IMPORT_EXPORT_SUPPORT), _)).Times(AnyNumber()); + EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(DEVICE_ARCHITECTURE), _)).Times(AnyNumber()); + ON_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_CONFIG_KEYS), _)). + WillByDefault(Invoke([&](const std::string &, const std::map &) { + std::vector res; + res.push_back(CUSTOM_KEY); + return res; + })); + { + EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0); + EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0); + EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0); + EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(0); + EXPECT_CALL(*net, ExportImpl(_)).Times(1); + testLoad([&](Core &ie) { + ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}}); + m_testFunctionWithCfg(ie, {{CUSTOM_KEY, "0"}}); + }); + } + + { + EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0); + EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0); + EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0); + EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(0); + EXPECT_CALL(*net, ExportImpl(_)).Times(1); + testLoad([&](Core &ie) { + ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}}); + m_testFunctionWithCfg(ie, {{CUSTOM_KEY, "1"}}); + }); + } +} + +TEST_P(CachingTest, TestNoCacheEnabled) { + EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_METRICS), _)).Times(AnyNumber()); + EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(IMPORT_EXPORT_SUPPORT), _)).Times(0); + EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(DEVICE_ARCHITECTURE), _)).Times(AnyNumber()); + { + EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0); + EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0); + EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0); + EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(0); + EXPECT_CALL(*net, ExportImpl(_)).Times(0); + testLoad([&](Core &ie) { + m_testFunction(ie); + }); + } +} + +TEST_P(CachingTest, TestNoCacheSupported) { + EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_METRICS), _)).Times(AnyNumber()); + EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(IMPORT_EXPORT_SUPPORT), _)) + .Times(AnyNumber()).WillRepeatedly(Return(false)); + EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(DEVICE_ARCHITECTURE), _)).Times(AnyNumber()); + { + EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0); + EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0); + EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0); + EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(0); + EXPECT_CALL(*net, ExportImpl(_)).Times(0); + testLoad([&](Core &ie) { + ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}}); + m_testFunction(ie); + }); + } +} + +TEST_P(CachingTest, TestNoCacheMetricSupported) { + EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_METRICS), _)) + .Times(AnyNumber()).WillRepeatedly(Return(std::vector{})); + EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(IMPORT_EXPORT_SUPPORT), _)).Times(0); + EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(DEVICE_ARCHITECTURE), _)).Times(0); + { + EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0); + EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0); + EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0); + EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(0); + EXPECT_CALL(*net, ExportImpl(_)).Times(0); + testLoad([&](Core &ie) { + ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}}); + m_testFunction(ie); + }); + } +} + +TEST_P(CachingTest, TestLoadChangeCacheDir) { + EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_METRICS), _)).Times(AnyNumber()); + EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(IMPORT_EXPORT_SUPPORT), _)).Times(AnyNumber()); + EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(DEVICE_ARCHITECTURE), _)).Times(AnyNumber()); + { + EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0); + EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0); + EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0); + EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(0); + EXPECT_CALL(*net, ExportImpl(_)).Times(1); + testLoad([&](Core &ie) { + ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}}); + m_testFunction(ie); + }); + } + + { + std::string newCacheDir = m_cacheDir + "2"; + MkDirGuard dir(newCacheDir); + EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0); + EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0); + EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0); + EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(0); + EXPECT_CALL(*net, ExportImpl(_)).Times(1); + testLoad([&](Core &ie) { + ie.SetConfig({{CONFIG_KEY(CACHE_DIR), newCacheDir}}); + m_testFunction(ie); + }); + } +} + +TEST_P(CachingTest, TestClearCacheDir) { + EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_METRICS), _)).Times(AnyNumber()); + EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(IMPORT_EXPORT_SUPPORT), _)).Times(0); + EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(DEVICE_ARCHITECTURE), _)).Times(AnyNumber()); + { + EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0); + EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0); + EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0); + EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(0); + EXPECT_CALL(*net, ExportImpl(_)).Times(0); + testLoad([&](Core &ie) { + ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}}); + ie.SetConfig({{CONFIG_KEY(CACHE_DIR), ""}}); + m_testFunction(ie); + }); + } +} + +TEST_P(CachingTest, TestChangeOtherConfig) { + EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_METRICS), _)).Times(AnyNumber()); + EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(IMPORT_EXPORT_SUPPORT), _)).Times(AnyNumber()); + EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(DEVICE_ARCHITECTURE), _)).Times(AnyNumber()); + { + EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0); + EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0); + EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0); + EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(0); + EXPECT_CALL(*net, ExportImpl(_)).Times(1); + testLoad([&](Core &ie) { + ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}}); + ie.SetConfig({{"someKey", "someValue"}}); + m_testFunction(ie); + }); + } +} + +TEST_P(CachingTest, TestChangeCacheDirFailure) { + std::string longName(1000000, ' '); + EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_METRICS), _)).Times(AnyNumber()); + EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(IMPORT_EXPORT_SUPPORT), _)).Times(AnyNumber()); + EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(DEVICE_ARCHITECTURE), _)).Times(AnyNumber()); + { + EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0); + EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0); + EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0); + EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(0); + EXPECT_CALL(*net, ExportImpl(_)).Times(1); + testLoad([&](Core &ie) { + ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}}); + m_testFunction(ie); + }); + } + + { + EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(0); + EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(0); + EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0); + EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0); + EXPECT_CALL(*net, ExportImpl(_)).Times(0); + testLoad([&](Core &ie) { + ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}}); + EXPECT_ANY_THROW(ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir + "/" + longName}})); + m_testFunction(ie); + }); + } +} + +TEST_P(CachingTest, TestCacheDirCreateRecursive) { + std::string newCacheDir1 = m_cacheDir + CommonTestUtils::FileSeparator + "a"; + std::string newCacheDir2 = newCacheDir1 + CommonTestUtils::FileSeparator + "b"; + std::string newCacheDir3 = newCacheDir2 + CommonTestUtils::FileSeparator + CommonTestUtils::FileSeparator; + + EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_METRICS), _)).Times(AnyNumber()); + EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(IMPORT_EXPORT_SUPPORT), _)).Times(AnyNumber()); + EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(DEVICE_ARCHITECTURE), _)).Times(AnyNumber()); + { + EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0); + EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0); + EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0); + EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(0); + EXPECT_CALL(*net, ExportImpl(_)).Times(1); + testLoad([&](Core &ie) { + EXPECT_NO_THROW(ie.SetConfig({{CONFIG_KEY(CACHE_DIR), newCacheDir3}})); + EXPECT_NO_THROW(m_testFunction(ie)); + }); + } + CommonTestUtils::removeFilesWithExt(newCacheDir2, "blob"); + std::remove(newCacheDir2.c_str()); + std::remove(newCacheDir1.c_str()); +} + +TEST_P(CachingTest, TestDeviceArchitecture) { + EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_METRICS), _)).Times(AnyNumber()); + EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(IMPORT_EXPORT_SUPPORT), _)).Times(AnyNumber()); + EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(DEVICE_ARCHITECTURE), _)).Times(AnyNumber()) + .WillRepeatedly(Invoke([&](const std::string&, const std::map& options) { + auto id = options.at("DEVICE_ID").as(); + if (std::stoi(id) < 10) { + return "mock_first_architecture"; + } else { + return "mock_another_architecture"; + } + })); + { + EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0); + EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0); + EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0); + EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(0); + EXPECT_CALL(*net, ExportImpl(_)).Times(1); + testLoad([&](Core &ie) { + deviceToLoad = "mock.0"; + ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}}); + m_testFunction(ie); + }); + } + + { + EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(0); + EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(0); + EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0); + EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0); + EXPECT_CALL(*net, ExportImpl(_)).Times(0); + testLoad([&](Core &ie) { + deviceToLoad = "mock.1"; + ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}}); + m_testFunction(ie); + }); + } + { + EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0); + EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0); + EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0); + EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(0); + EXPECT_CALL(*net, ExportImpl(_)).Times(1); + testLoad([&](Core &ie) { + deviceToLoad = "mock.50"; + ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}}); + m_testFunction(ie); + }); + } + + { + EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(0); + EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(0); + EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0); + EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0); + EXPECT_CALL(*net, ExportImpl(_)).Times(0); + testLoad([&](Core &ie) { + deviceToLoad = "mock.51"; + ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}}); + m_testFunction(ie); + }); + } +} + +TEST_P(CachingTest, TestNoDeviceArchitecture) { + EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_METRICS), _)).Times(AnyNumber()) + .WillRepeatedly(Invoke([&] (const std::string&, const std::map&) { + return std::vector{METRIC_KEY(IMPORT_EXPORT_SUPPORT)}; + })); + EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(IMPORT_EXPORT_SUPPORT), _)).Times(AnyNumber()); + EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(DEVICE_ARCHITECTURE), _)).Times(0); + { + EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0); + EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0); + EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0); + EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(0); + EXPECT_CALL(*net, ExportImpl(_)).Times(1); + testLoad([&](Core &ie) { + deviceToLoad = "mock.0"; + ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}}); + m_testFunction(ie); + }); + } + + { + EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(0); + EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(0); + EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0); + EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0); + EXPECT_CALL(*net, ExportImpl(_)).Times(0); + testLoad([&](Core &ie) { + deviceToLoad = "mock.50"; + ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}}); + m_testFunction(ie); + }); + } +} + +TEST_P(CachingTest, TestThrowOnExport) { + EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_METRICS), _)).Times(AnyNumber()); + EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(IMPORT_EXPORT_SUPPORT), _)).Times(AnyNumber()); + EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(DEVICE_ARCHITECTURE), _)).Times(AnyNumber()); + { + EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0); + EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0); + EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0); + EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(0); + EXPECT_CALL(*net, ExportImpl(_)).Times(1).WillOnce(Throw(1)); + testLoad([&](Core &ie) { + ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}}); + EXPECT_ANY_THROW(m_testFunction(ie)); + }); + } +} + +TEST_P(CachingTest, TestThrowOnImport) { + ON_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).WillByDefault(Throw(1)); + ON_CALL(*mockPlugin, ImportNetworkImpl(_, _)).WillByDefault(Throw(1)); + EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_METRICS), _)).Times(AnyNumber()); + EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(IMPORT_EXPORT_SUPPORT), _)).Times(AnyNumber()); + EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(DEVICE_ARCHITECTURE), _)).Times(AnyNumber()); + { + EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0); + EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0); + EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0); + EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(0); + EXPECT_CALL(*net, ExportImpl(_)).Times(1); + testLoad([&](Core &ie) { + ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}}); + m_testFunction(ie); + }); + } + { + EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(0); + EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(0); + EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0); + EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0); + EXPECT_CALL(*net, ExportImpl(_)).Times(0); + testLoad([&](Core &ie) { + ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}}); + EXPECT_ANY_THROW(m_testFunction(ie)); + }); + } + { // Step 3: same load, cache should be deleted due to unsuccessful import on step 2 + EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0); + EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0); + EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0); + EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(0); + EXPECT_CALL(*net, ExportImpl(_)).Times(1); + testLoad([&](Core &ie) { + ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}}); + EXPECT_NO_THROW(m_testFunction(ie)); + }); + } +} + +TEST_P(CachingTest, TestNetworkModified) { + EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_METRICS), _)).Times(AnyNumber()); + EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(IMPORT_EXPORT_SUPPORT), _)).Times(AnyNumber()); + EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(DEVICE_ARCHITECTURE), _)).Times(AnyNumber()); + + { + EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0); + EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0); + EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0); + EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(0); + EXPECT_CALL(*net, ExportImpl(_)).Times(1); + testLoad([&](Core &ie) { + EXPECT_NO_THROW(ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}})); + EXPECT_NO_THROW(m_testFunction(ie)); + }); + } + if (m_type == TestLoadType::EModelName) { + // Modify model file + std::fstream stream(modelName, std::fstream::out | std::fstream::app); + stream << " "; + } else { + // Modify loaded CNN network + m_cnnCallback = [&](CNNNetwork& network) { + auto f = network.getFunction(); + auto res = f->get_results(); + f->remove_result(res.front()); + }; + } + { + EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0); + EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0); + EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0); + EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(0); + EXPECT_CALL(*net, ExportImpl(_)).Times(1); + testLoad([&](Core &ie) { + EXPECT_NO_THROW(ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}})); + EXPECT_NO_THROW(m_testFunction(ie)); + }); + } + { // Step 3: same load, should be ok now + EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(0); + EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(0); + EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0); + EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0); + EXPECT_CALL(*net, ExportImpl(_)).Times(0); + testLoad([&](Core &ie) { + EXPECT_NO_THROW(ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}})); + EXPECT_NO_THROW(m_testFunction(ie)); + }); + } +} + +TEST_P(CachingTest, TestCacheFileCorrupted) { + EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_METRICS), _)).Times(AnyNumber()); + EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(IMPORT_EXPORT_SUPPORT), _)).Times(AnyNumber()); + EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(DEVICE_ARCHITECTURE), _)).Times(AnyNumber()); + + { + EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0); + EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0); + EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0); + EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(0); + EXPECT_CALL(*net, ExportImpl(_)).Times(1); + testLoad([&](Core &ie) { + EXPECT_NO_THROW(ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}})); + EXPECT_NO_THROW(m_testFunction(ie)); + }); + } + { + auto blobs = CommonTestUtils::listFilesWithExt(m_cacheDir, "blob"); + for (const auto& fileName : blobs) { + std::ofstream stream(fileName, std::ios_base::binary); + stream << "SomeCorruptedText"; + } + } + { // Step 2. Cache is corrupted, will be silently removed + EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0); + EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0); + EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0); + EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(0); + EXPECT_CALL(*net, ExportImpl(_)).Times(1); + testLoad([&](Core &ie) { + EXPECT_NO_THROW(ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}})); + EXPECT_NO_THROW(m_testFunction(ie)); + }); + } + { // Step 3: same load, should be ok now due to re-creation of cache + EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(0); + EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(0); + EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0); + EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0); + EXPECT_CALL(*net, ExportImpl(_)).Times(0); + testLoad([&](Core &ie) { + EXPECT_NO_THROW(ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}})); + EXPECT_NO_THROW(m_testFunction(ie)); + }); + } +} + +TEST_P(CachingTest, TestCacheFileOldVersion) { + EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_METRICS), _)).Times(AnyNumber()); + EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(IMPORT_EXPORT_SUPPORT), _)).Times(AnyNumber()); + EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(DEVICE_ARCHITECTURE), _)).Times(AnyNumber()); + + { + EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0); + EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0); + EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0); + EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(0); + EXPECT_CALL(*net, ExportImpl(_)).Times(1); + testLoad([&](Core &ie) { + EXPECT_NO_THROW(ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}})); + EXPECT_NO_THROW(m_testFunction(ie)); + }); + } + { + auto blobs = CommonTestUtils::listFilesWithExt(m_cacheDir, "blob"); + for (const auto& fileName : blobs) { + std::string content; + { + std::ifstream inp(fileName, std::ios_base::binary); + std::ostringstream ostr; + ostr << inp.rdbuf(); + content = ostr.str(); + } + std::string buildNum = GetInferenceEngineVersion()->buildNumber; + std::string zeroBuild(buildNum.size(), '0'); + auto index = content.find(buildNum); + if (index != std::string::npos) { + content.replace(index, buildNum.size(), zeroBuild); + } else { + SKIP(); + } + std::ofstream out(fileName, std::ios_base::binary); + out.write(content.c_str(), content.size()); + } + } + { // Step 2. Build number mismatch, cache will be silently removed + EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0); + EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0); + EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0); + EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(0); + EXPECT_CALL(*net, ExportImpl(_)).Times(1); + testLoad([&](Core &ie) { + EXPECT_NO_THROW(ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}})); + EXPECT_NO_THROW(m_testFunction(ie)); + }); + } + { // Step 3: same load, should be ok now due to re-creation of cache + EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(0); + EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(0); + EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0); + EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0); + EXPECT_CALL(*net, ExportImpl(_)).Times(0); + testLoad([&](Core &ie) { + EXPECT_NO_THROW(ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}})); + EXPECT_NO_THROW(m_testFunction(ie)); + }); + } +} + +TEST_P(CachingTest, LoadHeteroWithCorrectConfig) { + EXPECT_CALL(*mockPlugin, GetMetric(_, _)).Times(AnyNumber()); + EXPECT_CALL(*mockPlugin, QueryNetwork(_, _)).Times(AnyNumber()); + // TODO: test also HETERO with 1 plugin but different architectures, e.g. "HETERO:mock.1,mock.51" + deviceToLoad = CommonTestUtils::DEVICE_HETERO + std::string(":mock.1,mock.2"); + if (m_remoteContext) { + return; // skip the remote Context test for Hetero plugin + } + { + EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(0); + EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(1); + EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0); + EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(0); + EXPECT_CALL(*net, ExportImpl(_)).Times(1); + testLoad([&](Core &ie) { + ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}}); + m_testFunction(ie); + }); + } + + { + EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(0); + EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(0); + EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0); + EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(1); + EXPECT_CALL(*net, ExportImpl(_)).Times(0); + testLoad([&](Core &ie) { + ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}}); + m_testFunction(ie); + }); + } +} + +INSTANTIATE_TEST_CASE_P(CachingTest, CachingTest, + ::testing::Combine( + ::testing::ValuesIn(loadVariants), + ::testing::ValuesIn(cacheFolders)), + getTestCaseName); diff --git a/inference-engine/tests/ie_test_utils/common_test_utils/file_utils.hpp b/inference-engine/tests/ie_test_utils/common_test_utils/file_utils.hpp index 9747638aa64..b983fa149d6 100644 --- a/inference-engine/tests/ie_test_utils/common_test_utils/file_utils.hpp +++ b/inference-engine/tests/ie_test_utils/common_test_utils/file_utils.hpp @@ -1,4 +1,4 @@ -// Copyright (C) 2019 Intel Corporation +// Copyright (C) 2021 Intel Corporation // SPDX-License-Identifier: Apache-2.0 // #pragma once @@ -18,7 +18,6 @@ #define rmdir(dir) _rmdir(dir) #else // _WIN32 #include - #endif // _WIN32 namespace CommonTestUtils { @@ -103,6 +102,27 @@ inline int removeFilesWithExt(std::string path, std::string ext) { return ret; } +// Lists all files with extension=ext from the given directory +// Return value: +// vector of strings representing file paths +inline std::vector listFilesWithExt(const std::string& path, const std::string& ext) { + struct dirent *ent; + DIR *dir = opendir(path.c_str()); + std::vector res; + if (dir != nullptr) { + while ((ent = readdir(dir)) != NULL) { + auto file = makePath(path, std::string(ent->d_name)); + struct stat stat_path; + stat(file.c_str(), &stat_path); + if (!S_ISDIR(stat_path.st_mode) && endsWith(file, "." + ext)) { + res.push_back(std::move(file)); + } + } + closedir(dir); + } + return res; +} + inline int removeDir(const std::string &path) { return rmdir(path.c_str()); } diff --git a/inference-engine/tests/ie_test_utils/unit_test_utils/mocks/cpp_interfaces/interface/mock_icore.hpp b/inference-engine/tests/ie_test_utils/unit_test_utils/mocks/cpp_interfaces/interface/mock_icore.hpp index 8bffc712ac1..e52f79bb156 100644 --- a/inference-engine/tests/ie_test_utils/unit_test_utils/mocks/cpp_interfaces/interface/mock_icore.hpp +++ b/inference-engine/tests/ie_test_utils/unit_test_utils/mocks/cpp_interfaces/interface/mock_icore.hpp @@ -1,4 +1,4 @@ -// Copyright (C) 2018-2020 Intel Corporation +// Copyright (C) 2018-2021 Intel Corporation // SPDX-License-Identifier: Apache-2.0 // @@ -16,9 +16,13 @@ public: MOCK_METHOD3(LoadNetwork, InferenceEngine::ExecutableNetwork( const InferenceEngine::CNNNetwork&, const std::string&, const std::map&)); + MOCK_METHOD3(LoadNetwork, InferenceEngine::ExecutableNetwork( + const InferenceEngine::CNNNetwork&, const InferenceEngine::RemoteContext::Ptr &, const std::map&)); MOCK_METHOD3(ImportNetwork, InferenceEngine::ExecutableNetwork( std::istream&, const std::string&, const std::map&)); + MOCK_METHOD3(ImportNetwork, InferenceEngine::ExecutableNetwork( + std::istream&, const InferenceEngine::RemoteContext::Ptr&, const std::map&)); MOCK_QUALIFIED_METHOD3(QueryNetwork, const, InferenceEngine::QueryNetworkResult( const InferenceEngine::CNNNetwork&, const std::string&, const std::map&)); diff --git a/inference-engine/tests/ie_test_utils/unit_test_utils/mocks/mock_engine/mock_plugin.cpp b/inference-engine/tests/ie_test_utils/unit_test_utils/mocks/mock_engine/mock_plugin.cpp index 8b5acd2a8f9..18b6f7388bf 100644 --- a/inference-engine/tests/ie_test_utils/unit_test_utils/mocks/mock_engine/mock_plugin.cpp +++ b/inference-engine/tests/ie_test_utils/unit_test_utils/mocks/mock_engine/mock_plugin.cpp @@ -1,4 +1,4 @@ -// Copyright (C) 2018-2020 Intel Corporation +// Copyright (C) 2018-2021 Intel Corporation // SPDX-License-Identifier: Apache-2.0 // @@ -22,6 +22,14 @@ void MockPlugin::SetConfig(const std::map& config) { this->config = config; } +Parameter MockPlugin::GetMetric(const std::string& name, const std::map& options) const { + if (_target) { + return _target->GetMetric(name, options); + } else { + THROW_IE_EXCEPTION_WITH_STATUS(NOT_IMPLEMENTED); + } +} + ExecutableNetwork MockPlugin::LoadNetwork(const CNNNetwork &network, const std::map &config) { @@ -32,12 +40,62 @@ MockPlugin::LoadNetwork(const CNNNetwork &network, } } -InferenceEngine::ExecutableNetworkInternal::Ptr -MockPlugin::LoadExeNetworkImpl(const InferenceEngine::CNNNetwork& network, +ExecutableNetwork +MockPlugin::LoadNetwork(const CNNNetwork& network, const std::map& config, + RemoteContext::Ptr context) { + if (_target) { + return _target->LoadNetwork(network, config, context); + } else { + THROW_IE_EXCEPTION_WITH_STATUS(NOT_IMPLEMENTED); + } +} + +ExecutableNetworkInternal::Ptr +MockPlugin::LoadExeNetworkImpl(const CNNNetwork& network, const std::map& config) { return {}; } +InferenceEngine::ExecutableNetwork +MockPlugin::ImportNetworkImpl(std::istream& networkModel, + const std::map& config) { + if (_target) { + return _target->ImportNetwork(networkModel, config); + } else { + THROW_IE_EXCEPTION_WITH_STATUS(NOT_IMPLEMENTED); + } +} + +InferenceEngine::ExecutableNetwork +MockPlugin::ImportNetworkImpl(std::istream& networkModel, + const InferenceEngine::RemoteContext::Ptr& context, + const std::map& config) { + if (_target) { + return _target->ImportNetwork(networkModel, context, config); + } else { + THROW_IE_EXCEPTION_WITH_STATUS(NOT_IMPLEMENTED); + } +} + +InferenceEngine::RemoteContext::Ptr MockPlugin::GetDefaultContext(const InferenceEngine::ParamMap& params) { + if (_target) { + return _target->GetDefaultContext(params); + } else { + THROW_IE_EXCEPTION_WITH_STATUS(NOT_IMPLEMENTED); + } +} + +InferenceEngine::QueryNetworkResult +MockPlugin::QueryNetwork(const InferenceEngine::CNNNetwork& network, + const std::map& config) const { + if (_target) { + return _target->QueryNetwork(network, config); + } else { + THROW_IE_EXCEPTION_WITH_STATUS(NOT_IMPLEMENTED); + } +} + + InferenceEngine::IInferencePlugin *__target = nullptr; INFERENCE_PLUGIN_API(void) CreatePluginEngine(std::shared_ptr& plugin) { diff --git a/inference-engine/tests/ie_test_utils/unit_test_utils/mocks/mock_engine/mock_plugin.hpp b/inference-engine/tests/ie_test_utils/unit_test_utils/mocks/mock_engine/mock_plugin.hpp index 1015f6a5a54..73254a4e51d 100644 --- a/inference-engine/tests/ie_test_utils/unit_test_utils/mocks/mock_engine/mock_plugin.hpp +++ b/inference-engine/tests/ie_test_utils/unit_test_utils/mocks/mock_engine/mock_plugin.hpp @@ -1,4 +1,4 @@ -// Copyright (C) 2018-2020 Intel Corporation +// Copyright (C) 2018-2021 Intel Corporation // SPDX-License-Identifier: Apache-2.0 // @@ -17,12 +17,34 @@ public: explicit MockPlugin(InferenceEngine::IInferencePlugin*target); void SetConfig(const std::map& config) override; + InferenceEngine::ExecutableNetwork LoadNetwork(const InferenceEngine::CNNNetwork &network, const std::map &config) override; + + InferenceEngine::ExecutableNetwork + LoadNetwork(const InferenceEngine::CNNNetwork& network, + const std::map& config, + InferenceEngine::RemoteContext::Ptr context) override; + InferenceEngine::ExecutableNetworkInternal::Ptr LoadExeNetworkImpl(const InferenceEngine::CNNNetwork& network, const std::map& config) override; + InferenceEngine::ExecutableNetwork ImportNetworkImpl(std::istream& networkModel, + const std::map& config) override; + + InferenceEngine::ExecutableNetwork ImportNetworkImpl(std::istream& networkModel, + const InferenceEngine::RemoteContext::Ptr& context, + const std::map& config) override; + + InferenceEngine::Parameter GetMetric(const std::string& name, + const std::map& options) const override; + + InferenceEngine::RemoteContext::Ptr GetDefaultContext(const InferenceEngine::ParamMap& params) override; + + InferenceEngine::QueryNetworkResult QueryNetwork(const InferenceEngine::CNNNetwork& network, + const std::map& config) const override; + std::map config; }; diff --git a/inference-engine/tests/unit/inference_engine/ie_compilation_context_test.cpp b/inference-engine/tests/unit/inference_engine/ie_compilation_context_test.cpp new file mode 100644 index 00000000000..96808f2ccbc --- /dev/null +++ b/inference-engine/tests/unit/inference_engine/ie_compilation_context_test.cpp @@ -0,0 +1,388 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include +#include +#include +#include + +#include "compilation_context.hpp" +#include "ngraph/function.hpp" +#include "ngraph/ops.hpp" +#include "ngraph/variant.hpp" +#include "ngraph/opsets/opset6.hpp" +#include "transformations/rt_info/dequantization_attribute.hpp" +#include "transformations/rt_info/fused_names_attribute.hpp" +#include "transformations/rt_info/primitives_priority_attribute.hpp" +#include "cpp/ie_cnn_network.h" + +#include "common_test_utils/test_constants.hpp" + +using namespace InferenceEngine; +using namespace ngraph; +using namespace ::testing; +using namespace std::chrono; + +static std::string generateTestFilePrefix() { + // Generate unique file names based on test name, thread id and timestamp + // This allows execution of tests in parallel (stress mode) + auto testInfo = UnitTest::GetInstance()->current_test_info(); + std::string testName = testInfo->test_case_name(); + testName += testInfo->name(); + testName = std::to_string(std::hash()(testName)); + std::stringstream ss; + auto ts = duration_cast(high_resolution_clock::now().time_since_epoch()); + ss << testName << "_" << std::this_thread::get_id() << "_" << ts.count(); + testName = ss.str(); + return testName; +} + +class FileGuard { + std::string m_fileName; +public: + FileGuard(const std::string& name): m_fileName(name) {} + ~FileGuard() { std::remove(m_fileName.c_str()); } +}; + +class NetworkContext_CalcFileInfoTests : public Test { +public: + std::string m_fileName = "test.blob"; + + static void createFile(const std::string& fileName, std::size_t size = 1) { + std::ofstream str(fileName, std::ios::binary); + if (!str.good()) { + GTEST_SKIP(); + } + for (std::size_t i = 0; i < size; i++) + str.put('a'); + } + + // Sets up the test fixture. + void SetUp() override { + auto testName = generateTestFilePrefix(); + m_fileName = testName + m_fileName; + createFile(m_fileName); + } + + // Tears down the test fixture. + void TearDown() override { + std::remove(m_fileName.c_str()); + } +}; + +TEST_F(NetworkContext_CalcFileInfoTests, NoFile) { + ASSERT_NE(NetworkCompilationContext::calculateFileInfo("notexisting.abc"), + NetworkCompilationContext::calculateFileInfo("notexisting2.abc")); + + std::string fileName(100, 'a'); + std::string fileName2(fileName); + ASSERT_EQ(NetworkCompilationContext::calculateFileInfo(fileName), + NetworkCompilationContext::calculateFileInfo(fileName2)); +} + +TEST_F(NetworkContext_CalcFileInfoTests, ExistingFile) { + ASSERT_EQ(NetworkCompilationContext::calculateFileInfo(m_fileName), + NetworkCompilationContext::calculateFileInfo(m_fileName)); +} + +TEST_F(NetworkContext_CalcFileInfoTests, ExistingDiffFiles) { + auto hash1 = NetworkCompilationContext::calculateFileInfo(m_fileName); + std::string newName = m_fileName + "2"; + std::rename(m_fileName.c_str(), newName.c_str()); + m_fileName = std::move(newName); + auto hash2 = NetworkCompilationContext::calculateFileInfo(m_fileName); + ASSERT_NE(hash1, hash2); +} + +TEST_F(NetworkContext_CalcFileInfoTests, ExistingFile_sameAbsPath) { + std::string file1 = m_fileName; + std::string file2 = std::string(".") + CommonTestUtils::FileSeparator + m_fileName; + ASSERT_EQ(NetworkCompilationContext::calculateFileInfo(file1), + NetworkCompilationContext::calculateFileInfo(file2)) << + "Hash of [" << file1 << "] is not equal to hash of [" << file2 << "]"; +} + +TEST_F(NetworkContext_CalcFileInfoTests, DateModified) { + auto info1 = NetworkCompilationContext::calculateFileInfo(m_fileName); + std::this_thread::sleep_for(std::chrono::seconds(2)); + createFile(m_fileName); + auto info2 = NetworkCompilationContext::calculateFileInfo(m_fileName); + ASSERT_NE(info1, info2); +} + +TEST_F(NetworkContext_CalcFileInfoTests, SizeModified) { + createFile(m_fileName, 1); + auto info1 = NetworkCompilationContext::calculateFileInfo(m_fileName); + createFile(m_fileName, 2); + auto info2 = NetworkCompilationContext::calculateFileInfo(m_fileName); + ASSERT_NE(info1, info2); +} + +//////////////////////////////////////////////////// + +static std::shared_ptr create_simple_function() { + // This example is taken from docs, shows how to create ngraph::Function + // + // Parameter--->Multiply--->Add--->Result + // Constant---' / + // Constant---' + + // Create opset6::Parameter operation with static shape + auto data = std::make_shared(ngraph::element::i8, ngraph::Shape{3, 1, 2}); + data->set_friendly_name("Parameter"); + + auto mul_constant = ngraph::opset6::Constant::create(ngraph::element::i8, ngraph::Shape{1}, {3}); + mul_constant->set_friendly_name("mul_constant"); + auto mul = std::make_shared(data, mul_constant); + mul->set_friendly_name("mul"); + + auto add_constant = ngraph::opset6::Constant::create(ngraph::element::i8, ngraph::Shape{1}, {2}); + add_constant->set_friendly_name("add_constant"); + auto add = std::make_shared(mul, add_constant); + add->set_friendly_name("add"); + + // Create opset3::Result operation + auto res = std::make_shared(add); + res->set_friendly_name("res"); + + // Create nGraph function + auto func = std::make_shared(ngraph::ResultVector{res}, ngraph::ParameterVector{data}); + func->set_friendly_name("function"); + return func; +} + +static CNNNetwork createNetwork() { + CNNNetwork res(create_simple_function()); + return res; +} + +static void checkCustomRt(std::function emptyCb, + std::function nameCb) { + auto net1 = createNetwork(); + auto net2 = createNetwork(); + auto & op1 = net1.getFunction()->get_ops().front()->get_rt_info(); + auto & op2 = net2.getFunction()->get_ops().front()->get_rt_info(); + + emptyCb(op2); + ASSERT_NE(NetworkCompilationContext::computeHash(net1, {}), + NetworkCompilationContext::computeHash(net2, {})); + + emptyCb(op1); + ASSERT_EQ(NetworkCompilationContext::computeHash(net1, {}), + NetworkCompilationContext::computeHash(net2, {})); + + nameCb(op1, "test"); + ASSERT_NE(NetworkCompilationContext::computeHash(net1, {}), + NetworkCompilationContext::computeHash(net2, {})); + + nameCb(op2, "test"); + ASSERT_EQ(NetworkCompilationContext::computeHash(net1, {}), + NetworkCompilationContext::computeHash(net2, {})); + + nameCb(op1, "test2"); + ASSERT_NE(NetworkCompilationContext::computeHash(net1, {}), + NetworkCompilationContext::computeHash(net2, {})); +} + +TEST(NetworkContext_CNNNetwork, HashOfSame) { + auto net1 = createNetwork(); + auto net2 = createNetwork(); + ASSERT_EQ(NetworkCompilationContext::computeHash(net1, {}), + NetworkCompilationContext::computeHash(net2, {})); +} + +TEST(NetworkContext_CNNNetwork, HashWithConfig) { + auto net1 = createNetwork(); + auto net2 = createNetwork(); + ASSERT_NE(NetworkCompilationContext::computeHash(net1, {{"key", "value"}}), + NetworkCompilationContext::computeHash(net2, {})); + ASSERT_EQ(NetworkCompilationContext::computeHash(net1, {{"key", "value"}}), + NetworkCompilationContext::computeHash(net2, {{"key", "value"}})); +} + +TEST(NetworkContext_CNNNetwork, HashWithPrimitivesPriority) { + auto net1 = createNetwork(); + auto net2 = createNetwork(); + auto net3 = createNetwork(); + auto & op2 = net2.getFunction()->get_ops().front()->get_rt_info(); + op2["PrimitivesPriority"] = std::make_shared > ("testPriority"); + + auto & op3 = net3.getFunction()->get_ops().front()->get_rt_info(); + op3["PrimitivesPriority"] = std::make_shared > ("testPriority"); + + ASSERT_NE(NetworkCompilationContext::computeHash(net1, {}), + NetworkCompilationContext::computeHash(net2, {})); + + ASSERT_EQ(NetworkCompilationContext::computeHash(net2, {}), + NetworkCompilationContext::computeHash(net3, {})); +} + +TEST(NetworkContext_CNNNetwork, HashWithDequantization) { + auto setDeqEmpty = [&](Node::RTMap& rtInfo) { + rtInfo[VariantWrapper::type_info.name] = + std::make_shared>(DequantizationAttr()); + }; + auto setDeq = [&](Node::RTMap& rtInfo, const std::string& name) { + rtInfo[VariantWrapper::type_info.name] = + std::make_shared>(DequantizationAttr(name)); + }; + checkCustomRt(setDeqEmpty, setDeq); +} + +TEST(NetworkContext_CNNNetwork, HashWithFusedNames) { + auto setFusedEmpty = [&](Node::RTMap& rtInfo) { + rtInfo[VariantWrapper::type_info.name] = + std::make_shared>(FusedNames()); + }; + auto setFused = [&](Node::RTMap& rtInfo, const std::string& name) { + rtInfo[VariantWrapper::type_info.name] = + std::make_shared>(FusedNames(name)); + }; + checkCustomRt(setFusedEmpty, setFused); +} + +TEST(NetworkContext_CNNNetwork, HashWithPrimitivesPriorityType) { + auto setPrimEmpty = [&](Node::RTMap& rtInfo) { + rtInfo[VariantWrapper::type_info.name] = + std::make_shared>(PrimitivesPriority()); + }; + auto setPrim = [&](Node::RTMap& rtInfo, const std::string& name) { + rtInfo[VariantWrapper::type_info.name] = + std::make_shared>(PrimitivesPriority(name)); + }; + checkCustomRt(setPrimEmpty, setPrim); +} + +TEST(NetworkContext_CNNNetwork, HashWithAffinity) { + auto net1 = createNetwork(); + auto net2 = createNetwork(); + auto net3 = createNetwork(); + auto & op2 = net2.getFunction()->get_ops().front()->get_rt_info(); + op2["affinity"] = std::make_shared>("testAffinity"); + + auto & op3 = net3.getFunction()->get_ops().front()->get_rt_info(); + op3["affinity"] = std::make_shared>("testAffinity"); + + ASSERT_NE(NetworkCompilationContext::computeHash(net1, {}), + NetworkCompilationContext::computeHash(net2, {})); + + ASSERT_EQ(NetworkCompilationContext::computeHash(net2, {}), + NetworkCompilationContext::computeHash(net3, {})); +} + +TEST(NetworkContext_CNNNetwork, HashWithFutureRt_string) { + auto net1 = createNetwork(); + auto net2 = createNetwork(); + auto net3 = createNetwork(); + + auto & op1 = net1.getFunction()->get_ops().front()->get_rt_info(); + op1["someFutureKey"] = std::make_shared>("hello"); + + auto & op2 = net2.getFunction()->get_ops().front()->get_rt_info(); + op2["someFutureKey"] = std::make_shared>("hello"); + + auto & op3 = net3.getFunction()->get_ops().front()->get_rt_info(); + op3["someFutureKey"] = std::make_shared>("olleh"); + + ASSERT_EQ(NetworkCompilationContext::computeHash(net1, {}), + NetworkCompilationContext::computeHash(net2, {})); + + ASSERT_NE(NetworkCompilationContext::computeHash(net2, {}), + NetworkCompilationContext::computeHash(net3, {})); +} + +TEST(NetworkContext_CNNNetwork, HashWithFutureRt_int64) { + auto net1 = createNetwork(); + auto net2 = createNetwork(); + auto net3 = createNetwork(); + + auto & op1 = net1.getFunction()->get_ops().front()->get_rt_info(); + op1["someFutureKey"] = std::make_shared>(42); + + auto & op2 = net2.getFunction()->get_ops().front()->get_rt_info(); + op2["someFutureKey"] = std::make_shared>(42); + + auto & op3 = net3.getFunction()->get_ops().front()->get_rt_info(); + op3["someFutureKey"] = std::make_shared>(43); + + ASSERT_EQ(NetworkCompilationContext::computeHash(net1, {}), + NetworkCompilationContext::computeHash(net2, {})); + + ASSERT_NE(NetworkCompilationContext::computeHash(net2, {}), + NetworkCompilationContext::computeHash(net3, {})); +} + +TEST(NetworkContext_CNNNetwork, HashWithDifferentResults) { + auto net1 = createNetwork(); + auto net2 = createNetwork(); + net2.getFunction()->remove_result(net2.getFunction()->get_results().front()); + auto net3 = createNetwork(); + net3.getFunction()->remove_result(net3.getFunction()->get_results().front()); + ASSERT_NE(NetworkCompilationContext::computeHash(net1, {}), + NetworkCompilationContext::computeHash(net2, {})); + ASSERT_EQ(NetworkCompilationContext::computeHash(net2, {}), + NetworkCompilationContext::computeHash(net3, {})); +} + +TEST(NetworkContext_CNNNetwork, HashWithDifferentMeanValues) { + auto updatePreprocess = [&](CNNNetwork& cnnNet) { + auto &preProcess = cnnNet.getInputsInfo().begin()->second->getPreProcess(); + preProcess.init(3); + preProcess[0]->stdScale = 2; + preProcess[1]->stdScale = 3; + preProcess[2]->stdScale = 4; + preProcess[0]->meanValue = 0; + preProcess[1]->meanValue = 1; + preProcess[2]->meanValue = 2; + preProcess.setVariant(InferenceEngine::MEAN_VALUE); + }; + auto net1 = createNetwork(); + auto net2 = createNetwork(); + updatePreprocess(net2); + auto net3 = createNetwork(); + updatePreprocess(net3); + ASSERT_NE(NetworkCompilationContext::computeHash(net1, {}), + NetworkCompilationContext::computeHash(net2, {})); + ASSERT_EQ(NetworkCompilationContext::computeHash(net2, {}), + NetworkCompilationContext::computeHash(net3, {})); +} + +//////////////////////////////////////////// + +TEST(NetworkContext_ModelName, HashOfSame) { + ASSERT_EQ(NetworkCompilationContext::computeHash("model1", {}), + NetworkCompilationContext::computeHash("model1", {})); + + ASSERT_NE(NetworkCompilationContext::computeHash("model1", {}), + NetworkCompilationContext::computeHash("model2", {})); + + ASSERT_NE(NetworkCompilationContext::computeHash("model1", {{"key", "value"}}), + NetworkCompilationContext::computeHash("model1", {})); + + ASSERT_EQ(NetworkCompilationContext::computeHash("model1", {{"key", "value"}}), + NetworkCompilationContext::computeHash("model1", {{"key", "value"}})); +} + +TEST(NetworkContext_ModelName, HashOfExistingFile) { + auto file1 = generateTestFilePrefix() + ".xml"; + auto file2 = std::string(".") + CommonTestUtils::FileSeparator + file1; + + FileGuard guard(file1); + { + std::ofstream os(file1); + os << "test"; + } + ASSERT_EQ(NetworkCompilationContext::computeHash(file1, {}), + NetworkCompilationContext::computeHash(file1, {})); + + ASSERT_EQ(NetworkCompilationContext::computeHash(file1, {}), + NetworkCompilationContext::computeHash(file2, {})); + + ASSERT_NE(NetworkCompilationContext::computeHash(file1, {{"key", "value"}}), + NetworkCompilationContext::computeHash(file2, {})); + + ASSERT_EQ(NetworkCompilationContext::computeHash(file1, {{"key", "value"}}), + NetworkCompilationContext::computeHash(file2, {{"key", "value"}})); +}