sort

paddle.compat. sort ( input, dim=-1, descending=False, stable=False, out=None ) [源代码]

PyTorch 兼容的 sort 版本,同时返回排序的值结果以及索引值。对输入变量沿给定轴进行排序,输出排序好的数据,其维度和输入相同。默认升序排列,如果需要降序排列设置 descending=True

使用前请详细参考:【返回参数类型不一致】torch.sort 以确定是否使用此模块。

参数

  • input (Tensor) - 输入的多维 Tensor,支持的数据类型:float32, float64, int16, int32, int64, uint8, float16, bfloat16。

  • dim (int,可选) - 指定对输入 Tensor 进行运算的轴,dim 的有效范围是[-R, R),R 是输入 x 的 Rank, dim 为负时与 dim +R 等价。默认值为-1。

  • descending (bool,可选) - 指定算法排序的方向。如果设置为 True,算法按照降序排序。如果设置为 False 或者不设置,按照升序排序。默认值为 False。

  • stable (bool,可选) - 是否使用稳定排序算法。若设置为 True,则使用稳定排序算法,即相同元素的顺序在排序结果中将会被保留。默认值为 False,此时的算法不一定是稳定排序算法。

  • out (tuple(Tensor, Tensor),可选) - 用于引用式传入输出值。values 在前,indices 在后。注意:动态图下 out 可以是任意 Tensor,默认值为 None。out 返回方法与静态图联合使用是被禁止的行为,静态图下将报错。

返回

SortRetType(Tensor, Tensor),此处的 SortRetType 是一个具名元组,含有 values (在前)和 indices (在后)两个域,用法与 tuple 一致。

代码示例

>>> import paddle

>>> x = paddle.to_tensor([[5,8,9,5],
...                       [0,0,1,7],
...                       [6,9,2,4]],
...                      dtype='float32')
>>> out1 = paddle.compat.sort(input=x, dim=-1)
>>> out2 = paddle.compat.sort(x, 1, descending=True)
>>> out1
SortRetType(values=Tensor(shape=[3, 4], dtype=float32, place=Place(cpu), stop_gradient=True,
       [[5., 5., 8., 9.],
        [0., 0., 1., 7.],
        [2., 4., 6., 9.]]), indices=Tensor(shape=[3, 4], dtype=int64, place=Place(cpu), stop_gradient=True,
       [[0, 3, 1, 2],
        [0, 1, 2, 3],
        [2, 3, 0, 1]]))
>>> out2
SortRetType(values=Tensor(shape=[3, 4], dtype=float32, place=Place(cpu), stop_gradient=True,
       [[9., 8., 5., 5.],
        [7., 1., 0., 0.],
        [9., 6., 4., 2.]]), indices=Tensor(shape=[3, 4], dtype=int64, place=Place(cpu), stop_gradient=True,
       [[2, 1, 0, 3],
        [3, 2, 0, 1],
        [1, 0, 3, 2]]))