
paddle.distributed. to_static ( layer: paddle.nn.layer.layers.Layer, loader=None, loss=None, optimizer=None, strategy=None ) [source]

Converts the layer with distributed tensor (constructed from paddle.distributed.shard_tensor) to a static graph. to_static returns a DistModel instance containing the static graph for distributed training, evaluation and prediction.

  • layer (paddle.nn.Layer) – The layer in dygraph mode, the parameters or its inputs can be distributed tensors.

  • loader (ShardDataloader| – The data loader used in dygraph mode, used to infer inputs_spec and labels_spec.

  • loss (Loss|Callable|None, optional) – The loss function for training or evaluating the model. Can be a paddle.nn.Layer instance or any callable function. Default: None.

  • optimizer (paddle.optimizer.Optimizer|_ShardOptimizer|None, optional) – The optimizer for training. It can paddle.optimizer.Optimizer or _ShardOptimizer wrapped by shard_optimizer. Default: None.

  • strategy (paddle.distributed.Strategy|None, optional) – Configs for parallel strategies and optimization settings (e.g. sharding, pipeline parallelism). Default: None.


A DistModel instance converted the input layer.

Return type



>>> import numpy as np
>>> import paddle
>>> import paddle.distributed as dist
>>> from paddle import nn
>>> from paddle.distributed import Replicate, Shard

>>> BATCH_SIZE = 4
>>> BATCH_NUM = 4
>>> IMAGE_SIZE = 16
>>> CLASS_NUM = 8
>>> class RandomDataset(
...     def __init__(self, images, labels, num_samples):
...         self.images = images
...         self.labels = labels
...         self.num_samples = num_samples
...     def __getitem__(self, idx):
...         return self.images[idx], self.labels[idx]
...     def __len__(self):
...         return self.num_samples

>>> class DemoNet(nn.Layer):
...     def __init__(self, mesh):
...         super().__init__()
...         self._mesh = mesh
...         self.linear_0 = nn.Linear(IMAGE_SIZE, IMAGE_SIZE)
...         self.linear_1 = nn.Linear(IMAGE_SIZE, CLASS_NUM)
...         self.relu = nn.ReLU()
...         # shard the weights of this layer
...         self.linear_0.weight = dist.shard_tensor(
...             self.linear_0.weight,
...             self._mesh,
...             [Shard(1)],
...             stop_gradient=False,
...         )
...         self.linear_1.weight = dist.shard_tensor(
...             self.linear_1.weight,
...             self._mesh,
...             [Shard(0)],
...             stop_gradient=False,
...         )
...     def forward(self, x):
...         out = self.linear_0(x)
...         out = self.relu(out)
...         out = self.linear_1(out)
...         return out

>>> images = np.random.rand(BATCH_SIZE, IMAGE_SIZE).astype('float32')
>>> labels = np.random.rand(BATCH_SIZE, CLASS_NUM).astype('float32')
>>> dataset = RandomDataset(images, labels, BATCH_SIZE)
>>> loader =, batch_size=BATCH_SIZE)

>>> mesh = dist.ProcessMesh([0, 1], dim_names=["x"])
>>> layer = DemoNet(mesh)
>>> opt = paddle.optimizer.SGD(
...     learning_rate=0.1, parameters=layer.parameters()
... )
>>> loss_fn = nn.MSELoss()
>>> dist_loader = dist.shard_dataloader(loader, meshes=[mesh])
>>> dist_model = dist.to_static(
...     layer, dist_loader, loss_fn, opt
... )
>>> # training
>>> dist_model.train()
>>> for batch_id, (image, label) in enumerate(dist_loader()):
...     # in train mode, executing the __call__ method will
...     # update the parameters of the model and return the
...     # loss
...     loss = dist_model(image, label)

>>> # evaluation
>>> dist_model.eval()
>>> for batch_id, (image, label) in enumerate(dist_loader()):
...     # in eval mode, executing the __call__ method will
...     # return the loss
...     loss = dist_model(image, label)

>>> # prediction
>>> dist_model.predict()
>>> for batch_id, (image, label) in enumerate(dist_loader()):
...     # in predict mode, executing the __call__ method will
...     # return a dict that contains the outputs of the model,
...     # where the value of "out0" is the first output.
...     outs = dist_model(image)

>>> # This case need to be executed in multi-card environment
>>> # export CUDA_VISIBLE_DEVICES=0,1
>>> # python -m paddle.distributed.launch {test_case}.py