argwhere
返回输入 x
中非零元素的坐标。如果输入 x
有 n
维,共包含 z
个非零元素,返回结果是一个 shape
等于 [z x n]
的 Tensor
,第 i
行代表输入中第 i
个非零元素的坐标。
参数
input (Tensor)– 输入的 Tensor。
返回
Tensor(1-D Tensor),数据类型为 INT64 。
代码示例
>>> import paddle
>>> x = paddle.to_tensor([[1.0, 0.0, 0.0],
... [0.0, 2.0, 0.0],
... [0.0, 0.0, 3.0]])
>>> out = paddle.tensor.search.argwhere(x)
>>> print(out)
Tensor(shape=[3, 2], dtype=int64, place=Place(cpu), stop_gradient=True,
[[0, 0],
[1, 1],
[2, 2]])