0


基于CNN-RNN的医疗文本生成

🐱 基于CNN-RNN的医疗文本生成

本项目使用经过IMAGENET预训练的resnet101网络对图像特征进行提取后,
将图像特征输入LSTM来生成影像的文本描述。

初步实现了图像到文本的简单生成。


📖 0 项目背景

随着近年来深度学习的飞速发展,深度学习在医疗行业展现出巨大的发展潜力。因此,如果能通过深度学习的方法,使用计算机代替医生进行机械的影像报告撰写工作,这样既避免了经验不足的医生在阅片诊断中产生的误诊情况,又使得更多的资深医生可以从繁重的重复性工作中解脱出来,将更多的时间投入病人的诊治中去。

医学影像报告自动生成是近年来计算机与医疗图像新兴的交叉方向之一。目前,影像报告自动生成模型主要借鉴了机器翻译领域的 Encoder-Decoder 框架,利用卷积
神经网络(Convolutional Neural Network, CNN)对图像特征进行提取进而利用循环神经网络(Recurrent Neural Network, RNN)来生成影像的文本描述

📌 1 数据集

印第安纳大学胸部 X 射线集合 (IU X 射线) 是一组胸部 X 射线图像及其相应的诊断报告。该数据集包含 7,470 对图像和报告(6470:500:500)。 每个报告由以下部分组成:印象、发现、标签、比较和指示。平均每张图像关联2.2个标签,5.7个句子,每个句子包含6.5个单词。

本项目仅使用FINDINGS部分作为图像生成标签

参考代码:

🐰 2 数据集生成

🐅 2.1 医疗文本CSV生成


解压原始数据,对xml格式的数据进行解析,提取图像文件名和对应的FINDINGS,并生成CSV文件。

  1. # 解压数据集
  2. !unzip -o data/data123482/IU数据集.zip-d /home/aistudio/work/
  1. inflating: /home/aistudio/work/IU数据集/NLMCXR_reports/ecgen-radiology/1504.xml
  1. # 取消警告的输出import warnings
  2. warnings.filterwarnings("ignore")
  1. ## 制作CSV数据集# 平均字符数为 31.64992700729927import os
  2. import glob
  3. import pandas as pd
  4. from xml.dom import minidom
  5. import re
  6. import numpy as np
  7. LENGTH =[]defEmptyDrop(data):for i inrange(len(data)):if data.loc[i,'dir']==[]or data.loc[i,'caption']==[]:#如果为空,则删除该行
  8. data.drop([i],axis =0,inplace =True)else:
  9. data.loc[i,'dir']= data.loc[i,'dir'][0]
  10. data.loc[i,'caption']= data.loc[i,'caption'][0]
  11. data.reset_index(drop =True,inplace =True)return data
  12. defclean_text(origin_text):# 去掉标点和非法字符
  13. text = re.sub("^a-zA-Z"," ",origin_text)#大写改小写
  14. cleaned_text = text.lower()return cleaned_text
  15. defxml2csv(path):
  16. num =0
  17. column_name =['dir','caption']
  18. xml_csv = pd.DataFrame(columns = column_name)#图片保存地址
  19. pic_path ='work/IU数据集/NLMCXR_png'for xml_file in glob.glob(path+'/*.xml'):#记录每个xml需要保存的所有信息 fx 地址 IMPRESSION FINDINGS
  20. xml_list =[]#打开xml文档
  21. dom = minidom.parse(xml_file)#得到文档元素对象
  22. root = dom.documentElement
  23. # f1 地址
  24. itemlists=root.getElementsByTagName('parentImage')#记录地址
  25. dirAll =[]for itemlist in itemlists:
  26. figureId=itemlist.getElementsByTagName('figureId')#找出该图片的figureID
  27. figure = figureId[0].childNodes[0].nodeValue
  28. #找出该图片的名称
  29. ID= itemlist.getAttribute('id')
  30. IdPath = ID
  31. #正面图&侧面图
  32. figurePath =[figure+' '+IdPath]
  33. dirAll.extend(figurePath)
  34. xml_list.append(dirAll)#记录FINDINGS and IMPRESSION#记录内容
  35. CaptionAll =[]
  36. itemlists=root.getElementsByTagName('AbstractText')for i inrange(len(itemlists)):
  37. Label= itemlists[i].getAttribute('Label')if Label =='FINDINGS':# or Label == 'IMPRESSION':# 内容不为空iflen(itemlists[i].childNodes)!=0:
  38. text = itemlists[i].childNodes[0].nodeValue
  39. #转小写,过滤无效字符
  40. text = clean_text(text)
  41. text = text.replace('.','')
  42. text = text.replace(',','')
  43. text =[text+'']
  44. CaptionAll.extend(text)iflen(CaptionAll)>=1:
  45. LENGTH.append(len(CaptionAll[0].split(' ')))
  46. xml_list.append(CaptionAll)
  47. xml_csv.loc[num]=[item for item in xml_list]
  48. num = num +1print('epoch[{}/{}]'.format(num,len(glob.glob(path+'/*.xml'))))# print(np.mean(LENGTH))return xml_csv
  49. defmain():
  50. xml_path = os.path.join('work','IU数据集','NLMCXR_reports','ecgen-radiology')
  51. csv = xml2csv(xml_path)
  52. csv1 = EmptyDrop(csv)
  53. csv1.to_csv('work/IUxRay.csv',index=None)if __name__ =='__main__':
  54. main()

