class DatasetFolder ( root, loader=None, extensions=None, transform=None, is_valid_file=None ) [source]

A generic data loader where the samples are arranged in this way:

root/class_a/1.ext root/class_a/2.ext root/class_a/3.ext

root/class_b/123.ext root/class_b/456.ext root/class_b/789.ext

  • root (string) – Root directory path.

  • loader (callable|optional) – A function to load a sample given its path.

  • extensions (list[str]|tuple[str]|optional) – A list of allowed extensions. both extensions and is_valid_file should not be passed.

  • transform (callable|optional) – A function/transform that takes in a sample and returns a transformed version.

  • is_valid_file – A function that takes path of a file and check if the file is a valid file (used to check of corrupt files) both extensions and is_valid_file should not be passed.


import os
import cv2
import tempfile
import shutil
import numpy as np
from import DatasetFolder

def make_fake_dir():
    data_dir = tempfile.mkdtemp()

    for i in range(2):
        sub_dir = os.path.join(data_dir, 'class_' + str(i))
        if not os.path.exists(sub_dir):
        for j in range(2):
            fake_img = (np.random.random((32, 32, 3)) * 255).astype('uint8')
            cv2.imwrite(os.path.join(sub_dir, str(j) + '.jpg'), fake_img)
    return data_dir

temp_dir = make_fake_dir()
# temp_dir is root dir
# temp_dir/class_1/img1_1.jpg
# temp_dir/class_2/img2_1.jpg
data_folder = DatasetFolder(temp_dir)

for items in data_folder: