remove InsertIdentityToLSTMCellPass and fix InsertIdentity (#6962)

* check is node final non-functional for grouping; remove InsertIdentityToLSTMCellPass

* code style fix
This commit is contained in:
Evgeny Kotov 2021-08-10 13:39:16 +03:00 committed by GitHub
parent 7c82ad78ee
commit cc76d38920
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 8 additions and 40 deletions

View File

@ -757,7 +757,6 @@ void GNAPlugin::LoadNetwork(CNNNetwork & _network) {
passes->registerPass<RemoveConstPass>();
passes->registerPass<UnrollTIPass>();
passes->registerPass<RemoveConstPass>();
passes->registerPass<InsertIdentityToLSTMCellPass>();
passes->registerPass<UnrollLSTMCellPass>();
passes->registerPass<RemoveSingleInputConcatPass>();

View File

@ -131,6 +131,12 @@ static CNNLayerPtr InsertCopyLayer(CNNLayerPtr prevLayer, CNNLayerPtr nextLayer,
return copyWithQuant;
}
static bool hasNextFuncLayer(const CNNLayerPtr layer) {
return CNNNetHasNextLayerSkipCertain(layer, 0, 0, [](CNNLayerPtr layer) {
return LayerInfo(layer).isNonFunctional();
});
}
static std::vector<CNNLayerPtr> getCandidatesForIdentityInsertion(const CNNLayerPtr l, std::shared_ptr<IPassManager> passmanager) {
std::vector<CNNLayerPtr> prevLayers;
@ -796,7 +802,8 @@ void InsertIdentityLayerPass::run() {
for (auto && nextLayer : getInputTo(nextData)) {
if (nextLayer.second.get() == l.get())
continue;
if (getCandidatesForIdentityInsertion(nextLayer.second, getPassManager()).empty()) {
if (getCandidatesForIdentityInsertion(nextLayer.second, getPassManager()).empty() &&
hasNextFuncLayer(nextLayer.second)) {
notAll = true;
}
}
@ -1608,44 +1615,6 @@ void BroadcastConstPass::run() {
}
}
void InsertIdentityToLSTMCellPass::run() {
OV_ITT_SCOPED_TASK(itt::domains::GNA_LT, "InsertIdentityToLSTMCellPass");
for (auto layer : *pLayers) {
if (layer->type == "LSTMCell") {
// This fixed the cases when both functional and non-functional outputs are mixed (or not outputs are used)
// which results in scratch buffer being used so outputs cannot be used in form of blob or by non-functional layers
// downside is scaling down from i32 to i16 which may
for (int output_idx = 0; output_idx < layer->outData.size(); output_idx++) {
int numOfIdentityLayers = ((this->getPassManager())->getIntVar(identityLayersCounterName))++;
auto activationName = std::string("lstm_identity_") + std::to_string(numOfIdentityLayers);
auto& output = layer->outData[output_idx];
auto& input_to = getInputTo(output);
CNNLayerPtr activationLayer =
std::make_shared<GenericLayer>(LayerParams({activationName, "identity", InferenceEngine::Precision::FP32}));
auto dataPtr = std::make_shared<Data>("lstm_identity_data_" + std::to_string(numOfIdentityLayers), output->getTensorDesc());
auto quantized = InferenceEngine::getInjectedData<QuantizedLayerParams>(layer);
auto activationLayerWithQuant = quantized ? InferenceEngine::injectData<QuantizedLayerParams>(activationLayer) : activationLayer;
getCreatorLayer(dataPtr) = activationLayerWithQuant;
activationLayerWithQuant->outData.push_back(dataPtr);
activationLayerWithQuant->insData.push_back(output);
auto& activationInputTo = getInputTo(dataPtr);
for (auto& input : input_to) {
auto& next_layer = input.second;
activationInputTo[input.first] = next_layer;
std::replace_if(std::begin(next_layer->insData), std::end(next_layer->insData),
[output](DataWeakPtr data) { return data.lock() == output; }, dataPtr);
}
input_to.clear();
input_to[activationName] = activationLayerWithQuant;
}
}
}
}
void BreakFusingOfOutputLayersPass::run() {
OV_ITT_SCOPED_TASK(itt::domains::GNA_LT, "BreakFusingOfOutputLayersPass");
#if GNA_LIB_VER == 1