[GNA] KSOFunction test fix (#6678)
* [GNA] KSOFunction test fix * Lambda for dimensions matching check
This commit is contained in:
parent
c3c26b4807
commit
955c0a6c05
@ -752,12 +752,14 @@ void GNAPlugin::LoadNetwork(CNNNetwork & _network) {
|
||||
passes->registerPass<FuseFQIntoWeightsPass>();
|
||||
passes->registerPass<MoveFakeQuantizeLayerIntoQuantParamsPass>();
|
||||
|
||||
passes->registerPass<SubstituteScaleShiftBroadCastPass>();
|
||||
passes->registerPass<BroadcastConstPass>();
|
||||
|
||||
passes->registerPass<TransposeWeightsFromNCHWToNHWCPass>();
|
||||
|
||||
passes->registerPass<SubstitutePReluPass>();
|
||||
passes->registerPass<SubstituteSoftSignPass>();
|
||||
|
||||
passes->registerPass<BroadcastConstPass>();
|
||||
passes->registerPass<ReorderMaxPoolPass>();
|
||||
passes->registerPass<EltwiseSplitOverChannelsPass>();
|
||||
passes->registerPass<InsertSplitAligningFilterPass>();
|
||||
@ -775,7 +777,6 @@ void GNAPlugin::LoadNetwork(CNNNetwork & _network) {
|
||||
#if GNA_LIB_VER == 2
|
||||
passes->registerPass<ForbidActivationFusingPass>();
|
||||
#endif
|
||||
passes->registerPass<SubstituteScaleShiftBroadCastPass>();
|
||||
passes->registerPass<FuseMultipleIdentitiesPass>();
|
||||
passIdx = passes->run(passIdx);
|
||||
};
|
||||
|
@ -1530,16 +1530,7 @@ void SubstituteScaleShiftBroadCastPass::run() {
|
||||
continue;
|
||||
}
|
||||
|
||||
// only 3d scaleshift supported where number of c is arbitrary
|
||||
auto lastD = reshape_batch ? dataDims[1] : dataDims.back();
|
||||
if (lastD != weightsElements) {
|
||||
THROW_GNA_EXCEPTION << "Unsupported layer: " << l->name
|
||||
<< " should have last dim(" << lastD << ") equal to weights(" << weightsElements << ") length";
|
||||
}
|
||||
if (dataDims.size() == 2) {
|
||||
THROW_GNA_EXCEPTION << "For layer: " << l->name
|
||||
<< " weights size(" << weightsElements<< ") invalid: should match input size of(" << lastD << ")";
|
||||
}
|
||||
// TODO: add broadcasting rules checks
|
||||
|
||||
gnalog() << "Substitution ScaleShift broadcast for layer: " << l->name << "\n";
|
||||
if (nElements % scaleShift->_weights->size()) {
|
||||
@ -2220,6 +2211,17 @@ void TransposeWeightsFromNCHWToNHWCPass::run() {
|
||||
}
|
||||
};
|
||||
|
||||
auto transpInfoMatchWeightsSize = [](const std::vector<TranspositionInfo> &transpositionInfo, size_t weightsSize, const std::string &layerName) {
|
||||
size_t totalElements = 0;
|
||||
for (auto && transpositionInfoPart : transpositionInfo) {
|
||||
totalElements += transpositionInfoPart.num_transpose_rows * transpositionInfoPart.num_transpose_columns;
|
||||
}
|
||||
if (totalElements != weightsSize) {
|
||||
THROW_GNA_EXCEPTION << layerName << " weights elements from transposition info (" << totalElements
|
||||
<< ") don't match input dimensions (" << weightsSize << ")";
|
||||
}
|
||||
};
|
||||
|
||||
for (auto &&l : *pLayers) {
|
||||
if (LayerInfo(l).isScaleShift()) {
|
||||
std::vector<TranspositionInfo> transpositionInfo;
|
||||
@ -2237,6 +2239,10 @@ void TransposeWeightsFromNCHWToNHWCPass::run() {
|
||||
}
|
||||
auto weightable = dynamic_cast<WeightableLayer*>(l.get());
|
||||
IE_ASSERT(weightable != nullptr);
|
||||
|
||||
size_t totalWeights = weightable->_weights->size();
|
||||
transpInfoMatchWeightsSize(transpositionInfo, totalWeights, l->name);
|
||||
|
||||
ConvertTensorFromNCHWToNHWC(weightable->precision.size(), 1, weightable->_weights->size(),
|
||||
weightable->_weights->cbuffer().as<uint8_t*>(), true, transpositionInfo);
|
||||
if (weightable->_biases) {
|
||||
@ -2270,14 +2276,9 @@ void TransposeWeightsFromNCHWToNHWCPass::run() {
|
||||
// If we found a split it's not possible to rotate data
|
||||
THROW_GNA_EXCEPTION << l->name << " won't be transposed due to a split before it";
|
||||
}
|
||||
size_t totalColumns = 0;
|
||||
for (auto && transpositionInfoPart : transpositionInfo) {
|
||||
totalColumns += transpositionInfoPart.num_transpose_rows * transpositionInfoPart.num_transpose_columns;
|
||||
}
|
||||
if (weightsColumns != totalColumns) {
|
||||
THROW_GNA_EXCEPTION << l->name << " weights columns from transposition info (" << totalColumns
|
||||
<< ") don't match input dimensions (" << weightsColumns << ")";
|
||||
}
|
||||
|
||||
transpInfoMatchWeightsSize(transpositionInfo, weightsColumns, l->name);
|
||||
|
||||
ConvertTensorFromNCHWToNHWC(precision, weightsRows, weightsColumns, weightable->_weights->cbuffer().as<uint8_t*>(),
|
||||
true, transpositionInfo);
|
||||
gnalog() << l->name << " weights rows transposition info:\n";
|
||||
@ -2297,14 +2298,9 @@ void TransposeWeightsFromNCHWToNHWCPass::run() {
|
||||
// If we found a concat it's not possible to rotate data
|
||||
THROW_GNA_EXCEPTION << l->name << " won't be transposed due to a concat after it";
|
||||
}
|
||||
size_t totalRows = 0;
|
||||
for (const auto& transpositionInfoPart : transpositionInfo) {
|
||||
totalRows += transpositionInfoPart.num_transpose_rows * transpositionInfoPart.num_transpose_columns;
|
||||
}
|
||||
if (weightsRows != totalRows) {
|
||||
THROW_GNA_EXCEPTION << l->name << " weights rows from transposition info (" << totalRows
|
||||
<< ") don't match output dimensions (" << weightsRows << ")";
|
||||
}
|
||||
|
||||
transpInfoMatchWeightsSize(transpositionInfo, weightsRows, l->name);
|
||||
|
||||
ConvertTensorFromNCHWToNHWC(precision, weightsRows, weightsColumns, weightable->_weights->cbuffer().as<uint8_t*>(),
|
||||
false, transpositionInfo);
|
||||
gnalog() << l->name << " weights columns transposition info:\n";
|
||||
|
@ -44,8 +44,6 @@ std::vector<std::string> disabledTestPatterns() {
|
||||
R"(.*ConstantResultSubgraphTest.*inPrc=(U8|I8|I32|U64|I64|BOOL).*)",
|
||||
// TODO: Issue 51528
|
||||
R"(.*CachingSupport.*_(u8|i16)_.*)",
|
||||
// TODO: Issue 51525
|
||||
R"(.*CachingSupport.*KSOFunction.*)",
|
||||
// TODO: Issue 57363 (Param -> Result subgraphs)
|
||||
R"(.*smoke_MemoryTest.*LOW_LATENCY.*iteration_count=1_.*)",
|
||||
// TODO: Issue 57368 (accuracy)
|
||||
|
@ -13,7 +13,8 @@ const std::vector<std::vector<std::vector<size_t>>> shapes = {
|
||||
{{1, 64}, {64, 1}},
|
||||
{{8, 256}, {16, 128}},
|
||||
{{6, 384}, {18, 128}},
|
||||
{{8, 2048}, {32, 512}}
|
||||
{{8, 2048}, {32, 512}},
|
||||
{{2, 4, 64, 64}, {1, 8, 64, 64}}
|
||||
};
|
||||
|
||||
const std::vector<InferenceEngine::Precision> netPrecisions = {
|
||||
|
Loading…
Reference in New Issue
Block a user