之前简单地写了一个pytorch的U-net 复现过程,有很多小伙伴在评论里有很多疑问,抽空又复现了一遍,简单整理了常见的问题。
之前写的教程:U-net复现pytorch版本 以及制作自己的数据集并训练_candice5566的博客-CSDN博客
2021.11.14复现过程:
github代码链接:
https://github.com/milesial/Pytorch-UNet
代码说明:这个代码新加入了一个wandb的可视化库,能够可视化训练过程中的参数变化情况
1.下载github的代码,并且配置好环境:
2.下载数据集:
我自己用来测试的数据集,链接:https://pan.baidu.com/s/1lqwZ6XvAtPhw5EYn4bwLUQ
提取码:upb3
- 修改train.py文件里
a.修改数据集的路径
b.修改channel,如果是RGB图像,channel=3,如果是灰度图,channel=1;
修改classes,就是背景+你的数据集里有几个类别;比如我给的那个数据集有汽车和背景两个类,那么classes=2;
c. 其他参数可以自行进行修改
- 开始训练,训练顺利的话应该是这样:
训练完成,wandb可视化结果:
可以看出来最后dice系数在0.95左右,loss在0.2左右。
- 训练完成,模型会保存在checkpoints路径下:
我们找到predict.py文件,修改参数:
然后 找一张测试照片放在当前路径下,
# -i 是指定预测的照片 其他参数可以自己看get_args部分
python predict.py -i 7.jpg --viz -v
预测结果大概就是这样:
输入:
输出:
完结撒花~
** 训练过程常见问题整理:**
1.block: [0,0,0], thread: [828,0,0] Assertion t >= 0 && t < n_classes
failed. 如图
多半是classes设置错了
- f'Either no mask or multiple masks found for the ID {name}: {mask_file}',如图
多半是数据集出错了,
自查方法:
a.看路径出没出错
b.看后缀是不是对,可以看一下utils/date_lodading文件,
还有一定要注意,这里自定义了后缀是 "_mask",也就是说你的掩码和原图一个应该是1.jpg,另外一个就是1_mask.jpg
c.看一下数据集路径下有没有其他文件,这个代码的data_load方法很粗暴,不能有其他文件或者文件夹
- OSError: [WinError 1455] 页面文件太小,无法完成操作
这个直接说解决方法,在train.py文件里修改
num_worker为2,如果还是不行,改成0
4.dice系数一直都小,或者不变,那么有可能是学习率设置太小了,改成0.0001试试,
其他可能原因可以看一下原作者的解答:
Dice coefficient no change during training,is always very close to 0 · Issue #173 · milesial/Pytorch-UNet · GitHub
over
有问题可以评论区补充。
版权归原作者 奶盖芒果 所有, 如有侵权,请联系我们删除。