百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 编程字典 > 正文

PyTorch系列之DataLoader的dataset类定义,小迷糊的一步搞定

toyiye 2024-06-24 19:13 11 浏览 0 评论

大家好,我是钱多多先森。在学习CV的路上,我又找到了坑可以挖,今天我们一起简单挖一挖python中关于class类的问题,进一步讨论pytorch中数据读取问题。


大家都知道,在深度学习pytorch框架下,进行训练数据的准备阶段,我们经常会遇到如下这样的数据准备工作:

# ---- DATA GENERATION ----
dataset = DatasetGenerator(Config.PATH_DATA_TRAIN, Config.PATH_TXT_TRAIN, 
                            Config.IMG_RESIZE_W, Config.IMG_RESIZE_H, 
                            is_pretrained=False, mode="train")
data = torch.utils.data.DataLoader(dataset=dataset, batch_size=Config.BATCH_SIZE, shuffle=True,
                  num_workers=Config.NUM_WORKERS, pin_memory=True)

之后,dataset按照BATCH_SIZE的大小,进行打包,批次迭代进入网络,训练优化模型。

其中,在torch.utils.data.DataLoader的传入数据dataset的预处理,一般都遵循如下的类(class)结构:

class DatasetGenerator(Dataset):

    def __init__(self, path_data, path_txt, imgsize_w, imgsize_h):
        self.imgsize_w = imgsize_w
        self.imgsize_h = imgsize_h
        self.list_path_raw = []
        self.list_path_label = []
       
        fileDescriptor = open(path_txt, "r")
        line = True
        while line:      
            line = fileDescriptor.readline() 
            if line:  
                lineItems = line.split()
                path_raw = path_data + lineItems[0]  
                self.list_path_raw.append(path_raw)

                imageLabel = lineItems[1:]
                imageLabel = [int(i) for i in imageLabel]
                self.list_path_label.append(imageLabel)   
        fileDescriptor.close()
    
    def __getitem__(self, index):
        # ---- raw ----
        img_raw = Image.open(self.list_path_raw[index]).convert("RGB")
        img_raw = img_raw.resize((self.imgsize_w, self.imgsize_h), resample=Image.LANCZOS)
        # ---- label ----
        imageLabel = torch.FloatTensor(self.list_path_label[index])
        # ---- Data augmentation ----
        transform_raw = transforms.Compose(transforms.ToTensor())
        img_raw = transform_raw(img_raw)
        return img_raw, imageLabel

    def __len__(self):
        return len(self.list_path_raw)

简单点就形如下面这样:

class DatasetGenerator(Dataset):

    def __init__(self, param):
        
    def __getitem__(self, index):
        return 

    def __len__(self):
        return 

这里不禁就引发好奇:def __init__,def __getitem__, def __len__是什么?有什么作用呢?

本节我们就带着这样一个问题,深入浅出的探讨下, pytorch数据处理class问题。全文简单易懂,大约需要 3 minute。

1. 常规类class学习

社会我小白君,人狠话不多,上来就是一波示范,如下:

class Student:

    def __init__(self, name, age):
        self.name = name
        self.age = age

    def detail(self):
        print(self.name)
        print(self.age)


student = Student('chengd', 18)
student.detail()

output:

chengd
18

常规操作

  1. 定义类
  2. 实例化类
  3. 传参
  4. 调用类函数

引发思考:我们没有调用def __init__(self, name, age)函数,为什么name, age能够在函数detail中被调用呢?

原来 __init__(self, name, age)初始化实例数学,构造函数,类实例化是自动执行的。原来如此,__init__在实例化类的时候,就已经执行了。

做下大胆的猜想,既然__init__ 会自动执行,那直接把函数detail放到__init__ 里面,岂不是都不用调用了?尝试下

class Student:

    def __init__(self, name, age):
        self.name = name
        self.age = age

        self.detail()

    def detail(self):
        print(self.name)
        print(self.age)

student = Student('chengd', 18)

output:(果然是的,验证通过)

chengd
18

2. 认识class内置函数

到这个时候,我就开始联想了:在定义类时候,偶尔还会看到很多的形如 __init__的函数,比如__getitem__,__len__等等,搜索下,果然发现了不少。他们都是class的内置方法。

  • init(self,...) 初始化对象,在创建新对象时调用
  • del(self) 释放对象,在对象被删除之前调用
  • new(cls,*args,**kwd) 实例的生成操作
  • str(self) 在使用print语句时被调用
  • getitem(self,key) 获取序列的索引key对应的值,等价于seq[key]
  • len(self) 在调用内联函数len()时被调用
  • cmp(stc,dst) 比较两个对象src和dst
  • getattr(s,name) 获取属性的值
  • setattr(s,name,value) 设置属性的值
  • delattr(s,name) 删除name属性
  • getattribute() getattribute()功能与__getattr__()类似
  • gt(self,other) 判断self对象是否大于other对象
  • lt(slef,other) 判断self对象是否小于other对象
  • ge(slef,other) 判断self对象是否大于或者等于other对象
  • le(slef,other) 判断self对象是否小于或者等于other对象
  • eq(slef,other) 判断self对象是否等于other对象
  • call(self,*args) 把实例对象作为函数调用

一一罗列就太多了,挑本次我们用到的谈谈,__getitem__和__len__

2.1 __getitem__(self)

__getitem__(self, index): 如果类把某个属性定义为序列,可以使用__getitem__()获取序列的索引 index 对应的值,输出序列属性中的某个元素,假设学校多个学生,可以通过__getitem__()方法获取每个学生的姓名和年龄。

class Student:

    def __init__(self, name, age):
        self.name = name
        self.age = age

    def __getitem__(self, index):
        return self.name[index], self.age[index]


