百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 编程字典 > 正文

用机器给心爱的人写首诗

toyiye 2024-06-21 12:38 15 浏览 0 评论

现在AI的应用很广泛,机器学习提供了许多有效的解决方案,其中深度学习在图片识别、自然语言处理方面是当前有效的解决方案之一,通过深度学习的方式可以实现许多有趣的功能。本文将显示如何用机器学习的方式,给心爱的人写古诗,本项目使用LSTM进行古诗生成,并使用当前流量的AI框架TensorFlow2.4版本进行代码的实现。关于环境的搭建可以查看之前的文章 Ubuntu快速安装tensorflow2.4的gpu版本

本文也提供了源码,需要的可以查看:

https://gitee.com/fork-out-project/LotteryPredict

先看一下效果:

因为训练速度比较慢,本文只取了1000首诗,训练的效果不算太好。加大诗的数量,效果会更加好。

千言万语不如看代码。

具体代码如下:

#!/usr/bin/env python
# coding: utf-8

# # 古诗生成
# 本项目使用LSTM进行古诗生成,TensorFlow2.0实现。

# ## LSTM介绍
# 我们需要从过往的历史数据中寻找规律,[`LSTM`](https://en.wikipedia.org/wiki/Long_short-term_memory)再适合不过了。如果你对LSTM不熟悉的话,以下几篇文章建议你阅读:
# 
# [`Understanding LSTM Networks`](http://colah.github.io/posts/2015-08-Understanding-LSTMs/)
# 
# 
# [`RNN以及LSTM的介绍和公式梳理`](http://blog.csdn.net/Dark_Scope/article/details/47056361)
# 

# ## 加载数据集

# In[1]:


data_dir = './data/newtxt.txt'  # new_poetry
#data_dir = './data/new_poetry.txt'  # new_poetry
text = open(data_dir, 'rb').read().decode(encoding='utf-8')


# 下面函数是用来处理初始数据集poetry.txt的,使用newtxt.txt可以不调用。

# In[2]:


import os
import re
pattern = '[a-zA-Z0-9’"#$%&\'()*+-./:;<=>@★…【】《》“”‘’[\\]^_`{|}~]+'

def preprocess_poetry(outdir, datadir):
    with open(os.path.join(outdir, 'new_poetry.txt'), 'w') as out_f:
        with open(os.path.join(datadir, 'poetry.txt'), 'r') as f:
            for line in f:
                content = line.strip().rstrip('\n').split(':')[1]  # .rstrip('\n').
                content = content.replace(' ','')
                if '】' in content or '_' in content or '(' in content or '(' in content or '《' in content or '[' in content:
                    continue
                if len(content) < 20:
                    continue
                content=re.sub(pattern, '', content)
                out_f.write(content + '\n')


# In[3]:


preprocess_poetry('./data/', './data/')


# In[4]:


len(text)


# ## 预测网络介绍

# 网络的输入是每一个汉字,总共有1020个字,用one hot编码是一个1020维的稀疏向量。

# 使用one hot稀疏向量在输入层与网络第一层做矩阵乘法时会很没有效率,因为向量里面大部分都是0, 矩阵乘法浪费了大量的计算,最终矩阵运算得出的结果是向量中值为1的列所对应的矩阵中的行向量。
# <img src="assets/lookup_matrix.png">
# 
# 这看起来很像用索引查表一样,one hot向量中值为1的位置作为下标,去索引参数矩阵中的行向量。

# 为了代替矩阵乘法,我们将参数矩阵当作一个查找表(lookup table)或者叫做嵌入矩阵(embedding matrix),使用每个汉字所对应索引,比如汉字“你”,索引是958,然后在查找表中找第958行。
# 
# 这其实跟替换之前的模型没有什么不同,嵌入矩阵就是参数矩阵,嵌入层仍然是隐层。查找表只是矩阵乘法的一种便捷方式,它会像参数矩阵一样被训练,是要学习的参数。

# 下面就是我们要构建的网络架构,从嵌入层输出的向量进入LSTM层进行时间序列的学习,然后经过softmax预测出下一个汉字。

# ## 编码实现

# ### 实现数据预处理
# 
# 我们需要先准备好汉字和ID之间的转换关系。在这个函数中,创建并返回两个字典:
# - 汉字到ID的转换字典: `vocab_to_int`
# - ID到汉字的转换字典: `int_to_vocab`
# 

# In[5]:


import numpy as np
from collections import Counter
import pickle

