百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 编程字典 > 正文

如何不调用后端接口,在浏览器中直接运行训练好的模型

toyiye 2024-05-19 19:35 26 浏览 0 评论


00、背景

要实现一个机器学习的算法,我们javaer程序员习惯编写后端代码,经常会遇到nginx配置、NullPointerException、CPU、IO的报警、服务器的不定期扩容,及运维的不可控等等一系列工作和问题。

有没好办法,直接调用云原生的服务,没有专人运维,解决以上痛点?

01、效果

将使用在 MNIST 数据集上训练的手写数字识别模型的演示示例。

静态文件:

可以部署在oss上,执行:

npm install http-server -g

02、ONNX简介

开放神经网络交换ONNX(Open Neural Network Exchange)是一套表示深度神经网络模型的开放格式,由微软和Facebook于2017推出,然后迅速得到了各大厂商和框架的支持。通过短短几年的发展,已经成为表示深度学习模型的实际标准,并且通过ONNX-ML,可以支持传统非神经网络机器学习模型,统一整个AI模型交换标准。

ONNX定义了一组与环境和平台无关的标准格式,为AI模型的互操作性提供了基础,使AI模型可以在不同框架和环境下交互使用。硬件和软件厂商可以基于ONNX标准优化模型性能,让所有兼容ONNX标准的框架受益。目前,ONNX主要关注在模型预测方面(inferring),使用不同框架训练的模型,转化为ONNX格式后,可以很容易地部署在兼容ONNX的运行环境中。

03、再介绍下ONNX.JS

微软开源的,目前已经由ONNX Runtime Web替代

1、ONNX.JS 是一个 JavaScript 库,用于在浏览器和 Node.js 上运行ONNX模型

2、ONNX.JS 采用了 WebAssembly 和 WebGL 技术,为 CPU 和 GPU 提供优化的 ONNX 模型推理 Runtime

04、使用 ONNX.js 在浏览器中运行 PyTorch 模型

共有三步:

1、将 PyTorch 模型转换为 ONNX 格式

2、使用 ONNX.js 在您的网站或应用程序中加载该 ONNX 模型

3、使用 JavaScript 在浏览器中运行 PyTorch 模型

05、在浏览器中运行模型的好处

  • 使用较小的模型进行更快的推理
  • 易于托管和扩展(仅限静态文件)
  • 离线支持
  • 用户隐私(可以保留设备上的数据)

06、使用后端服务器的好处

  • 加载时间更快(不必下载模型)
  • 较大模型的推理时间更快且一致(可以利用 GPU 或其他加速器)
  • 模型隐私(如果您想保密,则不必共享您的模型)

07、主要代码

可视化数据集

def main():
train_loader = torch.utils.data.DataLoader(
torchvision.datasets.MNIST(
'data', train=True, download=True,
transform=torchvision.transforms.Compose([
torchvision.transforms.RandomAffine(
degrees=30, translate=(0.5, 0.5), scale=(0.25, 1),
shear=(-30, 30, -30, 30)),

torchvision.transforms.ToTensor(),
])),
batch_size=800)
inputs_batch, labels_batch = next(iter(train_loader))
grid = torchvision.utils.make_grid(inputs_batch, nrow=40, pad_value=1)
torchvision.utils.save_image(grid, 'inputs_batch_preview.png')

训练模型

def main():
# Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
help='input batch size for training (default: 64)')
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs', type=int, default=14, metavar='N',
help='number of epochs to train (default: 14)')
parser.add_argument('--lr', type=float, default=1.0, metavar='LR',
help='learning rate (default: 1.0)')
parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
help='Learning rate step gamma (default: 0.7)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
help='how many batches to wait before logging training status')

parser.add_argument('--save-model', action='store_true', default=False,
help='For Saving the current Model')
args = parser.parse_args()
use_cuda = not args.no_cuda and torch.cuda.is_available()

torch.manual_seed(args.seed)

device = torch.device("cuda" if use_cuda else "cpu")

kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('data', train=True, download=True,
transform=transforms.Compose([
# Add random transformations to the image.
transforms.RandomAffine(
degrees=30, translate=(0.5, 0.5), scale=(0.25, 1),
shear=(-30, 30, -30, 30)),

transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=args.batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=args.test_batch_size, shuffle=True, **kwargs)

model = Net().to(device)
optimizer = optim.Adadelta(model.parameters(), lr=args.lr)

scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
for epoch in range(1, args.epochs + 1):
train(args, model, device, train_loader, optimizer, epoch)
test(args, model, device, test_loader)
scheduler.step()

torch.save(model.state_dict(), "pytorch_model.pt")

PyTorch 模型转换为 ONNX 格式

def main():
pytorch_model = Net()
pytorch_model.load_state_dict(torch.load('pytorch_model.pt'))
pytorch_model.eval()
dummy_input = torch.zeros(280 * 280 * 4)
torch.onnx.export(pytorch_model, dummy_input, 'onnx_model.onnx', verbose=True)

前端代码

const CANVAS_SIZE = 280;
const CANVAS_SCALE = 0.5;

const canvas = document.getElementById("canvas");
const ctx = canvas.getContext("2d");
const clearButton = document.getElementById("clear-button");

let isMouseDown = false;
let hasIntroText = true;
let lastX = 0;
let lastY = 0;

// Load our model.
const sess = new onnx.InferenceSession();
const loadingModelPromise = sess.loadModel("./onnx_model.onnx");

ctx.lineWidth = 28;
ctx.lineJoin = "round";
ctx.font = "28px sans-serif";
ctx.textAlign = "center";
ctx.textBaseline = "middle";
ctx.fillStyle = "#212121";
ctx.fillText("加载中...", CANVAS_SIZE / 2, CANVAS_SIZE / 2);

ctx.strokeStyle = "#212121";

function clearCanvas() {
ctx.clearRect(0, 0, CANVAS_SIZE, CANVAS_SIZE);
for (let i = 0; i < 10; i++) {
const element = document.getElementById(`prediction-${i}`);
element.className = "prediction-col";
element.children[0].children[0].style.height = "0";
}
}

function drawLine(fromX, fromY, toX, toY) {
// Draws a line from (fromX, fromY) to (toX, toY).
ctx.beginPath();
ctx.moveTo(fromX, fromY);
ctx.lineTo(toX, toY);
ctx.closePath();
ctx.stroke();
updatePredictions();
}

async function updatePredictions() {
const imgData = ctx.getImageData(0, 0, CANVAS_SIZE, CANVAS_SIZE);
const input = new onnx.Tensor(new Float32Array(imgData.data), "float32");

const outputMap = await sess.run([input]);
const outputTensor = outputMap.values().next().value;
const predictions = outputTensor.data;
const maxPrediction = Math.max(...predictions);

for (let i = 0; i < predictions.length; i++) {
const element = document.getElementById(`prediction-${i}`);
element.children[0].children[0].style.height = `${predictions[i] * 100}%`;
element.className =
predictions[i] === maxPrediction
? "prediction-col top-prediction"
: "prediction-col";
}
}

function canvasMouseDown(event) {
isMouseDown = true;
if (hasIntroText) {
clearCanvas();
hasIntroText = false;
}
const x = event.offsetX / CANVAS_SCALE;
const y = event.offsetY / CANVAS_SCALE;

// To draw a dot on the mouse down event, we set laxtX and lastY to be
// slightly offset from x and y, and then we call `canvasMouseMove(event)`,
// which draws a line from (laxtX, lastY) to (x, y) that shows up as a
// dot because the difference between those points is so small. However,
// if the points were the same, nothing would be drawn, which is why the
// 0.001 offset is added.
lastX = x + 0.001;
lastY = y + 0.001;
canvasMouseMove(event);
}

function canvasMouseMove(event) {
const x = event.offsetX / CANVAS_SCALE;
const y = event.offsetY / CANVAS_SCALE;
if (isMouseDown) {
drawLine(lastX, lastY, x, y);
}
lastX = x;
lastY = y;
}

function bodyMouseUp() {
isMouseDown = false;
}

function bodyMouseOut(event) {
if (!event.relatedTarget || event.relatedTarget.nodeName === "HTML") {
isMouseDown = false;
}
}

loadingModelPromise.then(() => {
canvas.addEventListener("mousedown", canvasMouseDown);
canvas.addEventListener("mousemove", canvasMouseMove);
document.body.addEventListener("mouseup", bodyMouseUp);
document.body.addEventListener("mouseout", bodyMouseOut);
clearButton.addEventListener("mousedown", clearCanvas);

ctx.clearRect(0, 0, CANVAS_SIZE, CANVAS_SIZE);
ctx.fillText("在这里写一个数字!", CANVAS_SIZE / 2, CANVAS_SIZE / 2);
})

相关推荐

为何越来越多的编程语言使用JSON(为什么编程)

JSON是JavascriptObjectNotation的缩写,意思是Javascript对象表示法,是一种易于人类阅读和对编程友好的文本数据传递方法,是JavaScript语言规范定义的一个子...

何时在数据库中使用 JSON(数据库用json格式存储)

在本文中,您将了解何时应考虑将JSON数据类型添加到表中以及何时应避免使用它们。每天?分享?最新?软件?开发?,Devops,敏捷?,测试?以及?项目?管理?最新?,最热门?的?文章?,每天?花?...

MySQL 从零开始:05 数据类型(mysql数据类型有哪些,并举例)

前面的讲解中已经接触到了表的创建,表的创建是对字段的声明,比如:上述语句声明了字段的名称、类型、所占空间、默认值和是否可以为空等信息。其中的int、varchar、char和decimal都...

JSON对象花样进阶(json格式对象)

一、引言在现代Web开发中,JSON(JavaScriptObjectNotation)已经成为数据交换的标准格式。无论是从前端向后端发送数据,还是从后端接收数据,JSON都是不可或缺的一部分。...

深入理解 JSON 和 Form-data(json和formdata提交区别)

在讨论现代网络开发与API设计的语境下,理解客户端和服务器间如何有效且可靠地交换数据变得尤为关键。这里,特别值得关注的是两种主流数据格式:...

JSON 语法(json 语法 priority)

JSON语法是JavaScript语法的子集。JSON语法规则JSON语法是JavaScript对象表示法语法的子集。数据在名称/值对中数据由逗号分隔花括号保存对象方括号保存数组JS...

JSON语法详解(json的语法规则)

JSON语法规则JSON语法是JavaScript对象表示法语法的子集。数据在名称/值对中数据由逗号分隔大括号保存对象中括号保存数组注意:json的key是字符串,且必须是双引号,不能是单引号...

MySQL JSON数据类型操作(mysql的json)

概述mysql自5.7.8版本开始,就支持了json结构的数据存储和查询,这表明了mysql也在不断的学习和增加nosql数据库的有点。但mysql毕竟是关系型数据库,在处理json这种非结构化的数据...

JSON的数据模式(json数据格式示例)

像XML模式一样,JSON数据格式也有Schema,这是一个基于JSON格式的规范。JSON模式也以JSON格式编写。它用于验证JSON数据。JSON模式示例以下代码显示了基本的JSON模式。{"...

前端学习——JSON格式详解(后端json格式)

JSON(JavaScriptObjectNotation)是一种轻量级的数据交换格式。易于人阅读和编写。同时也易于机器解析和生成。它基于JavaScriptProgrammingLa...

什么是 JSON:详解 JSON 及其优势(什么叫json)

现在程序员还有谁不知道JSON吗?无论对于前端还是后端,JSON都是一种常见的数据格式。那么JSON到底是什么呢?JSON的定义...

PostgreSQL JSON 类型:处理结构化数据

PostgreSQL提供JSON类型,以存储结构化数据。JSON是一种开放的数据格式,可用于存储各种类型的值。什么是JSON类型?JSON类型表示JSON(JavaScriptO...

JavaScript:JSON、三种包装类(javascript 包)

JOSN:我们希望可以将一个对象在不同的语言中进行传递,以达到通信的目的,最佳方式就是将一个对象转换为字符串的形式JSON(JavaScriptObjectNotation)-JS的对象表示法...

Python数据分析 只要1分钟 教你玩转JSON 全程干货

Json简介:Json,全名JavaScriptObjectNotation,JSON(JavaScriptObjectNotation(记号、标记))是一种轻量级的数据交换格式。它基于J...

比较一下JSON与XML两种数据格式?(json和xml哪个好)

JSON(JavaScriptObjectNotation)和XML(eXtensibleMarkupLanguage)是在日常开发中比较常用的两种数据格式,它们主要的作用就是用来进行数据的传...

取消回复欢迎 发表评论:

请填写验证码