Skip to content

Commit f310c6b

Browse files
committed
init
1 parent d928b4b commit f310c6b

File tree

101 files changed

+10425
-3
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

101 files changed

+10425
-3
lines changed

.gitignore

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
tools/euler/*
2+
.gitignore
3+
tools/bsub_*
4+
tools/lint.sh
5+
.idea
6+
*.log
7+
teta/scripts/lint.sh

LICENSE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@
186186
same "printed page" as the copyright notice for easier
187187
identification within third-party archives.
188188

189-
Copyright [yyyy] [name of copyright owner]
189+
Copyright [2022] [Siyuan Li]
190190

191191
Licensed under the Apache License, Version 2.0 (the "License");
192192
you may not use this file except in compliance with the License.

README.md

Lines changed: 77 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,78 @@
1-
# Track Every Thing in the Wild, ECCV 2022
1+
# Track Every Thing in the Wild [ECCV2022]
22

3-
In construction, please check back later.
3+
This is the offical implementation of paper [Track Every Thing in the Wild](https://arxiv.org/abs/2207.12978).
4+
5+
Our project website contains more information: [vis.xyz/pub/tet](https://www.vis.xyz/pub/tet/).
6+
7+
8+
## Abstract
9+
10+
Current multi-category Multiple Object Tracking (MOT) metrics use class labels to group tracking results for per-class evaluation. Similarly, MOT methods typically only associate objects with the same class predictions.
11+
These two prevalent strategies in MOT implicitly assume that the classification performance is near-perfect.
12+
However, this is far from the case in recent large-scale MOT datasets, which contain large numbers of classes with many rare or semantically similar categories. Therefore, the resulting inaccurate classification leads to sub-optimal tracking and inadequate benchmarking of trackers.
13+
We address these issues by disentangling classification from tracking.
14+
We introduce a new metric, Track Every Thing Accuracy (TETA), breaking tracking measurement into three sub-factors: localization, association, and classification, allowing comprehensive benchmarking of tracking performance even under inaccurate classification. TETA also deals with the challenging incomplete annotation problem in large-scale tracking datasets. We further introduce a Track Every Thing tracker (TETer), that performs association using Class Exemplar Matching (CEM). Our experiments show that TETA evaluates trackers more comprehensively, and TETer achieves significant improvements on the challenging large-scale datasets BDD100K and TAO compared to the state-of-the-art.
15+
16+
## TETA
17+
[TETA](teta/README.md) builds upon the HOTA metric, while extending it to better deal with
18+
multiple categories and incomplete annotations. TETA evaluate trackers based on a novel local cluster design. TETA consists of three parts: a
19+
localization score, an association score, and a classification score, which enable
20+
us to evaluate the different aspects of each tracker properly.
21+
22+
<img src="figures/teta-teaser.png" width="400">
23+
24+
## TETer
25+
TETer follows an Associate-Every-Thing (AET) strategy.
26+
Instead of only associating objects in the same class, we associate every object in neighboring frames.
27+
We introduce Class Exemplar Matching (CEM), where the learned class exemplars incorporate valuable class information in a soft manner.
28+
In this way, we effectively exploit semantic supervision on large-scale detection datasets while not relying on the often incorrect classification output.
29+
30+
<img src="figures/teaser-teter.png" width="800">
31+
32+
## Main results
33+
Our method outperforms the states of the art on BDD100K, and TAO benchmarks.
34+
35+
### BDD100K val set
36+
37+
| Method | backbone | mMOTA | mIDF1 | TETA | LocA | AssocA | ClsA |
38+
|-----------------------------------------------------|-----------|-------|-------|------|------|--------|------|
39+
| [QDTrack(CVPR21)](https://arxiv.org/abs/2006.06664) | ResNet-50 | 36.6 | 51.6 | 47.8 | 45.9 | 48.5 | 49.2 |
40+
| TETer (Ours) | ResNet-50 | 39.1 | 53.3 | 50.8 | 47.2 | 52.9 | 52.4 |
41+
42+
43+
### BDD100K test set
44+
45+
| Method | backbone | mMOTA | mIDF1 | TETA | LocA | AssocA | ClsA |
46+
|-----------------------------------------------------|-----------|-------|-------|------|------|--------|------|
47+
| [QDTrack(CVPR21)](https://arxiv.org/abs/2006.06664) | ResNet-50 | 35.7 | 52.3 | 49.2 | 47.2 | 50.9 | 49.2 |
48+
| TETer (Ours) | ResNet-50 | 37.4 | 53.3 | 50.8 | 47.0 | 53.6 | 50.7 |
49+
50+
51+
### TAO val set
52+
53+
| Method | backbone | TETA | LocA | AssocA | ClsA |
54+
|-----------------------------------------------------|------------|------|------|--------|------|
55+
| [QDTrack(CVPR21)](https://arxiv.org/abs/2006.06664) | ResNet-101 | 30.0 | 50.5 | 27.4 | 12.1 |
56+
| TETer (Ours) | ResNet-101 | 33.3 | 51.6 | 35.0 | 13.2 |
57+
| TETer-swinT (Ours) | SwinT | 34.6 | 52.1 | 36.7 | 15.0 |
58+
59+
## Installation
60+
61+
Please refer to [INSTALL.md](docs/INSTALL.md) for installation instructions.
62+
63+
64+
## Usages
65+
Please refer to [GET_STARTED.md](docs/GET_STARTED.md) for dataset preparation and running instructions.
66+
67+
68+
## Citation
69+
70+
```
71+
@InProceedings{trackeverything,
72+
title = {Track Every Thing in the Wild},
73+
author = {Li, Siyuan and Danelljan, Martin and Ding, Henghui and Huang, Thomas E. and Yu, Fisher},
74+
booktitle = {Proceedings of the European Conference on Computer Vision (ECCV)},
75+
month = {Oct},
76+
year = {2022}
77+
}
78+
```

configs/_base_/faster_rcnn_r50_fpn.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
# model settings
2+
model = dict(
3+
type='FasterRCNN',
4+
backbone=dict(
5+
type='ResNet',
6+
depth=50,
7+
num_stages=4,
8+
out_indices=(0, 1, 2, 3),
9+
frozen_stages=1,
10+
norm_cfg=dict(type='BN', requires_grad=True),
11+
norm_eval=True,
12+
style='pytorch',
13+
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
14+
neck=dict(
15+
type='FPN',
16+
in_channels=[256, 512, 1024, 2048],
17+
out_channels=256,
18+
num_outs=5),
19+
rpn_head=dict(
20+
type='RPNHead',
21+
in_channels=256,
22+
feat_channels=256,
23+
anchor_generator=dict(
24+
type='AnchorGenerator',
25+
scales=[8],
26+
ratios=[0.5, 1.0, 2.0],
27+
strides=[4, 8, 16, 32, 64]),
28+
bbox_coder=dict(
29+
type='DeltaXYWHBBoxCoder',
30+
target_means=[.0, .0, .0, .0],
31+
target_stds=[1.0, 1.0, 1.0, 1.0]),
32+
loss_cls=dict(
33+
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
34+
loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
35+
roi_head=dict(
36+
type='StandardRoIHead',
37+
bbox_roi_extractor=dict(
38+
type='SingleRoIExtractor',
39+
roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
40+
out_channels=256,
41+
featmap_strides=[4, 8, 16, 32]),
42+
bbox_head=dict(
43+
type='Shared2FCBBoxHead',
44+
in_channels=256,
45+
fc_out_channels=1024,
46+
roi_feat_size=7,
47+
num_classes=80,
48+
bbox_coder=dict(
49+
type='DeltaXYWHBBoxCoder',
50+
target_means=[0., 0., 0., 0.],
51+
target_stds=[0.1, 0.1, 0.2, 0.2]),
52+
reg_class_agnostic=False,
53+
loss_cls=dict(
54+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
55+
loss_bbox=dict(type='L1Loss', loss_weight=1.0))),
56+
# model training and testing settings
57+
train_cfg=dict(
58+
rpn=dict(
59+
assigner=dict(
60+
type='MaxIoUAssigner',
61+
pos_iou_thr=0.7,
62+
neg_iou_thr=0.3,
63+
min_pos_iou=0.3,
64+
match_low_quality=True,
65+
ignore_iof_thr=-1),
66+
sampler=dict(
67+
type='RandomSampler',
68+
num=256,
69+
pos_fraction=0.5,
70+
neg_pos_ub=-1,
71+
add_gt_as_proposals=False),
72+
allowed_border=-1,
73+
pos_weight=-1,
74+
debug=False),
75+
rpn_proposal=dict(
76+
nms_pre=2000,
77+
max_per_img=1000,
78+
nms=dict(type='nms', iou_threshold=0.7),
79+
min_bbox_size=0),
80+
rcnn=dict(
81+
assigner=dict(
82+
type='MaxIoUAssigner',
83+
pos_iou_thr=0.5,
84+
neg_iou_thr=0.5,
85+
min_pos_iou=0.5,
86+
match_low_quality=False,
87+
ignore_iof_thr=-1),
88+
sampler=dict(
89+
type='RandomSampler',
90+
num=512,
91+
pos_fraction=0.25,
92+
neg_pos_ub=-1,
93+
add_gt_as_proposals=True),
94+
pos_weight=-1,
95+
debug=False)),
96+
test_cfg=dict(
97+
rpn=dict(
98+
nms_pre=1000,
99+
max_per_img=1000,
100+
nms=dict(type='nms', iou_threshold=0.7),
101+
min_bbox_size=0),
102+
rcnn=dict(
103+
score_thr=0.05,
104+
nms=dict(type='nms', iou_threshold=0.5),
105+
max_per_img=100)
106+
# soft-nms is also supported for rcnn testing
107+
# e.g., nms=dict(type='soft_nms', iou_threshold=0.5, min_score=0.05)
108+
))
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
_base_ = './faster_rcnn_r50_fpn.py'
2+
model = dict(
3+
type='QDTrack',
4+
rpn_head=dict(
5+
loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0)),
6+
roi_head=dict(
7+
type='QuasiDenseRoIHead',
8+
track_roi_extractor=dict(
9+
type='SingleRoIExtractor',
10+
roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
11+
out_channels=256,
12+
featmap_strides=[4, 8, 16, 32]),
13+
track_head=dict(
14+
type='QuasiDenseEmbedHead',
15+
num_convs=4,
16+
num_fcs=1,
17+
embed_channels=256,
18+
norm_cfg=dict(type='GN', num_groups=32),
19+
loss_track=dict(type='MultiPosCrossEntropyLoss', loss_weight=0.25),
20+
loss_track_aux=dict(
21+
type='L2Loss',
22+
neg_pos_ub=3,
23+
pos_margin=0,
24+
neg_margin=0.1,
25+
hard_mining=True,
26+
loss_weight=1.0))),
27+
train_cfg=dict(
28+
embed=dict(
29+
assigner=dict(
30+
type='MaxIoUAssigner',
31+
pos_iou_thr=0.7,
32+
neg_iou_thr=0.3,
33+
min_pos_iou=0.5,
34+
match_low_quality=False,
35+
ignore_iof_thr=-1),
36+
sampler=dict(
37+
type='CombinedSampler',
38+
num=256,
39+
pos_fraction=0.5,
40+
neg_pos_ub=3,
41+
add_gt_as_proposals=True,
42+
pos_sampler=dict(type='InstanceBalancedPosSampler'),
43+
neg_sampler=dict(type='RandomSampler')))))

0 commit comments

Comments
 (0)