基于DetNet-FPN的视频实时检测demo,pytorch版

最近要使用faster-rcnn,DetNet-FPN以及Light-Head三种目标检测方案训练自己的数据集,并做一个比较。在GitHub上搜罗了一番,发现下面三个开源项目一脉相承,正合我意。期中DetNet_pytorch和pytorch-lighthead我认为都是基于faster-rcnn.pytorch这个项目改的,作者貌似是佐治亚理工的Ph.D.,确实是高手中的高高手,代码写得很流畅。我三个项目都fork了下来,做了不少的修改以适配自己的数据集(不能一天到晚玩coco、voc嘛),期中faster-rcnn.pytorch和DetNet_pytorch都成功地进行了适配并做了大量实验,唯独pytorch-lighthead死活没有搞成功。哎,暂且不管了,考虑到我自己的数据集分类数很有限,用lighthead理论上并不会有啥优势。事实上,我后面尝试了LightHead的tensorflow版本,也印证了我的想法,mAP差异甚小。

https://github.com/jwyang/faster-rcnn.pytorch

https://github.com/guoruoqian/DetNet_pytorch

https://github.com/chengsq/pytorch-lighthead

以使用的较多的DetNet_pytorch为例,配合代码介绍视频流demo(输入为视频文件而非单张图片)。对于DetNet的实现原理,不再做过多阐述,这一部分网络上已有大量详细介绍的文章。

0. 我的环境

Python 2.7

Pytorch 0.3.1

CUDA 9.0

1. 准备

按照作者文档上的要求安装好一些依赖包,这个过程可能会有一点小挫折,别一开始就放弃,以下是作者建议的环境配置。

Python 2.7 or 3.6

Pytorch 0.2.0 or higher(not support pytorch version >=0.4.0)

CUDA 8.0 or higher

从Git上下载代码,完成编译。编译也有可能遇到错误,继续锻炼你Linux运维能力。

cd lib

sh make.sh

2.代码调整

作者已经提供了一个demo.py的程序,虽然不是我想要的视频流demo,但主体架构没问题,只要在这上面进行二次开发即可。以下是我修改后的demo.py程序,以供参考。如果是想要修改为faster rcnn版的视频流demo也是很容易的,可自行修改。

# coding: utf-8

# --------------------------------------------------------

# Tensorflow Faster R-CNN

# Licensed under The MIT License [see LICENSE for details]

# Written by Jiasen Lu, Jianwei Yang, based on code from Ross Girshick

# --------------------------------------------------------

from __future__ import absolute_import

from __future__ import division

from __future__ import print_function

import _init_paths

import os

import sys

import numpy as np

import argparse

import pprint

import pdb

import time

import cv2

import cPickle

import torch

from torch.autograd import Variable

import torch.nn as nn

import torch.optim as optim

import glob

import torchvision.transforms as transforms

import torchvision.datasets as dset

from PIL import Image

from scipy.misc import imread

from roi_data_layer.roidb import combined_roidb

from roi_data_layer.roibatchLoader import roibatchLoader

from model.utils.config import cfg, cfg_from_file, cfg_from_list, get_output_dir

from model.rpn.bbox_transform import clip_boxes

from model.nms.nms_wrapper import nms

from model.rpn.bbox_transform import bbox_transform_inv

from model.utils.net_utils import save_net, load_net, vis_detections

from model.utils.blob import prep_im_for_blob,im_list_to_blob

import pdb

from model.fpn.detnet_backbone import detnet

import warnings

warnings.filterwarnings("ignore")

def parse_args():

"""

Parse input arguments

"""

#减少篇幅,此处省略若干

args = parser.parse_args()

return args

lr = cfg.TRAIN.LEARNING_RATE

momentum = cfg.TRAIN.MOMENTUM

weight_decay = cfg.TRAIN.WEIGHT_DECAY

def get_image_blob(im):

"""Converts an image into a network input.

Arguments:

im: data of image

Returns:

blob (ndarray): a data blob holding an image pyramid

im_scale_factors (list): list of image scales (relative to im) used

in the image pyramid

"""

im_scales = []

processed_ims = []

scale_inds = np.random.randint(0, high=len(cfg.TRAIN.SCALES), size=1)

target_size = cfg.TRAIN.SCALES[scale_inds[0]]

im, im_scale = prep_im_for_blob(im, cfg.PIXEL_MEANS, cfg.PIXEL_STDS, target_size, cfg.TRAIN.MAX_SIZE)

im_scales.append(im_scale)

processed_ims.append(im)

# Create a blob to hold the input images

blob = im_list_to_blob(processed_ims)

return blob, im_scales

if __name__ == '__main__':

args = parse_args()

print('Called with args:')

print(args)

args.cfg_file = "cfgs/{}.yml".format(args.net)

if args.cfg_file is not None:

cfg_from_file(args.cfg_file)

if args.set_cfgs is not None:

cfg_from_list(args.set_cfgs)

#print('Using config:')

#pprint.pprint(cfg)

np.random.seed(cfg.RNG_SEED)

cfg.TRAIN.USE_FLIPPED = False

# train set

# -- Note: Use validation set and disable the flipped to enable faster loading.

if args.exp_name is not None:

input_dir = args.load_dir + "/" + args.net + "/" + args.dataset + '/' + args.exp_name

else:

input_dir = args.load_dir + "/" + args.net + "/" + args.dataset

if not os.path.exists(input_dir):

raise Exception('There is no input directory for loading network from ' + input_dir)

load_name = os.path.join(input_dir,

'fpn_{}_{}_{}.pth'.format(args.checksession, args.checkepoch, args.checkpoint))

classes = cfg.TRAIN.CLASSES

