GatherTree specification refactored (#7326)
* GatherTree specification refactored * Fix typos
This commit is contained in:
committed by
GitHub
parent
b282c74386
commit
c862abae03
@@ -2,63 +2,60 @@
|
||||
|
||||
**Versioned name**: *GatherTree-1*
|
||||
|
||||
**Category**: Beam search post-processing
|
||||
**Category**: *Data movement*
|
||||
|
||||
**Short description**: Generates the complete beams from the ids per each step and the parent beam ids.
|
||||
|
||||
**Detailed description**
|
||||
|
||||
The GatherTree operation implements the same algorithm as the [GatherTree operation in TensorFlow](https://www.tensorflow.org/addons/api_docs/python/tfa/seq2seq/gather_tree).
|
||||
*GatherTree* operation reorders token IDs of a given input tensor `step_id` representing IDs per each step of beam search, based on input tensor `parent_ids` representing the parent beam IDs. For a given beam, past the time step containing the first decoded `end_token` all values are filled in with end_token.
|
||||
|
||||
Pseudo code:
|
||||
The algorithm in pseudocode is as follows:
|
||||
|
||||
```python
|
||||
final_idx[ :, :, :] = end_token
|
||||
final_ids[ :, :, :] = end_token
|
||||
for batch in range(BATCH_SIZE):
|
||||
for beam in range(BEAM_WIDTH):
|
||||
max_sequence_in_beam = min(MAX_TIME, max_seq_len[batch])
|
||||
|
||||
parent = parent_idx[max_sequence_in_beam - 1, batch, beam]
|
||||
parent = parent_ids[max_sequence_in_beam - 1, batch, beam]
|
||||
|
||||
final_idx[max_sequence_in_beam - 1, batch, beam] = step_idx[max_sequence_in_beam - 1, batch, beam]
|
||||
final_ids[max_sequence_in_beam - 1, batch, beam] = step_ids[max_sequence_in_beam - 1, batch, beam]
|
||||
|
||||
for level in reversed(range(max_sequence_in_beam - 1)):
|
||||
final_idx[level, batch, beam] = step_idx[level, batch, parent]
|
||||
final_ids[level, batch, beam] = step_ids[level, batch, parent]
|
||||
|
||||
parent = parent_idx[level, batch, parent]
|
||||
parent = parent_ids[level, batch, parent]
|
||||
|
||||
# For a given beam, past the time step containing the first decoded end_token
|
||||
# all values are filled in with end_token.
|
||||
finished = False
|
||||
for time in range(max_sequence_in_beam):
|
||||
if(finished):
|
||||
final_idx[time, batch, beam] = end_token
|
||||
elif(final_idx[time, batch, beam] == end_token):
|
||||
final_ids[time, batch, beam] = end_token
|
||||
elif(final_ids[time, batch, beam] == end_token):
|
||||
finished = True
|
||||
```
|
||||
|
||||
Element data types for all input tensors should match each other.
|
||||
*GatherTree* operation is equivalent to [GatherTree operation in TensorFlow](https://www.tensorflow.org/addons/api_docs/python/tfa/seq2seq/gather_tree).
|
||||
|
||||
**Attributes**: *GatherTree* has no attributes
|
||||
**Attributes**: *GatherTree* operation has no attributes.
|
||||
|
||||
**Inputs**
|
||||
|
||||
* **1**: `step_ids` -- a tensor of shape `[MAX_TIME, BATCH_SIZE, BEAM_WIDTH]` of type *T* with indices from per each step. **Required.**
|
||||
|
||||
* **2**: `parent_idx` -- a tensor of shape `[MAX_TIME, BATCH_SIZE, BEAM_WIDTH]` of type *T* with parent beam indices. **Required.**
|
||||
|
||||
* **3**: `max_seq_len` -- a tensor of shape `[BATCH_SIZE]` of type *T* with maximum lengths for each sequence in the batch. **Required.**
|
||||
|
||||
* **4**: `end_token` -- a scalar tensor of type *T* with value of the end marker in a sequence. **Required.**
|
||||
|
||||
* **1**: `step_ids` - Indices per each step. A tensor of type *T* and rank 3. Layout is `[MAX_TIME, BATCH_SIZE, BEAM_WIDTH]`. **Required.**
|
||||
* **2**: `parent_ids` - Parent beam indices. A tensor of type *T* and rank 3. Layout is `[MAX_TIME, BATCH_SIZE, BEAM_WIDTH]`. **Required.**
|
||||
* **3**: `max_seq_len` - Maximum lengths for each sequence in the batch. A tensor of type *T* and rank 1. Layout is `[BATCH_SIZE]`. **Required.**
|
||||
* **4**: `end_token` - Value of the end marker in a sequence. A scalar of type *T*. **Required.**
|
||||
* **Note**: Inputs should have integer values only.
|
||||
|
||||
**Outputs**
|
||||
|
||||
* **1**: `final_idx` -- a tensor of shape `[MAX_TIME, BATCH_SIZE, BEAM_WIDTH]` of type *T*.
|
||||
* **1**: `final_ids` - The reordered token IDs based on `parent_ids` input. A tensor of type *T* and rank 3. Layout is `[MAX_TIME, BATCH_SIZE, BEAM_WIDTH]`.
|
||||
|
||||
**Types**
|
||||
|
||||
* *T*: `float32` or `int32`; `float32` should have integer values only.
|
||||
* *T*: any supported numeric type.
|
||||
|
||||
**Example**
|
||||
|
||||
|
||||
Reference in New Issue
Block a user