0


使用JAX实现完整的Vision Transformer

本文将展示如何使用JAX/Flax实现Vision Transformer (ViT),以及如何使用JAX/Flax训练ViT。

Vision Transformer

在实现Vision Transformer时,首先要记住这张图。

以下是论文描述的ViT执行过程。

从输入图像中提取补丁图像,并将其转换为平面向量。

投影到 Transformer Encoder 来处理的维度

预先添加一个可学习的嵌入([class]标记),并添加一个位置嵌入。

由 Transformer Encoder 进行编码处理

使用[class]令牌作为输出,输入到MLP进行分类。

细节实现

下面,我们将使用JAX/Flax创建每个模块。

1、图像到展平的图像补丁

下面的代码从输入图像中提取图像补丁。这个过程通过卷积来实现,内核大小为patch_size * patch_size, stride为patch_size * patch_size,以避免重复。

  1. classPatches(nn.Module):
  2. patch_size: int
  3. embed_dim: int
  4. defsetup(self):
  5. self.conv=nn.Conv(
  6. features=self.embed_dim,
  7. kernel_size=(self.patch_size, self.patch_size),
  8. strides=(self.patch_size, self.patch_size),
  9. padding='VALID'
  10. )
  11. def__call__(self, images):
  12. patches=self.conv(images)
  13. b, h, w, c=patches.shape
  14. patches=jnp.reshape(patches, (b, h*w, c))
  15. returnpatches

2和3、对展平补丁块的线性投影/添加[CLS]标记/位置嵌入

Transformer Encoder 对所有层使用相同的尺寸大小hidden_dim。上面创建的补丁块向量被投影到hidden_dim维度向量上。与BERT一样,有一个CLS令牌被添加到序列的开头,还增加了一个可学习的位置嵌入来保存位置信息。

  1. classPatchEncoder(nn.Module):
  2. hidden_dim: int
  3. @nn.compact
  4. def__call__(self, x):
  5. assertx.ndim==3
  6. n, seq_len, _=x.shape
  7. # Hidden dim
  8. x=nn.Dense(self.hidden_dim)(x)
  9. # Add cls token
  10. cls=self.param('cls_token', nn.initializers.zeros, (1, 1, self.hidden_dim))
  11. cls=jnp.tile(cls, (n, 1, 1))
  12. x=jnp.concatenate([cls, x], axis=1)
  13. # Add position embedding
  14. pos_embed=self.param(
  15. 'position_embedding',
  16. nn.initializers.normal(stddev=0.02), # From BERT
  17. (1, seq_len+1, self.hidden_dim)
  18. )
  19. returnx+pos_embed

4、Transformer encoder

如上图所示,编码器由多头自注意(MSA)和MLP交替层组成。Norm层 (LN)在MSA和MLP块之前,残差连接在块之后。

  1. classTransformerEncoder(nn.Module):
  2. embed_dim: int
  3. hidden_dim: int
  4. n_heads: int
  5. drop_p: float
  6. mlp_dim: int
  7. defsetup(self):
  8. self.mha=MultiHeadSelfAttention(self.hidden_dim, self.n_heads, self.drop_p)
  9. self.mlp=MLP(self.mlp_dim, self.drop_p)
  10. self.layer_norm=nn.LayerNorm(epsilon=1e-6)
  11. def__call__(self, inputs, train=True):
  12. # Attention Block
  13. x=self.layer_norm(inputs)
  14. x=self.mha(x, train)
  15. x=inputs+x
  16. # MLP block
  17. y=self.layer_norm(x)
  18. y=self.mlp(y, train)
  19. returnx+y

MLP是一个两层网络。激活函数是GELU。本文将Dropout应用于Dense层之后。

  1. classMLP(nn.Module):
  2. mlp_dim: int
  3. drop_p: float
  4. out_dim: Optional[int] =None
  5. @nn.compact
  6. def__call__(self, inputs, train=True):
  7. actual_out_dim=inputs.shape[-1] ifself.out_dimisNoneelseself.out_dim
  8. x=nn.Dense(features=self.mlp_dim)(inputs)
  9. x=nn.gelu(x)
  10. x=nn.Dropout(rate=self.drop_p, deterministic=nottrain)(x)
  11. x=nn.Dense(features=actual_out_dim)(x)
  12. x=nn.Dropout(rate=self.drop_p, deterministic=nottrain)(x)
  13. returnx

多头自注意(MSA)

