在自定义数据集上实现OpenAI CLIP

在2021年1月,OpenAI宣布了两个新模型:DALL-E和CLIP,它们都是以某种方式连接文本和图像的多模态模型。CLIP全称是Contrastive Language–Image Pre-training,一种基于对比文本-图像对的预训练方法。为什么要介绍CLIP呢?因为现在大火得Stable Diffusion 并不是单一模型,而是多个模型组成。其中会用到一个 Text encoder 将用户的文本输入进行编码,这个 text encoder 就是 CLIP 模型中 text encoder

CLIP模型在训练时,可以给它一个输入句子,并提取最相关的图像来配合它。CLIP学习了一个完整的句子和它所描述的图像之间的关系。也就是说它是在完整的句子上训练的,而不是像“汽车”、“狗”等离散的分类,这一点对于应用至关重要。当训练完整的短语时,模型可以学习更多的东西,并识别照片和文本之间的模式。他们还证明,当在相当大的照片和与之相对应的句子数据集上进行训练时,该模型是可以作为分类器的。CLIP在发布的时候能在无任何微调的情况下(zero-shot ),在 ImageNet 数据集上的分类表现超 ResNets-50 微调后的效果,也就是说他是非常有用的。



  1. import os
  2. import cv2
  3. import gc
  4. import numpy as np
  5. import pandas as pd
  6. import itertools
  7. from tqdm.autonotebook import tqdm
  8. import albumentations as A
  9. import matplotlib.pyplot as plt
  10. import torch
  11. from torch import nn
  12. import torch.nn.functional as F
  13. import timm
  14. from transformers import DistilBertModel, DistilBertConfig, DistilBertTokenizer

下一步就是预处理数据和通用配置config。config是一个普通的python文件,我们将所有的超参数放在里面,如果使用Jupyter Notebook的情况下,它是一个在Notebook开头定义的类。

  1. class CFG:
  2. debug = False
  3. image_path = "../input/flickr-image-dataset/flickr30k_images/flickr30k_images"
  4. captions_path = "."
  5. batch_size = 32
  6. num_workers = 4
  7. head_lr = 1e-3
  8. image_encoder_lr = 1e-4
  9. text_encoder_lr = 1e-5
  10. weight_decay = 1e-3
  11. patience = 1
  12. factor = 0.8
  13. epochs = 2
  14. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  15. model_name = 'resnet50'
  16. image_embedding = 2048
  17. text_encoder_model = "distilbert-base-uncased"
  18. text_embedding = 768
  19. text_tokenizer = "distilbert-base-uncased"
  20. max_length = 200
  21. pretrained = True # for both image encoder and text encoder
  22. trainable = True # for both image encoder and text encoder
  23. temperature = 1.0
  24. # image size
  25. size = 224
  26. # for projection head; used for both image and text encoders
  27. num_projection_layers = 1
  28. projection_dim = 256
  29. dropout = 0.1


  1. class AvgMeter:
  2. def __init__(self, name="Metric"):
  3. self.name = name
  4. self.reset()
  5. def reset(self):
  6. self.avg, self.sum, self.count = [0] * 3
  7. def update(self, val, count=1):
  8. self.count += count
  9. self.sum += val * count
  10. self.avg = self.sum / self.count
  11. def __repr__(self):
  12. text = f"{self.name}: {self.avg:.4f}"
  13. return text
  14. def get_lr(optimizer):
  15. for param_group in optimizer.param_groups:
  16. return param_group["lr"]

我们的目标是描述图像和句子。所以数据集必须同时返回句子和图像。所以需要使用DistilBERT标记器对句子(标题)进行标记,然后将标记id (input_ids)和注意掩码提供给DistilBERT。DistilBERT比BERT 模型要小,但是模型的结果都差不多,所以我们选择使用它。

