added delete of reversesequences to avoid run of transformation twice

This commit is contained in:
sadolini
2021-11-10 10:30:17 +03:00
parent 0a00e621df
commit fcb7de9c9d

View File

@@ -214,6 +214,8 @@ class ReverseTensorIteratorLSTMWithSqueeze(MiddleReplacementPattern):
inverse_reverse = match['inverse_reverse']
squeeze = match['squeeze']
unsqueeze = match['unsqueeze']
squeeze_1 = match['squeeze_1']
unsqueeze_1 = match['unsqueeze_1']
assert direct_reverse.seq_axis == inverse_reverse.seq_axis
assert direct_reverse.batch_axis is None and inverse_reverse.batch_axis is None or \
@@ -226,7 +228,6 @@ class ReverseTensorIteratorLSTMWithSqueeze(MiddleReplacementPattern):
direct_reverse.seq_axis = direct_reverse.seq_axis - 1 if direct_reverse.seq_axis > 0 else direct_reverse.seq_axis
update_ti(ti, direct_reverse)
direct_reverse.seq_axis = direct_reverse.seq_axis + 1 if direct_reverse.seq_axis > 0 else direct_reverse.seq_axis
# Remove reverses
unsqueeze.out_port(0).get_destinations()
@@ -235,9 +236,9 @@ class ReverseTensorIteratorLSTMWithSqueeze(MiddleReplacementPattern):
dest.get_connection().set_source(in_unsqueeze)
unsqueeze.in_port(0).disconnect()
squeeze_1 = match['squeeze_1']
unsqueeze_1 = match['unsqueeze_1']
in_unsqueeze_1 = unsqueeze_1.in_port(0).get_source()
for dest in squeeze_1.out_port(0).get_destinations():
dest.get_connection().set_source(in_unsqueeze_1)
unsqueeze_1.in_port(0).disconnect()
graph.remove_node(direct_reverse.id)
graph.remove_node(inverse_reverse.id)