百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 编程字典 > 正文

可视化深度学习模型架构的6个常用的方法总结

toyiye 2024-06-21 12:24 11 浏览 0 评论

可视化有助于解释和理解深度学习模型的内部结构。 通过模型计算图的可视化可以弄清楚神经网络是如何计算的,对于模型的可视化主要包括以下几个方面:

  • 模型有多少层
  • 每层的输入和输出形状
  • 不同的层是如何连接的?
  • 每层使用的参数
  • 使用了不同的激活函数

本文将使用 Keras 和 PyTorch 构建一个简单的深度学习模型,然后使用不同的工具和技术可视化其架构。

使用Keras构建模型

import keras
# Train the model on Fashion MNIST dataset
(train_images, train_labels), _ = keras.datasets.fashion_mnist.load_data()
train_images = train_images / 255.0
# Define the model.
model = keras.models.Sequential([
keras.layers.Flatten(input_shape=(28, 28)),
keras.layers.Dense(32, activation='relu'),
keras.layers.Dropout(0.2),
keras.layers.Dense(10, activation='softmax')
])
#Compile the model
model.compile(
optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])

Keras 内置可视化模型

在 Keras 中显示模型架构的最简单就是使用 summary()方法

model.summary()

这个方法是keras内置的实现,他的原理很简单。就是遍历所有模型层并打印相关细节,如层的输入维度和输出维度、参数数量、激活类型等,我们也可以用for训练遍历实现,代码如下:

for layer in model.layers:
print("Layer Name: " + layer.name)
print("Type of layer: " + layer.__class__.__name__)
print("Input dimesion: {}".format(layer.input_shape[1:]))
print("Output dimesion: {}".format(layer.output_shape[1:]))
print("Parameter Count: {}".format( layer.count_params()))
try:
print("Activation : " + layer.activation.__name__)
print(" ")
except:
print(" ")

这种方法只能提供一些简单的信息,下面我们介绍一些更好用的方法

Keras vis_utils

keras.utils.vis_utils 提供了使用 Graphviz 绘制 Keras 模型的实用函数。但是在使用之前需要安装一些其他的依赖:

pip install pydot
pip install pydotplus
pip install graphviz

使用Graphviz,还需要在系统 PATH 中添加 Graphviz bin 文件夹的路径,设置完成后就可以使用了

model_img_file = 'model.png'
tf.keras.utils.plot_model(model, to_file=model_img_file, 
show_shapes=True, 
show_layer_activations=True, 
show_dtype=True,
show_layer_names=True )

Visualkears

Visualkears 库只支持 CNN(卷积神经网络)的分层样式架构生成和大多数模型的图形样式架构,包括普通的前馈网络。

pip install visualkeras

layered view() 用于查看 CNN 模型架构

visualkeras.layered_view(model,legend=True, draw_volume=True)

TensorBoard

TensorBoard 的 Graphs 可查看模型结构图。对于 Tensorboard,使用如下的方法。

import tensorflow as tf
from datetime import datetime
import tensorboard

如果需要在notebook中使用,可以用下面的语句加载 Tensorboard 扩展

%load_ext tensorboard

在 fit() 中使用的 Keras Tensorboard Callback

# Define the Keras TensorBoard callback.
logdir="logs/fit/" + datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = keras.callbacks.TensorBoard(log_dir=logdir)
# Train the model.
model.fit(
train_images,
train_labels, 
batch_size=64,
epochs=5, 
callbacks=[tensorboard_callback])

model.save("model.h5")

模型训练完成后,启动 TensorBoard 并等待 UI 加载。

%tensorboard --logdir logs

通过单击的“Graphs”就可以看到模型的可视化结果了。

注:在Pytorch 1.8以后中提供了from torch.utils.tensorboard import SummaryWriter也可以生成tensorboard的数据,与tensorboard 对接。

Netron

Netron 是专门为神经网络、深度学习和机器学习模型设计的查看器。 它支持 Keras、TensorFlow lite、ONNX、Caffe,并对 PyTorch、TensorFlow 有实验性支持。

pip install netron

浏览器并输入netron.app ,请单击“打开模型”并选择 h5 文件的路径上传。

就可以看到每一层的可视化结果了。

在 PyTorch 中构建一个简单的深度学习模型

import torch
from torch import nn
# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")
class NeuralNetwork(nn.Module):
def __init__(self):
super(NeuralNetwork, self).__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28*28, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 10),
)def forward(self, x):
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
pytorch_model = NeuralNetwork().to(device)
x = torch.randn( 512, 28,28,1).requires_grad_(True)
y = pytorch_model(x)

查看模型架构最直接的方法是打印它。

print(pytorch_model)

