百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 编程字典 > 正文

一种基于频率激活映射的机械故障诊断可解释神经网络

toyiye 2024-06-21 12:39 8 浏览 0 评论

该方法主要使用频率激活图解释设备振动信号的部分意义,频率激活图将基于学习模型的时域分类标准可视化到频域,模型也比较容易理解,为后续的机械故障诊断可解释性添砖加瓦。

所用数据为西储大学轴承数据集中的正常工况数据”和“12k驱动端轴承故障数据”。数据集预处理实验配置:窗长= 2048,重叠率= 25%。

首先安装相应的深度学习模块

pip install torch
pip install torchvision
pip install pytorch-model-summary

加载模块

import torch
import torch.nn as nn
import pandas as pd
import numpy as np
import seaborn as sns
from pathlib import Path
import os
import pytorch_model_summary
from torch.utils.data import Dataset, DataLoader
import torchvision as tv
import torch.nn.functional as F
from datetime import datetime
from tqdm.notebook import trange, tqdm
import matplotlib.pyplot as plt

在Colab上运行

if 'google.colab' in str(get_ipython()):
    print('Running on Colab')
    from google.colab import drive

    drive.mount('/content/gdrive')
    data_dir = Path('/content/gdrive/My Drive', 'dataset', 'fault diagnosis')
    checkpoint_dir = Path('/content/gdrive/My Drive', 'checkpoints', 'fam')
else:
    data_dir = Path('.', 'dataset')
    checkpoint_dir = Path('.', 'checkpoints', 'fam')

data_dir.mkdir(exist_ok=True, parents=True)
checkpoint_dir.mkdir(exist_ok=True, parents=True)

file_name = {}
files = os.listdir(data_dir)
for file in files:
    if file.find('Train') != -1:
        file_name['Train'] = file
    elif file.find('Test') != -1:
        file_name['Test'] = file

定义卷积块

def conv_block(inp, oup, kernel, stride, padding, bias, num):
    layer_norm = nn.LayerNorm(2048)
    for name, param in layer_norm.named_parameters():
      if name == 'bias':
          param.requires_grad = False

    seq = nn.Sequential()
    seq.add_module(f'Conv_{num}', nn.Conv1d(inp, oup, kernel, stride, 'same', bias=bias))
    seq.add_module(f'LayerNorm_{num}',layer_norm)
    seq.add_module(f'Tanh_{num}', nn.Tanh())
    return seq

在优化步骤后调用此函数,在权范数上添加约束

def max_norm_(model, min_val=0.8, max_val=1.2, eps=1e-8):
    for name, param in model.named_parameters():
        if 'bias' not in name and 'Conv' in name:
            with torch.no_grad():
                norm = param.norm('fro', dim=[3, 1], keepdim=True)**2
                desired = torch.clamp(norm, min=min_val, max=max_val)
                param.copy_(param * torch.sqrt(desired / (eps + norm)))

def global_power_pooling(x):
    return torch.pow(x, 2).mean(3).mean(2)

定义频率激活映射网络类

class ConvolutionFAM(nn.Module):
    def __init__(self, in_channel, kernel_size, bias, stride, channels, n_class=12):
        super(ConvolutionFAM, self).__init__()
        block = conv_block
        self.features = []
        for idx, out_channel in enumerate(channels):
            self.features.append(block(in_channel, out_channel, kernel_size, stride[idx], 'same', bias, idx))
            in_channel = out_channel
        self.features = nn.Sequential(*self.features)
        self.classifier = nn.Linear(channels[-1], n_class)

    def forward(self, x):
        x_freq = self.features(x)
        x_out = global_power_pooling(x_freq)
        x_out = self.classifier(x_out)
        return x_freq, x_out

class BearingDataset(Dataset):
    def __init__(self, file_path, window_size):
        self.dataset = pd.read_csv(file_path, header=None).reset_index(drop=True)

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        x = self.dataset.iloc[idx, 0:window_size]
        y = self.dataset.iloc[idx, window_size]
        return torch.reshape(torch.FloatTensor(x.to_numpy()), (1, 1, window_size)), torch.tensor(y)

