跳转至

5.2 多任务数据预处理方式

多任务数据预处理介绍

学习目标

  • 了解本项目中多任务数据集的格式
  • 掌握实现数据预处理的函数代码

数据预处理过程

  • 本项目中对数据部分的预处理步骤如下:
  • 查看项目数据集
  • 编写Config类项目文件配置代码
  • 编写数据处理相关代码

1 查看项目数据集

  • 数据存放位置:/Users/**/PycharmProjects/llm/ptune_chatglm/data

  • data文件夹里面包含3个jsonl文档,分别为:mixed_train_dataset.jsonl、mixed_dev_dataset.jsonl、dataset.jsonl


1.1 train.jsonl

  • mixed_train_dataset.jsonl为训练数据集,因为我们本次项目同时进行「信息抽取+文本分类」两项任务,因此数据中混合了两种任务数据类型。举例展示如下:

  • 信息抽取数据示例

  • Instruction 部分告诉模型现在需要做「阅读理解」任务,Input 部分告知模型要抽取的句子以及输出的格式。
{
    "context": "Instruction: 你现在是一个很厉害的阅读理解器,严格按照人类指令进行回答。\nInput: 找到句子中的三元组信息并输出成json给我:\n\n九玄珠是在纵横中文网连载的一部小说,作者是龙马。\nAnswer: ", 
    "target": "```json\n[{\"predicate\": \"连载网站\", \"object_type\": \"网站\", \"subject_type\": \"网络小说\", \"object\": \"纵横中文网\", \"subject\": \"九玄珠\"}, {\"predicate\": \"作者\", \"object_type\": \"人物\", \"subject_type\": \"图书作品\", \"object\": \"龙马\", \"subject\": \"九玄珠\"}]\n```"
}
  • 文本数据示例
  • Instruction 部分告诉模型现在需要做「阅读理解」任务,Input 部分告知模型要抽取的句子以及输出的格式。
{
    "context": "Instruction: 你现在是一个很厉害的阅读理解器,严格按照人类指令进行回答。\nInput: 下面句子可能是一条关于什么的评论,用列表形式回答:\n\n很不错,很新鲜,快递小哥服务很好,水果也挺甜挺脆的\nAnswer: ", 
    "target": "[\"水果\"]"
}

训练集中一共包含902条数据,每一条数据都分为 contexttarget 两部分:

  1. context 部分是接受用户的输入。2.target 部分用于指定模型的输出。

context 中又包括 2 个部分:

  1. Instruction:用于告知模型的具体指令,当需要一个模型同时解决多个任务时可以设定不同的 Instruction 来帮助模型判别当前应当做什么任务。
  2. Input:当前用户的输入。

1.2 dev.jsonl

  • mixed_dev_dataset.jsonl为训练数据集,因为我们本次项目同时进行「信息抽取+文本分类」两项任务,因此数据中混合了两种任务数据类型。举例展示如下:

  • 信息抽取数据示例

  • Instruction 部分告诉模型现在需要做「阅读理解」任务,Input 部分告知模型要抽取的句子以及输出的格式。
{
    "context": "Instruction: 你现在是一个很厉害的阅读理解器,严格按照人类指令进行回答。\nInput: 下面句子包含了哪些三元组,只用json的格式回答:\n\n《全国公共英语等级考试四级词汇科学记忆(磁带1--5)》是由人民大学出版社出版的一部教育作品,作者是钟道隆。\nAnswer: ", 
    "target": "```json\n[{\"predicate\": \"出版社\", \"object_type\": \"出版社\", \"subject_type\": \"书籍\", \"object\": \"人民大学\", \"subject\": \"全国公共英语等级考试四级词汇科学记忆(磁带1--5)\"}, {\"predicate\": \"作者\", \"object_type\": \"人物\", \"subject_type\": \"图书作品\", \"object\": \"钟道隆\", \"subject\": \"全国公共英语等级考试四级词汇科学记忆(磁带1--5)\"}]\n```"
}
  • 文本数据示例
  • Instruction 部分告诉模型现在需要做「阅读理解」任务,Input 部分告知模型要抽取的句子以及输出的格式。
{
    "context": "Instruction: 你现在是一个很厉害的阅读理解器,严格按照人类指令进行回答。\nInput: 下面句子中描述的是一个什么?用列表的方式回答。\n\n什么苹果啊,都没有苹果味,怪怪的味道,而且一点都不甜,超级难吃!\nAnswer: ", 
    "target": "[\"水果\"]"
}

训练集中一共包含122条数据,每一条数据都分为 contexttarget 两部分:

  1. context 部分是接受用户的输入。2.target 部分用于指定模型的输出。

context 中又包括 2 个部分:

  1. Instruction:用于告知模型的具体指令,当需要一个模型同时解决多个任务时可以设定不同的 Instruction 来帮助模型判别当前应当做什么任务。
  2. Input:当前用户的输入。

