百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 编程字典 > 正文

用这两种方法向最终用户解释NLP模型的工作原理还是不错的

toyiye 2024-04-27 03:47 18 浏览 0 评论

点击上方关注,All in AI中国

上周,我看了一个关于“NLP的实践特性工程”的演讲。主要是关于LIME和SHAP在文本分类可解释性方面是如何工作的。

我决定写一篇关于它们的文章,因为它们很有趣、易于使用,而且视觉上很吸引人。

所有的机器学习模型都是在更高的维度上运行的,而不是在人脑可以直接看到的维度上运行的,这些机器学习模型都可以被称为黑盒模型,它可以归结为模型的可解释性。特别是在NLP领域中,特征的维数往往很大,说明特征的重要性变得越来越复杂。

LIME & SHAP不仅帮助我们向最终用户解释NLP模型的工作原理,而且帮助我们自己解释NLP模型是如何工作的。

利用 Stack Overflow 问题标签分类数据集,我们将构建一个多类文本分类模型,然后分别应用LIME和SHAP对模型进行解释。由于我们之前已经做过多次文本分类,所以我们将快速构建NLP模型,并着重于模型的可解释性。

数据预处理、特征工程和逻辑回归

 import pandas as pd
	import numpy as np
	import sklearn
	import sklearn.ensemble
	import sklearn.metrics
	from sklearn.utils import shuffle
	from __future__ import print_function
	from io import StringIO
	import re
	from bs4 import BeautifulSoup
from nltk.corpus import stopwords
	from sklearn.model_selection import train_test_split
	from sklearn.feature_extraction.text import CountVectorizer
	from sklearn.linear_model import LogisticRegression
	from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
	import lime
	from lime import lime_text
	from lime.lime_text import LimeTextExplainer
	from sklearn.pipeline import make_pipeline
df = pd.read_csv('stack-overflow-data.csv')
	df = df[pd.notnull(df['tags'])]
	df = df.sample(frac=0.5, random_state=99).reset_index(drop=True)
	df = shuffle(df, random_state=22)
	df = df.reset_index(drop=True)
	df['class_label'] = df['tags'].factorize()[0]
	class_label_df = df[['tags', 'class_label']].drop_duplicates().sort_values('class_label')
	label_to_id = dict(class_label_df.values)
	id_to_label = dict(class_label_df[['class_label', 'tags']].values)
REPLACE_BY_SPACE_RE = re.compile('[/(){}\[\]\|@,;]')
	BAD_SYMBOLS_RE = re.compile('[^0-9a-z #+_]')
	# STOPWORDS = set(stopwords.words('english'))
	
	def clean_text(text):
	 """
	 text: a string
	 
	 return: modified initial string
	 """
text = BeautifulSoup(text, "lxml").text # HTML decoding. BeautifulSoup's text attribute will return a string stripped of any HTML tags and metadata.
	 text = text.lower() # lowercase text
	 text = REPLACE_BY_SPACE_RE.sub(' ', text) # replace REPLACE_BY_SPACE_RE symbols by space in text. substitute the matched string in REPLACE_BY_SPACE_RE with space.
	 text = BAD_SYMBOLS_RE.sub('', text) # remove symbols which are in BAD_SYMBOLS_RE from text. substitute the matched string in BAD_SYMBOLS_RE with nothing. 
	# text = ' '.join(word for word in text.split() if word not in STOPWORDS) # remove stopwors from text
	 return text
df['post'] = df['post'].apply(clean_text)
	
	list_corpus = df["post"].tolist()
	list_labels = df["class_label"].tolist()
	X_train, X_test, y_train, y_test = train_test_split(list_corpus, list_labels, test_size=0.2, random_state=40)
	vectorizer = CountVectorizer(analyzer='word',token_pattern=r'\w{1,}', ngram_range=(1, 3), stop_words = 'english', binary=True)
	train_vectors = vectorizer.fit_transform(X_train)
	test_vectors = vectorizer.transform(X_test)
logreg = LogisticRegression(n_jobs=1, C=1e5)
	logreg.fit(train_vectors, y_train)
	pred = logreg.predict(test_vectors)
	accuracy = accuracy_score(y_test, pred)
	precision = precision_score(y_test, pred, average='weighted')
recall = recall_score(y_test, pred, average='weighted')f1 = f1_score(y_test, pred, average='weighted')print("accuracy = %.3f, precision = %.3f, recall = %.3f, f1 = %.3f" % (accuracy, precision, recall, f1))

我们现在目标并不是产生最好的结果。我想尽快进入LIME & SHAP,这就是接下来发生的事情。

用LIME解释文本预测

从现在开始,这是有趣的部分。下面的代码片段主要是从LIME教程中借来的。

c = make_pipeline(vectorizer, logreg)
	class_names=list(df.tags.unique())
	explainer = LimeTextExplainer(class_names=class_names)
