0


声纹识别之说话人验证speaker verification

  1. 写在最前面,最近几个月并没有在写博客上投入时间,主要是其他事情比较多也比较忙。20228月以后就开始准备婚礼、看房、买房,举行婚礼和看车等等,工作上也在做项目和打一些比赛,并没有什么值得写的。由于工作需要接触到了语音领域的声纹识别,对语音识别进行了一些预研,因此在这里开一篇博客,聊一聊speaker verification学习历程。

一、speaker verification简介

  1. Speaker Verification——说话人验证属于声纹识别领域范畴——给定两个音频,判定它们是不是同一个人所说。这里有两种不同的类型,一种是基于文本有关的,一种是基于文本无关的。基于文本有关的——每次检验的是否是同一个人说话,需要受检者说出限定范围的文本;而基于文本无关的则不需要,可以随意说话。前者相对容易一点,后者相对困难一点。Speaker Verification核心之处在于模型能够提炼出不同人声音的特征,且要有很好的区分度。

  1. 如上图所示,要判定EnrollmentEvaluation两个音频是不是同一个说话人,一般而言,可以把两个音频直接输入模型,训练一个分类模型,让模型来判定是不是同一个类别;也可以提前把Enrollment用训练好的模型提取出一个多维向量;等到Evaluation需要验证的时候,用模型同样提取响应特征向量,计算两个向量的向量度,根据阈值判定。在实际应用过程中,为了满足高效率,大多采用后者,提前把被检音频提取向量存储到对应的库中,然后检测音频实时抽取向量,计算向量,根据设定的阈值判定是否为同一个人。
  2. 在实际应用之前,需要对训练好的模型和整体的Speaker Verification系统进行评价。模型端评价根据建模的任务,一般采取F1值或者ACCRecall等来评价。而评价实际的Speaker Verification系统,则有自己的一套评价体系和指标。主要是如下的评价指标:
  3. FAR(False Accept Rate 错误接受率)
  4. FRR(False Reject Rate错误拒绝率)
  5. EER(Equal Error Rate 等错误率
  6. FRR = Nfr/Ntarget 其中Nfr是指应该通过而被拒绝测试用例的数量,Ntarget 是指所有应该通过测试用例的总数
  7. FAR = Nfa/Nnotarget 其中Nfa是指不应该通过也通过的测试用例的数量,Nnotarget 是指所有不应该通过测试用例的总数
  8. EER 是指FAR==FRR时的错误率。它说话人确认系统中常用的性能评价指标
  9. 这个没有考虑错误接受以及错误拒绝不同的影响,因此为了把它们不同的影响也考虑起来,设计不同的权重,同时也把受检者是真是假的先验概率考虑进来,得到一个新的指标dcf

  1. PT真实说话人出现的先验概率,PI假的说话人出现的先验概率;越严格的系统PI/PT的值越大。比较常见的比值是1:991:999
  2. 通过不断的调整阈值,DCF是会变化的,取最小的dcf的时候对应的阈值,会使得整个系统有最佳的表现。

二、主流方案和模型

  1. speaker verification发展了很多年,有许多的方案。传统的一些方案,主要是利用信号处理方式,把时序信号转换为频域信号,然后再通过一些手段进行区分。看一张计算方案的演进图(摘抄自知乎问答——声纹识别算法有哪几种):

  1. 其中可能涉及到的声学特征有MFCCFBankSpectrogram等,以及对它的一些数据增强。时至2022年了,大家更加关注端到端的方案,使用神经网络自动提取声学特征。比较主流的是Ecapa_TDNN模型,它于2020年被提出,通过引入SE (squeeze-excitation)模块以及通道注意机制,该方案在国际声纹识别比赛(VoxSRC2020)中取得了第一名;同时在2022年的FFSVC说话人验证任务中,该模型也被作为baseline。另外就是预训练模型,在语音领域也有很多类似文本领域Bert的预训练模型,其中个人认为效果最好的就是WavLm模型。

1、Ecapa_TDNN模型

先看整体结构图:

  1. 可以看到ecapa_tdnnconv1D+BNSE-Res2BlockASP+BNFC+BN以及AAM-softmax等模块构成。其中SE-Res2Block能是模型学习到音频数据中更多的全局信息,这个比之前的d-vector效果更好。

SE-Res2Block:

  1. SE-Res2Block主要是Res2Block模块中引入了SE-Block模块——这是一个通道注意力模块,比较经典在各种网络中都表现的比较不错。

2、WavLm

  1. 它是微软亚洲研究院与微软 Azure 语音组使用Transformer模型架构和Denoising Masked Speech Modeling 框架直接对音频时序数据进行类似Bert的掩码预训练,使用了海量的音频数据进行了预训练,在语音任务上取得了很好的效果。

  1. 模型网络结构如图所示,特征抽取采用CNN网络层,然后特征编码采用transformer-block层,具体的模型细节这里就不分析了,可以把它看做为一个音频领域的bert,实现细节稍有不同,具体的实现可以去看huggingface的实现——WavLmWavLmModel等。

三、代码实践

1、Ecapa_TDNN方案

a、模型结构

  1. 代码参考了百度的paddleSpeechpaddle版本和SpeechBrainpytorch版本代码,并做了一些删减,同时也参考了一些个人的实现VoiceprintRecognition-Pytorch,对它们的代码进行了综合考量,得到下面的Ecapa_TDNN模型结构代码
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from torch.nn import Parameter
  5. class TDNNBlock(nn.Module):
  6. """An implementation of TDNN."""
  7. def __init__(self, in_channels, out_channels, kernel_size, dilation, groups=1,padding=0):
  8. super(TDNNBlock, self).__init__()
  9. self.conv = nn.Conv1d(in_channels=in_channels,out_channels=out_channels,kernel_size=kernel_size, dilation=dilation,groups=groups,padding=padding)
  10. self.activation = nn.ReLU()
  11. self.bn = nn.BatchNorm1d(out_channels)
  12. def forward(self,x):
  13. x = self.conv(x)
  14. x = self.activation(x)
  15. x = self.bn(x)
  16. return x
  17. class Res2NetBlock(torch.nn.Module):
  18. """An implementation of Res2NetBlock w/ dilation.
  19. Example
  20. -------
  21. inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2)
  22. layer = Res2NetBlock(64, 64, scale=4, dilation=3)
  23. out_tensor = layer(inp_tensor).transpose(1, 2)
  24. out_tensor.shape
  25. torch.Size([8, 120, 64])
  26. """
  27. def __init__(
  28. self, in_channels, out_channels, scale=8, kernel_size=3, dilation=1,padding =0
  29. ):
  30. super(Res2NetBlock, self).__init__()
  31. assert in_channels % scale == 0
  32. assert out_channels % scale == 0
  33. in_channel = in_channels // scale
  34. hidden_channel = out_channels // scale
  35. self.blocks = nn.ModuleList(
  36. [
  37. TDNNBlock(
  38. in_channel,
  39. hidden_channel,
  40. kernel_size=kernel_size,
  41. dilation=dilation,
  42. padding = padding
  43. )
  44. for i in range(scale - 1)
  45. ]
  46. )
  47. self.scale = scale
  48. def forward(self, x):
  49. y = []
  50. for i, x_i in enumerate(torch.chunk(x, self.scale, dim=1)):
  51. if i == 0:
  52. y_i = x_i
  53. elif i == 1:
  54. y_i = self.blocks[i - 1](x_i)
  55. else:
  56. y_i = self.blocks[i - 1](x_i + y_i)
  57. y.append(y_i)
  58. y = torch.cat(y, dim=1)
  59. return y
  60. class SEBlock(nn.Module):
  61. """
  62. 省略了mask
  63. """
  64. def __init__(self, in_channels, se_channels, out_channels):
  65. super(SEBlock,self).__init__()
  66. self.conv1 = nn.Conv1d(in_channels=in_channels, out_channels=se_channels, kernel_size=1)
  67. self.relu = nn.ReLU(inplace=True)
  68. self.conv2 = nn.Conv1d(in_channels=se_channels, out_channels=out_channels, kernel_size=1)
  69. self.sigmoid = nn.Sigmoid()
  70. def forward(self,x):
  71. s = x.mean(dim=2, keepdim=True)
  72. s = self.relu(self.conv1(s))
  73. s = self.sigmoid(self.conv2(s))
  74. out = s * x
  75. return out
  76. class SERes2NetBlock(nn.Module):
  77. def __init__(self,in_channels,
  78. out_channels,
  79. res2net_scale=8,
  80. se_channels=128,
  81. kernel_size=1,
  82. dilation=1,
  83. groups=1,
  84. padding = 0):
  85. super(SERes2NetBlock, self).__init__()
  86. self.out_channels = out_channels
  87. self.tdnn1 = TDNNBlock(
  88. in_channels,
  89. out_channels,
  90. kernel_size=1,
  91. dilation=1,
  92. groups=groups,
  93. )
  94. self.res2net_block = Res2NetBlock(
  95. out_channels, out_channels, res2net_scale, kernel_size,padding, dilation
  96. )
  97. self.tdnn2 = TDNNBlock(
  98. out_channels,
  99. out_channels,
  100. kernel_size=1,
  101. dilation=1,
  102. groups=groups,
  103. )
  104. self.se_block = SEBlock(out_channels, se_channels, out_channels)
  105. self.shortcut = None
  106. if in_channels != out_channels:
  107. self.shortcut = nn.Conv1d(
  108. in_channels=in_channels,
  109. out_channels=out_channels,
  110. kernel_size=1,
  111. )
  112. def forward(self, x):
  113. """ Processes the input tensor x and returns an output tensor."""
  114. residual = x
  115. if self.shortcut:
  116. residual = self.shortcut(x)
  117. x = self.tdnn1(x)
  118. x = self.res2net_block(x)
  119. x = self.tdnn2(x)
  120. x = self.se_block(x)
  121. return x + residual
  122. class AttentiveStatsPool(nn.Module):
  123. def __init__(self, in_dim, bottleneck_dim):
  124. super(AttentiveStatsPool,self).__init__()
  125. # Use Conv1d with stride == 1 rather than Linear, then we don't need to transpose inputs.
  126. self.linear1 = nn.Conv1d(in_dim, bottleneck_dim, kernel_size=1) # equals W and b in the paper
  127. self.linear2 = nn.Conv1d(bottleneck_dim, in_dim, kernel_size=1) # equals V and k in the paper
  128. def forward(self, x):
  129. # DON'T use ReLU here! In experiments, I find ReLU hard to converge.
  130. alpha = torch.tanh(self.linear1(x))
  131. alpha = torch.softmax(self.linear2(alpha), dim=2)
  132. mean = torch.sum(alpha * x, dim=2)
  133. residuals = torch.sum(alpha * x ** 2, dim=2) - mean ** 2
  134. std = torch.sqrt(residuals.clamp(min=1e-9))
  135. return torch.cat([mean, std], dim=1)
  136. class ECAPATDNN(nn.Module):
  137. def __init__(self,
  138. input_size,
  139. lin_neurons=192,
  140. channels=[512, 512, 512, 512, 1536],
  141. kernel_sizes=[5, 3, 3, 3, 1],
  142. dilations=[1, 2, 3, 4, 1],
  143. attention_channels=128,
  144. res2net_scale=8,
  145. se_channels=128,
  146. groups=[1, 1, 1, 1, 1],
  147. paddings = [0,2,3,4,0]):
  148. super(ECAPATDNN, self).__init__()
  149. assert len(channels) == len(kernel_sizes)
  150. assert len(channels) == len(dilations)
  151. self.emb_size = lin_neurons
  152. self.channels = channels
  153. self.blocks = nn.ModuleList()
  154. self.blocks.append(
  155. TDNNBlock(
  156. input_size,
  157. channels[0],
  158. kernel_sizes[0],
  159. dilations[0],
  160. groups[0]
  161. )
  162. )
  163. for i in range(1,len(channels) -1):
  164. self.blocks.append(
  165. SERes2NetBlock(
  166. channels[i-1],channels[i],res2net_scale, se_channels, kernel_sizes[i],dilations[i],groups[i],paddings[i]
  167. )
  168. )
  169. self.mfa = TDNNBlock(
  170. channels[-1],
  171. channels[-1],
  172. kernel_sizes[-1],
  173. dilations[-1],
  174. groups[-1]
  175. )
  176. self.asp = AttentiveStatsPool(channels[-1],attention_channels)
  177. self.asp_bn = nn.BatchNorm1d(channels[-1] * 2)
  178. self.fc = nn.Conv1d(
  179. in_channels=channels[-1] * 2,
  180. out_channels=lin_neurons,
  181. kernel_size=1,
  182. )
  183. def forward(self,x):
  184. xl = []
  185. for layer in self.blocks:
  186. x = layer(x)
  187. xl.append(x)
  188. # Multi-layer feature aggregation
  189. x = torch.cat(xl[1:], dim=1)
  190. x = x.data
  191. x = self.mfa(x)
  192. # Attentive Statistical Pooling
  193. x = self.asp(x)
  194. x = self.asp_bn(x)
  195. x = x.unsqueeze(2)
  196. # Final linear transformation
  197. x = self.fc(x)
  198. return x
  199. class SpeakerIdentificationModel(nn.Module):
  200. def __init__(self,backbone,num_class=1,dropout=0.1):
  201. super(SpeakerIdentificationModel, self).__init__()
  202. self.backbone = backbone
  203. if dropout > 0:
  204. self.dropout = nn.Dropout(dropout)
  205. else:
  206. self.dropout = None
  207. input_size = self.backbone.emb_size
  208. # the final layer nn.Linear 采用不同的权重初始化
  209. self.weight = Parameter(torch.FloatTensor(num_class, input_size), requires_grad=True)
  210. nn.init.xavier_normal_(self.weight, gain=1)
  211. def forward(self,x):
  212. x = self.backbone(x)
  213. if self.dropout is not None:
  214. x = self.dropout(x)
  215. logits = F.linear(F.normalize(x.squeeze(2)),weight=F.normalize(self.weight,dim=-1))
  216. return logits

b、loss

  1. 这部分代码摘抄自VoiceprintRecognition-Pytorch
  2. Additive Angular Margin Loss(加性角度间隔损失函数)结合KLDivLoss(KL散度loss)得到最后的AAMloss
  1. import math
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. class AdditiveAngularMargin(nn.Module):
  6. def __init__(self, margin=0.0, scale=1.0, easy_margin=False):
  7. """The Implementation of Additive Angular Margin (AAM) proposed
  8. in the following paper: '''Margin Matters: Towards More Discriminative Deep Neural Network Embeddings for Speaker Recognition'''
  9. (https://arxiv.org/abs/1906.07317)
  10. Args:
  11. margin (float, optional): margin factor. Defaults to 0.0.
  12. scale (float, optional): scale factor. Defaults to 1.0.
  13. easy_margin (bool, optional): easy_margin flag. Defaults to False.
  14. """
  15. super(AdditiveAngularMargin, self).__init__()
  16. self.margin = margin
  17. self.scale = scale
  18. self.easy_margin = easy_margin
  19. self.cos_m = math.cos(self.margin)
  20. self.sin_m = math.sin(self.margin)
  21. self.th = math.cos(math.pi - self.margin)
  22. self.mm = math.sin(math.pi - self.margin) * self.margin
  23. def forward(self, outputs, targets):
  24. cosine = outputs.float()
  25. sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
  26. phi = cosine * self.cos_m - sine * self.sin_m
  27. if self.easy_margin:
  28. phi = torch.where(cosine > 0, phi, cosine)
  29. else:
  30. phi = torch.where(cosine > self.th, phi, cosine - self.mm)
  31. outputs = (targets * phi) + ((1.0 - targets) * cosine)
  32. return self.scale * outputs
  33. class AAMLoss(nn.Module):
  34. def __init__(self, margin=0.2, scale=30, easy_margin=False):
  35. super(AAMLoss, self).__init__()
  36. self.loss_fn = AdditiveAngularMargin(margin=margin, scale=scale, easy_margin=easy_margin)
  37. self.criterion = torch.nn.KLDivLoss(reduction="sum")
  38. def forward(self, outputs, targets):
  39. targets = F.one_hot(targets, outputs.shape[1]).float()
  40. predictions = self.loss_fn(outputs, targets)
  41. predictions = F.log_softmax(predictions, dim=1)
  42. loss = self.criterion(predictions, targets) / targets.sum()
  43. return loss

c、数据处理

  1. 这部分代码功能是对wav或者mp3数据进行语音特征处理,比如fbank(melspectrogram)、spectrogram以及梅尔倒谱系数mffcc等等
  1. import random
  2. import torch
  3. from torch.utils.data import Dataset
  4. import torchaudio
  5. from tqdm import tqdm
  6. class AudioDataReader(Dataset):
  7. def __init__(self, data_list_path,
  8. feature_method='melspectrogram',
  9. mode='train',
  10. sr=16000,
  11. chunk_duration=3,
  12. min_duration=0.5,
  13. label2ids = {},
  14. augmentors=None):
  15. super(AudioDataReader, self).__init__()
  16. assert data_list_path is not None
  17. with open(data_list_path,'r',encoding='utf-8') as f:
  18. self.lines = f.readlines()[0:]
  19. self.feature_method = feature_method
  20. self.mode = mode
  21. self.sr = sr
  22. self.chunk_duration = chunk_duration
  23. self.min_duration = min_duration
  24. self.augmentors = augmentors
  25. self.label2ids = label2ids
  26. self.audiofeatures = self.getaudiofeatures()
  27. def load_audio(self, audio_path,
  28. feature_method='melspectrogram',
  29. mode='train',
  30. sr=16000,
  31. chunk_duration=3,
  32. min_duration=0.5,
  33. augmentors=None):
  34. """
  35. 加载并预处理音频
  36. :param audio_path: 音频路径
  37. :param feature_method: 预处理方法melspectrogram(Fbank)梅尔频谱/MFCC梅尔倒谱系数/spectrogram声谱图
  38. :param mode: 对数据处理的方式,包括train,eval,infer
  39. :param sr: 采样率
  40. :param chunk_duration: 训练或者评估使用的音频长度
  41. :param min_duration: 最小训练或者评估的音频长度
  42. :param augmentors: 数据增强方法
  43. :return:
  44. """
  45. wav, sample_rate = torchaudio.load(audio_path) # 加载音频返回的是张量
  46. num_wav_samples = wav.shape[1]
  47. # 数据太短不利于训练
  48. if mode == 'train':
  49. if num_wav_samples < int(min_duration * sr):
  50. raise Exception(f'音频长度小于{min_duration}s,实际长度为:{(num_wav_samples / sr):.2f}s')
  51. # print(f'音频长度小于{min_duration}s,实际长度为:{(num_wav_samples / sr):.2f}s')
  52. # return None
  53. # 对小于训练长度的复制补充
  54. num_chunk_samples = int(chunk_duration * sr)
  55. if num_wav_samples < num_chunk_samples:
  56. times = int(num_chunk_samples / num_wav_samples) - 1
  57. shortages = []
  58. temp_num_wav_samples = num_wav_samples
  59. shortages.append(wav)
  60. if times >= 1:
  61. for _ in range(times):
  62. shortages.append(wav)
  63. temp_num_wav_samples += num_wav_samples
  64. shortages.append(wav[:,0:(num_chunk_samples - temp_num_wav_samples)])
  65. else:
  66. shortages.append(wav[:,0:(num_chunk_samples - num_wav_samples)])
  67. wav = torch.cat(shortages, dim=1)
  68. # 裁剪需要的数据
  69. if mode == 'train':
  70. # 随机裁剪
  71. num_wav_samples = wav.shape[1]
  72. num_chunk_samples = int(chunk_duration * sr)
  73. if num_wav_samples > num_chunk_samples + 1:
  74. start = random.randint(0, num_wav_samples - num_chunk_samples - 1)
  75. end = start + num_chunk_samples
  76. wav = wav[:,start:end]
  77. # # 对每次都满长度的再次裁剪
  78. # if random.random() > 0.5:
  79. # wav[:random.randint(1, sr // 4)] = 0 #加入了静音数据
  80. # wav = wav[:-random.randint(1, sr // 4)]
  81. # 数据增强
  82. if augmentors is not None:
  83. for key, augmentor in augmentors.items():
  84. if key == 'specaug':
  85. continue
  86. wav = wav.numpy()
  87. #转换为numpy,然后做增强
  88. wav = augmentor(wav)
  89. wav = torch.from_numpy(wav)
  90. elif mode == 'eval':
  91. # 为避免显存溢出,只裁剪指定长度
  92. num_wav_samples = wav.shape[1]
  93. num_chunk_samples = int(chunk_duration * sr)
  94. if num_wav_samples > num_chunk_samples + 1:
  95. wav = wav[:,0:num_chunk_samples]
  96. if feature_method == "melspectrogram":
  97. # 梅尔频谱 Fbank
  98. features = torchaudio.transforms.MelSpectrogram(sample_rate=sr, n_fft=400, n_mels=80, hop_length=160, win_length=400)(wav)
  99. elif feature_method == "spectrogram":
  100. # 声谱图
  101. features = torchaudio.transforms.Spectrogram( n_fft=400, win_length=400, hop_length=160)(wav)
  102. elif feature_method == "MFCC":
  103. features = torchaudio.transforms.MFCC(sample_rate=sr, n_fft=400, n_mels=80, hop_length=160, win_length=400)(wav)
  104. else:
  105. raise Exception(f'预处理方法 {feature_method} 不存在!')
  106. # 数据增强
  107. if mode == 'train' and augmentors is not None:
  108. for key, augmentor in augmentors.items():
  109. if key == 'specaug':
  110. features = augmentor(features)
  111. # 需要归一化
  112. features = torch.nn.LayerNorm(features.shape[-1])(features).squeeze(0)
  113. return features
  114. def getaudiofeatures(self):
  115. res = []
  116. for line in tqdm(self.lines,desc= self.mode + ' load all audios',ncols=100):
  117. temp = []
  118. try:
  119. audio_path, label = line.replace('\n', '').split('\t')
  120. label = self.label2ids[label]
  121. features = self.load_audio(audio_path=audio_path, feature_method=self.feature_method, mode=self.mode,
  122. sr=self.sr, chunk_duration=self.chunk_duration,
  123. min_duration=self.min_duration,
  124. augmentors=self.augmentors)
  125. label = torch.as_tensor(label, dtype=torch.long)
  126. temp.append(features)
  127. temp.append(label)
  128. res.append(temp)
  129. except Exception as e:
  130. print(e+',load audio data exception')
  131. return res
  132. @property
  133. def input_size(self):
  134. if self.feature_method == 'melspectrogram':
  135. return 80
  136. elif self.feature_method == 'spectrogram':
  137. return 201
  138. else:
  139. raise Exception(f'预处理方法 {self.feature_method} 不存在!')
  140. def __getitem__(self, item):
  141. return self.audiofeatures[item][0], self.audiofeatures[item][1]
  142. def __len__(self):
  143. return len(self.audiofeatures)
  1. 值得注意的是没有在__getitem__()函数中读取音频加载数据,而是直接全部加载到内存中,如果数据量过大还是要在_getitem__()函数中读取音频加载数据,减小内存消耗,当然训练速度会减慢。

d、模型训练和评估

  1. 数据集采用公共数据集:zhvoice: Chinese voice corpus中的zhstcmds数据
  1. "zhstcmds": {
  2. "character_W": 111.9317,
  3. "duration_H": 74.53628,
  4. "n_audio_per_speaker": 120.0,
  5. "n_character_per_sentence": 10.909522417153998,
  6. "n_minute_per_speaker": 5.230616140350877,
  7. "n_second_per_audio": 2.6153080701754385,
  8. "n_speaker": 855,
  9. "sentence_W": 10.26,
  10. "size_MB": 767.7000274658203
  11. }
  1. 总计104963条数据,随机切分,验证集10000条,训练集94963条数据。
  2. 训练代码如下
  1. from models.loss import AAMLoss
  2. from models.ecapa_tdnn import SpeakerIdentificationModel,ECAPATDNN
  3. # from models.ecapa_tdnn import SpeakerIdetification,EcapaTdnn
  4. from tools.log import Logger
  5. from tools.progressbar import ProgressBar
  6. from data_utils.reader import AudioDataReader
  7. from data_utils.noise_perturb import NoisePerturbAugmentor
  8. from data_utils.speed_perturb import SpeedPerturbAugmentor
  9. from data_utils.volum_perturb import VolumePerturbAugmentor
  10. from data_utils.spec_augment import SpecAugmentor
  11. from torch.utils.data import DataLoader
  12. import torch
  13. import os
  14. from torch.optim import AdamW
  15. from torch.optim.lr_scheduler import CosineAnnealingLR
  16. import argparse
  17. import random
  18. import numpy as np
  19. from torch.utils.tensorboard import SummaryWriter
  20. from datetime import datetime
  21. import yaml
  22. import torch.nn as nn
  23. def parse_args():
  24. parser = argparse.ArgumentParser()
  25. parser.add_argument("--train_datas_path", type=str, default='./data/train_audio_paths.txt', help="train text file")
  26. parser.add_argument("--val_datas_path", type=str, default='./data/val_audio_paths.txt', help="val text file")
  27. # parser.add_argument("--train_datas_path", type=str, default='./data/train_audio_paths_small.txt', help="train text file")
  28. # parser.add_argument("--val_datas_path", type=str, default='./data/val_audio_paths_small.txt', help="val text file")
  29. parser.add_argument("--log_file", type=str, default="./log_output/speaker_identification.log", help="log_file")
  30. parser.add_argument("--model_out", type=str, default="./output/", help="model output path")
  31. parser.add_argument("--batch_size", type=int, default=64, help="batch size")
  32. parser.add_argument("--epochs", type=int, default=30, help="epochs")
  33. parser.add_argument("--lr", type=float, default=1e-3, help="epochs")
  34. parser.add_argument("--random_seed", type=int, default=100, help="random_seed")
  35. parser.add_argument("--device", type=str, default='1', help="device")
  36. args = parser.parse_args()
  37. return args
  38. def training(args):
  39. os.environ['CUDA_VISIBLE_DEVICES'] = args.device
  40. logger = Logger(log_name='SI',log_level=10,log_file=args.log_file).logger
  41. logger.info(args)
  42. label2ids = {}
  43. id = 0
  44. with open(args.train_datas_path,'r',encoding='utf-8') as f:
  45. lines = f.readlines()
  46. for line in lines:
  47. line = line.strip('\n')
  48. if line.split('\t')[-1] not in label2ids:
  49. label2ids[line.split('\t')[-1]] = id
  50. id += 1
  51. with open(args.val_datas_path,'r',encoding='utf-8') as f:
  52. lines = f.readlines()
  53. for line in lines:
  54. line = line.strip('\n')
  55. if line.split('\t')[-1] not in label2ids:
  56. label2ids[line.split('\t')[-1]] = id
  57. id += 1
  58. augmentors = {}
  59. with open("augment.ymal",'r', encoding="utf-8") as fp:
  60. configs = yaml.load(fp, Loader=yaml.FullLoader)
  61. augmentors['noise'] = NoisePerturbAugmentor(**configs['noise'])
  62. augmentors['speed'] = SpeedPerturbAugmentor(**configs['speed'])
  63. augmentors['volume'] = VolumePerturbAugmentor(**configs['volume'])
  64. augmentors['specaug'] = SpecAugmentor(**configs['specaug'])
  65. augmentors = None
  66. time_srt = datetime.now().strftime('%Y-%m-%d')
  67. save_path = os.path.join(args.model_out,time_srt)
  68. if not os.path.exists(save_path):
  69. os.makedirs(save_path)
  70. logger.info(save_path)
  71. device = "cuda:0" if torch.cuda.is_available() else "cpu"
  72. train_dataset = AudioDataReader(feature_method='melspectrogram',data_list_path=args.train_datas_path,mode='train', label2ids=label2ids, augmentors=augmentors)
  73. train_dataloader = DataLoader(train_dataset,shuffle=True,batch_size=args.batch_size )
  74. val_dataset = AudioDataReader(feature_method='melspectrogram', data_list_path=args.val_datas_path, mode='eval', label2ids = label2ids,augmentors=augmentors)
  75. val_dataloader = DataLoader(val_dataset, shuffle=True, batch_size=args.batch_size)
  76. num_class = len(label2ids)
  77. logger.info('num_class:%d'%num_class)
  78. ecapa_tdnn = ECAPATDNN(input_size=train_dataset.input_size)
  79. model = SpeakerIdentificationModel(backbone=ecapa_tdnn, num_class=num_class).to(device)
  80. # ecapa_tdnn = EcapaTdnn(input_size=train_dataset.input_size)
  81. # model = SpeakerIdetification(backbone=ecapa_tdnn, num_class=num_class).to(device)
  82. # logger.info(model)
  83. loss_function = AAMLoss()
  84. optimizer = AdamW(lr=args.lr,params=model.parameters())
  85. scheduler = CosineAnnealingLR(optimizer,T_max=args.epochs)
  86. logger.info("***** Running training *****")
  87. logger.info(" Num examples = %d" % len(train_dataloader))
  88. logger.info(" Num Epochs = %d" % args.epochs)
  89. writer = SummaryWriter('./runs/' + time_srt + '/')
  90. best_acc = 0
  91. total_step = 0
  92. unimproving_count = 0
  93. for epoch in range(args.epochs):
  94. pbar = ProgressBar(n_total=len(train_dataloader), desc='Training')
  95. model.train()
  96. total_loss = 0
  97. for step, batch in enumerate(train_dataloader):
  98. batch = [t.to(device) for t in batch]
  99. audio = batch[0]
  100. speakers = batch[1]
  101. output = model(audio)
  102. loss = loss_function(output, speakers)
  103. optimizer.zero_grad()
  104. # loss.backward(retain_graph=True)
  105. loss.backward()
  106. optimizer.step()
  107. total_step += 1
  108. writer.add_scalar('Train/Learning loss', loss.item(), total_step)
  109. total_loss += loss.item()
  110. pbar(step, {'loss': loss.item()})
  111. val_acc = evaluate(model, val_dataloader, device)
  112. if best_acc < val_acc:
  113. best_acc = val_acc
  114. save_path = os.path.join(save_path,"ecapa_tdnn.bin")
  115. torch.save(model.state_dict(),save_path)
  116. is_improving = True
  117. unimproving_count = 0
  118. else:
  119. is_improving = False
  120. unimproving_count += 1
  121. if is_improving:
  122. logger.info(f"Train epoch [{epoch+1}/{args.epochs}],batch [{step+1}],Best_acc: {best_acc},Val_acc:{val_acc}, lr:{scheduler.get_lr()[0]}, total_loss:{round(total_loss,4)}. Save model!")
  123. else:
  124. logger.info(f"Train epoch [{epoch+1}/{args.epochs}],batch [{step+1}],Best_acc: {best_acc},Val_acc:{val_acc}, lr:{scheduler.get_lr()[0]}, total_loss:{round(total_loss,4)}.")
  125. writer.add_scalar('Val/val_acc', val_acc, total_step)
  126. writer.add_scalar('Val/best_acc', best_acc, total_step)
  127. writer.add_scalar('Train/Learning rate', scheduler.get_lr()[0], total_step)
  128. scheduler.step()
  129. if unimproving_count >= 5:
  130. logger.info('unimproving %d epochs, early stop!'%unimproving_count)
  131. break
  132. def evaluate(model,val_dataloader,device):
  133. total = 0
  134. correct_total = 0
  135. model.eval()
  136. with torch.no_grad():
  137. pbar = ProgressBar(n_total=len(val_dataloader), desc='evaluate')
  138. for step, batch in enumerate(val_dataloader):
  139. batch = [t.to(device) for t in batch]
  140. audio = batch[0]
  141. speakers = batch[1]
  142. output = model(audio)
  143. total += speakers.shape[0]
  144. preds = torch.argmax(output,dim=-1)
  145. correct = (speakers==preds).sum().item()
  146. pbar(step, {})
  147. correct_total += correct
  148. acc = correct_total/total
  149. model.train()
  150. return acc
  151. def set_seed(seed):
  152. torch.manual_seed(seed)
  153. torch.cuda.manual_seed(seed)
  154. np.random.seed(seed)
  155. random.seed(seed)
  156. torch.backends.cudnn.deterministic = True
  157. def collate_fn(batch):
  158. features,labels = zip(*batch)
  159. return features
  160. if __name__ == '__main__':
  161. args = parse_args()
  162. set_seed(args.random_seed)
  163. training(args)

训练过程中采用的评估指标直接是分类准确率,日志如下:

验证集分类准确率是0.9503

e、说话人验证推理

  1. 使用上述训练好的Ecapa_TDNN模型对经过数据处理后的音频数据抽取向量特征,计算相似度,通过设定的阈值来判定是否为同一个说话人,当然这里的阈值就需要经过构建的验证数据集进行搜索得到最佳阈值。
  1. from models.ecapa_tdnn import SpeakerIdentificationModel,ECAPATDNN
  2. from tools.log import Logger
  3. from tools.progressbar import ProgressBar
  4. from data_utils.reader import AudioDataReader
  5. from data_utils.noise_perturb import NoisePerturbAugmentor
  6. from data_utils.speed_perturb import SpeedPerturbAugmentor
  7. from data_utils.volum_perturb import VolumePerturbAugmentor
  8. from data_utils.spec_augment import SpecAugmentor
  9. from torch.utils.data import DataLoader
  10. import torch
  11. import os
  12. import argparse
  13. import numpy as np
  14. import yaml
  15. from tqdm import tqdm
  16. import matplotlib.pyplot as plt
  17. import time
  18. import random
  19. random.seed(100)
  20. def parse_args():
  21. parser = argparse.ArgumentParser()
  22. parser.add_argument("--train_datas_path", type=str, default='./data/train_audio_paths.txt', help="train text file")
  23. parser.add_argument("--val_datas_path", type=str, default='./data/val_audio_paths.txt', help="val text file")
  24. parser.add_argument("--log_file", type=str, default="./log_output/speaker_identification_evaluate.log", help="log_file")
  25. parser.add_argument("--batch_size", type=int, default=64, help="batch size")
  26. parser.add_argument("--random_seed", type=int, default=100, help="random_seed")
  27. parser.add_argument("--device", type=str, default='0', help="device")
  28. args = parser.parse_args()
  29. return args
  30. def evaluate(args):
  31. os.environ['CUDA_VISIBLE_DEVICES'] = args.device
  32. logger = Logger(log_name='SI',log_level=10,log_file=args.log_file).logger
  33. logger.info(args)
  34. label2ids = {}
  35. id = 0
  36. with open(args.train_datas_path,'r',encoding='utf-8') as f:
  37. lines = f.readlines()
  38. for line in lines:
  39. line = line.strip('\n')
  40. if line.split('\t')[-1] not in label2ids:
  41. label2ids[line.split('\t')[-1]] = id
  42. id += 1
  43. with open(args.val_datas_path,'r',encoding='utf-8') as f:
  44. lines = f.readlines()
  45. for line in lines:
  46. line = line.strip('\n')
  47. if line.split('\t')[-1] not in label2ids:
  48. label2ids[line.split('\t')[-1]] = id
  49. id += 1
  50. augmentors = {}
  51. with open("augment.ymal",'r', encoding="utf-8") as fp:
  52. configs = yaml.load(fp, Loader=yaml.FullLoader)
  53. augmentors['noise'] = NoisePerturbAugmentor(**configs['noise'])
  54. augmentors['speed'] = SpeedPerturbAugmentor(**configs['speed'])
  55. augmentors['volume'] = VolumePerturbAugmentor(**configs['volume'])
  56. augmentors['specaug'] = SpecAugmentor(**configs['specaug'])
  57. augmentors = None
  58. device = "cuda:0" if torch.cuda.is_available() else "cpu"
  59. val_dataset = AudioDataReader(feature_method='melspectrogram', data_list_path=args.val_datas_path, mode='eval', label2ids = label2ids,augmentors=augmentors)
  60. val_dataloader = DataLoader(val_dataset, shuffle=True, batch_size=args.batch_size)
  61. num_class = 875
  62. logger.info('num_class:%d'%num_class)
  63. ecapa_tdnn = ECAPATDNN(input_size=val_dataset.input_size)
  64. model = SpeakerIdentificationModel(backbone=ecapa_tdnn, num_class=num_class).to(device)
  65. weights = torch.load('./output/2022-11-07/ecapa_tdnn.bin')
  66. model.load_state_dict(weights)
  67. model.eval()
  68. logger.info("***** Running evaluate *****")
  69. logger.info(" Num examples = %d" % len(val_dataset))
  70. pbar = ProgressBar(n_total=len(val_dataloader), desc='extract features')
  71. model.eval()
  72. labels = []
  73. features = []
  74. with torch.no_grad():
  75. for step, batch in enumerate(val_dataloader):
  76. batch = [t.to(device) for t in batch]
  77. audio = batch[0]
  78. speakers = batch[1]
  79. output = model.backbone(audio)
  80. labels.append(speakers)
  81. features.append(output.squeeze(2))
  82. pbar(step,info={'step':step})
  83. labels = torch.cat(labels)
  84. features = torch.cat(features)
  85. scores_pos = []
  86. scores_neg = []
  87. y_true_pos = []
  88. y_true_neg = []
  89. for i in tqdm(range(features.shape[0]),desc='两两计算相似度',ncols=100):
  90. query = features[i]
  91. inside = features[i:,:]
  92. temp = (labels[i] == labels[i:]).detach().long()
  93. pos_index = torch.nonzero(temp==1)
  94. neg_index = torch.nonzero(temp==0)
  95. pos_label = torch.take(temp,pos_index).squeeze(1).detach().cpu().tolist()
  96. neg_label = torch.take(temp, neg_index).squeeze(1).detach().cpu().tolist()
  97. cos = torch.cosine_similarity(query, inside, dim=-1)
  98. pos_score = torch.take(cos,pos_index).squeeze(1).detach().cpu().tolist()
  99. neg_score = torch.take(cos,neg_index).squeeze(1).detach().cpu().tolist()
  100. y_true_pos.extend(pos_label)
  101. y_true_neg.extend(neg_label)
  102. scores_pos.extend(pos_score)
  103. scores_neg.extend(neg_score)
  104. print('len(y_true_neg)',len(y_true_neg))
  105. print('len(y_true_pos)',len(y_true_pos))
  106. print('len(scores_pos)', len(scores_pos))
  107. print('len(scores_neg)', len(scores_neg))
  108. if len(y_true_pos) * 99 < len(y_true_neg):
  109. indexs = random.choices(list(range(len(y_true_neg))),k=len(y_true_pos)*99)
  110. scores = scores_pos
  111. y_true = y_true_pos
  112. for index in indexs:
  113. scores.append(scores_neg[index])
  114. y_true.append(y_true_neg[index])
  115. else:
  116. scores = scores_pos + scores_neg
  117. y_true = y_true_pos + y_true_neg
  118. print('len(scores)', len(scores))
  119. print('len(y_true)', len(y_true))
  120. scores = torch.tensor(scores,dtype=torch.float32)
  121. y_true = torch.tensor(y_true,dtype=torch.long)
  122. # choice_best_threshold(scores, y_true)
  123. choice_best_threshold_dcf(scores, y_true)
  124. def choice_best_threshold_dcf(scores, y_true):
  125. thresholds = []
  126. fars = []
  127. frrs = []
  128. dcfs = []
  129. precisions = []
  130. recalls = []
  131. f1s = []
  132. max_precision = 0
  133. max_recall = 0
  134. max_f1 = 0
  135. f1_threshold = 0
  136. min_dcf = 1
  137. d_threshold = 0
  138. cfr = 1
  139. cfa =1
  140. err = 0.0
  141. err_threshold = 0
  142. diff = 1
  143. for i in tqdm(range(100), desc='choice_best_threshold', ncols=100):
  144. threshold = 0.01 * i
  145. thresholds.append(threshold)
  146. y_preds = (scores > threshold).long()
  147. tp = ((y_true == 1) * (y_preds == 1)).sum().item()
  148. fp = ((y_true == 0) * (y_preds == 1)).sum().item()
  149. tn = ((y_true == 0) * (y_preds == 0)).sum().item()
  150. fn = ((y_true == 1) * (y_preds == 0)).sum().item()
  151. pos = tp + fn
  152. neg = tn + fp
  153. precision = tp / (tp + fp+1e-13)
  154. recall = tp / (tp + fn+1e-13)
  155. f1 = 2 * precision * recall / (precision + recall + 1e-13)
  156. far = fp / (fp + tn + 1e-13)
  157. frr = fn / (tp + fn + 1e-13)
  158. dcf = cfa* far *(neg/(neg+pos)) + cfr* frr *(pos/(pos+neg))
  159. precisions.append(precision)
  160. recalls.append(recall)
  161. f1s.append(f1)
  162. fars.append(far)
  163. frrs.append(frr)
  164. dcfs.append(dcf)
  165. if max_precision < precision:
  166. max_precision = precision
  167. if max_recall < recall:
  168. max_recall = recall
  169. if max_f1 < f1:
  170. max_f1 = f1
  171. f1_threshold = threshold
  172. if min_dcf > dcf:
  173. min_dcf = dcf
  174. d_threshold = threshold
  175. if abs(far-frr) < diff:
  176. err = (far+frr)/2
  177. diff = abs(far-frr)
  178. err_threshold = threshold
  179. print(pos + neg)
  180. print('threshold:%.4f err:%.4f'%(err_threshold, err))
  181. print("d_threshold:%.4f, min_dcf%.4f"%(d_threshold, min_dcf))
  182. print("f1_threshold:%.4f, max_f1%.4f" % (f1_threshold, max_f1))
  183. start = time.time()
  184. plt.figure(figsize=(30,30),dpi=80)
  185. plt.title('2D curve ')
  186. plt.plot(thresholds, frrs, label='frr')
  187. plt.plot(thresholds, fars, label='far')
  188. plt.plot(thresholds, dcfs, label='dcf')
  189. plt.plot(thresholds, precisions, label='pre')
  190. plt.plot(thresholds, recalls, label='recall')
  191. plt.plot(thresholds, f1s, label='f1')
  192. plt.legend(loc=0)
  193. plt.scatter(d_threshold, min_dcf, c='red', s=100)
  194. plt.text(d_threshold, min_dcf, " min_dcf(%.4f,%.4f)"%(d_threshold, min_dcf))
  195. plt.scatter(err_threshold,err,c='blue',s=100)
  196. plt.text(err_threshold,err," err(%.4f,%.4f)"%(err_threshold,err))
  197. plt.scatter(f1_threshold, max_f1, c='yellow', s=100)
  198. plt.text(f1_threshold, max_f1, " f1(%.4f,%.4f)"%(f1_threshold, max_f1))
  199. plt.xlabel('threshold')
  200. plt.ylabel('frr f dcf recall or precision')
  201. plt.xticks(thresholds[::2])
  202. plt.yticks(thresholds[::2])
  203. end = time.time()
  204. print('plot time is', end - start)
  205. plt.savefig('ecapatdnn_2d_curve_voiceprint_dcf.png')
  206. plt.show()
  207. print("finish")
  208. def choice_best_threshold(scores,y_true):
  209. best_precision_threshold = 0
  210. precision_best = 0
  211. precision_recall = 0
  212. precision_f1 = 0
  213. tp_1 = 0
  214. fp_1 = 0
  215. fn_1 = 0
  216. tn_1 = 0
  217. best_recall_threshold = 0
  218. recall_best = 0
  219. recall_precision = 0
  220. recall_f1 = 0
  221. tp_2 = 0
  222. fp_2 = 0
  223. fn_2 = 0
  224. tn_2 = 0
  225. best_f1_threshold = 0
  226. f1_best = 0
  227. f1_precision = 0
  228. f1_recall = 0
  229. tp_3 = 0
  230. fp_3 = 0
  231. fn_3 = 0
  232. tn_3 = 0
  233. fars = []#误接受率
  234. frrs = []#误拒识率
  235. far_min = 1
  236. frr_min = 1
  237. thresholds = []
  238. err = None
  239. tp_4 = 0
  240. fp_4 = 0
  241. fn_4 = 0
  242. tn_4 = 0
  243. diff = 1
  244. for i in tqdm( range(100),desc='choice_best_threshold',ncols=100):
  245. threshold = 0.01 * i
  246. thresholds.append(threshold)
  247. y_preds = (scores > threshold).long()
  248. tp = ((y_true == 1)*(y_preds==1)).sum().item()
  249. fp = ((y_true == 0)*(y_preds==1)).sum().item()
  250. tn = ((y_true==0)*(y_preds==0)).sum().item()
  251. fn = ((y_true==1)*(y_preds==0)).sum().item()
  252. precision = tp /(tp+fp)
  253. recall = tp/(tp+fn)
  254. f1 = 2*precision*recall/(precision+recall + 1e-13)
  255. far = fp/(fp+tn)
  256. frr = fn/(tp+fn)
  257. fars.append(far)
  258. frrs.append(frr)
  259. if precision > precision_best:
  260. precision_best = precision
  261. best_precision_threshold = threshold
  262. precision_recall = recall
  263. precision_f1 = f1
  264. tp_1 = tp
  265. fp_1 = fp
  266. fn_1 = fn
  267. tn_1 = tn
  268. if recall > recall_best:
  269. recall_best = recall
  270. best_recall_threshold = threshold
  271. recall_precision = precision
  272. recall_f1 = f1
  273. tp_2 = tp
  274. fp_2 = fp
  275. fn_2 = fn
  276. tn_2 = tn
  277. if f1 > f1_best:
  278. f1_best = f1
  279. f1_precision = precision
  280. f1_recall = recall
  281. best_f1_threshold = threshold
  282. tp_3 = tp
  283. fp_3 = fp
  284. fn_3 = fn
  285. tn_3 = tn
  286. if abs(far-frr) < diff:
  287. diff = abs(far-frr)
  288. err = (far+frr)/2
  289. far_min = far
  290. frr_min = frr
  291. tp_4 = tp
  292. fp_4 = fp
  293. fn_4 = fn
  294. tn_4 = tn
  295. print(f"tp:{tp_4} fp{fp_4} tn{tn_4} fn{fn_4}")
  296. print("frr_min:%.4f,far_min:%.4f,err:%.4f"%(frr_min,far_min,err))
  297. print("precision:%.4f recall:%.4f"%(tp_4 /(tp_4+fp_4), tp_4/(tp_4+fn_4)))
  298. print('*'*50)
  299. print(f"tp:{tp_1} fp{fp_1} tn{tn_1} fn{fn_1}")
  300. print('best_precision_threshold:%.4f, precision_best:%.4f precision_recall:%.4f precision_f1:%.4f'%(best_precision_threshold,precision_best,precision_recall, precision_f1))
  301. print('*' * 50)
  302. print(f"tp:{tp_2} fp{fp_2} tn{tn_2} fn{fn_2}")
  303. print('best_recall_threshold:%.4f, recall_best:%.4f recall_precision:%.4f recall_f1:%.4f' % (
  304. best_recall_threshold, recall_best, recall_precision, recall_f1))
  305. print('*' * 50)
  306. print(f"tp:{tp_3} fp{fp_3} tn{tn_3} fn{fn_3}")
  307. print("frr:%.4f,far:%.4f"%(fn_3/(fn_3+tp_3),fp_3/(fp_3+tn_3)))
  308. print('best_f1_threshold:%.4f, f1_best:%.4f f1_precision:%.4f f1_recall:%.4f' % (
  309. best_f1_threshold, f1_best, f1_precision, f1_recall))
  310. print('*' * 50)
  311. # print(fars[0],"--",frrs[0])
  312. # print(fars[-1], "--", frrs[-1])
  313. #
  314. # plt.figure(figsize=(20,20),dpi=80)
  315. # plt.title('2D curve ')
  316. # plt.plot(fars, frrs)
  317. # plt.plot(thresholds,thresholds)
  318. # plt.scatter(err,err,c='red',s=100)
  319. # plt.text(err,err,(err,err))
  320. #
  321. # plt.xlabel('far')
  322. # plt.ylabel('frr')
  323. # plt.xticks(thresholds[::2])
  324. # plt.yticks(thresholds[::2])
  325. # plt.show()
  326. # plt.savefig('2d_curve_voiceprint_det.png')
  327. def set_seed(seed):
  328. torch.manual_seed(seed)
  329. torch.cuda.manual_seed(seed)
  330. np.random.seed(seed)
  331. random.seed(seed)
  332. torch.backends.cudnn.deterministic = True
  333. def collate_fn(batch):
  334. features,labels = zip(*batch)
  335. return features
  336. if __name__ == '__main__':
  337. args = parse_args()
  338. set_seed(args.random_seed)
  339. evaluate(args)

采用far和frr以及errdct等评价指标来获取最佳threshold:

可以看到最小dcf对应的相似度阈值是0.4500。

2、WavLm预训练方案

a、模型结构和loss

  1. from transformers import WavLMModel, WavLMPreTrainedModel
  2. from transformers.modeling_outputs import XVectorOutput
  3. from transformers.pytorch_utils import torch_int_div
  4. import torch.nn as nn
  5. import torch
  6. from typing import Optional, Tuple, Union
  7. _HIDDEN_STATES_START_POSITION = 2
  8. class TDNNLayer(nn.Module):
  9. def __init__(self, config, layer_id=0):
  10. super().__init__()
  11. self.in_conv_dim = config.tdnn_dim[layer_id - 1] if layer_id > 0 else config.tdnn_dim[layer_id]
  12. self.out_conv_dim = config.tdnn_dim[layer_id]
  13. self.kernel_size = config.tdnn_kernel[layer_id]
  14. self.dilation = config.tdnn_dilation[layer_id]
  15. self.kernel = nn.Linear(self.in_conv_dim * self.kernel_size, self.out_conv_dim)
  16. self.activation = nn.ReLU()
  17. def forward(self, hidden_states):
  18. hidden_states = hidden_states.unsqueeze(1)
  19. hidden_states = nn.functional.unfold(
  20. hidden_states,
  21. (self.kernel_size, self.in_conv_dim),
  22. stride=(1, self.in_conv_dim),
  23. dilation=(self.dilation, 1),
  24. )
  25. hidden_states = hidden_states.transpose(1, 2)
  26. hidden_states = self.kernel(hidden_states)
  27. hidden_states = self.activation(hidden_states)
  28. return hidden_states
  29. class AMSoftmaxLoss(nn.Module):
  30. def __init__(self, input_dim, num_labels, scale=30.0, margin=0.4):
  31. super(AMSoftmaxLoss, self).__init__()
  32. self.scale = scale
  33. self.margin = margin
  34. self.num_labels = num_labels
  35. self.weight = nn.Parameter(torch.randn(input_dim, num_labels), requires_grad=True)
  36. self.loss = nn.CrossEntropyLoss()
  37. def forward(self, hidden_states, labels = None):
  38. weight = nn.functional.normalize(self.weight, dim=0)
  39. hidden_states = nn.functional.normalize(hidden_states, dim=1)
  40. cos_theta = torch.mm(hidden_states, weight)
  41. if labels is not None:
  42. psi = cos_theta - self.margin
  43. labels = labels.flatten()
  44. onehot = nn.functional.one_hot(labels, self.num_labels)
  45. logits = self.scale * torch.where(onehot.bool(), psi, cos_theta)
  46. loss = self.loss(logits, labels)
  47. return loss,cos_theta
  48. else:
  49. return cos_theta
  50. class WavLm(WavLMPreTrainedModel):
  51. def __init__(self,config):
  52. super(WavLm, self).__init__(config)
  53. self.wavlm = WavLMModel(config)
  54. num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
  55. if config.use_weighted_layer_sum:
  56. self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
  57. self.projector = nn.Linear(config.hidden_size, config.tdnn_dim[0])
  58. tdnn_layers = [TDNNLayer(config, i) for i in range(len(config.tdnn_dim))]
  59. self.tdnn = nn.ModuleList(tdnn_layers)
  60. self.feature_extractor = nn.Linear(config.tdnn_dim[-1] * 2, config.xvector_output_dim)
  61. self.classifier = nn.Linear(config.xvector_output_dim, config.xvector_output_dim)
  62. self.objective = AMSoftmaxLoss(config.xvector_output_dim, config.num_labels)
  63. self.init_weights()
  64. def forward(self,input_values: Optional[torch.Tensor],
  65. attention_mask: Optional[torch.Tensor] = None,
  66. output_attentions: Optional[bool] = None,
  67. output_hidden_states: Optional[bool] = None,
  68. return_dict: Optional[bool] = None,
  69. labels: Optional[torch.Tensor] = None,):
  70. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  71. output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
  72. outputs = self.wavlm(
  73. input_values,
  74. attention_mask=attention_mask,
  75. output_attentions=output_attentions,
  76. output_hidden_states=output_hidden_states,
  77. return_dict=return_dict,
  78. )
  79. if self.config.use_weighted_layer_sum:
  80. hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
  81. hidden_states = torch.stack(hidden_states, dim=1)
  82. norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
  83. hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
  84. else:
  85. hidden_states = outputs[0]
  86. hidden_states = self.projector(hidden_states)
  87. for tdnn_layer in self.tdnn:
  88. hidden_states = tdnn_layer(hidden_states)
  89. # Statistic Pooling
  90. if attention_mask is None:
  91. mean_features = hidden_states.mean(dim=1)
  92. std_features = hidden_states.std(dim=1)
  93. else:
  94. feat_extract_output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(dim=1))
  95. tdnn_output_lengths = self._get_tdnn_output_lengths(feat_extract_output_lengths)
  96. mean_features = []
  97. std_features = []
  98. for i, length in enumerate(tdnn_output_lengths):
  99. mean_features.append(hidden_states[i, :length].mean(dim=0))
  100. std_features.append(hidden_states[i, :length].std(dim=0))
  101. mean_features = torch.stack(mean_features)
  102. std_features = torch.stack(std_features)
  103. statistic_pooling = torch.cat([mean_features, std_features], dim=-1)
  104. output_embeddings = self.feature_extractor(statistic_pooling)
  105. logits = self.classifier(output_embeddings)
  106. loss = None
  107. if labels is not None:
  108. loss, cos_theta = self.objective(logits, labels)
  109. else:
  110. cos_theta = self.objective(logits, labels)
  111. logits = cos_theta
  112. if not return_dict:
  113. output = (logits, output_embeddings) + outputs[_HIDDEN_STATES_START_POSITION:]
  114. return ((loss,) + output) if loss is not None else output
  115. return XVectorOutput(
  116. loss=loss,
  117. logits=logits,
  118. embeddings=output_embeddings,
  119. hidden_states=outputs.hidden_states,
  120. attentions=outputs.attentions,
  121. )
  122. def _get_tdnn_output_lengths(self, input_lengths: Union[torch.LongTensor, int]):
  123. """
  124. Computes the output length of the TDNN layers
  125. """
  126. def _conv_out_length(input_length, kernel_size, stride):
  127. # 1D convolutional layer output length formula taken
  128. # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
  129. return (input_length - kernel_size) // stride + 1
  130. for kernel_size in self.config.tdnn_kernel:
  131. input_lengths = _conv_out_length(input_lengths, kernel_size, 1)
  132. return input_lengths
  133. def _get_feat_extract_output_lengths(
  134. self, input_lengths: Union[torch.LongTensor, int], add_adapter: Optional[bool] = None
  135. ):
  136. """
  137. Computes the output length of the convolutional layers
  138. """
  139. add_adapter = self.config.add_adapter if add_adapter is None else add_adapter
  140. def _conv_out_length(input_length, kernel_size, stride):
  141. # 1D convolutional layer output length formula taken
  142. # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
  143. return torch_int_div(input_length - kernel_size, stride) + 1
  144. for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
  145. input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
  146. if add_adapter:
  147. for _ in range(self.config.num_adapter_layers):
  148. input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride)
  149. return input_lengths

