0


混淆矩阵Confusion Matrix(resnet34 基于 CIFAR10)

1. Confusion Matrix

混淆矩阵可以将真实标签和预测标签的结果以矩阵的形式表示出来,相比于之前计算的正确率acc更加的直观。

如下,是花分类的混淆矩阵:

之前计算的acc = 预测正确的个数 / 总个数 = 对角线的和 / 矩阵的总和

2. 其他的性能指标

除了准确率之外,还有别的指标可能更加方便的知道每一个类别的预测情况。

在介绍下面的内容之前,需要了解一些名词

其中,T都是True预测正确的,F都是False预测错误的。P是正确的label,N是错误的label

TP和TN都是是预测正确的类别。两者说明网络都可以正常分类,TP是真实值比如是猫,预测也是猫。TN是真实值为非猫,预测的结果也是非猫

FP和FN都是预测错误的。两者说明网络都不能正常分类,FN是说,真实值是猫,预测为非猫,FP是说真实值为非猫,预测为猫

方便的记法,T就是网络正确预测,P就是正确的类别。

例如:

TP,就是网络预测是对的,标签也是对的(猫)。

FP就是网络预测错的,标签是对的类别(也就是label是猫,网络预测是非猫,因为F代表错误的)。

FN就是,预测是错误的,N代表不是真正的标签,所以预测出来的是错误的正样本

TN就是,预测是对的,N代表不是正确的类别,所有预测出来也不是正确的类别

常见的有下面几种性能指标:除了准确率,其余的都是针对特定的类别计算的

3. example

比如,下面的为三分类的混淆矩阵

准确率 = 预测正确的 / 样本的总数 = (TP + TN) / (TP+TN+FP+FN) = (10+15+20)/66=0.68

下面都是针对于猫的其三个指标:

精确率 = TP / (TP+FP) = 10 / (10+1+2) = 0.77

精确度也叫查准率Precision,也就是预测为正样本中,真正正样本的比率

召回率 = TP/ (TP + FN) = 10 / (10 +3+5) = 0.56

召回率是说真正正样本中,预测为正样本的比率

特异度 = TN / (TN+FP) = (15+4+20+6) / (15+4+20+6+1+2) = 0.94

4. 代码实现混淆矩阵

首先,实现一个混淆矩阵类

然后更新混淆矩阵的值,传入预测和真正的标签,横坐标是真实值,纵坐标是预测值

p代表矩阵的行,也就是预测,t代表矩阵的列,就是真实

各项指标的计算

接着打印混淆矩阵

5. 测试,计算混淆矩阵

这里用的是之前的resnet34的迁移学习模型,数据是CIFAR10数据集

首先创建混淆矩阵类,上面注释的是手动编写的类别,下面是json文件提取的

注意这里混淆矩阵类,传入的第一个参数是混淆矩阵的size,也就是分类的个数。labels是一个list列表,存放不同的类名

更新打印混淆矩阵

6. show

混淆矩阵:

输出控制台:

观察可以发现召回率recall,就是对应对角线的值 / 1000

不难理解,因为recall = TP / (TP+FN),而分母就是label的个数,CIFAR10的测试集有1W张图像,共有10个类别,刚好每个是1k张图像,所有recall的分母都是1k

召回率,真正正样本中预测为正样本的个数

将混淆矩阵输出的图关闭后,会打印性能指标

7. 代码

混淆矩阵放在utils中,utils代码:

  1. import os
  2. os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
  3. import matplotlib.pyplot as plt
  4. import numpy as np
  5. from prettytable import PrettyTable
  6. # 计算混淆矩阵
  7. class ConfusionMatrix(object):
  8. def __init__(self, num_classes: int, labels: list):
  9. self.matrix = np.zeros((num_classes, num_classes)) # 初始化混淆矩阵
  10. self.num_classes = num_classes
  11. self.labels = labels
  12. def update(self, preds, labels): # 计算混淆矩阵的值
  13. for p, t in zip(preds, labels):
  14. self.matrix[p, t] += 1
  15. def summary(self): # 计算各项指标
  16. # calculate accuracy
  17. sum_TP = 0
  18. for i in range(self.num_classes):
  19. sum_TP += self.matrix[i, i] # 对角线的和
  20. acc = sum_TP / np.sum(self.matrix) # 混淆矩阵的和
  21. print("the model accuracy is ", acc)
  22. # precision, recall, specificity
  23. table = PrettyTable()
  24. table.field_names = ["", "Precision", "Recall", "Specificity"] # 表格的tittle
  25. for i in range(self.num_classes):
  26. TP = self.matrix[i, i] # label为真,预测为真
  27. FP = np.sum(self.matrix[i, :]) - TP # label为假,预测为真
  28. FN = np.sum(self.matrix[:, i]) - TP # label为假,预测为真
  29. TN = np.sum(self.matrix) - TP - FP - FN # label为假,预测为假
  30. Precision = round(TP / (TP + FP), 3) if TP + FP != 0 else 0.
  31. Recall = round(TP / (TP + FN), 3) if TP + FN != 0 else 0.
  32. Specificity = round(TN / (TN + FP), 3) if TN + FP != 0 else 0.
  33. table.add_row([self.labels[i], Precision, Recall, Specificity])
  34. print(table)
  35. def plot(self):
  36. matrix = self.matrix
  37. print(matrix)
  38. plt.imshow(matrix, cmap=plt.cm.Blues)
  39. plt.xticks(range(self.num_classes), self.labels, rotation=45) # 设置x轴坐标label
  40. plt.yticks(range(self.num_classes), self.labels) # 设置y轴坐标label
  41. plt.colorbar() # 显示 colorbar
  42. plt.xlabel('True Labels')
  43. plt.ylabel('Predicted Labels')
  44. plt.title('Confusion matrix')
  45. thresh = matrix.max() / 2 # 在图中标注数量/概率信息
  46. for x in range(self.num_classes):
  47. for y in range(self.num_classes):
  48. # 注意这里的matrix[y, x]不是matrix[x, y]
  49. info = int(matrix[y, x])
  50. plt.text(x, y, info,
  51. verticalalignment='center',
  52. horizontalalignment='center',
  53. color="white" if info > thresh else "black")
  54. plt.tight_layout()
  55. plt.show()

