0


机器学习——从0开始构建自己的GAN网络

一 前言

本文仅作为经验分享以及学习记录,如有问题,可以在评论区和我讨论。

具体的理论知识暂且不讲,待有时间了我就会慢慢分享理论知识,目前就整点干货,直接上代码,怎么从零开始构建自己的GAN网络。

本项目Github地址

二 生成式对抗网络GAN

生成对抗网络(GAN)有两个部分:生成网络G(Generator)和判别网络D(Discriminator)。
1生成网络G:用来生成图片的网络,它接收一个随机的噪声noise,通过这个噪声生成图片。
2判别网络D:用来判别图片是否真实的网络。它的输入是一张图片img,输出是img为真实图片的概率,如果为1,就代表100%是真实的图片,而输出为0,就代表不可能是真实的图片。

生成网络G的目的是努力生成一个图片来骗过判别网络D,判别网络D的目的是努力鉴别出生成出来的图片是假的。两个网络在不断博弈中互相进步,达到理想状态:D(G(noise))=0.5(即判别网络D也不确定是到底是不是真实的)

三 GAN的训练思路

GAN的训练要同时训练两个网络,我们使用的方法是:单独交替迭代训练(即训练一个网络的时候,固定住一个网络,去训练另一个网络)

这样做的目的是防止其中一个网络比另一个网络强大太多,导致网络性能弱化。在整个训练过程中,两个网络不断变强,达到理想状态。

四 数据集——Chinese MNIST

我的数据集选的是Kaggle网站上的Chinese MNIST,下载地址

下载速度慢的可以参考我的另一篇博客——解决Kaggle网站下载数据集速度慢,不方便下载的可以联系我发给你压缩包。

数据集举例:

五 代码——python

如果需要替换成自己数据集,我会在每部分代码首部进行特别说明。

这里我们直接开始,直接上代码,通过代码,一方面有助于我梳理本次学习思路,二是我觉得这样更直接明了一些,毕竟动手才有趣。

1.文件展示

2.代码(一) ——数据预处理

这部分函数用来加载path路径下的文件,即我的数据集,也可以根据你们需求换成别的数据集。只需要更改自己的数据集文件夹即可。

  1. def load_data(self, path):
  2. print("loading images...")
  3. data = []
  4. labels = []
  5. imagePaths = sorted(list(paths.list_images(path)))
  6. random.seed(42)
  7. random.shuffle(imagePaths)
  8. for imagePath in imagePaths:
  9. image = cv2.imread(imagePath)
  10. image = cv2.resize(image, (self.img_rows, self.img_cols))
  11. image = img_to_array(image)
  12. data.append(image)
  13. label = str(imagePath.split(os.path.sep)[-2])
  14. labels.append(label)
  15. data = np.array(data, dtype="float") / 255.0
  16. return data

3.代码(二) ——生成器的构建

这部分代码用来构建生成网络G,不需要更改,尽管网络性能不是很好,但不是必须修改的。

  1. # 构建生成器
  2. def build_generator(self):
  3. model = Sequential() # 模型选用的是传统的线性模型
  4. model.add(Dense(256, input_dim=self.latent_dim)) # 全连接层
  5. model.add(LeakyReLU(alpha=0.2)) # 带泄露修正线性单元
  6. model.add(BatchNormalization(momentum=0.8)) # 批归一化
  7. model.add(Dense(512))
  8. model.add(LeakyReLU(alpha=0.2))
  9. model.add(BatchNormalization(momentum=0.8))
  10. model.add(Dense(1024))
  11. model.add(LeakyReLU(alpha=0.2))
  12. model.add(BatchNormalization(momentum=0.8))
  13. model.add(Dense(np.prod(self.img_shape), activation='tanh')) # np.prod()计算所有乘积,输入
  14. model.add(Reshape(self.img_shape)) # reshape成图片的尺寸
  15. # model.summary() # 日志
  16. noise = Input(shape=(self.latent_dim,))
  17. img = model(noise)
  18. return Model(noise, img)

4.代码(三) ——判别器的构建

这部分代码用来构建判别网络D,不需要更改,尽管网络性能不是很好,但不是必须修改的。

  1. # 构建判别器
  2. def build_discriminator(self):
  3. # 模型选用的是传统的线性模型,CNN中用的也是这个
  4. model = Sequential()
  5. model.add(Flatten(input_shape=self.img_shape)) # 展平层
  6. model.add(Dense(512)) # 全连接层
  7. model.add(LeakyReLU(alpha=0.2)) # 带泄露修正线性单元
  8. model.add(Dense(256))
  9. model.add(LeakyReLU(alpha=0.2))
  10. model.add(Dense(1, activation='sigmoid'))
  11. # model.summary()
  12. img = Input(shape=self.img_shape) # 输入尺寸
  13. validity = model(img)
  14. return Model(img, validity)

