百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 编程字典 > 正文

PyTorch的Tensor和自动求导

toyiye 2024-06-21 12:39 12 浏览 0 评论

深度学习的基本流程是:设计神经网络结构,形成计算图,选择一个合适的损失函数,使用损失函数衡量预测值与真实值之间的差异。具体到一轮计算:我们需要将一个批次的数据喂给神经网络,神经网络前向传播(Forward),求得损失函数的值;然后是反向传播(Back Propagation)过程:求得损失函数对模型各参数的导数,利用梯度下降法来更新模型各参数。深度学习框架的最重要的一项功能就是帮我们完成了求导的过程。本文简单介绍一下PyTorch中张量Tensor以及其自动求导功能。

PyTorch中的Tensor

PyTorch的Tensor尽量兼容了NumPy的数据结构,除了可以像NumPy那样存储和计算张量外,Tensor还可以:

  1. 被拷贝到GPU上进行计算加速。
  2. 包含了自动求导的功能。

我们先看一下在PyTorch里Tensor的定义,以及如何创建Tensor。

torch.tensor(data, *, dtype=None, device=None, requires_grad=False, pin_memory=False) → Tensor

如果想创建Tensor,需要使用torch.tensor(),其中:

  • data是张量本身,可以是列表、元组、NumPy的ndarray,标量等等。
  • dtype是数据类型,比如float32等。
  • device指定存储该张量的的设备。这个变量主要是针对CPU/GPU异构编程,在CUDA这种异构编程体系里,CPU被称作主机(Host),某张GPU卡被称作设备(Device)。
  • requires_grad表示是否需要求梯度(Gradient)或者说是否需要求导。如果设置为True,表示需要求导。默认这个值是False

我们创建Tensor:

x = torch.ones(2, 2, requires_grad=True)
y = x + 2
z = y * y * 3
out = z.mean()

print(x)
print(y)
print(z)
print(out)

输出:

tensor([[1., 1.],
        [1., 1.]], requires_grad=True)
tensor([[3., 3.],
        [3., 3.]], grad_fn=<AddBackward0>)
tensor([[27., 27.],
        [27., 27.]], grad_fn=<MulBackward0>)
tensor(27., grad_fn=<MeanBackward0>)

这里,我们可以将x假想成神经网络的输入,y是神经网络的隐藏层,z是神经网络的输出,最后的out是损失函数;或者说我们建立了一个简单的计算图,数据从x流向out。可以看到,只要x设置了requires_grad=True,那么计算图后续的节点用grad_fn记录了计算图中各步的传播过程。

现在,我们从out开始进行反向传播:

out.backward()

执行完.backward()后,PyTorch帮我们把计算图中的梯度都计算好了,如下:

print(x.grad)

输出:

tensor([[4.5000, 4.5000],
        [4.5000, 4.5000]])

拓展到深度学习,从输入开始,每层都有大量参数Wb,这些参数也是Tensor结构。给Tensor设置了requires_grad=True后,PyTorch会跟踪Tensor之后的所有计算,经过.backward()后,PyTorch自动帮我们计算损失函数对于这些参数的梯度,梯度存储在了.grad属性里,PyTorch会按照梯度下降法更新参数。

在PyTorch中,.backward()方法默认只会对计算图中的叶子节点求导。在上面的例子里,x就是叶子节点,yz都是中间变量,他们的.grad属性都是None。而且,PyTorch目前只支持浮点数的求导。

另外,PyTorch的自动求导一般只是标量向量/矩阵求导。在深度学习中,最后的损失函数一般是一个标量值,是样本数据经过前向传播得到的损失值的和,而输入数据是一个向量或矩阵。在刚才的例子中,y是一个矩阵,.mean()y求导,得到的是标量。

Tensor与自动求导

下面是一个使用PyTorch训练神经网络的例子。在这个例子中,我们随机初始化了输入x和输出y,分别作为模型的特征和要拟合的目标值。这个模型有两层,第一层是输入层,第二层为隐藏层,模型的前向传播如下所示:



dtype = torch.float
device = torch.device("cpu") # 使用CPU
# device = torch.device("cuda:0") # 如果使用GPU,请打开注释

# N: batch size
# D_in: 输入维度
# H: 隐藏层
# D_out: 输出维度 
N, D_in, H, D_out = 64, 1000, 100, 10

# 初始化随机数x, y
# x, y用来模拟机器学习的输入和输出
x = torch.randn(N, D_in, device=device, dtype=dtype)
y = torch.randn(N, D_out, device=device, dtype=dtype)

# 初始化模型的参数w1和w2
# 均设置为 requires_grad=True
# PyTorch会跟踪w1和w2上的计算,帮我们自动求导
w1 = torch.randn(D_in, H, device=device, dtype=dtype, requires_grad=True)
w2 = torch.randn(H, D_out, device=device, dtype=dtype, requires_grad=True)

