diff --git a/inference-engine/src/gna_plugin/gna_device.cpp b/inference-engine/src/gna_plugin/gna_device.cpp index a9bd9eec15e..ee6bb08d597 100644 --- a/inference-engine/src/gna_plugin/gna_device.cpp +++ b/inference-engine/src/gna_plugin/gna_device.cpp @@ -169,8 +169,21 @@ void GNADeviceHelper::releaseModel(const uint32_t model_id) { } bool GNADeviceHelper::enforceLegacyCnnNeeded() const { - const auto compileTargetDevice = getTargetDevice(false); - return (isGnaLibVersion3_0 || isGnaLibVersion2_1) && isUpTo20HwGnaDevice(compileTargetDevice); + const auto execTargetDevice = getTargetDevice(true); + return (isGnaLibVersion3_0 || isGnaLibVersion2_1) && isUpTo20HwGnaDevice(execTargetDevice); +} + +Gna2DeviceVersion GNADeviceHelper::parseTarget(const std::string& target) { + const std::map targetMap { + {InferenceEngine::GNAConfigParams::GNA_TARGET_2_0, Gna2DeviceVersion2_0}, + {InferenceEngine::GNAConfigParams::GNA_TARGET_3_0, Gna2DeviceVersion3_0}, + {"", Gna2DeviceVersionSoftwareEmulation}, + }; + const auto f = targetMap.find(target); + if (f != targetMap.end()) { + return f->second; + } + THROW_GNA_EXCEPTION << "Unsupported " << "GNAConfigParams::GNA_TARGET = \"" << target << "\"\n"; } Gna2DeviceVersion GNADeviceHelper::parseDeclaredTarget(std::string target, const bool execTarget) const { @@ -476,6 +489,16 @@ void GNADeviceHelper::dumpXnnForDeviceVersion( outStream.write("Gna2ModelSueCreekHeader", 24); outStream.write(reinterpret_cast(&sueHeader), sizeof(sueHeader)); } + +void GNADeviceHelper::createVirtualDevice(Gna2DeviceVersion devVersion, std::string purpose) { + const auto status = Gna2DeviceCreateForExport(devVersion, &nGnaDeviceIndex); + GNADeviceHelper::checkGna2Status(status, "Gna2DeviceCreateForExport(" + std::to_string(devVersion) + ")" + purpose); +} + +void GNADeviceHelper::updateGnaDeviceVersion() { + const auto status = Gna2DeviceGetVersion(nGnaDeviceIndex, &detectedGnaDevVersion); + checkGna2Status(status, "Gna2DeviceGetVersion"); +} #endif #if GNA_LIB_VER == 1 @@ -492,14 +515,18 @@ void GNADeviceHelper::open(uint8_t n_threads) { nGNAHandle = GNADeviceOpenSetThreads(&nGNAStatus, n_threads); checkStatus(); #else - auto status = Gna2DeviceGetVersion(nGnaDeviceIndex, &detectedGnaDevVersion); - checkGna2Status(status, "Gna2DeviceGetVersion"); - + updateGnaDeviceVersion(); + const auto gnaExecTarget = parseTarget(executionTarget); if (useDeviceEmbeddedExport) { - status = Gna2DeviceCreateForExport(exportGeneration, &nGnaDeviceIndex); - GNADeviceHelper::checkGna2Status(status, "Gna2DeviceCreateForExport"); + createVirtualDevice(exportGeneration, "export"); + } else if (!executionTarget.empty() && gnaExecTarget != detectedGnaDevVersion) { + createVirtualDevice(gnaExecTarget, "execution"); + updateGnaDeviceVersion(); + if (detectedGnaDevVersion != gnaExecTarget) { + THROW_GNA_EXCEPTION << "Wrong virtual GNA device version reported: " << detectedGnaDevVersion << " instead of: " << gnaExecTarget; + } } else { - status = Gna2DeviceOpen(nGnaDeviceIndex); + const auto status = Gna2DeviceOpen(nGnaDeviceIndex); checkGna2Status(status, "Gna2DeviceOpen"); } diff --git a/inference-engine/src/gna_plugin/gna_device.hpp b/inference-engine/src/gna_plugin/gna_device.hpp index 5e6719607d6..4024b3c1807 100644 --- a/inference-engine/src/gna_plugin/gna_device.hpp +++ b/inference-engine/src/gna_plugin/gna_device.hpp @@ -202,9 +202,13 @@ public: static void enforceLegacyCnns(Gna2Model& gnaModel); static void enforceLegacyCnnsWhenNeeded(Gna2Model& gnaModel); + static Gna2DeviceVersion parseTarget(const std::string& target); Gna2DeviceVersion parseDeclaredTarget(std::string target, const bool execTarget) const; Gna2DeviceVersion getDefaultTarget() const; Gna2DeviceVersion getTargetDevice(bool execTarget) const; + + void createVirtualDevice(Gna2DeviceVersion devVersion, std::string purpose = ""); + void updateGnaDeviceVersion(); #endif void setOMPThreads(uint8_t const n_threads);