百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 编程字典 > 正文

微调ChatGLM-6B详细完整全过程记录

toyiye 2024-06-21 12:19 12 浏览 0 评论

备注:

1.参考BLOOMZ/LLaMA微调(LoRA)模式

2.一套代码,支持单机单卡,单机多卡,多机多卡完整流程记录

3.支持FULL,LoRA微调ChatGLM-6B


一.构造数据格式,如下格式即可的json文件

{

"instruction": "What are the three primary colors?",

"input": "",

"output": "The three primary colors are red, blue, and yellow."

}

二.微调训练代码:finetune.py

from typing import List

from transformers import TrainingArguments, AutoConfig

from transformers import Trainer, HfArgumentParser

from transformers import AutoTokenizer, AutoModel,AutoModelForCausalLM

import torch

import torch.nn as nn

from peft import get_peft_model, LoraConfig, TaskType

from dataclasses import dataclass, field

import datasets

import os

os.environ["WANDB_DISABLED"] = "true"

device_map="auto"

world_size = int(os.environ.get("WORLD_SIZE", 1))

print("world_size",world_size)

ddp = world_size != 1

@dataclass

class FinetuneArguments:

dataset_path: str = field(default="data/data.json")

model_path: str = field(default="output")

lora_rank: int = field(default=8)

lora_alpha: int = field(default=16)

lora_dropout: float = field(default=0.05)

use_lora: bool = field(default=False)

cutoff_len: int = field(default=128)

val_set_size: int = field(default=1000)

val_set_rate: float = field(default=0.1)

finetune_args, training_args = HfArgumentParser(

(FinetuneArguments, TrainingArguments)

).parse_args_into_dataclasses()

if ddp:

device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}

training_args.ddp_find_unused_parameters=False

tokenizer = AutoTokenizer.from_pretrained("./chatglm-6b", trust_remote_code=True)

config = AutoConfig.from_pretrained('./chatglm-6b', trust_remote_code=True)

class CastOutputToFloat(nn.Sequential):

def forward(self, x):

return super().forward(x).to(torch.float32)

def tokenize(prompt,target,cutoff_len=finetune_args.cutoff_len):

prompt_ids = tokenizer.encode(prompt, max_length=cutoff_len, truncation=True,add_special_tokens=True)

target_ids = tokenizer.encode(

target,

max_length=cutoff_len,

truncation=True,

add_special_tokens=False)

input_ids = prompt_ids + target_ids + [tokenizer.eos_token_id]

return {"input_ids": input_ids, "seq_len": len(prompt_ids)}

def generate_and_tokenize_prompt(data_point):

instruction = data_point['instruction']

input_text = data_point["input"]

input_text = instruction + input_text

target_text = data_point["output"]

tokenized_full_prompt = tokenize(input_text,target_text)

return tokenized_full_prompt

def data_collator(features: list) -> dict:

len_ids = [len(feature["input_ids"]) for feature in features]

longest = max(len_ids) + 1

input_ids = []

attention_mask_list = []

position_ids_list = []

labels_list = []

for ids_l, feature in sorted(zip(len_ids, features), key=lambda x: -x[0]):

ids = feature["input_ids"]

seq_len = feature["seq_len"]

labels = (

[-100] * (seq_len - 1)

+ ids[(seq_len - 1):]

+ [tokenizer.eos_token_id]

+ [-100] * (longest - ids_l - 1)

)

ids = ids + [tokenizer.eos_token_id] * (longest - ids_l)

_ids = torch.LongTensor(ids)

attention_mask, position_ids = get_masks_and_position_ids(

ids, seq_len, longest, _ids.device, gmask=False

)

labels_list.append(torch.LongTensor(labels))

input_ids.append(_ids)

attention_mask_list.append(attention_mask)

position_ids_list.append(position_ids)

input_ids = torch.stack(input_ids)

labels = torch.stack(labels_list)

attention_mask = torch.stack(attention_mask_list)

position_ids = torch.stack(position_ids_list)

return {

"input_ids": input_ids,

"labels": labels,

"attention_mask": attention_mask,

"position_ids": position_ids,

}

def get_masks_and_position_ids(

seq, seq_len, context_length, device, gmask=False, position_encoding_2d=True

):

mask_position = (

seq_len - 2

) # is equal to `seq.index(mask_token)` or `seq.index(150001)`

attention_mask = torch.ones((1, context_length, context_length), device=device)

attention_mask.tril_()

attention_mask[..., : mask_position - 1] = 1

attention_mask = (attention_mask < 0.5).bool()

if position_encoding_2d:

seq_length = seq_len - 1 # is equal to `seq_length = seq.index(150004)`

position_ids = torch.arange(context_length, dtype=torch.long, device=device)

if not gmask:

position_ids[seq_length:] = mask_position

block_position_ids = torch.cat(

(

torch.zeros(seq_length, dtype=torch.long, device=device),

torch.arange(

context_length - seq_length, dtype=torch.long, device=device

)

+ 1,

)

)

