[GNA] Fix KEY_EXEC_TARGET (cherry pick #7671) (#7701)

* Use Gna2DeviceCreateForExport when GNA_EXEC_TARGET is != detected

* Update detected GNA device version field in GNA Device helper

   * Use EXEC instead of COMPILE TARGET to append
   CNN Legacy enforcement (GNA1)

* Apply review
This commit is contained in:
Krzysztof Bruniecki 2021-09-28 17:39:49 +02:00 committed by GitHub
parent 068d31511b
commit eee864aed6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 39 additions and 8 deletions

View File

@ -169,8 +169,21 @@ void GNADeviceHelper::releaseModel(const uint32_t model_id) {
} }
bool GNADeviceHelper::enforceLegacyCnnNeeded() const { bool GNADeviceHelper::enforceLegacyCnnNeeded() const {
const auto compileTargetDevice = getTargetDevice(false); const auto execTargetDevice = getTargetDevice(true);
return (isGnaLibVersion3_0 || isGnaLibVersion2_1) && isUpTo20HwGnaDevice(compileTargetDevice); return (isGnaLibVersion3_0 || isGnaLibVersion2_1) && isUpTo20HwGnaDevice(execTargetDevice);
}
Gna2DeviceVersion GNADeviceHelper::parseTarget(const std::string& target) {
const std::map<std::string, Gna2DeviceVersion> targetMap {
{InferenceEngine::GNAConfigParams::GNA_TARGET_2_0, Gna2DeviceVersion2_0},
{InferenceEngine::GNAConfigParams::GNA_TARGET_3_0, Gna2DeviceVersion3_0},
{"", Gna2DeviceVersionSoftwareEmulation},
};
const auto f = targetMap.find(target);
if (f != targetMap.end()) {
return f->second;
}
THROW_GNA_EXCEPTION << "Unsupported " << "GNAConfigParams::GNA_TARGET = \"" << target << "\"\n";
} }
Gna2DeviceVersion GNADeviceHelper::parseDeclaredTarget(std::string target, const bool execTarget) const { Gna2DeviceVersion GNADeviceHelper::parseDeclaredTarget(std::string target, const bool execTarget) const {
@ -476,6 +489,16 @@ void GNADeviceHelper::dumpXnnForDeviceVersion(
outStream.write("Gna2ModelSueCreekHeader", 24); outStream.write("Gna2ModelSueCreekHeader", 24);
outStream.write(reinterpret_cast<const char*>(&sueHeader), sizeof(sueHeader)); outStream.write(reinterpret_cast<const char*>(&sueHeader), sizeof(sueHeader));
} }
void GNADeviceHelper::createVirtualDevice(Gna2DeviceVersion devVersion, std::string purpose) {
const auto status = Gna2DeviceCreateForExport(devVersion, &nGnaDeviceIndex);
GNADeviceHelper::checkGna2Status(status, "Gna2DeviceCreateForExport(" + std::to_string(devVersion) + ")" + purpose);
}
void GNADeviceHelper::updateGnaDeviceVersion() {
const auto status = Gna2DeviceGetVersion(nGnaDeviceIndex, &detectedGnaDevVersion);
checkGna2Status(status, "Gna2DeviceGetVersion");
}
#endif #endif
#if GNA_LIB_VER == 1 #if GNA_LIB_VER == 1
@ -492,14 +515,18 @@ void GNADeviceHelper::open(uint8_t n_threads) {
nGNAHandle = GNADeviceOpenSetThreads(&nGNAStatus, n_threads); nGNAHandle = GNADeviceOpenSetThreads(&nGNAStatus, n_threads);
checkStatus(); checkStatus();
#else #else
auto status = Gna2DeviceGetVersion(nGnaDeviceIndex, &detectedGnaDevVersion); updateGnaDeviceVersion();
checkGna2Status(status, "Gna2DeviceGetVersion"); const auto gnaExecTarget = parseTarget(executionTarget);
if (useDeviceEmbeddedExport) { if (useDeviceEmbeddedExport) {
status = Gna2DeviceCreateForExport(exportGeneration, &nGnaDeviceIndex); createVirtualDevice(exportGeneration, "export");
GNADeviceHelper::checkGna2Status(status, "Gna2DeviceCreateForExport"); } else if (!executionTarget.empty() && gnaExecTarget != detectedGnaDevVersion) {
createVirtualDevice(gnaExecTarget, "execution");
updateGnaDeviceVersion();
if (detectedGnaDevVersion != gnaExecTarget) {
THROW_GNA_EXCEPTION << "Wrong virtual GNA device version reported: " << detectedGnaDevVersion << " instead of: " << gnaExecTarget;
}
} else { } else {
status = Gna2DeviceOpen(nGnaDeviceIndex); const auto status = Gna2DeviceOpen(nGnaDeviceIndex);
checkGna2Status(status, "Gna2DeviceOpen"); checkGna2Status(status, "Gna2DeviceOpen");
} }

View File

@ -202,9 +202,13 @@ public:
static void enforceLegacyCnns(Gna2Model& gnaModel); static void enforceLegacyCnns(Gna2Model& gnaModel);
static void enforceLegacyCnnsWhenNeeded(Gna2Model& gnaModel); static void enforceLegacyCnnsWhenNeeded(Gna2Model& gnaModel);
static Gna2DeviceVersion parseTarget(const std::string& target);
Gna2DeviceVersion parseDeclaredTarget(std::string target, const bool execTarget) const; Gna2DeviceVersion parseDeclaredTarget(std::string target, const bool execTarget) const;
Gna2DeviceVersion getDefaultTarget() const; Gna2DeviceVersion getDefaultTarget() const;
Gna2DeviceVersion getTargetDevice(bool execTarget) const; Gna2DeviceVersion getTargetDevice(bool execTarget) const;
void createVirtualDevice(Gna2DeviceVersion devVersion, std::string purpose = "");
void updateGnaDeviceVersion();
#endif #endif
void setOMPThreads(uint8_t const n_threads); void setOMPThreads(uint8_t const n_threads);