百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 编程字典 > 正文

使用TensorFlow.js在Node.js中进行机器学习

toyiye 2024-07-04 09:10 14 浏览 0 评论

作者:jthomas

黑胡桃实验室敲制

在本文中,作者将介绍如何在Node.js环境下使用TensorFlow.js,并使用MobileNet模型来完成一个分类任务。

前言

TensorFlow.js是流行的机器学习开源库的新版本,它为JavaScript带来了机器学习的力量。开发人员现在可以使用TensorFlow.js的高级API定义,训练和运行机器学习模型。

TensorFlow.js可以使用预先训练的模型,这意味着开发人员现在可以通过几行JavaScript 轻松执行复杂的任务,如视觉识别生成音乐或姿态估计

TensorFlow.js作为Web浏览器的前端库,最近增加了对Node.js的支持。这允许TensorFlow.js在后端JavaScript应用程序中使用,而无需使用Python。

不幸的是,官网提供的大多数文档和示例代码都在浏览器中使用TensorFlow.js库。为了简化加载和使用预先训练的模型而提供的项目实用程序尚未添加对Node.js的支持。最后,我花了很多时间阅读库中的源代码。

经过了几天的研究,我设法完成了下面这个教程。欢呼!

① 安装TensorFlow.js的库

这里,我们可以直接使用NPM进行安装

@tensorflow/tfjs TensorFlow.js核心库@tensorflow/tfjs-node TensorFlow.js的Node.js扩展库@tensorflow/tfjs-node-gpu 支持GPU的TensorFlow.js的扩展库

npm install @tensorflow/tfjs @tensorflow/tfjs-node
// or...
npm install @tensorflow/tfjs @tensorflow/tfjs-node-gpu

Node.js扩展都使用本地依赖项,这些依赖项将会被在本地进行编译。

加载TensorFlow库

先导入TensorFlow.js的核心库,再导入Node.js的扩展库。

const tf = require('@tensorflow/tfjs')// 导入CPU版的
require('@tensorflow/tfjs-node')// 导入GPU版的
require('@tensorflow/tfjs-node-gpu')

③加载TensorFlow模型

TensorFlow.js提供了一个NPM库(tfjs-models),使用它可以轻松地加载经过转换的预训练模型,可以用于图像分类,姿态估计和KNN分类器等。

这里我们将使用用于图像分类的MobileNet模型,这个模型是经过预训练的,可以识别1000个不同类别物体的深度神经网络。

import * as mobilenet from '@tensorflow-models/mobilenet';// 导入模型.
const model = await mobilenet.load();

我们遇到的第一个挑战就是这种模型加载方式并不适用于Node.js。

Error: browserHTTPRequest is not supported outside the web browser.

通过查看源代码发现,MobileNet库是底层 tf.Model类的包装器,调用 load() 方法时,它会自动从外部HTTP地址上下载正确的模型文件,并实例化TensorFlow模型。

Node.js扩展库不支持使用HTTP请求来加载模型。相反,必须从文件系统手动加载模型。

在阅读过库的源代码后,我设法创建了一个解决方案......

从文件系统加载模型

如果手动创建MobileNet类,而不是调用模块的加载方法,则可以使用本地模型文件的路径来覆盖存有模型文件的HTTP地址的路径变量。完成此操作后,在实例上调用 load() 方法将触发文件系统加载器类,而不是使用基于浏览器的HTTP加载器类。

const path = "mobilenet/model.json"
const mn = new mobilenet.MobileNet(1, 1);// 这里替换成模型文件在磁盘上的路径
mn.path = `file://${path}`
await mn.load()

太棒了,它开始工作了!

但是,模型文件是来源于哪里呢?

下载模型文件

TensorFlow.js的模型由两种文件类型组成,一种是存储在JSON中的模型配置文件,另一种是二进制格式的模型权重。模型权重通常被分片为多个文件,以便浏览器更好地进行缓存。

查看 MobileNet模型的加载代码,发现模型的配置文件和权重分配都存储在公共的数据存储空间中。

https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v${version}_${alpha}_${size}/

URL中的 version 参数可以填下面的模型版本。每个版本的分类准确性结果也显示在该页面上。

在源代码中,只能使用tensorflow-models/mobilenet库加载MobileNet v1模型。

HTTP检索代码将从此URL加载model.json文件,然后递归获取所有引用的模型权重分片。这些文件的格式为 groupX-shard1of1。

将所有模型文件保存到文件系统可以通过检索模型配置文件,解析引用的权重文件并手动下载每个权重文件来实现。

我想使用具有1.0 alpha值和224像素图像大小的MobileNet V1模块,这是相关模型的URL。

https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v1_1.0_224/model.json

然后我们将模型配置文件model.json保存下来,放入到jupyter中查看模型文件的结构。

从中可以看出,权重分片共有54个(图中没有显示完全),第一个权重分片叫做 group1-shard10f1 。可以通过以下URL来进行下载

