训练自己的数据集

#标注后的目录结构
project
└── labelimg
    ├── 20190128155421222575013.jpg
    ├── 20190128155421222575013.xml
    ├── 20190128155703035712899.jpg
    ├── 20190128155703035712899.xml
    ├── 20190129091126392737624.jpg
    └── 20190129091126392737624.xml
  • 拉取容器镜像
    $ sudo docker pull gouchicao/keras-retinanet:latest
    
  • 运行容器
    $ sudo docker run -it --runtime=nvidia --name=keras-retinanet -p 8888:8888 -p 6006:6006 \
                  -v /home/wjunjian/ailab/datasets/helmet:/keras-retinanet/project \
                  gouchicao/keras-retinanet bash
    
  • voc转csv格式,分隔数据集
    $ python voc2csv.py --data_dir=project/labelimg/ --output_dir=project/dataset
    
#生成的目录结构
project
├── dataset
│   ├── class.csv
│   ├── train
│   │   ├── 20190128155421222575013.jpg
│   │   ├── 20190128155421222575013.xml
│   │   ├── 20190129091126392737624.jpg
│   │   └── 20190129091126392737624.xml
│   ├── train.csv
│   ├── val
│   │   ├── 20190128155703035712899.jpg
│   │   └── 20190128155703035712899.xml
│   └── val.csv
└── labelimg
    ├── 20190128155421222575013.jpg
    ├── 20190128155421222575013.xml
    ├── 20190128155703035712899.jpg
    ├── 20190128155703035712899.xml
    ├── 20190129091126392737624.jpg
    └── 20190129091126392737624.xml
  • 模型训练
    $ python keras-retinanet/keras_retinanet/bin/train.py --tensorboard-dir=project/logs --snapshot-path project/snapshots \
      csv project/dataset/train.csv project/dataset/class.csv --val-annotations project/dataset/val.csv
    $ ll -h project/models/resnet50_csv_01.h5
    -rw-r--r-- 1 root     root     417M 7月  27 22:58 resnet50_csv_01.h5
    
  • 训练过程可视化 TensorBoard
    $ tensorboard --logdir=project/logs --bind_all
    

    在本机浏览器中访问网址:http://localhost:6006

  • 模型评估
    $ python keras-retinanet/keras_retinanet/bin/evaluate.py csv project/dataset/val.csv project/dataset/class.csv \
      project/snapshots/resnet50_csv_01.h5 --convert-model
    
  • 模型转换
    $ mkdir project/inference
    $ python keras-retinanet/keras_retinanet/bin/convert_model.py --no-class-specific-filter \
      project/snapshots/resnet50_csv_01.h5 project/inference/model.h5
    $ ll -h project/inference/model.h5
    -rw-r--r-- 1 root     root     140M 7月  27 23:14 model.h5
    
  • 模型预测
    $ python predict.py --model project/inference/model.h5 \
      --class_csv project/dataset/class.csv \
      --data_dir project/test \
      --predict_dir project/predict
    

参考资料