Skip to content

Commit 8e09378

Browse files
authored
Added model.py
1 parent b25293a commit 8e09378

File tree

2 files changed

+242
-0
lines changed

2 files changed

+242
-0
lines changed

model.py

+196
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
import random
2+
import torch
3+
import torchvision
4+
import torchvision.datasets as dset
5+
import torchvision.transforms as transforms
6+
from torch.utils.data import DataLoader, Dataset
7+
import torch.nn as nn
8+
from torch import optim
9+
import torch.nn.functional as F
10+
from PIL import Image
11+
import PIL.ImageOps
12+
import pytorch_lightning as pl
13+
import numpy as np
14+
from pytorch_lightning.callbacks import ModelCheckpoint
15+
import argparse
16+
17+
18+
class ContrastiveLoss(torch.nn.Module):
19+
"""
20+
Contrastive loss function.
21+
Based on: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
22+
"""
23+
24+
def __init__(self, margin):
25+
super().__init__()
26+
self.margin = margin
27+
28+
def forward(self, output1, output2, label):
29+
euclidean_distance = F.pairwise_distance(output1, output2, keepdim = True)
30+
loss_contrastive = torch.mean((1-label) * torch.pow(euclidean_distance, 2) + (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))
31+
32+
return loss_contrastive
33+
34+
class SiameseNetworkDataset(Dataset):
35+
36+
def __init__(self,imageFolderDataset,transform=None,should_invert=None):
37+
self.imageFolderDataset = imageFolderDataset
38+
self.transform = transform
39+
self.should_invert = should_invert
40+
41+
def __getitem__(self,index):
42+
img0_tuple = random.choice(self.imageFolderDataset.imgs)
43+
#we need to make sure approx 50% of images are in the same class
44+
should_get_same_class = random.randint(0,1)
45+
if should_get_same_class:
46+
while True:
47+
#keep looping till the same class image is found
48+
img1_tuple = random.choice(self.imageFolderDataset.imgs)
49+
if img0_tuple[1]==img1_tuple[1]:
50+
break
51+
else:
52+
while True:
53+
#keep looping till a different class image is found
54+
55+
img1_tuple = random.choice(self.imageFolderDataset.imgs)
56+
if img0_tuple[1] !=img1_tuple[1]:
57+
break
58+
59+
img0 = Image.open(img0_tuple[0])
60+
img1 = Image.open(img1_tuple[0])
61+
img0 = img0.convert("L")
62+
img1 = img1.convert("L")
63+
64+
if self.should_invert:
65+
img0 = PIL.ImageOps.invert(img0)
66+
img1 = PIL.ImageOps.invert(img1)
67+
68+
if self.transform is not None:
69+
img0 = self.transform(img0)
70+
img1 = self.transform(img1)
71+
72+
return img0, img1 , torch.from_numpy(np.array([int(img1_tuple[1]!=img0_tuple[1])],dtype=np.float32))
73+
74+
def __len__(self):
75+
return len(self.imageFolderDataset.imgs)
76+
77+
78+
class SiameseNetwork(pl.LightningModule):
79+
def __init__(self, margin, learning_rate, resize, imageFolderTrain, imageFolderTest, batch_size, should_invert):
80+
super().__init__()
81+
self.imageFolderTrain = imageFolderTrain
82+
self.imageFolderTest= imageFolderTest
83+
self.learning_rate = learning_rate
84+
self.criterion = ContrastiveLoss(margin=margin)
85+
self.batch_size = batch_size
86+
self.should_invert = should_invert
87+
self.transform = transforms.Compose([transforms.Resize((resize,resize)),
88+
transforms.RandomHorizontalFlip(),
89+
transforms.ToTensor()])
90+
91+
self.cnn1 = nn.Sequential(
92+
nn.ReflectionPad2d(1),
93+
nn.Conv2d(1, 4, kernel_size=3),
94+
nn.ReLU(inplace=True),
95+
nn.BatchNorm2d(4),
96+
97+
nn.ReflectionPad2d(1),
98+
nn.Conv2d(4, 8, kernel_size=3),
99+
nn.ReLU(inplace=True),
100+
nn.BatchNorm2d(8),
101+
102+
103+
nn.ReflectionPad2d(1),
104+
nn.Conv2d(8, 8, kernel_size=3),
105+
nn.ReLU(inplace=True),
106+
nn.BatchNorm2d(8),
107+
108+
)
109+
110+
self.fc1 = nn.Sequential(
111+
nn.Linear(8*100*100, 500),
112+
nn.ReLU(inplace=True),
113+
114+
nn.Linear(500, 500),
115+
nn.ReLU(inplace=True),
116+
117+
nn.Linear(500, 5))
118+
119+
def forward_once(self, x):
120+
output = self.cnn1(x)
121+
output = output.view(output.size()[0], -1)
122+
output = self.fc1(output)
123+
return output
124+
125+
126+
def forward(self, input1, input2):
127+
output1 = self.forward_once(input1)
128+
output2 = self.forward_once(input2)
129+
return output1, output2
130+
131+
132+
def training_step(self, batch, batch_idx):
133+
x0, x1 , y = batch
134+
output1,output2 = self(x0, x1)
135+
loss = self.criterion(output1,output2, y)
136+
return loss
137+
138+
def validation_step(self, batch, batch_idx):
139+
x0, x1 , y = batch
140+
output1,output2 = self(x0, x1)
141+
loss = self.criterion(output1,output2,y)
142+
143+
self.log('val_loss', loss, prog_bar=True)
144+
return loss
145+
146+
def test_step(self, batch, batch_idx):
147+
return self.validation_step(batch, batch_idx)
148+
149+
def configure_optimizers(self):
150+
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
151+
return optimizer
152+
153+
def prepare_data(self):
154+
self.DatasetFolder = dset.ImageFolder(self.imageFolderTrain)
155+
self.DatasetFolder_testing = dset.ImageFolder(self.imageFolderTest)
156+
157+
def setup(self, stage=None):
158+
159+
self.siamese_dataset_train = SiameseNetworkDataset(imageFolderDataset=self.DatasetFolder,
160+
transform=self.transform
161+
,should_invert=self.should_invert)
162+
self.siamese_dataset_test = SiameseNetworkDataset(imageFolderDataset=self.DatasetFolder_testing,
163+
transform=self.transform
164+
,should_invert=self.should_invert)
165+
166+
def train_dataloader(self):
167+
return DataLoader(self.siamese_dataset_train, batch_size=self.batch_size)
168+
169+
def test_dataloader(self):
170+
return DataLoader(self.siamese_dataset_test, batch_size=self.batch_size)
171+
172+
if __name__=='__main__':
173+
174+
parser = argparse.ArgumentParser(
175+
description='Siamese Network - Face Recognition',
176+
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
177+
parser.add_argument('--gpus', default=1, type=int)
178+
parser.add_argument('--batch_size', default=64, type=int)
179+
parser.add_argument('--pretrain_epochs', default=5000, type=int)
180+
parser.add_argument('--margin', default=1.0, type=float)
181+
parser.add_argument('--should_invert', default=False)
182+
parser.add_argument('--imageFolderTrain', default=None)
183+
parser.add_argument('--imageFolderTest', default=None)
184+
parser.add_argument('--learning_rate', default=2e-2, type=float)
185+
parser.add_argument('--resize', default=100, type=int)
186+
187+
188+
args = parser.parse_args()
189+
print(args)
190+
191+
model = SiameseNetwork(margin= args.margin, learning_rate=args.learning_rate, resize=args.resize, imageFolderTrain=args.imageFolderTrain,
192+
imageFolderTest=args.imageFolderTest, batch_size=args.batch_size, should_invert=args.should_invert)
193+
trainer = pl.Trainer(gpus=args.gpus, max_epochs=args.pretrain_epochs, progress_bar_refresh_rate=20)
194+
trainer.fit(model)
195+
trainer.save_checkpoint("siamese_face_recognition.ckpt")
196+
trainer.test()

utils.py

+46
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import torch
2+
import torchvision.datasets as dset
3+
import torchvision.transforms as transforms
4+
from torch.utils.data import DataLoader, Dataset
5+
import torchvision
6+
import matplotlib.pyplot as plt
7+
import torch.nn.functional as F
8+
import numpy as np
9+
from model import SiameseNetwork
10+
%matplotlib inline
11+
12+
model = SiameseNetwork.load_from_checkpoint("siamese_face_recognition.ckpt")
13+
14+
def imshow(img,text=None,should_save=False):
15+
npimg = img.numpy()
16+
plt.axis("off")
17+
if text:
18+
plt.text(75, 8, text, style='italic',fontweight='bold',
19+
bbox={'facecolor':'white', 'alpha':0.8, 'pad':10})
20+
plt.imshow(np.transpose(npimg, (1, 2, 0)))
21+
plt.show()
22+
23+
def show_plot(iteration,loss):
24+
plt.plot(iteration,loss)
25+
plt.show()
26+
27+
28+
# def view_dissimilarity(testDirectory):
29+
# folder_dataset_test = dset.ImageFolder(root=testDirectory)
30+
# siamese_dataset = SiameseNetworkDataset(imageFolderDataset=folder_dataset_test,
31+
# transform=transforms.Compose([transforms.Resize((100,100)),
32+
# transforms.ToTensor()
33+
# ])
34+
# ,should_invert=False)
35+
36+
# test_dataloader = DataLoader(siamese_dataset,num_workers=6,batch_size=1,shuffle=True)
37+
# dataiter = iter(test_dataloader)
38+
# x0,_,_ = next(dataiter)
39+
40+
# for i in range(4):
41+
# _,x1,label2 = next(dataiter)
42+
# concatenated = torch.cat((x0,x1),0)
43+
# model.eval()
44+
# output1,output2 = model(Variable(x0).cuda(),Variable(x1).cuda())
45+
# euclidean_distance = F.pairwise_distance(output1, output2)
46+
# imshow(torchvision.utils.make_grid(concatenated),'Dissimilarity: {:.2f}'.format(euclidean_distance.item()))

0 commit comments

Comments
 (0)