[GPU] Cleanup tuning cache methods (#16000)

This commit is contained in:
Vladimir Paramuzov 2023-03-01 16:30:47 +04:00 committed by GitHub
parent bde65c25c4
commit c5c7e4ff65
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 13 additions and 148 deletions

View File

@ -31,8 +31,8 @@
namespace kernel_selector {
TuningCache::TuningCache(const std::string& cacheFilePath, bool createMode)
: cache(), needsSave(false) {
TuningCache::TuningCache(const std::string& cacheFilePath)
: cache() {
// Read cache file
std::ifstream tuningFile(cacheFilePath);
@ -41,13 +41,7 @@ TuningCache::TuningCache(const std::string& cacheFilePath, bool createMode)
buffer << tuningFile.rdbuf();
cache.Parse(buffer.str().c_str());
} else {
if (!createMode) {
throw std::runtime_error("Tuning file: " + cacheFilePath +
" could not be read! Must provide a valid cache file in USE_CACHE mode.");
}
cache.SetObject();
needsSave = true;
throw std::runtime_error("Tuning file: " + cacheFilePath + " could not be read! Must provide a valid cache file in USE_CACHE mode.");
}
if (cache.IsNull()) {
@ -94,39 +88,33 @@ TuningCache::TuningCache(const std::string& cacheFilePath, bool createMode)
it++;
}
}
needsSave = true;
}
//
}
TuningCache::TuningCache()
: cache(), needsSave(true) {
: cache() {
cache.SetObject();
auto v2Name = rapidjson::Value(version2Marker, cache.GetAllocator());
auto v2Obj = rapidjson::Value(rapidjson::Type::kObjectType);
cache.AddMember(v2Name, v2Obj, cache.GetAllocator());
}
TuningCache::Entry TuningCache::LoadKernel(const Params& params, bool update) {
return LoadKernel(params, params.engineInfo.computeUnitsCount, update);
TuningCache::Entry TuningCache::LoadKernel(const Params& params) {
return LoadKernel(params, params.engineInfo.computeUnitsCount);
}
TuningCache::Entry TuningCache::LoadKernel(const Params& params, uint32_t computeUnitsCount, bool update) {
TuningCache::Entry TuningCache::LoadKernel(const Params& params, uint32_t computeUnitsCount) {
bool oldVersion = false;
// Try to load from version 2
auto result = LoadKernel_v2(params, computeUnitsCount);
// Try to load from version 1
if (std::get<0>(result).empty() || update) {
if (std::get<0>(result).empty()) {
auto result_v1 = LoadKernel_v1(params, computeUnitsCount);
oldVersion = !std::get<0>(result_v1).empty();
if (oldVersion && std::get<0>(result).empty()) {
result = result_v1;
}
}
// Move cache from old version to newer
if (oldVersion && update) {
StoreKernel(params, computeUnitsCount, std::get<0>(result), std::get<1>(result));
}
return result;
}
@ -180,120 +168,13 @@ TuningCache::Entry TuningCache::LoadKernel_v2(const Params& params, uint32_t com
return std::make_tuple(prog[0].GetString(), prog[1].GetInt());
}
void TuningCache::StoreKernel(const Params& params, const std::string& implementationName, int tuneIndex) {
StoreKernel(params, params.engineInfo.computeUnitsCount, implementationName, tuneIndex);
}
void TuningCache::StoreKernel(const Params& params, uint32_t computeUnitsCount, const std::string& implementationName, int tuneIndex) {
auto kTypeStr = toString(params.GetType());
auto paramStr = params.to_cache_string_v2();
auto computeUnitsStr = std::to_string(computeUnitsCount);
auto& v2Cache = cache[version2Marker];
if (!v2Cache.HasMember(computeUnitsStr.c_str())) {
auto newName = rapidjson::Value(computeUnitsStr.c_str(), cache.GetAllocator());
auto newObj = rapidjson::Value(rapidjson::Type::kObjectType);
v2Cache.AddMember(newName, newObj, cache.GetAllocator());
}
if (!v2Cache[computeUnitsStr.c_str()].HasMember(kTypeStr.c_str())) {
auto newName = rapidjson::Value(kTypeStr.c_str(), cache.GetAllocator());
auto newObj = rapidjson::Value(rapidjson::Type::kObjectType);
v2Cache[computeUnitsStr.c_str()].AddMember(newName, newObj, cache.GetAllocator());
}
auto& deviceCache = v2Cache[computeUnitsStr.c_str()][kTypeStr.c_str()];
auto paramName = rapidjson::Value(paramStr.c_str(), cache.GetAllocator());
auto implDetails = rapidjson::Value(rapidjson::Type::kArrayType);
auto implName = rapidjson::Value(implementationName.c_str(), cache.GetAllocator());
auto implIndex = rapidjson::Value(tuneIndex);
implDetails.PushBack(implName, cache.GetAllocator());
implDetails.PushBack(implIndex, cache.GetAllocator());
deviceCache.AddMember(paramName, implDetails, cache.GetAllocator());
// Remove from old version if present
RemoveKernel_v1(params, computeUnitsCount);
needsSave = true;
}
void TuningCache::RemoveKernel(const Params& params) {
bool removed = false;
// Remove from version 2
removed |= RemoveKernel_v2(params, params.engineInfo.computeUnitsCount);
// Remove from version 1
removed |= RemoveKernel_v1(params, params.engineInfo.computeUnitsCount);
needsSave |= removed;
}
bool TuningCache::RemoveKernel_v1(const Params& params, uint32_t computeUnitsCount) {
auto hashStr = std::to_string(create_hash(params.to_string()));
auto computeUnitsStr = std::to_string(computeUnitsCount);
auto v1It = cache.FindMember(version1Marker);
if (v1It == cache.MemberEnd())
return false;
auto computeUnitsIt = v1It->value.FindMember(computeUnitsStr.c_str());
if (computeUnitsIt == v1It->value.MemberEnd())
return false;
auto hashIt = computeUnitsIt->value.FindMember(hashStr.c_str());
if (hashIt == computeUnitsIt->value.MemberEnd())
return false;
computeUnitsIt->value.RemoveMember(hashIt);
return true;
}
bool TuningCache::RemoveKernel_v2(const Params& params, uint32_t computeUnitsCount) {
auto kTypeStr = toString(params.GetType());
auto paramStr = params.to_cache_string_v2();
auto computeUnitsStr = std::to_string(computeUnitsCount);
auto v2It = cache.FindMember(version2Marker);
if (v2It == cache.MemberEnd())
return false;
auto computeUnitsIt = v2It->value.FindMember(computeUnitsStr.c_str());
if (computeUnitsIt == v2It->value.MemberEnd())
return false;
auto kTypeIt = computeUnitsIt->value.FindMember(kTypeStr.c_str());
if (kTypeIt == computeUnitsIt->value.MemberEnd())
return false;
auto paramIt = kTypeIt->value.FindMember(paramStr.c_str());
if (paramIt == kTypeIt->value.MemberEnd())
return false;
kTypeIt->value.RemoveMember(paramIt);
return true;
}
void TuningCache::Save(const std::string& cacheFilePath) {
std::ofstream cachedKernelsFile(cacheFilePath);
rapidjson::StringBuffer buffer(0, 1024);
rapidjson::PrettyWriter<rapidjson::StringBuffer> writer(buffer);
writer.SetFormatOptions(rapidjson::PrettyFormatOptions::kFormatSingleLineArray);
cache.Accept(writer);
auto temp = buffer.GetString();
cachedKernelsFile << temp;
cachedKernelsFile.close();
needsSave = false;
}
std::tuple<std::string, int> AutoTuner::LoadKernelOffline(const Params& params) {
std::lock_guard<std::mutex> lock(mutex);
static const uint32_t defaultComputeUnits = 24;
TuningCache* deviceCache = TuningCache::get();
if (!deviceCache)
return {};
auto result = deviceCache->LoadKernel(params, false);
auto result = deviceCache->LoadKernel(params);
if (std::get<0>(result).empty() && params.engineInfo.computeUnitsCount != defaultComputeUnits) {
result = deviceCache->LoadKernel(params, defaultComputeUnits);
}
@ -324,7 +205,7 @@ TuningCache* TuningCache::get() {
if (!cache_instance) {
try {
cache_instance = std::make_shared<kernel_selector::TuningCache>(path, false);
cache_instance = std::make_shared<kernel_selector::TuningCache>(path);
} catch (...) {
cache_instance = std::make_shared<kernel_selector::TuningCache>();
}

View File

@ -25,27 +25,16 @@ public:
// which may necessitate saving afterwards.
// This class is not thread-safe and all concurrent modifications should be synchronized by owner.
// cacheFilePath - Path to cache file
// createMode - Flag to enable creation if cache file does not exist.
// If file doesn't exist and createMode is false this constructor will throw.
explicit TuningCache(const std::string& cacheFilePath, bool createMode = false);
explicit TuningCache(const std::string& cacheFilePath);
// Constructs empty tuning cache.
TuningCache();
// Returns cached kernel for specified params. If "update" moves it to newest version if found, which may require saving afterwards.
Entry LoadKernel(const Params& params, bool update = true);
Entry LoadKernel(const Params& params);
// Overrides the compute units count in params.
Entry LoadKernel(const Params& params, uint32_t computeUnitsCount, bool update = true);
// Stores kernel for specified params.
void StoreKernel(const Params& params, const std::string& implementationName, int tuneIndex);
// Overrides the compute units count in params.
void StoreKernel(const Params& params, uint32_t computeUnitsCount, const std::string& implementationName, int tuneIndex);
// Removes the cached kernel for specified params if it exists, for all cache versions.
void RemoveKernel(const Params& params);
// Saves the internal cache to specified file.
void Save(const std::string& cacheFilePath);
bool NeedsSave() const { return needsSave; }
Entry LoadKernel(const Params& params, uint32_t computeUnitsCount);
static TuningCache* get();
@ -53,12 +42,7 @@ private:
Entry LoadKernel_v1(const Params& params, uint32_t computeUnitsCount);
Entry LoadKernel_v2(const Params& params, uint32_t computeUnitsCount);
bool RemoveKernel_v1(const Params& params, uint32_t computeUnitsCount);
bool RemoveKernel_v2(const Params& params, uint32_t computeUnitsCount);
rapidjson::Document cache;
bool needsSave;
static constexpr const char* version1Marker = "version_1";
static constexpr const char* version2Marker = "version_2";