[IE] Preserve output data name after merging and update output data map (#1092)
Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>
This commit is contained in:
@@ -272,6 +272,48 @@ void CombineData(DataPtr& master, DataPtr& slave) {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Preserve output data name and update output data map of the network
|
||||
*
|
||||
* @param in_data name to update
|
||||
* @param out_data name to preserve
|
||||
* @param net output data map to update with in_data
|
||||
*/
|
||||
template <typename NET>
|
||||
void SaveOutputDataName(InferenceEngine::DataPtr in_data, InferenceEngine::DataPtr out_data, NET &net) {
|
||||
// TODO: update outputs of the network if out_data was output
|
||||
if (out_data->getInputTo().empty()) {
|
||||
auto data_name = out_data->getName();
|
||||
in_data->setName(data_name);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* void SaveOutputDataName(InferenceEngine::DataPtr in_data, InferenceEngine::DataPtr out_data, NET &net), where
|
||||
* NET = ICNNNetwork
|
||||
*/
|
||||
void SaveOutputDataName(InferenceEngine::DataPtr in_data, InferenceEngine::DataPtr out_data, ICNNNetwork& net) {
|
||||
if (out_data->getInputTo().empty()) {
|
||||
InferenceEngine::OutputsDataMap outputs_data_map;
|
||||
net.getOutputsInfo(outputs_data_map);
|
||||
auto out_data_name = out_data->getName();
|
||||
in_data->setName(out_data_name);
|
||||
if (outputs_data_map.count(out_data_name)) {
|
||||
auto parent_layer_ptr = in_data->getCreatorLayer().lock();
|
||||
IE_ASSERT(parent_layer_ptr != nullptr);
|
||||
auto parent_layer_name = parent_layer_ptr->name;
|
||||
size_t in_data_out_index = 0;
|
||||
for (size_t ind = 0; ind < parent_layer_ptr->outData.size(); ++ind) {
|
||||
if (parent_layer_ptr->outData[ind] == in_data) {
|
||||
in_data_out_index = ind;
|
||||
}
|
||||
}
|
||||
net.addOutput(parent_layer_name, in_data_out_index);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Remove layer form graph
|
||||
* May be applied only for inplace layer. One input, one output,
|
||||
@@ -279,7 +321,8 @@ void CombineData(DataPtr& master, DataPtr& slave) {
|
||||
*
|
||||
* @param layer to remove from graph
|
||||
*/
|
||||
void RemoveLayer(CNNLayerPtr& layer) {
|
||||
template <typename NET>
|
||||
void RemoveLayer(CNNLayerPtr& layer, NET &net) {
|
||||
IE_ASSERT(layer->insData.size() == 1);
|
||||
IE_ASSERT(layer->outData.size() == 1);
|
||||
|
||||
@@ -299,10 +342,8 @@ void RemoveLayer(CNNLayerPtr& layer) {
|
||||
// transfer output connections into parent data
|
||||
CombineData(in_data, out_data);
|
||||
|
||||
// Save name for output data
|
||||
if (out_data->getInputTo().empty()) {
|
||||
in_data->setName(out_data->getName());
|
||||
}
|
||||
// save name for output data and update network output
|
||||
SaveOutputDataName(in_data, out_data, net);
|
||||
}
|
||||
|
||||
/************************************************************/
|
||||
@@ -1371,7 +1412,7 @@ void fixConvertLayers(NET &net) {
|
||||
}
|
||||
}
|
||||
for (auto &layer : to_remove) {
|
||||
RemoveLayer(layer);
|
||||
RemoveLayer(layer, net);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user