[GPU] Cleanup tuning cache methods (#16000)

This commit is contained in:
Vladimir Paramuzov 2023-03-01 16:30:47 +04:00 committed by GitHub
parent bde65c25c4
commit c5c7e4ff65
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 13 additions and 148 deletions

View File

@ -31,8 +31,8 @@
namespace kernel_selector { namespace kernel_selector {
TuningCache::TuningCache(const std::string& cacheFilePath, bool createMode) TuningCache::TuningCache(const std::string& cacheFilePath)
: cache(), needsSave(false) { : cache() {
// Read cache file // Read cache file
std::ifstream tuningFile(cacheFilePath); std::ifstream tuningFile(cacheFilePath);
@ -41,13 +41,7 @@ TuningCache::TuningCache(const std::string& cacheFilePath, bool createMode)
buffer << tuningFile.rdbuf(); buffer << tuningFile.rdbuf();
cache.Parse(buffer.str().c_str()); cache.Parse(buffer.str().c_str());
} else { } else {
if (!createMode) { throw std::runtime_error("Tuning file: " + cacheFilePath + " could not be read! Must provide a valid cache file in USE_CACHE mode.");
throw std::runtime_error("Tuning file: " + cacheFilePath +
" could not be read! Must provide a valid cache file in USE_CACHE mode.");
}
cache.SetObject();
needsSave = true;
} }
if (cache.IsNull()) { if (cache.IsNull()) {
@ -94,39 +88,33 @@ TuningCache::TuningCache(const std::string& cacheFilePath, bool createMode)
it++; it++;
} }
} }
needsSave = true;
} }
//
} }
TuningCache::TuningCache() TuningCache::TuningCache()
: cache(), needsSave(true) { : cache() {
cache.SetObject(); cache.SetObject();
auto v2Name = rapidjson::Value(version2Marker, cache.GetAllocator()); auto v2Name = rapidjson::Value(version2Marker, cache.GetAllocator());
auto v2Obj = rapidjson::Value(rapidjson::Type::kObjectType); auto v2Obj = rapidjson::Value(rapidjson::Type::kObjectType);
cache.AddMember(v2Name, v2Obj, cache.GetAllocator()); cache.AddMember(v2Name, v2Obj, cache.GetAllocator());
} }
TuningCache::Entry TuningCache::LoadKernel(const Params& params, bool update) { TuningCache::Entry TuningCache::LoadKernel(const Params& params) {
return LoadKernel(params, params.engineInfo.computeUnitsCount, update); return LoadKernel(params, params.engineInfo.computeUnitsCount);
} }
TuningCache::Entry TuningCache::LoadKernel(const Params& params, uint32_t computeUnitsCount, bool update) { TuningCache::Entry TuningCache::LoadKernel(const Params& params, uint32_t computeUnitsCount) {
bool oldVersion = false; bool oldVersion = false;
// Try to load from version 2 // Try to load from version 2
auto result = LoadKernel_v2(params, computeUnitsCount); auto result = LoadKernel_v2(params, computeUnitsCount);
// Try to load from version 1 // Try to load from version 1
if (std::get<0>(result).empty() || update) { if (std::get<0>(result).empty()) {
auto result_v1 = LoadKernel_v1(params, computeUnitsCount); auto result_v1 = LoadKernel_v1(params, computeUnitsCount);
oldVersion = !std::get<0>(result_v1).empty(); oldVersion = !std::get<0>(result_v1).empty();
if (oldVersion && std::get<0>(result).empty()) { if (oldVersion && std::get<0>(result).empty()) {
result = result_v1; result = result_v1;
} }
} }
// Move cache from old version to newer
if (oldVersion && update) {
StoreKernel(params, computeUnitsCount, std::get<0>(result), std::get<1>(result));
}
return result; return result;
} }
@ -180,120 +168,13 @@ TuningCache::Entry TuningCache::LoadKernel_v2(const Params& params, uint32_t com
return std::make_tuple(prog[0].GetString(), prog[1].GetInt()); return std::make_tuple(prog[0].GetString(), prog[1].GetInt());
} }
void TuningCache::StoreKernel(const Params& params, const std::string& implementationName, int tuneIndex) {
StoreKernel(params, params.engineInfo.computeUnitsCount, implementationName, tuneIndex);
}
void TuningCache::StoreKernel(const Params& params, uint32_t computeUnitsCount, const std::string& implementationName, int tuneIndex) {
auto kTypeStr = toString(params.GetType());
auto paramStr = params.to_cache_string_v2();
auto computeUnitsStr = std::to_string(computeUnitsCount);
auto& v2Cache = cache[version2Marker];
if (!v2Cache.HasMember(computeUnitsStr.c_str())) {
auto newName = rapidjson::Value(computeUnitsStr.c_str(), cache.GetAllocator());
auto newObj = rapidjson::Value(rapidjson::Type::kObjectType);
v2Cache.AddMember(newName, newObj, cache.GetAllocator());
}
if (!v2Cache[computeUnitsStr.c_str()].HasMember(kTypeStr.c_str())) {
auto newName = rapidjson::Value(kTypeStr.c_str(), cache.GetAllocator());
auto newObj = rapidjson::Value(rapidjson::Type::kObjectType);
v2Cache[computeUnitsStr.c_str()].AddMember(newName, newObj, cache.GetAllocator());
}
auto& deviceCache = v2Cache[computeUnitsStr.c_str()][kTypeStr.c_str()];
auto paramName = rapidjson::Value(paramStr.c_str(), cache.GetAllocator());
auto implDetails = rapidjson::Value(rapidjson::Type::kArrayType);
auto implName = rapidjson::Value(implementationName.c_str(), cache.GetAllocator());
auto implIndex = rapidjson::Value(tuneIndex);
implDetails.PushBack(implName, cache.GetAllocator());
implDetails.PushBack(implIndex, cache.GetAllocator());
deviceCache.AddMember(paramName, implDetails, cache.GetAllocator());
// Remove from old version if present
RemoveKernel_v1(params, computeUnitsCount);
needsSave = true;
}
void TuningCache::RemoveKernel(const Params& params) {
bool removed = false;
// Remove from version 2
removed |= RemoveKernel_v2(params, params.engineInfo.computeUnitsCount);
// Remove from version 1
removed |= RemoveKernel_v1(params, params.engineInfo.computeUnitsCount);
needsSave |= removed;
}
bool TuningCache::RemoveKernel_v1(const Params& params, uint32_t computeUnitsCount) {
auto hashStr = std::to_string(create_hash(params.to_string()));
auto computeUnitsStr = std::to_string(computeUnitsCount);
auto v1It = cache.FindMember(version1Marker);
if (v1It == cache.MemberEnd())
return false;
auto computeUnitsIt = v1It->value.FindMember(computeUnitsStr.c_str());
if (computeUnitsIt == v1It->value.MemberEnd())
return false;
auto hashIt = computeUnitsIt->value.FindMember(hashStr.c_str());
if (hashIt == computeUnitsIt->value.MemberEnd())
return false;
computeUnitsIt->value.RemoveMember(hashIt);
return true;
}
bool TuningCache::RemoveKernel_v2(const Params& params, uint32_t computeUnitsCount) {
auto kTypeStr = toString(params.GetType());
auto paramStr = params.to_cache_string_v2();
auto computeUnitsStr = std::to_string(computeUnitsCount);
auto v2It = cache.FindMember(version2Marker);
if (v2It == cache.MemberEnd())
return false;
auto computeUnitsIt = v2It->value.FindMember(computeUnitsStr.c_str());
if (computeUnitsIt == v2It->value.MemberEnd())
return false;
auto kTypeIt = computeUnitsIt->value.FindMember(kTypeStr.c_str());
if (kTypeIt == computeUnitsIt->value.MemberEnd())
return false;
auto paramIt = kTypeIt->value.FindMember(paramStr.c_str());
if (paramIt == kTypeIt->value.MemberEnd())
return false;
kTypeIt->value.RemoveMember(paramIt);
return true;
}
void TuningCache::Save(const std::string& cacheFilePath) {
std::ofstream cachedKernelsFile(cacheFilePath);
rapidjson::StringBuffer buffer(0, 1024);
rapidjson::PrettyWriter<rapidjson::StringBuffer> writer(buffer);
writer.SetFormatOptions(rapidjson::PrettyFormatOptions::kFormatSingleLineArray);
cache.Accept(writer);
auto temp = buffer.GetString();
cachedKernelsFile << temp;
cachedKernelsFile.close();
needsSave = false;
}
std::tuple<std::string, int> AutoTuner::LoadKernelOffline(const Params& params) { std::tuple<std::string, int> AutoTuner::LoadKernelOffline(const Params& params) {
std::lock_guard<std::mutex> lock(mutex); std::lock_guard<std::mutex> lock(mutex);
static const uint32_t defaultComputeUnits = 24; static const uint32_t defaultComputeUnits = 24;
TuningCache* deviceCache = TuningCache::get(); TuningCache* deviceCache = TuningCache::get();
if (!deviceCache) if (!deviceCache)
return {}; return {};
auto result = deviceCache->LoadKernel(params, false); auto result = deviceCache->LoadKernel(params);
if (std::get<0>(result).empty() && params.engineInfo.computeUnitsCount != defaultComputeUnits) { if (std::get<0>(result).empty() && params.engineInfo.computeUnitsCount != defaultComputeUnits) {
result = deviceCache->LoadKernel(params, defaultComputeUnits); result = deviceCache->LoadKernel(params, defaultComputeUnits);
} }
@ -324,7 +205,7 @@ TuningCache* TuningCache::get() {
if (!cache_instance) { if (!cache_instance) {
try { try {
cache_instance = std::make_shared<kernel_selector::TuningCache>(path, false); cache_instance = std::make_shared<kernel_selector::TuningCache>(path);
} catch (...) { } catch (...) {
cache_instance = std::make_shared<kernel_selector::TuningCache>(); cache_instance = std::make_shared<kernel_selector::TuningCache>();
} }

View File

@ -25,27 +25,16 @@ public:
// which may necessitate saving afterwards. // which may necessitate saving afterwards.
// This class is not thread-safe and all concurrent modifications should be synchronized by owner. // This class is not thread-safe and all concurrent modifications should be synchronized by owner.
// cacheFilePath - Path to cache file // cacheFilePath - Path to cache file
// createMode - Flag to enable creation if cache file does not exist.
// If file doesn't exist and createMode is false this constructor will throw. // If file doesn't exist and createMode is false this constructor will throw.
explicit TuningCache(const std::string& cacheFilePath, bool createMode = false); explicit TuningCache(const std::string& cacheFilePath);
// Constructs empty tuning cache. // Constructs empty tuning cache.
TuningCache(); TuningCache();
// Returns cached kernel for specified params. If "update" moves it to newest version if found, which may require saving afterwards. // Returns cached kernel for specified params. If "update" moves it to newest version if found, which may require saving afterwards.
Entry LoadKernel(const Params& params, bool update = true); Entry LoadKernel(const Params& params);
// Overrides the compute units count in params. // Overrides the compute units count in params.
Entry LoadKernel(const Params& params, uint32_t computeUnitsCount, bool update = true); Entry LoadKernel(const Params& params, uint32_t computeUnitsCount);
// Stores kernel for specified params.
void StoreKernel(const Params& params, const std::string& implementationName, int tuneIndex);
// Overrides the compute units count in params.
void StoreKernel(const Params& params, uint32_t computeUnitsCount, const std::string& implementationName, int tuneIndex);
// Removes the cached kernel for specified params if it exists, for all cache versions.
void RemoveKernel(const Params& params);
// Saves the internal cache to specified file.
void Save(const std::string& cacheFilePath);
bool NeedsSave() const { return needsSave; }
static TuningCache* get(); static TuningCache* get();
@ -53,12 +42,7 @@ private:
Entry LoadKernel_v1(const Params& params, uint32_t computeUnitsCount); Entry LoadKernel_v1(const Params& params, uint32_t computeUnitsCount);
Entry LoadKernel_v2(const Params& params, uint32_t computeUnitsCount); Entry LoadKernel_v2(const Params& params, uint32_t computeUnitsCount);
bool RemoveKernel_v1(const Params& params, uint32_t computeUnitsCount);
bool RemoveKernel_v2(const Params& params, uint32_t computeUnitsCount);
rapidjson::Document cache; rapidjson::Document cache;
bool needsSave;
static constexpr const char* version1Marker = "version_1"; static constexpr const char* version1Marker = "version_1";
static constexpr const char* version2Marker = "version_2"; static constexpr const char* version2Marker = "version_2";