1
+ import random
2
+ import torch
3
+ import torchvision
4
+ import torchvision .datasets as dset
5
+ import torchvision .transforms as transforms
6
+ from torch .utils .data import DataLoader , Dataset
7
+ import torch .nn as nn
8
+ from torch import optim
9
+ import torch .nn .functional as F
10
+ from PIL import Image
11
+ import PIL .ImageOps
12
+ import pytorch_lightning as pl
13
+ import numpy as np
14
+ from pytorch_lightning .callbacks import ModelCheckpoint
15
+ import argparse
16
+
17
+
18
+ class ContrastiveLoss (torch .nn .Module ):
19
+ """
20
+ Contrastive loss function.
21
+ Based on: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
22
+ """
23
+
24
+ def __init__ (self , margin ):
25
+ super ().__init__ ()
26
+ self .margin = margin
27
+
28
+ def forward (self , output1 , output2 , label ):
29
+ euclidean_distance = F .pairwise_distance (output1 , output2 , keepdim = True )
30
+ loss_contrastive = torch .mean ((1 - label ) * torch .pow (euclidean_distance , 2 ) + (label ) * torch .pow (torch .clamp (self .margin - euclidean_distance , min = 0.0 ), 2 ))
31
+
32
+ return loss_contrastive
33
+
34
+ class SiameseNetworkDataset (Dataset ):
35
+
36
+ def __init__ (self ,imageFolderDataset ,transform = None ,should_invert = None ):
37
+ self .imageFolderDataset = imageFolderDataset
38
+ self .transform = transform
39
+ self .should_invert = should_invert
40
+
41
+ def __getitem__ (self ,index ):
42
+ img0_tuple = random .choice (self .imageFolderDataset .imgs )
43
+ #we need to make sure approx 50% of images are in the same class
44
+ should_get_same_class = random .randint (0 ,1 )
45
+ if should_get_same_class :
46
+ while True :
47
+ #keep looping till the same class image is found
48
+ img1_tuple = random .choice (self .imageFolderDataset .imgs )
49
+ if img0_tuple [1 ]== img1_tuple [1 ]:
50
+ break
51
+ else :
52
+ while True :
53
+ #keep looping till a different class image is found
54
+
55
+ img1_tuple = random .choice (self .imageFolderDataset .imgs )
56
+ if img0_tuple [1 ] != img1_tuple [1 ]:
57
+ break
58
+
59
+ img0 = Image .open (img0_tuple [0 ])
60
+ img1 = Image .open (img1_tuple [0 ])
61
+ img0 = img0 .convert ("L" )
62
+ img1 = img1 .convert ("L" )
63
+
64
+ if self .should_invert :
65
+ img0 = PIL .ImageOps .invert (img0 )
66
+ img1 = PIL .ImageOps .invert (img1 )
67
+
68
+ if self .transform is not None :
69
+ img0 = self .transform (img0 )
70
+ img1 = self .transform (img1 )
71
+
72
+ return img0 , img1 , torch .from_numpy (np .array ([int (img1_tuple [1 ]!= img0_tuple [1 ])],dtype = np .float32 ))
73
+
74
+ def __len__ (self ):
75
+ return len (self .imageFolderDataset .imgs )
76
+
77
+
78
+ class SiameseNetwork (pl .LightningModule ):
79
+ def __init__ (self , margin , learning_rate , resize , imageFolderTrain , imageFolderTest , batch_size , should_invert ):
80
+ super ().__init__ ()
81
+ self .imageFolderTrain = imageFolderTrain
82
+ self .imageFolderTest = imageFolderTest
83
+ self .learning_rate = learning_rate
84
+ self .criterion = ContrastiveLoss (margin = margin )
85
+ self .batch_size = batch_size
86
+ self .should_invert = should_invert
87
+ self .transform = transforms .Compose ([transforms .Resize ((resize ,resize )),
88
+ transforms .RandomHorizontalFlip (),
89
+ transforms .ToTensor ()])
90
+
91
+ self .cnn1 = nn .Sequential (
92
+ nn .ReflectionPad2d (1 ),
93
+ nn .Conv2d (1 , 4 , kernel_size = 3 ),
94
+ nn .ReLU (inplace = True ),
95
+ nn .BatchNorm2d (4 ),
96
+
97
+ nn .ReflectionPad2d (1 ),
98
+ nn .Conv2d (4 , 8 , kernel_size = 3 ),
99
+ nn .ReLU (inplace = True ),
100
+ nn .BatchNorm2d (8 ),
101
+
102
+
103
+ nn .ReflectionPad2d (1 ),
104
+ nn .Conv2d (8 , 8 , kernel_size = 3 ),
105
+ nn .ReLU (inplace = True ),
106
+ nn .BatchNorm2d (8 ),
107
+
108
+ )
109
+
110
+ self .fc1 = nn .Sequential (
111
+ nn .Linear (8 * 100 * 100 , 500 ),
112
+ nn .ReLU (inplace = True ),
113
+
114
+ nn .Linear (500 , 500 ),
115
+ nn .ReLU (inplace = True ),
116
+
117
+ nn .Linear (500 , 5 ))
118
+
119
+ def forward_once (self , x ):
120
+ output = self .cnn1 (x )
121
+ output = output .view (output .size ()[0 ], - 1 )
122
+ output = self .fc1 (output )
123
+ return output
124
+
125
+
126
+ def forward (self , input1 , input2 ):
127
+ output1 = self .forward_once (input1 )
128
+ output2 = self .forward_once (input2 )
129
+ return output1 , output2
130
+
131
+
132
+ def training_step (self , batch , batch_idx ):
133
+ x0 , x1 , y = batch
134
+ output1 ,output2 = self (x0 , x1 )
135
+ loss = self .criterion (output1 ,output2 , y )
136
+ return loss
137
+
138
+ def validation_step (self , batch , batch_idx ):
139
+ x0 , x1 , y = batch
140
+ output1 ,output2 = self (x0 , x1 )
141
+ loss = self .criterion (output1 ,output2 ,y )
142
+
143
+ self .log ('val_loss' , loss , prog_bar = True )
144
+ return loss
145
+
146
+ def test_step (self , batch , batch_idx ):
147
+ return self .validation_step (batch , batch_idx )
148
+
149
+ def configure_optimizers (self ):
150
+ optimizer = torch .optim .Adam (self .parameters (), lr = self .learning_rate )
151
+ return optimizer
152
+
153
+ def prepare_data (self ):
154
+ self .DatasetFolder = dset .ImageFolder (self .imageFolderTrain )
155
+ self .DatasetFolder_testing = dset .ImageFolder (self .imageFolderTest )
156
+
157
+ def setup (self , stage = None ):
158
+
159
+ self .siamese_dataset_train = SiameseNetworkDataset (imageFolderDataset = self .DatasetFolder ,
160
+ transform = self .transform
161
+ ,should_invert = self .should_invert )
162
+ self .siamese_dataset_test = SiameseNetworkDataset (imageFolderDataset = self .DatasetFolder_testing ,
163
+ transform = self .transform
164
+ ,should_invert = self .should_invert )
165
+
166
+ def train_dataloader (self ):
167
+ return DataLoader (self .siamese_dataset_train , batch_size = self .batch_size )
168
+
169
+ def test_dataloader (self ):
170
+ return DataLoader (self .siamese_dataset_test , batch_size = self .batch_size )
171
+
172
+ if __name__ == '__main__' :
173
+
174
+ parser = argparse .ArgumentParser (
175
+ description = 'Siamese Network - Face Recognition' ,
176
+ formatter_class = argparse .ArgumentDefaultsHelpFormatter )
177
+ parser .add_argument ('--gpus' , default = 1 , type = int )
178
+ parser .add_argument ('--batch_size' , default = 64 , type = int )
179
+ parser .add_argument ('--pretrain_epochs' , default = 5000 , type = int )
180
+ parser .add_argument ('--margin' , default = 1.0 , type = float )
181
+ parser .add_argument ('--should_invert' , default = False )
182
+ parser .add_argument ('--imageFolderTrain' , default = None )
183
+ parser .add_argument ('--imageFolderTest' , default = None )
184
+ parser .add_argument ('--learning_rate' , default = 2e-2 , type = float )
185
+ parser .add_argument ('--resize' , default = 100 , type = int )
186
+
187
+
188
+ args = parser .parse_args ()
189
+ print (args )
190
+
191
+ model = SiameseNetwork (margin = args .margin , learning_rate = args .learning_rate , resize = args .resize , imageFolderTrain = args .imageFolderTrain ,
192
+ imageFolderTest = args .imageFolderTest , batch_size = args .batch_size , should_invert = args .should_invert )
193
+ trainer = pl .Trainer (gpus = args .gpus , max_epochs = args .pretrain_epochs , progress_bar_refresh_rate = 20 )
194
+ trainer .fit (model )
195
+ trainer .save_checkpoint ("siamese_face_recognition.ckpt" )
196
+ trainer .test ()
0 commit comments