b、数据处理

  1. import random
  2. import torch
  3. from torch.utils.data import Dataset
  4. import torchaudio
  5. from tqdm import tqdm
  6. class AudioDataReader(Dataset):
  7. def __init__(self, data_list_path,
  8. mode='train',
  9. sr=16000,
  10. chunk_duration=3,
  11. min_duration=0.5,
  12. label2ids = {},
  13. augmentors=None):
  14. super(AudioDataReader, self).__init__()
  15. assert data_list_path is not None
  16. with open(data_list_path,'r',encoding='utf-8') as f:
  17. self.lines = f.readlines()[0:]
  18. self.mode = mode
  19. self.sr = sr
  20. self.chunk_duration = chunk_duration
  21. self.min_duration = min_duration
  22. self.augmentors = augmentors
  23. self.label2ids = label2ids
  24. self.audiofeatures = self.getaudiofeatures()
  25. def handle_features(self,wav,sr,mode,chunk_duration,min_duration):
  26. num_wav_samples = wav.shape[1]
  27. # 数据太短不利于训练
  28. if mode == 'train':
  29. if num_wav_samples < int(min_duration * sr):
  30. raise Exception(f'音频长度小于{min_duration}s,实际长度为:{(num_wav_samples / sr):.2f}s')
  31. # print(f'音频长度小于{min_duration}s,实际长度为:{(num_wav_samples / sr):.2f}s')
  32. # return None
  33. # 对小于训练长度的复制补充
  34. num_chunk_samples = int(chunk_duration * sr)
  35. if num_wav_samples < num_chunk_samples:
  36. times = int(num_chunk_samples / num_wav_samples) - 1
  37. shortages = []
  38. temp_num_wav_samples = num_wav_samples
  39. shortages.append(wav)
  40. if times >= 1:
  41. for _ in range(times):
  42. shortages.append(wav)
  43. temp_num_wav_samples += num_wav_samples
  44. shortages.append(wav[:, 0:(num_chunk_samples - temp_num_wav_samples)])
  45. else:
  46. shortages.append(wav[:, 0:(num_chunk_samples - num_wav_samples)])
  47. wav = torch.cat(shortages, dim=1)
  48. # 裁剪需要的数据
  49. if mode == 'train':
  50. # 随机裁剪
  51. num_wav_samples = wav.shape[1]
  52. num_chunk_samples = int(chunk_duration * sr)
  53. if num_wav_samples > num_chunk_samples + 1:
  54. start = random.randint(0, num_wav_samples - num_chunk_samples - 1)
  55. end = start + num_chunk_samples
  56. wav = wav[:, start:end]
  57. # # 对每次都满长度的再次裁剪
  58. # if random.random() > 0.5:
  59. # wav[:random.randint(1, sr // 4)] = 0 #加入了静音数据
  60. # wav = wav[:-random.randint(1, sr // 4)]
  61. elif mode == 'eval':
  62. # 为避免显存溢出,只裁剪指定长度
  63. num_wav_samples = wav.shape[1]
  64. num_chunk_samples = int(chunk_duration * sr)
  65. if num_wav_samples > num_chunk_samples + 1:
  66. wav = wav[:, 0:num_chunk_samples]
  67. return wav
  68. def getaudiofeatures(self):
  69. res = []
  70. for line in tqdm(self.lines,desc= self.mode + ' load all audios',ncols=100):
  71. temp = []
  72. try:
  73. audio_path, label = line.replace('\n', '').split('\t')
  74. label = self.label2ids[label]
  75. wav, sample_rate = torchaudio.load(audio_path) # 加载音频返回的是张量
  76. wav = self.handle_features(wav,sr=self.sr,mode=self.mode,chunk_duration=self.chunk_duration,min_duration=self.min_duration)
  77. features = wav[:,0:self.sr*self.chunk_duration].squeeze(0)
  78. attention_mask = torch.ones_like(features,dtype=torch.long)
  79. label = torch.as_tensor(label, dtype=torch.long)
  80. temp.append(features)
  81. temp.append(attention_mask)
  82. temp.append(label)
  83. res.append(temp)
  84. except Exception as e:
  85. print(e+',load audio data exception')
  86. return res
  87. def __getitem__(self, item):
  88. return self.audiofeatures[item][0], self.audiofeatures[item][1], self.audiofeatures[item][2]
  89. def __len__(self):
  90. return len(self.audiofeatures)
  1. Ecapa_TDNN的不同就是直接采用时域数据而不是采用语音特征分析后的频域信息,代码就是训练和验证样本的长度进行了控制,比较简单。

