文章目录
一、简介
任务对象是Omniglot数据集上的few-shot classification任务,内容是利用元学习找到好的初始化参数。
Task: Few-shot Classification
The Omniglot dataset
Omniglot数据集-背景集: 30个字母 -评估集: 20个字母
问题设置: 5-way 1-shot classification
Training MAML on Omniglot classification task.
Training / validation set:30 alphabets
- multiple characters in one alphabet
- 20 images for one character Testing set: 640 support and query pairs
- 5 support images
- 5 query images
实验
1、simple
简单的迁移学习模型
训练:对随机选择的5个任务进行正常的分类训练
验证和测试:对这5个支持图像进行微调,并对查询图像进行推理
2、medium
完成元学习内部和外部循环的TODO块,使用FO-MAML。设置solver = ‘meta’,epoch调节为120。FOMAML是MAML的简化版本,可节省训练时间,它忽略了内循环梯度对结果的影响。
#TODO: Finish the inner loop update rule
grads = torch.autograd.grad(loss, fast_weights.values())
fast_weights =OrderedDict((name, param - inner_lr*grad)for((name, param), grad) in zip(fast_weights.items(), grads))#raiseNotImplementedError训练过程中需要设置该函数为损失函数#TODO: Finish the outer loop update
meta_batch_loss.backward()
optimizer.step()#raiseNotimplementedError
3、strong
使用MAML,可以计算更高阶的梯度,MAML就能用到内循环梯度的梯度 。
#TODO: Finish the inner loop update rule
grads = torch.autograd.grad(loss, fast_weights.values(), create_graph=True)
fast_weights =OrderedDict((name, param - inner_lr*grad)for((name, param), grad) in zip(fast_weights.items(), grads))#raiseNotImplementedError训练过程中需要设置该函数为损失函数
4、boss
任务增强(通过元学习)-什么是合理的方法来创建新任务?
使用了task augmentation的方法来增加训练任务的变化性,有40%的可能性做augmentation,旋转90度或270度。
#MetaSolver函数中修改for meta_batch in x:#Get dataif torch.rand(1).item()>0.6:
times =1if torch.rand(1).item()>0.5else3
meta_batch = torch.rot90(meta_batch, times,[-1,-2])
三、代码
模型构建准备工作
由于我们的任务是图像分类,我们需要建立一个基于CNN的模型。但是,要实现MAML算法,我们需要调整“nn.Module”中的一些代码。在第10行,我们采用的梯度是代表原始模型参数(外环)的θ,而不是内环中的θ,因此我们需要使用functional_forward来计算输入图像的输出逻辑,而不是在nn.Module中使用forward。下面定义了这些功能。
def functional_forward(self, x, params):for block in [1,2,3,4]:
x =ConvBlockFunction(
x,
params[f"conv{block}.0.weight"],
params[f"conv{block}.0.bias"],
params.get(f"conv{block}.1.weight"),
params.get(f"conv{block}.1.bias"),)
x = x.view(x.shape[0],-1)
x = F.linear(x, params["logits.weight"], params["logits.bias"])return x
创建labels for 5-way 2-shot
def create_label(n_way, k_shot):return torch.arange(n_way).repeat_interleave(k_shot).long()#Try to create labels for5-way 2-shot settingcreate_label(5,2)
计算精度
def calculate_accuracy(logits, labels):"""utility function for accuracy calculation"""
acc = np.asarray([(torch.argmax(logits,-1).cpu().numpy()== labels.cpu().numpy())]).mean()return acc
求解器首先从训练集中选择五个任务,然后对选择的五个任务进行正常的分类训练。在推理中,模型在支持集图像上对inner_train_step步骤进行微调,然后在查询集图像上进行推理。为了与元学习解算器保持一致,基本解算器具有与元学习解算器完全相同的输入和输出格式。
def BaseSolver(
model,
optimizer,
x,
n_way,
k_shot,
q_query,
loss_fn,
inner_train_step=1,
inner_lr=0.4,
train=True,
return_labels=False,):
criterion, task_loss, task_acc = loss_fn,[],[]
labels =[]for meta_batch in x:#Get data
support_set = meta_batch[: n_way * k_shot]
query_set = meta_batch[n_way * k_shot :]if train:""" training loop """#Use the support set to calculate loss
labels =create_label(n_way, k_shot).to(device)
logits = model.forward(support_set)
loss =criterion(logits, labels)
task_loss.append(loss)
task_acc.append(calculate_accuracy(logits, labels))else:""" validation / testing loop """#First update model with support set images for `inner_train_step` steps
fast_weights =OrderedDict(model.named_parameters())for inner_step in range(inner_train_step):#Simply training
train_label =create_label(n_way, k_shot).to(device)
logits = model.functional_forward(support_set, fast_weights)
loss =criterion(logits, train_label)
grads = torch.autograd.grad(loss, fast_weights.values(), create_graph=True)#Perform SGD
fast_weights =OrderedDict((name, param - inner_lr * grad)for((name, param), grad) in zip(fast_weights.items(), grads))if not return_labels:""" validation """
val_label =create_label(n_way, q_query).to(device)
logits = model.functional_forward(query_set, fast_weights)
loss =criterion(logits, val_label)
task_loss.append(loss)
task_acc.append(calculate_accuracy(logits, val_label))else:""" testing """
logits = model.functional_forward(query_set, fast_weights)
labels.extend(torch.argmax(logits,-1).cpu().numpy())if return_labels:return labels
batch_loss = torch.stack(task_loss).mean()
task_acc = np.mean(task_acc)if train:#Update model
model.train()
optimizer.zero_grad()
batch_loss.backward()
optimizer.step()return batch_loss, task_acc
元学习
def MetaSolver(
model,
optimizer,
x,
n_way,
k_shot,
q_query,
loss_fn,
inner_train_step=1,
inner_lr=0.4,
train=True,
return_labels=False
):
criterion, task_loss, task_acc = loss_fn,[],[]
labels =[]for meta_batch in x:#Get dataif torch.rand(1).item()>0.6:
times =1if torch.rand(1).item()>0.5else3
meta_batch = torch.rot90(meta_batch, times,[-1,-2])# B =rot90(A,k) 将数组 A 按逆时针方向旋转 k*90 度
support_set = meta_batch[: n_way * k_shot]
query_set = meta_batch[n_way * k_shot :]#Copy the params for inner loop
fast_weights =OrderedDict(model.named_parameters())
### ---------- INNER TRAIN LOOP ---------- ###
for inner_step in range(inner_train_step):#Simply training
train_label =create_label(n_way, k_shot).to(device)
logits = model.functional_forward(support_set, fast_weights)
loss =criterion(logits, train_label)#Inner gradients update! vvvvvvvvvvvvvvvvvvvv #""" Inner Loop Update """#TODO: Finish the inner loop update rule
grads = torch.autograd.grad(loss, fast_weights.values(), create_graph=True)
fast_weights =OrderedDict((name, param - inner_lr*grad)for((name, param), grad) in zip(fast_weights.items(), grads))#raiseNotImplementedError
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ #
### ---------- INNER VALID LOOP ---------- ###
if not return_labels:""" training / validation """
val_label =create_label(n_way, q_query).to(device)#Collect gradients for outer loop
logits = model.functional_forward(query_set, fast_weights)
loss =criterion(logits, val_label)
task_loss.append(loss)
task_acc.append(calculate_accuracy(logits, val_label))else:""" testing """
logits = model.functional_forward(query_set, fast_weights)
labels.extend(torch.argmax(logits,-1).cpu().numpy())if return_labels:return labels
#Update outer loop
model.train()
optimizer.zero_grad()
meta_batch_loss = torch.stack(task_loss).mean()if train:""" Outer Loop Update """#TODO: Finish the outer loop update
meta_batch_loss.backward()
optimizer.step()#raiseNotimplementedError
task_acc = np.mean(task_acc)return meta_batch_loss, task_acc
版权归原作者 Raphael9900 所有, 如有侵权,请联系我们删除。