def create_lookup_tables():
    # 去重排序
    vocab = sorted(set(text))
    vocab_to_int = {u:i for i, u in enumerate(vocab)}
    int_to_vocab = np.array(vocab)
    
    int_text = np.array([vocab_to_int[word] for word in text if word != '\n'])

    # 保存数据
    pickle.dump((int_text, vocab_to_int, int_to_vocab), open('preprocess.p', 'wb'))


# ### 处理所有数据并保存
# 将每期结果按照从第一期开始的顺序保存到文件中。

# In[6]:


create_lookup_tables()


# In[7]:


import numpy as np
# 读取保存的数据
int_text, vocab_to_int, int_to_vocab = pickle.load(open('preprocess.p', mode='rb'))


# In[8]:


def get_batches(int_text, batch_size, seq_length):

    batchCnt = len(int_text) // (batch_size * seq_length)
    int_text_inputs = int_text[:batchCnt * (batch_size * seq_length)]
    int_text_targets = int_text[1:batchCnt * (batch_size * seq_length)+1]

    result_list = []
    x = np.array(int_text_inputs).reshape(1, batch_size, -1)
    y = np.array(int_text_targets).reshape(1, batch_size, -1)

    x_new = np.dsplit(x, batchCnt)
    y_new = np.dsplit(y, batchCnt)

    for ii in range(batchCnt):
        x_list = []
        x_list.append(x_new[ii][0])
        x_list.append(y_new[ii][0])
        result_list.append(x_list)

    return np.array(result_list)


# In[9]:


len(int_to_vocab)


# ## 训练神经网络
# ### 超参数
# 

# In[10]:


vocab_size = len(int_to_vocab)

# 批次大小
batch_size = 32  # 64
# RNN的大小(隐藏节点的维度)
rnn_size = 1000
# 嵌入层的维度
embed_dim = 256  # 这里做了调整,跟彩票预测的也不同了
# 序列的长度
#seq_length = 15  # 注意这里已经不是1了,在古诗预测里面这个数值可以大一些,比如100也可以的
seq_length = 80
save_dir = './save'


# ### 构建计算图
# 使用实现的神经网络构建计算图。

# In[11]:


import tensorflow as tf
import datetime
from tensorflow import keras
from tensorflow.python.ops import summary_ops_v2
import time

MODEL_DIR = "./poetry_models"

train_batches = get_batches(int_text, batch_size, seq_length)  
losses = {'train': [], 'test': []}


class poetry_network(object):
    def __init__(self, batch_size=32):
        self.batch_size = batch_size  
        self.best_loss = 9999

        self.model = tf.keras.Sequential([
            tf.keras.layers.Embedding(vocab_size, embed_dim,
                                      batch_input_shape=[batch_size, None]),
            tf.keras.layers.LSTM(rnn_size,
                                 return_sequences=True,
                                 stateful=True,
                                 recurrent_initializer='glorot_uniform'),
            tf.keras.layers.Dense(vocab_size)
        ])
        self.model.summary()

        self.optimizer = tf.keras.optimizers.Adam()
        self.ComputeLoss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

        if tf.io.gfile.exists(MODEL_DIR):
            #             print('Removing existing model dir: {}'.format(MODEL_DIR))
            #             tf.io.gfile.rmtree(MODEL_DIR)
            pass
        else:
            tf.io.gfile.makedirs(MODEL_DIR)

        train_dir = os.path.join(MODEL_DIR, 'summaries', 'train')

        self.train_summary_writer = summary_ops_v2.create_file_writer(train_dir, flush_millis=10000)
 
        checkpoint_dir = os.path.join(MODEL_DIR, 'checkpoints')
        self.checkpoint_prefix = os.path.join(checkpoint_dir, 'ckpt')
        self.checkpoint = tf.train.Checkpoint(model=self.model, optimizer=self.optimizer)

        # Restore variables on creation if a checkpoint exists.
        self.checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

    @tf.function
    def train_step(self, x, y):
        # Record the operations used to compute the loss, so that the gradient
        # of the loss with respect to the variables can be computed.

        with tf.GradientTape() as tape:
            logits = self.model(x, training=True)

            loss = self.ComputeLoss(y, logits)

        grads = tape.gradient(loss, self.model.trainable_variables)
        self.optimizer.apply_gradients(zip(grads, self.model.trainable_variables))
        return loss, logits

    def training(self, epochs=1, log_freq=50):

        batchCnt = len(int_text) // (batch_size * seq_length)
        print("batchCnt : ", batchCnt)
        for i in range(epochs):
            train_start = time.time()
            with self.train_summary_writer.as_default():
                start = time.time()
                # Metrics are stateful. They accumulate values and return a cumulative
                # result when you call .result(). Clear accumulated values with .reset_states()
                avg_loss = tf.keras.metrics.Mean('loss', dtype=tf.float32)

                # Datasets can be iterated over like any other Python iterable.
                for batch_i, (x, y) in enumerate(train_batches):
                    loss, logits = self.train_step(x, y)
                    avg_loss(loss)
                    losses['train'].append(loss)

                    if tf.equal(self.optimizer.iterations % log_freq, 0):
                        summary_ops_v2.scalar('loss', avg_loss.result(), step=self.optimizer.iterations)

                        rate = log_freq / (time.time() - start)
                        print('Step #{}\tLoss: {:0.6f} ({} steps/sec)'.format(
                            self.optimizer.iterations.numpy(), loss, rate))

                        avg_loss.reset_states()

                        start = time.time()
