Skip to content

Commit 0714159

Browse files
committed
Full repository refactor
1 parent 3b2dd71 commit 0714159

13 files changed

+219
-148
lines changed

TODO

-12
This file was deleted.

dataset/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11

2-
from dataset.bdd import *
2+
from dataset.berkeley_deepdrive import *
33
from dataset.utils import *
44
from dataset.transforms import *

dataset/bdd.py

-66
This file was deleted.

dataset/berkeley_deepdrive.py

+84
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
''' Berkeley Deepdrive Segmentation Dataset loader '''
2+
3+
import os
4+
import re
5+
6+
from PIL import Image
7+
import torch
8+
from torch.utils.data import Dataset
9+
10+
from dataset.utils import listdir
11+
12+
class BDDSegmentationDataset(Dataset):
13+
''' Dataset loader for Berkeley Deepdrive Segmentation dataset '''
14+
15+
def __init__(self, path, split, transforms=None):
16+
assert split in ['train', 'val', 'test'], 'split must be one of: {train, val, test}'
17+
image_re = re.compile(r'(.*)\.jpg')
18+
label_re = re.compile(r'(.*)_train_id\.png')
19+
images = sorted(listdir(os.path.join(path, 'seg/images', split), image_re))
20+
labels = sorted(listdir(os.path.join(path, 'seg/labels', split), label_re))
21+
for (image, label) in zip(images, labels):
22+
assert (image_re.match(os.path.basename(image)).group(1) ==
23+
label_re.match(os.path.basename(label)).group(1))
24+
self.images, self.labels = images, labels
25+
self.transforms = transforms
26+
27+
def __len__(self):
28+
return len(self.images)
29+
30+
def __getitem__(self, key):
31+
image = Image.open(self.images[key])
32+
label = Image.open(self.labels[key])
33+
if self.transforms:
34+
image, label = self.transforms(image, label)
35+
return image, label
36+
37+
38+
def bdd_palette(labels):
39+
''' Applies a color palette to either a single label
40+
tensor or a batch of tensors '''
41+
assert len(labels.shape) in [2, 3], 'Invalid labels shape'
42+
43+
# pylint: disable=bad-whitespace
44+
color_map = torch.Tensor([
45+
[128, 67, 125], # Road
46+
[247, 48, 227], # Sidewalk
47+
[ 72, 72, 72], # Building
48+
[101, 103, 153], # Wall
49+
[190, 151, 152], # Fence
50+
[152, 152, 152], # Pole
51+
[254, 167, 56], # Light
52+
[221, 217, 55], # Sign
53+
[106, 140, 51], # Vegetation
54+
[146, 250, 157], # Terrain
55+
[ 65, 130, 176], # Sky
56+
[224, 20, 64], # Person
57+
[255, 0, 25], # Rider
58+
[ 0, 22, 138], # Car
59+
[ 0, 11, 70], # Truck
60+
[ 0, 63, 98], # Bus
61+
[ 0, 82, 99], # Train
62+
[ 0, 36, 224], # Motorcycle
63+
[121, 17, 38], # Bicycle
64+
[ 0, 0, 0] # Other
65+
]).to(labels.device) / 255.0
66+
67+
batched_input = True
68+
if len(labels.shape) == 2:
69+
batched_input = False
70+
labels = torch.unsqueeze(labels, 0)
71+
72+
# Convert ignore index to label 20
73+
labels = torch.clamp(labels, 0, 20 - 1).long()
74+
75+
n, h, w = labels.shape
76+
labels_one_hot = torch.zeros(n, 20, h, w).to(labels.device)
77+
labels_one_hot.scatter_(1, torch.unsqueeze(labels, 1), 1)
78+
79+
color_labels = torch.einsum('nlhw,lc->nchw', labels_one_hot, color_map)
80+
81+
if not batched_input:
82+
color_labels = torch.squeeze(color_labels, 0)
83+
84+
return color_labels

dataset/transforms.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ def transforms(img, seg, size=(360, 640), augment=True, hflip_prob=0.5,
1111
five_crop_prob=0.5, five_crop_scale=0.6,
1212
rotate_prob=0.5, max_rotate=30.0,
1313
tensor_output=True,
14-
normalize_mean=torch.Tensor([0.0, 0.0, 0.0]),
15-
normalize_std=torch.Tensor([1.0, 1.0, 1.0]),
14+
normalize_mean=torch.Tensor([0.3518, 0.3932, 0.4011]),
15+
normalize_std=torch.Tensor([0.2363, 0.2494, 0.2611]),
1616
_ignore_index=255):
1717
''' BDD transforms pipeline '''
1818

dataset/utils.py

+18
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,14 @@
11

2+
import os
3+
import re
4+
25
import torch
36

