百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 编程字典 > 正文

PyTorch踩过的12坑精选(pytorch cgan)

toyiye 2024-08-25 15:42 4 浏览 0 评论

| hyk_1996 来源 | CSDN博客

1. nn.Module.cuda() 和 Tensor.cuda() 的作用效果差异

无论是对于模型还是数据,cuda()函数都能实现从CPU到GPU的内存迁移,但是他们的作用效果有所不同。

对于nn.Module:

model = model.cuda() 
model.cuda() 

上面两句能够达到一样的效果,即对model自身进行的内存迁移。

对于Tensor:

和nn.Module不同,调用tensor.cuda()只是返回这个tensor对象在GPU内存上的拷贝,而不会对自身进行改变。因此必须对tensor进行重新赋值,即tensor=tensor.cuda().

例子:

model = create_a_model()
tensor = torch.zeros([2,3,10,10])
model.cuda()
tensor.cuda()
model(tensor) # 会报错
tensor = tensor.cuda()
model(tensor) # 正常运行

2. PyTorch 0.4 计算累积损失的不同

以广泛使用的模式total_loss += loss.data[0]为例。Python0.4.0之前,loss是一个封装了(1,)张量的Variable,但Python0.4.0的loss现在是一个零维的标量。对标量进行索引是没有意义的(似乎会报 invalid index to scalar variable 的错误)。使用loss.item()可以从标量中获取Python数字。所以改为:

total_loss += loss.item()

如果在累加损失时未将其转换为Python数字,则可能出现程序内存使用量增加的情况。这是因为上面表达式的右侧原本是一个Python浮点数,而它现在是一个零维张量。因此,总损失累加了张量和它们的梯度历史,这可能会产生很大的autograd 图,耗费内存和计算资源。

3. PyTorch 0.4 编写不限制设备的代码

# torch.device object used throughout this script
device = torch.device("cuda" if use_cuda else "cpu")
model = MyRNN().to(device)

# train
total_loss= 0
for input, target in train_loader:
 input, target = input.to(device), target.to(device)
 hidden = input.new_zeros(*h_shape) # has the same device & dtype as `input`
 ... # get loss and optimize
 total_loss += loss.item()

# test
with torch.no_grad(): # operations inside don't track history
 for input, targetin test_loader:
 ...

4. torch.Tensor.detach()的使用

detach()的官方说明如下:

Returns a new Tensor, detached from the current graph.
 The result will never require gradient.

假设有模型A和模型B,我们需要将A的输出作为B的输入,但训练时我们只训练模型B. 那么可以这样做:

input_B = output_A.detach()

它可以使两个计算图的梯度传递断开,从而实现我们所需的功能。

5. ERROR: Unexpected bus error encountered in worker. This might be caused by insufficient shared memory (shm)

出现这个错误的情况是,在服务器上的docker中运行训练代码时,batch size设置得过大,shared memory不够(因为docker限制了shm).解决方法是,将Dataloader的num_workers设置为0.

6. pytorch中loss函数的参数设置

以CrossEntropyLoss为例:

CrossEntropyLoss(self, weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='elementwise_mean')
  • 若 reduce = False,那么 size_average 参数失效,直接返回向量形式的 loss,即batch中每个元素对应的loss.
  • 若 reduce = True,那么 loss 返回的是标量:
    • 如果 size_average = True,返回 loss.mean().
    • 如果 size_average = False,返回 loss.sum().
  • weight : 输入一个1D的权值向量,为各个类别的loss加权,如下公式所示:
  • ignore_index : 选择要忽视的目标值,使其对输入梯度不作贡献。如果 size_average = True,那么只计算不被忽视的目标的loss的均值。
  • reduction : 可选的参数有:‘none’ | ‘elementwise_mean’ | ‘sum’, 正如参数的字面意思,不解释。

7. pytorch的可重复性问题

参考这篇博文:

https://blog.csdn.net/hyk_1996/article/details/84307108

8. 多GPU的处理机制

使用多GPU时,应该记住pytorch的处理逻辑是:

1)在各个GPU上初始化模型。