position_ids = torch.stack((position_ids, block_position_ids), dim=0)

else:

position_ids = torch.arange(context_length, dtype=torch.long, device=device)

if not gmask:

position_ids[context_length - 1:] = mask_position

return attention_mask, position_ids

class ModifiedTrainer(Trainer):

def compute_loss(self, model, inputs, return_outputs=False):

return model(

input_ids=inputs["input_ids"],

attention_mask=inputs["attention_mask"],

position_ids=inputs["position_ids"],

labels=inputs["labels"],

).loss

def save_model(self, output_dir=None, _internal_call=False):

from transformers.trainer import TRAINING_ARGS_NAME

os.makedirs(output_dir, exist_ok=True)

torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))

saved_params = {

k: v.to("cpu") for k, v in self.model.named_parameters() if v.requires_grad

}

torch.save(saved_params, os.path.join(output_dir, "adapter_model.bin"))

def main():

print("finetune_args.cutoff_len", finetune_args.cutoff_len)

print("finetune_args.use_lora", finetune_args.use_lora)

print("training_args.fp16", training_args.fp16)

print("training_args.local_rank", training_args.local_rank)

print("training_args.ddp_find_unused_parameters", training_args.ddp_find_unused_parameters)

# init model

model = AutoModel.from_pretrained(

"./chatglm-6b", load_in_8bit=True, trust_remote_code=True, device_map=device_map

)

#

if not ddp and torch.cuda.device_count() > 1:

model.is_parallelizable = True

model.model_parallel = True

model.gradient_checkpointing_enable()

model.enable_input_require_grads()

model.lm_head = CastOutputToFloat(model.lm_head)

model.config.use_cache = False

if finetune_args.use_lora:

# setup peft

peft_config = LoraConfig(

task_type=TaskType.CAUSAL_LM,

inference_mode=False,

r=finetune_args.lora_rank,

lora_alpha=finetune_args.lora_alpha,

lora_dropout=finetune_args.lora_dropout,

target_modules=["query_key_value"]

)

model = get_peft_model(model, peft_config)

# load dataset

dataset = datasets.load_dataset("json",data_files=finetune_args.dataset_path)

val_set_size = finetune_args.val_set_size

training_nums = len(dataset['train'])

if val_set_size > 0:

val_set_size = min(val_set_size, int(training_nums * finetune_args.val_set_rate))

train_val = dataset["train"].train_test_split(

test_size=val_set_size, shuffle=True, seed=42

)

train_data = train_val["train"].shuffle().map(generate_and_tokenize_prompt)

val_data = train_val["test"].shuffle().map(generate_and_tokenize_prompt)

else:

train_data = dataset["train"].shuffle().map(generate_and_tokenize_prompt)

val_data = None

# load dataset

#train_data = datasets.load_from_disk(finetune_args.dataset_path)

# start train

trainer = ModifiedTrainer(

model=model,

train_dataset=train_data,

eval_dataset=val_data,

args=training_args,

data_collator=data_collator

)

trainer.train()

# save model

model.save_pretrained(training_args.output_dir)

if __name__ == "__main__":

main()


三.测试运行

#普通

python finetune.py \

--dataset_path data/data.json \

--cutoff_len 8196 \

--val_set_size 1000 \

--val_set_rate 0.1 \

--lora_rank 8 \

--lora_alpha 16 \

--lora_dropout 0.05 \

--num_train_epochs 32 \

--per_device_train_batch_size 1 \

--gradient_accumulation_steps 1 \

--save_steps 1000 \

--eval_steps 1000 \

--save_total_limit 2 \

--learning_rate 3e-6 \

--remove_unused_columns False \

--warmup_steps 10 \

--logging_steps 10 \

--group_by_length True \

--output_dir trained_models/chatglm_lora \

--use_lora True


#一机多卡

torchrun --nproc_per_node=2 finetune.py \

--dataset_path data/data.json \

--cutoff_len 256 \

--val_set_size 1000 \

--val_set_rate 0.1 \

--lora_rank 8 \

--lora_alpha 16 \

--lora_dropout 0.05 \

--num_train_epochs 1 \

--per_device_train_batch_size 1 \

--gradient_accumulation_steps 1 \

--save_steps 1000 \

--eval_steps 1000 \

--save_total_limit 2 \

--learning_rate 2e-5 \

--remove_unused_columns False \

--warmup_steps 10 \

--logging_steps 10 \

--group_by_length True \

--output_dir trained_models/chatglm_lora \

--use_lora True

#多机多卡(链家50万数据量LoRA),rdzv_endpoint换成etcd服务IP地址

torchrun --nnodes=1:2 --nproc_per_node=10 --max_restarts 3 --rdzv_id=cls_chatglm_6b_belle --rdzv_backend=etcd --rdzv_endpoint=x.x.x.x:2379 finetune.py \