#                         self.checkpoint.save(self.checkpoint_prefix)
            self.checkpoint.save(self.checkpoint_prefix)
            print("save model\n")


# ## 训练
# 在预处理过的数据上训练神经网络。 

# In[30]:


net = poetry_network()
net.training(20)


# ## 显示训练Loss

# In[29]:


get_ipython().run_line_magic('matplotlib', 'inline')
get_ipython().run_line_magic('config', "InlineBackend.figure_format = 'retina'")
# import seaborn as sns
import matplotlib.pyplot as plt

plt.plot(losses['train'], label='Training loss')
plt.legend()
_ = plt.ylim()


# ## 加载保存的模型,准备预测

# In[14]:


restore_net=poetry_network(1)
restore_net.model.build(tf.TensorShape([1, None]))


# ## 生成古诗
# 开始生成古诗了。
#  - `prime_word` 是开始的头一个字。
#  - `top_n` 从前N个候选汉字中随机选择
#  - `rule` 默认是7言绝句
#  - `sentence_lines` 生成几句古诗,默认是4句(,和。都算一句)
#  - `hidden_head` 藏头诗的前几个字

# In[15]:


def gen_poetry(prime_word='白', top_n=5, rule=7, sentence_lines=4, hidden_head=None):
    gen_length = sentence_lines * (rule + 1) - len(prime_word)
    gen_sentences = [prime_word] if hidden_head==None else [hidden_head[0]]
    temperature = 1.0

    dyn_input = [vocab_to_int[s] for s in prime_word]
    dyn_input = tf.expand_dims(dyn_input, 0)

    dyn_seq_length = len(dyn_input[0])

    restore_net.model.reset_states()
    index=len(prime_word) if hidden_head==None else 1
    for n in range(gen_length):

        index += 1
        predictions = restore_net.model(np.array(dyn_input))
        
        predictions = tf.squeeze(predictions, 0)

        if index!=0 and (index % (rule+1)) == 0:
            if ((index / (rule+1)) + 1) % 2 == 0:
                predicted_id=vocab_to_int[',']
            else:
                predicted_id=vocab_to_int['。']
        else:
            if hidden_head != None and (index-1)%(rule+1)==0 and (index-1)//(rule+1) < len(hidden_head):
                predicted_id=vocab_to_int[hidden_head[(index-1)//(rule+1)]]
            else:
                while True:
                    predictions = predictions / temperature
                    predicted_id = tf.random.categorical(predictions, num_samples=1)[-1, 0].numpy()

                    # p = np.squeeze(predictions[-1].numpy())
                    # p[np.argsort(p)[:-top_n]] = 0
                    # p = p / np.sum(p)
                    # c = np.random.choice(vocab_size, 1, p=p)[0]
                    # predicted_id=c
                    if(predicted_id != vocab_to_int[','] and predicted_id != vocab_to_int['。'] ):
                        break
    # using a multinomial distribution to predict the word returned by the model
    #         predictions = predictions / temperature
    #         predicted_id = tf.multinomial(predictions, num_samples=1)[-1,0].numpy()

        dyn_input = tf.expand_dims([predicted_id], 0)
        gen_sentences.append(int_to_vocab[predicted_id])

    poetry_script = ' '.join(gen_sentences)
    poetry_script = poetry_script.replace('\n ', '\n')
    poetry_script = poetry_script.replace('( ', '(')

    return poetry_script


# ## 给定开头

# In[95]:


gen_poetry(prime_word='举头望明月', top_n=10, rule=5, sentence_lines=4)


# ## 7言绝句

# In[79]:


gen_poetry(prime_word='春', top_n=10, rule=7, sentence_lines=4)


# ## 5言

# In[109]:


gen_poetry(prime_word='月', top_n=10, rule=5, sentence_lines=4)


# ## 藏头诗

# In[158]:


gen_poetry(prime_word='夏', top_n=10, rule=7, sentence_lines=4, hidden_head='风和日丽')


# # 结论
# 
# 使用原始的数据集训练精度还是可以。

喜欢的朋友评论、点赞、转发、收藏本文。有疑问的在评论区留言。谢谢!

相关推荐

为何越来越多的编程语言使用JSON(为什么编程)

JSON是JavascriptObjectNotation的缩写,意思是Javascript对象表示法,是一种易于人类阅读和对编程友好的文本数据传递方法,是JavaScript语言规范定义的一个子...

何时在数据库中使用 JSON(数据库用json格式存储)

在本文中,您将了解何时应考虑将JSON数据类型添加到表中以及何时应避免使用它们。每天?分享?最新?软件?开发?,Devops,敏捷?,测试?以及?项目?管理?最新?,最热门?的?文章?,每天?花?...

MySQL 从零开始:05 数据类型(mysql数据类型有哪些,并举例)

前面的讲解中已经接触到了表的创建,表的创建是对字段的声明,比如:上述语句声明了字段的名称、类型、所占空间、默认值和是否可以为空等信息。其中的int、varchar、char和decimal都...

JSON对象花样进阶(json格式对象)

一、引言在现代Web开发中,JSON(JavaScriptObjectNotation)已经成为数据交换的标准格式。无论是从前端向后端发送数据,还是从后端接收数据,JSON都是不可或缺的一部分。...

深入理解 JSON 和 Form-data(json和formdata提交区别)

在讨论现代网络开发与API设计的语境下,理解客户端和服务器间如何有效且可靠地交换数据变得尤为关键。这里,特别值得关注的是两种主流数据格式:...

JSON 语法(json 语法 priority)

JSON语法是JavaScript语法的子集。JSON语法规则JSON语法是JavaScript对象表示法语法的子集。数据在名称/值对中数据由逗号分隔花括号保存对象方括号保存数组JS...

JSON语法详解(json的语法规则)

JSON语法规则JSON语法是JavaScript对象表示法语法的子集。数据在名称/值对中数据由逗号分隔大括号保存对象中括号保存数组注意:json的key是字符串,且必须是双引号,不能是单引号...

MySQL JSON数据类型操作(mysql的json)

概述mysql自5.7.8版本开始,就支持了json结构的数据存储和查询,这表明了mysql也在不断的学习和增加nosql数据库的有点。但mysql毕竟是关系型数据库,在处理json这种非结构化的数据...

JSON的数据模式(json数据格式示例)

像XML模式一样,JSON数据格式也有Schema,这是一个基于JSON格式的规范。JSON模式也以JSON格式编写。它用于验证JSON数据。JSON模式示例以下代码显示了基本的JSON模式。{"...

前端学习——JSON格式详解(后端json格式)

JSON(JavaScriptObjectNotation)是一种轻量级的数据交换格式。易于人阅读和编写。同时也易于机器解析和生成。它基于JavaScriptProgrammingLa...

什么是 JSON:详解 JSON 及其优势(什么叫json)

现在程序员还有谁不知道JSON吗?无论对于前端还是后端,JSON都是一种常见的数据格式。那么JSON到底是什么呢?JSON的定义...

PostgreSQL JSON 类型:处理结构化数据

PostgreSQL提供JSON类型,以存储结构化数据。JSON是一种开放的数据格式,可用于存储各种类型的值。什么是JSON类型?JSON类型表示JSON(JavaScriptO...

JavaScript:JSON、三种包装类(javascript 包)

JOSN:我们希望可以将一个对象在不同的语言中进行传递,以达到通信的目的,最佳方式就是将一个对象转换为字符串的形式JSON(JavaScriptObjectNotation)-JS的对象表示法...

Python数据分析 只要1分钟 教你玩转JSON 全程干货

Json简介:Json,全名JavaScriptObjectNotation,JSON(JavaScriptObjectNotation(记号、标记))是一种轻量级的数据交换格式。它基于J...

比较一下JSON与XML两种数据格式?(json和xml哪个好)

JSON(JavaScriptObjectNotation)和XML(eXtensibleMarkupLanguage)是在日常开发中比较常用的两种数据格式,它们主要的作用就是用来进行数据的传...

取消回复欢迎 发表评论:

请填写验证码