0


机器学习HW15元学习

文章目录


一、简介

任务对象是Omniglot数据集上的few-shot classification任务,内容是利用元学习找到好的初始化参数。

Task: Few-shot Classification

The Omniglot dataset
在这里插入图片描述
Omniglot数据集-背景集: 30个字母 -评估集: 20个字母
问题设置: 5-way 1-shot classification
在这里插入图片描述
Training MAML on Omniglot classification task.
在这里插入图片描述
Training / validation set:30 alphabets

  • multiple characters in one alphabet
  • 20 images for one character在这里插入图片描述 Testing set: 640 support and query pairs
  • 5 support images
  • 5 query images在这里插入图片描述

实验

1、simple

简单的迁移学习模型
训练:对随机选择的5个任务进行正常的分类训练
验证和测试:对这5个支持图像进行微调,并对查询图像进行推理
在这里插入图片描述

在这里插入图片描述

2、medium

完成元学习内部和外部循环的TODO块,使用FO-MAML。设置solver = ‘meta’,epoch调节为120。FOMAML是MAML的简化版本,可节省训练时间,它忽略了内循环梯度对结果的影响。

#TODO: Finish the inner loop update rule
grads = torch.autograd.grad(loss, fast_weights.values())
fast_weights =OrderedDict((name, param - inner_lr*grad)for((name, param), grad) in zip(fast_weights.items(), grads))#raiseNotImplementedError训练过程中需要设置该函数为损失函数#TODO: Finish the outer loop update
meta_batch_loss.backward()
optimizer.step()#raiseNotimplementedError

在这里插入图片描述

在这里插入图片描述
在这里插入图片描述

3、strong

使用MAML,可以计算更高阶的梯度,MAML就能用到内循环梯度的梯度 。

#TODO: Finish the inner loop update rule
grads = torch.autograd.grad(loss, fast_weights.values(), create_graph=True)
fast_weights =OrderedDict((name, param - inner_lr*grad)for((name, param), grad) in zip(fast_weights.items(), grads))#raiseNotImplementedError训练过程中需要设置该函数为损失函数

在这里插入图片描述
在这里插入图片描述

4、boss

任务增强(通过元学习)-什么是合理的方法来创建新任务?
使用了task augmentation的方法来增加训练任务的变化性,有40%的可能性做augmentation,旋转90度或270度。
在这里插入图片描述

#MetaSolver函数中修改for meta_batch in x:#Get dataif torch.rand(1).item()>0.6:
        times =1if torch.rand(1).item()>0.5else3
        meta_batch = torch.rot90(meta_batch, times,[-1,-2])

在这里插入图片描述
在这里插入图片描述

三、代码

模型构建准备工作

由于我们的任务是图像分类,我们需要建立一个基于CNN的模型。但是,要实现MAML算法,我们需要调整“nn.Module”中的一些代码。在第10行,我们采用的梯度是代表原始模型参数(外环)的θ,而不是内环中的θ,因此我们需要使用functional_forward来计算输入图像的输出逻辑,而不是在nn.Module中使用forward。下面定义了这些功能。

def functional_forward(self, x, params):for block in [1,2,3,4]:
            x =ConvBlockFunction(
                x,
                params[f"conv{block}.0.weight"],
                params[f"conv{block}.0.bias"],
                params.get(f"conv{block}.1.weight"),
                params.get(f"conv{block}.1.bias"),)
        x = x.view(x.shape[0],-1)
        x = F.linear(x, params["logits.weight"], params["logits.bias"])return x

创建labels for 5-way 2-shot

def create_label(n_way, k_shot):return torch.arange(n_way).repeat_interleave(k_shot).long()#Try to create labels for5-way 2-shot settingcreate_label(5,2)

计算精度

def calculate_accuracy(logits, labels):"""utility function for accuracy calculation"""
    acc = np.asarray([(torch.argmax(logits,-1).cpu().numpy()== labels.cpu().numpy())]).mean()return acc

求解器首先从训练集中选择五个任务,然后对选择的五个任务进行正常的分类训练。在推理中,模型在支持集图像上对inner_train_step步骤进行微调,然后在查询集图像上进行推理。为了与元学习解算器保持一致,基本解算器具有与元学习解算器完全相同的输入和输出格式。

