0


详细记录swfit微调interVL2-8B多模态大模型进行目标检测(附代码)

大模型相关目录

  1. 大模型,包括部署微调prompt/Agent应用开发、知识库增强、数据库增强、知识图谱增强、自然语言处理、多模态等大模型应用开发内容
  1. 0起步,扬帆起航。
  1. RAGOnMedicalKG:大模型结合知识图谱的RAG实现
  2. DSPy:变革式大模型应用开发
  3. 最简明的Few-shot Prompt指南
  4. Semantic Kernel:微软大模型开发框架——LangChain 替代
  5. 对话大模型Prompt是否需要礼貌点?
  6. swift与Internvl下的多模态大模型分布式微调指南(附代码和数据)
  7. 多模态大模型Internvl-1.5-26B微调后部署及测试实录(附代码)
  8. 多模态大模型Internvl-2-26B的OCR赋能方案(附代码)
  9. miniconda+xinference的大模型推理部署指南
  10. Mem0:大模型最强赋能“有记忆的LLM”
  11. 再谈Agent:Dify智能体实现Txet2SQL
  12. Moe模式:或将是最好的大模型应用开发路径
  13. 一文带你了解大模型RAG
  14. 详细记录swfit微调interVL2-8B多模态大模型进行目标检测(附代码)

文章目录


前言

目标检测任务已经不是一个新鲜事了,但是多模态大模型作目标检测任务并不多见,本文详细记录swfit微调interVL2-8B多模态大模型进行目标检测的过程,旨在让更多人了解多模态大模型微调技术、共享微调经验。

模型选型

并不是所有开源多模态大模型都有目标检测能力。
在这里插入图片描述
如图所示,哪怕是闭源模型,也并都不具备目标检测能力。
经调研,我们选用

  1. interVL2-8B模型

,在模型性能指标上,该模型胜过interVL1.5-26B的同时,还具备目标检测能力,且与interVL2-26B、40B、70B模型性能差不并没有非常巨大。

其回答格式也很有意思,此处分享:

  1. <ref>zs_code</ref><box>[[476,1221,814,1259]]</box>

数据集制作

本文任务数据集均为自行制作,其中,数据分布如下图:
在这里插入图片描述
其中,test文件夹用于性能测试,tain文件夹用于模型训练。pic子文件夹表示图像存储路径,xml表示标注存储路径,图像与标注一一对应。

具体内容如下:

图像示例:
在这里插入图片描述
对应标注示例

  1. <annotation><folder>code_data</folder><filename>xxx-本科毕业证.jpg</filename><path>C:\Users\12258\Desktop\code_data\xxx-本科毕业证.jpg</path><source><database>Unknown</database></source><size><width>842</width><height>596</height><depth>3</depth></size><segmented>0</segmented><object><name>zs_code</name><pose>Unspecified</pose><truncated>0</truncated><difficult>0</difficult><bndbox><xmin>142</xmin><ymin>422</ymin><xmax>351</xmax><ymax>446</ymax></bndbox></object></annotation>

该数据集使用labelimg手动标注,每张图像为典型毕业证、学位证、学历验证、资质证书影像,只含一个标签名称zs_code。

其中,测试图像30张,训练图像250张。

