Revert "Revert "Revert query_model logic""

This reverts commit 0df850838a.
This commit is contained in:
Ilya Churaev
2023-01-26 08:03:19 +04:00
parent f8a47e30ae
commit 3bb2af2250
7 changed files with 417 additions and 149 deletions

View File

@@ -45,11 +45,71 @@ static int32_t as_int32_t(T v) {
namespace {
uint64_t calculate_td(const InferenceEngine::TensorDesc& td, uint64_t _seed) {
uint64_t seed = _seed;
uint64_t compute_model_hash(const std::shared_ptr<const ov::Model>& model, const ov::AnyMap& compileOptions) {
OPENVINO_ASSERT(model);
uint64_t seed = 0;
// 1. Calculate hash on function
ov::pass::Manager m;
m.register_pass<ov::pass::FixRtInfo>();
m.register_pass<ov::pass::Hash>(seed);
m.run_passes(std::const_pointer_cast<ov::Model>(model));
// 2. Compute hash on serialized data and options
for (const auto& kvp : compileOptions) {
seed = ov::hash_combine(seed, kvp.first + kvp.second.as<std::string>());
}
// 3. Add runtime information which may not be serialized
for (const auto& op : model->get_ordered_ops()) {
const auto& rt = op->get_rt_info();
for (const auto& rtMapData : rt) {
seed = ov::hash_combine(seed, rtMapData.first);
std::stringstream strm;
rtMapData.second.print(strm);
seed = ov::hash_combine(seed, strm.str());
}
}
// 4. Legacy part if CNNNetwork is used with new Plugin API
for (auto&& input : model->inputs()) {
auto& rt_info = input.get_rt_info();
auto it = rt_info.find("ie_legacy_td");
if (it != rt_info.end()) {
auto td = it->second.as<InferenceEngine::TensorDesc>();
seed = ov::hash_combine(seed, ov::as_int32_t(td.getPrecision()));
seed = ov::hash_combine(seed, ov::as_int32_t(td.getLayout()));
}
it = rt_info.find("ie_legacy_preproc");
if (it != rt_info.end()) {
auto preproc = it->second.as<InferenceEngine::PreProcessInfo>();
seed = ov::hash_combine(seed, ov::as_int32_t(preproc.getMeanVariant()));
if (preproc.getMeanVariant() == InferenceEngine::MeanVariant::MEAN_VALUE) {
seed = ov::hash_combine(seed, preproc.getNumberOfChannels());
for (size_t c = 0; c < preproc.getNumberOfChannels(); ++c) {
const InferenceEngine::PreProcessChannel::Ptr& channelInfo = preproc[c];
seed = ov::hash_combine(seed, channelInfo->stdScale);
seed = ov::hash_combine(seed, channelInfo->meanValue);
}
} else if (preproc.getMeanVariant() == InferenceEngine::MeanVariant::MEAN_IMAGE) {
// TODO: think if we need to compute hash for mean image if it exists
}
}
}
for (auto&& output : model->outputs()) {
auto& rt_info = output.get_rt_info();
auto it = rt_info.find("ie_legacy_td");
if (it != rt_info.end()) {
auto td = it->second.as<InferenceEngine::TensorDesc>();
seed = ov::hash_combine(seed, ov::as_int32_t(td.getPrecision()));
seed = ov::hash_combine(seed, ov::as_int32_t(td.getLayout()));
}
}
seed = ov::hash_combine(seed, ov::as_int32_t(td.getPrecision()));
seed = ov::hash_combine(seed, ov::as_int32_t(td.getLayout()));
return seed;
}
@@ -83,64 +143,47 @@ std::string NetworkCompilationContext::compute_hash(const std::shared_ptr<const
const ov::AnyMap& compileOptions) {
OV_ITT_SCOPE(FIRST_INFERENCE, ov::itt::domains::IE_RT, "NetworkCompilationContext::compute_hash - Model");
OPENVINO_ASSERT(model);
return std::to_string(compute_model_hash(model, compileOptions));
}
std::string NetworkCompilationContext::compute_hash(const InferenceEngine::CNNNetwork& network,
const ov::AnyMap& compileOptions) {
OV_ITT_SCOPE(FIRST_INFERENCE, ov::itt::domains::IE_RT, "NetworkCompilationContext::compute_hash - CNN");
OPENVINO_ASSERT(network.getFunction());
uint64_t seed = 0;
// 1. Calculate hash on function
ov::pass::Manager m;
m.register_pass<ov::pass::FixRtInfo>();
m.register_pass<ov::pass::Hash>(seed);
m.run_passes(std::const_pointer_cast<ov::Model>(model));
// 2. Compute hash on serialized data and options
for (const auto& kvp : compileOptions) {
seed = ov::hash_combine(seed, kvp.first + kvp.second.as<std::string>());
}
// 3. Add runtime information which may not be serialized
for (const auto& op : model->get_ordered_ops()) {
const auto& rt = op->get_rt_info();
for (const auto& rtMapData : rt) {
seed = ov::hash_combine(seed, rtMapData.first);
std::stringstream strm;
rtMapData.second.print(strm);
seed = ov::hash_combine(seed, strm.str());
}
}
seed = compute_model_hash(network.getFunction(), compileOptions);
// 4. Legacy part if CNNNetwork is used with new Plugin API
for (auto&& input : model->inputs()) {
auto& rt_info = input.get_rt_info();
// 4. Add inputs info
for (const auto& input : network.getInputsInfo()) {
InferenceEngine::InputInfo::Ptr info = input.second;
seed = hash_combine(seed, as_int32_t(info->getPrecision()));
seed = hash_combine(seed, as_int32_t(info->getLayout()));
auto it = rt_info.find("ie_legacy_td");
if (it != rt_info.end()) {
seed = calculate_td(it->second.as<InferenceEngine::TensorDesc>(), seed);
}
const InferenceEngine::PreProcessInfo& preproc = info->getPreProcess();
seed = hash_combine(seed, as_int32_t(preproc.getMeanVariant()));
it = rt_info.find("ie_legacy_preproc");
if (it != rt_info.end()) {
auto preproc = it->second.as<InferenceEngine::PreProcessInfo>();
seed = ov::hash_combine(seed, ov::as_int32_t(preproc.getMeanVariant()));
if (preproc.getMeanVariant() == InferenceEngine::MeanVariant::MEAN_VALUE) {
seed = ov::hash_combine(seed, preproc.getNumberOfChannels());
for (size_t c = 0; c < preproc.getNumberOfChannels(); ++c) {
const InferenceEngine::PreProcessChannel::Ptr& channelInfo = preproc[c];
seed = ov::hash_combine(seed, channelInfo->stdScale);
seed = ov::hash_combine(seed, channelInfo->meanValue);
}
} else if (preproc.getMeanVariant() == InferenceEngine::MeanVariant::MEAN_IMAGE) {
// TODO: think if we need to compute hash for mean image if it exists
if (preproc.getMeanVariant() == InferenceEngine::MeanVariant::MEAN_VALUE) {
seed = hash_combine(seed, preproc.getNumberOfChannels());
for (size_t c = 0; c < preproc.getNumberOfChannels(); ++c) {
const InferenceEngine::PreProcessChannel::Ptr& channelInfo = preproc[c];
seed = hash_combine(seed, channelInfo->stdScale);
seed = hash_combine(seed, channelInfo->meanValue);
}
} else if (preproc.getMeanVariant() == InferenceEngine::MeanVariant::MEAN_IMAGE) {
// TODO: think if we need to compute hash for mean image if it exists
}
}
for (auto&& output : model->outputs()) {
auto& rt_info = output.get_rt_info();
auto it = rt_info.find("ie_legacy_td");
if (it != rt_info.end()) {
seed = calculate_td(it->second.as<InferenceEngine::TensorDesc>(), seed);
}
// 5. Add outputs info
for (const auto& output : network.getOutputsInfo()) {
InferenceEngine::DataPtr info = output.second;
seed = hash_combine(seed, as_int32_t(info->getPrecision()));
seed = hash_combine(seed, as_int32_t(info->getLayout()));
}
return std::to_string(seed);

View File

@@ -26,6 +26,8 @@ class Model;
struct NetworkCompilationContext final {
static std::string calculate_file_info(const std::string& filePath);
static std::string compute_hash(const InferenceEngine::CNNNetwork& network, const ov::AnyMap& compileOptions);
static std::string compute_hash(const std::shared_ptr<const ov::Model>& model, const ov::AnyMap& compileOptions);
static std::string compute_hash(const std::string& modelName, const ov::AnyMap& compileOptions);

View File

@@ -430,7 +430,9 @@ ov::SoPtr<InferenceEngine::IExecutableNetworkInternal> ov::CoreImpl::compile_mod
res = compile_model_impl(cnnNetwork.getFunction(), plugin, parsed._config, {}, cacheContent);
}
} else if (cacheManager) {
res = plugin.compile_model(model_path, parsed._config);
auto cnnNetwork = ReadNetwork(model_path, std::string());
// TODO: 'validation' for dynamic API doesn't work for this case, as it affects a lot of plugin API
res = compile_model(plugin, cnnNetwork.getFunction(), {}, parsed._config);
} else {
auto cnnNetwork = ReadNetwork(model_path, std::string());
res = compile_model_impl(cnnNetwork.getFunction(), plugin, parsed._config, {}, cacheContent);
@@ -483,37 +485,7 @@ ov::SupportedOpsMap ov::CoreImpl::query_model(const std::shared_ptr<const ov::Mo
const ov::AnyMap& config) const {
OV_ITT_SCOPED_TASK(ov::itt::domains::IE, "Core::query_model");
auto parsed = parseDeviceNameIntoConfig(device_name, config);
auto ret = get_plugin(parsed._deviceName).query_model(model, parsed._config);
auto specialized_function = model->clone();
std::string defDevice = ret.begin()->second;
ngraph::pass::ConstantFolding().run_on_model(specialized_function);
std::unordered_set<std::string> opNames;
for (const auto& op : specialized_function->get_ops())
opNames.emplace(op->get_friendly_name());
for (const auto& op : model->get_ops()) {
if (opNames.find(op->get_friendly_name()) == opNames.end()) {
ret[op->get_friendly_name()] = defDevice;
}
}
for (const auto& op : model->get_ops()) {
if (!ret.count(op->get_friendly_name()) && std::dynamic_pointer_cast<ngraph::op::Constant>(op)) {
bool are_all_users_supported = true;
for (const auto& user : op->output(0).get_target_inputs()) {
if (!ret.count(user.get_node()->get_friendly_name())) {
are_all_users_supported = false;
break;
}
}
if (are_all_users_supported) {
ret[op->get_friendly_name()] = defDevice;
}
}
}
return ret;
return get_plugin(parsed._deviceName).query_model(model, parsed._config);
}
std::vector<std::string> ov::CoreImpl::get_available_devices() const {

View File

@@ -219,7 +219,9 @@ private:
const InferenceEngine::CNNNetwork& model,
ov::Plugin& plugin,
const std::map<std::string, std::string>& parsedConfig,
const InferenceEngine::RemoteContext::Ptr& context);
const InferenceEngine::RemoteContext::Ptr& context,
const CacheContent& cacheContent,
bool forceDisableCache = false);
public:
CoreImpl(bool _newAPI);

View File

@@ -25,7 +25,9 @@ ov::SoPtr<InferenceEngine::IExecutableNetworkInternal> ov::CoreImpl::LoadNetwork
const InferenceEngine::CNNNetwork& network,
ov::Plugin& plugin,
const std::map<std::string, std::string>& parsedConfig,
const InferenceEngine::RemoteContext::Ptr& context) {
const InferenceEngine::RemoteContext::Ptr& context,
const CacheContent& cacheContent,
bool forceDisableCache) {
OV_ITT_SCOPED_TASK(ov::itt::domains::IE, "CoreImpl::compile_model_impl");
ov::SoPtr<InferenceEngine::IExecutableNetworkInternal> execNetwork;
auto wrapper = std::dynamic_pointer_cast<InferenceEngine::IPluginWrapper>(plugin.m_ptr);
@@ -34,6 +36,21 @@ ov::SoPtr<InferenceEngine::IExecutableNetworkInternal> ov::CoreImpl::LoadNetwork
execNetwork = {context ? old_plugin->LoadNetwork(network, parsedConfig, context)
: old_plugin->LoadNetwork(network, parsedConfig),
plugin.m_so};
if (!forceDisableCache && cacheContent.cacheManager && device_supports_import_export(plugin)) {
try {
// need to export network for further import from "cache"
OV_ITT_SCOPE(FIRST_INFERENCE, InferenceEngine::itt::domains::IE_LT, "Core::LoadNetwork::Export");
cacheContent.cacheManager->writeCacheEntry(cacheContent.blobId, [&](std::ostream& networkStream) {
networkStream << ov::CompiledBlobHeader(
InferenceEngine::GetInferenceEngineVersion()->buildNumber,
ov::NetworkCompilationContext::calculate_file_info(cacheContent.modelPath));
execNetwork->Export(networkStream);
});
} catch (...) {
cacheContent.cacheManager->removeCacheEntry(cacheContent.blobId);
throw;
}
}
return execNetwork;
}
@@ -72,23 +89,72 @@ ov::SoPtr<InferenceEngine::IExecutableNetworkInternal> ov::CoreImpl::LoadNetwork
auto plugin = get_plugin(parsed._deviceName);
auto res = LoadNetworkImpl(network, plugin, parsed._config, context);
ov::SoPtr<InferenceEngine::IExecutableNetworkInternal> res;
auto conf = ov::any_copy(parsed._config);
auto cacheManager =
coreConfig.get_cache_config_for_device(parsed._deviceName, device_supports_cache_dir(plugin), conf)
._cacheManager;
auto cacheContent = CacheContent{cacheManager};
if (cacheManager && device_supports_import_export(plugin)) {
cacheContent.blobId = ov::NetworkCompilationContext::compute_hash(
network,
create_compile_config(plugin, parsed._deviceName, ov::any_copy(parsed._config)));
bool loadedFromCache = false;
auto lock = cacheGuard.getHashLock(cacheContent.blobId);
res = load_model_from_cache(cacheContent, plugin, conf, {context, {}}, loadedFromCache);
if (!loadedFromCache) {
res = LoadNetworkImpl(network, plugin, parsed._config, context, cacheContent);
} else {
// Temporary workaround until all plugins support caching of original model inputs
InferenceEngine::SetExeNetworkInfo(res._ptr, network.getFunction(), isNewAPI());
}
} else {
res = LoadNetworkImpl(network, plugin, parsed._config, context, cacheContent);
}
return res;
}
InferenceEngine::SoExecutableNetworkInternal ov::CoreImpl::LoadNetwork(
const InferenceEngine::CNNNetwork& network,
const std::string& deviceName,
const std::string& deviceNameOrig,
const std::map<std::string, std::string>& config) {
OV_ITT_SCOPE(FIRST_INFERENCE, InferenceEngine::itt::domains::IE_LT, "Core::LoadNetwork::CNN");
if (network.getFunction()) {
auto compiled_model =
compile_model(ov::legacy_convert::convert_model(network, isNewAPI()), deviceName, any_copy(config));
compile_model(ov::legacy_convert::convert_model(network, isNewAPI()), deviceNameOrig, any_copy(config));
return {compiled_model._ptr, compiled_model._so};
}
auto parsed = parseDeviceNameIntoConfig(deviceName, config);
std::string deviceName = deviceNameOrig;
std::map<std::string, std::string> config_with_batch = config;
bool forceDisableCache = config_with_batch.count(CONFIG_KEY_INTERNAL(FORCE_DISABLE_CACHE)) > 0;
auto parsed = parseDeviceNameIntoConfig(deviceName, config_with_batch);
if (forceDisableCache) {
// remove this config key from parsed as plugins can throw unsupported exception
parsed._config.erase(CONFIG_KEY_INTERNAL(FORCE_DISABLE_CACHE));
}
auto plugin = get_plugin(parsed._deviceName);
auto res = LoadNetworkImpl(network, plugin, parsed._config, nullptr);
ov::SoPtr<InferenceEngine::IExecutableNetworkInternal> res;
auto conf = ov::any_copy(parsed._config);
auto cacheManager =
coreConfig.get_cache_config_for_device(parsed._deviceName, device_supports_cache_dir(plugin), conf)
._cacheManager;
auto cacheContent = CacheContent{cacheManager};
if (!forceDisableCache && cacheManager && device_supports_import_export(plugin)) {
cacheContent.blobId = ov::NetworkCompilationContext::compute_hash(
network,
create_compile_config(plugin, parsed._deviceName, ov::any_copy(parsed._config)));
bool loadedFromCache = false;
auto lock = cacheGuard.getHashLock(cacheContent.blobId);
res = load_model_from_cache(cacheContent, plugin, conf, {}, loadedFromCache);
if (!loadedFromCache) {
res = LoadNetworkImpl(network, plugin, parsed._config, nullptr, cacheContent, forceDisableCache);
} else {
// Temporary workaround until all plugins support caching of original model inputs
InferenceEngine::SetExeNetworkInfo(res._ptr, network.getFunction(), isNewAPI());
}
} else {
res = LoadNetworkImpl(network, plugin, parsed._config, nullptr, cacheContent, forceDisableCache);
}
return {res._ptr, res._so};
}
@@ -98,9 +164,54 @@ InferenceEngine::SoExecutableNetworkInternal ov::CoreImpl::LoadNetwork(
const std::map<std::string, std::string>& config,
const std::function<void(const InferenceEngine::CNNNetwork&)>& val) {
OV_ITT_SCOPE(FIRST_INFERENCE, ie::itt::domains::IE_LT, "Core::LoadNetwork::Path");
auto compiled_model = compile_model(modelPath, deviceName, any_copy(config));
return {compiled_model._ptr, compiled_model._so};
auto parsed = parseDeviceNameIntoConfig(deviceName, config);
auto plugin = get_plugin(parsed._deviceName);
ov::SoPtr<InferenceEngine::IExecutableNetworkInternal> res;
auto conf = any_copy(parsed._config);
auto cacheManager =
coreConfig.get_cache_config_for_device(parsed._deviceName, device_supports_cache_dir(plugin), conf)
._cacheManager;
auto cacheContent = CacheContent{cacheManager, modelPath};
if (cacheManager && device_supports_import_export(plugin)) {
bool loadedFromCache = false;
cacheContent.blobId =
ov::NetworkCompilationContext::compute_hash(modelPath,
create_compile_config(plugin, parsed._deviceName, conf));
auto lock = cacheGuard.getHashLock(cacheContent.blobId);
res = load_model_from_cache(cacheContent, plugin, conf, {}, loadedFromCache);
if (!loadedFromCache) {
auto cnnNetwork = ReadNetwork(modelPath, std::string());
if (val) {
val(cnnNetwork);
}
if (cnnNetwork.getFunction()) {
res = compile_model_impl(ov::legacy_convert::convert_model(cnnNetwork, isNewAPI()),
plugin,
conf,
{},
cacheContent);
} else {
res = LoadNetworkImpl(cnnNetwork, plugin, parsed._config, nullptr, cacheContent);
}
}
} else if (cacheManager) {
res = plugin.compile_model(modelPath, conf);
} else {
auto cnnNetwork = ReadNetwork(modelPath, std::string());
if (val) {
val(cnnNetwork);
}
if (cnnNetwork.getFunction()) {
res = compile_model_impl(ov::legacy_convert::convert_model(cnnNetwork, isNewAPI()),
plugin,
conf,
{},
cacheContent);
} else {
res = LoadNetworkImpl(cnnNetwork, plugin, parsed._config, nullptr, cacheContent);
}
}
return {res._ptr, res._so};
}
InferenceEngine::SoExecutableNetworkInternal ov::CoreImpl::LoadNetwork(
@@ -136,8 +247,43 @@ InferenceEngine::QueryNetworkResult ov::CoreImpl::QueryNetwork(const InferenceEn
return ret;
}
auto res = query_model(network.getFunction(), deviceName, any_copy(config));
if (!network.getFunction() || res.empty()) {
ret.rc = InferenceEngine::GENERAL_ERROR;
return ret;
}
ret.supportedLayersMap = res;
const auto& func = network.getFunction();
auto specialized_function = func->clone();
std::string defDevice = ret.supportedLayersMap.begin()->second;
ngraph::pass::ConstantFolding().run_on_model(specialized_function);
std::unordered_set<std::string> opNames;
for (const auto& op : specialized_function->get_ops())
opNames.emplace(op->get_friendly_name());
for (const auto& op : func->get_ops()) {
if (opNames.find(op->get_friendly_name()) == opNames.end()) {
ret.supportedLayersMap[op->get_friendly_name()] = defDevice;
}
}
for (const auto& op : func->get_ops()) {
if (!ret.supportedLayersMap.count(op->get_friendly_name()) &&
std::dynamic_pointer_cast<ngraph::op::Constant>(op)) {
bool are_all_users_supported = true;
for (const auto& user : op->output(0).get_target_inputs()) {
if (!ret.supportedLayersMap.count(user.get_node()->get_friendly_name())) {
are_all_users_supported = false;
break;
}
}
if (are_all_users_supported) {
ret.supportedLayersMap[op->get_friendly_name()] = defDevice;
}
}
}
return ret;
}

View File

@@ -147,3 +147,45 @@ TEST_F(CNNNetworkTests, throwsHasDynamicInputs_queryNetwork) {
EXPECT_TRUE(std::string(e.what()).find("p3_2") == std::string::npos) << e.what();
}
}
class CNNNetworkTests_LoadFromFileTest : public ::testing::Test {
protected:
std::string modelName{};
std::string weightsName{};
InferenceEngine::Core core;
public:
void SetUp() override {
auto filePrefix = CommonTestUtils::generateTestFilePrefix();
modelName = filePrefix + "_CNNNetworkTests_LoadFromFileTest.xml";
weightsName = filePrefix + "_CNNNetworkTests_LoadFromFileTest.bin";
std::shared_ptr<ov::Model> model = CNNNetworkTests_create_model();
ov::pass::Serialize(modelName, weightsName).run_on_model(model);
ASSERT_NO_THROW(
core.RegisterPlugin(ov::util::make_plugin_library_name(CommonTestUtils::getExecutableDirectory(),
std::string("mock_engine") + IE_BUILD_POSTFIX),
"mock"));
}
void TearDown() override {
CommonTestUtils::removeIRFiles(modelName, weightsName);
core.UnregisterPlugin("mock");
}
};
#if defined(ENABLE_OV_IR_FRONTEND)
TEST_F(CNNNetworkTests_LoadFromFileTest, throwsHasDynamicInputs_fromPath) {
try {
core.LoadNetwork(modelName, "mock");
FAIL() << "LoadNetwork with dynamic inputs shall throw";
} catch (const ov::AssertFailure& e) {
EXPECT_TRUE(std::string(e.what()).find("InferenceEngine::Core::LoadNetwork") != std::string::npos) << e.what();
EXPECT_TRUE(std::string(e.what()).find("p1_1") != std::string::npos) << e.what();
EXPECT_TRUE(std::string(e.what()).find("p1_2") != std::string::npos) << e.what();
EXPECT_TRUE(std::string(e.what()).find("p2_1") != std::string::npos) << e.what();
EXPECT_TRUE(std::string(e.what()).find("p2_2") != std::string::npos) << e.what();
EXPECT_TRUE(std::string(e.what()).find("p3_1") == std::string::npos) << e.what();
EXPECT_TRUE(std::string(e.what()).find("p3_2") == std::string::npos) << e.what();
}
}
#endif // defined(ENABLE_OV_IR_FRONTEND)

View File

@@ -147,52 +147,64 @@ static std::shared_ptr<ngraph::Function> create_simple_function() {
return func;
}
static CNNNetwork createNetwork() {
CNNNetwork res(create_simple_function());
return res;
}
static CNNNetwork createNetworkWithLayout(const ov::Layout& layout) {
auto fun = create_simple_function();
fun->get_parameters()[0]->set_layout(layout);
fun->get_results()[0]->set_layout(layout);
return CNNNetwork(fun);
}
static void checkCustomRt(const std::function<void(Node::RTMap&)>& emptyCb,
const std::function<void(Node::RTMap&, const std::string& name)>& nameCb) {
auto model1 = create_simple_function();
auto model2 = create_simple_function();
auto& op1 = model1->get_ops().front()->get_rt_info();
auto& op2 = model2->get_ops().front()->get_rt_info();
auto net1 = createNetwork();
auto net2 = createNetwork();
auto& op1 = net1.getFunction()->get_ops().front()->get_rt_info();
auto& op2 = net2.getFunction()->get_ops().front()->get_rt_info();
emptyCb(op2);
ASSERT_NE(NetworkCompilationContext::compute_hash(model1, {}), NetworkCompilationContext::compute_hash(model2, {}));
ASSERT_NE(NetworkCompilationContext::compute_hash(net1, {}), NetworkCompilationContext::compute_hash(net2, {}));
emptyCb(op1);
ASSERT_EQ(NetworkCompilationContext::compute_hash(model1, {}), NetworkCompilationContext::compute_hash(model2, {}));
ASSERT_EQ(NetworkCompilationContext::compute_hash(net1, {}), NetworkCompilationContext::compute_hash(net2, {}));
nameCb(op1, "test");
ASSERT_NE(NetworkCompilationContext::compute_hash(model1, {}), NetworkCompilationContext::compute_hash(model2, {}));
ASSERT_NE(NetworkCompilationContext::compute_hash(net1, {}), NetworkCompilationContext::compute_hash(net2, {}));
nameCb(op2, "test");
ASSERT_EQ(NetworkCompilationContext::compute_hash(model1, {}), NetworkCompilationContext::compute_hash(model2, {}));
ASSERT_EQ(NetworkCompilationContext::compute_hash(net1, {}), NetworkCompilationContext::compute_hash(net2, {}));
nameCb(op1, "test2");
ASSERT_NE(NetworkCompilationContext::compute_hash(model1, {}), NetworkCompilationContext::compute_hash(model2, {}));
ASSERT_NE(NetworkCompilationContext::compute_hash(net1, {}), NetworkCompilationContext::compute_hash(net2, {}));
}
TEST(NetworkContext, HashOfSame) {
auto model1 = create_simple_function();
auto model2 = create_simple_function();
ASSERT_EQ(NetworkCompilationContext::compute_hash(model1, {}), NetworkCompilationContext::compute_hash(model2, {}));
TEST(NetworkContext_CNNNetwork, HashOfSame) {
auto net1 = createNetwork();
auto net2 = createNetwork();
ASSERT_EQ(NetworkCompilationContext::compute_hash(net1, {}), NetworkCompilationContext::compute_hash(net2, {}));
}
TEST(NetworkContext, HashWithConfig) {
auto net1 = create_simple_function();
auto net2 = create_simple_function();
TEST(NetworkContext_CNNNetwork, HashWithConfig) {
auto net1 = createNetwork();
auto net2 = createNetwork();
ASSERT_NE(NetworkCompilationContext::compute_hash(net1, {{"key", "value"}}),
NetworkCompilationContext::compute_hash(net2, {}));
ASSERT_EQ(NetworkCompilationContext::compute_hash(net1, {{"key", "value"}}),
NetworkCompilationContext::compute_hash(net2, {{"key", "value"}}));
}
TEST(NetworkContext, HashWithPrimitivesPriority) {
auto net1 = create_simple_function();
auto net2 = create_simple_function();
auto net3 = create_simple_function();
auto& op2 = net2->get_ops().front()->get_rt_info();
TEST(NetworkContext_CNNNetwork, HashWithPrimitivesPriority) {
auto net1 = createNetwork();
auto net2 = createNetwork();
auto net3 = createNetwork();
auto& op2 = net2.getFunction()->get_ops().front()->get_rt_info();
op2[ov::PrimitivesPriority::get_type_info_static()] = ov::PrimitivesPriority("testPriority");
auto& op3 = net3->get_ops().front()->get_rt_info();
auto& op3 = net3.getFunction()->get_ops().front()->get_rt_info();
op3["PrimitivesPriority"] = "testPriority";
ASSERT_NE(NetworkCompilationContext::compute_hash(net1, {}), NetworkCompilationContext::compute_hash(net2, {}));
@@ -200,7 +212,7 @@ TEST(NetworkContext, HashWithPrimitivesPriority) {
ASSERT_EQ(NetworkCompilationContext::compute_hash(net2, {}), NetworkCompilationContext::compute_hash(net3, {}));
}
TEST(NetworkContext, HashWithFusedNames) {
TEST(NetworkContext_CNNNetwork, HashWithFusedNames) {
auto setFusedEmpty = [&](Node::RTMap& rtInfo) {
rtInfo[ov::FusedNames::get_type_info_static()] = ov::FusedNames();
};
@@ -210,7 +222,7 @@ TEST(NetworkContext, HashWithFusedNames) {
checkCustomRt(setFusedEmpty, setFused);
}
TEST(NetworkContext, HashWithPrimitivesPriorityType) {
TEST(NetworkContext_CNNNetwork, HashWithPrimitivesPriorityType) {
auto setPrimEmpty = [&](Node::RTMap& rtInfo) {
rtInfo[ov::PrimitivesPriority::get_type_info_static()] = ov::PrimitivesPriority("");
};
@@ -220,14 +232,14 @@ TEST(NetworkContext, HashWithPrimitivesPriorityType) {
checkCustomRt(setPrimEmpty, setPrim);
}
TEST(NetworkContext, HashWithAffinity) {
auto net1 = create_simple_function();
auto net2 = create_simple_function();
auto net3 = create_simple_function();
auto& op2 = net2->get_ops().front()->get_rt_info();
TEST(NetworkContext_CNNNetwork, HashWithAffinity) {
auto net1 = createNetwork();
auto net2 = createNetwork();
auto net3 = createNetwork();
auto& op2 = net2.getFunction()->get_ops().front()->get_rt_info();
op2["affinity"] = "testAffinity";
auto& op3 = net3->get_ops().front()->get_rt_info();
auto& op3 = net3.getFunction()->get_ops().front()->get_rt_info();
op3["affinity"] = "testAffinity";
ASSERT_NE(NetworkCompilationContext::compute_hash(net1, {}), NetworkCompilationContext::compute_hash(net2, {}));
@@ -235,18 +247,18 @@ TEST(NetworkContext, HashWithAffinity) {
ASSERT_EQ(NetworkCompilationContext::compute_hash(net2, {}), NetworkCompilationContext::compute_hash(net3, {}));
}
TEST(NetworkContext, HashWithFutureRt_string) {
auto net1 = create_simple_function();
auto net2 = create_simple_function();
auto net3 = create_simple_function();
TEST(NetworkContext_CNNNetwork, HashWithFutureRt_string) {
auto net1 = createNetwork();
auto net2 = createNetwork();
auto net3 = createNetwork();
auto& op1 = net1->get_ops().front()->get_rt_info();
auto& op1 = net1.getFunction()->get_ops().front()->get_rt_info();
op1["someFutureKey"] = "hello";
auto& op2 = net2->get_ops().front()->get_rt_info();
auto& op2 = net2.getFunction()->get_ops().front()->get_rt_info();
op2["someFutureKey"] = "hello";
auto& op3 = net3->get_ops().front()->get_rt_info();
auto& op3 = net3.getFunction()->get_ops().front()->get_rt_info();
op3["someFutureKey"] = "olleh";
ASSERT_EQ(NetworkCompilationContext::compute_hash(net1, {}), NetworkCompilationContext::compute_hash(net2, {}));
@@ -254,18 +266,18 @@ TEST(NetworkContext, HashWithFutureRt_string) {
ASSERT_NE(NetworkCompilationContext::compute_hash(net2, {}), NetworkCompilationContext::compute_hash(net3, {}));
}
TEST(NetworkContext, HashWithFutureRt_int64) {
auto net1 = create_simple_function();
auto net2 = create_simple_function();
auto net3 = create_simple_function();
TEST(NetworkContext_CNNNetwork, HashWithFutureRt_int64) {
auto net1 = createNetwork();
auto net2 = createNetwork();
auto net3 = createNetwork();
auto& op1 = net1->get_ops().front()->get_rt_info();
auto& op1 = net1.getFunction()->get_ops().front()->get_rt_info();
op1["someFutureKey"] = int64_t(42);
auto& op2 = net2->get_ops().front()->get_rt_info();
auto& op2 = net2.getFunction()->get_ops().front()->get_rt_info();
op2["someFutureKey"] = int64_t(42);
auto& op3 = net3->get_ops().front()->get_rt_info();
auto& op3 = net3.getFunction()->get_ops().front()->get_rt_info();
op3["someFutureKey"] = int64_t(43);
ASSERT_EQ(NetworkCompilationContext::compute_hash(net1, {}), NetworkCompilationContext::compute_hash(net2, {}));
@@ -273,7 +285,31 @@ TEST(NetworkContext, HashWithFutureRt_int64) {
ASSERT_NE(NetworkCompilationContext::compute_hash(net2, {}), NetworkCompilationContext::compute_hash(net3, {}));
}
TEST(NetworkContext, HashWithTensorNames) {
TEST(NetworkContext_CNNNetwork, HashWithLayout) {
auto net1 = createNetworkWithLayout("NCH");
auto net2 = createNetworkWithLayout("nch");
auto net3 = createNetworkWithLayout("?CH");
auto net3_1 = createNetworkWithLayout("?C?");
auto net4 = createNetworkWithLayout("");
auto fun5 = create_simple_function();
fun5->get_parameters()[0]->set_layout("NCH");
fun5->get_parameters()[0]->set_layout("");
fun5->get_results()[0]->set_layout("NHC");
fun5->get_results()[0]->set_layout(ov::Layout());
auto net5 = CNNNetwork(fun5);
EXPECT_EQ(NetworkCompilationContext::compute_hash(net1, {}), NetworkCompilationContext::compute_hash(net2, {}));
EXPECT_NE(NetworkCompilationContext::compute_hash(net2, {}), NetworkCompilationContext::compute_hash(net3, {}));
EXPECT_NE(NetworkCompilationContext::compute_hash(net3, {}), NetworkCompilationContext::compute_hash(net3_1, {}));
EXPECT_NE(NetworkCompilationContext::compute_hash(net3, {}), NetworkCompilationContext::compute_hash(net4, {}));
EXPECT_EQ(NetworkCompilationContext::compute_hash(net4, {}), NetworkCompilationContext::compute_hash(net5, {}));
}
TEST(NetworkContext_CNNNetwork, HashWithTensorNames) {
auto fun1 = create_simple_function();
auto fun2 = create_simple_function();
auto fun3 = create_simple_function();
@@ -293,25 +329,50 @@ TEST(NetworkContext, HashWithTensorNames) {
fun1->input().set_names(names1);
fun2->input().set_names(names2);
ASSERT_EQ(NetworkCompilationContext::compute_hash(fun1, {}), NetworkCompilationContext::compute_hash(fun2, {}));
auto net1 = CNNNetwork(fun1);
auto net2 = CNNNetwork(fun2);
auto net3 = CNNNetwork(fun3);
ASSERT_NE(NetworkCompilationContext::compute_hash(fun2, {}), NetworkCompilationContext::compute_hash(fun3, {}));
ASSERT_EQ(NetworkCompilationContext::compute_hash(net1, {}), NetworkCompilationContext::compute_hash(net2, {}));
ASSERT_NE(NetworkCompilationContext::compute_hash(net2, {}), NetworkCompilationContext::compute_hash(net3, {}));
}
TEST(NetworkContext, HashWithDifferentResults) {
auto net1 = create_simple_function();
auto net2 = create_simple_function();
net2->remove_result(net2->get_results().front());
auto net3 = create_simple_function();
net3->remove_result(net3->get_results().front());
TEST(NetworkContext_CNNNetwork, HashWithDifferentResults) {
auto net1 = createNetwork();
auto net2 = createNetwork();
net2.getFunction()->remove_result(net2.getFunction()->get_results().front());
auto net3 = createNetwork();
net3.getFunction()->remove_result(net3.getFunction()->get_results().front());
ASSERT_NE(NetworkCompilationContext::compute_hash(net1, {}), NetworkCompilationContext::compute_hash(net2, {}));
ASSERT_EQ(NetworkCompilationContext::compute_hash(net2, {}), NetworkCompilationContext::compute_hash(net3, {}));
}
TEST(NetworkContext_CNNNetwork, HashWithDifferentMeanValues) {
auto updatePreprocess = [&](CNNNetwork& cnnNet) {
auto& preProcess = cnnNet.getInputsInfo().begin()->second->getPreProcess();
preProcess.init(3);
preProcess[0]->stdScale = 2;
preProcess[1]->stdScale = 3;
preProcess[2]->stdScale = 4;
preProcess[0]->meanValue = 0;
preProcess[1]->meanValue = 1;
preProcess[2]->meanValue = 2;
preProcess.setVariant(InferenceEngine::MEAN_VALUE);
};
auto net1 = createNetwork();
auto net2 = createNetwork();
updatePreprocess(net2);
auto net3 = createNetwork();
updatePreprocess(net3);
ASSERT_NE(NetworkCompilationContext::compute_hash(net1, {}), NetworkCompilationContext::compute_hash(net2, {}));
ASSERT_EQ(NetworkCompilationContext::compute_hash(net2, {}), NetworkCompilationContext::compute_hash(net3, {}));
}
// Verify all internal hash calculations are thread-safe (like ngraph::function serialization)
TEST(NetworkContext, HashOfSameMultiThreading) {
auto net1 = create_simple_function();
auto net2 = create_simple_function();
TEST(NetworkContext_CNNNetwork, HashOfSameMultiThreading) {
auto net1 = createNetwork();
auto net2 = createNetwork();
std::atomic_bool fail{false};
const auto TEST_DURATION_MS = 1000;
auto start = high_resolution_clock::now();