文章目录
摘要
论文翻译:https://wanghao.blog.csdn.net/article/details/129379410
官方源码: https://github.com/OpenGVLab/InternImage
他来了!他来了!他带着氩弧焊的光芒过来了!作为CV的大模型,InternImage的光芒太强了。
- 2023年3月14日: 🚀 “书生2.5”发布!
- 2023年2月28日: 🚀 InternImage 被CVPR 2023接收!
- 2022年11月18日: 🚀 基于 InternImage-XL 主干网络,BEVFormer v2 在nuScenes的纯视觉3D检测任务上取得了最佳性能
63.4 NDS
! - 2022年11月10日: 🚀 InternImage-H 在COCO目标检测任务上以
65.4 mAP
斩获冠军,是唯一突破65.0 mAP
的超强物体检测模型! - 2022年11月10日: 🚀 InternImage-H 在ADE20K语义分割数据集上取得
62.9 mIoU
的SOTA性能!
“书生2.5”的应用
1. 图像模态任务性能
- 在图像分类标杆数据集ImageNet上,“书生2.5”仅基于公开数据便达到了 90.1% 的Top-1准确率。这是除谷歌与微软两个未公开模型及额外数据集外,唯一准确率超过90.0%的模型,同时也是世界上开源模型中ImageNet准确度最高,规模最大的模型;
- 在物体检测标杆数据集COCO上,“书生2.5” 取得了 65.5 的 mAP,是世界上唯一超过65 mAP的模型;
- 在另外16个重要的视觉基础数据集(覆盖分类、检测和分割任务)上取得世界最好性能。
分类任务
图像分类 场景分类 长尾分类ImageNetPlaces365Places 205iNaturalist 201890.161.271.792.3
检测任务
常规物体检测长尾物体检测 自动驾驶物体检测密集物体检测COCOVOC 2007VOC 2012OpenImageLVIS minivalLVIS valBDD100KnuScenesCrowdHuman65.594.097.274.165.863.238.864.897.2
分割任务
语义分割街景分割RGBD分割ADE20KCOCO Stuff-10KPascal ContextCityScapesNYU Depth V262.959.670.386.169.7
2. 图文跨模态任务性能
- 图文检索
“书生2.5”可根据文本内容需求快速定位检索出语义最相关的图像。这一能力既可应用于视频和图像集合,也可进一步结合物体检测框,具有丰富的应用模式,帮助用户更便捷、快速地找到所需图像资源, 例如可在相册中返回文本所指定的相关图像。
- 以图生文
“书生2.5”的“以图生文”在图像描述、视觉问答、视觉推理和文字识别等多个方面均拥有强大的理解能力。例如在自动驾驶场景下,可以提升场景感知理解能力,辅助车辆判断交通信号灯状态、道路标志牌等信息,为车辆的决策规划提供有效的感知信息支持。
图文多模态任务
图像描述微调图文检索零样本图文检索COCO CaptionCOCO CaptionFlickr30kFlickr30k148.276.494.889.1
核心技术
“书生2.5”在图文跨模态领域卓越的性能表现,源自于在多模态多任务通用模型技术核心方面的多项创新,实现了视觉核心视觉感知大模型主干网络(InternImage)、用于文本核心的超大规模文本预训练网络(LLM)和用于多任务的兼容解码建模(Uni-Perceiver)的创新组合。 视觉主干网络InternImage参数量高达30亿,能够基于动态稀疏卷积算子自适应地调整卷积的位置和组合方式,从而为多功能视觉感知提供强大的表示。Uni-Perceiver通才任务解码建模通过将不同模态的数据编码到统一的表示空间,并将不同任务统一为相同的任务范式,从而能够以相同的任务架构和共享的模型参数同时处理各种模态和任务。
这篇文章主要讲解如何使用InternImage完成图像分类任务,接下来我们一起完成项目的实战。本例选用的模型是InternImageNet_s,在植物幼苗数据集上实现了97%的准确率。
通过这篇文章能让你学到:
- 如何使用数据增强,包括transforms的增强、CutOut、MixUp、CutMix等增强手段?
- 如何实现InternImage模型实现训练?
- 如何使用pytorch自带混合精度?
- 如何使用梯度裁剪防止梯度爆炸?
- 如何使用DP多显卡训练?
- 如何绘制loss和acc曲线?
- 如何生成val的测评报告?
- 如何编写测试脚本测试测试集?
- 如何使用余弦退火策略调整学习率?
- 如何使用AverageMeter类统计ACC和loss等自定义变量?
- 如何理解和统计ACC1和ACC5?
- 如何使用EMA?
- 如果使用Grad-CAM 实现热力图可视化?
如果基础薄弱,对上面的这些功能难以理解可以看我的专栏:经典主干网络精讲与实战
这个专栏,从零开始时,一步一步的讲解这些,让大家更容易接受。
安装包
安装timm
使用pip就行,命令:
pip install timm
本文实战用的timm里面的模型。
安装 grad-cam
pip install grad-cam
安装DCNV3
从
https://github.com/OpenGVLab/InternImage
下载源码,然后进入ops_dcnv3,如果是Ubuntu系统,
则可以执行:
sh ./make.sh
通用的方式:
python setup.py install
注意:本机要有显卡,并安装CUDA环境。
数据增强Cutout和Mixup
为了提高成绩我在代码中加入Cutout和Mixup这两种增强方式。实现这两种增强需要安装torchtoolbox。安装命令:
pip install torchtoolbox
Cutout实现,在transforms中。
from torchtoolbox.transform import Cutout
# 数据预处理
transform = transforms.Compose([
transforms.Resize((224,224)),
Cutout(),
transforms.ToTensor(),
transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])])
需要导入包:from timm.data.mixup import Mixup,
定义Mixup,和SoftTargetCrossEntropy
mixup_fn = Mixup(
mixup_alpha=0.8, cutmix_alpha=1.0, cutmix_minmax=None,
prob=0.1, switch_prob=0.5, mode='batch',
label_smoothing=0.1, num_classes=12)
criterion_train = SoftTargetCrossEntropy()
参数详解:
mixup_alpha (float): mixup alpha 值,如果 > 0,则 mixup 处于活动状态。
cutmix_alpha (float):cutmix alpha 值,如果 > 0,cutmix 处于活动状态。
cutmix_minmax (List[float]):cutmix 最小/最大图像比率,cutmix 处于活动状态,如果不是 None,则使用这个 vs alpha。
如果设置了 cutmix_minmax 则cutmix_alpha 默认为1.0
prob (float): 每批次或元素应用 mixup 或 cutmix 的概率。
switch_prob (float): 当两者都处于活动状态时切换cutmix 和mixup 的概率 。
mode (str): 如何应用 mixup/cutmix 参数(每个’batch’,‘pair’(元素对),‘elem’(元素)。
correct_lam (bool): 当 cutmix bbox 被图像边框剪裁时应用。 lambda 校正
label_smoothing (float):将标签平滑应用于混合目标张量。
num_classes (int): 目标的类数。
EMA
EMA(Exponential Moving Average)是指数移动平均值。在深度学习中的做法是保存历史的一份参数,在一定训练阶段后,拿历史的参数给目前学习的参数做一次平滑。具体实现如下:
import logging
from collections import OrderedDict
from copy import deepcopy
import torch
import torch.nn as nn
_logger = logging.getLogger(__name__)classModelEma:def__init__(self, model, decay=0.9999, device='', resume=''):# make a copy of the model for accumulating moving average of weights
self.ema = deepcopy(model)
self.ema.eval()
self.decay = decay
self.device = device # perform ema on different device from model if setif device:
self.ema.to(device=device)
self.ema_has_module =hasattr(self.ema,'module')if resume:
self._load_checkpoint(resume)for p in self.ema.parameters():
p.requires_grad_(False)def_load_checkpoint(self, checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location='cpu')assertisinstance(checkpoint,dict)if'state_dict_ema'in checkpoint:
new_state_dict = OrderedDict()for k, v in checkpoint['state_dict_ema'].items():# ema model may have been wrapped by DataParallel, and need module prefixif self.ema_has_module:
name ='module.'+ k ifnot k.startswith('module')else k
else:
name = k
new_state_dict[name]= v
self.ema.load_state_dict(new_state_dict)
_logger.info("Loaded state_dict_ema")else:
_logger.warning("Failed to find state_dict_ema, starting from loaded model weights")defupdate(self, model):# correct a mismatch in state dict keys
needs_module =hasattr(model,'module')andnot self.ema_has_module
with torch.no_grad():
msd = model.state_dict()for k, ema_v in self.ema.state_dict().items():if needs_module:
k ='module.'+ k
model_v = msd[k].detach()if self.device:
model_v = model_v.to(device=self.device)
ema_v.copy_(ema_v * self.decay +(1.- self.decay)* model_v)
加入到模型中。
#初始化if use_ema:
model_ema = ModelEma(
model_ft,
decay=model_ema_decay,
device='cpu',
resume=resume)# 训练过程中,更新完参数后,同步update shadow weightsdeftrain():
optimizer.step()if model_ema isnotNone:
model_ema.update(model)# 将model_ema传入验证函数中
val(model_ema.ema, DEVICE, test_loader)
针对没有预训练的模型,容易出现EMA不上分的情况,这点大家要注意啊!
项目结构
InternImage_Demo
├─data1
│ ├─Black-grass
│ ├─Charlock
│ ├─Cleavers
│ ├─Common Chickweed
│ ├─Common wheat
│ ├─Fat Hen
│ ├─Loose Silky-bent
│ ├─Maize
│ ├─Scentless Mayweed
│ ├─Shepherds Purse
│ ├─Small-flowered Cranesbill
│ └─Sugar beet
├─models
│ ├─__init__.py
│ └─intern_image.py
├─ops_dcnv3
├─mean_std.py
├─makedata.py
├─train.py
├─cam_image.py
└─test.py
ops_dcnv3:来着官方的文件夹,存放dcnv3的文件。
models:来源官方代码,对面的代码做了一些适应性修改。增加了一些加载预训练,调用模型的逻辑。
mean_std.py:计算mean和std的值。
makedata.py:生成数据集。
ema.py:EMA脚本
train.py:训练PoolFormer模型
cam_image.py:热力图可视化
计算mean和std
为了使模型更加快速的收敛,我们需要计算出mean和std的值,新建mean_std.py,插入代码:
from torchvision.datasets import ImageFolder
import torch
from torchvision import transforms
defget_mean_and_std(train_data):
train_loader = torch.utils.data.DataLoader(
train_data, batch_size=1, shuffle=False, num_workers=0,
pin_memory=True)
mean = torch.zeros(3)
std = torch.zeros(3)for X, _ in train_loader:for d inrange(3):
mean[d]+= X[:, d,:,:].mean()
std[d]+= X[:, d,:,:].std()
mean.div_(len(train_data))
std.div_(len(train_data))returnlist(mean.numpy()),list(std.numpy())if __name__ =='__main__':
train_dataset = ImageFolder(root=r'data1', transform=transforms.ToTensor())print(get_mean_and_std(train_dataset))
数据集结构:
运行结果:
([0.3281186, 0.28937867, 0.20702125], [0.09407319, 0.09732835, 0.106712654])
把这个结果记录下来,后面要用!
生成数据集
我们整理还的图像分类的数据集结构是这样的
data
├─Black-grass
├─Charlock
├─Cleavers
├─Common Chickweed
├─Common wheat
├─Fat Hen
├─Loose Silky-bent
├─Maize
├─Scentless Mayweed
├─Shepherds Purse
├─Small-flowered Cranesbill
└─Sugar beet
pytorch和keras默认加载方式是ImageNet数据集格式,格式是
├─data
│ ├─val
│ │ ├─Black-grass
│ │ ├─Charlock
│ │ ├─Cleavers
│ │ ├─Common Chickweed
│ │ ├─Common wheat
│ │ ├─Fat Hen
│ │ ├─Loose Silky-bent
│ │ ├─Maize
│ │ ├─Scentless Mayweed
│ │ ├─Shepherds Purse
│ │ ├─Small-flowered Cranesbill
│ │ └─Sugar beet
│ └─train
│ ├─Black-grass
│ ├─Charlock
│ ├─Cleavers
│ ├─Common Chickweed
│ ├─Common wheat
│ ├─Fat Hen
│ ├─Loose Silky-bent
│ ├─Maize
│ ├─Scentless Mayweed
│ ├─Shepherds Purse
│ ├─Small-flowered Cranesbill
│ └─Sugar beet
新增格式转化脚本makedata.py,插入代码:
import glob
import os
import shutil
image_list=glob.glob('data1/*/*.png')print(image_list)
file_dir='data'if os.path.exists(file_dir):print('true')#os.rmdir(file_dir)
shutil.rmtree(file_dir)#删除再建立
os.makedirs(file_dir)else:
os.makedirs(file_dir)from sklearn.model_selection import train_test_split
trainval_files, val_files = train_test_split(image_list, test_size=0.3, random_state=42)
train_dir='train'
val_dir='val'
train_root=os.path.join(file_dir,train_dir)
val_root=os.path.join(file_dir,val_dir)forfilein trainval_files:
file_class=file.replace("\\","/").split('/')[-2]
file_name=file.replace("\\","/").split('/')[-1]
file_class=os.path.join(train_root,file_class)ifnot os.path.isdir(file_class):
os.makedirs(file_class)
shutil.copy(file, file_class +'/'+ file_name)forfilein val_files:
file_class=file.replace("\\","/").split('/')[-2]
file_name=file.replace("\\","/").split('/')[-1]
file_class=os.path.join(val_root,file_class)ifnot os.path.isdir(file_class):
os.makedirs(file_class)
shutil.copy(file, file_class +'/'+ file_name)
完成上面的内容就可以开启训练和测试了。
版权归原作者 AI浩 所有, 如有侵权,请联系我们删除。