编写脚本,构建可用于微调训练的数据集jsonl,jsonl配合图像即可完成swift框架下的多模态模型微调。

  1. import os
  2. import random
  3. import matplotlib.pyplot as plt
  4. import matplotlib.patches as patches
  5. from PIL import Image
  6. import json
  7. from PIL import Image, ExifTags
  8. import xml.etree.ElementTree as ET
  9. defcreate_directory(path):"""Create a new directory at the given path."""try:
  10. os.makedirs(path, exist_ok=True)returnf"Directory created at {path}"except Exception as e:returnf"An error occurred: {e}"deflist_files(directory):"""List all files in the given directory."""return[fileforfilein os.listdir(directory)if os.path.isfile(os.path.join(directory,file))]deflist_files_with_absolute_paths(directory):"""List all files in the given directory with their absolute paths."""return[os.path.abspath(os.path.join(directory,file))forfilein os.listdir(directory)if os.path.isfile(os.path.join(directory,file))]defextract_xml_info(xml_file_path):withopen(xml_file_path,'r',encoding='utf-8')asfile:
  11. xml_content =file.read()# 解析XML内容
  12. root = ET.fromstring(xml_content)# 初始化一个列表来保存提取的信息
  13. extracted_info =[]# 遍历所有<object>标签for obj in root.findall('object'):
  14. name = obj.find('name').text
  15. bndbox = obj.find('bndbox')
  16. xmin =int(bndbox.find('xmin').text)
  17. ymin =int(bndbox.find('ymin').text)
  18. xmax =int(bndbox.find('xmax').text)
  19. ymax =int(bndbox.find('ymax').text)# 将提取的信息保存到列表中
  20. extracted_info.append({'name': name,'xmin': xmin,'ymin': ymin,'xmax': xmax,'ymax': ymax})
  21. name =str(extracted_info[0]['name'])
  22. xmin =str(extracted_info[0]['xmin'])
  23. ymin =str(extracted_info[0]['ymin'])
  24. xmax =str(extracted_info[0]['xmax'])
  25. ymax =str(extracted_info[0]['ymax'])# 仅仅用于单标注图像
  26. result =f'<ref>{name}</ref><box>[[{xmin},{ymin},{xmax},{ymax}]]</box>'return result
  27. defget_elements_with_string(lst, target_string):return[element for element in lst if target_string in element]
  28. train_pic_path ='/home/super/lyq/zsbm_mbjc/data/train/pic'
  29. train_xml_path ='/home/super/lyq/zsbm_mbjc/data/train/xml'
  30. test_pic_path ='/home/super/lyq/zsbm_mbjc/data/test/pic'
  31. test_xml_path ='/home/super/lyq/zsbm_mbjc/data/test/xml'
  32. train_pic_absolute_paths = list_files_with_absolute_paths(train_pic_path)
  33. train_xml_absolute_paths = list_files_with_absolute_paths(train_xml_path)
  34. test_pic_absolute_paths = list_files_with_absolute_paths(test_pic_path)
  35. test_xml_absolute_paths = list_files_with_absolute_paths(test_xml_path)
  36. train_pic_paths = list_files(train_pic_path)
  37. train_xml_paths = list_files(train_xml_path)
  38. test_pic_paths = list_files(test_pic_path)
  39. test_xml_paths = list_files(test_xml_path)
  40. dataset =[]for train_pic_absolute_path in train_pic_absolute_paths:# 图像路径
  41. mid_dict ={}
  42. file_head = train_pic_absolute_path.split('/')[-1].split('.')[0]# print(file_head,train_pic_absolute_path)
  43. xml_path = get_elements_with_string(train_xml_absolute_paths,file_head)[0]# print(xml_path)
  44. xml_info = extract_xml_info(xml_path)# response
  45. mid_dict ={'system':'''职位:你是一个面向证书图像的目标检测大师,具备精准识别、定位图像中证书编码的能力。
  46. 职能:从毕业证、学历验证报告、证书等图像中检测到证书编码区域并给出边界框。
  47. **注意**:仅以给定格式返回检测结果,不要给出其它任何解释。
  48. **注意**:若图片中没有典型违章场景,返回<ref> class_name </ref><box>[[0, 0, 0, 0]]</box>即可。
  49. ''','query':'请目标检测图像中的证书编码并给出边界框','response':xml_info,'images':train_pic_absolute_path
  50. }
  51. dataset.append(mid_dict)# 指定输出文件的名称
  52. output_file ='train_dataset.jsonl'# 打开文件并写入JSONL格式的数据withopen(output_file,'w', encoding='utf-8')as f:for item in dataset:# 将字典转换为JSON字符串并写入文件,每个字典占一行
  53. json_string = json.dumps(item,ensure_ascii=False)
  54. f.write(json_string +'\n')
  55. dataset =[]for test_pic_absolute_path in test_pic_absolute_paths:# 图像路径
  56. mid_dict ={}
  57. file_head = test_pic_absolute_path.split('/')[-1].split('.')[0]
  58. xml_path = get_elements_with_string(test_xml_absolute_paths,file_head)[0]
  59. xml_info = extract_xml_info(xml_path)# response
  60. mid_dict ={'system':'''职位:你是一个面向证书图像的目标检测大师,具备精准识别、定位图像中证书编码的能力。
  61. 职能:从毕业证、学历验证报告、证书等图像中检测到证书编码区域并给出边界框。
  62. **注意**:仅以给定格式返回检测结果,不要给出其它任何解释。
  63. **注意**:若图片中没有典型违章场景,返回<ref> class_name </ref><box>[[0, 0, 0, 0]]</box>即可。
  64. ''','query':'请目标检测图像中的证书编码并给出边界框','response':xml_info,'images':test_pic_absolute_path
  65. }
  66. dataset.append(mid_dict)# 指定输出文件的名称
  67. output_file ='test_dataset.jsonl'# 打开文件并写入JSONL格式的数据withopen(output_file,'w', encoding='utf-8')as f:for item in dataset:# 将字典转换为JSON字符串并写入文件,每个字典占一行
  68. json_string = json.dumps(item,ensure_ascii=False)
  69. f.write(json_string +'\n')

