Improve CopyTIBody util to cover disconnected graphs

Signed-off-by: Alexander Peskov <alexander.peskov@intel.com>
This commit is contained in:
Alexander Peskov 2020-10-19 20:16:34 +03:00
parent 16bb29f29c
commit ec1561a234

View File

@ -46,15 +46,28 @@ static std::vector<DataPtr> getAllInputs(const std::vector<DataPtr>& heads) {
CNNLayerSet inputLayers;
std::unordered_set<CNNLayer*> allLayers;
// define any layer connected to provided Data object (consumer or creator)
auto findConnectedLayer = [] (const DataPtr& data) -> CNNLayerPtr {
auto consumerLayers = getInputTo(data);
if (!consumerLayers.empty())
return consumerLayers.begin()->second;
auto creator = getCreatorLayer(data).lock();
if (creator != nullptr)
return creator;
return nullptr;
};
// Define all start layers
for (const auto& data : heads) {
auto& secondLayers = getInputTo(data);
auto entryLayer = findConnectedLayer(data);
if (secondLayers.empty()) continue;
if (entryLayer == nullptr) continue;
details::UnorderedDFS(
allLayers, secondLayers.begin()->second,
[&](CNNLayerPtr layer) {
allLayers, entryLayer,
[&inputLayers](const CNNLayerPtr &layer) {
if (layer->insData.empty()) {
inputLayers.insert(layer);
}
@ -77,10 +90,17 @@ static std::vector<DataPtr> getAllInputs(const std::vector<DataPtr>& heads) {
std::vector<CNNLayerPtr> TIBodySortTopologically(const TensorIterator::Body& body) {
std::vector<CNNLayerPtr> all_layers;
auto all_input_layers = getAllInputs(body.inputs);
// In case of graph with several connected component
// total entry point is a union of [inputs]U[outputs]
// All internal nodes are achievable starting from this.
auto total_entry_point = body.inputs;
total_entry_point.insert(total_entry_point.end(),
body.outputs.begin(), body.outputs.end());
auto all_input_layers = getAllInputs(total_entry_point);
CNNNetForestDFS(
all_input_layers,
[&](CNNLayerPtr current) {
[&all_layers](const CNNLayerPtr &current) {
all_layers.push_back(current);
},
false);