Improve node name with port resolving (#1581)

* Improve node name with port resolving

* Fix IE remove Convert on output

* Address feedback
This commit is contained in:
Maxim Vafin 2020-08-05 11:31:17 +03:00 committed by GitHub
parent 850665d992
commit 75cb10fd6d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 66 additions and 25 deletions

View File

@ -1401,6 +1401,25 @@ void convertLayerPrecision(const CNNLayerPtr& layer) {
}
}
template <typename NET>
void RemoveConverts(NET& net, std::vector<CNNLayerPtr>& to_remove) {
for (auto& layer : to_remove) {
RemoveLayer(layer, net);
}
}
template <>
void RemoveConverts(ICNNNetwork& net, std::vector<CNNLayerPtr>& to_remove) {
OutputsDataMap outputs;
net.getOutputsInfo(outputs);
for (auto& layer : to_remove) {
if (!std::any_of(outputs.begin(), outputs.end(),
[layer](std::pair<std::string, DataPtr> p) { return p.second->getName() == layer->name; })) {
RemoveLayer(layer, net);
}
}
}
template <typename NET>
void fixConvertLayers(NET &net) {
std::vector<CNNLayerPtr> to_remove;
@ -1422,9 +1441,7 @@ void fixConvertLayers(NET &net) {
}
}
}
for (auto &layer : to_remove) {
RemoveLayer(layer, net);
}
RemoveConverts(net, to_remove);
}
template <Precision::ePrecision PREC_FROM, Precision::ePrecision PREC_TO, typename NET>

View File

@ -459,10 +459,18 @@ def get_node_id_with_ports(graph: Graph, node_name: str):
if match.group(1) and match.group(3):
log.warning('Skipping the case with both in and out port specified, only one port can be specified')
continue
node = Node(graph, graph.get_node_id_by_name(name))
if match.group(1):
in_port = int(match.group(1).replace(':', ''))
if in_port not in [e['in'] for e in node.in_edges().values()]:
# skip found node if it doesn't have such port number
continue
if match.group(3):
out_port = int(match.group(3).replace(':', ''))
if out_port not in [e['out'] for e in node.out_edges().values()]:
# skip found node if it doesn't have such port number
continue
found_names.append((in_port, out_port, name))
if len(found_names) == 0:
raise Error('No node with name {}'.format(node_name))

View File

