global_scatter

paddle.distributed.utils. global_scatter ( x, local_count, global_count, group=None, use_calc_stream=True ) [source]

The global_scatter operator distributes the data of x to n_expert * world_size experts according to local_count, and then receives data according to global_count. The expert refers to a user-defined expert network, n_expert refers to the number of expert networks owned by each card, and world_size refers to the number of graphics cards running the network.

As shown below, the value of the world size is 2, n_expert 2, the batch size of the x 4 and local_count is [2, 0, 2, 0]. The global_count of the rank 0 is [2, 0, , ], rank 1 is [2, 0, ,](Due to the limited space, only the data calculated on rank 0 is shown here). In the global_scatter operator, local_count[i] represents sending local_count[i] data to the (i % n_expert)th expert of the (i // n_expert)th card, global_count[i] represents receiving global_count[i] data from the (i // n_expert)th card to the (i % n_expert)th expert of this card. The rank in the figure respresent the rank of the current card in all cards.

The process of global_scatter sending data is as follows:

local_count[0] represents taking out 2 batches from x and sending 2 batches to the 0th expert of the 0th card;

local_count[1] represents taking out 0 batches from x and sending 0 batches to the 1th expert of the 0th card;

local_count[2] represents taking out 2 batches from x and sending 2 batches to the 0th expert of the 1th card;

local_count[3] represents taking out 0 batches from x and sending 0 batches to the 1th expert of the 1th card;

Therefore, the global_count[0] of the 0th card is equal to 2, which means that 2 batches of data are received from the 0th card to the 0th expert;

the global_count[1] of the 0th card is equal to 0, which means that 0 batches of data are received from the 0th card to the 1th expert;

the global_count[0] of the 1th card is equal to 2, which means that 2 batches of data are received from the 0th card to the 0th expert;

the global_count[1] of the 1th card is equal to 0, which means that 0 batches of data are received from the 0th card to the 1th expert.

global_scatter_gather
Parameters
  • x (Tensor) – Tensor. The tensor data type should be float16, float32, float64, int32 or int64.

  • local_count (Tensor) – Tensor which have n_expert * world_size elements that indicates how many data needed to be sent. The tensor data type should be int64.

  • global_count (Tensor) – Tensor which have n_expert * world_size elements that indicates how many data needed to be received. The tensor data type should be int64.

  • group (Group, optional) – The group instance return by new_group or None for global default group. Default: None.

  • use_calc_stream (bool, optional) – Wether to use calculation stream (True) or communication stream. Default: True.

Returns

The data received from all experts.

Return type

out (Tensor)

Examples

# required: distributed
import numpy as np
import paddle
from paddle.distributed import init_parallel_env
init_parallel_env()
n_expert = 2
world_size = 2
d_model = 2
in_feat = d_model
local_input_buf = np.array([[1, 2],[3, 4],[5, 6],[7, 8],[9, 10]],             dtype=np.float32)
if paddle.distributed.ParallelEnv().local_rank == 0:
    local_count = np.array([2, 1, 1, 1])
    global_count = np.array([2, 1, 1, 1])
else:
    local_count = np.array([1, 1, 2, 1])
    global_count = np.array([1, 1, 2, 1])
local_input_buf = paddle.to_tensor(local_input_buf, dtype="float32", stop_gradient=False)
local_count = paddle.to_tensor(local_count, dtype="int64")
global_count = paddle.to_tensor(global_count, dtype="int64")
a = paddle.distributed.utils.global_scatter(local_input_buf,             local_count, global_count)
a.stop_gradient = False
print(a)
# out for rank 0: [[1, 2], [3, 4], [1, 2], [5, 6], [3, 4]]
# out for rank 1: [[7, 8], [5, 6], [7, 8], [9, 10], [9, 10]]
# backward test
c = a * a
c.backward()
print("local_input_buf.grad: ", local_input_buf.grad)
# out for rank 0: [[2, 4], [6, 8], [10, 12], [14, 16], [18, 20]]
# out for rank 1: [[2, 4], [6, 8], [10, 12], [14, 16], [18, 20]]