margin_cross_entropy¶
- paddle.nn.functional. margin_cross_entropy ( logits, label, margin1=1.0, margin2=0.5, margin3=0.0, scale=64.0, group=None, return_softmax=False, reduction='mean' ) [源代码] ¶
其中,\(\theta_{y_i}\) 是特征 \(x\) 与类 \(w_{i}\) 的角度。更详细的介绍请参考 Arcface loss,https://arxiv.org/abs/1801.07698 。
提示:
这个 API 支持单卡,也支持多卡(模型并行),使用模型并行时,
logits.shape[-1]在每张卡上可以不同。
参数¶
logits (Tensor) - 2-D Tensor,维度为
[N, local_num_classes],logits为归一化后的X与归一化后的W矩阵乘得到,数据类型为 float16,float32 或者 float64。如果用了模型并行,则logits == sahrd_logits。
label (Tensor) - 维度为
[N]或者[N, 1]的标签。
margin1 (float,可选) - 公式中的
m1。默认值为1.0。
margin2 (float,可选) - 公式中的
m2。默认值为0.5。
margin3 (float,可选) - 公式中的
m3。默认值为0.0。
scale (float,可选) - 公式中的
s。默认值为64.0。
group (Group,可选) - 通信组的抽象描述,具体可以参考
paddle.distributed.collective.Group。默认值为None。
return_softmax (bool,可选) - 是否返回
softmax概率值。默认值为None。
reduction (str,可选)- 是否对
loss进行归约。可选值为'none'|'mean'|'sum'。如果reduction='mean',则对loss进行平均,如果reduction='sum',则对loss进行求和,reduction='None',则直接返回loss。默认值为'mean'。
返回¶
Tensor(loss) 或者Tensor二元组 (loss,softmax) - 如果return_softmax=False返回loss,否则返回 (loss,softmax)。当使用模型并行时softmax == shard_softmax,否则softmax的维度与logits相同。如果reduction == None,loss的维度为[N, 1],否则为[]。
代码示例¶
 >>> import paddle
 >>> paddle.seed(2023)
 >>> paddle.device.set_device('gpu')
 >>> m1 = 1.0
 >>> m2 = 0.5
 >>> m3 = 0.0
 >>> s = 64.0
 >>> batch_size = 2
 >>> feature_length = 4
 >>> num_classes = 4
 >>> label = paddle.randint(low=0, high=num_classes, shape=[batch_size], dtype='int64')
 >>> X = paddle.randn(
 ...     shape=[batch_size, feature_length],
 ...     dtype='float64')
 >>> X_l2 = paddle.sqrt(paddle.sum(paddle.square(X), axis=1, keepdim=True))
 >>> X = paddle.divide(X, X_l2)
 >>> W = paddle.randn(
 ...     shape=[feature_length, num_classes],
 ...     dtype='float64')
 >>> W_l2 = paddle.sqrt(paddle.sum(paddle.square(W), axis=0, keepdim=True))
 >>> W = paddle.divide(W, W_l2)
 >>> logits = paddle.matmul(X, W)
 >>> loss, softmax = paddle.nn.functional.margin_cross_entropy(
 ...     logits, label, margin1=m1, margin2=m2, margin3=m3, scale=s, return_softmax=True, reduction=None)
 >>> print(logits)
 Tensor(shape=[2, 4], dtype=float64, place=Place(gpu:0), stop_gradient=True,
        [[-0.59561850,  0.32797505,  0.80279214,  0.00144975],
         [-0.16265212,  0.84155098,  0.62008629,  0.79126072]])
 >>> print(label)
 Tensor(shape=[2], dtype=int64, place=Place(gpu:0), stop_gradient=True,
        [1, 0])
 >>> print(loss)
 Tensor(shape=[2, 1], dtype=float64, place=Place(gpu:0), stop_gradient=True,
        [[61.94391901],
         [93.30853839]])
 >>> print(softmax)
 Tensor(shape=[2, 4], dtype=float64, place=Place(gpu:0), stop_gradient=True,
        [[0.00000000, 0.00000000, 1.        , 0.00000000],
         [0.00000000, 0.96152676, 0.00000067, 0.03847257]])
          >>> # Multi GPU, test_margin_cross_entropy.py
 >>> from typing import List
 >>> import paddle
 >>> import paddle.distributed as dist
 >>> paddle.seed(2023)
 >>> strategy = dist.fleet.DistributedStrategy()
 >>> dist.fleet.init(is_collective=True, strategy=strategy)
 >>> rank_id = dist.get_rank()
 >>> m1 = 1.0
 >>> m2 = 0.5
 >>> m3 = 0.0
 >>> s = 64.0
 >>> batch_size = 2
 >>> feature_length = 4
 >>> num_class_per_card = [4, 8]
 >>> num_classes = paddle.sum(paddle.to_tensor(num_class_per_card))
 >>> label = paddle.randint(low=0, high=num_classes.item(), shape=[batch_size], dtype='int64')
 >>> label_list: List[paddle.Tensor] = []
 >>> dist.all_gather(label_list, label)
 >>> label = paddle.concat(label_list, axis=0)
 >>> X = paddle.randn(
 ...     shape=[batch_size, feature_length],
 ...     dtype='float64'
 ... )
 >>> X_list: List[paddle.Tensor] = []
 >>> dist.all_gather(X_list, X)
 >>> X = paddle.concat(X_list, axis=0)
 >>> X_l2 = paddle.sqrt(paddle.sum(paddle.square(X), axis=1, keepdim=True))
 >>> X = paddle.divide(X, X_l2)
 >>> W = paddle.randn(
 ...     shape=[feature_length, num_class_per_card[rank_id]],
 ...     dtype='float64')
 >>> W_l2 = paddle.sqrt(paddle.sum(paddle.square(W), axis=0, keepdim=True))
 >>> W = paddle.divide(W, W_l2)
 >>> logits = paddle.matmul(X, W)
 >>> loss, softmax = paddle.nn.functional.margin_cross_entropy(
 ...     logits, label, margin1=m1, margin2=m2, margin3=m3, scale=s, return_softmax=True, reduction=None)
 >>> print(logits)
 >>> print(label)
 >>> print(loss)
 >>> print(softmax)
 >>> # python -m paddle.distributed.launch --gpus=0,1 --log_dir log test_margin_cross_entropy.py
 >>> # cat log/workerlog.0
 >>> # Tensor(shape=[4, 4], dtype=float64, place=Place(gpu:0), stop_gradient=True,
 >>> #        [[-0.59561850,  0.32797505,  0.80279214,  0.00144975],
 >>> #         [-0.16265212,  0.84155098,  0.62008629,  0.79126072],
 >>> #         [-0.59561850,  0.32797505,  0.80279214,  0.00144975],
 >>> #         [-0.16265212,  0.84155098,  0.62008629,  0.79126072]])
 >>> # Tensor(shape=[4], dtype=int64, place=Place(gpu:0), stop_gradient=True,
 >>> #        [5, 4, 5, 4])
 >>> # Tensor(shape=[4, 1], dtype=float64, place=Place(gpu:0), stop_gradient=True,
 >>> #        [[104.27437027],
 >>> #         [113.40243782],
 >>> #         [104.27437027],
 >>> #         [113.40243782]])
 >>> # Tensor(shape=[4, 4], dtype=float64, place=Place(gpu:0), stop_gradient=True,
 >>> #        [[0.00000000, 0.00000000, 0.01210039, 0.00000000],
 >>> #         [0.00000000, 0.96152674, 0.00000067, 0.03847257],
 >>> #         [0.00000000, 0.00000000, 0.01210039, 0.00000000],
 >>> #         [0.00000000, 0.96152674, 0.00000067, 0.03847257]])
 >>> # cat log/workerlog.1
 >>> # Tensor(shape=[4, 8], dtype=float64, place=Place(gpu:1), stop_gradient=True,
 >>> #        [[-0.34913275, -0.35180883, -0.53976657, -0.75234331,  0.70534995,
 >>> #           0.87157838,  0.31064437,  0.19537700],
 >>> #         [-0.63941012, -0.05631600, -0.02561853,  0.09363013,  0.56571130,
 >>> #           0.13611246,  0.08849565,  0.39219619],
 >>> #         [-0.34913275, -0.35180883, -0.53976657, -0.75234331,  0.70534995,
 >>> #           0.87157838,  0.31064437,  0.19537700],
 >>> #         [-0.63941012, -0.05631600, -0.02561853,  0.09363013,  0.56571130,
 >>> #           0.13611246,  0.08849565,  0.39219619]])
 >>> # Tensor(shape=[4], dtype=int64, place=Place(gpu:1), stop_gradient=True,
 >>> #        [5, 4, 5, 4])
 >>> # Tensor(shape=[4, 1], dtype=float64, place=Place(gpu:1), stop_gradient=True,
 >>> #        [[104.27437027],
 >>> #         [113.40243782],
 >>> #         [104.27437027],
 >>> #         [113.40243782]])
 >>> # Tensor(shape=[4, 8], dtype=float64, place=Place(gpu:1), stop_gradient=True,
 >>> #        [[0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00002368, 0.98787593,
 >>> #          0.00000000, 0.00000000],
 >>> #         [0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000002, 0.00000000,
 >>> #          0.00000000, 0.00000000],
 >>> #         [0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00002368, 0.98787593,
 >>> #          0.00000000, 0.00000000],
 >>> #         [0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000002, 0.00000000,
 >>> #          0.00000000, 0.00000000]])