idx = 1877
	exp = explainer.explain_instance(X_test[idx], c.predict_proba, num_features=6, labels=[4, 8])
	print('Document id: %d' % idx)
print('Predicted class =', class_names[logreg.predict(test_vectors[idx]).reshape(1,-1)[0,0]])
	print('True class: %s' % class_names[y_test[idx]])

我们在测试集中随机选择一个文档,它恰好是一个标记为sql的文档,我们的模型也预测它是sql。使用这个文档,我们为标签4 (sql)和标签8 (python)生成解释。

print ('Explanation for class %s' % class_names[4])
print ('\n'.join(map(str, exp.as_list(label=4))))

print ('Explanation for class %s' % class_names[8])
print ('\n'.join(map(str, exp.as_list(label=8))))

很明显,这个文档对标签sql有最高的解释。我们还注意到正负号与特定的标签有关,例如单词"sql"对类sql是正的,而对类python是负的,反之亦然。

我们要为这个文档生成2类标签顶部。

exp = explainer.explain_instance(X_test[idx], c.predict_proba, num_features=6, top_labels=2)
print(exp.available_labels())

它给出了sql和python。

exp.show_in_notebook(text=False)

让我来解释一下这种可视化:

1. 对于本文档,词 "sql"对于类sql具有最高的正分数。

2. 我们的模型预测该文档应该标记为sql,其概率为100%。

3. 如果我们从文档中删除word"sql",我们期望模型预测label sql的概率为100% - 65% = 35%。

4. 另一方面,单词"sql"对于类python是负面的,我们的模型已经了解到单词"range"对于类python有一个小的正面得分。

我们可能想放大并研究类sql的解释,以及文档本身。

exp.show_in_notebook(text=y_test[idx], labels=(4,))

使用SHAP解释文本预测

以下过程是从本教程中学到的。「链接」

from sklearn.preprocessing import MultiLabelBinarizer
	import tensorflow as tf
	from tensorflow.keras.preprocessing import text
	import keras.backend.tensorflow_backend as K
	K.set_session
	import shap
tags_split = [tags.split(',') for tags in df['tags'].values]
	tag_encoder = MultiLabelBinarizer()
	tags_encoded = tag_encoder.fit_transform(tags_split)
	num_tags = len(tags_encoded[0])
	train_size = int(len(df) * .8)
y_train = tags_encoded[: train_size]
	y_test = tags_encoded[train_size:]
class TextPreprocessor(object):
	 def __init__(self, vocab_size):
	 self._vocab_size = vocab_size
	 self._tokenizer = None
	 def create_tokenizer(self, text_list):
	 tokenizer = text.Tokenizer(num_words = self._vocab_size)
	 tokenizer.fit_on_texts(text_list)
	 self._tokenizer = tokenizer
	 def transform_text(self, text_list):
	 text_matrix = self._tokenizer.texts_to_matrix(text_list)
	 return text_matrix
VOCAB_SIZE = 500
	train_post = df['post'].values[: train_size]
	test_post = df['post'].values[train_size: ]
	processor = TextPreprocessor(VOCAB_SIZE)
	processor.create_tokenizer(train_post)
	X_train = processor.transform_text(train_post)
	X_test = processor.transform_text(test_post)
def create_model(vocab_size, num_tags):
	 model = tf.keras.models.Sequential()
	 model.add(tf.keras.layers.Dense(50, input_shape = (VOCAB_SIZE,), activation='relu'))
	 model.add(tf.keras.layers.Dense(25, activation='relu'))
model.add(tf.keras.layers.Dense(num_tags, activation='sigmoid'))
	 model.compile(loss = 'binary_crossentropy', optimizer='adam', metrics = ['accuracy'])
	 return model
 model = create_model(VOCAB_SIZE, num_tags)
	model.fit(X_train, y_train, epochs = 2, batch_size=128, validation_split=0.1)
 print('Eval loss/accuracy:{}'.format(model.evaluate(X_test, y_test, batch_size = 128)))
  • 模型训练完成后,我们使用前200个训练文档作为背景数据集进行集成,并创建一个SHAP explainer对象。
  • 我们在测试集的子集上获得各个预测的属性值。
  • 将索引转换为单词。
  • 使用SHAP的summary_plot方法来显示影响模型预测的主要特性。
attrib_data = X_train[:200]
explainer = shap.DeepExplainer(model, attrib_data)
num_explanations = 20
shap_vals = explainer.shap_values(X_test[:num_explanations])
words = processor._tokenizer.word_index
word_lookup = list()
for i in words.keys():
 word_lookup.append(i)
word_lookup = [''] + word_lookup
shap.summary_plot(shap_vals, feature_names=word_lookup, class_names=tag_encoder.classes_)

  • 单词"want"是我们模型使用的最大信号词,对类jquery预测贡献最大。
  • 单词"php"是我们模型使用的第四大信号词,当然对PHP类贡献最大。
  • 另一方面,单词"php"可能对另一个类有负面信号,因为它不太可能在python文档中看到单词"php"。

