torchvision.datasets中folder.py文件是数据处理中的重要文件,每个入门的学生都会被要求看这个代码,这里是我阅读这个代码的阅读笔记,希望可以对你们理解代码有一点帮助。
理解:
ImageFolder :
转化为torch可识别的dataset格式,可被dataloader包装
文件夹格式:Root/dog/img
class ImageFolder(data.Dataset): # 继承data.Dataset def __init__(self): # 初始化属性和参数 self.name = name # 可在整个类使用 计算classes 计算self.imgs # (图片路径,图片类别) def __getitem__(self, index): 返回可索引的数据集格式 返回(图片格式,图片类别) def __len__(self): 返回数据集的大小
重点函数:
1)for root, _, fnames in sorted(os.walk(d)): # os.walk:遍历目录下所有内容,产生三元组 # (dirpath, dirnames, filenames)【文件夹路径, 文件夹名字, 文件名】 2)注意:图片路径 => 图片格式 3)图片类别的文件名(str)=> 类别名称
代码:
import torch.utils.data as data #PIL: Python Image Library缩写,图像处理模块 # Image,ImageFont,ImageDraw,ImageFilter from PIL import Image import os import os.path # 图片扩展(图片格式) IMG_EXTENSIONS = [ '.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', ] # 判断是不是图片文件 def is_image_file(filename): # 只要文件以IMG_EXTENSIONS结尾,就是图片 # 注意any的使用 return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) # 结果:classes:['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'] # classes_to_idx:{'1': 1, '0': 0, '3': 3, '2': 2, '5': 5, '4': 4, '7': 7, '6': 6, '9': 9, '8': 8} def find_classes(dir): ''' 返回dir下的类别名,classes:所有的类别,class_to_idx:将文件中str的类别名转化为int类别 classes为目录下所有文件夹名字的集合 ''' # os.listdir:以列表的形式显示当前目录下的所有文件名和目录名,但不会区分文件和目录。 # os.path.isdir:判定对象是否是目录,是则返回True,否则返回False # os.path.join:连接目录和文件名 classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] # sort:排序 classes.sort() # 将文件名中得到的类别转化为数字class_to_idx['3'] = 3 class_to_idx = {classes[i]: i for i in range(len(classes))} return classes, class_to_idx # class_to_idx :{'0': 0, '1': 1, '2': 2, '3': 3, '4': 4, '5': 5, '6': 6, '7': 7, '8': 8, '9': 9} # 如果文件是图片文件,则保留它的路径,和索引至images(path,class_to_idx) def make_dataset(dir, class_to_idx): # 返回(图片的路径,图片的类别) # 打开文件夹,一个个索引 images = [] # os.path.expanduser(path):把path中包含的"~"和"~user"转换成用户目录 dir = os.path.expanduser(dir) for target in sorted(os.listdir(dir)): d = os.path.join(dir, target) if not os.path.isdir(d): continue # os.walk:遍历目录下所有内容,产生三元组 # (dirpath, dirnames, filenames)【文件夹路径, 文件夹名字, 文件名】 for root, _, fnames in sorted(os.walk(d)): for fname in sorted(fnames): if is_image_file(fname): path = os.path.join(root, fname) # 图片的路径 item = (path, class_to_idx[target]) # (图片的路径,图片类别) images.append(item) return images # 打开路径下的图片,并转化为RGB模式 def pil_loader(path): # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) # with as : 安全方面,可替换:try,finally # 'r':以读方式打开文件,可读取文件信息 # 'b':以二进制模式打开文件,而不是文本 with open(path, 'rb') as f: with Image.open(f) as img: # convert:,用于图像不同模式图像之间的转换,这里转换为‘RGB’ return img.convert('RGB') def accimage_loader(path): # accimge:高性能图像加载和增强程序模拟的程序。 import accimage try: return accimage.Image(path) except IOError: # Potentially a decoding problem, fall back to PIL.Image return pil_loader(path) def default_loader(path): # get_image_backend:获取加载图像的包的名称 from torchvision import get_image_backend if get_image_backend() == 'accimage': return accimage_loader(path) else: return pil_loader(path) class ImageFolder(data.Dataset): """A generic data loader where the images are arranged in this way: :: root/dog/xxx.png root/dog/xxy.png root/dog/xxz.png root/cat/123.png root/cat/nsdf3.png root/cat/asd932_.png Args: root (string): Root directory path. transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. loader (callable, optional): A function to load an image given its path. Attributes: classes (list): List of the class names. class_to_idx (dict): Dict with items (class_name, class_index). imgs (list): List of (image path, class_index) tuples """ # 初始化,继承参数 def __init__(self, root, transform=None, target_transform=None, loader=default_loader): # TODO # 1. Initialize file path or list of file names. # 找到root的文件和索引 classes, class_to_idx = find_classes(root) # 保存路径下图片文件路径和索引至imgs imgs = make_dataset(root, class_to_idx) if len(imgs) == 0: raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n" "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) self.root = root self.imgs = imgs self.classes = classes self.class_to_idx = class_to_idx self.transform = transform self.target_transform = target_transform self.loader = loader def __getitem__(self, index): """ Args: index (int): Index Returns: tuple: (image, target) where target is class_index of the target class. """ # TODO # 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open). # 2. Preprocess the data (e.g. torchvision.Transform). # 3. Return a data pair (e.g. image and label). #这里需要注意的是,第一步:read one data,是一个data path, target = self.imgs[index] # 这里返回的是图片路径,而需要的是图片格式 img = self.loader(path) # 将图片路径加载成所需图片格式 if self.transform is not None: img = self.transform(img) if self.target_transform is not None: target = self.target_transform(target) return img, target def __len__(self): # return the total size of your dataset. return len(self.imgs)