百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 编程字典 > 正文

深度学习中的损失函数

toyiye 2024-06-21 12:08 9 浏览 0 评论

前言

你可以用神经网络来完成各种各样的任务,比如对数据进行分类,把动物图片分成猫和狗,或者做回归任务,预测每个月的收入等等。不同的任务有不同的输出,也需要用不同类型的损失函数来评估。

损失函数的选择会影响算法的性能。如果你能正确地设置损失函数,你的模型就能按照你的期望运行。

幸运的是,我们可以根据机器学习任务的特点来选择合适的损失函数。

本文将介绍深度学习中常用的损失函数,以及如何构建自定义损失函数。读完本文,你应该能够为你的项目挑选合适的损失函数。

什么是损失函数?

在我们介绍损失函数的具体内容之前,让我们回顾一下损失函数的概念。

损失函数是用来衡量预测输出和给定的目标值之间的差距。损失函数告诉我们算法模型离达到预期结果还差多少。“损失”一词意味着模型因为没有产生预期结果而要承担的代价。

例如,一个损失函数(我们用J表示)可以接受以下两个参数:

预测输出(y_pred) 目标值(y)


神经网络损失函数的图示


这个函数会通过比较模型的预测输出和期望输出来评估模型的性能。如果y_pred和y之间的差距很大,那么损失值就会很高。

如果差距很小或者值几乎相等,那么损失值就会很低。因此,你需要使用一个能够在模型训练时根据数据集的情况给予适当惩罚的损失函数。

损失函数会根据算法要解决的问题的性质而有所不同。因此为了解决不同类别的问题,应选择不同类别的损失函数。

损失函数的分类

损失函数大致可以广义地分为 3 种类型:分类损失 回归损失 排名损失

当模型预测连续值(如人的年龄)时,将使用回归损失函数。 当模型预测离散值(例如电子邮件是否为垃圾邮件)时,将使用分类损失函数。 当模型预测输入之间的相对距离时,将使用排名损失函数,例如根据产品在电子商务搜索页面上的相关性对产品进行排名。

现在我们将探讨 PyTorch 中不同类型的损失函数,进行可视化展示以理解,并且给出使用方法:

分类损失(待完成)

回归损失

平均绝对误差损失 MAE Loss/L1 Loss(Mean Absolute Error Loss/L1 Loss)

平均绝对误差是目标变量和预测变量之间的绝对误差的平均值,它反映了预测的准确性,而不考虑误差的正负。平均绝对误差的取值范围是从 0 到无穷大,越接近 0 表示预测越准确。

从零实现并可视化平均绝对误差损失

import matplotlib.pyplot as plt
import numpy as np

# 定义平均绝对误差损失函数
def mae_loss(y_true, y_pred):
  return np.mean(abs(y_true - y_pred))

# 定义真值
y_true = 0

# 定义预测值的取值区间与频率
x = np.linspace(-10000, 10000, 100)

# 计算对每个预测值的均方误差损失值
y = [mae_loss(y_true, x[i]) for i in range(len(x))]
#列表推导是从另一个可迭代对象(如列表、元组、范围等)创建列表的简洁方法。列表推导的一般语法为:
#[expression for item in iterable if condition]
#这将通过将表达式应用于满足条件的可迭代对象中的每个项目来创建新列表。该条件是可选的,可以省略。

#等价于:
# for i in range(len(x)):
#   y.append(mae_loss(y_true, x[i]))
#.append 加元素到列表末尾


# 绘制损失函数
plt.plot(x, y)
plt.xlabel("Predictions")
plt.ylabel("MAE Loss")
plt.title("MAE Loss vs Predictions")
plt.show()


平均绝对误差损失(Y轴)与预测值(X轴)的关系图


均方误差损失 MSE Loss(Mean Squared Error Loss)

均方误差 (MSE) 是最常用的回归损失函数。MSE 是目标变量和预测值之间的平方距离之和。

L_{MSE} = \frac{1}{n} \sum_{i=1}^n (y_i - \hat{y}_i)^2

其中 $n$ 是样本数量,$y_i$ 是第 $i$ 个样本的真实值,$\hat{y}_i$ 是第 $i$ 个样本的预测值。

另一种可能的写法是:

L_{MSE} = \frac{1}{n} \| y - \hat{y} \|^2

其中 $y$$\hat{y}$ 是真实值和预测值的向量,$\| \cdot \|$ 表示欧几里得范数。

