Files
openvino/model-optimizer/mo/middle/pattern_match.py
Alexey Suhov 6478f1742a Align copyright notice in python scripts (CVS-51320) (#4974)
* Align copyright notice in python scripts (CVS-51320)
2021-03-26 17:54:28 +03:00

158 lines
6.0 KiB
Python

# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import logging as log
import numpy as np
from networkx.algorithms import isomorphism as ism
from mo.graph.graph import Node, dict_includes, Graph
def inverse_dict(d: dict):
return {v: k for k, v in d.items()}
def for_each_sub_graph(graph: Graph, func: callable):
""" Run a given function `func` for each sub-graph in a given graph not recursively.
It doesn't search for sub-graphs in found sub-graphs recursively. If the recursion is required,
a given function `func` should be implemented in a special way to enable fully recursive traversal.
"""
for node in graph.nodes():
node = Node(graph, node)
if node.has_valid('sub_graphs'):
for sub_graph_name in node.sub_graphs:
func(node[sub_graph_name])
def for_each_sub_graph_recursively(graph: Graph, func: callable):
""" Run a given function `func` for each sub-graph in a given graph `graph` recursively.
A given function `func` shouldn't contain a recursion for sub-graphs of the second level.
"""
def recursive_helper(sub_graph):
# user action
func(sub_graph)
# recursion
for_each_sub_graph(sub_graph, recursive_helper)
for_each_sub_graph(graph, recursive_helper)
def for_graph_and_each_sub_graph_recursively(graph: Graph, func: callable):
""" Run a given function `func` for a given graph `graph` and each sub-graph recursively. """
func(graph)
for_each_sub_graph_recursively(graph, func)
def all_edges_in_nodes(nodes: list, edges: list):
return all([edge[0] in nodes and edge[1] in nodes for edge in edges])
def apply_pattern(graph: Graph, nodes: list, edges: list, action: callable, node_attrs: list = None,
edge_attrs: list = None):
"""
Search for all matches of a given subgraph defined by [nodes, edges] in graph,
then apply action for each such match.
"""
if not all_edges_in_nodes([node[0] for node in nodes], edges):
log.warning("Incorrect pattern attributes: not all nodes from edges are in nodes. "
"Please, mention all nodes you need in pattern in nodes attribute. ")
matches = []
for match in find_pattern_matches(graph, nodes, edges, node_attrs, edge_attrs):
matches.append(match)
for match in matches:
match = inverse_dict(match)
still_valid = True
for k in match:
if not graph.has_node(match[k]):
# Graph changed significantly
still_valid = False
log.warning("The graph has changed significantly during applying pattern:\n"
"nodes: {}\n"
"edges: {}\n"
"node_attrs: {}\n"
"edge_attrs: {}".format(nodes, edges, node_attrs, edge_attrs))
break
match[k] = Node(graph, match[k])
if still_valid:
action(graph, match)
def check_node_usages_out_of_match(match: dict, node_name_in_match_group: str):
"""
Checks if node is consumed by nodes out of match
:param match: dictionary with pattern match
:param node_name_in_match_group: string
:return:
"""
assert node_name_in_match_group in match
graph = match[node_name_in_match_group].graph
all_node_ids = [match[name].id for name in match]
in_out_node_ids = [u for u, _ in graph.in_edges(match[node_name_in_match_group].id)]
in_out_node_ids.extend([v for _, v in graph.out_edges(match[node_name_in_match_group].id)])
return all([n in all_node_ids for n in in_out_node_ids])
def node_match(data1: dict, data2: dict):
# We have to skip _in_ports/_out_ports attributes for comparison as they are not comparable
return dict_includes(data1, data2, skip_attr_names=['_in_ports', '_out_ports'])
def edge_match(datasets1, datasets2):
attrs = list(datasets2[0].keys())
values1 = set([])
for data1 in datasets1.values():
x = tuple(data1.get(attr, None) for attr in attrs)
values1.add(x)
values2 = set([])
for data2 in datasets2.values():
x = tuple(data2.get(attr, None) for attr in attrs)
values2.add(x)
return values1 == values2
def build_matcher(graph: Graph, nodes: list, edges: list, node_attrs: list = None,
edge_attrs: list = None):
if node_attrs is not None or edge_attrs is not None:
log.warning('\'edge_attrs\' or `\'node_attrs\'` parameter was passed to function \'find_pattern_matches\', '
'but they are not used anymore. Pattern matching proceeds according to \'nodes\' and \'edges\' '
'parameters. Please avoid passing \'edge_attrs\' and \'node_attrs\' parameters to any pattern '
'matching function like \'find_pattern_matches\', \'apply_pattern\' and \'pattern\' because it '
'will be deprecated in the next release.')
subgraph = Graph(name='pattern')
subgraph.add_nodes_from(nodes)
subgraph.add_edges_from(edges)
return ism.MultiDiGraphMatcher(graph, subgraph, node_match, edge_match)
def find_pattern_matches(graph: Graph, nodes: list, edges: list, node_attrs: list = None,
edge_attrs: list = None):
"""
Find all matches of a given sub-graph defined by [nodes, edges] in graph.
"""
matcher = build_matcher(graph, nodes, edges, node_attrs, edge_attrs)
return matcher.subgraph_isomorphisms_iter()
def find_isomorphisms(graph: Graph, nodes: list, edges: list):
''' Find for isomorphism between a given graph and a pattern specified by a given nodes and edges.
Applies the same rules as apply_pattern.
'''
matcher = build_matcher(graph, nodes, edges)
result = []
for match in matcher.isomorphisms_iter():
match = inverse_dict(match)
match = {k: Node(graph, match[k]) for k in match.keys()}
result.append(match)
return result
def check_value(v: np.ndarray, check: callable):
return v is not None and np.all(np.isreal(v)) and check(v)