GatherTree description was extended and outdated link fixed (#2167)
* add more alrifications to description * move clarification to comment * pseudo code become more accurate * review changes
This commit is contained in:
parent
5d59a112d7
commit
43d6bf045b
@ -8,21 +8,33 @@
|
||||
|
||||
**Detailed description**
|
||||
|
||||
GatherTree operation implements the same algorithm as GatherTree operation in TensorFlow. Please see complete documentation [here](https://www.tensorflow.org/versions/r1.12/api_docs/python/tf/contrib/seq2seq/gather_tree?hl=en).
|
||||
The GatherTree operation implements the same algorithm as the [GatherTree operation in TensorFlow](https://www.tensorflow.org/addons/api_docs/python/tfa/seq2seq/gather_tree).
|
||||
|
||||
Pseudo code:
|
||||
|
||||
```python
|
||||
final_idx[ :, :, :] = end_token
|
||||
for batch in range(BATCH_SIZE):
|
||||
for beam in range(BEAM_WIDTH):
|
||||
max_sequence_in_beam = min(MAX_TIME, max_seq_len[batch])
|
||||
|
||||
parent = parent_idx[max_sequence_in_beam - 1, batch, beam]
|
||||
|
||||
final_idx[max_sequence_in_beam - 1, batch, beam] = step_idx[max_sequence_in_beam - 1, batch, beam]
|
||||
|
||||
for level in reversed(range(max_sequence_in_beam - 1)):
|
||||
final_idx[level, batch, beam] = step_idx[level, batch, parent]
|
||||
|
||||
parent = parent_idx[level, batch, parent]
|
||||
|
||||
# For a given beam, past the time step containing the first decoded end_token
|
||||
# all values are filled in with end_token.
|
||||
finished = False
|
||||
for time in range(max_sequence_in_beam):
|
||||
if(finished):
|
||||
final_idx[time, batch, beam] = end_token
|
||||
elif(final_idx[time, batch, beam] == end_token):
|
||||
finished = True
|
||||
```
|
||||
|
||||
Element data types for all input tensors should match each other.
|
||||
|
Loading…
Reference in New Issue
Block a user