2)前向传播时,把batch分配到各个GPU上进行计算。

3)得到的输出在主GPU上进行汇总,计算loss并反向传播,更新主GPU上的权值。

4)把主GPU上的模型复制到其它GPU上。

9. num_batches_tracked参数

今天读取模型参数时出现了错误

KeyError: 'unexpected key "module.bn1.num_batches_tracked" in state_dict'

经过研究发现,在pytorch 0.4.1及后面的版本里,BatchNorm层新增了num_batches_tracked参数,用来统计训练时的forward过的batch数目,源码如下(pytorch0.4.1):

 if self.training and self.track_running_stats:
 self.num_batches_tracked += 1
 if self.momentum is None: # use cumulative moving average
 exponential_average_factor = 1.0 / self.num_batches_tracked.item()
 else: # use exponential moving average
 exponential_average_factor = self.momentum

大概可以看出,这个参数和训练时的归一化的计算方式有关。

因此,我们可以知道该错误是由于训练和测试所用的pytorch版本(0.4.1版本前后的差异)不一致引起的。具体的解决方案是:如果是模型参数(Orderdict格式,很容易修改)里少了num_batches_tracked变量,就加上去,如果是多了就删掉。偷懒的做法是将load_state_dict的strict参数置为False,如下所示:

load_state_dict(torch.load(weight_path), strict=False)

还看到有人直接修改pytorch 0.4.1的源代码把num_batches_tracked参数删掉的,这就非常不建议了。

10. 训练时损失出现nan的问题

最近在训练模型时出现了损失为nan的情况,发现是个大坑。暂时先记录着。

可能导致梯度出现nan的三个原因:

1.梯度爆炸。也就是说梯度数值超出范围变成nan. 通常可以调小学习率、加BN层或者做梯度裁剪来试试看有没有解决。

2.损失函数或者网络设计。比方说,出现了除0,或者出现一些边界情况导致函数不可导,比方说log(0)、sqrt(0).

3.脏数据。可以事先对输入数据进行判断看看是否存在nan.

补充一下nan数据的判断方法:

注意!像nan或者inf这样的数值不能使用 == 或者 is 来判断!为了安全起见统一使用 math.isnan() 或者 numpy.isnan() 吧。

例如:

import numpy as np

# 判断输入数据是否存在nan
if np.any(np.isnan(input.cpu().numpy())):
 print('Input data has NaN!')

# 判断损失是否为nan
if np.isnan(loss.item()):
 print('Loss value is NaN!')

11. ValueError: Expected more than 1 value per channel when training

当batch里只有一个样本时,再调用batch_norm就会报下面这个错误:

 raise ValueError('Expected more than 1 value per channel when training, got input size {}'.format(size))

没有什么特别好的解决办法,在训练前用 num_of_samples % batch_size 算一下会不会正好剩下一个样本。

12. 优化器的weight_decay项导致的隐蔽bug

我们都知道weight_decay指的是权值衰减,即在原损失的基础上加上一个L2惩罚项,使得模型趋向于选择更小的权重参数,起到正则化的效果。但是我经常会忽略掉这一项的存在,从而引发了意想不到的问题。

这次的坑是这样的,在训练一个ResNet50的时候,网络的高层部分layer4暂时没有用到,因此也并不会有梯度回传,于是我就放心地将ResNet50的所有参数都传递给Optimizer进行更新了,想着layer4应该能保持原来的权重不变才对。但是实际上,尽管layer4没有梯度回传,但是weight_decay的作用仍然存在,它使得layer4权值越来越小,趋向于0。后面需要用到layer4的时候,发现输出异常(接近于0),才注意到这个问题的存在。

虽然这样的情况可能不容易遇到,但是还是要谨慎:暂时不需要更新的权值,一定不要传递给Optimizer,避免不必要的麻烦。

最后,我自己是一名从事了多年开发的Python老程序员,辞职目前在做自己的Python私人定制课程,今年年初我花了一个月整理了一份最适合2019年学习的Python学习干货,可以送给每一位喜欢Python的小伙伴,想要获取的可以关注我的头条号并在后台私信我:01,即可免费获取。