虽然可以看到完整的模型架构,但是效果还没有Keras的内置函数效果好,下面介绍一个很好用的库解决这个问题。

PyTorchViz

PyTorchViz 依赖于graphviz,所以也需要安装:

pip install graphviz
pip install torchviz

使用PyTorchViz 可视化模型非常简单,只需要一个方法即可:

from torchviz import make_dot
make_dot(y, params=dict(list(pytorch_model.named_parameters()))).render("torchviz", format="png")

上面的代码生成了一个torchviz.png文件,如下图。

总结

可视化模型架构可以更好的解释深度学习模型。 模型结构可视化显示层数、每层数据的输入和输出形状、使用的激活函数以及每层中的参数数量,为优化模型提供更好的理解。

作者:Renu Khandelwal

相关推荐

为何越来越多的编程语言使用JSON(为什么编程)

JSON是JavascriptObjectNotation的缩写,意思是Javascript对象表示法,是一种易于人类阅读和对编程友好的文本数据传递方法,是JavaScript语言规范定义的一个子...

何时在数据库中使用 JSON(数据库用json格式存储)

在本文中,您将了解何时应考虑将JSON数据类型添加到表中以及何时应避免使用它们。每天?分享?最新?软件?开发?,Devops,敏捷?,测试?以及?项目?管理?最新?,最热门?的?文章?,每天?花?...

MySQL 从零开始:05 数据类型(mysql数据类型有哪些,并举例)

前面的讲解中已经接触到了表的创建,表的创建是对字段的声明,比如:上述语句声明了字段的名称、类型、所占空间、默认值和是否可以为空等信息。其中的int、varchar、char和decimal都...

JSON对象花样进阶(json格式对象)

一、引言在现代Web开发中,JSON(JavaScriptObjectNotation)已经成为数据交换的标准格式。无论是从前端向后端发送数据,还是从后端接收数据,JSON都是不可或缺的一部分。...

深入理解 JSON 和 Form-data(json和formdata提交区别)

在讨论现代网络开发与API设计的语境下,理解客户端和服务器间如何有效且可靠地交换数据变得尤为关键。这里,特别值得关注的是两种主流数据格式:...

JSON 语法(json 语法 priority)

JSON语法是JavaScript语法的子集。JSON语法规则JSON语法是JavaScript对象表示法语法的子集。数据在名称/值对中数据由逗号分隔花括号保存对象方括号保存数组JS...

JSON语法详解(json的语法规则)

JSON语法规则JSON语法是JavaScript对象表示法语法的子集。数据在名称/值对中数据由逗号分隔大括号保存对象中括号保存数组注意:json的key是字符串,且必须是双引号,不能是单引号...

MySQL JSON数据类型操作(mysql的json)

概述mysql自5.7.8版本开始,就支持了json结构的数据存储和查询,这表明了mysql也在不断的学习和增加nosql数据库的有点。但mysql毕竟是关系型数据库,在处理json这种非结构化的数据...

JSON的数据模式(json数据格式示例)

像XML模式一样,JSON数据格式也有Schema,这是一个基于JSON格式的规范。JSON模式也以JSON格式编写。它用于验证JSON数据。JSON模式示例以下代码显示了基本的JSON模式。{"...

前端学习——JSON格式详解(后端json格式)

JSON(JavaScriptObjectNotation)是一种轻量级的数据交换格式。易于人阅读和编写。同时也易于机器解析和生成。它基于JavaScriptProgrammingLa...

什么是 JSON:详解 JSON 及其优势(什么叫json)

现在程序员还有谁不知道JSON吗?无论对于前端还是后端,JSON都是一种常见的数据格式。那么JSON到底是什么呢?JSON的定义...

PostgreSQL JSON 类型:处理结构化数据

PostgreSQL提供JSON类型,以存储结构化数据。JSON是一种开放的数据格式,可用于存储各种类型的值。什么是JSON类型?JSON类型表示JSON(JavaScriptO...

JavaScript:JSON、三种包装类(javascript 包)

JOSN:我们希望可以将一个对象在不同的语言中进行传递,以达到通信的目的,最佳方式就是将一个对象转换为字符串的形式JSON(JavaScriptObjectNotation)-JS的对象表示法...

Python数据分析 只要1分钟 教你玩转JSON 全程干货

Json简介:Json,全名JavaScriptObjectNotation,JSON(JavaScriptObjectNotation(记号、标记))是一种轻量级的数据交换格式。它基于J...

比较一下JSON与XML两种数据格式?(json和xml哪个好)

JSON(JavaScriptObjectNotation)和XML(eXtensibleMarkupLanguage)是在日常开发中比较常用的两种数据格式,它们主要的作用就是用来进行数据的传...

取消回复欢迎 发表评论:

请填写验证码