0


机器学习——从0开始构建自己的GAN网络

一 前言

本文仅作为经验分享以及学习记录,如有问题,可以在评论区和我讨论。

具体的理论知识暂且不讲,待有时间了我就会慢慢分享理论知识,目前就整点干货,直接上代码,怎么从零开始构建自己的GAN网络。

本项目Github地址

二 生成式对抗网络GAN

生成对抗网络(GAN)有两个部分:生成网络G(Generator)和判别网络D(Discriminator)。
1生成网络G:用来生成图片的网络,它接收一个随机的噪声noise,通过这个噪声生成图片。
2判别网络D:用来判别图片是否真实的网络。它的输入是一张图片img,输出是img为真实图片的概率,如果为1,就代表100%是真实的图片,而输出为0,就代表不可能是真实的图片。

生成网络G的目的是努力生成一个图片来骗过判别网络D,判别网络D的目的是努力鉴别出生成出来的图片是假的。两个网络在不断博弈中互相进步,达到理想状态:D(G(noise))=0.5(即判别网络D也不确定是到底是不是真实的)

三 GAN的训练思路

GAN的训练要同时训练两个网络,我们使用的方法是:单独交替迭代训练(即训练一个网络的时候,固定住一个网络,去训练另一个网络)

这样做的目的是防止其中一个网络比另一个网络强大太多,导致网络性能弱化。在整个训练过程中,两个网络不断变强,达到理想状态。

四 数据集——Chinese MNIST

我的数据集选的是Kaggle网站上的Chinese MNIST,下载地址

下载速度慢的可以参考我的另一篇博客——解决Kaggle网站下载数据集速度慢,不方便下载的可以联系我发给你压缩包。

数据集举例:

五 代码——python

如果需要替换成自己数据集,我会在每部分代码首部进行特别说明。

这里我们直接开始,直接上代码,通过代码,一方面有助于我梳理本次学习思路,二是我觉得这样更直接明了一些,毕竟动手才有趣。

1.文件展示

2.代码(一) ——数据预处理

这部分函数用来加载path路径下的文件,即我的数据集,也可以根据你们需求换成别的数据集。只需要更改自己的数据集文件夹即可。

 def load_data(self, path):
        print("loading images...")
        data = []
        labels = []
        imagePaths = sorted(list(paths.list_images(path)))
        random.seed(42)
        random.shuffle(imagePaths)
        for imagePath in imagePaths:
            image = cv2.imread(imagePath)
            image = cv2.resize(image, (self.img_rows, self.img_cols))
            image = img_to_array(image)
            data.append(image)

            label = str(imagePath.split(os.path.sep)[-2])
            labels.append(label)

        data = np.array(data, dtype="float") / 255.0
        return data

3.代码(二) ——生成器的构建

这部分代码用来构建生成网络G,不需要更改,尽管网络性能不是很好,但不是必须修改的。

# 构建生成器
    def build_generator(self):
        model = Sequential()        # 模型选用的是传统的线性模型
        model.add(Dense(256, input_dim=self.latent_dim))  # 全连接层
        model.add(LeakyReLU(alpha=0.2))  # 带泄露修正线性单元
        model.add(BatchNormalization(momentum=0.8))  # 批归一化
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(1024))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(np.prod(self.img_shape), activation='tanh'))  # np.prod()计算所有乘积,输入
        model.add(Reshape(self.img_shape))  # reshape成图片的尺寸

        # model.summary()  # 日志

        noise = Input(shape=(self.latent_dim,))
        img = model(noise)

        return Model(noise, img)

4.代码(三) ——判别器的构建

这部分代码用来构建判别网络D,不需要更改,尽管网络性能不是很好,但不是必须修改的。

