一 前言
本文仅作为经验分享以及学习记录,如有问题,可以在评论区和我讨论。
具体的理论知识暂且不讲,待有时间了我就会慢慢分享理论知识,目前就整点干货,直接上代码,怎么从零开始构建自己的GAN网络。
本项目Github地址
二 生成式对抗网络GAN
生成对抗网络(GAN)有两个部分:生成网络G(Generator)和判别网络D(Discriminator)。
(1)生成网络G:用来生成图片的网络,它接收一个随机的噪声noise,通过这个噪声生成图片。
(2)判别网络D:用来判别图片是否真实的网络。它的输入是一张图片img,输出是img为真实图片的概率,如果为1,就代表100%是真实的图片,而输出为0,就代表不可能是真实的图片。
生成网络G的目的是努力生成一个图片来骗过判别网络D,判别网络D的目的是努力鉴别出生成出来的图片是假的。两个网络在不断博弈中互相进步,达到理想状态:D(G(noise))=0.5(即判别网络D也不确定是到底是不是真实的)
三 GAN的训练思路
GAN的训练要同时训练两个网络,我们使用的方法是:单独交替迭代训练(即训练一个网络的时候,固定住一个网络,去训练另一个网络)
这样做的目的是防止其中一个网络比另一个网络强大太多,导致网络性能弱化。在整个训练过程中,两个网络不断变强,达到理想状态。
四 数据集——Chinese MNIST
我的数据集选的是Kaggle网站上的Chinese MNIST,下载地址
下载速度慢的可以参考我的另一篇博客——解决Kaggle网站下载数据集速度慢,不方便下载的可以联系我发给你压缩包。
数据集举例:
五 代码——python
如果需要替换成自己数据集,我会在每部分代码首部进行特别说明。
这里我们直接开始,直接上代码,通过代码,一方面有助于我梳理本次学习思路,二是我觉得这样更直接明了一些,毕竟动手才有趣。
1.文件展示
2.代码(一) ——数据预处理
这部分函数用来加载path路径下的文件,即我的数据集,也可以根据你们需求换成别的数据集。只需要更改自己的数据集文件夹即可。
def load_data(self, path):
print("loading images...")
data = []
labels = []
imagePaths = sorted(list(paths.list_images(path)))
random.seed(42)
random.shuffle(imagePaths)
for imagePath in imagePaths:
image = cv2.imread(imagePath)
image = cv2.resize(image, (self.img_rows, self.img_cols))
image = img_to_array(image)
data.append(image)
label = str(imagePath.split(os.path.sep)[-2])
labels.append(label)
data = np.array(data, dtype="float") / 255.0
return data
3.代码(二) ——生成器的构建
这部分代码用来构建生成网络G,不需要更改,尽管网络性能不是很好,但不是必须修改的。
# 构建生成器
def build_generator(self):
model = Sequential() # 模型选用的是传统的线性模型
model.add(Dense(256, input_dim=self.latent_dim)) # 全连接层
model.add(LeakyReLU(alpha=0.2)) # 带泄露修正线性单元
model.add(BatchNormalization(momentum=0.8)) # 批归一化
model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(1024))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(np.prod(self.img_shape), activation='tanh')) # np.prod()计算所有乘积,输入
model.add(Reshape(self.img_shape)) # reshape成图片的尺寸
# model.summary() # 日志
noise = Input(shape=(self.latent_dim,))
img = model(noise)
return Model(noise, img)
4.代码(三) ——判别器的构建
这部分代码用来构建判别网络D,不需要更改,尽管网络性能不是很好,但不是必须修改的。
# 构建判别器
def build_discriminator(self):
# 模型选用的是传统的线性模型,CNN中用的也是这个
model = Sequential()
model.add(Flatten(input_shape=self.img_shape)) # 展平层
model.add(Dense(512)) # 全连接层
model.add(LeakyReLU(alpha=0.2)) # 带泄露修正线性单元
model.add(Dense(256))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(1, activation='sigmoid'))
# model.summary()
img = Input(shape=self.img_shape) # 输入尺寸
validity = model(img)
return Model(img, validity)
5.代码(四) ——图像的储存
这部分代码用来储存生成网络不同epoch的输出,不必更改。
def sample_images(self, epoch):
r, c = 5, 5
noise = np.random.normal(0, 1, (r * c, self.latent_dim))
gen_imgs = self.generator.predict(noise)
gen_imgs = 0.5 * gen_imgs + 0.5
fig, axs = plt.subplots(r, c)
cnt = 0
for i in range(r):
for j in range(c):
axs[i, j].imshow(gen_imgs[cnt, :, :, 0], cmap='gray')
axs[i, j].axis('off')
cnt += 1
# 保存地址为:"images/"
fig.savefig("images/%d.png" % epoch)
plt.close()
6.代码(五) ——网络的训练
这部分代码用来训练网络,不必更改。
def train(self, epochs, batch_size=128, sample_interval=50, file_path=None):
# 加载数据
X_train = self.load_data(file_path)
# 标准化
# X_train = np.expand_dims(X_train, axis=3)
# 创建标签
valid = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))
for epoch in range(epochs):
idx = np.random.randint(0, X_train.shape[0], batch_size)
imgs = X_train[idx]
noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
gen_imgs = self.generator.predict(noise)
d_loss_real = self.discriminator.train_on_batch(imgs, valid)
d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
g_loss = self.combined.train_on_batch(noise, valid)
if epoch % 200 == 0:
print("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100 * d_loss[1], g_loss))
# 图像的保存,每sample_interval次保存图片一次
if epoch % sample_interval == 0:
self.sample_images(epoch)
# 模型权重的保存,每2000个epoch,保存一次模型,保存地址为"weights/"
if epoch % 2000 == 0:
os.makedirs('weights', exist_ok=True)
self.generator.save_weights("weights/gen_epoch%d.h5" % epoch)
self.discriminator.save_weights("weights/dis_epoch%d.h5" % epoch)
7.代码(六) ——网络参数的定义
这部分代码定义了网络的一些参数,比如输入尺寸(我的数据集图片大小是[64,64,3]),优化器等等。
需要根据自己的数据集图片的大小,更改self.img_rows、self.img_cols、self.channels
def __init__(self):
# 图片尺寸 在这里更改!!!!
self.img_rows = 64
self.img_cols = 64
self.channels = 3
# 输入的图片尺寸
self.img_shape = (self.img_rows, self.img_cols, self.channels)
self.latent_dim = 100
# Adam优化器
optimizer = Adam(0.0002, 0.5)
self.discriminator = self.build_discriminator()
self.discriminator.compile(loss='binary_crossentropy',
optimizer=optimizer,
metrics=['accuracy'])
self.generator = self.build_generator()
z = Input(shape=(self.latent_dim,))
img = self.generator(z)
self.discriminator.trainable = False
validity = self.discriminator(img)
self.combined = Model(z, validity)
self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)
8.完整代码
需要根据自己需求,更改
epochs训练次数;batch_size每组的数量;sample_interval多少次输出一张图片;file_path数据集路径
完整代码我已上传到github中,代码地址
六 运行效果
迭代0次
迭代10000次
迭代30000次
迭代50000次
由于时间以及本人显卡配置的限制,只进行了50000次迭代,为了更好的效果可以增加迭代次数。
总结
至此本博客,从0开始搭建GAN网络就结束了,有什么问题欢迎和我讨论。
版权归原作者 秦天宝. 所有, 如有侵权,请联系我们删除。