百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 编程字典 > 正文

Java调用Keras、Tensorflow模型

toyiye 2024-06-21 12:07 9 浏览 0 评论

内容导读

但是有很多公司后台应用是用Java开发的,如果用python提供HTTP接口,对业务延迟要求比较高的话,仍然会有一定得延迟,所以能不能使用Java调用模型,python可以离线的训练模型?Deeplearning4j目前支持导入Keras训练的模型,并且提供了类似python中numpy的一些功能,更方便地处理结构化的数据。该方法可以将tensor为Variable的graph全部转为constant并且使用训练后的weight。注意output_name比较重要,后面Java调用模型的时候会用到。下面的代码可以查看定义好的Keras模型的输入、输出的name,这对之后Java调用有帮助。至此,已经可以实现Keras离线训练,Java在线预测的功能。

实现python离线训练模型,Java在线预测部署。查看原文

目前深度学习主流使用python训练自己的模型,有非常多的框架提供了能快速搭建神经网络的功能,其中Keras提供了high-level的语法,底层可以使用tensorflow或者theano。

但是有很多公司后台应用是用Java开发的,如果用python提供HTTP接口,对业务延迟要求比较高的话,仍然会有一定得延迟,所以能不能使用Java调用模型,python可以离线的训练模型?(tensorflow也提供了成熟的部署方案TensorFlow Serving)

手头上有一个用Keras训练的模型,网上关于Java调用Keras模型的资料不是很多,而且大部分是重复的,并且也没有讲的很详细。大致有两种方案,一种是基于Java的深度学习库导入Keras模型实现,另外一种是用tensorflow提供的Java接口调用。

Deeplearning4J

Eclipse Deeplearning4j is the first commercial-grade, open-source, distributed deep-learning library written for Java and Scala. Integrated with Hadoop and Spark, DL4J brings AIAI to business environments for use on distributed GPUs and CPUs.

Deeplearning4j目前支持导入Keras训练的模型,并且提供了类似python中numpy的一些功能,更方便地处理结构化的数据。遗憾的是,Deeplearning4j现在只覆盖了Keras <2.0版本的大部分Layer,如果你是用Keras 2.0以上的版本,在导入模型的时候可能会报错。

了解更多:

Keras Model Import: Supported Features

Importing Models From Keras to Deeplearning4j

Tensorflow

文档,Java的文档很少,不过调用模型的过程也很简单。采用这种方式调用模型需要先将Keras导出的模型转成tensorflow的protobuf协议的模型。

1、Keras的h5模型转为pb模型

在Keras中使用model.save(model.h5)保存当前模型为HDF5格式的文件中。

Keras的后端框架使用的是tensorflow,所以先把模型导出为pb模型。在Java中只需要调用模型进行预测,所以将当前的graph中的Variable全部变成Constant,并且使用训练后的weight。以下是freeze graph的代码:

 def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
 """
 :param session: 需要转换的tensorflow的session
 :param keep_var_names:需要保留的variable,默认全部转换constant
 :param output_names:output的名字
 :param clear_devices:是否移除设备指令以获得更好的可移植性
 :return:
 """
 from tensorflow.python.framework.graph_util import convert_variables_to_constants
 graph = session.graph
 with graph.as_default():
 freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
 output_names = output_names or []
 # 如果指定了output名字,则复制一个新的Tensor,并且以指定的名字命名
 if len(output_names) > 0:
 for i in range(output_names):
 # 当前graph中复制一个新的Tensor,指定名字
 tf.identity(model.model.outputs[i], name=output_names[i])
 output_names += [v.op.name for v in tf.global_variables()]
 input_graph_def = graph.as_graph_def()
 if clear_devices:
 for node in input_graph_def.node:
 node.device = ""
 frozen_graph = convert_variables_to_constants(session, input_graph_def,
 output_names, freeze_var_names)
 return frozen_graph

该方法可以将tensor为Variable的graph全部转为constant并且使用训练后的weight。注意output_name比较重要,后面Java调用模型的时候会用到。

在Keras中,模型是这么定义的:

 def create_model(self):
 input_tensor = Input(shape=(self.maxlen,), name="input")
 x = Embedding(len(self.text2id) + 1, 200)(input_tensor)
 x = Bidirectional(LSTM(128))(x)
 x = Dense(256, activation="relu")(x)
 x = Dropout(self.dropout)(x)
 x = Dense(len(self.id2class), activation='softmax', name="output_softmax")(x)
 model = Model(inputs=input_tensor, outputs=x)
 model.compile(loss='categorical_crossentropy',
 optimizer='adam',
 metrics=['accuracy'])

下面的代码可以查看定义好的Keras模型的输入、输出的name,这对之后Java调用有帮助。

print(model.input.op.name)
print(model.output.op.name)

训练好Keras模型后,转换为pb模型:

from keras import backend as K
import tensorflow as tf
model.load_model("model.h5")
print(model.input.op.name)
print(model.output.op.name)
# 自定义output_names
frozen_graph = freeze_session(K.get_session(), output_names=["output"])
tf.train.write_graph(frozen_graph, "./", "model.pb", as_text=False)
### 输出:
# input
# output_softmax/Softmax
# 如果不自定义output_name,则生成的pb模型的output_name为output_softmax/Softmax,如果自定义则以自定义名为output_name