🐅 2.2 图像特征提取


  • 使用ImageNet预训练的resnet101模型提取图像特征(删除最后的全连接层,改为恒等映射)。
  • 将数据保存为h5文件
  1. ## 使用resnet101预训练模型提取图像特征import paddle
  2. from paddle.vision.models import resnet101
  3. import h5py
  4. import cv2
  5. csv_file = pd.read_csv('work/IUxRay.csv')
  6. h5_png_file =list(csv_file['dir'])# 创建保存目录
  7. save_path ='work/util_IUxRay'ifnot os.path.exists(save_path):
  8. os.makedirs(save_path)# 导入模型resnet101 使用谷歌预训练
  9. model = resnet101(pretrained=True)# 删除最后的全连接层del model.fc
  10. model.fc =lambda x:x
  11. h5f = h5py.File(os.path.join(save_path,'resnet101_festures.h5'),'w')for idx,item inenumerate(h5_png_file):# 读取所有PNGF1,F2...)print(idx,len(h5_png_file))
  12. item_all = item.split(',')for item_t in item_all:
  13. item_t = item_t.replace('\'','').replace('[','').replace(']','')# 对不同朝向的图进行区分for orie in['F1','F2','F3','F4']:if orie in item_t:
  14. orie_fin = orie
  15. item_fin = item_t.replace(orie,'').replace(' ','')
  16. item_fin_png = item_fin +'.png'print(orie_fin +'_'+ item_fin)# 读取文件送入模型提取特征并保存为h5
  17. img = cv2.imread(os.path.join('work/IU数据集/NLMCXR_png',item_fin_png))# BGRRGB,且HWCCHW
  18. img = img[:,:,::-1].transpose((2,0,1))# 扩展维度
  19. img = np.expand_dims(img,0)
  20. img_tensor = paddle.to_tensor(img,dtype='float32',place=paddle.CPUPlace())# 进行特征提取
  21. out = model(img_tensor)
  22. data = out.numpy().astype('float32')# 保存的数据为h5
  23. save_path_h5 = data[0]
  24. h5f.create_dataset(orie_fin +'_'+ item_fin, data=save_path_h5)
  25. h5f.close()
  1. # 读取h5文件import h5py
  2. h5f = h5py.File('work/util_IUxRay/resnet101_festures.h5','r')
  3. data = h5f['F1_CXR3027_IM-1402-1001']# 第一个下划线 之前为图片朝向,之后为图像原命名print(np.array(data).shape)# 每个图片保存为一个2048维度的向量
  4. h5f.close()

