Improve CopyTIBody util to cover disconnected graphs
Signed-off-by: Alexander Peskov <alexander.peskov@intel.com>
This commit is contained in:
parent
16bb29f29c
commit
ec1561a234
@ -46,15 +46,28 @@ static std::vector<DataPtr> getAllInputs(const std::vector<DataPtr>& heads) {
|
||||
CNNLayerSet inputLayers;
|
||||
std::unordered_set<CNNLayer*> allLayers;
|
||||
|
||||
// define any layer connected to provided Data object (consumer or creator)
|
||||
auto findConnectedLayer = [] (const DataPtr& data) -> CNNLayerPtr {
|
||||
auto consumerLayers = getInputTo(data);
|
||||
if (!consumerLayers.empty())
|
||||
return consumerLayers.begin()->second;
|
||||
|
||||
auto creator = getCreatorLayer(data).lock();
|
||||
if (creator != nullptr)
|
||||
return creator;
|
||||
|
||||
return nullptr;
|
||||
};
|
||||
|
||||
// Define all start layers
|
||||
for (const auto& data : heads) {
|
||||
auto& secondLayers = getInputTo(data);
|
||||
auto entryLayer = findConnectedLayer(data);
|
||||
|
||||
if (secondLayers.empty()) continue;
|
||||
if (entryLayer == nullptr) continue;
|
||||
|
||||
details::UnorderedDFS(
|
||||
allLayers, secondLayers.begin()->second,
|
||||
[&](CNNLayerPtr layer) {
|
||||
allLayers, entryLayer,
|
||||
[&inputLayers](const CNNLayerPtr &layer) {
|
||||
if (layer->insData.empty()) {
|
||||
inputLayers.insert(layer);
|
||||
}
|
||||
@ -77,10 +90,17 @@ static std::vector<DataPtr> getAllInputs(const std::vector<DataPtr>& heads) {
|
||||
std::vector<CNNLayerPtr> TIBodySortTopologically(const TensorIterator::Body& body) {
|
||||
std::vector<CNNLayerPtr> all_layers;
|
||||
|
||||
auto all_input_layers = getAllInputs(body.inputs);
|
||||
// In case of graph with several connected component
|
||||
// total entry point is a union of [inputs]U[outputs]
|
||||
// All internal nodes are achievable starting from this.
|
||||
auto total_entry_point = body.inputs;
|
||||
total_entry_point.insert(total_entry_point.end(),
|
||||
body.outputs.begin(), body.outputs.end());
|
||||
|
||||
auto all_input_layers = getAllInputs(total_entry_point);
|
||||
CNNNetForestDFS(
|
||||
all_input_layers,
|
||||
[&](CNNLayerPtr current) {
|
||||
[&all_layers](const CNNLayerPtr ¤t) {
|
||||
all_layers.push_back(current);
|
||||
},
|
||||
false);
|
||||
|
Loading…
Reference in New Issue
Block a user