下一步就是使用HuggingFace tokenizer进行标记化。在__init__中获得的tokenizer对象,将在模型运行时加载。标题被填充并截断到预定的最大长度。在加载相关图像之前,我们将在**getitem**中加载一个编码的标题,这是一个带有键input_ids和attention_mask的字典,并对其进行转换和扩充(如果有的话)。然后把它变成一个张量,并以“image”作为键存储在字典中。最后我们将标题的原始文本与关键字“标题”一起输入字典。

  1. class CLIPDataset(torch.utils.data.Dataset):
  2. def __init__(self, image_filenames, captions, tokenizer, transforms):
  3. """
  4. image_filenames and cpations must have the same length; so, if there are
  5. multiple captions for each image, the image_filenames must have repetitive
  6. file names
  7. """
  8. self.image_filenames = image_filenames
  9. self.captions = list(captions)
  10. self.encoded_captions = tokenizer(
  11. list(captions), padding=True, truncation=True, max_length=CFG.max_length
  12. )
  13. self.transforms = transforms
  14. def __getitem__(self, idx):
  15. item = {
  16. key: torch.tensor(values[idx])
  17. for key, values in self.encoded_captions.items()
  18. }
  19. image = cv2.imread(f"{CFG.image_path}/{self.image_filenames[idx]}")
  20. image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
  21. image = self.transforms(image=image)['image']
  22. item['image'] = torch.tensor(image).permute(2, 0, 1).float()
  23. item['caption'] = self.captions[idx]
  24. return item
  25. def __len__(self):
  26. return len(self.captions)
  27. def get_transforms(mode="train"):
  28. if mode == "train":
  29. return A.Compose(
  30. [
  31. A.Resize(CFG.size, CFG.size, always_apply=True),
  32. A.Normalize(max_pixel_value=255.0, always_apply=True),
  33. ]
  34. )
  35. else:
  36. return A.Compose(
  37. [
  38. A.Resize(CFG.size, CFG.size, always_apply=True),
  39. A.Normalize(max_pixel_value=255.0, always_apply=True),
  40. ]
  41. )


  1. class ImageEncoder(nn.Module):
  2. """
  3. Encode images to a fixed size vector
  4. """
  5. def __init__(
  6. self, model_name=CFG.model_name, pretrained=CFG.pretrained, trainable=CFG.trainable
  7. ):
  8. super().__init__()
  9. self.model = timm.create_model(
  10. model_name, pretrained, num_classes=0, global_pool="avg"
  11. )
  12. for p in self.model.parameters():
  13. p.requires_grad = trainable
  14. def forward(self, x):
  15. return self.model(x)


  1. class TextEncoder(nn.Module):
  2. def __init__(self, model_name=CFG.text_encoder_model, pretrained=CFG.pretrained, trainable=CFG.trainable):
  3. super().__init__()
  4. if pretrained:
  5. self.model = DistilBertModel.from_pretrained(model_name)
  6. else:
  7. self.model = DistilBertModel(config=DistilBertConfig())
  8. for p in self.model.parameters():
  9. p.requires_grad = trainable
  10. # we are using the CLS token hidden representation as the sentence's embedding
  11. self.target_token_idx = 0
  12. def forward(self, input_ids, attention_mask):
  13. output = self.model(input_ids=input_ids, attention_mask=attention_mask)
  14. last_hidden_state = output.last_hidden_state
  15. return last_hidden_state[:, self.target_token_idx, :]


  1. class ProjectionHead(nn.Module):
  2. def __init__(
  3. self,
  4. embedding_dim,
  5. projection_dim=CFG.projection_dim,
  6. dropout=CFG.dropout
  7. ):
  8. super().__init__()
  9. self.projection = nn.Linear(embedding_dim, projection_dim)
  10. self.gelu = nn.GELU()
  11. self.fc = nn.Linear(projection_dim, projection_dim)
  12. self.dropout = nn.Dropout(dropout)
  13. self.layer_norm = nn.LayerNorm(projection_dim)
  14. def forward(self, x):
  15. projected = self.projection(x)
  16. x = self.gelu(projected)
  17. x = self.fc(x)
  18. x = self.dropout(x)
  19. x = x + projected
  20. x = self.layer_norm(x)
  21. return x


  1. class CLIPModel(nn.Module):
  2. def __init__(
  3. self,
  4. temperature=CFG.temperature,
  5. image_embedding=CFG.image_embedding,
  6. text_embedding=CFG.text_embedding,
  7. ):
  8. super().__init__()
  9. self.image_encoder = ImageEncoder()
  10. self.text_encoder = TextEncoder()
  11. self.image_projection = ProjectionHead(embedding_dim=image_embedding)
  12. self.text_projection = ProjectionHead(embedding_dim=text_embedding)
  13. self.temperature = temperature
  14. def forward(self, batch):
  15. # Getting Image and Text Features
  16. image_features = self.image_encoder(batch["image"])
  17. text_features = self.text_encoder(
  18. input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]
  19. )
  20. # Getting Image and Text Embeddings (with same dimension)
  21. image_embeddings = self.image_projection(image_features)
  22. text_embeddings = self.text_projection(text_features)
  23. # Calculating the Loss
  24. logits = (text_embeddings @ image_embeddings.T) / self.temperature
  25. images_similarity = image_embeddings @ image_embeddings.T
  26. texts_similarity = text_embeddings @ text_embeddings.T
  27. targets = F.softmax(
  28. (images_similarity + texts_similarity) / 2 * self.temperature, dim=-1
  29. )
  30. texts_loss = cross_entropy(logits, targets, reduction='none')
  31. images_loss = cross_entropy(logits.T, targets.T, reduction='none')
  32. loss = (images_loss + texts_loss) / 2.0 # shape: (batch_size)
  33. return loss.mean()
  34. #这里还加了一个交叉熵函数
  35. def cross_entropy(preds, targets, reduction='none'):
  36. log_softmax = nn.LogSoftmax(dim=-1)
  37. loss = (-targets * log_softmax(preds)).sum(1)
  38. if reduction == "none":
  39. return loss
  40. elif reduction == "mean":
  41. return loss.mean()

