Skip to content

Commit 6e4dfbf

Browse files
committed
code refactoring
1 parent e8205b3 commit 6e4dfbf

26 files changed

+254
-240
lines changed

constrained_opt.py

Lines changed: 21 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import sys
66
from lib import utils
77
from PyQt4.QtCore import *
8-
import cv2
8+
99

1010
class Constrained_OPT(QThread):
1111
def __init__(self, opt_solver, batch_size=32, n_iters=25, topK=16, morph_steps=16, interp='linear'):
@@ -35,7 +35,7 @@ def __init__(self, opt_solver, batch_size=32, n_iters=25, topK=16, morph_steps=1
3535
self.order = None
3636
self.init_constraints() # initialize
3737
self.init_z() # initialize latent vectors
38-
self.just_fixed=True
38+
self.just_fixed = True
3939
self.weights = None
4040

4141
def is_fixed(self):
@@ -55,7 +55,7 @@ def init_z(self, frame_id=-1, image_id=-1):
5555
print('set z as image %d, frame %d' % (image_id, frame_id))
5656
self.prev_z = self.z_seq[image_id, frame_id]
5757

58-
if self.prev_z is None: #random initialization
58+
if self.prev_z is None: # random initialization
5959
self.z_init = np_rng.uniform(-1.0, 1.0, size=(self.batch_size, nz))
6060
self.opt_solver.set_smoothness(0.0)
6161
self.z_const = self.z_init
@@ -79,25 +79,19 @@ def update(self): # update ui
7979

8080
def save_constraints(self):
8181
[im_c, mask_c, im_e, mask_e] = self.combine_constraints(self.constraints)
82-
# write image
83-
# im_c2 = cv2.cvtColor(im_c, cv2.COLOR_RGB2BGR)
84-
# cv2.imwrite('input_color_image.png', im_c2)
85-
# cv2.imwrite('input_color_mask.png', mask_c)
86-
# cv2.imwrite('input_edge_map.png', im_e)
8782
self.prev_im_c = im_c.copy()
8883
self.prev_mask_c = mask_c.copy()
8984
self.prev_im_e = im_e.copy()
90-
self.prev_mask_e =mask_e.copy()
85+
self.prev_mask_e = mask_e.copy()
9186

9287
def init_constraints(self):
9388
self.prev_im_c = None
9489
self.prev_mask_c = None
9590
self.prev_im_e = None
9691
self.prev_mask_e = None
9792

98-
9993
def combine_constraints(self, constraints):
100-
if constraints is not None: #[hack]
94+
if constraints is not None: # [hack]
10195
# print('combine strokes')
10296
[im_c, mask_c, im_e, mask_e] = constraints
10397
if self.prev_im_c is None:
@@ -113,22 +107,21 @@ def combine_constraints(self, constraints):
113107
if self.prev_im_c is None:
114108
im_c_f = im_c
115109
else:
116-
im_c_f = self.prev_im_c.copy()
117-
mask_c3 = np.tile(mask_c, [1,1, im_c.shape[2]])
118-
np.copyto(im_c_f, im_c, where=mask_c3.astype(np.bool)) #[hack]
110+
im_c_f = self.prev_im_c.copy()
111+
mask_c3 = np.tile(mask_c, [1, 1, im_c.shape[2]])
112+
np.copyto(im_c_f, im_c, where=mask_c3.astype(np.bool)) # [hack]
119113

120114
if self.prev_im_e is None:
121115
im_e_f = im_e
122116
else:
123117
im_e_f = self.prev_im_e.copy()
124-
mask_e3 = np.tile(mask_e, [1,1,im_e.shape[2]])
118+
mask_e3 = np.tile(mask_e, [1, 1, im_e.shape[2]])
125119
np.copyto(im_e_f, im_e, where=mask_e3.astype(np.bool))
126120

127121
return [im_c_f, mask_c_f, im_e_f, mask_e_f]
128122
else:
129123
return [self.prev_im_c, self.prev_mask_c, self.prev_im_e, self.prev_mask_e]
130124

131-
132125
def set_constraints(self, constraints):
133126
self.constraints = constraints
134127

@@ -140,7 +133,6 @@ def get_z(self, image_id, frame_id):
140133
else:
141134
return None
142135

143-
144136
def get_image(self, image_id, frame_id, useAverage=False):
145137
if self.to_update:
146138
if self.current_ims is None or self.current_ims.size == 0:
@@ -158,7 +150,7 @@ def get_image(self, image_id, frame_id, useAverage=False):
158150
frame_id = frame_id % self.img_seq.shape[1]
159151
image_id = image_id % self.img_seq.shape[0]
160152
if useAverage and self.weights is not None:
161-
return utils.average_image(self.img_seq[:,frame_id,...], self.weights)
153+
return utils.average_image(self.img_seq[:, frame_id, ...], self.weights)
162154
else:
163155
return self.img_seq[image_id, frame_id]
164156

@@ -188,10 +180,10 @@ def get_current_results(self):
188180
return self.current_ims
189181

190182
def run(self): # main function
191-
time_to_wait = 33 # 33 millisecond
183+
time_to_wait = 33 # 33 millisecond
192184
while (1):
193-
t1 =time()
194-
if self.to_set_constraints:# update constraints
185+
t1 = time()
186+
if self.to_set_constraints: # update constraints
195187
self.to_set_constraints = False
196188

197189
if self.constraints is not None and self.iter_count < self.max_iters:
@@ -204,11 +196,11 @@ def run(self): # main function
204196
self.to_update = False
205197
self.iter_count += 1
206198

207-
t_c = int(1000*(time()-t1))
199+
t_c = int(1000 * (time() - t1))
208200
print('update one iteration: %03d ms' % t_c, end='\r')
209201
sys.stdout.flush()
210202
if t_c < time_to_wait:
211-
self.msleep(time_to_wait-t_c)
203+
self.msleep(time_to_wait - t_c)
212204

213205
def update_invert(self, constraints):
214206
constraints_c = self.combine_constraints(constraints)
@@ -218,7 +210,7 @@ def update_invert(self, constraints):
218210

219211
if self.topK > 1:
220212
cost_sort = cost_all[order]
221-
thres_top = 2 * np.mean(cost_sort[0:min(int(self.topK / 2.0), len(cost_sort))])
213+
thres_top = 2 * np.mean(cost_sort[0:min(int(self.topK / 2.0), len(cost_sort))])
222214
ids = cost_sort - thres_top < 1e-10
223215
topK = np.min([self.topK, sum(ids)])
224216
else:
@@ -233,7 +225,7 @@ def update_invert(self, constraints):
233225
self.current_ims = gx_t[order]
234226
# compute weights
235227
cost_weights = cost_all[order]
236-
self.weights = np.exp(-(cost_weights-np.mean(cost_weights)) / (np.std(cost_weights)+1e-10))
228+
self.weights = np.exp(-(cost_weights - np.mean(cost_weights)) / (np.std(cost_weights) + 1e-10))
237229
self.current_zs = z_t[order]
238230
self.emit(SIGNAL('update_image'))
239231

@@ -248,14 +240,14 @@ def gen_morphing(self, interp='linear', n_steps=8):
248240
z_seq = []
249241

250242
for n in range(n_steps):
251-
ratio = n / float(n_steps- 1)
243+
ratio = n / float(n_steps - 1)
252244
z_t = utils.interp_z(z1, z2, ratio, interp=interp)
253245
seq = self.opt_solver.gen_samples(z0=z_t)
254246
img_seq.append(seq[:, np.newaxis, ...])
255-
z_seq.append(z_t[:,np.newaxis,...])
247+
z_seq.append(z_t[:, np.newaxis, ...])
256248
self.img_seq = np.concatenate(img_seq, axis=1)
257249
self.z_seq = np.concatenate(z_seq, axis=1)
258-
print('generate morphing sequence (%.3f seconds)' % (time()-t))
250+
print('generate morphing sequence (%.3f seconds)' % (time() - t))
259251

260252
def reset(self):
261253
self.prev_z = self.z0
@@ -271,7 +263,4 @@ def reset(self):
271263
self.to_set_constraints = False
272264
self.iter_total = 0
273265
self.iter_count = 0
274-
self.weights =None
275-
276-
277-
266+
self.weights = None

constrained_opt_theano.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from lib.theano_utils import floatX, sharedX
77
import numpy as np
88

9+
910
class OPT_Solver():
1011
def __init__(self, model, batch_size=32, d_weight=0.0):
1112
self.model = model
@@ -16,7 +17,7 @@ def __init__(self, model, batch_size=32, d_weight=0.0):
1617
self.transform = model.transform
1718
self.transform_mask = model.transform_mask
1819
self.inverse_transform = model.inverse_transform
19-
BS = 4 if self.nc == 1 else 8 # [hack]
20+
BS = 4 if self.nc == 1 else 8 # [hack]
2021
self.hog = HOGNet.HOGNet(use_bin=True, NO=8, BS=BS, nc=self.nc)
2122
self.opt_model = self.def_invert(model, batch_size=batch_size, d_weight=d_weight, nc=self.nc)
2223
self.batch_size = batch_size
@@ -26,8 +27,8 @@ def get_image_size(self):
2627

2728
def invert(self, constraints, z_i):
2829
[_invert, z_updates, z, beta_r, z_const] = self.opt_model
29-
constraints_t = self.preprocess_constraints(constraints)
30-
[im_c_t, mask_c_t, im_e_t, mask_e_t] = constraints_t # [im_c_t, mask_c_t, im_e_t, mask_e_t]
30+
constraints_t = self.preprocess_constraints(constraints)
31+
[im_c_t, mask_c_t, im_e_t, mask_e_t] = constraints_t # [im_c_t, mask_c_t, im_e_t, mask_e_t]
3132

3233
results = _invert(im_c_t, mask_c_t, im_e_t, mask_e_t, z_i.astype(np.float32))
3334

@@ -38,7 +39,6 @@ def invert(self, constraints, z_i):
3839
z_t = np.tanh(z.get_value()).copy()
3940
return gx_t, z_t, cost_all
4041

41-
4242
def preprocess_constraints(self, constraints):
4343
[im_c_o, mask_c_o, im_e_o, mask_e_o] = constraints
4444
im_c = self.transform(im_c_o[np.newaxis, :], self.nc)
@@ -65,7 +65,7 @@ def set_smoothness(self, l):
6565
def gen_samples(self, z0):
6666
samples = self.model.gen_samples(z0=z0)
6767
if self.nc == 1:
68-
samples = np.tile(samples, [1,1,1,3])
68+
samples = np.tile(samples, [1, 1, 1, 3])
6969
return samples
7070

7171
def def_invert(self, model, batch_size=1, d_weight=0.5, nc=1, lr=0.1, b1=0.9, nz=100, use_bin=True):
@@ -79,8 +79,8 @@ def def_invert(self, model, batch_size=1, d_weight=0.5, nc=1, lr=0.1, b1=0.9, nz
7979
gx = model.model_G(z)
8080
# input: im_c: 255: no edge; 0: edge; transform=> 1: no edge, 0: edge
8181

82-
if nc == 1: # gx, range [0, 1] => edge, 1
83-
gx3 = 1.0-gx #T.tile(gx, (1, 3, 1, 1))
82+
if nc == 1: # gx, range [0, 1] => edge, 1
83+
gx3 = 1.0 - gx # T.tile(gx, (1, 3, 1, 1))
8484
else:
8585
gx3 = gx
8686
mm_c = T.tile(m_c, (1, gx3.shape[1], 1, 1))
@@ -116,4 +116,3 @@ def def_invert(self, model, batch_size=1, d_weight=0.5, nc=1, lr=0.1, b1=0.9, nz
116116
_invert = theano.function(inputs=[x_c, m_c, x_e, m_e, z0], outputs=output, updates=z_updates)
117117
print('%.2f seconds to compile _invert function' % (time() - t))
118118
return [_invert, z_updates, z, d_weight_r, z_const]
119-

generate_samples.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from lib import utils
55
import cv2
66

7+
78
def parse_args():
89
parser = argparse.ArgumentParser(description='generated random samples (dcgan_theano)')
910
parser.add_argument('--model_name', dest='model_name', help='the model name', default='outdoor_64', type=str)
@@ -15,9 +16,10 @@ def parse_args():
1516
args = parser.parse_args()
1617
return args
1718

19+
1820
if __name__ == '__main__':
1921
args = parse_args()
20-
if not args.model_file: #if model directory is not specified
22+
if not args.model_file: # if model directory is not specified
2123
args.model_file = './models/%s.%s' % (args.model_name, args.model_type)
2224

2325
if not args.output_image:
@@ -37,4 +39,4 @@ def parse_args():
3739
im_vis = cv2.cvtColor(im_vis, cv2.COLOR_BGR2RGB)
3840
cv2.imwrite(args.output_image, im_vis)
3941
print('samples_shape', samples.shape)
40-
print('save image to %s' % args.output_image)
42+
print('save image to %s' % args.output_image)

iGAN_main.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from pydoc import locate
99
import constrained_opt
1010

11+
1112
def parse_args():
1213
parser = argparse.ArgumentParser(description='iGAN: Interactive Visual Synthesis Powered by GAN')
1314
parser.add_argument('--model_name', dest='model_name', help='the model name', default='outdoor_64', type=str)
@@ -21,14 +22,15 @@ def parse_args():
2122
parser.add_argument('--model_file', dest='model_file', help='the file that stores the generative model', type=str, default=None)
2223
parser.add_argument('--d_weight', dest='d_weight', help='captures the visual realism based on GAN discriminator', type=float, default=0.0)
2324
parser.add_argument('--interp', dest='interp', help='the interpolation method (linear or slerp)', type=str, default='linear')
24-
parser.add_argument('--average', dest='average', help='averageExplorer mode',action="store_true", default=False)
25+
parser.add_argument('--average', dest='average', help='averageExplorer mode', action="store_true", default=False)
2526
parser.add_argument('--shadow', dest='shadow', help='shadowDraw mode', action="store_true", default=False)
2627
args = parser.parse_args()
2728
return args
2829

30+
2931
if __name__ == '__main__':
3032
args = parse_args()
31-
if not args.model_file: #if the model_file is not specified
33+
if not args.model_file: # if the model_file is not specified
3234
args.model_file = './models/%s.%s' % (args.model_name, args.model_type)
3335

3436
for arg in vars(args):
@@ -54,4 +56,4 @@ def parse_args():
5456
window.setWindowTitle('Interactive GAN')
5557
window.setWindowFlags(window.windowFlags() & ~Qt.WindowMaximizeButtonHint) # fix window siz
5658
window.show()
57-
app.exec_()
59+
app.exec_()

iGAN_predict.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from pydoc import locate
1717
from lib import activations
1818

19+
1920
def def_feature(layer='conv4', up_scale=4):
2021
print('COMPILING...')
2122
t = time()
@@ -55,7 +56,8 @@ def def_bfgs(model_G, layer='conv4', npx=64, alpha=0.002):
5556
_invert = theano.function(inputs=[z, x, x_f], outputs=output)
5657

5758
print('%.2f seconds to compile _bfgs function' % (time() - t))
58-
return _invert,z
59+
return _invert, z
60+
5961

6062
def def_predict(model_P):
6163
print('COMPILING...')
@@ -74,8 +76,6 @@ def def_invert_models(gen_model, layer='conv4', alpha=0.002):
7476
return gen_model, bfgs_model, ftr_model, predict_model
7577

7678

77-
78-
7979
def predict_z(gen_model, _predict, ims, batch_size=32):
8080
n = ims.shape[0]
8181
n_gen = 0
@@ -97,28 +97,26 @@ def invert_bfgs_batch(gen_model, invert_model, ftr_model, ims, z_predict=None, n
9797
zs = []
9898
recs = []
9999
fs = []
100-
t = time()
101100
n_imgs = ims.shape[0]
102101
print('reconstruct %d images using bfgs' % n_imgs)
103102

104103
for n in range(n_imgs):
105-
im_n = ims[[n], :, :,:]
104+
im_n = ims[[n], :, :, :]
106105
if z_predict is not None:
107-
z0_n = z_predict[[n],...]
106+
z0_n = z_predict[[n], ...]
108107
else:
109108
z0_n = None
110-
gx, z_value, f_value = invert_bfgs(gen_model, invert_model, ftr_model,im=im_n, z_predict=z0_n, npx=npx)
109+
gx, z_value, f_value = invert_bfgs(gen_model, invert_model, ftr_model, im=im_n, z_predict=z0_n, npx=npx)
111110
rec_im = (gx * 255).astype(np.uint8)
112-
fs.append(f_value[np.newaxis,...])
113-
zs.append(z_value[np.newaxis,...])
111+
fs.append(f_value[np.newaxis, ...])
112+
zs.append(z_value[np.newaxis, ...])
114113
recs.append(rec_im)
115114
recs = np.concatenate(recs, axis=0)
116115
zs = np.concatenate(zs, axis=0)
117116
fs = np.concatenate(fs, axis=0)
118117
return recs, zs, fs
119118

120119

121-
122120
def invert_bfgs(gen_model, invert_model, ftr_model, im, z_predict=None, npx=64):
123121
_f, z = invert_model
124122
nz = gen_model.nz
@@ -131,14 +129,14 @@ def invert_bfgs(gen_model, invert_model, ftr_model, im, z_predict=None, npx=64):
131129
ftr = ftr_model(im_t)
132130

133131
prob = optimize.minimize(f_bfgs, z_predict, args=(_f, im_t, ftr),
134-
tol=1e-6, jac=True, method='L-BFGS-B', options={'maxiter':200})
132+
tol=1e-6, jac=True, method='L-BFGS-B', options={'maxiter': 200})
135133
print('n_iters = %3d, f = %.3f' % (prob.nit, prob.fun))
136134
z_opt = prob.x
137135
z_opt_n = floatX(z_opt[np.newaxis, :])
138136
[f_opt, g, gx] = _f(z_opt_n, im_t, ftr)
139137
gx = gen_model.inverse_transform(gx, npx=npx)
140138
z_opt = np.tanh(z_opt)
141-
return gx, z_opt,f_opt
139+
return gx, z_opt, f_opt
142140

143141

144142
def f_bfgs(z0, _f, x, x_f):
@@ -181,11 +179,12 @@ def parse_args():
181179
args = parser.parse_args()
182180
return args
183181

182+
184183
if __name__ == "__main__":
185184
args = parse_args()
186185
if not args.model_file: # if the model file is not specified
187186
args.model_file = './models/%s.%s' % (args.model_name, args.model_type)
188-
if not args.output_image:# if the output image path is not specified
187+
if not args.output_image: # if the output image path is not specified
189188
args.output_image = args.input_image.replace('.png', '_%s.png' % args.solver)
190189

191190
for arg in vars(args):
@@ -205,10 +204,10 @@ def parse_args():
205204
im = np.array(im)
206205
im_pre = im[np.newaxis, :, :, :]
207206
# run the model
208-
rec, _, _ = invert_images_CNN_opt(invert_models, im_pre, solver=args.solver)
207+
rec, _, _ = invert_images_CNN_opt(invert_models, im_pre, solver=args.solver)
209208
rec = np.squeeze(rec)
210209
rec_im = Image.fromarray(rec)
211210
# resize the image (input aspect ratio)
212211
rec_im = rec_im.resize((h, w))
213212
print('write result to %s' % args.output_image)
214-
rec_im.save(args.output_image)
213+
rec_im.save(args.output_image)

0 commit comments

Comments
 (0)