网络model:这里是resnet的代码

  1. import torch
  2. import torch.nn as nn
  3. # residual block
  4. class BasicBlock(nn.Module):
  5. expansion = 1
  6. def __init__(self,in_channel,out_channel,stride=1,downsample=None):
  7. super(BasicBlock,self).__init__()
  8. self.conv1 = nn.Conv2d(in_channel,out_channel,kernel_size=3,stride=stride,padding=1,bias=False) # 第一层的话,可能会缩小size,这时候 stride = 2
  9. self.bn1 = nn.BatchNorm2d(out_channel)
  10. self.relu = nn.ReLU()
  11. self.conv2 = nn.Conv2d(out_channel,out_channel,kernel_size=3,stride=1,padding=1,bias=False)
  12. self.bn2 = nn.BatchNorm2d(out_channel)
  13. self.downsample = downsample
  14. def forward(self,x):
  15. identity = x
  16. if self.downsample is not None: # 有下采样,意味着需要1*1进行降维,同时channel翻倍,residual block虚线部分
  17. identity = self.downsample(x)
  18. out = self.conv1(x)
  19. out = self.bn1(out)
  20. out = self.relu(out)
  21. out = self.conv2(out)
  22. out = self.bn2(out)
  23. out += identity
  24. out = self.relu(out)
  25. return out
  26. # bottleneck
  27. class Bottleneck(nn.Module):
  28. expansion = 4 # 卷积核的变化
  29. def __init__(self,in_channel,out_channel,stride=1,downsample=None):
  30. super(Bottleneck,self).__init__()
  31. # 1*1 降维度 --------> padding默认为 0,size不变,channel被降低
  32. self.conv1 = nn.Conv2d(in_channel,out_channel,kernel_size=1,stride=1,bias=False)
  33. self.bn1 = nn.BatchNorm2d(out_channel)
  34. # 3*3 卷积
  35. self.conv2 = nn.Conv2d(out_channel,out_channel,kernel_size=3,stride=stride,bias=False)
  36. self.bn2 = nn.BatchNorm2d(out_channel)
  37. # 1*1 还原维度 --------> padding默认为 0,size不变,channel被还原
  38. self.conv3 = nn.Conv2d(out_channel,out_channel*self.expansion,kernel_size=1,stride=1,bias=False)
  39. self.bn3 = nn.BatchNorm2d(out_channel*self.expansion)
  40. # other
  41. self.relu = nn.ReLU(inplace=True)
  42. self.downsample =downsample
  43. def forward(self,x):
  44. identity = x
  45. if self.downsample is not None:
  46. identity = self.downsample(x)
  47. out = self.conv1(x)
  48. out = self.bn1(out)
  49. out = self.relu(out)
  50. out = self.conv2(out)
  51. out = self.bn2(out)
  52. out = self.relu(out)
  53. out = self.conv3(out)
  54. out = self.bn3(out)
  55. out += identity
  56. out = self.relu(out)
  57. return out
  58. # resnet
  59. class ResNet(nn.Module):
  60. def __init__(self,block,block_num,num_classes=1000,include_top=True):
  61. super(ResNet, self).__init__()
  62. self.include_top = include_top
  63. self.in_channel = 64 # max pool 之后的 depth
  64. # 网络最开始的部分,输入是RGB图像,经过卷积,图像size减半,通道变为64
  65. self.conv1 = nn.Conv2d(3,self.in_channel,kernel_size=7,stride=2,padding=3,bias=False)
  66. self.bn1 = nn.BatchNorm2d(self.in_channel)
  67. self.relu = nn.ReLU(inplace=True)
  68. self.maxpool = nn.MaxPool2d(kernel_size=3,stride=2,padding=1) # size减半,padding = 1
  69. self.layer1 = self.__make_layer(block,64,block_num[0]) # conv2_x
  70. self.layer2 = self.__make_layer(block,128,block_num[1],stride=2) # conv3_x
  71. self.layer3 = self.__make_layer(block,256,block_num[2],stride=2) # conv4_X
  72. self.layer4 = self.__make_layer(block,512,block_num[3],stride=2) # conv5_x
  73. if self.include_top: # 分类部分
  74. self.avgpool = nn.AdaptiveAvgPool2d((1,1)) # out_size = 1*1
  75. self.fc = nn.Linear(512*block.expansion,num_classes)
  76. def __make_layer(self,block,channel,block_num,stride=1):
  77. downsample =None
  78. if stride != 1 or self.in_channel != channel*block.expansion: # shortcut 部分,1*1 进行升维
  79. downsample=nn.Sequential(
  80. nn.Conv2d(self.in_channel,channel*block.expansion,kernel_size=1,stride=stride,bias=False),
  81. nn.BatchNorm2d(channel*block.expansion)
  82. )
  83. layers =[]
  84. layers.append(block(self.in_channel, channel, downsample =downsample, stride=stride))
  85. self.in_channel = channel * block.expansion
  86. for _ in range(1,block_num): # residual 实线的部分
  87. layers.append(block(self.in_channel,channel))
  88. return nn.Sequential(*layers)
  89. def forward(self,x):
  90. # resnet 前面的卷积部分
  91. x = self.conv1(x)
  92. x = self.bn1(x)
  93. x = self.relu(x)
  94. x = self.maxpool(x)
  95. # residual 特征提取层
  96. x = self.layer1(x)
  97. x = self.layer2(x)
  98. x = self.layer3(x)
  99. x = self.layer4(x)
  100. # 分类
  101. if self.include_top:
  102. x = self.avgpool(x)
  103. x = torch.flatten(x,start_dim=1)
  104. x = self.fc(x)
  105. return x
  106. # 定义网络
  107. def resnet34(num_classes=1000,include_top=True):
  108. return ResNet(BasicBlock,[3,4,6,3],num_classes=num_classes,include_top=include_top)
  109. def resnet101(num_classes=1000,include_top=True):
  110. return ResNet(Bottleneck,[3,4,23,3],num_classes=num_classes,include_top=include_top)