运行之后会生成model.pb的模型,这将是之后调用的模型。

2、Java调用

新建一个maven项目,pom里面导入tensorflow包:

<dependency>
 <groupId>org.tensorflow</groupId>
 <artifactId>tensorflow</artifactId>
 <version>1.6.0</version>
</dependency>

核心代码:

public void predict() throws Exception {
 try (Graph graph = new Graph()) {
 graph.importGraphDef(Files.readAllBytes(Paths.get(
 "path/to/model.pb"
 )));
 try (Session sess = new Session(graph)) {
 // 自己构造一个输入
 float[][] input = {{56, 632, 675, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}};
 try (Tensor x = Tensor.create(input);
 // input是输入的name,output是输出的name
 Tensor y = sess.runner().feed("input", x).fetch("output").run().get(0)) {
 float[][] result = new float[1][y.shape[1]];
 y.copyTo(result);
 System.out.println(Arrays.toString(y.shape()));
 System.out.println(Arrays.toString(result[0]));
 }
 }
 }
 }

Graph和Tensor对象都是需要通过close()方法显式地释放占用的资源,代码中使用了try-with-resources的方法实现的。

至此,已经可以实现Keras离线训练,Java在线预测的功能。

相关推荐

为何越来越多的编程语言使用JSON(为什么编程)

JSON是JavascriptObjectNotation的缩写,意思是Javascript对象表示法,是一种易于人类阅读和对编程友好的文本数据传递方法,是JavaScript语言规范定义的一个子...

何时在数据库中使用 JSON(数据库用json格式存储)

在本文中,您将了解何时应考虑将JSON数据类型添加到表中以及何时应避免使用它们。每天?分享?最新?软件?开发?,Devops,敏捷?,测试?以及?项目?管理?最新?,最热门?的?文章?,每天?花?...

MySQL 从零开始:05 数据类型(mysql数据类型有哪些,并举例)

前面的讲解中已经接触到了表的创建,表的创建是对字段的声明,比如:上述语句声明了字段的名称、类型、所占空间、默认值和是否可以为空等信息。其中的int、varchar、char和decimal都...

JSON对象花样进阶(json格式对象)

一、引言在现代Web开发中,JSON(JavaScriptObjectNotation)已经成为数据交换的标准格式。无论是从前端向后端发送数据,还是从后端接收数据,JSON都是不可或缺的一部分。...

深入理解 JSON 和 Form-data(json和formdata提交区别)

在讨论现代网络开发与API设计的语境下,理解客户端和服务器间如何有效且可靠地交换数据变得尤为关键。这里,特别值得关注的是两种主流数据格式:...

JSON 语法(json 语法 priority)

JSON语法是JavaScript语法的子集。JSON语法规则JSON语法是JavaScript对象表示法语法的子集。数据在名称/值对中数据由逗号分隔花括号保存对象方括号保存数组JS...

JSON语法详解(json的语法规则)

JSON语法规则JSON语法是JavaScript对象表示法语法的子集。数据在名称/值对中数据由逗号分隔大括号保存对象中括号保存数组注意:json的key是字符串,且必须是双引号,不能是单引号...

MySQL JSON数据类型操作(mysql的json)

概述mysql自5.7.8版本开始,就支持了json结构的数据存储和查询,这表明了mysql也在不断的学习和增加nosql数据库的有点。但mysql毕竟是关系型数据库,在处理json这种非结构化的数据...

JSON的数据模式(json数据格式示例)

像XML模式一样,JSON数据格式也有Schema,这是一个基于JSON格式的规范。JSON模式也以JSON格式编写。它用于验证JSON数据。JSON模式示例以下代码显示了基本的JSON模式。{"...

前端学习——JSON格式详解(后端json格式)

JSON(JavaScriptObjectNotation)是一种轻量级的数据交换格式。易于人阅读和编写。同时也易于机器解析和生成。它基于JavaScriptProgrammingLa...

什么是 JSON:详解 JSON 及其优势(什么叫json)

现在程序员还有谁不知道JSON吗?无论对于前端还是后端,JSON都是一种常见的数据格式。那么JSON到底是什么呢?JSON的定义...

PostgreSQL JSON 类型:处理结构化数据

PostgreSQL提供JSON类型,以存储结构化数据。JSON是一种开放的数据格式,可用于存储各种类型的值。什么是JSON类型?JSON类型表示JSON(JavaScriptO...

JavaScript:JSON、三种包装类(javascript 包)

JOSN:我们希望可以将一个对象在不同的语言中进行传递,以达到通信的目的,最佳方式就是将一个对象转换为字符串的形式JSON(JavaScriptObjectNotation)-JS的对象表示法...

Python数据分析 只要1分钟 教你玩转JSON 全程干货

Json简介:Json,全名JavaScriptObjectNotation,JSON(JavaScriptObjectNotation(记号、标记))是一种轻量级的数据交换格式。它基于J...

比较一下JSON与XML两种数据格式?(json和xml哪个好)

JSON(JavaScriptObjectNotation)和XML(eXtensibleMarkupLanguage)是在日常开发中比较常用的两种数据格式,它们主要的作用就是用来进行数据的传...

取消回复欢迎 发表评论:

请填写验证码