0


Tensorflow2数据集过大,GPU内存不够

前言:
在我们平时使用tensorflow训练模型时,有时候可能因为数据集太大(比如VOC数据集等等)导致GPU内存不够导致终止,可以自制一个数据生成器来解决此问题。

代码如下:

  1. deftrain_generator(train_path,train_labels,batch):
  2. over=len(train_path)%batch
  3. whileTrue:for i inrange(0,len(train_path)-over,batch):
  4. train_data=read_img(train_path[i:i+batch])
  5. train_label=train_labels[i:i+batch]yield(np.array(train_data), np.array(train_label))

方法就是将数据集图片的路径保存到一个列表之中,然后使用while循环在训练时进行不断读取,这里over的作用是防止图片长度不是batch整数倍,导致label的数据长度不等于batch,我在训练时出现了这样的问题,这是我的猜测。然后yield与return的不同是,return是在函数执行到return就会退出函数,而yield则不会退出函数,所以使用yield
最后一句话也可以改成:

  1. yield({'input':np.array(train_data)},{'output':np.array(train_label)})

'input’是你网络第一层的名字.。
'output’是你网络最后一层的名字。

接下来是使用代码:

  1. history=model.fit(train_generator(train_data,train_label,batch=Yolo_param.Batch_size),
  2. batch_size=Yolo_param.Batch_size,
  3. epochs=10,
  4. steps_per_epoch=1024,
  5. validation_steps=32,
  6. callbacks=[callback],
  7. validation_data=train_generator(test_data,test_label,batch=Yolo_param.Batch_size))

steps_per_epoch这个参数是每个epoch的数据大小,如果不给进度就能难显示。

最后就是显存设置:

  1. gpus = tf.config.list_physical_devices('GPU')if gpus:try:
  2. tf.config.set_logical_device_configuration(
  3. gpus[0],[tf.config.LogicalDeviceConfiguration(memory_limit=4096)])except RuntimeError as e:print(e)

4096就是你限制显卡内存的大小,可以根据自己显卡实际情况来进行设置


本文转载自: https://blog.csdn.net/darlingqx/article/details/127222191
版权归原作者 月明Mo 所有, 如有侵权,请联系我们删除。

“Tensorflow2数据集过大,GPU内存不够”的评论:

还没有评论