Support ":" in node name for mean/scale application (#4082)

* Support ":" in node name for mean/scale application

* Apply review feedback
This commit is contained in:
Maxim Vafin
2021-02-08 16:51:53 +03:00
committed by GitHub
parent 2ad7db7b25
commit a157cc2a55
3 changed files with 48 additions and 11 deletions

View File

@@ -19,7 +19,7 @@ import numpy as np
from extensions.ops.elementwise import Add, Mul
from mo.front.common.layout import get_features_dim
from mo.front.extractor import split_node_in_port
from mo.front.extractor import split_node_in_port, get_node_id_with_ports
from mo.front.tf.graph_utils import create_op_with_const_inputs
from mo.graph.graph import Graph, Node
from mo.middle.replacement import MiddleReplacementPattern
@@ -85,28 +85,27 @@ class AddMeanScaleValues(MiddleReplacementPattern):
input_nodes = graph.get_op_nodes(op='Parameter')
if not isinstance(values, dict):
# The case when input names to apply mean/scales weren't specified
if len(values) != len(input_nodes):
raise Error('Numbers of inputs and mean/scale values do not match. ' + refer_to_faq_msg(61))
data = np.copy(values)
values = {}
for idx, node in enumerate(input_nodes):
assert node.has_valid('name')
values.update(
{
node['name']: {
node.soft_get('name', node.id): {
'mean': data[idx][0],
'scale': data[idx][1]
}
}
)
for node_name in values:
node_mean_scale_values = values[node_name]
node_name, port = split_node_in_port(node_name)
for node_name, node_mean_scale_values in values.items():
node_id = None
try:
node_id = graph.get_node_id_by_name(node_name)
node_id, direction, port = get_node_id_with_ports(graph, node_name, skip_if_no_port=False)
assert direction != 'out', 'Only input port can be specified for mean/scale application'
except Error as e:
log.warning('node_name {} is not found in graph'.format(node_name))
if Node(graph, node_id) not in input_nodes:
@@ -121,7 +120,7 @@ class AddMeanScaleValues(MiddleReplacementPattern):
log.debug('Can not get the port number from the node {}'.format(placeholder.id))
log.debug('Port will be defined as None')
port = None
if placeholder.has('initial_node_name') and placeholder.initial_node_name == node_name and (
if placeholder.has('initial_node_name') and placeholder.initial_node_name == node_id and (
port is None or placeholder_port == port):
new_node_id = placeholder.id
break

View File

@@ -255,6 +255,44 @@ class AddMeanScaleValuesTest(unittest.TestCase):
self.assertTrue(flag, resp)
self.check_graph_attrs(graph, graph_ref, ['parameter'])
def test_mean_values_with_colon_in_node_name(self):
graph_ref = build_graph(nodes, [
*connect('parameter', '0:add_mean'),
*connect('mean', '1:add_mean'),
*connect('add_mean', 'result'),
])
argv = Namespace(mean_scale_values={'param:0': {'scale': np.array([1.]), 'mean': np.array([1., 2., 3.])}})
graph = build_graph(nodes, [*connect('parameter', 'result')], {'parameter': {'name': 'param:0'}},
nodes_with_edges_only=True, cli=argv)
self.set_graph_attrs(graph, ['parameter'])
self.set_graph_attrs(graph_ref, ['parameter'])
graph.graph['layout'] = 'NCHW'
AddMeanScaleValues().find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
self.assertTrue(flag, resp)
def test_mean_values_with_colon_in_node_name_and_port(self):
graph_ref = build_graph(nodes, [
*connect('parameter', '0:add_mean'),
*connect('mean', '1:add_mean'),
*connect('add_mean', 'result'),
])
argv = Namespace(mean_scale_values={'0:param:0': {'scale': np.array([1.]), 'mean': np.array([1., 2., 3.])}})
graph = build_graph(nodes, [*connect('parameter', 'result')],
{'parameter': {'name': 'param:0', 'id': 'param:0/placeholder_0',
'initial_node_name': 'param:0'}},
nodes_with_edges_only=True, cli=argv)
self.set_graph_attrs(graph, ['parameter'])
self.set_graph_attrs(graph_ref, ['parameter'])
graph.graph['layout'] = 'NCHW'
AddMeanScaleValues().find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
self.assertTrue(flag, resp)
def test_scale_input(self):
graph_ref = build_graph(nodes, [
*connect('parameter', '0:mul_scale'),

View File

@@ -455,7 +455,7 @@ def extract_node_attrs(graph: Graph, extractor: callable):
return graph
def get_node_id_with_ports(graph: Graph, node_name: str):
def get_node_id_with_ports(graph: Graph, node_name: str, skip_if_no_port=True):
"""
Extracts port and node ID out of user provided name
:param graph: graph to operate on
@@ -476,12 +476,12 @@ def get_node_id_with_ports(graph: Graph, node_name: str):
node = Node(graph, graph.get_node_id_by_name(name))
if match.group(1):
in_port = int(match.group(1).replace(':', ''))
if in_port not in [e['in'] for e in node.in_edges().values()]:
if skip_if_no_port and in_port not in [e['in'] for e in node.in_edges().values()]:
# skip found node if it doesn't have such port number
continue
if match.group(3):
out_port = int(match.group(3).replace(':', ''))
if out_port not in [e['out'] for e in node.out_edges().values()]:
if skip_if_no_port and out_port not in [e['out'] for e in node.out_edges().values()]:
# skip found node if it doesn't have such port number
continue