百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 编程字典 > 正文

使用ONNX使模型通用化

toyiye 2024-04-07 14:12 18 浏览 0 评论



什么是ONNX?

ONNX(Open Neural Network Exchange)- 开放神经网络交换格式,作为框架共用的一种模型交换格式,使用protobuf 二进制格式来序列化模型,可以提供更好的传输性能我们可能会在某一任务中Pytorch或者TensorFlow模型转化为ONNX模型(ONNX模型一般用于中间部署阶段),然后再拿转化后的ONNX模型进而转化为我们使用不同框架部署需要的类型,ONNX相当于一个翻译的作用。


为什么要用ONNX?

深度学习算法大多通过计算数据流图来完成神经网络的深度学习过程。一些框架(例如CNTK,Caffe2,Theano和TensorFlow)使用静态图形,而其他框架(例如PyTorch和Chainer)使用动态图形。但是这些框架都提供了接口,使开发人员可以轻松构建计算图和运行时,以优化的方式处理图。这些图用作中间表示(IR),捕获开发人员源代码的特定意图,有助于优化和转换在特定设备(CPU,GPU,FPGA等)上运行。假设一个场景:现在某组织因为主要开发用TensorFlow为基础的框架,现在有一个深度算法,需要将其部署在移动设备上,以观测变现。传统地我们需要用Caffe2重新将模型写好,然后再训练参数;试想下这将是一个多么耗时耗力的过程。此时,ONNX便应运而生,Caffe2,PyTorch,Microsoft Cognitive Toolkit,Apache MXNet等主流框架都对ONNX有着不同程度的支持。这就便于我们的算法及模型在不同框架之间的迁移。


ONNX结构分析

ONNX将每一个网络的每一层或者说是每一个算子当作节点Node,再由这些Node去构建一个Graph,相当于是一个网络。最后将Graph和这个ONNX模型的其他信息结合在一起,生成一个Model,也就是最终的.onnx的模型。构建一个简单的ONNX模型,实质上,只要构建好每一个node,然后将它们和输入输出超参数一起塞到Graph,最后转成Model就可以了。

graph{    node{        input: "1"        input: "2"        output: "12"        op_type: "Conv"    }    attribute{        name: "strides"        ints: 1        ints: 1    }    attribute{        name: "pads"        ints: 2        ints: 2    }    ...}