这里需要说明下,CLIP使用 symmetric cross entropy 作为损失函数,可以降低噪音影响,提高模型鲁棒性,我们这里为了简单只是用cross entropy 。


  1. # A simple Example
  2. batch_size = 4
  3. dim = 256
  4. embeddings = torch.randn(batch_size, dim)
  5. out = embeddings @ embeddings.T
  6. print(F.softmax(out, dim=-1))


  1. def make_train_valid_dfs():
  2. dataframe = pd.read_csv(f"{CFG.captions_path}/captions.csv")
  3. max_id = dataframe["id"].max() + 1 if not CFG.debug else 100
  4. image_ids = np.arange(0, max_id)
  5. np.random.seed(42)
  6. valid_ids = np.random.choice(
  7. image_ids, size=int(0.2 * len(image_ids)), replace=False
  8. )
  9. train_ids = [id_ for id_ in image_ids if id_ not in valid_ids]
  10. train_dataframe = dataframe[dataframe["id"].isin(train_ids)].reset_index(drop=True)
  11. valid_dataframe = dataframe[dataframe["id"].isin(valid_ids)].reset_index(drop=True)
  12. return train_dataframe, valid_dataframe
  13. def build_loaders(dataframe, tokenizer, mode):
  14. transforms = get_transforms(mode=mode)
  15. dataset = CLIPDataset(
  16. dataframe["image"].values,
  17. dataframe["caption"].values,
  18. tokenizer=tokenizer,
  19. transforms=transforms,
  20. )
  21. dataloader = torch.utils.data.DataLoader(
  22. dataset,
  23. batch_size=CFG.batch_size,
  24. num_workers=CFG.num_workers,
  25. shuffle=True if mode == "train" else False,
  26. )
  27. return dataloader


  1. def train_epoch(model, train_loader, optimizer, lr_scheduler, step):
  2. loss_meter = AvgMeter()
  3. tqdm_object = tqdm(train_loader, total=len(train_loader))
  4. for batch in tqdm_object:
  5. batch = {k: v.to(CFG.device) for k, v in batch.items() if k != "caption"}
  6. loss = model(batch)
  7. optimizer.zero_grad()
  8. loss.backward()
  9. optimizer.step()
  10. if step == "batch":
  11. lr_scheduler.step()
  12. count = batch["image"].size(0)
  13. loss_meter.update(loss.item(), count)
  14. tqdm_object.set_postfix(train_loss=loss_meter.avg, lr=get_lr(optimizer))
  15. return loss_meter
  16. def valid_epoch(model, valid_loader):
  17. loss_meter = AvgMeter()
  18. tqdm_object = tqdm(valid_loader, total=len(valid_loader))
  19. for batch in tqdm_object:
  20. batch = {k: v.to(CFG.device) for k, v in batch.items() if k != "caption"}
  21. loss = model(batch)
  22. count = batch["image"].size(0)
  23. loss_meter.update(loss.item(), count)
  24. tqdm_object.set_postfix(valid_loss=loss_meter.avg)
  25. return loss_meter


  1. def main():
  2. train_df, valid_df = make_train_valid_dfs()
  3. tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer)
  4. train_loader = build_loaders(train_df, tokenizer, mode="train")
  5. valid_loader = build_loaders(valid_df, tokenizer, mode="valid")
  6. model = CLIPModel().to(CFG.device)
  7. params = [
  8. {"params": model.image_encoder.parameters(), "lr": CFG.image_encoder_lr},
  9. {"params": model.text_encoder.parameters(), "lr": CFG.text_encoder_lr},
  10. {"params": itertools.chain(
  11. model.image_projection.parameters(), model.text_projection.parameters()
  12. ), "lr": CFG.head_lr, "weight_decay": CFG.weight_decay}
  13. ]
  14. optimizer = torch.optim.AdamW(params, weight_decay=0.)
  15. lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
  16. optimizer, mode="min", patience=CFG.patience, factor=CFG.factor
  17. )
  18. step = "epoch"
  19. best_loss = float('inf')
  20. for epoch in range(CFG.epochs):
  21. print(f"Epoch: {epoch + 1}")
  22. model.train()
  23. train_loss = train_epoch(model, train_loader, optimizer, lr_scheduler, step)
  24. model.eval()
  25. with torch.no_grad():
  26. valid_loss = valid_epoch(model, valid_loader)
  27. if valid_loss.avg < best_loss:
  28. best_loss = valid_loss.avg
  29. torch.save(model.state_dict(), "best.pt")
  30. print("Saved Best Model!")
  31. lr_scheduler.step(valid_loss.avg)