主函数main:

  1. import torch
  2. from torchvision import transforms, datasets
  3. from tqdm import tqdm
  4. from model import resnet34
  5. from utils import ConfusionMatrix
  6. import json
  7. if __name__ == '__main__':
  8. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  9. print(device)
  10. data_transform = transforms.Compose([transforms.Resize(256),
  11. transforms.CenterCrop(224),
  12. transforms.ToTensor(),
  13. transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
  14. # 加载数据
  15. validate_dataset = datasets.CIFAR10(root='./data',train=False,transform=data_transform)
  16. validate_loader = torch.utils.data.DataLoader(validate_dataset,batch_size=16, shuffle=True)
  17. # 加载网络
  18. net = resnet34(num_classes=10)
  19. model_weight_path = "./resnet.pth"
  20. net.load_state_dict(torch.load(model_weight_path, map_location=device))
  21. net.to(device)
  22. # 类别
  23. # classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
  24. # labels = [label for label in classes]
  25. # confusion = ConfusionMatrix(num_classes=10, labels=labels)
  26. # 类别
  27. json_label_path = './class_indices.json'
  28. json_file = open(json_label_path, 'r')
  29. class_indict = json.load(json_file)
  30. labels = [label for _, label in class_indict.items()]
  31. confusion = ConfusionMatrix(num_classes=10, labels=labels)
  32. net.eval()
  33. with torch.no_grad():
  34. for val_data in tqdm(validate_loader):
  35. val_images, val_labels = val_data
  36. outputs = net(val_images.to(device))
  37. outputs = torch.softmax(outputs, dim=1)
  38. outputs = torch.argmax(outputs, dim=1)
  39. confusion.update(outputs.to("cpu").numpy(), val_labels.to("cpu").numpy()) # 更新混淆矩阵的值
  40. confusion.plot() # 绘制混淆矩阵
  41. confusion.summary() # 计算指标

本文转载自: https://blog.csdn.net/qq_44886601/article/details/129952744
版权归原作者 听风吹等浪起 所有, 如有侵权,请联系我们删除。

“混淆矩阵Confusion Matrix(resnet34 基于 CIFAR10)”的评论:

还没有评论