5.代码(四) ——图像的储存

这部分代码用来储存生成网络不同epoch的输出,不必更改。

  1. def sample_images(self, epoch):
  2. r, c = 5, 5
  3. noise = np.random.normal(0, 1, (r * c, self.latent_dim))
  4. gen_imgs = self.generator.predict(noise)
  5. gen_imgs = 0.5 * gen_imgs + 0.5
  6. fig, axs = plt.subplots(r, c)
  7. cnt = 0
  8. for i in range(r):
  9. for j in range(c):
  10. axs[i, j].imshow(gen_imgs[cnt, :, :, 0], cmap='gray')
  11. axs[i, j].axis('off')
  12. cnt += 1
  13. # 保存地址为:"images/"
  14. fig.savefig("images/%d.png" % epoch)
  15. plt.close()

6.代码(五) ——网络的训练

这部分代码用来训练网络,不必更改。

  1. def train(self, epochs, batch_size=128, sample_interval=50, file_path=None):
  2. # 加载数据
  3. X_train = self.load_data(file_path)
  4. # 标准化
  5. # X_train = np.expand_dims(X_train, axis=3)
  6. # 创建标签
  7. valid = np.ones((batch_size, 1))
  8. fake = np.zeros((batch_size, 1))
  9. for epoch in range(epochs):
  10. idx = np.random.randint(0, X_train.shape[0], batch_size)
  11. imgs = X_train[idx]
  12. noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
  13. gen_imgs = self.generator.predict(noise)
  14. d_loss_real = self.discriminator.train_on_batch(imgs, valid)
  15. d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
  16. d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
  17. noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
  18. g_loss = self.combined.train_on_batch(noise, valid)
  19. if epoch % 200 == 0:
  20. print("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100 * d_loss[1], g_loss))
  21. # 图像的保存,每sample_interval次保存图片一次
  22. if epoch % sample_interval == 0:
  23. self.sample_images(epoch)
  24. # 模型权重的保存,每2000个epoch,保存一次模型,保存地址为"weights/"
  25. if epoch % 2000 == 0:
  26. os.makedirs('weights', exist_ok=True)
  27. self.generator.save_weights("weights/gen_epoch%d.h5" % epoch)
  28. self.discriminator.save_weights("weights/dis_epoch%d.h5" % epoch)

7.代码(六) ——网络参数的定义

这部分代码定义了网络的一些参数,比如输入尺寸(我的数据集图片大小是[64,64,3]),优化器等等。

需要根据自己的数据集图片的大小,更改self.img_rows、self.img_cols、self.channels

  1. def __init__(self):
  2. # 图片尺寸 在这里更改!!!!
  3. self.img_rows = 64
  4. self.img_cols = 64
  5. self.channels = 3
  6. # 输入的图片尺寸
  7. self.img_shape = (self.img_rows, self.img_cols, self.channels)
  8. self.latent_dim = 100
  9. # Adam优化器
  10. optimizer = Adam(0.0002, 0.5)
  11. self.discriminator = self.build_discriminator()
  12. self.discriminator.compile(loss='binary_crossentropy',
  13. optimizer=optimizer,
  14. metrics=['accuracy'])
  15. self.generator = self.build_generator()
  16. z = Input(shape=(self.latent_dim,))
  17. img = self.generator(z)
  18. self.discriminator.trainable = False
  19. validity = self.discriminator(img)
  20. self.combined = Model(z, validity)
  21. self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)

8.完整代码

需要根据自己需求,更改

epochs训练次数;batch_size每组的数量;sample_interval多少次输出一张图片;file_path数据集路径

完整代码我已上传到github中,代码地址

六 运行效果

迭代0次

迭代10000次

迭代30000次

迭代50000次

由于时间以及本人显卡配置的限制,只进行了50000次迭代,为了更好的效果可以增加迭代次数。

总结

至此本博客,从0开始搭建GAN网络就结束了,有什么问题欢迎和我讨论。


本文转载自: https://blog.csdn.net/qq_52222102/article/details/125126575
版权归原作者 秦天宝. 所有, 如有侵权,请联系我们删除。

“机器学习——从0开始构建自己的GAN网络”的评论:

还没有评论