fix typo + test (#5392)
This commit is contained in:
parent
fa7d67b07f
commit
3028c78594
@ -337,7 +337,7 @@ def read_node(file_descr, graph, component_layer_map, layer_node_map):
|
||||
in_port = len(Node(graph, node_name).in_nodes())
|
||||
|
||||
Node(graph, node_name).add_input_port(in_port)
|
||||
Node(graph, in_node_id).add_output_port(out_port)
|
||||
Node(graph, in_node_id).add_output_port(out_port, skip_if_exist=True)
|
||||
|
||||
graph.add_edge(in_node_id, node_name, **create_edge_attrs(in_node_id, node_name, in_node_id, in_port, out_port))
|
||||
elif tokens[0] == b'output-node':
|
||||
@ -528,7 +528,7 @@ def parse_specifier(string, graph, layer_node_map):
|
||||
const_node = Const(graph, {'name': scale_const_name, 'value': float_array([scale_value])}).create_node()
|
||||
|
||||
node = Node(graph, node_name)
|
||||
graph.create_edge(const_node, scale_node, 0, 0, create_edge_attrs(const_node.id, scale_name.id, const_node.id))
|
||||
graph.create_edge(const_node, scale_node, 0, 0, create_edge_attrs(const_node.id, scale_node.id, const_node.id))
|
||||
out_port = len(node.out_nodes())
|
||||
graph.create_edge(node, scale_node, out_port, 1, create_edge_attrs(node_name, scale_node.id, node_name, 1, out_port))
|
||||
else:
|
||||
|
@ -203,3 +203,31 @@ class TestKaldiModelsLoading(unittest.TestCase):
|
||||
)
|
||||
(flag, resp) = compare_graphs(graph, ref_graph, 'tdnn1.relu')
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test_component_map_loading_scale(self):
|
||||
test_map = "input-node name=input dim=16\n" + \
|
||||
"component-node name=lda component=lda input=Scale(0.1, input)\n" + \
|
||||
"\n"
|
||||
graph = Graph(name="test_graph_component_map_loading_scale")
|
||||
|
||||
test_top_map = load_topology_map(io.BytesIO(bytes(test_map, 'ascii')), graph)
|
||||
|
||||
ref_map = {b"lda": ["lda"]}
|
||||
self.assertEqual(test_top_map, ref_map)
|
||||
self.assertTrue("input" in graph.nodes())
|
||||
self.assertListEqual(list(Node(graph, 'input')['shape']), [1, 16])
|
||||
|
||||
ref_graph = build_graph({'input': {'shape': np.array([1, 16]), 'kind': 'op', 'op': 'Parameter'},
|
||||
'lda': {'kind': 'op'},
|
||||
'mul': {'kind': 'op'},
|
||||
'scale_const': {'kind': 'op', 'op': 'Const'},
|
||||
},
|
||||
[
|
||||
('input', 'mul', {'in': 0}),
|
||||
('scale_const', 'mul', {'in': 1}),
|
||||
('mul', 'lda', {'out': 0}),
|
||||
]
|
||||
)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, ref_graph, 'lda')
|
||||
self.assertTrue(flag, resp)
|
||||
|
Loading…
Reference in New Issue
Block a user