Hetero creates single result node (#9572)

This commit is contained in:
Anton Pankratov
2022-01-14 12:26:39 +03:00
committed by GitHub
parent 12ab842970
commit 790f02c0b1
3 changed files with 95 additions and 18 deletions

View File

@@ -303,29 +303,37 @@ HeteroExecutableNetwork::HeteroExecutableNetwork(const InferenceEngine::CNNNetwo
// Break graph using insertion of result parameter split
NodeMap<ngraph::Node*> subgraphParameterToPrevResult;
std::vector<std::shared_ptr<ngraph::op::Result>> results;
for (auto&& input : subgraphInputs) {
if (!ngraph::op::is_parameter(input.get_node()) && !ngraph::op::is_constant(input.get_node())) {
auto output = input.get_source_output();
output.remove_target_input(input);
{
std::set<ngraph::Output<ngraph::Node>> subgraphOutputs;
for (auto&& input : subgraphInputs) {
if (!ngraph::op::is_parameter(input.get_node()) && !ngraph::op::is_constant(input.get_node())) {
subgraphOutputs.insert(input.get_source_output());
}
}
for (auto&& output : subgraphOutputs) {
auto inputs = output.get_target_inputs();
auto result = std::make_shared<ngraph::op::Result>(output);
result->set_friendly_name(output.get_node()->get_friendly_name() + "_" +
std::to_string(output.get_index()) + "_result");
ngraph::copy_runtime_info(output.get_node_shared_ptr(), result);
auto parameter =
std::make_shared<ngraph::op::Parameter>(output.get_element_type(), output.get_partial_shape());
parameter->set_friendly_name(input.get_node()->get_friendly_name() + "_" +
std::to_string(input.get_index()) + "_parameter");
ngraph::copy_runtime_info(input.get_node()->shared_from_this(), parameter);
input.replace_source_output(parameter->output(0));
results.push_back(result);
subgraphIds.emplace(result.get(), subgraphIds[output.get_node()]);
subgraphIds.emplace(parameter.get(), subgraphIds[input.get_node()]);
subgraphParameterToPrevResult.emplace(parameter.get(), result.get());
_blobNameMap.emplace(
parameter->get_friendly_name(),
output.get_node()->get_friendly_name() + ((output.get_node()->get_output_size() != 1)
? ("." + std::to_string(output.get_index()))
: std::string{}));
results.push_back(result);
for (auto&& input : inputs) {
output.remove_target_input(input);
auto parameter =
std::make_shared<ngraph::op::Parameter>(output.get_element_type(), output.get_partial_shape());
parameter->set_friendly_name(input.get_node()->get_friendly_name() + "_" +
std::to_string(input.get_index()) + "_parameter");
ngraph::copy_runtime_info(input.get_node()->shared_from_this(), parameter);
input.replace_source_output(parameter->output(0));
subgraphIds.emplace(parameter.get(), subgraphIds[input.get_node()]);
subgraphParameterToPrevResult.emplace(parameter.get(), result.get());
_blobNameMap.emplace(
parameter->get_friendly_name(),
output.get_node()->get_friendly_name() + ((output.get_node()->get_output_size() != 1)
? ("." + std::to_string(output.get_index()))
: std::string{}));
}
}
}
@@ -353,6 +361,7 @@ HeteroExecutableNetwork::HeteroExecutableNetwork(const InferenceEngine::CNNNetwo
subgraph._affinity = itAffinity->second;
}
}
results = {};
// Subgraph topological sort
std::vector<Subgraph> allSubgraphs;