added delete of reversesequences to avoid run of transformation twice
This commit is contained in:
@@ -214,6 +214,8 @@ class ReverseTensorIteratorLSTMWithSqueeze(MiddleReplacementPattern):
|
||||
inverse_reverse = match['inverse_reverse']
|
||||
squeeze = match['squeeze']
|
||||
unsqueeze = match['unsqueeze']
|
||||
squeeze_1 = match['squeeze_1']
|
||||
unsqueeze_1 = match['unsqueeze_1']
|
||||
|
||||
assert direct_reverse.seq_axis == inverse_reverse.seq_axis
|
||||
assert direct_reverse.batch_axis is None and inverse_reverse.batch_axis is None or \
|
||||
@@ -226,7 +228,6 @@ class ReverseTensorIteratorLSTMWithSqueeze(MiddleReplacementPattern):
|
||||
|
||||
direct_reverse.seq_axis = direct_reverse.seq_axis - 1 if direct_reverse.seq_axis > 0 else direct_reverse.seq_axis
|
||||
update_ti(ti, direct_reverse)
|
||||
direct_reverse.seq_axis = direct_reverse.seq_axis + 1 if direct_reverse.seq_axis > 0 else direct_reverse.seq_axis
|
||||
|
||||
# Remove reverses
|
||||
unsqueeze.out_port(0).get_destinations()
|
||||
@@ -235,9 +236,9 @@ class ReverseTensorIteratorLSTMWithSqueeze(MiddleReplacementPattern):
|
||||
dest.get_connection().set_source(in_unsqueeze)
|
||||
unsqueeze.in_port(0).disconnect()
|
||||
|
||||
squeeze_1 = match['squeeze_1']
|
||||
unsqueeze_1 = match['unsqueeze_1']
|
||||
in_unsqueeze_1 = unsqueeze_1.in_port(0).get_source()
|
||||
for dest in squeeze_1.out_port(0).get_destinations():
|
||||
dest.get_connection().set_source(in_unsqueeze_1)
|
||||
unsqueeze_1.in_port(0).disconnect()
|
||||
graph.remove_node(direct_reverse.id)
|
||||
graph.remove_node(inverse_reverse.id)
|
||||
|
||||
Reference in New Issue
Block a user