Skip to content

Commit 8cfe246

Browse files
committed
Merge pull request #184 from tenghehan-feature/mmdetection_support
support MMDetection Signed-off-by: ZQPei <dfzspzq@163.com>
1 parent aa9b4a9 commit 8cfe246

File tree

9 files changed

+116
-7
lines changed

9 files changed

+116
-7
lines changed

.gitmodules

+3
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
11
[submodule "thirdparty/fast-reid"]
22
path = thirdparty/fast-reid
33
url = https://github.com/JDAI-CV/fast-reid.git
4+
[submodule "thirdparty/mmdetection"]
5+
path = thirdparty/mmdetection
6+
url = https://github.com/open-mmlab/mmdetection.git

README.md

+20-5
Original file line numberDiff line numberDiff line change
@@ -79,26 +79,38 @@ cd ../../..
7979
Notice:
8080
If compiling failed, the simplist way is to **Upgrade your pytorch >= 1.1 and torchvision >= 0.3" and you can avoid the troublesome compiling problems which are most likely caused by either `gcc version too low` or `libraries missing`.
8181

82-
5. (Optional) Prepare [fast-reid](https://github.com/JDAI-CV/fast-reid)
82+
5. (Optional) Prepare third party submodules
83+
84+
[fast-reid](https://github.com/JDAI-CV/fast-reid)
8385

8486
This library supports bagtricks, AGW and other mainstream ReID methods through providing an fast-reid adapter.
8587

88+
to prepare our bundled fast-reid, then follow instructions in its README to install it.
89+
90+
Please refer to `configs/fastreid.yaml` for a sample of using fast-reid. See [Model Zoo](https://github.com/JDAI-CV/fast-reid/blob/master/docs/MODEL_ZOO.md) for available methods and trained models.
91+
92+
[MMDetection](https://github.com/open-mmlab/mmdetection)
93+
94+
This library supports Faster R-CNN and other mainstream detection methods through providing an MMDetection adapter.
95+
96+
to prepare our bundled MMDetection, then follow instructions in its README to install it.
97+
98+
Please refer to `configs/mmdet.yaml` for a sample of using MMDetection. See [Model Zoo](https://github.com/open-mmlab/mmdetection/blob/master/docs/model_zoo.md) for available methods and trained models.
99+
86100
Run
87101

88102
```
89103
git submodule update --init --recursive
90104
```
91105

92-
to prepare our bundled fast-reid, then follow instructions in its README to install it.
93-
94-
Please refer to `configs/fastreid.yaml` for a sample of using fast-reid. See [Model Zoo](https://github.com/JDAI-CV/fast-reid/blob/master/docs/MODEL_ZOO.md) for available methods and trained models.
95-
96106

97107
6. Run demo
98108
```
99109
usage: deepsort.py [-h]
100110
[--fastreid]
101111
[--config_fastreid CONFIG_FASTREID]
112+
[--mmdet]
113+
[--config_mmdetection CONFIG_MMDETECTION]
102114
[--config_detection CONFIG_DETECTION]
103115
[--config_deepsort CONFIG_DEEPSORT] [--display]
104116
[--frame_interval FRAME_INTERVAL]
@@ -121,6 +133,9 @@ python3 deepsort.py /dev/video0 --config_detection ./configs/yolov3_tiny.yaml --
121133
122134
# fast-reid + deepsort
123135
python deepsort.py [VIDEO_PATH] --fastreid [--config_fastreid ./configs/fastreid.yaml]
136+
137+
# MMDetection + deepsort
138+
python deepsort.py [VIDEO_PATH] --mmdet [--config_mmdetection ./configs/mmdet.yaml]
124139
```
125140
Use `--display` to enable display.
126141
Results will be saved to `./output/results.avi` and `./output/results.txt`.

configs/mmdet.yaml

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
MMDET:
2+
CFG: "thirdparty/mmdetection/configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py"
3+
CHECKPOINT: "detector/MMDet/weight/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth"
4+
5+
SCORE_THRESH: 0.5

deepsort.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -135,10 +135,12 @@ def run(self):
135135
def parse_args():
136136
parser = argparse.ArgumentParser()
137137
parser.add_argument("VIDEO_PATH", type=str)
138+
parser.add_argument("--config_mmdetection", type=str, default="./configs/mmdet.yaml")
138139
parser.add_argument("--config_detection", type=str, default="./configs/yolov3.yaml")
139140
parser.add_argument("--config_deepsort", type=str, default="./configs/deep_sort.yaml")
140141
parser.add_argument("--config_fastreid", type=str, default="./configs/fastreid.yaml")
141142
parser.add_argument("--fastreid", action="store_true")
143+
parser.add_argument("--mmdet", action="store_true")
142144
# parser.add_argument("--ignore_display", dest="display", action="store_false", default=True)
143145
parser.add_argument("--display", action="store_true")
144146
parser.add_argument("--frame_interval", type=int, default=1)
@@ -153,7 +155,12 @@ def parse_args():
153155
if __name__ == "__main__":
154156
args = parse_args()
155157
cfg = get_config()
156-
cfg.merge_from_file(args.config_detection)
158+
if args.mmdet:
159+
cfg.merge_from_file(args.config_mmdetection)
160+
cfg.USE_MMDET = True
161+
else:
162+
cfg.merge_from_file(args.config_detection)
163+
cfg.USE_MMDET = False
157164
cfg.merge_from_file(args.config_deepsort)
158165
if args.fastreid:
159166
cfg.merge_from_file(args.config_fastreid)

detector/MMDet/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .detector import MMDet
2+
__all__ = ['MMDet']

detector/MMDet/detector.py

+55
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import logging
2+
import numpy as np
3+
import torch
4+
5+
from mmdet.apis import init_detector, inference_detector
6+
from .mmdet_utils import xyxy_to_xywh
7+
8+
class MMDet(object):
9+
def __init__(self, cfg_file, checkpoint_file, score_thresh=0.7,
10+
is_xywh=False, use_cuda=True):
11+
# net definition
12+
self.device = "cuda" if use_cuda else "cpu"
13+
self.net = init_detector(cfg_file, checkpoint_file, device=self.device)
14+
logger = logging.getLogger("root.detector")
15+
logger.info('Loading weights from %s... Done!' % (checkpoint_file))
16+
17+
#constants
18+
self.score_thresh = score_thresh
19+
self.use_cuda = use_cuda
20+
self.is_xywh = is_xywh
21+
self.class_names = self.net.CLASSES
22+
self.num_classes = len(self.class_names)
23+
24+
def __call__(self, ori_img):
25+
# forward
26+
bbox_result = inference_detector(self.net, ori_img)
27+
bboxes = np.vstack(bbox_result)
28+
29+
if len(bboxes) == 0:
30+
bbox = np.array([]).reshape([0, 4])
31+
cls_conf = np.array([])
32+
cls_ids = np.array([])
33+
return bbox, cls_conf, cls_ids
34+
35+
bbox = bboxes[:, :4]
36+
cls_conf = bboxes[:, 4]
37+
cls_ids = [
38+
np.full(bbox.shape[0], i, dtype=np.int32)
39+
for i, bbox in enumerate(bbox_result)
40+
]
41+
cls_ids = np.concatenate(cls_ids)
42+
43+
selected_idx = cls_conf > self.score_thresh
44+
bbox = bbox[selected_idx, :]
45+
cls_conf = cls_conf[selected_idx]
46+
cls_ids = cls_ids[selected_idx]
47+
48+
if self.is_xywh:
49+
bbox = xyxy_to_xywh(bbox)
50+
51+
return bbox, cls_conf, cls_ids
52+
53+
54+
55+

detector/MMDet/mmdet_utils.py

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import torch
2+
import numpy as np
3+
4+
def xyxy_to_xywh(boxes_xyxy):
5+
if isinstance(boxes_xyxy, torch.Tensor):
6+
boxes_xywh = boxes_xyxy.clone()
7+
elif isinstance(boxes_xyxy, np.ndarray):
8+
boxes_xywh = boxes_xyxy.copy()
9+
10+
boxes_xywh[:, 0] = (boxes_xyxy[:, 0] + boxes_xyxy[:, 2]) / 2.
11+
boxes_xywh[:, 1] = (boxes_xyxy[:, 1] + boxes_xyxy[:, 3]) / 2.
12+
boxes_xywh[:, 2] = boxes_xyxy[:, 2] - boxes_xyxy[:, 0]
13+
boxes_xywh[:, 3] = boxes_xyxy[:, 3] - boxes_xyxy[:, 1]
14+
15+
return boxes_xywh

detector/__init__.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
11
from .YOLOv3 import YOLOv3
2+
from .MMDet import MMDet
23

34

45
__all__ = ['build_detector']
56

67
def build_detector(cfg, use_cuda):
7-
return YOLOv3(cfg.YOLOV3.CFG, cfg.YOLOV3.WEIGHT, cfg.YOLOV3.CLASS_NAMES,
8+
if cfg.USE_MMDET:
9+
return MMDet(cfg.MMDET.CFG, cfg.MMDET.CHECKPOINT,
10+
score_thresh=cfg.MMDET.SCORE_THRESH,
11+
is_xywh=True, use_cuda=use_cuda)
12+
else:
13+
return YOLOv3(cfg.YOLOV3.CFG, cfg.YOLOV3.WEIGHT, cfg.YOLOV3.CLASS_NAMES,
814
score_thresh=cfg.YOLOV3.SCORE_THRESH, nms_thresh=cfg.YOLOV3.NMS_THRESH,
915
is_xywh=True, use_cuda=use_cuda)

thirdparty/mmdetection

Submodule mmdetection added at 3e902c3

0 commit comments

Comments
 (0)