从零实现并可视化均方误差损失

import matplotlib.pyplot as plt
import numpy as np

# Define the MSE loss function
def mse_loss(y_true, y_pred):
  return np.mean((y_true - y_pred) ** 2)

# Define the true value
y_true = 0

# Define the range of predictions
x = np.linspace(-10000, 10000, 100)

# Compute the MSE loss for each prediction using a list comprehension
y = [mse_loss(y_true, x[i]) for i in range(len(x))]
#列表推导是从另一个可迭代对象(如列表、元组、范围等)创建列表的简洁方法。列表推导的一般语法为:
#[expression for item in iterable if condition]
#这将通过将表达式应用于满足条件的可迭代对象中的每个项目来创建新列表。该条件是可选的,可以省略。

#等价于:
# for i in range(len(x)):
#   y.append(mse_loss(y_true, x[i]))
#.append 加元素到列表末尾


# Plot the loss function
plt.plot(x, y)
plt.xlabel("Predictions")
plt.ylabel("MSE Loss")
plt.title("MSE Loss vs Predictions")
plt.show()

均方误差损失(Y轴)与预测值(X轴)的关系图


均方根误差 RMSE()

RMSE 是均方根误差的英文缩写,它是一种衡量预测值和真实值之间差异的指标。RMSE 是均方误差(MSE)的平方根,MSE 是预测值和真实值之差的平方的平均值。RMSE 越小,说明预测值和真实值越接近,模型的拟合效果越好。均方根误差是与平均绝对误差处于相同的尺度,这方便了我们后续对平均绝对误差与均方误差的比较。

RMSE 的计算公式如下:

RMSE = \sqrt{\frac{1}{n} \sum_{i=1}^n (y_i - \hat{y}_i)^2}

其中 $n$ 是样本数量,$y_i$ 是第 $i$ 个样本的真实值,$\hat{y}_i$ 是第 $i$ 个样本的预测值。

另一种写法是:

RMSE = \sqrt{\frac{1}{n} \| y - \hat{y} \|^2}

其中 $y$$\hat{y}$ 是真实值和预测值的向量,$\| \cdot \|$ 表示欧几里得范数。

从零实现并可视化均方根误差损失函数

import matplotlib.pyplot as plt
import numpy as np
import math

# Define the MAE loss function
def RMSE_loss(y_true, y_pred):
  return math.sqrt(np.mean((y_true - y_pred)**2))

# Define the true value
y_true = 0

# Define the range of predictions
x = np.linspace(-10000, 10000, 100)

# Compute the RMSE loss for each prediction using a list comprehension
y = [RMSE_loss(y_true, x[i]) for i in range(len(x))]
#列表推导是从另一个可迭代对象(如列表、元组、范围等)创建列表的简洁方法。列表推导的一般语法为:
#[expression for item in iterable if condition]
#这将通过将表达式应用于满足条件的可迭代对象中的每个项目来创建新列表。该条件是可选的,可以省略。

#等价于:
# for i in range(len(x)):
#   y.append(mae_loss(y_true, x[i]))
#.append 加元素到列表末尾


# Plot the loss function
plt.plot(x, y)
plt.xlabel("Predictions")
plt.ylabel("RMSE Loss")
plt.title("RMSE Loss vs Predictions")
plt.show()

均方根误差损失函数(Y轴)与预测值(X轴)的关系图


对平均绝对误差损失函数,与均方误差损失函数的分析

简而言之,使用使用平方误差更容易求解,但使用绝对误差对异常值更可靠。但是为什么?

在训练机器学习模型时,我们的目标都是找到最小化损失函数的点。当然,当预测完全等于真实值时,这两个函数都达到最小值。

我们来比较一下两种情况下的 MAE 和 RMSE(它是 MSE 开平方,使它和 MAE 在同一量级)的数值。在第一种情况下,预测值和真实值很接近,误差的方差很小。在第二种情况下,有一个异常的观测值,误差很大。

左: 错误彼此接近 右: 与其他错误相比,一个错误相差甚远


我们从中观察到什么,它如何帮助我们选择使用哪个损失函数?

