shard_layer

paddle.distributed. shard_layer ( layer, process_mesh, shard_fn=None, input_fn=None, output_fn=None ) [源代码]

根据参数 shard_fn 将传入的 paddle.nn.Layer 所有的参数转换为带有分布式切分信息的 Tensor。同时也支持指定 input_fnoutput_fn 用于控制输入和输出 Tensor 的转换。(具体指的是,将输入转换为带有分布式切分信息的 Tensor,将输出转回不带分布式切分信息的 Tensor。)

shard_fn 的函数签名为:def shard_fn(layer_name, layer, process_mesh) -> None。

input_fn 的函数签名为:def input_fn(inputs, process_mesh) -> list(paddle.Tensor),一般地,input_fn 返回值的类型为带有分布式切分信息的 Tensor

output_fn 的函数签名为:def output_fn(outputs, process_mesh) -> list(paddle.Tensor),一般地,output_fn 返回值的类型为不带分布式切分信息的 Tensor

参数

  • layer (paddle.nn.Layer) - 需要被切分的 Layer 对象。

  • process_mesh (paddle.distributed.ProcessMesh) - 执行当前 LayerProcessMesh 信息。

  • shard_fn (Callable) - 用于切分当前 Layer 参数的函数。如果没有指定,默认地我们将在当前 ProcessMesh 上复制所有的参数。

  • input_fn (Callable) - 指定如何切分 Layer 的输入。input_fn 函数将被注册为 Layer 的一个 forward pre-hook。默认我们将不会切分 Layer 的输入。

  • output_fn (Callable) - 指定如何切分 Layer 的输出,或者将 Layer 的输出转回不带分布式切分信息的 Tensoroutput_fn 函数将被注册为 Layer 的一个 forward post-hook。默认我们将不会切分或者转换 Layer 的输出。

返回

Layer:一个参数全部为带有分布式切分信息 TensorLayer 对象。

代码示例

>>> import paddle
>>> import paddle.distributed as dist

>>> mesh = dist.ProcessMesh([0, 1], dim_names=["x"])

>>> class MLP(paddle.nn.Layer):
...     def __init__(self):
...         super().__init__()
...         self.fc1 = paddle.nn.Linear(8, 8)
...         self.fc2 = paddle.nn.Linear(8, 8)
...
...     def forward(self, input):
...         return self.fc2(self.fc1(input))

>>> def shard_fn(layer_name, layer, process_mesh):
...     if layer_name == 'fc1':
...         layer.weight = dist.shard_tensor(layer.weight, process_mesh, [dist.Shard(0)])

>>> layer = MLP()
>>> layer = dist.shard_layer(layer, mesh, shard_fn)
>>> print(layer)

>>> # This case need to be excuted in multi-card environment
>>> # export CUDA_VISIBLE_DEVICES=0,1
>>> # python -m paddle.distributed.launch {test_case}.py