

Local Feature Transformers (LoFTR)是一种Detector-free的局部特征匹配方法,使用了具有自注意层和互注意层的Transformer模块来处理从卷积网络中提取的密集局部特征:首先在低特征分辨率(图像维度的1/8)上提取密集匹配,然后从这些匹配中选择具有高可信度的匹配,使用基于相关的方法将其细化到高分辨率的亚像素级别。这样,模型的大感受野使转换后的特征符能够体现出上下文和位置信息,通过多次自注意力和互注意层,LoFTR学习在GT中的匹配先验。另外,LOFTR还采用Linear Attention方法将计算复杂度降低到可接受的水平。



  1. 局部特征CNN从图像对中提取粗粒度特征图 F ~ A \tilde{F}^{A} FA , F ~ B \tilde{F}^{B} FB,以及细粒度特征图 F ^ A \hat{F}^{A} F^A , F ^ B \hat{F}^{B} F^B
  2. 粗粒度的特征经过flatten操作并添加位置编码,然后由LoFTR模块进行处理;该模块中包括自注意和互注意层,并重复 Nc 次
  3. 使用可微匹配层来匹配LoFTR模块输出的特征 F ~ t r A \tilde{F}{tr}^{A} F~trA​ , F ~ t r B \tilde{F}{tr}^{B} F~trB​,得到一个置信矩阵 P c P_c Pc​ ,然后根据置信阈值和最近邻算法(MNN)选择匹配对,得到粗粒度的匹配预测 M c M_c Mc​
  4. 对于每个粗粒度匹配对 ,从细粒度级特征图中裁剪一个大小为w*w的局部窗口,粗粒度匹配将在此窗口内细化为亚像素级别,并作为最终的匹配预测 M f M_f Mf​



  1. Local Feature CNN
  2. coarse-level loftr module
  3. match coarse-level
  4. fine-level refinement
  5. match fine-level

基础特征提取模块Local Feature CNN


# 1. Local Feature CNN
        data.update({'bs': data['image0'].size(0),'hw0_i': data['image0'].shape[2:],'hw1_i': data['image1'].shape[2:]})if data['hw0_i']== data['hw1_i']:# faster & better BN convergenceprint(torch.cat([data['image0'], data['image1']], dim=0).shape)
            feats_c, feats_f = self.backbone(torch.cat([data['image0'], data['image1']], dim=0))print(feats_c.shape)# 1/8print(feats_f.shape)# 1/2(feat_c0, feat_c1),(feat_f0, feat_f1)= feats_c.split(data['bs']), feats_f.split(data['bs'])print(feat_c0.shape)print(feat_c1.shape)print(feat_f0.shape)print(feat_f1.shape)else:# handle different input shapes(feat_c0, feat_f0),(feat_c1, feat_f1)= self.backbone(data['image0']), self.backbone(data['image1'])print(feat_c0.shape)print(feat_c1.shape)print(feat_f0.shape)print(feat_f1.shape)
        data.update({'hw0_c': feat_c0.shape[2:],'hw1_c': feat_c1.shape[2:],'hw0_f': feat_f0.shape[2:],'hw1_f': feat_f1.shape[2:]})


defforward(self, x):# ResNet Backbone
        x0 = self.relu(self.bn1(self.conv1(x)))
        x1 = self.layer1(x0)# 1/2
        x2 = self.layer2(x1)# 1/4
        x3 = self.layer3(x2)# 1/8# FPN
        x3_out = self.layer3_outconv(x3)

        x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=True)
        x2_out = self.layer2_outconv(x2)
        x2_out = self.layer2_outconv2(x2_out+x3_out_2x)

        x2_out_2x = F.interpolate(x2_out, scale_factor=2., mode='bilinear', align_corners=True)
        x1_out = self.layer1_outconv(x1)
        x1_out = self.layer1_outconv2(x1_out+x2_out_2x)return[x3_out, x1_out]

注意力机制应用 coarse-level loftr module

# 2. coarse-level loftr module# add featmap with positional encoding, then flatten it to sequence [N, HW, C]# 添加位置编码
        feat_c0 = rearrange(self.pos_encoding(feat_c0),'n c h w -> n (h w) c')print(feat_c0.shape)
        feat_c1 = rearrange(self.pos_encoding(feat_c1),'n c h w -> n (h w) c')print(feat_c1.shape)
        mask_c0 = mask_c1 =None# mask is useful in trainingif'mask0'in data:
            mask_c0, mask_c1 = data['mask0'].flatten(-2), data['mask1'].flatten(-2)# 进入transformer模块,这是论文的核心模块
        feat_c0, feat_c1 = self.loftr_coarse(feat_c0, feat_c1, mask_c0, mask_c1)print(feat_c0.shape)print(feat_c1.shape)