MSE 是对误差(y — y_predicted = e)的平方,所以如果 e > 1,误差(e)就会放大很多。如果我们的数据里有一个离群值,那么 e 就会很大,e2 就会更大。这样一来,用 MSE 作为损失函数的模型就会比用 MAE 作为损失函数的模型更倾向于拟合离群值。在上面的第二种情况中,用 RMSE 作为损失函数的模型会为了减小一个离群值的影响而牺牲其他正常的样本,这会降低它的整体性能。

如果训练数据被离群值污染了(也就是说,在训练环境中我们错误地收到了一些不切实际的极大或极小的值,而在测试环境中没有),那么 MAE 损失就很有用。

直观地说,我们可以这样理解:如果我们要用一个预测值来拟合所有试图最小化 MSE 的观测值,那么这个预测值应该是所有目标值的均值。但是,如果我们要最小化 MAE,那么这个预测值应该是所有观测值的中位数。我们知道中位数对离群值的抗干扰能力比均值强,所以 MAE 对离群值的抗干扰能力也比 MSE 强。

使用 MAE 损失(特别是对于神经网络)有一个很大的问题,就是它的梯度始终是一样的,这意味着即使损失值很小,梯度也很大。这对学习不利。为了解决这个问题,我们可以使用动态学习率,当我们接近最优解时,动态学习率会变小。在这种情况下,MSE 就表现得很好,即使用固定的学习率也能收敛。对于较大的损失值,MSE 损失的梯度较高,当损失接近 0 时变小,使得在训练结束时更加精确(见下图)。


选择使用哪个损失函数

应结合具体问题的实践,选择合适的损失函数。

如果离群值对实际问题很重要,需要检出,那么应该用 MSE 作为损失函数。反之,如果我们认为离群值只是数据损坏的表现,那么应该用 MAE 作为损失函数。

L1 vs. L2 Loss function – Rishabh Shukla (rishy.github.io),这篇文章对使用 L1 损失和 L2 损失的回归模型在有无离群值的情况下的性能进行了比较。文中,L1 损失和 L2 损失就是 MAE 和 MSE 的别名。

L1 损失对离群值更鲁棒,但它的导数不连续,所以求解效率低。L2 损失对离群值很敏感,但它有更稳定和封闭的解(通过令导数为 0 得到)。

两者的问题:在某些实际应用场景下,这两种损失函数都不能给出理想的预测。

比如说,如果我们的数据中 90% 的观测值的真实目标值是 150,剩下 10% 的目标值在 0-30 之间。那么,用 MAE 作为损失函数的模型可能会把所有观测值都预测为 150,忽略了 10% 的离群情况,因为它会试图接近中位数。同样的情况下,用 MSE 的模型会给出很多在 0 到 30 范围内的预测,因为它会偏向于离群值。在很多实际应用场景中,这两种结果都不是我们想要的。

那么这种情况怎么办呢?一个简单的办法是对目标变量进行转换。另一个办法是尝试不同的损失函数。这就引出了我们的第三种损失函数:平滑平均绝对误差

平滑平均绝对误差(Huber Loss)

平滑平均绝对误差损失函数比平方误差损失函数更不受离群值的影响。它在 0 点也是可导的。它基本上是绝对误差,但当误差很小的时候,它就变成了平方误差。误差多小才算很小,取决于一个可以调节的超参数 δ(delta)。当 δ 接近 0 时,Huber 损失和 MSE 很像,当 δ 很大时,Huber 损失和 MAE 很像。

import matplotlib.pyplot as plt
import numpy as np
import math
#import torch
# Define the smooth_MAE loss function

def smooth_mae_loss(y_true, y_pred, delta=1.0):
    diff = np.abs(y_true - y_pred)
    mask = diff < delta
    return np.where(mask, diff - 0.5 * delta, 0.5 * diff ** 2 / delta)

# Define the true value
y_true = 0

# Define the range of predictions
x = np.linspace(-10, 10, 100)

# Define the list of delta values
deltas = [0.1, 1, 10]

# Define the list of colors or linestyles for plotting
colors = ["r", "g", "b"]
linestyles = ["-", "--", "-."]

# Plot the loss function for each delta value
plt.figure()
for i in range(len(deltas)):
    # Compute the smooth_MAE loss for each prediction using a list comprehension
    y = [smooth_mae_loss(y_true, x[j], delta=deltas[i]) for j in range(len(x))]
    # Plot the loss function with the corresponding color and linestyle
    plt.plot(x, y, color=colors[i], linestyle=linestyles[i], label=f"delta={deltas[i]}")
