0


LAHeart2018左心房分割实战

2018 Atrial Segmentation Challenge

数据准备

The Left Atrium (LA) MR dataset from the Atrial Segmentation Challenge

数据集下载地址:Data – 2018 Atrial Segmentation Challenge (cardiacatlas.org)

在这里插入图片描述

数据集结构:

  1. Training_Set
  2. ├── 0RZDK210BSMWAA6467LU
  3. ├── laendo.nrrd
  4. └── lgemri.nrrd
  5. ├── 1D7CUD1955YZPGK8XHJX
  6. ├── laendo.nrrd
  7. └── lgemri.nrrd
  8. ......
  9. Testing_Set
  10. ├── 4URSJYI2QUH1T5S5PP47
  11. ├── laendo.nrrd
  12. └── lgemri.nrrd
  13. ├── 6HDYMTGBRI27MN763XTS
  14. ├── laendo.nrrd
  15. └── lgemri.nrrd
  16. ......
  • 一共有154例包含心房颤动3D MRI 图像
  • 分为训练集(Training Set)和测试集(Testing Set,已开源),数据集下每个文件夹包含一个患者的MRI(lgemri.nrrd)和标签图像(laendo.nrrd)
  • MRI灰度分布在[0, 255],空间分辨率为 0.625 x 0.625 x 0.625 mm³,切片尺寸因人而异,Z轴包含88个切片
  • 标签为二值图,0代表背景,255代表分割区域(左心房)

数据处理

所有的MRI数据空间分辨率都为 0.625 x 0.625 x 0.625 mm³,因此不需要做重采样。灰度分布都在0~255之间,也不需要做约束。

首先,将训练集和测试集放在一个文件夹里面,统一进行处理。

做三维图像的数据处理之前,最好提前确定目标尺寸,就是你输入到神经网络中的图像尺寸。可以用3D slicer提前看一下,分割区域大致有多大,选定的尺寸至少要包含目标区域。我选定的目标尺寸是 112 x 112 x 80,裁剪的时候不要一步裁剪到了目标尺寸,可以裁剪的比 112 x 112 x 80 略大,这样我们在做数据增强的时候,才能保证空间上的多样性,比如说平移。

具体操作可以看代码:

  • data_path 为合并后的数据集地址,包含 154 对图像
  • out_path 是输出地址,保存裁剪后的数据
  1. import os
  2. import numpy as np
  3. from tqdm import tqdm
  4. import h5py
  5. import nrrd
  6. output_size =[112,112,80]
  7. data_path ='E:/data/LASet/origin'
  8. out_path ='E:/data/LASet/data'defcovert_h5():
  9. listt = os.listdir(data_path)forcasein tqdm(listt):
  10. image, img_header = nrrd.read(os.path.join(data_path,case,'lgemri.nrrd'))
  11. label, gt_header = nrrd.read(os.path.join(data_path,case,'laendo.nrrd'))
  12. label =(label ==255).astype(np.uint8)
  13. w, h, d = label.shape
  14. # 返回label中所有非零区域(分割对象)的索引
  15. tempL = np.nonzero(label)# 分别获取非零区域在x,y,z三轴的最小值和最大值,确保裁剪图像包含分割对象
  16. minx, maxx = np.min(tempL[0]), np.max(tempL[0])
  17. miny, maxy = np.min(tempL[1]), np.max(tempL[1])
  18. minz, maxz = np.min(tempL[2]), np.max(tempL[2])# 计算目标尺寸比分割对象多余的尺寸
  19. px =max(output_size[0]-(maxx - minx),0)//2
  20. py =max(output_size[1]-(maxy - miny),0)//2
  21. pz =max(output_size[2]-(maxz - minz),0)//2# 在三个方向上随机扩增
  22. minx =max(minx - np.random.randint(10,20)- px,0)
  23. maxx =min(maxx + np.random.randint(10,20)+ px, w)
  24. miny =max(miny - np.random.randint(10,20)- py,0)
  25. maxy =min(maxy + np.random.randint(10,20)+ py, h)
  26. minz =max(minz - np.random.randint(5,10)- pz,0)
  27. maxz =min(maxz + np.random.randint(5,10)+ pz, d)# 图像归一化,转为32位浮点数(numpy默认是64位)
  28. image =(image - np.mean(image))/ np.std(image)
  29. image = image.astype(np.float32)# 裁剪
  30. image = image[minx:maxx, miny:maxy, minz:maxz]
  31. label = label[minx:maxx, miny:maxy, minz:maxz]print(label.shape)
  32. case_dir = os.path.join(out_path,case)
  33. os.mkdir(case_dir)
  34. f = h5py.File(os.path.join(case_dir,'mri_norm2.h5'),'w')
  35. f.create_dataset('image', data=image, compression="gzip")
  36. f.create_dataset('label', data=label, compression="gzip")
  37. f.close()if __name__ =='__main__':
  38. covert_h5()

