diff --git a/inference-engine/src/gna_plugin/backend/am_intel_dnn.cpp b/inference-engine/src/gna_plugin/backend/am_intel_dnn.cpp index 43b0fa34642..b41166c7715 100644 --- a/inference-engine/src/gna_plugin/backend/am_intel_dnn.cpp +++ b/inference-engine/src/gna_plugin/backend/am_intel_dnn.cpp @@ -1453,12 +1453,6 @@ void GNAPluginNS::backend::AMIntelDNN::InitGNAStruct(intel_nnet_type_t *ptr_nnet comp.op.conv1D.num_feature_maps * comp.op.conv1D.num_feature_map_columns), nullptr); - // TODO: GNA2: We have to explicitly enforce to use Legacy CNN - snprintf( - const_cast(gnaOperation->Operands[1]->Layout), - sizeof(gnaOperation->Operands[1]->Layout) / sizeof(char), - "GNA1"); - AdvanceCnnOperationIfAllApplied(component, i, gnaOperation); #else pLayer->nInputRows = component[i].num_rows_in; diff --git a/inference-engine/src/gna_plugin/gna_device.cpp b/inference-engine/src/gna_plugin/gna_device.cpp index ce15f2b895b..7cd9fc27127 100644 --- a/inference-engine/src/gna_plugin/gna_device.cpp +++ b/inference-engine/src/gna_plugin/gna_device.cpp @@ -90,8 +90,22 @@ uint32_t GNADeviceHelper::propagate(const uint32_t requestConfigId, Gna2Accelera return reqId; } -uint32_t GNADeviceHelper::createModel(const Gna2Model& gnaModel) const { +void GNADeviceHelper::enforceLegacyCnns(Gna2Model& gnaModel) { + for (uint32_t i = 0; i < gnaModel.NumberOfOperations; i++) { + if (gnaModel.Operations->Type == Gna2OperationTypeConvolution) { + snprintf( + const_cast(gnaModel.Operations[i].Operands[1]->Layout), + sizeof(gnaModel.Operations[i].Operands[1]->Layout) / sizeof(char), + "GNA1"); + } + } +} + +uint32_t GNADeviceHelper::createModel(Gna2Model& gnaModel) const { uint32_t modelId; + if (isUpTo20GnaDevice()) { + enforceLegacyCnns(gnaModel); + } const auto status = Gna2ModelCreate(nGnaDeviceIndex, &gnaModel, &modelId); checkGna2Status(status, gnaModel); @@ -108,7 +122,8 @@ uint32_t GNADeviceHelper::createRequestConfig(const uint32_t model_id) { auto status = Gna2RequestConfigCreate(model_id, &reqConfId); checkGna2Status(status); if (gna2HwConsistency != Gna2DeviceVersionSoftwareEmulation) { - status = Gna2RequestConfigEnableHardwareConsistency(reqConfId, gna2HwConsistency); + status = Gna2RequestConfigEnableHardwareConsistency(reqConfId, + isUpTo20GnaDevice() ? gna2HwConsistency : detectedGnaDevVersion); checkGna2Status(status); } status = Gna2InstrumentationConfigAssignToRequestConfig(instrumentationConfigId, reqConfId); diff --git a/inference-engine/src/gna_plugin/gna_device.hpp b/inference-engine/src/gna_plugin/gna_device.hpp index 0f71772d62e..d067ce9f410 100644 --- a/inference-engine/src/gna_plugin/gna_device.hpp +++ b/inference-engine/src/gna_plugin/gna_device.hpp @@ -110,7 +110,7 @@ public: void propagateSync(const uint32_t requestConfigId, Gna2AccelerationMode gna2AccelerationMode); uint32_t propagate(const uint32_t requestConfigId, Gna2AccelerationMode gna2AccelerationMode); #if GNA_LIB_VER == 2 - uint32_t createModel(const Gna2Model& gnaModel) const; + uint32_t createModel(Gna2Model& gnaModel) const; #else uint32_t createModel(const intel_nnet_type_t& intel_nnet_type); #endif @@ -119,6 +119,9 @@ public: bool hasGnaHw() const { return Gna2DeviceVersionSoftwareEmulation != detectedGnaDevVersion; } + bool isUpTo20GnaDevice() const { + return detectedGnaDevVersion <= Gna2DeviceVersion2_0; + } static void checkGna2Status(Gna2Status status); static void checkGna2Status(Gna2Status status, const Gna2Model& gnaModel); #endif @@ -166,6 +169,8 @@ public: static const std::map errorReasons; static const std::map operationTypes; static const std::map , const std::string > operandTypes; + + static void enforceLegacyCnns(Gna2Model& gnaModel); #endif void setOMPThreads(uint8_t const n_threads); diff --git a/inference-engine/src/gna_plugin/gna_plugin.cpp b/inference-engine/src/gna_plugin/gna_plugin.cpp index 3905da7c039..d7b704f5c88 100644 --- a/inference-engine/src/gna_plugin/gna_plugin.cpp +++ b/inference-engine/src/gna_plugin/gna_plugin.cpp @@ -756,7 +756,7 @@ void GNAPlugin::createRequestConfigsForGnaModels() { return; } for (auto& model : gnaModels) { - const auto& gnaNnet = std::get<0>(model).get()->obj; + auto& gnaNnet = std::get<0>(model).get()->obj; const auto modelId = gnadevice->createModel(gnaNnet); const auto requestConfigId = gnadevice->createRequestConfig(modelId); gnaRequestConfigToRequestIdMap.push_back(std::make_tuple(requestConfigId, -1, InferenceEngine::BlobMap()));