plt.xlabel("Predictions")
plt.ylabel("Smooth MAE Loss")
plt.title("Smooth MAE Loss vs Predictions")
plt.legend()
plt.show()

Y:平滑平均绝对误差,X:预测值


δ 的选择很重要,因为它决定了我们把哪些误差看作是离群值。误差大于 δ 的时候,用 L1 损失来最小化(对大的离群值不敏感),误差小于 delta 的时候,用 L2 损失来最小化(对小的误差更精确)。

为什么要使用平滑平均绝对误差损失函数? 使用 MAE 训练神经网络有一个很大的问题,就是它的梯度一直很大,这可能导致在用梯度下降法训练结束时错过最优解。对于 MSE,当损失接近最优解时,梯度变小,使得训练更加精确。

在这种情况下,平滑平均绝对误差损失函数就很有用,因为它在最优解附近变平滑,从而减小梯度。而且它比 MSE 更能抵抗离群值的干扰。所以,它结合了 MSE 和 MAE 的优点。不过,平滑平均绝对误差损失函数的问题是我们可能需要调整超参数 delta,这是一个反复试验的过程。

对数损失函数 Log-cosh Loss

对数损失函数是预测值和真实值之间差异的双曲余弦的对数的平均值,它比 L2 更平滑。它的公式如下:

L(y, \hat{y}) = \sum_{i=1}^n \log(\cosh(y_i - \hat{y}_i))

其中,$y$是真实值,$\hat{y}$是预测值,$n$是样本数量。

Log-cosh loss的含义是:

  • 当预测值和真实值之间的差异很小(接近0)时,log-cosh loss近似于平方损失(MSE),即$\log(\cosh(x)) \approx \frac{x^2}{2}$。这时,log-cosh loss对小误差有较大的惩罚,可以提高模型的精度。
  • 当预测值和真实值之间的差异很大(远离0)时,log-cosh loss近似于绝对损失(MAE),即$\log(\cosh(x)) \approx |x| - \log(2)$。这时,log-cosh loss对大误差有较小的惩罚,可以提高模型的鲁棒性。

因此,log-cosh loss可以看作是平方损失和绝对损失的折中,既能保证模型的精度,又能抵抗异常值的影响。

绘制对数损失函数的图形

import matplotlib.pyplot as plt
import numpy as np
import math

# Define the MAE loss function
def log_cosh_loss(y_true, y_pred):
  return np.log(np.cosh(y_true-y_pred))

# Define the true value
y_true = 0

# Define the range of predictions
x = np.linspace(-10, 10, 100)

# Compute the RMSE loss for each prediction using a list comprehension
y = [log_cosh_loss(y_true, x[i]) for i in range(len(x))]


# Plot the loss function
plt.plot(x, y)
plt.xlabel("Predictions")
plt.ylabel("RMSE Loss")
plt.title("RMSE Loss vs Predictions")
plt.show()

log-cosh主要像均方误差一样工作,但不会被偶尔出现的非常大的误差所影响。它具有 Huber 损失的所有优点,而且和 Huber 损失不同的是,它在任何地方都可以求二阶导数。log(cosh(x))(x ** 2) / 2xabs(x) - log(2)x

为什么我们需要二阶导数?许多像 XGBoost 这样的 ML 模型实现都使用牛顿法来寻找最优解,这就需要用到二阶导数(Hessian)。对于像 XGBoost 这样的 ML 框架来说,能求二阶导数的函数更有利。

XgBoost中使用的目标函数。注意对一阶和二阶导数的依赖性


分位数损失 Quantile Loss (未完成)

分位数损失是一种用于解决回归问题的损失函数,它是预测值和真实值之间差异的绝对值乘以一个分位数权重的平均值。它的公式如下:

分位数损失的含义是:

  • 当预测值和真实值之间的差异为正(即预测值大于真实值)时,分位数损失函数乘以一个正的权重$\tau$,这个权重表示了对高估的惩罚程度。当$\tau$接近1时,高估的惩罚很大;当$\tau$接近0时,高估的惩罚很小。
  • 当预测值和真实值之间的差异为负(即预测值小于真实值)时,分位数损失函数乘以一个负的权重$\tau - 1$,这个权重表示了对低估的惩罚程度。当$\tau$接近1时,低估的惩罚很小;当$\tau$接近0时,低估的惩罚很大。

因此,分位数损失可以看作是一种灵活的损失函数,它可以根据不同的分位数参数来调整对高估和低估的偏好3。