def training(model, dataloader, optimizer):
    epoch_loss = 0.0
    epoch_accuracy = 0.0
    dataset_len = len(dataloader.dataset)
    model.train()

    for batch in tqdm(dataloader):
        x, y = batch
        if torch.cuda.is_available():
            x, y = x.cuda(), y.cuda()
        
        optimizer.zero_grad()
        _, pred = model(x)
        softmax_pred = F.softmax(pred, dim=1)
        loss = F.cross_entropy(pred, y)

        epoch_loss += loss.detach()
        epoch_accuracy += torch.sum(torch.argmax(softmax_pred, dim=1) == y, dtype=torch.int64).detach()

        loss.backward()
        optimizer.step()
        max_norm_(model)
    
    epoch_loss /= dataset_len
    epoch_accuracy = epoch_accuracy / dataset_len * 100

    return epoch_loss, epoch_accuracy

定义频率激活映射相关的函数

def freq_activation_map(model, input, width, channels, target_label):
    '''
        Param:
            model : Neural Network Object
            input : timeseries data
            width : length of power_spectrum(input)
            channels : # last channel of the model
    '''
    fam = torch.zeros(input.shape[0], 1, width)
    if torch.cuda.is_available():
        fam = fam.cuda()    

    with torch.no_grad():
        freq, labels = model(torch.reshape(input, (-1, 1, 1, input.shape[-1])))
        labels = torch.argmax(F.softmax(labels, dim=1), dim=1)
        labels = torch.where(labels == target_label, 1., 0.)
        labels = torch.unsqueeze(torch.reshape(labels, [-1, 1]).repeat(1, width), 1)
        for c in range(channels):
            sp = freq[:, c, :, :]
            if torch.cuda.is_available():
                sp = sp.cuda()
            
            sp = power_spectrum(sp)
            sp = sp * labels            

            if model.classifier.weight[target_label, c] > 0:
                fam += model.classifier.weight[target_label, c] * sp
      
    return torch.squeeze(torch.sum(fam, dim=0)), torch.sum(labels[:, 0, 0], dim=0)

def normalize_freq(fam):
    max_fam, _ = torch.max(fam, dim=-1, keepdim=True)
    return torch.div(fam, max_fam)

#  P. Welch's method to compute power spectrum, 
def power_spectrum(t_freq):
    result = torch.abs(torch.fft.rfft(t_freq))**2
    return result / torch.mean(result, dim=2, keepdim=True)

训练的相关参数

lr = 1e-3
batch_size = 256
epochs = 100
sampling_rate = 12000
labels = ['Normal', 'FAULT7_INNER', 'FAULT14_INNER', 'FAULT21_INNER', 'FAULT28_INNER',
    'FAULT7_BALL', 'FAULT14_BALL', 'FAULT21_BALL', 'FAULT28_BALL',
    'FAULT7_OUTER', 'FAULT14_OUTER', 'FAULT21_OUTER']

window_size = 2048
n_class = 12
bias = False
kernel_size = (1, 7)
in_channel = 1
stride = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
channels = [4, 4, 8, 8, 16, 16, 16, 32, 32, 32]


model = ConvolutionFAM(in_channel, kernel_size, bias, stride, channels, n_class)

if torch.cuda.is_available():
    model = model.cuda()

optimizer = torch.optim.Adam(model.parameters(), lr=lr)

训练数据

train_data = BearingDataset(Path(data_dir, 'bearing_dataset.csv'), window_size)
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, pin_memory=True)

网络训练

for epoch in trange(epochs):
    loss, accuracy = training(model, train_loader, optimizer)
    print(f'Epoch {epoch+1} :: loss : {loss:.8f} / accuracy : {accuracy:.8f}%')

torch.save(model.state_dict(), Path(checkpoint_dir, 'CNN_FAM_21_12_30'))

<All keys matched successfully>

网络结构

dummy_input = torch.zeros(batch_size, in_channel, 1, window_size)
if torch.cuda.is_available():
    dummy_input = dummy_input.cuda()

print(pytorch_model_summary.summary(model, dummy_input))
 

Layer (type) Output Shape Param # Tr. Param #

==========================================================================

Conv1d-1 [256, 4, 1, 2048] 28 28

LayerNorm-2 [256, 4, 1, 2048] 4,096 2,048

Tanh-3 [256, 4, 1, 2048] 0 0

Conv1d-4 [256, 4, 1, 2048] 112 112

LayerNorm-5 [256, 4, 1, 2048] 4,096 2,048

