margin_cross_entropy
- paddle.nn.functional. margin_cross_entropy ( logits, label, margin1=1.0, margin2=0.5, margin3=0.0, scale=64.0, group=None, return_softmax=False, reduction='mean' ) [源代码]
其中,\(\theta_{y_i}\) 是特征 \(x\) 与类 \(w_{i}\) 的角度。更详细的介绍请参考 Arcface loss,https://arxiv.org/abs/1801.07698 。
提示:
这个 API 支持单卡,也支持多卡(模型并行),使用模型并行时,
logits.shape[-1]在每张卡上可以不同。
参数
logits (Tensor) - 2-D Tensor,维度为
[N, local_num_classes],logits为归一化后的X与归一化后的W矩阵乘得到,数据类型为 float16,float32 或者 float64。如果用了模型并行,则logits == sahrd_logits。label (Tensor) - 维度为
[N]或者[N, 1]的标签。margin1 (float,可选) - 公式中的
m1。默认值为1.0。margin2 (float,可选) - 公式中的
m2。默认值为0.5。margin3 (float,可选) - 公式中的
m3。默认值为0.0。scale (float,可选) - 公式中的
s。默认值为64.0。group (Group,可选) - 通信组的抽象描述,具体可以参考
paddle.distributed.collective.Group。默认值为None。return_softmax (bool,可选) - 是否返回
softmax概率值。默认值为None。reduction (str,可选)- 是否对
loss进行归约。可选值为'none'|'mean'|'sum'。如果reduction='mean',则对loss进行平均,如果reduction='sum',则对loss进行求和,reduction='None',则直接返回loss。默认值为'mean'。
返回
Tensor(loss) 或者Tensor二元组 (loss,softmax) - 如果return_softmax=False返回loss,否则返回 (loss,softmax)。当使用模型并行时softmax == shard_softmax,否则softmax的维度与logits相同。如果reduction == None,loss的维度为[N, 1],否则为[]。
代码示例
>>> import paddle
>>> paddle.seed(2023)
>>> paddle.device.set_device('gpu')
>>> m1 = 1.0
>>> m2 = 0.5
>>> m3 = 0.0
>>> s = 64.0
>>> batch_size = 2
>>> feature_length = 4
>>> num_classes = 4
>>> label = paddle.randint(low=0, high=num_classes, shape=[batch_size], dtype='int64')
>>> X = paddle.randn(
... shape=[batch_size, feature_length],
... dtype='float64')
>>> X_l2 = paddle.sqrt(paddle.sum(paddle.square(X), axis=1, keepdim=True))
>>> X = paddle.divide(X, X_l2)
>>> W = paddle.randn(
... shape=[feature_length, num_classes],
... dtype='float64')
>>> W_l2 = paddle.sqrt(paddle.sum(paddle.square(W), axis=0, keepdim=True))
>>> W = paddle.divide(W, W_l2)
>>> logits = paddle.matmul(X, W)
>>> loss, softmax = paddle.nn.functional.margin_cross_entropy(
... logits, label, margin1=m1, margin2=m2, margin3=m3, scale=s, return_softmax=True, reduction=None)
>>> print(logits)
Tensor(shape=[2, 4], dtype=float64, place=Place(gpu:0), stop_gradient=True,
[[-0.59561850, 0.32797505, 0.80279214, 0.00144975],
[-0.16265212, 0.84155098, 0.62008629, 0.79126072]])
>>> print(label)
Tensor(shape=[2], dtype=int64, place=Place(gpu:0), stop_gradient=True,
[1, 0])
>>> print(loss)
Tensor(shape=[2, 1], dtype=float64, place=Place(gpu:0), stop_gradient=True,
[[61.94391901],
[93.30853839]])
>>> print(softmax)
Tensor(shape=[2, 4], dtype=float64, place=Place(gpu:0), stop_gradient=True,
[[0.00000000, 0.00000000, 1. , 0.00000000],
[0.00000000, 0.96152676, 0.00000067, 0.03847257]])
>>> # Multi GPU, test_margin_cross_entropy.py
>>> from typing import List
>>> import paddle
>>> import paddle.distributed as dist
>>> paddle.seed(2023)
>>> strategy = dist.fleet.DistributedStrategy()
>>> dist.fleet.init(is_collective=True, strategy=strategy)
>>> rank_id = dist.get_rank()
>>> m1 = 1.0
>>> m2 = 0.5
>>> m3 = 0.0
>>> s = 64.0
>>> batch_size = 2
>>> feature_length = 4
>>> num_class_per_card = [4, 8]
>>> num_classes = paddle.sum(paddle.to_tensor(num_class_per_card))
>>> label = paddle.randint(low=0, high=num_classes.item(), shape=[batch_size], dtype='int64')
>>> label_list: List[paddle.Tensor] = []
>>> dist.all_gather(label_list, label)
>>> label = paddle.concat(label_list, axis=0)
>>> X = paddle.randn(
... shape=[batch_size, feature_length],
... dtype='float64'
... )
>>> X_list: List[paddle.Tensor] = []
>>> dist.all_gather(X_list, X)
>>> X = paddle.concat(X_list, axis=0)
>>> X_l2 = paddle.sqrt(paddle.sum(paddle.square(X), axis=1, keepdim=True))
>>> X = paddle.divide(X, X_l2)
>>> W = paddle.randn(
... shape=[feature_length, num_class_per_card[rank_id]],
... dtype='float64')
>>> W_l2 = paddle.sqrt(paddle.sum(paddle.square(W), axis=0, keepdim=True))
>>> W = paddle.divide(W, W_l2)
>>> logits = paddle.matmul(X, W)
>>> loss, softmax = paddle.nn.functional.margin_cross_entropy(
... logits, label, margin1=m1, margin2=m2, margin3=m3, scale=s, return_softmax=True, reduction=None)
>>> print(logits)
>>> print(label)
>>> print(loss)
>>> print(softmax)
>>> # python -m paddle.distributed.launch --gpus=0,1 --log_dir log test_margin_cross_entropy.py
>>> # cat log/workerlog.0
>>> # Tensor(shape=[4, 4], dtype=float64, place=Place(gpu:0), stop_gradient=True,
>>> # [[-0.59561850, 0.32797505, 0.80279214, 0.00144975],
>>> # [-0.16265212, 0.84155098, 0.62008629, 0.79126072],
>>> # [-0.59561850, 0.32797505, 0.80279214, 0.00144975],
>>> # [-0.16265212, 0.84155098, 0.62008629, 0.79126072]])
>>> # Tensor(shape=[4], dtype=int64, place=Place(gpu:0), stop_gradient=True,
>>> # [5, 4, 5, 4])
>>> # Tensor(shape=[4, 1], dtype=float64, place=Place(gpu:0), stop_gradient=True,
>>> # [[104.27437027],
>>> # [113.40243782],
>>> # [104.27437027],
>>> # [113.40243782]])
>>> # Tensor(shape=[4, 4], dtype=float64, place=Place(gpu:0), stop_gradient=True,
>>> # [[0.00000000, 0.00000000, 0.01210039, 0.00000000],
>>> # [0.00000000, 0.96152674, 0.00000067, 0.03847257],
>>> # [0.00000000, 0.00000000, 0.01210039, 0.00000000],
>>> # [0.00000000, 0.96152674, 0.00000067, 0.03847257]])
>>> # cat log/workerlog.1
>>> # Tensor(shape=[4, 8], dtype=float64, place=Place(gpu:1), stop_gradient=True,
>>> # [[-0.34913275, -0.35180883, -0.53976657, -0.75234331, 0.70534995,
>>> # 0.87157838, 0.31064437, 0.19537700],
>>> # [-0.63941012, -0.05631600, -0.02561853, 0.09363013, 0.56571130,
>>> # 0.13611246, 0.08849565, 0.39219619],
>>> # [-0.34913275, -0.35180883, -0.53976657, -0.75234331, 0.70534995,
>>> # 0.87157838, 0.31064437, 0.19537700],
>>> # [-0.63941012, -0.05631600, -0.02561853, 0.09363013, 0.56571130,
>>> # 0.13611246, 0.08849565, 0.39219619]])
>>> # Tensor(shape=[4], dtype=int64, place=Place(gpu:1), stop_gradient=True,
>>> # [5, 4, 5, 4])
>>> # Tensor(shape=[4, 1], dtype=float64, place=Place(gpu:1), stop_gradient=True,
>>> # [[104.27437027],
>>> # [113.40243782],
>>> # [104.27437027],
>>> # [113.40243782]])
>>> # Tensor(shape=[4, 8], dtype=float64, place=Place(gpu:1), stop_gradient=True,
>>> # [[0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00002368, 0.98787593,
>>> # 0.00000000, 0.00000000],
>>> # [0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000002, 0.00000000,
>>> # 0.00000000, 0.00000000],
>>> # [0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00002368, 0.98787593,
>>> # 0.00000000, 0.00000000],
>>> # [0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000002, 0.00000000,
>>> # 0.00000000, 0.00000000]])