百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 编程字典 > 正文

PyTorch 项目实战开发教程自动文本摘要生成

toyiye 2024-06-21 12:43 10 浏览 0 评论

介绍

在这个教程中,我们将使用 PyTorch 构建一个自动文本摘要生成模型,该模型能够从长文本中生成简短的摘要。我们将使用 Seq2Seq 模型,其中包含编码器和解码器,用于处理不定长的输入文本并生成摘要。通过这个项目,你将学到如何处理文本数据、构建 Seq2Seq 模型以及进行文本摘要生成。

步骤 1:环境设置

首先,确保你已经安装了 PyTorch 和其他必要的库。在终端中运行以下命令:

pip install torch torchvision


步骤 2:准备数据集

我们将使用一个包含长文本和对应摘要的数据集进行模型训练。你可以根据实际需求选择合适的文本摘要数据集。

import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from nltk.tokenize import word_tokenize
import string

class TextSummaryDataset(Dataset):
    def __init__(self, texts, summaries):
        self.texts = texts
        self.summaries = summaries

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        return self.texts[idx], self.summaries[idx]

# 使用示例数据
texts = ["This is a sample text for summarization.",
         "PyTorch is an open-source deep learning library.",
         "Natural language processing involves understanding and analyzing human language."]
summaries = ["Sample text summarization example.",
             "PyTorch is a deep learning library.",
             "NLP analyzes human language."]

# 创建数据集实例
dataset = TextSummaryDataset(texts, summaries)

步骤 3:文本处理与Tokenization

我们需要将文本数据转换为模型可接受的格式。这包括文本分词(tokenization)、去除标点符号等处理。

def preprocess_text(text):
    # 分词并去除标点符号
    tokens = word_tokenize(text)
    tokens = [token.lower() for token in tokens if token.isalpha()]
    return tokens

# 处理示例数据
processed_texts = [preprocess_text(text) for text in texts]
processed_summaries = [preprocess_text(summary) for summary in summaries]

步骤 4:构建Seq2Seq模型

我们将使用Seq2Seq模型,其中包括编码器(Encoder)和解码器(Decoder)。

import torch.nn as nn
import torch.optim as optim

class Encoder(nn.Module):
    def __init__(self, input_size, embedding_size, hidden_size, num_layers=1):
        super(Encoder, self).__init__()
        self.embedding = nn.Embedding(input_size, embedding_size)
        self.rnn = nn.GRU(embedding_size, hidden_size, num_layers)

    def forward(self, x):
        embedded = self.embedding(x)
        output, hidden = self.rnn(embedded)
        return output, hidden

class Decoder(nn.Module):
    def __init__(self, output_size, embedding_size, hidden_size, num_layers=1):
        super(Decoder, self).__init__()
        self.embedding = nn.Embedding(output_size, embedding_size)
        self.rnn = nn.GRU(embedding_size, hidden_size, num_layers)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x, hidden):
        x = x.unsqueeze(0)
        embedded = self.embedding(x)
        output, hidden = self.rnn(embedded, hidden)
        prediction = self.fc(output.squeeze(0))
        return prediction, hidden

# 示例模型创建
input_size = len(processed_texts)  # 根据实际情况确定
output_size = len(processed_summaries)  # 根据实际情况确定
embedding_size = 128
hidden_size = 256

encoder = Encoder(input_size, embedding_size, hidden_size)
decoder = Decoder(output_size, embedding_size, hidden_size)

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
encoder_optimizer = optim.Adam(encoder.parameters(), lr=0.001)
decoder_optimizer = optim.Adam(decoder.parameters(), lr=0.001)

步骤 5:训练Seq2Seq模型

from tqdm import tqdm

def train_seq2seq_model(encoder, decoder, dataset, criterion, encoder_optimizer, decoder_optimizer, num_epochs=5):
    for epoch in range(num_epochs):
        total_loss = 0
        for text, summary in tqdm(dataset):
            # 将文本和摘要转换为模型输入所需格式
            text_tokens = torch.tensor([word2index[word] for word in text], dtype=torch.long)
            summary_tokens = torch.tensor([word2index[word] for word in summary], dtype=torch.long)

            # 模型训练
            encoder_optimizer.zero_grad()
            decoder_optimizer.zero_grad()

            # 编码器
            encoder_output, encoder_hidden = encoder(text_tokens)

            # 解码器输入,初始为起始符"<sos>"
            decoder_input = torch.tensor(word2index["<sos>"], dtype=torch.long)
            decoder_hidden = encoder_hidden

            # 解码器输出序列
            predicted_sequence = []

            for _ in range(len(summary_tokens)):
                prediction, decoder_hidden = decoder(decoder_input, decoder_hidden)
                predicted_sequence.append(prediction)

                # 使用教师强制,将下一个时刻的输入设为真实的摘要单词
                decoder_input = summary_tokens[_]

            # 计算损失
            predicted_sequence = torch.stack(predicted_sequence, dim=0)
            loss = criterion(predicted_sequence, summary_tokens)
            total_loss += loss.item()

            # 反向传播和优化
            loss.backward()
            encoder_optimizer.step()
            decoder_optimizer.step()

        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss / len(dataset)}")

# 实际训练过程
word2index = {word: idx for idx, word in enumerate(set([word for sent in processed_texts+processed_summaries for word in sent]))}
index2word = {idx: word for word, idx in word2index.items()}

# 构建数据集
dataset = TextSummaryDataset(processed_texts, processed_summaries)

