topk

paddle.fluid.layers.topk(input, k, name=None)[source]

This OP is used to find values and indices of the k largest entries for the last dimension.

If the input is a 1-D Tensor, finds the k largest entries and outputs their values and indices.

If the input is a Tensor with higher rank, this operator computes the top k entries along the last dimension.

Case 1:

  Input:
    input.shape = [3, 4]
    input.data = [[5, 4, 2, 3],
             [9, 7, 10, 25],
             [6, 2, 10, 1]]
    k = 2

  Output:
    The first output:
    values.shape = [3, 2]
    values.data = [[5, 4],
              [10, 25],
              [6, 10]]

    The second output:
    indices.shape = [3, 2]
    indices.data = [[0, 1],
               [2, 3],
               [0, 2]]
Parameters
  • input (Variable) – The input tensor. Support data types: float32, float64.

  • k (int | Variable) – The number of top elements to look for along the last dimension of input tensor.

  • name (str, optional) – Please refer to Name, Default None.

Returns

Input tensor’s k largest elements along each last dimensional slice. The dimension is: \(input.shape[:-1]+[k]\). Indices (Variable): Indices of k largest elements alone the last dimension of input. The dimension is same as values.

Return type

Values (Variable)

Raises

ValueError – If \(k < 1\) or \(k > last dimension of input\).

Examples

import paddle.fluid as fluid
import paddle.fluid.layers as layers
# set batch size=None
input = fluid.data(name="input", shape=[None, 13, 11], dtype='float32')
top5_values, top5_indices = layers.topk(input, k=5) # top5_values.shape[None, 13, 5], top5_indices.shape=[None, 13, 5]

# 1D Tensor
input1 = fluid.data(name="input1", shape=[None, 13], dtype='float32')
top5_values, top5_indices = layers.topk(input1, k=5) #top5_values.shape=[None, 5], top5_indices.shape=[None, 5]

# k=Variable
input2 = fluid.data(name="input2", shape=[None, 13, 11], dtype='float32')
vk = fluid.data(name="vk", shape=[None, 1], dtype='int32') # save k in vk.data[0]
vk_values, vk_indices = layers.topk(input2, k=vk) #vk_values.shape=[None, 13, k], vk_indices.shape=[None, 13, k]