[GPU] Improve IsNodeOnConstPath() method (#7458)

This commit is contained in:
mei, yang 2021-09-22 15:52:58 +08:00 committed by GitHub
parent 377f46898c
commit e2272331be
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -325,28 +325,24 @@ void Program::InitProfileInfo(const std::string& layerName,
// TODO: Does it make sense to add such method to ngraph core? // TODO: Does it make sense to add such method to ngraph core?
bool IsNodeOnConstPath(const std::shared_ptr<ngraph::Node>& node) { bool IsNodeOnConstPath(const std::shared_ptr<ngraph::Node>& node) {
std::list<std::shared_ptr<ngraph::Node>> nodes_to_process = { node }; std::set<std::shared_ptr<ngraph::Node>> nodes_processed = {};
while (!nodes_to_process.empty()) { std::function<bool(const std::shared_ptr<ngraph::Node>&)> is_const_node = [&nodes_processed, &is_const_node](const std::shared_ptr<ngraph::Node>& node) {
auto current_node = nodes_to_process.front(); if (nodes_processed.count(node)) return true;
nodes_to_process.pop_front(); nodes_processed.insert(node);
// If input is constant, then drop if from the processing list
for (size_t i = 0; i < current_node->get_input_size(); i++) { if (std::dynamic_pointer_cast<ngraph::op::v0::Constant>(node) != nullptr)
auto input_node = current_node->get_input_node_shared_ptr(i); return true;
// If the node doesn't have any parents and it's not a constant, then we deal with dynamic path
// If input is constant, then drop if from the processing list if (node->get_input_size() == 0)
if (std::dynamic_pointer_cast<ngraph::op::v0::Constant>(input_node) != nullptr) return false;
continue; for (size_t i = 0; i < node->get_input_size(); i++) {
auto input_node = node->get_input_node_shared_ptr(i);
// If the node doesn't have any parents and it's not a constant, then we deal with dynamic path if (!is_const_node(input_node))
if (input_node->get_input_size() == 0) {
return false; return false;
}
nodes_to_process.insert(nodes_to_process.end(), input_node);
} }
} return true;
};
return true; return is_const_node(node);
} }
} // namespace CLDNNPlugin } // namespace CLDNNPlugin