Skip to content

Commit 8e618a1

Browse files
author
Stanislav Podhorskiy
committed
Cleanup
Visualize maps
1 parent bd1b4cc commit 8e618a1

File tree

2 files changed

+23
-44
lines changed

2 files changed

+23
-44
lines changed

Sample.py

Lines changed: 14 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -86,58 +86,25 @@ def sample(cfg, logger):
8686
channels=3)
8787
model.eval()
8888

89-
#torch.cuda.manual_seed_all(110)
90-
9189
logger.info("Trainable parameters generator:")
9290
count_parameters(model.generator)
9391

94-
if False:
95-
model_dict = {
96-
'generator': model.generator,
97-
'mapping': model.mapping,
98-
'dlatent_avg': model.dlatent_avg,
99-
}
100-
else:
101-
model_dict = {
102-
'generator_s': model.generator,
103-
'mapping_fl_s': model.mapping,
104-
'dlatent_avg': model.dlatent_avg,
105-
}
92+
model_dict = {
93+
'generator_s': model.generator,
94+
'mapping_fl_s': model.mapping,
95+
'dlatent_avg': model.dlatent_avg,
96+
}
10697

10798
checkpointer = Checkpointer(cfg,
10899
model_dict,
109100
logger=logger,
110101
save=True)
111102

112-
file_name = 'results/karras2019stylegan-ffhq_new'
113-
# file_name = 'results/model_final'
114-
115103
checkpointer.load()
116104

117-
rgbs = []
118-
for i in range(model.generator.layer_count):
119-
rgbs.append((model.generator.to_rgb[i].to_rgb.weight[:].cpu().detach().numpy(),
120-
model.generator.to_rgb[i].to_rgb.bias[:].cpu().detach().numpy()))
121-
122-
#with open('rgbs.pkl', 'wb') as handle:
123-
# pickle.dump(rgbs, handle, protocol=pickle.HIGHEST_PROTOCOL)
124-
125-
# checkpointer.save('final_stripped')
126-
127-
#sample_b = torch.randn(1, cfg.MODEL.LATENT_SPACE_SIZE).view(-1, cfg.MODEL.LATENT_SPACE_SIZE)
128-
129-
# for i in range(100):
130-
# if i % 20 == 0:
131-
# sample_a = sample_b
132-
# sample_b = torch.randn(1, cfg.MODEL.LATENT_SPACE_SIZE).view(-1, cfg.MODEL.LATENT_SPACE_SIZE)
133-
# x = (i % 20) / 20.0
134-
# sample = sample_a * (1.0 - x) + sample_b * x
135-
# save_sample(model, sample, i)
136-
137-
# print(model.discriminator.get_statistics(8))
138-
139105
ctx = bimpy.Context()
140106
remove = bimpy.Bool(False)
107+
layers = bimpy.Int(8)
141108

142109
ctx.init(1800, 1600, "Styles")
143110

@@ -147,8 +114,9 @@ def sample(cfg, logger):
147114

148115
def update_image(sample):
149116
with torch.no_grad():
117+
torch.manual_seed(0)
150118
model.eval()
151-
x_rec = model.generate(8, remove.value, z=sample)
119+
x_rec = model.generate(layers.value, remove.value, z=sample)
152120
#model.generator.set(l.value, c.value)
153121
resultsample = ((x_rec * 0.5 + 0.5) * 255).type(torch.long).clamp(0, 255)
154122
resultsample = resultsample.cpu()[0, :, :, :]
@@ -161,13 +129,18 @@ def update_image(sample):
161129
im = bimpy.Image(update_image(sample))
162130
while(not ctx.should_close()):
163131
with ctx:
132+
133+
bimpy.set_window_font_scale(2.0)
134+
164135
if bimpy.checkbox('REMOVE BLOB', remove):
165136
im = bimpy.Image(update_image(sample))
166-
bimpy.image(im)
167137
if bimpy.button('NEXT'):
168138
latents = rnd.randn(1, cfg.MODEL.LATENT_SPACE_SIZE)
169139
sample = torch.tensor(latents).float().cuda()
170140
im = bimpy.Image(update_image(sample))
141+
if bimpy.slider_int("Layers", layers, 0, 8):
142+
im = bimpy.Image(update_image(sample))
143+
bimpy.image(im, bimpy.Vec2(1024, 1024))
171144

172145

173146
if __name__ == '__main__':

net.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ def decode(self, styles, lod, remove_blob=True):
257257
x = self.decode_block[i].forward(x, styles[:, 2 * i + 0], styles[:, 2 * i + 1])
258258
if remove_blob and i == 3:
259259
_x = x.clone()
260-
_x[x > 200.0] = 0
260+
_x[x > 300.0] = 0
261261

262262
# plt.hist((torch.max(torch.max(_x, dim=2)[0], dim=2)[0]).cpu().flatten().numpy(), bins=300)
263263
# plt.show()
@@ -267,8 +267,14 @@ def decode(self, styles, lod, remove_blob=True):
267267

268268
if _x is not None:
269269
x = _x
270-
271-
x = self.to_rgb[lod](x)
270+
if lod == 8:
271+
x = self.to_rgb[lod](x)
272+
else:
273+
x = x.max(dim=1, keepdim=True)[0]
274+
x = x - x.min()
275+
x = x / x.max()
276+
x = torch.pow(x, 1.0/2.2)
277+
x = x.repeat(1, 3, 1, 1)
272278
return x
273279

274280
def forward(self, styles, lod, remove_blob=True):

0 commit comments

Comments
 (0)