def BaseSolver(
    model,
    optimizer,
    x,
    n_way,
    k_shot,
    q_query,
    loss_fn,
    inner_train_step=1,
    inner_lr=0.4,
    train=True,
    return_labels=False,):
    criterion, task_loss, task_acc = loss_fn,[],[]
    labels =[]for meta_batch in x:#Get data
        support_set = meta_batch[: n_way * k_shot]
        query_set = meta_batch[n_way * k_shot :]if train:""" training loop """#Use the support set to calculate loss
            labels =create_label(n_way, k_shot).to(device)
            logits = model.forward(support_set)
            loss =criterion(logits, labels)

            task_loss.append(loss)
            task_acc.append(calculate_accuracy(logits, labels))else:""" validation / testing loop """#First update model with support set images for `inner_train_step` steps
            fast_weights =OrderedDict(model.named_parameters())for inner_step in range(inner_train_step):#Simply training
                train_label =create_label(n_way, k_shot).to(device)
                logits = model.functional_forward(support_set, fast_weights)
                loss =criterion(logits, train_label)

                grads = torch.autograd.grad(loss, fast_weights.values(), create_graph=True)#Perform SGD
                fast_weights =OrderedDict((name, param - inner_lr * grad)for((name, param), grad) in zip(fast_weights.items(), grads))if not return_labels:""" validation """
                val_label =create_label(n_way, q_query).to(device)

                logits = model.functional_forward(query_set, fast_weights)
                loss =criterion(logits, val_label)
                task_loss.append(loss)
                task_acc.append(calculate_accuracy(logits, val_label))else:""" testing """
                logits = model.functional_forward(query_set, fast_weights)
                labels.extend(torch.argmax(logits,-1).cpu().numpy())if return_labels:return labels

    batch_loss = torch.stack(task_loss).mean()
    task_acc = np.mean(task_acc)if train:#Update model
        model.train()
        optimizer.zero_grad()
        batch_loss.backward()
        optimizer.step()return batch_loss, task_acc

元学习

def MetaSolver(
    model,
    optimizer,
    x,
    n_way,
    k_shot,
    q_query,
    loss_fn,
    inner_train_step=1,
    inner_lr=0.4,
    train=True,
    return_labels=False
):
    criterion, task_loss, task_acc = loss_fn,[],[]
    labels =[]for meta_batch in x:#Get dataif torch.rand(1).item()>0.6:
            times =1if torch.rand(1).item()>0.5else3
            meta_batch = torch.rot90(meta_batch, times,[-1,-2])#  B =rot90(A,k) 将数组 A 按逆时针方向旋转 k*90 度
        support_set = meta_batch[: n_way * k_shot]
        query_set = meta_batch[n_way * k_shot :]#Copy the params for inner loop
        fast_weights =OrderedDict(model.named_parameters())

        ### ---------- INNER TRAIN LOOP ---------- ###
        for inner_step in range(inner_train_step):#Simply training
            train_label =create_label(n_way, k_shot).to(device)
            logits = model.functional_forward(support_set, fast_weights)
            loss =criterion(logits, train_label)#Inner gradients update! vvvvvvvvvvvvvvvvvvvv #""" Inner Loop Update """#TODO: Finish the inner loop update rule
            grads = torch.autograd.grad(loss, fast_weights.values(), create_graph=True)
            fast_weights =OrderedDict((name, param - inner_lr*grad)for((name, param), grad) in zip(fast_weights.items(), grads))#raiseNotImplementedError
            # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ #

        ### ---------- INNER VALID LOOP ---------- ###
        if not return_labels:""" training / validation """
            val_label =create_label(n_way, q_query).to(device)#Collect gradients for outer loop
            logits = model.functional_forward(query_set, fast_weights)
            loss =criterion(logits, val_label)
            task_loss.append(loss)
            task_acc.append(calculate_accuracy(logits, val_label))else:""" testing """
            logits = model.functional_forward(query_set, fast_weights)
            labels.extend(torch.argmax(logits,-1).cpu().numpy())if return_labels:return labels

    #Update outer loop
    model.train()
    optimizer.zero_grad()

    meta_batch_loss = torch.stack(task_loss).mean()if train:""" Outer Loop Update """#TODO: Finish the outer loop update
        meta_batch_loss.backward()
        optimizer.step()#raiseNotimplementedError

    task_acc = np.mean(task_acc)return meta_batch_loss, task_acc
标签: 学习 人工智能

本文转载自: https://blog.csdn.net/Raphael9900/article/details/128646394
版权归原作者 Raphael9900 所有, 如有侵权,请联系我们删除。

“机器学习HW15元学习”的评论:

还没有评论