beam_search_decode(ids, scores, beam_size, end_id, name=None)
This operator is used after beam search has completed. It constructs the full predicted sequences for each sample by walking back along the search paths stored in lod of
ids. The result sequences are stored in a LoDTensor, which uses the following way to parse:
If lod = [[0, 3, 6], [0, 12, 24, 40, 54, 67, 82]] The first level of lod stands for: There are 2 samples each having 3 (beam width) predicted sequence. The second level of lod stands for: The lengths of the first sample's 3 predicted sequences are 12, 12, 16; The lengths of the second sample's 3 predicted sequences are 14, 13, 15.
- Please see the following demo for a fully beam search usage example:
ids (Variable) – The LoDTensorArray variable containing the selected ids of all steps. Each LoDTensor in it has int64 data type and 2 level lod which can be used to get the search paths.
scores (Variable) – The LodTensorArray variable containing the accumulated scores corresponding to selected ids of all steps. It has the same size as
ids. Each LoDTensor in it has the same shape and lod as the counterpart in
ids, and has a float32 data type.
beam_size (int) – The beam width used in beam search.
end_id (int) – The id of end token.
name (str, optional) – For detailed information, please refer to Name. Usually name is no need to set and None by default.
The tuple contains two LodTensor variables. The two LodTensor, containing the full sequences of ids and the corresponding accumulated scores, have the same shape flattened to 1D and have the same 2 level lod. The lod can be used to get how many predicted sequences each sample has and how many ids each predicted sequence has.
- Return type
import paddle.fluid as fluid # Suppose `ids` and `scores` are LodTensorArray variables reserving # the selected ids and scores of all steps ids = fluid.layers.create_array(dtype='int64') scores = fluid.layers.create_array(dtype='float32') finished_ids, finished_scores = fluid.layers.beam_search_decode( ids, scores, beam_size=5, end_id=0)