🐅 2.3 字典生成


  • 统计训练数据,按照单词进行分割创建字典。
  • 字典修正:删除仅在数据集中出现过一次的单词
  1. # 统计训练数据,以单词为粒度创建字典import pandas as pd
  2. import numpy as np
  3. import re
  4. csv_file = pd.read_csv('work/IUxRay.csv')
  5. csv_file.head()
  6. CaptionWordAll =[]
  7. CaptionWordLength =[]for idx,data_ inenumerate(csv_file.iterrows()):
  8. caption = data_[1][1]
  9. CaptionWordLength.append(len(caption.split(' ')))
  10. CaptionWordAll.extend(caption.split(' '))print('平均句子长度为:',np.mean(CaptionWordLength))print('最大句子长度为:',np.max(CaptionWordLength))print('最小句子长度为:',np.min(CaptionWordLength))print('单词总量为:',len(CaptionWordAll))print('字典长度为:',len(set(CaptionWordAll)))# 100from collections import Counter
  11. # 统计频率,按照从高到底排序,这样构建的字典使用频率最高的符号在最前面,查找起来快
  12. counts = Counter(CaptionWordAll)
  13. count_sorted = counts.most_common()
  14. count_sorted_ ={k: v for k, v in count_sorted if v >1}# 构造字典# 增加 <pad> 0 <unk> 1 <start> 2 <end> 3 四个作为常用符号
  15. word2id_dict={'<pad>':0,'<unk>':1,'<start>':2,'<end>':3}
  16. id2word_dict={0:'<pad>',1:'<unk>',2:'<start>',3:'<end>'}for idx,item inenumerate(count_sorted_):
  17. idx_ = idx+4# 预留四个做为记录
  18. item_ = item
  19. word2id_dict[item_]= idx_
  20. id2word_dict[idx_]= item_
  21. # 删除只出现一次的单词print('修正后字典长度为:',len(word2id_dict))

🥝 3 定义数据读取类


  • 将数据按照8:2划分为训练集和验证集。
  • 将文本数据经过字典进行映射,不同于翻译任务,本任务用图像特征替代了(85行)。
  1. ## 完成dataloadimport paddle
  2. from paddle.io import Dataset
  3. import numpy as np
  4. from sklearn.model_selection import train_test_split
  5. # 重写数据读取类classCaptionDataset(Dataset):# 构造函数,定义函数参数def__init__(self,csvData,word2id_dict,h5f,maxlength =40,mode ='train'):
  6. self.mode = mode
  7. self.w2i_dict = word2id_dict
  8. self.maxlength = maxlength # 输入的最长字符数
  9. self.padid =0# 0为填充符号
  10. self.h5f = h5f
  11. # 根据train/test 将数据按比例处理
  12. train,test =csvData.iloc[:int(0.8*len(csvData)),:],csvData.iloc[int(0.8*len(csvData)):,:]#train_test_split(csvData,train_size=0.8,random_state=10)if self.mode =='train':
  13. train.reset_index(drop=True)
  14. self.data = train
  15. else:
  16. test.reset_index(drop=True)
  17. self.data = test
  18. # 实现__getitem__方法,定义指定index时如何获取数据,并返回单条数据(训练数据,对应的标签)def__getitem__(self, index):
  19. path_name, trg_ = self.data.iloc[index,:]# 读取图像的特征
  20. temp = path_name.split(' ')
  21. names ='_'.join(temp)
  22. img_feature = np.array(self.h5f[names])# 第一个下划线 之前为图片朝向,之后为图像原命名# 输入转成idx
  23. trg,trg_length = self.generIdxList(trg_)# data
  24. img_name = temp[-1]return img_feature,trg,trg_length,img_name
  25. # 实现__len__方法,返回数据集总数目def__len__(self):returnlen(self.data)defgenerIdxList(self,tdata):# 从输入的String中,生成idx的List
  26. data = tdata.split(' ')
  27. data_out =[]# 限制长度,输入'<start>' '<end>'
  28. data =['<start>']+ data
  29. iflen(data)>self.maxlength-1:
  30. data = data[:self.maxlength-1]# 留一个位置给'<end>
  31. data = data +['<end>']else:# 占位符
  32. occupy_ =['<pad>']*(self.maxlength -1-len(data))
  33. data = data +['<end>']
  34. data = data + occupy_
  35. # word 2 indexfor word in data:if self.w2i_dict.get(word)!=None:# 能找到word
  36. id_ = self.w2i_dict[word]
  37. data_out.append(id_)else:
  38. id_ = self.w2i_dict['<unk>']
  39. data_out.append(id_)
  40. length =len(data_out)-1return data_out,length
  41. defstackInput(inputs):
  42. img_features = np.stack([inputsub[0]for inputsub in inputs], axis=0)
  43. trg = np.stack([inputsub[1]for inputsub in inputs], axis=0)
  44. trg_length = np.stack([inputsub[2]for inputsub in inputs], axis=0)
  45. trg_mask =(trg[:,:-1]!=0).astype(paddle.get_default_dtype())
  46. trg_ = trg[:,1:]# 将start标记更改为 imgfeaturesreturn img_features,trg_length,trg_[:,:-1],trg[:,1:,np.newaxis],trg_mask
  1. # 测试数据读取import pandas as pd
  2. import numpy as np
  3. import h5py
  4. from sklearn.model_selection import train_test_split
  5. csvData = pd.read_csv('work/IUxRay.csv')
  6. h5f = h5py.File('work/util_IUxRay/resnet101_festures.h5','r')
  7. maxlength =40
  8. dataset=CaptionDataset(csvData,word2id_dict,h5f,maxlength,'train')
  9. data_loader = paddle.io.DataLoader(dataset, batch_size=1,collate_fn = stackInput, shuffle=False)for item in data_loader:print(item[0].shape,item[1].shape,item[2].shape,item[3].shape,item[4].shape)break

