[GPU] Cleanup tuning cache methods (#16000)
This commit is contained in:
parent
bde65c25c4
commit
c5c7e4ff65
@ -31,8 +31,8 @@
|
||||
|
||||
namespace kernel_selector {
|
||||
|
||||
TuningCache::TuningCache(const std::string& cacheFilePath, bool createMode)
|
||||
: cache(), needsSave(false) {
|
||||
TuningCache::TuningCache(const std::string& cacheFilePath)
|
||||
: cache() {
|
||||
// Read cache file
|
||||
std::ifstream tuningFile(cacheFilePath);
|
||||
|
||||
@ -41,13 +41,7 @@ TuningCache::TuningCache(const std::string& cacheFilePath, bool createMode)
|
||||
buffer << tuningFile.rdbuf();
|
||||
cache.Parse(buffer.str().c_str());
|
||||
} else {
|
||||
if (!createMode) {
|
||||
throw std::runtime_error("Tuning file: " + cacheFilePath +
|
||||
" could not be read! Must provide a valid cache file in USE_CACHE mode.");
|
||||
}
|
||||
|
||||
cache.SetObject();
|
||||
needsSave = true;
|
||||
throw std::runtime_error("Tuning file: " + cacheFilePath + " could not be read! Must provide a valid cache file in USE_CACHE mode.");
|
||||
}
|
||||
|
||||
if (cache.IsNull()) {
|
||||
@ -94,39 +88,33 @@ TuningCache::TuningCache(const std::string& cacheFilePath, bool createMode)
|
||||
it++;
|
||||
}
|
||||
}
|
||||
needsSave = true;
|
||||
}
|
||||
//
|
||||
}
|
||||
|
||||
TuningCache::TuningCache()
|
||||
: cache(), needsSave(true) {
|
||||
: cache() {
|
||||
cache.SetObject();
|
||||
auto v2Name = rapidjson::Value(version2Marker, cache.GetAllocator());
|
||||
auto v2Obj = rapidjson::Value(rapidjson::Type::kObjectType);
|
||||
cache.AddMember(v2Name, v2Obj, cache.GetAllocator());
|
||||
}
|
||||
|
||||
TuningCache::Entry TuningCache::LoadKernel(const Params& params, bool update) {
|
||||
return LoadKernel(params, params.engineInfo.computeUnitsCount, update);
|
||||
TuningCache::Entry TuningCache::LoadKernel(const Params& params) {
|
||||
return LoadKernel(params, params.engineInfo.computeUnitsCount);
|
||||
}
|
||||
|
||||
TuningCache::Entry TuningCache::LoadKernel(const Params& params, uint32_t computeUnitsCount, bool update) {
|
||||
TuningCache::Entry TuningCache::LoadKernel(const Params& params, uint32_t computeUnitsCount) {
|
||||
bool oldVersion = false;
|
||||
// Try to load from version 2
|
||||
auto result = LoadKernel_v2(params, computeUnitsCount);
|
||||
// Try to load from version 1
|
||||
if (std::get<0>(result).empty() || update) {
|
||||
if (std::get<0>(result).empty()) {
|
||||
auto result_v1 = LoadKernel_v1(params, computeUnitsCount);
|
||||
oldVersion = !std::get<0>(result_v1).empty();
|
||||
if (oldVersion && std::get<0>(result).empty()) {
|
||||
result = result_v1;
|
||||
}
|
||||
}
|
||||
// Move cache from old version to newer
|
||||
if (oldVersion && update) {
|
||||
StoreKernel(params, computeUnitsCount, std::get<0>(result), std::get<1>(result));
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
@ -180,120 +168,13 @@ TuningCache::Entry TuningCache::LoadKernel_v2(const Params& params, uint32_t com
|
||||
return std::make_tuple(prog[0].GetString(), prog[1].GetInt());
|
||||
}
|
||||
|
||||
void TuningCache::StoreKernel(const Params& params, const std::string& implementationName, int tuneIndex) {
|
||||
StoreKernel(params, params.engineInfo.computeUnitsCount, implementationName, tuneIndex);
|
||||
}
|
||||
|
||||
void TuningCache::StoreKernel(const Params& params, uint32_t computeUnitsCount, const std::string& implementationName, int tuneIndex) {
|
||||
auto kTypeStr = toString(params.GetType());
|
||||
auto paramStr = params.to_cache_string_v2();
|
||||
auto computeUnitsStr = std::to_string(computeUnitsCount);
|
||||
auto& v2Cache = cache[version2Marker];
|
||||
|
||||
if (!v2Cache.HasMember(computeUnitsStr.c_str())) {
|
||||
auto newName = rapidjson::Value(computeUnitsStr.c_str(), cache.GetAllocator());
|
||||
auto newObj = rapidjson::Value(rapidjson::Type::kObjectType);
|
||||
v2Cache.AddMember(newName, newObj, cache.GetAllocator());
|
||||
}
|
||||
|
||||
if (!v2Cache[computeUnitsStr.c_str()].HasMember(kTypeStr.c_str())) {
|
||||
auto newName = rapidjson::Value(kTypeStr.c_str(), cache.GetAllocator());
|
||||
auto newObj = rapidjson::Value(rapidjson::Type::kObjectType);
|
||||
v2Cache[computeUnitsStr.c_str()].AddMember(newName, newObj, cache.GetAllocator());
|
||||
}
|
||||
|
||||
auto& deviceCache = v2Cache[computeUnitsStr.c_str()][kTypeStr.c_str()];
|
||||
|
||||
auto paramName = rapidjson::Value(paramStr.c_str(), cache.GetAllocator());
|
||||
auto implDetails = rapidjson::Value(rapidjson::Type::kArrayType);
|
||||
auto implName = rapidjson::Value(implementationName.c_str(), cache.GetAllocator());
|
||||
auto implIndex = rapidjson::Value(tuneIndex);
|
||||
implDetails.PushBack(implName, cache.GetAllocator());
|
||||
implDetails.PushBack(implIndex, cache.GetAllocator());
|
||||
|
||||
deviceCache.AddMember(paramName, implDetails, cache.GetAllocator());
|
||||
|
||||
// Remove from old version if present
|
||||
RemoveKernel_v1(params, computeUnitsCount);
|
||||
|
||||
needsSave = true;
|
||||
}
|
||||
|
||||
void TuningCache::RemoveKernel(const Params& params) {
|
||||
bool removed = false;
|
||||
// Remove from version 2
|
||||
removed |= RemoveKernel_v2(params, params.engineInfo.computeUnitsCount);
|
||||
// Remove from version 1
|
||||
removed |= RemoveKernel_v1(params, params.engineInfo.computeUnitsCount);
|
||||
|
||||
needsSave |= removed;
|
||||
}
|
||||
|
||||
bool TuningCache::RemoveKernel_v1(const Params& params, uint32_t computeUnitsCount) {
|
||||
auto hashStr = std::to_string(create_hash(params.to_string()));
|
||||
auto computeUnitsStr = std::to_string(computeUnitsCount);
|
||||
|
||||
auto v1It = cache.FindMember(version1Marker);
|
||||
if (v1It == cache.MemberEnd())
|
||||
return false;
|
||||
|
||||
auto computeUnitsIt = v1It->value.FindMember(computeUnitsStr.c_str());
|
||||
if (computeUnitsIt == v1It->value.MemberEnd())
|
||||
return false;
|
||||
|
||||
auto hashIt = computeUnitsIt->value.FindMember(hashStr.c_str());
|
||||
if (hashIt == computeUnitsIt->value.MemberEnd())
|
||||
return false;
|
||||
|
||||
computeUnitsIt->value.RemoveMember(hashIt);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool TuningCache::RemoveKernel_v2(const Params& params, uint32_t computeUnitsCount) {
|
||||
auto kTypeStr = toString(params.GetType());
|
||||
auto paramStr = params.to_cache_string_v2();
|
||||
auto computeUnitsStr = std::to_string(computeUnitsCount);
|
||||
|
||||
auto v2It = cache.FindMember(version2Marker);
|
||||
if (v2It == cache.MemberEnd())
|
||||
return false;
|
||||
|
||||
auto computeUnitsIt = v2It->value.FindMember(computeUnitsStr.c_str());
|
||||
if (computeUnitsIt == v2It->value.MemberEnd())
|
||||
return false;
|
||||
|
||||
auto kTypeIt = computeUnitsIt->value.FindMember(kTypeStr.c_str());
|
||||
if (kTypeIt == computeUnitsIt->value.MemberEnd())
|
||||
return false;
|
||||
|
||||
auto paramIt = kTypeIt->value.FindMember(paramStr.c_str());
|
||||
if (paramIt == kTypeIt->value.MemberEnd())
|
||||
return false;
|
||||
|
||||
kTypeIt->value.RemoveMember(paramIt);
|
||||
return true;
|
||||
}
|
||||
|
||||
void TuningCache::Save(const std::string& cacheFilePath) {
|
||||
std::ofstream cachedKernelsFile(cacheFilePath);
|
||||
rapidjson::StringBuffer buffer(0, 1024);
|
||||
rapidjson::PrettyWriter<rapidjson::StringBuffer> writer(buffer);
|
||||
writer.SetFormatOptions(rapidjson::PrettyFormatOptions::kFormatSingleLineArray);
|
||||
cache.Accept(writer);
|
||||
auto temp = buffer.GetString();
|
||||
cachedKernelsFile << temp;
|
||||
cachedKernelsFile.close();
|
||||
|
||||
needsSave = false;
|
||||
}
|
||||
|
||||
std::tuple<std::string, int> AutoTuner::LoadKernelOffline(const Params& params) {
|
||||
std::lock_guard<std::mutex> lock(mutex);
|
||||
static const uint32_t defaultComputeUnits = 24;
|
||||
TuningCache* deviceCache = TuningCache::get();
|
||||
if (!deviceCache)
|
||||
return {};
|
||||
auto result = deviceCache->LoadKernel(params, false);
|
||||
auto result = deviceCache->LoadKernel(params);
|
||||
if (std::get<0>(result).empty() && params.engineInfo.computeUnitsCount != defaultComputeUnits) {
|
||||
result = deviceCache->LoadKernel(params, defaultComputeUnits);
|
||||
}
|
||||
@ -324,7 +205,7 @@ TuningCache* TuningCache::get() {
|
||||
|
||||
if (!cache_instance) {
|
||||
try {
|
||||
cache_instance = std::make_shared<kernel_selector::TuningCache>(path, false);
|
||||
cache_instance = std::make_shared<kernel_selector::TuningCache>(path);
|
||||
} catch (...) {
|
||||
cache_instance = std::make_shared<kernel_selector::TuningCache>();
|
||||
}
|
||||
|
@ -25,27 +25,16 @@ public:
|
||||
// which may necessitate saving afterwards.
|
||||
// This class is not thread-safe and all concurrent modifications should be synchronized by owner.
|
||||
// cacheFilePath - Path to cache file
|
||||
// createMode - Flag to enable creation if cache file does not exist.
|
||||
// If file doesn't exist and createMode is false this constructor will throw.
|
||||
explicit TuningCache(const std::string& cacheFilePath, bool createMode = false);
|
||||
explicit TuningCache(const std::string& cacheFilePath);
|
||||
|
||||
// Constructs empty tuning cache.
|
||||
TuningCache();
|
||||
|
||||
// Returns cached kernel for specified params. If "update" moves it to newest version if found, which may require saving afterwards.
|
||||
Entry LoadKernel(const Params& params, bool update = true);
|
||||
Entry LoadKernel(const Params& params);
|
||||
// Overrides the compute units count in params.
|
||||
Entry LoadKernel(const Params& params, uint32_t computeUnitsCount, bool update = true);
|
||||
// Stores kernel for specified params.
|
||||
void StoreKernel(const Params& params, const std::string& implementationName, int tuneIndex);
|
||||
// Overrides the compute units count in params.
|
||||
void StoreKernel(const Params& params, uint32_t computeUnitsCount, const std::string& implementationName, int tuneIndex);
|
||||
// Removes the cached kernel for specified params if it exists, for all cache versions.
|
||||
void RemoveKernel(const Params& params);
|
||||
// Saves the internal cache to specified file.
|
||||
void Save(const std::string& cacheFilePath);
|
||||
|
||||
bool NeedsSave() const { return needsSave; }
|
||||
Entry LoadKernel(const Params& params, uint32_t computeUnitsCount);
|
||||
|
||||
static TuningCache* get();
|
||||
|
||||
@ -53,12 +42,7 @@ private:
|
||||
Entry LoadKernel_v1(const Params& params, uint32_t computeUnitsCount);
|
||||
Entry LoadKernel_v2(const Params& params, uint32_t computeUnitsCount);
|
||||
|
||||
bool RemoveKernel_v1(const Params& params, uint32_t computeUnitsCount);
|
||||
bool RemoveKernel_v2(const Params& params, uint32_t computeUnitsCount);
|
||||
|
||||
|
||||
rapidjson::Document cache;
|
||||
bool needsSave;
|
||||
|
||||
static constexpr const char* version1Marker = "version_1";
|
||||
static constexpr const char* version2Marker = "version_2";
|
||||
|
Loading…
Reference in New Issue
Block a user