[DO NOT REVIEW OR MERGE] LLM in SubgraphsDumper (#20756)

This commit is contained in:
Irina Efode 2023-10-31 19:57:41 +04:00 committed by GitHub
parent 65f6950f56
commit dd10a520e3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 28 additions and 9 deletions

View File

@ -45,7 +45,9 @@ public:
bool is_model_large_to_store_const(const std::shared_ptr<ov::Model>& model) {
auto model_bytesize = model->get_graph_size();
if (mem_size < model_bytesize * 4) {
size_t gb_8 = 1;
gb_8 <<= 33;
if (mem_size <= model_bytesize * 4 || model_bytesize >= gb_8) {
return true;
}
return false;

View File

@ -25,7 +25,7 @@ DEFINE_bool(h, false, help_message);
DEFINE_string(input_folders, "", local_cache_message);
DEFINE_string(local_cache, "", input_folders_message);
DEFINE_string(output_folder, "output", output_folder_message);
DEFINE_string(device, "CPU", device_message);
DEFINE_string(device, "TEMPLATE", device_message);
DEFINE_string(path_regex, ".*", output_folder_message);
DEFINE_bool(extract_body, true, extract_body_message);
DEFINE_string(cache_type, "", cache_type_message);

View File

@ -37,11 +37,16 @@ void GraphCache::update_cache(const std::shared_ptr<ov::Model>& model,
auto model_bytesize = model->get_graph_size();
// check that Free RAM memory is enough. Serialize in other case
// serialize graph cache in case graph cache bytesize > 4GB to avoid long search the same graphs
if (m_graph_cache_bytesize + 2 * model_bytesize > mem_size || m_graph_cache_bytesize >> 20 != 0) {
if (m_graph_cache_bytesize + 2 * model_bytesize >= mem_size) {
std::cout << "[ GRAPH CACHE ][ WARNING ] There are not enought RAM memory! Serialize graph cache" << std::endl;
serialize_cache();
m_graph_cache_bytesize = 0;
}
if (m_graph_cache_bytesize * 4 >= mem_size) {
std::cout << "[ GRAPH CACHE ][ WARNING ] 25% of RAM is used by cache! Serialize graph cache" << std::endl;
serialize_cache();
m_graph_cache_bytesize = 0;
}
auto is_large_model = is_model_large_to_store_const(model);
if (is_large_model) {
auto model_bytesize_gb = model_bytesize;
@ -49,7 +54,7 @@ void GraphCache::update_cache(const std::shared_ptr<ov::Model>& model,
auto mem_size_gb = mem_size;
mem_size_gb >>= 30;
std::cout << "[ GRAPH CACHE ][ WARNING ] Model bytesize is " << model_bytesize_gb <<
"GB. It is larger than 25% RAM size: " << mem_size_gb << ". Constants won't be copied!" << std::endl;
"GB. It is larger than 25% RAM size or >= 8GB: " << mem_size_gb << ". Constants won't be copied!" << std::endl;
}
auto extracted_patterns = m_manager.extract(model, extract_body, !is_large_model);
if (extracted_patterns.empty()) {
@ -169,11 +174,12 @@ void GraphCache::update_cache(const std::shared_ptr<ov::Model>& extracted_model,
}
void GraphCache::serialize_cache() {
for (const auto& cache_item : m_graph_cache) {
auto rel_dir = ov::util::path_join({ m_cache_subdir, get_model_type(cache_item.first), cache_item.second.get_any_extractor() });
serialize_model(cache_item, rel_dir);
while (!m_graph_cache.empty()) {
auto cache_item = m_graph_cache.begin();
auto rel_dir = ov::util::path_join({ m_cache_subdir, get_model_type(cache_item->first), cache_item->second.get_any_extractor() });
serialize_model(*cache_item, rel_dir);
m_graph_cache.erase(cache_item);
}
m_graph_cache.clear();
}
} // namespace subgraph_dumper

View File

@ -6,6 +6,7 @@
#include "openvino/op/tensor_iterator.hpp"
#include "openvino/op/if.hpp"
#include "openvino/op/loop.hpp"
#include "openvino/util/file_util.hpp"
#include "common_test_utils/common_utils.hpp"
@ -16,6 +17,13 @@ using namespace ov::tools::subgraph_dumper;
void FusedNamesExtractor::set_target_device(const std::string& _device) {
auto available_devices = core->get_available_devices();
if (_device == std::string(ov::test::utils::DEVICE_TEMPLATE) &&
std::find(available_devices.begin(), available_devices.end(), _device) == available_devices.end()) {
auto plugin_path = ov::util::make_plugin_library_name(ov::test::utils::getExecutableDirectory(),
std::string(ov::test::utils::TEMPLATE_LIB) + OV_BUILD_POSTFIX);
core->register_plugin(plugin_path, _device);
available_devices = core->get_available_devices();
}
if (_device.empty() && !available_devices.empty()) {
device = available_devices.front();
std::cout << "[ WARNING ][ GRAPH CACHE ] " << device <<

View File

@ -22,7 +22,7 @@ using namespace ov::tools::subgraph_dumper;
// ======================= ExtractorsManagerTest Unit tests =======================
class FusedNamesExtractorTest : public SubgraphsDumperBaseTest {
FusedNamesExtractor extractor;
FusedNamesExtractor extractor = FusedNamesExtractor("TEMPLATE");
protected:
void is_match(const std::shared_ptr<ov::Model>& model) {

View File

@ -90,6 +90,9 @@ def generate_model_list_file(input_str: str, re_exp_file_path: str, output_file_
except:
pass
for line in model_list:
str_line = str(line)
if "tfhub_module.pb" in str_line or "_metadata.pb" in str_line:
continue
output_file.write(f"{line}\n")
output_file.close()