💡 4 定义模型


  • 定义LSTM模型用于文本生成
  • 定义beam search算法对生成结果进行优化
  1. # 定义模型import paddle.nn as nn
  2. import paddle
  3. classCaptionModel(paddle.nn.Layer):def__init__(self, vocab_size,embedding_dim,hidden_size,num_layers,word2id_dict,id2word_dict):super(CaptionModel,self).__init__()
  4. self.hidden_size=hidden_size
  5. self.num_layers=num_layers
  6. self.fc = paddle.nn.Linear(2048,embedding_dim)
  7. self.embedding=paddle.nn.Embedding(vocab_size,embedding_dim)
  8. self.rnn=paddle.nn.LSTM(input_size=embedding_dim,
  9. hidden_size=hidden_size,
  10. num_layers=num_layers)
  11. self.word2ix = word2id_dict
  12. self.ix2word = id2word_dict
  13. self.classifier = paddle.nn.Linear(hidden_size,vocab_size)defforward(self,img_features,trg,trg_length):
  14. img_features = paddle.unsqueeze(self.fc(img_features),axis =1)
  15. embeddings = self.embedding(trg)
  16. inputs = paddle.concat([img_features,embeddings],axis =1)
  17. outputs,state = self.rnn(inputs,sequence_length = trg_length)
  18. predict = self.classifier(outputs)return predict
  19. defgenerate(self, img_feat, eos_token='<end>',
  20. beam_size=2,
  21. max_caption_length=40,
  22. length_normalization_factor=0.0):"""
  23. 根据图片生成描述,主要是使用beam search算法以得到更好的描述
  24. """
  25. cap_gen = CaptionGenerator(embedder=self.embedding,
  26. rnn=self.rnn,
  27. classifier=self.classifier,
  28. eos_id=self.word2ix[eos_token],
  29. beam_size=beam_size,
  30. max_caption_length=max_caption_length,
  31. length_normalization_factor=length_normalization_factor)
  32. img_feat = paddle.unsqueeze(img_feat,axis =0)
  33. img = paddle.unsqueeze(self.fc(img_feat),axis =0)
  34. sentences, score = cap_gen.beam_search(img)
  35. sentences =[' '.join([self.ix2word[int(idx)]for idx in sent])for sent in sentences]return sentences
  1. # Beam Searchimport paddle.nn as nn
  2. import heapq
  3. classTopN(object):"""Maintains the top n elements of an incrementally provided set."""def__init__(self, n):
  4. self._n = n
  5. self._data =[]defsize(self):assert self._data isnotNonereturnlen(self._data)defpush(self, x):"""Pushes a new element."""assert self._data isnotNoneiflen(self._data)< self._n:
  6. heapq.heappush(self._data, x)else:
  7. heapq.heappushpop(self._data, x)defextract(self, sort=False):"""
  8. Extracts all elements from the TopN. This is a destructive operation.
  9. The only method that can be called immediately after extract() is reset().
  10. Args:
  11. sort: Whether to return the elements in descending sorted order.
  12. Returns:
  13. A list of data; the top n elements provided to the set.
  14. """assert self._data isnotNone
  15. data = self._data
  16. self._data =Noneif sort:
  17. data.sort(reverse=True)return data
  18. defreset(self):"""Returns the TopN to an empty state."""
  19. self._data =[]classCaption(object):"""Represents a complete or partial caption."""def__init__(self, sentence, state, logprob, score, metadata=None):"""Initializes the Caption.
  20. Args:
  21. sentence: List of word ids in the caption.
  22. state: Model state after generating the previous word.
  23. logprob: Log-probability of the caption.
  24. score: Score of the caption.
  25. metadata: Optional metadata associated with the partial sentence. If not
  26. None, a list of strings with the same length as 'sentence'.
  27. """
  28. self.sentence = sentence
  29. self.state = state
  30. self.logprob = logprob
  31. self.score = score
  32. self.metadata = metadata
  33. def__cmp__(self, other):"""Compares Captions by score."""assertisinstance(other, Caption)if self.score == other.score:return0elif self.score < other.score:return-1else:return1# For Python 3 compatibility (__cmp__ is deprecated).def__lt__(self, other):assertisinstance(other, Caption)return self.score < other.score
  34. # Also for Python 3 compatibility.def__eq__(self, other):assertisinstance(other, Caption)return self.score == other.score
  35. classCaptionGenerator(object):"""Class to generate captions from an image-to-text model."""def__init__(self,
  36. embedder,
  37. rnn,
  38. classifier,
  39. eos_id,
  40. beam_size=3,
  41. max_caption_length=100,
  42. length_normalization_factor=0.0):"""Initializes the generator.
  43. Args:
  44. model: recurrent model, with inputs: (input, state) and outputs len(vocab) values
  45. beam_size: Beam size to use when generating captions.
  46. max_caption_length: The maximum caption length before stopping the search.
  47. length_normalization_factor: If != 0, a number x such that captions are
  48. scored by logprob/length^x, rather than logprob. This changes the
  49. relative scores of captions depending on their lengths. For example, if
  50. x > 0 then longer captions will be favored.
  51. """
  52. self.embedder = embedder
  53. self.rnn = rnn
  54. self.classifier = classifier
  55. self.eos_id = eos_id
  56. self.beam_size = beam_size
  57. self.max_caption_length = max_caption_length
  58. self.length_normalization_factor = length_normalization_factor
  59. defbeam_search(self, rnn_input, initial_state=None):"""Runs beam search caption generation on a single image.
  60. Args:
  61. initial_state: An initial state for the recurrent model
  62. Returns:
  63. A list of Caption sorted by descending score.
  64. """defget_topk_words(embeddings, state):
  65. output, new_states = self.rnn(embeddings, state)
  66. output = self.classifier(paddle.squeeze(output,axis=0))
  67. logprobs = nn.functional.log_softmax(output, axis=-1)iflen(logprobs.shape)==3:
  68. logprobs = paddle.squeeze(logprobs)
  69. logprobs, words = logprobs.topk(self.beam_size,1)return words, logprobs, new_states
  70. partial_captions = TopN(self.beam_size)
  71. complete_captions = TopN(self.beam_size)
  72. words, logprobs, new_state = get_topk_words(rnn_input, initial_state)for k inrange(self.beam_size):
  73. cap = Caption(
  74. sentence=[words[0, k]],
  75. state=new_state,
  76. logprob=logprobs[0, k],
  77. score=logprobs[0, k])
  78. partial_captions.push(cap)# Run beam search.for _ inrange(self.max_caption_length -1):
  79. partial_captions_list = partial_captions.extract()
  80. partial_captions.reset()
  81. input_feed =[c.sentence[-1]for c in partial_captions_list]
  82. input_feed = paddle.to_tensor(input_feed)
  83. state_feed =[c.state for c in partial_captions_list]ifisinstance(state_feed[0],tuple):
  84. state_feed_h, state_feed_c =zip(*state_feed)
  85. state_feed =(paddle.concat(state_feed_h,1),
  86. paddle.concat(state_feed_c,1))else:
  87. state_feed = paddle.concat(state_feed,1)
  88. embeddings = self.embedder(input_feed)
  89. words, logprobs, new_states = get_topk_words(
  90. embeddings, state_feed)for i, partial_caption inenumerate(partial_captions_list):ifisinstance(new_states,tuple):
  91. state =(paddle.slice(new_states[0],axes=[1],starts=[i],ends =[i+1]),
  92. paddle.slice(new_states[1],axes=[1],starts=[i],ends =[i+1]))else:
  93. state = new_states[i]for k inrange(self.beam_size):
  94. w = words[i, k]
  95. sentence = partial_caption.sentence +[w]
  96. logprob = partial_caption.logprob + logprobs[i, k]
  97. score = logprob
  98. if w == self.eos_id:if self.length_normalization_factor >0:
  99. score /=len(sentence)**self.length_normalization_factor
  100. beam = Caption(sentence, state, logprob, score)
  101. complete_captions.push(beam)else:
  102. beam = Caption(sentence, state, logprob, score)
  103. partial_captions.push(beam)if partial_captions.size()==0:# We have run out of partial candidates; happens when beam_size# = 1.break# If we have no complete captions then fall back to the partial captions.# But never output a mixture of complete and partial captions because a# partial caption could have a higher score than all the complete# captions.ifnot complete_captions.size():
  104. complete_captions = partial_captions
  105. caps = complete_captions.extract(sort=True)return[c.sentence for c in caps],[c.score for c in caps]

