本文用来记录PhysioNet/CinC Challenge 2021比赛 中的参赛队伍代码复现过程。首先选择了排名第三位的NIMA队伍代码 ,对其进行复现。
目录
环境配置
首先使用anaconda创建虚拟环境,之后conda activate激活创建的环境。
1 2 conda create -n NIMA python=3 .9 conda activate NIMA
定位到项目目录,安装依赖
1 pip install -r requirements.txt
运行
模型训练
运行以下代码
1 python train_model.py training_data model
首次运行发现accuracy始终为1.0,检查代码发现预处理部分有问题,在helper_code.py文件202行的get_lables()函数中,l.startswith(“#Dx”)应为l.startswith(“# Dx”)。原代码如下:
1 2 3 4 5 6 7 8 9 10 11 def get_labels (header ): labels = list () for l in header.split('\n' ): if l.startswith('#Dx' ): try : entries = l.split(': ' )[1 ].split(',' ) for entry in entries: labels.append(entry.strip()) except : pass return labels
修正后:
1 2 3 4 5 6 7 8 9 10 11 def get_labels (header ): labels = list () for l in header.split('\n' ): if l.startswith("# Dx" ): try : entries = l.split(': ' )[1 ].split(',' ) for entry in entries: labels.append(entry.strip()) except : pass return labels
再次运行,发现accuracy固定在0.2111,检查代码发现在helper_code.py文件164行的get_adc_gains()函数中,分隔符出错。原代码如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 def get_adc_gains (header, leads ): adc_gains = np.zeros(len (leads)) for i, l in enumerate (header.split('\n' )): entries = l.split(' ' ) if i==0 : num_leads = int (entries[1 ]) elif i<=num_leads: current_lead = entries[-1 ] if current_lead in leads: j = leads.index(current_lead) try : adc_gains[j] = float (entries[2 ].split('/' )[0 ]) except : pass else : break return adc_gains
修改后:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 def get_adc_gains (header, leads ): adc_gains = np.zeros(len (leads)) for i, l in enumerate (header.split('\n' )): entries = l.split(' ' ) if i==0 : num_leads = int (entries[1 ]) elif i<=num_leads: current_lead = entries[-1 ] if current_lead in leads: j = leads.index(current_lead) try : adc_gains[j] = float (entries[2 ].split('(' )[0 ]) except : pass else : break return adc_gains
修改以后运行正常,部分输出结果如下:
1 2 3 4 Epoch 2 /17 77 /77 [==============================] - 263 s 3 s/step - loss: 0 .1003 - accuracy: 0 .7224 - AUROC: 0 .5515 - AUPRC: 0 .1923 Epoch 3 /17 77 /77 [==============================] - 265 s 3 s/step - loss: 0 .0879 - accuracy: 0 .7509 - AUROC: 0 .5756 - AUPRC: 0 .2163
感觉跑的太慢了,查看用的是gpu还是cpu
1 2 from tensorflow.python.client import device_libprint (device_lib.list_local_devices())
原来是用cpu在跑,一顿折腾换成gpu,参考GPU 支持 | TensorFlow (google.cn)
注意
helper_code.py文件由比赛主办方提供,且禁止修改,然而其中的函数确实和提供的数据文件不匹配,所以为了跑通还是把它改了。