视频介绍更加详细,包含数据集、微调命令的下载。
从0到1一步步学会ChatGLM3-6B的部署和微调:
ChatGLM3-6B的安装部署、微调、训练智能客服_哔哩哔哩_bilibili
1 下载安装
1.1 下载推理代码(运行AI大模型使用)
GitHub地址:https://github.com/THUDM/ChatGLM3
git clone https://github.com/THUDM/ChatGLM3.git
1.2 下载模型文件
魔塔地址:https://modelscope.cn/models/ZhipuAI/chatglm3-6b/summary
git clone https://www.modelscope.cn/ZhipuAI/chatglm3-6b.git
1.3 下载微调代码LLaMaFactory(微调AI大模型使用)
https://github.com/hiyouga/LLaMA-Factory
git clone https://github.com/hiyouga/LLaMA-Factory.git
1.4 下载API运行方式的向量模型
git clone https://www.modelscope.cn/xrunda/m3e-base.git
2 配置
2.1 修改ChatGLM的AI大模型地址
修改/mnt/workspace/apps/ChatGLM3/basic_demo/cli_demo.py,将默认的THUDM/chatglm3-6b修改为你的模型地址,比如我的是/mnt/workspace/models/chatglm3-6b
2.2 修改API运行方式的向量模型路径
修改/mnt/workspace/models/m3e-base/api_server.py
3 推理,也就是运行AI大模型
3.1 命令行方式运行
python cli_demo.py
3.2 网页流方式运行(网页有打字效果)
3.3 API调用方式运行
【启动AI模型服务端】
/mnt/workspace/apps/ChatGLM3/openai_api_demo# python api_server.py
【curl测试命令】
curl -X POST "http://127.0.0.1:8000/v1/chat/completions" -H "Content-Type: application/json" -d "{"model": "chatglm3-6b", "messages": [{"role": "system", "content": "You are ChatGLM3, a large language model trained by Zhipu.AI. Follow the user's instructions carefully. Respond using markdown."}, {"role": "user", "content": "你好"}], "stream": false, "max_tokens": 100, "temperature": 0.8, "top_p": 0.8}"
4 微调、训练智能客服
微调前,ChatGLM大模型只能根据互联网抓取到的文本进行回答。
4.1 微调数据集整理
4.1.1 整理数据集
【数据集格式】
{ "instruction": "请告诉我怎么下单", "input": "", "output": "好的,请到https://space.bilibili.com/400138805。进入下单平台——》选择您的商品——》输入你的订单信息后进行下单" }
【数据集下载】
可观看视频查看下载方式
4.1.2 加入数据集
计算文件的sha后加入dataset_info.json
步骤一:计算sha代码
import java.io.FileInputStream;
import java.io.IOException;
import java.security.DigestInputStream;
import java.security.MessageDigest;
public class FileHashCalculator {
public static void main(String[] args) {
String filePath = "D:\\train\\chonger_ai_cussrv.json"; // 替换为你的文件路径
try {
String sha256 = getSHA256(filePath);
System.out.println("The SHA-256 hash of the file is: " + sha256);
} catch (IOException e) {
System.err.println("Error reading the file: " + e.getMessage());
} catch (Exception e) {
System.err.println("Error calculating SHA-256 hash: " + e.getMessage());
}
}
private static String getSHA256(String filePath) throws Exception {
try (FileInputStream fis = new FileInputStream(filePath)) {
MessageDigest digest = MessageDigest.getInstance("SHA-256");
DigestInputStream dis = new DigestInputStream(fis, digest);
byte[] buffer = new byte[4096];
while (dis.read(buffer) > 0) {
// 这里读取数据,但不需要实际使用数据,所以我们简单跳过
}
byte[] hashBytes = digest.digest();
StringBuilder hexString = new StringBuilder();
for (byte b : hashBytes) {
String hex = Integer.toHexString(0xff & b);
if (hex.length() == 1) hexString.append('0');
hexString.append(hex);
}
return hexString.toString();
}
}
}
步骤二:将sha加入dataset_info.json
"chonger_ai_cussrv": {
"file_name": "chonger_ai_cussrv.json",
"file_sha1": "27c60f462738034236900bd294e28af893bbb647e9a69f41dcd2fd062b0a62e4"
}
4.2 微调
python src/train_bash.py
--stage sft
--do_train True
--model_name_or_path /mnt/workspace/models/chatglm3-6b
--overwrite_output_dir
--dataset_dir data
--cutoff_len 1024
--dataset chonger_ai_cussrv
--template chatglm3
--finetuning_type lora
--lora_target query_key_value
--output_dir chonger_ai_cussrv
--overwrite_cache
--per_device_train_batch_size 2
--gradient_accumulation_steps 8
--lr_scheduler_type cosine
--logging_steps 5
--save_steps 100
--warmup_steps 0
--learning_rate 1e-4
--num_train_epochs 50.0
--max_samples 100000
--max_grad_norm 1.0
--plot_loss True
--fp16 True
--lora_rank 8
--lora_alpha 16
--lora_dropout 0.1
参数说明:
--stage sft: 指定训练阶段或模式,sft 可能代表 "supervised fine-tuning"(监督微调)。
--do_train True: 表示要执行训练过程。
--model_name_or_path /mnt/workspace/models/chatglm3-6b: 指定预训练模型的位置。在这个例子中,使用的模型是 chatglm3-6b。
--overwrite_output_dir: 如果输出目录已经存在,则覆盖它。
--dataset_dir data: 数据集所在的目录。
--cutoff_len 1024: 输入序列的最大长度。如果输入序列超过此长度,会被截断。
--dataset chonger_ai_cussrv: 指定具体的数据集名称。
--template chatglm3: 使用的模板类型,这里指定了与 chatglm3 相关的特定模板。
--finetuning_type lora: 微调类型,这里使用的是 LoRA (Low-Rank Adaptation) 方法。
--lora_target query_key_value: 指定 LoRA 应用到哪些层或模块上,在这里是针对 query_key_value 层。
--output_dir chonger_ai_cussrv: 训练结果和模型保存的目录。
--overwrite_cache: 如果缓存已存在,则覆盖它。
--per_device_train_batch_size 2: 每个设备上的训练批次大小。
--gradient_accumulation_steps 8: 梯度累积步数,可以用来模拟更大的批次大小。
--lr_scheduler_type cosine: 学习率调度器类型,这里是余弦退火策略。
--logging_steps 5: 日志记录的频率,每5个步骤记录一次。
--save_steps 100: 模型保存的频率,每100个步骤保存一次。
--warmup_steps 0: 学习率热身步骤的数量,这里没有设置热身期。
--learning_rate 1e-4: 初始学习率。
--num_train_epochs 50.0: 训练的总轮次(epoch)。
--max_samples 100000: 训练数据的最大样本数量。
--max_grad_norm 1.0: 最大梯度范数,用于梯度裁剪以避免梯度爆炸。
--plot_loss True: 是否绘制损失曲线。
--fp16 True: 使用混合精度训练(FP16),这可以节省显存并加速训练。
--lora_rank 8: LoRA 的秩(rank),控制了 LoRA 参数矩阵的大小。
--lora_alpha 16: LoRA 的缩放因子,影响 LoRA 的学习率。
--lora_dropout 0.1: LoRA 层中的 dropout 比率。
4.3 导出微调模型
python src/export_model.py
--model_name_or_path /mnt/workspace/models/chatglm3-6b/
--adapter_name_or_path /mnt/workspace/apps/LLaMA-Factory/chonger_ai_cussrv/
--template chatglm3
--finetuning_type lora
--export_dir export_chonger_ai_cussrv
--export_size 2
--export_legacy_format False
4.4 运行微调模型
python src/cli_demo.py
--model_name_or_path /mnt/workspace/apps/LLaMA-Factory/export_chonger_ai_cussrv/
--template chatglm3
--finetuning_type lora
版权归原作者 congzi1984 所有, 如有侵权,请联系我们删除。