qkv的形式应为[B, N, T, D],如Single Head中计算权重和注意力后,应输出回原维度[B, T, C=N*D]。

  1. classMultiHeadSelfAttention(nn.Module):
  2. hidden_dim: int
  3. n_heads: int
  4. drop_p: float
  5. defsetup(self):
  6. self.q_net=nn.Dense(self.hidden_dim)
  7. self.k_net=nn.Dense(self.hidden_dim)
  8. self.v_net=nn.Dense(self.hidden_dim)
  9. self.proj_net=nn.Dense(self.hidden_dim)
  10. self.att_drop=nn.Dropout(self.drop_p)
  11. self.proj_drop=nn.Dropout(self.drop_p)
  12. def__call__(self, x, train=True):
  13. B, T, C=x.shape# batch_size, seq_length, hidden_dim
  14. N, D=self.n_heads, C//self.n_heads# num_heads, head_dim
  15. q=self.q_net(x).reshape(B, T, N, D).transpose(0, 2, 1, 3) # (B, N, T, D)
  16. k=self.k_net(x).reshape(B, T, N, D).transpose(0, 2, 1, 3)
  17. v=self.v_net(x).reshape(B, T, N, D).transpose(0, 2, 1, 3)
  18. # weights (B, N, T, T)
  19. weights=jnp.matmul(q, jnp.swapaxes(k, -2, -1)) /math.sqrt(D)
  20. normalized_weights=nn.softmax(weights, axis=-1)
  21. # attention (B, N, T, D)
  22. attention=jnp.matmul(normalized_weights, v)
  23. attention=self.att_drop(attention, deterministic=nottrain)
  24. # gather heads
  25. attention=attention.transpose(0, 2, 1, 3).reshape(B, T, N*D)
  26. # project
  27. out=self.proj_drop(self.proj_net(attention), deterministic=nottrain)
  28. returnout

5、使用CLS嵌入进行分类

最后MLP头(分类头)。

  1. classViT(nn.Module):
  2. patch_size: int
  3. embed_dim: int
  4. hidden_dim: int
  5. n_heads: int
  6. drop_p: float
  7. num_layers: int
  8. mlp_dim: int
  9. num_classes: int
  10. defsetup(self):
  11. self.patch_extracter=Patches(self.patch_size, self.embed_dim)
  12. self.patch_encoder=PatchEncoder(self.hidden_dim)
  13. self.dropout=nn.Dropout(self.drop_p)
  14. self.transformer_encoder=TransformerEncoder(self.embed_dim, self.hidden_dim, self.n_heads, self.drop_p, self.mlp_dim)
  15. self.cls_head=nn.Dense(features=self.num_classes)
  16. def__call__(self, x, train=True):
  17. x=self.patch_extracter(x)
  18. x=self.patch_encoder(x)
  19. x=self.dropout(x, deterministic=nottrain)
  20. foriinrange(self.num_layers):
  21. x=self.transformer_encoder(x, train)
  22. # MLP head
  23. x=x[:, 0] # [CLS] token
  24. x=self.cls_head(x)
  25. returnx

使用JAX/Flax训练

现在已经创建了模型,下面就是使用JAX/Flax来训练。

数据集

这里我们直接使用 torchvision的CIFAR10.

首先是一些工具函数

  1. defimage_to_numpy(img):
  2. img=np.array(img, dtype=np.float32)
  3. img= (img/255.-DATA_MEANS) /DATA_STD
  4. returnimg
  5. defnumpy_collate(batch):
  6. ifisinstance(batch[0], np.ndarray):
  7. returnnp.stack(batch)
  8. elifisinstance(batch[0], (tuple, list)):
  9. transposed=zip(*batch)
  10. return [numpy_collate(samples) forsamplesintransposed]
  11. else:
  12. returnnp.array(batch)

然后是训练和测试的dataloader

  1. test_transform=image_to_numpy
  2. train_transform=transforms.Compose([
  3. transforms.RandomHorizontalFlip(),
  4. transforms.RandomResizedCrop((IMAGE_SIZE, IMAGE_SIZE), scale=CROP_SCALES, ratio=CROP_RATIO),
  5. image_to_numpy
  6. ])
  7. # Validation set should not use the augmentation.
  8. train_dataset=CIFAR10('data', train=True, transform=train_transform, download=True)
  9. val_dataset=CIFAR10('data', train=True, transform=test_transform, download=True)
  10. train_set, _=torch.utils.data.random_split(train_dataset, [45000, 5000], generator=torch.Generator().manual_seed(SEED))
  11. _, val_set=torch.utils.data.random_split(val_dataset, [45000, 5000], generator=torch.Generator().manual_seed(SEED))
  12. test_set=CIFAR10('data', train=False, transform=test_transform, download=True)
  13. train_loader=torch.utils.data.DataLoader(
  14. train_set, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, num_workers=2, persistent_workers=True, collate_fn=numpy_collate,
  15. )
  16. val_loader=torch.utils.data.DataLoader(
  17. val_set, batch_size=BATCH_SIZE, shuffle=False, drop_last=False, num_workers=2, persistent_workers=True, collate_fn=numpy_collate,
  18. )
  19. test_loader=torch.utils.data.DataLoader(
  20. test_set, batch_size=BATCH_SIZE, shuffle=False, drop_last=False, num_workers=2, persistent_workers=True, collate_fn=numpy_collate,
  21. )

初始化模型

