0


语义分割数据集:Cityscapes的使用

本文主要介绍Cityscapes在语义分割方向上的理解和使用。

Cityscapes官网:官方网站

6d5314b7833b4b32b0562dd68803ab01.png

Cityscapes

简介

Cityscapes大致有两个数据集,分别为精细的标注数据集(3475张训练图像,1525张测试图像)和粗糙的标注数据集(3475+19888张额外的粗糙标注),见图1。

6efbdcddde664fd5b9a43074b49d460f.png
标题

一般只需要用到精细的部分,也就是4375+1525张图像,在官网直接下载即可,一共5000张。

数据集的原始图片为图2中所示,左边摄像头拍到的图像。共11GB。

02130d6d245e4b1f9ff9be8298a01407.png
图2 原始数据

数据集标注方法

数据集下载以后,需要通过代码文件来生成标注,需要上github下载:cityscapes数据集生成工具在下载好工具后,需要pip安装相应工具包。

  1. pip install cityscapesscripts

在jupyter notebook中也可以输入以下代码进行安装。

  1. !pip install cityscapesscripts

选其一即可。


将下载的工具包打开,进入到preparation文件夹,找到如下文件:打开createTrainIdLabelImages.py

2276e79986be48c3998f73158788dcaa.png

在其中添加一行代码,保证能读取到你的标注文件路径。

  1. os.environ['CITYSCAPES_DATASET'] = "你的CityScapes gtFine路径"

26de434952554b1888ad89e5e13eefda.png

运行createTrainIdLabelImgs.py,即可生成如下数据集(19类)。

a561c58fe17b42719a7b5c9bffef8abb.png
生成的数据集-labelTrainIds结尾的图像

0e9243f1841f42618afa2bf3f7200bd5.png
原始的数据集-labelIds结尾图像(33类)

补充说明

在原始的gtFine数据集中就有的以labelIds结尾的数据:是所有类别的数据共有33类。

而在DeepLab论文中,只使用了其中19类,于是我们可以生成19类的数据集:以labelTrainIds结尾。

生成任意类别数的数据集

如果我们想生成任意类别数的数据集,可以修改工具包中的py文件。

进入工具包的helpers文件夹,找到labels.py文件,修改其中类别对应的trainId即可,把想要训练的类别标签设为1,2,3,4....,把不想要参与训练的类别标签设置为255。然后重新运行createTrainIdLabelImgs.py文件,生成新的数据集。

fba863bd2df342c2a369c9a9275badda.png

训练中需要注意的点

因为我们把不感兴趣的区域设置成为了255,所以,在定义损失函数的时候,需要设置ignore_index=255这个参数,来忽略我们不感兴趣的区域。

  1. lossf = nn.CrossEntropyLoss(ignore_index=255)

在pytorch中构建Dataset

现在,我们有了两个文件夹,一个是leftImg8bit的原始图像文件夹,一个是gtFine标注文件夹。

现在,我们要将这两个文件夹里面的图像都提取出来,存入train、val、test文件夹中。

