median
PyTorch 兼容的 median 版本,提供完全一致的函数签名与行为: - 当 dim
为 None
时,返回所有元素的中位数。 - 当 dim
指定维度时,返回该维度的中位数与对应索引。
备注
此 API 遵循 torch.median
的函数签名和行为以实现 PyTorch 兼容。如需使用 Paddle 原生实现,请参考 median。
参数
input (Tensor) - 输入 N 维 Tensor,支持 bool、bfloat16、float16、float32、float64、int32、int64 数据类型。
dim (int,可选) - 指定计算中位数的维度。为
None
时计算全局中位数。默认None
。keepdim (bool,可选) - 是否保留被约简的维度。默认
False
。out (tuple(Tensor, Tensor)|Tensor,可选) - 关键字参数。当指定
dim
时,可传入二元组(values, indices)
用于原位写回中位数与索引;当未指定dim
时,可传入单个Tensor
用于写回标量结果。默认None
。
返回
当
dim
为None
:返回一个标量Tensor
,为input
的中位数。当
dim
指定:返回具名元组(values, indices)
,分别是中位数与其索引。
代码示例
>>> import paddle
>>> x = paddle.to_tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
>>> result = paddle.compat.median(x)
>>> print(result)
Tensor(shape=[], dtype=int64, place=Place(cpu), stop_gradient=True, 5)
>>> ret = paddle.compat.median(x, dim=1)
>>> print(ret.values)
Tensor(shape=[3], dtype=int64, place=Place(cpu), stop_gradient=True, [2, 5, 8])
>>> print(ret.indices)
Tensor(shape=[3], dtype=int64, place=Place(cpu), stop_gradient=True, [1, 1, 1])
>>> # Using out parameter
>>> out_values = paddle.zeros([3], dtype='int64')
>>> out_indices = paddle.zeros([3], dtype='int64')
>>> paddle.compat.median(x, dim=1, out=(out_values, out_indices))
>>> print(out_values)
Tensor(shape=[3], dtype=int64, place=Place(cpu), stop_gradient=True, [2, 5, 8])