shard_layer

paddle.distributed. shard_layer ( layer: paddle.nn.layer.layers.Layer, process_mesh: paddle.distributed.auto_parallel.process_mesh.ProcessMesh, shard_fn: Optional[Callable] = None, input_fn: Optional[Callable] = None, output_fn: Optional[Callable] = None ) paddle.nn.layer.layers.Layer [source]

Converts all layer’s parameters to DistTensor parameters according to the shard_fn specified. It could also control the conversion of input or output of the layer by specifying the input_fn and output_fn. (i.e. convert the input to paddle.Tensor with DistTensor, convert output back to paddle.Tensor with DenseTensor.)

The shard_fn should have the following signature:

def shard_fn(layer_name, layer, process_mesh) -> None

The input_fn should have the following signature:

def input_fn(inputs, process_mesh) -> list(paddle.Tensor)

In general, the type of input_fn return value is paddle.Tensor with DistTensor.

The output_fn should have the following signature:

def output_fn(outputs, process_mesh) -> list(paddle.Tensor)

In general, the type of output_fn return value is paddle.Tensor with DenseTensor.

Parameters
  • layer (paddle.nn.Layer) – The Layer object to be shard.

  • process_mesh (paddle.distributed.ProcessMesh) – The ProcessMesh information to be place the input layer.

  • shard_fn (Callable) – The function to shard layer parameters across the process_mesh. If not specified, by default we replicate all parameters of the layer across the process_mesh.

  • input_fn (Callable) – Specify how the input of the layer is sharded. The input_fn will be registered for the Layer as a forward pre-hook. By default we do not shard the input.

  • output_fn (Callable) – Specify how the output of the layer is sharded or convert it back to paddle.Tensor with DenseTensor. The output_fn will be registered for the Layer as forward post-hook. By default we do not shard or convert the output.

Returns

A layer that contains parameters/buffers

that are all paddle.Tensor with DistTensor

Return type

Layer

Examples

>>> import paddle
>>> import paddle.distributed as dist

>>> mesh = dist.ProcessMesh([0, 1], dim_names=["x"])

>>> class MLP(paddle.nn.Layer):
...     def __init__(self):
...         super().__init__()
...         self.fc1 = paddle.nn.Linear(8, 8)
...         self.fc2 = paddle.nn.Linear(8, 8)
...
...     def forward(self, input):
...         return self.fc2(self.fc1(input))

>>> def shard_fn(layer_name, layer, process_mesh):
...     if layer_name == 'fc1':
...         layer.weight = dist.shard_tensor(layer.weight, process_mesh, [dist.Shard(0)])

>>> layer = MLP()
>>> layer = dist.shard_layer(layer, mesh, shard_fn)
>>> print(layer)

>>> # This case need to be excuted in multi-card environment
>>> # export CUDA_VISIBLE_DEVICES=0,1
>>> # python -m paddle.distributed.launch {test_case}.py