进入self.loftr_coarse模块,这里需要计算自身的attention注意力,还需要将两张图像计算cross attention,从代码中的循环中可以看到,self和cross两种操作分别是自己和自己计算注意力以及自己和其他特征图计算注意力,从layer的计算参数可以明确这一点。

defforward(self, feat0, feat1, mask0=None, mask1=None):"""
            feat0 (torch.Tensor): [N, L, C]
            feat1 (torch.Tensor): [N, S, C]
            mask0 (torch.Tensor): [N, L] (optional)
            mask1 (torch.Tensor): [N, S] (optional)
        """assert self.d_model == feat0.size(2),"the feature number of src and transformer must be equal"for layer, name inzip(self.layers, self.layer_names):if name =='self':
                feat0 = layer(feat0, feat0, mask0, mask0)print(feat0.shape)
                feat1 = layer(feat1, feat1, mask1, mask1)print(feat1.shape)elif name =='cross':
                feat0 = layer(feat0, feat1, mask0, mask1)print(feat0.shape)
                feat1 = layer(feat1, feat0, mask1, mask0)print(feat1.shape)else:raise KeyError
        print(feat0.shape)print(feat1.shape)return feat0, feat1


defforward(self, x, source, x_mask=None, source_mask=None):"""
            x (torch.Tensor): [N, L, C]
            source (torch.Tensor): [N, S, C]
            x_mask (torch.Tensor): [N, L] (optional)
            source_mask (torch.Tensor): [N, S] (optional)
        bs = x.size(0)
        query, key, value = x, source, source

        # multi-head attention
        query = self.q_proj(query).view(bs,-1, self.nhead, self.dim)# [N, L, (H, D)]
        key = self.k_proj(key).view(bs,-1, self.nhead, self.dim)# [N, S, (H, D)]
        value = self.v_proj(value).view(bs,-1, self.nhead, self.dim)
        message = self.attention(query, key, value, q_mask=x_mask, kv_mask=source_mask)# [N, L, (H, D)]
        message = self.merge(message.view(bs,-1, self.nhead*self.dim))# [N, L, C]
        message = self.norm1(message)# feed-forward network
        message = self.mlp(torch.cat([x, message], dim=2))
        message = self.norm2(message)return x + message

粗粒度匹配模块 match coarse-level

# 3. match coarse-level
        self.coarse_matching(feat_c0, feat_c1, data, mask_c0=mask_c0, mask_c1=mask_c1)


defforward(self, feat_c0, feat_c1, data, mask_c0=None, mask_c1=None):"""
            feat0 (torch.Tensor): [N, L, C]
            feat1 (torch.Tensor): [N, S, C]
            data (dict)
            mask_c0 (torch.Tensor): [N, L] (optional)
            mask_c1 (torch.Tensor): [N, S] (optional)
            data (dict): {
                'b_ids' (torch.Tensor): [M'],
                'i_ids' (torch.Tensor): [M'],
                'j_ids' (torch.Tensor): [M'],
                'gt_mask' (torch.Tensor): [M'],
                'mkpts0_c' (torch.Tensor): [M, 2],
                'mkpts1_c' (torch.Tensor): [M, 2],
                'mconf' (torch.Tensor): [M]}
            NOTE: M' != M during training.
        N, L, S, C = feat_c0.size(0), feat_c0.size(1), feat_c1.size(1), feat_c0.size(2)# normalize
        feat_c0, feat_c1 =map(lambda feat: feat / feat.shape[-1]**.5,[feat_c0, feat_c1])if self.match_type =='dual_softmax':
            sim_matrix = torch.einsum("nlc,nsc->nls", feat_c0,
                                      feat_c1)/ self.temperature
            print(sim_matrix.shape)if mask_c0 isnotNone:
                sim_matrix.masked_fill_(~(mask_c0[...,None]* mask_c1[:,None]).bool(),-INF)
            conf_matrix = F.softmax(sim_matrix,1)* F.softmax(sim_matrix,2)print(conf_matrix.shape)elif self.match_type =='sinkhorn':# sinkhorn, dustbin included
            sim_matrix = torch.einsum("nlc,nsc->nls", feat_c0, feat_c1)if mask_c0 isnotNone:
                sim_matrix[:,:L,:S].masked_fill_(~(mask_c0[...,None]* mask_c1[:,None]).bool(),-INF)# build uniform prior & use sinkhorn
            log_assign_matrix = self.log_optimal_transport(
                sim_matrix, self.bin_score, self.skh_iters)
            assign_matrix = log_assign_matrix.exp()
            conf_matrix = assign_matrix[:,:-1,:-1]# filter prediction with dustbin score (only in evaluation mode)ifnot self.training and self.skh_prefilter:
                filter0 =(assign_matrix.max(dim=2)[1]== S)[:,:-1]# [N, L]
                filter1 =(assign_matrix.max(dim=1)[1]== L)[:,:-1]# [N, S]
                conf_matrix[filter0[...,None].repeat(1,1, S)]=0
                conf_matrix[filter1[:,None].repeat(1, L,1)]=0if self.config['sparse_spvs']:
                data.update({'conf_matrix_with_bin': assign_matrix.clone()})

        data.update({'conf_matrix': conf_matrix})# predict coarse matches from conf_matrix
        data.update(**self.get_coarse_match(conf_matrix, data))


defget_coarse_match(self, conf_matrix, data):"""
            conf_matrix (torch.Tensor): [N, L, S]
            data (dict): with keys ['hw0_i', 'hw1_i', 'hw0_c', 'hw1_c']
            coarse_matches (dict): {
                'b_ids' (torch.Tensor): [M'],
                'i_ids' (torch.Tensor): [M'],
                'j_ids' (torch.Tensor): [M'],
                'gt_mask' (torch.Tensor): [M'],
                'm_bids' (torch.Tensor): [M],
                'mkpts0_c' (torch.Tensor): [M, 2],
                'mkpts1_c' (torch.Tensor): [M, 2],
                'mconf' (torch.Tensor): [M]}
        axes_lengths ={'h0c': data['hw0_c'][0],'w0c': data['hw0_c'][1],'h1c': data['hw1_c'][0],'w1c': data['hw1_c'][1]}
        _device = conf_matrix.device
        # 1. confidence thresholding
        mask = conf_matrix > self.thr
        mask = rearrange(mask,'b (h0c w0c) (h1c w1c) -> b h0c w0c h1c w1c',**axes_lengths)print(mask.shape)if'mask0'notin data:
            mask_border(mask, self.border_rm,False)else:
            mask_border_with_padding(mask, self.border_rm,False,
                                     data['mask0'], data['mask1'])
        mask = rearrange(mask,'b h0c w0c h1c w1c -> b (h0c w0c) (h1c w1c)',**axes_lengths)print(mask.shape)# 2. mutual nearestprint(conf_matrix.max(dim=2, keepdim=True)[0].shape)
        mask = mask \
            *(conf_matrix == conf_matrix.max(dim=2, keepdim=True)[0]) \
            *(conf_matrix == conf_matrix.max(dim=1, keepdim=True)[0])print(mask.shape)# 3. find all valid coarse matches# this only works when at most one `True` in each row
        mask_v, all_j_ids = mask.max(dim=2)print(mask_v.shape)print(all_j_ids.shape)
        b_ids, i_ids = torch.where(mask_v)print(b_ids.shape)print(i_ids.shape)
        j_ids = all_j_ids[b_ids, i_ids]print(j_ids.shape)
        mconf = conf_matrix[b_ids, i_ids, j_ids]print(mconf.shape)# 4. Random sampling of training samples for fine-level LoFTR# (optional) pad samples with gt coarse-level matchesif self.training:# NOTE:# The sampling is performed across all pairs in a batch without manually balancing# #samples for fine-level increases w.r.t. batch_sizeif'mask0'notin data:
                num_candidates_max = mask.size(0)*max(
                    mask.size(1), mask.size(2))else:
                num_candidates_max = compute_max_candidates(
                    data['mask0'], data['mask1'])
            num_matches_train =int(num_candidates_max *
            num_matches_pred =len(b_ids)assert self.train_pad_num_gt_min < num_matches_train,"min-num-gt-pad should be less than num-train-matches"# pred_indices is to select from predictionif num_matches_pred <= num_matches_train - self.train_pad_num_gt_min:
                pred_indices = torch.arange(num_matches_pred, device=_device)else:
                pred_indices = torch.randint(
                    num_matches_pred,(num_matches_train - self.train_pad_num_gt_min,),
                    device=_device)# gt_pad_indices is to select from gt padding. e.g. max(3787-4800, 200)
            gt_pad_indices = torch.randint(len(data['spv_b_ids']),(max(num_matches_train - num_matches_pred,
            mconf_gt = torch.zeros(len(data['spv_b_ids']), device=_device)# set conf of gt paddings to all zero

            b_ids, i_ids, j_ids, mconf =map(lambda x, y: torch.cat([x[pred_indices], y[gt_pad_indices]],
                                       dim=0),*zip([b_ids, data['spv_b_ids']],[i_ids, data['spv_i_ids']],[j_ids, data['spv_j_ids']],[mconf, mconf_gt]))# These matches select patches that feed into fine-level network
        coarse_matches ={'b_ids': b_ids,'i_ids': i_ids,'j_ids': j_ids}# 4. Update with matches in original image resolution
        scale = data['hw0_i'][0]/ data['hw0_c'][0]
        scale0 = scale * data['scale0'][b_ids]if'scale0'in data else scale
        scale1 = scale * data['scale1'][b_ids]if'scale1'in data else scale
        mkpts0_c = torch.stack([i_ids % data['hw0_c'][1], i_ids // data['hw0_c'][1]],
            dim=1)* scale0
        mkpts1_c = torch.stack([j_ids % data['hw1_c'][1], j_ids // data['hw1_c'][1]],
            dim=1)* scale1

        # These matches is the current prediction (for visualization)
        coarse_matches.update({'gt_mask': mconf ==0,'m_bids': b_ids[mconf !=0],# mconf == 0 => gt matches'mkpts0_c': mkpts0_c[mconf !=0],'mkpts1_c': mkpts1_c[mconf !=0],'mconf': mconf[mconf !=0]})return coarse_matches

精细化调整 fine-level refinement


# 4. fine-level refinement# 预处理操作,拆解特征图,提取匹配到的候选点
        feat_f0_unfold, feat_f1_unfold = self.fine_preprocess(feat_f0, feat_f1, feat_c0, feat_c1, data)if feat_f0_unfold.size(0)!=0:# at least one coarse level predicted
            feat_f0_unfold, feat_f1_unfold = self.loftr_fine(feat_f0_unfold, feat_f1_unfold)


细粒度匹配 match fine-level


# 5. match fine-level
        self.fine_matching(feat_f0_unfold, feat_f1_unfold, data)


defforward(self, feat_f0, feat_f1, data):"""
            feat0 (torch.Tensor): [M, WW, C]
            feat1 (torch.Tensor): [M, WW, C]
            data (dict)
            data (dict):{
                'expec_f' (torch.Tensor): [M, 3],
                'mkpts0_f' (torch.Tensor): [M, 2],
                'mkpts1_f' (torch.Tensor): [M, 2]}
        M, WW, C = feat_f0.shape
        W =int(math.sqrt(WW))
        scale = data['hw0_i'][0]/ data['hw0_f'][0]
        self.M, self.W, self.WW, self.C, self.scale = M, W, WW, C, scale

        # corner case: if no coarse matches foundif M ==0:assert self.training ==False,"M is always >0, when training, see coarse_matching.py"# logger.warning('No matches found in coarse-level.')
            data.update({'expec_f': torch.empty(0,3, device=feat_f0.device),'mkpts0_f': data['mkpts0_c'],'mkpts1_f': data['mkpts1_c'],})return

        feat_f0_picked = feat_f0_picked = feat_f0[:, WW//2,:]print(feat_f0_picked.shape)
        sim_matrix = torch.einsum('mc,mrc->mr', feat_f0_picked, feat_f1)print(sim_matrix.shape)
        softmax_temp =1./ C**.5
        heatmap = torch.softmax(softmax_temp * sim_matrix, dim=1).view(-1, W, W)print(heatmap.shape)# compute coordinates from heatmap
        coords_normalized = dsnt.spatial_expectation2d(heatmap[None],True)[0]# [M, 2]print(coords_normalized.shape)
        grid_normalized = create_meshgrid(W, W,True, heatmap.device).reshape(1,-1,2)# [1, WW, 2]print(grid_normalized.shape)# compute std over <x, y>
        var = torch.sum(grid_normalized**2* heatmap.view(-1, WW,1), dim=1)- coords_normalized**2# [M, 2]
        std = torch.sum(torch.sqrt(torch.clamp(var,min=1e-10)),-1)# [M]  clamp needed for numerical stability# for fine-level supervision
        data.update({'expec_f': torch.cat([coords_normalized, std.unsqueeze(1)],-1)})# compute absolute kpt coords
        self.get_fine_match(coords_normalized, data)
标签: 人工智能 算法

本文转载自: https://blog.csdn.net/zhoulizhu/article/details/135379188
版权归原作者 AI柱子哥 所有, 如有侵权,请联系我们删除。

