0


TensorFlow-slim包进行图像数据集分类---具体流程

TensorFlow中slim包的具体用法

本次使用的TensorFlow版本是1.13.0
地址:https://github.com/tensorflow/models/tree/r1.13.0
到tensorflow-models的GitHub下载research下面的slim这个包到本地
在这里插入图片描述

TensorFlow中slim包的目录结构:

  1. -- slim
  2. |-- BUILD
  3. |-- README.md
  4. |-- WORKSPACE
  5. |-- __init__.py
  6. |-- datasets
  7. ||-- __init__.py
  8. ||-- __pycache__
  9. |||-- __init__.cpython-37.pyc
  10. |||-- dataset_utils.cpython-37.pyc
  11. |||-- download_and_convert_cifar10.cpython-37.pyc
  12. |||-- download_and_convert_flowers.cpython-37.pyc
  13. || `-- download_and_convert_mnist.cpython-37.pyc
  14. ||-- build_imagenet_data.py
  15. ||-- cifar10.py
  16. ||-- dataset_factory.py
  17. ||-- dataset_utils.py
  18. ||-- download_and_convert_cifar10.py
  19. ||-- download_and_convert_flowers.py
  20. ||-- download_and_convert_imagenet.sh
  21. ||-- download_and_convert_mnist.py
  22. ||-- download_imagenet.sh
  23. ||-- flowers.py
  24. ||-- imagenet.py
  25. ||-- imagenet_2012_validation_synset_labels.txt
  26. ||-- imagenet_lsvrc_2015_synsets.txt
  27. ||-- imagenet_metadata.txt
  28. ||-- mnist.py
  29. ||-- preprocess_imagenet_validation_data.py
  30. | `-- process_bounding_boxes.py
  31. |-- deployment
  32. ||-- __init__.py
  33. ||-- model_deploy.py
  34. | `-- model_deploy_test.py
  35. |-- download_and_convert_data.py # 下载相应的数据集,并将数据打包成TF-record的格式|-- eval_image_classifier.py # 测试模型分类效果|-- export_inference_graph.py
  36. |-- export_inference_graph_test.py
  37. |-- nets
  38. ||-- __init__.py
  39. ||-- alexnet.py
  40. ||-- alexnet_test.py
  41. ||-- cifarnet.py
  42. ||-- cyclegan.py
  43. ||-- cyclegan_test.py
  44. ||-- dcgan.py
  45. ||-- dcgan_test.py
  46. ||-- i3d.py
  47. ||-- i3d_test.py
  48. ||-- i3d_utils.py
  49. ||-- inception.py
  50. ||-- inception_resnet_v2.py
  51. ||-- inception_resnet_v2_test.py
  52. ||-- inception_utils.py
  53. ||-- inception_v1.py
  54. ||-- inception_v1_test.py
  55. ||-- inception_v2.py
  56. ||-- inception_v2_test.py
  57. ||-- inception_v3.py
  58. ||-- inception_v3_test.py
  59. ||-- inception_v4.py
  60. ||-- inception_v4_test.py
  61. ||-- lenet.py
  62. ||-- mobilenet
  63. |||-- README.md
  64. |||-- __init__.py
  65. |||-- conv_blocks.py
  66. |||-- madds_top1_accuracy.png
  67. |||-- mnet_v1_vs_v2_pixel1_latency.png
  68. |||-- mobilenet.py
  69. |||-- mobilenet_example.ipynb
  70. |||-- mobilenet_v2.py
  71. || `-- mobilenet_v2_test.py
  72. ||-- mobilenet_v1.md
  73. ||-- mobilenet_v1.png
  74. ||-- mobilenet_v1.py
  75. ||-- mobilenet_v1_eval.py
  76. ||-- mobilenet_v1_test.py
  77. ||-- mobilenet_v1_train.py
  78. ||-- nasnet
  79. |||-- README.md
  80. |||-- __init__.py
  81. |||-- nasnet.py
  82. |||-- nasnet_test.py
  83. |||-- nasnet_utils.py
  84. |||-- nasnet_utils_test.py
  85. |||-- pnasnet.py
  86. || `-- pnasnet_test.py
  87. ||-- nets_factory.py
  88. ||-- nets_factory_test.py
  89. ||-- overfeat.py
  90. ||-- overfeat_test.py
  91. ||-- pix2pix.py
  92. ||-- pix2pix_test.py
  93. ||-- resnet_utils.py
  94. ||-- resnet_v1.py
  95. ||-- resnet_v1_test.py
  96. ||-- resnet_v2.py
  97. ||-- resnet_v2_test.py
  98. ||-- s3dg.py
  99. ||-- s3dg_test.py
  100. ||-- vgg.py
  101. | `-- vgg_test.py
  102. |-- preprocessing
  103. ||-- __init__.py
  104. ||-- cifarnet_preprocessing.py
  105. ||-- inception_preprocessing.py
  106. ||-- lenet_preprocessing.py
  107. ||-- preprocessing_factory.py
  108. | `-- vgg_preprocessing.py
  109. |-- scripts # gqr:存储的是相关的模型训练脚本 ||-- export_mobilenet.sh
  110. ||-- finetune_inception_resnet_v2_on_flowers.sh
  111. ||-- finetune_inception_v1_on_flowers.sh
  112. ||-- finetune_inception_v3_on_flowers.sh
  113. ||-- finetune_resnet_v1_50_on_flowers.sh
  114. ||-- train_cifarnet_on_cifar10.sh
  115. | `-- train_lenet_on_mnist.sh
  116. |-- setup.py
  117. |-- slim_walkthrough.ipynb
  118. `-- train_image_classifier.py # 训练模型的脚本

1、训练脚本文件(该文件包含数据下载打包、模型训练,模型评估流程)

scripts/finetune_resnet_v1_50_on_flowers.sh

  1. #!/bin/bash# Copyright 2017 The TensorFlow Authors. All Rights Reserved.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.# ==============================================================================## This script performs the following operations:# 1. Downloads the Flowers dataset# 2. Fine-tunes a ResNetV1-50 model on the Flowers training set.# 3. Evaluates the model on the Flowers validation set.## Usage:# cd slim# ./slim/scripts/finetune_resnet_v1_50_on_flowers.shset-e
  2. # Where the pre-trained ResNetV1-50 checkpoint is saved to.
  3. PRETRAINED_CHECKPOINT_DIR=/tmp/checkpoints # gqr:预训练模型存放路径# Where the training (fine-tuned) checkpoint and logs will be saved to.
  4. TRAIN_DIR=/tmp/flowers-models/resnet_v1_50
  5. # Where the dataset is saved to.
  6. DATASET_DIR=/tmp/flowers # gqr:数据集存放路径# Download the pre-trained checkpoint.if[ ! -d "$PRETRAINED_CHECKPOINT_DIR"]; then
  7. mkdir ${PRETRAINED_CHECKPOINT_DIR}
  8. fi
  9. if[ ! -f ${PRETRAINED_CHECKPOINT_DIR}/resnet_v1_50.ckpt ]; then
  10. wget http://download.tensorflow.org/models/resnet_v1_50_2016_08_28.tar.gz
  11. tar -xvf resnet_v1_50_2016_08_28.tar.gz
  12. mv resnet_v1_50.ckpt ${PRETRAINED_CHECKPOINT_DIR}/resnet_v1_50.ckpt
  13. rm resnet_v1_50_2016_08_28.tar.gz
  14. fi
  15. # Download the dataset
  16. python download_and_convert_data.py \
  17. --dataset_name=flowers \
  18. --dataset_dir=${DATASET_DIR}# Fine-tune only the new layers for 3000 steps.
  19. python train_image_classifier.py \
  20. --train_dir=${TRAIN_DIR} \
  21. --dataset_name=flowers \
  22. --dataset_split_name=train \
  23. --dataset_dir=${DATASET_DIR} \
  24. --model_name=resnet_v1_50 \
  25. --checkpoint_path=${PRETRAINED_CHECKPOINT_DIR}/resnet_v1_50.ckpt \
  26. --checkpoint_exclude_scopes=resnet_v1_50/logits \
  27. --trainable_scopes=resnet_v1_50/logits \
  28. --max_number_of_steps=3000 \
  29. --batch_size=32 \
  30. --learning_rate=0.01 \
  31. --save_interval_secs=60 \
  32. --save_summaries_secs=60 \
  33. --log_every_n_steps=100 \
  34. --optimizer=rmsprop \
  35. --weight_decay=0.00004# Run evaluation.
  36. python eval_image_classifier.py \
  37. --checkpoint_path=${TRAIN_DIR} \
  38. --eval_dir=${TRAIN_DIR} \
  39. --dataset_name=flowers \
  40. --dataset_split_name=validation \
  41. --dataset_dir=${DATASET_DIR} \
  42. --model_name=resnet_v1_50
  43. # Fine-tune all the new layers for 1000 steps.
  44. python train_image_classifier.py \
  45. --train_dir=${TRAIN_DIR}/all \
  46. --dataset_name=flowers \
  47. --dataset_split_name=train \
  48. --dataset_dir=${DATASET_DIR} \
  49. --checkpoint_path=${TRAIN_DIR} \
  50. --model_name=resnet_v1_50 \
  51. --max_number_of_steps=1000 \
  52. --batch_size=32 \
  53. --learning_rate=0.001 \
  54. --save_interval_secs=60 \
  55. --save_summaries_secs=60 \
  56. --log_every_n_steps=100 \
  57. --optimizer=rmsprop \
  58. --weight_decay=0.00004# Run evaluation.
  59. python eval_image_classifier.py \
  60. --checkpoint_path=${TRAIN_DIR}/all \
  61. --eval_dir=${TRAIN_DIR}/all \
  62. --dataset_name=flowers \
  63. --dataset_split_name=validation \
  64. --dataset_dir=${DATASET_DIR} \
  65. --model_name=resnet_v1_50

以上文件以下载并打包flowers数据集为例会调用slim/datasets下的****download_and_convert_flowers.py
在这里插入图片描述
代码43行:**_NUM_VALIDATION = 350值的意思的测试数据集的数量,我们一般2,8分数据集,这里只用填写测试集的数据代码会自动吧总数据集分成2部分
代码48行:
_NUM_SHARDS = 1**这个的意思是生成几个tfrecord文件,这个数量是根据你数据量来划分
在这里插入图片描述
代码190行:dataset_utils.download_and_uncompress_tarball(_DATA_URL, dataset_dir) 函数为下载数据集函数,如果本地已经存在数据集,可将将其注释掉
在这里插入图片描述
代码210行:_clean_up_temporary_files(dataset_dir) 函数为打包完毕后删除下载的数据集文件,如果需要下载的数据集可以将其注释掉

上述文件执行完毕后,会得到以下文件
在这里插入图片描述

3、模型训练

模型训练文件为
在这里插入图片描述
以下是该文件中各个模块相关内容

1、数据集相关模块:

在这里插入图片描述

2、设置网络模型模块

在这里插入图片描述

3、数据预处理模块

在这里插入图片描述

4、定义损失loss

在这里插入图片描述

5、定义优化器模块

在这里插入图片描述

运行训练指令:

  1. python train_image_classifier.py \
  2. --train_dir=./data/flowers-models/resnet_v1_50\
  3. --dataset_name=flowers \
  4. --dataset_split_name=train \
  5. --dataset_dir=./data/flowers \
  6. --model_name=resnet_v1_50 \
  7. --checkpoint_path=./data/checkpoints/resnet_v1_50.ckpt \
  8. --checkpoint_exclude_scopes=resnet_v1_50/logits \
  9. --trainable_scopes=resnet_v1_50/logits \
  10. --max_number_of_steps=3000 \
  11. --batch_size=32 \
  12. --learning_rate=0.01 \
  13. --save_interval_secs=60 \
  14. --save_summaries_secs=60 \
  15. --log_every_n_steps=100 \
  16. --optimizer=rmsprop \
  17. --weight_decay=0.00004

–dataset_name=指定模板
–model_name=指定预训练模板
–dataset_dir=指定训练集目录
–checkpoint_exclude_scopes=指定忘记那几层的参数,不带进训练里面,记住提取特征的部分
–train_dir=训练参数存放地址
–trainable_scopes=设定只对那几层变量进行调整,其他层都不进行调整,不设定就会对所有层训练(所以是必须要给定的)
–learning_rate=学习率
–optimizer=优化器
–checkpoint_path:预训练模型存放地址
–max_number_of_steps=训练步数
–batch_size=一次训练所选取的样本数。 (Batch Size的大小影响模型的优化程度和速度。同时其直接影响到GPU内存的使用情况,假如你GPU内存不大,该数值最好设置小一点。)
–weight_decay=即模型中所有参数的二次正则化超参数(这个的加入就是为了防止过拟合加入正则项,weight_decay 是乘在正则项的前面,控制正则化项在损失函数中所占权重的)

注意:在模型训练前,需要下载预训练模型,
wget http://download.tensorflow.org/models/resnet_v1_50_2016_08_28.tar.gz

解压后存放在相应目录


本文转载自: https://blog.csdn.net/guoqingru0311/article/details/132514699
版权归原作者 郭庆汝 所有, 如有侵权,请联系我们删除。

“TensorFlow-slim包进行图像数据集分类---具体流程”的评论:

还没有评论