裁剪后的数据保存在 mri_norm2.h5 文件中,每个 mri_norm2.h5 相当于一个字典,字典的键为 image 和 label ,值为对应的数组。

如果想看一看裁剪后的3D图像,可以使用SimpleITK或者nibabel将图像和标签分别保存为.nii格式的图像。

随机划分数据集

一般会划分训练集、验证集和测试集,这次偷个懒,只划分了训练集和测试集。

按照 4:1 的比例进行划分

  1. import os
  2. from sklearn.model_selection import train_test_split
  3. data_path ='E:/data/LASet'
  4. names = os.listdir(os.path.join(data_path,'origin'))
  5. train_ids,test_ids = train_test_split(names,test_size=0.2,random_state=367)withopen(os.path.join(data_path,'train.list'),'w')as f:
  6. f.write('\n'.join(train_ids))withopen(os.path.join(data_path,'test.list'),'w')as f:
  7. f.write('\n'.join(test_ids))print(len(names),len(train_ids),len(test_ids))

一共 154 例,划分 123 例作为训练集,31 例作为测试集

数据增强

读取

  1. import h5py
  2. from torch.utils.data import Dataset
  3. classLAHeart(Dataset):""" LA Dataset """def__init__(self, base_dir=None, split='train', num=None, transform=None):
  4. self._base_dir = base_dir
  5. self.transform = transform
  6. self.sample_list =[]if split =='train':withopen(self._base_dir +'/../train.list','r')as f:
  7. self.image_list = f.readlines()elif split =='test':withopen(self._base_dir +'/../test.list','r')as f:
  8. self.image_list = f.readlines()
  9. self.image_list =[item.strip()for item in self.image_list]if num isnotNone:
  10. self.image_list = self.image_list[:num]print("total {} samples".format(len(self.image_list)))def__len__(self):returnlen(self.image_list)def__getitem__(self, idx):
  11. image_name = self.image_list[idx]print(image_name)
  12. h5f = h5py.File(self._base_dir +"/"+ image_name +"/mri_norm2.h5",'r')
  13. image = h5f['image'][:]
  14. label = h5f['label'][:]
  15. sample ={'image': image,'label': label}if self.transform:
  16. sample = self.transform(sample)return sample
  17. if __name__ =='__main__':
  18. train_set = LAHeart('E:/data/LASet/data')print(len(train_set))
  19. data = train_set[0]
  20. image, label = data['image'], data['label']print(image.shape, label.shape)

增强

