take_along_axis
基于输入索引矩阵,沿着指定 axis 从 arr 矩阵里选取 1d 切片。索引矩阵必须和 arr 矩阵有相同的维度,需要能够 broadcast 与 arr 矩阵对齐。
备注
别名支持: 参数名 input 可替代 arr 和 dim 可替代 axis,如 input=tensor_arr 等价于 arr=tensor_arr, dim=1 等价于 axis=1。
参数
arr (Tensor) - 输入的 Tensor 作为源矩阵,数据类型为:bfloat16、float16、float32、float64、int32、int64、uint8。 别名:
inputindices (Tensor) - 索引矩阵,包含沿轴提取 1d 切片的下标,必须和 arr 矩阵有相同的维度,需要能够 broadcast 与 arr 矩阵对齐,数据类型为:int32、int64。
axis (int) - 指定沿着哪个维度获取对应的值,数据类型为:int。 别名:
dimbroadcast (bool,可选) - 是否广播
index矩阵,默认为True。
返回
输出 Tensor,包含 indeces 矩阵选定的元素,与 arr 数据类型相同。
代码示例
>>> import paddle
>>> x = paddle.to_tensor([[1, 2, 3], [4, 5, 6], [7,8,9]])
>>> index = paddle.to_tensor([[0]])
>>> axis = 0
>>> result = paddle.take_along_axis(x, index, axis)
>>> print(result)
Tensor(shape=[1, 3], dtype=int64, place=Place(cpu), stop_gradient=True,
[[1, 2, 3]])