Hetero creates single result node (#9572)
This commit is contained in:
@@ -303,29 +303,37 @@ HeteroExecutableNetwork::HeteroExecutableNetwork(const InferenceEngine::CNNNetwo
|
||||
// Break graph using insertion of result parameter split
|
||||
NodeMap<ngraph::Node*> subgraphParameterToPrevResult;
|
||||
std::vector<std::shared_ptr<ngraph::op::Result>> results;
|
||||
for (auto&& input : subgraphInputs) {
|
||||
if (!ngraph::op::is_parameter(input.get_node()) && !ngraph::op::is_constant(input.get_node())) {
|
||||
auto output = input.get_source_output();
|
||||
output.remove_target_input(input);
|
||||
{
|
||||
std::set<ngraph::Output<ngraph::Node>> subgraphOutputs;
|
||||
for (auto&& input : subgraphInputs) {
|
||||
if (!ngraph::op::is_parameter(input.get_node()) && !ngraph::op::is_constant(input.get_node())) {
|
||||
subgraphOutputs.insert(input.get_source_output());
|
||||
}
|
||||
}
|
||||
for (auto&& output : subgraphOutputs) {
|
||||
auto inputs = output.get_target_inputs();
|
||||
auto result = std::make_shared<ngraph::op::Result>(output);
|
||||
result->set_friendly_name(output.get_node()->get_friendly_name() + "_" +
|
||||
std::to_string(output.get_index()) + "_result");
|
||||
ngraph::copy_runtime_info(output.get_node_shared_ptr(), result);
|
||||
auto parameter =
|
||||
std::make_shared<ngraph::op::Parameter>(output.get_element_type(), output.get_partial_shape());
|
||||
parameter->set_friendly_name(input.get_node()->get_friendly_name() + "_" +
|
||||
std::to_string(input.get_index()) + "_parameter");
|
||||
ngraph::copy_runtime_info(input.get_node()->shared_from_this(), parameter);
|
||||
input.replace_source_output(parameter->output(0));
|
||||
results.push_back(result);
|
||||
subgraphIds.emplace(result.get(), subgraphIds[output.get_node()]);
|
||||
subgraphIds.emplace(parameter.get(), subgraphIds[input.get_node()]);
|
||||
subgraphParameterToPrevResult.emplace(parameter.get(), result.get());
|
||||
_blobNameMap.emplace(
|
||||
parameter->get_friendly_name(),
|
||||
output.get_node()->get_friendly_name() + ((output.get_node()->get_output_size() != 1)
|
||||
? ("." + std::to_string(output.get_index()))
|
||||
: std::string{}));
|
||||
results.push_back(result);
|
||||
for (auto&& input : inputs) {
|
||||
output.remove_target_input(input);
|
||||
auto parameter =
|
||||
std::make_shared<ngraph::op::Parameter>(output.get_element_type(), output.get_partial_shape());
|
||||
parameter->set_friendly_name(input.get_node()->get_friendly_name() + "_" +
|
||||
std::to_string(input.get_index()) + "_parameter");
|
||||
ngraph::copy_runtime_info(input.get_node()->shared_from_this(), parameter);
|
||||
input.replace_source_output(parameter->output(0));
|
||||
subgraphIds.emplace(parameter.get(), subgraphIds[input.get_node()]);
|
||||
subgraphParameterToPrevResult.emplace(parameter.get(), result.get());
|
||||
_blobNameMap.emplace(
|
||||
parameter->get_friendly_name(),
|
||||
output.get_node()->get_friendly_name() + ((output.get_node()->get_output_size() != 1)
|
||||
? ("." + std::to_string(output.get_index()))
|
||||
: std::string{}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -353,6 +361,7 @@ HeteroExecutableNetwork::HeteroExecutableNetwork(const InferenceEngine::CNNNetwo
|
||||
subgraph._affinity = itAffinity->second;
|
||||
}
|
||||
}
|
||||
results = {};
|
||||
|
||||
// Subgraph topological sort
|
||||
std::vector<Subgraph> allSubgraphs;
|
||||
|
||||
Reference in New Issue
Block a user