在深度学习中,分位数损失有一些应用,例如:

  • 可以用于变分自编码器(VAE)中,来提高输出质量和不确定性量化。
  • 可以用于二分类问题中,来估计条件分位数和置信区间。
  • 可以用于时间序列预测中,来提高模型的鲁棒性和稳定性。

排名损失

相关推荐

为何越来越多的编程语言使用JSON(为什么编程)

JSON是JavascriptObjectNotation的缩写,意思是Javascript对象表示法,是一种易于人类阅读和对编程友好的文本数据传递方法,是JavaScript语言规范定义的一个子...

何时在数据库中使用 JSON(数据库用json格式存储)

在本文中,您将了解何时应考虑将JSON数据类型添加到表中以及何时应避免使用它们。每天?分享?最新?软件?开发?,Devops,敏捷?,测试?以及?项目?管理?最新?,最热门?的?文章?,每天?花?...

MySQL 从零开始:05 数据类型(mysql数据类型有哪些,并举例)

前面的讲解中已经接触到了表的创建,表的创建是对字段的声明,比如:上述语句声明了字段的名称、类型、所占空间、默认值和是否可以为空等信息。其中的int、varchar、char和decimal都...

JSON对象花样进阶(json格式对象)

一、引言在现代Web开发中,JSON(JavaScriptObjectNotation)已经成为数据交换的标准格式。无论是从前端向后端发送数据,还是从后端接收数据,JSON都是不可或缺的一部分。...

深入理解 JSON 和 Form-data(json和formdata提交区别)

在讨论现代网络开发与API设计的语境下,理解客户端和服务器间如何有效且可靠地交换数据变得尤为关键。这里,特别值得关注的是两种主流数据格式:...

JSON 语法(json 语法 priority)

JSON语法是JavaScript语法的子集。JSON语法规则JSON语法是JavaScript对象表示法语法的子集。数据在名称/值对中数据由逗号分隔花括号保存对象方括号保存数组JS...

JSON语法详解(json的语法规则)

JSON语法规则JSON语法是JavaScript对象表示法语法的子集。数据在名称/值对中数据由逗号分隔大括号保存对象中括号保存数组注意:json的key是字符串,且必须是双引号,不能是单引号...

MySQL JSON数据类型操作(mysql的json)

概述mysql自5.7.8版本开始,就支持了json结构的数据存储和查询,这表明了mysql也在不断的学习和增加nosql数据库的有点。但mysql毕竟是关系型数据库,在处理json这种非结构化的数据...

JSON的数据模式(json数据格式示例)

像XML模式一样,JSON数据格式也有Schema,这是一个基于JSON格式的规范。JSON模式也以JSON格式编写。它用于验证JSON数据。JSON模式示例以下代码显示了基本的JSON模式。{"...

前端学习——JSON格式详解(后端json格式)

JSON(JavaScriptObjectNotation)是一种轻量级的数据交换格式。易于人阅读和编写。同时也易于机器解析和生成。它基于JavaScriptProgrammingLa...

什么是 JSON:详解 JSON 及其优势(什么叫json)

现在程序员还有谁不知道JSON吗?无论对于前端还是后端,JSON都是一种常见的数据格式。那么JSON到底是什么呢?JSON的定义...

PostgreSQL JSON 类型:处理结构化数据

PostgreSQL提供JSON类型,以存储结构化数据。JSON是一种开放的数据格式,可用于存储各种类型的值。什么是JSON类型?JSON类型表示JSON(JavaScriptO...

JavaScript:JSON、三种包装类(javascript 包)

JOSN:我们希望可以将一个对象在不同的语言中进行传递,以达到通信的目的,最佳方式就是将一个对象转换为字符串的形式JSON(JavaScriptObjectNotation)-JS的对象表示法...

Python数据分析 只要1分钟 教你玩转JSON 全程干货

Json简介:Json,全名JavaScriptObjectNotation,JSON(JavaScriptObjectNotation(记号、标记))是一种轻量级的数据交换格式。它基于J...

比较一下JSON与XML两种数据格式?(json和xml哪个好)

JSON(JavaScriptObjectNotation)和XML(eXtensibleMarkupLanguage)是在日常开发中比较常用的两种数据格式,它们主要的作用就是用来进行数据的传...

取消回复欢迎 发表评论:

请填写验证码