# 构建判别器
    def build_discriminator(self):
        # 模型选用的是传统的线性模型,CNN中用的也是这个
        model = Sequential()

        model.add(Flatten(input_shape=self.img_shape))  # 展平层
        model.add(Dense(512))  # 全连接层
        model.add(LeakyReLU(alpha=0.2))  # 带泄露修正线性单元
        model.add(Dense(256))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(1, activation='sigmoid'))
        # model.summary()

        img = Input(shape=self.img_shape)  # 输入尺寸
        validity = model(img)

        return Model(img, validity)

5.代码(四) ——图像的储存

这部分代码用来储存生成网络不同epoch的输出,不必更改。

    def sample_images(self, epoch):
        r, c = 5, 5
        noise = np.random.normal(0, 1, (r * c, self.latent_dim))
        gen_imgs = self.generator.predict(noise)

        gen_imgs = 0.5 * gen_imgs + 0.5

        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i, j].imshow(gen_imgs[cnt, :, :, 0], cmap='gray')
                axs[i, j].axis('off')
                cnt += 1
        # 保存地址为:"images/"
        fig.savefig("images/%d.png" % epoch)
        plt.close()

6.代码(五) ——网络的训练

这部分代码用来训练网络,不必更改。

    def train(self, epochs, batch_size=128, sample_interval=50, file_path=None):

        # 加载数据
        X_train = self.load_data(file_path)
        # 标准化
        # X_train = np.expand_dims(X_train, axis=3)

        # 创建标签
        valid = np.ones((batch_size, 1))
        fake = np.zeros((batch_size, 1))

        for epoch in range(epochs):
            idx = np.random.randint(0, X_train.shape[0], batch_size)
            imgs = X_train[idx]

            noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
            gen_imgs = self.generator.predict(noise)

            d_loss_real = self.discriminator.train_on_batch(imgs, valid)
            d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

            noise = np.random.normal(0, 1, (batch_size, self.latent_dim))

            g_loss = self.combined.train_on_batch(noise, valid)

            if epoch % 200 == 0:
                print("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100 * d_loss[1], g_loss))

            # 图像的保存,每sample_interval次保存图片一次
            if epoch % sample_interval == 0:
                self.sample_images(epoch)
            # 模型权重的保存,每2000个epoch,保存一次模型,保存地址为"weights/"
            if epoch % 2000 == 0:
                os.makedirs('weights', exist_ok=True)
                self.generator.save_weights("weights/gen_epoch%d.h5" % epoch)
                self.discriminator.save_weights("weights/dis_epoch%d.h5" % epoch)

7.代码(六) ——网络参数的定义

这部分代码定义了网络的一些参数,比如输入尺寸(我的数据集图片大小是[64,64,3]),优化器等等。

需要根据自己的数据集图片的大小,更改self.img_rows、self.img_cols、self.channels

    def __init__(self):
        # 图片尺寸 在这里更改!!!!
        self.img_rows = 64
        self.img_cols = 64
        self.channels = 3
        # 输入的图片尺寸
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        self.latent_dim = 100

        # Adam优化器
        optimizer = Adam(0.0002, 0.5)

        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss='binary_crossentropy',
                                   optimizer=optimizer,
                                   metrics=['accuracy'])

        self.generator = self.build_generator()

        z = Input(shape=(self.latent_dim,))
        img = self.generator(z)

        self.discriminator.trainable = False

        validity = self.discriminator(img)

        self.combined = Model(z, validity)
        self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)

8.完整代码

需要根据自己需求,更改

epochs训练次数;batch_size每组的数量;sample_interval多少次输出一张图片;file_path数据集路径

完整代码我已上传到github中,代码地址

六 运行效果

迭代0次

迭代10000次

迭代30000次

迭代50000次

由于时间以及本人显卡配置的限制,只进行了50000次迭代,为了更好的效果可以增加迭代次数。

总结

至此本博客,从0开始搭建GAN网络就结束了,有什么问题欢迎和我讨论。


本文转载自: https://blog.csdn.net/qq_52222102/article/details/125126575
版权归原作者 秦天宝. 所有, 如有侵权,请联系我们删除。

“机器学习——从0开始构建自己的GAN网络”的评论:

还没有评论