397 lines
19 KiB
Python
397 lines
19 KiB
Python
# Copyright (C) 2018-2021 Intel Corporation
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
import json
|
|
import logging as log
|
|
import os
|
|
from re import compile, match
|
|
|
|
from mo.graph.graph import Node, Graph
|
|
from mo.utils.error import Error
|
|
from mo.utils.graph import nodes_matching_name_pattern, sub_graph_between_nodes
|
|
from mo.utils.utils import refer_to_faq_msg
|
|
|
|
|
|
class CustomReplacementDescriptor(object):
|
|
registered_types = dict()
|
|
|
|
def __init__(self, replacement_id: str, attrs: dict = None):
|
|
"""
|
|
Create class instance based on attrs dictionary which is read from the configuration file.
|
|
:param attrs:
|
|
"""
|
|
super(CustomReplacementDescriptor, self).__setattr__('replacement_id', replacement_id)
|
|
if attrs is not None:
|
|
super(CustomReplacementDescriptor, self).__setattr__('custom_attributes',
|
|
attrs.setdefault('custom_attributes', {}))
|
|
super(CustomReplacementDescriptor, self).__setattr__('_replacement_desc', attrs.copy())
|
|
|
|
def __getattr__(self, k):
|
|
return self._replacement_desc[k]
|
|
|
|
def __setattr__(self, k, v):
|
|
# you can assign only existing attributes
|
|
if k not in self._replacement_desc:
|
|
raise AttributeError
|
|
self._replacement_desc[k] = v
|
|
|
|
def has(self, attr):
|
|
"""
|
|
Check that attribute 'attr' is defined for the CustomReplacementDescriptor.
|
|
:param attr: attribute to check.
|
|
:return: True if the attribute exists and False otherwise.
|
|
"""
|
|
return attr in self._replacement_desc
|
|
|
|
@classmethod
|
|
def register_type(cls, match_kind: str, class_type: object):
|
|
if match_kind in cls.registered_types:
|
|
log.warning('Class for match kind "{}" is already registered'.format(match_kind))
|
|
else:
|
|
cls.registered_types[match_kind] = class_type
|
|
|
|
@classmethod
|
|
def create_instance(cls, match_kind: str, replacement_id: str, attrs: dict = None):
|
|
"""
|
|
Fabric method to create proper object based on match_kind.
|
|
:param match_kind: match kind.
|
|
:param replacement_id: id of the replacement.
|
|
:param attrs: optional attributes to be set.
|
|
:return: object of the sub-class of the CustomLayerDescriptor class or None if the match kind is not registered.
|
|
"""
|
|
if attrs is None:
|
|
attrs = dict()
|
|
if match_kind in cls.registered_types:
|
|
return cls.registered_types[match_kind](replacement_id, attrs)
|
|
else:
|
|
raise Error('No class registered for match kind "{}". Supported match kinds are "{}". '.format(
|
|
match_kind, list(cls.registered_types.keys())) +
|
|
refer_to_faq_msg(65))
|
|
|
|
def sub_graph_instances(self):
|
|
raise Exception("The function 'get_sub_graph_instances' must be implemented in the sub-class.")
|
|
|
|
def get_config_file_representation(self):
|
|
result = {
|
|
'match_kind': self.match_kind, 'instances': self.instances,
|
|
'inputs': self.inputs, 'outputs': self.outputs,
|
|
'custom_attributes': self.custom_attributes, 'id': self.id
|
|
}
|
|
if self.has('op'):
|
|
result.update({'op': self.op})
|
|
return result
|
|
|
|
def get_inputs_description(self):
|
|
"""
|
|
Returns description of inputs of the layer with id 'layer_id'. The format of inputs is the following: list of
|
|
lists where each list contains information about nodes consuming the same tensor from outside of the graph. Each
|
|
element of the list is a pair where first element is a regular expression for the name of the node in the
|
|
sub-graph and the second is the input port of this node.
|
|
:return: description of inputs or None if layer with such id is not registered or information about inputs is
|
|
not available.
|
|
"""
|
|
if 'inputs' not in self._replacement_desc:
|
|
log.error("Information about inputs of layer with id '{}' is not available".format(self.replacement_id))
|
|
return None
|
|
result = list()
|
|
for index, input_desc in enumerate(self._replacement_desc['inputs']):
|
|
result.append([(inp['node'], inp['port']) for inp in input_desc])
|
|
return result
|
|
|
|
def get_outputs_description(self):
|
|
"""
|
|
Returns description of outputs of the layer with id 'layer_id'. The format of outputs is the following: list of
|
|
pairs where the first element of the pair is a regular expression for the name of the node that produces output
|
|
of the sub-graph and the second is the output port of this node.
|
|
:return: description of outputs or None if layer with such id is not registered or information about outputs is
|
|
not available.
|
|
"""
|
|
if 'outputs' not in self._replacement_desc:
|
|
log.error("Information about outputs of layer with id '{}' is not available")
|
|
return None
|
|
return [(out['node'], out['port']) for out in self._replacement_desc['outputs']]
|
|
|
|
def update_custom_replacement_attributes(self, graph: Graph):
|
|
"""
|
|
The function run specific functions to update attributes of the custom replacement description. Currently it
|
|
updates information about input/output nodes.
|
|
:param graph: graph to operate on.
|
|
:return: True if the update process completed successfully.
|
|
"""
|
|
raise Exception("The function 'update_custom_layer_attributes' must be implemented in the sub-class.")
|
|
|
|
def validate_data(self):
|
|
"""
|
|
Validates layer description dictionary.
|
|
:return: list of errors identified.
|
|
"""
|
|
errors = list()
|
|
if not self.has('id'):
|
|
errors.append("Replacement id is not specified for custom replacement '{}'".format(self.replacement_id))
|
|
if not self.has('instances') or self.instances == '':
|
|
errors.append("Attribute 'instances' is not specified for replacement '{}'".format(self.replacement_id))
|
|
if not self.has('match_kind'):
|
|
errors.append("Replacement match type is not specified for replacement '{}'".format(self.replacement_id))
|
|
return errors
|
|
|
|
|
|
class CustomReplacementDescriptorPoints(CustomReplacementDescriptor):
|
|
"""
|
|
Class that is used to describe custom replacement which is a sub-graph specified by start and end points.
|
|
"""
|
|
|
|
def __init__(self, replacement_id: str, attrs: dict = None):
|
|
super().__init__(replacement_id, attrs)
|
|
if not self.has('include_inputs_to_sub_graph'):
|
|
super(CustomReplacementDescriptorPoints, self).__setattr__('include_inputs_to_sub_graph', True)
|
|
if not self.has('include_outputs_to_sub_graph'):
|
|
super(CustomReplacementDescriptorPoints, self).__setattr__('include_outputs_to_sub_graph', True)
|
|
|
|
def get_config_file_representation(self):
|
|
result = {
|
|
'match_kind': self.match_kind, 'instances': self.instances,
|
|
'custom_attributes': self.custom_attributes, 'id': self.id,
|
|
'include_inputs_to_sub_graph': bool(self.include_inputs_to_sub_graph),
|
|
'include_outputs_to_sub_graph': bool(self.include_outputs_to_sub_graph)
|
|
}
|
|
if self.has('op'):
|
|
result.update({'op': self.op})
|
|
return result
|
|
|
|
def get_inputs_description(self):
|
|
return [[('^' + node_name + '$', 0)] for node_name in self.instances['start_points']]
|
|
|
|
def get_outputs_description(self):
|
|
return [('^' + node_name + '$', 0) for node_name in self.instances['end_points']]
|
|
|
|
def get_internal_input_nodes(self, graph: Graph):
|
|
"""
|
|
Gets list of node names getting input from outside of the sub-graph. This function checks whether input nodes
|
|
specified in the configuration file should be added to the sub-graph or not. If they should not be added to the
|
|
sub-graph then input nodes of the sub-graph are children of these nodes.
|
|
:param graph: graph to operate on.
|
|
:return: list of input node names.
|
|
"""
|
|
if not self.include_inputs_to_sub_graph:
|
|
log.debug('Do not include inputs to sub-graph for replacement with id {}'.format(self.replacement_id))
|
|
new_start_nodes = set()
|
|
for start_node in self.instances['start_points']:
|
|
for _, out_node_name in graph.out_edges(start_node):
|
|
new_start_nodes.add(out_node_name)
|
|
start_nodes = list(new_start_nodes)
|
|
log.debug('New inputs are: {}'.format(start_nodes))
|
|
return start_nodes
|
|
else:
|
|
return self.instances['start_points']
|
|
|
|
def get_internal_output_nodes(self, graph: Graph):
|
|
"""
|
|
Gets list of node names producing output outside of the sub-graph. This function checks whether output nodes
|
|
specified in the configuration file should be added to the sub-graph or not. If they should not be added to the
|
|
sub-graph then output nodes of the sub-graph are parents of these nodes.
|
|
:param graph: graph to operate on.
|
|
:return: list of output node names.
|
|
"""
|
|
if not self.include_outputs_to_sub_graph:
|
|
log.debug('Do not include outputs of sub-graph for replacement with id {}'.format(self.replacement_id))
|
|
new_end_nodes = set()
|
|
for end_node in self.instances['end_points']:
|
|
for in_node_name, _ in graph.in_edges(end_node):
|
|
new_end_nodes.add(in_node_name)
|
|
end_nodes = list(new_end_nodes)
|
|
log.debug('New outputs are: {}'.format(end_nodes))
|
|
return end_nodes
|
|
else:
|
|
return self.instances['end_points']
|
|
|
|
def update_custom_replacement_attributes(self, graph: Graph):
|
|
if not self.has('instances'):
|
|
raise Error("No instance(s) is(are) defined for the custom replacement '{}'. ".format(self.replacement_id) +
|
|
refer_to_faq_msg(66))
|
|
if not isinstance(self.instances, dict):
|
|
raise Error("The instance must be a single dictionary for the custom replacement with id '{}'. ".format(
|
|
self.replacement_id) +
|
|
refer_to_faq_msg(67))
|
|
|
|
start_points = self.get_internal_input_nodes(graph)
|
|
end_points = self.get_internal_output_nodes(graph)
|
|
|
|
matched_nodes = sub_graph_between_nodes(graph, start_points, end_points, include_control_flow=False)
|
|
output_tensors = set()
|
|
input_nodes_mapping = dict() # key is the input tensor name, value is the pair: (input_port, output_node_name)
|
|
for src_node_name, dst_node_name, edge_attrs in graph.edges(data=True):
|
|
dst_node = graph.node[dst_node_name]
|
|
|
|
# edge outside sub-graph into sub-graph
|
|
if (src_node_name not in matched_nodes) and (dst_node_name in matched_nodes):
|
|
tensor_name = src_node_name + ":" + str(edge_attrs['out'])
|
|
if tensor_name not in input_nodes_mapping:
|
|
input_nodes_mapping[tensor_name] = list()
|
|
input_nodes_mapping[tensor_name].append(('^' + dst_node_name + '$', edge_attrs['in']))
|
|
|
|
# edge from inside sub-graph to outside sub-graph
|
|
if (src_node_name in matched_nodes) and (dst_node_name not in matched_nodes):
|
|
output_tensors.add(('^' + dst_node['pb'].input[edge_attrs['in']] + '$', edge_attrs['out']))
|
|
|
|
for node_name in graph.nodes():
|
|
node = Node(graph, node_name)
|
|
if node_name in matched_nodes and len(node.out_nodes()) == 0 and node['pb'].op != 'Const':
|
|
log.debug("Node {} doesn't have output edges. Consider it output".format(node_name))
|
|
output_tensors.add(('^' + node_name + '$', 0))
|
|
|
|
if not self.has('inputs'):
|
|
self._replacement_desc['inputs'] = [[{'node': desc[0], 'port': desc[1]} for desc in inp]
|
|
for inp in sorted(input_nodes_mapping.values())]
|
|
log.debug('Updated inputs of sub-graph for instance "{}"'.format(self.instances))
|
|
|
|
if not self.has('outputs'):
|
|
self._replacement_desc['outputs'] = [{'node': node, 'port': port} for node, port in sorted(output_tensors)]
|
|
log.debug('Updated outputs of sub-graph for instance "{}"'.format(self.instances))
|
|
|
|
def sub_graph_instances(self):
|
|
return [self.instances]
|
|
|
|
|
|
CustomReplacementDescriptor.register_type('points', CustomReplacementDescriptorPoints)
|
|
|
|
|
|
class CustomReplacementDescriptorScope(CustomReplacementDescriptor):
|
|
"""
|
|
Class that is used to describe custom layer which is a sub-graph specified by scope name.
|
|
"""
|
|
|
|
def __init__(self, replacement_id: str, attrs: dict = None):
|
|
super().__init__(replacement_id, attrs)
|
|
|
|
def update_custom_replacement_attributes(self, graph: Graph):
|
|
if not self.has('instances') or len(self.instances) == 0:
|
|
raise Error("No instances are defined for replacement with id '{}'. ".format(self.replacement_id) +
|
|
refer_to_faq_msg(68))
|
|
|
|
pattern = self.instances[0] # use the first instance pattern to find input/output nodes patterns
|
|
# TODO verify that all instances will produce the same sub-graph
|
|
matched_nodes = nodes_matching_name_pattern(graph, pattern)
|
|
|
|
output_tensors = set()
|
|
input_nodes_mapping = dict() # key is the input tensor name, value is the pair: (input_port, output_node_name)
|
|
for src_node_name, dst_node_name, edge_attrs in graph.edges(data=True):
|
|
dst_node = graph.node[dst_node_name]
|
|
|
|
# edge outside sub-graph into sub-graph
|
|
if (src_node_name not in matched_nodes) and (dst_node_name in matched_nodes):
|
|
tensor_name = src_node_name + ":" + str(edge_attrs['out'])
|
|
if tensor_name not in input_nodes_mapping:
|
|
input_nodes_mapping[tensor_name] = list()
|
|
input_nodes_mapping[tensor_name].append((generate_pattern_for_node(graph, pattern, dst_node_name),
|
|
edge_attrs['in']))
|
|
|
|
# edge from inside sub-graph to outside sub-graph
|
|
if (src_node_name in matched_nodes) and (dst_node_name not in matched_nodes):
|
|
output_tensors.add(
|
|
(generate_pattern_for_node(graph, pattern, dst_node['pb'].input[edge_attrs['in']]),
|
|
edge_attrs['out']))
|
|
|
|
for node_name in graph.nodes():
|
|
node = Node(graph, node_name)
|
|
if node_name in matched_nodes and len(node.out_nodes()) == 0 and node['pb'].op != 'Const':
|
|
log.debug("Node {} doesn't have output edges. Consider it output".format(node_name))
|
|
output_tensors.add((generate_pattern_for_node(graph, pattern, node_name), 0))
|
|
|
|
if not self.has('inputs'):
|
|
self._replacement_desc['inputs'] = [[{'node': desc[0], 'port': desc[1]} for desc in inp]
|
|
for inp in sorted(input_nodes_mapping.values())]
|
|
log.debug('Updated inputs of sub-graph for instance "{}"'.format(self.instances))
|
|
|
|
if not self.has('outputs'):
|
|
self._replacement_desc['outputs'] = [{'node': node, 'port': port} for node, port in sorted(output_tensors)]
|
|
log.debug('Updated outputs of sub-graph for instance "{}"'.format(self.instances))
|
|
|
|
def sub_graph_instances(self):
|
|
return self.instances
|
|
|
|
|
|
CustomReplacementDescriptor.register_type('scope', CustomReplacementDescriptorScope)
|
|
|
|
|
|
class CustomReplacementDescriptorGeneral(CustomReplacementDescriptor):
|
|
def __init__(self, replacement_id: str, attrs: dict = None):
|
|
super().__init__(replacement_id, attrs)
|
|
|
|
def validate_data(self):
|
|
"""
|
|
Validates layer description dictionary.
|
|
:return: list of errors identified.
|
|
"""
|
|
errors = list()
|
|
if not self.has('id'):
|
|
errors.append("Replacement id is not specified for custom replacement '{}'".format(self.replacement_id))
|
|
if not self.has('match_kind'):
|
|
errors.append("Replacement match type is not specified for replacement '{}'".format(self.replacement_id))
|
|
return errors
|
|
|
|
|
|
CustomReplacementDescriptor.register_type('general', CustomReplacementDescriptorGeneral)
|
|
|
|
|
|
def parse_custom_replacement_config_file(file_name: str):
|
|
"""
|
|
Reads custom replacement configuration file file_name.
|
|
:param file_name: name of the file to read from.
|
|
:return: The dictionary where key is the layer id and value is an instance of the CustomLayerDescriptor object.
|
|
"""
|
|
if not os.path.exists(file_name):
|
|
raise Error("Custom replacements configuration file '{}' does not exist. ".format(file_name) +
|
|
refer_to_faq_msg(69))
|
|
try:
|
|
with open(file_name, 'r') as f:
|
|
data = json.load(f)
|
|
except Exception as exc:
|
|
raise Error("Failed to parse custom replacements configuration file '{}': {}. ".format(file_name, exc) +
|
|
refer_to_faq_msg(70)) from exc
|
|
|
|
result = list()
|
|
validation_errors = list()
|
|
for attrs in data:
|
|
if 'id' not in attrs:
|
|
raise Error('One of the custom replacements in the configuration file "{}" does not contain attribute '
|
|
'"id". '.format(file_name) +
|
|
refer_to_faq_msg(71))
|
|
if 'match_kind' not in attrs:
|
|
raise Error('One of the custom replacements in the configuration file "{}" does not contain attribute '
|
|
'"match_kind". Possible values are "points", "scope" and "general". '.format(file_name) +
|
|
refer_to_faq_msg(71))
|
|
desc = CustomReplacementDescriptor.create_instance(attrs['match_kind'], attrs['id'], attrs)
|
|
validation_errors.extend(desc.validate_data())
|
|
result.append(desc)
|
|
if len(validation_errors) > 0:
|
|
raise Error("File '{}' validation failed:\n{}. ".format(file_name, "\n".join(validation_errors)) +
|
|
refer_to_faq_msg(72))
|
|
return result
|
|
|
|
|
|
def generate_pattern_for_node(graph: Graph, sub_graph_pattern: str, node_name: str):
|
|
if sub_graph_pattern == '':
|
|
return node_name
|
|
node_name_components = node_name.split("/")
|
|
cur_name = ''
|
|
matched_index = None # index of the node name component to start new pattern from
|
|
compiled_pattern = compile(sub_graph_pattern)
|
|
for index in range(0, len(node_name_components)):
|
|
cur_name += node_name_components[index] + "/"
|
|
if match(compiled_pattern, cur_name):
|
|
matched_index = index
|
|
break
|
|
if matched_index is None:
|
|
raise RuntimeError('Node name "{}" does not match pattern "{}"'.format(node_name, sub_graph_pattern))
|
|
|
|
if sub_graph_pattern == '' or sub_graph_pattern[-1] != '/':
|
|
sub_graph_pattern += '/'
|
|
|
|
sub_graph_nodes = nodes_matching_name_pattern(graph, sub_graph_pattern)
|
|
name_suffix = '/'.join(node_name_components[matched_index + 1:]) + '$'
|
|
if len([node for node in sub_graph_nodes if match(sub_graph_pattern + name_suffix, node)]) == 1:
|
|
return name_suffix
|
|
|
|
raise RuntimeError('The pattern that uniquely identifies node "{}" using sub-graph pattern "{}" has not been found'.
|
|
format(node_name, sub_graph_pattern))
|