0


【深度学习实践(八)】生成对抗网络(GAN)之手写数字生成

文章目录

活动地址:CSDN21天学习挑战赛

学习的最大理由是想摆脱平庸,早一天就多一份人生的精彩;迟一天就多一天平庸的困扰。
热爱写作,愿意让自己成为更好的人…

👉引言💎

在这里插入图片描述
铭记于心🎉✨🎉我唯一知道的,便是我一无所知🎉✨🎉

【深度学习实践(八)】对抗生成网络(GAN)之手写数字生成

一、🌹对抗生成网络

1 定义与背景 :

生成对抗网络GAN是由蒙特利尔大学Ian Goodfellow在2014年提出的机器学习架构,GAN的核心本质是通过对抗训练将随机噪声的分布拉近到真实的数据分布

2 基本结构:

  • GAN本身是一个不断博弈,识别真假的过程,下面通过手写数字生成案例 窥探GAN对抗生成网络的原理及操作流程:

在这里插入图片描述

  • 定义一个模型来作为生成器(图三中蓝色部分Generator),能够输入一个向量,输出手写数字大小的像素图像(生成噪声
  • 定义一个分类器来作为判别器(图三中红色部分Discriminator)用来判别图片是真的还是假的(或者说是来自数据集中的还是生成器中生成的),输入为手写图片,输出为判别图片的标签

并且,既然是神经网络,那么模型就可以根据 外界反馈 自行调整参数,也就是会根据 标签匹配结果进行相应的学习与调整, 训练完成后 可以达到 ** 生成以假乱真的 手写数字图片效果**

二、🌹模型训练

💎1 设置GPU

  • GPU能够为大量数据的运算提供算力支持
import tensorflow as tf
gpus = tf.config.list_physical_devices("GPU")if gpus:
    gpu0 = gpus[0] #如果有多个GPU,仅使用第0个GPU
    tf.config.experimental.set_memory_growth(gpu0, True) #设置GPU显存用量按需使用
    tf.config.set_visible_devices([gpu0],"GPU")
    
warnings.filterwarnings("ignore")             
plt.rcParams['font.sans-serif']=['SimHei']  
plt.rcParams['axes.unicode_minus']= False    

💎2 构建GAN对抗网络生成器

def build_generator():
    # ======================================= #
    #     生成器,输入一串随机数字生成图片
    # ======================================= #
    model =Sequential([
        layers.Dense(256, input_dim=latent_dim),
        layers.LeakyReLU(alpha=0.2),               # 高级一点的激活函数
        layers.BatchNormalization(momentum=0.8),   # BN 归一化
        
        layers.Dense(512),
        layers.LeakyReLU(alpha=0.2),
        layers.BatchNormalization(momentum=0.8),
        
        layers.Dense(1024),
        layers.LeakyReLU(alpha=0.2),
        layers.BatchNormalization(momentum=0.8),
        
        layers.Dense(np.prod(img_shape), activation='tanh'),
        layers.Reshape(img_shape)])

    noise = layers.Input(shape=(latent_dim,))
    img =model(noise)returnModel(noise, img)

💎3 构造鉴别器

def build_discriminator():
    model =Sequential([
        layers.Flatten(input_shape=img_shape),
        layers.Dense(512),
        layers.LeakyReLU(alpha=0.2),
        layers.Dense(256),
        layers.LeakyReLU(alpha=0.2),
        layers.Dense(1, activation='sigmoid')])

    img = layers.Input(shape=img_shape)
    validity =model(img)
  • 最后传入img以及model参数构造Model对象return Model(img, validity)

  • 鉴别器训练原理:通过对输入的图片进行鉴别,从而达到提升的效果

  • 生成器训练原理:通过鉴别器对其生成的图片进行鉴别,来实现提升


💎4 构造生成器

# 创建判别器
dis =build_discriminator()

# 定义优化器
optimizer = tf.keras.optimizers.Adam(1e-4)
dis.compile(loss='binary_crossentropy',
                      optimizer=optimizer,
                      metrics=['accuracy'])
                      
# 创建生成器                       
generator =build_generator()
gan_input = layers.Input(shape=(latent_dim,))
img =generator(gan_input) 

#训练generate时候停止训练判别器
dis.trainable = False  

# 测试:对生成的假图片进行预测 
validity =discriminator(img)
combined =Model(gan_input, validity)
combined.compile(loss='binary_crossentropy', optimizer=optimizer)

💎5 训练模型

  • train_on_batch详解:keras在compile完模型后需要训练,除了常用的model.fit()与model.fit_generator外 还有model.train_on_bantch作用:对一批样品进行单梯度更新,即对一个epoch中的一个样本进行一次训练
  • 使用train_on_batch优点:更精细自定义训练过程,更精准的收集 loss 和 metrics分布训练模型-GAN生成对抗神经网络的实现多GPU训练保存模型更加方便def train(epochs, batch_size=128, sample_interval=50):
  • 加载数据(train_images,_),(_,_)= tf.keras.datasets.mnist.load_data()
  • 将图片标准化到 [-1, 1] 区间内train_images =(train_images -127.5)/127.5
  • 数据 train_images = np.expand_dims(train_images, axis=3)
  • 创建标签true = np.ones((batch_size,1))fake = np.zeros((batch_size,1))
  • 开始训练for epoch in range(epochs): idx = np.random.randint(0, train_images.shape[0], batch_size) imgs = train_images[idx] noise = np.random.normal(0,1,(batch_size, latent_dim)) gen_imgs = generator.predict(noise) d_loss_true = discriminator.train_on_batch(imgs, true) d_loss_fake = discriminator.train_on_batch(gen_imgs, fake) d_loss =0.5* np.add(d_loss_true, d_loss_fake) noise = np.random.normal(0,1,(batch_size, latent_dim)) g_loss = combined.train_on_batch(noise, true)print("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]"%(epoch, d_loss[0],100*d_loss[1], g_loss)) # 保存样例图片 if epoch % sample_interval ==0:sample_images(epoch)

在这里插入图片描述

  • 动图展示def compose_gif(): # 图片地址 data_dir ="F:/jupyter notebook/DL-100-days/code/images" data_dir = pathlib.Path(data_dir) paths =list(data_dir.glob('*')) gif_images =[]for path in paths:print(path) gif_images.append(imageio.imread(path)) imageio.mimsave("test.gif",gif_images,fps=2)

🌹写在最后💖
路漫漫其修远兮,吾将上下而求索!伙伴们,再见!🌹🌹🌹在这里插入图片描述


本文转载自: https://blog.csdn.net/runofsun/article/details/126446166
版权归原作者 梦想new的出来 所有, 如有侵权,请联系我们删除。

“【深度学习实践(八)】生成对抗网络(GAN)之手写数字生成”的评论:

还没有评论