我们训练完成后如何实际应用呢?我们需要编写一个函数加载训练后的模型,为其提供验证集中的图像,并返回形状(valid_set_size, 256)和模型本身的image_embeddings。

  1. def get_image_embeddings(valid_df, model_path):
  2. tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer)
  3. valid_loader = build_loaders(valid_df, tokenizer, mode="valid")
  4. model = CLIPModel().to(CFG.device)
  5. model.load_state_dict(torch.load(model_path, map_location=CFG.device))
  6. model.eval()
  7. valid_image_embeddings = []
  8. with torch.no_grad():
  9. for batch in tqdm(valid_loader):
  10. image_features = model.image_encoder(batch["image"].to(CFG.device))
  11. image_embeddings = model.image_projection(image_features)
  12. valid_image_embeddings.append(image_embeddings)
  13. return model, torch.cat(valid_image_embeddings)
  14. _, valid_df = make_train_valid_dfs()
  15. model, image_embeddings = get_image_embeddings(valid_df, "best.pt")
  16. def find_matches(model, image_embeddings, query, image_filenames, n=9):
  17. tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer)
  18. encoded_query = tokenizer([query])
  19. batch = {
  20. key: torch.tensor(values).to(CFG.device)
  21. for key, values in encoded_query.items()
  22. }
  23. with torch.no_grad():
  24. text_features = model.text_encoder(
  25. input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]
  26. )
  27. text_embeddings = model.text_projection(text_features)
  28. image_embeddings_n = F.normalize(image_embeddings, p=2, dim=-1)
  29. text_embeddings_n = F.normalize(text_embeddings, p=2, dim=-1)
  30. dot_similarity = text_embeddings_n @ image_embeddings_n.T
  31. values, indices = torch.topk(dot_similarity.squeeze(0), n * 5)
  32. matches = [image_filenames[idx] for idx in indices[::5]]
  33. _, axes = plt.subplots(3, 3, figsize=(10, 10))
  34. for match, ax in zip(matches, axes.flatten()):
  35. image = cv2.imread(f"{CFG.image_path}/{match}")
  36. image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
  37. ax.imshow(image)
  38. ax.axis("off")
  39. plt.show()


  1. find_matches(model,
  2. image_embeddings,
  3. query="one dog sitting on the grass",
  4. image_filenames=valid_df['image'].values,
  5. n=9)




作者:Jyoti Dabass, Ph.D

