scatter¶
通过基于 updates 来更新选定索引 index 上的输入来获得输出。具体行为如下:
 >>> import paddle
 >>> #input:
 >>> x = paddle.to_tensor([[1, 1], [2, 2], [3, 3]], dtype='float32')
 >>> index = paddle.to_tensor([2, 1, 0, 1], dtype='int64')
 >>> # shape of updates should be the same as x
 >>> # shape of updates with dim > 1 should be the same as input
 >>> updates = paddle.to_tensor([[1, 1], [2, 2], [3, 3], [4, 4]], dtype='float32')
 >>> overwrite = False
 >>> # calculation:
 >>> if not overwrite:
 ...     for i in range(len(index)):
 ...         x[index[i]] = paddle.zeros([2])
 >>> for i in range(len(index)):
 ...     if (overwrite):
 ...         x[index[i]] = updates[i]
 ...     else:
 ...         x[index[i]] += updates[i]
 >>> # output:
 >>> out = paddle.to_tensor([[3, 3], [6, 6], [1, 1]])
 >>> print(out.shape)
 [3, 2]
        Notice: 因为 updates 的应用顺序是不确定的,因此,如果索引 index 包含重复项,则输出将具有不确定性。
参数¶
x (Tensor) - ndim> = 1 的输入 N-D Tensor。数据类型可以是 float32,float64。
index (Tensor)- 一维或者零维 Tensor。数据类型可以是 int32,int64。
index的长度不能超过updates的长度,并且index中的值不能超过输入的长度。updates (Tensor)- 根据
index使用update参数更新输入x。当index为一维 tensor 时,updates形状应与输入x相同,并且 dim>1 的 dim 值应与输入x相同。当index为零维 tensor 时,updates应该是一个 (N-1)-D 的 Tensor,并且updates的第 i 个维度应该与x的 i+1 个维度相同。overwrite (bool,可选)- 指定索引
index相同时,更新输出的方式。如果为 True,则使用覆盖模式更新相同索引的输出,如果为 False,则使用累加模式更新相同索引的输出。默认值为 True。name (str,可选) - 具体用法请参见 Name,一般无需设置,默认值为 None。
返回¶
Tensor,与 x 有相同形状和数据类型。
代码示例¶
 >>> import paddle
 >>> x = paddle.to_tensor([[1, 1], [2, 2], [3, 3]], dtype='float32')
 >>> index = paddle.to_tensor([2, 1, 0, 1], dtype='int64')
 >>> updates = paddle.to_tensor([[1, 1], [2, 2], [3, 3], [4, 4]], dtype='float32')
 >>> output1 = paddle.scatter(x, index, updates, overwrite=False)
 >>> print(output1)
 Tensor(shape=[3, 2], dtype=float32, place=Place(cpu), stop_gradient=True,
 [[3., 3.],
  [6., 6.],
  [1., 1.]])
 >>> output2 = paddle.scatter(x, index, updates, overwrite=True)
 >>> # CPU device:
 >>> # [[3., 3.],
 >>> #  [4., 4.],
 >>> #  [1., 1.]]
 >>> # GPU device maybe have two results because of the repeated numbers in index
 >>> # result 1:
 >>> # [[3., 3.],
 >>> #  [4., 4.],
 >>> #  [1., 1.]]
 >>> # result 2:
 >>> # [[3., 3.],
 >>> #  [2., 2.],
 >>> #  [1., 1.]]