0


空间转录组 STAGATE

最近在阅读和复现各个大佬的空转论文,记录、交流学习下,如有错误,欢迎指出。

前言

首先是STAGATE,是中科院提出来的方法,具体发表在NC上,主要思路与空转普遍的思路类似,提取基因表达、空间信息和图像特征,然后进行聚类,以识别每个spot的类型。当然,STAGATE,没有用图像信息,就已经是是目前已发表论文中最好的结果了。

总体架构

总体架构如下。

总体来说模型就是一个四层的AutoEncode,两层编码器两层解码器,只是每一层都换成了GAT。将基因表达数据X输入进去再重构出来X’,损失函数自然而然的就是X和X’的MSE。值得注意的是第二层和第三层,第一层和第四层分别共用一组权重W,为转置关系,这点在图上已经表明。如果是spot级别的数据,模型就已经全部讲完了,如果是细胞级别的数据,还会构建SNN,即重新构建一个新的GAT的邻接矩阵,然后每一层的结果是新的邻接矩阵和旧邻接矩阵构成的GAT加权求和为下一层的输入。

代码

作者最初发布的是tensorflow1的代码,今年三月份又公布了torch的代码,但是torch版本没有构建SNN,在细节上与tensorflow也略有不同,比如损失函数,tensorflow中除了MSE,又加入了权重损失防止过拟合,具体的在代码中我发现的都会提到。下面我试着根据torch版本的代码来说下我对这篇论文的理解。(最好在linux系统上运行,在windows上总是会出现各种奇怪错误)

首先是数据预处理。包括数据读取,在根据论文下载数据就好。然后是Normalization,选择高表达基因,正则化,取对数。再然后是读取真实标签用于最后测评并做了可视化。

  1. input_dir = os.path.join('Data', section_id)
  2. adata = sc.read_visium(path=input_dir, count_file=section_id+'_filtered_feature_bc_matrix.h5')
  3. adata.var_names_make_unique()
  4. #Normalization
  5. sc.pp.highly_variable_genes(adata, flavor="seurat_v3", n_top_genes=3000)
  6. sc.pp.normalize_total(adata, target_sum=1e4)
  7. sc.pp.log1p(adata)
  8. Ann_df = pd.read_csv(os.path.join('Data',
  9. section_id, "cluster_labels_"+section_id+'.csv'), sep=',', header=0, index_col=0)
  10. adata.obs['ground_truth'] = Ann_df.loc[adata.obs_names, 'ground_truth']
  11. plt.rcParams["figure.figsize"] = (3, 3)
  12. sc.pl.spatial(adata, img_key="hires", color=["ground_truth"])

然后是spot和spot之间的距离。距离大于0小于150的spot构建邻接矩阵,在这个范围内认为有连接,邻接矩阵为1,否则是0。以下是计算符合距离范围的spot的距离,并保存adata.uns['Spatial_Net']中。

  1. def Cal_Spatial_Net(adata, rad_cutoff=None, k_cutoff=None, model='Radius', verbose=True):
  2. """\
  3. Construct the spatial neighbor networks.
  4. Parameters
  5. ----------
  6. adata
  7. AnnData object of scanpy package.
  8. rad_cutoff
  9. radius cutoff when model='Radius'
  10. k_cutoff
  11. The number of nearest neighbors when model='KNN'
  12. model
  13. The network construction model. When model=='Radius', the spot is connected to spots whose distance is less than rad_cutoff. When model=='KNN', the spot is connected to its first k_cutoff nearest neighbors.
  14. Returns
  15. -------
  16. The spatial networks are saved in adata.uns['Spatial_Net']
  17. """
  18. assert(model in ['Radius', 'KNN'])
  19. if verbose:
  20. print('------Calculating spatial graph...')
  21. coor = pd.DataFrame(adata.obsm['spatial'])
  22. coor.index = adata.obs.index
  23. coor.columns = ['imagerow', 'imagecol']
  24. if model == 'Radius':
  25. nbrs = sklearn.neighbors.NearestNeighbors(radius=rad_cutoff).fit(coor)
  26. distances, indices = nbrs.radius_neighbors(coor, return_distance=True)
  27. KNN_list = []
  28. for it in range(indices.shape[0]):
  29. KNN_list.append(pd.DataFrame(zip([it]*indices[it].shape[0], indices[it], distances[it])))
  30. if model == 'KNN':
  31. nbrs = sklearn.neighbors.NearestNeighbors(n_neighbors=k_cutoff+1).fit(coor)
  32. distances, indices = nbrs.kneighbors(coor)
  33. KNN_list = []
  34. for it in range(indices.shape[0]):
  35. KNN_list.append(pd.DataFrame(zip([it]*indices.shape[1],indices[it,:], distances[it,:])))
  36. KNN_df = pd.concat(KNN_list)
  37. KNN_df.columns = ['Cell1', 'Cell2', 'Distance']
  38. Spatial_Net = KNN_df.copy()
  39. Spatial_Net = Spatial_Net.loc[Spatial_Net['Distance']>0,]
  40. id_cell_trans = dict(zip(range(coor.shape[0]), np.array(coor.index), ))
  41. Spatial_Net['Cell1'] = Spatial_Net['Cell1'].map(id_cell_trans)
  42. Spatial_Net['Cell2'] = Spatial_Net['Cell2'].map(id_cell_trans)
  43. if verbose:
  44. print('The graph contains %d edges, %d cells.' %(Spatial_Net.shape[0], adata.n_obs))
  45. print('%.4f neighbors per cell on average.' %(Spatial_Net.shape[0]/adata.n_obs))
  46. adata.uns['Spatial_Net'] = Spatial_Net

