Skip to content

Commit 0ca1898

Browse files
committed
Add PGGAN
1 parent 26d855b commit 0ca1898

12 files changed

+2581
-0
lines changed

README.md

+24
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ Basically, this repository is a collection of my PyTorch implementation of Gener
1111
+ [Coupled GAN](#coupled-gan)
1212
+ [CycleGAN](#cyclegan)
1313
+ [GAN](#gan)
14+
+ [PGGAN](#pggan)
1415
+ [Wasserstein GAN](#wasserstein-gan)
1516
+ [Wasserstein GAN GP](#wasserstein-gan-gp)
1617

@@ -88,6 +89,29 @@ $ cd src/gan/
8889
$ python3 gan.py
8990
```
9091

92+
### PGGAN
93+
94+
_Progressive Growing of GANs for Improved Quality, Stability, and Variation_
95+
96+
#### Authors
97+
98+
Tero Karras, Timo Aila, Samuli Laine, Jaakko Lehtinen
99+
100+
#### Abstract
101+
102+
We describe a new training methodology for generative adversarial networks. The key idea is to grow both the generator and discriminator progressively: starting from a low resolution, we add new layers that model increasingly fine details as training progresses. This both speeds the training up and greatly stabilizes it, allowing us to produce images of unprecedented quality, e.g., CelebA images at 1024². We also propose a simple way to increase the variation in generated images, and achieve a record inception score of 8.80 in unsupervised CIFAR10. Additionally, we describe several implementation details that are important for discouraging unhealthy competition between the generator and discriminator. Finally, we suggest a new metric for evaluating GAN results, both in terms of image quality and variation. As an additional contribution, we construct a higher-quality version of the CelebA dataset.
103+
104+
[[paper]](https://research.nvidia.com/publication/2017-10_Progressive-Growing-of) [[Code]](./src/pggan/main.py)
105+
106+
#### Example Running
107+
108+
Before running the "main.py", you need to download the dataset from [here](https://drive.google.com/drive/folders/1j6uZ_a6zci0HyKZdpDq9kSa8VihtEPCp) to '/data' directory. You could find more information about downloading dataset from [the official PGGAN repository](https://github.com/tkarras/progressive_growing_of_gans/). My implementation uses the celeb dataset, so if you want to use other dataset, please follow the instructions in [the official PGGAN repository](https://github.com/tkarras/progressive_growing_of_gans/).
109+
110+
```
111+
$ cd src/pggan
112+
$ python3 main.py
113+
```
114+
91115
### Wasserstein GAN
92116

93117
_Wasserstein GAN_

src/pggan/PGGAN.py

+212
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
import os
2+
from glob import glob
3+
import copy
4+
5+
import torch
6+
7+
from preset import resl_to_batch, resl_to_lr, resl_to_ch
8+
from train_step import Train_LSGAN, Train_WGAN_GP
9+
10+
11+
def get_optim(net, optim_type, resl, beta, decay, momentum, nesterov=True):
12+
lr = resl_to_lr[resl]
13+
return {
14+
"adam" : torch.optim.Adam(net.parameters(), lr=lr, betas=beta, weight_decay=decay),
15+
"rmsprop" : torch.optim.RMSprop(net.parameters(), lr=lr, weight_decay=decay),
16+
"sgd" : torch.optim.SGD(net.parameters(), lr=lr, momentum=momentum, weight_decay=decay, nesterov=True)
17+
}[optim_type]
18+
19+
20+
class PGGAN:
21+
def __init__(self, arg, G, D, scalable_loader, torch_device, loss, tensorboard):
22+
self.arg = arg
23+
self.device = torch_device
24+
self.save_dir = arg.save_dir
25+
self.scalable_loader = scalable_loader
26+
27+
self.img_num = arg.img_num
28+
self.batch = resl_to_batch[arg.start_resl]
29+
self.tran_step = self.img_num // self.batch
30+
self.stab_step = self.img_num // self.batch
31+
32+
self.G = G
33+
34+
self.G_ema = copy.deepcopy(G.module).cpu()
35+
self.G_ema.eval()
36+
for p in self.G_ema.parameters():
37+
p.requires_grad_(False)
38+
39+
self.D = D
40+
self.optim_G = get_optim(self.G, self.arg.optim_G, self.arg.start_resl, self.arg.beta, self.arg.decay, self.arg.momentum)
41+
self.optim_D = get_optim(self.D, self.arg.optim_G, self.arg.start_resl, self.arg.beta, self.arg.decay, self.arg.momentum)
42+
43+
self.tensorboard = tensorboard
44+
45+
if loss == "lsgan":
46+
self.step = Train_LSGAN(self.G, self.D, self.optim_G, self.optim_D, self.arg.label_smoothing, self.batch, self.device)
47+
elif loss == "wgangp":
48+
self.step = Train_WGAN_GP(self.G, self.D, self.optim_G, self.optim_D, self.arg.gp_lambda, self.arg.eps_drift ,self.batch, self.device)
49+
50+
self.load_resl = -1
51+
self.load_global_step = -1
52+
self.load()
53+
54+
55+
def save(self, global_step, resl, mode):
56+
"""Save current step model
57+
Save Elements:
58+
model_type : arg.model
59+
start_step : current step
60+
network : network parameters
61+
optimizer: optimizer parameters
62+
best_metric : current best score
63+
Parameters:
64+
step : current step
65+
filename : model save file name
66+
"""
67+
torch.save({"global_step" : global_step,
68+
"resl" : resl,
69+
"G" : self.G.state_dict(),
70+
"G_ema" : self.G_ema.state_dict(),
71+
"D" : self.D.state_dict(),
72+
"optim_G" : self.optim_G.state_dict(),
73+
"optim_D" : self.optim_D.state_dict(),
74+
}, self.save_dir + "/step_%07d_resl_%d_%s.pth.tar" % (global_step, resl, mode))
75+
print("Model saved %d step" % (global_step))
76+
77+
def load(self, filename=None):
78+
""" Model load. same with save"""
79+
if filename is None:
80+
# load last epoch model
81+
filenames = sorted(glob(self.save_dir + "/*.pth.tar"))
82+
if len(filenames) == 0:
83+
print("Not Load")
84+
return
85+
else:
86+
filename = os.path.basename(filenames[-1])
87+
88+
file_path = self.save_dir + "/" + filename
89+
90+
if os.path.exists(file_path) is True:
91+
print("Load %s to %s File" % (self.save_dir, filename))
92+
ckpoint = torch.load(file_path)
93+
94+
self.load_resl = ckpoint["resl"]
95+
96+
resl = self.arg.start_resl
97+
while resl < self.load_resl:
98+
self.G.module.grow_network()
99+
self.D.module.grow_network()
100+
self.G_ema.grow_network()
101+
self.G.to(self.device)
102+
self.D.to(self.device)
103+
resl *= 2
104+
105+
self.G.load_state_dict(ckpoint["G"])
106+
self.G_ema.load_state_dict(ckpoint["G_ema"])
107+
self.D.load_state_dict(ckpoint["D"])
108+
self.optim_G.load_state_dict(ckpoint['optim_G'])
109+
self.optim_D.load_state_dict(ckpoint['optim_D'])
110+
self.load_global_step = ckpoint["global_step"]
111+
print("Load Model, Global step : %d / Resolution : %d " % (self.load_global_step, self.load_resl))
112+
113+
else:
114+
print("Load Failed, not exists file")
115+
116+
117+
118+
def grow_architecture(self, resl, global_step):
119+
resl *= 2
120+
121+
self.batch = resl_to_batch[resl]
122+
self.stab_step = (self.img_num // self.batch) * resl_to_ch[resl]
123+
self.tran_step = (self.img_num // self.batch) * resl_to_ch[resl]
124+
125+
self.optim_G.param_groups = []
126+
self.optim_G.add_param_group({"params": list(self.G.parameters())})
127+
self.optim_D.param_groups = []
128+
self.optim_D.add_param_group({"params": list(self.D.parameters())})
129+
130+
lr = resl_to_lr[resl]
131+
for x in self.optim_G.param_groups + self.optim_D.param_groups:
132+
x["lr"] = lr
133+
self.step.grow(self.batch, self.optim_G, self.optim_D)
134+
135+
136+
# When the saved model is loaded, self.load() already grows the architecture
137+
# To prevent additional growing, this condition is required
138+
if global_step >= self.load_global_step:
139+
self.G.module.grow_network()
140+
self.G_ema.grow_network()
141+
self.D.module.grow_network()
142+
self.G.to(self.device)
143+
self.D.to(self.device)
144+
torch.cuda.empty_cache()
145+
return resl
146+
else:
147+
self.G.module.alpha = 0
148+
self.G_ema.alpha = 0
149+
self.D.module.alpha = 0
150+
return resl
151+
152+
153+
def update_ema(self):
154+
with torch.no_grad():
155+
named_param = dict(self.G.module.named_parameters())
156+
for k, v in self.G_ema.named_parameters():
157+
param = named_param[k].detach().cpu()
158+
v.copy_(self.arg.ema_decay * v + (1 - self.arg.ema_decay) * param)
159+
160+
161+
def train(self):
162+
# Initialize Train
163+
global_step, resl = 0, self.arg.start_resl
164+
loader = self.scalable_loader(resl)
165+
166+
def _step(step, loader, mode, LOG_PER_STEP=50):
167+
# When the saved model is loaded,
168+
# skips network train until loaded step
169+
nonlocal global_step
170+
if global_step <= self.load_global_step:
171+
global_step += 1
172+
return
173+
174+
input_, _ = next(loader)
175+
input_ = input_.to(self.device)
176+
log_D = self.step.train_D(input_, mode, d_iter=self.arg.d_iter)
177+
log_G = self.step.train_G(mode)
178+
self.update_ema()
179+
180+
# Save images and record logs
181+
if (step % LOG_PER_STEP) == 0:
182+
print("[% 6d/% 6d : % 3.2f %%]" % (step, self.tran_step, (step / self.tran_step) * 100))
183+
self.G_ema.eval()
184+
with torch.no_grad():
185+
self.tensorboard.log_image(self.G_ema, mode, resl, global_step)
186+
self.tensorboard.log_scalar("Loss/%d" % (resl), {**log_D, **log_G}, global_step)
187+
188+
if (step % (LOG_PER_STEP * 10)) == 0:
189+
self.save(global_step, resl, mode)
190+
global_step += 1
191+
192+
193+
# Stabilization on initial resolution (default: 4 * 4)
194+
for step in range(self.stab_step):
195+
_step(step, loader, "stabilization")
196+
197+
while (resl < self.arg.end_resl):
198+
# Grow and update resolution, batch size, etc. Load the models on GPUs
199+
resl = self.grow_architecture(resl, global_step)
200+
loader = self.scalable_loader(resl)
201+
for step in range(self.tran_step):
202+
_step(step, loader, "transition")
203+
self.G.module.update_alpha(1 / self.tran_step)
204+
self.G_ema.update_alpha(1 / self.tran_step)
205+
self.D.module.update_alpha(1 / self.tran_step)
206+
207+
# Stabilization
208+
for step in range(self.stab_step):
209+
_step(step, loader, "stabilization")
210+
211+
for step in range(self.arg.extra_training_img_num):
212+
_step(step, loader, "stabilization")

src/pggan/ScalableLoader.py

+44
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import random
2+
from PIL import Image
3+
4+
from torch.utils.data import DataLoader
5+
6+
from torchvision import transforms
7+
from torchvision.datasets import ImageFolder
8+
9+
from preset import resl_to_batch
10+
11+
12+
class ScalableLoader:
13+
def __init__(self, path, shuffle=True, drop_last=False, num_workers=4, shuffled_cycle=True):
14+
self.path = path
15+
self.shuffle = shuffle
16+
self.drop_last = drop_last
17+
self.num_workers = num_workers
18+
self.shuffled_cycle = shuffled_cycle
19+
20+
def __call__(self, resl):
21+
batch = resl_to_batch[resl]
22+
23+
transform = transforms.Compose([transforms.Resize(size=(resl, resl), interpolation=Image.NEAREST), transforms.ToTensor()])
24+
25+
root = self.path + str(max(64, resl))
26+
print("Data root : %s" % root)
27+
28+
loader = DataLoader(
29+
dataset=ImageFolder(root=root, transform=transform),
30+
batch_size=batch,
31+
shuffle=self.shuffle,
32+
drop_last=self.drop_last,
33+
num_workers=self.num_workers
34+
)
35+
36+
loader = self.cycle(loader)
37+
return loader
38+
39+
def cycle(self, loader):
40+
while True:
41+
for element in loader:
42+
yield element
43+
if self.shuffled_cycle:
44+
random.shuffle(loader.dataset.imgs)

0 commit comments

Comments
 (0)