🥝 5 定义损失函数


  • 使用基本的交叉熵损失函数
  • 使用定义的trg_mask避免对padding部分求loss
  1. # 定义损失函数classCrossEntropy(paddle.nn.Layer):def__init__(self):super(CrossEntropy,self).__init__()defforward(self,pre,real,trg_mask):
  2. cost=paddle.nn.functional.softmax_with_cross_entropy(logits=pre,label=real)# 删除axis=2 shape上为1的维度
  3. cost=paddle.squeeze(cost,axis=[2])# trg_mask 的形状[batch_size,suqence_len]
  4. masked_cost=cost*trg_mask
  5. return paddle.mean(paddle.mean(masked_cost,axis=[0]))

🦃 6 定义参数并训练


  • 增加困惑度作为评价指标
  • 设置训练参数
  1. # 参数import h5py
  2. epochs=60
  3. word_size =1151
  4. eos_id=word2id_dict['<end>']
  5. num_layers=32
  6. hidden_size=512
  7. embedding_dim=512
  8. lr=1e-3
  9. maxlength=40
  10. model_path='./output'
  11. csvData = pd.read_csv('work/IUxRay.csv')
  12. h5f = h5py.File('work/util_IUxRay/resnet101_festures.h5','r')
  1. import paddlenlp
  2. model=CaptionModel(word_size,embedding_dim,hidden_size,num_layers,word2id_dict,id2word_dict)
  3. optimizer=paddle.optimizer.Adam(learning_rate=lr,parameters=model.parameters())# 困惑度
  4. ppl_metric=paddlenlp.metrics.Perplexity()
  5. train_dataset=CaptionDataset(csvData,word2id_dict,h5f,maxlength,'train')
  6. train_loader = paddle.io.DataLoader(train_dataset, batch_size=128,collate_fn = stackInput, shuffle=True)
  7. val_dataset=CaptionDataset(csvData,word2id_dict,h5f,maxlength,'test')
  8. val_loader = paddle.io.DataLoader(val_dataset, batch_size=64,collate_fn = stackInput, shuffle=True)# 设置优化器
  9. optimizer=paddle.optimizer.Adam(learning_rate=lr,parameters=model.parameters())# 设置损失函数
  10. loss_fn = CrossEntropy()
  11. perplexity = paddlenlp.metrics.Perplexity()
  12. model.train()for epoch inrange(epochs):for batch_id, data inenumerate(train_loader()):
  13. img_features,trg_length,inputs,label,label_mask = data[0],data[1],data[2],data[3], data[4]# 数据
  14. predicts = model(img_features,inputs,trg_length)# 预测结果# 计算损失 等价于 prepare loss的设置
  15. loss = loss_fn(predicts, label , label_mask)# 计算困惑度 等价于 prepare metrics的设置
  16. correct = perplexity.compute(predicts, label)
  17. perplexity.update(correct.numpy())
  18. ppl = perplexity.accumulate()# 下面的反向传播、打印训练信息、更新参数、梯度清零都被封装到 Model.fit() 中# 反向传播
  19. loss.backward()if(batch_id+1)%20==0:print("epoch: {}, batch_id: {}, loss is: {}, ppl is: {}".format(epoch+1, batch_id+1, loss.item(), ppl))# 保存模型参数,文件名为Unet_model.pdparams
  20. paddle.save(model.state_dict(),'work/LSTM_model.pdparams')# 更新参数
  21. optimizer.step()# 梯度清零
  22. optimizer.clear_grad()
  1. model.eval()for batch_id, data inenumerate(val_loader()):
  2. img_features,trg_length,inputs,label,label_mask = data[0],data[1],data[2],data[3], data[4]# 数据
  3. predicts = model(img_features,inputs,trg_length)# 预测结果# 计算损失 等价于 prepare loss的设置
  4. loss = loss_fn(predicts , label , label_mask)# 计算困惑度 等价于 prepare metrics的设置
  5. correct = perplexity.compute(predicts, label)
  6. perplexity.update(correct.numpy())
  7. ppl = perplexity.accumulate()# 下面的反向传播、打印训练信息、更新参数、梯度清零都被封装到 Model.fit() if(batch_id+1)%1==0:print(" batch_id: {}, loss is: {}, ppl is: {}".format( batch_id+1, loss.item(), ppl))

