fix typo + test (#5392)
This commit is contained in:
parent
fa7d67b07f
commit
3028c78594
@ -337,7 +337,7 @@ def read_node(file_descr, graph, component_layer_map, layer_node_map):
|
|||||||
in_port = len(Node(graph, node_name).in_nodes())
|
in_port = len(Node(graph, node_name).in_nodes())
|
||||||
|
|
||||||
Node(graph, node_name).add_input_port(in_port)
|
Node(graph, node_name).add_input_port(in_port)
|
||||||
Node(graph, in_node_id).add_output_port(out_port)
|
Node(graph, in_node_id).add_output_port(out_port, skip_if_exist=True)
|
||||||
|
|
||||||
graph.add_edge(in_node_id, node_name, **create_edge_attrs(in_node_id, node_name, in_node_id, in_port, out_port))
|
graph.add_edge(in_node_id, node_name, **create_edge_attrs(in_node_id, node_name, in_node_id, in_port, out_port))
|
||||||
elif tokens[0] == b'output-node':
|
elif tokens[0] == b'output-node':
|
||||||
@ -528,7 +528,7 @@ def parse_specifier(string, graph, layer_node_map):
|
|||||||
const_node = Const(graph, {'name': scale_const_name, 'value': float_array([scale_value])}).create_node()
|
const_node = Const(graph, {'name': scale_const_name, 'value': float_array([scale_value])}).create_node()
|
||||||
|
|
||||||
node = Node(graph, node_name)
|
node = Node(graph, node_name)
|
||||||
graph.create_edge(const_node, scale_node, 0, 0, create_edge_attrs(const_node.id, scale_name.id, const_node.id))
|
graph.create_edge(const_node, scale_node, 0, 0, create_edge_attrs(const_node.id, scale_node.id, const_node.id))
|
||||||
out_port = len(node.out_nodes())
|
out_port = len(node.out_nodes())
|
||||||
graph.create_edge(node, scale_node, out_port, 1, create_edge_attrs(node_name, scale_node.id, node_name, 1, out_port))
|
graph.create_edge(node, scale_node, out_port, 1, create_edge_attrs(node_name, scale_node.id, node_name, 1, out_port))
|
||||||
else:
|
else:
|
||||||
|
@ -203,3 +203,31 @@ class TestKaldiModelsLoading(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
(flag, resp) = compare_graphs(graph, ref_graph, 'tdnn1.relu')
|
(flag, resp) = compare_graphs(graph, ref_graph, 'tdnn1.relu')
|
||||||
self.assertTrue(flag, resp)
|
self.assertTrue(flag, resp)
|
||||||
|
|
||||||
|
def test_component_map_loading_scale(self):
|
||||||
|
test_map = "input-node name=input dim=16\n" + \
|
||||||
|
"component-node name=lda component=lda input=Scale(0.1, input)\n" + \
|
||||||
|
"\n"
|
||||||
|
graph = Graph(name="test_graph_component_map_loading_scale")
|
||||||
|
|
||||||
|
test_top_map = load_topology_map(io.BytesIO(bytes(test_map, 'ascii')), graph)
|
||||||
|
|
||||||
|
ref_map = {b"lda": ["lda"]}
|
||||||
|
self.assertEqual(test_top_map, ref_map)
|
||||||
|
self.assertTrue("input" in graph.nodes())
|
||||||
|
self.assertListEqual(list(Node(graph, 'input')['shape']), [1, 16])
|
||||||
|
|
||||||
|
ref_graph = build_graph({'input': {'shape': np.array([1, 16]), 'kind': 'op', 'op': 'Parameter'},
|
||||||
|
'lda': {'kind': 'op'},
|
||||||
|
'mul': {'kind': 'op'},
|
||||||
|
'scale_const': {'kind': 'op', 'op': 'Const'},
|
||||||
|
},
|
||||||
|
[
|
||||||
|
('input', 'mul', {'in': 0}),
|
||||||
|
('scale_const', 'mul', {'in': 1}),
|
||||||
|
('mul', 'lda', {'out': 0}),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
(flag, resp) = compare_graphs(graph, ref_graph, 'lda')
|
||||||
|
self.assertTrue(flag, resp)
|
||||||
|
Loading…
Reference in New Issue
Block a user