@@ -86,58 +86,25 @@ def sample(cfg, logger):
86
86
channels = 3 )
87
87
model .eval ()
88
88
89
- #torch.cuda.manual_seed_all(110)
90
-
91
89
logger .info ("Trainable parameters generator:" )
92
90
count_parameters (model .generator )
93
91
94
- if False :
95
- model_dict = {
96
- 'generator' : model .generator ,
97
- 'mapping' : model .mapping ,
98
- 'dlatent_avg' : model .dlatent_avg ,
99
- }
100
- else :
101
- model_dict = {
102
- 'generator_s' : model .generator ,
103
- 'mapping_fl_s' : model .mapping ,
104
- 'dlatent_avg' : model .dlatent_avg ,
105
- }
92
+ model_dict = {
93
+ 'generator_s' : model .generator ,
94
+ 'mapping_fl_s' : model .mapping ,
95
+ 'dlatent_avg' : model .dlatent_avg ,
96
+ }
106
97
107
98
checkpointer = Checkpointer (cfg ,
108
99
model_dict ,
109
100
logger = logger ,
110
101
save = True )
111
102
112
- file_name = 'results/karras2019stylegan-ffhq_new'
113
- # file_name = 'results/model_final'
114
-
115
103
checkpointer .load ()
116
104
117
- rgbs = []
118
- for i in range (model .generator .layer_count ):
119
- rgbs .append ((model .generator .to_rgb [i ].to_rgb .weight [:].cpu ().detach ().numpy (),
120
- model .generator .to_rgb [i ].to_rgb .bias [:].cpu ().detach ().numpy ()))
121
-
122
- #with open('rgbs.pkl', 'wb') as handle:
123
- # pickle.dump(rgbs, handle, protocol=pickle.HIGHEST_PROTOCOL)
124
-
125
- # checkpointer.save('final_stripped')
126
-
127
- #sample_b = torch.randn(1, cfg.MODEL.LATENT_SPACE_SIZE).view(-1, cfg.MODEL.LATENT_SPACE_SIZE)
128
-
129
- # for i in range(100):
130
- # if i % 20 == 0:
131
- # sample_a = sample_b
132
- # sample_b = torch.randn(1, cfg.MODEL.LATENT_SPACE_SIZE).view(-1, cfg.MODEL.LATENT_SPACE_SIZE)
133
- # x = (i % 20) / 20.0
134
- # sample = sample_a * (1.0 - x) + sample_b * x
135
- # save_sample(model, sample, i)
136
-
137
- # print(model.discriminator.get_statistics(8))
138
-
139
105
ctx = bimpy .Context ()
140
106
remove = bimpy .Bool (False )
107
+ layers = bimpy .Int (8 )
141
108
142
109
ctx .init (1800 , 1600 , "Styles" )
143
110
@@ -147,8 +114,9 @@ def sample(cfg, logger):
147
114
148
115
def update_image (sample ):
149
116
with torch .no_grad ():
117
+ torch .manual_seed (0 )
150
118
model .eval ()
151
- x_rec = model .generate (8 , remove .value , z = sample )
119
+ x_rec = model .generate (layers . value , remove .value , z = sample )
152
120
#model.generator.set(l.value, c.value)
153
121
resultsample = ((x_rec * 0.5 + 0.5 ) * 255 ).type (torch .long ).clamp (0 , 255 )
154
122
resultsample = resultsample .cpu ()[0 , :, :, :]
@@ -161,13 +129,18 @@ def update_image(sample):
161
129
im = bimpy .Image (update_image (sample ))
162
130
while (not ctx .should_close ()):
163
131
with ctx :
132
+
133
+ bimpy .set_window_font_scale (2.0 )
134
+
164
135
if bimpy .checkbox ('REMOVE BLOB' , remove ):
165
136
im = bimpy .Image (update_image (sample ))
166
- bimpy .image (im )
167
137
if bimpy .button ('NEXT' ):
168
138
latents = rnd .randn (1 , cfg .MODEL .LATENT_SPACE_SIZE )
169
139
sample = torch .tensor (latents ).float ().cuda ()
170
140
im = bimpy .Image (update_image (sample ))
141
+ if bimpy .slider_int ("Layers" , layers , 0 , 8 ):
142
+ im = bimpy .Image (update_image (sample ))
143
+ bimpy .image (im , bimpy .Vec2 (1024 , 1024 ))
171
144
172
145
173
146
if __name__ == '__main__' :
0 commit comments