运行下面的代码,即可将原始图像提取并处理。

  1. import os
  2. import random
  3. import shutil
  4. # 数据集路径
  5. dataset_path = r"dataset/cityscapes/leftImg8bit_trainvaltest/leftImg8bit"
  6. #原始的train, valid文件夹路径
  7. train_dataset_path = os.path.join(dataset_path,'train')
  8. val_dataset_path = os.path.join(dataset_path,'val')
  9. test_dataset_path = os.path.join(dataset_path,'test')
  10. #创建train,valid的文件夹
  11. train_images_path = os.path.join(dataset_path,'cityscapes_train')
  12. val_images_path = os.path.join(dataset_path,'cityscapes_val')
  13. test_images_path = os.path.join(dataset_path,'cityscapes_test')
  14. if os.path.exists(train_images_path)==False:
  15. os.mkdir(train_images_path )
  16. if os.path.exists(val_images_path)==False:
  17. os.mkdir(val_images_path)
  18. if os.path.exists(test_images_path)==False:
  19. os.mkdir(test_images_path)
  20. #-----------------移动文件夹-------------------------------------------------
  21. for file_name in os.listdir(train_dataset_path):
  22. file_path = os.path.join(train_dataset_path,file_name)
  23. for image in os.listdir(file_path):
  24. shutil.copy(os.path.join(file_path,image), os.path.join(train_images_path,image))
  25. for file_name in os.listdir(val_dataset_path):
  26. file_path = os.path.join(val_dataset_path,file_name)
  27. for image in os.listdir(file_path):
  28. shutil.copy(os.path.join(file_path,image), os.path.join(val_images_path,image))
  29. for file_name in os.listdir(test_dataset_path):
  30. file_path = os.path.join(test_dataset_path,file_name)
  31. for image in os.listdir(file_path):
  32. shutil.copy(os.path.join(file_path,image), os.path.join(test_images_path,image))

运行后生成如下文件夹。

b063b0f5bdbe4988bc67fd7abf761762.png

对于label文件也同样如此,比如下面生成19类的标注文件夹。

  1. import os
  2. import random
  3. import shutil
  4. # 数据集路径
  5. dataset_path = r"dataset\cityscapes\gtFine_trainvaltest\gtFine"
  6. #原始的train, valid文件夹路径
  7. train_dataset_path = os.path.join(dataset_path,'train')
  8. val_dataset_path = os.path.join(dataset_path,'val')
  9. test_dataset_path = os.path.join(dataset_path,'test')
  10. #创建train,valid的文件夹
  11. train_images_path = os.path.join(dataset_path,'cityscapes_19classes_train')
  12. val_images_path = os.path.join(dataset_path,'cityscapes_19classes_val')
  13. test_images_path = os.path.join(dataset_path,'cityscapes_19classes_test')
  14. if os.path.exists(train_images_path)==False:
  15. os.mkdir(train_images_path )
  16. if os.path.exists(val_images_path)==False:
  17. os.mkdir(val_images_path)
  18. if os.path.exists(test_images_path)==False:
  19. os.mkdir(test_images_path)
  20. #-----------------移动文件---对于19类语义分割, 主需要原始图像中的labelIds结尾图片-----------------------
  21. for file_name in os.listdir(train_dataset_path):
  22. file_path = os.path.join(train_dataset_path,file_name)
  23. for image in os.listdir(file_path):
  24. #查找对应的后缀名,然后保存到文件中
  25. if image.split('.png')[0][-13:] == "labelTrainIds":
  26. #print(image)
  27. shutil.copy(os.path.join(file_path,image), os.path.join(train_images_path,image))
  28. for file_name in os.listdir(val_dataset_path):
  29. file_path = os.path.join(val_dataset_path,file_name)
  30. for image in os.listdir(file_path):
  31. if image.split('.png')[0][-13:] == "labelTrainIds":
  32. shutil.copy(os.path.join(file_path,image), os.path.join(val_images_path,image))
  33. for file_name in os.listdir(test_dataset_path):
  34. file_path = os.path.join(test_dataset_path,file_name)
  35. for image in os.listdir(file_path):
  36. if image.split('.png')[0][-13:] == "labelTrainIds":
  37. shutil.copy(os.path.join(file_path,image), os.path.join(test_images_path,image))

得到如下结果。

1b271fc8a98f4b70a63a55bc4856dd8b.png

到这里,我们已经提取了所有的图像文件和标注文件。


读取数据集

