DataLoader

class paddle.io. DataLoader ( dataset, feed_list=None, places=None, return_list=True, batch_sampler=None, batch_size=1, shuffle=False, drop_last=False, collate_fn=None, num_workers=0, use_buffer_reader=True, prefetch_factor=2, use_shared_memory=True, timeout=0, worker_init_fn=None, persistent_workers=False ) [source]

DataLoader prodives an iterator which iterates given dataset once by the batch_sampler.

DataLoader supports single-process and multi-prcess data loading, multi-process workers will be used to load data asynchronously if num_workers is set as a positive number.

DataLoader supports map-style dataset and iterable-style dataset.

For map-style datast(can get a sample from dataset with a given index), please see paddle.io.Dataset.

For iterable-style datast(get samples from dataset iteratively, like a Python iterator), please see paddle.io.IterableDataset.

For batch_sampler please see paddle.io.BatchSampler

Notes

GPU tensor operation is not supported in subprocess currently, please don’t use GPU tensor operations in pipeline which will be performed in subprocess, such as dataset transforms, collte_fn, etc. Numpy array and CPU tensor operation is supported.

Disable automatic batching

In certain cases such as some NLP tasks, instead of automatic batching, handling batching manually in dataset is needed by users. For these cases, automatic batching is disabled if both batch_size and batch_sampler is set as None, each data got from dataset should be batched data and will be processed with function define by collate_fn or default_collate_fn.

Notes

When automatic batching is disabled, default_collate_fn will do nothing to data from dataset.

Parameters
  • dataset (Dataset) – the dataset to load data from, should be an instance of subclass of paddle.io.Dataset or paddle.io.IterableDataset.

  • feed_list (list(Tensor)|tuple(Tensor), optional) – feed Tensor list. The Tensors should be created by paddle.static.data(). feed_list must be set if return_list is False. Default None.

  • places (list(Place)|tuple(Place)|list(str), optional) – a list of Place, to put data onto, places can be None, if places is None, default place(CPUPlace or CUDAPlace(0)) will be used. Default None. If places is list of string, the string in the list can be cpu, gpu:x and gpu_pinned, where x is the index of the GPUs.

  • return_list (bool, optional) – whether the return value on each device is presented as a list. If return_list=False, the return value on each device would be a dict of str -> Tensor, where the key of the dict is the name of each fed Tensors. If return_list=True, the return value on each device would be a list(Tensor). return_list can only be True in dynamic graph mode. Default True.

  • batch_sampler (BatchSampler, optional) – an instance of paddle.io.BatchSampler to generate batch indices to draw samples from dataset and combine a batch. Default None.

  • batch_size (int|None, optional) – sample number in a mini-batch, a substitution parameter for batch_sampler, if batch_sampler is not set, a default paddle.io.BatchSampler will be used and initialize by batch_size, shuffle and drop_last. Default 1.

  • shuffle (bool, optional) – whther to shuffle indices order before genrate batch indices, a substitution parameter for batch_sampler see batch_size. Default False.

  • drop_last (bool, optional) – whether drop the last incomplete batch dataset size is not divisible by the batch size, a substitution parameter for batch_sampler, see batch_size. Default False

  • collate_fn (callable, optional) – function to generate mini-batch data by merging the sample list, None for only stack each fields of sample in axis 0(same as :attr::np.stack(…, axis=0)). Default None

  • num_workers (int, optional) – the number of subprocess to load data, 0 for no subprocess used and loading data in main process. Default 0

  • use_buffer_reader (bool, optional) – whether to use bufferred reader. If use_buffer_reader=True, the DataLoader would prefetch batch data asynchronously, so it would speed up data feeding and occupies a little more CPU or GPU memory, i.e., the memory of one batch input data. Default True.

  • prefetch_factor (int, optional) – Number of batch data the DataLoader would prefetch if use_buffer_reader=True. Default 2.

  • use_shared_memory (bool, optional) – whether to use shared memory to speed up putting data into inter-process queue, set use_shared_memory as True only when the shared memory space on your machine(e.g. space of ‘/dev/shm’ on Linux operating sysytem) is large enough. Shared memory will only be enabled in multi-process mode(num_workers > 0). Default True.

  • timeout (int, optional) – the timeout value for getting data form output queue of subprocesses. Default 0.

  • worker_init_fn (callable, optional) – init function which will be called with worker id on each subproces starting if not set as None. Default None.

Returns

an iterable object for data iterating, each elemnet of the generated data is a Tensor.

Return type

DataLoader

Examples

>>> import numpy as np

>>> import paddle
>>> import paddle.nn as nn
>>> import paddle.nn.functional as F
>>> from paddle.io import Dataset, BatchSampler, DataLoader

>>> BATCH_NUM = 20
>>> BATCH_SIZE = 16
>>> EPOCH_NUM = 4

>>> IMAGE_SIZE = 784
>>> CLASS_NUM = 10

>>> # define a random dataset
>>> class RandomDataset(Dataset):
...     def __init__(self, num_samples):
...         self.num_samples = num_samples
...
...     def __getitem__(self, idx):
...         image = np.random.random([IMAGE_SIZE]).astype('float32')
...         label = np.random.randint(0, CLASS_NUM - 1, (1, )).astype('int64')
...         return image, label
...
...     def __len__(self):
...         return self.num_samples
...
>>> dataset = RandomDataset(BATCH_NUM * BATCH_SIZE)

>>> class SimpleNet(nn.Layer):
...     def __init__(self):
...         super().__init__()
...         self.fc = nn.Linear(IMAGE_SIZE, CLASS_NUM)
...
...     def forward(self, image, label=None):
...         return self.fc(image)
...
>>> simple_net = SimpleNet()
>>> opt = paddle.optimizer.SGD(learning_rate=1e-3,
...                             parameters=simple_net.parameters())
...
>>> loader = DataLoader(dataset,
...                     batch_size=BATCH_SIZE,
...                     shuffle=True,
...                     drop_last=True,
...                     num_workers=2)
...
>>> for e in range(EPOCH_NUM):
...     for i, (image, label) in enumerate(loader()):
...         out = simple_net(image)
...         loss = F.cross_entropy(out, label)
...         avg_loss = paddle.mean(loss)
...         avg_loss.backward()
...         opt.minimize(avg_loss)
...         simple_net.clear_gradients()
...         print("Epoch {} batch {}: loss = {}".format(e, i, np.mean(loss.numpy())))

Notes

For reading iterable dataset with multiprocess Dataloader, please see paddle.io.IterableDataset