Remove CNN GNA1/2 compatibility enforcement when other GNA device detected (#2745)
This commit is contained in:
parent
c2271da637
commit
9c78a4855a
@ -1453,12 +1453,6 @@ void GNAPluginNS::backend::AMIntelDNN::InitGNAStruct(intel_nnet_type_t *ptr_nnet
|
||||
comp.op.conv1D.num_feature_maps * comp.op.conv1D.num_feature_map_columns),
|
||||
nullptr);
|
||||
|
||||
// TODO: GNA2: We have to explicitly enforce to use Legacy CNN
|
||||
snprintf(
|
||||
const_cast<char*>(gnaOperation->Operands[1]->Layout),
|
||||
sizeof(gnaOperation->Operands[1]->Layout) / sizeof(char),
|
||||
"GNA1");
|
||||
|
||||
AdvanceCnnOperationIfAllApplied(component, i, gnaOperation);
|
||||
#else
|
||||
pLayer->nInputRows = component[i].num_rows_in;
|
||||
|
@ -90,8 +90,22 @@ uint32_t GNADeviceHelper::propagate(const uint32_t requestConfigId, Gna2Accelera
|
||||
return reqId;
|
||||
}
|
||||
|
||||
uint32_t GNADeviceHelper::createModel(const Gna2Model& gnaModel) const {
|
||||
void GNADeviceHelper::enforceLegacyCnns(Gna2Model& gnaModel) {
|
||||
for (uint32_t i = 0; i < gnaModel.NumberOfOperations; i++) {
|
||||
if (gnaModel.Operations->Type == Gna2OperationTypeConvolution) {
|
||||
snprintf(
|
||||
const_cast<char*>(gnaModel.Operations[i].Operands[1]->Layout),
|
||||
sizeof(gnaModel.Operations[i].Operands[1]->Layout) / sizeof(char),
|
||||
"GNA1");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t GNADeviceHelper::createModel(Gna2Model& gnaModel) const {
|
||||
uint32_t modelId;
|
||||
if (isUpTo20GnaDevice()) {
|
||||
enforceLegacyCnns(gnaModel);
|
||||
}
|
||||
const auto status = Gna2ModelCreate(nGnaDeviceIndex, &gnaModel, &modelId);
|
||||
|
||||
checkGna2Status(status, gnaModel);
|
||||
@ -108,7 +122,8 @@ uint32_t GNADeviceHelper::createRequestConfig(const uint32_t model_id) {
|
||||
auto status = Gna2RequestConfigCreate(model_id, &reqConfId);
|
||||
checkGna2Status(status);
|
||||
if (gna2HwConsistency != Gna2DeviceVersionSoftwareEmulation) {
|
||||
status = Gna2RequestConfigEnableHardwareConsistency(reqConfId, gna2HwConsistency);
|
||||
status = Gna2RequestConfigEnableHardwareConsistency(reqConfId,
|
||||
isUpTo20GnaDevice() ? gna2HwConsistency : detectedGnaDevVersion);
|
||||
checkGna2Status(status);
|
||||
}
|
||||
status = Gna2InstrumentationConfigAssignToRequestConfig(instrumentationConfigId, reqConfId);
|
||||
|
@ -110,7 +110,7 @@ public:
|
||||
void propagateSync(const uint32_t requestConfigId, Gna2AccelerationMode gna2AccelerationMode);
|
||||
uint32_t propagate(const uint32_t requestConfigId, Gna2AccelerationMode gna2AccelerationMode);
|
||||
#if GNA_LIB_VER == 2
|
||||
uint32_t createModel(const Gna2Model& gnaModel) const;
|
||||
uint32_t createModel(Gna2Model& gnaModel) const;
|
||||
#else
|
||||
uint32_t createModel(const intel_nnet_type_t& intel_nnet_type);
|
||||
#endif
|
||||
@ -119,6 +119,9 @@ public:
|
||||
bool hasGnaHw() const {
|
||||
return Gna2DeviceVersionSoftwareEmulation != detectedGnaDevVersion;
|
||||
}
|
||||
bool isUpTo20GnaDevice() const {
|
||||
return detectedGnaDevVersion <= Gna2DeviceVersion2_0;
|
||||
}
|
||||
static void checkGna2Status(Gna2Status status);
|
||||
static void checkGna2Status(Gna2Status status, const Gna2Model& gnaModel);
|
||||
#endif
|
||||
@ -166,6 +169,8 @@ public:
|
||||
static const std::map <Gna2ErrorType, const std::string> errorReasons;
|
||||
static const std::map <Gna2OperationType, const std::string> operationTypes;
|
||||
static const std::map <const std::pair<Gna2OperationType, int32_t>, const std::string > operandTypes;
|
||||
|
||||
static void enforceLegacyCnns(Gna2Model& gnaModel);
|
||||
#endif
|
||||
void setOMPThreads(uint8_t const n_threads);
|
||||
|
||||
|
@ -756,7 +756,7 @@ void GNAPlugin::createRequestConfigsForGnaModels() {
|
||||
return;
|
||||
}
|
||||
for (auto& model : gnaModels) {
|
||||
const auto& gnaNnet = std::get<0>(model).get()->obj;
|
||||
auto& gnaNnet = std::get<0>(model).get()->obj;
|
||||
const auto modelId = gnadevice->createModel(gnaNnet);
|
||||
const auto requestConfigId = gnadevice->createRequestConfig(modelId);
|
||||
gnaRequestConfigToRequestIdMap.push_back(std::make_tuple(requestConfigId, -1, InferenceEngine::BlobMap()));
|
||||
|
Loading…
Reference in New Issue
Block a user