fix typo + test (#5392)

This commit is contained in:
Svetlana Dolinina 2021-04-27 13:37:02 +03:00 committed by GitHub
parent fa7d67b07f
commit 3028c78594
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 30 additions and 2 deletions

View File

@ -337,7 +337,7 @@ def read_node(file_descr, graph, component_layer_map, layer_node_map):
in_port = len(Node(graph, node_name).in_nodes()) in_port = len(Node(graph, node_name).in_nodes())
Node(graph, node_name).add_input_port(in_port) Node(graph, node_name).add_input_port(in_port)
Node(graph, in_node_id).add_output_port(out_port) Node(graph, in_node_id).add_output_port(out_port, skip_if_exist=True)
graph.add_edge(in_node_id, node_name, **create_edge_attrs(in_node_id, node_name, in_node_id, in_port, out_port)) graph.add_edge(in_node_id, node_name, **create_edge_attrs(in_node_id, node_name, in_node_id, in_port, out_port))
elif tokens[0] == b'output-node': elif tokens[0] == b'output-node':
@ -528,7 +528,7 @@ def parse_specifier(string, graph, layer_node_map):
const_node = Const(graph, {'name': scale_const_name, 'value': float_array([scale_value])}).create_node() const_node = Const(graph, {'name': scale_const_name, 'value': float_array([scale_value])}).create_node()
node = Node(graph, node_name) node = Node(graph, node_name)
graph.create_edge(const_node, scale_node, 0, 0, create_edge_attrs(const_node.id, scale_name.id, const_node.id)) graph.create_edge(const_node, scale_node, 0, 0, create_edge_attrs(const_node.id, scale_node.id, const_node.id))
out_port = len(node.out_nodes()) out_port = len(node.out_nodes())
graph.create_edge(node, scale_node, out_port, 1, create_edge_attrs(node_name, scale_node.id, node_name, 1, out_port)) graph.create_edge(node, scale_node, out_port, 1, create_edge_attrs(node_name, scale_node.id, node_name, 1, out_port))
else: else:

View File

@ -203,3 +203,31 @@ class TestKaldiModelsLoading(unittest.TestCase):
) )
(flag, resp) = compare_graphs(graph, ref_graph, 'tdnn1.relu') (flag, resp) = compare_graphs(graph, ref_graph, 'tdnn1.relu')
self.assertTrue(flag, resp) self.assertTrue(flag, resp)
def test_component_map_loading_scale(self):
test_map = "input-node name=input dim=16\n" + \
"component-node name=lda component=lda input=Scale(0.1, input)\n" + \
"\n"
graph = Graph(name="test_graph_component_map_loading_scale")
test_top_map = load_topology_map(io.BytesIO(bytes(test_map, 'ascii')), graph)
ref_map = {b"lda": ["lda"]}
self.assertEqual(test_top_map, ref_map)
self.assertTrue("input" in graph.nodes())
self.assertListEqual(list(Node(graph, 'input')['shape']), [1, 16])
ref_graph = build_graph({'input': {'shape': np.array([1, 16]), 'kind': 'op', 'op': 'Parameter'},
'lda': {'kind': 'op'},
'mul': {'kind': 'op'},
'scale_const': {'kind': 'op', 'op': 'Const'},
},
[
('input', 'mul', {'in': 0}),
('scale_const', 'mul', {'in': 1}),
('mul', 'lda', {'out': 0}),
]
)
(flag, resp) = compare_graphs(graph, ref_graph, 'lda')
self.assertTrue(flag, resp)