上述代码结果为

  1. test_dataset.jsonl

  1. train_dataset.jsonl

两个jsonl文件,分别对应train、test文件夹。

test_dataset.jsonl

  1. {"system":"职位:你是一个面向证书图像的目标检测大师,具备精准识别、定位图像中证书编码的能力。\n 职能:从毕业证、学历验证报告、证书等图像中检测到证书编码区域并给出边界框。\n **注意**:仅以给定格式返回检测结果,不要给出其它任何解释。\n **注意**:若图片中没有典型违章场景,返回<ref> class_name </ref><box>[[0, 0, 0, 0]]</box>即可。\n ","query":"请目标检测图像中的证书编码并给出边界框","response":"<ref>zs_code</ref><box>[[67,761,302,798]]</box>","images":"/home/super/lyq/zsbm_mbjc/data/train/pic/xxx-专科毕业证.jpg"}{"system":"职位:你是一个面向证书图像的目标检测大师,具备精准识别、定位图像中证书编码的能力。\n 职能:从毕业证、学历验证报告、证书等图像中检测到证书编码区域并给出边界框。\n **注意**:仅以给定格式返回检测结果,不要给出其它任何解释。\n **注意**:若图片中没有典型违章场景,返回<ref> class_name </ref><box>[[0, 0, 0, 0]]</box>即可。\n ","query":"请目标检测图像中的证书编码并给出边界框","response":"<ref>zs_code</ref><box>[[455,1272,1083,1356]]</box>","images":"/home/super/lyq/zsbm_mbjc/data/train/pic/xxx-本科毕业证.jpg"}{"system":"职位:你是一个面向证书图像的目标检测大师,具备精准识别、定位图像中证书编码的能力。\n 职能:从毕业证、学历验证报告、证书等图像中检测到证书编码区域并给出边界框。\n **注意**:仅以给定格式返回检测结果,不要给出其它任何解释。\n **注意**:若图片中没有典型违章场景,返回<ref> class_name </ref><box>[[0, 0, 0, 0]]</box>即可。\n ","query":"请目标检测图像中的证书编码并给出边界框","response":"<ref>zs_code</ref><box>[[90,484,329,508]]</box>","images":"/home/super/lyq/zsbm_mbjc/data/train/pic/xxx-本科毕业证.jpg"}

其中内容大概如上,人名已脱敏。

数据集于swift框架进行注册:
可参考我的历史文章

https://blog.csdn.net/qq_43128256/article/details/140314241

在这里插入图片描述

在这里插入图片描述

模型微调

