158 lines
6.0 KiB
Python
158 lines
6.0 KiB
Python
# Copyright (C) 2018-2021 Intel Corporation
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
import logging as log
|
|
|
|
import numpy as np
|
|
from networkx.algorithms import isomorphism as ism
|
|
|
|
from mo.graph.graph import Node, dict_includes, Graph
|
|
|
|
|
|
def inverse_dict(d: dict):
|
|
return {v: k for k, v in d.items()}
|
|
|
|
|
|
def for_each_sub_graph(graph: Graph, func: callable):
|
|
""" Run a given function `func` for each sub-graph in a given graph not recursively.
|
|
|
|
It doesn't search for sub-graphs in found sub-graphs recursively. If the recursion is required,
|
|
a given function `func` should be implemented in a special way to enable fully recursive traversal.
|
|
"""
|
|
for node in graph.nodes():
|
|
node = Node(graph, node)
|
|
if node.has_valid('sub_graphs'):
|
|
for sub_graph_name in node.sub_graphs:
|
|
func(node[sub_graph_name])
|
|
|
|
|
|
def for_each_sub_graph_recursively(graph: Graph, func: callable):
|
|
""" Run a given function `func` for each sub-graph in a given graph `graph` recursively.
|
|
|
|
A given function `func` shouldn't contain a recursion for sub-graphs of the second level.
|
|
"""
|
|
|
|
def recursive_helper(sub_graph):
|
|
# user action
|
|
func(sub_graph)
|
|
# recursion
|
|
for_each_sub_graph(sub_graph, recursive_helper)
|
|
|
|
for_each_sub_graph(graph, recursive_helper)
|
|
|
|
|
|
def for_graph_and_each_sub_graph_recursively(graph: Graph, func: callable):
|
|
""" Run a given function `func` for a given graph `graph` and each sub-graph recursively. """
|
|
func(graph)
|
|
for_each_sub_graph_recursively(graph, func)
|
|
|
|
|
|
def all_edges_in_nodes(nodes: list, edges: list):
|
|
return all([edge[0] in nodes and edge[1] in nodes for edge in edges])
|
|
|
|
|
|
def apply_pattern(graph: Graph, nodes: list, edges: list, action: callable, node_attrs: list = None,
|
|
edge_attrs: list = None):
|
|
"""
|
|
Search for all matches of a given subgraph defined by [nodes, edges] in graph,
|
|
then apply action for each such match.
|
|
"""
|
|
if not all_edges_in_nodes([node[0] for node in nodes], edges):
|
|
log.warning("Incorrect pattern attributes: not all nodes from edges are in nodes. "
|
|
"Please, mention all nodes you need in pattern in nodes attribute. ")
|
|
|
|
matches = []
|
|
for match in find_pattern_matches(graph, nodes, edges, node_attrs, edge_attrs):
|
|
matches.append(match)
|
|
|
|
for match in matches:
|
|
match = inverse_dict(match)
|
|
still_valid = True
|
|
for k in match:
|
|
if not graph.has_node(match[k]):
|
|
# Graph changed significantly
|
|
still_valid = False
|
|
log.warning("The graph has changed significantly during applying pattern:\n"
|
|
"nodes: {}\n"
|
|
"edges: {}\n"
|
|
"node_attrs: {}\n"
|
|
"edge_attrs: {}".format(nodes, edges, node_attrs, edge_attrs))
|
|
break
|
|
match[k] = Node(graph, match[k])
|
|
if still_valid:
|
|
action(graph, match)
|
|
|
|
|
|
def check_node_usages_out_of_match(match: dict, node_name_in_match_group: str):
|
|
"""
|
|
Checks if node is consumed by nodes out of match
|
|
:param match: dictionary with pattern match
|
|
:param node_name_in_match_group: string
|
|
:return:
|
|
"""
|
|
assert node_name_in_match_group in match
|
|
graph = match[node_name_in_match_group].graph
|
|
all_node_ids = [match[name].id for name in match]
|
|
in_out_node_ids = [u for u, _ in graph.in_edges(match[node_name_in_match_group].id)]
|
|
in_out_node_ids.extend([v for _, v in graph.out_edges(match[node_name_in_match_group].id)])
|
|
return all([n in all_node_ids for n in in_out_node_ids])
|
|
|
|
|
|
def node_match(data1: dict, data2: dict):
|
|
# We have to skip _in_ports/_out_ports attributes for comparison as they are not comparable
|
|
return dict_includes(data1, data2, skip_attr_names=['_in_ports', '_out_ports'])
|
|
|
|
|
|
def edge_match(datasets1, datasets2):
|
|
attrs = list(datasets2[0].keys())
|
|
values1 = set([])
|
|
for data1 in datasets1.values():
|
|
x = tuple(data1.get(attr, None) for attr in attrs)
|
|
values1.add(x)
|
|
values2 = set([])
|
|
for data2 in datasets2.values():
|
|
x = tuple(data2.get(attr, None) for attr in attrs)
|
|
values2.add(x)
|
|
return values1 == values2
|
|
|
|
|
|
def build_matcher(graph: Graph, nodes: list, edges: list, node_attrs: list = None,
|
|
edge_attrs: list = None):
|
|
if node_attrs is not None or edge_attrs is not None:
|
|
log.warning('\'edge_attrs\' or `\'node_attrs\'` parameter was passed to function \'find_pattern_matches\', '
|
|
'but they are not used anymore. Pattern matching proceeds according to \'nodes\' and \'edges\' '
|
|
'parameters. Please avoid passing \'edge_attrs\' and \'node_attrs\' parameters to any pattern '
|
|
'matching function like \'find_pattern_matches\', \'apply_pattern\' and \'pattern\' because it '
|
|
'will be deprecated in the next release.')
|
|
|
|
subgraph = Graph(name='pattern')
|
|
subgraph.add_nodes_from(nodes)
|
|
subgraph.add_edges_from(edges)
|
|
return ism.MultiDiGraphMatcher(graph, subgraph, node_match, edge_match)
|
|
|
|
|
|
def find_pattern_matches(graph: Graph, nodes: list, edges: list, node_attrs: list = None,
|
|
edge_attrs: list = None):
|
|
"""
|
|
Find all matches of a given sub-graph defined by [nodes, edges] in graph.
|
|
"""
|
|
matcher = build_matcher(graph, nodes, edges, node_attrs, edge_attrs)
|
|
return matcher.subgraph_isomorphisms_iter()
|
|
|
|
|
|
def find_isomorphisms(graph: Graph, nodes: list, edges: list):
|
|
''' Find for isomorphism between a given graph and a pattern specified by a given nodes and edges.
|
|
Applies the same rules as apply_pattern.
|
|
'''
|
|
matcher = build_matcher(graph, nodes, edges)
|
|
result = []
|
|
for match in matcher.isomorphisms_iter():
|
|
match = inverse_dict(match)
|
|
match = {k: Node(graph, match[k]) for k in match.keys()}
|
|
result.append(match)
|
|
return result
|
|
|
|
|
|
def check_value(v: np.ndarray, check: callable):
|
|
return v is not None and np.all(np.isreal(v)) and check(v)
|