argmax

paddle. argmax ( x, axis=None, keepdim=False, dtype='int64', name=None ) [source]

This OP computes the indices of the max elements of the input tensor’s element along the provided axis.

Parameters
  • x (Tensor) – An input N-D Tensor with type float32, float64, int16, int32, int64, uint8.

  • axis (int, optional) – Axis to compute indices along. The effective range is [-R, R), where R is x.ndim. when axis < 0, it works the same way as axis + R. Default is None, the input x will be into the flatten tensor, and selecting the min value index.

  • keepdim (bool, optional) – Whether to keep the given axis in output. If it is True, the dimensions will be same as input x and with size one in the axis. Otherwise the output dimentions is one fewer than x since the axis is squeezed. Default is False.

  • dtype (str|np.dtype, optional) – Data type of the output tensor which can be int32, int64. The default value is ‘int64’, and it will return the int64 indices.

  • name (str, optional) – The default value is None. Normally there is no need for user to set this property. For more information, please refer to Name.

Returns

Tensor, return the tensor of int32 if set dtype is int32, otherwise return the tensor of int64

Examples

import paddle

x =  paddle.to_tensor([[5,8,9,5],
                         [0,0,1,7],
                         [6,9,2,4]])
out1 = paddle.argmax(x)
print(out1) # 2
out2 = paddle.argmax(x, axis=0)
print(out2)
# [2, 2, 0, 1]
out3 = paddle.argmax(x, axis=-1)
print(out3)
# [2, 3, 1]
out4 = paddle.argmax(x, axis=0, keepdim=True)
print(out4)
# [[2, 2, 0, 1]]