- class paddle.nn. RNN ( cell, is_reverse=False, time_major=False )
Wrapper for RNN, which creates a recurrent neural network with an RNN cell. It performs
cell.forward()repeatedly until reaches to the maximum length of inputs.
cell (RNNCellBase) – An instance of RNNCellBase.
is_reverse (bool, optional) – Indicate whether to calculate in the reverse order of input sequences. Defaults to False.
time_major (bool) – Whether the first dimension of the input means the time steps. Defaults to False.
inputs (Tensor): A (possibly nested structure of) tensor[s]. The input sequences. If time major is False, the shape is [batch_size, time_steps, input_size]. If time major is True, the shape is [time_steps, batch_size, input_size] where input_size is the input size of the cell.
initial_states (Tensor|list|tuple, optional): Tensor of a possibly nested structure of tensors, representing the initial state for the rnn cell. If not provided, cell.get_initial_states would be called to produce the initial states. Defaults to None.
sequence_length (Tensor, optional): shape [batch_size], dtype: int64 or int32. The valid lengths of input sequences. Defaults to None.If sequence_length is not None, the inputs are treated as padded sequences. In each input sequence, elements whose time step index are not less than the valid length are treated as paddings.
kwargs: Additional keyword arguments to pass to forward of the cell.
the output sequences. If time_major is True, the shape is [time_steps, batch_size, hidden_size], else [batch_size, time_steps, hidden_size]. - final_states (Tensor|list|tuple): final states of the cell. Tensor or a possibly nested structure of tensors which has the same structure with intial state. Each tensor in final states has the same shape and dtype as the corresponding tensor in initial states.
- Return type
This class is a low level API for wrapping rnn cell into a RNN network. Users should take care of the state of the cell. If initial_states is passed to the forward method, make sure that it satisfies the requirements of the cell.
import paddle inputs = paddle.rand((4, 23, 16)) prev_h = paddle.randn((4, 32)) cell = paddle.nn.SimpleRNNCell(16, 32) rnn = paddle.nn.RNN(cell) outputs, final_states = rnn(inputs, prev_h) print(outputs.shape) print(final_states.shape) #[4,23,32] #[4,32]
Defines the computation performed at every call. Should be overridden by all subclasses.
*inputs (tuple) – unpacked tuple arguments
**kwargs (dict) – unpacked dict arguments