0


GAN网络

目录

GAN

生成式对抗网络(GAN, Generative Adversarial Networks)是一种深度学习模型。主要包括两部分:生成模型判别模型。也就是对应神经网络的生成器与判别器:

  1. 生成器G(Generator):通过生成器G生成数据。
  2. 判别器D(Discriminator):判断这张图像是真实的还是机器生成的,目的是判别数据是否是生成器做的“假数据”

生成器与判别器互相对抗,不断调整参数。最终的目的是使判别网络无法判断生成网络的输出结果是否真实。

生成器G是一个生成图片的网络,它接收一个随机的噪声

  1. z
  2. z
  3. z,通过这个噪声生成图片,生成的图片记做
  4. G
  5. (
  6. z
  7. )
  8. G(z)
  9. G(z)。

判别器D判别一张图片是不是“真实的”。它的输入是

  1. x
  2. x
  3. x
  4. x
  5. x
  6. x代表一张图片(其中,
  7. x
  8. x
  9. x包含生成图片和真实图片,对于生成图片有
  10. x
  11. =
  12. G
  13. (
  14. z
  15. )
  16. x=G(z)
  17. x=G(z)),输出
  18. D
  19. (
  20. x
  21. )
  22. D(x)
  23. D(x)代表
  24. x
  25. x
  26. x为真实图片的概率,如果为1,就代表100%是真实的图片,而输出为0,就代表图片0%是真的(或者说100%是假的)

其中网络结构如下所示:

在这里插入图片描述

其中,真实数据分布中的数据与生成数据可以认为是相同形状的。

生成网络G(Generative)

生成网络从隐空间(latent space)中随机采样作为输入,其输出结果需要尽量模仿训练集中的真实样本。

在这里插入图片描述

对抗网络D(Discriminative)

对抗网络也可称判别网络,判别网络的输入则为真实样本或生成网络的输出,其目的是将生成网络的输出从真实样本中尽可能分辨出来。

在这里插入图片描述

两分布之间差异性评价

真实数据分布中的数据

  1. x
  2. x
  3. x 服从分布
  4. x
  5. P
  6. d
  7. a
  8. t
  9. a
  10. (
  11. x
  12. )
  13. x \sim P_{data} (x)
  14. xPdata​(x),生成数据分布中的数据
  15. x
  16. x
  17. x 服从分布
  18. x
  19. P
  20. G
  21. (
  22. x
  23. ;
  24. θ
  25. )
  26. x \sim P_{G} (x;\theta)
  27. xPG​(x;θ)

在这里插入图片描述

那么衡量两个分布之间的差异性指标有:KL散度,JS散度,交叉熵和Wasserstein距离。

KL散度

离散概率分布的KL散度计算公式:

  1. K
  2. L
  3. (
  4. p
  5. q
  6. )
  7. =
  8. p
  9. (
  10. x
  11. )
  12. log
  13. p
  14. (
  15. x
  16. )
  17. q
  18. (
  19. x
  20. )
  21. K L(p \| q)=\sum p(x) \log \frac{p(x)}{q(x)}
  22. KL(pq)=∑p(x)logq(x)p(x)​

连续概率分布的KL散度计算公式:

  1. K
  2. L
  3. (
  4. p
  5. q
  6. )
  7. =
  8. p
  9. (
  10. x
  11. )
  12. log
  13. p
  14. (
  15. x
  16. )
  17. q
  18. (
  19. x
  20. )
  21. d
  22. x
  23. K L(p \| q)=\int p(x) \log \frac{p(x)}{q(x)} d x
  24. KL(pq)=∫p(x)logq(x)p(x)​dx

JS散度

  1. J
  2. S
  3. (
  4. p
  5. q
  6. )
  7. =
  8. 1
  9. 2
  10. K
  11. L
  12. (
  13. p
  14. p
  15. +
  16. q
  17. 2
  18. )
  19. +
  20. 1
  21. 2
  22. K
  23. L
  24. (
  25. q
  26. p
  27. +
  28. q
  29. 2
  30. )
  31. J S(p \| q)=\frac{1}{2} K L\left(p \| \frac{p+q}{2}\right)+\frac{1}{2} K L\left(q \| \frac{p+q}{2}\right)
  32. JS(pq)=21KL(p2p+q​)+21KL(q2p+q​)

损失函数

以分布的角度来看GAN网络结构,然后考虑其损失函数。

在这里插入图片描述

对于生成网络G,其输入的

  1. z
  2. z
  3. z
  4. z
  5. N
  6. (
  7. 0
  8. ,
  9. I
  10. )
  11. \mathbf{z} \sim \mathcal{N}(\mathbf{0}, \mathbf{I})
  12. zN(0,I),表示
  13. z
  14. z
  15. z 服从正态分布的数据),通过训练出来的参数
  16. θ
  17. \theta
  18. θ 的生成网络生成的图片为
  19. G
  20. (
  21. z
  22. ,
  23. θ
  24. )
  25. G(\mathbf{z},\theta)
  26. G(z,θ)

对于判别网络,可以认为是二分类问题,一类是生成网络的输出,即

  1. x
  2. g
  3. e
  4. n
  5. e
  6. r
  7. a
  8. t
  9. i
  10. v
  11. e
  12. =
  13. G
  14. (
  15. z
  16. ,
  17. θ
  18. )
  19. x_{generative} = G(\mathbf{z},\theta)
  20. xgenerative​=G(z,θ);另一类是真实数据
  21. x
  22. r
  23. e
  24. a
  25. l
  26. x_{real}
  27. xreal​,(其中,
  28. x
  29. r
  30. e
  31. a
  32. l
  33. D
  34. r
  35. e
  36. a
  37. l
  38. x_{real} \sim D_{real}
  39. xreal​∼Dreal ,表示
  40. x
  41. r
  42. e
  43. a
  44. l
  45. x_{real}
  46. xreal 服从一种真实的分布distribution)。将
  47. x
  48. x
  49. x (其中,
  50. x
  51. =
  52. x
  53. g
  54. e
  55. n
  56. e
  57. r
  58. a
  59. t
  60. i
  61. v
  62. e
  63. x
  64. r
  65. e
  66. a
  67. l
  68. x = x_{generative} \cup x_{real}
  69. x=xgenerative​∪xreal​) 数据输入到判别网络中,输出结果分别为:
    1. D ( x r e a l , ϕ ) D(\mathbf{x_{real}}, \phi) D(xreal​,ϕ)
    1. D ( x g e n e r a t i v e , ϕ ) = D ( G ( z , θ ) , ϕ ) D(\mathbf{x_{generative}}, \phi) = D(G(\mathbf{z},\theta), \phi) D(xgenerative​,ϕ)=D(G(z,θ),ϕ)

分别从生成网络和判别网络的角度来看:

  1. 对于生成网络的标准就是:我希望我生成的图片越接近真实越好,那么也就是使 D ( x g e n e r a t i v e , ϕ ) = D ( G ( z , θ ) , ϕ ) D(\mathbf{x_{generative}}, \phi) = D(G(\mathbf{z},\theta), \phi) D(xgenerative​,ϕ)=D(G(z,θ),ϕ) 越接近1越好。也就是训练生成网络中的参数 θ \theta θ 满足: max ⁡ θ ( E z ∼ p ( z ) [ log ⁡ D ( G ( z ; θ ) ; ϕ ) ] ) \max {\theta}\left(\mathbb{E}{z \sim p(z)}[\log D(G(\boldsymbol{z} ; \theta) ; \phi)]\right) θmax​(Ez∼p(z)​[logD(G(z;θ);ϕ)])
  2. 对于判别网络的标准就是:我能够很好的区分哪些是真的,哪些是假的。也就是说能够很好的将真的和假的区分开来。也就是希望真实数据的输出 D ( x r e a l , ϕ ) D(\mathbf{x_{real}}, \phi) D(xreal​,ϕ) 越趋近于1,而生成数据的输出 D ( x g e n e r a t i v e , ϕ ) = D ( G ( z , θ ) , ϕ ) D(\mathbf{x_{generative}}, \phi) = D(G(\mathbf{z},\theta), \phi) D(xgenerative​,ϕ)=D(G(z,θ),ϕ) 越趋近于0。> 将其看成二分类问题,二分类问题的损失函数可以使用交叉熵损失函数来表示,对于二分类,只有正样本(label=1)与负样本(label=0)。并且两者概率之和为1。对于一个输入> > > > > x> > > > x> > > x,经过模型输出为> > > > > p> > > (> > > x> > > )> > > > p(x)> > > p(x)。y是真实的标签。于是单个样本的损失函数就是:> > > > > > L> > > O> > > S> > > S> > > => > > −> > > y> > > ∗> > > l> > > o> > > g> > > (> > > p> > > (> > > x> > > )> > > )> > > +> > > (> > > 1> > > −> > > y> > > )> > > l> > > o> > > g> > > (> > > 1> > > −> > > p> > > (> > > x> > > )> > > )> > > > LOSS = -y * log(p(x)) + (1-y)log(1-p(x)) > > > LOSS=−y∗log(p(x))+(1−y)log(1−p(x))> 如果是计算 N 个样本的平均损失函数,只要将 N 个 Loss 叠加起来再除以N就行:> > > > > > L> > > O> > > S> > > S> > > => > > > 1> > > N> > > > > ∑> > > > i> > > => > > 1> > > > N> > > > > y> > > > (> > > i> > > )> > > > > log> > > ⁡> > > (> > > > > p> > > (> > > x> > > )> > > > > (> > > i> > > )> > > > > )> > > +> > > > (> > > 1> > > −> > > > y> > > > (> > > i> > > )> > > > > )> > > > log> > > ⁡> > > > (> > > 1> > > −> > > > > p> > > (> > > x> > > )> > > > > (> > > i> > > )> > > > > )> > > > > LOSS=\frac{1}{N} \sum_{i=1}^{N} y^{(i)} \log ({p(x)}^{(i)})+\left(1-y^{(i)}\right) \log \left(1-{p(x)}^{(i)}\right) > > > LOSS=N1​i=1∑N​y(i)log(p(x)(i))+(1−y(i))log(1−p(x)(i))对于 x ∼ p d a t a ( x ) x \sim p_{data}(x) x∼pdata​(x) 的输出 D ( x ; ϕ ) D(x; \phi) D(x;ϕ)是真实的,也就是标签为1,那么其单个样本损失函数就是: L O S S = − 1 ∗ log ⁡ D ( x ; ϕ ) + ( 1 − 1 ) ( 1 − log ⁡ D ( x ; ϕ ) ) = − log ⁡ D ( x ; ϕ ) LOSS = -1 * \log D(\boldsymbol{x} ; \phi) + (1-1)(1- \log D(\boldsymbol{x} ; \phi)) = - \log D(\boldsymbol{x} ; \phi) LOSS=−1∗logD(x;ϕ)+(1−1)(1−logD(x;ϕ))=−logD(x;ϕ) 平均损失函数就是(其实就是求平均): L O S S = − E x ∼ p d a t a ( x ) [ log ⁡ D ( x ; ϕ ) ] LOSS = - \mathbb{E}{\boldsymbol{x} \sim p{data}(\boldsymbol{x})}[\log D(\boldsymbol{x} ; \phi)] LOSS=−Ex∼pdata​(x)​[logD(x;ϕ)] 那么对于 z ∼ p ( z ) z \sim p(z) z∼p(z) 的输出 D ( G ( z ; θ ) ; ϕ ) ) D(G(\boldsymbol{z} ; \theta) ; \phi)) D(G(z;θ);ϕ))是假的,也就是标签为0,那么其单个样本损失函数就是: L O S S = 0 ∗ D ( G ( z ; θ ) ; ϕ ) ) + ( 1 − 0 ) ( 1 − D ( G ( z ; θ ) ; ϕ ) ) ) = 1 − D ( G ( z ; θ ) ; ϕ ) ) LOSS = 0*D(G(\boldsymbol{z} ; \theta) ; \phi))+(1-0)(1-D(G(\boldsymbol{z} ; \theta) ; \phi))) = 1 - D(G(\boldsymbol{z} ; \theta) ; \phi)) LOSS=0∗D(G(z;θ);ϕ))+(1−0)(1−D(G(z;θ);ϕ)))=1−D(G(z;θ);ϕ)) 平均损失函数就是: L O S S = − E z ∼ p ( z ) [ log ⁡ ( 1 − D ( G ( z ; θ ) ; ϕ ) ) ] LOSS = - \mathbb{E}{\boldsymbol{z} \sim p(\boldsymbol{z})}[\log (1-D(G(\boldsymbol{z} ; \theta) ; \phi))] LOSS=−Ez∼p(z)​[log(1−D(G(z;θ);ϕ))] 由于上面损失函数是负数,并且需要最小化损失函数,那么反过来的最大化损失函数就是: max ⁡ ϕ E x ∼ p d a t a ( x ) [ log ⁡ D ( x ; ϕ ) ] + E z ∼ p ( z ) [ log ⁡ ( 1 − D ( G ( z ; θ ) ; ϕ ) ) ] \max {\phi} \mathbb{E}{\boldsymbol{x} \sim p{data}(\boldsymbol{x})}[\log D(\boldsymbol{x} ; \phi)]+\mathbb{E}_{\boldsymbol{z} \sim p(\boldsymbol{z})}[\log (1-D(G(\boldsymbol{z} ; \theta) ; \phi))] ϕmax​Ex∼pdata​(x)​[logD(x;ϕ)]+Ez∼p(z)​[log(1−D(G(z;θ);ϕ))]

在训练过程中,当G固定的时候,有:

  1. E
  2. z
  3. p
  4. (
  5. z
  6. )
  7. [
  8. log
  9. (
  10. 1
  11. D
  12. (
  13. G
  14. (
  15. z
  16. ;
  17. θ
  18. )
  19. ;
  20. ϕ
  21. )
  22. )
  23. ]
  24. =
  25. E
  26. x
  27. p
  28. g
  29. (
  30. x
  31. )
  32. [
  33. log
  34. (
  35. 1
  36. D
  37. (
  38. x
  39. ;
  40. ϕ
  41. )
  42. )
  43. ]
  44. \mathbb{E}_{\boldsymbol{z} \sim p(\boldsymbol{z})}[\log (1-D(G(\boldsymbol{z} ; \theta) ; \phi))] = \mathbb{E}_{\boldsymbol{x} \sim p_{g} (\boldsymbol{x})}[\log (1-D(x; \phi))]
  45. Ezp(z)​[log(1D(G(z;θ);ϕ))]=Expg​(x)​[log(1D(x;ϕ))]

全局优化首先固定G,然后优化D(这种情况也就是生成数据的分布

  1. x
  2. p
  3. g
  4. (
  5. x
  6. )
  7. \boldsymbol{x} \sim p_{g} (\boldsymbol{x})
  8. xpg​(x) 与真实分布
  9. x
  10. p
  11. d
  12. a
  13. t
  14. a
  15. (
  16. x
  17. )
  18. \boldsymbol{x} \sim p_{data} (\boldsymbol{x})
  19. xpdata​(x)已知),D的最佳情况为:(推导可以看GAN入门理解及公式推导 - 知乎 (zhihu.com))
  20. D
  21. G
  22. (
  23. x
  24. )
  25. =
  26. p
  27. data
  28. (
  29. x
  30. )
  31. p
  32. data
  33. (
  34. x
  35. )
  36. +
  37. p
  38. g
  39. (
  40. x
  41. )
  42. D_{G}^{*}(\boldsymbol{x})=\frac{p_{\text {data }}(\boldsymbol{x})}{p_{\text {data }}(\boldsymbol{x})+p_{g}(\boldsymbol{x})}
  43. DG∗​(x)=pdata ​(x)+pg​(x)pdata ​(x)​

将最佳的D代入目标loss函数,有:

  1. =
  2. 2
  3. log
  4. 2
  5. +
  6. 2
  7. D
  8. J
  9. S
  10. (
  11. P
  12. data
  13. P
  14. g
  15. )
  16. =-2 \log 2+2 D_{JS}\left(P_{\text {data }} \| P_{g}\right)
  17. =−2log2+2DJS​(Pdata ​∥Pg​)

也就是说,原始GAN的loss实际上等价于JS散度

一次代码实验

很多的GAN网络结构代码可以参考:PyTorch-GAN

其中一个网络结构如下:

在这里插入图片描述

在实验中,对应的形状如下所示:

  1. 高斯随机变量:torch.Size([batch_size, 100])
  2. 生成的fake_image, 真实image:torch.Size([batch_size, 3,64,64])
  3. 判别真假:torch.Size([batch_size, 1])

代码如下:

生成网络代码

  1. classGenerator(nn.Module):def__init__(self):super(Generator, self).__init__()defblock(in_feat, out_feat, normalize=True):
  2. layers =[nn.Linear(in_feat, out_feat)]if normalize:
  3. layers.append(nn.BatchNorm1d(out_feat,0.8))
  4. layers.append(nn.LeakyReLU(0.2, inplace=True))return layers
  5. self.model = nn.Sequential(*block(opt.latent_dim,128, normalize=False),*block(128,256),*block(256,512),*block(512,1024),
  6. nn.Linear(1024,int(np.prod(img_shape))),
  7. nn.Tanh())defforward(self, z):
  8. img = self.model(z)
  9. img = img.view(img.shape[0],*img_shape)return img

对抗网络代码

  1. classDiscriminator(nn.Module):def__init__(self):super(Discriminator, self).__init__()
  2. self.model = nn.Sequential(
  3. nn.Linear(int(np.prod(img_shape)),512),
  4. nn.LeakyReLU(0.2, inplace=True),
  5. nn.Linear(512,256),
  6. nn.LeakyReLU(0.2, inplace=True),
  7. nn.Linear(256,1),)

然后在Anime_Faces数据集中训练,获得的1000个epochs后生成的数据有:

在这里插入图片描述

感觉效果不太行,估计是网络结构的局限性。

WGAN

参考论文:[1701.07875] Wasserstein GAN (arxiv.org)

在生成对抗网络中,当判断网络为最优时,生成网络的优化目标是最小化真实分布

  1. p
  2. r
  3. (
  4. x
  5. )
  6. p_r (x)
  7. pr​(x) 和模型分布
  8. p
  9. θ
  10. (
  11. x
  12. )
  13. p_θ (x)
  14. pθ​(x) 之间的JS散度。当两个分布相同时,JS散度为0,最优生成网络对应的损失为−2log2。但是使用JS散度来训练生成对抗网络的一个问题是当两个分布没有重叠时,它们之间的JS散度恒等于常数log2。对生成网络来说,目标函数**关于参数的梯度为0**。

在GAN的基础上加入了Wasserstein距离,Wasserstein距离用于衡量两个分布之间的距离。相比KL散度和JS散度的优势在于即使两个分布没有重叠或者重叠非常少,Wasserstein距离仍然能反映两个分布的远近。其数学公式如下:

  1. W
  2. p
  3. (
  4. q
  5. 1
  6. ,
  7. q
  8. 2
  9. )
  10. =
  11. (
  12. inf
  13. γ
  14. (
  15. x
  16. ,
  17. y
  18. )
  19. Γ
  20. (
  21. q
  22. 1
  23. ,
  24. q
  25. 2
  26. )
  27. E
  28. (
  29. x
  30. ,
  31. y
  32. )
  33. γ
  34. (
  35. x
  36. ,
  37. y
  38. )
  39. [
  40. d
  41. (
  42. x
  43. ,
  44. y
  45. )
  46. p
  47. ]
  48. )
  49. 1
  50. p
  51. W_{p}\left(q_{1}, q_{2}\right)=\left(\inf _{\gamma(x, y) \in \Gamma\left(q_{1}, q_{2}\right)} \mathbb{E}_{(x, y) \sim \gamma(x, y)}\left[d(x, y)^{p}\right]\right)^{\frac{1}{p}}
  52. Wp​(q1​,q2​)=(γ(x,y)∈Γ(q1​,q2​)infE(x,y)∼γ(x,y)​[d(x,y)p])p1

