split

paddle.distributed. split ( x, size, operation, axis=0, num_partitions=1, gather_out=True, weight_attr=None, bias_attr=None, name=None ) [源代码]

切分指定操作的参数到多个设备,并且并行计算得到结果。

当前,支持一下三种情形。

情形1:并行Embedding

Embedding操作的参数是个NxM的矩阵,行数为N,列数为M。并行Embedding情形下,参数切分到num_partitions个设备,每个设备上的参数是 (N/num_partitions + 1)行、M列的矩阵。其中,最后一行作为padding idx。

假设将NxM的参数矩阵切分到两个设备device_0和device_1。那么每个设备上的参数矩阵为(N/2+1)行和M列。device_0上,输入x中的值如果介于[0, N/2-1],则其值保持不变;否则值变更为N/2,经过embedding映射为全0值。类似地,device_1上,输入x中的值V如果介于[N/2, N-1]之间,那么这些值将变更为(V-N/2);否则,值变更为N/2,经过embedding映射为全0值。最后,使用all_reduce_sum操作汇聚各个卡上的结果。

单卡Embedding情况如下图所示

single_embedding

并行Embedding情况如下图所示

split_embedding
情形2:行并行Linear

Linear操作是将输入变量X(N*N)与权重矩阵W(N*M)进行矩阵相乘。行并行Linear情形下,参数切分到num_partitions个设备,每个设备上的参数是N/num_partitions行、M列的矩阵。

单卡Linear情况如下图所示,输入变量用X表示,权重矩阵用W表示,输出变量用O表示,单卡Linear就是一个简单的矩阵乘操作,O = X * W。

single_linear

行并行Linear情况如下图所示,顾名思义,行并行是按照权重矩阵W的行切分权重矩阵为 [[W_row1], [W_row2]],对应的输入X也按照列切成了两份[X_col1, X_col2],分别与各自对应的权重矩阵相乘, 最后通过AllReduce规约每张卡的输出得到最终输出。

split_row
情形3:列并行Linear

Linear操作是将输入变量X(N*N)与权重矩阵W(N*M)进行矩阵相乘。列并行Linear情形下,参数切分到num_partitions个设备,每个设备上的参数是N行、M/num_partitions列的矩阵。

单卡并行Linear可以看上面对应的图,列并行Linear情况如下图所示。列并行是按照权重矩阵W的列切分权重矩阵为[W_col1, W_col2], X分别与切分出来的矩阵相乘,最后通过AllGather拼接每张卡的输出得到最终输出。

split_col

我们观察到,可以把上述按列切分矩阵乘法和按行切分矩阵乘法串联起来,从而省略掉一次AllGather通信操作,如下图所示。同时,我们注意到Transformer的Attention和MLP组件中各种两次矩阵乘法操作。因此,我们可以按照这种串联方式分别把Attention和MLP组件中的两次矩阵乘法串联起来,从而进一步优化性能。

split_col_row

参数

  • x (Tensor) - 输入Tensor。Tensor的数据类型为:float16、float32、float64、int32、int64。

  • size (list|tuple) - 指定参数形状的列表或元组,包含2个元素。

  • operation (str) - 指定操作名称,当前支持的操作名称为'embedding'或'linear'。

  • axis (int,可选) - 指定沿哪个维度切分参数。默认值:0。

  • num_partitions (int,可选) - 指定参数的划分数。默认值:1。

  • gather_out (bool,可选) - 是否聚合所有设备的计算结果。默认地,聚合所有设备的计算结果。默认值:True。

  • weight_attr (ParamAttr,可选) - 指定参数的属性。默认值:None。

  • bias_attr (ParamAttr,可选) - 指定偏置的属性。默认值:None。

  • name (str,可选) - 默认值为None,通常用户不需要设置该属性。更多信息请参考 Name

返回

Tensor

代码示例

import paddle
import paddle.distributed.fleet as fleet

paddle.enable_static()
paddle.set_device('gpu:%d'%paddle.distributed.ParallelEnv().dev_id)
fleet.init(is_collective=True)
data = paddle.randint(0, 8, shape=[10,4])
emb_out = paddle.distributed.split(
    data,
    (8, 8),
    operation="embedding",
    num_partitions=2)