@ -502,25 +502,25 @@ class TestUserDataRepack(unittest.TestCase):
]
def test_input_user_data_repack_none(self):
graph = build_graph_with_edge_attrs(self.nodes, self.edges)
graph = build_graph(self.nodes, self.edges)
input, freeze_placeholder = input_user_data_repack(graph, None, None)
self.assertEqual(input, None)
self.assertEqual(freeze_placeholder, None)
def test_input_user_data_repack_names_to_ids_list(self):
graph = build_graph_with_edge_attrs(self.nodes, self.edges)
graph = build_graph(self.nodes, self.edges)
input, freeze_placeholder = input_user_data_repack(graph, ['Aa', 'Bb'], None)
self.assertDictEqual(input, {'A': [{'shape': None, 'port': None}], 'B': [{'shape': None, 'port': None}]})
self.assertEqual(freeze_placeholder, None)
def test_input_user_data_repack_names_ports_in_out(self):
graph = build_graph_with_edge_attrs(self.nodes, self.edges)
input, freeze_placeholder = input_user_data_repack(graph, ['Aa:1', '0:Bb'], None)
self.assertDictEqual(input, {'A': [{'shape': None, 'out': 1}], 'B': [{'shape': None, 'in': 0}]})
graph = build_graph(self.nodes, self.edges)
input, freeze_placeholder = input_user_data_repack(graph, ['Aa:0', '1:Cc'], None)
self.assertDictEqual(input, {'A': [{'shape': None, 'out': 0}], 'C': [{'shape': None, 'in': 1}]})
self.assertEqual(freeze_placeholder, None)
def test_input_user_data_repack_dict_with_shapes(self):
graph = build_graph_with_edge_attrs(self.nodes, self.edges)
graph = build_graph(self.nodes, self.edges)
shape_1 = np.array([1, 160, 160, 3])
shape_2 = np.array([1, 127, 127, 3])
input, freeze_placeholder = input_user_data_repack(graph, {'Aa': shape_1, 'Bb': shape_2}, None)
@ -528,34 +528,34 @@ class TestUserDataRepack(unittest.TestCase):
self.assertEqual(freeze_placeholder, None)
def test_input_user_data_repack_dict_with_shapes_and_ports(self):
graph = build_graph_with_edge_attrs(self.nodes, self.edges)
graph = build_graph(self.nodes, self.edges)
shape_1 = np.array([1, 160, 160, 3])
shape_2 = np.array([1, 127, 127, 3])
input, freeze_placeholder = input_user_data_repack(graph, {'Aa:0': shape_1, 'Bb:1': shape_2}, None)
self.assertDictEqual(input, {'A': [{'shape': shape_1, 'out': 0}], 'B': [{'shape': shape_2, 'out': 1}]})
input, freeze_placeholder = input_user_data_repack(graph, {'Aa:0': shape_1, 'Bb:0': shape_2}, None)
self.assertDictEqual(input, {'A': [{'shape': shape_1, 'out': 0}], 'B': [{'shape': shape_2, 'out': 0}]})
self.assertEqual(freeze_placeholder, None)
def test_freeze_placeholder_and_input(self):
graph = build_graph_with_edge_attrs(self.nodes, self.edges)
graph = build_graph(self.nodes, self.edges)
shape_1 = np.array([1, 160, 160, 3])
input, freeze_placeholder = input_user_data_repack(graph, {'Aa:0': shape_1}, {'Bb': False})
self.assertDictEqual(input, {'A': [{'shape': shape_1, 'out': 0}], 'B': [{'shape': None, 'port': None}]})
self.assertEqual(freeze_placeholder, {'B': False})
def test_error(self):
graph = build_graph_with_edge_attrs(self.nodes, self.edges)
graph = build_graph(self.nodes, self.edges)
self.assertRaises(Error, input_user_data_repack, graph, np.array([1, 227, 227, 3]), None)
def test_error_2(self):
graph = build_graph_with_edge_attrs(self.nodes, self.edges)
graph = build_graph(self.nodes, self.edges)
self.assertRaises(Error, input_user_data_repack, graph, np.array([1, 227, 227, 3]), None)
def test_error_3(self):
graph = build_graph_with_edge_attrs(self.nodes, self.edges)
graph = build_graph(self.nodes, self.edges)
self.assertRaises(Error, input_user_data_repack, graph, ['Bcb'], None)
def test_input_and_freeze(self):
graph = build_graph_with_edge_attrs(self.nodes, self.edges)
graph = build_graph(self.nodes, self.edges)
shape_1 = np.array([1, 160, 160, 3])
input, freeze_placeholder = input_user_data_repack(graph, shape_1, {'Bb': True})
self.assertDictEqual(input, {'A': [{'shape': shape_1, 'port': None}], 'B': [{'shape': None, 'port': None}]})
@ -563,7 +563,7 @@ class TestUserDataRepack(unittest.TestCase):
def test_freeze_new_placeholder_1(self):
# create a new placeholder Cc:0 by cutting output port with shape_2 = [5] and freeze a value [1.0 1.0 2.0 3.0 5.0]
graph = build_graph_with_edge_attrs(self.nodes, self.edges)
graph = build_graph(self.nodes, self.edges)
shape_1 = np.array([1, 160, 160, 3])
shape_2 = np.array([5])
input, freeze_placeholder = input_user_data_repack(graph, {'Aa:0': shape_1, 'Cc:0' : shape_2}, {'Bb': False, 'Cc:0' : [1.0, 1.0, 2.0, 3.0, 5.0]})
@ -572,7 +572,7 @@ class TestUserDataRepack(unittest.TestCase):
def test_freeze_new_placeholder_2(self):
# create a new placeholder Ee by cutting input port with shape_2 = [2, 2] and freeze a value [[1.0, 1.0], [2.0, 3.0]]
graph = build_graph_with_edge_attrs(self.nodes, self.edges)
graph = build_graph(self.nodes, self.edges)
shape_1 = np.array([1, 160, 160, 3])
shape_2 = np.array([2, 2])
input, freeze_placeholder = input_user_data_repack(graph, {'Aa:0': shape_1, 'Ee' : shape_2}, {'Bb': False, 'Ee' : [[1.0, 1.0], [2.0, 3.0]]})
@ -581,22 +581,22 @@ class TestUserDataRepack(unittest.TestCase):
def test_freeze_new_placeholder_error(self):
# shape is not specified for new placeholder Cc:0 with frozen value
graph = build_graph_with_edge_attrs(self.nodes, self.edges)
graph = build_graph(self.nodes, self.edges)
shape_1 = np.array([1, 160, 160, 3])
self.assertRaises(Error, input_user_data_repack, graph, {'Aa:0': shape_1}, {'Bb': False, 'Cc:0' : [1.0, 1.0, 2.0, 3.0, 5.0]})
def test_output_user_data_repack(self):
graph = build_graph_with_edge_attrs(self.nodes, self.edges)
graph = build_graph(self.nodes, self.edges)
output = output_user_data_repack(graph, ['Cc'])
self.assertDictEqual(output, {'C': [{'port': None}]})
def test_output_user_data_repack_ports(self):
graph = build_graph_with_edge_attrs(self.nodes, self.edges)
graph = build_graph(self.nodes, self.edges)
output = output_user_data_repack(graph, ['Cc:1', '0:Cc'])
self.assertDictEqual(output, {'C': [{'out': 1}, {'in': 0}]})
def test_output_user_data_repack_none(self):
graph = build_graph_with_edge_attrs(self.nodes, self.edges)
graph = build_graph(self.nodes, self.edges)
output = output_user_data_repack(graph, None)
self.assertEqual(output, None)
@ -629,8 +629,8 @@ class TestExtractPort(unittest.TestCase):
self.assertEqual(port, 0)
def test_in_port2(self):
node_id, direction, port = get_node_id_with_ports(self.graph, '0:1input1:0')
self.assertEqual(node_id, 'input_id')
node_id, direction, port = get_node_id_with_ports(self.graph, '0:relu:0')
self.assertEqual(node_id, 'squeeze_id')
self.assertEqual(direction, 'in')
self.assertEqual(port, 0)
@ -649,6 +649,22 @@ class TestExtractPort(unittest.TestCase):
def test_two_ports(self):
self.assertRaises(Error, get_node_id_with_ports, self.graph, '0:1input1:1')
def test_name_looks_like_port_number(self):
nodes = {
'input_id': {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter', 'name': '0'},
'conv_id': {'type': 'Convolution', 'kind': 'op', 'op': 'NotPlaceholder', 'name': '1'},
'relu_id': {'type': 'ReLU', 'kind': 'op', 'op': 'NotPlaceholder', 'name': '2'},
}
edges = [
('input_id', 'conv_id'),
('conv_id', 'relu_id'),
]
graph = build_graph(nodes, edges)
node_id, direction, port = get_node_id_with_ports(graph, '0:2')
self.assertEqual(node_id, 'relu_id')
self.assertEqual(direction, 'in')
self.assertEqual(port, 0)
class TestCaffePythonFrontExtractorOp(unittest.TestCase):
def test_get_attrs(self):