Simplify Reshape inserted while converting RNN from mxnet (#13570)

This commit is contained in:
Evgenya Stepyreva 2022-10-21 14:29:39 +04:00 committed by GitHub
parent 5e25341904
commit 1047bb7732
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -174,19 +174,17 @@ class RNNSequenceNormalize(MiddleReplacementPattern):
shape = Shape(graph, dict(name=rnn_layer_name + '/ShapeOf')).create_node()
rnn_layer.in_port(0).get_source().connect(shape.in_port(0))
batch = node_to_get_shape_value_of_indices(shape, int64_array([rnn_layer.batch_dim]))
new_dim = create_op_node_with_second_input(graph, Concat, second_input_value=int64_array([hidden_size]),
op_attrs=dict(name=rnn_layer_name + '/HiddenStateResizeDim',
in_ports_count=2, axis=0), input_node=batch)
reshape_h = Reshape(graph, dict(name=rnn_layer_name + '/HiddenStateResize', override_output_shape=True)).create_node()
new_dim.out_port(0).connect(reshape_h.in_port(1))
reshape_h = create_op_node_with_second_input(graph, Reshape, second_input_value=int64_array([-1, hidden_size]),
op_attrs={'name': rnn_layer_name + '/HiddenStateResize',
'override_output_shape': True})
rnn_layer.in_port(hidden_init_port).get_connection().insert_node(reshape_h)
if rnn_layer.op == 'LSTM':
assert cell_init_port in rnn_layer.in_nodes()
reshape_c = Reshape(graph, dict(name=rnn_layer_name + '/CellStateResize', override_output_shape=True)).create_node()
new_dim.out_port(0).connect(reshape_c.in_port(1))
reshape_c = create_op_node_with_second_input(graph, Reshape,
second_input_value=int64_array([-1, hidden_size]),
op_attrs={'name': rnn_layer_name + '/CellStateResize',
'override_output_shape': True})
rnn_layer.in_port(cell_init_port).get_connection().insert_node(reshape_c)
@staticmethod