DistributedBatchSampler

class paddle.io. DistributedBatchSampler ( dataset, batch_size, num_replicas=None, rank=None, shuffle=False, drop_last=False ) [source]

Sampler that restricts data loading to a subset of the dataset.

In such case, each process can pass a DistributedBatchSampler instance as a DataLoader sampler, and load a subset of the original dataset that is exclusive to it.

Note

Dataset is assumed to be of constant size.

Parameters
  • dataset (Dataset) – this could be an instance of subclass of Dataset or other python object which implemented __len__ for BatchSampler to get indices of samples.

  • batch_size (int) – sample size of each mini-batch.

  • num_replicas (int, optional) – porcess number in distributed training. If num_replicas is None, num_replicas will be retrieved from ParallelEnv . Default None.

  • rank (int, optional) – the rank of the current process among num_replicas processes. If rank is None, rank is retrieved from ParallelEnv. Default None.

  • shuffle (bool, optional) – whther to shuffle indices order before genrating batch indices. Default False.

  • drop_last (bool, optional) – whether drop the last incomplete(less than a mini-batch) batch dataset size. Default False.

Returns

DistributedBatchSampler, return an iterable object for indices iterating.

Examples

>>> import numpy as np

>>> from paddle.io import Dataset, DistributedBatchSampler

>>> # init with dataset
>>> class RandomDataset(Dataset):
...     def __init__(self, num_samples):
...         self.num_samples = num_samples
...
...     def __getitem__(self, idx):
...         image = np.random.random([784]).astype('float32')
...         label = np.random.randint(0, 9, (1, )).astype('int64')
...         return image, label
...
...     def __len__(self):
...         return self.num_samples
...
>>> dataset = RandomDataset(100)
>>> sampler = DistributedBatchSampler(dataset, batch_size=64)

>>> for data in sampler:
...     # do something
...     break
set_epoch ( epoch )

set_epoch

Sets the epoch number. When shuffle=True, this number is used as seeds of random numbers. By default, users may not set this, all replicas (workers) use a different random ordering for each epoch. If set same number at each epoch, this sampler will yield the same ordering at all epoches.

Parameters

epoch (int) – Epoch number.

Examples

>>> import numpy as np

>>> from paddle.io import Dataset, DistributedBatchSampler

>>> # init with dataset
>>> class RandomDataset(Dataset):
...     def __init__(self, num_samples):
...         self.num_samples = num_samples
...
...     def __getitem__(self, idx):
...         image = np.random.random([784]).astype('float32')
...         label = np.random.randint(0, 9, (1, )).astype('int64')
...         return image, label
...
...     def __len__(self):
...         return self.num_samples
...
>>> dataset = RandomDataset(100)
>>> sampler = DistributedBatchSampler(dataset, batch_size=64)

>>> for epoch in range(10):
...     sampler.set_epoch(epoch)