百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 编程字典 > 正文

主动学习简介及示例

toyiye 2024-06-21 12:38 11 浏览 0 评论

主动学习是一种半监督机器学习技术,它通过从学习过程(损失)的角度选择最重要的样本来标记较少的数据。在数据量大、贴标签率高的情况下,会对项目成本产生巨大影响。例如,对象检测和np - ner问题。

导入Python库

import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
sess = tf.InteractiveSession()

实验数据

#load 4000 of MNIST data for train and 400 for testing
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_full = x_train[:4000] / 255
y_full = y_train[:4000]
x_test = x_test[:400] /255
y_test = y_test[:400]
x_full.shape, y_full.shape, x_test.shape, y_test.shape

((4000, 28, 28), (4000,), (400, 28, 28), (400,))

plt.imshow(x_full[3999])

我将使用MNIST机器学习数据集的一个子集,该机器学习数据集是60K的带有标签的数字图片和10K的测试样本。为了更快的训练,需要4000个样本(图片)进行训练,400个样本(图片)进行测试(神经网络在训练过程中永远看不到)。为了归一化,我将灰度图像点除以255。

机器学习模型,训练和labeling过程

#build computation graph
x = tf.placeholder(tf.float32, [None, 28, 28])
x_flat = tf.reshape(x, [-1, 28 * 28])
y_ = tf.placeholder(tf.int32, [None])
W = tf.Variable(tf.zeros([28 * 28, 10]), tf.float32)
b = tf.Variable(tf.zeros([10]), tf.float32)
y = tf.matmul(x_flat, W) + b
y_sm = tf.nn.softmax(y)
loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y_, logits=y))
train = tf.train.AdamOptimizer(0.1).minimize(loss)
accuracy = tf.reduce_mean(tf.cast(tf.equal(y_, tf.cast(tf.argmax(y, 1), tf.int32)), tf.float32))

作为一个框架,我们可以使用TensorFlow计算图来构建10个神经元(每个数字)。W和b是神经元的权重。softmax输出y_sm将帮助处理数字的概率(置信度)。损失将是一个典型的“softmaxed”交叉熵之间的预测和标记的数据。优化器的选择是流行的Adam,其学习率几乎是默认的- 0.1。测试数据的准确性将作为主要度量指标

def reset():
 '''Initialize data sets and session'''
 global x_labeled, y_labeled, x_unlabeled, y_unlabeled
 x_labeled = x_full[:0]
 y_labeled = y_full[:0]
 x_unlabeled = x_full
 y_unlabeled = y_full
 tf.global_variables_initializer().run()
 tf.local_variables_initializer().run() 
def fit():
 '''Train current labeled dataset until overfit.'''
 trial_count = 10
 acc = sess.run(accuracy, feed_dict={x:x_test, y_:y_test})
 weights = sess.run([W, b])
 while trial_count > 0:
 sess.run(train, feed_dict={x:x_labeled, y_:y_labeled})
 acc_new = sess.run(accuracy, feed_dict={x:x_test, y_:y_test})
 if acc_new <= acc:
 trial_count -= 1
 else:
 trial_count = 10
 weights = sess.run([W, b])
 acc = acc_new
 sess.run([W.assign(weights[0]), b.assign(weights[1])]) 
 acc = sess.run(accuracy, feed_dict={x:x_test, y_:y_test})
 print('Labels:', x_labeled.shape[0], '\tAccuracy:', acc)
def label_manually(n):
 '''Human powered labeling (actually copying from the prelabeled MNIST dataset).'''
 global x_labeled, y_labeled, x_unlabeled, y_unlabeled
 x_labeled = np.concatenate([x_labeled, x_unlabeled[:n]])
 y_labeled = np.concatenate([y_labeled, y_unlabeled[:n]])
 x_unlabeled = x_unlabeled[n:]
 y_unlabeled = y_unlabeled[n:]

为了更方便的编码,我在这里定义了这三个过程。

  • reset() -清空已标记的机器学习数据集,将所有数据放入未标记的机器学习数据集中,并重置会话变量
  • fit()——运行一个试图达到最佳准确度的训练。如果在前十次尝试中不能提高,训练就会停在最后一个最好的结果。我们不能只使用大量的训练时间,因为模型往往很快就会过度拟合,或者需要进行L2正则化。
  • label_manual()——这是一种模拟人类数据标记的方法。实际上,我们从已经标记的MNIST数据集中获取标签。

Ground Truth

#train full dataset of 1000
reset()
label_manually(4000)
fit()

Labels: 4000 Accuracy: 0.9225

如果我们足够幸运,有足够的资源来标记整个数据集,我们将获得92.25%的准确性。

