From 43d6bf045bb18dc583718a160a885fa48227eec4 Mon Sep 17 00:00:00 2001 From: Svetlana Dolinina Date: Mon, 14 Sep 2020 19:49:29 +0300 Subject: [PATCH] GatherTree description was extended and outdated link fixed (#2167) * add more alrifications to description * move clarification to comment * pseudo code become more accurate * review changes --- docs/ops/movement/GatherTree_1.md | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/docs/ops/movement/GatherTree_1.md b/docs/ops/movement/GatherTree_1.md index 773beeae794..0b81bea248e 100644 --- a/docs/ops/movement/GatherTree_1.md +++ b/docs/ops/movement/GatherTree_1.md @@ -8,21 +8,33 @@ **Detailed description** -GatherTree operation implements the same algorithm as GatherTree operation in TensorFlow. Please see complete documentation [here](https://www.tensorflow.org/versions/r1.12/api_docs/python/tf/contrib/seq2seq/gather_tree?hl=en). +The GatherTree operation implements the same algorithm as the [GatherTree operation in TensorFlow](https://www.tensorflow.org/addons/api_docs/python/tfa/seq2seq/gather_tree). Pseudo code: ```python +final_idx[ :, :, :] = end_token for batch in range(BATCH_SIZE): for beam in range(BEAM_WIDTH): max_sequence_in_beam = min(MAX_TIME, max_seq_len[batch]) parent = parent_idx[max_sequence_in_beam - 1, batch, beam] + final_idx[max_sequence_in_beam - 1, batch, beam] = step_idx[max_sequence_in_beam - 1, batch, beam] + for level in reversed(range(max_sequence_in_beam - 1)): final_idx[level, batch, beam] = step_idx[level, batch, parent] parent = parent_idx[level, batch, parent] + + # For a given beam, past the time step containing the first decoded end_token + # all values are filled in with end_token. + finished = False + for time in range(max_sequence_in_beam): + if(finished): + final_idx[time, batch, beam] = end_token + elif(final_idx[time, batch, beam] == end_token): + finished = True ``` Element data types for all input tensors should match each other.