c、模型训练

  1. from transformers import Wav2Vec2Config
  2. from models.wavlm import WavLm
  3. from tools.log import Logger
  4. from tools.progressbar import ProgressBar
  5. from data_utils.wavlm_reader import AudioDataReader
  6. from torch.utils.data import DataLoader
  7. import torch
  8. import os
  9. from torch.optim import AdamW
  10. from torch.optim.lr_scheduler import CosineAnnealingLR
  11. import argparse
  12. import random
  13. import numpy as np
  14. from torch.utils.tensorboard import SummaryWriter
  15. from datetime import datetime
  16. from torch.nn.utils.rnn import pad_sequence
  17. def parse_args():
  18. parser = argparse.ArgumentParser()
  19. parser.add_argument("--train_datas_path", type=str, default='./data/train_audio_paths.txt', help="train text file")
  20. parser.add_argument("--val_datas_path", type=str, default='./data/val_audio_paths.txt', help="val text file")
  21. # parser.add_argument("--train_datas_path", type=str, default='./data/train_audio_paths_small.txt', help="train text file")
  22. # parser.add_argument("--val_datas_path", type=str, default='./data/val_audio_paths_small.txt', help="val text file")
  23. parser.add_argument("--log_file", type=str, default="./log_output/speaker_identification_wavlm.log", help="log_file")
  24. parser.add_argument("--model_out", type=str, default="./output/wavlm/", help="model output path")
  25. parser.add_argument("--batch_size", type=int, default=32, help="batch size")
  26. parser.add_argument("--epochs", type=int, default=30, help="epochs")
  27. parser.add_argument("--lr", type=float, default=1e-5, help="epochs")
  28. parser.add_argument("--random_seed", type=int, default=100, help="random_seed")
  29. parser.add_argument("--device", type=str, default='0', help="device")
  30. args = parser.parse_args()
  31. return args
  32. def training(args):
  33. os.environ['CUDA_VISIBLE_DEVICES'] = args.device
  34. logger = Logger(log_name='SI',log_level=10,log_file=args.log_file).logger
  35. logger.info(args)
  36. label2ids = {}
  37. config = Wav2Vec2Config.from_pretrained('./pretrained_models/torch/wavlm-base-plus-sv/')
  38. id = 0
  39. with open(args.train_datas_path,'r',encoding='utf-8') as f:
  40. lines = f.readlines()
  41. for line in lines:
  42. line = line.strip('\n')
  43. if line.split('\t')[-1] not in label2ids:
  44. label2ids[line.split('\t')[-1]] = id
  45. id += 1
  46. with open(args.val_datas_path,'r',encoding='utf-8') as f:
  47. lines = f.readlines()
  48. for line in lines:
  49. line = line.strip('\n')
  50. if line.split('\t')[-1] not in label2ids:
  51. label2ids[line.split('\t')[-1]] = id
  52. id += 1
  53. time_srt = datetime.now().strftime('%Y-%m-%d')
  54. save_path = os.path.join(args.model_out,time_srt)
  55. if not os.path.exists(save_path):
  56. os.makedirs(save_path)
  57. logger.info(save_path)
  58. device = "cuda:0" if torch.cuda.is_available() else "cpu"
  59. train_dataset = AudioDataReader(data_list_path=args.train_datas_path,mode='train', label2ids=label2ids)
  60. train_dataloader = DataLoader(train_dataset,shuffle=True,batch_size=args.batch_size, collate_fn=collate_fn)
  61. val_dataset = AudioDataReader(data_list_path=args.val_datas_path, mode='eval', label2ids = label2ids)
  62. val_dataloader = DataLoader(val_dataset, shuffle=True, batch_size=args.batch_size, collate_fn=collate_fn)
  63. num_class = len(label2ids)
  64. logger.info('num_class:%d'%num_class)
  65. config.num_labels = num_class
  66. model = WavLm.from_pretrained('./pretrained_models/torch/wavlm-base-plus-sv/', config=config, ignore_mismatched_sizes=True).to(device)
  67. model.eval()
  68. # ecapa_tdnn = EcapaTdnn(input_size=train_dataset.input_size)
  69. # model = SpeakerIdetification(backbone=ecapa_tdnn, num_class=num_class).to(device)
  70. # logger.info(model)
  71. optimizer = AdamW(lr=args.lr,params=model.parameters())
  72. scheduler = CosineAnnealingLR(optimizer,T_max=args.epochs)
  73. logger.info("***** Running training *****")
  74. logger.info(" Num examples = %d" % len(train_dataloader))
  75. logger.info(" Num Epochs = %d" % args.epochs)
  76. writer = SummaryWriter('./runs/' + time_srt + '/')
  77. best_acc = 0
  78. total_step = 0
  79. unimproving_count = 0
  80. for epoch in range(args.epochs):
  81. pbar = ProgressBar(n_total=len(train_dataloader), desc='Training')
  82. model.train()
  83. total_loss = 0
  84. for step, batch in enumerate(train_dataloader):
  85. batch = [t.to(device) for t in batch]
  86. wav = batch[0]
  87. mask = batch[1]
  88. speakers = batch[2]
  89. inputs = {
  90. "input_values": wav,
  91. "attention_mask": mask
  92. }
  93. output = model(**inputs,labels=speakers)
  94. loss = output.loss
  95. optimizer.zero_grad()
  96. # loss.backward(retain_graph=True)
  97. loss.backward()
  98. optimizer.step()
  99. total_step += 1
  100. writer.add_scalar('Train/Learning loss', loss.item(), total_step)
  101. total_loss += loss.item()
  102. pbar(step, {'loss': loss.item()})
  103. val_acc = evaluate(model, val_dataloader, device)
  104. if best_acc < val_acc:
  105. best_acc = val_acc
  106. model.save_pretrained(save_path)
  107. is_improving = True
  108. unimproving_count = 0
  109. else:
  110. is_improving = False
  111. unimproving_count += 1
  112. if is_improving:
  113. logger.info(f"Train epoch [{epoch+1}/{args.epochs}],batch [{step+1}],Best_acc: {best_acc},Val_acc:{val_acc}, lr:{scheduler.get_last_lr()[0]}, total_loss:{round(total_loss,4)}. Save model!")
  114. else:
  115. logger.info(f"Train epoch [{epoch+1}/{args.epochs}],batch [{step+1}],Best_acc: {best_acc},Val_acc:{val_acc}, lr:{scheduler.get_last_lr()[0]}, total_loss:{round(total_loss,4)}.")
  116. writer.add_scalar('Val/val_acc', val_acc, total_step)
  117. writer.add_scalar('Val/best_acc', best_acc, total_step)
  118. # writer.add_scalar('Train/Learning rate', scheduler.get_lr()[0], total_step)
  119. writer.add_scalar('Train/Learning rate', scheduler.get_last_lr()[0], total_step)
  120. scheduler.step()
  121. if unimproving_count >= 5:
  122. logger.info('unimproving %d epochs, early stop!'%unimproving_count)
  123. break
  124. def evaluate(model,val_dataloader,device):
  125. total = 0
  126. correct_total = 0
  127. model.eval()
  128. with torch.no_grad():
  129. pbar = ProgressBar(n_total=len(val_dataloader), desc='evaluate')
  130. for step, batch in enumerate(val_dataloader):
  131. batch = [t.to(device) for t in batch]
  132. wav = batch[0]
  133. mask = batch[1]
  134. speakers = batch[2]
  135. inputs = {
  136. "input_values": wav,
  137. "attention_mask": mask
  138. }
  139. output = model(**inputs)
  140. logits = output.logits
  141. total += speakers.shape[0]
  142. preds = torch.argmax(logits,dim=-1)
  143. correct = (speakers==preds).sum().item()
  144. pbar(step, {})
  145. correct_total += correct
  146. acc = correct_total/total
  147. return acc
  148. def set_seed(seed):
  149. torch.manual_seed(seed)
  150. torch.cuda.manual_seed(seed)
  151. np.random.seed(seed)
  152. random.seed(seed)
  153. torch.backends.cudnn.deterministic = True
  154. def collate_fn(batch):
  155. features, attention_mask, labels = zip(*batch)
  156. features = pad_sequence(features, batch_first=True, padding_value=0.0)
  157. attention_mask = pad_sequence(attention_mask, batch_first=True, padding_value=0)
  158. labels = torch.stack(labels, dim=-1)
  159. return features, attention_mask, labels
  160. if __name__ == '__main__':
  161. args = parse_args()
  162. set_seed(args.random_seed)
  163. training(args)