聚类

#apply clustering
kmeans = tf.contrib.factorization.KMeansClustering(10, use_mini_batch=False)
kmeans.train(lambda: tf.train.limit_epochs(x_full.reshape(4000, 784).astype(np.float32), 10))

centers = kmeans.cluster_centers().reshape([10, 28, 28])
plt.imshow(np.concatenate([centers[i] for i in range(10)], axis=1))

在这里,我尝试使用k-means聚类来找到一组数字,并使用这些信息进行自动标记。我运行Tensorflow聚类估计器,然后可视化结果的十个质心。正如你所看到的,结果远非完美——数字“9”出现了三次,有时还夹杂着“8”和“3”。

随机标记

#try to run on random 400
reset()
label_manually(400)
fit()

Labels: 400 Accuracy: 0.8375

让我们尝试仅标记10%的数据(400个样本),我们将获得83.75%的准确度,远远低于92.25%的ground truth。

#now try to run on 10
reset()
label_manually(10)
fit()

Labels: 10 Accuracy: 0.38

让我们尝试仅标记10个样本,我们仅获得38%的准确度。

#pass unlabeled rest 3990 through the early model
res = sess.run(y_sm, feed_dict={x:x_unlabeled})
#find less confident samples
pmax = np.amax(res, axis=1)
pidx = np.argsort(pmax)
#sort the unlabeled corpus on the confidency
x_unlabeled = x_unlabeled[pidx]
y_unlabeled = y_unlabeled[pidx]
plt.plot(pmax[pidx])

现在我们将使用active learning标记相同10%的数据(400个样本)。为了做到这一点,我们从10个样本中抽取一批样本,并训练一个非常原始的机器学习模型。然后,我们将剩余的数据(3990个样本)通过该机器学习模型传递,并计算最大softmax输出。这将显示所选类是正确答案的概率(换句话说,是神经网络的置信度)。排序后,我们可以在图中看到置信度的分布在20%到100%之间。我们的想法是从LESS CONFIDENT样本中精确选择下一批标记。

#do the same in a loop for 400 samples
for i in range(39):
 label_manually(10)
 fit()
 
 res = sess.run(y_sm, feed_dict={x:x_unlabeled})
 pmax = np.amax(res, axis=1)
 pidx = np.argsort(pmax)
 x_unlabeled = x_unlabeled[pidx]
 y_unlabeled = y_unlabeled[pidx]

在对40批10个样品进行这样的程序之后,我们可以看到所得到的精度几乎为90%。这远远超过随机标记数据的83.75%。

如何处理剩余的未标记数据

#pass rest unlabeled data through the model and try to autolabel
res = sess.run(y_sm, feed_dict={x:x_unlabeled})
y_autolabeled = res.argmax(axis=1)
x_labeled = np.concatenate([x_labeled, x_unlabeled])
y_labeled = np.concatenate([y_labeled, y_autolabeled])
#train on 400 labeled by active learning and 3600 stochasticly autolabeled data
fit()

经典的方法是通过现有机器学习模型运行数据集的其余部分并自动标记数据。然后,使其在训练过程中也许会有助于更好的优化模型。但在我们的例子中,它并没有给我们任何更好的结果。

#pass rest of unlabeled (3600) data trough the model for automatic labeling and show most confident samples
res = sess.run(y_sm, feed_dict={x:x_unlabeled})
y_autolabeled = res.argmax(axis=1)
pmax = np.amax(res, axis=1)
pidx = np.argsort(pmax)
#sort by confidency
x_unlabeled = x_unlabeled[pidx]
y_autolabeled = y_autolabeled[pidx]
plt.plot(pmax[pidx])

#automatically label 10 most confident sample and train for it
x_labeled = np.concatenate([x_labeled, x_unlabeled[-10:]])
y_labeled = np.concatenate([y_labeled, y_autolabeled[-10:]])
x_unlabeled = x_unlabeled[:-10]
fit()

Labels: 410 Accuracy: 0.8975

在这里,我们通过模型评估运行剩余的未标记数据,我们仍然可以看到其余样本的置信度不同。因此,我们的想法是采取一批十个MOST CONFIDENT样本并训练模型。

#run rest of unlabelled samples starting from most confident
for i in range(359):
 res = sess.run(y_sm, feed_dict={x:x_unlabeled})
 y_autolabeled = res.argmax(axis=1)
 pmax = np.amax(res, axis=1)
 pidx = np.argsort(pmax)
 x_unlabeled = x_unlabeled[pidx]
 y_autolabeled = y_autolabeled[pidx]
 x_labeled = np.concatenate([x_labeled, x_unlabeled[-10:]])
 y_labeled = np.concatenate([y_labeled, y_autolabeled[-10:]])
 x_unlabeled = x_unlabeled[:-10]
 fit()

