5.2 多任务数据预处理方式
多任务数据预处理介绍¶
学习目标¶
- 了解本项目中多任务数据集的格式
- 掌握实现数据预处理的函数代码
数据预处理过程¶
- 本项目中对数据部分的预处理步骤如下:
- 查看项目数据集
- 编写Config类项目文件配置代码
- 编写数据处理相关代码
1 查看项目数据集¶
-
数据存放位置:/Users/**/PycharmProjects/llm/ptune_chatglm/data
-
data文件夹里面包含3个jsonl文档,分别为:mixed_train_dataset.jsonl、mixed_dev_dataset.jsonl、dataset.jsonl
1.1 train.jsonl¶
-
mixed_train_dataset.jsonl为训练数据集,因为我们本次项目同时进行「信息抽取+文本分类」两项任务,因此数据中混合了两种任务数据类型。举例展示如下:
-
信息抽取数据示例
- Instruction 部分告诉模型现在需要做「阅读理解」任务,Input 部分告知模型要抽取的句子以及输出的格式。
{
"context": "Instruction: 你现在是一个很厉害的阅读理解器,严格按照人类指令进行回答。\nInput: 找到句子中的三元组信息并输出成json给我:\n\n九玄珠是在纵横中文网连载的一部小说,作者是龙马。\nAnswer: ",
"target": "```json\n[{\"predicate\": \"连载网站\", \"object_type\": \"网站\", \"subject_type\": \"网络小说\", \"object\": \"纵横中文网\", \"subject\": \"九玄珠\"}, {\"predicate\": \"作者\", \"object_type\": \"人物\", \"subject_type\": \"图书作品\", \"object\": \"龙马\", \"subject\": \"九玄珠\"}]\n```"
}
- 文本数据示例
- Instruction 部分告诉模型现在需要做「阅读理解」任务,Input 部分告知模型要抽取的句子以及输出的格式。
{
"context": "Instruction: 你现在是一个很厉害的阅读理解器,严格按照人类指令进行回答。\nInput: 下面句子可能是一条关于什么的评论,用列表形式回答:\n\n很不错,很新鲜,快递小哥服务很好,水果也挺甜挺脆的\nAnswer: ",
"target": "[\"水果\"]"
}
训练集中一共包含902条数据,每一条数据都分为
context和target两部分:
context部分是接受用户的输入。2.target部分用于指定模型的输出。在
context中又包括 2 个部分:
- Instruction:用于告知模型的具体指令,当需要一个模型同时解决多个任务时可以设定不同的 Instruction 来帮助模型判别当前应当做什么任务。
- Input:当前用户的输入。
1.2 dev.jsonl¶
-
mixed_dev_dataset.jsonl为训练数据集,因为我们本次项目同时进行「信息抽取+文本分类」两项任务,因此数据中混合了两种任务数据类型。举例展示如下:
-
信息抽取数据示例
- Instruction 部分告诉模型现在需要做「阅读理解」任务,Input 部分告知模型要抽取的句子以及输出的格式。
{
"context": "Instruction: 你现在是一个很厉害的阅读理解器,严格按照人类指令进行回答。\nInput: 下面句子包含了哪些三元组,只用json的格式回答:\n\n《全国公共英语等级考试四级词汇科学记忆(磁带1--5)》是由人民大学出版社出版的一部教育作品,作者是钟道隆。\nAnswer: ",
"target": "```json\n[{\"predicate\": \"出版社\", \"object_type\": \"出版社\", \"subject_type\": \"书籍\", \"object\": \"人民大学\", \"subject\": \"全国公共英语等级考试四级词汇科学记忆(磁带1--5)\"}, {\"predicate\": \"作者\", \"object_type\": \"人物\", \"subject_type\": \"图书作品\", \"object\": \"钟道隆\", \"subject\": \"全国公共英语等级考试四级词汇科学记忆(磁带1--5)\"}]\n```"
}
- 文本数据示例
- Instruction 部分告诉模型现在需要做「阅读理解」任务,Input 部分告知模型要抽取的句子以及输出的格式。
{
"context": "Instruction: 你现在是一个很厉害的阅读理解器,严格按照人类指令进行回答。\nInput: 下面句子中描述的是一个什么?用列表的方式回答。\n\n什么苹果啊,都没有苹果味,怪怪的味道,而且一点都不甜,超级难吃!\nAnswer: ",
"target": "[\"水果\"]"
}
训练集中一共包含122条数据,每一条数据都分为
context和target两部分:
context部分是接受用户的输入。2.target部分用于指定模型的输出。在
context中又包括 2 个部分:
- Instruction:用于告知模型的具体指令,当需要一个模型同时解决多个任务时可以设定不同的 Instruction 来帮助模型判别当前应当做什么任务。
- Input:当前用户的输入。
如果想使用自定义数据训练,只需要仿照上述示例数据构建数据集即可。
2 编写项目Config类配置文件¶
-
代码路径:/Users/**/PycharmProjects/llm/ptune_chatglm/glm_config.py
-
config文件目的:配置项目常用变量,一般这些变量属于不经常改变的,比如:训练文件路径、模型训练次数、模型超参数等等
具体代码实现:
# -*- coding:utf-8 -*-
import torch
class ProjectConfig(object):
def __init__(self):
self.device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
self.pre_model = './llm/ChatGLM-6B/THUDM/chatglm-6b'
self.train_path = './llm/ptune_chatglm/data/mixed_train_dataset.jsonl'
self.dev_path = './llm/ptune_chatglm/data/mixed_dev_dataset.jsonl'
self.use_lora = True
self.use_ptuning = False
# 低秩矩阵的秩是8
self.lora_rank = 8
self.batch_size = 1
self.epochs = 2
self.learning_rate = 3e-5
self.weight_decay = 0
self.warmup_ratio = 0.06
self.max_source_seq_len = 400
self.max_target_seq_len = 300
self.logging_steps = 10
self.save_freq = 200
self.pre_seq_len = 128
self.prefix_projection = False # 默认为False,即p-tuning,如果为True,即p-tuning-v2
self.save_dir = './llm/ptune_chatglm/checkpoints/ptune'
if __name__ == '__main__':
pc = ProjectConfig()
print(pc.save_dir)
3 编写数据处理相关代码¶
- 代码路径:/Users/**/PycharmProjects/llm/ptune_chatglm/data_handle.
- data_handle文件夹中一共包含两个py脚本:data_preprocess.py、data_loader.py
3.1 data_preprocess.py¶
- 目的: 将样本数据转换为模型接受的输入数据
- 导入必备的工具包
import json
# 返回的字符串包含有关异常的详细信
import traceback
import numpy as np
from tqdm import tqdm
from datasets import load_dataset
from transformers import AutoTokenizer
from functools import partial
import sys
sys.path.append('..')
from glm_config import *
- 定义数据转换方法convert_example()
def convert_example(
examples: dict,
tokenizer,
max_source_seq_len: int,
max_target_seq_len: int,
):
"""
将样本数据转换为Ptuning模型接收的输入数据。
Args:
examples (dict): 训练数据样本, e.g. -> {
"text": [
'{"context": "年基准利率4.35%。从实际看...", "target": "2017年银行贷款基准利率"}',
...
]
}
max_source_seq_len (int): prompt最大长度
max_target_seq_len (int): 答案最大长度
Returns:
dict (str: np.array) -> tokenized_output = {
'input_ids': [[1525, 10, ...], [758, 2345, ...]],
'labels': [[822, 10, ...], [125, 58...]]
}
"""
tokenized_output = {
'input_ids': [],
'labels': []
}
max_seq_length = max_source_seq_len + max_target_seq_len
for example in examples['text']:
try:
example = json.loads(example)
context = example["context"]
target = example["target"]
# print(f'context-->{context}')
# print(f'target-->{target}')
prompts_ids = tokenizer.encode(
text=context,
add_special_tokens=False
)
# print(f'prompts_ids--》{prompts_ids}\n{len(prompts_ids)}')
target_ids = tokenizer.encode(
text=target,
add_special_tokens=False
)
# print(f'target_ids--》{target_ids}\n{len(target_ids)}')
if len(prompts_ids) >= max_source_seq_len:
# source 需要留一个 [gMASK] token 在结尾
prompts_ids = prompts_ids[:max_source_seq_len - 1]
if len(target_ids) >= max_target_seq_len - 1:
# target 需要留一个 <sop> 在开头和一个 <eop> token 在结尾
target_ids = target_ids[:max_target_seq_len - 2]
# source_ids + [gMASK] + <sop> + target_ids + <eop>
input_ids = tokenizer.build_inputs_with_special_tokens(prompts_ids, target_ids)
# print(f'input_ids-->{input_ids}')
# bos 在 target 的第一位
context_length = input_ids.index(tokenizer.bos_token_id)
# print(f'context_length-->{context_length}')
# [gMASK] 在 source 的最后一位
mask_position = context_length - 1
# 从 bos 开始到后面所有的 target 到 eos 都为 label
labels = [-100] * context_length + input_ids[mask_position + 1:]
# print(f'labels-->{labels}')
pad_len = max_seq_length - len(input_ids)
# print(f'pad_len-->{pad_len}')
input_ids = input_ids + [tokenizer.pad_token_id] * pad_len
# print(f'input_ids-->{input_ids}\n{len(input_ids)}')
labels = labels + [-100] * pad_len
# print(f'labels-->{labels}\n{len(labels)}')
tokenized_output['input_ids'].append(input_ids)
tokenized_output['labels'].append(labels)
except:
print(f'"{example}" -> {traceback.format_exc()}')
continue
for k, v in tokenized_output.items():
tokenized_output[k] = np.array(v)
return tokenized_output
- 定义获取训练或验证数据最大长度方法get_max_length()
def get_max_length(
tokenizer,
dataset_file: str
):
"""
测试数据集最大的输入/输出tokens是多少。
Args:
dataset_file (str): _description_
"""
source_seq_len_list = []
target_seq_len_list = []
with open(dataset_file, 'r') as f:
for line in tqdm(f.readlines()):
line = json.loads(line)
source_len = tokenizer.encode(line['context'])
source_seq_len_list.append(len(source_len))
target_len = tokenizer.encode(line['target'])
target_seq_len_list.append(len(target_len))
print(dataset_file)
print(f"【Source Sequence】 Max: {max(source_seq_len_list)}, Avg: {int(sum(source_seq_len_list) / len(source_seq_len_list))}, Middle: {sorted(source_seq_len_list)[int(len(source_seq_len_list) / 2)]}.")
print(f"【Target Sequence】 Max: {max(target_seq_len_list)}, Avg: {int(sum(target_seq_len_list) / len(target_seq_len_list))}, Middle: {sorted(target_seq_len_list)[int(len(target_seq_len_list) / 2)]}.")
3.3 data_loader.py¶
- 目的:定义数据加载器
- 导入必备的工具包
# coding:utf-8
from torch.utils.data import DataLoader
from transformers import default_data_collator, AutoTokenizer
from data_handle.data_preprocess import *
from glm_config import *
pc = ProjectConfig() # 实例化项目配置文件
tokenizer = AutoTokenizer.from_pretrained(pc.pre_model, trust_remote_code=True)
- 定义获取数据加载器的方法get_data()
def get_data():
dataset = load_dataset('text', data_files={'train': pc.train_path,
'dev': pc.dev_path})
new_func = partial(convert_example,
tokenizer=tokenizer,
max_source_seq_len=100,
max_target_seq_len=100)
dataset = dataset.map(new_func, batched=True)
train_dataset = dataset["train"]
dev_dataset = dataset["dev"]
train_dataloader = DataLoader(train_dataset,
shuffle=True,
collate_fn=default_data_collator,
batch_size=pc.batch_size)
dev_dataloader = DataLoader(dev_dataset,
collate_fn=default_data_collator,
batch_size=pc.batch_size)
return train_dataloader, dev_dataloader
if __name__ == '__main__':
train_dataloader, dev_dataloader = get_data()
print(len(train_dataloader))
print(len(dev_dataloader))
for i, value in enumerate(train_dataloader):
print(value)
print(value['input_ids'].shape)
print(value['labels'].shape)
break
- 打印结果:
902
122
{
'input_ids': tensor([[ 37010, 12, 5, 76331, 83362, 92831,
103593, 64464, 6,
77115, 65077, 72863, 63891, 66207, 63823, 4, 3430, 12,
68327, 74351, 77756, 66263, 81577, 64536, 6, 82145, 2031,
63825, 69574, 66207, 12, 4, 4, 64590, 67748, 69958,
66152, 63923, 65024, 64676, 65102, 66089, 64101, 73127, 64025,
64236, 6, 72996, 73518, 64236, 82273, 63823, 4, 13049,
12, 130001, 130004, 5, 125827, 2031, 4, 127903, 38861,
83, 28, 66845, 67541, 57, 28, 1932, 24, 317,
83, 28, 64069, 57, 28, 9832, 24, 317, 83,
28, 65210, 57, 28, 1932, 83, 28, 73127, 64025,
64236, 57, 28, 9832, 83, 28, 64590, 67748, 69958,
66152, 127731, 4, 125827, 130005, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3]]),
'labels': tensor([[ -100, -100, -100, -100, -100, -100, -100,
-100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, 130004, 5, 125827, 2031, 4, 127903, 38861,
83, 28, 66845, 67541, 57, 28, 1932, 24, 317,
83, 28, 64069, 57, 28, 9832, 24, 317, 83,
28, 65210, 57, 28, 1932, 83, 28, 73127, 64025,
64236, 57, 28, 9832, 83, 28, 64590, 67748, 69958,
66152, 127731, 4, 125827, 130005, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100]])
}
torch.Size([1, 200])
torch.Size([1, 200])