Model caching support - Core part (#4661)
* Model caching support - Core part Introducing model caching support Use core.SetConfig({{CONFIG_KEY(CACHE_DIR), <dir>}}); to enable caching of models OpenVINO will try to create caching folder if it doesn't exist, but it is recommended for client to create caching folder with necessary permissions before enabling cache in config For caching, plugins shall support import/export functionality Plugin requirements: - Add METRIC_KEY(IMPORT_EXPORT_SUPPORT) in SUPPORTED_METRICS to support caching If plugin has different device architectures with different caches, i.e. For "GNA.0" - one cache, for "GNA.10" - another cache In this case plugin shall support DEVICE_ARCHITECTURE metric and return different strings for different DEVICE_ID's Added functional tests * Fix CentOS build issues * Few updates according to code review * Revert unnecessary changes for Import/Export core implementation These changes affect old behavior and may be undesired For caching support these is no need to change anything in this area If needed, such removal of 'Magic' usage can be done under separate task in future * More tests: 1) Verify that Imported data from stream is the same as was exported 2) Verify that cache is not loaded when config in LoadNetwork is changed 3) Verify that if CNN Network is changed between ReadNetwork and LoadNetwork - cache is not loaded * Update of NetworkCompilationContext Put back functionality of calculating hash based on runtime information, weights Implemented OstreamHashWrapper to avoid serialization to buffer * Correction of CACHE_DIR key description * Unit tests for compilation_context Changes: 1) Improved handling of OstreamHashAdapter 2) Improved runtime info serialization (not just PrimitivesPriority and affinity) 3) Removed redundant weights hash calculation * Fix GCC 4.8 build issues * Compilation context updates 1) Use hash of sum of serialized data to get hash of network. It is more efficient comparing to weights sum calculation 2) CalculateFileInfo - convert path to absolute ("./test.blob" and "test.blob" shall give same hash) * Hash - added more rt_info attributes + tests - PrimitivesPriority - FusedNames - Dequantization * Moved "get_absolute_path" macro to file_utils.h * Make 'absoluteFilePath' a library API, not macro * One more unit test for fileName hashing * Fix compilation error after merge with latest master * Allow tests to be executed in parallel (stress mode) * More minor updates for stress testing Now it allows to execute tests with '--repeat=100' option where one test is executed in multiple processes simultaneously Example: ./gtest-parallel <openvino_dir>/bin/intel64/Debug/ieFuncTests --gtest_filter=CachingTest* --repeat=10 * Use absolute model file path for calculating blob name * Added 'createDirectoryRecursive' API to plugin_api/file_utils
This commit is contained in:
parent
0518840630
commit
7f9daadc08
@ -1,4 +1,4 @@
|
|||||||
// Copyright (C) 2018-2020 Intel Corporation
|
// Copyright (C) 2018-2021 Intel Corporation
|
||||||
// SPDX-License-Identifier: Apache-2.0
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
//
|
//
|
||||||
|
|
||||||
@ -108,6 +108,23 @@ public:
|
|||||||
const CNNNetwork& network, const std::string& deviceName,
|
const CNNNetwork& network, const std::string& deviceName,
|
||||||
const std::map<std::string, std::string>& config = {});
|
const std::map<std::string, std::string>& config = {});
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Reads model and creates an executable network from IR or ONNX file
|
||||||
|
*
|
||||||
|
* This can be more efficient than using ReadNetwork + LoadNetwork(CNNNetwork) flow
|
||||||
|
* especially for cases when caching is enabled and cached model is available
|
||||||
|
*
|
||||||
|
* @param modelPath path to model
|
||||||
|
* @param deviceName Name of device to load network to
|
||||||
|
* @param config Optional map of pairs: (config parameter name, config parameter value) relevant only for this load
|
||||||
|
* operation/
|
||||||
|
*
|
||||||
|
* @return An executable network reference
|
||||||
|
*/
|
||||||
|
ExecutableNetwork LoadNetwork(
|
||||||
|
const std::string& modelPath, const std::string& deviceName,
|
||||||
|
const std::map<std::string, std::string>& config = {});
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Registers extension
|
* @brief Registers extension
|
||||||
* @param extension Pointer to already loaded extension
|
* @param extension Pointer to already loaded extension
|
||||||
@ -137,8 +154,8 @@ public:
|
|||||||
/**
|
/**
|
||||||
* @brief Creates an executable network from a previously exported network
|
* @brief Creates an executable network from a previously exported network
|
||||||
*
|
*
|
||||||
* @param deviceName Name of device load executable network on
|
|
||||||
* @param modelFileName Path to the location of the exported file
|
* @param modelFileName Path to the location of the exported file
|
||||||
|
* @param deviceName Name of device load executable network on
|
||||||
* @param config Optional map of pairs: (config parameter name, config parameter value) relevant only for this load
|
* @param config Optional map of pairs: (config parameter name, config parameter value) relevant only for this load
|
||||||
* operation*
|
* operation*
|
||||||
* @return An executable network reference
|
* @return An executable network reference
|
||||||
@ -149,8 +166,8 @@ public:
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Creates an executable network from a previously exported network
|
* @brief Creates an executable network from a previously exported network
|
||||||
* @param deviceName Name of device load executable network on
|
|
||||||
* @param networkModel network model stream
|
* @param networkModel network model stream
|
||||||
|
* @param deviceName Name of device load executable network on
|
||||||
* @param config Optional map of pairs: (config parameter name, config parameter value) relevant only for this load
|
* @param config Optional map of pairs: (config parameter name, config parameter value) relevant only for this load
|
||||||
* operation*
|
* operation*
|
||||||
* @return An executable network reference
|
* @return An executable network reference
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
// Copyright (C) 2018-2020 Intel Corporation
|
// Copyright (C) 2018-2021 Intel Corporation
|
||||||
// SPDX-License-Identifier: Apache-2.0
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
//
|
//
|
||||||
|
|
||||||
@ -143,6 +143,16 @@ DECLARE_METRIC_KEY(NUMBER_OF_WAITING_INFER_REQUESTS, unsigned int);
|
|||||||
*/
|
*/
|
||||||
DECLARE_METRIC_KEY(NUMBER_OF_EXEC_INFER_REQUESTS, unsigned int);
|
DECLARE_METRIC_KEY(NUMBER_OF_EXEC_INFER_REQUESTS, unsigned int);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Metric which defines the device architecture.
|
||||||
|
*/
|
||||||
|
DECLARE_METRIC_KEY(DEVICE_ARCHITECTURE, std::string);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Metric which defines support of import/export functionality by plugin
|
||||||
|
*/
|
||||||
|
DECLARE_METRIC_KEY(IMPORT_EXPORT_SUPPORT, bool);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Metric to get a name of network. String value is "NETWORK_NAME".
|
* @brief Metric to get a name of network. String value is "NETWORK_NAME".
|
||||||
*/
|
*/
|
||||||
@ -365,13 +375,20 @@ DECLARE_CONFIG_KEY(ENFORCE_BF16);
|
|||||||
/**
|
/**
|
||||||
* @brief This key defines the directory which will be used to store any data cached by plugins.
|
* @brief This key defines the directory which will be used to store any data cached by plugins.
|
||||||
*
|
*
|
||||||
* This key supports unicode symbols in path
|
|
||||||
* The underlying cache structure is not defined and might differ between OpenVINO releases
|
* The underlying cache structure is not defined and might differ between OpenVINO releases
|
||||||
* Cached data might be platform / device specific and might be invalid after OpenVINO version change
|
* Cached data might be platform / device specific and might be invalid after OpenVINO version change
|
||||||
* If this key is not specified or value is empty string, then caching is disabled.
|
* If this key is not specified or value is empty string, then caching is disabled.
|
||||||
* The key might enable caching for all plugin or some specific ones, e.g.:
|
* The key might enable caching for the plugin using the following code:
|
||||||
* ie.SetConfig({{CONFIG_KEY(CACHE_DIR), "cache/"}}) - enables cache for all plugins that might want to use it
|
*
|
||||||
* ie.SetConfig({{CONFIG_KEY(CACHE_DIR), "cache/"}}, {"GPU"}) - enables cache only for GPU plugin
|
* @code
|
||||||
|
* ie.SetConfig({{CONFIG_KEY(CACHE_DIR), "cache/"}}, "GPU"); // enables cache for GPU plugin
|
||||||
|
* @endcode
|
||||||
|
*
|
||||||
|
* The following code enables caching of compiled network blobs for devices where import/export is supported
|
||||||
|
*
|
||||||
|
* @code
|
||||||
|
* ie.SetConfig({{CONFIG_KEY(CACHE_DIR), "cache/"}}); // enables models cache
|
||||||
|
* @endcode
|
||||||
*/
|
*/
|
||||||
DECLARE_CONFIG_KEY(CACHE_DIR);
|
DECLARE_CONFIG_KEY(CACHE_DIR);
|
||||||
|
|
||||||
|
219
inference-engine/src/inference_engine/compilation_context.cpp
Normal file
219
inference-engine/src/inference_engine/compilation_context.cpp
Normal file
@ -0,0 +1,219 @@
|
|||||||
|
// Copyright (C) 2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
#include "compilation_context.hpp"
|
||||||
|
|
||||||
|
#include <sys/types.h>
|
||||||
|
#include <sys/stat.h>
|
||||||
|
|
||||||
|
#ifndef WIN32
|
||||||
|
#include <unistd.h>
|
||||||
|
#endif
|
||||||
|
#include <xml_parse_utils.h>
|
||||||
|
|
||||||
|
#include "ie_itt.hpp"
|
||||||
|
#include "cpp_interfaces/exception2status.hpp"
|
||||||
|
#include "transformations/serialize.hpp"
|
||||||
|
#include "cpp/ie_cnn_network.h"
|
||||||
|
#include "details/ie_exception.hpp"
|
||||||
|
|
||||||
|
#include "ngraph/variant.hpp"
|
||||||
|
#include "ngraph/opsets/opset6.hpp"
|
||||||
|
#include "transformations/rt_info/dequantization_attribute.hpp"
|
||||||
|
#include "transformations/rt_info/fused_names_attribute.hpp"
|
||||||
|
#include "transformations/rt_info/primitives_priority_attribute.hpp"
|
||||||
|
#include "file_utils.h"
|
||||||
|
|
||||||
|
#ifdef WIN32
|
||||||
|
#define stat _stat
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace InferenceEngine {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static std::size_t hash_combine(std::size_t seed, const T& a) {
|
||||||
|
// Hash combine formula from boost
|
||||||
|
return seed ^ (std::hash<T>()(a) + 0x9e3779b9 + (seed << 6) + (seed >> 2));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static int32_t as_int32_t(T v) {
|
||||||
|
return static_cast<int32_t>(v);
|
||||||
|
}
|
||||||
|
|
||||||
|
class OstreamHashWrapper final: public std::streambuf {
|
||||||
|
std::size_t m_res = {};
|
||||||
|
public:
|
||||||
|
std::size_t getResult() const { return m_res; }
|
||||||
|
std::streamsize xsputn(const char* s, std::streamsize n) override {
|
||||||
|
const std::int64_t* intS = (const std::int64_t *)s;
|
||||||
|
std::streamsize n64 = n / sizeof(std::int64_t);
|
||||||
|
std::streamsize i = 0;
|
||||||
|
// Using 64-bit values executes much faster than char
|
||||||
|
while (i++ < n64) {
|
||||||
|
m_res += *(intS++);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::streamsize rest = n % sizeof(std::int64_t);
|
||||||
|
for (i = 0; i < rest; i++) {
|
||||||
|
m_res += s[n - rest + i];
|
||||||
|
}
|
||||||
|
return n;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////
|
||||||
|
|
||||||
|
std::string NetworkCompilationContext::calculateFileInfo(const std::string& filePath) {
|
||||||
|
size_t seed {};
|
||||||
|
auto absPath = filePath;
|
||||||
|
try {
|
||||||
|
absPath = FileUtils::absoluteFilePath(filePath);
|
||||||
|
} catch (...) {
|
||||||
|
// can't get absolute path, will use filePath for hash
|
||||||
|
}
|
||||||
|
|
||||||
|
seed = hash_combine(seed, absPath);
|
||||||
|
|
||||||
|
std::string res;
|
||||||
|
struct stat result;
|
||||||
|
if (stat(absPath.c_str(), &result) == 0) {
|
||||||
|
seed = hash_combine(seed, result.st_mtime);
|
||||||
|
seed = hash_combine(seed, result.st_size);
|
||||||
|
}
|
||||||
|
return std::to_string(seed);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string NetworkCompilationContext::computeHash(const CNNNetwork& network,
|
||||||
|
const std::map<std::string, std::string>& compileOptions) {
|
||||||
|
OV_ITT_SCOPED_TASK(itt::domains::IE_LT, "NetworkCompilationContext::computeHash - CNN");
|
||||||
|
OstreamHashWrapper xmlHash;
|
||||||
|
OstreamHashWrapper binHash;
|
||||||
|
std::ostream xml(&xmlHash);
|
||||||
|
std::ostream bin(&binHash);
|
||||||
|
|
||||||
|
IE_ASSERT(network.getFunction());
|
||||||
|
|
||||||
|
// 1. Serialize
|
||||||
|
CNNNetwork net(network);
|
||||||
|
ngraph::pass::Serialize serializer(xml, bin,
|
||||||
|
ngraph::pass::Serialize::Version::IR_V10);
|
||||||
|
serializer.run_on_function(net.getFunction());
|
||||||
|
|
||||||
|
// 2. Compute hash on serialized data and options
|
||||||
|
size_t seed {};
|
||||||
|
seed = hash_combine(seed, xmlHash.getResult());
|
||||||
|
seed = hash_combine(seed, binHash.getResult());
|
||||||
|
|
||||||
|
for (const auto& kvp : compileOptions) {
|
||||||
|
seed = hash_combine(seed, kvp.first + kvp.second);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. Add runtime information which may not be serialized
|
||||||
|
for (const auto& op : network.getFunction()->get_ordered_ops()) {
|
||||||
|
const auto& rt = op->get_rt_info();
|
||||||
|
for (const auto& rtMapData : rt) {
|
||||||
|
seed = hash_combine(seed, rtMapData.first);
|
||||||
|
|
||||||
|
if (auto stringData = std::dynamic_pointer_cast<ngraph::VariantWrapper<std::string>>(rtMapData.second)) {
|
||||||
|
seed = hash_combine(seed, stringData->get());
|
||||||
|
} else if (auto intData = std::dynamic_pointer_cast<ngraph::VariantWrapper<std::int64_t>>(rtMapData.second)) {
|
||||||
|
seed = hash_combine(seed, intData->get());
|
||||||
|
} else if (auto deq = std::dynamic_pointer_cast<ngraph::VariantWrapper<ngraph::DequantizationAttr>>(rtMapData.second)) {
|
||||||
|
seed = hash_combine(seed, deq->get().getDequantizationAttr());
|
||||||
|
} else if (auto fNames = std::dynamic_pointer_cast<ngraph::VariantWrapper<ngraph::FusedNames>>(rtMapData.second)) {
|
||||||
|
seed = hash_combine(seed, fNames->get().getNames());
|
||||||
|
} else if (auto prim = std::dynamic_pointer_cast<ngraph::VariantWrapper<ngraph::PrimitivesPriority>>(rtMapData.second)) {
|
||||||
|
seed = hash_combine(seed, prim->get().getPrimitivesPriority());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4. Add inputs info
|
||||||
|
for (const auto& input : network.getInputsInfo()) {
|
||||||
|
InputInfo::Ptr info = input.second;
|
||||||
|
seed = hash_combine(seed, as_int32_t(info->getPrecision()));
|
||||||
|
seed = hash_combine(seed, as_int32_t(info->getLayout()));
|
||||||
|
|
||||||
|
const InferenceEngine::PreProcessInfo& preproc = info->getPreProcess();
|
||||||
|
seed = hash_combine(seed, as_int32_t(preproc.getMeanVariant()));
|
||||||
|
|
||||||
|
if (preproc.getMeanVariant() == MeanVariant::MEAN_VALUE) {
|
||||||
|
seed = hash_combine(seed, preproc.getNumberOfChannels());
|
||||||
|
for (size_t c = 0; c < preproc.getNumberOfChannels(); ++c) {
|
||||||
|
const PreProcessChannel::Ptr & channelInfo = preproc[c];
|
||||||
|
seed = hash_combine(seed, channelInfo->stdScale);
|
||||||
|
seed = hash_combine(seed, channelInfo->meanValue);
|
||||||
|
}
|
||||||
|
} else if (preproc.getMeanVariant() == MeanVariant::MEAN_IMAGE) {
|
||||||
|
// TODO: think if we need to compute hash for mean image if it exists
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 5. Add outputs info
|
||||||
|
for (const auto& output : network.getOutputsInfo()) {
|
||||||
|
DataPtr info = output.second;
|
||||||
|
seed = hash_combine(seed, as_int32_t(info->getPrecision()));
|
||||||
|
seed = hash_combine(seed, as_int32_t(info->getLayout()));
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::to_string(seed);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string NetworkCompilationContext::computeHash(const std::string& modelName,
|
||||||
|
const std::map<std::string, std::string>& compileOptions) {
|
||||||
|
OV_ITT_SCOPED_TASK(itt::domains::IE_LT, "NetworkCompilationContext::computeHash - ModelName");
|
||||||
|
size_t seed {};
|
||||||
|
try {
|
||||||
|
seed = hash_combine(seed, FileUtils::absoluteFilePath(modelName));
|
||||||
|
} catch (...) {
|
||||||
|
// can't get absolute path, use modelName for hash calculation
|
||||||
|
seed = hash_combine(seed, modelName);
|
||||||
|
}
|
||||||
|
for (const auto& kvp : compileOptions) {
|
||||||
|
seed = hash_combine(seed, kvp.first + kvp.second);
|
||||||
|
}
|
||||||
|
return std::to_string(seed);
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////
|
||||||
|
|
||||||
|
CompiledBlobHeader::CompiledBlobHeader() {}
|
||||||
|
|
||||||
|
CompiledBlobHeader::CompiledBlobHeader(const std::string& ieVersion, const std::string& fileInfo) :
|
||||||
|
m_ieVersion(ieVersion),
|
||||||
|
m_fileInfo(fileInfo) {
|
||||||
|
}
|
||||||
|
|
||||||
|
std::istream& operator >> (std::istream& stream, CompiledBlobHeader& header) {
|
||||||
|
std::string xmlStr;
|
||||||
|
std::getline(stream, xmlStr);
|
||||||
|
|
||||||
|
pugi::xml_document document;
|
||||||
|
pugi::xml_parse_result res = document.load_string(xmlStr.c_str());
|
||||||
|
|
||||||
|
if (res.status != pugi::status_ok) {
|
||||||
|
THROW_IE_EXCEPTION_WITH_STATUS(NETWORK_NOT_READ) << "Error reading compiled blob header";
|
||||||
|
}
|
||||||
|
|
||||||
|
pugi::xml_node compiledBlobNode = document.document_element();
|
||||||
|
header.m_ieVersion = XMLParseUtils::GetStrAttr(compiledBlobNode, "ie_version");
|
||||||
|
header.m_fileInfo = XMLParseUtils::GetStrAttr(compiledBlobNode, "file_info");
|
||||||
|
|
||||||
|
return stream;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::ostream& operator << (std::ostream& stream, const CompiledBlobHeader& header) {
|
||||||
|
pugi::xml_document document;
|
||||||
|
auto compiledBlobNode = document.append_child("compiled_blob");
|
||||||
|
compiledBlobNode.append_attribute("ie_version").set_value(header.m_ieVersion.c_str());
|
||||||
|
compiledBlobNode.append_attribute("file_info").set_value(header.m_fileInfo.c_str());
|
||||||
|
|
||||||
|
document.save(stream, nullptr, pugi::format_raw);
|
||||||
|
document.reset();
|
||||||
|
stream << std::endl;
|
||||||
|
|
||||||
|
return stream;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace InferenceEngine
|
@ -0,0 +1,47 @@
|
|||||||
|
// Copyright (C) 2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <map>
|
||||||
|
#include <istream>
|
||||||
|
#include <ostream>
|
||||||
|
|
||||||
|
namespace InferenceEngine {
|
||||||
|
|
||||||
|
class CNNNetwork;
|
||||||
|
|
||||||
|
struct NetworkCompilationContext final {
|
||||||
|
static std::string calculateFileInfo(const std::string& filePath);
|
||||||
|
|
||||||
|
static std::string computeHash(const CNNNetwork& network,
|
||||||
|
const std::map<std::string, std::string>& compileOptions);
|
||||||
|
|
||||||
|
static std::string computeHash(const std::string& modelName,
|
||||||
|
const std::map<std::string, std::string>& compileOptions);
|
||||||
|
};
|
||||||
|
|
||||||
|
class CompiledBlobHeader final {
|
||||||
|
std::string m_ieVersion;
|
||||||
|
std::string m_fileInfo;
|
||||||
|
|
||||||
|
public:
|
||||||
|
CompiledBlobHeader();
|
||||||
|
CompiledBlobHeader(const std::string& ieVersion, const std::string& fileInfo);
|
||||||
|
|
||||||
|
const std::string& getIeVersion() const {
|
||||||
|
return m_ieVersion;
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::string& getFileInfo() const {
|
||||||
|
return m_fileInfo;
|
||||||
|
}
|
||||||
|
|
||||||
|
friend std::istream & operator >> (std::istream& stream, CompiledBlobHeader& header);
|
||||||
|
|
||||||
|
friend std::ostream & operator << (std::ostream& stream, const CompiledBlobHeader& header);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace InferenceEngine
|
@ -13,6 +13,8 @@
|
|||||||
|
|
||||||
#include <file_utils.h>
|
#include <file_utils.h>
|
||||||
#include <details/ie_exception.hpp>
|
#include <details/ie_exception.hpp>
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <sys/stat.h>
|
||||||
|
|
||||||
#ifndef _WIN32
|
#ifndef _WIN32
|
||||||
# include <limits.h>
|
# include <limits.h>
|
||||||
@ -32,6 +34,38 @@
|
|||||||
# include <Windows.h>
|
# include <Windows.h>
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#ifdef _WIN32
|
||||||
|
|
||||||
|
#include <direct.h>
|
||||||
|
|
||||||
|
// Copied from linux libc sys/stat.h:
|
||||||
|
# define S_ISDIR(m) (((m) & S_IFMT) == S_IFDIR)
|
||||||
|
|
||||||
|
/// @brief Windows-specific 'mkdir' wrapper
|
||||||
|
#define makedir(dir) _mkdir(dir)
|
||||||
|
|
||||||
|
/// @brief Max length of absolute file path
|
||||||
|
#define MAX_ABS_PATH _MAX_PATH
|
||||||
|
/// @brief Get absolute file path, returns NULL in case of error
|
||||||
|
#define get_absolute_path(result, path) _fullpath(result, path.c_str(), MAX_ABS_PATH)
|
||||||
|
|
||||||
|
/// @brief Windows-specific 'stat' wrapper
|
||||||
|
#define stat _stat
|
||||||
|
|
||||||
|
#else
|
||||||
|
|
||||||
|
#include <unistd.h>
|
||||||
|
|
||||||
|
/// @brief mkdir wrapper
|
||||||
|
#define makedir(dir) mkdir(dir, 0755)
|
||||||
|
|
||||||
|
/// @brief Max length of absolute file path
|
||||||
|
#define MAX_ABS_PATH PATH_MAX
|
||||||
|
/// @brief Get absolute file path, returns NULL in case of error
|
||||||
|
#define get_absolute_path(result, path) realpath(path.c_str(), result)
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
#ifdef ENABLE_UNICODE_PATH_SUPPORT
|
#ifdef ENABLE_UNICODE_PATH_SUPPORT
|
||||||
|
|
||||||
std::string FileUtils::wStringtoMBCSstringChar(const std::wstring& wstr) {
|
std::string FileUtils::wStringtoMBCSstringChar(const std::wstring& wstr) {
|
||||||
@ -73,6 +107,44 @@ long long FileUtils::fileSize(const char* charfilepath) {
|
|||||||
return in.tellg();
|
return in.tellg();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string FileUtils::absoluteFilePath(const std::string& filePath) {
|
||||||
|
std::string absolutePath;
|
||||||
|
absolutePath.resize(MAX_ABS_PATH);
|
||||||
|
auto absPath = get_absolute_path(&absolutePath[0], filePath);
|
||||||
|
if (!absPath) {
|
||||||
|
THROW_IE_EXCEPTION << "Can't get absolute file path for [" << filePath << "], err = " << strerror(errno);
|
||||||
|
}
|
||||||
|
absolutePath.resize(strlen(absPath));
|
||||||
|
return absolutePath;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool FileUtils::directoryExists(const std::string &path) {
|
||||||
|
struct stat sb;
|
||||||
|
|
||||||
|
if (stat(path.c_str(), &sb) == 0 && S_ISDIR(sb.st_mode)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
void FileUtils::createDirectoryRecursive(const std::string& dirPath) {
|
||||||
|
if (dirPath.empty() || directoryExists(dirPath)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::size_t pos = dirPath.rfind(FileUtils::FileSeparator);
|
||||||
|
if (pos != std::string::npos) {
|
||||||
|
createDirectoryRecursive(dirPath.substr(0, pos));
|
||||||
|
}
|
||||||
|
|
||||||
|
int err = makedir(dirPath.c_str());
|
||||||
|
if (err != 0 && errno != EEXIST) {
|
||||||
|
// TODO: in case of exception it may be needed to remove all created sub-directories
|
||||||
|
THROW_IE_EXCEPTION << "Couldn't create directory ["
|
||||||
|
<< dirPath << "], err=" << strerror(errno) << ")";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
namespace InferenceEngine {
|
namespace InferenceEngine {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
121
inference-engine/src/inference_engine/ie_cache_manager.hpp
Normal file
121
inference-engine/src/inference_engine/ie_cache_manager.hpp
Normal file
@ -0,0 +1,121 @@
|
|||||||
|
// Copyright (C) 2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief This is a header file for the Inference Engine Cache Manager class C++ API
|
||||||
|
*
|
||||||
|
* @file ie_cache_manager.hpp
|
||||||
|
*/
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <fstream>
|
||||||
|
#include <string>
|
||||||
|
#include <functional>
|
||||||
|
#include "ie_api.h"
|
||||||
|
#include "file_utils.h"
|
||||||
|
|
||||||
|
namespace InferenceEngine {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief This class represents private interface for Cache Manager
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
class ICacheManager {
|
||||||
|
public:
|
||||||
|
/**
|
||||||
|
* @brief Default destructor
|
||||||
|
*/
|
||||||
|
virtual ~ICacheManager() = default;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Function passing created output stream
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
using StreamWriter = std::function<void(std::ostream&)>;
|
||||||
|
/**
|
||||||
|
* @brief Callback when Inference Engine intends to write network to cache
|
||||||
|
*
|
||||||
|
* Client needs to call create std::ostream object and call writer(ostream)
|
||||||
|
* Otherwise, network will not be cached
|
||||||
|
*
|
||||||
|
* @param id Id of cache (hash of the network)
|
||||||
|
* @param writer Lambda function to be called when stream is created
|
||||||
|
*/
|
||||||
|
virtual void writeCacheEntry(const std::string& id, StreamWriter writer) = 0;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Function passing created input stream
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
using StreamReader = std::function<void(std::istream&)>;
|
||||||
|
/**
|
||||||
|
* @brief Callback when Inference Engine intends to read network from cache
|
||||||
|
*
|
||||||
|
* Client needs to call create std::istream object and call reader(istream)
|
||||||
|
* Otherwise, network will not be read from cache and will be loaded as usual
|
||||||
|
*
|
||||||
|
* @param id Id of cache (hash of the network)
|
||||||
|
* @param reader Lambda function to be called when input stream is created
|
||||||
|
*/
|
||||||
|
virtual void readCacheEntry(const std::string& id, StreamReader reader) = 0;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Callback when Inference Engine intends to remove cache entry
|
||||||
|
*
|
||||||
|
* Client needs to perform appropriate cleanup (e.g. delete a cache file)
|
||||||
|
*
|
||||||
|
* @param id Id of cache (hash of the network)
|
||||||
|
*/
|
||||||
|
virtual void removeCacheEntry(const std::string& id) = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief File storage-based Implementation of ICacheManager
|
||||||
|
*
|
||||||
|
* Uses simple file for read/write cached models.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
class FileStorageCacheManager final : public ICacheManager {
|
||||||
|
std::string m_cachePath;
|
||||||
|
|
||||||
|
std::string getBlobFile(const std::string& blobHash) const {
|
||||||
|
return FileUtils::makePath(m_cachePath, blobHash + ".blob");
|
||||||
|
}
|
||||||
|
|
||||||
|
public:
|
||||||
|
/**
|
||||||
|
* @brief Constructor
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
FileStorageCacheManager(std::string&& cachePath) : m_cachePath(std::move(cachePath)) {}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Destructor
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
~FileStorageCacheManager() override = default;
|
||||||
|
|
||||||
|
private:
|
||||||
|
void writeCacheEntry(const std::string& id, StreamWriter writer) override {
|
||||||
|
std::ofstream stream(getBlobFile(id), std::ios_base::binary | std::ofstream::out);
|
||||||
|
writer(stream);
|
||||||
|
}
|
||||||
|
|
||||||
|
void readCacheEntry(const std::string& id, StreamReader reader) override {
|
||||||
|
auto blobFileName = getBlobFile(id);
|
||||||
|
if (FileUtils::fileExist(blobFileName)) {
|
||||||
|
std::ifstream stream(blobFileName, std::ios_base::binary);
|
||||||
|
reader(stream);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void removeCacheEntry(const std::string& id) override {
|
||||||
|
auto blobFileName = getBlobFile(id);
|
||||||
|
if (FileUtils::fileExist(blobFileName))
|
||||||
|
std::remove(blobFileName.c_str());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace InferenceEngine
|
@ -1,4 +1,4 @@
|
|||||||
// Copyright (C) 2018-2020 Intel Corporation
|
// Copyright (C) 2018-2021 Intel Corporation
|
||||||
// SPDX-License-Identifier: Apache-2.0
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
//
|
//
|
||||||
|
|
||||||
@ -6,8 +6,8 @@
|
|||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <istream>
|
|
||||||
#include <mutex>
|
#include <mutex>
|
||||||
|
#include <sys/stat.h>
|
||||||
|
|
||||||
#include <ie_core.hpp>
|
#include <ie_core.hpp>
|
||||||
#include <multi-device/multi_device_config.hpp>
|
#include <multi-device/multi_device_config.hpp>
|
||||||
@ -17,14 +17,17 @@
|
|||||||
#include <ngraph/pass/constant_folding.hpp>
|
#include <ngraph/pass/constant_folding.hpp>
|
||||||
|
|
||||||
#include <cpp_interfaces/exception2status.hpp>
|
#include <cpp_interfaces/exception2status.hpp>
|
||||||
|
#include "compilation_context.hpp"
|
||||||
#include "ie_plugin_cpp.hpp"
|
#include "ie_plugin_cpp.hpp"
|
||||||
#include "ie_plugin_config.hpp"
|
#include "ie_plugin_config.hpp"
|
||||||
|
#include "ie_cache_manager.hpp"
|
||||||
#include "ie_itt.hpp"
|
#include "ie_itt.hpp"
|
||||||
#include "file_utils.h"
|
#include "file_utils.h"
|
||||||
#include "ie_network_reader.hpp"
|
#include "ie_network_reader.hpp"
|
||||||
#include "xml_parse_utils.h"
|
#include "xml_parse_utils.h"
|
||||||
|
|
||||||
using namespace InferenceEngine::PluginConfigParams;
|
using namespace InferenceEngine::PluginConfigParams;
|
||||||
|
using namespace std::placeholders;
|
||||||
|
|
||||||
namespace InferenceEngine {
|
namespace InferenceEngine {
|
||||||
|
|
||||||
@ -158,6 +161,41 @@ class Core::Impl : public ICore {
|
|||||||
|
|
||||||
mutable std::map<std::string, InferencePlugin> plugins;
|
mutable std::map<std::string, InferencePlugin> plugins;
|
||||||
|
|
||||||
|
class CoreConfig final {
|
||||||
|
public:
|
||||||
|
struct CacheConfig {
|
||||||
|
std::shared_ptr<ICacheManager> _cacheManager;
|
||||||
|
};
|
||||||
|
|
||||||
|
void setAndUpdate(std::map<std::string, std::string>& config) {
|
||||||
|
auto it = config.find(CONFIG_KEY(CACHE_DIR));
|
||||||
|
if (it != config.end()) {
|
||||||
|
std::lock_guard<std::mutex> lock(_cacheConfigMutex);
|
||||||
|
if (!it->second.empty()) {
|
||||||
|
FileUtils::createDirectoryRecursive(it->second);
|
||||||
|
_cacheConfig._cacheManager = std::make_shared<FileStorageCacheManager>(std::move(it->second));
|
||||||
|
} else {
|
||||||
|
_cacheConfig._cacheManager = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
config.erase(it);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Creating thread-safe copy of config including shared_ptr to ICacheManager
|
||||||
|
CacheConfig getCacheConfig() const {
|
||||||
|
std::lock_guard<std::mutex> lock(_cacheConfigMutex);
|
||||||
|
return _cacheConfig;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
mutable std::mutex _cacheConfigMutex;
|
||||||
|
CacheConfig _cacheConfig;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Core settings (cache config, etc)
|
||||||
|
CoreConfig coreConfig;
|
||||||
|
|
||||||
struct PluginDescriptor {
|
struct PluginDescriptor {
|
||||||
FileUtils::FilePath libraryLocation;
|
FileUtils::FilePath libraryLocation;
|
||||||
std::map<std::string, std::string> defaultConfig;
|
std::map<std::string, std::string> defaultConfig;
|
||||||
@ -170,9 +208,141 @@ class Core::Impl : public ICore {
|
|||||||
std::map<std::string, PluginDescriptor> pluginRegistry;
|
std::map<std::string, PluginDescriptor> pluginRegistry;
|
||||||
mutable std::mutex pluginsMutex; // to lock parallel access to pluginRegistry and plugins
|
mutable std::mutex pluginsMutex; // to lock parallel access to pluginRegistry and plugins
|
||||||
|
|
||||||
|
bool DeviceSupportsImportExport(const InferencePlugin& plugin) const {
|
||||||
|
std::vector<std::string> supportedMetricKeys = plugin.GetMetric(METRIC_KEY(SUPPORTED_METRICS), {});
|
||||||
|
auto it = std::find(supportedMetricKeys.begin(), supportedMetricKeys.end(),
|
||||||
|
METRIC_KEY(IMPORT_EXPORT_SUPPORT));
|
||||||
|
bool supported = (it != supportedMetricKeys.end()) &&
|
||||||
|
plugin.GetMetric(METRIC_KEY(IMPORT_EXPORT_SUPPORT), {});
|
||||||
|
return supported;
|
||||||
|
}
|
||||||
|
|
||||||
|
ExecutableNetwork LoadNetworkImpl(const CNNNetwork& network,
|
||||||
|
InferencePlugin& plugin,
|
||||||
|
const std::map<std::string, std::string>& parsedConfig,
|
||||||
|
const RemoteContext::Ptr& context,
|
||||||
|
const std::string& blobID,
|
||||||
|
const std::string& modelPath = std::string()) {
|
||||||
|
OV_ITT_SCOPED_TASK(itt::domains::IE_LT, "Core::Impl::LoadNetworkImpl");
|
||||||
|
ExecutableNetwork execNetwork;
|
||||||
|
execNetwork = context ? plugin.LoadNetwork(network, context, parsedConfig) :
|
||||||
|
plugin.LoadNetwork(network, parsedConfig);
|
||||||
|
auto cacheManager = coreConfig.getCacheConfig()._cacheManager;
|
||||||
|
if (cacheManager && DeviceSupportsImportExport(plugin)) {
|
||||||
|
try {
|
||||||
|
// need to export network for further import from "cache"
|
||||||
|
OV_ITT_SCOPED_TASK(itt::domains::IE_LT, "Core::LoadNetwork::Export");
|
||||||
|
cacheManager->writeCacheEntry(blobID, [&](std::ostream& networkStream) {
|
||||||
|
networkStream << CompiledBlobHeader(GetInferenceEngineVersion()->buildNumber,
|
||||||
|
NetworkCompilationContext::calculateFileInfo(modelPath));
|
||||||
|
execNetwork.Export(networkStream);
|
||||||
|
});
|
||||||
|
} catch (...) {
|
||||||
|
cacheManager->removeCacheEntry(blobID);
|
||||||
|
throw;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return execNetwork;
|
||||||
|
}
|
||||||
|
|
||||||
|
ExecutableNetwork LoadNetworkFromCache(const std::shared_ptr<ICacheManager>& cacheManager,
|
||||||
|
const std::string& blobId,
|
||||||
|
InferencePlugin& plugin,
|
||||||
|
const std::map<std::string, std::string>& config,
|
||||||
|
const RemoteContext::Ptr& context,
|
||||||
|
bool& networkIsImported,
|
||||||
|
const std::string& modelPath = std::string()) {
|
||||||
|
ExecutableNetwork execNetwork;
|
||||||
|
struct HeaderException {};
|
||||||
|
|
||||||
|
IE_ASSERT(cacheManager != nullptr);
|
||||||
|
try {
|
||||||
|
cacheManager->readCacheEntry(blobId, [&](std::istream &networkStream) {
|
||||||
|
OV_ITT_SCOPED_TASK(itt::domains::IE_LT, "Core::LoadNetworkFromCache::ReadStreamAndImport");
|
||||||
|
try {
|
||||||
|
CompiledBlobHeader header;
|
||||||
|
networkStream >> header;
|
||||||
|
if (header.getIeVersion() != GetInferenceEngineVersion()->buildNumber) {
|
||||||
|
// Build number mismatch, don't use this cache
|
||||||
|
throw NetworkNotRead("Version does not match");
|
||||||
|
}
|
||||||
|
if (header.getFileInfo() != NetworkCompilationContext::calculateFileInfo(modelPath)) {
|
||||||
|
// Original file is changed, don't use cache
|
||||||
|
throw NetworkNotRead("Original model file is changed");
|
||||||
|
}
|
||||||
|
} catch (...) {
|
||||||
|
throw HeaderException();
|
||||||
|
}
|
||||||
|
|
||||||
|
execNetwork = context ?
|
||||||
|
plugin.ImportNetwork(networkStream, context, config) :
|
||||||
|
plugin.ImportNetwork(networkStream, config);
|
||||||
|
networkIsImported = true;
|
||||||
|
});
|
||||||
|
} catch (const HeaderException& ex) {
|
||||||
|
// For these exceptions just remove old cache and set that import didn't work
|
||||||
|
cacheManager->removeCacheEntry(blobId);
|
||||||
|
networkIsImported = false;
|
||||||
|
} catch (...) {
|
||||||
|
cacheManager->removeCacheEntry(blobId);
|
||||||
|
throw;
|
||||||
|
}
|
||||||
|
return execNetwork;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::map<std::string, std::string> CreateCompileConfig(const InferencePlugin& plugin,
|
||||||
|
const std::string& deviceFamily,
|
||||||
|
const std::map<std::string, std::string>& origConfig) const {
|
||||||
|
std::map<std::string, Parameter> getMetricConfig;
|
||||||
|
auto compileConfig = origConfig;
|
||||||
|
|
||||||
|
// 0. remove DEVICE_ID key
|
||||||
|
auto deviceIt = compileConfig.find(CONFIG_KEY(DEVICE_ID));
|
||||||
|
if (deviceIt != compileConfig.end()) {
|
||||||
|
getMetricConfig[deviceIt->first] = deviceIt->second;
|
||||||
|
compileConfig.erase(deviceIt);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 1. replace it with DEVICE_ARCHITECTURE value
|
||||||
|
std::vector<std::string> supportedMetricKeys =
|
||||||
|
plugin.GetMetric(METRIC_KEY(SUPPORTED_METRICS), getMetricConfig);
|
||||||
|
auto archIt = std::find(supportedMetricKeys.begin(), supportedMetricKeys.end(),
|
||||||
|
METRIC_KEY(DEVICE_ARCHITECTURE));
|
||||||
|
if (archIt != supportedMetricKeys.end()) {
|
||||||
|
auto value = plugin.GetMetric(METRIC_KEY(DEVICE_ARCHITECTURE), getMetricConfig);
|
||||||
|
compileConfig[METRIC_KEY(DEVICE_ARCHITECTURE)] = value.as<std::string>();
|
||||||
|
} else {
|
||||||
|
// Take device name if device does not support DEVICE_ARCHITECTURE metric
|
||||||
|
compileConfig[METRIC_KEY(DEVICE_ARCHITECTURE)] = deviceFamily;
|
||||||
|
}
|
||||||
|
return compileConfig;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string CalculateNetworkHash(const CNNNetwork& network, const std::string& deviceFamily,
|
||||||
|
const InferencePlugin& plugin,
|
||||||
|
const std::map<std::string, std::string>& config) const {
|
||||||
|
auto compileConfig = CreateCompileConfig(plugin, deviceFamily, config);
|
||||||
|
return NetworkCompilationContext::computeHash(network, compileConfig);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string CalculateFileHash(const std::string& modelName, const std::string& deviceFamily,
|
||||||
|
const InferencePlugin& plugin,
|
||||||
|
const std::map<std::string, std::string>& config) const {
|
||||||
|
auto compileConfig = CreateCompileConfig(plugin, deviceFamily, config);
|
||||||
|
return NetworkCompilationContext::computeHash(modelName, compileConfig);
|
||||||
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
Impl();
|
Impl() {
|
||||||
~Impl() override;
|
opsetNames.insert("opset1");
|
||||||
|
opsetNames.insert("opset2");
|
||||||
|
opsetNames.insert("opset3");
|
||||||
|
opsetNames.insert("opset4");
|
||||||
|
opsetNames.insert("opset5");
|
||||||
|
opsetNames.insert("opset6");
|
||||||
|
}
|
||||||
|
|
||||||
|
~Impl() override = default;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Register plugins for devices which are located in .xml configuration file. The function supports UNICODE path
|
* @brief Register plugins for devices which are located in .xml configuration file. The function supports UNICODE path
|
||||||
@ -250,20 +420,80 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
CNNNetwork ReadNetwork(const std::string& modelPath, const std::string& binPath) const override {
|
CNNNetwork ReadNetwork(const std::string& modelPath, const std::string& binPath) const override {
|
||||||
OV_ITT_SCOPED_TASK(itt::domains::IE);
|
OV_ITT_SCOPED_TASK(itt::domains::IE, "Core::Impl::ReadNetwork from file");
|
||||||
return details::ReadNetwork(modelPath, binPath, extensions);
|
return details::ReadNetwork(modelPath, binPath, extensions);
|
||||||
}
|
}
|
||||||
|
|
||||||
CNNNetwork ReadNetwork(const std::string& model, const Blob::CPtr& weights) const override {
|
CNNNetwork ReadNetwork(const std::string& model, const Blob::CPtr& weights) const override {
|
||||||
OV_ITT_SCOPED_TASK(itt::domains::IE, "Core::Impl::ReadNetwork");
|
OV_ITT_SCOPED_TASK(itt::domains::IE, "Core::Impl::ReadNetwork from memory");
|
||||||
return details::ReadNetwork(model, weights, extensions);
|
return details::ReadNetwork(model, weights, extensions);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: In future this method can be added to ICore interface
|
||||||
|
ExecutableNetwork LoadNetwork(const CNNNetwork& network, const RemoteContext::Ptr& context,
|
||||||
|
const std::map<std::string, std::string>& config) {
|
||||||
|
OV_ITT_SCOPED_TASK(itt::domains::IE_LT, "Core::LoadNetwork::RemoteContext");
|
||||||
|
if (context == nullptr) {
|
||||||
|
THROW_IE_EXCEPTION << "Remote context is null";
|
||||||
|
}
|
||||||
|
auto parsed = parseDeviceNameIntoConfig(context->getDeviceName(), config);
|
||||||
|
auto plugin = GetCPPPluginByName(parsed._deviceName);
|
||||||
|
bool loadedFromCache = false;
|
||||||
|
ExecutableNetwork res;
|
||||||
|
std::string hash;
|
||||||
|
auto cacheManager = coreConfig.getCacheConfig()._cacheManager;
|
||||||
|
if (cacheManager && DeviceSupportsImportExport(plugin)) {
|
||||||
|
hash = CalculateNetworkHash(network, parsed._deviceName, plugin, parsed._config);
|
||||||
|
res = LoadNetworkFromCache(cacheManager, hash, plugin, parsed._config, context, loadedFromCache);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!loadedFromCache) {
|
||||||
|
res = LoadNetworkImpl(network, plugin, parsed._config, context, hash);
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
ExecutableNetwork LoadNetwork(const CNNNetwork& network, const std::string& deviceName,
|
ExecutableNetwork LoadNetwork(const CNNNetwork& network, const std::string& deviceName,
|
||||||
const std::map<std::string, std::string>& config) override {
|
const std::map<std::string, std::string>& config) override {
|
||||||
OV_ITT_SCOPED_TASK(itt::domains::IE, "Core::Impl::LoadNetwork");
|
OV_ITT_SCOPED_TASK(itt::domains::IE_LT, "Core::LoadNetwork::CNN");
|
||||||
auto parsed = parseDeviceNameIntoConfig(deviceName, config);
|
auto parsed = parseDeviceNameIntoConfig(deviceName, config);
|
||||||
return GetCPPPluginByName(parsed._deviceName).LoadNetwork(network, parsed._config);
|
auto plugin = GetCPPPluginByName(parsed._deviceName);
|
||||||
|
bool loadedFromCache = false;
|
||||||
|
ExecutableNetwork res;
|
||||||
|
std::string hash;
|
||||||
|
auto cacheManager = coreConfig.getCacheConfig()._cacheManager;
|
||||||
|
if (cacheManager && DeviceSupportsImportExport(plugin)) {
|
||||||
|
hash = CalculateNetworkHash(network, parsed._deviceName, plugin, parsed._config);
|
||||||
|
res = LoadNetworkFromCache(cacheManager, hash, plugin, parsed._config, nullptr, loadedFromCache);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!loadedFromCache) {
|
||||||
|
res = LoadNetworkImpl(network, plugin, parsed._config, nullptr, hash);
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: In future this method can be added to ICore interface
|
||||||
|
ExecutableNetwork LoadNetwork(const std::string& modelPath, const std::string& deviceName,
|
||||||
|
const std::map<std::string, std::string>& config) {
|
||||||
|
OV_ITT_SCOPED_TASK(itt::domains::IE_LT, "Core::LoadNetwork::Path");
|
||||||
|
auto parsed = parseDeviceNameIntoConfig(deviceName, config);
|
||||||
|
auto plugin = GetCPPPluginByName(parsed._deviceName);
|
||||||
|
bool loadedFromCache = false;
|
||||||
|
ExecutableNetwork res;
|
||||||
|
std::string hash;
|
||||||
|
auto cacheManager = coreConfig.getCacheConfig()._cacheManager;
|
||||||
|
if (cacheManager && DeviceSupportsImportExport(plugin)) {
|
||||||
|
hash = CalculateFileHash(modelPath, parsed._deviceName, plugin, parsed._config);
|
||||||
|
res = LoadNetworkFromCache(cacheManager, hash, plugin, parsed._config,
|
||||||
|
nullptr, loadedFromCache, modelPath);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!loadedFromCache) {
|
||||||
|
auto cnnNetwork = ReadNetwork(modelPath, std::string());
|
||||||
|
res = LoadNetworkImpl(cnnNetwork, plugin, parsed._config, nullptr, hash, modelPath);
|
||||||
|
}
|
||||||
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
ExecutableNetwork ImportNetwork(std::istream& networkModel, const std::string& deviceName,
|
ExecutableNetwork ImportNetwork(std::istream& networkModel, const std::string& deviceName,
|
||||||
@ -286,6 +516,7 @@ public:
|
|||||||
|
|
||||||
QueryNetworkResult QueryNetwork(const CNNNetwork& network, const std::string& deviceName,
|
QueryNetworkResult QueryNetwork(const CNNNetwork& network, const std::string& deviceName,
|
||||||
const std::map<std::string, std::string>& config) const override {
|
const std::map<std::string, std::string>& config) const override {
|
||||||
|
OV_ITT_SCOPED_TASK(itt::domains::IE, "Core::QueryNetwork");
|
||||||
auto parsed = parseDeviceNameIntoConfig(deviceName, config);
|
auto parsed = parseDeviceNameIntoConfig(deviceName, config);
|
||||||
auto res = GetCPPPluginByName(parsed._deviceName).QueryNetwork(network, parsed._config);
|
auto res = GetCPPPluginByName(parsed._deviceName).QueryNetwork(network, parsed._config);
|
||||||
if (!network.getFunction() || res.supportedLayersMap.empty())
|
if (!network.getFunction() || res.supportedLayersMap.empty())
|
||||||
@ -338,7 +569,6 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @deprecated
|
|
||||||
* @brief Returns reference to CPP plugin wrapper by a device name
|
* @brief Returns reference to CPP plugin wrapper by a device name
|
||||||
* @param deviceName A name of device
|
* @param deviceName A name of device
|
||||||
* @return Reference to a CPP plugin wrapper
|
* @return Reference to a CPP plugin wrapper
|
||||||
@ -463,10 +693,18 @@ public:
|
|||||||
* @brief Sets config values for a plugin or set of plugins
|
* @brief Sets config values for a plugin or set of plugins
|
||||||
* @param deviceName A device name to set config to
|
* @param deviceName A device name to set config to
|
||||||
* If empty, config is set for all the plugins / plugin's meta-data
|
* If empty, config is set for all the plugins / plugin's meta-data
|
||||||
|
* @note `deviceName` is not allowed in form of MULTI:CPU, HETERO:FPGA,CPU
|
||||||
|
* just simple forms like CPU, GPU, MULTU, GPU.0, etc
|
||||||
*/
|
*/
|
||||||
void SetConfigForPlugins(const std::map<std::string, std::string>& config, const std::string& deviceName) {
|
void SetConfigForPlugins(const std::map<std::string, std::string>& configMap, const std::string& deviceName) {
|
||||||
|
auto config = configMap;
|
||||||
|
|
||||||
std::lock_guard<std::mutex> lock(pluginsMutex);
|
std::lock_guard<std::mutex> lock(pluginsMutex);
|
||||||
|
|
||||||
|
if (deviceName.empty()) {
|
||||||
|
coreConfig.setAndUpdate(config);
|
||||||
|
}
|
||||||
|
|
||||||
// set config for plugins in registry
|
// set config for plugins in registry
|
||||||
bool configIsSet = false;
|
bool configIsSet = false;
|
||||||
for (auto& desc : pluginRegistry) {
|
for (auto& desc : pluginRegistry) {
|
||||||
@ -524,15 +762,6 @@ public:
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
Core::Impl::Impl() {
|
|
||||||
opsetNames.insert("opset1");
|
|
||||||
opsetNames.insert("opset2");
|
|
||||||
opsetNames.insert("opset3");
|
|
||||||
opsetNames.insert("opset4");
|
|
||||||
}
|
|
||||||
|
|
||||||
Core::Impl::~Impl() {}
|
|
||||||
|
|
||||||
Core::Core(const std::string& xmlConfigFile) {
|
Core::Core(const std::string& xmlConfigFile) {
|
||||||
_impl = std::make_shared<Impl>();
|
_impl = std::make_shared<Impl>();
|
||||||
|
|
||||||
@ -603,20 +832,14 @@ ExecutableNetwork Core::LoadNetwork(const CNNNetwork& network, const std::string
|
|||||||
return _impl->LoadNetwork(network, deviceName, config);
|
return _impl->LoadNetwork(network, deviceName, config);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Core::AddExtension(const IExtensionPtr& extension) {
|
|
||||||
_impl->AddExtension(extension);
|
|
||||||
}
|
|
||||||
|
|
||||||
ExecutableNetwork Core::LoadNetwork(const CNNNetwork& network, RemoteContext::Ptr context,
|
ExecutableNetwork Core::LoadNetwork(const CNNNetwork& network, RemoteContext::Ptr context,
|
||||||
const std::map<std::string, std::string>& config) {
|
const std::map<std::string, std::string>& config) {
|
||||||
OV_ITT_SCOPED_TASK(itt::domains::IE, "Core::LoadNetwork");
|
return _impl->LoadNetwork(network, context, config);
|
||||||
|
|
||||||
if (context == nullptr) {
|
|
||||||
THROW_IE_EXCEPTION << "Remote context is null";
|
|
||||||
}
|
}
|
||||||
|
|
||||||
auto parsed = parseDeviceNameIntoConfig(context->getDeviceName(), config);
|
ExecutableNetwork Core::LoadNetwork(const std::string& modelPath, const std::string& deviceName,
|
||||||
return _impl->GetCPPPluginByName(parsed._deviceName).LoadNetwork(network, parsed._config, context);
|
const std::map<std::string, std::string>& config) {
|
||||||
|
return _impl->LoadNetwork(modelPath, deviceName, config);
|
||||||
}
|
}
|
||||||
|
|
||||||
RemoteContext::Ptr Core::CreateContext(const std::string& deviceName, const ParamMap& params) {
|
RemoteContext::Ptr Core::CreateContext(const std::string& deviceName, const ParamMap& params) {
|
||||||
@ -656,8 +879,15 @@ void Core::AddExtension(IExtensionPtr extension, const std::string& deviceName_)
|
|||||||
_impl->AddExtension(extension);
|
_impl->AddExtension(extension);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void Core::AddExtension(const IExtensionPtr& extension) {
|
||||||
|
_impl->AddExtension(extension);
|
||||||
|
}
|
||||||
|
|
||||||
ExecutableNetwork Core::ImportNetwork(const std::string& modelFileName, const std::string& deviceName,
|
ExecutableNetwork Core::ImportNetwork(const std::string& modelFileName, const std::string& deviceName,
|
||||||
const std::map<std::string, std::string>& config) {
|
const std::map<std::string, std::string>& config) {
|
||||||
|
OV_ITT_SCOPED_TASK(itt::domains::IE, "Core::ImportNetwork");
|
||||||
|
|
||||||
|
// TODO: remove once NotImplemented exception is deprecated and not used
|
||||||
if (deviceName.find("HETERO") == 0) {
|
if (deviceName.find("HETERO") == 0) {
|
||||||
THROW_IE_EXCEPTION << "HETERO device does not support ImportNetwork";
|
THROW_IE_EXCEPTION << "HETERO device does not support ImportNetwork";
|
||||||
}
|
}
|
||||||
@ -698,19 +928,21 @@ QueryNetworkResult Core::QueryNetwork(const CNNNetwork& network, const std::stri
|
|||||||
|
|
||||||
void Core::SetConfig(const std::map<std::string, std::string>& config, const std::string& deviceName) {
|
void Core::SetConfig(const std::map<std::string, std::string>& config, const std::string& deviceName) {
|
||||||
// HETERO case
|
// HETERO case
|
||||||
{
|
|
||||||
if (deviceName.find("HETERO:") == 0) {
|
if (deviceName.find("HETERO:") == 0) {
|
||||||
THROW_IE_EXCEPTION << "SetConfig is supported only for HETERO itself (without devices). "
|
THROW_IE_EXCEPTION << "SetConfig is supported only for HETERO itself (without devices). "
|
||||||
"You can configure the devices with SetConfig before creating the HETERO on top.";
|
"You can configure the devices with SetConfig before creating the HETERO on top.";
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// MULTI case
|
// MULTI case
|
||||||
{
|
|
||||||
if (deviceName.find("MULTI:") == 0) {
|
if (deviceName.find("MULTI:") == 0) {
|
||||||
THROW_IE_EXCEPTION << "SetConfig is supported only for MULTI itself (without devices). "
|
THROW_IE_EXCEPTION << "SetConfig is supported only for MULTI itself (without devices). "
|
||||||
"You can configure the devices with SetConfig before creating the MULTI on top.";
|
"You can configure the devices with SetConfig before creating the MULTI on top.";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GPU.0, FPGA.1 cases
|
||||||
|
if (deviceName.find(".") != std::string::npos) {
|
||||||
|
THROW_IE_EXCEPTION << "SetConfig is supported only for device family itself (without particular device .#). "
|
||||||
|
"You can pass .# as a particular device instance to QueryNetwork, LoadNetwork, ImportNetwork only";
|
||||||
}
|
}
|
||||||
|
|
||||||
if (deviceName.empty()) {
|
if (deviceName.empty()) {
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
// Copyright (C) 2018-2020 Intel Corporation
|
// Copyright (C) 2018-2021 Intel Corporation
|
||||||
// SPDX-License-Identifier: Apache-2.0
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
//
|
//
|
||||||
|
|
||||||
@ -75,10 +75,6 @@ public:
|
|||||||
CALL_STATEMENT(return actual->GetVersion());
|
CALL_STATEMENT(return actual->GetVersion());
|
||||||
}
|
}
|
||||||
|
|
||||||
ExecutableNetwork LoadNetwork(CNNNetwork network, const std::map<std::string, std::string>& config) {
|
|
||||||
CALL_STATEMENT(return ExecutableNetwork(actual->LoadNetwork(network, config), actual));
|
|
||||||
}
|
|
||||||
|
|
||||||
void AddExtension(InferenceEngine::IExtensionPtr extension) {
|
void AddExtension(InferenceEngine::IExtensionPtr extension) {
|
||||||
CALL_STATEMENT(actual->AddExtension(extension));
|
CALL_STATEMENT(actual->AddExtension(extension));
|
||||||
}
|
}
|
||||||
@ -87,9 +83,12 @@ public:
|
|||||||
CALL_STATEMENT(actual->SetConfig(config));
|
CALL_STATEMENT(actual->SetConfig(config));
|
||||||
}
|
}
|
||||||
|
|
||||||
ExecutableNetwork ImportNetwork(const std::string& modelFileName,
|
ExecutableNetwork LoadNetwork(const CNNNetwork& network, const std::map<std::string, std::string>& config) {
|
||||||
const std::map<std::string, std::string>& config) {
|
CALL_STATEMENT(return ExecutableNetwork(actual->LoadNetwork(network, config), actual));
|
||||||
CALL_STATEMENT(return ExecutableNetwork(actual->ImportNetwork(modelFileName, config), actual));
|
}
|
||||||
|
|
||||||
|
ExecutableNetwork LoadNetwork(const CNNNetwork& network, RemoteContext::Ptr context, const std::map<std::string, std::string>& config) {
|
||||||
|
CALL_STATEMENT(return ExecutableNetwork(actual->LoadNetwork(network, config, context), actual));
|
||||||
}
|
}
|
||||||
|
|
||||||
QueryNetworkResult QueryNetwork(const CNNNetwork& network,
|
QueryNetworkResult QueryNetwork(const CNNNetwork& network,
|
||||||
@ -100,18 +99,24 @@ public:
|
|||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ExecutableNetwork ImportNetwork(const std::string& modelFileName,
|
||||||
|
const std::map<std::string, std::string>& config) {
|
||||||
|
CALL_STATEMENT(return ExecutableNetwork(actual->ImportNetwork(modelFileName, config), actual));
|
||||||
|
}
|
||||||
|
|
||||||
ExecutableNetwork ImportNetwork(std::istream& networkModel,
|
ExecutableNetwork ImportNetwork(std::istream& networkModel,
|
||||||
const std::map<std::string, std::string>& config) {
|
const std::map<std::string, std::string>& config) {
|
||||||
CALL_STATEMENT(return ExecutableNetwork(actual->ImportNetwork(networkModel, config), actual));
|
CALL_STATEMENT(return ExecutableNetwork(actual->ImportNetwork(networkModel, config), actual));
|
||||||
}
|
}
|
||||||
|
|
||||||
Parameter GetMetric(const std::string& name, const std::map<std::string, Parameter>& options) const {
|
ExecutableNetwork ImportNetwork(std::istream& networkModel,
|
||||||
CALL_STATEMENT(return actual->GetMetric(name, options));
|
const RemoteContext::Ptr& context,
|
||||||
|
const std::map<std::string, std::string>& config) {
|
||||||
|
CALL_STATEMENT(return ExecutableNetwork(actual->ImportNetwork(networkModel, context, config), actual));
|
||||||
}
|
}
|
||||||
|
|
||||||
ExecutableNetwork LoadNetwork(const CNNNetwork& network, const std::map<std::string, std::string>& config,
|
Parameter GetMetric(const std::string& name, const std::map<std::string, Parameter>& options) const {
|
||||||
RemoteContext::Ptr context) {
|
CALL_STATEMENT(return actual->GetMetric(name, options));
|
||||||
CALL_STATEMENT(return ExecutableNetwork(actual->LoadNetwork(network, config, context), actual));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
RemoteContext::Ptr CreateContext(const ParamMap& params) {
|
RemoteContext::Ptr CreateContext(const ParamMap& params) {
|
||||||
@ -122,12 +127,6 @@ public:
|
|||||||
CALL_STATEMENT(return actual->GetDefaultContext(params));
|
CALL_STATEMENT(return actual->GetDefaultContext(params));
|
||||||
}
|
}
|
||||||
|
|
||||||
ExecutableNetwork ImportNetwork(std::istream& networkModel,
|
|
||||||
const RemoteContext::Ptr& context,
|
|
||||||
const std::map<std::string, std::string>& config) {
|
|
||||||
CALL_STATEMENT(return ExecutableNetwork(actual->ImportNetwork(networkModel, context, config), actual));
|
|
||||||
}
|
|
||||||
|
|
||||||
Parameter GetConfig(const std::string& name, const std::map<std::string, Parameter>& options) const {
|
Parameter GetConfig(const std::string& name, const std::map<std::string, Parameter>& options) const {
|
||||||
CALL_STATEMENT(return actual->GetConfig(name, options));
|
CALL_STATEMENT(return actual->GetConfig(name, options));
|
||||||
}
|
}
|
||||||
|
@ -81,6 +81,31 @@ template<> struct FileTraits<wchar_t> {
|
|||||||
};
|
};
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Interface function to get absolute path of file
|
||||||
|
* @ingroup ie_dev_api_file_utils
|
||||||
|
* @param filePath - path to file, can be relative to current working directory
|
||||||
|
* @return Absolute path of file
|
||||||
|
* @throw InferenceEngineException if any error occurred
|
||||||
|
*/
|
||||||
|
INFERENCE_ENGINE_API_CPP(std::string) absoluteFilePath(const std::string& filePath);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Interface function to create directorty recursively by given path
|
||||||
|
* @ingroup ie_dev_api_file_utils
|
||||||
|
* @param dirPath - path to file, can be relative to current working directory
|
||||||
|
* @throw InferenceEngineException if any error occurred
|
||||||
|
*/
|
||||||
|
INFERENCE_ENGINE_API_CPP(void) createDirectoryRecursive(const std::string& dirPath);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Interface function to check if directory exists for given path
|
||||||
|
* @ingroup ie_dev_api_file_utils
|
||||||
|
* @param path - path to directory
|
||||||
|
* @return true if directory exists, false otherwise
|
||||||
|
*/
|
||||||
|
INFERENCE_ENGINE_API_CPP(bool) directoryExists(const std::string& path);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Interface function to get the size of a file. The function supports UNICODE path
|
* @brief Interface function to get the size of a file. The function supports UNICODE path
|
||||||
* @ingroup ie_dev_api_file_utils
|
* @ingroup ie_dev_api_file_utils
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
// Copyright (C) 2018-2020 Intel Corporation
|
// Copyright (C) 2018-2021 Intel Corporation
|
||||||
// SPDX-License-Identifier: Apache-2.0
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
//
|
//
|
||||||
|
|
||||||
@ -68,8 +68,8 @@ public:
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Creates an executable network from a previously exported network
|
* @brief Creates an executable network from a previously exported network
|
||||||
* @param deviceName Name of device load executable network on
|
|
||||||
* @param networkModel network model stream
|
* @param networkModel network model stream
|
||||||
|
* @param deviceName Name of device load executable network on
|
||||||
* @param config Optional map of pairs: (config parameter name, config parameter value) relevant only for this load
|
* @param config Optional map of pairs: (config parameter name, config parameter value) relevant only for this load
|
||||||
* operation*
|
* operation*
|
||||||
* @return An executable network reference
|
* @return An executable network reference
|
||||||
|
@ -0,0 +1,956 @@
|
|||||||
|
// Copyright (C) 2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <atomic>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include <thread>
|
||||||
|
#include <chrono>
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
#include <gmock/gmock.h>
|
||||||
|
|
||||||
|
#include "ie_plugin_ptr.hpp"
|
||||||
|
#include "ngraph/function.hpp"
|
||||||
|
#include "details/ie_so_loader.h"
|
||||||
|
#include "ie_metric_helpers.hpp"
|
||||||
|
#include "ie_iexecutable_network.hpp"
|
||||||
|
|
||||||
|
#include "cpp_interfaces/impl/ie_executable_network_internal.hpp"
|
||||||
|
#include "cpp_interfaces/impl/ie_plugin_internal.hpp"
|
||||||
|
|
||||||
|
#include "common_test_utils/unicode_utils.hpp"
|
||||||
|
#include "common_test_utils/file_utils.hpp"
|
||||||
|
#include "common_test_utils/test_constants.hpp"
|
||||||
|
|
||||||
|
#include "functional_test_utils/test_model/test_model.hpp"
|
||||||
|
#include "functional_test_utils/network_utils.hpp"
|
||||||
|
|
||||||
|
#include "unit_test_utils/mocks/mock_iexecutable_network.hpp"
|
||||||
|
|
||||||
|
using namespace InferenceEngine;
|
||||||
|
using namespace ::testing;
|
||||||
|
using namespace InferenceEngine::details;
|
||||||
|
using namespace std::placeholders;
|
||||||
|
using namespace std::chrono;
|
||||||
|
|
||||||
|
enum class TestLoadType {
|
||||||
|
ECNN,
|
||||||
|
EContext,
|
||||||
|
EModelName
|
||||||
|
};
|
||||||
|
using TestParam = std::tuple<TestLoadType, std::string, bool>;
|
||||||
|
|
||||||
|
// GCC4.8 limitation: have to specify type of each element in list
|
||||||
|
static const std::vector<TestParam> loadVariants = {
|
||||||
|
TestParam { TestLoadType::ECNN, std::string("ByCNNNetwork"), false },
|
||||||
|
TestParam { TestLoadType::EContext, std::string("ByRemoteContext"), true },
|
||||||
|
TestParam { TestLoadType::EModelName, std::string("ByModelName"), false },
|
||||||
|
};
|
||||||
|
|
||||||
|
static const std::vector<std::string> cacheFolders {
|
||||||
|
std::string("testCache"),
|
||||||
|
};
|
||||||
|
|
||||||
|
std::string getTestCaseName(const testing::TestParamInfo<std::tuple<TestParam, std::string>> &obj) {
|
||||||
|
return std::get<1>(std::get<0>(obj.param)) + "_" + std::get<1>(obj.param);
|
||||||
|
}
|
||||||
|
|
||||||
|
class MockRemoteContext : public RemoteContext {
|
||||||
|
std::string m_name;
|
||||||
|
public:
|
||||||
|
MockRemoteContext(std::string name): m_name(std::move(name)) {}
|
||||||
|
std::string getDeviceName() const noexcept { return m_name; }
|
||||||
|
MOCK_METHOD2(CreateBlob, RemoteBlob::Ptr(const TensorDesc&, const ParamMap&));
|
||||||
|
MOCK_QUALIFIED_METHOD0(getParams, const, ParamMap());
|
||||||
|
};
|
||||||
|
|
||||||
|
class MockCachingInferencePlugin : public InferenceEngine::InferencePluginInternal {
|
||||||
|
public:
|
||||||
|
MockCachingInferencePlugin() = default;
|
||||||
|
~MockCachingInferencePlugin() = default;
|
||||||
|
|
||||||
|
MOCK_METHOD2(LoadExeNetworkImpl, ExecutableNetworkInternal::Ptr(const CNNNetwork& network,
|
||||||
|
const std::map<std::string, std::string>& config));
|
||||||
|
|
||||||
|
MOCK_METHOD3(LoadExeNetworkImpl, ExecutableNetworkInternal::Ptr(const CNNNetwork& network, RemoteContext::Ptr context,
|
||||||
|
const std::map<std::string, std::string>& config));
|
||||||
|
|
||||||
|
MOCK_METHOD2(ImportNetworkImpl, ExecutableNetwork(std::istream& networkModel,
|
||||||
|
const std::map<std::string, std::string>& config));
|
||||||
|
|
||||||
|
MOCK_METHOD3(ImportNetworkImpl, ExecutableNetwork(std::istream& networkModel,
|
||||||
|
const RemoteContext::Ptr& context,
|
||||||
|
const std::map<std::string, std::string>& config));
|
||||||
|
|
||||||
|
MOCK_CONST_METHOD2(QueryNetwork, QueryNetworkResult(const CNNNetwork& network,
|
||||||
|
const std::map<std::string, std::string>& config));
|
||||||
|
|
||||||
|
MOCK_CONST_METHOD2(GetMetric, Parameter(const std::string& name, const std::map<std::string, Parameter>& options));
|
||||||
|
MOCK_METHOD1(GetDefaultContext, RemoteContext::Ptr(const ParamMap& params));
|
||||||
|
};
|
||||||
|
|
||||||
|
class MockExecutableNetwork : public ExecutableNetworkInternal {
|
||||||
|
public:
|
||||||
|
MockExecutableNetwork() {}
|
||||||
|
MOCK_METHOD1(ExportImpl, void(std::ostream& networkModel));
|
||||||
|
MOCK_METHOD0(CreateInferRequest, IInferRequest::Ptr());
|
||||||
|
};
|
||||||
|
|
||||||
|
//------------------------------------------------------
|
||||||
|
class MkDirGuard {
|
||||||
|
std::string m_dir;
|
||||||
|
public:
|
||||||
|
MkDirGuard(const std::string &dir = std::string()): m_dir(dir) {
|
||||||
|
if (!m_dir.empty()) {
|
||||||
|
CommonTestUtils::createDirectory(m_dir);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
MkDirGuard(const MkDirGuard&) = delete;
|
||||||
|
MkDirGuard& operator=(const MkDirGuard&) = delete;
|
||||||
|
|
||||||
|
~MkDirGuard() {
|
||||||
|
if (!m_dir.empty()) {
|
||||||
|
CommonTestUtils::removeFilesWithExt(m_dir, "blob");
|
||||||
|
CommonTestUtils::removeDir(m_dir);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class CachingTest : public ::testing::TestWithParam<std::tuple<TestParam, std::string>> {
|
||||||
|
public:
|
||||||
|
std::unique_ptr<SharedObjectLoader> sharedObjectLoader;
|
||||||
|
std::function<void(IInferencePlugin*)> injectProxyEngine;
|
||||||
|
std::string modelName = "Caching_test.xml";
|
||||||
|
std::string weightsName = "Caching_test.bin";
|
||||||
|
std::string deviceName = "mock";
|
||||||
|
std::string deviceToLoad = "mock";
|
||||||
|
std::shared_ptr<MockCachingInferencePlugin> mockPlugin;
|
||||||
|
std::shared_ptr<MockExecutableNetwork> net;
|
||||||
|
std::unique_ptr<MkDirGuard> m_dirCreator;
|
||||||
|
TestLoadType m_type;
|
||||||
|
std::string m_cacheDir;
|
||||||
|
using LoadFunction = std::function<void(Core&)>;
|
||||||
|
using LoadFunctionWithCfg = std::function<void(Core&, const std::map<std::string, std::string> &)>;
|
||||||
|
LoadFunction m_testFunction;
|
||||||
|
LoadFunctionWithCfg m_testFunctionWithCfg;
|
||||||
|
bool m_remoteContext = false;
|
||||||
|
using CNNCallback = std::function<void(CNNNetwork&)>;
|
||||||
|
CNNCallback m_cnnCallback = nullptr;
|
||||||
|
|
||||||
|
|
||||||
|
std::string get_mock_engine_name() {
|
||||||
|
std::string mockEngineName("mock_engine");
|
||||||
|
return CommonTestUtils::pre + mockEngineName + IE_BUILD_POSTFIX + CommonTestUtils::ext;
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::string generateTestFilePrefix() {
|
||||||
|
// Generate unique file names based on test name, thread id and timestamp
|
||||||
|
// This allows execution of tests in parallel (stress mode)
|
||||||
|
auto testInfo = UnitTest::GetInstance()->current_test_info();
|
||||||
|
std::string testName = testInfo->test_case_name();
|
||||||
|
testName += testInfo->name();
|
||||||
|
testName = std::to_string(std::hash<std::string>()(testName));
|
||||||
|
std::stringstream ss;
|
||||||
|
auto ts = duration_cast<microseconds>(high_resolution_clock::now().time_since_epoch());
|
||||||
|
ss << testName << "_" << std::this_thread::get_id() << "_" << ts.count();
|
||||||
|
testName = ss.str();
|
||||||
|
return testName;
|
||||||
|
}
|
||||||
|
|
||||||
|
void initParamTest() {
|
||||||
|
m_type = std::get<0>(std::get<0>(GetParam()));
|
||||||
|
m_cacheDir = std::get<1>(GetParam());
|
||||||
|
m_testFunction = getLoadFunction(m_type);
|
||||||
|
m_testFunctionWithCfg = getLoadFunctionWithCfg(m_type);
|
||||||
|
m_remoteContext = std::get<2>(std::get<0>(GetParam()));
|
||||||
|
auto testName = generateTestFilePrefix();
|
||||||
|
modelName = testName + ".xml";
|
||||||
|
weightsName = testName + ".bin";
|
||||||
|
m_cacheDir = testName + m_cacheDir;
|
||||||
|
m_dirCreator = std::unique_ptr<MkDirGuard>(new MkDirGuard(m_cacheDir));
|
||||||
|
}
|
||||||
|
|
||||||
|
void SetUp() override {
|
||||||
|
initParamTest();
|
||||||
|
mockPlugin = std::make_shared<MockCachingInferencePlugin>();
|
||||||
|
net = std::make_shared<MockExecutableNetwork>();
|
||||||
|
setupMock(*mockPlugin);
|
||||||
|
std::string libraryName = get_mock_engine_name();
|
||||||
|
sharedObjectLoader.reset(new SharedObjectLoader(libraryName.c_str()));
|
||||||
|
injectProxyEngine = make_std_function<void(IInferencePlugin*)>("InjectProxyEngine");
|
||||||
|
|
||||||
|
FuncTestUtils::TestModel::generateTestModel(modelName, weightsName);
|
||||||
|
}
|
||||||
|
|
||||||
|
void TearDown() override {
|
||||||
|
CommonTestUtils::removeIRFiles(modelName, weightsName);
|
||||||
|
}
|
||||||
|
|
||||||
|
void testLoad(std::function<void(Core& ie)> func) {
|
||||||
|
Core ie;
|
||||||
|
injectProxyEngine(mockPlugin.get());
|
||||||
|
ie.RegisterPlugin(std::string("mock_engine") + IE_BUILD_POSTFIX, deviceName);
|
||||||
|
func(ie);
|
||||||
|
ie.UnregisterPlugin(deviceName);
|
||||||
|
}
|
||||||
|
|
||||||
|
LoadFunction getLoadFunction(TestLoadType type) const {
|
||||||
|
switch (type) {
|
||||||
|
case TestLoadType::ECNN:
|
||||||
|
return [&](Core& ie) { performReadAndLoad(ie); };
|
||||||
|
case TestLoadType::EContext:
|
||||||
|
return [&](Core& ie) { performReadAndLoadWithContext(ie); };
|
||||||
|
case TestLoadType::EModelName:
|
||||||
|
return [&](Core& ie) { performLoadByName(ie); };
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
LoadFunctionWithCfg getLoadFunctionWithCfg(TestLoadType type) const {
|
||||||
|
switch (type) {
|
||||||
|
case TestLoadType::ECNN:
|
||||||
|
return std::bind(&CachingTest::performReadAndLoad, this, _1, _2);
|
||||||
|
case TestLoadType::EContext:
|
||||||
|
return std::bind(&CachingTest::performReadAndLoadWithContext, this, _1, _2);
|
||||||
|
case TestLoadType::EModelName:
|
||||||
|
return std::bind(&CachingTest::performLoadByName, this, _1, _2);
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
void performLoadByName(Core& ie, const std::map<std::string, std::string>& config = {}) const {
|
||||||
|
ie.LoadNetwork(modelName, deviceToLoad, config);
|
||||||
|
}
|
||||||
|
|
||||||
|
void performReadAndLoad(Core& ie, const std::map<std::string, std::string>& config = {}) const {
|
||||||
|
auto cnnNetwork = ie.ReadNetwork(modelName);
|
||||||
|
if (m_cnnCallback) m_cnnCallback(cnnNetwork);
|
||||||
|
ie.LoadNetwork(cnnNetwork, deviceToLoad, config);
|
||||||
|
}
|
||||||
|
|
||||||
|
void performReadAndLoadWithContext(Core& ie, const std::map<std::string, std::string>& config = {}) const {
|
||||||
|
auto cnnNetwork = ie.ReadNetwork(modelName);
|
||||||
|
EXPECT_CALL(*mockPlugin, GetDefaultContext(_)).Times(AnyNumber());
|
||||||
|
auto context = ie.GetDefaultContext(deviceToLoad);
|
||||||
|
if (m_cnnCallback) m_cnnCallback(cnnNetwork);
|
||||||
|
ie.LoadNetwork(cnnNetwork, context, config);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
template <class T>
|
||||||
|
std::function<T> make_std_function(const std::string& functionName) {
|
||||||
|
std::function <T> ptr(reinterpret_cast<T*>(sharedObjectLoader->get_symbol(functionName.c_str())));
|
||||||
|
return ptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
void setupMock(MockCachingInferencePlugin& plugin) {
|
||||||
|
ON_CALL(plugin, GetMetric(METRIC_KEY(SUPPORTED_METRICS), _)).
|
||||||
|
WillByDefault(Invoke([&](const std::string &, const std::map<std::string, Parameter> &) {
|
||||||
|
std::vector<std::string> res;
|
||||||
|
res.push_back(METRIC_KEY(IMPORT_EXPORT_SUPPORT));
|
||||||
|
res.push_back(METRIC_KEY(DEVICE_ARCHITECTURE));
|
||||||
|
return res;
|
||||||
|
}));
|
||||||
|
ON_CALL(plugin, GetMetric(METRIC_KEY(IMPORT_EXPORT_SUPPORT), _)).
|
||||||
|
WillByDefault(Return(true));
|
||||||
|
|
||||||
|
ON_CALL(plugin, GetMetric(METRIC_KEY(SUPPORTED_CONFIG_KEYS), _)).
|
||||||
|
WillByDefault(Invoke([&](const std::string &, const std::map<std::string, Parameter> &) {
|
||||||
|
std::vector<std::string> res;
|
||||||
|
res.push_back("SomeConfig");
|
||||||
|
return res;
|
||||||
|
}));
|
||||||
|
|
||||||
|
ON_CALL(plugin, GetMetric(METRIC_KEY(DEVICE_ARCHITECTURE), _)).
|
||||||
|
WillByDefault(Return("mock"));
|
||||||
|
|
||||||
|
ON_CALL(plugin, ImportNetworkImpl(_, _, _)).
|
||||||
|
WillByDefault(Invoke([&](std::istream &, RemoteContext::Ptr,
|
||||||
|
const std::map<std::string, std::string> &) {
|
||||||
|
return ExecutableNetwork(std::make_shared<MockIExecutableNetwork>());
|
||||||
|
}));
|
||||||
|
|
||||||
|
ON_CALL(plugin, ImportNetworkImpl(_, _)).
|
||||||
|
WillByDefault(Invoke([&](std::istream &, const std::map<std::string, std::string> &) {
|
||||||
|
return ExecutableNetwork(std::make_shared<MockIExecutableNetwork>());
|
||||||
|
}));
|
||||||
|
|
||||||
|
ON_CALL(plugin, LoadExeNetworkImpl(_, _, _)).
|
||||||
|
WillByDefault(Invoke([&](const CNNNetwork &, RemoteContext::Ptr,
|
||||||
|
const std::map<std::string, std::string> &) {
|
||||||
|
return net;
|
||||||
|
}));
|
||||||
|
|
||||||
|
ON_CALL(plugin, LoadExeNetworkImpl(_, _)).
|
||||||
|
WillByDefault(Invoke([&](const CNNNetwork &,
|
||||||
|
const std::map<std::string, std::string> &) {
|
||||||
|
return net;
|
||||||
|
}));
|
||||||
|
|
||||||
|
ON_CALL(plugin, GetDefaultContext(_)).
|
||||||
|
WillByDefault(Invoke([&](const ParamMap &) {
|
||||||
|
return std::make_shared<MockRemoteContext>(deviceToLoad);
|
||||||
|
}));
|
||||||
|
|
||||||
|
ON_CALL(plugin, QueryNetwork(_, _)).
|
||||||
|
WillByDefault(Invoke([&](const CNNNetwork &network, const std::map<std::string, std::string>&) {
|
||||||
|
QueryNetworkResult res;
|
||||||
|
auto function = network.getFunction();
|
||||||
|
EXPECT_TRUE(function);
|
||||||
|
|
||||||
|
for (auto &&node : function->get_ops()) {
|
||||||
|
res.supportedLayersMap.emplace(node->get_friendly_name(), deviceName);
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_P(CachingTest, TestLoad) {
|
||||||
|
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_METRICS), _)).Times(AnyNumber());
|
||||||
|
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(IMPORT_EXPORT_SUPPORT), _)).Times(AnyNumber());
|
||||||
|
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(DEVICE_ARCHITECTURE), _)).Times(AnyNumber());
|
||||||
|
{
|
||||||
|
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0);
|
||||||
|
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0);
|
||||||
|
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0);
|
||||||
|
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(0);
|
||||||
|
EXPECT_CALL(*net, ExportImpl(_)).Times(1);
|
||||||
|
testLoad([&](Core &ie) {
|
||||||
|
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}});
|
||||||
|
m_testFunction(ie);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(0);
|
||||||
|
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(0);
|
||||||
|
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0);
|
||||||
|
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0);
|
||||||
|
EXPECT_CALL(*net, ExportImpl(_)).Times(0);
|
||||||
|
testLoad([&](Core &ie) {
|
||||||
|
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}});
|
||||||
|
m_testFunction(ie);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_P(CachingTest, TestLoadCustomImportExport) {
|
||||||
|
const int customNumber = 1234;
|
||||||
|
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_METRICS), _)).Times(AnyNumber());
|
||||||
|
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(IMPORT_EXPORT_SUPPORT), _)).Times(AnyNumber());
|
||||||
|
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(DEVICE_ARCHITECTURE), _)).Times(AnyNumber());
|
||||||
|
ON_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).
|
||||||
|
WillByDefault(Invoke([&](std::istream& s, RemoteContext::Ptr,
|
||||||
|
const std::map<std::string, std::string> &) {
|
||||||
|
int a;
|
||||||
|
s >> a;
|
||||||
|
EXPECT_EQ(customNumber, a);
|
||||||
|
return ExecutableNetwork(std::make_shared<MockIExecutableNetwork>());
|
||||||
|
}));
|
||||||
|
|
||||||
|
ON_CALL(*mockPlugin, ImportNetworkImpl(_, _)).
|
||||||
|
WillByDefault(Invoke([&](std::istream &s, const std::map<std::string, std::string> &) {
|
||||||
|
int a;
|
||||||
|
s >> a;
|
||||||
|
EXPECT_EQ(customNumber, a);
|
||||||
|
return ExecutableNetwork(std::make_shared<MockIExecutableNetwork>());
|
||||||
|
}));
|
||||||
|
|
||||||
|
ON_CALL(*net, ExportImpl(_)).WillByDefault(Invoke([&] (std::ostream& s) {
|
||||||
|
s << customNumber;
|
||||||
|
}));
|
||||||
|
|
||||||
|
{
|
||||||
|
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0);
|
||||||
|
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0);
|
||||||
|
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0);
|
||||||
|
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(0);
|
||||||
|
EXPECT_CALL(*net, ExportImpl(_)).Times(1);
|
||||||
|
testLoad([&](Core &ie) {
|
||||||
|
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}});
|
||||||
|
m_testFunction(ie);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(0);
|
||||||
|
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(0);
|
||||||
|
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0);
|
||||||
|
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0);
|
||||||
|
EXPECT_CALL(*net, ExportImpl(_)).Times(0);
|
||||||
|
testLoad([&](Core &ie) {
|
||||||
|
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}});
|
||||||
|
m_testFunction(ie);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Brief: when LoadNetwork is called from different config - old cache shall not be used
|
||||||
|
TEST_P(CachingTest, TestChangeLoadConfig) {
|
||||||
|
const std::string CUSTOM_KEY = "CUSTOM_KEY";
|
||||||
|
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_METRICS), _)).Times(AnyNumber());
|
||||||
|
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(IMPORT_EXPORT_SUPPORT), _)).Times(AnyNumber());
|
||||||
|
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(DEVICE_ARCHITECTURE), _)).Times(AnyNumber());
|
||||||
|
ON_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_CONFIG_KEYS), _)).
|
||||||
|
WillByDefault(Invoke([&](const std::string &, const std::map<std::string, Parameter> &) {
|
||||||
|
std::vector<std::string> res;
|
||||||
|
res.push_back(CUSTOM_KEY);
|
||||||
|
return res;
|
||||||
|
}));
|
||||||
|
{
|
||||||
|
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0);
|
||||||
|
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0);
|
||||||
|
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0);
|
||||||
|
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(0);
|
||||||
|
EXPECT_CALL(*net, ExportImpl(_)).Times(1);
|
||||||
|
testLoad([&](Core &ie) {
|
||||||
|
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}});
|
||||||
|
m_testFunctionWithCfg(ie, {{CUSTOM_KEY, "0"}});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0);
|
||||||
|
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0);
|
||||||
|
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0);
|
||||||
|
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(0);
|
||||||
|
EXPECT_CALL(*net, ExportImpl(_)).Times(1);
|
||||||
|
testLoad([&](Core &ie) {
|
||||||
|
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}});
|
||||||
|
m_testFunctionWithCfg(ie, {{CUSTOM_KEY, "1"}});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_P(CachingTest, TestNoCacheEnabled) {
|
||||||
|
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_METRICS), _)).Times(AnyNumber());
|
||||||
|
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(IMPORT_EXPORT_SUPPORT), _)).Times(0);
|
||||||
|
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(DEVICE_ARCHITECTURE), _)).Times(AnyNumber());
|
||||||
|
{
|
||||||
|
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0);
|
||||||
|
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0);
|
||||||
|
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0);
|
||||||
|
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(0);
|
||||||
|
EXPECT_CALL(*net, ExportImpl(_)).Times(0);
|
||||||
|
testLoad([&](Core &ie) {
|
||||||
|
m_testFunction(ie);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_P(CachingTest, TestNoCacheSupported) {
|
||||||
|
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_METRICS), _)).Times(AnyNumber());
|
||||||
|
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(IMPORT_EXPORT_SUPPORT), _))
|
||||||
|
.Times(AnyNumber()).WillRepeatedly(Return(false));
|
||||||
|
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(DEVICE_ARCHITECTURE), _)).Times(AnyNumber());
|
||||||
|
{
|
||||||
|
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0);
|
||||||
|
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0);
|
||||||
|
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0);
|
||||||
|
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(0);
|
||||||
|
EXPECT_CALL(*net, ExportImpl(_)).Times(0);
|
||||||
|
testLoad([&](Core &ie) {
|
||||||
|
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}});
|
||||||
|
m_testFunction(ie);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_P(CachingTest, TestNoCacheMetricSupported) {
|
||||||
|
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_METRICS), _))
|
||||||
|
.Times(AnyNumber()).WillRepeatedly(Return(std::vector<std::string>{}));
|
||||||
|
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(IMPORT_EXPORT_SUPPORT), _)).Times(0);
|
||||||
|
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(DEVICE_ARCHITECTURE), _)).Times(0);
|
||||||
|
{
|
||||||
|
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0);
|
||||||
|
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0);
|
||||||
|
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0);
|
||||||
|
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(0);
|
||||||
|
EXPECT_CALL(*net, ExportImpl(_)).Times(0);
|
||||||
|
testLoad([&](Core &ie) {
|
||||||
|
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}});
|
||||||
|
m_testFunction(ie);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_P(CachingTest, TestLoadChangeCacheDir) {
|
||||||
|
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_METRICS), _)).Times(AnyNumber());
|
||||||
|
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(IMPORT_EXPORT_SUPPORT), _)).Times(AnyNumber());
|
||||||
|
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(DEVICE_ARCHITECTURE), _)).Times(AnyNumber());
|
||||||
|
{
|
||||||
|
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0);
|
||||||
|
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0);
|
||||||
|
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0);
|
||||||
|
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(0);
|
||||||
|
EXPECT_CALL(*net, ExportImpl(_)).Times(1);
|
||||||
|
testLoad([&](Core &ie) {
|
||||||
|
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}});
|
||||||
|
m_testFunction(ie);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
std::string newCacheDir = m_cacheDir + "2";
|
||||||
|
MkDirGuard dir(newCacheDir);
|
||||||
|
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0);
|
||||||
|
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0);
|
||||||
|
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0);
|
||||||
|
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(0);
|
||||||
|
EXPECT_CALL(*net, ExportImpl(_)).Times(1);
|
||||||
|
testLoad([&](Core &ie) {
|
||||||
|
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), newCacheDir}});
|
||||||
|
m_testFunction(ie);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_P(CachingTest, TestClearCacheDir) {
|
||||||
|
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_METRICS), _)).Times(AnyNumber());
|
||||||
|
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(IMPORT_EXPORT_SUPPORT), _)).Times(0);
|
||||||
|
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(DEVICE_ARCHITECTURE), _)).Times(AnyNumber());
|
||||||
|
{
|
||||||
|
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0);
|
||||||
|
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0);
|
||||||
|
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0);
|
||||||
|
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(0);
|
||||||
|
EXPECT_CALL(*net, ExportImpl(_)).Times(0);
|
||||||
|
testLoad([&](Core &ie) {
|
||||||
|
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}});
|
||||||
|
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), ""}});
|
||||||
|
m_testFunction(ie);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_P(CachingTest, TestChangeOtherConfig) {
|
||||||
|
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_METRICS), _)).Times(AnyNumber());
|
||||||
|
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(IMPORT_EXPORT_SUPPORT), _)).Times(AnyNumber());
|
||||||
|
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(DEVICE_ARCHITECTURE), _)).Times(AnyNumber());
|
||||||
|
{
|
||||||
|
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0);
|
||||||
|
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0);
|
||||||
|
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0);
|
||||||
|
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(0);
|
||||||
|
EXPECT_CALL(*net, ExportImpl(_)).Times(1);
|
||||||
|
testLoad([&](Core &ie) {
|
||||||
|
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}});
|
||||||
|
ie.SetConfig({{"someKey", "someValue"}});
|
||||||
|
m_testFunction(ie);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_P(CachingTest, TestChangeCacheDirFailure) {
|
||||||
|
std::string longName(1000000, ' ');
|
||||||
|
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_METRICS), _)).Times(AnyNumber());
|
||||||
|
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(IMPORT_EXPORT_SUPPORT), _)).Times(AnyNumber());
|
||||||
|
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(DEVICE_ARCHITECTURE), _)).Times(AnyNumber());
|
||||||
|
{
|
||||||
|
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0);
|
||||||
|
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0);
|
||||||
|
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0);
|
||||||
|
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(0);
|
||||||
|
EXPECT_CALL(*net, ExportImpl(_)).Times(1);
|
||||||
|
testLoad([&](Core &ie) {
|
||||||
|
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}});
|
||||||
|
m_testFunction(ie);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(0);
|
||||||
|
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(0);
|
||||||
|
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0);
|
||||||
|
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0);
|
||||||
|
EXPECT_CALL(*net, ExportImpl(_)).Times(0);
|
||||||
|
testLoad([&](Core &ie) {
|
||||||
|
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}});
|
||||||
|
EXPECT_ANY_THROW(ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir + "/" + longName}}));
|
||||||
|
m_testFunction(ie);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_P(CachingTest, TestCacheDirCreateRecursive) {
|
||||||
|
std::string newCacheDir1 = m_cacheDir + CommonTestUtils::FileSeparator + "a";
|
||||||
|
std::string newCacheDir2 = newCacheDir1 + CommonTestUtils::FileSeparator + "b";
|
||||||
|
std::string newCacheDir3 = newCacheDir2 + CommonTestUtils::FileSeparator + CommonTestUtils::FileSeparator;
|
||||||
|
|
||||||
|
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_METRICS), _)).Times(AnyNumber());
|
||||||
|
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(IMPORT_EXPORT_SUPPORT), _)).Times(AnyNumber());
|
||||||
|
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(DEVICE_ARCHITECTURE), _)).Times(AnyNumber());
|
||||||
|
{
|
||||||
|
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0);
|
||||||
|
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0);
|
||||||
|
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0);
|
||||||
|
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(0);
|
||||||
|
EXPECT_CALL(*net, ExportImpl(_)).Times(1);
|
||||||
|
testLoad([&](Core &ie) {
|
||||||
|
EXPECT_NO_THROW(ie.SetConfig({{CONFIG_KEY(CACHE_DIR), newCacheDir3}}));
|
||||||
|
EXPECT_NO_THROW(m_testFunction(ie));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
CommonTestUtils::removeFilesWithExt(newCacheDir2, "blob");
|
||||||
|
std::remove(newCacheDir2.c_str());
|
||||||
|
std::remove(newCacheDir1.c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_P(CachingTest, TestDeviceArchitecture) {
|
||||||
|
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_METRICS), _)).Times(AnyNumber());
|
||||||
|
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(IMPORT_EXPORT_SUPPORT), _)).Times(AnyNumber());
|
||||||
|
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(DEVICE_ARCHITECTURE), _)).Times(AnyNumber())
|
||||||
|
.WillRepeatedly(Invoke([&](const std::string&, const std::map<std::string, Parameter>& options) {
|
||||||
|
auto id = options.at("DEVICE_ID").as<std::string>();
|
||||||
|
if (std::stoi(id) < 10) {
|
||||||
|
return "mock_first_architecture";
|
||||||
|
} else {
|
||||||
|
return "mock_another_architecture";
|
||||||
|
}
|
||||||
|
}));
|
||||||
|
{
|
||||||
|
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0);
|
||||||
|
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0);
|
||||||
|
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0);
|
||||||
|
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(0);
|
||||||
|
EXPECT_CALL(*net, ExportImpl(_)).Times(1);
|
||||||
|
testLoad([&](Core &ie) {
|
||||||
|
deviceToLoad = "mock.0";
|
||||||
|
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}});
|
||||||
|
m_testFunction(ie);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(0);
|
||||||
|
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(0);
|
||||||
|
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0);
|
||||||
|
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0);
|
||||||
|
EXPECT_CALL(*net, ExportImpl(_)).Times(0);
|
||||||
|
testLoad([&](Core &ie) {
|
||||||
|
deviceToLoad = "mock.1";
|
||||||
|
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}});
|
||||||
|
m_testFunction(ie);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
{
|
||||||
|
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0);
|
||||||
|
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0);
|
||||||
|
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0);
|
||||||
|
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(0);
|
||||||
|
EXPECT_CALL(*net, ExportImpl(_)).Times(1);
|
||||||
|
testLoad([&](Core &ie) {
|
||||||
|
deviceToLoad = "mock.50";
|
||||||
|
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}});
|
||||||
|
m_testFunction(ie);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(0);
|
||||||
|
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(0);
|
||||||
|
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0);
|
||||||
|
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0);
|
||||||
|
EXPECT_CALL(*net, ExportImpl(_)).Times(0);
|
||||||
|
testLoad([&](Core &ie) {
|
||||||
|
deviceToLoad = "mock.51";
|
||||||
|
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}});
|
||||||
|
m_testFunction(ie);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_P(CachingTest, TestNoDeviceArchitecture) {
|
||||||
|
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_METRICS), _)).Times(AnyNumber())
|
||||||
|
.WillRepeatedly(Invoke([&] (const std::string&, const std::map<std::string, Parameter>&) {
|
||||||
|
return std::vector<std::string>{METRIC_KEY(IMPORT_EXPORT_SUPPORT)};
|
||||||
|
}));
|
||||||
|
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(IMPORT_EXPORT_SUPPORT), _)).Times(AnyNumber());
|
||||||
|
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(DEVICE_ARCHITECTURE), _)).Times(0);
|
||||||
|
{
|
||||||
|
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0);
|
||||||
|
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0);
|
||||||
|
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0);
|
||||||
|
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(0);
|
||||||
|
EXPECT_CALL(*net, ExportImpl(_)).Times(1);
|
||||||
|
testLoad([&](Core &ie) {
|
||||||
|
deviceToLoad = "mock.0";
|
||||||
|
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}});
|
||||||
|
m_testFunction(ie);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(0);
|
||||||
|
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(0);
|
||||||
|
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0);
|
||||||
|
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0);
|
||||||
|
EXPECT_CALL(*net, ExportImpl(_)).Times(0);
|
||||||
|
testLoad([&](Core &ie) {
|
||||||
|
deviceToLoad = "mock.50";
|
||||||
|
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}});
|
||||||
|
m_testFunction(ie);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_P(CachingTest, TestThrowOnExport) {
|
||||||
|
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_METRICS), _)).Times(AnyNumber());
|
||||||
|
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(IMPORT_EXPORT_SUPPORT), _)).Times(AnyNumber());
|
||||||
|
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(DEVICE_ARCHITECTURE), _)).Times(AnyNumber());
|
||||||
|
{
|
||||||
|
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0);
|
||||||
|
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0);
|
||||||
|
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0);
|
||||||
|
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(0);
|
||||||
|
EXPECT_CALL(*net, ExportImpl(_)).Times(1).WillOnce(Throw(1));
|
||||||
|
testLoad([&](Core &ie) {
|
||||||
|
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}});
|
||||||
|
EXPECT_ANY_THROW(m_testFunction(ie));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_P(CachingTest, TestThrowOnImport) {
|
||||||
|
ON_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).WillByDefault(Throw(1));
|
||||||
|
ON_CALL(*mockPlugin, ImportNetworkImpl(_, _)).WillByDefault(Throw(1));
|
||||||
|
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_METRICS), _)).Times(AnyNumber());
|
||||||
|
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(IMPORT_EXPORT_SUPPORT), _)).Times(AnyNumber());
|
||||||
|
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(DEVICE_ARCHITECTURE), _)).Times(AnyNumber());
|
||||||
|
{
|
||||||
|
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0);
|
||||||
|
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0);
|
||||||
|
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0);
|
||||||
|
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(0);
|
||||||
|
EXPECT_CALL(*net, ExportImpl(_)).Times(1);
|
||||||
|
testLoad([&](Core &ie) {
|
||||||
|
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}});
|
||||||
|
m_testFunction(ie);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
{
|
||||||
|
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(0);
|
||||||
|
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(0);
|
||||||
|
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0);
|
||||||
|
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0);
|
||||||
|
EXPECT_CALL(*net, ExportImpl(_)).Times(0);
|
||||||
|
testLoad([&](Core &ie) {
|
||||||
|
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}});
|
||||||
|
EXPECT_ANY_THROW(m_testFunction(ie));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
{ // Step 3: same load, cache should be deleted due to unsuccessful import on step 2
|
||||||
|
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0);
|
||||||
|
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0);
|
||||||
|
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0);
|
||||||
|
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(0);
|
||||||
|
EXPECT_CALL(*net, ExportImpl(_)).Times(1);
|
||||||
|
testLoad([&](Core &ie) {
|
||||||
|
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}});
|
||||||
|
EXPECT_NO_THROW(m_testFunction(ie));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_P(CachingTest, TestNetworkModified) {
|
||||||
|
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_METRICS), _)).Times(AnyNumber());
|
||||||
|
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(IMPORT_EXPORT_SUPPORT), _)).Times(AnyNumber());
|
||||||
|
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(DEVICE_ARCHITECTURE), _)).Times(AnyNumber());
|
||||||
|
|
||||||
|
{
|
||||||
|
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0);
|
||||||
|
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0);
|
||||||
|
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0);
|
||||||
|
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(0);
|
||||||
|
EXPECT_CALL(*net, ExportImpl(_)).Times(1);
|
||||||
|
testLoad([&](Core &ie) {
|
||||||
|
EXPECT_NO_THROW(ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}}));
|
||||||
|
EXPECT_NO_THROW(m_testFunction(ie));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
if (m_type == TestLoadType::EModelName) {
|
||||||
|
// Modify model file
|
||||||
|
std::fstream stream(modelName, std::fstream::out | std::fstream::app);
|
||||||
|
stream << " ";
|
||||||
|
} else {
|
||||||
|
// Modify loaded CNN network
|
||||||
|
m_cnnCallback = [&](CNNNetwork& network) {
|
||||||
|
auto f = network.getFunction();
|
||||||
|
auto res = f->get_results();
|
||||||
|
f->remove_result(res.front());
|
||||||
|
};
|
||||||
|
}
|
||||||
|
{
|
||||||
|
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0);
|
||||||
|
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0);
|
||||||
|
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0);
|
||||||
|
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(0);
|
||||||
|
EXPECT_CALL(*net, ExportImpl(_)).Times(1);
|
||||||
|
testLoad([&](Core &ie) {
|
||||||
|
EXPECT_NO_THROW(ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}}));
|
||||||
|
EXPECT_NO_THROW(m_testFunction(ie));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
{ // Step 3: same load, should be ok now
|
||||||
|
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(0);
|
||||||
|
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(0);
|
||||||
|
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0);
|
||||||
|
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0);
|
||||||
|
EXPECT_CALL(*net, ExportImpl(_)).Times(0);
|
||||||
|
testLoad([&](Core &ie) {
|
||||||
|
EXPECT_NO_THROW(ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}}));
|
||||||
|
EXPECT_NO_THROW(m_testFunction(ie));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_P(CachingTest, TestCacheFileCorrupted) {
|
||||||
|
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_METRICS), _)).Times(AnyNumber());
|
||||||
|
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(IMPORT_EXPORT_SUPPORT), _)).Times(AnyNumber());
|
||||||
|
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(DEVICE_ARCHITECTURE), _)).Times(AnyNumber());
|
||||||
|
|
||||||
|
{
|
||||||
|
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0);
|
||||||
|
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0);
|
||||||
|
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0);
|
||||||
|
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(0);
|
||||||
|
EXPECT_CALL(*net, ExportImpl(_)).Times(1);
|
||||||
|
testLoad([&](Core &ie) {
|
||||||
|
EXPECT_NO_THROW(ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}}));
|
||||||
|
EXPECT_NO_THROW(m_testFunction(ie));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
{
|
||||||
|
auto blobs = CommonTestUtils::listFilesWithExt(m_cacheDir, "blob");
|
||||||
|
for (const auto& fileName : blobs) {
|
||||||
|
std::ofstream stream(fileName, std::ios_base::binary);
|
||||||
|
stream << "SomeCorruptedText";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
{ // Step 2. Cache is corrupted, will be silently removed
|
||||||
|
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0);
|
||||||
|
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0);
|
||||||
|
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0);
|
||||||
|
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(0);
|
||||||
|
EXPECT_CALL(*net, ExportImpl(_)).Times(1);
|
||||||
|
testLoad([&](Core &ie) {
|
||||||
|
EXPECT_NO_THROW(ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}}));
|
||||||
|
EXPECT_NO_THROW(m_testFunction(ie));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
{ // Step 3: same load, should be ok now due to re-creation of cache
|
||||||
|
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(0);
|
||||||
|
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(0);
|
||||||
|
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0);
|
||||||
|
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0);
|
||||||
|
EXPECT_CALL(*net, ExportImpl(_)).Times(0);
|
||||||
|
testLoad([&](Core &ie) {
|
||||||
|
EXPECT_NO_THROW(ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}}));
|
||||||
|
EXPECT_NO_THROW(m_testFunction(ie));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_P(CachingTest, TestCacheFileOldVersion) {
|
||||||
|
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(SUPPORTED_METRICS), _)).Times(AnyNumber());
|
||||||
|
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(IMPORT_EXPORT_SUPPORT), _)).Times(AnyNumber());
|
||||||
|
EXPECT_CALL(*mockPlugin, GetMetric(METRIC_KEY(DEVICE_ARCHITECTURE), _)).Times(AnyNumber());
|
||||||
|
|
||||||
|
{
|
||||||
|
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0);
|
||||||
|
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0);
|
||||||
|
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0);
|
||||||
|
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(0);
|
||||||
|
EXPECT_CALL(*net, ExportImpl(_)).Times(1);
|
||||||
|
testLoad([&](Core &ie) {
|
||||||
|
EXPECT_NO_THROW(ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}}));
|
||||||
|
EXPECT_NO_THROW(m_testFunction(ie));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
{
|
||||||
|
auto blobs = CommonTestUtils::listFilesWithExt(m_cacheDir, "blob");
|
||||||
|
for (const auto& fileName : blobs) {
|
||||||
|
std::string content;
|
||||||
|
{
|
||||||
|
std::ifstream inp(fileName, std::ios_base::binary);
|
||||||
|
std::ostringstream ostr;
|
||||||
|
ostr << inp.rdbuf();
|
||||||
|
content = ostr.str();
|
||||||
|
}
|
||||||
|
std::string buildNum = GetInferenceEngineVersion()->buildNumber;
|
||||||
|
std::string zeroBuild(buildNum.size(), '0');
|
||||||
|
auto index = content.find(buildNum);
|
||||||
|
if (index != std::string::npos) {
|
||||||
|
content.replace(index, buildNum.size(), zeroBuild);
|
||||||
|
} else {
|
||||||
|
SKIP();
|
||||||
|
}
|
||||||
|
std::ofstream out(fileName, std::ios_base::binary);
|
||||||
|
out.write(content.c_str(), content.size());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
{ // Step 2. Build number mismatch, cache will be silently removed
|
||||||
|
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0);
|
||||||
|
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0);
|
||||||
|
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0);
|
||||||
|
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(0);
|
||||||
|
EXPECT_CALL(*net, ExportImpl(_)).Times(1);
|
||||||
|
testLoad([&](Core &ie) {
|
||||||
|
EXPECT_NO_THROW(ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}}));
|
||||||
|
EXPECT_NO_THROW(m_testFunction(ie));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
{ // Step 3: same load, should be ok now due to re-creation of cache
|
||||||
|
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(0);
|
||||||
|
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(0);
|
||||||
|
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(m_remoteContext ? 1 : 0);
|
||||||
|
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(!m_remoteContext ? 1 : 0);
|
||||||
|
EXPECT_CALL(*net, ExportImpl(_)).Times(0);
|
||||||
|
testLoad([&](Core &ie) {
|
||||||
|
EXPECT_NO_THROW(ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}}));
|
||||||
|
EXPECT_NO_THROW(m_testFunction(ie));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_P(CachingTest, LoadHeteroWithCorrectConfig) {
|
||||||
|
EXPECT_CALL(*mockPlugin, GetMetric(_, _)).Times(AnyNumber());
|
||||||
|
EXPECT_CALL(*mockPlugin, QueryNetwork(_, _)).Times(AnyNumber());
|
||||||
|
// TODO: test also HETERO with 1 plugin but different architectures, e.g. "HETERO:mock.1,mock.51"
|
||||||
|
deviceToLoad = CommonTestUtils::DEVICE_HETERO + std::string(":mock.1,mock.2");
|
||||||
|
if (m_remoteContext) {
|
||||||
|
return; // skip the remote Context test for Hetero plugin
|
||||||
|
}
|
||||||
|
{
|
||||||
|
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(0);
|
||||||
|
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(1);
|
||||||
|
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0);
|
||||||
|
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(0);
|
||||||
|
EXPECT_CALL(*net, ExportImpl(_)).Times(1);
|
||||||
|
testLoad([&](Core &ie) {
|
||||||
|
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}});
|
||||||
|
m_testFunction(ie);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _, _)).Times(0);
|
||||||
|
EXPECT_CALL(*mockPlugin, LoadExeNetworkImpl(_, _)).Times(0);
|
||||||
|
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _, _)).Times(0);
|
||||||
|
EXPECT_CALL(*mockPlugin, ImportNetworkImpl(_, _)).Times(1);
|
||||||
|
EXPECT_CALL(*net, ExportImpl(_)).Times(0);
|
||||||
|
testLoad([&](Core &ie) {
|
||||||
|
ie.SetConfig({{CONFIG_KEY(CACHE_DIR), m_cacheDir}});
|
||||||
|
m_testFunction(ie);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_CASE_P(CachingTest, CachingTest,
|
||||||
|
::testing::Combine(
|
||||||
|
::testing::ValuesIn(loadVariants),
|
||||||
|
::testing::ValuesIn(cacheFolders)),
|
||||||
|
getTestCaseName);
|
@ -1,4 +1,4 @@
|
|||||||
// Copyright (C) 2019 Intel Corporation
|
// Copyright (C) 2021 Intel Corporation
|
||||||
// SPDX-License-Identifier: Apache-2.0
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
//
|
//
|
||||||
#pragma once
|
#pragma once
|
||||||
@ -18,7 +18,6 @@
|
|||||||
#define rmdir(dir) _rmdir(dir)
|
#define rmdir(dir) _rmdir(dir)
|
||||||
#else // _WIN32
|
#else // _WIN32
|
||||||
#include <unistd.h>
|
#include <unistd.h>
|
||||||
|
|
||||||
#endif // _WIN32
|
#endif // _WIN32
|
||||||
|
|
||||||
namespace CommonTestUtils {
|
namespace CommonTestUtils {
|
||||||
@ -103,6 +102,27 @@ inline int removeFilesWithExt(std::string path, std::string ext) {
|
|||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Lists all files with extension=ext from the given directory
|
||||||
|
// Return value:
|
||||||
|
// vector of strings representing file paths
|
||||||
|
inline std::vector<std::string> listFilesWithExt(const std::string& path, const std::string& ext) {
|
||||||
|
struct dirent *ent;
|
||||||
|
DIR *dir = opendir(path.c_str());
|
||||||
|
std::vector<std::string> res;
|
||||||
|
if (dir != nullptr) {
|
||||||
|
while ((ent = readdir(dir)) != NULL) {
|
||||||
|
auto file = makePath(path, std::string(ent->d_name));
|
||||||
|
struct stat stat_path;
|
||||||
|
stat(file.c_str(), &stat_path);
|
||||||
|
if (!S_ISDIR(stat_path.st_mode) && endsWith(file, "." + ext)) {
|
||||||
|
res.push_back(std::move(file));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
closedir(dir);
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
inline int removeDir(const std::string &path) {
|
inline int removeDir(const std::string &path) {
|
||||||
return rmdir(path.c_str());
|
return rmdir(path.c_str());
|
||||||
}
|
}
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
// Copyright (C) 2018-2020 Intel Corporation
|
// Copyright (C) 2018-2021 Intel Corporation
|
||||||
// SPDX-License-Identifier: Apache-2.0
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
//
|
//
|
||||||
|
|
||||||
@ -16,9 +16,13 @@ public:
|
|||||||
|
|
||||||
MOCK_METHOD3(LoadNetwork, InferenceEngine::ExecutableNetwork(
|
MOCK_METHOD3(LoadNetwork, InferenceEngine::ExecutableNetwork(
|
||||||
const InferenceEngine::CNNNetwork&, const std::string&, const std::map<std::string, std::string>&));
|
const InferenceEngine::CNNNetwork&, const std::string&, const std::map<std::string, std::string>&));
|
||||||
|
MOCK_METHOD3(LoadNetwork, InferenceEngine::ExecutableNetwork(
|
||||||
|
const InferenceEngine::CNNNetwork&, const InferenceEngine::RemoteContext::Ptr &, const std::map<std::string, std::string>&));
|
||||||
|
|
||||||
MOCK_METHOD3(ImportNetwork, InferenceEngine::ExecutableNetwork(
|
MOCK_METHOD3(ImportNetwork, InferenceEngine::ExecutableNetwork(
|
||||||
std::istream&, const std::string&, const std::map<std::string, std::string>&));
|
std::istream&, const std::string&, const std::map<std::string, std::string>&));
|
||||||
|
MOCK_METHOD3(ImportNetwork, InferenceEngine::ExecutableNetwork(
|
||||||
|
std::istream&, const InferenceEngine::RemoteContext::Ptr&, const std::map<std::string, std::string>&));
|
||||||
|
|
||||||
MOCK_QUALIFIED_METHOD3(QueryNetwork, const, InferenceEngine::QueryNetworkResult(
|
MOCK_QUALIFIED_METHOD3(QueryNetwork, const, InferenceEngine::QueryNetworkResult(
|
||||||
const InferenceEngine::CNNNetwork&, const std::string&, const std::map<std::string, std::string>&));
|
const InferenceEngine::CNNNetwork&, const std::string&, const std::map<std::string, std::string>&));
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
// Copyright (C) 2018-2020 Intel Corporation
|
// Copyright (C) 2018-2021 Intel Corporation
|
||||||
// SPDX-License-Identifier: Apache-2.0
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
//
|
//
|
||||||
|
|
||||||
@ -22,6 +22,14 @@ void MockPlugin::SetConfig(const std::map<std::string, std::string>& config) {
|
|||||||
this->config = config;
|
this->config = config;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Parameter MockPlugin::GetMetric(const std::string& name, const std::map<std::string, InferenceEngine::Parameter>& options) const {
|
||||||
|
if (_target) {
|
||||||
|
return _target->GetMetric(name, options);
|
||||||
|
} else {
|
||||||
|
THROW_IE_EXCEPTION_WITH_STATUS(NOT_IMPLEMENTED);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
ExecutableNetwork
|
ExecutableNetwork
|
||||||
MockPlugin::LoadNetwork(const CNNNetwork &network,
|
MockPlugin::LoadNetwork(const CNNNetwork &network,
|
||||||
const std::map<std::string, std::string> &config) {
|
const std::map<std::string, std::string> &config) {
|
||||||
@ -32,12 +40,62 @@ MockPlugin::LoadNetwork(const CNNNetwork &network,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
InferenceEngine::ExecutableNetworkInternal::Ptr
|
ExecutableNetwork
|
||||||
MockPlugin::LoadExeNetworkImpl(const InferenceEngine::CNNNetwork& network,
|
MockPlugin::LoadNetwork(const CNNNetwork& network, const std::map<std::string, std::string>& config,
|
||||||
|
RemoteContext::Ptr context) {
|
||||||
|
if (_target) {
|
||||||
|
return _target->LoadNetwork(network, config, context);
|
||||||
|
} else {
|
||||||
|
THROW_IE_EXCEPTION_WITH_STATUS(NOT_IMPLEMENTED);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ExecutableNetworkInternal::Ptr
|
||||||
|
MockPlugin::LoadExeNetworkImpl(const CNNNetwork& network,
|
||||||
const std::map<std::string, std::string>& config) {
|
const std::map<std::string, std::string>& config) {
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
InferenceEngine::ExecutableNetwork
|
||||||
|
MockPlugin::ImportNetworkImpl(std::istream& networkModel,
|
||||||
|
const std::map<std::string, std::string>& config) {
|
||||||
|
if (_target) {
|
||||||
|
return _target->ImportNetwork(networkModel, config);
|
||||||
|
} else {
|
||||||
|
THROW_IE_EXCEPTION_WITH_STATUS(NOT_IMPLEMENTED);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
InferenceEngine::ExecutableNetwork
|
||||||
|
MockPlugin::ImportNetworkImpl(std::istream& networkModel,
|
||||||
|
const InferenceEngine::RemoteContext::Ptr& context,
|
||||||
|
const std::map<std::string, std::string>& config) {
|
||||||
|
if (_target) {
|
||||||
|
return _target->ImportNetwork(networkModel, context, config);
|
||||||
|
} else {
|
||||||
|
THROW_IE_EXCEPTION_WITH_STATUS(NOT_IMPLEMENTED);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
InferenceEngine::RemoteContext::Ptr MockPlugin::GetDefaultContext(const InferenceEngine::ParamMap& params) {
|
||||||
|
if (_target) {
|
||||||
|
return _target->GetDefaultContext(params);
|
||||||
|
} else {
|
||||||
|
THROW_IE_EXCEPTION_WITH_STATUS(NOT_IMPLEMENTED);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
InferenceEngine::QueryNetworkResult
|
||||||
|
MockPlugin::QueryNetwork(const InferenceEngine::CNNNetwork& network,
|
||||||
|
const std::map<std::string, std::string>& config) const {
|
||||||
|
if (_target) {
|
||||||
|
return _target->QueryNetwork(network, config);
|
||||||
|
} else {
|
||||||
|
THROW_IE_EXCEPTION_WITH_STATUS(NOT_IMPLEMENTED);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
InferenceEngine::IInferencePlugin *__target = nullptr;
|
InferenceEngine::IInferencePlugin *__target = nullptr;
|
||||||
|
|
||||||
INFERENCE_PLUGIN_API(void) CreatePluginEngine(std::shared_ptr<InferenceEngine::IInferencePlugin>& plugin) {
|
INFERENCE_PLUGIN_API(void) CreatePluginEngine(std::shared_ptr<InferenceEngine::IInferencePlugin>& plugin) {
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
// Copyright (C) 2018-2020 Intel Corporation
|
// Copyright (C) 2018-2021 Intel Corporation
|
||||||
// SPDX-License-Identifier: Apache-2.0
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
//
|
//
|
||||||
|
|
||||||
@ -17,12 +17,34 @@ public:
|
|||||||
explicit MockPlugin(InferenceEngine::IInferencePlugin*target);
|
explicit MockPlugin(InferenceEngine::IInferencePlugin*target);
|
||||||
|
|
||||||
void SetConfig(const std::map<std::string, std::string>& config) override;
|
void SetConfig(const std::map<std::string, std::string>& config) override;
|
||||||
|
|
||||||
InferenceEngine::ExecutableNetwork
|
InferenceEngine::ExecutableNetwork
|
||||||
LoadNetwork(const InferenceEngine::CNNNetwork &network,
|
LoadNetwork(const InferenceEngine::CNNNetwork &network,
|
||||||
const std::map<std::string, std::string> &config) override;
|
const std::map<std::string, std::string> &config) override;
|
||||||
|
|
||||||
|
InferenceEngine::ExecutableNetwork
|
||||||
|
LoadNetwork(const InferenceEngine::CNNNetwork& network,
|
||||||
|
const std::map<std::string, std::string>& config,
|
||||||
|
InferenceEngine::RemoteContext::Ptr context) override;
|
||||||
|
|
||||||
InferenceEngine::ExecutableNetworkInternal::Ptr
|
InferenceEngine::ExecutableNetworkInternal::Ptr
|
||||||
LoadExeNetworkImpl(const InferenceEngine::CNNNetwork& network,
|
LoadExeNetworkImpl(const InferenceEngine::CNNNetwork& network,
|
||||||
const std::map<std::string, std::string>& config) override;
|
const std::map<std::string, std::string>& config) override;
|
||||||
|
|
||||||
|
InferenceEngine::ExecutableNetwork ImportNetworkImpl(std::istream& networkModel,
|
||||||
|
const std::map<std::string, std::string>& config) override;
|
||||||
|
|
||||||
|
InferenceEngine::ExecutableNetwork ImportNetworkImpl(std::istream& networkModel,
|
||||||
|
const InferenceEngine::RemoteContext::Ptr& context,
|
||||||
|
const std::map<std::string, std::string>& config) override;
|
||||||
|
|
||||||
|
InferenceEngine::Parameter GetMetric(const std::string& name,
|
||||||
|
const std::map<std::string, InferenceEngine::Parameter>& options) const override;
|
||||||
|
|
||||||
|
InferenceEngine::RemoteContext::Ptr GetDefaultContext(const InferenceEngine::ParamMap& params) override;
|
||||||
|
|
||||||
|
InferenceEngine::QueryNetworkResult QueryNetwork(const InferenceEngine::CNNNetwork& network,
|
||||||
|
const std::map<std::string, std::string>& config) const override;
|
||||||
|
|
||||||
std::map<std::string, std::string> config;
|
std::map<std::string, std::string> config;
|
||||||
};
|
};
|
||||||
|
@ -0,0 +1,388 @@
|
|||||||
|
// Copyright (C) 2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
#include <fstream>
|
||||||
|
#include <thread>
|
||||||
|
#include <chrono>
|
||||||
|
|
||||||
|
#include "compilation_context.hpp"
|
||||||
|
#include "ngraph/function.hpp"
|
||||||
|
#include "ngraph/ops.hpp"
|
||||||
|
#include "ngraph/variant.hpp"
|
||||||
|
#include "ngraph/opsets/opset6.hpp"
|
||||||
|
#include "transformations/rt_info/dequantization_attribute.hpp"
|
||||||
|
#include "transformations/rt_info/fused_names_attribute.hpp"
|
||||||
|
#include "transformations/rt_info/primitives_priority_attribute.hpp"
|
||||||
|
#include "cpp/ie_cnn_network.h"
|
||||||
|
|
||||||
|
#include "common_test_utils/test_constants.hpp"
|
||||||
|
|
||||||
|
using namespace InferenceEngine;
|
||||||
|
using namespace ngraph;
|
||||||
|
using namespace ::testing;
|
||||||
|
using namespace std::chrono;
|
||||||
|
|
||||||
|
static std::string generateTestFilePrefix() {
|
||||||
|
// Generate unique file names based on test name, thread id and timestamp
|
||||||
|
// This allows execution of tests in parallel (stress mode)
|
||||||
|
auto testInfo = UnitTest::GetInstance()->current_test_info();
|
||||||
|
std::string testName = testInfo->test_case_name();
|
||||||
|
testName += testInfo->name();
|
||||||
|
testName = std::to_string(std::hash<std::string>()(testName));
|
||||||
|
std::stringstream ss;
|
||||||
|
auto ts = duration_cast<microseconds>(high_resolution_clock::now().time_since_epoch());
|
||||||
|
ss << testName << "_" << std::this_thread::get_id() << "_" << ts.count();
|
||||||
|
testName = ss.str();
|
||||||
|
return testName;
|
||||||
|
}
|
||||||
|
|
||||||
|
class FileGuard {
|
||||||
|
std::string m_fileName;
|
||||||
|
public:
|
||||||
|
FileGuard(const std::string& name): m_fileName(name) {}
|
||||||
|
~FileGuard() { std::remove(m_fileName.c_str()); }
|
||||||
|
};
|
||||||
|
|
||||||
|
class NetworkContext_CalcFileInfoTests : public Test {
|
||||||
|
public:
|
||||||
|
std::string m_fileName = "test.blob";
|
||||||
|
|
||||||
|
static void createFile(const std::string& fileName, std::size_t size = 1) {
|
||||||
|
std::ofstream str(fileName, std::ios::binary);
|
||||||
|
if (!str.good()) {
|
||||||
|
GTEST_SKIP();
|
||||||
|
}
|
||||||
|
for (std::size_t i = 0; i < size; i++)
|
||||||
|
str.put('a');
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sets up the test fixture.
|
||||||
|
void SetUp() override {
|
||||||
|
auto testName = generateTestFilePrefix();
|
||||||
|
m_fileName = testName + m_fileName;
|
||||||
|
createFile(m_fileName);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tears down the test fixture.
|
||||||
|
void TearDown() override {
|
||||||
|
std::remove(m_fileName.c_str());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(NetworkContext_CalcFileInfoTests, NoFile) {
|
||||||
|
ASSERT_NE(NetworkCompilationContext::calculateFileInfo("notexisting.abc"),
|
||||||
|
NetworkCompilationContext::calculateFileInfo("notexisting2.abc"));
|
||||||
|
|
||||||
|
std::string fileName(100, 'a');
|
||||||
|
std::string fileName2(fileName);
|
||||||
|
ASSERT_EQ(NetworkCompilationContext::calculateFileInfo(fileName),
|
||||||
|
NetworkCompilationContext::calculateFileInfo(fileName2));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(NetworkContext_CalcFileInfoTests, ExistingFile) {
|
||||||
|
ASSERT_EQ(NetworkCompilationContext::calculateFileInfo(m_fileName),
|
||||||
|
NetworkCompilationContext::calculateFileInfo(m_fileName));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(NetworkContext_CalcFileInfoTests, ExistingDiffFiles) {
|
||||||
|
auto hash1 = NetworkCompilationContext::calculateFileInfo(m_fileName);
|
||||||
|
std::string newName = m_fileName + "2";
|
||||||
|
std::rename(m_fileName.c_str(), newName.c_str());
|
||||||
|
m_fileName = std::move(newName);
|
||||||
|
auto hash2 = NetworkCompilationContext::calculateFileInfo(m_fileName);
|
||||||
|
ASSERT_NE(hash1, hash2);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(NetworkContext_CalcFileInfoTests, ExistingFile_sameAbsPath) {
|
||||||
|
std::string file1 = m_fileName;
|
||||||
|
std::string file2 = std::string(".") + CommonTestUtils::FileSeparator + m_fileName;
|
||||||
|
ASSERT_EQ(NetworkCompilationContext::calculateFileInfo(file1),
|
||||||
|
NetworkCompilationContext::calculateFileInfo(file2)) <<
|
||||||
|
"Hash of [" << file1 << "] is not equal to hash of [" << file2 << "]";
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(NetworkContext_CalcFileInfoTests, DateModified) {
|
||||||
|
auto info1 = NetworkCompilationContext::calculateFileInfo(m_fileName);
|
||||||
|
std::this_thread::sleep_for(std::chrono::seconds(2));
|
||||||
|
createFile(m_fileName);
|
||||||
|
auto info2 = NetworkCompilationContext::calculateFileInfo(m_fileName);
|
||||||
|
ASSERT_NE(info1, info2);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(NetworkContext_CalcFileInfoTests, SizeModified) {
|
||||||
|
createFile(m_fileName, 1);
|
||||||
|
auto info1 = NetworkCompilationContext::calculateFileInfo(m_fileName);
|
||||||
|
createFile(m_fileName, 2);
|
||||||
|
auto info2 = NetworkCompilationContext::calculateFileInfo(m_fileName);
|
||||||
|
ASSERT_NE(info1, info2);
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
static std::shared_ptr<ngraph::Function> create_simple_function() {
|
||||||
|
// This example is taken from docs, shows how to create ngraph::Function
|
||||||
|
//
|
||||||
|
// Parameter--->Multiply--->Add--->Result
|
||||||
|
// Constant---' /
|
||||||
|
// Constant---'
|
||||||
|
|
||||||
|
// Create opset6::Parameter operation with static shape
|
||||||
|
auto data = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::i8, ngraph::Shape{3, 1, 2});
|
||||||
|
data->set_friendly_name("Parameter");
|
||||||
|
|
||||||
|
auto mul_constant = ngraph::opset6::Constant::create(ngraph::element::i8, ngraph::Shape{1}, {3});
|
||||||
|
mul_constant->set_friendly_name("mul_constant");
|
||||||
|
auto mul = std::make_shared<ngraph::opset6::Multiply>(data, mul_constant);
|
||||||
|
mul->set_friendly_name("mul");
|
||||||
|
|
||||||
|
auto add_constant = ngraph::opset6::Constant::create(ngraph::element::i8, ngraph::Shape{1}, {2});
|
||||||
|
add_constant->set_friendly_name("add_constant");
|
||||||
|
auto add = std::make_shared<ngraph::opset6::Add>(mul, add_constant);
|
||||||
|
add->set_friendly_name("add");
|
||||||
|
|
||||||
|
// Create opset3::Result operation
|
||||||
|
auto res = std::make_shared<ngraph::opset6::Result>(add);
|
||||||
|
res->set_friendly_name("res");
|
||||||
|
|
||||||
|
// Create nGraph function
|
||||||
|
auto func = std::make_shared<ngraph::Function>(ngraph::ResultVector{res}, ngraph::ParameterVector{data});
|
||||||
|
func->set_friendly_name("function");
|
||||||
|
return func;
|
||||||
|
}
|
||||||
|
|
||||||
|
static CNNNetwork createNetwork() {
|
||||||
|
CNNNetwork res(create_simple_function());
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void checkCustomRt(std::function<void(Node::RTMap&)> emptyCb,
|
||||||
|
std::function<void(Node::RTMap&, const std::string& name)> nameCb) {
|
||||||
|
auto net1 = createNetwork();
|
||||||
|
auto net2 = createNetwork();
|
||||||
|
auto & op1 = net1.getFunction()->get_ops().front()->get_rt_info();
|
||||||
|
auto & op2 = net2.getFunction()->get_ops().front()->get_rt_info();
|
||||||
|
|
||||||
|
emptyCb(op2);
|
||||||
|
ASSERT_NE(NetworkCompilationContext::computeHash(net1, {}),
|
||||||
|
NetworkCompilationContext::computeHash(net2, {}));
|
||||||
|
|
||||||
|
emptyCb(op1);
|
||||||
|
ASSERT_EQ(NetworkCompilationContext::computeHash(net1, {}),
|
||||||
|
NetworkCompilationContext::computeHash(net2, {}));
|
||||||
|
|
||||||
|
nameCb(op1, "test");
|
||||||
|
ASSERT_NE(NetworkCompilationContext::computeHash(net1, {}),
|
||||||
|
NetworkCompilationContext::computeHash(net2, {}));
|
||||||
|
|
||||||
|
nameCb(op2, "test");
|
||||||
|
ASSERT_EQ(NetworkCompilationContext::computeHash(net1, {}),
|
||||||
|
NetworkCompilationContext::computeHash(net2, {}));
|
||||||
|
|
||||||
|
nameCb(op1, "test2");
|
||||||
|
ASSERT_NE(NetworkCompilationContext::computeHash(net1, {}),
|
||||||
|
NetworkCompilationContext::computeHash(net2, {}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(NetworkContext_CNNNetwork, HashOfSame) {
|
||||||
|
auto net1 = createNetwork();
|
||||||
|
auto net2 = createNetwork();
|
||||||
|
ASSERT_EQ(NetworkCompilationContext::computeHash(net1, {}),
|
||||||
|
NetworkCompilationContext::computeHash(net2, {}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(NetworkContext_CNNNetwork, HashWithConfig) {
|
||||||
|
auto net1 = createNetwork();
|
||||||
|
auto net2 = createNetwork();
|
||||||
|
ASSERT_NE(NetworkCompilationContext::computeHash(net1, {{"key", "value"}}),
|
||||||
|
NetworkCompilationContext::computeHash(net2, {}));
|
||||||
|
ASSERT_EQ(NetworkCompilationContext::computeHash(net1, {{"key", "value"}}),
|
||||||
|
NetworkCompilationContext::computeHash(net2, {{"key", "value"}}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(NetworkContext_CNNNetwork, HashWithPrimitivesPriority) {
|
||||||
|
auto net1 = createNetwork();
|
||||||
|
auto net2 = createNetwork();
|
||||||
|
auto net3 = createNetwork();
|
||||||
|
auto & op2 = net2.getFunction()->get_ops().front()->get_rt_info();
|
||||||
|
op2["PrimitivesPriority"] = std::make_shared<ngraph::VariantWrapper<std::string> > ("testPriority");
|
||||||
|
|
||||||
|
auto & op3 = net3.getFunction()->get_ops().front()->get_rt_info();
|
||||||
|
op3["PrimitivesPriority"] = std::make_shared<ngraph::VariantWrapper<std::string> > ("testPriority");
|
||||||
|
|
||||||
|
ASSERT_NE(NetworkCompilationContext::computeHash(net1, {}),
|
||||||
|
NetworkCompilationContext::computeHash(net2, {}));
|
||||||
|
|
||||||
|
ASSERT_EQ(NetworkCompilationContext::computeHash(net2, {}),
|
||||||
|
NetworkCompilationContext::computeHash(net3, {}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(NetworkContext_CNNNetwork, HashWithDequantization) {
|
||||||
|
auto setDeqEmpty = [&](Node::RTMap& rtInfo) {
|
||||||
|
rtInfo[VariantWrapper<DequantizationAttr>::type_info.name] =
|
||||||
|
std::make_shared<VariantWrapper<DequantizationAttr>>(DequantizationAttr());
|
||||||
|
};
|
||||||
|
auto setDeq = [&](Node::RTMap& rtInfo, const std::string& name) {
|
||||||
|
rtInfo[VariantWrapper<DequantizationAttr>::type_info.name] =
|
||||||
|
std::make_shared<VariantWrapper<DequantizationAttr>>(DequantizationAttr(name));
|
||||||
|
};
|
||||||
|
checkCustomRt(setDeqEmpty, setDeq);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(NetworkContext_CNNNetwork, HashWithFusedNames) {
|
||||||
|
auto setFusedEmpty = [&](Node::RTMap& rtInfo) {
|
||||||
|
rtInfo[VariantWrapper<FusedNames>::type_info.name] =
|
||||||
|
std::make_shared<VariantWrapper<FusedNames>>(FusedNames());
|
||||||
|
};
|
||||||
|
auto setFused = [&](Node::RTMap& rtInfo, const std::string& name) {
|
||||||
|
rtInfo[VariantWrapper<FusedNames>::type_info.name] =
|
||||||
|
std::make_shared<VariantWrapper<FusedNames>>(FusedNames(name));
|
||||||
|
};
|
||||||
|
checkCustomRt(setFusedEmpty, setFused);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(NetworkContext_CNNNetwork, HashWithPrimitivesPriorityType) {
|
||||||
|
auto setPrimEmpty = [&](Node::RTMap& rtInfo) {
|
||||||
|
rtInfo[VariantWrapper<PrimitivesPriority>::type_info.name] =
|
||||||
|
std::make_shared<VariantWrapper<PrimitivesPriority>>(PrimitivesPriority());
|
||||||
|
};
|
||||||
|
auto setPrim = [&](Node::RTMap& rtInfo, const std::string& name) {
|
||||||
|
rtInfo[VariantWrapper<PrimitivesPriority>::type_info.name] =
|
||||||
|
std::make_shared<VariantWrapper<PrimitivesPriority>>(PrimitivesPriority(name));
|
||||||
|
};
|
||||||
|
checkCustomRt(setPrimEmpty, setPrim);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(NetworkContext_CNNNetwork, HashWithAffinity) {
|
||||||
|
auto net1 = createNetwork();
|
||||||
|
auto net2 = createNetwork();
|
||||||
|
auto net3 = createNetwork();
|
||||||
|
auto & op2 = net2.getFunction()->get_ops().front()->get_rt_info();
|
||||||
|
op2["affinity"] = std::make_shared<ngraph::VariantWrapper<std::string>>("testAffinity");
|
||||||
|
|
||||||
|
auto & op3 = net3.getFunction()->get_ops().front()->get_rt_info();
|
||||||
|
op3["affinity"] = std::make_shared<ngraph::VariantWrapper<std::string>>("testAffinity");
|
||||||
|
|
||||||
|
ASSERT_NE(NetworkCompilationContext::computeHash(net1, {}),
|
||||||
|
NetworkCompilationContext::computeHash(net2, {}));
|
||||||
|
|
||||||
|
ASSERT_EQ(NetworkCompilationContext::computeHash(net2, {}),
|
||||||
|
NetworkCompilationContext::computeHash(net3, {}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(NetworkContext_CNNNetwork, HashWithFutureRt_string) {
|
||||||
|
auto net1 = createNetwork();
|
||||||
|
auto net2 = createNetwork();
|
||||||
|
auto net3 = createNetwork();
|
||||||
|
|
||||||
|
auto & op1 = net1.getFunction()->get_ops().front()->get_rt_info();
|
||||||
|
op1["someFutureKey"] = std::make_shared<ngraph::VariantWrapper<std::string>>("hello");
|
||||||
|
|
||||||
|
auto & op2 = net2.getFunction()->get_ops().front()->get_rt_info();
|
||||||
|
op2["someFutureKey"] = std::make_shared<ngraph::VariantWrapper<std::string>>("hello");
|
||||||
|
|
||||||
|
auto & op3 = net3.getFunction()->get_ops().front()->get_rt_info();
|
||||||
|
op3["someFutureKey"] = std::make_shared<ngraph::VariantWrapper<std::string>>("olleh");
|
||||||
|
|
||||||
|
ASSERT_EQ(NetworkCompilationContext::computeHash(net1, {}),
|
||||||
|
NetworkCompilationContext::computeHash(net2, {}));
|
||||||
|
|
||||||
|
ASSERT_NE(NetworkCompilationContext::computeHash(net2, {}),
|
||||||
|
NetworkCompilationContext::computeHash(net3, {}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(NetworkContext_CNNNetwork, HashWithFutureRt_int64) {
|
||||||
|
auto net1 = createNetwork();
|
||||||
|
auto net2 = createNetwork();
|
||||||
|
auto net3 = createNetwork();
|
||||||
|
|
||||||
|
auto & op1 = net1.getFunction()->get_ops().front()->get_rt_info();
|
||||||
|
op1["someFutureKey"] = std::make_shared<ngraph::VariantWrapper<int64_t>>(42);
|
||||||
|
|
||||||
|
auto & op2 = net2.getFunction()->get_ops().front()->get_rt_info();
|
||||||
|
op2["someFutureKey"] = std::make_shared<ngraph::VariantWrapper<int64_t>>(42);
|
||||||
|
|
||||||
|
auto & op3 = net3.getFunction()->get_ops().front()->get_rt_info();
|
||||||
|
op3["someFutureKey"] = std::make_shared<ngraph::VariantWrapper<int64_t>>(43);
|
||||||
|
|
||||||
|
ASSERT_EQ(NetworkCompilationContext::computeHash(net1, {}),
|
||||||
|
NetworkCompilationContext::computeHash(net2, {}));
|
||||||
|
|
||||||
|
ASSERT_NE(NetworkCompilationContext::computeHash(net2, {}),
|
||||||
|
NetworkCompilationContext::computeHash(net3, {}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(NetworkContext_CNNNetwork, HashWithDifferentResults) {
|
||||||
|
auto net1 = createNetwork();
|
||||||
|
auto net2 = createNetwork();
|
||||||
|
net2.getFunction()->remove_result(net2.getFunction()->get_results().front());
|
||||||
|
auto net3 = createNetwork();
|
||||||
|
net3.getFunction()->remove_result(net3.getFunction()->get_results().front());
|
||||||
|
ASSERT_NE(NetworkCompilationContext::computeHash(net1, {}),
|
||||||
|
NetworkCompilationContext::computeHash(net2, {}));
|
||||||
|
ASSERT_EQ(NetworkCompilationContext::computeHash(net2, {}),
|
||||||
|
NetworkCompilationContext::computeHash(net3, {}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(NetworkContext_CNNNetwork, HashWithDifferentMeanValues) {
|
||||||
|
auto updatePreprocess = [&](CNNNetwork& cnnNet) {
|
||||||
|
auto &preProcess = cnnNet.getInputsInfo().begin()->second->getPreProcess();
|
||||||
|
preProcess.init(3);
|
||||||
|
preProcess[0]->stdScale = 2;
|
||||||
|
preProcess[1]->stdScale = 3;
|
||||||
|
preProcess[2]->stdScale = 4;
|
||||||
|
preProcess[0]->meanValue = 0;
|
||||||
|
preProcess[1]->meanValue = 1;
|
||||||
|
preProcess[2]->meanValue = 2;
|
||||||
|
preProcess.setVariant(InferenceEngine::MEAN_VALUE);
|
||||||
|
};
|
||||||
|
auto net1 = createNetwork();
|
||||||
|
auto net2 = createNetwork();
|
||||||
|
updatePreprocess(net2);
|
||||||
|
auto net3 = createNetwork();
|
||||||
|
updatePreprocess(net3);
|
||||||
|
ASSERT_NE(NetworkCompilationContext::computeHash(net1, {}),
|
||||||
|
NetworkCompilationContext::computeHash(net2, {}));
|
||||||
|
ASSERT_EQ(NetworkCompilationContext::computeHash(net2, {}),
|
||||||
|
NetworkCompilationContext::computeHash(net3, {}));
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////
|
||||||
|
|
||||||
|
TEST(NetworkContext_ModelName, HashOfSame) {
|
||||||
|
ASSERT_EQ(NetworkCompilationContext::computeHash("model1", {}),
|
||||||
|
NetworkCompilationContext::computeHash("model1", {}));
|
||||||
|
|
||||||
|
ASSERT_NE(NetworkCompilationContext::computeHash("model1", {}),
|
||||||
|
NetworkCompilationContext::computeHash("model2", {}));
|
||||||
|
|
||||||
|
ASSERT_NE(NetworkCompilationContext::computeHash("model1", {{"key", "value"}}),
|
||||||
|
NetworkCompilationContext::computeHash("model1", {}));
|
||||||
|
|
||||||
|
ASSERT_EQ(NetworkCompilationContext::computeHash("model1", {{"key", "value"}}),
|
||||||
|
NetworkCompilationContext::computeHash("model1", {{"key", "value"}}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(NetworkContext_ModelName, HashOfExistingFile) {
|
||||||
|
auto file1 = generateTestFilePrefix() + ".xml";
|
||||||
|
auto file2 = std::string(".") + CommonTestUtils::FileSeparator + file1;
|
||||||
|
|
||||||
|
FileGuard guard(file1);
|
||||||
|
{
|
||||||
|
std::ofstream os(file1);
|
||||||
|
os << "test";
|
||||||
|
}
|
||||||
|
ASSERT_EQ(NetworkCompilationContext::computeHash(file1, {}),
|
||||||
|
NetworkCompilationContext::computeHash(file1, {}));
|
||||||
|
|
||||||
|
ASSERT_EQ(NetworkCompilationContext::computeHash(file1, {}),
|
||||||
|
NetworkCompilationContext::computeHash(file2, {}));
|
||||||
|
|
||||||
|
ASSERT_NE(NetworkCompilationContext::computeHash(file1, {{"key", "value"}}),
|
||||||
|
NetworkCompilationContext::computeHash(file2, {}));
|
||||||
|
|
||||||
|
ASSERT_EQ(NetworkCompilationContext::computeHash(file1, {{"key", "value"}}),
|
||||||
|
NetworkCompilationContext::computeHash(file2, {{"key", "value"}}));
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user