[GNA] Allow 2d reshape of the first diagonal layer (#6115)

This commit is contained in:
Elizaveta Lobanova 2021-06-16 16:19:21 +03:00 committed by GitHub
parent 2c775d48b2
commit 5c55d390e8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 3 additions and 12 deletions

View File

@ -52,13 +52,6 @@ inline bool HasTo2DReshapeData(InferenceEngine::CNNLayerPtr layer) {
if (!GNAPluginNS::LayerInfo(layer).isSyntheticScaleShift())
return false;
// Don't reshape the first dnn layer since it breaks groups recognition
auto prevLayer = InferenceEngine::CNNNetPrevLayerSkipCertain(layer, 0, [](InferenceEngine::CNNLayerPtr ptr) {
return LayerInfo(ptr).isNonValuesChangable();
});
IE_ASSERT(prevLayer != nullptr);
if (LayerInfo(prevLayer).isInput()) return false;
// Don't reshape diagonallayers with bias connection
return !GNAPluginNS::LayerInfo(getCreatorLayer(layer->insData.front().lock()).lock()).has32BOutput();
}

View File

@ -85,9 +85,8 @@ static void insertDiagonalLayerBetween(InferenceEngine::CNNLayerPtr prevLayer,
return LayerInfo(ptr).isNonValuesChangable();
});
IE_ASSERT(inputLayer != nullptr);
size_t weightsSize = (LayerInfo(prevLayer).has32BOutput() || LayerInfo(inputLayer).isInput()) ?
nextLayer->outData[0]->getDims().back() :
Get2DReshapedData(nextLayer->outData[0], 8)->getDims()[1];
size_t weightsSize = LayerInfo(prevLayer).has32BOutput() ? nextLayer->outData[0]->getDims().back() :
Get2DReshapedData(nextLayer->outData[0], 8)->getDims()[1];
std::vector<float> weightsValues(weightsSize, fillValue);
IE_ASSERT(diagLayer != nullptr);
diagLayer->_weights = make_shared_blob<float>(

View File

@ -217,8 +217,7 @@ INSTANTIATE_TEST_CASE_P(smoke_ConvertMatmulToPointwiseConvTest, ConvertMatmulToP
::testing::ValuesIn(inputShape)),
ConvertMatmulToPointwiseConv::getTestCaseName);
// Issue 55662
INSTANTIATE_TEST_CASE_P(DISABLED_smoke_ConvertMatmulToPointwiseConvTest, ConvertMatmulToPointwiseConvWithFq,
INSTANTIATE_TEST_CASE_P(smoke_ConvertMatmulToPointwiseConvTest, ConvertMatmulToPointwiseConvWithFq,
::testing::Combine(
::testing::ValuesIn(netPrecisions),
::testing::Values(CommonTestUtils::DEVICE_GNA),