class ImageFolder ( root, loader=None, extensions=None, transform=None, is_valid_file=None ) [source]

A generic data loader where the samples are arranged in this way:

root/1.ext root/2.ext root/sub_dir/3.ext

  • root (string) – Root directory path.

  • loader (callable, optional) – A function to load a sample given its path.

  • extensions (list[str]|tuple[str], optional) – A list of allowed extensions. both extensions and is_valid_file should not be passed.

  • transform (callable, optional) – A function/transform that takes in a sample and returns a transformed version.

  • is_valid_file – A function that takes path of a file and check if the file is a valid file (used to check of corrupt files) both extensions and is_valid_file should not be passed.


import os
import cv2
import tempfile
import shutil
import numpy as np
from import ImageFolder

def make_fake_dir():
    data_dir = tempfile.mkdtemp()

    for i in range(2):
        sub_dir = os.path.join(data_dir, 'class_' + str(i))
        if not os.path.exists(sub_dir):
        for j in range(2):
            fake_img = (np.random.random((32, 32, 3)) * 255).astype('uint8')
            cv2.imwrite(os.path.join(sub_dir, str(j) + '.jpg'), fake_img)
    return data_dir

temp_dir = make_fake_dir()
data_folder = ImageFolder(temp_dir)

for items in data_folder: