argsort

paddle.fluid.layers.tensor. argsort ( input, axis=- 1, descending=False, name=None ) [source]
alias_main

paddle.argsort

alias

paddle.argsort,paddle.tensor.argsort,paddle.tensor.search.argsort

old_api

paddle.fluid.layers.argsort

This OP sorts the input along the given axis, and returns sorted output data Varibale and its corresponding index Variable with the same shape as input.

Parameters
  • input (Variable) – An input N-D Tensor with type float32, float64, int16, int32, int64, uint8.

  • axis (int, optional) – Axis to compute indices along. The effective range is [-R, R), where R is Rank(x). when axis<0, it works the same way as axis+R. Default is 0.

  • descending (bool, optional) – Descending is a flag, if set to true, algorithm will sort by descending order, else sort by ascending order. Default is false.

  • name (str, optional) – The default value is None. Normally there is no need for user to set this property. For more information, please refer to Name.

Returns

A tuple of sorted data Variable(with the same shape and data type as input) and the sorted indices(with the same shape as input’s and with data type int64).

Return type

tuple

Examples

import paddle.fluid as fluid
import numpy as np

in1 = np.array([[[5,8,9,5],
                [0,0,1,7],
                [6,9,2,4]],
                [[5,2,4,2],
                [4,7,7,9],
                [1,7,0,6]]]).astype(np.float32)
with fluid.dygraph.guard():
    x = fluid.dygraph.to_variable(in1)
    out1 = fluid.layers.argsort(input=x, axis=-1)
    out2 = fluid.layers.argsort(input=x, axis=0)
    out3 = fluid.layers.argsort(input=x, axis=1)
    print(out1[0].numpy())
    # [[[5. 5. 8. 9.]
    #   [0. 0. 1. 7.]
    #   [2. 4. 6. 9.]]
    #  [[2. 2. 4. 5.]
    #   [4. 7. 7. 9.]
    #   [0. 1. 6. 7.]]]
    print(out1[1].numpy())
    # [[[0 3 1 2]
    #   [0 1 2 3]
    #   [2 3 0 1]]
    #  [[1 3 2 0]
    #   [0 1 2 3]
    #   [2 0 3 1]]]
    print(out2[0].numpy())
    # [[[5. 2. 4. 2.]
    #   [0. 0. 1. 7.]
    #   [1. 7. 0. 4.]]
    #  [[5. 8. 9. 5.]
    #   [4. 7. 7. 9.]
    #   [6. 9. 2. 6.]]]
    print(out3[0].numpy())
    # [[[0. 0. 1. 4.]
    #   [5. 8. 2. 5.]
    #   [6. 9. 9. 7.]]
    #  [[1. 2. 0. 2.]
    #   [4. 7. 4. 6.]
    #   [5. 7. 7. 9.]]]