--dataset_path data/belle.json \

--cutoff_len 256 \

--val_set_size 1000 \

--val_set_rate 0.1 \

--lora_rank 8 \

--lora_alpha 16 \

--lora_dropout 0.05 \

--num_train_epochs 3 \

--per_device_train_batch_size 1 \

--gradient_accumulation_steps 8 \

--save_steps 1000 \

--eval_steps 1000 \

--save_total_limit 2 \

--learning_rate 2e-4 \

--remove_unused_columns False \

--warmup_steps 10 \

--logging_steps 10 \

--group_by_length True \

--output_dir trained_models/chatglm_lora_belle \

--use_lora True


四.验证推理

请参考之前关于ChatGLM-6B相关代码文章

五.参考链接

https://github.com/THUDM/ChatGLM-6B

https://github.com/yanqiangmiffy/InstructGLM

相关推荐

为何越来越多的编程语言使用JSON(为什么编程)

JSON是JavascriptObjectNotation的缩写,意思是Javascript对象表示法,是一种易于人类阅读和对编程友好的文本数据传递方法,是JavaScript语言规范定义的一个子...

何时在数据库中使用 JSON(数据库用json格式存储)

在本文中,您将了解何时应考虑将JSON数据类型添加到表中以及何时应避免使用它们。每天?分享?最新?软件?开发?,Devops,敏捷?,测试?以及?项目?管理?最新?,最热门?的?文章?,每天?花?...

MySQL 从零开始:05 数据类型(mysql数据类型有哪些,并举例)

前面的讲解中已经接触到了表的创建,表的创建是对字段的声明,比如:上述语句声明了字段的名称、类型、所占空间、默认值和是否可以为空等信息。其中的int、varchar、char和decimal都...

JSON对象花样进阶(json格式对象)

一、引言在现代Web开发中,JSON(JavaScriptObjectNotation)已经成为数据交换的标准格式。无论是从前端向后端发送数据,还是从后端接收数据,JSON都是不可或缺的一部分。...

深入理解 JSON 和 Form-data(json和formdata提交区别)

在讨论现代网络开发与API设计的语境下,理解客户端和服务器间如何有效且可靠地交换数据变得尤为关键。这里,特别值得关注的是两种主流数据格式:...

JSON 语法(json 语法 priority)

JSON语法是JavaScript语法的子集。JSON语法规则JSON语法是JavaScript对象表示法语法的子集。数据在名称/值对中数据由逗号分隔花括号保存对象方括号保存数组JS...

JSON语法详解(json的语法规则)

JSON语法规则JSON语法是JavaScript对象表示法语法的子集。数据在名称/值对中数据由逗号分隔大括号保存对象中括号保存数组注意:json的key是字符串,且必须是双引号,不能是单引号...

MySQL JSON数据类型操作(mysql的json)

概述mysql自5.7.8版本开始,就支持了json结构的数据存储和查询,这表明了mysql也在不断的学习和增加nosql数据库的有点。但mysql毕竟是关系型数据库,在处理json这种非结构化的数据...

JSON的数据模式(json数据格式示例)

像XML模式一样,JSON数据格式也有Schema,这是一个基于JSON格式的规范。JSON模式也以JSON格式编写。它用于验证JSON数据。JSON模式示例以下代码显示了基本的JSON模式。{"...

前端学习——JSON格式详解(后端json格式)

JSON(JavaScriptObjectNotation)是一种轻量级的数据交换格式。易于人阅读和编写。同时也易于机器解析和生成。它基于JavaScriptProgrammingLa...

什么是 JSON:详解 JSON 及其优势(什么叫json)

现在程序员还有谁不知道JSON吗?无论对于前端还是后端,JSON都是一种常见的数据格式。那么JSON到底是什么呢?JSON的定义...

PostgreSQL JSON 类型:处理结构化数据

PostgreSQL提供JSON类型,以存储结构化数据。JSON是一种开放的数据格式,可用于存储各种类型的值。什么是JSON类型?JSON类型表示JSON(JavaScriptO...

JavaScript:JSON、三种包装类(javascript 包)

JOSN:我们希望可以将一个对象在不同的语言中进行传递,以达到通信的目的,最佳方式就是将一个对象转换为字符串的形式JSON(JavaScriptObjectNotation)-JS的对象表示法...

Python数据分析 只要1分钟 教你玩转JSON 全程干货

Json简介:Json,全名JavaScriptObjectNotation,JSON(JavaScriptObjectNotation(记号、标记))是一种轻量级的数据交换格式。它基于J...

比较一下JSON与XML两种数据格式?(json和xml哪个好)

JSON(JavaScriptObjectNotation)和XML(eXtensibleMarkupLanguage)是在日常开发中比较常用的两种数据格式,它们主要的作用就是用来进行数据的传...

取消回复欢迎 发表评论:

请填写验证码