GatherTree specification refactored (#7326)

* GatherTree specification refactored

* Fix typos
This commit is contained in:
Gabriele Galiero Casay
2021-09-10 05:31:35 +02:00
committed by GitHub
parent b282c74386
commit c862abae03

View File

@@ -2,63 +2,60 @@
**Versioned name**: *GatherTree-1*
**Category**: Beam search post-processing
**Category**: *Data movement*
**Short description**: Generates the complete beams from the ids per each step and the parent beam ids.
**Detailed description**
The GatherTree operation implements the same algorithm as the [GatherTree operation in TensorFlow](https://www.tensorflow.org/addons/api_docs/python/tfa/seq2seq/gather_tree).
*GatherTree* operation reorders token IDs of a given input tensor `step_id` representing IDs per each step of beam search, based on input tensor `parent_ids` representing the parent beam IDs. For a given beam, past the time step containing the first decoded `end_token` all values are filled in with end_token.
Pseudo code:
The algorithm in pseudocode is as follows:
```python
final_idx[ :, :, :] = end_token
final_ids[ :, :, :] = end_token
for batch in range(BATCH_SIZE):
for beam in range(BEAM_WIDTH):
max_sequence_in_beam = min(MAX_TIME, max_seq_len[batch])
parent = parent_idx[max_sequence_in_beam - 1, batch, beam]
parent = parent_ids[max_sequence_in_beam - 1, batch, beam]
final_idx[max_sequence_in_beam - 1, batch, beam] = step_idx[max_sequence_in_beam - 1, batch, beam]
final_ids[max_sequence_in_beam - 1, batch, beam] = step_ids[max_sequence_in_beam - 1, batch, beam]
for level in reversed(range(max_sequence_in_beam - 1)):
final_idx[level, batch, beam] = step_idx[level, batch, parent]
final_ids[level, batch, beam] = step_ids[level, batch, parent]
parent = parent_idx[level, batch, parent]
parent = parent_ids[level, batch, parent]
# For a given beam, past the time step containing the first decoded end_token
# all values are filled in with end_token.
finished = False
for time in range(max_sequence_in_beam):
if(finished):
final_idx[time, batch, beam] = end_token
elif(final_idx[time, batch, beam] == end_token):
final_ids[time, batch, beam] = end_token
elif(final_ids[time, batch, beam] == end_token):
finished = True
```
Element data types for all input tensors should match each other.
*GatherTree* operation is equivalent to [GatherTree operation in TensorFlow](https://www.tensorflow.org/addons/api_docs/python/tfa/seq2seq/gather_tree).
**Attributes**: *GatherTree* has no attributes
**Attributes**: *GatherTree* operation has no attributes.
**Inputs**
* **1**: `step_ids` -- a tensor of shape `[MAX_TIME, BATCH_SIZE, BEAM_WIDTH]` of type *T* with indices from per each step. **Required.**
* **2**: `parent_idx` -- a tensor of shape `[MAX_TIME, BATCH_SIZE, BEAM_WIDTH]` of type *T* with parent beam indices. **Required.**
* **3**: `max_seq_len` -- a tensor of shape `[BATCH_SIZE]` of type *T* with maximum lengths for each sequence in the batch. **Required.**
* **4**: `end_token` -- a scalar tensor of type *T* with value of the end marker in a sequence. **Required.**
* **1**: `step_ids` - Indices per each step. A tensor of type *T* and rank 3. Layout is `[MAX_TIME, BATCH_SIZE, BEAM_WIDTH]`. **Required.**
* **2**: `parent_ids` - Parent beam indices. A tensor of type *T* and rank 3. Layout is `[MAX_TIME, BATCH_SIZE, BEAM_WIDTH]`. **Required.**
* **3**: `max_seq_len` - Maximum lengths for each sequence in the batch. A tensor of type *T* and rank 1. Layout is `[BATCH_SIZE]`. **Required.**
* **4**: `end_token` - Value of the end marker in a sequence. A scalar of type *T*. **Required.**
* **Note**: Inputs should have integer values only.
**Outputs**
* **1**: `final_idx` -- a tensor of shape `[MAX_TIME, BATCH_SIZE, BEAM_WIDTH]` of type *T*.
* **1**: `final_ids` - The reordered token IDs based on `parent_ids` input. A tensor of type *T* and rank 3. Layout is `[MAX_TIME, BATCH_SIZE, BEAM_WIDTH]`.
**Types**
* *T*: `float32` or `int32`; `float32` should have integer values only.
* *T*: any supported numeric type.
**Example**