scatter

paddle.fluid.layers.scatter(input, index, updates, name=None, overwrite=True)[source]

Scatter Layer

Output is obtained by updating the input on selected indices based on updates.

Parameters
  • input (Variable) – The input N-D Tensor with rank>=1. Data type can be float32.

  • index (Variable) – The index 1-D Tensor. Data type can be int32, int64. The length of index cannot exceed updates’s length, and the value in index cannot exceed input’s length.

  • updates (Variable) – update input with updates parameter based on index. shape should be the same as input, and dim value with dim > 1 shoule be the same as input.

  • name (str, optional) – The default value is None. Normally there is no need for user to set this property. For more information, please refer to Name .

  • overwrite (bool) – The mode that updating the output when there are same indices. If True, use the overwrite mode to update the output of the same index, if False, use the accumulate mode to update the output of the same index. Default value is True.

Returns

The output is a Tensor with the same shape as input.

Return type

Variable(Tensor|LoDTensor)

Examples

import numpy as np
import paddle.fluid as fluid

input = fluid.layers.data(name='data', shape=[3, 2], dtype='float32', append_batch_size=False)
index = fluid.layers.data(name='index', shape=[4], dtype='int64', append_batch_size=False)
updates = fluid.layers.data(name='update', shape=[4, 2], dtype='float32', append_batch_size=False)

output = fluid.layers.scatter(input, index, updates, overwrite=False)

exe = fluid.Executor(fluid.CPUPlace())
exe.run(fluid.default_startup_program())

in_data = np.array([[1, 1], [2, 2], [3, 3]]).astype(np.float32)
index_data = np.array([2, 1, 0, 1]).astype(np.int64)
update_data = np.array([[1, 1], [2, 2], [3, 3], [4, 4]]).astype(np.float32)

res = exe.run(fluid.default_main_program(), feed={'data':in_data, "index":index_data, "update":update_data}, fetch_list=[output])
print(res)
# [array([[3., 3.],
#   [6., 6.],
#   [1., 1.]], dtype=float32)]