随后是一个可视化,平均每个spot有多少个邻居。

  1. def Stats_Spatial_Net(adata):
  2. import matplotlib.pyplot as plt
  3. Num_edge = adata.uns['Spatial_Net']['Cell1'].shape[0]
  4. Mean_edge = Num_edge/adata.shape[0]
  5. plot_df = pd.value_counts(pd.value_counts(adata.uns['Spatial_Net']['Cell1']))
  6. plot_df = plot_df/adata.shape[0]
  7. fig, ax = plt.subplots(figsize=[3,2])
  8. plt.ylabel('Percentage')
  9. plt.xlabel('')
  10. plt.title('Number of Neighbors (Mean=%.2f)'%Mean_edge)
  11. ax.bar(plot_df.index, plot_df)

下面就正式进入STAGATE的训练阶段了。

首先将是数据准备,包括两部分:根据挑选出来的邻居构建邻接矩阵和基因表达数据。

  1. def Transfer_pytorch_Data(adata):
  2. G_df = adata.uns['Spatial_Net'].copy()
  3. cells = np.array(adata.obs_names)
  4. cells_id_tran = dict(zip(cells, range(cells.shape[0])))
  5. G_df['Cell1'] = G_df['Cell1'].map(cells_id_tran)
  6. G_df['Cell2'] = G_df['Cell2'].map(cells_id_tran)
  7. G = sp.coo_matrix((np.ones(G_df.shape[0]), (G_df['Cell1'], G_df['Cell2'])), shape=(adata.n_obs, adata.n_obs))
  8. G = G + sp.eye(G.shape[0])
  9. edgeList = np.nonzero(G)
  10. if type(adata.X) == np.ndarray:
  11. data = Data(edge_index=torch.LongTensor(np.array(
  12. [edgeList[0], edgeList[1]])), x=torch.FloatTensor(adata.X)) # .todense()
  13. else:
  14. data = Data(edge_index=torch.LongTensor(np.array(
  15. [edgeList[0], edgeList[1]])), x=torch.FloatTensor(adata.X.todense())) # .todense()
  16. return data

然后构建STAGATE模型 正如前边所说四层GAT,其中h2是最后的特征向量,h4是重建的基因表达数据。

  1. class STAGATE(torch.nn.Module):
  2. def __init__(self, hidden_dims):
  3. super(STAGATE, self).__init__()
  4. [in_dim, num_hidden, out_dim] = hidden_dims
  5. self.conv1 = GATConv(in_dim, num_hidden, heads=1, concat=False,
  6. dropout=0, add_self_loops=False, bias=False)
  7. self.conv2 = GATConv(num_hidden, out_dim, heads=1, concat=False,
  8. dropout=0, add_self_loops=False, bias=False)
  9. self.conv3 = GATConv(out_dim, num_hidden, heads=1, concat=False,
  10. dropout=0, add_self_loops=False, bias=False)
  11. self.conv4 = GATConv(num_hidden, in_dim, heads=1, concat=False,
  12. dropout=0, add_self_loops=False, bias=False)
  13. def forward(self, features, edge_index):
  14. h1 = F.elu(self.conv1(features, edge_index))
  15. h2 = self.conv2(h1, edge_index, attention=False)
  16. self.conv3.lin_src.data = self.conv2.lin_src.transpose(0, 1)
  17. self.conv3.lin_dst.data = self.conv2.lin_dst.transpose(0, 1)
  18. self.conv4.lin_src.data = self.conv1.lin_src.transpose(0, 1)
  19. self.conv4.lin_dst.data = self.conv1.lin_dst.transpose(0, 1)
  20. h3 = F.elu(self.conv3(h2, edge_index, attention=True,
  21. tied_attention=self.conv1.attentions))
  22. h4 = self.conv4(h3, edge_index, attention=False)
  23. return h2, h4 # F.log_softmax(x, dim=-1)

