0


Datawhale AI 夏令营

Datawhale AI 夏令营 Task3 baseline02 微调方案 学习笔记

1、微调技术及微调任务介绍

1.1产生背景

机器学习模型只能代表其训练数据的逻辑和理解。对于未见过的数据样本,模型可能无法准确识别或理解。对于大型模型而言,它们虽然能够处理广泛的语言信息并进行流畅的对话,但在特定场景下可能无法提供准确的答案。此时,需要对基础模型进行微调。

1.2环境配置

1、需要的数据、模型文件、代码文件地址:魔搭社区

2、环境:魔搭社区,Notebook,GPU环境ubuntu22.04-cuda12.1.0-py310-torch2.1.2-tf2.14.0-1.14.0

1.3LoRA微调

LoRA优势
  • 可以针对不同的下游任务构建小型 LoRA 模块,从而在共享预训练模型参数基础上有效地切换下游任务。

  • LoRA 使用自适应优化器(Adaptive Optimizer),不需要计算梯度或维护大多数参数的优化器状态,训练更有效、硬件门槛更低。

  • LoRA 使用简单的线性设计,在部署时将可训练矩阵与冻结权重合并,不存在推理延迟。

  • LoRA 与其他方法正交,可以组合。

LoRA原理

2、vllm加速

2.1 vllm介绍

vLLM(Virtual Large Language Model)是一个由伯克利大学LMSYS组织开源的大规模语言模型高速推理框架。它的设计目标是在实时应用场景中大幅提升语言模型服务的吞吐量和内存使用效率。vLLM的特点包括易于使用、与Hugging Face等流行工具无缝集成以及高效的性能。

3、多路LLM投票

3.1思路介绍

所谓的“多路召回策略”就是指采用不同的策略、特征或者简单模型,分别召回一部分候选集,然后再把这些候选集混合在一起后供后续排序模型使用的策略。

4.2 实现原理

设计投票函数:

通过三次结果推理,将选择答案最多的结果作为最终结果:

 def most_frequent_char(char1, char2, char3):
     # 创建一个字典来存储每个字符的出现次数
     frequency = {char1: 0, char2: 0, char3: 0}
     
     # 增加每个字符的出现次数
     frequency[char1] += 1
     frequency[char2] += 1
     frequency[char3] += 1
     
     # 找到出现次数最多的字符
     most_frequent = max(frequency, key=frequency.get)
     
     return most_frequent

设计多路LLM:

改写process函数,三次调用llm,做出现次数统计,最终返回投票数最多的结果。

 def process_datas(datas,MODEL_NAME):
     results = []
 ​
     # 送入多线程任务
     for data in tqdm(datas, desc="Submitting tasks", total=len(datas)):
         problem = data['problem']
         for id,question in enumerate(data['questions']):
             prompt = get_prompt(problem, 
                                 question['question'], 
                                 question['options'],
                                     )
             # 统一使用llm 三次调用
             res,res1,res2 = api_retry(MODEL_NAME, prompt),api_retry(MODEL_NAME, prompt),api_retry(MODEL_NAME, prompt)
             # 统一做结果抽取
             extract_response,extract_response1,extract_response2 = extract(res),extract(res1),extract(res2)
             # 通过投票函数获取最终结果并返回
             ans = most_frequent_char(extract_response,extract_response1,extract_response2)
             data['questions'][id]['answer'] = ans
             results.append(data) 
     return results

5、总结

这一个task我们学习了如何使用LoRA微调,并且使用vllm加速。这里和大家同步一下微调后模型的性能。
原模型LoRA_anLoRA_an_投票成绩0.660.680.71用时120min2min7min
原模型为Qwen2-7B-Instruct模型

LoRA_an为使用an数据集对Qwen2-7B-Instruct lora微调后的模型

LoRA_an_投票为使用an数据集对Qwen2-7B-Instruct lora微调后再进行多路投票的模型

标签: 大数据

本文转载自: https://blog.csdn.net/2301_80708318/article/details/140882048
版权归原作者 2301_80708318 所有, 如有侵权,请联系我们删除。

“Datawhale AI 夏令营”的评论:

还没有评论