相关推荐

# Python 3 # Python 3字典Dictionary(1)

Python3字典字典是另一种可变容器模型,且可存储任意类型对象。字典的每个键值(key=>value)对用冒号(:)分割,每个对之间用逗号(,)分割,整个字典包括在花括号({})中,格式如...

Python第八课:数据类型中的字典及其函数与方法

Python3字典字典是另一种可变容器模型,且可存储任意类型对象。字典的每个键值...

Python中字典详解(python 中字典)

字典是Python中使用键进行索引的重要数据结构。它们是无序的项序列(键值对),这意味着顺序不被保留。键是不可变的。与列表一样,字典的值可以保存异构数据,即整数、浮点、字符串、NaN、布尔值、列表、数...

Python3.9又更新了:dict内置新功能,正式版十月见面

机器之心报道参与:一鸣、JaminPython3.8的热乎劲还没过去,Python就又双叒叕要更新了。近日,3.9版本的第四个alpha版已经开源。从文档中,我们可以看到官方透露的对dic...

Python3 基本数据类型详解(python三种基本数据类型)

文章来源:加米谷大数据Python中的变量不需要声明。每个变量在使用前都必须赋值,变量赋值以后该变量才会被创建。在Python中,变量就是变量,它没有类型,我们所说的"类型"是变...

一文掌握Python的字典(python字典用法大全)

字典是Python中最强大、最灵活的内置数据结构之一。它们允许存储键值对,从而实现高效的数据检索、操作和组织。本文深入探讨了字典,涵盖了它们的创建、操作和高级用法,以帮助中级Python开发...

超级完整|Python字典详解(python字典的方法或操作)

一、字典概述01字典的格式Python字典是一种可变容器模型,且可存储任意类型对象,如字符串、数字、元组等其他容器模型。字典的每个键值key=>value对用冒号:分割,每个对之间用逗号,...

Python3.9版本新特性:字典合并操作的详细解读

处于测试阶段的Python3.9版本中有一个新特性:我们在使用Python字典时,将能够编写出更可读、更紧凑的代码啦!Python版本你现在使用哪种版本的Python?3.7分?3.5分?还是2.7...

python 自学,字典3(一些例子)(python字典有哪些基本操作)

例子11;如何批量复制字典里的内容2;如何批量修改字典的内容3;如何批量修改字典里某些指定的内容...

Python3.9中的字典合并和更新,几乎影响了所有Python程序员

全文共2837字,预计学习时长9分钟Python3.9正在积极开发,并计划于今年10月发布。2月26日,开发团队发布了alpha4版本。该版本引入了新的合并(|)和更新(|=)运算符,这个新特性几乎...

Python3大字典:《Python3自学速查手册.pdf》限时下载中

最近有人会想了,2022了,想学Python晚不晚,学习python有前途吗?IT行业行业薪资高,发展前景好,是很多求职群里严重的香饽饽,而要进入这个高薪行业,也不是那么轻而易举的,拿信工专业的大学生...

python学习——字典(python字典基本操作)

字典Python的字典数据类型是基于hash散列算法实现的,采用键值对(key:value)的形式,根据key的值计算value的地址,具有非常快的查取和插入速度。但它是无序的,包含的元素个数不限,值...

324页清华教授撰写【Python 3 菜鸟查询手册】火了,小白入门字典

如何入门学习python...

Python3.9中的字典合并和更新,了解一下

全文共2837字,预计学习时长9分钟Python3.9正在积极开发,并计划于今年10月发布。2月26日,开发团队发布了alpha4版本。该版本引入了新的合并(|)和更新(|=)运算符,这个新特性几乎...

python3基础之字典(python中字典的基本操作)

字典和列表一样,也是python内置的一种数据结构。字典的结构如下图:列表用中括号[]把元素包起来,而字典是用大括号{}把元素包起来,只不过字典的每一个元素都包含键和值两部分。键和值是一一对应的...

取消回复欢迎 发表评论:

请填写验证码