Tanh-6 [256, 4, 1, 2048] 0 0

Conv1d-7 [256, 8, 1, 2048] 224 224

LayerNorm-8 [256, 8, 1, 2048] 4,096 2,048

Tanh-9 [256, 8, 1, 2048] 0 0

Conv1d-10 [256, 8, 1, 2048] 448 448

LayerNorm-11 [256, 8, 1, 2048] 4,096 2,048

Tanh-12 [256, 8, 1, 2048] 0 0

Conv1d-13 [256, 16, 1, 2048] 896 896

LayerNorm-14 [256, 16, 1, 2048] 4,096 2,048

Tanh-15 [256, 16, 1, 2048] 0 0

Conv1d-16 [256, 16, 1, 2048] 1,792 1,792

LayerNorm-17 [256, 16, 1, 2048] 4,096 2,048

Tanh-18 [256, 16, 1, 2048] 0 0

Conv1d-19 [256, 16, 1, 2048] 1,792 1,792

LayerNorm-20 [256, 16, 1, 2048] 4,096 2,048

Tanh-21 [256, 16, 1, 2048] 0 0

Conv1d-22 [256, 32, 1, 2048] 3,584 3,584

LayerNorm-23 [256, 32, 1, 2048] 4,096 2,048

Tanh-24 [256, 32, 1, 2048] 0 0

Conv1d-25 [256, 32, 1, 2048] 7,168 7,168

LayerNorm-26 [256, 32, 1, 2048] 4,096 2,048

Tanh-27 [256, 32, 1, 2048] 0 0

Conv1d-28 [256, 32, 1, 2048] 7,168 7,168

LayerNorm-29 [256, 32, 1, 2048] 4,096 2,048

Tanh-30 [256, 32, 1, 2048] 0 0

Linear-31 [256, 12] 396 396

==========================================================================

Total params: 64,568

Trainable params: 44,088

Non-trainable params: 20,480

freq_intervals = np.fft.rfftfreq(window_size, d=1/sampling_rate)
total_fam = torch.zeros(n_class, len(freq_intervals))
total_len = torch.zeros(n_class, 1)
if torch.cuda.is_available():
    total_fam = total_fam.cuda()
    total_len = total_len.cuda()

for batch in tqdm(train_loader):
    x, y = batch
    if torch.cuda.is_available():
        x = x.cuda()    
    for c in range(n_class):
        tmp_fam, cnt = freq_activation_map(model, x, len(freq_intervals), channels[-1], c)
        total_fam[c, :] += tmp_fam
        total_len[c] += cnt

total_fam /= total_len

0%| | 0/54 [00:00<?, ?it/s]

绘制归一化频谱

rtotal_fam = normalize_freq(total_fam).cpu().detach()
columns_ = freq_intervals
plt.plot(columns_, rtotal_fam[0, :])
 

plt.plot(columns_, rtotal_fam[5, :])

绘制频率激活映射图

result = pd.DataFrame(rtotal_fam.numpy(),
                      index = ['Normal', 'FAULT7_INNER', 'FAULT14_INNER', 'FAULT21_INNER', 'FAULT28_INNER',
    'FAULT7_BALL', 'FAULT14_BALL', 'FAULT21_BALL', 'FAULT28_BALL',
    'FAULT7_OUTER', 'FAULT14_OUTER', 'FAULT21_OUTER'],
                      columns = columns_,
                      )

new_index = ['FAULT21_OUTER', 'FAULT14_OUTER', 'FAULT7_OUTER', 'FAULT28_INNER',
             'FAULT21_INNER', 'FAULT14_INNER', 'FAULT7_INNER', 'FAULT28_BALL', 
             'FAULT21_BALL', 'FAULT14_BALL', 'FAULT7_BALL', 'Normal']

new_index.reverse()

result = result.reindex(new_index)

sns.heatmap(result,
            cmap='viridis',
            )

看看激活频率,结合理论故障频率,效果一目了然。

相关的文章参考

几种信号降噪算法(第一部分)

https://www.toutiao.com/article/7190201924820402721/

几种信号降噪算法(第二部分)

https://www.toutiao.com/article/7190270349236683264/

机械故障诊断及工业工程故障诊断若干例子(第一篇)

https://www.toutiao.com/article/7193957227231855163/

知乎咨询:哥廷根数学学派

算法代码地址,面包多主页:

