take_along_axis

paddle. take_along_axis ( arr, indices, axis ) [source]

Take values from the input array by given indices matrix along the designated axis.

Parameters
  • arr (Tensor) – The input Tensor. Supported data types are float32 and float64.

  • indices (Tensor) – Indices to take along each 1d slice of arr. This must match the dimension of arr, and need to broadcast against arr. Supported data type are int and int64.

  • axis (int) – The axis to take 1d slices along.

Returns

The indexed element, same dtype with arr

Return type

Tensor

Examples

import paddle
import numpy as np

x_np = np.array([[1, 2, 3], [4, 5, 6], [7,8,9]])
index_np = np.array([[0]])
x = paddle.to_tensor(x_np)
index = paddle.to_tensor(index_np)
axis = 0
result = paddle.take_along_axis(x, index, axis)
print(result)
# [[1, 2, 3]]