Simplify Reshape inserted while converting RNN from mxnet (#13570)
This commit is contained in:
parent
5e25341904
commit
1047bb7732
@ -174,19 +174,17 @@ class RNNSequenceNormalize(MiddleReplacementPattern):
|
||||
shape = Shape(graph, dict(name=rnn_layer_name + '/ShapeOf')).create_node()
|
||||
rnn_layer.in_port(0).get_source().connect(shape.in_port(0))
|
||||
|
||||
batch = node_to_get_shape_value_of_indices(shape, int64_array([rnn_layer.batch_dim]))
|
||||
new_dim = create_op_node_with_second_input(graph, Concat, second_input_value=int64_array([hidden_size]),
|
||||
op_attrs=dict(name=rnn_layer_name + '/HiddenStateResizeDim',
|
||||
in_ports_count=2, axis=0), input_node=batch)
|
||||
reshape_h = Reshape(graph, dict(name=rnn_layer_name + '/HiddenStateResize', override_output_shape=True)).create_node()
|
||||
new_dim.out_port(0).connect(reshape_h.in_port(1))
|
||||
reshape_h = create_op_node_with_second_input(graph, Reshape, second_input_value=int64_array([-1, hidden_size]),
|
||||
op_attrs={'name': rnn_layer_name + '/HiddenStateResize',
|
||||
'override_output_shape': True})
|
||||
rnn_layer.in_port(hidden_init_port).get_connection().insert_node(reshape_h)
|
||||
|
||||
if rnn_layer.op == 'LSTM':
|
||||
assert cell_init_port in rnn_layer.in_nodes()
|
||||
|
||||
reshape_c = Reshape(graph, dict(name=rnn_layer_name + '/CellStateResize', override_output_shape=True)).create_node()
|
||||
new_dim.out_port(0).connect(reshape_c.in_port(1))
|
||||
reshape_c = create_op_node_with_second_input(graph, Reshape,
|
||||
second_input_value=int64_array([-1, hidden_size]),
|
||||
op_attrs={'name': rnn_layer_name + '/CellStateResize',
|
||||
'override_output_shape': True})
|
||||
rnn_layer.in_port(cell_init_port).get_connection().insert_node(reshape_c)
|
||||
|
||||
@staticmethod
|
||||
|
Loading…
Reference in New Issue
Block a user