本文不再采取UI,纯指令如下:

  1. CUDA_VISIBLE_DEVICES=0,1,2,3 swift sft \
  2. --model_id_or_path /data/hfd/InternVL2-8B \
  3. --template_type internvl2 \
  4. --dataset /home/super/lyq/train_dataset.jsonl \
  5. --lora_target_modules ALL \
  6. --lora_lr_ratio 16.0 \
  7. --lora_rank 16 \
  8. --learning_rate 1e-4 \
  9. --num_train_epochs 5 \
  10. --use_flash_attn True \
  11. --gradient_accumulation_steps 4 \
  12. --batch_size 2 \
  13. --eval_steps 50 \
  14. --save_steps 500 \
  15. --neftune_noise_alpha 5 \
  16. --model_type internvl2-8b \
  17. --device_max_memory 15GB 15GB 15GB 15GB \
  18. --output_dir /home/super/sgq/swift/llm-yolo/detection2/v1 \
  19. --logging_dir /home/super/sgq/swift/llm-yolo/detection2/v1/runs

其中需注意:

–model_id_or_path /data/hfd/InternVL2-8B
该参数为模型路径

–dataset /home/super/lyq/train_dataset.jsonl
该参数为微调数据集

–num_train_epochs 5
该参数为训练轮次,视情况调整

–use_flash_attn True
加速项,服务器未配置可不选

–output_dir /home/super/sgq/swift/llm-yolo/detection2/v1
为训练结果保存路径,结果包含微调训练参数和精度损失记录等

–logging_dir /home/super/sgq/swift/llm-yolo/detection2/v1/runs
为tensorboard查看结果内容存储路径

在这里插入图片描述

结果如上,其中checkpoint-135为训练后的lora权重;images为训练曲线;其他文件为训练参数。
在这里插入图片描述

在这里插入图片描述
在这里插入图片描述

训练后的模型部署及测试

合并权重

  1. CUDA_VISIBLE_DEVICES=0,1,2,3 swift export --ckpt_dir '/home/super/lyq/zsbm_mbjc/train_240731_1/internvl2-8b/v0-20240731-154920/checkpoint-135'--merge_lora true

生成合并模型:
在这里插入图片描述

推理部署

在这里插入图片描述

测试

