Remove CNN GNA1/2 compatibility enforcement when other GNA device detected (#2745)

This commit is contained in:
Krzysztof Bruniecki 2020-10-23 12:30:16 +02:00 committed by GitHub
parent c2271da637
commit 9c78a4855a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 24 additions and 10 deletions

View File

@ -1453,12 +1453,6 @@ void GNAPluginNS::backend::AMIntelDNN::InitGNAStruct(intel_nnet_type_t *ptr_nnet
comp.op.conv1D.num_feature_maps * comp.op.conv1D.num_feature_map_columns),
nullptr);
// TODO: GNA2: We have to explicitly enforce to use Legacy CNN
snprintf(
const_cast<char*>(gnaOperation->Operands[1]->Layout),
sizeof(gnaOperation->Operands[1]->Layout) / sizeof(char),
"GNA1");
AdvanceCnnOperationIfAllApplied(component, i, gnaOperation);
#else
pLayer->nInputRows = component[i].num_rows_in;

View File

@ -90,8 +90,22 @@ uint32_t GNADeviceHelper::propagate(const uint32_t requestConfigId, Gna2Accelera
return reqId;
}
uint32_t GNADeviceHelper::createModel(const Gna2Model& gnaModel) const {
void GNADeviceHelper::enforceLegacyCnns(Gna2Model& gnaModel) {
for (uint32_t i = 0; i < gnaModel.NumberOfOperations; i++) {
if (gnaModel.Operations->Type == Gna2OperationTypeConvolution) {
snprintf(
const_cast<char*>(gnaModel.Operations[i].Operands[1]->Layout),
sizeof(gnaModel.Operations[i].Operands[1]->Layout) / sizeof(char),
"GNA1");
}
}
}
uint32_t GNADeviceHelper::createModel(Gna2Model& gnaModel) const {
uint32_t modelId;
if (isUpTo20GnaDevice()) {
enforceLegacyCnns(gnaModel);
}
const auto status = Gna2ModelCreate(nGnaDeviceIndex, &gnaModel, &modelId);
checkGna2Status(status, gnaModel);
@ -108,7 +122,8 @@ uint32_t GNADeviceHelper::createRequestConfig(const uint32_t model_id) {
auto status = Gna2RequestConfigCreate(model_id, &reqConfId);
checkGna2Status(status);
if (gna2HwConsistency != Gna2DeviceVersionSoftwareEmulation) {
status = Gna2RequestConfigEnableHardwareConsistency(reqConfId, gna2HwConsistency);
status = Gna2RequestConfigEnableHardwareConsistency(reqConfId,
isUpTo20GnaDevice() ? gna2HwConsistency : detectedGnaDevVersion);
checkGna2Status(status);
}
status = Gna2InstrumentationConfigAssignToRequestConfig(instrumentationConfigId, reqConfId);

View File

@ -110,7 +110,7 @@ public:
void propagateSync(const uint32_t requestConfigId, Gna2AccelerationMode gna2AccelerationMode);
uint32_t propagate(const uint32_t requestConfigId, Gna2AccelerationMode gna2AccelerationMode);
#if GNA_LIB_VER == 2
uint32_t createModel(const Gna2Model& gnaModel) const;
uint32_t createModel(Gna2Model& gnaModel) const;
#else
uint32_t createModel(const intel_nnet_type_t& intel_nnet_type);
#endif
@ -119,6 +119,9 @@ public:
bool hasGnaHw() const {
return Gna2DeviceVersionSoftwareEmulation != detectedGnaDevVersion;
}
bool isUpTo20GnaDevice() const {
return detectedGnaDevVersion <= Gna2DeviceVersion2_0;
}
static void checkGna2Status(Gna2Status status);
static void checkGna2Status(Gna2Status status, const Gna2Model& gnaModel);
#endif
@ -166,6 +169,8 @@ public:
static const std::map <Gna2ErrorType, const std::string> errorReasons;
static const std::map <Gna2OperationType, const std::string> operationTypes;
static const std::map <const std::pair<Gna2OperationType, int32_t>, const std::string > operandTypes;
static void enforceLegacyCnns(Gna2Model& gnaModel);
#endif
void setOMPThreads(uint8_t const n_threads);

View File

@ -756,7 +756,7 @@ void GNAPlugin::createRequestConfigsForGnaModels() {
return;
}
for (auto& model : gnaModels) {
const auto& gnaNnet = std::get<0>(model).get()->obj;
auto& gnaNnet = std::get<0>(model).get()->obj;
const auto modelId = gnadevice->createModel(gnaNnet);
const auto requestConfigId = gnadevice->createRequestConfig(modelId);
gnaRequestConfigToRequestIdMap.push_back(std::make_tuple(requestConfigId, -1, InferenceEngine::BlobMap()));