[IE] Preserve output data name after merging and update output data map (#1092)

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>
This commit is contained in:
Roman Kazantsev
2020-06-24 14:30:25 +05:00
committed by GitHub
parent 32054ff180
commit c26ec8b312

View File

@@ -272,6 +272,48 @@ void CombineData(DataPtr& master, DataPtr& slave) {
}
}
/**
* Preserve output data name and update output data map of the network
*
* @param in_data name to update
* @param out_data name to preserve
* @param net output data map to update with in_data
*/
template <typename NET>
void SaveOutputDataName(InferenceEngine::DataPtr in_data, InferenceEngine::DataPtr out_data, NET &net) {
// TODO: update outputs of the network if out_data was output
if (out_data->getInputTo().empty()) {
auto data_name = out_data->getName();
in_data->setName(data_name);
}
}
/**
* void SaveOutputDataName(InferenceEngine::DataPtr in_data, InferenceEngine::DataPtr out_data, NET &net), where
* NET = ICNNNetwork
*/
void SaveOutputDataName(InferenceEngine::DataPtr in_data, InferenceEngine::DataPtr out_data, ICNNNetwork& net) {
if (out_data->getInputTo().empty()) {
InferenceEngine::OutputsDataMap outputs_data_map;
net.getOutputsInfo(outputs_data_map);
auto out_data_name = out_data->getName();
in_data->setName(out_data_name);
if (outputs_data_map.count(out_data_name)) {
auto parent_layer_ptr = in_data->getCreatorLayer().lock();
IE_ASSERT(parent_layer_ptr != nullptr);
auto parent_layer_name = parent_layer_ptr->name;
size_t in_data_out_index = 0;
for (size_t ind = 0; ind < parent_layer_ptr->outData.size(); ++ind) {
if (parent_layer_ptr->outData[ind] == in_data) {
in_data_out_index = ind;
}
}
net.addOutput(parent_layer_name, in_data_out_index);
}
}
}
/**
* Remove layer form graph
* May be applied only for inplace layer. One input, one output,
@@ -279,7 +321,8 @@ void CombineData(DataPtr& master, DataPtr& slave) {
*
* @param layer to remove from graph
*/
void RemoveLayer(CNNLayerPtr& layer) {
template <typename NET>
void RemoveLayer(CNNLayerPtr& layer, NET &net) {
IE_ASSERT(layer->insData.size() == 1);
IE_ASSERT(layer->outData.size() == 1);
@@ -299,10 +342,8 @@ void RemoveLayer(CNNLayerPtr& layer) {
// transfer output connections into parent data
CombineData(in_data, out_data);
// Save name for output data
if (out_data->getInputTo().empty()) {
in_data->setName(out_data->getName());
}
// save name for output data and update network output
SaveOutputDataName(in_data, out_data, net);
}
/************************************************************/
@@ -1371,7 +1412,7 @@ void fixConvertLayers(NET &net) {
}
}
for (auto &layer : to_remove) {
RemoveLayer(layer);
RemoveLayer(layer, net);
}
}