🍓 7 模型推理

  1. # 验证数据集from IPython.display import display
  2. from PIL import Image
  3. import numpy as np
  4. from tqdm import tqdm
  5. path ='work/IU数据集/NLMCXR_png/'
  6. csvData = pd.read_csv('work/IUxRay.csv')
  7. h5f = h5py.File('work/util_IUxRay/resnet101_festures.h5','r')
  8. data = csvData.iloc[int(0.8*len(csvData)):,:]
  9. scores =[]
  10. Beam_Size =3for idx,data_ in tqdm(enumerate(data.iterrows())):
  11. F_name = data_[1][0]
  12. F_text = data_[1][1]
  13. img_name = F_name.split(' ')[-1]
  14. h5f_name ='_'.join(F_name.split(' '))
  15. img_feature = np.array(h5f[h5f_name])
  16. img_path = path + img_name +'.png'
  17. img_feature = paddle.to_tensor(img_feature)
  18. results = model.generate(img_feature,beam_size=Beam_Size)#print('预测结果:',results[Beam_Size-1])#print('正确结果:',F_text)#img = Image.open(img_path).convert('RGB')#display(img, Image.BILINEAR)# 计算BLUEfrom nltk.translate.bleu_score import sentence_bleu
  19. reference =[F_text.split(' ')]
  20. candidate = results[Beam_Size-1].split(' ')
  21. score = sentence_bleu(reference,candidate)
  22. scores.append(score)print('预测结果:',results[Beam_Size-1])print('正确结果:',F_text)print('BLEU:',np.mean(scores))
  23. img = Image.open(img_path).convert('RGB'))
  24. img_path = path + img_name +'.png'
  25. img_feature = paddle.to_tensor(img_feature)
  26. results = model.generate(img_feature,beam_size=Beam_Size)#print('预测结果:',results[Beam_Size-1])#print('正确结果:',F_text)#img = Image.open(img_path).convert('RGB')#display(img, Image.BILINEAR)# 计算BLUEfrom nltk.translate.bleu_score import sentence_bleu
  27. reference =[F_text.split(' ')]
  28. candidate = results[Beam_Size-1].split(' ')
  29. score = sentence_bleu(reference,candidate)
  30. scores.append(score)print('预测结果:',results[Beam_Size-1])print('正确结果:',F_text)print('BLEU:',np.mean(scores))
  31. img = Image.open(img_path).convert('RGB')
  32. display(img, Image.BILINEAR)

🎖️ 8 项目总结

  • 项目主要使用CNN+RNN的形式对CT影像报告的生成进行演示。

  • 由于BeamSearch的部分代码有小bug,目前使用的实际上是最大概率 已修正,可以正常传入Beam Size参数

  • 该项目是ImageCaption任务在医疗文本领域的简单实现,

  • 本项目所有代码及数据均以notebook呈现,简单易懂。

  • 本项目使用BLUE进行效果评价


特别注意:该项目灵感来自《深度学习框架Pytorch入门与实践》第十章内容。


  1. 有任何问题,欢迎评论区留言交流。
标签: cnn rnn 深度学习

本文转载自: https://blog.csdn.net/Magic_Zsir/article/details/125429965
版权归原作者 猿知 所有, 如有侵权,请联系我们删除。

“基于CNN-RNN的医疗文本生成”的评论:

还没有评论