SoftMarginLoss

class paddle.nn. SoftMarginLoss ( reduction='mean', name=None ) [源代码]

生成一个可以计算输入 inputlabel 间的二分类损失的类。

损失函数按照下列公式计算

\[\text{loss}(x, y) = \sum_i \frac{\log(1 + \exp(-y[i]*x[i]))}{\text{x.nelement}()}\]

最后,添加 reduce 操作到前面的输出 Out 上。当 reductionnone 时,直接返回最原始的 Out 结果。当 reductionmean 时,返回输出的均值 \(Out = MEAN(Out)\) 。当 reductionsum 时,返回输出的求和 \(Out = SUM(Out)\)

参数

  • reduction (str,可选) - 指定应用于输出结果的计算方式,可选值有: 'none', 'mean', 'sum' 。默认为 'mean',计算 Loss 的均值;设置为 'sum' 时,计算 Loss 的总和;设置为 'none' 时,则返回原始 Loss。

  • name (str,可选) - 操作的名称(可选,默认值为 None)。更多信息请参见 Name

形状

  • input (Tensor) - \([N, *]\) , 其中 N 是 batch_size, * 是任意其他维度。数据类型是 float32、float64。

  • label (Tensor) - \([N, *]\) ,标签 label 的维度、数据类型与输入 input 相同。

  • output (Tensor) - 输出的 Tensor。如果 reduction'none',则输出的维度为 \([N, *]\),与输入 input 的形状相同。如果 reduction'mean''sum',则输出的维度为 \([]\)

返回

返回计算 SoftMarginLoss 的可调用对象。

代码示例

>>> import paddle
>>> paddle.seed(2023)
>>> input = paddle.to_tensor([[0.5, 0.6, 0.7],[0.3, 0.5, 0.2]], 'float32')
>>> label = paddle.to_tensor([[1.0, -1.0, 1.0],[-1.0, 1.0, 1.0]], 'float32')
>>> soft_margin_loss = paddle.nn.SoftMarginLoss()
>>> output = soft_margin_loss(input, label)
>>> print(output)
Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True,
0.64022040)

>>> input_np = paddle.uniform(shape=(5, 5), min=0.1, max=0.8, dtype="float64")
>>> label_np = paddle.randint(high=2, shape=(5, 5), dtype="int64")
>>> label_np[label_np==0]=-1
>>> input = paddle.to_tensor(input_np)
>>> label = paddle.to_tensor(label_np)
>>> soft_margin_loss = paddle.nn.SoftMarginLoss(reduction='none')
>>> output = soft_margin_loss(input, label)
>>> print(output)
Tensor(shape=[5, 5], dtype=float64, place=Place(cpu), stop_gradient=True,
[[1.10725628, 0.48778139, 0.56217249, 1.12581404, 0.51430043],
 [0.90375795, 0.37761249, 0.43007557, 0.95089798, 0.43288319],
 [1.16043599, 0.63015939, 0.51362715, 0.43617541, 0.57783301],
 [0.81927846, 0.52558369, 0.59713908, 0.83100696, 0.50811616],
 [0.82684205, 1.02064907, 0.50296995, 1.13461733, 0.93222519]])