RNNTLoss

class paddle.nn. RNNTLoss ( blank=0, fastemit_lambda=0.001, reduction='mean', name=None ) [source]
Parameters
  • blank (int, optional) – blank label. Default: 0.

  • fastemit_lambda (float, optional) – Regularization parameter for FastEmit (https://arxiv.org/pdf/2010.11148.pdf)

  • reduction (string, optional) – Specifies the reduction to apply to the output: ‘none’ | ‘mean’ | ‘sum’. ‘none’: no reduction will be applied, ‘mean’: the output losses will be divided by the target lengths and then the mean over the batch is taken. Default: ‘mean’

Shape:

input: logprob Tensor of (batch x seqLength x labelLength x outputDim) containing output from network label: 2 dimensional (batch, labelLength) Tensor containing all the targets of the batch with zero padded input_lengths: Tensor of size (batch) containing size of each output sequence from the network label_lengths: Tensor of (batch) containing label length of each example

Returns

reduction is 'none', the shape of loss is [batch_size], otherwise, the shape of loss is []. Data type is the same as logprobs.

Return type

Tensor, The RNN-T loss between logprobs and labels. If attr

Examples

>>> # declarative mode
>>> import numpy as np
>>> import paddle
>>> from paddle.nn import RNNTLoss

>>> fn = RNNTLoss(reduction='sum', fastemit_lambda=0.0)

>>> acts = np.array([[[[0.1, 0.6, 0.1, 0.1, 0.1],
...                    [0.1, 0.1, 0.6, 0.1, 0.1],
...                    [0.1, 0.1, 0.2, 0.8, 0.1]],
...                   [[0.1, 0.6, 0.1, 0.1, 0.1],
...                    [0.1, 0.1, 0.2, 0.1, 0.1],
...                    [0.7, 0.1, 0.2, 0.1, 0.1]]]])
>>> labels = [[1, 2]]

>>> acts = paddle.to_tensor(acts, stop_gradient=False)

>>> lengths = [acts.shape[1]] * acts.shape[0]
>>> label_lengths = [len(l) for l in labels]
>>> labels = paddle.to_tensor(labels, paddle.int32)
>>> lengths = paddle.to_tensor(lengths, paddle.int32)
>>> label_lengths = paddle.to_tensor(label_lengths, paddle.int32)

>>> costs = fn(acts, labels, lengths, label_lengths)
>>> print(costs)
Tensor(shape=[], dtype=float64, place=Place(cpu), stop_gradient=False,
-2.85042444)
forward ( input, label, input_lengths, label_lengths )

forward

Defines the computation performed at every call. Should be overridden by all subclasses.

Parameters
  • *inputs (tuple) – unpacked tuple arguments

  • **kwargs (dict) – unpacked dict arguments