关于LIME & SHAP的机器学习可解释性,还有很多需要学习的地方。我只介绍了一小部分NLP。其余的可以在Github上找到。NLP-with-Python/LIME_SHAP_StackOverflow.ipynb at master · susanli2016/NLP-with-Python · GitHub

相关推荐

Asterisk-ARI对通道中的DTMF事件处理

Asterisk通道中关于DTMF处理是一个非常重要的功能。通过DTMF可以实现很多的业务处理。现在我们介绍一下关于ARI对通道中的DTMF处理,我们通过自动话务员实例来说明Asterisk如何创建一...

PyQt5 初次使用(pyqt5下载官网)

本篇文章默认已安装Python3,本篇文章默认使用虚拟环境。安装pipinstallPyQt5PyQt一些图形界面开发工具QtDesigner、国际化翻译工具Liguist需要另外...

Qt开发,使用Qt for Python还是Qt C++ Qt开发,使用Qt for

Qt开发使用QtforPython还是QtC++?1.早些年写过一个PyQt5的项目,最近几年重构成QtC++了,其中有个人原因,如早期代码写得烂,...

最简单方法!!用python生成动态条形图

最近非常流行动态条形图,在B站等视频网站上,此类视频经常会有上百万的播放量,今天我们通过第三方库:bar_chart_race(0.2版本)来实现动态条形图的生成;生成的效果如图:问题:...

Asterisk通道和ARI接口的通信(aau通道数)

Asterisk通道和ARI详解什么是通道Asterisk中,通道是介于终端和Asterisk自己本身的一个通信媒介。它包含了所有相关信息传递到终端,或者从终端传递到Asterisk服务器端。这些信...

Python GUI-长链转短链(长链接转化成短链接java)

当我们要分享某一个链接给别人,或是要把某个链接放入帖子中时,如果链接太长,则会占用大量空间,而且很不美观。这时候,我们可以结束长链转短链工具进行转换。当然可以直接搜索在线的网站进行转换,但我们可以借此...

Python 的hash 函数(python的hash函数)

今天在看python的hash函数源码的时候,发现针对不同的数据类型python实现了不同的hash函数,今天简单介绍源码中提到的hash函数。(https://github.com/pyth...

8款Python GUI开源框架,谁才是你的菜?

作为Python开发者,你迟早都会用到图形用户界面来开发应用。本文千锋武汉Python培训小编将推荐一些PythonGUI框架,希望对你有所帮助。1、Python的UI开发工具包Kivy...

python适合开发桌面软件吗?(python可不可以开发桌面应用软件)

其实Python/Java/PHP都不适合用来做桌面开发,Java还是有几个比较成熟的产品的,比如大名鼎鼎的Java集成开发环境IntelliJIDEA、Eclipse就是用Java开发的,不过PH...

CryptoChat:一款功能强大的纯Python消息加密安全传输工具

关于CryptoChatCryptoChat是一款功能强大的纯Python消息加密安全传输工具,该工具专为安全研究专家、渗透测试人员和红蓝队专家设计,该工具可以完全保证数据传输中的隐私安全。该工具建立...

为什么都说Python简单,但我觉得难?

Python普遍被大家认为是编程语言中比较简单的一种,但有一位电子信息的学生说自己已经学了C语言,但仍然觉得Python挺难的,感觉有很多疑问,像迭代器、装饰器什么的……所以他提出疑问:Python真...

蓝牙电话-关联FreeSwitch中继SIP账号通过Rest接口

蓝牙电话-关联FreeSwitch中继SIP账号通过Rest接口前言上一篇章《蓝牙电话-与FreeSwitch服务器和UA坐席的通话.docx》中,我们使用开源的B2B-UA当中经典的FreeSWIT...

技术分享|Sip与WebRTC互通-SRProxy开源库讲解

SRProxy介绍目前WebRTC协议跟SIP协议互通场景主要运用在企业呼叫中心、企业内部通信、电话会议(PSTN)、智能门禁等场景,要想让WebRTC与SIP互通,要解决两个层面的...

全网第N篇SIP协议之GB28181注册 JAVA版本

鉴于网上大部分关于SIP注册服务器编写都是C/C++/python,故开此贴,JAVA实现也贴出分享GB28181定义了了基于SIP架构的视频监控互联规范,而对于多数私有协议实现的监控系统...

「linux专栏」top命令用法详解,再也不怕看不懂top了

在linux系统中,我们经常使用到的一个命令就是top,它主要是用来显示系统运行中所有的进程和进程对应资源的使用等信息,所有的用户都可以使用top命令。top命令内容量丰富,可令使用者头疼的是无法全部...

取消回复欢迎 发表评论:

请填写验证码