现在我们可以读取对应的数据集。

  1. # 导入库
  2. import os
  3. os.environ['CUDA_VISIBLE_DEVICES'] = '0'
  4. import torch
  5. import torch.nn as nn
  6. import torch.optim as optim
  7. import torch.nn.functional as F
  8. from torch import optim
  9. from torch.utils.data import Dataset, DataLoader, random_split
  10. from tqdm import tqdm
  11. import warnings
  12. warnings.filterwarnings("ignore")
  13. import os.path as osp
  14. import matplotlib.pyplot as plt
  15. from PIL import Image
  16. import numpy as np
  17. import albumentations as A
  18. from albumentations.pytorch.transforms import ToTensorV2
  19. torch.manual_seed(17)
  20. # 自定义数据集CamVidDataset
  21. class CityScapesDataset(torch.utils.data.Dataset):
  22. """CamVid Dataset. Read images, apply augmentation and preprocessing transformations.
  23. Args:
  24. images_dir (str): path to images folder
  25. masks_dir (str): path to segmentation masks folder
  26. class_values (list): values of classes to extract from segmentation mask
  27. augmentation (albumentations.Compose): data transfromation pipeline
  28. (e.g. flip, scale, etc.)
  29. preprocessing (albumentations.Compose): data preprocessing
  30. (e.g. noralization, shape manipulation, etc.)
  31. """
  32. def __init__(self, images_dir, masks_dir):
  33. self.transform = A.Compose([
  34. A.Resize(224, 448),
  35. A.HorizontalFlip(),
  36. #A.RandomBrightnessContrast(),
  37. A.RandomSnow(),
  38. A.Normalize(),
  39. ToTensorV2(),
  40. ])
  41. self.ids = os.listdir(images_dir)
  42. self.ids2 = os.listdir(masks_dir)
  43. self.images_fps = [os.path.join(images_dir, image_id) for image_id in self.ids]
  44. self.masks_fps = [os.path.join(masks_dir, image_id) for image_id in self.ids2]
  45. def __getitem__(self, i):
  46. # read data
  47. image = np.array(Image.open(self.images_fps[i]).convert('RGB'))
  48. mask = np.array( Image.open(self.masks_fps[i]).convert('RGB'))
  49. image = self.transform(image=image,mask=mask)
  50. return image['image'], image['mask'][:,:,0]
  51. def __len__(self):
  52. return len(self.ids)
  53. # 设置数据集路径
  54. x_train_dir = r"dataset\cityscapes\leftImg8bit_trainvaltest\leftImg8bit\cityscapes_train"
  55. y_train_dir = r"dataset\cityscapes\gtFine_trainvaltest\gtFine\cityscapes_19classes_train"
  56. x_valid_dir = r"dataset\cityscapes\leftImg8bit_trainvaltest\leftImg8bit\cityscapes_val"
  57. y_valid_dir = r"dataset\cityscapes\gtFine_trainvaltest\gtFine\cityscapes_19classes_val"
  58. train_dataset = CityScapesDataset(
  59. x_train_dir,
  60. y_train_dir,
  61. )
  62. val_dataset = CityScapesDataset(
  63. x_valid_dir,
  64. y_valid_dir,
  65. )
  66. train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
  67. val_loader = DataLoader(val_dataset, batch_size=8, shuffle=True)

测试一下结果

  1. for index, (img, label) in enumerate(train_loader):
  2. print(img.shape)
  3. print(label.shape)
  4. plt.figure(figsize=(10,10))
  5. plt.subplot(221)
  6. plt.imshow((img[0,:,:,:].moveaxis(0,2)))
  7. plt.subplot(222)
  8. plt.imshow(label[0,:,:])
  9. plt.subplot(223)
  10. plt.imshow((img[6,:,:,:].moveaxis(0,2)))
  11. plt.subplot(224)
  12. plt.imshow(label[6,:,:])
  13. plt.show()
  14. if index==0:
  15. break

a8f11f0b31f9410387b0e4f7cd581319.png


本文转载自: https://blog.csdn.net/yumaomi/article/details/124847721
版权归原作者 yumaomi 所有, 如有侵权,请联系我们删除。

“语义分割数据集:Cityscapes的使用”的评论:

还没有评论