大家好,我是钱多多先森。在学习CV的路上,我又找到了坑可以挖,今天我们一起简单挖一挖python中关于class类的问题,进一步讨论pytorch中数据读取问题。
大家都知道,在深度学习pytorch框架下,进行训练数据的准备阶段,我们经常会遇到如下这样的数据准备工作:
# ---- DATA GENERATION ----
dataset = DatasetGenerator(Config.PATH_DATA_TRAIN, Config.PATH_TXT_TRAIN,
Config.IMG_RESIZE_W, Config.IMG_RESIZE_H,
is_pretrained=False, mode="train")
data = torch.utils.data.DataLoader(dataset=dataset, batch_size=Config.BATCH_SIZE, shuffle=True,
num_workers=Config.NUM_WORKERS, pin_memory=True)
之后,dataset按照BATCH_SIZE的大小,进行打包,批次迭代进入网络,训练优化模型。
其中,在torch.utils.data.DataLoader的传入数据dataset的预处理,一般都遵循如下的类(class)结构:
class DatasetGenerator(Dataset):
def __init__(self, path_data, path_txt, imgsize_w, imgsize_h):
self.imgsize_w = imgsize_w
self.imgsize_h = imgsize_h
self.list_path_raw = []
self.list_path_label = []
fileDescriptor = open(path_txt, "r")
line = True
while line:
line = fileDescriptor.readline()
if line:
lineItems = line.split()
path_raw = path_data + lineItems[0]
self.list_path_raw.append(path_raw)
imageLabel = lineItems[1:]
imageLabel = [int(i) for i in imageLabel]
self.list_path_label.append(imageLabel)
fileDescriptor.close()
def __getitem__(self, index):
# ---- raw ----
img_raw = Image.open(self.list_path_raw[index]).convert("RGB")
img_raw = img_raw.resize((self.imgsize_w, self.imgsize_h), resample=Image.LANCZOS)
# ---- label ----
imageLabel = torch.FloatTensor(self.list_path_label[index])
# ---- Data augmentation ----
transform_raw = transforms.Compose(transforms.ToTensor())
img_raw = transform_raw(img_raw)
return img_raw, imageLabel
def __len__(self):
return len(self.list_path_raw)
简单点就形如下面这样:
class DatasetGenerator(Dataset):
def __init__(self, param):
def __getitem__(self, index):
return
def __len__(self):
return
这里不禁就引发好奇:def __init__,def __getitem__, def __len__是什么?有什么作用呢?
本节我们就带着这样一个问题,深入浅出的探讨下, pytorch数据处理class问题。全文简单易懂,大约需要 3 minute。
1. 常规类class学习
社会我小白君,人狠话不多,上来就是一波示范,如下:
class Student:
def __init__(self, name, age):
self.name = name
self.age = age
def detail(self):
print(self.name)
print(self.age)
student = Student('chengd', 18)
student.detail()
output:
chengd
18
常规操作
- 定义类
- 实例化类
- 传参
- 调用类函数
引发思考:我们没有调用def __init__(self, name, age)函数,为什么name, age能够在函数detail中被调用呢?
原来 __init__(self, name, age)初始化实例数学,构造函数,类实例化是自动执行的。原来如此,__init__在实例化类的时候,就已经执行了。
做下大胆的猜想,既然__init__ 会自动执行,那直接把函数detail放到__init__ 里面,岂不是都不用调用了?尝试下
class Student:
def __init__(self, name, age):
self.name = name
self.age = age
self.detail()
def detail(self):
print(self.name)
print(self.age)
student = Student('chengd', 18)
output:(果然是的,验证通过)
chengd
18
2. 认识class内置函数
到这个时候,我就开始联想了:在定义类时候,偶尔还会看到很多的形如 __init__的函数,比如__getitem__,__len__等等,搜索下,果然发现了不少。他们都是class的内置方法。
- init(self,...) 初始化对象,在创建新对象时调用
- del(self) 释放对象,在对象被删除之前调用
- new(cls,*args,**kwd) 实例的生成操作
- str(self) 在使用print语句时被调用
- getitem(self,key) 获取序列的索引key对应的值,等价于seq[key]
- len(self) 在调用内联函数len()时被调用
- cmp(stc,dst) 比较两个对象src和dst
- getattr(s,name) 获取属性的值
- setattr(s,name,value) 设置属性的值
- delattr(s,name) 删除name属性
- getattribute() getattribute()功能与__getattr__()类似
- gt(self,other) 判断self对象是否大于other对象
- lt(slef,other) 判断self对象是否小于other对象
- ge(slef,other) 判断self对象是否大于或者等于other对象
- le(slef,other) 判断self对象是否小于或者等于other对象
- eq(slef,other) 判断self对象是否等于other对象
- call(self,*args) 把实例对象作为函数调用
一一罗列就太多了,挑本次我们用到的谈谈,__getitem__和__len__。
2.1 __getitem__(self)
__getitem__(self, index): 如果类把某个属性定义为序列,可以使用__getitem__()获取序列的索引 index 对应的值,输出序列属性中的某个元素,假设学校多个学生,可以通过__getitem__()方法获取每个学生的姓名和年龄。
class Student:
def __init__(self, name, age):
self.name = name
self.age = age
def __getitem__(self, index):
return self.name[index], self.age[index]
student = Student(['Xiaoming','Didi','Tom'], [18,16,26])
print(student[1], "\n")
for item in student:
print(item)
output:
('Didi', 16)
('Xiaoming', 18)
('Didi', 16)
('Tom', 26)
2.2 __len__(self)
__len__(self) :在调用内联函数len()时被调用,len 经常用来表示长度
小白胡想:len(list)就是list的长度,这里究竟是不是也是这样的呢?
class Student:
def __init__(self, name, age):
self.name = name
self.age = age
def __getitem__(self, index):
return self.name[index], self.age[index]
def __len__(self):
return len(self.name)
student = Student(['Xiaoming','Didi','Tom'], [18,16,26])
print(student[1], "\n")
for item in student:
print(item)
print('\n',len(student))
output:
('Didi', 16)
('Xiaoming', 18)
('Didi', 16)
('Tom', 26)
3
3. 总结
到这里,基本已经弄明白为什么在pytorch中,定义数据类时候,为什么需要使用了__getitem__ 和 __len__两个内置函数,而不是其他别的。我理解的原因总结有以下下3点:
- init(self) 必备,用于传参,数据初始化,比如我们这里数据的存储位置、宽、高等信息,这是必不可少的;
- getitem(self, index) 获取序列的索引index对应的值,image和label,方便后面迭代时候获取数据;
- len(self) 总长度。
torch.utils.data.DataLoader中dataset我们已经给它讲明白了,究竟后面他是如何用DataLoader打包的,建议去研读Pytorch的源码,后面有时间我们专门开一起,聊聊这个。
一个分享技术、生活、理财的头条号。关注我,更多AI(人工智能)、图像处理、Python、医疗影像、理财和理想内容,这里都有。
往期回顾