BatchSampler¶
- class paddle.io. BatchSampler ( dataset=None, sampler=None, shuffle=False, batch_size=1, drop_last=False ) [source]
- 
         A base implement of batch sampler used by paddle.io.DataLoader which yield mini-batch indices(a list/tuple with length as mini-batch size and holds sample indices) iterably. Batch sampler used by paddle.io.DataLoadershould be a subclass ofpaddle.io.BatchSampler, BatchSampler subclasses should implement following methods:__iter__: return mini-batch indices iterably.__len__: get mini-batch number in an epoch.- Parameters
- 
           - dataset (Dataset) – this could be a - paddle.io.Datasetimplement or other python object which implemented- __len__for BatchSampler to get indices as the range of- datasetlength. Default None.
- sampler (Sampler) – this could be a - paddle.io.Datasetinstance which implemented- __iter__to yield sample indices.- samplerand- datasetcan not be set in the same time. If- sampleris set,- shuffleshould not be set. Default None.
- shuffle (bool) – whether to shuffle indices order before genrating batch indices. Default False. 
- batch_size (int) – sample indice number in a mini-batch indices. 
- drop_last (bool) – whether drop the last incomplete batch dataset size is not divisible by the batch size. Default False 
 
- Returns
- 
           an iterable object for indices iterating 
- Return type
- 
           BatchSampler 
 Examples from paddle.io import RandomSampler, BatchSampler, Dataset # init with dataset class RandomDataset(Dataset): def __init__(self, num_samples): self.num_samples = num_samples def __getitem__(self, idx): image = np.random.random([784]).astype('float32') label = np.random.randint(0, 9, (1, )).astype('int64') return image, label def __len__(self): return self.num_samples bs = BatchSampler(dataset=RandomDataset(100), shuffle=False, batch_size=16, drop_last=False) for batch_indices in bs: print(batch_indices) # init with sampler sampler = RandomSampler(RandomDataset(100)) bs = BatchSampler(sampler=sampler, batch_size=8, drop_last=True) for batch_indices in bs: print(batch_indices) see paddle.io.DataLoader 
