0


Keras实现vgg16网络和迁移学习

VGGnet

VGGnet是由牛津大学和DeepMind研发的深度学习网络(2014年)。它是由Alexnet发展而来的,其中最为经典的是vgg16和vgg19 ,至今任被广泛应用于图像特征提取

下面是vgg11-vgg19的结构,vgg后的数字代表网络层数

Keras定义vgg16

vgg16结构:

上图的输入是224*224,可以根据实际做调整

keras使我们很容易构建深度学习模型

根据上图定义的vgg16模型:

  1. from keras.models import Sequential
  2. from keras.layers import Conv2D,Dense,Flatten,Dropout,MaxPool2D,BatchNormalization
  1. def vgg16_model():
  2. model=Sequential()
  3. model.add(Conv2D(filters=64,kernel_size=(3,3),strides=(1,1),input_shape=(224,224,3),padding='same',activation='relu',kernel_initializer='glorot_uniform'))
  4. model.add(Conv2D(filters=64,kernel_size=(3,3),strides=(1,1),padding='same',activation='relu',kernel_initializer='glorot_uniform'))
  5. model.add(MaxPool2D()) # pool_size=(2,2) strides=(2,2)
  6. model.add(Conv2D(128,(3,3),strides=(1,1),padding='same',activation='relu',kernel_initializer='glorot_uniform'))
  7. model.add(Conv2D(128,(3,3),strides=(1,1),padding='same',activation='relu',kernel_initializer='glorot_uniform'))
  8. model.add(MaxPool2D())
  9. model.add(Conv2D(256,(3,3),strides=(1,1),padding='same',activation='relu',kernel_initializer='glorot_uniform'))
  10. model.add(Conv2D(256,(3,3),strides=(1,1),padding='same',activation='relu',kernel_initializer='glorot_uniform'))
  11. model.add(Conv2D(256,(1,1),strides=(1,1),padding='same',activation='relu',kernel_initializer='glorot_uniform'))
  12. model.add(MaxPool2D())
  13. model.add(Conv2D(512,(3,3),strides=(1,1),padding='same',activation='relu',kernel_initializer='glorot_uniform'))
  14. model.add(Conv2D(512,(3,3),strides=(1,1),padding='same',activation='relu',kernel_initializer='glorot_uniform'))
  15. model.add(Conv2D(512,(1,1),strides=(1,1),padding='same',activation='relu',kernel_initializer='glorot_uniform'))
  16. model.add(MaxPool2D())
  17. model.add(Conv2D(512,(3,3),strides=(1,1),padding='same',activation='relu',kernel_initializer='glorot_uniform'))
  18. model.add(Conv2D(512,(3,3),strides=(1,1),padding='same',activation='relu',kernel_initializer='glorot_uniform'))
  19. model.add(Conv2D(512,(1,1),strides=(1,1),padding='same',activation='relu',kernel_initializer='glorot_uniform'))
  20. model.add(MaxPool2D())
  21. model.add(Flatten())
  22. model.add(Dense(4096,activation='relu'))
  23. model.add(Dropout(0.5))
  24. model.add(Dense(4096,activation='relu'))
  25. model.add(Dropout(0.5))
  26. model.add(Dense(1000,activation='relu'))
  27. model.add(Dropout(0.5))
  28. model.add(Dense(1000,activation='softmax'))
  29. return model
  1. model=vgg16_model()
  2. model.summary()
  1. _________________________________________________________________
  2. Layer (type) Output Shape Param #
  3. =================================================================
  4. conv2d_146 (Conv2D) (None, 224, 224, 64) 1792
  5. conv2d_147 (Conv2D) (None, 224, 224, 64) 36928
  6. max_pooling2d_56 (MaxPoolin (None, 112, 112, 64) 0
  7. g2D)
  8. conv2d_148 (Conv2D) (None, 112, 112, 128) 73856
  9. conv2d_149 (Conv2D) (None, 112, 112, 128) 147584
  10. max_pooling2d_57 (MaxPoolin (None, 56, 56, 128) 0
  11. g2D)
  12. conv2d_150 (Conv2D) (None, 56, 56, 256) 295168
  13. conv2d_151 (Conv2D) (None, 56, 56, 256) 590080
  14. conv2d_152 (Conv2D) (None, 56, 56, 256) 65792
  15. max_pooling2d_58 (MaxPoolin (None, 28, 28, 256) 0
  16. ...
  17. Total params: 134,639,952
  18. Trainable params: 134,639,952
  19. Non-trainable params: 0

Keras实现vgg16迁移学习

训练数据:

cifar10数据集(由 Hinton 的学生 Alex Krizhevsky 和 Ilya Sutskever 整理的一个用于识别普适物体的小型数据集。一共包含 10 个类别的 RGB 彩色图 片:飞机( a叩lane )、汽车( automobile )、鸟类( bird )、猫( cat )、鹿( deer )、狗( dog )、蛙类( frog )、马( horse )、船( ship )和卡车( truck )。图片的尺寸为 32×32 ,数据集中一共有 50000 张训练图片和 10000 张测试图片)

Keras.applications:

提供了许多带有预训练权值的深度学习模型

livelossplot:

一个绘制损失图像的库,可以在训练过程中实时绘制训练情况

  1. from keras.datasets import cifar10
  2. from keras.utils import to_categorical
  3. from keras import applications
  4. from keras.models import Sequential,Model
  5. from keras.layers import Dense,Dropout,Flatten
  6. from keras import optimizers
  1. # 加载cifar10数据集
  2. (x_train, y_train), (x_test, y_test) = cifar10.load_data()
  3. x_train = x_train.astype('float32') / 255
  4. x_test = x_test.astype('float32') / 255
  5. y_train = to_categorical(y_train, 10) #按照10个类别one-hot编码
  6. y_test = to_categorical(y_test, 10)
  1. # 加载keras训练后的vgg16模型
  2. vgg_model= applications.VGG16(include_top=False,input_shape=(32,32,3))
  3. for layer in vgg_model.layers[:15]: #冻结前15层
  4. layer.trainable=False
  5. # 定义迁移学习层
  6. top_model=Sequential()
  7. top_model.add(Flatten(input_shape=vgg_model.output_shape[1:]))
  8. top_model.add(Dense(32, activation='relu'))
  9. top_model.add(Dropout(0.5))
  10. top_model.add(Dense(10, activation='softmax'))
  11. # 结合预训练模型和迁移学习层为新模型
  12. model = Model(
  13. inputs=vgg_model.input,
  14. outputs=top_model(vgg_model.output))
  15. model.compile(
  16. loss='categorical_crossentropy',
  17. optimizer = optimizers.Adam(learning_rate=0.0001),
  18. metrics=['accuracy'])
  1. # 边训练边绘制损失
  2. from livelossplot import PlotLossesKeras
  3. plotlosses = PlotLossesKeras()
  4. model.fit(x_train,y_train,epochs=5,verbose=1,validation_split=0.1,batch_size=32,callbacks=plotlosses)

训练中损失和分类准确率变化:

  1. # 测试集损失与准确率
  2. test_loss, test_acc = model.evaluate(x_test,y_test)
  3. print('test_loss:{:.2} test_acc:{}%'.format(test_loss,test_acc*100))
  1. # 保存模型
  2. model.save('trained/cifar10.h5')

预测新图-识别猫咪

用学习出的模型预测猫咪

  1. from keras.models import load_model
  2. from keras.utils import image_utils
  3. import matplotlib.pyplot as plt
  4. import numpy as np
  5. # 加载模型
  6. model=load_model("trained/cifar10.h5")
  7. # 图片预处理
  8. path = 'data/my_cat.jpg'
  9. img_height, img_width = 32, 32
  10. x = image_utils.load_img(path=path, target_size=(img_height, img_width))
  11. x = image_utils.img_to_array(x)
  12. # print(x.shape) # (32, 32, 3)
  13. x = x[None] # 相当于增加一个维度
  14. # print(x.shape) # (1, 32, 32, 3)
  15. # 预测
  16. y = model.predict(x)
  17. labels=["airplane","automobile","bird","cat","deer","dog","frog","horse","ship","truck"]
  18. result=[i.tolist().index(1) for i in y]
  19. print("This is a {}".format(labels[result[0]]))
  20. img = Image.open(path)
  21. # img.show() # 会调用系统的显示窗口
  22. plt.figure(dpi=120)
  23. plt.imshow(img)

成功识别我的猫!


本文转载自: https://blog.csdn.net/skycol/article/details/127152048
版权归原作者 skycol 所有, 如有侵权,请联系我们删除。

“Keras实现vgg16网络和迁移学习”的评论:

还没有评论