processing sinks in hetero plugin, update single layer tests (#6090)

This commit is contained in:
Ivan Tikhonov
2021-06-09 19:03:38 +03:00
committed by GitHub
parent ec1134532a
commit 3bedd051dc
4 changed files with 19 additions and 2 deletions

View File

@@ -312,6 +312,7 @@ HeteroExecutableNetwork::HeteroExecutableNetwork(const InferenceEngine::CNNNetwo
struct Subgraph {
ngraph::ResultVector _results;
ngraph::ParameterVector _parameters;
ngraph::SinkVector _sinks;
std::string _affinity;
};
std::unordered_map<int, Subgraph> subgraphs;
@@ -325,6 +326,9 @@ HeteroExecutableNetwork::HeteroExecutableNetwork(const InferenceEngine::CNNNetwo
} else if (ngraph::op::is_parameter(node)) {
subgraph._parameters.emplace_back(
std::dynamic_pointer_cast<ngraph::op::v0::Parameter>(node->shared_from_this()));
} else if (ngraph::op::is_sink(node)) {
subgraph._sinks.emplace_back(
std::dynamic_pointer_cast<ngraph::op::Sink>(node->shared_from_this()));
}
auto itAffinity = affinities.find(node);
if (itAffinity != affinities.end()) {
@@ -373,7 +377,7 @@ HeteroExecutableNetwork::HeteroExecutableNetwork(const InferenceEngine::CNNNetwo
for (auto&& subgraph : orderedSubgraphs) {
_networks[id]._device = subgraph._affinity;
subFunctions[id] =
std::make_shared<ngraph::Function>(subgraph._results, subgraph._parameters,
std::make_shared<ngraph::Function>(subgraph._results, subgraph._sinks, subgraph._parameters,
_name + '_' + std::to_string(id));
_networks[id]._clonedNetwork = CNNNetwork{subFunctions[id]};
// update of pre-processing info

View File

@@ -38,7 +38,7 @@ INSTANTIATE_TEST_CASE_P(smoke_MemoryTest, MemoryTest,
::testing::ValuesIn(iterationCount),
::testing::ValuesIn(inShapes),
::testing::ValuesIn(inputPrecisions),
::testing::Values(CommonTestUtils::DEVICE_CPU)),
::testing::Values(CommonTestUtils::DEVICE_CPU, "HETERO:CPU")),
MemoryTest::getTestCaseName);
} // namespace

View File

@@ -34,6 +34,8 @@ namespace ngraph
NGRAPH_API
bool is_output(const ngraph::Node* node);
NGRAPH_API
bool is_sink(const ngraph::Node* node);
NGRAPH_API
bool is_constant(const ngraph::Node* node);
NGRAPH_API
bool is_commutative(const ngraph::Node* node);
@@ -60,6 +62,8 @@ namespace ngraph
NGRAPH_API
bool is_output(const std::shared_ptr<ngraph::Node>& node);
NGRAPH_API
bool is_sink(const std::shared_ptr<ngraph::Node>& node);
NGRAPH_API
bool is_constant(const std::shared_ptr<ngraph::Node>& node);
NGRAPH_API
bool is_commutative(const std::shared_ptr<ngraph::Node>& node);

View File

@@ -76,6 +76,11 @@ bool ngraph::op::is_output(const ngraph::Node* node)
return dynamic_cast<const ngraph::op::Result*>(node) != nullptr;
}
bool ngraph::op::is_sink(const ngraph::Node* node)
{
return dynamic_cast<const ngraph::op::Sink*>(node) != nullptr;
}
bool ngraph::op::is_constant(const ngraph::Node* node)
{
return dynamic_cast<const ngraph::op::Constant*>(node) != nullptr;
@@ -134,6 +139,10 @@ bool ngraph::op::is_output(const std::shared_ptr<ngraph::Node>& node)
{
return is_output(node.get());
}
bool ngraph::op::is_sink(const std::shared_ptr<ngraph::Node>& node)
{
return is_sink(node.get());
}
bool ngraph::op::is_constant(const std::shared_ptr<ngraph::Node>& node)
{
return is_constant(node.get());