Revert "Revert "Revert query_model logic""
This reverts commit 0df850838a.
This commit is contained in:
@@ -45,11 +45,71 @@ static int32_t as_int32_t(T v) {
|
||||
|
||||
namespace {
|
||||
|
||||
uint64_t calculate_td(const InferenceEngine::TensorDesc& td, uint64_t _seed) {
|
||||
uint64_t seed = _seed;
|
||||
uint64_t compute_model_hash(const std::shared_ptr<const ov::Model>& model, const ov::AnyMap& compileOptions) {
|
||||
OPENVINO_ASSERT(model);
|
||||
|
||||
uint64_t seed = 0;
|
||||
// 1. Calculate hash on function
|
||||
ov::pass::Manager m;
|
||||
m.register_pass<ov::pass::FixRtInfo>();
|
||||
m.register_pass<ov::pass::Hash>(seed);
|
||||
m.run_passes(std::const_pointer_cast<ov::Model>(model));
|
||||
|
||||
// 2. Compute hash on serialized data and options
|
||||
for (const auto& kvp : compileOptions) {
|
||||
seed = ov::hash_combine(seed, kvp.first + kvp.second.as<std::string>());
|
||||
}
|
||||
|
||||
// 3. Add runtime information which may not be serialized
|
||||
for (const auto& op : model->get_ordered_ops()) {
|
||||
const auto& rt = op->get_rt_info();
|
||||
for (const auto& rtMapData : rt) {
|
||||
seed = ov::hash_combine(seed, rtMapData.first);
|
||||
std::stringstream strm;
|
||||
rtMapData.second.print(strm);
|
||||
seed = ov::hash_combine(seed, strm.str());
|
||||
}
|
||||
}
|
||||
|
||||
// 4. Legacy part if CNNNetwork is used with new Plugin API
|
||||
for (auto&& input : model->inputs()) {
|
||||
auto& rt_info = input.get_rt_info();
|
||||
|
||||
auto it = rt_info.find("ie_legacy_td");
|
||||
if (it != rt_info.end()) {
|
||||
auto td = it->second.as<InferenceEngine::TensorDesc>();
|
||||
seed = ov::hash_combine(seed, ov::as_int32_t(td.getPrecision()));
|
||||
seed = ov::hash_combine(seed, ov::as_int32_t(td.getLayout()));
|
||||
}
|
||||
|
||||
it = rt_info.find("ie_legacy_preproc");
|
||||
if (it != rt_info.end()) {
|
||||
auto preproc = it->second.as<InferenceEngine::PreProcessInfo>();
|
||||
|
||||
seed = ov::hash_combine(seed, ov::as_int32_t(preproc.getMeanVariant()));
|
||||
|
||||
if (preproc.getMeanVariant() == InferenceEngine::MeanVariant::MEAN_VALUE) {
|
||||
seed = ov::hash_combine(seed, preproc.getNumberOfChannels());
|
||||
for (size_t c = 0; c < preproc.getNumberOfChannels(); ++c) {
|
||||
const InferenceEngine::PreProcessChannel::Ptr& channelInfo = preproc[c];
|
||||
seed = ov::hash_combine(seed, channelInfo->stdScale);
|
||||
seed = ov::hash_combine(seed, channelInfo->meanValue);
|
||||
}
|
||||
} else if (preproc.getMeanVariant() == InferenceEngine::MeanVariant::MEAN_IMAGE) {
|
||||
// TODO: think if we need to compute hash for mean image if it exists
|
||||
}
|
||||
}
|
||||
}
|
||||
for (auto&& output : model->outputs()) {
|
||||
auto& rt_info = output.get_rt_info();
|
||||
auto it = rt_info.find("ie_legacy_td");
|
||||
if (it != rt_info.end()) {
|
||||
auto td = it->second.as<InferenceEngine::TensorDesc>();
|
||||
seed = ov::hash_combine(seed, ov::as_int32_t(td.getPrecision()));
|
||||
seed = ov::hash_combine(seed, ov::as_int32_t(td.getLayout()));
|
||||
}
|
||||
}
|
||||
|
||||
seed = ov::hash_combine(seed, ov::as_int32_t(td.getPrecision()));
|
||||
seed = ov::hash_combine(seed, ov::as_int32_t(td.getLayout()));
|
||||
return seed;
|
||||
}
|
||||
|
||||
@@ -83,64 +143,47 @@ std::string NetworkCompilationContext::compute_hash(const std::shared_ptr<const
|
||||
const ov::AnyMap& compileOptions) {
|
||||
OV_ITT_SCOPE(FIRST_INFERENCE, ov::itt::domains::IE_RT, "NetworkCompilationContext::compute_hash - Model");
|
||||
|
||||
OPENVINO_ASSERT(model);
|
||||
return std::to_string(compute_model_hash(model, compileOptions));
|
||||
}
|
||||
|
||||
std::string NetworkCompilationContext::compute_hash(const InferenceEngine::CNNNetwork& network,
|
||||
const ov::AnyMap& compileOptions) {
|
||||
OV_ITT_SCOPE(FIRST_INFERENCE, ov::itt::domains::IE_RT, "NetworkCompilationContext::compute_hash - CNN");
|
||||
|
||||
OPENVINO_ASSERT(network.getFunction());
|
||||
|
||||
uint64_t seed = 0;
|
||||
// 1. Calculate hash on function
|
||||
ov::pass::Manager m;
|
||||
m.register_pass<ov::pass::FixRtInfo>();
|
||||
m.register_pass<ov::pass::Hash>(seed);
|
||||
m.run_passes(std::const_pointer_cast<ov::Model>(model));
|
||||
|
||||
// 2. Compute hash on serialized data and options
|
||||
for (const auto& kvp : compileOptions) {
|
||||
seed = ov::hash_combine(seed, kvp.first + kvp.second.as<std::string>());
|
||||
}
|
||||
|
||||
// 3. Add runtime information which may not be serialized
|
||||
for (const auto& op : model->get_ordered_ops()) {
|
||||
const auto& rt = op->get_rt_info();
|
||||
for (const auto& rtMapData : rt) {
|
||||
seed = ov::hash_combine(seed, rtMapData.first);
|
||||
std::stringstream strm;
|
||||
rtMapData.second.print(strm);
|
||||
seed = ov::hash_combine(seed, strm.str());
|
||||
}
|
||||
}
|
||||
seed = compute_model_hash(network.getFunction(), compileOptions);
|
||||
|
||||
// 4. Legacy part if CNNNetwork is used with new Plugin API
|
||||
for (auto&& input : model->inputs()) {
|
||||
auto& rt_info = input.get_rt_info();
|
||||
// 4. Add inputs info
|
||||
for (const auto& input : network.getInputsInfo()) {
|
||||
InferenceEngine::InputInfo::Ptr info = input.second;
|
||||
seed = hash_combine(seed, as_int32_t(info->getPrecision()));
|
||||
seed = hash_combine(seed, as_int32_t(info->getLayout()));
|
||||
|
||||
auto it = rt_info.find("ie_legacy_td");
|
||||
if (it != rt_info.end()) {
|
||||
seed = calculate_td(it->second.as<InferenceEngine::TensorDesc>(), seed);
|
||||
}
|
||||
const InferenceEngine::PreProcessInfo& preproc = info->getPreProcess();
|
||||
seed = hash_combine(seed, as_int32_t(preproc.getMeanVariant()));
|
||||
|
||||
it = rt_info.find("ie_legacy_preproc");
|
||||
if (it != rt_info.end()) {
|
||||
auto preproc = it->second.as<InferenceEngine::PreProcessInfo>();
|
||||
|
||||
seed = ov::hash_combine(seed, ov::as_int32_t(preproc.getMeanVariant()));
|
||||
|
||||
if (preproc.getMeanVariant() == InferenceEngine::MeanVariant::MEAN_VALUE) {
|
||||
seed = ov::hash_combine(seed, preproc.getNumberOfChannels());
|
||||
for (size_t c = 0; c < preproc.getNumberOfChannels(); ++c) {
|
||||
const InferenceEngine::PreProcessChannel::Ptr& channelInfo = preproc[c];
|
||||
seed = ov::hash_combine(seed, channelInfo->stdScale);
|
||||
seed = ov::hash_combine(seed, channelInfo->meanValue);
|
||||
}
|
||||
} else if (preproc.getMeanVariant() == InferenceEngine::MeanVariant::MEAN_IMAGE) {
|
||||
// TODO: think if we need to compute hash for mean image if it exists
|
||||
if (preproc.getMeanVariant() == InferenceEngine::MeanVariant::MEAN_VALUE) {
|
||||
seed = hash_combine(seed, preproc.getNumberOfChannels());
|
||||
for (size_t c = 0; c < preproc.getNumberOfChannels(); ++c) {
|
||||
const InferenceEngine::PreProcessChannel::Ptr& channelInfo = preproc[c];
|
||||
seed = hash_combine(seed, channelInfo->stdScale);
|
||||
seed = hash_combine(seed, channelInfo->meanValue);
|
||||
}
|
||||
} else if (preproc.getMeanVariant() == InferenceEngine::MeanVariant::MEAN_IMAGE) {
|
||||
// TODO: think if we need to compute hash for mean image if it exists
|
||||
}
|
||||
}
|
||||
for (auto&& output : model->outputs()) {
|
||||
auto& rt_info = output.get_rt_info();
|
||||
auto it = rt_info.find("ie_legacy_td");
|
||||
if (it != rt_info.end()) {
|
||||
seed = calculate_td(it->second.as<InferenceEngine::TensorDesc>(), seed);
|
||||
}
|
||||
|
||||
// 5. Add outputs info
|
||||
for (const auto& output : network.getOutputsInfo()) {
|
||||
InferenceEngine::DataPtr info = output.second;
|
||||
seed = hash_combine(seed, as_int32_t(info->getPrecision()));
|
||||
seed = hash_combine(seed, as_int32_t(info->getLayout()));
|
||||
}
|
||||
|
||||
return std::to_string(seed);
|
||||
|
||||
@@ -26,6 +26,8 @@ class Model;
|
||||
struct NetworkCompilationContext final {
|
||||
static std::string calculate_file_info(const std::string& filePath);
|
||||
|
||||
static std::string compute_hash(const InferenceEngine::CNNNetwork& network, const ov::AnyMap& compileOptions);
|
||||
|
||||
static std::string compute_hash(const std::shared_ptr<const ov::Model>& model, const ov::AnyMap& compileOptions);
|
||||
|
||||
static std::string compute_hash(const std::string& modelName, const ov::AnyMap& compileOptions);
|
||||
|
||||
@@ -430,7 +430,9 @@ ov::SoPtr<InferenceEngine::IExecutableNetworkInternal> ov::CoreImpl::compile_mod
|
||||
res = compile_model_impl(cnnNetwork.getFunction(), plugin, parsed._config, {}, cacheContent);
|
||||
}
|
||||
} else if (cacheManager) {
|
||||
res = plugin.compile_model(model_path, parsed._config);
|
||||
auto cnnNetwork = ReadNetwork(model_path, std::string());
|
||||
// TODO: 'validation' for dynamic API doesn't work for this case, as it affects a lot of plugin API
|
||||
res = compile_model(plugin, cnnNetwork.getFunction(), {}, parsed._config);
|
||||
} else {
|
||||
auto cnnNetwork = ReadNetwork(model_path, std::string());
|
||||
res = compile_model_impl(cnnNetwork.getFunction(), plugin, parsed._config, {}, cacheContent);
|
||||
@@ -483,37 +485,7 @@ ov::SupportedOpsMap ov::CoreImpl::query_model(const std::shared_ptr<const ov::Mo
|
||||
const ov::AnyMap& config) const {
|
||||
OV_ITT_SCOPED_TASK(ov::itt::domains::IE, "Core::query_model");
|
||||
auto parsed = parseDeviceNameIntoConfig(device_name, config);
|
||||
auto ret = get_plugin(parsed._deviceName).query_model(model, parsed._config);
|
||||
auto specialized_function = model->clone();
|
||||
|
||||
std::string defDevice = ret.begin()->second;
|
||||
ngraph::pass::ConstantFolding().run_on_model(specialized_function);
|
||||
std::unordered_set<std::string> opNames;
|
||||
|
||||
for (const auto& op : specialized_function->get_ops())
|
||||
opNames.emplace(op->get_friendly_name());
|
||||
|
||||
for (const auto& op : model->get_ops()) {
|
||||
if (opNames.find(op->get_friendly_name()) == opNames.end()) {
|
||||
ret[op->get_friendly_name()] = defDevice;
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto& op : model->get_ops()) {
|
||||
if (!ret.count(op->get_friendly_name()) && std::dynamic_pointer_cast<ngraph::op::Constant>(op)) {
|
||||
bool are_all_users_supported = true;
|
||||
for (const auto& user : op->output(0).get_target_inputs()) {
|
||||
if (!ret.count(user.get_node()->get_friendly_name())) {
|
||||
are_all_users_supported = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (are_all_users_supported) {
|
||||
ret[op->get_friendly_name()] = defDevice;
|
||||
}
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
return get_plugin(parsed._deviceName).query_model(model, parsed._config);
|
||||
}
|
||||
|
||||
std::vector<std::string> ov::CoreImpl::get_available_devices() const {
|
||||
|
||||
@@ -219,7 +219,9 @@ private:
|
||||
const InferenceEngine::CNNNetwork& model,
|
||||
ov::Plugin& plugin,
|
||||
const std::map<std::string, std::string>& parsedConfig,
|
||||
const InferenceEngine::RemoteContext::Ptr& context);
|
||||
const InferenceEngine::RemoteContext::Ptr& context,
|
||||
const CacheContent& cacheContent,
|
||||
bool forceDisableCache = false);
|
||||
|
||||
public:
|
||||
CoreImpl(bool _newAPI);
|
||||
|
||||
@@ -25,7 +25,9 @@ ov::SoPtr<InferenceEngine::IExecutableNetworkInternal> ov::CoreImpl::LoadNetwork
|
||||
const InferenceEngine::CNNNetwork& network,
|
||||
ov::Plugin& plugin,
|
||||
const std::map<std::string, std::string>& parsedConfig,
|
||||
const InferenceEngine::RemoteContext::Ptr& context) {
|
||||
const InferenceEngine::RemoteContext::Ptr& context,
|
||||
const CacheContent& cacheContent,
|
||||
bool forceDisableCache) {
|
||||
OV_ITT_SCOPED_TASK(ov::itt::domains::IE, "CoreImpl::compile_model_impl");
|
||||
ov::SoPtr<InferenceEngine::IExecutableNetworkInternal> execNetwork;
|
||||
auto wrapper = std::dynamic_pointer_cast<InferenceEngine::IPluginWrapper>(plugin.m_ptr);
|
||||
@@ -34,6 +36,21 @@ ov::SoPtr<InferenceEngine::IExecutableNetworkInternal> ov::CoreImpl::LoadNetwork
|
||||
execNetwork = {context ? old_plugin->LoadNetwork(network, parsedConfig, context)
|
||||
: old_plugin->LoadNetwork(network, parsedConfig),
|
||||
plugin.m_so};
|
||||
if (!forceDisableCache && cacheContent.cacheManager && device_supports_import_export(plugin)) {
|
||||
try {
|
||||
// need to export network for further import from "cache"
|
||||
OV_ITT_SCOPE(FIRST_INFERENCE, InferenceEngine::itt::domains::IE_LT, "Core::LoadNetwork::Export");
|
||||
cacheContent.cacheManager->writeCacheEntry(cacheContent.blobId, [&](std::ostream& networkStream) {
|
||||
networkStream << ov::CompiledBlobHeader(
|
||||
InferenceEngine::GetInferenceEngineVersion()->buildNumber,
|
||||
ov::NetworkCompilationContext::calculate_file_info(cacheContent.modelPath));
|
||||
execNetwork->Export(networkStream);
|
||||
});
|
||||
} catch (...) {
|
||||
cacheContent.cacheManager->removeCacheEntry(cacheContent.blobId);
|
||||
throw;
|
||||
}
|
||||
}
|
||||
return execNetwork;
|
||||
}
|
||||
|
||||
@@ -72,23 +89,72 @@ ov::SoPtr<InferenceEngine::IExecutableNetworkInternal> ov::CoreImpl::LoadNetwork
|
||||
|
||||
auto plugin = get_plugin(parsed._deviceName);
|
||||
|
||||
auto res = LoadNetworkImpl(network, plugin, parsed._config, context);
|
||||
ov::SoPtr<InferenceEngine::IExecutableNetworkInternal> res;
|
||||
auto conf = ov::any_copy(parsed._config);
|
||||
auto cacheManager =
|
||||
coreConfig.get_cache_config_for_device(parsed._deviceName, device_supports_cache_dir(plugin), conf)
|
||||
._cacheManager;
|
||||
auto cacheContent = CacheContent{cacheManager};
|
||||
if (cacheManager && device_supports_import_export(plugin)) {
|
||||
cacheContent.blobId = ov::NetworkCompilationContext::compute_hash(
|
||||
network,
|
||||
create_compile_config(plugin, parsed._deviceName, ov::any_copy(parsed._config)));
|
||||
bool loadedFromCache = false;
|
||||
auto lock = cacheGuard.getHashLock(cacheContent.blobId);
|
||||
res = load_model_from_cache(cacheContent, plugin, conf, {context, {}}, loadedFromCache);
|
||||
if (!loadedFromCache) {
|
||||
res = LoadNetworkImpl(network, plugin, parsed._config, context, cacheContent);
|
||||
} else {
|
||||
// Temporary workaround until all plugins support caching of original model inputs
|
||||
InferenceEngine::SetExeNetworkInfo(res._ptr, network.getFunction(), isNewAPI());
|
||||
}
|
||||
} else {
|
||||
res = LoadNetworkImpl(network, plugin, parsed._config, context, cacheContent);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
InferenceEngine::SoExecutableNetworkInternal ov::CoreImpl::LoadNetwork(
|
||||
const InferenceEngine::CNNNetwork& network,
|
||||
const std::string& deviceName,
|
||||
const std::string& deviceNameOrig,
|
||||
const std::map<std::string, std::string>& config) {
|
||||
OV_ITT_SCOPE(FIRST_INFERENCE, InferenceEngine::itt::domains::IE_LT, "Core::LoadNetwork::CNN");
|
||||
if (network.getFunction()) {
|
||||
auto compiled_model =
|
||||
compile_model(ov::legacy_convert::convert_model(network, isNewAPI()), deviceName, any_copy(config));
|
||||
compile_model(ov::legacy_convert::convert_model(network, isNewAPI()), deviceNameOrig, any_copy(config));
|
||||
return {compiled_model._ptr, compiled_model._so};
|
||||
}
|
||||
auto parsed = parseDeviceNameIntoConfig(deviceName, config);
|
||||
std::string deviceName = deviceNameOrig;
|
||||
std::map<std::string, std::string> config_with_batch = config;
|
||||
bool forceDisableCache = config_with_batch.count(CONFIG_KEY_INTERNAL(FORCE_DISABLE_CACHE)) > 0;
|
||||
auto parsed = parseDeviceNameIntoConfig(deviceName, config_with_batch);
|
||||
if (forceDisableCache) {
|
||||
// remove this config key from parsed as plugins can throw unsupported exception
|
||||
parsed._config.erase(CONFIG_KEY_INTERNAL(FORCE_DISABLE_CACHE));
|
||||
}
|
||||
auto plugin = get_plugin(parsed._deviceName);
|
||||
auto res = LoadNetworkImpl(network, plugin, parsed._config, nullptr);
|
||||
ov::SoPtr<InferenceEngine::IExecutableNetworkInternal> res;
|
||||
auto conf = ov::any_copy(parsed._config);
|
||||
auto cacheManager =
|
||||
coreConfig.get_cache_config_for_device(parsed._deviceName, device_supports_cache_dir(plugin), conf)
|
||||
._cacheManager;
|
||||
auto cacheContent = CacheContent{cacheManager};
|
||||
if (!forceDisableCache && cacheManager && device_supports_import_export(plugin)) {
|
||||
cacheContent.blobId = ov::NetworkCompilationContext::compute_hash(
|
||||
network,
|
||||
create_compile_config(plugin, parsed._deviceName, ov::any_copy(parsed._config)));
|
||||
bool loadedFromCache = false;
|
||||
auto lock = cacheGuard.getHashLock(cacheContent.blobId);
|
||||
res = load_model_from_cache(cacheContent, plugin, conf, {}, loadedFromCache);
|
||||
if (!loadedFromCache) {
|
||||
res = LoadNetworkImpl(network, plugin, parsed._config, nullptr, cacheContent, forceDisableCache);
|
||||
} else {
|
||||
// Temporary workaround until all plugins support caching of original model inputs
|
||||
InferenceEngine::SetExeNetworkInfo(res._ptr, network.getFunction(), isNewAPI());
|
||||
}
|
||||
} else {
|
||||
res = LoadNetworkImpl(network, plugin, parsed._config, nullptr, cacheContent, forceDisableCache);
|
||||
}
|
||||
return {res._ptr, res._so};
|
||||
}
|
||||
|
||||
@@ -98,9 +164,54 @@ InferenceEngine::SoExecutableNetworkInternal ov::CoreImpl::LoadNetwork(
|
||||
const std::map<std::string, std::string>& config,
|
||||
const std::function<void(const InferenceEngine::CNNNetwork&)>& val) {
|
||||
OV_ITT_SCOPE(FIRST_INFERENCE, ie::itt::domains::IE_LT, "Core::LoadNetwork::Path");
|
||||
|
||||
auto compiled_model = compile_model(modelPath, deviceName, any_copy(config));
|
||||
return {compiled_model._ptr, compiled_model._so};
|
||||
auto parsed = parseDeviceNameIntoConfig(deviceName, config);
|
||||
auto plugin = get_plugin(parsed._deviceName);
|
||||
ov::SoPtr<InferenceEngine::IExecutableNetworkInternal> res;
|
||||
auto conf = any_copy(parsed._config);
|
||||
auto cacheManager =
|
||||
coreConfig.get_cache_config_for_device(parsed._deviceName, device_supports_cache_dir(plugin), conf)
|
||||
._cacheManager;
|
||||
auto cacheContent = CacheContent{cacheManager, modelPath};
|
||||
if (cacheManager && device_supports_import_export(plugin)) {
|
||||
bool loadedFromCache = false;
|
||||
cacheContent.blobId =
|
||||
ov::NetworkCompilationContext::compute_hash(modelPath,
|
||||
create_compile_config(plugin, parsed._deviceName, conf));
|
||||
auto lock = cacheGuard.getHashLock(cacheContent.blobId);
|
||||
res = load_model_from_cache(cacheContent, plugin, conf, {}, loadedFromCache);
|
||||
if (!loadedFromCache) {
|
||||
auto cnnNetwork = ReadNetwork(modelPath, std::string());
|
||||
if (val) {
|
||||
val(cnnNetwork);
|
||||
}
|
||||
if (cnnNetwork.getFunction()) {
|
||||
res = compile_model_impl(ov::legacy_convert::convert_model(cnnNetwork, isNewAPI()),
|
||||
plugin,
|
||||
conf,
|
||||
{},
|
||||
cacheContent);
|
||||
} else {
|
||||
res = LoadNetworkImpl(cnnNetwork, plugin, parsed._config, nullptr, cacheContent);
|
||||
}
|
||||
}
|
||||
} else if (cacheManager) {
|
||||
res = plugin.compile_model(modelPath, conf);
|
||||
} else {
|
||||
auto cnnNetwork = ReadNetwork(modelPath, std::string());
|
||||
if (val) {
|
||||
val(cnnNetwork);
|
||||
}
|
||||
if (cnnNetwork.getFunction()) {
|
||||
res = compile_model_impl(ov::legacy_convert::convert_model(cnnNetwork, isNewAPI()),
|
||||
plugin,
|
||||
conf,
|
||||
{},
|
||||
cacheContent);
|
||||
} else {
|
||||
res = LoadNetworkImpl(cnnNetwork, plugin, parsed._config, nullptr, cacheContent);
|
||||
}
|
||||
}
|
||||
return {res._ptr, res._so};
|
||||
}
|
||||
|
||||
InferenceEngine::SoExecutableNetworkInternal ov::CoreImpl::LoadNetwork(
|
||||
@@ -136,8 +247,43 @@ InferenceEngine::QueryNetworkResult ov::CoreImpl::QueryNetwork(const InferenceEn
|
||||
return ret;
|
||||
}
|
||||
auto res = query_model(network.getFunction(), deviceName, any_copy(config));
|
||||
if (!network.getFunction() || res.empty()) {
|
||||
ret.rc = InferenceEngine::GENERAL_ERROR;
|
||||
return ret;
|
||||
}
|
||||
ret.supportedLayersMap = res;
|
||||
|
||||
const auto& func = network.getFunction();
|
||||
auto specialized_function = func->clone();
|
||||
|
||||
std::string defDevice = ret.supportedLayersMap.begin()->second;
|
||||
ngraph::pass::ConstantFolding().run_on_model(specialized_function);
|
||||
std::unordered_set<std::string> opNames;
|
||||
|
||||
for (const auto& op : specialized_function->get_ops())
|
||||
opNames.emplace(op->get_friendly_name());
|
||||
|
||||
for (const auto& op : func->get_ops()) {
|
||||
if (opNames.find(op->get_friendly_name()) == opNames.end()) {
|
||||
ret.supportedLayersMap[op->get_friendly_name()] = defDevice;
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto& op : func->get_ops()) {
|
||||
if (!ret.supportedLayersMap.count(op->get_friendly_name()) &&
|
||||
std::dynamic_pointer_cast<ngraph::op::Constant>(op)) {
|
||||
bool are_all_users_supported = true;
|
||||
for (const auto& user : op->output(0).get_target_inputs()) {
|
||||
if (!ret.supportedLayersMap.count(user.get_node()->get_friendly_name())) {
|
||||
are_all_users_supported = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (are_all_users_supported) {
|
||||
ret.supportedLayersMap[op->get_friendly_name()] = defDevice;
|
||||
}
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
||||
@@ -147,3 +147,45 @@ TEST_F(CNNNetworkTests, throwsHasDynamicInputs_queryNetwork) {
|
||||
EXPECT_TRUE(std::string(e.what()).find("p3_2") == std::string::npos) << e.what();
|
||||
}
|
||||
}
|
||||
|
||||
class CNNNetworkTests_LoadFromFileTest : public ::testing::Test {
|
||||
protected:
|
||||
std::string modelName{};
|
||||
std::string weightsName{};
|
||||
InferenceEngine::Core core;
|
||||
|
||||
public:
|
||||
void SetUp() override {
|
||||
auto filePrefix = CommonTestUtils::generateTestFilePrefix();
|
||||
modelName = filePrefix + "_CNNNetworkTests_LoadFromFileTest.xml";
|
||||
weightsName = filePrefix + "_CNNNetworkTests_LoadFromFileTest.bin";
|
||||
|
||||
std::shared_ptr<ov::Model> model = CNNNetworkTests_create_model();
|
||||
ov::pass::Serialize(modelName, weightsName).run_on_model(model);
|
||||
ASSERT_NO_THROW(
|
||||
core.RegisterPlugin(ov::util::make_plugin_library_name(CommonTestUtils::getExecutableDirectory(),
|
||||
std::string("mock_engine") + IE_BUILD_POSTFIX),
|
||||
"mock"));
|
||||
}
|
||||
|
||||
void TearDown() override {
|
||||
CommonTestUtils::removeIRFiles(modelName, weightsName);
|
||||
core.UnregisterPlugin("mock");
|
||||
}
|
||||
};
|
||||
#if defined(ENABLE_OV_IR_FRONTEND)
|
||||
TEST_F(CNNNetworkTests_LoadFromFileTest, throwsHasDynamicInputs_fromPath) {
|
||||
try {
|
||||
core.LoadNetwork(modelName, "mock");
|
||||
FAIL() << "LoadNetwork with dynamic inputs shall throw";
|
||||
} catch (const ov::AssertFailure& e) {
|
||||
EXPECT_TRUE(std::string(e.what()).find("InferenceEngine::Core::LoadNetwork") != std::string::npos) << e.what();
|
||||
EXPECT_TRUE(std::string(e.what()).find("p1_1") != std::string::npos) << e.what();
|
||||
EXPECT_TRUE(std::string(e.what()).find("p1_2") != std::string::npos) << e.what();
|
||||
EXPECT_TRUE(std::string(e.what()).find("p2_1") != std::string::npos) << e.what();
|
||||
EXPECT_TRUE(std::string(e.what()).find("p2_2") != std::string::npos) << e.what();
|
||||
EXPECT_TRUE(std::string(e.what()).find("p3_1") == std::string::npos) << e.what();
|
||||
EXPECT_TRUE(std::string(e.what()).find("p3_2") == std::string::npos) << e.what();
|
||||
}
|
||||
}
|
||||
#endif // defined(ENABLE_OV_IR_FRONTEND)
|
||||
|
||||
@@ -147,52 +147,64 @@ static std::shared_ptr<ngraph::Function> create_simple_function() {
|
||||
return func;
|
||||
}
|
||||
|
||||
static CNNNetwork createNetwork() {
|
||||
CNNNetwork res(create_simple_function());
|
||||
return res;
|
||||
}
|
||||
|
||||
static CNNNetwork createNetworkWithLayout(const ov::Layout& layout) {
|
||||
auto fun = create_simple_function();
|
||||
fun->get_parameters()[0]->set_layout(layout);
|
||||
fun->get_results()[0]->set_layout(layout);
|
||||
return CNNNetwork(fun);
|
||||
}
|
||||
|
||||
static void checkCustomRt(const std::function<void(Node::RTMap&)>& emptyCb,
|
||||
const std::function<void(Node::RTMap&, const std::string& name)>& nameCb) {
|
||||
auto model1 = create_simple_function();
|
||||
auto model2 = create_simple_function();
|
||||
auto& op1 = model1->get_ops().front()->get_rt_info();
|
||||
auto& op2 = model2->get_ops().front()->get_rt_info();
|
||||
auto net1 = createNetwork();
|
||||
auto net2 = createNetwork();
|
||||
auto& op1 = net1.getFunction()->get_ops().front()->get_rt_info();
|
||||
auto& op2 = net2.getFunction()->get_ops().front()->get_rt_info();
|
||||
|
||||
emptyCb(op2);
|
||||
ASSERT_NE(NetworkCompilationContext::compute_hash(model1, {}), NetworkCompilationContext::compute_hash(model2, {}));
|
||||
ASSERT_NE(NetworkCompilationContext::compute_hash(net1, {}), NetworkCompilationContext::compute_hash(net2, {}));
|
||||
|
||||
emptyCb(op1);
|
||||
ASSERT_EQ(NetworkCompilationContext::compute_hash(model1, {}), NetworkCompilationContext::compute_hash(model2, {}));
|
||||
ASSERT_EQ(NetworkCompilationContext::compute_hash(net1, {}), NetworkCompilationContext::compute_hash(net2, {}));
|
||||
|
||||
nameCb(op1, "test");
|
||||
ASSERT_NE(NetworkCompilationContext::compute_hash(model1, {}), NetworkCompilationContext::compute_hash(model2, {}));
|
||||
ASSERT_NE(NetworkCompilationContext::compute_hash(net1, {}), NetworkCompilationContext::compute_hash(net2, {}));
|
||||
|
||||
nameCb(op2, "test");
|
||||
ASSERT_EQ(NetworkCompilationContext::compute_hash(model1, {}), NetworkCompilationContext::compute_hash(model2, {}));
|
||||
ASSERT_EQ(NetworkCompilationContext::compute_hash(net1, {}), NetworkCompilationContext::compute_hash(net2, {}));
|
||||
|
||||
nameCb(op1, "test2");
|
||||
ASSERT_NE(NetworkCompilationContext::compute_hash(model1, {}), NetworkCompilationContext::compute_hash(model2, {}));
|
||||
ASSERT_NE(NetworkCompilationContext::compute_hash(net1, {}), NetworkCompilationContext::compute_hash(net2, {}));
|
||||
}
|
||||
|
||||
TEST(NetworkContext, HashOfSame) {
|
||||
auto model1 = create_simple_function();
|
||||
auto model2 = create_simple_function();
|
||||
ASSERT_EQ(NetworkCompilationContext::compute_hash(model1, {}), NetworkCompilationContext::compute_hash(model2, {}));
|
||||
TEST(NetworkContext_CNNNetwork, HashOfSame) {
|
||||
auto net1 = createNetwork();
|
||||
auto net2 = createNetwork();
|
||||
ASSERT_EQ(NetworkCompilationContext::compute_hash(net1, {}), NetworkCompilationContext::compute_hash(net2, {}));
|
||||
}
|
||||
|
||||
TEST(NetworkContext, HashWithConfig) {
|
||||
auto net1 = create_simple_function();
|
||||
auto net2 = create_simple_function();
|
||||
TEST(NetworkContext_CNNNetwork, HashWithConfig) {
|
||||
auto net1 = createNetwork();
|
||||
auto net2 = createNetwork();
|
||||
ASSERT_NE(NetworkCompilationContext::compute_hash(net1, {{"key", "value"}}),
|
||||
NetworkCompilationContext::compute_hash(net2, {}));
|
||||
ASSERT_EQ(NetworkCompilationContext::compute_hash(net1, {{"key", "value"}}),
|
||||
NetworkCompilationContext::compute_hash(net2, {{"key", "value"}}));
|
||||
}
|
||||
|
||||
TEST(NetworkContext, HashWithPrimitivesPriority) {
|
||||
auto net1 = create_simple_function();
|
||||
auto net2 = create_simple_function();
|
||||
auto net3 = create_simple_function();
|
||||
auto& op2 = net2->get_ops().front()->get_rt_info();
|
||||
TEST(NetworkContext_CNNNetwork, HashWithPrimitivesPriority) {
|
||||
auto net1 = createNetwork();
|
||||
auto net2 = createNetwork();
|
||||
auto net3 = createNetwork();
|
||||
auto& op2 = net2.getFunction()->get_ops().front()->get_rt_info();
|
||||
op2[ov::PrimitivesPriority::get_type_info_static()] = ov::PrimitivesPriority("testPriority");
|
||||
|
||||
auto& op3 = net3->get_ops().front()->get_rt_info();
|
||||
auto& op3 = net3.getFunction()->get_ops().front()->get_rt_info();
|
||||
op3["PrimitivesPriority"] = "testPriority";
|
||||
|
||||
ASSERT_NE(NetworkCompilationContext::compute_hash(net1, {}), NetworkCompilationContext::compute_hash(net2, {}));
|
||||
@@ -200,7 +212,7 @@ TEST(NetworkContext, HashWithPrimitivesPriority) {
|
||||
ASSERT_EQ(NetworkCompilationContext::compute_hash(net2, {}), NetworkCompilationContext::compute_hash(net3, {}));
|
||||
}
|
||||
|
||||
TEST(NetworkContext, HashWithFusedNames) {
|
||||
TEST(NetworkContext_CNNNetwork, HashWithFusedNames) {
|
||||
auto setFusedEmpty = [&](Node::RTMap& rtInfo) {
|
||||
rtInfo[ov::FusedNames::get_type_info_static()] = ov::FusedNames();
|
||||
};
|
||||
@@ -210,7 +222,7 @@ TEST(NetworkContext, HashWithFusedNames) {
|
||||
checkCustomRt(setFusedEmpty, setFused);
|
||||
}
|
||||
|
||||
TEST(NetworkContext, HashWithPrimitivesPriorityType) {
|
||||
TEST(NetworkContext_CNNNetwork, HashWithPrimitivesPriorityType) {
|
||||
auto setPrimEmpty = [&](Node::RTMap& rtInfo) {
|
||||
rtInfo[ov::PrimitivesPriority::get_type_info_static()] = ov::PrimitivesPriority("");
|
||||
};
|
||||
@@ -220,14 +232,14 @@ TEST(NetworkContext, HashWithPrimitivesPriorityType) {
|
||||
checkCustomRt(setPrimEmpty, setPrim);
|
||||
}
|
||||
|
||||
TEST(NetworkContext, HashWithAffinity) {
|
||||
auto net1 = create_simple_function();
|
||||
auto net2 = create_simple_function();
|
||||
auto net3 = create_simple_function();
|
||||
auto& op2 = net2->get_ops().front()->get_rt_info();
|
||||
TEST(NetworkContext_CNNNetwork, HashWithAffinity) {
|
||||
auto net1 = createNetwork();
|
||||
auto net2 = createNetwork();
|
||||
auto net3 = createNetwork();
|
||||
auto& op2 = net2.getFunction()->get_ops().front()->get_rt_info();
|
||||
op2["affinity"] = "testAffinity";
|
||||
|
||||
auto& op3 = net3->get_ops().front()->get_rt_info();
|
||||
auto& op3 = net3.getFunction()->get_ops().front()->get_rt_info();
|
||||
op3["affinity"] = "testAffinity";
|
||||
|
||||
ASSERT_NE(NetworkCompilationContext::compute_hash(net1, {}), NetworkCompilationContext::compute_hash(net2, {}));
|
||||
@@ -235,18 +247,18 @@ TEST(NetworkContext, HashWithAffinity) {
|
||||
ASSERT_EQ(NetworkCompilationContext::compute_hash(net2, {}), NetworkCompilationContext::compute_hash(net3, {}));
|
||||
}
|
||||
|
||||
TEST(NetworkContext, HashWithFutureRt_string) {
|
||||
auto net1 = create_simple_function();
|
||||
auto net2 = create_simple_function();
|
||||
auto net3 = create_simple_function();
|
||||
TEST(NetworkContext_CNNNetwork, HashWithFutureRt_string) {
|
||||
auto net1 = createNetwork();
|
||||
auto net2 = createNetwork();
|
||||
auto net3 = createNetwork();
|
||||
|
||||
auto& op1 = net1->get_ops().front()->get_rt_info();
|
||||
auto& op1 = net1.getFunction()->get_ops().front()->get_rt_info();
|
||||
op1["someFutureKey"] = "hello";
|
||||
|
||||
auto& op2 = net2->get_ops().front()->get_rt_info();
|
||||
auto& op2 = net2.getFunction()->get_ops().front()->get_rt_info();
|
||||
op2["someFutureKey"] = "hello";
|
||||
|
||||
auto& op3 = net3->get_ops().front()->get_rt_info();
|
||||
auto& op3 = net3.getFunction()->get_ops().front()->get_rt_info();
|
||||
op3["someFutureKey"] = "olleh";
|
||||
|
||||
ASSERT_EQ(NetworkCompilationContext::compute_hash(net1, {}), NetworkCompilationContext::compute_hash(net2, {}));
|
||||
@@ -254,18 +266,18 @@ TEST(NetworkContext, HashWithFutureRt_string) {
|
||||
ASSERT_NE(NetworkCompilationContext::compute_hash(net2, {}), NetworkCompilationContext::compute_hash(net3, {}));
|
||||
}
|
||||
|
||||
TEST(NetworkContext, HashWithFutureRt_int64) {
|
||||
auto net1 = create_simple_function();
|
||||
auto net2 = create_simple_function();
|
||||
auto net3 = create_simple_function();
|
||||
TEST(NetworkContext_CNNNetwork, HashWithFutureRt_int64) {
|
||||
auto net1 = createNetwork();
|
||||
auto net2 = createNetwork();
|
||||
auto net3 = createNetwork();
|
||||
|
||||
auto& op1 = net1->get_ops().front()->get_rt_info();
|
||||
auto& op1 = net1.getFunction()->get_ops().front()->get_rt_info();
|
||||
op1["someFutureKey"] = int64_t(42);
|
||||
|
||||
auto& op2 = net2->get_ops().front()->get_rt_info();
|
||||
auto& op2 = net2.getFunction()->get_ops().front()->get_rt_info();
|
||||
op2["someFutureKey"] = int64_t(42);
|
||||
|
||||
auto& op3 = net3->get_ops().front()->get_rt_info();
|
||||
auto& op3 = net3.getFunction()->get_ops().front()->get_rt_info();
|
||||
op3["someFutureKey"] = int64_t(43);
|
||||
|
||||
ASSERT_EQ(NetworkCompilationContext::compute_hash(net1, {}), NetworkCompilationContext::compute_hash(net2, {}));
|
||||
@@ -273,7 +285,31 @@ TEST(NetworkContext, HashWithFutureRt_int64) {
|
||||
ASSERT_NE(NetworkCompilationContext::compute_hash(net2, {}), NetworkCompilationContext::compute_hash(net3, {}));
|
||||
}
|
||||
|
||||
TEST(NetworkContext, HashWithTensorNames) {
|
||||
TEST(NetworkContext_CNNNetwork, HashWithLayout) {
|
||||
auto net1 = createNetworkWithLayout("NCH");
|
||||
auto net2 = createNetworkWithLayout("nch");
|
||||
auto net3 = createNetworkWithLayout("?CH");
|
||||
auto net3_1 = createNetworkWithLayout("?C?");
|
||||
auto net4 = createNetworkWithLayout("");
|
||||
auto fun5 = create_simple_function();
|
||||
fun5->get_parameters()[0]->set_layout("NCH");
|
||||
fun5->get_parameters()[0]->set_layout("");
|
||||
fun5->get_results()[0]->set_layout("NHC");
|
||||
fun5->get_results()[0]->set_layout(ov::Layout());
|
||||
auto net5 = CNNNetwork(fun5);
|
||||
|
||||
EXPECT_EQ(NetworkCompilationContext::compute_hash(net1, {}), NetworkCompilationContext::compute_hash(net2, {}));
|
||||
|
||||
EXPECT_NE(NetworkCompilationContext::compute_hash(net2, {}), NetworkCompilationContext::compute_hash(net3, {}));
|
||||
|
||||
EXPECT_NE(NetworkCompilationContext::compute_hash(net3, {}), NetworkCompilationContext::compute_hash(net3_1, {}));
|
||||
|
||||
EXPECT_NE(NetworkCompilationContext::compute_hash(net3, {}), NetworkCompilationContext::compute_hash(net4, {}));
|
||||
|
||||
EXPECT_EQ(NetworkCompilationContext::compute_hash(net4, {}), NetworkCompilationContext::compute_hash(net5, {}));
|
||||
}
|
||||
|
||||
TEST(NetworkContext_CNNNetwork, HashWithTensorNames) {
|
||||
auto fun1 = create_simple_function();
|
||||
auto fun2 = create_simple_function();
|
||||
auto fun3 = create_simple_function();
|
||||
@@ -293,25 +329,50 @@ TEST(NetworkContext, HashWithTensorNames) {
|
||||
fun1->input().set_names(names1);
|
||||
fun2->input().set_names(names2);
|
||||
|
||||
ASSERT_EQ(NetworkCompilationContext::compute_hash(fun1, {}), NetworkCompilationContext::compute_hash(fun2, {}));
|
||||
auto net1 = CNNNetwork(fun1);
|
||||
auto net2 = CNNNetwork(fun2);
|
||||
auto net3 = CNNNetwork(fun3);
|
||||
|
||||
ASSERT_NE(NetworkCompilationContext::compute_hash(fun2, {}), NetworkCompilationContext::compute_hash(fun3, {}));
|
||||
ASSERT_EQ(NetworkCompilationContext::compute_hash(net1, {}), NetworkCompilationContext::compute_hash(net2, {}));
|
||||
|
||||
ASSERT_NE(NetworkCompilationContext::compute_hash(net2, {}), NetworkCompilationContext::compute_hash(net3, {}));
|
||||
}
|
||||
|
||||
TEST(NetworkContext, HashWithDifferentResults) {
|
||||
auto net1 = create_simple_function();
|
||||
auto net2 = create_simple_function();
|
||||
net2->remove_result(net2->get_results().front());
|
||||
auto net3 = create_simple_function();
|
||||
net3->remove_result(net3->get_results().front());
|
||||
TEST(NetworkContext_CNNNetwork, HashWithDifferentResults) {
|
||||
auto net1 = createNetwork();
|
||||
auto net2 = createNetwork();
|
||||
net2.getFunction()->remove_result(net2.getFunction()->get_results().front());
|
||||
auto net3 = createNetwork();
|
||||
net3.getFunction()->remove_result(net3.getFunction()->get_results().front());
|
||||
ASSERT_NE(NetworkCompilationContext::compute_hash(net1, {}), NetworkCompilationContext::compute_hash(net2, {}));
|
||||
ASSERT_EQ(NetworkCompilationContext::compute_hash(net2, {}), NetworkCompilationContext::compute_hash(net3, {}));
|
||||
}
|
||||
|
||||
TEST(NetworkContext_CNNNetwork, HashWithDifferentMeanValues) {
|
||||
auto updatePreprocess = [&](CNNNetwork& cnnNet) {
|
||||
auto& preProcess = cnnNet.getInputsInfo().begin()->second->getPreProcess();
|
||||
preProcess.init(3);
|
||||
preProcess[0]->stdScale = 2;
|
||||
preProcess[1]->stdScale = 3;
|
||||
preProcess[2]->stdScale = 4;
|
||||
preProcess[0]->meanValue = 0;
|
||||
preProcess[1]->meanValue = 1;
|
||||
preProcess[2]->meanValue = 2;
|
||||
preProcess.setVariant(InferenceEngine::MEAN_VALUE);
|
||||
};
|
||||
auto net1 = createNetwork();
|
||||
auto net2 = createNetwork();
|
||||
updatePreprocess(net2);
|
||||
auto net3 = createNetwork();
|
||||
updatePreprocess(net3);
|
||||
ASSERT_NE(NetworkCompilationContext::compute_hash(net1, {}), NetworkCompilationContext::compute_hash(net2, {}));
|
||||
ASSERT_EQ(NetworkCompilationContext::compute_hash(net2, {}), NetworkCompilationContext::compute_hash(net3, {}));
|
||||
}
|
||||
|
||||
// Verify all internal hash calculations are thread-safe (like ngraph::function serialization)
|
||||
TEST(NetworkContext, HashOfSameMultiThreading) {
|
||||
auto net1 = create_simple_function();
|
||||
auto net2 = create_simple_function();
|
||||
TEST(NetworkContext_CNNNetwork, HashOfSameMultiThreading) {
|
||||
auto net1 = createNetwork();
|
||||
auto net2 = createNetwork();
|
||||
std::atomic_bool fail{false};
|
||||
const auto TEST_DURATION_MS = 1000;
|
||||
auto start = high_resolution_clock::now();
|
||||
|
||||
Reference in New Issue
Block a user