百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 编程字典 > 正文

基于1D-CNN的齿轮故障诊断及TSNE可视化

toyiye 2024-06-21 12:02 11 浏览 0 评论

水一水吧。


数据来自kaggle。

#加载相关模块
import pandas as pd
import numpy as np
import os
import matplotlib.pyplot as plt
import seaborn as sns
#数据路径
Directory='Gear Data\BrokenTooth'
for root, dirs, files in os.walk(Directory):
for i in range (len(files)):
print(files[i])

path = os.path.join(root,files[0])
path
df_temp = pd.read_csv(path)
df_temp
#时域波形
plt.plot(df_temp.iloc[:,0])
#设置标签
load_col = [int(files[0][5:-4])/100 for j in range(len(df_temp))]
lab='F'
label_col = [lab for j in range(len(df_temp))]
label_col
df_temp['load']=load_col
df_temp['fault']=label_col
df_temp
#数据集处理
def MakeDataset(Directory,lab):
df=pd.DataFrame(columns=['a1','a2','a3','a4'])
for root, dirs, files in os.walk(Directory):
for i in range (len(files)):
path = os.path.join(root,files[i])
df_temp = pd.read_csv(path)
load_col = [int(files[i][5:-4])/100 for j in range(len(df_temp))]
label_col = [lab for j in range(len(df_temp))]
df_temp['load']=load_col
df_temp['fault']=label_col
df = pd.concat([df,df_temp],axis=0)
print(path)

return df
#故障数据
Directory='Gear Data\BrokenTooth'
df_F = MakeDataset(Directory,lab='F')
df_F
#健康数据
Directory='Gear Data\Healthy'
df_H = MakeDataset(Directory,lab='H')
df_H
#将故障数据和健康数据利用concat函数进行连接并输出,便于后续模型使用
df = pd.concat([df_F,df_H],axis=0)
df.to_csv('Gear_Fault_data.csv',index=False)
#数据归一化操作
df = pd.read_csv('Gear_Fault_data.csv')
from sklearn.preprocessing import StandardScaler
scaler=StandardScaler()
df.iloc[:,:-2]=scaler.fit_transform(df.iloc[:,:-2])
##为 CNN 创建数据集
from sklearn.preprocessing import LabelEncoder
from tensorflow.keras.utils import to_categorical
win_len=100 #窗口长度
stride=200 #移动步长
X=[]
Y=[]
for k in ['F','H']:

df_temp_1 = df[df['fault']==k]

for j in (np.arange(0,1,0.1)):
df_temp_2=df_temp_1[df_temp_1['load']==j]
for i in np.arange(0,len(df_temp_2)-(win_len),stride):
X.append(df_temp_2.iloc[i:i+win_len,:-1])
Y.append(df_temp_2.iloc[i+win_len,-1])
#训练数据
X=np.array(X)
X=X.reshape((X.shape[0],X.shape[1],X.shape[2],1))
#X = np.repeat(X, 3, axis=3) # To repeat into 3 chanel format
#标签
Y=np.array(Y)
encoder= LabelEncoder()
encoder.fit(Y)
encoded_Y = encoder.transform(Y)
OHE_Y = to_categorical(encoded_Y)
#训练集尺寸
X.shape
##T-sne可视化
X_pre_cnn = X.reshape(X.shape[0],X.shape[1]*X.shape[2])
from sklearn.manifold import TSNE
X_t_sne = TSNE(n_components=2, learning_rate='auto',verbose=1, perplexity=40, n_iter=300).fit_transform(X_pre_cnn)
tSNEdf = pd.DataFrame(data = X_t_sne, columns = ['t-SNE component 1', 't-SNE component 2'])
tSNEdf['Fault']=Y
#绘制2个主成分
fig, ax = plt.subplots(figsize=(7,7))
sns.scatterplot(x=tSNEdf['t-SNE component 1'],y=tSNEdf['t-SNE component 2'],hue='Fault',
data=tSNEdf,
legend="full",
alpha=0.3)
plt.show()
#训练集和测试集划分
from sklearn.model_selection import train_test_split
X_train,X_test,y_train,y_test = train_test_split(X,OHE_Y,test_size=0.3,shuffle=True)
#构建CNN模型
from tensorflow.keras.models import Sequential,Model
from tensorflow.keras.layers import Input,Dense, Dropout, Flatten
from tensorflow.keras.layers import Conv2D, MaxPooling2D
no_classes = 2 #2个类别
cnn_model = Sequential()
cnn_model.add(Conv2D(32, kernel_size=(20, 3),activation='relu',input_shape=(X.shape[1],X.shape[2],1),padding='same'))
cnn_model.add(MaxPooling2D((20, 2),strides=(5, 5),padding='same'))
cnn_model.add(Conv2D(64, (10, 3), activation='relu',padding='same'))
cnn_model.add(MaxPooling2D(pool_size=(10, 2),strides=(3, 3),padding='same'))
cnn_model.add(Flatten())
cnn_model.add(Dense(128, activation='relu'))