https://mbd.pub/o/GeBENHAGEN/work

擅长现代信号处理(改进小波分析系列,改进变分模态分解,改进经验小波变换,改进辛几何模态分解等等),改进机器学习,改进深度学习,机械故障诊断,改进时间序列分析(金融信号,心电信号,振动信号等)

相关推荐

为何越来越多的编程语言使用JSON(为什么编程)

JSON是JavascriptObjectNotation的缩写,意思是Javascript对象表示法,是一种易于人类阅读和对编程友好的文本数据传递方法,是JavaScript语言规范定义的一个子...

何时在数据库中使用 JSON(数据库用json格式存储)

在本文中,您将了解何时应考虑将JSON数据类型添加到表中以及何时应避免使用它们。每天?分享?最新?软件?开发?,Devops,敏捷?,测试?以及?项目?管理?最新?,最热门?的?文章?,每天?花?...

MySQL 从零开始:05 数据类型(mysql数据类型有哪些,并举例)

前面的讲解中已经接触到了表的创建,表的创建是对字段的声明,比如:上述语句声明了字段的名称、类型、所占空间、默认值和是否可以为空等信息。其中的int、varchar、char和decimal都...

JSON对象花样进阶(json格式对象)

一、引言在现代Web开发中,JSON(JavaScriptObjectNotation)已经成为数据交换的标准格式。无论是从前端向后端发送数据,还是从后端接收数据,JSON都是不可或缺的一部分。...

深入理解 JSON 和 Form-data(json和formdata提交区别)

在讨论现代网络开发与API设计的语境下,理解客户端和服务器间如何有效且可靠地交换数据变得尤为关键。这里,特别值得关注的是两种主流数据格式:...

JSON 语法(json 语法 priority)

JSON语法是JavaScript语法的子集。JSON语法规则JSON语法是JavaScript对象表示法语法的子集。数据在名称/值对中数据由逗号分隔花括号保存对象方括号保存数组JS...

JSON语法详解(json的语法规则)

JSON语法规则JSON语法是JavaScript对象表示法语法的子集。数据在名称/值对中数据由逗号分隔大括号保存对象中括号保存数组注意:json的key是字符串,且必须是双引号,不能是单引号...

MySQL JSON数据类型操作(mysql的json)

概述mysql自5.7.8版本开始,就支持了json结构的数据存储和查询,这表明了mysql也在不断的学习和增加nosql数据库的有点。但mysql毕竟是关系型数据库,在处理json这种非结构化的数据...

JSON的数据模式(json数据格式示例)

像XML模式一样,JSON数据格式也有Schema,这是一个基于JSON格式的规范。JSON模式也以JSON格式编写。它用于验证JSON数据。JSON模式示例以下代码显示了基本的JSON模式。{"...

前端学习——JSON格式详解(后端json格式)

JSON(JavaScriptObjectNotation)是一种轻量级的数据交换格式。易于人阅读和编写。同时也易于机器解析和生成。它基于JavaScriptProgrammingLa...

什么是 JSON:详解 JSON 及其优势(什么叫json)

现在程序员还有谁不知道JSON吗?无论对于前端还是后端,JSON都是一种常见的数据格式。那么JSON到底是什么呢?JSON的定义...

PostgreSQL JSON 类型:处理结构化数据

PostgreSQL提供JSON类型,以存储结构化数据。JSON是一种开放的数据格式,可用于存储各种类型的值。什么是JSON类型?JSON类型表示JSON(JavaScriptO...

JavaScript:JSON、三种包装类(javascript 包)

JOSN:我们希望可以将一个对象在不同的语言中进行传递,以达到通信的目的,最佳方式就是将一个对象转换为字符串的形式JSON(JavaScriptObjectNotation)-JS的对象表示法...

Python数据分析 只要1分钟 教你玩转JSON 全程干货

Json简介:Json,全名JavaScriptObjectNotation,JSON(JavaScriptObjectNotation(记号、标记))是一种轻量级的数据交换格式。它基于J...

比较一下JSON与XML两种数据格式?(json和xml哪个好)

JSON(JavaScriptObjectNotation)和XML(eXtensibleMarkupLanguage)是在日常开发中比较常用的两种数据格式,它们主要的作用就是用来进行数据的传...

取消回复欢迎 发表评论:

请填写验证码