https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v1_1.0_224/group1-shard1of1

但是共有54个权重分片,手动下载起来很麻烦,我们准备了一个下载脚本,来自动化下载模型文件这个过程。

import urllib.request
import json
## 加载 model.js文件
with open("model.json", encoding='utf-8-sig') as json_file:
 json_data = json.load(json_file)
 ## 模型下载地址
url = "https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v1_1.0_224/"
weights = json_data['weightsManifest']
## 下载权重分片
for index in range(len(weights)):
 filename = weights[index]['paths'][0]
 response = urllib.request.urlopen(url + filename)
 data = response.read() # a `bytes` object
 with open("model/" + filename, 'wb') as f:
 f.write(data)

进行图像分类

参考TensorFlow.js中的示例代码,加载图像并对图像进行分类

const img = document.getElementById('img');// 对图像进行分类.
const predictions = await model.classify(img);

由于缺少DOM对象,这在Node.js上并不能使用

MobileNet的实例对象的classify()方法接收一个DOM对象(Canvas,Video,Img),并自动检索这些对象中的图像字节并将其转换为 tf.Tensor3D 类型,作为模型的输入。也可以直接向 classify()方法中直接传入 tf.Tensor3D 类型的变量。

我没有尝试使用外部包来模拟Node.js中的DOM元素,而是采用手动构建 tf.Tensor3D 的变量,因为这样更加容易

利用下面的代码,可以将一个图像数组转换为相应的Tensor3D的变量。

const values = new Int32Array(image.height * image.width * numChannels);
// 整理输出变量的形状
const outShape = [image.height, image.width, numChannels];
const input = tf.tensor3d(values, outShape, 'int32');

values 是一个int32类型的2D数组,其包含了每个像素通道值的顺序列表,numChannels是每个像素的通道值。

jpeg-js 是一个可以在Node.js上使用的JavaScript JPEG编码器和解码器,使用该库可以提取每个像素的RGB值。

const pixels = jpeg.decode(buffer, true);

这将为每个像素(宽度*高度)返回一个带有四个通道值(RGBA)的Uint8Array。MobileNet模型仅使用三个颜色通道(RGB)进行分类,忽略alpha通道。下面的代码将四通道数组转换为正确的三通道版本。

const numChannels = 3;
const numPixels = image.width * image.height;
const values = new Int32Array(numPixels * numChannels);
for (let i = 0; i < numPixels; i++) {
 for (let channel = 0; channel < numChannels; ++channel) {
 values[i * numChannels + channel] = pixels[i * 4 + channel];
 }
}

MobileNet模型将宽度和高度为224像素的图像进行分类。对于三个通道像素值中的每一个,输入张量必须是介于-1和1之间的浮点值。

因此,在分类之前,需要重新调整不同尺寸图像的输入值。另外,来自JPEG解码器的像素值在0到255范围内 ,而不是 -1到1。这些值还需要在进行分类之前进行转换。

TensorFlow.js库中有方法使这个过程更容易,但幸运的是,tfjs-models/mobilenets 会自动处理这个问题!

开发人员可以将类型为int32和不同维度的Tensor3D输入传递给classify方法,并在分类之前将输入转换为正确的格式。超级。

最后,对代码进行整理,最终结果如下:

// 导入相应的库
const tf = require('@tensorflow/tfjs')
const mobilenet = require('@tensorflow-models/mobilenet');
require('@tensorflow/tfjs-node')
const fs = require('fs');
const jpeg = require('jpeg-js');
const NUMBER_OF_CHANNELS = 3
// 读取图片
const readImage = path => {
 const buf = fs.readFileSync(path)
 const pixels = jpeg.decode(buf, true)
 return pixels
}
// 将 4通道的图像数组 转换为 3通道的图像数组
const imageByteArray = (image, numChannels) => {
 const pixels = image.data
 const numPixels = image.width * image.height;
 const values = new Int32Array(numPixels * numChannels);
 for (let i = 0; i < numPixels; i++) {
 for (let channel = 0; channel < numChannels; ++channel) {
 values[i * numChannels + channel] = pixels[i * 4 + channel];
 }
 }
 return values
}
// 将图片数组转换为 Tensor3D类型 
const imageToInput = (image, numChannels) => {
 const values = imageByteArray(image, numChannels)
 const outShape = [image.height, image.width, numChannels];
 const input = tf.tensor3d(values, outShape, 'int32');
 return input
}
// 从文件系统中导入模型
const loadModel = async path => {
 const mn = new mobilenet.MobileNet(1, 1);
 mn.path = `file://${path}`
 await mn.load()
 return mn
}
const classify = async (model, path) => {
 // 读取图片
 const image = readImage(path)
 // 将图片数组 转换为 Tensor3D类型的变量
 const input = imageToInput(image, NUMBER_OF_CHANNELS)
 // 导入模型
 const mn_model = await loadModel(model)
 // 对图片进行推理
 const predictions = await mn_model.classify(input)
 console.log('classification results:', predictions)
}
if (process.argv.length !== 4) throw new Error('incorrect arguments: node script.js <MODEL> <IMAGE_FILE>')
classify(process.argv[2], process.argv[3])