如果想使用自定义数据训练,只需要仿照上述示例数据构建数据集即可。

2 编写项目Config类配置文件

  • 代码路径:/Users/**/PycharmProjects/llm/ptune_chatglm/glm_config.py

  • config文件目的:配置项目常用变量,一般这些变量属于不经常改变的,比如:训练文件路径、模型训练次数、模型超参数等等

具体代码实现:

# -*- coding:utf-8 -*-
import torch


class ProjectConfig(object):
    def __init__(self):
        self.device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
        self.pre_model = './llm/ChatGLM-6B/THUDM/chatglm-6b'
        self.train_path = './llm/ptune_chatglm/data/mixed_train_dataset.jsonl'
        self.dev_path = './llm/ptune_chatglm/data/mixed_dev_dataset.jsonl'
        self.use_lora = True
        self.use_ptuning = False
        # 低秩矩阵的秩是8
        self.lora_rank = 8
        self.batch_size = 1
        self.epochs = 2
        self.learning_rate = 3e-5
        self.weight_decay = 0
        self.warmup_ratio = 0.06
        self.max_source_seq_len = 400
        self.max_target_seq_len = 300
        self.logging_steps = 10
        self.save_freq = 200
        self.pre_seq_len = 128
        self.prefix_projection = False # 默认为False,即p-tuning,如果为True,即p-tuning-v2
        self.save_dir = './llm/ptune_chatglm/checkpoints/ptune'


if __name__ == '__main__':
    pc = ProjectConfig()
    print(pc.save_dir)

3 编写数据处理相关代码

  • 代码路径:/Users/**/PycharmProjects/llm/ptune_chatglm/data_handle.
  • data_handle文件夹中一共包含两个py脚本:data_preprocess.py、data_loader.py

3.1 data_preprocess.py

  • 目的: 将样本数据转换为模型接受的输入数据
  • 导入必备的工具包
import json
# 返回的字符串包含有关异常的详细信
import traceback
import numpy as np
from tqdm import tqdm
from datasets import load_dataset
from transformers import AutoTokenizer
from functools import partial
import sys
sys.path.append('..')

from glm_config import *

  • 定义数据转换方法convert_example()
def convert_example(
        examples: dict,
        tokenizer,
        max_source_seq_len: int,
        max_target_seq_len: int,
    ):
    """
    将样本数据转换为Ptuning模型接收的输入数据。

    Args:
        examples (dict): 训练数据样本, e.g. -> {
                                                "text": [
                                                            '{"context": "年基准利率4.35%。从实际看...", "target": "2017年银行贷款基准利率"}',
                                                            ...
                                                ]
                                            }
        max_source_seq_len (int): prompt最大长度
        max_target_seq_len (int): 答案最大长度

    Returns:
        dict (str: np.array) -> tokenized_output = {
                            'input_ids': [[1525, 10, ...], [758, 2345, ...]],
                            'labels': [[822, 10, ...], [125, 58...]]
                        }
    """
    tokenized_output = {
        'input_ids': [],
        'labels': []
    }

    max_seq_length = max_source_seq_len + max_target_seq_len

    for example in examples['text']:
        try:
            example = json.loads(example)
            context = example["context"]
            target = example["target"]
            # print(f'context-->{context}')
            # print(f'target-->{target}')
            prompts_ids = tokenizer.encode(
                text=context,
                add_special_tokens=False
            )
            # print(f'prompts_ids--》{prompts_ids}\n{len(prompts_ids)}')

            target_ids = tokenizer.encode(
                text=target,
                add_special_tokens=False
            )
            # print(f'target_ids--》{target_ids}\n{len(target_ids)}')

            if len(prompts_ids) >= max_source_seq_len:
                # source 需要留一个 [gMASK] token 在结尾
                prompts_ids = prompts_ids[:max_source_seq_len - 1]

            if len(target_ids) >= max_target_seq_len - 1: 
              # target 需要留一个 <sop> 在开头和一个 <eop> token 在结尾
                target_ids = target_ids[:max_target_seq_len - 2]

                        # source_ids + [gMASK] + <sop> + target_ids + <eop>
            input_ids = tokenizer.build_inputs_with_special_tokens(prompts_ids, target_ids)     
            # print(f'input_ids-->{input_ids}')

            # bos 在 target 的第一位
            context_length = input_ids.index(tokenizer.bos_token_id) 
            # print(f'context_length-->{context_length}')
                        # [gMASK] 在 source 的最后一位
            mask_position = context_length - 1

            # 从 bos 开始到后面所有的 target 到 eos 都为 label
            labels = [-100] * context_length + input_ids[mask_position + 1:]                    
            # print(f'labels-->{labels}')

            pad_len = max_seq_length - len(input_ids)
            # print(f'pad_len-->{pad_len}')

            input_ids = input_ids + [tokenizer.pad_token_id] * pad_len
            # print(f'input_ids-->{input_ids}\n{len(input_ids)}')
            labels = labels + [-100] * pad_len
            # print(f'labels-->{labels}\n{len(labels)}')


            tokenized_output['input_ids'].append(input_ids)
            tokenized_output['labels'].append(labels)
        except:
            print(f'"{example}" -> {traceback.format_exc()}')
            continue

    for k, v in tokenized_output.items():
        tokenized_output[k] = np.array(v)

    return tokenized_output

  • 定义获取训练或验证数据最大长度方法get_max_length()
