nanmedian
PyTorch 兼容的 nanmedian 版本,提供完全一致的函数签名与行为: - 忽略 NaN 元素计算中位数。 - 当 dim
为 None
时,返回所有非 NaN 元素的中位数; - 当 dim
指定维度时,返回该维度上非 NaN 元素的中位数与对应索引;若某个切片全为 NaN,则返回 NaN 和索引 0。
备注
此 API 遵循 torch.nanmedian
的函数签名和行为以实现 PyTorch 兼容。如需使用 Paddle 原生实现,请参考 nanmedian。
参数
input (Tensor) - 输入 N 维 Tensor,支持 bfloat16、float16、float32、float64 数据类型。
dim (int,可选) - 指定计算中位数的维度。为
None
时计算全局中位数。默认None
。keepdim (bool,可选) - 是否保留被约简的维度。默认
False
。out (tuple(Tensor, Tensor)|Tensor,可选) - 关键字参数。当指定
dim
时,可传入二元组(values, indices)
用于原位写回中位数与索引;当未指定dim
时,可传入单个Tensor
用于写回标量结果。默认None
。
返回
当
dim
为None
:返回一个标量Tensor
,为忽略 NaN 后的中位数。当
dim
指定:返回具名元组(values, indices)
,分别是中位数与其索引。
代码示例
>>> import paddle
>>> import numpy as np
>>> x = paddle.to_tensor([[1, float('nan'), 3], [4, 5, 6], [float('nan'), 8, 9]], dtype='float32')
>>> result = paddle.compat.nanmedian(x)
>>> print(result)
Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True, 5.0)
>>> ret = paddle.compat.nanmedian(x, dim=1)
>>> print(ret.values)
Tensor(shape=[3], dtype=float32, place=Place(cpu), stop_gradient=True, [1.0, 5.0, 8.0])
>>> print(ret.indices)
Tensor(shape=[3], dtype=int64, place=Place(cpu), stop_gradient=True, [0, 1, 1])
>>> # Using out parameter
>>> out_values = paddle.zeros([3], dtype='float32')
>>> out_indices = paddle.zeros([3], dtype='int64')
>>> paddle.compat.nanmedian(x, dim=1, out=(out_values, out_indices))
>>> print(out_values)
Tensor(shape=[3], dtype=float32, place=Place(cpu), stop_gradient=True, [1.0, 5.0, 8.0])