结果如下:

分类准确率:0.9684

d、推理和评估

同样采用far frr err dcf 以及f1 recall和precision等指标来评估

  1. from transformers import WavLMForXVector
  2. from tools.log import Logger
  3. from tools.progressbar import ProgressBar
  4. from data_utils.wavlm_reader import AudioDataReader
  5. from torch.utils.data import DataLoader
  6. import torch
  7. import os
  8. import argparse
  9. import random
  10. import numpy as np
  11. from tqdm import tqdm
  12. import matplotlib.pyplot as plt
  13. from torch.nn.utils.rnn import pad_sequence
  14. import time
  15. def parse_args():
  16. parser = argparse.ArgumentParser()
  17. parser.add_argument("--train_datas_path", type=str, default='./data/train_audio_paths.txt', help="train text file")
  18. parser.add_argument("--val_datas_path", type=str, default='./data/val_audio_paths.txt', help="val text file")
  19. # parser.add_argument("--train_datas_path", type=str, default='./data/train_audio_paths_small.txt', help="train text file")
  20. # parser.add_argument("--val_datas_path", type=str, default='./data/val_audio_paths_small.txt', help="val text file")
  21. parser.add_argument("--log_file", type=str, default="./log_output/speaker_identification_evaluate.log", help="log_file")
  22. parser.add_argument("--batch_size", type=int, default=64, help="batch size")
  23. parser.add_argument("--random_seed", type=int, default=100, help="random_seed")
  24. parser.add_argument("--device", type=str, default='0', help="device")
  25. args = parser.parse_args()
  26. return args
  27. def evaluate(args):
  28. os.environ['CUDA_VISIBLE_DEVICES'] = args.device
  29. logger = Logger(log_name='SI',log_level=10,log_file=args.log_file).logger
  30. logger.info(args)
  31. label2ids = {}
  32. id = 0
  33. with open(args.train_datas_path,'r',encoding='utf-8') as f:
  34. lines = f.readlines()
  35. for line in lines:
  36. line = line.strip('\n')
  37. if line.split('\t')[-1] not in label2ids:
  38. label2ids[line.split('\t')[-1]] = id
  39. id += 1
  40. with open(args.val_datas_path,'r',encoding='utf-8') as f:
  41. lines = f.readlines()
  42. for line in lines:
  43. line = line.strip('\n')
  44. if line.split('\t')[-1] not in label2ids:
  45. label2ids[line.split('\t')[-1]] = id
  46. id += 1
  47. device = "cuda:0" if torch.cuda.is_available() else "cpu"
  48. val_dataset = AudioDataReader( data_list_path=args.val_datas_path, mode='eval', label2ids = label2ids)
  49. val_dataloader = DataLoader(val_dataset, shuffle=True, batch_size=args.batch_size,collate_fn=collate_fn)
  50. num_class = 875
  51. logger.info('num_class:%d'%num_class)
  52. model = WavLMForXVector.from_pretrained('./output/wavlm/2022-11-11/').to(device)
  53. model.eval()
  54. logger.info("***** Running evaluate *****")
  55. logger.info(" Num examples = %d" % len(val_dataset))
  56. pbar = ProgressBar(n_total=len(val_dataloader), desc='extract features')
  57. model.eval()
  58. labels = []
  59. features = []
  60. with torch.no_grad():
  61. for step, batch in enumerate(val_dataloader):
  62. batch = [t.to(device) for t in batch]
  63. wav = batch[0]
  64. mask = batch[1]
  65. speakers = batch[2]
  66. inputs = {
  67. "input_values": wav,
  68. "attention_mask": mask
  69. }
  70. output = model(**inputs)
  71. labels.append(speakers)
  72. features.append(output.embeddings)
  73. pbar(step,info={'step':step})
  74. labels = torch.cat(labels)
  75. features = torch.cat(features)
  76. scores_pos = []
  77. scores_neg = []
  78. y_true_pos = []
  79. y_true_neg = []
  80. for i in tqdm(range(features.shape[0]), desc='两两计算相似度', ncols=100):
  81. query = features[i]
  82. inside = features[i:, :]
  83. temp = (labels[i] == labels[i:]).detach().long()
  84. pos_index = torch.nonzero(temp == 1)
  85. neg_index = torch.nonzero(temp == 0)
  86. pos_label = torch.take(temp, pos_index).squeeze(1).detach().cpu().tolist()
  87. neg_label = torch.take(temp, neg_index).squeeze(1).detach().cpu().tolist()
  88. cos = torch.cosine_similarity(query, inside, dim=-1)
  89. pos_score = torch.take(cos, pos_index).squeeze(1).detach().cpu().tolist()
  90. neg_score = torch.take(cos, neg_index).squeeze(1).detach().cpu().tolist()
  91. y_true_pos.extend(pos_label)
  92. y_true_neg.extend(neg_label)
  93. scores_pos.extend(pos_score)
  94. scores_neg.extend(neg_score)
  95. print('len(y_true_neg)', len(y_true_neg))
  96. print('len(y_true_pos)', len(y_true_pos))
  97. print('len(scores_pos)', len(scores_pos))
  98. print('len(scores_neg)', len(scores_neg))
  99. if len(y_true_pos) * 99 < len(y_true_neg):
  100. indexs = random.choices(list(range(len(y_true_neg))), k=len(y_true_pos) * 99)
  101. scores = scores_pos
  102. y_true = y_true_pos
  103. for index in indexs:
  104. scores.append(scores_neg[index])
  105. y_true.append(y_true_neg[index])
  106. else:
  107. scores = scores_pos + scores_neg
  108. y_true = y_true_pos + y_true_neg
  109. print('len(scores)', len(scores))
  110. print('len(y_true)', len(y_true))
  111. scores = torch.tensor(scores,dtype=torch.float32)
  112. y_true = torch.tensor(y_true,dtype=torch.long)
  113. choice_best_threshold_dcf(scores, y_true)
  114. def choice_best_threshold_dcf(scores, y_true):
  115. thresholds = []
  116. fars = []
  117. frrs = []
  118. dcfs = []
  119. precisions = []
  120. recalls = []
  121. f1s = []
  122. max_precision = 0
  123. max_recall = 0
  124. max_f1 = 0
  125. f1_threshold = 0
  126. min_dcf = 1
  127. d_threshold = 0
  128. cfr = 1
  129. cfa =1
  130. err = 0.0
  131. err_threshold = 0
  132. diff = 1
  133. for i in tqdm(range(100), desc='choice_best_threshold', ncols=100):
  134. threshold = 0.01 * i
  135. thresholds.append(threshold)
  136. y_preds = (scores > threshold).long()
  137. tp = ((y_true == 1) * (y_preds == 1)).sum().item()
  138. fp = ((y_true == 0) * (y_preds == 1)).sum().item()
  139. tn = ((y_true == 0) * (y_preds == 0)).sum().item()
  140. fn = ((y_true == 1) * (y_preds == 0)).sum().item()
  141. pos = tp + fn
  142. neg = tn + fp
  143. precision = tp / (tp + fp+1e-13)
  144. recall = tp / (tp + fn+1e-13)
  145. f1 = 2 * precision * recall / (precision + recall + 1e-13)
  146. far = fp / (fp + tn + 1e-13)
  147. frr = fn / (tp + fn + 1e-13)
  148. dcf = cfa* far *(neg/(neg+pos)) + cfr* frr *(pos/(pos+neg))
  149. precisions.append(precision)
  150. recalls.append(recall)
  151. f1s.append(f1)
  152. fars.append(far)
  153. frrs.append(frr)
  154. dcfs.append(dcf)
  155. if max_precision < precision:
  156. max_precision = precision
  157. if max_recall < recall:
  158. max_recall = recall
  159. if max_f1 < f1:
  160. max_f1 = f1
  161. f1_threshold = threshold
  162. if min_dcf > dcf:
  163. min_dcf = dcf
  164. d_threshold = threshold
  165. if abs(far-frr) < diff:
  166. err = (far+frr)/2
  167. diff = abs(far-frr)
  168. err_threshold = threshold
  169. print(pos + neg)
  170. print('threshold:%.4f err:%.4f'%(err_threshold, err))
  171. print("d_threshold:%.4f, min_dcf%.4f"%(d_threshold, min_dcf))
  172. print("f1_threshold:%.4f, max_f1%.4f" % (f1_threshold, max_f1))
  173. start = time.time()
  174. plt.figure(figsize=(30,30),dpi=80)
  175. plt.title('2D curve ')
  176. plt.plot(thresholds, frrs, label='frr')
  177. plt.plot(thresholds, fars, label='far')
  178. plt.plot(thresholds, dcfs, label='dcf')
  179. plt.plot(thresholds, precisions, label='pre')
  180. plt.plot(thresholds, recalls, label='recall')
  181. plt.plot(thresholds, f1s, label='f1')
  182. plt.legend(loc=0)
  183. plt.scatter(d_threshold, min_dcf, c='red', s=100)
  184. plt.text(d_threshold, min_dcf, " min_dcf(%.4f,%.4f)"%(d_threshold, min_dcf))
  185. plt.scatter(err_threshold,err,c='blue',s=100)
  186. plt.text(err_threshold,err," err(%.4f,%.4f)"%(err_threshold,err))
  187. plt.scatter(f1_threshold, max_f1, c='yellow', s=100)
  188. plt.text(f1_threshold, max_f1, " f1(%.4f,%.4f)"%(f1_threshold, max_f1))
  189. plt.xlabel('threshold')
  190. plt.ylabel('frr f dcf recall or precision')
  191. plt.xticks(thresholds[::2])
  192. plt.yticks(thresholds[::2])
  193. end = time.time()
  194. print('plot time is', end - start)
  195. plt.savefig('wavlm_2d_curve_voiceprint_dcf.png')
  196. plt.show()
  197. print("finish")
  198. def set_seed(seed):
  199. torch.manual_seed(seed)
  200. torch.cuda.manual_seed(seed)
  201. np.random.seed(seed)
  202. random.seed(seed)
  203. torch.backends.cudnn.deterministic = True
  204. def collate_fn(batch):
  205. features,attention_mask,labels = zip(*batch)
  206. features = pad_sequence(features,batch_first=True,padding_value=0.0)
  207. attention_mask = pad_sequence(attention_mask, batch_first=True, padding_value=0)
  208. labels = torch.stack(labels,dim=-1)
  209. return features, attention_mask, labels
  210. if __name__ == '__main__':
  211. args = parse_args()
  212. set_seed(args.random_seed)
  213. evaluate(args)