cnn_model.add(Dense(no_classes, activation='softmax'))
cnn_model.summary()
cnn_model.compile(loss='categorical_crossentropy', optimizer='adam',metrics=['accuracy'])
#设置训练参数并训练CNN
batch_size = 128
epochs = 5
history = cnn_model.fit(X_train, y_train, batch_size=batch_size,epochs=epochs,verbose=1,validation_data=(X_test,y_test),shuffle=True)
#保存模型
cnn_model.save('CNN_model_gear.h5')
##模型性能计算
def inv_Transform_result(y_pred):
y_pred = y_pred.argmax(axis=1)
y_pred = encoder.inverse_transform(y_pred)
return y_pred
#预测
y_pred=cnn_model.predict(X_test)
Y_pred=inv_Transform_result(y_pred)
Y_test = inv_Transform_result(y_test)
from sklearn.metrics import confusion_matrix
#混淆矩阵
plt.figure(figsize=(5,5))
cm = confusion_matrix(Y_test, Y_pred)
f = sns.heatmap(cm, annot=True, fmt='d',xticklabels=encoder.classes_,yticklabels=encoder.classes_)
plt.show()
#输出可视化
dummy_cnn = Model(inputs=cnn_model.input,outputs=cnn_model.layers[5].output)
y_viz = dummy_cnn.predict(X_train)
y_viz.shape
from sklearn.manifold import TSNE
X_t_sne = TSNE(n_components=2, learning_rate='auto',verbose=1, perplexity=40, n_iter=300).fit_transform(y_viz)
tSNEdf = pd.DataFrame(data = X_t_sne, columns = ['principal component 1', 'principal component 2'])
tSNEdf['Fault']=inv_Transform_result(y_train)
# 绘制两个主成分分量
fig, ax = plt.subplots(figsize=(10,10))
sns.scatterplot(x=tSNEdf['principal component 1'],y=tSNEdf['principal component 2'],hue='Fault',
data=tSNEdf,
legend="full",
alpha=0.3)
plt.show()
#Flatten层可视化
dummy_cnn = Model(inputs=cnn_model.input,outputs=cnn_model.layers[4].output)
y_viz = dummy_cnn.predict(X_train)
from sklearn.manifold import TSNE
X_t_sne = TSNE(n_components=2, learning_rate='auto',verbose=1, perplexity=40, n_iter=300).fit_transform(y_viz)
tSNEdf = pd.DataFrame(data = X_t_sne, columns = ['t-SNE component 1', 't-SNE component 2'])
tSNEdf['Fault']=inv_Transform_result(y_train)
# 绘制两个主成分
fig, ax = plt.subplots(figsize=(7,7))
sns.scatterplot(x=tSNEdf['t-SNE component 1'],y=tSNEdf['t-SNE component 2'],hue='Fault',
data=tSNEdf,
legend="full",
alpha=0.3)
plt.show()





完整代码:https://mbd.pub/o/bread/mbd-Y5yYk5hv

