Support ":" in node name for mean/scale application (#4082)
* Support ":" in node name for mean/scale application * Apply review feedback
This commit is contained in:
@@ -19,7 +19,7 @@ import numpy as np
|
||||
|
||||
from extensions.ops.elementwise import Add, Mul
|
||||
from mo.front.common.layout import get_features_dim
|
||||
from mo.front.extractor import split_node_in_port
|
||||
from mo.front.extractor import split_node_in_port, get_node_id_with_ports
|
||||
from mo.front.tf.graph_utils import create_op_with_const_inputs
|
||||
from mo.graph.graph import Graph, Node
|
||||
from mo.middle.replacement import MiddleReplacementPattern
|
||||
@@ -85,28 +85,27 @@ class AddMeanScaleValues(MiddleReplacementPattern):
|
||||
input_nodes = graph.get_op_nodes(op='Parameter')
|
||||
|
||||
if not isinstance(values, dict):
|
||||
# The case when input names to apply mean/scales weren't specified
|
||||
if len(values) != len(input_nodes):
|
||||
raise Error('Numbers of inputs and mean/scale values do not match. ' + refer_to_faq_msg(61))
|
||||
|
||||
data = np.copy(values)
|
||||
values = {}
|
||||
for idx, node in enumerate(input_nodes):
|
||||
assert node.has_valid('name')
|
||||
values.update(
|
||||
{
|
||||
node['name']: {
|
||||
node.soft_get('name', node.id): {
|
||||
'mean': data[idx][0],
|
||||
'scale': data[idx][1]
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
for node_name in values:
|
||||
node_mean_scale_values = values[node_name]
|
||||
node_name, port = split_node_in_port(node_name)
|
||||
for node_name, node_mean_scale_values in values.items():
|
||||
node_id = None
|
||||
try:
|
||||
node_id = graph.get_node_id_by_name(node_name)
|
||||
node_id, direction, port = get_node_id_with_ports(graph, node_name, skip_if_no_port=False)
|
||||
assert direction != 'out', 'Only input port can be specified for mean/scale application'
|
||||
except Error as e:
|
||||
log.warning('node_name {} is not found in graph'.format(node_name))
|
||||
if Node(graph, node_id) not in input_nodes:
|
||||
@@ -121,7 +120,7 @@ class AddMeanScaleValues(MiddleReplacementPattern):
|
||||
log.debug('Can not get the port number from the node {}'.format(placeholder.id))
|
||||
log.debug('Port will be defined as None')
|
||||
port = None
|
||||
if placeholder.has('initial_node_name') and placeholder.initial_node_name == node_name and (
|
||||
if placeholder.has('initial_node_name') and placeholder.initial_node_name == node_id and (
|
||||
port is None or placeholder_port == port):
|
||||
new_node_id = placeholder.id
|
||||
break
|
||||
|
||||
@@ -255,6 +255,44 @@ class AddMeanScaleValuesTest(unittest.TestCase):
|
||||
self.assertTrue(flag, resp)
|
||||
self.check_graph_attrs(graph, graph_ref, ['parameter'])
|
||||
|
||||
def test_mean_values_with_colon_in_node_name(self):
|
||||
graph_ref = build_graph(nodes, [
|
||||
*connect('parameter', '0:add_mean'),
|
||||
*connect('mean', '1:add_mean'),
|
||||
*connect('add_mean', 'result'),
|
||||
])
|
||||
|
||||
argv = Namespace(mean_scale_values={'param:0': {'scale': np.array([1.]), 'mean': np.array([1., 2., 3.])}})
|
||||
graph = build_graph(nodes, [*connect('parameter', 'result')], {'parameter': {'name': 'param:0'}},
|
||||
nodes_with_edges_only=True, cli=argv)
|
||||
self.set_graph_attrs(graph, ['parameter'])
|
||||
self.set_graph_attrs(graph_ref, ['parameter'])
|
||||
graph.graph['layout'] = 'NCHW'
|
||||
|
||||
AddMeanScaleValues().find_and_replace_pattern(graph)
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test_mean_values_with_colon_in_node_name_and_port(self):
|
||||
graph_ref = build_graph(nodes, [
|
||||
*connect('parameter', '0:add_mean'),
|
||||
*connect('mean', '1:add_mean'),
|
||||
*connect('add_mean', 'result'),
|
||||
])
|
||||
|
||||
argv = Namespace(mean_scale_values={'0:param:0': {'scale': np.array([1.]), 'mean': np.array([1., 2., 3.])}})
|
||||
graph = build_graph(nodes, [*connect('parameter', 'result')],
|
||||
{'parameter': {'name': 'param:0', 'id': 'param:0/placeholder_0',
|
||||
'initial_node_name': 'param:0'}},
|
||||
nodes_with_edges_only=True, cli=argv)
|
||||
self.set_graph_attrs(graph, ['parameter'])
|
||||
self.set_graph_attrs(graph_ref, ['parameter'])
|
||||
graph.graph['layout'] = 'NCHW'
|
||||
|
||||
AddMeanScaleValues().find_and_replace_pattern(graph)
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test_scale_input(self):
|
||||
graph_ref = build_graph(nodes, [
|
||||
*connect('parameter', '0:mul_scale'),
|
||||
|
||||
@@ -455,7 +455,7 @@ def extract_node_attrs(graph: Graph, extractor: callable):
|
||||
return graph
|
||||
|
||||
|
||||
def get_node_id_with_ports(graph: Graph, node_name: str):
|
||||
def get_node_id_with_ports(graph: Graph, node_name: str, skip_if_no_port=True):
|
||||
"""
|
||||
Extracts port and node ID out of user provided name
|
||||
:param graph: graph to operate on
|
||||
@@ -476,12 +476,12 @@ def get_node_id_with_ports(graph: Graph, node_name: str):
|
||||
node = Node(graph, graph.get_node_id_by_name(name))
|
||||
if match.group(1):
|
||||
in_port = int(match.group(1).replace(':', ''))
|
||||
if in_port not in [e['in'] for e in node.in_edges().values()]:
|
||||
if skip_if_no_port and in_port not in [e['in'] for e in node.in_edges().values()]:
|
||||
# skip found node if it doesn't have such port number
|
||||
continue
|
||||
if match.group(3):
|
||||
out_port = int(match.group(3).replace(':', ''))
|
||||
if out_port not in [e['out'] for e in node.out_edges().values()]:
|
||||
if skip_if_no_port and out_port not in [e['out'] for e in node.out_edges().values()]:
|
||||
# skip found node if it doesn't have such port number
|
||||
continue
|
||||
|
||||
|
||||
Reference in New Issue
Block a user