BeamSearchDecoder

class paddle.nn. BeamSearchDecoder ( cell, start_token, end_token, beam_size, embedding_fn=None, output_fn=None ) [source]

Decoder with beam search decoding strategy. It wraps a cell to get probabilities, and follows a beam search step to calculate scores and select candidate token ids for each decoding step.

Please refer to Beam search for more details.

NOTE When decoding with beam search, the inputs and states of cell would be tiled to beam_size (unsqueeze and tile), resulting to shapes like [batch_size * beam_size, …] , which is built into BeamSearchDecoder and done automatically. Thus any other tensor with shape [batch_size, …] used in cell.call needs to be tiled manually first, which can be completed by using BeamSearchDecoder.tile_beam_merge_with_batch . The most common case for this is the encoder output in attention mechanism.

Examples

import numpy as np
import paddle
from paddle.nn import BeamSearchDecoder, dynamic_decode
from paddle.nn import GRUCell, Linear, Embedding
trg_embeder = Embedding(100, 32)
output_layer = Linear(32, 32)
decoder_cell = GRUCell(input_size=32, hidden_size=32)
decoder = BeamSearchDecoder(decoder_cell,
                            start_token=0,
                            end_token=1,
                            beam_size=4,
                            embedding_fn=trg_embeder,
                            output_fn=output_layer)
static tile_beam_merge_with_batch ( x, beam_size )

tile_beam_merge_with_batch

Tile the batch dimension of a tensor. Specifically, this function takes a tensor t shaped [batch_size, s0, s1, …] composed of minibatch entries t[0], …, t[batch_size - 1] and tiles it to have a shape [batch_size * beam_size, s0, s1, …] composed of minibatch entries t[0], t[0], …, t[1], t[1], … where each minibatch entry is repeated beam_size times.

Parameters
  • x (Variable) – A tensor with shape [batch_size, …]. The data type should be float32, float64, int32, int64 or bool.

  • beam_size (int) – The beam width used in beam search.

Returns

A tensor with shape [batch_size * beam_size, …], whose

data type is same as x.

Return type

Variable

class OutputWrapper ( scores, predicted_ids, parent_ids )

The structure for the returned value outputs of decoder.step. A namedtuple includes scores, predicted_ids, parent_ids as fields.

class StateWrapper ( cell_states, log_probs, finished, lengths )

The structure for the argument states of decoder.step. A namedtuple includes cell_states, log_probs, finished, lengths as fields.

initialize ( initial_cell_states )

initialize

Initialize the BeamSearchDecoder.

Parameters

initial_cell_states (Variable) – A (possibly nested structure of) tensor variable[s]. An argument provided by the caller.

Returns

A tuple( (initial_inputs, initial_states, finished) ).

initial_inputs is a tensor t filled by start_token with shape [batch_size, beam_size] when embedding_fn is None, or the returned value of embedding_fn(t) when embedding_fn is provided. initial_states is a nested structure(namedtuple including cell_states, log_probs, finished, lengths as fields) of tensor variables, where log_probs, finished, lengths all has a tensor value shaped [batch_size, beam_size] with data type float32, bool, int64. cell_states has a value with the same structure as the input argument initial_cell_states but with tiled shape [batch_size, beam_size, …]. finished is a bool tensor filled by False with shape [batch_size, beam_size].

Return type

tuple

step ( time, inputs, states, **kwargs )

step

Perform a beam search decoding step, which uses cell to get probabilities, and follows a beam search step to calculate scores and select candidate token ids.

Parameters
  • time (Variable) – An int64 tensor with shape [1] provided by the caller, representing the current time step number of decoding.

  • inputs (Variable) – A tensor variable. It is same as initial_inputs returned by initialize() for the first decoding step and next_inputs returned by step() for the others.

  • states (Variable) – A structure of tensor variables. It is same as the initial_states returned by initialize() for the first decoding step and beam_search_state returned by step() for the others.

  • **kwargs – Additional keyword arguments, provided by the caller.

Returns

A tuple( (beam_search_output, beam_search_state, next_inputs, finished) ).

beam_search_state and next_inputs have the same structure, shape and data type as the input arguments states and inputs separately. beam_search_output is a namedtuple(including scores, predicted_ids, parent_ids as fields) of tensor variables, where scores, predicted_ids, parent_ids all has a tensor value shaped [batch_size, beam_size] with data type float32, int64, int64. finished is a bool tensor with shape [batch_size, beam_size].

Return type

tuple

finalize ( outputs, final_states, sequence_lengths )

finalize

Use gather_tree to backtrace along the beam search tree and construct the full predicted sequences.

Parameters
  • outputs (Variable) – A structure(namedtuple) of tensor variables, The structure and data type is same as output_dtype. The tensor stacks all time steps’ output thus has shape [time_step, batch_size, …], which is done by the caller.

  • final_states (Variable) – A structure(namedtuple) of tensor variables. It is the next_states returned by decoder.step at last decoding step, thus has the same structure, shape and data type with states at any time step.

  • sequence_lengths (Variable) – An int64 tensor shaped [batch_size, beam_size]. It contains sequence lengths for each beam determined during decoding.

Returns

A tuple( (predicted_ids, final_states) ).

predicted_ids is an int64 tensor shaped [time_step, batch_size, beam_size]. final_states is the same as the input argument final_states.

Return type

tuple

property tracks_own_finished

BeamSearchDecoder reorders its beams and their finished state. Thus it conflicts with dynamic_decode function’s tracking of finished states. Setting this property to true to avoid early stopping of decoding due to mismanagement of the finished state.

Returns

A python bool True.

Return type

bool