fpn = detnet(classes, 59, pretrained=False, class_agnostic=args.class_agnostic)

fpn.create_architecture()

print('load checkpoint %s' % (load_name))

checkpoint = torch.load(load_name)

fpn.load_state_dict(checkpoint['model'])

if 'pooling_mode' in checkpoint.keys():

cfg.POOLING_MODE = checkpoint['pooling_mode']

# initilize the tensor holder here.

im_data = torch.FloatTensor(1)

im_info = torch.FloatTensor(1)

num_boxes = torch.LongTensor(1)

gt_boxes = torch.FloatTensor(1)

# ship to cuda

if args.ngpu > 0:

im_data = im_data.cuda()

im_info = im_info.cuda()

num_boxes = num_boxes.cuda()

gt_boxes = gt_boxes.cuda()

# make variable

im_data = Variable(im_data, volatile=True)

im_info = Variable(im_info, volatile=True)

num_boxes = Variable(num_boxes, volatile=True)

gt_boxes = Variable(gt_boxes, volatile=True)

if args.ngpu > 0:

cfg.CUDA = True

if args.ngpu > 0:

fpn.cuda()

fpn.eval()

max_per_image = 100

thresh = 0.05

vis_thresh = 0.8

vis = True

if not os.path.exists(args.video_file):

raise Exception("video %s not exist".format(args.video_file))

vc = cv2.VideoCapture(args.video_file)

i = 0

while True:

i += 1

_, im = vc.read()

if im is None:

break

blobs, im_scales = get_image_blob(im)

assert len(im_scales) == 1, "Only single-image batch implemented"

im_blob = blobs

# (h,w,scale)

im_info_np = np.array([[im_blob.shape[1], im_blob.shape[2], im_scales[0]]], dtype=np.float32)

im_data_pt = torch.from_numpy(im_blob)

# exchange dimension->(b,c,h,w)

im_data_pt = im_data_pt.permute(0, 3, 1, 2)

im_info_pt = torch.from_numpy(im_info_np)

#im_info_pt = im_info_pt.view(3)

im_data.data.resize_(im_data_pt.size()).copy_(im_data_pt)

im_info.data.resize_(im_info_pt.size()).copy_(im_info_pt)

gt_boxes.data.resize_(1, 1, 5).zero_()

num_boxes.data.resize_(1).zero_()

det_tic = time.time()

rois, cls_prob, bbox_pred, \

_, _, _, _, _ = fpn(im_data, im_info, gt_boxes, num_boxes)

#pdb.set_trace()

scores = cls_prob.data

boxes = rois.data[:, :, 1:5]

if cfg.TEST.BBOX_REG:

# Apply bounding-box regression deltas

box_deltas = bbox_pred.data

if cfg.TRAIN.BBOX_NORMALIZE_TARGETS_PRECOMPUTED:

# Optionally normalize targets by a precomputed mean and stdev

if args.class_agnostic:

box_deltas = box_deltas.view(-1, 4) * torch.FloatTensor(cfg.TRAIN.BBOX_NORMALIZE_STDS).cuda() \

+ torch.FloatTensor(cfg.TRAIN.BBOX_NORMALIZE_MEANS).cuda()

box_deltas = box_deltas.view(1, -1, 4)

else:

box_deltas = box_deltas.view(-1, 4) * torch.FloatTensor(cfg.TRAIN.BBOX_NORMALIZE_STDS).cuda() \

+ torch.FloatTensor(cfg.TRAIN.BBOX_NORMALIZE_MEANS).cuda()

box_deltas = box_deltas.view(1, -1, 4 * len(classes))

#pdb.set_trace()

pred_boxes = bbox_transform_inv(boxes, box_deltas, 1)

pred_boxes = clip_boxes(pred_boxes, im_info.data, 1)

else:

# Simply repeat the boxes, once for each class

pred_boxes = np.tile(boxes, (1, scores.shape[1]))

pred_boxes /= im_scales[0]

scores = scores.squeeze()

pred_boxes = pred_boxes.squeeze()

det_toc = time.time()

detect_time = det_toc - det_tic

if vis:

im2show = np.copy(im)

sys.stdout.write('im_detect: {:d} {:.3f}s \r'.format(i, detect_time))

sys.stdout.flush()

for j in xrange(1, len(classes)): # 0 for background

inds = torch.nonzero(scores[:,j] > thresh).view(-1)

# if there is det

if inds.numel() > 0:

cls_scores = scores[:,j][inds] # confidence of the specified class

_, order = torch.sort(cls_scores, 0, True) # sorted scores and indexes

if args.class_agnostic:

cls_boxes = pred_boxes[inds, :]

else:

cls_boxes = pred_boxes[inds][:, j * 4:(j + 1) * 4]

cls_dets = torch.cat((cls_boxes, cls_scores.unsqueeze(1)), 1)

# cls_dets = torch.cat((cls_boxes, cls_scores), 1)

cls_dets = cls_dets[order]

keep = nms(cls_dets, cfg.TEST.NMS) # after nms

cls_dets = cls_dets[keep.view(-1).long()] # keep shape is ?x1

if vis:

# cls_dets.cpu().numpy() make tensor->numpy array

im2show = vis_detections(im2show, classes[j], cls_dets.cpu().numpy(), vis_thresh)

#drawpath = os.path.join('images', "{}.jpg".format(i))

#cv2.imwrite(drawpath, im2show)

cv2.imshow('demo', im2show)

if (cv2.waitKey(25) & 0xFF) == ord('q'):

break

vc.release()

cv2.destroyAllWindows()

3.运行

【代码】

https://github.com/LeftThink/DetNet_pytorch.git

ref:https://blog.csdn.net/ChuiGeDaQiQiu/article/details/82969397