api_ask.py

  1. from openai import OpenAI
  2. import base64
  3. client = OpenAI(api_key='YOUR_API_KEY', base_url='http://172.20.32.127:23333/v1')
  4. model_name = client.models.list().data[0].id#图片转base64函数defencode_image(image_path):withopen(image_path,"rb")as image_file:return base64.b64encode(image_file.read()).decode('utf-8')#原图片转base64defget_response(input_image_path):
  5. base64_image = encode_image(input_image_path)
  6. response = client.chat.completions.create(
  7. model=model_name,
  8. messages=[{"role":"system","content":'''职位:你是一个面向证书图像的目标检测大师,具备精准识别、定位图像中证书编码的能力。
  9. 职能:从毕业证、学历验证报告、证书等图像中检测到证书编码区域并给出边界框。
  10. **注意**:仅以给定格式返回检测结果,不要给出其它任何解释。
  11. **注意**:若图片中没有典型违章场景,返回<ref> class_name </ref><box>[[0, 0, 0, 0]]</box>即可。
  12. '''},{"role":"user","content":[{"type":"text","text":'请目标检测图像中的证书编码并给出边界框'},{"type":"image_url","image_url":{"url":f"data:image/jpeg;base64,{base64_image}"# "url": 'https://i-blog.csdnimg.cn/direct/253ad27104b7466792511f78e9f636a9.png'}},]}],
  13. temperature=0.8,
  14. top_p=0.8)return response.choices[0].message.content

get_llm_response.py

  1. import json
  2. import api_ask as llm_api
  3. defread_jsonl(file_path):"""
  4. Read a JSONL file and return a list of dictionaries.
  5. :param file_path: Absolute path of the JSONL file to be read.
  6. :return: List of dictionaries representing the JSON objects in the file.
  7. """
  8. data =[]withopen(file_path,'r', encoding='utf-8')asfile:for line infile:
  9. data.append(json.loads(line))return data
  10. data = read_jsonl('/home/super/lyq/test_dataset.jsonl')
  11. result =[]for single_data in data:
  12. img_path = single_data['images']
  13. single_result = llm_api.get_response(img_path)print(single_result)
  14. result.append({'images':img_path,'response':single_result})import pandas as pd
  15. pd.DataFrame(result).to_excel('llm_response.xlsx',index=False)

结果如下图:
在这里插入图片描述
result_test.py

  1. import pandas as pd
  2. from PIL import Image, ImageDraw
  3. import re
  4. import json
  5. from PIL import Image, ExifTags
  6. # 添加这个函数来处理图片方向defcorrect_image_orientation(image):try:for orientation in ExifTags.TAGS.keys():if ExifTags.TAGS[orientation]=='Orientation':break
  7. exif =dict(image._getexif().items())if exif[orientation]==3:
  8. image = image.rotate(180, expand=True)elif exif[orientation]==6:
  9. image = image.rotate(270, expand=True)elif exif[orientation]==8:
  10. image = image.rotate(90, expand=True)except(AttributeError, KeyError, IndexError):# 如果没有EXIF信息,就不做任何处理passreturn image
  11. defdraw_rectangle(image_path, coordinates, output_path):"""
  12. 在图像上标出矩形框。
  13. :param image_path: 图像的路径
  14. :param coordinates: 包含矩形框坐标的列表,格式为 [x1, y1, x2, y2]
  15. :param output_path: 输出图像的路径
  16. """# 打开图像with Image.open(image_path)as img:
  17. img = correct_image_orientation(img)
  18. img = correct_image_orientation(img)# 创建一个可以在给定图像上绘图的对象
  19. draw = ImageDraw.Draw(img)# 计算矩形的左上角和右下角坐标
  20. x1, y1, x2, y2 = coordinates
  21. # 在图像上绘制矩形
  22. draw.rectangle([x1, y1, x2, y2], outline="red", width=2)# 保存修改后的图像
  23. img.save(output_path)defextract_string(s):"""
  24. 从给定的字符串中提取方括号内的内容。
  25. :param s: 包含方括号的字符串
  26. :return: 提取出的字符串
  27. """# 使用正则表达式匹配方括号内的内容match= re.search(r'\[(.*?)\]', s)ifmatch:# 提取匹配的内容
  28. extracted_str =match.group(0)returneval(extracted_str+']')else:returnNonedefread_jsonl(file_path):"""
  29. 读取JSONL文件并返回一个包含所有条目的列表。
  30. :param file_path: JSONL文件的路径
  31. :return: 包含JSON对象的列表
  32. """
  33. data =[]withopen(file_path,'r', encoding='utf-8')asfile:for line infile:
  34. data.append(json.loads(line))return data
  35. data = pd.read_excel('/home/super/lyq/llm_response.xlsx')
  36. images = data['images'].tolist()
  37. responses = data['response'].tolist()
  38. n =len(images)print(images)for index inrange(n):print(images[index])
  39. img_path = images[index]
  40. zuobiao = extract_string(responses[index])
  41. draw_rectangle(img_path,zuobiao[0],'/home/super/lyq/zsbm_mbjc/test_result_pic'+'/'+img_path.split('/')[-1])

在这里插入图片描述
在这里插入图片描述

总结

实际上,interVL2-8B多模态大模型在该任务上微调后的表现并不好。与此同时,我们还就电力巡检场景进行了微调测试,精度达到了80左右,其实也比较一般,综合来看,大模型其实并不那么擅长目标检测。

此处引申一个结论,大模型在分类任务上表现则好得多,且提升精度微调是必要的。
最近做了实验,测试集微调前精度57%,微调后97%,不过面向的是单轮问答。


本文转载自: https://blog.csdn.net/qq_43128256/article/details/140829075
版权归原作者 写代码的中青年 所有, 如有侵权,请联系我们删除。

“详细记录swfit微调interVL2-8B多模态大模型进行目标检测(附代码)”的评论:

还没有评论