具体的GAT代码不放了,详见`"Graph Attention Networks" https://arxiv.org/abs/1710.10903

具体训练代码如下,不同点是加了梯度截断,最后返回h2,或者说是z,也就是特征向量用于下一步聚类分析,保存到adata中。

  1. optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
  2. loss_list = []
  3. for epoch in tqdm(range(1, n_epochs+1)):
  4. model.train()
  5. optimizer.zero_grad()
  6. z, out = model(data.x, data.edge_index)
  7. loss = F.mse_loss(data.x, out) #F.nll_loss(out[data.train_mask], data.y[data.train_mask])
  8. loss_list.append(loss)
  9. loss.backward()
  10. torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clipping)
  11. optimizer.step()
  12. model.eval()
  13. z, out = model(data.x, data.edge_index)
  14. STAGATE_rep = z.to('cpu').detach().numpy()
  15. adata.obsm[key_added] = STAGATE_rep
  16. if save_loss:
  17. adata.uns['STAGATE_loss'] = loss
  18. if save_reconstrction:
  19. ReX = out.to('cpu').detach().numpy()
  20. ReX[ReX<0] = 0
  21. adata.layers['STAGATE_ReX'] = ReX

最后调用了R中的mclust包进行聚类。

  1. def mclust_R(adata, num_cluster, modelNames='EEE', used_obsm='STAGATE', random_seed=2020):
  2. """\
  3. Clustering using the mclust algorithm.
  4. The parameters are the same as those in the R package mclust.
  5. """
  6. np.random.seed(random_seed)
  7. import rpy2.robjects as robjects
  8. robjects.r.library("mclust")
  9. import rpy2.robjects.numpy2ri
  10. rpy2.robjects.numpy2ri.activate()
  11. r_random_seed = robjects.r['set.seed']
  12. r_random_seed(random_seed)
  13. rmclust = robjects.r['Mclust']
  14. res = rmclust(rpy2.robjects.numpy2ri.numpy2rpy(adata.obsm[used_obsm]), num_cluster, modelNames)
  15. mclust_res = np.array(res[-2])
  16. adata.obs['mclust'] = mclust_res
  17. adata.obs['mclust'] = adata.obs['mclust'].astype('int')
  18. adata.obs['mclust'] = adata.obs['mclust'].astype('category')
  19. return adata

去掉缺失值并计算ARI。tensorflow版本和后续的数据分析解析等我看明白再来记录,最后附上测试DFPFC数据库的主函数。所有代码、数据和论文可以再github上下载,欢迎交流。

  1. import warnings
  2. warnings.filterwarnings("ignore")
  3. import pandas as pd
  4. import numpy as np
  5. import scanpy as sc
  6. import matplotlib.pyplot as plt
  7. import os
  8. import sys
  9. from sklearn.metrics.cluster import adjusted_rand_score
  10. # import sklearn
  11. import STAGATE_pyG as STAGATE
  12. os.environ['R_HOME'] = '/home/admin/anaconda3/envs/lib/R'
  13. # os.environ['R_USER'] = '/home/admin/Anaconda3\Lib\site-packages/rpy2'
  14. dataset = ["151507", "151508", "151509", "151510", "151669", "151670", "151671", "151672", "151673", "151674", "151675",
  15. "151676"]
  16. knn = [7, 7, 7, 7, 5, 5, 5, 5, 7, 7, 7, 7]
  17. ARIlist = []
  18. for section_id, k in zip(dataset, knn):
  19. print(section_id,k)
  20. input_dir = os.path.join('Data', section_id)
  21. adata = sc.read_visium(path=input_dir, count_file=section_id+'_filtered_feature_bc_matrix.h5')
  22. adata.var_names_make_unique()
  23. #Normalization
  24. sc.pp.highly_variable_genes(adata, flavor="seurat_v3", n_top_genes=3000)
  25. sc.pp.normalize_total(adata, target_sum=1e4)
  26. sc.pp.log1p(adata)
  27. Ann_df = pd.read_csv(os.path.join('Data',
  28. section_id, "cluster_labels_"+section_id+'.csv'), sep=',', header=0, index_col=0)
  29. adata.obs['ground_truth'] = Ann_df.loc[adata.obs_names, 'ground_truth']
  30. plt.rcParams["figure.figsize"] = (3, 3)
  31. sc.pl.spatial(adata, img_key="hires", color=["ground_truth"])
  32. STAGATE.Cal_Spatial_Net(adata, rad_cutoff=150)
  33. STAGATE.Stats_Spatial_Net(adata)
  34. adata = STAGATE.train_STAGATE(adata)
  35. sc.pp.neighbors(adata, use_rep='STAGATE')
  36. sc.tl.umap(adata)
  37. adata = STAGATE.mclust_R(adata, used_obsm='STAGATE', num_cluster=k)
  38. obs_df = adata.obs.dropna()
  39. ARI = adjusted_rand_score(obs_df['mclust'], obs_df['ground_truth'])
  40. ARIlist.append(ARI)
  41. print('Adjusted rand index = %.2f' %ARI)
  42. print("ari mean", np.mean(ARIlist))
  43. print("ari median", np.median(ARIlist))

本文转载自: https://blog.csdn.net/qq_27648263/article/details/125779049
版权归原作者 qq_27648263 所有, 如有侵权,请联系我们删除。

“空间转录组 STAGATE”的评论:

还没有评论