相关推荐

为何越来越多的编程语言使用JSON(为什么编程)

JSON是JavascriptObjectNotation的缩写,意思是Javascript对象表示法,是一种易于人类阅读和对编程友好的文本数据传递方法,是JavaScript语言规范定义的一个子...

何时在数据库中使用 JSON(数据库用json格式存储)

在本文中,您将了解何时应考虑将JSON数据类型添加到表中以及何时应避免使用它们。每天?分享?最新?软件?开发?,Devops,敏捷?,测试?以及?项目?管理?最新?,最热门?的?文章?,每天?花?...

MySQL 从零开始:05 数据类型(mysql数据类型有哪些,并举例)

前面的讲解中已经接触到了表的创建,表的创建是对字段的声明,比如:上述语句声明了字段的名称、类型、所占空间、默认值和是否可以为空等信息。其中的int、varchar、char和decimal都...

JSON对象花样进阶(json格式对象)

一、引言在现代Web开发中,JSON(JavaScriptObjectNotation)已经成为数据交换的标准格式。无论是从前端向后端发送数据,还是从后端接收数据,JSON都是不可或缺的一部分。...

深入理解 JSON 和 Form-data(json和formdata提交区别)

在讨论现代网络开发与API设计的语境下,理解客户端和服务器间如何有效且可靠地交换数据变得尤为关键。这里,特别值得关注的是两种主流数据格式:...

JSON 语法(json 语法 priority)

JSON语法是JavaScript语法的子集。JSON语法规则JSON语法是JavaScript对象表示法语法的子集。数据在名称/值对中数据由逗号分隔花括号保存对象方括号保存数组JS...

JSON语法详解(json的语法规则)

JSON语法规则JSON语法是JavaScript对象表示法语法的子集。数据在名称/值对中数据由逗号分隔大括号保存对象中括号保存数组注意:json的key是字符串,且必须是双引号,不能是单引号...

MySQL JSON数据类型操作(mysql的json)

概述mysql自5.7.8版本开始,就支持了json结构的数据存储和查询,这表明了mysql也在不断的学习和增加nosql数据库的有点。但mysql毕竟是关系型数据库,在处理json这种非结构化的数据...

JSON的数据模式(json数据格式示例)

像XML模式一样,JSON数据格式也有Schema,这是一个基于JSON格式的规范。JSON模式也以JSON格式编写。它用于验证JSON数据。JSON模式示例以下代码显示了基本的JSON模式。{"...

前端学习——JSON格式详解(后端json格式)

JSON(JavaScriptObjectNotation)是一种轻量级的数据交换格式。易于人阅读和编写。同时也易于机器解析和生成。它基于JavaScriptProgrammingLa...

什么是 JSON:详解 JSON 及其优势(什么叫json)

现在程序员还有谁不知道JSON吗?无论对于前端还是后端,JSON都是一种常见的数据格式。那么JSON到底是什么呢?JSON的定义...

PostgreSQL JSON 类型:处理结构化数据

PostgreSQL提供JSON类型,以存储结构化数据。JSON是一种开放的数据格式,可用于存储各种类型的值。什么是JSON类型?JSON类型表示JSON(JavaScriptO...

JavaScript:JSON、三种包装类(javascript 包)

JOSN:我们希望可以将一个对象在不同的语言中进行传递,以达到通信的目的,最佳方式就是将一个对象转换为字符串的形式JSON(JavaScriptObjectNotation)-JS的对象表示法...

Python数据分析 只要1分钟 教你玩转JSON 全程干货

Json简介:Json,全名JavaScriptObjectNotation,JSON(JavaScriptObjectNotation(记号、标记))是一种轻量级的数据交换格式。它基于J...

比较一下JSON与XML两种数据格式?(json和xml哪个好)

JSON(JavaScriptObjectNotation)和XML(eXtensibleMarkupLanguage)是在日常开发中比较常用的两种数据格式,它们主要的作用就是用来进行数据的传...

取消回复欢迎 发表评论:

请填写验证码