Tensorflow学习笔记3: Object_detection之配置Training Pipeline

参考:Configuring an object detection pipeline

1、config文件

配置好的config文件存放路径:object_detection/samples/configs

2、PASCAL VOC数据集配置

选取faster_rcnn_resnet101_voc07.config做为该数据集的config文件,并复制到对应目录,下面为该文件的内容,需要修改的部分见备注

  1 # Faster R-CNN with Resnet-101 (v1), configured for Pascal VOC Dataset.
  2 # Users should configure the fine_tune_checkpoint field in the train config as
  3 # well as the label_map_path and input_path fields in the train_input_reader and
  4 # eval_input_reader. Search for "PATH_TO_BE_CONFIGURED" to find the fields that
  5 # should be configured.
  6 
  7 model {
  8   faster_rcnn {
  9     num_classes: 20 # 如果是自己数据集需要修改类目数
 10     image_resizer {
 11       keep_aspect_ratio_resizer {
 12         min_dimension: 600
 13         max_dimension: 1024
 14       }
 15     }
 16     feature_extractor {
 17       type: 'faster_rcnn_resnet101'
 18       first_stage_features_stride: 16
 19     }
 20     first_stage_anchor_generator {
 21       grid_anchor_generator {
 22         scales: [0.25, 0.5, 1.0, 2.0]
 23         aspect_ratios: [0.5, 1.0, 2.0]
 24         height_stride: 16
 25         width_stride: 16
 26       }
 27     }
 28     first_stage_box_predictor_conv_hyperparams {
 29       op: CONV
 30       regularizer {
 31         l2_regularizer {
 32           weight: 0.0
 33         }
 34       }
 35       initializer {
 36         truncated_normal_initializer {
 37           stddev: 0.01
 38         }
 39       }
 40     }
 41     first_stage_nms_score_threshold: 0.0
 42     first_stage_nms_iou_threshold: 0.7
 43     first_stage_max_proposals: 300
 44     first_stage_localization_loss_weight: 2.0
 45     first_stage_objectness_loss_weight: 1.0
 46     initial_crop_size: 14
 47     maxpool_kernel_size: 2
 48     maxpool_stride: 2
 49     second_stage_box_predictor {
 50       mask_rcnn_box_predictor {
 51         use_dropout: false
 52         dropout_keep_probability: 1.0
 53         fc_hyperparams {
 54           op: FC
 55           regularizer {
 56             l2_regularizer {
 57               weight: 0.0
 58             }
 59           }
 60           initializer {
 61             variance_scaling_initializer {
 62               factor: 1.0
 63               uniform: true
 64               mode: FAN_AVG
 65             }
 66           }
 67         }
 68       }
 69     }
 70     second_stage_post_processing {
 71       batch_non_max_suppression {
 72         score_threshold: 0.0
 73         iou_threshold: 0.6
 74         max_detections_per_class: 100
 75         max_total_detections: 300
 76       }
 77       score_converter: SOFTMAX
 78     }
 79     second_stage_localization_loss_weight: 2.0
 80     second_stage_classification_loss_weight: 1.0
 81   }
 82 }
 83 
 84 train_config: {
 85   batch_size: 1 # 每次喂的数据量
 86   optimizer {
 87     momentum_optimizer: {
 88       learning_rate: {
 89         manual_step_learning_rate {
 90           initial_learning_rate: 0.0001
 91           schedule {
 92             step: 500000
 93             learning_rate: .00001
 94           }
 95           schedule {
 96             step: 700000
 97             learning_rate: .000001
 98           }
 99         }
100       }
101       momentum_optimizer_value: 0.9
102     }
103     use_moving_average: false
104   }
105   gradient_clipping_by_norm: 10.0
106   fine_tune_checkpoint: "PATH_TO_BE_CONFIGURED/model.ckpt"  # 是否需要加入别人预先训练好的模型,如是需要加入完整文件路径,别人预先训练好的模型可以从这找到:https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md
107   from_detection_checkpoint: false  # 布尔值,true为使用别人预先训练好的模型,这里暂时先不加,后面在来研究怎么去匹配
108   num_steps: 800000 # 最大训练次数
109   data_augmentation_options {
110     random_horizontal_flip {
111     }
112   }
113 }
114 
115 train_input_reader: {
116   tf_record_input_reader {
117     input_path: "/data/zxx/models/research/date/VOCdevkit/pascal_train.record" # 对应修改路径
118   }
119   label_map_path: "object_detection/data/pascal_label_map.pbtxt" # 对应修改路径
120 }
121 
122 eval_config: {
123   num_examples: 4952
124 }
125 
126 eval_input_reader: {
127   tf_record_input_reader {
128     input_path: "/data/zxx/models/research/date/VOCdevkit/pascal_val.record" # 对应修改路径
129   }
130   label_map_path: "object_detection/data/pascal_label_map.pbtxt"  # 对应修改路径
131   shuffle: false
132   num_readers: 1
133 }

修改好后,保存