class Sampler ( data_source=None ) [source]

An abstract class to encapsulate methods and behaviors of samplers.

All sampler used by should be a subclass of, BatchSampler subclasses should implement following methods:

__iter__: return sample index iterably, which iterate over indices of dataset elements

__len__: the number of sample in data_source


data_source (Dataset, optional) – this could be an instance of other Python object which implemented __len__ for Sampler to get indices as the range of dataset length. Default None.


an iterable object for sample indices iterating

Return type



from import Dataset, Sampler

class RandomDataset(Dataset):
    def __init__(self, num_samples):
        self.num_samples = num_samples

    def __getitem__(self, idx):
        image = np.random.random([784]).astype('float32')
        label = np.random.randint(0, 9, (1, )).astype('int64')
        return image, label

    def __len__(self):
        return self.num_samples

class MySampler(Sampler):
    def __init__(self, data_source):
        self.data_source = data_source

    def __iter__(self):
        return iter(range(len(self.data_source)))

    def __len__(self):
        return len(self.data_source)

sampler = MySampler(data_source=RandomDataset(100))

for index in sampler:

