0


Vision Transformer详解(附代码)

1 引言

  1. T
  2. r
  3. a
  4. n
  5. s
  6. f
  7. o
  8. r
  9. m
  10. e
  11. r
  12. \mathrm{Transformer}
  13. Transformer
  14. N
  15. L
  16. P
  17. \mathrm{NLP}
  18. NLP中大获成功,
  19. V
  20. i
  21. s
  22. i
  23. o
  24. n
  25. T
  26. r
  27. a
  28. n
  29. s
  30. f
  31. o
  32. r
  33. m
  34. e
  35. r
  36. \mathrm{Vision\text{ }Transformer}
  37. Vision Transformer则将
  38. T
  39. r
  40. a
  41. n
  42. s
  43. f
  44. o
  45. r
  46. m
  47. e
  48. r
  49. \mathrm{Transformer}
  50. Transformer模型架构扩展到计算机视觉的领域中,并且它可以很好的地取代卷积操作,在不依赖卷积的情况下,依然可以在图像分类任务上达到很好的效果。卷积操作只能考虑到局部的特征信息,而
  51. T
  52. r
  53. a
  54. n
  55. s
  56. f
  57. o
  58. r
  59. m
  60. e
  61. r
  62. \mathrm{Transformer}
  63. Transformer中的注意力机制可以综合考量全局的特征信息。
  64. V
  65. i
  66. s
  67. i
  68. o
  69. n
  70. T
  71. r
  72. a
  73. n
  74. s
  75. f
  76. o
  77. r
  78. m
  79. e
  80. r
  81. \mathrm{Vision\text{ }Transformer}
  82. Vision Transformer尽力做到在不改变
  83. T
  84. r
  85. a
  86. n
  87. s
  88. f
  89. o
  90. r
  91. m
  92. e
  93. r
  94. \mathrm{Transformer}
  95. Transformer
  96. E
  97. n
  98. c
  99. o
  100. d
  101. e
  102. r
  103. \mathrm{Encoder}
  104. Encoder架构的前提下,直接将其从
  105. N
  106. L
  107. P
  108. \mathrm{NLP}
  109. NLP领域迁移到计算机视觉领域中,目的是让原始的
  110. T
  111. r
  112. a
  113. n
  114. s
  115. f
  116. o
  117. r
  118. m
  119. e
  120. r
  121. \mathrm{Transformer}
  122. Transformer模型开箱即用。如果想要了解
  123. T
  124. r
  125. a
  126. n
  127. s
  128. f
  129. o
  130. r
  131. m
  132. e
  133. r
  134. \mathrm{Transformer}
  135. Transformer原理详细的介绍可以看我的上一篇文章《Transformer详解(附代码)》。

2 注意力机制应用

在正式详细介绍

  1. V
  2. i
  3. s
  4. i
  5. o
  6. n
  7. T
  8. r
  9. a
  10. n
  11. s
  12. f
  13. o
  14. r
  15. m
  16. e
  17. r
  18. \mathrm{Vision\text{ }Transformer}
  19. Vision Transformer之前,先介绍两个注意力机制在计算机视觉中应用的例子。
  20. V
  21. i
  22. s
  23. i
  24. o
  25. n
  26. T
  27. r
  28. a
  29. n
  30. s
  31. f
  32. o
  33. r
  34. m
  35. e
  36. r
  37. \mathrm{Vision\text{ }Transformer}
  38. Vision Transformer并不是第一个将注意力机制应用到计算机视觉的领域中去的,其中
  39. S
  40. A
  41. G
  42. A
  43. N
  44. \mathrm{SAGAN}
  45. SAGAN
  46. A
  47. t
  48. t
  49. n
  50. G
  51. A
  52. N
  53. \mathrm{AttnGAN}
  54. AttnGAN就早已经在
  55. G
  56. A
  57. N
  58. \mathrm{GAN}
  59. GAN的框架中引入了注意力机制,并且它们大大提高了图像生成的质量。

2.1 Self-Attention GAN

  1. S
  2. A
  3. G
  4. A
  5. N
  6. \mathrm{SAGAN}
  7. SAGAN
  8. G
  9. A
  10. N
  11. \mathrm{GAN}
  12. GAN的框架中利用自注意力机制来捕获图像特征的长距离依赖关系,使得合成的图像中考量了所有的图像特征信息。
  13. S
  14. A
  15. G
  16. A
  17. N
  18. \mathrm{SAGAN}
  19. SAGAN中自注意力机制的操作原理如下图所示。

给定一个

  1. 3
  2. 3
  3. 3通道的输入特征图
  4. X
  5. =
  6. (
  7. X
  8. 1
  9. ,
  10. X
  11. 2
  12. ,
  13. X
  14. 3
  15. )
  16. R
  17. 3
  18. ×
  19. 3
  20. ×
  21. 3
  22. X=(X^1,X^2,X^3)\in \mathbb{R}^{3\times 3\times 3}
  23. X=(X1,X2,X3)∈R3×3×3,其中
  24. X
  25. i
  26. R
  27. 3
  28. ×
  29. 3
  30. X^{i}\in \mathbb{R}^{3\times 3}
  31. XiR3×3
  32. i
  33. {
  34. 1
  35. ,
  36. 2
  37. ,
  38. 3
  39. }
  40. i\in\{1,2,3\}
  41. i∈{1,2,3}。将
  42. X
  43. X
  44. X分别输入到三个不同的
  45. 1
  46. ×
  47. 1
  48. 1\times 1
  49. 1×1的卷积层中,并生成
  50. q
  51. u
  52. e
  53. r
  54. y
  55. \mathrm{query}
  56. query特征图
  57. Q
  58. R
  59. 3
  60. ×
  61. 3
  62. ×
  63. 3
  64. Q\in \mathbb{R}^{3\times 3\times 3}
  65. QR3×3×3
  66. k
  67. e
  68. y
  69. \mathrm{key}
  70. key特征图
  71. K
  72. R
  73. 3
  74. ×
  75. 3
  76. ×
  77. 3
  78. K\in \mathbb{R}^{3\times 3\times 3}
  79. KR3×3×3
  80. v
  81. a
  82. l
  83. u
  84. e
  85. \mathrm{value}
  86. value特征图
  87. V
  88. R
  89. 3
  90. ×
  91. 3
  92. ×
  93. 3
  94. V\in \mathbb{R}^{3\times 3\times 3}
  95. VR3×3×3。生成
  96. Q
  97. Q
  98. Q具体的计算过程为,给定三个卷积核
  99. W
  100. q
  101. 1
  102. W^{q1}
  103. Wq1
  104. W
  105. q
  106. 2
  107. W^{q2}
  108. Wq2
  109. W
  110. q
  111. 3
  112. R
  113. 1
  114. ×
  115. 1
  116. ×
  117. 3
  118. W^{q3}\in\mathbb{R}^{1\times1\times3}
  119. Wq3R1×1×3,并用这三个卷积核分别与
  120. X
  121. X
  122. X做卷积运算得到
  123. Q
  124. 1
  125. Q^1
  126. Q1
  127. Q
  128. 2
  129. Q^2
  130. Q2
  131. Q
  132. 3
  133. R
  134. 3
  135. ×
  136. 3
  137. Q^3\in \mathbb{R}^{3 \times 3}
  138. Q3R3×3,即
  139. {
  140. Q
  141. 1
  142. =
  143. X
  144. W
  145. q
  146. 1
  147. Q
  148. 2
  149. =
  150. X
  151. W
  152. q
  153. 2
  154. Q
  155. 3
  156. =
  157. X
  158. W
  159. q
  160. 3
  161. \left\{\begin{aligned}Q^1&=X * W^{q1}\\Q^2&=X * W^{q2}\\Q^3&=X*W^{q3}\end{aligned}\right.
  162. ⎩⎪⎨⎪⎧​Q1Q2Q3​=XWq1=XWq2=XWq3​其中
  163. *
  164. ∗表示卷积运算符号。同理生成
  165. K
  166. K
  167. K
  168. V
  169. V
  170. V的计算过程与
  171. Q
  172. Q
  173. Q的计算过程类似。然后再利用
  174. Q
  175. Q
  176. Q
  177. K
  178. K
  179. K进行注意力分数的计算得到矩阵
  180. A
  181. R
  182. 3
  183. ×
  184. 3
  185. A\in \mathbb{R}^{3 \times 3}
  186. AR3×3,其中矩阵
  187. A
  188. A
  189. A的元素
  190. a
  191. m
  192. l
  193. a_{ml}
  194. aml​的计算公式为
  195. a
  196. m
  197. l
  198. =
  199. Q
  200. m
  201. K
  202. l
  203. ,
  204. m
  205. {
  206. 1
  207. ,
  208. 2
  209. ,
  210. 3
  211. }
  212. ,
  213. l
  214. {
  215. 1
  216. ,
  217. 2
  218. ,
  219. 3
  220. }
  221. a_{ml}=Q^m * K^l,\quad m \in \{1,2,3\},l\in \{1,2,3\}
  222. aml​=QmKl,m∈{1,2,3},l∈{1,2,3}再对矩阵
  223. A
  224. A
  225. A利用
  226. s
  227. o
  228. f
  229. t
  230. m
  231. a
  232. x
  233. \mathrm{softmax}
  234. softmax函数进行注意力分布的计算得到注意力分布矩阵
  235. S
  236. R
  237. 3
  238. ×
  239. 3
  240. S\in \mathbb{R}^{3\times 3}
  241. SR3×3,其中矩阵
  242. S
  243. S
  244. S的元素
  245. s
  246. m
  247. l
  248. s_{ml}
  249. sml​的计算公式为
  250. s
  251. m
  252. l
  253. =
  254. exp
  255. (
  256. a
  257. m
  258. l
  259. )
  260. i
  261. =
  262. j
  263. 3
  264. exp
  265. (
  266. a
  267. m
  268. j
  269. )
  270. ,
  271. m
  272. {
  273. 1
  274. ,
  275. 2
  276. ,
  277. 3
  278. }
  279. ,
  280. l
  281. {
  282. 1
  283. ,
  284. 2
  285. ,
  286. 3
  287. }
  288. s_{ml}=\frac{\exp(a_{ml})}{\sum\limits_{i=j}^{3}\exp(a_{mj})},\quad m \in \{1,2,3\},l\in\{1,2,3\}
  289. sml​=i=j3exp(amj​)exp(aml​)​,m∈{1,2,3},l∈{1,2,3}最后利用注意力分布矩阵
  290. S
  291. S
  292. S
  293. v
  294. a
  295. l
  296. u
  297. e
  298. \mathrm{value}
  299. value特征图
  300. V
  301. V
  302. V得到最后的输出
  303. O
  304. =
  305. (
  306. O
  307. 1
  308. ,
  309. O
  310. 2
  311. ,
  312. O
  313. 3
  314. )
  315. R
  316. 3
  317. ×
  318. 3
  319. ×
  320. 3
  321. O=(O^1,O^2,O^3)\in \mathbb{R}^{3\times 3\times 3}
  322. O=(O1,O2,O3)∈R3×3×3,即
  323. {
  324. O
  325. 1
  326. =
  327. s
  328. 11
  329. V
  330. 1
  331. +
  332. s
  333. 12
  334. V
  335. 2
  336. +
  337. s
  338. 13
  339. V
  340. 3
  341. O
  342. 2
  343. =
  344. s
  345. 21
  346. V
  347. 1
  348. +
  349. s
  350. 22
  351. V
  352. 2
  353. +
  354. s
  355. 23
  356. V
  357. 3
  358. O
  359. 3
  360. =
  361. s
  362. 31
  363. V
  364. 1
  365. +
  366. s
  367. 32
  368. V
  369. 2
  370. +
  371. s
  372. 33
  373. V
  374. 3
  375. \left\{\begin{aligned}O^1&=s_{11}\cdot V^1+s_{12}\cdot V^2+s_{13}\cdot V^3\\O^2&=s_{21}\cdot V^1+s_{22}\cdot V^2+s_{23}\cdot V^3\\O^3&=s_{31}\cdot V^1+s_{32}\cdot V^2+s_{33}\cdot V^3\end{aligned}\right.
  376. ⎩⎪⎨⎪⎧​O1O2O3​=s11​⋅V1+s12​⋅V2+s13​⋅V3=s21​⋅V1+s22​⋅V2+s23​⋅V3=s31​⋅V1+s32​⋅V2+s33​⋅V3

2.2 AttnGAN

  1. A
  2. t
  3. t
  4. n
  5. G
  6. A
  7. N
  8. \mathrm{AttnGAN}
  9. AttnGAN通过利用注意力机制来实现多阶段细颗粒度的文本到图像的生成,它可以通过关注自然语言中的一些重要单词来对图像的不同子区域进行合成。比如通过文本“一只鸟有黄色的羽毛和黑色的眼睛”来生成图像时,会对关键词“鸟”,“羽毛”,“眼睛”,“黄色”,“黑色”给予不同的生成权重,并根据这些关键词的引导在图像的不同的子区域中进行细节的丰富。
  10. A
  11. t
  12. t
  13. n
  14. G
  15. A
  16. N
  17. \mathrm{AttnGAN}
  18. AttnGAN中注意力机制的操作原理如下图所示。

 给定输入图像特征向量

  1. h
  2. =
  3. (
  4. h
  5. 1
  6. ,
  7. h
  8. 2
  9. ,
  10. h
  11. 3
  12. ,
  13. h
  14. 4
  15. )
  16. R
  17. D
  18. ^
  19. ×
  20. 4
  21. h=(h^1,h^2,h^3,h^4)\in\mathbb{R}^{\hat{D}\times 4}
  22. h=(h1,h2,h3,h4)∈RD4和词特征向量
  23. e
  24. =
  25. (
  26. e
  27. 1
  28. ,
  29. e
  30. 2
  31. ,
  32. e
  33. 3
  34. ,
  35. e
  36. 4
  37. )
  38. e=(e^1,e^2,e^3,e^4)
  39. e=(e1,e2,e3,e4),其中
  40. h
  41. i
  42. R
  43. D
  44. ^
  45. ×
  46. 1
  47. h^i\in \mathbb{R}^{\hat{D}\times 1}
  48. hiRD1
  49. e
  50. i
  51. R
  52. D
  53. ×
  54. 1
  55. e^i\in \mathbb{R}^{D\times 1}
  56. eiRD×1
  57. i
  58. {
  59. 1
  60. ,
  61. 2
  62. ,
  63. 3
  64. ,
  65. 4
  66. }
  67. i\in \{1,2,3,4\}
  68. i∈{1,2,3,4}。首先利用矩阵
  69. W
  70. W
  71. W进行线性变换将词特征空间
  72. R
  73. D
  74. \mathbb{R}^{D}
  75. RD的向量转换成图像特征空间
  76. R
  77. D
  78. ^
  79. \mathbb{R}^{\hat{D}}
  80. RD^的向量,则有
  81. e
  82. ^
  83. =
  84. W
  85. e
  86. =
  87. (
  88. e
  89. ^
  90. 1
  91. ,
  92. e
  93. ^
  94. 2
  95. ,
  96. e
  97. ^
  98. 3
  99. ,
  100. e
  101. ^
  102. 4
  103. )
  104. R
  105. D
  106. ^
  107. ×
  108. 4
  109. \hat{e}=W\cdot e=(\hat{e}^1,\hat{e}^2,\hat{e}^3,\hat{e}^4)\in \mathbb{R}^{\hat{D}\times 4}
  110. e^=We=(e^1,e^2,e^3,e^4)∈RD4然后再利用转换后的词特征
  111. e
  112. ^
  113. \hat{e}
  114. e^与图像特征
  115. h
  116. h
  117. h进行注意力分数的计算得到注意力分数矩阵
  118. S
  119. S
  120. S,其中的分量
  121. s
  122. i
  123. j
  124. s_{ij}
  125. sij​的计算公式为
  126. s
  127. i
  128. j
  129. =
  130. (
  131. h
  132. i
  133. )
  134. e
  135. ^
  136. j
  137. ,
  138. i
  139. {
  140. 1
  141. ,
  142. 2
  143. ,
  144. 3
  145. ,
  146. 4
  147. }
  148. ,
  149. j
  150. {
  151. 1
  152. ,
  153. 2
  154. ,
  155. 3
  156. ,
  157. 4
  158. }
  159. s_{ij}=(h^i)^{\top}\cdot \hat{e}^j,\quad i\in \{1,2,3,4\},j\in\{1,2,3,4\}
  160. sij​=(hi)⊤⋅e^j,i∈{1,2,3,4},j∈{1,2,3,4} 再对矩阵
  161. S
  162. S
  163. S利用
  164. s
  165. o
  166. f
  167. t
  168. m
  169. a
  170. x
  171. \mathrm{softmax}
  172. softmax函数进行注意力分布的计算得到注意力分布矩阵
  173. β
  174. R
  175. 4
  176. ×
  177. 4
  178. \beta\in \mathbb{R}^{4\times 4}
  179. β∈R4×4,其中矩阵
  180. β
  181. \beta
  182. β的元素
  183. β
  184. i
  185. j
  186. \beta_{ij}
  187. βij​的计算公式为
  188. β
  189. i
  190. j
  191. =
  192. exp
  193. (
  194. s
  195. i
  196. j
  197. )
  198. k
  199. =
  200. 1
  201. 3
  202. exp
  203. (
  204. s
  205. i
  206. k
  207. )
  208. ,
  209. i
  210. {
  211. 1
  212. ,
  213. 2
  214. ,
  215. 3
  216. ,
  217. 4
  218. }
  219. ,
  220. l
  221. {
  222. 1
  223. ,
  224. 2
  225. ,
  226. 3
  227. ,
  228. 4
  229. }
  230. \beta_{ij}=\frac{\exp(s_{ij})}{\sum\limits_{k=1}^{3}\exp(s_{ik})},\quad i \in \{1,2,3,4\},l\in\{1,2,3,4\}
  231. βij​=k=13exp(sik​)exp(sij​)​,i∈{1,2,3,4},l∈{1,2,3,4}最后利用注意力分布矩阵
  232. β
  233. \beta
  234. β和图像特征
  235. h
  236. h
  237. h得到最后的输出
  238. o
  239. =
  240. (
  241. o
  242. 1
  243. ,
  244. o
  245. 2
  246. ,
  247. o
  248. 3
  249. ,
  250. o
  251. 4
  252. )
  253. R
  254. D
  255. ^
  256. ×
  257. 4
  258. o=(o^1,o^2,o^3,o^4)\in \mathbb{R}^{\hat{D}\times 4}
  259. o=(o1,o2,o3,o4)∈RD4,即
  260. {
  261. o
  262. 1
  263. =
  264. β
  265. 11
  266. h
  267. 1
  268. +
  269. β
  270. 12
  271. h
  272. 2
  273. +
  274. β
  275. 13
  276. h
  277. 3
  278. +
  279. β
  280. 14
  281. h
  282. 4
  283. o
  284. 2
  285. =
  286. β
  287. 21
  288. h
  289. 1
  290. +
  291. β
  292. 22
  293. h
  294. 2
  295. +
  296. β
  297. 23
  298. h
  299. 3
  300. +
  301. β
  302. 24
  303. h
  304. 4
  305. o
  306. 3
  307. =
  308. β
  309. 31
  310. h
  311. 1
  312. +
  313. β
  314. 32
  315. h
  316. 2
  317. +
  318. β
  319. 33
  320. h
  321. 3
  322. +
  323. β
  324. 34
  325. h
  326. 4
  327. o
  328. 4
  329. =
  330. β
  331. 41
  332. h
  333. 1
  334. +
  335. β
  336. 42
  337. h
  338. 2
  339. +
  340. β
  341. 43
  342. h
  343. 3
  344. +
  345. β
  346. 44
  347. h
  348. 4
  349. \left\{\begin{aligned}o^1&=\beta_{11}\cdot h^1+\beta_{12}\cdot h^2+\beta_{13}\cdot h^3+\beta_{14}\cdot h^4\\o^2&=\beta_{21}\cdot h^1+\beta_{22}\cdot h^2+\beta_{23}\cdot h^3+\beta_{24}\cdot h^4\\o^3&=\beta_{31}\cdot h^1+\beta_{32}\cdot h^2+\beta_{33}\cdot h^3+\beta_{34}\cdot h^4\\o^4&=\beta_{41}\cdot h^1+\beta_{42}\cdot h^2+\beta_{43}\cdot h^3+\beta_{44}\cdot h^4\end{aligned}\right.
  350. ⎩⎪⎪⎪⎪⎨⎪⎪⎪⎪⎧​o1o2o3o4​=β11​⋅h112​⋅h213​⋅h314​⋅h421​⋅h122​⋅h223​⋅h324​⋅h431​⋅h132​⋅h233​⋅h334​⋅h441​⋅h142​⋅h243​⋅h344​⋅h4

3 Vision Transformer

本节主要详细介绍

  1. V
  2. i
  3. s
  4. i
  5. o
  6. n
  7. T
  8. r
  9. a
  10. n
  11. s
  12. f
  13. o
  14. r
  15. m
  16. e
  17. r
  18. \mathrm{Vision\text{ }Transformer}
  19. Vision Transformer的工作原理,3.1节是关于
  20. V
  21. i
  22. s
  23. i
  24. o
  25. n
  26. T
  27. r
  28. a
  29. n
  30. s
  31. f
  32. o
  33. r
  34. m
  35. e
  36. r
  37. \mathrm{Vision\text{ }Transformer}
  38. Vision Transformer的整体框架,3.2节是关于
  39. T
  40. r
  41. a
  42. n
  43. s
  44. f
  45. o
  46. r
  47. m
  48. e
  49. r
  50. E
  51. n
  52. c
  53. o
  54. d
  55. e
  56. r
  57. \mathrm{Transformer\text{ }Encoder}
  58. Transformer Encoder的内部操作细节。对于
  59. T
  60. r
  61. a
  62. n
  63. s
  64. f
  65. o
  66. r
  67. m
  68. e
  69. r
  70. E
  71. n
  72. c
  73. o
  74. d
  75. e
  76. r
  77. \mathrm{Transformer\text{ }Encoder}
  78. Transformer Encoder
  79. M
  80. u
  81. l
  82. t
  83. i
  84. \mathrm{Multi}
  85. Multi-
  86. H
  87. e
  88. a
  89. d
  90. A
  91. t
  92. t
  93. e
  94. n
  95. t
  96. i
  97. o
  98. n
  99. \mathrm{Head\text{ }Attention}
  100. Head Attention的原理本文不会赘述,具体想了解的可以参考上一篇文章《Transformer详解(附代码)》中相关原理的介绍。不难发现,不管是自然语言处理中的
  101. T
  102. r
  103. a
  104. n
  105. s
  106. f
  107. o
  108. r
  109. m
  110. e
  111. r
  112. \mathrm{Transformer}
  113. Transformer,还是计算机视觉中图像生成的
  114. S
  115. A
  116. G
  117. A
  118. N
  119. \mathrm{SAGAN}
  120. SAGAN,以及文本生成图像的
  121. A
  122. t
  123. t
  124. n
  125. G
  126. A
  127. N
  128. \mathrm{AttnGAN}
  129. AttnGAN,它们核心模块中注意力机制的主要目的就是求出注意力分布。

3.1 Vision Transformer整体框架

如果下图所示为

  1. V
  2. i
  3. s
  4. i
  5. o
  6. n
  7. T
  8. r
  9. a
  10. n
  11. s
  12. f
  13. o
  14. r
  15. m
  16. e
  17. r
  18. \mathrm{Vision\text{ }Transformer}
  19. Vision Transformer的整体框架以及相应的训练流程
  • 给定一张图片 X ∈ R 3 n × 3 n X\in \mathbb{R}^{3n\times 3n} X∈R3n×3n,并将它分割成 9 9 9个 p a t c h \mathrm{patch} patch分别为 x 1 , ⋯   , x 9 ∈ R n × n x^1,\cdots,x^9\in\mathbb{R}^{n\times n} x1,⋯,x9∈Rn×n。然后再将这个 9 9 9个 p a t c h \mathrm{patch} patch拉平,则有 x 1 , ⋯   , x 9 ∈ R n 2 x^1,\cdots,x^9\in\mathbb{R}^{n^2} x1,⋯,x9∈Rn2
  • 利用矩阵 W ∈ R l × n 2 W\in \mathbb{R}^{l \times n^2} W∈Rl×n2将拉平后的向量 x i ∈ R n 2 , i ∈ { 1 , ⋯   , 9 } x^i\in\mathbb{R}^{n^2},i\in{1,\cdots,9} xi∈Rn2,i∈{1,⋯,9}经过线性变换得到图像编码向量 z i ∈ R l , i ∈ { 1 , ⋯   , 9 } z^i\in \mathbb{R}^{l},i\in{1,\cdots,9} zi∈Rl,i∈{1,⋯,9},具体的计算公式为 z i = W ⋅ x i , i ∈ { 1 , ⋯ 9 } z^i = W\cdot x^i,\quad i\in{1,\cdots9} zi=W⋅xi,i∈{1,⋯9}
  • 然后将图像编码向量 z i , i ∈ { 1 , ⋅ , 9 } z^{i},i\in{1,\cdot,9} zi,i∈{1,⋅,9}和类编码向量 z 0 z^0 z0分别与对应的位置编进行加和得到输入编码向量,则有 z i + p i ∈ R l , i ∈ { 0 , ⋯ 9 } z^{i}+p^{i}\in\mathbb{R}^l,\quad i\in{0,\cdots 9} zi+pi∈Rl,i∈{0,⋯9}
  • 接着将输入编码向量输入到 V i s i o n T r a n s f o r m e r E n c o d e r \mathrm{Vision\text{ }Transformer\text{ }Encoder} Vision Transformer Encoder中得到对应的输出 o i ∈ R l , i ∈ { 0 , ⋯   , 9 } o^i\in \mathbb{R}^l,i\in{0,\cdots,9} oi∈Rl,i∈{0,⋯,9}
  • 最后将类编码向量 o 0 o^0 o0输入全连接神经网络中 M L P \mathrm{MLP} MLP得到类别预测向量 y ^ ∈ R c \hat{y}\in\mathbb{R}^c y^​∈Rc,并与真实类别向量 y ∈ R c y\in\mathbb{R}^c y∈Rc计算交叉熵损失得到损失值 l o s s loss loss,利用优化算法更新模型的权重参数

注意事项: 看到这里可能会有一个疑问为什么预测类别的时候只用到了类别编码向量

  1. o
  2. 0
  3. o^0
  4. o0
  5. V
  6. i
  7. s
  8. i
  9. o
  10. n
  11. T
  12. r
  13. a
  14. n
  15. s
  16. f
  17. o
  18. r
  19. m
  20. e
  21. r
  22. E
  23. n
  24. c
  25. o
  26. d
  27. e
  28. r
  29. \mathrm{Vision\text{ }Transformer\text{ }Encoder}
  30. Vision Transformer Encoder其它的输出为什么没有输入到
  31. M
  32. L
  33. P
  34. \mathrm{MLP}
  35. MLP中?为了回答这个问题,我们令函数
  36. f
  37. 0
  38. (
  39. )
  40. f_0(\cdot)
  41. f0​(⋅)为
  42. V
  43. i
  44. s
  45. i
  46. o
  47. n
  48. T
  49. r
  50. a
  51. n
  52. s
  53. f
  54. o
  55. r
  56. m
  57. e
  58. r
  59. E
  60. n
  61. c
  62. o
  63. d
  64. e
  65. r
  66. \mathrm{Vision\text{ }Transformer\text{ }Encoder}
  67. Vision Transformer Encoder,则类编码向量
  68. o
  69. 0
  70. o^{0}
  71. o0可以表示为
  72. o
  73. 0
  74. =
  75. f
  76. 0
  77. (
  78. z
  79. 0
  80. +
  81. p
  82. 0
  83. ,
  84. ,
  85. z
  86. 9
  87. +
  88. p
  89. 9
  90. )
  91. o^0=f_0(z^0+p^0,\cdots,z^9+p^9)
  92. o0=f0​(z0+p0,⋯,z9+p9)由上公式可以发现,类编码向量
  93. o
  94. 0
  95. o^{0}
  96. o0是属于高层特征,其实它综合了所有的图像编码信息,所以可以用它来进行分类,这个可以类比在卷积神经网络中最后的类别输出向量其实就是一层层卷积得到的高层特征。

3.2 Transformer Encoder操作原理

如下图所示分别为

  1. V
  2. i
  3. s
  4. i
  5. o
  6. n
  7. T
  8. r
  9. a
  10. n
  11. s
  12. f
  13. o
  14. r
  15. m
  16. e
  17. r
  18. E
  19. n
  20. c
  21. o
  22. d
  23. e
  24. r
  25. \mathrm{Vision\text{ }Transformer\text{ }Encoder}
  26. Vision Transformer Encoder模型结构图和原始
  27. T
  28. r
  29. a
  30. n
  31. s
  32. f
  33. o
  34. r
  35. m
  36. e
  37. r
  38. E
  39. n
  40. c
  41. o
  42. d
  43. e
  44. r
  45. \mathrm{Transformer\text{ }Encoder}
  46. Transformer Encoder的模型结构图。可以直观的发现
  47. V
  48. i
  49. s
  50. i
  51. o
  52. n
  53. T
  54. r
  55. a
  56. n
  57. s
  58. f
  59. o
  60. r
  61. m
  62. e
  63. r
  64. E
  65. n
  66. c
  67. o
  68. d
  69. e
  70. r
  71. \mathrm{Vision\text{ }Transformer\text{ }Encoder}
  72. Vision Transformer Encoder
  73. T
  74. r
  75. a
  76. n
  77. s
  78. f
  79. o
  80. r
  81. m
  82. e
  83. r
  84. E
  85. n
  86. c
  87. o
  88. d
  89. e
  90. r
  91. \mathrm{Transformer\text{ }Encoder}
  92. Transformer Encoder都有层归一化,多头注意力机制,残差连接和线性变换这四个操作,只是在操作顺序有所不同。在以下的
  93. T
  94. r
  95. a
  96. n
  97. s
  98. f
  99. o
  100. r
  101. m
  102. e
  103. r
  104. \mathrm{ \text{ }Transformer}
  105. Transformer代码实例中,将以下两种
  106. E
  107. n
  108. c
  109. o
  110. d
  111. e
  112. r
  113. \mathrm{Encoder}
  114. Encoder网络结构都进行了实现,可以发现两种网络结构都可以进行很好的训练。

下图左半部分

  1. V
  2. i
  3. s
  4. i
  5. o
  6. n
  7. T
  8. r
  9. a
  10. n
  11. s
  12. f
  13. o
  14. r
  15. m
  16. e
  17. r
  18. E
  19. n
  20. c
  21. o
  22. d
  23. e
  24. r
  25. \mathrm{Vision\text{ }Transformer\text{ }Encoder}
  26. Vision Transformer Encoder具体的操作流程为
  • 给定输入编码矩阵 Z ∈ R l × n Z\in\mathbb{R}^{l\times n} Z∈Rl×n,首先将其进行层归一化得到 Z ′ ∈ R l × n Z^{\prime}\in\mathbb{R}^{l \times n} Z′∈Rl×n
  • 利用矩阵 W q , W k , W v ∈ R l × l W^{q},W^{k},W^{v}\in \mathbb{R}^{l\times l} Wq,Wk,Wv∈Rl×l对 Z ′ Z^{\prime} Z′进行线性变换得到矩阵 Q , K , W ∈ R l × n Q,K,W\in\mathbb{R}^{l\times n} Q,K,W∈Rl×n具体的计算过程为 { Q = W q ⋅ Z ′ K = W k ⋅ Z ′ V = W v ⋅ Z ′ \left{\begin{aligned}Q &= W^{q}\cdot Z^{\prime}\K&=W^{k}\cdot Z^{\prime}\V&=W^v \cdot Z^{\prime}\end{aligned}\right. ⎩⎪⎨⎪⎧​QKV​=Wq⋅Z′=Wk⋅Z′=Wv⋅Z′​再将这三个矩阵输入到 M u l t i \mathrm{Multi} Multi- H e a d A t t e n t i o n \mathrm{Head\text{ }Attention} Head Attention(该原理参考《Transformer详解(附代码)》)中得到矩阵 Z ′ ′ ∈ R l × n Z^{\prime\prime}\in \mathbb{R}^{l \times n} Z′′∈Rl×n将最原始的输入矩阵 Z Z Z与 Z ′ ′ Z^{\prime\prime} Z′′进行残差计算得到 Z + Z ′ ′ ∈ R l × n Z+Z^{\prime\prime}\in \mathbb{R}^{l\times n} Z+Z′′∈Rl×n
  • 将 Z + Z ′ ′ Z+Z^{\prime\prime} Z+Z′′进行第二次层归一化得到 Z ′ ′ ′ ∈ R l × n Z^{\prime\prime\prime}\in\mathbb{R}^{l\times n} Z′′′∈Rl×n,然后再将 Z ′ ′ ′ Z^{\prime\prime\prime} Z′′′输入到全连接神经网络中进行线性变换得到 Z ′ ′ ′ ′ ∈ R l × n Z^{\prime\prime\prime\prime}\in\mathbb{R}^{l\times n} Z′′′′∈Rl×n。最后将 Z + Z ′ ′ Z+Z^{\prime\prime} Z+Z′′与 Z ′ ′ ′ ′ Z^{\prime\prime\prime\prime} Z′′′′进行残差操作得到该 B l o c k \mathrm{Block} Block的输出 Z + Z ′ ′ + Z ′ ′ ′ ′ ∈ R l × n Z+Z^{\prime\prime}+Z^{\prime\prime\prime\prime}\in\mathbb{R}^{l\times n} Z+Z′′+Z′′′′∈Rl×n。一个 E n c o d e r \mathrm{Encoder} Encoder可以将 N N N个 B l o c k \mathrm{Block} Block进行堆叠,最后得到的输出为 O ∈ R l × n O\in\mathbb{R}^{l\times n} O∈Rl×n。

4 程序代码

  1. V
  2. i
  3. s
  4. i
  5. o
  6. n
  7. T
  8. r
  9. a
  10. n
  11. s
  12. f
  13. o
  14. r
  15. m
  16. e
  17. r
  18. \mathrm{Vision\text{ }Transformer}
  19. Vision Transformer的代码示例如下所示。该代码是由上一篇《Transformer详解(附代码)》的代码的基础上改编而来。
  20. V
  21. i
  22. s
  23. i
  24. o
  25. n
  26. T
  27. r
  28. a
  29. n
  30. s
  31. f
  32. o
  33. r
  34. m
  35. e
  36. r
  37. \mathrm{Vision\text{ }Transformer}
  38. Vision Transformer的作者的本意就是想让在
  39. N
  40. L
  41. P
  42. \mathrm{NLP}
  43. NLP中的
  44. T
  45. r
  46. a
  47. n
  48. s
  49. f
  50. o
  51. r
  52. m
  53. e
  54. r
  55. \mathrm{Transformer}
  56. Transformer模型架构做尽可能少的修改可以直接迁移到
  57. C
  58. V
  59. \mathrm{CV}
  60. CV中,所以以下程序尽可能保持作者的原意,并在代码实现了两种
  61. E
  62. n
  63. c
  64. o
  65. d
  66. e
  67. r
  68. \mathrm{Encoder}
  69. Encoder的网络结构,即3.2节图片所示的两个网络结构,一种是最原始的
  70. E
  71. n
  72. c
  73. o
  74. d
  75. e
  76. r
  77. \mathrm{Encoder}
  78. Encoder网络结构,一种是
  79. V
  80. i
  81. s
  82. i
  83. o
  84. n
  85. T
  86. r
  87. a
  88. n
  89. s
  90. f
  91. o
  92. r
  93. m
  94. e
  95. r
  96. \mathrm{Vision\text{ }Transformer}
  97. Vision Transformer论文里的
  98. E
  99. n
  100. c
  101. o
  102. d
  103. e
  104. r
  105. \mathrm{Encoder}
  106. Encoder的网络结构。这里需要注意的是,
  107. V
  108. i
  109. s
  110. i
  111. o
  112. n
  113. T
  114. r
  115. a
  116. n
  117. s
  118. f
  119. o
  120. r
  121. m
  122. e
  123. r
  124. \mathrm{Vision\text{ }Transformer}
  125. Vision Transformer里并能没有
  126. D
  127. e
  128. c
  129. o
  130. d
  131. e
  132. r
  133. \mathrm{Decoder}
  134. Decoder模块,所以不需要计算
  135. E
  136. n
  137. c
  138. o
  139. d
  140. e
  141. r
  142. \mathrm{Encoder}
  143. Encoder
  144. D
  145. e
  146. c
  147. o
  148. d
  149. e
  150. r
  151. \mathrm{Decoder}
  152. Decoder的交叉注意力分布,这就进一步给
  153. V
  154. i
  155. s
  156. i
  157. o
  158. n
  159. T
  160. r
  161. a
  162. n
  163. s
  164. f
  165. o
  166. r
  167. m
  168. e
  169. r
  170. \mathrm{Vision\text{ }Transformer}
  171. Vision Transformer的编程带来了简便。
  172. V
  173. i
  174. s
  175. i
  176. o
  177. n
  178. T
  179. r
  180. a
  181. n
  182. s
  183. f
  184. o
  185. r
  186. m
  187. e
  188. r
  189. \mathrm{Vision\text{ }Transformer}
  190. Vision Transformer的开源代码的网址为https://github.com/lucidrains/vit-pytorch/tree/main/vit_pytorch。
  1. import torch
  2. import torch.nn as nn
  3. import os
  4. from einops import rearrange
  5. from einops import repeat
  6. from einops.layers.torch import Rearrange
  7. definputs_deal(inputs):return inputs ifisinstance(inputs,tuple)else(inputs, inputs)classSelfAttention(nn.Module):def__init__(self, embed_size, heads):super(SelfAttention, self).__init__()
  8. self.embed_size = embed_size
  9. self.heads = heads
  10. self.head_dim = embed_size // heads
  11. assert(self.head_dim * heads == embed_size),"Embed size needs to be div by heads"
  12. self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
  13. self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
  14. self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
  15. self.fc_out = nn.Linear(heads * self.head_dim, embed_size)defforward(self, values, keys, query):
  16. N =query.shape[0]
  17. value_len , key_len , query_len = values.shape[1], keys.shape[1], query.shape[1]# split embedding into self.heads pieces
  18. values = values.reshape(N, value_len, self.heads, self.head_dim)
  19. keys = keys.reshape(N, key_len, self.heads, self.head_dim)
  20. queries = query.reshape(N, query_len, self.heads, self.head_dim)
  21. values = self.values(values)
  22. keys = self.keys(keys)
  23. queries = self.queries(queries)
  24. energy = torch.einsum("nqhd,nkhd->nhqk", queries, keys)# queries shape: (N, query_len, heads, heads_dim)# keys shape : (N, key_len, heads, heads_dim)# energy shape: (N, heads, query_len, key_len)
  25. attention = torch.softmax(energy/(self.embed_size **(1/2)), dim=3)
  26. out = torch.einsum("nhql, nlhd->nqhd",[attention, values]).reshape(N, query_len, self.heads*self.head_dim)# attention shape: (N, heads, query_len, key_len)# values shape: (N, value_len, heads, heads_dim)# (N, query_len, heads, head_dim)
  27. out = self.fc_out(out)return out
  28. classTransformerBlock(nn.Module):def__init__(self, embed_size, heads, dropout, forward_expansion):super(TransformerBlock, self).__init__()
  29. self.attention = SelfAttention(embed_size, heads)
  30. self.norm = nn.LayerNorm(embed_size)
  31. self.feed_forward = nn.Sequential(
  32. nn.Linear(embed_size, forward_expansion*embed_size),
  33. nn.ReLU(),
  34. nn.Linear(forward_expansion*embed_size, embed_size))
  35. self.dropout = nn.Dropout(dropout)defforward(self, value, key, query, x, type_mode):if type_mode =='original':
  36. attention = self.attention(value, key, query)
  37. x = self.dropout(self.norm(attention + x))
  38. forward = self.feed_forward(x)
  39. out = self.dropout(self.norm(forward + x))return out
  40. else:
  41. attention = self.attention(self.norm(value), self.norm(key), self.norm(query))
  42. x =self.dropout(attention + x)
  43. forward = self.feed_forward(self.norm(x))
  44. out = self.dropout(forward + x)return out
  45. classTransformerEncoder(nn.Module):def__init__(
  46. self,
  47. embed_size,
  48. num_layers,
  49. heads,
  50. forward_expansion,
  51. dropout =0,
  52. type_mode ='original'):super(TransformerEncoder, self).__init__()
  53. self.embed_size = embed_size
  54. self.type_mode = type_mode
  55. self.Query_Key_Value = nn.Linear(embed_size, embed_size *3, bias =False)
  56. self.layers = nn.ModuleList([
  57. TransformerBlock(
  58. embed_size,
  59. heads,
  60. dropout=dropout,
  61. forward_expansion=forward_expansion,)for _ inrange(num_layers)])
  62. self.dropout = nn.Dropout(dropout)defforward(self, x):for layer in self.layers:
  63. QKV_list = self.Query_Key_Value(x).chunk(3, dim =-1)
  64. x = layer(QKV_list[0], QKV_list[1], QKV_list[2], x, self.type_mode)return x
  65. classVisionTransformer(nn.Module):def__init__(self,
  66. image_size,
  67. patch_size,
  68. num_classes,
  69. embed_size,
  70. num_layers,
  71. heads,
  72. mlp_dim,
  73. pool ='cls',
  74. channels =3,
  75. dropout =0,
  76. emb_dropout =0.1,
  77. type_mode ='vit'):super(VisionTransformer, self).__init__()
  78. img_h, img_w = inputs_deal(image_size)
  79. patch_h, patch_w = inputs_deal(patch_size)assert img_h % patch_h ==0and img_w % patch_w ==0,'Img dimensions can be divisible by the patch dimensions'
  80. num_patches =(img_h // patch_h)*(img_w // patch_w)
  81. patch_size = channels * patch_h * patch_w
  82. self.patch_embedding = nn.Sequential(
  83. Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_h, p2=patch_w),
  84. nn.Linear(patch_size, embed_size, bias=False))
  85. self.pos_embedding = nn.Parameter(torch.randn(1, num_patches +1, embed_size))
  86. self.cls_token = nn.Parameter(torch.randn(1,1, embed_size))
  87. self.dropout = nn.Dropout(emb_dropout)
  88. self.transformer = TransformerEncoder(embed_size,
  89. num_layers,
  90. heads,
  91. mlp_dim,
  92. dropout)
  93. self.pool = pool
  94. self.to_latent = nn.Identity()
  95. self.mlp_head = nn.Sequential(
  96. nn.LayerNorm(embed_size),
  97. nn.Linear(embed_size, num_classes))defforward(self, img):
  98. x = self.patch_embedding(img)
  99. b, n, _ = x.shape
  100. cls_tokens = repeat(self.cls_token,'() n d ->b n d', b = b)
  101. x = torch.cat((cls_tokens, x), dim =1)
  102. x += self.pos_embedding[:,:(n +1)]
  103. x = self.dropout(x)
  104. x = self.transformer(x)
  105. x = x.mean(dim =1)if self.pool =='mean'else x[:,0]
  106. x = self.to_latent(x)return self.mlp_head(x)if __name__ =='__main__':
  107. vit = VisionTransformer(
  108. image_size =256,
  109. patch_size =16,
  110. num_classes =10,
  111. embed_size =256,
  112. num_layers =6,
  113. heads =8,
  114. mlp_dim =512,
  115. dropout =0.1,
  116. emb_dropout =0.1)
  117. img = torch.randn(3,3,256,256)
  118. pred = vit(img)print(pred)

以下代码是利用

  1. V
  2. i
  3. s
  4. i
  5. o
  6. n
  7. T
  8. r
  9. a
  10. n
  11. s
  12. f
  13. o
  14. r
  15. m
  16. e
  17. r
  18. \mathrm{Vision \text{ }Transformer}
  19. Vision Transformer网络结构训练一个分类
  20. m
  21. n
  22. i
  23. s
  24. t
  25. \mathrm{mnist}
  26. mnist数据集的主程序代码。
  1. from torchvision import datasets, transforms
  2. from torch.utils.data import DataLoader, Dataset
  3. import torch
  4. import torch.nn as nn
  5. import torch.optim as optim
  6. import torch.nn.functional as F
  7. import VIT
  8. import os
  9. deftrain():
  10. batch_size =4
  11. device = torch.device('cuda'if torch.cuda.is_available()else'cpu')
  12. epoches =20
  13. mnist_train = datasets.MNIST("mnist-data", train=True, download=True, transform=transforms.ToTensor())
  14. train_loader = torch.utils.data.DataLoader(mnist_train, batch_size= batch_size, shuffle=True)
  15. mnist_model = VIT.VisionTransformer(
  16. image_size =28,
  17. patch_size =7,
  18. num_classes =10,
  19. channels =1,
  20. embed_size =512,
  21. num_layers =1,
  22. heads =2,
  23. mlp_dim =1024,
  24. dropout =0,
  25. emb_dropout =0)
  26. loss_fn = nn.CrossEntropyLoss()
  27. mnist_model = mnist_model.to(device)
  28. opitimizer = optim.Adam(mnist_model.parameters(), lr=0.00001)
  29. mnist_model.train()for epoch inrange(epoches):
  30. total_loss =0
  31. corrects =0
  32. num =0for batch_X, batch_Y in train_loader:
  33. batch_X, batch_Y = batch_X.to(device), batch_Y.to(device)
  34. opitimizer.zero_grad()
  35. outputs = mnist_model(batch_X)
  36. _, pred = torch.max(outputs.data,1)
  37. loss = loss_fn(outputs, batch_Y)
  38. loss.backward()
  39. opitimizer.step()
  40. total_loss += loss.item()
  41. corrects = torch.sum(pred == batch_Y.data)
  42. num += batch_size
  43. print(epoch, total_loss/float(num), corrects.item()/float(batch_size))if __name__ =='__main__':
  44. train()

训练的过程如下所示,可以发现损失函数可以稳定下降。但是训练一个

  1. V
  2. i
  3. s
  4. i
  5. o
  6. n
  7. T
  8. r
  9. a
  10. n
  11. s
  12. f
  13. o
  14. r
  15. m
  16. e
  17. r
  18. \mathrm{Vision \text{ }Transformer}
  19. Vision Transformer模型真的是很烧硬件,跟训练一个普通的
  20. C
  21. N
  22. N
  23. \mathrm{CNN}
  24. CNN模型相比,训练一个
  25. V
  26. i
  27. s
  28. i
  29. o
  30. n
  31. T
  32. r
  33. a
  34. n
  35. s
  36. f
  37. o
  38. r
  39. m
  40. e
  41. r
  42. \mathrm{Vision \text{ }Transformer}
  43. Vision Transformer模型更加耗时耗力。


本文转载自: https://blog.csdn.net/qq_38406029/article/details/122157116
版权归原作者 鬼道2022 所有, 如有侵权,请联系我们删除。

“Vision Transformer详解(附代码)”的评论:

还没有评论