student = Student(['Xiaoming','Didi','Tom'], [18,16,26])
print(student[1], "\n")
for item in student:
    print(item)

output:

('Didi', 16) 

('Xiaoming', 18)
('Didi', 16)
('Tom', 26)

2.2 __len__(self)

__len__(self) :在调用内联函数len()时被调用,len 经常用来表示长度

小白胡想:len(list)就是list的长度,这里究竟是不是也是这样的呢?

class Student:

    def __init__(self, name, age):
        self.name = name
        self.age = age

    def __getitem__(self, index):
        return self.name[index], self.age[index]

    def __len__(self):
        return len(self.name)


student = Student(['Xiaoming','Didi','Tom'], [18,16,26])
print(student[1], "\n")
for item in student:
    print(item)

print('\n',len(student))

output:

('Didi', 16) 

('Xiaoming', 18)
('Didi', 16)
('Tom', 26)

 3

3. 总结

到这里,基本已经弄明白为什么在pytorch中,定义数据类时候,为什么需要使用了__getitem__ 和 __len__两个内置函数,而不是其他别的。我理解的原因总结有以下下3点:

  1. init(self) 必备,用于传参,数据初始化,比如我们这里数据的存储位置、宽、高等信息,这是必不可少的;
  2. getitem(self, index) 获取序列的索引index对应的值,image和label,方便后面迭代时候获取数据;
  3. len(self) 总长度。

torch.utils.data.DataLoader中dataset我们已经给它讲明白了,究竟后面他是如何用DataLoader打包的,建议去研读Pytorch的源码,后面有时间我们专门开一起,聊聊这个。

一个分享技术、生活、理财的头条号。关注我,更多AI(人工智能)、图像处理、Python、医疗影像、理财和理想内容,这里都有。

往期回顾

相关推荐

为何越来越多的编程语言使用JSON(为什么编程)

JSON是JavascriptObjectNotation的缩写,意思是Javascript对象表示法,是一种易于人类阅读和对编程友好的文本数据传递方法,是JavaScript语言规范定义的一个子...

何时在数据库中使用 JSON(数据库用json格式存储)

在本文中,您将了解何时应考虑将JSON数据类型添加到表中以及何时应避免使用它们。每天?分享?最新?软件?开发?,Devops,敏捷?,测试?以及?项目?管理?最新?,最热门?的?文章?,每天?花?...

MySQL 从零开始:05 数据类型(mysql数据类型有哪些,并举例)

前面的讲解中已经接触到了表的创建,表的创建是对字段的声明,比如:上述语句声明了字段的名称、类型、所占空间、默认值和是否可以为空等信息。其中的int、varchar、char和decimal都...

JSON对象花样进阶(json格式对象)

一、引言在现代Web开发中,JSON(JavaScriptObjectNotation)已经成为数据交换的标准格式。无论是从前端向后端发送数据,还是从后端接收数据,JSON都是不可或缺的一部分。...

深入理解 JSON 和 Form-data(json和formdata提交区别)

在讨论现代网络开发与API设计的语境下,理解客户端和服务器间如何有效且可靠地交换数据变得尤为关键。这里,特别值得关注的是两种主流数据格式:...

JSON 语法(json 语法 priority)

JSON语法是JavaScript语法的子集。JSON语法规则JSON语法是JavaScript对象表示法语法的子集。数据在名称/值对中数据由逗号分隔花括号保存对象方括号保存数组JS...

JSON语法详解(json的语法规则)

JSON语法规则JSON语法是JavaScript对象表示法语法的子集。数据在名称/值对中数据由逗号分隔大括号保存对象中括号保存数组注意:json的key是字符串,且必须是双引号,不能是单引号...

MySQL JSON数据类型操作(mysql的json)

概述mysql自5.7.8版本开始,就支持了json结构的数据存储和查询,这表明了mysql也在不断的学习和增加nosql数据库的有点。但mysql毕竟是关系型数据库,在处理json这种非结构化的数据...

JSON的数据模式(json数据格式示例)

像XML模式一样,JSON数据格式也有Schema,这是一个基于JSON格式的规范。JSON模式也以JSON格式编写。它用于验证JSON数据。JSON模式示例以下代码显示了基本的JSON模式。{"...

前端学习——JSON格式详解(后端json格式)

JSON(JavaScriptObjectNotation)是一种轻量级的数据交换格式。易于人阅读和编写。同时也易于机器解析和生成。它基于JavaScriptProgrammingLa...

什么是 JSON:详解 JSON 及其优势(什么叫json)

现在程序员还有谁不知道JSON吗?无论对于前端还是后端,JSON都是一种常见的数据格式。那么JSON到底是什么呢?JSON的定义...

PostgreSQL JSON 类型:处理结构化数据

PostgreSQL提供JSON类型,以存储结构化数据。JSON是一种开放的数据格式,可用于存储各种类型的值。什么是JSON类型?JSON类型表示JSON(JavaScriptO...

JavaScript:JSON、三种包装类(javascript 包)

JOSN:我们希望可以将一个对象在不同的语言中进行传递,以达到通信的目的,最佳方式就是将一个对象转换为字符串的形式JSON(JavaScriptObjectNotation)-JS的对象表示法...

Python数据分析 只要1分钟 教你玩转JSON 全程干货

Json简介:Json,全名JavaScriptObjectNotation,JSON(JavaScriptObjectNotation(记号、标记))是一种轻量级的数据交换格式。它基于J...

比较一下JSON与XML两种数据格式?(json和xml哪个好)

JSON(JavaScriptObjectNotation)和XML(eXtensibleMarkupLanguage)是在日常开发中比较常用的两种数据格式,它们主要的作用就是用来进行数据的传...

取消回复欢迎 发表评论:

请填写验证码