初始化ViT模型

  1. definitialize_model(
  2. seed=42,
  3. patch_size=16, embed_dim=192, hidden_dim=192,
  4. n_heads=3, drop_p=0.1, num_layers=12, mlp_dim=768, num_classes=10
  5. ):
  6. main_rng=jax.random.PRNGKey(seed)
  7. x=jnp.ones(shape=(5, 32, 32, 3))
  8. # ViT
  9. model=ViT(
  10. patch_size=patch_size,
  11. embed_dim=embed_dim,
  12. hidden_dim=hidden_dim,
  13. n_heads=n_heads,
  14. drop_p=drop_p,
  15. num_layers=num_layers,
  16. mlp_dim=mlp_dim,
  17. num_classes=num_classes
  18. )
  19. main_rng, init_rng, drop_rng=random.split(main_rng, 3)
  20. params=model.init({'params': init_rng, 'dropout': drop_rng}, x, train=True)['params']
  21. returnmodel, params, main_rng
  22. vit_model, vit_params, vit_rng=initialize_model()

创建TrainState

在Flax中常见的模式是创建管理训练的状态的类,包括轮次、优化器状态和模型参数等等。还可以通过在apply_fn中指定apply_fn来减少学习循环中的函数参数列表,apply_fn对应于模型的前向传播。

  1. defcreate_train_state(
  2. model, params, learning_rate
  3. ):
  4. optimizer=optax.adam(learning_rate)
  5. returntrain_state.TrainState.create(
  6. apply_fn=model.apply,
  7. tx=optimizer,
  8. params=params
  9. )
  10. state=create_train_state(vit_model, vit_params, 3e-4)

循环训练

  1. deftrain_model(train_loader, val_loader, state, rng, num_epochs=100):
  2. best_eval=0.0
  3. forepoch_idxintqdm(range(1, num_epochs+1)):
  4. state, rng=train_epoch(train_loader, epoch_idx, state, rng)
  5. ifepoch_idx%1==0:
  6. eval_acc=eval_model(val_loader, state, rng)
  7. logger.add_scalar('val/acc', eval_acc, global_step=epoch_idx)
  8. ifeval_acc>=best_eval:
  9. best_eval=eval_acc
  10. save_model(state, step=epoch_idx)
  11. logger.flush()
  12. # Evaluate after training
  13. test_acc=eval_model(test_loader, state, rng)
  14. print(f'test_acc: {test_acc}')
  15. deftrain_epoch(train_loader, epoch_idx, state, rng):
  16. metrics=defaultdict(list)
  17. forbatchintqdm(train_loader, desc='Training', leave=False):
  18. state, rng, loss, acc=train_step(state, rng, batch)
  19. metrics['loss'].append(loss)
  20. metrics['acc'].append(acc)
  21. forkeyinmetrics.keys():
  22. arg_val=np.stack(jax.device_get(metrics[key])).mean()
  23. logger.add_scalar('train/'+key, arg_val, global_step=epoch_idx)
  24. print(f'[epoch {epoch_idx}] {key}: {arg_val}')
  25. returnstate, rng

验证

  1. defeval_model(data_loader, state, rng):
  2. # Test model on all images of a data loader and return avg loss
  3. correct_class, count=0, 0
  4. forbatchindata_loader:
  5. rng, acc=eval_step(state, rng, batch)
  6. correct_class+=acc*batch[0].shape[0]
  7. count+=batch[0].shape[0]
  8. eval_acc= (correct_class/count).item()
  9. returneval_acc

训练步骤

在train_step中定义损失函数,计算模型参数的梯度,并根据梯度更新参数;在value_and_gradients方法中,计算状态的梯度。在apply_gradients中,更新TrainState。交叉熵损失是通过apply_fn(与model.apply相同)计算logits来计算的,apply_fn是在创建TrainState时指定的。

  1. @jax.jit
  2. deftrain_step(state, rng, batch):
  3. loss_fn=lambdaparams: calculate_loss(params, state, rng, batch, train=True)
  4. # Get loss, gradients for loss, and other outputs of loss function
  5. (loss, (acc, rng)), grads=jax.value_and_grad(loss_fn, has_aux=True)(state.params)
  6. # Update parameters and batch statistics
  7. state=state.apply_gradients(grads=grads)
  8. returnstate, rng, loss, acc

计算损失

  1. defcalculate_loss(params, state, rng, batch, train):
  2. imgs, labels=batch
  3. rng, drop_rng=random.split(rng)
  4. logits=state.apply_fn({'params': params}, imgs, train=train, rngs={'dropout': drop_rng})
  5. loss=optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=labels).mean()
  6. acc= (logits.argmax(axis=-1) ==labels).mean()
  7. returnloss, (acc, rng)

结果

训练结果如下所示。在Colab pro的标准GPU上,训练时间约为1.5小时。

  1. test_acc: 0.7704000473022461

如果你对JAX感兴趣,请看这里是本文的完整代码:

https://github.com/satojkovic/vit-jax-flax

作者:satojkovic

“使用JAX实现完整的Vision Transformer”的评论:

还没有评论