# 调用训练函数
train_seq2seq_model(encoder, decoder, dataset, criterion, encoder_optimizer, decoder_optimizer, num_epochs=5)

步骤 6:生成摘要

def generate_summary(encoder, decoder, text, max_length=50):
    # 将输入文本转换为模型输入所需格式
    text_tokens = torch.tensor([word2index[word] for word in preprocess_text(text)], dtype=torch.long)

    # 编码器
    encoder_output, encoder_hidden = encoder(text_tokens)

    # 解码器输入,初始为起始符"<sos>"
    decoder_input = torch.tensor(word2index["<sos>"], dtype=torch.long)
    decoder_hidden = encoder_hidden

    # 解码器生成序列
    summary = []

    for _ in range(max_length):
        prediction, decoder_hidden = decoder(decoder_input, decoder_hidden)
        predicted_word_index = torch.argmax(prediction).item()
        summary.append(index2word[predicted_word_index])

        # 如果生成终止符"<eos>",则停止生成
        if predicted_word_index == word2index["<eos>"]:
            break

        # 下一个时刻的输入
        decoder_input = torch.tensor(predicted_word_index, dtype=torch.long)

    return ' '.join(summary)

# 示例生成摘要
input_text = "This is a sample text for summarization."
generated_summary = generate_summary(encoder, decoder, input_text)
print("Original Text:", input_text)
print("Generated Summary:", generated_summary)

这个项目中,我们通过构建 Seq2Seq 模型进行文本摘要生成。你可以根据实际需求选择更大的模型、更大的数据集,并进行更多的训练迭代,以获得更好的摘要生成效果。

相关推荐

为何越来越多的编程语言使用JSON(为什么编程)

JSON是JavascriptObjectNotation的缩写,意思是Javascript对象表示法,是一种易于人类阅读和对编程友好的文本数据传递方法,是JavaScript语言规范定义的一个子...

何时在数据库中使用 JSON(数据库用json格式存储)

在本文中,您将了解何时应考虑将JSON数据类型添加到表中以及何时应避免使用它们。每天?分享?最新?软件?开发?,Devops,敏捷?,测试?以及?项目?管理?最新?,最热门?的?文章?,每天?花?...

MySQL 从零开始:05 数据类型(mysql数据类型有哪些,并举例)

前面的讲解中已经接触到了表的创建,表的创建是对字段的声明,比如:上述语句声明了字段的名称、类型、所占空间、默认值和是否可以为空等信息。其中的int、varchar、char和decimal都...

JSON对象花样进阶(json格式对象)

一、引言在现代Web开发中,JSON(JavaScriptObjectNotation)已经成为数据交换的标准格式。无论是从前端向后端发送数据,还是从后端接收数据,JSON都是不可或缺的一部分。...

深入理解 JSON 和 Form-data(json和formdata提交区别)

在讨论现代网络开发与API设计的语境下,理解客户端和服务器间如何有效且可靠地交换数据变得尤为关键。这里,特别值得关注的是两种主流数据格式:...

JSON 语法(json 语法 priority)

JSON语法是JavaScript语法的子集。JSON语法规则JSON语法是JavaScript对象表示法语法的子集。数据在名称/值对中数据由逗号分隔花括号保存对象方括号保存数组JS...

JSON语法详解(json的语法规则)

JSON语法规则JSON语法是JavaScript对象表示法语法的子集。数据在名称/值对中数据由逗号分隔大括号保存对象中括号保存数组注意:json的key是字符串,且必须是双引号,不能是单引号...

MySQL JSON数据类型操作(mysql的json)

概述mysql自5.7.8版本开始,就支持了json结构的数据存储和查询,这表明了mysql也在不断的学习和增加nosql数据库的有点。但mysql毕竟是关系型数据库,在处理json这种非结构化的数据...

JSON的数据模式(json数据格式示例)

像XML模式一样,JSON数据格式也有Schema,这是一个基于JSON格式的规范。JSON模式也以JSON格式编写。它用于验证JSON数据。JSON模式示例以下代码显示了基本的JSON模式。{"...

前端学习——JSON格式详解(后端json格式)

JSON(JavaScriptObjectNotation)是一种轻量级的数据交换格式。易于人阅读和编写。同时也易于机器解析和生成。它基于JavaScriptProgrammingLa...

什么是 JSON:详解 JSON 及其优势(什么叫json)

现在程序员还有谁不知道JSON吗?无论对于前端还是后端,JSON都是一种常见的数据格式。那么JSON到底是什么呢?JSON的定义...

PostgreSQL JSON 类型:处理结构化数据

PostgreSQL提供JSON类型,以存储结构化数据。JSON是一种开放的数据格式,可用于存储各种类型的值。什么是JSON类型?JSON类型表示JSON(JavaScriptO...

JavaScript:JSON、三种包装类(javascript 包)

JOSN:我们希望可以将一个对象在不同的语言中进行传递,以达到通信的目的,最佳方式就是将一个对象转换为字符串的形式JSON(JavaScriptObjectNotation)-JS的对象表示法...

Python数据分析 只要1分钟 教你玩转JSON 全程干货

Json简介:Json,全名JavaScriptObjectNotation,JSON(JavaScriptObjectNotation(记号、标记))是一种轻量级的数据交换格式。它基于J...

比较一下JSON与XML两种数据格式?(json和xml哪个好)

JSON(JavaScriptObjectNotation)和XML(eXtensibleMarkupLanguage)是在日常开发中比较常用的两种数据格式,它们主要的作用就是用来进行数据的传...

取消回复欢迎 发表评论:

请填写验证码