其中 ,

  1. Γ
  2. (
  3. q
  4. 1
  5. ,
  6. q
  7. 2
  8. )
  9. \Gamma\left(q_{1}, q_{2}\right)
  10. Γ(q1​,q2​) 是边际分布为
  11. q
  12. 1
  13. ,
  14. q
  15. 2
  16. q_{1}, q_{2}
  17. q1​,q2 的所有可能的联合分布集合,
  18. d
  19. (
  20. x
  21. ,
  22. y
  23. )
  24. \mathrm{d}(\mathrm{x}, \mathrm{y})
  25. d(x,y)
  26. x
  27. \mathrm{x}
  28. x
  29. y
  30. \mathrm{y}
  31. y 距离, 比如
  32. p
  33. \ell_{p}
  34. p 距离等,
  35. E
  36. \mathbb{E}
  37. E 表示期望,
  38. inf
  39. \inf
  40. inf表示下确界。
  1. 下确界

,如果是一个集合的下确界, 即表示

  1. 小于或等于集合E

的所有其他元素的

  1. 最大元素

, 这个数

  1. 不一定

在集合E中。举例来说:

    1. i n f { 1 , 2 , 3 } = 1 inf\{1,2,3\} = 1 inf{1,2,3}=1; 也就是说集合 { 1 , 2 , 3 } \{1,2,3\} {1,2,3}的下确界为1
    1. i n f { x R , 0 < x < 1 } = 0 inf\{x \in \mathbb{R}, 0<x<1 \} = 0 inf{xR,0<x<1}=0 ;
    1. i n f { ( 1 ) n + 1 / n : n = 1 , 2 , 3 , . . . } = 1 inf\{(-1)^{n} + 1/n : n = 1, 2, 3,...\} = -1 inf{(−1)n+1/n:n=1,2,3,...}=−1;

当然,换一种角度解读:将两个分布看作是两个土堆,联合分布

  1. γ
  2. (
  3. x
  4. ,
  5. y
  6. )
  7. \gamma(x, y)
  8. γ(x,y) 看作是从土堆
  9. q
  10. 1
  11. q_1
  12. q1 的位置
  13. x
  14. x
  15. x 到土堆
  16. q
  17. 2
  18. q_2
  19. q2 的位置
  20. y
  21. y
  22. y 的搬运土的数量。Wasserstein距离可以理解为搬运土堆的最小工作量,也称为推土机距离(Earth-Movers DistanceEMD

在这里插入图片描述

WGAN-GP

参考论文:[1704.00028] Improved Training of Wasserstein GANs (arxiv.org)

WGAN还是有问题:

  • 权重裁剪会导致参数基本都在限制的边界值,极大浪费了模型的参数。
  • 还是很容易梯度消失或者梯度爆炸,需要仔细的调参

WGAN-GP,核心只有一个:Gradient Penalty

Gradient Penalty:判别器相对于输入的梯度的二范数要约束在1附近,这样就能够保证Lipschitz连续。

在这里插入图片描述

  1. # 计算Gradient Penaltydefcompute_gradient_penalty(D, real_samples, fake_samples):"""Calculates the gradient penalty loss for WGAN GP"""# Random weight term for interpolation between real and fake samples
  2. alpha = Tensor(np.random.random((real_samples.size(0),1,1,1)))# Get random interpolation between real and fake samples
  3. interpolates =(alpha * real_samples +((1- alpha)* fake_samples)).requires_grad_(True)
  4. d_interpolates = D(interpolates)
  5. fake = Variable(Tensor(real_samples.shape[0],1).fill_(1.0), requires_grad=False)# Get gradient w.r.t. interpolates
  6. gradients = autograd.grad(
  7. outputs=d_interpolates,
  8. inputs=interpolates,
  9. grad_outputs=fake,
  10. create_graph=True,
  11. retain_graph=True,
  12. only_inputs=True,)[0]
  13. gradients = gradients.view(gradients.size(0),-1)
  14. gradient_penalty =((gradients.norm(2, dim=1)-1)**2).mean()return gradient_penalty

Conditional GAN

条件GAN,顾名思义就是根据条件针对性的生成数据。具体有AC-GAN。


本文转载自: https://blog.csdn.net/weixin_41012765/article/details/125711857
版权归原作者 牵一发而动全身 所有, 如有侵权,请联系我们删除。

“GAN网络”的评论:

还没有评论