我们查看ONNX网络结构和参数(查看网址:https://netron.app/)


ONNX安装、使用

安装ONNX环境,在终端中执行以下命令,环境中需要提前准本 python3.6. 以下流程以ubunt 20.04 为例。


模型转换流程

超分辨率是一种提高图像、视频分辨率的算法,广泛用于图像处理或视频编辑。首先,让我们在PyTorch中创建一个SuperResolution 模型。该模型使用描述的高效子像素卷积层将图像的分辨率提高了一个放大因子。该模型将图像的YCbCr的Y分量作为输入,并以超分辨率输出放大的Y分量。

# Some standard importsimport ioimport numpy as npfrom torch import nnimport torch.utils.model_zoo as model_zooimport torch.onnx# Super Resolution model definition in PyTorchimport torch.nn as nnimport torch.nn.init as initclass SuperResolutionNet(nn.Module):    def __init__(self, upscale_factor, inplace=False):        super(SuperResolutionNet, self).__init__()        self.relu = nn.ReLU(inplace=inplace)        self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))        self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))        self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))        self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1))        self.pixel_shuffle = nn.PixelShuffle(upscale_factor)        self._initialize_weights()    def forward(self, x):        x = self.relu(self.conv1(x))        x = self.relu(self.conv2(x))        x = self.relu(self.conv3(x))        x = self.pixel_shuffle(self.conv4(x))        return x    def _initialize_weights(self):        init.orthogonal_(self.conv1.weight, init.calculate_gain('relu'))        init.orthogonal_(self.conv2.weight, init.calculate_gain('relu'))        init.orthogonal_(self.conv3.weight, init.calculate_gain('relu'))        init.orthogonal_(self.conv4.weight)# Create the super-resolution model by using the above model definition.torch_model = SuperResolutionNet(upscale_factor=3)



模型下载

由于本教程以演示为目的,因此采用下载预先训练好的权重。在导出模型之前调用torch_model.eval()或torch_model.train(False)将模型转换为推理模式很重要。因为dropout或batchnorm等运算符在推理和训练模式下的行为不同。

# Load pretrained model weightsmodel_url = 'https://s3.amazonaws.com/pytorch/test_data/export/superres_epoch100-44c6958e.pth'batch_size = 1    # just a random number# Initialize model with the pretrained weightsmap_location = lambda storage, loc: storageif torch.cuda.is_available():    map_location = Nonetorch_model.load_state_dict(model_zoo.load_url(model_url, map_location=map_location))# set the model to inference modetorch_model.eval()


模型导出

要导出模型,我们调用该torch.onnx.export() 函数。这将执行模型,记录用于计算输出的运算符。因为export运行模型,我们需要提供一个输入张量x。只要它是正确的类型和大小,其中的值可以是随机的。请注意,除非指定为动态轴,否则所有输入维度的导出ONNX图中的输入大小将是固定的。在此示例中,我们使用batch_size 1的输入导出模型,但随后在dynamic_axes参数中将第一个维度指定为动态 torch.onnx.export() . 因此,导出的模型将接受大小为[batch_size, 1, 224, 224]的输入,其中batch_size可以是可变的。


# Input to the modelx = torch.randn(batch_size, 1, 224, 224, requires_grad=True)torch_out = torch_model(x)# Export the modeltorch.onnx.export(torch_model,               # model being run                  x,                         # model input (or a tuple for multiple inputs)                  "super_resolution.onnx",   # where to save the model (can be a file or file-like object)                  export_params=True,        # store the trained parameter weights inside the model file                  opset_version=10,          # the ONNX version to export the model to                  do_constant_folding=True,  # whether to execute constant folding for optimization                  input_names = ['input'],   # the model's input names                  output_names = ['output'], # the model's output names                  dynamic_axes={'input' : {0 : 'batch_size'},    # variable length axes                                'output' : {0 : 'batch_size'}})


导出模型测试

在使用ONNX Runtime验证模型的输出之前,我们将使用ONNX的 API检查ONNX 模型。首先,onnx.load("super_resolution.onnx") 将加载保存的模型并输出 onnx.ModelProto结构(用于捆绑 ML 模型的顶级文件/容器格式)。然后,onnx.checker.check_model(onnx_model) 将验证模型的结构并确认模型具有有效的架构。ONNX 图的有效性通过检查模型的版本、图的结构以及节点及其输入和输出来验证。


import onnxonnx_model = onnx.load("super_resolution.onnx")onnx.checker.check_model(onnx_model)import onnxruntimeort_session = onnxruntime.InferenceSession("super_resolution.onnx")def to_numpy(tensor):    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()# compute ONNX Runtime output predictionort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}ort_outs = ort_session.run(None, ort_inputs)# compare ONNX Runtime and PyTorch resultsnp.testing.assert_allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05)print("Exported model has been tested with ONNXRuntime, and the result looks good!")


理前图片:


1.加载处理前图片,使用标准PIL python库对其进行预处理。

2.调整图像大小以适应模型输入的大小 (224x224)。

处理后结果:

注:文章仅代表作者个人的观点,欢迎大家留言交流。


作者介绍

塔超,海云捷迅研发工程师。本科毕业于内蒙古科技大学并获得计算机主修学士学位。拥有丰富的项目经验,开发过AI、K8S相关的项目。



相关推荐

为何越来越多的编程语言使用JSON(为什么编程)

JSON是JavascriptObjectNotation的缩写,意思是Javascript对象表示法,是一种易于人类阅读和对编程友好的文本数据传递方法,是JavaScript语言规范定义的一个子...

何时在数据库中使用 JSON(数据库用json格式存储)

在本文中,您将了解何时应考虑将JSON数据类型添加到表中以及何时应避免使用它们。每天?分享?最新?软件?开发?,Devops,敏捷?,测试?以及?项目?管理?最新?,最热门?的?文章?,每天?花?...

MySQL 从零开始:05 数据类型(mysql数据类型有哪些,并举例)

前面的讲解中已经接触到了表的创建,表的创建是对字段的声明,比如:上述语句声明了字段的名称、类型、所占空间、默认值和是否可以为空等信息。其中的int、varchar、char和decimal都...

JSON对象花样进阶(json格式对象)

一、引言在现代Web开发中,JSON(JavaScriptObjectNotation)已经成为数据交换的标准格式。无论是从前端向后端发送数据,还是从后端接收数据,JSON都是不可或缺的一部分。...

深入理解 JSON 和 Form-data(json和formdata提交区别)

在讨论现代网络开发与API设计的语境下,理解客户端和服务器间如何有效且可靠地交换数据变得尤为关键。这里,特别值得关注的是两种主流数据格式:...

JSON 语法(json 语法 priority)

JSON语法是JavaScript语法的子集。JSON语法规则JSON语法是JavaScript对象表示法语法的子集。数据在名称/值对中数据由逗号分隔花括号保存对象方括号保存数组JS...

JSON语法详解(json的语法规则)

JSON语法规则JSON语法是JavaScript对象表示法语法的子集。数据在名称/值对中数据由逗号分隔大括号保存对象中括号保存数组注意:json的key是字符串,且必须是双引号,不能是单引号...

MySQL JSON数据类型操作(mysql的json)

概述mysql自5.7.8版本开始,就支持了json结构的数据存储和查询,这表明了mysql也在不断的学习和增加nosql数据库的有点。但mysql毕竟是关系型数据库,在处理json这种非结构化的数据...

JSON的数据模式(json数据格式示例)

像XML模式一样,JSON数据格式也有Schema,这是一个基于JSON格式的规范。JSON模式也以JSON格式编写。它用于验证JSON数据。JSON模式示例以下代码显示了基本的JSON模式。{"...

前端学习——JSON格式详解(后端json格式)

JSON(JavaScriptObjectNotation)是一种轻量级的数据交换格式。易于人阅读和编写。同时也易于机器解析和生成。它基于JavaScriptProgrammingLa...

什么是 JSON:详解 JSON 及其优势(什么叫json)

现在程序员还有谁不知道JSON吗?无论对于前端还是后端,JSON都是一种常见的数据格式。那么JSON到底是什么呢?JSON的定义...

PostgreSQL JSON 类型:处理结构化数据

PostgreSQL提供JSON类型,以存储结构化数据。JSON是一种开放的数据格式,可用于存储各种类型的值。什么是JSON类型?JSON类型表示JSON(JavaScriptO...

JavaScript:JSON、三种包装类(javascript 包)

JOSN:我们希望可以将一个对象在不同的语言中进行传递,以达到通信的目的,最佳方式就是将一个对象转换为字符串的形式JSON(JavaScriptObjectNotation)-JS的对象表示法...

Python数据分析 只要1分钟 教你玩转JSON 全程干货

Json简介:Json,全名JavaScriptObjectNotation,JSON(JavaScriptObjectNotation(记号、标记))是一种轻量级的数据交换格式。它基于J...

比较一下JSON与XML两种数据格式?(json和xml哪个好)

JSON(JavaScriptObjectNotation)和XML(eXtensibleMarkupLanguage)是在日常开发中比较常用的两种数据格式,它们主要的作用就是用来进行数据的传...

取消回复欢迎 发表评论:

请填写验证码