结果

实验准确度

4000个样本92.25%

400个随机样本83.75%

400个活跃学习样本89.75%

+自动标记90.50%

结论

当然,这种方法有其缺点,如计算资源的大量使用以及与早期模型评估相结合的数据标记需要特殊程序的事实。此外,为了测试目的,还需要标记数据。但是,如果标签的成本很高(特别是对于NLP,CV项目),此方法可以节省大量资源并推动更好的项目结果。

相关推荐

为何越来越多的编程语言使用JSON(为什么编程)

JSON是JavascriptObjectNotation的缩写,意思是Javascript对象表示法,是一种易于人类阅读和对编程友好的文本数据传递方法,是JavaScript语言规范定义的一个子...

何时在数据库中使用 JSON(数据库用json格式存储)

在本文中,您将了解何时应考虑将JSON数据类型添加到表中以及何时应避免使用它们。每天?分享?最新?软件?开发?,Devops,敏捷?,测试?以及?项目?管理?最新?,最热门?的?文章?,每天?花?...

MySQL 从零开始:05 数据类型(mysql数据类型有哪些,并举例)

前面的讲解中已经接触到了表的创建,表的创建是对字段的声明,比如:上述语句声明了字段的名称、类型、所占空间、默认值和是否可以为空等信息。其中的int、varchar、char和decimal都...

JSON对象花样进阶(json格式对象)

一、引言在现代Web开发中,JSON(JavaScriptObjectNotation)已经成为数据交换的标准格式。无论是从前端向后端发送数据,还是从后端接收数据,JSON都是不可或缺的一部分。...

深入理解 JSON 和 Form-data(json和formdata提交区别)

在讨论现代网络开发与API设计的语境下,理解客户端和服务器间如何有效且可靠地交换数据变得尤为关键。这里,特别值得关注的是两种主流数据格式:...

JSON 语法(json 语法 priority)

JSON语法是JavaScript语法的子集。JSON语法规则JSON语法是JavaScript对象表示法语法的子集。数据在名称/值对中数据由逗号分隔花括号保存对象方括号保存数组JS...

JSON语法详解(json的语法规则)

JSON语法规则JSON语法是JavaScript对象表示法语法的子集。数据在名称/值对中数据由逗号分隔大括号保存对象中括号保存数组注意:json的key是字符串,且必须是双引号,不能是单引号...

MySQL JSON数据类型操作(mysql的json)

概述mysql自5.7.8版本开始,就支持了json结构的数据存储和查询,这表明了mysql也在不断的学习和增加nosql数据库的有点。但mysql毕竟是关系型数据库,在处理json这种非结构化的数据...

JSON的数据模式(json数据格式示例)

像XML模式一样,JSON数据格式也有Schema,这是一个基于JSON格式的规范。JSON模式也以JSON格式编写。它用于验证JSON数据。JSON模式示例以下代码显示了基本的JSON模式。{"...

前端学习——JSON格式详解(后端json格式)

JSON(JavaScriptObjectNotation)是一种轻量级的数据交换格式。易于人阅读和编写。同时也易于机器解析和生成。它基于JavaScriptProgrammingLa...

什么是 JSON:详解 JSON 及其优势(什么叫json)

现在程序员还有谁不知道JSON吗?无论对于前端还是后端,JSON都是一种常见的数据格式。那么JSON到底是什么呢?JSON的定义...

PostgreSQL JSON 类型:处理结构化数据

PostgreSQL提供JSON类型,以存储结构化数据。JSON是一种开放的数据格式,可用于存储各种类型的值。什么是JSON类型?JSON类型表示JSON(JavaScriptO...

JavaScript:JSON、三种包装类(javascript 包)

JOSN:我们希望可以将一个对象在不同的语言中进行传递,以达到通信的目的,最佳方式就是将一个对象转换为字符串的形式JSON(JavaScriptObjectNotation)-JS的对象表示法...

Python数据分析 只要1分钟 教你玩转JSON 全程干货

Json简介:Json,全名JavaScriptObjectNotation,JSON(JavaScriptObjectNotation(记号、标记))是一种轻量级的数据交换格式。它基于J...

比较一下JSON与XML两种数据格式?(json和xml哪个好)

JSON(JavaScriptObjectNotation)和XML(eXtensibleMarkupLanguage)是在日常开发中比较常用的两种数据格式,它们主要的作用就是用来进行数据的传...

取消回复欢迎 发表评论:

请填写验证码