global_scatter¶
- paddle.distributed.utils. global_scatter ( x, local_count, global_count, group=None, use_calc_stream=True ) [source]
-
The global_scatter operator distributes the data of x to n_expert * world_size experts according to local_count, and then receives data according to global_count. The expert refers to a user-defined expert network, n_expert refers to the number of expert networks owned by each card, and world_size refers to the number of graphics cards running the network.
As shown below, the value of the world size is 2, n_expert 2, the batch size of the x 4 and local_count is [2, 0, 2, 0]. The global_count of the rank 0 is [2, 0, , ], rank 1 is [2, 0, ,](Due to the limited space, only the data calculated on rank 0 is shown here). In the global_scatter operator, local_count[i] represents sending local_count[i] data to the (i % n_expert)th expert of the (i // n_expert)th card, global_count[i] represents receiving global_count[i] data from the (i // n_expert)th card to the (i % n_expert)th expert of this card. The rank in the figure respresent the rank of the current card in all cards.
The process of global_scatter sending data is as follows:
local_count[0] represents taking out 2 batches from x and sending 2 batches to the 0th expert of the 0th card;
local_count[1] represents taking out 0 batches from x and sending 0 batches to the 1th expert of the 0th card;
local_count[2] represents taking out 2 batches from x and sending 2 batches to the 0th expert of the 1th card;
local_count[3] represents taking out 0 batches from x and sending 0 batches to the 1th expert of the 1th card;
Therefore, the global_count[0] of the 0th card is equal to 2, which means that 2 batches of data are received from the 0th card to the 0th expert;
the global_count[1] of the 0th card is equal to 0, which means that 0 batches of data are received from the 0th card to the 1th expert;
the global_count[0] of the 1th card is equal to 2, which means that 2 batches of data are received from the 0th card to the 0th expert;
the global_count[1] of the 1th card is equal to 0, which means that 0 batches of data are received from the 0th card to the 1th expert.
- Parameters
-
x (Tensor) – Tensor. The tensor data type should be float16, float32, float64, int32 or int64.
local_count (Tensor) – Tensor which have n_expert * world_size elements that indicates how many data needed to be sent. The tensor data type should be int64.
global_count (Tensor) – Tensor which have n_expert * world_size elements that indicates how many data needed to be received. The tensor data type should be int64.
group (Group, optional) – The group instance return by new_group or None for global default group. Default: None.
use_calc_stream (bool, optional) – Wether to use calculation stream (True) or communication stream. Default: True.
- Returns
-
The data received from all experts.
- Return type
-
out (Tensor)
Examples
# required: distributed import numpy as np import paddle from paddle.distributed import init_parallel_env init_parallel_env() n_expert = 2 world_size = 2 d_model = 2 in_feat = d_model local_input_buf = np.array([[1, 2],[3, 4],[5, 6],[7, 8],[9, 10]], dtype=np.float32) if paddle.distributed.ParallelEnv().local_rank == 0: local_count = np.array([2, 1, 1, 1]) global_count = np.array([2, 1, 1, 1]) else: local_count = np.array([1, 1, 2, 1]) global_count = np.array([1, 1, 2, 1]) local_input_buf = paddle.to_tensor(local_input_buf, dtype="float32", stop_gradient=False) local_count = paddle.to_tensor(local_count, dtype="int64") global_count = paddle.to_tensor(global_count, dtype="int64") a = paddle.distributed.utils.global_scatter(local_input_buf, local_count, global_count) a.stop_gradient = False print(a) # out for rank 0: [[1, 2], [3, 4], [1, 2], [5, 6], [3, 4]] # out for rank 1: [[7, 8], [5, 6], [7, 8], [9, 10], [9, 10]] # backward test c = a * a c.backward() print("local_input_buf.grad: ", local_input_buf.grad) # out for rank 0: [[2, 4], [6, 8], [10, 12], [14, 16], [18, 20]] # out for rank 1: [[2, 4], [6, 8], [10, 12], [14, 16], [18, 20]]