- class paddle.nn. BiRNN ( cell_fw, cell_bw, time_major=False )
Wrapper for bidirectional RNN, which builds a bidiretional RNN given the forward rnn cell and backward rnn cell. A BiRNN applies forward RNN and backward RNN with coresponding cells separately and concats the outputs along the last axis.
cell_fw (RNNCellBase) – A RNNCellBase instance used for forward RNN.
cell_bw (RNNCellBase) – A RNNCellBase instance used for backward RNN.
time_major (bool) – Whether the first dimension of the input means the time steps. Defaults to False.
inputs (Tensor): the input sequences of both RNN. If time_major is True, the shape of is [time_steps, batch_size, input_size], else the shape is [batch_size, time_steps, input_size], where input_size is the input size of both cells.
initial_states (list|tuple, optional): A tuple/list of the initial states of the forward cell and backward cell. Defaults to None. If not provided, cell.get_initial_states would be called to produce the initial states for each cell. Defaults to None.
sequence_length (Tensor, optional): shape [batch_size], dtype: int64 or int32. The valid lengths of input sequences. Defaults to None. If sequence_length is not None, the inputs are treated as padded sequences. In each input sequence, elements whose time step index are not less than the valid length are treated as paddings.
kwargs: Additional keyword arguments. Arguments passed to forward for each cell.
outputs (Tensor): the outputs of the bidirectional RNN. It is the concatenation of the outputs from the forward RNN and backward RNN along the last axis. If time major is True, the shape is [time_steps, batch_size, size], else the shape is [batch_size, time_steps, size], where size is cell_fw.hidden_size + cell_bw.hidden_size.
final_states (tuple): A tuple of the final states of the forward cell and backward cell.
This class is a low level API for wrapping rnn cells into a BiRNN network. Users should take care of the states of the cells. If initial_states is passed to the forward method, make sure that it satisfies the requirements of the cells.
import paddle cell_fw = paddle.nn.LSTMCell(16, 32) cell_bw = paddle.nn.LSTMCell(16, 32) rnn = paddle.nn.BiRNN(cell_fw, cell_bw) inputs = paddle.rand((2, 23, 16)) outputs, final_states = rnn(inputs) print(outputs.shape) print(final_states.shape,len(final_states),len(final_states)) #[4,23,64] #[2,32] 2 2
Defines the computation performed at every call. Should be overridden by all subclasses.
*inputs (tuple) – unpacked tuple arguments
**kwargs (dict) – unpacked dict arguments