TripletMarginWithDistanceLoss

class paddle.nn. TripletMarginWithDistanceLoss ( distance_function=None, margin=1.0, swap=False, reduction: str = 'mean', name=None ) [source]

Creates a criterion that measures the triplet loss given an input tensors \(x1\), \(x2\), \(x3\) and a margin with a value greater than \(0\). This is used for measuring a relative similarity between samples. A triplet is composed by input, positive and negative (i.e., input, positive examples and negative examples respectively). The shapes of all input tensors should be \((N, D)\).

The loss function for each sample in the mini-batch is:

\[L(input, pos, neg) = \max \{d(input_i, pos_i) - d(input_i, neg_i) + {\rm margin}, 0\}\]

where the default distance_function

\[d(x_i, y_i) = \left\lVert {\bf x}_i - {\bf y}_i \right\rVert_2\]

or user can define their own distance function. margin is a nonnegative margin representing the minimum difference between the positive and negative distances that is required for the loss to be 0. If swap is true, it will compare distance of (input, negative) with distance of (negative, positive) and change it to the smaller one. For more details see http://www.bmva.org/bmvc/2016/papers/paper119/paper119.pdf.

Parameters
  • distance_function (Callable, Optional) – Quantifies the distance between two tensors. if not specified, 2 norm functions will be used.

  • margin (float, Optional) – Default: \(1\).A nonnegative margin representing the minimum difference between the positive and negative distances required for the loss to be 0. Larger margins penalize cases where the negative examples are not distant enough from the anchors, relative to the positives.

  • swap (bool, Optional) – The distance swap changes the negative distance to the swap distance (distance between positive samples and negative samples) if swap distance smaller than negative distance. Default: False.

  • reduction (str, Optional) – Indicate how to average the loss by batch_size. the candicates are 'none' | 'mean' | 'sum'. If reduction is 'none', the unreduced loss is returned; If reduction is 'mean', the reduced mean loss is returned; If reduction is 'sum', the summed loss is returned. Default: 'mean'

  • name (str, optional) – Name for the operation (optional, default is None). For more information, please refer to Name.

Shapes:
  • input (Tensor):Input tensor, the data type is float32 or float64. the shape is [N, *], N is batch size and * means any number of additional dimensions, available dtype is float32, float64.

  • positive (Tensor):Positive tensor, the data type is float32 or float64. The shape of label is the same as the shape of input.

  • negative (Tensor):Negative tensor, the data type is float32 or float64. The shape of label is the same as the shape of input.

  • output(Tensor): The tensor variable storing the triplet_margin_with_distance_loss of input and positive and negative.

Returns

A callable object of TripletMarginWithDistanceLoss

Examples

>>> import paddle
>>> from paddle.nn import TripletMarginWithDistanceLoss

>>> input = paddle.to_tensor([[1, 5, 3], [0, 3, 2], [1, 4, 1]], dtype=paddle.float32)
>>> positive= paddle.to_tensor([[5, 1, 2], [3, 2, 1], [3, -1, 1]], dtype=paddle.float32)
>>> negative = paddle.to_tensor([[2, 1, -3], [1, 1, -1], [4, -2, 1]], dtype=paddle.float32)
>>> triplet_margin_with_distance_loss = TripletMarginWithDistanceLoss(reduction='none')
>>> loss = triplet_margin_with_distance_loss(input, positive, negative,)
>>> print(loss)
Tensor(shape=[3], dtype=float32, place=Place(cpu), stop_gradient=True,
[0.        , 0.57496595, 0.        ])

>>> triplet_margin_with_distance_loss = TripletMarginWithDistanceLoss(reduction='mean')
>>> loss = triplet_margin_with_distance_loss(input, positive, negative,)
>>> print(loss)
Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True,
0.19165532)
forward ( input, positive, negative )

forward

Defines the computation performed at every call. Should be overridden by all subclasses.

Parameters
  • *inputs (tuple) – unpacked tuple arguments

  • **kwargs (dict) – unpacked dict arguments