* Use Gna2DeviceCreateForExport when GNA_EXEC_TARGET is != detected * Update detected GNA device version field in GNA Device helper * Use EXEC instead of COMPILE TARGET to append CNN Legacy enforcement (GNA1) * Apply review
This commit is contained in:
parent
068d31511b
commit
eee864aed6
@ -169,8 +169,21 @@ void GNADeviceHelper::releaseModel(const uint32_t model_id) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool GNADeviceHelper::enforceLegacyCnnNeeded() const {
|
bool GNADeviceHelper::enforceLegacyCnnNeeded() const {
|
||||||
const auto compileTargetDevice = getTargetDevice(false);
|
const auto execTargetDevice = getTargetDevice(true);
|
||||||
return (isGnaLibVersion3_0 || isGnaLibVersion2_1) && isUpTo20HwGnaDevice(compileTargetDevice);
|
return (isGnaLibVersion3_0 || isGnaLibVersion2_1) && isUpTo20HwGnaDevice(execTargetDevice);
|
||||||
|
}
|
||||||
|
|
||||||
|
Gna2DeviceVersion GNADeviceHelper::parseTarget(const std::string& target) {
|
||||||
|
const std::map<std::string, Gna2DeviceVersion> targetMap {
|
||||||
|
{InferenceEngine::GNAConfigParams::GNA_TARGET_2_0, Gna2DeviceVersion2_0},
|
||||||
|
{InferenceEngine::GNAConfigParams::GNA_TARGET_3_0, Gna2DeviceVersion3_0},
|
||||||
|
{"", Gna2DeviceVersionSoftwareEmulation},
|
||||||
|
};
|
||||||
|
const auto f = targetMap.find(target);
|
||||||
|
if (f != targetMap.end()) {
|
||||||
|
return f->second;
|
||||||
|
}
|
||||||
|
THROW_GNA_EXCEPTION << "Unsupported " << "GNAConfigParams::GNA_TARGET = \"" << target << "\"\n";
|
||||||
}
|
}
|
||||||
|
|
||||||
Gna2DeviceVersion GNADeviceHelper::parseDeclaredTarget(std::string target, const bool execTarget) const {
|
Gna2DeviceVersion GNADeviceHelper::parseDeclaredTarget(std::string target, const bool execTarget) const {
|
||||||
@ -476,6 +489,16 @@ void GNADeviceHelper::dumpXnnForDeviceVersion(
|
|||||||
outStream.write("Gna2ModelSueCreekHeader", 24);
|
outStream.write("Gna2ModelSueCreekHeader", 24);
|
||||||
outStream.write(reinterpret_cast<const char*>(&sueHeader), sizeof(sueHeader));
|
outStream.write(reinterpret_cast<const char*>(&sueHeader), sizeof(sueHeader));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void GNADeviceHelper::createVirtualDevice(Gna2DeviceVersion devVersion, std::string purpose) {
|
||||||
|
const auto status = Gna2DeviceCreateForExport(devVersion, &nGnaDeviceIndex);
|
||||||
|
GNADeviceHelper::checkGna2Status(status, "Gna2DeviceCreateForExport(" + std::to_string(devVersion) + ")" + purpose);
|
||||||
|
}
|
||||||
|
|
||||||
|
void GNADeviceHelper::updateGnaDeviceVersion() {
|
||||||
|
const auto status = Gna2DeviceGetVersion(nGnaDeviceIndex, &detectedGnaDevVersion);
|
||||||
|
checkGna2Status(status, "Gna2DeviceGetVersion");
|
||||||
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if GNA_LIB_VER == 1
|
#if GNA_LIB_VER == 1
|
||||||
@ -492,14 +515,18 @@ void GNADeviceHelper::open(uint8_t n_threads) {
|
|||||||
nGNAHandle = GNADeviceOpenSetThreads(&nGNAStatus, n_threads);
|
nGNAHandle = GNADeviceOpenSetThreads(&nGNAStatus, n_threads);
|
||||||
checkStatus();
|
checkStatus();
|
||||||
#else
|
#else
|
||||||
auto status = Gna2DeviceGetVersion(nGnaDeviceIndex, &detectedGnaDevVersion);
|
updateGnaDeviceVersion();
|
||||||
checkGna2Status(status, "Gna2DeviceGetVersion");
|
const auto gnaExecTarget = parseTarget(executionTarget);
|
||||||
|
|
||||||
if (useDeviceEmbeddedExport) {
|
if (useDeviceEmbeddedExport) {
|
||||||
status = Gna2DeviceCreateForExport(exportGeneration, &nGnaDeviceIndex);
|
createVirtualDevice(exportGeneration, "export");
|
||||||
GNADeviceHelper::checkGna2Status(status, "Gna2DeviceCreateForExport");
|
} else if (!executionTarget.empty() && gnaExecTarget != detectedGnaDevVersion) {
|
||||||
|
createVirtualDevice(gnaExecTarget, "execution");
|
||||||
|
updateGnaDeviceVersion();
|
||||||
|
if (detectedGnaDevVersion != gnaExecTarget) {
|
||||||
|
THROW_GNA_EXCEPTION << "Wrong virtual GNA device version reported: " << detectedGnaDevVersion << " instead of: " << gnaExecTarget;
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
status = Gna2DeviceOpen(nGnaDeviceIndex);
|
const auto status = Gna2DeviceOpen(nGnaDeviceIndex);
|
||||||
checkGna2Status(status, "Gna2DeviceOpen");
|
checkGna2Status(status, "Gna2DeviceOpen");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -202,9 +202,13 @@ public:
|
|||||||
|
|
||||||
static void enforceLegacyCnns(Gna2Model& gnaModel);
|
static void enforceLegacyCnns(Gna2Model& gnaModel);
|
||||||
static void enforceLegacyCnnsWhenNeeded(Gna2Model& gnaModel);
|
static void enforceLegacyCnnsWhenNeeded(Gna2Model& gnaModel);
|
||||||
|
static Gna2DeviceVersion parseTarget(const std::string& target);
|
||||||
Gna2DeviceVersion parseDeclaredTarget(std::string target, const bool execTarget) const;
|
Gna2DeviceVersion parseDeclaredTarget(std::string target, const bool execTarget) const;
|
||||||
Gna2DeviceVersion getDefaultTarget() const;
|
Gna2DeviceVersion getDefaultTarget() const;
|
||||||
Gna2DeviceVersion getTargetDevice(bool execTarget) const;
|
Gna2DeviceVersion getTargetDevice(bool execTarget) const;
|
||||||
|
|
||||||
|
void createVirtualDevice(Gna2DeviceVersion devVersion, std::string purpose = "");
|
||||||
|
void updateGnaDeviceVersion();
|
||||||
#endif
|
#endif
|
||||||
void setOMPThreads(uint8_t const n_threads);
|
void setOMPThreads(uint8_t const n_threads);
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user