使用yolo3-tf2训练自己的数据集

You only look once (YOLO) 是一个实时目标检测模型。YOLO3在保持速度优势的前提下,提升了预测精度,尤其是加强了对小物体的识别能力。
本文将介绍如何使用yolo3 + tensorflow2.x训练自己的数据集。

配置环境

下载yolo工程

yolov3-tf2地址:https://github.com/zzh8829/yolov3-tf2

1
2
git clone  https://github.com/zzh8829/yolov3-tf2 yolov3-tf2
cd yolov3-tf2

建立conda环境

1
2
3
4
5
6
7
# Tensorflow CPU
conda env create -f conda-cpu.yml
conda activate yolov3-tf2-cpu

# Tensorflow GPU
conda env create -f conda-gpu.yml
conda activate yolov3-tf2-gpu

下载权重文件并验证

1
2
3
wget https://pjreddie.com/media/files/yolov3.weights -O data/yolov3.weights
python convert.py
python detect.py --image ./data/meme.jpg # Sanity check

制作VOC数据集

目录:

  • VOC
    • Annotations #存放xml文件,可使用LabelImg生成
    • JPEGImages #存放图片
    • ImageSets
      • Main
        • test.txt
        • train.txt
        • trainval.txt
        • val.txt

执行python make.py 在Main 下生成四个txt

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
# -*- coding:utf-8 -*-

import os
import random

trainval_percent = 0.6 # 自己设定(训练集+验证集)所占(训练集+验证集+测试集)的比重
train_percent = 0.7 # 自己设定(训练集)所占(训练集+验证集)的比重
xmlfilepath = 'Annotations' #注意自己地址是否正确
txtsavepath = 'ImageSets/Main' #注意自己地址是否正确
total_xml = os.listdir(xmlfilepath)

num = len(total_xml)
print(num)
list = range(num)
tv = int(num*trainval_percent)
tr = int(tv*train_percent)
trainval = random.sample(list,tv)
train = random.sample(trainval,tr)

ftrainval = open('ImageSets/Main/trainval.txt', 'w')
ftest = open('ImageSets/Main/test.txt', 'w')
ftrain = open('ImageSets/Main/train.txt', 'w')
fval = open('ImageSets/Main/val.txt', 'w')

for i in list:
name = total_xml[i][:-4]+'\n'
#print(name)
if i in trainval:
ftrainval.write(name)
if i in train:
ftrain.write(name)
else:
fval.write(name)
else:
ftest.write(name)

ftrainval.close()
ftrain.close()
fval.close()
ftest.close()
print('Done')

建立标签.names文件

在data文件夹下,写入的就是自己要训练的类别

生成tfrecord文件

训练集

1
2
3
4
5
python tools/voc2012.py \
--data_dir './data/voc-5h' \
--split train \
--output_file ./data/voc5h_train.tfrecord \
--classes ./data/voc-5h.names

测试集

1
2
3
4
5
python tools/voc2012.py \
--data_dir './data/voc-5h' \
--split val \
--output_file ./data/voc5h_val.tfrecord \
--classes ./data/voc-5h.names

训练

进行迁移训练

1
2
3
4
5
6
7
8
9
10
python train.py \
--dataset ./data/voc5h_train.tfrecord \
--val_dataset ./data/voc5h_val.tfrecord \
--classes ./data/voc-5h.names \
--num_classes 181 \
--mode fit --transfer darknet \
--batch_size 16 \
--epochs 10 \
--weights ./checkpoints/yolov3.tf \
--weights_num_classes 80

使用随机权重进行训练

1
2
3
4
5
6
7
8
python train.py \
--dataset ./data/voc5h_train.tfrecord \
--val_dataset ./data/voc5h_val.tfrecord \
--classes ./data/voc-5h.names \
--num_classes 181 \
--mode fit --transfer none \
--batch_size 16 \
--epochs 50 \

模型测试

从图像中检测

1
2
3
4
5
python detect.py \
--classes ./data/voc-5h.names \
--num_classes 181 \
--weights ./checkpoints/yolov3_train_5.tf \
--image ./data/street.jpg

从验证集中检测

1
2
3
4
5
python detect.py \
--classes ./data/voc-5h.names \
--num_classes 181 \
--weights ./checkpoints/yolov3_train_5.tf \
--tfrecord ./data/voc5h_val.tfrecord