结果如下

  1. threshold=0.69 dcf f1值都处于最佳状态 而且f1=0.9765 errdcf值都非常低,明显wavLm模型在该数据集上的效果要优于Ecapa_TDNN

四、demo演示

  1. 花了接近两周下班后的时间以及周末可以去学习了一下vue2.0vue3.0,看的是b站尚硅谷的视频,做了一个speaker verification的前端demo(vue3.0)。先看看整体页面效果:

大体上说说demo的实现方案:

  1. 1、后端直接使用python+flask非常简单。
  2. 2、前端采用vue3.0+html+css做一些简单的页面也非常容易(不过完全不懂前端的话学习起来还是需要一点时间的)。
  3. 3、算法端python+torch,模型使用了WavLmEcapa_TdNN模型。

五、总结

  1. 关于这个声纹识别,本文章只是简单的做了一个尝试和验证一下主流的模型方案的效果。并没有考虑实际业务场景,比方说音频的背景是否有噪声、跨设备、跨距离、录音代替真人实时说话问题、以及如何优化、上线需要注意那些问题都没有讨论。这里面还有很多值得学习的地方,本人水平有限,后续再来学习。
  2. 关于预训练模型WavLMCNN组网模型,个人认为WavLm应该是更加主流,个人更看好WavLm,如果有相应的音频数据,继续预训练+微调应该能解决一些特定领域的问题,前提是要有大规模的数据。

参考文章:

Speaker Verification——学习笔记

说话人确认系统性能评价指标EER和minDCF

ECAPA-TDNN: Emphasized Channel Attention, Propagation and Aggregation in TDNN Based Speaker Verification

通用模型、全新框架,WavLM语音预训练模型全解

WavLM: Large-Scale Self-Supervised Pre-Training for Full Stack Speech Processing


本文转载自: https://blog.csdn.net/HUSTHY/article/details/127808132
版权归原作者 colourmind 所有, 如有侵权,请联系我们删除。

“声纹识别之说话人验证speaker verification”的评论:

还没有评论