备注:
1.参考BLOOMZ/LLaMA微调(LoRA)模式
2.一套代码,支持单机单卡,单机多卡,多机多卡完整流程记录
3.支持FULL,LoRA微调ChatGLM-6B
一.构造数据格式,如下格式即可的json文件
{
"instruction": "What are the three primary colors?",
"input": "",
"output": "The three primary colors are red, blue, and yellow."
}
二.微调训练代码:finetune.py
from typing import List
from transformers import TrainingArguments, AutoConfig
from transformers import Trainer, HfArgumentParser
from transformers import AutoTokenizer, AutoModel,AutoModelForCausalLM
import torch
import torch.nn as nn
from peft import get_peft_model, LoraConfig, TaskType
from dataclasses import dataclass, field
import datasets
import os
os.environ["WANDB_DISABLED"] = "true"
device_map="auto"
world_size = int(os.environ.get("WORLD_SIZE", 1))
print("world_size",world_size)
ddp = world_size != 1
@dataclass
class FinetuneArguments:
dataset_path: str = field(default="data/data.json")
model_path: str = field(default="output")
lora_rank: int = field(default=8)
lora_alpha: int = field(default=16)
lora_dropout: float = field(default=0.05)
use_lora: bool = field(default=False)
cutoff_len: int = field(default=128)
val_set_size: int = field(default=1000)
val_set_rate: float = field(default=0.1)
finetune_args, training_args = HfArgumentParser(
(FinetuneArguments, TrainingArguments)
).parse_args_into_dataclasses()
if ddp:
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
training_args.ddp_find_unused_parameters=False
tokenizer = AutoTokenizer.from_pretrained("./chatglm-6b", trust_remote_code=True)
config = AutoConfig.from_pretrained('./chatglm-6b', trust_remote_code=True)
class CastOutputToFloat(nn.Sequential):
def forward(self, x):
return super().forward(x).to(torch.float32)
def tokenize(prompt,target,cutoff_len=finetune_args.cutoff_len):
prompt_ids = tokenizer.encode(prompt, max_length=cutoff_len, truncation=True,add_special_tokens=True)
target_ids = tokenizer.encode(
target,
max_length=cutoff_len,
truncation=True,
add_special_tokens=False)
input_ids = prompt_ids + target_ids + [tokenizer.eos_token_id]
return {"input_ids": input_ids, "seq_len": len(prompt_ids)}
def generate_and_tokenize_prompt(data_point):
instruction = data_point['instruction']
input_text = data_point["input"]
input_text = instruction + input_text
target_text = data_point["output"]
tokenized_full_prompt = tokenize(input_text,target_text)
return tokenized_full_prompt
def data_collator(features: list) -> dict:
len_ids = [len(feature["input_ids"]) for feature in features]
longest = max(len_ids) + 1
input_ids = []
attention_mask_list = []
position_ids_list = []
labels_list = []
for ids_l, feature in sorted(zip(len_ids, features), key=lambda x: -x[0]):
ids = feature["input_ids"]
seq_len = feature["seq_len"]
labels = (
[-100] * (seq_len - 1)
+ ids[(seq_len - 1):]
+ [tokenizer.eos_token_id]
+ [-100] * (longest - ids_l - 1)
)
ids = ids + [tokenizer.eos_token_id] * (longest - ids_l)
_ids = torch.LongTensor(ids)
attention_mask, position_ids = get_masks_and_position_ids(
ids, seq_len, longest, _ids.device, gmask=False
)
labels_list.append(torch.LongTensor(labels))
input_ids.append(_ids)
attention_mask_list.append(attention_mask)
position_ids_list.append(position_ids)
input_ids = torch.stack(input_ids)
labels = torch.stack(labels_list)
attention_mask = torch.stack(attention_mask_list)
position_ids = torch.stack(position_ids_list)
return {
"input_ids": input_ids,
"labels": labels,
"attention_mask": attention_mask,
"position_ids": position_ids,
}
def get_masks_and_position_ids(
seq, seq_len, context_length, device, gmask=False, position_encoding_2d=True
):
mask_position = (
seq_len - 2
) # is equal to `seq.index(mask_token)` or `seq.index(150001)`
attention_mask = torch.ones((1, context_length, context_length), device=device)
attention_mask.tril_()
attention_mask[..., : mask_position - 1] = 1
attention_mask = (attention_mask < 0.5).bool()
if position_encoding_2d:
seq_length = seq_len - 1 # is equal to `seq_length = seq.index(150004)`
position_ids = torch.arange(context_length, dtype=torch.long, device=device)
if not gmask:
position_ids[seq_length:] = mask_position
block_position_ids = torch.cat(
(
torch.zeros(seq_length, dtype=torch.long, device=device),
torch.arange(
context_length - seq_length, dtype=torch.long, device=device
)
+ 1,
)
)
position_ids = torch.stack((position_ids, block_position_ids), dim=0)
else:
position_ids = torch.arange(context_length, dtype=torch.long, device=device)
if not gmask:
position_ids[context_length - 1:] = mask_position
return attention_mask, position_ids
class ModifiedTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False):
return model(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
position_ids=inputs["position_ids"],
labels=inputs["labels"],
).loss
def save_model(self, output_dir=None, _internal_call=False):
from transformers.trainer import TRAINING_ARGS_NAME
os.makedirs(output_dir, exist_ok=True)
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
saved_params = {
k: v.to("cpu") for k, v in self.model.named_parameters() if v.requires_grad
}
torch.save(saved_params, os.path.join(output_dir, "adapter_model.bin"))
def main():
print("finetune_args.cutoff_len", finetune_args.cutoff_len)
print("finetune_args.use_lora", finetune_args.use_lora)
print("training_args.fp16", training_args.fp16)
print("training_args.local_rank", training_args.local_rank)
print("training_args.ddp_find_unused_parameters", training_args.ddp_find_unused_parameters)
# init model
model = AutoModel.from_pretrained(
"./chatglm-6b", load_in_8bit=True, trust_remote_code=True, device_map=device_map
)
#
if not ddp and torch.cuda.device_count() > 1:
model.is_parallelizable = True
model.model_parallel = True
model.gradient_checkpointing_enable()
model.enable_input_require_grads()
model.lm_head = CastOutputToFloat(model.lm_head)
model.config.use_cache = False
if finetune_args.use_lora:
# setup peft
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
r=finetune_args.lora_rank,
lora_alpha=finetune_args.lora_alpha,
lora_dropout=finetune_args.lora_dropout,
target_modules=["query_key_value"]
)
model = get_peft_model(model, peft_config)
# load dataset
dataset = datasets.load_dataset("json",data_files=finetune_args.dataset_path)
val_set_size = finetune_args.val_set_size
training_nums = len(dataset['train'])
if val_set_size > 0:
val_set_size = min(val_set_size, int(training_nums * finetune_args.val_set_rate))
train_val = dataset["train"].train_test_split(
test_size=val_set_size, shuffle=True, seed=42
)
train_data = train_val["train"].shuffle().map(generate_and_tokenize_prompt)
val_data = train_val["test"].shuffle().map(generate_and_tokenize_prompt)
else:
train_data = dataset["train"].shuffle().map(generate_and_tokenize_prompt)
val_data = None
# load dataset
#train_data = datasets.load_from_disk(finetune_args.dataset_path)
# start train
trainer = ModifiedTrainer(
model=model,
train_dataset=train_data,
eval_dataset=val_data,
args=training_args,
data_collator=data_collator
)
trainer.train()
# save model
model.save_pretrained(training_args.output_dir)
if __name__ == "__main__":
main()
三.测试运行
#普通
python finetune.py \
--dataset_path data/data.json \
--cutoff_len 8196 \
--val_set_size 1000 \
--val_set_rate 0.1 \
--lora_rank 8 \
--lora_alpha 16 \
--lora_dropout 0.05 \
--num_train_epochs 32 \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 1 \
--save_steps 1000 \
--eval_steps 1000 \
--save_total_limit 2 \
--learning_rate 3e-6 \
--remove_unused_columns False \
--warmup_steps 10 \
--logging_steps 10 \
--group_by_length True \
--output_dir trained_models/chatglm_lora \
--use_lora True
#一机多卡
torchrun --nproc_per_node=2 finetune.py \
--dataset_path data/data.json \
--cutoff_len 256 \
--val_set_size 1000 \
--val_set_rate 0.1 \
--lora_rank 8 \
--lora_alpha 16 \
--lora_dropout 0.05 \
--num_train_epochs 1 \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 1 \
--save_steps 1000 \
--eval_steps 1000 \
--save_total_limit 2 \
--learning_rate 2e-5 \
--remove_unused_columns False \
--warmup_steps 10 \
--logging_steps 10 \
--group_by_length True \
--output_dir trained_models/chatglm_lora \
--use_lora True
#多机多卡(链家50万数据量LoRA),rdzv_endpoint换成etcd服务IP地址
torchrun --nnodes=1:2 --nproc_per_node=10 --max_restarts 3 --rdzv_id=cls_chatglm_6b_belle --rdzv_backend=etcd --rdzv_endpoint=x.x.x.x:2379 finetune.py \
--dataset_path data/belle.json \
--cutoff_len 256 \
--val_set_size 1000 \
--val_set_rate 0.1 \
--lora_rank 8 \
--lora_alpha 16 \
--lora_dropout 0.05 \
--num_train_epochs 3 \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 8 \
--save_steps 1000 \
--eval_steps 1000 \
--save_total_limit 2 \
--learning_rate 2e-4 \
--remove_unused_columns False \
--warmup_steps 10 \
--logging_steps 10 \
--group_by_length True \
--output_dir trained_models/chatglm_lora_belle \
--use_lora True
四.验证推理
请参考之前关于ChatGLM-6B相关代码文章
五.参考链接
https://github.com/THUDM/ChatGLM-6B
https://github.com/yanqiangmiffy/InstructGLM