测试模型

下载示例图片,准备进行分类。

wget http://bit.ly/2JYSal9 -O panda.jpg

运行脚本对图像进行分类,在运行脚本前需要传入相应的参数

node script.js mobilenet/model.json panda.jpg

如果一切正常,则应将以下输出打印到控制台。

classification results: [ {
 className: 'giant panda, panda, panda bear, coon bear',
 probability: 0.9993536472320557 
} ]

这幅图片被正确分类成熊猫,概率为99.93% 。

原文链接:https://dev.to/jthomas/machine-learning-in-nodejs-with-tensorflowjs-1g1p


本文由黑胡桃实验室敲制,转载请获得授权。

相关推荐

为何越来越多的编程语言使用JSON(为什么编程)

JSON是JavascriptObjectNotation的缩写,意思是Javascript对象表示法,是一种易于人类阅读和对编程友好的文本数据传递方法,是JavaScript语言规范定义的一个子...

何时在数据库中使用 JSON(数据库用json格式存储)

在本文中,您将了解何时应考虑将JSON数据类型添加到表中以及何时应避免使用它们。每天?分享?最新?软件?开发?,Devops,敏捷?,测试?以及?项目?管理?最新?,最热门?的?文章?,每天?花?...

MySQL 从零开始:05 数据类型(mysql数据类型有哪些,并举例)

前面的讲解中已经接触到了表的创建,表的创建是对字段的声明,比如:上述语句声明了字段的名称、类型、所占空间、默认值和是否可以为空等信息。其中的int、varchar、char和decimal都...

JSON对象花样进阶(json格式对象)

一、引言在现代Web开发中,JSON(JavaScriptObjectNotation)已经成为数据交换的标准格式。无论是从前端向后端发送数据,还是从后端接收数据,JSON都是不可或缺的一部分。...

深入理解 JSON 和 Form-data(json和formdata提交区别)

在讨论现代网络开发与API设计的语境下,理解客户端和服务器间如何有效且可靠地交换数据变得尤为关键。这里,特别值得关注的是两种主流数据格式:...

JSON 语法(json 语法 priority)

JSON语法是JavaScript语法的子集。JSON语法规则JSON语法是JavaScript对象表示法语法的子集。数据在名称/值对中数据由逗号分隔花括号保存对象方括号保存数组JS...

JSON语法详解(json的语法规则)

JSON语法规则JSON语法是JavaScript对象表示法语法的子集。数据在名称/值对中数据由逗号分隔大括号保存对象中括号保存数组注意:json的key是字符串,且必须是双引号,不能是单引号...

MySQL JSON数据类型操作(mysql的json)

概述mysql自5.7.8版本开始,就支持了json结构的数据存储和查询,这表明了mysql也在不断的学习和增加nosql数据库的有点。但mysql毕竟是关系型数据库,在处理json这种非结构化的数据...

JSON的数据模式(json数据格式示例)

像XML模式一样,JSON数据格式也有Schema,这是一个基于JSON格式的规范。JSON模式也以JSON格式编写。它用于验证JSON数据。JSON模式示例以下代码显示了基本的JSON模式。{"...

前端学习——JSON格式详解(后端json格式)

JSON(JavaScriptObjectNotation)是一种轻量级的数据交换格式。易于人阅读和编写。同时也易于机器解析和生成。它基于JavaScriptProgrammingLa...

什么是 JSON:详解 JSON 及其优势(什么叫json)

现在程序员还有谁不知道JSON吗?无论对于前端还是后端,JSON都是一种常见的数据格式。那么JSON到底是什么呢?JSON的定义...

PostgreSQL JSON 类型:处理结构化数据

PostgreSQL提供JSON类型,以存储结构化数据。JSON是一种开放的数据格式,可用于存储各种类型的值。什么是JSON类型?JSON类型表示JSON(JavaScriptO...

JavaScript:JSON、三种包装类(javascript 包)

JOSN:我们希望可以将一个对象在不同的语言中进行传递,以达到通信的目的,最佳方式就是将一个对象转换为字符串的形式JSON(JavaScriptObjectNotation)-JS的对象表示法...

Python数据分析 只要1分钟 教你玩转JSON 全程干货

Json简介:Json,全名JavaScriptObjectNotation,JSON(JavaScriptObjectNotation(记号、标记))是一种轻量级的数据交换格式。它基于J...

比较一下JSON与XML两种数据格式?(json和xml哪个好)

JSON(JavaScriptObjectNotation)和XML(eXtensibleMarkupLanguage)是在日常开发中比较常用的两种数据格式,它们主要的作用就是用来进行数据的传...

取消回复欢迎 发表评论:

请填写验证码