split( x, size, operation, axis=0, num_partitions=1, gather_out=True, weight_attr=None, bias_attr=None, name=None )
Split the weight of the specified operation into multiple devices and do the computation in parallel.
Now the following three cases are supported.
- Case 1: Parallel Embedding
The weight of the embedding operation is a NxM matrix with N rows and M columns. With parallel embedding, the weight is split into num_partitions partitions, each of which is a matrix with (N/num_partitions + 1) rows and M column where the last row as the padding idx.
Suppose we split the NxM weight into two partitons on device_0 and device_1 respectively. Then, one each device, the final weight has (N/2 + 1) rows with the index range from 0 to N/2. On device_0, all values in the input within [0, N/2 -1] keep unchanged and all other values are changed to N/2 which is the padding index and are mapped to all zeros after embedding. In the same way, on device_1, the value V in the input within [N/2, N-1] will be changed to (V - N/2), and all other values are changed to N/2 and are mapped to all zeros after embedding. Finally, the results on the two devices are sum-reduced.
- Case 2: Row Parallel Linear
The weight of the linear operation is a NxM matrix with N rows and M columns. With row parallel linear, the weight is split into num_partitions partitions, each of which is a matrix with N/num_partitions rows and M column.
- Case 3: Column Parallel Linear
The weight of the linear operation is a NxM matrix with N rows and M columns. With column parallel linear, the weight is split into num_paratitions partitions, each of which is a matrix with N rows and M/num_partitions column.
x (Tensor) – Input tensor. It’s data type should be float16, float32, float64, int32 or int64.
size (list|tuple) – A list or tuple with two elements indicating the shape of the weight.
operation (str) – The name of the operation. The supported operations are ‘linear’ and ‘embedding’.
axis (int, Optional) – Indicate along which axis to split the weight. Default: 0.
num_partitions (int, Optional) – How many parts the weight is partitioned. Default: 1.
gather_out (bool, Optional) – Whether to gather the output after computation. By default, the output on each partitions will be gathered after computation. Default: True.
weight_attr (ParamAttr, Optional) – The parameter attribute for the learnable weights(Parameter) of the specified operation. Default: None.
bias_attr (ParamAttr, Optional) – The parameter attribute for the bias of the specified operation. Default: None.
name (str, Optional) – The default value is None. Normally there is no need for user to set this property. Default: None. For more information, please refer to Name.
import paddle from paddle.distributed import init_parallel_env paddle.set_device('gpu:%d'%paddle.distributed.ParallelEnv().dev_id) init_parallel_env() data = paddle.randint(0, 8, shape=[10,4]) emb_out = padle.distributed.split( data, (8, 8), operation="embedding", num_partitions=2)