GatherTree description was extended and outdated link fixed (#2167)

* add more alrifications to description

* move clarification to comment

* pseudo code become more accurate

* review changes
This commit is contained in:
Svetlana Dolinina 2020-09-14 19:49:29 +03:00 committed by GitHub
parent 5d59a112d7
commit 43d6bf045b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -8,21 +8,33 @@
**Detailed description**
GatherTree operation implements the same algorithm as GatherTree operation in TensorFlow. Please see complete documentation [here](https://www.tensorflow.org/versions/r1.12/api_docs/python/tf/contrib/seq2seq/gather_tree?hl=en).
The GatherTree operation implements the same algorithm as the [GatherTree operation in TensorFlow](https://www.tensorflow.org/addons/api_docs/python/tfa/seq2seq/gather_tree).
Pseudo code:
```python
final_idx[ :, :, :] = end_token
for batch in range(BATCH_SIZE):
for beam in range(BEAM_WIDTH):
max_sequence_in_beam = min(MAX_TIME, max_seq_len[batch])
parent = parent_idx[max_sequence_in_beam - 1, batch, beam]
final_idx[max_sequence_in_beam - 1, batch, beam] = step_idx[max_sequence_in_beam - 1, batch, beam]
for level in reversed(range(max_sequence_in_beam - 1)):
final_idx[level, batch, beam] = step_idx[level, batch, parent]
parent = parent_idx[level, batch, parent]
# For a given beam, past the time step containing the first decoded end_token
# all values are filled in with end_token.
finished = False
for time in range(max_sequence_in_beam):
if(finished):
final_idx[time, batch, beam] = end_token
elif(final_idx[time, batch, beam] == end_token):
finished = True
```
Element data types for all input tensors should match each other.