7+
def listdir(path, filter_=re.compile(r'.*')):
8+
''' Enumerates full paths of files in a directory matching a filter '''
9+
return [os.path.join(path, f) for f in os.listdir(path) if filter_.match(f)]
10+
11+
412
def median_frequency_balance(dataset, num_classes=19, ignore_index=255, _eps=1e-5):
513
'''
614
For more details refer to Section 6.3.2 in
@@ -14,3 +22,13 @@ def median_frequency_balance(dataset, num_classes=19, ignore_index=255, _eps=1e-
1422
frequency[cid] += torch.sum(seg == cid)
1523
frequency /= torch.sum(frequency)
1624
return torch.median(frequency) / frequency
25+
26+
27+
def mean_std(dataset):
28+
''' Returns the channel means and standard deviations
29+
of the images in the dataset '''
30+
mean, std = 0.0, 0.0
31+
for image, _ in dataset:
32+
mean += image.mean(dim=(1, 2)) # CHW -> C
33+
std += image.view((3, -1)).std(dim=1) ** 2
34+
return mean / len(dataset), (std / len(dataset)) ** 0.5

eval.py

-35
This file was deleted.

metrics/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11

2-
from metrics.metrics import *
2+
from metrics.metrics import mean_iou, pixel_accuracy

metrics/metrics.py

+37-14
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,45 @@
11

2+
#pylint: disable=invalid-name
3+
24
import torch
35

4-
def mean_iou(y_pred, y, logits_dim=1, ignore_index=255, eps=1e-8):
6+
def mean_iou(y_pred, y, num_classes, ignore_index=255):
57
''' Evaluates mean IoU between prediction and ground truth '''
6-
y_pred = torch.argmax(y_pred, dim=logits_dim)
7-
classes = set(torch.unique(torch.cat((y_pred, y))))
8-
classes.discard(ignore_index)
9-
mask = (y != ignore_index)
8+
ignore_mask = (y != ignore_index)
9+
y_pred, y = y_pred[ignore_mask], y[ignore_mask]
1010

11-
miou = 0.0
12-
for i in classes:
13-
intersect = torch.sum((y_pred[mask] == i) & (y[mask] == i)).float()
14-
union = torch.sum((y_pred[mask] == i) | (y[mask] == i)).float()
15-
miou += (intersect + eps) / (union + eps)
16-
return (miou + eps) / (len(classes) + eps)
11+
conf_matrix = _confusion_matrix(y, y_pred, num_classes)
12+
true_pos = torch.diag(conf_matrix)
13+
false_pos = torch.sum(conf_matrix, dim=0) - true_pos
14+
false_neg = torch.sum(conf_matrix, dim=1) - true_pos
15+
tp_fp_fn = true_pos + false_pos + false_neg
1716

18-
def pixel_accuracy(y_pred, y, logits_dim=1, ignore_index=255):
17+
exist_class_mask = tp_fp_fn > 0
18+
true_pos, tp_fp_fn = true_pos[exist_class_mask], tp_fp_fn[exist_class_mask]
19+
return torch.mean(true_pos / tp_fp_fn)
20+
21+
def pixel_accuracy(y_pred, y, num_classes, ignore_index=255):
1922
''' Evaluates pixel accuracy between prediction and ground truth '''
20-
y_pred = torch.argmax(y_pred, dim=logits_dim)
2123
mask = (y != ignore_index)
22-
return torch.sum(y[mask] == y_pred[mask]).float() / torch.sum(mask).float()
24+
y_pred, y = y_pred[mask], y[mask]
25+
26+
conf_matrix = _confusion_matrix(y, y_pred, num_classes)
27+
return torch.sum(torch.diag(conf_matrix)) / torch.sum(conf_matrix)
28+
29+
# Helper functions
30+
31+
def _one_hot(labels, num_classes, class_dim=1):
32+
''' Converts a labels tensor (NHW) into a one-hot tensor (NLHW) '''
33+
labels = torch.unsqueeze(labels, class_dim)
34+
labels_one_hot = torch.zeros_like(labels).repeat(
35+
[num_classes if d == class_dim else 1
36+
for d in range(len(labels.shape))])
37+
labels_one_hot.scatter_(class_dim, labels, 1)
38+
return labels_one_hot
39+
40+
def _confusion_matrix(y_pred, y, num_classes):
41+
''' Computes the confusion matrix between two predicitons '''
42+
b_size = y_pred.shape[0]
43+
y, y_pred = _one_hot(y, num_classes), _one_hot(y_pred, num_classes)
44+
y, y_pred = y.reshape(b_size, num_classes, -1), y_pred.reshape(b_size, num_classes, -1)
45+
return torch.einsum('iaj,ibj->ab', y.float(), y_pred.float())

model/deeplab.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
# pylint: disable=W0221,C0414,C0103
2-
1+
# pylint: disable=arguments-differ, too-many-arguments
32
''' DeepLab V3+ '''
43

54
import torch
@@ -83,7 +82,7 @@ def forward(self, x_in):
8382
logits = nn_func.interpolate(logits, size=x_in.shape[2:4],
8483
mode='bilinear', align_corners=True)
8584
return logits
86-
85+
8786
def _init_weights(self):
8887
''' Initializes weights of the model.
8988
- Conv2d parameters initialized using Kaiming normal

model/nn_ext.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# pylint: disable=W0221,C0414
1+
# pylint: disable=arguments-differ, too-many-arguments
22

33
''' Extensions to standard torch.nn primitives '''
44

0 commit comments

Comments
 (0)