轻量级大语言模型MiniMind源码解读(九):指令微调详解-让大模型从“能说”变得“会听”
一、查看有监督微调数据集格式
import jsonpretrain_dataset_path=r'D:\MyFile\github\minimind-master\minimind_dataset\sft_mini_512.jsonl'with open(pretrain_dataset_path, 'r', encoding='utf-8') as f:for line_num, line in enumerate(f, 1): data = json.loads(line.strip())breakprint(data.keys()) # dict_keys(['text'])print(data)
{'conversations': [ {'role': 'user', 'content': '请告诉我在中国古代的“四大发明”是什么?' }, {'role': 'assistant', 'content': '中国古代的“四大发明”是指造纸术、印刷术、火药和指南针。这四项发明对世界文明的发展产生了深远的影响:\n\n1. **造纸术**:据史书记载,东汉时期的蔡伦改进了造纸工艺,使得纸张的生产更加便捷、成本更低,质量也更加优良。这一发明极大地促进了文化的传播和保存。\n\n2. **印刷术**:中国古代的印刷术最早可以追溯到唐代的雕版印刷,到了宋代发展出了活字印刷技术。印刷术的发明极大地促进了知识的传播,降低了书籍的成本,对教育和文化的发展起到了重要的推动作用。\n\n3. **火药**:火药最初是在唐代被发现的,最初可能用于医疗或炼金术。到了宋代,人们开始将火药用于军事目的,发明了各种火器。火药的发明改变了战争的面貌,对世界军事技术的发展产生了重大影响。\n\n4. **指南针**:指南针最初在中国被用于风水测量,后来逐渐发展成为航海导航的重要工具。这一发明极大地促进了海上航行技术的发展,对于新航路的开辟和世界地理大发现起到了关键作用。\n\n这四项发明不仅在中国历史上占有重要地位,而且对全世界的科技进步和文明发展都产生了深远的影响。' } ]}
二、准备有监督微调数据加载器
构建符合PyTorch的Dataloader的Dataset类:
import jsonimport torchfrom torch.utils.data import DatasetclassSFTDataset(Dataset):def__init__(self, jsonl_path, tokenizer, max_length=1024): super().__init__() self.tokenizer = tokenizer # 分词器 self.max_length = max_length # 最大输入长度(会进行截断或填充) self.samples = self.load_data(jsonl_path) # 加载数据样本 self.bos_id = tokenizer('<|im_start|>assistant', add_special_tokens=False).input_ids# [1, 1078, 538, 501], [1]是<|im_start|>这个特殊token的id,[1078, 538, 501]是assistant的分词id self.eos_id = tokenizer('<|im_end|>', add_special_tokens=False).input_ids# [2]def__len__(self):return len(self.samples) # 返回样本数量defload_data(self, path):"""从 jsonl 文件加载对话数据""" samples = []with open(path, 'r', encoding='utf-8') as f:for line_num, line in enumerate(f, 1): data = json.loads(line.strip()) # 每行为一个 JSON 对象 samples.append(data)return samplesdef_create_chat_prompt(self, conversations):""" 将对话轮构造成符合 ChatML 格式的字符串: 每一轮用户/助手对话被标注为 'user' / 'assistant' 最终用 tokenizer 的 apply_chat_template 统一构造 prompt。 """ messages = []for i, turn in enumerate(conversations): role = 'user'if i % 2 == 0else'assistant'# 偶数轮为用户,奇数轮为助手 messages.append({"role": role, "content": turn['content']})# 返回字符串形式的 prompt,而非直接 tokenizereturn self.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=False )def_generate_loss_mask(self, input_ids):""" 构建损失掩码,只有 assistant 的回答部分才参与 loss 计算。 找出每一段 assistant 的响应,在其 <|im_start|>assistant 和 <|im_end|> 之间设置 loss_mask 为 1。 """ loss_mask = [0] * len(input_ids) i = 0while i < len(input_ids):# 找 assistant 开头标志if input_ids[i:i + len(self.bos_id)] == self.bos_id: start = i + len(self.bos_id) # 答案起点 end = startwhile end < len(input_ids):# 查找 assistant 的回答终止符 <|im_end|>if input_ids[end:end + len(self.eos_id)] == self.eos_id:break end += 1# 为 assistant 回答部分(从 start + 1 到 end 之间)设置 loss maskfor j in range(start + 1, min(end + len(self.eos_id) + 1, self.max_length)): loss_mask[j] = 1# 跳过到下一个 segment i = end + len(self.eos_id) if end < len(input_ids) else len(input_ids)else: i += 1return loss_maskdef__getitem__(self, index): sample = self.samples[index]# 构建 ChatML 格式 prompt(字符串) prompt = self._create_chat_prompt(sample['conversations'])# 分词并截断,确保长度 <= max_length input_ids = self.tokenizer(prompt).input_ids[:self.max_length]# 右侧填充 pad_token 直到 max_length 长度 input_ids += [self.tokenizer.pad_token_id] * (self.max_length - len(input_ids))# 生成动态 loss mask,仅对 assistant 响应位置计算 loss loss_mask = self._generate_loss_mask(input_ids)# 构建训练样本:# 模型输入为前 n-1 个 token,预测目标为第 2 到第 n 个 token X = torch.tensor(input_ids[:-1], dtype=torch.long) # 输入序列 Y = torch.tensor(input_ids[1:], dtype=torch.long) # 目标标签(shifted) loss_mask = torch.tensor(loss_mask[1:], dtype=torch.long) # 对齐 Y 的位置(从第一个预测 token 开始)return X, Y, loss_mask
沿着__getitem__方法,逐行向下解析。
2.1 sample = self.samples[index]
sample = self.samples[index]用于获取self.samples中对应index的一条数据,这是从原始.jsonl数据集中读取的,如上所述,它只有一个key叫做conversations,取出其value,示例如下:
[ {'role': 'user', 'content': '请告诉我在中国古代的“四大发明”是什么?' }, {'role': 'assistant', 'content': '中国古代的“四大发明”是指造纸术、印刷术、火药和指南针。这四项发明对世界文明的发展产生了深远的影响:\n\n1. **造纸术**:据史书记载,东汉时期的蔡伦改进了造纸工艺,使得纸张的生产更加便捷、成本更低,质量也更加优良。这一发明极大地促进了文化的传播和保存。\n\n2. **印刷术**:中国古代的印刷术最早可以追溯到唐代的雕版印刷,到了宋代发展出了活字印刷技术。印刷术的发明极大地促进了知识的传播,降低了书籍的成本,对教育和文化的发展起到了重要的推动作用。\n\n3. **火药**:火药最初是在唐代被发现的,最初可能用于医疗或炼金术。到了宋代,人们开始将火药用于军事目的,发明了各种火器。火药的发明改变了战争的面貌,对世界军事技术的发展产生了重大影响。\n\n4. **指南针**:指南针最初在中国被用于风水测量,后来逐渐发展成为航海导航的重要工具。这一发明极大地促进了海上航行技术的发展,对于新航路的开辟和世界地理大发现起到了关键作用。\n\n这四项发明不仅在中国历史上占有重要地位,而且对全世界的科技进步和文明发展都产生了深远的影响。' }]
2.2 prompt = self._create_chat_prompt(sample['conversations'])
self._create_chat_prompt(sample['conversations'])将上述sample应用一种称之为ChatML的模板,它是一种专门为多轮对话任务设计的Prompt模板格式,用于格式化输入,模板如下:
{% for message in messages %}<|im_start|>{{ message['role'] }}{{ message['content'] }}<|im_end|>{% endfor %}
上述代码使用了self.tokenizer.apply_chat_template方法来应用ChatML模板,其中tokenize=False表示只返回字符串形式的prompt,不进行分词。add_generation_prompt=False表示是否在最后自动添加<|im_start|>assistant这样的生成提示,用于推理阶段.如果是训练数据(已经包括答案),一般设为 False。
应用ChatML模板后得到的prompt为:
'<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\n请告诉我在中国古代的“四大发明”是什么?<|im_end|>\n<|im_start|>assistant\n中国古代的“四大发明”是指造纸术、印刷术、火药和指南针。这四项发明对世界文明的发展产生了深远的影响:\n\n1. **造纸术**:据史书记载,东汉时期的蔡伦改进了造纸工艺,使得纸张的生产更加便捷、成本更低,质量也更加优良。这一发明极大地促进了文化的传播和保存。\n\n2. **印刷术**:中国古代的印刷术最早可以追溯到唐代的雕版印刷,到了宋代发展出了活字印刷技术。印刷术的发明极大地促进了知识的传播,降低了书籍的成本,对教育和文化的发展起到了重要的推动作用。\n\n3. **火药**:火药最初是在唐代被发现的,最初可能用于医疗或炼金术。到了宋代,人们开始将火药用于军事目的,发明了各种火器。火药的发明改变了战争的面貌,对世界军事技术的发展产生了重大影响。\n\n4. **指南针**:指南针最初在中国被用于风水测量,后来逐渐发展成为航海导航的重要工具。这一发明极大地促进了海上航行技术的发展,对于新航路的开辟和世界地理大发现起到了关键作用。\n\n这四项发明不仅在中国历史上占有重要地位,而且对全世界的科技进步和文明发展都产生了深远的影响。<|im_end|>\n'
紧接着对这个prompt使用tokenizer编码成input_ids,并根据最大序列长度进行padding处理。
2.3 loss_mask = self._generate_loss_mask(input_ids)
这里仅对assistant响应位置(也就是assistant回复的内容)计算loss,因此需要找出每一段assistant的响应,在其<|im_start|>assistant和<|im_end|>之间设置loss_mask为1,其余位置的loss_mask均为0。
使用_generate_loss_mask方法实现上述功能。
基本思想就是遍历整个input_ids,查找出现<|im_start|>assistant的位置start,这是模型回复开始的标志;然后继续遍历,找到第一个出现的<|im_end|>的位置end,start到end之间的计算模型的回复,loss_mask设置为1。
如果是多轮对话,就继续往后遍历,查找第二个模型预测开始的位置<|im_start|>assistant,以此类推。
最后,和预训练一样,返回X, Y以及Y对应的loss mask。
现在来构建数据加载器:
from torch.utils.data import DataLoaderfrom transformers import AutoTokenizermax_length=512data_path=r'D:\MyFile\github\minimind-master\minimind_dataset\sft_mini_512.jsonl'tokenizer = AutoTokenizer.from_pretrained(r'D:\MyFile\github\minimind-master\model')train_ds = SFTDataset(data_path, tokenizer, max_length)train_loader = DataLoader( train_ds, batch_size=2, pin_memory=True, drop_last=False, shuffle=False, num_workers=0,)
查看一下有监督微调的数据总量以及数据的维度信息:
print(len(train_loader)) # 607362for item in train_loader: print([i.shape for i in item]) # [torch.Size([2, 511]), torch.Size([2, 511]), torch.Size([2, 511])]break
通过打印看到,有监督微调的数据总量为607362,每一条数据都包含3个PyTorch Tensor,分别是X, Y以及Y对应的padding mask(用于掩掉padding token的loss),shape都是2x511,2是batch_size,511是max_length-1,因为X和Y是正好是偏移一位的。这一点和预训练一样。
三、开始有监督微调
有监督微调代码和常规的模型预训练代码几乎没有区别,直接核心代码段粘贴过来:
loss_fct = nn.CrossEntropyLoss(reduction='none')for step, (X, Y, loss_mask) in enumerate(train_loader): X = X.to(args.device) Y = Y.to(args.device) loss_mask = loss_mask.to(args.device) lr = get_lr(epoch * iter_per_epoch + step, args.epochs * iter_per_epoch, args.learning_rate)for param_group in optimizer.param_groups: param_group['lr'] = lrwith ctx: res = model(X) loss = loss_fct( res.logits.view(-1, res.logits.size(-1)), Y.view(-1) ).view(Y.size()) loss = (loss * loss_mask).sum() / loss_mask.sum() loss += res.aux_loss loss = loss / args.accumulation_steps scaler.scale(loss).backward()if (step + 1) % args.accumulation_steps == 0: scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) scaler.step(optimizer) scaler.update() optimizer.zero_grad(set_to_none=True)
夜雨聆风