learning_rate = 1e-6
for t in range(500):
    # 前向传播过程:
    # h1 = relu(x * w1)
    # y = h1 * w2
    y_pred = x.mm(w1).clamp(min=0).mm(w2)

    # 计算损失函数loss
    # loss是误差的平方和
    loss = (y_pred - y).pow(2).sum()
    if t % 100 == 99:
        print(t, loss.item())

    # 反向传播过程:
    # PyTorch会对设置了requires_grad=True的Tensor自动求导,本例中是w1和w2
    # 执行完backward()后,w1.grad 和 w2.grad 里存储着对于loss的梯度
    loss.backward()

    # 根据梯度,更新参数w1和w2
    with torch.no_grad():
        w1 -= learning_rate * w1.grad
        w2 -= learning_rate * w2.grad

        # 将 w1.grad 和 w2.grad 中的梯度设为零
        # PyTorch的backward()方法计算梯度会默认将本次计算的梯度与.grad中已有的梯度加和
        # 必须在下次反向传播前先将.grad中的梯度清零
        w1.grad.zero_()
        w2.grad.zero_()

在这个例子中,我们对w1w2设置了requires_grad=True,损失函数是loss,PyTorch会自动跟踪这两个变量上的计算,当执行backward()时,PyTorch帮我们计算了loss对于w1w2的导数。每次迭代后,我们都要根据导数,更新w1w2。PyTorch的backward()方法计算梯度会默认将本次计算的梯度与.grad中已有的梯度加和,下次迭代前,需要将.grad中的梯度清零,否则影响下一轮迭代的梯度值。经过多轮训练,模型的loss不断变小。

以上就是使用PyTorch的Tensor和自动求导来构建神经网络模型的过程。

相关推荐

为何越来越多的编程语言使用JSON(为什么编程)

JSON是JavascriptObjectNotation的缩写,意思是Javascript对象表示法,是一种易于人类阅读和对编程友好的文本数据传递方法,是JavaScript语言规范定义的一个子...

何时在数据库中使用 JSON(数据库用json格式存储)

在本文中,您将了解何时应考虑将JSON数据类型添加到表中以及何时应避免使用它们。每天?分享?最新?软件?开发?,Devops,敏捷?,测试?以及?项目?管理?最新?,最热门?的?文章?,每天?花?...

MySQL 从零开始:05 数据类型(mysql数据类型有哪些,并举例)

前面的讲解中已经接触到了表的创建,表的创建是对字段的声明,比如:上述语句声明了字段的名称、类型、所占空间、默认值和是否可以为空等信息。其中的int、varchar、char和decimal都...

JSON对象花样进阶(json格式对象)

一、引言在现代Web开发中,JSON(JavaScriptObjectNotation)已经成为数据交换的标准格式。无论是从前端向后端发送数据,还是从后端接收数据,JSON都是不可或缺的一部分。...

深入理解 JSON 和 Form-data(json和formdata提交区别)

在讨论现代网络开发与API设计的语境下,理解客户端和服务器间如何有效且可靠地交换数据变得尤为关键。这里,特别值得关注的是两种主流数据格式:...

JSON 语法(json 语法 priority)

JSON语法是JavaScript语法的子集。JSON语法规则JSON语法是JavaScript对象表示法语法的子集。数据在名称/值对中数据由逗号分隔花括号保存对象方括号保存数组JS...

JSON语法详解(json的语法规则)

JSON语法规则JSON语法是JavaScript对象表示法语法的子集。数据在名称/值对中数据由逗号分隔大括号保存对象中括号保存数组注意:json的key是字符串,且必须是双引号,不能是单引号...

MySQL JSON数据类型操作(mysql的json)

概述mysql自5.7.8版本开始,就支持了json结构的数据存储和查询,这表明了mysql也在不断的学习和增加nosql数据库的有点。但mysql毕竟是关系型数据库,在处理json这种非结构化的数据...

JSON的数据模式(json数据格式示例)

像XML模式一样,JSON数据格式也有Schema,这是一个基于JSON格式的规范。JSON模式也以JSON格式编写。它用于验证JSON数据。JSON模式示例以下代码显示了基本的JSON模式。{"...

前端学习——JSON格式详解(后端json格式)

JSON(JavaScriptObjectNotation)是一种轻量级的数据交换格式。易于人阅读和编写。同时也易于机器解析和生成。它基于JavaScriptProgrammingLa...

什么是 JSON:详解 JSON 及其优势(什么叫json)

现在程序员还有谁不知道JSON吗?无论对于前端还是后端,JSON都是一种常见的数据格式。那么JSON到底是什么呢?JSON的定义...

PostgreSQL JSON 类型:处理结构化数据

PostgreSQL提供JSON类型,以存储结构化数据。JSON是一种开放的数据格式,可用于存储各种类型的值。什么是JSON类型?JSON类型表示JSON(JavaScriptO...

JavaScript:JSON、三种包装类(javascript 包)

JOSN:我们希望可以将一个对象在不同的语言中进行传递,以达到通信的目的,最佳方式就是将一个对象转换为字符串的形式JSON(JavaScriptObjectNotation)-JS的对象表示法...

Python数据分析 只要1分钟 教你玩转JSON 全程干货

Json简介:Json,全名JavaScriptObjectNotation,JSON(JavaScriptObjectNotation(记号、标记))是一种轻量级的数据交换格式。它基于J...

比较一下JSON与XML两种数据格式?(json和xml哪个好)

JSON(JavaScriptObjectNotation)和XML(eXtensibleMarkupLanguage)是在日常开发中比较常用的两种数据格式,它们主要的作用就是用来进行数据的传...

取消回复欢迎 发表评论:

请填写验证码