def get_max_length(
        tokenizer,
        dataset_file: str
    ):
    """
    测试数据集最大的输入/输出tokens是多少。

    Args:
        dataset_file (str): _description_
    """
    source_seq_len_list = []
    target_seq_len_list = []
    with open(dataset_file, 'r') as f:
        for line in tqdm(f.readlines()):
            line = json.loads(line)

            source_len = tokenizer.encode(line['context'])
            source_seq_len_list.append(len(source_len))

            target_len = tokenizer.encode(line['target'])
            target_seq_len_list.append(len(target_len))

    print(dataset_file)
    print(f"【Source Sequence】 Max: {max(source_seq_len_list)}, Avg: {int(sum(source_seq_len_list) / len(source_seq_len_list))}, Middle: {sorted(source_seq_len_list)[int(len(source_seq_len_list) / 2)]}.")
    print(f"【Target Sequence】 Max: {max(target_seq_len_list)}, Avg: {int(sum(target_seq_len_list) / len(target_seq_len_list))}, Middle: {sorted(target_seq_len_list)[int(len(target_seq_len_list) / 2)]}.")

3.3 data_loader.py

  • 目的:定义数据加载器
  • 导入必备的工具包
# coding:utf-8
from torch.utils.data import DataLoader
from transformers import default_data_collator, AutoTokenizer
from data_handle.data_preprocess import *
from glm_config import *

pc = ProjectConfig() # 实例化项目配置文件

tokenizer = AutoTokenizer.from_pretrained(pc.pre_model, trust_remote_code=True)

  • 定义获取数据加载器的方法get_data()
def get_data():
    dataset = load_dataset('text', data_files={'train': pc.train_path,
                                               'dev': pc.dev_path})


    new_func = partial(convert_example,
                       tokenizer=tokenizer,
                       max_source_seq_len=100,
                       max_target_seq_len=100)

    dataset = dataset.map(new_func, batched=True)
    train_dataset = dataset["train"]
    dev_dataset = dataset["dev"]
    train_dataloader = DataLoader(train_dataset,
                                  shuffle=True,
                                  collate_fn=default_data_collator,
                                  batch_size=pc.batch_size)
    dev_dataloader = DataLoader(dev_dataset,
                                collate_fn=default_data_collator,
                                batch_size=pc.batch_size)
    return train_dataloader, dev_dataloader
if __name__ == '__main__':
    train_dataloader, dev_dataloader = get_data()
    print(len(train_dataloader))
    print(len(dev_dataloader))
    for i, value in enumerate(train_dataloader):
        print(value)
        print(value['input_ids'].shape)
        print(value['labels'].shape)
        break
  • 打印结果:
902
122
{
    'input_ids': tensor([[ 37010,     12,      5,  76331,  83362,  92831, 
103593,  64464,      6,
          77115,  65077,  72863,  63891,  66207,  63823,      4,   3430,     12,
          68327,  74351,  77756,  66263,  81577,  64536,      6,  82145,   2031,
          63825,  69574,  66207,     12,      4,      4,  64590,  67748,  69958,
          66152,  63923,  65024,  64676,  65102,  66089,  64101,  73127,  64025,
          64236,      6,  72996,  73518,  64236,  82273,  63823,      4,  13049,
             12, 130001, 130004,      5, 125827,   2031,      4, 127903,  38861,
             83,     28,  66845,  67541,     57,     28,   1932,     24,    317,
             83,     28,  64069,     57,     28,   9832,     24,    317,     83,
             28,  65210,     57,     28,   1932,     83,     28,  73127,  64025,
          64236,     57,     28,   9832,     83,     28,  64590,  67748,  69958,
          66152, 127731,      4, 125827, 130005,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3,      3,      3,      3,      3,      3,      3,      3,
              3,      3]]),
    'labels': tensor([[  -100,   -100,   -100,   -100,   -100,   -100,   -100,  
-100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100, 130004,      5, 125827,   2031,      4, 127903,  38861,
             83,     28,  66845,  67541,     57,     28,   1932,     24,    317,
             83,     28,  64069,     57,     28,   9832,     24,    317,     83,
             28,  65210,     57,     28,   1932,     83,     28,  73127,  64025,
          64236,     57,     28,   9832,     83,     28,  64590,  67748,  69958,
          66152, 127731,      4, 125827, 130005,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100]])
}
torch.Size([1, 200])
torch.Size([1, 200])