shard_optimizer

paddle.distributed. shard_optimizer ( optimizer, shard_fn=None ) [source]

Warp the global view optimizer to distributed view.

Note

The shard_fn should have the following signature:

def shard_fn(accumulator_name, param, accumulator) -> sharded_accumulator

Parameters
  • optimizer (paddle.optimizer.Optimizer) – The optimizer to be sharded.

  • shard_fn (Callable, optional) – The function to shard accumulators. If not specified, we simply pass down the dist attr of the params.

Returns

An optimizer with distributed view.

Examples

>>> import paddle
>>> import paddle.distributed as dist
>>> mesh = dist.ProcessMesh([0, 1], dim_names=["x"])
>>> class MLP(paddle.nn.Layer):
...     def __init__(self):
...         super().__init__()
...         self.fc1 = paddle.nn.Linear(8, 8)
...         self.fc2 = paddle.nn.Linear(8, 8)
...
...     def forward(self, input):
...         return self.fc2(self.fc1(input))
>>> layer = MLP()
>>> batch = paddle.rand(shape=[8, 8])
>>> opt = paddle.optimizer.AdamW(parameters=layer.parameters())
>>> opt = dist.shard_optimizer(opt)
>>> for _ in range(5):
>>>     loss = layer(batch)
>>>     loss.backward()
>>>     opt.step()
>>>     opt.clear_grad()
>>> # This case need to be executed in multi-card environment
>>> # python -m paddle.distributed.launch --gpus=0,1 {test_case}.py