diff --git a/src/inference/src/compilation_context.cpp b/src/inference/src/compilation_context.cpp index 90826a63687..6471a2861dd 100644 --- a/src/inference/src/compilation_context.cpp +++ b/src/inference/src/compilation_context.cpp @@ -45,11 +45,71 @@ static int32_t as_int32_t(T v) { namespace { -uint64_t calculate_td(const InferenceEngine::TensorDesc& td, uint64_t _seed) { - uint64_t seed = _seed; +uint64_t compute_model_hash(const std::shared_ptr& model, const ov::AnyMap& compileOptions) { + OPENVINO_ASSERT(model); + + uint64_t seed = 0; + // 1. Calculate hash on function + ov::pass::Manager m; + m.register_pass(); + m.register_pass(seed); + m.run_passes(std::const_pointer_cast(model)); + + // 2. Compute hash on serialized data and options + for (const auto& kvp : compileOptions) { + seed = ov::hash_combine(seed, kvp.first + kvp.second.as()); + } + + // 3. Add runtime information which may not be serialized + for (const auto& op : model->get_ordered_ops()) { + const auto& rt = op->get_rt_info(); + for (const auto& rtMapData : rt) { + seed = ov::hash_combine(seed, rtMapData.first); + std::stringstream strm; + rtMapData.second.print(strm); + seed = ov::hash_combine(seed, strm.str()); + } + } + + // 4. Legacy part if CNNNetwork is used with new Plugin API + for (auto&& input : model->inputs()) { + auto& rt_info = input.get_rt_info(); + + auto it = rt_info.find("ie_legacy_td"); + if (it != rt_info.end()) { + auto td = it->second.as(); + seed = ov::hash_combine(seed, ov::as_int32_t(td.getPrecision())); + seed = ov::hash_combine(seed, ov::as_int32_t(td.getLayout())); + } + + it = rt_info.find("ie_legacy_preproc"); + if (it != rt_info.end()) { + auto preproc = it->second.as(); + + seed = ov::hash_combine(seed, ov::as_int32_t(preproc.getMeanVariant())); + + if (preproc.getMeanVariant() == InferenceEngine::MeanVariant::MEAN_VALUE) { + seed = ov::hash_combine(seed, preproc.getNumberOfChannels()); + for (size_t c = 0; c < preproc.getNumberOfChannels(); ++c) { + const InferenceEngine::PreProcessChannel::Ptr& channelInfo = preproc[c]; + seed = ov::hash_combine(seed, channelInfo->stdScale); + seed = ov::hash_combine(seed, channelInfo->meanValue); + } + } else if (preproc.getMeanVariant() == InferenceEngine::MeanVariant::MEAN_IMAGE) { + // TODO: think if we need to compute hash for mean image if it exists + } + } + } + for (auto&& output : model->outputs()) { + auto& rt_info = output.get_rt_info(); + auto it = rt_info.find("ie_legacy_td"); + if (it != rt_info.end()) { + auto td = it->second.as(); + seed = ov::hash_combine(seed, ov::as_int32_t(td.getPrecision())); + seed = ov::hash_combine(seed, ov::as_int32_t(td.getLayout())); + } + } - seed = ov::hash_combine(seed, ov::as_int32_t(td.getPrecision())); - seed = ov::hash_combine(seed, ov::as_int32_t(td.getLayout())); return seed; } @@ -83,64 +143,47 @@ std::string NetworkCompilationContext::compute_hash(const std::shared_ptr(); - m.register_pass(seed); - m.run_passes(std::const_pointer_cast(model)); - // 2. Compute hash on serialized data and options - for (const auto& kvp : compileOptions) { - seed = ov::hash_combine(seed, kvp.first + kvp.second.as()); - } - // 3. Add runtime information which may not be serialized - for (const auto& op : model->get_ordered_ops()) { - const auto& rt = op->get_rt_info(); - for (const auto& rtMapData : rt) { - seed = ov::hash_combine(seed, rtMapData.first); - std::stringstream strm; - rtMapData.second.print(strm); - seed = ov::hash_combine(seed, strm.str()); - } - } + seed = compute_model_hash(network.getFunction(), compileOptions); - // 4. Legacy part if CNNNetwork is used with new Plugin API - for (auto&& input : model->inputs()) { - auto& rt_info = input.get_rt_info(); + // 4. Add inputs info + for (const auto& input : network.getInputsInfo()) { + InferenceEngine::InputInfo::Ptr info = input.second; + seed = hash_combine(seed, as_int32_t(info->getPrecision())); + seed = hash_combine(seed, as_int32_t(info->getLayout())); - auto it = rt_info.find("ie_legacy_td"); - if (it != rt_info.end()) { - seed = calculate_td(it->second.as(), seed); - } + const InferenceEngine::PreProcessInfo& preproc = info->getPreProcess(); + seed = hash_combine(seed, as_int32_t(preproc.getMeanVariant())); - it = rt_info.find("ie_legacy_preproc"); - if (it != rt_info.end()) { - auto preproc = it->second.as(); - - seed = ov::hash_combine(seed, ov::as_int32_t(preproc.getMeanVariant())); - - if (preproc.getMeanVariant() == InferenceEngine::MeanVariant::MEAN_VALUE) { - seed = ov::hash_combine(seed, preproc.getNumberOfChannels()); - for (size_t c = 0; c < preproc.getNumberOfChannels(); ++c) { - const InferenceEngine::PreProcessChannel::Ptr& channelInfo = preproc[c]; - seed = ov::hash_combine(seed, channelInfo->stdScale); - seed = ov::hash_combine(seed, channelInfo->meanValue); - } - } else if (preproc.getMeanVariant() == InferenceEngine::MeanVariant::MEAN_IMAGE) { - // TODO: think if we need to compute hash for mean image if it exists + if (preproc.getMeanVariant() == InferenceEngine::MeanVariant::MEAN_VALUE) { + seed = hash_combine(seed, preproc.getNumberOfChannels()); + for (size_t c = 0; c < preproc.getNumberOfChannels(); ++c) { + const InferenceEngine::PreProcessChannel::Ptr& channelInfo = preproc[c]; + seed = hash_combine(seed, channelInfo->stdScale); + seed = hash_combine(seed, channelInfo->meanValue); } + } else if (preproc.getMeanVariant() == InferenceEngine::MeanVariant::MEAN_IMAGE) { + // TODO: think if we need to compute hash for mean image if it exists } } - for (auto&& output : model->outputs()) { - auto& rt_info = output.get_rt_info(); - auto it = rt_info.find("ie_legacy_td"); - if (it != rt_info.end()) { - seed = calculate_td(it->second.as(), seed); - } + + // 5. Add outputs info + for (const auto& output : network.getOutputsInfo()) { + InferenceEngine::DataPtr info = output.second; + seed = hash_combine(seed, as_int32_t(info->getPrecision())); + seed = hash_combine(seed, as_int32_t(info->getLayout())); } return std::to_string(seed); diff --git a/src/inference/src/compilation_context.hpp b/src/inference/src/compilation_context.hpp index b06deb24284..06c20148830 100644 --- a/src/inference/src/compilation_context.hpp +++ b/src/inference/src/compilation_context.hpp @@ -26,6 +26,8 @@ class Model; struct NetworkCompilationContext final { static std::string calculate_file_info(const std::string& filePath); + static std::string compute_hash(const InferenceEngine::CNNNetwork& network, const ov::AnyMap& compileOptions); + static std::string compute_hash(const std::shared_ptr& model, const ov::AnyMap& compileOptions); static std::string compute_hash(const std::string& modelName, const ov::AnyMap& compileOptions); diff --git a/src/inference/src/dev/core_impl.cpp b/src/inference/src/dev/core_impl.cpp index 31448e6e54d..8080a6d5aba 100644 --- a/src/inference/src/dev/core_impl.cpp +++ b/src/inference/src/dev/core_impl.cpp @@ -430,7 +430,9 @@ ov::SoPtr ov::CoreImpl::compile_mod res = compile_model_impl(cnnNetwork.getFunction(), plugin, parsed._config, {}, cacheContent); } } else if (cacheManager) { - res = plugin.compile_model(model_path, parsed._config); + auto cnnNetwork = ReadNetwork(model_path, std::string()); + // TODO: 'validation' for dynamic API doesn't work for this case, as it affects a lot of plugin API + res = compile_model(plugin, cnnNetwork.getFunction(), {}, parsed._config); } else { auto cnnNetwork = ReadNetwork(model_path, std::string()); res = compile_model_impl(cnnNetwork.getFunction(), plugin, parsed._config, {}, cacheContent); @@ -483,37 +485,7 @@ ov::SupportedOpsMap ov::CoreImpl::query_model(const std::shared_ptrclone(); - - std::string defDevice = ret.begin()->second; - ngraph::pass::ConstantFolding().run_on_model(specialized_function); - std::unordered_set opNames; - - for (const auto& op : specialized_function->get_ops()) - opNames.emplace(op->get_friendly_name()); - - for (const auto& op : model->get_ops()) { - if (opNames.find(op->get_friendly_name()) == opNames.end()) { - ret[op->get_friendly_name()] = defDevice; - } - } - - for (const auto& op : model->get_ops()) { - if (!ret.count(op->get_friendly_name()) && std::dynamic_pointer_cast(op)) { - bool are_all_users_supported = true; - for (const auto& user : op->output(0).get_target_inputs()) { - if (!ret.count(user.get_node()->get_friendly_name())) { - are_all_users_supported = false; - break; - } - } - if (are_all_users_supported) { - ret[op->get_friendly_name()] = defDevice; - } - } - } - return ret; + return get_plugin(parsed._deviceName).query_model(model, parsed._config); } std::vector ov::CoreImpl::get_available_devices() const { diff --git a/src/inference/src/dev/core_impl.hpp b/src/inference/src/dev/core_impl.hpp index c070765826c..46f9e0b22f9 100644 --- a/src/inference/src/dev/core_impl.hpp +++ b/src/inference/src/dev/core_impl.hpp @@ -219,7 +219,9 @@ private: const InferenceEngine::CNNNetwork& model, ov::Plugin& plugin, const std::map& parsedConfig, - const InferenceEngine::RemoteContext::Ptr& context); + const InferenceEngine::RemoteContext::Ptr& context, + const CacheContent& cacheContent, + bool forceDisableCache = false); public: CoreImpl(bool _newAPI); diff --git a/src/inference/src/dev/core_impl_ie.cpp b/src/inference/src/dev/core_impl_ie.cpp index 02c24bfdada..d070779e538 100644 --- a/src/inference/src/dev/core_impl_ie.cpp +++ b/src/inference/src/dev/core_impl_ie.cpp @@ -25,7 +25,9 @@ ov::SoPtr ov::CoreImpl::LoadNetwork const InferenceEngine::CNNNetwork& network, ov::Plugin& plugin, const std::map& parsedConfig, - const InferenceEngine::RemoteContext::Ptr& context) { + const InferenceEngine::RemoteContext::Ptr& context, + const CacheContent& cacheContent, + bool forceDisableCache) { OV_ITT_SCOPED_TASK(ov::itt::domains::IE, "CoreImpl::compile_model_impl"); ov::SoPtr execNetwork; auto wrapper = std::dynamic_pointer_cast(plugin.m_ptr); @@ -34,6 +36,21 @@ ov::SoPtr ov::CoreImpl::LoadNetwork execNetwork = {context ? old_plugin->LoadNetwork(network, parsedConfig, context) : old_plugin->LoadNetwork(network, parsedConfig), plugin.m_so}; + if (!forceDisableCache && cacheContent.cacheManager && device_supports_import_export(plugin)) { + try { + // need to export network for further import from "cache" + OV_ITT_SCOPE(FIRST_INFERENCE, InferenceEngine::itt::domains::IE_LT, "Core::LoadNetwork::Export"); + cacheContent.cacheManager->writeCacheEntry(cacheContent.blobId, [&](std::ostream& networkStream) { + networkStream << ov::CompiledBlobHeader( + InferenceEngine::GetInferenceEngineVersion()->buildNumber, + ov::NetworkCompilationContext::calculate_file_info(cacheContent.modelPath)); + execNetwork->Export(networkStream); + }); + } catch (...) { + cacheContent.cacheManager->removeCacheEntry(cacheContent.blobId); + throw; + } + } return execNetwork; } @@ -72,23 +89,72 @@ ov::SoPtr ov::CoreImpl::LoadNetwork auto plugin = get_plugin(parsed._deviceName); - auto res = LoadNetworkImpl(network, plugin, parsed._config, context); + ov::SoPtr res; + auto conf = ov::any_copy(parsed._config); + auto cacheManager = + coreConfig.get_cache_config_for_device(parsed._deviceName, device_supports_cache_dir(plugin), conf) + ._cacheManager; + auto cacheContent = CacheContent{cacheManager}; + if (cacheManager && device_supports_import_export(plugin)) { + cacheContent.blobId = ov::NetworkCompilationContext::compute_hash( + network, + create_compile_config(plugin, parsed._deviceName, ov::any_copy(parsed._config))); + bool loadedFromCache = false; + auto lock = cacheGuard.getHashLock(cacheContent.blobId); + res = load_model_from_cache(cacheContent, plugin, conf, {context, {}}, loadedFromCache); + if (!loadedFromCache) { + res = LoadNetworkImpl(network, plugin, parsed._config, context, cacheContent); + } else { + // Temporary workaround until all plugins support caching of original model inputs + InferenceEngine::SetExeNetworkInfo(res._ptr, network.getFunction(), isNewAPI()); + } + } else { + res = LoadNetworkImpl(network, plugin, parsed._config, context, cacheContent); + } return res; } InferenceEngine::SoExecutableNetworkInternal ov::CoreImpl::LoadNetwork( const InferenceEngine::CNNNetwork& network, - const std::string& deviceName, + const std::string& deviceNameOrig, const std::map& config) { OV_ITT_SCOPE(FIRST_INFERENCE, InferenceEngine::itt::domains::IE_LT, "Core::LoadNetwork::CNN"); if (network.getFunction()) { auto compiled_model = - compile_model(ov::legacy_convert::convert_model(network, isNewAPI()), deviceName, any_copy(config)); + compile_model(ov::legacy_convert::convert_model(network, isNewAPI()), deviceNameOrig, any_copy(config)); return {compiled_model._ptr, compiled_model._so}; } - auto parsed = parseDeviceNameIntoConfig(deviceName, config); + std::string deviceName = deviceNameOrig; + std::map config_with_batch = config; + bool forceDisableCache = config_with_batch.count(CONFIG_KEY_INTERNAL(FORCE_DISABLE_CACHE)) > 0; + auto parsed = parseDeviceNameIntoConfig(deviceName, config_with_batch); + if (forceDisableCache) { + // remove this config key from parsed as plugins can throw unsupported exception + parsed._config.erase(CONFIG_KEY_INTERNAL(FORCE_DISABLE_CACHE)); + } auto plugin = get_plugin(parsed._deviceName); - auto res = LoadNetworkImpl(network, plugin, parsed._config, nullptr); + ov::SoPtr res; + auto conf = ov::any_copy(parsed._config); + auto cacheManager = + coreConfig.get_cache_config_for_device(parsed._deviceName, device_supports_cache_dir(plugin), conf) + ._cacheManager; + auto cacheContent = CacheContent{cacheManager}; + if (!forceDisableCache && cacheManager && device_supports_import_export(plugin)) { + cacheContent.blobId = ov::NetworkCompilationContext::compute_hash( + network, + create_compile_config(plugin, parsed._deviceName, ov::any_copy(parsed._config))); + bool loadedFromCache = false; + auto lock = cacheGuard.getHashLock(cacheContent.blobId); + res = load_model_from_cache(cacheContent, plugin, conf, {}, loadedFromCache); + if (!loadedFromCache) { + res = LoadNetworkImpl(network, plugin, parsed._config, nullptr, cacheContent, forceDisableCache); + } else { + // Temporary workaround until all plugins support caching of original model inputs + InferenceEngine::SetExeNetworkInfo(res._ptr, network.getFunction(), isNewAPI()); + } + } else { + res = LoadNetworkImpl(network, plugin, parsed._config, nullptr, cacheContent, forceDisableCache); + } return {res._ptr, res._so}; } @@ -98,9 +164,54 @@ InferenceEngine::SoExecutableNetworkInternal ov::CoreImpl::LoadNetwork( const std::map& config, const std::function& val) { OV_ITT_SCOPE(FIRST_INFERENCE, ie::itt::domains::IE_LT, "Core::LoadNetwork::Path"); - - auto compiled_model = compile_model(modelPath, deviceName, any_copy(config)); - return {compiled_model._ptr, compiled_model._so}; + auto parsed = parseDeviceNameIntoConfig(deviceName, config); + auto plugin = get_plugin(parsed._deviceName); + ov::SoPtr res; + auto conf = any_copy(parsed._config); + auto cacheManager = + coreConfig.get_cache_config_for_device(parsed._deviceName, device_supports_cache_dir(plugin), conf) + ._cacheManager; + auto cacheContent = CacheContent{cacheManager, modelPath}; + if (cacheManager && device_supports_import_export(plugin)) { + bool loadedFromCache = false; + cacheContent.blobId = + ov::NetworkCompilationContext::compute_hash(modelPath, + create_compile_config(plugin, parsed._deviceName, conf)); + auto lock = cacheGuard.getHashLock(cacheContent.blobId); + res = load_model_from_cache(cacheContent, plugin, conf, {}, loadedFromCache); + if (!loadedFromCache) { + auto cnnNetwork = ReadNetwork(modelPath, std::string()); + if (val) { + val(cnnNetwork); + } + if (cnnNetwork.getFunction()) { + res = compile_model_impl(ov::legacy_convert::convert_model(cnnNetwork, isNewAPI()), + plugin, + conf, + {}, + cacheContent); + } else { + res = LoadNetworkImpl(cnnNetwork, plugin, parsed._config, nullptr, cacheContent); + } + } + } else if (cacheManager) { + res = plugin.compile_model(modelPath, conf); + } else { + auto cnnNetwork = ReadNetwork(modelPath, std::string()); + if (val) { + val(cnnNetwork); + } + if (cnnNetwork.getFunction()) { + res = compile_model_impl(ov::legacy_convert::convert_model(cnnNetwork, isNewAPI()), + plugin, + conf, + {}, + cacheContent); + } else { + res = LoadNetworkImpl(cnnNetwork, plugin, parsed._config, nullptr, cacheContent); + } + } + return {res._ptr, res._so}; } InferenceEngine::SoExecutableNetworkInternal ov::CoreImpl::LoadNetwork( @@ -136,8 +247,43 @@ InferenceEngine::QueryNetworkResult ov::CoreImpl::QueryNetwork(const InferenceEn return ret; } auto res = query_model(network.getFunction(), deviceName, any_copy(config)); + if (!network.getFunction() || res.empty()) { + ret.rc = InferenceEngine::GENERAL_ERROR; + return ret; + } ret.supportedLayersMap = res; + const auto& func = network.getFunction(); + auto specialized_function = func->clone(); + + std::string defDevice = ret.supportedLayersMap.begin()->second; + ngraph::pass::ConstantFolding().run_on_model(specialized_function); + std::unordered_set opNames; + + for (const auto& op : specialized_function->get_ops()) + opNames.emplace(op->get_friendly_name()); + + for (const auto& op : func->get_ops()) { + if (opNames.find(op->get_friendly_name()) == opNames.end()) { + ret.supportedLayersMap[op->get_friendly_name()] = defDevice; + } + } + + for (const auto& op : func->get_ops()) { + if (!ret.supportedLayersMap.count(op->get_friendly_name()) && + std::dynamic_pointer_cast(op)) { + bool are_all_users_supported = true; + for (const auto& user : op->output(0).get_target_inputs()) { + if (!ret.supportedLayersMap.count(user.get_node()->get_friendly_name())) { + are_all_users_supported = false; + break; + } + } + if (are_all_users_supported) { + ret.supportedLayersMap[op->get_friendly_name()] = defDevice; + } + } + } return ret; } diff --git a/src/inference/tests/functional/cnn_network_test.cpp b/src/inference/tests/functional/cnn_network_test.cpp index ba461899086..c82591f14b3 100644 --- a/src/inference/tests/functional/cnn_network_test.cpp +++ b/src/inference/tests/functional/cnn_network_test.cpp @@ -147,3 +147,45 @@ TEST_F(CNNNetworkTests, throwsHasDynamicInputs_queryNetwork) { EXPECT_TRUE(std::string(e.what()).find("p3_2") == std::string::npos) << e.what(); } } + +class CNNNetworkTests_LoadFromFileTest : public ::testing::Test { +protected: + std::string modelName{}; + std::string weightsName{}; + InferenceEngine::Core core; + +public: + void SetUp() override { + auto filePrefix = CommonTestUtils::generateTestFilePrefix(); + modelName = filePrefix + "_CNNNetworkTests_LoadFromFileTest.xml"; + weightsName = filePrefix + "_CNNNetworkTests_LoadFromFileTest.bin"; + + std::shared_ptr model = CNNNetworkTests_create_model(); + ov::pass::Serialize(modelName, weightsName).run_on_model(model); + ASSERT_NO_THROW( + core.RegisterPlugin(ov::util::make_plugin_library_name(CommonTestUtils::getExecutableDirectory(), + std::string("mock_engine") + IE_BUILD_POSTFIX), + "mock")); + } + + void TearDown() override { + CommonTestUtils::removeIRFiles(modelName, weightsName); + core.UnregisterPlugin("mock"); + } +}; +#if defined(ENABLE_OV_IR_FRONTEND) +TEST_F(CNNNetworkTests_LoadFromFileTest, throwsHasDynamicInputs_fromPath) { + try { + core.LoadNetwork(modelName, "mock"); + FAIL() << "LoadNetwork with dynamic inputs shall throw"; + } catch (const ov::AssertFailure& e) { + EXPECT_TRUE(std::string(e.what()).find("InferenceEngine::Core::LoadNetwork") != std::string::npos) << e.what(); + EXPECT_TRUE(std::string(e.what()).find("p1_1") != std::string::npos) << e.what(); + EXPECT_TRUE(std::string(e.what()).find("p1_2") != std::string::npos) << e.what(); + EXPECT_TRUE(std::string(e.what()).find("p2_1") != std::string::npos) << e.what(); + EXPECT_TRUE(std::string(e.what()).find("p2_2") != std::string::npos) << e.what(); + EXPECT_TRUE(std::string(e.what()).find("p3_1") == std::string::npos) << e.what(); + EXPECT_TRUE(std::string(e.what()).find("p3_2") == std::string::npos) << e.what(); + } +} +#endif // defined(ENABLE_OV_IR_FRONTEND) diff --git a/src/inference/tests/unit/ie_compilation_context_test.cpp b/src/inference/tests/unit/ie_compilation_context_test.cpp index b08d78c58e1..654cdc34e0f 100644 --- a/src/inference/tests/unit/ie_compilation_context_test.cpp +++ b/src/inference/tests/unit/ie_compilation_context_test.cpp @@ -147,52 +147,64 @@ static std::shared_ptr create_simple_function() { return func; } +static CNNNetwork createNetwork() { + CNNNetwork res(create_simple_function()); + return res; +} + +static CNNNetwork createNetworkWithLayout(const ov::Layout& layout) { + auto fun = create_simple_function(); + fun->get_parameters()[0]->set_layout(layout); + fun->get_results()[0]->set_layout(layout); + return CNNNetwork(fun); +} + static void checkCustomRt(const std::function& emptyCb, const std::function& nameCb) { - auto model1 = create_simple_function(); - auto model2 = create_simple_function(); - auto& op1 = model1->get_ops().front()->get_rt_info(); - auto& op2 = model2->get_ops().front()->get_rt_info(); + auto net1 = createNetwork(); + auto net2 = createNetwork(); + auto& op1 = net1.getFunction()->get_ops().front()->get_rt_info(); + auto& op2 = net2.getFunction()->get_ops().front()->get_rt_info(); emptyCb(op2); - ASSERT_NE(NetworkCompilationContext::compute_hash(model1, {}), NetworkCompilationContext::compute_hash(model2, {})); + ASSERT_NE(NetworkCompilationContext::compute_hash(net1, {}), NetworkCompilationContext::compute_hash(net2, {})); emptyCb(op1); - ASSERT_EQ(NetworkCompilationContext::compute_hash(model1, {}), NetworkCompilationContext::compute_hash(model2, {})); + ASSERT_EQ(NetworkCompilationContext::compute_hash(net1, {}), NetworkCompilationContext::compute_hash(net2, {})); nameCb(op1, "test"); - ASSERT_NE(NetworkCompilationContext::compute_hash(model1, {}), NetworkCompilationContext::compute_hash(model2, {})); + ASSERT_NE(NetworkCompilationContext::compute_hash(net1, {}), NetworkCompilationContext::compute_hash(net2, {})); nameCb(op2, "test"); - ASSERT_EQ(NetworkCompilationContext::compute_hash(model1, {}), NetworkCompilationContext::compute_hash(model2, {})); + ASSERT_EQ(NetworkCompilationContext::compute_hash(net1, {}), NetworkCompilationContext::compute_hash(net2, {})); nameCb(op1, "test2"); - ASSERT_NE(NetworkCompilationContext::compute_hash(model1, {}), NetworkCompilationContext::compute_hash(model2, {})); + ASSERT_NE(NetworkCompilationContext::compute_hash(net1, {}), NetworkCompilationContext::compute_hash(net2, {})); } -TEST(NetworkContext, HashOfSame) { - auto model1 = create_simple_function(); - auto model2 = create_simple_function(); - ASSERT_EQ(NetworkCompilationContext::compute_hash(model1, {}), NetworkCompilationContext::compute_hash(model2, {})); +TEST(NetworkContext_CNNNetwork, HashOfSame) { + auto net1 = createNetwork(); + auto net2 = createNetwork(); + ASSERT_EQ(NetworkCompilationContext::compute_hash(net1, {}), NetworkCompilationContext::compute_hash(net2, {})); } -TEST(NetworkContext, HashWithConfig) { - auto net1 = create_simple_function(); - auto net2 = create_simple_function(); +TEST(NetworkContext_CNNNetwork, HashWithConfig) { + auto net1 = createNetwork(); + auto net2 = createNetwork(); ASSERT_NE(NetworkCompilationContext::compute_hash(net1, {{"key", "value"}}), NetworkCompilationContext::compute_hash(net2, {})); ASSERT_EQ(NetworkCompilationContext::compute_hash(net1, {{"key", "value"}}), NetworkCompilationContext::compute_hash(net2, {{"key", "value"}})); } -TEST(NetworkContext, HashWithPrimitivesPriority) { - auto net1 = create_simple_function(); - auto net2 = create_simple_function(); - auto net3 = create_simple_function(); - auto& op2 = net2->get_ops().front()->get_rt_info(); +TEST(NetworkContext_CNNNetwork, HashWithPrimitivesPriority) { + auto net1 = createNetwork(); + auto net2 = createNetwork(); + auto net3 = createNetwork(); + auto& op2 = net2.getFunction()->get_ops().front()->get_rt_info(); op2[ov::PrimitivesPriority::get_type_info_static()] = ov::PrimitivesPriority("testPriority"); - auto& op3 = net3->get_ops().front()->get_rt_info(); + auto& op3 = net3.getFunction()->get_ops().front()->get_rt_info(); op3["PrimitivesPriority"] = "testPriority"; ASSERT_NE(NetworkCompilationContext::compute_hash(net1, {}), NetworkCompilationContext::compute_hash(net2, {})); @@ -200,7 +212,7 @@ TEST(NetworkContext, HashWithPrimitivesPriority) { ASSERT_EQ(NetworkCompilationContext::compute_hash(net2, {}), NetworkCompilationContext::compute_hash(net3, {})); } -TEST(NetworkContext, HashWithFusedNames) { +TEST(NetworkContext_CNNNetwork, HashWithFusedNames) { auto setFusedEmpty = [&](Node::RTMap& rtInfo) { rtInfo[ov::FusedNames::get_type_info_static()] = ov::FusedNames(); }; @@ -210,7 +222,7 @@ TEST(NetworkContext, HashWithFusedNames) { checkCustomRt(setFusedEmpty, setFused); } -TEST(NetworkContext, HashWithPrimitivesPriorityType) { +TEST(NetworkContext_CNNNetwork, HashWithPrimitivesPriorityType) { auto setPrimEmpty = [&](Node::RTMap& rtInfo) { rtInfo[ov::PrimitivesPriority::get_type_info_static()] = ov::PrimitivesPriority(""); }; @@ -220,14 +232,14 @@ TEST(NetworkContext, HashWithPrimitivesPriorityType) { checkCustomRt(setPrimEmpty, setPrim); } -TEST(NetworkContext, HashWithAffinity) { - auto net1 = create_simple_function(); - auto net2 = create_simple_function(); - auto net3 = create_simple_function(); - auto& op2 = net2->get_ops().front()->get_rt_info(); +TEST(NetworkContext_CNNNetwork, HashWithAffinity) { + auto net1 = createNetwork(); + auto net2 = createNetwork(); + auto net3 = createNetwork(); + auto& op2 = net2.getFunction()->get_ops().front()->get_rt_info(); op2["affinity"] = "testAffinity"; - auto& op3 = net3->get_ops().front()->get_rt_info(); + auto& op3 = net3.getFunction()->get_ops().front()->get_rt_info(); op3["affinity"] = "testAffinity"; ASSERT_NE(NetworkCompilationContext::compute_hash(net1, {}), NetworkCompilationContext::compute_hash(net2, {})); @@ -235,18 +247,18 @@ TEST(NetworkContext, HashWithAffinity) { ASSERT_EQ(NetworkCompilationContext::compute_hash(net2, {}), NetworkCompilationContext::compute_hash(net3, {})); } -TEST(NetworkContext, HashWithFutureRt_string) { - auto net1 = create_simple_function(); - auto net2 = create_simple_function(); - auto net3 = create_simple_function(); +TEST(NetworkContext_CNNNetwork, HashWithFutureRt_string) { + auto net1 = createNetwork(); + auto net2 = createNetwork(); + auto net3 = createNetwork(); - auto& op1 = net1->get_ops().front()->get_rt_info(); + auto& op1 = net1.getFunction()->get_ops().front()->get_rt_info(); op1["someFutureKey"] = "hello"; - auto& op2 = net2->get_ops().front()->get_rt_info(); + auto& op2 = net2.getFunction()->get_ops().front()->get_rt_info(); op2["someFutureKey"] = "hello"; - auto& op3 = net3->get_ops().front()->get_rt_info(); + auto& op3 = net3.getFunction()->get_ops().front()->get_rt_info(); op3["someFutureKey"] = "olleh"; ASSERT_EQ(NetworkCompilationContext::compute_hash(net1, {}), NetworkCompilationContext::compute_hash(net2, {})); @@ -254,18 +266,18 @@ TEST(NetworkContext, HashWithFutureRt_string) { ASSERT_NE(NetworkCompilationContext::compute_hash(net2, {}), NetworkCompilationContext::compute_hash(net3, {})); } -TEST(NetworkContext, HashWithFutureRt_int64) { - auto net1 = create_simple_function(); - auto net2 = create_simple_function(); - auto net3 = create_simple_function(); +TEST(NetworkContext_CNNNetwork, HashWithFutureRt_int64) { + auto net1 = createNetwork(); + auto net2 = createNetwork(); + auto net3 = createNetwork(); - auto& op1 = net1->get_ops().front()->get_rt_info(); + auto& op1 = net1.getFunction()->get_ops().front()->get_rt_info(); op1["someFutureKey"] = int64_t(42); - auto& op2 = net2->get_ops().front()->get_rt_info(); + auto& op2 = net2.getFunction()->get_ops().front()->get_rt_info(); op2["someFutureKey"] = int64_t(42); - auto& op3 = net3->get_ops().front()->get_rt_info(); + auto& op3 = net3.getFunction()->get_ops().front()->get_rt_info(); op3["someFutureKey"] = int64_t(43); ASSERT_EQ(NetworkCompilationContext::compute_hash(net1, {}), NetworkCompilationContext::compute_hash(net2, {})); @@ -273,7 +285,31 @@ TEST(NetworkContext, HashWithFutureRt_int64) { ASSERT_NE(NetworkCompilationContext::compute_hash(net2, {}), NetworkCompilationContext::compute_hash(net3, {})); } -TEST(NetworkContext, HashWithTensorNames) { +TEST(NetworkContext_CNNNetwork, HashWithLayout) { + auto net1 = createNetworkWithLayout("NCH"); + auto net2 = createNetworkWithLayout("nch"); + auto net3 = createNetworkWithLayout("?CH"); + auto net3_1 = createNetworkWithLayout("?C?"); + auto net4 = createNetworkWithLayout(""); + auto fun5 = create_simple_function(); + fun5->get_parameters()[0]->set_layout("NCH"); + fun5->get_parameters()[0]->set_layout(""); + fun5->get_results()[0]->set_layout("NHC"); + fun5->get_results()[0]->set_layout(ov::Layout()); + auto net5 = CNNNetwork(fun5); + + EXPECT_EQ(NetworkCompilationContext::compute_hash(net1, {}), NetworkCompilationContext::compute_hash(net2, {})); + + EXPECT_NE(NetworkCompilationContext::compute_hash(net2, {}), NetworkCompilationContext::compute_hash(net3, {})); + + EXPECT_NE(NetworkCompilationContext::compute_hash(net3, {}), NetworkCompilationContext::compute_hash(net3_1, {})); + + EXPECT_NE(NetworkCompilationContext::compute_hash(net3, {}), NetworkCompilationContext::compute_hash(net4, {})); + + EXPECT_EQ(NetworkCompilationContext::compute_hash(net4, {}), NetworkCompilationContext::compute_hash(net5, {})); +} + +TEST(NetworkContext_CNNNetwork, HashWithTensorNames) { auto fun1 = create_simple_function(); auto fun2 = create_simple_function(); auto fun3 = create_simple_function(); @@ -293,25 +329,50 @@ TEST(NetworkContext, HashWithTensorNames) { fun1->input().set_names(names1); fun2->input().set_names(names2); - ASSERT_EQ(NetworkCompilationContext::compute_hash(fun1, {}), NetworkCompilationContext::compute_hash(fun2, {})); + auto net1 = CNNNetwork(fun1); + auto net2 = CNNNetwork(fun2); + auto net3 = CNNNetwork(fun3); - ASSERT_NE(NetworkCompilationContext::compute_hash(fun2, {}), NetworkCompilationContext::compute_hash(fun3, {})); + ASSERT_EQ(NetworkCompilationContext::compute_hash(net1, {}), NetworkCompilationContext::compute_hash(net2, {})); + + ASSERT_NE(NetworkCompilationContext::compute_hash(net2, {}), NetworkCompilationContext::compute_hash(net3, {})); } -TEST(NetworkContext, HashWithDifferentResults) { - auto net1 = create_simple_function(); - auto net2 = create_simple_function(); - net2->remove_result(net2->get_results().front()); - auto net3 = create_simple_function(); - net3->remove_result(net3->get_results().front()); +TEST(NetworkContext_CNNNetwork, HashWithDifferentResults) { + auto net1 = createNetwork(); + auto net2 = createNetwork(); + net2.getFunction()->remove_result(net2.getFunction()->get_results().front()); + auto net3 = createNetwork(); + net3.getFunction()->remove_result(net3.getFunction()->get_results().front()); + ASSERT_NE(NetworkCompilationContext::compute_hash(net1, {}), NetworkCompilationContext::compute_hash(net2, {})); + ASSERT_EQ(NetworkCompilationContext::compute_hash(net2, {}), NetworkCompilationContext::compute_hash(net3, {})); +} + +TEST(NetworkContext_CNNNetwork, HashWithDifferentMeanValues) { + auto updatePreprocess = [&](CNNNetwork& cnnNet) { + auto& preProcess = cnnNet.getInputsInfo().begin()->second->getPreProcess(); + preProcess.init(3); + preProcess[0]->stdScale = 2; + preProcess[1]->stdScale = 3; + preProcess[2]->stdScale = 4; + preProcess[0]->meanValue = 0; + preProcess[1]->meanValue = 1; + preProcess[2]->meanValue = 2; + preProcess.setVariant(InferenceEngine::MEAN_VALUE); + }; + auto net1 = createNetwork(); + auto net2 = createNetwork(); + updatePreprocess(net2); + auto net3 = createNetwork(); + updatePreprocess(net3); ASSERT_NE(NetworkCompilationContext::compute_hash(net1, {}), NetworkCompilationContext::compute_hash(net2, {})); ASSERT_EQ(NetworkCompilationContext::compute_hash(net2, {}), NetworkCompilationContext::compute_hash(net3, {})); } // Verify all internal hash calculations are thread-safe (like ngraph::function serialization) -TEST(NetworkContext, HashOfSameMultiThreading) { - auto net1 = create_simple_function(); - auto net2 = create_simple_function(); +TEST(NetworkContext_CNNNetwork, HashOfSameMultiThreading) { + auto net1 = createNetwork(); + auto net2 = createNetwork(); std::atomic_bool fail{false}; const auto TEST_DURATION_MS = 1000; auto start = high_resolution_clock::now();