1.随机裁剪

  1. classRandomCrop(object):"""
  2. Crop randomly the image in a sample
  3. Args:
  4. output_size (int): Desired output size
  5. """def__init__(self, output_size):
  6. self.output_size = output_size
  7. def__call__(self, sample):
  8. image, label = sample['image'], sample['label']# pad the sample if necessaryif label.shape[0]<= self.output_size[0]or label.shape[1]<= self.output_size[1]or label.shape[2]<= \
  9. self.output_size[2]:
  10. pw =max((self.output_size[0]- label.shape[0])//2+3,0)
  11. ph =max((self.output_size[1]- label.shape[1])//2+3,0)
  12. pd =max((self.output_size[2]- label.shape[2])//2+3,0)
  13. image = np.pad(image,[(pw, pw),(ph, ph),(pd, pd)], mode='constant', constant_values=0)
  14. label = np.pad(label,[(pw, pw),(ph, ph),(pd, pd)], mode='constant', constant_values=0)(w, h, d)= image.shape
  15. w1 = np.random.randint(0, w - self.output_size[0])
  16. h1 = np.random.randint(0, h - self.output_size[1])
  17. d1 = np.random.randint(0, d - self.output_size[2])
  18. label = label[w1:w1 + self.output_size[0], h1:h1 + self.output_size[1], d1:d1 + self.output_size[2]]
  19. image = image[w1:w1 + self.output_size[0], h1:h1 + self.output_size[1], d1:d1 + self.output_size[2]]return{'image': image,'label': label}

2.中心裁剪

  1. classCenterCrop(object):def__init__(self, output_size):
  2. self.output_size = output_size
  3. def__call__(self, sample):
  4. image, label = sample['image'], sample['label']# pad the sample if necessaryif label.shape[0]<= self.output_size[0]or label.shape[1]<= self.output_size[1]or label.shape[2]<= \
  5. self.output_size[2]:
  6. pw =max((self.output_size[0]- label.shape[0])//2+3,0)
  7. ph =max((self.output_size[1]- label.shape[1])//2+3,0)
  8. pd =max((self.output_size[2]- label.shape[2])//2+3,0)
  9. image = np.pad(image,[(pw, pw),(ph, ph),(pd, pd)], mode='constant', constant_values=0)
  10. label = np.pad(label,[(pw, pw),(ph, ph),(pd, pd)], mode='constant', constant_values=0)(w, h, d)= image.shape
  11. w1 =int(round((w - self.output_size[0])/2.))
  12. h1 =int(round((h - self.output_size[1])/2.))
  13. d1 =int(round((d - self.output_size[2])/2.))
  14. label = label[w1:w1 + self.output_size[0], h1:h1 + self.output_size[1], d1:d1 + self.output_size[2]]
  15. image = image[w1:w1 + self.output_size[0], h1:h1 + self.output_size[1], d1:d1 + self.output_size[2]]return{'image': image,'label': label}

3.随机翻转

  1. classRandomRotFlip(object):"""
  2. Crop randomly flip the dataset in a sample
  3. Args:
  4. output_size (int): Desired output size
  5. """def__call__(self, sample):
  6. image, label = sample['image'], sample['label']
  7. k = np.random.randint(0,4)
  8. image = np.rot90(image, k)
  9. label = np.rot90(label, k)
  10. axis = np.random.randint(0,2)
  11. image = np.flip(image, axis=axis).copy()
  12. label = np.flip(label, axis=axis).copy()return{'image': image,'label': label}

4.数组转为张量

  1. classToTensor(object):"""Convert ndarrays in sample to Tensors."""def__call__(self, sample):
  2. image = sample['image']
  3. image = image.reshape(1, image.shape[0], image.shape[1], image.shape[2]).astype(np.float32)return{'image': torch.from_numpy(image),'label': torch.from_numpy(sample['label']).long()}

模型训练

网络结构

在这里插入图片描述

以一个简单的 3D V-Net 为例,具体代码见我的 github

  1. classVNet(nn.Module):def__init__(self, n_channels=3, n_classes=2, n_filters=16, normalization='none', has_dropout=False):super(VNet, self).__init__()
  2. self.has_dropout = has_dropout
  3. self.block_one = ConvBlock(1, n_channels, n_filters, normalization=normalization)
  4. self.block_one_dw = DownsamplingConvBlock(n_filters,2* n_filters, normalization=normalization)
  5. self.block_two = ConvBlock(2, n_filters *2, n_filters *2, normalization=normalization)
  6. self.block_two_dw = DownsamplingConvBlock(n_filters *2, n_filters *4, normalization=normalization)
  7. self.block_three = ConvBlock(3, n_filters *4, n_filters *4, normalization=normalization)
  8. self.block_three_dw = DownsamplingConvBlock(n_filters *4, n_filters *8, normalization=normalization)
  9. self.block_four = ConvBlock(3, n_filters *8, n_filters *8, normalization=normalization)
  10. self.block_four_dw = DownsamplingConvBlock(n_filters *8, n_filters *16, stride=(2,2,1), normalization=normalization)
  11. self.block_five = ConvBlock(3, n_filters *16, n_filters *16, normalization=normalization)
  12. self.block_five_up = UpsamplingDeconvBlock(n_filters *16, n_filters *8, stride=(2,2,1), normalization=normalization)
  13. self.block_six = ConvBlock(3, n_filters *8, n_filters *8, normalization=normalization)
  14. self.block_six_up = UpsamplingDeconvBlock(n_filters *8, n_filters *4, normalization=normalization)
  15. self.block_seven = ConvBlock(3, n_filters *4, n_filters *4, normalization=normalization)
  16. self.block_seven_up = UpsamplingDeconvBlock(n_filters *4, n_filters *2, normalization=normalization)
  17. self.block_eight = ConvBlock(2, n_filters *2, n_filters *2, normalization=normalization)
  18. self.block_eight_up = UpsamplingDeconvBlock(n_filters *2, n_filters, normalization=normalization)
  19. self.block_nine = ConvBlock(1, n_filters, n_filters, normalization=normalization)
  20. self.out_conv = nn.Conv3d(n_filters, n_classes,1, padding=0)
  21. self.dropout = nn.Dropout3d(p=0.5, inplace=False)# self.__init_weight()defencoder(self,input):
  22. x1 = self.block_one(input)
  23. x1_dw = self.block_one_dw(x1)
  24. x2 = self.block_two(x1_dw)
  25. x2_dw = self.block_two_dw(x2)
  26. x3 = self.block_three(x2_dw)
  27. x3_dw = self.block_three_dw(x3)
  28. x4 = self.block_four(x3_dw)
  29. x4_dw = self.block_four_dw(x4)
  30. x5 = self.block_five(x4_dw)# x5 = F.dropout3d(x5, p=0.5, training=True)if self.has_dropout:
  31. x5 = self.dropout(x5)
  32. res =[x1, x2, x3, x4, x5]# print(x5.shape)return res
  33. defdecoder(self, features):
  34. x1 = features[0]
  35. x2 = features[1]
  36. x3 = features[2]
  37. x4 = features[3]
  38. x5 = features[4]
  39. x5_up = self.block_five_up(x5)# print(x5_up.shape)
  40. x5_up = x5_up + x4
  41. x6 = self.block_six(x5_up)
  42. x6_up = self.block_six_up(x6)
  43. x6_up = x6_up + x3
  44. x7 = self.block_seven(x6_up)
  45. x7_up = self.block_seven_up(x7)
  46. x7_up = x7_up + x2
  47. x8 = self.block_eight(x7_up)
  48. x8_up = self.block_eight_up(x8)
  49. x8_up = x8_up + x1
  50. x9 = self.block_nine(x8_up)# x9 = F.dropout3d(x9, p=0.5, training=True)if self.has_dropout:
  51. x9 = self.dropout(x9)
  52. out = self.out_conv(x9)return out
  53. defforward(self,input, turnoff_drop=False):if turnoff_drop:
  54. has_dropout = self.has_dropout
  55. self.has_dropout =False
  56. features = self.encoder(input)
  57. out = self.decoder(features)if turnoff_drop:
  58. self.has_dropout = has_dropout
  59. return out

损失函数

损失函数仍然是dice损失和交叉熵

在这里插入图片描述

dice loss

在这里插入图片描述

  • μ是网络的softmax输出
  • v是分割标签的one-hot编码
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from einops import rearrange
  5. # 二分割的dice loss其实可以写的更简单,但我懒得简化了classLoss(nn.Module):def__init__(self, n_classes, alpha=0.5):"dice_loss_plus_cetr_weighted"super(Loss, self).__init__()
  6. self.n_classes = n_classes
  7. self.alpha = alpha
  8. defforward(self,input, target):
  9. smooth =0.01
  10. input1 = F.softmax(input, dim=1)
  11. target1 = F.one_hot(target,self.n_classes)
  12. input1 = rearrange(input1,'b n h w s -> b n (h w s)')
  13. target1 = rearrange(target1,'b h w s n -> b n (h w s)')# 只取前景
  14. input1 = input1[:,1:,:]
  15. target1 = target1[:,1:,:].float()# dice loss
  16. inter = torch.sum(input1 * target1)
  17. union = torch.sum(input1)+ torch.sum(target1)+ smooth
  18. dice =2.0* inter / union
  19. # 交叉熵
  20. loss = F.cross_entropy(input,target)
  21. total_loss =(1- self.alpha)* loss +(1- dice)* self.alpha
  22. return total_loss
  23. if __name__ =='__main__':
  24. torch.manual_seed(3)
  25. device = torch.device('cuda'if torch.cuda.is_available()else'cpu')
  26. losser = Loss(n_classes=2).to(device)
  27. x = torch.randn((2,2,16,16,16)).to(device)
  28. y = torch.randint(0,2,(2,16,16,16)).to(device)print(losser(x, y))

训练框架

  1. import os
  2. import torch
  3. import argparse
  4. import torch.optim as optim
  5. from tqdm import tqdm
  6. from torch.utils.data import DataLoader
  7. from torchvision import transforms
  8. from networks.vnet import VNet
  9. from loss import Loss
  10. from dataloaders.la_heart import LAHeart, RandomCrop, CenterCrop, RandomRotFlip, ToTensor
  11. defcal_dice(output, target, eps=1e-3):
  12. output = torch.argmax(output,dim=1)
  13. inter = torch.sum(output * target)+ eps
  14. union = torch.sum(output)+ torch.sum(target)+ eps *2
  15. dice =2* inter / union
  16. return dice
  17. deftrain_loop(model, optimizer, criterion, train_loader, device):
  18. model.train()
  19. running_loss =0
  20. pbar = tqdm(train_loader)
  21. dice_train =0for sampled_batch in pbar:
  22. volume_batch, label_batch = sampled_batch['image'], sampled_batch['label']
  23. volume_batch, label_batch = volume_batch.to(device), label_batch.to(device)# print(volume_batch.shape,label_batch.shape)
  24. outputs = model(volume_batch)# print(outputs.shape)
  25. loss = criterion(outputs, label_batch)
  26. dice = cal_dice(outputs, label_batch)
  27. dice_train += dice.item()
  28. pbar.set_postfix(loss="{:.3f}".format(loss.item()), dice="{:.3f}".format(dice.item()))
  29. running_loss += loss.item()
  30. optimizer.zero_grad()
  31. loss.backward()
  32. optimizer.step()
  33. loss = running_loss /len(train_loader)
  34. dice = dice_train /len(train_loader)return{'loss': loss,'dice': dice}defeval_loop(model, criterion, valid_loader, device):
  35. model.eval()
  36. running_loss =0
  37. pbar = tqdm(valid_loader)
  38. dice_valid =0with torch.no_grad():for sampled_batch in pbar:
  39. volume_batch, label_batch = sampled_batch['image'], sampled_batch['label']
  40. volume_batch, label_batch = volume_batch.to(device), label_batch.to(device)
  41. outputs = model(volume_batch)
  42. loss = criterion(outputs, label_batch)
  43. dice = cal_dice(outputs, label_batch)
  44. running_loss += loss.item()
  45. dice_valid += dice.item()
  46. pbar.set_postfix(loss="{:.3f}".format(loss.item()), dice="{:.3f}".format(dice.item()))
  47. loss = running_loss /len(valid_loader)
  48. dice = dice_valid /len(valid_loader)return{'loss': loss,'dice': dice}deftrain(args, model, optimizer, criterion, train_loader, valid_loader, epochs,
  49. device, train_log, loss_min=999.0):for e inrange(epochs):# train for epoch
  50. train_metrics = train_loop(model, optimizer, criterion, train_loader, device)
  51. valid_metrics = eval_loop(model, criterion, valid_loader, device)# eval for epoch
  52. info1 ="Epoch:[{}/{}] train_loss: {:.3f} valid_loss: {:.3f}".format(e +1, epochs, train_metrics["loss"],
  53. valid_metrics['loss'])
  54. info2 ="train_dice: {:.3f} valid_dice: {:.3f}".format(train_metrics['dice'], valid_metrics['dice'])print(info1 +'\n'+ info2)withopen(train_log,'a')as f:
  55. f.write(info1 +'\n'+ info2 +'\n')if valid_metrics['loss']< loss_min:
  56. loss_min = valid_metrics['loss']
  57. torch.save(model.state_dict(), args.save_path)print("Finished Training!")defmain(args):
  58. torch.manual_seed(args.seed)# CPU设置种子用于生成随机数,以使得结果是确定的
  59. torch.cuda.manual_seed_all(args.seed)# 为所有的GPU设置种子,以使得结果是确定的
  60. torch.backends.cudnn.deterministic =True
  61. torch.backends.cudnn.benchmark =True
  62. os.environ['CUDA_VISIBLE_DEVICES']='0'
  63. device = torch.device('cuda'if torch.cuda.is_available()else'cpu')# data info
  64. db_train = LAHeart(base_dir=args.train_path,
  65. split='train',
  66. transform=transforms.Compose([
  67. RandomRotFlip(),
  68. RandomCrop(args.patch_size),
  69. ToTensor(),]))
  70. db_test = LAHeart(base_dir=args.train_path,
  71. split='test',
  72. transform=transforms.Compose([
  73. CenterCrop(args.patch_size),
  74. ToTensor()]))print('Using {} images for training, {} images for testing.'.format(len(db_train),len(db_test)))
  75. trainloader = DataLoader(db_train, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True,
  76. drop_last=True)
  77. testloader = DataLoader(db_test, batch_size=1, num_workers=4, pin_memory=True)
  78. model = VNet(n_channels=1,n_classes=args.num_classes, normalization='batchnorm', has_dropout=True).to(device)
  79. criterion = Loss(n_classes=args.num_classes).to(device)
  80. optimizer = optim.SGD(model.parameters(), momentum=0.9, lr=args.lr, weight_decay=1e-4)# 加载训练模型if os.path.exists(args.weight_path):
  81. weight_dict = torch.load(args.weight_path, map_location=device)
  82. model.load_state_dict(weight_dict)print('Successfully loading checkpoint.')
  83. train(args, model, optimizer, criterion, trainloader, testloader, args.epochs, device, train_log=args.train_log)if __name__ =='__main__':
  84. parser = argparse.ArgumentParser()
  85. parser.add_argument('--num_classes',type=int, default=2)
  86. parser.add_argument('--seed',type=int, default=21)
  87. parser.add_argument('--epochs',type=int, default=160)
  88. parser.add_argument('--batch_size',type=int, default=4)
  89. parser.add_argument('--lr',type=float, default=0.01)
  90. parser.add_argument('--patch_size',type=float, default=(112,112,80))
  91. parser.add_argument('--train_path',type=str, default='/***/LASet/data')
  92. parser.add_argument('--train_log',type=str, default='results/VNet_sup.txt')
  93. parser.add_argument('--weight_path',type=str, default='results/VNet_sup.pth')# 加载
  94. parser.add_argument('--save_path',type=str, default='results/VNet_sup.pth')# 保存
  95. args = parser.parse_args()
  96. main(args)

实验结果

训练

在这里插入图片描述
在这里插入图片描述

  1. Epoch:[1/160] train_loss: 0.670 valid_loss: 0.559
  2. train_dice: 0.337 valid_dice: 0.192
  3. Epoch:[2/160] train_loss: 0.522 valid_loss: 0.567
  4. train_dice: 0.317 valid_dice: 0.143
  5. ......
  6. Epoch:[160/160] train_loss: 0.066 valid_loss: 0.090
  7. train_dice: 0.939 valid_dice: 0.924

任务比较简单,因此收敛的很快。

注意,这里的dice是测试集中心裁剪的dice,真实指标需要使用滑动窗口进行推理,代码我放在了

  1. inference.py

推理

在这里插入图片描述

图中,红色的是标签轮廓,蓝色的是 VNet 网络预测结果的轮廓。

  1. import math
  2. import torch
  3. import torch.nn.functional as F
  4. import numpy as np
  5. import h5py
  6. import nibabel as nib
  7. from medpy import metric
  8. from networks.vnet import VNet
  9. defcalculate_metric_percase(pred, gt):
  10. dice = metric.binary.dc(pred, gt)
  11. jc = metric.binary.jc(pred, gt)
  12. hd = metric.binary.hd95(pred, gt)
  13. asd = metric.binary.asd(pred, gt)return dice, jc, hd, asd
  14. deftest_single_case(net, image, stride_xy, stride_z, patch_size, num_classes=1):
  15. w, h, d = image.shape
  16. # if the size of image is less than patch_size, then padding it
  17. add_pad =Falseif w < patch_size[0]:
  18. w_pad = patch_size[0]-w
  19. add_pad =Trueelse:
  20. w_pad =0if h < patch_size[1]:
  21. h_pad = patch_size[1]-h
  22. add_pad =Trueelse:
  23. h_pad =0if d < patch_size[2]:
  24. d_pad = patch_size[2]-d
  25. add_pad =Trueelse:
  26. d_pad =0
  27. wl_pad, wr_pad = w_pad//2,w_pad-w_pad//2
  28. hl_pad, hr_pad = h_pad//2,h_pad-h_pad//2
  29. dl_pad, dr_pad = d_pad//2,d_pad-d_pad//2if add_pad:
  30. image = np.pad(image,[(wl_pad,wr_pad),(hl_pad,hr_pad),(dl_pad, dr_pad)], mode='constant', constant_values=0)
  31. ww,hh,dd = image.shape
  32. sx = math.ceil((ww - patch_size[0])/ stride_xy)+1
  33. sy = math.ceil((hh - patch_size[1])/ stride_xy)+1
  34. sz = math.ceil((dd - patch_size[2])/ stride_z)+1# print("{}, {}, {}".format(sx, sy, sz))
  35. score_map = np.zeros((num_classes,)+ image.shape).astype(np.float32)
  36. cnt = np.zeros(image.shape).astype(np.float32)for x inrange(0, sx):
  37. xs =min(stride_xy*x, ww-patch_size[0])for y inrange(0, sy):
  38. ys =min(stride_xy * y,hh-patch_size[1])for z inrange(0, sz):
  39. zs =min(stride_z * z, dd-patch_size[2])
  40. test_patch = image[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]]
  41. test_patch = np.expand_dims(np.expand_dims(test_patch,axis=0),axis=0).astype(np.float32)
  42. test_patch = torch.from_numpy(test_patch).cuda()
  43. y1 = net(test_patch)
  44. y = F.softmax(y1, dim=1)
  45. y = y.cpu().data.numpy()
  46. y = y[0,:,:,:,:]
  47. score_map[:, xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] \
  48. = score_map[:, xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]]+ y
  49. cnt[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] \
  50. = cnt[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]]+1
  51. score_map = score_map/np.expand_dims(cnt,axis=0)
  52. label_map = np.argmax(score_map, axis =0)if add_pad:
  53. label_map = label_map[wl_pad:wl_pad+w,hl_pad:hl_pad+h,dl_pad:dl_pad+d]
  54. score_map = score_map[:,wl_pad:wl_pad+w,hl_pad:hl_pad+h,dl_pad:dl_pad+d]return label_map, score_map
  55. deftest_all_case(net, image_list, num_classes=2, patch_size=(112,112,80), stride_xy=18, stride_z=4, save_result=True, test_save_path=None, preproc_fn=None):
  56. total_metric =0.0for ith,image_path inenumerate(image_list):
  57. h5f = h5py.File(image_path,'r')
  58. image = h5f['image'][:]
  59. label = h5f['label'][:]if preproc_fn isnotNone:
  60. image = preproc_fn(image)
  61. prediction, score_map = test_single_case(net, image, stride_xy, stride_z, patch_size, num_classes=num_classes)if np.sum(prediction)==0:
  62. single_metric =(0,0,0,0)else:
  63. single_metric = calculate_metric_percase(prediction, label[:])print('%02d,\t%.5f, %.5f, %.5f, %.5f'%(ith, single_metric[0], single_metric[1], single_metric[2], single_metric[3]))
  64. total_metric += np.asarray(single_metric)if save_result:
  65. nib.save(nib.Nifti1Image(prediction.astype(np.float32), np.eye(4)), test_save_path +"%02d_pred.nii.gz"%(ith))
  66. nib.save(nib.Nifti1Image(image[:].astype(np.float32), np.eye(4)), test_save_path +"%02d_img.nii.gz"%(ith))
  67. nib.save(nib.Nifti1Image(label[:].astype(np.float32), np.eye(4)), test_save_path +"%02d_gt.nii.gz"%(ith))
  68. avg_metric = total_metric /len(image_list)print('average metric is {}'.format(avg_metric))return avg_metric
  69. if __name__ =='__main__':
  70. data_path ='/***/data_set/LASet/data/'
  71. test_save_path ='predictions/'
  72. save_mode_path ='results/VNet.pth'
  73. net = VNet(n_channels=1,n_classes=2, normalization='batchnorm').cuda()
  74. net.load_state_dict(torch.load(save_mode_path))print("init weight from {}".format(save_mode_path))
  75. net.eval()withopen(data_path +'/../test.list','r')as f:
  76. image_list = f.readlines()
  77. image_list =[data_path +item.replace('\n','')+"/mri_norm2.h5"for item in image_list]# 滑动窗口法
  78. avg_metric = test_all_case(net, image_list, num_classes=2,
  79. patch_size=(112,112,80), stride_xy=18, stride_z=4,
  80. save_result=True,test_save_path=test_save_path)
  1. init weight from results/VNet.pth
  2. 00, 0.90632, 0.82868, 6.40312, 1.27997
  3. 01, 0.89492, 0.80982, 6.48074, 1.14056
  4. ......
  5. 30, 0.94105, 0.88866, 3.16228, 1.03454
  6. average metric is [0.91669405 0.84675762 5.33117527 1.42431875]

这个数据集也比较简单,常用来做半监督分割,以后也会更新一些半监督学习的内容。码字不易,有用的话还请点个赞。

项目github地址:LASeg: 2018 Left Atrium Segmentation (MRI)


代码参考 https://github.com/yulequan/UA-MT 以及 https://github.com/ycwu1997/MC-Net


本文转载自: https://blog.csdn.net/weixin_44858814/article/details/127149601
版权归原作者 宁远x 所有, 如有侵权,请联系我们删除。

“LAHeart2018左心房分割实战”的评论:

还没有评论